whisper.rn 0.4.0-rc.7 → 0.4.0-rc.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/android/src/main/CMakeLists.txt +2 -1
  2. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  3. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
  5. package/android/src/main/jni.cpp +29 -1
  6. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  7. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  8. package/cpp/coreml/whisper-encoder.mm +1 -1
  9. package/cpp/ggml-aarch64.c +3209 -0
  10. package/cpp/ggml-aarch64.h +39 -0
  11. package/cpp/ggml-alloc.c +732 -494
  12. package/cpp/ggml-alloc.h +47 -63
  13. package/cpp/ggml-backend-impl.h +162 -47
  14. package/cpp/ggml-backend.cpp +2635 -0
  15. package/cpp/ggml-backend.h +216 -71
  16. package/cpp/ggml-common.h +1853 -0
  17. package/cpp/ggml-cpu-impl.h +614 -0
  18. package/cpp/ggml-impl.h +144 -178
  19. package/cpp/ggml-metal.h +14 -60
  20. package/cpp/ggml-metal.m +3437 -2097
  21. package/cpp/ggml-quants.c +12559 -4189
  22. package/cpp/ggml-quants.h +135 -212
  23. package/cpp/ggml-whisper.metallib +0 -0
  24. package/cpp/ggml.c +9029 -5219
  25. package/cpp/ggml.h +673 -338
  26. package/cpp/rn-whisper.cpp +91 -0
  27. package/cpp/rn-whisper.h +2 -0
  28. package/cpp/whisper.cpp +1476 -675
  29. package/cpp/whisper.h +84 -28
  30. package/ios/RNWhisper.mm +124 -37
  31. package/ios/RNWhisperAudioUtils.h +1 -0
  32. package/ios/RNWhisperAudioUtils.m +20 -13
  33. package/ios/RNWhisperContext.h +3 -2
  34. package/ios/RNWhisperContext.mm +41 -8
  35. package/jest/mock.js +9 -1
  36. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  37. package/lib/commonjs/index.js +48 -19
  38. package/lib/commonjs/index.js.map +1 -1
  39. package/lib/commonjs/version.json +1 -1
  40. package/lib/module/NativeRNWhisper.js.map +1 -1
  41. package/lib/module/index.js +48 -19
  42. package/lib/module/index.js.map +1 -1
  43. package/lib/module/version.json +1 -1
  44. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  45. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  46. package/lib/typescript/index.d.ts +25 -3
  47. package/lib/typescript/index.d.ts.map +1 -1
  48. package/package.json +6 -5
  49. package/src/NativeRNWhisper.ts +12 -3
  50. package/src/index.ts +63 -24
  51. package/src/version.json +1 -1
  52. package/whisper-rn.podspec +9 -2
  53. package/cpp/ggml-backend.c +0 -1357
  54. package/cpp/ggml-metal-whisper.metal +0 -4908
package/cpp/ggml.h CHANGED
@@ -204,24 +204,30 @@
204
204
  # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
205
205
  #endif
206
206
 
207
- #include <stdint.h>
208
- #include <stddef.h>
209
207
  #include <stdbool.h>
208
+ #include <stddef.h>
209
+ #include <stdint.h>
210
+ #include <stdio.h>
210
211
 
211
212
  #define WSP_GGML_FILE_MAGIC 0x67676d6c // "ggml"
212
- #define WSP_GGML_FILE_VERSION 1
213
+ #define WSP_GGML_FILE_VERSION 2
213
214
 
214
215
  #define WSP_GGML_QNT_VERSION 2 // bump this on quantization format changes
215
216
  #define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this
216
217
 
217
218
  #define WSP_GGML_MAX_DIMS 4
218
219
  #define WSP_GGML_MAX_PARAMS 2048
219
- #define WSP_GGML_MAX_CONTEXTS 64
220
220
  #define WSP_GGML_MAX_SRC 10
221
- #define WSP_GGML_MAX_NAME 64
221
+ #define WSP_GGML_MAX_N_THREADS 512
222
222
  #define WSP_GGML_MAX_OP_PARAMS 64
223
+
224
+ #ifndef WSP_GGML_MAX_NAME
225
+ # define WSP_GGML_MAX_NAME 64
226
+ #endif
227
+
223
228
  #define WSP_GGML_DEFAULT_N_THREADS 4
224
229
  #define WSP_GGML_DEFAULT_GRAPH_SIZE 2048
230
+
225
231
  #if UINTPTR_MAX == 0xFFFFFFFF
226
232
  #define WSP_GGML_MEM_ALIGN 4
227
233
  #else
@@ -231,6 +237,8 @@
231
237
  #define WSP_GGML_EXIT_SUCCESS 0
232
238
  #define WSP_GGML_EXIT_ABORTED 1
233
239
 
240
+ #define WSP_GGML_ROPE_TYPE_NEOX 2
241
+
234
242
  #define WSP_GGUF_MAGIC "GGUF"
235
243
 
236
244
  #define WSP_GGUF_VERSION 3
@@ -241,24 +249,27 @@
241
249
 
242
250
  #define WSP_GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
243
251
 
244
- #define WSP_GGML_ASSERT(x) \
245
- do { \
246
- if (!(x)) { \
247
- fflush(stdout); \
248
- fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
249
- wsp_ggml_print_backtrace(); \
250
- abort(); \
251
- } \
252
- } while (0)
253
-
254
252
  #ifndef NDEBUG
255
- #define WSP_GGML_UNREACHABLE() WSP_GGML_ASSERT(!"statement should not be reached")
253
+ # define WSP_GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
256
254
  #elif defined(__GNUC__)
257
- #define WSP_GGML_UNREACHABLE() __builtin_unreachable()
255
+ # define WSP_GGML_UNREACHABLE() __builtin_unreachable()
256
+ #elif defined(_MSC_VER)
257
+ # define WSP_GGML_UNREACHABLE() __assume(0)
258
+ #else
259
+ # define WSP_GGML_UNREACHABLE() ((void) 0)
260
+ #endif
261
+
262
+ #ifdef __cplusplus
263
+ # define WSP_GGML_NORETURN [[noreturn]]
264
+ #elif defined(_MSC_VER)
265
+ # define WSP_GGML_NORETURN __declspec(noreturn)
258
266
  #else
259
- #define WSP_GGML_UNREACHABLE() ((void) 0)
267
+ # define WSP_GGML_NORETURN _Noreturn
260
268
  #endif
261
269
 
270
+ #define WSP_GGML_ABORT(...) wsp_ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
271
+ #define WSP_GGML_ASSERT(x) if (!(x)) WSP_GGML_ABORT("WSP_GGML_ASSERT(%s) failed", #x)
272
+
262
273
  // used to copy the number of elements and stride in bytes of tensors into local variables.
263
274
  // main purpose is to reduce code duplication and improve readability.
264
275
  //
@@ -297,74 +308,131 @@
297
308
  WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
298
309
  WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
299
310
 
311
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS01 \
312
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
313
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
314
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
315
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
316
+
300
317
  #ifdef __cplusplus
301
318
  extern "C" {
302
319
  #endif
303
320
 
304
- #if defined(__ARM_NEON) && defined(__CUDACC__)
305
- typedef half wsp_ggml_fp16_t;
306
- #elif defined(__ARM_NEON)
307
- typedef __fp16 wsp_ggml_fp16_t;
308
- #else
309
- typedef uint16_t wsp_ggml_fp16_t;
310
- #endif
321
+ WSP_GGML_NORETURN WSP_GGML_ATTRIBUTE_FORMAT(3, 4)
322
+ WSP_GGML_API void wsp_ggml_abort(const char * file, int line, const char * fmt, ...);
323
+
324
+ enum wsp_ggml_status {
325
+ WSP_GGML_STATUS_ALLOC_FAILED = -2,
326
+ WSP_GGML_STATUS_FAILED = -1,
327
+ WSP_GGML_STATUS_SUCCESS = 0,
328
+ WSP_GGML_STATUS_ABORTED = 1,
329
+ };
311
330
 
312
- // convert FP16 <-> FP32
313
- WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t x);
314
- WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float x);
331
+ // get wsp_ggml_status name string
332
+ WSP_GGML_API const char * wsp_ggml_status_to_string(enum wsp_ggml_status status);
315
333
 
316
- WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t * x, float * y, int n);
317
- WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float * x, wsp_ggml_fp16_t * y, int n);
334
+ // ieee 754-2008 half-precision float16
335
+ // todo: make this not an integral type
336
+ typedef uint16_t wsp_ggml_fp16_t;
337
+ WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t);
338
+ WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float);
339
+ WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t *, float *, int64_t);
340
+ WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float *, wsp_ggml_fp16_t *, int64_t);
341
+
342
+ // google brain half-precision bfloat16
343
+ typedef struct { uint16_t bits; } wsp_ggml_bf16_t;
344
+ WSP_GGML_API wsp_ggml_bf16_t wsp_ggml_fp32_to_bf16(float);
345
+ WSP_GGML_API float wsp_ggml_bf16_to_fp32(wsp_ggml_bf16_t); // consider just doing << 16
346
+ WSP_GGML_API void wsp_ggml_bf16_to_fp32_row(const wsp_ggml_bf16_t *, float *, int64_t);
347
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row_ref(const float *, wsp_ggml_bf16_t *, int64_t);
348
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row(const float *, wsp_ggml_bf16_t *, int64_t);
318
349
 
319
350
  struct wsp_ggml_object;
320
351
  struct wsp_ggml_context;
352
+ struct wsp_ggml_cgraph;
321
353
 
