whisper.rn 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +49 -18
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend-reg.cpp +8 -0
  5. package/cpp/ggml-backend.cpp +0 -2
  6. package/cpp/ggml-backend.h +2 -0
  7. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  8. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  9. package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
  10. package/cpp/ggml-cpu/ggml-cpu.c +67 -24
  11. package/cpp/ggml-cpu/ops.cpp +489 -364
  12. package/cpp/ggml-cpu/ops.h +4 -4
  13. package/cpp/ggml-cpu/repack.cpp +143 -29
  14. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  15. package/cpp/ggml-cpu/unary-ops.cpp +151 -0
  16. package/cpp/ggml-cpu/unary-ops.h +7 -0
  17. package/cpp/ggml-cpu/vec.cpp +83 -0
  18. package/cpp/ggml-cpu/vec.h +20 -8
  19. package/cpp/ggml-impl.h +67 -2
  20. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  21. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  22. package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
  23. package/cpp/ggml-metal/ggml-metal-device.h +26 -1
  24. package/cpp/ggml-metal/ggml-metal-device.m +243 -28
  25. package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
  26. package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
  27. package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
  28. package/cpp/ggml-metal/ggml-metal.cpp +8 -3
  29. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  30. package/cpp/ggml.c +317 -4
  31. package/cpp/ggml.h +139 -0
  32. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  33. package/cpp/rn-whisper.h +1 -0
  34. package/cpp/whisper.cpp +8 -2
  35. package/ios/RNWhisperContext.mm +3 -1
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  37. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  39. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  47. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  54. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  62. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  70. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  71. package/lib/commonjs/version.json +1 -1
  72. package/lib/module/NativeRNWhisper.js.map +1 -1
  73. package/lib/module/version.json +1 -1
  74. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  75. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  76. package/package.json +1 -1
  77. package/src/NativeRNWhisper.ts +2 -0
  78. package/src/version.json +1 -1
  79. package/whisper-rn.podspec +1 -1
  80. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  81. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  82. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
package/cpp/ggml.c CHANGED
@@ -943,6 +943,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
943
943
  "COS",
944
944
  "SUM",
945
945
  "SUM_ROWS",
946
+ "CUMSUM",
946
947
  "MEAN",
947
948
  "ARGMAX",
948
949
  "COUNT_EQUAL",
@@ -998,6 +999,8 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
998
999
  "TIMESTEP_EMBEDDING",
999
1000
  "ARGSORT",
1000
1001
  "LEAKY_RELU",
1002
+ "TRI",
1003
+ "FILL",
1001
1004
 
1002
1005
  "FLASH_ATTN_EXT",
1003
1006
  "FLASH_ATTN_BACK",
@@ -1010,6 +1013,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1010
1013
  "RWKV_WKV6",
1011
1014
  "GATED_LINEAR_ATTN",
1012
1015
  "RWKV_WKV7",
1016
+ "SOLVE_TRI",
1013
1017
 
1014
1018
  "UNARY",
1015
1019
 
@@ -1027,7 +1031,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1027
1031
  "GLU",
1028
1032
  };
1029
1033
 
1030
- static_assert(WSP_GGML_OP_COUNT == 90, "WSP_GGML_OP_COUNT != 90");
1034
+ static_assert(WSP_GGML_OP_COUNT == 94, "WSP_GGML_OP_COUNT != 94");
1031
1035
 
1032
1036
  static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1033
1037
  "none",
@@ -1047,6 +1051,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1047
1051
  "cos(x)",
1048
1052
  "Σx",
1049
1053
  "Σx_k",
1054
+ "cumsum(x)",
1050
1055
  "Σx/n",
1051
1056
  "argmax(x)",
1052
1057
  "count_equal(x)",
@@ -1102,6 +1107,8 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1102
1107
  "timestep_embedding(timesteps, dim, max_period)",
1103
1108
  "argsort(x)",
1104
1109
  "leaky_relu(x)",
1110
+ "tri(x)",
1111
+ "fill(x, c)",
1105
1112
 
