whispercpp 1.3.0 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -1,29 +1,19 @@
1
1
  #include "whisper.h"
2
2
 
3
- #ifdef WHISPER_USE_COREML
4
- #include "coreml/whisper-encoder.h"
5
- #endif
3
+ #include "ggml-cpu.h"
6
4
 
7
- #ifdef GGML_USE_METAL
8
- #include "ggml-metal.h"
9
- #endif
10
-
11
- #ifdef GGML_USE_CUDA
12
- #include "ggml-cuda.h"
13
- #endif
5
+ #include "ggml.h"
6
+ #include "ggml-alloc.h"
7
+ #include "ggml-backend.h"
14
8
 
15
- #ifdef GGML_USE_SYCL
16
- #include "ggml-sycl.h"
9
+ #ifdef WHISPER_USE_COREML
10
+ #include "coreml/whisper-encoder.h"
17
11
  #endif
18
12
 
19
13
  #ifdef WHISPER_USE_OPENVINO
20
14
  #include "openvino/whisper-openvino-encoder.h"
21
15
  #endif
22
16
 
23
- #include "ggml.h"
24
- #include "ggml-alloc.h"
25
- #include "ggml-backend.h"
26
-
27
17
  #include <atomic>
28
18
  #include <algorithm>
29
19
  #include <cassert>
@@ -41,6 +31,7 @@
41
31
  #include <regex>
42
32
  #include <random>
43
33
  #include <functional>
34
+ #include <codecvt>
44
35
 
45
36
  #if defined(_MSC_VER)
46
37
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -147,8 +138,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
147
138
  } \
148
139
  } while (0)
149
140
 
150
- //#define WHISPER_USE_FLASH_ATTN
151
- //#define WHISPER_USE_FLASH_FF
152
141
  #define WHISPER_MAX_DECODERS 8
153
142
  #define WHISPER_MAX_NODES 4096
154
143
 
@@ -162,7 +151,7 @@ static bool ggml_graph_compute_helper(
162
151
  int n_threads,
163
152
  ggml_abort_callback abort_callback,
164
153
  void * abort_callback_data) {
165
- struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
154
+ struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
166
155
 
167
156
  plan.abort_callback = abort_callback;
168
157
  plan.abort_callback_data = abort_callback_data;
@@ -176,18 +165,24 @@ static bool ggml_graph_compute_helper(
176
165
  }
177
166
 
178
167
  static bool ggml_graph_compute_helper(
179
- struct ggml_backend * backend,
168
+ ggml_backend_sched_t sched,
180
169
  struct ggml_cgraph * graph,
181
170
  int n_threads) {
182
- if (ggml_backend_is_cpu(backend)) {
183
- ggml_backend_cpu_set_n_threads(backend, n_threads);
184
- }
185
- #ifdef GGML_USE_METAL
186
- if (ggml_backend_is_metal(backend)) {
187
- ggml_backend_metal_set_n_cb(backend, n_threads);
171
+
172
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
173
+ ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
174
+ ggml_backend_dev_t dev = ggml_backend_get_device(backend);
175
+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
176
+
177
+ auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
178
+ if (fn_set_n_threads) {
179
+ fn_set_n_threads(backend, n_threads);
180
+ }
188
181
  }
189
- #endif
190
- return ggml_backend_graph_compute(backend, graph) == GGML_STATUS_SUCCESS;
182
+
183
+ bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
184
+ ggml_backend_sched_reset(sched);
185
+ return t;
191
186
  }
192
187
 
193
188
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
@@ -363,6 +358,7 @@ static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15},
363
358
  static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
364
359
  static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
365
360
  static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
361
+ static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
366
362
 
367
363
  static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
368
364
  { WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
@@ -376,6 +372,7 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
376
372
  { WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
377
373
  { WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
378
374
  { WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
375
+ { WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
379
376
  };
380
377
 
381
378
  static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
@@ -502,33 +499,41 @@ struct whisper_pair {
502
499
  whisper_pair() : first(A()), second(B()) {}
503
500
  };
504
501
 
505
- // ggml_allocr wrapper for whisper usage
506
- struct whisper_allocr {
507
- ggml_gallocr_t alloc = nullptr;
502
+ // ggml_backend_sched wrapper for whisper usage
503
+ struct whisper_sched {
504
+ ggml_backend_sched_t sched = nullptr;
508
505
 
509
506
  std::vector<uint8_t> meta;
510
507
  };
511
508
 
512
- static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
513
- return allocr.meta.size() + ggml_gallocr_get_buffer_size(allocr.alloc, 0);
509
+ static size_t whisper_sched_size(struct whisper_sched & allocr) {
510
+ size_t size = allocr.meta.size();
511
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
512
+ ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
513
+ size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
514
+ }
515
+ return size;
514
516
  }
515
517
 
516
518
  // measure the memory usage of a graph and prepare the allocr's internal data buffer
517
- static bool whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
518
- auto & alloc = allocr.alloc;
519
+ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
520
+ auto & sched = allocr.sched;
519
521
  auto & meta = allocr.meta;
520
522
 
521
- alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
523
+ sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
522
524
 
523
525
  meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
524
526
 
525
527
  // since there are dependencies between the different graphs,
526
528
  // we need to allocate them instead of only reserving to get the correct compute buffer size
527
- if (!ggml_gallocr_alloc_graph(alloc, get_graph())) {
529
+ if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
528
530
  // failed to allocate the compute buffer
529
531
  WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
530
532
  return false;
531
533
  }
534
+
535
+ ggml_backend_sched_reset(sched);
536
+
532
537
  return true;
533
538
  }
534
539
 
@@ -671,9 +676,9 @@ struct whisper_kv_cache {
671
676
  struct ggml_tensor * k;
672
677
  struct ggml_tensor * v;
673
678
 
674
- struct ggml_context * ctx = nullptr;
675
-
676
679
  ggml_backend_buffer_t buffer = nullptr;
680
+
681
+ std::vector<uint8_t> ctx_buf;
677
682
  };
678
683
 
679
684
  struct whisper_model {
@@ -802,6 +807,9 @@ struct whisper_state {
802
807
  int32_t n_fail_p = 0; // number of logprob threshold failures
803
808
  int32_t n_fail_h = 0; // number of entropy threshold failures
804
809
 
810
+ // number of decoders for which we have constructed the KV cache
811
+ int32_t kv_self_n_dec = 0;
812
+
805
813
  // unified self-attention KV cache for all decoders
806
814
  whisper_kv_cache kv_self;
807
815
 
@@ -809,21 +817,22 @@ struct whisper_state {
809
817
  // shared between all decoders
810
818
  whisper_kv_cache kv_cross;
811
819
 
820
+ // padded buffer for flash-attention
821
+ whisper_kv_cache kv_pad;
822
+
812
823
  whisper_mel mel;
813
824
 
814
825
  whisper_batch batch;
815
826
 
816
827
  whisper_decoder decoders[WHISPER_MAX_DECODERS];
817
828
 
818
- ggml_backend_t backend = nullptr;
829
+ std::vector<ggml_backend_t> backends;
819
830
 
820
- // ggml-alloc:
821
831
  // - stores meta info about the intermediate tensors into the `meta` buffers
822
- // - stores the actual tensor data into the `data` buffers
823
- whisper_allocr alloc_conv;
824
- whisper_allocr alloc_encode;
825
- whisper_allocr alloc_cross;
826
- whisper_allocr alloc_decode;
832
+ whisper_sched sched_conv;
833
+ whisper_sched sched_encode;
834
+ whisper_sched sched_cross;
835
+ whisper_sched sched_decode;
827
836
 
828
837
  // result of the encoder
829
838
  struct ggml_tensor * embd_conv = nullptr;
@@ -858,6 +867,7 @@ struct whisper_state {
858
867
  whisper_token tid_last;
859
868
 
860
869
  std::vector<float> energy; // PCM signal energy
870
+ float no_speech_prob = 0.0f;
861
871
 
862
872
  // [EXPERIMENTAL] Token-level timestamps with DTW
863
873
  whisper_aheads_masks aheads_masks;
@@ -882,8 +892,6 @@ struct whisper_context {
882
892
 
883
893
  whisper_state * state = nullptr;
884
894
 
885
- ggml_backend_t backend = nullptr;
886
-
887
895
  std::string path_model; // populated by whisper_init_from_file_with_params()
888
896
  };
889
897
 
@@ -901,21 +909,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
901
909
  BYTESWAP_VALUE(dest);
902
910
  }
903
911
 
904
- static bool kv_cache_init(
905
- const struct whisper_hparams & hparams,
912
+ static bool whisper_kv_cache_init(
906
913
  struct whisper_kv_cache & cache,
907
914
  ggml_backend_t backend,
908
915
  ggml_type wtype,
916
+ int64_t n_text_state,
917
+ int64_t n_text_layer,
909
918
  int n_ctx) {
910
- const int64_t n_text_state = hparams.n_text_state;
911
- const int64_t n_text_layer = hparams.n_text_layer;
912
-
913
919
  const int64_t n_mem = n_text_layer*n_ctx;
914
920
  const int64_t n_elements = n_text_state*n_mem;
915
921
 
922
+ cache.ctx_buf.resize(2*ggml_tensor_overhead());
923
+
916
924
  struct ggml_init_params params = {
917
- /*.mem_size =*/ 2*ggml_tensor_overhead(),
918
- /*.mem_buffer =*/ nullptr,
925
+ /*.mem_size =*/ cache.ctx_buf.size(),
926
+ /*.mem_buffer =*/ cache.ctx_buf.data(),
919
927
  /*.no_alloc =*/ true,
920
928
  };
921
929
 
@@ -925,29 +933,31 @@ static bool kv_cache_init(
925
933
  cache.cells.clear();
926
934
  cache.cells.resize(n_ctx);
927
935
 
928
- cache.ctx = ggml_init(params);
936
+ struct ggml_context * ctx = ggml_init(params);
929
937
 
930
- if (!cache.ctx) {
938
+ if (!ctx) {
931
939
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
932
940
  return false;
933
941
  }
934
942
 
935
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
936
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
943
+ cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
944
+ cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);
937
945
 
938
- cache.buffer = ggml_backend_alloc_ctx_tensors(cache.ctx, backend);
946
+ cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
939
947
  if (!cache.buffer) {
940
948
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
941
949
  return false;
942
950
  }
943
951
 
952
+ ggml_backend_buffer_clear(cache.buffer, 0);
953
+
954
+ ggml_free(ctx);
955
+
944
956
  return true;
945
957
  }
946
958
 
947
- static void kv_cache_free(struct whisper_kv_cache & cache) {
948
- ggml_free(cache.ctx);
959
+ static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
949
960
  ggml_backend_buffer_free(cache.buffer);
950
- cache.ctx = nullptr;
951
961
  }
952
962
 
953
963
  static bool whisper_kv_cache_find_slot(
@@ -1018,6 +1028,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
1018
1028
  cache.cells[i].seq_id.clear();
1019
1029
  }
1020
1030
  cache.head = 0;
1031
+
1032
+ ggml_backend_buffer_clear(cache.buffer, 0);
1021
1033
  }
1022
1034
 
1023
1035
  static void whisper_kv_cache_seq_rm(
@@ -1068,6 +1080,26 @@ static void whisper_kv_cache_seq_cp(
1068
1080
  }
1069
1081
  }
1070
1082
 
1083
+ static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
1084
+ if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
1085
+ return 1u;
1086
+ }
1087
+
1088
+ #ifdef GGML_USE_METAL
1089
+ if (wctx.params.use_gpu) {
1090
+ return 32u;
1091
+ }
1092
+ #endif
1093
+
1094
+ #ifdef GGML_USE_CUDA
1095
+ if (wctx.params.use_gpu) {
1096
+ return 256u;
1097
+ }
1098
+ #endif
1099
+
1100
+ return 1u;
1101
+ }
1102
+
1071
1103
  // [EXPERIMENTAL] Token-level timestamps with DTW
1072
1104
  static bool aheads_masks_init(
1073
1105
  const whisper_context_params & cparams,
@@ -1199,49 +1231,71 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1199
1231
  return size;
1200
1232
  }
1201
1233
 
1202
- static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1203
- ggml_backend_t backend_gpu = NULL;
1234
+ static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1235
+ ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1204
1236
 
1205
- // initialize the backends
1206
- #ifdef GGML_USE_CUDA
1207
1237
  if (params.use_gpu) {
1208
- WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1209
- backend_gpu = ggml_backend_cuda_init(params.gpu_device);
1210
- if (!backend_gpu) {
1211
- WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
1238
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1239
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1240
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1241
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1242
+ ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1243
+ if (!result) {
1244
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1245
+ }
1246
+ return result;
1247
+ }
1212
1248
  }
1213
1249
  }
1214
- #endif
1215
1250
 
1216
- #ifdef GGML_USE_METAL
1217
- if (params.use_gpu) {
1218
- WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1219
- ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
1220
- backend_gpu = ggml_backend_metal_init();
1221
- if (!backend_gpu) {
1222
- WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
1223
- } else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
1224
- WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1225
- ggml_backend_free(backend_gpu);
1226
- backend_gpu = NULL;
1227
- }
1251
+ return nullptr;
1252
+ }
1253
+
1254
+ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
1255
+ std::vector<ggml_backend_t> result;
1256
+
1257
+ ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
1258
+
1259
+ if (backend_gpu) {
1260
+ result.push_back(backend_gpu);
1228
1261
  }
1229
- #endif
1230
1262
 
1231
- #ifdef GGML_USE_SYCL
1232
- if (params.use_gpu) {
1233
- WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
1234
- backend_gpu = ggml_backend_sycl_init(params.gpu_device);
1235
- if (!backend_gpu) {
1236
- WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
1263
+ // ACCEL backends
1264
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1265
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1266
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
1267
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1268
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
1269
+ if (!backend) {
1270
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1271
+ continue;
1272
+ }
1273
+ result.push_back(backend);
1237
1274
  }
1238
1275
  }
1239
- #endif
1240
1276
 
1241
- if (backend_gpu) {
1242
- return backend_gpu;
1277
+ GGML_UNUSED(params);
1278
+
1279
+ result.push_back(ggml_backend_cpu_init());
1280
+
1281
+ return result;
1282
+ }
1283
+
1284
+ static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1285
+ if (!params.use_gpu) {
1286
+ return ggml_backend_cpu_buffer_type();
1243
1287
  }
1244
- return ggml_backend_cpu_init();
1288
+
1289
+ // if we have a GPU device - use it
1290
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1291
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1292
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1293
+ WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
1294
+ return ggml_backend_dev_buffer_type(dev);
1295
+ }
1296
+ }
1297
+
1298
+ return ggml_backend_cpu_buffer_type();
1245
1299
  }
