whisper.rn 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/android/src/main/jni.cpp +12 -3
  4. package/cpp/ggml-alloc.c +292 -130
  5. package/cpp/ggml-backend-impl.h +4 -4
  6. package/cpp/ggml-backend-reg.cpp +13 -5
  7. package/cpp/ggml-backend.cpp +207 -17
  8. package/cpp/ggml-backend.h +19 -1
  9. package/cpp/ggml-cpu/amx/amx.cpp +5 -2
  10. package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
  11. package/cpp/ggml-cpu/arch-fallback.h +0 -4
  12. package/cpp/ggml-cpu/common.h +14 -0
  13. package/cpp/ggml-cpu/ggml-cpu-impl.h +14 -7
  14. package/cpp/ggml-cpu/ggml-cpu.c +65 -44
  15. package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
  16. package/cpp/ggml-cpu/ops.cpp +542 -775
  17. package/cpp/ggml-cpu/ops.h +2 -0
  18. package/cpp/ggml-cpu/simd-mappings.h +88 -59
  19. package/cpp/ggml-cpu/unary-ops.cpp +135 -0
  20. package/cpp/ggml-cpu/unary-ops.h +5 -0
  21. package/cpp/ggml-cpu/vec.cpp +227 -20
  22. package/cpp/ggml-cpu/vec.h +407 -56
  23. package/cpp/ggml-cpu.h +1 -1
  24. package/cpp/ggml-impl.h +94 -12
  25. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  26. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  27. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  28. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  29. package/cpp/ggml-metal/ggml-metal-device.cpp +1565 -0
  30. package/cpp/ggml-metal/ggml-metal-device.h +244 -0
  31. package/cpp/ggml-metal/ggml-metal-device.m +1325 -0
  32. package/cpp/ggml-metal/ggml-metal-impl.h +802 -0
  33. package/cpp/ggml-metal/ggml-metal-ops.cpp +3583 -0
  34. package/cpp/ggml-metal/ggml-metal-ops.h +88 -0
  35. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  36. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  37. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  38. package/cpp/ggml-metal-impl.h +40 -40
  39. package/cpp/ggml-metal.h +1 -6
  40. package/cpp/ggml-quants.c +1 -0
  41. package/cpp/ggml.c +341 -15
  42. package/cpp/ggml.h +150 -5
  43. package/cpp/jsi/RNWhisperJSI.cpp +9 -2
  44. package/cpp/jsi/ThreadPool.h +3 -3
  45. package/cpp/rn-whisper.h +1 -0
  46. package/cpp/whisper.cpp +89 -72
  47. package/cpp/whisper.h +1 -0
  48. package/ios/CMakeLists.txt +6 -1
  49. package/ios/RNWhisperContext.mm +3 -1
  50. package/ios/RNWhisperVadContext.mm +14 -13
  51. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  57. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  58. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  59. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  60. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  61. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  62. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  68. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  69. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  70. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  71. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  72. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  73. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  74. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  77. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  78. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  79. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  80. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  81. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  82. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  83. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  86. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  87. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  88. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  89. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  90. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  91. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  92. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  93. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  94. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  95. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  96. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  97. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  98. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  99. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  100. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  101. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  102. package/lib/commonjs/version.json +1 -1
  103. package/lib/module/NativeRNWhisper.js.map +1 -1
  104. package/lib/module/version.json +1 -1
  105. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  106. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  107. package/package.json +1 -1
  108. package/src/NativeRNWhisper.ts +2 -0
  109. package/src/version.json +1 -1
  110. package/whisper-rn.podspec +8 -9
  111. package/cpp/ggml-metal.m +0 -6779
  112. package/cpp/ggml-whisper-sim.metallib +0 -0
  113. package/cpp/ggml-whisper.metallib +0 -0
package/cpp/ggml.c CHANGED
@@ -982,7 +982,9 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
982
982
  "CONV_TRANSPOSE_1D",
983
983
  "IM2COL",
984
984
  "IM2COL_BACK",
985
+ "IM2COL_3D",
985
986
  "CONV_2D",
987
+ "CONV_3D",
986
988
  "CONV_2D_DW",
987
989
  "CONV_TRANSPOSE_2D",
