whisper.rn 0.4.0-rc.2 → 0.4.0-rc.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/android/src/main/CMakeLists.txt +2 -0
  2. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  3. package/android/src/main/java/com/rnwhisper/WhisperContext.java +29 -15
  4. package/android/src/main/jni.cpp +6 -2
  5. package/cpp/ggml-alloc.c +413 -280
  6. package/cpp/ggml-alloc.h +67 -8
  7. package/cpp/ggml-backend-impl.h +87 -0
  8. package/cpp/ggml-backend.c +950 -0
  9. package/cpp/ggml-backend.h +136 -0
  10. package/cpp/ggml-impl.h +243 -0
  11. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +591 -121
  12. package/cpp/ggml-metal.h +21 -0
  13. package/cpp/ggml-metal.m +623 -234
  14. package/cpp/ggml-quants.c +7377 -0
  15. package/cpp/ggml-quants.h +224 -0
  16. package/cpp/ggml.c +3773 -4455
  17. package/cpp/ggml.h +279 -146
  18. package/cpp/whisper.cpp +182 -103
  19. package/cpp/whisper.h +48 -11
  20. package/ios/RNWhisper.mm +8 -2
  21. package/ios/RNWhisperContext.h +6 -2
  22. package/ios/RNWhisperContext.mm +97 -26
  23. package/jest/mock.js +1 -1
  24. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  25. package/lib/commonjs/index.js +28 -9
  26. package/lib/commonjs/index.js.map +1 -1
  27. package/lib/commonjs/version.json +1 -1
  28. package/lib/module/NativeRNWhisper.js.map +1 -1
  29. package/lib/module/index.js +28 -9
  30. package/lib/module/index.js.map +1 -1
  31. package/lib/module/version.json +1 -1
  32. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  33. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  34. package/lib/typescript/index.d.ts +8 -3
  35. package/lib/typescript/index.d.ts.map +1 -1
  36. package/package.json +1 -1
  37. package/src/NativeRNWhisper.ts +8 -1
  38. package/src/index.ts +30 -18
  39. package/src/version.json +1 -1
  40. package/whisper-rn.podspec +1 -2
package/cpp/ggml.h CHANGED
@@ -58,7 +58,8 @@
58
58
  // {
59
59
  // ...
60
60
  //
61
- // struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(f);
61
+ // struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx);
62
+ // wsp_ggml_build_forward_expand(gf, f);
62
63
  //
63
64
  // // set the input variable and parameter values
64
65
  // wsp_ggml_set_f32(x, 2.0f);
@@ -213,15 +214,14 @@
213
214
  #define WSP_GGML_QNT_VERSION 2 // bump this on quantization format changes
214
215
  #define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this
215
216
 
216
- #define WSP_GGML_MAX_DIMS 4
217
- #define WSP_GGML_MAX_NODES 4096
218
- #define WSP_GGML_MAX_PARAMS 256
219
- #define WSP_GGML_MAX_CONTEXTS 64
220
- #define WSP_GGML_MAX_SRC 6
221
- #define WSP_GGML_MAX_NAME 64
222
- #define WSP_GGML_MAX_OP_PARAMS 32
223
- #define WSP_GGML_DEFAULT_N_THREADS 4
224
-
217
+ #define WSP_GGML_MAX_DIMS 4
218
+ #define WSP_GGML_MAX_PARAMS 1024
219
+ #define WSP_GGML_MAX_CONTEXTS 64
220
+ #define WSP_GGML_MAX_SRC 6
221
+ #define WSP_GGML_MAX_NAME 64
222
+ #define WSP_GGML_MAX_OP_PARAMS 64
223
+ #define WSP_GGML_DEFAULT_N_THREADS 4
224
+ #define WSP_GGML_DEFAULT_GRAPH_SIZE 2048
225
225
  #if UINTPTR_MAX == 0xFFFFFFFF
226
226
  #define WSP_GGML_MEM_ALIGN 4
227
227
  #else
@@ -231,10 +231,11 @@
231
231
  #define WSP_GGML_EXIT_SUCCESS 0
232
232
  #define WSP_GGML_EXIT_ABORTED 1
233
233
 
234
- #define GGUF_MAGIC 0x46554747 // "GGUF"
235
- #define GGUF_VERSION 2
234
+ #define WSP_GGUF_MAGIC "GGUF"
235
+
236
+ #define WSP_GGUF_VERSION 3
236
237
 
237
- #define GGUF_DEFAULT_ALIGNMENT 32
238
+ #define WSP_GGUF_DEFAULT_ALIGNMENT 32
238
239
 
239
240
  #define WSP_GGML_UNUSED(x) (void)(x)
240
241
 
@@ -244,10 +245,21 @@
244
245
  do { \
245
246
  if (!(x)) { \
246
247
  fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
247
- abort(); \
248
+ fflush(stderr); \
249
+ fflush(stdout); \
250
+ wsp_ggml_print_backtrace(); \
251
+ exit(1); \
248
252
  } \
249
253
  } while (0)
250
254
 
