whisper.rn 0.3.0-rc.7 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/whisper.cpp CHANGED
@@ -1,8 +1,12 @@
1
1
  #include "whisper.h"
2
- #if WHISPER_USE_COREML
2
+ #ifdef WHISPER_USE_COREML
3
3
  #include "coreml/whisper-encoder.h"
4
4
  #endif
5
5
 
6
+ #if WHISPER_USE_OPENVINO
7
+ #include "openvino/whisper-openvino-encoder.h"
8
+ #endif
9
+
6
10
  #include "ggml.h"
7
11
 
8
12
  #include <algorithm>
@@ -10,6 +14,7 @@
10
14
  #define _USE_MATH_DEFINES
11
15
  #include <cmath>
12
16
  #include <cstdio>
17
+ #include <cstdarg>
13
18
  #include <cstring>
14
19
  #include <fstream>
15
20
  #include <map>
@@ -19,6 +24,10 @@
19
24
  #include <regex>
20
25
  #include <random>
21
26
 
27
+ #if defined(_MSC_VER)
28
+ #pragma warning(disable: 4244 4267) // possible loss of data
29
+ #endif
30
+
22
31
  #if defined(GGML_BIG_ENDIAN)
23
32
  #include <bit>
24
33
 
@@ -84,7 +93,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
84
93
  #define WHISPER_ASSERT(x) \
85
94
  do { \
86
95
  if (!(x)) { \
87
- fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
96
+ log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
88
97
  abort(); \
89
98
  } \
90
99
  } while (0)
