whisper.rn 0.4.0-rc.3 → 0.4.0-rc.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/android/src/main/CMakeLists.txt +2 -0
  2. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  3. package/android/src/main/java/com/rnwhisper/WhisperContext.java +3 -3
  4. package/android/src/main/jni.cpp +6 -2
  5. package/cpp/ggml-alloc.c +413 -280
  6. package/cpp/ggml-alloc.h +67 -8
  7. package/cpp/ggml-backend-impl.h +87 -0
  8. package/cpp/ggml-backend.c +950 -0
  9. package/cpp/ggml-backend.h +136 -0
  10. package/cpp/ggml-impl.h +243 -0
  11. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +591 -121
  12. package/cpp/ggml-metal.h +21 -0
  13. package/cpp/ggml-metal.m +623 -234
  14. package/cpp/ggml-quants.c +7377 -0
  15. package/cpp/ggml-quants.h +224 -0
  16. package/cpp/ggml.c +3773 -4455
  17. package/cpp/ggml.h +279 -146
  18. package/cpp/whisper.cpp +182 -103
  19. package/cpp/whisper.h +48 -11
  20. package/ios/RNWhisper.mm +8 -2
  21. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  22. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  23. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  24. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  25. package/ios/RNWhisperContext.h +5 -1
  26. package/ios/RNWhisperContext.mm +76 -10
  27. package/jest/mock.js +1 -1
  28. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  29. package/lib/commonjs/index.js +28 -9
  30. package/lib/commonjs/index.js.map +1 -1
  31. package/lib/commonjs/version.json +1 -1
  32. package/lib/module/NativeRNWhisper.js.map +1 -1
  33. package/lib/module/index.js +28 -9
  34. package/lib/module/index.js.map +1 -1
  35. package/lib/module/version.json +1 -1
  36. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  37. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  38. package/lib/typescript/index.d.ts +7 -2
  39. package/lib/typescript/index.d.ts.map +1 -1
  40. package/package.json +1 -1
  41. package/src/NativeRNWhisper.ts +8 -1
  42. package/src/index.ts +29 -17
  43. package/src/version.json +1 -1
  44. package/whisper-rn.podspec +1 -2
package/cpp/whisper.cpp CHANGED
@@ -120,6 +120,7 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
120
120
  //#define WHISPER_USE_FLASH_ATTN
121
121
  //#define WHISPER_USE_FLASH_FF
122
122
  #define WHISPER_MAX_DECODERS 16
123
+ #define WHISPER_MAX_NODES 4096
123
124
 
124
125
  //
125
126
  // ggml helpers
@@ -192,6 +193,15 @@ enum e_model {
192
193
  MODEL_LARGE,
193
194
  };
194
195
 
196
+ static const std::map<e_model, std::string> g_model_name = {
197
+ { MODEL_UNKNOWN, "unknown" },
198
+ { MODEL_TINY, "tiny" },
199
+ { MODEL_BASE, "base" },
200
+ { MODEL_SMALL, "small" },
201
+ { MODEL_MEDIUM, "medium" },
202
+ { MODEL_LARGE, "large" },
203
+ };
204
+
195
205
  static const std::map<std::string, std::pair<int, std::string>> g_lang = {
196
206
  { "en", { 0, "english", } },
197
207
  { "zh", { 1, "chinese", } },
@@ -292,6 +302,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
292
302
  { "ba", { 96, "bashkir", } },
293
303
  { "jw", { 97, "javanese", } },
294
304
  { "su", { 98, "sundanese", } },
305
+ { "yue", { 99, "cantonese", } },
295
306
  };
296
307
 
297
308
  static const size_t MB = 1ull*1024*1024;
@@ -401,7 +412,11 @@ struct whisper_vocab {
401
412
  id token_beg = 50363; // begin timestamps
402
413
 
403
414
  bool is_multilingual() const {
404
- return n_vocab == 51865;
415
+ return n_vocab >= 51865;
416
+ }
417
+
418
+ int num_languages() const {
419
+ return n_vocab - 51765 - (is_multilingual() ? 1 : 0);
405
420
  }
406
421
  };
407
422
 
@@ -663,7 +678,7 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::funct
663
678
  auto & meta = allocr.meta;
