whisper.rn 0.4.0-rc.8 → 0.4.0-rc.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. package/android/src/main/CMakeLists.txt +2 -1
  2. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  3. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
  5. package/android/src/main/jni.cpp +29 -1
  6. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  7. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  8. package/cpp/ggml-aarch64.c +3209 -0
  9. package/cpp/ggml-aarch64.h +39 -0
  10. package/cpp/ggml-alloc.c +725 -517
  11. package/cpp/ggml-alloc.h +47 -65
  12. package/cpp/ggml-backend-impl.h +166 -55
  13. package/cpp/ggml-backend.cpp +2635 -0
  14. package/cpp/ggml-backend.h +202 -85
  15. package/cpp/ggml-common.h +1853 -0
  16. package/cpp/ggml-cpu-impl.h +614 -0
  17. package/cpp/ggml-impl.h +143 -180
  18. package/cpp/ggml-metal.h +13 -11
  19. package/cpp/ggml-metal.m +2955 -1632
  20. package/cpp/ggml-quants.c +9824 -3263
  21. package/cpp/ggml-quants.h +133 -248
  22. package/cpp/ggml-whisper.metallib +0 -0
  23. package/cpp/ggml.c +8482 -5142
  24. package/cpp/ggml.h +633 -349
  25. package/cpp/rn-whisper.cpp +91 -0
  26. package/cpp/rn-whisper.h +2 -0
  27. package/cpp/whisper.cpp +1427 -658
  28. package/cpp/whisper.h +84 -28
  29. package/ios/RNWhisper.mm +124 -37
  30. package/ios/RNWhisperAudioUtils.h +1 -0
  31. package/ios/RNWhisperAudioUtils.m +20 -13
  32. package/ios/RNWhisperContext.h +3 -2
  33. package/ios/RNWhisperContext.mm +39 -7
  34. package/jest/mock.js +9 -1
  35. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  36. package/lib/commonjs/index.js +48 -19
  37. package/lib/commonjs/index.js.map +1 -1
  38. package/lib/commonjs/version.json +1 -1
  39. package/lib/module/NativeRNWhisper.js.map +1 -1
  40. package/lib/module/index.js +48 -19
  41. package/lib/module/index.js.map +1 -1
  42. package/lib/module/version.json +1 -1
  43. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  44. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  45. package/lib/typescript/index.d.ts +25 -3
  46. package/lib/typescript/index.d.ts.map +1 -1
  47. package/package.json +6 -5
  48. package/src/NativeRNWhisper.ts +12 -3
  49. package/src/index.ts +63 -24
  50. package/src/version.json +1 -1
  51. package/whisper-rn.podspec +9 -2
  52. package/cpp/ggml-backend.c +0 -1718
  53. package/cpp/ggml-metal-whisper.metal +0 -5820
package/cpp/whisper.cpp CHANGED
@@ -8,14 +8,30 @@
8
8
  #include "ggml-metal.h"
9
9
  #endif
10
10
 
11
- #ifdef WSP_GGML_USE_CUBLAS
11
+ #ifdef WSP_GGML_USE_CUDA
12
12
  #include "ggml-cuda.h"
13
13
  #endif
14
14
 
15
+ #ifdef WSP_GGML_USE_SYCL
16
+ #include "ggml-sycl.h"
17
+ #endif
18
+
19
+ #ifdef WSP_GGML_USE_VULKAN
20
+ #include "ggml-vulkan.h"
21
+ #endif
22
+
23
+ #ifdef WSP_GGML_USE_BLAS
24
+ #include "ggml-blas.h"
25
+ #endif
26
+
15
27
  #ifdef WHISPER_USE_OPENVINO
16
28
  #include "openvino/whisper-openvino-encoder.h"
17
29
  #endif
18
30
 
31
+ #ifdef WSP_GGML_USE_CANN
32
+ #include "ggml-cann.h"
33
+ #endif
34
+
19
35
  #include "ggml.h"
20
36
  #include "ggml-alloc.h"
21
37
  #include "ggml-backend.h"
@@ -37,6 +53,7 @@
37
53
  #include <regex>
38
54
  #include <random>
39
55
  #include <functional>
56
+ #include <codecvt>
40
57
 
41
58
  #if defined(_MSC_VER)
42
59
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -143,8 +160,6 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char *
143
160
  } \
144
161
  } while (0)
145
162
 
146
- //#define WHISPER_USE_FLASH_ATTN
147
- //#define WHISPER_USE_FLASH_FF
148
163
  #define WHISPER_MAX_DECODERS 8
149
164
  #define WHISPER_MAX_NODES 4096
150
165
 
@@ -156,11 +171,11 @@ static bool wsp_ggml_graph_compute_helper(
156
171
  struct wsp_ggml_cgraph * graph,
157
172
  std::vector<uint8_t> & buf,
158
173
  int n_threads,
159
- whisper_abort_callback abort_callback,
174
+ wsp_ggml_abort_callback abort_callback,
160
175
  void * abort_callback_data) {
161
- struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads);
176
+ struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads, nullptr);
162
177
 
163
- plan.abort_callback = abort_callback;
178
+ plan.abort_callback = abort_callback;
164
179
  plan.abort_callback_data = abort_callback_data;
165
180
 
166
181
  if (plan.work_size > 0) {
@@ -172,18 +187,25 @@ static bool wsp_ggml_graph_compute_helper(
172
187
  }
173
188
 
174
189
  static bool wsp_ggml_graph_compute_helper(
175
- struct wsp_ggml_backend * backend,
190
+ wsp_ggml_backend_sched_t sched,
176
191
  struct wsp_ggml_cgraph * graph,
177
192
  int n_threads) {
178
- if (wsp_ggml_backend_is_cpu(backend)) {
179
- wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
180
- }
181
- #ifdef WSP_GGML_USE_METAL
182
- if (wsp_ggml_backend_is_metal(backend)) {
183
- wsp_ggml_backend_metal_set_n_cb(backend, n_threads);
184
- }
193
+
194
+ for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(sched); ++i) {
195
+ wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(sched, i);
196
+ if (wsp_ggml_backend_is_cpu(backend)) {
197
+ wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
198
+ }
199
+ #ifdef WSP_GGML_USE_BLAS
200
+ if (wsp_ggml_backend_is_blas(backend)) {
201
+ wsp_ggml_backend_blas_set_n_threads(backend, n_threads);
202
+ }
185
203
  #endif
186
- return wsp_ggml_backend_graph_compute(backend, graph);
204
+ }
205
+
206
+ bool t = wsp_ggml_backend_sched_graph_compute(sched, graph) == WSP_GGML_STATUS_SUCCESS;
207
+ wsp_ggml_backend_sched_reset(sched);
208
+ return t;
187
209
  }
188
210
 
189
211
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
@@ -347,6 +369,37 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
347
369
  { "yue", { 99, "cantonese", } },
348
370
  };
349
371
 
