whisper.rn 0.4.0-rc.7 → 0.4.0-rc.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/android/src/main/CMakeLists.txt +2 -1
  2. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  3. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
  5. package/android/src/main/jni.cpp +29 -1
  6. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  7. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  8. package/cpp/coreml/whisper-encoder.mm +1 -1
  9. package/cpp/ggml-aarch64.c +3209 -0
  10. package/cpp/ggml-aarch64.h +39 -0
  11. package/cpp/ggml-alloc.c +732 -494
  12. package/cpp/ggml-alloc.h +47 -63
  13. package/cpp/ggml-backend-impl.h +162 -47
  14. package/cpp/ggml-backend.cpp +2635 -0
  15. package/cpp/ggml-backend.h +216 -71
  16. package/cpp/ggml-common.h +1853 -0
  17. package/cpp/ggml-cpu-impl.h +614 -0
  18. package/cpp/ggml-impl.h +144 -178
  19. package/cpp/ggml-metal.h +14 -60
  20. package/cpp/ggml-metal.m +3437 -2097
  21. package/cpp/ggml-quants.c +12559 -4189
  22. package/cpp/ggml-quants.h +135 -212
  23. package/cpp/ggml-whisper.metallib +0 -0
  24. package/cpp/ggml.c +9029 -5219
  25. package/cpp/ggml.h +673 -338
  26. package/cpp/rn-whisper.cpp +91 -0
  27. package/cpp/rn-whisper.h +2 -0
  28. package/cpp/whisper.cpp +1476 -675
  29. package/cpp/whisper.h +84 -28
  30. package/ios/RNWhisper.mm +124 -37
  31. package/ios/RNWhisperAudioUtils.h +1 -0
  32. package/ios/RNWhisperAudioUtils.m +20 -13
  33. package/ios/RNWhisperContext.h +3 -2
  34. package/ios/RNWhisperContext.mm +41 -8
  35. package/jest/mock.js +9 -1
  36. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  37. package/lib/commonjs/index.js +48 -19
  38. package/lib/commonjs/index.js.map +1 -1
  39. package/lib/commonjs/version.json +1 -1
  40. package/lib/module/NativeRNWhisper.js.map +1 -1
  41. package/lib/module/index.js +48 -19
  42. package/lib/module/index.js.map +1 -1
  43. package/lib/module/version.json +1 -1
  44. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  45. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  46. package/lib/typescript/index.d.ts +25 -3
  47. package/lib/typescript/index.d.ts.map +1 -1
  48. package/package.json +6 -5
  49. package/src/NativeRNWhisper.ts +12 -3
  50. package/src/index.ts +63 -24
  51. package/src/version.json +1 -1
  52. package/whisper-rn.podspec +9 -2
  53. package/cpp/ggml-backend.c +0 -1357
  54. package/cpp/ggml-metal-whisper.metal +0 -4908
package/cpp/whisper.cpp CHANGED
@@ -8,14 +8,30 @@
8
8
  #include "ggml-metal.h"
9
9
  #endif
10
10
 
11
- #ifdef WSP_GGML_USE_CUBLAS
11
+ #ifdef WSP_GGML_USE_CUDA
12
12
  #include "ggml-cuda.h"
13
13
  #endif
14
14
 
15
+ #ifdef WSP_GGML_USE_SYCL
16
+ #include "ggml-sycl.h"
17
+ #endif
18
+
19
+ #ifdef WSP_GGML_USE_VULKAN
20
+ #include "ggml-vulkan.h"
21
+ #endif
22
+
23
+ #ifdef WSP_GGML_USE_BLAS
24
+ #include "ggml-blas.h"
25
+ #endif
26
+
15
27
  #ifdef WHISPER_USE_OPENVINO
16
28
  #include "openvino/whisper-openvino-encoder.h"
17
29
  #endif
18
30
 
31
+ #ifdef WSP_GGML_USE_CANN
32
+ #include "ggml-cann.h"
33
+ #endif
34
+
19
35
  #include "ggml.h"
20
36
  #include "ggml-alloc.h"
21
37
  #include "ggml-backend.h"
@@ -37,6 +53,7 @@
37
53
  #include <regex>
38
54
  #include <random>
39
55
  #include <functional>
56
+ #include <codecvt>
40
57
 
41
58
  #if defined(_MSC_VER)
42
59
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -122,9 +139,18 @@ WHISPER_ATTRIBUTE_FORMAT(2, 3)
122
139
  static void whisper_log_internal (wsp_ggml_log_level level, const char * format, ...);
123
140
  static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data);
124
141
 
125
- #define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
126
- #define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
127
142
  #define WHISPER_LOG_ERROR(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
143
+ #define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
144
+ #define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
145
+
146
+ // define this to enable verbose trace logging - useful for debugging purposes
147
+ //#define WHISPER_DEBUG
148
+
149
+ #if defined(WHISPER_DEBUG)
150
+ #define WHISPER_LOG_DEBUG(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
151
+ #else
152
+ #define WHISPER_LOG_DEBUG(...)
153
+ #endif
128
154
 
129
155
  #define WHISPER_ASSERT(x) \
130
156
  do { \
@@ -134,20 +160,6 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char *
134
160
  } \
135
161
  } while (0)
136
162
 
137
- // define this to enable verbose trace logging - useful for debugging purposes
138
- //#define WHISPER_DEBUG
139
-
140
- #if defined(WHISPER_DEBUG)
141
- #define WHISPER_PRINT_DEBUG(...) \
142
- do { \
143
- fprintf(stderr, __VA_ARGS__); \
144
- } while (0)
145
- #else
146
- #define WHISPER_PRINT_DEBUG(...)
147
- #endif
148
-
149
- //#define WHISPER_USE_FLASH_ATTN
150
- //#define WHISPER_USE_FLASH_FF
151
163
  #define WHISPER_MAX_DECODERS 8
152
164
  #define WHISPER_MAX_NODES 4096
153
165
 
@@ -155,15 +167,15 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char *
155
167
  // ggml helpers
156
168
  //
157
169
 
158
- static void wsp_ggml_graph_compute_helper(
170
+ static bool wsp_ggml_graph_compute_helper(
159
171
  struct wsp_ggml_cgraph * graph,
160
172
  std::vector<uint8_t> & buf,
161
173
  int n_threads,
162
- whisper_abort_callback abort_callback,
174
+ wsp_ggml_abort_callback abort_callback,
163
175
  void * abort_callback_data) {
164
- struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads);
176
+ struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads, nullptr);
165
177
 
166
- plan.abort_callback = abort_callback;
178
+ plan.abort_callback = abort_callback;
167
179
  plan.abort_callback_data = abort_callback_data;
168
180
 
169
181
  if (plan.work_size > 0) {
@@ -171,22 +183,29 @@ static void wsp_ggml_graph_compute_helper(
171
183
  plan.work_data = buf.data();
172
184
  }
173
185
 
174
- wsp_ggml_graph_compute(graph, &plan);
186
+ return wsp_ggml_graph_compute(graph, &plan);
175
187
  }
176
188
 
177
- static void wsp_ggml_graph_compute_helper(
178
- struct wsp_ggml_backend * backend,
189
+ static bool wsp_ggml_graph_compute_helper(
190
+ wsp_ggml_backend_sched_t sched,
179
191
  struct wsp_ggml_cgraph * graph,
180
192
  int n_threads) {
181
- if (wsp_ggml_backend_is_cpu(backend)) {
182
- wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
183
- }
184
- #ifdef WSP_GGML_USE_METAL
185
- if (wsp_ggml_backend_is_metal(backend)) {
186
- wsp_ggml_backend_metal_set_n_cb(backend, n_threads);
187
- }
193
+
194
+ for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(sched); ++i) {
195
+ wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(sched, i);
196
+ if (wsp_ggml_backend_is_cpu(backend)) {
197
+ wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
198
+ }
199
+ #ifdef WSP_GGML_USE_BLAS
200
+ if (wsp_ggml_backend_is_blas(backend)) {
201
+ wsp_ggml_backend_blas_set_n_threads(backend, n_threads);
202
+ }
188
203
  #endif
189
- wsp_ggml_backend_graph_compute(backend, graph);
204
+ }
205
+
206
+ bool t = wsp_ggml_backend_sched_graph_compute(sched, graph) == WSP_GGML_STATUS_SUCCESS;
207
+ wsp_ggml_backend_sched_reset(sched);
208
+ return t;
190
209
  }
191
210
 
192
211
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
@@ -350,6 +369,37 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
350
369
  { "yue", { 99, "cantonese", } },
351
370
  };
352
371
 
