whispercpp 1.3.0 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -1,29 +1,19 @@
1
1
  #include "whisper.h"
2
2
 
3
- #ifdef WHISPER_USE_COREML
4
- #include "coreml/whisper-encoder.h"
5
- #endif
3
+ #include "ggml-cpu.h"
6
4
 
7
- #ifdef GGML_USE_METAL
8
- #include "ggml-metal.h"
9
- #endif
10
-
11
- #ifdef GGML_USE_CUDA
12
- #include "ggml-cuda.h"
13
- #endif
5
+ #include "ggml.h"
6
+ #include "ggml-alloc.h"
7
+ #include "ggml-backend.h"
14
8
 
15
- #ifdef GGML_USE_SYCL
16
- #include "ggml-sycl.h"
9
+ #ifdef WHISPER_USE_COREML
10
+ #include "coreml/whisper-encoder.h"
17
11
  #endif
18
12
 
19
13
  #ifdef WHISPER_USE_OPENVINO
20
14
  #include "openvino/whisper-openvino-encoder.h"
21
15
  #endif
22
16
 
23
- #include "ggml.h"
24
- #include "ggml-alloc.h"
25
- #include "ggml-backend.h"
26
-
27
17
  #include <atomic>
28
18
  #include <algorithm>
29
19
  #include <cassert>
@@ -41,6 +31,7 @@
41
31
  #include <regex>
42
32
  #include <random>
43
33
  #include <functional>
34
+ #include <codecvt>
44
35
 
45
36
  #if defined(_MSC_VER)
46
37
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -147,8 +138,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
147
138
  } \
148
139
  } while (0)
149
140
 
150
- //#define WHISPER_USE_FLASH_ATTN
151
- //#define WHISPER_USE_FLASH_FF
152
141
  #define WHISPER_MAX_DECODERS 8
153
142
  #define WHISPER_MAX_NODES 4096
154
143
 
@@ -162,7 +151,7 @@ static bool ggml_graph_compute_helper(
162
151
  int n_threads,
163
152
  ggml_abort_callback abort_callback,
164
153
  void * abort_callback_data) {
165
- struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
154
+ struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
166
155
 
167
156
  plan.abort_callback = abort_callback;
168
157
  plan.abort_callback_data = abort_callback_data;
@@ -176,18 +165,24 @@ static bool ggml_graph_compute_helper(
176
165
  }
177
166
 
178
167
  static bool ggml_graph_compute_helper(
179
- struct ggml_backend * backend,
168
+ ggml_backend_sched_t sched,
180
169
  struct ggml_cgraph * graph,
181
170
  int n_threads) {
182
- if (ggml_backend_is_cpu(backend)) {
183
- ggml_backend_cpu_set_n_threads(backend, n_threads);
184
- }
185
- #ifdef GGML_USE_METAL
186
- if (ggml_backend_is_metal(backend)) {
187
- ggml_backend_metal_set_n_cb(backend, n_threads);
171
+
172
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
173
+ ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
174
+ ggml_backend_dev_t dev = ggml_backend_get_device(backend);
175
+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
176
+
177
+ auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
178
+ if (fn_set_n_threads) {
179
+ fn_set_n_threads(backend, n_threads);
180
+ }
188
181
  }
189
- #endif
190
- return ggml_backend_graph_compute(backend, graph) == GGML_STATUS_SUCCESS;
182
+
183
+ bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
184
+ ggml_backend_sched_reset(sched);
185
+ return t;
191
186
  }
192
187
 
193
188
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
@@ -363,6 +358,7 @@ static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15},
363
358
  static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
364
359
  static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
365
360
  static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
361
+ static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
366
362
 
367
363
  static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
368
364
  { WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
@@ -376,6 +372,7 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
376
372
  { WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
377
373
  { WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
378
374
  { WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
375
+ { WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
379
376
  };
380
377
 
381
378
  static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
@@ -502,33 +499,41 @@ struct whisper_pair {
502
499
  whisper_pair() : first(A()), second(B()) {}
503
500
  };
504
501
 
505
- // ggml_allocr wrapper for whisper usage
506
- struct whisper_allocr {
507
- ggml_gallocr_t alloc = nullptr;
502
+ // ggml_backend_sched wrapper for whisper usage
503
+ struct whisper_sched {
504
+ ggml_backend_sched_t sched = nullptr;
508
505
 
509
506
  std::vector<uint8_t> meta;
510
507
  };
511
508
 
512
- static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
513
- return allocr.meta.size() + ggml_gallocr_get_buffer_size(allocr.alloc, 0);
509
+ static size_t whisper_sched_size(struct whisper_sched & allocr) {
510
+ size_t size = allocr.meta.size();
511
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
512
+ ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
513
+ size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
514
+ }
515
+ return size;
514
516
  }
515
517
 
516
518
  // measure the memory usage of a graph and prepare the allocr's internal data buffer
517
- static bool whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
518
- auto & alloc = allocr.alloc;
519
+ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
520
+ auto & sched = allocr.sched;
519
521
  auto & meta = allocr.meta;
520
522
 
521
- alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
523
+ sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
522
524
 
523
525
  meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
524
526
 
525
527
  // since there are dependencies between the different graphs,
526
528
  // we need to allocate them instead of only reserving to get the correct compute buffer size
527
- if (!ggml_gallocr_alloc_graph(alloc, get_graph())) {
529
+ if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
528
530
  // failed to allocate the compute buffer
529
531
  WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
530
532
  return false;
531
533
  }
534
+
535
+ ggml_backend_sched_reset(sched);
536
+
532
537
  return true;
533
538
  }
534
539
 
@@ -671,9 +676,9 @@ struct whisper_kv_cache {
671
676
  struct ggml_tensor * k;
672
677
  struct ggml_tensor * v;
673
678
 
674
- struct ggml_context * ctx = nullptr;
675
-
676
679
  ggml_backend_buffer_t buffer = nullptr;
680
+
681
+ std::vector<uint8_t> ctx_buf;
677
682
  };
678
683
 
679
684
  struct whisper_model {
@@ -802,6 +807,9 @@ struct whisper_state {
802
807
  int32_t n_fail_p = 0; // number of logprob threshold failures
803
808
  int32_t n_fail_h = 0; // number of entropy threshold failures
804
809
 
810
+ // number of decoders for which we have constructed the KV cache
811
+ int32_t kv_self_n_dec = 0;
812
+
805
813
  // unified self-attention KV cache for all decoders
806
814
  whisper_kv_cache kv_self;
807
815
 
@@ -809,21 +817,22 @@ struct whisper_state {
809
817
  // shared between all decoders
810
818
  whisper_kv_cache kv_cross;
811
819
 
820
+ // padded buffer for flash-attention
821
+ whisper_kv_cache kv_pad;
822
+
812
823
  whisper_mel mel;
813
824
 
814
825
  whisper_batch batch;
815
826
 
816
827
  whisper_decoder decoders[WHISPER_MAX_DECODERS];
817
828
 
818
- ggml_backend_t backend = nullptr;
829
+ std::vector<ggml_backend_t> backends;
819
830
 
820
- // ggml-alloc:
821
831
  // - stores meta info about the intermediate tensors into the `meta` buffers
822
- // - stores the actual tensor data into the `data` buffers
823
- whisper_allocr alloc_conv;
824
- whisper_allocr alloc_encode;
825
- whisper_allocr alloc_cross;
826
- whisper_allocr alloc_decode;
832
+ whisper_sched sched_conv;
833
+ whisper_sched sched_encode;
834
+ whisper_sched sched_cross;
835
+ whisper_sched sched_decode;
827
836
 
828
837
  // result of the encoder
829
838
  struct ggml_tensor * embd_conv = nullptr;
@@ -858,6 +867,7 @@ struct whisper_state {
858
867
  whisper_token tid_last;
859
868
 
860
869
  std::vector<float> energy; // PCM signal energy
870
+ float no_speech_prob = 0.0f;
861
871
 
862
872
  // [EXPERIMENTAL] Token-level timestamps with DTW
863
873
  whisper_aheads_masks aheads_masks;
@@ -882,8 +892,6 @@ struct whisper_context {
882
892
 
883
893
  whisper_state * state = nullptr;
884
894
 
885
- ggml_backend_t backend = nullptr;
886
-
887
895
  std::string path_model; // populated by whisper_init_from_file_with_params()
888
896
  };
889
897
 
@@ -901,21 +909,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
901
909
  BYTESWAP_VALUE(dest);
902
910
  }
903
911
 
904
- static bool kv_cache_init(
905
- const struct whisper_hparams & hparams,
912
+ static bool whisper_kv_cache_init(
906
913
  struct whisper_kv_cache & cache,
907
914
  ggml_backend_t backend,
908
915
  ggml_type wtype,
916
+ int64_t n_text_state,
917
+ int64_t n_text_layer,
909
918
  int n_ctx) {
910
- const int64_t n_text_state = hparams.n_text_state;
911
- const int64_t n_text_layer = hparams.n_text_layer;
912
-
913
919
  const int64_t n_mem = n_text_layer*n_ctx;
914
920
  const int64_t n_elements = n_text_state*n_mem;
915
921
 
922
+ cache.ctx_buf.resize(2*ggml_tensor_overhead());
923
+
916
924
  struct ggml_init_params params = {
917
- /*.mem_size =*/ 2*ggml_tensor_overhead(),
918
- /*.mem_buffer =*/ nullptr,
925
+ /*.mem_size =*/ cache.ctx_buf.size(),
926
+ /*.mem_buffer =*/ cache.ctx_buf.data(),
919
927
  /*.no_alloc =*/ true,
920
928
  };
921
929
 
@@ -925,29 +933,31 @@ static bool kv_cache_init(
925
933
  cache.cells.clear();
926
934
  cache.cells.resize(n_ctx);
927
935
 
928
- cache.ctx = ggml_init(params);
936
+ struct ggml_context * ctx = ggml_init(params);
929
937
 
930
- if (!cache.ctx) {
938
+ if (!ctx) {
931
939
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
932
940
  return false;
933
941
  }
934
942
 
935
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
936
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
943
+ cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
944
+ cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);
937
945
 
938
- cache.buffer = ggml_backend_alloc_ctx_tensors(cache.ctx, backend);
946
+ cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
939
947
  if (!cache.buffer) {
940
948
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
941
949
  return false;
942
950
  }
943
951
 
952
+ ggml_backend_buffer_clear(cache.buffer, 0);
953
+
954
+ ggml_free(ctx);
955
+
944
956
  return true;
945
957
  }
946
958
 
947
- static void kv_cache_free(struct whisper_kv_cache & cache) {
948
- ggml_free(cache.ctx);
959
+ static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
949
960
  ggml_backend_buffer_free(cache.buffer);
950
- cache.ctx = nullptr;
951
961
  }
952
962
 
953
963
  static bool whisper_kv_cache_find_slot(
@@ -1018,6 +1028,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
1018
1028
  cache.cells[i].seq_id.clear();
1019
1029
  }
1020
1030
  cache.head = 0;
1031
+
1032
+ ggml_backend_buffer_clear(cache.buffer, 0);
1021
1033
  }
1022
1034
 
1023
1035
  static void whisper_kv_cache_seq_rm(
@@ -1068,6 +1080,26 @@ static void whisper_kv_cache_seq_cp(
1068
1080
  }
1069
1081
  }
1070
1082
 
1083
+ static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
1084
+ if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
1085
+ return 1u;
1086
+ }
1087
+
1088
+ #ifdef GGML_USE_METAL
1089
+ if (wctx.params.use_gpu) {
1090
+ return 32u;
1091
+ }
1092
+ #endif
1093
+
1094
+ #ifdef GGML_USE_CUDA
1095
+ if (wctx.params.use_gpu) {
1096
+ return 256u;
1097
+ }
1098
+ #endif
1099
+
1100
+ return 1u;
1101
+ }
1102
+
1071
1103
  // [EXPERIMENTAL] Token-level timestamps with DTW
1072
1104
  static bool aheads_masks_init(
1073
1105
  const whisper_context_params & cparams,
@@ -1199,49 +1231,71 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1199
1231
  return size;
1200
1232
  }
1201
1233
 
1202
- static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1203
- ggml_backend_t backend_gpu = NULL;
1234
+ static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1235
+ ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1204
1236
 
1205
- // initialize the backends
1206
- #ifdef GGML_USE_CUDA
1207
1237
  if (params.use_gpu) {
1208
- WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1209
- backend_gpu = ggml_backend_cuda_init(params.gpu_device);
1210
- if (!backend_gpu) {
1211
- WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
1238
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1239
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1240
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1241
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1242
+ ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1243
+ if (!result) {
1244
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1245
+ }
1246
+ return result;
1247
+ }
1212
1248
  }
1213
1249
  }
1214
- #endif
1215
1250
 
1216
- #ifdef GGML_USE_METAL
1217
- if (params.use_gpu) {
1218
- WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1219
- ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
1220
- backend_gpu = ggml_backend_metal_init();
1221
- if (!backend_gpu) {
1222
- WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
1223
- } else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
1224
- WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1225
- ggml_backend_free(backend_gpu);
1226
- backend_gpu = NULL;
1227
- }
1251
+ return nullptr;
1252
+ }
1253
+
1254
+ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
1255
+ std::vector<ggml_backend_t> result;
1256
+
1257
+ ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
1258
+
1259
+ if (backend_gpu) {
1260
+ result.push_back(backend_gpu);
1228
1261
  }
1229
- #endif
1230
1262
 
1231
- #ifdef GGML_USE_SYCL
1232
- if (params.use_gpu) {
1233
- WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
1234
- backend_gpu = ggml_backend_sycl_init(params.gpu_device);
1235
- if (!backend_gpu) {
1236
- WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
1263
+ // ACCEL backends
1264
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1265
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1266
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
1267
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1268
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
1269
+ if (!backend) {
1270
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1271
+ continue;
1272
+ }
1273
+ result.push_back(backend);
1237
1274
  }
1238
1275
  }
1239
- #endif
1240
1276
 
1241
- if (backend_gpu) {
1242
- return backend_gpu;
1277
+ GGML_UNUSED(params);
1278
+
1279
+ result.push_back(ggml_backend_cpu_init());
1280
+
1281
+ return result;
1282
+ }
1283
+
1284
+ static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1285
+ if (!params.use_gpu) {
1286
+ return ggml_backend_cpu_buffer_type();
1243
1287
  }
1244
- return ggml_backend_cpu_init();
1288
+
1289
+ // if we have a GPU device - use it
1290
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1291
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1292
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1293
+ WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
1294
+ return ggml_backend_dev_buffer_type(dev);
1295
+ }
1296
+ }
1297
+
1298
+ return ggml_backend_cpu_buffer_type();
1245
1299
  }