372
+ // [EXPERIMENTAL] Token-level timestamps with DTW
373
+ static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} };
374
+ static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} };
375
+ static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} };
376
+ static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} };
377
+ static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} };
378
+ static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} };
379
+ static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} };
380
+ static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} };
381
+ static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
382
+ static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
383
+ static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
384
+ static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
385
+
386
+ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
387
+ { WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
388
+ { WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } },
389
+ { WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } },
390
+ { WHISPER_AHEADS_BASE, { 8, g_aheads_base } },
391
+ { WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } },
392
+ { WHISPER_AHEADS_SMALL, { 10, g_aheads_small } },
393
+ { WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } },
394
+ { WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } },
395
+ { WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
396
+ { WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
397
+ { WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
398
+ { WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
399
+ };
400
+
401
+ static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
402
+
350
403
  struct whisper_mel {
351
404
  int n_len;
352
405
  int n_len_org;
@@ -409,7 +462,7 @@ struct whisper_batch {
409
462
 
410
463
  whisper_token * token;
411
464
  whisper_pos * pos;
412
- int32_t * n_seq_id;
465
+ int32_t * n_seq_id; // always 1, here for consistency with llama.cpp
413
466
  whisper_seq_id ** seq_id; // null terminated
414
467
  int8_t * logits;
415
468
  };
@@ -469,54 +522,42 @@ struct whisper_pair {
469
522
  whisper_pair() : first(A()), second(B()) {}
470
523
  };
471
524
 
472
- // wsp_ggml_allocr wrapper for whisper usage
473
- struct whisper_allocr {
474
- wsp_ggml_allocr * alloc = nullptr;
525
+ // wsp_ggml_backend_sched wrapper for whisper usage
526
+ struct whisper_sched {
527
+ wsp_ggml_backend_sched_t sched = nullptr;
475
528
 
476
529
  std::vector<uint8_t> meta;
477
-
478
- wsp_ggml_backend_buffer_t buffer;
479
530
  };
480
531
 
481
- static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
482
- return allocr.meta.size() + wsp_ggml_allocr_max_size(allocr.alloc);
532
+ static size_t whisper_sched_size(struct whisper_sched & allocr) {
533
+ size_t size = allocr.meta.size();
534
+ for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
535
+ wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(allocr.sched, i);
536
+ size += wsp_ggml_backend_sched_get_buffer_size(allocr.sched, backend);
537
+ }
538
+ return size;
483
539
  }
484
540
 
485
541
  // measure the memory usage of a graph and prepare the allocr's internal data buffer
486
- static void whisper_allocr_graph_init(struct whisper_allocr & allocr, wsp_ggml_backend_t backend, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
487
- auto & alloc = allocr.alloc;
542
+ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<wsp_ggml_backend_t> backends, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
543
+ auto & sched = allocr.sched;
488
544
  auto & meta = allocr.meta;
489
545
 
490
- alloc = wsp_ggml_allocr_new_measure_from_backend(backend);
546
+ sched = wsp_ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
491
547
 
492
548
  meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
493
549
 
494
- wsp_ggml_allocr_alloc_graph(alloc, get_graph());
495
- }
496
-
497
- static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, wsp_ggml_backend_t backend) {
498
- if (allocr.alloc == nullptr) {
499
- // this can be null if we use external encoder like CoreML or OpenVINO
500
- return;
550
+ // since there are dependencies between the different graphs,
551
+ // we need to allocate them instead of only reserving to get the correct compute buffer size
552
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, get_graph())) {
553
+ // failed to allocate the compute buffer
554
+ WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
555
+ return false;
501
556
  }
502
557
 
503
- auto & alloc = allocr.alloc;
504
- auto & buffer = allocr.buffer;
505
-
506
- size_t size = wsp_ggml_allocr_max_size(alloc);
507
-
508
- wsp_ggml_allocr_free(alloc);
558
+ wsp_ggml_backend_sched_reset(sched);
509
559
 
510
- buffer = wsp_ggml_backend_alloc_buffer(backend, size);
511
- alloc = wsp_ggml_allocr_new_from_buffer(buffer);
512
- }
513
-
514
- static void whisper_allocr_free(struct whisper_allocr & allocr) {
515
- if (allocr.alloc) {
516
- wsp_ggml_allocr_free(allocr.alloc);
517
- wsp_ggml_backend_buffer_free(allocr.buffer);
518
- allocr.alloc = nullptr;
519
- }
560
+ return true;
520
561
  }
521
562
 
522
563
  // medium
@@ -658,9 +699,9 @@ struct whisper_kv_cache {
658
699
  struct wsp_ggml_tensor * k;
659
700
  struct wsp_ggml_tensor * v;
660
701
 
661
- struct wsp_ggml_context * ctx;
702
+ wsp_ggml_backend_buffer_t buffer = nullptr;
662
703
 
663
- wsp_ggml_backend_buffer_t buffer;
704
+ std::vector<uint8_t> ctx_buf;
664
705
  };
665
706
 
666
707
  struct whisper_model {
@@ -698,10 +739,10 @@ struct whisper_model {
698
739
  std::vector<whisper_layer_decoder> layers_decoder;
699
740
 
700
741
  // ggml context that contains all the meta information about the model tensors
701
- struct wsp_ggml_context * ctx;
742
+ struct wsp_ggml_context * ctx = nullptr;
702
743
 
703
744
  // the model backend data is read-only and can be shared between processors
704
- std::vector<struct wsp_ggml_backend_buffer *> buffers;
745
+ wsp_ggml_backend_buffer_t buffer = nullptr;
705
746
 
706
747
  // tensors
707
748
  int n_loaded;
@@ -766,6 +807,13 @@ struct whisper_decoder {
766
807
  mutable std::mt19937 rng; // used for sampling at t > 0.0
767
808
  };
768
809
 
810
+ // [EXPERIMENTAL] Token-level timestamps with DTW
811
+ struct whisper_aheads_masks {
812
+ std::vector<struct wsp_ggml_tensor *> m; // One mask per text layer.
813
+ struct wsp_ggml_context * ctx = nullptr;
814
+ wsp_ggml_backend_buffer_t buffer = nullptr;
815
+ };
816
+
769
817
  struct whisper_state {
770
818
  int64_t t_sample_us = 0;
771
819
  int64_t t_encode_us = 0;
@@ -782,6 +830,9 @@ struct whisper_state {
782
830
  int32_t n_fail_p = 0; // number of logprob threshold failures
783
831
  int32_t n_fail_h = 0; // number of entropy threshold failures
784
832
 
833
+ // number of decoders for which we have constructed the KV cache
834
+ int32_t kv_self_n_dec = 0;
835
+
785
836
  // unified self-attention KV cache for all decoders
786
837
  whisper_kv_cache kv_self;
787
838
 
@@ -789,21 +840,22 @@ struct whisper_state {
789
840
  // shared between all decoders
790
841
  whisper_kv_cache kv_cross;
791
842
 
843
+ // padded buffer for flash-attention
844
+ whisper_kv_cache kv_pad;
845
+
792
846
  whisper_mel mel;
793
847
 
794
848
  whisper_batch batch;
795
849
 
796
850
  whisper_decoder decoders[WHISPER_MAX_DECODERS];
797
851
 
798
- wsp_ggml_backend_t backend = nullptr;
852
+ std::vector<wsp_ggml_backend_t> backends;
799
853
 
800
- // ggml-alloc:
801
854
  // - stores meta info about the intermediate tensors into the `meta` buffers
802
- // - stores the actual tensor data into the `data` buffers
803
- whisper_allocr alloc_conv;
804
- whisper_allocr alloc_encode;
805
- whisper_allocr alloc_cross;
806
- whisper_allocr alloc_decode;
855
+ whisper_sched sched_conv;
856
+ whisper_sched sched_encode;
857
+ whisper_sched sched_cross;
858
+ whisper_sched sched_decode;
807
859
 
808
860
  // result of the encoder
809
861
  struct wsp_ggml_tensor * embd_conv = nullptr;
@@ -839,6 +891,11 @@ struct whisper_state {
839
891
 
840
892
  std::vector<float> energy; // PCM signal energy
841
893
 
894
+ // [EXPERIMENTAL] Token-level timestamps with DTW
895
+ whisper_aheads_masks aheads_masks;
896
+ wsp_ggml_tensor * aheads_cross_QKs = nullptr;
897
+ std::vector<float> aheads_cross_QKs_data;
898
+
842
899
  // [EXPERIMENTAL] speed-up techniques
843
900
  int32_t exp_n_audio_ctx = 0; // 0 - use default
844
901
  };
@@ -857,8 +914,6 @@ struct whisper_context {
857
914
 
858
915
  whisper_state * state = nullptr;
859
916
 
860
- wsp_ggml_backend_t backend = nullptr;
861
-
862
917
  std::string path_model; // populated by whisper_init_from_file_with_params()
863
918
  };
864
919
 
@@ -876,21 +931,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
876
931
  BYTESWAP_VALUE(dest);
877
932
  }
878
933
 
879
- static bool kv_cache_init(
880
- const struct whisper_hparams & hparams,
934
+ static bool whisper_kv_cache_init(
881
935
  struct whisper_kv_cache & cache,
882
936
  wsp_ggml_backend_t backend,
883
937
  wsp_ggml_type wtype,
938
+ int64_t n_text_state,
939
+ int64_t n_text_layer,
884
940
  int n_ctx) {
885
- const int64_t n_text_state = hparams.n_text_state;
886
- const int64_t n_text_layer = hparams.n_text_layer;
887
-
888
941
  const int64_t n_mem = n_text_layer*n_ctx;
889
942
  const int64_t n_elements = n_text_state*n_mem;
890
943
 
944
+ cache.ctx_buf.resize(2*wsp_ggml_tensor_overhead());
945
+
891
946
  struct wsp_ggml_init_params params = {
892
- /*.mem_size =*/ 2*wsp_ggml_tensor_overhead(),
893
- /*.mem_buffer =*/ nullptr,
947
+ /*.mem_size =*/ cache.ctx_buf.size(),
948
+ /*.mem_buffer =*/ cache.ctx_buf.data(),
894
949
  /*.no_alloc =*/ true,
895
950
  };
896
951
 
@@ -900,39 +955,31 @@ static bool kv_cache_init(
900
955
  cache.cells.clear();
901
956
  cache.cells.resize(n_ctx);
902
957
 
903
- cache.ctx = wsp_ggml_init(params);
958
+ struct wsp_ggml_context * ctx = wsp_ggml_init(params);
904
959
 
905
- if (!cache.ctx) {
906
- WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
960
+ if (!ctx) {
961
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
907
962
  return false;
908
963
  }
909
964
 
910
- cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
911
- cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
965
+ cache.k = wsp_ggml_new_tensor_1d(ctx, wtype, n_elements);
966
+ cache.v = wsp_ggml_new_tensor_1d(ctx, wtype, n_elements);
912
967
 
913
- const size_t mem_bytes = wsp_ggml_nbytes(cache.k) + wsp_ggml_nbytes(cache.v);
914
-
915
- cache.buffer = wsp_ggml_backend_alloc_buffer(backend, mem_bytes);
916
-
917
- // allocate the tensors into the backend buffer
918
- {
919
- wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(cache.buffer);
968
+ cache.buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx, backend);
969
+ if (!cache.buffer) {
970
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
971
+ return false;
972
+ }
920
973
 
921
- wsp_ggml_allocr_alloc(alloc, cache.k);
922
- wsp_ggml_allocr_alloc(alloc, cache.v);
974
+ wsp_ggml_backend_buffer_clear(cache.buffer, 0);
923
975
 
924
- wsp_ggml_allocr_free(alloc);
925
- }
976
+ wsp_ggml_free(ctx);
926
977
 
927
978
  return true;
928
979
  }
929
980
 
930
- static void kv_cache_free(struct whisper_kv_cache & cache) {
931
- if (cache.ctx) {
932
- wsp_ggml_free(cache.ctx);
933
- wsp_ggml_backend_buffer_free(cache.buffer);
934
- cache.ctx = nullptr;
935
- }
981
+ static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
982
+ wsp_ggml_backend_buffer_free(cache.buffer);
936
983
  }
937
984
 
938
985
  static bool whisper_kv_cache_find_slot(
@@ -1003,6 +1050,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
1003
1050
  cache.cells[i].seq_id.clear();
1004
1051
  }
1005
1052
  cache.head = 0;
1053
+
1054
+ wsp_ggml_backend_buffer_clear(cache.buffer, 0);
1006
1055
  }
1007
1056
 
1008
1057
  static void whisper_kv_cache_seq_rm(
@@ -1053,15 +1102,167 @@ static void whisper_kv_cache_seq_cp(
1053
1102
  }
1054
1103
  }
1055
1104
 
1056
- static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1057
- wsp_ggml_backend_t backend_gpu = NULL;
1105
+ static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
1106
+ if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
1107
+ return 1u;
1108
+ }
1109
+
1110
+ #ifdef WSP_GGML_USE_METAL
1111
+ if (wctx.params.use_gpu) {
1112
+ return 32u;
1113
+ }
1114
+ #endif
1058
1115
 
1059
- // initialize the backends
1060
- #ifdef WSP_GGML_USE_CUBLAS
1061
- if (params.use_gpu && wsp_ggml_cublas_loaded()) {
1116
+ #ifdef WSP_GGML_USE_CUDA
1117
+ if (wctx.params.use_gpu) {
1118
+ return 256u;
1119
+ }
1120
+ #endif
1121
+
1122
+ return 1u;
1123
+ }
1124
+
1125
+ // [EXPERIMENTAL] Token-level timestamps with DTW
1126
+ static bool aheads_masks_init(
1127
+ const whisper_context_params & cparams,
1128
+ const whisper_hparams & hparams,
1129
+ struct whisper_aheads_masks & aheads_masks,
1130
+ wsp_ggml_backend_t backend) {
1131
+
1132
+ const int32_t n_text_layer = hparams.n_text_layer;
1133
+ const int32_t n_head = hparams.n_text_head;
1134
+
1135
+ // Sanity checks
1136
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
1137
+ WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__);
1138
+ return false;
1139
+ } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
1140
+ if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) {
1141
+ WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer);
1142
+ return false;
1143
+ }
1144
+ } else {
1145
+ const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
1146
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) {
1147
+ if (aheads.n_heads == 0) {
1148
+ WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__);
1149
+ return false;
1150
+ }
1151
+ if (aheads.heads == NULL) {
1152
+ WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__);
1153
+ return false;
1154
+ }
1155
+ }
1156
+ for (size_t i = 0; i < aheads.n_heads; ++i) {
1157
+ if (aheads.heads[i].n_text_layer >= n_text_layer) {
1158
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer);
1159
+ return false;
1160
+ }
1161
+ if (aheads.heads[i].n_text_layer < 0) {
1162
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__);
1163
+ return false;
1164
+ }
1165
+ if (aheads.heads[i].n_head >= n_head) {
1166
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head);
1167
+ return false;
1168
+ }
1169
+ if (aheads.heads[i].n_head < 0) {
1170
+ WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__);
1171
+ return false;
1172
+ }
1173
+ }
1174
+ }
1175
+
1176
+ struct wsp_ggml_init_params params = {
1177
+ /*.mem_size =*/ (size_t) static_cast<size_t>(n_text_layer)*wsp_ggml_tensor_overhead(),
1178
+ /*.mem_buffer =*/ nullptr,
1179
+ /*.no_alloc =*/ true,
1180
+ };
1181
+
1182
+ aheads_masks.ctx = wsp_ggml_init(params);
1183
+
1184
+ if (!aheads_masks.ctx) {
1185
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__);
1186
+ return false;
1187
+ }
1188
+
1189
+ for (int64_t il = 0; il < n_text_layer; ++il) {
1190
+ auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
1191
+ if (!aheads.empty()) {
1192
+ aheads_masks.m.push_back(wsp_ggml_new_tensor_2d(aheads_masks.ctx, WSP_GGML_TYPE_F32, n_head, aheads.size()));
1193
+ } else {
1194
+ aheads_masks.m.push_back(nullptr);
1195
+ }
1196
+ }
1197
+
1198
+ aheads_masks.buffer = wsp_ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend);
1199
+ if (!aheads_masks.buffer) {
1200
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__);
1201
+ return false;
1202
+ }
1203
+
1204
+ // Set data on mask tensors
1205
+ // Since this must be backend agnostic, we write our desired values on mask_data,
1206
+ // and send it to backend with wsp_ggml_backend_tensor_set.
1207
+ // Each mask in N_HEADS*N_ALIGNMENT_HEADS, one per text layer containing alignment
1208
+ // heads. Each row of the mask "marks" one alignment head. E.g. if some text layer
1209
+ // has a total of 10 heads and of those, heads 0,5,6 are alignment heads, the mask
1210
+ // should read:
1211
+ // 1 0 0 0 0 0 0 0 0 0
1212
+ // 0 0 0 0 0 1 0 0 0 0
1213
+ // 0 0 0 0 0 0 1 0 0 0
1214
+ std::vector<float> mask_data;
1215
+ for (int64_t il = 0; il < n_text_layer; ++il) {
1216
+ if (aheads_masks.m[il] != nullptr) {
1217
+ auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
1218
+
1219
+ size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1];
1220
+ size_t data_size_bytes = data_size * sizeof(float);
1221
+ mask_data.resize(data_size);
1222
+
1223
+ std::fill(mask_data.begin(), mask_data.end(), 0);
1224
+ for (size_t ih = 0; ih < aheads.size(); ++ih) {
1225
+ size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0]));
1226
+ mask_data[pos] = 1.0f;
1227
+ }
1228
+
1229
+ wsp_ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size_bytes);
1230
+ }
1231
+ }
1232
+
1233
+ if (aheads_masks.m.empty()) {
1234
+ WHISPER_LOG_ERROR("%s: \n", __func__);
1235
+ return false;
1236
+ }
1237
+
1238
+ return true;
1239
+ }
1240
+
1241
+ static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) {
1242
+ wsp_ggml_free(aheads_masks.ctx);
1243
+ wsp_ggml_backend_buffer_free(aheads_masks.buffer);
1244
+ aheads_masks.ctx = nullptr;
1245
+ }
1246
+
1247
+ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1248
+ size_t size = 0;
1249
+ for (size_t i = 0; i < aheads_masks.m.size(); ++i) {
1250
+ if (aheads_masks.m[i] != nullptr)
1251
+ size += wsp_ggml_nbytes(aheads_masks.m[i]);
1252
+ }
1253
+ return size;
1254
+ }
1255
+
1256
+ static wsp_ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1257
+ wsp_ggml_backend_t result = NULL;
1258
+
1259
+ wsp_ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1260
+
1261
+ #ifdef WSP_GGML_USE_CUDA
1262
+ if (params.use_gpu) {
1062
1263
  WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1063
- backend_gpu = wsp_ggml_backend_cuda_init(0);
1064
- if (!backend_gpu) {
1264
+ result = wsp_ggml_backend_cuda_init(params.gpu_device);
1265
+ if (!result) {
1065
1266
  WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
1066
1267
  }
1067
1268
  }
@@ -1070,22 +1271,108 @@ static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & pa
1070
1271
  #ifdef WSP_GGML_USE_METAL
1071
1272
  if (params.use_gpu) {
1072
1273
  WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1073
- wsp_ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
1074
- backend_gpu = wsp_ggml_backend_metal_init();
1075
- if (!backend_gpu) {
1274
+ result = wsp_ggml_backend_metal_init();
1275
+ if (!result) {
1076
1276
  WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__);
1077
- } else if (!wsp_ggml_backend_metal_supports_family(backend_gpu, 7)) {
1277
+ } else if (!wsp_ggml_backend_metal_supports_family(result, 7)) {
1078
1278
  WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1079
- wsp_ggml_backend_free(backend_gpu);
1080
- backend_gpu = NULL;
1279
+ wsp_ggml_backend_free(result);
1280
+ result = NULL;
1081
1281
  }
1082
1282
  }
1083
1283
  #endif
1084
1284
 
1285
+ #ifdef WSP_GGML_USE_SYCL
1286
+ if (params.use_gpu) {
1287
+ WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
1288
+ result = wsp_ggml_backend_sycl_init(params.gpu_device);
1289
+ if (!result) {
1290
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_sycl_init() failed\n", __func__);
1291
+ }
1292
+ }
1293
+ #endif
1294
+
1295
+ #ifdef WSP_GGML_USE_VULKAN
1296
+ if (params.use_gpu) {
1297
+ WHISPER_LOG_INFO("%s: using Vulkan backend\n", __func__);
1298
+ result = wsp_ggml_backend_vk_init(params.gpu_device);
1299
+ if (!result) {
1300
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_vk_init() failed\n", __func__);
1301
+ }
1302
+ }
1303
+ #endif
1304
+
1305
+ #ifdef WSP_GGML_USE_CANN
1306
+ if (params.use_gpu) {
1307
+ WHISPER_LOG_INFO("%s: using CANN backend\n", __func__);
1308
+ result = wsp_ggml_backend_cann_init(params.gpu_device);
1309
+ if (!result) {
1310
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cann_init() failed\n", __func__);
1311
+ }
1312
+ }
1313
+ #endif
1314
+
1315
+ WSP_GGML_UNUSED(params);
1316
+
1317
+ return result;
1318
+ }
1319
+
1320
+ static std::vector<wsp_ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
1321
+ std::vector<wsp_ggml_backend_t> result;
1322
+
1323
+ wsp_ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
1324
+
1085
1325
  if (backend_gpu) {
1086
- return backend_gpu;
1326
+ result.push_back(backend_gpu);
1327
+ }
1328
+
1329
+ #ifdef WSP_GGML_USE_BLAS
1330
+ {
1331
+ WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__);
1332
+ wsp_ggml_backend_t backend_blas = wsp_ggml_backend_blas_init();
1333
+ if (!backend_blas) {
1334
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_blas_init() failed\n", __func__);
1335
+ } else {
1336
+ result.push_back(backend_blas);
1337
+ }
1087
1338
  }
1088
- return wsp_ggml_backend_cpu_init();
1339
+ #endif
1340
+
1341
+ WSP_GGML_UNUSED(params);
1342
+
1343
+ result.push_back(wsp_ggml_backend_cpu_init());
1344
+
1345
+ return result;
1346
+ }
1347
+
1348
+ static wsp_ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1349
+ wsp_ggml_backend_buffer_type_t result = nullptr;
1350
+
1351
+ params.use_gpu || (result = wsp_ggml_backend_cpu_buffer_type());
1352
+
1353
+ #ifdef WSP_GGML_USE_CUDA
1354
+ result || (result = wsp_ggml_backend_cuda_buffer_type(params.gpu_device));
1355
+ #endif
1356
+
1357
+ #ifdef WSP_GGML_USE_METAL
1358
+ result || (result = wsp_ggml_backend_metal_buffer_type());
1359
+ #endif
1360
+
1361
+ #ifdef WSP_GGML_USE_SYCL
1362
+ result || (result = wsp_ggml_backend_sycl_buffer_type(params.gpu_device));
1363
+ #endif
1364
+
1365
+ #ifdef WSP_GGML_USE_VULKAN
1366
+ result || (result = wsp_ggml_backend_vk_buffer_type(params.gpu_device));
1367
+ #endif
1368
+
1369
+ #ifdef WSP_GGML_USE_CANN
1370
+ result || (result == wsp_ggml_backend_cann_buffer_type(params.gpu_device));
1371
+ #endif
1372
+
1373
+ result || (result = wsp_ggml_backend_cpu_buffer_type());
1374
+
1375
+ return result;
1089
1376
  }
