whisper.rn 0.3.8 → 0.4.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml.h CHANGED
@@ -65,7 +65,7 @@
65
65
  // wsp_ggml_set_f32(a, 3.0f);
66
66
  // wsp_ggml_set_f32(b, 4.0f);
67
67
  //
68
- // wsp_ggml_graph_compute(ctx0, &gf);
68
+ // wsp_ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
69
69
  //
70
70
  // printf("f = %f\n", wsp_ggml_get_f32_1d(f, 0));
71
71
  //
@@ -130,13 +130,16 @@
130
130
  // The data of the tensor is accessed via the "data" pointer. For example:
131
131
  //
132
132
  // {
133
- // struct wsp_ggml_tensor * a = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 2, 3);
133
+ // const int nx = 2;
134
+ // const int ny = 3;
134
135
  //
135
- // // a[1, 2] = 1.0f;
136
- // *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
136
+ // struct wsp_ggml_tensor * a = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, nx, ny);
137
137
  //
138
- // // a[2, 0] = 2.0f;
139
- // *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
138
+ // for (int y = 0; y < ny; y++) {
139
+ // for (int x = 0; x < nx; x++) {
140
+ // *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;
141
+ // }
142
+ // }
140
143
  //
141
144
  // ...
142
145
  // }
@@ -183,6 +186,23 @@
183
186
  # define WSP_GGML_API
184
187
  #endif
185
188
 
189
+ // TODO: support for clang
190
+ #ifdef __GNUC__
191
+ # define WSP_GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
192
+ #elif defined(_MSC_VER)
193
+ # define WSP_GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
194
+ #else
195
+ # define WSP_GGML_DEPRECATED(func, hint) func
196
+ #endif
197
+
198
+ #ifndef __GNUC__
199
+ # define WSP_GGML_ATTRIBUTE_FORMAT(...)
200
+ #elif defined(__MINGW32__)
201
+ # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
202
+ #else
203
+ # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
204
+ #endif
205
+
186
206
  #include <stdint.h>
187
207
  #include <stddef.h>
188
208
  #include <stdbool.h>
@@ -197,12 +217,29 @@
197
217
  #define WSP_GGML_MAX_NODES 4096
198
218
  #define WSP_GGML_MAX_PARAMS 256
199
219
  #define WSP_GGML_MAX_CONTEXTS 64
200
- #define WSP_GGML_MAX_OPT 4
201
- #define WSP_GGML_MAX_NAME 48
220
+ #define WSP_GGML_MAX_SRC 6
221
+ #define WSP_GGML_MAX_NAME 64
222
+ #define WSP_GGML_MAX_OP_PARAMS 32
202
223
  #define WSP_GGML_DEFAULT_N_THREADS 4
203
224
 
225
+ #if UINTPTR_MAX == 0xFFFFFFFF
226
+ #define WSP_GGML_MEM_ALIGN 4
227
+ #else
228
+ #define WSP_GGML_MEM_ALIGN 16
229
+ #endif
230
+
231
+ #define WSP_GGML_EXIT_SUCCESS 0
232
+ #define WSP_GGML_EXIT_ABORTED 1
233
+
234
+ #define GGUF_MAGIC 0x46554747 // "GGUF"
235
+ #define GGUF_VERSION 2
236
+
237
+ #define GGUF_DEFAULT_ALIGNMENT 32
238
+
204
239
  #define WSP_GGML_UNUSED(x) (void)(x)
205
240
 
241
+ #define WSP_GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
242
+
206
243
  #define WSP_GGML_ASSERT(x) \
207
244
  do { \
208
245
  if (!(x)) { \
@@ -239,8 +276,9 @@
239
276
  extern "C" {
240
277
  #endif
241
278
 
242
- #ifdef __ARM_NEON
243
- // we use the built-in 16-bit float type
279
+ #if defined(__ARM_NEON) && defined(__CUDACC__)
280
+ typedef half wsp_ggml_fp16_t;
281
+ #elif defined(__ARM_NEON)
244
282
  typedef __fp16 wsp_ggml_fp16_t;
245
283
  #else
246
284
  typedef uint16_t wsp_ggml_fp16_t;
@@ -250,8 +288,8 @@ extern "C" {
250
288
  WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t x);
251
289
  WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float x);
252
290
 
253
- WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t * x, float * y, size_t n);
254
- WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float * x, wsp_ggml_fp16_t * y, size_t n);
291
+ WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t * x, float * y, int n);
292
+ WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float * x, wsp_ggml_fp16_t * y, int n);
255
293
 
256
294
  struct wsp_ggml_object;
257
295
  struct wsp_ggml_context;
