whisper.rn 0.4.1 → 0.4.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. package/android/src/main/java/com/rnwhisper/RNWhisper.java +24 -18
  2. package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +1 -57
  3. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  5. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  9. package/cpp/ggml-backend.cpp +36 -18
  10. package/cpp/ggml-backend.h +1 -1
  11. package/cpp/ggml-cpu/amx/mmq.cpp +10 -9
  12. package/cpp/ggml-cpu/arch/arm/quants.c +109 -108
  13. package/cpp/ggml-cpu/arch/arm/repack.cpp +13 -12
  14. package/cpp/ggml-cpu/arch/x86/quants.c +83 -82
  15. package/cpp/ggml-cpu/arch/x86/repack.cpp +20 -19
  16. package/cpp/ggml-cpu/common.h +3 -2
  17. package/cpp/ggml-cpu/ggml-cpu-impl.h +9 -3
  18. package/cpp/ggml-cpu/ggml-cpu.c +95 -17
  19. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  20. package/cpp/ggml-cpu/ops.cpp +775 -74
  21. package/cpp/ggml-cpu/ops.h +7 -0
  22. package/cpp/ggml-cpu/quants.c +25 -24
  23. package/cpp/ggml-cpu/repack.cpp +15 -14
  24. package/cpp/ggml-cpu/simd-mappings.h +211 -33
  25. package/cpp/ggml-cpu/vec.cpp +26 -2
  26. package/cpp/ggml-cpu/vec.h +99 -45
  27. package/cpp/ggml-cpu.h +2 -0
  28. package/cpp/ggml-impl.h +125 -183
  29. package/cpp/ggml-metal-impl.h +27 -0
  30. package/cpp/ggml-metal.m +298 -41
  31. package/cpp/ggml-quants.c +6 -6
  32. package/cpp/ggml-whisper-sim.metallib +0 -0
  33. package/cpp/ggml-whisper.metallib +0 -0
  34. package/cpp/ggml.c +269 -40
  35. package/cpp/ggml.h +122 -2
  36. package/cpp/gguf.cpp +5 -1
  37. package/cpp/whisper.cpp +4 -0
  38. package/cpp/whisper.h +2 -0
  39. package/ios/RNWhisper.mm +35 -38
  40. package/ios/RNWhisperVadContext.h +1 -1
  41. package/ios/RNWhisperVadContext.mm +2 -6
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  64. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  72. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  73. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  74. package/package.json +1 -1
package/cpp/ggml-quants.c CHANGED
@@ -568,14 +568,14 @@ static float make_qkx2_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x
568
568
  }
569
569
  float iscale = nmax/(max - min);
570
570
  float scale = 1/iscale;
571
- float best_mad = 0;
571
+ float best_error = 0;
572
572
  for (int i = 0; i < n; ++i) {
573
573
  int l = nearest_int(iscale*(x[i] - min));
574
574
  L[i] = MAX(0, MIN(nmax, l));
575
575
  float diff = scale * L[i] + min - x[i];
576
576
  diff = use_mad ? fabsf(diff) : diff * diff;
577
577
  float w = weights[i];
578
- best_mad += w * diff;
578
+ best_error += w * diff;
579
579
  }
580
580
  if (nstep < 1) {
581
581
  *the_min = -min;
@@ -601,18 +601,18 @@ static float make_qkx2_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x
601
601
  this_min = 0;
602
602
  this_scale = sum_xl / sum_l2;
603
603
  }
604
- float mad = 0;
604
+ float cur_error = 0;
605
605
  for (int i = 0; i < n; ++i) {
606
606
  float diff = this_scale * Laux[i] + this_min - x[i];
607
607
  diff = use_mad ? fabsf(diff) : diff * diff;
608
608
  float w = weights[i];
609
- mad += w * diff;
609
+ cur_error += w * diff;
610
610
  }