1090
1377
 
1091
1378
  // load the model from a ggml file
@@ -1512,69 +1799,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1512
1799
  }
1513
1800
  }
1514
1801
 
1515
- wctx.backend = whisper_backend_init(wctx.params);
1516
-
1517
- // some devices have a limit on the maximum size of single memory buffer
1518
- // for example, iPhones are limited to 1GB per buffer
1519
- // to workaround this, we will allocate multiple buffers of smaller size and will split the tensors with the
1520
- // model weights between them
1521
- //
1522
- // the map_t2b maps tensor names to buffer indices
1523
- // as we iterate over the tensors, we will allocate new buffers when the current one is full
1524
- //
1525
- // finally, we create a separate allocator for each buffer and use it to allocate the tensors
1526
- // we keep the allocators alive until all the tensors are loaded
1527
-
1528
- WSP_GGML_ASSERT(model.buffers.empty());
1529
-
1530
- std::map<std::string, int> map_t2b;
1531
-
1532
- {
1533
- size_t size_main = 0;
1534
- size_t size_cur = 0;
1535
-
1536
- static const size_t GB = 1024ull*1024ull*1024ull;
1537
-
1538
- for (const auto & t : model.tensors) {
1539
- const size_t cur = wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead();
1540
-
1541
- // adding the tensor to the current buffer will exceed the limit, so we need to allocate a new buffer
1542
- if (size_cur + cur > GB) {
1543
- WSP_GGML_ASSERT(size_cur > 0 && "A tensor is too large to fit in a single buffer");
1544
-
1545
- model.buffers.emplace_back(wsp_ggml_backend_alloc_buffer(wctx.backend, size_cur));
1546
-
1547
- size_cur = cur;
1548
- }
1549
-
1550
- map_t2b[t.first] = model.buffers.size();
1551
-
1552
- size_cur += cur;
1553
- size_main += cur;
1554
- }
1555
-
1556
- // allocate the last buffer if needed
1557
- if (size_cur > 0) {
1558
- model.buffers.emplace_back(wsp_ggml_backend_alloc_buffer(wctx.backend, size_cur));
1559
- }
1560
-
1561
- WSP_GGML_ASSERT(model.buffers.size() > 0);
1562
-
1563
- WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB (%d buffers)\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6, (int) model.buffers.size());
1564
- }
1565
-
1566
- std::vector<wsp_ggml_allocr *> allocs(model.buffers.size());
1567
- for (size_t i = 0; i < allocs.size(); ++i) {
1568
- allocs[i] = wsp_ggml_allocr_new_from_buffer(model.buffers[i]);
1569
- }
1570
-
1571
1802
  // allocate tensors in the backend buffers
1572
- {
1573
- for (const auto & t : model.tensors) {
1574
- wsp_ggml_allocr_alloc(allocs[map_t2b[t.first]], t.second);
1575
- }
1803
+ model.buffer = wsp_ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
1804
+ if (!model.buffer) {
1805
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1806
+ return false;
1576
1807
  }
1577
1808
 
1809
+ size_t size_main = wsp_ggml_backend_buffer_get_size(model.buffer);
1810
+ WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, wsp_ggml_backend_buffer_name(model.buffer), size_main / 1e6);
1811
+
1578
1812
  // load weights
1579
1813
  {
1580
1814
  size_t total_size = 0;
@@ -1636,15 +1870,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1636
1870
  return false;
1637
1871
  }
1638
1872
 
1639
- wsp_ggml_backend_t backend = wctx.backend;
1873
+ //wsp_ggml_backend_t backend = wctx.backend;
1640
1874
 
1641
1875
  //printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
1642
1876
 
1643
- if ((wsp_ggml_backend_is_cpu(backend)
1644
- #ifdef WSP_GGML_USE_METAL
1645
- || wsp_ggml_backend_is_metal(backend)
1646
- #endif
1647
- )) {
1877
+ if (wsp_ggml_backend_buffer_is_host(model.buffer)) {
1648
1878
  // for the CPU and Metal backend, we can read directly into the tensor
1649
1879
  loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
1650
1880
  BYTESWAP_TENSOR(tensor);
@@ -1672,9 +1902,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1672
1902
  }
1673
1903
  }
1674
1904
 
1675
- for (auto & alloc : allocs) {
1676
- wsp_ggml_allocr_free(alloc);
1677
- }
1905
+ wsp_ggml_backend_buffer_set_usage(model.buffer, WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1678
1906
 
1679
1907
  wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
1680
1908
 
@@ -1701,10 +1929,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
1701
1929
 
1702
1930
  static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1703
1931
  whisper_context & wctx,
1704
- whisper_state & wstate,
1705
- const int mel_offset) {
1932
+ whisper_state & wstate) {
1706
1933
  const auto & model = wctx.model;
1707
- const auto & mel_inp = wstate.mel;
1708
1934
  const auto & hparams = model.hparams;
1709
1935
 
1710
1936
  const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
@@ -1713,8 +1939,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1713
1939
  const int n_mels = hparams.n_mels;
1714
1940
 
1715
1941
  struct wsp_ggml_init_params params = {
1716
- /*.mem_size =*/ wstate.alloc_conv.meta.size(),
1717
- /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
1942
+ /*.mem_size =*/ wstate.sched_conv.meta.size(),
1943
+ /*.mem_buffer =*/ wstate.sched_conv.meta.data(),
1718
1944
  /*.no_alloc =*/ true,
1719
1945
  };
1720
1946
 
@@ -1722,31 +1948,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1722
1948
 
1723
1949
  wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1724
1950
 
1725
- wsp_ggml_allocr * alloc = wstate.alloc_conv.alloc;
1726
-
1727
1951
  struct wsp_ggml_tensor * mel = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, 2*n_ctx, n_mels);
1728
- wsp_ggml_allocr_alloc(alloc, mel);
1729
-
1730
- assert(mel->type == WSP_GGML_TYPE_F32);
1731
- if (!wsp_ggml_allocr_is_measure(alloc)) {
1732
- assert(mel_inp.n_mel == n_mels);
1733
-
1734
- wstate.inp_mel.resize(wsp_ggml_nelements(mel));
1735
-
1736
- float * dst = wstate.inp_mel.data();
1737
- memset(dst, 0, wsp_ggml_nbytes(mel));
1738
-
1739
- const int i0 = std::min(mel_offset, mel_inp.n_len);
1740
- const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1741
-
1742
- for (int j = 0; j < mel_inp.n_mel; ++j) {
1743
- for (int i = i0; i < i1; ++i) {
1744
- dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1745
- }
1746
- }
1747
-
1748
- wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
1749
- }
1952
+ wsp_ggml_set_name(mel, "mel");
1953
+ wsp_ggml_set_input(mel);
1750
1954
 
1751
1955
  struct wsp_ggml_tensor * cur = nullptr;
1752
1956
 
@@ -1767,27 +1971,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1767
1971
  wsp_ggml_set_name(cur, "embd_conv");
1768
1972
  wstate.embd_conv = cur;
1769
1973
  } else {
1770
- #ifdef WHISPER_USE_COREML
1771
- cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1772
- wsp_ggml_allocr_alloc(alloc, cur);
1974
+ wsp_ggml_build_forward_expand(gf, mel);
1773
1975
 
1774
- if (!wsp_ggml_allocr_is_measure(alloc)) {
1775
- whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
1776
- }
1777
- #endif
1778
- #ifdef WHISPER_USE_OPENVINO
1779
1976
  cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1780
- wsp_ggml_allocr_alloc(alloc, cur);
1781
-
1782
- if (!wsp_ggml_allocr_is_measure(alloc)) {
1783
- whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
1784
- }
1785
- #endif
1977
+ wsp_ggml_set_input(cur); // the external encoder will write into this tensor
1786
1978
 
1787
1979
  wsp_ggml_set_name(cur, "embd_enc");
1788
1980
  wstate.embd_enc = cur;
1789
1981
  }
1790
1982
 
1983
+ wsp_ggml_set_output(cur);
1984
+
1791
1985
  wsp_ggml_build_forward_expand(gf, cur);
1792
1986
 
1793
1987
  wsp_ggml_free(ctx0);
@@ -1806,9 +2000,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1806
2000
  const int n_head = hparams.n_audio_head;
1807
2001
  const int n_layer = hparams.n_audio_layer;
1808
2002
 
2003
+ const int n_state_head = n_state/n_head;
2004
+
2005
+ auto & kv_pad = wstate.kv_pad;
2006
+
2007
+ WHISPER_ASSERT(!!kv_pad.buffer);
2008
+
2009
+ const int n_ctx_pad = WSP_GGML_PAD(n_ctx, 256);
2010
+
1809
2011
  struct wsp_ggml_init_params params = {
1810
- /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1811
- /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
2012
+ /*.mem_size =*/ wstate.sched_encode.meta.size(),
2013
+ /*.mem_buffer =*/ wstate.sched_encode.meta.data(),
1812
2014
  /*.no_alloc =*/ true,
1813
2015
  };
1814
2016
 
@@ -1816,17 +2018,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1816
2018
 
1817
2019
  wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
1818
2020
 
1819
- //wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
1820
-
1821
- //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state);
1822
- //wsp_ggml_allocr_alloc(alloc, cur);
1823
-
1824
- //if (!wsp_ggml_allocr_is_measure(alloc)) {
1825
- // wsp_ggml_backend_tensor_copy(wstate.embd_conv, cur);
1826
- //}
1827
2021
  struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
1828
2022
 
1829
- const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
2023
+ const float KQscale = 1.0f/sqrtf(float(n_state_head));
1830
2024
 
1831
2025
  // ===================================================================
1832
2026
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1876,14 +2070,14 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1876
2070
 
1877
2071
  Qcur = wsp_ggml_add(ctx0, Qcur, layer.attn_q_b);
1878
2072
 
1879
- //Qcur = wsp_ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
2073
+ //Qcur = wsp_ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
1880
2074
 
1881
2075
  // note: no bias for Key
1882
2076
  struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
1883
2077
  layer.attn_k_w,
1884
2078
  cur);
1885
2079
 
1886
- //Kcur = wsp_ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
2080
+ //Kcur = wsp_ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
1887
2081
 
1888
2082
  struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
1889
2083
  layer.attn_v_w,
@@ -1893,70 +2087,60 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1893
2087
 
1894
2088
  // ------
1895
2089
 
1896
- #ifdef WHISPER_USE_FLASH_ATTN
1897
2090
  struct wsp_ggml_tensor * Q =
1898
2091
  wsp_ggml_permute(ctx0,
1899
- wsp_ggml_cpy(ctx0,
1900
- Qcur,
1901
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
2092
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
1902
2093
  0, 2, 1, 3);
1903
2094
 
1904
- struct wsp_ggml_tensor * K =
1905
- wsp_ggml_permute(ctx0,
1906
- wsp_ggml_cpy(ctx0,
1907
- Kcur,
1908
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1909
- 0, 2, 1, 3);
2095
+ if (wctx.params.flash_attn) {
2096
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, wsp_ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
2097
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, wsp_ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
1910
2098
 
1911
- struct wsp_ggml_tensor * V =
1912
- wsp_ggml_cpy(ctx0,
1913
- wsp_ggml_permute(ctx0,
1914
- wsp_ggml_reshape_3d(ctx0,
1915
- Vcur,
1916
- n_state/n_head, n_head, n_ctx),
1917
- 1, 2, 0, 3),
1918
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
2099
+ struct wsp_ggml_tensor * K =
2100
+ wsp_ggml_view_3d(ctx0, kv_pad.k,
2101
+ n_state_head, n_ctx_pad, n_head,
2102
+ wsp_ggml_element_size(kv_pad.k)*n_state,
2103
+ wsp_ggml_element_size(kv_pad.k)*n_state_head,
2104
+ 0);
1919
2105
 
1920
- struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false);
1921
- #else
1922
- struct wsp_ggml_tensor * Q =
1923
- wsp_ggml_permute(ctx0,
1924
- wsp_ggml_cpy(ctx0,
1925
- Qcur,
1926
- wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1927
- 0, 2, 1, 3);
2106
+ struct wsp_ggml_tensor * V =
2107
+ wsp_ggml_view_3d(ctx0, kv_pad.v,
2108
+ n_state_head, n_ctx_pad, n_head,
2109
+ wsp_ggml_element_size(kv_pad.v)*n_state,
2110
+ wsp_ggml_element_size(kv_pad.v)*n_state_head,
2111
+ 0);
1928
2112
 
1929
- struct wsp_ggml_tensor * K =
1930
- wsp_ggml_permute(ctx0,
1931
- wsp_ggml_cpy(ctx0,
1932
- Kcur,
1933
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1934
- 0, 2, 1, 3);
2113
+ cur = wsp_ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
1935
2114
 
1936
- // K * Q
1937
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2115
+ cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
2116
+ } else {
2117
+ struct wsp_ggml_tensor * K =
2118
+ wsp_ggml_permute(ctx0,
2119
+ wsp_ggml_cast(ctx0,
2120
+ wsp_ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
2121
+ wctx.itype),
2122
+ 0, 2, 1, 3);
1938
2123
 
1939
- struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQscale);
2124
+ // K * Q
2125
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
1940
2126
 
1941
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_scaled);
2127
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
1942
2128
 
1943
- struct wsp_ggml_tensor * V =
1944
- wsp_ggml_cpy(ctx0,
1945
- wsp_ggml_permute(ctx0,
1946
- wsp_ggml_reshape_3d(ctx0,
1947
- Vcur,
1948
- n_state/n_head, n_head, n_ctx),
1949
- 1, 2, 0, 3),
1950
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1951
- );
2129
+ struct wsp_ggml_tensor * V =
2130
+ wsp_ggml_cast(ctx0,
2131
+ wsp_ggml_permute(ctx0,
2132
+ wsp_ggml_reshape_3d(ctx0,
2133
+ Vcur,
2134
+ n_state_head, n_head, n_ctx),
2135
+ 1, 2, 0, 3),
2136
+ wctx.itype);
1952
2137
 
1953
- struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
1954
- #endif
1955
- struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2138
+ struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
1956
2139
 
1957
- cur = wsp_ggml_cpy(ctx0,
1958
- KQV_merged,
1959
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
2140
+ struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2141
+
2142
+ cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
2143
+ }
1960
2144
  }
1961
2145
 
1962
2146
  // projection
@@ -1985,11 +2169,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1985
2169
  layer.mlp_ln_b);
1986
2170
  }
1987
2171
 
1988
- #ifdef WHISPER_USE_FLASH_FF
1989
- cur = wsp_ggml_flash_ff(ctx0,
1990
- wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
1991
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1992
- #else
1993
2172
  // fully connected
1994
2173
  cur = wsp_ggml_mul_mat(ctx0,
1995
2174
  layer.mlp_0_w,
@@ -2006,7 +2185,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
2006
2185
  cur);
2007
2186
 
2008
2187
  cur = wsp_ggml_add(ctx0, cur, layer.mlp_1_b);
2009
- #endif
2010
2188
  }
2011
2189
 
2012
2190
  inpL = wsp_ggml_add(ctx0, cur, inpFF);
@@ -2055,9 +2233,13 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
2055
2233
  const int n_state = hparams.n_audio_state;
