whisper.rn 0.4.0-rc.8 → 0.4.0-rc.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. package/android/src/main/CMakeLists.txt +2 -1
  2. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  3. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
  5. package/android/src/main/jni.cpp +29 -1
  6. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  7. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  8. package/cpp/ggml-aarch64.c +3209 -0
  9. package/cpp/ggml-aarch64.h +39 -0
  10. package/cpp/ggml-alloc.c +725 -517
  11. package/cpp/ggml-alloc.h +47 -65
  12. package/cpp/ggml-backend-impl.h +166 -55
  13. package/cpp/ggml-backend.cpp +2635 -0
  14. package/cpp/ggml-backend.h +202 -85
  15. package/cpp/ggml-common.h +1853 -0
  16. package/cpp/ggml-cpu-impl.h +614 -0
  17. package/cpp/ggml-impl.h +143 -180
  18. package/cpp/ggml-metal.h +13 -11
  19. package/cpp/ggml-metal.m +2955 -1632
  20. package/cpp/ggml-quants.c +9824 -3263
  21. package/cpp/ggml-quants.h +133 -248
  22. package/cpp/ggml-whisper.metallib +0 -0
  23. package/cpp/ggml.c +8482 -5142
  24. package/cpp/ggml.h +633 -349
  25. package/cpp/rn-whisper.cpp +91 -0
  26. package/cpp/rn-whisper.h +2 -0
  27. package/cpp/whisper.cpp +1427 -658
  28. package/cpp/whisper.h +84 -28
  29. package/ios/RNWhisper.mm +124 -37
  30. package/ios/RNWhisperAudioUtils.h +1 -0
  31. package/ios/RNWhisperAudioUtils.m +20 -13
  32. package/ios/RNWhisperContext.h +3 -2
  33. package/ios/RNWhisperContext.mm +39 -7
  34. package/jest/mock.js +9 -1
  35. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  36. package/lib/commonjs/index.js +48 -19
  37. package/lib/commonjs/index.js.map +1 -1
  38. package/lib/commonjs/version.json +1 -1
  39. package/lib/module/NativeRNWhisper.js.map +1 -1
  40. package/lib/module/index.js +48 -19
  41. package/lib/module/index.js.map +1 -1
  42. package/lib/module/version.json +1 -1
  43. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  44. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  45. package/lib/typescript/index.d.ts +25 -3
  46. package/lib/typescript/index.d.ts.map +1 -1
  47. package/package.json +6 -5
  48. package/src/NativeRNWhisper.ts +12 -3
  49. package/src/index.ts +63 -24
  50. package/src/version.json +1 -1
  51. package/whisper-rn.podspec +9 -2
  52. package/cpp/ggml-backend.c +0 -1718
  53. package/cpp/ggml-metal-whisper.metal +0 -5820
package/cpp/ggml.h CHANGED
@@ -187,16 +187,6 @@
187
187
  # define WSP_GGML_API
188
188
  #endif
189
189
 
190
- #ifdef WSP_GGML_MULTIPLATFORM
191
- # if defined(_WIN32)
192
- # define WSP_GGML_CALL
193
- # else
194
- # define WSP_GGML_CALL __attribute__((__ms_abi__))
195
- # endif
196
- #else
197
- # define WSP_GGML_CALL
198
- #endif
199
-
200
190
  // TODO: support for clang
201
191
  #ifdef __GNUC__
202
192
  # define WSP_GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
@@ -214,26 +204,30 @@
214
204
  # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
215
205
  #endif
216
206
 
217
- #include <stdint.h>
218
- #include <stddef.h>
219
207
  #include <stdbool.h>
208
+ #include <stddef.h>
209
+ #include <stdint.h>
210
+ #include <stdio.h>
220
211
 
221
212
  #define WSP_GGML_FILE_MAGIC 0x67676d6c // "ggml"
222
- #define WSP_GGML_FILE_VERSION 1
213
+ #define WSP_GGML_FILE_VERSION 2
223
214
 
224
215
  #define WSP_GGML_QNT_VERSION 2 // bump this on quantization format changes
225
216
  #define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this
226
217
 
227
218
  #define WSP_GGML_MAX_DIMS 4
228
219
  #define WSP_GGML_MAX_PARAMS 2048
229
- #define WSP_GGML_MAX_CONTEXTS 64
230
220
  #define WSP_GGML_MAX_SRC 10
221
+ #define WSP_GGML_MAX_N_THREADS 512
222
+ #define WSP_GGML_MAX_OP_PARAMS 64
223
+
231
224
  #ifndef WSP_GGML_MAX_NAME
232
- #define WSP_GGML_MAX_NAME 64
225
+ # define WSP_GGML_MAX_NAME 64
233
226
  #endif
234
- #define WSP_GGML_MAX_OP_PARAMS 64
227
+
235
228
  #define WSP_GGML_DEFAULT_N_THREADS 4
236
229
  #define WSP_GGML_DEFAULT_GRAPH_SIZE 2048
230
+
237
231
  #if UINTPTR_MAX == 0xFFFFFFFF
238
232
  #define WSP_GGML_MEM_ALIGN 4
239
233
  #else
@@ -243,6 +237,8 @@
243
237
  #define WSP_GGML_EXIT_SUCCESS 0
244
238
  #define WSP_GGML_EXIT_ABORTED 1
245
239
 
240
+ #define WSP_GGML_ROPE_TYPE_NEOX 2
241
+
246
242
  #define WSP_GGUF_MAGIC "GGUF"
247
243
 
248
244
  #define WSP_GGUF_VERSION 3
@@ -253,26 +249,27 @@
253
249
 
254
250
  #define WSP_GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
255
251
 
256
- #define WSP_GGML_ASSERT(x) \
257
- do { \
258
- if (!(x)) { \
259
- fflush(stdout); \
260
- fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
261
- wsp_ggml_print_backtrace(); \
262
- abort(); \
263
- } \
264
- } while (0)
265
-
266
252
  #ifndef NDEBUG
267
- #define WSP_GGML_UNREACHABLE() WSP_GGML_ASSERT(!"statement should not be reached")
253
+ # define WSP_GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
268
254
  #elif defined(__GNUC__)
269
- #define WSP_GGML_UNREACHABLE() __builtin_unreachable()
255
+ # define WSP_GGML_UNREACHABLE() __builtin_unreachable()
270
256
  #elif defined(_MSC_VER)
271
- #define WSP_GGML_UNREACHABLE() __assume(0)
257
+ # define WSP_GGML_UNREACHABLE() __assume(0)
272
258
  #else
273
- #define WSP_GGML_UNREACHABLE() ((void) 0)
259
+ # define WSP_GGML_UNREACHABLE() ((void) 0)
274
260
  #endif
275
261
 
262
+ #ifdef __cplusplus
263
+ # define WSP_GGML_NORETURN [[noreturn]]
264
+ #elif defined(_MSC_VER)
265
+ # define WSP_GGML_NORETURN __declspec(noreturn)
266
+ #else
267
+ # define WSP_GGML_NORETURN _Noreturn
268
+ #endif
269
+
270
+ #define WSP_GGML_ABORT(...) wsp_ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
271
+ #define WSP_GGML_ASSERT(x) if (!(x)) WSP_GGML_ABORT("WSP_GGML_ASSERT(%s) failed", #x)
272
+
276
273
  // used to copy the number of elements and stride in bytes of tensors into local variables.
277
274
  // main purpose is to reduce code duplication and improve readability.
278
275
  //
@@ -311,51 +308,87 @@
311
308
  WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
312
309
  WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
313
310
 
311
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS01 \
312
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
313
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
314
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
315
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
316
+
314
317
  #ifdef __cplusplus
315
318
  extern "C" {
316
319
  #endif
317
320
 
318
- #if defined(__ARM_NEON) && defined(__CUDACC__)
319
- typedef half wsp_ggml_fp16_t;
320
- #elif defined(__ARM_NEON) && !defined(_MSC_VER)
321
- typedef __fp16 wsp_ggml_fp16_t;
322
- #else
323
- typedef uint16_t wsp_ggml_fp16_t;
324
- #endif
321
+ WSP_GGML_NORETURN WSP_GGML_ATTRIBUTE_FORMAT(3, 4)
322
+ WSP_GGML_API void wsp_ggml_abort(const char * file, int line, const char * fmt, ...);
325
323
 
326
- // convert FP16 <-> FP32
327
- WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t x);
328
- WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float x);
324
+ enum wsp_ggml_status {
325
+ WSP_GGML_STATUS_ALLOC_FAILED = -2,
326
+ WSP_GGML_STATUS_FAILED = -1,
327
+ WSP_GGML_STATUS_SUCCESS = 0,
328
+ WSP_GGML_STATUS_ABORTED = 1,
329
+ };
330
+
331
+ // get wsp_ggml_status name string
332
+ WSP_GGML_API const char * wsp_ggml_status_to_string(enum wsp_ggml_status status);
329
333
 
330
- WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t * x, float * y, int n);
331
- WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float * x, wsp_ggml_fp16_t * y, int n);
334
+ // ieee 754-2008 half-precision float16
335
+ // todo: make this not an integral type
336
+ typedef uint16_t wsp_ggml_fp16_t;
337
+ WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t);
338
+ WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float);
339
+ WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t *, float *, int64_t);
340
+ WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float *, wsp_ggml_fp16_t *, int64_t);
341
+
342
+ // google brain half-precision bfloat16
343
+ typedef struct { uint16_t bits; } wsp_ggml_bf16_t;
344
+ WSP_GGML_API wsp_ggml_bf16_t wsp_ggml_fp32_to_bf16(float);
345
+ WSP_GGML_API float wsp_ggml_bf16_to_fp32(wsp_ggml_bf16_t); // consider just doing << 16
346
+ WSP_GGML_API void wsp_ggml_bf16_to_fp32_row(const wsp_ggml_bf16_t *, float *, int64_t);
347
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row_ref(const float *, wsp_ggml_bf16_t *, int64_t);
348
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row(const float *, wsp_ggml_bf16_t *, int64_t);
332
349
 