@@ -324,20 +362,12 @@ extern "C" {
324
362
  WSP_GGML_OP_ARGMAX,
325
363
  WSP_GGML_OP_REPEAT,
326
364
  WSP_GGML_OP_REPEAT_BACK,
327
- WSP_GGML_OP_ABS,
328
- WSP_GGML_OP_SGN,
329
- WSP_GGML_OP_NEG,
330
- WSP_GGML_OP_STEP,
331
- WSP_GGML_OP_TANH,
332
- WSP_GGML_OP_ELU,
333
- WSP_GGML_OP_RELU,
334
- WSP_GGML_OP_GELU,
335
- WSP_GGML_OP_GELU_QUICK,
336
- WSP_GGML_OP_SILU,
365
+ WSP_GGML_OP_CONCAT,
337
366
  WSP_GGML_OP_SILU_BACK,
338
367
  WSP_GGML_OP_NORM, // normalize
339
368
  WSP_GGML_OP_RMS_NORM,
340
369
  WSP_GGML_OP_RMS_NORM_BACK,
370
+ WSP_GGML_OP_GROUP_NORM,
341
371
 
342
372
  WSP_GGML_OP_MUL_MAT,
343
373
  WSP_GGML_OP_OUT_PROD,
@@ -363,16 +393,29 @@ extern "C" {
363
393
  WSP_GGML_OP_CLAMP,
364
394
  WSP_GGML_OP_CONV_1D,
365
395
  WSP_GGML_OP_CONV_2D,
396
+ WSP_GGML_OP_CONV_TRANSPOSE_2D,
397
+ WSP_GGML_OP_POOL_1D,
398
+ WSP_GGML_OP_POOL_2D,
399
+
400
+ WSP_GGML_OP_UPSCALE, // nearest interpolate
366
401
 
367
402
  WSP_GGML_OP_FLASH_ATTN,
368
403
  WSP_GGML_OP_FLASH_FF,
369
404
  WSP_GGML_OP_FLASH_ATTN_BACK,
370
405
  WSP_GGML_OP_WIN_PART,
371
406
  WSP_GGML_OP_WIN_UNPART,
407
+ WSP_GGML_OP_GET_REL_POS,
408
+ WSP_GGML_OP_ADD_REL_POS,
409
+
410
+ WSP_GGML_OP_UNARY,
372
411
 
373
412
  WSP_GGML_OP_MAP_UNARY,
374
413
  WSP_GGML_OP_MAP_BINARY,
375
414
 
415
+ WSP_GGML_OP_MAP_CUSTOM1_F32,
416
+ WSP_GGML_OP_MAP_CUSTOM2_F32,
417
+ WSP_GGML_OP_MAP_CUSTOM3_F32,
418
+
376
419
  WSP_GGML_OP_MAP_CUSTOM1,
377
420
  WSP_GGML_OP_MAP_CUSTOM2,
378
421
  WSP_GGML_OP_MAP_CUSTOM3,
@@ -383,6 +426,24 @@ extern "C" {
383
426
  WSP_GGML_OP_COUNT,
384
427
  };
385
428
 
429
+ enum wsp_ggml_unary_op {
430
+ WSP_GGML_UNARY_OP_ABS,
431
+ WSP_GGML_UNARY_OP_SGN,
432
+ WSP_GGML_UNARY_OP_NEG,
433
+ WSP_GGML_UNARY_OP_STEP,
434
+ WSP_GGML_UNARY_OP_TANH,
435
+ WSP_GGML_UNARY_OP_ELU,
436
+ WSP_GGML_UNARY_OP_RELU,
437
+ WSP_GGML_UNARY_OP_GELU,
438
+ WSP_GGML_UNARY_OP_GELU_QUICK,
439
+ WSP_GGML_UNARY_OP_SILU,
440
+ };
441
+
442
+ enum wsp_ggml_object_type {
443
+ WSP_GGML_OBJECT_TENSOR,
444
+ WSP_GGML_OBJECT_GRAPH,
445
+ WSP_GGML_OBJECT_WORK_BUFFER
446
+ };
386
447
 
387
448
  // ggml object
388
449
  struct wsp_ggml_object {
@@ -391,7 +452,9 @@ extern "C" {
391
452
 
392
453
  struct wsp_ggml_object * next;
393
454
 
394
- char padding[8];
455
+ enum wsp_ggml_object_type type;
456
+
457
+ char padding[4];
395
458
  };
396
459
 
397
460
  static const size_t WSP_GGML_OBJECT_SIZE = sizeof(struct wsp_ggml_object);
@@ -411,21 +474,22 @@ extern "C" {
411
474
  // compute data
412
475
  enum wsp_ggml_op op;
413
476
 
477
+ // op params - allocated as int32_t for alignment
478
+ int32_t op_params[WSP_GGML_MAX_OP_PARAMS / sizeof(int32_t)];
479
+
414
480
  bool is_param;
415
481
 
416
482
  struct wsp_ggml_tensor * grad;
417
- struct wsp_ggml_tensor * src0;
418
- struct wsp_ggml_tensor * src1;
419
- struct wsp_ggml_tensor * opt[WSP_GGML_MAX_OPT];
420
-
421
- // thread scheduling
422
- int n_tasks;
483
+ struct wsp_ggml_tensor * src[WSP_GGML_MAX_SRC];
423
484
 
424
485
  // performance
425
486
  int perf_runs;
426
487
  int64_t perf_cycles;
427
488
  int64_t perf_time_us;
428
489
 
490
+ struct wsp_ggml_tensor * view_src;
491
+ size_t view_offs;
492
+
429
493
  void * data;
430
494
 
431
495
  char name[WSP_GGML_MAX_NAME];
@@ -437,25 +501,46 @@ extern "C" {
437
501
 
438
502
  static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor);
439
503
 
504
+ // the compute plan that needs to be prepared for wsp_ggml_graph_compute()
505
+ // since https://github.com/ggerganov/ggml/issues/287
506
+ struct wsp_ggml_cplan {
507
+ size_t work_size; // size of work buffer, calculated by `wsp_ggml_graph_plan()`
508
+ uint8_t * work_data; // work buffer, to be allocated by caller before calling to `wsp_ggml_graph_compute()`
509
+
510
+ int n_threads;
511
+
512
+ // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
513
+ int n_tasks[WSP_GGML_MAX_NODES];
514
+
515
+ // abort wsp_ggml_graph_compute when true
516
+ bool (*abort_callback)(void * data);
517
+ void * abort_callback_data;
518
+ };
519
+
520
+ // next prime after WSP_GGML_MAX_NODES
521
+ // #define WSP_GGML_GRAPH_HASHTABLE_SIZE 4099
522
+ // next prime after WSP_GGML_MAX_NODES * 2 (nodes + leafs)
523
+ #define WSP_GGML_GRAPH_HASHTABLE_SIZE 8273
524
+
440
525
  // computation graph
441
526
  struct wsp_ggml_cgraph {
442
527
  int n_nodes;
443
528
  int n_leafs;
444
- int n_threads;
445
-
446
- size_t work_size;
447
- struct wsp_ggml_tensor * work;
448
529
 
449
530
  struct wsp_ggml_tensor * nodes[WSP_GGML_MAX_NODES];
450
531
  struct wsp_ggml_tensor * grads[WSP_GGML_MAX_NODES];
451
532
  struct wsp_ggml_tensor * leafs[WSP_GGML_MAX_NODES];
452
533
 
534
+ void * visited_hash_table[WSP_GGML_GRAPH_HASHTABLE_SIZE];
535
+
453
536
  // performance
454
537
  int perf_runs;
455
538
  int64_t perf_cycles;
456
539
  int64_t perf_time_us;
457
540
  };
458
541
 
542
+ static const size_t WSP_GGML_GRAPH_SIZE = sizeof(struct wsp_ggml_cgraph);
543
+
459
544
  // scratch buffer
460
545
  struct wsp_ggml_scratch {
461
546
  size_t offs;
@@ -509,6 +594,7 @@ extern "C" {
509
594
  WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
510
595
  WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
511
596
  WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
597
+ WSP_GGML_API size_t wsp_ggml_nbytes_pad (const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
512
598
  WSP_GGML_API size_t wsp_ggml_nbytes_split(const struct wsp_ggml_tensor * tensor, int nrows_split);
513
599
 
514
600
  WSP_GGML_API int wsp_ggml_blck_size (enum wsp_ggml_type type);
@@ -517,6 +603,7 @@ extern "C" {
517
603
 
518
604
  WSP_GGML_API const char * wsp_ggml_type_name(enum wsp_ggml_type type);
519
605
  WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op);
606
+ WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
520
607
 
521
608
  WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
522
609
 
@@ -529,6 +616,8 @@ extern "C" {
529
616
  WSP_GGML_API bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor);
530
617
  WSP_GGML_API bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor);
531
618
 
619
+ WSP_GGML_API bool wsp_ggml_are_same_shape(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
620
+
532
621
  // use this to compute the memory overhead of a tensor
533
622
  WSP_GGML_API size_t wsp_ggml_tensor_overhead(void);
534
623
 
@@ -540,6 +629,7 @@ extern "C" {
540
629
  WSP_GGML_API size_t wsp_ggml_used_mem(const struct wsp_ggml_context * ctx);
541
630
 
542
631
  WSP_GGML_API size_t wsp_ggml_set_scratch (struct wsp_ggml_context * ctx, struct wsp_ggml_scratch scratch);
632
+ WSP_GGML_API bool wsp_ggml_get_no_alloc(struct wsp_ggml_context * ctx);
543
633
  WSP_GGML_API void wsp_ggml_set_no_alloc(struct wsp_ggml_context * ctx, bool no_alloc);
544
634
 
545
635
  WSP_GGML_API void * wsp_ggml_get_mem_buffer (const struct wsp_ggml_context * ctx);
@@ -582,7 +672,7 @@ extern "C" {
582
672
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float value);
583
673
 
584
674
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_tensor (struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src);
585
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src);
675
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src);
586
676
 
587
677
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name);
588
678
 
@@ -599,9 +689,12 @@ extern "C" {
599
689
  WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor);
600
690
  WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor);
601
691
 
602
- WSP_GGML_API const char * wsp_ggml_get_name(const struct wsp_ggml_tensor * tensor);
603
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_name(struct wsp_ggml_tensor * tensor, const char * name);
604
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_format_name(struct wsp_ggml_tensor * tensor, const char * fmt, ...);
692
+ WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
693
+
694
+ WSP_GGML_API const char * wsp_ggml_get_name (const struct wsp_ggml_tensor * tensor);
695
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_name ( struct wsp_ggml_tensor * tensor, const char * name);
696
+ WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
697
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_format_name( struct wsp_ggml_tensor * tensor, const char * fmt, ...);
605
698
 
606
699
  //
607
700
  // operations on tensors with backpropagation
@@ -611,6 +704,11 @@ extern "C" {
611
704
  struct wsp_ggml_context * ctx,
612
705
  struct wsp_ggml_tensor * a);
613
706
 
707
+ // in-place, returns view(a)
708
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_inplace(
709
+ struct wsp_ggml_context * ctx,
710
+ struct wsp_ggml_tensor * a);
711
+
614
712
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add(
615
713
  struct wsp_ggml_context * ctx,
616
714
  struct wsp_ggml_tensor * a,
@@ -735,6 +833,13 @@ extern "C" {
735
833
  struct wsp_ggml_tensor * a,
736
834
  struct wsp_ggml_tensor * b);
737
835
 
836
+ // concat a and b on dim 2
837
+ // used in stable-diffusion
838
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_concat(
839
+ struct wsp_ggml_context * ctx,
840
+ struct wsp_ggml_tensor * a,
841
+ struct wsp_ggml_tensor * b);
842
+
738
843
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_abs(
739
844
  struct wsp_ggml_context * ctx,
740
845
  struct wsp_ggml_tensor * a);
@@ -824,29 +929,46 @@ extern "C" {
824
929
  struct wsp_ggml_tensor * b);
825
930
 
826
931
  // normalize along rows
827
- // TODO: eps is hardcoded to 1e-5 for now
828
932
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm(
829
933
  struct wsp_ggml_context * ctx,
830
- struct wsp_ggml_tensor * a);
934
+ struct wsp_ggml_tensor * a,
935
+ float eps);
831
936
 
832
937
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm_inplace(
833
938
  struct wsp_ggml_context * ctx,
834
- struct wsp_ggml_tensor * a);
939
+ struct wsp_ggml_tensor * a,
940
+ float eps);
835
941
 
836
942
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rms_norm(
837
943
  struct wsp_ggml_context * ctx,
838
- struct wsp_ggml_tensor * a);
944
+ struct wsp_ggml_tensor * a,
945
+ float eps);
839
946
 
840
947
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rms_norm_inplace(
841
948
  struct wsp_ggml_context * ctx,
842
- struct wsp_ggml_tensor * a);
949
+ struct wsp_ggml_tensor * a,
950
+ float eps);
951
+
952
+ // group normalize along ne0*ne1*n_groups
953
+ // used in stable-diffusion
954
+ // TODO: eps is hardcoded to 1e-6 for now
955
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm(
956
+ struct wsp_ggml_context * ctx,
957
+ struct wsp_ggml_tensor * a,
958
+ int n_groups);
959
+
960
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm_inplace(
961
+ struct wsp_ggml_context * ctx,
962
+ struct wsp_ggml_tensor * a,
963
+ int n_groups);
843
964
 
844
965
  // a - x
845
966
  // b - dy
846
967
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rms_norm_back(
847
968
  struct wsp_ggml_context * ctx,
848
969
  struct wsp_ggml_tensor * a,
849
- struct wsp_ggml_tensor * b);
970
+ struct wsp_ggml_tensor * b,
971
+ float eps);
850
972
 