2056
2234
  const int n_head = hparams.n_audio_head;
2057
2235
 
2236
+ const int n_state_head = n_state/n_head;
2237
+
2238
+ const int n_ctx_pad = WSP_GGML_PAD(n_ctx, 256);
2239
+
2058
2240
  struct wsp_ggml_init_params params = {
2059
- /*.mem_size =*/ wstate.alloc_cross.meta.size(),
2060
- /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
2241
+ /*.mem_size =*/ wstate.sched_cross.meta.size(),
2242
+ /*.mem_buffer =*/ wstate.sched_cross.meta.data(),
2061
2243
  /*.no_alloc =*/ true,
2062
2244
  };
2063
2245
 
@@ -2065,28 +2247,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
2065
2247
 
2066
2248
  wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
2067
2249
 
2068
- //wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
2069
-
2070
- //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
2071
- //wsp_ggml_allocr_alloc(alloc, cur);
2072
-
2073
- //if (!wsp_ggml_allocr_is_measure(alloc)) {
2074
- // wsp_ggml_backend_tensor_copy(wstate.embd_enc, cur);
2075
- //}
2076
2250
  struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
2077
2251
 
2078
- const float Kscale = pow(float(n_state) / n_head, -0.25);
2252
+ const float Kscale = pow(float(n_state_head), -0.25);
2079
2253
 
2080
2254
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
2081
2255
  auto & layer = model.layers_decoder[il];
2082
2256
 
2083
- struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
2257
+ struct wsp_ggml_tensor * Kcross = wsp_ggml_mul_mat(ctx0,
2084
2258
  layer.cross_attn_k_w,
2085
2259
  cur);
2086
2260
 
2087
2261
  Kcross = wsp_ggml_scale(ctx0, Kcross, Kscale);
2088
2262
 
2089
- struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
2263
+ struct wsp_ggml_tensor * Vcross = wsp_ggml_mul_mat(ctx0,
2090
2264
  layer.cross_attn_v_w,
2091
2265
  cur);
2092
2266
 
@@ -2094,15 +2268,25 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
2094
2268
  Vcross,
2095
2269
  layer.cross_attn_v_b);
2096
2270
 
2097
- Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2271
+ struct wsp_ggml_tensor * k;
2272
+ struct wsp_ggml_tensor * v;
2273
+
2274
+ if (wctx.params.flash_attn) {
2275
+ k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2276
+ (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
2098
2277
 
2099
- struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k,
2100
- n_state*n_ctx,
2101
- (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2278
+ v = wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
2279
+ (wsp_ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
2280
+ } else {
2281
+ Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2282
+
2283
+ k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2284
+ (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2102
2285
 
2103
- struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2104
- ( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
2105
- (il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
2286
+ v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2287
+ ( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
2288
+ (il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
2289
+ }
2106
2290
 
2107
2291
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcross, k));
2108
2292
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcross, v));
@@ -2130,53 +2314,89 @@ static bool whisper_encode_internal(
2130
2314
  whisper_state & wstate,
2131
2315
  const int mel_offset,
2132
2316
  const int n_threads,
2133
- whisper_abort_callback abort_callback,
2317
+ wsp_ggml_abort_callback abort_callback,
2134
2318
  void * abort_callback_data) {
2135
2319
  const int64_t t_start_us = wsp_ggml_time_us();
2136
2320
 
2137
2321
  // conv
2138
2322
  {
2139
- auto & alloc = wstate.alloc_conv.alloc;
2323
+ auto & sched = wstate.sched_conv.sched;
2324
+
2325
+ wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
2140
2326
 
2141
- wsp_ggml_allocr_reset(alloc);
2327
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
2328
+ // should never happen as we pre-allocate the memory
2329
+ return false;
2330
+ }
2331
+
2332
+ struct wsp_ggml_tensor * mel = wsp_ggml_graph_get_tensor(gf, "mel");
2333
+
2334
+ // set the input
2335
+ {
2336
+ const auto & mel_inp = wstate.mel;
2337
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
2338
+
2339
+ assert(mel->type == WSP_GGML_TYPE_F32);
2340
+ assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
2341
+
2342
+ wstate.inp_mel.resize(wsp_ggml_nelements(mel));
2343
+
2344
+ float * dst = wstate.inp_mel.data();
2345
+ memset(dst, 0, wsp_ggml_nbytes(mel));
2346
+
2347
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
2348
+ const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
2142
2349
 
2143
- wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
2350
+ for (int j = 0; j < mel_inp.n_mel; ++j) {
2351
+ for (int i = i0; i < i1; ++i) {
2352
+ dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
2353
+ }
2354
+ }
2144
2355
 
2145
- wsp_ggml_allocr_alloc_graph(alloc, gf);
2356
+ wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
2357
+ }
2146
2358
 
2147
2359
  if (!whisper_encode_external(wstate)) {
2148
- if (!wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2360
+ if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
2149
2361
  return false;
2150
2362
  }
2363
+ } else {
2364
+ #if defined(WHISPER_USE_COREML)
2365
+ whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
2366
+ #elif defined(WHISPER_USE_OPENVINO)
2367
+ whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc);
2368
+ #endif
2151
2369
  }
2152
2370
  }
2153
2371
 
2154
2372
  // encoder
2155
2373
  if (!whisper_encode_external(wstate)) {
2156
- auto & alloc = wstate.alloc_encode.alloc;
2157
-
2158
- wsp_ggml_allocr_reset(alloc);
2374
+ auto & sched = wstate.sched_encode.sched;
2159
2375
 
2160
2376
  wsp_ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
2161
2377
 
2162
- wsp_ggml_allocr_alloc_graph(alloc, gf);
2378
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
2379
+ // should never happen as we pre-allocate the memory
2380
+ return false;
2381
+ }
2163
2382
 
2164
- if (!wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2383
+ if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
2165
2384
  return false;
2166
2385
  }
2167
2386
  }
2168
2387
 
2169
2388
  // cross
2170
2389
  {
2171
- auto & alloc = wstate.alloc_cross.alloc;
2172
-
2173
- wsp_ggml_allocr_reset(alloc);
2390
+ auto & sched = wstate.sched_cross.sched;
2174
2391
 
2175
2392
  wsp_ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
2176
2393
 
2177
- wsp_ggml_allocr_alloc_graph(alloc, gf);
2394
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
2395
+ // should never happen as we pre-allocate the memory
2396
+ return false;
2397
+ }
2178
2398
 
2179
- if (!wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2399
+ if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
2180
2400
  return false;
2181
2401
  }
2182
2402
  }
@@ -2190,82 +2410,58 @@ static bool whisper_encode_internal(
2190
2410
  static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2191
2411
  whisper_context & wctx,
2192
2412
  whisper_state & wstate,
2193
- const whisper_batch & batch) {
2413
+ const whisper_batch & batch,
2414
+ bool save_alignment_heads_QKs,
2415
+ bool worst_case) {
2194
2416
  const auto & model = wctx.model;
2195
2417
  const auto & hparams = model.hparams;
2196
2418
 
2197
2419
  auto & kv_self = wstate.kv_self;
2198
2420
 
2199
- WHISPER_ASSERT(!!kv_self.ctx);
2200
-
2201
- wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
2421
+ WHISPER_ASSERT(!!kv_self.buffer);
2202
2422
 
2203
2423
  const int n_ctx = kv_self.size;
2204
2424
  const int n_state = hparams.n_text_state;
2205
2425
  const int n_head = hparams.n_text_head;
2206
2426
  const int n_layer = hparams.n_text_layer;
2207
2427
 
2428
+ const int n_state_head = n_state/n_head;
2429
+
2208
2430
  const int n_tokens = batch.n_tokens;
2209
2431
  const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2210
2432
 
2211
- const int32_t n_kv = wsp_ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
2212
- const int32_t kv_head = wsp_ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
2433
+ const int n_audio_ctx_pad = WSP_GGML_PAD(n_audio_ctx, 256);
2434
+
2435
+ const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2436
+ const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
2213
2437
 
2214
2438
  //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2215
2439
 
2216
2440
  struct wsp_ggml_init_params params = {
2217
- /*.mem_size =*/ wstate.alloc_decode.meta.size(),
2218
- /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
2441
+ /*.mem_size =*/ wstate.sched_decode.meta.size(),
2442
+ /*.mem_buffer =*/ wstate.sched_decode.meta.data(),
2219
2443
  /*.no_alloc =*/ true,
2220
2444
  };
2221
2445
 
2222
2446
  struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
2223
2447
 
2224
- wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
2225
-
2226
- struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2227
- wsp_ggml_allocr_alloc(alloc, embd);
2228
-
2229
- if (!wsp_ggml_allocr_is_measure(alloc)) {
2230
- wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
2231
- }
2232
-
2233
- struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2234
- wsp_ggml_allocr_alloc(alloc, position);
2235
-
2236
- if (!wsp_ggml_allocr_is_measure(alloc)) {
2237
- for (int i = 0; i < n_tokens; ++i) {
2238
- const int32_t val = batch.pos[i];
2239
- wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
2240
- }
2241
- }
2242
-
2243
- const float KQscale = pow(float(n_state)/n_head, -0.25);
2244
-
2245
- struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, n_tokens, 1);
2246
- wsp_ggml_allocr_alloc(alloc, KQ_mask);
2448
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
2247
2449
 
2248
- if (!wsp_ggml_allocr_is_measure(alloc)) {
2249
- wstate.inp_mask.resize(n_kv*n_tokens);
2450
+ struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2451
+ wsp_ggml_set_name(embd, "embd");
2452
+ wsp_ggml_set_input(embd);
2250
2453
 
2251
- float * data = wstate.inp_mask.data();
2252
- memset(data, 0, wsp_ggml_nbytes(KQ_mask));
2454
+ struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2455
+ wsp_ggml_set_name(position, "position");
2456
+ wsp_ggml_set_input(position);
2253
2457
 
2254
- for (int h = 0; h < 1; ++h) {
2255
- for (int j = 0; j < n_tokens; ++j) {
2256
- const whisper_pos pos = batch.pos[j];
2257
- const whisper_seq_id seq_id = batch.seq_id[j][0];
2458
+ const float KQscale = pow(float(n_state_head), -0.25);
2258
2459
 
2259
- for (int i = 0; i < n_kv; ++i) {
2260
- if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2261
- data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
2262
- }
2263
- }
2264
- }
2265
- }
2460
+ struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, WSP_GGML_PAD(n_tokens, WSP_GGML_KQ_MASK_PAD), 1);
2461
+ wsp_ggml_set_name(KQ_mask, "KQ_mask");
2462
+ wsp_ggml_set_input(KQ_mask);
2266
2463
 
2267
- wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
2268
- }
2464
+ struct wsp_ggml_tensor * KQ_mask_f16 = wsp_ggml_cast(ctx0, KQ_mask, WSP_GGML_TYPE_F16);
2269
2465
 
2270
2466
  // token encoding + position encoding
2271
2467
  struct wsp_ggml_tensor * cur =
@@ -2275,6 +2471,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2275
2471
 
2276
2472
  struct wsp_ggml_tensor * inpL = cur;
2277
2473
 
2474
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2475
+ struct wsp_ggml_tensor * aheads_cross_QKs = nullptr;
2476
+
2278
2477
  for (int il = 0; il < n_layer; ++il) {
2279
2478
  const auto & layer = model.layers_decoder[il];
2280
2479
 
@@ -2319,12 +2518,25 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2319
2518
  Vcur,
2320
2519
  layer.attn_v_b);
2321
2520
 