1246
1300
 
1247
1301
  // load the model from a ggml file
@@ -1668,21 +1722,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1668
1722
  }
1669
1723
  }
1670
1724
 
1671
- wctx.backend = whisper_backend_init(wctx.params);
1672
- if (!wctx.backend) {
1673
- WHISPER_LOG_ERROR("%s: failed to initialize the backend\n", __func__);
1674
- return false;
1675
- }
1676
-
1677
1725
  // allocate tensors in the backend buffers
1678
- model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, wctx.backend);
1726
+ model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
1679
1727
  if (!model.buffer) {
1680
1728
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1681
1729
  return false;
1682
1730
  }
1683
1731
 
1684
1732
  size_t size_main = ggml_backend_buffer_get_size(model.buffer);
1685
- WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6);
1733
+ WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
1686
1734
 
1687
1735
  // load weights
1688
1736
  {
@@ -1777,6 +1825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1777
1825
  }
1778
1826
  }
1779
1827
 
1828
+ ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1829
+
1780
1830
  wctx.t_load_us = ggml_time_us() - t_start_us;
1781
1831
 
1782
1832
  return true;
@@ -1812,8 +1862,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1812
1862
  const int n_mels = hparams.n_mels;
1813
1863
 
1814
1864
  struct ggml_init_params params = {
1815
- /*.mem_size =*/ wstate.alloc_conv.meta.size(),
1816
- /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
1865
+ /*.mem_size =*/ wstate.sched_conv.meta.size(),
1866
+ /*.mem_buffer =*/ wstate.sched_conv.meta.data(),
1817
1867
  /*.no_alloc =*/ true,
1818
1868
  };
1819
1869
 
@@ -1847,6 +1897,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1847
1897
  ggml_build_forward_expand(gf, mel);
1848
1898
 
1849
1899
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1900
+ ggml_set_input(cur); // the external encoder will write into this tensor
1850
1901
 
1851
1902
  ggml_set_name(cur, "embd_enc");
1852
1903
  wstate.embd_enc = cur;
@@ -1872,9 +1923,17 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1872
1923
  const int n_head = hparams.n_audio_head;
1873
1924
  const int n_layer = hparams.n_audio_layer;
1874
1925
 
1926
+ const int n_state_head = n_state/n_head;
1927
+
1928
+ auto & kv_pad = wstate.kv_pad;
1929
+
1930
+ WHISPER_ASSERT(!!kv_pad.buffer);
1931
+
1932
+ const int n_ctx_pad = GGML_PAD(n_ctx, 256);
1933
+
1875
1934
  struct ggml_init_params params = {
1876
- /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1877
- /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
1935
+ /*.mem_size =*/ wstate.sched_encode.meta.size(),
1936
+ /*.mem_buffer =*/ wstate.sched_encode.meta.data(),
1878
1937
  /*.no_alloc =*/ true,
1879
1938
  };
1880
1939
 
@@ -1884,7 +1943,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1884
1943
 
1885
1944
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
1886
1945
 
1887
- const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
1946
+ const float KQscale = 1.0f/sqrtf(float(n_state_head));
1888
1947
 
1889
1948
  // ===================================================================
1890
1949
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1934,14 +1993,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1934
1993
 
1935
1994
  Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
1936
1995
 
1937
- //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
1996
+ //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
1938
1997
 
1939
1998
  // note: no bias for Key
1940
1999
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1941
2000
  layer.attn_k_w,
1942
2001
  cur);
1943
2002
 
1944
- //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
2003
+ //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
1945
2004
 
1946
2005
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1947
2006
  layer.attn_v_w,
@@ -1951,70 +2010,60 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1951
2010
 
1952
2011
  // ------
1953
2012
 
1954
- #ifdef WHISPER_USE_FLASH_ATTN
1955
2013
  struct ggml_tensor * Q =