@@ -376,16 +385,18 @@ struct whisper_vocab {
376
385
  std::map<token, id> token_to_id;
377
386
  std::map<id, token> id_to_token;
378
387
 
379
- id token_eot = 50256;
380
- id token_sot = 50257;
381
- id token_prev = 50360;
382
- id token_solm = 50361; // ??
383
- id token_not = 50362; // no timestamps
384
- id token_beg = 50363;
385
-
386
- // available tasks
387
- static const id token_translate = 50358;
388
- static const id token_transcribe = 50359;
388
+ // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349
389
+ id token_eot = 50256;
390
+ id token_sot = 50257;
391
+ // task tokens (used only for multilingual models)
392
+ id token_translate = 50357;
393
+ id token_transcribe = 50358;
394
+ // other special tokens
395
+ id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn
396
+ id token_prev = 50360;
397
+ id token_nosp = 50361;
398
+ id token_not = 50362; // no timestamps
399
+ id token_beg = 50363; // begin timestamps
389
400
 
390
401
  bool is_multilingual() const {
391
402
  return n_vocab == 51865;
@@ -399,6 +410,8 @@ struct whisper_segment {
399
410
  std::string text;
400
411
 
401
412
  std::vector<whisper_token_data> tokens;
413
+
414
+ bool speaker_turn_next;
402
415
  };
403
416
 
404
417
  // medium
@@ -652,6 +665,10 @@ struct whisper_state {
652
665
  whisper_coreml_context * ctx_coreml = nullptr;
653
666
  #endif
654
667
 
668
+ #ifdef WHISPER_USE_OPENVINO
669
+ whisper_openvino_context * ctx_openvino = nullptr;
670
+ #endif
671
+
655
672
  // [EXPERIMENTAL] token-level timestamps data
656
673
  int64_t t_beg = 0;
657
674
  int64_t t_last = 0;
@@ -707,6 +724,21 @@ struct whisper_context {
707
724
  std::string path_model; // populated by whisper_init_from_file()
708
725
  };
709
726
 
727
+ static void whisper_default_log(const char * text) {
728
+ fprintf(stderr, "%s", text);
729
+ }
730
+
731
+ static whisper_log_callback whisper_log = whisper_default_log;
732
+
733
+ static void log(const char * fmt, ...) {
734
+ if (!whisper_log) return;
735
+ char buf[1024];
736
+ va_list args;
737
+ va_start(args, fmt);
738
+ vsnprintf(buf, sizeof(buf), fmt, args);
739
+ whisper_log(buf);
740
+ }
741
+
710
742
  template<typename T>
711
743
  static void read_safe(whisper_model_loader * loader, T & dest) {
712
744
  loader->read(loader->context, &dest, sizeof(T));
@@ -730,7 +762,7 @@ static bool kv_cache_init(
730
762
  cache.ctx = ggml_init(params);
731
763
 
732
764
  if (!cache.ctx) {
733
- fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
765
+ log("%s: failed to allocate memory for kv cache\n", __func__);
734
766
  return false;
735
767
  }
736
768
 
@@ -766,7 +798,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
766
798
  cache.ctx = ggml_init(params);
767
799
 
768
800
  if (!cache.ctx) {
769
- fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
801
+ log("%s: failed to allocate memory for kv cache\n", __func__);
770
802
  return false;
771
803
  }
772
804
 
@@ -795,7 +827,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
795
827
  // see the convert-pt-to-ggml.py script for details
796
828
  //
797
829
  static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
798
- fprintf(stderr, "%s: loading model\n", __func__);
830
+ log("%s: loading model\n", __func__);
799
831
 
800
832
  const int64_t t_start_us = ggml_time_us();
801
833
 
@@ -808,8 +840,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
808
840
  {
809
841
  uint32_t magic;
810
842
  read_safe(loader, magic);
811
- if (magic != 0x67676d6c) {
812
- fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
843
+ if (magic != GGML_FILE_MAGIC) {
844
+ log("%s: invalid model data (bad magic)\n", __func__);
813
845
  return false;
814
846
  }
815
847
  }
@@ -860,25 +892,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
860
892
  // in order to save memory and also to speed up the computation
861
893
  wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
862
894
  if (wctx.wtype == GGML_TYPE_COUNT) {
863
- fprintf(stderr, "%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
895
+ log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
864
896
  return false;
865
897
  }
866
898
 
867
899
  const size_t scale = model.hparams.ftype ? 1 : 2;
868
900
 
869
- fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
870
- fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
871
- fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
872
- fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
873
- fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
874
- fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
875
- fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
876
- fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
877
- fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
878
- fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
879
- fprintf(stderr, "%s: ftype = %d\n", __func__, model.hparams.ftype);
880
- fprintf(stderr, "%s: qntvr = %d\n", __func__, qntvr);
881
- fprintf(stderr, "%s: type = %d\n", __func__, model.type);
901
+ log("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
902
+ log("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
903
+ log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
904
+ log("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
905
+ log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
906
+ log("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
907
+ log("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
908
+ log("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
909
+ log("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
910
+ log("%s: n_mels = %d\n", __func__, hparams.n_mels);
911
+ log("%s: ftype = %d\n", __func__, model.hparams.ftype);
912
+ log("%s: qntvr = %d\n", __func__, qntvr);
913
+ log("%s: type = %d\n", __func__, model.type);
882
914
 
883
915
  // print memory requirements
884
916
  {
@@ -896,7 +928,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
896
928
  const size_t mem_required_decoder =
897
929
  scale*MEM_REQ_KV_SELF.at(model.type);
898
930
 
899
- fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
931
+ log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
900
932
  mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
901
933
  }
902
934
 
@@ -928,7 +960,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
928
960
  read_safe(loader, n_vocab);
929
961
 
930
962
  //if (n_vocab != model.hparams.n_vocab) {
931
- // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
963
+ // log("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
932
964
  // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
933
965
  // return false;
934
966
  //}
@@ -948,7 +980,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
948
980
  word.assign(&tmp[0], tmp.size());
949
981
  } else {
950
982
  // seems like we have an empty-string token in multi-language models (i = 50256)
951
- //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
983
+ //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
952
984
  word = "";
953
985
  }
954
986
 
@@ -962,14 +994,17 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
962
994
  if (vocab.is_multilingual()) {
963
995
  vocab.token_eot++;
964
996
  vocab.token_sot++;
965
- vocab.token_prev++;
997
+ vocab.token_translate++;
998
+ vocab.token_transcribe++;
966
999
  vocab.token_solm++;
1000
+ vocab.token_prev++;
1001
+ vocab.token_nosp++;
967
1002
  vocab.token_not++;
968
1003
  vocab.token_beg++;
969
1004
  }
970
1005
 
971
1006
  if (n_vocab < model.hparams.n_vocab) {
972
- fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
1007
+ log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
973
1008
  for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
974
1009
  if (i > vocab.token_beg) {
975
1010
  word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
@@ -977,8 +1012,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
977
1012
  word = "[_EOT_]";
978
1013
  } else if (i == vocab.token_sot) {
979
1014
  word = "[_SOT_]";
1015
+ } else if (i == vocab.token_solm) {
1016
+ word = "[_SOLM_]";
980
1017
  } else if (i == vocab.token_prev) {
981
1018
  word = "[_PREV_]";
1019
+ } else if (i == vocab.token_nosp) {
1020
+ word = "[_NOSP_]";
982
1021
  } else if (i == vocab.token_not) {
983
1022
  word = "[_NOT_]";
984
1023
  } else if (i == vocab.token_beg) {
@@ -1104,7 +1143,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1104
1143
 
1105
1144
  ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
1106
1145
 
1107
- fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
1146
+ log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
1108
1147
  }
1109
1148
 
1110
1149
  // create the ggml context
@@ -1117,7 +1156,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1117
1156
 
1118
1157
  model.ctx = ggml_init(params);
1119
1158
  if (!model.ctx) {
1120
- fprintf(stderr, "%s: ggml_init() failed\n", __func__);
1159
+ log("%s: ggml_init() failed\n", __func__);
1121
1160
  return false;
1122
1161
  }
1123
1162
  }
@@ -1350,20 +1389,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1350
1389
  name.assign(&tmp[0], tmp.size());
1351
1390
 
1352
1391
  if (model.tensors.find(name) == model.tensors.end()) {
1353
- fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
1392
+ log("%s: unknown tensor '%s' in model file\n", __func__, name.data());
1354
1393
  return false;
1355
1394
  }
1356
1395
 
1357
1396
  auto tensor = model.tensors[name.data()];
1358
1397
  if (ggml_nelements(tensor) != nelements) {
1359
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1360
- fprintf(stderr, "%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1398
+ log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1399
+ log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1361
1400
  __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
1362
1401
  return false;
1363
1402
  }
1364
1403
 
1365
1404
  if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1366
- fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1405
+ log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1367
1406
  __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
1368
1407
  return false;
1369
1408
  }
@@ -1371,7 +1410,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1371
1410
  const size_t bpe = ggml_type_size(ggml_type(ttype));
1372
1411
 
1373
1412
  if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
1374
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1413
+ log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1375
1414
  __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1376
1415
  return false;
1377
1416
  }
@@ -1384,12 +1423,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1384
1423
  model.n_loaded++;
1385
1424
  }
1386
1425
 
1387
- fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
1426
+ log("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
1388
1427
 
1389
1428
  if (model.n_loaded == 0) {
1390
- fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1429
+ log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1391
1430
  } else if (model.n_loaded != (int) model.tensors.size()) {
1392
- fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1431
+ log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1393
1432
  return false;
1394
1433
  }
1395
1434
  }
@@ -1463,12 +1502,18 @@ static bool whisper_encode_internal(
1463
1502
  const bool use_coreml = wstate.ctx_coreml != nullptr;
1464
1503
  #endif
1465
1504
 
1466
- if (!use_coreml) {
1505
+ #ifndef WHISPER_USE_OPENVINO
1506
+ const bool use_openvino = false;
1507
+ #else
1508
+ const bool use_openvino = wstate.ctx_openvino != nullptr;
1509
+ #endif
1510
+
1511
+ if (!use_coreml && !use_openvino) {
1467
1512
  // convolution + gelu
1468
1513
  {
1469
1514
  wstate.use_buf(ctx0, 1);
1470
1515
 
1471
- cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1516
+ cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1472
1517
  cur = ggml_add(ctx0,
1473
1518
  ggml_repeat(ctx0,
1474
1519
  model.e_conv_1_b,
@@ -1479,7 +1524,7 @@ static bool whisper_encode_internal(
1479
1524
 
1480
1525
  wstate.use_buf(ctx0, 0);
1481
1526
 
1482
- cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1527
+ cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1483
1528
  cur = ggml_add(ctx0,
1484
1529
  ggml_repeat(ctx0,
1485
1530
  model.e_conv_2_b,
@@ -1762,8 +1807,7 @@ static bool whisper_encode_internal(
1762
1807
  }
1763
1808
  }
1764
1809
  #ifdef WHISPER_USE_COREML
1765
- else
1766
- {
1810
+ else if (use_coreml) {
1767
1811
  wstate.use_buf(ctx0, -1);
1768
1812
 
1769
1813
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
@@ -1771,6 +1815,17 @@ static bool whisper_encode_internal(
1771
1815
  whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1772
1816
  }
1773
1817
  #endif
1818
+ #ifdef WHISPER_USE_OPENVINO
1819
+ else if (use_openvino) {
1820
+ wstate.use_buf(ctx0, -1);
1821
+
1822
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1823
+
1824
+ if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
1825
+ return false;
1826
+ }
1827
+ }
1828
+ #endif
1774
1829
 
1775
1830
  // cur
1776
1831
  //{
@@ -2577,7 +2632,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
2577
2632
  --j;
2578
2633
  }
2579
2634
  if (!found) {
2580
- fprintf(stderr, "unknown token \n");
2635
+ log("unknown token\n");
2581
2636
  ++i;
2582
2637
  }
2583
2638
  }
@@ -2613,47 +2668,72 @@ static std::string whisper_get_coreml_path_encoder(std::string path_bin) {
2613
2668
  }
2614
2669
  #endif
2615
2670
 
2671
+ #ifdef WHISPER_USE_OPENVINO
2672
+ // replace .bin with-encoder-openvino.xml
2673
+ static std::string whisper_openvino_get_path_encoder(std::string path_bin) {
2674
+ auto pos = path_bin.rfind('.');
2675
+ if (pos != std::string::npos) {
2676
+ path_bin = path_bin.substr(0, pos);
2677
+ }
2678
+
2679
+ path_bin += "-encoder-openvino.xml";
2680
+
2681
+ return path_bin;
2682
+ }
2683
+
2684
+ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
2685
+ auto pos = path_bin.rfind('.');
2686
+ if (pos != std::string::npos) {
2687
+ path_bin = path_bin.substr(0, pos);
2688
+ }
2689
+
2690
+ path_bin += "-encoder-openvino-cache";
2691
+
2692
+ return path_bin;
2693
+ }
2694
+ #endif
2695
+
2616
2696
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
2617
2697
  whisper_state * state = new whisper_state;
2618
2698
 
2619
2699
  const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
2620
2700
 
2621
2701
  if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2622
- fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
2702
+ log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
2623
2703
  delete state;
2624
2704
  return nullptr;
2625
2705
  }
2626
2706
 
2627
2707
  {
2628
2708
  const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
2629
- fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2709
+ log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2630
2710
  }
2631
2711
 
2632
2712
  if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2633
- fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2713
+ log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2634
2714
  delete state;
2635
2715
  return nullptr;
2636
2716
  }
2637
2717
 
2638
2718
  {
2639
2719
  const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
2640
- fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2720
+ log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2641
2721
  }
2642
2722
 
2643
2723
  #ifdef WHISPER_USE_COREML
2644
2724
  const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
2645
2725
 
2646
- fprintf(stderr, "%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
2647
- fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__);
2726
+ log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
2727
+ log("%s: first run on a device may take a while ...\n", __func__);
2648
2728
 
2649
2729
  state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
2650
2730
  if (!state->ctx_coreml) {
2651
- fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2731
+ log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2652
2732
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
2653
2733
  return nullptr;
2654
2734
  #endif
2655
2735
  } else {
2656
- fprintf(stderr, "%s: Core ML model loaded\n", __func__);
2736
+ log("%s: Core ML model loaded\n", __func__);
2657
2737
  }
2658
2738
  #endif
2659
2739
 
@@ -2679,13 +2759,62 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2679
2759
  return state;
2680
2760
  }
2681
2761
 
2762
+ int whisper_ctx_init_openvino_encoder(
2763
+ struct whisper_context * ctx,
2764
+ const char * model_path,
2765
+ const char * device,
2766
+ const char * cache_dir) {
2767
+ #ifndef WHISPER_USE_OPENVINO
2768
+ (void)(ctx);
2769
+ (void)(model_path);
2770
+ (void)(device);
2771
+ (void)(cache_dir);
2772
+
2773
+ return 1;
2774
+ #else
2775
+ if (!model_path && ctx->path_model.empty()) {
2776
+ log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
2777
+ return 1;
2778
+ }
2779
+
2780
+ std::string path_encoder;
2781
+ if (!model_path) {
2782
+ //if model_path is not set, attempt to find it in the same directory as ggml-<model>.bin model
2783
+ path_encoder = whisper_openvino_get_path_encoder(ctx->path_model);
2784
+ } else {
2785
+ path_encoder = model_path;
2786
+ }
2787
+
2788
+ std::string path_cache;
2789
+ if (!cache_dir) {
2790
+ //if cache_dir is not set, set it as a dir residing next to ggml-<model>.bin
2791
+ path_cache = whisper_openvino_get_path_cache(ctx->path_model);
2792
+ } else {
2793
+ path_cache = cache_dir;
2794
+ }
2795
+
2796
+ log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
2797
+ log("%s: first run on a device may take a while ...\n", __func__);
2798
+
2799
+ ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
2800
+ if (!ctx->state->ctx_openvino) {
2801
+ log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
2802
+ return 1;
2803
+ } else {
2804
+ log("%s: OpenVINO model loaded\n", __func__);
2805
+ }
2806
+
2807
+ return 0;
2808
+ #endif
2809
+ }
2810
+
2682
2811
  struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
2683
2812
 
2684
- fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
2813
+ log("%s: loading model from '%s'\n", __func__, path_model);
2685
2814
 
2686
2815
  auto fin = std::ifstream(path_model, std::ios::binary);
2687
2816
  if (!fin) {
2688
- fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
2817
+ log("%s: failed to open '%s'\n", __func__, path_model);
2689
2818
  return nullptr;
2690
2819
  }
2691
2820
 
@@ -2727,7 +2856,7 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
2727
2856
 
2728
2857
  buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
2729
2858
 
2730
- fprintf(stderr, "%s: loading model from buffer\n", __func__);
2859
+ log("%s: loading model from buffer\n", __func__);
2731
2860
 
2732
2861
  whisper_model_loader loader = {};
2733
2862
 
@@ -2762,7 +2891,7 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
2762
2891
 
2763
2892
  if (!whisper_model_load(loader, *ctx)) {
2764
2893
  loader->close(loader->context);
2765
- fprintf(stderr, "%s: failed to load model\n", __func__);
2894
+ log("%s: failed to load model\n", __func__);
2766
2895
  delete ctx;
2767
2896
  return nullptr;
2768
2897
  }
@@ -2833,6 +2962,13 @@ void whisper_free_state(struct whisper_state * state)
2833
2962
  }
2834
2963
  #endif
2835
2964
 
2965
+ #ifdef WHISPER_USE_OPENVINO
2966
+ if (state->ctx_openvino != nullptr) {
2967
+ whisper_openvino_free(state->ctx_openvino);
2968
+ state->ctx_openvino = nullptr;
2969
+ }
2970
+ #endif
2971
+
2836
2972
  delete state;
2837
2973
  }
2838
2974
  }
@@ -2860,7 +2996,7 @@ void whisper_free_params(struct whisper_full_params * params) {
2860
2996
 
2861
2997
  int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
2862
2998
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
2863
- fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2999
+ log("%s: failed to compute mel spectrogram\n", __func__);
2864
3000
  return -1;
2865
3001
  }
2866
3002
 
@@ -2874,7 +3010,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
2874
3010
  // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
2875
3011
  int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
2876
3012
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
2877
- fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
3013
+ log("%s: failed to compute mel spectrogram\n", __func__);
2878
3014
  return -1;
2879
3015
  }
2880
3016
 
@@ -2893,7 +3029,7 @@ int whisper_set_mel_with_state(
2893
3029
  int n_len,
2894
3030
  int n_mel) {
2895
3031
  if (n_mel != WHISPER_N_MEL) {
2896
- fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
3032
+ log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
2897
3033
  return -1;
2898
3034
  }
2899
3035
 
@@ -2917,7 +3053,7 @@ int whisper_set_mel(
2917
3053
 
2918
3054
  int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
2919
3055
  if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
2920
- fprintf(stderr, "%s: failed to eval\n", __func__);
3056
+ log("%s: failed to eval\n", __func__);
2921
3057
  return -1;
2922
3058
  }
2923
3059
 
@@ -2926,7 +3062,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
2926
3062
 
2927
3063
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2928
3064
  if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
2929
- fprintf(stderr, "%s: failed to eval\n", __func__);
3065
+ log("%s: failed to eval\n", __func__);
2930
3066
  return -1;
2931
3067
  }
2932
3068
 
@@ -2937,7 +3073,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
2937
3073
  const int selected_decoder_id = 0;
2938
3074
 
2939
3075
  if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2940
- fprintf(stderr, "%s: failed to eval\n", __func__);
3076
+ log("%s: failed to eval\n", __func__);
2941
3077
  return 1;
2942
3078
  }
2943
3079
 
@@ -2949,13 +3085,13 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
2949
3085
  const int selected_decoder_id = 0;
2950
3086
 
2951
3087
  if (ctx->state == nullptr) {
2952
- fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__);
3088
+ log("%s: ERROR state was not loaded.\n", __func__);
2953
3089
  return false;
2954
3090
  }
2955
3091
 
2956
3092
 
2957
3093
  if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2958
- fprintf(stderr, "%s: failed to eval\n", __func__);
3094
+ log("%s: failed to eval\n", __func__);
2959
3095
  return 1;
2960
3096
  }
2961
3097
 
@@ -2966,7 +3102,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
2966
3102
  const auto res = tokenize(ctx->vocab, text);
2967
3103
 
2968
3104
  if (n_max_tokens < (int) res.size()) {
2969
- fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3105
+ log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
2970
3106
  return -1;
2971
3107
  }
2972
3108
 
@@ -2994,7 +3130,7 @@ int whisper_lang_id(const char * lang) {
2994
3130
  }
2995
3131
  }
2996
3132
 
2997
- fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
3133
+ log("%s: unknown language '%s'\n", __func__, lang);
2998
3134
  return -1;
2999
3135
  }
3000
3136
  return g_lang.at(lang).first;
@@ -3007,7 +3143,7 @@ const char * whisper_lang_str(int id) {
3007
3143
  }
3008
3144
  }
3009
3145
 
3010
- fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
3146
+ log("%s: unknown language id %d\n", __func__, id);
3011
3147
  return nullptr;
3012
3148
  }
3013
3149
 
@@ -3020,25 +3156,25 @@ int whisper_lang_auto_detect_with_state(
3020
3156
  const int seek = offset_ms/10;
3021
3157
 
3022
3158
  if (seek < 0) {
3023
- fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
3159
+ log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
3024
3160
  return -1;
3025
3161
  }
3026
3162
 
3027
3163
  if (seek >= state->mel.n_len_org) {
3028
- fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
3164
+ log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
3029
3165
  return -2;
3030
3166
  }
3031
3167
 
3032
3168
  // run the encoder
3033
3169
  if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
3034
- fprintf(stderr, "%s: failed to encode\n", __func__);
3170
+ log("%s: failed to encode\n", __func__);
3035
3171
  return -6;
3036
3172
  }
3037
3173
 
3038
3174
  const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
3039
3175
 
3040
3176
  if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
3041
- fprintf(stderr, "%s: failed to decode\n", __func__);
3177
+ log("%s: failed to decode\n", __func__);
3042
3178
  return -7;
3043
3179
  }
3044
3180
 
@@ -3204,12 +3340,16 @@ whisper_token whisper_token_sot(struct whisper_context * ctx) {
3204
3340
  return ctx->vocab.token_sot;
3205
3341
  }
3206
3342
 
3343
+ whisper_token whisper_token_solm(struct whisper_context * ctx) {
3344
+ return ctx->vocab.token_solm;
3345
+ }
3346
+
3207
3347
  whisper_token whisper_token_prev(struct whisper_context * ctx) {
3208
3348
  return ctx->vocab.token_prev;
3209
3349
  }
3210
3350
 
3211
- whisper_token whisper_token_solm(struct whisper_context * ctx) {
3212
- return ctx->vocab.token_solm;
3351
+ whisper_token whisper_token_nosp(struct whisper_context * ctx) {
3352
+ return ctx->vocab.token_nosp;
3213
3353
  }
3214
3354
 
3215
3355
  whisper_token whisper_token_not(struct whisper_context * ctx) {
@@ -3224,32 +3364,32 @@ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
3224
3364
  return whisper_token_sot(ctx) + 1 + lang_id;
3225
3365
  }
3226
3366
 
3227
- whisper_token whisper_token_translate(void) {
3228
- return whisper_vocab::token_translate;
3367
+ whisper_token whisper_token_translate(struct whisper_context * ctx) {
3368
+ return ctx->vocab.token_translate;
3229
3369
  }
3230
3370
 
3231
- whisper_token whisper_token_transcribe(void) {
3232
- return whisper_vocab::token_transcribe;
3371
+ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
3372
+ return ctx->vocab.token_transcribe;
3233
3373
  }
3234
3374
 
3235
3375
  void whisper_print_timings(struct whisper_context * ctx) {
3236
3376
  const int64_t t_end_us = ggml_time_us();
3237
3377
 
3238
- fprintf(stderr, "\n");
3239
- fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3378
+ log("\n");
3379
+ log("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3240
3380
  if (ctx->state != nullptr) {
3241
3381
 
3242
3382
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3243
3383
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3244
3384
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3245
3385
 
3246
- fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3247
- fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3248
- fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3249
- fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3250
- fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3386
+ log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3387
+ log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3388
+ log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3389
+ log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3390
+ log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3251
3391
  }
3252
- fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3392
+ log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3253
3393
  }
3254
3394
 
3255
3395
  void whisper_reset_timings(struct whisper_context * ctx) {
@@ -3268,6 +3408,14 @@ static int whisper_has_coreml(void) {
3268
3408
  #endif
3269
3409
  }
3270
3410
 
3411
+ static int whisper_has_openvino(void) {
3412
+ #ifdef WHISPER_USE_OPENVINO
3413
+ return 1;
3414
+ #else
3415
+ return 0;
3416
+ #endif
3417
+ }
3418
+
3271
3419
  const char * whisper_print_system_info(void) {
3272
3420
  static std::string s;
3273
3421
 
@@ -3285,6 +3433,7 @@ const char * whisper_print_system_info(void) {
3285
3433
  s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
3286
3434
  s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
3287
3435
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
3436
+ s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
3288
3437
 
3289
3438
  return s.c_str();
3290
3439
  }
@@ -3301,51 +3450,53 @@ struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sam
3301
3450
 
3302
3451
  struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
3303
3452
  struct whisper_full_params result = {
3304
- /*.strategy =*/ strategy,
3305
-
3306
- /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
3307
- /*.n_max_text_ctx =*/ 16384,
3308
- /*.offset_ms =*/ 0,
3309
- /*.duration_ms =*/ 0,
3310
-
3311
- /*.translate =*/ false,
3312
- /*.no_context =*/ true,
3313
- /*.single_segment =*/ false,
3314
- /*.print_special =*/ false,
3315
- /*.print_progress =*/ true,
3316
- /*.print_realtime =*/ false,
3317
- /*.print_timestamps =*/ true,
3318
-
3319
- /*.token_timestamps =*/ false,
3320
- /*.thold_pt =*/ 0.01f,
3321
- /*.thold_ptsum =*/ 0.01f,
3322
- /*.max_len =*/ 0,
3323
- /*.split_on_word =*/ false,
3324
- /*.max_tokens =*/ 0,
3325
-
3326
- /*.speed_up =*/ false,
3327
- /*.audio_ctx =*/ 0,
3328
-
3329
- /*.initial_prompt =*/ nullptr,
3330
- /*.prompt_tokens =*/ nullptr,
3331
- /*.prompt_n_tokens =*/ 0,
3332
-
3333
- /*.language =*/ "en",
3334
- /*.detect_language =*/ false,
3335
-
3336
- /*.suppress_blank =*/ true,
3453
+ /*.strategy =*/ strategy,
3454
+
3455
+ /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
3456
+ /*.n_max_text_ctx =*/ 16384,
3457
+ /*.offset_ms =*/ 0,
3458
+ /*.duration_ms =*/ 0,
3459
+
3460
+ /*.translate =*/ false,
3461
+ /*.no_context =*/ true,
3462
+ /*.single_segment =*/ false,
3463
+ /*.print_special =*/ false,
3464
+ /*.print_progress =*/ true,
3465
+ /*.print_realtime =*/ false,
3466
+ /*.print_timestamps =*/ true,
3467
+
3468
+ /*.token_timestamps =*/ false,
3469
+ /*.thold_pt =*/ 0.01f,
3470
+ /*.thold_ptsum =*/ 0.01f,
3471
+ /*.max_len =*/ 0,
3472
+ /*.split_on_word =*/ false,
3473
+ /*.max_tokens =*/ 0,
3474
+
3475
+ /*.speed_up =*/ false,
3476
+ /*.audio_ctx =*/ 0,
3477
+
3478
+ /*.tdrz_enable =*/ false,
3479
+
3480
+ /*.initial_prompt =*/ nullptr,
3481
+ /*.prompt_tokens =*/ nullptr,
3482
+ /*.prompt_n_tokens =*/ 0,
3483
+
3484
+ /*.language =*/ "en",
3485
+ /*.detect_language =*/ false,
3486
+
3487
+ /*.suppress_blank =*/ true,
3337
3488
  /*.suppress_non_speech_tokens =*/ false,
3338
3489
 
3339
- /*.temperature =*/ 0.0f,
3340
- /*.max_initial_ts =*/ 1.0f,
3341
- /*.length_penalty =*/ -1.0f,
3490
+ /*.temperature =*/ 0.0f,
3491
+ /*.max_initial_ts =*/ 1.0f,
3492
+ /*.length_penalty =*/ -1.0f,
3342
3493
 
3343
- /*.temperature_inc =*/ 0.4f,
3344
- /*.entropy_thold =*/ 2.4f,
3345
- /*.logprob_thold =*/ -1.0f,
3346
- /*.no_speech_thold =*/ 0.6f,
3494
+ /*.temperature_inc =*/ 0.4f,
3495
+ /*.entropy_thold =*/ 2.4f,
3496
+ /*.logprob_thold =*/ -1.0f,
3497
+ /*.no_speech_thold =*/ 0.6f,
3347
3498
 
3348
- /*.greedy =*/ {
3499
+ /*.greedy =*/ {
3349
3500
  /*.best_of =*/ -1,
3350
3501
  },
3351
3502
 
@@ -3397,26 +3548,6 @@ static void whisper_exp_compute_token_level_timestamps(
3397
3548
  float thold_pt,
3398
3549
  float thold_ptsum);
3399
3550
 
3400
- // trim from start (in place)
3401
- static inline void ltrim(std::string &s) {
3402
- s.erase(s.begin(), std::find_if_not(s.begin(), s.end(), [](unsigned char ch) {
3403
- return std::isspace(ch);
3404
- }));
3405
- }
3406
-
3407
- // trim from end (in place)
3408
- static inline void rtrim(std::string &s) {
3409
- s.erase(std::find_if_not(s.rbegin(), s.rend(), [](unsigned char ch) {
3410
- return std::isspace(ch);
3411
- }).base(), s.end());
3412
- }
3413
-
3414
- // trim from both ends (in place)
3415
- static inline void trim(std::string &s) {
3416
- rtrim(s);
3417
- ltrim(s);
3418
- }
3419
-
3420
3551
  static inline bool should_split_on_word(const char * txt, bool split_on_word) {
3421
3552
  if (!split_on_word) return true;
3422
3553
 
@@ -3443,14 +3574,10 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
3443
3574
  const int cur = strlen(txt);
3444
3575
 
3445
3576
  if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
3446
- // split here
3447
- if (split_on_word) {
3448
- trim(text);
3449
- }
3450
-
3451
3577
  state.result_all.back().text = std::move(text);
3452
3578
  state.result_all.back().t1 = token.t0;
3453
3579
  state.result_all.back().tokens.resize(i);
3580
+ state.result_all.back().speaker_turn_next = false;
3454
3581
 
3455
3582
  state.result_all.push_back({});
3456
3583
  state.result_all.back().t0 = token.t0;
@@ -3462,6 +3589,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
3462
3589
  segment.tokens.begin() + i,
3463
3590
  segment.tokens.end());
3464
3591
 
3592
+ state.result_all.back().speaker_turn_next = segment.speaker_turn_next;
3593
+
3465
3594
  acc = 0;
3466
3595
  text = "";
3467
3596
 
@@ -3475,9 +3604,6 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
3475
3604
  }
3476
3605
  }
3477
3606
 
3478
- if (split_on_word) {
3479
- trim(text);
3480
- }
3481
3607
  state.result_all.back().text = std::move(text);
3482
3608
 
3483
3609
  return res;
@@ -3543,9 +3669,14 @@ static void whisper_process_logits(
3543
3669
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
3544
3670
  logits[vocab.token_not] = -INFINITY;
3545
3671
 
3546
- // suppress sot and solm tokens
3672
+ // suppress sot and nosp tokens
3547
3673
  logits[vocab.token_sot] = -INFINITY;
3548
- logits[vocab.token_solm] = -INFINITY;
3674
+ logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
3675
+
3676
+ // [TDRZ] when tinydiarize is disabled, suppress solm token
3677
+ if (params.tdrz_enable == false) {
3678
+ logits[vocab.token_solm] = -INFINITY;
3679
+ }
3549
3680
 
3550
3681
  // suppress task tokens
3551
3682
  logits[vocab.token_translate] = -INFINITY;
@@ -3582,7 +3713,7 @@ static void whisper_process_logits(
3582
3713
  const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
3583
3714
  const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
3584
3715
 
3585
- //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
3716
+ //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
3586
3717
 
3587
3718
  if (last_was_timestamp) {
3588
3719
  if (penultimate_was_timestamp) {
@@ -3658,7 +3789,7 @@ static void whisper_process_logits(
3658
3789
 
3659
3790
  const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
3660
3791
 
3661
- //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
3792
+ //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
3662
3793
 
3663
3794
  if (timestamp_logprob > max_text_token_logprob) {
3664
3795
  for (int i = 0; i < vocab.token_beg; ++i) {
@@ -3907,12 +4038,12 @@ int whisper_full_with_state(
3907
4038
  // compute log mel spectrogram
3908
4039
  if (params.speed_up) {
3909
4040
  if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
3910
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
4041
+ log("%s: failed to compute log mel spectrogram\n", __func__);
3911
4042
  return -1;
3912
4043
  }
3913
4044
  } else {
3914
4045
  if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
3915
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
4046
+ log("%s: failed to compute log mel spectrogram\n", __func__);
3916
4047
  return -2;
3917
4048
  }
3918
4049
  }
@@ -3923,13 +4054,13 @@ int whisper_full_with_state(
3923
4054
 
3924
4055
  const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
3925
4056
  if (lang_id < 0) {
3926
- fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
4057
+ log("%s: failed to auto-detect language\n", __func__);
3927
4058
  return -3;
3928
4059
  }
3929
4060
  state->lang_id = lang_id;
3930
4061
  params.language = whisper_lang_str(lang_id);
3931
4062
 
3932
- fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
4063
+ log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
3933
4064
  if (params.detect_language) {
3934
4065
  return 0;
3935
4066
  }
@@ -3986,7 +4117,7 @@ int whisper_full_with_state(
3986
4117
  if (decoder.kv_self.ctx == nullptr) {
3987
4118
  decoder.kv_self = state->decoders[0].kv_self;
3988
4119
  if (!kv_cache_reinit(decoder.kv_self)) {
3989
- fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
4120
+ log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
3990
4121
  return -4;
3991
4122
  }
3992
4123
 
@@ -4030,7 +4161,7 @@ int whisper_full_with_state(
4030
4161
 
4031
4162
  // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
4032
4163
  if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
4033
- fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
4164
+ log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
4034
4165
  return -5;
4035
4166
  }
4036
4167
  state->exp_n_audio_ctx = params.audio_ctx;
@@ -4042,15 +4173,12 @@ int whisper_full_with_state(
4042
4173
  state->lang_id = lang_id;
4043
4174
  prompt_init.push_back(whisper_token_lang(ctx, lang_id));
4044
4175
  if (params.translate) {
4045
- prompt_init.push_back(whisper_token_translate());
4176
+ prompt_init.push_back(whisper_token_translate(ctx));
4046
4177
  } else {
4047
- prompt_init.push_back(whisper_token_transcribe());
4178
+ prompt_init.push_back(whisper_token_transcribe(ctx));
4048
4179
  }
4049
4180
  }
4050
4181
 
4051
- int progress_prev = 0;
4052
- int progress_step = 5;
4053
-
4054
4182
  int seek = seek_start;
4055
4183
 
4056
4184
  std::vector<whisper_token> prompt;
@@ -4077,16 +4205,11 @@ int whisper_full_with_state(
4077
4205
 
4078
4206
  // main loop
4079
4207
  while (true) {
4080
- const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
4081
- while (progress_cur >= progress_prev + progress_step) {
4082
- progress_prev += progress_step;
4083
- if (params.print_progress) {
4084
- fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
4085
- }
4086
- }
4087
4208
  if (params.progress_callback) {
4209
+ const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
4210
+
4088
4211
  params.progress_callback(
4089
- ctx, ctx->state, progress_prev, params.progress_callback_user_data);
4212
+ ctx, ctx->state, progress_cur, params.progress_callback_user_data);
4090
4213
  }
4091
4214
 
4092
4215
  // of only 1 second left, then stop
@@ -4096,14 +4219,14 @@ int whisper_full_with_state(
4096
4219
 
4097
4220
  if (params.encoder_begin_callback) {
4098
4221
  if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
4099
- fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
4222
+ log("%s: encoder_begin_callback returned false - aborting\n", __func__);
4100
4223
  break;
4101
4224
  }
4102
4225
  }
4103
4226
 
4104
4227
  // encode audio features starting at offset seek
4105
4228
  if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
4106
- fprintf(stderr, "%s: failed to encode\n", __func__);
4229
+ log("%s: failed to encode\n", __func__);
4107
4230
  return -6;
4108
4231
  }
4109
4232
 
@@ -4186,7 +4309,7 @@ int whisper_full_with_state(
4186
4309
  WHISPER_PRINT_DEBUG("\n\n");
4187
4310
 
4188
4311
  if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
4189
- fprintf(stderr, "%s: failed to decode\n", __func__);
4312
+ log("%s: failed to decode\n", __func__);
4190
4313
  return -7;
4191
4314
  }
4192
4315
 
@@ -4424,7 +4547,7 @@ int whisper_full_with_state(
4424
4547
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4425
4548
 
4426
4549
  if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4427
- fprintf(stderr, "%s: failed to decode\n", __func__);
4550
+ log("%s: failed to decode\n", __func__);
4428
4551
  return -8;
4429
4552
  }
4430
4553
 
@@ -4524,23 +4647,27 @@ int whisper_full_with_state(
4524
4647
  prompt_past.push_back(tokens_cur[i].id);
4525
4648
  }
4526
4649
 
4527
- // store the text from this iteration
4528
4650
  if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
4529
4651
  int i0 = 0;
4530
4652
  auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
4531
4653
 
4532
4654
  std::string text;
4655
+ bool speaker_turn_next = false;
4533
4656
 
4534
4657
  for (int i = 0; i < (int) tokens_cur.size(); i++) {
4535
4658
  //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
4536
4659
  // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
4537
4660
  // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
4538
4661
 
4539
- if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
4540
- } else {
4662
+ if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) {
4541
4663
  text += whisper_token_to_str(ctx, tokens_cur[i].id);
4542
4664
  }
4543
4665
 
4666
+ // [TDRZ] record if speaker turn was predicted after current segment
4667
+ if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) {
4668
+ speaker_turn_next = true;
4669
+ }
4670
+
4544
4671
  if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
4545
4672
  const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
4546
4673
 
@@ -4559,7 +4686,7 @@ int whisper_full_with_state(
4559
4686
 
4560
4687
  //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
4561
4688
 
4562
- result_all.push_back({ tt0, tt1, text, {} });
4689
+ result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
4563
4690
  for (int j = i0; j <= i; j++) {
4564
4691
  result_all.back().tokens.push_back(tokens_cur[j]);
4565
4692
  }
@@ -4585,6 +4712,7 @@ int whisper_full_with_state(
4585
4712
  i--;
4586
4713
  t0 = t1;
4587
4714
  i0 = i + 1;
4715
+ speaker_turn_next = false;
4588
4716
  }
4589
4717
  }
4590
4718
 
@@ -4603,7 +4731,7 @@ int whisper_full_with_state(
4603
4731
  }
4604
4732
  }
4605
4733
 
4606
- result_all.push_back({ tt0, tt1, text, {} });
4734
+ result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
4607
4735
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
4608
4736
  result_all.back().tokens.push_back(tokens_cur[j]);
4609
4737
  }
@@ -4741,12 +4869,12 @@ int whisper_full_parallel(
4741
4869
  ctx->state->t_decode_us /= n_processors;
4742
4870
 
4743
4871
  // print information about the audio boundaries
4744
- fprintf(stderr, "\n");
4745
- fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
4872
+ log("\n");
4873
+ log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
4746
4874
  for (int i = 0; i < n_processors - 1; ++i) {
4747
- fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
4875
+ log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
4748
4876
  }
4749
- fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
4877
+ log("%s: the transcription quality may be degraded near these boundaries\n", __func__);
4750
4878
 
4751
4879
  return ret;
4752
4880
  }
@@ -4783,6 +4911,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
4783
4911
  return ctx->state->result_all[i_segment].t1;
4784
4912
  }
4785
4913
 
4914
+ bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
4915
+ return ctx->state->result_all[i_segment].speaker_turn_next;
4916
+ }
4917
+
4786
4918
  const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
4787
4919
  return state->result_all[i_segment].text.c_str();
4788
4920
  }
@@ -5102,7 +5234,7 @@ static void whisper_exp_compute_token_level_timestamps(
5102
5234
  const int n_samples = state.energy.size();
5103
5235
 
5104
5236
  if (n_samples == 0) {
5105
- fprintf(stderr, "%s: no signal data available\n", __func__);
5237
+ log("%s: no signal data available\n", __func__);
5106
5238
  return;
5107
5239
  }
5108
5240
 
@@ -5322,3 +5454,7 @@ static void whisper_exp_compute_token_level_timestamps(
5322
5454
  // }
5323
5455
  //}
5324
5456
  }
5457
+
5458
+ void whisper_set_log_callback(whisper_log_callback callback) {
5459
+ whisper_log = callback;
5460
+ }