2322
- Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2521
+ struct wsp_ggml_tensor * k;
2522
+ struct wsp_ggml_tensor * v;
2523
+
2524
+ if (wctx.params.flash_attn) {
2525
+ k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2526
+ (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2527
+
2528
+ v = wsp_ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
2529
+ (wsp_ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
2530
+ } else {
2531
+ Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2532
+
2533
+ k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2534
+ (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2323
2535
 
2324
- struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2325
- struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2326
- ( n_ctx)*wsp_ggml_element_size(kv_self.v),
2327
- (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
2536
+ v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2537
+ ( n_ctx)*wsp_ggml_element_size(kv_self.v),
2538
+ (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
2539
+ }
2328
2540
 
2329
2541
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
2330
2542
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
@@ -2334,40 +2546,46 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2334
2546
 
2335
2547
  struct wsp_ggml_tensor * Q =
2336
2548
  wsp_ggml_permute(ctx0,
2337
- wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2549
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2338
2550
  0, 2, 1, 3);
2339
2551
 
2340
2552
  struct wsp_ggml_tensor * K =
2341
2553
  wsp_ggml_view_3d(ctx0, kv_self.k,
2342
- n_state/n_head, n_kv, n_head,
2554
+ n_state_head, n_kv, n_head,
2343
2555
  wsp_ggml_element_size(kv_self.k)*n_state,
2344
- wsp_ggml_element_size(kv_self.k)*n_state/n_head,
2556
+ wsp_ggml_element_size(kv_self.k)*n_state_head,
2345
2557
  wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
2346
2558
 
2347
- // K * Q
2348
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2559
+ if (wctx.params.flash_attn) {
2560
+ struct wsp_ggml_tensor * V =
2561
+ wsp_ggml_view_3d(ctx0, kv_self.v,
2562
+ n_state_head, n_kv, n_head,
2563
+ wsp_ggml_element_size(kv_self.v)*n_state,
2564
+ wsp_ggml_element_size(kv_self.v)*n_state_head,
2565
+ wsp_ggml_element_size(kv_self.v)*n_state*n_ctx*il);
2349
2566
 
2350
- //struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
2567
+ cur = wsp_ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
2351
2568
 
2352
- //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
2353
- struct wsp_ggml_tensor * KQ_masked = wsp_ggml_add(ctx0, KQ, KQ_mask);
2569
+ cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2570
+ } else {
2571
+ // K * Q
2572
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2354
2573
 
2355
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
2574
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
2356
2575
 
2357
- struct wsp_ggml_tensor * V =
2358
- wsp_ggml_view_3d(ctx0, kv_self.v,
2359
- n_kv, n_state/n_head, n_head,
2360
- n_ctx*wsp_ggml_element_size(kv_self.v),
2361
- n_ctx*wsp_ggml_element_size(kv_self.v)*n_state/n_head,
2362
- n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
2576
+ struct wsp_ggml_tensor * V =
2577
+ wsp_ggml_view_3d(ctx0, kv_self.v,
2578
+ n_kv, n_state_head, n_head,
2579
+ n_ctx*wsp_ggml_element_size(kv_self.v),
2580
+ n_ctx*wsp_ggml_element_size(kv_self.v)*n_state_head,
2581
+ n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
2363
2582
 
2364
- struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2583
+ struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2365
2584
 
2366
- struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2585
+ struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2367
2586
 
2368
- cur = wsp_ggml_cpy(ctx0,
2369
- KQV_merged,
2370
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
2587
+ cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2588
+ }
2371
2589
  }
2372
2590
 
2373
2591
  // projection
@@ -2406,62 +2624,75 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2406
2624
  Qcur,
2407
2625
  layer.cross_attn_q_b);
2408
2626
 
2409
- Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
2410
-
2411
- // Kcross is already scaled
2412
- struct wsp_ggml_tensor * Kcross =
2413
- wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2414
- n_state/n_head, n_audio_ctx, n_head,
2415
- wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2416
- wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2417
- wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2418
-
2419
- //struct wsp_ggml_tensor * Vcross =
2420
- // wsp_ggml_reshape_3d(ctx0,
2421
- // wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
2422
- // n_state/n_head, n_head, n_audio_ctx);
2423
-
2424
- //struct wsp_ggml_tensor * V_trans =
2425
- // wsp_ggml_cpy(ctx0,
2426
- // wsp_ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2427
- // wsp_ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
2428
-
2429
- struct wsp_ggml_tensor * V =
2430
- wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
2431
- n_audio_ctx, n_state/n_head, n_head,
2432
- n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
2433
- n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2434
- n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
2435
-
2436
- // ------
2437
-
2438
2627
  struct wsp_ggml_tensor * Q =
2439
2628
  wsp_ggml_permute(ctx0,
2440
- wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2629
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2441
2630
  0, 2, 1, 3);
2442
2631
 
2443
- // K * Q
2444
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
2632
+ if (wctx.params.flash_attn) {
2633
+ struct wsp_ggml_tensor * Kcross =
2634
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2635
+ n_state_head, n_audio_ctx_pad, n_head,
2636
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2637
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state_head,
2638
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
2445
2639
 
2446
- //struct wsp_ggml_tensor * KQ_scaled =
2447
- // wsp_ggml_scale(ctx0,
2448
- // KQ,
2449
- // wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2450
- // );
2640
+ struct wsp_ggml_tensor * Vcross =
2641
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
2642
+ n_state_head, n_audio_ctx_pad, n_head,
2643
+ wsp_ggml_element_size(wstate.kv_cross.v)*n_state,
2644
+ wsp_ggml_element_size(wstate.kv_cross.v)*n_state_head,
2645
+ wsp_ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
2451
2646
 
2452
- // no masking for cross-attention
2453
- //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2647
+ cur = wsp_ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
2454
2648
 
2455
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ);
2649
+ cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2650
+ } else {
2651
+ struct wsp_ggml_tensor * Kcross =
2652
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2653
+ n_state_head, n_audio_ctx, n_head,
2654
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2655
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state_head,
2656
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2657
+
2658
+ struct wsp_ggml_tensor * Vcross =
2659
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
2660
+ n_audio_ctx, n_state_head, n_head,
2661
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
2662
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state_head,
2663
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
2664
+
2665
+ // ------
2666
+
2667
+ // K * Q
2668
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
2669
+
2670
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2671
+
2672
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2673
+ if (wctx.params.dtw_token_timestamps) {
2674
+ if (wstate.aheads_masks.m[il] != nullptr) {
2675
+ struct wsp_ggml_tensor * aheads_KQs = wsp_ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2676
+ aheads_KQs = wsp_ggml_transpose(ctx0, aheads_KQs);
2677
+ aheads_KQs = wsp_ggml_cont(ctx0, aheads_KQs);
2678
+ aheads_KQs = wsp_ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2679
+ aheads_KQs = wsp_ggml_transpose(ctx0, aheads_KQs);
2680
+ aheads_KQs = wsp_ggml_cont(ctx0, aheads_KQs);
2681
+ aheads_KQs = wsp_ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2682
+ if (aheads_cross_QKs == NULL) {
2683
+ aheads_cross_QKs = aheads_KQs;
2684
+ } else {
2685
+ aheads_cross_QKs = wsp_ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
2686
+ }
2687
+ }
2688
+ }
2456
2689
 
2457
- struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2690
+ struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
2458
2691
 
2459
- struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2692
+ struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2460
2693
 
2461
- // cur = KQV_merged.contiguous().view(n_state, n_tokens)
2462
- cur = wsp_ggml_cpy(ctx0,
2463
- KQV_merged,
2464
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
2694
+ cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2695
+ }
2465
2696
  }
2466
2697
 
2467
2698
  // projection
@@ -2539,6 +2770,16 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2539
2770
 
2540
2771
  struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
2541
2772
 
2773
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2774
+ if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) {
2775
+ aheads_cross_QKs = wsp_ggml_transpose(ctx0, aheads_cross_QKs);
2776
+ aheads_cross_QKs = wsp_ggml_cont(ctx0, aheads_cross_QKs);
2777
+ if (save_alignment_heads_QKs) {
2778
+ wsp_ggml_build_forward_expand(gf, aheads_cross_QKs);
2779
+ wstate.aheads_cross_QKs = aheads_cross_QKs;
2780
+ }
2781
+ }
2782
+
2542
2783
  wsp_ggml_build_forward_expand(gf, logits);
2543
2784
 
2544
2785
  wsp_ggml_free(ctx0);
@@ -2561,7 +2802,8 @@ static bool whisper_decode_internal(
2561
2802
  whisper_state & wstate,
2562
2803
  const whisper_batch & batch,
2563
2804
  const int n_threads,
2564
- whisper_abort_callback abort_callback,
2805
+ bool save_alignment_heads_QKs,
2806
+ wsp_ggml_abort_callback abort_callback,
2565
2807
  void * abort_callback_data) {
2566
2808
  const int64_t t_start_us = wsp_ggml_time_us();
2567
2809
 
@@ -2583,24 +2825,75 @@ static bool whisper_decode_internal(
2583
2825
  return false;
2584
2826
  }
2585
2827
 
2586
- kv_self.n = whisper_kv_cache_cell_max(kv_self);
2828
+ const uint32_t pad = whisper_kv_cache_get_padding(wctx);
2829
+ kv_self.n = std::min(kv_self.size, std::max(pad, WSP_GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
2830
+
2587
2831
  //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2588
2832
  //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2589
2833
  }
2590
2834
 
2591
2835
  // decoder
2592
2836
  {
2593
- auto & alloc = wstate.alloc_decode.alloc;
2837
+ auto & sched = wstate.sched_decode.sched;
2838
+
2839
+ wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
2594
2840
 
2595
- wsp_ggml_allocr_reset(alloc);
2841
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
2842
+ // should never happen as we pre-allocate the memory
2843
+ return false;
2844
+ }
2845
+
2846
+ // set the inputs
2847
+ {
2848
+ struct wsp_ggml_tensor * embd = wsp_ggml_graph_get_tensor(gf, "embd");
2849
+ wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
2850
+ }
2851
+
2852
+ {
2853
+ struct wsp_ggml_tensor * position = wsp_ggml_graph_get_tensor(gf, "position");
2854
+ for (int i = 0; i < n_tokens; ++i) {
2855
+ const int32_t val = batch.pos[i];
2856
+ wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
2857
+ }
2858
+ }
2859
+
2860
+ {
2861
+ struct wsp_ggml_tensor * KQ_mask = wsp_ggml_graph_get_tensor(gf, "KQ_mask");
2862
+
2863
+ auto & kv_self = wstate.kv_self;
2864
+
2865
+ const int32_t n_kv = kv_self.n;
2866
+
2867
+ wstate.inp_mask.resize(wsp_ggml_nelements(KQ_mask));
2868
+
2869
+ float * data = wstate.inp_mask.data();
2870
+ memset(data, 0, wsp_ggml_nbytes(KQ_mask));
2871
+
2872
+ for (int h = 0; h < 1; ++h) {
2873
+ for (int j = 0; j < n_tokens; ++j) {
2874
+ const whisper_pos pos = batch.pos[j];
2875
+ const whisper_seq_id seq_id = batch.seq_id[j][0];
2596
2876
 
2597
- wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
2877
+ for (int i = 0; i < n_kv; ++i) {
2878
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2879
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
2880
+ }
2881
+ }
2882
+ }
2883
+
2884
+ for (int i = n_tokens; i < WSP_GGML_PAD(n_tokens, WSP_GGML_KQ_MASK_PAD); ++i) {
2885
+ for (int j = 0; j < n_kv; ++j) {
2886
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
2887
+ }
2888
+ }
2889
+ }
2598
2890
 
2599
- wsp_ggml_allocr_alloc_graph(alloc, gf);
2891
+ wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
2892
+ }
2600
2893
 
2601
- logits = gf->nodes[gf->n_nodes - 1];
2894
+ logits = wsp_ggml_graph_node(gf, -1);
2602
2895
 
2603
- if (!wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2896
+ if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
2604
2897
  return false;
2605
2898
  }
2606
2899
  }
@@ -2654,29 +2947,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2654
2947
  }
2655
2948
 
2656
2949
  #define SIN_COS_N_COUNT WHISPER_N_FFT
2657
- static float sin_vals[SIN_COS_N_COUNT];
2658
- static float cos_vals[SIN_COS_N_COUNT];
2950
+ namespace {
2951
+ struct whisper_global_cache {
2952
+ // In FFT, we frequently use sine and cosine operations with the same values.
2953
+ // We can use precalculated values to speed up the process.
2954
+ float sin_vals[SIN_COS_N_COUNT];
2955
+ float cos_vals[SIN_COS_N_COUNT];
2956
+
2957
+ // Hann window (Use cosf to eliminate difference)
2958
+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2959
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2960
+ float hann_window[WHISPER_N_FFT];
2961
+
2962
+ whisper_global_cache() {
2963
+ fill_sin_cos_table();
2964
+ fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
2965
+ }
2659
2966
 
2660
- // In FFT, we frequently use sine and cosine operations with the same values.
2661
- // We can use precalculated values to speed up the process.
2662
- static void fill_sin_cos_table() {
2663
- static bool is_filled = false;
2664
- if (is_filled) return;
2665
- for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2666
- double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
2667
- sin_vals[i] = sinf(theta);
2668
- cos_vals[i] = cosf(theta);
2967
+ void fill_sin_cos_table() {
2968
+ for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2969
+ double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
2970
+ sin_vals[i] = sinf(theta);
2971
+ cos_vals[i] = cosf(theta);
2972
+ }
2973
+ }
2974
+
2975
+ void fill_hann_window(int length, bool periodic, float * output) {
2976
+ int offset = -1;
2977
+ if (periodic) {
2978
+ offset = 0;
2979
+ }
2980
+ for (int i = 0; i < length; i++) {
2981
+ output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
2982
+ }
2669
2983
  }
2670
- is_filled = true;
2984
+ } global_cache;
2671
2985
  }
2672
2986
 
2673
2987
  // naive Discrete Fourier Transform
2674
2988
  // input is real-valued
2675
2989
  // output is complex-valued
2676
- static void dft(const std::vector<float> & in, std::vector<float> & out) {
2677
- int N = in.size();
2678
-
2679
- out.resize(N*2);
2990
+ static void dft(const float* in, int N, float* out) {
2680
2991
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2681
2992
 
2682
2993
  for (int k = 0; k < N; k++) {
@@ -2685,8 +2996,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2685
2996
 
2686
2997
  for (int n = 0; n < N; n++) {
2687
2998
  int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
2688
- re += in[n]*cos_vals[idx]; // cos(t)
2689
- im -= in[n]*sin_vals[idx]; // sin(t)
2999
+ re += in[n]*global_cache.cos_vals[idx]; // cos(t)
3000
+ im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
2690
3001
  }
2691
3002
 
2692
3003
  out[k*2 + 0] = re;
@@ -2698,47 +3009,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2698
3009
  // poor man's implementation - use something better
2699
3010
  // input is real-valued
2700
3011
  // output is complex-valued
2701
- static void fft(const std::vector<float> & in, std::vector<float> & out) {
2702
- out.resize(in.size()*2);
2703
-
2704
- int N = in.size();
2705
-
3012
+ static void fft(float* in, int N, float* out) {
2706
3013
  if (N == 1) {
2707
3014
  out[0] = in[0];
2708
3015
  out[1] = 0;
2709
3016
  return;
2710
3017
  }
2711
3018
 
2712
- if (N%2 == 1) {
2713
- dft(in, out);
3019
+ const int half_N = N / 2;
3020
+ if (N - half_N*2 == 1) {
3021
+ dft(in, N, out);
2714
3022
  return;
2715
3023
  }
2716
3024
 
2717
- std::vector<float> even;
2718
- std::vector<float> odd;
2719
-
2720
- even.reserve(N/2);
2721
- odd.reserve(N/2);
2722
-
2723
- for (int i = 0; i < N; i++) {
2724
- if (i % 2 == 0) {
2725
- even.push_back(in[i]);
2726
- } else {
2727
- odd.push_back(in[i]);
2728
- }
3025
+ float* even = in + N;
3026
+ for (int i = 0; i < half_N; ++i) {
3027
+ even[i]= in[2*i];
2729
3028
  }
3029
+ float* even_fft = out + 2 * N;
3030
+ fft(even, half_N, even_fft);
2730
3031
 
2731
- std::vector<float> even_fft;
2732
- std::vector<float> odd_fft;
2733
-
2734
- fft(even, even_fft);
2735
- fft(odd, odd_fft);
3032
+ float* odd = even;
3033
+ for (int i = 0; i < half_N; ++i) {
3034
+ odd[i] = in[2*i + 1];
3035
+ }
3036
+ float* odd_fft = even_fft + N;
3037
+ fft(odd, half_N, odd_fft);
2736
3038
 
2737
3039
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2738
- for (int k = 0; k < N/2; k++) {
3040
+ for (int k = 0; k < half_N; k++) {
2739
3041
  int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2740
- float re = cos_vals[idx]; // cos(t)
2741
- float im = -sin_vals[idx]; // sin(t)
3042
+ float re = global_cache.cos_vals[idx]; // cos(t)
3043
+ float im = -global_cache.sin_vals[idx]; // sin(t)
2742
3044
 
2743
3045
  float re_odd = odd_fft[2*k + 0];
2744
3046
  float im_odd = odd_fft[2*k + 1];
@@ -2746,61 +3048,49 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2746
3048
  out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
2747
3049
  out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
2748
3050
 
2749
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2750
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2751
- }
2752
- }
2753
-
2754
- static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2755
- if (output.size() < static_cast<size_t>(length)) {
2756
- output.resize(length);
2757
- }
2758
- int offset = -1;
2759
- if (periodic) {
2760
- offset = 0;
3051
+ out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
3052
+ out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2761
3053
  }
2762
- for (int i = 0; i < length; i++) {
2763
- output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
2764
- }
2765
-
2766
- return true;
2767
3054
  }
2768
3055
 
2769
- static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
3056
+ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
2770
3057
  int n_samples, int frame_size, int frame_step, int n_threads,
2771
3058
  const whisper_filters & filters, whisper_mel & mel) {
2772
- std::vector<float> fft_in(frame_size, 0.0);
2773
- std::vector<float> fft_out(2 * frame_step);
2774
- // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2775
- int n_fft = 1 + (frame_size / 2);
3059
+ std::vector<float> fft_in(frame_size * 2, 0.0);
3060
+ std::vector<float> fft_out(frame_size * 2 * 2 * 2);
3061
+
3062
+ int n_fft = filters.n_fft;
2776
3063
  int i = ith;
2777
3064
 
3065
+ // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
3066
+ assert(n_fft == 1 + (frame_size / 2));
3067
+
2778
3068
  // calculate FFT only when fft_in are not all zero
2779
3069
  for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2780
3070
  const int offset = i * frame_step;
2781
3071
 
2782
- // apply Hanning window (~10% faster)
3072
+ // apply Hann window (~10% faster)
2783
3073
  for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
2784
3074
  fft_in[j] = hann[j] * samples[offset + j];
2785
3075
  }
3076
+
2786
3077
  // fill the rest with zeros
2787
3078
  if (n_samples - offset < frame_size) {
2788
3079
  std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
2789
3080
  }
2790
3081
 
2791
3082
  // FFT
2792
- fft(fft_in, fft_out);
3083
+ fft(fft_in.data(), frame_size, fft_out.data());
2793
3084
 
2794
3085
  // Calculate modulus^2 of complex numbers
2795
3086
  // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
2796
- for (int j = 0; j < frame_size; j++) {
3087
+ for (int j = 0; j < n_fft; j++) {
2797
3088
  fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
2798
3089
  }
2799
3090
 
2800
3091
  // mel spectrogram
2801
3092
  for (int j = 0; j < mel.n_mel; j++) {
2802
3093
  double sum = 0.0;
2803
-
2804
3094
  // unroll loop (suggested by GH user @lunixbochs)
2805
3095
  int k = 0;
2806
3096
  for (k = 0; k < n_fft - 3; k += 4) {
@@ -2810,14 +3100,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2810
3100
  fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
2811
3101
  fft_out[k + 3] * filters.data[j * n_fft + k + 3];
2812
3102
  }
2813
-
2814
3103
  // handle n_fft remainder
2815
3104
  for (; k < n_fft; k++) {
2816
3105
  sum += fft_out[k] * filters.data[j * n_fft + k];
2817
3106
  }
2818
-
2819
3107
  sum = log10(std::max(sum, 1e-10));
2820
-
2821
3108
  mel.data[j * mel.n_len + i] = sum;
2822
3109
  }
2823
3110
  }
@@ -2846,12 +3133,9 @@ static bool log_mel_spectrogram(
2846
3133
  whisper_mel & mel) {
2847
3134
  const int64_t t_start_us = wsp_ggml_time_us();
2848
3135
 
2849
- // Hanning window (Use cosf to eliminate difference)
2850
- // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2851
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2852
- std::vector<float> hann;
2853
- hann_window(frame_size, true, hann);
2854
-
3136
+ // Hann window
3137
+ WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
3138
+ const float * hann = global_cache.hann_window;
2855
3139
 
2856
3140
  // Calculate the length of padding
2857
3141
  int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
@@ -2876,12 +3160,11 @@ static bool log_mel_spectrogram(
2876
3160
  mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
2877
3161
  mel.data.resize(mel.n_mel * mel.n_len);
2878
3162
 
2879
-
2880
3163
  {
2881
3164
  std::vector<std::thread> workers(n_threads - 1);
2882
3165
  for (int iw = 0; iw < n_threads - 1; ++iw) {
2883
3166
  workers[iw] = std::thread(
2884
- log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
3167
+ log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
2885
3168
  n_samples + stage_2_pad, frame_size, frame_step, n_threads,
2886
3169
  std::cref(filters), std::ref(mel));
2887
3170
  }
@@ -3041,19 +3324,24 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
3041
3324
  #endif
3042
3325
 
3043
3326
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
3044
- fill_sin_cos_table();
3045
-
3046
3327
  whisper_state * state = new whisper_state;
3047
3328
 
3048
- state->backend = whisper_backend_init(ctx->params);
3049
-
3050
- // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3051
- // in theory, there can be a case where this is not enough, but in practice it should always be enough
3052
- const int factor = 3;
3329
+ state->backends = whisper_backend_init(ctx->params);
3330
+ if (state->backends.empty()) {
3331
+ WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
3332
+ whisper_free_state(state);
3333
+ return nullptr;
3334
+ }
3053
3335
 
3054
- if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
3055
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3056
- delete state;
3336
+ // at this point, we don't know yet how many decoders will be used
3337
+ // later during decoding, if more decoders are used, we will recreate the KV cache respectively
3338
+ state->kv_self_n_dec = 1;
3339
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
3340
+ ctx->model.hparams.n_text_state,
3341
+ ctx->model.hparams.n_text_layer,
3342
+ WSP_GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
3343
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3344
+ whisper_free_state(state);
3057
3345
  return nullptr;
3058
3346
  }
3059
3347
 
@@ -3062,9 +3350,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3062
3350
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
3063
3351
  }
3064
3352
 
3065
- if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
3066
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
3067
- delete state;
3353
+ if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
3354
+ ctx->model.hparams.n_text_state,
3355
+ ctx->model.hparams.n_text_layer,
3356
+ WSP_GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3357
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
3358
+ whisper_free_state(state);
3068
3359
  return nullptr;
3069
3360
  }
3070
3361
 
@@ -3073,6 +3364,31 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3073
3364
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3074
3365
  }
3075
3366
 
3367
+ if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
3368
+ ctx->model.hparams.n_audio_state,
3369
+ 1,
3370
+ WSP_GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3371
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3372
+ whisper_free_state(state);
3373
+ return nullptr;
3374
+ }
3375
+
3376
+ {
3377
+ const size_t memory_size = wsp_ggml_nbytes(state->kv_pad.k) + wsp_ggml_nbytes(state->kv_pad.v);
3378
+ WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
3379
+ }
3380
+
3381
+ // [EXPERIMENTAL] Token-level timestamps with DTW
3382
+ if (ctx->params.dtw_token_timestamps) {
3383
+ if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
3384
+ WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
3385
+ whisper_free_state(state);
3386
+ return nullptr;
3387
+ }
3388
+ const size_t memory_size = aheads_masks_nbytes(state->aheads_masks);
3389
+ WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size);
3390
+ }
3391
+
3076
3392
 
3077
3393
  #ifdef WHISPER_USE_COREML
3078
3394
  if (ctx->params.use_coreml) {
@@ -3085,7 +3401,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3085
3401
  if (!state->ctx_coreml) {
3086
3402
  WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
3087
3403
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
3088
- delete state;
3404
+ whisper_free_state(state);
3089
3405
  return nullptr;
3090
3406
  #endif
3091
3407
  } else {
@@ -3110,37 +3426,55 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3110
3426
 
3111
3427
  // conv allocator
3112
3428
  {
3113
- whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
3429
+ bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
3114
3430
  [&]() {
3115
- return whisper_build_graph_conv(*ctx, *state, 0);
3431
+ return whisper_build_graph_conv(*ctx, *state);
3116
3432
  });
3117
3433
 
3118
- WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
3434
+ if (!ok) {
3435
+ WHISPER_LOG_ERROR("%s: failed to init conv allocator\n", __func__);
3436
+ whisper_free_state(state);
3437
+ return nullptr;
3438
+ }
3439
+
3440
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
3119
3441
  }
3120
3442
 
3121
3443
  // encoder allocator
3122
3444
  if (!whisper_encode_external(*state)) {
3123
- whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
3445
+ bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
3124
3446
  [&]() {
3125
3447
  return whisper_build_graph_encoder(*ctx, *state);
3126
3448
  });
3127
3449
 
3128
- WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
3450
+ if (!ok) {
3451
+ WHISPER_LOG_ERROR("%s: failed to init encoder allocator\n", __func__);
3452
+ whisper_free_state(state);
3453
+ return nullptr;
3454
+ }
3455
+
3456
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
3129
3457
  }
3130
3458
 
3131
3459
  // cross allocator
3132
3460
  {
3133
- whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
3461
+ bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
3134
3462
  [&]() {
3135
3463
  return whisper_build_graph_cross(*ctx, *state);
3136
3464
  });
3137
3465
 
3138
- WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
3466
+ if (!ok) {
3467
+ WHISPER_LOG_ERROR("%s: failed to init cross allocator\n", __func__);
3468
+ whisper_free_state(state);
3469
+ return nullptr;
3470
+ }
3471
+
3472
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
3139
3473
  }
3140
3474
 
3141
3475
  // decoder allocator
3142
3476
  {
3143
- whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
3477
+ bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
3144
3478
  [&]() {
3145
3479
  const auto & hparams = ctx->model.hparams;
3146
3480
 
@@ -3150,27 +3484,30 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3150
3484
 
3151
3485
  whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
3152
3486
 
3153
- return whisper_build_graph_decoder(*ctx, *state, state->batch);
3487
+ return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true);
3154
3488
  });
3155
3489
 
3156
- WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
3157
- }
3490
+ if (!ok) {
3491
+ WHISPER_LOG_ERROR("%s: failed to init decoder allocator\n", __func__);
3492
+ whisper_free_state(state);
3493
+ return nullptr;
3494
+ }
3158
3495
 
3159
- whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
3160
- whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
3161
- whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
3162
- whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
3496
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
3497
+ }
3163
3498
 
3164
3499
  return state;
3165
3500
  }