354
+ // NOTE: always add types at the end of the enum to keep backward compatibility
322
355
  enum wsp_ggml_type {
323
- WSP_GGML_TYPE_F32 = 0,
324
- WSP_GGML_TYPE_F16 = 1,
325
- WSP_GGML_TYPE_Q4_0 = 2,
326
- WSP_GGML_TYPE_Q4_1 = 3,
356
+ WSP_GGML_TYPE_F32 = 0,
357
+ WSP_GGML_TYPE_F16 = 1,
358
+ WSP_GGML_TYPE_Q4_0 = 2,
359
+ WSP_GGML_TYPE_Q4_1 = 3,
327
360
  // WSP_GGML_TYPE_Q4_2 = 4, support has been removed
328
- // WSP_GGML_TYPE_Q4_3 (5) support has been removed
329
- WSP_GGML_TYPE_Q5_0 = 6,
330
- WSP_GGML_TYPE_Q5_1 = 7,
331
- WSP_GGML_TYPE_Q8_0 = 8,
332
- WSP_GGML_TYPE_Q8_1 = 9,
333
- // k-quantizations
334
- WSP_GGML_TYPE_Q2_K = 10,
335
- WSP_GGML_TYPE_Q3_K = 11,
336
- WSP_GGML_TYPE_Q4_K = 12,
337
- WSP_GGML_TYPE_Q5_K = 13,
338
- WSP_GGML_TYPE_Q6_K = 14,
339
- WSP_GGML_TYPE_Q8_K = 15,
340
- WSP_GGML_TYPE_I8,
341
- WSP_GGML_TYPE_I16,
342
- WSP_GGML_TYPE_I32,
361
+ // WSP_GGML_TYPE_Q4_3 = 5, support has been removed
362
+ WSP_GGML_TYPE_Q5_0 = 6,
363
+ WSP_GGML_TYPE_Q5_1 = 7,
364
+ WSP_GGML_TYPE_Q8_0 = 8,
365
+ WSP_GGML_TYPE_Q8_1 = 9,
366
+ WSP_GGML_TYPE_Q2_K = 10,
367
+ WSP_GGML_TYPE_Q3_K = 11,
368
+ WSP_GGML_TYPE_Q4_K = 12,
369
+ WSP_GGML_TYPE_Q5_K = 13,
370
+ WSP_GGML_TYPE_Q6_K = 14,
371
+ WSP_GGML_TYPE_Q8_K = 15,
372
+ WSP_GGML_TYPE_IQ2_XXS = 16,
373
+ WSP_GGML_TYPE_IQ2_XS = 17,
374
+ WSP_GGML_TYPE_IQ3_XXS = 18,
375
+ WSP_GGML_TYPE_IQ1_S = 19,
376
+ WSP_GGML_TYPE_IQ4_NL = 20,
377
+ WSP_GGML_TYPE_IQ3_S = 21,
378
+ WSP_GGML_TYPE_IQ2_S = 22,
379
+ WSP_GGML_TYPE_IQ4_XS = 23,
380
+ WSP_GGML_TYPE_I8 = 24,
381
+ WSP_GGML_TYPE_I16 = 25,
382
+ WSP_GGML_TYPE_I32 = 26,
383
+ WSP_GGML_TYPE_I64 = 27,
384
+ WSP_GGML_TYPE_F64 = 28,
385
+ WSP_GGML_TYPE_IQ1_M = 29,
386
+ WSP_GGML_TYPE_BF16 = 30,
387
+ WSP_GGML_TYPE_Q4_0_4_4 = 31,
388
+ WSP_GGML_TYPE_Q4_0_4_8 = 32,
389
+ WSP_GGML_TYPE_Q4_0_8_8 = 33,
390
+ WSP_GGML_TYPE_TQ1_0 = 34,
391
+ WSP_GGML_TYPE_TQ2_0 = 35,
343
392
  WSP_GGML_TYPE_COUNT,
344
393
  };
345
394
 
395
+ // precision
396
+ enum wsp_ggml_prec {
397
+ WSP_GGML_PREC_DEFAULT,
398
+ WSP_GGML_PREC_F32,
399
+ };
400
+
346
401
  enum wsp_ggml_backend_type {
347
- WSP_GGML_BACKEND_CPU = 0,
348
- WSP_GGML_BACKEND_GPU = 10,
349
- WSP_GGML_BACKEND_GPU_SPLIT = 20,
402
+ WSP_GGML_BACKEND_TYPE_CPU = 0,
403
+ WSP_GGML_BACKEND_TYPE_GPU = 10,
404
+ WSP_GGML_BACKEND_TYPE_GPU_SPLIT = 20,
350
405
  };
351
406
 
352
407
  // model file types
353
408
  enum wsp_ggml_ftype {
354
- WSP_GGML_FTYPE_UNKNOWN = -1,
355
- WSP_GGML_FTYPE_ALL_F32 = 0,
356
- WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
357
- WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
358
- WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
409
+ WSP_GGML_FTYPE_UNKNOWN = -1,
410
+ WSP_GGML_FTYPE_ALL_F32 = 0,
411
+ WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
412
+ WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
413
+ WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
359
414
  WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
360
- WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
361
- WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
362
- WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
363
- WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
364
- WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
365
- WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
366
- WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
367
- WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
415
+ WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
416
+ WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
417
+ WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
418
+ WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
419
+ WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
420
+ WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
421
+ WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
422
+ WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
423
+ WSP_GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
424
+ WSP_GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
425
+ WSP_GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
426
+ WSP_GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
427
+ WSP_GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
428
+ WSP_GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
429
+ WSP_GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
430
+ WSP_GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
431
+ WSP_GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
432
+ WSP_GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
433
+ WSP_GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
434
+ WSP_GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
435
+ WSP_GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
368
436
  };
369
437
 
370
438
  // available tensor operations:
