whisper.rn 0.4.0-rc.10 → 0.4.0-rc.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/android/src/main/CMakeLists.txt +9 -3
  2. package/cpp/amx/amx.cpp +220 -0
  3. package/cpp/amx/amx.h +8 -0
  4. package/cpp/amx/common.h +91 -0
  5. package/cpp/amx/mmq.cpp +2511 -0
  6. package/cpp/amx/mmq.h +10 -0
  7. package/cpp/ggml-alloc.c +6 -14
  8. package/cpp/ggml-backend-impl.h +50 -11
  9. package/cpp/ggml-backend-reg.cpp +409 -31
  10. package/cpp/ggml-backend.cpp +9 -3
  11. package/cpp/ggml-backend.h +18 -0
  12. package/cpp/ggml-common.h +41 -43
  13. package/cpp/ggml-cpp.h +1 -0
  14. package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +941 -254
  15. package/cpp/ggml-cpu-aarch64.h +2 -24
  16. package/cpp/ggml-cpu-impl.h +171 -11
  17. package/cpp/ggml-cpu-quants.c +1812 -389
  18. package/cpp/ggml-cpu-traits.cpp +36 -0
  19. package/cpp/ggml-cpu-traits.h +38 -0
  20. package/cpp/ggml-cpu.c +1432 -610
  21. package/cpp/ggml-cpu.cpp +131 -141
  22. package/cpp/ggml-cpu.h +10 -50
  23. package/cpp/ggml-impl.h +27 -11
  24. package/cpp/ggml-metal-impl.h +39 -0
  25. package/cpp/ggml-metal.h +1 -1
  26. package/cpp/ggml-metal.m +1031 -359
  27. package/cpp/ggml-opt.cpp +854 -0
  28. package/cpp/ggml-opt.h +216 -0
  29. package/cpp/ggml-quants.c +0 -9
  30. package/cpp/ggml-threading.h +4 -2
  31. package/cpp/ggml-whisper.metallib +0 -0
  32. package/cpp/ggml.c +501 -1537
  33. package/cpp/ggml.h +144 -171
  34. package/cpp/gguf.cpp +1329 -0
  35. package/cpp/gguf.h +202 -0
  36. package/cpp/whisper.cpp +254 -114
  37. package/cpp/whisper.h +6 -3
  38. package/lib/commonjs/version.json +1 -1
  39. package/lib/module/version.json +1 -1
  40. package/package.json +2 -1
  41. package/src/version.json +1 -1
  42. package/whisper-rn.podspec +2 -2
  43. package/cpp/README.md +0 -4
  44. package/cpp/ggml-aarch64.c +0 -129
  45. package/cpp/ggml-aarch64.h +0 -19
  46. package/cpp/ggml-backend.cpp.rej +0 -12
package/cpp/ggml.c CHANGED
@@ -8,7 +8,10 @@
8
8
 
9
9
  // FIXME: required here for quantization functions
10
10
  #include "ggml-quants.h"
11
- #include "ggml-aarch64.h"
11
+
12
+ #ifdef WSP_GGML_USE_CPU_HBM
13
+ #include <hbwmalloc.h>
14
+ #endif
12
15
 
13
16
  #if defined(_MSC_VER) || defined(__MINGW32__)