1106
1113
  "flash_attn_ext(x)",
1107
1114
  "flash_attn_back(x)",
@@ -1114,6 +1121,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1114
1121
  "rwkv_wkv6(k, v, r, tf, td, s)",
1115
1122
  "gated_linear_attn(k, v, q, gate, s)",
1116
1123
  "rwkv_wkv7(r, w, k, v, a, b, s)",
1124
+ "A X = B, A triangular, solve X",
1117
1125
 
1118
1126
  "unary(x)",
1119
1127
 
@@ -1131,7 +1139,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1131
1139
  "glu(x)",
1132
1140
  };
1133
1141
 
1134
- static_assert(WSP_GGML_OP_COUNT == 90, "WSP_GGML_OP_COUNT != 90");
1142
+ static_assert(WSP_GGML_OP_COUNT == 94, "WSP_GGML_OP_COUNT != 94");
1135
1143
 
1136
1144
  static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
1137
1145
 
@@ -1150,11 +1158,17 @@ static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
1150
1158
  "HARDSWISH",
1151
1159
  "HARDSIGMOID",
1152
1160
  "EXP",
1161
+ "EXPM1",
1162
+ "SOFTPLUS",
1153
1163
  "GELU_ERF",
1164
+ "XIELU",
1165
+ "FLOOR",
1166
+ "CEIL",
1167
+ "ROUND",
1168
+ "TRUNC",
1154
1169
  };
1155
1170
 
1156
- static_assert(WSP_GGML_UNARY_OP_COUNT == 15, "WSP_GGML_UNARY_OP_COUNT != 15");
1157
-
1171
+ static_assert(WSP_GGML_UNARY_OP_COUNT == 22, "WSP_GGML_UNARY_OP_COUNT != 22");
1158
1172
 
1159
1173
  static const char * WSP_GGML_GLU_OP_NAME[WSP_GGML_GLU_OP_COUNT] = {
1160
1174
  "REGLU",
@@ -2262,6 +2276,30 @@ struct wsp_ggml_tensor * wsp_ggml_log_inplace(
2262
2276
  return wsp_ggml_log_impl(ctx, a, true);
2263
2277
  }
2264
2278
 
2279
+ struct wsp_ggml_tensor * wsp_ggml_expm1(
2280
+ struct wsp_ggml_context * ctx,
2281
+ struct wsp_ggml_tensor * a) {
2282
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_EXPM1);
2283
+ }
2284
+
2285
+ struct wsp_ggml_tensor * wsp_ggml_expm1_inplace(
2286
+ struct wsp_ggml_context * ctx,
2287
+ struct wsp_ggml_tensor * a) {
2288
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_EXPM1);
2289
+ }
2290
+
2291
+ struct wsp_ggml_tensor * wsp_ggml_softplus(
2292
+ struct wsp_ggml_context * ctx,
2293
+ struct wsp_ggml_tensor * a) {
2294
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_SOFTPLUS);
2295
+ }
2296
+
2297
+ struct wsp_ggml_tensor * wsp_ggml_softplus_inplace(
2298
+ struct wsp_ggml_context * ctx,
2299
+ struct wsp_ggml_tensor * a) {
2300
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_SOFTPLUS);
2301
+ }
2302
+
2265
2303
  // wsp_ggml_sin
2266
2304
 
