whispercpp 1.3.0 → 1.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
@@ -1,29 +1,19 @@
|
|
1
1
|
#include "whisper.h"
|
2
2
|
|
3
|
-
#
|
4
|
-
#include "coreml/whisper-encoder.h"
|
5
|
-
#endif
|
3
|
+
#include "ggml-cpu.h"
|
6
4
|
|
7
|
-
#
|
8
|
-
#include "ggml-
|
9
|
-
#
|
10
|
-
|
11
|
-
#ifdef GGML_USE_CUDA
|
12
|
-
#include "ggml-cuda.h"
|
13
|
-
#endif
|
5
|
+
#include "ggml.h"
|
6
|
+
#include "ggml-alloc.h"
|
7
|
+
#include "ggml-backend.h"
|
14
8
|
|
15
|
-
#ifdef
|
16
|
-
#include "
|
9
|
+
#ifdef WHISPER_USE_COREML
|
10
|
+
#include "coreml/whisper-encoder.h"
|
17
11
|
#endif
|
18
12
|
|
19
13
|
#ifdef WHISPER_USE_OPENVINO
|
20
14
|
#include "openvino/whisper-openvino-encoder.h"
|
21
15
|
#endif
|
22
16
|
|
23
|
-
#include "ggml.h"
|
24
|
-
#include "ggml-alloc.h"
|
25
|
-
#include "ggml-backend.h"
|
26
|
-
|
27
17
|
#include <atomic>
|
28
18
|
#include <algorithm>
|
29
19
|
#include <cassert>
|
@@ -41,6 +31,7 @@
|
|
41
31
|
#include <regex>
|
42
32
|
#include <random>
|
43
33
|
#include <functional>
|
34
|
+
#include <codecvt>
|
44
35
|
|
45
36
|
#if defined(_MSC_VER)
|
46
37
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
@@ -147,8 +138,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|
147
138
|
} \
|
148
139
|
} while (0)
|
149
140
|
|
150
|
-
//#define WHISPER_USE_FLASH_ATTN
|
151
|
-
//#define WHISPER_USE_FLASH_FF
|
152
141
|
#define WHISPER_MAX_DECODERS 8
|
153
142
|
#define WHISPER_MAX_NODES 4096
|
154
143
|
|
@@ -162,7 +151,7 @@ static bool ggml_graph_compute_helper(
|
|
162
151
|
int n_threads,
|
163
152
|
ggml_abort_callback abort_callback,
|
164
153
|
void * abort_callback_data) {
|
165
|
-
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
154
|
+
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
|
166
155
|
|
167
156
|
plan.abort_callback = abort_callback;
|
168
157
|
plan.abort_callback_data = abort_callback_data;
|
@@ -176,18 +165,24 @@ static bool ggml_graph_compute_helper(
|
|
176
165
|
}
|
177
166
|
|
178
167
|
static bool ggml_graph_compute_helper(
|
179
|
-
|
168
|
+
ggml_backend_sched_t sched,
|
180
169
|
struct ggml_cgraph * graph,
|
181
170
|
int n_threads) {
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
171
|
+
|
172
|
+
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
|
173
|
+
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
|
174
|
+
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
175
|
+
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
176
|
+
|
177
|
+
auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
178
|
+
if (fn_set_n_threads) {
|
179
|
+
fn_set_n_threads(backend, n_threads);
|
180
|
+
}
|
188
181
|
}
|
189
|
-
|
190
|
-
|
182
|
+
|
183
|
+
bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
|
184
|
+
ggml_backend_sched_reset(sched);
|
185
|
+
return t;
|
191
186
|
}
|
192
187
|
|
193
188
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
@@ -363,6 +358,7 @@ static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15},
|
|
363
358
|
static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
|
364
359
|
static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
|
365
360
|
static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
|
361
|
+
static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
|
366
362
|
|
367
363
|
static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
368
364
|
{ WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
|
@@ -376,6 +372,7 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
|
376
372
|
{ WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
|
377
373
|
{ WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
|
378
374
|
{ WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
|
375
|
+
{ WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
|
379
376
|
};
|
380
377
|
|
381
378
|
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
|
@@ -502,33 +499,41 @@ struct whisper_pair {
|
|
502
499
|
whisper_pair() : first(A()), second(B()) {}
|
503
500
|
};
|
504
501
|
|
505
|
-
//
|
506
|
-
struct
|
507
|
-
|
502
|
+
// ggml_backend_sched wrapper for whisper usage
|
503
|
+
struct whisper_sched {
|
504
|
+
ggml_backend_sched_t sched = nullptr;
|
508
505
|
|
509
506
|
std::vector<uint8_t> meta;
|
510
507
|
};
|
511
508
|
|
512
|
-
static size_t
|
513
|
-
|
509
|
+
static size_t whisper_sched_size(struct whisper_sched & allocr) {
|
510
|
+
size_t size = allocr.meta.size();
|
511
|
+
for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
|
512
|
+
ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
|
513
|
+
size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
|
514
|
+
}
|
515
|
+
return size;
|
514
516
|
}
|
515
517
|
|
516
518
|
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
517
|
-
static bool
|
518
|
-
auto &
|
519
|
+
static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
|
520
|
+
auto & sched = allocr.sched;
|
519
521
|
auto & meta = allocr.meta;
|
520
522
|
|
521
|
-
|
523
|
+
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
|
522
524
|
|
523
525
|
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
524
526
|
|
525
527
|
// since there are dependencies between the different graphs,
|
526
528
|
// we need to allocate them instead of only reserving to get the correct compute buffer size
|
527
|
-
if (!
|
529
|
+
if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
|
528
530
|
// failed to allocate the compute buffer
|
529
531
|
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
530
532
|
return false;
|
531
533
|
}
|
534
|
+
|
535
|
+
ggml_backend_sched_reset(sched);
|
536
|
+
|
532
537
|
return true;
|
533
538
|
}
|
534
539
|
|
@@ -671,9 +676,9 @@ struct whisper_kv_cache {
|
|
671
676
|
struct ggml_tensor * k;
|
672
677
|
struct ggml_tensor * v;
|
673
678
|
|
674
|
-
struct ggml_context * ctx = nullptr;
|
675
|
-
|
676
679
|
ggml_backend_buffer_t buffer = nullptr;
|
680
|
+
|
681
|
+
std::vector<uint8_t> ctx_buf;
|
677
682
|
};
|
678
683
|
|
679
684
|
struct whisper_model {
|
@@ -802,6 +807,9 @@ struct whisper_state {
|
|
802
807
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
803
808
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
804
809
|
|
810
|
+
// number of decoders for which we have constructed the KV cache
|
811
|
+
int32_t kv_self_n_dec = 0;
|
812
|
+
|
805
813
|
// unified self-attention KV cache for all decoders
|
806
814
|
whisper_kv_cache kv_self;
|
807
815
|
|
@@ -809,21 +817,22 @@ struct whisper_state {
|
|
809
817
|
// shared between all decoders
|
810
818
|
whisper_kv_cache kv_cross;
|
811
819
|
|
820
|
+
// padded buffer for flash-attention
|
821
|
+
whisper_kv_cache kv_pad;
|
822
|
+
|
812
823
|
whisper_mel mel;
|
813
824
|
|
814
825
|
whisper_batch batch;
|
815
826
|
|
816
827
|
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
817
828
|
|
818
|
-
ggml_backend_t
|
829
|
+
std::vector<ggml_backend_t> backends;
|
819
830
|
|
820
|
-
// ggml-alloc:
|
821
831
|
// - stores meta info about the intermediate tensors into the `meta` buffers
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
whisper_allocr alloc_decode;
|
832
|
+
whisper_sched sched_conv;
|
833
|
+
whisper_sched sched_encode;
|
834
|
+
whisper_sched sched_cross;
|
835
|
+
whisper_sched sched_decode;
|
827
836
|
|
828
837
|
// result of the encoder
|
829
838
|
struct ggml_tensor * embd_conv = nullptr;
|
@@ -858,6 +867,7 @@ struct whisper_state {
|
|
858
867
|
whisper_token tid_last;
|
859
868
|
|
860
869
|
std::vector<float> energy; // PCM signal energy
|
870
|
+
float no_speech_prob = 0.0f;
|
861
871
|
|
862
872
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
863
873
|
whisper_aheads_masks aheads_masks;
|
@@ -882,8 +892,6 @@ struct whisper_context {
|
|
882
892
|
|
883
893
|
whisper_state * state = nullptr;
|
884
894
|
|
885
|
-
ggml_backend_t backend = nullptr;
|
886
|
-
|
887
895
|
std::string path_model; // populated by whisper_init_from_file_with_params()
|
888
896
|
};
|
889
897
|
|
@@ -901,21 +909,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
901
909
|
BYTESWAP_VALUE(dest);
|
902
910
|
}
|
903
911
|
|
904
|
-
static bool
|
905
|
-
const struct whisper_hparams & hparams,
|
912
|
+
static bool whisper_kv_cache_init(
|
906
913
|
struct whisper_kv_cache & cache,
|
907
914
|
ggml_backend_t backend,
|
908
915
|
ggml_type wtype,
|
916
|
+
int64_t n_text_state,
|
917
|
+
int64_t n_text_layer,
|
909
918
|
int n_ctx) {
|
910
|
-
const int64_t n_text_state = hparams.n_text_state;
|
911
|
-
const int64_t n_text_layer = hparams.n_text_layer;
|
912
|
-
|
913
919
|
const int64_t n_mem = n_text_layer*n_ctx;
|
914
920
|
const int64_t n_elements = n_text_state*n_mem;
|
915
921
|
|
922
|
+
cache.ctx_buf.resize(2*ggml_tensor_overhead());
|
923
|
+
|
916
924
|
struct ggml_init_params params = {
|
917
|
-
/*.mem_size =*/
|
918
|
-
/*.mem_buffer =*/
|
925
|
+
/*.mem_size =*/ cache.ctx_buf.size(),
|
926
|
+
/*.mem_buffer =*/ cache.ctx_buf.data(),
|
919
927
|
/*.no_alloc =*/ true,
|
920
928
|
};
|
921
929
|
|
@@ -925,29 +933,31 @@ static bool kv_cache_init(
|
|
925
933
|
cache.cells.clear();
|
926
934
|
cache.cells.resize(n_ctx);
|
927
935
|
|
928
|
-
|
936
|
+
struct ggml_context * ctx = ggml_init(params);
|
929
937
|
|
930
|
-
if (!
|
938
|
+
if (!ctx) {
|
931
939
|
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
|
932
940
|
return false;
|
933
941
|
}
|
934
942
|
|
935
|
-
cache.k = ggml_new_tensor_1d(
|
936
|
-
cache.v = ggml_new_tensor_1d(
|
943
|
+
cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
944
|
+
cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
937
945
|
|
938
|
-
cache.buffer = ggml_backend_alloc_ctx_tensors(
|
946
|
+
cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
939
947
|
if (!cache.buffer) {
|
940
948
|
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
|
941
949
|
return false;
|
942
950
|
}
|
943
951
|
|
952
|
+
ggml_backend_buffer_clear(cache.buffer, 0);
|
953
|
+
|
954
|
+
ggml_free(ctx);
|
955
|
+
|
944
956
|
return true;
|
945
957
|
}
|
946
958
|
|
947
|
-
static void
|
948
|
-
ggml_free(cache.ctx);
|
959
|
+
static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
|
949
960
|
ggml_backend_buffer_free(cache.buffer);
|
950
|
-
cache.ctx = nullptr;
|
951
961
|
}
|
952
962
|
|
953
963
|
static bool whisper_kv_cache_find_slot(
|
@@ -1018,6 +1028,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
|
1018
1028
|
cache.cells[i].seq_id.clear();
|
1019
1029
|
}
|
1020
1030
|
cache.head = 0;
|
1031
|
+
|
1032
|
+
ggml_backend_buffer_clear(cache.buffer, 0);
|
1021
1033
|
}
|
1022
1034
|
|
1023
1035
|
static void whisper_kv_cache_seq_rm(
|
@@ -1068,6 +1080,26 @@ static void whisper_kv_cache_seq_cp(
|
|
1068
1080
|
}
|
1069
1081
|
}
|
1070
1082
|
|
1083
|
+
static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
|
1084
|
+
if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
|
1085
|
+
return 1u;
|
1086
|
+
}
|
1087
|
+
|
1088
|
+
#ifdef GGML_USE_METAL
|
1089
|
+
if (wctx.params.use_gpu) {
|
1090
|
+
return 32u;
|
1091
|
+
}
|
1092
|
+
#endif
|
1093
|
+
|
1094
|
+
#ifdef GGML_USE_CUDA
|
1095
|
+
if (wctx.params.use_gpu) {
|
1096
|
+
return 256u;
|
1097
|
+
}
|
1098
|
+
#endif
|
1099
|
+
|
1100
|
+
return 1u;
|
1101
|
+
}
|
1102
|
+
|
1071
1103
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
1072
1104
|
static bool aheads_masks_init(
|
1073
1105
|
const whisper_context_params & cparams,
|
@@ -1199,49 +1231,71 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
|
1199
1231
|
return size;
|
1200
1232
|
}
|
1201
1233
|
|
1202
|
-
static ggml_backend_t
|
1203
|
-
|
1234
|
+
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
1235
|
+
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
1204
1236
|
|
1205
|
-
// initialize the backends
|
1206
|
-
#ifdef GGML_USE_CUDA
|
1207
1237
|
if (params.use_gpu) {
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1238
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1239
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1240
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1241
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1242
|
+
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
|
1243
|
+
if (!result) {
|
1244
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1245
|
+
}
|
1246
|
+
return result;
|
1247
|
+
}
|
1212
1248
|
}
|
1213
1249
|
}
|
1214
|
-
#endif
|
1215
1250
|
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
backend_gpu = NULL;
|
1227
|
-
}
|
1251
|
+
return nullptr;
|
1252
|
+
}
|
1253
|
+
|
1254
|
+
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
|
1255
|
+
std::vector<ggml_backend_t> result;
|
1256
|
+
|
1257
|
+
ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
|
1258
|
+
|
1259
|
+
if (backend_gpu) {
|
1260
|
+
result.push_back(backend_gpu);
|
1228
1261
|
}
|
1229
|
-
#endif
|
1230
1262
|
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1263
|
+
// ACCEL backends
|
1264
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1265
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1266
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
1267
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1268
|
+
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
1269
|
+
if (!backend) {
|
1270
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1271
|
+
continue;
|
1272
|
+
}
|
1273
|
+
result.push_back(backend);
|
1237
1274
|
}
|
1238
1275
|
}
|
1239
|
-
#endif
|
1240
1276
|
|
1241
|
-
|
1242
|
-
|
1277
|
+
GGML_UNUSED(params);
|
1278
|
+
|
1279
|
+
result.push_back(ggml_backend_cpu_init());
|
1280
|
+
|
1281
|
+
return result;
|
1282
|
+
}
|
1283
|
+
|
1284
|
+
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
|
1285
|
+
if (!params.use_gpu) {
|
1286
|
+
return ggml_backend_cpu_buffer_type();
|
1243
1287
|
}
|
1244
|
-
|
1288
|
+
|
1289
|
+
// if we have a GPU device - use it
|
1290
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1291
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1292
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1293
|
+
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
|
1294
|
+
return ggml_backend_dev_buffer_type(dev);
|
1295
|
+
}
|
1296
|
+
}
|
1297
|
+
|
1298
|
+
return ggml_backend_cpu_buffer_type();
|
1245
1299
|
}
|
1246
1300
|
|
1247
1301
|
// load the model from a ggml file
|
@@ -1668,21 +1722,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1668
1722
|
}
|
1669
1723
|
}
|
1670
1724
|
|
1671
|
-
wctx.backend = whisper_backend_init(wctx.params);
|
1672
|
-
if (!wctx.backend) {
|
1673
|
-
WHISPER_LOG_ERROR("%s: failed to initialize the backend\n", __func__);
|
1674
|
-
return false;
|
1675
|
-
}
|
1676
|
-
|
1677
1725
|
// allocate tensors in the backend buffers
|
1678
|
-
model.buffer =
|
1726
|
+
model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
|
1679
1727
|
if (!model.buffer) {
|
1680
1728
|
WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
|
1681
1729
|
return false;
|
1682
1730
|
}
|
1683
1731
|
|
1684
1732
|
size_t size_main = ggml_backend_buffer_get_size(model.buffer);
|
1685
|
-
WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__,
|
1733
|
+
WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
|
1686
1734
|
|
1687
1735
|
// load weights
|
1688
1736
|
{
|
@@ -1777,6 +1825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1777
1825
|
}
|
1778
1826
|
}
|
1779
1827
|
|
1828
|
+
ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
1829
|
+
|
1780
1830
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
1781
1831
|
|
1782
1832
|
return true;
|
@@ -1812,8 +1862,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
1812
1862
|
const int n_mels = hparams.n_mels;
|
1813
1863
|
|
1814
1864
|
struct ggml_init_params params = {
|
1815
|
-
/*.mem_size =*/ wstate.
|
1816
|
-
/*.mem_buffer =*/ wstate.
|
1865
|
+
/*.mem_size =*/ wstate.sched_conv.meta.size(),
|
1866
|
+
/*.mem_buffer =*/ wstate.sched_conv.meta.data(),
|
1817
1867
|
/*.no_alloc =*/ true,
|
1818
1868
|
};
|
1819
1869
|
|
@@ -1847,6 +1897,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
1847
1897
|
ggml_build_forward_expand(gf, mel);
|
1848
1898
|
|
1849
1899
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
1900
|
+
ggml_set_input(cur); // the external encoder will write into this tensor
|
1850
1901
|
|
1851
1902
|
ggml_set_name(cur, "embd_enc");
|
1852
1903
|
wstate.embd_enc = cur;
|
@@ -1872,9 +1923,17 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1872
1923
|
const int n_head = hparams.n_audio_head;
|
1873
1924
|
const int n_layer = hparams.n_audio_layer;
|
1874
1925
|
|
1926
|
+
const int n_state_head = n_state/n_head;
|
1927
|
+
|
1928
|
+
auto & kv_pad = wstate.kv_pad;
|
1929
|
+
|
1930
|
+
WHISPER_ASSERT(!!kv_pad.buffer);
|
1931
|
+
|
1932
|
+
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
|
1933
|
+
|
1875
1934
|
struct ggml_init_params params = {
|
1876
|
-
/*.mem_size =*/ wstate.
|
1877
|
-
/*.mem_buffer =*/ wstate.
|
1935
|
+
/*.mem_size =*/ wstate.sched_encode.meta.size(),
|
1936
|
+
/*.mem_buffer =*/ wstate.sched_encode.meta.data(),
|
1878
1937
|
/*.no_alloc =*/ true,
|
1879
1938
|
};
|
1880
1939
|
|
@@ -1884,7 +1943,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1884
1943
|
|
1885
1944
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
1886
1945
|
|
1887
|
-
const float KQscale = 1.0f/sqrtf(float(
|
1946
|
+
const float KQscale = 1.0f/sqrtf(float(n_state_head));
|
1888
1947
|
|
1889
1948
|
// ===================================================================
|
1890
1949
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
@@ -1934,14 +1993,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1934
1993
|
|
1935
1994
|
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
1936
1995
|
|
1937
|
-
//Qcur = ggml_scale(ctx0, Qcur, pow(float(
|
1996
|
+
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
|
1938
1997
|
|
1939
1998
|
// note: no bias for Key
|
1940
1999
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
1941
2000
|
layer.attn_k_w,
|
1942
2001
|
cur);
|
1943
2002
|
|
1944
|
-
//Kcur = ggml_scale(ctx0, Kcur, pow(float(
|
2003
|
+
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
|
1945
2004
|
|
1946
2005
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
1947
2006
|
layer.attn_v_w,
|
@@ -1951,70 +2010,60 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1951
2010
|
|
1952
2011
|
// ------
|
1953
2012
|
|
1954
|
-
#ifdef WHISPER_USE_FLASH_ATTN
|
1955
2013
|
struct ggml_tensor * Q =
|
1956
2014
|
ggml_permute(ctx0,
|
1957
|
-
|
1958
|
-
Qcur,
|
1959
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
2015
|
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
|
1960
2016
|
0, 2, 1, 3);
|
1961
2017
|
|
1962
|
-
|
1963
|
-
|
1964
|
-
|
1965
|
-
Kcur,
|
1966
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
1967
|
-
0, 2, 1, 3);
|
2018
|
+
if (wctx.params.flash_attn) {
|
2019
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
|
2020
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
|
1968
2021
|
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1972
|
-
|
1973
|
-
|
1974
|
-
|
1975
|
-
1, 2, 0, 3),
|
1976
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
2022
|
+
struct ggml_tensor * K =
|
2023
|
+
ggml_view_3d(ctx0, kv_pad.k,
|
2024
|
+
n_state_head, n_ctx_pad, n_head,
|
2025
|
+
ggml_element_size(kv_pad.k)*n_state,
|
2026
|
+
ggml_element_size(kv_pad.k)*n_state_head,
|
2027
|
+
0);
|
1977
2028
|
|
1978
|
-
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
1985
|
-
0, 2, 1, 3);
|
2029
|
+
struct ggml_tensor * V =
|
2030
|
+
ggml_view_3d(ctx0, kv_pad.v,
|
2031
|
+
n_state_head, n_ctx_pad, n_head,
|
2032
|
+
ggml_element_size(kv_pad.v)*n_state,
|
2033
|
+
ggml_element_size(kv_pad.v)*n_state_head,
|
2034
|
+
0);
|
1986
2035
|
|
1987
|
-
|
1988
|
-
ggml_permute(ctx0,
|
1989
|
-
ggml_cpy(ctx0,
|
1990
|
-
Kcur,
|
1991
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
1992
|
-
0, 2, 1, 3);
|
2036
|
+
cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
|
1993
2037
|
|
1994
|
-
|
1995
|
-
|
2038
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
2039
|
+
} else {
|
2040
|
+
struct ggml_tensor * K =
|
2041
|
+
ggml_permute(ctx0,
|
2042
|
+
ggml_cast(ctx0,
|
2043
|
+
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
|
2044
|
+
wctx.itype),
|
2045
|
+
0, 2, 1, 3);
|
1996
2046
|
|
1997
|
-
|
2047
|
+
// K * Q
|
2048
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
1998
2049
|
|
1999
|
-
|
2050
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
2000
2051
|
|
2001
|
-
|
2002
|
-
|
2003
|
-
|
2004
|
-
|
2005
|
-
|
2006
|
-
|
2007
|
-
|
2008
|
-
|
2009
|
-
);
|
2052
|
+
struct ggml_tensor * V =
|
2053
|
+
ggml_cast(ctx0,
|
2054
|
+
ggml_permute(ctx0,
|
2055
|
+
ggml_reshape_3d(ctx0,
|
2056
|
+
Vcur,
|
2057
|
+
n_state_head, n_head, n_ctx),
|
2058
|
+
1, 2, 0, 3),
|
2059
|
+
wctx.itype);
|
2010
2060
|
|
2011
|
-
|
2012
|
-
#endif
|
2013
|
-
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2061
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
2014
2062
|
|
2015
|
-
|
2016
|
-
|
2017
|
-
|
2063
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2064
|
+
|
2065
|
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
|
2066
|
+
}
|
2018
2067
|
}
|
2019
2068
|
|
2020
2069
|
// projection
|
@@ -2043,11 +2092,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
2043
2092
|
layer.mlp_ln_b);
|
2044
2093
|
}
|
2045
2094
|
|
2046
|
-
#ifdef WHISPER_USE_FLASH_FF
|
2047
|
-
cur = ggml_flash_ff(ctx0,
|
2048
|
-
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
2049
|
-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
2050
|
-
#else
|
2051
2095
|
// fully connected
|
2052
2096
|
cur = ggml_mul_mat(ctx0,
|
2053
2097
|
layer.mlp_0_w,
|
@@ -2064,7 +2108,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
2064
2108
|
cur);
|
2065
2109
|
|
2066
2110
|
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
|
2067
|
-
#endif
|
2068
2111
|
}
|
2069
2112
|
|
2070
2113
|
inpL = ggml_add(ctx0, cur, inpFF);
|
@@ -2113,9 +2156,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
2113
2156
|
const int n_state = hparams.n_audio_state;
|
2114
2157
|
const int n_head = hparams.n_audio_head;
|
2115
2158
|
|
2159
|
+
const int n_state_head = n_state/n_head;
|
2160
|
+
|
2161
|
+
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
|
2162
|
+
|
2116
2163
|
struct ggml_init_params params = {
|
2117
|
-
/*.mem_size =*/ wstate.
|
2118
|
-
/*.mem_buffer =*/ wstate.
|
2164
|
+
/*.mem_size =*/ wstate.sched_cross.meta.size(),
|
2165
|
+
/*.mem_buffer =*/ wstate.sched_cross.meta.data(),
|
2119
2166
|
/*.no_alloc =*/ true,
|
2120
2167
|
};
|
2121
2168
|
|
@@ -2125,18 +2172,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
2125
2172
|
|
2126
2173
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
2127
2174
|
|
2128
|
-
const float Kscale = pow(float(
|
2175
|
+
const float Kscale = pow(float(n_state_head), -0.25);
|
2129
2176
|
|
2130
2177
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
2131
2178
|
auto & layer = model.layers_decoder[il];
|
2132
2179
|
|
2133
|
-
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
2180
|
+
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
|
2134
2181
|
layer.cross_attn_k_w,
|
2135
2182
|
cur);
|
2136
2183
|
|
2137
2184
|
Kcross = ggml_scale(ctx0, Kcross, Kscale);
|
2138
2185
|
|
2139
|
-
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
2186
|
+
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
2140
2187
|
layer.cross_attn_v_w,
|
2141
2188
|
cur);
|
2142
2189
|
|
@@ -2144,15 +2191,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
2144
2191
|
Vcross,
|
2145
2192
|
layer.cross_attn_v_b);
|
2146
2193
|
|
2147
|
-
|
2194
|
+
struct ggml_tensor * k;
|
2195
|
+
struct ggml_tensor * v;
|
2148
2196
|
|
2149
|
-
|
2150
|
-
|
2151
|
-
|
2197
|
+
if (wctx.params.flash_attn) {
|
2198
|
+
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
2199
|
+
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
|
2152
2200
|
|
2153
|
-
|
2154
|
-
|
2155
|
-
|
2201
|
+
v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
|
2202
|
+
(ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
|
2203
|
+
} else {
|
2204
|
+
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
2205
|
+
|
2206
|
+
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
2207
|
+
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
2208
|
+
|
2209
|
+
v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
2210
|
+
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
2211
|
+
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
2212
|
+
}
|
2156
2213
|
|
2157
2214
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
|
2158
2215
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
|
@@ -2186,11 +2243,11 @@ static bool whisper_encode_internal(
|
|
2186
2243
|
|
2187
2244
|
// conv
|
2188
2245
|
{
|
2189
|
-
auto &
|
2246
|
+
auto & sched = wstate.sched_conv.sched;
|
2190
2247
|
|
2191
2248
|
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
|
2192
2249
|
|
2193
|
-
if (!
|
2250
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2194
2251
|
// should never happen as we pre-allocate the memory
|
2195
2252
|
return false;
|
2196
2253
|
}
|
@@ -2223,7 +2280,7 @@ static bool whisper_encode_internal(
|
|
2223
2280
|
}
|
2224
2281
|
|
2225
2282
|
if (!whisper_encode_external(wstate)) {
|
2226
|
-
if (!ggml_graph_compute_helper(
|
2283
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2227
2284
|
return false;
|
2228
2285
|
}
|
2229
2286
|
} else {
|
@@ -2237,32 +2294,32 @@ static bool whisper_encode_internal(
|
|
2237
2294
|
|
2238
2295
|
// encoder
|
2239
2296
|
if (!whisper_encode_external(wstate)) {
|
2240
|
-
auto &
|
2297
|
+
auto & sched = wstate.sched_encode.sched;
|
2241
2298
|
|
2242
2299
|
ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
|
2243
2300
|
|
2244
|
-
if (!
|
2301
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2245
2302
|
// should never happen as we pre-allocate the memory
|
2246
2303
|
return false;
|
2247
2304
|
}
|
2248
2305
|
|
2249
|
-
if (!ggml_graph_compute_helper(
|
2306
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2250
2307
|
return false;
|
2251
2308
|
}
|
2252
2309
|
}
|
2253
2310
|
|
2254
2311
|
// cross
|
2255
2312
|
{
|
2256
|
-
auto &
|
2313
|
+
auto & sched = wstate.sched_cross.sched;
|
2257
2314
|
|
2258
2315
|
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
2259
2316
|
|
2260
|
-
if (!
|
2317
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2261
2318
|
// should never happen as we pre-allocate the memory
|
2262
2319
|
return false;
|
2263
2320
|
}
|
2264
2321
|
|
2265
|
-
if (!ggml_graph_compute_helper(
|
2322
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2266
2323
|
return false;
|
2267
2324
|
}
|
2268
2325
|
}
|
@@ -2284,24 +2341,28 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2284
2341
|
|
2285
2342
|
auto & kv_self = wstate.kv_self;
|
2286
2343
|
|
2287
|
-
WHISPER_ASSERT(!!kv_self.
|
2344
|
+
WHISPER_ASSERT(!!kv_self.buffer);
|
2288
2345
|
|
2289
2346
|
const int n_ctx = kv_self.size;
|
2290
2347
|
const int n_state = hparams.n_text_state;
|
2291
2348
|
const int n_head = hparams.n_text_head;
|
2292
2349
|
const int n_layer = hparams.n_text_layer;
|
2293
2350
|
|
2351
|
+
const int n_state_head = n_state/n_head;
|
2352
|
+
|
2294
2353
|
const int n_tokens = batch.n_tokens;
|
2295
2354
|
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
2296
2355
|
|
2297
|
-
const
|
2298
|
-
|
2356
|
+
const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
|
2357
|
+
|
2358
|
+
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
|
2359
|
+
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
|
2299
2360
|
|
2300
2361
|
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
2301
2362
|
|
2302
2363
|
struct ggml_init_params params = {
|
2303
|
-
/*.mem_size =*/ wstate.
|
2304
|
-
/*.mem_buffer =*/ wstate.
|
2364
|
+
/*.mem_size =*/ wstate.sched_decode.meta.size(),
|
2365
|
+
/*.mem_buffer =*/ wstate.sched_decode.meta.data(),
|
2305
2366
|
/*.no_alloc =*/ true,
|
2306
2367
|
};
|
2307
2368
|
|
@@ -2317,12 +2378,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2317
2378
|
ggml_set_name(position, "position");
|
2318
2379
|
ggml_set_input(position);
|
2319
2380
|
|
2320
|
-
const float KQscale = pow(float(
|
2381
|
+
const float KQscale = pow(float(n_state_head), -0.25);
|
2321
2382
|
|
2322
|
-
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
2383
|
+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
|
2323
2384
|
ggml_set_name(KQ_mask, "KQ_mask");
|
2324
2385
|
ggml_set_input(KQ_mask);
|
2325
2386
|
|
2387
|
+
struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
|
2388
|
+
|
2326
2389
|
// token encoding + position encoding
|
2327
2390
|
struct ggml_tensor * cur =
|
2328
2391
|
ggml_add(ctx0,
|
@@ -2378,12 +2441,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2378
2441
|
Vcur,
|
2379
2442
|
layer.attn_v_b);
|
2380
2443
|
|
2381
|
-
|
2444
|
+
struct ggml_tensor * k;
|
2445
|
+
struct ggml_tensor * v;
|
2446
|
+
|
2447
|
+
if (wctx.params.flash_attn) {
|
2448
|
+
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
2449
|
+
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
2450
|
+
|
2451
|
+
v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
|
2452
|
+
(ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
|
2453
|
+
} else {
|
2454
|
+
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
2455
|
+
|
2456
|
+
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
2457
|
+
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
2382
2458
|
|
2383
|
-
|
2384
|
-
|
2385
|
-
|
2386
|
-
|
2459
|
+
v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
2460
|
+
( n_ctx)*ggml_element_size(kv_self.v),
|
2461
|
+
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
|
2462
|
+
}
|
2387
2463
|
|
2388
2464
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
2389
2465
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
@@ -2393,40 +2469,46 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2393
2469
|
|
2394
2470
|
struct ggml_tensor * Q =
|
2395
2471
|
ggml_permute(ctx0,
|
2396
|
-
ggml_reshape_3d(ctx0, Qcur,
|
2472
|
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
2397
2473
|
0, 2, 1, 3);
|
2398
2474
|
|
2399
2475
|
struct ggml_tensor * K =
|
2400
2476
|
ggml_view_3d(ctx0, kv_self.k,
|
2401
|
-
|
2477
|
+
n_state_head, n_kv, n_head,
|
2402
2478
|
ggml_element_size(kv_self.k)*n_state,
|
2403
|
-
ggml_element_size(kv_self.k)*
|
2479
|
+
ggml_element_size(kv_self.k)*n_state_head,
|
2404
2480
|
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
2405
2481
|
|
2406
|
-
|
2407
|
-
|
2482
|
+
if (wctx.params.flash_attn) {
|
2483
|
+
struct ggml_tensor * V =
|
2484
|
+
ggml_view_3d(ctx0, kv_self.v,
|
2485
|
+
n_state_head, n_kv, n_head,
|
2486
|
+
ggml_element_size(kv_self.v)*n_state,
|
2487
|
+
ggml_element_size(kv_self.v)*n_state_head,
|
2488
|
+
ggml_element_size(kv_self.v)*n_state*n_ctx*il);
|
2408
2489
|
|
2409
|
-
|
2490
|
+
cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
|
2410
2491
|
|
2411
|
-
|
2412
|
-
|
2492
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
2493
|
+
} else {
|
2494
|
+
// K * Q
|
2495
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
2413
2496
|
|
2414
|
-
|
2497
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
|
2415
2498
|
|
2416
|
-
|
2417
|
-
|
2418
|
-
|
2419
|
-
|
2420
|
-
|
2421
|
-
|
2499
|
+
struct ggml_tensor * V =
|
2500
|
+
ggml_view_3d(ctx0, kv_self.v,
|
2501
|
+
n_kv, n_state_head, n_head,
|
2502
|
+
n_ctx*ggml_element_size(kv_self.v),
|
2503
|
+
n_ctx*ggml_element_size(kv_self.v)*n_state_head,
|
2504
|
+
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
|
2422
2505
|
|
2423
|
-
|
2506
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
2424
2507
|
|
2425
|
-
|
2508
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2426
2509
|
|
2427
|
-
|
2428
|
-
|
2429
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
2510
|
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
2511
|
+
}
|
2430
2512
|
}
|
2431
2513
|
|
2432
2514
|
// projection
|
@@ -2465,80 +2547,75 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2465
2547
|
Qcur,
|
2466
2548
|
layer.cross_attn_q_b);
|
2467
2549
|
|
2468
|
-
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
2469
|
-
|
2470
|
-
// Kcross is already scaled
|
2471
|
-
struct ggml_tensor * Kcross =
|
2472
|
-
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
2473
|
-
n_state/n_head, n_audio_ctx, n_head,
|
2474
|
-
ggml_element_size(wstate.kv_cross.k)*n_state,
|
2475
|
-
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
2476
|
-
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
2477
|
-
|
2478
|
-
//struct ggml_tensor * Vcross =
|
2479
|
-
// ggml_reshape_3d(ctx0,
|
2480
|
-
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
|
2481
|
-
// n_state/n_head, n_head, n_audio_ctx);
|
2482
|
-
|
2483
|
-
//struct ggml_tensor * V_trans =
|
2484
|
-
// ggml_cpy(ctx0,
|
2485
|
-
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
2486
|
-
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
2487
|
-
|
2488
|
-
struct ggml_tensor * V =
|
2489
|
-
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
2490
|
-
n_audio_ctx, n_state/n_head, n_head,
|
2491
|
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
2492
|
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
2493
|
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
2494
|
-
|
2495
|
-
// ------
|
2496
|
-
|
2497
2550
|
struct ggml_tensor * Q =
|
2498
2551
|
ggml_permute(ctx0,
|
2499
|
-
ggml_reshape_3d(ctx0, Qcur,
|
2552
|
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
2500
2553
|
0, 2, 1, 3);
|
2501
2554
|
|
2502
|
-
|
2503
|
-
|
2504
|
-
|
2505
|
-
|
2506
|
-
|
2507
|
-
|
2508
|
-
|
2509
|
-
// );
|
2555
|
+
if (wctx.params.flash_attn) {
|
2556
|
+
struct ggml_tensor * Kcross =
|
2557
|
+
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
2558
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
2559
|
+
ggml_element_size(wstate.kv_cross.k)*n_state,
|
2560
|
+
ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
2561
|
+
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
|
2510
2562
|
|
2511
|
-
|
2512
|
-
|
2563
|
+
struct ggml_tensor * Vcross =
|
2564
|
+
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
2565
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
2566
|
+
ggml_element_size(wstate.kv_cross.v)*n_state,
|
2567
|
+
ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
2568
|
+
ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
|
2513
2569
|
|
2514
|
-
|
2570
|
+
cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
|
2515
2571
|
|
2516
|
-
|
2517
|
-
|
2518
|
-
|
2519
|
-
|
2520
|
-
|
2521
|
-
|
2522
|
-
|
2523
|
-
|
2524
|
-
|
2525
|
-
|
2526
|
-
|
2527
|
-
|
2528
|
-
|
2529
|
-
|
2572
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
2573
|
+
} else {
|
2574
|
+
struct ggml_tensor * Kcross =
|
2575
|
+
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
2576
|
+
n_state_head, n_audio_ctx, n_head,
|
2577
|
+
ggml_element_size(wstate.kv_cross.k)*n_state,
|
2578
|
+
ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
2579
|
+
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
2580
|
+
|
2581
|
+
struct ggml_tensor * Vcross =
|
2582
|
+
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
2583
|
+
n_audio_ctx, n_state_head, n_head,
|
2584
|
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
2585
|
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
2586
|
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
2587
|
+
|
2588
|
+
// ------
|
2589
|
+
|
2590
|
+
// K * Q
|
2591
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
|
2592
|
+
|
2593
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
2594
|
+
|
2595
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
2596
|
+
if (wctx.params.dtw_token_timestamps) {
|
2597
|
+
if (wstate.aheads_masks.m[il] != nullptr) {
|
2598
|
+
struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
|
2599
|
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
2600
|
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
2601
|
+
aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
|
2602
|
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
2603
|
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
2604
|
+
aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
|
2605
|
+
if (aheads_cross_QKs == NULL) {
|
2606
|
+
aheads_cross_QKs = aheads_KQs;
|
2607
|
+
} else {
|
2608
|
+
aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
|
2609
|
+
}
|
2530
2610
|
}
|
2531
2611
|
}
|
2532
|
-
}
|
2533
2612
|
|
2534
|
-
|
2613
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
|
2535
2614
|
|
2536
|
-
|
2615
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2537
2616
|
|
2538
|
-
|
2539
|
-
|
2540
|
-
KQV_merged,
|
2541
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
2617
|
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
2618
|
+
}
|
2542
2619
|
}
|
2543
2620
|
|
2544
2621
|
// projection
|
@@ -2671,18 +2748,20 @@ static bool whisper_decode_internal(
|
|
2671
2748
|
return false;
|
2672
2749
|
}
|
2673
2750
|
|
2674
|
-
|
2751
|
+
const uint32_t pad = whisper_kv_cache_get_padding(wctx);
|
2752
|
+
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
|
2753
|
+
|
2675
2754
|
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
2676
2755
|
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
2677
2756
|
}
|
2678
2757
|
|
2679
2758
|
// decoder
|
2680
2759
|
{
|
2681
|
-
auto &
|
2760
|
+
auto & sched = wstate.sched_decode.sched;
|
2682
2761
|
|
2683
2762
|
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
|
2684
2763
|
|
2685
|
-
if (!
|
2764
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2686
2765
|
// should never happen as we pre-allocate the memory
|
2687
2766
|
return false;
|
2688
2767
|
}
|
@@ -2705,9 +2784,10 @@ static bool whisper_decode_internal(
|
|
2705
2784
|
struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
|
2706
2785
|
|
2707
2786
|
auto & kv_self = wstate.kv_self;
|
2708
|
-
const int32_t n_kv = kv_self.n;
|
2709
2787
|
|
2710
|
-
|
2788
|
+
const int32_t n_kv = kv_self.n;
|
2789
|
+
|
2790
|
+
wstate.inp_mask.resize(ggml_nelements(KQ_mask));
|
2711
2791
|
|
2712
2792
|
float * data = wstate.inp_mask.data();
|
2713
2793
|
memset(data, 0, ggml_nbytes(KQ_mask));
|
@@ -2723,14 +2803,20 @@ static bool whisper_decode_internal(
|
|
2723
2803
|
}
|
2724
2804
|
}
|
2725
2805
|
}
|
2806
|
+
|
2807
|
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
2808
|
+
for (int j = 0; j < n_kv; ++j) {
|
2809
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
2810
|
+
}
|
2811
|
+
}
|
2726
2812
|
}
|
2727
2813
|
|
2728
2814
|
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
2729
2815
|
}
|
2730
2816
|
|
2731
|
-
logits = gf
|
2817
|
+
logits = ggml_graph_node(gf, -1);
|
2732
2818
|
|
2733
|
-
if (!ggml_graph_compute_helper(
|
2819
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2734
2820
|
return false;
|
2735
2821
|
}
|
2736
2822
|
}
|
@@ -2784,29 +2870,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
2784
2870
|
}
|
2785
2871
|
|
2786
2872
|
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
2787
|
-
|
2788
|
-
|
2873
|
+
namespace {
|
2874
|
+
struct whisper_global_cache {
|
2875
|
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
2876
|
+
// We can use precalculated values to speed up the process.
|
2877
|
+
float sin_vals[SIN_COS_N_COUNT];
|
2878
|
+
float cos_vals[SIN_COS_N_COUNT];
|
2879
|
+
|
2880
|
+
// Hann window (Use cosf to eliminate difference)
|
2881
|
+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
2882
|
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
2883
|
+
float hann_window[WHISPER_N_FFT];
|
2884
|
+
|
2885
|
+
whisper_global_cache() {
|
2886
|
+
fill_sin_cos_table();
|
2887
|
+
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
|
2888
|
+
}
|
2889
|
+
|
2890
|
+
void fill_sin_cos_table() {
|
2891
|
+
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
2892
|
+
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
2893
|
+
sin_vals[i] = sinf(theta);
|
2894
|
+
cos_vals[i] = cosf(theta);
|
2895
|
+
}
|
2896
|
+
}
|
2789
2897
|
|
2790
|
-
|
2791
|
-
|
2792
|
-
|
2793
|
-
|
2794
|
-
|
2795
|
-
|
2796
|
-
|
2797
|
-
|
2798
|
-
cos_vals[i] = cosf(theta);
|
2898
|
+
void fill_hann_window(int length, bool periodic, float * output) {
|
2899
|
+
int offset = -1;
|
2900
|
+
if (periodic) {
|
2901
|
+
offset = 0;
|
2902
|
+
}
|
2903
|
+
for (int i = 0; i < length; i++) {
|
2904
|
+
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
2905
|
+
}
|
2799
2906
|
}
|
2800
|
-
|
2907
|
+
} global_cache;
|
2801
2908
|
}
|
2802
2909
|
|
2803
2910
|
// naive Discrete Fourier Transform
|
2804
2911
|
// input is real-valued
|
2805
2912
|
// output is complex-valued
|
2806
|
-
static void dft(const
|
2807
|
-
int N = in.size();
|
2808
|
-
|
2809
|
-
out.resize(N*2);
|
2913
|
+
static void dft(const float* in, int N, float* out) {
|
2810
2914
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
2811
2915
|
|
2812
2916
|
for (int k = 0; k < N; k++) {
|
@@ -2815,8 +2919,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
2815
2919
|
|
2816
2920
|
for (int n = 0; n < N; n++) {
|
2817
2921
|
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
2818
|
-
re += in[n]*cos_vals[idx]; // cos(t)
|
2819
|
-
im -= in[n]*sin_vals[idx]; // sin(t)
|
2922
|
+
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
2923
|
+
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
2820
2924
|
}
|
2821
2925
|
|
2822
2926
|
out[k*2 + 0] = re;
|
@@ -2828,47 +2932,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
2828
2932
|
// poor man's implementation - use something better
|
2829
2933
|
// input is real-valued
|
2830
2934
|
// output is complex-valued
|
2831
|
-
static void fft(
|
2832
|
-
out.resize(in.size()*2);
|
2833
|
-
|
2834
|
-
int N = in.size();
|
2835
|
-
|
2935
|
+
static void fft(float* in, int N, float* out) {
|
2836
2936
|
if (N == 1) {
|
2837
2937
|
out[0] = in[0];
|
2838
2938
|
out[1] = 0;
|
2839
2939
|
return;
|
2840
2940
|
}
|
2841
2941
|
|
2842
|
-
|
2843
|
-
|
2942
|
+
const int half_N = N / 2;
|
2943
|
+
if (N - half_N*2 == 1) {
|
2944
|
+
dft(in, N, out);
|
2844
2945
|
return;
|
2845
2946
|
}
|
2846
2947
|
|
2847
|
-
|
2848
|
-
|
2849
|
-
|
2850
|
-
even.reserve(N/2);
|
2851
|
-
odd.reserve(N/2);
|
2852
|
-
|
2853
|
-
for (int i = 0; i < N; i++) {
|
2854
|
-
if (i % 2 == 0) {
|
2855
|
-
even.push_back(in[i]);
|
2856
|
-
} else {
|
2857
|
-
odd.push_back(in[i]);
|
2858
|
-
}
|
2948
|
+
float* even = in + N;
|
2949
|
+
for (int i = 0; i < half_N; ++i) {
|
2950
|
+
even[i]= in[2*i];
|
2859
2951
|
}
|
2952
|
+
float* even_fft = out + 2 * N;
|
2953
|
+
fft(even, half_N, even_fft);
|
2860
2954
|
|
2861
|
-
|
2862
|
-
|
2863
|
-
|
2864
|
-
|
2865
|
-
|
2955
|
+
float* odd = even;
|
2956
|
+
for (int i = 0; i < half_N; ++i) {
|
2957
|
+
odd[i] = in[2*i + 1];
|
2958
|
+
}
|
2959
|
+
float* odd_fft = even_fft + N;
|
2960
|
+
fft(odd, half_N, odd_fft);
|
2866
2961
|
|
2867
2962
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
2868
|
-
for (int k = 0; k <
|
2963
|
+
for (int k = 0; k < half_N; k++) {
|
2869
2964
|
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
2870
|
-
float re = cos_vals[idx]; // cos(t)
|
2871
|
-
float im = -sin_vals[idx]; // sin(t)
|
2965
|
+
float re = global_cache.cos_vals[idx]; // cos(t)
|
2966
|
+
float im = -global_cache.sin_vals[idx]; // sin(t)
|
2872
2967
|
|
2873
2968
|
float re_odd = odd_fft[2*k + 0];
|
2874
2969
|
float im_odd = odd_fft[2*k + 1];
|
@@ -2876,52 +2971,39 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
2876
2971
|
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
2877
2972
|
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
2878
2973
|
|
2879
|
-
out[2*(k +
|
2880
|
-
out[2*(k +
|
2881
|
-
}
|
2882
|
-
}
|
2883
|
-
|
2884
|
-
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
|
2885
|
-
if (output.size() < static_cast<size_t>(length)) {
|
2886
|
-
output.resize(length);
|
2974
|
+
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
2975
|
+
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
2887
2976
|
}
|
2888
|
-
int offset = -1;
|
2889
|
-
if (periodic) {
|
2890
|
-
offset = 0;
|
2891
|
-
}
|
2892
|
-
for (int i = 0; i < length; i++) {
|
2893
|
-
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
|
2894
|
-
}
|
2895
|
-
|
2896
|
-
return true;
|
2897
2977
|
}
|
2898
2978
|
|
2899
|
-
static void log_mel_spectrogram_worker_thread(int ith, const
|
2979
|
+
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
2900
2980
|
int n_samples, int frame_size, int frame_step, int n_threads,
|
2901
2981
|
const whisper_filters & filters, whisper_mel & mel) {
|
2902
|
-
std::vector<float> fft_in(frame_size, 0.0);
|
2903
|
-
std::vector<float> fft_out(2 *
|
2982
|
+
std::vector<float> fft_in(frame_size * 2, 0.0);
|
2983
|
+
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
2984
|
+
|
2904
2985
|
int n_fft = filters.n_fft;
|
2905
2986
|
int i = ith;
|
2906
2987
|
|
2907
2988
|
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
2908
|
-
assert(
|
2909
|
-
|
2989
|
+
assert(n_fft == 1 + (frame_size / 2));
|
2990
|
+
|
2910
2991
|
// calculate FFT only when fft_in are not all zero
|
2911
2992
|
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
2912
2993
|
const int offset = i * frame_step;
|
2913
2994
|
|
2914
|
-
// apply
|
2995
|
+
// apply Hann window (~10% faster)
|
2915
2996
|
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
2916
2997
|
fft_in[j] = hann[j] * samples[offset + j];
|
2917
2998
|
}
|
2999
|
+
|
2918
3000
|
// fill the rest with zeros
|
2919
3001
|
if (n_samples - offset < frame_size) {
|
2920
3002
|
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
|
2921
3003
|
}
|
2922
3004
|
|
2923
3005
|
// FFT
|
2924
|
-
fft(fft_in, fft_out);
|
3006
|
+
fft(fft_in.data(), frame_size, fft_out.data());
|
2925
3007
|
|
2926
3008
|
// Calculate modulus^2 of complex numbers
|
2927
3009
|
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
@@ -2932,7 +3014,6 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
2932
3014
|
// mel spectrogram
|
2933
3015
|
for (int j = 0; j < mel.n_mel; j++) {
|
2934
3016
|
double sum = 0.0;
|
2935
|
-
|
2936
3017
|
// unroll loop (suggested by GH user @lunixbochs)
|
2937
3018
|
int k = 0;
|
2938
3019
|
for (k = 0; k < n_fft - 3; k += 4) {
|
@@ -2942,14 +3023,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
2942
3023
|
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
|
2943
3024
|
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
|
2944
3025
|
}
|
2945
|
-
|
2946
3026
|
// handle n_fft remainder
|
2947
3027
|
for (; k < n_fft; k++) {
|
2948
3028
|
sum += fft_out[k] * filters.data[j * n_fft + k];
|
2949
3029
|
}
|
2950
|
-
|
2951
3030
|
sum = log10(std::max(sum, 1e-10));
|
2952
|
-
|
2953
3031
|
mel.data[j * mel.n_len + i] = sum;
|
2954
3032
|
}
|
2955
3033
|
}
|
@@ -2978,12 +3056,9 @@ static bool log_mel_spectrogram(
|
|
2978
3056
|
whisper_mel & mel) {
|
2979
3057
|
const int64_t t_start_us = ggml_time_us();
|
2980
3058
|
|
2981
|
-
//
|
2982
|
-
|
2983
|
-
|
2984
|
-
std::vector<float> hann;
|
2985
|
-
hann_window(frame_size, true, hann);
|
2986
|
-
|
3059
|
+
// Hann window
|
3060
|
+
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
|
3061
|
+
const float * hann = global_cache.hann_window;
|
2987
3062
|
|
2988
3063
|
// Calculate the length of padding
|
2989
3064
|
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
@@ -3008,12 +3083,11 @@ static bool log_mel_spectrogram(
|
|
3008
3083
|
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
3009
3084
|
mel.data.resize(mel.n_mel * mel.n_len);
|
3010
3085
|
|
3011
|
-
|
3012
3086
|
{
|
3013
3087
|
std::vector<std::thread> workers(n_threads - 1);
|
3014
3088
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
3015
3089
|
workers[iw] = std::thread(
|
3016
|
-
log_mel_spectrogram_worker_thread, iw + 1, std::cref(
|
3090
|
+
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
|
3017
3091
|
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
3018
3092
|
std::cref(filters), std::ref(mel));
|
3019
3093
|
}
|
@@ -3173,23 +3247,23 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
|
3173
3247
|
#endif
|
3174
3248
|
|
3175
3249
|
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
3176
|
-
fill_sin_cos_table();
|
3177
|
-
|
3178
3250
|
whisper_state * state = new whisper_state;
|
3179
3251
|
|
3180
|
-
state->
|
3181
|
-
if (
|
3252
|
+
state->backends = whisper_backend_init(ctx->params);
|
3253
|
+
if (state->backends.empty()) {
|
3182
3254
|
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
3183
3255
|
whisper_free_state(state);
|
3184
3256
|
return nullptr;
|
3185
3257
|
}
|
3186
3258
|
|
3187
|
-
// at this point, we don't know yet how many decoders will be used
|
3188
|
-
//
|
3189
|
-
|
3190
|
-
|
3191
|
-
|
3192
|
-
|
3259
|
+
// at this point, we don't know yet how many decoders will be used
|
3260
|
+
// later during decoding, if more decoders are used, we will recreate the KV cache respectively
|
3261
|
+
state->kv_self_n_dec = 1;
|
3262
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
3263
|
+
ctx->model.hparams.n_text_state,
|
3264
|
+
ctx->model.hparams.n_text_layer,
|
3265
|
+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
|
3266
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
3193
3267
|
whisper_free_state(state);
|
3194
3268
|
return nullptr;
|
3195
3269
|
}
|
@@ -3199,8 +3273,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3199
3273
|
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
3200
3274
|
}
|
3201
3275
|
|
3202
|
-
if (!
|
3203
|
-
|
3276
|
+
if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
|
3277
|
+
ctx->model.hparams.n_text_state,
|
3278
|
+
ctx->model.hparams.n_text_layer,
|
3279
|
+
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
3280
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
|
3204
3281
|
whisper_free_state(state);
|
3205
3282
|
return nullptr;
|
3206
3283
|
}
|
@@ -3210,9 +3287,23 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3210
3287
|
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
3211
3288
|
}
|
3212
3289
|
|
3290
|
+
if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
|
3291
|
+
ctx->model.hparams.n_audio_state,
|
3292
|
+
1,
|
3293
|
+
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
3294
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
3295
|
+
whisper_free_state(state);
|
3296
|
+
return nullptr;
|
3297
|
+
}
|
3298
|
+
|
3299
|
+
{
|
3300
|
+
const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
|
3301
|
+
WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
|
3302
|
+
}
|
3303
|
+
|
3213
3304
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
3214
3305
|
if (ctx->params.dtw_token_timestamps) {
|
3215
|
-
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks,
|
3306
|
+
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
|
3216
3307
|
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
|
3217
3308
|
whisper_free_state(state);
|
3218
3309
|
return nullptr;
|
@@ -3255,7 +3346,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3255
3346
|
|
3256
3347
|
// conv allocator
|
3257
3348
|
{
|
3258
|
-
bool ok =
|
3349
|
+
bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
|
3259
3350
|
[&]() {
|
3260
3351
|
return whisper_build_graph_conv(*ctx, *state);
|
3261
3352
|
});
|
@@ -3266,12 +3357,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3266
3357
|
return nullptr;
|
3267
3358
|
}
|
3268
3359
|
|
3269
|
-
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__,
|
3360
|
+
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
|
3270
3361
|
}
|
3271
3362
|
|
3272
3363
|
// encoder allocator
|
3273
3364
|
if (!whisper_encode_external(*state)) {
|
3274
|
-
bool ok =
|
3365
|
+
bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
|
3275
3366
|
[&]() {
|
3276
3367
|
return whisper_build_graph_encoder(*ctx, *state);
|
3277
3368
|
});
|
@@ -3282,12 +3373,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3282
3373
|
return nullptr;
|
3283
3374
|
}
|
3284
3375
|
|
3285
|
-
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__,
|
3376
|
+
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
|
3286
3377
|
}
|
3287
3378
|
|
3288
3379
|
// cross allocator
|
3289
3380
|
{
|
3290
|
-
bool ok =
|
3381
|
+
bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
|
3291
3382
|
[&]() {
|
3292
3383
|
return whisper_build_graph_cross(*ctx, *state);
|
3293
3384
|
});
|
@@ -3298,12 +3389,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3298
3389
|
return nullptr;
|
3299
3390
|
}
|
3300
3391
|
|
3301
|
-
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__,
|
3392
|
+
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
|
3302
3393
|
}
|
3303
3394
|
|
3304
3395
|
// decoder allocator
|
3305
3396
|
{
|
3306
|
-
bool ok =
|
3397
|
+
bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
|
3307
3398
|
[&]() {
|
3308
3399
|
const auto & hparams = ctx->model.hparams;
|
3309
3400
|
|
@@ -3322,19 +3413,21 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3322
3413
|
return nullptr;
|
3323
3414
|
}
|
3324
3415
|
|
3325
|
-
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__,
|
3416
|
+
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
|
3326
3417
|
}
|
3327
3418
|
|
3328
3419
|
return state;
|
3329
3420
|
}
|
3330
3421
|
|
3331
|
-
int
|
3422
|
+
int whisper_ctx_init_openvino_encoder_with_state(
|
3332
3423
|
struct whisper_context * ctx,
|
3424
|
+
struct whisper_state * state,
|
3333
3425
|
const char * model_path,
|
3334
3426
|
const char * device,
|
3335
3427
|
const char * cache_dir) {
|
3336
3428
|
#ifndef WHISPER_USE_OPENVINO
|
3337
3429
|
(void)(ctx);
|
3430
|
+
(void)(state);
|
3338
3431
|
(void)(model_path);
|
3339
3432
|
(void)(device);
|
3340
3433
|
(void)(cache_dir);
|
@@ -3365,8 +3458,8 @@ int whisper_ctx_init_openvino_encoder(
|
|
3365
3458
|
WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
|
3366
3459
|
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
3367
3460
|
|
3368
|
-
|
3369
|
-
if (!
|
3461
|
+
state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
|
3462
|
+
if (!state->ctx_openvino) {
|
3370
3463
|
WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
|
3371
3464
|
return 1;
|
3372
3465
|
} else {
|
@@ -3377,9 +3470,18 @@ int whisper_ctx_init_openvino_encoder(
|
|
3377
3470
|
#endif
|
3378
3471
|
}
|
3379
3472
|
|
3473
|
+
int whisper_ctx_init_openvino_encoder(
|
3474
|
+
struct whisper_context * ctx,
|
3475
|
+
const char * model_path,
|
3476
|
+
const char * device,
|
3477
|
+
const char * cache_dir) {
|
3478
|
+
return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
|
3479
|
+
}
|
3480
|
+
|
3380
3481
|
struct whisper_context_params whisper_context_default_params() {
|
3381
3482
|
struct whisper_context_params result = {
|
3382
3483
|
/*.use_gpu =*/ true,
|
3484
|
+
/*.flash_attn =*/ false,
|
3383
3485
|
/*.gpu_device =*/ 0,
|
3384
3486
|
|
3385
3487
|
/*.dtw_token_timestamps =*/ false,
|
@@ -3396,8 +3498,14 @@ struct whisper_context_params whisper_context_default_params() {
|
|
3396
3498
|
|
3397
3499
|
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
|
3398
3500
|
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
|
3399
|
-
|
3501
|
+
#ifdef _MSC_VER
|
3502
|
+
// Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
|
3503
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
3504
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
3505
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
3506
|
+
#else
|
3400
3507
|
auto fin = std::ifstream(path_model, std::ios::binary);
|
3508
|
+
#endif
|
3401
3509
|
if (!fin) {
|
3402
3510
|
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
|
3403
3511
|
return nullptr;
|
@@ -3472,6 +3580,18 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
|
|
3472
3580
|
struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
|
3473
3581
|
ggml_time_init();
|
3474
3582
|
|
3583
|
+
if (params.flash_attn && params.dtw_token_timestamps) {
|
3584
|
+
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
|
3585
|
+
params.dtw_token_timestamps = false;
|
3586
|
+
}
|
3587
|
+
|
3588
|
+
WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
|
3589
|
+
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
|
3590
|
+
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
|
3591
|
+
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
|
3592
|
+
WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
|
3593
|
+
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
|
3594
|
+
|
3475
3595
|
whisper_context * ctx = new whisper_context;
|
3476
3596
|
ctx->params = params;
|
3477
3597
|
|
@@ -3558,8 +3678,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
|
3558
3678
|
|
3559
3679
|
void whisper_free_state(struct whisper_state * state) {
|
3560
3680
|
if (state) {
|
3561
|
-
|
3562
|
-
|
3681
|
+
whisper_kv_cache_free(state->kv_self);
|
3682
|
+
whisper_kv_cache_free(state->kv_cross);
|
3683
|
+
whisper_kv_cache_free(state->kv_pad);
|
3563
3684
|
|
3564
3685
|
#ifdef WHISPER_USE_COREML
|
3565
3686
|
if (state->ctx_coreml != nullptr) {
|
@@ -3577,12 +3698,14 @@ void whisper_free_state(struct whisper_state * state) {
|
|
3577
3698
|
|
3578
3699
|
whisper_batch_free(state->batch);
|
3579
3700
|
|
3580
|
-
|
3581
|
-
|
3582
|
-
|
3583
|
-
|
3701
|
+
ggml_backend_sched_free(state->sched_conv.sched);
|
3702
|
+
ggml_backend_sched_free(state->sched_encode.sched);
|
3703
|
+
ggml_backend_sched_free(state->sched_cross.sched);
|
3704
|
+
ggml_backend_sched_free(state->sched_decode.sched);
|
3584
3705
|
|
3585
|
-
|
3706
|
+
for (auto & backend : state->backends) {
|
3707
|
+
ggml_backend_free(backend);
|
3708
|
+
}
|
3586
3709
|
|
3587
3710
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
3588
3711
|
aheads_masks_free(state->aheads_masks);
|
@@ -3599,8 +3722,6 @@ void whisper_free(struct whisper_context * ctx) {
|
|
3599
3722
|
|
3600
3723
|
whisper_free_state(ctx->state);
|
3601
3724
|
|
3602
|
-
ggml_backend_free(ctx->backend);
|
3603
|
-
|
3604
3725
|
delete ctx;
|
3605
3726
|
}
|
3606
3727
|
}
|
@@ -3630,30 +3751,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
3630
3751
|
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
3631
3752
|
}
|
3632
3753
|
|
3633
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
3634
|
-
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
3635
|
-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
3636
|
-
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
3637
|
-
return -1;
|
3638
|
-
}
|
3639
|
-
|
3640
|
-
return 0;
|
3641
|
-
}
|
3642
|
-
|
3643
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
3644
|
-
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
3645
|
-
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
3646
|
-
}
|
3647
|
-
|
3648
|
-
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
|
3649
|
-
// TODO
|
3650
|
-
|
3651
|
-
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
|
3652
|
-
// TODO
|
3653
|
-
|
3654
|
-
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
|
3655
|
-
// TODO
|
3656
|
-
|
3657
3754
|
int whisper_set_mel_with_state(
|
3658
3755
|
struct whisper_context * ctx,
|
3659
3756
|
struct whisper_state * state,
|
@@ -3742,7 +3839,7 @@ int whisper_token_count(struct whisper_context * ctx, const char * text) {
|
|
3742
3839
|
return -whisper_tokenize(ctx, text, NULL, 0);
|
3743
3840
|
}
|
3744
3841
|
|
3745
|
-
int whisper_lang_max_id() {
|
3842
|
+
int whisper_lang_max_id(void) {
|
3746
3843
|
auto max_id = 0;
|
3747
3844
|
for (const auto & kv : g_lang) {
|
3748
3845
|
max_id = std::max(max_id, kv.second.first);
|
@@ -4011,6 +4108,19 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
|
|
4011
4108
|
return ctx->vocab.token_transcribe;
|
4012
4109
|
}
|
4013
4110
|
|
4111
|
+
struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
|
4112
|
+
if (ctx->state == nullptr) {
|
4113
|
+
return nullptr;
|
4114
|
+
}
|
4115
|
+
whisper_timings * timings = new whisper_timings;
|
4116
|
+
timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
|
4117
|
+
timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
|
4118
|
+
timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
|
4119
|
+
timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd);
|
4120
|
+
timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt);
|
4121
|
+
return timings;
|
4122
|
+
}
|
4123
|
+
|
4014
4124
|
void whisper_print_timings(struct whisper_context * ctx) {
|
4015
4125
|
const int64_t t_end_us = ggml_time_us();
|
4016
4126
|
|
@@ -4078,17 +4188,14 @@ const char * whisper_print_system_info(void) {
|
|
4078
4188
|
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
4079
4189
|
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
4080
4190
|
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
4081
|
-
s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
|
4082
4191
|
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
4083
4192
|
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
4084
4193
|
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
4085
|
-
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
4086
4194
|
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
4087
4195
|
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
|
4088
4196
|
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
4089
|
-
s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
|
4090
4197
|
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
4091
|
-
s += "OPENVINO = " + std::to_string(whisper_has_openvino())
|
4198
|
+
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
4092
4199
|
|
4093
4200
|
return s.c_str();
|
4094
4201
|
}
|
@@ -4099,7 +4206,7 @@ const char * whisper_print_system_info(void) {
|
|
4099
4206
|
|
4100
4207
|
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
4101
4208
|
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
4102
|
-
std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
4209
|
+
static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
4103
4210
|
const char * src,
|
4104
4211
|
whisper_partial_utf8 partial_start) {
|
4105
4212
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
@@ -4513,7 +4620,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
|
|
4513
4620
|
|
4514
4621
|
////////////////////////////////////////////////////////////////////////////
|
4515
4622
|
|
4516
|
-
struct whisper_context_params * whisper_context_default_params_by_ref() {
|
4623
|
+
struct whisper_context_params * whisper_context_default_params_by_ref(void) {
|
4517
4624
|
struct whisper_context_params params = whisper_context_default_params();
|
4518
4625
|
|
4519
4626
|
struct whisper_context_params* result = new whisper_context_params();
|
@@ -4554,7 +4661,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
4554
4661
|
/*.split_on_word =*/ false,
|
4555
4662
|
/*.max_tokens =*/ 0,
|
4556
4663
|
|
4557
|
-
/*.speed_up =*/ false,
|
4558
4664
|
/*.debug_mode =*/ false,
|
4559
4665
|
/*.audio_ctx =*/ 0,
|
4560
4666
|
|
@@ -4720,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
4720
4826
|
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
4721
4827
|
};
|
4722
4828
|
|
4829
|
+
static void whisper_compute_logprobs(
|
4830
|
+
const std::vector<float> & logits,
|
4831
|
+
const int n_logits,
|
4832
|
+
std::vector<float> & logprobs) {
|
4833
|
+
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
4834
|
+
float logsumexp = 0.0f;
|
4835
|
+
for (int i = 0; i < n_logits; ++i) {
|
4836
|
+
if (logits[i] > -INFINITY) {
|
4837
|
+
logsumexp += expf(logits[i] - logit_max);
|
4838
|
+
}
|
4839
|
+
}
|
4840
|
+
logsumexp = logf(logsumexp) + logit_max;
|
4841
|
+
|
4842
|
+
for (int i = 0; i < n_logits; ++i) {
|
4843
|
+
if (logits[i] > -INFINITY) {
|
4844
|
+
logprobs[i] = logits[i] - logsumexp;
|
4845
|
+
} else {
|
4846
|
+
logprobs[i] = -INFINITY;
|
4847
|
+
}
|
4848
|
+
}
|
4849
|
+
}
|
4850
|
+
|
4851
|
+
static void whisper_compute_probs(
|
4852
|
+
const std::vector<float> & logits,
|
4853
|
+
const int n_logits,
|
4854
|
+
const std::vector<float> & logprobs,
|
4855
|
+
std::vector<float> & probs) {
|
4856
|
+
for (int i = 0; i < n_logits; ++i) {
|
4857
|
+
if (logits[i] == -INFINITY) {
|
4858
|
+
probs[i] = 0.0f;
|
4859
|
+
} else {
|
4860
|
+
probs[i] = expf(logprobs[i]);
|
4861
|
+
}
|
4862
|
+
}
|
4863
|
+
}
|
4864
|
+
|
4723
4865
|
// process the logits for the selected decoder
|
4724
4866
|
// - applies logit filters
|
4725
4867
|
// - computes logprobs and probs
|
@@ -4781,7 +4923,7 @@ static void whisper_process_logits(
|
|
4781
4923
|
|
4782
4924
|
// suppress sot and nosp tokens
|
4783
4925
|
logits[vocab.token_sot] = -INFINITY;
|
4784
|
-
logits[vocab.token_nosp] = -INFINITY;
|
4926
|
+
logits[vocab.token_nosp] = -INFINITY;
|
4785
4927
|
|
4786
4928
|
// [TDRZ] when tinydiarize is disabled, suppress solm token
|
4787
4929
|
if (params.tdrz_enable == false) {
|
@@ -4880,24 +5022,7 @@ static void whisper_process_logits(
|
|
4880
5022
|
}
|
4881
5023
|
|
4882
5024
|
// populate the logprobs array (log_softmax)
|
4883
|
-
|
4884
|
-
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
4885
|
-
float logsumexp = 0.0f;
|
4886
|
-
for (int i = 0; i < n_logits; ++i) {
|
4887
|
-
if (logits[i] > -INFINITY) {
|
4888
|
-
logsumexp += expf(logits[i] - logit_max);
|
4889
|
-
}
|
4890
|
-
}
|
4891
|
-
logsumexp = logf(logsumexp) + logit_max;
|
4892
|
-
|
4893
|
-
for (int i = 0; i < n_logits; ++i) {
|
4894
|
-
if (logits[i] > -INFINITY) {
|
4895
|
-
logprobs[i] = logits[i] - logsumexp;
|
4896
|
-
} else {
|
4897
|
-
logprobs[i] = -INFINITY;
|
4898
|
-
}
|
4899
|
-
}
|
4900
|
-
}
|
5025
|
+
whisper_compute_logprobs(logits, n_logits, logprobs);
|
4901
5026
|
|
4902
5027
|
// if sum of probability over timestamps is above any other token, sample timestamp
|
4903
5028
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
@@ -4955,15 +5080,7 @@ static void whisper_process_logits(
|
|
4955
5080
|
}
|
4956
5081
|
|
4957
5082
|
// compute probs
|
4958
|
-
|
4959
|
-
for (int i = 0; i < n_logits; ++i) {
|
4960
|
-
if (logits[i] == -INFINITY) {
|
4961
|
-
probs[i] = 0.0f;
|
4962
|
-
} else {
|
4963
|
-
probs[i] = expf(logprobs[i]);
|
4964
|
-
}
|
4965
|
-
}
|
4966
|
-
}
|
5083
|
+
whisper_compute_probs(logits, n_logits, logprobs, probs);
|
4967
5084
|
|
4968
5085
|
#if 0
|
4969
5086
|
// print first 100 logits - token string : logit
|
@@ -5228,15 +5345,9 @@ int whisper_full_with_state(
|
|
5228
5345
|
|
5229
5346
|
if (n_samples > 0) {
|
5230
5347
|
// compute log mel spectrogram
|
5231
|
-
if (params.
|
5232
|
-
// TODO: Replace PV with more advanced algorithm
|
5348
|
+
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
5233
5349
|
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
5234
|
-
return -
|
5235
|
-
} else {
|
5236
|
-
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
5237
|
-
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
5238
|
-
return -2;
|
5239
|
-
}
|
5350
|
+
return -2;
|
5240
5351
|
}
|
5241
5352
|
}
|
5242
5353
|
|
@@ -5273,7 +5384,7 @@ int whisper_full_with_state(
|
|
5273
5384
|
// if length of spectrogram is less than 1.0s (100 frames), then return
|
5274
5385
|
// basically don't process anything that is less than 1.0s
|
5275
5386
|
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
5276
|
-
if (seek_end < seek_start +
|
5387
|
+
if (seek_end < seek_start + 100) {
|
5277
5388
|
WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
|
5278
5389
|
return 0;
|
5279
5390
|
}
|
@@ -5518,13 +5629,46 @@ int whisper_full_with_state(
|
|
5518
5629
|
}
|
5519
5630
|
WHISPER_LOG_DEBUG("\n\n");
|
5520
5631
|
|
5632
|
+
// recreate the KV cache if the number of decoders has changed
|
5633
|
+
if (state->kv_self_n_dec < n_decoders_cur) {
|
5634
|
+
WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
|
5635
|
+
|
5636
|
+
whisper_kv_cache_free(state->kv_self);
|
5637
|
+
|
5638
|
+
// overallocate to workaround KV cache fragmentation issues
|
5639
|
+
const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
|
5640
|
+
|
5641
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
5642
|
+
ctx->model.hparams.n_text_state,
|
5643
|
+
ctx->model.hparams.n_text_layer,
|
5644
|
+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
5645
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
5646
|
+
whisper_free_state(state);
|
5647
|
+
return -7;
|
5648
|
+
}
|
5649
|
+
|
5650
|
+
state->kv_self_n_dec = n_decoders_cur;
|
5651
|
+
}
|
5652
|
+
|
5521
5653
|
whisper_kv_cache_clear(state->kv_self);
|
5522
5654
|
|
5523
5655
|
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
5524
5656
|
|
5525
5657
|
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
5526
5658
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
5527
|
-
return -
|
5659
|
+
return -8;
|
5660
|
+
}
|
5661
|
+
|
5662
|
+
// Calculate no_speech probability after first decode.
|
5663
|
+
// This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
|
5664
|
+
{
|
5665
|
+
const int n_logits = ctx->vocab.id_to_token.size();
|
5666
|
+
std::vector<float> logprobs(n_logits);
|
5667
|
+
std::vector<float> probs(n_logits);
|
5668
|
+
|
5669
|
+
whisper_compute_logprobs(state->logits, n_logits, logprobs);
|
5670
|
+
whisper_compute_probs(state->logits, n_logits, logprobs, probs);
|
5671
|
+
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
|
5528
5672
|
}
|
5529
5673
|
|
5530
5674
|
{
|
@@ -5824,7 +5968,7 @@ int whisper_full_with_state(
|
|
5824
5968
|
|
5825
5969
|
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
5826
5970
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
5827
|
-
return -
|
5971
|
+
return -9;
|
5828
5972
|
}
|
5829
5973
|
|
5830
5974
|
const int64_t t_start_sample_us = ggml_time_us();
|
@@ -5918,8 +6062,9 @@ int whisper_full_with_state(
|
|
5918
6062
|
if (it != (int) temperatures.size() - 1) {
|
5919
6063
|
const auto & decoder = state->decoders[best_decoder_id];
|
5920
6064
|
|
5921
|
-
if (decoder.failed ||
|
5922
|
-
|
6065
|
+
if (decoder.failed ||
|
6066
|
+
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
|
6067
|
+
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
|
5923
6068
|
success = false;
|
5924
6069
|
state->n_fail_p++;
|
5925
6070
|
}
|
@@ -5940,7 +6085,7 @@ int whisper_full_with_state(
|
|
5940
6085
|
{
|
5941
6086
|
const auto & best_decoder = state->decoders[best_decoder_id];
|
5942
6087
|
|
5943
|
-
|
6088
|
+
auto seek_delta = best_decoder.seek_delta;
|
5944
6089
|
const auto result_len = best_decoder.sequence.result_len;
|
5945
6090
|
|
5946
6091
|
const auto & tokens_cur = best_decoder.sequence.tokens;
|
@@ -5948,6 +6093,9 @@ int whisper_full_with_state(
|
|
5948
6093
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
5949
6094
|
const auto n_segments_before = state->result_all.size();
|
5950
6095
|
|
6096
|
+
const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
|
6097
|
+
best_decoder.sequence.avg_logprobs < params.logprob_thold);
|
6098
|
+
|
5951
6099
|
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
5952
6100
|
|
5953
6101
|
// update prompt_past
|
@@ -5956,11 +6104,11 @@ int whisper_full_with_state(
|
|
5956
6104
|
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
5957
6105
|
}
|
5958
6106
|
|
5959
|
-
for (int i = 0; i < result_len; ++i) {
|
6107
|
+
for (int i = 0; i < result_len && !is_no_speech; ++i) {
|
5960
6108
|
prompt_past.push_back(tokens_cur[i].id);
|
5961
6109
|
}
|
5962
6110
|
|
5963
|
-
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
|
6111
|
+
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
|
5964
6112
|
int i0 = 0;
|
5965
6113
|
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
5966
6114
|
|
@@ -5985,8 +6133,8 @@ int whisper_full_with_state(
|
|
5985
6133
|
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
5986
6134
|
|
5987
6135
|
if (!text.empty()) {
|
5988
|
-
const auto tt0 =
|
5989
|
-
const auto tt1 =
|
6136
|
+
const auto tt0 = t0;
|
6137
|
+
const auto tt1 = t1;
|
5990
6138
|
|
5991
6139
|
if (params.print_realtime) {
|
5992
6140
|
if (params.print_timestamps) {
|
@@ -6014,7 +6162,7 @@ int whisper_full_with_state(
|
|
6014
6162
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
6015
6163
|
}
|
6016
6164
|
}
|
6017
|
-
if (params.new_segment_callback) {
|
6165
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
6018
6166
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
6019
6167
|
}
|
6020
6168
|
}
|
@@ -6032,8 +6180,8 @@ int whisper_full_with_state(
|
|
6032
6180
|
if (!text.empty()) {
|
6033
6181
|
const auto t1 = seek + seek_delta;
|
6034
6182
|
|
6035
|
-
const auto tt0 =
|
6036
|
-
const auto tt1 =
|
6183
|
+
const auto tt0 = t0;
|
6184
|
+
const auto tt1 = t1;
|
6037
6185
|
|
6038
6186
|
if (params.print_realtime) {
|
6039
6187
|
if (params.print_timestamps) {
|
@@ -6059,7 +6207,7 @@ int whisper_full_with_state(
|
|
6059
6207
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
6060
6208
|
}
|
6061
6209
|
}
|
6062
|
-
if (params.new_segment_callback) {
|
6210
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
6063
6211
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
6064
6212
|
}
|
6065
6213
|
}
|
@@ -6068,14 +6216,28 @@ int whisper_full_with_state(
|
|
6068
6216
|
// FIXME: will timestamp offsets be correct?
|
6069
6217
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
6070
6218
|
{
|
6071
|
-
const
|
6219
|
+
const int n_segments = state->result_all.size() - n_segments_before;
|
6072
6220
|
if (ctx->params.dtw_token_timestamps && n_segments) {
|
6073
6221
|
const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
|
6074
6222
|
whisper_exp_compute_token_level_timestamps_dtw(
|
6075
6223
|
ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
|
6224
|
+
if (params.new_segment_callback) {
|
6225
|
+
for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
|
6226
|
+
params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
|
6227
|
+
}
|
6228
|
+
}
|
6076
6229
|
}
|
6077
6230
|
}
|
6078
6231
|
|
6232
|
+
// ref: https://github.com/ggerganov/whisper.cpp/pull/2629
|
6233
|
+
const bool single_timestamp_ending = tokens_cur.size() > 1 &&
|
6234
|
+
tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
|
6235
|
+
tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
|
6236
|
+
if (single_timestamp_ending) {
|
6237
|
+
WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n");
|
6238
|
+
seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100);
|
6239
|
+
}
|
6240
|
+
|
6079
6241
|
// update audio window
|
6080
6242
|
seek += seek_delta;
|
6081
6243
|
|
@@ -6835,7 +6997,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6835
6997
|
k++;
|
6836
6998
|
}
|
6837
6999
|
tokens[j].t1 = sample_to_timestamp(k);
|
6838
|
-
if (j <
|
7000
|
+
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
6839
7001
|
tokens[j].t1 = tokens[j + 1].t0;
|
6840
7002
|
} else {
|
6841
7003
|
s1 = k;
|
@@ -6998,10 +7160,11 @@ struct median_filter_user_data {
|
|
6998
7160
|
int filter_width;
|
6999
7161
|
};
|
7000
7162
|
|
7001
|
-
static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth
|
7163
|
+
static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
|
7164
|
+
if (ith != 0) {
|
7165
|
+
return;
|
7166
|
+
}
|
7002
7167
|
int filter_width = ((median_filter_user_data *) userdata)->filter_width;
|
7003
|
-
WHISPER_ASSERT(nth == 1);
|
7004
|
-
WHISPER_ASSERT(ith == 0);
|
7005
7168
|
WHISPER_ASSERT(filter_width < a->ne[2]);
|
7006
7169
|
WHISPER_ASSERT(filter_width % 2);
|
7007
7170
|
WHISPER_ASSERT(ggml_n_dims(a) == 3);
|
@@ -7124,7 +7287,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7124
7287
|
// operation (after median filter)
|
7125
7288
|
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
7126
7289
|
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
7127
|
-
w = ggml_norm(gctx, w, 1e-
|
7290
|
+
w = ggml_norm(gctx, w, 1e-9f);
|
7128
7291
|
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
|
7129
7292
|
|
7130
7293
|
// Pass median filter - this is done over AUDIO_TOKENS dimension.
|
@@ -7196,6 +7359,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7196
7359
|
void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
|
7197
7360
|
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
7198
7361
|
g_state.log_callback_user_data = user_data;
|
7362
|
+
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
7199
7363
|
}
|
7200
7364
|
|
7201
7365
|
GGML_ATTRIBUTE_FORMAT(2, 3)
|
@@ -7219,6 +7383,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
|
|
7219
7383
|
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
|
7220
7384
|
(void) level;
|
7221
7385
|
(void) user_data;
|
7386
|
+
#ifndef WHISPER_DEBUG
|
7387
|
+
if (level == GGML_LOG_LEVEL_DEBUG) {
|
7388
|
+
return;
|
7389
|
+
}
|
7390
|
+
#endif
|
7222
7391
|
fputs(text, stderr);
|
7223
7392
|
fflush(stderr);
|
7224
7393
|
}
|