@@ -381,10 +449,13 @@ extern "C" {
381
449
  WSP_GGML_OP_SQR,
382
450
  WSP_GGML_OP_SQRT,
383
451
  WSP_GGML_OP_LOG,
452
+ WSP_GGML_OP_SIN,
453
+ WSP_GGML_OP_COS,
384
454
  WSP_GGML_OP_SUM,
385
455
  WSP_GGML_OP_SUM_ROWS,
386
456
  WSP_GGML_OP_MEAN,
387
457
  WSP_GGML_OP_ARGMAX,
458
+ WSP_GGML_OP_COUNT_EQUAL,
388
459
  WSP_GGML_OP_REPEAT,
389
460
  WSP_GGML_OP_REPEAT_BACK,
390
461
  WSP_GGML_OP_CONCAT,
@@ -415,25 +486,30 @@ extern "C" {
415
486
  WSP_GGML_OP_SOFT_MAX_BACK,
416
487
  WSP_GGML_OP_ROPE,
417
488
  WSP_GGML_OP_ROPE_BACK,
418
- WSP_GGML_OP_ALIBI,
419
489
  WSP_GGML_OP_CLAMP,
420
490
  WSP_GGML_OP_CONV_TRANSPOSE_1D,
421
491
  WSP_GGML_OP_IM2COL,
492
+ WSP_GGML_OP_IM2COL_BACK,
422
493
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
423
494
  WSP_GGML_OP_POOL_1D,
424
495
  WSP_GGML_OP_POOL_2D,
496
+ WSP_GGML_OP_POOL_2D_BACK,
425
497
  WSP_GGML_OP_UPSCALE, // nearest interpolate
426
498
  WSP_GGML_OP_PAD,
499
+ WSP_GGML_OP_ARANGE,
500
+ WSP_GGML_OP_TIMESTEP_EMBEDDING,
427
501
  WSP_GGML_OP_ARGSORT,
428
502
  WSP_GGML_OP_LEAKY_RELU,
429
503
 
430
- WSP_GGML_OP_FLASH_ATTN,
431
- WSP_GGML_OP_FLASH_FF,
504
+ WSP_GGML_OP_FLASH_ATTN_EXT,
432
505
  WSP_GGML_OP_FLASH_ATTN_BACK,
506
+ WSP_GGML_OP_SSM_CONV,
507
+ WSP_GGML_OP_SSM_SCAN,
433
508
  WSP_GGML_OP_WIN_PART,
434
509
  WSP_GGML_OP_WIN_UNPART,
435
510
  WSP_GGML_OP_GET_REL_POS,
436
511
  WSP_GGML_OP_ADD_REL_POS,
512
+ WSP_GGML_OP_RWKV_WKV,
437
513
 
438
514
  WSP_GGML_OP_UNARY,
439
515
 
@@ -450,6 +526,7 @@ extern "C" {
450
526
 
451
527
  WSP_GGML_OP_CROSS_ENTROPY_LOSS,
452
528
  WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
529
+ WSP_GGML_OP_OPT_STEP_ADAMW,
453
530
 
454
531
  WSP_GGML_OP_COUNT,
455
532
  };
@@ -462,47 +539,48 @@ extern "C" {
462
539
  WSP_GGML_UNARY_OP_TANH,
463
540
  WSP_GGML_UNARY_OP_ELU,
464
541
  WSP_GGML_UNARY_OP_RELU,
542
+ WSP_GGML_UNARY_OP_SIGMOID,
465
543
  WSP_GGML_UNARY_OP_GELU,
466
544
  WSP_GGML_UNARY_OP_GELU_QUICK,
467
545
  WSP_GGML_UNARY_OP_SILU,
546
+ WSP_GGML_UNARY_OP_HARDSWISH,
547
+ WSP_GGML_UNARY_OP_HARDSIGMOID,
548
+ WSP_GGML_UNARY_OP_EXP,
468
549
 
469
550
  WSP_GGML_UNARY_OP_COUNT,
470
551
  };
471
552
 
472
553
  enum wsp_ggml_object_type {
473
- WSP_GGML_OBJECT_TENSOR,
474
- WSP_GGML_OBJECT_GRAPH,
475
- WSP_GGML_OBJECT_WORK_BUFFER
554
+ WSP_GGML_OBJECT_TYPE_TENSOR,
555
+ WSP_GGML_OBJECT_TYPE_GRAPH,
556
+ WSP_GGML_OBJECT_TYPE_WORK_BUFFER
476
557
  };
477
558
 
478
559
  enum wsp_ggml_log_level {
479
- WSP_GGML_LOG_LEVEL_ERROR = 2,
480
- WSP_GGML_LOG_LEVEL_WARN = 3,
481
- WSP_GGML_LOG_LEVEL_INFO = 4
560
+ WSP_GGML_LOG_LEVEL_NONE = 0,
561
+ WSP_GGML_LOG_LEVEL_INFO = 1,
562
+ WSP_GGML_LOG_LEVEL_WARN = 2,
563
+ WSP_GGML_LOG_LEVEL_ERROR = 3,
564
+ WSP_GGML_LOG_LEVEL_DEBUG = 4,
565
+ WSP_GGML_LOG_LEVEL_CONT = 5, // continue previous log
482
566
  };
483
567
 
484
- // ggml object
485
- struct wsp_ggml_object {
486
- size_t offs;
487
- size_t size;
488
-
489
- struct wsp_ggml_object * next;
490
-
491
- enum wsp_ggml_object_type type;
492
-
493
- char padding[4];
568
+ // this tensor...
569
+ enum wsp_ggml_tensor_flag {
570
+ WSP_GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
571
+ WSP_GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
572
+ WSP_GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
573
+ WSP_GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
494
574
  };
495
575
 
496
- static const size_t WSP_GGML_OBJECT_SIZE = sizeof(struct wsp_ggml_object);
497
-
498
576
  // n-dimensional tensor
499
577
  struct wsp_ggml_tensor {
500
- enum wsp_ggml_type type;
501
- enum wsp_ggml_backend_type backend;
578
+ enum wsp_ggml_type type;
579
+
580
+ WSP_GGML_DEPRECATED(enum wsp_ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
502
581
 
503
582
  struct wsp_ggml_backend_buffer * buffer;
504
583
 
505
- int n_dims;
506
584
  int64_t ne[WSP_GGML_MAX_DIMS]; // number of elements
507
585
  size_t nb[WSP_GGML_MAX_DIMS]; // stride in bytes:
508
586
  // nb[0] = wsp_ggml_type_size(type)
@@ -515,16 +593,12 @@ extern "C" {
515
593
  // op params - allocated as int32_t for alignment
516
594
  int32_t op_params[WSP_GGML_MAX_OP_PARAMS / sizeof(int32_t)];
517
595
 
518
- bool is_param;
596
+ int32_t flags;
519
597
 
520
598
  struct wsp_ggml_tensor * grad;
521
599
  struct wsp_ggml_tensor * src[WSP_GGML_MAX_SRC];
522
600
 
523
- // performance
524
- int perf_runs;
525
- int64_t perf_cycles;
526
- int64_t perf_time_us;
527
-
601
+ // source tensor and offset for views
528
602
  struct wsp_ggml_tensor * view_src;
529
603
  size_t view_offs;
530
604
 
@@ -534,11 +608,39 @@ extern "C" {
534
608
 
535
609
  void * extra; // extra things e.g. for ggml-cuda.cu
536
610
 
537
- char padding[12];
611
+ // char padding[4];
538
612
  };
539
613
 
540
614
  static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor);
541
615
 
616
+ // Abort callback
617
+ // If not NULL, called before ggml computation
618
+ // If it returns true, the computation is aborted
619
+ typedef bool (*wsp_ggml_abort_callback)(void * data);
620
+
621
+ // Scheduling priorities
622
+ enum wsp_ggml_sched_priority {
623
+ WSP_GGML_SCHED_PRIO_NORMAL,
624
+ WSP_GGML_SCHED_PRIO_MEDIUM,
625
+ WSP_GGML_SCHED_PRIO_HIGH,
626
+ WSP_GGML_SCHED_PRIO_REALTIME
627
+ };
628
+
629
+ // Threadpool params
630
+ // Use wsp_ggml_threadpool_params_default() or wsp_ggml_threadpool_params_init() to populate the defaults
631
+ struct wsp_ggml_threadpool_params {
632
+ bool cpumask[WSP_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
633
+ int n_threads; // number of threads
634
+ enum wsp_ggml_sched_priority prio; // thread priority
635
+ uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
636
+ bool strict_cpu; // strict cpu placement
637
+ bool paused; // start in paused state
638
+ };
639
+
640
+ struct wsp_ggml_threadpool; // forward declaration, see ggml.c
641
+
642
+ typedef struct wsp_ggml_threadpool * wsp_ggml_threadpool_t;
643
+
542
644
  // the compute plan that needs to be prepared for wsp_ggml_graph_compute()
543
645
  // since https://github.com/ggerganov/ggml/issues/287
544
646
  struct wsp_ggml_cplan {
@@ -546,44 +648,15 @@ extern "C" {
546
648
  uint8_t * work_data; // work buffer, to be allocated by caller before calling to `wsp_ggml_graph_compute()`
547
649
 
548
650
  int n_threads;
651
+ struct wsp_ggml_threadpool * threadpool;
549
652
 
550
653
  // abort wsp_ggml_graph_compute when true
551
- bool (*abort_callback)(void * data);
552
- void * abort_callback_data;
553
- };
554
-
555
- enum wsp_ggml_cgraph_eval_order {
556
- WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
557
- WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
558
- WSP_GGML_CGRAPH_EVAL_ORDER_COUNT
559
- };
560
-
561
- struct wsp_ggml_hash_set {
562
- size_t size;
563
- struct wsp_ggml_tensor ** keys;
564
- };
565
-
566
- // computation graph
567
- struct wsp_ggml_cgraph {
568
- int size;
569
- int n_nodes;
570
- int n_leafs;
571
-
572
- struct wsp_ggml_tensor ** nodes;
573
- struct wsp_ggml_tensor ** grads;
574
- struct wsp_ggml_tensor ** leafs;
575
-
576
- struct wsp_ggml_hash_set visited_hash_table;
577
-
578
- enum wsp_ggml_cgraph_eval_order order;
579
-
580
- // performance
581
- int perf_runs;
582
- int64_t perf_cycles;
583
- int64_t perf_time_us;
654
+ wsp_ggml_abort_callback abort_callback;
655
+ void * abort_callback_data;
584
656
  };
585
657
 
586
658
  // scratch buffer
659
+ // TODO: deprecate and remove
587
660
  struct wsp_ggml_scratch {
588
661
  size_t offs;
589
662
  size_t size;
@@ -597,27 +670,25 @@ extern "C" {
597
670
  bool no_alloc; // don't allocate memory for the tensor data
598
671
  };
599
672
 
600
-
601
- // compute types
602
-
603
- // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
604
- // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
605
- enum wsp_ggml_task_type {
606
- WSP_GGML_TASK_INIT = 0,
607
- WSP_GGML_TASK_COMPUTE,
608
- WSP_GGML_TASK_FINALIZE,
673
+ // numa strategies
674
+ enum wsp_ggml_numa_strategy {
675
+ WSP_GGML_NUMA_STRATEGY_DISABLED = 0,
676
+ WSP_GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
677
+ WSP_GGML_NUMA_STRATEGY_ISOLATE = 2,
678
+ WSP_GGML_NUMA_STRATEGY_NUMACTL = 3,
679
+ WSP_GGML_NUMA_STRATEGY_MIRROR = 4,
680
+ WSP_GGML_NUMA_STRATEGY_COUNT
609
681
  };
610
682
 
611
- struct wsp_ggml_compute_params {
612
- enum wsp_ggml_task_type type;
683
+ //
684
+ // GUID
685
+ //
613
686
 
614
- // ith = thread index, nth = number of threads
615
- int ith, nth;
687
+ // GUID types
688
+ typedef uint8_t wsp_ggml_guid[16];
689
+ typedef wsp_ggml_guid * wsp_ggml_guid_t;
616
690
 
617
- // work buffer for all threads
618
- size_t wsize;
619
- void * wdata;
620
- };
691
+ WSP_GGML_API bool wsp_ggml_guid_matches(wsp_ggml_guid_t guid_a, wsp_ggml_guid_t guid_b);
621
692
 
622
693
  // misc
623
694
 
@@ -627,23 +698,27 @@ extern "C" {
627
698
  WSP_GGML_API int64_t wsp_ggml_cycles(void);
628
699
  WSP_GGML_API int64_t wsp_ggml_cycles_per_ms(void);
629
700
 
630
- WSP_GGML_API void wsp_ggml_print_backtrace(void);
701
+ // accepts a UTF-8 path, even on Windows
702
+ WSP_GGML_API FILE * wsp_ggml_fopen(const char * fname, const char * mode);
631
703
 
632
- WSP_GGML_API void wsp_ggml_numa_init(void); // call once for better performance on NUMA systems
704
+ WSP_GGML_API void wsp_ggml_numa_init(enum wsp_ggml_numa_strategy numa); // call once for better performance on NUMA systems
633
705
  WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node
634
706
 
635
707
  WSP_GGML_API void wsp_ggml_print_object (const struct wsp_ggml_object * obj);
636
708
  WSP_GGML_API void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx);
637
709
 
638
- WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
639
- WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
640
- WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
641
- WSP_GGML_API size_t wsp_ggml_nbytes_pad (const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
642
- WSP_GGML_API size_t wsp_ggml_nbytes_split(const struct wsp_ggml_tensor * tensor, int nrows_split);
710
+ WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
711
+ WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
712
+ WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
713
+ WSP_GGML_API size_t wsp_ggml_nbytes_pad(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
643
714
 
644
- WSP_GGML_API int wsp_ggml_blck_size (enum wsp_ggml_type type);
645
- WSP_GGML_API size_t wsp_ggml_type_size (enum wsp_ggml_type type); // size in bytes for all elements in a block
646
- WSP_GGML_API float wsp_ggml_type_sizef(enum wsp_ggml_type type); // wsp_ggml_type_size()/wsp_ggml_blck_size() as float
715
+ WSP_GGML_API int64_t wsp_ggml_blck_size(enum wsp_ggml_type type);
716
+ WSP_GGML_API size_t wsp_ggml_type_size(enum wsp_ggml_type type); // size in bytes for all elements in a block
717
+ WSP_GGML_API size_t wsp_ggml_row_size (enum wsp_ggml_type type, int64_t ne); // size in bytes for all elements in a row
718
+
719
+ WSP_GGML_DEPRECATED(
720
+ WSP_GGML_API double wsp_ggml_type_sizef(enum wsp_ggml_type type), // wsp_ggml_type_size()/wsp_ggml_blck_size() as float
721
+ "use wsp_ggml_row_size() instead");
647
722
 
648
723
  WSP_GGML_API const char * wsp_ggml_type_name(enum wsp_ggml_type type);
649
724
  WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op);
@@ -660,18 +735,34 @@ extern "C" {
660
735
  WSP_GGML_API enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype);
661
736
 
662
737
  WSP_GGML_API bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor);
663
- WSP_GGML_API bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor);
664
738
  WSP_GGML_API bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor);
739
+ WSP_GGML_API bool wsp_ggml_is_empty (const struct wsp_ggml_tensor * tensor);
740
+ WSP_GGML_API bool wsp_ggml_is_scalar (const struct wsp_ggml_tensor * tensor);
741
+ WSP_GGML_API bool wsp_ggml_is_vector (const struct wsp_ggml_tensor * tensor);
742
+ WSP_GGML_API bool wsp_ggml_is_matrix (const struct wsp_ggml_tensor * tensor);
743
+ WSP_GGML_API bool wsp_ggml_is_3d (const struct wsp_ggml_tensor * tensor);
744
+ WSP_GGML_API int wsp_ggml_n_dims (const struct wsp_ggml_tensor * tensor); // returns 1 for scalars
745
+
746
+ WSP_GGML_API bool wsp_ggml_is_contiguous (const struct wsp_ggml_tensor * tensor);
747
+ WSP_GGML_API bool wsp_ggml_is_contiguous_0(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_is_contiguous()
748
+ WSP_GGML_API bool wsp_ggml_is_contiguous_1(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 1
749
+ WSP_GGML_API bool wsp_ggml_is_contiguous_2(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 2
665
750
 
666
- WSP_GGML_API bool wsp_ggml_are_same_shape(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
751
+ WSP_GGML_API bool wsp_ggml_are_same_shape (const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
752
+ WSP_GGML_API bool wsp_ggml_are_same_stride(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
753
+
754
+ WSP_GGML_API bool wsp_ggml_can_repeat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
667
755
 
668
756
  // use this to compute the memory overhead of a tensor
669
757
  WSP_GGML_API size_t wsp_ggml_tensor_overhead(void);
670
758
 
759
+ WSP_GGML_API bool wsp_ggml_validate_row_data(enum wsp_ggml_type type, const void * data, size_t nbytes);
760
+
671
761
  // main
672
762
 
673
- WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params);
674
- WSP_GGML_API void wsp_ggml_free(struct wsp_ggml_context * ctx);
763
+ WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init (struct wsp_ggml_init_params params);
764
+ WSP_GGML_API void wsp_ggml_reset(struct wsp_ggml_context * ctx);
765
+ WSP_GGML_API void wsp_ggml_free (struct wsp_ggml_context * ctx);
675
766
 
676
767
  WSP_GGML_API size_t wsp_ggml_used_mem(const struct wsp_ggml_context * ctx);
677
768
 
@@ -722,8 +813,8 @@ extern "C" {
722
813
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src);
723
814
 
724
815
  // Context tensor enumeration and lookup
725
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx);
726
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
816
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(const struct wsp_ggml_context * ctx);
817
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (const struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
727
818
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name);
728
819
 
729
820
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
@@ -869,6 +960,22 @@ extern "C" {
869
960
  struct wsp_ggml_context * ctx,
870
961
  struct wsp_ggml_tensor * a);
871
962
 
963
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin(
964
+ struct wsp_ggml_context * ctx,
965
+ struct wsp_ggml_tensor * a);
966
+
967
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin_inplace(
968
+ struct wsp_ggml_context * ctx,
969
+ struct wsp_ggml_tensor * a);
970
+
971
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos(
972
+ struct wsp_ggml_context * ctx,
973
+ struct wsp_ggml_tensor * a);
974
+
975
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos_inplace(
976
+ struct wsp_ggml_context * ctx,
977
+ struct wsp_ggml_tensor * a);
978
+
872
979
  // return scalar
873
980
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sum(
874
981
  struct wsp_ggml_context * ctx,
@@ -889,6 +996,12 @@ extern "C" {
889
996
  struct wsp_ggml_context * ctx,
890
997
  struct wsp_ggml_tensor * a);
891
998
 
999
+ // count number of equal elements in a and b
1000
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_count_equal(
1001
+ struct wsp_ggml_context * ctx,
1002
+ struct wsp_ggml_tensor * a,
1003
+ struct wsp_ggml_tensor * b);
1004
+
892
1005
  // if a is the same shape as b, and a is not parameter, return a
893
1006
  // otherwise, return a new tensor: repeat(a) to fit in b
894
1007
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat(
@@ -902,12 +1015,13 @@ extern "C" {
902
1015
  struct wsp_ggml_tensor * a,
903
1016
  struct wsp_ggml_tensor * b);
904
1017
 
905
- // concat a and b on dim 2
1018
+ // concat a and b along dim
906
1019
  // used in stable-diffusion
907
1020
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_concat(
908
1021
  struct wsp_ggml_context * ctx,
909
1022
  struct wsp_ggml_tensor * a,
910
- struct wsp_ggml_tensor * b);
1023
+ struct wsp_ggml_tensor * b,
1024
+ int dim);
911
1025
 
912
1026
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_abs(
913
1027
  struct wsp_ggml_context * ctx,
@@ -969,6 +1083,14 @@ extern "C" {
969
1083
  struct wsp_ggml_context * ctx,
970
1084
  struct wsp_ggml_tensor * a);
971
1085
 
1086
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid(
1087
+ struct wsp_ggml_context * ctx,
1088
+ struct wsp_ggml_tensor * a);
1089
+
1090
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid_inplace(
1091
+ struct wsp_ggml_context * ctx,
1092
+ struct wsp_ggml_tensor * a);
1093
+
972
1094
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu(
973
1095
  struct wsp_ggml_context * ctx,
974
1096
  struct wsp_ggml_tensor * a);
@@ -1000,6 +1122,24 @@ extern "C" {
1000
1122
  struct wsp_ggml_tensor * a,
1001
1123
  struct wsp_ggml_tensor * b);
1002
1124
 
1125
+ // hardswish(x) = x * relu6(x + 3) / 6
1126
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardswish(
1127
+ struct wsp_ggml_context * ctx,
1128
+ struct wsp_ggml_tensor * a);
1129
+
1130
+ // hardsigmoid(x) = relu6(x + 3) / 6
1131
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardsigmoid(
1132
+ struct wsp_ggml_context * ctx,
1133
+ struct wsp_ggml_tensor * a);
1134
+
1135
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp(
1136
+ struct wsp_ggml_context * ctx,
1137
+ struct wsp_ggml_tensor * a);
1138
+
1139
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp_inplace(
1140
+ struct wsp_ggml_context * ctx,
1141
+ struct wsp_ggml_tensor * a);
1142
+
1003
1143
  // normalize along rows
1004
1144
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm(
1005
1145
  struct wsp_ggml_context * ctx,
@@ -1023,16 +1163,17 @@ extern "C" {
1023
1163
 
1024
1164
  // group normalize along ne0*ne1*n_groups
1025
1165
  // used in stable-diffusion
1026
- // TODO: eps is hardcoded to 1e-6 for now
1027
1166
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm(
1028
1167
  struct wsp_ggml_context * ctx,
1029
1168
  struct wsp_ggml_tensor * a,
1030
- int n_groups);
1169
+ int n_groups,
1170
+ float eps);
1031
1171
 
1032
1172
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm_inplace(
1033
1173
  struct wsp_ggml_context * ctx,
1034
1174
  struct wsp_ggml_tensor * a,
1035
- int n_groups);
1175
+ int n_groups,
1176
+ float eps);
1036
1177
 
1037
1178
  // a - x
1038
1179
  // b - dy
@@ -1050,15 +1191,18 @@ extern "C" {
1050
1191
  struct wsp_ggml_tensor * a,
1051
1192
  struct wsp_ggml_tensor * b);
1052
1193
 
1194
+ // change the precision of a matrix multiplication
1195
+ // set to WSP_GGML_PREC_F32 for higher precision (useful for phi-2)
1196
+ WSP_GGML_API void wsp_ggml_mul_mat_set_prec(
1197
+ struct wsp_ggml_tensor * a,
1198
+ enum wsp_ggml_prec prec);
1199
+
1053
1200
  // indirect matrix multiplication
1054
- // wsp_ggml_mul_mat_id(ctx, as, ids, id, b) ~= wsp_ggml_mul_mat(as[ids[id]], b)
1055
1201
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
1056
1202
  struct wsp_ggml_context * ctx,
1057
- struct wsp_ggml_tensor * const as[],
1058
- int n_as,
1059
- struct wsp_ggml_tensor * ids,
1060
- int id,
1061
- struct wsp_ggml_tensor * b);
1203
+ struct wsp_ggml_tensor * as,
1204
+ struct wsp_ggml_tensor * b,
1205
+ struct wsp_ggml_tensor * ids);
1062
1206
 
1063
1207
  // A: m columns, n rows,
1064
1208
  // B: p columns, n rows,
@@ -1075,13 +1219,13 @@ extern "C" {
1075
1219
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale(
1076
1220
  struct wsp_ggml_context * ctx,
1077
1221
  struct wsp_ggml_tensor * a,
1078
- struct wsp_ggml_tensor * b);
1222
+ float s);
1079
1223
 
1080
1224
  // in-place, returns view(a)
1081
1225
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale_inplace(
1082
1226
  struct wsp_ggml_context * ctx,
1083
1227
  struct wsp_ggml_tensor * a,
1084
- struct wsp_ggml_tensor * b);
1228
+ float s);
1085
1229
 
1086
1230
  // b -> view(a,offset,nb1,nb2,3), return modified a
1087
1231
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set(
@@ -1091,7 +1235,7 @@ extern "C" {
1091
1235
  size_t nb1,
1092
1236
  size_t nb2,
1093
1237
  size_t nb3,
1094
- size_t offset);
1238
+ size_t offset); // in bytes
1095
1239
 
1096
1240
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1097
1241
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_inplace(
@@ -1101,19 +1245,19 @@ extern "C" {
1101
1245
  size_t nb1,
1102
1246
  size_t nb2,
1103
1247
  size_t nb3,
1104
- size_t offset);
1248
+ size_t offset); // in bytes
1105
1249
 
1106
1250
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d(
1107
1251
  struct wsp_ggml_context * ctx,
1108
1252
  struct wsp_ggml_tensor * a,
1109
1253
  struct wsp_ggml_tensor * b,
1110
- size_t offset);
1254
+ size_t offset); // in bytes
1111
1255
 
1112
1256
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d_inplace(
1113
1257
  struct wsp_ggml_context * ctx,
1114
1258
  struct wsp_ggml_tensor * a,
1115
1259
  struct wsp_ggml_tensor * b,
1116
- size_t offset);
1260
+ size_t offset); // in bytes
1117
1261
 
1118
1262
  // b -> view(a,offset,nb1,nb2,3), return modified a
1119
1263
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d(
@@ -1121,7 +1265,7 @@ extern "C" {
1121
1265
  struct wsp_ggml_tensor * a,
1122
1266
  struct wsp_ggml_tensor * b,
1123
1267
  size_t nb1,
1124
- size_t offset);
1268
+ size_t offset); // in bytes
1125
1269
 
1126
1270
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1127
1271
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace(
@@ -1129,7 +1273,7 @@ extern "C" {
1129
1273
  struct wsp_ggml_tensor * a,
1130
1274
  struct wsp_ggml_tensor * b,
1131
1275
  size_t nb1,
1132
- size_t offset);
1276
+ size_t offset); // in bytes
1133
1277
 
1134
1278
  // a -> b, return view(b)
1135
1279
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy(
@@ -1137,22 +1281,16 @@ extern "C" {
1137
1281
  struct wsp_ggml_tensor * a,
1138
1282
  struct wsp_ggml_tensor * b);
1139
1283
 
1140
- // a -> b, in-place, return view(b)
1141
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy_inplace(
1284
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cast(
1142
1285
  struct wsp_ggml_context * ctx,
1143
1286
  struct wsp_ggml_tensor * a,
1144
- struct wsp_ggml_tensor * b);
1287
+ enum wsp_ggml_type type);
1145
1288
 
1146
1289
  // make contiguous
1147
1290
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont(
1148
1291
  struct wsp_ggml_context * ctx,
1149
1292
  struct wsp_ggml_tensor * a);
1150
1293
 
1151
- // make contiguous, in-place
1152
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_inplace(
1153
- struct wsp_ggml_context * ctx,
1154
- struct wsp_ggml_tensor * a);
1155
-
1156
1294
  // make contiguous, with new shape
1157
1295
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_1d(
1158
1296
  struct wsp_ggml_context * ctx,
@@ -1270,14 +1408,14 @@ extern "C" {
1270
1408
  // supports 3D: a->ne[2] == b->ne[1]
1271
1409
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows(
1272
1410
  struct wsp_ggml_context * ctx,
1273
- struct wsp_ggml_tensor * a,
1274
- struct wsp_ggml_tensor * b);
1411
+ struct wsp_ggml_tensor * a, // data
1412
+ struct wsp_ggml_tensor * b); // row indices
1275
1413
 
1276
1414
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows_back(
1277
1415
  struct wsp_ggml_context * ctx,
1278
- struct wsp_ggml_tensor * a,
1279
- struct wsp_ggml_tensor * b,
1280
- struct wsp_ggml_tensor * c);
1416
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_get_rows result
1417
+ struct wsp_ggml_tensor * b, // row indices
1418
+ struct wsp_ggml_tensor * c); // data for wsp_ggml_get_rows, only used for its shape
1281
1419
 
1282
1420
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag(
1283
1421
  struct wsp_ggml_context * ctx,
@@ -1316,13 +1454,15 @@ extern "C" {
1316
1454
  struct wsp_ggml_context * ctx,
1317
1455
  struct wsp_ggml_tensor * a);
1318
1456
 
1319
- // fused soft_max(a*scale + mask)
1457
+ // fused soft_max(a*scale + mask*(ALiBi slope))
1320
1458
  // mask is optional
1459
+ // max_bias = 0.0f for no ALiBi
1321
1460
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
1322
1461
  struct wsp_ggml_context * ctx,
1323
1462
  struct wsp_ggml_tensor * a,
1324
1463
  struct wsp_ggml_tensor * mask,
1325
- float scale);
1464
+ float scale,
1465
+ float max_bias);
1326
1466
 
1327
1467
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
1328
1468
  struct wsp_ggml_context * ctx,
@@ -1336,9 +1476,8 @@ extern "C" {
1336
1476
  struct wsp_ggml_tensor * b);
1337
1477
 
1338
1478
  // rotary position embedding
1339
- // if mode & 1 == 1, skip n_past elements (DEPRECATED)
1340
- // if mode & 2 == 1, GPT-NeoX style
1341
- // if mode & 4 == 1, ChatGLM style
1479
+ // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
1480
+ // if (mode & WSP_GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
1342
1481
  //
1343
1482
  // b is an int32 vector with size a->ne[2], it contains the positions
1344
1483
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope(
@@ -1346,8 +1485,7 @@ extern "C" {
1346
1485
  struct wsp_ggml_tensor * a,
1347
1486
  struct wsp_ggml_tensor * b,
1348
1487
  int n_dims,
1349
- int mode,
1350
- int n_ctx);
1488
+ int mode);
1351
1489
 
1352
1490
  // in-place, returns view(a)
1353
1491
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
@@ -1355,18 +1493,18 @@ extern "C" {
1355
1493
  struct wsp_ggml_tensor * a,
1356
1494
  struct wsp_ggml_tensor * b,
1357
1495
  int n_dims,
1358
- int mode,
1359
- int n_ctx);
1496
+ int mode);
1360
1497
 
1361
1498
  // custom RoPE
1362
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1499
+ // c is freq factors (e.g. phi3-128k), (optional)
1500
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext(
1363
1501
  struct wsp_ggml_context * ctx,
1364
1502
  struct wsp_ggml_tensor * a,
1365
1503
  struct wsp_ggml_tensor * b,
1504
+ struct wsp_ggml_tensor * c,
1366
1505
  int n_dims,
1367
1506
  int mode,
1368
- int n_ctx,
1369
- int n_orig_ctx,
1507
+ int n_ctx_orig,
1370
1508
  float freq_base,
1371
1509
  float freq_scale,
1372
1510
  float ext_factor,
@@ -1375,14 +1513,14 @@ extern "C" {
1375
1513
  float beta_slow);
1376
1514
 
1377
1515
  // in-place, returns view(a)
1378
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1516
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext_inplace(
1379
1517
  struct wsp_ggml_context * ctx,
1380
1518
  struct wsp_ggml_tensor * a,
1381
1519
  struct wsp_ggml_tensor * b,
1520
+ struct wsp_ggml_tensor * c,
1382
1521
  int n_dims,
1383
1522
  int mode,
1384
- int n_ctx,
1385
- int n_orig_ctx,
1523
+ int n_ctx_orig,
1386
1524
  float freq_base,
1387
1525
  float freq_scale,
1388
1526
  float ext_factor,
@@ -1390,46 +1528,56 @@ extern "C" {
1390
1528
  float beta_fast,
1391
1529
  float beta_slow);
1392
1530
 
1393
- // compute correction dims for YaRN RoPE scaling
1394
- void wsp_ggml_rope_yarn_corr_dims(
1395
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1396
-
1397
- // xPos RoPE, in-place, returns view(a)
1398
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace(
1531
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1399
1532
  struct wsp_ggml_context * ctx,
1400
1533
  struct wsp_ggml_tensor * a,
1401
1534
  struct wsp_ggml_tensor * b,
1402
1535
  int n_dims,
1403
- float base,
1404
- bool down);
1536
+ int mode,
1537
+ int n_ctx_orig,
1538
+ float freq_base,
1539
+ float freq_scale,
1540
+ float ext_factor,
1541
+ float attn_factor,
1542
+ float beta_fast,
1543
+ float beta_slow),
1544
+ "use wsp_ggml_rope_ext instead");
1405
1545
 
1406
- // rotary position embedding backward, i.e compute dx from dy
1407
- // a - dy
1408
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1546
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1409
1547
  struct wsp_ggml_context * ctx,
1410
1548
  struct wsp_ggml_tensor * a,
1411
1549
  struct wsp_ggml_tensor * b,
1412
1550
  int n_dims,
1413
1551
  int mode,
1414
- int n_ctx,
1415
- int n_orig_ctx,
1552
+ int n_ctx_orig,
1416
1553
  float freq_base,
1417
1554
  float freq_scale,
1418
1555
  float ext_factor,
1419
1556
  float attn_factor,
1420
1557
  float beta_fast,
1421
- float beta_slow,
1422
- float xpos_base,
1423
- bool xpos_down);
1558
+ float beta_slow),
1559
+ "use wsp_ggml_rope_ext_inplace instead");
1424
1560
 
1425
- // alibi position embedding
1426
- // in-place, returns view(a)
1427
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_alibi(
1561
+ // compute correction dims for YaRN RoPE scaling
1562
+ void wsp_ggml_rope_yarn_corr_dims(
1563
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1564
+
1565
+ // rotary position embedding backward, i.e compute dx from dy
1566
+ // a - dy
1567
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1428
1568
  struct wsp_ggml_context * ctx,
1429
- struct wsp_ggml_tensor * a,
1430
- int n_past,
1431
- int n_head,
1432
- float bias_max);
1569
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_rope result
1570
+ struct wsp_ggml_tensor * b, // positions
1571
+ struct wsp_ggml_tensor * c, // freq factors
1572
+ int n_dims,
1573
+ int mode,
1574
+ int n_ctx_orig,
1575
+ float freq_base,
1576
+ float freq_scale,
1577
+ float ext_factor,
1578
+ float attn_factor,
1579
+ float beta_fast,
1580
+ float beta_slow);
1433
1581
 
1434
1582
  // clamp
1435
1583
  // in-place, returns view(a)
@@ -1439,22 +1587,49 @@ extern "C" {
1439
1587
  float min,
1440
1588
  float max);
1441
1589
 
1590
+ // im2col
1591
+ // converts data into a format that effectively results in a convolution when combined with matrix multiplication
1442
1592
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col(
1443
1593
  struct wsp_ggml_context * ctx,
1444
- struct wsp_ggml_tensor * a,
1445
- struct wsp_ggml_tensor * b,
1446
- int s0,
1447
- int s1,
1448
- int p0,
1449
- int p1,
1450
- int d0,
1451
- int d1,
1452
- bool is_2D);
1594
+ struct wsp_ggml_tensor * a, // convolution kernel
1595
+ struct wsp_ggml_tensor * b, // data
1596
+ int s0, // stride dimension 0
1597
+ int s1, // stride dimension 1
1598
+ int p0, // padding dimension 0
1599
+ int p1, // padding dimension 1
1600
+ int d0, // dilation dimension 0
1601
+ int d1, // dilation dimension 1
1602
+ bool is_2D,
1603
+ enum wsp_ggml_type dst_type);
1604
+
1605
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col_back(
1606
+ struct wsp_ggml_context * ctx,
1607
+ struct wsp_ggml_tensor * a, // convolution kernel
1608
+ struct wsp_ggml_tensor * b, // gradient of im2col output
1609
+ int64_t * ne, // shape of im2col input
1610
+ int s0, // stride dimension 0
1611
+ int s1, // stride dimension 1
1612
+ int p0, // padding dimension 0
1613
+ int p1, // padding dimension 1
1614
+ int d0, // dilation dimension 0
1615
+ int d1, // dilation dimension 1
1616
+ bool is_2D);
1617
+
1618
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_depthwise_2d(
1619
+ struct wsp_ggml_context * ctx,
1620
+ struct wsp_ggml_tensor * a, // convolution kernel
1621
+ struct wsp_ggml_tensor * b, // data
1622
+ int s0, // stride dimension 0
1623
+ int s1, // stride dimension 1
1624
+ int p0, // padding dimension 0
1625
+ int p1, // padding dimension 1
1626
+ int d0, // dilation dimension 0
1627
+ int d1); // dilation dimension 1
1453
1628
 
1454
1629
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
1455
1630
  struct wsp_ggml_context * ctx,
1456
- struct wsp_ggml_tensor * a,
1457
- struct wsp_ggml_tensor * b,
1631
+ struct wsp_ggml_tensor * a, // convolution kernel
1632
+ struct wsp_ggml_tensor * b, // data
1458
1633
  int s0, // stride
1459
1634
  int p0, // padding
1460
1635
  int d0); // dilation
@@ -1463,29 +1638,29 @@ extern "C" {
1463
1638
  // alias for wsp_ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1464
1639
  WSP_GGML_API struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph(
1465
1640
  struct wsp_ggml_context * ctx,
1466
- struct wsp_ggml_tensor * a,
1467
- struct wsp_ggml_tensor * b,
1468
- int s,
1469
- int d);
1641
+ struct wsp_ggml_tensor * a, // convolution kernel
1642
+ struct wsp_ggml_tensor * b, // data
1643
+ int s, // stride
1644
+ int d); // dilation
1470
1645
 
1471
1646
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
1472
1647
  struct wsp_ggml_context * ctx,
1473
- struct wsp_ggml_tensor * a,
1474
- struct wsp_ggml_tensor * b,
1475
- int s0,
1476
- int p0,
1477
- int d0);
1648
+ struct wsp_ggml_tensor * a, // convolution kernel
1649
+ struct wsp_ggml_tensor * b, // data
1650
+ int s0, // stride
1651
+ int p0, // padding
1652
+ int d0); // dilation
1478
1653
 
1479
1654
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d(
1480
1655
  struct wsp_ggml_context * ctx,
1481
- struct wsp_ggml_tensor * a,
1482
- struct wsp_ggml_tensor * b,
1483
- int s0,
1484
- int s1,
1485
- int p0,
1486
- int p1,
1487
- int d0,
1488
- int d1);
1656
+ struct wsp_ggml_tensor * a, // convolution kernel
1657
+ struct wsp_ggml_tensor * b, // data
1658
+ int s0, // stride dimension 0
1659
+ int s1, // stride dimension 1
1660
+ int p0, // padding dimension 0
1661
+ int p1, // padding dimension 1
1662
+ int d0, // dilation dimension 0
1663
+ int d1); // dilation dimension 1
1489
1664
 
