whisper.rn 0.4.0-rc.7 → 0.4.0-rc.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +2 -1
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
- package/android/src/main/jni.cpp +29 -1
- package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
- package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
- package/cpp/coreml/whisper-encoder.mm +1 -1
- package/cpp/ggml-aarch64.c +3209 -0
- package/cpp/ggml-aarch64.h +39 -0
- package/cpp/ggml-alloc.c +732 -494
- package/cpp/ggml-alloc.h +47 -63
- package/cpp/ggml-backend-impl.h +162 -47
- package/cpp/ggml-backend.cpp +2635 -0
- package/cpp/ggml-backend.h +216 -71
- package/cpp/ggml-common.h +1853 -0
- package/cpp/ggml-cpu-impl.h +614 -0
- package/cpp/ggml-impl.h +144 -178
- package/cpp/ggml-metal.h +14 -60
- package/cpp/ggml-metal.m +3437 -2097
- package/cpp/ggml-quants.c +12559 -4189
- package/cpp/ggml-quants.h +135 -212
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +9029 -5219
- package/cpp/ggml.h +673 -338
- package/cpp/rn-whisper.cpp +91 -0
- package/cpp/rn-whisper.h +2 -0
- package/cpp/whisper.cpp +1476 -675
- package/cpp/whisper.h +84 -28
- package/ios/RNWhisper.mm +124 -37
- package/ios/RNWhisperAudioUtils.h +1 -0
- package/ios/RNWhisperAudioUtils.m +20 -13
- package/ios/RNWhisperContext.h +3 -2
- package/ios/RNWhisperContext.mm +41 -8
- package/jest/mock.js +9 -1
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +48 -19
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +48 -19
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +6 -3
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +25 -3
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +6 -5
- package/src/NativeRNWhisper.ts +12 -3
- package/src/index.ts +63 -24
- package/src/version.json +1 -1
- package/whisper-rn.podspec +9 -2
- package/cpp/ggml-backend.c +0 -1357
- package/cpp/ggml-metal-whisper.metal +0 -4908
package/cpp/whisper.cpp
CHANGED
|
@@ -8,14 +8,30 @@
|
|
|
8
8
|
#include "ggml-metal.h"
|
|
9
9
|
#endif
|
|
10
10
|
|
|
11
|
-
#ifdef
|
|
11
|
+
#ifdef WSP_GGML_USE_CUDA
|
|
12
12
|
#include "ggml-cuda.h"
|
|
13
13
|
#endif
|
|
14
14
|
|
|
15
|
+
#ifdef WSP_GGML_USE_SYCL
|
|
16
|
+
#include "ggml-sycl.h"
|
|
17
|
+
#endif
|
|
18
|
+
|
|
19
|
+
#ifdef WSP_GGML_USE_VULKAN
|
|
20
|
+
#include "ggml-vulkan.h"
|
|
21
|
+
#endif
|
|
22
|
+
|
|
23
|
+
#ifdef WSP_GGML_USE_BLAS
|
|
24
|
+
#include "ggml-blas.h"
|
|
25
|
+
#endif
|
|
26
|
+
|
|
15
27
|
#ifdef WHISPER_USE_OPENVINO
|
|
16
28
|
#include "openvino/whisper-openvino-encoder.h"
|
|
17
29
|
#endif
|
|
18
30
|
|
|
31
|
+
#ifdef WSP_GGML_USE_CANN
|
|
32
|
+
#include "ggml-cann.h"
|
|
33
|
+
#endif
|
|
34
|
+
|
|
19
35
|
#include "ggml.h"
|
|
20
36
|
#include "ggml-alloc.h"
|
|
21
37
|
#include "ggml-backend.h"
|
|
@@ -37,6 +53,7 @@
|
|
|
37
53
|
#include <regex>
|
|
38
54
|
#include <random>
|
|
39
55
|
#include <functional>
|
|
56
|
+
#include <codecvt>
|
|
40
57
|
|
|
41
58
|
#if defined(_MSC_VER)
|
|
42
59
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
@@ -122,9 +139,18 @@ WHISPER_ATTRIBUTE_FORMAT(2, 3)
|
|
|
122
139
|
static void whisper_log_internal (wsp_ggml_log_level level, const char * format, ...);
|
|
123
140
|
static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data);
|
|
124
141
|
|
|
125
|
-
#define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
|
126
|
-
#define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
|
127
142
|
#define WHISPER_LOG_ERROR(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
|
143
|
+
#define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
|
144
|
+
#define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
|
145
|
+
|
|
146
|
+
// define this to enable verbose trace logging - useful for debugging purposes
|
|
147
|
+
//#define WHISPER_DEBUG
|
|
148
|
+
|
|
149
|
+
#if defined(WHISPER_DEBUG)
|
|
150
|
+
#define WHISPER_LOG_DEBUG(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
|
151
|
+
#else
|
|
152
|
+
#define WHISPER_LOG_DEBUG(...)
|
|
153
|
+
#endif
|
|
128
154
|
|
|
129
155
|
#define WHISPER_ASSERT(x) \
|
|
130
156
|
do { \
|
|
@@ -134,20 +160,6 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char *
|
|
|
134
160
|
} \
|
|
135
161
|
} while (0)
|
|
136
162
|
|
|
137
|
-
// define this to enable verbose trace logging - useful for debugging purposes
|
|
138
|
-
//#define WHISPER_DEBUG
|
|
139
|
-
|
|
140
|
-
#if defined(WHISPER_DEBUG)
|
|
141
|
-
#define WHISPER_PRINT_DEBUG(...) \
|
|
142
|
-
do { \
|
|
143
|
-
fprintf(stderr, __VA_ARGS__); \
|
|
144
|
-
} while (0)
|
|
145
|
-
#else
|
|
146
|
-
#define WHISPER_PRINT_DEBUG(...)
|
|
147
|
-
#endif
|
|
148
|
-
|
|
149
|
-
//#define WHISPER_USE_FLASH_ATTN
|
|
150
|
-
//#define WHISPER_USE_FLASH_FF
|
|
151
163
|
#define WHISPER_MAX_DECODERS 8
|
|
152
164
|
#define WHISPER_MAX_NODES 4096
|
|
153
165
|
|
|
@@ -155,15 +167,15 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char *
|
|
|
155
167
|
// ggml helpers
|
|
156
168
|
//
|
|
157
169
|
|
|
158
|
-
static
|
|
170
|
+
static bool wsp_ggml_graph_compute_helper(
|
|
159
171
|
struct wsp_ggml_cgraph * graph,
|
|
160
172
|
std::vector<uint8_t> & buf,
|
|
161
173
|
int n_threads,
|
|
162
|
-
|
|
174
|
+
wsp_ggml_abort_callback abort_callback,
|
|
163
175
|
void * abort_callback_data) {
|
|
164
|
-
struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads);
|
|
176
|
+
struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads, nullptr);
|
|
165
177
|
|
|
166
|
-
plan.abort_callback
|
|
178
|
+
plan.abort_callback = abort_callback;
|
|
167
179
|
plan.abort_callback_data = abort_callback_data;
|
|
168
180
|
|
|
169
181
|
if (plan.work_size > 0) {
|
|
@@ -171,22 +183,29 @@ static void wsp_ggml_graph_compute_helper(
|
|
|
171
183
|
plan.work_data = buf.data();
|
|
172
184
|
}
|
|
173
185
|
|
|
174
|
-
wsp_ggml_graph_compute(graph, &plan);
|
|
186
|
+
return wsp_ggml_graph_compute(graph, &plan);
|
|
175
187
|
}
|
|
176
188
|
|
|
177
|
-
static
|
|
178
|
-
|
|
189
|
+
static bool wsp_ggml_graph_compute_helper(
|
|
190
|
+
wsp_ggml_backend_sched_t sched,
|
|
179
191
|
struct wsp_ggml_cgraph * graph,
|
|
180
192
|
int n_threads) {
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
193
|
+
|
|
194
|
+
for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(sched); ++i) {
|
|
195
|
+
wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(sched, i);
|
|
196
|
+
if (wsp_ggml_backend_is_cpu(backend)) {
|
|
197
|
+
wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
|
|
198
|
+
}
|
|
199
|
+
#ifdef WSP_GGML_USE_BLAS
|
|
200
|
+
if (wsp_ggml_backend_is_blas(backend)) {
|
|
201
|
+
wsp_ggml_backend_blas_set_n_threads(backend, n_threads);
|
|
202
|
+
}
|
|
188
203
|
#endif
|
|
189
|
-
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
bool t = wsp_ggml_backend_sched_graph_compute(sched, graph) == WSP_GGML_STATUS_SUCCESS;
|
|
207
|
+
wsp_ggml_backend_sched_reset(sched);
|
|
208
|
+
return t;
|
|
190
209
|
}
|
|
191
210
|
|
|
192
211
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
|
@@ -350,6 +369,37 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
|
350
369
|
{ "yue", { 99, "cantonese", } },
|
|
351
370
|
};
|
|
352
371
|
|
|
372
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
373
|
+
static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} };
|
|
374
|
+
static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} };
|
|
375
|
+
static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} };
|
|
376
|
+
static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} };
|
|
377
|
+
static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} };
|
|
378
|
+
static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} };
|
|
379
|
+
static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} };
|
|
380
|
+
static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} };
|
|
381
|
+
static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
|
|
382
|
+
static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
|
|
383
|
+
static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
|
|
384
|
+
static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
|
|
385
|
+
|
|
386
|
+
static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
|
387
|
+
{ WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
|
|
388
|
+
{ WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } },
|
|
389
|
+
{ WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } },
|
|
390
|
+
{ WHISPER_AHEADS_BASE, { 8, g_aheads_base } },
|
|
391
|
+
{ WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } },
|
|
392
|
+
{ WHISPER_AHEADS_SMALL, { 10, g_aheads_small } },
|
|
393
|
+
{ WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } },
|
|
394
|
+
{ WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } },
|
|
395
|
+
{ WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
|
|
396
|
+
{ WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
|
|
397
|
+
{ WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
|
|
398
|
+
{ WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
|
|
399
|
+
};
|
|
400
|
+
|
|
401
|
+
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
|
|
402
|
+
|
|
353
403
|
struct whisper_mel {
|
|
354
404
|
int n_len;
|
|
355
405
|
int n_len_org;
|
|
@@ -412,7 +462,7 @@ struct whisper_batch {
|
|
|
412
462
|
|
|
413
463
|
whisper_token * token;
|
|
414
464
|
whisper_pos * pos;
|
|
415
|
-
int32_t * n_seq_id;
|
|
465
|
+
int32_t * n_seq_id; // always 1, here for consistency with llama.cpp
|
|
416
466
|
whisper_seq_id ** seq_id; // null terminated
|
|
417
467
|
int8_t * logits;
|
|
418
468
|
};
|
|
@@ -472,54 +522,42 @@ struct whisper_pair {
|
|
|
472
522
|
whisper_pair() : first(A()), second(B()) {}
|
|
473
523
|
};
|
|
474
524
|
|
|
475
|
-
//
|
|
476
|
-
struct
|
|
477
|
-
|
|
525
|
+
// wsp_ggml_backend_sched wrapper for whisper usage
|
|
526
|
+
struct whisper_sched {
|
|
527
|
+
wsp_ggml_backend_sched_t sched = nullptr;
|
|
478
528
|
|
|
479
529
|
std::vector<uint8_t> meta;
|
|
480
|
-
|
|
481
|
-
wsp_ggml_backend_buffer_t buffer;
|
|
482
530
|
};
|
|
483
531
|
|
|
484
|
-
static size_t
|
|
485
|
-
|
|
532
|
+
static size_t whisper_sched_size(struct whisper_sched & allocr) {
|
|
533
|
+
size_t size = allocr.meta.size();
|
|
534
|
+
for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
|
|
535
|
+
wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(allocr.sched, i);
|
|
536
|
+
size += wsp_ggml_backend_sched_get_buffer_size(allocr.sched, backend);
|
|
537
|
+
}
|
|
538
|
+
return size;
|
|
486
539
|
}
|
|
487
540
|
|
|
488
541
|
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
|
489
|
-
static
|
|
490
|
-
auto &
|
|
491
|
-
auto & meta
|
|
542
|
+
static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<wsp_ggml_backend_t> backends, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
|
|
543
|
+
auto & sched = allocr.sched;
|
|
544
|
+
auto & meta = allocr.meta;
|
|
492
545
|
|
|
493
|
-
|
|
546
|
+
sched = wsp_ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
|
|
494
547
|
|
|
495
548
|
meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
|
|
496
549
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
return;
|
|
550
|
+
// since there are dependencies between the different graphs,
|
|
551
|
+
// we need to allocate them instead of only reserving to get the correct compute buffer size
|
|
552
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, get_graph())) {
|
|
553
|
+
// failed to allocate the compute buffer
|
|
554
|
+
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
|
555
|
+
return false;
|
|
504
556
|
}
|
|
505
557
|
|
|
506
|
-
|
|
507
|
-
auto & buffer = allocr.buffer;
|
|
508
|
-
|
|
509
|
-
size_t size = wsp_ggml_allocr_max_size(alloc);
|
|
558
|
+
wsp_ggml_backend_sched_reset(sched);
|
|
510
559
|
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
buffer = wsp_ggml_backend_alloc_buffer(backend, size);
|
|
514
|
-
alloc = wsp_ggml_allocr_new_from_buffer(buffer);
|
|
515
|
-
}
|
|
516
|
-
|
|
517
|
-
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
|
518
|
-
if (allocr.alloc) {
|
|
519
|
-
wsp_ggml_allocr_free(allocr.alloc);
|
|
520
|
-
wsp_ggml_backend_buffer_free(allocr.buffer);
|
|
521
|
-
allocr.alloc = nullptr;
|
|
522
|
-
}
|
|
560
|
+
return true;
|
|
523
561
|
}
|
|
524
562
|
|
|
525
563
|
// medium
|
|
@@ -661,9 +699,9 @@ struct whisper_kv_cache {
|
|
|
661
699
|
struct wsp_ggml_tensor * k;
|
|
662
700
|
struct wsp_ggml_tensor * v;
|
|
663
701
|
|
|
664
|
-
|
|
702
|
+
wsp_ggml_backend_buffer_t buffer = nullptr;
|
|
665
703
|
|
|
666
|
-
|
|
704
|
+
std::vector<uint8_t> ctx_buf;
|
|
667
705
|
};
|
|
668
706
|
|
|
669
707
|
struct whisper_model {
|
|
@@ -701,10 +739,10 @@ struct whisper_model {
|
|
|
701
739
|
std::vector<whisper_layer_decoder> layers_decoder;
|
|
702
740
|
|
|
703
741
|
// ggml context that contains all the meta information about the model tensors
|
|
704
|
-
struct wsp_ggml_context * ctx;
|
|
742
|
+
struct wsp_ggml_context * ctx = nullptr;
|
|
705
743
|
|
|
706
744
|
// the model backend data is read-only and can be shared between processors
|
|
707
|
-
|
|
745
|
+
wsp_ggml_backend_buffer_t buffer = nullptr;
|
|
708
746
|
|
|
709
747
|
// tensors
|
|
710
748
|
int n_loaded;
|
|
@@ -769,6 +807,13 @@ struct whisper_decoder {
|
|
|
769
807
|
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
770
808
|
};
|
|
771
809
|
|
|
810
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
811
|
+
struct whisper_aheads_masks {
|
|
812
|
+
std::vector<struct wsp_ggml_tensor *> m; // One mask per text layer.
|
|
813
|
+
struct wsp_ggml_context * ctx = nullptr;
|
|
814
|
+
wsp_ggml_backend_buffer_t buffer = nullptr;
|
|
815
|
+
};
|
|
816
|
+
|
|
772
817
|
struct whisper_state {
|
|
773
818
|
int64_t t_sample_us = 0;
|
|
774
819
|
int64_t t_encode_us = 0;
|
|
@@ -785,6 +830,9 @@ struct whisper_state {
|
|
|
785
830
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
|
786
831
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
|
787
832
|
|
|
833
|
+
// number of decoders for which we have constructed the KV cache
|
|
834
|
+
int32_t kv_self_n_dec = 0;
|
|
835
|
+
|
|
788
836
|
// unified self-attention KV cache for all decoders
|
|
789
837
|
whisper_kv_cache kv_self;
|
|
790
838
|
|
|
@@ -792,21 +840,22 @@ struct whisper_state {
|
|
|
792
840
|
// shared between all decoders
|
|
793
841
|
whisper_kv_cache kv_cross;
|
|
794
842
|
|
|
843
|
+
// padded buffer for flash-attention
|
|
844
|
+
whisper_kv_cache kv_pad;
|
|
845
|
+
|
|
795
846
|
whisper_mel mel;
|
|
796
847
|
|
|
797
848
|
whisper_batch batch;
|
|
798
849
|
|
|
799
850
|
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
|
800
851
|
|
|
801
|
-
wsp_ggml_backend_t
|
|
852
|
+
std::vector<wsp_ggml_backend_t> backends;
|
|
802
853
|
|
|
803
|
-
// ggml-alloc:
|
|
804
854
|
// - stores meta info about the intermediate tensors into the `meta` buffers
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
whisper_allocr alloc_decode;
|
|
855
|
+
whisper_sched sched_conv;
|
|
856
|
+
whisper_sched sched_encode;
|
|
857
|
+
whisper_sched sched_cross;
|
|
858
|
+
whisper_sched sched_decode;
|
|
810
859
|
|
|
811
860
|
// result of the encoder
|
|
812
861
|
struct wsp_ggml_tensor * embd_conv = nullptr;
|
|
@@ -842,6 +891,11 @@ struct whisper_state {
|
|
|
842
891
|
|
|
843
892
|
std::vector<float> energy; // PCM signal energy
|
|
844
893
|
|
|
894
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
895
|
+
whisper_aheads_masks aheads_masks;
|
|
896
|
+
wsp_ggml_tensor * aheads_cross_QKs = nullptr;
|
|
897
|
+
std::vector<float> aheads_cross_QKs_data;
|
|
898
|
+
|
|
845
899
|
// [EXPERIMENTAL] speed-up techniques
|
|
846
900
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
|
847
901
|
};
|
|
@@ -860,8 +914,6 @@ struct whisper_context {
|
|
|
860
914
|
|
|
861
915
|
whisper_state * state = nullptr;
|
|
862
916
|
|
|
863
|
-
wsp_ggml_backend_t backend = nullptr;
|
|
864
|
-
|
|
865
917
|
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
866
918
|
};
|
|
867
919
|
|
|
@@ -879,21 +931,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
|
879
931
|
BYTESWAP_VALUE(dest);
|
|
880
932
|
}
|
|
881
933
|
|
|
882
|
-
static bool
|
|
883
|
-
const struct whisper_hparams & hparams,
|
|
934
|
+
static bool whisper_kv_cache_init(
|
|
884
935
|
struct whisper_kv_cache & cache,
|
|
885
936
|
wsp_ggml_backend_t backend,
|
|
886
937
|
wsp_ggml_type wtype,
|
|
938
|
+
int64_t n_text_state,
|
|
939
|
+
int64_t n_text_layer,
|
|
887
940
|
int n_ctx) {
|
|
888
|
-
const int64_t n_text_state = hparams.n_text_state;
|
|
889
|
-
const int64_t n_text_layer = hparams.n_text_layer;
|
|
890
|
-
|
|
891
941
|
const int64_t n_mem = n_text_layer*n_ctx;
|
|
892
942
|
const int64_t n_elements = n_text_state*n_mem;
|
|
893
943
|
|
|
944
|
+
cache.ctx_buf.resize(2*wsp_ggml_tensor_overhead());
|
|
945
|
+
|
|
894
946
|
struct wsp_ggml_init_params params = {
|
|
895
|
-
/*.mem_size =*/
|
|
896
|
-
/*.mem_buffer =*/
|
|
947
|
+
/*.mem_size =*/ cache.ctx_buf.size(),
|
|
948
|
+
/*.mem_buffer =*/ cache.ctx_buf.data(),
|
|
897
949
|
/*.no_alloc =*/ true,
|
|
898
950
|
};
|
|
899
951
|
|
|
@@ -903,39 +955,31 @@ static bool kv_cache_init(
|
|
|
903
955
|
cache.cells.clear();
|
|
904
956
|
cache.cells.resize(n_ctx);
|
|
905
957
|
|
|
906
|
-
|
|
958
|
+
struct wsp_ggml_context * ctx = wsp_ggml_init(params);
|
|
907
959
|
|
|
908
|
-
if (!
|
|
909
|
-
WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
|
|
960
|
+
if (!ctx) {
|
|
961
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
|
|
910
962
|
return false;
|
|
911
963
|
}
|
|
912
964
|
|
|
913
|
-
cache.k = wsp_ggml_new_tensor_1d(
|
|
914
|
-
cache.v = wsp_ggml_new_tensor_1d(
|
|
915
|
-
|
|
916
|
-
const size_t mem_bytes = wsp_ggml_nbytes(cache.k) + wsp_ggml_nbytes(cache.v);
|
|
965
|
+
cache.k = wsp_ggml_new_tensor_1d(ctx, wtype, n_elements);
|
|
966
|
+
cache.v = wsp_ggml_new_tensor_1d(ctx, wtype, n_elements);
|
|
917
967
|
|
|
918
|
-
cache.buffer =
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
968
|
+
cache.buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx, backend);
|
|
969
|
+
if (!cache.buffer) {
|
|
970
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
|
|
971
|
+
return false;
|
|
972
|
+
}
|
|
923
973
|
|
|
924
|
-
|
|
925
|
-
wsp_ggml_allocr_alloc(alloc, cache.v);
|
|
974
|
+
wsp_ggml_backend_buffer_clear(cache.buffer, 0);
|
|
926
975
|
|
|
927
|
-
|
|
928
|
-
}
|
|
976
|
+
wsp_ggml_free(ctx);
|
|
929
977
|
|
|
930
978
|
return true;
|
|
931
979
|
}
|
|
932
980
|
|
|
933
|
-
static void
|
|
934
|
-
|
|
935
|
-
wsp_ggml_free(cache.ctx);
|
|
936
|
-
wsp_ggml_backend_buffer_free(cache.buffer);
|
|
937
|
-
cache.ctx = nullptr;
|
|
938
|
-
}
|
|
981
|
+
static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
|
|
982
|
+
wsp_ggml_backend_buffer_free(cache.buffer);
|
|
939
983
|
}
|
|
940
984
|
|
|
941
985
|
static bool whisper_kv_cache_find_slot(
|
|
@@ -1006,6 +1050,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
|
|
1006
1050
|
cache.cells[i].seq_id.clear();
|
|
1007
1051
|
}
|
|
1008
1052
|
cache.head = 0;
|
|
1053
|
+
|
|
1054
|
+
wsp_ggml_backend_buffer_clear(cache.buffer, 0);
|
|
1009
1055
|
}
|
|
1010
1056
|
|
|
1011
1057
|
static void whisper_kv_cache_seq_rm(
|
|
@@ -1056,15 +1102,167 @@ static void whisper_kv_cache_seq_cp(
|
|
|
1056
1102
|
}
|
|
1057
1103
|
}
|
|
1058
1104
|
|
|
1059
|
-
static
|
|
1060
|
-
|
|
1105
|
+
static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
|
|
1106
|
+
if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
|
|
1107
|
+
return 1u;
|
|
1108
|
+
}
|
|
1109
|
+
|
|
1110
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1111
|
+
if (wctx.params.use_gpu) {
|
|
1112
|
+
return 32u;
|
|
1113
|
+
}
|
|
1114
|
+
#endif
|
|
1115
|
+
|
|
1116
|
+
#ifdef WSP_GGML_USE_CUDA
|
|
1117
|
+
if (wctx.params.use_gpu) {
|
|
1118
|
+
return 256u;
|
|
1119
|
+
}
|
|
1120
|
+
#endif
|
|
1121
|
+
|
|
1122
|
+
return 1u;
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
1126
|
+
static bool aheads_masks_init(
|
|
1127
|
+
const whisper_context_params & cparams,
|
|
1128
|
+
const whisper_hparams & hparams,
|
|
1129
|
+
struct whisper_aheads_masks & aheads_masks,
|
|
1130
|
+
wsp_ggml_backend_t backend) {
|
|
1131
|
+
|
|
1132
|
+
const int32_t n_text_layer = hparams.n_text_layer;
|
|
1133
|
+
const int32_t n_head = hparams.n_text_head;
|
|
1134
|
+
|
|
1135
|
+
// Sanity checks
|
|
1136
|
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
|
|
1137
|
+
WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__);
|
|
1138
|
+
return false;
|
|
1139
|
+
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
|
|
1140
|
+
if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) {
|
|
1141
|
+
WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer);
|
|
1142
|
+
return false;
|
|
1143
|
+
}
|
|
1144
|
+
} else {
|
|
1145
|
+
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
|
|
1146
|
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) {
|
|
1147
|
+
if (aheads.n_heads == 0) {
|
|
1148
|
+
WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__);
|
|
1149
|
+
return false;
|
|
1150
|
+
}
|
|
1151
|
+
if (aheads.heads == NULL) {
|
|
1152
|
+
WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__);
|
|
1153
|
+
return false;
|
|
1154
|
+
}
|
|
1155
|
+
}
|
|
1156
|
+
for (size_t i = 0; i < aheads.n_heads; ++i) {
|
|
1157
|
+
if (aheads.heads[i].n_text_layer >= n_text_layer) {
|
|
1158
|
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer);
|
|
1159
|
+
return false;
|
|
1160
|
+
}
|
|
1161
|
+
if (aheads.heads[i].n_text_layer < 0) {
|
|
1162
|
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__);
|
|
1163
|
+
return false;
|
|
1164
|
+
}
|
|
1165
|
+
if (aheads.heads[i].n_head >= n_head) {
|
|
1166
|
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head);
|
|
1167
|
+
return false;
|
|
1168
|
+
}
|
|
1169
|
+
if (aheads.heads[i].n_head < 0) {
|
|
1170
|
+
WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__);
|
|
1171
|
+
return false;
|
|
1172
|
+
}
|
|
1173
|
+
}
|
|
1174
|
+
}
|
|
1175
|
+
|
|
1176
|
+
struct wsp_ggml_init_params params = {
|
|
1177
|
+
/*.mem_size =*/ (size_t) static_cast<size_t>(n_text_layer)*wsp_ggml_tensor_overhead(),
|
|
1178
|
+
/*.mem_buffer =*/ nullptr,
|
|
1179
|
+
/*.no_alloc =*/ true,
|
|
1180
|
+
};
|
|
1181
|
+
|
|
1182
|
+
aheads_masks.ctx = wsp_ggml_init(params);
|
|
1183
|
+
|
|
1184
|
+
if (!aheads_masks.ctx) {
|
|
1185
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__);
|
|
1186
|
+
return false;
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
for (int64_t il = 0; il < n_text_layer; ++il) {
|
|
1190
|
+
auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
|
|
1191
|
+
if (!aheads.empty()) {
|
|
1192
|
+
aheads_masks.m.push_back(wsp_ggml_new_tensor_2d(aheads_masks.ctx, WSP_GGML_TYPE_F32, n_head, aheads.size()));
|
|
1193
|
+
} else {
|
|
1194
|
+
aheads_masks.m.push_back(nullptr);
|
|
1195
|
+
}
|
|
1196
|
+
}
|
|
1197
|
+
|
|
1198
|
+
aheads_masks.buffer = wsp_ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend);
|
|
1199
|
+
if (!aheads_masks.buffer) {
|
|
1200
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__);
|
|
1201
|
+
return false;
|
|
1202
|
+
}
|
|
1203
|
+
|
|
1204
|
+
// Set data on mask tensors
|
|
1205
|
+
// Since this must be backend agnostic, we write our desired values on mask_data,
|
|
1206
|
+
// and send it to backend with wsp_ggml_backend_tensor_set.
|
|
1207
|
+
// Each mask in N_HEADS*N_ALIGNMENT_HEADS, one per text layer containing alignment
|
|
1208
|
+
// heads. Each row of the mask "marks" one alignment head. E.g. if some text layer
|
|
1209
|
+
// has a total of 10 heads and of those, heads 0,5,6 are alignment heads, the mask
|
|
1210
|
+
// should read:
|
|
1211
|
+
// 1 0 0 0 0 0 0 0 0 0
|
|
1212
|
+
// 0 0 0 0 0 1 0 0 0 0
|
|
1213
|
+
// 0 0 0 0 0 0 1 0 0 0
|
|
1214
|
+
std::vector<float> mask_data;
|
|
1215
|
+
for (int64_t il = 0; il < n_text_layer; ++il) {
|
|
1216
|
+
if (aheads_masks.m[il] != nullptr) {
|
|
1217
|
+
auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
|
|
1218
|
+
|
|
1219
|
+
size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1];
|
|
1220
|
+
size_t data_size_bytes = data_size * sizeof(float);
|
|
1221
|
+
mask_data.resize(data_size);
|
|
1222
|
+
|
|
1223
|
+
std::fill(mask_data.begin(), mask_data.end(), 0);
|
|
1224
|
+
for (size_t ih = 0; ih < aheads.size(); ++ih) {
|
|
1225
|
+
size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0]));
|
|
1226
|
+
mask_data[pos] = 1.0f;
|
|
1227
|
+
}
|
|
1228
|
+
|
|
1229
|
+
wsp_ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size_bytes);
|
|
1230
|
+
}
|
|
1231
|
+
}
|
|
1232
|
+
|
|
1233
|
+
if (aheads_masks.m.empty()) {
|
|
1234
|
+
WHISPER_LOG_ERROR("%s: \n", __func__);
|
|
1235
|
+
return false;
|
|
1236
|
+
}
|
|
1237
|
+
|
|
1238
|
+
return true;
|
|
1239
|
+
}
|
|
1240
|
+
|
|
1241
|
+
static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) {
|
|
1242
|
+
wsp_ggml_free(aheads_masks.ctx);
|
|
1243
|
+
wsp_ggml_backend_buffer_free(aheads_masks.buffer);
|
|
1244
|
+
aheads_masks.ctx = nullptr;
|
|
1245
|
+
}
|
|
1061
1246
|
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1247
|
+
static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
|
1248
|
+
size_t size = 0;
|
|
1249
|
+
for (size_t i = 0; i < aheads_masks.m.size(); ++i) {
|
|
1250
|
+
if (aheads_masks.m[i] != nullptr)
|
|
1251
|
+
size += wsp_ggml_nbytes(aheads_masks.m[i]);
|
|
1252
|
+
}
|
|
1253
|
+
return size;
|
|
1254
|
+
}
|
|
1255
|
+
|
|
1256
|
+
static wsp_ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
|
1257
|
+
wsp_ggml_backend_t result = NULL;
|
|
1258
|
+
|
|
1259
|
+
wsp_ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
|
1260
|
+
|
|
1261
|
+
#ifdef WSP_GGML_USE_CUDA
|
|
1262
|
+
if (params.use_gpu) {
|
|
1065
1263
|
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
|
1066
|
-
|
|
1067
|
-
if (!
|
|
1264
|
+
result = wsp_ggml_backend_cuda_init(params.gpu_device);
|
|
1265
|
+
if (!result) {
|
|
1068
1266
|
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
|
|
1069
1267
|
}
|
|
1070
1268
|
}
|
|
@@ -1073,22 +1271,108 @@ static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & pa
|
|
|
1073
1271
|
#ifdef WSP_GGML_USE_METAL
|
|
1074
1272
|
if (params.use_gpu) {
|
|
1075
1273
|
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
if (!backend_gpu) {
|
|
1274
|
+
result = wsp_ggml_backend_metal_init();
|
|
1275
|
+
if (!result) {
|
|
1079
1276
|
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__);
|
|
1080
|
-
} else if (!wsp_ggml_backend_metal_supports_family(
|
|
1277
|
+
} else if (!wsp_ggml_backend_metal_supports_family(result, 7)) {
|
|
1081
1278
|
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
|
|
1082
|
-
wsp_ggml_backend_free(
|
|
1083
|
-
|
|
1279
|
+
wsp_ggml_backend_free(result);
|
|
1280
|
+
result = NULL;
|
|
1281
|
+
}
|
|
1282
|
+
}
|
|
1283
|
+
#endif
|
|
1284
|
+
|
|
1285
|
+
#ifdef WSP_GGML_USE_SYCL
|
|
1286
|
+
if (params.use_gpu) {
|
|
1287
|
+
WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
|
|
1288
|
+
result = wsp_ggml_backend_sycl_init(params.gpu_device);
|
|
1289
|
+
if (!result) {
|
|
1290
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_sycl_init() failed\n", __func__);
|
|
1084
1291
|
}
|
|
1085
1292
|
}
|
|
1086
1293
|
#endif
|
|
1087
1294
|
|
|
1295
|
+
#ifdef WSP_GGML_USE_VULKAN
|
|
1296
|
+
if (params.use_gpu) {
|
|
1297
|
+
WHISPER_LOG_INFO("%s: using Vulkan backend\n", __func__);
|
|
1298
|
+
result = wsp_ggml_backend_vk_init(params.gpu_device);
|
|
1299
|
+
if (!result) {
|
|
1300
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_vk_init() failed\n", __func__);
|
|
1301
|
+
}
|
|
1302
|
+
}
|
|
1303
|
+
#endif
|
|
1304
|
+
|
|
1305
|
+
#ifdef WSP_GGML_USE_CANN
|
|
1306
|
+
if (params.use_gpu) {
|
|
1307
|
+
WHISPER_LOG_INFO("%s: using CANN backend\n", __func__);
|
|
1308
|
+
result = wsp_ggml_backend_cann_init(params.gpu_device);
|
|
1309
|
+
if (!result) {
|
|
1310
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cann_init() failed\n", __func__);
|
|
1311
|
+
}
|
|
1312
|
+
}
|
|
1313
|
+
#endif
|
|
1314
|
+
|
|
1315
|
+
WSP_GGML_UNUSED(params);
|
|
1316
|
+
|
|
1317
|
+
return result;
|
|
1318
|
+
}
|
|
1319
|
+
|
|
1320
|
+
static std::vector<wsp_ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
|
|
1321
|
+
std::vector<wsp_ggml_backend_t> result;
|
|
1322
|
+
|
|
1323
|
+
wsp_ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
|
|
1324
|
+
|
|
1088
1325
|
if (backend_gpu) {
|
|
1089
|
-
|
|
1326
|
+
result.push_back(backend_gpu);
|
|
1090
1327
|
}
|
|
1091
|
-
|
|
1328
|
+
|
|
1329
|
+
#ifdef WSP_GGML_USE_BLAS
|
|
1330
|
+
{
|
|
1331
|
+
WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__);
|
|
1332
|
+
wsp_ggml_backend_t backend_blas = wsp_ggml_backend_blas_init();
|
|
1333
|
+
if (!backend_blas) {
|
|
1334
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_blas_init() failed\n", __func__);
|
|
1335
|
+
} else {
|
|
1336
|
+
result.push_back(backend_blas);
|
|
1337
|
+
}
|
|
1338
|
+
}
|
|
1339
|
+
#endif
|
|
1340
|
+
|
|
1341
|
+
WSP_GGML_UNUSED(params);
|
|
1342
|
+
|
|
1343
|
+
result.push_back(wsp_ggml_backend_cpu_init());
|
|
1344
|
+
|
|
1345
|
+
return result;
|
|
1346
|
+
}
|
|
1347
|
+
|
|
1348
|
+
static wsp_ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
|
|
1349
|
+
wsp_ggml_backend_buffer_type_t result = nullptr;
|
|
1350
|
+
|
|
1351
|
+
params.use_gpu || (result = wsp_ggml_backend_cpu_buffer_type());
|
|
1352
|
+
|
|
1353
|
+
#ifdef WSP_GGML_USE_CUDA
|
|
1354
|
+
result || (result = wsp_ggml_backend_cuda_buffer_type(params.gpu_device));
|
|
1355
|
+
#endif
|
|
1356
|
+
|
|
1357
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1358
|
+
result || (result = wsp_ggml_backend_metal_buffer_type());
|
|
1359
|
+
#endif
|
|
1360
|
+
|
|
1361
|
+
#ifdef WSP_GGML_USE_SYCL
|
|
1362
|
+
result || (result = wsp_ggml_backend_sycl_buffer_type(params.gpu_device));
|
|
1363
|
+
#endif
|
|
1364
|
+
|
|
1365
|
+
#ifdef WSP_GGML_USE_VULKAN
|
|
1366
|
+
result || (result = wsp_ggml_backend_vk_buffer_type(params.gpu_device));
|
|
1367
|
+
#endif
|
|
1368
|
+
|
|
1369
|
+
#ifdef WSP_GGML_USE_CANN
|
|
1370
|
+
result || (result == wsp_ggml_backend_cann_buffer_type(params.gpu_device));
|
|
1371
|
+
#endif
|
|
1372
|
+
|
|
1373
|
+
result || (result = wsp_ggml_backend_cpu_buffer_type());
|
|
1374
|
+
|
|
1375
|
+
return result;
|
|
1092
1376
|
}
|
|
1093
1377
|
|
|
1094
1378
|
// load the model from a ggml file
|
|
@@ -1515,29 +1799,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1515
1799
|
}
|
|
1516
1800
|
}
|
|
1517
1801
|
|
|
1518
|
-
wctx.backend = whisper_backend_init(wctx.params);
|
|
1519
|
-
|
|
1520
|
-
{
|
|
1521
|
-
size_t size_main = 0;
|
|
1522
|
-
|
|
1523
|
-
for (const auto & t : model.tensors) {
|
|
1524
|
-
size_main += wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead();
|
|
1525
|
-
}
|
|
1526
|
-
|
|
1527
|
-
model.buffer = wsp_ggml_backend_alloc_buffer(wctx.backend, size_main);
|
|
1528
|
-
|
|
1529
|
-
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6);
|
|
1530
|
-
}
|
|
1531
|
-
|
|
1532
|
-
wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(model.buffer);
|
|
1533
|
-
|
|
1534
1802
|
// allocate tensors in the backend buffers
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1803
|
+
model.buffer = wsp_ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
|
|
1804
|
+
if (!model.buffer) {
|
|
1805
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
|
|
1806
|
+
return false;
|
|
1539
1807
|
}
|
|
1540
1808
|
|
|
1809
|
+
size_t size_main = wsp_ggml_backend_buffer_get_size(model.buffer);
|
|
1810
|
+
WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, wsp_ggml_backend_buffer_name(model.buffer), size_main / 1e6);
|
|
1811
|
+
|
|
1541
1812
|
// load weights
|
|
1542
1813
|
{
|
|
1543
1814
|
size_t total_size = 0;
|
|
@@ -1599,15 +1870,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1599
1870
|
return false;
|
|
1600
1871
|
}
|
|
1601
1872
|
|
|
1602
|
-
wsp_ggml_backend_t backend = wctx.backend;
|
|
1873
|
+
//wsp_ggml_backend_t backend = wctx.backend;
|
|
1603
1874
|
|
|
1604
1875
|
//printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
|
|
1605
1876
|
|
|
1606
|
-
if ((
|
|
1607
|
-
#ifdef WSP_GGML_USE_METAL
|
|
1608
|
-
|| wsp_ggml_backend_is_metal(backend)
|
|
1609
|
-
#endif
|
|
1610
|
-
)) {
|
|
1877
|
+
if (wsp_ggml_backend_buffer_is_host(model.buffer)) {
|
|
1611
1878
|
// for the CPU and Metal backend, we can read directly into the tensor
|
|
1612
1879
|
loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
|
|
1613
1880
|
BYTESWAP_TENSOR(tensor);
|
|
@@ -1635,7 +1902,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1635
1902
|
}
|
|
1636
1903
|
}
|
|
1637
1904
|
|
|
1638
|
-
|
|
1905
|
+
wsp_ggml_backend_buffer_set_usage(model.buffer, WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
|
1639
1906
|
|
|
1640
1907
|
wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
|
|
1641
1908
|
|
|
@@ -1662,10 +1929,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
|
|
|
1662
1929
|
|
|
1663
1930
|
static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
1664
1931
|
whisper_context & wctx,
|
|
1665
|
-
whisper_state & wstate
|
|
1666
|
-
const int mel_offset) {
|
|
1932
|
+
whisper_state & wstate) {
|
|
1667
1933
|
const auto & model = wctx.model;
|
|
1668
|
-
const auto & mel_inp = wstate.mel;
|
|
1669
1934
|
const auto & hparams = model.hparams;
|
|
1670
1935
|
|
|
1671
1936
|
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
@@ -1674,8 +1939,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1674
1939
|
const int n_mels = hparams.n_mels;
|
|
1675
1940
|
|
|
1676
1941
|
struct wsp_ggml_init_params params = {
|
|
1677
|
-
/*.mem_size =*/ wstate.
|
|
1678
|
-
/*.mem_buffer =*/ wstate.
|
|
1942
|
+
/*.mem_size =*/ wstate.sched_conv.meta.size(),
|
|
1943
|
+
/*.mem_buffer =*/ wstate.sched_conv.meta.data(),
|
|
1679
1944
|
/*.no_alloc =*/ true,
|
|
1680
1945
|
};
|
|
1681
1946
|
|
|
@@ -1683,31 +1948,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1683
1948
|
|
|
1684
1949
|
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
1685
1950
|
|
|
1686
|
-
wsp_ggml_allocr * alloc = wstate.alloc_conv.alloc;
|
|
1687
|
-
|
|
1688
1951
|
struct wsp_ggml_tensor * mel = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, 2*n_ctx, n_mels);
|
|
1689
|
-
|
|
1690
|
-
|
|
1691
|
-
assert(mel->type == WSP_GGML_TYPE_F32);
|
|
1692
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1693
|
-
assert(mel_inp.n_mel == n_mels);
|
|
1694
|
-
|
|
1695
|
-
wstate.inp_mel.resize(wsp_ggml_nelements(mel));
|
|
1696
|
-
|
|
1697
|
-
float * dst = wstate.inp_mel.data();
|
|
1698
|
-
memset(dst, 0, wsp_ggml_nbytes(mel));
|
|
1699
|
-
|
|
1700
|
-
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
|
1701
|
-
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
|
1702
|
-
|
|
1703
|
-
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
|
1704
|
-
for (int i = i0; i < i1; ++i) {
|
|
1705
|
-
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
|
1706
|
-
}
|
|
1707
|
-
}
|
|
1708
|
-
|
|
1709
|
-
wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
|
|
1710
|
-
}
|
|
1952
|
+
wsp_ggml_set_name(mel, "mel");
|
|
1953
|
+
wsp_ggml_set_input(mel);
|
|
1711
1954
|
|
|
1712
1955
|
struct wsp_ggml_tensor * cur = nullptr;
|
|
1713
1956
|
|
|
@@ -1728,27 +1971,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1728
1971
|
wsp_ggml_set_name(cur, "embd_conv");
|
|
1729
1972
|
wstate.embd_conv = cur;
|
|
1730
1973
|
} else {
|
|
1731
|
-
|
|
1732
|
-
cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
1733
|
-
wsp_ggml_allocr_alloc(alloc, cur);
|
|
1974
|
+
wsp_ggml_build_forward_expand(gf, mel);
|
|
1734
1975
|
|
|
1735
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1736
|
-
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
|
|
1737
|
-
}
|
|
1738
|
-
#endif
|
|
1739
|
-
#ifdef WHISPER_USE_OPENVINO
|
|
1740
1976
|
cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1744
|
-
whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
|
|
1745
|
-
}
|
|
1746
|
-
#endif
|
|
1977
|
+
wsp_ggml_set_input(cur); // the external encoder will write into this tensor
|
|
1747
1978
|
|
|
1748
1979
|
wsp_ggml_set_name(cur, "embd_enc");
|
|
1749
1980
|
wstate.embd_enc = cur;
|
|
1750
1981
|
}
|
|
1751
1982
|
|
|
1983
|
+
wsp_ggml_set_output(cur);
|
|
1984
|
+
|
|
1752
1985
|
wsp_ggml_build_forward_expand(gf, cur);
|
|
1753
1986
|
|
|
1754
1987
|
wsp_ggml_free(ctx0);
|
|
@@ -1767,9 +2000,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1767
2000
|
const int n_head = hparams.n_audio_head;
|
|
1768
2001
|
const int n_layer = hparams.n_audio_layer;
|
|
1769
2002
|
|
|
2003
|
+
const int n_state_head = n_state/n_head;
|
|
2004
|
+
|
|
2005
|
+
auto & kv_pad = wstate.kv_pad;
|
|
2006
|
+
|
|
2007
|
+
WHISPER_ASSERT(!!kv_pad.buffer);
|
|
2008
|
+
|
|
2009
|
+
const int n_ctx_pad = WSP_GGML_PAD(n_ctx, 256);
|
|
2010
|
+
|
|
1770
2011
|
struct wsp_ggml_init_params params = {
|
|
1771
|
-
/*.mem_size =*/ wstate.
|
|
1772
|
-
/*.mem_buffer =*/ wstate.
|
|
2012
|
+
/*.mem_size =*/ wstate.sched_encode.meta.size(),
|
|
2013
|
+
/*.mem_buffer =*/ wstate.sched_encode.meta.data(),
|
|
1773
2014
|
/*.no_alloc =*/ true,
|
|
1774
2015
|
};
|
|
1775
2016
|
|
|
@@ -1777,23 +2018,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1777
2018
|
|
|
1778
2019
|
wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
|
1779
2020
|
|
|
1780
|
-
wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
|
1781
|
-
|
|
1782
|
-
//struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state);
|
|
1783
|
-
//wsp_ggml_allocr_alloc(alloc, cur);
|
|
1784
|
-
|
|
1785
|
-
//if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1786
|
-
// wsp_ggml_backend_tensor_copy(wstate.embd_conv, cur);
|
|
1787
|
-
//}
|
|
1788
2021
|
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
|
|
1789
2022
|
|
|
1790
|
-
|
|
1791
|
-
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
1792
|
-
|
|
1793
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1794
|
-
const float val = 1.0f/sqrtf(float(n_state)/n_head);
|
|
1795
|
-
wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
|
1796
|
-
}
|
|
2023
|
+
const float KQscale = 1.0f/sqrtf(float(n_state_head));
|
|
1797
2024
|
|
|
1798
2025
|
// ===================================================================
|
|
1799
2026
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
@@ -1843,14 +2070,14 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1843
2070
|
|
|
1844
2071
|
Qcur = wsp_ggml_add(ctx0, Qcur, layer.attn_q_b);
|
|
1845
2072
|
|
|
1846
|
-
//Qcur = wsp_ggml_scale(ctx0, Qcur,
|
|
2073
|
+
//Qcur = wsp_ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
|
|
1847
2074
|
|
|
1848
2075
|
// note: no bias for Key
|
|
1849
2076
|
struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
|
|
1850
2077
|
layer.attn_k_w,
|
|
1851
2078
|
cur);
|
|
1852
2079
|
|
|
1853
|
-
//Kcur = wsp_ggml_scale(ctx0, Kcur,
|
|
2080
|
+
//Kcur = wsp_ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
|
|
1854
2081
|
|
|
1855
2082
|
struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
|
|
1856
2083
|
layer.attn_v_w,
|
|
@@ -1860,70 +2087,60 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1860
2087
|
|
|
1861
2088
|
// ------
|
|
1862
2089
|
|
|
1863
|
-
#ifdef WHISPER_USE_FLASH_ATTN
|
|
1864
2090
|
struct wsp_ggml_tensor * Q =
|
|
1865
2091
|
wsp_ggml_permute(ctx0,
|
|
1866
|
-
|
|
1867
|
-
Qcur,
|
|
1868
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
2092
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
|
|
1869
2093
|
0, 2, 1, 3);
|
|
1870
2094
|
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
Kcur,
|
|
1875
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1876
|
-
0, 2, 1, 3);
|
|
2095
|
+
if (wctx.params.flash_attn) {
|
|
2096
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, wsp_ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
|
|
2097
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, wsp_ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
|
|
1877
2098
|
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
1, 2, 0, 3),
|
|
1885
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
|
2099
|
+
struct wsp_ggml_tensor * K =
|
|
2100
|
+
wsp_ggml_view_3d(ctx0, kv_pad.k,
|
|
2101
|
+
n_state_head, n_ctx_pad, n_head,
|
|
2102
|
+
wsp_ggml_element_size(kv_pad.k)*n_state,
|
|
2103
|
+
wsp_ggml_element_size(kv_pad.k)*n_state_head,
|
|
2104
|
+
0);
|
|
1886
2105
|
|
|
1887
|
-
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
|
1894
|
-
0, 2, 1, 3);
|
|
2106
|
+
struct wsp_ggml_tensor * V =
|
|
2107
|
+
wsp_ggml_view_3d(ctx0, kv_pad.v,
|
|
2108
|
+
n_state_head, n_ctx_pad, n_head,
|
|
2109
|
+
wsp_ggml_element_size(kv_pad.v)*n_state,
|
|
2110
|
+
wsp_ggml_element_size(kv_pad.v)*n_state_head,
|
|
2111
|
+
0);
|
|
1895
2112
|
|
|
1896
|
-
|
|
1897
|
-
wsp_ggml_permute(ctx0,
|
|
1898
|
-
wsp_ggml_cpy(ctx0,
|
|
1899
|
-
Kcur,
|
|
1900
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1901
|
-
0, 2, 1, 3);
|
|
2113
|
+
cur = wsp_ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
|
|
1902
2114
|
|
|
1903
|
-
|
|
1904
|
-
|
|
2115
|
+
cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
|
2116
|
+
} else {
|
|
2117
|
+
struct wsp_ggml_tensor * K =
|
|
2118
|
+
wsp_ggml_permute(ctx0,
|
|
2119
|
+
wsp_ggml_cast(ctx0,
|
|
2120
|
+
wsp_ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
|
|
2121
|
+
wctx.itype),
|
|
2122
|
+
0, 2, 1, 3);
|
|
1905
2123
|
|
|
1906
|
-
|
|
2124
|
+
// K * Q
|
|
2125
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
|
|
1907
2126
|
|
|
1908
|
-
|
|
2127
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
|
1909
2128
|
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
);
|
|
2129
|
+
struct wsp_ggml_tensor * V =
|
|
2130
|
+
wsp_ggml_cast(ctx0,
|
|
2131
|
+
wsp_ggml_permute(ctx0,
|
|
2132
|
+
wsp_ggml_reshape_3d(ctx0,
|
|
2133
|
+
Vcur,
|
|
2134
|
+
n_state_head, n_head, n_ctx),
|
|
2135
|
+
1, 2, 0, 3),
|
|
2136
|
+
wctx.itype);
|
|
1919
2137
|
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
2138
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
2139
|
+
|
|
2140
|
+
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
1923
2141
|
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
|
|
2142
|
+
cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
|
|
2143
|
+
}
|
|
1927
2144
|
}
|
|
1928
2145
|
|
|
1929
2146
|
// projection
|
|
@@ -1952,11 +2169,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1952
2169
|
layer.mlp_ln_b);
|
|
1953
2170
|
}
|
|
1954
2171
|
|
|
1955
|
-
#ifdef WHISPER_USE_FLASH_FF
|
|
1956
|
-
cur = wsp_ggml_flash_ff(ctx0,
|
|
1957
|
-
wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
|
1958
|
-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
1959
|
-
#else
|
|
1960
2172
|
// fully connected
|
|
1961
2173
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
1962
2174
|
layer.mlp_0_w,
|
|
@@ -1973,7 +2185,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1973
2185
|
cur);
|
|
1974
2186
|
|
|
1975
2187
|
cur = wsp_ggml_add(ctx0, cur, layer.mlp_1_b);
|
|
1976
|
-
#endif
|
|
1977
2188
|
}
|
|
1978
2189
|
|
|
1979
2190
|
inpL = wsp_ggml_add(ctx0, cur, inpFF);
|
|
@@ -2022,9 +2233,13 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
|
2022
2233
|
const int n_state = hparams.n_audio_state;
|
|
2023
2234
|
const int n_head = hparams.n_audio_head;
|
|
2024
2235
|
|
|
2236
|
+
const int n_state_head = n_state/n_head;
|
|
2237
|
+
|
|
2238
|
+
const int n_ctx_pad = WSP_GGML_PAD(n_ctx, 256);
|
|
2239
|
+
|
|
2025
2240
|
struct wsp_ggml_init_params params = {
|
|
2026
|
-
/*.mem_size =*/ wstate.
|
|
2027
|
-
/*.mem_buffer =*/ wstate.
|
|
2241
|
+
/*.mem_size =*/ wstate.sched_cross.meta.size(),
|
|
2242
|
+
/*.mem_buffer =*/ wstate.sched_cross.meta.data(),
|
|
2028
2243
|
/*.no_alloc =*/ true,
|
|
2029
2244
|
};
|
|
2030
2245
|
|
|
@@ -2032,34 +2247,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
|
2032
2247
|
|
|
2033
2248
|
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
2034
2249
|
|
|
2035
|
-
wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
|
2036
|
-
|
|
2037
|
-
//struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
2038
|
-
//wsp_ggml_allocr_alloc(alloc, cur);
|
|
2039
|
-
|
|
2040
|
-
//if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2041
|
-
// wsp_ggml_backend_tensor_copy(wstate.embd_enc, cur);
|
|
2042
|
-
//}
|
|
2043
2250
|
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
|
|
2044
2251
|
|
|
2045
|
-
|
|
2046
|
-
wsp_ggml_allocr_alloc(alloc, Kscale);
|
|
2047
|
-
|
|
2048
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2049
|
-
const float val = pow(float(n_state) / n_head, -0.25);
|
|
2050
|
-
wsp_ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
|
|
2051
|
-
}
|
|
2252
|
+
const float Kscale = pow(float(n_state_head), -0.25);
|
|
2052
2253
|
|
|
2053
2254
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
2054
2255
|
auto & layer = model.layers_decoder[il];
|
|
2055
2256
|
|
|
2056
|
-
struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
|
|
2257
|
+
struct wsp_ggml_tensor * Kcross = wsp_ggml_mul_mat(ctx0,
|
|
2057
2258
|
layer.cross_attn_k_w,
|
|
2058
2259
|
cur);
|
|
2059
2260
|
|
|
2060
2261
|
Kcross = wsp_ggml_scale(ctx0, Kcross, Kscale);
|
|
2061
2262
|
|
|
2062
|
-
struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
|
|
2263
|
+
struct wsp_ggml_tensor * Vcross = wsp_ggml_mul_mat(ctx0,
|
|
2063
2264
|
layer.cross_attn_v_w,
|
|
2064
2265
|
cur);
|
|
2065
2266
|
|
|
@@ -2067,15 +2268,25 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
|
2067
2268
|
Vcross,
|
|
2068
2269
|
layer.cross_attn_v_b);
|
|
2069
2270
|
|
|
2070
|
-
|
|
2271
|
+
struct wsp_ggml_tensor * k;
|
|
2272
|
+
struct wsp_ggml_tensor * v;
|
|
2273
|
+
|
|
2274
|
+
if (wctx.params.flash_attn) {
|
|
2275
|
+
k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
|
2276
|
+
(wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
|
|
2071
2277
|
|
|
2072
|
-
|
|
2073
|
-
|
|
2074
|
-
|
|
2278
|
+
v = wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
|
|
2279
|
+
(wsp_ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
|
|
2280
|
+
} else {
|
|
2281
|
+
Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
|
2282
|
+
|
|
2283
|
+
k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
|
2284
|
+
(wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
|
2075
2285
|
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2286
|
+
v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
|
2287
|
+
( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
2288
|
+
(il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
|
|
2289
|
+
}
|
|
2079
2290
|
|
|
2080
2291
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcross, k));
|
|
2081
2292
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcross, v));
|
|
@@ -2103,49 +2314,91 @@ static bool whisper_encode_internal(
|
|
|
2103
2314
|
whisper_state & wstate,
|
|
2104
2315
|
const int mel_offset,
|
|
2105
2316
|
const int n_threads,
|
|
2106
|
-
|
|
2317
|
+
wsp_ggml_abort_callback abort_callback,
|
|
2107
2318
|
void * abort_callback_data) {
|
|
2108
2319
|
const int64_t t_start_us = wsp_ggml_time_us();
|
|
2109
2320
|
|
|
2110
2321
|
// conv
|
|
2111
2322
|
{
|
|
2112
|
-
auto &
|
|
2323
|
+
auto & sched = wstate.sched_conv.sched;
|
|
2324
|
+
|
|
2325
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
|
|
2326
|
+
|
|
2327
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2328
|
+
// should never happen as we pre-allocate the memory
|
|
2329
|
+
return false;
|
|
2330
|
+
}
|
|
2331
|
+
|
|
2332
|
+
struct wsp_ggml_tensor * mel = wsp_ggml_graph_get_tensor(gf, "mel");
|
|
2333
|
+
|
|
2334
|
+
// set the input
|
|
2335
|
+
{
|
|
2336
|
+
const auto & mel_inp = wstate.mel;
|
|
2337
|
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
|
|
2338
|
+
|
|
2339
|
+
assert(mel->type == WSP_GGML_TYPE_F32);
|
|
2340
|
+
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
|
|
2341
|
+
|
|
2342
|
+
wstate.inp_mel.resize(wsp_ggml_nelements(mel));
|
|
2113
2343
|
|
|
2114
|
-
|
|
2344
|
+
float * dst = wstate.inp_mel.data();
|
|
2345
|
+
memset(dst, 0, wsp_ggml_nbytes(mel));
|
|
2115
2346
|
|
|
2116
|
-
|
|
2347
|
+
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
|
2348
|
+
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
|
2349
|
+
|
|
2350
|
+
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
|
2351
|
+
for (int i = i0; i < i1; ++i) {
|
|
2352
|
+
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
|
2353
|
+
}
|
|
2354
|
+
}
|
|
2117
2355
|
|
|
2118
|
-
|
|
2356
|
+
wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
|
|
2357
|
+
}
|
|
2119
2358
|
|
|
2120
2359
|
if (!whisper_encode_external(wstate)) {
|
|
2121
|
-
wsp_ggml_graph_compute_helper(
|
|
2360
|
+
if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2361
|
+
return false;
|
|
2362
|
+
}
|
|
2363
|
+
} else {
|
|
2364
|
+
#if defined(WHISPER_USE_COREML)
|
|
2365
|
+
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
|
|
2366
|
+
#elif defined(WHISPER_USE_OPENVINO)
|
|
2367
|
+
whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc);
|
|
2368
|
+
#endif
|
|
2122
2369
|
}
|
|
2123
2370
|
}
|
|
2124
2371
|
|
|
2125
2372
|
// encoder
|
|
2126
2373
|
if (!whisper_encode_external(wstate)) {
|
|
2127
|
-
auto &
|
|
2128
|
-
|
|
2129
|
-
wsp_ggml_allocr_reset(alloc);
|
|
2374
|
+
auto & sched = wstate.sched_encode.sched;
|
|
2130
2375
|
|
|
2131
2376
|
wsp_ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
|
|
2132
2377
|
|
|
2133
|
-
|
|
2378
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2379
|
+
// should never happen as we pre-allocate the memory
|
|
2380
|
+
return false;
|
|
2381
|
+
}
|
|
2134
2382
|
|
|
2135
|
-
wsp_ggml_graph_compute_helper(
|
|
2383
|
+
if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2384
|
+
return false;
|
|
2385
|
+
}
|
|
2136
2386
|
}
|
|
2137
2387
|
|
|
2138
2388
|
// cross
|
|
2139
2389
|
{
|
|
2140
|
-
auto &
|
|
2141
|
-
|
|
2142
|
-
wsp_ggml_allocr_reset(alloc);
|
|
2390
|
+
auto & sched = wstate.sched_cross.sched;
|
|
2143
2391
|
|
|
2144
2392
|
wsp_ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
|
2145
2393
|
|
|
2146
|
-
|
|
2394
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2395
|
+
// should never happen as we pre-allocate the memory
|
|
2396
|
+
return false;
|
|
2397
|
+
}
|
|
2147
2398
|
|
|
2148
|
-
wsp_ggml_graph_compute_helper(
|
|
2399
|
+
if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2400
|
+
return false;
|
|
2401
|
+
}
|
|
2149
2402
|
}
|
|
2150
2403
|
|
|
2151
2404
|
wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
|
|
@@ -2157,32 +2410,36 @@ static bool whisper_encode_internal(
|
|
|
2157
2410
|
static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
2158
2411
|
whisper_context & wctx,
|
|
2159
2412
|
whisper_state & wstate,
|
|
2160
|
-
const whisper_batch & batch
|
|
2413
|
+
const whisper_batch & batch,
|
|
2414
|
+
bool save_alignment_heads_QKs,
|
|
2415
|
+
bool worst_case) {
|
|
2161
2416
|
const auto & model = wctx.model;
|
|
2162
2417
|
const auto & hparams = model.hparams;
|
|
2163
2418
|
|
|
2164
2419
|
auto & kv_self = wstate.kv_self;
|
|
2165
2420
|
|
|
2166
|
-
WHISPER_ASSERT(!!kv_self.
|
|
2167
|
-
|
|
2168
|
-
wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
|
2421
|
+
WHISPER_ASSERT(!!kv_self.buffer);
|
|
2169
2422
|
|
|
2170
2423
|
const int n_ctx = kv_self.size;
|
|
2171
2424
|
const int n_state = hparams.n_text_state;
|
|
2172
2425
|
const int n_head = hparams.n_text_head;
|
|
2173
2426
|
const int n_layer = hparams.n_text_layer;
|
|
2174
2427
|
|
|
2428
|
+
const int n_state_head = n_state/n_head;
|
|
2429
|
+
|
|
2175
2430
|
const int n_tokens = batch.n_tokens;
|
|
2176
2431
|
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
2177
2432
|
|
|
2178
|
-
const
|
|
2179
|
-
|
|
2433
|
+
const int n_audio_ctx_pad = WSP_GGML_PAD(n_audio_ctx, 256);
|
|
2434
|
+
|
|
2435
|
+
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
|
|
2436
|
+
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
|
|
2180
2437
|
|
|
2181
|
-
//
|
|
2438
|
+
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
|
2182
2439
|
|
|
2183
2440
|
struct wsp_ggml_init_params params = {
|
|
2184
|
-
/*.mem_size =*/ wstate.
|
|
2185
|
-
/*.mem_buffer =*/ wstate.
|
|
2441
|
+
/*.mem_size =*/ wstate.sched_decode.meta.size(),
|
|
2442
|
+
/*.mem_buffer =*/ wstate.sched_decode.meta.data(),
|
|
2186
2443
|
/*.no_alloc =*/ true,
|
|
2187
2444
|
};
|
|
2188
2445
|
|
|
@@ -2190,55 +2447,21 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2190
2447
|
|
|
2191
2448
|
wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
|
2192
2449
|
|
|
2193
|
-
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2194
|
-
|
|
2195
|
-
|
|
2196
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2197
|
-
wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
|
|
2198
|
-
}
|
|
2199
|
-
|
|
2200
|
-
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2201
|
-
wsp_ggml_allocr_alloc(alloc, position);
|
|
2202
|
-
|
|
2203
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2204
|
-
for (int i = 0; i < n_tokens; ++i) {
|
|
2205
|
-
const int32_t val = batch.pos[i];
|
|
2206
|
-
wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
|
2207
|
-
}
|
|
2208
|
-
}
|
|
2209
|
-
|
|
2210
|
-
struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
2211
|
-
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
2212
|
-
|
|
2213
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2214
|
-
const float val = pow(float(n_state)/n_head, -0.25);
|
|
2215
|
-
wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
|
2216
|
-
}
|
|
2217
|
-
|
|
2218
|
-
struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, n_tokens, 1);
|
|
2219
|
-
wsp_ggml_allocr_alloc(alloc, KQ_mask);
|
|
2220
|
-
|
|
2221
|
-
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2222
|
-
wstate.inp_mask.resize(n_kv*n_tokens);
|
|
2450
|
+
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2451
|
+
wsp_ggml_set_name(embd, "embd");
|
|
2452
|
+
wsp_ggml_set_input(embd);
|
|
2223
2453
|
|
|
2224
|
-
|
|
2225
|
-
|
|
2454
|
+
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2455
|
+
wsp_ggml_set_name(position, "position");
|
|
2456
|
+
wsp_ggml_set_input(position);
|
|
2226
2457
|
|
|
2227
|
-
|
|
2228
|
-
for (int j = 0; j < n_tokens; ++j) {
|
|
2229
|
-
const whisper_pos pos = batch.pos[j];
|
|
2230
|
-
const whisper_seq_id seq_id = batch.seq_id[j][0];
|
|
2458
|
+
const float KQscale = pow(float(n_state_head), -0.25);
|
|
2231
2459
|
|
|
2232
|
-
|
|
2233
|
-
|
|
2234
|
-
|
|
2235
|
-
}
|
|
2236
|
-
}
|
|
2237
|
-
}
|
|
2238
|
-
}
|
|
2460
|
+
struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, WSP_GGML_PAD(n_tokens, WSP_GGML_KQ_MASK_PAD), 1);
|
|
2461
|
+
wsp_ggml_set_name(KQ_mask, "KQ_mask");
|
|
2462
|
+
wsp_ggml_set_input(KQ_mask);
|
|
2239
2463
|
|
|
2240
|
-
|
|
2241
|
-
}
|
|
2464
|
+
struct wsp_ggml_tensor * KQ_mask_f16 = wsp_ggml_cast(ctx0, KQ_mask, WSP_GGML_TYPE_F16);
|
|
2242
2465
|
|
|
2243
2466
|
// token encoding + position encoding
|
|
2244
2467
|
struct wsp_ggml_tensor * cur =
|
|
@@ -2248,6 +2471,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2248
2471
|
|
|
2249
2472
|
struct wsp_ggml_tensor * inpL = cur;
|
|
2250
2473
|
|
|
2474
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
2475
|
+
struct wsp_ggml_tensor * aheads_cross_QKs = nullptr;
|
|
2476
|
+
|
|
2251
2477
|
for (int il = 0; il < n_layer; ++il) {
|
|
2252
2478
|
const auto & layer = model.layers_decoder[il];
|
|
2253
2479
|
|
|
@@ -2292,12 +2518,25 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2292
2518
|
Vcur,
|
|
2293
2519
|
layer.attn_v_b);
|
|
2294
2520
|
|
|
2295
|
-
|
|
2521
|
+
struct wsp_ggml_tensor * k;
|
|
2522
|
+
struct wsp_ggml_tensor * v;
|
|
2296
2523
|
|
|
2297
|
-
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
|
|
2524
|
+
if (wctx.params.flash_attn) {
|
|
2525
|
+
k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
|
2526
|
+
(wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
|
2527
|
+
|
|
2528
|
+
v = wsp_ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
|
|
2529
|
+
(wsp_ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
|
|
2530
|
+
} else {
|
|
2531
|
+
Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
|
2532
|
+
|
|
2533
|
+
k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
|
2534
|
+
(wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
|
2535
|
+
|
|
2536
|
+
v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
|
2537
|
+
( n_ctx)*wsp_ggml_element_size(kv_self.v),
|
|
2538
|
+
(il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
|
|
2539
|
+
}
|
|
2301
2540
|
|
|
2302
2541
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
|
|
2303
2542
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
|
|
@@ -2307,40 +2546,46 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2307
2546
|
|
|
2308
2547
|
struct wsp_ggml_tensor * Q =
|
|
2309
2548
|
wsp_ggml_permute(ctx0,
|
|
2310
|
-
wsp_ggml_reshape_3d(ctx0, Qcur,
|
|
2549
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
|
2311
2550
|
0, 2, 1, 3);
|
|
2312
2551
|
|
|
2313
2552
|
struct wsp_ggml_tensor * K =
|
|
2314
2553
|
wsp_ggml_view_3d(ctx0, kv_self.k,
|
|
2315
|
-
|
|
2554
|
+
n_state_head, n_kv, n_head,
|
|
2316
2555
|
wsp_ggml_element_size(kv_self.k)*n_state,
|
|
2317
|
-
wsp_ggml_element_size(kv_self.k)*
|
|
2556
|
+
wsp_ggml_element_size(kv_self.k)*n_state_head,
|
|
2318
2557
|
wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
|
2319
2558
|
|
|
2320
|
-
|
|
2321
|
-
|
|
2559
|
+
if (wctx.params.flash_attn) {
|
|
2560
|
+
struct wsp_ggml_tensor * V =
|
|
2561
|
+
wsp_ggml_view_3d(ctx0, kv_self.v,
|
|
2562
|
+
n_state_head, n_kv, n_head,
|
|
2563
|
+
wsp_ggml_element_size(kv_self.v)*n_state,
|
|
2564
|
+
wsp_ggml_element_size(kv_self.v)*n_state_head,
|
|
2565
|
+
wsp_ggml_element_size(kv_self.v)*n_state*n_ctx*il);
|
|
2322
2566
|
|
|
2323
|
-
|
|
2567
|
+
cur = wsp_ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
|
|
2324
2568
|
|
|
2325
|
-
|
|
2326
|
-
|
|
2569
|
+
cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
|
2570
|
+
} else {
|
|
2571
|
+
// K * Q
|
|
2572
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
|
|
2327
2573
|
|
|
2328
|
-
|
|
2574
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
|
|
2329
2575
|
|
|
2330
|
-
|
|
2331
|
-
|
|
2332
|
-
|
|
2333
|
-
|
|
2334
|
-
|
|
2335
|
-
|
|
2576
|
+
struct wsp_ggml_tensor * V =
|
|
2577
|
+
wsp_ggml_view_3d(ctx0, kv_self.v,
|
|
2578
|
+
n_kv, n_state_head, n_head,
|
|
2579
|
+
n_ctx*wsp_ggml_element_size(kv_self.v),
|
|
2580
|
+
n_ctx*wsp_ggml_element_size(kv_self.v)*n_state_head,
|
|
2581
|
+
n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
|
|
2336
2582
|
|
|
2337
|
-
|
|
2583
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
2338
2584
|
|
|
2339
|
-
|
|
2585
|
+
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2340
2586
|
|
|
2341
|
-
|
|
2342
|
-
|
|
2343
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
|
|
2587
|
+
cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
|
2588
|
+
}
|
|
2344
2589
|
}
|
|
2345
2590
|
|
|
2346
2591
|
// projection
|
|
@@ -2379,62 +2624,75 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2379
2624
|
Qcur,
|
|
2380
2625
|
layer.cross_attn_q_b);
|
|
2381
2626
|
|
|
2382
|
-
Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
|
|
2383
|
-
|
|
2384
|
-
// Kcross is already scaled
|
|
2385
|
-
struct wsp_ggml_tensor * Kcross =
|
|
2386
|
-
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2387
|
-
n_state/n_head, n_audio_ctx, n_head,
|
|
2388
|
-
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2389
|
-
wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
|
2390
|
-
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
|
2391
|
-
|
|
2392
|
-
//struct wsp_ggml_tensor * Vcross =
|
|
2393
|
-
// wsp_ggml_reshape_3d(ctx0,
|
|
2394
|
-
// wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
|
|
2395
|
-
// n_state/n_head, n_head, n_audio_ctx);
|
|
2396
|
-
|
|
2397
|
-
//struct wsp_ggml_tensor * V_trans =
|
|
2398
|
-
// wsp_ggml_cpy(ctx0,
|
|
2399
|
-
// wsp_ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
|
2400
|
-
// wsp_ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
|
2401
|
-
|
|
2402
|
-
struct wsp_ggml_tensor * V =
|
|
2403
|
-
wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
|
|
2404
|
-
n_audio_ctx, n_state/n_head, n_head,
|
|
2405
|
-
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
2406
|
-
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
|
2407
|
-
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
|
2408
|
-
|
|
2409
|
-
// ------
|
|
2410
|
-
|
|
2411
2627
|
struct wsp_ggml_tensor * Q =
|
|
2412
2628
|
wsp_ggml_permute(ctx0,
|
|
2413
|
-
wsp_ggml_reshape_3d(ctx0, Qcur,
|
|
2629
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
|
2414
2630
|
0, 2, 1, 3);
|
|
2415
2631
|
|
|
2416
|
-
|
|
2417
|
-
|
|
2632
|
+
if (wctx.params.flash_attn) {
|
|
2633
|
+
struct wsp_ggml_tensor * Kcross =
|
|
2634
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2635
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
|
2636
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2637
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
|
2638
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
|
|
2418
2639
|
|
|
2419
|
-
|
|
2420
|
-
|
|
2421
|
-
|
|
2422
|
-
|
|
2423
|
-
|
|
2640
|
+
struct wsp_ggml_tensor * Vcross =
|
|
2641
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
|
|
2642
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
|
2643
|
+
wsp_ggml_element_size(wstate.kv_cross.v)*n_state,
|
|
2644
|
+
wsp_ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
|
2645
|
+
wsp_ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
|
|
2424
2646
|
|
|
2425
|
-
|
|
2426
|
-
//struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
|
2647
|
+
cur = wsp_ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
|
|
2427
2648
|
|
|
2428
|
-
|
|
2649
|
+
cur = wsp_ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
|
2650
|
+
} else {
|
|
2651
|
+
struct wsp_ggml_tensor * Kcross =
|
|
2652
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2653
|
+
n_state_head, n_audio_ctx, n_head,
|
|
2654
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2655
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
|
2656
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
|
2657
|
+
|
|
2658
|
+
struct wsp_ggml_tensor * Vcross =
|
|
2659
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
|
|
2660
|
+
n_audio_ctx, n_state_head, n_head,
|
|
2661
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
2662
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
|
2663
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
|
2664
|
+
|
|
2665
|
+
// ------
|
|
2666
|
+
|
|
2667
|
+
// K * Q
|
|
2668
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
|
|
2669
|
+
|
|
2670
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
|
2671
|
+
|
|
2672
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
2673
|
+
if (wctx.params.dtw_token_timestamps) {
|
|
2674
|
+
if (wstate.aheads_masks.m[il] != nullptr) {
|
|
2675
|
+
struct wsp_ggml_tensor * aheads_KQs = wsp_ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
|
|
2676
|
+
aheads_KQs = wsp_ggml_transpose(ctx0, aheads_KQs);
|
|
2677
|
+
aheads_KQs = wsp_ggml_cont(ctx0, aheads_KQs);
|
|
2678
|
+
aheads_KQs = wsp_ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
|
|
2679
|
+
aheads_KQs = wsp_ggml_transpose(ctx0, aheads_KQs);
|
|
2680
|
+
aheads_KQs = wsp_ggml_cont(ctx0, aheads_KQs);
|
|
2681
|
+
aheads_KQs = wsp_ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
|
|
2682
|
+
if (aheads_cross_QKs == NULL) {
|
|
2683
|
+
aheads_cross_QKs = aheads_KQs;
|
|
2684
|
+
} else {
|
|
2685
|
+
aheads_cross_QKs = wsp_ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
|
|
2686
|
+
}
|
|
2687
|
+
}
|
|
2688
|
+
}
|
|
2429
2689
|
|
|
2430
|
-
|
|
2690
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
|
|
2431
2691
|
|
|
2432
|
-
|
|
2692
|
+
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2433
2693
|
|
|
2434
|
-
|
|
2435
|
-
|
|
2436
|
-
KQV_merged,
|
|
2437
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
|
|
2694
|
+
cur = wsp_ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
|
2695
|
+
}
|
|
2438
2696
|
}
|
|
2439
2697
|
|
|
2440
2698
|
// projection
|
|
@@ -2512,6 +2770,16 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2512
2770
|
|
|
2513
2771
|
struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
|
|
2514
2772
|
|
|
2773
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
2774
|
+
if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) {
|
|
2775
|
+
aheads_cross_QKs = wsp_ggml_transpose(ctx0, aheads_cross_QKs);
|
|
2776
|
+
aheads_cross_QKs = wsp_ggml_cont(ctx0, aheads_cross_QKs);
|
|
2777
|
+
if (save_alignment_heads_QKs) {
|
|
2778
|
+
wsp_ggml_build_forward_expand(gf, aheads_cross_QKs);
|
|
2779
|
+
wstate.aheads_cross_QKs = aheads_cross_QKs;
|
|
2780
|
+
}
|
|
2781
|
+
}
|
|
2782
|
+
|
|
2515
2783
|
wsp_ggml_build_forward_expand(gf, logits);
|
|
2516
2784
|
|
|
2517
2785
|
wsp_ggml_free(ctx0);
|
|
@@ -2534,7 +2802,8 @@ static bool whisper_decode_internal(
|
|
|
2534
2802
|
whisper_state & wstate,
|
|
2535
2803
|
const whisper_batch & batch,
|
|
2536
2804
|
const int n_threads,
|
|
2537
|
-
|
|
2805
|
+
bool save_alignment_heads_QKs,
|
|
2806
|
+
wsp_ggml_abort_callback abort_callback,
|
|
2538
2807
|
void * abort_callback_data) {
|
|
2539
2808
|
const int64_t t_start_us = wsp_ggml_time_us();
|
|
2540
2809
|
|
|
@@ -2556,24 +2825,77 @@ static bool whisper_decode_internal(
|
|
|
2556
2825
|
return false;
|
|
2557
2826
|
}
|
|
2558
2827
|
|
|
2559
|
-
|
|
2828
|
+
const uint32_t pad = whisper_kv_cache_get_padding(wctx);
|
|
2829
|
+
kv_self.n = std::min(kv_self.size, std::max(pad, WSP_GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
|
|
2830
|
+
|
|
2560
2831
|
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
|
2561
2832
|
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
|
2562
2833
|
}
|
|
2563
2834
|
|
|
2564
2835
|
// decoder
|
|
2565
2836
|
{
|
|
2566
|
-
auto &
|
|
2837
|
+
auto & sched = wstate.sched_decode.sched;
|
|
2838
|
+
|
|
2839
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
|
|
2840
|
+
|
|
2841
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2842
|
+
// should never happen as we pre-allocate the memory
|
|
2843
|
+
return false;
|
|
2844
|
+
}
|
|
2845
|
+
|
|
2846
|
+
// set the inputs
|
|
2847
|
+
{
|
|
2848
|
+
struct wsp_ggml_tensor * embd = wsp_ggml_graph_get_tensor(gf, "embd");
|
|
2849
|
+
wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
|
|
2850
|
+
}
|
|
2851
|
+
|
|
2852
|
+
{
|
|
2853
|
+
struct wsp_ggml_tensor * position = wsp_ggml_graph_get_tensor(gf, "position");
|
|
2854
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
2855
|
+
const int32_t val = batch.pos[i];
|
|
2856
|
+
wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
|
2857
|
+
}
|
|
2858
|
+
}
|
|
2859
|
+
|
|
2860
|
+
{
|
|
2861
|
+
struct wsp_ggml_tensor * KQ_mask = wsp_ggml_graph_get_tensor(gf, "KQ_mask");
|
|
2862
|
+
|
|
2863
|
+
auto & kv_self = wstate.kv_self;
|
|
2864
|
+
|
|
2865
|
+
const int32_t n_kv = kv_self.n;
|
|
2866
|
+
|
|
2867
|
+
wstate.inp_mask.resize(wsp_ggml_nelements(KQ_mask));
|
|
2868
|
+
|
|
2869
|
+
float * data = wstate.inp_mask.data();
|
|
2870
|
+
memset(data, 0, wsp_ggml_nbytes(KQ_mask));
|
|
2871
|
+
|
|
2872
|
+
for (int h = 0; h < 1; ++h) {
|
|
2873
|
+
for (int j = 0; j < n_tokens; ++j) {
|
|
2874
|
+
const whisper_pos pos = batch.pos[j];
|
|
2875
|
+
const whisper_seq_id seq_id = batch.seq_id[j][0];
|
|
2567
2876
|
|
|
2568
|
-
|
|
2877
|
+
for (int i = 0; i < n_kv; ++i) {
|
|
2878
|
+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
|
2879
|
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
|
2880
|
+
}
|
|
2881
|
+
}
|
|
2882
|
+
}
|
|
2569
2883
|
|
|
2570
|
-
|
|
2884
|
+
for (int i = n_tokens; i < WSP_GGML_PAD(n_tokens, WSP_GGML_KQ_MASK_PAD); ++i) {
|
|
2885
|
+
for (int j = 0; j < n_kv; ++j) {
|
|
2886
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
2887
|
+
}
|
|
2888
|
+
}
|
|
2889
|
+
}
|
|
2571
2890
|
|
|
2572
|
-
|
|
2891
|
+
wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
|
|
2892
|
+
}
|
|
2573
2893
|
|
|
2574
|
-
logits = gf
|
|
2894
|
+
logits = wsp_ggml_graph_node(gf, -1);
|
|
2575
2895
|
|
|
2576
|
-
wsp_ggml_graph_compute_helper(
|
|
2896
|
+
if (!wsp_ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2897
|
+
return false;
|
|
2898
|
+
}
|
|
2577
2899
|
}
|
|
2578
2900
|
|
|
2579
2901
|
logits_out.resize(n_tokens*n_vocab);
|
|
@@ -2625,29 +2947,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
2625
2947
|
}
|
|
2626
2948
|
|
|
2627
2949
|
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
|
2628
|
-
|
|
2629
|
-
|
|
2950
|
+
namespace {
|
|
2951
|
+
struct whisper_global_cache {
|
|
2952
|
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
|
2953
|
+
// We can use precalculated values to speed up the process.
|
|
2954
|
+
float sin_vals[SIN_COS_N_COUNT];
|
|
2955
|
+
float cos_vals[SIN_COS_N_COUNT];
|
|
2956
|
+
|
|
2957
|
+
// Hann window (Use cosf to eliminate difference)
|
|
2958
|
+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
|
2959
|
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
|
2960
|
+
float hann_window[WHISPER_N_FFT];
|
|
2961
|
+
|
|
2962
|
+
whisper_global_cache() {
|
|
2963
|
+
fill_sin_cos_table();
|
|
2964
|
+
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
|
|
2965
|
+
}
|
|
2966
|
+
|
|
2967
|
+
void fill_sin_cos_table() {
|
|
2968
|
+
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
|
2969
|
+
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
|
2970
|
+
sin_vals[i] = sinf(theta);
|
|
2971
|
+
cos_vals[i] = cosf(theta);
|
|
2972
|
+
}
|
|
2973
|
+
}
|
|
2630
2974
|
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2634
|
-
|
|
2635
|
-
|
|
2636
|
-
|
|
2637
|
-
|
|
2638
|
-
|
|
2639
|
-
cos_vals[i] = cosf(theta);
|
|
2975
|
+
void fill_hann_window(int length, bool periodic, float * output) {
|
|
2976
|
+
int offset = -1;
|
|
2977
|
+
if (periodic) {
|
|
2978
|
+
offset = 0;
|
|
2979
|
+
}
|
|
2980
|
+
for (int i = 0; i < length; i++) {
|
|
2981
|
+
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
|
2982
|
+
}
|
|
2640
2983
|
}
|
|
2641
|
-
|
|
2984
|
+
} global_cache;
|
|
2642
2985
|
}
|
|
2643
2986
|
|
|
2644
2987
|
// naive Discrete Fourier Transform
|
|
2645
2988
|
// input is real-valued
|
|
2646
2989
|
// output is complex-valued
|
|
2647
|
-
static void dft(const
|
|
2648
|
-
int N = in.size();
|
|
2649
|
-
|
|
2650
|
-
out.resize(N*2);
|
|
2990
|
+
static void dft(const float* in, int N, float* out) {
|
|
2651
2991
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
|
2652
2992
|
|
|
2653
2993
|
for (int k = 0; k < N; k++) {
|
|
@@ -2656,8 +2996,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2656
2996
|
|
|
2657
2997
|
for (int n = 0; n < N; n++) {
|
|
2658
2998
|
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
|
2659
|
-
re += in[n]*cos_vals[idx]; // cos(t)
|
|
2660
|
-
im -= in[n]*sin_vals[idx]; // sin(t)
|
|
2999
|
+
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
|
3000
|
+
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
|
2661
3001
|
}
|
|
2662
3002
|
|
|
2663
3003
|
out[k*2 + 0] = re;
|
|
@@ -2669,47 +3009,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2669
3009
|
// poor man's implementation - use something better
|
|
2670
3010
|
// input is real-valued
|
|
2671
3011
|
// output is complex-valued
|
|
2672
|
-
static void fft(
|
|
2673
|
-
out.resize(in.size()*2);
|
|
2674
|
-
|
|
2675
|
-
int N = in.size();
|
|
2676
|
-
|
|
3012
|
+
static void fft(float* in, int N, float* out) {
|
|
2677
3013
|
if (N == 1) {
|
|
2678
3014
|
out[0] = in[0];
|
|
2679
3015
|
out[1] = 0;
|
|
2680
3016
|
return;
|
|
2681
3017
|
}
|
|
2682
3018
|
|
|
2683
|
-
|
|
2684
|
-
|
|
3019
|
+
const int half_N = N / 2;
|
|
3020
|
+
if (N - half_N*2 == 1) {
|
|
3021
|
+
dft(in, N, out);
|
|
2685
3022
|
return;
|
|
2686
3023
|
}
|
|
2687
3024
|
|
|
2688
|
-
|
|
2689
|
-
|
|
2690
|
-
|
|
2691
|
-
even.reserve(N/2);
|
|
2692
|
-
odd.reserve(N/2);
|
|
2693
|
-
|
|
2694
|
-
for (int i = 0; i < N; i++) {
|
|
2695
|
-
if (i % 2 == 0) {
|
|
2696
|
-
even.push_back(in[i]);
|
|
2697
|
-
} else {
|
|
2698
|
-
odd.push_back(in[i]);
|
|
2699
|
-
}
|
|
3025
|
+
float* even = in + N;
|
|
3026
|
+
for (int i = 0; i < half_N; ++i) {
|
|
3027
|
+
even[i]= in[2*i];
|
|
2700
3028
|
}
|
|
3029
|
+
float* even_fft = out + 2 * N;
|
|
3030
|
+
fft(even, half_N, even_fft);
|
|
2701
3031
|
|
|
2702
|
-
|
|
2703
|
-
|
|
2704
|
-
|
|
2705
|
-
|
|
2706
|
-
|
|
3032
|
+
float* odd = even;
|
|
3033
|
+
for (int i = 0; i < half_N; ++i) {
|
|
3034
|
+
odd[i] = in[2*i + 1];
|
|
3035
|
+
}
|
|
3036
|
+
float* odd_fft = even_fft + N;
|
|
3037
|
+
fft(odd, half_N, odd_fft);
|
|
2707
3038
|
|
|
2708
3039
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
|
2709
|
-
for (int k = 0; k <
|
|
3040
|
+
for (int k = 0; k < half_N; k++) {
|
|
2710
3041
|
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
|
2711
|
-
float re = cos_vals[idx]; // cos(t)
|
|
2712
|
-
float im = -sin_vals[idx]; // sin(t)
|
|
3042
|
+
float re = global_cache.cos_vals[idx]; // cos(t)
|
|
3043
|
+
float im = -global_cache.sin_vals[idx]; // sin(t)
|
|
2713
3044
|
|
|
2714
3045
|
float re_odd = odd_fft[2*k + 0];
|
|
2715
3046
|
float im_odd = odd_fft[2*k + 1];
|
|
@@ -2717,61 +3048,49 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2717
3048
|
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
|
2718
3049
|
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
|
2719
3050
|
|
|
2720
|
-
out[2*(k +
|
|
2721
|
-
out[2*(k +
|
|
2722
|
-
}
|
|
2723
|
-
}
|
|
2724
|
-
|
|
2725
|
-
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
|
|
2726
|
-
if (output.size() < static_cast<size_t>(length)) {
|
|
2727
|
-
output.resize(length);
|
|
2728
|
-
}
|
|
2729
|
-
int offset = -1;
|
|
2730
|
-
if (periodic) {
|
|
2731
|
-
offset = 0;
|
|
2732
|
-
}
|
|
2733
|
-
for (int i = 0; i < length; i++) {
|
|
2734
|
-
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
|
|
3051
|
+
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
|
3052
|
+
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
|
2735
3053
|
}
|
|
2736
|
-
|
|
2737
|
-
return true;
|
|
2738
3054
|
}
|
|
2739
3055
|
|
|
2740
|
-
static void log_mel_spectrogram_worker_thread(int ith, const
|
|
3056
|
+
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
|
2741
3057
|
int n_samples, int frame_size, int frame_step, int n_threads,
|
|
2742
3058
|
const whisper_filters & filters, whisper_mel & mel) {
|
|
2743
|
-
std::vector<float> fft_in(frame_size, 0.0);
|
|
2744
|
-
std::vector<float> fft_out(2 *
|
|
2745
|
-
|
|
2746
|
-
int n_fft =
|
|
3059
|
+
std::vector<float> fft_in(frame_size * 2, 0.0);
|
|
3060
|
+
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
|
3061
|
+
|
|
3062
|
+
int n_fft = filters.n_fft;
|
|
2747
3063
|
int i = ith;
|
|
2748
3064
|
|
|
3065
|
+
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
|
3066
|
+
assert(n_fft == 1 + (frame_size / 2));
|
|
3067
|
+
|
|
2749
3068
|
// calculate FFT only when fft_in are not all zero
|
|
2750
3069
|
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
|
2751
3070
|
const int offset = i * frame_step;
|
|
2752
3071
|
|
|
2753
|
-
// apply
|
|
3072
|
+
// apply Hann window (~10% faster)
|
|
2754
3073
|
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
|
2755
3074
|
fft_in[j] = hann[j] * samples[offset + j];
|
|
2756
3075
|
}
|
|
3076
|
+
|
|
2757
3077
|
// fill the rest with zeros
|
|
2758
3078
|
if (n_samples - offset < frame_size) {
|
|
2759
3079
|
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
|
|
2760
3080
|
}
|
|
2761
3081
|
|
|
2762
3082
|
// FFT
|
|
2763
|
-
fft(fft_in, fft_out);
|
|
3083
|
+
fft(fft_in.data(), frame_size, fft_out.data());
|
|
2764
3084
|
|
|
2765
3085
|
// Calculate modulus^2 of complex numbers
|
|
2766
3086
|
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
|
2767
|
-
for (int j = 0; j <
|
|
3087
|
+
for (int j = 0; j < n_fft; j++) {
|
|
2768
3088
|
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
|
2769
3089
|
}
|
|
2770
3090
|
|
|
2771
3091
|
// mel spectrogram
|
|
2772
3092
|
for (int j = 0; j < mel.n_mel; j++) {
|
|
2773
3093
|
double sum = 0.0;
|
|
2774
|
-
|
|
2775
3094
|
// unroll loop (suggested by GH user @lunixbochs)
|
|
2776
3095
|
int k = 0;
|
|
2777
3096
|
for (k = 0; k < n_fft - 3; k += 4) {
|
|
@@ -2781,14 +3100,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
|
2781
3100
|
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
|
|
2782
3101
|
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
|
|
2783
3102
|
}
|
|
2784
|
-
|
|
2785
3103
|
// handle n_fft remainder
|
|
2786
3104
|
for (; k < n_fft; k++) {
|
|
2787
3105
|
sum += fft_out[k] * filters.data[j * n_fft + k];
|
|
2788
3106
|
}
|
|
2789
|
-
|
|
2790
3107
|
sum = log10(std::max(sum, 1e-10));
|
|
2791
|
-
|
|
2792
3108
|
mel.data[j * mel.n_len + i] = sum;
|
|
2793
3109
|
}
|
|
2794
3110
|
}
|
|
@@ -2817,12 +3133,9 @@ static bool log_mel_spectrogram(
|
|
|
2817
3133
|
whisper_mel & mel) {
|
|
2818
3134
|
const int64_t t_start_us = wsp_ggml_time_us();
|
|
2819
3135
|
|
|
2820
|
-
//
|
|
2821
|
-
|
|
2822
|
-
|
|
2823
|
-
std::vector<float> hann;
|
|
2824
|
-
hann_window(frame_size, true, hann);
|
|
2825
|
-
|
|
3136
|
+
// Hann window
|
|
3137
|
+
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
|
|
3138
|
+
const float * hann = global_cache.hann_window;
|
|
2826
3139
|
|
|
2827
3140
|
// Calculate the length of padding
|
|
2828
3141
|
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
|
@@ -2847,12 +3160,11 @@ static bool log_mel_spectrogram(
|
|
|
2847
3160
|
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
|
2848
3161
|
mel.data.resize(mel.n_mel * mel.n_len);
|
|
2849
3162
|
|
|
2850
|
-
|
|
2851
3163
|
{
|
|
2852
3164
|
std::vector<std::thread> workers(n_threads - 1);
|
|
2853
3165
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
|
2854
3166
|
workers[iw] = std::thread(
|
|
2855
|
-
log_mel_spectrogram_worker_thread, iw + 1,
|
|
3167
|
+
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
|
|
2856
3168
|
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
|
2857
3169
|
std::cref(filters), std::ref(mel));
|
|
2858
3170
|
}
|
|
@@ -3012,19 +3324,24 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
|
|
3012
3324
|
#endif
|
|
3013
3325
|
|
|
3014
3326
|
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3015
|
-
fill_sin_cos_table();
|
|
3016
|
-
|
|
3017
3327
|
whisper_state * state = new whisper_state;
|
|
3018
3328
|
|
|
3019
|
-
state->
|
|
3020
|
-
|
|
3021
|
-
|
|
3022
|
-
|
|
3023
|
-
|
|
3329
|
+
state->backends = whisper_backend_init(ctx->params);
|
|
3330
|
+
if (state->backends.empty()) {
|
|
3331
|
+
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
|
3332
|
+
whisper_free_state(state);
|
|
3333
|
+
return nullptr;
|
|
3334
|
+
}
|
|
3024
3335
|
|
|
3025
|
-
|
|
3026
|
-
|
|
3027
|
-
|
|
3336
|
+
// at this point, we don't know yet how many decoders will be used
|
|
3337
|
+
// later during decoding, if more decoders are used, we will recreate the KV cache respectively
|
|
3338
|
+
state->kv_self_n_dec = 1;
|
|
3339
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
|
3340
|
+
ctx->model.hparams.n_text_state,
|
|
3341
|
+
ctx->model.hparams.n_text_layer,
|
|
3342
|
+
WSP_GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
|
|
3343
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
|
3344
|
+
whisper_free_state(state);
|
|
3028
3345
|
return nullptr;
|
|
3029
3346
|
}
|
|
3030
3347
|
|
|
@@ -3033,9 +3350,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3033
3350
|
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
3034
3351
|
}
|
|
3035
3352
|
|
|
3036
|
-
if (!
|
|
3037
|
-
|
|
3038
|
-
|
|
3353
|
+
if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
|
|
3354
|
+
ctx->model.hparams.n_text_state,
|
|
3355
|
+
ctx->model.hparams.n_text_layer,
|
|
3356
|
+
WSP_GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
|
3357
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
3358
|
+
whisper_free_state(state);
|
|
3039
3359
|
return nullptr;
|
|
3040
3360
|
}
|
|
3041
3361
|
|
|
@@ -3044,6 +3364,31 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3044
3364
|
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
3045
3365
|
}
|
|
3046
3366
|
|
|
3367
|
+
if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
|
|
3368
|
+
ctx->model.hparams.n_audio_state,
|
|
3369
|
+
1,
|
|
3370
|
+
WSP_GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
|
3371
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
|
3372
|
+
whisper_free_state(state);
|
|
3373
|
+
return nullptr;
|
|
3374
|
+
}
|
|
3375
|
+
|
|
3376
|
+
{
|
|
3377
|
+
const size_t memory_size = wsp_ggml_nbytes(state->kv_pad.k) + wsp_ggml_nbytes(state->kv_pad.v);
|
|
3378
|
+
WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
3379
|
+
}
|
|
3380
|
+
|
|
3381
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
3382
|
+
if (ctx->params.dtw_token_timestamps) {
|
|
3383
|
+
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
|
|
3384
|
+
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
|
|
3385
|
+
whisper_free_state(state);
|
|
3386
|
+
return nullptr;
|
|
3387
|
+
}
|
|
3388
|
+
const size_t memory_size = aheads_masks_nbytes(state->aheads_masks);
|
|
3389
|
+
WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size);
|
|
3390
|
+
}
|
|
3391
|
+
|
|
3047
3392
|
|
|
3048
3393
|
#ifdef WHISPER_USE_COREML
|
|
3049
3394
|
if (ctx->params.use_coreml) {
|
|
@@ -3056,7 +3401,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3056
3401
|
if (!state->ctx_coreml) {
|
|
3057
3402
|
WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
3058
3403
|
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
|
3059
|
-
|
|
3404
|
+
whisper_free_state(state);
|
|
3060
3405
|
return nullptr;
|
|
3061
3406
|
#endif
|
|
3062
3407
|
} else {
|
|
@@ -3081,37 +3426,55 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3081
3426
|
|
|
3082
3427
|
// conv allocator
|
|
3083
3428
|
{
|
|
3084
|
-
|
|
3429
|
+
bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
|
|
3085
3430
|
[&]() {
|
|
3086
|
-
return whisper_build_graph_conv(*ctx, *state
|
|
3431
|
+
return whisper_build_graph_conv(*ctx, *state);
|
|
3087
3432
|
});
|
|
3088
3433
|
|
|
3089
|
-
|
|
3434
|
+
if (!ok) {
|
|
3435
|
+
WHISPER_LOG_ERROR("%s: failed to init conv allocator\n", __func__);
|
|
3436
|
+
whisper_free_state(state);
|
|
3437
|
+
return nullptr;
|
|
3438
|
+
}
|
|
3439
|
+
|
|
3440
|
+
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
|
|
3090
3441
|
}
|
|
3091
3442
|
|
|
3092
3443
|
// encoder allocator
|
|
3093
3444
|
if (!whisper_encode_external(*state)) {
|
|
3094
|
-
|
|
3445
|
+
bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
|
|
3095
3446
|
[&]() {
|
|
3096
3447
|
return whisper_build_graph_encoder(*ctx, *state);
|
|
3097
3448
|
});
|
|
3098
3449
|
|
|
3099
|
-
|
|
3450
|
+
if (!ok) {
|
|
3451
|
+
WHISPER_LOG_ERROR("%s: failed to init encoder allocator\n", __func__);
|
|
3452
|
+
whisper_free_state(state);
|
|
3453
|
+
return nullptr;
|
|
3454
|
+
}
|
|
3455
|
+
|
|
3456
|
+
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
|
|
3100
3457
|
}
|
|
3101
3458
|
|
|
3102
3459
|
// cross allocator
|
|
3103
3460
|
{
|
|
3104
|
-
|
|
3461
|
+
bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
|
|
3105
3462
|
[&]() {
|
|
3106
3463
|
return whisper_build_graph_cross(*ctx, *state);
|
|
3107
3464
|
});
|
|
3108
3465
|
|
|
3109
|
-
|
|
3466
|
+
if (!ok) {
|
|
3467
|
+
WHISPER_LOG_ERROR("%s: failed to init cross allocator\n", __func__);
|
|
3468
|
+
whisper_free_state(state);
|
|
3469
|
+
return nullptr;
|
|
3470
|
+
}
|
|
3471
|
+
|
|
3472
|
+
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
|
|
3110
3473
|
}
|
|
3111
3474
|
|
|
3112
3475
|
// decoder allocator
|
|
3113
3476
|
{
|
|
3114
|
-
|
|
3477
|
+
bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
|
|
3115
3478
|
[&]() {
|
|
3116
3479
|
const auto & hparams = ctx->model.hparams;
|
|
3117
3480
|
|
|
@@ -3121,27 +3484,30 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
3121
3484
|
|
|
3122
3485
|
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
|
|
3123
3486
|
|
|
3124
|
-
return whisper_build_graph_decoder(*ctx, *state, state->batch);
|
|
3487
|
+
return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true);
|
|
3125
3488
|
});
|
|
3126
3489
|
|
|
3127
|
-
|
|
3128
|
-
|
|
3490
|
+
if (!ok) {
|
|
3491
|
+
WHISPER_LOG_ERROR("%s: failed to init decoder allocator\n", __func__);
|
|
3492
|
+
whisper_free_state(state);
|
|
3493
|
+
return nullptr;
|
|
3494
|
+
}
|
|
3129
3495
|
|
|
3130
|
-
|
|
3131
|
-
|
|
3132
|
-
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
|
|
3133
|
-
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
|
|
3496
|
+
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
|
|
3497
|
+
}
|
|
3134
3498
|
|
|
3135
3499
|
return state;
|
|
3136
3500
|
}
|
|
3137
3501
|
|
|
3138
|
-
int
|
|
3502
|
+
int whisper_ctx_init_openvino_encoder_with_state(
|
|
3139
3503
|
struct whisper_context * ctx,
|
|
3504
|
+
struct whisper_state * state,
|
|
3140
3505
|
const char * model_path,
|
|
3141
3506
|
const char * device,
|
|
3142
3507
|
const char * cache_dir) {
|
|
3143
3508
|
#ifndef WHISPER_USE_OPENVINO
|
|
3144
3509
|
(void)(ctx);
|
|
3510
|
+
(void)(state);
|
|
3145
3511
|
(void)(model_path);
|
|
3146
3512
|
(void)(device);
|
|
3147
3513
|
(void)(cache_dir);
|
|
@@ -3172,8 +3538,8 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
3172
3538
|
WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
|
|
3173
3539
|
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
|
3174
3540
|
|
|
3175
|
-
|
|
3176
|
-
if (!
|
|
3541
|
+
state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
|
|
3542
|
+
if (!state->ctx_openvino) {
|
|
3177
3543
|
WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
|
|
3178
3544
|
return 1;
|
|
3179
3545
|
} else {
|
|
@@ -3184,18 +3550,43 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
3184
3550
|
#endif
|
|
3185
3551
|
}
|
|
3186
3552
|
|
|
3553
|
+
int whisper_ctx_init_openvino_encoder(
|
|
3554
|
+
struct whisper_context * ctx,
|
|
3555
|
+
const char * model_path,
|
|
3556
|
+
const char * device,
|
|
3557
|
+
const char * cache_dir) {
|
|
3558
|
+
return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
|
|
3559
|
+
}
|
|
3560
|
+
|
|
3187
3561
|
struct whisper_context_params whisper_context_default_params() {
|
|
3188
3562
|
struct whisper_context_params result = {
|
|
3189
|
-
/*.use_gpu
|
|
3190
|
-
/*.use_coreml
|
|
3563
|
+
/*.use_gpu =*/ true,
|
|
3564
|
+
/*.use_coreml =*/ false,
|
|
3565
|
+
/*.flash_attn =*/ false,
|
|
3566
|
+
/*.gpu_device =*/ 0,
|
|
3567
|
+
|
|
3568
|
+
/*.dtw_token_timestamps =*/ false,
|
|
3569
|
+
/*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE,
|
|
3570
|
+
/*.dtw_n_top =*/ -1,
|
|
3571
|
+
/*.dtw_aheads =*/ {
|
|
3572
|
+
/*.n_heads =*/ 0,
|
|
3573
|
+
/*.heads =*/ NULL,
|
|
3574
|
+
},
|
|
3575
|
+
/*.dtw_mem_size =*/ 1024*1024*128,
|
|
3191
3576
|
};
|
|
3192
3577
|
return result;
|
|
3193
3578
|
}
|
|
3194
3579
|
|
|
3195
3580
|
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
|
|
3196
3581
|
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
|
|
3197
|
-
|
|
3582
|
+
#ifdef _MSC_VER
|
|
3583
|
+
// Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
|
|
3584
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
|
3585
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
|
3586
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
|
3587
|
+
#else
|
|
3198
3588
|
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
3589
|
+
#endif
|
|
3199
3590
|
if (!fin) {
|
|
3200
3591
|
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
|
|
3201
3592
|
return nullptr;
|
|
@@ -3270,6 +3661,19 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
|
|
|
3270
3661
|
struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
|
|
3271
3662
|
wsp_ggml_time_init();
|
|
3272
3663
|
|
|
3664
|
+
if (params.flash_attn && params.dtw_token_timestamps) {
|
|
3665
|
+
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
|
|
3666
|
+
params.dtw_token_timestamps = false;
|
|
3667
|
+
}
|
|
3668
|
+
|
|
3669
|
+
WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
|
|
3670
|
+
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
|
|
3671
|
+
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
|
|
3672
|
+
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
|
|
3673
|
+
|
|
3674
|
+
// TODO: temporary call to force backend registry initialization
|
|
3675
|
+
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, wsp_ggml_backend_reg_count());
|
|
3676
|
+
|
|
3273
3677
|
whisper_context * ctx = new whisper_context;
|
|
3274
3678
|
ctx->params = params;
|
|
3275
3679
|
|
|
@@ -3354,11 +3758,11 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
|
|
3354
3758
|
return whisper_init_with_params_no_state(loader, whisper_context_default_params());
|
|
3355
3759
|
}
|
|
3356
3760
|
|
|
3357
|
-
void whisper_free_state(struct whisper_state * state)
|
|
3358
|
-
{
|
|
3761
|
+
void whisper_free_state(struct whisper_state * state) {
|
|
3359
3762
|
if (state) {
|
|
3360
|
-
|
|
3361
|
-
|
|
3763
|
+
whisper_kv_cache_free(state->kv_self);
|
|
3764
|
+
whisper_kv_cache_free(state->kv_cross);
|
|
3765
|
+
whisper_kv_cache_free(state->kv_pad);
|
|
3362
3766
|
|
|
3363
3767
|
#ifdef WHISPER_USE_COREML
|
|
3364
3768
|
if (state->ctx_coreml != nullptr) {
|
|
@@ -3376,12 +3780,17 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3376
3780
|
|
|
3377
3781
|
whisper_batch_free(state->batch);
|
|
3378
3782
|
|
|
3379
|
-
|
|
3380
|
-
|
|
3381
|
-
|
|
3382
|
-
|
|
3783
|
+
wsp_ggml_backend_sched_free(state->sched_conv.sched);
|
|
3784
|
+
wsp_ggml_backend_sched_free(state->sched_encode.sched);
|
|
3785
|
+
wsp_ggml_backend_sched_free(state->sched_cross.sched);
|
|
3786
|
+
wsp_ggml_backend_sched_free(state->sched_decode.sched);
|
|
3787
|
+
|
|
3788
|
+
for (auto & backend : state->backends) {
|
|
3789
|
+
wsp_ggml_backend_free(backend);
|
|
3790
|
+
}
|
|
3383
3791
|
|
|
3384
|
-
|
|
3792
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
3793
|
+
aheads_masks_free(state->aheads_masks);
|
|
3385
3794
|
|
|
3386
3795
|
delete state;
|
|
3387
3796
|
}
|
|
@@ -3389,18 +3798,12 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3389
3798
|
|
|
3390
3799
|
void whisper_free(struct whisper_context * ctx) {
|
|
3391
3800
|
if (ctx) {
|
|
3392
|
-
|
|
3393
|
-
wsp_ggml_free(ctx->model.ctx);
|
|
3394
|
-
}
|
|
3801
|
+
wsp_ggml_free(ctx->model.ctx);
|
|
3395
3802
|
|
|
3396
|
-
|
|
3397
|
-
wsp_ggml_backend_buffer_free(ctx->model.buffer);
|
|
3398
|
-
}
|
|
3803
|
+
wsp_ggml_backend_buffer_free(ctx->model.buffer);
|
|
3399
3804
|
|
|
3400
3805
|
whisper_free_state(ctx->state);
|
|
3401
3806
|
|
|
3402
|
-
wsp_ggml_backend_free(ctx->backend);
|
|
3403
|
-
|
|
3404
3807
|
delete ctx;
|
|
3405
3808
|
}
|
|
3406
3809
|
}
|
|
@@ -3430,30 +3833,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
|
3430
3833
|
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
|
3431
3834
|
}
|
|
3432
3835
|
|
|
3433
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
|
3434
|
-
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
|
3435
|
-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
|
3436
|
-
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
|
3437
|
-
return -1;
|
|
3438
|
-
}
|
|
3439
|
-
|
|
3440
|
-
return 0;
|
|
3441
|
-
}
|
|
3442
|
-
|
|
3443
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
|
3444
|
-
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
|
3445
|
-
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
|
3446
|
-
}
|
|
3447
|
-
|
|
3448
|
-
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
|
|
3449
|
-
// TODO
|
|
3450
|
-
|
|
3451
|
-
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
|
|
3452
|
-
// TODO
|
|
3453
|
-
|
|
3454
|
-
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
|
|
3455
|
-
// TODO
|
|
3456
|
-
|
|
3457
3836
|
int whisper_set_mel_with_state(
|
|
3458
3837
|
struct whisper_context * ctx,
|
|
3459
3838
|
struct whisper_state * state,
|
|
@@ -3506,7 +3885,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
|
3506
3885
|
|
|
3507
3886
|
whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
|
|
3508
3887
|
|
|
3509
|
-
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
|
|
3888
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) {
|
|
3510
3889
|
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3511
3890
|
return 1;
|
|
3512
3891
|
}
|
|
@@ -3528,7 +3907,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
|
|
|
3528
3907
|
|
|
3529
3908
|
if (n_max_tokens < (int) res.size()) {
|
|
3530
3909
|
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
|
3531
|
-
return -
|
|
3910
|
+
return -(int) res.size();
|
|
3532
3911
|
}
|
|
3533
3912
|
|
|
3534
3913
|
for (int i = 0; i < (int) res.size(); i++) {
|
|
@@ -3538,7 +3917,11 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
|
|
|
3538
3917
|
return res.size();
|
|
3539
3918
|
}
|
|
3540
3919
|
|
|
3541
|
-
int
|
|
3920
|
+
int whisper_token_count(struct whisper_context * ctx, const char * text) {
|
|
3921
|
+
return -whisper_tokenize(ctx, text, NULL, 0);
|
|
3922
|
+
}
|
|
3923
|
+
|
|
3924
|
+
int whisper_lang_max_id(void) {
|
|
3542
3925
|
auto max_id = 0;
|
|
3543
3926
|
for (const auto & kv : g_lang) {
|
|
3544
3927
|
max_id = std::max(max_id, kv.second.first);
|
|
@@ -3807,28 +4190,51 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
|
|
|
3807
4190
|
return ctx->vocab.token_transcribe;
|
|
3808
4191
|
}
|
|
3809
4192
|
|
|
4193
|
+
struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
|
|
4194
|
+
if (ctx->state == nullptr) {
|
|
4195
|
+
return nullptr;
|
|
4196
|
+
}
|
|
4197
|
+
return new whisper_timings {
|
|
4198
|
+
.load_us = ctx->t_load_us,
|
|
4199
|
+
.t_start_us = ctx->t_start_us,
|
|
4200
|
+
.fail_p = ctx->state->n_fail_p,
|
|
4201
|
+
.fail_h = ctx->state->n_fail_h,
|
|
4202
|
+
.t_mel_us = ctx->state->t_mel_us,
|
|
4203
|
+
.n_sample = ctx->state->n_sample,
|
|
4204
|
+
.n_encode = ctx->state->n_encode,
|
|
4205
|
+
.n_decode = ctx->state->n_decode,
|
|
4206
|
+
.n_batchd = ctx->state->n_batchd,
|
|
4207
|
+
.n_prompt = ctx->state->n_prompt,
|
|
4208
|
+
.t_sample_us = ctx->state->t_sample_us,
|
|
4209
|
+
.t_encode_us = ctx->state->t_encode_us,
|
|
4210
|
+
.t_decode_us = ctx->state->t_decode_us,
|
|
4211
|
+
.t_batchd_us = ctx->state->t_batchd_us,
|
|
4212
|
+
.t_prompt_us = ctx->state->t_prompt_us,
|
|
4213
|
+
};
|
|
4214
|
+
}
|
|
4215
|
+
|
|
3810
4216
|
void whisper_print_timings(struct whisper_context * ctx) {
|
|
3811
4217
|
const int64_t t_end_us = wsp_ggml_time_us();
|
|
4218
|
+
const struct whisper_timings * timings = whisper_get_timings(ctx);
|
|
3812
4219
|
|
|
3813
4220
|
WHISPER_LOG_INFO("\n");
|
|
3814
|
-
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__,
|
|
4221
|
+
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings->load_us / 1000.0f);
|
|
3815
4222
|
if (ctx->state != nullptr) {
|
|
3816
|
-
|
|
3817
4223
|
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
3818
4224
|
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
|
3819
4225
|
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
3820
4226
|
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
|
|
3821
4227
|
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
|
3822
4228
|
|
|
3823
|
-
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__,
|
|
3824
|
-
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__,
|
|
3825
|
-
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
3826
|
-
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
3827
|
-
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
3828
|
-
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
3829
|
-
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
4229
|
+
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings->fail_p, timings->fail_h);
|
|
4230
|
+
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings->t_mel_us/1000.0f);
|
|
4231
|
+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_sample_us, n_sample, 1e-3f * timings->t_sample_us / n_sample);
|
|
4232
|
+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_encode_us, n_encode, 1e-3f * timings->t_encode_us / n_encode);
|
|
4233
|
+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_decode_us, n_decode, 1e-3f * timings->t_decode_us / n_decode);
|
|
4234
|
+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_batchd_us, n_batchd, 1e-3f * timings->t_batchd_us / n_batchd);
|
|
4235
|
+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_prompt_us, n_prompt, 1e-3f * timings->t_prompt_us / n_prompt);
|
|
3830
4236
|
}
|
|
3831
|
-
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us -
|
|
4237
|
+
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings->t_start_us)/1000.0f);
|
|
3832
4238
|
}
|
|
3833
4239
|
|
|
3834
4240
|
void whisper_reset_timings(struct whisper_context * ctx) {
|
|
@@ -3838,6 +4244,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
|
|
|
3838
4244
|
ctx->state->t_sample_us = 0;
|
|
3839
4245
|
ctx->state->t_encode_us = 0;
|
|
3840
4246
|
ctx->state->t_decode_us = 0;
|
|
4247
|
+
ctx->state->t_batchd_us = 0;
|
|
3841
4248
|
ctx->state->t_prompt_us = 0;
|
|
3842
4249
|
ctx->state->n_sample = 0;
|
|
3843
4250
|
ctx->state->n_encode = 0;
|
|
@@ -3881,10 +4288,10 @@ const char * whisper_print_system_info(void) {
|
|
|
3881
4288
|
s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
|
|
3882
4289
|
s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
|
|
3883
4290
|
s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
|
|
3884
|
-
s += "CUDA = " + std::to_string(
|
|
4291
|
+
s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cuda()) + " | ";
|
|
3885
4292
|
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
|
3886
4293
|
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
|
3887
|
-
|
|
4294
|
+
s += "CANN = " + std::to_string(wsp_ggml_cpu_has_cann()) ;
|
|
3888
4295
|
return s.c_str();
|
|
3889
4296
|
}
|
|
3890
4297
|
|
|
@@ -3894,7 +4301,7 @@ const char * whisper_print_system_info(void) {
|
|
|
3894
4301
|
|
|
3895
4302
|
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
|
3896
4303
|
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
|
3897
|
-
std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
|
4304
|
+
static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
|
3898
4305
|
const char * src,
|
|
3899
4306
|
whisper_partial_utf8 partial_start) {
|
|
3900
4307
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
|
@@ -4308,7 +4715,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
|
|
|
4308
4715
|
|
|
4309
4716
|
////////////////////////////////////////////////////////////////////////////
|
|
4310
4717
|
|
|
4311
|
-
struct whisper_context_params * whisper_context_default_params_by_ref() {
|
|
4718
|
+
struct whisper_context_params * whisper_context_default_params_by_ref(void) {
|
|
4312
4719
|
struct whisper_context_params params = whisper_context_default_params();
|
|
4313
4720
|
|
|
4314
4721
|
struct whisper_context_params* result = new whisper_context_params();
|
|
@@ -4349,12 +4756,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
4349
4756
|
/*.split_on_word =*/ false,
|
|
4350
4757
|
/*.max_tokens =*/ 0,
|
|
4351
4758
|
|
|
4352
|
-
/*.speed_up =*/ false,
|
|
4353
4759
|
/*.debug_mode =*/ false,
|
|
4354
4760
|
/*.audio_ctx =*/ 0,
|
|
4355
4761
|
|
|
4356
4762
|
/*.tdrz_enable =*/ false,
|
|
4357
4763
|
|
|
4764
|
+
/* suppress_regex =*/ nullptr,
|
|
4765
|
+
|
|
4358
4766
|
/*.initial_prompt =*/ nullptr,
|
|
4359
4767
|
/*.prompt_tokens =*/ nullptr,
|
|
4360
4768
|
/*.prompt_n_tokens =*/ 0,
|
|
@@ -4440,6 +4848,17 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
|
|
|
4440
4848
|
return txt[0] == ' ';
|
|
4441
4849
|
}
|
|
4442
4850
|
|
|
4851
|
+
static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
4852
|
+
struct whisper_context * ctx,
|
|
4853
|
+
struct whisper_state * state,
|
|
4854
|
+
struct whisper_full_params params,
|
|
4855
|
+
int i_segment,
|
|
4856
|
+
size_t n_segments,
|
|
4857
|
+
int seek,
|
|
4858
|
+
int n_frames,
|
|
4859
|
+
int medfilt_width,
|
|
4860
|
+
int n_threads);
|
|
4861
|
+
|
|
4443
4862
|
// wrap the last segment to max_len characters
|
|
4444
4863
|
// returns the number of new segments
|
|
4445
4864
|
static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
|
|
@@ -4587,6 +5006,17 @@ static void whisper_process_logits(
|
|
|
4587
5006
|
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
|
4588
5007
|
}
|
|
4589
5008
|
|
|
5009
|
+
// suppress any tokens matching a regular expression
|
|
5010
|
+
// ref: https://github.com/openai/whisper/discussions/1041
|
|
5011
|
+
if (params.suppress_regex != nullptr) {
|
|
5012
|
+
std::regex re(params.suppress_regex);
|
|
5013
|
+
for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
|
|
5014
|
+
if (std::regex_match(token_id.first, re)) {
|
|
5015
|
+
logits[token_id.second] = -INFINITY;
|
|
5016
|
+
}
|
|
5017
|
+
}
|
|
5018
|
+
}
|
|
5019
|
+
|
|
4590
5020
|
// suppress non-speech tokens
|
|
4591
5021
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
|
4592
5022
|
if (params.suppress_non_speech_tokens) {
|
|
@@ -4790,12 +5220,25 @@ static void whisper_process_logits(
|
|
|
4790
5220
|
#endif
|
|
4791
5221
|
}
|
|
4792
5222
|
|
|
5223
|
+
static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) {
|
|
5224
|
+
if (a.tokens.size() != b.tokens.size()) {
|
|
5225
|
+
return false;
|
|
5226
|
+
}
|
|
5227
|
+
// sequences are more likely to diverge at the end
|
|
5228
|
+
for (int i = a.tokens.size() - 1; i >= 0; i--) {
|
|
5229
|
+
if (a.tokens[i].id != b.tokens[i].id) {
|
|
5230
|
+
return false;
|
|
5231
|
+
}
|
|
5232
|
+
}
|
|
5233
|
+
return true;
|
|
5234
|
+
}
|
|
5235
|
+
|
|
4793
5236
|
static whisper_token_data whisper_sample_token(
|
|
4794
5237
|
whisper_context & ctx,
|
|
4795
5238
|
const whisper_decoder & decoder,
|
|
4796
5239
|
bool best) {
|
|
4797
5240
|
whisper_token_data result = {
|
|
4798
|
-
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
|
5241
|
+
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
|
|
4799
5242
|
};
|
|
4800
5243
|
|
|
4801
5244
|
const auto & vocab = ctx.vocab;
|
|
@@ -4913,7 +5356,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
4913
5356
|
const auto id = dist(decoder.rng);
|
|
4914
5357
|
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
|
4915
5358
|
|
|
4916
|
-
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
|
5359
|
+
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
|
|
4917
5360
|
|
|
4918
5361
|
if (result[i].id >= vocab.token_beg) {
|
|
4919
5362
|
result[i].tid = result[i].id;
|
|
@@ -4966,7 +5409,7 @@ static void whisper_sequence_score(
|
|
|
4966
5409
|
const auto p = kv.second/(double)cnt;
|
|
4967
5410
|
entropy -= p*log(p);
|
|
4968
5411
|
|
|
4969
|
-
//
|
|
5412
|
+
//WHISPER_LOG_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
|
|
4970
5413
|
}
|
|
4971
5414
|
|
|
4972
5415
|
sequence.entropy = entropy;
|
|
@@ -4986,15 +5429,9 @@ int whisper_full_with_state(
|
|
|
4986
5429
|
|
|
4987
5430
|
if (n_samples > 0) {
|
|
4988
5431
|
// compute log mel spectrogram
|
|
4989
|
-
if (params.
|
|
4990
|
-
// TODO: Replace PV with more advanced algorithm
|
|
5432
|
+
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
4991
5433
|
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
4992
|
-
return -
|
|
4993
|
-
} else {
|
|
4994
|
-
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
4995
|
-
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
4996
|
-
return -2;
|
|
4997
|
-
}
|
|
5434
|
+
return -2;
|
|
4998
5435
|
}
|
|
4999
5436
|
}
|
|
5000
5437
|
|
|
@@ -5031,8 +5468,8 @@ int whisper_full_with_state(
|
|
|
5031
5468
|
// if length of spectrogram is less than 1.0s (100 frames), then return
|
|
5032
5469
|
// basically don't process anything that is less than 1.0s
|
|
5033
5470
|
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
|
5034
|
-
if (seek_end < seek_start +
|
|
5035
|
-
|
|
5471
|
+
if (seek_end < seek_start + 100) {
|
|
5472
|
+
WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
|
|
5036
5473
|
return 0;
|
|
5037
5474
|
}
|
|
5038
5475
|
|
|
@@ -5095,7 +5532,12 @@ int whisper_full_with_state(
|
|
|
5095
5532
|
// initial prompt
|
|
5096
5533
|
if (!params.prompt_tokens && params.initial_prompt) {
|
|
5097
5534
|
prompt_tokens.resize(1024);
|
|
5098
|
-
|
|
5535
|
+
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
|
|
5536
|
+
if (n_needed < 0) {
|
|
5537
|
+
prompt_tokens.resize(-n_needed);
|
|
5538
|
+
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
|
|
5539
|
+
}
|
|
5540
|
+
prompt_tokens.resize(n_needed);
|
|
5099
5541
|
params.prompt_tokens = prompt_tokens.data();
|
|
5100
5542
|
params.prompt_n_tokens = prompt_tokens.size();
|
|
5101
5543
|
}
|
|
@@ -5131,11 +5573,11 @@ int whisper_full_with_state(
|
|
|
5131
5573
|
}
|
|
5132
5574
|
}
|
|
5133
5575
|
|
|
5134
|
-
// distilled models require the "no_timestamps" token
|
|
5576
|
+
// first release distilled models require the "no_timestamps" token
|
|
5135
5577
|
{
|
|
5136
|
-
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
|
|
5578
|
+
const bool is_distil = ctx->model.hparams.n_text_layer == 2 && ctx->model.hparams.n_vocab != 51866;
|
|
5137
5579
|
if (is_distil && !params.no_timestamps) {
|
|
5138
|
-
WHISPER_LOG_WARN("%s: using distilled
|
|
5580
|
+
WHISPER_LOG_WARN("%s: using first release distilled models - forcing no_timestamps\n", __func__);
|
|
5139
5581
|
params.no_timestamps = true;
|
|
5140
5582
|
}
|
|
5141
5583
|
}
|
|
@@ -5221,7 +5663,7 @@ int whisper_full_with_state(
|
|
|
5221
5663
|
|
|
5222
5664
|
n_decoders_cur = std::max(1, n_decoders_cur);
|
|
5223
5665
|
|
|
5224
|
-
|
|
5666
|
+
WHISPER_LOG_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
|
|
5225
5667
|
|
|
5226
5668
|
// TAGS: WHISPER_DECODER_INIT
|
|
5227
5669
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
@@ -5265,19 +5707,40 @@ int whisper_full_with_state(
|
|
|
5265
5707
|
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
|
5266
5708
|
|
|
5267
5709
|
// print the prompt
|
|
5268
|
-
|
|
5710
|
+
WHISPER_LOG_DEBUG("\n\n");
|
|
5269
5711
|
for (int i = 0; i < (int) prompt.size(); i++) {
|
|
5270
|
-
|
|
5712
|
+
WHISPER_LOG_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
|
|
5713
|
+
}
|
|
5714
|
+
WHISPER_LOG_DEBUG("\n\n");
|
|
5715
|
+
|
|
5716
|
+
// recreate the KV cache if the number of decoders has changed
|
|
5717
|
+
if (state->kv_self_n_dec < n_decoders_cur) {
|
|
5718
|
+
WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
|
|
5719
|
+
|
|
5720
|
+
whisper_kv_cache_free(state->kv_self);
|
|
5721
|
+
|
|
5722
|
+
// overallocate to workaround KV cache fragmentation issues
|
|
5723
|
+
const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
|
|
5724
|
+
|
|
5725
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
|
5726
|
+
ctx->model.hparams.n_text_state,
|
|
5727
|
+
ctx->model.hparams.n_text_layer,
|
|
5728
|
+
WSP_GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
|
5729
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
|
5730
|
+
whisper_free_state(state);
|
|
5731
|
+
return -7;
|
|
5732
|
+
}
|
|
5733
|
+
|
|
5734
|
+
state->kv_self_n_dec = n_decoders_cur;
|
|
5271
5735
|
}
|
|
5272
|
-
WHISPER_PRINT_DEBUG("\n\n");
|
|
5273
5736
|
|
|
5274
5737
|
whisper_kv_cache_clear(state->kv_self);
|
|
5275
5738
|
|
|
5276
5739
|
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
|
5277
5740
|
|
|
5278
|
-
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
5741
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
|
5279
5742
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
5280
|
-
return -
|
|
5743
|
+
return -8;
|
|
5281
5744
|
}
|
|
5282
5745
|
|
|
5283
5746
|
{
|
|
@@ -5388,7 +5851,10 @@ int whisper_full_with_state(
|
|
|
5388
5851
|
beam_candidates.begin(),
|
|
5389
5852
|
beam_candidates.end(),
|
|
5390
5853
|
[](const beam_candidate & a, const beam_candidate & b) {
|
|
5391
|
-
|
|
5854
|
+
if (a.sequence.sum_logprobs_all != b.sequence.sum_logprobs_all) {
|
|
5855
|
+
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
|
5856
|
+
}
|
|
5857
|
+
return a.decoder_idx < b.decoder_idx;
|
|
5392
5858
|
});
|
|
5393
5859
|
|
|
5394
5860
|
uint32_t cur_c = 0;
|
|
@@ -5406,7 +5872,7 @@ int whisper_full_with_state(
|
|
|
5406
5872
|
|
|
5407
5873
|
auto & cur = beam_candidates[cur_c++];
|
|
5408
5874
|
|
|
5409
|
-
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence
|
|
5875
|
+
while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) {
|
|
5410
5876
|
++cur_c;
|
|
5411
5877
|
}
|
|
5412
5878
|
|
|
@@ -5417,7 +5883,7 @@ int whisper_full_with_state(
|
|
|
5417
5883
|
|
|
5418
5884
|
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
|
|
5419
5885
|
|
|
5420
|
-
|
|
5886
|
+
WHISPER_LOG_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
|
5421
5887
|
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
|
5422
5888
|
}
|
|
5423
5889
|
|
|
@@ -5460,7 +5926,7 @@ int whisper_full_with_state(
|
|
|
5460
5926
|
|
|
5461
5927
|
// do not allow to go back in time
|
|
5462
5928
|
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
|
5463
|
-
|
|
5929
|
+
WHISPER_LOG_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
|
|
5464
5930
|
failed = true; // TODO: maybe this is not a failure ?
|
|
5465
5931
|
continue;
|
|
5466
5932
|
}
|
|
@@ -5475,7 +5941,7 @@ int whisper_full_with_state(
|
|
|
5475
5941
|
#ifdef WHISPER_DEBUG
|
|
5476
5942
|
{
|
|
5477
5943
|
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
|
|
5478
|
-
|
|
5944
|
+
WHISPER_LOG_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
|
|
5479
5945
|
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
|
|
5480
5946
|
}
|
|
5481
5947
|
#endif
|
|
@@ -5485,22 +5951,22 @@ int whisper_full_with_state(
|
|
|
5485
5951
|
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
|
5486
5952
|
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
|
5487
5953
|
) {
|
|
5488
|
-
if (result_len == 0) {
|
|
5954
|
+
if (result_len == 0 && !params.no_timestamps) {
|
|
5489
5955
|
if (seek + seek_delta + 100 >= seek_end) {
|
|
5490
5956
|
result_len = i + 1;
|
|
5491
5957
|
} else {
|
|
5492
|
-
|
|
5958
|
+
WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
|
|
5493
5959
|
failed = true;
|
|
5494
5960
|
continue;
|
|
5495
5961
|
}
|
|
5496
5962
|
}
|
|
5497
5963
|
|
|
5498
|
-
if (params.single_segment) {
|
|
5964
|
+
if (params.single_segment || params.no_timestamps) {
|
|
5499
5965
|
result_len = i + 1;
|
|
5500
5966
|
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
5501
5967
|
}
|
|
5502
5968
|
|
|
5503
|
-
|
|
5969
|
+
WHISPER_LOG_DEBUG("%s: decoder %d completed\n", __func__, j);
|
|
5504
5970
|
completed = true;
|
|
5505
5971
|
continue;
|
|
5506
5972
|
}
|
|
@@ -5516,7 +5982,7 @@ int whisper_full_with_state(
|
|
|
5516
5982
|
// sometimes, the decoding can get stuck in a repetition loop
|
|
5517
5983
|
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
|
5518
5984
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
5519
|
-
|
|
5985
|
+
WHISPER_LOG_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
|
|
5520
5986
|
failed = true;
|
|
5521
5987
|
continue;
|
|
5522
5988
|
}
|
|
@@ -5558,7 +6024,7 @@ int whisper_full_with_state(
|
|
|
5558
6024
|
continue;
|
|
5559
6025
|
}
|
|
5560
6026
|
|
|
5561
|
-
//
|
|
6027
|
+
//WHISPER_LOG_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
|
|
5562
6028
|
|
|
5563
6029
|
decoder.i_batch = batch.n_tokens;
|
|
5564
6030
|
|
|
@@ -5572,9 +6038,9 @@ int whisper_full_with_state(
|
|
|
5572
6038
|
|
|
5573
6039
|
assert(batch.n_tokens > 0);
|
|
5574
6040
|
|
|
5575
|
-
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
6041
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
|
5576
6042
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
5577
|
-
return -
|
|
6043
|
+
return -9;
|
|
5578
6044
|
}
|
|
5579
6045
|
|
|
5580
6046
|
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
@@ -5638,11 +6104,11 @@ int whisper_full_with_state(
|
|
|
5638
6104
|
decoder.sequence.tokens.resize(decoder.sequence.result_len);
|
|
5639
6105
|
whisper_sequence_score(params, decoder.sequence);
|
|
5640
6106
|
|
|
5641
|
-
|
|
6107
|
+
WHISPER_LOG_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
|
|
5642
6108
|
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
|
|
5643
6109
|
|
|
5644
6110
|
if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
|
|
5645
|
-
|
|
6111
|
+
WHISPER_LOG_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
|
|
5646
6112
|
__func__, j, decoder.sequence.entropy, params.entropy_thold);
|
|
5647
6113
|
|
|
5648
6114
|
decoder.failed = true;
|
|
@@ -5657,7 +6123,7 @@ int whisper_full_with_state(
|
|
|
5657
6123
|
}
|
|
5658
6124
|
}
|
|
5659
6125
|
|
|
5660
|
-
|
|
6126
|
+
WHISPER_LOG_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
|
5661
6127
|
}
|
|
5662
6128
|
|
|
5663
6129
|
bool success = true;
|
|
@@ -5669,7 +6135,7 @@ int whisper_full_with_state(
|
|
|
5669
6135
|
const auto & decoder = state->decoders[best_decoder_id];
|
|
5670
6136
|
|
|
5671
6137
|
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
|
5672
|
-
|
|
6138
|
+
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
|
|
5673
6139
|
success = false;
|
|
5674
6140
|
state->n_fail_p++;
|
|
5675
6141
|
}
|
|
@@ -5677,13 +6143,13 @@ int whisper_full_with_state(
|
|
|
5677
6143
|
|
|
5678
6144
|
if (success) {
|
|
5679
6145
|
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
|
5680
|
-
//
|
|
6146
|
+
// WHISPER_LOG_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
|
5681
6147
|
//}
|
|
5682
6148
|
|
|
5683
6149
|
break;
|
|
5684
6150
|
}
|
|
5685
6151
|
|
|
5686
|
-
|
|
6152
|
+
WHISPER_LOG_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
|
5687
6153
|
}
|
|
5688
6154
|
|
|
5689
6155
|
// output results through a user-provided callback
|
|
@@ -5695,7 +6161,10 @@ int whisper_full_with_state(
|
|
|
5695
6161
|
|
|
5696
6162
|
const auto & tokens_cur = best_decoder.sequence.tokens;
|
|
5697
6163
|
|
|
5698
|
-
//
|
|
6164
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
6165
|
+
const auto n_segments_before = state->result_all.size();
|
|
6166
|
+
|
|
6167
|
+
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
|
5699
6168
|
|
|
5700
6169
|
// update prompt_past
|
|
5701
6170
|
prompt_past.clear();
|
|
@@ -5732,8 +6201,8 @@ int whisper_full_with_state(
|
|
|
5732
6201
|
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
|
5733
6202
|
|
|
5734
6203
|
if (!text.empty()) {
|
|
5735
|
-
const auto tt0 =
|
|
5736
|
-
const auto tt1 =
|
|
6204
|
+
const auto tt0 = t0;
|
|
6205
|
+
const auto tt1 = t1;
|
|
5737
6206
|
|
|
5738
6207
|
if (params.print_realtime) {
|
|
5739
6208
|
if (params.print_timestamps) {
|
|
@@ -5761,7 +6230,7 @@ int whisper_full_with_state(
|
|
|
5761
6230
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
|
5762
6231
|
}
|
|
5763
6232
|
}
|
|
5764
|
-
if (params.new_segment_callback) {
|
|
6233
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
|
5765
6234
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
|
5766
6235
|
}
|
|
5767
6236
|
}
|
|
@@ -5779,8 +6248,8 @@ int whisper_full_with_state(
|
|
|
5779
6248
|
if (!text.empty()) {
|
|
5780
6249
|
const auto t1 = seek + seek_delta;
|
|
5781
6250
|
|
|
5782
|
-
const auto tt0 =
|
|
5783
|
-
const auto tt1 =
|
|
6251
|
+
const auto tt0 = t0;
|
|
6252
|
+
const auto tt1 = t1;
|
|
5784
6253
|
|
|
5785
6254
|
if (params.print_realtime) {
|
|
5786
6255
|
if (params.print_timestamps) {
|
|
@@ -5806,16 +6275,32 @@ int whisper_full_with_state(
|
|
|
5806
6275
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
|
5807
6276
|
}
|
|
5808
6277
|
}
|
|
5809
|
-
if (params.new_segment_callback) {
|
|
6278
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
|
5810
6279
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
|
5811
6280
|
}
|
|
5812
6281
|
}
|
|
5813
6282
|
}
|
|
5814
6283
|
|
|
6284
|
+
// FIXME: will timestamp offsets be correct?
|
|
6285
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
6286
|
+
{
|
|
6287
|
+
const int n_segments = state->result_all.size() - n_segments_before;
|
|
6288
|
+
if (ctx->params.dtw_token_timestamps && n_segments) {
|
|
6289
|
+
const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
|
|
6290
|
+
whisper_exp_compute_token_level_timestamps_dtw(
|
|
6291
|
+
ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
|
|
6292
|
+
if (params.new_segment_callback) {
|
|
6293
|
+
for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
|
|
6294
|
+
params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
|
|
6295
|
+
}
|
|
6296
|
+
}
|
|
6297
|
+
}
|
|
6298
|
+
}
|
|
6299
|
+
|
|
5815
6300
|
// update audio window
|
|
5816
6301
|
seek += seek_delta;
|
|
5817
6302
|
|
|
5818
|
-
|
|
6303
|
+
WHISPER_LOG_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
|
|
5819
6304
|
}
|
|
5820
6305
|
}
|
|
5821
6306
|
|
|
@@ -6132,7 +6617,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
|
6132
6617
|
|
|
6133
6618
|
// multi-thread
|
|
6134
6619
|
|
|
6135
|
-
for (
|
|
6620
|
+
for (int32_t k = 1; k <= n_threads; k++) {
|
|
6136
6621
|
char * src = (char *) malloc(size);
|
|
6137
6622
|
char * dst = (char *) malloc(size);
|
|
6138
6623
|
|
|
@@ -6156,13 +6641,13 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
|
6156
6641
|
const int64_t t0 = wsp_ggml_time_us();
|
|
6157
6642
|
|
|
6158
6643
|
std::vector<std::thread> threads(k - 1);
|
|
6159
|
-
for (
|
|
6644
|
+
for (int32_t th = 0; th < k - 1; ++th) {
|
|
6160
6645
|
threads[th] = std::thread(helper, th);
|
|
6161
6646
|
}
|
|
6162
6647
|
|
|
6163
6648
|
helper(k - 1);
|
|
6164
6649
|
|
|
6165
|
-
for (
|
|
6650
|
+
for (int32_t th = 0; th < k - 1; ++th) {
|
|
6166
6651
|
threads[th].join();
|
|
6167
6652
|
}
|
|
6168
6653
|
|
|
@@ -6571,7 +7056,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
6571
7056
|
k++;
|
|
6572
7057
|
}
|
|
6573
7058
|
tokens[j].t1 = sample_to_timestamp(k);
|
|
6574
|
-
if (j <
|
|
7059
|
+
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
|
6575
7060
|
tokens[j].t1 = tokens[j + 1].t0;
|
|
6576
7061
|
} else {
|
|
6577
7062
|
s1 = k;
|
|
@@ -6614,6 +7099,322 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
6614
7099
|
//}
|
|
6615
7100
|
}
|
|
6616
7101
|
|
|
7102
|
+
//
|
|
7103
|
+
// token level timestamps - dtw version
|
|
7104
|
+
//
|
|
7105
|
+
|
|
7106
|
+
// n_text_layer -> total text layers on model
|
|
7107
|
+
// n_head -> total heads per text layer on model
|
|
7108
|
+
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) {
|
|
7109
|
+
std::vector<uint32_t> ret;
|
|
7110
|
+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
|
|
7111
|
+
return ret;
|
|
7112
|
+
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
|
|
7113
|
+
if (il >= n_text_layer - cparams.dtw_n_top) {
|
|
7114
|
+
for (int32_t i = 0; i < n_head; ++i) {
|
|
7115
|
+
ret.push_back(i);
|
|
7116
|
+
}
|
|
7117
|
+
}
|
|
7118
|
+
} else {
|
|
7119
|
+
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
|
|
7120
|
+
for (size_t i = 0; i < aheads.n_heads; ++i) {
|
|
7121
|
+
if (aheads.heads[i].n_text_layer == il) {
|
|
7122
|
+
ret.push_back(aheads.heads[i].n_head);
|
|
7123
|
+
}
|
|
7124
|
+
}
|
|
7125
|
+
}
|
|
7126
|
+
return ret;
|
|
7127
|
+
}
|
|
7128
|
+
|
|
7129
|
+
// dtw + backtrace to return found path
|
|
7130
|
+
// based on
|
|
7131
|
+
// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
|
|
7132
|
+
static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tensor * x) {
|
|
7133
|
+
WHISPER_ASSERT(wsp_ggml_n_dims(x) == 2);
|
|
7134
|
+
|
|
7135
|
+
int64_t N = x->ne[0];
|
|
7136
|
+
int64_t M = x->ne[1];
|
|
7137
|
+
struct wsp_ggml_tensor * cost = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, N + 1, M + 1);
|
|
7138
|
+
struct wsp_ggml_tensor * trace = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, N + 1, M + 1);
|
|
7139
|
+
|
|
7140
|
+
cost = wsp_ggml_set_f32(cost, INFINITY);
|
|
7141
|
+
trace = wsp_ggml_set_f32(trace, -1);
|
|
7142
|
+
wsp_ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
|
|
7143
|
+
|
|
7144
|
+
// dtw
|
|
7145
|
+
// supposedly can be optmized by computing diagonals in parallel ?
|
|
7146
|
+
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
|
|
7147
|
+
for (int64_t j = 1; j < M + 1; ++j) {
|
|
7148
|
+
for (int64_t i = 1; i < N + 1; ++i) {
|
|
7149
|
+
float c0 = wsp_ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
|
|
7150
|
+
float c1 = wsp_ggml_get_f32_nd(cost, i - 1, j, 0, 0);
|
|
7151
|
+
float c2 = wsp_ggml_get_f32_nd(cost, i, j - 1, 0, 0);
|
|
7152
|
+
|
|
7153
|
+
float c;
|
|
7154
|
+
int32_t t;
|
|
7155
|
+
if (c0 < c1 && c0 < c2) {
|
|
7156
|
+
c = c0;
|
|
7157
|
+
t = 0;
|
|
7158
|
+
} else if (c1 < c0 && c1 < c2) {
|
|
7159
|
+
c = c1;
|
|
7160
|
+
t = 1;
|
|
7161
|
+
} else {
|
|
7162
|
+
c = c2;
|
|
7163
|
+
t = 2;
|
|
7164
|
+
}
|
|
7165
|
+
|
|
7166
|
+
c = wsp_ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
|
|
7167
|
+
wsp_ggml_set_f32_nd(cost, i, j, 0, 0, c);
|
|
7168
|
+
wsp_ggml_set_i32_nd(trace, i, j, 0, 0, t);
|
|
7169
|
+
}
|
|
7170
|
+
}
|
|
7171
|
+
|
|
7172
|
+
// Backtrace
|
|
7173
|
+
const int64_t BT_MAX_ROWS = N + M - 1;
|
|
7174
|
+
struct wsp_ggml_tensor * bt = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, BT_MAX_ROWS, 2);
|
|
7175
|
+
// trace[0, :] = 2;
|
|
7176
|
+
for (int64_t i = 0; i < M + 1; ++i)
|
|
7177
|
+
wsp_ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
|
|
7178
|
+
//trace[:, 0] = 1;
|
|
7179
|
+
for (int64_t i = 0; i < N + 1; ++i)
|
|
7180
|
+
wsp_ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
|
|
7181
|
+
int bt_row_idx = BT_MAX_ROWS - 1;
|
|
7182
|
+
int64_t i = N;
|
|
7183
|
+
int64_t j = M;
|
|
7184
|
+
while (i > 0 || j > 0) {
|
|
7185
|
+
wsp_ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
|
|
7186
|
+
wsp_ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
|
|
7187
|
+
--bt_row_idx;
|
|
7188
|
+
|
|
7189
|
+
int32_t t = wsp_ggml_get_i32_nd(trace, i, j, 0, 0);
|
|
7190
|
+
if (t == 0) {
|
|
7191
|
+
--i;
|
|
7192
|
+
--j;
|
|
7193
|
+
} else if (t == 1) {
|
|
7194
|
+
--i;
|
|
7195
|
+
} else if (t == 2) {
|
|
7196
|
+
--j;
|
|
7197
|
+
} else {
|
|
7198
|
+
WHISPER_ASSERT(0);
|
|
7199
|
+
}
|
|
7200
|
+
}
|
|
7201
|
+
|
|
7202
|
+
// FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs)
|
|
7203
|
+
// Clip + transpose
|
|
7204
|
+
// This might not be entirely necessary for our case, but leaving it for now so output matrix
|
|
7205
|
+
// is identical to dtw on openAI timing.py
|
|
7206
|
+
const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1;
|
|
7207
|
+
wsp_ggml_tensor * r = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, 2, result_n_cols);
|
|
7208
|
+
for (int64_t i = 0; i < 2; ++i) {
|
|
7209
|
+
for (int64_t j = 0; j < result_n_cols; ++j) {
|
|
7210
|
+
int32_t v = wsp_ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
|
|
7211
|
+
wsp_ggml_set_i32_nd(r, i, j, 0, 0, v);
|
|
7212
|
+
}
|
|
7213
|
+
}
|
|
7214
|
+
|
|
7215
|
+
return r;
|
|
7216
|
+
}
|
|
7217
|
+
|
|
7218
|
+
struct median_filter_user_data {
|
|
7219
|
+
int filter_width;
|
|
7220
|
+
};
|
|
7221
|
+
|
|
7222
|
+
static void median_filter(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
|
|
7223
|
+
if (ith != 0) {
|
|
7224
|
+
return;
|
|
7225
|
+
}
|
|
7226
|
+
int filter_width = ((median_filter_user_data *) userdata)->filter_width;
|
|
7227
|
+
WHISPER_ASSERT(filter_width < a->ne[2]);
|
|
7228
|
+
WHISPER_ASSERT(filter_width % 2);
|
|
7229
|
+
WHISPER_ASSERT(wsp_ggml_n_dims(a) == 3);
|
|
7230
|
+
WHISPER_ASSERT(a->type == WSP_GGML_TYPE_F32);
|
|
7231
|
+
|
|
7232
|
+
std::vector<float> filter;
|
|
7233
|
+
filter.reserve(filter_width);
|
|
7234
|
+
for (int64_t i = 0; i < a->ne[0]; ++i) {
|
|
7235
|
+
for (int64_t j = 0; j < a->ne[1]; ++j) {
|
|
7236
|
+
for (int64_t k = 0; k < a->ne[2]; ++k) {
|
|
7237
|
+
for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) {
|
|
7238
|
+
// "reflect" padding
|
|
7239
|
+
int64_t idx = k + off;
|
|
7240
|
+
if (idx < 0) {
|
|
7241
|
+
idx = -idx;
|
|
7242
|
+
} else if (idx >= a->ne[2]) {
|
|
7243
|
+
idx = 2*(a->ne[2] - 1) - idx;
|
|
7244
|
+
}
|
|
7245
|
+
|
|
7246
|
+
filter.push_back(wsp_ggml_get_f32_nd(a, i, j, idx, 0));
|
|
7247
|
+
}
|
|
7248
|
+
std::sort(filter.begin(), filter.end());
|
|
7249
|
+
const float v = filter[filter.size()/2];
|
|
7250
|
+
wsp_ggml_set_f32_nd(dst, i, j, k, 0, v);
|
|
7251
|
+
filter.clear();
|
|
7252
|
+
}
|
|
7253
|
+
}
|
|
7254
|
+
}
|
|
7255
|
+
}
|
|
7256
|
+
|
|
7257
|
+
static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7258
|
+
struct whisper_context * ctx,
|
|
7259
|
+
struct whisper_state * state,
|
|
7260
|
+
struct whisper_full_params params,
|
|
7261
|
+
int i_segment,
|
|
7262
|
+
size_t n_segments,
|
|
7263
|
+
int seek,
|
|
7264
|
+
int n_frames,
|
|
7265
|
+
int medfilt_width,
|
|
7266
|
+
int n_threads)
|
|
7267
|
+
{
|
|
7268
|
+
const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx;
|
|
7269
|
+
WHISPER_ASSERT(medfilt_width % 2);
|
|
7270
|
+
WHISPER_ASSERT(n_frames <= n_audio_ctx * 2);
|
|
7271
|
+
WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE);
|
|
7272
|
+
|
|
7273
|
+
// FIXME: Allocating mem everytime we call this func
|
|
7274
|
+
// Our ggml buffer should be pre-allocated somewhere during init and reused
|
|
7275
|
+
// when we call this function
|
|
7276
|
+
struct wsp_ggml_init_params gparams = {
|
|
7277
|
+
/*.mem_size =*/ ctx->params.dtw_mem_size,
|
|
7278
|
+
/*.mem_buffer =*/ NULL,
|
|
7279
|
+
/*.no_alloc =*/ false,
|
|
7280
|
+
};
|
|
7281
|
+
struct wsp_ggml_context * gctx = wsp_ggml_init(gparams);
|
|
7282
|
+
|
|
7283
|
+
// Build token sequence that will be passed to decoder
|
|
7284
|
+
// sot + [lang] + text result + eot
|
|
7285
|
+
std::vector<whisper_token> tokens = { whisper_token_sot(ctx), };
|
|
7286
|
+
if (whisper_is_multilingual(ctx)) {
|
|
7287
|
+
const int lang_id = whisper_lang_id(params.language);
|
|
7288
|
+
state->lang_id = lang_id;
|
|
7289
|
+
tokens.push_back(whisper_token_lang(ctx, lang_id));
|
|
7290
|
+
}
|
|
7291
|
+
const size_t sot_sequence_length = tokens.size();
|
|
7292
|
+
tokens.push_back(whisper_token_not(ctx));
|
|
7293
|
+
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
|
|
7294
|
+
auto & segment = state->result_all[i];
|
|
7295
|
+
for (auto &t: segment.tokens) {
|
|
7296
|
+
// Only text tokens
|
|
7297
|
+
if (t.id < whisper_token_eot(ctx)) {
|
|
7298
|
+
tokens.push_back(t.id);
|
|
7299
|
+
}
|
|
7300
|
+
}
|
|
7301
|
+
}
|
|
7302
|
+
tokens.push_back(whisper_token_eot(ctx));
|
|
7303
|
+
|
|
7304
|
+
// Get result tokens, pass then along to decoder to get cross attention QKs
|
|
7305
|
+
// used in timestamping
|
|
7306
|
+
// Decoder already returns only alignment head QKs, already concatenated in
|
|
7307
|
+
// one tensor.
|
|
7308
|
+
whisper_kv_cache_clear(state->kv_self);
|
|
7309
|
+
whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0);
|
|
7310
|
+
whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1);
|
|
7311
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
|
|
7312
|
+
WHISPER_LOG_INFO("DECODER FAILED\n");
|
|
7313
|
+
WHISPER_ASSERT(0);
|
|
7314
|
+
}
|
|
7315
|
+
WHISPER_ASSERT(state->aheads_cross_QKs != nullptr);
|
|
7316
|
+
|
|
7317
|
+
const auto n_audio_tokens = n_frames/2;
|
|
7318
|
+
WHISPER_ASSERT(state->aheads_cross_QKs != NULL);
|
|
7319
|
+
WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]);
|
|
7320
|
+
const auto n_tokens = state->aheads_cross_QKs->ne[0];
|
|
7321
|
+
const auto n_heads = state->aheads_cross_QKs->ne[2];
|
|
7322
|
+
|
|
7323
|
+
// Copy data from decoder buffer to a local CPU tensor, discarding unused audio
|
|
7324
|
+
// tokens (i.e. discarding rows at the end of tensor)
|
|
7325
|
+
// IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims
|
|
7326
|
+
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
|
7327
|
+
WHISPER_ASSERT(state->aheads_cross_QKs->type == WSP_GGML_TYPE_F32);
|
|
7328
|
+
WHISPER_ASSERT(wsp_ggml_is_contiguous(state->aheads_cross_QKs));
|
|
7329
|
+
wsp_ggml_tensor * w = wsp_ggml_new_tensor_3d(gctx, WSP_GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads);
|
|
7330
|
+
auto & data = state->aheads_cross_QKs_data;
|
|
7331
|
+
data.resize(n_tokens * n_audio_ctx * n_heads);
|
|
7332
|
+
wsp_ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);
|
|
7333
|
+
for (int k = 0; k < n_heads; ++k) {
|
|
7334
|
+
for (int j = 0; j < n_audio_tokens; ++j) {
|
|
7335
|
+
memcpy(
|
|
7336
|
+
(char *) w->data + j * w->nb[1] + k * w->nb[2],
|
|
7337
|
+
data.data() + j * n_tokens + k * n_tokens * n_audio_ctx,
|
|
7338
|
+
n_tokens * sizeof(float)
|
|
7339
|
+
);
|
|
7340
|
+
}
|
|
7341
|
+
}
|
|
7342
|
+
|
|
7343
|
+
// Normalize - in original OpenAI code, this is done over dim=-2. In this case,
|
|
7344
|
+
// we already permuted N_TOKENS dimension to columns on last loop, becase wsp_ggml_norm
|
|
7345
|
+
// operates over columns. Afterwards, permute to a shape that facilitates mean
|
|
7346
|
+
// operation (after median filter)
|
|
7347
|
+
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
|
7348
|
+
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
|
7349
|
+
w = wsp_ggml_norm(gctx, w, 1e-9f);
|
|
7350
|
+
w = wsp_ggml_permute(gctx, wsp_ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
|
|
7351
|
+
|
|
7352
|
+
// Pass median filter - this is done over AUDIO_TOKENS dimension.
|
|
7353
|
+
// IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
|
7354
|
+
// OUT: Same dims
|
|
7355
|
+
median_filter_user_data mf_user_data = {medfilt_width};
|
|
7356
|
+
w = wsp_ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data);
|
|
7357
|
+
|
|
7358
|
+
// Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT
|
|
7359
|
+
// IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
|
7360
|
+
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims
|
|
7361
|
+
w = wsp_ggml_mean(gctx, w);
|
|
7362
|
+
w = wsp_ggml_scale(gctx, w, -1.0);
|
|
7363
|
+
w = wsp_ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]);
|
|
7364
|
+
|
|
7365
|
+
// Remove SOT sequence and EOT
|
|
7366
|
+
// Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS
|
|
7367
|
+
w = wsp_ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]);
|
|
7368
|
+
|
|
7369
|
+
// Compute
|
|
7370
|
+
struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(gctx);
|
|
7371
|
+
wsp_ggml_build_forward_expand(gf, w);
|
|
7372
|
+
wsp_ggml_graph_compute_with_ctx(gctx, gf, n_threads);
|
|
7373
|
+
|
|
7374
|
+
wsp_ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
|
|
7375
|
+
|
|
7376
|
+
// Place timestamps on segments
|
|
7377
|
+
int32_t last_v = 0;
|
|
7378
|
+
auto seg_i = state->result_all.begin() + i_segment;
|
|
7379
|
+
auto tok_i = seg_i->tokens.begin();
|
|
7380
|
+
for (int i = 0; i < alignment->ne[1]; ++i) {
|
|
7381
|
+
int32_t v = wsp_ggml_get_i32_nd(alignment, 0, i, 0, 0);
|
|
7382
|
+
if (v != last_v) {
|
|
7383
|
+
int32_t time_index = wsp_ggml_get_i32_nd(alignment, 1, i, 0, 0);
|
|
7384
|
+
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
|
|
7385
|
+
last_v = v;
|
|
7386
|
+
|
|
7387
|
+
// Skip non-text tokens
|
|
7388
|
+
while (!(tok_i->id < whisper_token_eot(ctx))) {
|
|
7389
|
+
++tok_i;
|
|
7390
|
+
if (tok_i == seg_i->tokens.end()) {
|
|
7391
|
+
++seg_i;
|
|
7392
|
+
tok_i = seg_i->tokens.begin();
|
|
7393
|
+
}
|
|
7394
|
+
}
|
|
7395
|
+
|
|
7396
|
+
tok_i->t_dtw = timestamp;
|
|
7397
|
+
++tok_i;
|
|
7398
|
+
if (tok_i == seg_i->tokens.end()) {
|
|
7399
|
+
++seg_i;
|
|
7400
|
+
tok_i = seg_i->tokens.begin();
|
|
7401
|
+
}
|
|
7402
|
+
}
|
|
7403
|
+
}
|
|
7404
|
+
|
|
7405
|
+
// Print DTW timestamps
|
|
7406
|
+
/*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
|
|
7407
|
+
auto & segment = state->result_all[i];
|
|
7408
|
+
for (auto &t: segment.tokens) {
|
|
7409
|
+
const char * tok = whisper_token_to_str(ctx, t.id);
|
|
7410
|
+
fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100);
|
|
7411
|
+
}
|
|
7412
|
+
fprintf(stderr, "\n");
|
|
7413
|
+
}*/
|
|
7414
|
+
|
|
7415
|
+
wsp_ggml_free(gctx);
|
|
7416
|
+
}
|
|
7417
|
+
|
|
6617
7418
|
void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
|
|
6618
7419
|
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
|
6619
7420
|
g_state.log_callback_user_data = user_data;
|