664
679
  auto & data = allocr.data;
665
680
 
666
- meta.resize(wsp_ggml_tensor_overhead()*WSP_GGML_MAX_NODES + wsp_ggml_graph_overhead());
681
+ meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
667
682
 
668
683
  alloc = wsp_ggml_allocr_new_measure(tensor_alignment);
669
684
 
@@ -735,7 +750,7 @@ struct whisper_state {
735
750
 
736
751
  int lang_id = 0; // english by default
737
752
 
738
- std::string path_model; // populated by whisper_init_from_file()
753
+ std::string path_model; // populated by whisper_init_from_file_with_params()
739
754
  #ifdef WHISPER_USE_COREML
740
755
  whisper_coreml_context * ctx_coreml = nullptr;
741
756
  #endif
@@ -769,10 +784,8 @@ struct whisper_context {
769
784
  whisper_vocab vocab;
770
785
  whisper_state * state = nullptr;
771
786
 
772
- std::string path_model; // populated by whisper_init_from_file()
773
- #ifdef WHISPER_USE_COREML
774
- bool load_coreml = true;
775
- #endif
787
+ std::string path_model; // populated by whisper_init_from_file_with_params()
788
+ whisper_context_params params;
776
789
  };
777
790
 
778
791
  static void whisper_default_log(const char * text) {
@@ -923,6 +936,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
923
936
 
924
937
  assert(hparams.n_text_state == hparams.n_audio_state);
925
938
 
939
+ std::string mver = "";
940
+
926
941
  if (hparams.n_audio_layer == 4) {
927
942
  model.type = e_model::MODEL_TINY;
928
943
  }
@@ -941,6 +956,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
941
956
 
942
957
  if (hparams.n_audio_layer == 32) {
943
958
  model.type = e_model::MODEL_LARGE;
959
+
960
+ if (hparams.n_vocab == 51866) {
961
+ mver = " v3";
962
+ }
944
963
  }
945
964
 
946
965
  const int32_t qntvr = hparams.ftype / WSP_GGML_QNT_VERSION_FACTOR;
@@ -969,7 +988,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
969
988
  log("%s: n_mels = %d\n", __func__, hparams.n_mels);
970
989
  log("%s: ftype = %d\n", __func__, model.hparams.ftype);
971
990
  log("%s: qntvr = %d\n", __func__, qntvr);
972
- log("%s: type = %d\n", __func__, model.type);
991
+ log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
973
992
 
974
993
  // print memory requirements
975
994
  {
@@ -1040,13 +1059,17 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1040
1059
  if (vocab.is_multilingual()) {
1041
1060
  vocab.token_eot++;
1042
1061
  vocab.token_sot++;
1043
- vocab.token_translate++;
1044
- vocab.token_transcribe++;
1045
- vocab.token_solm++;
1046
- vocab.token_prev++;
1047
- vocab.token_nosp++;
1048
- vocab.token_not++;
1049
- vocab.token_beg++;
1062
+
1063
+ // account for variable number of language tokens
1064
+ const int dt = vocab.num_languages() - 98;
1065
+
1066
+ vocab.token_translate += dt;
1067
+ vocab.token_transcribe += dt;
1068
+ vocab.token_solm += dt;
1069
+ vocab.token_prev += dt;
1070
+ vocab.token_nosp += dt;
1071
+ vocab.token_not += dt;
1072
+ vocab.token_beg += dt;
1050
1073
  }
1051
1074
 
1052
1075
  if (n_vocab < model.hparams.n_vocab) {
@@ -1075,6 +1098,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1075
1098
  vocab.id_to_token[i] = word;
1076
1099
  }
1077
1100
  }
1101
+
1102
+ log("%s: n_langs = %d\n", __func__, vocab.num_languages());
1078
1103
  }
1079
1104
 
1080
1105
  size_t ctx_size = 0;
@@ -1619,7 +1644,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1619
1644
 
1620
1645
  struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1621
1646
 
1622
- wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1647
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
1623
1648
 
1624
1649
  wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
1625
1650
 
@@ -2037,7 +2062,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2037
2062
 
2038
2063
  struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
2039
2064
 
2040
- wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
2065
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
2041
2066
 
2042
2067
  wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
2043
2068
 
@@ -2856,8 +2881,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2856
2881
  log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2857
2882
  }
