whisper.rn 0.3.9 → 0.4.0-rc.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +2 -1
- package/android/src/main/jni.cpp +7 -1
- package/cpp/coreml/whisper-encoder.mm +7 -1
- package/cpp/ggml-alloc.c +633 -0
- package/cpp/ggml-alloc.h +26 -0
- package/cpp/ggml-metal.h +85 -0
- package/cpp/ggml-metal.m +1283 -0
- package/cpp/ggml-metal.metal +2353 -0
- package/cpp/ggml.c +5024 -2924
- package/cpp/ggml.h +569 -95
- package/cpp/whisper.cpp +1014 -667
- package/cpp/whisper.h +13 -0
- package/ios/RNWhisper.mm +2 -0
- package/ios/RNWhisperContext.h +1 -1
- package/ios/RNWhisperContext.mm +18 -4
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +3 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +3 -1
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +1 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +3 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +1 -0
- package/src/index.ts +4 -0
- package/whisper-rn.podspec +8 -2
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/whisper.cpp
CHANGED
|
@@ -3,11 +3,16 @@
|
|
|
3
3
|
#include "coreml/whisper-encoder.h"
|
|
4
4
|
#endif
|
|
5
5
|
|
|
6
|
-
#
|
|
6
|
+
#ifdef WSP_GGML_USE_METAL
|
|
7
|
+
# include "ggml-metal.h"
|
|
8
|
+
#endif
|
|
9
|
+
|
|
10
|
+
#ifdef WHISPER_USE_OPENVINO
|
|
7
11
|
#include "openvino/whisper-openvino-encoder.h"
|
|
8
12
|
#endif
|
|
9
13
|
|
|
10
14
|
#include "ggml.h"
|
|
15
|
+
#include "ggml-alloc.h"
|
|
11
16
|
|
|
12
17
|
#include <algorithm>
|
|
13
18
|
#include <cassert>
|
|
@@ -18,11 +23,13 @@
|
|
|
18
23
|
#include <cstring>
|
|
19
24
|
#include <fstream>
|
|
20
25
|
#include <map>
|
|
26
|
+
#include <set>
|
|
21
27
|
#include <string>
|
|
22
28
|
#include <thread>
|
|
23
29
|
#include <vector>
|
|
24
30
|
#include <regex>
|
|
25
31
|
#include <random>
|
|
32
|
+
#include <functional>
|
|
26
33
|
|
|
27
34
|
#if defined(_MSC_VER)
|
|
28
35
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
@@ -114,8 +121,66 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
|
|
|
114
121
|
//#define WHISPER_USE_FLASH_FF
|
|
115
122
|
#define WHISPER_MAX_DECODERS 16
|
|
116
123
|
|
|
117
|
-
|
|
118
|
-
|
|
124
|
+
//
|
|
125
|
+
// ggml helpers
|
|
126
|
+
//
|
|
127
|
+
|
|
128
|
+
static void wsp_ggml_graph_compute_helper(
|
|
129
|
+
std::vector<uint8_t> & buf,
|
|
130
|
+
wsp_ggml_cgraph * graph,
|
|
131
|
+
int n_threads,
|
|
132
|
+
whisper_abort_callback abort_callback,
|
|
133
|
+
void * abort_callback_data) {
|
|
134
|
+
struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads);
|
|
135
|
+
|
|
136
|
+
plan.abort_callback = abort_callback;
|
|
137
|
+
plan.abort_callback_data = abort_callback_data;
|
|
138
|
+
|
|
139
|
+
if (plan.work_size > 0) {
|
|
140
|
+
buf.resize(plan.work_size);
|
|
141
|
+
plan.work_data = buf.data();
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
wsp_ggml_graph_compute(graph, &plan);
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
|
148
|
+
// the idea is to represent the original matrix multiplication:
|
|
149
|
+
//
|
|
150
|
+
// Z = X @ Y
|
|
151
|
+
//
|
|
152
|
+
// with the sum of two matrix multiplications:
|
|
153
|
+
//
|
|
154
|
+
// Z = (X_0 @ Y_0) + (X_1 @ Y_1)
|
|
155
|
+
//
|
|
156
|
+
// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
|
|
157
|
+
// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
|
|
158
|
+
// general-purpose kernels
|
|
159
|
+
//
|
|
160
|
+
static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * x, struct wsp_ggml_tensor * y, int pad = 32) {
|
|
161
|
+
// use padding only if dimension 0 is at least 8 times larger than the padding
|
|
162
|
+
// else we won't get much benefit from the optimization
|
|
163
|
+
const int n_pad_req = 8;
|
|
164
|
+
|
|
165
|
+
if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
|
|
166
|
+
return wsp_ggml_mul_mat(ctx, x, y);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
struct wsp_ggml_tensor * x_0 = wsp_ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
|
|
170
|
+
struct wsp_ggml_tensor * x_1 = wsp_ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
|
|
171
|
+
|
|
172
|
+
struct wsp_ggml_tensor * y_0 = wsp_ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
|
|
173
|
+
struct wsp_ggml_tensor * y_1 = wsp_ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
|
|
174
|
+
|
|
175
|
+
return wsp_ggml_add(ctx,
|
|
176
|
+
wsp_ggml_mul_mat(ctx, x_0, y_0),
|
|
177
|
+
wsp_ggml_mul_mat(ctx, x_1, y_1));
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
// TODO: check if other platforms can benefit from this optimization
|
|
181
|
+
#if defined(WSP_GGML_USE_METAL)
|
|
182
|
+
#define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad
|
|
183
|
+
#endif
|
|
119
184
|
|
|
120
185
|
// available whisper models
|
|
121
186
|
enum e_model {
|
|
@@ -231,38 +296,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
|
231
296
|
|
|
232
297
|
static const size_t MB = 1ull*1024*1024;
|
|
233
298
|
|
|
234
|
-
|
|
235
|
-
{ MODEL_TINY, 62ull*MB },
|
|
236
|
-
{ MODEL_BASE, 80ull*MB },
|
|
237
|
-
{ MODEL_SMALL, 120ull*MB },
|
|
238
|
-
{ MODEL_MEDIUM, 158ull*MB },
|
|
239
|
-
{ MODEL_LARGE, 198ull*MB },
|
|
240
|
-
};
|
|
241
|
-
|
|
242
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
|
|
243
|
-
{ MODEL_TINY, 18ull*MB },
|
|
244
|
-
{ MODEL_BASE, 24ull*MB },
|
|
245
|
-
{ MODEL_SMALL, 36ull*MB },
|
|
246
|
-
{ MODEL_MEDIUM, 48ull*MB },
|
|
247
|
-
{ MODEL_LARGE, 60ull*MB },
|
|
248
|
-
};
|
|
249
|
-
|
|
250
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
|
|
251
|
-
{ MODEL_TINY, 4ull*MB },
|
|
252
|
-
{ MODEL_BASE, 4ull*MB },
|
|
253
|
-
{ MODEL_SMALL, 6ull*MB },
|
|
254
|
-
{ MODEL_MEDIUM, 7ull*MB },
|
|
255
|
-
{ MODEL_LARGE, 9ull*MB },
|
|
256
|
-
};
|
|
257
|
-
|
|
258
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
|
|
259
|
-
{ MODEL_TINY, 4ull*MB },
|
|
260
|
-
{ MODEL_BASE, 4ull*MB },
|
|
261
|
-
{ MODEL_SMALL, 6ull*MB },
|
|
262
|
-
{ MODEL_MEDIUM, 7ull*MB },
|
|
263
|
-
{ MODEL_LARGE, 9ull*MB },
|
|
264
|
-
};
|
|
265
|
-
|
|
299
|
+
// TODO: avoid using GGUF
|
|
266
300
|
static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
|
267
301
|
{ WSP_GGML_TYPE_F32,
|
|
268
302
|
{
|
|
@@ -329,38 +363,6 @@ static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL =
|
|
|
329
363
|
},
|
|
330
364
|
};
|
|
331
365
|
|
|
332
|
-
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
|
|
333
|
-
{ MODEL_TINY, 3ull*MB },
|
|
334
|
-
{ MODEL_BASE, 6ull*MB },
|
|
335
|
-
{ MODEL_SMALL, 16ull*MB },
|
|
336
|
-
{ MODEL_MEDIUM, 43ull*MB },
|
|
337
|
-
{ MODEL_LARGE, 71ull*MB },
|
|
338
|
-
};
|
|
339
|
-
|
|
340
|
-
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
|
|
341
|
-
{ MODEL_TINY, 9ull*MB },
|
|
342
|
-
{ MODEL_BASE, 18ull*MB },
|
|
343
|
-
{ MODEL_SMALL, 53ull*MB },
|
|
344
|
-
{ MODEL_MEDIUM, 141ull*MB },
|
|
345
|
-
{ MODEL_LARGE, 235ull*MB },
|
|
346
|
-
};
|
|
347
|
-
|
|
348
|
-
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
|
349
|
-
{ MODEL_TINY, 30ull*MB },
|
|
350
|
-
{ MODEL_BASE, 38ull*MB },
|
|
351
|
-
{ MODEL_SMALL, 56ull*MB },
|
|
352
|
-
{ MODEL_MEDIUM, 74ull*MB },
|
|
353
|
-
{ MODEL_LARGE, 94ull*MB },
|
|
354
|
-
};
|
|
355
|
-
|
|
356
|
-
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
|
357
|
-
{ MODEL_TINY, 3ull*MB },
|
|
358
|
-
{ MODEL_BASE, 5ull*MB },
|
|
359
|
-
{ MODEL_SMALL, 10ull*MB },
|
|
360
|
-
{ MODEL_MEDIUM, 18ull*MB },
|
|
361
|
-
{ MODEL_LARGE, 27ull*MB },
|
|
362
|
-
};
|
|
363
|
-
|
|
364
366
|
struct whisper_mel {
|
|
365
367
|
int n_len;
|
|
366
368
|
int n_len_org;
|
|
@@ -441,6 +443,7 @@ struct whisper_hparams {
|
|
|
441
443
|
int32_t n_text_layer = 4;
|
|
442
444
|
int32_t n_mels = 80;
|
|
443
445
|
int32_t ftype = 1;
|
|
446
|
+
float eps = 1e-5f;
|
|
444
447
|
};
|
|
445
448
|
|
|
446
449
|
// audio encoding layer
|
|
@@ -536,6 +539,7 @@ struct whisper_kv_cache {
|
|
|
536
539
|
|
|
537
540
|
struct wsp_ggml_context * ctx;
|
|
538
541
|
|
|
542
|
+
// buf points to the memory allocated for both wsp_ggml_tensor 'k' and 'v' (see kv_cache_init)
|
|
539
543
|
std::vector<uint8_t> buf;
|
|
540
544
|
|
|
541
545
|
int n; // number of tokens currently in the cache
|
|
@@ -601,7 +605,7 @@ struct whisper_sequence {
|
|
|
601
605
|
|
|
602
606
|
// TAGS: WHISPER_DECODER_INIT
|
|
603
607
|
struct whisper_decoder {
|
|
604
|
-
// each
|
|
608
|
+
// each decoder keeps its own KV-cache
|
|
605
609
|
whisper_kv_cache kv_self;
|
|
606
610
|
|
|
607
611
|
// the currently generated sequence of tokens
|
|
@@ -621,15 +625,75 @@ struct whisper_decoder {
|
|
|
621
625
|
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
|
622
626
|
};
|
|
623
627
|
|
|
628
|
+
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
|
629
|
+
template<typename A, typename B>
|
|
630
|
+
struct whisper_pair {
|
|
631
|
+
A first;
|
|
632
|
+
B second;
|
|
633
|
+
|
|
634
|
+
// Define a constructor that takes two arguments.
|
|
635
|
+
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
|
636
|
+
// Define a constructor that takes no argument.
|
|
637
|
+
whisper_pair() : first(A()), second(B()) {}
|
|
638
|
+
};
|
|
639
|
+
|
|
640
|
+
// beam-search helpers
|
|
641
|
+
struct kv_buf {
|
|
642
|
+
std::vector<uint8_t> k;
|
|
643
|
+
std::vector<uint8_t> v;
|
|
644
|
+
};
|
|
645
|
+
|
|
646
|
+
// wsp_ggml_allocr wrapper for whisper usage
|
|
647
|
+
struct whisper_allocr {
|
|
648
|
+
wsp_ggml_allocr * alloc = nullptr;
|
|
649
|
+
|
|
650
|
+
std::vector<uint8_t> meta;
|
|
651
|
+
std::vector<uint8_t> data;
|
|
652
|
+
};
|
|
653
|
+
|
|
654
|
+
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
|
655
|
+
return allocr.meta.size() + allocr.data.size();
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
|
659
|
+
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
|
|
660
|
+
const int tensor_alignment = 32;
|
|
661
|
+
|
|
662
|
+
auto & alloc = allocr.alloc;
|
|
663
|
+
auto & meta = allocr.meta;
|
|
664
|
+
auto & data = allocr.data;
|
|
665
|
+
|
|
666
|
+
meta.resize(wsp_ggml_tensor_overhead()*WSP_GGML_MAX_NODES + wsp_ggml_graph_overhead());
|
|
667
|
+
|
|
668
|
+
alloc = wsp_ggml_allocr_new_measure(tensor_alignment);
|
|
669
|
+
|
|
670
|
+
const size_t alloc_size = wsp_ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
|
|
671
|
+
|
|
672
|
+
wsp_ggml_allocr_free(alloc);
|
|
673
|
+
|
|
674
|
+
data.resize(alloc_size);
|
|
675
|
+
|
|
676
|
+
alloc = wsp_ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
|
680
|
+
if (allocr.alloc) {
|
|
681
|
+
wsp_ggml_allocr_free(allocr.alloc);
|
|
682
|
+
allocr.alloc = nullptr;
|
|
683
|
+
}
|
|
684
|
+
}
|
|
685
|
+
|
|
624
686
|
struct whisper_state {
|
|
625
687
|
int64_t t_sample_us = 0;
|
|
626
688
|
int64_t t_encode_us = 0;
|
|
627
689
|
int64_t t_decode_us = 0;
|
|
690
|
+
int64_t t_prompt_us = 0;
|
|
628
691
|
int64_t t_mel_us = 0;
|
|
629
692
|
|
|
630
693
|
int32_t n_sample = 0; // number of tokens sampled
|
|
631
694
|
int32_t n_encode = 0; // number of encoder calls
|
|
632
|
-
int32_t n_decode = 0; // number of decoder calls
|
|
695
|
+
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
|
|
696
|
+
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
|
|
633
697
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
|
634
698
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
|
635
699
|
|
|
@@ -640,12 +704,23 @@ struct whisper_state {
|
|
|
640
704
|
|
|
641
705
|
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
|
642
706
|
|
|
643
|
-
//
|
|
644
|
-
std::vector<
|
|
645
|
-
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
|
|
707
|
+
// buffer for swapping KV caches between decoders during beam-search
|
|
708
|
+
std::vector<kv_buf> kv_swap_bufs;
|
|
646
709
|
|
|
647
|
-
|
|
648
|
-
|
|
710
|
+
// reusable buffer for `struct wsp_ggml_graph_plan.work_data`
|
|
711
|
+
std::vector<uint8_t> work_buffer;
|
|
712
|
+
|
|
713
|
+
// ggml-alloc:
|
|
714
|
+
// - stores meta info about the intermediate tensors into the `meta` buffers
|
|
715
|
+
// - stores the actual tensor data into the `data` buffers
|
|
716
|
+
whisper_allocr alloc_conv;
|
|
717
|
+
whisper_allocr alloc_encode;
|
|
718
|
+
whisper_allocr alloc_cross;
|
|
719
|
+
whisper_allocr alloc_decode;
|
|
720
|
+
|
|
721
|
+
// result of the encoder
|
|
722
|
+
struct wsp_ggml_tensor * embd_conv = nullptr;
|
|
723
|
+
struct wsp_ggml_tensor * embd_enc = nullptr;
|
|
649
724
|
|
|
650
725
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
|
651
726
|
std::vector<float> logits;
|
|
@@ -654,7 +729,7 @@ struct whisper_state {
|
|
|
654
729
|
std::vector<whisper_token> prompt_past;
|
|
655
730
|
|
|
656
731
|
// work container used to avoid memory allocations
|
|
657
|
-
std::vector<
|
|
732
|
+
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
|
658
733
|
|
|
659
734
|
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
660
735
|
|
|
@@ -665,6 +740,10 @@ struct whisper_state {
|
|
|
665
740
|
whisper_coreml_context * ctx_coreml = nullptr;
|
|
666
741
|
#endif
|
|
667
742
|
|
|
743
|
+
#ifdef WSP_GGML_USE_METAL
|
|
744
|
+
wsp_ggml_metal_context * ctx_metal = nullptr;
|
|
745
|
+
#endif
|
|
746
|
+
|
|
668
747
|
#ifdef WHISPER_USE_OPENVINO
|
|
669
748
|
whisper_openvino_context * ctx_openvino = nullptr;
|
|
670
749
|
#endif
|
|
@@ -677,37 +756,6 @@ struct whisper_state {
|
|
|
677
756
|
|
|
678
757
|
// [EXPERIMENTAL] speed-up techniques
|
|
679
758
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
|
680
|
-
|
|
681
|
-
void use_buf(struct wsp_ggml_context * ctx, int i) {
|
|
682
|
-
#if defined(WHISPER_USE_SCRATCH)
|
|
683
|
-
size_t last_size = 0;
|
|
684
|
-
|
|
685
|
-
if (i == -1) {
|
|
686
|
-
last_size = wsp_ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
|
687
|
-
} else {
|
|
688
|
-
auto & buf = buf_scratch[i];
|
|
689
|
-
last_size = wsp_ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
|
|
690
|
-
}
|
|
691
|
-
|
|
692
|
-
if (buf_last >= 0) {
|
|
693
|
-
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
|
694
|
-
}
|
|
695
|
-
|
|
696
|
-
buf_last = i;
|
|
697
|
-
#else
|
|
698
|
-
(void) i;
|
|
699
|
-
(void) ctx;
|
|
700
|
-
#endif
|
|
701
|
-
}
|
|
702
|
-
|
|
703
|
-
size_t get_buf_max_mem(int i) const {
|
|
704
|
-
#if defined(WHISPER_USE_SCRATCH)
|
|
705
|
-
return buf_max_size[i];
|
|
706
|
-
#else
|
|
707
|
-
(void) i;
|
|
708
|
-
return 0;
|
|
709
|
-
#endif
|
|
710
|
-
}
|
|
711
759
|
};
|
|
712
760
|
|
|
713
761
|
struct whisper_context {
|
|
@@ -722,6 +770,9 @@ struct whisper_context {
|
|
|
722
770
|
whisper_state * state = nullptr;
|
|
723
771
|
|
|
724
772
|
std::string path_model; // populated by whisper_init_from_file()
|
|
773
|
+
#ifdef WHISPER_USE_COREML
|
|
774
|
+
bool load_coreml = true;
|
|
775
|
+
#endif
|
|
725
776
|
};
|
|
726
777
|
|
|
727
778
|
static void whisper_default_log(const char * text) {
|
|
@@ -730,6 +781,13 @@ static void whisper_default_log(const char * text) {
|
|
|
730
781
|
|
|
731
782
|
static whisper_log_callback whisper_log = whisper_default_log;
|
|
732
783
|
|
|
784
|
+
#ifdef __GNUC__
|
|
785
|
+
#ifdef __MINGW32__
|
|
786
|
+
__attribute__((gnu_format(printf, 1, 2)))
|
|
787
|
+
#else
|
|
788
|
+
__attribute__((format(printf, 1, 2)))
|
|
789
|
+
#endif
|
|
790
|
+
#endif
|
|
733
791
|
static void log(const char * fmt, ...) {
|
|
734
792
|
if (!whisper_log) return;
|
|
735
793
|
char buf[1024];
|
|
@@ -747,10 +805,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
|
747
805
|
|
|
748
806
|
static bool kv_cache_init(
|
|
749
807
|
const struct whisper_hparams & hparams,
|
|
750
|
-
const size_t mem_bytes,
|
|
751
808
|
struct whisper_kv_cache & cache,
|
|
752
809
|
wsp_ggml_type wtype,
|
|
753
810
|
int n_ctx) {
|
|
811
|
+
const int64_t n_text_state = hparams.n_text_state;
|
|
812
|
+
const int64_t n_text_layer = hparams.n_text_layer;
|
|
813
|
+
|
|
814
|
+
const int64_t n_mem = n_text_layer*n_ctx;
|
|
815
|
+
const int64_t n_elements = n_text_state*n_mem;
|
|
816
|
+
|
|
817
|
+
const size_t mem_bytes = 2*(wsp_ggml_type_size(wtype)*n_elements + wsp_ggml_tensor_overhead());
|
|
818
|
+
|
|
754
819
|
cache.buf.resize(mem_bytes);
|
|
755
820
|
|
|
756
821
|
struct wsp_ggml_init_params params = {
|
|
@@ -766,12 +831,6 @@ static bool kv_cache_init(
|
|
|
766
831
|
return false;
|
|
767
832
|
}
|
|
768
833
|
|
|
769
|
-
const int n_text_state = hparams.n_text_state;
|
|
770
|
-
const int n_text_layer = hparams.n_text_layer;
|
|
771
|
-
|
|
772
|
-
const int n_mem = n_text_layer*n_ctx;
|
|
773
|
-
const int n_elements = n_text_state*n_mem;
|
|
774
|
-
|
|
775
834
|
cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
|
776
835
|
cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
|
777
836
|
|
|
@@ -914,22 +973,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
914
973
|
|
|
915
974
|
// print memory requirements
|
|
916
975
|
{
|
|
917
|
-
//
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
MEM_REQ_SCRATCH1.at(model.type) +
|
|
921
|
-
MEM_REQ_SCRATCH2.at(model.type) +
|
|
922
|
-
MEM_REQ_SCRATCH3.at(model.type) +
|
|
923
|
-
scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
|
|
924
|
-
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
|
925
|
-
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
|
926
|
-
|
|
927
|
-
// this is the memory required by one decoder
|
|
928
|
-
const size_t mem_required_decoder =
|
|
929
|
-
scale*MEM_REQ_KV_SELF.at(model.type);
|
|
930
|
-
|
|
931
|
-
log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
|
932
|
-
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
|
976
|
+
// TODO
|
|
977
|
+
//log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
|
978
|
+
// mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
|
933
979
|
}
|
|
934
980
|
|
|
935
981
|
// initialize all memory buffers
|
|
@@ -1438,49 +1484,56 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1438
1484
|
return true;
|
|
1439
1485
|
}
|
|
1440
1486
|
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
|
1444
|
-
// part of the transformer model and returns the encoded features
|
|
1445
|
-
//
|
|
1446
|
-
// - wctx: the model
|
|
1447
|
-
// - wstate: the state of the encoder
|
|
1448
|
-
// - n_threads: number of threads to use
|
|
1449
|
-
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
|
1450
|
-
//
|
|
1451
|
-
static bool whisper_encode_internal(
|
|
1452
|
-
whisper_context & wctx,
|
|
1453
|
-
whisper_state & wstate,
|
|
1454
|
-
const int mel_offset,
|
|
1455
|
-
const int n_threads){
|
|
1487
|
+
static bool whisper_encode_external(const whisper_state & wstate) {
|
|
1488
|
+
WSP_GGML_UNUSED(wstate);
|
|
1456
1489
|
|
|
1457
|
-
|
|
1490
|
+
#ifndef WHISPER_USE_COREML
|
|
1491
|
+
const bool use_coreml = false;
|
|
1492
|
+
#else
|
|
1493
|
+
const bool use_coreml = wstate.ctx_coreml != nullptr;
|
|
1494
|
+
#endif
|
|
1495
|
+
|
|
1496
|
+
#ifndef WHISPER_USE_OPENVINO
|
|
1497
|
+
const bool use_openvino = false;
|
|
1498
|
+
#else
|
|
1499
|
+
const bool use_openvino = wstate.ctx_openvino != nullptr;
|
|
1500
|
+
#endif
|
|
1501
|
+
|
|
1502
|
+
return use_coreml || use_openvino;
|
|
1503
|
+
}
|
|
1458
1504
|
|
|
1505
|
+
static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
1506
|
+
whisper_context & wctx,
|
|
1507
|
+
whisper_state & wstate,
|
|
1508
|
+
const int mel_offset) {
|
|
1459
1509
|
const auto & model = wctx.model;
|
|
1460
1510
|
const auto & mel_inp = wstate.mel;
|
|
1461
1511
|
const auto & hparams = model.hparams;
|
|
1462
1512
|
|
|
1463
1513
|
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
1464
|
-
const int n_state = hparams.n_audio_state;
|
|
1465
|
-
const int n_head = hparams.n_audio_head;
|
|
1466
|
-
const int n_layer = hparams.n_audio_layer;
|
|
1514
|
+
const int n_state = hparams.n_audio_state; WSP_GGML_UNUSED(n_state);
|
|
1467
1515
|
|
|
1468
1516
|
const int n_mels = hparams.n_mels;
|
|
1469
|
-
assert(mel_inp.n_mel == n_mels);
|
|
1470
1517
|
|
|
1471
1518
|
struct wsp_ggml_init_params params = {
|
|
1472
|
-
/*.mem_size =*/ wstate.
|
|
1473
|
-
/*.mem_buffer =*/ wstate.
|
|
1474
|
-
/*.no_alloc =*/
|
|
1519
|
+
/*.mem_size =*/ wstate.alloc_conv.meta.size(),
|
|
1520
|
+
/*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
|
|
1521
|
+
/*.no_alloc =*/ true,
|
|
1475
1522
|
};
|
|
1476
1523
|
|
|
1477
1524
|
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
1478
1525
|
|
|
1479
|
-
|
|
1526
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
1527
|
+
|
|
1528
|
+
wsp_ggml_allocr * alloc = wstate.alloc_conv.alloc;
|
|
1480
1529
|
|
|
1481
1530
|
struct wsp_ggml_tensor * mel = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, 2*n_ctx, n_mels);
|
|
1531
|
+
wsp_ggml_allocr_alloc(alloc, mel);
|
|
1532
|
+
|
|
1482
1533
|
assert(mel->type == WSP_GGML_TYPE_F32);
|
|
1483
|
-
{
|
|
1534
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1535
|
+
assert(mel_inp.n_mel == n_mels);
|
|
1536
|
+
|
|
1484
1537
|
float * dst = (float *) mel->data;
|
|
1485
1538
|
memset(dst, 0, wsp_ggml_nbytes(mel));
|
|
1486
1539
|
|
|
@@ -1494,25 +1547,11 @@ static bool whisper_encode_internal(
|
|
|
1494
1547
|
}
|
|
1495
1548
|
}
|
|
1496
1549
|
|
|
1497
|
-
struct wsp_ggml_tensor * cur;
|
|
1498
|
-
|
|
1499
|
-
#ifndef WHISPER_USE_COREML
|
|
1500
|
-
const bool use_coreml = false;
|
|
1501
|
-
#else
|
|
1502
|
-
const bool use_coreml = wstate.ctx_coreml != nullptr;
|
|
1503
|
-
#endif
|
|
1504
|
-
|
|
1505
|
-
#ifndef WHISPER_USE_OPENVINO
|
|
1506
|
-
const bool use_openvino = false;
|
|
1507
|
-
#else
|
|
1508
|
-
const bool use_openvino = wstate.ctx_openvino != nullptr;
|
|
1509
|
-
#endif
|
|
1550
|
+
struct wsp_ggml_tensor * cur = nullptr;
|
|
1510
1551
|
|
|
1511
|
-
if (!
|
|
1552
|
+
if (!whisper_encode_external(wstate)) {
|
|
1512
1553
|
// convolution + gelu
|
|
1513
1554
|
{
|
|
1514
|
-
wstate.use_buf(ctx0, 1);
|
|
1515
|
-
|
|
1516
1555
|
cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
|
1517
1556
|
cur = wsp_ggml_add(ctx0,
|
|
1518
1557
|
wsp_ggml_repeat(ctx0,
|
|
@@ -1522,8 +1561,6 @@ static bool whisper_encode_internal(
|
|
|
1522
1561
|
|
|
1523
1562
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1524
1563
|
|
|
1525
|
-
wstate.use_buf(ctx0, 0);
|
|
1526
|
-
|
|
1527
1564
|
cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
|
1528
1565
|
cur = wsp_ggml_add(ctx0,
|
|
1529
1566
|
wsp_ggml_repeat(ctx0,
|
|
@@ -1534,373 +1571,433 @@ static bool whisper_encode_internal(
|
|
|
1534
1571
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1535
1572
|
}
|
|
1536
1573
|
|
|
1537
|
-
wstate.
|
|
1574
|
+
wstate.embd_conv = cur;
|
|
1575
|
+
} else {
|
|
1576
|
+
#ifdef WHISPER_USE_COREML
|
|
1577
|
+
cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
1578
|
+
wsp_ggml_allocr_alloc(alloc, cur);
|
|
1538
1579
|
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1580
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1581
|
+
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
|
1582
|
+
}
|
|
1583
|
+
#endif
|
|
1584
|
+
#ifdef WHISPER_USE_OPENVINO
|
|
1585
|
+
cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
1586
|
+
wsp_ggml_allocr_alloc(alloc, cur);
|
|
1543
1587
|
|
|
1544
|
-
|
|
1588
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1589
|
+
whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
|
|
1590
|
+
}
|
|
1591
|
+
#endif
|
|
1545
1592
|
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
// memset(model.memory_cross_v->data, 0, wsp_ggml_nbytes(model.memory_cross_v));
|
|
1549
|
-
//}
|
|
1593
|
+
wstate.embd_enc = cur;
|
|
1594
|
+
}
|
|
1550
1595
|
|
|
1551
|
-
|
|
1596
|
+
wsp_ggml_build_forward_expand(gf, cur);
|
|
1552
1597
|
|
|
1553
|
-
|
|
1554
|
-
const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
|
|
1598
|
+
wsp_ggml_free(ctx0);
|
|
1555
1599
|
|
|
1556
|
-
|
|
1600
|
+
return gf;
|
|
1601
|
+
}
|
|
1557
1602
|
|
|
1558
|
-
|
|
1603
|
+
static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
1604
|
+
whisper_context & wctx,
|
|
1605
|
+
whisper_state & wstate) {
|
|
1606
|
+
const auto & model = wctx.model;
|
|
1607
|
+
const auto & hparams = model.hparams;
|
|
1559
1608
|
|
|
1560
|
-
|
|
1609
|
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
1610
|
+
const int n_state = hparams.n_audio_state;
|
|
1611
|
+
const int n_head = hparams.n_audio_head;
|
|
1612
|
+
const int n_layer = hparams.n_audio_layer;
|
|
1561
1613
|
|
|
1562
|
-
|
|
1563
|
-
|
|
1614
|
+
struct wsp_ggml_init_params params = {
|
|
1615
|
+
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
|
|
1616
|
+
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
|
|
1617
|
+
/*.no_alloc =*/ true,
|
|
1618
|
+
};
|
|
1564
1619
|
|
|
1565
|
-
|
|
1620
|
+
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
1566
1621
|
|
|
1567
|
-
|
|
1568
|
-
const auto & layer = model.layers_encoder[il];
|
|
1622
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
1569
1623
|
|
|
1570
|
-
|
|
1571
|
-
{
|
|
1572
|
-
wstate.use_buf(ctx0, 0);
|
|
1624
|
+
wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
|
1573
1625
|
|
|
1574
|
-
|
|
1626
|
+
struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
1627
|
+
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
1575
1628
|
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
wsp_ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
|
1580
|
-
cur),
|
|
1581
|
-
wsp_ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
|
1582
|
-
}
|
|
1629
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1630
|
+
wsp_ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
|
|
1631
|
+
}
|
|
1583
1632
|
|
|
1584
|
-
|
|
1585
|
-
{
|
|
1586
|
-
wstate.use_buf(ctx0, 1);
|
|
1633
|
+
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
|
|
1587
1634
|
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1635
|
+
// ===================================================================
|
|
1636
|
+
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
1637
|
+
//static int iter = -1;
|
|
1638
|
+
//const int n_iter = 1500/n_ctx;
|
|
1591
1639
|
|
|
1592
|
-
|
|
1593
|
-
wsp_ggml_repeat(ctx0,
|
|
1594
|
-
layer.attn_q_b,
|
|
1595
|
-
Qcur),
|
|
1596
|
-
Qcur);
|
|
1640
|
+
//iter = (iter + 1) % n_iter;
|
|
1597
1641
|
|
|
1598
|
-
|
|
1642
|
+
//if (iter == 0) {
|
|
1643
|
+
// memset(model.memory_cross_k->data, 0, wsp_ggml_nbytes(model.memory_cross_k));
|
|
1644
|
+
// memset(model.memory_cross_v->data, 0, wsp_ggml_nbytes(model.memory_cross_v));
|
|
1645
|
+
//}
|
|
1599
1646
|
|
|
1600
|
-
|
|
1601
|
-
struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
|
|
1602
|
-
layer.attn_k_w,
|
|
1603
|
-
cur);
|
|
1647
|
+
static int iter = 0;
|
|
1604
1648
|
|
|
1605
|
-
|
|
1649
|
+
const size_t e_pe_stride = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe);
|
|
1650
|
+
const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
|
|
1606
1651
|
|
|
1607
|
-
|
|
1608
|
-
layer.attn_v_w,
|
|
1609
|
-
cur);
|
|
1652
|
+
struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
|
|
1610
1653
|
|
|
1611
|
-
|
|
1612
|
-
wsp_ggml_repeat(ctx0,
|
|
1613
|
-
layer.attn_v_b,
|
|
1614
|
-
Vcur),
|
|
1615
|
-
Vcur);
|
|
1616
|
-
|
|
1617
|
-
// ------
|
|
1654
|
+
cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_cont(ctx0, wsp_ggml_transpose(ctx0, cur)));
|
|
1618
1655
|
|
|
1619
|
-
|
|
1656
|
+
// ===================================================================
|
|
1620
1657
|
|
|
1621
|
-
|
|
1622
|
-
|
|
1623
|
-
wsp_ggml_permute(ctx0,
|
|
1624
|
-
wsp_ggml_cpy(ctx0,
|
|
1625
|
-
Qcur,
|
|
1626
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1627
|
-
0, 2, 1, 3);
|
|
1628
|
-
|
|
1629
|
-
struct wsp_ggml_tensor * K =
|
|
1630
|
-
wsp_ggml_permute(ctx0,
|
|
1631
|
-
wsp_ggml_cpy(ctx0,
|
|
1632
|
-
Kcur,
|
|
1633
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1634
|
-
0, 2, 1, 3);
|
|
1635
|
-
|
|
1636
|
-
struct wsp_ggml_tensor * V =
|
|
1637
|
-
wsp_ggml_cpy(ctx0,
|
|
1638
|
-
wsp_ggml_permute(ctx0,
|
|
1639
|
-
wsp_ggml_reshape_3d(ctx0,
|
|
1640
|
-
Vcur,
|
|
1641
|
-
n_state/n_head, n_head, n_ctx),
|
|
1642
|
-
1, 2, 0, 3),
|
|
1643
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
|
1644
|
-
|
|
1645
|
-
struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false);
|
|
1646
|
-
#else
|
|
1647
|
-
struct wsp_ggml_tensor * Q =
|
|
1648
|
-
wsp_ggml_permute(ctx0,
|
|
1649
|
-
wsp_ggml_cpy(ctx0,
|
|
1650
|
-
Qcur,
|
|
1651
|
-
wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
|
1652
|
-
0, 2, 1, 3);
|
|
1653
|
-
|
|
1654
|
-
struct wsp_ggml_tensor * K =
|
|
1655
|
-
wsp_ggml_permute(ctx0,
|
|
1656
|
-
wsp_ggml_cpy(ctx0,
|
|
1657
|
-
Kcur,
|
|
1658
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1659
|
-
0, 2, 1, 3);
|
|
1660
|
-
|
|
1661
|
-
// K * Q
|
|
1662
|
-
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
|
|
1663
|
-
|
|
1664
|
-
struct wsp_ggml_tensor * KQ_scaled =
|
|
1665
|
-
wsp_ggml_scale_inplace(ctx0,
|
|
1666
|
-
KQ,
|
|
1667
|
-
wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
|
1668
|
-
);
|
|
1669
|
-
|
|
1670
|
-
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ_scaled);
|
|
1671
|
-
|
|
1672
|
-
struct wsp_ggml_tensor * V =
|
|
1673
|
-
wsp_ggml_cpy(ctx0,
|
|
1674
|
-
wsp_ggml_permute(ctx0,
|
|
1675
|
-
wsp_ggml_reshape_3d(ctx0,
|
|
1676
|
-
Vcur,
|
|
1677
|
-
n_state/n_head, n_head, n_ctx),
|
|
1678
|
-
1, 2, 0, 3),
|
|
1679
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
|
1680
|
-
);
|
|
1681
|
-
|
|
1682
|
-
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
1683
|
-
#endif
|
|
1684
|
-
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
1658
|
+
// original:
|
|
1659
|
+
//cur = wsp_ggml_add(ctx0, model.e_pe, wsp_ggml_transpose(ctx0, cur));
|
|
1685
1660
|
|
|
1686
|
-
|
|
1661
|
+
struct wsp_ggml_tensor * inpL = cur;
|
|
1687
1662
|
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
|
|
1691
|
-
}
|
|
1663
|
+
for (int il = 0; il < n_layer; ++il) {
|
|
1664
|
+
const auto & layer = model.layers_encoder[il];
|
|
1692
1665
|
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1666
|
+
// norm
|
|
1667
|
+
{
|
|
1668
|
+
cur = wsp_ggml_norm(ctx0, inpL, hparams.eps);
|
|
1696
1669
|
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1670
|
+
// cur = ln_0_w*cur + ln_0_b
|
|
1671
|
+
cur = wsp_ggml_add(ctx0,
|
|
1672
|
+
wsp_ggml_mul(ctx0, cur, layer.attn_ln_0_w),
|
|
1673
|
+
layer.attn_ln_0_b);
|
|
1674
|
+
}
|
|
1700
1675
|
|
|
1701
|
-
|
|
1676
|
+
// self-attention
|
|
1677
|
+
{
|
|
1678
|
+
struct wsp_ggml_tensor * Qcur = wsp_ggml_mul_mat(ctx0,
|
|
1679
|
+
layer.attn_q_w,
|
|
1680
|
+
cur);
|
|
1702
1681
|
|
|
1703
|
-
|
|
1704
|
-
wsp_ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
|
1705
|
-
cur);
|
|
1706
|
-
}
|
|
1682
|
+
Qcur = wsp_ggml_add(ctx0, Qcur, layer.attn_q_b);
|
|
1707
1683
|
|
|
1708
|
-
|
|
1684
|
+
//Qcur = wsp_ggml_scale(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1709
1685
|
|
|
1710
|
-
//
|
|
1711
|
-
|
|
1686
|
+
// note: no bias for Key
|
|
1687
|
+
struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
|
|
1688
|
+
layer.attn_k_w,
|
|
1689
|
+
cur);
|
|
1712
1690
|
|
|
1713
|
-
|
|
1691
|
+
//Kcur = wsp_ggml_scale(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1714
1692
|
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
{
|
|
1719
|
-
wstate.use_buf(ctx0, 0);
|
|
1693
|
+
struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
|
|
1694
|
+
layer.attn_v_w,
|
|
1695
|
+
cur);
|
|
1720
1696
|
|
|
1721
|
-
|
|
1697
|
+
Vcur = wsp_ggml_add(ctx0, Vcur, layer.attn_v_b);
|
|
1722
1698
|
|
|
1723
|
-
|
|
1699
|
+
// ------
|
|
1724
1700
|
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
|
|
1701
|
+
#ifdef WHISPER_USE_FLASH_ATTN
|
|
1702
|
+
struct wsp_ggml_tensor * Q =
|
|
1703
|
+
wsp_ggml_permute(ctx0,
|
|
1704
|
+
wsp_ggml_cpy(ctx0,
|
|
1705
|
+
Qcur,
|
|
1706
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1707
|
+
0, 2, 1, 3);
|
|
1732
1708
|
|
|
1733
|
-
|
|
1734
|
-
|
|
1709
|
+
struct wsp_ggml_tensor * K =
|
|
1710
|
+
wsp_ggml_permute(ctx0,
|
|
1711
|
+
wsp_ggml_cpy(ctx0,
|
|
1712
|
+
Kcur,
|
|
1713
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1714
|
+
0, 2, 1, 3);
|
|
1735
1715
|
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1716
|
+
struct wsp_ggml_tensor * V =
|
|
1717
|
+
wsp_ggml_cpy(ctx0,
|
|
1718
|
+
wsp_ggml_permute(ctx0,
|
|
1719
|
+
wsp_ggml_reshape_3d(ctx0,
|
|
1720
|
+
Vcur,
|
|
1721
|
+
n_state/n_head, n_head, n_ctx),
|
|
1722
|
+
1, 2, 0, 3),
|
|
1723
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
|
1724
|
+
|
|
1725
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false);
|
|
1739
1726
|
#else
|
|
1740
|
-
|
|
1727
|
+
struct wsp_ggml_tensor * Q =
|
|
1728
|
+
wsp_ggml_permute(ctx0,
|
|
1729
|
+
wsp_ggml_cpy(ctx0,
|
|
1730
|
+
Qcur,
|
|
1731
|
+
wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
|
1732
|
+
0, 2, 1, 3);
|
|
1741
1733
|
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1734
|
+
struct wsp_ggml_tensor * K =
|
|
1735
|
+
wsp_ggml_permute(ctx0,
|
|
1736
|
+
wsp_ggml_cpy(ctx0,
|
|
1737
|
+
Kcur,
|
|
1738
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1739
|
+
0, 2, 1, 3);
|
|
1746
1740
|
|
|
1747
|
-
|
|
1741
|
+
// K * Q
|
|
1742
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
|
|
1748
1743
|
|
|
1749
|
-
|
|
1750
|
-
wsp_ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
|
1751
|
-
cur);
|
|
1744
|
+
struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQscale);
|
|
1752
1745
|
|
|
1753
|
-
|
|
1746
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_scaled);
|
|
1754
1747
|
|
|
1755
|
-
|
|
1756
|
-
|
|
1748
|
+
struct wsp_ggml_tensor * V =
|
|
1749
|
+
wsp_ggml_cpy(ctx0,
|
|
1750
|
+
wsp_ggml_permute(ctx0,
|
|
1751
|
+
wsp_ggml_reshape_3d(ctx0,
|
|
1752
|
+
Vcur,
|
|
1753
|
+
n_state/n_head, n_head, n_ctx),
|
|
1754
|
+
1, 2, 0, 3),
|
|
1755
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
|
1756
|
+
);
|
|
1757
1757
|
|
|
1758
|
-
|
|
1758
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
1759
|
+
#endif
|
|
1760
|
+
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
1759
1761
|
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1762
|
+
cur = wsp_ggml_cpy(ctx0,
|
|
1763
|
+
KQV_merged,
|
|
1764
|
+
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
|
|
1765
|
+
}
|
|
1766
|
+
|
|
1767
|
+
// projection
|
|
1768
|
+
{
|
|
1769
|
+
cur = wsp_ggml_mul_mat(ctx0,
|
|
1770
|
+
layer.attn_ln_1_w,
|
|
1771
|
+
cur);
|
|
1772
|
+
|
|
1773
|
+
cur = wsp_ggml_add(ctx0, cur, layer.attn_ln_1_b);
|
|
1774
|
+
}
|
|
1775
|
+
|
|
1776
|
+
// add the input
|
|
1777
|
+
cur = wsp_ggml_add(ctx0, cur, inpL);
|
|
1764
1778
|
|
|
1765
|
-
|
|
1779
|
+
struct wsp_ggml_tensor * inpFF = cur;
|
|
1780
|
+
|
|
1781
|
+
// feed-forward network
|
|
1782
|
+
{
|
|
1783
|
+
// norm
|
|
1784
|
+
{
|
|
1785
|
+
cur = wsp_ggml_norm(ctx0, inpFF, hparams.eps);
|
|
1766
1786
|
|
|
1787
|
+
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
1767
1788
|
cur = wsp_ggml_add(ctx0,
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
#endif
|
|
1789
|
+
wsp_ggml_mul(ctx0, cur, layer.mlp_ln_w),
|
|
1790
|
+
layer.mlp_ln_b);
|
|
1771
1791
|
}
|
|
1772
1792
|
|
|
1773
|
-
|
|
1793
|
+
#ifdef WHISPER_USE_FLASH_FF
|
|
1794
|
+
cur = wsp_ggml_flash_ff(ctx0,
|
|
1795
|
+
wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
|
1796
|
+
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
1797
|
+
#else
|
|
1798
|
+
// fully connected
|
|
1799
|
+
cur = wsp_ggml_mul_mat(ctx0,
|
|
1800
|
+
layer.mlp_0_w,
|
|
1801
|
+
cur);
|
|
1802
|
+
|
|
1803
|
+
cur = wsp_ggml_add(ctx0, cur, layer.mlp_0_b);
|
|
1804
|
+
|
|
1805
|
+
// GELU activation
|
|
1806
|
+
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1807
|
+
|
|
1808
|
+
// projection
|
|
1809
|
+
cur = wsp_ggml_mul_mat(ctx0,
|
|
1810
|
+
layer.mlp_1_w,
|
|
1811
|
+
cur);
|
|
1774
1812
|
|
|
1775
|
-
|
|
1813
|
+
cur = wsp_ggml_add(ctx0, cur, layer.mlp_1_b);
|
|
1814
|
+
#endif
|
|
1776
1815
|
}
|
|
1777
1816
|
|
|
1778
|
-
|
|
1817
|
+
inpL = wsp_ggml_add(ctx0, cur, inpFF);
|
|
1818
|
+
}
|
|
1779
1819
|
|
|
1780
|
-
|
|
1781
|
-
{
|
|
1782
|
-
wstate.use_buf(ctx0, 0);
|
|
1820
|
+
cur = inpL;
|
|
1783
1821
|
|
|
1784
|
-
|
|
1822
|
+
// norm
|
|
1823
|
+
{
|
|
1824
|
+
cur = wsp_ggml_norm(ctx0, cur, hparams.eps);
|
|
1785
1825
|
|
|
1786
|
-
|
|
1826
|
+
// cur = ln_f_g*cur + ln_f_b
|
|
1827
|
+
cur = wsp_ggml_add(ctx0,
|
|
1828
|
+
wsp_ggml_mul(ctx0, cur, model.e_ln_w),
|
|
1829
|
+
model.e_ln_b);
|
|
1830
|
+
}
|
|
1787
1831
|
|
|
1788
|
-
|
|
1789
|
-
cur = wsp_ggml_add(ctx0,
|
|
1790
|
-
wsp_ggml_mul(ctx0,
|
|
1791
|
-
wsp_ggml_repeat(ctx0, model.e_ln_w, cur),
|
|
1792
|
-
cur),
|
|
1793
|
-
wsp_ggml_repeat(ctx0, model.e_ln_b, cur));
|
|
1794
|
-
}
|
|
1832
|
+
wsp_ggml_build_forward_expand(gf, cur);
|
|
1795
1833
|
|
|
1796
|
-
|
|
1834
|
+
wstate.embd_enc = cur;
|
|
1797
1835
|
|
|
1798
|
-
|
|
1799
|
-
{
|
|
1800
|
-
struct wsp_ggml_cgraph gf = {};
|
|
1801
|
-
gf.n_threads = n_threads;
|
|
1836
|
+
//wsp_ggml_graph_print(gf);
|
|
1802
1837
|
|
|
1803
|
-
|
|
1804
|
-
wsp_ggml_graph_compute(ctx0, &gf);
|
|
1838
|
+
////////////////////////////////////////////////////////////////////////////
|
|
1805
1839
|
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
wstate.
|
|
1840
|
+
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
1841
|
+
// wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
1842
|
+
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
1843
|
+
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
|
1844
|
+
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
|
1845
|
+
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
|
1812
1846
|
|
|
1813
|
-
|
|
1847
|
+
wsp_ggml_free(ctx0);
|
|
1814
1848
|
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
#endif
|
|
1818
|
-
#ifdef WHISPER_USE_OPENVINO
|
|
1819
|
-
else if (use_openvino) {
|
|
1820
|
-
wstate.use_buf(ctx0, -1);
|
|
1849
|
+
return gf;
|
|
1850
|
+
}
|
|
1821
1851
|
|
|
1822
|
-
|
|
1852
|
+
// pre-compute cross-attention memory
|
|
1853
|
+
static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
1854
|
+
whisper_context & wctx,
|
|
1855
|
+
whisper_state & wstate) {
|
|
1856
|
+
const auto & model = wctx.model;
|
|
1857
|
+
const auto & hparams = model.hparams;
|
|
1823
1858
|
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
}
|
|
1828
|
-
#endif
|
|
1859
|
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
1860
|
+
const int n_state = hparams.n_audio_state;
|
|
1861
|
+
const int n_head = hparams.n_audio_head;
|
|
1829
1862
|
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
1836
|
-
// }
|
|
1837
|
-
// printf("... ");
|
|
1838
|
-
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
|
|
1839
|
-
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
1840
|
-
// }
|
|
1841
|
-
// printf("\n");
|
|
1842
|
-
//}
|
|
1863
|
+
struct wsp_ggml_init_params params = {
|
|
1864
|
+
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
|
|
1865
|
+
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
|
|
1866
|
+
/*.no_alloc =*/ true,
|
|
1867
|
+
};
|
|
1843
1868
|
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
gf.n_threads = n_threads;
|
|
1869
|
+
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
1870
|
+
|
|
1871
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
1848
1872
|
|
|
1849
|
-
|
|
1850
|
-
cur->op = WSP_GGML_OP_NONE;
|
|
1851
|
-
cur->src0 = nullptr;
|
|
1852
|
-
cur->src1 = nullptr;
|
|
1873
|
+
wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
|
1853
1874
|
|
|
1854
|
-
|
|
1855
|
-
|
|
1875
|
+
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
|
|
1876
|
+
|
|
1877
|
+
struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
1878
|
+
wsp_ggml_allocr_alloc(alloc, Kscale);
|
|
1879
|
+
|
|
1880
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1881
|
+
wsp_ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
|
|
1882
|
+
}
|
|
1856
1883
|
|
|
1857
|
-
|
|
1884
|
+
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
1885
|
+
auto & layer = model.layers_decoder[il];
|
|
1858
1886
|
|
|
1859
|
-
|
|
1887
|
+
struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
|
|
1860
1888
|
layer.cross_attn_k_w,
|
|
1861
1889
|
cur);
|
|
1862
1890
|
|
|
1863
|
-
|
|
1891
|
+
Kcross = wsp_ggml_scale(ctx0, Kcross, Kscale);
|
|
1864
1892
|
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
|
|
1893
|
+
struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
|
|
1868
1894
|
layer.cross_attn_v_w,
|
|
1869
1895
|
cur);
|
|
1870
1896
|
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
layer.cross_attn_v_b
|
|
1874
|
-
|
|
1875
|
-
|
|
1897
|
+
Vcross = wsp_ggml_add(ctx0,
|
|
1898
|
+
Vcross,
|
|
1899
|
+
layer.cross_attn_v_b);
|
|
1900
|
+
|
|
1901
|
+
Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
|
1902
|
+
|
|
1903
|
+
struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k,
|
|
1904
|
+
n_state*n_ctx,
|
|
1905
|
+
(wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
|
1906
|
+
|
|
1907
|
+
struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
|
1908
|
+
( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
1909
|
+
(il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
|
|
1910
|
+
|
|
1911
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcross, k));
|
|
1912
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcross, v));
|
|
1913
|
+
}
|
|
1876
1914
|
|
|
1877
|
-
|
|
1915
|
+
//wsp_ggml_graph_print(gf);
|
|
1878
1916
|
|
|
1879
|
-
|
|
1917
|
+
wsp_ggml_free(ctx0);
|
|
1880
1918
|
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
1884
|
-
(il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
|
|
1919
|
+
return gf;
|
|
1920
|
+
}
|
|
1885
1921
|
|
|
1886
|
-
|
|
1887
|
-
|
|
1922
|
+
// evaluate the encoder with the given state
|
|
1923
|
+
//
|
|
1924
|
+
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
|
1925
|
+
// part of the transformer model and returns the encoded features
|
|
1926
|
+
//
|
|
1927
|
+
// - wctx: the model
|
|
1928
|
+
// - wstate: the state of the encoder
|
|
1929
|
+
// - n_threads: number of threads to use
|
|
1930
|
+
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
|
1931
|
+
//
|
|
1932
|
+
static bool whisper_encode_internal(
|
|
1933
|
+
whisper_context & wctx,
|
|
1934
|
+
whisper_state & wstate,
|
|
1935
|
+
const int mel_offset,
|
|
1936
|
+
const int n_threads,
|
|
1937
|
+
whisper_abort_callback abort_callback,
|
|
1938
|
+
void * abort_callback_data) {
|
|
1939
|
+
const int64_t t_start_us = wsp_ggml_time_us();
|
|
1940
|
+
|
|
1941
|
+
// conv
|
|
1942
|
+
{
|
|
1943
|
+
auto & alloc = wstate.alloc_conv.alloc;
|
|
1944
|
+
|
|
1945
|
+
wsp_ggml_allocr_reset(alloc);
|
|
1946
|
+
|
|
1947
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
|
|
1948
|
+
|
|
1949
|
+
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1950
|
+
|
|
1951
|
+
if (!whisper_encode_external(wstate)) {
|
|
1952
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1888
1953
|
}
|
|
1954
|
+
}
|
|
1955
|
+
|
|
1956
|
+
// encoder
|
|
1957
|
+
if (!whisper_encode_external(wstate)) {
|
|
1958
|
+
auto & alloc = wstate.alloc_encode.alloc;
|
|
1959
|
+
|
|
1960
|
+
wsp_ggml_allocr_reset(alloc);
|
|
1889
1961
|
|
|
1890
|
-
|
|
1891
|
-
|
|
1962
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
|
|
1963
|
+
|
|
1964
|
+
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1965
|
+
|
|
1966
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1967
|
+
if (wstate.ctx_metal) {
|
|
1968
|
+
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
1969
|
+
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
1970
|
+
} else {
|
|
1971
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1972
|
+
}
|
|
1973
|
+
#else
|
|
1974
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1975
|
+
#endif
|
|
1892
1976
|
}
|
|
1893
1977
|
|
|
1894
|
-
|
|
1978
|
+
// cross
|
|
1979
|
+
{
|
|
1980
|
+
auto & alloc = wstate.alloc_cross.alloc;
|
|
1895
1981
|
|
|
1896
|
-
|
|
1897
|
-
// wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
1898
|
-
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
1899
|
-
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
|
1900
|
-
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
|
1901
|
-
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
|
1982
|
+
wsp_ggml_allocr_reset(alloc);
|
|
1902
1983
|
|
|
1903
|
-
|
|
1984
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
|
1985
|
+
|
|
1986
|
+
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1987
|
+
|
|
1988
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1989
|
+
if (wstate.ctx_metal) {
|
|
1990
|
+
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
1991
|
+
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
1992
|
+
} else {
|
|
1993
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1994
|
+
}
|
|
1995
|
+
#else
|
|
1996
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1997
|
+
#endif
|
|
1998
|
+
}
|
|
1999
|
+
|
|
2000
|
+
// wsp_ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
1904
2001
|
|
|
1905
2002
|
wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
|
|
1906
2003
|
wstate.n_encode++;
|
|
@@ -1908,26 +2005,13 @@ static bool whisper_encode_internal(
|
|
|
1908
2005
|
return true;
|
|
1909
2006
|
}
|
|
1910
2007
|
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
// - n_tokens: number of tokens in the prompt
|
|
1919
|
-
// - n_past: number of past tokens to prefix the prompt with
|
|
1920
|
-
//
|
|
1921
|
-
static bool whisper_decode_internal(
|
|
1922
|
-
whisper_context & wctx,
|
|
1923
|
-
whisper_state & wstate,
|
|
1924
|
-
whisper_decoder & decoder,
|
|
1925
|
-
const whisper_token * tokens,
|
|
1926
|
-
const int n_tokens,
|
|
1927
|
-
const int n_past,
|
|
1928
|
-
const int n_threads) {
|
|
1929
|
-
const int64_t t_start_us = wsp_ggml_time_us();
|
|
1930
|
-
|
|
2008
|
+
static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
2009
|
+
whisper_context & wctx,
|
|
2010
|
+
whisper_state & wstate,
|
|
2011
|
+
whisper_decoder & decoder,
|
|
2012
|
+
const whisper_token * tokens,
|
|
2013
|
+
int n_tokens,
|
|
2014
|
+
int n_past) {
|
|
1931
2015
|
const auto & model = wctx.model;
|
|
1932
2016
|
const auto & hparams = model.hparams;
|
|
1933
2017
|
|
|
@@ -1935,10 +2019,6 @@ static bool whisper_decode_internal(
|
|
|
1935
2019
|
|
|
1936
2020
|
WHISPER_ASSERT(!!kv_self.ctx);
|
|
1937
2021
|
|
|
1938
|
-
auto & logits_out = wstate.logits;
|
|
1939
|
-
|
|
1940
|
-
const int n_vocab = hparams.n_vocab;
|
|
1941
|
-
|
|
1942
2022
|
const int n_ctx = hparams.n_text_ctx;
|
|
1943
2023
|
const int n_state = hparams.n_text_state;
|
|
1944
2024
|
const int n_head = hparams.n_text_head;
|
|
@@ -1950,25 +2030,39 @@ static bool whisper_decode_internal(
|
|
|
1950
2030
|
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
|
1951
2031
|
|
|
1952
2032
|
struct wsp_ggml_init_params params = {
|
|
1953
|
-
/*.mem_size =*/ wstate.
|
|
1954
|
-
/*.mem_buffer =*/ wstate.
|
|
1955
|
-
/*.no_alloc =*/
|
|
2033
|
+
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
|
2034
|
+
/*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
|
|
2035
|
+
/*.no_alloc =*/ true,
|
|
1956
2036
|
};
|
|
1957
2037
|
|
|
1958
2038
|
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
1959
2039
|
|
|
1960
|
-
|
|
1961
|
-
|
|
2040
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
2041
|
+
|
|
2042
|
+
wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
|
1962
2043
|
|
|
1963
2044
|
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
|
|
1964
|
-
|
|
2045
|
+
wsp_ggml_allocr_alloc(alloc, embd);
|
|
2046
|
+
|
|
2047
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2048
|
+
memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd));
|
|
2049
|
+
}
|
|
1965
2050
|
|
|
1966
2051
|
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
|
|
1967
|
-
|
|
1968
|
-
|
|
2052
|
+
wsp_ggml_allocr_alloc(alloc, position);
|
|
2053
|
+
|
|
2054
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2055
|
+
for (int i = 0; i < N; ++i) {
|
|
2056
|
+
((int32_t *) position->data)[i] = n_past + i;
|
|
2057
|
+
}
|
|
1969
2058
|
}
|
|
1970
2059
|
|
|
1971
|
-
|
|
2060
|
+
struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
2061
|
+
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
2062
|
+
|
|
2063
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2064
|
+
wsp_ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
|
|
2065
|
+
}
|
|
1972
2066
|
|
|
1973
2067
|
// token encoding + position encoding
|
|
1974
2068
|
struct wsp_ggml_tensor * cur =
|
|
@@ -1983,16 +2077,14 @@ static bool whisper_decode_internal(
|
|
|
1983
2077
|
|
|
1984
2078
|
// norm
|
|
1985
2079
|
{
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
cur = wsp_ggml_norm(ctx0, inpL);
|
|
2080
|
+
cur = wsp_ggml_norm(ctx0, inpL, hparams.eps);
|
|
1989
2081
|
|
|
1990
2082
|
// cur = ln_0_w*cur + ln_0_b
|
|
1991
2083
|
cur = wsp_ggml_add(ctx0,
|
|
1992
2084
|
wsp_ggml_mul(ctx0,
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
|
|
2085
|
+
cur,
|
|
2086
|
+
layer.attn_ln_0_w),
|
|
2087
|
+
layer.attn_ln_0_b);
|
|
1996
2088
|
}
|
|
1997
2089
|
|
|
1998
2090
|
// self-attention
|
|
@@ -2002,19 +2094,17 @@ static bool whisper_decode_internal(
|
|
|
2002
2094
|
cur);
|
|
2003
2095
|
|
|
2004
2096
|
Qcur = wsp_ggml_add(ctx0,
|
|
2005
|
-
|
|
2006
|
-
layer.attn_q_b
|
|
2007
|
-
Qcur),
|
|
2008
|
-
Qcur);
|
|
2097
|
+
Qcur,
|
|
2098
|
+
layer.attn_q_b);
|
|
2009
2099
|
|
|
2010
|
-
Qcur =
|
|
2100
|
+
Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
|
|
2011
2101
|
|
|
2012
2102
|
// note: no bias for Key
|
|
2013
2103
|
struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
|
|
2014
2104
|
layer.attn_k_w,
|
|
2015
2105
|
cur);
|
|
2016
2106
|
|
|
2017
|
-
Kcur =
|
|
2107
|
+
Kcur = wsp_ggml_scale(ctx0, Kcur, KQscale);
|
|
2018
2108
|
|
|
2019
2109
|
// store key and value to memory
|
|
2020
2110
|
{
|
|
@@ -2023,10 +2113,8 @@ static bool whisper_decode_internal(
|
|
|
2023
2113
|
cur);
|
|
2024
2114
|
|
|
2025
2115
|
Vcur = wsp_ggml_add(ctx0,
|
|
2026
|
-
|
|
2027
|
-
layer.attn_v_b
|
|
2028
|
-
Vcur),
|
|
2029
|
-
Vcur);
|
|
2116
|
+
Vcur,
|
|
2117
|
+
layer.attn_v_b);
|
|
2030
2118
|
|
|
2031
2119
|
Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, N));
|
|
2032
2120
|
|
|
@@ -2035,42 +2123,32 @@ static bool whisper_decode_internal(
|
|
|
2035
2123
|
( n_ctx)*wsp_ggml_element_size(kv_self.v),
|
|
2036
2124
|
(il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + n_past*wsp_ggml_element_size(kv_self.v));
|
|
2037
2125
|
|
|
2038
|
-
wsp_ggml_build_forward_expand(
|
|
2039
|
-
wsp_ggml_build_forward_expand(
|
|
2126
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
|
|
2127
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
|
|
2040
2128
|
}
|
|
2041
2129
|
|
|
2042
2130
|
// ------
|
|
2043
2131
|
|
|
2044
|
-
wstate.use_buf(ctx0, 0);
|
|
2045
|
-
|
|
2046
2132
|
struct wsp_ggml_tensor * Q =
|
|
2047
2133
|
wsp_ggml_permute(ctx0,
|
|
2048
|
-
|
|
2049
|
-
Qcur,
|
|
2050
|
-
wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
|
2134
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
|
2051
2135
|
0, 2, 1, 3);
|
|
2052
2136
|
|
|
2053
2137
|
struct wsp_ggml_tensor * K =
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
wstate.use_buf(ctx0, 1);
|
|
2138
|
+
wsp_ggml_view_3d(ctx0, kv_self.k,
|
|
2139
|
+
n_state/n_head, n_past + N, n_head,
|
|
2140
|
+
wsp_ggml_element_size(kv_self.k)*n_state,
|
|
2141
|
+
wsp_ggml_element_size(kv_self.k)*n_state/n_head,
|
|
2142
|
+
wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
|
2061
2143
|
|
|
2062
2144
|
// K * Q
|
|
2063
2145
|
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
|
|
2064
2146
|
|
|
2065
|
-
//struct wsp_ggml_tensor * KQ_scaled =
|
|
2066
|
-
// wsp_ggml_scale_inplace(ctx0,
|
|
2067
|
-
// KQ,
|
|
2068
|
-
// wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
|
2069
|
-
// );
|
|
2147
|
+
//struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
|
|
2070
2148
|
|
|
2071
|
-
struct wsp_ggml_tensor * KQ_masked =
|
|
2149
|
+
struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
2072
2150
|
|
|
2073
|
-
struct wsp_ggml_tensor * KQ_soft_max =
|
|
2151
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
|
|
2074
2152
|
|
|
2075
2153
|
struct wsp_ggml_tensor * V =
|
|
2076
2154
|
wsp_ggml_view_3d(ctx0, kv_self.v,
|
|
@@ -2090,36 +2168,28 @@ static bool whisper_decode_internal(
|
|
|
2090
2168
|
|
|
2091
2169
|
// projection
|
|
2092
2170
|
{
|
|
2093
|
-
wstate.use_buf(ctx0, 0);
|
|
2094
|
-
|
|
2095
2171
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
2096
2172
|
layer.attn_ln_1_w,
|
|
2097
2173
|
cur);
|
|
2098
2174
|
|
|
2099
|
-
wstate.use_buf(ctx0, 1);
|
|
2100
|
-
|
|
2101
2175
|
cur = wsp_ggml_add(ctx0,
|
|
2102
|
-
|
|
2103
|
-
|
|
2176
|
+
cur,
|
|
2177
|
+
layer.attn_ln_1_b);
|
|
2104
2178
|
}
|
|
2105
2179
|
|
|
2106
|
-
wstate.use_buf(ctx0, 2);
|
|
2107
|
-
|
|
2108
2180
|
// add the input
|
|
2109
2181
|
struct wsp_ggml_tensor * inpCA = wsp_ggml_add(ctx0, cur, inpL);
|
|
2110
2182
|
|
|
2111
2183
|
// norm
|
|
2112
2184
|
{
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
cur = wsp_ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
|
2185
|
+
cur = wsp_ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
|
|
2116
2186
|
|
|
2117
2187
|
// cur = ln_0_w*cur + ln_0_b
|
|
2118
2188
|
cur = wsp_ggml_add(ctx0,
|
|
2119
2189
|
wsp_ggml_mul(ctx0,
|
|
2120
|
-
|
|
2121
|
-
|
|
2122
|
-
|
|
2190
|
+
cur,
|
|
2191
|
+
layer.cross_attn_ln_0_w),
|
|
2192
|
+
layer.cross_attn_ln_0_b);
|
|
2123
2193
|
}
|
|
2124
2194
|
|
|
2125
2195
|
// cross-attention
|
|
@@ -2129,18 +2199,18 @@ static bool whisper_decode_internal(
|
|
|
2129
2199
|
cur);
|
|
2130
2200
|
|
|
2131
2201
|
Qcur = wsp_ggml_add(ctx0,
|
|
2132
|
-
|
|
2133
|
-
layer.cross_attn_q_b
|
|
2134
|
-
Qcur),
|
|
2135
|
-
Qcur);
|
|
2202
|
+
Qcur,
|
|
2203
|
+
layer.cross_attn_q_b);
|
|
2136
2204
|
|
|
2137
|
-
Qcur =
|
|
2205
|
+
Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
|
|
2138
2206
|
|
|
2139
2207
|
// Kcross is already scaled
|
|
2140
2208
|
struct wsp_ggml_tensor * Kcross =
|
|
2141
|
-
|
|
2142
|
-
|
|
2143
|
-
n_state
|
|
2209
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2210
|
+
n_state/n_head, M, n_head,
|
|
2211
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2212
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
|
2213
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
|
|
2144
2214
|
|
|
2145
2215
|
//struct wsp_ggml_tensor * Vcross =
|
|
2146
2216
|
// wsp_ggml_reshape_3d(ctx0,
|
|
@@ -2163,26 +2233,22 @@ static bool whisper_decode_internal(
|
|
|
2163
2233
|
|
|
2164
2234
|
struct wsp_ggml_tensor * Q =
|
|
2165
2235
|
wsp_ggml_permute(ctx0,
|
|
2166
|
-
|
|
2167
|
-
Qcur,
|
|
2168
|
-
wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
|
2236
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
|
2169
2237
|
0, 2, 1, 3);
|
|
2170
2238
|
|
|
2171
|
-
struct wsp_ggml_tensor * K = wsp_ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
|
2172
|
-
|
|
2173
2239
|
// K * Q
|
|
2174
|
-
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0,
|
|
2240
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
|
|
2175
2241
|
|
|
2176
2242
|
//struct wsp_ggml_tensor * KQ_scaled =
|
|
2177
|
-
//
|
|
2243
|
+
// wsp_ggml_scale(ctx0,
|
|
2178
2244
|
// KQ,
|
|
2179
2245
|
// wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
|
2180
2246
|
// );
|
|
2181
2247
|
|
|
2182
2248
|
// no masking for cross-attention
|
|
2183
|
-
//struct wsp_ggml_tensor * KQ_masked =
|
|
2249
|
+
//struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
|
2184
2250
|
|
|
2185
|
-
struct wsp_ggml_tensor * KQ_soft_max =
|
|
2251
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ);
|
|
2186
2252
|
|
|
2187
2253
|
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
2188
2254
|
|
|
@@ -2196,21 +2262,15 @@ static bool whisper_decode_internal(
|
|
|
2196
2262
|
|
|
2197
2263
|
// projection
|
|
2198
2264
|
{
|
|
2199
|
-
wstate.use_buf(ctx0, 0);
|
|
2200
|
-
|
|
2201
2265
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
2202
2266
|
layer.cross_attn_ln_1_w,
|
|
2203
2267
|
cur);
|
|
2204
2268
|
|
|
2205
|
-
wstate.use_buf(ctx0, 1);
|
|
2206
|
-
|
|
2207
2269
|
cur = wsp_ggml_add(ctx0,
|
|
2208
|
-
|
|
2209
|
-
|
|
2270
|
+
cur,
|
|
2271
|
+
layer.cross_attn_ln_1_b);
|
|
2210
2272
|
}
|
|
2211
2273
|
|
|
2212
|
-
wstate.use_buf(ctx0, 2);
|
|
2213
|
-
|
|
2214
2274
|
// add the input
|
|
2215
2275
|
cur = wsp_ggml_add(ctx0, cur, inpCA);
|
|
2216
2276
|
|
|
@@ -2220,54 +2280,38 @@ static bool whisper_decode_internal(
|
|
|
2220
2280
|
{
|
|
2221
2281
|
// norm
|
|
2222
2282
|
{
|
|
2223
|
-
|
|
2224
|
-
|
|
2225
|
-
cur = wsp_ggml_norm(ctx0, inpFF);
|
|
2226
|
-
|
|
2227
|
-
wstate.use_buf(ctx0, 1);
|
|
2283
|
+
cur = wsp_ggml_norm(ctx0, inpFF, hparams.eps);
|
|
2228
2284
|
|
|
2229
2285
|
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
2230
2286
|
cur = wsp_ggml_add(ctx0,
|
|
2231
2287
|
wsp_ggml_mul(ctx0,
|
|
2232
|
-
|
|
2233
|
-
|
|
2234
|
-
|
|
2288
|
+
cur,
|
|
2289
|
+
layer.mlp_ln_w),
|
|
2290
|
+
layer.mlp_ln_b);
|
|
2235
2291
|
}
|
|
2236
2292
|
|
|
2237
|
-
wstate.use_buf(ctx0, 0);
|
|
2238
|
-
|
|
2239
2293
|
// fully connected
|
|
2240
2294
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
2241
2295
|
layer.mlp_0_w,
|
|
2242
2296
|
cur);
|
|
2243
2297
|
|
|
2244
|
-
wstate.use_buf(ctx0, 1);
|
|
2245
|
-
|
|
2246
2298
|
cur = wsp_ggml_add(ctx0,
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2250
|
-
wstate.use_buf(ctx0, 0);
|
|
2299
|
+
cur,
|
|
2300
|
+
layer.mlp_0_b);
|
|
2251
2301
|
|
|
2252
2302
|
// GELU activation
|
|
2253
2303
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
2254
2304
|
|
|
2255
|
-
wstate.use_buf(ctx0, 1);
|
|
2256
|
-
|
|
2257
2305
|
// projection
|
|
2258
2306
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
2259
2307
|
layer.mlp_1_w,
|
|
2260
2308
|
cur);
|
|
2261
2309
|
|
|
2262
|
-
wstate.use_buf(ctx0, 0);
|
|
2263
|
-
|
|
2264
2310
|
cur = wsp_ggml_add(ctx0,
|
|
2265
|
-
|
|
2266
|
-
|
|
2311
|
+
cur,
|
|
2312
|
+
layer.mlp_1_b);
|
|
2267
2313
|
}
|
|
2268
2314
|
|
|
2269
|
-
wstate.use_buf(ctx0, 3);
|
|
2270
|
-
|
|
2271
2315
|
inpL = wsp_ggml_add(ctx0, cur, inpFF);
|
|
2272
2316
|
}
|
|
2273
2317
|
|
|
@@ -2275,21 +2319,15 @@ static bool whisper_decode_internal(
|
|
|
2275
2319
|
|
|
2276
2320
|
// norm
|
|
2277
2321
|
{
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
cur = wsp_ggml_norm(ctx0, cur);
|
|
2281
|
-
|
|
2282
|
-
wstate.use_buf(ctx0, 1);
|
|
2322
|
+
cur = wsp_ggml_norm(ctx0, cur, hparams.eps);
|
|
2283
2323
|
|
|
2284
2324
|
cur = wsp_ggml_add(ctx0,
|
|
2285
2325
|
wsp_ggml_mul(ctx0,
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2326
|
+
cur,
|
|
2327
|
+
model.d_ln_w),
|
|
2328
|
+
model.d_ln_b);
|
|
2289
2329
|
}
|
|
2290
2330
|
|
|
2291
|
-
wstate.use_buf(ctx0, 0);
|
|
2292
|
-
|
|
2293
2331
|
// compute logits only for the last token
|
|
2294
2332
|
// comment this line to compute logits for all N tokens
|
|
2295
2333
|
// might be useful in the future
|
|
@@ -2297,23 +2335,77 @@ static bool whisper_decode_internal(
|
|
|
2297
2335
|
|
|
2298
2336
|
struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
|
|
2299
2337
|
|
|
2300
|
-
|
|
2338
|
+
wsp_ggml_build_forward_expand(gf, logits);
|
|
2339
|
+
|
|
2340
|
+
wsp_ggml_free(ctx0);
|
|
2341
|
+
|
|
2342
|
+
return gf;
|
|
2343
|
+
}
|
|
2344
|
+
|
|
2345
|
+
// evaluate the decoder
|
|
2346
|
+
//
|
|
2347
|
+
// given text prompt + audio features -> computes the logits for the next token
|
|
2348
|
+
//
|
|
2349
|
+
// - model: the model
|
|
2350
|
+
// - n_threads: number of threads to use
|
|
2351
|
+
// - tokens: text prompt
|
|
2352
|
+
// - n_tokens: number of tokens in the prompt
|
|
2353
|
+
// - n_past: number of past tokens to prefix the prompt with
|
|
2354
|
+
//
|
|
2355
|
+
static bool whisper_decode_internal(
|
|
2356
|
+
whisper_context & wctx,
|
|
2357
|
+
whisper_state & wstate,
|
|
2358
|
+
whisper_decoder & decoder,
|
|
2359
|
+
const whisper_token * tokens,
|
|
2360
|
+
const int n_tokens,
|
|
2361
|
+
const int n_past,
|
|
2362
|
+
const int n_threads,
|
|
2363
|
+
whisper_abort_callback abort_callback,
|
|
2364
|
+
void * abort_callback_data) {
|
|
2365
|
+
const int64_t t_start_us = wsp_ggml_time_us();
|
|
2366
|
+
|
|
2367
|
+
const auto & model = wctx.model;
|
|
2368
|
+
const auto & hparams = model.hparams;
|
|
2369
|
+
|
|
2370
|
+
const int n_vocab = hparams.n_vocab;
|
|
2371
|
+
|
|
2372
|
+
auto & logits_out = wstate.logits;
|
|
2373
|
+
|
|
2374
|
+
struct wsp_ggml_tensor * logits;
|
|
2301
2375
|
|
|
2302
|
-
//
|
|
2376
|
+
// decoder
|
|
2303
2377
|
{
|
|
2304
|
-
|
|
2305
|
-
|
|
2378
|
+
auto & alloc = wstate.alloc_decode.alloc;
|
|
2379
|
+
|
|
2380
|
+
wsp_ggml_allocr_reset(alloc);
|
|
2381
|
+
|
|
2382
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
|
|
2383
|
+
|
|
2384
|
+
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
2385
|
+
|
|
2386
|
+
logits = gf->nodes[gf->n_nodes - 1];
|
|
2387
|
+
|
|
2388
|
+
#ifdef WSP_GGML_USE_METAL
|
|
2389
|
+
if (wstate.ctx_metal) {
|
|
2390
|
+
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
2391
|
+
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
2392
|
+
} else {
|
|
2393
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2394
|
+
}
|
|
2395
|
+
#else
|
|
2396
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2397
|
+
#endif
|
|
2306
2398
|
}
|
|
2307
2399
|
|
|
2308
2400
|
// extract logits for all N tokens
|
|
2309
|
-
//logits_out.resize(
|
|
2310
|
-
//memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*
|
|
2401
|
+
//logits_out.resize(n_tokens*n_vocab);
|
|
2402
|
+
//memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
|
|
2311
2403
|
|
|
2312
2404
|
// extract logits only for the last token
|
|
2313
2405
|
logits_out.resize(n_vocab);
|
|
2314
2406
|
memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_vocab);
|
|
2315
2407
|
|
|
2316
|
-
if (
|
|
2408
|
+
if (n_tokens > 1) {
|
|
2317
2409
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
2318
2410
|
// wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
2319
2411
|
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
@@ -2322,14 +2414,18 @@ static bool whisper_decode_internal(
|
|
|
2322
2414
|
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
|
2323
2415
|
}
|
|
2324
2416
|
|
|
2325
|
-
|
|
2326
|
-
|
|
2327
|
-
|
|
2328
|
-
|
|
2417
|
+
if (n_tokens == 1) {
|
|
2418
|
+
wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
|
|
2419
|
+
wstate.n_decode++;
|
|
2420
|
+
} else {
|
|
2421
|
+
wstate.t_prompt_us += wsp_ggml_time_us() - t_start_us;
|
|
2422
|
+
wstate.n_prompt++;
|
|
2423
|
+
}
|
|
2329
2424
|
|
|
2330
2425
|
return true;
|
|
2331
2426
|
}
|
|
2332
2427
|
|
|
2428
|
+
|
|
2333
2429
|
// 500 -> 00:05.000
|
|
2334
2430
|
// 6000 -> 01:00.000
|
|
2335
2431
|
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
@@ -2351,7 +2447,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
2351
2447
|
static float sin_vals[SIN_COS_N_COUNT];
|
|
2352
2448
|
static float cos_vals[SIN_COS_N_COUNT];
|
|
2353
2449
|
|
|
2354
|
-
// In FFT, we frequently use sine and cosine operations with the same values.
|
|
2450
|
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
|
2355
2451
|
// We can use precalculated values to speed up the process.
|
|
2356
2452
|
static void fill_sin_cos_table() {
|
|
2357
2453
|
static bool is_filled = false;
|
|
@@ -2446,7 +2542,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2446
2542
|
}
|
|
2447
2543
|
|
|
2448
2544
|
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
|
|
2449
|
-
if (output.size() < length) {
|
|
2545
|
+
if (output.size() < static_cast<size_t>(length)) {
|
|
2450
2546
|
output.resize(length);
|
|
2451
2547
|
}
|
|
2452
2548
|
int offset = -1;
|
|
@@ -2738,9 +2834,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2738
2834
|
fill_sin_cos_table();
|
|
2739
2835
|
whisper_state * state = new whisper_state;
|
|
2740
2836
|
|
|
2741
|
-
|
|
2742
|
-
|
|
2743
|
-
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
|
2837
|
+
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
|
2744
2838
|
log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
|
2745
2839
|
delete state;
|
|
2746
2840
|
return nullptr;
|
|
@@ -2751,7 +2845,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2751
2845
|
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
|
2752
2846
|
}
|
|
2753
2847
|
|
|
2754
|
-
if (!kv_cache_init(ctx->model.hparams,
|
|
2848
|
+
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
|
2755
2849
|
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
2756
2850
|
delete state;
|
|
2757
2851
|
return nullptr;
|
|
@@ -2763,6 +2857,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2763
2857
|
}
|
|
2764
2858
|
|
|
2765
2859
|
#ifdef WHISPER_USE_COREML
|
|
2860
|
+
if (ctx->load_coreml) { // Not in correct layer for easy patch
|
|
2766
2861
|
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
|
2767
2862
|
|
|
2768
2863
|
log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
@@ -2772,11 +2867,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2772
2867
|
if (!state->ctx_coreml) {
|
|
2773
2868
|
log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
2774
2869
|
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
|
2870
|
+
delete state;
|
|
2775
2871
|
return nullptr;
|
|
2776
2872
|
#endif
|
|
2777
2873
|
} else {
|
|
2778
2874
|
log("%s: Core ML model loaded\n", __func__);
|
|
2779
2875
|
}
|
|
2876
|
+
}
|
|
2780
2877
|
#endif
|
|
2781
2878
|
|
|
2782
2879
|
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
|
@@ -2786,21 +2883,134 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2786
2883
|
// TAGS: WHISPER_DECODER_INIT
|
|
2787
2884
|
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
|
2788
2885
|
|
|
2789
|
-
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
|
2790
|
-
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
|
2886
|
+
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
|
|
2887
|
+
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
|
|
2791
2888
|
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
|
2792
|
-
state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
|
|
2793
2889
|
|
|
2794
|
-
|
|
2795
|
-
|
|
2796
|
-
|
|
2797
|
-
|
|
2890
|
+
// conv allocator
|
|
2891
|
+
{
|
|
2892
|
+
whisper_allocr_graph_init(state->alloc_conv,
|
|
2893
|
+
[&]() {
|
|
2894
|
+
return whisper_build_graph_conv(*ctx, *state, 0);
|
|
2895
|
+
});
|
|
2896
|
+
|
|
2897
|
+
log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
|
|
2898
|
+
}
|
|
2899
|
+
|
|
2900
|
+
// encoder allocator
|
|
2901
|
+
if (!whisper_encode_external(*state)) {
|
|
2902
|
+
whisper_allocr_graph_init(state->alloc_encode,
|
|
2903
|
+
[&]() {
|
|
2904
|
+
return whisper_build_graph_encoder(*ctx, *state);
|
|
2905
|
+
});
|
|
2906
|
+
|
|
2907
|
+
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
|
|
2908
|
+
}
|
|
2909
|
+
|
|
2910
|
+
// cross allocator
|
|
2911
|
+
{
|
|
2912
|
+
whisper_allocr_graph_init(state->alloc_cross,
|
|
2913
|
+
[&]() {
|
|
2914
|
+
return whisper_build_graph_cross(*ctx, *state);
|
|
2915
|
+
});
|
|
2916
|
+
|
|
2917
|
+
log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
|
|
2918
|
+
}
|
|
2919
|
+
|
|
2920
|
+
// decoder allocator
|
|
2921
|
+
{
|
|
2922
|
+
whisper_allocr_graph_init(state->alloc_decode,
|
|
2923
|
+
[&]() {
|
|
2924
|
+
const auto & hparams = ctx->model.hparams;
|
|
2925
|
+
|
|
2926
|
+
// TODO: make sure this is the worst-case scenario
|
|
2927
|
+
const int n_tokens = hparams.n_text_ctx;
|
|
2928
|
+
const int n_past = 0;
|
|
2929
|
+
|
|
2930
|
+
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
|
|
2931
|
+
});
|
|
2932
|
+
|
|
2933
|
+
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
|
2934
|
+
}
|
|
2935
|
+
|
|
2936
|
+
#ifdef WSP_GGML_USE_METAL
|
|
2937
|
+
state->ctx_metal = wsp_ggml_metal_init(1);
|
|
2938
|
+
if (!state->ctx_metal) {
|
|
2939
|
+
log("%s: wsp_ggml_metal_init() failed\n", __func__);
|
|
2940
|
+
delete state;
|
|
2941
|
+
return nullptr;
|
|
2942
|
+
}
|
|
2943
|
+
|
|
2944
|
+
log("%s: Metal context initialized\n", __func__);
|
|
2945
|
+
|
|
2946
|
+
// this allocates all Metal resources and memory buffers
|
|
2947
|
+
|
|
2948
|
+
void * data_ptr = NULL;
|
|
2949
|
+
size_t data_size = 0;
|
|
2950
|
+
|
|
2951
|
+
// TODO: add mmap support
|
|
2952
|
+
//if (params.use_mmap) {
|
|
2953
|
+
// data_ptr = ctx->model.mapping->addr;
|
|
2954
|
+
// data_size = ctx->model.mapping->size;
|
|
2955
|
+
//} else {
|
|
2956
|
+
// data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
|
|
2957
|
+
// data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
|
|
2958
|
+
//}
|
|
2959
|
+
|
|
2960
|
+
data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
|
|
2961
|
+
data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
|
|
2962
|
+
|
|
2963
|
+
const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx);
|
|
2964
|
+
|
|
2965
|
+
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
|
|
2966
|
+
|
|
2967
|
+
#define WHISPER_METAL_CHECK_BUF(result) \
|
|
2968
|
+
if (!(result)) { \
|
|
2969
|
+
log("%s: failed to add metal buffer\n", __func__); \
|
|
2970
|
+
delete state; \
|
|
2971
|
+
return nullptr; \
|
|
2972
|
+
}
|
|
2973
|
+
|
|
2974
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
|
|
2975
|
+
|
|
2976
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
|
|
2977
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
|
|
2978
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
|
|
2979
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
|
|
2980
|
+
|
|
2981
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
|
|
2982
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
|
|
2983
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
|
|
2984
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
|
|
2985
|
+
|
|
2986
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
|
|
2987
|
+
|
|
2988
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
|
|
2989
|
+
#undef WHISPER_METAL_CHECK_BUF
|
|
2990
|
+
#endif
|
|
2798
2991
|
|
|
2799
2992
|
state->rng = std::mt19937(0);
|
|
2800
2993
|
|
|
2801
2994
|
return state;
|
|
2802
2995
|
}
|
|
2803
2996
|
|
|
2997
|
+
#ifdef WHISPER_USE_COREML
|
|
2998
|
+
struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) {
|
|
2999
|
+
whisper_context * ctx = whisper_init_from_file_no_state(path_model);
|
|
3000
|
+
if (!ctx) {
|
|
3001
|
+
return nullptr;
|
|
3002
|
+
}
|
|
3003
|
+
ctx->load_coreml = false;
|
|
3004
|
+
ctx->state = whisper_init_state(ctx);
|
|
3005
|
+
if (!ctx->state) {
|
|
3006
|
+
whisper_free(ctx);
|
|
3007
|
+
return nullptr;
|
|
3008
|
+
}
|
|
3009
|
+
|
|
3010
|
+
return ctx;
|
|
3011
|
+
}
|
|
3012
|
+
#endif
|
|
3013
|
+
|
|
2804
3014
|
int whisper_ctx_init_openvino_encoder(
|
|
2805
3015
|
struct whisper_context * ctx,
|
|
2806
3016
|
const char * model_path,
|
|
@@ -2851,7 +3061,6 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
2851
3061
|
}
|
|
2852
3062
|
|
|
2853
3063
|
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
|
2854
|
-
|
|
2855
3064
|
log("%s: loading model from '%s'\n", __func__, path_model);
|
|
2856
3065
|
|
|
2857
3066
|
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
@@ -3004,6 +3213,13 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3004
3213
|
}
|
|
3005
3214
|
#endif
|
|
3006
3215
|
|
|
3216
|
+
#ifdef WSP_GGML_USE_METAL
|
|
3217
|
+
if (state->ctx_metal) {
|
|
3218
|
+
wsp_ggml_metal_free(state->ctx_metal);
|
|
3219
|
+
state->ctx_metal = nullptr;
|
|
3220
|
+
}
|
|
3221
|
+
#endif
|
|
3222
|
+
|
|
3007
3223
|
#ifdef WHISPER_USE_OPENVINO
|
|
3008
3224
|
if (state->ctx_openvino != nullptr) {
|
|
3009
3225
|
whisper_openvino_free(state->ctx_openvino);
|
|
@@ -3011,6 +3227,11 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3011
3227
|
}
|
|
3012
3228
|
#endif
|
|
3013
3229
|
|
|
3230
|
+
whisper_allocr_free(state->alloc_conv);
|
|
3231
|
+
whisper_allocr_free(state->alloc_decode);
|
|
3232
|
+
whisper_allocr_free(state->alloc_cross);
|
|
3233
|
+
whisper_allocr_free(state->alloc_encode);
|
|
3234
|
+
|
|
3014
3235
|
delete state;
|
|
3015
3236
|
}
|
|
3016
3237
|
}
|
|
@@ -3103,7 +3324,7 @@ int whisper_set_mel(
|
|
|
3103
3324
|
}
|
|
3104
3325
|
|
|
3105
3326
|
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
|
3106
|
-
if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
|
|
3327
|
+
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
|
3107
3328
|
log("%s: failed to eval\n", __func__);
|
|
3108
3329
|
return -1;
|
|
3109
3330
|
}
|
|
@@ -3112,7 +3333,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
|
3112
3333
|
}
|
|
3113
3334
|
|
|
3114
3335
|
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
3115
|
-
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
|
|
3336
|
+
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
|
3116
3337
|
log("%s: failed to eval\n", __func__);
|
|
3117
3338
|
return -1;
|
|
3118
3339
|
}
|
|
@@ -3123,7 +3344,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
|
3123
3344
|
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
3124
3345
|
const int selected_decoder_id = 0;
|
|
3125
3346
|
|
|
3126
|
-
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
|
3347
|
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
|
3127
3348
|
log("%s: failed to eval\n", __func__);
|
|
3128
3349
|
return 1;
|
|
3129
3350
|
}
|
|
@@ -3140,7 +3361,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
|
|
3140
3361
|
return false;
|
|
3141
3362
|
}
|
|
3142
3363
|
|
|
3143
|
-
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
|
3364
|
+
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
|
3144
3365
|
log("%s: failed to eval\n", __func__);
|
|
3145
3366
|
return 1;
|
|
3146
3367
|
}
|
|
@@ -3431,12 +3652,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
|
|
3431
3652
|
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
3432
3653
|
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
|
3433
3654
|
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
3655
|
+
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
|
3434
3656
|
|
|
3435
3657
|
log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
|
3436
3658
|
log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
|
3437
3659
|
log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
|
3438
3660
|
log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
|
3439
3661
|
log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
|
3662
|
+
log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
|
3440
3663
|
}
|
|
3441
3664
|
log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
|
3442
3665
|
}
|
|
@@ -3446,6 +3669,11 @@ void whisper_reset_timings(struct whisper_context * ctx) {
|
|
|
3446
3669
|
ctx->state->t_sample_us = 0;
|
|
3447
3670
|
ctx->state->t_encode_us = 0;
|
|
3448
3671
|
ctx->state->t_decode_us = 0;
|
|
3672
|
+
ctx->state->t_prompt_us = 0;
|
|
3673
|
+
ctx->state->n_sample = 0;
|
|
3674
|
+
ctx->state->n_encode = 0;
|
|
3675
|
+
ctx->state->n_decode = 0;
|
|
3676
|
+
ctx->state->n_prompt = 0;
|
|
3449
3677
|
}
|
|
3450
3678
|
}
|
|
3451
3679
|
|
|
@@ -3475,6 +3703,7 @@ const char * whisper_print_system_info(void) {
|
|
|
3475
3703
|
s += "FMA = " + std::to_string(wsp_ggml_cpu_has_fma()) + " | ";
|
|
3476
3704
|
s += "NEON = " + std::to_string(wsp_ggml_cpu_has_neon()) + " | ";
|
|
3477
3705
|
s += "ARM_FMA = " + std::to_string(wsp_ggml_cpu_has_arm_fma()) + " | ";
|
|
3706
|
+
s += "METAL = " + std::to_string(wsp_ggml_cpu_has_metal()) + " | ";
|
|
3478
3707
|
s += "F16C = " + std::to_string(wsp_ggml_cpu_has_f16c()) + " | ";
|
|
3479
3708
|
s += "FP16_VA = " + std::to_string(wsp_ggml_cpu_has_fp16_va()) + " | ";
|
|
3480
3709
|
s += "WASM_SIMD = " + std::to_string(wsp_ggml_cpu_has_wasm_simd()) + " | ";
|
|
@@ -3566,6 +3795,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3566
3795
|
/*.encoder_begin_callback =*/ nullptr,
|
|
3567
3796
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
3568
3797
|
|
|
3798
|
+
/*.abort_callback =*/ nullptr,
|
|
3799
|
+
/*.abort_callback_user_data =*/ nullptr,
|
|
3800
|
+
|
|
3569
3801
|
/*.logits_filter_callback =*/ nullptr,
|
|
3570
3802
|
/*.logits_filter_callback_user_data =*/ nullptr,
|
|
3571
3803
|
};
|
|
@@ -3970,17 +4202,21 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
3970
4202
|
|
|
3971
4203
|
auto & logits_id = state.logits_id;
|
|
3972
4204
|
|
|
3973
|
-
logits_id.
|
|
4205
|
+
logits_id.resize(n_logits);
|
|
3974
4206
|
for (int i = 0; i < n_logits; ++i) {
|
|
3975
|
-
logits_id.
|
|
4207
|
+
logits_id[i].first = logits[i];
|
|
4208
|
+
logits_id[i].second = i;
|
|
3976
4209
|
}
|
|
3977
4210
|
|
|
3978
|
-
|
|
3979
|
-
|
|
3980
|
-
|
|
3981
|
-
|
|
3982
|
-
|
|
3983
|
-
|
|
4211
|
+
{
|
|
4212
|
+
using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
|
|
4213
|
+
std::partial_sort(
|
|
4214
|
+
logits_id.begin(),
|
|
4215
|
+
logits_id.begin() + k, logits_id.end(),
|
|
4216
|
+
[](const pair_type & a, const pair_type & b) {
|
|
4217
|
+
return a.first > b.first;
|
|
4218
|
+
});
|
|
4219
|
+
}
|
|
3984
4220
|
|
|
3985
4221
|
std::vector<whisper_token_data> result;
|
|
3986
4222
|
result.reserve(k);
|
|
@@ -4075,6 +4311,115 @@ static void whisper_sequence_score(
|
|
|
4075
4311
|
}
|
|
4076
4312
|
}
|
|
4077
4313
|
|
|
4314
|
+
static bool whisper_kv_swap_fast(
|
|
4315
|
+
std::vector<int> & view,
|
|
4316
|
+
whisper_decoder src[],
|
|
4317
|
+
std::vector<kv_buf> & kv_swap_bufs,
|
|
4318
|
+
const int & n_decoders) {
|
|
4319
|
+
WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
|
|
4320
|
+
|
|
4321
|
+
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
|
|
4322
|
+
std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
|
|
4323
|
+
|
|
4324
|
+
// (buffer->decoder or decoder->decoder)
|
|
4325
|
+
std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
|
|
4326
|
+
|
|
4327
|
+
// (decoder<->decoder)
|
|
4328
|
+
std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
|
|
4329
|
+
std::vector<whisper_pair<int, int>> p_swap_vec;
|
|
4330
|
+
p_swap_vec.reserve(n_decoders);
|
|
4331
|
+
|
|
4332
|
+
// see https://github.com/ggerganov/whisper.cpp/wiki
|
|
4333
|
+
for (int i = 0; i < n_decoders; i++) {
|
|
4334
|
+
// zero-copy (no modification)
|
|
4335
|
+
if (i == view[i] || view[i] < 0) {
|
|
4336
|
+
continue;
|
|
4337
|
+
}
|
|
4338
|
+
|
|
4339
|
+
bool is_one_copy = true;
|
|
4340
|
+
// since we modify data sequentially, we only consider decoder indices after current index
|
|
4341
|
+
for (int j = i + 1; j < n_decoders; j++) {
|
|
4342
|
+
if (i == view[j]) {
|
|
4343
|
+
// detect symmetric diagram
|
|
4344
|
+
if (j == view[i]) {
|
|
4345
|
+
p_swap_set.insert(i);
|
|
4346
|
+
p_swap_set.insert(j);
|
|
4347
|
+
p_swap_vec.emplace_back(i, j);
|
|
4348
|
+
} else {
|
|
4349
|
+
two_copy.insert(i);
|
|
4350
|
+
is_one_copy = false;
|
|
4351
|
+
}
|
|
4352
|
+
break;
|
|
4353
|
+
}
|
|
4354
|
+
}
|
|
4355
|
+
if (is_one_copy) {
|
|
4356
|
+
one_copy.insert(i);
|
|
4357
|
+
}
|
|
4358
|
+
}
|
|
4359
|
+
|
|
4360
|
+
kv_swap_bufs.resize(n_decoders);
|
|
4361
|
+
|
|
4362
|
+
for (int i = 0; i < n_decoders; i++) {
|
|
4363
|
+
kv_swap_bufs[i].k.resize(wsp_ggml_nbytes(src[i].kv_self.k));
|
|
4364
|
+
kv_swap_bufs[i].v.resize(wsp_ggml_nbytes(src[i].kv_self.v));
|
|
4365
|
+
}
|
|
4366
|
+
|
|
4367
|
+
for (auto & i : two_copy) {
|
|
4368
|
+
// make a copy of KV caches
|
|
4369
|
+
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
|
|
4370
|
+
memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
|
|
4371
|
+
memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
|
|
4372
|
+
}
|
|
4373
|
+
|
|
4374
|
+
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
|
|
4375
|
+
for (auto & i : two_copy) {
|
|
4376
|
+
// skip the decoder indices that require pointer swapping
|
|
4377
|
+
if (p_swap_set.find(i) != p_swap_set.end()) {
|
|
4378
|
+
continue;
|
|
4379
|
+
}
|
|
4380
|
+
|
|
4381
|
+
if (two_copy.find(view[i]) != two_copy.end()) {
|
|
4382
|
+
// modify KV caches of decoder using data from kv_swap_bufs
|
|
4383
|
+
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
|
4384
|
+
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
|
4385
|
+
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
4386
|
+
} else {
|
|
4387
|
+
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
|
4388
|
+
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
|
4389
|
+
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
|
|
4390
|
+
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
|
|
4391
|
+
}
|
|
4392
|
+
}
|
|
4393
|
+
|
|
4394
|
+
// then modify one-copy decoder KV caches
|
|
4395
|
+
for (auto & i : one_copy) {
|
|
4396
|
+
// skip the decoder indices that require pointer swapping
|
|
4397
|
+
if (p_swap_set.find(i) != p_swap_set.end()) {
|
|
4398
|
+
continue;
|
|
4399
|
+
}
|
|
4400
|
+
|
|
4401
|
+
if (two_copy.find(view[i]) != two_copy.end()) {
|
|
4402
|
+
// modify KV caches of decoder using data from kv_swap_bufs
|
|
4403
|
+
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
|
4404
|
+
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
|
4405
|
+
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
4406
|
+
} else {
|
|
4407
|
+
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
|
4408
|
+
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
|
4409
|
+
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
|
|
4410
|
+
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
|
|
4411
|
+
}
|
|
4412
|
+
}
|
|
4413
|
+
|
|
4414
|
+
// swap the pointers
|
|
4415
|
+
for (auto & i : p_swap_vec) {
|
|
4416
|
+
WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
|
|
4417
|
+
std::swap(src[i.first].kv_self, src[i.second].kv_self);
|
|
4418
|
+
}
|
|
4419
|
+
|
|
4420
|
+
return true;
|
|
4421
|
+
}
|
|
4422
|
+
|
|
4078
4423
|
int whisper_full_with_state(
|
|
4079
4424
|
struct whisper_context * ctx,
|
|
4080
4425
|
struct whisper_state * state,
|
|
@@ -4182,6 +4527,21 @@ int whisper_full_with_state(
|
|
|
4182
4527
|
decoder.probs.resize (ctx->vocab.n_vocab);
|
|
4183
4528
|
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
4184
4529
|
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
4530
|
+
|
|
4531
|
+
// TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
|
|
4532
|
+
#ifdef WSP_GGML_USE_METAL
|
|
4533
|
+
#define WHISPER_METAL_CHECK_BUF(result) \
|
|
4534
|
+
if (!(result)) { \
|
|
4535
|
+
log("%s: failed to add metal buffer\n", __func__); \
|
|
4536
|
+
return 0; \
|
|
4537
|
+
}
|
|
4538
|
+
|
|
4539
|
+
const std::string kv_name = "kv_self_" + std::to_string(j);
|
|
4540
|
+
auto & kv_self = decoder.kv_self;
|
|
4541
|
+
|
|
4542
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
|
|
4543
|
+
#undef WHISPER_METAL_CHECK_BUF
|
|
4544
|
+
#endif
|
|
4185
4545
|
}
|
|
4186
4546
|
}
|
|
4187
4547
|
|
|
@@ -4197,7 +4557,7 @@ int whisper_full_with_state(
|
|
|
4197
4557
|
|
|
4198
4558
|
// initial prompt
|
|
4199
4559
|
if (!params.prompt_tokens && params.initial_prompt) {
|
|
4200
|
-
prompt_tokens.resize(
|
|
4560
|
+
prompt_tokens.resize(2048);
|
|
4201
4561
|
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
|
|
4202
4562
|
params.prompt_tokens = prompt_tokens.data();
|
|
4203
4563
|
params.prompt_n_tokens = prompt_tokens.size();
|
|
@@ -4238,14 +4598,6 @@ int whisper_full_with_state(
|
|
|
4238
4598
|
std::vector<whisper_token> prompt;
|
|
4239
4599
|
prompt.reserve(whisper_n_text_ctx(ctx));
|
|
4240
4600
|
|
|
4241
|
-
// beam-search helpers
|
|
4242
|
-
struct kv_buf {
|
|
4243
|
-
std::vector<uint8_t> k;
|
|
4244
|
-
std::vector<uint8_t> v;
|
|
4245
|
-
};
|
|
4246
|
-
|
|
4247
|
-
std::vector<kv_buf> kv_bufs;
|
|
4248
|
-
|
|
4249
4601
|
struct beam_candidate {
|
|
4250
4602
|
int decoder_idx;
|
|
4251
4603
|
int seek_delta;
|
|
@@ -4279,7 +4631,7 @@ int whisper_full_with_state(
|
|
|
4279
4631
|
}
|
|
4280
4632
|
|
|
4281
4633
|
// encode audio features starting at offset seek
|
|
4282
|
-
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
|
|
4634
|
+
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
4283
4635
|
log("%s: failed to encode\n", __func__);
|
|
4284
4636
|
return -6;
|
|
4285
4637
|
}
|
|
@@ -4362,7 +4714,7 @@ int whisper_full_with_state(
|
|
|
4362
4714
|
}
|
|
4363
4715
|
WHISPER_PRINT_DEBUG("\n\n");
|
|
4364
4716
|
|
|
4365
|
-
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
|
|
4717
|
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
4366
4718
|
log("%s: failed to decode\n", __func__);
|
|
4367
4719
|
return -7;
|
|
4368
4720
|
}
|
|
@@ -4382,8 +4734,8 @@ int whisper_full_with_state(
|
|
|
4382
4734
|
|
|
4383
4735
|
decoder.kv_self.n += prompt.size();
|
|
4384
4736
|
|
|
4385
|
-
memcpy(decoder.probs.data(),
|
|
4386
|
-
memcpy(decoder.logits.data(),
|
|
4737
|
+
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
|
4738
|
+
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
4387
4739
|
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
|
4388
4740
|
}
|
|
4389
4741
|
|
|
@@ -4394,23 +4746,7 @@ int whisper_full_with_state(
|
|
|
4394
4746
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
4395
4747
|
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
4396
4748
|
|
|
4397
|
-
// store the KV caches of all decoders when doing beam-search
|
|
4398
4749
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
|
4399
|
-
kv_bufs.resize(n_decoders_cur);
|
|
4400
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
4401
|
-
auto & decoder = state->decoders[j];
|
|
4402
|
-
|
|
4403
|
-
if (decoder.completed || decoder.failed) {
|
|
4404
|
-
continue;
|
|
4405
|
-
}
|
|
4406
|
-
|
|
4407
|
-
kv_bufs[j].k.resize(wsp_ggml_nbytes(decoder.kv_self.k));
|
|
4408
|
-
kv_bufs[j].v.resize(wsp_ggml_nbytes(decoder.kv_self.v));
|
|
4409
|
-
|
|
4410
|
-
memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
|
|
4411
|
-
memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
|
|
4412
|
-
}
|
|
4413
|
-
|
|
4414
4750
|
beam_candidates.clear();
|
|
4415
4751
|
}
|
|
4416
4752
|
|
|
@@ -4458,6 +4794,7 @@ int whisper_full_with_state(
|
|
|
4458
4794
|
});
|
|
4459
4795
|
|
|
4460
4796
|
uint32_t cur_c = 0;
|
|
4797
|
+
std::vector<int> decoder_idx(n_decoders_cur, -1);
|
|
4461
4798
|
|
|
4462
4799
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
4463
4800
|
auto & decoder = state->decoders[j];
|
|
@@ -4476,12 +4813,13 @@ int whisper_full_with_state(
|
|
|
4476
4813
|
decoder.seek_delta = cur.seek_delta;
|
|
4477
4814
|
decoder.has_ts = cur.has_ts;
|
|
4478
4815
|
|
|
4479
|
-
|
|
4480
|
-
memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
|
|
4481
|
-
|
|
4816
|
+
decoder_idx[j] = cur.decoder_idx;
|
|
4482
4817
|
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
|
4483
4818
|
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
|
4484
4819
|
}
|
|
4820
|
+
|
|
4821
|
+
// update KV caches
|
|
4822
|
+
whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
|
|
4485
4823
|
}
|
|
4486
4824
|
|
|
4487
4825
|
// update the decoder state
|
|
@@ -4600,7 +4938,7 @@ int whisper_full_with_state(
|
|
|
4600
4938
|
|
|
4601
4939
|
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
|
4602
4940
|
|
|
4603
|
-
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
|
|
4941
|
+
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
4604
4942
|
log("%s: failed to decode\n", __func__);
|
|
4605
4943
|
return -8;
|
|
4606
4944
|
}
|
|
@@ -4910,6 +5248,12 @@ int whisper_full_parallel(
|
|
|
4910
5248
|
ctx->state->t_sample_us += states[i]->t_sample_us;
|
|
4911
5249
|
ctx->state->t_encode_us += states[i]->t_encode_us;
|
|
4912
5250
|
ctx->state->t_decode_us += states[i]->t_decode_us;
|
|
5251
|
+
ctx->state->t_prompt_us += states[i]->t_prompt_us;
|
|
5252
|
+
|
|
5253
|
+
ctx->state->n_sample += states[i]->n_sample;
|
|
5254
|
+
ctx->state->n_encode += states[i]->n_encode;
|
|
5255
|
+
ctx->state->n_decode += states[i]->n_decode;
|
|
5256
|
+
ctx->state->n_prompt += states[i]->n_prompt;
|
|
4913
5257
|
|
|
4914
5258
|
whisper_free_state(states[i]);
|
|
4915
5259
|
}
|
|
@@ -4963,6 +5307,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
|
|
|
4963
5307
|
return ctx->state->result_all[i_segment].t1;
|
|
4964
5308
|
}
|
|
4965
5309
|
|
|
5310
|
+
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
|
|
5311
|
+
return state->result_all[i_segment].speaker_turn_next;
|
|
5312
|
+
}
|
|
5313
|
+
|
|
4966
5314
|
bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
|
|
4967
5315
|
return ctx->state->result_all[i_segment].speaker_turn_next;
|
|
4968
5316
|
}
|
|
@@ -5106,7 +5454,8 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
5106
5454
|
// b: N*N*sizeof(float)
|
|
5107
5455
|
// c: N*N*sizeof(float)
|
|
5108
5456
|
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
|
5109
|
-
std::vector<
|
|
5457
|
+
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead());
|
|
5458
|
+
std::vector<uint8_t> work;
|
|
5110
5459
|
|
|
5111
5460
|
// put a bunch of random data in the buffer
|
|
5112
5461
|
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
|
@@ -5158,17 +5507,15 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
5158
5507
|
|
|
5159
5508
|
struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(c);
|
|
5160
5509
|
|
|
5161
|
-
gf.n_threads = n_threads;
|
|
5162
|
-
|
|
5163
5510
|
double tsum = 0.0;
|
|
5164
5511
|
|
|
5165
5512
|
// heat-up
|
|
5166
|
-
|
|
5513
|
+
wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
|
|
5167
5514
|
|
|
5168
5515
|
for (int i = 0; i < n_max; ++i) {
|
|
5169
5516
|
const int64_t t0 = wsp_ggml_time_us();
|
|
5170
5517
|
|
|
5171
|
-
|
|
5518
|
+
wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
|
|
5172
5519
|
|
|
5173
5520
|
const int64_t t1 = wsp_ggml_time_us();
|
|
5174
5521
|
|