1956
2014
  ggml_permute(ctx0,
1957
- ggml_cpy(ctx0,
1958
- Qcur,
1959
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
2015
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
1960
2016
  0, 2, 1, 3);
1961
2017
 
1962
- struct ggml_tensor * K =
1963
- ggml_permute(ctx0,
1964
- ggml_cpy(ctx0,
1965
- Kcur,
1966
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1967
- 0, 2, 1, 3);
2018
+ if (wctx.params.flash_attn) {
2019
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
2020
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
1968
2021
 
1969
- struct ggml_tensor * V =
1970
- ggml_cpy(ctx0,
1971
- ggml_permute(ctx0,
1972
- ggml_reshape_3d(ctx0,
1973
- Vcur,
1974
- n_state/n_head, n_head, n_ctx),
1975
- 1, 2, 0, 3),
1976
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
2022
+ struct ggml_tensor * K =
2023
+ ggml_view_3d(ctx0, kv_pad.k,
2024
+ n_state_head, n_ctx_pad, n_head,
2025
+ ggml_element_size(kv_pad.k)*n_state,
2026
+ ggml_element_size(kv_pad.k)*n_state_head,
2027
+ 0);
1977
2028
 
1978
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1979
- #else
1980
- struct ggml_tensor * Q =
1981
- ggml_permute(ctx0,
1982
- ggml_cpy(ctx0,
1983
- Qcur,
1984
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1985
- 0, 2, 1, 3);
2029
+ struct ggml_tensor * V =
2030
+ ggml_view_3d(ctx0, kv_pad.v,
2031
+ n_state_head, n_ctx_pad, n_head,
2032
+ ggml_element_size(kv_pad.v)*n_state,
2033
+ ggml_element_size(kv_pad.v)*n_state_head,
2034
+ 0);
1986
2035
 
1987
- struct ggml_tensor * K =
1988
- ggml_permute(ctx0,
1989
- ggml_cpy(ctx0,
1990
- Kcur,
1991
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1992
- 0, 2, 1, 3);
2036
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
1993
2037
 
1994
- // K * Q
1995
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2038
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
2039
+ } else {
2040
+ struct ggml_tensor * K =
2041
+ ggml_permute(ctx0,
2042
+ ggml_cast(ctx0,
2043
+ ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
2044
+ wctx.itype),
2045
+ 0, 2, 1, 3);
1996
2046
 
1997
- struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
2047
+ // K * Q
2048
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1998
2049
 
1999
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
2050
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2000
2051
 
2001
- struct ggml_tensor * V =
2002
- ggml_cpy(ctx0,
2003
- ggml_permute(ctx0,
2004
- ggml_reshape_3d(ctx0,
2005
- Vcur,
2006
- n_state/n_head, n_head, n_ctx),
2007
- 1, 2, 0, 3),
2008
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
2009
- );
2052
+ struct ggml_tensor * V =
2053
+ ggml_cast(ctx0,
2054
+ ggml_permute(ctx0,
2055
+ ggml_reshape_3d(ctx0,
2056
+ Vcur,
2057
+ n_state_head, n_head, n_ctx),
2058
+ 1, 2, 0, 3),
2059
+ wctx.itype);
2010
2060
 
2011
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2012
- #endif
2013
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2061
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2014
2062
 
2015
- cur = ggml_cpy(ctx0,
2016
- KQV_merged,
2017
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
2063
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2064
+
2065
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
2066
+ }
2018
2067
  }
2019
2068
 
2020
2069
  // projection
@@ -2043,11 +2092,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2043
2092
  layer.mlp_ln_b);
2044
2093
  }
2045
2094
 
2046
- #ifdef WHISPER_USE_FLASH_FF
2047
- cur = ggml_flash_ff(ctx0,
2048
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
2049
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
2050
- #else
2051
2095
  // fully connected
2052
2096
  cur = ggml_mul_mat(ctx0,
2053
2097
  layer.mlp_0_w,
@@ -2064,7 +2108,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2064
2108
  cur);
2065
2109
 
2066
2110
  cur = ggml_add(ctx0, cur, layer.mlp_1_b);
2067
- #endif
2068
2111
  }
2069
2112
 
2070
2113
  inpL = ggml_add(ctx0, cur, inpFF);
@@ -2113,9 +2156,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2113
2156
  const int n_state = hparams.n_audio_state;
2114
2157
  const int n_head = hparams.n_audio_head;
2115
2158
 
2159
+ const int n_state_head = n_state/n_head;
2160
+
2161
+ const int n_ctx_pad = GGML_PAD(n_ctx, 256);
2162
+
2116
2163
  struct ggml_init_params params = {
2117
- /*.mem_size =*/ wstate.alloc_cross.meta.size(),
2118
- /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
2164
+ /*.mem_size =*/ wstate.sched_cross.meta.size(),
2165
+ /*.mem_buffer =*/ wstate.sched_cross.meta.data(),
2119
2166
  /*.no_alloc =*/ true,
2120
2167
  };
2121
2168
 
@@ -2125,18 +2172,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2125
2172
 
2126
2173
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
2127
2174
 
2128
- const float Kscale = pow(float(n_state) / n_head, -0.25);
2175
+ const float Kscale = pow(float(n_state_head), -0.25);
2129
2176
 
2130
2177
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
2131
2178
  auto & layer = model.layers_decoder[il];
2132
2179
 
2133
- struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
2180
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
2134
2181
  layer.cross_attn_k_w,
2135
2182
  cur);
2136
2183
 
2137
2184
  Kcross = ggml_scale(ctx0, Kcross, Kscale);
2138
2185
 
2139
- struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
2186
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
2140
2187
  layer.cross_attn_v_w,
2141
2188
  cur);
2142
2189
 
@@ -2144,15 +2191,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2144
2191
  Vcross,
2145
2192
  layer.cross_attn_v_b);
2146
2193
 
2147
- Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2194
+ struct ggml_tensor * k;
2195
+ struct ggml_tensor * v;
2148
2196
 