988
990
  "POOL_1D",
@@ -1025,7 +1027,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1025
1027
  "GLU",
1026
1028
  };
1027
1029
 
1028
- static_assert(WSP_GGML_OP_COUNT == 88, "WSP_GGML_OP_COUNT != 88");
1030
+ static_assert(WSP_GGML_OP_COUNT == 90, "WSP_GGML_OP_COUNT != 90");
1029
1031
 
1030
1032
  static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1031
1033
  "none",
@@ -1084,7 +1086,9 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1084
1086
  "conv_transpose_1d(x)",
1085
1087
  "im2col(x)",
1086
1088
  "im2col_back(x)",
1089
+ "im2col_3d(x)",
1087
1090
  "conv_2d(x)",
1091
+ "conv_3d(x)",
1088
1092
  "conv_2d_dw(x)",
1089
1093
  "conv_transpose_2d(x)",
1090
1094
  "pool_1d(x)",
@@ -1127,7 +1131,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1127
1131
  "glu(x)",
1128
1132
  };
1129
1133
 
1130
- static_assert(WSP_GGML_OP_COUNT == 88, "WSP_GGML_OP_COUNT != 88");
1134
+ static_assert(WSP_GGML_OP_COUNT == 90, "WSP_GGML_OP_COUNT != 90");
1131
1135
 
1132
1136
  static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
1133
1137
 
@@ -1147,10 +1151,14 @@ static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
1147
1151
  "HARDSIGMOID",
1148
1152
  "EXP",
1149
1153
  "GELU_ERF",
1154
+ "XIELU",
1155
+ "FLOOR",
1156
+ "CEIL",
1157
+ "ROUND",
1158
+ "TRUNC",
1150
1159
  };
1151
1160
 
1152
- static_assert(WSP_GGML_UNARY_OP_COUNT == 15, "WSP_GGML_UNARY_OP_COUNT != 15");
1153
-
1161
+ static_assert(WSP_GGML_UNARY_OP_COUNT == 20, "WSP_GGML_UNARY_OP_COUNT != 20");
1154
1162
 
1155
1163
  static const char * WSP_GGML_GLU_OP_NAME[WSP_GGML_GLU_OP_COUNT] = {
1156
1164
  "REGLU",
@@ -2656,6 +2664,29 @@ struct wsp_ggml_tensor * wsp_ggml_silu_inplace(
2656
2664
  return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_SILU);
2657
2665
  }
2658
2666
 
2667
+ // wsp_ggml_xielu
2668
+
2669
+ struct wsp_ggml_tensor * wsp_ggml_xielu(
2670
+ struct wsp_ggml_context * ctx,
2671
+ struct wsp_ggml_tensor * a,
2672
+ float alpha_n,
2673
+ float alpha_p,
2674
+ float beta,
2675
+ float eps) {
2676
+ struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
2677
+
2678
+ wsp_ggml_set_op_params_i32(result, 0, (int32_t) WSP_GGML_UNARY_OP_XIELU);
2679
+ wsp_ggml_set_op_params_f32(result, 1, beta + wsp_ggml_softplus(alpha_n));
2680
+ wsp_ggml_set_op_params_f32(result, 2, wsp_ggml_softplus(alpha_p));
2681
+ wsp_ggml_set_op_params_f32(result, 3, beta);
2682
+ wsp_ggml_set_op_params_f32(result, 4, eps);
2683
+
2684
+ result->op = WSP_GGML_OP_UNARY;
2685
+ result->src[0] = a;
2686
+
2687
+ return result;
2688
+ }
2689
+
2659
2690
  // wsp_ggml_silu_back
2660
2691
 