851
973
  // A: n columns, m rows
852
974
  // B: n columns, p rows (i.e. we transpose it internally)
@@ -934,11 +1056,22 @@ extern "C" {
934
1056
  struct wsp_ggml_tensor * a,
935
1057
  struct wsp_ggml_tensor * b);
936
1058
 
1059
+ // a -> b, in-place, return view(b)
1060
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy_inplace(
1061
+ struct wsp_ggml_context * ctx,
1062
+ struct wsp_ggml_tensor * a,
1063
+ struct wsp_ggml_tensor * b);
1064
+
937
1065
  // make contiguous
938
1066
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont(
939
1067
  struct wsp_ggml_context * ctx,
940
1068
  struct wsp_ggml_tensor * a);
941
1069
 
1070
+ // make contiguous, in-place
1071
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_inplace(
1072
+ struct wsp_ggml_context * ctx,
1073
+ struct wsp_ggml_tensor * a);
1074
+
942
1075
  // return view(a), b specifies the new shape
943
1076
  // TODO: when we start computing gradient, make a copy instead of view
944
1077
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reshape(
@@ -1107,6 +1240,37 @@ extern "C" {
1107
1240
  int mode,
1108
1241
  int n_ctx);
1109
1242
 
1243
+ // custom RoPE
1244
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1245
+ struct wsp_ggml_context * ctx,
1246
+ struct wsp_ggml_tensor * a,
1247
+ int n_past,
1248
+ int n_dims,
1249
+ int mode,
1250
+ int n_ctx,
1251
+ float freq_base,
1252
+ float freq_scale);
1253
+
1254
+ // in-place, returns view(a)
1255
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1256
+ struct wsp_ggml_context * ctx,
1257
+ struct wsp_ggml_tensor * a,
1258
+ int n_past,
1259
+ int n_dims,
1260
+ int mode,
1261
+ int n_ctx,
1262
+ float freq_base,
1263
+ float freq_scale);
1264
+
1265
+ // xPos RoPE, in-place, returns view(a)
1266
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace(
1267
+ struct wsp_ggml_context * ctx,
1268
+ struct wsp_ggml_tensor * a,
1269
+ int n_past,
1270
+ int n_dims,
1271
+ float base,
1272
+ bool down);
1273
+
1110
1274
  // rotary position embedding backward, i.e compute dx from dy