611
- if (mad < best_mad) {
611
+ if (cur_error < best_error) {
612
612
  for (int i = 0; i < n; ++i) {
613
613
  L[i] = Laux[i];
614
614
  }
615
- best_mad = mad;
615
+ best_error = cur_error;
616
616
  scale = this_scale;
617
617
  min = this_min;
618
618
  }
Binary file
Binary file
package/cpp/ggml.c CHANGED
@@ -61,9 +61,6 @@
61
61
  #define m512i(p) (__m512i)(p)
62
62
  #endif
63
63
 
64
- // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
65
- float wsp_ggml_table_f32_f16[1 << 16];
66
-
67
64
  #if defined(__linux__) || \
68
65
  defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
69
66
  (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
@@ -936,6 +933,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
936
933
  "TRANSPOSE",
937
934
  "GET_ROWS",
938
935
  "GET_ROWS_BACK",
936
+ "SET_ROWS",
939
937
  "DIAG",
940
938
  "DIAG_MASK_INF",
941
939
  "DIAG_MASK_ZERO",
@@ -947,6 +945,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
947
945
  "CONV_TRANSPOSE_1D",
948
946
  "IM2COL",
949
947
  "IM2COL_BACK",
948
+ "CONV_2D",
950
949
  "CONV_2D_DW",
951
950
  "CONV_TRANSPOSE_2D",
952
951
  "POOL_1D",
@@ -984,9 +983,11 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
984
983
  "CROSS_ENTROPY_LOSS",
985
984
  "CROSS_ENTROPY_LOSS_BACK",
986
985
  "OPT_STEP_ADAMW",
986
+
987
+ "GLU",
987
988
  };
988
989
 
989
- static_assert(WSP_GGML_OP_COUNT == 83, "WSP_GGML_OP_COUNT != 83");
990
+ static_assert(WSP_GGML_OP_COUNT == 86, "WSP_GGML_OP_COUNT != 86");
990
991
 
991
992
  static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
992
993
  "none",
@@ -1032,6 +1033,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1032
1033
  "transpose(x)",
1033
1034
  "get_rows(x)",
1034
1035
  "get_rows_back(x)",
1036
+ "set_rows(x)",
1035
1037
  "diag(x)",
1036
1038
  "diag_mask_inf(x)",
1037
1039
  "diag_mask_zero(x)",
@@ -1043,6 +1045,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1043
1045
  "conv_transpose_1d(x)",
1044
1046
  "im2col(x)",
1045
1047
  "im2col_back(x)",
1048
+ "conv_2d(x)",
1046
1049
  "conv_2d_dw(x)",
1047
1050
  "conv_transpose_2d(x)",
1048
1051
  "pool_1d(x)",
@@ -1080,9 +1083,11 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1080
1083
  "cross_entropy_loss(x,y)",
1081
1084
  "cross_entropy_loss_back(x,y)",
1082
1085
  "adamw(x)",
1086
+
1087
+ "glu(x)",
1083
1088
  };
1084
1089
 
1085
- static_assert(WSP_GGML_OP_COUNT == 83, "WSP_GGML_OP_COUNT != 83");
1090
+ static_assert(WSP_GGML_OP_COUNT == 86, "WSP_GGML_OP_COUNT != 86");
1086
1091
 
1087
1092
  static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
1088
1093
 
@@ -1108,6 +1113,15 @@ static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
1108
1113
  static_assert(WSP_GGML_UNARY_OP_COUNT == 15, "WSP_GGML_UNARY_OP_COUNT != 15");
1109
1114
 
1110
1115
 
1116
+ static const char * WSP_GGML_GLU_OP_NAME[WSP_GGML_GLU_OP_COUNT] = {
1117
+ "REGLU",
1118
+ "GEGLU",
1119
+ "SWIGLU",
1120
+ };
1121
+
1122
+ static_assert(WSP_GGML_GLU_OP_COUNT == 3, "WSP_GGML_GLU_OP_COUNT != 3");
1123
+
1124
+
1111
1125
  static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN");
1112
1126
  static_assert(sizeof(struct wsp_ggml_tensor)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_tensor size must be a multiple of WSP_GGML_MEM_ALIGN");
1113
1127
 
@@ -1210,11 +1224,19 @@ const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op) {
1210
1224
  return WSP_GGML_UNARY_OP_NAME[op];
1211
1225
  }