2661
2692
  struct wsp_ggml_tensor * wsp_ggml_silu_back(
@@ -2730,6 +2761,62 @@ static struct wsp_ggml_tensor * wsp_ggml_glu_impl(
2730
2761
  return result;
2731
2762
  }
2732
2763
 
2764
+ // wsp_ggml_floor
2765
+
2766
+ struct wsp_ggml_tensor * wsp_ggml_floor(
2767
+ struct wsp_ggml_context * ctx,
2768
+ struct wsp_ggml_tensor * a) {
2769
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_FLOOR);
2770
+ }
2771
+
2772
+ struct wsp_ggml_tensor * wsp_ggml_floor_inplace(
2773
+ struct wsp_ggml_context * ctx,
2774
+ struct wsp_ggml_tensor * a) {
2775
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_FLOOR);
2776
+ }
2777
+
2778
+ // wsp_ggml_ceil
2779
+
2780
+ struct wsp_ggml_tensor * wsp_ggml_ceil(
2781
+ struct wsp_ggml_context * ctx,
2782
+ struct wsp_ggml_tensor * a) {
2783
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_CEIL);
2784
+ }
2785
+
2786
+ struct wsp_ggml_tensor * wsp_ggml_ceil_inplace(
2787
+ struct wsp_ggml_context * ctx,
2788
+ struct wsp_ggml_tensor * a) {
2789
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_CEIL);
2790
+ }
2791
+
2792
+ //wsp_ggml_round
2793
+
2794
+ struct wsp_ggml_tensor * wsp_ggml_round(
2795
+ struct wsp_ggml_context * ctx,
2796
+ struct wsp_ggml_tensor * a) {
2797
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_ROUND);
2798
+ }
2799
+
2800
+ struct wsp_ggml_tensor * wsp_ggml_round_inplace(
2801
+ struct wsp_ggml_context * ctx,
2802
+ struct wsp_ggml_tensor * a) {
2803
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_ROUND);
2804
+ }
2805
+
2806
+ //wsp_ggml_trunc
2807
+
2808
+ struct wsp_ggml_tensor * wsp_ggml_trunc(
2809
+ struct wsp_ggml_context * ctx,
2810
+ struct wsp_ggml_tensor * a) {
2811
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_TRUNC);
2812
+ }
2813
+
2814
+ struct wsp_ggml_tensor * wsp_ggml_trunc_inplace(
2815
+ struct wsp_ggml_context * ctx,
2816
+ struct wsp_ggml_tensor * a) {
2817
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_TRUNC);
2818
+ }
2819
+
2733
2820
  struct wsp_ggml_tensor * wsp_ggml_glu(
2734
2821
  struct wsp_ggml_context * ctx,
2735
2822
  struct wsp_ggml_tensor * a,
@@ -3627,6 +3714,7 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
3627
3714
  struct wsp_ggml_tensor * a,
3628
3715
  struct wsp_ggml_tensor * b) {
3629
3716
  WSP_GGML_ASSERT(a->ne[2] == b->ne[1]);
3717
+ WSP_GGML_ASSERT(a->ne[3] == b->ne[2]);
3630
3718
  WSP_GGML_ASSERT(b->ne[3] == 1);
3631
3719
  WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
3632
3720
 
@@ -3680,7 +3768,7 @@ struct wsp_ggml_tensor * wsp_ggml_set_rows(
3680
3768
  WSP_GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
3681
3769
  WSP_GGML_ASSERT(c->ne[3] == 1);
3682
3770
  WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_F32);
3683
- WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_I64);
3771
+ WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_I64 || c->type == WSP_GGML_TYPE_I32);
3684
3772
 
3685
3773
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(a));
3686
3774
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(b));
@@ -3690,6 +3778,7 @@ struct wsp_ggml_tensor * wsp_ggml_set_rows(
3690
3778
  result->op = WSP_GGML_OP_SET_ROWS;
3691
3779
  result->src[0] = b;
3692
3780
  result->src[1] = c;
3781
+ result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
3693
3782
 
3694
3783
  return result;
3695
3784
  }
@@ -3831,6 +3920,15 @@ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
3831
3920
  return wsp_ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
3832
3921
  }
3833
3922
 