333
350
  struct wsp_ggml_object;
334
351
  struct wsp_ggml_context;
352
+ struct wsp_ggml_cgraph;
335
353
 
354
+ // NOTE: always add types at the end of the enum to keep backward compatibility
336
355
  enum wsp_ggml_type {
337
- WSP_GGML_TYPE_F32 = 0,
338
- WSP_GGML_TYPE_F16 = 1,
339
- WSP_GGML_TYPE_Q4_0 = 2,
340
- WSP_GGML_TYPE_Q4_1 = 3,
356
+ WSP_GGML_TYPE_F32 = 0,
357
+ WSP_GGML_TYPE_F16 = 1,
358
+ WSP_GGML_TYPE_Q4_0 = 2,
359
+ WSP_GGML_TYPE_Q4_1 = 3,
341
360
  // WSP_GGML_TYPE_Q4_2 = 4, support has been removed
342
- // WSP_GGML_TYPE_Q4_3 (5) support has been removed
343
- WSP_GGML_TYPE_Q5_0 = 6,
344
- WSP_GGML_TYPE_Q5_1 = 7,
345
- WSP_GGML_TYPE_Q8_0 = 8,
346
- WSP_GGML_TYPE_Q8_1 = 9,
347
- // k-quantizations
348
- WSP_GGML_TYPE_Q2_K = 10,
349
- WSP_GGML_TYPE_Q3_K = 11,
350
- WSP_GGML_TYPE_Q4_K = 12,
351
- WSP_GGML_TYPE_Q5_K = 13,
352
- WSP_GGML_TYPE_Q6_K = 14,
353
- WSP_GGML_TYPE_Q8_K = 15,
361
+ // WSP_GGML_TYPE_Q4_3 = 5, support has been removed
362
+ WSP_GGML_TYPE_Q5_0 = 6,
363
+ WSP_GGML_TYPE_Q5_1 = 7,
364
+ WSP_GGML_TYPE_Q8_0 = 8,
365
+ WSP_GGML_TYPE_Q8_1 = 9,
366
+ WSP_GGML_TYPE_Q2_K = 10,
367
+ WSP_GGML_TYPE_Q3_K = 11,
368
+ WSP_GGML_TYPE_Q4_K = 12,
369
+ WSP_GGML_TYPE_Q5_K = 13,
370
+ WSP_GGML_TYPE_Q6_K = 14,
371
+ WSP_GGML_TYPE_Q8_K = 15,
354
372
  WSP_GGML_TYPE_IQ2_XXS = 16,
355
373
  WSP_GGML_TYPE_IQ2_XS = 17,
356
- WSP_GGML_TYPE_I8,
357
- WSP_GGML_TYPE_I16,
358
- WSP_GGML_TYPE_I32,
374
+ WSP_GGML_TYPE_IQ3_XXS = 18,
375
+ WSP_GGML_TYPE_IQ1_S = 19,
376
+ WSP_GGML_TYPE_IQ4_NL = 20,
377
+ WSP_GGML_TYPE_IQ3_S = 21,
378
+ WSP_GGML_TYPE_IQ2_S = 22,
379
+ WSP_GGML_TYPE_IQ4_XS = 23,
380
+ WSP_GGML_TYPE_I8 = 24,
381
+ WSP_GGML_TYPE_I16 = 25,
382
+ WSP_GGML_TYPE_I32 = 26,
383
+ WSP_GGML_TYPE_I64 = 27,
384
+ WSP_GGML_TYPE_F64 = 28,
385
+ WSP_GGML_TYPE_IQ1_M = 29,
386
+ WSP_GGML_TYPE_BF16 = 30,
387
+ WSP_GGML_TYPE_Q4_0_4_4 = 31,
388
+ WSP_GGML_TYPE_Q4_0_4_8 = 32,
389
+ WSP_GGML_TYPE_Q4_0_8_8 = 33,
390
+ WSP_GGML_TYPE_TQ1_0 = 34,
391
+ WSP_GGML_TYPE_TQ2_0 = 35,
359
392
  WSP_GGML_TYPE_COUNT,
360
393
  };
361
394
 
@@ -366,29 +399,40 @@ extern "C" {
366
399
  };
367
400
 
368
401
  enum wsp_ggml_backend_type {
369
- WSP_GGML_BACKEND_CPU = 0,
370
- WSP_GGML_BACKEND_GPU = 10,
371
- WSP_GGML_BACKEND_GPU_SPLIT = 20,
402
+ WSP_GGML_BACKEND_TYPE_CPU = 0,
403
+ WSP_GGML_BACKEND_TYPE_GPU = 10,
404
+ WSP_GGML_BACKEND_TYPE_GPU_SPLIT = 20,
372
405
  };
373
406
 
374
407
  // model file types
375
408
  enum wsp_ggml_ftype {
376
- WSP_GGML_FTYPE_UNKNOWN = -1,
377
- WSP_GGML_FTYPE_ALL_F32 = 0,
378
- WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
379
- WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
380
- WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
409
+ WSP_GGML_FTYPE_UNKNOWN = -1,
410
+ WSP_GGML_FTYPE_ALL_F32 = 0,
411
+ WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
412
+ WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
413
+ WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
381
414
  WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
382
- WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
383
- WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
384
- WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
385
- WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
386
- WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
387
- WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
388
- WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
389
- WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
415
+ WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
416
+ WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
417
+ WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
418
+ WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
419
+ WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
420
+ WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
421
+ WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
422
+ WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
390
423
  WSP_GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
391
424
  WSP_GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
425
+ WSP_GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
426
+ WSP_GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
427
+ WSP_GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
428
+ WSP_GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
429
+ WSP_GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
430
+ WSP_GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
431
+ WSP_GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
432
+ WSP_GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
433
+ WSP_GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
434
+ WSP_GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
435
+ WSP_GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
392
436
  };
393
437
 
394
438
  // available tensor operations:
@@ -405,10 +449,13 @@ extern "C" {
405
449
  WSP_GGML_OP_SQR,
406
450
  WSP_GGML_OP_SQRT,
407
451
  WSP_GGML_OP_LOG,
452
+ WSP_GGML_OP_SIN,
453
+ WSP_GGML_OP_COS,
408
454
  WSP_GGML_OP_SUM,
409
455
  WSP_GGML_OP_SUM_ROWS,
410
456
  WSP_GGML_OP_MEAN,
411
457
  WSP_GGML_OP_ARGMAX,
458
+ WSP_GGML_OP_COUNT_EQUAL,
412
459
  WSP_GGML_OP_REPEAT,
413
460
  WSP_GGML_OP_REPEAT_BACK,
414
461
  WSP_GGML_OP_CONCAT,
@@ -439,25 +486,30 @@ extern "C" {
439
486
  WSP_GGML_OP_SOFT_MAX_BACK,
440
487
  WSP_GGML_OP_ROPE,
441
488
  WSP_GGML_OP_ROPE_BACK,
442
- WSP_GGML_OP_ALIBI,
443
489
  WSP_GGML_OP_CLAMP,
444
490
  WSP_GGML_OP_CONV_TRANSPOSE_1D,
445
491
  WSP_GGML_OP_IM2COL,
492
+ WSP_GGML_OP_IM2COL_BACK,
446
493
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
447
494
  WSP_GGML_OP_POOL_1D,
448
495
  WSP_GGML_OP_POOL_2D,
496
+ WSP_GGML_OP_POOL_2D_BACK,
449
497
  WSP_GGML_OP_UPSCALE, // nearest interpolate
450
498
  WSP_GGML_OP_PAD,
499
+ WSP_GGML_OP_ARANGE,
500
+ WSP_GGML_OP_TIMESTEP_EMBEDDING,
451
501
  WSP_GGML_OP_ARGSORT,
452
502
  WSP_GGML_OP_LEAKY_RELU,
453
503
 
454
- WSP_GGML_OP_FLASH_ATTN,
455
- WSP_GGML_OP_FLASH_FF,
504
+ WSP_GGML_OP_FLASH_ATTN_EXT,
456
505
  WSP_GGML_OP_FLASH_ATTN_BACK,
506
+ WSP_GGML_OP_SSM_CONV,
507
+ WSP_GGML_OP_SSM_SCAN,
457
508
  WSP_GGML_OP_WIN_PART,
458
509
  WSP_GGML_OP_WIN_UNPART,
459
510
  WSP_GGML_OP_GET_REL_POS,
460
511
  WSP_GGML_OP_ADD_REL_POS,
512
+ WSP_GGML_OP_RWKV_WKV,
461
513
 
462
514
  WSP_GGML_OP_UNARY,
463
515
 
@@ -474,6 +526,7 @@ extern "C" {
474
526
 
475
527
  WSP_GGML_OP_CROSS_ENTROPY_LOSS,
476
528
  WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
529
+ WSP_GGML_OP_OPT_STEP_ADAMW,
477
530
 
478
531
  WSP_GGML_OP_COUNT,
479
532
  };
@@ -486,44 +539,45 @@ extern "C" {
486
539
  WSP_GGML_UNARY_OP_TANH,
487
540
  WSP_GGML_UNARY_OP_ELU,
488
541
  WSP_GGML_UNARY_OP_RELU,
542
+ WSP_GGML_UNARY_OP_SIGMOID,
489
543
  WSP_GGML_UNARY_OP_GELU,
490
544
  WSP_GGML_UNARY_OP_GELU_QUICK,
491
545
  WSP_GGML_UNARY_OP_SILU,
546
+ WSP_GGML_UNARY_OP_HARDSWISH,
547
+ WSP_GGML_UNARY_OP_HARDSIGMOID,
548
+ WSP_GGML_UNARY_OP_EXP,
492
549
 
493
550
  WSP_GGML_UNARY_OP_COUNT,
494
551
  };
495
552
 
496
553
  enum wsp_ggml_object_type {
497
- WSP_GGML_OBJECT_TENSOR,
498
- WSP_GGML_OBJECT_GRAPH,
499
- WSP_GGML_OBJECT_WORK_BUFFER
554
+ WSP_GGML_OBJECT_TYPE_TENSOR,
555
+ WSP_GGML_OBJECT_TYPE_GRAPH,
556
+ WSP_GGML_OBJECT_TYPE_WORK_BUFFER
500
557
  };
501
558
 
502
559
  enum wsp_ggml_log_level {
503
- WSP_GGML_LOG_LEVEL_ERROR = 2,
504
- WSP_GGML_LOG_LEVEL_WARN = 3,
505
- WSP_GGML_LOG_LEVEL_INFO = 4,
506
- WSP_GGML_LOG_LEVEL_DEBUG = 5
560
+ WSP_GGML_LOG_LEVEL_NONE = 0,
561
+ WSP_GGML_LOG_LEVEL_INFO = 1,
562
+ WSP_GGML_LOG_LEVEL_WARN = 2,
563
+ WSP_GGML_LOG_LEVEL_ERROR = 3,
564
+ WSP_GGML_LOG_LEVEL_DEBUG = 4,
565
+ WSP_GGML_LOG_LEVEL_CONT = 5, // continue previous log
507
566
  };
508
567
 
509
- // ggml object
510
- struct wsp_ggml_object {
511
- size_t offs;
512
- size_t size;
513
-
514
- struct wsp_ggml_object * next;
515
-
516
- enum wsp_ggml_object_type type;
517
-
518
- char padding[4];
568
+ // this tensor...
569
+ enum wsp_ggml_tensor_flag {
570
+ WSP_GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
571
+ WSP_GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
572
+ WSP_GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
573
+ WSP_GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
519
574
  };
520
575
 
521
- static const size_t WSP_GGML_OBJECT_SIZE = sizeof(struct wsp_ggml_object);
522
-
523
576
  // n-dimensional tensor
524
577
  struct wsp_ggml_tensor {
525
- enum wsp_ggml_type type;
526
- enum wsp_ggml_backend_type backend;
578
+ enum wsp_ggml_type type;
579
+
580
+ WSP_GGML_DEPRECATED(enum wsp_ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
527
581
 
528
582
  struct wsp_ggml_backend_buffer * buffer;
529
583
 
@@ -539,16 +593,12 @@ extern "C" {
539
593
  // op params - allocated as int32_t for alignment
540
594
  int32_t op_params[WSP_GGML_MAX_OP_PARAMS / sizeof(int32_t)];
541
595
 
542
- bool is_param;
596
+ int32_t flags;
543
597
 
544
598
  struct wsp_ggml_tensor * grad;
545
599
  struct wsp_ggml_tensor * src[WSP_GGML_MAX_SRC];
546
600
 
547
- // performance
548
- int perf_runs;
549
- int64_t perf_cycles;
550
- int64_t perf_time_us;
551
-
601
+ // source tensor and offset for views
552
602
  struct wsp_ggml_tensor * view_src;
553
603
  size_t view_offs;
554
604
 
@@ -558,11 +608,39 @@ extern "C" {
558
608
 
559
609
  void * extra; // extra things e.g. for ggml-cuda.cu
560
610
 
561
- char padding[8];
611
+ // char padding[4];
562
612
  };
563
613
 
564
614
  static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor);
565
615
 
616
+ // Abort callback
617
+ // If not NULL, called before ggml computation
618
+ // If it returns true, the computation is aborted
619
+ typedef bool (*wsp_ggml_abort_callback)(void * data);
620
+
621
+ // Scheduling priorities
622
+ enum wsp_ggml_sched_priority {
623
+ WSP_GGML_SCHED_PRIO_NORMAL,
624
+ WSP_GGML_SCHED_PRIO_MEDIUM,
625
+ WSP_GGML_SCHED_PRIO_HIGH,
626
+ WSP_GGML_SCHED_PRIO_REALTIME
627
+ };
628
+
629
+ // Threadpool params
630
+ // Use wsp_ggml_threadpool_params_default() or wsp_ggml_threadpool_params_init() to populate the defaults
631
+ struct wsp_ggml_threadpool_params {
632
+ bool cpumask[WSP_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
633
+ int n_threads; // number of threads
634
+ enum wsp_ggml_sched_priority prio; // thread priority
635
+ uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
636
+ bool strict_cpu; // strict cpu placement
637
+ bool paused; // start in paused state
638
+ };
639
+
640
+ struct wsp_ggml_threadpool; // forward declaration, see ggml.c
641
+
642
+ typedef struct wsp_ggml_threadpool * wsp_ggml_threadpool_t;
643
+
566
644
  // the compute plan that needs to be prepared for wsp_ggml_graph_compute()
567
645
  // since https://github.com/ggerganov/ggml/issues/287
568
646
  struct wsp_ggml_cplan {
@@ -570,44 +648,15 @@ extern "C" {
570
648
  uint8_t * work_data; // work buffer, to be allocated by caller before calling to `wsp_ggml_graph_compute()`
571
649
 
572
650
  int n_threads;
651
+ struct wsp_ggml_threadpool * threadpool;
573
652
 
574
653
  // abort wsp_ggml_graph_compute when true
575
- bool (*abort_callback)(void * data);
576
- void * abort_callback_data;
577
- };
578
-
579
- enum wsp_ggml_cgraph_eval_order {
580
- WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
581
- WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
582
- WSP_GGML_CGRAPH_EVAL_ORDER_COUNT
583
- };
584
-
585
- struct wsp_ggml_hash_set {
586
- size_t size;
587
- struct wsp_ggml_tensor ** keys;
588
- };
589
-
590
- // computation graph
591
- struct wsp_ggml_cgraph {
592
- int size;
593
- int n_nodes;
594
- int n_leafs;
595
-
596
- struct wsp_ggml_tensor ** nodes;
597
- struct wsp_ggml_tensor ** grads;
598
- struct wsp_ggml_tensor ** leafs;
599
-
600
- struct wsp_ggml_hash_set visited_hash_table;
601
-
602
- enum wsp_ggml_cgraph_eval_order order;
603
-
604
- // performance
605
- int perf_runs;
606
- int64_t perf_cycles;
607
- int64_t perf_time_us;
654
+ wsp_ggml_abort_callback abort_callback;
655
+ void * abort_callback_data;
608
656
  };
609
657
 
610
658
  // scratch buffer
659
+ // TODO: deprecate and remove
611
660
  struct wsp_ggml_scratch {
612
661
  size_t offs;
613
662
  size_t size;
@@ -621,27 +670,25 @@ extern "C" {
621
670
  bool no_alloc; // don't allocate memory for the tensor data
622
671
  };
623
672
 
624
-
625
- // compute types
626
-
627
- // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
628
- // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
629
- enum wsp_ggml_task_type {
630
- WSP_GGML_TASK_INIT = 0,
631
- WSP_GGML_TASK_COMPUTE,
632
- WSP_GGML_TASK_FINALIZE,
673
+ // numa strategies
674
+ enum wsp_ggml_numa_strategy {
675
+ WSP_GGML_NUMA_STRATEGY_DISABLED = 0,
676
+ WSP_GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
677
+ WSP_GGML_NUMA_STRATEGY_ISOLATE = 2,
678
+ WSP_GGML_NUMA_STRATEGY_NUMACTL = 3,
679
+ WSP_GGML_NUMA_STRATEGY_MIRROR = 4,
680
+ WSP_GGML_NUMA_STRATEGY_COUNT
633
681
  };
634
682
 
635
- struct wsp_ggml_compute_params {
636
- enum wsp_ggml_task_type type;
683
+ //
684
+ // GUID
685
+ //
637
686
 
638
- // ith = thread index, nth = number of threads
639
- int ith, nth;
687
+ // GUID types
688
+ typedef uint8_t wsp_ggml_guid[16];
689
+ typedef wsp_ggml_guid * wsp_ggml_guid_t;
640
690
 
641
- // work buffer for all threads
642
- size_t wsize;
643
- void * wdata;
644
- };
691
+ WSP_GGML_API bool wsp_ggml_guid_matches(wsp_ggml_guid_t guid_a, wsp_ggml_guid_t guid_b);
645
692
 
646
693
  // misc
647
694
 
@@ -651,59 +698,71 @@ extern "C" {
651
698
  WSP_GGML_API int64_t wsp_ggml_cycles(void);
652
699
  WSP_GGML_API int64_t wsp_ggml_cycles_per_ms(void);
653
700
 
654
- WSP_GGML_API void wsp_ggml_print_backtrace(void);
701
+ // accepts a UTF-8 path, even on Windows
702
+ WSP_GGML_API FILE * wsp_ggml_fopen(const char * fname, const char * mode);
655
703
 
656
- WSP_GGML_API void wsp_ggml_numa_init(void); // call once for better performance on NUMA systems
704
+ WSP_GGML_API void wsp_ggml_numa_init(enum wsp_ggml_numa_strategy numa); // call once for better performance on NUMA systems
657
705
  WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node
658
706
 
659
707
  WSP_GGML_API void wsp_ggml_print_object (const struct wsp_ggml_object * obj);
660
708
  WSP_GGML_API void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx);
661
709
 
662
- WSP_GGML_API WSP_GGML_CALL int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
663
- WSP_GGML_API WSP_GGML_CALL int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
664
- WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
665
- WSP_GGML_API size_t wsp_ggml_nbytes_pad (const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
710
+ WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
711
+ WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
712
+ WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
713
+ WSP_GGML_API size_t wsp_ggml_nbytes_pad(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
666
714
 
667
- WSP_GGML_API WSP_GGML_CALL int wsp_ggml_blck_size(enum wsp_ggml_type type);
668
- WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_type_size(enum wsp_ggml_type type); // size in bytes for all elements in a block
669
- WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_row_size (enum wsp_ggml_type type, int64_t ne); // size in bytes for all elements in a row
715
+ WSP_GGML_API int64_t wsp_ggml_blck_size(enum wsp_ggml_type type);
716
+ WSP_GGML_API size_t wsp_ggml_type_size(enum wsp_ggml_type type); // size in bytes for all elements in a block
717
+ WSP_GGML_API size_t wsp_ggml_row_size (enum wsp_ggml_type type, int64_t ne); // size in bytes for all elements in a row
670
718
 
671
719
  WSP_GGML_DEPRECATED(
672
720
  WSP_GGML_API double wsp_ggml_type_sizef(enum wsp_ggml_type type), // wsp_ggml_type_size()/wsp_ggml_blck_size() as float
673
721
  "use wsp_ggml_row_size() instead");
674
722
 
675
- WSP_GGML_API WSP_GGML_CALL const char * wsp_ggml_type_name(enum wsp_ggml_type type);
676
- WSP_GGML_API WSP_GGML_CALL const char * wsp_ggml_op_name (enum wsp_ggml_op op);
677
- WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
723
+ WSP_GGML_API const char * wsp_ggml_type_name(enum wsp_ggml_type type);
724
+ WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op);
725
+ WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
678
726
 
679
- WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
680
- WSP_GGML_API WSP_GGML_CALL const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
727
+ WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
728
+ WSP_GGML_API const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
681
729
 
682
- WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
730
+ WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
683
731
 
684
- WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_quantized(enum wsp_ggml_type type);
732
+ WSP_GGML_API bool wsp_ggml_is_quantized(enum wsp_ggml_type type);
685
733
 
686
734
  // TODO: temporary until model loading of ggml examples is refactored
687
735
  WSP_GGML_API enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype);
688
736
 
689
- WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor);
690
- WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor);
691
- WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor);
692
- WSP_GGML_API bool wsp_ggml_is_scalar (const struct wsp_ggml_tensor * tensor);
693
- WSP_GGML_API bool wsp_ggml_is_vector (const struct wsp_ggml_tensor * tensor);
694
- WSP_GGML_API bool wsp_ggml_is_matrix (const struct wsp_ggml_tensor * tensor);
695
- WSP_GGML_API bool wsp_ggml_is_3d (const struct wsp_ggml_tensor * tensor);
696
- WSP_GGML_API int wsp_ggml_n_dims (const struct wsp_ggml_tensor * tensor); // returns 1 for scalars
737
+ WSP_GGML_API bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor);
738
+ WSP_GGML_API bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor);
739
+ WSP_GGML_API bool wsp_ggml_is_empty (const struct wsp_ggml_tensor * tensor);
740
+ WSP_GGML_API bool wsp_ggml_is_scalar (const struct wsp_ggml_tensor * tensor);
741
+ WSP_GGML_API bool wsp_ggml_is_vector (const struct wsp_ggml_tensor * tensor);
742
+ WSP_GGML_API bool wsp_ggml_is_matrix (const struct wsp_ggml_tensor * tensor);
743
+ WSP_GGML_API bool wsp_ggml_is_3d (const struct wsp_ggml_tensor * tensor);
744
+ WSP_GGML_API int wsp_ggml_n_dims (const struct wsp_ggml_tensor * tensor); // returns 1 for scalars
745
+
746
+ WSP_GGML_API bool wsp_ggml_is_contiguous (const struct wsp_ggml_tensor * tensor);
747
+ WSP_GGML_API bool wsp_ggml_is_contiguous_0(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_is_contiguous()
748
+ WSP_GGML_API bool wsp_ggml_is_contiguous_1(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 1
749
+ WSP_GGML_API bool wsp_ggml_is_contiguous_2(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 2
697
750
 
698
- WSP_GGML_API bool wsp_ggml_are_same_shape(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
751
+ WSP_GGML_API bool wsp_ggml_are_same_shape (const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
752
+ WSP_GGML_API bool wsp_ggml_are_same_stride(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
753
+
754
+ WSP_GGML_API bool wsp_ggml_can_repeat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
699
755
 
700
756
  // use this to compute the memory overhead of a tensor
701
757
  WSP_GGML_API size_t wsp_ggml_tensor_overhead(void);
702
758
 
759
+ WSP_GGML_API bool wsp_ggml_validate_row_data(enum wsp_ggml_type type, const void * data, size_t nbytes);
760
+
703
761
  // main
704
762
 
705
- WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params);
706
- WSP_GGML_API void wsp_ggml_free(struct wsp_ggml_context * ctx);
763
+ WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init (struct wsp_ggml_init_params params);
764
+ WSP_GGML_API void wsp_ggml_reset(struct wsp_ggml_context * ctx);
765
+ WSP_GGML_API void wsp_ggml_free (struct wsp_ggml_context * ctx);
707
766
 
708
767
  WSP_GGML_API size_t wsp_ggml_used_mem(const struct wsp_ggml_context * ctx);
709
768
 
@@ -780,7 +839,7 @@ extern "C" {
780
839
  WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor);
781
840
  WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor);
782
841
 
783
- WSP_GGML_API WSP_GGML_CALL enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
842
+ WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
784
843
 
785
844
  WSP_GGML_API const char * wsp_ggml_get_name (const struct wsp_ggml_tensor * tensor);
786
845
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_name ( struct wsp_ggml_tensor * tensor, const char * name);
@@ -901,6 +960,22 @@ extern "C" {
901
960
  struct wsp_ggml_context * ctx,
902
961
  struct wsp_ggml_tensor * a);
903
962
 
963
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin(
964
+ struct wsp_ggml_context * ctx,
965
+ struct wsp_ggml_tensor * a);
966
+
967
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin_inplace(
968
+ struct wsp_ggml_context * ctx,
969
+ struct wsp_ggml_tensor * a);
970
+
971
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos(
972
+ struct wsp_ggml_context * ctx,
973
+ struct wsp_ggml_tensor * a);
974
+
975
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos_inplace(
976
+ struct wsp_ggml_context * ctx,
977
+ struct wsp_ggml_tensor * a);
978
+
904
979
  // return scalar
905
980
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sum(
906
981
  struct wsp_ggml_context * ctx,
@@ -921,6 +996,12 @@ extern "C" {
921
996
  struct wsp_ggml_context * ctx,
922
997
  struct wsp_ggml_tensor * a);
923
998
 
999
+ // count number of equal elements in a and b
1000
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_count_equal(
1001
+ struct wsp_ggml_context * ctx,
1002
+ struct wsp_ggml_tensor * a,
1003
+ struct wsp_ggml_tensor * b);
1004
+
924
1005
  // if a is the same shape as b, and a is not parameter, return a
925
1006
  // otherwise, return a new tensor: repeat(a) to fit in b
926
1007
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat(
@@ -934,12 +1015,13 @@ extern "C" {
934
1015
  struct wsp_ggml_tensor * a,
935
1016
  struct wsp_ggml_tensor * b);
936
1017
 
937
- // concat a and b on dim 2
1018
+ // concat a and b along dim
938
1019
  // used in stable-diffusion
939
1020
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_concat(
940
1021
  struct wsp_ggml_context * ctx,
941
1022
  struct wsp_ggml_tensor * a,
942
- struct wsp_ggml_tensor * b);
1023
+ struct wsp_ggml_tensor * b,
1024
+ int dim);
943
1025
 
944
1026
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_abs(
945
1027
  struct wsp_ggml_context * ctx,
@@ -1001,6 +1083,14 @@ extern "C" {
1001
1083
  struct wsp_ggml_context * ctx,
1002
1084
  struct wsp_ggml_tensor * a);
1003
1085
 
1086
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid(
1087
+ struct wsp_ggml_context * ctx,
1088
+ struct wsp_ggml_tensor * a);
1089
+
1090
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid_inplace(
1091
+ struct wsp_ggml_context * ctx,
1092
+ struct wsp_ggml_tensor * a);
1093
+
1004
1094
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu(
1005
1095
  struct wsp_ggml_context * ctx,
1006
1096
  struct wsp_ggml_tensor * a);
@@ -1032,6 +1122,24 @@ extern "C" {
1032
1122
  struct wsp_ggml_tensor * a,
1033
1123
  struct wsp_ggml_tensor * b);
1034
1124
 
1125
+ // hardswish(x) = x * relu6(x + 3) / 6
1126
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardswish(
1127
+ struct wsp_ggml_context * ctx,
1128
+ struct wsp_ggml_tensor * a);
1129
+
1130
+ // hardsigmoid(x) = relu6(x + 3) / 6
1131
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardsigmoid(
1132
+ struct wsp_ggml_context * ctx,
1133
+ struct wsp_ggml_tensor * a);
1134
+
1135
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp(
1136
+ struct wsp_ggml_context * ctx,
1137
+ struct wsp_ggml_tensor * a);
1138
+
1139
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp_inplace(
1140
+ struct wsp_ggml_context * ctx,
1141
+ struct wsp_ggml_tensor * a);
1142
+
1035
1143
  // normalize along rows
1036
1144
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm(
1037
1145
  struct wsp_ggml_context * ctx,
@@ -1055,16 +1163,17 @@ extern "C" {
1055
1163
 
1056
1164
  // group normalize along ne0*ne1*n_groups
1057
1165
  // used in stable-diffusion
1058
- // TODO: eps is hardcoded to 1e-6 for now
1059
1166
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm(
1060
1167
  struct wsp_ggml_context * ctx,
1061
1168
  struct wsp_ggml_tensor * a,
1062
- int n_groups);
1169
+ int n_groups,
1170
+ float eps);
1063
1171
 
1064
1172
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm_inplace(
1065
1173
  struct wsp_ggml_context * ctx,
1066
1174
  struct wsp_ggml_tensor * a,
1067
- int n_groups);
1175
+ int n_groups,
1176
+ float eps);
1068
1177
 
1069
1178
  // a - x
1070
1179
  // b - dy
@@ -1089,14 +1198,11 @@ extern "C" {
1089
1198
  enum wsp_ggml_prec prec);
1090
1199
 
1091
1200
  // indirect matrix multiplication
1092
- // wsp_ggml_mul_mat_id(ctx, as, ids, id, b) ~= wsp_ggml_mul_mat(as[ids[id]], b)
1093
1201
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
1094
1202
  struct wsp_ggml_context * ctx,
1095
- struct wsp_ggml_tensor * const as[],
1096
- int n_as,
1097
- struct wsp_ggml_tensor * ids,
1098
- int id,
1099
- struct wsp_ggml_tensor * b);
1203
+ struct wsp_ggml_tensor * as,
1204
+ struct wsp_ggml_tensor * b,
1205
+ struct wsp_ggml_tensor * ids);
1100
1206
 
1101
1207
  // A: m columns, n rows,
1102
1208
  // B: p columns, n rows,
@@ -1129,7 +1235,7 @@ extern "C" {
1129
1235
  size_t nb1,
1130
1236
  size_t nb2,
1131
1237
  size_t nb3,
1132
- size_t offset);
1238
+ size_t offset); // in bytes
1133
1239
 
1134
1240
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1135
1241
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_inplace(
@@ -1139,19 +1245,19 @@ extern "C" {
1139
1245
  size_t nb1,
1140
1246
  size_t nb2,
1141
1247
  size_t nb3,
1142
- size_t offset);
1248
+ size_t offset); // in bytes
1143
1249
 
1144
1250
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d(
1145
1251
  struct wsp_ggml_context * ctx,
1146
1252
  struct wsp_ggml_tensor * a,
1147
1253
  struct wsp_ggml_tensor * b,
1148
- size_t offset);
1254
+ size_t offset); // in bytes
1149
1255
 
1150
1256
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d_inplace(
1151
1257
  struct wsp_ggml_context * ctx,
1152
1258
  struct wsp_ggml_tensor * a,
1153
1259
  struct wsp_ggml_tensor * b,
1154
- size_t offset);
1260
+ size_t offset); // in bytes
1155
1261
 
1156
1262
  // b -> view(a,offset,nb1,nb2,3), return modified a
1157
1263
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d(
@@ -1159,7 +1265,7 @@ extern "C" {
1159
1265
  struct wsp_ggml_tensor * a,
1160
1266
  struct wsp_ggml_tensor * b,
1161
1267
  size_t nb1,
1162
- size_t offset);
1268
+ size_t offset); // in bytes
1163
1269
 
1164
1270
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1165
1271
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace(
@@ -1167,7 +1273,7 @@ extern "C" {
1167
1273
  struct wsp_ggml_tensor * a,
1168
1274
  struct wsp_ggml_tensor * b,
1169
1275
  size_t nb1,
1170
- size_t offset);
1276
+ size_t offset); // in bytes
1171
1277
 
1172
1278
  // a -> b, return view(b)
1173
1279
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy(
@@ -1302,14 +1408,14 @@ extern "C" {
1302
1408
  // supports 3D: a->ne[2] == b->ne[1]
1303
1409
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows(
1304
1410
  struct wsp_ggml_context * ctx,
1305
- struct wsp_ggml_tensor * a,
1306
- struct wsp_ggml_tensor * b);
1411
+ struct wsp_ggml_tensor * a, // data
1412
+ struct wsp_ggml_tensor * b); // row indices
1307
1413
 
1308
1414
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows_back(
1309
1415
  struct wsp_ggml_context * ctx,
1310
- struct wsp_ggml_tensor * a,
1311
- struct wsp_ggml_tensor * b,
1312
- struct wsp_ggml_tensor * c);
1416
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_get_rows result
1417
+ struct wsp_ggml_tensor * b, // row indices
1418
+ struct wsp_ggml_tensor * c); // data for wsp_ggml_get_rows, only used for its shape
1313
1419
 
1314
1420
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag(
1315
1421
  struct wsp_ggml_context * ctx,
@@ -1348,13 +1454,15 @@ extern "C" {
1348
1454
  struct wsp_ggml_context * ctx,
1349
1455
  struct wsp_ggml_tensor * a);
1350
1456
 
1351
- // fused soft_max(a*scale + mask)
1457
+ // fused soft_max(a*scale + mask*(ALiBi slope))
1352
1458
  // mask is optional
1459
+ // max_bias = 0.0f for no ALiBi
1353
1460
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
1354
1461
  struct wsp_ggml_context * ctx,
1355
1462
  struct wsp_ggml_tensor * a,
1356
1463
  struct wsp_ggml_tensor * mask,
1357
- float scale);
1464
+ float scale,
1465
+ float max_bias);
1358
1466
 
1359
1467
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
1360
1468
  struct wsp_ggml_context * ctx,
@@ -1368,9 +1476,8 @@ extern "C" {
1368
1476
  struct wsp_ggml_tensor * b);
1369
1477
 
1370
1478
  // rotary position embedding
1371
- // if mode & 1 == 1, skip n_past elements (DEPRECATED)
1372
- // if mode & 2 == 1, GPT-NeoX style
1373
- // if mode & 4 == 1, ChatGLM style
1479
+ // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
1480
+ // if (mode & WSP_GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
1374
1481
  //
1375
1482
  // b is an int32 vector with size a->ne[2], it contains the positions
1376
1483
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope(
@@ -1378,8 +1485,7 @@ extern "C" {
1378
1485
  struct wsp_ggml_tensor * a,
1379
1486
  struct wsp_ggml_tensor * b,
1380
1487
  int n_dims,
1381
- int mode,
1382
- int n_ctx);
1488
+ int mode);
1383
1489
 
1384
1490
  // in-place, returns view(a)
1385
1491
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
@@ -1387,18 +1493,18 @@ extern "C" {
1387
1493
  struct wsp_ggml_tensor * a,
1388
1494
  struct wsp_ggml_tensor * b,
1389
1495
  int n_dims,
1390
- int mode,
1391
- int n_ctx);
1496
+ int mode);
1392
1497
 
1393
1498
  // custom RoPE
1394
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1499
+ // c is freq factors (e.g. phi3-128k), (optional)
1500
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext(
1395
1501
  struct wsp_ggml_context * ctx,
1396
1502
  struct wsp_ggml_tensor * a,
1397
1503
  struct wsp_ggml_tensor * b,
1504
+ struct wsp_ggml_tensor * c,
1398
1505
  int n_dims,
1399
1506
  int mode,
1400
- int n_ctx,
1401
- int n_orig_ctx,
1507
+ int n_ctx_orig,
1402
1508
  float freq_base,
1403
1509
  float freq_scale,
1404
1510
  float ext_factor,
@@ -1407,14 +1513,14 @@ extern "C" {
1407
1513
  float beta_slow);
1408
1514
 
1409
1515
  // in-place, returns view(a)
1410
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1516
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext_inplace(
1411
1517
  struct wsp_ggml_context * ctx,
1412
1518
  struct wsp_ggml_tensor * a,
1413
1519
  struct wsp_ggml_tensor * b,
1520
+ struct wsp_ggml_tensor * c,
1414
1521
  int n_dims,
1415
1522
  int mode,
1416
- int n_ctx,
1417
- int n_orig_ctx,
1523
+ int n_ctx_orig,
1418
1524
  float freq_base,
1419
1525
  float freq_scale,
1420
1526
  float ext_factor,
@@ -1422,46 +1528,56 @@ extern "C" {
1422
1528
  float beta_fast,
1423
1529
  float beta_slow);
1424
1530
 
1425
- // compute correction dims for YaRN RoPE scaling
1426
- WSP_GGML_CALL void wsp_ggml_rope_yarn_corr_dims(
1427
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1428
-
1429
- // xPos RoPE, in-place, returns view(a)
1430
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace(
1531
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1431
1532
  struct wsp_ggml_context * ctx,
1432
1533
  struct wsp_ggml_tensor * a,
1433
1534
  struct wsp_ggml_tensor * b,
1434
1535
  int n_dims,
1435
- float base,
1436
- bool down);
1536
+ int mode,
1537
+ int n_ctx_orig,
1538
+ float freq_base,
1539
+ float freq_scale,
1540
+ float ext_factor,
1541
+ float attn_factor,
1542
+ float beta_fast,
1543
+ float beta_slow),
1544
+ "use wsp_ggml_rope_ext instead");
1437
1545
 
1438
- // rotary position embedding backward, i.e compute dx from dy
1439
- // a - dy
1440
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1546
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1441
1547
  struct wsp_ggml_context * ctx,
1442
1548
  struct wsp_ggml_tensor * a,
1443
1549
  struct wsp_ggml_tensor * b,
1444
1550
  int n_dims,
1445
1551
  int mode,
1446
- int n_ctx,
1447
- int n_orig_ctx,
1552
+ int n_ctx_orig,
1448
1553
  float freq_base,
1449
1554
  float freq_scale,
1450
1555
  float ext_factor,
1451
1556
  float attn_factor,
1452
1557
  float beta_fast,
1453
- float beta_slow,
1454
- float xpos_base,
1455
- bool xpos_down);
1558
+ float beta_slow),
1559
+ "use wsp_ggml_rope_ext_inplace instead");
1456
1560
 
1457
- // alibi position embedding
1458
- // in-place, returns view(a)
1459
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_alibi(
1561
+ // compute correction dims for YaRN RoPE scaling
1562
+ void wsp_ggml_rope_yarn_corr_dims(
1563
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1564
+
1565
+ // rotary position embedding backward, i.e compute dx from dy
1566
+ // a - dy
1567
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1460
1568
  struct wsp_ggml_context * ctx,
1461
- struct wsp_ggml_tensor * a,
1462
- int n_past,
1463
- int n_head,
1464
- float bias_max);
1569
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_rope result
1570
+ struct wsp_ggml_tensor * b, // positions
1571
+ struct wsp_ggml_tensor * c, // freq factors
1572
+ int n_dims,
1573
+ int mode,
1574
+ int n_ctx_orig,
1575
+ float freq_base,
1576
+ float freq_scale,
1577
+ float ext_factor,
1578
+ float attn_factor,
1579
+ float beta_fast,
1580
+ float beta_slow);
1465
1581
 
1466
1582
  // clamp
1467
1583
  // in-place, returns view(a)
@@ -1471,22 +1587,49 @@ extern "C" {
1471
1587
  float min,
1472
1588
  float max);
1473
1589
 
1590
+ // im2col
1591
+ // converts data into a format that effectively results in a convolution when combined with matrix multiplication
1474
1592
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col(
1475
1593
  struct wsp_ggml_context * ctx,
1476
- struct wsp_ggml_tensor * a,
1477
- struct wsp_ggml_tensor * b,
1478
- int s0,
1479
- int s1,
1480
- int p0,
1481
- int p1,
1482
- int d0,
1483
- int d1,
1484
- bool is_2D);
1594
+ struct wsp_ggml_tensor * a, // convolution kernel
1595
+ struct wsp_ggml_tensor * b, // data
1596
+ int s0, // stride dimension 0
1597
+ int s1, // stride dimension 1
1598
+ int p0, // padding dimension 0
1599
+ int p1, // padding dimension 1
1600
+ int d0, // dilation dimension 0
1601
+ int d1, // dilation dimension 1
1602
+ bool is_2D,
1603
+ enum wsp_ggml_type dst_type);
1604
+
1605
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col_back(
1606
+ struct wsp_ggml_context * ctx,
1607
+ struct wsp_ggml_tensor * a, // convolution kernel
1608
+ struct wsp_ggml_tensor * b, // gradient of im2col output
1609
+ int64_t * ne, // shape of im2col input
1610
+ int s0, // stride dimension 0
1611
+ int s1, // stride dimension 1
1612
+ int p0, // padding dimension 0
1613
+ int p1, // padding dimension 1
1614
+ int d0, // dilation dimension 0
1615
+ int d1, // dilation dimension 1
1616
+ bool is_2D);
1617
+
1618
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_depthwise_2d(
1619
+ struct wsp_ggml_context * ctx,
1620
+ struct wsp_ggml_tensor * a, // convolution kernel
1621
+ struct wsp_ggml_tensor * b, // data
1622
+ int s0, // stride dimension 0
1623
+ int s1, // stride dimension 1
1624
+ int p0, // padding dimension 0
1625
+ int p1, // padding dimension 1
1626
+ int d0, // dilation dimension 0
1627
+ int d1); // dilation dimension 1
1485
1628
 