1212
1226
 
1227
+ const char * wsp_ggml_glu_op_name(enum wsp_ggml_glu_op op) {
1228
+ return WSP_GGML_GLU_OP_NAME[op];
1229
+ }
1230
+
1213
1231
  const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) {
1214
1232
  if (t->op == WSP_GGML_OP_UNARY) {
1215
1233
  enum wsp_ggml_unary_op uop = wsp_ggml_get_unary_op(t);
1216
1234
  return wsp_ggml_unary_op_name(uop);
1217
1235
  }
1236
+ if (t->op == WSP_GGML_OP_GLU) {
1237
+ enum wsp_ggml_glu_op gop = wsp_ggml_get_glu_op(t);
1238
+ return wsp_ggml_glu_op_name(gop);
1239
+ }
1218
1240
  return wsp_ggml_op_name(t->op);
1219
1241
  }
1220
1242
 
@@ -1351,6 +1373,12 @@ bool wsp_ggml_is_contiguous_channels(const struct wsp_ggml_tensor * tensor) {
1351
1373
  tensor->nb[2] == wsp_ggml_type_size(tensor->type);
1352
1374
  }
1353
1375
 
1376
+ bool wsp_ggml_is_contiguous_rows(const struct wsp_ggml_tensor * tensor) {
1377
+ return
1378
+ tensor->ne[0] == wsp_ggml_blck_size(tensor->type) ||
1379
+ tensor->nb[0] == wsp_ggml_type_size(tensor->type);
1380
+ }
1381
+
1354
1382
  static inline bool wsp_ggml_is_padded_1d(const struct wsp_ggml_tensor * tensor) {
1355
1383
  static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
1356
1384
 
@@ -1422,14 +1450,6 @@ struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params) {
1422
1450
  // initialize time system (required on Windows)
1423
1451
  wsp_ggml_time_init();
1424
1452
 
1425
- for (int i = 0; i < (1 << 16); ++i) {
1426
- union {
1427
- uint16_t u16;
1428
- wsp_ggml_fp16_t fp16;
1429
- } u = {i};
1430
- wsp_ggml_table_f32_f16[i] = WSP_GGML_COMPUTE_FP16_TO_FP32(u.fp16);
1431
- }
1432
-
1433
1453
  is_first_call = false;
1434
1454
  }
1435
1455
 
@@ -1733,6 +1753,11 @@ enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tens
1733
1753
  return (enum wsp_ggml_unary_op) wsp_ggml_get_op_params_i32(tensor, 0);
1734
1754
  }
1735
1755
 
1756
+ enum wsp_ggml_glu_op wsp_ggml_get_glu_op(const struct wsp_ggml_tensor * tensor) {
1757
+ WSP_GGML_ASSERT(tensor->op == WSP_GGML_OP_GLU);
1758
+ return (enum wsp_ggml_glu_op) wsp_ggml_get_op_params_i32(tensor, 0);
1759
+ }
1760
+
1736
1761
  const char * wsp_ggml_get_name(const struct wsp_ggml_tensor * tensor) {
1737
1762
  return tensor->name;
1738
1763
  }
@@ -2612,6 +2637,114 @@ struct wsp_ggml_tensor * wsp_ggml_exp_inplace(
2612
2637
  return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_EXP);
2613
2638
  }
2614
2639
 