1111
1275
  // a - dy
1112
1276
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
@@ -1114,7 +1278,12 @@ extern "C" {
1114
1278
  struct wsp_ggml_tensor * a,
1115
1279
  int n_past,
1116
1280
  int n_dims,
1117
- int mode);
1281
+ int mode,
1282
+ int n_ctx,
1283
+ float freq_base,
1284
+ float freq_scale,
1285
+ float xpos_base,
1286
+ bool xpos_down);
1118
1287
 
1119
1288
  // alibi position embedding
1120
1289
  // in-place, returns view(a)
@@ -1141,6 +1310,15 @@ extern "C" {
1141
1310
  int p0, // padding
1142
1311
  int d0); // dilation
1143
1312
 
1313
+ // conv_1d with padding = half
1314
+ // alias for wsp_ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1315
+ WSP_GGML_API struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph(
1316
+ struct wsp_ggml_context * ctx,
1317
+ struct wsp_ggml_tensor * a,
1318
+ struct wsp_ggml_tensor * b,
1319
+ int s,
1320
+ int d);
1321
+
1144
1322
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d(
1145
1323
  struct wsp_ggml_context * ctx,
1146
1324
  struct wsp_ggml_tensor * a,
@@ -1152,14 +1330,70 @@ extern "C" {
1152
1330
  int d0,
1153
1331
  int d1);
1154
1332
 
1155
- // conv_1d with padding = half
1156
- // alias for wsp_ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1157
- WSP_GGML_API struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph(
1333
+
1334
+ // kernel size is a->ne[0] x a->ne[1]
1335
+ // stride is equal to kernel size
1336
+ // padding is zero
1337
+ // example:
1338
+ // a: 16 16 3 768
1339
+ // b: 1024 1024 3 1
1340
+ // res: 64 64 768 1
1341
+ // used in sam
1342
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d_sk_p0(
1343
+ struct wsp_ggml_context * ctx,
1344
+ struct wsp_ggml_tensor * a,
1345
+ struct wsp_ggml_tensor * b);
1346
+
1347
+ // kernel size is a->ne[0] x a->ne[1]
1348
+ // stride is 1
1349
+ // padding is half
1350
+ // example:
1351
+ // a: 3 3 256 256
1352
+ // b: 64 64 256 1
1353
+ // res: 64 64 256 1
1354
+ // used in sam
1355
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d_s1_ph(
1356
+ struct wsp_ggml_context * ctx,
1357
+ struct wsp_ggml_tensor * a,
1358
+ struct wsp_ggml_tensor * b);
1359
+
1360
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_2d_p0(
1158
1361
  struct wsp_ggml_context * ctx,
1159
1362
  struct wsp_ggml_tensor * a,
1160
1363
  struct wsp_ggml_tensor * b,
1161
- int s,
1162
- int d);
1364
+ int stride);
1365
+
1366
+ enum wsp_ggml_op_pool {
1367
+ WSP_GGML_OP_POOL_MAX,
1368
+ WSP_GGML_OP_POOL_AVG,
1369
+ WSP_GGML_OP_POOL_COUNT,
1370
+ };
1371
+
1372
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_1d(
1373
+ struct wsp_ggml_context * ctx,
1374
+ struct wsp_ggml_tensor * a,
1375
+ enum wsp_ggml_op_pool op,
1376
+ int k0, // kernel size
1377
+ int s0, // stride
1378
+ int p0); // padding
1379
+
1380
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d(
1381
+ struct wsp_ggml_context * ctx,
1382
+ struct wsp_ggml_tensor * a,
1383
+ enum wsp_ggml_op_pool op,
1384
+ int k0,
1385
+ int k1,
1386
+ int s0,
1387
+ int s1,
1388
+ int p0,
1389
+ int p1);
1390
+
1391
+ // nearest interpolate
1392
+ // used in stable-diffusion
1393
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale(
1394
+ struct wsp_ggml_context * ctx,
1395
+ struct wsp_ggml_tensor * a,
1396
+ int scale_factor);
1163
1397
 
1164
1398
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn(
1165
1399
  struct wsp_ggml_context * ctx,
@@ -1204,6 +1438,37 @@ extern "C" {
1204
1438
  int h0,
1205
1439
  int w);
1206
1440
 
1441
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_unary(
1442
+ struct wsp_ggml_context * ctx,
1443
+ struct wsp_ggml_tensor * a,
1444
+ enum wsp_ggml_unary_op op);
1445
+
1446
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_unary_inplace(
1447
+ struct wsp_ggml_context * ctx,
1448
+ struct wsp_ggml_tensor * a,
1449
+ enum wsp_ggml_unary_op op);
1450
+
1451
+ // used in sam
1452
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rel_pos(
1453
+ struct wsp_ggml_context * ctx,
1454
+ struct wsp_ggml_tensor * a,
1455
+ int qh,
1456
+ int kh);
1457
+
1458
+ // used in sam
1459
+
1460
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_rel_pos(
1461
+ struct wsp_ggml_context * ctx,
1462
+ struct wsp_ggml_tensor * a,
1463
+ struct wsp_ggml_tensor * pw,
1464
+ struct wsp_ggml_tensor * ph);
1465
+
1466
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_rel_pos_inplace(
1467
+ struct wsp_ggml_context * ctx,
1468
+ struct wsp_ggml_tensor * a,
1469
+ struct wsp_ggml_tensor * pw,
1470
+ struct wsp_ggml_tensor * ph);
1471
+
1207
1472
  // custom operators
1208
1473
 
1209
1474
  typedef void (*wsp_ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -1213,63 +1478,129 @@ extern "C" {
1213
1478
  typedef void (*wsp_ggml_custom2_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *);
1214
1479
  typedef void (*wsp_ggml_custom3_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *);
1215
1480
 
1216
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_f32(
1481
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_f32(
1217
1482
  struct wsp_ggml_context * ctx,
1218
1483
  struct wsp_ggml_tensor * a,
1219
- wsp_ggml_unary_op_f32_t fun);
1484
+ wsp_ggml_unary_op_f32_t fun),
1485
+ "use wsp_ggml_map_custom1 instead");
1220
1486
 
1221
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_inplace_f32(
1487
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_inplace_f32(
1222
1488
  struct wsp_ggml_context * ctx,
1223
1489
  struct wsp_ggml_tensor * a,
1224
- wsp_ggml_unary_op_f32_t fun);
1490
+ wsp_ggml_unary_op_f32_t fun),
1491
+ "use wsp_ggml_map_custom1_inplace instead");
1225
1492
 
1226
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_f32(
1493
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_f32(
1227
1494
  struct wsp_ggml_context * ctx,
1228
1495
  struct wsp_ggml_tensor * a,
1229
1496
  struct wsp_ggml_tensor * b,
1230
- wsp_ggml_binary_op_f32_t fun);
1497
+ wsp_ggml_binary_op_f32_t fun),
1498
+ "use wsp_ggml_map_custom2 instead");
1231
1499
 
1232
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_inplace_f32(
1500
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_inplace_f32(
1233
1501
  struct wsp_ggml_context * ctx,
1234
1502
  struct wsp_ggml_tensor * a,
1235
1503
  struct wsp_ggml_tensor * b,
1236
- wsp_ggml_binary_op_f32_t fun);
1504
+ wsp_ggml_binary_op_f32_t fun),
1505
+ "use wsp_ggml_map_custom2_inplace instead");
1237
1506
 