1486
1629
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
1487
1630
  struct wsp_ggml_context * ctx,
1488
- struct wsp_ggml_tensor * a,
1489
- struct wsp_ggml_tensor * b,
1631
+ struct wsp_ggml_tensor * a, // convolution kernel
1632
+ struct wsp_ggml_tensor * b, // data
1490
1633
  int s0, // stride
1491
1634
  int p0, // padding
1492
1635
  int d0); // dilation
@@ -1495,29 +1638,29 @@ extern "C" {
1495
1638
  // alias for wsp_ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1496
1639
  WSP_GGML_API struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph(
1497
1640
  struct wsp_ggml_context * ctx,
1498
- struct wsp_ggml_tensor * a,
1499
- struct wsp_ggml_tensor * b,
1500
- int s,
1501
- int d);
1641
+ struct wsp_ggml_tensor * a, // convolution kernel
1642
+ struct wsp_ggml_tensor * b, // data
1643
+ int s, // stride
1644
+ int d); // dilation
1502
1645
 
1503
1646
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
1504
1647
  struct wsp_ggml_context * ctx,
1505
- struct wsp_ggml_tensor * a,
1506
- struct wsp_ggml_tensor * b,
1507
- int s0,
1508
- int p0,
1509
- int d0);
1648
+ struct wsp_ggml_tensor * a, // convolution kernel
1649
+ struct wsp_ggml_tensor * b, // data
1650
+ int s0, // stride
1651
+ int p0, // padding
1652
+ int d0); // dilation
1510
1653
 