2640
+ // wsp_ggml_glu
2641
+
2642
+ static struct wsp_ggml_tensor * wsp_ggml_glu_impl(
2643
+ struct wsp_ggml_context * ctx,
2644
+ struct wsp_ggml_tensor * a,
2645
+ struct wsp_ggml_tensor * b,
2646
+ enum wsp_ggml_glu_op op,
2647
+ bool swapped) {
2648
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(a));
2649
+
2650
+ if (b) {
2651
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(b));
2652
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b));
2653
+ WSP_GGML_ASSERT(a->type == b->type);
2654
+ }
2655
+
2656
+ int64_t ne[WSP_GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < WSP_GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2657
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, WSP_GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
2658
+
2659
+ wsp_ggml_set_op_params_i32(result, 0, (int32_t) op);
2660
+ wsp_ggml_set_op_params_i32(result, 1, (int32_t) swapped);
2661
+
2662
+ result->op = WSP_GGML_OP_GLU;
2663
+ result->src[0] = a;
2664
+ result->src[1] = b;
2665
+
2666
+ return result;
2667
+ }
2668
+
2669
+ struct wsp_ggml_tensor * wsp_ggml_glu(
2670
+ struct wsp_ggml_context * ctx,
2671
+ struct wsp_ggml_tensor * a,
2672
+ enum wsp_ggml_glu_op op,
2673
+ bool swapped) {
2674
+ return wsp_ggml_glu_impl(ctx, a, NULL, op, swapped);
2675
+ }
2676
+
2677
+ struct wsp_ggml_tensor * wsp_ggml_glu_split(
2678
+ struct wsp_ggml_context * ctx,
2679
+ struct wsp_ggml_tensor * a,
2680
+ struct wsp_ggml_tensor * b,
2681
+ enum wsp_ggml_glu_op op) {
2682
+ return wsp_ggml_glu_impl(ctx, a, b, op, false);
2683
+ }
2684
+
2685
+ // wsp_ggml_reglu
2686
+
2687
+ struct wsp_ggml_tensor * wsp_ggml_reglu(
2688
+ struct wsp_ggml_context * ctx,
2689
+ struct wsp_ggml_tensor * a) {
2690
+ return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_REGLU, false);
2691
+ }
2692
+
2693
+ struct wsp_ggml_tensor * wsp_ggml_reglu_swapped(
2694
+ struct wsp_ggml_context * ctx,
2695
+ struct wsp_ggml_tensor * a) {
2696
+ return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_REGLU, true);
2697
+ }
2698
+
2699
+ struct wsp_ggml_tensor * wsp_ggml_reglu_split(
2700
+ struct wsp_ggml_context * ctx,
2701
+ struct wsp_ggml_tensor * a,
2702
+ struct wsp_ggml_tensor * b) {
2703
+ return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_REGLU, false);
2704
+ }
2705
+
2706
+ // wsp_ggml_geglu
2707
+
2708
+ struct wsp_ggml_tensor * wsp_ggml_geglu(
2709
+ struct wsp_ggml_context * ctx,
2710
+ struct wsp_ggml_tensor * a) {
2711
+ return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU, false);
2712
+ }
2713
+
2714
+ struct wsp_ggml_tensor * wsp_ggml_geglu_swapped(
2715
+ struct wsp_ggml_context * ctx,
2716
+ struct wsp_ggml_tensor * a) {
2717
+ return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU, true);
2718
+ }
2719
+
2720
+ struct wsp_ggml_tensor * wsp_ggml_geglu_split(
2721
+ struct wsp_ggml_context * ctx,
2722
+ struct wsp_ggml_tensor * a,
2723
+ struct wsp_ggml_tensor * b) {
2724
+ return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_GEGLU, false);
2725
+ }
2726
+
2727
+ // wsp_ggml_swiglu
2728
+
2729
+ struct wsp_ggml_tensor * wsp_ggml_swiglu(
2730
+ struct wsp_ggml_context * ctx,
2731
+ struct wsp_ggml_tensor * a) {
2732
+ return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_SWIGLU, false);
2733
+ }
2734
+
2735
+ struct wsp_ggml_tensor * wsp_ggml_swiglu_swapped(
2736
+ struct wsp_ggml_context * ctx,
2737
+ struct wsp_ggml_tensor * a) {
2738
+ return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_SWIGLU, true);
2739
+ }
2740
+
2741
+ struct wsp_ggml_tensor * wsp_ggml_swiglu_split(
2742
+ struct wsp_ggml_context * ctx,
2743
+ struct wsp_ggml_tensor * a,
2744
+ struct wsp_ggml_tensor * b) {
2745
+ return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_SWIGLU, false);
2746
+ }
2747
+
2615
2748
  // wsp_ggml_norm