255
+ #ifndef NDEBUG
256
+ #define WSP_GGML_UNREACHABLE() WSP_GGML_ASSERT(!"statement should not be reached")
257
+ #elif defined(__GNUC__)
258
+ #define WSP_GGML_UNREACHABLE() __builtin_unreachable()
259
+ #else
260
+ #define WSP_GGML_UNREACHABLE() ((void) 0)
261
+ #endif
262
+
251
263
  // used to copy the number of elements and stride in bytes of tensors into local variables.
252
264
  // main purpose is to reduce code duplication and improve readability.
253
265
  //
@@ -318,7 +330,7 @@ extern "C" {
318
330
  WSP_GGML_TYPE_COUNT,
319
331
  };
320
332
 
321
- enum wsp_ggml_backend {
333
+ enum wsp_ggml_backend_type {
322
334
  WSP_GGML_BACKEND_CPU = 0,
323
335
  WSP_GGML_BACKEND_GPU = 10,
324
336
  WSP_GGML_BACKEND_GPU_SPLIT = 20,
@@ -392,7 +404,12 @@ extern "C" {
392
404
  WSP_GGML_OP_ALIBI,
393
405
  WSP_GGML_OP_CLAMP,
394
406
  WSP_GGML_OP_CONV_1D,
407
+ WSP_GGML_OP_CONV_1D_STAGE_0, // internal
408
+ WSP_GGML_OP_CONV_1D_STAGE_1, // internal
409
+ WSP_GGML_OP_CONV_TRANSPOSE_1D,
395
410
  WSP_GGML_OP_CONV_2D,
411
+ WSP_GGML_OP_CONV_2D_STAGE_0, // internal
412
+ WSP_GGML_OP_CONV_2D_STAGE_1, // internal
396
413
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
397
414
  WSP_GGML_OP_POOL_1D,
398
415
  WSP_GGML_OP_POOL_2D,
@@ -437,6 +454,7 @@ extern "C" {
437
454
  WSP_GGML_UNARY_OP_GELU,
438
455
  WSP_GGML_UNARY_OP_GELU_QUICK,
439
456
  WSP_GGML_UNARY_OP_SILU,
457
+ WSP_GGML_UNARY_OP_LEAKY
440
458
  };
441
459
 
442
460
  enum wsp_ggml_object_type {
@@ -445,6 +463,12 @@ extern "C" {
445
463
  WSP_GGML_OBJECT_WORK_BUFFER
446
464
  };
447
465
 
466
+ enum wsp_ggml_log_level {
467
+ WSP_GGML_LOG_LEVEL_ERROR = 2,
468
+ WSP_GGML_LOG_LEVEL_WARN = 3,
469
+ WSP_GGML_LOG_LEVEL_INFO = 4
470
+ };
471
+
448
472
  // ggml object
449
473
  struct wsp_ggml_object {
450
474
  size_t offs;
@@ -461,14 +485,16 @@ extern "C" {
461
485
 
462
486
  // n-dimensional tensor
463
487
  struct wsp_ggml_tensor {
464
- enum wsp_ggml_type type;
465
- enum wsp_ggml_backend backend;
488
+ enum wsp_ggml_type type;
489
+ enum wsp_ggml_backend_type backend;
490
+
491
+ struct wsp_ggml_backend_buffer * buffer;
466
492
 
467
493
  int n_dims;
468
494
  int64_t ne[WSP_GGML_MAX_DIMS]; // number of elements
469
495
  size_t nb[WSP_GGML_MAX_DIMS]; // stride in bytes:
470
- // nb[0] = sizeof(type)
471
- // nb[1] = nb[0] * ne[0] + padding
496
+ // nb[0] = wsp_ggml_type_size(type)
497
+ // nb[1] = nb[0] * (ne[0] / wsp_ggml_blck_size(type)) + padding
472
498
  // nb[i] = nb[i-1] * ne[i-1]
473
499
 
474
500
  // compute data
@@ -496,7 +522,7 @@ extern "C" {
496
522
 
497
523
  void * extra; // extra things e.g. for ggml-cuda.cu
498
524
 
499
- char padding[4];
525
+ char padding[12];
500
526
  };
501
527
 
502
528
  static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor);
@@ -509,29 +535,35 @@ extern "C" {
509
535
 
510
536
  int n_threads;
511
537
 
512
- // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
513
- int n_tasks[WSP_GGML_MAX_NODES];
514
-
515
538
  // abort wsp_ggml_graph_compute when true
516
539
  bool (*abort_callback)(void * data);
517
540
  void * abort_callback_data;
518
541
  };
519
542
 
520
- // next prime after WSP_GGML_MAX_NODES
521
- // #define WSP_GGML_GRAPH_HASHTABLE_SIZE 4099
522
- // next prime after WSP_GGML_MAX_NODES * 2 (nodes + leafs)
523
- #define WSP_GGML_GRAPH_HASHTABLE_SIZE 8273
543
+ enum wsp_ggml_cgraph_eval_order {
544
+ WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
545
+ WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
546
+ WSP_GGML_CGRAPH_EVAL_ORDER_COUNT
547
+ };
548
+
549
+ struct wsp_ggml_hash_set {
550
+ size_t size;
551
+ struct wsp_ggml_tensor ** keys;
552
+ };
524
553
 
525
554
  // computation graph
526
555
  struct wsp_ggml_cgraph {
556
+ int size;
527
557
  int n_nodes;
528
558
  int n_leafs;
529
559
 
530
- struct wsp_ggml_tensor * nodes[WSP_GGML_MAX_NODES];
531
- struct wsp_ggml_tensor * grads[WSP_GGML_MAX_NODES];
532
- struct wsp_ggml_tensor * leafs[WSP_GGML_MAX_NODES];
560
+ struct wsp_ggml_tensor ** nodes;
561
+ struct wsp_ggml_tensor ** grads;
562
+ struct wsp_ggml_tensor ** leafs;
563
+
564
+ struct wsp_ggml_hash_set visited_hash_table;
533
565
 
534
- void * visited_hash_table[WSP_GGML_GRAPH_HASHTABLE_SIZE];
566
+ enum wsp_ggml_cgraph_eval_order order;
535
567
 
536
568
  // performance
537
569
  int perf_runs;
@@ -539,8 +571,6 @@ extern "C" {
539
571
  int64_t perf_time_us;
540
572
  };
541
573
 
542
- static const size_t WSP_GGML_GRAPH_SIZE = sizeof(struct wsp_ggml_cgraph);
543
-
544
574
  // scratch buffer
545
575
  struct wsp_ggml_scratch {
546
576
  size_t offs;
@@ -585,6 +615,8 @@ extern "C" {
585
615
  WSP_GGML_API int64_t wsp_ggml_cycles(void);
586
616
  WSP_GGML_API int64_t wsp_ggml_cycles_per_ms(void);
587
617
 
618
+ WSP_GGML_API void wsp_ggml_print_backtrace(void);
619
+
588
620
  WSP_GGML_API void wsp_ggml_numa_init(void); // call once for better performance on NUMA systems
589
621
  WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node
590
622
 
@@ -674,18 +706,30 @@ extern "C" {
674
706
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_tensor (struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src);
675
707
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src);
676
708
 
709
+ // Context tensor enumeration and lookup
710
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx);
711
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
677
712
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name);
678
713
 
679
714
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
680
715
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value);
681
716
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_f32 (struct wsp_ggml_tensor * tensor, float value);
682
717
 
718
+ // Converts a flat index into coordinates
719
+ WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
720
+
683
721
  WSP_GGML_API int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i);
684
722
  WSP_GGML_API void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value);
685
723
 
724
+ WSP_GGML_API int32_t wsp_ggml_get_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3);
725
+ WSP_GGML_API void wsp_ggml_set_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
726
+
686
727
  WSP_GGML_API float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i);
687
728
  WSP_GGML_API void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value);
688
729
 
730
+ WSP_GGML_API float wsp_ggml_get_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3);
731
+ WSP_GGML_API void wsp_ggml_set_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
732
+
689
733
  WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor);
