whisper.rn 0.4.0-rc.8 → 0.4.0-rc.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +2 -1
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
- package/android/src/main/jni.cpp +29 -1
- package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
- package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
- package/cpp/ggml-aarch64.c +3209 -0
- package/cpp/ggml-aarch64.h +39 -0
- package/cpp/ggml-alloc.c +725 -517
- package/cpp/ggml-alloc.h +47 -65
- package/cpp/ggml-backend-impl.h +166 -55
- package/cpp/ggml-backend.cpp +2635 -0
- package/cpp/ggml-backend.h +202 -85
- package/cpp/ggml-common.h +1853 -0
- package/cpp/ggml-cpu-impl.h +614 -0
- package/cpp/ggml-impl.h +143 -180
- package/cpp/ggml-metal.h +13 -11
- package/cpp/ggml-metal.m +2955 -1632
- package/cpp/ggml-quants.c +9824 -3263
- package/cpp/ggml-quants.h +133 -248
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +8482 -5142
- package/cpp/ggml.h +633 -349
- package/cpp/rn-whisper.cpp +91 -0
- package/cpp/rn-whisper.h +2 -0
- package/cpp/whisper.cpp +1427 -658
- package/cpp/whisper.h +84 -28
- package/ios/RNWhisper.mm +124 -37
- package/ios/RNWhisperAudioUtils.h +1 -0
- package/ios/RNWhisperAudioUtils.m +20 -13
- package/ios/RNWhisperContext.h +3 -2
- package/ios/RNWhisperContext.mm +39 -7
- package/jest/mock.js +9 -1
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +48 -19
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +48 -19
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +6 -3
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +25 -3
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +6 -5
- package/src/NativeRNWhisper.ts +12 -3
- package/src/index.ts +63 -24
- package/src/version.json +1 -1
- package/whisper-rn.podspec +9 -2
- package/cpp/ggml-backend.c +0 -1718
- package/cpp/ggml-metal-whisper.metal +0 -5820
package/cpp/whisper.cpp
CHANGED
|
@@ -8,14 +8,30 @@
|
|
|
8
8
|
#include "ggml-metal.h"
|
|
9
9
|
#endif
|
|
10
10
|
|
|
11
|
-
#ifdef
|
|
11
|
+
#ifdef WSP_GGML_USE_CUDA
|
|
12
12
|
#include "ggml-cuda.h"
|
|
13
13
|
#endif
|
|
14
14
|
|
|
15
|
+
#ifdef WSP_GGML_USE_SYCL
|
|
16
|
+
#include "ggml-sycl.h"
|
|
17
|
+
#endif
|
|
18
|
+
|
|
19
|
+
#ifdef WSP_GGML_USE_VULKAN
|
|
20
|
+
#include "ggml-vulkan.h"
|
|
21
|
+
#endif
|
|
22
|
+
|
|
23
|
+
#ifdef WSP_GGML_USE_BLAS
|
|
24
|
+
#include "ggml-blas.h"
|
|
25
|
+
#endif
|
|
26
|
+
|
|
15
27
|
#ifdef WHISPER_USE_OPENVINO
|
|
16
28
|
#include "openvino/whisper-openvino-encoder.h"
|
|
17
29
|
#endif
|
|
18
30
|
|
|
31
|
+
#ifdef WSP_GGML_USE_CANN
|
|
32
|
+
#include "ggml-cann.h"
|
|
33
|
+
#endif
|
|
34
|
+
|
|
19
35
|
#include "ggml.h"
|
|
20
36
|
#include "ggml-alloc.h"
|
|
21
37
|
#include "ggml-backend.h"
|
|
@@ -37,6 +53,7 @@
|
|
|
37
53
|
#include <regex>
|
|
38
54
|
#include <random>
|
|
39
55
|
#include <functional>
|
|
56
|
+
#include <codecvt>
|
|
40
57
|
|
|
41
58
|
#if defined(_MSC_VER)
|
|
42
59
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
@@ -143,8 +160,6 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char *
|
|
|
143
160
|
} \
|
|
144
161
|
} while (0)
|
|
145
162
|
|
|
146
|
-
//#define WHISPER_USE_FLASH_ATTN
|
|
147
|
-
//#define WHISPER_USE_FLASH_FF
|
|
148
163
|
#define WHISPER_MAX_DECODERS 8
|
|
149
164
|
#define WHISPER_MAX_NODES 4096
|
|
150
165
|
|
|
@@ -156,11 +171,11 @@ static bool wsp_ggml_graph_compute_helper(
|
|
|
156
171
|
struct wsp_ggml_cgraph * graph,
|
|
157
172
|
std::vector<uint8_t> & buf,
|
|
158
173
|
int n_threads,
|
|
159
|
-
|
|
174
|
+
wsp_ggml_abort_callback abort_callback,
|
|
160
175
|
void * abort_callback_data) {
|
|
161
|
-
struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads);
|
|
176
|
+
struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads, nullptr);
|
|
162
177
|
|
|
163
|
-
plan.abort_callback
|
|
178
|
+
plan.abort_callback = abort_callback;
|
|
164
179
|
plan.abort_callback_data = abort_callback_data;
|
|
165
180
|
|
|
166
181
|
if (plan.work_size > 0) {
|
|
@@ -172,18 +187,25 @@ static bool wsp_ggml_graph_compute_helper(
|
|
|
172
187
|
}
|
|
173
188
|
|
|
174
189
|
static bool wsp_ggml_graph_compute_helper(
|
|
175
|
-
|
|
190
|
+
wsp_ggml_backend_sched_t sched,
|
|
176
191
|
struct wsp_ggml_cgraph * graph,
|
|
177
192
|
int n_threads) {
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
193
|
+
|
|
194
|
+
for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(sched); ++i) {
|
|
195
|
+
wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(sched, i);
|
|
196
|
+
if (wsp_ggml_backend_is_cpu(backend)) {
|
|
197
|
+
wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
|
|
198
|
+
}
|
|
199
|
+
#ifdef WSP_GGML_USE_BLAS
|
|
200
|
+
if (wsp_ggml_backend_is_blas(backend)) {
|
|
201
|
+
wsp_ggml_backend_blas_set_n_threads(backend, n_threads);
|
|
202
|
+
}
|
|
185
203
|
#endif
|
|
186
|
-
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
bool t = wsp_ggml_backend_sched_graph_compute(sched, graph) == WSP_GGML_STATUS_SUCCESS;
|
|
207
|
+
wsp_ggml_backend_sched_reset(sched);
|
|
208
|
+
return t;
|
|
187
209
|
}
|
|
188
210
|
|
|
189
211
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
|
@@ -347,6 +369,37 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
|
347
369
|
{ "yue", { 99, "cantonese", } },
|
|
348
370
|
};
|
|
349
371
|
|
|
372
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
373
|
+
static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} };
|
|
374
|
+
static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} };
|
|
375
|
+
static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} };
|
|
376
|
+
static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} };
|
|
377
|
+
static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} };
|
|
378
|
+
static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} };
|
|
379
|
+
static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} };
|
|
380
|
+
static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} };
|
|
381
|
+
static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
|
|
382
|
+
static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
|
|
383
|
+
static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
|
|
384
|
+
static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
|
|
385
|
+
|
|
386
|
+
static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
|
387
|
+
{ WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
|
|
388
|
+
{ WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } },
|
|
389
|
+
{ WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } },
|
|
390
|
+
{ WHISPER_AHEADS_BASE, { 8, g_aheads_base } },
|
|
391
|
+
{ WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } },
|
|
392
|
+
{ WHISPER_AHEADS_SMALL, { 10, g_aheads_small } },
|
|
393
|
+
{ WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } },
|
|
394
|
+
{ WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } },
|
|
395
|
+
{ WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
|
|
396
|
+
{ WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
|
|
397
|
+
{ WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
|
|
398
|
+
{ WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
|
|
399
|
+
};
|
|
400
|
+
|
|
401
|
+
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
|
|
402
|
+
|
|
350
403
|
struct whisper_mel {
|
|
351
404
|
int n_len;
|
|
352
405
|
int n_len_org;
|
|
@@ -409,7 +462,7 @@ struct whisper_batch {
|
|
|
409
462
|
|
|
410
463
|
whisper_token * token;
|
|
411
464
|
whisper_pos * pos;
|
|
412
|
-
int32_t * n_seq_id;
|
|
465
|
+
int32_t * n_seq_id; // always 1, here for consistency with llama.cpp
|
|
413
466
|
whisper_seq_id ** seq_id; // null terminated
|
|
414
467
|
int8_t * logits;
|
|
415
468
|
};
|
|
@@ -469,54 +522,42 @@ struct whisper_pair {
|
|
|
469
522
|
whisper_pair() : first(A()), second(B()) {}
|
|
470
523
|
};
|
|
471
524
|
|
|
472
|
-
//
|
|
473
|
-
struct
|
|
474
|
-
|
|
525
|
+
// wsp_ggml_backend_sched wrapper for whisper usage
|
|
526
|
+
struct whisper_sched {
|
|
527
|
+
wsp_ggml_backend_sched_t sched = nullptr;
|
|
475
528
|
|
|
476
529
|
std::vector<uint8_t> meta;
|
|
477
|
-
|
|
478
|
-
wsp_ggml_backend_buffer_t buffer;
|
|
479
530
|
};
|
|
480
531
|
|
|
481
|
-
static size_t
|
|
482
|
-
|
|
532
|
+
static size_t whisper_sched_size(struct whisper_sched & allocr) {
|
|
533
|
+
size_t size = allocr.meta.size();
|
|
534
|
+
for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
|
|
535
|
+
wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(allocr.sched, i);
|
|
536
|
+
size += wsp_ggml_backend_sched_get_buffer_size(allocr.sched, backend);
|
|
537
|
+
}
|
|
538
|
+
return size;
|
|
483
539
|
}
|
|
484
540
|
|
|
485
541
|
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
|
486
|
-
static
|
|
487
|
-
auto &
|
|
542
|
+
static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<wsp_ggml_backend_t> backends, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
|
|
543
|
+
auto & sched = allocr.sched;
|
|
488
544
|
auto & meta = allocr.meta;
|
|
489
545
|
|
|
490
|
-
|
|
546
|
+
sched = wsp_ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
|
|
491
547
|
|
|
492
548
|
meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
|
|
493
549
|
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
return;
|
|
550
|
+
// since there are dependencies between the different graphs,
|
|
551
|
+
// we need to allocate them instead of only reserving to get the correct compute buffer size
|
|
552
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, get_graph())) {
|
|
553
|
+
// failed to allocate the compute buffer
|
|
554
|
+
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
|
555
|
+
return false;
|
|
501
556
|
}
|
|
502
557
|
|
|
503
|
-
|
|
504
|
-
auto & buffer = allocr.buffer;
|
|
505
|
-
|
|
506
|
-
size_t size = wsp_ggml_allocr_max_size(alloc);
|
|
507
|
-
|
|
508
|
-
wsp_ggml_allocr_free(alloc);
|
|
558
|
+
wsp_ggml_backend_sched_reset(sched);
|
|
509
559
|
|
|
510
|
-
|
|
511
|
-
alloc = wsp_ggml_allocr_new_from_buffer(buffer);
|
|
512
|
-
}
|
|
513
|
-
|
|
514
|
-
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
|
515
|
-
if (allocr.alloc) {
|
|
516
|
-
wsp_ggml_allocr_free(allocr.alloc);
|
|
517
|
-
wsp_ggml_backend_buffer_free(allocr.buffer);
|
|
518
|
-
allocr.alloc = nullptr;
|
|
519
|
-
}
|
|
560
|
+
return true;
|
|
520
561
|
}
|
|
521
562
|
|
|
522
563
|
// medium
|
|
@@ -658,9 +699,9 @@ struct whisper_kv_cache {
|
|
|
658
699
|
struct wsp_ggml_tensor * k;
|
|
659
700
|
struct wsp_ggml_tensor * v;
|
|
660
701
|
|
|
661
|
-
|
|
702
|
+
wsp_ggml_backend_buffer_t buffer = nullptr;
|
|
662
703
|
|
|
663
|
-
|
|
704
|
+
std::vector<uint8_t> ctx_buf;
|
|
664
705
|
};
|
|
665
706
|
|
|
666
707
|
struct whisper_model {
|
|
@@ -698,10 +739,10 @@ struct whisper_model {
|
|
|
698
739
|
std::vector<whisper_layer_decoder> layers_decoder;
|
|
699
740
|
|
|
700
741
|
// ggml context that contains all the meta information about the model tensors
|
|
701
|
-
struct wsp_ggml_context * ctx;
|
|
742
|
+
struct wsp_ggml_context * ctx = nullptr;
|
|
702
743
|
|
|
703
744
|
// the model backend data is read-only and can be shared between processors
|
|
704
|
-
|
|
745
|
+
wsp_ggml_backend_buffer_t buffer = nullptr;
|
|
705
746
|
|
|
706
747
|
// tensors
|
|
707
748
|
int n_loaded;
|
|
@@ -766,6 +807,13 @@ struct whisper_decoder {
|
|
|
766
807
|
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
767
808
|
};
|
|
768
809
|
|
|
810
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
811
|
+
struct whisper_aheads_masks {
|
|
812
|
+
std::vector<struct wsp_ggml_tensor *> m; // One mask per text layer.
|
|
813
|
+
struct wsp_ggml_context * ctx = nullptr;
|
|
814
|
+
wsp_ggml_backend_buffer_t buffer = nullptr;
|
|
815
|
+
};
|
|
816
|
+
|
|
769
817
|
struct whisper_state {
|
|
770
818
|
int64_t t_sample_us = 0;
|
|
771
819
|
int64_t t_encode_us = 0;
|
|
@@ -782,6 +830,9 @@ struct whisper_state {
|
|
|
782
830
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
|
783
831
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
|
784
832
|
|
|
833
|
+
// number of decoders for which we have constructed the KV cache
|
|
834
|
+
int32_t kv_self_n_dec = 0;
|
|
835
|
+
|
|
785
836
|
// unified self-attention KV cache for all decoders
|
|
786
837
|
whisper_kv_cache kv_self;
|
|
787
838
|
|
|
@@ -789,21 +840,22 @@ struct whisper_state {
|
|
|
789
840
|
// shared between all decoders
|
|
790
841
|
whisper_kv_cache kv_cross;
|
|
791
842
|
|
|
843
|
+
// padded buffer for flash-attention
|
|
844
|
+
whisper_kv_cache kv_pad;
|
|
845
|
+
|
|
792
846
|
whisper_mel mel;
|
|
793
847
|
|
|
794
848
|
whisper_batch batch;
|
|
795
849
|
|
|
796
850
|
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
|
797
851
|
|
|
798
|
-
wsp_ggml_backend_t
|
|
852
|
+
std::vector<wsp_ggml_backend_t> backends;
|
|
799
853
|
|
|
800
|
-
// ggml-alloc:
|
|
801
854
|
// - stores meta info about the intermediate tensors into the `meta` buffers
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
whisper_allocr alloc_decode;
|
|
855
|
+
whisper_sched sched_conv;
|
|
856
|
+
whisper_sched sched_encode;
|
|
857
|
+
whisper_sched sched_cross;
|
|
858
|
+
whisper_sched sched_decode;
|
|
807
859
|
|
|
808
860
|
// result of the encoder
|
|
809
861
|
struct wsp_ggml_tensor * embd_conv = nullptr;
|
|
@@ -839,6 +891,11 @@ struct whisper_state {
|
|
|
839
891
|
|
|
840
892
|
std::vector<float> energy; // PCM signal energy
|
|
841
893
|
|
|
894
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
895
|
+
whisper_aheads_masks aheads_masks;
|
|
896
|
+
wsp_ggml_tensor * aheads_cross_QKs = nullptr;
|
|
897
|
+
std::vector<float> aheads_cross_QKs_data;
|
|
898
|
+
|
|
842
899
|
// [EXPERIMENTAL] speed-up techniques
|
|
843
900
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
|
844
901
|
};
|
|
@@ -857,8 +914,6 @@ struct whisper_context {
|
|
|
857
914
|
|
|
858
915
|
whisper_state * state = nullptr;
|
|
859
916
|
|
|
860
|
-
wsp_ggml_backend_t backend = nullptr;
|
|
861
|
-
|
|
862
917
|
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
863
918
|
};
|
|
864
919
|
|
|
@@ -876,21 +931,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
|
876
931
|
BYTESWAP_VALUE(dest);
|
|
877
932
|
}
|
|
878
933
|
|
|
879
|
-
static bool
|
|
880
|
-
const struct whisper_hparams & hparams,
|
|
934
|
+
static bool whisper_kv_cache_init(
|
|
881
935
|
struct whisper_kv_cache & cache,
|
|
882
936
|
wsp_ggml_backend_t backend,
|
|
883
937
|
wsp_ggml_type wtype,
|
|
938
|
+
int64_t n_text_state,
|
|
939
|
+
int64_t n_text_layer,
|
|
884
940
|
int n_ctx) {
|
|
885
|
-
const int64_t n_text_state = hparams.n_text_state;
|
|
886
|
-
const int64_t n_text_layer = hparams.n_text_layer;
|
|
887
|
-
|
|
888
941
|
const int64_t n_mem = n_text_layer*n_ctx;
|
|
889
942
|
const int64_t n_elements = n_text_state*n_mem;
|
|
890
943
|
|
|
944
|
+
cache.ctx_buf.resize(2*wsp_ggml_tensor_overhead());
|
|
945
|
+
|
|
891
946
|
struct wsp_ggml_init_params params = {
|
|
892
|
-
/*.mem_size =*/
|
|
893
|
-
/*.mem_buffer =*/
|
|
947
|
+
/*.mem_size =*/ cache.ctx_buf.size(),
|
|
948
|
+
/*.mem_buffer =*/ cache.ctx_buf.data(),
|
|
894
949
|
/*.no_alloc =*/ true,
|
|
895
950
|
};
|
|
896
951
|
|
|
@@ -900,39 +955,31 @@ static bool kv_cache_init(
|
|
|
900
955
|
cache.cells.clear();
|
|
901
956
|
cache.cells.resize(n_ctx);
|
|
902
957
|
|
|
903
|
-
|
|
958
|
+
struct wsp_ggml_context * ctx = wsp_ggml_init(params);
|
|
904
959
|
|
|
905
|
-
if (!
|
|
906
|
-
WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
|
|
960
|
+
if (!ctx) {
|
|
961
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
|
|
907
962
|
return false;
|
|
908
963
|
}
|
|
909
964
|
|
|
910
|
-
cache.k = wsp_ggml_new_tensor_1d(
|
|
911
|
-
cache.v = wsp_ggml_new_tensor_1d(
|
|
965
|
+
cache.k = wsp_ggml_new_tensor_1d(ctx, wtype, n_elements);
|
|
966
|
+
cache.v = wsp_ggml_new_tensor_1d(ctx, wtype, n_elements);
|
|
912
967
|
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
{
|
|
919
|
-
wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(cache.buffer);
|
|
968
|
+
cache.buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx, backend);
|
|
969
|
+
if (!cache.buffer) {
|
|
970
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
|
|
971
|
+
return false;
|
|
972
|
+
}
|
|
920
973
|
|
|
921
|
-
|
|
922
|
-
wsp_ggml_allocr_alloc(alloc, cache.v);
|
|
974
|
+
wsp_ggml_backend_buffer_clear(cache.buffer, 0);
|
|
923
975
|
|
|
924
|
-
|
|
925
|
-
}
|
|
976
|
+
wsp_ggml_free(ctx);
|
|
926
977
|
|
|
927
978
|
return true;
|
|
928
979
|
}
|
|
929
980
|
|
|
930
|
-
static void
|
|
931
|
-
|
|
932
|
-
wsp_ggml_free(cache.ctx);
|
|
933
|
-
wsp_ggml_backend_buffer_free(cache.buffer);
|
|
934
|
-
cache.ctx = nullptr;
|
|
935
|
-
}
|
|
981
|
+
static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
|
|
982
|
+
wsp_ggml_backend_buffer_free(cache.buffer);
|
|
936
983
|
}
|
|
937
984
|
|
|
938
985
|
static bool whisper_kv_cache_find_slot(
|
|
@@ -1003,6 +1050,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
|
|
1003
1050
|
cache.cells[i].seq_id.clear();
|
|
1004
1051
|
}
|
|
1005
1052
|
cache.head = 0;
|
|
1053
|
+
|
|
1054
|
+
wsp_ggml_backend_buffer_clear(cache.buffer, 0);
|
|
1006
1055
|
}
|
|
1007
1056
|
|
|
1008
1057
|
static void whisper_kv_cache_seq_rm(
|
|
@@ -1053,15 +1102,167 @@ static void whisper_kv_cache_seq_cp(
|
|
|
1053
1102
|
}
|
|
1054
1103
|
}
|
|
1055
1104
|
|
|
1056
|
-
static
|
|
1057
|
-
|
|
1105
|
+
static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
|
|
1106
|
+
if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
|
|
1107
|
+
return 1u;
|
|
1108
|
+
}
|
|
1109
|
+
|
|
1110
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1111
|
+
if (wctx.params.use_gpu) {
|
|
1112
|
+
return 32u;
|
|
1113
|
+
}
|
|
1114
|
+
#endif
|
|
1058
1115
|
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1116
|
+
#ifdef WSP_GGML_USE_CUDA
|
|
1117
|
+
if (wctx.params.use_gpu) {
|
|
1118
|
+
return 256u;
|
|
1119
|
+
}
|
|
1120
|
+
#endif
|
|
1121
|
+
|
|
1122
|
+
return 1u;
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
1126
|
+
static bool aheads_masks_init(
|
|
1127
|
+
const whisper_context_params & cparams,
|
|
1128
|
+
const whisper_hparams & hparams,
|
|
1129
|
+
struct whisper_aheads_masks & aheads_masks,
|
|
1130
|
+
wsp_ggml_backend_t backend) {
|
|
1131
|
+
|
|
1132
|
+
const int32_t n_text_layer = hparams.n_text_layer;
|
|
1133
|
+
const int32_t n_head = hparams.n_text_head;
|
|
1134
|
+
|
|
1135
|
+
// Sanity checks
|
|
1136
|
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
|
|
1137
|
+
WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__);
|
|
1138
|
+
return false;
|
|
1139
|
+
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
|
|
1140
|
+
if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) {
|
|
1141
|
+
WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer);
|
|
1142
|
+
return false;
|
|
1143
|
+
}
|
|
1144
|
+
} else {
|
|
1145
|
+
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
|
|
1146
|
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) {
|
|
1147
|
+
if (aheads.n_heads == 0) {
|
|
1148
|
+
WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__);
|
|
1149
|
+
return false;
|
|
1150
|
+
}
|
|
1151
|
+
if (aheads.heads == NULL) {
|
|
1152
|
+
WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__);
|
|
1153
|
+
return false;
|
|
1154
|
+
}
|
|
1155
|
+
}
|
|
1156
|
+
for (size_t i = 0; i < aheads.n_heads; ++i) {
|
|
1157
|
+
if (aheads.heads[i].n_text_layer >= n_text_layer) {
|
|
1158
|
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer);
|
|
1159
|
+
return false;
|
|
1160
|
+
}
|
|
1161
|
+
if (aheads.heads[i].n_text_layer < 0) {
|
|
1162
|
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__);
|
|
1163
|
+
return false;
|
|
1164
|
+
}
|
|
1165
|
+
if (aheads.heads[i].n_head >= n_head) {
|
|
1166
|
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head);
|
|
1167
|
+
return false;
|
|
1168
|
+
}
|
|
1169
|
+
if (aheads.heads[i].n_head < 0) {
|
|
1170
|
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__);
|
|
1171
|
+
return false;
|
|
1172
|
+
}
|
|
1173
|
+
}
|
|
1174
|
+
}
|
|
1175
|
+
|
|
1176
|
+
struct wsp_ggml_init_params params = {
|
|
1177
|
+
/*.mem_size =*/ (size_t) static_cast<size_t>(n_text_layer)*wsp_ggml_tensor_overhead(),
|
|
1178
|
+
/*.mem_buffer =*/ nullptr,
|
|
1179
|
+
/*.no_alloc =*/ true,
|
|
1180
|
+
};
|
|
1181
|
+
|
|
1182
|
+
aheads_masks.ctx = wsp_ggml_init(params);
|
|
1183
|
+
|
|
1184
|
+
if (!aheads_masks.ctx) {
|
|
1185
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__);
|
|
1186
|
+
return false;
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
for (int64_t il = 0; il < n_text_layer; ++il) {
|
|
1190
|
+
auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
|
|
1191
|
+
if (!aheads.empty()) {
|
|
1192
|
+
aheads_masks.m.push_back(wsp_ggml_new_tensor_2d(aheads_masks.ctx, WSP_GGML_TYPE_F32, n_head, aheads.size()));
|
|
1193
|
+
} else {
|
|
1194
|
+
aheads_masks.m.push_back(nullptr);
|
|
1195
|
+
}
|
|
1196
|
+
}
|
|
1197
|
+
|
|
1198
|
+
aheads_masks.buffer = wsp_ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend);
|
|
1199
|
+
if (!aheads_masks.buffer) {
|
|
1200
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__);
|
|
1201
|
+
return false;
|
|
1202
|
+
}
|
|
1203
|
+
|
|
1204
|
+
// Set data on mask tensors
|
|
1205
|
+
// Since this must be backend agnostic, we write our desired values on mask_data,
|
|
1206
|
+
// and send it to backend with wsp_ggml_backend_tensor_set.
|
|
1207
|
+
// Each mask in N_HEADS*N_ALIGNMENT_HEADS, one per text layer containing alignment
|
|
1208
|
+
// heads. Each row of the mask "marks" one alignment head. E.g. if some text layer
|
|
1209
|
+
// has a total of 10 heads and of those, heads 0,5,6 are alignment heads, the mask
|
|
1210
|
+
// should read:
|
|
1211
|
+
// 1 0 0 0 0 0 0 0 0 0
|
|
1212
|
+
// 0 0 0 0 0 1 0 0 0 0
|
|
1213
|
+
// 0 0 0 0 0 0 1 0 0 0
|
|
1214
|
+
std::vector<float> mask_data;
|
|
1215
|
+
for (int64_t il = 0; il < n_text_layer; ++il) {
|
|
1216
|
+
if (aheads_masks.m[il] != nullptr) {
|
|
1217
|
+
auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
|
|
1218
|
+
|
|
1219
|
+
size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1];
|
|
1220
|
+
size_t data_size_bytes = data_size * sizeof(float);
|
|
1221
|
+
mask_data.resize(data_size);
|
|
1222
|
+
|
|
1223
|
+
std::fill(mask_data.begin(), mask_data.end(), 0);
|
|
1224
|
+
for (size_t ih = 0; ih < aheads.size(); ++ih) {
|
|
1225
|
+
size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0]));
|
|
1226
|
+
mask_data[pos] = 1.0f;
|
|
1227
|
+
}
|
|
1228
|
+
|
|
1229
|
+
wsp_ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size_bytes);
|
|
1230
|
+
}
|
|
1231
|
+
}
|
|
1232
|
+
|
|
1233
|
+
if (aheads_masks.m.empty()) {
|
|
1234
|
+
WHISPER_LOG_ERROR("%s: \n", __func__);
|
|
1235
|
+
return false;
|
|
1236
|
+
}
|
|
1237
|
+
|
|
1238
|
+
return true;
|
|
1239
|
+
}
|
|
1240
|
+
|
|
1241
|
+
static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) {
|
|
1242
|
+
wsp_ggml_free(aheads_masks.ctx);
|
|
1243
|
+
wsp_ggml_backend_buffer_free(aheads_masks.buffer);
|
|
1244
|
+
aheads_masks.ctx = nullptr;
|
|
1245
|
+
}
|
|
1246
|
+
|
|
1247
|
+
static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
|
1248
|
+
size_t size = 0;
|
|
1249
|
+
for (size_t i = 0; i < aheads_masks.m.size(); ++i) {
|
|
1250
|
+
if (aheads_masks.m[i] != nullptr)
|
|
1251
|
+
size += wsp_ggml_nbytes(aheads_masks.m[i]);
|
|
1252
|
+
}
|
|
1253
|
+
return size;
|
|
1254
|
+
}
|
|
1255
|
+
|
|
1256
|
+
static wsp_ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
|
1257
|
+
wsp_ggml_backend_t result = NULL;
|
|
1258
|
+
|
|
1259
|
+
wsp_ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
|
1260
|
+
|
|
1261
|
+
#ifdef WSP_GGML_USE_CUDA
|
|
1262
|
+
if (params.use_gpu) {
|
|
1062
1263
|
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
|
1063
|
-
|
|
1064
|
-
if (!
|
|
1264
|
+
result = wsp_ggml_backend_cuda_init(params.gpu_device);
|
|
1265
|
+
if (!result) {
|
|
1065
1266
|
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
|
|
1066
1267
|
}
|
|
1067
1268
|
}
|
|
@@ -1070,22 +1271,108 @@ static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & pa
|
|
|
1070
1271
|
#ifdef WSP_GGML_USE_METAL
|
|
1071
1272
|
if (params.use_gpu) {
|
|
1072
1273
|
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
if (!backend_gpu) {
|
|
1274
|
+
result = wsp_ggml_backend_metal_init();
|
|
1275
|
+
if (!result) {
|
|
1076
1276
|
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__);
|
|
1077
|
-
} else if (!wsp_ggml_backend_metal_supports_family(
|
|
1277
|
+
} else if (!wsp_ggml_backend_metal_supports_family(result, 7)) {
|
|
1078
1278
|
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
|
|
1079
|
-
wsp_ggml_backend_free(
|
|
1080
|
-
|
|
1279
|
+
wsp_ggml_backend_free(result);
|
|
1280
|
+
result = NULL;
|
|
1081
1281
|
}
|
|
1082
1282
|
}
|
|
1083
1283
|
#endif
|
|
1084
1284
|
|
|
1285
|
+
#ifdef WSP_GGML_USE_SYCL
|
|
1286
|
+
if (params.use_gpu) {
|
|
1287
|
+
WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
|
|
1288
|
+
result = wsp_ggml_backend_sycl_init(params.gpu_device);
|
|
1289
|
+
if (!result) {
|
|
1290
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_sycl_init() failed\n", __func__);
|
|
1291
|
+
}
|
|
1292
|
+
}
|
|
1293
|
+
#endif
|
|
1294
|
+
|
|
1295
|
+
#ifdef WSP_GGML_USE_VULKAN
|
|
1296
|
+
if (params.use_gpu) {
|
|
1297
|
+
WHISPER_LOG_INFO("%s: using Vulkan backend\n", __func__);
|
|
1298
|
+
result = wsp_ggml_backend_vk_init(params.gpu_device);
|
|
1299
|
+
if (!result) {
|
|
1300
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_vk_init() failed\n", __func__);
|
|
1301
|
+
}
|
|
1302
|
+
}
|
|
1303
|
+
#endif
|
|
1304
|
+
|
|
1305
|
+
#ifdef WSP_GGML_USE_CANN
|
|
1306
|
+
if (params.use_gpu) {
|
|
1307
|
+
WHISPER_LOG_INFO("%s: using CANN backend\n", __func__);
|
|
1308
|
+
result = wsp_ggml_backend_cann_init(params.gpu_device);
|
|
1309
|
+
if (!result) {
|
|
1310
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cann_init() failed\n", __func__);
|
|
1311
|
+
}
|
|
1312
|
+
}
|
|
1313
|
+
#endif
|
|
1314
|
+
|
|
1315
|
+
WSP_GGML_UNUSED(params);
|
|
1316
|
+
|
|
1317
|
+
return result;
|
|
1318
|
+
}
|
|
1319
|
+
|
|
1320
|
+
static std::vector<wsp_ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
|
|
1321
|
+
std::vector<wsp_ggml_backend_t> result;
|
|
1322
|
+
|
|
1323
|
+
wsp_ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
|
|
1324
|
+
|
|
1085
1325
|
if (backend_gpu) {
|
|
1086
|
-
|
|
1326
|
+
result.push_back(backend_gpu);
|
|
1327
|
+
}
|
|
1328
|
+
|
|
1329
|
+
#ifdef WSP_GGML_USE_BLAS
|
|
1330
|
+
{
|
|
1331
|
+
WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__);
|
|
1332
|
+
wsp_ggml_backend_t backend_blas = wsp_ggml_backend_blas_init();
|
|
1333
|
+
if (!backend_blas) {
|
|
1334
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_blas_init() failed\n", __func__);
|
|
1335
|
+
} else {
|
|
1336
|
+
result.push_back(backend_blas);
|
|
1337
|
+
}
|
|
1087
1338
|
}
|
|
1088
|
-
|
|
1339
|
+
#endif
|
|
1340
|
+
|
|
1341
|
+
WSP_GGML_UNUSED(params);
|
|
1342
|
+
|
|
1343
|
+
result.push_back(wsp_ggml_backend_cpu_init());
|
|
1344
|
+
|
|
1345
|
+
return result;
|
|
1346
|
+
}
|
|
1347
|
+
|
|
1348
|
+
static wsp_ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
|
|
1349
|
+
wsp_ggml_backend_buffer_type_t result = nullptr;
|
|
1350
|
+
|
|
1351
|
+
params.use_gpu || (result = wsp_ggml_backend_cpu_buffer_type());
|
|
1352
|
+
|
|
1353
|
+
#ifdef WSP_GGML_USE_CUDA
|
|
1354
|
+
result || (result = wsp_ggml_backend_cuda_buffer_type(params.gpu_device));
|
|
1355
|
+
#endif
|
|
1356
|
+
|
|
1357
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1358
|
+
result || (result = wsp_ggml_backend_metal_buffer_type());
|
|
1359
|
+
#endif
|
|
1360
|
+
|
|
1361
|
+
#ifdef WSP_GGML_USE_SYCL
|
|
1362
|
+
result || (result = wsp_ggml_backend_sycl_buffer_type(params.gpu_device));
|
|
1363
|
+
#endif
|
|
1364
|
+
|
|
1365
|
+
#ifdef WSP_GGML_USE_VULKAN
|
|
1366
|
+
result || (result = wsp_ggml_backend_vk_buffer_type(params.gpu_device));
|
|
1367
|
+
#endif
|
|
1368
|
+
|
|
1369
|
+
#ifdef WSP_GGML_USE_CANN
|
|
1370
|
+
result || (result == wsp_ggml_backend_cann_buffer_type(params.gpu_device));
|
|
1371
|
+
#endif
|
|
1372
|
+
|
|
1373
|
+
result || (result = wsp_ggml_backend_cpu_buffer_type());
|
|
1374
|
+
|
|
1375
|
+
return result;
|
|
1089
1376
|
}
|
|
1090
1377
|
|
|
1091
1378
|
// load the model from a ggml file
|
|
@@ -1512,69 +1799,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1512
1799
|
}
|
|
1513
1800
|
}
|
|
1514
1801
|
|
|
1515
|
-
wctx.backend = whisper_backend_init(wctx.params);
|
|
1516
|
-
|
|
1517
|
-
// some devices have a limit on the maximum size of single memory buffer
|
|
1518
|
-
// for example, iPhones are limited to 1GB per buffer
|
|
1519
|
-
// to workaround this, we will allocate multiple buffers of smaller size and will split the tensors with the
|
|
1520
|
-
// model weights between them
|
|
1521
|
-
//
|
|
1522
|
-
// the map_t2b maps tensor names to buffer indices
|
|
1523
|
-
// as we iterate over the tensors, we will allocate new buffers when the current one is full
|
|
1524
|
-
//
|
|
1525
|
-
// finally, we create a separate allocator for each buffer and use it to allocate the tensors
|
|
1526
|
-
// we keep the allocators alive until all the tensors are loaded
|
|
1527
|
-
|
|
1528
|
-
WSP_GGML_ASSERT(model.buffers.empty());
|
|
1529
|
-
|
|
1530
|
-
std::map<std::string, int> map_t2b;
|
|
1531
|
-
|
|
1532
|
-
{
|
|
1533
|
-
size_t size_main = 0;
|
|
1534
|
-
size_t size_cur = 0;
|
|
1535
|
-
|
|
1536
|
-
static const size_t GB = 1024ull*1024ull*1024ull;
|
|
1537
|
-
|
|
1538
|
-
for (const auto & t : model.tensors) {
|
|
1539
|
-
const size_t cur = wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead();
|
|
1540
|
-
|
|
1541
|
-
// adding the tensor to the current buffer will exceed the limit, so we need to allocate a new buffer
|
|
1542
|
-
if (size_cur + cur > GB) {
|
|
1543
|
-
WSP_GGML_ASSERT(size_cur > 0 && "A tensor is too large to fit in a single buffer");
|
|
1544
|
-
|
|
1545
|
-
model.buffers.emplace_back(wsp_ggml_backend_alloc_buffer(wctx.backend, size_cur));
|
|
1546
|
-
|
|
1547
|
-
size_cur = cur;
|
|
1548
|
-
}
|
|
1549
|
-
|
|
1550
|
-
map_t2b[t.first] = model.buffers.size();
|
|
1551
|
-
|
|
1552
|
-
size_cur += cur;
|
|
1553
|
-
size_main += cur;
|
|
1554
|
-
}
|
|
1555
|
-
|
|
1556
|
-
// allocate the last buffer if needed
|
|
1557
|
-
if (size_cur > 0) {
|
|
1558
|
-
model.buffers.emplace_back(wsp_ggml_backend_alloc_buffer(wctx.backend, size_cur));
|
|
1559
|
-
}
|
|
1560
|
-
|
|
1561
|
-
WSP_GGML_ASSERT(model.buffers.size() > 0);
|
|
1562
|
-
|
|
1563
|
-
WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB (%d buffers)\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6, (int) model.buffers.size());
|
|
1564
|
-
}
|
|
1565
|
-
|
|
1566
|
-
std::vector<wsp_ggml_allocr *> allocs(model.buffers.size());
|
|
1567
|
-
for (size_t i = 0; i < allocs.size(); ++i) {
|
|
1568
|
-
allocs[i] = wsp_ggml_allocr_new_from_buffer(model.buffers[i]);
|
|
1569
|
-
}
|
|
1570
|
-
|
|
1571
1802
|
// allocate tensors in the backend buffers
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1803
|
+
model.buffer = wsp_ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
|
|
1804
|
+
if (!model.buffer) {
|
|
1805
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
|
|
1806
|
+
return false;
|
|
1576
1807
|
}
|
|
1577
1808
|
|
|
1809
|
+
size_t size_main = wsp_ggml_backend_buffer_get_size(model.buffer);
|
|
1810
|
+
WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, wsp_ggml_backend_buffer_name(model.buffer), size_main / 1e6);
|
|
1811
|
+
|
|
1578
1812
|
// load weights
|
|
1579
1813
|
{
|
|
1580
1814
|
size_t total_size = 0;
|
|
@@ -1636,15 +1870,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1636
1870
|
return false;
|
|
1637
1871
|
}
|
|
1638
1872
|
|
|
1639
|
-
wsp_ggml_backend_t backend = wctx.backend;
|
|
1873
|
+
//wsp_ggml_backend_t backend = wctx.backend;
|
|
1640
1874
|
|
|
1641
1875
|
//printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
|
|
1642
1876
|
|
|
1643
|
-
if ((
|
|
1644
|
-
#ifdef WSP_GGML_USE_METAL
|
|
1645
|
-
|| wsp_ggml_backend_is_metal(backend)
|
|
1646
|
-
#endif
|
|
1647
|
-
)) {
|
|
1877
|
+
if (wsp_ggml_backend_buffer_is_host(model.buffer)) {
|
|
1648
1878
|
// for the CPU and Metal backend, we can read directly into the tensor
|
|
1649
1879
|
loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
|
|
1650
1880
|
BYTESWAP_TENSOR(tensor);
|
|
@@ -1672,9 +1902,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1672
1902
|
}
|
|
1673
1903
|
}
|
|
1674
1904
|
|
|
1675
|
-
|
|
1676
|
-
wsp_ggml_allocr_free(alloc);
|
|
1677
|
-
}
|
|
1905
|
+
wsp_ggml_backend_buffer_set_usage(model.buffer, WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
|
1678
1906
|
|
|
1679
1907
|
wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
|
|
1680
1908
|
|
|
@@ -1701,10 +1929,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
|
|
|
1701
1929
|
|
|
1702
1930
|
static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
1703
1931
|
whisper_context & wctx,
|
|
1704
|
-
whisper_state & wstate
|
|
1705
|
-
const int mel_offset) {
|
|
1932
|
+
whisper_state & wstate) {
|
|
1706
1933
|
const auto & model = wctx.model;
|
|
1707
|
-
const auto & mel_inp = wstate.mel;
|
|
1708
1934
|
const auto & hparams = model.hparams;
|
|
1709
1935
|
|
|
1710
1936
|
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
@@ -1713,8 +1939,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1713
1939
|
const int n_mels = hparams.n_mels;
|
|
1714
1940
|
|
|
1715
1941
|
struct wsp_ggml_init_params params = {
|
|
1716
|
-
/*.mem_size =*/ wstate.
|
|
1717
|
-
/*.mem_buffer =*/ wstate.
|
|
1942
|
+
/*.mem_size =*/ wstate.sched_conv.meta.size(),
|
|
1943
|
+
/*.mem_buffer =*/ wstate.sched_conv.meta.data(),
|
|
1718
1944
|
/*.no_alloc =*/ true,
|
|
1719
1945
|
};
|
|
1720
1946
|
|
|
@@ -1722,31 +1948,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1722
1948
|
|
|
1723
1949
|
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
1724
1950
|
|
|
1725
|
-
wsp_ggml_allocr * alloc = wstate.alloc_conv.alloc;
|
|
1726
|
-
|
|
1727
1951
|
struct wsp_ggml_tensor * mel = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, 2*n_ctx, n_mels);
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
assert(mel->type == WSP_GGML_TYPE_F32);
|
|
1731
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1732
|
-
assert(mel_inp.n_mel == n_mels);
|
|
1733
|
-
|
|
1734
|
-
wstate.inp_mel.resize(wsp_ggml_nelements(mel));
|
|
1735
|
-
|
|
1736
|
-
float * dst = wstate.inp_mel.data();
|
|
1737
|
-
memset(dst, 0, wsp_ggml_nbytes(mel));
|
|
1738
|
-
|
|
1739
|
-
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
|
1740
|
-
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
|
1741
|
-
|
|
1742
|
-
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
|
1743
|
-
for (int i = i0; i < i1; ++i) {
|
|
1744
|
-
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
|
1745
|
-
}
|
|
1746
|
-
}
|
|
1747
|
-
|
|
1748
|
-
wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
|
|
1749
|
-
}
|
|
1952
|
+
wsp_ggml_set_name(mel, "mel");
|
|
1953
|
+
wsp_ggml_set_input(mel);
|
|
1750
1954
|
|
|
1751
1955
|
struct wsp_ggml_tensor * cur = nullptr;
|
|
1752
1956
|
|
|
@@ -1767,27 +1971,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1767
1971
|
wsp_ggml_set_name(cur, "embd_conv");
|
|
1768
1972
|
wstate.embd_conv = cur;
|
|
1769
1973
|
} else {
|
|
1770
|
-
|
|
1771
|
-
cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
1772
|
-
wsp_ggml_allocr_alloc(alloc, cur);
|
|
1974
|
+
wsp_ggml_build_forward_expand(gf, mel);
|
|
1773
1975
|
|
|
1774
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1775
|
-
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
|
|
1776
|
-
}
|
|
1777
|
-
#endif
|
|
1778
|
-
#ifdef WHISPER_USE_OPENVINO
|
|
1779
1976
|
cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
1780
|
-
|
|
1781
|
-
|
|
1782
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1783
|
-
whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
|
|
1784
|
-
}
|
|
1785
|
-
#endif
|
|
1977
|
+
wsp_ggml_set_input(cur); // the external encoder will write into this tensor
|
|
1786
1978
|
|
|
1787
1979
|
wsp_ggml_set_name(cur, "embd_enc");
|
|
1788
1980
|
wstate.embd_enc = cur;
|
|
1789
1981
|
}
|
|
1790
1982
|
|
|
1983
|
+
wsp_ggml_set_output(cur);
|
|
1984
|
+
|
|
1791
1985
|
wsp_ggml_build_forward_expand(gf, cur);
|
|
1792
1986
|
|
|
1793
1987
|
wsp_ggml_free(ctx0);
|
|
@@ -1806,9 +2000,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1806
2000
|
const int n_head = hparams.n_audio_head;
|
|
1807
2001
|
const int n_layer = hparams.n_audio_layer;
|
|
1808
2002
|
|
|
2003
|
+
const int n_state_head = n_state/n_head;
|
|
2004
|
+
|
|
2005
|
+
auto & kv_pad = wstate.kv_pad;
|
|
2006
|
+
|
|
2007
|
+
WHISPER_ASSERT(!!kv_pad.buffer);
|
|
2008
|
+
|
|
2009
|
+
const int n_ctx_pad = WSP_GGML_PAD(n_ctx, 256);
|
|
2010
|
+
|
|
1809
2011
|
struct wsp_ggml_init_params params = {
|
|
1810
|
-
/*.mem_size =*/ wstate.
|
|
1811
|
-
/*.mem_buffer =*/ wstate.
|
|
2012
|
+
/*.mem_size =*/ wstate.sched_encode.meta.size(),
|
|
2013
|
+
/*.mem_buffer =*/ wstate.sched_encode.meta.data(),
|
|
1812
2014
|
/*.no_alloc =*/ true,
|
|
1813
2015
|
};
|
|
1814
2016
|
|
|
@@ -1816,17 +2018,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1816
2018
|
|
|
1817
2019
|
wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
|
1818
2020
|
|
|
1819
|
-
//wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
|
1820
|
-
|
|
1821
|
-
//struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state);
|
|
1822
|
-
//wsp_ggml_allocr_alloc(alloc, cur);
|
|
1823
|
-
|
|
1824
|
-
//if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1825
|
-
// wsp_ggml_backend_tensor_copy(wstate.embd_conv, cur);
|
|
1826
|
-
//}
|
|
1827
2021
|
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
|
|
1828
2022
|
|
|
1829
|
-
const float KQscale = 1.0f/sqrtf(float(
|
|
2023
|
+
const float KQscale = 1.0f/sqrtf(float(n_state_head));
|
|
1830
2024
|
|
|
1831
2025
|
// ===================================================================
|
|
1832
2026
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
@@ -1876,14 +2070,14 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1876
2070
|
|
|
1877
2071
|
Qcur = wsp_ggml_add(ctx0, Qcur, layer.attn_q_b);
|
|
1878
2072
|
|
|
1879
|
-
//Qcur = wsp_ggml_scale(ctx0, Qcur, pow(float(
|
|
2073
|
+
//Qcur = wsp_ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
|
|
1880
2074
|
|
|
1881
2075
|
// note: no bias for Key
|
|
1882
2076
|
struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
|
|
1883
2077
|
layer.attn_k_w,
|
|
1884
2078
|
cur);
|
|
1885
2079
|
|
|
1886
|
-
//Kcur = wsp_ggml_scale(ctx0, Kcur, pow(float(
|
|
2080
|
+
//Kcur = wsp_ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
|
|
1887
2081
|
|
|
1888
2082
|
struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
|
|
1889
2083
|
layer.attn_v_w,
|
|
@@ -1893,70 +2087,60 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1893
2087
|
|
|
1894
2088
|
// ------
|
|
1895
2089
|
|
|
1896
|
-
#ifdef WHISPER_USE_FLASH_ATTN
|
|
1897
2090
|
struct wsp_ggml_tensor * Q =
|
|
1898
2091
|
wsp_ggml_permute(ctx0,
|
|
1899
|
-
|
|
1900
|
-
Qcur,
|
|
1901
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
2092
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
|
|
1902
2093
|
0, 2, 1, 3);
|
|
1903
2094
|
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
Kcur,
|
|
1908
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1909
|
-
0, 2, 1, 3);
|
|
2095
|
+
if (wctx.params.flash_attn) {
|
|
2096
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, wsp_ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
|
|
2097
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, wsp_ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
|
|
1910
2098
|
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
1, 2, 0, 3),
|
|
1918
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
|
2099
|
+
struct wsp_ggml_tensor * K =
|
|
2100
|
+
wsp_ggml_view_3d(ctx0, kv_pad.k,
|
|
2101
|
+
n_state_head, n_ctx_pad, n_head,
|
|
2102
|
+
wsp_ggml_element_size(kv_pad.k)*n_state,
|
|
2103
|
+
wsp_ggml_element_size(kv_pad.k)*n_state_head,
|
|
2104
|
+
0);
|
|
1919
2105
|
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
|
1927
|
-
0, 2, 1, 3);
|
|
2106
|
+
struct wsp_ggml_tensor * V =
|
|
2107
|
+
wsp_ggml_view_3d(ctx0, kv_pad.v,
|
|
2108
|
+
n_state_head, n_ctx_pad, n_head,
|
|
2109
|
+
wsp_ggml_element_size(kv_pad.v)*n_state,
|
|
2110
|
+
wsp_ggml_element_size(kv_pad.v)*n_state_head,
|
|
2111
|
+
0);
|
|
1928
2112
|
|
|
1929
|
-
|
|
1930
|
-
wsp_ggml_permute(ctx0,
|
|
1931
|
-
wsp_ggml_cpy(ctx0,
|
|
1932
|
-
Kcur,
|
|
1933
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1934
|
-
0, 2, 1, 3);
|
|
2113
|
+
cur = wsp_ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
|
|
1935
2114
|
|
|
1936
|
-
|
|
1937
|
-
|
|
2115
|
+
cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
|
2116
|
+
} else {
|
|
2117
|
+
struct wsp_ggml_tensor * K =
|
|
2118
|
+
wsp_ggml_permute(ctx0,
|
|
2119
|
+
wsp_ggml_cast(ctx0,
|
|
2120
|
+
wsp_ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
|
|
2121
|
+
wctx.itype),
|
|
2122
|
+
0, 2, 1, 3);
|
|
1938
2123
|
|
|
1939
|
-
|
|
2124
|
+
// K * Q
|
|
2125
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
|
|
1940
2126
|
|
|
1941
|
-
|
|
2127
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
|
1942
2128
|
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
);
|
|
2129
|
+
struct wsp_ggml_tensor * V =
|
|
2130
|
+
wsp_ggml_cast(ctx0,
|
|
2131
|
+
wsp_ggml_permute(ctx0,
|
|
2132
|
+
wsp_ggml_reshape_3d(ctx0,
|
|
2133
|
+
Vcur,
|
|
2134
|
+
n_state_head, n_head, n_ctx),
|
|
2135
|
+
1, 2, 0, 3),
|
|
2136
|
+
wctx.itype);
|
|
1952
2137
|
|
|
1953
|
-
|
|
1954
|
-
#endif
|
|
1955
|
-
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2138
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
1956
2139
|
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
2140
|
+
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2141
|
+
|
|
2142
|
+
cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
|
|
2143
|
+
}
|
|
1960
2144
|
}
|
|
1961
2145
|
|
|
1962
2146
|
// projection
|
|
@@ -1985,11 +2169,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1985
2169
|
layer.mlp_ln_b);
|
|
1986
2170
|
}
|
|
1987
2171
|
|
|
1988
|
-
#ifdef WHISPER_USE_FLASH_FF
|
|
1989
|
-
cur = wsp_ggml_flash_ff(ctx0,
|
|
1990
|
-
wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
|
1991
|
-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
1992
|
-
#else
|
|
1993
2172
|
// fully connected
|
|
1994
2173
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
1995
2174
|
layer.mlp_0_w,
|
|
@@ -2006,7 +2185,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
2006
2185
|
cur);
|
|
2007
2186
|
|
|
2008
2187
|
cur = wsp_ggml_add(ctx0, cur, layer.mlp_1_b);
|
|
2009
|
-
#endif
|
|
2010
2188
|
}
|
|
2011
2189
|
|
|
2012
2190
|
inpL = wsp_ggml_add(ctx0, cur, inpFF);
|
|
@@ -2055,9 +2233,13 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
|
2055
2233
|
const int n_state = hparams.n_audio_state;
|
|
2056
2234
|
const int n_head = hparams.n_audio_head;
|
|
2057
2235
|
|
|
2236
|
+
const int n_state_head = n_state/n_head;
|
|
2237
|
+
|
|
2238
|
+
const int n_ctx_pad = WSP_GGML_PAD(n_ctx, 256);
|
|
2239
|
+
|
|
2058
2240
|
struct wsp_ggml_init_params params = {
|
|
2059
|
-
/*.mem_size =*/ wstate.
|
|
2060
|
-
/*.mem_buffer =*/ wstate.
|
|
2241
|
+
/*.mem_size =*/ wstate.sched_cross.meta.size(),
|
|
2242
|
+
/*.mem_buffer =*/ wstate.sched_cross.meta.data(),
|
|
2061
2243
|
/*.no_alloc =*/ true,
|
|
2062
2244
|
};
|
|
2063
2245
|
|
|
@@ -2065,28 +2247,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
|
2065
2247
|
|
|
2066
2248
|
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
2067
2249
|
|
|
2068
|
-
//wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
|
2069
|
-
|
|
2070
|
-
//struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
2071
|
-
//wsp_ggml_allocr_alloc(alloc, cur);
|
|
2072
|
-
|
|
2073
|
-
//if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2074
|
-
// wsp_ggml_backend_tensor_copy(wstate.embd_enc, cur);
|
|
2075
|
-
//}
|
|
2076
2250
|
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
|
|
2077
2251
|
|
|
2078
|
-
const float Kscale = pow(float(
|
|
2252
|
+
const float Kscale = pow(float(n_state_head), -0.25);
|
|
2079
2253
|
|
|
2080
2254
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
2081
2255
|
auto & layer = model.layers_decoder[il];
|
|
2082
2256
|
|
|
2083
|
-
struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
|
|
2257
|
+
struct wsp_ggml_tensor * Kcross = wsp_ggml_mul_mat(ctx0,
|
|
2084
2258
|
layer.cross_attn_k_w,
|
|
2085
2259
|
cur);
|
|
2086
2260
|
|
|
2087
2261
|
Kcross = wsp_ggml_scale(ctx0, Kcross, Kscale);
|
|
2088
2262
|
|
|
2089
|
-
struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
|
|
2263
|
+
struct wsp_ggml_tensor * Vcross = wsp_ggml_mul_mat(ctx0,
|
|
2090
2264
|
layer.cross_attn_v_w,
|
|
2091
2265
|
cur);
|
|
2092
2266
|
|
|
@@ -2094,15 +2268,25 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
|
2094
2268
|
Vcross,
|
|
2095
2269
|
layer.cross_attn_v_b);
|
|
2096
2270
|
|
|
2097
|
-
|
|
2271
|
+
struct wsp_ggml_tensor * k;
|
|
2272
|
+
struct wsp_ggml_tensor * v;
|
|
2273
|
+
|
|
2274
|
+
if (wctx.params.flash_attn) {
|
|
2275
|
+
k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
|
2276
|
+
(wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
|
|
2098
2277
|
|
|
2099
|
-
|
|
2100
|
-
|
|
2101
|
-
|
|
2278
|
+
v = wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
|
|
2279
|
+
(wsp_ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
|
|
2280
|
+
} else {
|
|
2281
|
+
Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
|
2282
|
+
|
|
2283
|
+
k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
|
2284
|
+
(wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
|
2102
2285
|
|
|
2103
|
-
|
|
2104
|
-
|
|
2105
|
-
|
|
2286
|
+
v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
|
2287
|
+
( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
2288
|
+
(il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
|
|
2289
|
+
}
|
|
2106
2290
|
|
|
2107
2291
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcross, k));
|
|
2108
2292
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcross, v));
|
|
@@ -2130,53 +2314,89 @@ static bool whisper_encode_internal(
|
|
|
2130
2314
|
whisper_state & wstate,
|
|
2131
2315
|
const int mel_offset,
|
|
2132
2316
|
const int n_threads,
|
|
2133
|
-
|
|
2317
|
+
wsp_ggml_abort_callback abort_callback,
|
|
2134
2318
|
void * abort_callback_data) {
|
|
2135
2319
|
const int64_t t_start_us = wsp_ggml_time_us();
|
|
2136
2320
|
|
|
2137
2321
|
// conv
|
|
2138
2322
|
{
|
|
2139
|
-
auto &
|
|
2323
|
+
auto & sched = wstate.sched_conv.sched;
|
|
2324
|
+
|
|
2325
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
|
|
2140
2326
|
|
|
2141
|
-
|
|
2327
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2328
|
+
// should never happen as we pre-allocate the memory
|
|
2329
|
+
return false;
|
|
2330
|
+
}
|
|
2331
|
+
|
|
2332
|
+
struct wsp_ggml_tensor * mel = wsp_ggml_graph_get_tensor(gf, "mel");
|
|
2333
|
+
|
|
2334
|
+
// set the input
|
|
2335
|
+
{
|
|
2336
|
+
const auto & mel_inp = wstate.mel;
|
|
2337
|
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
|
|
2338
|
+
|
|
2339
|
+
assert(mel->type == WSP_GGML_TYPE_F32);
|
|
2340
|
+
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
|
|
2341
|
+
|
|
2342
|
+
wstate.inp_mel.resize(wsp_ggml_nelements(mel));
|
|
2343
|
+
|
|
2344
|
+
float * dst = wstate.inp_mel.data();
|
|
2345
|
+
memset(dst, 0, wsp_ggml_nbytes(mel));
|
|
2346
|
+
|
|
2347
|
+
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
|
2348
|
+
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
|
2142
2349
|
|
|
2143
|
-
|
|
2350
|
+
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
|
2351
|
+
for (int i = i0; i < i1; ++i) {
|
|
2352
|
+
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
|
2353
|
+
}
|
|
2354
|
+
}
|
|
2144
2355
|
|
|
2145
|
-
|
|
2356
|
+
wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
|
|
2357
|
+
}
|
|
2146
2358
|
|
|
2147
2359
|
if (!whisper_encode_external(wstate)) {
|
|
2148
|
-
if (!wsp_ggml_graph_compute_helper(
|
|
2360
|
+
if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2149
2361
|
return false;
|
|
2150
2362
|
}
|
|
2363
|
+
} else {
|
|
2364
|
+
#if defined(WHISPER_USE_COREML)
|
|
2365
|
+
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
|
|
2366
|
+
#elif defined(WHISPER_USE_OPENVINO)
|
|
2367
|
+
whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc);
|
|
2368
|
+
#endif
|
|
2151
2369
|
}
|
|
2152
2370
|
}
|
|
2153
2371
|
|
|
2154
2372
|
// encoder
|
|
2155
2373
|
if (!whisper_encode_external(wstate)) {
|
|
2156
|
-
auto &
|
|
2157
|
-
|
|
2158
|
-
wsp_ggml_allocr_reset(alloc);
|
|
2374
|
+
auto & sched = wstate.sched_encode.sched;
|
|
2159
2375
|
|
|
2160
2376
|
wsp_ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
|
|
2161
2377
|
|
|
2162
|
-
|
|
2378
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2379
|
+
// should never happen as we pre-allocate the memory
|
|
2380
|
+
return false;
|
|
2381
|
+
}
|
|
2163
2382
|
|
|
2164
|
-
if (!wsp_ggml_graph_compute_helper(
|
|
2383
|
+
if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2165
2384
|
return false;
|
|
2166
2385
|
}
|
|
2167
2386
|
}
|
|
2168
2387
|
|
|
2169
2388
|
// cross
|
|
2170
2389
|
{
|
|
2171
|
-
auto &
|
|
2172
|
-
|
|
2173
|
-
wsp_ggml_allocr_reset(alloc);
|
|
2390
|
+
auto & sched = wstate.sched_cross.sched;
|
|
2174
2391
|
|
|
2175
2392
|
wsp_ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
|
2176
2393
|
|
|
2177
|
-
|
|
2394
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2395
|
+
// should never happen as we pre-allocate the memory
|
|
2396
|
+
return false;
|
|
2397
|
+
}
|
|
2178
2398
|
|
|
2179
|
-
if (!wsp_ggml_graph_compute_helper(
|
|
2399
|
+
if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2180
2400
|
return false;
|
|
2181
2401
|
}
|
|
2182
2402
|
}
|
|
@@ -2190,82 +2410,58 @@ static bool whisper_encode_internal(
|
|
|
2190
2410
|
static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
2191
2411
|
whisper_context & wctx,
|
|
2192
2412
|
whisper_state & wstate,
|
|
2193
|
-
const whisper_batch & batch
|
|
2413
|
+
const whisper_batch & batch,
|
|
2414
|
+
bool save_alignment_heads_QKs,
|
|
2415
|
+
bool worst_case) {
|
|
2194
2416
|
const auto & model = wctx.model;
|
|
2195
2417
|
const auto & hparams = model.hparams;
|
|
2196
2418
|
|
|
2197
2419
|
auto & kv_self = wstate.kv_self;
|
|
2198
2420
|
|
|
2199
|
-
WHISPER_ASSERT(!!kv_self.
|
|
2200
|
-
|
|
2201
|
-
wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
|
2421
|
+
WHISPER_ASSERT(!!kv_self.buffer);
|
|
2202
2422
|
|
|
2203
2423
|
const int n_ctx = kv_self.size;
|
|
2204
2424
|
const int n_state = hparams.n_text_state;
|
|
2205
2425
|
const int n_head = hparams.n_text_head;
|
|
2206
2426
|
const int n_layer = hparams.n_text_layer;
|
|
2207
2427
|
|
|
2428
|
+
const int n_state_head = n_state/n_head;
|
|
2429
|
+
|
|
2208
2430
|
const int n_tokens = batch.n_tokens;
|
|
2209
2431
|
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
2210
2432
|
|
|
2211
|
-
const
|
|
2212
|
-
|
|
2433
|
+
const int n_audio_ctx_pad = WSP_GGML_PAD(n_audio_ctx, 256);
|
|
2434
|
+
|
|
2435
|
+
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
|
|
2436
|
+
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
|
|
2213
2437
|
|
|
2214
2438
|
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
|
2215
2439
|
|
|
2216
2440
|
struct wsp_ggml_init_params params = {
|
|
2217
|
-
/*.mem_size =*/ wstate.
|
|
2218
|
-
/*.mem_buffer =*/ wstate.
|
|
2441
|
+
/*.mem_size =*/ wstate.sched_decode.meta.size(),
|
|
2442
|
+
/*.mem_buffer =*/ wstate.sched_decode.meta.data(),
|
|
2219
2443
|
/*.no_alloc =*/ true,
|
|
2220
2444
|
};
|
|
2221
2445
|
|
|
2222
2446
|
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
2223
2447
|
|
|
2224
|
-
wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
|
2225
|
-
|
|
2226
|
-
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2227
|
-
wsp_ggml_allocr_alloc(alloc, embd);
|
|
2228
|
-
|
|
2229
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2230
|
-
wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
|
|
2231
|
-
}
|
|
2232
|
-
|
|
2233
|
-
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2234
|
-
wsp_ggml_allocr_alloc(alloc, position);
|
|
2235
|
-
|
|
2236
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2237
|
-
for (int i = 0; i < n_tokens; ++i) {
|
|
2238
|
-
const int32_t val = batch.pos[i];
|
|
2239
|
-
wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
|
2240
|
-
}
|
|
2241
|
-
}
|
|
2242
|
-
|
|
2243
|
-
const float KQscale = pow(float(n_state)/n_head, -0.25);
|
|
2244
|
-
|
|
2245
|
-
struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, n_tokens, 1);
|
|
2246
|
-
wsp_ggml_allocr_alloc(alloc, KQ_mask);
|
|
2448
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
|
2247
2449
|
|
|
2248
|
-
|
|
2249
|
-
|
|
2450
|
+
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2451
|
+
wsp_ggml_set_name(embd, "embd");
|
|
2452
|
+
wsp_ggml_set_input(embd);
|
|
2250
2453
|
|
|
2251
|
-
|
|
2252
|
-
|
|
2454
|
+
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2455
|
+
wsp_ggml_set_name(position, "position");
|
|
2456
|
+
wsp_ggml_set_input(position);
|
|
2253
2457
|
|
|
2254
|
-
|
|
2255
|
-
for (int j = 0; j < n_tokens; ++j) {
|
|
2256
|
-
const whisper_pos pos = batch.pos[j];
|
|
2257
|
-
const whisper_seq_id seq_id = batch.seq_id[j][0];
|
|
2458
|
+
const float KQscale = pow(float(n_state_head), -0.25);
|
|
2258
2459
|
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
|
|
2262
|
-
}
|
|
2263
|
-
}
|
|
2264
|
-
}
|
|
2265
|
-
}
|
|
2460
|
+
struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, WSP_GGML_PAD(n_tokens, WSP_GGML_KQ_MASK_PAD), 1);
|
|
2461
|
+
wsp_ggml_set_name(KQ_mask, "KQ_mask");
|
|
2462
|
+
wsp_ggml_set_input(KQ_mask);
|
|
2266
2463
|
|
|
2267
|
-
|
|
2268
|
-
}
|
|
2464
|
+
struct wsp_ggml_tensor * KQ_mask_f16 = wsp_ggml_cast(ctx0, KQ_mask, WSP_GGML_TYPE_F16);
|
|
2269
2465
|
|
|
2270
2466
|
// token encoding + position encoding
|
|
2271
2467
|
struct wsp_ggml_tensor * cur =
|
|
@@ -2275,6 +2471,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2275
2471
|
|
|
2276
2472
|
struct wsp_ggml_tensor * inpL = cur;
|
|
2277
2473
|
|
|
2474
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
2475
|
+
struct wsp_ggml_tensor * aheads_cross_QKs = nullptr;
|
|
2476
|
+
|
|
2278
2477
|
for (int il = 0; il < n_layer; ++il) {
|
|
2279
2478
|
const auto & layer = model.layers_decoder[il];
|
|
2280
2479
|
|
|
@@ -2319,12 +2518,25 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2319
2518
|
Vcur,
|
|
2320
2519
|
layer.attn_v_b);
|
|
2321
2520
|
|
|
2322
|
-
|
|
2521
|
+
struct wsp_ggml_tensor * k;
|
|
2522
|
+
struct wsp_ggml_tensor * v;
|
|
2523
|
+
|
|
2524
|
+
if (wctx.params.flash_attn) {
|
|
2525
|
+
k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
|
2526
|
+
(wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
|
2527
|
+
|
|
2528
|
+
v = wsp_ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
|
|
2529
|
+
(wsp_ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
|
|
2530
|
+
} else {
|
|
2531
|
+
Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
|
2532
|
+
|
|
2533
|
+
k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
|
2534
|
+
(wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
|
2323
2535
|
|
|
2324
|
-
|
|
2325
|
-
|
|
2326
|
-
|
|
2327
|
-
|
|
2536
|
+
v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
|
2537
|
+
( n_ctx)*wsp_ggml_element_size(kv_self.v),
|
|
2538
|
+
(il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
|
|
2539
|
+
}
|
|
2328
2540
|
|
|
2329
2541
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
|
|
2330
2542
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
|
|
@@ -2334,40 +2546,46 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2334
2546
|
|
|
2335
2547
|
struct wsp_ggml_tensor * Q =
|
|
2336
2548
|
wsp_ggml_permute(ctx0,
|
|
2337
|
-
wsp_ggml_reshape_3d(ctx0, Qcur,
|
|
2549
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
|
2338
2550
|
0, 2, 1, 3);
|
|
2339
2551
|
|
|
2340
2552
|
struct wsp_ggml_tensor * K =
|
|
2341
2553
|
wsp_ggml_view_3d(ctx0, kv_self.k,
|
|
2342
|
-
|
|
2554
|
+
n_state_head, n_kv, n_head,
|
|
2343
2555
|
wsp_ggml_element_size(kv_self.k)*n_state,
|
|
2344
|
-
wsp_ggml_element_size(kv_self.k)*
|
|
2556
|
+
wsp_ggml_element_size(kv_self.k)*n_state_head,
|
|
2345
2557
|
wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
|
2346
2558
|
|
|
2347
|
-
|
|
2348
|
-
|
|
2559
|
+
if (wctx.params.flash_attn) {
|
|
2560
|
+
struct wsp_ggml_tensor * V =
|
|
2561
|
+
wsp_ggml_view_3d(ctx0, kv_self.v,
|
|
2562
|
+
n_state_head, n_kv, n_head,
|
|
2563
|
+
wsp_ggml_element_size(kv_self.v)*n_state,
|
|
2564
|
+
wsp_ggml_element_size(kv_self.v)*n_state_head,
|
|
2565
|
+
wsp_ggml_element_size(kv_self.v)*n_state*n_ctx*il);
|
|
2349
2566
|
|
|
2350
|
-
|
|
2567
|
+
cur = wsp_ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
|
|
2351
2568
|
|
|
2352
|
-
|
|
2353
|
-
|
|
2569
|
+
cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
|
2570
|
+
} else {
|
|
2571
|
+
// K * Q
|
|
2572
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
|
|
2354
2573
|
|
|
2355
|
-
|
|
2574
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
|
|
2356
2575
|
|
|
2357
|
-
|
|
2358
|
-
|
|
2359
|
-
|
|
2360
|
-
|
|
2361
|
-
|
|
2362
|
-
|
|
2576
|
+
struct wsp_ggml_tensor * V =
|
|
2577
|
+
wsp_ggml_view_3d(ctx0, kv_self.v,
|
|
2578
|
+
n_kv, n_state_head, n_head,
|
|
2579
|
+
n_ctx*wsp_ggml_element_size(kv_self.v),
|
|
2580
|
+
n_ctx*wsp_ggml_element_size(kv_self.v)*n_state_head,
|
|
2581
|
+
n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
|
|
2363
2582
|
|
|
2364
|
-
|
|
2583
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
2365
2584
|
|
|
2366
|
-
|
|
2585
|
+
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2367
2586
|
|
|
2368
|
-
|
|
2369
|
-
|
|
2370
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
|
|
2587
|
+
cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
|
2588
|
+
}
|
|
2371
2589
|
}
|
|
2372
2590
|
|
|
2373
2591
|
// projection
|
|
@@ -2406,62 +2624,75 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2406
2624
|
Qcur,
|
|
2407
2625
|
layer.cross_attn_q_b);
|
|
2408
2626
|
|
|
2409
|
-
Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
|
|
2410
|
-
|
|
2411
|
-
// Kcross is already scaled
|
|
2412
|
-
struct wsp_ggml_tensor * Kcross =
|
|
2413
|
-
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2414
|
-
n_state/n_head, n_audio_ctx, n_head,
|
|
2415
|
-
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2416
|
-
wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
|
2417
|
-
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
|
2418
|
-
|
|
2419
|
-
//struct wsp_ggml_tensor * Vcross =
|
|
2420
|
-
// wsp_ggml_reshape_3d(ctx0,
|
|
2421
|
-
// wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
|
|
2422
|
-
// n_state/n_head, n_head, n_audio_ctx);
|
|
2423
|
-
|
|
2424
|
-
//struct wsp_ggml_tensor * V_trans =
|
|
2425
|
-
// wsp_ggml_cpy(ctx0,
|
|
2426
|
-
// wsp_ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
|
2427
|
-
// wsp_ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
|
2428
|
-
|
|
2429
|
-
struct wsp_ggml_tensor * V =
|
|
2430
|
-
wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
|
|
2431
|
-
n_audio_ctx, n_state/n_head, n_head,
|
|
2432
|
-
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
2433
|
-
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
|
2434
|
-
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
|
2435
|
-
|
|
2436
|
-
// ------
|
|
2437
|
-
|
|
2438
2627
|
struct wsp_ggml_tensor * Q =
|
|
2439
2628
|
wsp_ggml_permute(ctx0,
|
|
2440
|
-
wsp_ggml_reshape_3d(ctx0, Qcur,
|
|
2629
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
|
2441
2630
|
0, 2, 1, 3);
|
|
2442
2631
|
|
|
2443
|
-
|
|
2444
|
-
|
|
2632
|
+
if (wctx.params.flash_attn) {
|
|
2633
|
+
struct wsp_ggml_tensor * Kcross =
|
|
2634
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2635
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
|
2636
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2637
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
|
2638
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
|
|
2445
2639
|
|
|
2446
|
-
|
|
2447
|
-
|
|
2448
|
-
|
|
2449
|
-
|
|
2450
|
-
|
|
2640
|
+
struct wsp_ggml_tensor * Vcross =
|
|
2641
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
|
|
2642
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
|
2643
|
+
wsp_ggml_element_size(wstate.kv_cross.v)*n_state,
|
|
2644
|
+
wsp_ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
|
2645
|
+
wsp_ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
|
|
2451
2646
|
|
|
2452
|
-
|
|
2453
|
-
//struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
|
2647
|
+
cur = wsp_ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
|
|
2454
2648
|
|
|
2455
|
-
|
|
2649
|
+
cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
|
2650
|
+
} else {
|
|
2651
|
+
struct wsp_ggml_tensor * Kcross =
|
|
2652
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2653
|
+
n_state_head, n_audio_ctx, n_head,
|
|
2654
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2655
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
|
2656
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
|
2657
|
+
|
|
2658
|
+
struct wsp_ggml_tensor * Vcross =
|
|
2659
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
|
|
2660
|
+
n_audio_ctx, n_state_head, n_head,
|
|
2661
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
2662
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
|
2663
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
|
2664
|
+
|
|
2665
|
+
// ------
|
|
2666
|
+
|
|
2667
|
+
// K * Q
|
|
2668
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
|
|
2669
|
+
|
|
2670
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
|
2671
|
+
|
|
2672
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
2673
|
+
if (wctx.params.dtw_token_timestamps) {
|
|
2674
|
+
if (wstate.aheads_masks.m[il] != nullptr) {
|
|
2675
|
+
struct wsp_ggml_tensor * aheads_KQs = wsp_ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
|
|
2676
|
+
aheads_KQs = wsp_ggml_transpose(ctx0, aheads_KQs);
|
|
2677
|
+
aheads_KQs = wsp_ggml_cont(ctx0, aheads_KQs);
|
|
2678
|
+
aheads_KQs = wsp_ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
|
|
2679
|
+
aheads_KQs = wsp_ggml_transpose(ctx0, aheads_KQs);
|
|
2680
|
+
aheads_KQs = wsp_ggml_cont(ctx0, aheads_KQs);
|
|
2681
|
+
aheads_KQs = wsp_ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
|
|
2682
|
+
if (aheads_cross_QKs == NULL) {
|
|
2683
|
+
aheads_cross_QKs = aheads_KQs;
|
|
2684
|
+
} else {
|
|
2685
|
+
aheads_cross_QKs = wsp_ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
|
|
2686
|
+
}
|
|
2687
|
+
}
|
|
2688
|
+
}
|
|
2456
2689
|
|
|
2457
|
-
|
|
2690
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
|
|
2458
2691
|
|
|
2459
|
-
|
|
2692
|
+
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2460
2693
|
|
|
2461
|
-
|
|
2462
|
-
|
|
2463
|
-
KQV_merged,
|
|
2464
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
|
|
2694
|
+
cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
|
2695
|
+
}
|
|
2465
2696
|
}
|
|
2466
2697
|
|
|
2467
2698
|
// projection
|
|
@@ -2539,6 +2770,16 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2539
2770
|
|
|
2540
2771
|
struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
|
|
2541
2772
|
|
|
2773
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
2774
|
+
if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) {
|
|
2775
|
+
aheads_cross_QKs = wsp_ggml_transpose(ctx0, aheads_cross_QKs);
|
|
2776
|
+
aheads_cross_QKs = wsp_ggml_cont(ctx0, aheads_cross_QKs);
|
|
2777
|
+
if (save_alignment_heads_QKs) {
|
|
2778
|
+
wsp_ggml_build_forward_expand(gf, aheads_cross_QKs);
|
|
2779
|
+
wstate.aheads_cross_QKs = aheads_cross_QKs;
|
|
2780
|
+
}
|
|
2781
|
+
}
|
|
2782
|
+
|
|
2542
2783
|
wsp_ggml_build_forward_expand(gf, logits);
|
|
2543
2784
|
|
|
2544
2785
|
wsp_ggml_free(ctx0);
|
|
@@ -2561,7 +2802,8 @@ static bool whisper_decode_internal(
|
|
|
2561
2802
|
whisper_state & wstate,
|
|
2562
2803
|
const whisper_batch & batch,
|
|
2563
2804
|
const int n_threads,
|
|
2564
|
-
|
|
2805
|
+
bool save_alignment_heads_QKs,
|
|
2806
|
+
wsp_ggml_abort_callback abort_callback,
|
|
2565
2807
|
void * abort_callback_data) {
|
|
2566
2808
|
const int64_t t_start_us = wsp_ggml_time_us();
|
|
2567
2809
|
|
|
@@ -2583,24 +2825,75 @@ static bool whisper_decode_internal(
|
|
|
2583
2825
|
return false;
|
|
2584
2826
|
}
|
|
2585
2827
|
|
|
2586
|
-
|
|
2828
|
+
const uint32_t pad = whisper_kv_cache_get_padding(wctx);
|
|
2829
|
+
kv_self.n = std::min(kv_self.size, std::max(pad, WSP_GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
|
|
2830
|
+
|
|
2587
2831
|
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
|
2588
2832
|
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
|
2589
2833
|
}
|
|
2590
2834
|
|
|
2591
2835
|
// decoder
|
|
2592
2836
|
{
|
|
2593
|
-
auto &
|
|
2837
|
+
auto & sched = wstate.sched_decode.sched;
|
|
2838
|
+
|
|
2839
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
|
|
2594
2840
|
|
|
2595
|
-
|
|
2841
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2842
|
+
// should never happen as we pre-allocate the memory
|
|
2843
|
+
return false;
|
|
2844
|
+
}
|
|
2845
|
+
|
|
2846
|
+
// set the inputs
|
|
2847
|
+
{
|
|
2848
|
+
struct wsp_ggml_tensor * embd = wsp_ggml_graph_get_tensor(gf, "embd");
|
|
2849
|
+
wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
|
|
2850
|
+
}
|
|
2851
|
+
|
|
2852
|
+
{
|
|
2853
|
+
struct wsp_ggml_tensor * position = wsp_ggml_graph_get_tensor(gf, "position");
|
|
2854
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
2855
|
+
const int32_t val = batch.pos[i];
|
|
2856
|
+
wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
|
2857
|
+
}
|
|
2858
|
+
}
|
|
2859
|
+
|
|
2860
|
+
{
|
|
2861
|
+
struct wsp_ggml_tensor * KQ_mask = wsp_ggml_graph_get_tensor(gf, "KQ_mask");
|
|
2862
|
+
|
|
2863
|
+
auto & kv_self = wstate.kv_self;
|
|
2864
|
+
|
|
2865
|
+
const int32_t n_kv = kv_self.n;
|
|
2866
|
+
|
|
2867
|
+
wstate.inp_mask.resize(wsp_ggml_nelements(KQ_mask));
|
|
2868
|
+
|
|
2869
|
+
float * data = wstate.inp_mask.data();
|
|
2870
|
+
memset(data, 0, wsp_ggml_nbytes(KQ_mask));
|
|
2871
|
+
|
|
2872
|
+
for (int h = 0; h < 1; ++h) {
|
|
2873
|
+
for (int j = 0; j < n_tokens; ++j) {
|
|
2874
|
+
const whisper_pos pos = batch.pos[j];
|
|
2875
|
+
const whisper_seq_id seq_id = batch.seq_id[j][0];
|
|
2596
2876
|
|
|
2597
|
-
|
|
2877
|
+
for (int i = 0; i < n_kv; ++i) {
|
|
2878
|
+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
|
2879
|
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
|
2880
|
+
}
|
|
2881
|
+
}
|
|
2882
|
+
}
|
|
2883
|
+
|
|
2884
|
+
for (int i = n_tokens; i < WSP_GGML_PAD(n_tokens, WSP_GGML_KQ_MASK_PAD); ++i) {
|
|
2885
|
+
for (int j = 0; j < n_kv; ++j) {
|
|
2886
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
2887
|
+
}
|
|
2888
|
+
}
|
|
2889
|
+
}
|
|
2598
2890
|
|
|
2599
|
-
|
|
2891
|
+
wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
|
|
2892
|
+
}
|
|
2600
2893
|
|
|
2601
|
-
logits = gf
|
|
2894
|
+
logits = wsp_ggml_graph_node(gf, -1);
|
|
2602
2895
|
|
|
2603
|
-
if (!wsp_ggml_graph_compute_helper(
|
|
2896
|
+
if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2604
2897
|
return false;
|
|
2605
2898
|
}
|
|
2606
2899
|
}
|
|
@@ -2654,29 +2947,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
2654
2947
|
}
|
|
2655
2948
|
|
|
2656
2949
|
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
|
2657
|
-
|
|
2658
|
-
|
|
2950
|
+
namespace {
|
|
2951
|
+
struct whisper_global_cache {
|
|
2952
|
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
|
2953
|
+
// We can use precalculated values to speed up the process.
|
|
2954
|
+
float sin_vals[SIN_COS_N_COUNT];
|
|
2955
|
+
float cos_vals[SIN_COS_N_COUNT];
|
|
2956
|
+
|
|
2957
|
+
// Hann window (Use cosf to eliminate difference)
|
|
2958
|
+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
|
2959
|
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
|
2960
|
+
float hann_window[WHISPER_N_FFT];
|
|
2961
|
+
|
|
2962
|
+
whisper_global_cache() {
|
|
2963
|
+
fill_sin_cos_table();
|
|
2964
|
+
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
|
|
2965
|
+
}
|
|
2659
2966
|
|
|
2660
|
-
|
|
2661
|
-
|
|
2662
|
-
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
|
|
2667
|
-
|
|
2668
|
-
|
|
2967
|
+
void fill_sin_cos_table() {
|
|
2968
|
+
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
|
2969
|
+
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
|
2970
|
+
sin_vals[i] = sinf(theta);
|
|
2971
|
+
cos_vals[i] = cosf(theta);
|
|
2972
|
+
}
|
|
2973
|
+
}
|
|
2974
|
+
|
|
2975
|
+
void fill_hann_window(int length, bool periodic, float * output) {
|
|
2976
|
+
int offset = -1;
|
|
2977
|
+
if (periodic) {
|
|
2978
|
+
offset = 0;
|
|
2979
|
+
}
|
|
2980
|
+
for (int i = 0; i < length; i++) {
|
|
2981
|
+
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
|
2982
|
+
}
|
|
2669
2983
|
}
|
|
2670
|
-
|
|
2984
|
+
} global_cache;
|
|
2671
2985
|
}
|
|
2672
2986
|
|
|
2673
2987
|
// naive Discrete Fourier Transform
|
|
2674
2988
|
// input is real-valued
|
|
2675
2989
|
// output is complex-valued
|
|
2676
|
-
static void dft(const
|
|
2677
|
-
int N = in.size();
|
|
2678
|
-
|
|
2679
|
-
out.resize(N*2);
|
|
2990
|
+
static void dft(const float* in, int N, float* out) {
|
|
2680
2991
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
|
2681
2992
|
|
|
2682
2993
|
for (int k = 0; k < N; k++) {
|
|
@@ -2685,8 +2996,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2685
2996
|
|
|
2686
2997
|
for (int n = 0; n < N; n++) {
|
|
2687
2998
|
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
|
2688
|
-
re += in[n]*cos_vals[idx]; // cos(t)
|
|
2689
|
-
im -= in[n]*sin_vals[idx]; // sin(t)
|
|
2999
|
+
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
|
3000
|
+
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
|
2690
3001
|
}
|
|
2691
3002
|
|
|
2692
3003
|
out[k*2 + 0] = re;
|
|
@@ -2698,47 +3009,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2698
3009
|
// poor man's implementation - use something better
|
|
2699
3010
|
// input is real-valued
|
|
2700
3011
|
// output is complex-valued
|
|
2701
|
-
static void fft(
|
|
2702
|
-
out.resize(in.size()*2);
|
|
2703
|
-
|
|
2704
|
-
int N = in.size();
|
|
2705
|
-
|
|
3012
|
+
static void fft(float* in, int N, float* out) {
|
|
2706
3013
|
if (N == 1) {
|
|
2707
3014
|
out[0] = in[0];
|
|
2708
3015
|
out[1] = 0;
|
|
2709
3016
|
return;
|
|
2710
3017
|
}
|
|
2711
3018
|
|
|
2712
|
-
|
|
2713
|
-
|
|
3019
|
+
const int half_N = N / 2;
|
|
3020
|
+
if (N - half_N*2 == 1) {
|
|
3021
|
+
dft(in, N, out);
|
|
2714
3022
|
return;
|
|
2715
3023
|
}
|
|
2716
3024
|
|
|
2717
|
-
|
|
2718
|
-
|
|
2719
|
-
|
|
2720
|
-
even.reserve(N/2);
|
|
2721
|
-
odd.reserve(N/2);
|
|
2722
|
-
|
|
2723
|
-
for (int i = 0; i < N; i++) {
|
|
2724
|
-
if (i % 2 == 0) {
|
|
2725
|
-
even.push_back(in[i]);
|
|
2726
|
-
} else {
|
|
2727
|
-
odd.push_back(in[i]);
|
|
2728
|
-
}
|
|
3025
|
+
float* even = in + N;
|
|
3026
|
+
for (int i = 0; i < half_N; ++i) {
|
|
3027
|
+
even[i]= in[2*i];
|
|
2729
3028
|
}
|
|
3029
|
+
float* even_fft = out + 2 * N;
|
|
3030
|
+
fft(even, half_N, even_fft);
|
|
2730
3031
|
|
|
2731
|
-
|
|
2732
|
-
|
|
2733
|
-
|
|
2734
|
-
|
|
2735
|
-
|
|
3032
|
+
float* odd = even;
|
|
3033
|
+
for (int i = 0; i < half_N; ++i) {
|
|
3034
|
+
odd[i] = in[2*i + 1];
|
|
3035
|
+
}
|
|
3036
|
+
float* odd_fft = even_fft + N;
|
|
3037
|
+
fft(odd, half_N, odd_fft);
|
|
2736
3038
|
|
|
2737
3039
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
|
2738
|
-
for (int k = 0; k <
|
|
3040
|
+
for (int k = 0; k < half_N; k++) {
|
|
2739
3041
|
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
|
2740
|
-
float re = cos_vals[idx]; // cos(t)
|
|
2741
|
-
float im = -sin_vals[idx]; // sin(t)
|
|
3042
|
+
float re = global_cache.cos_vals[idx]; // cos(t)
|
|
3043
|
+
float im = -global_cache.sin_vals[idx]; // sin(t)
|
|
2742
3044
|
|
|
2743
3045
|
float re_odd = odd_fft[2*k + 0];
|
|
2744
3046
|
float im_odd = odd_fft[2*k + 1];
|
|
@@ -2746,61 +3048,49 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2746
3048
|
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
|
2747
3049
|
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
|
2748
3050
|
|
|
2749
|
-
out[2*(k +
|
|
2750
|
-
out[2*(k +
|
|
2751
|
-
}
|
|
2752
|
-
}
|
|
2753
|
-
|
|
2754
|
-
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
|
|
2755
|
-
if (output.size() < static_cast<size_t>(length)) {
|
|
2756
|
-
output.resize(length);
|
|
2757
|
-
}
|
|
2758
|
-
int offset = -1;
|
|
2759
|
-
if (periodic) {
|
|
2760
|
-
offset = 0;
|
|
3051
|
+
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
|
3052
|
+
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
|
2761
3053
|
}
|
|
2762
|
-
for (int i = 0; i < length; i++) {
|
|
2763
|
-
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
|
|
2764
|
-
}
|
|
2765
|
-
|
|
2766
|
-
return true;
|
|
2767
3054
|
}
|
|
2768
3055
|
|
|
2769
|
-
static void log_mel_spectrogram_worker_thread(int ith, const
|
|
3056
|
+
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
|
2770
3057
|
int n_samples, int frame_size, int frame_step, int n_threads,
|
|
2771
3058
|
const whisper_filters & filters, whisper_mel & mel) {
|
|
2772
|
-
std::vector<float> fft_in(frame_size, 0.0);
|
|
2773
|
-
std::vector<float> fft_out(2 *
|
|
2774
|
-
|
|
2775
|
-
int n_fft =
|
|
3059
|
+
std::vector<float> fft_in(frame_size * 2, 0.0);
|
|
3060
|
+
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
|
3061
|
+
|
|
3062
|
+
int n_fft = filters.n_fft;
|
|
2776
3063
|
int i = ith;
|
|
2777
3064
|
|
|
3065
|
+
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
|
3066
|
+
assert(n_fft == 1 + (frame_size / 2));
|
|
3067
|
+
|
|
2778
3068
|
// calculate FFT only when fft_in are not all zero
|
|
2779
3069
|
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
|
2780
3070
|
const int offset = i * frame_step;
|
|
2781
3071
|
|
|
2782
|
-
// apply
|
|
3072
|
+
// apply Hann window (~10% faster)
|
|
2783
3073
|
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
|
2784
3074
|
fft_in[j] = hann[j] * samples[offset + j];
|
|
2785
3075
|
}
|
|
3076
|
+
|
|
2786
3077
|
// fill the rest with zeros
|
|
2787
3078
|
if (n_samples - offset < frame_size) {
|
|
2788
3079
|
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
|
|
2789
3080
|
}
|
|
2790
3081
|
|
|
2791
3082
|
// FFT
|
|
2792
|
-
fft(fft_in, fft_out);
|
|
3083
|
+
fft(fft_in.data(), frame_size, fft_out.data());
|
|
2793
3084
|
|
|
2794
3085
|
// Calculate modulus^2 of complex numbers
|
|
2795
3086
|
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
|
2796
|
-
for (int j = 0; j <
|
|
3087
|
+
for (int j = 0; j < n_fft; j++) {
|
|
2797
3088
|
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
|
2798
3089
|
}
|
|
2799
3090
|
|
|
2800
3091
|
// mel spectrogram
|
|
2801
3092
|
for (int j = 0; j < mel.n_mel; j++) {
|
|
2802
3093
|
double sum = 0.0;
|
|
2803
|
-
|
|
2804
3094
|
// unroll loop (suggested by GH user @lunixbochs)
|
|
2805
3095
|
int k = 0;
|
|
2806
3096
|
for (k = 0; k < n_fft - 3; k += 4) {
|
|
@@ -2810,14 +3100,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
|
2810
3100
|
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
|
|
2811
3101
|
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
|
|
2812
3102
|
}
|
|
2813
|
-
|
|
2814
3103
|
// handle n_fft remainder
|
|
2815
3104
|
for (; k < n_fft; k++) {
|
|
2816
3105
|
sum += fft_out[k] * filters.data[j * n_fft + k];
|
|
2817
3106
|
}
|
|
2818
|
-
|
|
2819
3107
|
sum = log10(std::max(sum, 1e-10));
|
|
2820
|
-
|
|
2821
3108
|
mel.data[j * mel.n_len + i] = sum;
|
|
2822
3109
|
}
|
|
2823
3110
|
}
|
|
@@ -2846,12 +3133,9 @@ static bool log_mel_spectrogram(
|
|
|
2846
3133
|
whisper_mel & mel) {
|
|
2847
3134
|
const int64_t t_start_us = wsp_ggml_time_us();
|
|
2848
3135
|
|
|
2849
|
-
//
|
|
2850
|
-
|
|
2851
|
-
|
|
2852
|
-
std::vector<float> hann;
|
|
2853
|
-
hann_window(frame_size, true, hann);
|
|
2854
|
-
|
|
3136
|
+
// Hann window
|
|
3137
|
+
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
|
|
3138
|
+
const float * hann = global_cache.hann_window;
|
|
2855
3139
|
|
|
2856
3140
|
// Calculate the length of padding
|
|
2857
3141
|
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
|
@@ -2876,12 +3160,11 @@ static bool log_mel_spectrogram(
|
|
|
2876
3160
|
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
|
2877
3161
|
mel.data.resize(mel.n_mel * mel.n_len);
|
|
2878
3162
|
|
|
2879
|
-
|
|
2880
3163
|
{
|
|
2881
3164
|
std::vector<std::thread> workers(n_threads - 1);
|
|
2882
3165
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
|
2883
3166
|
workers[iw] = std::thread(
|
|
2884
|
-
log_mel_spectrogram_worker_thread, iw + 1,
|
|
3167
|
+
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
|
|
2885
3168
|
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
|
2886
3169
|
std::cref(filters), std::ref(mel));
|
|
2887
3170
|
}
|
|
@@ -3041,19 +3324,24 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
|
|
3041
3324
|
#endif
|
|
3042
3325
|
|
|
3043
3326
|
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3044
|
-
fill_sin_cos_table();
|
|
3045
|
-
|
|
3046
3327
|
whisper_state * state = new whisper_state;
|
|
3047
3328
|
|
|
3048
|
-
state->
|
|
3049
|
-
|
|
3050
|
-
|
|
3051
|
-
|
|
3052
|
-
|
|
3329
|
+
state->backends = whisper_backend_init(ctx->params);
|
|
3330
|
+
if (state->backends.empty()) {
|
|
3331
|
+
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
|
3332
|
+
whisper_free_state(state);
|
|
3333
|
+
return nullptr;
|
|
3334
|
+
}
|
|
3053
3335
|
|
|
3054
|
-
|
|
3055
|
-
|
|
3056
|
-
|
|
3336
|
+
// at this point, we don't know yet how many decoders will be used
|
|
3337
|
+
// later during decoding, if more decoders are used, we will recreate the KV cache respectively
|
|
3338
|
+
state->kv_self_n_dec = 1;
|
|
3339
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
|
3340
|
+
ctx->model.hparams.n_text_state,
|
|
3341
|
+
ctx->model.hparams.n_text_layer,
|
|
3342
|
+
WSP_GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
|
|
3343
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
|
3344
|
+
whisper_free_state(state);
|
|
3057
3345
|
return nullptr;
|
|
3058
3346
|
}
|
|
3059
3347
|
|
|
@@ -3062,9 +3350,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3062
3350
|
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
3063
3351
|
}
|
|
3064
3352
|
|
|
3065
|
-
if (!
|
|
3066
|
-
|
|
3067
|
-
|
|
3353
|
+
if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
|
|
3354
|
+
ctx->model.hparams.n_text_state,
|
|
3355
|
+
ctx->model.hparams.n_text_layer,
|
|
3356
|
+
WSP_GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
|
3357
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
3358
|
+
whisper_free_state(state);
|
|
3068
3359
|
return nullptr;
|
|
3069
3360
|
}
|
|
3070
3361
|
|
|
@@ -3073,6 +3364,31 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3073
3364
|
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
3074
3365
|
}
|
|
3075
3366
|
|
|
3367
|
+
if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
|
|
3368
|
+
ctx->model.hparams.n_audio_state,
|
|
3369
|
+
1,
|
|
3370
|
+
WSP_GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
|
3371
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
|
3372
|
+
whisper_free_state(state);
|
|
3373
|
+
return nullptr;
|
|
3374
|
+
}
|
|
3375
|
+
|
|
3376
|
+
{
|
|
3377
|
+
const size_t memory_size = wsp_ggml_nbytes(state->kv_pad.k) + wsp_ggml_nbytes(state->kv_pad.v);
|
|
3378
|
+
WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
3379
|
+
}
|
|
3380
|
+
|
|
3381
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
3382
|
+
if (ctx->params.dtw_token_timestamps) {
|
|
3383
|
+
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
|
|
3384
|
+
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
|
|
3385
|
+
whisper_free_state(state);
|
|
3386
|
+
return nullptr;
|
|
3387
|
+
}
|
|
3388
|
+
const size_t memory_size = aheads_masks_nbytes(state->aheads_masks);
|
|
3389
|
+
WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size);
|
|
3390
|
+
}
|
|
3391
|
+
|
|
3076
3392
|
|
|
3077
3393
|
#ifdef WHISPER_USE_COREML
|
|
3078
3394
|
if (ctx->params.use_coreml) {
|
|
@@ -3085,7 +3401,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3085
3401
|
if (!state->ctx_coreml) {
|
|
3086
3402
|
WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
3087
3403
|
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
|
3088
|
-
|
|
3404
|
+
whisper_free_state(state);
|
|
3089
3405
|
return nullptr;
|
|
3090
3406
|
#endif
|
|
3091
3407
|
} else {
|
|
@@ -3110,37 +3426,55 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3110
3426
|
|
|
3111
3427
|
// conv allocator
|
|
3112
3428
|
{
|
|
3113
|
-
|
|
3429
|
+
bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
|
|
3114
3430
|
[&]() {
|
|
3115
|
-
return whisper_build_graph_conv(*ctx, *state
|
|
3431
|
+
return whisper_build_graph_conv(*ctx, *state);
|
|
3116
3432
|
});
|
|
3117
3433
|
|
|
3118
|
-
|
|
3434
|
+
if (!ok) {
|
|
3435
|
+
WHISPER_LOG_ERROR("%s: failed to init conv allocator\n", __func__);
|
|
3436
|
+
whisper_free_state(state);
|
|
3437
|
+
return nullptr;
|
|
3438
|
+
}
|
|
3439
|
+
|
|
3440
|
+
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
|
|
3119
3441
|
}
|
|
3120
3442
|
|
|
3121
3443
|
// encoder allocator
|
|
3122
3444
|
if (!whisper_encode_external(*state)) {
|
|
3123
|
-
|
|
3445
|
+
bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
|
|
3124
3446
|
[&]() {
|
|
3125
3447
|
return whisper_build_graph_encoder(*ctx, *state);
|
|
3126
3448
|
});
|
|
3127
3449
|
|
|
3128
|
-
|
|
3450
|
+
if (!ok) {
|
|
3451
|
+
WHISPER_LOG_ERROR("%s: failed to init encoder allocator\n", __func__);
|
|
3452
|
+
whisper_free_state(state);
|
|
3453
|
+
return nullptr;
|
|
3454
|
+
}
|
|
3455
|
+
|
|
3456
|
+
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
|
|
3129
3457
|
}
|
|
3130
3458
|
|
|
3131
3459
|
// cross allocator
|
|
3132
3460
|
{
|
|
3133
|
-
|
|
3461
|
+
bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
|
|
3134
3462
|
[&]() {
|
|
3135
3463
|
return whisper_build_graph_cross(*ctx, *state);
|
|
3136
3464
|
});
|
|
3137
3465
|
|
|
3138
|
-
|
|
3466
|
+
if (!ok) {
|
|
3467
|
+
WHISPER_LOG_ERROR("%s: failed to init cross allocator\n", __func__);
|
|
3468
|
+
whisper_free_state(state);
|
|
3469
|
+
return nullptr;
|
|
3470
|
+
}
|
|
3471
|
+
|
|
3472
|
+
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
|
|
3139
3473
|
}
|
|
3140
3474
|
|
|
3141
3475
|
// decoder allocator
|
|
3142
3476
|
{
|
|
3143
|
-
|
|
3477
|
+
bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
|
|
3144
3478
|
[&]() {
|
|
3145
3479
|
const auto & hparams = ctx->model.hparams;
|
|
3146
3480
|
|
|
@@ -3150,27 +3484,30 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3150
3484
|
|
|
3151
3485
|
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
|
|
3152
3486
|
|
|
3153
|
-
return whisper_build_graph_decoder(*ctx, *state, state->batch);
|
|
3487
|
+
return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true);
|
|
3154
3488
|
});
|
|
3155
3489
|
|
|
3156
|
-
|
|
3157
|
-
|
|
3490
|
+
if (!ok) {
|
|
3491
|
+
WHISPER_LOG_ERROR("%s: failed to init decoder allocator\n", __func__);
|
|
3492
|
+
whisper_free_state(state);
|
|
3493
|
+
return nullptr;
|
|
3494
|
+
}
|
|
3158
3495
|
|
|
3159
|
-
|
|
3160
|
-
|
|
3161
|
-
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
|
|
3162
|
-
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
|
|
3496
|
+
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
|
|
3497
|
+
}
|
|
3163
3498
|
|
|
3164
3499
|
return state;
|
|
3165
3500
|
}
|
|
3166
3501
|
|
|
3167
|
-
int
|
|
3502
|
+
int whisper_ctx_init_openvino_encoder_with_state(
|
|
3168
3503
|
struct whisper_context * ctx,
|
|
3504
|
+
struct whisper_state * state,
|
|
3169
3505
|
const char * model_path,
|
|
3170
3506
|
const char * device,
|
|
3171
3507
|
const char * cache_dir) {
|
|
3172
3508
|
#ifndef WHISPER_USE_OPENVINO
|
|
3173
3509
|
(void)(ctx);
|
|
3510
|
+
(void)(state);
|
|
3174
3511
|
(void)(model_path);
|
|
3175
3512
|
(void)(device);
|
|
3176
3513
|
(void)(cache_dir);
|
|
@@ -3201,8 +3538,8 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
3201
3538
|
WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
|
|
3202
3539
|
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
|
3203
3540
|
|
|
3204
|
-
|
|
3205
|
-
if (!
|
|
3541
|
+
state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
|
|
3542
|
+
if (!state->ctx_openvino) {
|
|
3206
3543
|
WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
|
|
3207
3544
|
return 1;
|
|
3208
3545
|
} else {
|
|
@@ -3213,18 +3550,43 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
3213
3550
|
#endif
|
|
3214
3551
|
}
|
|
3215
3552
|
|
|
3553
|
+
int whisper_ctx_init_openvino_encoder(
|
|
3554
|
+
struct whisper_context * ctx,
|
|
3555
|
+
const char * model_path,
|
|
3556
|
+
const char * device,
|
|
3557
|
+
const char * cache_dir) {
|
|
3558
|
+
return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
|
|
3559
|
+
}
|
|
3560
|
+
|
|
3216
3561
|
struct whisper_context_params whisper_context_default_params() {
|
|
3217
3562
|
struct whisper_context_params result = {
|
|
3218
|
-
/*.use_gpu
|
|
3219
|
-
/*.use_coreml
|
|
3563
|
+
/*.use_gpu =*/ true,
|
|
3564
|
+
/*.use_coreml =*/ false,
|
|
3565
|
+
/*.flash_attn =*/ false,
|
|
3566
|
+
/*.gpu_device =*/ 0,
|
|
3567
|
+
|
|
3568
|
+
/*.dtw_token_timestamps =*/ false,
|
|
3569
|
+
/*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE,
|
|
3570
|
+
/*.dtw_n_top =*/ -1,
|
|
3571
|
+
/*.dtw_aheads =*/ {
|
|
3572
|
+
/*.n_heads =*/ 0,
|
|
3573
|
+
/*.heads =*/ NULL,
|
|
3574
|
+
},
|
|
3575
|
+
/*.dtw_mem_size =*/ 1024*1024*128,
|
|
3220
3576
|
};
|
|
3221
3577
|
return result;
|
|
3222
3578
|
}
|
|
3223
3579
|
|
|
3224
3580
|
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
|
|
3225
3581
|
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
|
|
3226
|
-
|
|
3582
|
+
#ifdef _MSC_VER
|
|
3583
|
+
// Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
|
|
3584
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
|
3585
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
|
3586
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
|
3587
|
+
#else
|
|
3227
3588
|
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
3589
|
+
#endif
|
|
3228
3590
|
if (!fin) {
|
|
3229
3591
|
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
|
|
3230
3592
|
return nullptr;
|
|
@@ -3299,6 +3661,19 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
|
|
|
3299
3661
|
struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
|
|
3300
3662
|
wsp_ggml_time_init();
|
|
3301
3663
|
|
|
3664
|
+
if (params.flash_attn && params.dtw_token_timestamps) {
|
|
3665
|
+
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
|
|
3666
|
+
params.dtw_token_timestamps = false;
|
|
3667
|
+
}
|
|
3668
|
+
|
|
3669
|
+
WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
|
|
3670
|
+
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
|
|
3671
|
+
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
|
|
3672
|
+
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
|
|
3673
|
+
|
|
3674
|
+
// TODO: temporary call to force backend registry initialization
|
|
3675
|
+
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, wsp_ggml_backend_reg_count());
|
|
3676
|
+
|
|
3302
3677
|
whisper_context * ctx = new whisper_context;
|
|
3303
3678
|
ctx->params = params;
|
|
3304
3679
|
|
|
@@ -3383,11 +3758,11 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
|
|
3383
3758
|
return whisper_init_with_params_no_state(loader, whisper_context_default_params());
|
|
3384
3759
|
}
|
|
3385
3760
|
|
|
3386
|
-
void whisper_free_state(struct whisper_state * state)
|
|
3387
|
-
{
|
|
3761
|
+
void whisper_free_state(struct whisper_state * state) {
|
|
3388
3762
|
if (state) {
|
|
3389
|
-
|
|
3390
|
-
|
|
3763
|
+
whisper_kv_cache_free(state->kv_self);
|
|
3764
|
+
whisper_kv_cache_free(state->kv_cross);
|
|
3765
|
+
whisper_kv_cache_free(state->kv_pad);
|
|
3391
3766
|
|
|
3392
3767
|
#ifdef WHISPER_USE_COREML
|
|
3393
3768
|
if (state->ctx_coreml != nullptr) {
|
|
@@ -3405,12 +3780,17 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3405
3780
|
|
|
3406
3781
|
whisper_batch_free(state->batch);
|
|
3407
3782
|
|
|
3408
|
-
|
|
3409
|
-
|
|
3410
|
-
|
|
3411
|
-
|
|
3783
|
+
wsp_ggml_backend_sched_free(state->sched_conv.sched);
|
|
3784
|
+
wsp_ggml_backend_sched_free(state->sched_encode.sched);
|
|
3785
|
+
wsp_ggml_backend_sched_free(state->sched_cross.sched);
|
|
3786
|
+
wsp_ggml_backend_sched_free(state->sched_decode.sched);
|
|
3787
|
+
|
|
3788
|
+
for (auto & backend : state->backends) {
|
|
3789
|
+
wsp_ggml_backend_free(backend);
|
|
3790
|
+
}
|
|
3412
3791
|
|
|
3413
|
-
|
|
3792
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
3793
|
+
aheads_masks_free(state->aheads_masks);
|
|
3414
3794
|
|
|
3415
3795
|
delete state;
|
|
3416
3796
|
}
|
|
@@ -3418,20 +3798,12 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3418
3798
|
|
|
3419
3799
|
void whisper_free(struct whisper_context * ctx) {
|
|
3420
3800
|
if (ctx) {
|
|
3421
|
-
|
|
3422
|
-
wsp_ggml_free(ctx->model.ctx);
|
|
3423
|
-
}
|
|
3801
|
+
wsp_ggml_free(ctx->model.ctx);
|
|
3424
3802
|
|
|
3425
|
-
|
|
3426
|
-
if (buffer) {
|
|
3427
|
-
wsp_ggml_backend_buffer_free(buffer);
|
|
3428
|
-
}
|
|
3429
|
-
}
|
|
3803
|
+
wsp_ggml_backend_buffer_free(ctx->model.buffer);
|
|
3430
3804
|
|
|
3431
3805
|
whisper_free_state(ctx->state);
|
|
3432
3806
|
|
|
3433
|
-
wsp_ggml_backend_free(ctx->backend);
|
|
3434
|
-
|
|
3435
3807
|
delete ctx;
|
|
3436
3808
|
}
|
|
3437
3809
|
}
|
|
@@ -3461,30 +3833,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
|
3461
3833
|
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
|
3462
3834
|
}
|
|
3463
3835
|
|
|
3464
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
|
3465
|
-
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
|
3466
|
-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
|
3467
|
-
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
|
3468
|
-
return -1;
|
|
3469
|
-
}
|
|
3470
|
-
|
|
3471
|
-
return 0;
|
|
3472
|
-
}
|
|
3473
|
-
|
|
3474
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
|
3475
|
-
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
|
3476
|
-
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
|
3477
|
-
}
|
|
3478
|
-
|
|
3479
|
-
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
|
|
3480
|
-
// TODO
|
|
3481
|
-
|
|
3482
|
-
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
|
|
3483
|
-
// TODO
|
|
3484
|
-
|
|
3485
|
-
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
|
|
3486
|
-
// TODO
|
|
3487
|
-
|
|
3488
3836
|
int whisper_set_mel_with_state(
|
|
3489
3837
|
struct whisper_context * ctx,
|
|
3490
3838
|
struct whisper_state * state,
|
|
@@ -3537,7 +3885,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
|
3537
3885
|
|
|
3538
3886
|
whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
|
|
3539
3887
|
|
|
3540
|
-
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
|
|
3888
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) {
|
|
3541
3889
|
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3542
3890
|
return 1;
|
|
3543
3891
|
}
|
|
@@ -3559,7 +3907,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
|
|
|
3559
3907
|
|
|
3560
3908
|
if (n_max_tokens < (int) res.size()) {
|
|
3561
3909
|
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
|
3562
|
-
return -
|
|
3910
|
+
return -(int) res.size();
|
|
3563
3911
|
}
|
|
3564
3912
|
|
|
3565
3913
|
for (int i = 0; i < (int) res.size(); i++) {
|
|
@@ -3569,7 +3917,11 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
|
|
|
3569
3917
|
return res.size();
|
|
3570
3918
|
}
|
|
3571
3919
|
|
|
3572
|
-
int
|
|
3920
|
+
int whisper_token_count(struct whisper_context * ctx, const char * text) {
|
|
3921
|
+
return -whisper_tokenize(ctx, text, NULL, 0);
|
|
3922
|
+
}
|
|
3923
|
+
|
|
3924
|
+
int whisper_lang_max_id(void) {
|
|
3573
3925
|
auto max_id = 0;
|
|
3574
3926
|
for (const auto & kv : g_lang) {
|
|
3575
3927
|
max_id = std::max(max_id, kv.second.first);
|
|
@@ -3838,28 +4190,51 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
|
|
|
3838
4190
|
return ctx->vocab.token_transcribe;
|
|
3839
4191
|
}
|
|
3840
4192
|
|
|
4193
|
+
struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
|
|
4194
|
+
if (ctx->state == nullptr) {
|
|
4195
|
+
return nullptr;
|
|
4196
|
+
}
|
|
4197
|
+
return new whisper_timings {
|
|
4198
|
+
.load_us = ctx->t_load_us,
|
|
4199
|
+
.t_start_us = ctx->t_start_us,
|
|
4200
|
+
.fail_p = ctx->state->n_fail_p,
|
|
4201
|
+
.fail_h = ctx->state->n_fail_h,
|
|
4202
|
+
.t_mel_us = ctx->state->t_mel_us,
|
|
4203
|
+
.n_sample = ctx->state->n_sample,
|
|
4204
|
+
.n_encode = ctx->state->n_encode,
|
|
4205
|
+
.n_decode = ctx->state->n_decode,
|
|
4206
|
+
.n_batchd = ctx->state->n_batchd,
|
|
4207
|
+
.n_prompt = ctx->state->n_prompt,
|
|
4208
|
+
.t_sample_us = ctx->state->t_sample_us,
|
|
4209
|
+
.t_encode_us = ctx->state->t_encode_us,
|
|
4210
|
+
.t_decode_us = ctx->state->t_decode_us,
|
|
4211
|
+
.t_batchd_us = ctx->state->t_batchd_us,
|
|
4212
|
+
.t_prompt_us = ctx->state->t_prompt_us,
|
|
4213
|
+
};
|
|
4214
|
+
}
|
|
4215
|
+
|
|
3841
4216
|
void whisper_print_timings(struct whisper_context * ctx) {
|
|
3842
4217
|
const int64_t t_end_us = wsp_ggml_time_us();
|
|
4218
|
+
const struct whisper_timings * timings = whisper_get_timings(ctx);
|
|
3843
4219
|
|
|
3844
4220
|
WHISPER_LOG_INFO("\n");
|
|
3845
|
-
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__,
|
|
4221
|
+
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings->load_us / 1000.0f);
|
|
3846
4222
|
if (ctx->state != nullptr) {
|
|
3847
|
-
|
|
3848
4223
|
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
3849
4224
|
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
|
3850
4225
|
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
3851
4226
|
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
|
|
3852
4227
|
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
|
3853
4228
|
|
|
3854
|
-
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__,
|
|
3855
|
-
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__,
|
|
3856
|
-
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
3857
|
-
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
3858
|
-
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
3859
|
-
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
3860
|
-
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
4229
|
+
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings->fail_p, timings->fail_h);
|
|
4230
|
+
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings->t_mel_us/1000.0f);
|
|
4231
|
+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_sample_us, n_sample, 1e-3f * timings->t_sample_us / n_sample);
|
|
4232
|
+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_encode_us, n_encode, 1e-3f * timings->t_encode_us / n_encode);
|
|
4233
|
+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_decode_us, n_decode, 1e-3f * timings->t_decode_us / n_decode);
|
|
4234
|
+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_batchd_us, n_batchd, 1e-3f * timings->t_batchd_us / n_batchd);
|
|
4235
|
+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_prompt_us, n_prompt, 1e-3f * timings->t_prompt_us / n_prompt);
|
|
3861
4236
|
}
|
|
3862
|
-
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us -
|
|
4237
|
+
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings->t_start_us)/1000.0f);
|
|
3863
4238
|
}
|
|
3864
4239
|
|
|
3865
4240
|
void whisper_reset_timings(struct whisper_context * ctx) {
|
|
@@ -3913,10 +4288,10 @@ const char * whisper_print_system_info(void) {
|
|
|
3913
4288
|
s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
|
|
3914
4289
|
s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
|
|
3915
4290
|
s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
|
|
3916
|
-
s += "CUDA = " + std::to_string(
|
|
4291
|
+
s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cuda()) + " | ";
|
|
3917
4292
|
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
|
3918
4293
|
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
|
3919
|
-
|
|
4294
|
+
s += "CANN = " + std::to_string(wsp_ggml_cpu_has_cann()) ;
|
|
3920
4295
|
return s.c_str();
|
|
3921
4296
|
}
|
|
3922
4297
|
|
|
@@ -3926,7 +4301,7 @@ const char * whisper_print_system_info(void) {
|
|
|
3926
4301
|
|
|
3927
4302
|
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
|
3928
4303
|
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
|
3929
|
-
std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
|
4304
|
+
static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
|
3930
4305
|
const char * src,
|
|
3931
4306
|
whisper_partial_utf8 partial_start) {
|
|
3932
4307
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
|
@@ -4340,7 +4715,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
|
|
|
4340
4715
|
|
|
4341
4716
|
////////////////////////////////////////////////////////////////////////////
|
|
4342
4717
|
|
|
4343
|
-
struct whisper_context_params * whisper_context_default_params_by_ref() {
|
|
4718
|
+
struct whisper_context_params * whisper_context_default_params_by_ref(void) {
|
|
4344
4719
|
struct whisper_context_params params = whisper_context_default_params();
|
|
4345
4720
|
|
|
4346
4721
|
struct whisper_context_params* result = new whisper_context_params();
|
|
@@ -4381,12 +4756,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
4381
4756
|
/*.split_on_word =*/ false,
|
|
4382
4757
|
/*.max_tokens =*/ 0,
|
|
4383
4758
|
|
|
4384
|
-
/*.speed_up =*/ false,
|
|
4385
4759
|
/*.debug_mode =*/ false,
|
|
4386
4760
|
/*.audio_ctx =*/ 0,
|
|
4387
4761
|
|
|
4388
4762
|
/*.tdrz_enable =*/ false,
|
|
4389
4763
|
|
|
4764
|
+
/* suppress_regex =*/ nullptr,
|
|
4765
|
+
|
|
4390
4766
|
/*.initial_prompt =*/ nullptr,
|
|
4391
4767
|
/*.prompt_tokens =*/ nullptr,
|
|
4392
4768
|
/*.prompt_n_tokens =*/ 0,
|
|
@@ -4472,6 +4848,17 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
|
|
|
4472
4848
|
return txt[0] == ' ';
|
|
4473
4849
|
}
|
|
4474
4850
|
|
|
4851
|
+
static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
4852
|
+
struct whisper_context * ctx,
|
|
4853
|
+
struct whisper_state * state,
|
|
4854
|
+
struct whisper_full_params params,
|
|
4855
|
+
int i_segment,
|
|
4856
|
+
size_t n_segments,
|
|
4857
|
+
int seek,
|
|
4858
|
+
int n_frames,
|
|
4859
|
+
int medfilt_width,
|
|
4860
|
+
int n_threads);
|
|
4861
|
+
|
|
4475
4862
|
// wrap the last segment to max_len characters
|
|
4476
4863
|
// returns the number of new segments
|
|
4477
4864
|
static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
|
|
@@ -4619,6 +5006,17 @@ static void whisper_process_logits(
|
|
|
4619
5006
|
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
|
4620
5007
|
}
|
|
4621
5008
|
|
|
5009
|
+
// suppress any tokens matching a regular expression
|
|
5010
|
+
// ref: https://github.com/openai/whisper/discussions/1041
|
|
5011
|
+
if (params.suppress_regex != nullptr) {
|
|
5012
|
+
std::regex re(params.suppress_regex);
|
|
5013
|
+
for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
|
|
5014
|
+
if (std::regex_match(token_id.first, re)) {
|
|
5015
|
+
logits[token_id.second] = -INFINITY;
|
|
5016
|
+
}
|
|
5017
|
+
}
|
|
5018
|
+
}
|
|
5019
|
+
|
|
4622
5020
|
// suppress non-speech tokens
|
|
4623
5021
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
|
4624
5022
|
if (params.suppress_non_speech_tokens) {
|
|
@@ -4822,12 +5220,25 @@ static void whisper_process_logits(
|
|
|
4822
5220
|
#endif
|
|
4823
5221
|
}
|
|
4824
5222
|
|
|
5223
|
+
static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) {
|
|
5224
|
+
if (a.tokens.size() != b.tokens.size()) {
|
|
5225
|
+
return false;
|
|
5226
|
+
}
|
|
5227
|
+
// sequences are more likely to diverge at the end
|
|
5228
|
+
for (int i = a.tokens.size() - 1; i >= 0; i--) {
|
|
5229
|
+
if (a.tokens[i].id != b.tokens[i].id) {
|
|
5230
|
+
return false;
|
|
5231
|
+
}
|
|
5232
|
+
}
|
|
5233
|
+
return true;
|
|
5234
|
+
}
|
|
5235
|
+
|
|
4825
5236
|
static whisper_token_data whisper_sample_token(
|
|
4826
5237
|
whisper_context & ctx,
|
|
4827
5238
|
const whisper_decoder & decoder,
|
|
4828
5239
|
bool best) {
|
|
4829
5240
|
whisper_token_data result = {
|
|
4830
|
-
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
|
5241
|
+
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
|
|
4831
5242
|
};
|
|
4832
5243
|
|
|
4833
5244
|
const auto & vocab = ctx.vocab;
|
|
@@ -4945,7 +5356,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
4945
5356
|
const auto id = dist(decoder.rng);
|
|
4946
5357
|
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
|
4947
5358
|
|
|
4948
|
-
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
|
5359
|
+
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
|
|
4949
5360
|
|
|
4950
5361
|
if (result[i].id >= vocab.token_beg) {
|
|
4951
5362
|
result[i].tid = result[i].id;
|
|
@@ -5018,15 +5429,9 @@ int whisper_full_with_state(
|
|
|
5018
5429
|
|
|
5019
5430
|
if (n_samples > 0) {
|
|
5020
5431
|
// compute log mel spectrogram
|
|
5021
|
-
if (params.
|
|
5022
|
-
// TODO: Replace PV with more advanced algorithm
|
|
5432
|
+
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
5023
5433
|
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
5024
|
-
return -
|
|
5025
|
-
} else {
|
|
5026
|
-
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
5027
|
-
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
5028
|
-
return -2;
|
|
5029
|
-
}
|
|
5434
|
+
return -2;
|
|
5030
5435
|
}
|
|
5031
5436
|
}
|
|
5032
5437
|
|
|
@@ -5063,8 +5468,8 @@ int whisper_full_with_state(
|
|
|
5063
5468
|
// if length of spectrogram is less than 1.0s (100 frames), then return
|
|
5064
5469
|
// basically don't process anything that is less than 1.0s
|
|
5065
5470
|
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
|
5066
|
-
if (seek_end < seek_start +
|
|
5067
|
-
|
|
5471
|
+
if (seek_end < seek_start + 100) {
|
|
5472
|
+
WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
|
|
5068
5473
|
return 0;
|
|
5069
5474
|
}
|
|
5070
5475
|
|
|
@@ -5127,7 +5532,12 @@ int whisper_full_with_state(
|
|
|
5127
5532
|
// initial prompt
|
|
5128
5533
|
if (!params.prompt_tokens && params.initial_prompt) {
|
|
5129
5534
|
prompt_tokens.resize(1024);
|
|
5130
|
-
|
|
5535
|
+
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
|
|
5536
|
+
if (n_needed < 0) {
|
|
5537
|
+
prompt_tokens.resize(-n_needed);
|
|
5538
|
+
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
|
|
5539
|
+
}
|
|
5540
|
+
prompt_tokens.resize(n_needed);
|
|
5131
5541
|
params.prompt_tokens = prompt_tokens.data();
|
|
5132
5542
|
params.prompt_n_tokens = prompt_tokens.size();
|
|
5133
5543
|
}
|
|
@@ -5163,11 +5573,11 @@ int whisper_full_with_state(
|
|
|
5163
5573
|
}
|
|
5164
5574
|
}
|
|
5165
5575
|
|
|
5166
|
-
// distilled models require the "no_timestamps" token
|
|
5576
|
+
// first release distilled models require the "no_timestamps" token
|
|
5167
5577
|
{
|
|
5168
|
-
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
|
|
5578
|
+
const bool is_distil = ctx->model.hparams.n_text_layer == 2 && ctx->model.hparams.n_vocab != 51866;
|
|
5169
5579
|
if (is_distil && !params.no_timestamps) {
|
|
5170
|
-
WHISPER_LOG_WARN("%s: using distilled
|
|
5580
|
+
WHISPER_LOG_WARN("%s: using first release distilled models - forcing no_timestamps\n", __func__);
|
|
5171
5581
|
params.no_timestamps = true;
|
|
5172
5582
|
}
|
|
5173
5583
|
}
|
|
@@ -5303,13 +5713,34 @@ int whisper_full_with_state(
|
|
|
5303
5713
|
}
|
|
5304
5714
|
WHISPER_LOG_DEBUG("\n\n");
|
|
5305
5715
|
|
|
5716
|
+
// recreate the KV cache if the number of decoders has changed
|
|
5717
|
+
if (state->kv_self_n_dec < n_decoders_cur) {
|
|
5718
|
+
WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
|
|
5719
|
+
|
|
5720
|
+
whisper_kv_cache_free(state->kv_self);
|
|
5721
|
+
|
|
5722
|
+
// overallocate to workaround KV cache fragmentation issues
|
|
5723
|
+
const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
|
|
5724
|
+
|
|
5725
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
|
5726
|
+
ctx->model.hparams.n_text_state,
|
|
5727
|
+
ctx->model.hparams.n_text_layer,
|
|
5728
|
+
WSP_GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
|
5729
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
|
5730
|
+
whisper_free_state(state);
|
|
5731
|
+
return -7;
|
|
5732
|
+
}
|
|
5733
|
+
|
|
5734
|
+
state->kv_self_n_dec = n_decoders_cur;
|
|
5735
|
+
}
|
|
5736
|
+
|
|
5306
5737
|
whisper_kv_cache_clear(state->kv_self);
|
|
5307
5738
|
|
|
5308
5739
|
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
|
5309
5740
|
|
|
5310
|
-
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
5741
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
|
5311
5742
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
5312
|
-
return -
|
|
5743
|
+
return -8;
|
|
5313
5744
|
}
|
|
5314
5745
|
|
|
5315
5746
|
{
|
|
@@ -5420,7 +5851,10 @@ int whisper_full_with_state(
|
|
|
5420
5851
|
beam_candidates.begin(),
|
|
5421
5852
|
beam_candidates.end(),
|
|
5422
5853
|
[](const beam_candidate & a, const beam_candidate & b) {
|
|
5423
|
-
|
|
5854
|
+
if (a.sequence.sum_logprobs_all != b.sequence.sum_logprobs_all) {
|
|
5855
|
+
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
|
5856
|
+
}
|
|
5857
|
+
return a.decoder_idx < b.decoder_idx;
|
|
5424
5858
|
});
|
|
5425
5859
|
|
|
5426
5860
|
uint32_t cur_c = 0;
|
|
@@ -5438,7 +5872,7 @@ int whisper_full_with_state(
|
|
|
5438
5872
|
|
|
5439
5873
|
auto & cur = beam_candidates[cur_c++];
|
|
5440
5874
|
|
|
5441
|
-
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence
|
|
5875
|
+
while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) {
|
|
5442
5876
|
++cur_c;
|
|
5443
5877
|
}
|
|
5444
5878
|
|
|
@@ -5604,9 +6038,9 @@ int whisper_full_with_state(
|
|
|
5604
6038
|
|
|
5605
6039
|
assert(batch.n_tokens > 0);
|
|
5606
6040
|
|
|
5607
|
-
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
6041
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
|
5608
6042
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
5609
|
-
return -
|
|
6043
|
+
return -9;
|
|
5610
6044
|
}
|
|
5611
6045
|
|
|
5612
6046
|
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
@@ -5727,6 +6161,9 @@ int whisper_full_with_state(
|
|
|
5727
6161
|
|
|
5728
6162
|
const auto & tokens_cur = best_decoder.sequence.tokens;
|
|
5729
6163
|
|
|
6164
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
6165
|
+
const auto n_segments_before = state->result_all.size();
|
|
6166
|
+
|
|
5730
6167
|
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
|
5731
6168
|
|
|
5732
6169
|
// update prompt_past
|
|
@@ -5764,8 +6201,8 @@ int whisper_full_with_state(
|
|
|
5764
6201
|
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
|
5765
6202
|
|
|
5766
6203
|
if (!text.empty()) {
|
|
5767
|
-
const auto tt0 =
|
|
5768
|
-
const auto tt1 =
|
|
6204
|
+
const auto tt0 = t0;
|
|
6205
|
+
const auto tt1 = t1;
|
|
5769
6206
|
|
|
5770
6207
|
if (params.print_realtime) {
|
|
5771
6208
|
if (params.print_timestamps) {
|
|
@@ -5793,7 +6230,7 @@ int whisper_full_with_state(
|
|
|
5793
6230
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
|
5794
6231
|
}
|
|
5795
6232
|
}
|
|
5796
|
-
if (params.new_segment_callback) {
|
|
6233
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
|
5797
6234
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
|
5798
6235
|
}
|
|
5799
6236
|
}
|
|
@@ -5811,8 +6248,8 @@ int whisper_full_with_state(
|
|
|
5811
6248
|
if (!text.empty()) {
|
|
5812
6249
|
const auto t1 = seek + seek_delta;
|
|
5813
6250
|
|
|
5814
|
-
const auto tt0 =
|
|
5815
|
-
const auto tt1 =
|
|
6251
|
+
const auto tt0 = t0;
|
|
6252
|
+
const auto tt1 = t1;
|
|
5816
6253
|
|
|
5817
6254
|
if (params.print_realtime) {
|
|
5818
6255
|
if (params.print_timestamps) {
|
|
@@ -5838,12 +6275,28 @@ int whisper_full_with_state(
|
|
|
5838
6275
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
|
5839
6276
|
}
|
|
5840
6277
|
}
|
|
5841
|
-
if (params.new_segment_callback) {
|
|
6278
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
|
5842
6279
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
|
5843
6280
|
}
|
|
5844
6281
|
}
|
|
5845
6282
|
}
|
|
5846
6283
|
|
|
6284
|
+
// FIXME: will timestamp offsets be correct?
|
|
6285
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
6286
|
+
{
|
|
6287
|
+
const int n_segments = state->result_all.size() - n_segments_before;
|
|
6288
|
+
if (ctx->params.dtw_token_timestamps && n_segments) {
|
|
6289
|
+
const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
|
|
6290
|
+
whisper_exp_compute_token_level_timestamps_dtw(
|
|
6291
|
+
ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
|
|
6292
|
+
if (params.new_segment_callback) {
|
|
6293
|
+
for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
|
|
6294
|
+
params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
|
|
6295
|
+
}
|
|
6296
|
+
}
|
|
6297
|
+
}
|
|
6298
|
+
}
|
|
6299
|
+
|
|
5847
6300
|
// update audio window
|
|
5848
6301
|
seek += seek_delta;
|
|
5849
6302
|
|
|
@@ -6603,7 +7056,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
6603
7056
|
k++;
|
|
6604
7057
|
}
|
|
6605
7058
|
tokens[j].t1 = sample_to_timestamp(k);
|
|
6606
|
-
if (j <
|
|
7059
|
+
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
|
6607
7060
|
tokens[j].t1 = tokens[j + 1].t0;
|
|
6608
7061
|
} else {
|
|
6609
7062
|
s1 = k;
|
|
@@ -6646,6 +7099,322 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
6646
7099
|
//}
|
|
6647
7100
|
}
|
|
6648
7101
|
|
|
7102
|
+
//
|
|
7103
|
+
// token level timestamps - dtw version
|
|
7104
|
+
//
|
|
7105
|
+
|
|
7106
|
+
// n_text_layer -> total text layers on model
|
|
7107
|
+
// n_head -> total heads per text layer on model
|
|
7108
|
+
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) {
|
|
7109
|
+
std::vector<uint32_t> ret;
|
|
7110
|
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
|
|
7111
|
+
return ret;
|
|
7112
|
+
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
|
|
7113
|
+
if (il >= n_text_layer - cparams.dtw_n_top) {
|
|
7114
|
+
for (int32_t i = 0; i < n_head; ++i) {
|
|
7115
|
+
ret.push_back(i);
|
|
7116
|
+
}
|
|
7117
|
+
}
|
|
7118
|
+
} else {
|
|
7119
|
+
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
|
|
7120
|
+
for (size_t i = 0; i < aheads.n_heads; ++i) {
|
|
7121
|
+
if (aheads.heads[i].n_text_layer == il) {
|
|
7122
|
+
ret.push_back(aheads.heads[i].n_head);
|
|
7123
|
+
}
|
|
7124
|
+
}
|
|
7125
|
+
}
|
|
7126
|
+
return ret;
|
|
7127
|
+
}
|
|
7128
|
+
|
|
7129
|
+
// dtw + backtrace to return found path
|
|
7130
|
+
// based on
|
|
7131
|
+
// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
|
|
7132
|
+
static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tensor * x) {
|
|
7133
|
+
WHISPER_ASSERT(wsp_ggml_n_dims(x) == 2);
|
|
7134
|
+
|
|
7135
|
+
int64_t N = x->ne[0];
|
|
7136
|
+
int64_t M = x->ne[1];
|
|
7137
|
+
struct wsp_ggml_tensor * cost = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, N + 1, M + 1);
|
|
7138
|
+
struct wsp_ggml_tensor * trace = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, N + 1, M + 1);
|
|
7139
|
+
|
|
7140
|
+
cost = wsp_ggml_set_f32(cost, INFINITY);
|
|
7141
|
+
trace = wsp_ggml_set_f32(trace, -1);
|
|
7142
|
+
wsp_ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
|
|
7143
|
+
|
|
7144
|
+
// dtw
|
|
7145
|
+
// supposedly can be optmized by computing diagonals in parallel ?
|
|
7146
|
+
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
|
|
7147
|
+
for (int64_t j = 1; j < M + 1; ++j) {
|
|
7148
|
+
for (int64_t i = 1; i < N + 1; ++i) {
|
|
7149
|
+
float c0 = wsp_ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
|
|
7150
|
+
float c1 = wsp_ggml_get_f32_nd(cost, i - 1, j, 0, 0);
|
|
7151
|
+
float c2 = wsp_ggml_get_f32_nd(cost, i, j - 1, 0, 0);
|
|
7152
|
+
|
|
7153
|
+
float c;
|
|
7154
|
+
int32_t t;
|
|
7155
|
+
if (c0 < c1 && c0 < c2) {
|
|
7156
|
+
c = c0;
|
|
7157
|
+
t = 0;
|
|
7158
|
+
} else if (c1 < c0 && c1 < c2) {
|
|
7159
|
+
c = c1;
|
|
7160
|
+
t = 1;
|
|
7161
|
+
} else {
|
|
7162
|
+
c = c2;
|
|
7163
|
+
t = 2;
|
|
7164
|
+
}
|
|
7165
|
+
|
|
7166
|
+
c = wsp_ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
|
|
7167
|
+
wsp_ggml_set_f32_nd(cost, i, j, 0, 0, c);
|
|
7168
|
+
wsp_ggml_set_i32_nd(trace, i, j, 0, 0, t);
|
|
7169
|
+
}
|
|
7170
|
+
}
|
|
7171
|
+
|
|
7172
|
+
// Backtrace
|
|
7173
|
+
const int64_t BT_MAX_ROWS = N + M - 1;
|
|
7174
|
+
struct wsp_ggml_tensor * bt = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, BT_MAX_ROWS, 2);
|
|
7175
|
+
// trace[0, :] = 2;
|
|
7176
|
+
for (int64_t i = 0; i < M + 1; ++i)
|
|
7177
|
+
wsp_ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
|
|
7178
|
+
//trace[:, 0] = 1;
|
|
7179
|
+
for (int64_t i = 0; i < N + 1; ++i)
|
|
7180
|
+
wsp_ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
|
|
7181
|
+
int bt_row_idx = BT_MAX_ROWS - 1;
|
|
7182
|
+
int64_t i = N;
|
|
7183
|
+
int64_t j = M;
|
|
7184
|
+
while (i > 0 || j > 0) {
|
|
7185
|
+
wsp_ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
|
|
7186
|
+
wsp_ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
|
|
7187
|
+
--bt_row_idx;
|
|
7188
|
+
|
|
7189
|
+
int32_t t = wsp_ggml_get_i32_nd(trace, i, j, 0, 0);
|
|
7190
|
+
if (t == 0) {
|
|
7191
|
+
--i;
|
|
7192
|
+
--j;
|
|
7193
|
+
} else if (t == 1) {
|
|
7194
|
+
--i;
|
|
7195
|
+
} else if (t == 2) {
|
|
7196
|
+
--j;
|
|
7197
|
+
} else {
|
|
7198
|
+
WHISPER_ASSERT(0);
|
|
7199
|
+
}
|
|
7200
|
+
}
|
|
7201
|
+
|
|
7202
|
+
// FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs)
|
|
7203
|
+
// Clip + transpose
|
|
7204
|
+
// This might not be entirely necessary for our case, but leaving it for now so output matrix
|
|
7205
|
+
// is identical to dtw on openAI timing.py
|
|
7206
|
+
const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1;
|
|
7207
|
+
wsp_ggml_tensor * r = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, 2, result_n_cols);
|
|
7208
|
+
for (int64_t i = 0; i < 2; ++i) {
|
|
7209
|
+
for (int64_t j = 0; j < result_n_cols; ++j) {
|
|
7210
|
+
int32_t v = wsp_ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
|
|
7211
|
+
wsp_ggml_set_i32_nd(r, i, j, 0, 0, v);
|
|
7212
|
+
}
|
|
7213
|
+
}
|
|
7214
|
+
|
|
7215
|
+
return r;
|
|
7216
|
+
}
|
|
7217
|
+
|
|
7218
|
+
struct median_filter_user_data {
|
|
7219
|
+
int filter_width;
|
|
7220
|
+
};
|
|
7221
|
+
|
|
7222
|
+
static void median_filter(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
|
|
7223
|
+
if (ith != 0) {
|
|
7224
|
+
return;
|
|
7225
|
+
}
|
|
7226
|
+
int filter_width = ((median_filter_user_data *) userdata)->filter_width;
|
|
7227
|
+
WHISPER_ASSERT(filter_width < a->ne[2]);
|
|
7228
|
+
WHISPER_ASSERT(filter_width % 2);
|
|
7229
|
+
WHISPER_ASSERT(wsp_ggml_n_dims(a) == 3);
|
|
7230
|
+
WHISPER_ASSERT(a->type == WSP_GGML_TYPE_F32);
|
|
7231
|
+
|
|
7232
|
+
std::vector<float> filter;
|
|
7233
|
+
filter.reserve(filter_width);
|
|
7234
|
+
for (int64_t i = 0; i < a->ne[0]; ++i) {
|
|
7235
|
+
for (int64_t j = 0; j < a->ne[1]; ++j) {
|
|
7236
|
+
for (int64_t k = 0; k < a->ne[2]; ++k) {
|
|
7237
|
+
for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) {
|
|
7238
|
+
// "reflect" padding
|
|
7239
|
+
int64_t idx = k + off;
|
|
7240
|
+
if (idx < 0) {
|
|
7241
|
+
idx = -idx;
|
|
7242
|
+
} else if (idx >= a->ne[2]) {
|
|
7243
|
+
idx = 2*(a->ne[2] - 1) - idx;
|
|
7244
|
+
}
|
|
7245
|
+
|
|
7246
|
+
filter.push_back(wsp_ggml_get_f32_nd(a, i, j, idx, 0));
|
|
7247
|
+
}
|
|
7248
|
+
std::sort(filter.begin(), filter.end());
|
|
7249
|
+
const float v = filter[filter.size()/2];
|
|
7250
|
+
wsp_ggml_set_f32_nd(dst, i, j, k, 0, v);
|
|
7251
|
+
filter.clear();
|
|
7252
|
+
}
|
|
7253
|
+
}
|
|
7254
|
+
}
|
|
7255
|
+
}
|
|
7256
|
+
|
|
7257
|
+
static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7258
|
+
struct whisper_context * ctx,
|
|
7259
|
+
struct whisper_state * state,
|
|
7260
|
+
struct whisper_full_params params,
|
|
7261
|
+
int i_segment,
|
|
7262
|
+
size_t n_segments,
|
|
7263
|
+
int seek,
|
|
7264
|
+
int n_frames,
|
|
7265
|
+
int medfilt_width,
|
|
7266
|
+
int n_threads)
|
|
7267
|
+
{
|
|
7268
|
+
const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx;
|
|
7269
|
+
WHISPER_ASSERT(medfilt_width % 2);
|
|
7270
|
+
WHISPER_ASSERT(n_frames <= n_audio_ctx * 2);
|
|
7271
|
+
WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE);
|
|
7272
|
+
|
|
7273
|
+
// FIXME: Allocating mem everytime we call this func
|
|
7274
|
+
// Our ggml buffer should be pre-allocated somewhere during init and reused
|
|
7275
|
+
// when we call this function
|
|
7276
|
+
struct wsp_ggml_init_params gparams = {
|
|
7277
|
+
/*.mem_size =*/ ctx->params.dtw_mem_size,
|
|
7278
|
+
/*.mem_buffer =*/ NULL,
|
|
7279
|
+
/*.no_alloc =*/ false,
|
|
7280
|
+
};
|
|
7281
|
+
struct wsp_ggml_context * gctx = wsp_ggml_init(gparams);
|
|
7282
|
+
|
|
7283
|
+
// Build token sequence that will be passed to decoder
|
|
7284
|
+
// sot + [lang] + text result + eot
|
|
7285
|
+
std::vector<whisper_token> tokens = { whisper_token_sot(ctx), };
|
|
7286
|
+
if (whisper_is_multilingual(ctx)) {
|
|
7287
|
+
const int lang_id = whisper_lang_id(params.language);
|
|
7288
|
+
state->lang_id = lang_id;
|
|
7289
|
+
tokens.push_back(whisper_token_lang(ctx, lang_id));
|
|
7290
|
+
}
|
|
7291
|
+
const size_t sot_sequence_length = tokens.size();
|
|
7292
|
+
tokens.push_back(whisper_token_not(ctx));
|
|
7293
|
+
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
|
|
7294
|
+
auto & segment = state->result_all[i];
|
|
7295
|
+
for (auto &t: segment.tokens) {
|
|
7296
|
+
// Only text tokens
|
|
7297
|
+
if (t.id < whisper_token_eot(ctx)) {
|
|
7298
|
+
tokens.push_back(t.id);
|
|
7299
|
+
}
|
|
7300
|
+
}
|
|
7301
|
+
}
|
|
7302
|
+
tokens.push_back(whisper_token_eot(ctx));
|
|
7303
|
+
|
|
7304
|
+
// Get result tokens, pass then along to decoder to get cross attention QKs
|
|
7305
|
+
// used in timestamping
|
|
7306
|
+
// Decoder already returns only alignment head QKs, already concatenated in
|
|
7307
|
+
// one tensor.
|
|
7308
|
+
whisper_kv_cache_clear(state->kv_self);
|
|
7309
|
+
whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0);
|
|
7310
|
+
whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1);
|
|
7311
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
|
|
7312
|
+
WHISPER_LOG_INFO("DECODER FAILED\n");
|
|
7313
|
+
WHISPER_ASSERT(0);
|
|
7314
|
+
}
|
|
7315
|
+
WHISPER_ASSERT(state->aheads_cross_QKs != nullptr);
|
|
7316
|
+
|
|
7317
|
+
const auto n_audio_tokens = n_frames/2;
|
|
7318
|
+
WHISPER_ASSERT(state->aheads_cross_QKs != NULL);
|
|
7319
|
+
WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]);
|
|
7320
|
+
const auto n_tokens = state->aheads_cross_QKs->ne[0];
|
|
7321
|
+
const auto n_heads = state->aheads_cross_QKs->ne[2];
|
|
7322
|
+
|
|
7323
|
+
// Copy data from decoder buffer to a local CPU tensor, discarding unused audio
|
|
7324
|
+
// tokens (i.e. discarding rows at the end of tensor)
|
|
7325
|
+
// IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims
|
|
7326
|
+
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
|
7327
|
+
WHISPER_ASSERT(state->aheads_cross_QKs->type == WSP_GGML_TYPE_F32);
|
|
7328
|
+
WHISPER_ASSERT(wsp_ggml_is_contiguous(state->aheads_cross_QKs));
|
|
7329
|
+
wsp_ggml_tensor * w = wsp_ggml_new_tensor_3d(gctx, WSP_GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads);
|
|
7330
|
+
auto & data = state->aheads_cross_QKs_data;
|
|
7331
|
+
data.resize(n_tokens * n_audio_ctx * n_heads);
|
|
7332
|
+
wsp_ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);
|
|
7333
|
+
for (int k = 0; k < n_heads; ++k) {
|
|
7334
|
+
for (int j = 0; j < n_audio_tokens; ++j) {
|
|
7335
|
+
memcpy(
|
|
7336
|
+
(char *) w->data + j * w->nb[1] + k * w->nb[2],
|
|
7337
|
+
data.data() + j * n_tokens + k * n_tokens * n_audio_ctx,
|
|
7338
|
+
n_tokens * sizeof(float)
|
|
7339
|
+
);
|
|
7340
|
+
}
|
|
7341
|
+
}
|
|
7342
|
+
|
|
7343
|
+
// Normalize - in original OpenAI code, this is done over dim=-2. In this case,
|
|
7344
|
+
// we already permuted N_TOKENS dimension to columns on last loop, becase wsp_ggml_norm
|
|
7345
|
+
// operates over columns. Afterwards, permute to a shape that facilitates mean
|
|
7346
|
+
// operation (after median filter)
|
|
7347
|
+
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
|
7348
|
+
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
|
7349
|
+
w = wsp_ggml_norm(gctx, w, 1e-9f);
|
|
7350
|
+
w = wsp_ggml_permute(gctx, wsp_ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
|
|
7351
|
+
|
|
7352
|
+
// Pass median filter - this is done over AUDIO_TOKENS dimension.
|
|
7353
|
+
// IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
|
7354
|
+
// OUT: Same dims
|
|
7355
|
+
median_filter_user_data mf_user_data = {medfilt_width};
|
|
7356
|
+
w = wsp_ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data);
|
|
7357
|
+
|
|
7358
|
+
// Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT
|
|
7359
|
+
// IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
|
7360
|
+
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims
|
|
7361
|
+
w = wsp_ggml_mean(gctx, w);
|
|
7362
|
+
w = wsp_ggml_scale(gctx, w, -1.0);
|
|
7363
|
+
w = wsp_ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]);
|
|
7364
|
+
|
|
7365
|
+
// Remove SOT sequence and EOT
|
|
7366
|
+
// Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS
|
|
7367
|
+
w = wsp_ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]);
|
|
7368
|
+
|
|
7369
|
+
// Compute
|
|
7370
|
+
struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(gctx);
|
|
7371
|
+
wsp_ggml_build_forward_expand(gf, w);
|
|
7372
|
+
wsp_ggml_graph_compute_with_ctx(gctx, gf, n_threads);
|
|
7373
|
+
|
|
7374
|
+
wsp_ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
|
|
7375
|
+
|
|
7376
|
+
// Place timestamps on segments
|
|
7377
|
+
int32_t last_v = 0;
|
|
7378
|
+
auto seg_i = state->result_all.begin() + i_segment;
|
|
7379
|
+
auto tok_i = seg_i->tokens.begin();
|
|
7380
|
+
for (int i = 0; i < alignment->ne[1]; ++i) {
|
|
7381
|
+
int32_t v = wsp_ggml_get_i32_nd(alignment, 0, i, 0, 0);
|
|
7382
|
+
if (v != last_v) {
|
|
7383
|
+
int32_t time_index = wsp_ggml_get_i32_nd(alignment, 1, i, 0, 0);
|
|
7384
|
+
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
|
|
7385
|
+
last_v = v;
|
|
7386
|
+
|
|
7387
|
+
// Skip non-text tokens
|
|
7388
|
+
while (!(tok_i->id < whisper_token_eot(ctx))) {
|
|
7389
|
+
++tok_i;
|
|
7390
|
+
if (tok_i == seg_i->tokens.end()) {
|
|
7391
|
+
++seg_i;
|
|
7392
|
+
tok_i = seg_i->tokens.begin();
|
|
7393
|
+
}
|
|
7394
|
+
}
|
|
7395
|
+
|
|
7396
|
+
tok_i->t_dtw = timestamp;
|
|
7397
|
+
++tok_i;
|
|
7398
|
+
if (tok_i == seg_i->tokens.end()) {
|
|
7399
|
+
++seg_i;
|
|
7400
|
+
tok_i = seg_i->tokens.begin();
|
|
7401
|
+
}
|
|
7402
|
+
}
|
|
7403
|
+
}
|
|
7404
|
+
|
|
7405
|
+
// Print DTW timestamps
|
|
7406
|
+
/*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
|
|
7407
|
+
auto & segment = state->result_all[i];
|
|
7408
|
+
for (auto &t: segment.tokens) {
|
|
7409
|
+
const char * tok = whisper_token_to_str(ctx, t.id);
|
|
7410
|
+
fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100);
|
|
7411
|
+
}
|
|
7412
|
+
fprintf(stderr, "\n");
|
|
7413
|
+
}*/
|
|
7414
|
+
|
|
7415
|
+
wsp_ggml_free(gctx);
|
|
7416
|
+
}
|
|
7417
|
+
|
|
6649
7418
|
void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
|
|
6650
7419
|
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
|
6651
7420
|
g_state.log_callback_user_data = user_data;
|