2616
2749
 
2617
2750
  static struct wsp_ggml_tensor * wsp_ggml_norm_impl(
@@ -3395,6 +3528,35 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows_back(
3395
3528
  return result;
3396
3529
  }
3397
3530
 
3531
+ // wsp_ggml_set_rows
3532
+
3533
+ struct wsp_ggml_tensor * wsp_ggml_set_rows(
3534
+ struct wsp_ggml_context * ctx,
3535
+ struct wsp_ggml_tensor * a,
3536
+ struct wsp_ggml_tensor * b,
3537
+ struct wsp_ggml_tensor * c) {
3538
+ WSP_GGML_ASSERT(a->ne[0] == b->ne[0]);
3539
+ WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
3540
+ WSP_GGML_ASSERT(a->ne[3] == b->ne[3]);
3541
+ WSP_GGML_ASSERT(b->ne[1] == c->ne[0]);
3542
+ WSP_GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
3543
+ WSP_GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
3544
+ WSP_GGML_ASSERT(c->ne[3] == 1);
3545
+ WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_F32);
3546
+ WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_I64);
3547
+
3548
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(a));
3549
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(b));
3550
+
3551
+ struct wsp_ggml_tensor * result = wsp_ggml_view_tensor(ctx, a);
3552
+
3553
+ result->op = WSP_GGML_OP_SET_ROWS;
3554
+ result->src[0] = b;
3555
+ result->src[1] = c;
3556
+
3557
+ return result;
3558
+ }
3559
+
3398
3560
  // wsp_ggml_diag
3399
3561
 
3400
3562
  struct wsp_ggml_tensor * wsp_ggml_diag(
@@ -4131,6 +4293,44 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d_dw_direct(
4131
4293
  return result;
4132
4294
  }
4133
4295
 
4296
+ // wsp_ggml_conv_2d_direct
4297
+
4298
+ struct wsp_ggml_tensor * wsp_ggml_conv_2d_direct(
4299
+ struct wsp_ggml_context * ctx,
4300
+ struct wsp_ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
4301
+ struct wsp_ggml_tensor * b, // input data [W, H, C, N]
4302
+ int s0, // stride dimension 0
4303
+ int s1, // stride dimension 1
4304
+ int p0, // padding dimension 0
4305
+ int p1, // padding dimension 1
4306
+ int d0, // dilation dimension 0
4307
+ int d1) {// dilation dimension 1
4308
+
4309
+ WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
4310
+ //WSP_GGML_ASSERT(a->type == b->type);
4311
+
4312
+ int64_t ne[4];
4313
+ ne[0] = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4314
+ ne[1] = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4315
+ ne[2] = a->ne[3];
4316
+ ne[3] = b->ne[3];
4317
+
4318
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, b->type, 4, ne);
4319
+
4320
+ wsp_ggml_set_op_params_i32(result, 0, s0);
4321
+ wsp_ggml_set_op_params_i32(result, 1, s1);
4322
+ wsp_ggml_set_op_params_i32(result, 2, p0);
4323
+ wsp_ggml_set_op_params_i32(result, 3, p1);
4324
+ wsp_ggml_set_op_params_i32(result, 4, d0);
4325
+ wsp_ggml_set_op_params_i32(result, 5, d1);
4326
+
4327
+ result->op = WSP_GGML_OP_CONV_2D;
4328
+ result->src[0] = a;
4329
+ result->src[1] = b;
4330
+
4331
+ return result;
4332
+ }
4333
+
4134
4334
  // wsp_ggml_conv_transpose_2d_p0
4135
4335
 
4136
4336
  static int64_t wsp_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4247,24 +4447,21 @@ struct wsp_ggml_tensor * wsp_ggml_pool_2d_back(
4247
4447
  return result;
4248
4448
  }
4249
4449
 
4250
- // wsp_ggml_upscale
4450
+ // wsp_ggml_upscale / wsp_ggml_interpolate
4251
4451
 