2149
- struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
2150
- n_state*n_ctx,
2151
- (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2197
+ if (wctx.params.flash_attn) {
2198
+ k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2199
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
2152
2200
 
2153
- struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2154
- ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
2155
- (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
2201
+ v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
2202
+ (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
2203
+ } else {
2204
+ Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2205
+
2206
+ k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2207
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2208
+
2209
+ v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2210
+ ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
2211
+ (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
2212
+ }
2156
2213
 
2157
2214
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
2158
2215
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
@@ -2186,11 +2243,11 @@ static bool whisper_encode_internal(
2186
2243
 
2187
2244
  // conv
2188
2245
  {
2189
- auto & alloc = wstate.alloc_conv.alloc;
2246
+ auto & sched = wstate.sched_conv.sched;
2190
2247
 
2191
2248
  ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
2192
2249
 
2193
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2250
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2194
2251
  // should never happen as we pre-allocate the memory
2195
2252
  return false;
2196
2253
  }
@@ -2223,7 +2280,7 @@ static bool whisper_encode_internal(
2223
2280
  }
2224
2281
 
2225
2282
  if (!whisper_encode_external(wstate)) {
2226
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2283
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2227
2284
  return false;
2228
2285
  }
2229
2286
  } else {
@@ -2237,32 +2294,32 @@ static bool whisper_encode_internal(
2237
2294
 
2238
2295
  // encoder
2239
2296
  if (!whisper_encode_external(wstate)) {
2240
- auto & alloc = wstate.alloc_encode.alloc;
2297
+ auto & sched = wstate.sched_encode.sched;
2241
2298
 
2242
2299
  ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
2243
2300
 
2244
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2301
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2245
2302
  // should never happen as we pre-allocate the memory
2246
2303
  return false;
2247
2304
  }
2248
2305
 
2249
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2306
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2250
2307
  return false;
2251
2308
  }
2252
2309
  }
2253
2310
 
2254
2311
  // cross
2255
2312
  {
2256
- auto & alloc = wstate.alloc_cross.alloc;
2313
+ auto & sched = wstate.sched_cross.sched;
2257
2314
 
2258
2315
  ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
2259
2316
 
2260
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2317
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2261
2318
  // should never happen as we pre-allocate the memory
2262
2319
  return false;
2263
2320
  }
2264
2321
 
2265
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2322
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2266
2323
  return false;
2267
2324
  }
2268
2325
  }
@@ -2284,24 +2341,28 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2284
2341
 
2285
2342
  auto & kv_self = wstate.kv_self;
2286
2343
 
2287
- WHISPER_ASSERT(!!kv_self.ctx);
2344
+ WHISPER_ASSERT(!!kv_self.buffer);
2288
2345
 
2289
2346
  const int n_ctx = kv_self.size;
2290
2347
  const int n_state = hparams.n_text_state;
2291
2348
  const int n_head = hparams.n_text_head;
2292
2349
  const int n_layer = hparams.n_text_layer;
2293
2350
 
2351
+ const int n_state_head = n_state/n_head;
2352
+
2294
2353
  const int n_tokens = batch.n_tokens;
2295
2354
  const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2296
2355
 
2297
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2298
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
2356
+ const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
2357
+
2358
+ const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2359
+ const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
2299
2360
 
2300
2361
  //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2301
2362
 
2302
2363
  struct ggml_init_params params = {
2303
- /*.mem_size =*/ wstate.alloc_decode.meta.size(),
2304
- /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
2364
+ /*.mem_size =*/ wstate.sched_decode.meta.size(),
2365
+ /*.mem_buffer =*/ wstate.sched_decode.meta.data(),
2305
2366
  /*.no_alloc =*/ true,
2306
2367
  };
2307
2368
 
@@ -2317,12 +2378,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2317
2378
  ggml_set_name(position, "position");
2318
2379
  ggml_set_input(position);
2319
2380
 
2320
- const float KQscale = pow(float(n_state)/n_head, -0.25);
2381
+ const float KQscale = pow(float(n_state_head), -0.25);
2321
2382
 
2322
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
2383
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
2323
2384
  ggml_set_name(KQ_mask, "KQ_mask");
2324
2385
  ggml_set_input(KQ_mask);
2325
2386
 
2387
+ struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
2388
+
2326
2389
  // token encoding + position encoding
2327
2390
  struct ggml_tensor * cur =
2328
2391
  ggml_add(ctx0,
@@ -2378,12 +2441,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2378
2441
  Vcur,
2379
2442
  layer.attn_v_b);
2380
2443
 
2381
- Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2444
+ struct ggml_tensor * k;
2445
+ struct ggml_tensor * v;
2446
+
2447
+ if (wctx.params.flash_attn) {
2448
+ k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2449
+ (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2450
+
2451
+ v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
2452
+ (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
2453
+ } else {
2454
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2455
+
2456
+ k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2457
+ (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2382
2458
 
2383
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2384
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2385
- ( n_ctx)*ggml_element_size(kv_self.v),
2386
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
2459
+ v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2460
+ ( n_ctx)*ggml_element_size(kv_self.v),
2461
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
2462
+ }
2387
2463
 
2388
2464
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
2389
2465
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@@ -2393,40 +2469,46 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2393
2469
 
2394
2470
  struct ggml_tensor * Q =
2395
2471
  ggml_permute(ctx0,
2396
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2472
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2397
2473
  0, 2, 1, 3);
2398
2474
 
2399
2475
  struct ggml_tensor * K =
2400
2476
  ggml_view_3d(ctx0, kv_self.k,
2401
- n_state/n_head, n_kv, n_head,
2477
+ n_state_head, n_kv, n_head,
2402
2478
  ggml_element_size(kv_self.k)*n_state,
2403
- ggml_element_size(kv_self.k)*n_state/n_head,
2479
+ ggml_element_size(kv_self.k)*n_state_head,
2404
2480
  ggml_element_size(kv_self.k)*n_state*n_ctx*il);
2405
2481
 
2406
- // K * Q
2407
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2482
+ if (wctx.params.flash_attn) {
2483
+ struct ggml_tensor * V =
2484
+ ggml_view_3d(ctx0, kv_self.v,
2485
+ n_state_head, n_kv, n_head,
2486
+ ggml_element_size(kv_self.v)*n_state,
2487
+ ggml_element_size(kv_self.v)*n_state_head,
2488
+ ggml_element_size(kv_self.v)*n_state*n_ctx*il);
2408
2489
 
2409
- //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
2490
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
2410
2491
 
2411
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
2412
- struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
2492
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2493
+ } else {
2494
+ // K * Q
2495
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2413
2496
 
2414
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
2497
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
2415
2498
 
2416
- struct ggml_tensor * V =
2417
- ggml_view_3d(ctx0, kv_self.v,
2418
- n_kv, n_state/n_head, n_head,
2419
- n_ctx*ggml_element_size(kv_self.v),
2420
- n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
2421
- n_ctx*ggml_element_size(kv_self.v)*n_state*il);
2499
+ struct ggml_tensor * V =
2500
+ ggml_view_3d(ctx0, kv_self.v,
2501
+ n_kv, n_state_head, n_head,
2502
+ n_ctx*ggml_element_size(kv_self.v),
2503
+ n_ctx*ggml_element_size(kv_self.v)*n_state_head,
2504
+ n_ctx*ggml_element_size(kv_self.v)*n_state*il);
2422
2505
 
2423
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2506
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2424
2507
 
2425
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2508
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2426
2509
 
2427
- cur = ggml_cpy(ctx0,
2428
- KQV_merged,
2429
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2510
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2511
+ }
2430
2512
  }
2431
2513
 
2432
2514
  // projection
@@ -2465,80 +2547,75 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2465
2547
  Qcur,
2466
2548
  layer.cross_attn_q_b);
2467
2549
 
2468
- Qcur = ggml_scale(ctx0, Qcur, KQscale);
2469
-
2470
- // Kcross is already scaled
2471
- struct ggml_tensor * Kcross =
2472
- ggml_view_3d(ctx0, wstate.kv_cross.k,
2473
- n_state/n_head, n_audio_ctx, n_head,
2474
- ggml_element_size(wstate.kv_cross.k)*n_state,
2475
- ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2476
- ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2477
-
2478
- //struct ggml_tensor * Vcross =
2479
- // ggml_reshape_3d(ctx0,
2480
- // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
2481
- // n_state/n_head, n_head, n_audio_ctx);
2482
-
2483
- //struct ggml_tensor * V_trans =
2484
- // ggml_cpy(ctx0,
2485
- // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2486
- // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
2487
-
2488
- struct ggml_tensor * V =
2489
- ggml_view_3d(ctx0, wstate.kv_cross.v,
2490
- n_audio_ctx, n_state/n_head, n_head,
2491
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
2492
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2493
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
2494
-
2495
- // ------
2496
-
2497
2550
  struct ggml_tensor * Q =
2498
2551
  ggml_permute(ctx0,
2499
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2552
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2500
2553
  0, 2, 1, 3);
2501
2554
 
2502
- // K * Q
2503
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
2504
-
2505
- //struct ggml_tensor * KQ_scaled =
2506
- // ggml_scale(ctx0,
2507
- // KQ,
2508
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2509
- // );
2555
+ if (wctx.params.flash_attn) {
2556
+ struct ggml_tensor * Kcross =
2557
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
2558
+ n_state_head, n_audio_ctx_pad, n_head,
2559
+ ggml_element_size(wstate.kv_cross.k)*n_state,
2560
+ ggml_element_size(wstate.kv_cross.k)*n_state_head,
2561
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
2510
2562
 
2511
- // no masking for cross-attention
2512
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2563
+ struct ggml_tensor * Vcross =
2564
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
2565
+ n_state_head, n_audio_ctx_pad, n_head,
2566
+ ggml_element_size(wstate.kv_cross.v)*n_state,
2567
+ ggml_element_size(wstate.kv_cross.v)*n_state_head,
2568
+ ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
2513
2569
 
2514
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2570
+ cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
2515
2571
 
2516
- // [EXPERIMENTAL] Token-level timestamps with DTW
2517
- if (wctx.params.dtw_token_timestamps) {
2518
- if (wstate.aheads_masks.m[il] != nullptr) {
2519
- struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2520
- aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2521
- aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2522
- aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2523
- aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2524
- aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2525
- aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2526
- if (aheads_cross_QKs == NULL) {
2527
- aheads_cross_QKs = aheads_KQs;
2528
- } else {
2529
- aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
2572
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2573
+ } else {
2574
+ struct ggml_tensor * Kcross =
2575
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
2576
+ n_state_head, n_audio_ctx, n_head,
2577
+ ggml_element_size(wstate.kv_cross.k)*n_state,
2578
+ ggml_element_size(wstate.kv_cross.k)*n_state_head,
2579
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2580
+
2581
+ struct ggml_tensor * Vcross =
2582
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
2583
+ n_audio_ctx, n_state_head, n_head,
2584
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
2585
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
2586
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
2587
+
2588
+ // ------
2589
+
2590
+ // K * Q
2591
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
2592
+
2593
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2594
+
2595
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2596
+ if (wctx.params.dtw_token_timestamps) {
2597
+ if (wstate.aheads_masks.m[il] != nullptr) {
2598
+ struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2599
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2600
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2601
+ aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2602
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2603
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2604
+ aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2605
+ if (aheads_cross_QKs == NULL) {
2606
+ aheads_cross_QKs = aheads_KQs;
2607
+ } else {
2608
+ aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
2609
+ }
2530
2610
  }
2531
2611
  }
2532
- }
2533
2612
 
2534
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2613
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
2535
2614
 
2536
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2615
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2537
2616
 
2538
- // cur = KQV_merged.contiguous().view(n_state, n_tokens)
2539
- cur = ggml_cpy(ctx0,
2540
- KQV_merged,
2541
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2617
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2618
+ }
2542
2619
  }
2543
2620
 
2544
2621
  // projection
@@ -2671,18 +2748,20 @@ static bool whisper_decode_internal(
2671
2748
  return false;
2672
2749
  }
2673
2750
 
2674
- kv_self.n = whisper_kv_cache_cell_max(kv_self);
2751
+ const uint32_t pad = whisper_kv_cache_get_padding(wctx);
2752
+ kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
2753
+
2675
2754
  //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2676
2755
  //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2677
2756
  }
2678
2757
 
2679
2758
  // decoder
2680
2759
  {
2681
- auto & alloc = wstate.alloc_decode.alloc;
2760
+ auto & sched = wstate.sched_decode.sched;
2682
2761
 
2683
2762
  ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
2684
2763
 
2685
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2764
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2686
2765
  // should never happen as we pre-allocate the memory
2687
2766
  return false;
2688
2767
  }
@@ -2705,9 +2784,10 @@ static bool whisper_decode_internal(
2705
2784
  struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
2706
2785
 
2707
2786
  auto & kv_self = wstate.kv_self;
2708
- const int32_t n_kv = kv_self.n;
2709
2787
 
2710
- wstate.inp_mask.resize(n_kv*n_tokens);
2788
+ const int32_t n_kv = kv_self.n;
2789
+
2790
+ wstate.inp_mask.resize(ggml_nelements(KQ_mask));
2711
2791
 
2712
2792
  float * data = wstate.inp_mask.data();
2713
2793
  memset(data, 0, ggml_nbytes(KQ_mask));
@@ -2723,14 +2803,20 @@ static bool whisper_decode_internal(
2723
2803
  }
2724
2804
  }
2725
2805
  }
2806
+
2807
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
2808
+ for (int j = 0; j < n_kv; ++j) {
2809
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
2810
+ }
2811
+ }
2726
2812
  }
2727
2813
 
2728
2814
  ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
2729
2815
  }
2730
2816
 
2731
- logits = gf->nodes[gf->n_nodes - 1];
2817
+ logits = ggml_graph_node(gf, -1);
2732
2818
 
2733
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2819
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2734
2820
  return false;
2735
2821
  }
2736
2822
  }
@@ -2784,29 +2870,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2784
2870
  }