372
+ // [EXPERIMENTAL] Token-level timestamps with DTW
373
+ static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} };
374
+ static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} };
375
+ static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} };
376
+ static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} };
377
+ static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} };
378
+ static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} };
379
+ static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} };
380
+ static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} };
381
+ static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
382
+ static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
383
+ static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
384
+ static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
385
+
386
+ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
387
+ { WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
388
+ { WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } },
389
+ { WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } },
390
+ { WHISPER_AHEADS_BASE, { 8, g_aheads_base } },
391
+ { WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } },
392
+ { WHISPER_AHEADS_SMALL, { 10, g_aheads_small } },
393
+ { WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } },
394
+ { WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } },
395
+ { WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
396
+ { WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
397
+ { WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
398
+ { WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
399
+ };
400
+
401
+ static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
402
+
353
403
  struct whisper_mel {
354
404
  int n_len;
355
405
  int n_len_org;
@@ -412,7 +462,7 @@ struct whisper_batch {
412
462
 
413
463
  whisper_token * token;
414
464
  whisper_pos * pos;
415
- int32_t * n_seq_id;
465
+ int32_t * n_seq_id; // always 1, here for consistency with llama.cpp
416
466
  whisper_seq_id ** seq_id; // null terminated
417
467
  int8_t * logits;
418
468
  };
@@ -472,54 +522,42 @@ struct whisper_pair {
472
522
  whisper_pair() : first(A()), second(B()) {}
473
523
  };
474
524
 
475
- // wsp_ggml_allocr wrapper for whisper usage
476
- struct whisper_allocr {
477
- wsp_ggml_allocr * alloc = nullptr;
525
+ // wsp_ggml_backend_sched wrapper for whisper usage
526
+ struct whisper_sched {
527
+ wsp_ggml_backend_sched_t sched = nullptr;
478
528
 
479
529
  std::vector<uint8_t> meta;
480
-
481
- wsp_ggml_backend_buffer_t buffer;
482
530
  };
483
531
 
484
- static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
485
- return allocr.meta.size() + wsp_ggml_allocr_max_size(allocr.alloc);
532
+ static size_t whisper_sched_size(struct whisper_sched & allocr) {
533
+ size_t size = allocr.meta.size();
534
+ for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
535
+ wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(allocr.sched, i);
536
+ size += wsp_ggml_backend_sched_get_buffer_size(allocr.sched, backend);
537
+ }
538
+ return size;
486
539
  }
487
540
 
488
541
  // measure the memory usage of a graph and prepare the allocr's internal data buffer
489
- static void whisper_allocr_graph_init(struct whisper_allocr & allocr, wsp_ggml_backend_t backend, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
490
- auto & alloc = allocr.alloc;
491
- auto & meta = allocr.meta;
542
+ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<wsp_ggml_backend_t> backends, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
543
+ auto & sched = allocr.sched;
544
+ auto & meta = allocr.meta;
492
545
 
493
- alloc = wsp_ggml_allocr_new_measure_from_backend(backend);
546
+ sched = wsp_ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
494
547
 
495
548
  meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
496
549
 
497
- wsp_ggml_allocr_alloc_graph(alloc, get_graph());
498
- }
499
-
500
- static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, wsp_ggml_backend_t backend) {
501
- if (allocr.alloc == nullptr) {
502
- // this can be null if we use external encoder like CoreML or OpenVINO
503
- return;
550
+ // since there are dependencies between the different graphs,
551
+ // we need to allocate them instead of only reserving to get the correct compute buffer size
552
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, get_graph())) {
553
+ // failed to allocate the compute buffer
554
+ WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
555
+ return false;
504
556
  }
505
557
 
506
- auto & alloc = allocr.alloc;
507
- auto & buffer = allocr.buffer;
508
-
509
- size_t size = wsp_ggml_allocr_max_size(alloc);
558
+ wsp_ggml_backend_sched_reset(sched);
510
559
 
511
- wsp_ggml_allocr_free(alloc);
512
-
513
- buffer = wsp_ggml_backend_alloc_buffer(backend, size);
514
- alloc = wsp_ggml_allocr_new_from_buffer(buffer);
515
- }
516
-
517
- static void whisper_allocr_free(struct whisper_allocr & allocr) {
518
- if (allocr.alloc) {
519
- wsp_ggml_allocr_free(allocr.alloc);
520
- wsp_ggml_backend_buffer_free(allocr.buffer);
521
- allocr.alloc = nullptr;
522
- }
560
+ return true;
523
561
  }
524
562
 
525
563
  // medium
@@ -661,9 +699,9 @@ struct whisper_kv_cache {
661
699
  struct wsp_ggml_tensor * k;
662
700
  struct wsp_ggml_tensor * v;
663
701
 
664
- struct wsp_ggml_context * ctx;
702
+ wsp_ggml_backend_buffer_t buffer = nullptr;
665
703
 
666
- wsp_ggml_backend_buffer_t buffer;
704
+ std::vector<uint8_t> ctx_buf;
667
705
  };
668
706
 
669
707
  struct whisper_model {
@@ -701,10 +739,10 @@ struct whisper_model {
701
739
  std::vector<whisper_layer_decoder> layers_decoder;
702
740
 
703
741
  // ggml context that contains all the meta information about the model tensors
704
- struct wsp_ggml_context * ctx;
742
+ struct wsp_ggml_context * ctx = nullptr;
705
743
 
706
744
  // the model backend data is read-only and can be shared between processors
707
- struct wsp_ggml_backend_buffer * buffer;
745
+ wsp_ggml_backend_buffer_t buffer = nullptr;
708
746
 
709
747
  // tensors
710
748
  int n_loaded;
@@ -769,6 +807,13 @@ struct whisper_decoder {
769
807
  mutable std::mt19937 rng; // used for sampling at t > 0.0
770
808
  };
771
809
 
810
+ // [EXPERIMENTAL] Token-level timestamps with DTW
811
+ struct whisper_aheads_masks {
812
+ std::vector<struct wsp_ggml_tensor *> m; // One mask per text layer.
813
+ struct wsp_ggml_context * ctx = nullptr;
814
+ wsp_ggml_backend_buffer_t buffer = nullptr;
815
+ };
816
+
772
817
  struct whisper_state {
773
818
  int64_t t_sample_us = 0;
774
819
  int64_t t_encode_us = 0;
@@ -785,6 +830,9 @@ struct whisper_state {
785
830
  int32_t n_fail_p = 0; // number of logprob threshold failures
786
831
  int32_t n_fail_h = 0; // number of entropy threshold failures
787
832
 
833
+ // number of decoders for which we have constructed the KV cache
834
+ int32_t kv_self_n_dec = 0;
835
+
788
836
  // unified self-attention KV cache for all decoders
789
837
  whisper_kv_cache kv_self;
790
838
 
@@ -792,21 +840,22 @@ struct whisper_state {
792
840
  // shared between all decoders
793
841
  whisper_kv_cache kv_cross;
794
842
 
843
+ // padded buffer for flash-attention
844
+ whisper_kv_cache kv_pad;
845
+
795
846
  whisper_mel mel;
796
847
 
797
848
  whisper_batch batch;
798
849
 
799
850
  whisper_decoder decoders[WHISPER_MAX_DECODERS];
800
851
 
801
- wsp_ggml_backend_t backend = nullptr;
852
+ std::vector<wsp_ggml_backend_t> backends;
802
853
 
803
- // ggml-alloc:
804
854
  // - stores meta info about the intermediate tensors into the `meta` buffers
805
- // - stores the actual tensor data into the `data` buffers
806
- whisper_allocr alloc_conv;
807
- whisper_allocr alloc_encode;
808
- whisper_allocr alloc_cross;
809
- whisper_allocr alloc_decode;
855
+ whisper_sched sched_conv;
856
+ whisper_sched sched_encode;
857
+ whisper_sched sched_cross;
858
+ whisper_sched sched_decode;
810
859
 
811
860
  // result of the encoder
812
861
  struct wsp_ggml_tensor * embd_conv = nullptr;
@@ -842,6 +891,11 @@ struct whisper_state {
842
891
 
843
892
  std::vector<float> energy; // PCM signal energy
844
893
 
894
+ // [EXPERIMENTAL] Token-level timestamps with DTW
895
+ whisper_aheads_masks aheads_masks;
896
+ wsp_ggml_tensor * aheads_cross_QKs = nullptr;
897
+ std::vector<float> aheads_cross_QKs_data;
898
+
845
899
  // [EXPERIMENTAL] speed-up techniques
846
900
  int32_t exp_n_audio_ctx = 0; // 0 - use default
847
901
  };
@@ -860,8 +914,6 @@ struct whisper_context {
860
914
 
861
915
  whisper_state * state = nullptr;
862
916
 
863
- wsp_ggml_backend_t backend = nullptr;
864
-
865
917
  std::string path_model; // populated by whisper_init_from_file_with_params()
866
918
  };
867
919
 
@@ -879,21 +931,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
879
931
  BYTESWAP_VALUE(dest);
880
932
  }
881
933
 
882
- static bool kv_cache_init(
883
- const struct whisper_hparams & hparams,
934
+ static bool whisper_kv_cache_init(
884
935
  struct whisper_kv_cache & cache,
885
936
  wsp_ggml_backend_t backend,
886
937
  wsp_ggml_type wtype,
938
+ int64_t n_text_state,
939
+ int64_t n_text_layer,
887
940
  int n_ctx) {
888
- const int64_t n_text_state = hparams.n_text_state;
889
- const int64_t n_text_layer = hparams.n_text_layer;
890
-
891
941
  const int64_t n_mem = n_text_layer*n_ctx;
892
942
  const int64_t n_elements = n_text_state*n_mem;
893
943
 
944
+ cache.ctx_buf.resize(2*wsp_ggml_tensor_overhead());
945
+
894
946
  struct wsp_ggml_init_params params = {
895
- /*.mem_size =*/ 2*wsp_ggml_tensor_overhead(),
896
- /*.mem_buffer =*/ nullptr,
947
+ /*.mem_size =*/ cache.ctx_buf.size(),
948
+ /*.mem_buffer =*/ cache.ctx_buf.data(),
897
949
  /*.no_alloc =*/ true,
898
950
  };
899
951
 
@@ -903,39 +955,31 @@ static bool kv_cache_init(
903
955
  cache.cells.clear();
904
956
  cache.cells.resize(n_ctx);
905
957
 
906
- cache.ctx = wsp_ggml_init(params);
958
+ struct wsp_ggml_context * ctx = wsp_ggml_init(params);
907
959
 
908
- if (!cache.ctx) {
909
- WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
960
+ if (!ctx) {
961
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
910
962
  return false;
911
963
  }
912
964
 
913
- cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
914
- cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
915
-
916
- const size_t mem_bytes = wsp_ggml_nbytes(cache.k) + wsp_ggml_nbytes(cache.v);
965
+ cache.k = wsp_ggml_new_tensor_1d(ctx, wtype, n_elements);
966
+ cache.v = wsp_ggml_new_tensor_1d(ctx, wtype, n_elements);
917
967
 
918
- cache.buffer = wsp_ggml_backend_alloc_buffer(backend, mem_bytes);
919
-
920
- // allocate the tensors into the backend buffer
921
- {
922
- wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(cache.buffer);
968
+ cache.buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx, backend);
969
+ if (!cache.buffer) {
970
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
971
+ return false;
972
+ }
923
973
 
924
- wsp_ggml_allocr_alloc(alloc, cache.k);
925
- wsp_ggml_allocr_alloc(alloc, cache.v);
974
+ wsp_ggml_backend_buffer_clear(cache.buffer, 0);
926
975
 
927
- wsp_ggml_allocr_free(alloc);
928
- }
976
+ wsp_ggml_free(ctx);
929
977
 
930
978
  return true;
931
979
  }
932
980
 
933
- static void kv_cache_free(struct whisper_kv_cache & cache) {
934
- if (cache.ctx) {
935
- wsp_ggml_free(cache.ctx);
936
- wsp_ggml_backend_buffer_free(cache.buffer);
937
- cache.ctx = nullptr;
938
- }
981
+ static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
982
+ wsp_ggml_backend_buffer_free(cache.buffer);
939
983
  }
940
984
 
941
985
  static bool whisper_kv_cache_find_slot(
@@ -1006,6 +1050,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
1006
1050
  cache.cells[i].seq_id.clear();
1007
1051
  }
1008
1052
  cache.head = 0;
1053
+
1054
+ wsp_ggml_backend_buffer_clear(cache.buffer, 0);
1009
1055
  }
1010
1056
 
1011
1057
  static void whisper_kv_cache_seq_rm(
@@ -1056,15 +1102,167 @@ static void whisper_kv_cache_seq_cp(
1056
1102
  }
1057
1103
  }
1058
1104
 
1059
- static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1060
- wsp_ggml_backend_t backend_gpu = NULL;
1105
+ static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
1106
+ if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
1107
+ return 1u;
1108
+ }
1109
+
1110
+ #ifdef WSP_GGML_USE_METAL
1111
+ if (wctx.params.use_gpu) {
1112
+ return 32u;
1113
+ }
1114
+ #endif
1115
+
1116
+ #ifdef WSP_GGML_USE_CUDA
1117
+ if (wctx.params.use_gpu) {
1118
+ return 256u;
1119
+ }
1120
+ #endif
1121
+
1122
+ return 1u;
1123
+ }
1124
+
1125
+ // [EXPERIMENTAL] Token-level timestamps with DTW
1126
+ static bool aheads_masks_init(
1127
+ const whisper_context_params & cparams,
1128
+ const whisper_hparams & hparams,
1129
+ struct whisper_aheads_masks & aheads_masks,
1130
+ wsp_ggml_backend_t backend) {
1131
+
1132
+ const int32_t n_text_layer = hparams.n_text_layer;
1133
+ const int32_t n_head = hparams.n_text_head;
1134
+
1135
+ // Sanity checks
1136
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
1137
+ WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__);
1138
+ return false;
1139
+ } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
1140
+ if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) {
1141
+ WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer);
1142
+ return false;
1143
+ }
1144
+ } else {
1145
+ const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
1146
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) {
1147
+ if (aheads.n_heads == 0) {
1148
+ WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__);
1149
+ return false;
1150
+ }
1151
+ if (aheads.heads == NULL) {
1152
+ WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__);
1153
+ return false;
1154
+ }
1155
+ }
1156
+ for (size_t i = 0; i < aheads.n_heads; ++i) {
1157
+ if (aheads.heads[i].n_text_layer >= n_text_layer) {
1158
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer);
1159
+ return false;
1160
+ }
1161
+ if (aheads.heads[i].n_text_layer < 0) {
1162
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__);
1163
+ return false;
1164
+ }
1165
+ if (aheads.heads[i].n_head >= n_head) {
1166
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head);
1167
+ return false;
1168
+ }
1169
+ if (aheads.heads[i].n_head < 0) {
1170
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__);
1171
+ return false;
1172
+ }
1173
+ }
1174
+ }
1175
+
1176
+ struct wsp_ggml_init_params params = {
1177
+ /*.mem_size =*/ (size_t) static_cast<size_t>(n_text_layer)*wsp_ggml_tensor_overhead(),
1178
+ /*.mem_buffer =*/ nullptr,
1179
+ /*.no_alloc =*/ true,
1180
+ };
1181
+
1182
+ aheads_masks.ctx = wsp_ggml_init(params);
1183
+
1184
+ if (!aheads_masks.ctx) {
1185
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__);
1186
+ return false;
1187
+ }
1188
+
1189
+ for (int64_t il = 0; il < n_text_layer; ++il) {
1190
+ auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
1191
+ if (!aheads.empty()) {
1192
+ aheads_masks.m.push_back(wsp_ggml_new_tensor_2d(aheads_masks.ctx, WSP_GGML_TYPE_F32, n_head, aheads.size()));
1193
+ } else {
1194
+ aheads_masks.m.push_back(nullptr);
1195
+ }
1196
+ }
1197
+
1198
+ aheads_masks.buffer = wsp_ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend);
1199
+ if (!aheads_masks.buffer) {
1200
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__);
1201
+ return false;
1202
+ }
1203
+
1204
+ // Set data on mask tensors
1205
+ // Since this must be backend agnostic, we write our desired values on mask_data,
1206
+ // and send it to backend with wsp_ggml_backend_tensor_set.
1207
+ // Each mask in N_HEADS*N_ALIGNMENT_HEADS, one per text layer containing alignment
1208
+ // heads. Each row of the mask "marks" one alignment head. E.g. if some text layer
1209
+ // has a total of 10 heads and of those, heads 0,5,6 are alignment heads, the mask
1210
+ // should read:
1211
+ // 1 0 0 0 0 0 0 0 0 0
1212
+ // 0 0 0 0 0 1 0 0 0 0
1213
+ // 0 0 0 0 0 0 1 0 0 0
1214
+ std::vector<float> mask_data;
1215
+ for (int64_t il = 0; il < n_text_layer; ++il) {
1216
+ if (aheads_masks.m[il] != nullptr) {
1217
+ auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
1218
+
1219
+ size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1];
1220
+ size_t data_size_bytes = data_size * sizeof(float);
1221
+ mask_data.resize(data_size);
1222
+
1223
+ std::fill(mask_data.begin(), mask_data.end(), 0);
1224
+ for (size_t ih = 0; ih < aheads.size(); ++ih) {
1225
+ size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0]));
1226
+ mask_data[pos] = 1.0f;
1227
+ }
1228
+
1229
+ wsp_ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size_bytes);
1230
+ }
1231
+ }
1232
+
1233
+ if (aheads_masks.m.empty()) {
1234
+ WHISPER_LOG_ERROR("%s: \n", __func__);
1235
+ return false;
1236
+ }
1237
+
1238
+ return true;
1239
+ }
1240
+
1241
+ static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) {
1242
+ wsp_ggml_free(aheads_masks.ctx);
1243
+ wsp_ggml_backend_buffer_free(aheads_masks.buffer);
1244
+ aheads_masks.ctx = nullptr;
1245
+ }
1061
1246
 
1062
- // initialize the backends
1063
- #ifdef WSP_GGML_USE_CUBLAS
1064
- if (params.use_gpu && wsp_ggml_cublas_loaded()) {
1247
+ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1248
+ size_t size = 0;
1249
+ for (size_t i = 0; i < aheads_masks.m.size(); ++i) {
1250
+ if (aheads_masks.m[i] != nullptr)
1251
+ size += wsp_ggml_nbytes(aheads_masks.m[i]);
1252
+ }
1253
+ return size;
1254
+ }
1255
+
1256
+ static wsp_ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1257
+ wsp_ggml_backend_t result = NULL;
1258
+
1259
+ wsp_ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1260
+
1261
+ #ifdef WSP_GGML_USE_CUDA
1262
+ if (params.use_gpu) {
1065
1263
  WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1066
- backend_gpu = wsp_ggml_backend_cuda_init(0);
1067
- if (!backend_gpu) {
1264
+ result = wsp_ggml_backend_cuda_init(params.gpu_device);
1265
+ if (!result) {
1068
1266
  WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
1069
1267
  }
1070
1268
  }
@@ -1073,22 +1271,108 @@ static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & pa
1073
1271
  #ifdef WSP_GGML_USE_METAL
1074
1272
  if (params.use_gpu) {
1075
1273
  WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1076
- wsp_ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
1077
- backend_gpu = wsp_ggml_backend_metal_init();
1078
- if (!backend_gpu) {
1274
+ result = wsp_ggml_backend_metal_init();
1275
+ if (!result) {
1079
1276
  WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__);
1080
- } else if (!wsp_ggml_backend_metal_supports_family(backend_gpu, 7)) {
1277
+ } else if (!wsp_ggml_backend_metal_supports_family(result, 7)) {
1081
1278
  WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1082
- wsp_ggml_backend_free(backend_gpu);
1083
- backend_gpu = NULL;
1279
+ wsp_ggml_backend_free(result);
1280
+ result = NULL;
1281
+ }
1282
+ }
1283
+ #endif
1284
+
1285
+ #ifdef WSP_GGML_USE_SYCL
1286
+ if (params.use_gpu) {
1287
+ WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
1288
+ result = wsp_ggml_backend_sycl_init(params.gpu_device);
1289
+ if (!result) {
1290
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_sycl_init() failed\n", __func__);
1084
1291
  }
1085
1292
  }
1086
1293
  #endif
1087
1294
 
1295
+ #ifdef WSP_GGML_USE_VULKAN
1296
+ if (params.use_gpu) {
1297
+ WHISPER_LOG_INFO("%s: using Vulkan backend\n", __func__);
1298
+ result = wsp_ggml_backend_vk_init(params.gpu_device);
1299
+ if (!result) {
1300
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_vk_init() failed\n", __func__);
1301
+ }
1302
+ }
1303
+ #endif
1304
+
1305
+ #ifdef WSP_GGML_USE_CANN
1306
+ if (params.use_gpu) {
1307
+ WHISPER_LOG_INFO("%s: using CANN backend\n", __func__);
1308
+ result = wsp_ggml_backend_cann_init(params.gpu_device);
1309
+ if (!result) {
1310
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cann_init() failed\n", __func__);
1311
+ }
1312
+ }
1313
+ #endif
1314
+
1315
+ WSP_GGML_UNUSED(params);
1316
+
1317
+ return result;
1318
+ }
1319
+
1320
+ static std::vector<wsp_ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
1321
+ std::vector<wsp_ggml_backend_t> result;
1322
+
1323
+ wsp_ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
1324
+
1088
1325
  if (backend_gpu) {
1089
- return backend_gpu;
1326
+ result.push_back(backend_gpu);
1090
1327
  }
1091
- return wsp_ggml_backend_cpu_init();
1328
+
1329
+ #ifdef WSP_GGML_USE_BLAS
1330
+ {
1331
+ WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__);
1332
+ wsp_ggml_backend_t backend_blas = wsp_ggml_backend_blas_init();
1333
+ if (!backend_blas) {
1334
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_blas_init() failed\n", __func__);
1335
+ } else {
1336
+ result.push_back(backend_blas);
1337
+ }
1338
+ }
1339
+ #endif
1340
+
1341
+ WSP_GGML_UNUSED(params);
1342
+
1343
+ result.push_back(wsp_ggml_backend_cpu_init());
1344
+
1345
+ return result;
1346
+ }
1347
+
1348
+ static wsp_ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1349
+ wsp_ggml_backend_buffer_type_t result = nullptr;
1350
+
1351
+ params.use_gpu || (result = wsp_ggml_backend_cpu_buffer_type());
1352
+
1353
+ #ifdef WSP_GGML_USE_CUDA
1354
+ result || (result = wsp_ggml_backend_cuda_buffer_type(params.gpu_device));
1355
+ #endif
1356
+
1357
+ #ifdef WSP_GGML_USE_METAL
1358
+ result || (result = wsp_ggml_backend_metal_buffer_type());
1359
+ #endif
1360
+
1361
+ #ifdef WSP_GGML_USE_SYCL
1362
+ result || (result = wsp_ggml_backend_sycl_buffer_type(params.gpu_device));
1363
+ #endif
1364
+
1365
+ #ifdef WSP_GGML_USE_VULKAN
1366
+ result || (result = wsp_ggml_backend_vk_buffer_type(params.gpu_device));
1367
+ #endif
1368
+
1369
+ #ifdef WSP_GGML_USE_CANN
1370
+ result || (result == wsp_ggml_backend_cann_buffer_type(params.gpu_device));
1371
+ #endif
1372
+
1373
+ result || (result = wsp_ggml_backend_cpu_buffer_type());
1374
+
1375
+ return result;
1092
1376
  }
1093
1377
 
1094
1378
  // load the model from a ggml file
@@ -1515,29 +1799,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1515
1799
  }
1516
1800
  }
1517
1801
 
