whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +7 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +188 -109
  9. package/cpp/README.md +1 -1
  10. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  11. package/cpp/coreml/whisper-encoder.h +4 -0
  12. package/cpp/coreml/whisper-encoder.mm +4 -2
  13. package/cpp/ggml-alloc.c +451 -282
  14. package/cpp/ggml-alloc.h +74 -8
  15. package/cpp/ggml-backend-impl.h +112 -0
  16. package/cpp/ggml-backend.c +1357 -0
  17. package/cpp/ggml-backend.h +181 -0
  18. package/cpp/ggml-impl.h +243 -0
  19. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
  20. package/cpp/ggml-metal.h +28 -1
  21. package/cpp/ggml-metal.m +1128 -308
  22. package/cpp/ggml-quants.c +7382 -0
  23. package/cpp/ggml-quants.h +224 -0
  24. package/cpp/ggml.c +3848 -5245
  25. package/cpp/ggml.h +353 -155
  26. package/cpp/rn-audioutils.cpp +68 -0
  27. package/cpp/rn-audioutils.h +14 -0
  28. package/cpp/rn-whisper-log.h +11 -0
  29. package/cpp/rn-whisper.cpp +141 -59
  30. package/cpp/rn-whisper.h +47 -15
  31. package/cpp/whisper.cpp +1750 -964
  32. package/cpp/whisper.h +97 -15
  33. package/ios/RNWhisper.mm +15 -9
  34. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  35. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  36. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  37. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  38. package/ios/RNWhisperAudioUtils.h +0 -2
  39. package/ios/RNWhisperAudioUtils.m +0 -56
  40. package/ios/RNWhisperContext.h +8 -12
  41. package/ios/RNWhisperContext.mm +132 -138
  42. package/jest/mock.js +1 -1
  43. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  44. package/lib/commonjs/index.js +28 -9
  45. package/lib/commonjs/index.js.map +1 -1
  46. package/lib/commonjs/version.json +1 -1
  47. package/lib/module/NativeRNWhisper.js.map +1 -1
  48. package/lib/module/index.js +28 -9
  49. package/lib/module/index.js.map +1 -1
  50. package/lib/module/version.json +1 -1
  51. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  52. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  53. package/lib/typescript/index.d.ts +7 -2
  54. package/lib/typescript/index.d.ts.map +1 -1
  55. package/package.json +6 -5
  56. package/src/NativeRNWhisper.ts +8 -1
  57. package/src/index.ts +29 -17
  58. package/src/version.json +1 -1
  59. package/whisper-rn.podspec +1 -2
package/cpp/ggml.h CHANGED
@@ -58,7 +58,8 @@
58
58
  // {
59
59
  // ...
60
60
  //
61
- // struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(f);
61
+ // struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx);
62
+ // wsp_ggml_build_forward_expand(gf, f);
62
63
  //
63
64
  // // set the input variable and parameter values
64
65
  // wsp_ggml_set_f32(x, 2.0f);
@@ -213,15 +214,14 @@
213
214
  #define WSP_GGML_QNT_VERSION 2 // bump this on quantization format changes
214
215
  #define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this
215
216
 
216
- #define WSP_GGML_MAX_DIMS 4
217
- #define WSP_GGML_MAX_NODES 4096
218
- #define WSP_GGML_MAX_PARAMS 256
219
- #define WSP_GGML_MAX_CONTEXTS 64
220
- #define WSP_GGML_MAX_SRC 6
221
- #define WSP_GGML_MAX_NAME 64
222
- #define WSP_GGML_MAX_OP_PARAMS 32
223
- #define WSP_GGML_DEFAULT_N_THREADS 4
224
-
217
+ #define WSP_GGML_MAX_DIMS 4
218
+ #define WSP_GGML_MAX_PARAMS 1024
219
+ #define WSP_GGML_MAX_CONTEXTS 64
220
+ #define WSP_GGML_MAX_SRC 6
221
+ #define WSP_GGML_MAX_NAME 64
222
+ #define WSP_GGML_MAX_OP_PARAMS 64
223
+ #define WSP_GGML_DEFAULT_N_THREADS 4
224
+ #define WSP_GGML_DEFAULT_GRAPH_SIZE 2048
225
225
  #if UINTPTR_MAX == 0xFFFFFFFF
226
226
  #define WSP_GGML_MEM_ALIGN 4
227
227
  #else
@@ -231,10 +231,11 @@
231
231
  #define WSP_GGML_EXIT_SUCCESS 0
232
232
  #define WSP_GGML_EXIT_ABORTED 1
233
233
 
234
- #define GGUF_MAGIC 0x46554747 // "GGUF"
235
- #define GGUF_VERSION 2
234
+ #define WSP_GGUF_MAGIC "GGUF"
235
+
236
+ #define WSP_GGUF_VERSION 3
236
237
 
237
- #define GGUF_DEFAULT_ALIGNMENT 32
238
+ #define WSP_GGUF_DEFAULT_ALIGNMENT 32
238
239
 
239
240
  #define WSP_GGML_UNUSED(x) (void)(x)
240
241
 
@@ -243,11 +244,21 @@
243
244
  #define WSP_GGML_ASSERT(x) \
244
245
  do { \
245
246
  if (!(x)) { \
247
+ fflush(stdout); \
246
248
  fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
249
+ wsp_ggml_print_backtrace(); \
247
250
  abort(); \
248
251
  } \
249
252
  } while (0)
250
253
 
254
+ #ifndef NDEBUG
255
+ #define WSP_GGML_UNREACHABLE() WSP_GGML_ASSERT(!"statement should not be reached")
256
+ #elif defined(__GNUC__)
257
+ #define WSP_GGML_UNREACHABLE() __builtin_unreachable()
258
+ #else
259
+ #define WSP_GGML_UNREACHABLE() ((void) 0)
260
+ #endif
261
+
251
262
  // used to copy the number of elements and stride in bytes of tensors into local variables.