1246
1300
 
1247
1301
  // load the model from a ggml file
@@ -1668,21 +1722,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1668
1722
  }
1669
1723
  }
1670
1724
 
1671
- wctx.backend = whisper_backend_init(wctx.params);
1672
- if (!wctx.backend) {
1673
- WHISPER_LOG_ERROR("%s: failed to initialize the backend\n", __func__);
1674
- return false;
1675
- }
1676
-
1677
1725
  // allocate tensors in the backend buffers
1678
- model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, wctx.backend);
1726
+ model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
1679
1727
  if (!model.buffer) {
1680
1728
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1681
1729
  return false;
1682
1730
  }
1683
1731
 
1684
1732
  size_t size_main = ggml_backend_buffer_get_size(model.buffer);
1685
- WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6);
1733
+ WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
1686
1734
 
1687
1735
  // load weights
1688
1736
  {
@@ -1777,6 +1825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1777
1825
  }
1778
1826
  }
1779
1827
 
1828
+ ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1829
+
1780
1830
  wctx.t_load_us = ggml_time_us() - t_start_us;
1781
1831
 
1782
1832
  return true;
@@ -1812,8 +1862,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1812
1862
  const int n_mels = hparams.n_mels;
1813
1863
 
1814
1864
  struct ggml_init_params params = {
1815
- /*.mem_size =*/ wstate.alloc_conv.meta.size(),
1816
- /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
1865
+ /*.mem_size =*/ wstate.sched_conv.meta.size(),
1866
+ /*.mem_buffer =*/ wstate.sched_conv.meta.data(),
1817
1867
  /*.no_alloc =*/ true,
1818
1868
  };
1819
1869
 
@@ -1847,6 +1897,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1847
1897
  ggml_build_forward_expand(gf, mel);
1848
1898
 
1849
1899
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1900
+ ggml_set_input(cur); // the external encoder will write into this tensor
1850
1901
 
1851
1902
  ggml_set_name(cur, "embd_enc");
1852
1903
  wstate.embd_enc = cur;
@@ -1872,9 +1923,17 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1872
1923
  const int n_head = hparams.n_audio_head;
1873
1924
  const int n_layer = hparams.n_audio_layer;
1874
1925
 
1926
+ const int n_state_head = n_state/n_head;
1927
+
1928
+ auto & kv_pad = wstate.kv_pad;
1929
+
1930
+ WHISPER_ASSERT(!!kv_pad.buffer);
1931
+
1932
+ const int n_ctx_pad = GGML_PAD(n_ctx, 256);
1933
+
1875
1934
  struct ggml_init_params params = {
1876
- /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1877
- /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
1935
+ /*.mem_size =*/ wstate.sched_encode.meta.size(),
1936
+ /*.mem_buffer =*/ wstate.sched_encode.meta.data(),
1878
1937
  /*.no_alloc =*/ true,
1879
1938
  };
1880
1939
 
@@ -1884,7 +1943,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1884
1943
 
1885
1944
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
1886
1945
 
1887
- const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
1946
+ const float KQscale = 1.0f/sqrtf(float(n_state_head));
1888
1947
 
1889
1948
  // ===================================================================
1890
1949
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1934,14 +1993,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1934
1993
 
1935
1994
  Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
1936
1995
 
1937
- //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
1996
+ //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
1938
1997
 
1939
1998
  // note: no bias for Key
1940
1999
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1941
2000
  layer.attn_k_w,
1942
2001
  cur);
1943
2002
 
1944
- //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
2003
+ //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
1945
2004
 
1946
2005
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1947
2006
  layer.attn_v_w,
@@ -1951,70 +2010,60 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1951
2010
 
1952
2011
  // ------
1953
2012
 
1954
- #ifdef WHISPER_USE_FLASH_ATTN
1955
2013
  struct ggml_tensor * Q =