1511
1654
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d(
1512
1655
  struct wsp_ggml_context * ctx,
1513
- struct wsp_ggml_tensor * a,
1514
- struct wsp_ggml_tensor * b,
1515
- int s0,
1516
- int s1,
1517
- int p0,
1518
- int p1,
1519
- int d0,
1520
- int d1);
1656
+ struct wsp_ggml_tensor * a, // convolution kernel
1657
+ struct wsp_ggml_tensor * b, // data
1658
+ int s0, // stride dimension 0
1659
+ int s1, // stride dimension 1
1660
+ int p0, // padding dimension 0
1661
+ int p1, // padding dimension 1
1662
+ int d0, // dilation dimension 0
1663
+ int d1); // dilation dimension 1
1521
1664
 
1522
1665
 
1523
1666
  // kernel size is a->ne[0] x a->ne[1]
@@ -1579,13 +1722,37 @@ extern "C" {
1579
1722
  float p0,
1580
1723
  float p1);
1581
1724
 
1725
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d_back(
1726
+ struct wsp_ggml_context * ctx,
1727
+ struct wsp_ggml_tensor * a,
1728
+ struct wsp_ggml_tensor * af, // "a"/input used in forward pass
1729
+ enum wsp_ggml_op_pool op,
1730
+ int k0,
1731
+ int k1,
1732
+ int s0,
1733
+ int s1,
1734
+ float p0,
1735
+ float p1);
1736
+
1582
1737
  // nearest interpolate