1490
1665
 
1491
1666
  // kernel size is a->ne[0] x a->ne[1]
@@ -1547,13 +1722,37 @@ extern "C" {
1547
1722
  float p0,
1548
1723
  float p1);
1549
1724
 
1725
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d_back(
1726
+ struct wsp_ggml_context * ctx,
1727
+ struct wsp_ggml_tensor * a,
1728
+ struct wsp_ggml_tensor * af, // "a"/input used in forward pass
1729
+ enum wsp_ggml_op_pool op,
1730
+ int k0,
1731
+ int k1,
1732
+ int s0,
1733
+ int s1,
1734
+ float p0,
1735
+ float p1);
1736
+
1550
1737
  // nearest interpolate
1738
+ // multiplies ne0 and ne1 by scale factor
1551
1739
  // used in stable-diffusion
1552
1740
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale(
1553
1741
  struct wsp_ggml_context * ctx,
1554
1742
  struct wsp_ggml_tensor * a,
1555
1743
  int scale_factor);
1556
1744
 
1745
+ // nearest interpolate
1746
+ // nearest interpolate to specified dimensions
1747
+ // used in tortoise.cpp
1748
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
1749
+ struct wsp_ggml_context * ctx,
1750
+ struct wsp_ggml_tensor * a,
1751
+ int ne0,
1752
+ int ne1,
1753
+ int ne2,
1754
+ int ne3);
1755
+
1557
1756
  // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1558