3166
3501
 
3167
- int whisper_ctx_init_openvino_encoder(
3502
+ int whisper_ctx_init_openvino_encoder_with_state(
3168
3503
  struct whisper_context * ctx,
3504
+ struct whisper_state * state,
3169
3505
  const char * model_path,
3170
3506
  const char * device,
3171
3507
  const char * cache_dir) {
3172
3508
  #ifndef WHISPER_USE_OPENVINO
3173
3509
  (void)(ctx);
3510
+ (void)(state);
3174
3511
  (void)(model_path);
3175
3512
  (void)(device);
3176
3513
  (void)(cache_dir);
@@ -3201,8 +3538,8 @@ int whisper_ctx_init_openvino_encoder(
3201
3538
  WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3202
3539
  WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
3203
3540
 
3204
- ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3205
- if (!ctx->state->ctx_openvino) {
3541
+ state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3542
+ if (!state->ctx_openvino) {
3206
3543
  WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3207
3544
  return 1;
3208
3545
  } else {
@@ -3213,18 +3550,43 @@ int whisper_ctx_init_openvino_encoder(
3213
3550
  #endif
3214
3551
  }
3215
3552
 
3553
+ int whisper_ctx_init_openvino_encoder(
3554
+ struct whisper_context * ctx,
3555
+ const char * model_path,
3556
+ const char * device,
3557
+ const char * cache_dir) {
3558
+ return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
3559
+ }
3560
+
3216
3561
  struct whisper_context_params whisper_context_default_params() {
3217
3562
  struct whisper_context_params result = {
3218
- /*.use_gpu =*/ true,
3219
- /*.use_coreml =*/ false,
3563
+ /*.use_gpu =*/ true,
3564
+ /*.use_coreml =*/ false,
3565
+ /*.flash_attn =*/ false,
3566
+ /*.gpu_device =*/ 0,
3567
+
3568
+ /*.dtw_token_timestamps =*/ false,
3569
+ /*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE,
3570
+ /*.dtw_n_top =*/ -1,
3571
+ /*.dtw_aheads =*/ {
3572
+ /*.n_heads =*/ 0,
3573
+ /*.heads =*/ NULL,
3574
+ },
3575
+ /*.dtw_mem_size =*/ 1024*1024*128,
3220
3576
  };
3221
3577
  return result;
3222
3578
  }
3223
3579
 
3224
3580
  struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3225
3581
  WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
3226
-
3582
+ #ifdef _MSC_VER
3583
+ // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
3584
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
3585
+ std::wstring path_model_wide = converter.from_bytes(path_model);
3586
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
3587
+ #else
3227
3588
  auto fin = std::ifstream(path_model, std::ios::binary);
3589
+ #endif
3228
3590
  if (!fin) {
3229
3591
  WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
3230
3592
  return nullptr;
@@ -3299,6 +3661,19 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
3299
3661
  struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
3300
3662
  wsp_ggml_time_init();
3301
3663
 
3664
+ if (params.flash_attn && params.dtw_token_timestamps) {
3665
+ WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
3666
+ params.dtw_token_timestamps = false;
3667
+ }
3668
+
3669
+ WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
3670
+ WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
3671
+ WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
3672
+ WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
3673
+
3674
+ // TODO: temporary call to force backend registry initialization
3675
+ WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, wsp_ggml_backend_reg_count());
3676
+
3302
3677
  whisper_context * ctx = new whisper_context;
3303
3678
  ctx->params = params;
3304
3679
 
@@ -3383,11 +3758,11 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
3383
3758
  return whisper_init_with_params_no_state(loader, whisper_context_default_params());
3384
3759
  }
3385
3760
 
3386
- void whisper_free_state(struct whisper_state * state)
3387
- {
3761
+ void whisper_free_state(struct whisper_state * state) {
3388
3762
  if (state) {
3389
- kv_cache_free(state->kv_self);
3390
- kv_cache_free(state->kv_cross);
3763
+ whisper_kv_cache_free(state->kv_self);
3764
+ whisper_kv_cache_free(state->kv_cross);
3765
+ whisper_kv_cache_free(state->kv_pad);
3391
3766
 
3392
3767
  #ifdef WHISPER_USE_COREML
3393
3768
  if (state->ctx_coreml != nullptr) {
@@ -3405,12 +3780,17 @@ void whisper_free_state(struct whisper_state * state)
3405
3780
 
3406
3781
  whisper_batch_free(state->batch);
3407
3782
 
3408
- whisper_allocr_free(state->alloc_conv);
3409
- whisper_allocr_free(state->alloc_encode);
3410
- whisper_allocr_free(state->alloc_cross);
3411
- whisper_allocr_free(state->alloc_decode);
3783
+ wsp_ggml_backend_sched_free(state->sched_conv.sched);
3784
+ wsp_ggml_backend_sched_free(state->sched_encode.sched);
3785
+ wsp_ggml_backend_sched_free(state->sched_cross.sched);
3786
+ wsp_ggml_backend_sched_free(state->sched_decode.sched);
3787
+
3788
+ for (auto & backend : state->backends) {
3789
+ wsp_ggml_backend_free(backend);
3790
+ }
3412
3791
 
3413
- wsp_ggml_backend_free(state->backend);
3792
+ // [EXPERIMENTAL] Token-level timestamps with DTW
3793
+ aheads_masks_free(state->aheads_masks);
3414
3794
 
3415
3795
  delete state;
3416
3796
  }
@@ -3418,20 +3798,12 @@ void whisper_free_state(struct whisper_state * state)
3418
3798
 
3419
3799
  void whisper_free(struct whisper_context * ctx) {
3420
3800
  if (ctx) {
3421
- if (ctx->model.ctx) {
3422
- wsp_ggml_free(ctx->model.ctx);
3423
- }
3801
+ wsp_ggml_free(ctx->model.ctx);
3424
3802
 
3425
- for (auto & buffer : ctx->model.buffers) {
3426
- if (buffer) {
3427
- wsp_ggml_backend_buffer_free(buffer);
3428
- }
3429
- }
3803
+ wsp_ggml_backend_buffer_free(ctx->model.buffer);
3430
3804
 
3431
3805
  whisper_free_state(ctx->state);
3432
3806
 
3433
- wsp_ggml_backend_free(ctx->backend);
3434
-
3435
3807
  delete ctx;
3436
3808
  }
3437
3809
  }
@@ -3461,30 +3833,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3461
3833
  return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3462
3834
  }
3463
3835
 
3464
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3465
- int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3466
- if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3467
- WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3468
- return -1;
3469
- }
3470
-
3471
- return 0;
3472
- }
3473
-
3474
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3475
- int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
3476
- return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3477
- }
3478
-
3479
- // same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
3480
- // TODO
3481
-
3482
- // same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
3483
- // TODO
3484
-
3485
- // same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
3486
- // TODO
3487
-
3488
3836
  int whisper_set_mel_with_state(
3489
3837
  struct whisper_context * ctx,
3490
3838
  struct whisper_state * state,
@@ -3537,7 +3885,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
3537
3885
 
3538
3886
  whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
3539
3887
 
3540
- if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
3888
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) {
3541
3889
  WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3542
3890
  return 1;
3543
3891
  }
@@ -3559,7 +3907,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
3559
3907
 
3560
3908
  if (n_max_tokens < (int) res.size()) {
3561
3909
  WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3562
- return -1;
3910
+ return -(int) res.size();
3563
3911
  }
3564
3912
 
3565
3913
  for (int i = 0; i < (int) res.size(); i++) {
@@ -3569,7 +3917,11 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
3569
3917
  return res.size();
3570
3918
  }
3571
3919
 
3572
- int whisper_lang_max_id() {
3920
+ int whisper_token_count(struct whisper_context * ctx, const char * text) {
3921
+ return -whisper_tokenize(ctx, text, NULL, 0);
3922
+ }
3923
+
3924
+ int whisper_lang_max_id(void) {
3573
3925
  auto max_id = 0;
3574
3926
  for (const auto & kv : g_lang) {
3575
3927
  max_id = std::max(max_id, kv.second.first);
@@ -3838,28 +4190,51 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
3838
4190
  return ctx->vocab.token_transcribe;
3839
4191
  }
3840
4192
 
4193
+ struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
4194
+ if (ctx->state == nullptr) {
4195
+ return nullptr;
4196
+ }
4197
+ return new whisper_timings {
4198
+ .load_us = ctx->t_load_us,
4199
+ .t_start_us = ctx->t_start_us,
4200
+ .fail_p = ctx->state->n_fail_p,
4201
+ .fail_h = ctx->state->n_fail_h,
4202
+ .t_mel_us = ctx->state->t_mel_us,
4203
+ .n_sample = ctx->state->n_sample,
4204
+ .n_encode = ctx->state->n_encode,
4205
+ .n_decode = ctx->state->n_decode,
4206
+ .n_batchd = ctx->state->n_batchd,
4207
+ .n_prompt = ctx->state->n_prompt,
4208
+ .t_sample_us = ctx->state->t_sample_us,
4209
+ .t_encode_us = ctx->state->t_encode_us,
4210
+ .t_decode_us = ctx->state->t_decode_us,
4211
+ .t_batchd_us = ctx->state->t_batchd_us,
4212
+ .t_prompt_us = ctx->state->t_prompt_us,
4213
+ };
4214
+ }
4215
+
3841
4216
  void whisper_print_timings(struct whisper_context * ctx) {
3842
4217
  const int64_t t_end_us = wsp_ggml_time_us();
4218
+ const struct whisper_timings * timings = whisper_get_timings(ctx);
3843
4219
 
3844
4220
  WHISPER_LOG_INFO("\n");
3845
- WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
4221
+ WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings->load_us / 1000.0f);
3846
4222
  if (ctx->state != nullptr) {
3847
-
3848
4223
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3849
4224
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3850
4225
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3851
4226
  const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
3852
4227
  const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3853
4228
 
3854
- WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3855
- WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3856
- WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3857
- WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3858
- WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3859
- WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
3860
- WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
4229
+ WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings->fail_p, timings->fail_h);
4230
+ WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings->t_mel_us/1000.0f);
4231
+ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_sample_us, n_sample, 1e-3f * timings->t_sample_us / n_sample);
4232
+ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_encode_us, n_encode, 1e-3f * timings->t_encode_us / n_encode);
4233
+ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_decode_us, n_decode, 1e-3f * timings->t_decode_us / n_decode);
4234
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_batchd_us, n_batchd, 1e-3f * timings->t_batchd_us / n_batchd);
4235
+ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_prompt_us, n_prompt, 1e-3f * timings->t_prompt_us / n_prompt);
3861
4236
  }