1238
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_f32(
1507
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_f32(
1239
1508
  struct wsp_ggml_context * ctx,
1240
1509
  struct wsp_ggml_tensor * a,
1241
- wsp_ggml_custom1_op_f32_t fun);
1510
+ wsp_ggml_custom1_op_f32_t fun),
1511
+ "use wsp_ggml_map_custom1 instead");
1242
1512
 
1243
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_inplace_f32(
1513
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_inplace_f32(
1244
1514
  struct wsp_ggml_context * ctx,
1245
1515
  struct wsp_ggml_tensor * a,
1246
- wsp_ggml_custom1_op_f32_t fun);
1516
+ wsp_ggml_custom1_op_f32_t fun),
1517
+ "use wsp_ggml_map_custom1_inplace instead");
1247
1518
 
1248
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_f32(
1519
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_f32(
1249
1520
  struct wsp_ggml_context * ctx,
1250
1521
  struct wsp_ggml_tensor * a,
1251
1522
  struct wsp_ggml_tensor * b,
1252
- wsp_ggml_custom2_op_f32_t fun);
1523
+ wsp_ggml_custom2_op_f32_t fun),
1524
+ "use wsp_ggml_map_custom2 instead");
1253
1525
 
1254
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_inplace_f32(
1526
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_inplace_f32(
1255
1527
  struct wsp_ggml_context * ctx,
1256
1528
  struct wsp_ggml_tensor * a,
1257
1529
  struct wsp_ggml_tensor * b,
1258
- wsp_ggml_custom2_op_f32_t fun);
1530
+ wsp_ggml_custom2_op_f32_t fun),
1531
+ "use wsp_ggml_map_custom2_inplace instead");
1259
1532
 