2858
2883
 
2884
+
2859
2885
  #ifdef WHISPER_USE_COREML
2860
- if (ctx->load_coreml) { // Not in correct layer for easy patch
2886
+ if (ctx->params.use_coreml) {
2861
2887
  const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
2862
2888
 
2863
2889
  log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
@@ -2873,7 +2899,7 @@ if (ctx->load_coreml) { // Not in correct layer for easy patch
2873
2899
  } else {
2874
2900
  log("%s: Core ML model loaded\n", __func__);
2875
2901
  }
2876
- }
2902
+ }
2877
2903
  #endif
2878
2904
 
2879
2905
  state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
@@ -2934,59 +2960,64 @@ if (ctx->load_coreml) { // Not in correct layer for easy patch
2934
2960
  }
2935
2961
 
2936
2962
  #ifdef WSP_GGML_USE_METAL
2937
- state->ctx_metal = wsp_ggml_metal_init(1);
2938
- if (!state->ctx_metal) {
2939
- log("%s: wsp_ggml_metal_init() failed\n", __func__);
2940
- delete state;
2941
- return nullptr;
2963
+ if (ctx->params.use_gpu) {
2964
+ state->ctx_metal = wsp_ggml_metal_init(1);
2965
+ if (!state->ctx_metal) {
2966
+ log("%s: wsp_ggml_metal_init() failed\n", __func__);
2967
+ delete state;
2968
+ return nullptr;
2969
+ }
2942
2970
  }
2943
2971
 
2944
- log("%s: Metal context initialized\n", __func__);
2972
+ if (state->ctx_metal) {
2973
+ log("%s: Metal context initialized\n", __func__);
2945
2974
 
2946
- // this allocates all Metal resources and memory buffers
2975
+ // this allocates all Metal resources and memory buffers
2947
2976
 
2948
- void * data_ptr = NULL;
2949
- size_t data_size = 0;
2977
+ void * data_ptr = NULL;
2978
+ size_t data_size = 0;
2950
2979
 
2951
- // TODO: add mmap support
2952
- //if (params.use_mmap) {
2953
- // data_ptr = ctx->model.mapping->addr;
2954
- // data_size = ctx->model.mapping->size;
2955
- //} else {
2956
- // data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2957
- // data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2958
- //}
2980
+ // TODO: add mmap support
2981
+ //if (params.use_mmap) {
2982
+ // data_ptr = ctx->model.mapping->addr;
2983
+ // data_size = ctx->model.mapping->size;
2984
+ //} else {
2985
+ // data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2986
+ // data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2987
+ //}
2959
2988
 
2960
- data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2961
- data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2989
+ data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2990
+ data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2962
2991
 
2963
- const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx);
2992
+ const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx);
2964
2993
 
2965
- log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
2994
+ log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
2966
2995
 
2967
2996
  #define WHISPER_METAL_CHECK_BUF(result) \
2968
- if (!(result)) { \
2969
- log("%s: failed to add metal buffer\n", __func__); \
2970
- delete state; \
2971
- return nullptr; \
2972
- }
2997
+ if (!(result)) { \
2998
+ log("%s: failed to add metal buffer\n", __func__); \
2999
+ delete state; \
3000
+ return nullptr; \
3001
+ }
2973
3002
 
2974
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
3003
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
2975
3004
 
2976
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
2977
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
2978
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
2979
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
3005
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
3006
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
3007
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
3008
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
2980
3009
 
2981
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
2982
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
2983
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
2984
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
3010
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
3011
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
3012
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
3013
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
2985
3014
 
2986
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
3015
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
2987
3016
 
2988
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
3017
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
2989
3018
  #undef WHISPER_METAL_CHECK_BUF
3019
+
3020
+ }
2990
3021
  #endif
2991
3022
 
2992
3023
  state->rng = std::mt19937(0);
@@ -2994,23 +3025,6 @@ if (ctx->load_coreml) { // Not in correct layer for easy patch
2994
3025
  return state;
2995
3026
  }