2785
2871
 
2786
2872
  #define SIN_COS_N_COUNT WHISPER_N_FFT
2787
- static float sin_vals[SIN_COS_N_COUNT];
2788
- static float cos_vals[SIN_COS_N_COUNT];
2873
+ namespace {
2874
+ struct whisper_global_cache {
2875
+ // In FFT, we frequently use sine and cosine operations with the same values.
2876
+ // We can use precalculated values to speed up the process.
2877
+ float sin_vals[SIN_COS_N_COUNT];
2878
+ float cos_vals[SIN_COS_N_COUNT];
2879
+
2880
+ // Hann window (Use cosf to eliminate difference)
2881
+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2882
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2883
+ float hann_window[WHISPER_N_FFT];
2884
+
2885
+ whisper_global_cache() {
2886
+ fill_sin_cos_table();
2887
+ fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
2888
+ }
2889
+
2890
+ void fill_sin_cos_table() {
2891
+ for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2892
+ double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
2893
+ sin_vals[i] = sinf(theta);
2894
+ cos_vals[i] = cosf(theta);
2895
+ }
2896
+ }
2789
2897
 
2790
- // In FFT, we frequently use sine and cosine operations with the same values.
2791
- // We can use precalculated values to speed up the process.
2792
- static void fill_sin_cos_table() {
2793
- static bool is_filled = false;
2794
- if (is_filled) return;
2795
- for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2796
- double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
2797
- sin_vals[i] = sinf(theta);
2798
- cos_vals[i] = cosf(theta);
2898
+ void fill_hann_window(int length, bool periodic, float * output) {
2899
+ int offset = -1;
2900
+ if (periodic) {
2901
+ offset = 0;
2902
+ }
2903
+ for (int i = 0; i < length; i++) {
2904
+ output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
2905
+ }
2799
2906
  }
2800
- is_filled = true;
2907
+ } global_cache;
2801
2908
  }
2802
2909
 
2803
2910
  // naive Discrete Fourier Transform
2804
2911
  // input is real-valued
2805
2912
  // output is complex-valued
2806
- static void dft(const std::vector<float> & in, std::vector<float> & out) {
2807
- int N = in.size();
2808
-
2809
- out.resize(N*2);
2913
+ static void dft(const float* in, int N, float* out) {
2810
2914
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2811
2915
 
2812
2916
  for (int k = 0; k < N; k++) {
@@ -2815,8 +2919,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2815
2919
 
2816
2920
  for (int n = 0; n < N; n++) {
2817
2921
  int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
2818
- re += in[n]*cos_vals[idx]; // cos(t)
2819
- im -= in[n]*sin_vals[idx]; // sin(t)
2922
+ re += in[n]*global_cache.cos_vals[idx]; // cos(t)
2923
+ im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
2820
2924
  }
2821
2925
 
2822
2926
  out[k*2 + 0] = re;
@@ -2828,47 +2932,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2828
2932
  // poor man's implementation - use something better
2829
2933
  // input is real-valued
2830
2934
  // output is complex-valued
2831
- static void fft(const std::vector<float> & in, std::vector<float> & out) {
2832
- out.resize(in.size()*2);
2833
-
2834
- int N = in.size();
2835
-
2935
+ static void fft(float* in, int N, float* out) {
2836
2936
  if (N == 1) {
2837
2937
  out[0] = in[0];
2838
2938
  out[1] = 0;
2839
2939
  return;
2840
2940
  }
2841
2941
 
2842
- if (N%2 == 1) {
2843
- dft(in, out);
2942
+ const int half_N = N / 2;
2943
+ if (N - half_N*2 == 1) {
2944
+ dft(in, N, out);
2844
2945
  return;
2845
2946
  }
2846
2947
 
2847
- std::vector<float> even;
2848
- std::vector<float> odd;
2849
-
2850
- even.reserve(N/2);
2851
- odd.reserve(N/2);
2852
-
2853
- for (int i = 0; i < N; i++) {
2854
- if (i % 2 == 0) {
2855
- even.push_back(in[i]);
2856
- } else {
2857
- odd.push_back(in[i]);
2858
- }
2948
+ float* even = in + N;
2949
+ for (int i = 0; i < half_N; ++i) {
2950
+ even[i]= in[2*i];
2859
2951
  }
2952
+ float* even_fft = out + 2 * N;
2953
+ fft(even, half_N, even_fft);
2860
2954
 
2861
- std::vector<float> even_fft;
2862
- std::vector<float> odd_fft;
2863
-
2864
- fft(even, even_fft);
2865
- fft(odd, odd_fft);
2955
+ float* odd = even;
2956
+ for (int i = 0; i < half_N; ++i) {
2957
+ odd[i] = in[2*i + 1];
2958
+ }
2959
+ float* odd_fft = even_fft + N;
2960
+ fft(odd, half_N, odd_fft);
2866
2961
 
2867
2962
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2868
- for (int k = 0; k < N/2; k++) {
2963
+ for (int k = 0; k < half_N; k++) {
2869
2964
  int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2870
- float re = cos_vals[idx]; // cos(t)
2871
- float im = -sin_vals[idx]; // sin(t)
2965
+ float re = global_cache.cos_vals[idx]; // cos(t)
2966
+ float im = -global_cache.sin_vals[idx]; // sin(t)
2872
2967
 
2873
2968
  float re_odd = odd_fft[2*k + 0];
2874
2969
  float im_odd = odd_fft[2*k + 1];
@@ -2876,52 +2971,39 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2876
2971
  out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
2877
2972
  out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
2878
2973
 
2879
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2880
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2881
- }
2882
- }
2883
-
2884
- static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2885
- if (output.size() < static_cast<size_t>(length)) {
2886
- output.resize(length);
2974
+ out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2975
+ out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2887
2976
  }
2888
- int offset = -1;
2889
- if (periodic) {
2890
- offset = 0;
2891
- }
2892
- for (int i = 0; i < length; i++) {
2893
- output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
2894
- }
2895
-
2896
- return true;
2897
2977
  }
2898
2978
 
2899
- static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
2979
+ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
2900
2980
  int n_samples, int frame_size, int frame_step, int n_threads,
2901
2981
  const whisper_filters & filters, whisper_mel & mel) {
2902
- std::vector<float> fft_in(frame_size, 0.0);
2903
- std::vector<float> fft_out(2 * frame_size);
2982
+ std::vector<float> fft_in(frame_size * 2, 0.0);
2983
+ std::vector<float> fft_out(frame_size * 2 * 2 * 2);
2984
+
2904
2985
  int n_fft = filters.n_fft;
2905
2986
  int i = ith;
2906
2987
 
2907
2988
  // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2908
- assert( n_fft == 1 + (frame_size / 2) );
2909
-
2989
+ assert(n_fft == 1 + (frame_size / 2));
2990
+
2910
2991
  // calculate FFT only when fft_in are not all zero
2911
2992
  for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2912
2993
  const int offset = i * frame_step;
2913
2994
 
2914
- // apply Hanning window (~10% faster)
2995
+ // apply Hann window (~10% faster)
2915
2996
  for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
2916
2997
  fft_in[j] = hann[j] * samples[offset + j];
2917
2998
  }
2999
+
2918
3000
  // fill the rest with zeros
2919
3001
  if (n_samples - offset < frame_size) {
2920
3002
  std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
2921
3003
  }
2922
3004
 
2923
3005
  // FFT
2924
- fft(fft_in, fft_out);
3006
+ fft(fft_in.data(), frame_size, fft_out.data());
2925
3007
 
2926
3008
  // Calculate modulus^2 of complex numbers
2927
3009
  // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
@@ -2932,7 +3014,6 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2932
3014
  // mel spectrogram
2933
3015
  for (int j = 0; j < mel.n_mel; j++) {
2934
3016
  double sum = 0.0;
2935
-
2936
3017
  // unroll loop (suggested by GH user @lunixbochs)
2937
3018
  int k = 0;
2938
3019
  for (k = 0; k < n_fft - 3; k += 4) {
@@ -2942,14 +3023,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2942
3023
  fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
2943
3024
  fft_out[k + 3] * filters.data[j * n_fft + k + 3];
2944
3025
  }
2945
-
2946
3026
  // handle n_fft remainder
2947
3027
  for (; k < n_fft; k++) {
2948
3028
  sum += fft_out[k] * filters.data[j * n_fft + k];
2949
3029
  }
2950
-
2951
3030
  sum = log10(std::max(sum, 1e-10));
2952
-
2953
3031
  mel.data[j * mel.n_len + i] = sum;
2954
3032
  }
2955
3033
  }
@@ -2978,12 +3056,9 @@ static bool log_mel_spectrogram(
2978
3056
  whisper_mel & mel) {
2979
3057
  const int64_t t_start_us = ggml_time_us();
2980
3058
 
2981
- // Hanning window (Use cosf to eliminate difference)
2982
- // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2983
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2984
- std::vector<float> hann;
2985
- hann_window(frame_size, true, hann);
2986
-
3059
+ // Hann window
3060
+ WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
3061
+ const float * hann = global_cache.hann_window;
2987
3062
 
2988
3063
  // Calculate the length of padding
2989
3064
  int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
@@ -3008,12 +3083,11 @@ static bool log_mel_spectrogram(
3008
3083
  mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
3009
3084
  mel.data.resize(mel.n_mel * mel.n_len);
3010
3085
 
3011
-
3012
3086
  {
3013
3087
  std::vector<std::thread> workers(n_threads - 1);
3014
3088
  for (int iw = 0; iw < n_threads - 1; ++iw) {
3015
3089
  workers[iw] = std::thread(
3016
- log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
3090
+ log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
3017
3091
  n_samples + stage_2_pad, frame_size, frame_step, n_threads,
3018
3092
  std::cref(filters), std::ref(mel));
3019
3093
  }
@@ -3173,23 +3247,23 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
3173
3247
  #endif
3174
3248
 
3175
3249
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
3176
- fill_sin_cos_table();
3177
-
3178
3250
  whisper_state * state = new whisper_state;
3179
3251
 
3180
- state->backend = whisper_backend_init(ctx->params);
3181
- if (!state->backend) {
3252
+ state->backends = whisper_backend_init(ctx->params);
3253
+ if (state->backends.empty()) {
3182
3254
  WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
3183
3255
  whisper_free_state(state);
3184
3256
  return nullptr;
3185
3257
  }
3186
3258
 
3187
- // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3188
- // in theory, there can be a case where this is not enough, but in practice it should always be enough
3189
- const int factor = 3;
3190
-
3191
- if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
3192
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3259
+ // at this point, we don't know yet how many decoders will be used
3260
+ // later during decoding, if more decoders are used, we will recreate the KV cache respectively
3261
+ state->kv_self_n_dec = 1;
3262
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
3263
+ ctx->model.hparams.n_text_state,
3264
+ ctx->model.hparams.n_text_layer,
3265
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
3266
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3193
3267
  whisper_free_state(state);
3194
3268
  return nullptr;
3195
3269
  }
@@ -3199,8 +3273,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3199
3273
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
3200
3274
  }
3201
3275
 
3202
- if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
3203
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
3276
+ if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
3277
+ ctx->model.hparams.n_text_state,
3278
+ ctx->model.hparams.n_text_layer,
3279
+ GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3280
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
3204
3281
  whisper_free_state(state);
3205
3282
  return nullptr;
3206
3283
  }
@@ -3210,9 +3287,23 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3210
3287
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3211
3288
  }
3212
3289
 
3290
+ if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
3291
+ ctx->model.hparams.n_audio_state,
3292
+ 1,
3293
+ GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3294
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3295
+ whisper_free_state(state);
3296
+ return nullptr;
3297
+ }
3298
+
3299
+ {
3300
+ const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
3301
+ WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
3302
+ }
3303
+
3213
3304
  // [EXPERIMENTAL] Token-level timestamps with DTW
3214
3305
  if (ctx->params.dtw_token_timestamps) {
3215
- if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
3306
+ if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
3216
3307
  WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
3217
3308
  whisper_free_state(state);
3218
3309
  return nullptr;
@@ -3255,7 +3346,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3255
3346
 
3256
3347
  // conv allocator
3257
3348
  {
3258
- bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
3349
+ bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
3259
3350
  [&]() {
3260
3351
  return whisper_build_graph_conv(*ctx, *state);
3261
3352
  });
@@ -3266,12 +3357,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3266
3357
  return nullptr;
3267
3358
  }
3268
3359
 
3269
- WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
3360
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
3270
3361
  }
3271
3362
 
3272
3363
  // encoder allocator
3273
3364
  if (!whisper_encode_external(*state)) {
3274
- bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
3365
+ bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
3275
3366
  [&]() {
3276
3367
  return whisper_build_graph_encoder(*ctx, *state);
3277
3368
  });
@@ -3282,12 +3373,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3282
3373
  return nullptr;
3283
3374
  }
3284
3375
 
3285
- WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
3376
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
3286
3377
  }