1260
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_f32(
1533
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_f32(
1261
1534
  struct wsp_ggml_context * ctx,
1262
1535
  struct wsp_ggml_tensor * a,
1263
1536
  struct wsp_ggml_tensor * b,
1264
1537
  struct wsp_ggml_tensor * c,
1265
- wsp_ggml_custom3_op_f32_t fun);
1538
+ wsp_ggml_custom3_op_f32_t fun),
1539
+ "use wsp_ggml_map_custom3 instead");
1266
1540
 
1267
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_inplace_f32(
1541
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_inplace_f32(
1268
1542
  struct wsp_ggml_context * ctx,
1269
1543
  struct wsp_ggml_tensor * a,
1270
1544
  struct wsp_ggml_tensor * b,
1271
1545
  struct wsp_ggml_tensor * c,
1272
- wsp_ggml_custom3_op_f32_t fun);
1546
+ wsp_ggml_custom3_op_f32_t fun),
1547
+ "use wsp_ggml_map_custom3_inplace instead");
1548
+
1549
+ // custom operators v2
1550
+
1551
+ typedef void (*wsp_ggml_custom1_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, int ith, int nth, void * userdata);
1552
+ typedef void (*wsp_ggml_custom2_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, int ith, int nth, void * userdata);
1553
+ typedef void (*wsp_ggml_custom3_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, const struct wsp_ggml_tensor * c, int ith, int nth, void * userdata);
1554
+
1555
+ #define WSP_GGML_N_TASKS_MAX -1
1556
+
1557
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1(
1558
+ struct wsp_ggml_context * ctx,
1559
+ struct wsp_ggml_tensor * a,
1560
+ wsp_ggml_custom1_op_t fun,
1561
+ int n_tasks,
1562
+ void * userdata);
1563
+
1564
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_inplace(
1565
+ struct wsp_ggml_context * ctx,
1566
+ struct wsp_ggml_tensor * a,
1567
+ wsp_ggml_custom1_op_t fun,
1568
+ int n_tasks,
1569
+ void * userdata);
1570
+
1571
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2(
1572
+ struct wsp_ggml_context * ctx,
1573
+ struct wsp_ggml_tensor * a,
1574
+ struct wsp_ggml_tensor * b,
1575
+ wsp_ggml_custom2_op_t fun,
1576
+ int n_tasks,
1577
+ void * userdata);
1578
+
1579
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_inplace(
1580
+ struct wsp_ggml_context * ctx,
1581
+ struct wsp_ggml_tensor * a,
1582
+ struct wsp_ggml_tensor * b,
1583
+ wsp_ggml_custom2_op_t fun,
1584
+ int n_tasks,
1585
+ void * userdata);
1586
+
1587
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3(
1588
+ struct wsp_ggml_context * ctx,
1589
+ struct wsp_ggml_tensor * a,
1590
+ struct wsp_ggml_tensor * b,
1591
+ struct wsp_ggml_tensor * c,
1592
+ wsp_ggml_custom3_op_t fun,
1593
+ int n_tasks,
1594
+ void * userdata);
1595
+
1596
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_inplace(
1597
+ struct wsp_ggml_context * ctx,
1598
+ struct wsp_ggml_tensor * a,
1599
+ struct wsp_ggml_tensor * b,
1600
+ struct wsp_ggml_tensor * c,
1601
+ wsp_ggml_custom3_op_t fun,
1602
+ int n_tasks,
1603
+ void * userdata);
1273
1604
 
1274
1605
  // loss function
1275
1606
 
@@ -1290,15 +1621,29 @@ extern "C" {
1290
1621
 
1291
1622
  WSP_GGML_API void wsp_ggml_set_param(
1292
1623
  struct wsp_ggml_context * ctx,
1293
- struct wsp_ggml_tensor * tensor);
1624
+ struct wsp_ggml_tensor * tensor);
1625
+
1294
1626
 
1295
- WSP_GGML_API void wsp_ggml_build_forward_expand(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1627
+ WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1628
+ WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep);
1296
1629
 
1297
1630
  WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_forward (struct wsp_ggml_tensor * tensor);
1298
1631
  WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, bool keep);
1299
1632
 