2267
2305
  static struct wsp_ggml_tensor * wsp_ggml_sin_impl(
@@ -2345,6 +2383,21 @@ struct wsp_ggml_tensor * wsp_ggml_sum_rows(
2345
2383
  return result;
2346
2384
  }
2347
2385
 
2386
+ // wsp_ggml_cumsum
2387
+
2388
+ struct wsp_ggml_tensor * wsp_ggml_cumsum(
2389
+ struct wsp_ggml_context * ctx,
2390
+ struct wsp_ggml_tensor * a) {
2391
+ WSP_GGML_ASSERT(a->type == WSP_GGML_TYPE_F32);
2392
+
2393
+ struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
2394
+
2395
+ result->op = WSP_GGML_OP_CUMSUM;
2396
+ result->src[0] = a;
2397
+
2398
+ return result;
2399
+ }
2400
+
2348
2401
  // wsp_ggml_mean
2349
2402
 
2350
2403
  struct wsp_ggml_tensor * wsp_ggml_mean(
@@ -2660,6 +2713,29 @@ struct wsp_ggml_tensor * wsp_ggml_silu_inplace(
2660
2713
  return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_SILU);
2661
2714
  }
2662
2715
 
2716
+ // wsp_ggml_xielu
2717
+
2718
+ struct wsp_ggml_tensor * wsp_ggml_xielu(
2719
+ struct wsp_ggml_context * ctx,
2720
+ struct wsp_ggml_tensor * a,
2721
+ float alpha_n,
2722
+ float alpha_p,
2723
+ float beta,
2724
+ float eps) {
2725
+ struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
2726
+
2727
+ wsp_ggml_set_op_params_i32(result, 0, (int32_t) WSP_GGML_UNARY_OP_XIELU);
2728
+ wsp_ggml_set_op_params_f32(result, 1, beta + wsp_ggml_compute_softplus_f32(alpha_n));
2729
+ wsp_ggml_set_op_params_f32(result, 2, wsp_ggml_compute_softplus_f32(alpha_p));
2730
+ wsp_ggml_set_op_params_f32(result, 3, beta);
2731
+ wsp_ggml_set_op_params_f32(result, 4, eps);
2732
+
2733
+ result->op = WSP_GGML_OP_UNARY;
2734
+ result->src[0] = a;
2735
+
2736
+ return result;
2737
+ }
2738
+
2663
2739
  // wsp_ggml_silu_back
2664
2740
 
2665
2741
  struct wsp_ggml_tensor * wsp_ggml_silu_back(
@@ -2734,6 +2810,62 @@ static struct wsp_ggml_tensor * wsp_ggml_glu_impl(
2734
2810
  return result;
2735
2811
  }
2736
2812
 
2813
+ // wsp_ggml_floor
2814
+
2815
+ struct wsp_ggml_tensor * wsp_ggml_floor(
2816
+ struct wsp_ggml_context * ctx,
2817
+ struct wsp_ggml_tensor * a) {
2818
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_FLOOR);
2819
+ }
2820
+
2821
+ struct wsp_ggml_tensor * wsp_ggml_floor_inplace(
2822
+ struct wsp_ggml_context * ctx,
2823
+ struct wsp_ggml_tensor * a) {
2824
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_FLOOR);
2825
+ }
2826
+
2827
+ // wsp_ggml_ceil
2828
+
2829
+ struct wsp_ggml_tensor * wsp_ggml_ceil(
2830
+ struct wsp_ggml_context * ctx,
2831
+ struct wsp_ggml_tensor * a) {
2832
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_CEIL);
2833
+ }
2834
+
2835
+ struct wsp_ggml_tensor * wsp_ggml_ceil_inplace(
2836
+ struct wsp_ggml_context * ctx,
2837
+ struct wsp_ggml_tensor * a) {
2838
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_CEIL);
2839
+ }
2840
+
2841
+ //wsp_ggml_round
2842
+
2843
+ struct wsp_ggml_tensor * wsp_ggml_round(
2844
+ struct wsp_ggml_context * ctx,
2845
+ struct wsp_ggml_tensor * a) {
2846
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_ROUND);
2847
+ }
2848
+
2849
+ struct wsp_ggml_tensor * wsp_ggml_round_inplace(
2850
+ struct wsp_ggml_context * ctx,
2851
+ struct wsp_ggml_tensor * a) {
2852
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_ROUND);
2853
+ }
2854
+
2855
+ //wsp_ggml_trunc
2856
+
2857
+ struct wsp_ggml_tensor * wsp_ggml_trunc(
2858
+ struct wsp_ggml_context * ctx,
2859
+ struct wsp_ggml_tensor * a) {
2860
+ return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_TRUNC);
2861
+ }
2862
+
2863
+ struct wsp_ggml_tensor * wsp_ggml_trunc_inplace(
2864
+ struct wsp_ggml_context * ctx,
2865
+ struct wsp_ggml_tensor * a) {
2866
+ return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_TRUNC);
2867
+ }
2868
+
2737
2869
  struct wsp_ggml_tensor * wsp_ggml_glu(
2738
2870
  struct wsp_ggml_context * ctx,
2739
2871
  struct wsp_ggml_tensor * a,
@@ -3837,6 +3969,15 @@ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
3837
3969
  return wsp_ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
3838
3970
  }