1518
- wctx.backend = whisper_backend_init(wctx.params);
1519
-
1520
- {
1521
- size_t size_main = 0;
1522
-
1523
- for (const auto & t : model.tensors) {
1524
- size_main += wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead();
1525
- }
1526
-
1527
- model.buffer = wsp_ggml_backend_alloc_buffer(wctx.backend, size_main);
1528
-
1529
- WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6);
1530
- }
1531
-
1532
- wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(model.buffer);
1533
-
1534
1802
  // allocate tensors in the backend buffers
1535
- {
1536
- for (const auto & t : model.tensors) {
1537
- wsp_ggml_allocr_alloc(alloc, t.second);
1538
- }
1803
+ model.buffer = wsp_ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
1804
+ if (!model.buffer) {
1805
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1806
+ return false;
1539
1807
  }
1540
1808
 
1809
+ size_t size_main = wsp_ggml_backend_buffer_get_size(model.buffer);
1810
+ WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, wsp_ggml_backend_buffer_name(model.buffer), size_main / 1e6);
1811
+
1541
1812
  // load weights
1542
1813
  {
1543
1814
  size_t total_size = 0;
@@ -1599,15 +1870,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1599
1870
  return false;
1600
1871
  }
1601
1872
 
1602
- wsp_ggml_backend_t backend = wctx.backend;
1873
+ //wsp_ggml_backend_t backend = wctx.backend;
1603
1874
 
1604
1875
  //printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
1605
1876
 
1606
- if ((wsp_ggml_backend_is_cpu(backend)
1607
- #ifdef WSP_GGML_USE_METAL
1608
- || wsp_ggml_backend_is_metal(backend)
1609
- #endif
1610
- )) {
1877
+ if (wsp_ggml_backend_buffer_is_host(model.buffer)) {
1611
1878
  // for the CPU and Metal backend, we can read directly into the tensor
1612
1879
  loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
1613
1880
  BYTESWAP_TENSOR(tensor);
@@ -1635,7 +1902,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1635
1902
  }
1636
1903
  }
1637
1904
 
1638
- wsp_ggml_allocr_free(alloc);
1905
+ wsp_ggml_backend_buffer_set_usage(model.buffer, WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1639
1906
 
1640
1907
  wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
1641
1908
 
@@ -1662,10 +1929,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
1662
1929
 
1663
1930
  static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1664
1931
  whisper_context & wctx,
1665
- whisper_state & wstate,
1666
- const int mel_offset) {
1932
+ whisper_state & wstate) {
1667
1933
  const auto & model = wctx.model;
1668
- const auto & mel_inp = wstate.mel;
1669
1934
  const auto & hparams = model.hparams;
1670
1935
 
1671
1936
  const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
@@ -1674,8 +1939,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1674
1939
  const int n_mels = hparams.n_mels;
1675
1940
 
1676
1941
  struct wsp_ggml_init_params params = {
1677
- /*.mem_size =*/ wstate.alloc_conv.meta.size(),
1678
- /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
1942
+ /*.mem_size =*/ wstate.sched_conv.meta.size(),
1943
+ /*.mem_buffer =*/ wstate.sched_conv.meta.data(),
1679
1944
  /*.no_alloc =*/ true,
1680
1945
  };
1681
1946
 
@@ -1683,31 +1948,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1683
1948
 
1684
1949
  wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1685
1950
 
1686
- wsp_ggml_allocr * alloc = wstate.alloc_conv.alloc;
1687
-
1688
1951
  struct wsp_ggml_tensor * mel = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, 2*n_ctx, n_mels);
1689
- wsp_ggml_allocr_alloc(alloc, mel);
1690
-
1691
- assert(mel->type == WSP_GGML_TYPE_F32);
1692
- if (!wsp_ggml_allocr_is_measure(alloc)) {
1693
- assert(mel_inp.n_mel == n_mels);
1694
-
1695
- wstate.inp_mel.resize(wsp_ggml_nelements(mel));
1696
-
1697
- float * dst = wstate.inp_mel.data();
1698
- memset(dst, 0, wsp_ggml_nbytes(mel));
1699
-
1700
- const int i0 = std::min(mel_offset, mel_inp.n_len);
1701
- const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1702
-
1703
- for (int j = 0; j < mel_inp.n_mel; ++j) {
1704
- for (int i = i0; i < i1; ++i) {
1705
- dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1706
- }
1707
- }
1708
-
1709
- wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
1710
- }
1952
+ wsp_ggml_set_name(mel, "mel");
1953
+ wsp_ggml_set_input(mel);
1711
1954
 
1712
1955
  struct wsp_ggml_tensor * cur = nullptr;
1713
1956
 
@@ -1728,27 +1971,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1728
1971
  wsp_ggml_set_name(cur, "embd_conv");
1729
1972
  wstate.embd_conv = cur;
1730
1973
  } else {
1731
- #ifdef WHISPER_USE_COREML
1732
- cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1733
- wsp_ggml_allocr_alloc(alloc, cur);
1974
+ wsp_ggml_build_forward_expand(gf, mel);
1734
1975
 
1735
- if (!wsp_ggml_allocr_is_measure(alloc)) {
1736
- whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
1737
- }
1738
- #endif
1739
- #ifdef WHISPER_USE_OPENVINO
1740
1976
  cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1741
- wsp_ggml_allocr_alloc(alloc, cur);
1742
-
1743
- if (!wsp_ggml_allocr_is_measure(alloc)) {
1744
- whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
1745
- }
1746
- #endif
1977
+ wsp_ggml_set_input(cur); // the external encoder will write into this tensor
1747
1978
 
1748
1979
  wsp_ggml_set_name(cur, "embd_enc");
1749
1980
  wstate.embd_enc = cur;
1750
1981
  }
1751
1982
 
1983
+ wsp_ggml_set_output(cur);
1984
+
1752
1985
  wsp_ggml_build_forward_expand(gf, cur);
1753
1986
 
1754
1987
  wsp_ggml_free(ctx0);
@@ -1767,9 +2000,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1767
2000
  const int n_head = hparams.n_audio_head;
1768
2001
  const int n_layer = hparams.n_audio_layer;
1769
2002
 
2003
+ const int n_state_head = n_state/n_head;
2004
+
2005
+ auto & kv_pad = wstate.kv_pad;
2006
+
2007
+ WHISPER_ASSERT(!!kv_pad.buffer);
2008
+
2009
+ const int n_ctx_pad = WSP_GGML_PAD(n_ctx, 256);
2010
+
1770
2011
  struct wsp_ggml_init_params params = {
1771
- /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1772
- /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
2012
+ /*.mem_size =*/ wstate.sched_encode.meta.size(),
2013
+ /*.mem_buffer =*/ wstate.sched_encode.meta.data(),
1773
2014
  /*.no_alloc =*/ true,
1774
2015
  };
1775
2016
 
@@ -1777,23 +2018,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1777
2018
 
1778
2019
  wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
1779
2020
 
1780
- wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
1781
-
1782
- //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state);
1783
- //wsp_ggml_allocr_alloc(alloc, cur);
1784
-
1785
- //if (!wsp_ggml_allocr_is_measure(alloc)) {
1786
- // wsp_ggml_backend_tensor_copy(wstate.embd_conv, cur);
1787
- //}
1788
2021
  struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
1789
2022
 
1790
- struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
1791
- wsp_ggml_allocr_alloc(alloc, KQscale);
1792
-
1793
- if (!wsp_ggml_allocr_is_measure(alloc)) {
1794
- const float val = 1.0f/sqrtf(float(n_state)/n_head);
1795
- wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
1796
- }
2023
+ const float KQscale = 1.0f/sqrtf(float(n_state_head));
1797
2024
 
1798
2025
  // ===================================================================
1799
2026
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1843,14 +2070,14 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1843
2070
 
1844
2071
  Qcur = wsp_ggml_add(ctx0, Qcur, layer.attn_q_b);
1845
2072
 
1846
- //Qcur = wsp_ggml_scale(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2073
+ //Qcur = wsp_ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
1847
2074
 
1848
2075
  // note: no bias for Key
1849
2076
  struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
1850
2077
  layer.attn_k_w,
1851
2078
  cur);
1852
2079
 
1853
- //Kcur = wsp_ggml_scale(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2080
+ //Kcur = wsp_ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
1854
2081
 
1855
2082
  struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
1856
2083
  layer.attn_v_w,
@@ -1860,70 +2087,60 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1860
2087
 
1861
2088
  // ------
1862
2089
 
1863
- #ifdef WHISPER_USE_FLASH_ATTN
1864
2090
  struct wsp_ggml_tensor * Q =
1865
2091
  wsp_ggml_permute(ctx0,
1866
- wsp_ggml_cpy(ctx0,
1867
- Qcur,
1868
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
2092
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
1869
2093
  0, 2, 1, 3);
1870
2094
 
1871
- struct wsp_ggml_tensor * K =
1872
- wsp_ggml_permute(ctx0,
1873
- wsp_ggml_cpy(ctx0,
1874
- Kcur,
1875
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1876
- 0, 2, 1, 3);
2095
+ if (wctx.params.flash_attn) {
2096
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, wsp_ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
2097
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, wsp_ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
1877
2098
 
1878
- struct wsp_ggml_tensor * V =
1879
- wsp_ggml_cpy(ctx0,
1880
- wsp_ggml_permute(ctx0,
1881
- wsp_ggml_reshape_3d(ctx0,
1882
- Vcur,
1883
- n_state/n_head, n_head, n_ctx),
1884
- 1, 2, 0, 3),
1885
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
2099
+ struct wsp_ggml_tensor * K =
2100
+ wsp_ggml_view_3d(ctx0, kv_pad.k,
2101
+ n_state_head, n_ctx_pad, n_head,
2102
+ wsp_ggml_element_size(kv_pad.k)*n_state,
2103
+ wsp_ggml_element_size(kv_pad.k)*n_state_head,
2104
+ 0);
1886
2105
 
1887
- struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false);
1888
- #else
1889
- struct wsp_ggml_tensor * Q =
1890
- wsp_ggml_permute(ctx0,
1891
- wsp_ggml_cpy(ctx0,
1892
- Qcur,
1893
- wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1894
- 0, 2, 1, 3);
2106
+ struct wsp_ggml_tensor * V =
2107
+ wsp_ggml_view_3d(ctx0, kv_pad.v,
2108
+ n_state_head, n_ctx_pad, n_head,
2109
+ wsp_ggml_element_size(kv_pad.v)*n_state,
2110
+ wsp_ggml_element_size(kv_pad.v)*n_state_head,
2111
+ 0);
1895
2112
 
1896
- struct wsp_ggml_tensor * K =
1897
- wsp_ggml_permute(ctx0,
1898
- wsp_ggml_cpy(ctx0,
1899
- Kcur,
1900
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1901
- 0, 2, 1, 3);
2113
+ cur = wsp_ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
1902
2114
 
1903
- // K * Q
1904
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2115
+ cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
2116
+ } else {
2117
+ struct wsp_ggml_tensor * K =
2118
+ wsp_ggml_permute(ctx0,
2119
+ wsp_ggml_cast(ctx0,
2120
+ wsp_ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
2121
+ wctx.itype),
2122
+ 0, 2, 1, 3);
1905
2123
 
1906
- struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQscale);
2124
+ // K * Q
2125
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
1907
2126
 
1908
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_scaled);
2127
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
1909
2128
 
1910
- struct wsp_ggml_tensor * V =
1911
- wsp_ggml_cpy(ctx0,
1912
- wsp_ggml_permute(ctx0,
1913
- wsp_ggml_reshape_3d(ctx0,
1914
- Vcur,
1915
- n_state/n_head, n_head, n_ctx),
1916
- 1, 2, 0, 3),
1917
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1918
- );
2129
+ struct wsp_ggml_tensor * V =
2130
+ wsp_ggml_cast(ctx0,
2131
+ wsp_ggml_permute(ctx0,
2132
+ wsp_ggml_reshape_3d(ctx0,
2133
+ Vcur,
2134
+ n_state_head, n_head, n_ctx),
2135
+ 1, 2, 0, 3),
2136
+ wctx.itype);
1919
2137
 
1920
- struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
1921
- #endif
1922
- struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2138
+ struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2139
+
2140
+ struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1923
2141
 
1924
- cur = wsp_ggml_cpy(ctx0,
1925
- KQV_merged,
1926
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
2142
+ cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
2143
+ }
1927
2144
  }
1928
2145
 
1929
2146
  // projection
@@ -1952,11 +2169,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1952
2169
  layer.mlp_ln_b);
1953
2170
  }
1954
2171
 
1955
- #ifdef WHISPER_USE_FLASH_FF
1956
- cur = wsp_ggml_flash_ff(ctx0,
1957
- wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
1958
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1959
- #else
1960
2172
  // fully connected
1961
2173
  cur = wsp_ggml_mul_mat(ctx0,
1962
2174
  layer.mlp_0_w,
@@ -1973,7 +2185,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1973
2185
  cur);
1974
2186
 
1975
2187
  cur = wsp_ggml_add(ctx0, cur, layer.mlp_1_b);
1976
- #endif
1977
2188
  }
1978
2189
 
1979
2190
  inpL = wsp_ggml_add(ctx0, cur, inpFF);
@@ -2022,9 +2233,13 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
2022
2233
  const int n_state = hparams.n_audio_state;
2023
2234
  const int n_head = hparams.n_audio_head;
2024
2235
 
2236
+ const int n_state_head = n_state/n_head;
2237
+
2238
+ const int n_ctx_pad = WSP_GGML_PAD(n_ctx, 256);
2239
+
2025
2240
  struct wsp_ggml_init_params params = {
2026
- /*.mem_size =*/ wstate.alloc_cross.meta.size(),
2027
- /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
2241
+ /*.mem_size =*/ wstate.sched_cross.meta.size(),
2242
+ /*.mem_buffer =*/ wstate.sched_cross.meta.data(),
2028
2243
  /*.no_alloc =*/ true,
2029
2244
  };
2030
2245
 
@@ -2032,34 +2247,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
2032
2247
 
2033
2248
  wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
2034
2249
 
2035
- wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
2036
-
2037
- //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
2038
- //wsp_ggml_allocr_alloc(alloc, cur);
2039
-
2040
- //if (!wsp_ggml_allocr_is_measure(alloc)) {
2041
- // wsp_ggml_backend_tensor_copy(wstate.embd_enc, cur);
2042
- //}
2043
2250
  struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
2044
2251
 
2045
- struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
2046
- wsp_ggml_allocr_alloc(alloc, Kscale);
2047
-
2048
- if (!wsp_ggml_allocr_is_measure(alloc)) {
2049
- const float val = pow(float(n_state) / n_head, -0.25);
2050
- wsp_ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
2051
- }
2252
+ const float Kscale = pow(float(n_state_head), -0.25);
2052
2253
 
2053
2254
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
2054
2255
  auto & layer = model.layers_decoder[il];
2055
2256
 
2056
- struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
2257
+ struct wsp_ggml_tensor * Kcross = wsp_ggml_mul_mat(ctx0,
2057
2258
  layer.cross_attn_k_w,
2058
2259
  cur);
2059
2260
 
2060
2261
  Kcross = wsp_ggml_scale(ctx0, Kcross, Kscale);
2061
2262
 
2062
- struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
2263
+ struct wsp_ggml_tensor * Vcross = wsp_ggml_mul_mat(ctx0,
2063
2264
  layer.cross_attn_v_w,
2064
2265
  cur);
2065
2266
 
@@ -2067,15 +2268,25 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
2067
2268
  Vcross,
2068
2269
  layer.cross_attn_v_b);
2069
2270
 