1757
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad(
1559
1758
  struct wsp_ggml_context * ctx,
@@ -1563,10 +1762,19 @@ extern "C" {
1563
1762
  int p2,
1564
1763
  int p3);
1565
1764
 
1765
+ // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
1766
+ // timesteps: [N,]
1767
+ // return: [N, dim]
1768
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_timestep_embedding(
1769
+ struct wsp_ggml_context * ctx,
1770
+ struct wsp_ggml_tensor * timesteps,
1771
+ int dim,
1772
+ int max_period);
1773
+
1566
1774
  // sort rows
1567
1775
  enum wsp_ggml_sort_order {
1568
- WSP_GGML_SORT_ASC,
1569
- WSP_GGML_SORT_DESC,
1776
+ WSP_GGML_SORT_ORDER_ASC,
1777
+ WSP_GGML_SORT_ORDER_DESC,
1570
1778
  };
1571
1779
 
1572
1780
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_argsort(
@@ -1574,19 +1782,40 @@ extern "C" {
1574
1782
  struct wsp_ggml_tensor * a,
1575
1783
  enum wsp_ggml_sort_order order);
1576
1784
 
1785
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_arange(
1786
+ struct wsp_ggml_context * ctx,
1787
+ float start,
1788
+ float stop,
1789
+ float step);
1790
+
1577
1791
  // top k elements per row
1578
1792
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_top_k(
1579
1793
  struct wsp_ggml_context * ctx,
1580
1794
  struct wsp_ggml_tensor * a,
1581
1795
  int k);
1582
1796
 
1583
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn(
1797
+ #define WSP_GGML_KQ_MASK_PAD 32
1798
+
1799
+ // q: [n_embd, n_batch, n_head, 1]
1800
+ // k: [n_embd, n_kv, n_head_kv, 1]
1801
+ // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1802
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = WSP_GGML_PAD(n_batch, WSP_GGML_KQ_MASK_PAD) !!
1803
+ // res: [n_embd, n_head, n_batch, 1] !! permuted !!
1804
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_ext(
1584
1805
  struct wsp_ggml_context * ctx,
1585
1806
  struct wsp_ggml_tensor * q,
1586
1807
  struct wsp_ggml_tensor * k,
1587
1808
  struct wsp_ggml_tensor * v,
1588
- bool masked);
1809
+ struct wsp_ggml_tensor * mask,
1810
+ float scale,
1811
+ float max_bias,
1812
+ float logit_softcap);
1813
+
1814
+ WSP_GGML_API void wsp_ggml_flash_attn_ext_set_prec(
1815
+ struct wsp_ggml_tensor * a,
1816
+ enum wsp_ggml_prec prec);
1589
1817
 
1818
+ // TODO: needs to be adapted to wsp_ggml_flash_attn_ext
1590
1819
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_back(
1591
1820
  struct wsp_ggml_context * ctx,
1592
1821
  struct wsp_ggml_tensor * q,
@@ -1595,13 +1824,19 @@ extern "C" {
1595
1824
  struct wsp_ggml_tensor * d,
1596
1825
  bool masked);
1597
1826
 
1598
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_ff(
1827
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_conv(
1599
1828
  struct wsp_ggml_context * ctx,
1600
- struct wsp_ggml_tensor * a,
1601
- struct wsp_ggml_tensor * b0,
1602
- struct wsp_ggml_tensor * b1,
1603
- struct wsp_ggml_tensor * c0,
1604
- struct wsp_ggml_tensor * c1);
1829
+ struct wsp_ggml_tensor * sx,
1830
+ struct wsp_ggml_tensor * c);
1831
+
1832
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_scan(
1833
+ struct wsp_ggml_context * ctx,
1834
+ struct wsp_ggml_tensor * s,
1835
+ struct wsp_ggml_tensor * x,
1836
+ struct wsp_ggml_tensor * dt,
1837
+ struct wsp_ggml_tensor * A,
1838
+ struct wsp_ggml_tensor * B,
1839
+ struct wsp_ggml_tensor * C);
1605
1840
 
1606
1841
  // partition into non-overlapping windows with padding if needed
1607
1842
  // example:
@@ -1653,6 +1888,15 @@ extern "C" {
1653
1888
  struct wsp_ggml_tensor * pw,
1654
1889
  struct wsp_ggml_tensor * ph);
1655
1890
 
1891
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv(
1892
+ struct wsp_ggml_context * ctx,
1893
+ struct wsp_ggml_tensor * k,
1894
+ struct wsp_ggml_tensor * v,
1895
+ struct wsp_ggml_tensor * r,
1896
+ struct wsp_ggml_tensor * tf,
1897
+ struct wsp_ggml_tensor * td,
1898
+ struct wsp_ggml_tensor * state);
1899
+
1656
1900
  // custom operators
1657
1901
 
1658
1902
  typedef void (*wsp_ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -1736,7 +1980,8 @@ extern "C" {
1736
1980
  typedef void (*wsp_ggml_custom2_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, int ith, int nth, void * userdata);
1737
1981
  typedef void (*wsp_ggml_custom3_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, const struct wsp_ggml_tensor * c, int ith, int nth, void * userdata);
1738
1982
 
1739
- #define WSP_GGML_N_TASKS_MAX -1
1983
+ #define WSP_GGML_N_TASKS_MAX (-1)
1984
+ // n_tasks == WSP_GGML_N_TASKS_MAX means to use max number of tasks
1740
1985
 
1741
1986
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1(
1742
1987
  struct wsp_ggml_context * ctx,
@@ -1789,48 +2034,87 @@ extern "C" {
1789
2034
  // loss function
1790
2035
 
1791
2036
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss(
1792
- struct wsp_ggml_context * ctx,
1793
- struct wsp_ggml_tensor * a,
1794
- struct wsp_ggml_tensor * b);
2037
+ struct wsp_ggml_context * ctx,
2038
+ struct wsp_ggml_tensor * a, // logits
2039
+ struct wsp_ggml_tensor * b); // labels
1795
2040
 
1796
2041
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss_back(
1797
- struct wsp_ggml_context * ctx,
1798
- struct wsp_ggml_tensor * a,
1799
- struct wsp_ggml_tensor * b,
1800
- struct wsp_ggml_tensor * c);
2042
+ struct wsp_ggml_context * ctx,
2043
+ struct wsp_ggml_tensor * a, // logits
2044
+ struct wsp_ggml_tensor * b, // labels
2045
+ struct wsp_ggml_tensor * c); // gradients of cross_entropy_loss result
2046
+
2047
+ // AdamW optimizer step
2048
+ // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
2049
+ // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
2050
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_step_adamw(
2051
+ struct wsp_ggml_context * ctx,
2052
+ struct wsp_ggml_tensor * a,
2053
+ struct wsp_ggml_tensor * grad,
2054
+ float alpha,
2055
+ float beta1,
2056
+ float beta2,
2057
+ float eps,
2058
+ float wd); // weight decay
1801
2059
 
1802
2060
  //
1803
2061
  // automatic differentiation
1804
2062
  //
1805
2063
 
1806
- WSP_GGML_API void wsp_ggml_set_param(
1807
- struct wsp_ggml_context * ctx,
1808
- struct wsp_ggml_tensor * tensor);
1809
-
2064
+ WSP_GGML_API void wsp_ggml_set_param(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
2065
+ WSP_GGML_API void wsp_ggml_set_loss(struct wsp_ggml_tensor * tensor);
1810
2066
 
1811
2067
  WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1812
- WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep);
2068
+ WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool accumulate);
2069
+
2070
+ WSP_GGML_API void wsp_ggml_build_opt_adamw(
2071
+ struct wsp_ggml_context * ctx,
2072
+ struct wsp_ggml_cgraph * gf,
2073
+ struct wsp_ggml_cgraph * gb,
2074
+ float alpha,
2075
+ float beta1,
2076
+ float beta2,
2077
+ float eps,
2078
+ float wd); // weight decay
1813
2079
 
1814
2080
  // graph allocation in a context
1815
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
1816
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom (struct wsp_ggml_context * ctx, size_t size, bool grads);
1817
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
1818
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_view (struct wsp_ggml_cgraph * cgraph, int i0, int i1);
1819
- WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
1820
- WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // zero grads
1821
- WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
2081
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
2082
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx, size_t size, bool grads);
2083
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
2084
+ WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
2085
+ WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2086
+ WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
2087
+
2088
+ WSP_GGML_API int wsp_ggml_graph_size (struct wsp_ggml_cgraph * cgraph);
2089
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_node (struct wsp_ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]
2090
+ WSP_GGML_API struct wsp_ggml_tensor ** wsp_ggml_graph_nodes (struct wsp_ggml_cgraph * cgraph);
2091
+ WSP_GGML_API int wsp_ggml_graph_n_nodes(struct wsp_ggml_cgraph * cgraph);
2092
+
2093
+ WSP_GGML_API void wsp_ggml_graph_add_node(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1822
2094
 
1823
2095
  WSP_GGML_API size_t wsp_ggml_graph_overhead(void);
1824
2096
  WSP_GGML_API size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads);
1825
2097
 
2098
+ WSP_GGML_API struct wsp_ggml_threadpool_params wsp_ggml_threadpool_params_default(int n_threads);
2099
+ WSP_GGML_API void wsp_ggml_threadpool_params_init (struct wsp_ggml_threadpool_params * p, int n_threads);
2100
+ WSP_GGML_API bool wsp_ggml_threadpool_params_match (const struct wsp_ggml_threadpool_params * p0, const struct wsp_ggml_threadpool_params * p1);
2101
+ WSP_GGML_API struct wsp_ggml_threadpool * wsp_ggml_threadpool_new (struct wsp_ggml_threadpool_params * params);
2102
+ WSP_GGML_API void wsp_ggml_threadpool_free (struct wsp_ggml_threadpool * threadpool);
2103
+ WSP_GGML_API int wsp_ggml_threadpool_get_n_threads(struct wsp_ggml_threadpool * threadpool);
2104
+ WSP_GGML_API void wsp_ggml_threadpool_pause (struct wsp_ggml_threadpool * threadpool);
2105
+ WSP_GGML_API void wsp_ggml_threadpool_resume (struct wsp_ggml_threadpool * threadpool);
2106
+
1826
2107
  // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute()
1827
2108
  // when plan.work_size > 0, caller must allocate memory for plan.work_data
1828
- WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/);
1829
- WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
2109
+ WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan(
2110
+ const struct wsp_ggml_cgraph * cgraph,
2111
+ int n_threads, /* = WSP_GGML_DEFAULT_N_THREADS */
2112
+ struct wsp_ggml_threadpool * threadpool /* = NULL */ );
2113
+ WSP_GGML_API enum wsp_ggml_status wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1830
2114
 
1831
2115
  // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context
1832
2116
  // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
1833
- WSP_GGML_API void wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads);
2117
+ WSP_GGML_API enum wsp_ggml_status wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads);
1834
2118
 