3839
3971
 
3972
+ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_inplace(
3973
+ struct wsp_ggml_context * ctx,
3974
+ struct wsp_ggml_tensor * a,
3975
+ struct wsp_ggml_tensor * mask,
3976
+ float scale,
3977
+ float max_bias) {
3978
+ return wsp_ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
3979
+ }
3980
+
3840
3981
  void wsp_ggml_soft_max_add_sinks(
3841
3982
  struct wsp_ggml_tensor * a,
3842
3983
  struct wsp_ggml_tensor * sinks) {
@@ -4944,6 +5085,61 @@ struct wsp_ggml_tensor * wsp_ggml_timestep_embedding(
4944
5085
  return result;
4945
5086
  }
4946
5087
 
5088
+ // wsp_ggml_tri
5089
+
5090
+ struct wsp_ggml_tensor * wsp_ggml_tri(
5091
+ struct wsp_ggml_context * ctx,
5092
+ struct wsp_ggml_tensor * a,
5093
+ enum wsp_ggml_tri_type type) {
5094
+ WSP_GGML_ASSERT(a->type == WSP_GGML_TYPE_F32);
5095
+
5096
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a));
5097
+ WSP_GGML_ASSERT(a->ne[0] == a->ne[1]);
5098
+
5099
+ struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
5100
+
5101
+ wsp_ggml_set_op_params_i32(result, 0, type);
5102
+
5103
+ result->op = WSP_GGML_OP_TRI;
5104
+ result->src[0] = a;
5105
+
5106
+ return result;
5107
+ }
5108
+
5109
+ // wsp_ggml_fill
5110
+
5111
+ static struct wsp_ggml_tensor * wsp_ggml_fill_impl(
5112
+ struct wsp_ggml_context * ctx,
5113
+ struct wsp_ggml_tensor * a,
5114
+ float c,
5115
+ bool inplace) {
5116
+ WSP_GGML_ASSERT(a->type == WSP_GGML_TYPE_F32);
5117
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a));
5118
+
5119
+ struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
5120
+
5121
+ wsp_ggml_set_op_params_f32(result, 0, c);
5122
+
5123
+ result->op = WSP_GGML_OP_FILL;
5124
+ result->src[0] = a;
5125
+
5126
+ return result;
5127
+ }
5128
+
5129
+ struct wsp_ggml_tensor * wsp_ggml_fill(
5130
+ struct wsp_ggml_context * ctx,
5131
+ struct wsp_ggml_tensor * a,
5132
+ float c) {
5133
+ return wsp_ggml_fill_impl(ctx, a, c, false);
5134
+ }
5135
+
5136
+ struct wsp_ggml_tensor * wsp_ggml_fill_inplace(
5137
+ struct wsp_ggml_context * ctx,
5138
+ struct wsp_ggml_tensor * a,
5139
+ float c) {
5140
+ return wsp_ggml_fill_impl(ctx, a, c, true);
5141
+ }
5142
+
4947
5143
  // wsp_ggml_argsort
4948
5144
 