252
263
  // main purpose is to reduce code duplication and improve readability.
253
264
  //
@@ -272,6 +283,20 @@
272
283
  const type prefix##3 = (pointer)->array[3]; \
273
284
  WSP_GGML_UNUSED(prefix##3);
274
285
 
286
+ #define WSP_GGML_TENSOR_UNARY_OP_LOCALS \
287
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
288
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
289
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
290
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
291
+
292
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS \
293
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
294
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
295
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
296
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
297
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
298
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
299
+
275
300
  #ifdef __cplusplus
276
301
  extern "C" {
277
302
  #endif
@@ -318,7 +343,7 @@ extern "C" {
318
343
  WSP_GGML_TYPE_COUNT,
319
344
  };
320
345
 
321
- enum wsp_ggml_backend {
346
+ enum wsp_ggml_backend_type {
322
347
  WSP_GGML_BACKEND_CPU = 0,
323
348
  WSP_GGML_BACKEND_GPU = 10,
324
349
  WSP_GGML_BACKEND_GPU_SPLIT = 20,
@@ -370,6 +395,7 @@ extern "C" {
370
395
  WSP_GGML_OP_GROUP_NORM,
371
396
 
372
397
  WSP_GGML_OP_MUL_MAT,
398
+ WSP_GGML_OP_MUL_MAT_ID,
373
399
  WSP_GGML_OP_OUT_PROD,
374
400
 
375
401
  WSP_GGML_OP_SCALE,
@@ -391,13 +417,13 @@ extern "C" {
391
417
  WSP_GGML_OP_ROPE_BACK,
392
418
  WSP_GGML_OP_ALIBI,
393
419
  WSP_GGML_OP_CLAMP,
394
- WSP_GGML_OP_CONV_1D,
395
- WSP_GGML_OP_CONV_2D,
420
+ WSP_GGML_OP_CONV_TRANSPOSE_1D,
421
+ WSP_GGML_OP_IM2COL,
396
422
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
397
423
  WSP_GGML_OP_POOL_1D,
398
424
  WSP_GGML_OP_POOL_2D,
399
-
400
425
  WSP_GGML_OP_UPSCALE, // nearest interpolate
426
+ WSP_GGML_OP_ARGSORT,
401
427
 
402
428
  WSP_GGML_OP_FLASH_ATTN,
403
429
  WSP_GGML_OP_FLASH_FF,
@@ -437,6 +463,9 @@ extern "C" {
437
463
  WSP_GGML_UNARY_OP_GELU,
438
464
  WSP_GGML_UNARY_OP_GELU_QUICK,
439
465
  WSP_GGML_UNARY_OP_SILU,
466
+ WSP_GGML_UNARY_OP_LEAKY,
467
+
468
+ WSP_GGML_UNARY_OP_COUNT,
440
469
  };
441
470
 
442
471
  enum wsp_ggml_object_type {
@@ -445,6 +474,12 @@ extern "C" {
445
474
  WSP_GGML_OBJECT_WORK_BUFFER
446
475
  };
447
476
 
477
+ enum wsp_ggml_log_level {
478
+ WSP_GGML_LOG_LEVEL_ERROR = 2,
479
+ WSP_GGML_LOG_LEVEL_WARN = 3,
480
+ WSP_GGML_LOG_LEVEL_INFO = 4
481
+ };
482
+
448
483
  // ggml object
449
484
  struct wsp_ggml_object {
450
485
  size_t offs;
@@ -461,14 +496,16 @@ extern "C" {
461
496
 
462
497
  // n-dimensional tensor
463
498
  struct wsp_ggml_tensor {
464
- enum wsp_ggml_type type;
465
- enum wsp_ggml_backend backend;
499
+ enum wsp_ggml_type type;
500
+ enum wsp_ggml_backend_type backend;
501
+
502
+ struct wsp_ggml_backend_buffer * buffer;
466
503
 
467
504
  int n_dims;
468
505
  int64_t ne[WSP_GGML_MAX_DIMS]; // number of elements
469
506
  size_t nb[WSP_GGML_MAX_DIMS]; // stride in bytes:
470
- // nb[0] = sizeof(type)
471
- // nb[1] = nb[0] * ne[0] + padding
507
+ // nb[0] = wsp_ggml_type_size(type)
508
+ // nb[1] = nb[0] * (ne[0] / wsp_ggml_blck_size(type)) + padding
472
509
  // nb[i] = nb[i-1] * ne[i-1]
473
510
 
474
511
  // compute data
@@ -496,7 +533,7 @@ extern "C" {
496
533
 
497
534
  void * extra; // extra things e.g. for ggml-cuda.cu
498
535
 
499
- char padding[4];
536
+ char padding[12];
500
537
  };
501
538
 
502
539
  static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor);
@@ -509,29 +546,35 @@ extern "C" {
509
546
 
510
547
  int n_threads;
511
548
 
512
- // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
513
- int n_tasks[WSP_GGML_MAX_NODES];
514
-
515
549
  // abort wsp_ggml_graph_compute when true
516
550
  bool (*abort_callback)(void * data);
517
551
  void * abort_callback_data;
518
552
  };
519
553
 
520
- // next prime after WSP_GGML_MAX_NODES
521
- // #define WSP_GGML_GRAPH_HASHTABLE_SIZE 4099
522
- // next prime after WSP_GGML_MAX_NODES * 2 (nodes + leafs)
523
- #define WSP_GGML_GRAPH_HASHTABLE_SIZE 8273
554
+ enum wsp_ggml_cgraph_eval_order {
555
+ WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
556
+ WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
557
+ WSP_GGML_CGRAPH_EVAL_ORDER_COUNT
558
+ };
559
+
560
+ struct wsp_ggml_hash_set {
561
+ size_t size;
562
+ struct wsp_ggml_tensor ** keys;
563
+ };
524
564
 
525
565
  // computation graph
526
566
  struct wsp_ggml_cgraph {
567
+ int size;
527
568
  int n_nodes;
528
569
  int n_leafs;
529
570
 
530
- struct wsp_ggml_tensor * nodes[WSP_GGML_MAX_NODES];
531
- struct wsp_ggml_tensor * grads[WSP_GGML_MAX_NODES];
532
- struct wsp_ggml_tensor * leafs[WSP_GGML_MAX_NODES];
571
+ struct wsp_ggml_tensor ** nodes;
572
+ struct wsp_ggml_tensor ** grads;
573
+ struct wsp_ggml_tensor ** leafs;
574
+
575
+ struct wsp_ggml_hash_set visited_hash_table;
533
576
 
534
- void * visited_hash_table[WSP_GGML_GRAPH_HASHTABLE_SIZE];
577
+ enum wsp_ggml_cgraph_eval_order order;
535
578
 
536
579
  // performance
537
580
  int perf_runs;
@@ -539,8 +582,6 @@ extern "C" {
539
582
  int64_t perf_time_us;
540
583
  };
541
584
 
542
- static const size_t WSP_GGML_GRAPH_SIZE = sizeof(struct wsp_ggml_cgraph);
543
-
544
585
  // scratch buffer
545
586
  struct wsp_ggml_scratch {
546
587
  size_t offs;
@@ -585,6 +626,8 @@ extern "C" {
585
626
  WSP_GGML_API int64_t wsp_ggml_cycles(void);
586
627
  WSP_GGML_API int64_t wsp_ggml_cycles_per_ms(void);
587
628
 
629
+ WSP_GGML_API void wsp_ggml_print_backtrace(void);
630
+
588
631
  WSP_GGML_API void wsp_ggml_numa_init(void); // call once for better performance on NUMA systems
589
632
  WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node
590
633
 
@@ -605,6 +648,9 @@ extern "C" {
605
648
  WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op);
606
649
  WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
607
650
 
651
+ WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
652
+ WSP_GGML_API const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
653
+
608
654
  WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
609
655
 
610
656
  WSP_GGML_API bool wsp_ggml_is_quantized(enum wsp_ggml_type type);
@@ -674,18 +720,30 @@ extern "C" {
674
720
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_tensor (struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src);
675
721
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src);
676
722
 
723
+ // Context tensor enumeration and lookup
724
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx);
725
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
677
726
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name);
678
727
 
679
728
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
680
729
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value);
681
730
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_f32 (struct wsp_ggml_tensor * tensor, float value);
682
731
 
732
+ // Converts a flat index into coordinates
733
+ WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
734
+
683
735
  WSP_GGML_API int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i);
684
736
  WSP_GGML_API void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value);
685
737
 
738
+ WSP_GGML_API int32_t wsp_ggml_get_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3);
739
+ WSP_GGML_API void wsp_ggml_set_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
740
+
686
741
  WSP_GGML_API float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i);
687
742
  WSP_GGML_API void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value);
688
743
 
744
+ WSP_GGML_API float wsp_ggml_get_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3);
745
+ WSP_GGML_API void wsp_ggml_set_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
746
+
689
747
  WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor);
690
748
  WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor);
691
749
 
@@ -719,6 +777,12 @@ extern "C" {
719
777
  struct wsp_ggml_tensor * a,
720
778
  struct wsp_ggml_tensor * b);
721
779
 
780
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_cast(
781
+ struct wsp_ggml_context * ctx,
782
+ struct wsp_ggml_tensor * a,
783
+ struct wsp_ggml_tensor * b,
784
+ enum wsp_ggml_type type);
785
+
722
786
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add1(
723
787
  struct wsp_ggml_context * ctx,
724
788
  struct wsp_ggml_tensor * a,
@@ -828,6 +892,7 @@ extern "C" {
828
892
  struct wsp_ggml_tensor * a,
829
893
  struct wsp_ggml_tensor * b);
830
894
 
895
+ // sums repetitions in a into shape of b
831
896
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_back(
832
897
  struct wsp_ggml_context * ctx,
833
898
  struct wsp_ggml_tensor * a,
@@ -892,6 +957,10 @@ extern "C" {
892
957
  struct wsp_ggml_context * ctx,
893
958
  struct wsp_ggml_tensor * a);
894
959
 
960
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_leaky(
961
+ struct wsp_ggml_context * ctx,
962
+ struct wsp_ggml_tensor * a);
963
+
895
964
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_relu_inplace(
896
965
  struct wsp_ggml_context * ctx,
897
966
  struct wsp_ggml_tensor * a);
@@ -970,14 +1039,23 @@ extern "C" {
970
1039
  struct wsp_ggml_tensor * b,
971
1040
  float eps);
972
1041
 
973
- // A: n columns, m rows
974
- // B: n columns, p rows (i.e. we transpose it internally)
975
- // result is m columns, p rows
1042
+ // A: k columns, n rows => [ne03, ne02, n, k]
1043
+ // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
1044
+ // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
976
1045
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat(
977
1046
  struct wsp_ggml_context * ctx,
978
1047
  struct wsp_ggml_tensor * a,
979
1048
  struct wsp_ggml_tensor * b);
980
1049
 
1050
+ // indirect matrix multiplication
1051
+ // wsp_ggml_mul_mat_id(ctx, as, ids, id, b) ~= wsp_ggml_mul_mat(as[ids[id]], b)
1052
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
1053
+ struct wsp_ggml_context * ctx,
1054
+ struct wsp_ggml_tensor * as[],
1055
+ struct wsp_ggml_tensor * ids,
1056
+ int id,
1057
+ struct wsp_ggml_tensor * b);
1058
+
981
1059
  // A: m columns, n rows,
982
1060
  // B: p columns, n rows,
983
1061
  // result is m columns, p rows
@@ -1049,7 +1127,6 @@ extern "C" {
1049
1127
  size_t nb1,
1050
1128
  size_t offset);
1051
1129
 
1052
-
1053
1130
  // a -> b, return view(b)
1054
1131
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy(
1055
1132
  struct wsp_ggml_context * ctx,
@@ -1072,6 +1149,33 @@ extern "C" {
1072
1149
  struct wsp_ggml_context * ctx,
1073
1150
  struct wsp_ggml_tensor * a);
1074
1151
 
1152
+ // make contiguous, with new shape
1153
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_1d(
1154
+ struct wsp_ggml_context * ctx,
1155
+ struct wsp_ggml_tensor * a,
1156
+ int64_t ne0);
1157
+
1158
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_2d(
1159
+ struct wsp_ggml_context * ctx,
1160
+ struct wsp_ggml_tensor * a,
1161
+ int64_t ne0,
1162
+ int64_t ne1);
1163
+
1164
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_3d(
1165
+ struct wsp_ggml_context * ctx,
1166
+ struct wsp_ggml_tensor * a,
1167
+ int64_t ne0,
1168
+ int64_t ne1,
1169
+ int64_t ne2);
1170
+
1171
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_4d(
1172
+ struct wsp_ggml_context * ctx,
1173
+ struct wsp_ggml_tensor * a,
1174
+ int64_t ne0,
1175
+ int64_t ne1,
1176
+ int64_t ne2,
1177
+ int64_t ne3);
1178
+
1075
1179
  // return view(a), b specifies the new shape
1076
1180
  // TODO: when we start computing gradient, make a copy instead of view
1077
1181
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reshape(
@@ -1207,6 +1311,14 @@ extern "C" {
1207
1311
  struct wsp_ggml_context * ctx,
1208
1312
  struct wsp_ggml_tensor * a);
1209
1313
 
1314
+ // fused soft_max(a*scale + mask)
1315
+ // mask is optional
1316
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
1317
+ struct wsp_ggml_context * ctx,
1318
+ struct wsp_ggml_tensor * a,
1319
+ struct wsp_ggml_tensor * mask,
1320
+ float scale);
1321
+
1210
1322
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
1211
1323
  struct wsp_ggml_context * ctx,
1212
1324
  struct wsp_ggml_tensor * a,
@@ -1219,14 +1331,15 @@ extern "C" {
1219
1331
  struct wsp_ggml_tensor * b);
1220
1332
 
1221
1333
  // rotary position embedding
1222
- // if mode & 1 == 1, skip n_past elements
1334
+ // if mode & 1 == 1, skip n_past elements (DEPRECATED)
1223
1335
  // if mode & 2 == 1, GPT-NeoX style
1224
1336
  // if mode & 4 == 1, ChatGLM style
1225
- // TODO: avoid creating a new tensor every time
1337
+ //
1338
+ // b is an int32 vector with size a->ne[2], it contains the positions
1226
1339
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope(
1227
1340
  struct wsp_ggml_context * ctx,
1228
1341
  struct wsp_ggml_tensor * a,
1229
- int n_past,
1342
+ struct wsp_ggml_tensor * b,
1230
1343
  int n_dims,
1231
1344
  int mode,
1232
1345
  int n_ctx);
@@ -1235,7 +1348,7 @@ extern "C" {
1235
1348
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
1236
1349
  struct wsp_ggml_context * ctx,
1237
1350
  struct wsp_ggml_tensor * a,
1238
- int n_past,
1351
+ struct wsp_ggml_tensor * b,
1239
1352
  int n_dims,
1240
1353
  int mode,
1241
1354
  int n_ctx);
@@ -1244,29 +1357,43 @@ extern "C" {
1244
1357
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1245
1358
  struct wsp_ggml_context * ctx,
1246
1359
  struct wsp_ggml_tensor * a,
1247
- int n_past,
1360
+ struct wsp_ggml_tensor * b,
1248
1361
  int n_dims,
1249
1362
  int mode,
1250
1363
  int n_ctx,
1364
+ int n_orig_ctx,
1251
1365
  float freq_base,
1252
- float freq_scale);
1366
+ float freq_scale,
1367
+ float ext_factor,
1368
+ float attn_factor,
1369
+ float beta_fast,
1370
+ float beta_slow);
1253
1371
 
1254
1372
  // in-place, returns view(a)
1255
1373
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1256
1374
  struct wsp_ggml_context * ctx,
1257
1375
  struct wsp_ggml_tensor * a,
1258
- int n_past,
1376
+ struct wsp_ggml_tensor * b,
1259
1377
  int n_dims,
1260
1378
  int mode,
1261
1379
  int n_ctx,
1380
+ int n_orig_ctx,
1262
1381
  float freq_base,
1263
- float freq_scale);
1382
+ float freq_scale,
1383
+ float ext_factor,
1384
+ float attn_factor,
1385
+ float beta_fast,
1386
+ float beta_slow);
1387
+
1388
+ // compute correction dims for YaRN RoPE scaling
1389
+ void wsp_ggml_rope_yarn_corr_dims(
1390
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1264
1391
 
1265
1392
  // xPos RoPE, in-place, returns view(a)
1266
1393
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace(
1267
1394
  struct wsp_ggml_context * ctx,
1268
1395
  struct wsp_ggml_tensor * a,
1269
- int n_past,
1396
+ struct wsp_ggml_tensor * b,
1270
1397
  int n_dims,
1271
1398
  float base,
1272
1399
  bool down);
@@ -1276,18 +1403,23 @@ extern "C" {
1276
1403
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1277
1404
  struct wsp_ggml_context * ctx,
1278
1405
  struct wsp_ggml_tensor * a,
1279
- int n_past,
1406
+ struct wsp_ggml_tensor * b,
1280
1407
  int n_dims,
1281
1408
  int mode,
1282
1409
  int n_ctx,
1410
+ int n_orig_ctx,
1283
1411
  float freq_base,
1284
1412
  float freq_scale,
1413
+ float ext_factor,
1414
+ float attn_factor,
1415
+ float beta_fast,
1416
+ float beta_slow,
1285
1417
  float xpos_base,
1286
1418
  bool xpos_down);
1287
1419
 
1288
1420
  // alibi position embedding
1289
1421
  // in-place, returns view(a)
1290
- struct wsp_ggml_tensor * wsp_ggml_alibi(
1422
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_alibi(
1291
1423
  struct wsp_ggml_context * ctx,
1292
1424
  struct wsp_ggml_tensor * a,
1293
1425
  int n_past,
@@ -1296,12 +1428,24 @@ extern "C" {
1296
1428
 
1297
1429
  // clamp
1298
1430
  // in-place, returns view(a)
1299
- struct wsp_ggml_tensor * wsp_ggml_clamp(
1431
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_clamp(
1300
1432
  struct wsp_ggml_context * ctx,
1301
1433
  struct wsp_ggml_tensor * a,
1302
1434
  float min,
1303
1435
  float max);
1304
1436
 
1437
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col(
1438
+ struct wsp_ggml_context * ctx,
1439
+ struct wsp_ggml_tensor * a,
1440
+ struct wsp_ggml_tensor * b,
1441
+ int s0,
1442
+ int s1,
1443
+ int p0,
1444
+ int p1,
1445
+ int d0,
1446
+ int d1,
1447
+ bool is_2D);
1448
+
1305
1449
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
1306
1450
  struct wsp_ggml_context * ctx,
1307
1451
  struct wsp_ggml_tensor * a,
@@ -1319,6 +1463,14 @@ extern "C" {
1319
1463
  int s,
1320
1464
  int d);
1321
1465
 
1466
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
1467
+ struct wsp_ggml_context * ctx,
1468
+ struct wsp_ggml_tensor * a,
1469
+ struct wsp_ggml_tensor * b,
1470
+ int s0,
1471
+ int p0,
1472
+ int d0);
1473
+
1322
1474
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d(
1323
1475
  struct wsp_ggml_context * ctx,
1324
1476
  struct wsp_ggml_tensor * a,
@@ -1377,6 +1529,8 @@ extern "C" {
1377
1529
  int s0, // stride
1378
1530
  int p0); // padding
1379
1531
 
1532
+ // the result will have 2*p0 padding for the first dimension
1533
+ // and 2*p1 padding for the second dimension
1380
1534
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d(
1381
1535
  struct wsp_ggml_context * ctx,
1382
1536
  struct wsp_ggml_tensor * a,
@@ -1385,8 +1539,8 @@ extern "C" {
1385
1539
  int k1,
1386
1540
  int s0,
1387
1541
  int s1,
1388
- int p0,
1389
- int p1);
1542
+ float p0,
1543
+ float p1);
1390
1544
 
1391
1545
  // nearest interpolate
1392
1546
  // used in stable-diffusion
@@ -1395,6 +1549,23 @@ extern "C" {
1395
1549
  struct wsp_ggml_tensor * a,
1396
1550
  int scale_factor);
1397
1551
 
1552
+ // sort rows
1553
+ enum wsp_ggml_sort_order {
1554
+ WSP_GGML_SORT_ASC,
1555
+ WSP_GGML_SORT_DESC,
1556
+ };
1557
+
1558
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_argsort(
1559
+ struct wsp_ggml_context * ctx,
1560
+ struct wsp_ggml_tensor * a,
1561
+ enum wsp_ggml_sort_order order);
1562
+
1563
+ // top k elements per row
1564
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_top_k(
1565
+ struct wsp_ggml_context * ctx,
1566
+ struct wsp_ggml_tensor * a,
1567
+ int k);
1568
+
1398
1569
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn(
1399
1570
  struct wsp_ggml_context * ctx,
1400
1571
  struct wsp_ggml_tensor * q,
@@ -1456,7 +1627,6 @@ extern "C" {
1456
1627
  int kh);
1457
1628
 
1458
1629
  // used in sam
1459
-
1460
1630
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_rel_pos(
1461
1631
  struct wsp_ggml_context * ctx,
1462
1632
  struct wsp_ggml_tensor * a,
@@ -1627,19 +1797,22 @@ extern "C" {
1627
1797
  WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1628
1798
  WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep);
1629
1799
 
1630
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_forward (struct wsp_ggml_tensor * tensor);
1631
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, bool keep);
1632
-
1633
1800
  // graph allocation in a context
1634
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx);
1635
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_build_forward_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
1801
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
1802
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom (struct wsp_ggml_context * ctx, size_t size, bool grads);
1803
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
1804
+ WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_view (struct wsp_ggml_cgraph * cgraph, int i0, int i1);
1805
+ WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
1806
+ WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // zero grads
1807
+ WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
1808
+
1636
1809
  WSP_GGML_API size_t wsp_ggml_graph_overhead(void);
1810
+ WSP_GGML_API size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads);
1637
1811
 
1638
1812
  // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute()
1639
1813
  // when plan.work_size > 0, caller must allocate memory for plan.work_data
1640
1814
  WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/);
1641
- WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1642
- WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph);
1815
+ WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1643
1816
 
1644
1817
  // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context
1645
1818
  // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
@@ -1647,8 +1820,8 @@ extern "C" {
1647
1820
 
1648
1821
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name);
1649
1822
 
1650
- WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
1651
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
1823
+ WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
1824
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
1652
1825
 
1653
1826
  // print info and performance information for the graph
1654
1827
  WSP_GGML_API void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph);
@@ -1656,6 +1829,16 @@ extern "C" {
1656
1829
  // dump the graph into a file using the dot format
1657
1830
  WSP_GGML_API void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp_ggml_cgraph * gf, const char * filename);
1658
1831
 
1832
+ // build gradient checkpointing backward graph gb for gf using provided checkpoints
1833
+ // gb_tmp will contain original backward graph with rewritten backward process nodes,
1834
+ // but without the second forward pass nodes.
1835
+ WSP_GGML_API void wsp_ggml_build_backward_gradient_checkpointing(
1836
+ struct wsp_ggml_context * ctx,
1837
+ struct wsp_ggml_cgraph * gf,
1838
+ struct wsp_ggml_cgraph * gb,
1839
+ struct wsp_ggml_cgraph * gb_tmp,
1840
+ struct wsp_ggml_tensor * * checkpoints,
1841
+ int n_checkpoints);
1659
1842
  //
1660
1843
  // optimization
1661
1844
  //
@@ -1682,6 +1865,7 @@ extern "C" {
1682
1865
  WSP_GGML_OPT_NO_CONTEXT,
1683
1866
  WSP_GGML_OPT_INVALID_WOLFE,
1684
1867
  WSP_GGML_OPT_FAIL,
1868
+ WSP_GGML_OPT_CANCEL,
1685
1869
 
1686
1870
  WSP_GGML_LINESEARCH_FAIL = -128,
1687
1871
  WSP_GGML_LINESEARCH_MINIMUM_STEP,
@@ -1690,7 +1874,8 @@ extern "C" {
1690
1874
  WSP_GGML_LINESEARCH_INVALID_PARAMETERS,
1691
1875
  };
1692
1876
 
1693
- typedef void (*wsp_ggml_opt_callback)(void * data, float * sched);
1877
+ typedef void (*wsp_ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
1878
+ typedef void (*wsp_ggml_log_callback)(enum wsp_ggml_log_level level, const char * text, void * user_data);
1694
1879
 
1695
1880
  // optimization parameters
1696
1881
  //
@@ -1699,6 +1884,8 @@ extern "C" {
1699
1884
  struct wsp_ggml_opt_params {
1700
1885
  enum wsp_ggml_opt_type type;
1701
1886
 
1887
+ size_t graph_size;
1888
+
1702
1889
  int n_threads;
1703
1890
 
1704
1891
  // delta-based convergence test
@@ -1721,6 +1908,8 @@ extern "C" {
1721
1908
  bool print_forward_graph;
1722
1909
  bool print_backward_graph;
1723
1910
 
1911
+ int n_gradient_accumulation;
1912
+
1724
1913
  // ADAM parameters
1725
1914
  struct {
1726
1915
  int n_iter;
@@ -1766,6 +1955,7 @@ extern "C" {
1766
1955
  float loss_after;
1767
1956
 
1768
1957
  struct {
1958
+ struct wsp_ggml_tensor * g; // current gradient
1769
1959
  struct wsp_ggml_tensor * m; // first moment
1770
1960
  struct wsp_ggml_tensor * v; // second moment
1771
1961
  struct wsp_ggml_tensor * pf; // past function values
@@ -1829,134 +2019,142 @@ extern "C" {
1829
2019
  // quantization
1830
2020
  //
1831
2021
 
1832
- WSP_GGML_API size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
1833
- WSP_GGML_API size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
1834
- WSP_GGML_API size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
1835
- WSP_GGML_API size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
1836
- WSP_GGML_API size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
2022
+ // TODO: these would probably get removed in favor of the more general wsp_ggml_wsp_quantize_chunk
2023
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
2024
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
2025
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
2026
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
2027
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
2028
+
2029
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
2030
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
2031
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
2032
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2033
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
1837
2034
 
1838
- WSP_GGML_API size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
2035
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1839
2036
 
1840
2037
  //
1841
2038
  // gguf
1842
2039
  //
1843
2040
 
1844
- enum gguf_type {
1845
- GGUF_TYPE_UINT8 = 0,
1846
- GGUF_TYPE_INT8 = 1,
1847
- GGUF_TYPE_UINT16 = 2,
1848
- GGUF_TYPE_INT16 = 3,
1849
- GGUF_TYPE_UINT32 = 4,
1850
- GGUF_TYPE_INT32 = 5,
1851
- GGUF_TYPE_FLOAT32 = 6,
1852
- GGUF_TYPE_BOOL = 7,
1853
- GGUF_TYPE_STRING = 8,
1854
- GGUF_TYPE_ARRAY = 9,
1855
- GGUF_TYPE_UINT64 = 10,
1856
- GGUF_TYPE_INT64 = 11,
1857
- GGUF_TYPE_FLOAT64 = 12,
1858
- GGUF_TYPE_COUNT, // marks the end of the enum
2041
+ enum wsp_gguf_type {
2042
+ WSP_GGUF_TYPE_UINT8 = 0,
2043
+ WSP_GGUF_TYPE_INT8 = 1,
2044
+ WSP_GGUF_TYPE_UINT16 = 2,
2045
+ WSP_GGUF_TYPE_INT16 = 3,
2046
+ WSP_GGUF_TYPE_UINT32 = 4,
2047
+ WSP_GGUF_TYPE_INT32 = 5,
2048
+ WSP_GGUF_TYPE_FLOAT32 = 6,
2049
+ WSP_GGUF_TYPE_BOOL = 7,
2050
+ WSP_GGUF_TYPE_STRING = 8,
2051
+ WSP_GGUF_TYPE_ARRAY = 9,
2052
+ WSP_GGUF_TYPE_UINT64 = 10,
2053
+ WSP_GGUF_TYPE_INT64 = 11,
2054
+ WSP_GGUF_TYPE_FLOAT64 = 12,
2055
+ WSP_GGUF_TYPE_COUNT, // marks the end of the enum
1859
2056
  };
1860
2057
 
1861
- struct gguf_context;
2058
+ struct wsp_gguf_context;
1862
2059
 
1863
- struct gguf_init_params {
2060
+ struct wsp_gguf_init_params {
1864
2061
  bool no_alloc;
1865
2062
 
1866
2063
  // if not NULL, create a wsp_ggml_context and allocate the tensor data in it
1867
2064
  struct wsp_ggml_context ** ctx;
1868
2065
  };
1869
2066
 
1870
- WSP_GGML_API struct gguf_context * gguf_init_empty(void);
1871
- WSP_GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
1872
- //WSP_GGML_API struct gguf_context * gguf_init_from_buffer(..);
1873
-
1874
- WSP_GGML_API void gguf_free(struct gguf_context * ctx);
1875
-
1876
- WSP_GGML_API const char * gguf_type_name(enum gguf_type type);
1877
-
1878
- WSP_GGML_API int gguf_get_version (const struct gguf_context * ctx);
1879
- WSP_GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
1880
- WSP_GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
1881
- WSP_GGML_API void * gguf_get_data (const struct gguf_context * ctx);
1882
-
1883
- WSP_GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
1884
- WSP_GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
1885
- WSP_GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
1886
-
1887
- WSP_GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
1888
- WSP_GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
1889
-
1890
- // results are undefined if the wrong type is used for the key
1891
- WSP_GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i);
1892
- WSP_GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i);
1893
- WSP_GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i);
1894
- WSP_GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i);
1895
- WSP_GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i);
1896
- WSP_GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i);
1897
- WSP_GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i);
1898
- WSP_GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i);
1899
- WSP_GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i);
1900
- WSP_GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i);
1901
- WSP_GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i);
1902
- WSP_GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
1903
- WSP_GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i);
1904
- WSP_GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
1905
- WSP_GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
1906
-
1907
- WSP_GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
1908
- WSP_GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
1909
- WSP_GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
1910
- WSP_GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
2067
+ WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_empty(void);
2068
+ WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp_gguf_init_params params);
2069
+ //WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_buffer(..);
2070
+
2071
+ WSP_GGML_API void wsp_gguf_free(struct wsp_gguf_context * ctx);
2072
+
2073
+ WSP_GGML_API const char * wsp_gguf_type_name(enum wsp_gguf_type type);
2074
+
2075
+ WSP_GGML_API int wsp_gguf_get_version (const struct wsp_gguf_context * ctx);
2076
+ WSP_GGML_API size_t wsp_gguf_get_alignment (const struct wsp_gguf_context * ctx);
2077
+ WSP_GGML_API size_t wsp_gguf_get_data_offset(const struct wsp_gguf_context * ctx);
2078
+ WSP_GGML_API void * wsp_gguf_get_data (const struct wsp_gguf_context * ctx);
2079
+
2080
+ WSP_GGML_API int wsp_gguf_get_n_kv(const struct wsp_gguf_context * ctx);
2081
+ WSP_GGML_API int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key);
2082
+ WSP_GGML_API const char * wsp_gguf_get_key (const struct wsp_gguf_context * ctx, int key_id);
2083
+
2084
+ WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_kv_type (const struct wsp_gguf_context * ctx, int key_id);
2085
+ WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id);
2086
+
2087
+ // will abort if the wrong type is used for the key
2088
+ WSP_GGML_API uint8_t wsp_gguf_get_val_u8 (const struct wsp_gguf_context * ctx, int key_id);
2089
+ WSP_GGML_API int8_t wsp_gguf_get_val_i8 (const struct wsp_gguf_context * ctx, int key_id);
2090
+ WSP_GGML_API uint16_t wsp_gguf_get_val_u16 (const struct wsp_gguf_context * ctx, int key_id);
2091
+ WSP_GGML_API int16_t wsp_gguf_get_val_i16 (const struct wsp_gguf_context * ctx, int key_id);
2092
+ WSP_GGML_API uint32_t wsp_gguf_get_val_u32 (const struct wsp_gguf_context * ctx, int key_id);
2093
+ WSP_GGML_API int32_t wsp_gguf_get_val_i32 (const struct wsp_gguf_context * ctx, int key_id);
2094
+ WSP_GGML_API float wsp_gguf_get_val_f32 (const struct wsp_gguf_context * ctx, int key_id);
2095
+ WSP_GGML_API uint64_t wsp_gguf_get_val_u64 (const struct wsp_gguf_context * ctx, int key_id);
2096
+ WSP_GGML_API int64_t wsp_gguf_get_val_i64 (const struct wsp_gguf_context * ctx, int key_id);
2097
+ WSP_GGML_API double wsp_gguf_get_val_f64 (const struct wsp_gguf_context * ctx, int key_id);
2098
+ WSP_GGML_API bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id);
2099
+ WSP_GGML_API const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int key_id);
2100
+ WSP_GGML_API const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id);
2101
+ WSP_GGML_API int wsp_gguf_get_arr_n (const struct wsp_gguf_context * ctx, int key_id);
2102
+ WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id);
2103
+ WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i);
2104
+
2105
+ WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx);
2106
+ WSP_GGML_API int wsp_gguf_find_tensor (const struct wsp_gguf_context * ctx, const char * name);
2107
+ WSP_GGML_API size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i);
2108
+ WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i);
1911
2109
 