1835
2119
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name);
1836
2120
 
@@ -1859,8 +2143,8 @@ extern "C" {
1859
2143
 
1860
2144
  // optimization methods
1861
2145
  enum wsp_ggml_opt_type {
1862
- WSP_GGML_OPT_ADAM,
1863
- WSP_GGML_OPT_LBFGS,
2146
+ WSP_GGML_OPT_TYPE_ADAM,
2147
+ WSP_GGML_OPT_TYPE_LBFGS,
1864
2148
  };
1865
2149
 
1866
2150
  // linesearch methods
@@ -1874,12 +2158,12 @@ extern "C" {
1874
2158
 
1875
2159
  // optimization return values
1876
2160
  enum wsp_ggml_opt_result {
1877
- WSP_GGML_OPT_OK = 0,
1878
- WSP_GGML_OPT_DID_NOT_CONVERGE,
1879
- WSP_GGML_OPT_NO_CONTEXT,
1880
- WSP_GGML_OPT_INVALID_WOLFE,
1881
- WSP_GGML_OPT_FAIL,
1882
- WSP_GGML_OPT_CANCEL,
2161
+ WSP_GGML_OPT_RESULT_OK = 0,
2162
+ WSP_GGML_OPT_RESULT_DID_NOT_CONVERGE,
2163
+ WSP_GGML_OPT_RESULT_NO_CONTEXT,
2164
+ WSP_GGML_OPT_RESULT_INVALID_WOLFE,
2165
+ WSP_GGML_OPT_RESULT_FAIL,
2166
+ WSP_GGML_OPT_RESULT_CANCEL,
1883
2167
 
1884
2168
  WSP_GGML_LINESEARCH_FAIL = -128,
1885
2169
  WSP_GGML_LINESEARCH_MINIMUM_STEP,
@@ -1891,6 +2175,10 @@ extern "C" {
1891
2175
  typedef void (*wsp_ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
1892
2176
  typedef void (*wsp_ggml_log_callback)(enum wsp_ggml_log_level level, const char * text, void * user_data);
1893
2177
 
2178
+ // Set callback for all future logging events.
2179
+ // If this is not called, or NULL is supplied, everything is output on stderr.
2180
+ WSP_GGML_API void wsp_ggml_log_set(wsp_ggml_log_callback log_callback, void * user_data);
2181
+
1894
2182
  // optimization parameters
1895
2183
  //
1896
2184
  // see ggml.c (wsp_ggml_opt_default_params) for default values
@@ -2030,23 +2318,39 @@ extern "C" {
2030
2318
  void * callback_data);
2031
2319
 
2032
2320
  //
2033
- // quantization
2321
+ // tensor flags
2034
2322
  //
2323
+ WSP_GGML_API void wsp_ggml_set_input(struct wsp_ggml_tensor * tensor);
2324
+ WSP_GGML_API void wsp_ggml_set_output(struct wsp_ggml_tensor * tensor);
2035
2325
 
2036
- // TODO: these would probably get removed in favor of the more general wsp_ggml_wsp_quantize_chunk
2037
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
2038
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
2039
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
2040
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
2041
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
2042
-
2043
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
2044
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
2045
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
2046
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2047
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
2326
+ //
2327
+ // quantization
2328
+ //
2048
2329
 
2049
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
2330
+ // - wsp_ggml_wsp_quantize_init can be called multiple times with the same type
2331
+ // it will only initialize the quantization tables for the first call or after wsp_ggml_wsp_quantize_free
2332
+ // automatically called by wsp_ggml_wsp_quantize_chunk for convenience
2333
+ //
2334
+ // - wsp_ggml_wsp_quantize_free will free any memory allocated by wsp_ggml_wsp_quantize_init
2335
+ // call this at the end of the program to avoid memory leaks
2336
+ //
2337
+ // note: these are thread-safe
2338
+ //
2339
+ WSP_GGML_API void wsp_ggml_wsp_quantize_init(enum wsp_ggml_type type);
2340
+ WSP_GGML_API void wsp_ggml_wsp_quantize_free(void);
2341
+
2342
+ // some quantization type cannot be used without an importance matrix
2343
+ WSP_GGML_API bool wsp_ggml_wsp_quantize_requires_imatrix(enum wsp_ggml_type type);
2344
+
2345
+ // calls wsp_ggml_wsp_quantize_init internally (i.e. can allocate memory)
2346
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(
2347
+ enum wsp_ggml_type type,
2348
+ const float * src,
2349
+ void * dst,
2350
+ int64_t start,
2351
+ int64_t nrows,
2352
+ int64_t n_per_row,
2353
+ const float * imatrix);
2050
2354
 
2051
2355
  //
2052
2356
  // gguf
@@ -2116,10 +2420,14 @@ extern "C" {
2116
2420
  WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id);
2117
2421
  WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i);
2118
2422
 
2119
- WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx);
2120
- WSP_GGML_API int wsp_gguf_find_tensor (const struct wsp_gguf_context * ctx, const char * name);
2121
- WSP_GGML_API size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i);
2122
- WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i);
2423
+ WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx);
2424
+ WSP_GGML_API int wsp_gguf_find_tensor (const struct wsp_gguf_context * ctx, const char * name);
2425
+ WSP_GGML_API size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i);
2426
+ WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i);
2427
+ WSP_GGML_API enum wsp_ggml_type wsp_gguf_get_tensor_type (const struct wsp_gguf_context * ctx, int i);
2428
+
2429
+ // removes key if it exists
2430
+ WSP_GGML_API void wsp_gguf_remove_key(struct wsp_gguf_context * ctx, const char * key);
2123
2431
 