4949
5145
  struct wsp_ggml_tensor * wsp_ggml_argsort(
@@ -5798,6 +5994,41 @@ struct wsp_ggml_tensor * wsp_ggml_opt_step_sgd(
5798
5994
  return result;
5799
5995
  }
5800
5996
 
5997
+ // solve_tri
5998
+
5999
+ struct wsp_ggml_tensor * wsp_ggml_solve_tri(
6000
+ struct wsp_ggml_context * ctx,
6001
+ struct wsp_ggml_tensor * a,
6002
+ struct wsp_ggml_tensor * b,
6003
+ bool left,
6004
+ bool lower,
6005
+ bool uni) {
6006
+ WSP_GGML_ASSERT(a->type == WSP_GGML_TYPE_F32);
6007
+ WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_F32);
6008
+
6009
+ // A must be square and lower diagonal
6010
+ WSP_GGML_ASSERT(a->ne[0] == a->ne[1]);
6011
+ // B must have same outer dimension as A
6012
+ WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
6013
+
6014
+ // batch dimensions must be equal
6015
+ WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
6016
+ WSP_GGML_ASSERT(a->ne[3] == b->ne[3]);
6017
+
6018
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a));
6019
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(b));
6020
+
6021
+ WSP_GGML_ASSERT(lower && left && !uni); // TODO: support other variants
6022
+
6023
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
6024
+
6025
+ result->op = WSP_GGML_OP_SOLVE_TRI;
6026
+ result->src[0] = a;
6027
+ result->src[1] = b;
6028
+
6029
+ return result;
6030
+ }
6031
+
5801
6032
  ////////////////////////////////////////////////////////////////////////////////
5802
6033
 
5803
6034
  struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size) {
@@ -6370,6 +6601,16 @@ static void wsp_ggml_compute_backward(
6370
6601
  wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_mul(ctx, tensor, grad));
6371
6602
  }
6372
6603
  } break;
6604
+ case WSP_GGML_UNARY_OP_EXPM1: {
6605
+ if (src0_needs_grads) {
6606
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_mul(ctx, grad, wsp_ggml_exp(ctx, src0)));
6607
+ }
6608
+ } break;
6609
+ case WSP_GGML_UNARY_OP_SOFTPLUS: {
6610
+ if (src0_needs_grads) {
6611
+ wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_mul(ctx, grad, wsp_ggml_sigmoid(ctx, src0)));
6612
+ }
6613
+ } break;
6373
6614
  default: {
6374
6615
  fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
6375
6616
  __func__, wsp_ggml_unary_op_name(wsp_ggml_get_unary_op(tensor)));
@@ -6880,6 +7121,78 @@ void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph) {
6880
7121
  WSP_GGML_LOG_INFO("========================================\n");
6881
7122
  }
6882
7123
 
7124
+ static int wsp_ggml_node_list_find_tensor(const struct wsp_ggml_cgraph * cgraph,
7125
+ const int * idxs,
7126
+ int count,
7127
+ const struct wsp_ggml_tensor * tensor) {
7128
+ WSP_GGML_ASSERT(cgraph && idxs);
7129
+ for (int i = 0; i < count; ++i) {
7130
+ const int node_idx = idxs[i];
7131
+
7132
+ if (node_idx >= cgraph->n_nodes) {
7133
+ return -1;
7134
+ }
7135
+ if (cgraph->nodes[node_idx] == tensor) {
7136
+ return i;
7137
+ }
7138
+ }
7139
+ return -1;
7140
+ }
7141
+
7142
+ bool wsp_ggml_can_fuse_subgraph_ext(const struct wsp_ggml_cgraph * cgraph,
7143
+ const int * node_idxs,
7144
+ int count,
7145
+ const enum wsp_ggml_op * ops,
7146
+ const int * outputs,
7147
+ int num_outputs) {
7148
+ WSP_GGML_ASSERT(outputs && num_outputs > 0);
7149
+
7150
+ for (int i = 0; i < count; ++i) {
7151
+ if (node_idxs[i] >= cgraph->n_nodes) {
7152
+ return false;
7153
+ }
7154
+
7155
+ const struct wsp_ggml_tensor * node = cgraph->nodes[node_idxs[i]];
7156
+
7157
+ if (node->op != ops[i]) {
7158
+ return false;
7159
+ }
7160
+
7161
+ if (wsp_ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
7162
+ continue;
7163
+ }
7164
+
7165
+ if (node->flags & WSP_GGML_TENSOR_FLAG_OUTPUT) {
7166
+ return false;
7167
+ }
7168
+
7169
+ int subgraph_uses = 0;
7170
+ for (int j = i + 1; j < count; ++j) {
7171
+ const struct wsp_ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
7172
+ for (int src_idx = 0; src_idx < WSP_GGML_MAX_SRC; src_idx++) {
7173
+ if (other_node->src[src_idx] == node) {
7174
+ subgraph_uses++;
7175
+ }
7176
+ }
7177
+ }
7178
+
7179
+ if (subgraph_uses != wsp_ggml_node_get_use_count(cgraph, node_idxs[i])) {
7180
+ return false;
7181
+ }
7182
+
7183
+ // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
7184
+ struct wsp_ggml_tensor * view_src = node->view_src;
7185
+ while (view_src) {
7186
+ if (wsp_ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
7187
+ return false;
7188
+ }
7189
+ view_src = view_src->view_src;
7190
+ }
7191
+ }
7192
+
7193
+ return true;
7194
+ }
7195
+
6883
7196
  // check if node is part of the graph