3862
- WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
4237
+ WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings->t_start_us)/1000.0f);
3863
4238
  }
3864
4239
 
3865
4240
  void whisper_reset_timings(struct whisper_context * ctx) {
@@ -3913,10 +4288,10 @@ const char * whisper_print_system_info(void) {
3913
4288
  s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
3914
4289
  s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
3915
4290
  s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
3916
- s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cublas()) + " | ";
4291
+ s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cuda()) + " | ";
3917
4292
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
3918
4293
  s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
3919
-
4294
+ s += "CANN = " + std::to_string(wsp_ggml_cpu_has_cann()) ;
3920
4295
  return s.c_str();
3921
4296
  }
3922
4297
 
@@ -3926,7 +4301,7 @@ const char * whisper_print_system_info(void) {
3926
4301
 
3927
4302
  // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
3928
4303
  // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
3929
- std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4304
+ static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
3930
4305
  const char * src,
3931
4306
  whisper_partial_utf8 partial_start) {
3932
4307
  static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
@@ -4340,7 +4715,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
4340
4715
 
4341
4716
  ////////////////////////////////////////////////////////////////////////////
4342
4717
 
4343
- struct whisper_context_params * whisper_context_default_params_by_ref() {
4718
+ struct whisper_context_params * whisper_context_default_params_by_ref(void) {
4344
4719
  struct whisper_context_params params = whisper_context_default_params();
4345
4720
 
4346
4721
  struct whisper_context_params* result = new whisper_context_params();
@@ -4381,12 +4756,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4381
4756
  /*.split_on_word =*/ false,
4382
4757
  /*.max_tokens =*/ 0,
4383
4758
 
4384
- /*.speed_up =*/ false,
4385
4759
  /*.debug_mode =*/ false,
4386
4760
  /*.audio_ctx =*/ 0,
4387
4761
 
4388
4762
  /*.tdrz_enable =*/ false,
4389
4763
 
4764
+ /* suppress_regex =*/ nullptr,
4765
+
4390
4766
  /*.initial_prompt =*/ nullptr,
4391
4767
  /*.prompt_tokens =*/ nullptr,
4392
4768
  /*.prompt_n_tokens =*/ 0,
@@ -4472,6 +4848,17 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
4472
4848
  return txt[0] == ' ';
4473
4849
  }
4474
4850
 
4851
+ static void whisper_exp_compute_token_level_timestamps_dtw(
4852
+ struct whisper_context * ctx,
4853
+ struct whisper_state * state,
4854
+ struct whisper_full_params params,
4855
+ int i_segment,
4856
+ size_t n_segments,
4857
+ int seek,
4858
+ int n_frames,
4859
+ int medfilt_width,
4860
+ int n_threads);
4861
+
4475
4862
  // wrap the last segment to max_len characters
4476
4863
  // returns the number of new segments
4477
4864
  static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
@@ -4619,6 +5006,17 @@ static void whisper_process_logits(
4619
5006
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
4620
5007
  }
4621
5008
 
5009
+ // suppress any tokens matching a regular expression
5010
+ // ref: https://github.com/openai/whisper/discussions/1041
5011
+ if (params.suppress_regex != nullptr) {
5012
+ std::regex re(params.suppress_regex);
5013
+ for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
5014
+ if (std::regex_match(token_id.first, re)) {
5015
+ logits[token_id.second] = -INFINITY;
5016
+ }
5017
+ }
5018
+ }
5019
+
4622
5020
  // suppress non-speech tokens
4623
5021
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
4624
5022
  if (params.suppress_non_speech_tokens) {
@@ -4822,12 +5220,25 @@ static void whisper_process_logits(
4822
5220
  #endif
4823
5221
  }
4824
5222
 
5223
+ static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) {
5224
+ if (a.tokens.size() != b.tokens.size()) {
5225
+ return false;
5226
+ }
5227
+ // sequences are more likely to diverge at the end
5228
+ for (int i = a.tokens.size() - 1; i >= 0; i--) {
5229
+ if (a.tokens[i].id != b.tokens[i].id) {
5230
+ return false;
5231
+ }
5232
+ }
5233
+ return true;
5234
+ }
5235
+
4825
5236
  static whisper_token_data whisper_sample_token(
4826
5237
  whisper_context & ctx,
4827
5238
  const whisper_decoder & decoder,
4828
5239
  bool best) {
4829
5240
  whisper_token_data result = {
4830
- 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
5241
+ 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
4831
5242
  };
4832
5243
 
4833
5244
  const auto & vocab = ctx.vocab;
@@ -4945,7 +5356,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4945
5356
  const auto id = dist(decoder.rng);
4946
5357
  //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
4947
5358
 
4948
- result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
5359
+ result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
4949
5360
 
4950
5361
  if (result[i].id >= vocab.token_beg) {
4951
5362
  result[i].tid = result[i].id;
@@ -5018,15 +5429,9 @@ int whisper_full_with_state(
5018
5429
 
5019
5430
  if (n_samples > 0) {
5020
5431
  // compute log mel spectrogram
5021
- if (params.speed_up) {
5022
- // TODO: Replace PV with more advanced algorithm
5432
+ if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
5023
5433
  WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5024
- return -1;
5025
- } else {
5026
- if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
5027
- WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5028
- return -2;
5029
- }
5434
+ return -2;
5030
5435
  }
5031
5436
  }
5032
5437
 
@@ -5063,8 +5468,8 @@ int whisper_full_with_state(
5063
5468
  // if length of spectrogram is less than 1.0s (100 frames), then return
5064
5469
  // basically don't process anything that is less than 1.0s
5065
5470
  // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
5066
- if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
5067
- WHISPER_LOG_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
5471
+ if (seek_end < seek_start + 100) {
5472
+ WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
5068
5473
  return 0;
5069
5474
  }
5070
5475
 
@@ -5127,7 +5532,12 @@ int whisper_full_with_state(
5127
5532
  // initial prompt
5128
5533
  if (!params.prompt_tokens && params.initial_prompt) {
5129
5534
  prompt_tokens.resize(1024);
5130
- prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
5535
+ int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
5536
+ if (n_needed < 0) {
5537
+ prompt_tokens.resize(-n_needed);
5538
+ n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
5539
+ }
5540
+ prompt_tokens.resize(n_needed);
5131
5541
  params.prompt_tokens = prompt_tokens.data();
5132
5542
  params.prompt_n_tokens = prompt_tokens.size();
5133
5543
  }
@@ -5163,11 +5573,11 @@ int whisper_full_with_state(
5163
5573
  }
5164
5574
  }
5165
5575
 
5166
- // distilled models require the "no_timestamps" token
5576
+ // first release distilled models require the "no_timestamps" token
5167
5577
  {
5168
- const bool is_distil = ctx->model.hparams.n_text_layer == 2;
5578
+ const bool is_distil = ctx->model.hparams.n_text_layer == 2 && ctx->model.hparams.n_vocab != 51866;
5169
5579
  if (is_distil && !params.no_timestamps) {
5170
- WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
5580
+ WHISPER_LOG_WARN("%s: using first release distilled models - forcing no_timestamps\n", __func__);
5171
5581
  params.no_timestamps = true;
5172
5582
  }
5173
5583
  }
@@ -5303,13 +5713,34 @@ int whisper_full_with_state(
5303
5713
  }
5304
5714
  WHISPER_LOG_DEBUG("\n\n");
5305
5715
 
5716
+ // recreate the KV cache if the number of decoders has changed
5717
+ if (state->kv_self_n_dec < n_decoders_cur) {
5718
+ WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
5719
+
5720
+ whisper_kv_cache_free(state->kv_self);
5721
+
5722
+ // overallocate to workaround KV cache fragmentation issues
5723
+ const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
5724
+
5725
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
5726
+ ctx->model.hparams.n_text_state,
5727
+ ctx->model.hparams.n_text_layer,
5728
+ WSP_GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
5729
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
5730
+ whisper_free_state(state);
5731
+ return -7;
5732
+ }
5733
+
5734
+ state->kv_self_n_dec = n_decoders_cur;
5735
+ }
5736
+
5306
5737
  whisper_kv_cache_clear(state->kv_self);
5307
5738
 
5308
5739
  whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5309
5740
 
5310
- if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5741
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5311
5742
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5312
- return -7;
5743
+ return -8;
5313
5744
  }
5314
5745
 
5315
5746
  {
@@ -5420,7 +5851,10 @@ int whisper_full_with_state(
5420
5851
  beam_candidates.begin(),
5421
5852
  beam_candidates.end(),
5422
5853
  [](const beam_candidate & a, const beam_candidate & b) {
5423
- return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
5854
+ if (a.sequence.sum_logprobs_all != b.sequence.sum_logprobs_all) {
5855
+ return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
5856
+ }
5857
+ return a.decoder_idx < b.decoder_idx;
5424
5858
  });
5425
5859
 
5426
5860
  uint32_t cur_c = 0;
@@ -5438,7 +5872,7 @@ int whisper_full_with_state(
5438
5872
 
5439
5873
  auto & cur = beam_candidates[cur_c++];
5440
5874
 
5441
- while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
5875
+ while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) {
5442
5876
  ++cur_c;
5443
5877
  }
5444
5878
 
@@ -5604,9 +6038,9 @@ int whisper_full_with_state(
5604
6038
 
5605
6039
  assert(batch.n_tokens > 0);
5606
6040
 
5607
- if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
6041
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5608
6042
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5609
- return -8;
6043
+ return -9;
5610
6044
  }
5611
6045
 
5612
6046
  const int64_t t_start_sample_us = wsp_ggml_time_us();
@@ -5727,6 +6161,9 @@ int whisper_full_with_state(
5727
6161
 
5728
6162
  const auto & tokens_cur = best_decoder.sequence.tokens;
5729
6163
 
6164
+ // [EXPERIMENTAL] Token-level timestamps with DTW
6165
+ const auto n_segments_before = state->result_all.size();
6166
+
5730
6167
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
5731
6168
 
5732
6169
  // update prompt_past
@@ -5764,8 +6201,8 @@ int whisper_full_with_state(
5764
6201
  const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
5765
6202
 
5766
6203
  if (!text.empty()) {
5767
- const auto tt0 = params.speed_up ? 2*t0 : t0;
5768
- const auto tt1 = params.speed_up ? 2*t1 : t1;
6204
+ const auto tt0 = t0;
6205
+ const auto tt1 = t1;
5769
6206
 
5770
6207
  if (params.print_realtime) {
5771
6208
  if (params.print_timestamps) {
@@ -5793,7 +6230,7 @@ int whisper_full_with_state(
5793
6230
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
5794
6231
  }
5795
6232
  }
5796
- if (params.new_segment_callback) {
6233
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
5797
6234
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
5798
6235
  }
5799
6236
  }
@@ -5811,8 +6248,8 @@ int whisper_full_with_state(
5811
6248
  if (!text.empty()) {
5812
6249
  const auto t1 = seek + seek_delta;
5813
6250
 
5814
- const auto tt0 = params.speed_up ? 2*t0 : t0;
5815
- const auto tt1 = params.speed_up ? 2*t1 : t1;
6251
+ const auto tt0 = t0;
6252
+ const auto tt1 = t1;
5816
6253
 
5817
6254
  if (params.print_realtime) {
5818
6255
  if (params.print_timestamps) {
@@ -5838,12 +6275,28 @@ int whisper_full_with_state(
5838
6275
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
5839
6276
  }
5840
6277
  }
5841
- if (params.new_segment_callback) {
6278
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
5842
6279
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
5843
6280
  }
5844
6281
  }
5845
6282
  }
5846
6283
 
6284
+ // FIXME: will timestamp offsets be correct?
6285
+ // [EXPERIMENTAL] Token-level timestamps with DTW
6286
+ {
6287
+ const int n_segments = state->result_all.size() - n_segments_before;
6288
+ if (ctx->params.dtw_token_timestamps && n_segments) {
6289
+ const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
6290
+ whisper_exp_compute_token_level_timestamps_dtw(
6291
+ ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
6292
+ if (params.new_segment_callback) {
6293
+ for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
6294
+ params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
6295
+ }
6296
+ }
6297
+ }
6298
+ }
6299
+
5847
6300
  // update audio window
5848
6301
  seek += seek_delta;
5849
6302
 
@@ -6603,7 +7056,7 @@ static void whisper_exp_compute_token_level_timestamps(
6603
7056
  k++;
6604
7057
  }
6605
7058
  tokens[j].t1 = sample_to_timestamp(k);
6606
- if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
7059
+ if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
6607
7060
  tokens[j].t1 = tokens[j + 1].t0;
6608
7061
  } else {
6609
7062
  s1 = k;
@@ -6646,6 +7099,322 @@ static void whisper_exp_compute_token_level_timestamps(
6646
7099
  //}
6647
7100
  }
6648
7101
 
7102
+ //
7103
+ // token level timestamps - dtw version
7104
+ //
7105
+
7106
+ // n_text_layer -> total text layers on model
7107
+ // n_head -> total heads per text layer on model
7108
+ static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) {
7109
+ std::vector<uint32_t> ret;
7110
+ if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
7111
+ return ret;
7112
+ } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
7113
+ if (il >= n_text_layer - cparams.dtw_n_top) {
7114
+ for (int32_t i = 0; i < n_head; ++i) {
7115
+ ret.push_back(i);
7116
+ }
7117
+ }
7118
+ } else {
7119
+ const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
7120
+ for (size_t i = 0; i < aheads.n_heads; ++i) {
7121
+ if (aheads.heads[i].n_text_layer == il) {
7122
+ ret.push_back(aheads.heads[i].n_head);
7123
+ }
7124
+ }
7125
+ }
7126
+ return ret;
7127
+ }
7128
+
7129
+ // dtw + backtrace to return found path
7130
+ // based on
7131
+ // https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
7132
+ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tensor * x) {
7133
+ WHISPER_ASSERT(wsp_ggml_n_dims(x) == 2);
7134
+
7135
+ int64_t N = x->ne[0];
7136
+ int64_t M = x->ne[1];
7137
+ struct wsp_ggml_tensor * cost = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, N + 1, M + 1);
7138
+ struct wsp_ggml_tensor * trace = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, N + 1, M + 1);
7139
+
7140
+ cost = wsp_ggml_set_f32(cost, INFINITY);
7141
+ trace = wsp_ggml_set_f32(trace, -1);
7142
+ wsp_ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
7143
+
7144
+ // dtw
7145
+ // supposedly can be optmized by computing diagonals in parallel ?
7146
+ // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
7147
+ for (int64_t j = 1; j < M + 1; ++j) {
7148
+ for (int64_t i = 1; i < N + 1; ++i) {
7149
+ float c0 = wsp_ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
7150
+ float c1 = wsp_ggml_get_f32_nd(cost, i - 1, j, 0, 0);
7151
+ float c2 = wsp_ggml_get_f32_nd(cost, i, j - 1, 0, 0);
7152
+
7153
+ float c;
7154
+ int32_t t;
7155
+ if (c0 < c1 && c0 < c2) {
7156
+ c = c0;
7157
+ t = 0;
7158
+ } else if (c1 < c0 && c1 < c2) {
7159
+ c = c1;
7160
+ t = 1;
7161
+ } else {
7162
+ c = c2;
7163
+ t = 2;
7164
+ }
7165
+
7166
+ c = wsp_ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
7167
+ wsp_ggml_set_f32_nd(cost, i, j, 0, 0, c);
7168
+ wsp_ggml_set_i32_nd(trace, i, j, 0, 0, t);
7169
+ }
7170
+ }
7171
+
7172
+ // Backtrace
7173
+ const int64_t BT_MAX_ROWS = N + M - 1;
7174
+ struct wsp_ggml_tensor * bt = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, BT_MAX_ROWS, 2);
7175
+ // trace[0, :] = 2;
7176
+ for (int64_t i = 0; i < M + 1; ++i)
7177
+ wsp_ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
7178
+ //trace[:, 0] = 1;
7179
+ for (int64_t i = 0; i < N + 1; ++i)
7180
+ wsp_ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
7181
+ int bt_row_idx = BT_MAX_ROWS - 1;
7182
+ int64_t i = N;
7183
+ int64_t j = M;
7184
+ while (i > 0 || j > 0) {
7185
+ wsp_ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
7186
+ wsp_ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
7187
+ --bt_row_idx;
7188
+
7189
+ int32_t t = wsp_ggml_get_i32_nd(trace, i, j, 0, 0);
7190
+ if (t == 0) {
7191
+ --i;
7192
+ --j;
7193
+ } else if (t == 1) {
7194
+ --i;
7195
+ } else if (t == 2) {
7196
+ --j;
7197
+ } else {
7198
+ WHISPER_ASSERT(0);
7199
+ }
7200
+ }
7201
+
7202
+ // FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs)
7203
+ // Clip + transpose
7204
+ // This might not be entirely necessary for our case, but leaving it for now so output matrix
7205
+ // is identical to dtw on openAI timing.py
7206
+ const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1;
7207
+ wsp_ggml_tensor * r = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, 2, result_n_cols);
7208
+ for (int64_t i = 0; i < 2; ++i) {
7209
+ for (int64_t j = 0; j < result_n_cols; ++j) {
7210
+ int32_t v = wsp_ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
7211
+ wsp_ggml_set_i32_nd(r, i, j, 0, 0, v);
7212
+ }
7213
+ }
7214
+
7215
+ return r;
7216
+ }
7217
+
7218
+ struct median_filter_user_data {
7219
+ int filter_width;
7220
+ };
7221
+
7222
+ static void median_filter(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
7223
+ if (ith != 0) {
7224
+ return;
7225
+ }
7226
+ int filter_width = ((median_filter_user_data *) userdata)->filter_width;
7227
+ WHISPER_ASSERT(filter_width < a->ne[2]);
7228
+ WHISPER_ASSERT(filter_width % 2);
7229
+ WHISPER_ASSERT(wsp_ggml_n_dims(a) == 3);
7230
+ WHISPER_ASSERT(a->type == WSP_GGML_TYPE_F32);
7231
+
7232
+ std::vector<float> filter;
7233
+ filter.reserve(filter_width);
7234
+ for (int64_t i = 0; i < a->ne[0]; ++i) {
7235
+ for (int64_t j = 0; j < a->ne[1]; ++j) {
7236
+ for (int64_t k = 0; k < a->ne[2]; ++k) {
7237
+ for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) {
7238
+ // "reflect" padding
7239
+ int64_t idx = k + off;
7240
+ if (idx < 0) {
7241
+ idx = -idx;
7242
+ } else if (idx >= a->ne[2]) {
7243
+ idx = 2*(a->ne[2] - 1) - idx;
7244
+ }
7245
+
7246
+ filter.push_back(wsp_ggml_get_f32_nd(a, i, j, idx, 0));
7247
+ }
7248
+ std::sort(filter.begin(), filter.end());
7249
+ const float v = filter[filter.size()/2];
7250
+ wsp_ggml_set_f32_nd(dst, i, j, k, 0, v);
7251
+ filter.clear();
7252
+ }
7253
+ }
7254
+ }
7255
+ }
7256
+
7257
+ static void whisper_exp_compute_token_level_timestamps_dtw(
7258
+ struct whisper_context * ctx,
7259
+ struct whisper_state * state,
7260
+ struct whisper_full_params params,
7261
+ int i_segment,
7262
+ size_t n_segments,
7263
+ int seek,
7264
+ int n_frames,
7265
+ int medfilt_width,
7266
+ int n_threads)
7267
+ {
7268
+ const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx;
7269
+ WHISPER_ASSERT(medfilt_width % 2);
7270
+ WHISPER_ASSERT(n_frames <= n_audio_ctx * 2);
7271
+ WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE);
7272
+
7273
+ // FIXME: Allocating mem everytime we call this func
7274
+ // Our ggml buffer should be pre-allocated somewhere during init and reused
7275
+ // when we call this function
7276
+ struct wsp_ggml_init_params gparams = {
7277
+ /*.mem_size =*/ ctx->params.dtw_mem_size,
7278
+ /*.mem_buffer =*/ NULL,
7279
+ /*.no_alloc =*/ false,
7280
+ };
7281
+ struct wsp_ggml_context * gctx = wsp_ggml_init(gparams);
7282
+
7283
+ // Build token sequence that will be passed to decoder
7284
+ // sot + [lang] + text result + eot
7285
+ std::vector<whisper_token> tokens = { whisper_token_sot(ctx), };
7286
+ if (whisper_is_multilingual(ctx)) {
7287
+ const int lang_id = whisper_lang_id(params.language);
7288
+ state->lang_id = lang_id;
7289
+ tokens.push_back(whisper_token_lang(ctx, lang_id));
7290
+ }
7291
+ const size_t sot_sequence_length = tokens.size();
7292
+ tokens.push_back(whisper_token_not(ctx));
7293
+ for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
7294
+ auto & segment = state->result_all[i];
7295
+ for (auto &t: segment.tokens) {
7296
+ // Only text tokens
7297
+ if (t.id < whisper_token_eot(ctx)) {
7298
+ tokens.push_back(t.id);
7299
+ }
7300
+ }
7301
+ }
7302
+ tokens.push_back(whisper_token_eot(ctx));
7303
+
7304
+ // Get result tokens, pass then along to decoder to get cross attention QKs
7305
+ // used in timestamping
7306
+ // Decoder already returns only alignment head QKs, already concatenated in
7307
+ // one tensor.
7308
+ whisper_kv_cache_clear(state->kv_self);
7309
+ whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0);
7310
+ whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1);
7311
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
7312
+ WHISPER_LOG_INFO("DECODER FAILED\n");
7313
+ WHISPER_ASSERT(0);
7314
+ }
7315
+ WHISPER_ASSERT(state->aheads_cross_QKs != nullptr);
7316
+
7317
+ const auto n_audio_tokens = n_frames/2;
7318
+ WHISPER_ASSERT(state->aheads_cross_QKs != NULL);
7319
+ WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]);
7320
+ const auto n_tokens = state->aheads_cross_QKs->ne[0];
7321
+ const auto n_heads = state->aheads_cross_QKs->ne[2];
7322
+
7323
+ // Copy data from decoder buffer to a local CPU tensor, discarding unused audio
7324
+ // tokens (i.e. discarding rows at the end of tensor)
7325
+ // IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims
7326
+ // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7327
+ WHISPER_ASSERT(state->aheads_cross_QKs->type == WSP_GGML_TYPE_F32);
7328
+ WHISPER_ASSERT(wsp_ggml_is_contiguous(state->aheads_cross_QKs));
7329
+ wsp_ggml_tensor * w = wsp_ggml_new_tensor_3d(gctx, WSP_GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads);
7330
+ auto & data = state->aheads_cross_QKs_data;
7331
+ data.resize(n_tokens * n_audio_ctx * n_heads);
7332
+ wsp_ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);
7333
+ for (int k = 0; k < n_heads; ++k) {
7334
+ for (int j = 0; j < n_audio_tokens; ++j) {
7335
+ memcpy(
7336
+ (char *) w->data + j * w->nb[1] + k * w->nb[2],
7337
+ data.data() + j * n_tokens + k * n_tokens * n_audio_ctx,
7338
+ n_tokens * sizeof(float)
7339
+ );
7340
+ }
7341
+ }
7342
+
7343
+ // Normalize - in original OpenAI code, this is done over dim=-2. In this case,
7344
+ // we already permuted N_TOKENS dimension to columns on last loop, becase wsp_ggml_norm
7345
+ // operates over columns. Afterwards, permute to a shape that facilitates mean
7346
+ // operation (after median filter)
7347
+ // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7348
+ // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7349
+ w = wsp_ggml_norm(gctx, w, 1e-9f);
7350
+ w = wsp_ggml_permute(gctx, wsp_ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
7351
+
7352
+ // Pass median filter - this is done over AUDIO_TOKENS dimension.
7353
+ // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7354
+ // OUT: Same dims
7355
+ median_filter_user_data mf_user_data = {medfilt_width};
7356
+ w = wsp_ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data);
7357
+
7358
+ // Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT
7359
+ // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7360
+ // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims
7361
+ w = wsp_ggml_mean(gctx, w);
7362
+ w = wsp_ggml_scale(gctx, w, -1.0);
7363
+ w = wsp_ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]);
7364
+
7365
+ // Remove SOT sequence and EOT
7366
+ // Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS
7367
+ w = wsp_ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]);
7368
+
7369
+ // Compute
7370
+ struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(gctx);
7371
+ wsp_ggml_build_forward_expand(gf, w);
7372
+ wsp_ggml_graph_compute_with_ctx(gctx, gf, n_threads);
7373
+
7374
+ wsp_ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
7375
+
7376
+ // Place timestamps on segments
7377
+ int32_t last_v = 0;
7378
+ auto seg_i = state->result_all.begin() + i_segment;
7379
+ auto tok_i = seg_i->tokens.begin();
7380
+ for (int i = 0; i < alignment->ne[1]; ++i) {
7381
+ int32_t v = wsp_ggml_get_i32_nd(alignment, 0, i, 0, 0);
7382
+ if (v != last_v) {
7383
+ int32_t time_index = wsp_ggml_get_i32_nd(alignment, 1, i, 0, 0);
7384
+ int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
7385
+ last_v = v;
7386
+
7387
+ // Skip non-text tokens
7388
+ while (!(tok_i->id < whisper_token_eot(ctx))) {
7389
+ ++tok_i;
7390
+ if (tok_i == seg_i->tokens.end()) {
7391
+ ++seg_i;
7392
+ tok_i = seg_i->tokens.begin();
7393
+ }
7394
+ }
7395
+
7396
+ tok_i->t_dtw = timestamp;
7397
+ ++tok_i;
7398
+ if (tok_i == seg_i->tokens.end()) {
7399
+ ++seg_i;
7400
+ tok_i = seg_i->tokens.begin();
7401
+ }
7402
+ }
7403
+ }
7404
+
7405
+ // Print DTW timestamps
7406
+ /*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
7407
+ auto & segment = state->result_all[i];
7408
+ for (auto &t: segment.tokens) {
7409
+ const char * tok = whisper_token_to_str(ctx, t.id);
7410
+ fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100);
7411
+ }
7412
+ fprintf(stderr, "\n");
7413
+ }*/
7414
+
7415
+ wsp_ggml_free(gctx);
7416
+ }
7417
+
6649
7418
  void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
6650
7419
  g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
6651
7420
  g_state.log_callback_user_data = user_data;