1738
+ // multiplies ne0 and ne1 by scale factor
1583
1739
  // used in stable-diffusion
1584
1740
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale(
1585
1741
  struct wsp_ggml_context * ctx,
1586
1742
  struct wsp_ggml_tensor * a,
1587
1743
  int scale_factor);
1588
1744
 
1745
+ // nearest interpolate
1746
+ // nearest interpolate to specified dimensions
1747
+ // used in tortoise.cpp
1748
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
1749
+ struct wsp_ggml_context * ctx,
1750
+ struct wsp_ggml_tensor * a,
1751
+ int ne0,
1752
+ int ne1,
1753
+ int ne2,
1754
+ int ne3);
1755
+
1589
1756
  // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1590
1757
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad(
1591
1758
  struct wsp_ggml_context * ctx,
@@ -1595,10 +1762,19 @@ extern "C" {
1595
1762
  int p2,
1596
1763
  int p3);
1597
1764
 
1765
+ // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
1766
+ // timesteps: [N,]
1767
+ // return: [N, dim]
1768
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_timestep_embedding(
1769
+ struct wsp_ggml_context * ctx,
1770
+ struct wsp_ggml_tensor * timesteps,
1771
+ int dim,
1772
+ int max_period);
1773
+
1598
1774
  // sort rows
1599
1775
  enum wsp_ggml_sort_order {
1600
- WSP_GGML_SORT_ASC,
1601
- WSP_GGML_SORT_DESC,
1776
+ WSP_GGML_SORT_ORDER_ASC,
1777
+ WSP_GGML_SORT_ORDER_DESC,
1602
1778
  };
1603
1779
 
1604
1780
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_argsort(
@@ -1606,19 +1782,40 @@ extern "C" {
1606
1782
  struct wsp_ggml_tensor * a,
1607
1783
  enum wsp_ggml_sort_order order);
1608
1784
 
1785
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_arange(
1786
+ struct wsp_ggml_context * ctx,
1787
+ float start,
1788
+ float stop,
1789
+ float step);
1790
+
1609
1791
  // top k elements per row
1610
1792
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_top_k(
1611
1793
  struct wsp_ggml_context * ctx,
1612
1794
  struct wsp_ggml_tensor * a,
1613
1795
  int k);
1614
1796
 
1615
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn(
1797
+ #define WSP_GGML_KQ_MASK_PAD 32
1798
+
1799
+ // q: [n_embd, n_batch, n_head, 1]
1800
+ // k: [n_embd, n_kv, n_head_kv, 1]
1801
+ // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1802
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = WSP_GGML_PAD(n_batch, WSP_GGML_KQ_MASK_PAD) !!
1803
+ // res: [n_embd, n_head, n_batch, 1] !! permuted !!
1804
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_ext(
1616
1805
  struct wsp_ggml_context * ctx,
1617
1806
  struct wsp_ggml_tensor * q,
1618
1807
  struct wsp_ggml_tensor * k,
1619
1808
  struct wsp_ggml_tensor * v,
1620
- bool masked);
1809
+ struct wsp_ggml_tensor * mask,
1810
+ float scale,
1811
+ float max_bias,
1812
+ float logit_softcap);
1621
1813
 
1814
+ WSP_GGML_API void wsp_ggml_flash_attn_ext_set_prec(
1815
+ struct wsp_ggml_tensor * a,
1816
+ enum wsp_ggml_prec prec);
1817
+
1818
+ // TODO: needs to be adapted to wsp_ggml_flash_attn_ext
1622
1819
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_back(
1623
1820
  struct wsp_ggml_context * ctx,
1624
1821
  struct wsp_ggml_tensor * q,
@@ -1627,13 +1824,19 @@ extern "C" {
1627
1824
  struct wsp_ggml_tensor * d,
1628
1825
  bool masked);
1629
1826
 
1630
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_ff(
1827
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_conv(
1631
1828
  struct wsp_ggml_context * ctx,
1632
- struct wsp_ggml_tensor * a,
1633
- struct wsp_ggml_tensor * b0,
1634
- struct wsp_ggml_tensor * b1,
1635
- struct wsp_ggml_tensor * c0,
1636
- struct wsp_ggml_tensor * c1);
1829
+ struct wsp_ggml_tensor * sx,
1830
+ struct wsp_ggml_tensor * c);
1831
+
1832
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_scan(
1833
+ struct wsp_ggml_context * ctx,
1834
+ struct wsp_ggml_tensor * s,
1835
+ struct wsp_ggml_tensor * x,
1836
+ struct wsp_ggml_tensor * dt,
1837
+ struct wsp_ggml_tensor * A,
1838
+ struct wsp_ggml_tensor * B,
1839
+ struct wsp_ggml_tensor * C);
1637
1840
 
1638
1841
  // partition into non-overlapping windows with padding if needed
1639
1842
  // example:
@@ -1685,6 +1888,15 @@ extern "C" {
1685
1888
  struct wsp_ggml_tensor * pw,
1686
1889
  struct wsp_ggml_tensor * ph);
1687
1890
 
1891
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv(
1892
+ struct wsp_ggml_context * ctx,
1893
+ struct wsp_ggml_tensor * k,
1894
+ struct wsp_ggml_tensor * v,
1895
+ struct wsp_ggml_tensor * r,
1896
+ struct wsp_ggml_tensor * tf,
1897
+ struct wsp_ggml_tensor * td,
1898
+ struct wsp_ggml_tensor * state);
1899
+
1688
1900
  // custom operators
1689
1901
 
1690
1902
  typedef void (*wsp_ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -1768,7 +1980,8 @@ extern "C" {
1768
1980
  typedef void (*wsp_ggml_custom2_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, int ith, int nth, void * userdata);
1769
1981
  typedef void (*wsp_ggml_custom3_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, const struct wsp_ggml_tensor * c, int ith, int nth, void * userdata);
1770
1982
 
1771
- #define WSP_GGML_N_TASKS_MAX -1
1983
+ #define WSP_GGML_N_TASKS_MAX (-1)
1984
+ // n_tasks == WSP_GGML_N_TASKS_MAX means to use max number of tasks
1772
1985
 
1773
1986
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1(
1774
1987
  struct wsp_ggml_context * ctx,
@@ -1821,48 +2034,87 @@ extern "C" {
1821
2034
  // loss function
1822
2035
 
1823
2036
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss(
1824
- struct wsp_ggml_context * ctx,
1825
- struct wsp_ggml_tensor * a,
1826
- struct wsp_ggml_tensor * b);
2037
+ struct wsp_ggml_context * ctx,
2038
+ struct wsp_ggml_tensor * a, // logits
2039
+ struct wsp_ggml_tensor * b); // labels
1827
2040
 
1828
2041
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss_back(
1829
- struct wsp_ggml_context * ctx,
1830
- struct wsp_ggml_tensor * a,
1831
- struct wsp_ggml_tensor * b,
1832
- struct wsp_ggml_tensor * c);
2042
+ struct wsp_ggml_context * ctx,
2043
+ struct wsp_ggml_tensor * a, // logits
2044
+ struct wsp_ggml_tensor * b, // labels
2045
+ struct wsp_ggml_tensor * c); // gradients of cross_entropy_loss result
2046
+
2047
+ // AdamW optimizer step
2048
+ // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
2049
+ // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
2050
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_step_adamw(
2051
+ struct wsp_ggml_context * ctx,
2052
+ struct wsp_ggml_tensor * a,
2053
+ struct wsp_ggml_tensor * grad,
2054
+ float alpha,
2055
+ float beta1,
2056
+ float beta2,
2057
+ float eps,
2058
+ float wd); // weight decay
1833
2059
 
1834
2060
  //
1835
2061
  // automatic differentiation
1836
2062
  //
1837
2063
 
1838
- WSP_GGML_API void wsp_ggml_set_param(
1839
- struct wsp_ggml_context * ctx,
1840
- struct wsp_ggml_tensor * tensor);
1841
-
2064
+ WSP_GGML_API void wsp_ggml_set_param(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
2065
+ WSP_GGML_API void wsp_ggml_set_loss(struct wsp_ggml_tensor * tensor);
1842
2066
 
1843
2067
  WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1844
- WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep);
2068
+ WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool accumulate);
2069
+
2070
+ WSP_GGML_API void wsp_ggml_build_opt_adamw(
2071
+ struct wsp_ggml_context * ctx,
2072
+ struct wsp_ggml_cgraph * gf,
2073
+ struct wsp_ggml_cgraph * gb,
2074
+ float alpha,
2075
+ float beta1,
2076
+ float beta2,
2077
+ float eps,
2078
+ float wd); // weight decay
1845
2079
 
1846
2080
  // graph allocation in a context
1847
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
1848
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom (struct wsp_ggml_context * ctx, size_t size, bool grads);
1849
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
1850
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_view (struct wsp_ggml_cgraph * cgraph, int i0, int i1);
1851
- WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
1852
- WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // zero grads
1853
- WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
2081
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
2082
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx, size_t size, bool grads);
2083
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
2084
+ WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
2085
+ WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2086
+ WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
2087
+
2088
+ WSP_GGML_API int wsp_ggml_graph_size (struct wsp_ggml_cgraph * cgraph);
2089
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_node (struct wsp_ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]
2090
+ WSP_GGML_API struct wsp_ggml_tensor ** wsp_ggml_graph_nodes (struct wsp_ggml_cgraph * cgraph);
2091
+ WSP_GGML_API int wsp_ggml_graph_n_nodes(struct wsp_ggml_cgraph * cgraph);
2092
+
2093
+ WSP_GGML_API void wsp_ggml_graph_add_node(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1854
2094
 
1855
2095
  WSP_GGML_API size_t wsp_ggml_graph_overhead(void);
1856
2096
  WSP_GGML_API size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads);
1857
2097
 
2098
+ WSP_GGML_API struct wsp_ggml_threadpool_params wsp_ggml_threadpool_params_default(int n_threads);
2099
+ WSP_GGML_API void wsp_ggml_threadpool_params_init (struct wsp_ggml_threadpool_params * p, int n_threads);
2100
+ WSP_GGML_API bool wsp_ggml_threadpool_params_match (const struct wsp_ggml_threadpool_params * p0, const struct wsp_ggml_threadpool_params * p1);
2101
+ WSP_GGML_API struct wsp_ggml_threadpool * wsp_ggml_threadpool_new (struct wsp_ggml_threadpool_params * params);
2102
+ WSP_GGML_API void wsp_ggml_threadpool_free (struct wsp_ggml_threadpool * threadpool);
2103
+ WSP_GGML_API int wsp_ggml_threadpool_get_n_threads(struct wsp_ggml_threadpool * threadpool);
2104
+ WSP_GGML_API void wsp_ggml_threadpool_pause (struct wsp_ggml_threadpool * threadpool);
2105
+ WSP_GGML_API void wsp_ggml_threadpool_resume (struct wsp_ggml_threadpool * threadpool);
2106
+
1858
2107
  // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute()
1859
2108
  // when plan.work_size > 0, caller must allocate memory for plan.work_data
1860
- WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (const struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/);
1861
- WSP_GGML_API int wsp_ggml_graph_compute( struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
2109
+ WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan(
2110
+ const struct wsp_ggml_cgraph * cgraph,
2111
+ int n_threads, /* = WSP_GGML_DEFAULT_N_THREADS */
2112
+ struct wsp_ggml_threadpool * threadpool /* = NULL */ );
2113
+ WSP_GGML_API enum wsp_ggml_status wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1862
2114
 
1863
2115
  // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context
1864
2116
  // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
1865
- WSP_GGML_API void wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads);
2117
+ WSP_GGML_API enum wsp_ggml_status wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads);
1866
2118
 
1867
2119
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name);
1868
2120
 
@@ -1891,8 +2143,8 @@ extern "C" {
1891
2143
 
1892
2144
  // optimization methods
1893
2145
  enum wsp_ggml_opt_type {
1894
- WSP_GGML_OPT_ADAM,
1895
- WSP_GGML_OPT_LBFGS,
2146
+ WSP_GGML_OPT_TYPE_ADAM,
2147
+ WSP_GGML_OPT_TYPE_LBFGS,
1896
2148
  };
1897
2149
 
1898
2150
  // linesearch methods
@@ -1906,12 +2158,12 @@ extern "C" {
1906
2158
 
1907
2159
  // optimization return values
1908
2160
  enum wsp_ggml_opt_result {
1909
- WSP_GGML_OPT_OK = 0,
1910
- WSP_GGML_OPT_DID_NOT_CONVERGE,
1911
- WSP_GGML_OPT_NO_CONTEXT,
1912
- WSP_GGML_OPT_INVALID_WOLFE,
1913
- WSP_GGML_OPT_FAIL,
1914
- WSP_GGML_OPT_CANCEL,
2161
+ WSP_GGML_OPT_RESULT_OK = 0,
2162
+ WSP_GGML_OPT_RESULT_DID_NOT_CONVERGE,
2163
+ WSP_GGML_OPT_RESULT_NO_CONTEXT,
2164
+ WSP_GGML_OPT_RESULT_INVALID_WOLFE,
2165
+ WSP_GGML_OPT_RESULT_FAIL,
2166
+ WSP_GGML_OPT_RESULT_CANCEL,
1915
2167
 
1916
2168
  WSP_GGML_LINESEARCH_FAIL = -128,
1917
2169
  WSP_GGML_LINESEARCH_MINIMUM_STEP,
@@ -1923,6 +2175,10 @@ extern "C" {
1923
2175
  typedef void (*wsp_ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
1924
2176
  typedef void (*wsp_ggml_log_callback)(enum wsp_ggml_log_level level, const char * text, void * user_data);
1925
2177
 
2178
+ // Set callback for all future logging events.
2179
+ // If this is not called, or NULL is supplied, everything is output on stderr.
2180
+ WSP_GGML_API void wsp_ggml_log_set(wsp_ggml_log_callback log_callback, void * user_data);
2181
+
1926
2182
  // optimization parameters
1927
2183
  //
1928
2184
  // see ggml.c (wsp_ggml_opt_default_params) for default values
@@ -2061,6 +2317,12 @@ extern "C" {
2061
2317
  wsp_ggml_opt_callback callback,
2062
2318
  void * callback_data);
2063
2319
 
2320
+ //
2321
+ // tensor flags
2322
+ //
2323
+ WSP_GGML_API void wsp_ggml_set_input(struct wsp_ggml_tensor * tensor);
2324
+ WSP_GGML_API void wsp_ggml_set_output(struct wsp_ggml_tensor * tensor);
2325
+
2064
2326
  //
2065
2327
  // quantization
2066
2328
  //
@@ -2077,25 +2339,18 @@ extern "C" {
2077
2339
  WSP_GGML_API void wsp_ggml_wsp_quantize_init(enum wsp_ggml_type type);
2078
2340
  WSP_GGML_API void wsp_ggml_wsp_quantize_free(void);
2079
2341
 
2080
- // TODO: these would probably get removed in favor of the more general wsp_ggml_wsp_quantize_chunk
2081
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
2082
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
2083
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
2084
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
2085
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
2086
-
2087
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
2088
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
2089
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
2090
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2091
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
2092
-
2093
2342
  // some quantization type cannot be used without an importance matrix
2094
2343
  WSP_GGML_API bool wsp_ggml_wsp_quantize_requires_imatrix(enum wsp_ggml_type type);
2095
2344
 
2096
2345
  // calls wsp_ggml_wsp_quantize_init internally (i.e. can allocate memory)
2097
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst,
2098
- int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
2346
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(
2347
+ enum wsp_ggml_type type,
2348
+ const float * src,
2349
+ void * dst,
2350
+ int64_t start,
2351
+ int64_t nrows,
2352
+ int64_t n_per_row,
2353
+ const float * imatrix);
2099
2354
 
2100
2355
  //
2101
2356
  // gguf
@@ -2171,6 +2426,9 @@ extern "C" {
2171
2426
  WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i);
2172
2427
  WSP_GGML_API enum wsp_ggml_type wsp_gguf_get_tensor_type (const struct wsp_gguf_context * ctx, int i);
2173
2428
 
2429
+ // removes key if it exists
2430
+ WSP_GGML_API void wsp_gguf_remove_key(struct wsp_gguf_context * ctx, const char * key);
2431
+
2174
2432
  // overrides existing values or adds a new one
2175
2433
  WSP_GGML_API void wsp_gguf_set_val_u8 (struct wsp_gguf_context * ctx, const char * key, uint8_t val);
2176
2434
  WSP_GGML_API void wsp_gguf_set_val_i8 (struct wsp_gguf_context * ctx, const char * key, int8_t val);
@@ -2230,20 +2488,33 @@ extern "C" {
2230
2488
  WSP_GGML_API int wsp_ggml_cpu_has_avx512 (void);
2231
2489
  WSP_GGML_API int wsp_ggml_cpu_has_avx512_vbmi(void);
2232
2490
  WSP_GGML_API int wsp_ggml_cpu_has_avx512_vnni(void);
2491
+ WSP_GGML_API int wsp_ggml_cpu_has_avx512_bf16(void);
2492
+ WSP_GGML_API int wsp_ggml_cpu_has_amx_int8 (void);
2233
2493
  WSP_GGML_API int wsp_ggml_cpu_has_fma (void);
2234
2494
  WSP_GGML_API int wsp_ggml_cpu_has_neon (void);
2495
+ WSP_GGML_API int wsp_ggml_cpu_has_sve (void);
2235
2496
  WSP_GGML_API int wsp_ggml_cpu_has_arm_fma (void);
2236
2497
  WSP_GGML_API int wsp_ggml_cpu_has_metal (void);
2237
2498
  WSP_GGML_API int wsp_ggml_cpu_has_f16c (void);
2238
2499
  WSP_GGML_API int wsp_ggml_cpu_has_fp16_va (void);
2239
2500
  WSP_GGML_API int wsp_ggml_cpu_has_wasm_simd (void);
2240
2501
  WSP_GGML_API int wsp_ggml_cpu_has_blas (void);
2241
- WSP_GGML_API int wsp_ggml_cpu_has_cublas (void);
2242
- WSP_GGML_API int wsp_ggml_cpu_has_clblast (void);
2502
+ WSP_GGML_API int wsp_ggml_cpu_has_cuda (void);
2503
+ WSP_GGML_API int wsp_ggml_cpu_has_vulkan (void);
2504
+ WSP_GGML_API int wsp_ggml_cpu_has_kompute (void);
2243
2505
  WSP_GGML_API int wsp_ggml_cpu_has_gpublas (void);
2244
2506
  WSP_GGML_API int wsp_ggml_cpu_has_sse3 (void);
2245
2507
  WSP_GGML_API int wsp_ggml_cpu_has_ssse3 (void);
2508
+ WSP_GGML_API int wsp_ggml_cpu_has_riscv_v (void);
2509
+ WSP_GGML_API int wsp_ggml_cpu_has_sycl (void);
2510
+ WSP_GGML_API int wsp_ggml_cpu_has_rpc (void);
2246
2511
  WSP_GGML_API int wsp_ggml_cpu_has_vsx (void);
2512
+ WSP_GGML_API int wsp_ggml_cpu_has_matmul_int8(void);
2513
+ WSP_GGML_API int wsp_ggml_cpu_has_cann (void);
2514
+ WSP_GGML_API int wsp_ggml_cpu_has_llamafile (void);
2515
+
2516
+ // get the sve vector length in bytes
2517
+ WSP_GGML_API int wsp_ggml_cpu_get_sve_cnt(void);
2247
2518
 
2248
2519
  //
2249
2520
  // Internal types and functions exposed for tests and benchmarks
@@ -2255,23 +2526,36 @@ extern "C" {
2255
2526
  #else
2256
2527
  #define WSP_GGML_RESTRICT restrict
2257
2528
  #endif
2258
- typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int k);
2259
- typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int k);
2260
- typedef void (*wsp_ggml_vec_dot_t) (const int n, float * WSP_GGML_RESTRICT s, const void * WSP_GGML_RESTRICT x, const void * WSP_GGML_RESTRICT y);
2261
-
2262
- typedef struct {
2263
- const char * type_name;
2264
- int blck_size;
2265
- size_t type_size;
2266
- bool is_quantized;
2267
- wsp_ggml_to_float_t to_float;
2268
- wsp_ggml_from_float_t from_float;
2269
- wsp_ggml_from_float_t from_float_reference;
2270
- wsp_ggml_vec_dot_t vec_dot;
2271
- enum wsp_ggml_type vec_dot_type;
2272
- } wsp_ggml_type_traits_t;
2273
-
2274
- WSP_GGML_API wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
2529
+ typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
2530
+ typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int64_t k);
2531
+ typedef void (*wsp_ggml_from_float_to_mat_t)
2532
+ (const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
2533
+ typedef void (*wsp_ggml_vec_dot_t) (int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT x, size_t bx,
2534
+ const void * WSP_GGML_RESTRICT y, size_t by, int nrc);
2535
+ typedef void (*wsp_ggml_gemv_t) (int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT x,
2536
+ const void * WSP_GGML_RESTRICT y, int nr, int nc);
2537
+ typedef void (*wsp_ggml_gemm_t) (int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT x,
2538
+ const void * WSP_GGML_RESTRICT y, int nr, int nc);
2539
+
2540
+ struct wsp_ggml_type_traits {
2541
+ const char * type_name;
2542
+ int64_t blck_size;
2543
+ int64_t blck_size_interleave; // interleave elements in blocks
2544
+ size_t type_size;
2545
+ bool is_quantized;
2546
+ wsp_ggml_to_float_t to_float;
2547
+ wsp_ggml_from_float_t from_float;
2548
+ wsp_ggml_from_float_t from_float_ref;
2549
+ wsp_ggml_from_float_to_mat_t from_float_to_mat;
2550
+ wsp_ggml_vec_dot_t vec_dot;
2551
+ enum wsp_ggml_type vec_dot_type;
2552
+ int64_t nrows; // number of rows to process simultaneously
2553
+ int64_t ncols; // number of columns to process simultaneously
2554
+ wsp_ggml_gemv_t gemv;
2555
+ wsp_ggml_gemm_t gemm;
2556
+ };
2557
+
2558
+ WSP_GGML_API const struct wsp_ggml_type_traits * wsp_ggml_get_type_traits(enum wsp_ggml_type type);
2275
2559
 
2276
2560
  #ifdef __cplusplus
2277
2561
  }