6884
7197
  static bool wsp_ggml_graph_find(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node) {
6885
7198
  if (cgraph == NULL) {
package/cpp/ggml.h CHANGED
@@ -237,9 +237,12 @@
237
237
  #define WSP_GGML_EXIT_SUCCESS 0
238
238
  #define WSP_GGML_EXIT_ABORTED 1
239
239
 
240
+ // TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726
241
+ #define WSP_GGML_ROPE_TYPE_NORMAL 0
240
242
  #define WSP_GGML_ROPE_TYPE_NEOX 2
241
243
  #define WSP_GGML_ROPE_TYPE_MROPE 8
242
244
  #define WSP_GGML_ROPE_TYPE_VISION 24
245
+ #define WSP_GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
243
246
 
244
247
  #define WSP_GGML_MROPE_SECTIONS 4
245
248
 
@@ -472,6 +475,7 @@ extern "C" {
472
475
  WSP_GGML_OP_COS,
473
476
  WSP_GGML_OP_SUM,
474
477
  WSP_GGML_OP_SUM_ROWS,
478
+ WSP_GGML_OP_CUMSUM,
475
479
  WSP_GGML_OP_MEAN,
476
480
  WSP_GGML_OP_ARGMAX,
477
481
  WSP_GGML_OP_COUNT_EQUAL,
@@ -527,6 +531,8 @@ extern "C" {
527
531
  WSP_GGML_OP_TIMESTEP_EMBEDDING,
528
532
  WSP_GGML_OP_ARGSORT,
529
533
  WSP_GGML_OP_LEAKY_RELU,
534
+ WSP_GGML_OP_TRI,
535
+ WSP_GGML_OP_FILL,
530
536
 
531
537
  WSP_GGML_OP_FLASH_ATTN_EXT,
532
538
  WSP_GGML_OP_FLASH_ATTN_BACK,
@@ -539,6 +545,7 @@ extern "C" {
539
545
  WSP_GGML_OP_RWKV_WKV6,
540
546
  WSP_GGML_OP_GATED_LINEAR_ATTN,
541
547
  WSP_GGML_OP_RWKV_WKV7,
548
+ WSP_GGML_OP_SOLVE_TRI,
542
549
 
543
550
  WSP_GGML_OP_UNARY,
544
551
 
@@ -573,7 +580,14 @@ extern "C" {
573
580
  WSP_GGML_UNARY_OP_HARDSWISH,
574
581
  WSP_GGML_UNARY_OP_HARDSIGMOID,
575
582
  WSP_GGML_UNARY_OP_EXP,
583
+ WSP_GGML_UNARY_OP_EXPM1,
584
+ WSP_GGML_UNARY_OP_SOFTPLUS,
576
585
  WSP_GGML_UNARY_OP_GELU_ERF,
586
+ WSP_GGML_UNARY_OP_XIELU,
587
+ WSP_GGML_UNARY_OP_FLOOR,
588
+ WSP_GGML_UNARY_OP_CEIL,
589
+ WSP_GGML_UNARY_OP_ROUND,
590
+ WSP_GGML_UNARY_OP_TRUNC,
577
591
 
578
592
  WSP_GGML_UNARY_OP_COUNT,
579
593
  };
@@ -612,6 +626,13 @@ extern "C" {
612
626
  WSP_GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
613
627
  };
614
628
 
629
+ enum wsp_ggml_tri_type {
630
+ WSP_GGML_TRI_TYPE_UPPER_DIAG = 0,
631
+ WSP_GGML_TRI_TYPE_UPPER = 1,
632
+ WSP_GGML_TRI_TYPE_LOWER_DIAG = 2,
633
+ WSP_GGML_TRI_TYPE_LOWER = 3
634
+ };
635
+
615
636
  struct wsp_ggml_init_params {
616
637
  // memory pool
617
638
  size_t mem_size; // bytes
@@ -949,6 +970,22 @@ extern "C" {
949
970
  struct wsp_ggml_context * ctx,
950
971
  struct wsp_ggml_tensor * a);
951
972
 
973
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_expm1(
974
+ struct wsp_ggml_context * ctx,
975
+ struct wsp_ggml_tensor * a);
976
+
977
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_expm1_inplace(
978
+ struct wsp_ggml_context * ctx,
979
+ struct wsp_ggml_tensor * a);
980
+
981
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_softplus(
982
+ struct wsp_ggml_context * ctx,
983
+ struct wsp_ggml_tensor * a);
984
+
985
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_softplus_inplace(
986
+ struct wsp_ggml_context * ctx,
987
+ struct wsp_ggml_tensor * a);
988
+
952
989
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin(
953
990
  struct wsp_ggml_context * ctx,
954
991
  struct wsp_ggml_tensor * a);
@@ -975,6 +1012,10 @@ extern "C" {
975
1012
  struct wsp_ggml_context * ctx,
976
1013
  struct wsp_ggml_tensor * a);
977
1014
 
1015
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cumsum(
1016
+ struct wsp_ggml_context * ctx,
1017
+ struct wsp_ggml_tensor * a);
1018
+
978
1019
  // mean along rows
979
1020
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mean(
980
1021
  struct wsp_ggml_context * ctx,
@@ -1148,6 +1189,58 @@ extern "C" {
1148
1189
  struct wsp_ggml_context * ctx,
1149
1190
  struct wsp_ggml_tensor * a);
1150
1191
 
1192
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_floor(
1193
+ struct wsp_ggml_context * ctx,
1194
+ struct wsp_ggml_tensor * a);
1195
+
1196
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_floor_inplace(
1197
+ struct wsp_ggml_context * ctx,
1198
+ struct wsp_ggml_tensor * a);
1199
+
1200
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ceil(
1201
+ struct wsp_ggml_context * ctx,
1202
+ struct wsp_ggml_tensor * a);
1203
+
1204
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ceil_inplace(
1205
+ struct wsp_ggml_context * ctx,
1206
+ struct wsp_ggml_tensor * a);
1207
+
1208
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_round(
1209
+ struct wsp_ggml_context * ctx,
1210
+ struct wsp_ggml_tensor * a);
1211
+
1212
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_round_inplace(
1213
+ struct wsp_ggml_context * ctx,
1214
+ struct wsp_ggml_tensor * a);
1215
+
1216
+ /**
1217
+ * Truncates the fractional part of each element in the tensor (towards zero).
1218
+ * For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0
1219
+ * Similar to std::trunc in C/C++.
1220
+ */
1221
+
1222
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_trunc(
1223
+ struct wsp_ggml_context * ctx,
1224
+ struct wsp_ggml_tensor * a);
1225
+
1226
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_trunc_inplace(
1227
+ struct wsp_ggml_context * ctx,
1228
+ struct wsp_ggml_tensor * a);
1229
+
1230
+
1231
+
1232
+ // xIELU activation function
1233
+ // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
1234
+ // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
1235
+ // that constrain the positive and negative source alpha values respectively
1236
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_xielu(
1237
+ struct wsp_ggml_context * ctx,
1238
+ struct wsp_ggml_tensor * a,
1239
+ float alpha_n,
1240
+ float alpha_p,
1241
+ float beta,
1242
+ float eps);
1243
+
1151
1244
  // gated linear unit ops
1152
1245
  // A: n columns, r rows,
1153
1246
  // result is n / 2 columns, r rows,
@@ -1615,6 +1708,13 @@ extern "C" {
1615
1708
  float scale,
1616
1709
  float max_bias);
1617
1710
 
1711
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_inplace(
1712
+ struct wsp_ggml_context * ctx,
1713
+ struct wsp_ggml_tensor * a,
1714
+ struct wsp_ggml_tensor * mask,
1715
+ float scale,
1716
+ float max_bias);
1717
+
1618
1718
  WSP_GGML_API void wsp_ggml_soft_max_add_sinks(
1619
1719
  struct wsp_ggml_tensor * a,
1620
1720
  struct wsp_ggml_tensor * sinks);
@@ -2041,6 +2141,7 @@ extern "C" {
2041
2141
  enum wsp_ggml_scale_mode {
2042
2142
  WSP_GGML_SCALE_MODE_NEAREST = 0,
2043
2143
  WSP_GGML_SCALE_MODE_BILINEAR = 1,
2144
+ WSP_GGML_SCALE_MODE_BICUBIC = 2,
2044
2145
 
2045
2146
  WSP_GGML_SCALE_MODE_COUNT
2046
2147
  };
@@ -2119,6 +2220,23 @@ extern "C" {
2119
2220
  int shift2,
2120
2221
  int shift3);
2121
2222
 
2223
+ // Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
2224
+ // zeroes everywhere outside the masked area
2225
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_tri(
2226
+ struct wsp_ggml_context * ctx,
2227
+ struct wsp_ggml_tensor * a,
2228
+ enum wsp_ggml_tri_type type);
2229
+
2230
+ // Fill tensor a with constant c
2231
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_fill(
2232
+ struct wsp_ggml_context * ctx,
2233
+ struct wsp_ggml_tensor * a,
2234
+ float c);
2235
+
2236
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_fill_inplace(
2237
+ struct wsp_ggml_context * ctx,
2238
+ struct wsp_ggml_tensor * a,
2239
+ float c);
2122
2240
 
2123
2241
  // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
2124
2242
  // timesteps: [N,]
@@ -2288,6 +2406,27 @@ extern "C" {
2288
2406
  struct wsp_ggml_tensor * b,
2289
2407
  struct wsp_ggml_tensor * state);
2290
2408
 
2409
+ /* Solves a specific equation of the form Ax=B, where A is a triangular matrix
2410
+ * without zeroes on the diagonal (i.e. invertible).
2411
+ * B can have any number of columns, but must have the same number of rows as A
2412
+ * If A is [n, n] and B is [n, m], then the result will be [n, m] as well
2413
+ * Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
2414
+ * where n > 100 sparingly, pre-chunk if necessary.
2415
+ *
2416
+ * If left = false, solves xA=B instead
2417
+ * If lower = false, assumes upper triangular instead
2418
+ * If uni = true, assumes diagonal of A to be all ones (will override actual values)
2419
+ *
2420
+ * TODO: currently only lower, right, non-unitriangular variant is implemented
2421
+ */
2422
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_solve_tri(
2423
+ struct wsp_ggml_context * ctx,
2424
+ struct wsp_ggml_tensor * a,
2425
+ struct wsp_ggml_tensor * b,
2426
+ bool left,
2427
+ bool lower,
2428
+ bool uni);
2429
+
2291
2430
  // custom operators
2292
2431
 
2293
2432
  typedef void (*wsp_ggml_custom1_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, int ith, int nth, void * userdata);