2070
- Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2271
+ struct wsp_ggml_tensor * k;
2272
+ struct wsp_ggml_tensor * v;
2273
+
2274
+ if (wctx.params.flash_attn) {
2275
+ k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2276
+ (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
2071
2277
 
2072
- struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k,
2073
- n_state*n_ctx,
2074
- (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2278
+ v = wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
2279
+ (wsp_ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
2280
+ } else {
2281
+ Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2282
+
2283
+ k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2284
+ (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2075
2285
 
2076
- struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2077
- ( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
2078
- (il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
2286
+ v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2287
+ ( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
2288
+ (il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
2289
+ }
2079
2290
 
2080
2291
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcross, k));
2081
2292
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcross, v));
@@ -2103,49 +2314,91 @@ static bool whisper_encode_internal(
2103
2314
  whisper_state & wstate,
2104
2315
  const int mel_offset,
2105
2316
  const int n_threads,
2106
- whisper_abort_callback abort_callback,
2317
+ wsp_ggml_abort_callback abort_callback,
2107
2318
  void * abort_callback_data) {
2108
2319
  const int64_t t_start_us = wsp_ggml_time_us();
2109
2320
 
2110
2321
  // conv
2111
2322
  {
2112
- auto & alloc = wstate.alloc_conv.alloc;
2323
+ auto & sched = wstate.sched_conv.sched;
2324
+
2325
+ wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
2326
+
2327
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
2328
+ // should never happen as we pre-allocate the memory
2329
+ return false;
2330
+ }
2331
+
2332
+ struct wsp_ggml_tensor * mel = wsp_ggml_graph_get_tensor(gf, "mel");
2333
+
2334
+ // set the input
2335
+ {
2336
+ const auto & mel_inp = wstate.mel;
2337
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
2338
+
2339
+ assert(mel->type == WSP_GGML_TYPE_F32);
2340
+ assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
2341
+
2342
+ wstate.inp_mel.resize(wsp_ggml_nelements(mel));
2113
2343
 
2114
- wsp_ggml_allocr_reset(alloc);
2344
+ float * dst = wstate.inp_mel.data();
2345
+ memset(dst, 0, wsp_ggml_nbytes(mel));
2115
2346
 
2116
- wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
2347
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
2348
+ const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
2349
+
2350
+ for (int j = 0; j < mel_inp.n_mel; ++j) {
2351
+ for (int i = i0; i < i1; ++i) {
2352
+ dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
2353
+ }
2354
+ }
2117
2355
 
2118
- wsp_ggml_allocr_alloc_graph(alloc, gf);
2356
+ wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
2357
+ }
2119
2358
 
2120
2359
  if (!whisper_encode_external(wstate)) {
2121
- wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2360
+ if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
2361
+ return false;
2362
+ }
2363
+ } else {
2364
+ #if defined(WHISPER_USE_COREML)
2365
+ whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
2366
+ #elif defined(WHISPER_USE_OPENVINO)
2367
+ whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc);
2368
+ #endif
2122
2369
  }
2123
2370
  }
2124
2371
 
2125
2372
  // encoder
2126
2373
  if (!whisper_encode_external(wstate)) {
2127
- auto & alloc = wstate.alloc_encode.alloc;
2128
-
2129
- wsp_ggml_allocr_reset(alloc);
2374
+ auto & sched = wstate.sched_encode.sched;
2130
2375
 
2131
2376
  wsp_ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
2132
2377
 
2133
- wsp_ggml_allocr_alloc_graph(alloc, gf);
2378
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
2379
+ // should never happen as we pre-allocate the memory
2380
+ return false;
2381
+ }
2134
2382
 
2135
- wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2383
+ if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
2384
+ return false;
2385
+ }
2136
2386
  }
2137
2387
 
2138
2388
  // cross
2139
2389
  {
2140
- auto & alloc = wstate.alloc_cross.alloc;
2141
-
2142
- wsp_ggml_allocr_reset(alloc);
2390
+ auto & sched = wstate.sched_cross.sched;
2143
2391
 
2144
2392
  wsp_ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
2145
2393
 
2146
- wsp_ggml_allocr_alloc_graph(alloc, gf);
2394
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
2395
+ // should never happen as we pre-allocate the memory
2396
+ return false;
2397
+ }
2147
2398
 
2148
- wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2399
+ if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
2400
+ return false;
2401
+ }
2149
2402
  }
2150
2403
 
2151
2404
  wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
@@ -2157,32 +2410,36 @@ static bool whisper_encode_internal(
2157
2410
  static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2158
2411
  whisper_context & wctx,
2159
2412
  whisper_state & wstate,
2160
- const whisper_batch & batch) {
2413
+ const whisper_batch & batch,
2414
+ bool save_alignment_heads_QKs,
2415
+ bool worst_case) {
2161
2416
  const auto & model = wctx.model;
2162
2417
  const auto & hparams = model.hparams;
2163
2418
 
2164
2419
  auto & kv_self = wstate.kv_self;
2165
2420
 
2166
- WHISPER_ASSERT(!!kv_self.ctx);
2167
-
2168
- wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
2421
+ WHISPER_ASSERT(!!kv_self.buffer);
2169
2422
 
2170
2423
  const int n_ctx = kv_self.size;
2171
2424
  const int n_state = hparams.n_text_state;
2172
2425
  const int n_head = hparams.n_text_head;
2173
2426
  const int n_layer = hparams.n_text_layer;
2174
2427
 
2428
+ const int n_state_head = n_state/n_head;
2429
+
2175
2430
  const int n_tokens = batch.n_tokens;
2176
2431
  const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2177
2432
 
2178
- const int32_t n_kv = wsp_ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
2179
- const int32_t kv_head = wsp_ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
2433
+ const int n_audio_ctx_pad = WSP_GGML_PAD(n_audio_ctx, 256);
2434
+
2435
+ const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2436
+ const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
2180
2437
 
2181
- //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2438
+ //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2182
2439
 
2183
2440
  struct wsp_ggml_init_params params = {
2184
- /*.mem_size =*/ wstate.alloc_decode.meta.size(),
2185
- /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
2441
+ /*.mem_size =*/ wstate.sched_decode.meta.size(),
2442
+ /*.mem_buffer =*/ wstate.sched_decode.meta.data(),
2186
2443
  /*.no_alloc =*/ true,
2187
2444
  };
2188
2445
 
@@ -2190,55 +2447,21 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2190
2447
 
2191
2448
  wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
2192
2449
 
2193
- struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2194
- wsp_ggml_allocr_alloc(alloc, embd);
2195
-
2196
- if (!wsp_ggml_allocr_is_measure(alloc)) {
2197
- wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
2198
- }
2199
-
2200
- struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2201
- wsp_ggml_allocr_alloc(alloc, position);
2202
-
2203
- if (!wsp_ggml_allocr_is_measure(alloc)) {
2204
- for (int i = 0; i < n_tokens; ++i) {
2205
- const int32_t val = batch.pos[i];
2206
- wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
2207
- }
2208
- }
2209
-
2210
- struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
2211
- wsp_ggml_allocr_alloc(alloc, KQscale);
2212
-
2213
- if (!wsp_ggml_allocr_is_measure(alloc)) {
2214
- const float val = pow(float(n_state)/n_head, -0.25);
2215
- wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
2216
- }
2217
-
2218
- struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, n_tokens, 1);
2219
- wsp_ggml_allocr_alloc(alloc, KQ_mask);
2220
-
2221
- if (!wsp_ggml_allocr_is_measure(alloc)) {
2222
- wstate.inp_mask.resize(n_kv*n_tokens);
2450
+ struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2451
+ wsp_ggml_set_name(embd, "embd");
2452
+ wsp_ggml_set_input(embd);
2223
2453
 
2224
- float * data = wstate.inp_mask.data();
2225
- memset(data, 0, wsp_ggml_nbytes(KQ_mask));
2454
+ struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2455
+ wsp_ggml_set_name(position, "position");
2456
+ wsp_ggml_set_input(position);
2226
2457
 
2227
- for (int h = 0; h < 1; ++h) {
2228
- for (int j = 0; j < n_tokens; ++j) {
2229
- const whisper_pos pos = batch.pos[j];
2230
- const whisper_seq_id seq_id = batch.seq_id[j][0];
2458
+ const float KQscale = pow(float(n_state_head), -0.25);
2231
2459
 
2232
- for (int i = 0; i < n_kv; ++i) {
2233
- if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2234
- data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
2235
- }
2236
- }
2237
- }
2238
- }
2460
+ struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, WSP_GGML_PAD(n_tokens, WSP_GGML_KQ_MASK_PAD), 1);
2461
+ wsp_ggml_set_name(KQ_mask, "KQ_mask");
2462
+ wsp_ggml_set_input(KQ_mask);
2239
2463
 
2240
- wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
2241
- }
2464
+ struct wsp_ggml_tensor * KQ_mask_f16 = wsp_ggml_cast(ctx0, KQ_mask, WSP_GGML_TYPE_F16);
2242
2465
 
2243
2466
  // token encoding + position encoding
2244
2467
  struct wsp_ggml_tensor * cur =
@@ -2248,6 +2471,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2248
2471
 
2249
2472
  struct wsp_ggml_tensor * inpL = cur;
2250
2473
 
2474
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2475
+ struct wsp_ggml_tensor * aheads_cross_QKs = nullptr;
2476
+
2251
2477
  for (int il = 0; il < n_layer; ++il) {
2252
2478
  const auto & layer = model.layers_decoder[il];
2253
2479
 
@@ -2292,12 +2518,25 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2292
2518
  Vcur,
2293
2519
  layer.attn_v_b);
2294
2520
 
2295
- Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2521
+ struct wsp_ggml_tensor * k;
2522
+ struct wsp_ggml_tensor * v;
2296
2523
 