690
734
  WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor);
691
735
 
@@ -719,6 +763,12 @@ extern "C" {
719
763
  struct wsp_ggml_tensor * a,
720
764
  struct wsp_ggml_tensor * b);
721
765
 
766
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_cast(
767
+ struct wsp_ggml_context * ctx,
768
+ struct wsp_ggml_tensor * a,
769
+ struct wsp_ggml_tensor * b,
770
+ enum wsp_ggml_type type);
771
+
722
772
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add1(
723
773
  struct wsp_ggml_context * ctx,
724
774
  struct wsp_ggml_tensor * a,
@@ -828,6 +878,7 @@ extern "C" {
828
878
  struct wsp_ggml_tensor * a,
829
879
  struct wsp_ggml_tensor * b);
830
880
 
881
+ // sums repetitions in a into shape of b
831
882
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_back(
832
883
  struct wsp_ggml_context * ctx,
833
884
  struct wsp_ggml_tensor * a,
@@ -892,6 +943,10 @@ extern "C" {
892
943
  struct wsp_ggml_context * ctx,
893
944
  struct wsp_ggml_tensor * a);
894
945
 
946
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_leaky(
947
+ struct wsp_ggml_context * ctx,
948
+ struct wsp_ggml_tensor * a);
949
+
895
950
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_relu_inplace(
896
951
  struct wsp_ggml_context * ctx,
897
952
  struct wsp_ggml_tensor * a);
@@ -970,9 +1025,9 @@ extern "C" {
970
1025
  struct wsp_ggml_tensor * b,
971
1026
  float eps);
972
1027
 
973
- // A: n columns, m rows
974
- // B: n columns, p rows (i.e. we transpose it internally)
975
- // result is m columns, p rows
1028
+ // A: k columns, n rows => [ne03, ne02, n, k]
1029
+ // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
1030
+ // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
976
1031
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat(
977
1032
  struct wsp_ggml_context * ctx,
978
1033
  struct wsp_ggml_tensor * a,
@@ -1049,7 +1104,6 @@ extern "C" {
1049
1104
  size_t nb1,
1050
1105
  size_t offset);
1051
1106
 
1052
-
1053
1107
  // a -> b, return view(b)
1054
1108
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy(
1055
1109
  struct wsp_ggml_context * ctx,
@@ -1072,6 +1126,33 @@ extern "C" {
1072
1126
  struct wsp_ggml_context * ctx,
1073
1127
  struct wsp_ggml_tensor * a);
1074
1128
 
1129
+ // make contiguous, with new shape
1130
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_1d(
1131
+ struct wsp_ggml_context * ctx,
1132
+ struct wsp_ggml_tensor * a,
1133
+ int64_t ne0);
1134
+
1135
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_2d(
1136
+ struct wsp_ggml_context * ctx,
1137
+ struct wsp_ggml_tensor * a,
1138
+ int64_t ne0,
1139
+ int64_t ne1);
1140
+
1141
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_3d(
1142
+ struct wsp_ggml_context * ctx,
1143
+ struct wsp_ggml_tensor * a,
1144
+ int64_t ne0,
1145
+ int64_t ne1,
1146
+ int64_t ne2);
1147
+
1148
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_4d(
1149
+ struct wsp_ggml_context * ctx,
1150
+ struct wsp_ggml_tensor * a,
1151
+ int64_t ne0,
1152
+ int64_t ne1,
1153
+ int64_t ne2,
1154
+ int64_t ne3);
1155
+
1075
1156
  // return view(a), b specifies the new shape
1076
1157
  // TODO: when we start computing gradient, make a copy instead of view
1077
1158
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reshape(
@@ -1219,14 +1300,15 @@ extern "C" {
1219
1300
  struct wsp_ggml_tensor * b);
1220
1301
 
1221
1302
  // rotary position embedding
1222
- // if mode & 1 == 1, skip n_past elements
1303
+ // if mode & 1 == 1, skip n_past elements (DEPRECATED)
1223
1304
  // if mode & 2 == 1, GPT-NeoX style
1224
1305
  // if mode & 4 == 1, ChatGLM style
1225
- // TODO: avoid creating a new tensor every time
1306
+ //
1307
+ // b is an int32 vector with size a->ne[2], it contains the positions
1226
1308
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope(
1227
1309
  struct wsp_ggml_context * ctx,
1228
1310
  struct wsp_ggml_tensor * a,
1229
- int n_past,
1311
+ struct wsp_ggml_tensor * b,
1230
1312
  int n_dims,
1231
1313
  int mode,
1232
1314
  int n_ctx);
@@ -1235,7 +1317,7 @@ extern "C" {
1235
1317
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
1236
1318
  struct wsp_ggml_context * ctx,
1237
1319
  struct wsp_ggml_tensor * a,
1238
- int n_past,
1320
+ struct wsp_ggml_tensor * b,
1239
1321
  int n_dims,
1240
1322
  int mode,
1241
1323
  int n_ctx);
@@ -1244,29 +1326,43 @@ extern "C" {
1244
1326
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1245
1327
  struct wsp_ggml_context * ctx,
1246
1328
  struct wsp_ggml_tensor * a,
1247
- int n_past,
1329
+ struct wsp_ggml_tensor * b,
1248
1330
  int n_dims,
1249
1331
  int mode,
1250
1332
  int n_ctx,
1333
+ int n_orig_ctx,
1251
1334
  float freq_base,
1252
- float freq_scale);
1335
+ float freq_scale,
1336
+ float ext_factor,
1337
+ float attn_factor,
1338
+ float beta_fast,
1339
+ float beta_slow);
1253
1340
 
1254
1341
  // in-place, returns view(a)
1255
1342
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1256
1343
  struct wsp_ggml_context * ctx,
1257
1344
  struct wsp_ggml_tensor * a,
1258
- int n_past,
1345
+ struct wsp_ggml_tensor * b,
1259
1346
  int n_dims,
1260
1347
  int mode,
1261
1348
  int n_ctx,
1349
+ int n_orig_ctx,
1262
1350
  float freq_base,
1263
- float freq_scale);
1351
+ float freq_scale,
1352
+ float ext_factor,
1353
+ float attn_factor,
1354
+ float beta_fast,
1355
+ float beta_slow);
1356
+
1357
+ // compute correction dims for YaRN RoPE scaling
1358
+ void wsp_ggml_rope_yarn_corr_dims(
1359
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1264
1360
 
1265
1361
  // xPos RoPE, in-place, returns view(a)
1266
1362
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace(
1267
1363
  struct wsp_ggml_context * ctx,
1268
1364
  struct wsp_ggml_tensor * a,
1269
- int n_past,
1365
+ struct wsp_ggml_tensor * b,
1270
1366
  int n_dims,
1271
1367
  float base,
1272
1368
  bool down);
@@ -1276,7 +1372,7 @@ extern "C" {
1276
1372
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1277
1373
  struct wsp_ggml_context * ctx,
1278
1374
  struct wsp_ggml_tensor * a,
1279
- int n_past,
1375
+ struct wsp_ggml_tensor * b,
1280
1376
  int n_dims,
1281
1377
  int mode,
1282
1378
  int n_ctx,
@@ -1287,7 +1383,7 @@ extern "C" {
1287
1383
 
1288
1384
  // alibi position embedding
1289
1385
  // in-place, returns view(a)
1290
- struct wsp_ggml_tensor * wsp_ggml_alibi(
1386
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_alibi(
1291
1387
  struct wsp_ggml_context * ctx,
1292
1388
  struct wsp_ggml_tensor * a,
1293
1389
  int n_past,
@@ -1296,7 +1392,7 @@ extern "C" {
1296
1392
 
1297
1393
  // clamp
1298
1394
  // in-place, returns view(a)
1299
- struct wsp_ggml_tensor * wsp_ggml_clamp(
1395
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_clamp(
1300
1396
  struct wsp_ggml_context * ctx,
1301
1397
  struct wsp_ggml_tensor * a,
1302
1398
  float min,
@@ -1319,6 +1415,14 @@ extern "C" {
1319
1415
  int s,
1320
1416
  int d);
1321
1417
 
1418
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
1419
+ struct wsp_ggml_context * ctx,
1420
+ struct wsp_ggml_tensor * a,
1421
+ struct wsp_ggml_tensor * b,
1422
+ int s0,
1423
+ int p0,
1424
+ int d0);
1425
+
1322
1426
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d(
1323
1427
  struct wsp_ggml_context * ctx,
1324
1428
  struct wsp_ggml_tensor * a,
@@ -1377,6 +1481,8 @@ extern "C" {
1377
1481
  int s0, // stride
1378
1482
  int p0); // padding
1379
1483
 
1484
+ // the result will have 2*p0 padding for the first dimension
1485
+ // and 2*p1 padding for the second dimension
1380
1486
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d(
1381
1487
  struct wsp_ggml_context * ctx,
1382
1488
  struct wsp_ggml_tensor * a,
@@ -1385,8 +1491,8 @@ extern "C" {
1385
1491
  int k1,
1386
1492
  int s0,
1387
1493
  int s1,
1388
- int p0,
1389
- int p1);
1494
+ float p0,
1495
+ float p1);
1390
1496
 
1391
1497
  // nearest interpolate
1392
1498
  // used in stable-diffusion
@@ -1627,19 +1733,22 @@ extern "C" {
1627
1733
  WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1628
1734
  WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep);
1629
1735
 
1630
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_forward (struct wsp_ggml_tensor * tensor);
1631
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, bool keep);
1632
-
1633
1736
  // graph allocation in a context
1634
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx);
1635
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_build_forward_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
1737
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
1738
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom (struct wsp_ggml_context * ctx, size_t size, bool grads);
1739
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
1740
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_view (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int i0, int i1);
1741
+ WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
1742
+ WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // zero grads
1743
+ WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
1744
+
1636
1745
  WSP_GGML_API size_t wsp_ggml_graph_overhead(void);
1746
+ WSP_GGML_API size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads);
1637
1747
 
1638
1748
  // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute()
1639
1749
  // when plan.work_size > 0, caller must allocate memory for plan.work_data
1640
1750
  WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/);
1641
- WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1642
- WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph);
1751
+ WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1643
1752
 
1644
1753
  // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context
1645
1754
  // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
@@ -1647,8 +1756,8 @@ extern "C" {
1647
1756
 
1648
1757
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name);
1649
1758
 
1650
- WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
1651
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
1759
+ WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
1760
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
1652
1761
 
1653
1762
  // print info and performance information for the graph
1654
1763
  WSP_GGML_API void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph);
@@ -1656,6 +1765,16 @@ extern "C" {
1656
1765
  // dump the graph into a file using the dot format
1657
1766
  WSP_GGML_API void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp_ggml_cgraph * gf, const char * filename);
1658
1767
 
1768
+ // build gradient checkpointing backward graph gb for gf using provided checkpoints
1769
+ // gb_tmp will contain original backward graph with rewritten backward process nodes,
1770
+ // but without the second forward pass nodes.
1771
+ WSP_GGML_API void wsp_ggml_build_backward_gradient_checkpointing(
1772
+ struct wsp_ggml_context * ctx,
1773
+ struct wsp_ggml_cgraph * gf,
1774
+ struct wsp_ggml_cgraph * gb,
1775
+ struct wsp_ggml_cgraph * gb_tmp,
1776
+ struct wsp_ggml_tensor * * checkpoints,
1777
+ int n_checkpoints);
1659
1778
  //
1660
1779
  // optimization
1661
1780
  //
@@ -1682,6 +1801,7 @@ extern "C" {
1682
1801
  WSP_GGML_OPT_NO_CONTEXT,
1683
1802
  WSP_GGML_OPT_INVALID_WOLFE,
1684
1803
  WSP_GGML_OPT_FAIL,
1804
+ WSP_GGML_OPT_CANCEL,
1685
1805
 
1686
1806
  WSP_GGML_LINESEARCH_FAIL = -128,
1687
1807
  WSP_GGML_LINESEARCH_MINIMUM_STEP,
@@ -1690,7 +1810,8 @@ extern "C" {
1690
1810
  WSP_GGML_LINESEARCH_INVALID_PARAMETERS,
1691
1811
  };
1692
1812
 
1693
- typedef void (*wsp_ggml_opt_callback)(void * data, float * sched);
1813
+ typedef void (*wsp_ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
1814
+ typedef void (*wsp_ggml_log_callback)(enum wsp_ggml_log_level level, const char * text, void * user_data);
1694
1815
 
1695
1816
  // optimization parameters
1696
1817
  //
@@ -1699,6 +1820,8 @@ extern "C" {
1699
1820
  struct wsp_ggml_opt_params {
1700
1821
  enum wsp_ggml_opt_type type;
1701
1822
 
1823
+ size_t graph_size;
1824
+
1702
1825
  int n_threads;
1703
1826
 
1704
1827
  // delta-based convergence test
@@ -1721,6 +1844,8 @@ extern "C" {
1721
1844
  bool print_forward_graph;
1722
1845
  bool print_backward_graph;
1723
1846
 
1847
+ int n_gradient_accumulation;
1848
+
1724
1849
  // ADAM parameters
1725
1850
  struct {
1726
1851
  int n_iter;
@@ -1766,6 +1891,7 @@ extern "C" {
1766
1891
  float loss_after;
1767
1892
 
1768
1893
  struct {
1894
+ struct wsp_ggml_tensor * g; // current gradient
1769
1895
  struct wsp_ggml_tensor * m; // first moment
1770
1896
  struct wsp_ggml_tensor * v; // second moment
1771
1897
  struct wsp_ggml_tensor * pf; // past function values
@@ -1829,134 +1955,141 @@ extern "C" {
1829
1955
  // quantization
1830
1956
  //
1831
1957
 
1958
+ // TODO: these would probably get removed in favor of the more general wsp_ggml_quantize_chunk
1832
1959
  WSP_GGML_API size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
1833
1960
  WSP_GGML_API size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
1834
1961
  WSP_GGML_API size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
1835
1962
  WSP_GGML_API size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
1836
1963
  WSP_GGML_API size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
1837
1964
 
1965
+ WSP_GGML_API size_t wsp_ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
1966
+ WSP_GGML_API size_t wsp_ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
1967
+ WSP_GGML_API size_t wsp_ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
1968
+ WSP_GGML_API size_t wsp_ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
1969
+ WSP_GGML_API size_t wsp_ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
1970
+
1838
1971
  WSP_GGML_API size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1839
1972
 
1840
1973
  //
1841
1974
  // gguf
1842
1975
  //
1843
1976
 
1844
- enum gguf_type {
1845
- GGUF_TYPE_UINT8 = 0,
1846
- GGUF_TYPE_INT8 = 1,
1847
- GGUF_TYPE_UINT16 = 2,
1848
- GGUF_TYPE_INT16 = 3,
1849
- GGUF_TYPE_UINT32 = 4,
1850
- GGUF_TYPE_INT32 = 5,
1851
- GGUF_TYPE_FLOAT32 = 6,
1852
- GGUF_TYPE_BOOL = 7,
1853
- GGUF_TYPE_STRING = 8,
1854
- GGUF_TYPE_ARRAY = 9,
1855
- GGUF_TYPE_UINT64 = 10,
1856
- GGUF_TYPE_INT64 = 11,
1857
- GGUF_TYPE_FLOAT64 = 12,
1858
- GGUF_TYPE_COUNT, // marks the end of the enum
1977
+ enum wsp_gguf_type {
1978
+ WSP_GGUF_TYPE_UINT8 = 0,
1979
+ WSP_GGUF_TYPE_INT8 = 1,
1980
+ WSP_GGUF_TYPE_UINT16 = 2,
1981
+ WSP_GGUF_TYPE_INT16 = 3,
1982
+ WSP_GGUF_TYPE_UINT32 = 4,
1983
+ WSP_GGUF_TYPE_INT32 = 5,
1984
+ WSP_GGUF_TYPE_FLOAT32 = 6,
1985
+ WSP_GGUF_TYPE_BOOL = 7,
1986
+ WSP_GGUF_TYPE_STRING = 8,
1987
+ WSP_GGUF_TYPE_ARRAY = 9,
1988
+ WSP_GGUF_TYPE_UINT64 = 10,
1989
+ WSP_GGUF_TYPE_INT64 = 11,
1990
+ WSP_GGUF_TYPE_FLOAT64 = 12,
1991
+ WSP_GGUF_TYPE_COUNT, // marks the end of the enum
1859
1992
  };
1860
1993
 
1861
- struct gguf_context;
1994
+ struct wsp_gguf_context;
1862
1995
 
1863
- struct gguf_init_params {
1996
+ struct wsp_gguf_init_params {
1864
1997
  bool no_alloc;
1865
1998
 
1866
1999
  // if not NULL, create a wsp_ggml_context and allocate the tensor data in it
1867
2000
  struct wsp_ggml_context ** ctx;
1868
2001
  };
1869
2002
 
1870
- WSP_GGML_API struct gguf_context * gguf_init_empty(void);
1871
- WSP_GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
1872
- //WSP_GGML_API struct gguf_context * gguf_init_from_buffer(..);
1873
-
1874
- WSP_GGML_API void gguf_free(struct gguf_context * ctx);
1875
-
1876
- WSP_GGML_API const char * gguf_type_name(enum gguf_type type);
1877
-
1878
- WSP_GGML_API int gguf_get_version (const struct gguf_context * ctx);
1879
- WSP_GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
1880
- WSP_GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
1881
- WSP_GGML_API void * gguf_get_data (const struct gguf_context * ctx);
1882
-
1883
- WSP_GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
1884
- WSP_GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
1885
- WSP_GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
1886
-
1887
- WSP_GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
1888
- WSP_GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
1889
-
1890
- // results are undefined if the wrong type is used for the key
1891
- WSP_GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i);
1892
- WSP_GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i);
1893
- WSP_GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i);
1894
- WSP_GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i);
1895
- WSP_GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i);
1896
- WSP_GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i);
1897
- WSP_GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i);
1898
- WSP_GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i);
1899
- WSP_GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i);
1900
- WSP_GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i);
1901
- WSP_GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i);
1902
- WSP_GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
1903
- WSP_GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i);
1904
- WSP_GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
1905
- WSP_GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
1906
-
1907
- WSP_GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
1908
- WSP_GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
1909
- WSP_GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
1910
- WSP_GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
2003
+ WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_empty(void);
2004
+ WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp_gguf_init_params params);
2005
+ //WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_buffer(..);
2006
+
2007
+ WSP_GGML_API void wsp_gguf_free(struct wsp_gguf_context * ctx);
2008
+
2009
+ WSP_GGML_API const char * wsp_gguf_type_name(enum wsp_gguf_type type);
2010
+
2011
+ WSP_GGML_API int wsp_gguf_get_version (const struct wsp_gguf_context * ctx);
2012
+ WSP_GGML_API size_t wsp_gguf_get_alignment (const struct wsp_gguf_context * ctx);
2013
+ WSP_GGML_API size_t wsp_gguf_get_data_offset(const struct wsp_gguf_context * ctx);
2014
+ WSP_GGML_API void * wsp_gguf_get_data (const struct wsp_gguf_context * ctx);
2015
+
2016
+ WSP_GGML_API int wsp_gguf_get_n_kv(const struct wsp_gguf_context * ctx);
2017
+ WSP_GGML_API int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key);
2018
+ WSP_GGML_API const char * wsp_gguf_get_key (const struct wsp_gguf_context * ctx, int key_id);
2019
+
2020
+ WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_kv_type (const struct wsp_gguf_context * ctx, int key_id);
2021
+ WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id);
2022
+
2023
+ // will abort if the wrong type is used for the key
2024
+ WSP_GGML_API uint8_t wsp_gguf_get_val_u8 (const struct wsp_gguf_context * ctx, int key_id);
2025
+ WSP_GGML_API int8_t wsp_gguf_get_val_i8 (const struct wsp_gguf_context * ctx, int key_id);
2026
+ WSP_GGML_API uint16_t wsp_gguf_get_val_u16 (const struct wsp_gguf_context * ctx, int key_id);
2027
+ WSP_GGML_API int16_t wsp_gguf_get_val_i16 (const struct wsp_gguf_context * ctx, int key_id);
2028
+ WSP_GGML_API uint32_t wsp_gguf_get_val_u32 (const struct wsp_gguf_context * ctx, int key_id);
2029
+ WSP_GGML_API int32_t wsp_gguf_get_val_i32 (const struct wsp_gguf_context * ctx, int key_id);
2030
+ WSP_GGML_API float wsp_gguf_get_val_f32 (const struct wsp_gguf_context * ctx, int key_id);
2031
+ WSP_GGML_API uint64_t wsp_gguf_get_val_u64 (const struct wsp_gguf_context * ctx, int key_id);
2032
+ WSP_GGML_API int64_t wsp_gguf_get_val_i64 (const struct wsp_gguf_context * ctx, int key_id);
2033
+ WSP_GGML_API double wsp_gguf_get_val_f64 (const struct wsp_gguf_context * ctx, int key_id);
2034
+ WSP_GGML_API bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id);
2035
+ WSP_GGML_API const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int key_id);
2036
+ WSP_GGML_API int wsp_gguf_get_arr_n (const struct wsp_gguf_context * ctx, int key_id);
2037
+ WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id);
2038
+ WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i);
2039
+
2040
+ WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx);
2041
+ WSP_GGML_API int wsp_gguf_find_tensor (const struct wsp_gguf_context * ctx, const char * name);
2042
+ WSP_GGML_API size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i);
2043
+ WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i);
1911
2044
 