1956
2014
  ggml_permute(ctx0,
1957
- ggml_cpy(ctx0,
1958
- Qcur,
1959
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
2015
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
1960
2016
  0, 2, 1, 3);
1961
2017
 
1962
- struct ggml_tensor * K =
1963
- ggml_permute(ctx0,
1964
- ggml_cpy(ctx0,
1965
- Kcur,
1966
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1967
- 0, 2, 1, 3);
2018
+ if (wctx.params.flash_attn) {
2019
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
2020
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
1968
2021
 
1969
- struct ggml_tensor * V =
1970
- ggml_cpy(ctx0,
1971
- ggml_permute(ctx0,
1972
- ggml_reshape_3d(ctx0,
1973
- Vcur,
1974
- n_state/n_head, n_head, n_ctx),
1975
- 1, 2, 0, 3),
1976
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
2022
+ struct ggml_tensor * K =
2023
+ ggml_view_3d(ctx0, kv_pad.k,
2024
+ n_state_head, n_ctx_pad, n_head,
2025
+ ggml_element_size(kv_pad.k)*n_state,
2026
+ ggml_element_size(kv_pad.k)*n_state_head,
2027
+ 0);
1977
2028
 
1978
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1979
- #else
1980
- struct ggml_tensor * Q =
1981
- ggml_permute(ctx0,
1982
- ggml_cpy(ctx0,
1983
- Qcur,
1984
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1985
- 0, 2, 1, 3);
2029
+ struct ggml_tensor * V =
2030
+ ggml_view_3d(ctx0, kv_pad.v,
2031
+ n_state_head, n_ctx_pad, n_head,
2032
+ ggml_element_size(kv_pad.v)*n_state,
2033
+ ggml_element_size(kv_pad.v)*n_state_head,
2034
+ 0);
1986
2035
 
1987
- struct ggml_tensor * K =
1988
- ggml_permute(ctx0,
1989
- ggml_cpy(ctx0,
1990
- Kcur,
1991
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1992
- 0, 2, 1, 3);
2036
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
1993
2037
 
1994
- // K * Q
1995
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2038
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
2039
+ } else {
2040
+ struct ggml_tensor * K =
2041
+ ggml_permute(ctx0,
2042
+ ggml_cast(ctx0,
2043
+ ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
2044
+ wctx.itype),
2045
+ 0, 2, 1, 3);
1996
2046
 
1997
- struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
2047
+ // K * Q
2048
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1998
2049
 
1999
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
2050
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2000
2051
 
2001
- struct ggml_tensor * V =
2002
- ggml_cpy(ctx0,
2003
- ggml_permute(ctx0,
2004
- ggml_reshape_3d(ctx0,
2005
- Vcur,
2006
- n_state/n_head, n_head, n_ctx),
2007
- 1, 2, 0, 3),
2008
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
2009
- );
2052
+ struct ggml_tensor * V =
2053
+ ggml_cast(ctx0,
2054
+ ggml_permute(ctx0,
2055
+ ggml_reshape_3d(ctx0,
2056
+ Vcur,
2057
+ n_state_head, n_head, n_ctx),
2058
+ 1, 2, 0, 3),
2059
+ wctx.itype);
2010
2060
 
2011
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2012
- #endif
2013
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2061
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2014
2062
 
2015
- cur = ggml_cpy(ctx0,
2016
- KQV_merged,
2017
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
2063
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2064
+
2065
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
2066
+ }
2018
2067
  }
2019
2068
 
2020
2069
  // projection
@@ -2043,11 +2092,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2043
2092
  layer.mlp_ln_b);
2044
2093
  }
2045
2094
 
2046
- #ifdef WHISPER_USE_FLASH_FF
2047
- cur = ggml_flash_ff(ctx0,
2048
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
2049
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
2050
- #else
2051
2095
  // fully connected
2052
2096
  cur = ggml_mul_mat(ctx0,
2053
2097
  layer.mlp_0_w,
@@ -2064,7 +2108,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2064
2108
  cur);
2065
2109
 
2066
2110
  cur = ggml_add(ctx0, cur, layer.mlp_1_b);
2067
- #endif
2068
2111
  }
2069
2112
 
2070
2113
  inpL = ggml_add(ctx0, cur, inpFF);
@@ -2113,9 +2156,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2113
2156
  const int n_state = hparams.n_audio_state;
2114
2157
  const int n_head = hparams.n_audio_head;
2115
2158
 
2159
+ const int n_state_head = n_state/n_head;
2160
+
2161
+ const int n_ctx_pad = GGML_PAD(n_ctx, 256);
2162
+
2116
2163
  struct ggml_init_params params = {
2117
- /*.mem_size =*/ wstate.alloc_cross.meta.size(),
2118
- /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
2164
+ /*.mem_size =*/ wstate.sched_cross.meta.size(),
2165
+ /*.mem_buffer =*/ wstate.sched_cross.meta.data(),
2119
2166
  /*.no_alloc =*/ true,
2120
2167
  };
2121
2168
 
@@ -2125,18 +2172,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2125
2172
 
2126
2173
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
2127
2174
 
2128
- const float Kscale = pow(float(n_state) / n_head, -0.25);
2175
+ const float Kscale = pow(float(n_state_head), -0.25);
2129
2176
 
2130
2177
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
2131
2178
  auto & layer = model.layers_decoder[il];
2132
2179
 
2133
- struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
2180
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
2134
2181
  layer.cross_attn_k_w,
2135
2182
  cur);
2136
2183
 
2137
2184
  Kcross = ggml_scale(ctx0, Kcross, Kscale);
2138
2185
 
2139
- struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
2186
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
2140
2187
  layer.cross_attn_v_w,
2141
2188
  cur);
2142
2189
 
@@ -2144,15 +2191,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2144
2191
  Vcross,
2145
2192
  layer.cross_attn_v_b);
2146
2193
 
2147
- Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2194
+ struct ggml_tensor * k;
2195
+ struct ggml_tensor * v;
2148
2196
 