4252
- static struct wsp_ggml_tensor * wsp_ggml_upscale_impl(
4452
+ static struct wsp_ggml_tensor * wsp_ggml_interpolate_impl(
4253
4453
  struct wsp_ggml_context * ctx,
4254
4454
  struct wsp_ggml_tensor * a,
4255
- int ne0,
4256
- int ne1,
4257
- int ne2,
4258
- int ne3,
4259
- enum wsp_ggml_scale_mode mode) {
4260
- WSP_GGML_ASSERT(a->ne[0] <= ne0);
4261
- WSP_GGML_ASSERT(a->ne[1] <= ne1);
4262
- WSP_GGML_ASSERT(a->ne[2] <= ne2);
4263
- WSP_GGML_ASSERT(a->ne[3] <= ne3);
4455
+ int64_t ne0,
4456
+ int64_t ne1,
4457
+ int64_t ne2,
4458
+ int64_t ne3,
4459
+ uint32_t mode) {
4460
+ WSP_GGML_ASSERT((mode & 0xFF) < WSP_GGML_SCALE_MODE_COUNT);
4264
4461
 
4265
4462
  struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4266
4463
 
4267
- wsp_ggml_set_op_params_i32(result, 0, mode);
4464
+ wsp_ggml_set_op_params_i32(result, 0, (int32_t)mode);
4268
4465
 
4269
4466
  result->op = WSP_GGML_OP_UPSCALE;
4270
4467
  result->src[0] = a;
@@ -4277,7 +4474,8 @@ struct wsp_ggml_tensor * wsp_ggml_upscale(
4277
4474
  struct wsp_ggml_tensor * a,
4278
4475
  int scale_factor,
4279
4476
  enum wsp_ggml_scale_mode mode) {
4280
- return wsp_ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4477
+ WSP_GGML_ASSERT(scale_factor > 1);
4478
+ return wsp_ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4281
4479
  }
4282
4480
 
4283
4481
  struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
@@ -4288,7 +4486,18 @@ struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
4288
4486
  int ne2,
4289
4487
  int ne3,
4290
4488
  enum wsp_ggml_scale_mode mode) {
4291
- return wsp_ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4489
+ return wsp_ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4490
+ }
4491
+
4492
+ struct wsp_ggml_tensor * wsp_ggml_interpolate(
4493
+ struct wsp_ggml_context * ctx,
4494
+ struct wsp_ggml_tensor * a,
4495
+ int64_t ne0,
4496
+ int64_t ne1,
4497
+ int64_t ne2,
4498
+ int64_t ne3,
4499
+ uint32_t mode) {
4500
+ return wsp_ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4292
4501
  }
4293
4502
 
4294
4503
  // wsp_ggml_pad
@@ -5815,19 +6024,32 @@ static void wsp_ggml_compute_backward(
5815
6024
  WSP_GGML_ASSERT(!src2_needs_grads || wsp_ggml_are_same_shape(src2, cgraph->grads[isrc2]));
5816
6025
  }
5817
6026
 
5818
- static void wsp_ggml_visit_parents(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * node) {
6027
+ static size_t wsp_ggml_visit_parents(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * node) {
5819
6028
  // check if already visited
5820
- if (wsp_ggml_hash_insert(&cgraph->visited_hash_set, node) == WSP_GGML_HASHSET_ALREADY_EXISTS) {
5821
- return;
6029
+ size_t node_hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
6030
+ WSP_GGML_ASSERT(node_hash_pos != WSP_GGML_HASHSET_FULL);
6031
+ if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
6032
+ // This is the first time we see this node in the current graph.
6033
+ cgraph->visited_hash_set.keys[node_hash_pos] = node;
6034
+ wsp_ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
6035
+ cgraph->use_counts[node_hash_pos] = 0;
6036
+ } else {
6037
+ // already visited
6038
+ return node_hash_pos;
5822
6039
  }
5823
6040
 
5824
6041
  for (int i = 0; i < WSP_GGML_MAX_SRC; ++i) {
5825
6042
  const int k =
5826
6043
  (cgraph->order == WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
5827
6044
  (cgraph->order == WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (WSP_GGML_MAX_SRC-1-i) :
5828
- /* unknown order, just fall back to using i*/ i;
5829
- if (node->src[k]) {
5830
- wsp_ggml_visit_parents(cgraph, node->src[k]);
6045
+ /* unknown order, just fall back to using i */ i;
6046
+
6047
+ struct wsp_ggml_tensor * src = node->src[k];
6048
+ if (src) {
6049
+ size_t src_hash_pos = wsp_ggml_visit_parents(cgraph, src);
6050
+
6051
+ // Update the use count for this operand.
6052
+ cgraph->use_counts[src_hash_pos]++;
5831
6053
  }
5832
6054
  }
5833
6055
 
@@ -5851,6 +6073,8 @@ static void wsp_ggml_visit_parents(struct wsp_ggml_cgraph * cgraph, struct wsp_g
5851
6073
  cgraph->nodes[cgraph->n_nodes] = node;
5852
6074
  cgraph->n_nodes++;
5853
6075
  }
6076
+
6077
+ return node_hash_pos;
5854
6078
  }
5855
6079
 
5856
6080
  static void wsp_ggml_build_forward_impl(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor, bool expand) {
@@ -5988,6 +6212,7 @@ static size_t wsp_ggml_graph_nbytes(size_t size, bool grads) {
5988
6212
  incr_ptr_aligned(&p, sizeof(struct wsp_ggml_cgraph), 1);
5989
6213
  incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)); // nodes
5990
6214
  incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)); // leafs
6215
+ incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
5991
6216
  incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)); // hash keys
5992
6217
  if (grads) {
5993
6218
  incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)); // grads