1912
2045
  // overrides existing values or adds a new one
1913
- WSP_GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
1914
- WSP_GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
1915
- WSP_GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
1916
- WSP_GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
1917
- WSP_GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
1918
- WSP_GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
1919
- WSP_GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
1920
- WSP_GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
1921
- WSP_GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
1922
- WSP_GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
1923
- WSP_GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
1924
- WSP_GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
1925
- WSP_GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
1926
- WSP_GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
2046
+ WSP_GGML_API void wsp_gguf_set_val_u8 (struct wsp_gguf_context * ctx, const char * key, uint8_t val);
2047
+ WSP_GGML_API void wsp_gguf_set_val_i8 (struct wsp_gguf_context * ctx, const char * key, int8_t val);
2048
+ WSP_GGML_API void wsp_gguf_set_val_u16 (struct wsp_gguf_context * ctx, const char * key, uint16_t val);
2049
+ WSP_GGML_API void wsp_gguf_set_val_i16 (struct wsp_gguf_context * ctx, const char * key, int16_t val);
2050
+ WSP_GGML_API void wsp_gguf_set_val_u32 (struct wsp_gguf_context * ctx, const char * key, uint32_t val);
2051
+ WSP_GGML_API void wsp_gguf_set_val_i32 (struct wsp_gguf_context * ctx, const char * key, int32_t val);
2052
+ WSP_GGML_API void wsp_gguf_set_val_f32 (struct wsp_gguf_context * ctx, const char * key, float val);
2053
+ WSP_GGML_API void wsp_gguf_set_val_u64 (struct wsp_gguf_context * ctx, const char * key, uint64_t val);
2054
+ WSP_GGML_API void wsp_gguf_set_val_i64 (struct wsp_gguf_context * ctx, const char * key, int64_t val);
2055
+ WSP_GGML_API void wsp_gguf_set_val_f64 (struct wsp_gguf_context * ctx, const char * key, double val);
2056
+ WSP_GGML_API void wsp_gguf_set_val_bool(struct wsp_gguf_context * ctx, const char * key, bool val);
2057
+ WSP_GGML_API void wsp_gguf_set_val_str (struct wsp_gguf_context * ctx, const char * key, const char * val);
2058
+ WSP_GGML_API void wsp_gguf_set_arr_data(struct wsp_gguf_context * ctx, const char * key, enum wsp_gguf_type type, const void * data, int n);
2059
+ WSP_GGML_API void wsp_gguf_set_arr_str (struct wsp_gguf_context * ctx, const char * key, const char ** data, int n);
1927
2060
 