2149
- struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
2150
- n_state*n_ctx,
2151
- (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2197
+ if (wctx.params.flash_attn) {
2198
+ k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2199
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
2152
2200
 
2153
- struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2154
- ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
2155
- (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
2201
+ v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
2202
+ (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
2203
+ } else {
2204
+ Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2205
+
2206
+ k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2207
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2208
+
2209
+ v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2210
+ ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
2211
+ (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
2212
+ }
2156
2213
 
2157
2214
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
2158
2215
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
@@ -2186,11 +2243,11 @@ static bool whisper_encode_internal(
2186
2243
 
2187
2244
  // conv
2188
2245
  {
2189
- auto & alloc = wstate.alloc_conv.alloc;
2246
+ auto & sched = wstate.sched_conv.sched;
2190
2247
 
2191
2248
  ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
2192
2249
 
2193
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2250
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2194
2251
  // should never happen as we pre-allocate the memory
2195
2252
  return false;
2196
2253
  }
@@ -2223,7 +2280,7 @@ static bool whisper_encode_internal(
2223
2280
  }
2224
2281
 
2225
2282
  if (!whisper_encode_external(wstate)) {
2226
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2283
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2227
2284
  return false;
2228
2285
  }
2229
2286
  } else {
@@ -2237,32 +2294,32 @@ static bool whisper_encode_internal(
2237
2294
 
2238
2295
  // encoder
2239
2296
  if (!whisper_encode_external(wstate)) {
2240
- auto & alloc = wstate.alloc_encode.alloc;
2297
+ auto & sched = wstate.sched_encode.sched;
2241
2298
 
2242
2299
  ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
2243
2300
 
2244
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2301
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2245
2302
  // should never happen as we pre-allocate the memory
2246
2303
  return false;
2247
2304
  }
2248
2305
 
2249
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2306
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2250
2307
  return false;
2251
2308
  }
2252
2309
  }
2253
2310
 
2254
2311
  // cross
2255
2312
  {
2256
- auto & alloc = wstate.alloc_cross.alloc;
2313
+ auto & sched = wstate.sched_cross.sched;
2257
2314
 
2258
2315
  ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
2259
2316
 
2260
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2317
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2261
2318
  // should never happen as we pre-allocate the memory
2262
2319
  return false;
2263
2320
  }
2264
2321
 
2265
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2322
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2266
2323
  return false;
2267
2324
  }
2268
2325
  }
@@ -2284,24 +2341,28 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2284
2341
 
2285
2342
  auto & kv_self = wstate.kv_self;
2286
2343
 
2287
- WHISPER_ASSERT(!!kv_self.ctx);
2344
+ WHISPER_ASSERT(!!kv_self.buffer);
2288
2345
 
2289
2346
  const int n_ctx = kv_self.size;
2290
2347
  const int n_state = hparams.n_text_state;
2291
2348
  const int n_head = hparams.n_text_head;
2292
2349
  const int n_layer = hparams.n_text_layer;
2293
2350
 
2351
+ const int n_state_head = n_state/n_head;
2352
+
2294
2353
  const int n_tokens = batch.n_tokens;
2295
2354
  const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2296
2355
 
2297
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2298
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
2356
+ const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
2357
+
2358
+ const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2359
+ const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
2299
2360
 
2300
2361
  //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2301
2362
 
2302
2363
  struct ggml_init_params params = {
2303
- /*.mem_size =*/ wstate.alloc_decode.meta.size(),
2304
- /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
2364
+ /*.mem_size =*/ wstate.sched_decode.meta.size(),
2365
+ /*.mem_buffer =*/ wstate.sched_decode.meta.data(),
2305
2366
  /*.no_alloc =*/ true,
2306
2367
  };
2307
2368
 
@@ -2317,12 +2378,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2317
2378
  ggml_set_name(position, "position");
2318
2379
  ggml_set_input(position);
2319
2380
 
2320
- const float KQscale = pow(float(n_state)/n_head, -0.25);
2381
+ const float KQscale = pow(float(n_state_head), -0.25);
2321
2382
 
2322
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
2383
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
2323
2384
  ggml_set_name(KQ_mask, "KQ_mask");
2324
2385
  ggml_set_input(KQ_mask);
2325
2386
 
2387
+ struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
2388
+
2326
2389
  // token encoding + position encoding
2327
2390
  struct ggml_tensor * cur =
2328
2391
  ggml_add(ctx0,
@@ -2378,12 +2441,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2378
2441
  Vcur,
2379
2442
  layer.attn_v_b);
2380
2443
 
2381
- Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2444
+ struct ggml_tensor * k;
2445
+ struct ggml_tensor * v;
2446
+
2447
+ if (wctx.params.flash_attn) {
2448
+ k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2449
+ (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2450
+
2451
+ v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
2452
+ (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
2453
+ } else {
2454
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2455
+
2456
+ k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2457
+ (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2382
2458
 
2383
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2384
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2385
- ( n_ctx)*ggml_element_size(kv_self.v),
2386
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
2459
+ v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2460
+ ( n_ctx)*ggml_element_size(kv_self.v),
2461
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
2462
+ }
2387
2463
 
2388
2464
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
2389
2465
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@@ -2393,40 +2469,46 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2393
2469
 
2394
2470
  struct ggml_tensor * Q =
2395
2471
  ggml_permute(ctx0,
2396
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2472
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2397
2473
  0, 2, 1, 3);
2398
2474
 
2399
2475
  struct ggml_tensor * K =
2400
2476
  ggml_view_3d(ctx0, kv_self.k,
2401
- n_state/n_head, n_kv, n_head,
2477
+ n_state_head, n_kv, n_head,
2402
2478
  ggml_element_size(kv_self.k)*n_state,
2403
- ggml_element_size(kv_self.k)*n_state/n_head,
2479
+ ggml_element_size(kv_self.k)*n_state_head,
2404
2480
  ggml_element_size(kv_self.k)*n_state*n_ctx*il);
2405
2481
 
2406
- // K * Q
2407
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2482
+ if (wctx.params.flash_attn) {
2483
+ struct ggml_tensor * V =
2484
+ ggml_view_3d(ctx0, kv_self.v,
2485
+ n_state_head, n_kv, n_head,
2486
+ ggml_element_size(kv_self.v)*n_state,
2487
+ ggml_element_size(kv_self.v)*n_state_head,
2488
+ ggml_element_size(kv_self.v)*n_state*n_ctx*il);
2408
2489
 
2409
- //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
2490
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
2410
2491
 
2411
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
2412
- struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
2492
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2493
+ } else {
2494
+ // K * Q
2495
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2413
2496
 
2414
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
2497
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
2415
2498
 
2416
- struct ggml_tensor * V =
2417
- ggml_view_3d(ctx0, kv_self.v,
2418
- n_kv, n_state/n_head, n_head,
2419
- n_ctx*ggml_element_size(kv_self.v),
2420
- n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
2421
- n_ctx*ggml_element_size(kv_self.v)*n_state*il);
2499
+ struct ggml_tensor * V =
2500
+ ggml_view_3d(ctx0, kv_self.v,
2501
+ n_kv, n_state_head, n_head,
2502
+ n_ctx*ggml_element_size(kv_self.v),
2503
+ n_ctx*ggml_element_size(kv_self.v)*n_state_head,
2504
+ n_ctx*ggml_element_size(kv_self.v)*n_state*il);
2422
2505
 
2423
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2506
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2424
2507
 
2425
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2508
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2426
2509
 
2427
- cur = ggml_cpy(ctx0,
2428
- KQV_merged,
2429
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2510
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2511
+ }
2430
2512
  }
2431
2513
 
2432
2514
  // projection
@@ -2465,80 +2547,75 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2465
2547
  Qcur,
2466
2548
  layer.cross_attn_q_b);
2467
2549
 
2468
- Qcur = ggml_scale(ctx0, Qcur, KQscale);
2469
-
2470
- // Kcross is already scaled
2471
- struct ggml_tensor * Kcross =
2472
- ggml_view_3d(ctx0, wstate.kv_cross.k,
2473
- n_state/n_head, n_audio_ctx, n_head,
2474
- ggml_element_size(wstate.kv_cross.k)*n_state,
2475
- ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2476
- ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2477
-
2478
- //struct ggml_tensor * Vcross =
2479
- // ggml_reshape_3d(ctx0,
2480
- // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
2481
- // n_state/n_head, n_head, n_audio_ctx);
2482
-
2483
- //struct ggml_tensor * V_trans =
2484
- // ggml_cpy(ctx0,
2485
- // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2486
- // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
2487
-
2488
- struct ggml_tensor * V =
2489
- ggml_view_3d(ctx0, wstate.kv_cross.v,
2490
- n_audio_ctx, n_state/n_head, n_head,
2491
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
2492
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2493
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
2494
-
2495
- // ------
2496
-
2497
2550
  struct ggml_tensor * Q =
2498
2551
  ggml_permute(ctx0,
2499
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2552
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2500
2553
  0, 2, 1, 3);
2501
2554
 
2502
- // K * Q
2503
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
2504
-
2505
- //struct ggml_tensor * KQ_scaled =
2506
- // ggml_scale(ctx0,
2507
- // KQ,
2508
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2509
- // );
2555
+ if (wctx.params.flash_attn) {
2556
+ struct ggml_tensor * Kcross =
2557
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
2558
+ n_state_head, n_audio_ctx_pad, n_head,
2559
+ ggml_element_size(wstate.kv_cross.k)*n_state,
2560
+ ggml_element_size(wstate.kv_cross.k)*n_state_head,
2561
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
2510
2562
 
2511
- // no masking for cross-attention
2512
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2563
+ struct ggml_tensor * Vcross =
2564
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
2565
+ n_state_head, n_audio_ctx_pad, n_head,
2566
+ ggml_element_size(wstate.kv_cross.v)*n_state,
2567
+ ggml_element_size(wstate.kv_cross.v)*n_state_head,
2568
+ ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
2513
2569
 
2514
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2570
+ cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
2515
2571
 
2516
- // [EXPERIMENTAL] Token-level timestamps with DTW
2517
- if (wctx.params.dtw_token_timestamps) {
2518
- if (wstate.aheads_masks.m[il] != nullptr) {
2519
- struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2520
- aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2521
- aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2522
- aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2523
- aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2524
- aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2525
- aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2526
- if (aheads_cross_QKs == NULL) {
2527
- aheads_cross_QKs = aheads_KQs;
2528
- } else {
2529
- aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
2572
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2573
+ } else {
2574
+ struct ggml_tensor * Kcross =
2575
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
2576
+ n_state_head, n_audio_ctx, n_head,
2577
+ ggml_element_size(wstate.kv_cross.k)*n_state,
2578
+ ggml_element_size(wstate.kv_cross.k)*n_state_head,
2579
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2580
+
2581
+ struct ggml_tensor * Vcross =
2582
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
2583
+ n_audio_ctx, n_state_head, n_head,
2584
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
2585
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
2586
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
2587
+
2588
+ // ------
2589
+
2590
+ // K * Q
2591
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
2592
+
2593
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2594
+
2595
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2596
+ if (wctx.params.dtw_token_timestamps) {
2597
+ if (wstate.aheads_masks.m[il] != nullptr) {
2598
+ struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2599
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2600
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2601
+ aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2602
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2603
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2604
+ aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2605
+ if (aheads_cross_QKs == NULL) {
2606
+ aheads_cross_QKs = aheads_KQs;
2607
+ } else {
2608
+ aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
2609
+ }
2530
2610
  }
2531
2611
  }
2532
- }
2533
2612
 
2534
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2613
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
2535
2614
 
2536
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2615
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2537
2616
 
2538
- // cur = KQV_merged.contiguous().view(n_state, n_tokens)
2539
- cur = ggml_cpy(ctx0,
2540
- KQV_merged,
2541
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2617
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2618
+ }
2542
2619
  }
2543
2620
 
2544
2621
  // projection
@@ -2671,18 +2748,20 @@ static bool whisper_decode_internal(
2671
2748
  return false;
2672
2749
  }
2673
2750
 
2674
- kv_self.n = whisper_kv_cache_cell_max(kv_self);
2751
+ const uint32_t pad = whisper_kv_cache_get_padding(wctx);
2752
+ kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
2753
+
2675
2754
  //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2676
2755
  //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2677
2756
  }
2678
2757
 
2679
2758
  // decoder
2680
2759
  {
2681
- auto & alloc = wstate.alloc_decode.alloc;
2760
+ auto & sched = wstate.sched_decode.sched;
2682
2761
 
2683
2762
  ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
2684
2763
 
2685
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2764
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2686
2765
  // should never happen as we pre-allocate the memory
2687
2766
  return false;
2688
2767
  }
@@ -2705,9 +2784,10 @@ static bool whisper_decode_internal(
2705
2784
  struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
2706
2785
 
2707
2786
  auto & kv_self = wstate.kv_self;
2708
- const int32_t n_kv = kv_self.n;
2709
2787
 
2710
- wstate.inp_mask.resize(n_kv*n_tokens);
2788
+ const int32_t n_kv = kv_self.n;
2789
+
2790
+ wstate.inp_mask.resize(ggml_nelements(KQ_mask));
2711
2791
 
2712
2792
  float * data = wstate.inp_mask.data();
2713
2793
  memset(data, 0, ggml_nbytes(KQ_mask));
@@ -2723,14 +2803,20 @@ static bool whisper_decode_internal(
2723
2803
  }
2724
2804
  }
2725
2805
  }
2806
+
2807
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
2808
+ for (int j = 0; j < n_kv; ++j) {
2809
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
2810
+ }
2811
+ }
2726
2812
  }
2727
2813
 
2728
2814
  ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
2729
2815
  }
2730
2816
 
2731
- logits = gf->nodes[gf->n_nodes - 1];
2817
+ logits = ggml_graph_node(gf, -1);
2732
2818
 
2733
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2819
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2734
2820
  return false;
2735
2821
  }
2736
2822
  }
@@ -2784,29 +2870,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2784
2870
  }