2996
3027
 
2997
- #ifdef WHISPER_USE_COREML
2998
- struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) {
2999
- whisper_context * ctx = whisper_init_from_file_no_state(path_model);
3000
- if (!ctx) {
3001
- return nullptr;
3002
- }
3003
- ctx->load_coreml = false;
3004
- ctx->state = whisper_init_state(ctx);
3005
- if (!ctx->state) {
3006
- whisper_free(ctx);
3007
- return nullptr;
3008
- }
3009
-
3010
- return ctx;
3011
- }
3012
- #endif
3013
-
3014
3028
  int whisper_ctx_init_openvino_encoder(
3015
3029
  struct whisper_context * ctx,
3016
3030
  const char * model_path,
@@ -3060,7 +3074,15 @@ int whisper_ctx_init_openvino_encoder(
3060
3074
  #endif
3061
3075
  }
3062
3076
 
3063
- struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
3077
+ struct whisper_context_params whisper_context_default_params() {
3078
+ struct whisper_context_params result = {
3079
+ /*.use_gpu =*/ true,
3080
+ /*.use_coreml =*/ false,
3081
+ };
3082
+ return result;
3083
+ }
3084
+
3085
+ struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3064
3086
  log("%s: loading model from '%s'\n", __func__, path_model);
3065
3087
 
3066
3088
  auto fin = std::ifstream(path_model, std::ios::binary);
@@ -3089,7 +3111,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
3089
3111
  fin->close();
3090
3112
  };
3091
3113
 
3092
- auto ctx = whisper_init_no_state(&loader);
3114
+ auto ctx = whisper_init_with_params_no_state(&loader, params);
3093
3115
 
3094
3116
  if (ctx) {
3095
3117
  ctx->path_model = path_model;
@@ -3098,7 +3120,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
3098
3120
  return ctx;
3099
3121
  }
3100
3122
 
3101
- struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
3123
+ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) {
3102
3124
  struct buf_context {
3103
3125
  uint8_t* buffer;
3104
3126
  size_t size;
@@ -3132,13 +3154,14 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
3132
3154
 
3133
3155
  loader.close = [](void * /*ctx*/) { };
3134
3156
 
3135
- return whisper_init_no_state(&loader);
3157
+ return whisper_init_with_params_no_state(&loader, params);
3136
3158
  }
3137
3159
 
3138
- struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
3160
+ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
3139
3161
  wsp_ggml_time_init();
3140
3162
 
3141
3163
  whisper_context * ctx = new whisper_context;
3164
+ ctx->params = params;
3142
3165
 
3143
3166
  if (!whisper_model_load(loader, *ctx)) {
3144
3167
  loader->close(loader->context);
@@ -3152,8 +3175,8 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
3152
3175
  return ctx;
3153
3176
  }
3154
3177
 
3155
- struct whisper_context * whisper_init_from_file(const char * path_model) {
3156
- whisper_context * ctx = whisper_init_from_file_no_state(path_model);
3178
+ struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) {
3179
+ whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params);
3157
3180
  if (!ctx) {
3158
3181
  return nullptr;
3159
3182
  }
@@ -3167,8 +3190,8 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
3167
3190
  return ctx;
3168
3191
  }
3169
3192
 
3170
- struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
3171
- whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
3193
+ struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) {
3194
+ whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params);
3172
3195
  if (!ctx) {
3173
3196
  return nullptr;
3174
3197
  }
@@ -3182,8 +3205,8 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
3182
3205
  return ctx;
3183
3206
  }
3184
3207
 
3185
- struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
3186
- whisper_context * ctx = whisper_init_no_state(loader);
3208
+ struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) {
3209
+ whisper_context * ctx = whisper_init_with_params_no_state(loader, params);
3187
3210
  if (!ctx) {
3188
3211
  return nullptr;
3189
3212
  }
@@ -3197,6 +3220,30 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
3197
3220
  return ctx;
3198
3221
  }
3199
3222
 