2124
2432
  // overrides existing values or adds a new one
2125
2433
  WSP_GGML_API void wsp_gguf_set_val_u8 (struct wsp_gguf_context * ctx, const char * key, uint8_t val);
@@ -2175,24 +2483,38 @@ extern "C" {
2175
2483
  //
2176
2484
 
2177
2485
  WSP_GGML_API int wsp_ggml_cpu_has_avx (void);
2486
+ WSP_GGML_API int wsp_ggml_cpu_has_avx_vnni (void);
2178
2487
  WSP_GGML_API int wsp_ggml_cpu_has_avx2 (void);
2179
2488
  WSP_GGML_API int wsp_ggml_cpu_has_avx512 (void);
2180
2489
  WSP_GGML_API int wsp_ggml_cpu_has_avx512_vbmi(void);
2181
2490
  WSP_GGML_API int wsp_ggml_cpu_has_avx512_vnni(void);
2491
+ WSP_GGML_API int wsp_ggml_cpu_has_avx512_bf16(void);
2492
+ WSP_GGML_API int wsp_ggml_cpu_has_amx_int8 (void);
2182
2493
  WSP_GGML_API int wsp_ggml_cpu_has_fma (void);
2183
2494
  WSP_GGML_API int wsp_ggml_cpu_has_neon (void);
2495
+ WSP_GGML_API int wsp_ggml_cpu_has_sve (void);
2184
2496
  WSP_GGML_API int wsp_ggml_cpu_has_arm_fma (void);
2185
2497
  WSP_GGML_API int wsp_ggml_cpu_has_metal (void);
2186
2498
  WSP_GGML_API int wsp_ggml_cpu_has_f16c (void);
2187
2499
  WSP_GGML_API int wsp_ggml_cpu_has_fp16_va (void);
2188
2500
  WSP_GGML_API int wsp_ggml_cpu_has_wasm_simd (void);
2189
2501
  WSP_GGML_API int wsp_ggml_cpu_has_blas (void);
2190
- WSP_GGML_API int wsp_ggml_cpu_has_cublas (void);
2191
- WSP_GGML_API int wsp_ggml_cpu_has_clblast (void);
2502
+ WSP_GGML_API int wsp_ggml_cpu_has_cuda (void);
2503
+ WSP_GGML_API int wsp_ggml_cpu_has_vulkan (void);
2504
+ WSP_GGML_API int wsp_ggml_cpu_has_kompute (void);
2192
2505
  WSP_GGML_API int wsp_ggml_cpu_has_gpublas (void);
2193
2506
  WSP_GGML_API int wsp_ggml_cpu_has_sse3 (void);
2194
2507
  WSP_GGML_API int wsp_ggml_cpu_has_ssse3 (void);
2508
+ WSP_GGML_API int wsp_ggml_cpu_has_riscv_v (void);
2509
+ WSP_GGML_API int wsp_ggml_cpu_has_sycl (void);
2510
+ WSP_GGML_API int wsp_ggml_cpu_has_rpc (void);
2195
2511
  WSP_GGML_API int wsp_ggml_cpu_has_vsx (void);
2512
+ WSP_GGML_API int wsp_ggml_cpu_has_matmul_int8(void);
2513
+ WSP_GGML_API int wsp_ggml_cpu_has_cann (void);
2514
+ WSP_GGML_API int wsp_ggml_cpu_has_llamafile (void);
2515
+
2516
+ // get the sve vector length in bytes
2517
+ WSP_GGML_API int wsp_ggml_cpu_get_sve_cnt(void);
2196
2518
 
2197
2519
  //
2198
2520
  // Internal types and functions exposed for tests and benchmarks
@@ -2204,23 +2526,36 @@ extern "C" {
2204
2526
  #else
2205
2527
  #define WSP_GGML_RESTRICT restrict
2206
2528
  #endif
2207
- typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int k);
2208
- typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int k);
2209
- typedef void (*wsp_ggml_vec_dot_t) (const int n, float * WSP_GGML_RESTRICT s, const void * WSP_GGML_RESTRICT x, const void * WSP_GGML_RESTRICT y);
2210
-
2211
- typedef struct {
2212
- const char * type_name;
2213
- int blck_size;
2214
- size_t type_size;
2215
- bool is_quantized;
2216
- wsp_ggml_to_float_t to_float;
2217
- wsp_ggml_from_float_t from_float;
2218
- wsp_ggml_from_float_t from_float_reference;
2219
- wsp_ggml_vec_dot_t vec_dot;
2220
- enum wsp_ggml_type vec_dot_type;
2221
- } wsp_ggml_type_traits_t;
2222
-
2223
- WSP_GGML_API wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
2529
+ typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
2530
+ typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int64_t k);
2531
+ typedef void (*wsp_ggml_from_float_to_mat_t)
2532
+ (const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
2533
+ typedef void (*wsp_ggml_vec_dot_t) (int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT x, size_t bx,
2534
+ const void * WSP_GGML_RESTRICT y, size_t by, int nrc);
2535
+ typedef void (*wsp_ggml_gemv_t) (int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT x,
2536
+ const void * WSP_GGML_RESTRICT y, int nr, int nc);
2537
+ typedef void (*wsp_ggml_gemm_t) (int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT x,
2538
+ const void * WSP_GGML_RESTRICT y, int nr, int nc);
2539
+
2540
+ struct wsp_ggml_type_traits {
2541
+ const char * type_name;
2542
+ int64_t blck_size;
2543
+ int64_t blck_size_interleave; // interleave elements in blocks
2544
+ size_t type_size;
2545
+ bool is_quantized;
2546
+ wsp_ggml_to_float_t to_float;
2547
+ wsp_ggml_from_float_t from_float;
2548
+ wsp_ggml_from_float_t from_float_ref;
2549
+ wsp_ggml_from_float_to_mat_t from_float_to_mat;
2550
+ wsp_ggml_vec_dot_t vec_dot;
2551
+ enum wsp_ggml_type vec_dot_type;
2552
+ int64_t nrows; // number of rows to process simultaneously
2553
+ int64_t ncols; // number of columns to process simultaneously
2554
+ wsp_ggml_gemv_t gemv;
2555
+ wsp_ggml_gemm_t gemm;
2556
+ };
2557
+
2558
+ WSP_GGML_API const struct wsp_ggml_type_traits * wsp_ggml_get_type_traits(enum wsp_ggml_type type);
2224
2559
 
2225
2560
  #ifdef __cplusplus
2226
2561
  }