1300
- WSP_GGML_API void wsp_ggml_graph_compute(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
1301
- WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph);
1633
+ // graph allocation in a context
1634
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx);
1635
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_build_forward_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
1636
+ WSP_GGML_API size_t wsp_ggml_graph_overhead(void);
1637
+
1638
+ // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute()
1639
+ // when plan.work_size > 0, caller must allocate memory for plan.work_data
1640
+ WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/);
1641
+ WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1642
+ WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph);
1643
+
1644
+ // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context
1645
+ // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
1646
+ WSP_GGML_API void wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads);
1302
1647
 
1303
1648
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name);
1304
1649
 
@@ -1345,6 +1690,8 @@ extern "C" {
1345
1690
  WSP_GGML_LINESEARCH_INVALID_PARAMETERS,
1346
1691
  };
1347
1692
 
1693
+ typedef void (*wsp_ggml_opt_callback)(void * data, float * sched);
1694
+
1348
1695
  // optimization parameters
1349
1696
  //
1350
1697
  // see ggml.c (wsp_ggml_opt_default_params) for default values
@@ -1380,12 +1727,14 @@ extern "C" {
1380
1727
 
1381
1728
  float sched; // schedule multiplier (fixed, decay or warmup)
1382
1729
  float decay; // weight decay for AdamW, use 0.0f to disable
1730
+ int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
1383
1731
  float alpha; // learning rate
1384
1732
  float beta1;
1385
1733
  float beta2;
1386
1734
  float eps; // epsilon for numerical stability
1387
1735
  float eps_f; // epsilon for convergence test
1388
1736
  float eps_g; // epsilon for convergence test
1737
+ float gclip; // gradient clipping
1389
1738
  } adam;
1390
1739
 
1391
1740
  // LBFGS parameters
@@ -1413,14 +1762,12 @@ extern "C" {
1413
1762
 
1414
1763
  bool just_initialized;
1415
1764
 
1765
+ float loss_before;
1766
+ float loss_after;
1767
+
1416
1768
  struct {
1417
- struct wsp_ggml_tensor * x; // view of the parameters
1418
- struct wsp_ggml_tensor * g1; // gradient
1419
- struct wsp_ggml_tensor * g2; // gradient squared
1420
1769
  struct wsp_ggml_tensor * m; // first moment
1421
1770
  struct wsp_ggml_tensor * v; // second moment
1422
- struct wsp_ggml_tensor * mh; // first moment hat
1423
- struct wsp_ggml_tensor * vh; // second moment hat
1424
1771
  struct wsp_ggml_tensor * pf; // past function values
1425
1772
  float fx_best;
1426
1773
  float fx_prev;
@@ -1457,10 +1804,10 @@ extern "C" {
1457
1804
 
1458
1805
  // initialize optimizer context
1459
1806
  WSP_GGML_API void wsp_ggml_opt_init(
1460
- struct wsp_ggml_context * ctx,
1807
+ struct wsp_ggml_context * ctx,
1461
1808
  struct wsp_ggml_opt_context * opt,
1462
- struct wsp_ggml_opt_params params,
1463
- int64_t nx);
1809
+ struct wsp_ggml_opt_params params,
1810
+ int64_t nx);
1464
1811
 
1465
1812
  // continue optimizing the function defined by the tensor f
1466
1813
  WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume(
@@ -1474,7 +1821,9 @@ extern "C" {
1474
1821
  struct wsp_ggml_opt_context * opt,
1475
1822
  struct wsp_ggml_tensor * f,
1476
1823
  struct wsp_ggml_cgraph * gf,
1477
- struct wsp_ggml_cgraph * gb);
1824
+ struct wsp_ggml_cgraph * gb,
1825
+ wsp_ggml_opt_callback callback,
1826
+ void * callback_data);
1478
1827
 
1479
1828
  //
1480
1829
  // quantization
@@ -1488,6 +1837,127 @@ extern "C" {
1488
1837
 
1489
1838
  WSP_GGML_API size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1490
1839
 
1840
+ //
1841
+ // gguf
1842
+ //
1843
+
1844
+ enum gguf_type {
1845
+ GGUF_TYPE_UINT8 = 0,
1846
+ GGUF_TYPE_INT8 = 1,
1847
+ GGUF_TYPE_UINT16 = 2,
1848
+ GGUF_TYPE_INT16 = 3,
1849
+ GGUF_TYPE_UINT32 = 4,
1850
+ GGUF_TYPE_INT32 = 5,
1851
+ GGUF_TYPE_FLOAT32 = 6,
1852
+ GGUF_TYPE_BOOL = 7,
1853
+ GGUF_TYPE_STRING = 8,
1854
+ GGUF_TYPE_ARRAY = 9,
1855
+ GGUF_TYPE_UINT64 = 10,
1856
+ GGUF_TYPE_INT64 = 11,
1857
+ GGUF_TYPE_FLOAT64 = 12,
1858
+ GGUF_TYPE_COUNT, // marks the end of the enum
1859
+ };
1860
+
1861
+ struct gguf_context;
1862
+
1863
+ struct gguf_init_params {
1864
+ bool no_alloc;
1865
+
1866
+ // if not NULL, create a wsp_ggml_context and allocate the tensor data in it
1867
+ struct wsp_ggml_context ** ctx;
1868
+ };
1869
+
1870
+ WSP_GGML_API struct gguf_context * gguf_init_empty(void);
1871
+ WSP_GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
1872
+ //WSP_GGML_API struct gguf_context * gguf_init_from_buffer(..);
1873
+
1874
+ WSP_GGML_API void gguf_free(struct gguf_context * ctx);
1875
+
1876
+ WSP_GGML_API const char * gguf_type_name(enum gguf_type type);
1877
+
1878
+ WSP_GGML_API int gguf_get_version (const struct gguf_context * ctx);
1879
+ WSP_GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
1880
+ WSP_GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
1881
+ WSP_GGML_API void * gguf_get_data (const struct gguf_context * ctx);
1882
+
1883
+ WSP_GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
1884
+ WSP_GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
1885
+ WSP_GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
1886
+
1887
+ WSP_GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
1888
+ WSP_GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
1889
+
1890
+ // results are undefined if the wrong type is used for the key
1891
+ WSP_GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i);
1892
+ WSP_GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i);
1893
+ WSP_GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i);
1894
+ WSP_GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i);
1895
+ WSP_GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i);
1896
+ WSP_GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i);
1897
+ WSP_GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i);
1898
+ WSP_GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i);
1899
+ WSP_GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i);
1900
+ WSP_GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i);
1901
+ WSP_GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i);
1902
+ WSP_GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
1903
+ WSP_GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i);
1904
+ WSP_GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
1905
+ WSP_GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
1906
+
1907
+ WSP_GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
1908
+ WSP_GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
1909
+ WSP_GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
1910
+ WSP_GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
1911
+
1912
+ // overrides existing values or adds a new one
1913
+ WSP_GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
1914
+ WSP_GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
1915
+ WSP_GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
1916
+ WSP_GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
1917
+ WSP_GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
1918
+ WSP_GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
1919
+ WSP_GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
1920
+ WSP_GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
1921
+ WSP_GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
1922
+ WSP_GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
1923
+ WSP_GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
1924
+ WSP_GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
1925
+ WSP_GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
1926
+ WSP_GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
1927
+
1928
+ // set or add KV pairs from another context
1929
+ WSP_GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
1930
+
1931
+ // manage tensor info
1932
+ WSP_GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
1933
+ WSP_GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum wsp_ggml_type type);
1934
+ WSP_GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
1935
+
1936
+ // writing gguf files can be done in 2 ways:
1937
+ //
1938
+ // - write the entire gguf_context to a binary file in a single pass:
1939
+ //
1940
+ // gguf_write_to_file(ctx, fname);
1941
+ //
1942
+ // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
1943
+ //
1944
+ // FILE * f = fopen(fname, "wb");
1945
+ // fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
1946
+ // fwrite(f, ...);
1947
+ // void * data = gguf_meta_get_meta_data(ctx);
1948
+ // fseek(f, 0, SEEK_SET);
1949
+ // fwrite(f, data, gguf_get_meta_size(ctx));
1950
+ // free(data);
1951
+ // fclose(f);
1952
+ //
1953
+
1954
+ // write the entire context to a binary file
1955
+ WSP_GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
1956
+
1957
+ // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
1958
+ WSP_GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
1959
+ WSP_GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
1960
+
1491
1961
  //