2297
- struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2298
- struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2299
- ( n_ctx)*wsp_ggml_element_size(kv_self.v),
2300
- (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
2524
+ if (wctx.params.flash_attn) {
2525
+ k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2526
+ (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2527
+
2528
+ v = wsp_ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
2529
+ (wsp_ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
2530
+ } else {
2531
+ Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2532
+
2533
+ k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2534
+ (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2535
+
2536
+ v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2537
+ ( n_ctx)*wsp_ggml_element_size(kv_self.v),
2538
+ (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
2539
+ }
2301
2540
 
2302
2541
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
2303
2542
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
@@ -2307,40 +2546,46 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2307
2546
 
2308
2547
  struct wsp_ggml_tensor * Q =
2309
2548
  wsp_ggml_permute(ctx0,
2310
- wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2549
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2311
2550
  0, 2, 1, 3);
2312
2551
 
2313
2552
  struct wsp_ggml_tensor * K =
2314
2553
  wsp_ggml_view_3d(ctx0, kv_self.k,
2315
- n_state/n_head, n_kv, n_head,
2554
+ n_state_head, n_kv, n_head,
2316
2555
  wsp_ggml_element_size(kv_self.k)*n_state,
2317
- wsp_ggml_element_size(kv_self.k)*n_state/n_head,
2556
+ wsp_ggml_element_size(kv_self.k)*n_state_head,
2318
2557
  wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
2319
2558
 
2320
- // K * Q
2321
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2559
+ if (wctx.params.flash_attn) {
2560
+ struct wsp_ggml_tensor * V =
2561
+ wsp_ggml_view_3d(ctx0, kv_self.v,
2562
+ n_state_head, n_kv, n_head,
2563
+ wsp_ggml_element_size(kv_self.v)*n_state,
2564
+ wsp_ggml_element_size(kv_self.v)*n_state_head,
2565
+ wsp_ggml_element_size(kv_self.v)*n_state*n_ctx*il);
2322
2566
 
2323
- //struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
2567
+ cur = wsp_ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
2324
2568
 
2325
- //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
2326
- struct wsp_ggml_tensor * KQ_masked = wsp_ggml_add(ctx0, KQ, KQ_mask);
2569
+ cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2570
+ } else {
2571
+ // K * Q
2572
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2327
2573
 
2328
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
2574
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
2329
2575
 
2330
- struct wsp_ggml_tensor * V =
2331
- wsp_ggml_view_3d(ctx0, kv_self.v,
2332
- n_kv, n_state/n_head, n_head,
2333
- n_ctx*wsp_ggml_element_size(kv_self.v),
2334
- n_ctx*wsp_ggml_element_size(kv_self.v)*n_state/n_head,
2335
- n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
2576
+ struct wsp_ggml_tensor * V =
2577
+ wsp_ggml_view_3d(ctx0, kv_self.v,
2578
+ n_kv, n_state_head, n_head,
2579
+ n_ctx*wsp_ggml_element_size(kv_self.v),
2580
+ n_ctx*wsp_ggml_element_size(kv_self.v)*n_state_head,
2581
+ n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
2336
2582
 
2337
- struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2583
+ struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2338
2584
 
2339
- struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2585
+ struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2340
2586
 
2341
- cur = wsp_ggml_cpy(ctx0,
2342
- KQV_merged,
2343
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
2587
+ cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2588
+ }
2344
2589
  }
2345
2590
 
2346
2591
  // projection
@@ -2379,62 +2624,75 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2379
2624
  Qcur,
2380
2625
  layer.cross_attn_q_b);
2381
2626
 
2382
- Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
2383
-
2384
- // Kcross is already scaled
2385
- struct wsp_ggml_tensor * Kcross =
2386
- wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2387
- n_state/n_head, n_audio_ctx, n_head,
2388
- wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2389
- wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2390
- wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2391
-
2392
- //struct wsp_ggml_tensor * Vcross =
2393
- // wsp_ggml_reshape_3d(ctx0,
2394
- // wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
2395
- // n_state/n_head, n_head, n_audio_ctx);
2396
-
2397
- //struct wsp_ggml_tensor * V_trans =
2398
- // wsp_ggml_cpy(ctx0,
2399
- // wsp_ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2400
- // wsp_ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
2401
-
2402
- struct wsp_ggml_tensor * V =
2403
- wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
2404
- n_audio_ctx, n_state/n_head, n_head,
2405
- n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
2406
- n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2407
- n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
2408
-
2409
- // ------
2410
-
2411
2627
  struct wsp_ggml_tensor * Q =
2412
2628
  wsp_ggml_permute(ctx0,
2413
- wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2629
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2414
2630
  0, 2, 1, 3);
2415
2631
 
2416
- // K * Q
2417
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
2632
+ if (wctx.params.flash_attn) {
2633
+ struct wsp_ggml_tensor * Kcross =
2634
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2635
+ n_state_head, n_audio_ctx_pad, n_head,
2636
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2637
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state_head,
2638
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
2418
2639
 
2419
- //struct wsp_ggml_tensor * KQ_scaled =
2420
- // wsp_ggml_scale(ctx0,
2421
- // KQ,
2422
- // wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2423
- // );
2640
+ struct wsp_ggml_tensor * Vcross =
2641
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
2642
+ n_state_head, n_audio_ctx_pad, n_head,
2643
+ wsp_ggml_element_size(wstate.kv_cross.v)*n_state,
2644
+ wsp_ggml_element_size(wstate.kv_cross.v)*n_state_head,
2645
+ wsp_ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
2424
2646
 
2425
- // no masking for cross-attention
2426
- //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2647
+ cur = wsp_ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
2427
2648
 
2428
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ);
2649
+ cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2650
+ } else {
2651
+ struct wsp_ggml_tensor * Kcross =
2652
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2653
+ n_state_head, n_audio_ctx, n_head,
2654
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2655
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state_head,
2656
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2657
+
2658
+ struct wsp_ggml_tensor * Vcross =
2659
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
2660
+ n_audio_ctx, n_state_head, n_head,
2661
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
2662
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state_head,
2663
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
2664
+
2665
+ // ------
2666
+
2667
+ // K * Q
2668
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
2669
+
2670
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2671
+
2672
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2673
+ if (wctx.params.dtw_token_timestamps) {
2674
+ if (wstate.aheads_masks.m[il] != nullptr) {
2675
+ struct wsp_ggml_tensor * aheads_KQs = wsp_ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2676
+ aheads_KQs = wsp_ggml_transpose(ctx0, aheads_KQs);
2677
+ aheads_KQs = wsp_ggml_cont(ctx0, aheads_KQs);
2678
+ aheads_KQs = wsp_ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2679
+ aheads_KQs = wsp_ggml_transpose(ctx0, aheads_KQs);
2680
+ aheads_KQs = wsp_ggml_cont(ctx0, aheads_KQs);
2681
+ aheads_KQs = wsp_ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2682
+ if (aheads_cross_QKs == NULL) {
2683
+ aheads_cross_QKs = aheads_KQs;
2684
+ } else {
2685
+ aheads_cross_QKs = wsp_ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
2686
+ }
2687
+ }
2688
+ }
2429
2689
 
2430
- struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2690
+ struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
2431
2691
 
2432
- struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2692
+ struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2433
2693
 
2434
- // cur = KQV_merged.contiguous().view(n_state, n_tokens)
2435
- cur = wsp_ggml_cpy(ctx0,
2436
- KQV_merged,
2437
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
2694
+ cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2695
+ }
2438
2696
  }
2439
2697
 
2440
2698
  // projection
@@ -2512,6 +2770,16 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2512
2770
 
2513
2771
  struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
2514
2772
 
2773
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2774
+ if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) {
2775
+ aheads_cross_QKs = wsp_ggml_transpose(ctx0, aheads_cross_QKs);
2776
+ aheads_cross_QKs = wsp_ggml_cont(ctx0, aheads_cross_QKs);
2777
+ if (save_alignment_heads_QKs) {
2778
+ wsp_ggml_build_forward_expand(gf, aheads_cross_QKs);
2779
+ wstate.aheads_cross_QKs = aheads_cross_QKs;
2780
+ }
2781
+ }
2782
+
2515
2783
  wsp_ggml_build_forward_expand(gf, logits);
2516
2784
 
2517
2785
  wsp_ggml_free(ctx0);
@@ -2534,7 +2802,8 @@ static bool whisper_decode_internal(
2534
2802
  whisper_state & wstate,
2535
2803
  const whisper_batch & batch,
2536
2804
  const int n_threads,
2537
- whisper_abort_callback abort_callback,
2805
+ bool save_alignment_heads_QKs,
2806
+ wsp_ggml_abort_callback abort_callback,
2538
2807
  void * abort_callback_data) {
2539
2808
  const int64_t t_start_us = wsp_ggml_time_us();
2540
2809
 
@@ -2556,24 +2825,77 @@ static bool whisper_decode_internal(
2556
2825
  return false;
2557
2826
  }
2558
2827
 
2559
- kv_self.n = whisper_kv_cache_cell_max(kv_self);
2828
+ const uint32_t pad = whisper_kv_cache_get_padding(wctx);
2829
+ kv_self.n = std::min(kv_self.size, std::max(pad, WSP_GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
2830
+
2560
2831
  //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2561
2832
  //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2562
2833
  }
2563
2834
 
2564
2835
  // decoder
2565
2836
  {
2566
- auto & alloc = wstate.alloc_decode.alloc;
2837
+ auto & sched = wstate.sched_decode.sched;
2838
+
2839
+ wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
2840
+
2841
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
2842
+ // should never happen as we pre-allocate the memory
2843
+ return false;
2844
+ }
2845
+
2846
+ // set the inputs
2847
+ {
2848
+ struct wsp_ggml_tensor * embd = wsp_ggml_graph_get_tensor(gf, "embd");
2849
+ wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
2850
+ }
2851
+
2852
+ {
2853
+ struct wsp_ggml_tensor * position = wsp_ggml_graph_get_tensor(gf, "position");
2854
+ for (int i = 0; i < n_tokens; ++i) {
2855
+ const int32_t val = batch.pos[i];
2856
+ wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
2857
+ }
2858
+ }
2859
+
2860
+ {
2861
+ struct wsp_ggml_tensor * KQ_mask = wsp_ggml_graph_get_tensor(gf, "KQ_mask");
2862
+
2863
+ auto & kv_self = wstate.kv_self;
2864
+
2865
+ const int32_t n_kv = kv_self.n;
2866
+
2867
+ wstate.inp_mask.resize(wsp_ggml_nelements(KQ_mask));
2868
+
2869
+ float * data = wstate.inp_mask.data();
2870
+ memset(data, 0, wsp_ggml_nbytes(KQ_mask));
2871
+
2872
+ for (int h = 0; h < 1; ++h) {
2873
+ for (int j = 0; j < n_tokens; ++j) {
2874
+ const whisper_pos pos = batch.pos[j];
2875
+ const whisper_seq_id seq_id = batch.seq_id[j][0];
2567
2876
 
2568
- wsp_ggml_allocr_reset(alloc);
2877
+ for (int i = 0; i < n_kv; ++i) {
2878
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2879
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
2880
+ }
2881
+ }
2882
+ }
2569
2883
 
2570
- wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
2884
+ for (int i = n_tokens; i < WSP_GGML_PAD(n_tokens, WSP_GGML_KQ_MASK_PAD); ++i) {
2885
+ for (int j = 0; j < n_kv; ++j) {
2886
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
2887
+ }
2888
+ }
2889
+ }
2571
2890
 
2572
- wsp_ggml_allocr_alloc_graph(alloc, gf);
2891
+ wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
2892
+ }
2573
2893
 
2574
- logits = gf->nodes[gf->n_nodes - 1];
2894
+ logits = wsp_ggml_graph_node(gf, -1);
2575
2895
 
2576
- wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2896
+ if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
2897
+ return false;
2898
+ }
2577
2899
  }
2578
2900
 
2579
2901
  logits_out.resize(n_tokens*n_vocab);
@@ -2625,29 +2947,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2625
2947
  }
2626
2948
 
2627
2949
  #define SIN_COS_N_COUNT WHISPER_N_FFT
2628
- static float sin_vals[SIN_COS_N_COUNT];
2629
- static float cos_vals[SIN_COS_N_COUNT];
2950
+ namespace {
2951
+ struct whisper_global_cache {
2952
+ // In FFT, we frequently use sine and cosine operations with the same values.
2953
+ // We can use precalculated values to speed up the process.
2954
+ float sin_vals[SIN_COS_N_COUNT];
2955
+ float cos_vals[SIN_COS_N_COUNT];
2956
+
2957
+ // Hann window (Use cosf to eliminate difference)
2958
+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2959
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2960
+ float hann_window[WHISPER_N_FFT];
2961
+
2962
+ whisper_global_cache() {
2963
+ fill_sin_cos_table();
2964
+ fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
2965
+ }
2966
+
2967
+ void fill_sin_cos_table() {
2968
+ for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2969
+ double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
2970
+ sin_vals[i] = sinf(theta);
2971
+ cos_vals[i] = cosf(theta);
2972
+ }
2973
+ }
2630
2974
 
2631
- // In FFT, we frequently use sine and cosine operations with the same values.
2632
- // We can use precalculated values to speed up the process.
2633
- static void fill_sin_cos_table() {
2634
- static bool is_filled = false;
2635
- if (is_filled) return;
2636
- for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2637
- double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
2638
- sin_vals[i] = sinf(theta);
2639
- cos_vals[i] = cosf(theta);
2975
+ void fill_hann_window(int length, bool periodic, float * output) {
2976
+ int offset = -1;
2977
+ if (periodic) {
2978
+ offset = 0;
2979
+ }
2980
+ for (int i = 0; i < length; i++) {
2981
+ output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
2982
+ }
2640
2983
  }
2641
- is_filled = true;
2984
+ } global_cache;
2642
2985
  }
2643
2986
 
2644
2987
  // naive Discrete Fourier Transform
2645
2988
  // input is real-valued
2646
2989
  // output is complex-valued
2647
- static void dft(const std::vector<float> & in, std::vector<float> & out) {
2648
- int N = in.size();
2649
-
2650
- out.resize(N*2);
2990
+ static void dft(const float* in, int N, float* out) {
2651
2991
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2652
2992
 
2653
2993
  for (int k = 0; k < N; k++) {
@@ -2656,8 +2996,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2656
2996
 
2657
2997
  for (int n = 0; n < N; n++) {
2658
2998
  int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
2659
- re += in[n]*cos_vals[idx]; // cos(t)
2660
- im -= in[n]*sin_vals[idx]; // sin(t)
2999
+ re += in[n]*global_cache.cos_vals[idx]; // cos(t)
3000
+ im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
2661
3001
  }
2662
3002
 
2663
3003
  out[k*2 + 0] = re;
@@ -2669,47 +3009,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2669
3009
  // poor man's implementation - use something better
2670
3010
  // input is real-valued
2671
3011
  // output is complex-valued
2672
- static void fft(const std::vector<float> & in, std::vector<float> & out) {
2673
- out.resize(in.size()*2);
2674
-
2675
- int N = in.size();
2676
-
3012
+ static void fft(float* in, int N, float* out) {
2677
3013
  if (N == 1) {
2678
3014
  out[0] = in[0];
2679
3015
  out[1] = 0;
2680
3016
  return;
2681
3017
  }
2682
3018
 
2683
- if (N%2 == 1) {
2684
- dft(in, out);
3019
+ const int half_N = N / 2;
3020
+ if (N - half_N*2 == 1) {
3021
+ dft(in, N, out);
2685
3022
  return;
2686
3023
  }
2687
3024
 
2688
- std::vector<float> even;
2689
- std::vector<float> odd;
2690
-
2691
- even.reserve(N/2);
2692
- odd.reserve(N/2);
2693
-
2694
- for (int i = 0; i < N; i++) {
2695
- if (i % 2 == 0) {
2696
- even.push_back(in[i]);
2697
- } else {
2698
- odd.push_back(in[i]);
2699
- }
3025
+ float* even = in + N;
3026
+ for (int i = 0; i < half_N; ++i) {
3027
+ even[i]= in[2*i];
2700
3028
  }
3029
+ float* even_fft = out + 2 * N;
3030
+ fft(even, half_N, even_fft);
2701
3031
 
2702
- std::vector<float> even_fft;
2703
- std::vector<float> odd_fft;
2704
-
2705
- fft(even, even_fft);
2706
- fft(odd, odd_fft);
3032
+ float* odd = even;
3033
+ for (int i = 0; i < half_N; ++i) {
3034
+ odd[i] = in[2*i + 1];
3035
+ }
3036
+ float* odd_fft = even_fft + N;
3037
+ fft(odd, half_N, odd_fft);
2707
3038
 
2708
3039
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2709
- for (int k = 0; k < N/2; k++) {
3040
+ for (int k = 0; k < half_N; k++) {
2710
3041
  int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2711
- float re = cos_vals[idx]; // cos(t)
2712
- float im = -sin_vals[idx]; // sin(t)
3042
+ float re = global_cache.cos_vals[idx]; // cos(t)
3043
+ float im = -global_cache.sin_vals[idx]; // sin(t)
2713
3044
 
2714
3045
  float re_odd = odd_fft[2*k + 0];
2715
3046
  float im_odd = odd_fft[2*k + 1];
@@ -2717,61 +3048,49 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2717
3048
  out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
2718
3049
  out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
2719
3050
 
2720
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2721
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2722
- }
2723
- }
2724
-
2725
- static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2726
- if (output.size() < static_cast<size_t>(length)) {
2727
- output.resize(length);
2728
- }
2729
- int offset = -1;
2730
- if (periodic) {
2731
- offset = 0;
2732
- }
2733
- for (int i = 0; i < length; i++) {
2734
- output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
3051
+ out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
3052
+ out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2735
3053
  }
2736
-
2737
- return true;
2738
3054
  }
2739
3055
 
2740
- static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
3056
+ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
2741
3057
  int n_samples, int frame_size, int frame_step, int n_threads,
2742
3058
  const whisper_filters & filters, whisper_mel & mel) {
2743
- std::vector<float> fft_in(frame_size, 0.0);
2744
- std::vector<float> fft_out(2 * frame_step);
2745
- // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2746
- int n_fft = 1 + (frame_size / 2);
3059
+ std::vector<float> fft_in(frame_size * 2, 0.0);
3060
+ std::vector<float> fft_out(frame_size * 2 * 2 * 2);
3061
+
3062
+ int n_fft = filters.n_fft;
2747
3063
  int i = ith;
2748
3064
 
3065
+ // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
3066
+ assert(n_fft == 1 + (frame_size / 2));
3067
+
2749
3068
  // calculate FFT only when fft_in are not all zero
2750
3069
  for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2751
3070
  const int offset = i * frame_step;
2752
3071
 
2753
- // apply Hanning window (~10% faster)
3072
+ // apply Hann window (~10% faster)
2754
3073
  for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
2755
3074
  fft_in[j] = hann[j] * samples[offset + j];
2756
3075
  }
3076
+
2757
3077
  // fill the rest with zeros
2758
3078
  if (n_samples - offset < frame_size) {
2759
3079
  std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
2760
3080
  }
2761
3081
 
2762
3082
  // FFT
2763
- fft(fft_in, fft_out);
3083
+ fft(fft_in.data(), frame_size, fft_out.data());
2764
3084
 
2765
3085
  // Calculate modulus^2 of complex numbers
2766
3086
  // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
2767
- for (int j = 0; j < frame_size; j++) {
3087
+ for (int j = 0; j < n_fft; j++) {
2768
3088
  fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
2769
3089
  }
2770
3090
 
2771
3091
  // mel spectrogram
2772
3092
  for (int j = 0; j < mel.n_mel; j++) {
2773
3093
  double sum = 0.0;
2774
-
2775
3094
  // unroll loop (suggested by GH user @lunixbochs)
2776
3095
  int k = 0;
2777
3096
  for (k = 0; k < n_fft - 3; k += 4) {
@@ -2781,14 +3100,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2781
3100
  fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
2782
3101
  fft_out[k + 3] * filters.data[j * n_fft + k + 3];
2783
3102
  }
2784
-
2785
3103
  // handle n_fft remainder
2786
3104
  for (; k < n_fft; k++) {
2787
3105
  sum += fft_out[k] * filters.data[j * n_fft + k];
2788
3106
  }
2789
-
2790
3107
  sum = log10(std::max(sum, 1e-10));
2791
-
2792
3108
  mel.data[j * mel.n_len + i] = sum;
2793
3109
  }
2794
3110
  }
@@ -2817,12 +3133,9 @@ static bool log_mel_spectrogram(
2817
3133
  whisper_mel & mel) {
2818
3134
  const int64_t t_start_us = wsp_ggml_time_us();
2819
3135
 
2820
- // Hanning window (Use cosf to eliminate difference)
2821
- // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2822
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2823
- std::vector<float> hann;
2824
- hann_window(frame_size, true, hann);
2825
-
3136
+ // Hann window
3137
+ WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
3138
+ const float * hann = global_cache.hann_window;
2826
3139
 
2827
3140
  // Calculate the length of padding
2828
3141
  int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
@@ -2847,12 +3160,11 @@ static bool log_mel_spectrogram(
2847
3160
  mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
2848
3161
  mel.data.resize(mel.n_mel * mel.n_len);
2849
3162
 
2850
-
2851
3163
  {
2852
3164
  std::vector<std::thread> workers(n_threads - 1);
2853
3165
  for (int iw = 0; iw < n_threads - 1; ++iw) {
2854
3166
  workers[iw] = std::thread(
2855
- log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
3167
+ log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
2856
3168
  n_samples + stage_2_pad, frame_size, frame_step, n_threads,
2857
3169
  std::cref(filters), std::ref(mel));
2858
3170
  }
@@ -3012,19 +3324,24 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
3012
3324
  #endif
3013
3325
 
3014
3326
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
3015
- fill_sin_cos_table();
3016
-
3017
3327
  whisper_state * state = new whisper_state;
3018
3328
 
3019
- state->backend = whisper_backend_init(ctx->params);
3020
-
3021
- // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3022
- // in theory, there can be a case where this is not enough, but in practice it should always be enough
3023
- const int factor = 3;
3329
+ state->backends = whisper_backend_init(ctx->params);
3330
+ if (state->backends.empty()) {
3331
+ WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
3332
+ whisper_free_state(state);
3333
+ return nullptr;
3334
+ }
3024
3335
 
3025
- if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
3026
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3027
- delete state;
3336
+ // at this point, we don't know yet how many decoders will be used
3337
+ // later during decoding, if more decoders are used, we will recreate the KV cache respectively
3338
+ state->kv_self_n_dec = 1;
3339
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
3340
+ ctx->model.hparams.n_text_state,
3341
+ ctx->model.hparams.n_text_layer,
3342
+ WSP_GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
3343
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3344
+ whisper_free_state(state);
3028
3345
  return nullptr;
3029
3346
  }
3030
3347
 
@@ -3033,9 +3350,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3033
3350
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
3034
3351
  }
3035
3352
 
3036
- if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
3037
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
3038
- delete state;
3353
+ if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
3354
+ ctx->model.hparams.n_text_state,
3355
+ ctx->model.hparams.n_text_layer,
3356
+ WSP_GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3357
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
3358
+ whisper_free_state(state);
3039
3359
  return nullptr;
3040
3360
  }
3041
3361
 
@@ -3044,6 +3364,31 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3044
3364
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3045
3365
  }
3046
3366
 
3367
+ if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
3368
+ ctx->model.hparams.n_audio_state,
3369
+ 1,
3370
+ WSP_GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3371
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3372
+ whisper_free_state(state);
3373
+ return nullptr;
3374
+ }
3375
+
3376
+ {
3377
+ const size_t memory_size = wsp_ggml_nbytes(state->kv_pad.k) + wsp_ggml_nbytes(state->kv_pad.v);
3378
+ WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
3379
+ }
3380
+
3381
+ // [EXPERIMENTAL] Token-level timestamps with DTW
3382
+ if (ctx->params.dtw_token_timestamps) {
3383
+ if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
3384
+ WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
3385
+ whisper_free_state(state);
3386
+ return nullptr;
3387
+ }
3388
+ const size_t memory_size = aheads_masks_nbytes(state->aheads_masks);
3389
+ WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size);
3390
+ }
3391
+
3047
3392
 
3048
3393
  #ifdef WHISPER_USE_COREML
3049
3394
  if (ctx->params.use_coreml) {
@@ -3056,7 +3401,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3056
3401
  if (!state->ctx_coreml) {
3057
3402
  WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
3058
3403
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
3059
- delete state;
3404
+ whisper_free_state(state);
3060
3405
  return nullptr;
3061
3406
  #endif
3062
3407
  } else {
@@ -3081,37 +3426,55 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3081
3426
 
3082
3427
  // conv allocator
3083
3428
  {
3084
- whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
3429
+ bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
3085
3430
  [&]() {
3086
- return whisper_build_graph_conv(*ctx, *state, 0);
3431
+ return whisper_build_graph_conv(*ctx, *state);
3087
3432
  });
3088
3433
 
3089
- WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
3434
+ if (!ok) {
3435
+ WHISPER_LOG_ERROR("%s: failed to init conv allocator\n", __func__);
3436
+ whisper_free_state(state);
3437
+ return nullptr;
3438
+ }
3439
+
3440
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
3090
3441
  }
3091
3442
 
3092
3443
  // encoder allocator
3093
3444
  if (!whisper_encode_external(*state)) {
3094
- whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
3445
+ bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
3095
3446
  [&]() {
3096
3447
  return whisper_build_graph_encoder(*ctx, *state);
3097
3448
  });
3098
3449
 
3099
- WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
3450
+ if (!ok) {
3451
+ WHISPER_LOG_ERROR("%s: failed to init encoder allocator\n", __func__);
3452
+ whisper_free_state(state);
3453
+ return nullptr;
3454
+ }
3455
+
3456
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
3100
3457
  }
3101
3458
 
3102
3459
  // cross allocator
3103
3460
  {
3104
- whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
3461
+ bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
3105
3462
  [&]() {
3106
3463
  return whisper_build_graph_cross(*ctx, *state);
3107
3464
  });
3108
3465
 
3109
- WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
3466
+ if (!ok) {
3467
+ WHISPER_LOG_ERROR("%s: failed to init cross allocator\n", __func__);
3468
+ whisper_free_state(state);
3469
+ return nullptr;
3470
+ }
3471
+
3472
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
3110
3473
  }
3111
3474
 
3112
3475
  // decoder allocator
3113
3476
  {
3114
- whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
3477
+ bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
3115
3478
  [&]() {
3116
3479
  const auto & hparams = ctx->model.hparams;
3117
3480
 
@@ -3121,27 +3484,30 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3121
3484
 
3122
3485
  whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
3123
3486
 
3124
- return whisper_build_graph_decoder(*ctx, *state, state->batch);
3487
+ return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true);
3125
3488
  });
3126
3489
 
3127
- WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
3128
- }
3490
+ if (!ok) {
3491
+ WHISPER_LOG_ERROR("%s: failed to init decoder allocator\n", __func__);
3492
+ whisper_free_state(state);
3493
+ return nullptr;
3494
+ }
3129
3495
 
3130
- whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
3131
- whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
3132
- whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
3133
- whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
3496
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
3497
+ }
3134
3498
 