3223
+ struct whisper_context * whisper_init_from_file(const char * path_model) {
3224
+ return whisper_init_from_file_with_params(path_model, whisper_context_default_params());
3225
+ }
3226
+
3227
+ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
3228
+ return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params());
3229
+ }
3230
+
3231
+ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
3232
+ return whisper_init_with_params(loader, whisper_context_default_params());
3233
+ }
3234
+
3235
+ struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
3236
+ return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params());
3237
+ }
3238
+
3239
+ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
3240
+ return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params());
3241
+ }
3242
+
3243
+ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
3244
+ return whisper_init_with_params_no_state(loader, whisper_context_default_params());
3245
+ }
3246
+
3200
3247
  void whisper_free_state(struct whisper_state * state)
3201
3248
  {
3202
3249
  if (state) {
@@ -3251,6 +3298,12 @@ void whisper_free(struct whisper_context * ctx) {
3251
3298
  }
3252
3299
  }
3253
3300
 
3301
+ void whisper_free_context_params(struct whisper_context_params * params) {
3302
+ if (params) {
3303
+ delete params;
3304
+ }
3305
+ }
3306
+
3254
3307
  void whisper_free_params(struct whisper_full_params * params) {
3255
3308
  if (params) {
3256
3309
  delete params;
@@ -3258,7 +3311,7 @@ void whisper_free_params(struct whisper_full_params * params) {
3258
3311
  }
3259
3312
 
3260
3313
  int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3261
- if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
3314
+ if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3262
3315
  log("%s: failed to compute mel spectrogram\n", __func__);
3263
3316
  return -1;
3264
3317
  }
@@ -3272,7 +3325,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3272
3325
 
3273
3326
  // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3274
3327
  int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3275
- if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
3328
+ if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3276
3329
  log("%s: failed to compute mel spectrogram\n", __func__);
3277
3330
  return -1;
3278
3331
  }
@@ -3295,13 +3348,13 @@ int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float *
3295
3348
  // TODO
3296
3349
 
3297
3350
  int whisper_set_mel_with_state(
3298
- struct whisper_context * /*ctx*/,
3351
+ struct whisper_context * ctx,
3299
3352
  struct whisper_state * state,
3300
3353
  const float * data,
3301
3354
  int n_len,
3302
3355
  int n_mel) {
3303
- if (n_mel != WHISPER_N_MEL) {
3304
- log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
3356
+ if (n_mel != ctx->model.filters.n_mel) {
3357
+ log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
3305
3358
  return -1;
3306
3359
  }
3307
3360
 
@@ -3665,6 +3718,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
3665
3718
  }
3666
3719
 
3667
3720
  void whisper_reset_timings(struct whisper_context * ctx) {
3721
+ ctx->t_start_us = wsp_ggml_time_us();
3668
3722
  if (ctx->state != nullptr) {
3669
3723
  ctx->state->t_sample_us = 0;
3670
3724
  ctx->state->t_encode_us = 0;
@@ -3719,6 +3773,14 @@ const char * whisper_print_system_info(void) {
3719
3773
 
3720
3774
  ////////////////////////////////////////////////////////////////////////////
3721
3775
 
3776
+ struct whisper_context_params * whisper_context_default_params_by_ref() {
3777
+ struct whisper_context_params params = whisper_context_default_params();
3778
+
3779
+ struct whisper_context_params* result = new whisper_context_params();
3780
+ *result = params;
3781
+ return result;
3782
+ }
3783
+
3722
3784
  struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
3723
3785
  struct whisper_full_params params = whisper_full_default_params(strategy);
3724
3786
 
@@ -3795,8 +3857,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3795
3857
  /*.encoder_begin_callback =*/ nullptr,
3796
3858
  /*.encoder_begin_callback_user_data =*/ nullptr,
3797
3859
 
3798
- /*.abort_callback =*/ nullptr,
3799
- /*.abort_callback_user_data =*/ nullptr,
3860
+ /*.abort_callback =*/ nullptr,
3861
+ /*.abort_callback_user_data =*/ nullptr,
3800
3862
 
3801
3863
  /*.logits_filter_callback =*/ nullptr,
3802
3864
  /*.logits_filter_callback_user_data =*/ nullptr,
@@ -3964,6 +4026,7 @@ static void whisper_process_logits(
3964
4026
  // suppress task tokens
3965
4027
  logits[vocab.token_translate] = -INFINITY;
3966
4028
  logits[vocab.token_transcribe] = -INFINITY;
4029
+ logits[vocab.token_prev] = -INFINITY;
3967
4030
 
3968
4031
  if (params.logits_filter_callback) {
3969
4032
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
@@ -4530,17 +4593,19 @@ int whisper_full_with_state(
4530
4593
 
4531
4594
  // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
4532
4595
  #ifdef WSP_GGML_USE_METAL
4596
+ if (state->ctx_metal) {
4533
4597
  #define WHISPER_METAL_CHECK_BUF(result) \
4534
- if (!(result)) { \
4535
- log("%s: failed to add metal buffer\n", __func__); \
4536
- return 0; \
4537
- }
4598
+ if (!(result)) { \
4599
+ log("%s: failed to add metal buffer\n", __func__); \
4600
+ return 0; \
4601
+ }
4538
4602
 
4539
- const std::string kv_name = "kv_self_" + std::to_string(j);
4540
- auto & kv_self = decoder.kv_self;
4603
+ const std::string kv_name = "kv_self_" + std::to_string(j);
4604
+ auto & kv_self = decoder.kv_self;
4541
4605
 
4542
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
4606
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
4543
4607
  #undef WHISPER_METAL_CHECK_BUF
4608
+ }
4544
4609
  #endif
4545
4610
  }
4546
4611
  }
@@ -4557,7 +4622,7 @@ int whisper_full_with_state(
4557
4622
 
4558
4623
  // initial prompt
4559
4624
  if (!params.prompt_tokens && params.initial_prompt) {
4560
- prompt_tokens.resize(2048);
4625
+ prompt_tokens.resize(1024);
4561
4626
  prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
4562
4627
  params.prompt_tokens = prompt_tokens.data();
4563
4628
  params.prompt_n_tokens = prompt_tokens.size();
@@ -4582,6 +4647,7 @@ int whisper_full_with_state(
4582
4647
 
4583
4648
  // these tokens determine the task that will be performed
4584
4649
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
4650
+
4585
4651
  if (whisper_is_multilingual(ctx)) {
4586
4652
  const int lang_id = whisper_lang_id(params.language);
4587
4653
  state->lang_id = lang_id;
@@ -4593,6 +4659,17 @@ int whisper_full_with_state(
4593
4659
  }
4594
4660
  }
4595
4661
 
4662
+ {
4663
+ const bool is_distil = ctx->model.hparams.n_text_layer == 2;
4664
+
4665
+ // distilled models require the "no_timestamps" token
4666
+ // TODO: add input parameter (#1229)
4667
+ if (is_distil) {
4668
+ log("%s: using distilled model - forcing no_timestamps\n", __func__);
4669
+ prompt_init.push_back(whisper_token_not(ctx));
4670
+ }
4671
+ }
4672
+
4596
4673
  int seek = seek_start;
4597
4674
 
4598
4675
  std::vector<whisper_token> prompt;
@@ -5454,7 +5531,7 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
5454
5531
  // b: N*N*sizeof(float)
5455
5532
  // c: N*N*sizeof(float)
5456
5533
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5457
- std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead());
5534
+ std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead() + wsp_ggml_graph_overhead());
5458
5535
  std::vector<uint8_t> work;
5459
5536
 
5460
5537
  // put a bunch of random data in the buffer
@@ -5505,17 +5582,19 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
5505
5582
 
5506
5583
  struct wsp_ggml_tensor * c = wsp_ggml_mul_mat(ctx0, a, b);
5507
5584
 
5508
- struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(c);
5585
+ struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
5586
+
5587
+ wsp_ggml_build_forward_expand(gf, c);
5509
5588
 
5510
5589
  double tsum = 0.0;
5511
5590
 
5512
5591
  // heat-up
5513
- wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
5592
+ wsp_ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
5514
5593
 
5515
5594
  for (int i = 0; i < n_max; ++i) {
5516
5595
  const int64_t t0 = wsp_ggml_time_us();
5517
5596
 
5518
- wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
5597
+ wsp_ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
5519
5598
 
5520
5599
  const int64_t t1 = wsp_ggml_time_us();
5521
5600