1912
2110
  // overrides existing values or adds a new one
1913
- WSP_GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
1914
- WSP_GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
1915
- WSP_GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
1916
- WSP_GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
1917
- WSP_GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
1918
- WSP_GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
1919
- WSP_GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
1920
- WSP_GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
1921
- WSP_GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
1922
- WSP_GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
1923
- WSP_GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
1924
- WSP_GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
1925
- WSP_GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
1926
- WSP_GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
2111
+ WSP_GGML_API void wsp_gguf_set_val_u8 (struct wsp_gguf_context * ctx, const char * key, uint8_t val);
2112
+ WSP_GGML_API void wsp_gguf_set_val_i8 (struct wsp_gguf_context * ctx, const char * key, int8_t val);
2113
+ WSP_GGML_API void wsp_gguf_set_val_u16 (struct wsp_gguf_context * ctx, const char * key, uint16_t val);
2114
+ WSP_GGML_API void wsp_gguf_set_val_i16 (struct wsp_gguf_context * ctx, const char * key, int16_t val);
2115
+ WSP_GGML_API void wsp_gguf_set_val_u32 (struct wsp_gguf_context * ctx, const char * key, uint32_t val);
2116
+ WSP_GGML_API void wsp_gguf_set_val_i32 (struct wsp_gguf_context * ctx, const char * key, int32_t val);
2117
+ WSP_GGML_API void wsp_gguf_set_val_f32 (struct wsp_gguf_context * ctx, const char * key, float val);
2118
+ WSP_GGML_API void wsp_gguf_set_val_u64 (struct wsp_gguf_context * ctx, const char * key, uint64_t val);
2119
+ WSP_GGML_API void wsp_gguf_set_val_i64 (struct wsp_gguf_context * ctx, const char * key, int64_t val);
2120
+ WSP_GGML_API void wsp_gguf_set_val_f64 (struct wsp_gguf_context * ctx, const char * key, double val);
2121
+ WSP_GGML_API void wsp_gguf_set_val_bool(struct wsp_gguf_context * ctx, const char * key, bool val);
2122
+ WSP_GGML_API void wsp_gguf_set_val_str (struct wsp_gguf_context * ctx, const char * key, const char * val);
2123
+ WSP_GGML_API void wsp_gguf_set_arr_data(struct wsp_gguf_context * ctx, const char * key, enum wsp_gguf_type type, const void * data, int n);
2124
+ WSP_GGML_API void wsp_gguf_set_arr_str (struct wsp_gguf_context * ctx, const char * key, const char ** data, int n);
1927
2125
 