@@ -6017,11 +6242,12 @@ struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx
6017
6242
 
6018
6243
  void * p = cgraph + 1;
6019
6244
 
6020
- struct wsp_ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *));
6021
- struct wsp_ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *));
6022
- struct wsp_ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *));
6023
- struct wsp_ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)) : NULL;
6024
- struct wsp_ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)) : NULL;
6245
+ struct wsp_ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *));
6246
+ struct wsp_ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *));
6247
+ int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
6248
+ struct wsp_ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *));
6249
+ struct wsp_ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)) : NULL;
6250
+ struct wsp_ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)) : NULL;
6025
6251
 
6026
6252
  wsp_ggml_bitset_t * hash_used = incr_ptr_aligned(&p, wsp_ggml_bitset_size(hash_size) * sizeof(wsp_ggml_bitset_t), sizeof(wsp_ggml_bitset_t));
6027
6253
 
@@ -6036,6 +6262,7 @@ struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx
6036
6262
  /*.grads =*/ grads_ptr,
6037
6263
  /*.grad_accs =*/ grad_accs_ptr,
6038
6264
  /*.leafs =*/ leafs_ptr,
6265
+ /*.use_counts =*/ use_counts_ptr,
6039
6266
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
6040
6267
  /*.order =*/ WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
6041
6268
  };
@@ -6062,7 +6289,8 @@ struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph0, int
6062
6289
  /*.grads =*/ NULL, // gradients would need visited_hash_set
6063
6290
  /*.grad_accs =*/ NULL,
6064
6291
  /*.leafs =*/ NULL,
6065
- /*.visited_hash_set =*/ { 0, NULL, NULL },
6292
+ /*.use_counts =*/ cgraph0->use_counts,
6293
+ /*.visited_hash_set =*/ cgraph0->visited_hash_set,
6066
6294
  /*.order =*/ cgraph0->order,
6067
6295
  };
6068
6296
 
@@ -6089,7 +6317,8 @@ void wsp_ggml_graph_cpy(struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * d
6089
6317
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
6090
6318
  // copy all hashset keys (tensors) that are in use
6091
6319
  if (wsp_ggml_bitset_get(src->visited_hash_set.used, i)) {
6092
- wsp_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6320
+ size_t new_hash_pos = wsp_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6321
+ dst->use_counts[new_hash_pos] = src->use_counts[i];
6093
6322
  }
6094
6323
  }
6095
6324