1928
2061
  // set or add KV pairs from another context
1929
- WSP_GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
2062
+ WSP_GGML_API void wsp_gguf_set_kv(struct wsp_gguf_context * ctx, struct wsp_gguf_context * src);
1930
2063
 
1931
2064
  // manage tensor info
1932
- WSP_GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
1933
- WSP_GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum wsp_ggml_type type);
1934
- WSP_GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
2065
+ WSP_GGML_API void wsp_gguf_add_tensor(struct wsp_gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
2066
+ WSP_GGML_API void wsp_gguf_set_tensor_type(struct wsp_gguf_context * ctx, const char * name, enum wsp_ggml_type type);
2067
+ WSP_GGML_API void wsp_gguf_set_tensor_data(struct wsp_gguf_context * ctx, const char * name, const void * data, size_t size);
1935
2068
 
1936
2069
  // writing gguf files can be done in 2 ways:
1937
2070
  //
1938
- // - write the entire gguf_context to a binary file in a single pass:
2071
+ // - write the entire wsp_gguf_context to a binary file in a single pass:
1939
2072
  //
1940
- // gguf_write_to_file(ctx, fname);
2073
+ // wsp_gguf_write_to_file(ctx, fname);
1941
2074
  //
1942
2075
  // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
1943
2076
  //
1944
2077
  // FILE * f = fopen(fname, "wb");
1945
- // fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
2078
+ // fseek(f, wsp_gguf_get_meta_size(ctx), SEEK_SET);
1946
2079
  // fwrite(f, ...);
1947
- // void * data = gguf_meta_get_meta_data(ctx);
2080
+ // void * data = wsp_gguf_meta_get_meta_data(ctx);
1948
2081
  // fseek(f, 0, SEEK_SET);
1949
- // fwrite(f, data, gguf_get_meta_size(ctx));
2082
+ // fwrite(f, data, wsp_gguf_get_meta_size(ctx));
1950
2083
  // free(data);
1951
2084
  // fclose(f);
1952
2085
  //
1953
2086
 
1954
2087
  // write the entire context to a binary file
1955
- WSP_GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
2088
+ WSP_GGML_API void wsp_gguf_write_to_file(const struct wsp_gguf_context * ctx, const char * fname, bool only_meta);
1956
2089
 
1957
2090
  // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
1958
- WSP_GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
1959
- WSP_GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
2091
+ WSP_GGML_API size_t wsp_gguf_get_meta_size(const struct wsp_gguf_context * ctx);
2092
+ WSP_GGML_API void wsp_gguf_get_meta_data(const struct wsp_gguf_context * ctx, void * data);
1960
2093
 
1961
2094
  //
1962
2095
  // system info
@@ -2008,7 +2141,7 @@ extern "C" {
2008
2141
  enum wsp_ggml_type vec_dot_type;
2009
2142
  } wsp_ggml_type_traits_t;
2010
2143
 
2011
- wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
2144
+ WSP_GGML_API wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
2012
2145
 
2013
2146
  #ifdef __cplusplus
2014
2147
  }