1928
2126
  // set or add KV pairs from another context
1929
- WSP_GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
2127
+ WSP_GGML_API void wsp_gguf_set_kv(struct wsp_gguf_context * ctx, struct wsp_gguf_context * src);
1930
2128
 
1931
2129
  // manage tensor info
1932
- WSP_GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
1933
- WSP_GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum wsp_ggml_type type);
1934
- WSP_GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
2130
+ WSP_GGML_API void wsp_gguf_add_tensor(struct wsp_gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
2131
+ WSP_GGML_API void wsp_gguf_set_tensor_type(struct wsp_gguf_context * ctx, const char * name, enum wsp_ggml_type type);
2132
+ WSP_GGML_API void wsp_gguf_set_tensor_data(struct wsp_gguf_context * ctx, const char * name, const void * data, size_t size);
1935
2133
 
1936
2134
  // writing gguf files can be done in 2 ways:
1937
2135
  //
1938
- // - write the entire gguf_context to a binary file in a single pass:
2136
+ // - write the entire wsp_gguf_context to a binary file in a single pass:
1939
2137
  //
1940
- // gguf_write_to_file(ctx, fname);
2138
+ // wsp_gguf_write_to_file(ctx, fname);
1941
2139
  //
1942
2140
  // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
1943
2141
  //
1944
2142
  // FILE * f = fopen(fname, "wb");
1945
- // fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
2143
+ // fseek(f, wsp_gguf_get_meta_size(ctx), SEEK_SET);
1946
2144
  // fwrite(f, ...);
1947
- // void * data = gguf_meta_get_meta_data(ctx);
2145
+ // void * data = wsp_gguf_meta_get_meta_data(ctx);
1948
2146
  // fseek(f, 0, SEEK_SET);
1949
- // fwrite(f, data, gguf_get_meta_size(ctx));
2147
+ // fwrite(f, data, wsp_gguf_get_meta_size(ctx));
1950
2148
  // free(data);
1951
2149
  // fclose(f);
1952
2150
  //
1953
2151
 
1954
2152
  // write the entire context to a binary file
1955
- WSP_GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
2153
+ WSP_GGML_API void wsp_gguf_write_to_file(const struct wsp_gguf_context * ctx, const char * fname, bool only_meta);
1956
2154
 
1957
2155
  // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
1958
- WSP_GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
1959
- WSP_GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
2156
+ WSP_GGML_API size_t wsp_gguf_get_meta_size(const struct wsp_gguf_context * ctx);
2157
+ WSP_GGML_API void wsp_gguf_get_meta_data(const struct wsp_gguf_context * ctx, void * data);
1960
2158
 
1961
2159
  //
1962
2160
  // system info
@@ -2008,7 +2206,7 @@ extern "C" {
2008
2206
  enum wsp_ggml_type vec_dot_type;
2009
2207
  } wsp_ggml_type_traits_t;
2010
2208
 
2011
- wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
2209
+ WSP_GGML_API wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
2012
2210
 
2013
2211
  #ifdef __cplusplus
2014
2212
  }