whispercpp 1.3.0 → 1.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
@@ -1,29 +1,19 @@
|
|
1
1
|
#include "whisper.h"
|
2
2
|
|
3
|
-
#
|
4
|
-
#include "coreml/whisper-encoder.h"
|
5
|
-
#endif
|
3
|
+
#include "ggml-cpu.h"
|
6
4
|
|
7
|
-
#
|
8
|
-
#include "ggml-
|
9
|
-
#
|
10
|
-
|
11
|
-
#ifdef GGML_USE_CUDA
|
12
|
-
#include "ggml-cuda.h"
|
13
|
-
#endif
|
5
|
+
#include "ggml.h"
|
6
|
+
#include "ggml-alloc.h"
|
7
|
+
#include "ggml-backend.h"
|
14
8
|
|
15
|
-
#ifdef
|
16
|
-
#include "
|
9
|
+
#ifdef WHISPER_USE_COREML
|
10
|
+
#include "coreml/whisper-encoder.h"
|
17
11
|
#endif
|
18
12
|
|
19
13
|
#ifdef WHISPER_USE_OPENVINO
|
20
14
|
#include "openvino/whisper-openvino-encoder.h"
|
21
15
|
#endif
|
22
16
|
|
23
|
-
#include "ggml.h"
|
24
|
-
#include "ggml-alloc.h"
|
25
|
-
#include "ggml-backend.h"
|
26
|
-
|
27
17
|
#include <atomic>
|
28
18
|
#include <algorithm>
|
29
19
|
#include <cassert>
|
@@ -41,6 +31,7 @@
|
|
41
31
|
#include <regex>
|
42
32
|
#include <random>
|
43
33
|
#include <functional>
|
34
|
+
#include <codecvt>
|
44
35
|
|
45
36
|
#if defined(_MSC_VER)
|
46
37
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
@@ -147,8 +138,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|
147
138
|
} \
|
148
139
|
} while (0)
|
149
140
|
|
150
|
-
//#define WHISPER_USE_FLASH_ATTN
|
151
|
-
//#define WHISPER_USE_FLASH_FF
|
152
141
|
#define WHISPER_MAX_DECODERS 8
|
153
142
|
#define WHISPER_MAX_NODES 4096
|
154
143
|
|
@@ -162,7 +151,7 @@ static bool ggml_graph_compute_helper(
|
|
162
151
|
int n_threads,
|
163
152
|
ggml_abort_callback abort_callback,
|
164
153
|
void * abort_callback_data) {
|
165
|
-
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
154
|
+
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
|
166
155
|
|
167
156
|
plan.abort_callback = abort_callback;
|
168
157
|
plan.abort_callback_data = abort_callback_data;
|
@@ -176,18 +165,24 @@ static bool ggml_graph_compute_helper(
|
|
176
165
|
}
|
177
166
|
|
178
167
|
static bool ggml_graph_compute_helper(
|
179
|
-
|
168
|
+
ggml_backend_sched_t sched,
|
180
169
|
struct ggml_cgraph * graph,
|
181
170
|
int n_threads) {
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
171
|
+
|
172
|
+
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
|
173
|
+
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
|
174
|
+
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
175
|
+
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
176
|
+
|
177
|
+
auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
178
|
+
if (fn_set_n_threads) {
|
179
|
+
fn_set_n_threads(backend, n_threads);
|
180
|
+
}
|
188
181
|
}
|
189
|
-
|
190
|
-
|
182
|
+
|
183
|
+
bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
|
184
|
+
ggml_backend_sched_reset(sched);
|
185
|
+
return t;
|
191
186
|
}
|
192
187
|
|
193
188
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
@@ -363,6 +358,7 @@ static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15},
|
|
363
358
|
static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
|
364
359
|
static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
|
365
360
|
static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
|
361
|
+
static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
|
366
362
|
|
367
363
|
static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
368
364
|
{ WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
|
@@ -376,6 +372,7 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
|
376
372
|
{ WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
|
377
373
|
{ WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
|
378
374
|
{ WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
|
375
|
+
{ WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
|
379
376
|
};
|
380
377
|
|
381
378
|
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
|
@@ -502,33 +499,41 @@ struct whisper_pair {
|
|
502
499
|
whisper_pair() : first(A()), second(B()) {}
|
503
500
|
};
|
504
501
|
|
505
|
-
//
|
506
|
-
struct
|
507
|
-
|
502
|
+
// ggml_backend_sched wrapper for whisper usage
|
503
|
+
struct whisper_sched {
|
504
|
+
ggml_backend_sched_t sched = nullptr;
|
508
505
|
|
509
506
|
std::vector<uint8_t> meta;
|
510
507
|
};
|
511
508
|
|
512
|
-
static size_t
|
513
|
-
|
509
|
+
static size_t whisper_sched_size(struct whisper_sched & allocr) {
|
510
|
+
size_t size = allocr.meta.size();
|
511
|
+
for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
|
512
|
+
ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
|
513
|
+
size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
|
514
|
+
}
|
515
|
+
return size;
|
514
516
|
}
|
515
517
|
|
516
518
|
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
517
|
-
static bool
|
518
|
-
auto &
|
519
|
+
static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
|
520
|
+
auto & sched = allocr.sched;
|
519
521
|
auto & meta = allocr.meta;
|
520
522
|
|
521
|
-
|
523
|
+
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
|
522
524
|
|
523
525
|
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
524
526
|
|
525
527
|
// since there are dependencies between the different graphs,
|
526
528
|
// we need to allocate them instead of only reserving to get the correct compute buffer size
|
527
|
-
if (!
|
529
|
+
if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
|
528
530
|
// failed to allocate the compute buffer
|
529
531
|
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
530
532
|
return false;
|
531
533
|
}
|
534
|
+
|
535
|
+
ggml_backend_sched_reset(sched);
|
536
|
+
|
532
537
|
return true;
|
533
538
|
}
|
534
539
|
|
@@ -671,9 +676,9 @@ struct whisper_kv_cache {
|
|
671
676
|
struct ggml_tensor * k;
|
672
677
|
struct ggml_tensor * v;
|
673
678
|
|
674
|
-
struct ggml_context * ctx = nullptr;
|
675
|
-
|
676
679
|
ggml_backend_buffer_t buffer = nullptr;
|
680
|
+
|
681
|
+
std::vector<uint8_t> ctx_buf;
|
677
682
|
};
|
678
683
|
|
679
684
|
struct whisper_model {
|
@@ -802,6 +807,9 @@ struct whisper_state {
|
|
802
807
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
803
808
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
804
809
|
|
810
|
+
// number of decoders for which we have constructed the KV cache
|
811
|
+
int32_t kv_self_n_dec = 0;
|
812
|
+
|
805
813
|
// unified self-attention KV cache for all decoders
|
806
814
|
whisper_kv_cache kv_self;
|
807
815
|
|
@@ -809,21 +817,22 @@ struct whisper_state {
|
|
809
817
|
// shared between all decoders
|
810
818
|
whisper_kv_cache kv_cross;
|
811
819
|
|
820
|
+
// padded buffer for flash-attention
|
821
|
+
whisper_kv_cache kv_pad;
|
822
|
+
|
812
823
|
whisper_mel mel;
|
813
824
|
|
814
825
|
whisper_batch batch;
|
815
826
|
|
816
827
|
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
817
828
|
|
818
|
-
ggml_backend_t
|
829
|
+
std::vector<ggml_backend_t> backends;
|
819
830
|
|
820
|
-
// ggml-alloc:
|
821
831
|
// - stores meta info about the intermediate tensors into the `meta` buffers
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
whisper_allocr alloc_decode;
|
832
|
+
whisper_sched sched_conv;
|
833
|
+
whisper_sched sched_encode;
|
834
|
+
whisper_sched sched_cross;
|
835
|
+
whisper_sched sched_decode;
|
827
836
|
|
828
837
|
// result of the encoder
|
829
838
|
struct ggml_tensor * embd_conv = nullptr;
|
@@ -858,6 +867,7 @@ struct whisper_state {
|
|
858
867
|
whisper_token tid_last;
|
859
868
|
|
860
869
|
std::vector<float> energy; // PCM signal energy
|
870
|
+
float no_speech_prob = 0.0f;
|
861
871
|
|
862
872
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
863
873
|
whisper_aheads_masks aheads_masks;
|
@@ -882,8 +892,6 @@ struct whisper_context {
|
|
882
892
|
|
883
893
|
whisper_state * state = nullptr;
|
884
894
|
|
885
|
-
ggml_backend_t backend = nullptr;
|
886
|
-
|
887
895
|
std::string path_model; // populated by whisper_init_from_file_with_params()
|
888
896
|
};
|
889
897
|
|
@@ -901,21 +909,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
901
909
|
BYTESWAP_VALUE(dest);
|
902
910
|
}
|
903
911
|
|
904
|
-
static bool
|
905
|
-
const struct whisper_hparams & hparams,
|
912
|
+
static bool whisper_kv_cache_init(
|
906
913
|
struct whisper_kv_cache & cache,
|
907
914
|
ggml_backend_t backend,
|
908
915
|
ggml_type wtype,
|
916
|
+
int64_t n_text_state,
|
917
|
+
int64_t n_text_layer,
|
909
918
|
int n_ctx) {
|
910
|
-
const int64_t n_text_state = hparams.n_text_state;
|
911
|
-
const int64_t n_text_layer = hparams.n_text_layer;
|
912
|
-
|
913
919
|
const int64_t n_mem = n_text_layer*n_ctx;
|
914
920
|
const int64_t n_elements = n_text_state*n_mem;
|
915
921
|
|
922
|
+
cache.ctx_buf.resize(2*ggml_tensor_overhead());
|
923
|
+
|
916
924
|
struct ggml_init_params params = {
|
917
|
-
/*.mem_size =*/
|
918
|
-
/*.mem_buffer =*/
|
925
|
+
/*.mem_size =*/ cache.ctx_buf.size(),
|
926
|
+
/*.mem_buffer =*/ cache.ctx_buf.data(),
|
919
927
|
/*.no_alloc =*/ true,
|
920
928
|
};
|
921
929
|
|
@@ -925,29 +933,31 @@ static bool kv_cache_init(
|
|
925
933
|
cache.cells.clear();
|
926
934
|
cache.cells.resize(n_ctx);
|
927
935
|
|
928
|
-
|
936
|
+
struct ggml_context * ctx = ggml_init(params);
|
929
937
|
|
930
|
-
if (!
|
938
|
+
if (!ctx) {
|
931
939
|
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
|
932
940
|
return false;
|
933
941
|
}
|
934
942
|
|
935
|
-
cache.k = ggml_new_tensor_1d(
|
936
|
-
cache.v = ggml_new_tensor_1d(
|
943
|
+
cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
944
|
+
cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
937
945
|
|
938
|
-
cache.buffer = ggml_backend_alloc_ctx_tensors(
|
946
|
+
cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
939
947
|
if (!cache.buffer) {
|
940
948
|
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
|
941
949
|
return false;
|
942
950
|
}
|
943
951
|
|
952
|
+
ggml_backend_buffer_clear(cache.buffer, 0);
|
953
|
+
|
954
|
+
ggml_free(ctx);
|
955
|
+
|
944
956
|
return true;
|
945
957
|
}
|
946
958
|
|
947
|
-
static void
|
948
|
-
ggml_free(cache.ctx);
|
959
|
+
static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
|
949
960
|
ggml_backend_buffer_free(cache.buffer);
|
950
|
-
cache.ctx = nullptr;
|
951
961
|
}
|
952
962
|
|
953
963
|
static bool whisper_kv_cache_find_slot(
|
@@ -1018,6 +1028,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
|
1018
1028
|
cache.cells[i].seq_id.clear();
|
1019
1029
|
}
|
1020
1030
|
cache.head = 0;
|
1031
|
+
|
1032
|
+
ggml_backend_buffer_clear(cache.buffer, 0);
|
1021
1033
|
}
|
1022
1034
|
|
1023
1035
|
static void whisper_kv_cache_seq_rm(
|
@@ -1068,6 +1080,26 @@ static void whisper_kv_cache_seq_cp(
|
|
1068
1080
|
}
|
1069
1081
|
}
|
1070
1082
|
|
1083
|
+
static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
|
1084
|
+
if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
|
1085
|
+
return 1u;
|
1086
|
+
}
|
1087
|
+
|
1088
|
+
#ifdef GGML_USE_METAL
|
1089
|
+
if (wctx.params.use_gpu) {
|
1090
|
+
return 32u;
|
1091
|
+
}
|
1092
|
+
#endif
|
1093
|
+
|
1094
|
+
#ifdef GGML_USE_CUDA
|
1095
|
+
if (wctx.params.use_gpu) {
|
1096
|
+
return 256u;
|
1097
|
+
}
|
1098
|
+
#endif
|
1099
|
+
|
1100
|
+
return 1u;
|
1101
|
+
}
|
1102
|
+
|
1071
1103
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
1072
1104
|
static bool aheads_masks_init(
|
1073
1105
|
const whisper_context_params & cparams,
|
@@ -1199,49 +1231,71 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
|
1199
1231
|
return size;
|
1200
1232
|
}
|
1201
1233
|
|
1202
|
-
static ggml_backend_t
|
1203
|
-
|
1234
|
+
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
1235
|
+
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
1204
1236
|
|
1205
|
-
// initialize the backends
|
1206
|
-
#ifdef GGML_USE_CUDA
|
1207
1237
|
if (params.use_gpu) {
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1238
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1239
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1240
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1241
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1242
|
+
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
|
1243
|
+
if (!result) {
|
1244
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1245
|
+
}
|
1246
|
+
return result;
|
1247
|
+
}
|
1212
1248
|
}
|
1213
1249
|
}
|
1214
|
-
#endif
|
1215
1250
|
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
backend_gpu = NULL;
|
1227
|
-
}
|
1251
|
+
return nullptr;
|
1252
|
+
}
|
1253
|
+
|
1254
|
+
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
|
1255
|
+
std::vector<ggml_backend_t> result;
|
1256
|
+
|
1257
|
+
ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
|
1258
|
+
|
1259
|
+
if (backend_gpu) {
|
1260
|
+
result.push_back(backend_gpu);
|
1228
1261
|
}
|
1229
|
-
#endif
|
1230
1262
|
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1263
|
+
// ACCEL backends
|
1264
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1265
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1266
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
1267
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1268
|
+
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
1269
|
+
if (!backend) {
|
1270
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1271
|
+
continue;
|
1272
|
+
}
|
1273
|
+
result.push_back(backend);
|
1237
1274
|
}
|
1238
1275
|
}
|
1239
|
-
#endif
|
1240
1276
|
|
1241
|
-
|
1242
|
-
|
1277
|
+
GGML_UNUSED(params);
|
1278
|
+
|
1279
|
+
result.push_back(ggml_backend_cpu_init());
|
1280
|
+
|
1281
|
+
return result;
|
1282
|
+
}
|
1283
|
+
|
1284
|
+
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
|
1285
|
+
if (!params.use_gpu) {
|
1286
|
+
return ggml_backend_cpu_buffer_type();
|
1243
1287
|
}
|
1244
|
-
|
1288
|
+
|
1289
|
+
// if we have a GPU device - use it
|
1290
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1291
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1292
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1293
|
+
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
|
1294
|
+
return ggml_backend_dev_buffer_type(dev);
|
1295
|
+
}
|
1296
|
+
}
|
1297
|
+
|
1298
|
+
return ggml_backend_cpu_buffer_type();
|
1245
1299
|
}
|
1246
1300
|
|
1247
1301
|
// load the model from a ggml file
|
@@ -1668,21 +1722,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1668
1722
|
}
|
1669
1723
|
}
|
1670
1724
|
|
1671
|
-
wctx.backend = whisper_backend_init(wctx.params);
|
1672
|
-
if (!wctx.backend) {
|
1673
|
-
WHISPER_LOG_ERROR("%s: failed to initialize the backend\n", __func__);
|
1674
|
-
return false;
|
1675
|
-
}
|
1676
|
-
|
1677
1725
|
// allocate tensors in the backend buffers
|
1678
|
-
model.buffer =
|
1726
|
+
model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
|
1679
1727
|
if (!model.buffer) {
|
1680
1728
|
WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
|
1681
1729
|
return false;
|
1682
1730
|
}
|
1683
1731
|
|
1684
1732
|
size_t size_main = ggml_backend_buffer_get_size(model.buffer);
|
1685
|
-
WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__,
|
1733
|
+
WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
|
1686
1734
|
|
1687
1735
|
// load weights
|
1688
1736
|
{
|
@@ -1777,6 +1825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1777
1825
|
}
|
1778
1826
|
}
|
1779
1827
|
|
1828
|
+
ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
1829
|
+
|
1780
1830
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
1781
1831
|
|
1782
1832
|
return true;
|
@@ -1812,8 +1862,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
1812
1862
|
const int n_mels = hparams.n_mels;
|
1813
1863
|
|
1814
1864
|
struct ggml_init_params params = {
|
1815
|
-
/*.mem_size =*/ wstate.
|
1816
|
-
/*.mem_buffer =*/ wstate.
|
1865
|
+
/*.mem_size =*/ wstate.sched_conv.meta.size(),
|
1866
|
+
/*.mem_buffer =*/ wstate.sched_conv.meta.data(),
|
1817
1867
|
/*.no_alloc =*/ true,
|
1818
1868
|
};
|
1819
1869
|
|
@@ -1847,6 +1897,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
1847
1897
|
ggml_build_forward_expand(gf, mel);
|
1848
1898
|
|
1849
1899
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
1900
|
+
ggml_set_input(cur); // the external encoder will write into this tensor
|
1850
1901
|
|
1851
1902
|
ggml_set_name(cur, "embd_enc");
|
1852
1903
|
wstate.embd_enc = cur;
|
@@ -1872,9 +1923,17 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1872
1923
|
const int n_head = hparams.n_audio_head;
|
1873
1924
|
const int n_layer = hparams.n_audio_layer;
|
1874
1925
|
|
1926
|
+
const int n_state_head = n_state/n_head;
|
1927
|
+
|
1928
|
+
auto & kv_pad = wstate.kv_pad;
|
1929
|
+
|
1930
|
+
WHISPER_ASSERT(!!kv_pad.buffer);
|
1931
|
+
|
1932
|
+
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
|
1933
|
+
|
1875
1934
|
struct ggml_init_params params = {
|
1876
|
-
/*.mem_size =*/ wstate.
|
1877
|
-
/*.mem_buffer =*/ wstate.
|
1935
|
+
/*.mem_size =*/ wstate.sched_encode.meta.size(),
|
1936
|
+
/*.mem_buffer =*/ wstate.sched_encode.meta.data(),
|
1878
1937
|
/*.no_alloc =*/ true,
|
1879
1938
|
};
|
1880
1939
|
|
@@ -1884,7 +1943,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1884
1943
|
|
1885
1944
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
1886
1945
|
|
1887
|
-
const float KQscale = 1.0f/sqrtf(float(
|
1946
|
+
const float KQscale = 1.0f/sqrtf(float(n_state_head));
|
1888
1947
|
|
1889
1948
|
// ===================================================================
|
1890
1949
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
@@ -1934,14 +1993,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1934
1993
|
|
1935
1994
|
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
1936
1995
|
|
1937
|
-
//Qcur = ggml_scale(ctx0, Qcur, pow(float(
|
1996
|
+
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
|
1938
1997
|
|
1939
1998
|
// note: no bias for Key
|
1940
1999
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
1941
2000
|
layer.attn_k_w,
|
1942
2001
|
cur);
|
1943
2002
|
|
1944
|
-
//Kcur = ggml_scale(ctx0, Kcur, pow(float(
|
2003
|
+
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
|
1945
2004
|
|
1946
2005
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
1947
2006
|
layer.attn_v_w,
|
@@ -1951,70 +2010,60 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1951
2010
|
|
1952
2011
|
// ------
|
1953
2012
|
|
1954
|
-
#ifdef WHISPER_USE_FLASH_ATTN
|
1955
2013
|
struct ggml_tensor * Q =
|
1956
2014
|
ggml_permute(ctx0,
|
1957
|
-
|
1958
|
-
Qcur,
|
1959
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
2015
|
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
|
1960
2016
|
0, 2, 1, 3);
|
1961
2017
|
|
1962
|
-
|
1963
|
-
|
1964
|
-
|
1965
|
-
Kcur,
|
1966
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
1967
|
-
0, 2, 1, 3);
|
2018
|
+
if (wctx.params.flash_attn) {
|
2019
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
|
2020
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
|
1968
2021
|
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1972
|
-
|
1973
|
-
|
1974
|
-
|
1975
|
-
1, 2, 0, 3),
|
1976
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
2022
|
+
struct ggml_tensor * K =
|
2023
|
+
ggml_view_3d(ctx0, kv_pad.k,
|
2024
|
+
n_state_head, n_ctx_pad, n_head,
|
2025
|
+
ggml_element_size(kv_pad.k)*n_state,
|
2026
|
+
ggml_element_size(kv_pad.k)*n_state_head,
|
2027
|
+
0);
|
1977
2028
|
|
1978
|
-
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
1985
|
-
0, 2, 1, 3);
|
2029
|
+
struct ggml_tensor * V =
|
2030
|
+
ggml_view_3d(ctx0, kv_pad.v,
|
2031
|
+
n_state_head, n_ctx_pad, n_head,
|
2032
|
+
ggml_element_size(kv_pad.v)*n_state,
|
2033
|
+
ggml_element_size(kv_pad.v)*n_state_head,
|
2034
|
+
0);
|
1986
2035
|
|
1987
|
-
|
1988
|
-
ggml_permute(ctx0,
|
1989
|
-
ggml_cpy(ctx0,
|
1990
|
-
Kcur,
|
1991
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
1992
|
-
0, 2, 1, 3);
|
2036
|
+
cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
|
1993
2037
|
|
1994
|
-
|
1995
|
-
|
2038
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
2039
|
+
} else {
|
2040
|
+
struct ggml_tensor * K =
|
2041
|
+
ggml_permute(ctx0,
|
2042
|
+
ggml_cast(ctx0,
|
2043
|
+
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
|
2044
|
+
wctx.itype),
|
2045
|
+
0, 2, 1, 3);
|
1996
2046
|
|
1997
|
-
|
2047
|
+
// K * Q
|
2048
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
1998
2049
|
|
1999
|
-
|
2050
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
2000
2051
|
|
2001
|
-
|
2002
|
-
|
2003
|
-
|
2004
|
-
|
2005
|
-
|
2006
|
-
|
2007
|
-
|
2008
|
-
|
2009
|
-
);
|
2052
|
+
struct ggml_tensor * V =
|
2053
|
+
ggml_cast(ctx0,
|
2054
|
+
ggml_permute(ctx0,
|
2055
|
+
ggml_reshape_3d(ctx0,
|
2056
|
+
Vcur,
|
2057
|
+
n_state_head, n_head, n_ctx),
|
2058
|
+
1, 2, 0, 3),
|
2059
|
+
wctx.itype);
|
2010
2060
|
|
2011
|
-
|
2012
|
-
#endif
|
2013
|
-
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2061
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
2014
2062
|
|
2015
|
-
|
2016
|
-
|
2017
|
-
|
2063
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2064
|
+
|
2065
|
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
|
2066
|
+
}
|
2018
2067
|
}
|
2019
2068
|
|
2020
2069
|
// projection
|
@@ -2043,11 +2092,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
2043
2092
|
layer.mlp_ln_b);
|
2044
2093
|
}
|
2045
2094
|
|
2046
|
-
#ifdef WHISPER_USE_FLASH_FF
|
2047
|
-
cur = ggml_flash_ff(ctx0,
|
2048
|
-
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
2049
|
-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
2050
|
-
#else
|
2051
2095
|
// fully connected
|
2052
2096
|
cur = ggml_mul_mat(ctx0,
|
2053
2097
|
layer.mlp_0_w,
|
@@ -2064,7 +2108,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
2064
2108
|
cur);
|
2065
2109
|
|
2066
2110
|
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
|
2067
|
-
#endif
|
2068
2111
|
}
|
2069
2112
|
|
2070
2113
|
inpL = ggml_add(ctx0, cur, inpFF);
|
@@ -2113,9 +2156,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
2113
2156
|
const int n_state = hparams.n_audio_state;
|
2114
2157
|
const int n_head = hparams.n_audio_head;
|
2115
2158
|
|
2159
|
+
const int n_state_head = n_state/n_head;
|
2160
|
+
|
2161
|
+
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
|
2162
|
+
|
2116
2163
|
struct ggml_init_params params = {
|
2117
|
-
/*.mem_size =*/ wstate.
|
2118
|
-
/*.mem_buffer =*/ wstate.
|
2164
|
+
/*.mem_size =*/ wstate.sched_cross.meta.size(),
|
2165
|
+
/*.mem_buffer =*/ wstate.sched_cross.meta.data(),
|
2119
2166
|
/*.no_alloc =*/ true,
|
2120
2167
|
};
|
2121
2168
|
|
@@ -2125,18 +2172,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
2125
2172
|
|
2126
2173
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
2127
2174
|
|
2128
|
-
const float Kscale = pow(float(
|
2175
|
+
const float Kscale = pow(float(n_state_head), -0.25);
|
2129
2176
|
|
2130
2177
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
2131
2178
|
auto & layer = model.layers_decoder[il];
|
2132
2179
|
|
2133
|
-
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
2180
|
+
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
|
2134
2181
|
layer.cross_attn_k_w,
|
2135
2182
|
cur);
|
2136
2183
|
|
2137
2184
|
Kcross = ggml_scale(ctx0, Kcross, Kscale);
|
2138
2185
|
|
2139
|
-
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
2186
|
+
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
2140
2187
|
layer.cross_attn_v_w,
|
2141
2188
|
cur);
|
2142
2189
|
|
@@ -2144,15 +2191,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
2144
2191
|
Vcross,
|
2145
2192
|
layer.cross_attn_v_b);
|
2146
2193
|
|
2147
|
-
|
2194
|
+
struct ggml_tensor * k;
|
2195
|
+
struct ggml_tensor * v;
|
2148
2196
|
|
2149
|
-
|
2150
|
-
|
2151
|
-
|
2197
|
+
if (wctx.params.flash_attn) {
|
2198
|
+
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
2199
|
+
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
|
2152
2200
|
|
2153
|
-
|
2154
|
-
|
2155
|
-
|
2201
|
+
v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
|
2202
|
+
(ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
|
2203
|
+
} else {
|
2204
|
+
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
2205
|
+
|
2206
|
+
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
2207
|
+
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
2208
|
+
|
2209
|
+
v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
2210
|
+
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
2211
|
+
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
2212
|
+
}
|
2156
2213
|
|
2157
2214
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
|
2158
2215
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
|
@@ -2186,11 +2243,11 @@ static bool whisper_encode_internal(
|
|
2186
2243
|
|
2187
2244
|
// conv
|
2188
2245
|
{
|
2189
|
-
auto &
|
2246
|
+
auto & sched = wstate.sched_conv.sched;
|
2190
2247
|
|
2191
2248
|
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
|
2192
2249
|
|
2193
|
-
if (!
|
2250
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2194
2251
|
// should never happen as we pre-allocate the memory
|
2195
2252
|
return false;
|
2196
2253
|
}
|
@@ -2223,7 +2280,7 @@ static bool whisper_encode_internal(
|
|
2223
2280
|
}
|
2224
2281
|
|
2225
2282
|
if (!whisper_encode_external(wstate)) {
|
2226
|
-
if (!ggml_graph_compute_helper(
|
2283
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2227
2284
|
return false;
|
2228
2285
|
}
|
2229
2286
|
} else {
|
@@ -2237,32 +2294,32 @@ static bool whisper_encode_internal(
|
|
2237
2294
|
|
2238
2295
|
// encoder
|
2239
2296
|
if (!whisper_encode_external(wstate)) {
|
2240
|
-
auto &
|
2297
|
+
auto & sched = wstate.sched_encode.sched;
|
2241
2298
|
|
2242
2299
|
ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
|
2243
2300
|
|
2244
|
-
if (!
|
2301
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2245
2302
|
// should never happen as we pre-allocate the memory
|
2246
2303
|
return false;
|
2247
2304
|
}
|
2248
2305
|
|
2249
|
-
if (!ggml_graph_compute_helper(
|
2306
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2250
2307
|
return false;
|
2251
2308
|
}
|
2252
2309
|
}
|
2253
2310
|
|
2254
2311
|
// cross
|
2255
2312
|
{
|
2256
|
-
auto &
|
2313
|
+
auto & sched = wstate.sched_cross.sched;
|
2257
2314
|
|
2258
2315
|
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
2259
2316
|
|
2260
|
-
if (!
|
2317
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2261
2318
|
// should never happen as we pre-allocate the memory
|
2262
2319
|
return false;
|
2263
2320
|
}
|
2264
2321
|
|
2265
|
-
if (!ggml_graph_compute_helper(
|
2322
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2266
2323
|
return false;
|
2267
2324
|
}
|
2268
2325
|
}
|
@@ -2284,24 +2341,28 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2284
2341
|
|
2285
2342
|
auto & kv_self = wstate.kv_self;
|
2286
2343
|
|
2287
|
-
WHISPER_ASSERT(!!kv_self.
|
2344
|
+
WHISPER_ASSERT(!!kv_self.buffer);
|
2288
2345
|
|
2289
2346
|
const int n_ctx = kv_self.size;
|
2290
2347
|
const int n_state = hparams.n_text_state;
|
2291
2348
|
const int n_head = hparams.n_text_head;
|
2292
2349
|
const int n_layer = hparams.n_text_layer;
|
2293
2350
|
|
2351
|
+
const int n_state_head = n_state/n_head;
|
2352
|
+
|
2294
2353
|
const int n_tokens = batch.n_tokens;
|
2295
2354
|
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
2296
2355
|
|
2297
|
-
const
|
2298
|
-
|
2356
|
+
const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
|
2357
|
+
|
2358
|
+
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
|
2359
|
+
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
|
2299
2360
|
|
2300
2361
|
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
2301
2362
|
|
2302
2363
|
struct ggml_init_params params = {
|
2303
|
-
/*.mem_size =*/ wstate.
|
2304
|
-
/*.mem_buffer =*/ wstate.
|
2364
|
+
/*.mem_size =*/ wstate.sched_decode.meta.size(),
|
2365
|
+
/*.mem_buffer =*/ wstate.sched_decode.meta.data(),
|
2305
2366
|
/*.no_alloc =*/ true,
|
2306
2367
|
};
|
2307
2368
|
|
@@ -2317,12 +2378,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2317
2378
|
ggml_set_name(position, "position");
|
2318
2379
|
ggml_set_input(position);
|
2319
2380
|
|
2320
|
-
const float KQscale = pow(float(
|
2381
|
+
const float KQscale = pow(float(n_state_head), -0.25);
|
2321
2382
|
|
2322
|
-
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
2383
|
+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
|
2323
2384
|
ggml_set_name(KQ_mask, "KQ_mask");
|
2324
2385
|
ggml_set_input(KQ_mask);
|
2325
2386
|
|
2387
|
+
struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
|
2388
|
+
|
2326
2389
|
// token encoding + position encoding
|
2327
2390
|
struct ggml_tensor * cur =
|
2328
2391
|
ggml_add(ctx0,
|
@@ -2378,12 +2441,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2378
2441
|
Vcur,
|
2379
2442
|
layer.attn_v_b);
|
2380
2443
|
|
2381
|
-
|
2444
|
+
struct ggml_tensor * k;
|
2445
|
+
struct ggml_tensor * v;
|
2446
|
+
|
2447
|
+
if (wctx.params.flash_attn) {
|
2448
|
+
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
2449
|
+
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
2450
|
+
|
2451
|
+
v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
|
2452
|
+
(ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
|
2453
|
+
} else {
|
2454
|
+
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
2455
|
+
|
2456
|
+
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
2457
|
+
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
2382
2458
|
|
2383
|
-
|
2384
|
-
|
2385
|
-
|
2386
|
-
|
2459
|
+
v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
2460
|
+
( n_ctx)*ggml_element_size(kv_self.v),
|
2461
|
+
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
|
2462
|
+
}
|
2387
2463
|
|
2388
2464
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
2389
2465
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
@@ -2393,40 +2469,46 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2393
2469
|
|
2394
2470
|
struct ggml_tensor * Q =
|
2395
2471
|
ggml_permute(ctx0,
|
2396
|
-
ggml_reshape_3d(ctx0, Qcur,
|
2472
|
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
2397
2473
|
0, 2, 1, 3);
|
2398
2474
|
|
2399
2475
|
struct ggml_tensor * K =
|
2400
2476
|
ggml_view_3d(ctx0, kv_self.k,
|
2401
|
-
|
2477
|
+
n_state_head, n_kv, n_head,
|
2402
2478
|
ggml_element_size(kv_self.k)*n_state,
|
2403
|
-
ggml_element_size(kv_self.k)*
|
2479
|
+
ggml_element_size(kv_self.k)*n_state_head,
|
2404
2480
|
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
2405
2481
|
|
2406
|
-
|
2407
|
-
|
2482
|
+
if (wctx.params.flash_attn) {
|
2483
|
+
struct ggml_tensor * V =
|
2484
|
+
ggml_view_3d(ctx0, kv_self.v,
|
2485
|
+
n_state_head, n_kv, n_head,
|
2486
|
+
ggml_element_size(kv_self.v)*n_state,
|
2487
|
+
ggml_element_size(kv_self.v)*n_state_head,
|
2488
|
+
ggml_element_size(kv_self.v)*n_state*n_ctx*il);
|
2408
2489
|
|
2409
|
-
|
2490
|
+
cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
|
2410
2491
|
|
2411
|
-
|
2412
|
-
|
2492
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
2493
|
+
} else {
|
2494
|
+
// K * Q
|
2495
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
2413
2496
|
|
2414
|
-
|
2497
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
|
2415
2498
|
|
2416
|
-
|
2417
|
-
|
2418
|
-
|
2419
|
-
|
2420
|
-
|
2421
|
-
|
2499
|
+
struct ggml_tensor * V =
|
2500
|
+
ggml_view_3d(ctx0, kv_self.v,
|
2501
|
+
n_kv, n_state_head, n_head,
|
2502
|
+
n_ctx*ggml_element_size(kv_self.v),
|
2503
|
+
n_ctx*ggml_element_size(kv_self.v)*n_state_head,
|
2504
|
+
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
|
2422
2505
|
|
2423
|
-
|
2506
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
2424
2507
|
|
2425
|
-
|
2508
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2426
2509
|
|
2427
|
-
|
2428
|
-
|
2429
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
2510
|
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
2511
|
+
}
|
2430
2512
|
}
|
2431
2513
|
|
2432
2514
|
// projection
|
@@ -2465,80 +2547,75 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2465
2547
|
Qcur,
|
2466
2548
|
layer.cross_attn_q_b);
|
2467
2549
|
|
2468
|
-
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
2469
|
-
|
2470
|
-
// Kcross is already scaled
|
2471
|
-
struct ggml_tensor * Kcross =
|
2472
|
-
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
2473
|
-
n_state/n_head, n_audio_ctx, n_head,
|
2474
|
-
ggml_element_size(wstate.kv_cross.k)*n_state,
|
2475
|
-
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
2476
|
-
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
2477
|
-
|
2478
|
-
//struct ggml_tensor * Vcross =
|
2479
|
-
// ggml_reshape_3d(ctx0,
|
2480
|
-
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
|
2481
|
-
// n_state/n_head, n_head, n_audio_ctx);
|
2482
|
-
|
2483
|
-
//struct ggml_tensor * V_trans =
|
2484
|
-
// ggml_cpy(ctx0,
|
2485
|
-
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
2486
|
-
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
2487
|
-
|
2488
|
-
struct ggml_tensor * V =
|
2489
|
-
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
2490
|
-
n_audio_ctx, n_state/n_head, n_head,
|
2491
|
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
2492
|
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
2493
|
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
2494
|
-
|
2495
|
-
// ------
|
2496
|
-
|
2497
2550
|
struct ggml_tensor * Q =
|
2498
2551
|
ggml_permute(ctx0,
|
2499
|
-
ggml_reshape_3d(ctx0, Qcur,
|
2552
|
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
2500
2553
|
0, 2, 1, 3);
|
2501
2554
|
|
2502
|
-
|
2503
|
-
|
2504
|
-
|
2505
|
-
|
2506
|
-
|
2507
|
-
|
2508
|
-
|
2509
|
-
// );
|
2555
|
+
if (wctx.params.flash_attn) {
|
2556
|
+
struct ggml_tensor * Kcross =
|
2557
|
+
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
2558
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
2559
|
+
ggml_element_size(wstate.kv_cross.k)*n_state,
|
2560
|
+
ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
2561
|
+
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
|
2510
2562
|
|
2511
|
-
|
2512
|
-
|
2563
|
+
struct ggml_tensor * Vcross =
|
2564
|
+
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
2565
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
2566
|
+
ggml_element_size(wstate.kv_cross.v)*n_state,
|
2567
|
+
ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
2568
|
+
ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
|
2513
2569
|
|
2514
|
-
|
2570
|
+
cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
|
2515
2571
|
|
2516
|
-
|
2517
|
-
|
2518
|
-
|
2519
|
-
|
2520
|
-
|
2521
|
-
|
2522
|
-
|
2523
|
-
|
2524
|
-
|
2525
|
-
|
2526
|
-
|
2527
|
-
|
2528
|
-
|
2529
|
-
|
2572
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
2573
|
+
} else {
|
2574
|
+
struct ggml_tensor * Kcross =
|
2575
|
+
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
2576
|
+
n_state_head, n_audio_ctx, n_head,
|
2577
|
+
ggml_element_size(wstate.kv_cross.k)*n_state,
|
2578
|
+
ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
2579
|
+
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
2580
|
+
|
2581
|
+
struct ggml_tensor * Vcross =
|
2582
|
+
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
2583
|
+
n_audio_ctx, n_state_head, n_head,
|
2584
|
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
2585
|
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
2586
|
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
2587
|
+
|
2588
|
+
// ------
|
2589
|
+
|
2590
|
+
// K * Q
|
2591
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
|
2592
|
+
|
2593
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
2594
|
+
|
2595
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
2596
|
+
if (wctx.params.dtw_token_timestamps) {
|
2597
|
+
if (wstate.aheads_masks.m[il] != nullptr) {
|
2598
|
+
struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
|
2599
|
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
2600
|
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
2601
|
+
aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
|
2602
|
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
2603
|
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
2604
|
+
aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
|
2605
|
+
if (aheads_cross_QKs == NULL) {
|
2606
|
+
aheads_cross_QKs = aheads_KQs;
|
2607
|
+
} else {
|
2608
|
+
aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
|
2609
|
+
}
|
2530
2610
|
}
|
2531
2611
|
}
|
2532
|
-
}
|
2533
2612
|
|
2534
|
-
|
2613
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
|
2535
2614
|
|
2536
|
-
|
2615
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2537
2616
|
|
2538
|
-
|
2539
|
-
|
2540
|
-
KQV_merged,
|
2541
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
2617
|
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
2618
|
+
}
|
2542
2619
|
}
|
2543
2620
|
|
2544
2621
|
// projection
|
@@ -2671,18 +2748,20 @@ static bool whisper_decode_internal(
|
|
2671
2748
|
return false;
|
2672
2749
|
}
|
2673
2750
|
|
2674
|
-
|
2751
|
+
const uint32_t pad = whisper_kv_cache_get_padding(wctx);
|
2752
|
+
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
|
2753
|
+
|
2675
2754
|
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
2676
2755
|
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
2677
2756
|
}
|
2678
2757
|
|
2679
2758
|
// decoder
|
2680
2759
|
{
|
2681
|
-
auto &
|
2760
|
+
auto & sched = wstate.sched_decode.sched;
|
2682
2761
|
|
2683
2762
|
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
|
2684
2763
|
|
2685
|
-
if (!
|
2764
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2686
2765
|
// should never happen as we pre-allocate the memory
|
2687
2766
|
return false;
|
2688
2767
|
}
|
@@ -2705,9 +2784,10 @@ static bool whisper_decode_internal(
|
|
2705
2784
|
struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
|
2706
2785
|
|
2707
2786
|
auto & kv_self = wstate.kv_self;
|
2708
|
-
const int32_t n_kv = kv_self.n;
|
2709
2787
|
|
2710
|
-
|
2788
|
+
const int32_t n_kv = kv_self.n;
|
2789
|
+
|
2790
|
+
wstate.inp_mask.resize(ggml_nelements(KQ_mask));
|
2711
2791
|
|
2712
2792
|
float * data = wstate.inp_mask.data();
|
2713
2793
|
memset(data, 0, ggml_nbytes(KQ_mask));
|
@@ -2723,14 +2803,20 @@ static bool whisper_decode_internal(
|
|
2723
2803
|
}
|
2724
2804
|
}
|
2725
2805
|
}
|
2806
|
+
|
2807
|
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
2808
|
+
for (int j = 0; j < n_kv; ++j) {
|
2809
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
2810
|
+
}
|
2811
|
+
}
|
2726
2812
|
}
|
2727
2813
|
|
2728
2814
|
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
2729
2815
|
}
|
2730
2816
|
|
2731
|
-
logits = gf
|
2817
|
+
logits = ggml_graph_node(gf, -1);
|
2732
2818
|
|
2733
|
-
if (!ggml_graph_compute_helper(
|
2819
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2734
2820
|
return false;
|
2735
2821
|
}
|
2736
2822
|
}
|
@@ -2784,29 +2870,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
2784
2870
|
}
|
2785
2871
|
|
2786
2872
|
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
2787
|
-
|
2788
|
-
|
2873
|
+
namespace {
|
2874
|
+
struct whisper_global_cache {
|
2875
|
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
2876
|
+
// We can use precalculated values to speed up the process.
|
2877
|
+
float sin_vals[SIN_COS_N_COUNT];
|
2878
|
+
float cos_vals[SIN_COS_N_COUNT];
|
2879
|
+
|
2880
|
+
// Hann window (Use cosf to eliminate difference)
|
2881
|
+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
2882
|
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
2883
|
+
float hann_window[WHISPER_N_FFT];
|
2884
|
+
|
2885
|
+
whisper_global_cache() {
|
2886
|
+
fill_sin_cos_table();
|
2887
|
+
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
|
2888
|
+
}
|
2889
|
+
|
2890
|
+
void fill_sin_cos_table() {
|
2891
|
+
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
2892
|
+
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
2893
|
+
sin_vals[i] = sinf(theta);
|
2894
|
+
cos_vals[i] = cosf(theta);
|
2895
|
+
}
|
2896
|
+
}
|
2789
2897
|
|
2790
|
-
|
2791
|
-
|
2792
|
-
|
2793
|
-
|
2794
|
-
|
2795
|
-
|
2796
|
-
|
2797
|
-
|
2798
|
-
cos_vals[i] = cosf(theta);
|
2898
|
+
void fill_hann_window(int length, bool periodic, float * output) {
|
2899
|
+
int offset = -1;
|
2900
|
+
if (periodic) {
|
2901
|
+
offset = 0;
|
2902
|
+
}
|
2903
|
+
for (int i = 0; i < length; i++) {
|
2904
|
+
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
2905
|
+
}
|
2799
2906
|
}
|
2800
|
-
|
2907
|
+
} global_cache;
|
2801
2908
|
}
|
2802
2909
|
|
2803
2910
|
// naive Discrete Fourier Transform
|
2804
2911
|
// input is real-valued
|
2805
2912
|
// output is complex-valued
|
2806
|
-
static void dft(const
|
2807
|
-
int N = in.size();
|
2808
|
-
|
2809
|
-
out.resize(N*2);
|
2913
|
+
static void dft(const float* in, int N, float* out) {
|
2810
2914
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
2811
2915
|
|
2812
2916
|
for (int k = 0; k < N; k++) {
|
@@ -2815,8 +2919,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
2815
2919
|
|
2816
2920
|
for (int n = 0; n < N; n++) {
|
2817
2921
|
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
2818
|
-
re += in[n]*cos_vals[idx]; // cos(t)
|
2819
|
-
im -= in[n]*sin_vals[idx]; // sin(t)
|
2922
|
+
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
2923
|
+
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
2820
2924
|
}
|
2821
2925
|
|
2822
2926
|
out[k*2 + 0] = re;
|
@@ -2828,47 +2932,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
2828
2932
|
// poor man's implementation - use something better
|
2829
2933
|
// input is real-valued
|
2830
2934
|
// output is complex-valued
|
2831
|
-
static void fft(
|
2832
|
-
out.resize(in.size()*2);
|
2833
|
-
|
2834
|
-
int N = in.size();
|
2835
|
-
|
2935
|
+
static void fft(float* in, int N, float* out) {
|
2836
2936
|
if (N == 1) {
|
2837
2937
|
out[0] = in[0];
|
2838
2938
|
out[1] = 0;
|
2839
2939
|
return;
|
2840
2940
|
}
|
2841
2941
|
|
2842
|
-
|
2843
|
-
|
2942
|
+
const int half_N = N / 2;
|
2943
|
+
if (N - half_N*2 == 1) {
|
2944
|
+
dft(in, N, out);
|
2844
2945
|
return;
|
2845
2946
|
}
|
2846
2947
|
|
2847
|
-
|
2848
|
-
|
2849
|
-
|
2850
|
-
even.reserve(N/2);
|
2851
|
-
odd.reserve(N/2);
|
2852
|
-
|
2853
|
-
for (int i = 0; i < N; i++) {
|
2854
|
-
if (i % 2 == 0) {
|
2855
|
-
even.push_back(in[i]);
|
2856
|
-
} else {
|
2857
|
-
odd.push_back(in[i]);
|
2858
|
-
}
|
2948
|
+
float* even = in + N;
|
2949
|
+
for (int i = 0; i < half_N; ++i) {
|
2950
|
+
even[i]= in[2*i];
|
2859
2951
|
}
|
2952
|
+
float* even_fft = out + 2 * N;
|
2953
|
+
fft(even, half_N, even_fft);
|
2860
2954
|
|
2861
|
-
|
2862
|
-
|
2863
|
-
|
2864
|
-
|
2865
|
-
|
2955
|
+
float* odd = even;
|
2956
|
+
for (int i = 0; i < half_N; ++i) {
|
2957
|
+
odd[i] = in[2*i + 1];
|
2958
|
+
}
|
2959
|
+
float* odd_fft = even_fft + N;
|
2960
|
+
fft(odd, half_N, odd_fft);
|
2866
2961
|
|
2867
2962
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
2868
|
-
for (int k = 0; k <
|
2963
|
+
for (int k = 0; k < half_N; k++) {
|
2869
2964
|
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
2870
|
-
float re = cos_vals[idx]; // cos(t)
|
2871
|
-
float im = -sin_vals[idx]; // sin(t)
|
2965
|
+
float re = global_cache.cos_vals[idx]; // cos(t)
|
2966
|
+
float im = -global_cache.sin_vals[idx]; // sin(t)
|
2872
2967
|
|
2873
2968
|
float re_odd = odd_fft[2*k + 0];
|
2874
2969
|
float im_odd = odd_fft[2*k + 1];
|
@@ -2876,52 +2971,39 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
2876
2971
|
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
2877
2972
|
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
2878
2973
|
|
2879
|
-
out[2*(k +
|
2880
|
-
out[2*(k +
|
2881
|
-
}
|
2882
|
-
}
|
2883
|
-
|
2884
|
-
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
|
2885
|
-
if (output.size() < static_cast<size_t>(length)) {
|
2886
|
-
output.resize(length);
|
2974
|
+
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
2975
|
+
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
2887
2976
|
}
|
2888
|
-
int offset = -1;
|
2889
|
-
if (periodic) {
|
2890
|
-
offset = 0;
|
2891
|
-
}
|
2892
|
-
for (int i = 0; i < length; i++) {
|
2893
|
-
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
|
2894
|
-
}
|
2895
|
-
|
2896
|
-
return true;
|
2897
2977
|
}
|
2898
2978
|
|
2899
|
-
static void log_mel_spectrogram_worker_thread(int ith, const
|
2979
|
+
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
2900
2980
|
int n_samples, int frame_size, int frame_step, int n_threads,
|
2901
2981
|
const whisper_filters & filters, whisper_mel & mel) {
|
2902
|
-
std::vector<float> fft_in(frame_size, 0.0);
|
2903
|
-
std::vector<float> fft_out(2 *
|
2982
|
+
std::vector<float> fft_in(frame_size * 2, 0.0);
|
2983
|
+
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
2984
|
+
|
2904
2985
|
int n_fft = filters.n_fft;
|
2905
2986
|
int i = ith;
|
2906
2987
|
|
2907
2988
|
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
2908
|
-
assert(
|
2909
|
-
|
2989
|
+
assert(n_fft == 1 + (frame_size / 2));
|
2990
|
+
|
2910
2991
|
// calculate FFT only when fft_in are not all zero
|
2911
2992
|
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
2912
2993
|
const int offset = i * frame_step;
|
2913
2994
|
|
2914
|
-
// apply
|
2995
|
+
// apply Hann window (~10% faster)
|
2915
2996
|
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
2916
2997
|
fft_in[j] = hann[j] * samples[offset + j];
|
2917
2998
|
}
|
2999
|
+
|
2918
3000
|
// fill the rest with zeros
|
2919
3001
|
if (n_samples - offset < frame_size) {
|
2920
3002
|
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
|
2921
3003
|
}
|
2922
3004
|
|
2923
3005
|
// FFT
|
2924
|
-
fft(fft_in, fft_out);
|
3006
|
+
fft(fft_in.data(), frame_size, fft_out.data());
|
2925
3007
|
|
2926
3008
|
// Calculate modulus^2 of complex numbers
|
2927
3009
|
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
@@ -2932,7 +3014,6 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
2932
3014
|
// mel spectrogram
|
2933
3015
|
for (int j = 0; j < mel.n_mel; j++) {
|
2934
3016
|
double sum = 0.0;
|
2935
|
-
|
2936
3017
|
// unroll loop (suggested by GH user @lunixbochs)
|
2937
3018
|
int k = 0;
|
2938
3019
|
for (k = 0; k < n_fft - 3; k += 4) {
|
@@ -2942,14 +3023,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
2942
3023
|
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
|
2943
3024
|
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
|
2944
3025
|
}
|
2945
|
-
|
2946
3026
|
// handle n_fft remainder
|
2947
3027
|
for (; k < n_fft; k++) {
|
2948
3028
|
sum += fft_out[k] * filters.data[j * n_fft + k];
|
2949
3029
|
}
|
2950
|
-
|
2951
3030
|
sum = log10(std::max(sum, 1e-10));
|
2952
|
-
|
2953
3031
|
mel.data[j * mel.n_len + i] = sum;
|
2954
3032
|
}
|
2955
3033
|
}
|
@@ -2978,12 +3056,9 @@ static bool log_mel_spectrogram(
|
|
2978
3056
|
whisper_mel & mel) {
|
2979
3057
|
const int64_t t_start_us = ggml_time_us();
|
2980
3058
|
|
2981
|
-
//
|
2982
|
-
|
2983
|
-
|
2984
|
-
std::vector<float> hann;
|
2985
|
-
hann_window(frame_size, true, hann);
|
2986
|
-
|
3059
|
+
// Hann window
|
3060
|
+
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
|
3061
|
+
const float * hann = global_cache.hann_window;
|
2987
3062
|
|
2988
3063
|
// Calculate the length of padding
|
2989
3064
|
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
@@ -3008,12 +3083,11 @@ static bool log_mel_spectrogram(
|
|
3008
3083
|
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
3009
3084
|
mel.data.resize(mel.n_mel * mel.n_len);
|
3010
3085
|
|
3011
|
-
|
3012
3086
|
{
|
3013
3087
|
std::vector<std::thread> workers(n_threads - 1);
|
3014
3088
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
3015
3089
|
workers[iw] = std::thread(
|
3016
|
-
log_mel_spectrogram_worker_thread, iw + 1, std::cref(
|
3090
|
+
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
|
3017
3091
|
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
3018
3092
|
std::cref(filters), std::ref(mel));
|
3019
3093
|
}
|
@@ -3173,23 +3247,23 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
|
3173
3247
|
#endif
|
3174
3248
|
|
3175
3249
|
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
3176
|
-
fill_sin_cos_table();
|
3177
|
-
|
3178
3250
|
whisper_state * state = new whisper_state;
|
3179
3251
|
|
3180
|
-
state->
|
3181
|
-
if (
|
3252
|
+
state->backends = whisper_backend_init(ctx->params);
|
3253
|
+
if (state->backends.empty()) {
|
3182
3254
|
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
3183
3255
|
whisper_free_state(state);
|
3184
3256
|
return nullptr;
|
3185
3257
|
}
|
3186
3258
|
|
3187
|
-
// at this point, we don't know yet how many decoders will be used
|
3188
|
-
//
|
3189
|
-
|
3190
|
-
|
3191
|
-
|
3192
|
-
|
3259
|
+
// at this point, we don't know yet how many decoders will be used
|
3260
|
+
// later during decoding, if more decoders are used, we will recreate the KV cache respectively
|
3261
|
+
state->kv_self_n_dec = 1;
|
3262
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
3263
|
+
ctx->model.hparams.n_text_state,
|
3264
|
+
ctx->model.hparams.n_text_layer,
|
3265
|
+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
|
3266
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
3193
3267
|
whisper_free_state(state);
|
3194
3268
|
return nullptr;
|
3195
3269
|
}
|
@@ -3199,8 +3273,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3199
3273
|
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
3200
3274
|
}
|
3201
3275
|
|
3202
|
-
if (!
|
3203
|
-
|
3276
|
+
if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
|
3277
|
+
ctx->model.hparams.n_text_state,
|
3278
|
+
ctx->model.hparams.n_text_layer,
|
3279
|
+
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
3280
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
|
3204
3281
|
whisper_free_state(state);
|
3205
3282
|
return nullptr;
|
3206
3283
|
}
|
@@ -3210,9 +3287,23 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3210
3287
|
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
3211
3288
|
}
|
3212
3289
|
|
3290
|
+
if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
|
3291
|
+
ctx->model.hparams.n_audio_state,
|
3292
|
+
1,
|
3293
|
+
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
3294
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
3295
|
+
whisper_free_state(state);
|
3296
|
+
return nullptr;
|
3297
|
+
}
|
3298
|
+
|
3299
|
+
{
|
3300
|
+
const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
|
3301
|
+
WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
|
3302
|
+
}
|
3303
|
+
|
3213
3304
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
3214
3305
|
if (ctx->params.dtw_token_timestamps) {
|
3215
|
-
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks,
|
3306
|
+
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
|
3216
3307
|
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
|
3217
3308
|
whisper_free_state(state);
|
3218
3309
|
return nullptr;
|
@@ -3255,7 +3346,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3255
3346
|
|
3256
3347
|
// conv allocator
|
3257
3348
|
{
|
3258
|
-
bool ok =
|
3349
|
+
bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
|
3259
3350
|
[&]() {
|
3260
3351
|
return whisper_build_graph_conv(*ctx, *state);
|
3261
3352
|
});
|
@@ -3266,12 +3357,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3266
3357
|
return nullptr;
|
3267
3358
|
}
|
3268
3359
|
|
3269
|
-
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__,
|
3360
|
+
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
|
3270
3361
|
}
|
3271
3362
|
|
3272
3363
|
// encoder allocator
|
3273
3364
|
if (!whisper_encode_external(*state)) {
|
3274
|
-
bool ok =
|
3365
|
+
bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
|
3275
3366
|
[&]() {
|
3276
3367
|
return whisper_build_graph_encoder(*ctx, *state);
|
3277
3368
|
});
|
@@ -3282,12 +3373,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3282
3373
|
return nullptr;
|
3283
3374
|
}
|
3284
3375
|
|
3285
|
-
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__,
|
3376
|
+
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
|
3286
3377
|
}
|
3287
3378
|
|
3288
3379
|
// cross allocator
|
3289
3380
|
{
|
3290
|
-
bool ok =
|
3381
|
+
bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
|
3291
3382
|
[&]() {
|
3292
3383
|
return whisper_build_graph_cross(*ctx, *state);
|
3293
3384
|
});
|
@@ -3298,12 +3389,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3298
3389
|
return nullptr;
|
3299
3390
|
}
|
3300
3391
|
|
3301
|
-
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__,
|
3392
|
+
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
|
3302
3393
|
}
|
3303
3394
|
|
3304
3395
|
// decoder allocator
|
3305
3396
|
{
|
3306
|
-
bool ok =
|
3397
|
+
bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
|
3307
3398
|
[&]() {
|
3308
3399
|
const auto & hparams = ctx->model.hparams;
|
3309
3400
|
|
@@ -3322,19 +3413,21 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3322
3413
|
return nullptr;
|
3323
3414
|
}
|
3324
3415
|
|
3325
|
-
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__,
|
3416
|
+
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
|
3326
3417
|
}
|
3327
3418
|
|
3328
3419
|
return state;
|
3329
3420
|
}
|
3330
3421
|
|
3331
|
-
int
|
3422
|
+
int whisper_ctx_init_openvino_encoder_with_state(
|
3332
3423
|
struct whisper_context * ctx,
|
3424
|
+
struct whisper_state * state,
|
3333
3425
|
const char * model_path,
|
3334
3426
|
const char * device,
|
3335
3427
|
const char * cache_dir) {
|
3336
3428
|
#ifndef WHISPER_USE_OPENVINO
|
3337
3429
|
(void)(ctx);
|
3430
|
+
(void)(state);
|
3338
3431
|
(void)(model_path);
|
3339
3432
|
(void)(device);
|
3340
3433
|
(void)(cache_dir);
|
@@ -3365,8 +3458,8 @@ int whisper_ctx_init_openvino_encoder(
|
|
3365
3458
|
WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
|
3366
3459
|
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
3367
3460
|
|
3368
|
-
|
3369
|
-
if (!
|
3461
|
+
state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
|
3462
|
+
if (!state->ctx_openvino) {
|
3370
3463
|
WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
|
3371
3464
|
return 1;
|
3372
3465
|
} else {
|
@@ -3377,9 +3470,18 @@ int whisper_ctx_init_openvino_encoder(
|
|
3377
3470
|
#endif
|
3378
3471
|
}
|
3379
3472
|
|
3473
|
+
int whisper_ctx_init_openvino_encoder(
|
3474
|
+
struct whisper_context * ctx,
|
3475
|
+
const char * model_path,
|
3476
|
+
const char * device,
|
3477
|
+
const char * cache_dir) {
|
3478
|
+
return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
|
3479
|
+
}
|
3480
|
+
|
3380
3481
|
struct whisper_context_params whisper_context_default_params() {
|
3381
3482
|
struct whisper_context_params result = {
|
3382
3483
|
/*.use_gpu =*/ true,
|
3484
|
+
/*.flash_attn =*/ false,
|
3383
3485
|
/*.gpu_device =*/ 0,
|
3384
3486
|
|
3385
3487
|
/*.dtw_token_timestamps =*/ false,
|
@@ -3396,8 +3498,14 @@ struct whisper_context_params whisper_context_default_params() {
|
|
3396
3498
|
|
3397
3499
|
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
|
3398
3500
|
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
|
3399
|
-
|
3501
|
+
#ifdef _MSC_VER
|
3502
|
+
// Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
|
3503
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
3504
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
3505
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
3506
|
+
#else
|
3400
3507
|
auto fin = std::ifstream(path_model, std::ios::binary);
|
3508
|
+
#endif
|
3401
3509
|
if (!fin) {
|
3402
3510
|
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
|
3403
3511
|
return nullptr;
|
@@ -3472,6 +3580,18 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
|
|
3472
3580
|
struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
|
3473
3581
|
ggml_time_init();
|
3474
3582
|
|
3583
|
+
if (params.flash_attn && params.dtw_token_timestamps) {
|
3584
|
+
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
|
3585
|
+
params.dtw_token_timestamps = false;
|
3586
|
+
}
|
3587
|
+
|
3588
|
+
WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
|
3589
|
+
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
|
3590
|
+
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
|
3591
|
+
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
|
3592
|
+
WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
|
3593
|
+
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
|
3594
|
+
|
3475
3595
|
whisper_context * ctx = new whisper_context;
|
3476
3596
|
ctx->params = params;
|
3477
3597
|
|
@@ -3558,8 +3678,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
|
3558
3678
|
|
3559
3679
|
void whisper_free_state(struct whisper_state * state) {
|
3560
3680
|
if (state) {
|
3561
|
-
|
3562
|
-
|
3681
|
+
whisper_kv_cache_free(state->kv_self);
|
3682
|
+
whisper_kv_cache_free(state->kv_cross);
|
3683
|
+
whisper_kv_cache_free(state->kv_pad);
|
3563
3684
|
|
3564
3685
|
#ifdef WHISPER_USE_COREML
|
3565
3686
|
if (state->ctx_coreml != nullptr) {
|
@@ -3577,12 +3698,14 @@ void whisper_free_state(struct whisper_state * state) {
|
|
3577
3698
|
|
3578
3699
|
whisper_batch_free(state->batch);
|
3579
3700
|
|
3580
|
-
|
3581
|
-
|
3582
|
-
|
3583
|
-
|
3701
|
+
ggml_backend_sched_free(state->sched_conv.sched);
|
3702
|
+
ggml_backend_sched_free(state->sched_encode.sched);
|
3703
|
+
ggml_backend_sched_free(state->sched_cross.sched);
|
3704
|
+
ggml_backend_sched_free(state->sched_decode.sched);
|
3584
3705
|
|
3585
|
-
|
3706
|
+
for (auto & backend : state->backends) {
|
3707
|
+
ggml_backend_free(backend);
|
3708
|
+
}
|
3586
3709
|
|
3587
3710
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
3588
3711
|
aheads_masks_free(state->aheads_masks);
|
@@ -3599,8 +3722,6 @@ void whisper_free(struct whisper_context * ctx) {
|
|
3599
3722
|
|
3600
3723
|
whisper_free_state(ctx->state);
|
3601
3724
|
|
3602
|
-
ggml_backend_free(ctx->backend);
|
3603
|
-
|
3604
3725
|
delete ctx;
|
3605
3726
|
}
|
3606
3727
|
}
|
@@ -3630,30 +3751,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
3630
3751
|
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
3631
3752
|
}
|
3632
3753
|
|
3633
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
3634
|
-
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
3635
|
-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
3636
|
-
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
3637
|
-
return -1;
|
3638
|
-
}
|
3639
|
-
|
3640
|
-
return 0;
|
3641
|
-
}
|
3642
|
-
|
3643
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
3644
|
-
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
3645
|
-
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
3646
|
-
}
|
3647
|
-
|
3648
|
-
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
|
3649
|
-
// TODO
|
3650
|
-
|
3651
|
-
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
|
3652
|
-
// TODO
|
3653
|
-
|
3654
|
-
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
|
3655
|
-
// TODO
|
3656
|
-
|
3657
3754
|
int whisper_set_mel_with_state(
|
3658
3755
|
struct whisper_context * ctx,
|
3659
3756
|
struct whisper_state * state,
|
@@ -3742,7 +3839,7 @@ int whisper_token_count(struct whisper_context * ctx, const char * text) {
|
|
3742
3839
|
return -whisper_tokenize(ctx, text, NULL, 0);
|
3743
3840
|
}
|
3744
3841
|
|
3745
|
-
int whisper_lang_max_id() {
|
3842
|
+
int whisper_lang_max_id(void) {
|
3746
3843
|
auto max_id = 0;
|
3747
3844
|
for (const auto & kv : g_lang) {
|
3748
3845
|
max_id = std::max(max_id, kv.second.first);
|
@@ -4011,6 +4108,19 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
|
|
4011
4108
|
return ctx->vocab.token_transcribe;
|
4012
4109
|
}
|
4013
4110
|
|
4111
|
+
struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
|
4112
|
+
if (ctx->state == nullptr) {
|
4113
|
+
return nullptr;
|
4114
|
+
}
|
4115
|
+
whisper_timings * timings = new whisper_timings;
|
4116
|
+
timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
|
4117
|
+
timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
|
4118
|
+
timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
|
4119
|
+
timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd);
|
4120
|
+
timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt);
|
4121
|
+
return timings;
|
4122
|
+
}
|
4123
|
+
|
4014
4124
|
void whisper_print_timings(struct whisper_context * ctx) {
|
4015
4125
|
const int64_t t_end_us = ggml_time_us();
|
4016
4126
|
|
@@ -4078,17 +4188,14 @@ const char * whisper_print_system_info(void) {
|
|
4078
4188
|
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
4079
4189
|
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
4080
4190
|
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
4081
|
-
s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
|
4082
4191
|
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
4083
4192
|
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
4084
4193
|
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
4085
|
-
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
4086
4194
|
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
4087
4195
|
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
|
4088
4196
|
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
4089
|
-
s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
|
4090
4197
|
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
4091
|
-
s += "OPENVINO = " + std::to_string(whisper_has_openvino())
|
4198
|
+
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
4092
4199
|
|
4093
4200
|
return s.c_str();
|
4094
4201
|
}
|
@@ -4099,7 +4206,7 @@ const char * whisper_print_system_info(void) {
|
|
4099
4206
|
|
4100
4207
|
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
4101
4208
|
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
4102
|
-
std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
4209
|
+
static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
4103
4210
|
const char * src,
|
4104
4211
|
whisper_partial_utf8 partial_start) {
|
4105
4212
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
@@ -4513,7 +4620,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
|
|
4513
4620
|
|
4514
4621
|
////////////////////////////////////////////////////////////////////////////
|
4515
4622
|
|
4516
|
-
struct whisper_context_params * whisper_context_default_params_by_ref() {
|
4623
|
+
struct whisper_context_params * whisper_context_default_params_by_ref(void) {
|
4517
4624
|
struct whisper_context_params params = whisper_context_default_params();
|
4518
4625
|
|
4519
4626
|
struct whisper_context_params* result = new whisper_context_params();
|
@@ -4554,7 +4661,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
4554
4661
|
/*.split_on_word =*/ false,
|
4555
4662
|
/*.max_tokens =*/ 0,
|
4556
4663
|
|
4557
|
-
/*.speed_up =*/ false,
|
4558
4664
|
/*.debug_mode =*/ false,
|
4559
4665
|
/*.audio_ctx =*/ 0,
|
4560
4666
|
|
@@ -4720,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
4720
4826
|
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
4721
4827
|
};
|
4722
4828
|
|
4829
|
+
static void whisper_compute_logprobs(
|
4830
|
+
const std::vector<float> & logits,
|
4831
|
+
const int n_logits,
|
4832
|
+
std::vector<float> & logprobs) {
|
4833
|
+
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
4834
|
+
float logsumexp = 0.0f;
|
4835
|
+
for (int i = 0; i < n_logits; ++i) {
|
4836
|
+
if (logits[i] > -INFINITY) {
|
4837
|
+
logsumexp += expf(logits[i] - logit_max);
|
4838
|
+
}
|
4839
|
+
}
|
4840
|
+
logsumexp = logf(logsumexp) + logit_max;
|
4841
|
+
|
4842
|
+
for (int i = 0; i < n_logits; ++i) {
|
4843
|
+
if (logits[i] > -INFINITY) {
|
4844
|
+
logprobs[i] = logits[i] - logsumexp;
|
4845
|
+
} else {
|
4846
|
+
logprobs[i] = -INFINITY;
|
4847
|
+
}
|
4848
|
+
}
|
4849
|
+
}
|
4850
|
+
|
4851
|
+
static void whisper_compute_probs(
|
4852
|
+
const std::vector<float> & logits,
|
4853
|
+
const int n_logits,
|
4854
|
+
const std::vector<float> & logprobs,
|
4855
|
+
std::vector<float> & probs) {
|
4856
|
+
for (int i = 0; i < n_logits; ++i) {
|
4857
|
+
if (logits[i] == -INFINITY) {
|
4858
|
+
probs[i] = 0.0f;
|
4859
|
+
} else {
|
4860
|
+
probs[i] = expf(logprobs[i]);
|
4861
|
+
}
|
4862
|
+
}
|
4863
|
+
}
|
4864
|
+
|
4723
4865
|
// process the logits for the selected decoder
|
4724
4866
|
// - applies logit filters
|
4725
4867
|
// - computes logprobs and probs
|
@@ -4781,7 +4923,7 @@ static void whisper_process_logits(
|
|
4781
4923
|
|
4782
4924
|
// suppress sot and nosp tokens
|
4783
4925
|
logits[vocab.token_sot] = -INFINITY;
|
4784
|
-
logits[vocab.token_nosp] = -INFINITY;
|
4926
|
+
logits[vocab.token_nosp] = -INFINITY;
|
4785
4927
|
|
4786
4928
|
// [TDRZ] when tinydiarize is disabled, suppress solm token
|
4787
4929
|
if (params.tdrz_enable == false) {
|
@@ -4880,24 +5022,7 @@ static void whisper_process_logits(
|
|
4880
5022
|
}
|
4881
5023
|
|
4882
5024
|
// populate the logprobs array (log_softmax)
|
4883
|
-
|
4884
|
-
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
4885
|
-
float logsumexp = 0.0f;
|
4886
|
-
for (int i = 0; i < n_logits; ++i) {
|
4887
|
-
if (logits[i] > -INFINITY) {
|
4888
|
-
logsumexp += expf(logits[i] - logit_max);
|
4889
|
-
}
|
4890
|
-
}
|
4891
|
-
logsumexp = logf(logsumexp) + logit_max;
|
4892
|
-
|
4893
|
-
for (int i = 0; i < n_logits; ++i) {
|
4894
|
-
if (logits[i] > -INFINITY) {
|
4895
|
-
logprobs[i] = logits[i] - logsumexp;
|
4896
|
-
} else {
|
4897
|
-
logprobs[i] = -INFINITY;
|
4898
|
-
}
|
4899
|
-
}
|
4900
|
-
}
|
5025
|
+
whisper_compute_logprobs(logits, n_logits, logprobs);
|
4901
5026
|
|
4902
5027
|
// if sum of probability over timestamps is above any other token, sample timestamp
|
4903
5028
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
@@ -4955,15 +5080,7 @@ static void whisper_process_logits(
|
|
4955
5080
|
}
|
4956
5081
|
|
4957
5082
|
// compute probs
|
4958
|
-
|
4959
|
-
for (int i = 0; i < n_logits; ++i) {
|
4960
|
-
if (logits[i] == -INFINITY) {
|
4961
|
-
probs[i] = 0.0f;
|
4962
|
-
} else {
|
4963
|
-
probs[i] = expf(logprobs[i]);
|
4964
|
-
}
|
4965
|
-
}
|
4966
|
-
}
|
5083
|
+
whisper_compute_probs(logits, n_logits, logprobs, probs);
|
4967
5084
|
|
4968
5085
|
#if 0
|
4969
5086
|
// print first 100 logits - token string : logit
|
@@ -5228,15 +5345,9 @@ int whisper_full_with_state(
|
|
5228
5345
|
|
5229
5346
|
if (n_samples > 0) {
|
5230
5347
|
// compute log mel spectrogram
|
5231
|
-
if (params.
|
5232
|
-
// TODO: Replace PV with more advanced algorithm
|
5348
|
+
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
5233
5349
|
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
5234
|
-
return -
|
5235
|
-
} else {
|
5236
|
-
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
5237
|
-
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
5238
|
-
return -2;
|
5239
|
-
}
|
5350
|
+
return -2;
|
5240
5351
|
}
|
5241
5352
|
}
|
5242
5353
|
|
@@ -5273,7 +5384,7 @@ int whisper_full_with_state(
|
|
5273
5384
|
// if length of spectrogram is less than 1.0s (100 frames), then return
|
5274
5385
|
// basically don't process anything that is less than 1.0s
|
5275
5386
|
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
5276
|
-
if (seek_end < seek_start +
|
5387
|
+
if (seek_end < seek_start + 100) {
|
5277
5388
|
WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
|
5278
5389
|
return 0;
|
5279
5390
|
}
|
@@ -5518,13 +5629,46 @@ int whisper_full_with_state(
|
|
5518
5629
|
}
|
5519
5630
|
WHISPER_LOG_DEBUG("\n\n");
|
5520
5631
|
|
5632
|
+
// recreate the KV cache if the number of decoders has changed
|
5633
|
+
if (state->kv_self_n_dec < n_decoders_cur) {
|
5634
|
+
WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
|
5635
|
+
|
5636
|
+
whisper_kv_cache_free(state->kv_self);
|
5637
|
+
|
5638
|
+
// overallocate to workaround KV cache fragmentation issues
|
5639
|
+
const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
|
5640
|
+
|
5641
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
5642
|
+
ctx->model.hparams.n_text_state,
|
5643
|
+
ctx->model.hparams.n_text_layer,
|
5644
|
+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
5645
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
5646
|
+
whisper_free_state(state);
|
5647
|
+
return -7;
|
5648
|
+
}
|
5649
|
+
|
5650
|
+
state->kv_self_n_dec = n_decoders_cur;
|
5651
|
+
}
|
5652
|
+
|
5521
5653
|
whisper_kv_cache_clear(state->kv_self);
|
5522
5654
|
|
5523
5655
|
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
5524
5656
|
|
5525
5657
|
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
5526
5658
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
5527
|
-
return -
|
5659
|
+
return -8;
|
5660
|
+
}
|
5661
|
+
|
5662
|
+
// Calculate no_speech probability after first decode.
|
5663
|
+
// This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
|
5664
|
+
{
|
5665
|
+
const int n_logits = ctx->vocab.id_to_token.size();
|
5666
|
+
std::vector<float> logprobs(n_logits);
|
5667
|
+
std::vector<float> probs(n_logits);
|
5668
|
+
|
5669
|
+
whisper_compute_logprobs(state->logits, n_logits, logprobs);
|
5670
|
+
whisper_compute_probs(state->logits, n_logits, logprobs, probs);
|
5671
|
+
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
|
5528
5672
|
}
|
5529
5673
|
|
5530
5674
|
{
|
@@ -5824,7 +5968,7 @@ int whisper_full_with_state(
|
|
5824
5968
|
|
5825
5969
|
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
5826
5970
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
5827
|
-
return -
|
5971
|
+
return -9;
|
5828
5972
|
}
|
5829
5973
|
|
5830
5974
|
const int64_t t_start_sample_us = ggml_time_us();
|
@@ -5918,8 +6062,9 @@ int whisper_full_with_state(
|
|
5918
6062
|
if (it != (int) temperatures.size() - 1) {
|
5919
6063
|
const auto & decoder = state->decoders[best_decoder_id];
|
5920
6064
|
|
5921
|
-
if (decoder.failed ||
|
5922
|
-
|
6065
|
+
if (decoder.failed ||
|
6066
|
+
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
|
6067
|
+
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
|
5923
6068
|
success = false;
|
5924
6069
|
state->n_fail_p++;
|
5925
6070
|
}
|
@@ -5940,7 +6085,7 @@ int whisper_full_with_state(
|
|
5940
6085
|
{
|
5941
6086
|
const auto & best_decoder = state->decoders[best_decoder_id];
|
5942
6087
|
|
5943
|
-
|
6088
|
+
auto seek_delta = best_decoder.seek_delta;
|
5944
6089
|
const auto result_len = best_decoder.sequence.result_len;
|
5945
6090
|
|
5946
6091
|
const auto & tokens_cur = best_decoder.sequence.tokens;
|
@@ -5948,6 +6093,9 @@ int whisper_full_with_state(
|
|
5948
6093
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
5949
6094
|
const auto n_segments_before = state->result_all.size();
|
5950
6095
|
|
6096
|
+
const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
|
6097
|
+
best_decoder.sequence.avg_logprobs < params.logprob_thold);
|
6098
|
+
|
5951
6099
|
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
5952
6100
|
|
5953
6101
|
// update prompt_past
|
@@ -5956,11 +6104,11 @@ int whisper_full_with_state(
|
|
5956
6104
|
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
5957
6105
|
}
|
5958
6106
|
|
5959
|
-
for (int i = 0; i < result_len; ++i) {
|
6107
|
+
for (int i = 0; i < result_len && !is_no_speech; ++i) {
|
5960
6108
|
prompt_past.push_back(tokens_cur[i].id);
|
5961
6109
|
}
|
5962
6110
|
|
5963
|
-
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
|
6111
|
+
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
|
5964
6112
|
int i0 = 0;
|
5965
6113
|
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
5966
6114
|
|
@@ -5985,8 +6133,8 @@ int whisper_full_with_state(
|
|
5985
6133
|
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
5986
6134
|
|
5987
6135
|
if (!text.empty()) {
|
5988
|
-
const auto tt0 =
|
5989
|
-
const auto tt1 =
|
6136
|
+
const auto tt0 = t0;
|
6137
|
+
const auto tt1 = t1;
|
5990
6138
|
|
5991
6139
|
if (params.print_realtime) {
|
5992
6140
|
if (params.print_timestamps) {
|
@@ -6014,7 +6162,7 @@ int whisper_full_with_state(
|
|
6014
6162
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
6015
6163
|
}
|
6016
6164
|
}
|
6017
|
-
if (params.new_segment_callback) {
|
6165
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
6018
6166
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
6019
6167
|
}
|
6020
6168
|
}
|
@@ -6032,8 +6180,8 @@ int whisper_full_with_state(
|
|
6032
6180
|
if (!text.empty()) {
|
6033
6181
|
const auto t1 = seek + seek_delta;
|
6034
6182
|
|
6035
|
-
const auto tt0 =
|
6036
|
-
const auto tt1 =
|
6183
|
+
const auto tt0 = t0;
|
6184
|
+
const auto tt1 = t1;
|
6037
6185
|
|
6038
6186
|
if (params.print_realtime) {
|
6039
6187
|
if (params.print_timestamps) {
|
@@ -6059,7 +6207,7 @@ int whisper_full_with_state(
|
|
6059
6207
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
6060
6208
|
}
|
6061
6209
|
}
|
6062
|
-
if (params.new_segment_callback) {
|
6210
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
6063
6211
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
6064
6212
|
}
|
6065
6213
|
}
|
@@ -6068,14 +6216,28 @@ int whisper_full_with_state(
|
|
6068
6216
|
// FIXME: will timestamp offsets be correct?
|
6069
6217
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
6070
6218
|
{
|
6071
|
-
const
|
6219
|
+
const int n_segments = state->result_all.size() - n_segments_before;
|
6072
6220
|
if (ctx->params.dtw_token_timestamps && n_segments) {
|
6073
6221
|
const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
|
6074
6222
|
whisper_exp_compute_token_level_timestamps_dtw(
|
6075
6223
|
ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
|
6224
|
+
if (params.new_segment_callback) {
|
6225
|
+
for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
|
6226
|
+
params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
|
6227
|
+
}
|
6228
|
+
}
|
6076
6229
|
}
|
6077
6230
|
}
|
6078
6231
|
|
6232
|
+
// ref: https://github.com/ggerganov/whisper.cpp/pull/2629
|
6233
|
+
const bool single_timestamp_ending = tokens_cur.size() > 1 &&
|
6234
|
+
tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
|
6235
|
+
tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
|
6236
|
+
if (single_timestamp_ending) {
|
6237
|
+
WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n");
|
6238
|
+
seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100);
|
6239
|
+
}
|
6240
|
+
|
6079
6241
|
// update audio window
|
6080
6242
|
seek += seek_delta;
|
6081
6243
|
|
@@ -6835,7 +6997,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6835
6997
|
k++;
|
6836
6998
|
}
|
6837
6999
|
tokens[j].t1 = sample_to_timestamp(k);
|
6838
|
-
if (j <
|
7000
|
+
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
6839
7001
|
tokens[j].t1 = tokens[j + 1].t0;
|
6840
7002
|
} else {
|
6841
7003
|
s1 = k;
|
@@ -6998,10 +7160,11 @@ struct median_filter_user_data {
|
|
6998
7160
|
int filter_width;
|
6999
7161
|
};
|
7000
7162
|
|
7001
|
-
static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth
|
7163
|
+
static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
|
7164
|
+
if (ith != 0) {
|
7165
|
+
return;
|
7166
|
+
}
|
7002
7167
|
int filter_width = ((median_filter_user_data *) userdata)->filter_width;
|
7003
|
-
WHISPER_ASSERT(nth == 1);
|
7004
|
-
WHISPER_ASSERT(ith == 0);
|
7005
7168
|
WHISPER_ASSERT(filter_width < a->ne[2]);
|
7006
7169
|
WHISPER_ASSERT(filter_width % 2);
|
7007
7170
|
WHISPER_ASSERT(ggml_n_dims(a) == 3);
|
@@ -7124,7 +7287,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7124
7287
|
// operation (after median filter)
|
7125
7288
|
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
7126
7289
|
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
7127
|
-
w = ggml_norm(gctx, w, 1e-
|
7290
|
+
w = ggml_norm(gctx, w, 1e-9f);
|
7128
7291
|
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
|
7129
7292
|
|
7130
7293
|
// Pass median filter - this is done over AUDIO_TOKENS dimension.
|
@@ -7196,6 +7359,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7196
7359
|
void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
|
7197
7360
|
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
7198
7361
|
g_state.log_callback_user_data = user_data;
|
7362
|
+
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
7199
7363
|
}
|
7200
7364
|
|
7201
7365
|
GGML_ATTRIBUTE_FORMAT(2, 3)
|
@@ -7219,6 +7383,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
|
|
7219
7383
|
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
|
7220
7384
|
(void) level;
|
7221
7385
|
(void) user_data;
|
7386
|
+
#ifndef WHISPER_DEBUG
|
7387
|
+
if (level == GGML_LOG_LEVEL_DEBUG) {
|
7388
|
+
return;
|
7389
|
+
}
|
7390
|
+
#endif
|
7222
7391
|
fputs(text, stderr);
|
7223
7392
|
fflush(stderr);
|
7224
7393
|
}
|