1492
1962
  // system info
1493
1963
  //
@@ -1500,6 +1970,7 @@ extern "C" {
1500
1970
  WSP_GGML_API int wsp_ggml_cpu_has_fma (void);
1501
1971
  WSP_GGML_API int wsp_ggml_cpu_has_neon (void);
1502
1972
  WSP_GGML_API int wsp_ggml_cpu_has_arm_fma (void);
1973
+ WSP_GGML_API int wsp_ggml_cpu_has_metal (void);
1503
1974
  WSP_GGML_API int wsp_ggml_cpu_has_f16c (void);
1504
1975
  WSP_GGML_API int wsp_ggml_cpu_has_fp16_va (void);
1505
1976
  WSP_GGML_API int wsp_ggml_cpu_has_wasm_simd (void);
@@ -1516,25 +1987,28 @@ extern "C" {
1516
1987
  //
1517
1988
 
1518
1989
  #ifdef __cplusplus
1519
- // restrict not standard in C++
1990
+ // restrict not standard in C++
1520
1991
  #define WSP_GGML_RESTRICT
1521
1992
  #else
1522
1993
  #define WSP_GGML_RESTRICT restrict
1523
1994
  #endif
1524
- typedef void (*dequantize_row_q_t)(const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int k);
1525
- typedef void (*quantize_row_q_t) (const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int k);
1526
- typedef void (*vec_dot_q_t) (const int n, float * WSP_GGML_RESTRICT s, const void * WSP_GGML_RESTRICT x, const void * WSP_GGML_RESTRICT y);
1995
+ typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int k);
1996
+ typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int k);
1997
+ typedef void (*wsp_ggml_vec_dot_t) (const int n, float * WSP_GGML_RESTRICT s, const void * WSP_GGML_RESTRICT x, const void * WSP_GGML_RESTRICT y);
1527
1998
 
1528
1999
  typedef struct {
1529
- dequantize_row_q_t dequantize_row_q;
1530
- quantize_row_q_t quantize_row_q;
1531
- quantize_row_q_t quantize_row_q_reference;
1532
- quantize_row_q_t quantize_row_q_dot;
1533
- vec_dot_q_t vec_dot_q;
1534
- enum wsp_ggml_type vec_dot_type;
1535
- } quantize_fns_t;
1536
-
1537
- quantize_fns_t wsp_ggml_internal_get_quantize_fn(size_t i);
2000
+ const char * type_name;
2001
+ int blck_size;
2002
+ size_t type_size;
2003
+ bool is_quantized;
2004
+ wsp_ggml_to_float_t to_float;
2005
+ wsp_ggml_from_float_t from_float;
2006
+ wsp_ggml_from_float_t from_float_reference;
2007
+ wsp_ggml_vec_dot_t vec_dot;
2008
+ enum wsp_ggml_type vec_dot_type;
2009
+ } wsp_ggml_type_traits_t;
2010
+
2011
+ wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
1538
2012
 
1539
2013
  #ifdef __cplusplus
1540
2014
  }