3287
3378
 
3288
3379
  // cross allocator
3289
3380
  {
3290
- bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
3381
+ bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
3291
3382
  [&]() {
3292
3383
  return whisper_build_graph_cross(*ctx, *state);
3293
3384
  });
@@ -3298,12 +3389,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3298
3389
  return nullptr;
3299
3390
  }
3300
3391
 
3301
- WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
3392
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
3302
3393
  }
3303
3394
 
3304
3395
  // decoder allocator
3305
3396
  {
3306
- bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
3397
+ bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
3307
3398
  [&]() {
3308
3399
  const auto & hparams = ctx->model.hparams;
3309
3400
 
@@ -3322,19 +3413,21 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3322
3413
  return nullptr;
3323
3414
  }
3324
3415
 
3325
- WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
3416
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
3326
3417
  }
3327
3418
 
3328
3419
  return state;
3329
3420
  }
3330
3421
 
3331
- int whisper_ctx_init_openvino_encoder(
3422
+ int whisper_ctx_init_openvino_encoder_with_state(
3332
3423
  struct whisper_context * ctx,
3424
+ struct whisper_state * state,
3333
3425
  const char * model_path,
3334
3426
  const char * device,
3335
3427
  const char * cache_dir) {
3336
3428
  #ifndef WHISPER_USE_OPENVINO
3337
3429
  (void)(ctx);
3430
+ (void)(state);
3338
3431
  (void)(model_path);
3339
3432
  (void)(device);
3340
3433
  (void)(cache_dir);
@@ -3365,8 +3458,8 @@ int whisper_ctx_init_openvino_encoder(
3365
3458
  WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3366
3459
  WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
3367
3460
 
3368
- ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3369
- if (!ctx->state->ctx_openvino) {
3461
+ state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3462
+ if (!state->ctx_openvino) {
3370
3463
  WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3371
3464
  return 1;
3372
3465
  } else {
@@ -3377,9 +3470,18 @@ int whisper_ctx_init_openvino_encoder(
3377
3470
  #endif
3378
3471
  }
3379
3472
 
3473
+ int whisper_ctx_init_openvino_encoder(
3474
+ struct whisper_context * ctx,
3475
+ const char * model_path,
3476
+ const char * device,
3477
+ const char * cache_dir) {
3478
+ return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
3479
+ }
3480
+
3380
3481
  struct whisper_context_params whisper_context_default_params() {
3381
3482
  struct whisper_context_params result = {
3382
3483
  /*.use_gpu =*/ true,
3484
+ /*.flash_attn =*/ false,
3383
3485
  /*.gpu_device =*/ 0,
3384
3486
 
3385
3487
  /*.dtw_token_timestamps =*/ false,
@@ -3396,8 +3498,14 @@ struct whisper_context_params whisper_context_default_params() {
3396
3498
 
3397
3499
  struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3398
3500
  WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
3399
-
3501
+ #ifdef _MSC_VER
3502
+ // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
3503
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
3504
+ std::wstring path_model_wide = converter.from_bytes(path_model);
3505
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
3506
+ #else
3400
3507
  auto fin = std::ifstream(path_model, std::ios::binary);
3508
+ #endif
3401
3509
  if (!fin) {
3402
3510
  WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
3403
3511
  return nullptr;
@@ -3472,6 +3580,18 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
3472
3580
  struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
3473
3581
  ggml_time_init();
3474
3582
 
3583
+ if (params.flash_attn && params.dtw_token_timestamps) {
3584
+ WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
3585
+ params.dtw_token_timestamps = false;
3586
+ }
3587
+
3588
+ WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
3589
+ WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
3590
+ WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
3591
+ WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
3592
+ WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
3593
+ WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
3594
+
3475
3595
  whisper_context * ctx = new whisper_context;
3476
3596
  ctx->params = params;
3477
3597
 
@@ -3558,8 +3678,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
3558
3678
 
3559
3679
  void whisper_free_state(struct whisper_state * state) {
3560
3680
  if (state) {
3561
- kv_cache_free(state->kv_self);
3562
- kv_cache_free(state->kv_cross);
3681
+ whisper_kv_cache_free(state->kv_self);
3682
+ whisper_kv_cache_free(state->kv_cross);
3683
+ whisper_kv_cache_free(state->kv_pad);
3563
3684
 
3564
3685
  #ifdef WHISPER_USE_COREML
3565
3686
  if (state->ctx_coreml != nullptr) {
@@ -3577,12 +3698,14 @@ void whisper_free_state(struct whisper_state * state) {
3577
3698
 
3578
3699
  whisper_batch_free(state->batch);
3579
3700
 
3580
- ggml_gallocr_free(state->alloc_conv.alloc);
3581
- ggml_gallocr_free(state->alloc_encode.alloc);
3582
- ggml_gallocr_free(state->alloc_cross.alloc);
3583
- ggml_gallocr_free(state->alloc_decode.alloc);
3701
+ ggml_backend_sched_free(state->sched_conv.sched);
3702
+ ggml_backend_sched_free(state->sched_encode.sched);
3703
+ ggml_backend_sched_free(state->sched_cross.sched);
3704
+ ggml_backend_sched_free(state->sched_decode.sched);
3584
3705
 
3585
- ggml_backend_free(state->backend);
3706
+ for (auto & backend : state->backends) {
3707
+ ggml_backend_free(backend);
3708
+ }
3586
3709
 
3587
3710
  // [EXPERIMENTAL] Token-level timestamps with DTW
3588
3711
  aheads_masks_free(state->aheads_masks);
@@ -3599,8 +3722,6 @@ void whisper_free(struct whisper_context * ctx) {
3599
3722
 
3600
3723
  whisper_free_state(ctx->state);
3601
3724
 
3602
- ggml_backend_free(ctx->backend);
3603
-
3604
3725
  delete ctx;
3605
3726
  }
3606
3727
  }
@@ -3630,30 +3751,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3630
3751
  return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3631
3752
  }
3632
3753
 
3633
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3634
- int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3635
- if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3636
- WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3637
- return -1;
3638
- }
3639
-
3640
- return 0;
3641
- }
3642
-
3643
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3644
- int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
3645
- return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3646
- }
3647
-
3648
- // same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
3649
- // TODO
3650
-
3651
- // same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
3652
- // TODO
3653
-
3654
- // same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
3655
- // TODO
3656
-
3657
3754
  int whisper_set_mel_with_state(
3658
3755
  struct whisper_context * ctx,
3659
3756
  struct whisper_state * state,
@@ -3742,7 +3839,7 @@ int whisper_token_count(struct whisper_context * ctx, const char * text) {
3742
3839
  return -whisper_tokenize(ctx, text, NULL, 0);
3743
3840
  }
3744
3841
 
3745
- int whisper_lang_max_id() {
3842
+ int whisper_lang_max_id(void) {
3746
3843
  auto max_id = 0;
3747
3844
  for (const auto & kv : g_lang) {
3748
3845
  max_id = std::max(max_id, kv.second.first);
@@ -4011,6 +4108,19 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
4011
4108
  return ctx->vocab.token_transcribe;
4012
4109
  }
4013
4110
 
4111
+ struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
4112
+ if (ctx->state == nullptr) {
4113
+ return nullptr;
4114
+ }
4115
+ whisper_timings * timings = new whisper_timings;
4116
+ timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
4117
+ timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
4118
+ timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
4119
+ timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd);
4120
+ timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt);
4121
+ return timings;
4122
+ }
4123
+
4014
4124
  void whisper_print_timings(struct whisper_context * ctx) {
4015
4125
  const int64_t t_end_us = ggml_time_us();
4016
4126
 
@@ -4078,17 +4188,14 @@ const char * whisper_print_system_info(void) {
4078
4188
  s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
4079
4189
  s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
4080
4190
  s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
4081
- s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
4082
4191
  s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
4083
4192
  s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
4084
4193
  s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
4085
- s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
4086
4194
  s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
4087
4195
  s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
4088
4196
  s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
4089
- s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
4090
4197
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
4091
- s += "OPENVINO = " + std::to_string(whisper_has_openvino()) ;
4198
+ s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
4092
4199
 
4093
4200
  return s.c_str();
4094
4201
  }
@@ -4099,7 +4206,7 @@ const char * whisper_print_system_info(void) {
4099
4206
 
4100
4207
  // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
4101
4208
  // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
4102
- std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4209
+ static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4103
4210
  const char * src,
4104
4211
  whisper_partial_utf8 partial_start) {
4105
4212
  static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
@@ -4513,7 +4620,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
4513
4620
 
4514
4621
  ////////////////////////////////////////////////////////////////////////////
4515
4622
 
4516
- struct whisper_context_params * whisper_context_default_params_by_ref() {
4623
+ struct whisper_context_params * whisper_context_default_params_by_ref(void) {
4517
4624
  struct whisper_context_params params = whisper_context_default_params();
4518
4625
 
4519
4626
  struct whisper_context_params* result = new whisper_context_params();
@@ -4554,7 +4661,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4554
4661
  /*.split_on_word =*/ false,
4555
4662
  /*.max_tokens =*/ 0,
4556
4663
 
4557
- /*.speed_up =*/ false,
4558
4664
  /*.debug_mode =*/ false,
4559
4665
  /*.audio_ctx =*/ 0,
4560
4666
 
@@ -4720,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
4720
4826
  "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
4721
4827
  };
4722
4828
 
4829
+ static void whisper_compute_logprobs(
4830
+ const std::vector<float> & logits,
4831
+ const int n_logits,
4832
+ std::vector<float> & logprobs) {
4833
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
4834
+ float logsumexp = 0.0f;
4835
+ for (int i = 0; i < n_logits; ++i) {
4836
+ if (logits[i] > -INFINITY) {
4837
+ logsumexp += expf(logits[i] - logit_max);
4838
+ }
4839
+ }
4840
+ logsumexp = logf(logsumexp) + logit_max;
4841
+
4842
+ for (int i = 0; i < n_logits; ++i) {
4843
+ if (logits[i] > -INFINITY) {
4844
+ logprobs[i] = logits[i] - logsumexp;
4845
+ } else {
4846
+ logprobs[i] = -INFINITY;
4847
+ }
4848
+ }
4849
+ }
4850
+
4851
+ static void whisper_compute_probs(
4852
+ const std::vector<float> & logits,
4853
+ const int n_logits,
4854
+ const std::vector<float> & logprobs,
4855
+ std::vector<float> & probs) {
4856
+ for (int i = 0; i < n_logits; ++i) {
4857
+ if (logits[i] == -INFINITY) {
4858
+ probs[i] = 0.0f;
4859
+ } else {
4860
+ probs[i] = expf(logprobs[i]);
4861
+ }
4862
+ }
4863
+ }
4864
+
4723
4865
  // process the logits for the selected decoder
4724
4866
  // - applies logit filters
4725
4867
  // - computes logprobs and probs
@@ -4781,7 +4923,7 @@ static void whisper_process_logits(
4781
4923
 
4782
4924
  // suppress sot and nosp tokens
4783
4925
  logits[vocab.token_sot] = -INFINITY;
4784
- logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
4926
+ logits[vocab.token_nosp] = -INFINITY;
4785
4927
 
4786
4928
  // [TDRZ] when tinydiarize is disabled, suppress solm token
4787
4929
  if (params.tdrz_enable == false) {
@@ -4880,24 +5022,7 @@ static void whisper_process_logits(
4880
5022
  }
4881
5023
 
4882
5024
  // populate the logprobs array (log_softmax)
4883
- {
4884
- const float logit_max = *std::max_element(logits.begin(), logits.end());
4885
- float logsumexp = 0.0f;
4886
- for (int i = 0; i < n_logits; ++i) {
4887
- if (logits[i] > -INFINITY) {
4888
- logsumexp += expf(logits[i] - logit_max);
4889
- }
4890
- }
4891
- logsumexp = logf(logsumexp) + logit_max;
4892
-
4893
- for (int i = 0; i < n_logits; ++i) {
4894
- if (logits[i] > -INFINITY) {
4895
- logprobs[i] = logits[i] - logsumexp;
4896
- } else {
4897
- logprobs[i] = -INFINITY;
4898
- }
4899
- }
4900
- }
5025
+ whisper_compute_logprobs(logits, n_logits, logprobs);
4901
5026
 
4902
5027
  // if sum of probability over timestamps is above any other token, sample timestamp
4903
5028
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
@@ -4955,15 +5080,7 @@ static void whisper_process_logits(
4955
5080
  }
4956
5081
 
4957
5082
  // compute probs
4958
- {
4959
- for (int i = 0; i < n_logits; ++i) {
4960
- if (logits[i] == -INFINITY) {
4961
- probs[i] = 0.0f;
4962
- } else {
4963
- probs[i] = expf(logprobs[i]);
4964
- }
4965
- }
4966
- }
5083
+ whisper_compute_probs(logits, n_logits, logprobs, probs);
4967
5084
 
4968
5085
  #if 0
4969
5086
  // print first 100 logits - token string : logit
@@ -5228,15 +5345,9 @@ int whisper_full_with_state(
5228
5345
 
5229
5346
  if (n_samples > 0) {
5230
5347
  // compute log mel spectrogram
5231
- if (params.speed_up) {
5232
- // TODO: Replace PV with more advanced algorithm
5348
+ if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
5233
5349
  WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5234
- return -1;
5235
- } else {
5236
- if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
5237
- WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5238
- return -2;
5239
- }
5350
+ return -2;
5240
5351
  }
5241
5352
  }
5242
5353
 
@@ -5273,7 +5384,7 @@ int whisper_full_with_state(
5273
5384
  // if length of spectrogram is less than 1.0s (100 frames), then return
5274
5385
  // basically don't process anything that is less than 1.0s
5275
5386
  // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
5276
- if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
5387
+ if (seek_end < seek_start + 100) {
5277
5388
  WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
5278
5389
  return 0;
5279
5390
  }
@@ -5518,13 +5629,46 @@ int whisper_full_with_state(
5518
5629
  }
5519
5630
  WHISPER_LOG_DEBUG("\n\n");
5520
5631
 
5632
+ // recreate the KV cache if the number of decoders has changed
5633
+ if (state->kv_self_n_dec < n_decoders_cur) {
5634
+ WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
5635
+
5636
+ whisper_kv_cache_free(state->kv_self);
5637
+
5638
+ // overallocate to workaround KV cache fragmentation issues
5639
+ const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
5640
+
5641
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
5642
+ ctx->model.hparams.n_text_state,
5643
+ ctx->model.hparams.n_text_layer,
5644
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
5645
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
5646
+ whisper_free_state(state);
5647
+ return -7;
5648
+ }
5649
+
5650
+ state->kv_self_n_dec = n_decoders_cur;
5651
+ }
5652
+
5521
5653
  whisper_kv_cache_clear(state->kv_self);
5522
5654
 
5523
5655
  whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5524
5656
 
5525
5657
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5526
5658
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5527
- return -7;
5659
+ return -8;
5660
+ }
5661
+
5662
+ // Calculate no_speech probability after first decode.
5663
+ // This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
5664
+ {
5665
+ const int n_logits = ctx->vocab.id_to_token.size();
5666
+ std::vector<float> logprobs(n_logits);
5667
+ std::vector<float> probs(n_logits);
5668
+
5669
+ whisper_compute_logprobs(state->logits, n_logits, logprobs);
5670
+ whisper_compute_probs(state->logits, n_logits, logprobs, probs);
5671
+ state->no_speech_prob = probs[whisper_token_nosp(ctx)];
5528
5672
  }
5529
5673
 
5530
5674
  {
@@ -5824,7 +5968,7 @@ int whisper_full_with_state(
5824
5968
 
5825
5969
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5826
5970
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5827
- return -8;
5971
+ return -9;
5828
5972
  }
5829
5973
 
5830
5974
  const int64_t t_start_sample_us = ggml_time_us();
@@ -5918,8 +6062,9 @@ int whisper_full_with_state(
5918
6062
  if (it != (int) temperatures.size() - 1) {
5919
6063
  const auto & decoder = state->decoders[best_decoder_id];
5920
6064
 
5921
- if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
5922
- WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
6065
+ if (decoder.failed ||
6066
+ (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
6067
+ WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
5923
6068
  success = false;
5924
6069
  state->n_fail_p++;
5925
6070
  }
@@ -5940,7 +6085,7 @@ int whisper_full_with_state(
5940
6085
  {
5941
6086
  const auto & best_decoder = state->decoders[best_decoder_id];
5942
6087
 
5943
- const auto seek_delta = best_decoder.seek_delta;
6088
+ auto seek_delta = best_decoder.seek_delta;
5944
6089
  const auto result_len = best_decoder.sequence.result_len;
5945
6090
 
5946
6091
  const auto & tokens_cur = best_decoder.sequence.tokens;
@@ -5948,6 +6093,9 @@ int whisper_full_with_state(
5948
6093
  // [EXPERIMENTAL] Token-level timestamps with DTW
5949
6094
  const auto n_segments_before = state->result_all.size();
5950
6095
 
6096
+ const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
6097
+ best_decoder.sequence.avg_logprobs < params.logprob_thold);
6098
+
5951
6099
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
5952
6100
 
5953
6101
  // update prompt_past
@@ -5956,11 +6104,11 @@ int whisper_full_with_state(
5956
6104
  prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
5957
6105
  }
5958
6106
 
5959
- for (int i = 0; i < result_len; ++i) {
6107
+ for (int i = 0; i < result_len && !is_no_speech; ++i) {
5960
6108
  prompt_past.push_back(tokens_cur[i].id);
5961
6109
  }
5962
6110
 
5963
- if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
6111
+ if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
5964
6112
  int i0 = 0;
5965
6113
  auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
5966
6114
 
@@ -5985,8 +6133,8 @@ int whisper_full_with_state(
5985
6133
  const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
5986
6134
 
5987
6135
  if (!text.empty()) {
5988
- const auto tt0 = params.speed_up ? 2*t0 : t0;
5989
- const auto tt1 = params.speed_up ? 2*t1 : t1;
6136
+ const auto tt0 = t0;
6137
+ const auto tt1 = t1;
5990
6138
 
5991
6139
  if (params.print_realtime) {
5992
6140
  if (params.print_timestamps) {
@@ -6014,7 +6162,7 @@ int whisper_full_with_state(
6014
6162
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
6015
6163
  }
6016
6164
  }
6017
- if (params.new_segment_callback) {
6165
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
6018
6166
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
6019
6167
  }
6020
6168
  }
@@ -6032,8 +6180,8 @@ int whisper_full_with_state(
6032
6180
  if (!text.empty()) {
6033
6181
  const auto t1 = seek + seek_delta;
6034
6182
 
6035
- const auto tt0 = params.speed_up ? 2*t0 : t0;
6036
- const auto tt1 = params.speed_up ? 2*t1 : t1;
6183
+ const auto tt0 = t0;
6184
+ const auto tt1 = t1;
6037
6185
 
6038
6186
  if (params.print_realtime) {
6039
6187
  if (params.print_timestamps) {
@@ -6059,7 +6207,7 @@ int whisper_full_with_state(
6059
6207
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
6060
6208
  }
6061
6209
  }
6062
- if (params.new_segment_callback) {
6210
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
6063
6211
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
6064
6212
  }
6065
6213
  }
@@ -6068,14 +6216,28 @@ int whisper_full_with_state(
6068
6216
  // FIXME: will timestamp offsets be correct?
6069
6217
  // [EXPERIMENTAL] Token-level timestamps with DTW
6070
6218
  {
6071
- const auto n_segments = state->result_all.size() - n_segments_before;
6219
+ const int n_segments = state->result_all.size() - n_segments_before;
6072
6220
  if (ctx->params.dtw_token_timestamps && n_segments) {
6073
6221
  const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
6074
6222
  whisper_exp_compute_token_level_timestamps_dtw(
6075
6223
  ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
6224
+ if (params.new_segment_callback) {
6225
+ for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
6226
+ params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
6227
+ }
6228
+ }
6076
6229
  }
6077
6230
  }
6078
6231
 
6232
+ // ref: https://github.com/ggerganov/whisper.cpp/pull/2629
6233
+ const bool single_timestamp_ending = tokens_cur.size() > 1 &&
6234
+ tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
6235
+ tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
6236
+ if (single_timestamp_ending) {
6237
+ WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n");
6238
+ seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100);
6239
+ }
6240
+
6079
6241
  // update audio window
6080
6242
  seek += seek_delta;
6081
6243
 
@@ -6835,7 +6997,7 @@ static void whisper_exp_compute_token_level_timestamps(
6835
6997
  k++;
6836
6998
  }
6837
6999
  tokens[j].t1 = sample_to_timestamp(k);
6838
- if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
7000
+ if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
6839
7001
  tokens[j].t1 = tokens[j + 1].t0;
6840
7002
  } else {
6841
7003
  s1 = k;
@@ -6998,10 +7160,11 @@ struct median_filter_user_data {
6998
7160
  int filter_width;
6999
7161
  };
7000
7162
 
7001
- static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) {
7163
+ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
7164
+ if (ith != 0) {
7165
+ return;
7166
+ }
7002
7167
  int filter_width = ((median_filter_user_data *) userdata)->filter_width;
7003
- WHISPER_ASSERT(nth == 1);
7004
- WHISPER_ASSERT(ith == 0);
7005
7168
  WHISPER_ASSERT(filter_width < a->ne[2]);
7006
7169
  WHISPER_ASSERT(filter_width % 2);
7007
7170
  WHISPER_ASSERT(ggml_n_dims(a) == 3);
@@ -7124,7 +7287,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7124
7287
  // operation (after median filter)
7125
7288
  // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7126
7289
  // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7127
- w = ggml_norm(gctx, w, 1e-9);
7290
+ w = ggml_norm(gctx, w, 1e-9f);
7128
7291
  w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
7129
7292
 
7130
7293
  // Pass median filter - this is done over AUDIO_TOKENS dimension.
@@ -7196,6 +7359,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7196
7359
  void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
7197
7360
  g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
7198
7361
  g_state.log_callback_user_data = user_data;
7362
+ ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
7199
7363
  }
7200
7364
 
7201
7365
  GGML_ATTRIBUTE_FORMAT(2, 3)
@@ -7219,6 +7383,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
7219
7383
  static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
7220
7384
  (void) level;
7221
7385
  (void) user_data;
7386
+ #ifndef WHISPER_DEBUG
7387
+ if (level == GGML_LOG_LEVEL_DEBUG) {
7388
+ return;
7389
+ }
7390
+ #endif
7222
7391
  fputs(text, stderr);
7223
7392
  fflush(stderr);
7224
7393
  }