2785
2871
 
2786
2872
  #define SIN_COS_N_COUNT WHISPER_N_FFT
2787
- static float sin_vals[SIN_COS_N_COUNT];
2788
- static float cos_vals[SIN_COS_N_COUNT];
2873
+ namespace {
2874
+ struct whisper_global_cache {
2875
+ // In FFT, we frequently use sine and cosine operations with the same values.
2876
+ // We can use precalculated values to speed up the process.
2877
+ float sin_vals[SIN_COS_N_COUNT];
2878
+ float cos_vals[SIN_COS_N_COUNT];
2879
+
2880
+ // Hann window (Use cosf to eliminate difference)
2881
+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2882
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2883
+ float hann_window[WHISPER_N_FFT];
2884
+
2885
+ whisper_global_cache() {
2886
+ fill_sin_cos_table();
2887
+ fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
2888
+ }
2889
+
2890
+ void fill_sin_cos_table() {
2891
+ for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2892
+ double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
2893
+ sin_vals[i] = sinf(theta);
2894
+ cos_vals[i] = cosf(theta);
2895
+ }
2896
+ }
2789
2897
 
2790
- // In FFT, we frequently use sine and cosine operations with the same values.
2791
- // We can use precalculated values to speed up the process.
2792
- static void fill_sin_cos_table() {
2793
- static bool is_filled = false;
2794
- if (is_filled) return;
2795
- for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2796
- double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
2797
- sin_vals[i] = sinf(theta);
2798
- cos_vals[i] = cosf(theta);
2898
+ void fill_hann_window(int length, bool periodic, float * output) {
2899
+ int offset = -1;
2900
+ if (periodic) {
2901
+ offset = 0;
2902
+ }
2903
+ for (int i = 0; i < length; i++) {
2904
+ output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
2905
+ }
2799
2906
  }
2800
- is_filled = true;
2907
+ } global_cache;
2801
2908
  }
2802
2909
 
2803
2910
  // naive Discrete Fourier Transform
2804
2911
  // input is real-valued
2805
2912
  // output is complex-valued
2806
- static void dft(const std::vector<float> & in, std::vector<float> & out) {
2807
- int N = in.size();
2808
-
2809
- out.resize(N*2);
2913
+ static void dft(const float* in, int N, float* out) {
2810
2914
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2811
2915
 
2812
2916
  for (int k = 0; k < N; k++) {
@@ -2815,8 +2919,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2815
2919
 
2816
2920
  for (int n = 0; n < N; n++) {
2817
2921
  int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
2818
- re += in[n]*cos_vals[idx]; // cos(t)
2819
- im -= in[n]*sin_vals[idx]; // sin(t)
2922
+ re += in[n]*global_cache.cos_vals[idx]; // cos(t)
2923
+ im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
2820
2924
  }
2821
2925
 
2822
2926
  out[k*2 + 0] = re;
@@ -2828,47 +2932,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2828
2932
  // poor man's implementation - use something better
2829
2933
  // input is real-valued
2830
2934
  // output is complex-valued
2831
- static void fft(const std::vector<float> & in, std::vector<float> & out) {
2832
- out.resize(in.size()*2);
2833
-
2834
- int N = in.size();
2835
-
2935
+ static void fft(float* in, int N, float* out) {
2836
2936
  if (N == 1) {
2837
2937
  out[0] = in[0];
2838
2938
  out[1] = 0;
2839
2939
  return;
2840
2940
  }
2841
2941
 
2842
- if (N%2 == 1) {
2843
- dft(in, out);
2942
+ const int half_N = N / 2;
2943
+ if (N - half_N*2 == 1) {
2944
+ dft(in, N, out);
2844
2945
  return;
2845
2946
  }
2846
2947
 
2847
- std::vector<float> even;
2848
- std::vector<float> odd;
2849
-
2850
- even.reserve(N/2);
2851
- odd.reserve(N/2);
2852
-
2853
- for (int i = 0; i < N; i++) {
2854
- if (i % 2 == 0) {
2855
- even.push_back(in[i]);
2856
- } else {
2857
- odd.push_back(in[i]);
2858
- }
2948
+ float* even = in + N;
2949
+ for (int i = 0; i < half_N; ++i) {
2950
+ even[i]= in[2*i];
2859
2951
  }
2952
+ float* even_fft = out + 2 * N;
2953
+ fft(even, half_N, even_fft);
2860
2954
 
2861
- std::vector<float> even_fft;
2862
- std::vector<float> odd_fft;
2863
-
2864
- fft(even, even_fft);
2865
- fft(odd, odd_fft);
2955
+ float* odd = even;
2956
+ for (int i = 0; i < half_N; ++i) {
2957
+ odd[i] = in[2*i + 1];
2958
+ }
2959
+ float* odd_fft = even_fft + N;
2960
+ fft(odd, half_N, odd_fft);
2866
2961
 
2867
2962
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2868
- for (int k = 0; k < N/2; k++) {
2963
+ for (int k = 0; k < half_N; k++) {
2869
2964
  int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2870
- float re = cos_vals[idx]; // cos(t)
2871
- float im = -sin_vals[idx]; // sin(t)
2965
+ float re = global_cache.cos_vals[idx]; // cos(t)
2966
+ float im = -global_cache.sin_vals[idx]; // sin(t)
2872
2967
 
2873
2968
  float re_odd = odd_fft[2*k + 0];
2874
2969
  float im_odd = odd_fft[2*k + 1];
@@ -2876,52 +2971,39 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2876
2971
  out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
2877
2972
  out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
2878
2973
 
2879
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2880
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2881
- }
2882
- }
2883
-
2884
- static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2885
- if (output.size() < static_cast<size_t>(length)) {
2886
- output.resize(length);
2974
+ out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2975
+ out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2887
2976
  }
2888
- int offset = -1;
2889
- if (periodic) {
2890
- offset = 0;
2891
- }
2892
- for (int i = 0; i < length; i++) {
2893
- output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
2894
- }
2895
-
2896
- return true;
2897
2977
  }
2898
2978
 
2899
- static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
2979
+ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
2900
2980
  int n_samples, int frame_size, int frame_step, int n_threads,
2901
2981
  const whisper_filters & filters, whisper_mel & mel) {
2902
- std::vector<float> fft_in(frame_size, 0.0);
2903
- std::vector<float> fft_out(2 * frame_size);
2982
+ std::vector<float> fft_in(frame_size * 2, 0.0);
2983
+ std::vector<float> fft_out(frame_size * 2 * 2 * 2);
2984
+
2904
2985
  int n_fft = filters.n_fft;
2905
2986
  int i = ith;
2906
2987
 
2907
2988
  // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2908
- assert( n_fft == 1 + (frame_size / 2) );
2909
-
2989
+ assert(n_fft == 1 + (frame_size / 2));
2990
+
2910
2991
  // calculate FFT only when fft_in are not all zero
2911
2992
  for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2912
2993
  const int offset = i * frame_step;
2913
2994
 
2914
- // apply Hanning window (~10% faster)
2995
+ // apply Hann window (~10% faster)
2915
2996
  for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
2916
2997
  fft_in[j] = hann[j] * samples[offset + j];
2917
2998
  }
2999
+
2918
3000
  // fill the rest with zeros
2919
3001
  if (n_samples - offset < frame_size) {
2920
3002
  std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
2921
3003
  }
2922
3004
 
2923
3005
  // FFT
2924
- fft(fft_in, fft_out);
3006
+ fft(fft_in.data(), frame_size, fft_out.data());
2925
3007
 
2926
3008
  // Calculate modulus^2 of complex numbers
2927
3009
  // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
@@ -2932,7 +3014,6 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2932
3014
  // mel spectrogram
2933
3015
  for (int j = 0; j < mel.n_mel; j++) {
2934
3016
  double sum = 0.0;
2935
-
2936
3017
  // unroll loop (suggested by GH user @lunixbochs)
2937
3018
  int k = 0;
2938
3019
  for (k = 0; k < n_fft - 3; k += 4) {
@@ -2942,14 +3023,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2942
3023
  fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
2943
3024
  fft_out[k + 3] * filters.data[j * n_fft + k + 3];
2944
3025
  }
2945
-
2946
3026
  // handle n_fft remainder
2947
3027
  for (; k < n_fft; k++) {
2948
3028
  sum += fft_out[k] * filters.data[j * n_fft + k];
2949
3029
  }
2950
-
2951
3030
  sum = log10(std::max(sum, 1e-10));
2952
-
2953
3031
  mel.data[j * mel.n_len + i] = sum;
2954
3032
  }
2955
3033
  }
@@ -2978,12 +3056,9 @@ static bool log_mel_spectrogram(
2978
3056
  whisper_mel & mel) {
2979
3057
  const int64_t t_start_us = ggml_time_us();
2980
3058
 
2981
- // Hanning window (Use cosf to eliminate difference)
2982
- // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2983
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2984
- std::vector<float> hann;
2985
- hann_window(frame_size, true, hann);
2986
-
3059
+ // Hann window
3060
+ WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
3061
+ const float * hann = global_cache.hann_window;
2987
3062
 
2988
3063
  // Calculate the length of padding
2989
3064
  int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
@@ -3008,12 +3083,11 @@ static bool log_mel_spectrogram(
3008
3083
  mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
3009
3084
  mel.data.resize(mel.n_mel * mel.n_len);
3010
3085
 
3011
-
3012
3086
  {
3013
3087
  std::vector<std::thread> workers(n_threads - 1);
3014
3088
  for (int iw = 0; iw < n_threads - 1; ++iw) {
3015
3089
  workers[iw] = std::thread(
3016
- log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
3090
+ log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
3017
3091
  n_samples + stage_2_pad, frame_size, frame_step, n_threads,
3018
3092
  std::cref(filters), std::ref(mel));
3019
3093
  }
@@ -3173,23 +3247,23 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
3173
3247
  #endif
3174
3248
 
3175
3249
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
3176
- fill_sin_cos_table();
3177
-
3178
3250
  whisper_state * state = new whisper_state;
3179
3251
 
3180
- state->backend = whisper_backend_init(ctx->params);
3181
- if (!state->backend) {
3252
+ state->backends = whisper_backend_init(ctx->params);
3253
+ if (state->backends.empty()) {
3182
3254
  WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
3183
3255
  whisper_free_state(state);
3184
3256
  return nullptr;
3185
3257
  }
3186
3258
 
3187
- // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3188
- // in theory, there can be a case where this is not enough, but in practice it should always be enough
3189
- const int factor = 3;
3190
-
3191
- if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
3192
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3259
+ // at this point, we don't know yet how many decoders will be used
3260
+ // later during decoding, if more decoders are used, we will recreate the KV cache respectively
3261
+ state->kv_self_n_dec = 1;
3262
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
3263
+ ctx->model.hparams.n_text_state,
3264
+ ctx->model.hparams.n_text_layer,
3265
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
3266
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3193
3267
  whisper_free_state(state);
3194
3268
  return nullptr;
3195
3269
  }
@@ -3199,8 +3273,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3199
3273
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
3200
3274
  }
3201
3275
 
3202
- if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
3203
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
3276
+ if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
3277
+ ctx->model.hparams.n_text_state,
3278
+ ctx->model.hparams.n_text_layer,
3279
+ GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3280
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
3204
3281
  whisper_free_state(state);
3205
3282
  return nullptr;
3206
3283
  }
@@ -3210,9 +3287,23 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3210
3287
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3211
3288
  }
3212
3289
 
3290
+ if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
3291
+ ctx->model.hparams.n_audio_state,
3292
+ 1,
3293
+ GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3294
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3295
+ whisper_free_state(state);
3296
+ return nullptr;
3297
+ }
3298
+
3299
+ {
3300
+ const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
3301
+ WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
3302
+ }
3303
+
3213
3304
  // [EXPERIMENTAL] Token-level timestamps with DTW
3214
3305
  if (ctx->params.dtw_token_timestamps) {
3215
- if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
3306
+ if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
3216
3307
  WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
3217
3308
  whisper_free_state(state);
3218
3309
  return nullptr;
@@ -3255,7 +3346,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3255
3346
 
3256
3347
  // conv allocator
3257
3348
  {
3258
- bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
3349
+ bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
3259
3350
  [&]() {
3260
3351
  return whisper_build_graph_conv(*ctx, *state);
3261
3352
  });
@@ -3266,12 +3357,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3266
3357
  return nullptr;
3267
3358
  }
3268
3359
 
3269
- WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
3360
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
3270
3361
  }
3271
3362
 
3272
3363
  // encoder allocator
3273
3364
  if (!whisper_encode_external(*state)) {
3274
- bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
3365
+ bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
3275
3366
  [&]() {
3276
3367
  return whisper_build_graph_encoder(*ctx, *state);
3277
3368
  });
@@ -3282,12 +3373,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3282
3373
  return nullptr;
3283
3374
  }
3284
3375
 
3285
- WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
3376
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
3286
3377
  }