3135
3499
  return state;
3136
3500
  }
3137
3501
 
3138
- int whisper_ctx_init_openvino_encoder(
3502
+ int whisper_ctx_init_openvino_encoder_with_state(
3139
3503
  struct whisper_context * ctx,
3504
+ struct whisper_state * state,
3140
3505
  const char * model_path,
3141
3506
  const char * device,
3142
3507
  const char * cache_dir) {
3143
3508
  #ifndef WHISPER_USE_OPENVINO
3144
3509
  (void)(ctx);
3510
+ (void)(state);
3145
3511
  (void)(model_path);
3146
3512
  (void)(device);
3147
3513
  (void)(cache_dir);
@@ -3172,8 +3538,8 @@ int whisper_ctx_init_openvino_encoder(
3172
3538
  WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3173
3539
  WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
3174
3540
 
3175
- ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3176
- if (!ctx->state->ctx_openvino) {
3541
+ state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3542
+ if (!state->ctx_openvino) {
3177
3543
  WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3178
3544
  return 1;
3179
3545
  } else {
@@ -3184,18 +3550,43 @@ int whisper_ctx_init_openvino_encoder(
3184
3550
  #endif
3185
3551
  }
3186
3552
 
3553
+ int whisper_ctx_init_openvino_encoder(
3554
+ struct whisper_context * ctx,
3555
+ const char * model_path,
3556
+ const char * device,
3557
+ const char * cache_dir) {
3558
+ return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
3559
+ }
3560
+
3187
3561
  struct whisper_context_params whisper_context_default_params() {
3188
3562
  struct whisper_context_params result = {
3189
- /*.use_gpu =*/ true,
3190
- /*.use_coreml =*/ false,
3563
+ /*.use_gpu =*/ true,
3564
+ /*.use_coreml =*/ false,
3565
+ /*.flash_attn =*/ false,
3566
+ /*.gpu_device =*/ 0,
3567
+
3568
+ /*.dtw_token_timestamps =*/ false,
3569
+ /*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE,
3570
+ /*.dtw_n_top =*/ -1,
3571
+ /*.dtw_aheads =*/ {
3572
+ /*.n_heads =*/ 0,
3573
+ /*.heads =*/ NULL,
3574
+ },
3575
+ /*.dtw_mem_size =*/ 1024*1024*128,
3191
3576
  };
3192
3577
  return result;
3193
3578
  }
3194
3579
 
3195
3580
  struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3196
3581
  WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
3197
-
3582
+ #ifdef _MSC_VER
3583
+ // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
3584
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
3585
+ std::wstring path_model_wide = converter.from_bytes(path_model);
3586
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
3587
+ #else
3198
3588
  auto fin = std::ifstream(path_model, std::ios::binary);
3589
+ #endif
3199
3590
  if (!fin) {
3200
3591
  WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
3201
3592
  return nullptr;
@@ -3270,6 +3661,19 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
3270
3661
  struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
3271
3662
  wsp_ggml_time_init();
3272
3663
 
3664
+ if (params.flash_attn && params.dtw_token_timestamps) {
3665
+ WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
3666
+ params.dtw_token_timestamps = false;
3667
+ }
3668
+
3669
+ WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
3670
+ WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
3671
+ WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
3672
+ WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
3673
+
3674
+ // TODO: temporary call to force backend registry initialization
3675
+ WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, wsp_ggml_backend_reg_count());
3676
+
3273
3677
  whisper_context * ctx = new whisper_context;
3274
3678
  ctx->params = params;
3275
3679
 
@@ -3354,11 +3758,11 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
3354
3758
  return whisper_init_with_params_no_state(loader, whisper_context_default_params());
3355
3759
  }
3356
3760
 
3357
- void whisper_free_state(struct whisper_state * state)
3358
- {
3761
+ void whisper_free_state(struct whisper_state * state) {
3359
3762
  if (state) {
3360
- kv_cache_free(state->kv_self);
3361
- kv_cache_free(state->kv_cross);
3763
+ whisper_kv_cache_free(state->kv_self);
3764
+ whisper_kv_cache_free(state->kv_cross);
3765
+ whisper_kv_cache_free(state->kv_pad);
3362
3766
 
3363
3767
  #ifdef WHISPER_USE_COREML
3364
3768
  if (state->ctx_coreml != nullptr) {
@@ -3376,12 +3780,17 @@ void whisper_free_state(struct whisper_state * state)
3376
3780
 
3377
3781
  whisper_batch_free(state->batch);
3378
3782
 
3379
- whisper_allocr_free(state->alloc_conv);
3380
- whisper_allocr_free(state->alloc_encode);
3381
- whisper_allocr_free(state->alloc_cross);
3382
- whisper_allocr_free(state->alloc_decode);
3783
+ wsp_ggml_backend_sched_free(state->sched_conv.sched);
3784
+ wsp_ggml_backend_sched_free(state->sched_encode.sched);
3785
+ wsp_ggml_backend_sched_free(state->sched_cross.sched);
3786
+ wsp_ggml_backend_sched_free(state->sched_decode.sched);
3787
+
3788
+ for (auto & backend : state->backends) {
3789
+ wsp_ggml_backend_free(backend);
3790
+ }
3383
3791
 
3384
- wsp_ggml_backend_free(state->backend);
3792
+ // [EXPERIMENTAL] Token-level timestamps with DTW
3793
+ aheads_masks_free(state->aheads_masks);
3385
3794
 
3386
3795
  delete state;
3387
3796
  }
@@ -3389,18 +3798,12 @@ void whisper_free_state(struct whisper_state * state)
3389
3798
 
3390
3799
  void whisper_free(struct whisper_context * ctx) {
3391
3800
  if (ctx) {
3392
- if (ctx->model.ctx) {
3393
- wsp_ggml_free(ctx->model.ctx);
3394
- }
3801
+ wsp_ggml_free(ctx->model.ctx);
3395
3802
 
3396
- if (ctx->model.buffer) {
3397
- wsp_ggml_backend_buffer_free(ctx->model.buffer);
3398
- }
3803
+ wsp_ggml_backend_buffer_free(ctx->model.buffer);
3399
3804
 
3400
3805
  whisper_free_state(ctx->state);
3401
3806
 
3402
- wsp_ggml_backend_free(ctx->backend);
3403
-
3404
3807
  delete ctx;
3405
3808
  }
3406
3809
  }
@@ -3430,30 +3833,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3430
3833
  return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3431
3834
  }
3432
3835
 
3433
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3434
- int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3435
- if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3436
- WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3437
- return -1;
3438
- }
3439
-
3440
- return 0;
3441
- }
3442
-
3443
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3444
- int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
3445
- return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3446
- }
3447
-
3448
- // same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
3449
- // TODO
3450
-
3451
- // same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
3452
- // TODO
3453
-
3454
- // same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
3455
- // TODO
3456
-
3457
3836
  int whisper_set_mel_with_state(
3458
3837
  struct whisper_context * ctx,
3459
3838
  struct whisper_state * state,
@@ -3506,7 +3885,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
3506
3885
 
3507
3886
  whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
3508
3887
 
3509
- if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
3888
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) {
3510
3889
  WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3511
3890
  return 1;
3512
3891
  }
@@ -3528,7 +3907,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
3528
3907
 
3529
3908
  if (n_max_tokens < (int) res.size()) {
3530
3909
  WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3531
- return -1;
3910
+ return -(int) res.size();
3532
3911
  }
3533
3912
 
3534
3913
  for (int i = 0; i < (int) res.size(); i++) {
@@ -3538,7 +3917,11 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
3538
3917
  return res.size();
3539
3918
  }
3540
3919
 
3541
- int whisper_lang_max_id() {
3920
+ int whisper_token_count(struct whisper_context * ctx, const char * text) {
3921
+ return -whisper_tokenize(ctx, text, NULL, 0);
3922
+ }
3923
+
3924
+ int whisper_lang_max_id(void) {
3542
3925
  auto max_id = 0;
3543
3926
  for (const auto & kv : g_lang) {
3544
3927
  max_id = std::max(max_id, kv.second.first);
@@ -3807,28 +4190,51 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
3807
4190
  return ctx->vocab.token_transcribe;
3808
4191
  }
3809
4192
 
4193
+ struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
4194
+ if (ctx->state == nullptr) {
4195
+ return nullptr;
4196
+ }
4197
+ return new whisper_timings {
4198
+ .load_us = ctx->t_load_us,
4199
+ .t_start_us = ctx->t_start_us,
4200
+ .fail_p = ctx->state->n_fail_p,
4201
+ .fail_h = ctx->state->n_fail_h,
4202
+ .t_mel_us = ctx->state->t_mel_us,
4203
+ .n_sample = ctx->state->n_sample,
4204
+ .n_encode = ctx->state->n_encode,
4205
+ .n_decode = ctx->state->n_decode,
4206
+ .n_batchd = ctx->state->n_batchd,
4207
+ .n_prompt = ctx->state->n_prompt,
4208
+ .t_sample_us = ctx->state->t_sample_us,
4209
+ .t_encode_us = ctx->state->t_encode_us,
4210
+ .t_decode_us = ctx->state->t_decode_us,
4211
+ .t_batchd_us = ctx->state->t_batchd_us,
4212
+ .t_prompt_us = ctx->state->t_prompt_us,
4213
+ };
4214
+ }
4215
+
3810
4216
  void whisper_print_timings(struct whisper_context * ctx) {
3811
4217
  const int64_t t_end_us = wsp_ggml_time_us();
4218
+ const struct whisper_timings * timings = whisper_get_timings(ctx);
3812
4219
 
3813
4220
  WHISPER_LOG_INFO("\n");
3814
- WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
4221
+ WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings->load_us / 1000.0f);
3815
4222
  if (ctx->state != nullptr) {
3816
-
3817
4223
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3818
4224
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3819
4225
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3820
4226
  const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
3821
4227
  const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3822
4228
 
3823
- WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3824
- WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3825
- WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3826
- WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3827
- WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3828
- WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
3829
- WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
4229
+ WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings->fail_p, timings->fail_h);
4230
+ WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings->t_mel_us/1000.0f);
4231
+ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_sample_us, n_sample, 1e-3f * timings->t_sample_us / n_sample);
4232
+ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_encode_us, n_encode, 1e-3f * timings->t_encode_us / n_encode);
4233
+ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_decode_us, n_decode, 1e-3f * timings->t_decode_us / n_decode);
4234
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_batchd_us, n_batchd, 1e-3f * timings->t_batchd_us / n_batchd);
4235
+ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_prompt_us, n_prompt, 1e-3f * timings->t_prompt_us / n_prompt);
3830
4236
  }
3831
- WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
4237
+ WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings->t_start_us)/1000.0f);
3832
4238
  }
3833
4239
 
3834
4240
  void whisper_reset_timings(struct whisper_context * ctx) {
@@ -3838,6 +4244,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
3838
4244
  ctx->state->t_sample_us = 0;
3839
4245
  ctx->state->t_encode_us = 0;
3840
4246
  ctx->state->t_decode_us = 0;
4247
+ ctx->state->t_batchd_us = 0;
3841
4248
  ctx->state->t_prompt_us = 0;
3842
4249
  ctx->state->n_sample = 0;
3843
4250
  ctx->state->n_encode = 0;
@@ -3881,10 +4288,10 @@ const char * whisper_print_system_info(void) {
3881
4288
  s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
3882
4289
  s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
3883
4290
  s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
3884
- s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cublas()) + " | ";
4291
+ s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cuda()) + " | ";
3885
4292
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
3886
4293
  s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
3887
-
4294
+ s += "CANN = " + std::to_string(wsp_ggml_cpu_has_cann()) ;
3888
4295
  return s.c_str();
3889
4296
  }
3890
4297
 