3923
+ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_inplace(
3924
+ struct wsp_ggml_context * ctx,
3925
+ struct wsp_ggml_tensor * a,
3926
+ struct wsp_ggml_tensor * mask,
3927
+ float scale,
3928
+ float max_bias) {
3929
+ return wsp_ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
3930
+ }
3931
+
3834
3932
  void wsp_ggml_soft_max_add_sinks(
3835
3933
  struct wsp_ggml_tensor * a,
3836
3934
  struct wsp_ggml_tensor * sinks) {
@@ -3930,7 +4028,7 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl(
3930
4028
  memcpy(params + 8, &attn_factor, sizeof(float));
3931
4029
  memcpy(params + 9, &beta_fast, sizeof(float));
3932
4030
  memcpy(params + 10, &beta_slow, sizeof(float));
3933
- if (mrope_used) {
4031
+ if (mrope_used && sections) {
3934
4032
  memcpy(params + 11, sections, sizeof(int32_t) * WSP_GGML_MROPE_SECTIONS);
3935
4033
  } else {
3936
4034
  memset(params + 11, 0, sizeof(int32_t) * WSP_GGML_MROPE_SECTIONS);
@@ -4367,6 +4465,91 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d(
4367
4465
  return result;
4368
4466
  }
4369
4467
 
4468
+ // a: [OC*IC, KD, KH, KW]
4469
+ // b: [N*IC, ID, IH, IW]
4470
+ // result: [N*OD, OH, OW, IC * KD * KH * KW]
4471
+ struct wsp_ggml_tensor * wsp_ggml_im2col_3d(
4472
+ struct wsp_ggml_context * ctx,
4473
+ struct wsp_ggml_tensor * a,
4474
+ struct wsp_ggml_tensor * b,
4475
+ int64_t IC,
4476
+ int s0, // stride width
4477
+ int s1, // stride height
4478
+ int s2, // stride depth
4479
+ int p0, // padding width
4480
+ int p1, // padding height
4481
+ int p2, // padding depth
4482
+ int d0, // dilation width
4483
+ int d1, // dilation height
4484
+ int d2, // dilation depth
4485
+ enum wsp_ggml_type dst_type) {
4486
+ const int64_t N = b->ne[3] / IC;
4487
+ const int64_t ID = b->ne[2];
4488
+ const int64_t IH = b->ne[1];
4489
+ const int64_t IW = b->ne[0];
4490
+
4491
+ const int64_t OC = a->ne[3] / IC;
4492
+ UNUSED(OC);
4493
+ const int64_t KD = a->ne[2];
4494
+ const int64_t KH = a->ne[1];
4495
+ const int64_t KW = a->ne[0];
4496
+ const int64_t OD = wsp_ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
4497
+ const int64_t OH = wsp_ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
4498
+ const int64_t OW = wsp_ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
4499
+
4500
+ WSP_GGML_ASSERT((OD > 0) && "b too small compared to a");
4501
+ WSP_GGML_ASSERT((OH > 0) && "b too small compared to a");
4502
+ WSP_GGML_ASSERT((OW > 0) && "b too small compared to a");
4503
+
4504
+
4505
+ const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
4506
+
4507
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, dst_type, 4, ne);
4508
+ int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
4509
+ wsp_ggml_set_op_params(result, params, sizeof(params));
4510
+
4511
+ result->op = WSP_GGML_OP_IM2COL_3D;
4512
+ result->src[0] = a;
4513
+ result->src[1] = b;
4514
+
4515
+ return result;
4516
+ }
4517
+
4518
+ // a: [OC*IC, KD, KH, KW]
4519
+ // b: [N*IC, ID, IH, IW]
4520
+ // result: [N*OC, OD, OH, OW]
4521
+ struct wsp_ggml_tensor * wsp_ggml_conv_3d(
4522
+ struct wsp_ggml_context * ctx,
4523
+ struct wsp_ggml_tensor * a,
4524
+ struct wsp_ggml_tensor * b,
4525
+ int64_t IC,
4526
+ int s0, // stride width
4527
+ int s1, // stride height
4528
+ int s2, // stride depth
4529
+ int p0, // padding width
4530
+ int p1, // padding height
4531
+ int p2, // padding depth
4532
+ int d0, // dilation width
4533
+ int d1, // dilation height
4534
+ int d2 // dilation depth
4535
+ ) {
4536
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
4537
+
4538
+ int64_t OC = a->ne[3] / IC;
4539
+ int64_t N = b->ne[3] / IC;
4540
+ struct wsp_ggml_tensor * result =
4541
+ wsp_ggml_mul_mat(ctx,
4542
+ wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
4543
+ wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
4544
+
4545
+ int64_t OD = im2col->ne[3] / N;
4546
+ result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
4547
+ result = wsp_ggml_cont(ctx, wsp_ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
4548
+ result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
4549
+
4550
+ return result;
4551
+ }
4552
+
4370
4553
  // wsp_ggml_conv_2d_sk_p0
4371
4554
 
4372
4555
  struct wsp_ggml_tensor * wsp_ggml_conv_2d_sk_p0(
@@ -4488,6 +4671,56 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d_direct(
4488
4671
  return result;
4489
4672
  }
4490
4673
 
4674
+ // wsp_ggml_conv_3d_direct
4675
+
4676
+ struct wsp_ggml_tensor * wsp_ggml_conv_3d_direct(
4677
+ struct wsp_ggml_context * ctx,
4678
+ struct wsp_ggml_tensor * a,
4679
+ struct wsp_ggml_tensor * b,
4680
+ int s0,
4681
+ int s1,
4682
+ int s2,
4683
+ int p0,
4684
+ int p1,
4685
+ int p2,
4686
+ int d0,
4687
+ int d1,
4688
+ int d2,
4689
+ int c,
4690
+ int n,
4691
+ int oc) {
4692
+
4693
+ WSP_GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
4694
+ WSP_GGML_ASSERT(b->ne[3] == (int64_t) c * n);
4695
+
4696
+ int64_t ne[4];
4697
+ ne[0] = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4698
+ ne[1] = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4699
+ ne[2] = wsp_ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
4700
+ ne[3] = (int64_t) oc * n;
4701
+
4702
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
4703
+
4704
+ wsp_ggml_set_op_params_i32(result, 0, s0);
4705
+ wsp_ggml_set_op_params_i32(result, 1, s1);
4706
+ wsp_ggml_set_op_params_i32(result, 2, s2);
4707
+ wsp_ggml_set_op_params_i32(result, 3, p0);
4708
+ wsp_ggml_set_op_params_i32(result, 4, p1);
4709
+ wsp_ggml_set_op_params_i32(result, 5, p2);
4710
+ wsp_ggml_set_op_params_i32(result, 6, d0);
4711
+ wsp_ggml_set_op_params_i32(result, 7, d1);
4712
+ wsp_ggml_set_op_params_i32(result, 8, d2);
4713
+ wsp_ggml_set_op_params_i32(result, 9, c);
4714
+ wsp_ggml_set_op_params_i32(result, 10, n);
4715
+ wsp_ggml_set_op_params_i32(result, 11, oc);
4716
+
4717
+ result->op = WSP_GGML_OP_CONV_3D;
4718
+ result->src[0] = a;
4719
+ result->src[1] = b;
4720
+
4721
+ return result;
4722
+ }
4723
+
4491
4724
  // wsp_ggml_conv_transpose_2d_p0
4492
4725
 
4493
4726
  static int64_t wsp_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4666,11 +4899,36 @@ struct wsp_ggml_tensor * wsp_ggml_pad(
4666
4899
  int p1,
4667
4900
  int p2,
4668
4901
  int p3) {
4902
+ return wsp_ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
4903
+ }
4904
+
4905
+ struct wsp_ggml_tensor * wsp_ggml_pad_ext(
4906
+ struct wsp_ggml_context * ctx,
4907
+ struct wsp_ggml_tensor * a,
4908
+ int lp0,
4909
+ int rp0,
4910
+ int lp1,
4911
+ int rp1,
4912
+ int lp2,
4913
+ int rp2,
4914
+ int lp3,
4915
+ int rp3
4916
+ ) {
4669
4917
  struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type,
4670
- a->ne[0] + p0,
4671
- a->ne[1] + p1,
4672
- a->ne[2] + p2,
4673
- a->ne[3] + p3);
4918
+ a->ne[0] + lp0 + rp0,
4919
+ a->ne[1] + lp1 + rp1,
4920
+ a->ne[2] + lp2 + rp2,
4921
+ a->ne[3] + lp3 + rp3);
4922
+
4923
+ wsp_ggml_set_op_params_i32(result, 0, lp0);
4924
+ wsp_ggml_set_op_params_i32(result, 1, rp0);
4925
+ wsp_ggml_set_op_params_i32(result, 2, lp1);
4926
+ wsp_ggml_set_op_params_i32(result, 3, rp1);
4927
+ wsp_ggml_set_op_params_i32(result, 4, lp2);
4928
+ wsp_ggml_set_op_params_i32(result, 5, rp2);
4929
+ wsp_ggml_set_op_params_i32(result, 6, lp3);
4930
+ wsp_ggml_set_op_params_i32(result, 7, rp3);
4931
+
4674
4932
 
4675
4933
  result->op = WSP_GGML_OP_PAD;
4676
4934
  result->src[0] = a;
@@ -4766,12 +5024,8 @@ struct wsp_ggml_tensor * wsp_ggml_timestep_embedding(
4766
5024
  struct wsp_ggml_tensor * timesteps,
4767
5025
  int dim,
4768
5026
  int max_period) {
4769
- int actual_dim = dim;
4770
- if (dim % 2 != 0) {
4771
- actual_dim = dim + 1;
4772
- }
4773
5027
 
4774
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
5028
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, dim, timesteps->ne[0]);
4775
5029
 
4776
5030
  wsp_ggml_set_op_params_i32(result, 0, dim);
4777
5031
  wsp_ggml_set_op_params_i32(result, 1, max_period);
@@ -6718,6 +6972,78 @@ void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph) {
6718
6972
  WSP_GGML_LOG_INFO("========================================\n");
6719
6973
  }
6720
6974
 
6975
+ static int wsp_ggml_node_list_find_tensor(const struct wsp_ggml_cgraph * cgraph,
6976
+ const int * idxs,
6977
+ int count,
6978
+ const struct wsp_ggml_tensor * tensor) {
6979
+ WSP_GGML_ASSERT(cgraph && idxs);
6980
+ for (int i = 0; i < count; ++i) {
6981
+ const int node_idx = idxs[i];
6982
+
6983
+ if (node_idx >= cgraph->n_nodes) {
6984
+ return -1;
6985
+ }
6986
+ if (cgraph->nodes[node_idx] == tensor) {
6987
+ return i;
6988
+ }
6989
+ }
6990
+ return -1;
6991
+ }
6992
+
6993
+ bool wsp_ggml_can_fuse_subgraph_ext(const struct wsp_ggml_cgraph * cgraph,
6994
+ const int * node_idxs,
6995
+ int count,
6996
+ const enum wsp_ggml_op * ops,
6997
+ const int * outputs,
6998
+ int num_outputs) {
6999
+ WSP_GGML_ASSERT(outputs && num_outputs > 0);
7000
+
7001
+ for (int i = 0; i < count; ++i) {
7002
+ if (node_idxs[i] >= cgraph->n_nodes) {
7003
+ return false;
7004
+ }
7005
+
7006
+ const struct wsp_ggml_tensor * node = cgraph->nodes[node_idxs[i]];
7007
+
7008
+ if (node->op != ops[i]) {
7009
+ return false;
7010
+ }
7011
+
7012
+ if (wsp_ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
7013
+ continue;
7014
+ }
7015
+
7016
+ if (node->flags & WSP_GGML_TENSOR_FLAG_OUTPUT) {
7017
+ return false;
7018
+ }
7019
+
7020
+ int subgraph_uses = 0;
7021
+ for (int j = i + 1; j < count; ++j) {
7022
+ const struct wsp_ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
7023
+ for (int src_idx = 0; src_idx < WSP_GGML_MAX_SRC; src_idx++) {
7024
+ if (other_node->src[src_idx] == node) {
7025
+ subgraph_uses++;
7026
+ }
7027
+ }
7028
+ }
7029
+
7030
+ if (subgraph_uses != wsp_ggml_node_get_use_count(cgraph, node_idxs[i])) {
7031
+ return false;
7032
+ }
7033
+
7034
+ // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
7035
+ struct wsp_ggml_tensor * view_src = node->view_src;
7036
+ while (view_src) {
7037
+ if (wsp_ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
7038
+ return false;
7039
+ }
7040
+ view_src = view_src->view_src;
7041
+ }
7042
+ }
7043
+
7044
+ return true;
7045
+ }
7046
+
6721
7047
  // check if node is part of the graph
6722
7048
  static bool wsp_ggml_graph_find(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node) {
6723
7049
  if (cgraph == NULL) {