3287
3378
 
3288
3379
  // cross allocator
3289
3380
  {
3290
- bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
3381
+ bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
3291
3382
  [&]() {
3292
3383
  return whisper_build_graph_cross(*ctx, *state);
3293
3384
  });
@@ -3298,12 +3389,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3298
3389
  return nullptr;
3299
3390
  }
3300
3391
 
3301
- WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
3392
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
3302
3393
  }
3303
3394
 
3304
3395
  // decoder allocator
3305
3396
  {
3306
- bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
3397
+ bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
3307
3398
  [&]() {
3308
3399
  const auto & hparams = ctx->model.hparams;
3309
3400
 
@@ -3322,19 +3413,21 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3322
3413
  return nullptr;
3323
3414
  }
3324
3415
 
3325
- WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
3416
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
3326
3417
  }
3327
3418
 
3328
3419
  return state;
3329
3420
  }
3330
3421
 
3331
- int whisper_ctx_init_openvino_encoder(
3422
+ int whisper_ctx_init_openvino_encoder_with_state(
3332
3423
  struct whisper_context * ctx,
3424
+ struct whisper_state * state,
3333
3425
  const char * model_path,
3334
3426
  const char * device,
3335
3427
  const char * cache_dir) {
3336
3428
  #ifndef WHISPER_USE_OPENVINO
3337
3429
  (void)(ctx);
3430
+ (void)(state);
3338
3431
  (void)(model_path);
3339
3432
  (void)(device);
3340
3433
  (void)(cache_dir);
@@ -3365,8 +3458,8 @@ int whisper_ctx_init_openvino_encoder(
3365
3458
  WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3366
3459
  WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
3367
3460
 
3368
- ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3369
- if (!ctx->state->ctx_openvino) {
3461
+ state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3462
+ if (!state->ctx_openvino) {
3370
3463
  WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3371
3464
  return 1;
3372
3465
  } else {
@@ -3377,9 +3470,18 @@ int whisper_ctx_init_openvino_encoder(
3377
3470
  #endif
3378
3471
  }
3379
3472
 
3473
+ int whisper_ctx_init_openvino_encoder(
3474
+ struct whisper_context * ctx,
3475
+ const char * model_path,
3476
+ const char * device,
3477
+ const char * cache_dir) {
3478
+ return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
3479
+ }
3480
+
3380
3481
  struct whisper_context_params whisper_context_default_params() {
3381
3482
  struct whisper_context_params result = {
3382
3483
  /*.use_gpu =*/ true,
3484
+ /*.flash_attn =*/ false,
3383
3485
  /*.gpu_device =*/ 0,
3384
3486
 
3385
3487
  /*.dtw_token_timestamps =*/ false,
@@ -3396,8 +3498,14 @@ struct whisper_context_params whisper_context_default_params() {
3396
3498
 
3397
3499
  struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3398
3500
  WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
3399
-
3501
+ #ifdef _MSC_VER
3502
+ // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
3503
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
3504
+ std::wstring path_model_wide = converter.from_bytes(path_model);
3505
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
3506
+ #else
3400
3507
  auto fin = std::ifstream(path_model, std::ios::binary);
3508
+ #endif
3401
3509
  if (!fin) {
3402
3510
  WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
3403
3511
  return nullptr;
@@ -3472,6 +3580,18 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
3472
3580
  struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
3473
3581
  ggml_time_init();
3474
3582
 
3583
+ if (params.flash_attn && params.dtw_token_timestamps) {
3584
+ WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
3585
+ params.dtw_token_timestamps = false;
3586
+ }
3587
+
3588
+ WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
3589
+ WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
3590
+ WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
3591
+ WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
3592
+ WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
3593
+ WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
3594
+
3475
3595
  whisper_context * ctx = new whisper_context;
3476
3596
  ctx->params = params;
3477
3597
 
@@ -3558,8 +3678,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
3558
3678
 
3559
3679
  void whisper_free_state(struct whisper_state * state) {
3560
3680
  if (state) {
3561
- kv_cache_free(state->kv_self);
3562
- kv_cache_free(state->kv_cross);
3681
+ whisper_kv_cache_free(state->kv_self);
3682
+ whisper_kv_cache_free(state->kv_cross);
3683
+ whisper_kv_cache_free(state->kv_pad);
3563
3684
 
3564
3685
  #ifdef WHISPER_USE_COREML
3565
3686
  if (state->ctx_coreml != nullptr) {
@@ -3577,12 +3698,14 @@ void whisper_free_state(struct whisper_state * state) {
3577
3698
 
3578
3699
  whisper_batch_free(state->batch);
3579
3700
 
3580
- ggml_gallocr_free(state->alloc_conv.alloc);
3581
- ggml_gallocr_free(state->alloc_encode.alloc);
3582
- ggml_gallocr_free(state->alloc_cross.alloc);
3583
- ggml_gallocr_free(state->alloc_decode.alloc);
3701
+ ggml_backend_sched_free(state->sched_conv.sched);
3702
+ ggml_backend_sched_free(state->sched_encode.sched);
3703
+ ggml_backend_sched_free(state->sched_cross.sched);
3704
+ ggml_backend_sched_free(state->sched_decode.sched);
3584
3705
 
3585
- ggml_backend_free(state->backend);
3706
+ for (auto & backend : state->backends) {
3707
+ ggml_backend_free(backend);
3708
+ }
3586
3709
 
3587
3710
  // [EXPERIMENTAL] Token-level timestamps with DTW
3588
3711
  aheads_masks_free(state->aheads_masks);
@@ -3599,8 +3722,6 @@ void whisper_free(struct whisper_context * ctx) {
3599
3722
 
3600
3723
  whisper_free_state(ctx->state);
3601
3724
 
3602
- ggml_backend_free(ctx->backend);
3603
-
3604
3725
  delete ctx;
3605
3726
  }
3606
3727
  }
@@ -3630,30 +3751,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3630
3751
  return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3631
3752
  }
3632
3753
 
3633
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3634
- int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3635
- if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3636
- WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3637
- return -1;
3638
- }
3639
-
3640
- return 0;
3641
- }
3642
-
3643
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3644
- int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
3645
- return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3646
- }
3647
-
3648
- // same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
3649
- // TODO
3650
-
3651
- // same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
3652
- // TODO
3653
-
3654
- // same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
3655
- // TODO
3656
-
3657
3754
  int whisper_set_mel_with_state(
3658
3755
  struct whisper_context * ctx,
3659
3756
  struct whisper_state * state,
@@ -3742,7 +3839,7 @@ int whisper_token_count(struct whisper_context * ctx, const char * text) {
3742
3839
  return -whisper_tokenize(ctx, text, NULL, 0);
3743
3840
  }
3744
3841
 
3745
- int whisper_lang_max_id() {
3842
+ int whisper_lang_max_id(void) {
3746
3843
  auto max_id = 0;
3747
3844
  for (const auto & kv : g_lang) {
3748
3845
  max_id = std::max(max_id, kv.second.first);
@@ -4011,6 +4108,19 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
4011
4108
  return ctx->vocab.token_transcribe;
4012
4109
  }
4013
4110
 
4111
+ struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
4112
+ if (ctx->state == nullptr) {
4113
+ return nullptr;
4114
+ }
4115
+ whisper_timings * timings = new whisper_timings;
4116
+ timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
4117
+ timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
4118
+ timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
4119
+ timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd);
4120
+ timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt);
4121
+ return timings;
4122
+ }
4123
+
4014
4124
  void whisper_print_timings(struct whisper_context * ctx) {
4015
4125
  const int64_t t_end_us = ggml_time_us();
4016
4126
 
@@ -4078,17 +4188,14 @@ const char * whisper_print_system_info(void) {
4078
4188
  s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
4079
4189
  s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
4080
4190
  s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
4081
- s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
4082
4191
  s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
4083
4192
  s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
4084
4193
  s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
4085
- s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
4086
4194
  s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
4087
4195
  s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
4088
4196
  s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
4089
- s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
4090
4197
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
4091
- s += "OPENVINO = " + std::to_string(whisper_has_openvino()) ;
4198
+ s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
4092
4199
 
4093
4200
  return s.c_str();
4094
4201
  }
@@ -4099,7 +4206,7 @@ const char * whisper_print_system_info(void) {
4099
4206
 
4100
4207
  // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
4101
4208
  // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
4102
- std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4209
+ static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4103
4210
  const char * src,
4104
4211
  whisper_partial_utf8 partial_start) {
4105
4212
  static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
@@ -4513,7 +4620,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
4513
4620
 
4514
4621
  ////////////////////////////////////////////////////////////////////////////
4515
4622
 
4516
- struct whisper_context_params * whisper_context_default_params_by_ref() {
4623
+ struct whisper_context_params * whisper_context_default_params_by_ref(void) {
4517
4624
  struct whisper_context_params params = whisper_context_default_params();
4518
4625
 
4519
4626
  struct whisper_context_params* result = new whisper_context_params();
@@ -4554,7 +4661,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4554
4661
  /*.split_on_word =*/ false,
4555
4662
  /*.max_tokens =*/ 0,
4556
4663
 
4557
- /*.speed_up =*/ false,
4558
4664
  /*.debug_mode =*/ false,
4559
4665
  /*.audio_ctx =*/ 0,
4560
4666
 
@@ -4720,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
4720
4826
  "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
4721
4827
  };
4722
4828
 
4829
+ static void whisper_compute_logprobs(
4830
+ const std::vector<float> & logits,
4831
+ const int n_logits,
4832
+ std::vector<float> & logprobs) {
4833
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
4834
+ float logsumexp = 0.0f;
4835
+ for (int i = 0; i < n_logits; ++i) {
4836
+ if (logits[i] > -INFINITY) {
4837
+ logsumexp += expf(logits[i] - logit_max);
4838
+ }
4839
+ }
4840
+ logsumexp = logf(logsumexp) + logit_max;
4841
+
4842
+ for (int i = 0; i < n_logits; ++i) {
4843
+ if (logits[i] > -INFINITY) {
4844
+ logprobs[i] = logits[i] - logsumexp;
4845
+ } else {
4846
+ logprobs[i] = -INFINITY;
4847
+ }
4848
+ }
4849
+ }
4850
+
4851
+ static void whisper_compute_probs(
4852
+ const std::vector<float> & logits,
4853
+ const int n_logits,
4854
+ const std::vector<float> & logprobs,
4855
+ std::vector<float> & probs) {
4856
+ for (int i = 0; i < n_logits; ++i) {
4857
+ if (logits[i] == -INFINITY) {
4858
+ probs[i] = 0.0f;
4859
+ } else {
4860
+ probs[i] = expf(logprobs[i]);
4861
+ }
4862
+ }
4863
+ }
4864
+
4723
4865
  // process the logits for the selected decoder
4724
4866
  // - applies logit filters
4725
4867
  // - computes logprobs and probs
@@ -4781,7 +4923,7 @@ static void whisper_process_logits(
4781
4923
 
4782
4924
  // suppress sot and nosp tokens
4783
4925
  logits[vocab.token_sot] = -INFINITY;
4784
- logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
4926
+ logits[vocab.token_nosp] = -INFINITY;
4785
4927
 
4786
4928
  // [TDRZ] when tinydiarize is disabled, suppress solm token
4787
4929
  if (params.tdrz_enable == false) {
@@ -4880,24 +5022,7 @@ static void whisper_process_logits(
4880
5022
  }
4881
5023
 
4882
5024
  // populate the logprobs array (log_softmax)
4883
- {
4884
- const float logit_max = *std::max_element(logits.begin(), logits.end());
4885
- float logsumexp = 0.0f;
4886
- for (int i = 0; i < n_logits; ++i) {
4887
- if (logits[i] > -INFINITY) {
4888
- logsumexp += expf(logits[i] - logit_max);
4889
- }
4890
- }
4891
- logsumexp = logf(logsumexp) + logit_max;
4892
-
4893
- for (int i = 0; i < n_logits; ++i) {
4894
- if (logits[i] > -INFINITY) {
4895
- logprobs[i] = logits[i] - logsumexp;
4896
- } else {
4897
- logprobs[i] = -INFINITY;
4898
- }
4899
- }
4900
- }
5025
+ whisper_compute_logprobs(logits, n_logits, logprobs);
4901
5026
 
4902
5027
  // if sum of probability over timestamps is above any other token, sample timestamp
4903
5028
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
@@ -4955,15 +5080,7 @@ static void whisper_process_logits(
4955
5080
  }
4956
5081
 
4957
5082
  // compute probs
4958
- {
4959
- for (int i = 0; i < n_logits; ++i) {
4960
- if (logits[i] == -INFINITY) {
4961
- probs[i] = 0.0f;
4962
- } else {
4963
- probs[i] = expf(logprobs[i]);
4964
- }
4965
- }
4966
- }
5083
+ whisper_compute_probs(logits, n_logits, logprobs, probs);
4967
5084
 
4968
5085
  #if 0
4969
5086
  // print first 100 logits - token string : logit
@@ -5228,15 +5345,9 @@ int whisper_full_with_state(
5228
5345
 
5229
5346
  if (n_samples > 0) {
5230
5347
  // compute log mel spectrogram
5231
- if (params.speed_up) {
5232
- // TODO: Replace PV with more advanced algorithm
5348
+ if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
5233
5349
  WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5234
- return -1;
5235
- } else {
5236
- if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
5237
- WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5238
- return -2;
5239
- }
5350
+ return -2;
5240
5351
  }
5241
5352
  }
5242
5353
 
@@ -5273,7 +5384,7 @@ int whisper_full_with_state(
5273
5384
  // if length of spectrogram is less than 1.0s (100 frames), then return
5274
5385
  // basically don't process anything that is less than 1.0s
5275
5386
  // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
5276
- if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
5387
+ if (seek_end < seek_start + 100) {
5277
5388
  WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
5278
5389
  return 0;
5279
5390
  }
@@ -5518,13 +5629,46 @@ int whisper_full_with_state(
5518
5629
  }
5519
5630
  WHISPER_LOG_DEBUG("\n\n");
5520
5631
 
5632
+ // recreate the KV cache if the number of decoders has changed
5633
+ if (state->kv_self_n_dec < n_decoders_cur) {
5634
+ WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
5635
+
5636
+ whisper_kv_cache_free(state->kv_self);
5637
+
5638
+ // overallocate to workaround KV cache fragmentation issues
5639
+ const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
5640
+
5641
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
5642
+ ctx->model.hparams.n_text_state,
5643
+ ctx->model.hparams.n_text_layer,
5644
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
5645
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
5646
+ whisper_free_state(state);
5647
+ return -7;
5648
+ }
5649
+
5650
+ state->kv_self_n_dec = n_decoders_cur;
5651
+ }
5652
+
5521
5653
  whisper_kv_cache_clear(state->kv_self);
5522
5654
 
5523
5655
  whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5524
5656
 
5525
5657
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5526
5658
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5527
- return -7;
5659
+ return -8;
5660
+ }
5661
+
5662
+ // Calculate no_speech probability after first decode.
5663
+ // This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
5664
+ {
5665
+ const int n_logits = ctx->vocab.id_to_token.size();
5666
+ std::vector<float> logprobs(n_logits);
5667
+ std::vector<float> probs(n_logits);
5668
+
5669
+ whisper_compute_logprobs(state->logits, n_logits, logprobs);
5670
+ whisper_compute_probs(state->logits, n_logits, logprobs, probs);
5671
+ state->no_speech_prob = probs[whisper_token_nosp(ctx)];
5528
5672
  }
5529
5673
 
5530
5674
  {
@@ -5824,7 +5968,7 @@ int whisper_full_with_state(
5824
5968
 
5825
5969
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5826
5970
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5827
- return -8;
5971
+ return -9;
5828
5972
  }
5829
5973
 
5830
5974
  const int64_t t_start_sample_us = ggml_time_us();
@@ -5918,8 +6062,9 @@ int whisper_full_with_state(
5918
6062
  if (it != (int) temperatures.size() - 1) {
5919
6063
  const auto & decoder = state->decoders[best_decoder_id];
5920
6064
 
5921
- if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
5922
- WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
6065
+ if (decoder.failed ||
6066
+ (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
6067
+ WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
5923
6068
  success = false;
5924
6069
  state->n_fail_p++;
5925
6070
  }
@@ -5940,7 +6085,7 @@ int whisper_full_with_state(
5940
6085
  {
5941
6086
  const auto & best_decoder = state->decoders[best_decoder_id];
5942
6087
 
5943
- const auto seek_delta = best_decoder.seek_delta;
6088
+ auto seek_delta = best_decoder.seek_delta;
5944
6089
  const auto result_len = best_decoder.sequence.result_len;
5945
6090
 
5946
6091
  const auto & tokens_cur = best_decoder.sequence.tokens;
@@ -5948,6 +6093,9 @@ int whisper_full_with_state(
5948
6093
  // [EXPERIMENTAL] Token-level timestamps with DTW
5949
6094
  const auto n_segments_before = state->result_all.size();
5950
6095
 
6096
+ const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
6097
+ best_decoder.sequence.avg_logprobs < params.logprob_thold);
6098
+
5951
6099
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
5952
6100
 
5953
6101
  // update prompt_past
@@ -5956,11 +6104,11 @@ int whisper_full_with_state(
5956
6104
  prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
5957
6105
  }
5958
6106
 
5959
- for (int i = 0; i < result_len; ++i) {
6107
+ for (int i = 0; i < result_len && !is_no_speech; ++i) {
5960
6108
  prompt_past.push_back(tokens_cur[i].id);
5961
6109
  }
5962
6110
 
5963
- if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
6111
+ if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
5964
6112
  int i0 = 0;
5965
6113
  auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
5966
6114
 
@@ -5985,8 +6133,8 @@ int whisper_full_with_state(
5985
6133
  const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
5986
6134
 
5987
6135
  if (!text.empty()) {
5988
- const auto tt0 = params.speed_up ? 2*t0 : t0;
5989
- const auto tt1 = params.speed_up ? 2*t1 : t1;
6136
+ const auto tt0 = t0;
6137
+ const auto tt1 = t1;
5990
6138
 
5991
6139
  if (params.print_realtime) {
5992
6140
  if (params.print_timestamps) {
@@ -6014,7 +6162,7 @@ int whisper_full_with_state(
6014
6162
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
6015
6163
  }
6016
6164
  }
6017
- if (params.new_segment_callback) {
6165
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
6018
6166
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
6019
6167
  }
6020
6168
  }
@@ -6032,8 +6180,8 @@ int whisper_full_with_state(
6032
6180
  if (!text.empty()) {
6033
6181
  const auto t1 = seek + seek_delta;
6034
6182
 
6035
- const auto tt0 = params.speed_up ? 2*t0 : t0;
6036
- const auto tt1 = params.speed_up ? 2*t1 : t1;
6183
+ const auto tt0 = t0;
6184
+ const auto tt1 = t1;
6037
6185
 
6038
6186
  if (params.print_realtime) {
6039
6187
  if (params.print_timestamps) {
@@ -6059,7 +6207,7 @@ int whisper_full_with_state(
6059
6207
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
6060
6208
  }
6061
6209
  }
6062
- if (params.new_segment_callback) {
6210
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
6063
6211
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
6064
6212
  }
6065
6213
  }
@@ -6068,14 +6216,28 @@ int whisper_full_with_state(
6068
6216
  // FIXME: will timestamp offsets be correct?
6069
6217
  // [EXPERIMENTAL] Token-level timestamps with DTW
6070
6218
  {
6071
- const auto n_segments = state->result_all.size() - n_segments_before;
6219
+ const int n_segments = state->result_all.size() - n_segments_before;
6072
6220
  if (ctx->params.dtw_token_timestamps && n_segments) {
6073
6221
  const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
6074
6222
  whisper_exp_compute_token_level_timestamps_dtw(
6075
6223
  ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
6224
+ if (params.new_segment_callback) {
6225
+ for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
6226
+ params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
6227
+ }
6228
+ }
6076
6229
  }
6077
6230
  }
6078
6231
 
6232
+ // ref: https://github.com/ggerganov/whisper.cpp/pull/2629
6233
+ const bool single_timestamp_ending = tokens_cur.size() > 1 &&
6234
+ tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
6235
+ tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
6236
+ if (single_timestamp_ending) {
6237
+ WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n");
6238
+ seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100);
6239
+ }
6240
+
6079
6241
  // update audio window
6080
6242
  seek += seek_delta;
6081
6243
 
@@ -6835,7 +6997,7 @@ static void whisper_exp_compute_token_level_timestamps(
6835
6997
  k++;
6836
6998
  }
6837
6999
  tokens[j].t1 = sample_to_timestamp(k);
6838
- if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
7000
+ if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
6839
7001
  tokens[j].t1 = tokens[j + 1].t0;
6840
7002
  } else {
6841
7003
  s1 = k;
@@ -6998,10 +7160,11 @@ struct median_filter_user_data {
6998
7160
  int filter_width;
6999
7161
  };
7000
7162
 
7001
- static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) {
7163
+ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
7164
+ if (ith != 0) {
7165
+ return;
7166
+ }
7002
7167
  int filter_width = ((median_filter_user_data *) userdata)->filter_width;
7003
- WHISPER_ASSERT(nth == 1);
7004
- WHISPER_ASSERT(ith == 0);
7005
7168
  WHISPER_ASSERT(filter_width < a->ne[2]);
7006
7169
  WHISPER_ASSERT(filter_width % 2);
7007
7170
  WHISPER_ASSERT(ggml_n_dims(a) == 3);
@@ -7124,7 +7287,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7124
7287
  // operation (after median filter)
7125
7288
  // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7126
7289
  // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7127
- w = ggml_norm(gctx, w, 1e-9);
7290
+ w = ggml_norm(gctx, w, 1e-9f);
7128
7291
  w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
7129
7292
 
7130
7293
  // Pass median filter - this is done over AUDIO_TOKENS dimension.
@@ -7196,6 +7359,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7196
7359
  void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
7197
7360
  g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
7198
7361
  g_state.log_callback_user_data = user_data;
7362
+ ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
7199
7363
  }
7200
7364
 
7201
7365
  GGML_ATTRIBUTE_FORMAT(2, 3)
@@ -7219,6 +7383,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
7219
7383
  static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
7220
7384
  (void) level;
7221
7385
  (void) user_data;
7386
+ #ifndef WHISPER_DEBUG
7387
+ if (level == GGML_LOG_LEVEL_DEBUG) {
7388
+ return;
7389
+ }
7390
+ #endif
7222
7391
  fputs(text, stderr);
7223
7392
  fflush(stderr);
7224
7393
  }