14
17
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -125,6 +128,10 @@ static void wsp_ggml_print_backtrace_symbols(void) {
125
128
  #endif
126
129
 
127
130
  static void wsp_ggml_print_backtrace(void) {
131
+ const char * WSP_GGML_NO_BACKTRACE = getenv("WSP_GGML_NO_BACKTRACE");
132
+ if (WSP_GGML_NO_BACKTRACE) {
133
+ return;
134
+ }
128
135
  char attach[32];
129
136
  snprintf(attach, sizeof(attach), "attach %d", getpid());
130
137
  int pid = fork();
@@ -233,7 +240,11 @@ void wsp_ggml_log_callback_default(enum wsp_ggml_log_level level, const char * t
233
240
 
234
241
 
235
242
  void * wsp_ggml_aligned_malloc(size_t size) {
243
+ #if defined(__s390x__)
244
+ const int alignment = 256;
245
+ #else
236
246
  const int alignment = 64;
247
+ #endif
237
248
 
238
249
  #if defined(_MSC_VER) || defined(__MINGW32__)
239
250
  return _aligned_malloc(size, alignment);
@@ -788,32 +799,23 @@ static const struct wsp_ggml_type_traits type_traits[WSP_GGML_TYPE_COUNT] = {
788
799
  .to_float = (wsp_ggml_to_float_t) wsp_ggml_bf16_to_fp32_row,
789
800
  .from_float_ref = (wsp_ggml_from_float_t) wsp_ggml_fp32_to_bf16_row_ref,
790
801
  },
791
- [WSP_GGML_TYPE_Q4_0_4_4] = {
792
- .type_name = "q4_0_4x4",
793
- .blck_size = QK4_0,
794
- .blck_size_interleave = 4,
795
- .type_size = sizeof(block_q4_0),
796
- .is_quantized = true,
797
- .to_float = NULL,
798
- .from_float_ref = NULL,
802
+ [31] = { // WSP_GGML_TYPE_Q4_0_4_4
803
+ .type_name = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking",
804
+ .blck_size = 0,
805
+ .type_size = 0,
806
+ .is_quantized = false,
799
807
  },
800
- [WSP_GGML_TYPE_Q4_0_4_8] = {
801
- .type_name = "q4_0_4x8",
802
- .blck_size = QK4_0,
803
- .blck_size_interleave = 8,
804
- .type_size = sizeof(block_q4_0),
805
- .is_quantized = true,
806
- .to_float = NULL,
807
- .from_float_ref = NULL,
808
+ [32] = { // WSP_GGML_TYPE_Q4_0_4_8
809
+ .type_name = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking",
810
+ .blck_size = 0,
811
+ .type_size = 0,
812
+ .is_quantized = false,
808
813
  },
809
- [WSP_GGML_TYPE_Q4_0_8_8] = {
810
- .type_name = "q4_0_8x8",
811
- .blck_size = QK4_0,
812
- .blck_size_interleave = 8,
813
- .type_size = sizeof(block_q4_0),
814
- .is_quantized = true,
815
- .to_float = NULL,
816
- .from_float_ref = NULL,
814
+ [33] = { // WSP_GGML_TYPE_Q4_0_8_8
815
+ .type_name = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking",
816
+ .blck_size = 0,
817
+ .type_size = 0,
818
+ .is_quantized = false,
817
819
  },
818
820
  [WSP_GGML_TYPE_TQ1_0] = {
819
821
  .type_name = "tq1_0",
@@ -831,6 +833,24 @@ static const struct wsp_ggml_type_traits type_traits[WSP_GGML_TYPE_COUNT] = {
831
833
  .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_tq2_0,
832
834
  .from_float_ref = (wsp_ggml_from_float_t) wsp_quantize_row_tq2_0_ref,
833
835
  },
836
+ [36] = { // WSP_GGML_TYPE_IQ4_NL_4_4
837
+ .type_name = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking",
838
+ .blck_size = 0,
839
+ .type_size = 0,
840
+ .is_quantized = false,
841
+ },
842
+ [37] = { // WSP_GGML_TYPE_IQ4_NL_4_8
843
+ .type_name = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking",
844
+ .blck_size = 0,
845
+ .type_size = 0,
846
+ .is_quantized = false,
847
+ },
848
+ [38] = { // WSP_GGML_TYPE_IQ4_NL_8_8
849
+ .type_name = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking",
850
+ .blck_size = 0,
851
+ .type_size = 0,
852
+ .is_quantized = false,
853
+ },
834
854
  };
835
855
 
836
856
  const struct wsp_ggml_type_traits * wsp_ggml_get_type_traits(enum wsp_ggml_type type) {
@@ -941,6 +961,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
941
961
  "POOL_2D_BACK",
942
962
  "UPSCALE",
943
963
  "PAD",
964
+ "PAD_REFLECT_1D",
944
965
  "ARANGE",
945
966
  "TIMESTEP_EMBEDDING",
946
967
  "ARGSORT",
@@ -955,6 +976,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
955
976
  "GET_REL_POS",
956
977
  "ADD_REL_POS",
957
978
  "RWKV_WKV6",
979
+ "GATED_LINEAR_ATTN",
958
980
 
959
981
  "UNARY",
960
982
 
@@ -974,7 +996,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
974
996
  "OPT_STEP_ADAMW",
975
997
  };
976
998
 
977
- static_assert(WSP_GGML_OP_COUNT == 81, "WSP_GGML_OP_COUNT != 81");
999
+ static_assert(WSP_GGML_OP_COUNT == 83, "WSP_GGML_OP_COUNT != 83");
978
1000
 
979
1001
  static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
980
1002
  "none",
@@ -1036,6 +1058,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1036
1058
  "pool_2d_back(x)",
1037
1059
  "upscale(x)",
1038
1060
  "pad(x)",
1061
+ "pad_reflect_1d(x)",
1039
1062
  "arange(start, stop, step)",
1040
1063
  "timestep_embedding(timesteps, dim, max_period)",
1041
1064
  "argsort(x)",
@@ -1050,6 +1073,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1050
1073
  "get_rel_pos(x)",
1051
1074
  "add_rel_pos(x)",
1052
1075
  "rwkv_wkv6(k, v, r, tf, td, s)",
1076
+ "gated_linear_attn(k, v, q, gate, s)",
1053
1077
 
1054
1078
  "unary(x)",
1055
1079
 
@@ -1069,7 +1093,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1069
1093
  "adamw(x)",
1070
1094
  };
1071
1095
 
1072
- static_assert(WSP_GGML_OP_COUNT == 81, "WSP_GGML_OP_COUNT != 81");
1096
+ static_assert(WSP_GGML_OP_COUNT == 83, "WSP_GGML_OP_COUNT != 83");
1073
1097
 
1074
1098
  static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
1075
1099
 
@@ -1259,9 +1283,6 @@ enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype) {
1259
1283
  case WSP_GGML_FTYPE_MOSTLY_IQ4_XS: wtype = WSP_GGML_TYPE_IQ4_XS; break;
1260
1284
  case WSP_GGML_FTYPE_MOSTLY_IQ3_S: wtype = WSP_GGML_TYPE_IQ3_S; break;
1261
1285
  case WSP_GGML_FTYPE_MOSTLY_IQ2_S: wtype = WSP_GGML_TYPE_IQ2_S; break;
1262
- case WSP_GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = WSP_GGML_TYPE_Q4_0_4_4; break;
1263
- case WSP_GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = WSP_GGML_TYPE_Q4_0_4_8; break;
1264
- case WSP_GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = WSP_GGML_TYPE_Q4_0_8_8; break;
1265
1286
  case WSP_GGML_FTYPE_UNKNOWN: wtype = WSP_GGML_TYPE_COUNT; break;
1266
1287
  case WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = WSP_GGML_TYPE_COUNT; break;
1267
1288
  }
@@ -1362,7 +1383,7 @@ bool wsp_ggml_are_same_stride(const struct wsp_ggml_tensor * t0, const struct ws
1362
1383
  (t0->nb[3] == t1->nb[3]);
1363
1384
  }
1364
1385
 
1365
- // check if t1 can be represented as a repeatition of t0
1386
+ // check if t1 can be represented as a repetition of t0
1366
1387
  bool wsp_ggml_can_repeat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1) {
1367
1388
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
1368
1389
 
@@ -1577,15 +1598,8 @@ static struct wsp_ggml_tensor * wsp_ggml_new_tensor_impl(
1577
1598
 
1578
1599
  struct wsp_ggml_tensor * const result = (struct wsp_ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
1579
1600
 
1580
- #ifdef __clang__
1581
- // temporary until wsp_ggml_tensor::backend is removed
1582
- #pragma clang diagnostic push
1583
- #pragma clang diagnostic ignored "-Wdeprecated-declarations"
1584
- #endif
1585
-
1586
1601
  *result = (struct wsp_ggml_tensor) {
1587
1602
  /*.type =*/ type,
1588
- /*.backend =*/ WSP_GGML_BACKEND_TYPE_CPU,
1589
1603
  /*.buffer =*/ NULL,
1590
1604
  /*.ne =*/ { 1, 1, 1, 1 },
1591
1605
  /*.nb =*/ { 0, 0, 0, 0 },
@@ -1601,10 +1615,6 @@ static struct wsp_ggml_tensor * wsp_ggml_new_tensor_impl(
1601
1615
  /*.padding =*/ { 0 },
1602
1616
  };
1603
1617
 
1604
- #ifdef __clang__
1605
- #pragma clang diagnostic pop
1606
- #endif
1607
-
1608
1618
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
1609
1619
  //WSP_GGML_ASSERT_ALIGNED(result->data);
1610
1620
 
@@ -2255,6 +2265,7 @@ struct wsp_ggml_tensor * wsp_ggml_argmax(
2255
2265
  struct wsp_ggml_context * ctx,
2256
2266
  struct wsp_ggml_tensor * a) {
2257
2267
  WSP_GGML_ASSERT(wsp_ggml_is_matrix(a));
2268
+ WSP_GGML_ASSERT(a->ne[0] <= INT32_MAX);
2258
2269
 
2259
2270
  struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, a->ne[1]);
2260
2271
 
@@ -3447,12 +3458,14 @@ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
3447
3458
  return wsp_ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
3448
3459
  }
3449
3460
 
3450
- // wsp_ggml_soft_max_back
3461
+ // wsp_ggml_soft_max_ext_back
3451
3462
 
3452
- static struct wsp_ggml_tensor * wsp_ggml_soft_max_back_impl(
3463
+ static struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back_impl(
3453
3464
  struct wsp_ggml_context * ctx,
3454
3465
  struct wsp_ggml_tensor * a,
3455
3466
  struct wsp_ggml_tensor * b,
3467
+ float scale,
3468
+ float max_bias,
3456
3469
  bool inplace) {
3457
3470
  struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
3458
3471
 
@@ -3460,21 +3473,28 @@ static struct wsp_ggml_tensor * wsp_ggml_soft_max_back_impl(
3460
3473
  result->src[0] = a;
3461
3474
  result->src[1] = b;
3462
3475
 
3476
+ memcpy((float *) result->op_params + 0, &scale, sizeof(float));
3477
+ memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
3478
+
3463
3479
  return result;
3464
3480
  }
3465
3481
 
3466
- struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
3482
+ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back(
3467
3483
  struct wsp_ggml_context * ctx,
3468
3484
  struct wsp_ggml_tensor * a,
3469
- struct wsp_ggml_tensor * b) {
3470
- return wsp_ggml_soft_max_back_impl(ctx, a, b, false);
3485
+ struct wsp_ggml_tensor * b,
3486
+ float scale,
3487
+ float max_bias) {
3488
+ return wsp_ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
3471
3489
  }
3472
3490
 
3473
- struct wsp_ggml_tensor * wsp_ggml_soft_max_back_inplace(
3491
+ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back_inplace(
3474
3492
  struct wsp_ggml_context * ctx,
3475
3493
  struct wsp_ggml_tensor * a,
3476
- struct wsp_ggml_tensor * b) {
3477
- return wsp_ggml_soft_max_back_impl(ctx, a, b, true);
3494
+ struct wsp_ggml_tensor * b,
3495
+ float scale,
3496
+ float max_bias) {
3497
+ return wsp_ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
3478
3498
  }
3479
3499
 
3480
3500
  // wsp_ggml_rope
@@ -3505,15 +3525,18 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl(
3505
3525
  WSP_GGML_ASSERT(c->ne[0] >= n_dims / 2);
3506
3526
  }
3507
3527
 
3528
+ int sections[4] = {0, 0, 0, 0};
3529
+
3508
3530
  struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
3509
3531
 
3510
- int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3532
+ int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3511
3533
  memcpy(params + 5, &freq_base, sizeof(float));
3512
3534
  memcpy(params + 6, &freq_scale, sizeof(float));
3513
3535
  memcpy(params + 7, &ext_factor, sizeof(float));
3514
3536
  memcpy(params + 8, &attn_factor, sizeof(float));
3515
3537
  memcpy(params + 9, &beta_fast, sizeof(float));
3516
3538
  memcpy(params + 10, &beta_slow, sizeof(float));
3539
+ memcpy(params + 11, &sections, sizeof(int)*4);
3517
3540
  wsp_ggml_set_op_params(result, params, sizeof(params));
3518
3541
 
3519
3542
  result->op = WSP_GGML_OP_ROPE;
@@ -3535,6 +3558,53 @@ struct wsp_ggml_tensor * wsp_ggml_rope(
3535
3558
  );
3536
3559
  }
3537
3560
 
3561
+ struct wsp_ggml_tensor * wsp_ggml_rope_multi(
3562
+ struct wsp_ggml_context * ctx,
3563
+ struct wsp_ggml_tensor * a,
3564
+ struct wsp_ggml_tensor * b,
3565
+ struct wsp_ggml_tensor * c,
3566
+ int n_dims,
3567
+ int sections[4],
3568
+ int mode,
3569
+ int n_ctx_orig,
3570
+ float freq_base,
3571
+ float freq_scale,
3572
+ float ext_factor,
3573
+ float attn_factor,
3574
+ float beta_fast,
3575
+ float beta_slow) {
3576
+ // Multimodal Rotary Position Embedding
3577
+ WSP_GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
3578
+
3579
+ WSP_GGML_ASSERT(wsp_ggml_is_vector(b));
3580
+ WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
3581
+ WSP_GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
3582
+
3583
+ if (c) {
3584
+ WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_F32);
3585
+ WSP_GGML_ASSERT(c->ne[0] >= n_dims / 2);
3586
+ }
3587
+
3588
+ struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
3589
+
3590
+ int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3591
+ memcpy(params + 5, &freq_base, sizeof(float));
3592
+ memcpy(params + 6, &freq_scale, sizeof(float));
3593
+ memcpy(params + 7, &ext_factor, sizeof(float));
3594
+ memcpy(params + 8, &attn_factor, sizeof(float));
3595
+ memcpy(params + 9, &beta_fast, sizeof(float));
3596
+ memcpy(params + 10, &beta_slow, sizeof(float));
3597
+ memcpy(&params[11], sections, sizeof(int)*4);
3598
+ wsp_ggml_set_op_params(result, params, sizeof(params));
3599
+
3600
+ result->op = WSP_GGML_OP_ROPE;
3601
+ result->src[0] = a;
3602
+ result->src[1] = b;
3603
+ result->src[2] = c;
3604
+
3605
+ return result;
3606
+ }
3607
+
3538
3608
  struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
3539
3609
  struct wsp_ggml_context * ctx,
3540
3610
  struct wsp_ggml_tensor * a,
@@ -3642,7 +3712,7 @@ void wsp_ggml_rope_yarn_corr_dims(
3642
3712
 
3643
3713
  // wsp_ggml_rope_back
3644
3714
 
3645
- struct wsp_ggml_tensor * wsp_ggml_rope_back(
3715
+ struct wsp_ggml_tensor * wsp_ggml_rope_ext_back(
3646
3716
  struct wsp_ggml_context * ctx,
3647
3717
  struct wsp_ggml_tensor * a,
3648
3718
  struct wsp_ggml_tensor * b,
@@ -3656,29 +3726,32 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back(
3656
3726
  float attn_factor,
3657
3727
  float beta_fast,
3658
3728
  float beta_slow) {
3659
- WSP_GGML_ASSERT(wsp_ggml_is_vector(b));
3660
- WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
3661
- WSP_GGML_ASSERT(a->ne[2] == b->ne[0]);
3662
-
3663
- struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
3664
-
3665
- int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3666
- memcpy(params + 5, &freq_base, sizeof(float));
3667
- memcpy(params + 6, &freq_scale, sizeof(float));
3668
- memcpy(params + 7, &ext_factor, sizeof(float));
3669
- memcpy(params + 8, &attn_factor, sizeof(float));
3670
- memcpy(params + 9, &beta_fast, sizeof(float));
3671
- memcpy(params + 10, &beta_slow, sizeof(float));
3672
- wsp_ggml_set_op_params(result, params, sizeof(params));
3673
-
3674
- result->op = WSP_GGML_OP_ROPE_BACK;
3675
- result->src[0] = a;
3676
- result->src[1] = b;
3677
- result->src[2] = c;
3678
-
3729
+ struct wsp_ggml_tensor * result = wsp_ggml_rope_ext(
3730
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
3731
+ result->op = WSP_GGML_OP_ROPE_BACK;
3679
3732
  return result;
3680
3733
  }
3681
3734
 
3735
+ struct wsp_ggml_tensor * wsp_ggml_rope_multi_back(
3736
+ struct wsp_ggml_context * ctx,
3737
+ struct wsp_ggml_tensor * a,
3738
+ struct wsp_ggml_tensor * b,
3739
+ struct wsp_ggml_tensor * c,
3740
+ int n_dims,
3741
+ int sections[4],
3742
+ int mode,
3743
+ int n_ctx_orig,
3744
+ float freq_base,
3745
+ float freq_scale,
3746
+ float ext_factor,
3747
+ float attn_factor,
3748
+ float beta_fast,
3749
+ float beta_slow) {
3750
+ struct wsp_ggml_tensor * result = wsp_ggml_rope_multi(
3751
+ ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
3752
+ result->op = WSP_GGML_OP_ROPE_BACK;
3753
+ return result;
3754
+ }
3682
3755
  // wsp_ggml_clamp
3683
3756
 
3684
3757
  struct wsp_ggml_tensor * wsp_ggml_clamp(
@@ -3698,104 +3771,10 @@ struct wsp_ggml_tensor * wsp_ggml_clamp(
3698
3771
  return result;
3699
3772
  }
3700
3773
 
3701
- // wsp_ggml_conv_1d
3702
-
3703
3774
  static int64_t wsp_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
3704
3775
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
3705
3776
  }
3706
3777
 
3707
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
3708
- struct wsp_ggml_context * ctx,
3709
- struct wsp_ggml_tensor * a,
3710
- struct wsp_ggml_tensor * b,
3711
- int s0,
3712
- int p0,
3713
- int d0) {
3714
- struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, WSP_GGML_TYPE_F16); // [N, OL, IC * K]
3715
-
3716
- struct wsp_ggml_tensor * result =
3717
- wsp_ggml_mul_mat(ctx,
3718
- wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
3719
- wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
3720
-
3721
- result = wsp_ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
3722
-
3723
- return result;
3724
- }
3725
-
3726
- // wsp_ggml_conv_1d_ph
3727
-
3728
- struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph(
3729
- struct wsp_ggml_context * ctx,
3730
- struct wsp_ggml_tensor * a,
3731
- struct wsp_ggml_tensor * b,
3732
- int s,
3733
- int d) {
3734
- return wsp_ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
3735
- }
3736
-
3737
- // wsp_ggml_conv_transpose_1d
3738
-
3739
- static int64_t wsp_ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
3740
- return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
3741
- }
3742
-
3743
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
3744
- struct wsp_ggml_context * ctx,
3745
- struct wsp_ggml_tensor * a,
3746
- struct wsp_ggml_tensor * b,
3747
- int s0,
3748
- int p0,
3749
- int d0) {
3750
- WSP_GGML_ASSERT(wsp_ggml_is_matrix(b));
3751
- WSP_GGML_ASSERT(a->ne[2] == b->ne[1]);
3752
- WSP_GGML_ASSERT(a->ne[3] == 1);
3753
-
3754
- WSP_GGML_ASSERT(p0 == 0);
3755
- WSP_GGML_ASSERT(d0 == 1);
3756
-
3757
- const int64_t ne[4] = {
3758
- wsp_ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
3759
- a->ne[1], b->ne[2], 1,
3760
- };
3761
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
3762
-
3763
- int32_t params[] = { s0, p0, d0 };
3764
- wsp_ggml_set_op_params(result, params, sizeof(params));
3765
-
3766
- result->op = WSP_GGML_OP_CONV_TRANSPOSE_1D;
3767
- result->src[0] = a;
3768
- result->src[1] = b;
3769
-
3770
- return result;
3771
- }
3772
-
3773
- // wsp_ggml_conv_depthwise
3774
-
3775
- struct wsp_ggml_tensor * wsp_ggml_conv_depthwise_2d(
3776
- struct wsp_ggml_context * ctx,
3777
- struct wsp_ggml_tensor * a,
3778
- struct wsp_ggml_tensor * b,
3779
- int s0,
3780
- int s1,
3781
- int p0,
3782
- int p1,
3783
- int d0,
3784
- int d1) {
3785
- struct wsp_ggml_tensor * new_a = wsp_ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
3786
- struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, new_a,
3787
- wsp_ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
3788
- s0, s1, p0, p1, d0, d1, true, WSP_GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
3789
- struct wsp_ggml_tensor * new_b = wsp_ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
3790
-
3791
- new_a = wsp_ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
3792
- struct wsp_ggml_tensor * result = wsp_ggml_mul_mat(ctx, new_a, new_b);
3793
- result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
3794
-
3795
- return result;
3796
- }
3797
- // wsp_ggml_conv_2d
3798
-
3799
3778
  // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
3800
3779
  // a: [OC,IC, KH, KW]
3801
3780
  // b: [N, IC, IH, IW]
@@ -3812,10 +3791,11 @@ struct wsp_ggml_tensor * wsp_ggml_im2col(
3812
3791
  int d1,
3813
3792
  bool is_2D,
3814
3793
  enum wsp_ggml_type dst_type) {
3815
- if(is_2D) {
3794
+ if (is_2D) {
3816
3795
  WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
3817
3796
  } else {
3818
- WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
3797
+ //WSP_GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
3798
+ WSP_GGML_ASSERT(b->ne[1] == a->ne[1]);
3819
3799
  WSP_GGML_ASSERT(b->ne[3] == 1);
3820
3800
  }
3821
3801
 
@@ -3866,51 +3846,178 @@ struct wsp_ggml_tensor * wsp_ggml_im2col_back(
3866
3846
  return result;
3867
3847
  }
3868
3848
 
3869
- // a: [OC,IC, KH, KW]
3870
- // b: [N, IC, IH, IW]
3871
- // result: [N, OC, OH, OW]
3872
- struct wsp_ggml_tensor * wsp_ggml_conv_2d(
3849
+ // wsp_ggml_conv_1d
3850
+
3851
+ struct wsp_ggml_tensor * wsp_ggml_conv_1d(
3873
3852
  struct wsp_ggml_context * ctx,
3874
3853
  struct wsp_ggml_tensor * a,
3875
3854
  struct wsp_ggml_tensor * b,
3876
3855
  int s0,
3877
- int s1,
3878
3856
  int p0,
3879
- int p1,
3880
- int d0,
3881
- int d1) {
3882
- struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
3857
+ int d0) {
3858
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, WSP_GGML_TYPE_F16); // [N, OL, IC * K]
3883
3859
 
3884
3860
  struct wsp_ggml_tensor * result =
3885
3861
  wsp_ggml_mul_mat(ctx,
3886
- wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
3887
- wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
3888
-
3889
- result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
3890
- result = wsp_ggml_cont(ctx, wsp_ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
3862
+ wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
3863
+ wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
3891
3864
 
3865
+ result = wsp_ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
3892
3866
 
3893
3867
  return result;
3894
3868
  }
3895
3869
 
3896
- // wsp_ggml_conv_2d_sk_p0
3870
+ // wsp_ggml_conv_1d_ph
3897
3871
 
3898
- struct wsp_ggml_tensor * wsp_ggml_conv_2d_sk_p0(
3872
+ struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph(
3899
3873
  struct wsp_ggml_context * ctx,
3900
3874
  struct wsp_ggml_tensor * a,
3901
- struct wsp_ggml_tensor * b) {
3902
- return wsp_ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
3875
+ struct wsp_ggml_tensor * b,
3876
+ int s,
3877
+ int d) {
3878
+ return wsp_ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
3903
3879
  }
3904
3880
 
3905
- // wsp_ggml_conv_2d_s1_ph
3881
+ // wsp_ggml_conv_1d_dw
3906
3882
 
3907
- struct wsp_ggml_tensor * wsp_ggml_conv_2d_s1_ph(
3883
+ struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw(
3908
3884
  struct wsp_ggml_context * ctx,
3909
3885
  struct wsp_ggml_tensor * a,
3910
- struct wsp_ggml_tensor * b) {
3911
- return wsp_ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
3912
- }
3913
-
3886
+ struct wsp_ggml_tensor * b,
3887
+ int s0,
3888
+ int p0,
3889
+ int d0) {
3890
+ struct wsp_ggml_tensor * new_a = wsp_ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
3891
+ struct wsp_ggml_tensor * new_b = wsp_ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
3892
+
3893
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, WSP_GGML_TYPE_F16);
3894
+
3895
+ struct wsp_ggml_tensor * result = wsp_ggml_mul_mat(ctx, im2col, a);
3896
+
3897
+ result = wsp_ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1);
3898
+
3899
+ return result;
3900
+ }
3901
+
3902
+ // wsp_ggml_conv_1d_dw_ph
3903
+
3904
+ struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw_ph(
3905
+ struct wsp_ggml_context * ctx,
3906
+ struct wsp_ggml_tensor * a,
3907
+ struct wsp_ggml_tensor * b,
3908
+ int s0,
3909
+ int d0) {
3910
+ return wsp_ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
3911
+ }
3912
+
3913
+ // wsp_ggml_conv_transpose_1d
3914
+
3915
+ static int64_t wsp_ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
3916
+ return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
3917
+ }
3918
+
3919
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
3920
+ struct wsp_ggml_context * ctx,
3921
+ struct wsp_ggml_tensor * a,
3922
+ struct wsp_ggml_tensor * b,
3923
+ int s0,
3924
+ int p0,
3925
+ int d0) {
3926
+ WSP_GGML_ASSERT(wsp_ggml_is_matrix(b));
3927
+ WSP_GGML_ASSERT(a->ne[2] == b->ne[1]);
3928
+ WSP_GGML_ASSERT(a->ne[3] == 1);
3929
+
3930
+ WSP_GGML_ASSERT(p0 == 0);
3931
+ WSP_GGML_ASSERT(d0 == 1);
3932
+
3933
+ const int64_t ne[4] = {
3934
+ wsp_ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
3935
+ a->ne[1], b->ne[2], 1,
3936
+ };
3937
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
3938
+
3939
+ int32_t params[] = { s0, p0, d0 };
3940
+ wsp_ggml_set_op_params(result, params, sizeof(params));
3941
+
3942
+ result->op = WSP_GGML_OP_CONV_TRANSPOSE_1D;
3943
+ result->src[0] = a;
3944
+ result->src[1] = b;
3945
+
3946
+ return result;
3947
+ }
3948
+
3949
+ // wsp_ggml_conv_2d
3950
+
3951
+ // a: [OC,IC, KH, KW]
3952
+ // b: [N, IC, IH, IW]
3953
+ // result: [N, OC, OH, OW]
3954
+ struct wsp_ggml_tensor * wsp_ggml_conv_2d(
3955
+ struct wsp_ggml_context * ctx,
3956
+ struct wsp_ggml_tensor * a,
3957
+ struct wsp_ggml_tensor * b,
3958
+ int s0,
3959
+ int s1,
3960
+ int p0,
3961
+ int p1,
3962
+ int d0,
3963
+ int d1) {
3964
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
3965
+
3966
+ struct wsp_ggml_tensor * result =
3967
+ wsp_ggml_mul_mat(ctx,
3968
+ wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
3969
+ wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
3970
+
3971
+ result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
3972
+ result = wsp_ggml_cont(ctx, wsp_ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
3973
+
3974
+
3975
+ return result;
3976
+ }
3977
+
3978
+ // wsp_ggml_conv_2d_sk_p0
3979
+
3980
+ struct wsp_ggml_tensor * wsp_ggml_conv_2d_sk_p0(
3981
+ struct wsp_ggml_context * ctx,
3982
+ struct wsp_ggml_tensor * a,
3983
+ struct wsp_ggml_tensor * b) {
3984
+ return wsp_ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
3985
+ }
3986
+
3987
+ // wsp_ggml_conv_2d_s1_ph
3988
+
3989
+ struct wsp_ggml_tensor * wsp_ggml_conv_2d_s1_ph(
3990
+ struct wsp_ggml_context * ctx,
3991
+ struct wsp_ggml_tensor * a,
3992
+ struct wsp_ggml_tensor * b) {
3993
+ return wsp_ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
3994
+ }
3995
+
3996
+ // wsp_ggml_conv_2d_dw
3997
+
3998
+ struct wsp_ggml_tensor * wsp_ggml_conv_2d_dw(
3999
+ struct wsp_ggml_context * ctx,
4000
+ struct wsp_ggml_tensor * a,
4001
+ struct wsp_ggml_tensor * b,
4002
+ int s0,
4003
+ int s1,
4004
+ int p0,
4005
+ int p1,
4006
+ int d0,
4007
+ int d1) {
4008
+ struct wsp_ggml_tensor * new_a = wsp_ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
4009
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, new_a,
4010
+ wsp_ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
4011
+ s0, s1, p0, p1, d0, d1, true, WSP_GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
4012
+ struct wsp_ggml_tensor * new_b = wsp_ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
4013
+
4014
+ new_a = wsp_ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
4015
+ struct wsp_ggml_tensor * result = wsp_ggml_mul_mat(ctx, new_a, new_b);
4016
+ result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
4017
+
4018
+ return result;
4019
+ }
4020
+
3914
4021
  // wsp_ggml_conv_transpose_2d_p0
3915
4022
 
3916
4023
  static int64_t wsp_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4087,6 +4194,37 @@ struct wsp_ggml_tensor * wsp_ggml_pad(
4087
4194
  return result;
4088
4195
  }
4089
4196
 
4197
+ // wsp_ggml_pad_reflect_1d
4198
+
4199
+ struct wsp_ggml_tensor * wsp_ggml_pad_reflect_1d(
4200
+ struct wsp_ggml_context * ctx,
4201
+ struct wsp_ggml_tensor * a,
4202
+ int p0,
4203
+ int p1) {
4204
+ WSP_GGML_ASSERT(p0 >= 0);
4205
+ WSP_GGML_ASSERT(p1 >= 0);
4206
+
4207
+ WSP_GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
4208
+ WSP_GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded
4209
+
4210
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a));
4211
+ WSP_GGML_ASSERT(a->type == WSP_GGML_TYPE_F32);
4212
+
4213
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type,
4214
+ a->ne[0] + p0 + p1,
4215
+ a->ne[1],
4216
+ a->ne[2],
4217
+ a->ne[3]);
4218
+
4219
+ int32_t params[] = { p0, p1 };
4220
+ wsp_ggml_set_op_params(result, params, sizeof(params));
4221
+
4222
+ result->op = WSP_GGML_OP_PAD_REFLECT_1D;
4223
+ result->src[0] = a;
4224
+
4225
+ return result;
4226
+ }
4227
+
4090
4228
  // wsp_ggml_arange
4091
4229
 
4092
4230
  struct wsp_ggml_tensor * wsp_ggml_arange(
@@ -4138,6 +4276,7 @@ struct wsp_ggml_tensor * wsp_ggml_argsort(
4138
4276
  struct wsp_ggml_context * ctx,
4139
4277
  struct wsp_ggml_tensor * a,
4140
4278
  enum wsp_ggml_sort_order order) {
4279
+ WSP_GGML_ASSERT(a->ne[0] <= INT32_MAX);
4141
4280
  struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, WSP_GGML_MAX_DIMS, a->ne);
4142
4281
 
4143
4282
  wsp_ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -4512,15 +4651,13 @@ struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv6(
4512
4651
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(state));
4513
4652
 
4514
4653
  const int64_t S = k->ne[0];
4515
- const int64_t H = k->ne[2];
4516
- const int64_t n_tokens = k->ne[3];
4654
+ const int64_t H = k->ne[1];
4655
+ const int64_t n_tokens = k->ne[2];
4517
4656
  const int64_t n_seqs = state->ne[1];
4518
4657
  {
4519
- WSP_GGML_ASSERT(k->ne[1] == 1);
4520
- WSP_GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
4521
- WSP_GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
4522
- // TODO: RWKV v4 and v5
4523
- WSP_GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
4658
+ WSP_GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
4659
+ WSP_GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
4660
+ WSP_GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
4524
4661
  WSP_GGML_ASSERT(wsp_ggml_nelements(state) == S * S * H * n_seqs);
4525
4662
  }
4526
4663
 
@@ -4539,6 +4676,49 @@ struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv6(
4539
4676
  return result;
4540
4677
  }
4541
4678
 
4679
+ // wsp_ggml_gated_linear_attn
4680
+
4681
+ struct wsp_ggml_tensor * wsp_ggml_gated_linear_attn(
4682
+ struct wsp_ggml_context * ctx,
4683
+ struct wsp_ggml_tensor * k,
4684
+ struct wsp_ggml_tensor * v,
4685
+ struct wsp_ggml_tensor * q,
4686
+ struct wsp_ggml_tensor * g,
4687
+ struct wsp_ggml_tensor * state,
4688
+ float scale) {
4689
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(k));
4690
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(v));
4691
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(q));
4692
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(g));
4693
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(state));
4694
+
4695
+ const int64_t S = k->ne[0];
4696
+ const int64_t H = k->ne[1];
4697
+ const int64_t n_tokens = k->ne[2];
4698
+ const int64_t n_seqs = state->ne[1];
4699
+ {
4700
+ WSP_GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
4701
+ WSP_GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
4702
+ WSP_GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
4703
+ WSP_GGML_ASSERT(wsp_ggml_nelements(state) == S * S * H * n_seqs);
4704
+ }
4705
+
4706
+ // concat output and new_state
4707
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
4708
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
4709
+
4710
+ wsp_ggml_set_op_params_f32(result, 0, scale);
4711
+
4712
+ result->op = WSP_GGML_OP_GATED_LINEAR_ATTN;
4713
+ result->src[0] = k;
4714
+ result->src[1] = v;
4715
+ result->src[2] = q;
4716
+ result->src[3] = g;
4717
+ result->src[4] = state;
4718
+
4719
+ return result;
4720
+ }
4721
+
4542
4722
  // wsp_ggml_unary
4543
4723
 
4544
4724
  static struct wsp_ggml_tensor * wsp_ggml_unary_impl(
@@ -4913,10 +5093,10 @@ struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss_back(
4913
5093
  struct wsp_ggml_tensor * a,
4914
5094
  struct wsp_ggml_tensor * b,
4915
5095
  struct wsp_ggml_tensor * c) {
4916
- WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b));
4917
- WSP_GGML_ASSERT(wsp_ggml_is_scalar(c));
5096
+ WSP_GGML_ASSERT(wsp_ggml_is_scalar(a));
5097
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(b, c));
4918
5098
 
4919
- struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
5099
+ struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, b);
4920
5100
 
4921
5101
  result->op = WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK;
4922
5102
  result->src[0] = a;
@@ -5019,8 +5199,10 @@ static void wsp_ggml_hash_map_free(struct hash_map * map) {
5019
5199
  }
5020
5200
 
5021
5201
  // utility functions to change gradients
5022
- // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
5023
- // else if a is in zero_table, replace a
5202
+ // isrc is the index of tensor in cgraph->visited_has_set.keys
5203
+ // the corresponding gradient (accumulators) are also at position isrc
5204
+ // if tensor has a gradient accumulator, modify that accumulator in-place
5205
+ // else if there is no gradient for tensor, set the corresponding value
5024
5206
  // else, just add/subtract/etc. the gradients
5025
5207
 
5026
5208
  static void wsp_ggml_add_or_set(
@@ -5028,11 +5210,14 @@ static void wsp_ggml_add_or_set(
5028
5210
  struct wsp_ggml_cgraph * cgraph,
5029
5211
  size_t isrc,
5030
5212
  struct wsp_ggml_tensor * tensor) {
5213
+ struct wsp_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
5214
+ WSP_GGML_ASSERT(src);
5031
5215
  if (cgraph->grads[isrc]) {
5032
- cgraph->grads[isrc] = wsp_ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
5216
+ cgraph->grads[isrc] = wsp_ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
5033
5217
  } else {
5034
5218
  cgraph->grads[isrc] = tensor;
5035
5219
  }
5220
+ wsp_ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
5036
5221
  wsp_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5037
5222
  }
5038
5223
 
@@ -5040,18 +5225,20 @@ static void wsp_ggml_acc_or_set(
5040
5225
  struct wsp_ggml_context * ctx,
5041
5226
  struct wsp_ggml_cgraph * cgraph,
5042
5227
  size_t isrc,
5043
- struct wsp_ggml_tensor * src,
5044
5228
  struct wsp_ggml_tensor * tensor,
5045
5229
  const size_t nb1,
5046
5230
  const size_t nb2,
5047
5231
  const size_t nb3,
5048
5232
  const size_t offset) {
5233
+ struct wsp_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
5234
+ WSP_GGML_ASSERT(src);
5049
5235
  if (cgraph->grads[isrc]) {
5050
5236
  cgraph->grads[isrc] = wsp_ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
5051
5237
  } else {
5052
5238
  struct wsp_ggml_tensor * a_zero = wsp_ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
5053
5239
  cgraph->grads[isrc] = wsp_ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
5054
5240
  }
5241
+ wsp_ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
5055
5242
  wsp_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5056
5243
  }
5057
5244
 
@@ -5059,13 +5246,15 @@ static void wsp_ggml_add1_or_set(
5059
5246
  struct wsp_ggml_context * ctx,
5060
5247
  struct wsp_ggml_cgraph * cgraph,
5061
5248
  size_t isrc,
5062
- struct wsp_ggml_tensor * src,
5063
5249
  struct wsp_ggml_tensor * tensor) {
5250
+ struct wsp_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
5251
+ WSP_GGML_ASSERT(src);
5064
5252
  if (cgraph->grads[isrc]) {
5065
5253
  cgraph->grads[isrc] = wsp_ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
5066
5254
  } else {
5067
5255
  cgraph->grads[isrc] = wsp_ggml_repeat(ctx, tensor, src);
5068
5256
  }
5257
+ wsp_ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
5069
5258
  wsp_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5070
5259
  }
5071
5260
 
@@ -5074,16 +5263,19 @@ static void wsp_ggml_sub_or_set(
5074
5263
  struct wsp_ggml_cgraph * cgraph,
5075
5264
  size_t isrc,
5076
5265
  struct wsp_ggml_tensor * tensor) {
5266
+ struct wsp_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
5267
+ WSP_GGML_ASSERT(src);
5077
5268
  if (cgraph->grads[isrc]) {
5078
5269
  cgraph->grads[isrc] = wsp_ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
5079
5270
  } else {
5080
5271
  cgraph->grads[isrc] = wsp_ggml_neg(ctx, tensor);
5081
5272
  }
5273
+ wsp_ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
5082
5274
  wsp_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5083
5275
  }
5084
5276
 
5085
5277
  static void wsp_ggml_compute_backward(
5086
- struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int i, bool * grads_needed) {
5278
+ struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int i, const bool * grads_needed) {
5087
5279
  struct wsp_ggml_tensor * tensor = cgraph->nodes[i];
5088
5280
  struct wsp_ggml_tensor * grad = wsp_ggml_graph_get_grad(cgraph, tensor);
5089
5281
 
@@ -5095,12 +5287,12 @@ static void wsp_ggml_compute_backward(
5095
5287
  struct wsp_ggml_tensor * src1 = tensor->src[1];
5096
5288
  struct wsp_ggml_tensor * src2 = tensor->src[2];
5097
5289
  struct wsp_ggml_hash_set * hash_set = &cgraph->visited_hash_set;
5098
- const size_t isrc0 = wsp_ggml_hash_find(hash_set, src0);
5099
- const size_t isrc1 = wsp_ggml_hash_find(hash_set, src1);
5100
- const size_t isrc2 = wsp_ggml_hash_find(hash_set, src2);
5101
- const bool src0_needs_grads = isrc0 != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
5102
- const bool src1_needs_grads = isrc1 != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
5103
- const bool src2_needs_grads = isrc2 != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
5290
+ const size_t isrc0 = src0 ? wsp_ggml_hash_find(hash_set, src0) : (size_t) -1;
5291
+ const size_t isrc1 = src1 ? wsp_ggml_hash_find(hash_set, src1) : (size_t) -1;
5292
+ const size_t isrc2 = src2 ? wsp_ggml_hash_find(hash_set, src2) : (size_t) -1;
5293
+ const bool src0_needs_grads = src0 && isrc0 != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
5294
+ const bool src1_needs_grads = src1 && isrc1 != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
5295
+ const bool src2_needs_grads = src2 && isrc2 != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
5104
5296
 
5105
5297
  switch (tensor->op) {
5106
5298
  case WSP_GGML_OP_DUP: {
@@ -5155,7 +5347,7 @@ static void wsp_ggml_compute_backward(
5155
5347
  } break;
5156
5348
  case WSP_GGML_OP_MUL: {
5157
5349
  if (src0_needs_grads) {
5158
- wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_mul(ctx, src1, grad));
5350
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_mul(ctx, grad, src1));
5159
5351
  }
5160
5352
  if (src1_needs_grads) {
5161
5353
  struct wsp_ggml_tensor * tmp = wsp_ggml_mul(ctx, src0, grad);
@@ -5200,7 +5392,7 @@ static void wsp_ggml_compute_backward(
5200
5392
  } break;
5201
5393
  case WSP_GGML_OP_SUM: {
5202
5394
  if (src0_needs_grads) {
5203
- wsp_ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad);
5395
+ wsp_ggml_add1_or_set(ctx, cgraph, isrc0, grad);
5204
5396
  }
5205
5397
  } break;
5206
5398
  case WSP_GGML_OP_SUM_ROWS: {
@@ -5210,7 +5402,7 @@ static void wsp_ggml_compute_backward(
5210
5402
  } break;
5211
5403
  case WSP_GGML_OP_MEAN: {
5212
5404
  if (src0_needs_grads) {
5213
- wsp_ggml_add1_or_set(ctx, cgraph, isrc0, src0, wsp_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
5405
+ wsp_ggml_add1_or_set(ctx, cgraph, isrc0, wsp_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
5214
5406
  }
5215
5407
  } break;
5216
5408
  case WSP_GGML_OP_REPEAT: {
@@ -5227,7 +5419,7 @@ static void wsp_ggml_compute_backward(
5227
5419
  if (src0_needs_grads) {
5228
5420
  float eps;
5229
5421
  memcpy(&eps, tensor->op_params, sizeof(float));
5230
- wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_rms_norm_back(ctx, src0, grad, eps));
5422
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_rms_norm_back(ctx, grad, src0, eps));
5231
5423
  }
5232
5424
  } break;
5233
5425
  case WSP_GGML_OP_MUL_MAT: {
@@ -5247,21 +5439,25 @@ static void wsp_ggml_compute_backward(
5247
5439
  // src1.shape [n,p,qq,rr]
5248
5440
 
5249
5441
  if (src0_needs_grads) {
5250
- struct wsp_ggml_tensor * s1_tg =
5442
+ WSP_GGML_ASSERT(grad->ne[2] == src1->ne[2]);
5443
+ WSP_GGML_ASSERT(grad->ne[3] == src1->ne[3]);
5444
+ struct wsp_ggml_tensor * tmp =
5251
5445
  wsp_ggml_out_prod(ctx, // [n,m,qq,rr]
5252
5446
  src1, // [n,p,qq,rr]
5253
5447
  grad); // [m,p,qq,rr]
5254
- const int64_t qq = s1_tg->ne[2];
5255
- const int64_t rr = s1_tg->ne[3];
5256
- const int64_t q1 = src0->ne[2];
5257
- const int64_t r1 = src0->ne[3];
5258
- const bool ne2_broadcasted = qq > q1;
5259
- const bool ne3_broadcasted = rr > r1;
5260
- if (ne2_broadcasted || ne3_broadcasted) {
5261
- // sum broadcast repetitions of s1_tg into shape of src0
5262
- s1_tg = wsp_ggml_repeat_back(ctx, s1_tg, src0);
5448
+ if (!wsp_ggml_are_same_shape(tmp, src0)) {
5449
+ WSP_GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
5450
+ WSP_GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
5451
+ WSP_GGML_ASSERT(tmp->ne[3] == 1);
5452
+
5453
+ const int64_t nr2 = tmp->ne[2] / src0->ne[2];
5454
+ const size_t nb2 = tmp->nb[2] * nr2;
5455
+ const size_t nb3 = tmp->nb[2];
5456
+
5457
+ tmp = wsp_ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
5458
+ tmp = wsp_ggml_repeat_back(ctx, tmp, src0);
5263
5459
  }
5264
- wsp_ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
5460
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0, tmp);
5265
5461
  }
5266
5462
  if (src1_needs_grads) {
5267
5463
  wsp_ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5330,7 +5526,9 @@ static void wsp_ggml_compute_backward(
5330
5526
  if (src0_needs_grads) {
5331
5527
  WSP_GGML_ASSERT(!cgraph->grads[isrc0] || wsp_ggml_is_contiguous(cgraph->grads[isrc0]));
5332
5528
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(grad));
5333
- wsp_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5529
+ WSP_GGML_ASSERT(wsp_ggml_nelements(tensor) == wsp_ggml_nelements(src0));
5530
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0,
5531
+ wsp_ggml_are_same_shape(tensor, src0) ? grad : wsp_ggml_reshape(ctx, grad, src0));
5334
5532
  }
5335
5533
  } break;
5336
5534
  case WSP_GGML_OP_RESHAPE: {
@@ -5363,7 +5561,7 @@ static void wsp_ggml_compute_backward(
5363
5561
  nb3 = (nb3 / n0) * ng;
5364
5562
  }
5365
5563
 
5366
- wsp_ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset);
5564
+ wsp_ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
5367
5565
  }
5368
5566
  } break;
5369
5567
  case WSP_GGML_OP_PERMUTE: {
@@ -5410,7 +5608,13 @@ static void wsp_ggml_compute_backward(
5410
5608
  } break;
5411
5609
  case WSP_GGML_OP_SOFT_MAX: {
5412
5610
  if (src0_needs_grads) {
5413
- wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_soft_max_back(ctx, grad, tensor));
5611
+ float scale = 1.0f;
5612
+ float max_bias = 0.0f;
5613
+
5614
+ memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
5615
+ memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
5616
+
5617
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
5414
5618
  }
5415
5619
  WSP_GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
5416
5620
  } break;
@@ -5422,6 +5626,7 @@ static void wsp_ggml_compute_backward(
5422
5626
  //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5423
5627
  const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
5424
5628
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5629
+ int sections[4] = {0, 0, 0, 0};
5425
5630
 
5426
5631
  memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
5427
5632
  memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
@@ -5429,10 +5634,14 @@ static void wsp_ggml_compute_backward(
5429
5634
  memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
5430
5635
  memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
5431
5636
  memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
5432
-
5433
- wsp_ggml_add_or_set(ctx, cgraph, isrc0,
5434
- wsp_ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
5435
- freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
5637
+ memcpy(&sections, tensor->op_params + 11, sizeof(sections));
5638
+
5639
+ struct wsp_ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
5640
+ wsp_ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
5641
+ mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
5642
+ wsp_ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
5643
+ mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
5644
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
5436
5645
  }
5437
5646
  WSP_GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
5438
5647
  } break;
@@ -5446,7 +5655,7 @@ static void wsp_ggml_compute_backward(
5446
5655
  const int32_t d1 = wsp_ggml_get_op_params_i32(tensor, 5);
5447
5656
  const bool is_2D = wsp_ggml_get_op_params_i32(tensor, 6) == 1;
5448
5657
 
5449
- wsp_ggml_add_or_set(ctx, cgraph, isrc1, wsp_ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5658
+ wsp_ggml_add_or_set(ctx, cgraph, isrc1, wsp_ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5450
5659
  }
5451
5660
  } break;
5452
5661
  case WSP_GGML_OP_POOL_2D: {
@@ -5489,7 +5698,7 @@ static void wsp_ggml_compute_backward(
5489
5698
  } break;
5490
5699
  case WSP_GGML_UNARY_OP_SILU: {
5491
5700
  if (src0_needs_grads) {
5492
- wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_silu_back(ctx, src0, grad));
5701
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_silu_back(ctx, grad, src0));
5493
5702
  }
5494
5703
  } break;
5495
5704
  case WSP_GGML_UNARY_OP_EXP: {
@@ -5506,7 +5715,7 @@ static void wsp_ggml_compute_backward(
5506
5715
  } break;
5507
5716
  case WSP_GGML_OP_CROSS_ENTROPY_LOSS: {
5508
5717
  if (src0_needs_grads) {
5509
- wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
5718
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
5510
5719
  }
5511
5720
  WSP_GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
5512
5721
  } break;
@@ -5597,10 +5806,9 @@ void wsp_ggml_build_backward_expand(
5597
5806
 
5598
5807
  const int n_nodes_f = cgraph->n_nodes;
5599
5808
 
5600
- const size_t hash_size = wsp_ggml_hash_size(2*cgraph->size);
5601
- memset(cgraph->grads, 0, hash_size*sizeof(struct wsp_ggml_tensor *));
5602
- memset(cgraph->grad_accs, 0, hash_size*sizeof(struct wsp_ggml_tensor *));
5603
- bool * grads_needed = calloc(hash_size, sizeof(bool));
5809
+ memset(cgraph->grads, 0, cgraph->visited_hash_set.size*sizeof(struct wsp_ggml_tensor *));
5810
+ memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct wsp_ggml_tensor *));
5811
+ bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
5604
5812
 
5605
5813
  {
5606
5814
  bool any_params = false;
@@ -5621,7 +5829,7 @@ void wsp_ggml_build_backward_expand(
5621
5829
  continue;
5622
5830
  }
5623
5831
 
5624
- bool node_needs_grad = node->flags & WSP_GGML_TENSOR_FLAG_PARAM;
5832
+ bool node_needs_grad = (node->flags & WSP_GGML_TENSOR_FLAG_PARAM) || (node->flags & WSP_GGML_TENSOR_FLAG_LOSS);
5625
5833
  bool ignore_src[WSP_GGML_MAX_SRC] = {false};
5626
5834
  switch (node->op) {
5627
5835
  // gradients in node->src[0] for one reason or another have no effect on output gradients
@@ -5638,7 +5846,7 @@ void wsp_ggml_build_backward_expand(
5638
5846
  } break;
5639
5847
 
5640
5848
  // gradients in node->src[1] for one reason or another have no effect on output gradients
5641
- case WSP_GGML_OP_CPY: // gradients in CPY target are irrelevant
5849
+ case WSP_GGML_OP_CPY: // gradients in CPY target are irrelevant
5642
5850
  case WSP_GGML_OP_GET_ROWS: // row indices not differentiable
5643
5851
  case WSP_GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
5644
5852
  case WSP_GGML_OP_ROPE: // positions not differentiable
@@ -5665,9 +5873,12 @@ void wsp_ggml_build_backward_expand(
5665
5873
  node->op == WSP_GGML_OP_RESHAPE || node->op == WSP_GGML_OP_PERMUTE || node->op == WSP_GGML_OP_TRANSPOSE);
5666
5874
 
5667
5875
  const size_t igrad = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
5876
+ WSP_GGML_ASSERT(igrad != WSP_GGML_HASHSET_FULL);
5877
+ WSP_GGML_ASSERT(wsp_ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
5668
5878
  if ((accumulate && (node->flags & WSP_GGML_TENSOR_FLAG_PARAM)) || (node->flags & WSP_GGML_TENSOR_FLAG_LOSS)) {
5669
- cgraph->grads[igrad] = wsp_ggml_dup_tensor(ctx_static, node);
5670
- cgraph->grad_accs[igrad] = cgraph->grads[igrad];
5879
+ cgraph->grad_accs[igrad] = wsp_ggml_dup_tensor(ctx_static, node);
5880
+ cgraph->grads[igrad] = cgraph->grad_accs[igrad];
5881
+ wsp_ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
5671
5882
  }
5672
5883
  grads_needed[igrad] = true;
5673
5884
  }
@@ -5761,15 +5972,15 @@ struct wsp_ggml_cgraph * wsp_ggml_new_graph(struct wsp_ggml_context * ctx) {
5761
5972
 
5762
5973
  struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph0, int i0, int i1) {
5763
5974
  struct wsp_ggml_cgraph cgraph = {
5764
- /*.size =*/ 0,
5765
- /*.n_nodes =*/ i1 - i0,
5766
- /*.n_leafs =*/ 0,
5767
- /*.nodes =*/ cgraph0->nodes + i0,
5768
- /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
5769
- /*.grad_accs =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL,
5770
- /*.leafs =*/ NULL,
5771
- /*.hash_table =*/ { 0, NULL, NULL },
5772
- /*.order =*/ cgraph0->order,
5975
+ /*.size =*/ 0,
5976
+ /*.n_nodes =*/ i1 - i0,
5977
+ /*.n_leafs =*/ 0,
5978
+ /*.nodes =*/ cgraph0->nodes + i0,
5979
+ /*.grads =*/ NULL, // gradients would need visited_hash_set
5980
+ /*.grad_accs =*/ NULL,
5981
+ /*.leafs =*/ NULL,
5982
+ /*.visited_hash_set =*/ { 0, NULL, NULL },
5983
+ /*.order =*/ cgraph0->order,
5773
5984
  };
5774
5985
 
5775
5986
  return cgraph;
@@ -5799,12 +6010,22 @@ void wsp_ggml_graph_cpy(struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * d
5799
6010
  }
5800
6011
  }
5801
6012
 
6013
+ if (dst->grads) {
6014
+ memset(dst->grads, 0, dst->visited_hash_set.size*sizeof(struct wsp_ggml_tensor *));
6015
+ memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct wsp_ggml_tensor *));
6016
+ }
5802
6017
  if (src->grads) {
5803
6018
  WSP_GGML_ASSERT(dst->grads != NULL);
5804
6019
  WSP_GGML_ASSERT(dst->grad_accs != NULL);
5805
6020
  for (int i = 0; i < src->n_nodes; ++i) {
5806
6021
  const size_t igrad_src = wsp_ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
5807
6022
  const size_t igrad_dst = wsp_ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
6023
+
6024
+ WSP_GGML_ASSERT(igrad_src != WSP_GGML_HASHSET_FULL);
6025
+ WSP_GGML_ASSERT(wsp_ggml_bitset_get(src->visited_hash_set.used, igrad_src));
6026
+ WSP_GGML_ASSERT(igrad_dst != WSP_GGML_HASHSET_FULL);
6027
+ WSP_GGML_ASSERT(wsp_ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
6028
+
5808
6029
  dst->grads[igrad_dst] = src->grads[igrad_src];
5809
6030
  dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
5810
6031
  }
@@ -5839,12 +6060,8 @@ void wsp_ggml_graph_reset(struct wsp_ggml_cgraph * cgraph) {
5839
6060
 
5840
6061
  if (node->op == WSP_GGML_OP_OPT_STEP_ADAMW) {
5841
6062
  // clear momenta
5842
- if (node->src[2]->data) {
5843
- wsp_ggml_set_zero(node->src[2]);
5844
- }
5845
- if (node->src[3]->data) {
5846
- wsp_ggml_set_zero(node->src[3]);
5847
- }
6063
+ wsp_ggml_set_zero(node->src[2]);
6064
+ wsp_ggml_set_zero(node->src[3]);
5848
6065
  }
5849
6066
 
5850
6067
  // initial gradients of loss should be 1, 0 otherwise
@@ -5923,12 +6140,12 @@ struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(const struct wsp_ggml_cgraph
5923
6140
 
5924
6141
  struct wsp_ggml_tensor * wsp_ggml_graph_get_grad(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node) {
5925
6142
  const size_t igrad = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
5926
- return igrad != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL;
6143
+ return igrad != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL;
5927
6144
  }
5928
6145
 
5929
6146
  struct wsp_ggml_tensor * wsp_ggml_graph_get_grad_acc(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node) {
5930
6147
  const size_t igrad = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
5931
- return igrad != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL;
6148
+ return igrad != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL;
5932
6149
  }
5933
6150
 
5934
6151
  void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph) {
@@ -6240,9 +6457,6 @@ size_t wsp_ggml_wsp_quantize_chunk(
6240
6457
  case WSP_GGML_TYPE_IQ1_M: result = wsp_quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6241
6458
  case WSP_GGML_TYPE_IQ4_NL: result = wsp_quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6242
6459
  case WSP_GGML_TYPE_IQ4_XS: result = wsp_quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6243
- case WSP_GGML_TYPE_Q4_0_4_4: result = wsp_quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6244
- case WSP_GGML_TYPE_Q4_0_4_8: result = wsp_quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6245
- case WSP_GGML_TYPE_Q4_0_8_8: result = wsp_quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6246
6460
  case WSP_GGML_TYPE_F16:
6247
6461
  {
6248
6462
  size_t elemsize = sizeof(wsp_ggml_fp16_t);
@@ -6272,1280 +6486,30 @@ size_t wsp_ggml_wsp_quantize_chunk(
6272
6486
 
6273
6487
  ////////////////////////////////////////////////////////////////////////////////
6274
6488
 
6275
- struct wsp_gguf_str {
6276
- uint64_t n; // GGUFv2
6277
- char * data;
6278
- };
6279
-
6280
- static const size_t WSP_GGUF_TYPE_SIZE[WSP_GGUF_TYPE_COUNT] = {
6281
- [WSP_GGUF_TYPE_UINT8] = sizeof(uint8_t),
6282
- [WSP_GGUF_TYPE_INT8] = sizeof(int8_t),
6283
- [WSP_GGUF_TYPE_UINT16] = sizeof(uint16_t),
6284
- [WSP_GGUF_TYPE_INT16] = sizeof(int16_t),
6285
- [WSP_GGUF_TYPE_UINT32] = sizeof(uint32_t),
6286
- [WSP_GGUF_TYPE_INT32] = sizeof(int32_t),
6287
- [WSP_GGUF_TYPE_FLOAT32] = sizeof(float),
6288
- [WSP_GGUF_TYPE_BOOL] = sizeof(bool),
6289
- [WSP_GGUF_TYPE_STRING] = sizeof(struct wsp_gguf_str),
6290
- [WSP_GGUF_TYPE_UINT64] = sizeof(uint64_t),
6291
- [WSP_GGUF_TYPE_INT64] = sizeof(int64_t),
6292
- [WSP_GGUF_TYPE_FLOAT64] = sizeof(double),
6293
- [WSP_GGUF_TYPE_ARRAY] = 0, // undefined
6294
- };
6295
- static_assert(WSP_GGUF_TYPE_COUNT == 13, "WSP_GGUF_TYPE_COUNT != 13");
6296
-
6297
- static const char * WSP_GGUF_TYPE_NAME[WSP_GGUF_TYPE_COUNT] = {
6298
- [WSP_GGUF_TYPE_UINT8] = "u8",
6299
- [WSP_GGUF_TYPE_INT8] = "i8",
6300
- [WSP_GGUF_TYPE_UINT16] = "u16",
6301
- [WSP_GGUF_TYPE_INT16] = "i16",
6302
- [WSP_GGUF_TYPE_UINT32] = "u32",
6303
- [WSP_GGUF_TYPE_INT32] = "i32",
6304
- [WSP_GGUF_TYPE_FLOAT32] = "f32",
6305
- [WSP_GGUF_TYPE_BOOL] = "bool",
6306
- [WSP_GGUF_TYPE_STRING] = "str",
6307
- [WSP_GGUF_TYPE_ARRAY] = "arr",
6308
- [WSP_GGUF_TYPE_UINT64] = "u64",
6309
- [WSP_GGUF_TYPE_INT64] = "i64",
6310
- [WSP_GGUF_TYPE_FLOAT64] = "f64",
6311
- };
6312
- static_assert(WSP_GGUF_TYPE_COUNT == 13, "WSP_GGUF_TYPE_COUNT != 13");
6313
-
6314
- union wsp_gguf_value {
6315
- uint8_t uint8;
6316
- int8_t int8;
6317
- uint16_t uint16;
6318
- int16_t int16;
6319
- uint32_t uint32;
6320
- int32_t int32;
6321
- float float32;
6322
- uint64_t uint64;
6323
- int64_t int64;
6324
- double float64;
6325
- bool bool_;
6326
-
6327
- struct wsp_gguf_str str;
6328
-
6329
- struct {
6330
- enum wsp_gguf_type type;
6331
-
6332
- uint64_t n; // GGUFv2
6333
- void * data;
6334
- } arr;
6335
- };
6336
-
6337
- struct wsp_gguf_kv {
6338
- struct wsp_gguf_str key;
6339
-
6340
- enum wsp_gguf_type type;
6341
- union wsp_gguf_value value;
6342
- };
6343
-
6344
- struct wsp_gguf_header {
6345
- char magic[4];
6346
-
6347
- uint32_t version;
6348
- uint64_t n_tensors; // GGUFv2
6349
- uint64_t n_kv; // GGUFv2
6350
- };
6351
-
6352
- struct wsp_gguf_tensor_info {
6353
- struct wsp_gguf_str name;
6354
-
6355
- uint32_t n_dims;
6356
- uint64_t ne[WSP_GGML_MAX_DIMS];
6357
-
6358
- enum wsp_ggml_type type;
6359
-
6360
- uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
6361
-
6362
- // for writing API
6363
- const void * data;
6364
- size_t size;
6365
- };
6366
-
6367
- struct wsp_gguf_context {
6368
- struct wsp_gguf_header header;
6369
-
6370
- struct wsp_gguf_kv * kv;
6371
- struct wsp_gguf_tensor_info * infos;
6372
-
6373
- size_t alignment;
6374
- size_t offset; // offset of `data` from beginning of file
6375
- size_t size; // size of `data` in bytes
6376
-
6377
- //uint8_t * padding;
6378
- void * data;
6379
- };
6380
-
6381
- static size_t wsp_gguf_type_size(enum wsp_gguf_type type) {
6382
- WSP_GGML_ASSERT(0 <= type && type < WSP_GGUF_TYPE_COUNT);
6383
- return WSP_GGUF_TYPE_SIZE[type];
6384
- }
6385
-
6386
- static bool wsp_gguf_tensor_info_sanitize(struct wsp_gguf_tensor_info * info) {
6387
- if (info->n_dims > WSP_GGML_MAX_DIMS) {
6388
- fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
6389
- return false;
6390
- }
6391
-
6392
- if (info->type < 0 || info->type >= WSP_GGML_TYPE_COUNT) {
6393
- fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
6394
- return false;
6395
- }
6396
-
6397
- if (strlen(info->name.data) >= WSP_GGML_MAX_NAME) {
6398
- fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
6399
- return false;
6400
- }
6401
-
6402
- for (uint32_t i = 0; i < info->n_dims; ++i) {
6403
- if (info->ne[i] <= 0) {
6404
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
6405
- return false;
6406
- }
6407
- }
6408
-
6409
- // prevent overflow for total number of elements
6410
- if (INT64_MAX/info->ne[1] <= info->ne[0]) {
6411
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
6412
- return false;
6413
- }
6414
-
6415
- if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
6416
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
6417
- return false;
6418
- }
6419
-
6420
- if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
6421
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
6422
- return false;
6423
- }
6424
-
6425
- return true;
6489
+ void wsp_ggml_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
6490
+ g_logger_state.log_callback = log_callback ? log_callback : wsp_ggml_log_callback_default;
6491
+ g_logger_state.log_callback_user_data = user_data;
6426
6492
  }
6427
6493
 
6428
- static bool wsp_gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
6429
- const size_t n = fread(dst, 1, size, file);
6430
- *offset += n;
6431
- return n == size;
6494
+ void wsp_ggml_threadpool_params_init(struct wsp_ggml_threadpool_params * p, int n_threads) {
6495
+ p->n_threads = n_threads;
6496
+ p->prio = 0; // default priority (usually means normal or inherited)
6497
+ p->poll = 50; // hybrid-polling enabled
6498
+ p->strict_cpu = false; // no strict placement (all threads share same cpumask)
6499
+ p->paused = false; // threads are ready to go
6500
+ memset(p->cpumask, 0, WSP_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
6432
6501
  }
6433
6502
 
6434
- static bool wsp_gguf_fread_str(FILE * file, struct wsp_gguf_str * p, size_t * offset) {
6435
- p->n = 0;
6436
- p->data = NULL;
6437
-
6438
- bool ok = true;
6439
-
6440
- ok = ok && wsp_gguf_fread_el(file, &p->n, sizeof(p->n), offset);
6441
-
6442
- // early exit if string length is invalid, prevents from integer overflow
6443
- if (p->n == SIZE_MAX) {
6444
- fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n);
6445
- return false;
6446
- }
6447
-
6448
- p->data = calloc(p->n + 1, 1);
6449
- if (!p->data) {
6450
- fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
6451
- return false;
6452
- }
6453
-
6454
- ok = ok && wsp_gguf_fread_el(file, p->data, p->n, offset);
6455
-
6456
- return ok;
6503
+ struct wsp_ggml_threadpool_params wsp_ggml_threadpool_params_default(int n_threads) {
6504
+ struct wsp_ggml_threadpool_params p;
6505
+ wsp_ggml_threadpool_params_init(&p, n_threads);
6506
+ return p;
6457
6507
  }
6458
6508
 
6459
- static void wsp_gguf_free_kv(struct wsp_gguf_kv * kv) {
6460
- if (kv->key.data) {
6461
- WSP_GGML_FREE(kv->key.data);
6462
- }
6463
-
6464
- if (kv->type == WSP_GGUF_TYPE_STRING) {
6465
- if (kv->value.str.data) {
6466
- WSP_GGML_FREE(kv->value.str.data);
6467
- }
6468
- }
6469
-
6470
- if (kv->type == WSP_GGUF_TYPE_ARRAY) {
6471
- if (kv->value.arr.data) {
6472
- if (kv->value.arr.type == WSP_GGUF_TYPE_STRING) {
6473
- for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
6474
- struct wsp_gguf_str * str = &((struct wsp_gguf_str *) kv->value.arr.data)[j];
6475
- if (str->data) {
6476
- WSP_GGML_FREE(str->data);
6477
- }
6478
- }
6479
- }
6480
- WSP_GGML_FREE(kv->value.arr.data);
6481
- }
6482
- }
6483
- }
6484
-
6485
- struct wsp_gguf_context * wsp_gguf_init_empty(void) {
6486
- struct wsp_gguf_context * ctx = calloc(1, sizeof(struct wsp_gguf_context));
6487
- if (!ctx) {
6488
- fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
6489
- return NULL;
6490
- }
6491
-
6492
- memcpy(ctx->header.magic, WSP_GGUF_MAGIC, sizeof(ctx->header.magic));
6493
- ctx->header.version = WSP_GGUF_VERSION;
6494
- ctx->header.n_tensors = 0;
6495
- ctx->header.n_kv = 0;
6496
-
6497
- ctx->kv = NULL;
6498
- ctx->infos = NULL;
6499
-
6500
- ctx->alignment = WSP_GGUF_DEFAULT_ALIGNMENT;
6501
- ctx->offset = 0;
6502
- ctx->size = 0;
6503
-
6504
- ctx->data = NULL;
6505
-
6506
- return ctx;
6507
- }
6508
-
6509
- struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp_gguf_init_params params) {
6510
- FILE * file = wsp_ggml_fopen(fname, "rb");
6511
- if (!file) {
6512
- fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
6513
- return NULL;
6514
- }
6515
-
6516
- // offset from start of file
6517
- size_t offset = 0;
6518
-
6519
- char magic[4];
6520
-
6521
- // check the magic before making allocations
6522
- {
6523
- wsp_gguf_fread_el(file, &magic, sizeof(magic), &offset);
6524
-
6525
- for (uint32_t i = 0; i < sizeof(magic); i++) {
6526
- if (magic[i] != WSP_GGUF_MAGIC[i]) {
6527
- fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
6528
- fclose(file);
6529
- return NULL;
6530
- }
6531
- }
6532
- }
6533
-
6534
- bool ok = true;
6535
-
6536
- struct wsp_gguf_context * ctx = calloc(1, sizeof(struct wsp_gguf_context));
6537
- if (!ctx) {
6538
- fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
6539
- fclose(file);
6540
- return NULL;
6541
- }
6542
-
6543
- // read the header
6544
- {
6545
- strncpy(ctx->header.magic, magic, 4);
6546
-
6547
- ctx->kv = NULL;
6548
- ctx->infos = NULL;
6549
- ctx->data = NULL;
6550
-
6551
- ok = ok && wsp_gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset);
6552
- ok = ok && wsp_gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
6553
- ok = ok && wsp_gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset);
6554
-
6555
- if (ctx->header.version == 1) {
6556
- fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
6557
- fclose(file);
6558
- wsp_gguf_free(ctx);
6559
- return NULL;
6560
- }
6561
-
6562
- // sanity-checks to prevent from integer/buffer overflows
6563
-
6564
- ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct wsp_gguf_tensor_info));
6565
- ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/wsp_ggml_tensor_overhead());
6566
- ok = ok && (ctx->header.n_kv < (SIZE_MAX/2)/sizeof(struct wsp_gguf_kv));
6567
-
6568
- if (!ok) {
6569
- fprintf(stderr, "%s: failed to read header\n", __func__);
6570
- fclose(file);
6571
- wsp_gguf_free(ctx);
6572
- return NULL;
6573
- }
6574
- }
6575
-
6576
- // read the kv pairs
6577
- {
6578
- const uint64_t n_kv = ctx->header.n_kv;
6579
-
6580
- ctx->kv = calloc(n_kv, sizeof(struct wsp_gguf_kv));
6581
- if (!ctx->kv) {
6582
- fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
6583
- fclose(file);
6584
- wsp_gguf_free(ctx);
6585
- return NULL;
6586
- }
6587
-
6588
- for (uint64_t i = 0; i < n_kv; ++i) {
6589
- struct wsp_gguf_kv * kv = &ctx->kv[i];
6590
-
6591
- //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
6592
-
6593
- ok = ok && wsp_gguf_fread_str(file, &kv->key, &offset);
6594
- ok = ok && wsp_gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
6595
-
6596
- //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
6597
-
6598
- switch (kv->type) {
6599
- case WSP_GGUF_TYPE_UINT8: ok = ok && wsp_gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break;
6600
- case WSP_GGUF_TYPE_INT8: ok = ok && wsp_gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break;
6601
- case WSP_GGUF_TYPE_UINT16: ok = ok && wsp_gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break;
6602
- case WSP_GGUF_TYPE_INT16: ok = ok && wsp_gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break;
6603
- case WSP_GGUF_TYPE_UINT32: ok = ok && wsp_gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break;
6604
- case WSP_GGUF_TYPE_INT32: ok = ok && wsp_gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break;
6605
- case WSP_GGUF_TYPE_FLOAT32: ok = ok && wsp_gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
6606
- case WSP_GGUF_TYPE_UINT64: ok = ok && wsp_gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break;
6607
- case WSP_GGUF_TYPE_INT64: ok = ok && wsp_gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break;
6608
- case WSP_GGUF_TYPE_FLOAT64: ok = ok && wsp_gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
6609
- case WSP_GGUF_TYPE_BOOL: ok = ok && wsp_gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break;
6610
- case WSP_GGUF_TYPE_STRING: ok = ok && wsp_gguf_fread_str(file, &kv->value.str, &offset); break;
6611
- case WSP_GGUF_TYPE_ARRAY:
6612
- {
6613
- ok = ok && wsp_gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
6614
- ok = ok && wsp_gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
6615
-
6616
- switch (kv->value.arr.type) {
6617
- case WSP_GGUF_TYPE_UINT8:
6618
- case WSP_GGUF_TYPE_INT8:
6619
- case WSP_GGUF_TYPE_UINT16:
6620
- case WSP_GGUF_TYPE_INT16:
6621
- case WSP_GGUF_TYPE_UINT32:
6622
- case WSP_GGUF_TYPE_INT32:
6623
- case WSP_GGUF_TYPE_FLOAT32:
6624
- case WSP_GGUF_TYPE_UINT64:
6625
- case WSP_GGUF_TYPE_INT64:
6626
- case WSP_GGUF_TYPE_FLOAT64:
6627
- case WSP_GGUF_TYPE_BOOL:
6628
- {
6629
- // prevent from integer overflow in the malloc below
6630
- if (kv->value.arr.n >= SIZE_MAX/wsp_gguf_type_size(kv->value.arr.type)) {
6631
- fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
6632
- fclose(file);
6633
- wsp_gguf_free(ctx);
6634
- return NULL;
6635
- }
6636
-
6637
- kv->value.arr.data = calloc(kv->value.arr.n, wsp_gguf_type_size(kv->value.arr.type));
6638
- if (!kv->value.arr.data) {
6639
- fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
6640
- fclose(file);
6641
- wsp_gguf_free(ctx);
6642
- return NULL;
6643
- }
6644
-
6645
- ok = ok && wsp_gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * wsp_gguf_type_size(kv->value.arr.type), &offset);
6646
- } break;
6647
- case WSP_GGUF_TYPE_STRING:
6648
- {
6649
- // prevent from integer overflow in the malloc below
6650
- if (kv->value.arr.n >= SIZE_MAX/sizeof(struct wsp_gguf_str)) {
6651
- fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
6652
- fclose(file);
6653
- wsp_gguf_free(ctx);
6654
- return NULL;
6655
- }
6656
-
6657
- kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct wsp_gguf_str));
6658
- if (!kv->value.arr.data) {
6659
- fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
6660
- fclose(file);
6661
- wsp_gguf_free(ctx);
6662
- return NULL;
6663
- }
6664
-
6665
- for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
6666
- ok = ok && wsp_gguf_fread_str(file, &((struct wsp_gguf_str *) kv->value.arr.data)[j], &offset);
6667
- }
6668
- } break;
6669
- case WSP_GGUF_TYPE_ARRAY:
6670
- default:
6671
- {
6672
- fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
6673
- ok = false;
6674
- } break;
6675
- }
6676
- } break;
6677
- default:
6678
- {
6679
- fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
6680
- ok = false;
6681
- } break;
6682
- }
6683
-
6684
- if (!ok) {
6685
- break;
6686
- }
6687
- }
6688
-
6689
- if (!ok) {
6690
- fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
6691
- fclose(file);
6692
- wsp_gguf_free(ctx);
6693
- return NULL;
6694
- }
6695
- }
6696
-
6697
- // read the tensor infos
6698
- if (ctx->header.n_tensors > 0) {
6699
- ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct wsp_gguf_tensor_info));
6700
- if (!ctx->infos) {
6701
- fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
6702
- fclose(file);
6703
- wsp_gguf_free(ctx);
6704
- return NULL;
6705
- }
6706
-
6707
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6708
- struct wsp_gguf_tensor_info * info = &ctx->infos[i];
6709
-
6710
- for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) {
6711
- info->ne[j] = 1;
6712
- }
6713
-
6714
- ok = ok && wsp_gguf_fread_str(file, &info->name, &offset);
6715
- ok = ok && wsp_gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset);
6716
-
6717
- ok = ok && (info->n_dims <= WSP_GGML_MAX_DIMS);
6718
-
6719
- for (uint32_t j = 0; j < info->n_dims; ++j) {
6720
- ok = ok && wsp_gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
6721
- }
6722
-
6723
- ok = ok && wsp_gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
6724
- ok = ok && wsp_gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
6725
-
6726
- ok = ok && wsp_gguf_tensor_info_sanitize(info);
6727
-
6728
- // make sure there is no duplicated tensor names
6729
- for (uint64_t j = 0; j < i && ok; ++j) {
6730
- if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
6731
- fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
6732
- ok = false;
6733
- }
6734
- }
6735
-
6736
- if (!ok) {
6737
- fprintf(stderr, "%s: failed to read tensor info\n", __func__);
6738
- fclose(file);
6739
- wsp_gguf_free(ctx);
6740
- return NULL;
6741
- }
6742
- }
6743
- }
6744
-
6745
- ctx->alignment = WSP_GGUF_DEFAULT_ALIGNMENT;
6746
-
6747
- int alignment_idx = wsp_gguf_find_key(ctx, "general.alignment");
6748
- if (alignment_idx != -1) {
6749
- ctx->alignment = wsp_gguf_get_val_u32(ctx, alignment_idx);
6750
- }
6751
-
6752
- // we require the data section to be aligned, so take into account any padding
6753
- {
6754
- const size_t offset_pad = offset % ctx->alignment;
6755
-
6756
- if (offset_pad != 0) {
6757
- offset += ctx->alignment - offset_pad;
6758
- fseek(file, offset, SEEK_SET);
6759
- }
6760
- }
6761
-
6762
- // store the current file offset - this is where the data section starts
6763
- ctx->offset = offset;
6764
-
6765
- // compute the total size of the data section, taking into account the alignment
6766
- {
6767
- ctx->size = 0;
6768
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6769
- struct wsp_gguf_tensor_info * info = &ctx->infos[i];
6770
-
6771
- const int64_t ne =
6772
- (int64_t) info->ne[0] *
6773
- (int64_t) info->ne[1] *
6774
- (int64_t) info->ne[2] *
6775
- (int64_t) info->ne[3];
6776
-
6777
- if (wsp_ggml_blck_size(info->type) == 0 || ne % wsp_ggml_blck_size(info->type) != 0) {
6778
- fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
6779
- __func__, info->name.data, (int) info->type, wsp_ggml_type_name(info->type), ne, wsp_ggml_blck_size(info->type));
6780
- fclose(file);
6781
- wsp_gguf_free(ctx);
6782
- return NULL;
6783
- }
6784
-
6785
- const size_t size_cur = wsp_ggml_row_size(info->type, ne);
6786
-
6787
- ctx->size += WSP_GGML_PAD(size_cur, ctx->alignment);
6788
- }
6789
- }
6790
-
6791
- // load the tensor data only if requested
6792
- if (params.ctx != NULL) {
6793
- // if the provided wsp_gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
6794
- // otherwise, we load the binary blob into the created wsp_ggml_context as well, and point the "data" members of
6795
- // the wsp_ggml_tensor structs to the appropriate locations in the binary blob
6796
-
6797
- // compute the exact size needed for the new wsp_ggml_context
6798
- const size_t mem_size =
6799
- params.no_alloc ?
6800
- (ctx->header.n_tensors )*wsp_ggml_tensor_overhead() :
6801
- (ctx->header.n_tensors + 1)*wsp_ggml_tensor_overhead() + ctx->size;
6802
-
6803
- struct wsp_ggml_init_params pdata = {
6804
- .mem_size = mem_size,
6805
- .mem_buffer = NULL,
6806
- .no_alloc = params.no_alloc,
6807
- };
6808
-
6809
- *params.ctx = wsp_ggml_init(pdata);
6810
- if (*params.ctx == NULL) {
6811
- fprintf(stderr, "%s: failed to initialize context\n", __func__);
6812
- fclose(file);
6813
- wsp_gguf_free(ctx);
6814
- return NULL;
6815
- }
6816
-
6817
- struct wsp_ggml_context * ctx_data = *params.ctx;
6818
-
6819
- struct wsp_ggml_tensor * data = NULL;
6820
-
6821
- if (!params.no_alloc) {
6822
- data = wsp_ggml_new_tensor_1d(ctx_data, WSP_GGML_TYPE_I8, ctx->size);
6823
-
6824
- ok = ok && data != NULL;
6825
-
6826
- // read the binary blob with the tensor data
6827
- ok = ok && wsp_gguf_fread_el(file, data->data, ctx->size, &offset);
6828
-
6829
- if (!ok) {
6830
- fprintf(stderr, "%s: failed to read tensor data\n", __func__);
6831
- fclose(file);
6832
- wsp_ggml_free(ctx_data);
6833
- wsp_gguf_free(ctx);
6834
- return NULL;
6835
- }
6836
-
6837
- ctx->data = data->data;
6838
- }
6839
-
6840
- wsp_ggml_set_no_alloc(ctx_data, true);
6841
-
6842
- // create the tensors
6843
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6844
- const int64_t ne[WSP_GGML_MAX_DIMS] = {
6845
- ctx->infos[i].ne[0],
6846
- ctx->infos[i].ne[1],
6847
- ctx->infos[i].ne[2],
6848
- ctx->infos[i].ne[3],
6849
- };
6850
-
6851
- struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
6852
-
6853
- ok = ok && cur != NULL;
6854
-
6855
- if (!ok) {
6856
- break;
6857
- }
6858
-
6859
- wsp_ggml_set_name(cur, ctx->infos[i].name.data);
6860
-
6861
- // point the data member to the appropriate location in the binary blob using the tensor infos
6862
- if (!params.no_alloc) {
6863
- //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
6864
- cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
6865
- }
6866
- }
6867
-
6868
- if (!ok) {
6869
- fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
6870
- fclose(file);
6871
- wsp_ggml_free(ctx_data);
6872
- wsp_gguf_free(ctx);
6873
- return NULL;
6874
- }
6875
-
6876
- wsp_ggml_set_no_alloc(ctx_data, params.no_alloc);
6877
- }
6878
-
6879
- fclose(file);
6880
-
6881
- return ctx;
6882
- }
6883
-
6884
- void wsp_gguf_free(struct wsp_gguf_context * ctx) {
6885
- if (ctx == NULL) {
6886
- return;
6887
- }
6888
-
6889
- if (ctx->kv) {
6890
- // free string memory - not great..
6891
- for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
6892
- wsp_gguf_free_kv(&ctx->kv[i]);
6893
- }
6894
-
6895
- WSP_GGML_FREE(ctx->kv);
6896
- }
6897
-
6898
- if (ctx->infos) {
6899
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6900
- struct wsp_gguf_tensor_info * info = &ctx->infos[i];
6901
-
6902
- if (info->name.data) {
6903
- WSP_GGML_FREE(info->name.data);
6904
- }
6905
- }
6906
-
6907
- WSP_GGML_FREE(ctx->infos);
6908
- }
6909
-
6910
- WSP_GGML_FREE(ctx);
6911
- }
6912
-
6913
- const char * wsp_gguf_type_name(enum wsp_gguf_type type) {
6914
- return WSP_GGUF_TYPE_NAME[type];
6915
- }
6916
-
6917
- int wsp_gguf_get_version(const struct wsp_gguf_context * ctx) {
6918
- return ctx->header.version;
6919
- }
6920
-
6921
- size_t wsp_gguf_get_alignment(const struct wsp_gguf_context * ctx) {
6922
- return ctx->alignment;
6923
- }
6924
-
6925
- size_t wsp_gguf_get_data_offset(const struct wsp_gguf_context * ctx) {
6926
- return ctx->offset;
6927
- }
6928
-
6929
- void * wsp_gguf_get_data(const struct wsp_gguf_context * ctx) {
6930
- return ctx->data;
6931
- }
6932
-
6933
- int wsp_gguf_get_n_kv(const struct wsp_gguf_context * ctx) {
6934
- return ctx->header.n_kv;
6935
- }
6936
-
6937
- int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key) {
6938
- // return -1 if key not found
6939
- int keyfound = -1;
6940
-
6941
- const int n_kv = wsp_gguf_get_n_kv(ctx);
6942
-
6943
- for (int i = 0; i < n_kv; ++i) {
6944
- if (strcmp(key, wsp_gguf_get_key(ctx, i)) == 0) {
6945
- keyfound = i;
6946
- break;
6947
- }
6948
- }
6949
-
6950
- return keyfound;
6951
- }
6952
-
6953
- const char * wsp_gguf_get_key(const struct wsp_gguf_context * ctx, int key_id) {
6954
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
6955
- return ctx->kv[key_id].key.data;
6956
- }
6957
-
6958
- enum wsp_gguf_type wsp_gguf_get_kv_type(const struct wsp_gguf_context * ctx, int key_id) {
6959
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
6960
- return ctx->kv[key_id].type;
6961
- }
6962
-
6963
- enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id) {
6964
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
6965
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
6966
- return ctx->kv[key_id].value.arr.type;
6967
- }
6968
-
6969
- const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id) {
6970
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
6971
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
6972
- return ctx->kv[key_id].value.arr.data;
6973
- }
6974
-
6975
- const char * wsp_gguf_get_arr_str(const struct wsp_gguf_context * ctx, int key_id, int i) {
6976
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
6977
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
6978
- struct wsp_gguf_kv * kv = &ctx->kv[key_id];
6979
- struct wsp_gguf_str * str = &((struct wsp_gguf_str *) kv->value.arr.data)[i];
6980
- return str->data;
6981
- }
6982
-
6983
- int wsp_gguf_get_arr_n(const struct wsp_gguf_context * ctx, int key_id) {
6984
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
6985
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
6986
- return ctx->kv[key_id].value.arr.n;
6987
- }
6988
-
6989
- uint8_t wsp_gguf_get_val_u8(const struct wsp_gguf_context * ctx, int key_id) {
6990
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
6991
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT8);
6992
- return ctx->kv[key_id].value.uint8;
6993
- }
6994
-
6995
- int8_t wsp_gguf_get_val_i8(const struct wsp_gguf_context * ctx, int key_id) {
6996
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
6997
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT8);
6998
- return ctx->kv[key_id].value.int8;
6999
- }
7000
-
7001
- uint16_t wsp_gguf_get_val_u16(const struct wsp_gguf_context * ctx, int key_id) {
7002
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7003
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT16);
7004
- return ctx->kv[key_id].value.uint16;
7005
- }
7006
-
7007
- int16_t wsp_gguf_get_val_i16(const struct wsp_gguf_context * ctx, int key_id) {
7008
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7009
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT16);
7010
- return ctx->kv[key_id].value.int16;
7011
- }
7012
-
7013
- uint32_t wsp_gguf_get_val_u32(const struct wsp_gguf_context * ctx, int key_id) {
7014
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7015
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT32);
7016
- return ctx->kv[key_id].value.uint32;
7017
- }
7018
-
7019
- int32_t wsp_gguf_get_val_i32(const struct wsp_gguf_context * ctx, int key_id) {
7020
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7021
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT32);
7022
- return ctx->kv[key_id].value.int32;
7023
- }
7024
-
7025
- float wsp_gguf_get_val_f32(const struct wsp_gguf_context * ctx, int key_id) {
7026
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7027
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT32);
7028
- return ctx->kv[key_id].value.float32;
7029
- }
7030
-
7031
- uint64_t wsp_gguf_get_val_u64(const struct wsp_gguf_context * ctx, int key_id) {
7032
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7033
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT64);
7034
- return ctx->kv[key_id].value.uint64;
7035
- }
7036
-
7037
- int64_t wsp_gguf_get_val_i64(const struct wsp_gguf_context * ctx, int key_id) {
7038
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7039
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT64);
7040
- return ctx->kv[key_id].value.int64;
7041
- }
7042
-
7043
- double wsp_gguf_get_val_f64(const struct wsp_gguf_context * ctx, int key_id) {
7044
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7045
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT64);
7046
- return ctx->kv[key_id].value.float64;
7047
- }
7048
-
7049
- bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id) {
7050
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7051
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_BOOL);
7052
- return ctx->kv[key_id].value.bool_;
7053
- }
7054
-
7055
- const char * wsp_gguf_get_val_str(const struct wsp_gguf_context * ctx, int key_id) {
7056
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7057
- WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_STRING);
7058
- return ctx->kv[key_id].value.str.data;
7059
- }
7060
-
7061
- const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id) {
7062
- WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
7063
- WSP_GGML_ASSERT(ctx->kv[key_id].type != WSP_GGUF_TYPE_ARRAY);
7064
- WSP_GGML_ASSERT(ctx->kv[key_id].type != WSP_GGUF_TYPE_STRING);
7065
- return &ctx->kv[key_id].value;
7066
- }
7067
-
7068
- int wsp_gguf_get_n_tensors(const struct wsp_gguf_context * ctx) {
7069
- return ctx->header.n_tensors;
7070
- }
7071
-
7072
- int wsp_gguf_find_tensor(const struct wsp_gguf_context * ctx, const char * name) {
7073
- // return -1 if tensor not found
7074
- int tensorfound = -1;
7075
-
7076
- const int n_tensors = wsp_gguf_get_n_tensors(ctx);
7077
-
7078
- for (int i = 0; i < n_tensors; ++i) {
7079
- if (strcmp(name, wsp_gguf_get_tensor_name(ctx, i)) == 0) {
7080
- tensorfound = i;
7081
- break;
7082
- }
7083
- }
7084
-
7085
- return tensorfound;
7086
- }
7087
-
7088
- size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i) {
7089
- return ctx->infos[i].offset;
7090
- }
7091
-
7092
- char * wsp_gguf_get_tensor_name(const struct wsp_gguf_context * ctx, int i) {
7093
- return ctx->infos[i].name.data;
7094
- }
7095
-
7096
- enum wsp_ggml_type wsp_gguf_get_tensor_type(const struct wsp_gguf_context * ctx, int i) {
7097
- return ctx->infos[i].type;
7098
- }
7099
-
7100
- // returns the index
7101
- static int wsp_gguf_get_or_add_key(struct wsp_gguf_context * ctx, const char * key) {
7102
- const int idx = wsp_gguf_find_key(ctx, key);
7103
- if (idx >= 0) {
7104
- return idx;
7105
- }
7106
-
7107
- const int n_kv = wsp_gguf_get_n_kv(ctx);
7108
-
7109
- ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct wsp_gguf_kv));
7110
- ctx->kv[n_kv].key.n = strlen(key);
7111
- ctx->kv[n_kv].key.data = strdup(key);
7112
- ctx->header.n_kv++;
7113
-
7114
- return n_kv;
7115
- }
7116
-
7117
- void wsp_gguf_remove_key(struct wsp_gguf_context * ctx, const char * key) {
7118
- const int idx = wsp_gguf_find_key(ctx, key);
7119
- if (idx >= 0) {
7120
- const int n_kv = wsp_gguf_get_n_kv(ctx);
7121
- wsp_gguf_free_kv(&ctx->kv[idx]);
7122
- for (int i = idx; i < n_kv-1; ++i) {
7123
- ctx->kv[i] = ctx->kv[i+1];
7124
- }
7125
- ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct wsp_gguf_kv));
7126
- ctx->header.n_kv--;
7127
- }
7128
- }
7129
-
7130
- void wsp_gguf_set_val_u8(struct wsp_gguf_context * ctx, const char * key, uint8_t val) {
7131
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7132
-
7133
- ctx->kv[idx].type = WSP_GGUF_TYPE_UINT8;
7134
- ctx->kv[idx].value.uint8 = val;
7135
- }
7136
-
7137
- void wsp_gguf_set_val_i8(struct wsp_gguf_context * ctx, const char * key, int8_t val) {
7138
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7139
-
7140
- ctx->kv[idx].type = WSP_GGUF_TYPE_INT8;
7141
- ctx->kv[idx].value.int8 = val;
7142
- }
7143
-
7144
- void wsp_gguf_set_val_u16(struct wsp_gguf_context * ctx, const char * key, uint16_t val) {
7145
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7146
-
7147
- ctx->kv[idx].type = WSP_GGUF_TYPE_UINT16;
7148
- ctx->kv[idx].value.uint16 = val;
7149
- }
7150
-
7151
- void wsp_gguf_set_val_i16(struct wsp_gguf_context * ctx, const char * key, int16_t val) {
7152
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7153
-
7154
- ctx->kv[idx].type = WSP_GGUF_TYPE_INT16;
7155
- ctx->kv[idx].value.int16 = val;
7156
- }
7157
-
7158
- void wsp_gguf_set_val_u32(struct wsp_gguf_context * ctx, const char * key, uint32_t val) {
7159
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7160
-
7161
- ctx->kv[idx].type = WSP_GGUF_TYPE_UINT32;
7162
- ctx->kv[idx].value.uint32 = val;
7163
- }
7164
-
7165
- void wsp_gguf_set_val_i32(struct wsp_gguf_context * ctx, const char * key, int32_t val) {
7166
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7167
-
7168
- ctx->kv[idx].type = WSP_GGUF_TYPE_INT32;
7169
- ctx->kv[idx].value.int32 = val;
7170
- }
7171
-
7172
- void wsp_gguf_set_val_f32(struct wsp_gguf_context * ctx, const char * key, float val) {
7173
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7174
-
7175
- ctx->kv[idx].type = WSP_GGUF_TYPE_FLOAT32;
7176
- ctx->kv[idx].value.float32 = val;
7177
- }
7178
-
7179
- void wsp_gguf_set_val_u64(struct wsp_gguf_context * ctx, const char * key, uint64_t val) {
7180
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7181
-
7182
- ctx->kv[idx].type = WSP_GGUF_TYPE_UINT64;
7183
- ctx->kv[idx].value.uint64 = val;
7184
- }
7185
-
7186
- void wsp_gguf_set_val_i64(struct wsp_gguf_context * ctx, const char * key, int64_t val) {
7187
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7188
-
7189
- ctx->kv[idx].type = WSP_GGUF_TYPE_INT64;
7190
- ctx->kv[idx].value.int64 = val;
7191
- }
7192
-
7193
- void wsp_gguf_set_val_f64(struct wsp_gguf_context * ctx, const char * key, double val) {
7194
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7195
-
7196
- ctx->kv[idx].type = WSP_GGUF_TYPE_FLOAT64;
7197
- ctx->kv[idx].value.float64 = val;
7198
- }
7199
-
7200
- void wsp_gguf_set_val_bool(struct wsp_gguf_context * ctx, const char * key, bool val) {
7201
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7202
-
7203
- ctx->kv[idx].type = WSP_GGUF_TYPE_BOOL;
7204
- ctx->kv[idx].value.bool_ = val;
7205
- }
7206
-
7207
- void wsp_gguf_set_val_str(struct wsp_gguf_context * ctx, const char * key, const char * val) {
7208
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7209
-
7210
- ctx->kv[idx].type = WSP_GGUF_TYPE_STRING;
7211
- ctx->kv[idx].value.str.n = strlen(val);
7212
- ctx->kv[idx].value.str.data = strdup(val);
7213
- }
7214
-
7215
- void wsp_gguf_set_arr_data(struct wsp_gguf_context * ctx, const char * key, enum wsp_gguf_type type, const void * data, int n) {
7216
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7217
-
7218
- ctx->kv[idx].type = WSP_GGUF_TYPE_ARRAY;
7219
- ctx->kv[idx].value.arr.type = type;
7220
- ctx->kv[idx].value.arr.n = n;
7221
- ctx->kv[idx].value.arr.data = WSP_GGML_CALLOC(n, wsp_gguf_type_size(type));
7222
- memcpy(ctx->kv[idx].value.arr.data, data, n*wsp_gguf_type_size(type));
7223
- }
7224
-
7225
- void wsp_gguf_set_arr_str(struct wsp_gguf_context * ctx, const char * key, const char ** data, int n) {
7226
- const int idx = wsp_gguf_get_or_add_key(ctx, key);
7227
-
7228
- ctx->kv[idx].type = WSP_GGUF_TYPE_ARRAY;
7229
- ctx->kv[idx].value.arr.type = WSP_GGUF_TYPE_STRING;
7230
- ctx->kv[idx].value.arr.n = n;
7231
- ctx->kv[idx].value.arr.data = WSP_GGML_CALLOC(n, sizeof(struct wsp_gguf_str));
7232
- for (int i = 0; i < n; i++) {
7233
- struct wsp_gguf_str * str = &((struct wsp_gguf_str *)ctx->kv[idx].value.arr.data)[i];
7234
- str->n = strlen(data[i]);
7235
- str->data = strdup(data[i]);
7236
- }
7237
- }
7238
-
7239
- // set or add KV pairs from another context
7240
- void wsp_gguf_set_kv(struct wsp_gguf_context * ctx, struct wsp_gguf_context * src) {
7241
- for (uint32_t i = 0; i < src->header.n_kv; i++) {
7242
- switch (src->kv[i].type) {
7243
- case WSP_GGUF_TYPE_UINT8: wsp_gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break;
7244
- case WSP_GGUF_TYPE_INT8: wsp_gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break;
7245
- case WSP_GGUF_TYPE_UINT16: wsp_gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break;
7246
- case WSP_GGUF_TYPE_INT16: wsp_gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break;
7247
- case WSP_GGUF_TYPE_UINT32: wsp_gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break;
7248
- case WSP_GGUF_TYPE_INT32: wsp_gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break;
7249
- case WSP_GGUF_TYPE_FLOAT32: wsp_gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break;
7250
- case WSP_GGUF_TYPE_UINT64: wsp_gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break;
7251
- case WSP_GGUF_TYPE_INT64: wsp_gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break;
7252
- case WSP_GGUF_TYPE_FLOAT64: wsp_gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break;
7253
- case WSP_GGUF_TYPE_BOOL: wsp_gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break;
7254
- case WSP_GGUF_TYPE_STRING: wsp_gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
7255
- case WSP_GGUF_TYPE_ARRAY:
7256
- {
7257
- if (src->kv[i].value.arr.type == WSP_GGUF_TYPE_STRING) {
7258
- const char ** data = WSP_GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
7259
- for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
7260
- data[j] = ((struct wsp_gguf_str *)src->kv[i].value.arr.data)[j].data;
7261
- }
7262
- wsp_gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
7263
- WSP_GGML_FREE((void *)data);
7264
- } else if (src->kv[i].value.arr.type == WSP_GGUF_TYPE_ARRAY) {
7265
- WSP_GGML_ABORT("nested arrays not supported");
7266
- } else {
7267
- wsp_gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
7268
- }
7269
- } break;
7270
- default: WSP_GGML_ABORT("invalid type");
7271
- }
7272
- }
7273
- }
7274
-
7275
- void wsp_gguf_add_tensor(
7276
- struct wsp_gguf_context * ctx,
7277
- const struct wsp_ggml_tensor * tensor) {
7278
- WSP_GGML_ASSERT(tensor);
7279
- if (wsp_gguf_find_tensor(ctx, tensor->name) != -1) {
7280
- WSP_GGML_ABORT("duplicated tensor name");
7281
- }
7282
-
7283
- const int idx = ctx->header.n_tensors;
7284
- ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct wsp_gguf_tensor_info));
7285
-
7286
- ctx->infos[idx].name.n = strlen(tensor->name);
7287
- ctx->infos[idx].name.data = strdup(tensor->name);
7288
-
7289
- for (int i = 0; i < WSP_GGML_MAX_DIMS; ++i) {
7290
- ctx->infos[idx].ne[i] = 1;
7291
- }
7292
-
7293
- ctx->infos[idx].n_dims = wsp_ggml_n_dims(tensor);
7294
- for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
7295
- ctx->infos[idx].ne[i] = tensor->ne[i];
7296
- }
7297
-
7298
- ctx->infos[idx].type = tensor->type;
7299
- ctx->infos[idx].offset = 0;
7300
- ctx->infos[idx].data = tensor->data;
7301
- ctx->infos[idx].size = wsp_ggml_nbytes(tensor);
7302
-
7303
- if (ctx->header.n_tensors > 0) {
7304
- ctx->infos[idx].offset = ctx->infos[idx - 1].offset + WSP_GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment);
7305
- }
7306
-
7307
- ctx->header.n_tensors++;
7308
- }
7309
-
7310
- void wsp_gguf_set_tensor_type(struct wsp_gguf_context * ctx, const char * name, enum wsp_ggml_type type) {
7311
- const int idx = wsp_gguf_find_tensor(ctx, name);
7312
- if (idx < 0) {
7313
- WSP_GGML_ABORT("tensor not found");
7314
- }
7315
-
7316
- ctx->infos[idx].type = type;
7317
- }
7318
-
7319
- void wsp_gguf_set_tensor_data(struct wsp_gguf_context * ctx, const char * name, const void * data, size_t size) {
7320
- const int idx = wsp_gguf_find_tensor(ctx, name);
7321
- if (idx < 0) {
7322
- WSP_GGML_ABORT("tensor not found");
7323
- }
7324
-
7325
- ctx->infos[idx].data = data;
7326
- ctx->infos[idx].size = size;
7327
-
7328
- // update offsets
7329
- for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
7330
- ctx->infos[i].offset = ctx->infos[i - 1].offset + WSP_GGML_PAD(ctx->infos[i - 1].size, ctx->alignment);
7331
- }
7332
- }
7333
-
7334
- //static void wsp_gguf_fwrite_str(FILE * file, const struct wsp_gguf_str * val) {
7335
- // fwrite(&val->n, sizeof(val->n), 1, file);
7336
- // fwrite(val->data, sizeof(char), val->n, file);
7337
- //}
7338
- //
7339
- //static void wsp_gguf_fwrite_el(FILE * file, const void * val, size_t size) {
7340
- // fwrite(val, sizeof(char), size, file);
7341
- //}
7342
-
7343
- struct wsp_gguf_buf {
7344
- void * data;
7345
- size_t size;
7346
- size_t offset;
7347
- };
7348
-
7349
- static struct wsp_gguf_buf wsp_gguf_buf_init(size_t size) {
7350
- struct wsp_gguf_buf buf = {
7351
- /*buf.data =*/ size == 0 ? NULL : WSP_GGML_CALLOC(1, size),
7352
- /*buf.size =*/ size,
7353
- /*buf.offset =*/ 0,
7354
- };
7355
-
7356
- return buf;
7357
- }
7358
-
7359
- static void wsp_gguf_buf_free(struct wsp_gguf_buf buf) {
7360
- if (buf.data) {
7361
- WSP_GGML_FREE(buf.data);
7362
- }
7363
- }
7364
-
7365
- static void wsp_gguf_buf_grow(struct wsp_gguf_buf * buf, size_t size) {
7366
- if (buf->offset + size > buf->size) {
7367
- buf->size = 1.5*(buf->offset + size);
7368
- if (buf->data) {
7369
- buf->data = realloc(buf->data, buf->size);
7370
- }
7371
- }
7372
- }
7373
-
7374
- static void wsp_gguf_bwrite_str(struct wsp_gguf_buf * buf, const struct wsp_gguf_str * val) {
7375
- wsp_gguf_buf_grow(buf, sizeof(val->n) + val->n);
7376
-
7377
- if (buf->data) {
7378
- memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
7379
- }
7380
- buf->offset += sizeof(val->n);
7381
-
7382
- if (buf->data) {
7383
- memcpy((char *) buf->data + buf->offset, val->data, val->n);
7384
- }
7385
- buf->offset += val->n;
7386
- }
7387
-
7388
- static void wsp_gguf_bwrite_el(struct wsp_gguf_buf * buf, const void * val, size_t el_size) {
7389
- wsp_gguf_buf_grow(buf, el_size);
7390
-
7391
- if (buf->data) {
7392
- memcpy((char *) buf->data + buf->offset, val, el_size);
7393
- }
7394
- buf->offset += el_size;
7395
- }
7396
-
7397
- static void wsp_gguf_write_to_buf(const struct wsp_gguf_context * ctx, struct wsp_gguf_buf * buf, bool only_meta) {
7398
- // write header
7399
- wsp_gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
7400
- wsp_gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
7401
- wsp_gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
7402
- wsp_gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv));
7403
-
7404
- // write key-value pairs
7405
- for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
7406
- struct wsp_gguf_kv * kv = &ctx->kv[i];
7407
-
7408
- wsp_gguf_bwrite_str(buf, &kv->key);
7409
- wsp_gguf_bwrite_el (buf, &kv->type, sizeof(kv->type));
7410
-
7411
- switch (kv->type) {
7412
- case WSP_GGUF_TYPE_UINT8: wsp_gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break;
7413
- case WSP_GGUF_TYPE_INT8: wsp_gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break;
7414
- case WSP_GGUF_TYPE_UINT16: wsp_gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break;
7415
- case WSP_GGUF_TYPE_INT16: wsp_gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break;
7416
- case WSP_GGUF_TYPE_UINT32: wsp_gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break;
7417
- case WSP_GGUF_TYPE_INT32: wsp_gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break;
7418
- case WSP_GGUF_TYPE_FLOAT32: wsp_gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
7419
- case WSP_GGUF_TYPE_UINT64: wsp_gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break;
7420
- case WSP_GGUF_TYPE_INT64: wsp_gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break;
7421
- case WSP_GGUF_TYPE_FLOAT64: wsp_gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
7422
- case WSP_GGUF_TYPE_BOOL: wsp_gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break;
7423
- case WSP_GGUF_TYPE_STRING: wsp_gguf_bwrite_str(buf, &kv->value.str ); break;
7424
- case WSP_GGUF_TYPE_ARRAY:
7425
- {
7426
- wsp_gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
7427
- wsp_gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) );
7428
-
7429
- switch (kv->value.arr.type) {
7430
- case WSP_GGUF_TYPE_UINT8:
7431
- case WSP_GGUF_TYPE_INT8:
7432
- case WSP_GGUF_TYPE_UINT16:
7433
- case WSP_GGUF_TYPE_INT16:
7434
- case WSP_GGUF_TYPE_UINT32:
7435
- case WSP_GGUF_TYPE_INT32:
7436
- case WSP_GGUF_TYPE_FLOAT32:
7437
- case WSP_GGUF_TYPE_UINT64:
7438
- case WSP_GGUF_TYPE_INT64:
7439
- case WSP_GGUF_TYPE_FLOAT64:
7440
- case WSP_GGUF_TYPE_BOOL:
7441
- {
7442
- wsp_gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * wsp_gguf_type_size(kv->value.arr.type));
7443
- } break;
7444
- case WSP_GGUF_TYPE_STRING:
7445
- {
7446
- for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
7447
- wsp_gguf_bwrite_str(buf, &((struct wsp_gguf_str *) kv->value.arr.data)[j]);
7448
- }
7449
- } break;
7450
- case WSP_GGUF_TYPE_ARRAY:
7451
- default: WSP_GGML_ABORT("invalid type");
7452
- }
7453
- } break;
7454
- default: WSP_GGML_ABORT("invalid type");
7455
- }
7456
- }
7457
-
7458
- // write tensor infos
7459
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
7460
- struct wsp_gguf_tensor_info * info = &ctx->infos[i];
7461
-
7462
- wsp_gguf_bwrite_str(buf, &info->name);
7463
- wsp_gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
7464
- for (uint32_t j = 0; j < info->n_dims; ++j) {
7465
- wsp_gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
7466
- }
7467
- wsp_gguf_bwrite_el(buf, &info->type, sizeof(info->type));
7468
- wsp_gguf_bwrite_el(buf, &info->offset, sizeof(info->offset));
7469
- }
7470
-
7471
- // we require the data section to be aligned, so take into account any padding
7472
- {
7473
- const size_t offset = buf->offset;
7474
- const size_t offset_pad = WSP_GGML_PAD(offset, ctx->alignment);
7475
-
7476
- if (offset_pad != offset) {
7477
- uint8_t pad = 0;
7478
- for (size_t i = 0; i < offset_pad - offset; ++i) {
7479
- wsp_gguf_bwrite_el(buf, &pad, sizeof(pad));
7480
- }
7481
- }
7482
- }
7483
-
7484
- if (only_meta) {
7485
- return;
7486
- }
7487
-
7488
- size_t offset = 0;
7489
-
7490
- // write tensor data
7491
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
7492
- struct wsp_gguf_tensor_info * info = &ctx->infos[i];
7493
-
7494
- const size_t size = info->size;
7495
- const size_t size_pad = WSP_GGML_PAD(size, ctx->alignment);
7496
-
7497
- wsp_gguf_bwrite_el(buf, info->data, size);
7498
-
7499
- if (size_pad != size) {
7500
- uint8_t pad = 0;
7501
- for (size_t j = 0; j < size_pad - size; ++j) {
7502
- wsp_gguf_bwrite_el(buf, &pad, sizeof(pad));
7503
- }
7504
- }
7505
-
7506
- WSP_GGML_ASSERT(offset == info->offset);
7507
-
7508
- offset += size_pad;
7509
- }
7510
- }
7511
-
7512
- void wsp_gguf_write_to_file(const struct wsp_gguf_context * ctx, const char * fname, bool only_meta) {
7513
- FILE * file = wsp_ggml_fopen(fname, "wb");
7514
- if (!file) {
7515
- WSP_GGML_ABORT("failed to open file for writing");
7516
- }
7517
-
7518
- struct wsp_gguf_buf buf = wsp_gguf_buf_init(16*1024);
7519
-
7520
- wsp_gguf_write_to_buf(ctx, &buf, only_meta);
7521
-
7522
- fwrite(buf.data, 1, buf.offset, file);
7523
-
7524
- wsp_gguf_buf_free(buf);
7525
-
7526
- fclose(file);
7527
- }
7528
-
7529
- size_t wsp_gguf_get_meta_size(const struct wsp_gguf_context * ctx) {
7530
- // no allocs - only compute size
7531
- struct wsp_gguf_buf buf = wsp_gguf_buf_init(0);
7532
-
7533
- wsp_gguf_write_to_buf(ctx, &buf, true);
7534
-
7535
- return buf.offset;
7536
- }
7537
-
7538
- void wsp_gguf_get_meta_data(const struct wsp_gguf_context * ctx, void * data) {
7539
- struct wsp_gguf_buf buf = wsp_gguf_buf_init(16*1024);
7540
-
7541
- wsp_gguf_write_to_buf(ctx, &buf, true);
7542
-
7543
- memcpy(data, buf.data, buf.offset);
7544
-
7545
- wsp_gguf_buf_free(buf);
7546
- }
7547
-
7548
- void wsp_ggml_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
7549
- g_logger_state.log_callback = log_callback ? log_callback : wsp_ggml_log_callback_default;
7550
- g_logger_state.log_callback_user_data = user_data;
6509
+ bool wsp_ggml_threadpool_params_match(const struct wsp_ggml_threadpool_params * p0, const struct wsp_ggml_threadpool_params * p1) {
6510
+ if (p0->n_threads != p1->n_threads ) return false;
6511
+ if (p0->prio != p1->prio ) return false;
6512
+ if (p0->poll != p1->poll ) return false;
6513
+ if (p0->strict_cpu != p1->strict_cpu ) return false;
6514
+ return memcmp(p0->cpumask, p1->cpumask, WSP_GGML_MAX_N_THREADS) == 0;
7551
6515
  }