@@ -3894,7 +4301,7 @@ const char * whisper_print_system_info(void) {
3894
4301
 
3895
4302
  // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
3896
4303
  // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
3897
- std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4304
+ static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
3898
4305
  const char * src,
3899
4306
  whisper_partial_utf8 partial_start) {
3900
4307
  static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
@@ -4308,7 +4715,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
4308
4715
 
4309
4716
  ////////////////////////////////////////////////////////////////////////////
4310
4717
 
4311
- struct whisper_context_params * whisper_context_default_params_by_ref() {
4718
+ struct whisper_context_params * whisper_context_default_params_by_ref(void) {
4312
4719
  struct whisper_context_params params = whisper_context_default_params();
4313
4720
 
4314
4721
  struct whisper_context_params* result = new whisper_context_params();
@@ -4349,12 +4756,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4349
4756
  /*.split_on_word =*/ false,
4350
4757
  /*.max_tokens =*/ 0,
4351
4758
 
4352
- /*.speed_up =*/ false,
4353
4759
  /*.debug_mode =*/ false,
4354
4760
  /*.audio_ctx =*/ 0,
4355
4761
 
4356
4762
  /*.tdrz_enable =*/ false,
4357
4763
 
4764
+ /* suppress_regex =*/ nullptr,
4765
+
4358
4766
  /*.initial_prompt =*/ nullptr,
4359
4767
  /*.prompt_tokens =*/ nullptr,
4360
4768
  /*.prompt_n_tokens =*/ 0,
@@ -4440,6 +4848,17 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
4440
4848
  return txt[0] == ' ';
4441
4849
  }
4442
4850
 
4851
+ static void whisper_exp_compute_token_level_timestamps_dtw(
4852
+ struct whisper_context * ctx,
4853
+ struct whisper_state * state,
4854
+ struct whisper_full_params params,
4855
+ int i_segment,
4856
+ size_t n_segments,
4857
+ int seek,
4858
+ int n_frames,
4859
+ int medfilt_width,
4860
+ int n_threads);
4861
+
4443
4862
  // wrap the last segment to max_len characters
4444
4863
  // returns the number of new segments
4445
4864
  static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
@@ -4587,6 +5006,17 @@ static void whisper_process_logits(
4587
5006
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
4588
5007
  }
4589
5008
 
5009
+ // suppress any tokens matching a regular expression
5010
+ // ref: https://github.com/openai/whisper/discussions/1041
5011
+ if (params.suppress_regex != nullptr) {
5012
+ std::regex re(params.suppress_regex);
5013
+ for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
5014
+ if (std::regex_match(token_id.first, re)) {
5015
+ logits[token_id.second] = -INFINITY;
5016
+ }
5017
+ }
5018
+ }
5019
+
4590
5020
  // suppress non-speech tokens
4591
5021
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
4592
5022
  if (params.suppress_non_speech_tokens) {
@@ -4790,12 +5220,25 @@ static void whisper_process_logits(
4790
5220
  #endif
4791
5221
  }
4792
5222
 
5223
+ static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) {
5224
+ if (a.tokens.size() != b.tokens.size()) {
5225
+ return false;
5226
+ }
5227
+ // sequences are more likely to diverge at the end
5228
+ for (int i = a.tokens.size() - 1; i >= 0; i--) {
5229
+ if (a.tokens[i].id != b.tokens[i].id) {
5230
+ return false;
5231
+ }
5232
+ }
5233
+ return true;
5234
+ }
5235
+
4793
5236
  static whisper_token_data whisper_sample_token(
4794
5237
  whisper_context & ctx,
4795
5238
  const whisper_decoder & decoder,
4796
5239
  bool best) {
4797
5240
  whisper_token_data result = {
4798
- 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
5241
+ 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
4799
5242
  };
4800
5243
 
4801
5244
  const auto & vocab = ctx.vocab;
@@ -4913,7 +5356,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4913
5356
  const auto id = dist(decoder.rng);
4914
5357
  //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
4915
5358
 
4916
- result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
5359
+ result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
4917
5360
 
4918
5361
  if (result[i].id >= vocab.token_beg) {
4919
5362
  result[i].tid = result[i].id;
@@ -4966,7 +5409,7 @@ static void whisper_sequence_score(
4966
5409
  const auto p = kv.second/(double)cnt;
4967
5410
  entropy -= p*log(p);
4968
5411
 
4969
- //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
5412
+ //WHISPER_LOG_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
4970
5413
  }
4971
5414
 
4972
5415
  sequence.entropy = entropy;
@@ -4986,15 +5429,9 @@ int whisper_full_with_state(
4986
5429
 
4987
5430
  if (n_samples > 0) {
4988
5431
  // compute log mel spectrogram
4989
- if (params.speed_up) {
4990
- // TODO: Replace PV with more advanced algorithm
5432
+ if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
4991
5433
  WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
4992
- return -1;
4993
- } else {
4994
- if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
4995
- WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
4996
- return -2;
4997
- }
5434
+ return -2;
4998
5435
  }
4999
5436
  }
5000
5437
 
@@ -5031,8 +5468,8 @@ int whisper_full_with_state(
5031
5468
  // if length of spectrogram is less than 1.0s (100 frames), then return
5032
5469
  // basically don't process anything that is less than 1.0s
5033
5470
  // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
5034
- if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
5035
- WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
5471
+ if (seek_end < seek_start + 100) {
5472
+ WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
5036
5473
  return 0;
5037
5474
  }
5038
5475
 
@@ -5095,7 +5532,12 @@ int whisper_full_with_state(
5095
5532
  // initial prompt
5096
5533
  if (!params.prompt_tokens && params.initial_prompt) {
5097
5534
  prompt_tokens.resize(1024);
5098
- prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
5535
+ int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
5536
+ if (n_needed < 0) {
5537
+ prompt_tokens.resize(-n_needed);
5538
+ n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
5539
+ }
5540
+ prompt_tokens.resize(n_needed);
5099
5541
  params.prompt_tokens = prompt_tokens.data();
5100
5542
  params.prompt_n_tokens = prompt_tokens.size();
5101
5543
  }
@@ -5131,11 +5573,11 @@ int whisper_full_with_state(
5131
5573
  }
5132
5574
  }
5133
5575
 
5134
- // distilled models require the "no_timestamps" token
5576
+ // first release distilled models require the "no_timestamps" token
5135
5577
  {
5136
- const bool is_distil = ctx->model.hparams.n_text_layer == 2;
5578
+ const bool is_distil = ctx->model.hparams.n_text_layer == 2 && ctx->model.hparams.n_vocab != 51866;
5137
5579
  if (is_distil && !params.no_timestamps) {
5138
- WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
5580
+ WHISPER_LOG_WARN("%s: using first release distilled models - forcing no_timestamps\n", __func__);
5139
5581
  params.no_timestamps = true;
5140
5582
  }
5141
5583
  }
@@ -5221,7 +5663,7 @@ int whisper_full_with_state(
5221
5663
 
5222
5664
  n_decoders_cur = std::max(1, n_decoders_cur);
5223
5665
 
5224
- WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
5666
+ WHISPER_LOG_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
5225
5667
 
5226
5668
  // TAGS: WHISPER_DECODER_INIT
5227
5669
  for (int j = 0; j < n_decoders_cur; ++j) {
@@ -5265,19 +5707,40 @@ int whisper_full_with_state(
5265
5707
  prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
5266
5708
 
5267
5709
  // print the prompt
5268
- WHISPER_PRINT_DEBUG("\n\n");
5710
+ WHISPER_LOG_DEBUG("\n\n");
5269
5711
  for (int i = 0; i < (int) prompt.size(); i++) {
5270
- WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
5712
+ WHISPER_LOG_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
5713
+ }
5714
+ WHISPER_LOG_DEBUG("\n\n");
5715
+
5716
+ // recreate the KV cache if the number of decoders has changed
5717
+ if (state->kv_self_n_dec < n_decoders_cur) {
5718
+ WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
5719
+
5720
+ whisper_kv_cache_free(state->kv_self);
5721
+
5722
+ // overallocate to workaround KV cache fragmentation issues
5723
+ const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
5724
+
5725
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
5726
+ ctx->model.hparams.n_text_state,
5727
+ ctx->model.hparams.n_text_layer,
5728
+ WSP_GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
5729
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
5730
+ whisper_free_state(state);
5731
+ return -7;
5732
+ }
5733
+
5734
+ state->kv_self_n_dec = n_decoders_cur;
5271
5735
  }
5272
- WHISPER_PRINT_DEBUG("\n\n");
5273
5736
 
5274
5737
  whisper_kv_cache_clear(state->kv_self);
5275
5738
 
5276
5739
  whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5277
5740
 
5278
- if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5741
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5279
5742
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5280
- return -7;
5743
+ return -8;
5281
5744
  }
5282
5745
 
5283
5746
  {
@@ -5388,7 +5851,10 @@ int whisper_full_with_state(
5388
5851
  beam_candidates.begin(),
5389
5852
  beam_candidates.end(),
5390
5853
  [](const beam_candidate & a, const beam_candidate & b) {
5391
- return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
5854
+ if (a.sequence.sum_logprobs_all != b.sequence.sum_logprobs_all) {
5855
+ return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
5856
+ }
5857
+ return a.decoder_idx < b.decoder_idx;
5392
5858
  });
5393
5859
 
5394
5860
  uint32_t cur_c = 0;
@@ -5406,7 +5872,7 @@ int whisper_full_with_state(
5406
5872
 
5407
5873
  auto & cur = beam_candidates[cur_c++];
5408
5874
 
5409
- while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
5875
+ while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) {
5410
5876
  ++cur_c;
5411
5877
  }
5412
5878
 
@@ -5417,7 +5883,7 @@ int whisper_full_with_state(
5417
5883
 
5418
5884
  whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
5419
5885
 
5420
- WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
5886
+ WHISPER_LOG_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
5421
5887
  __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
5422
5888
  }
5423
5889
 
@@ -5460,7 +5926,7 @@ int whisper_full_with_state(
5460
5926
 
5461
5927
  // do not allow to go back in time
5462
5928
  if (has_ts && seek_delta > seek_delta_new && result_len < i) {
5463
- WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
5929
+ WHISPER_LOG_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
5464
5930
  failed = true; // TODO: maybe this is not a failure ?
5465
5931
  continue;
5466
5932
  }
@@ -5475,7 +5941,7 @@ int whisper_full_with_state(
5475
5941
  #ifdef WHISPER_DEBUG
5476
5942
  {
5477
5943
  const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
5478
- WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
5944
+ WHISPER_LOG_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
5479
5945
  __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
5480
5946
  }
5481
5947
  #endif
@@ -5485,22 +5951,22 @@ int whisper_full_with_state(
5485
5951
  (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
5486
5952
  (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
5487
5953
  ) {
5488
- if (result_len == 0) {
5954
+ if (result_len == 0 && !params.no_timestamps) {
5489
5955
  if (seek + seek_delta + 100 >= seek_end) {
5490
5956
  result_len = i + 1;
5491
5957
  } else {
5492
- WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
5958
+ WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
5493
5959
  failed = true;
5494
5960
  continue;
5495
5961
  }
5496
5962
  }
5497
5963
 
5498
- if (params.single_segment) {
5964
+ if (params.single_segment || params.no_timestamps) {
5499
5965
  result_len = i + 1;
5500
5966
  seek_delta = 100*WHISPER_CHUNK_SIZE;
5501
5967
  }
5502
5968
 
5503
- WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j);
5969
+ WHISPER_LOG_DEBUG("%s: decoder %d completed\n", __func__, j);
5504
5970
  completed = true;
5505
5971
  continue;
5506
5972
  }
@@ -5516,7 +5982,7 @@ int whisper_full_with_state(
5516
5982
  // sometimes, the decoding can get stuck in a repetition loop
5517
5983
  // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
5518
5984
  if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
5519
- WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
5985
+ WHISPER_LOG_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
5520
5986
  failed = true;
5521
5987
  continue;
5522
5988
  }
@@ -5558,7 +6024,7 @@ int whisper_full_with_state(
5558
6024
  continue;
5559
6025
  }
5560
6026
 
5561
- //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
6027
+ //WHISPER_LOG_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
5562
6028
 
5563
6029
  decoder.i_batch = batch.n_tokens;
5564
6030
 
@@ -5572,9 +6038,9 @@ int whisper_full_with_state(
5572
6038
 
5573
6039
  assert(batch.n_tokens > 0);
5574
6040
 
5575
- if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
6041
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5576
6042
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5577
- return -8;
6043
+ return -9;
5578
6044
  }
5579
6045
 
5580
6046
  const int64_t t_start_sample_us = wsp_ggml_time_us();
@@ -5638,11 +6104,11 @@ int whisper_full_with_state(
5638
6104
  decoder.sequence.tokens.resize(decoder.sequence.result_len);
5639
6105
  whisper_sequence_score(params, decoder.sequence);
5640
6106
 
5641
- WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
6107
+ WHISPER_LOG_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
5642
6108
  __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
5643
6109
 
5644
6110
  if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
5645
- WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
6111
+ WHISPER_LOG_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
5646
6112
  __func__, j, decoder.sequence.entropy, params.entropy_thold);
5647
6113
 
5648
6114
  decoder.failed = true;
@@ -5657,7 +6123,7 @@ int whisper_full_with_state(
5657
6123
  }
5658
6124
  }
5659
6125
 
5660
- WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
6126
+ WHISPER_LOG_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
5661
6127
  }
5662
6128
 
5663
6129
  bool success = true;
@@ -5669,7 +6135,7 @@ int whisper_full_with_state(
5669
6135
  const auto & decoder = state->decoders[best_decoder_id];
5670
6136
 
5671
6137
  if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
5672
- WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
6138
+ WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
5673
6139
  success = false;
5674
6140
  state->n_fail_p++;
5675
6141
  }
@@ -5677,13 +6143,13 @@ int whisper_full_with_state(
5677
6143
 
5678
6144
  if (success) {
5679
6145
  //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
5680
- // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
6146
+ // WHISPER_LOG_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
5681
6147
  //}
5682
6148
 
5683
6149
  break;
5684
6150
  }
5685
6151
 
5686
- WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
6152
+ WHISPER_LOG_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
5687
6153
  }
5688
6154
 
5689
6155
  // output results through a user-provided callback
@@ -5695,7 +6161,10 @@ int whisper_full_with_state(
5695
6161
 
5696
6162
  const auto & tokens_cur = best_decoder.sequence.tokens;
5697
6163
 
5698
- //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
6164
+ // [EXPERIMENTAL] Token-level timestamps with DTW
6165
+ const auto n_segments_before = state->result_all.size();
6166
+
6167
+ //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
5699
6168
 
5700
6169
  // update prompt_past
5701
6170
  prompt_past.clear();
@@ -5732,8 +6201,8 @@ int whisper_full_with_state(
5732
6201
  const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
5733
6202
 
5734
6203
  if (!text.empty()) {
5735
- const auto tt0 = params.speed_up ? 2*t0 : t0;
5736
- const auto tt1 = params.speed_up ? 2*t1 : t1;
6204
+ const auto tt0 = t0;
6205
+ const auto tt1 = t1;
5737
6206
 
5738
6207
  if (params.print_realtime) {
5739
6208
  if (params.print_timestamps) {
@@ -5761,7 +6230,7 @@ int whisper_full_with_state(
5761
6230
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
5762
6231
  }
5763
6232
  }
5764
- if (params.new_segment_callback) {
6233
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
5765
6234
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
5766
6235
  }
5767
6236
  }
@@ -5779,8 +6248,8 @@ int whisper_full_with_state(
5779
6248
  if (!text.empty()) {
5780
6249
  const auto t1 = seek + seek_delta;
5781
6250
 
5782
- const auto tt0 = params.speed_up ? 2*t0 : t0;
5783
- const auto tt1 = params.speed_up ? 2*t1 : t1;
6251
+ const auto tt0 = t0;
6252
+ const auto tt1 = t1;
5784
6253
 
5785
6254
  if (params.print_realtime) {
5786
6255
  if (params.print_timestamps) {
@@ -5806,16 +6275,32 @@ int whisper_full_with_state(
5806
6275
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
5807
6276
  }
5808
6277
  }
5809
- if (params.new_segment_callback) {
6278
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
5810
6279
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
5811
6280
  }
5812
6281
  }
5813
6282
  }
5814
6283
 
6284
+ // FIXME: will timestamp offsets be correct?
6285
+ // [EXPERIMENTAL] Token-level timestamps with DTW
6286
+ {
6287
+ const int n_segments = state->result_all.size() - n_segments_before;
6288
+ if (ctx->params.dtw_token_timestamps && n_segments) {
6289
+ const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
6290
+ whisper_exp_compute_token_level_timestamps_dtw(
6291
+ ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
6292
+ if (params.new_segment_callback) {
6293
+ for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
6294
+ params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
6295
+ }
6296
+ }
6297
+ }
6298
+ }
6299
+
5815
6300
  // update audio window
5816
6301
  seek += seek_delta;
5817
6302
 
5818
- WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
6303
+ WHISPER_LOG_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
5819
6304
  }
5820
6305
  }
5821
6306
 
@@ -6132,7 +6617,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
6132
6617
 
6133
6618
  // multi-thread
6134
6619
 
6135
- for (uint32_t k = 1; k <= n_threads; k++) {
6620
+ for (int32_t k = 1; k <= n_threads; k++) {
6136
6621
  char * src = (char *) malloc(size);
6137
6622
  char * dst = (char *) malloc(size);
6138
6623
 
@@ -6156,13 +6641,13 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
6156
6641
  const int64_t t0 = wsp_ggml_time_us();
6157
6642
 
6158
6643
  std::vector<std::thread> threads(k - 1);
6159
- for (uint32_t th = 0; th < k - 1; ++th) {
6644
+ for (int32_t th = 0; th < k - 1; ++th) {
6160
6645
  threads[th] = std::thread(helper, th);
6161
6646
  }
6162
6647
 
6163
6648
  helper(k - 1);
6164
6649
 
6165
- for (uint32_t th = 0; th < k - 1; ++th) {
6650
+ for (int32_t th = 0; th < k - 1; ++th) {
6166
6651
  threads[th].join();
6167
6652
  }
6168
6653
 
@@ -6571,7 +7056,7 @@ static void whisper_exp_compute_token_level_timestamps(
6571
7056
  k++;
6572
7057
  }
6573
7058
  tokens[j].t1 = sample_to_timestamp(k);
6574
- if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
7059
+ if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
6575
7060
  tokens[j].t1 = tokens[j + 1].t0;
6576
7061
  } else {
6577
7062
  s1 = k;
@@ -6614,6 +7099,322 @@ static void whisper_exp_compute_token_level_timestamps(
6614
7099
  //}
6615
7100
  }
6616
7101
 
7102
+ //
7103
+ // token level timestamps - dtw version
7104
+ //
7105
+
7106
+ // n_text_layer -> total text layers on model
7107
+ // n_head -> total heads per text layer on model
7108
+ static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) {
7109
+ std::vector<uint32_t> ret;
7110
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
7111
+ return ret;
7112
+ } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
7113
+ if (il >= n_text_layer - cparams.dtw_n_top) {
7114
+ for (int32_t i = 0; i < n_head; ++i) {
7115
+ ret.push_back(i);
7116
+ }
7117
+ }
7118
+ } else {
7119
+ const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
7120
+ for (size_t i = 0; i < aheads.n_heads; ++i) {
7121
+ if (aheads.heads[i].n_text_layer == il) {
7122
+ ret.push_back(aheads.heads[i].n_head);
7123
+ }
7124
+ }
7125
+ }
7126
+ return ret;
7127
+ }
7128
+
7129
+ // dtw + backtrace to return found path
7130
+ // based on
7131
+ // https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
7132
+ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tensor * x) {
7133
+ WHISPER_ASSERT(wsp_ggml_n_dims(x) == 2);
7134
+
7135
+ int64_t N = x->ne[0];
7136
+ int64_t M = x->ne[1];
7137
+ struct wsp_ggml_tensor * cost = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, N + 1, M + 1);
7138
+ struct wsp_ggml_tensor * trace = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, N + 1, M + 1);
7139
+
7140
+ cost = wsp_ggml_set_f32(cost, INFINITY);
7141
+ trace = wsp_ggml_set_f32(trace, -1);
7142
+ wsp_ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
7143
+
7144
+ // dtw
7145
+ // supposedly can be optmized by computing diagonals in parallel ?
7146
+ // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
7147
+ for (int64_t j = 1; j < M + 1; ++j) {
7148
+ for (int64_t i = 1; i < N + 1; ++i) {
7149
+ float c0 = wsp_ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
7150
+ float c1 = wsp_ggml_get_f32_nd(cost, i - 1, j, 0, 0);
7151
+ float c2 = wsp_ggml_get_f32_nd(cost, i, j - 1, 0, 0);
7152
+
7153
+ float c;
7154
+ int32_t t;
7155
+ if (c0 < c1 && c0 < c2) {
7156
+ c = c0;
7157
+ t = 0;
7158
+ } else if (c1 < c0 && c1 < c2) {
7159
+ c = c1;
7160
+ t = 1;
7161
+ } else {
7162
+ c = c2;
7163
+ t = 2;
7164
+ }
7165
+
7166
+ c = wsp_ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
7167
+ wsp_ggml_set_f32_nd(cost, i, j, 0, 0, c);
7168
+ wsp_ggml_set_i32_nd(trace, i, j, 0, 0, t);
7169
+ }
7170
+ }
7171
+
7172
+ // Backtrace
7173
+ const int64_t BT_MAX_ROWS = N + M - 1;
7174
+ struct wsp_ggml_tensor * bt = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, BT_MAX_ROWS, 2);
7175
+ // trace[0, :] = 2;
7176
+ for (int64_t i = 0; i < M + 1; ++i)
7177
+ wsp_ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
7178
+ //trace[:, 0] = 1;
7179
+ for (int64_t i = 0; i < N + 1; ++i)
7180
+ wsp_ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
7181
+ int bt_row_idx = BT_MAX_ROWS - 1;
7182
+ int64_t i = N;
7183
+ int64_t j = M;
7184
+ while (i > 0 || j > 0) {
7185
+ wsp_ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
7186
+ wsp_ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
7187
+ --bt_row_idx;
7188
+
7189
+ int32_t t = wsp_ggml_get_i32_nd(trace, i, j, 0, 0);
7190
+ if (t == 0) {
7191
+ --i;
7192
+ --j;
7193
+ } else if (t == 1) {
7194
+ --i;
7195
+ } else if (t == 2) {
7196
+ --j;
7197
+ } else {
7198
+ WHISPER_ASSERT(0);
7199
+ }
7200
+ }
7201
+
7202
+ // FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs)
7203
+ // Clip + transpose
7204
+ // This might not be entirely necessary for our case, but leaving it for now so output matrix
7205
+ // is identical to dtw on openAI timing.py
7206
+ const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1;
7207
+ wsp_ggml_tensor * r = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, 2, result_n_cols);
7208
+ for (int64_t i = 0; i < 2; ++i) {
7209
+ for (int64_t j = 0; j < result_n_cols; ++j) {
7210
+ int32_t v = wsp_ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
7211
+ wsp_ggml_set_i32_nd(r, i, j, 0, 0, v);
7212
+ }
7213
+ }
7214
+
7215
+ return r;
7216
+ }
7217
+
7218
+ struct median_filter_user_data {
7219
+ int filter_width;
7220
+ };
7221
+
7222
+ static void median_filter(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
7223
+ if (ith != 0) {
7224
+ return;
7225
+ }
7226
+ int filter_width = ((median_filter_user_data *) userdata)->filter_width;
7227
+ WHISPER_ASSERT(filter_width < a->ne[2]);
7228
+ WHISPER_ASSERT(filter_width % 2);
7229
+ WHISPER_ASSERT(wsp_ggml_n_dims(a) == 3);
7230
+ WHISPER_ASSERT(a->type == WSP_GGML_TYPE_F32);
7231
+
7232
+ std::vector<float> filter;
7233
+ filter.reserve(filter_width);
7234
+ for (int64_t i = 0; i < a->ne[0]; ++i) {
7235
+ for (int64_t j = 0; j < a->ne[1]; ++j) {
7236
+ for (int64_t k = 0; k < a->ne[2]; ++k) {
7237
+ for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) {
7238
+ // "reflect" padding
7239
+ int64_t idx = k + off;
7240
+ if (idx < 0) {
7241
+ idx = -idx;
7242
+ } else if (idx >= a->ne[2]) {
7243
+ idx = 2*(a->ne[2] - 1) - idx;
7244
+ }
7245
+
7246
+ filter.push_back(wsp_ggml_get_f32_nd(a, i, j, idx, 0));
7247
+ }
7248
+ std::sort(filter.begin(), filter.end());
7249
+ const float v = filter[filter.size()/2];
7250
+ wsp_ggml_set_f32_nd(dst, i, j, k, 0, v);
7251
+ filter.clear();
7252
+ }
7253
+ }
7254
+ }
7255
+ }
7256
+
7257
+ static void whisper_exp_compute_token_level_timestamps_dtw(
7258
+ struct whisper_context * ctx,
7259
+ struct whisper_state * state,
7260
+ struct whisper_full_params params,
7261
+ int i_segment,
7262
+ size_t n_segments,
7263
+ int seek,
7264
+ int n_frames,
7265
+ int medfilt_width,
7266
+ int n_threads)
7267
+ {
7268
+ const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx;
7269
+ WHISPER_ASSERT(medfilt_width % 2);
7270
+ WHISPER_ASSERT(n_frames <= n_audio_ctx * 2);
7271
+ WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE);
7272
+
7273
+ // FIXME: Allocating mem everytime we call this func
7274
+ // Our ggml buffer should be pre-allocated somewhere during init and reused
7275
+ // when we call this function
7276
+ struct wsp_ggml_init_params gparams = {
7277
+ /*.mem_size =*/ ctx->params.dtw_mem_size,
7278
+ /*.mem_buffer =*/ NULL,
7279
+ /*.no_alloc =*/ false,
7280
+ };
7281
+ struct wsp_ggml_context * gctx = wsp_ggml_init(gparams);
7282
+
7283
+ // Build token sequence that will be passed to decoder
7284
+ // sot + [lang] + text result + eot
7285
+ std::vector<whisper_token> tokens = { whisper_token_sot(ctx), };
7286
+ if (whisper_is_multilingual(ctx)) {
7287
+ const int lang_id = whisper_lang_id(params.language);
7288
+ state->lang_id = lang_id;
7289
+ tokens.push_back(whisper_token_lang(ctx, lang_id));
7290
+ }
7291
+ const size_t sot_sequence_length = tokens.size();
7292
+ tokens.push_back(whisper_token_not(ctx));
7293
+ for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
7294
+ auto & segment = state->result_all[i];
7295
+ for (auto &t: segment.tokens) {
7296
+ // Only text tokens
7297
+ if (t.id < whisper_token_eot(ctx)) {
7298
+ tokens.push_back(t.id);
7299
+ }
7300
+ }
7301
+ }
7302
+ tokens.push_back(whisper_token_eot(ctx));
7303
+
7304
+ // Get result tokens, pass then along to decoder to get cross attention QKs
7305
+ // used in timestamping
7306
+ // Decoder already returns only alignment head QKs, already concatenated in
7307
+ // one tensor.
7308
+ whisper_kv_cache_clear(state->kv_self);
7309
+ whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0);
7310
+ whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1);
7311
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
7312
+ WHISPER_LOG_INFO("DECODER FAILED\n");
7313
+ WHISPER_ASSERT(0);
7314
+ }
7315
+ WHISPER_ASSERT(state->aheads_cross_QKs != nullptr);
7316
+
7317
+ const auto n_audio_tokens = n_frames/2;
7318
+ WHISPER_ASSERT(state->aheads_cross_QKs != NULL);
7319
+ WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]);
7320
+ const auto n_tokens = state->aheads_cross_QKs->ne[0];
7321
+ const auto n_heads = state->aheads_cross_QKs->ne[2];
7322
+
7323
+ // Copy data from decoder buffer to a local CPU tensor, discarding unused audio
7324
+ // tokens (i.e. discarding rows at the end of tensor)
7325
+ // IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims
7326
+ // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7327
+ WHISPER_ASSERT(state->aheads_cross_QKs->type == WSP_GGML_TYPE_F32);
7328
+ WHISPER_ASSERT(wsp_ggml_is_contiguous(state->aheads_cross_QKs));
7329
+ wsp_ggml_tensor * w = wsp_ggml_new_tensor_3d(gctx, WSP_GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads);
7330
+ auto & data = state->aheads_cross_QKs_data;
7331
+ data.resize(n_tokens * n_audio_ctx * n_heads);
7332
+ wsp_ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);
7333
+ for (int k = 0; k < n_heads; ++k) {
7334
+ for (int j = 0; j < n_audio_tokens; ++j) {
7335
+ memcpy(
7336
+ (char *) w->data + j * w->nb[1] + k * w->nb[2],
7337
+ data.data() + j * n_tokens + k * n_tokens * n_audio_ctx,
7338
+ n_tokens * sizeof(float)
7339
+ );
7340
+ }
7341
+ }
7342
+
7343
+ // Normalize - in original OpenAI code, this is done over dim=-2. In this case,
7344
+ // we already permuted N_TOKENS dimension to columns on last loop, becase wsp_ggml_norm
7345
+ // operates over columns. Afterwards, permute to a shape that facilitates mean
7346
+ // operation (after median filter)
7347
+ // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7348
+ // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7349
+ w = wsp_ggml_norm(gctx, w, 1e-9f);
7350
+ w = wsp_ggml_permute(gctx, wsp_ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
7351
+
7352
+ // Pass median filter - this is done over AUDIO_TOKENS dimension.
7353
+ // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7354
+ // OUT: Same dims
7355
+ median_filter_user_data mf_user_data = {medfilt_width};
7356
+ w = wsp_ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data);
7357
+
7358
+ // Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT
7359
+ // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7360
+ // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims
7361
+ w = wsp_ggml_mean(gctx, w);
7362
+ w = wsp_ggml_scale(gctx, w, -1.0);
7363
+ w = wsp_ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]);
7364
+
7365
+ // Remove SOT sequence and EOT
7366
+ // Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS
7367
+ w = wsp_ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]);
7368
+
7369
+ // Compute
7370
+ struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(gctx);
7371
+ wsp_ggml_build_forward_expand(gf, w);
7372
+ wsp_ggml_graph_compute_with_ctx(gctx, gf, n_threads);
7373
+
7374
+ wsp_ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
7375
+
7376
+ // Place timestamps on segments
7377
+ int32_t last_v = 0;
7378
+ auto seg_i = state->result_all.begin() + i_segment;
7379
+ auto tok_i = seg_i->tokens.begin();
7380
+ for (int i = 0; i < alignment->ne[1]; ++i) {
7381
+ int32_t v = wsp_ggml_get_i32_nd(alignment, 0, i, 0, 0);
7382
+ if (v != last_v) {
7383
+ int32_t time_index = wsp_ggml_get_i32_nd(alignment, 1, i, 0, 0);
7384
+ int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
7385
+ last_v = v;
7386
+
7387
+ // Skip non-text tokens
7388
+ while (!(tok_i->id < whisper_token_eot(ctx))) {
7389
+ ++tok_i;
7390
+ if (tok_i == seg_i->tokens.end()) {
7391
+ ++seg_i;
7392
+ tok_i = seg_i->tokens.begin();
7393
+ }
7394
+ }
7395
+
7396
+ tok_i->t_dtw = timestamp;
7397
+ ++tok_i;
7398
+ if (tok_i == seg_i->tokens.end()) {
7399
+ ++seg_i;
7400
+ tok_i = seg_i->tokens.begin();
7401
+ }
7402
+ }
7403
+ }
7404
+
7405
+ // Print DTW timestamps
7406
+ /*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
7407
+ auto & segment = state->result_all[i];
7408
+ for (auto &t: segment.tokens) {
7409
+ const char * tok = whisper_token_to_str(ctx, t.id);
7410
+ fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100);
7411
+ }
7412
+ fprintf(stderr, "\n");
7413
+ }*/
7414
+
7415
+ wsp_ggml_free(gctx);
7416
+ }
7417
+
6617
7418
  void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
6618
7419
  g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
6619
7420
  g_state.log_callback_user_data = user_data;