whisper.rn 0.3.8 → 0.4.0-rc.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +2 -1
- package/android/src/main/jni.cpp +7 -1
- package/cpp/coreml/whisper-encoder.mm +7 -1
- package/cpp/ggml-alloc.c +633 -0
- package/cpp/ggml-alloc.h +26 -0
- package/cpp/ggml-metal.h +85 -0
- package/cpp/ggml-metal.m +1283 -0
- package/cpp/ggml-metal.metal +2353 -0
- package/cpp/ggml.c +5024 -2924
- package/cpp/ggml.h +569 -95
- package/cpp/whisper.cpp +993 -668
- package/cpp/whisper.h +10 -0
- package/ios/RNWhisperAudioSessionUtils.m +7 -1
- package/ios/RNWhisperContext.mm +9 -3
- package/jest/mock.js +10 -0
- package/package.json +1 -1
- package/whisper-rn.podspec +8 -2
package/cpp/whisper.cpp
CHANGED
|
@@ -3,11 +3,16 @@
|
|
|
3
3
|
#include "coreml/whisper-encoder.h"
|
|
4
4
|
#endif
|
|
5
5
|
|
|
6
|
-
#
|
|
6
|
+
#ifdef WSP_GGML_USE_METAL
|
|
7
|
+
# include "ggml-metal.h"
|
|
8
|
+
#endif
|
|
9
|
+
|
|
10
|
+
#ifdef WHISPER_USE_OPENVINO
|
|
7
11
|
#include "openvino/whisper-openvino-encoder.h"
|
|
8
12
|
#endif
|
|
9
13
|
|
|
10
14
|
#include "ggml.h"
|
|
15
|
+
#include "ggml-alloc.h"
|
|
11
16
|
|
|
12
17
|
#include <algorithm>
|
|
13
18
|
#include <cassert>
|
|
@@ -18,11 +23,13 @@
|
|
|
18
23
|
#include <cstring>
|
|
19
24
|
#include <fstream>
|
|
20
25
|
#include <map>
|
|
26
|
+
#include <set>
|
|
21
27
|
#include <string>
|
|
22
28
|
#include <thread>
|
|
23
29
|
#include <vector>
|
|
24
30
|
#include <regex>
|
|
25
31
|
#include <random>
|
|
32
|
+
#include <functional>
|
|
26
33
|
|
|
27
34
|
#if defined(_MSC_VER)
|
|
28
35
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
@@ -114,8 +121,66 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
|
|
|
114
121
|
//#define WHISPER_USE_FLASH_FF
|
|
115
122
|
#define WHISPER_MAX_DECODERS 16
|
|
116
123
|
|
|
117
|
-
|
|
118
|
-
|
|
124
|
+
//
|
|
125
|
+
// ggml helpers
|
|
126
|
+
//
|
|
127
|
+
|
|
128
|
+
static void wsp_ggml_graph_compute_helper(
|
|
129
|
+
std::vector<uint8_t> & buf,
|
|
130
|
+
wsp_ggml_cgraph * graph,
|
|
131
|
+
int n_threads,
|
|
132
|
+
whisper_abort_callback abort_callback,
|
|
133
|
+
void * abort_callback_data) {
|
|
134
|
+
struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads);
|
|
135
|
+
|
|
136
|
+
plan.abort_callback = abort_callback;
|
|
137
|
+
plan.abort_callback_data = abort_callback_data;
|
|
138
|
+
|
|
139
|
+
if (plan.work_size > 0) {
|
|
140
|
+
buf.resize(plan.work_size);
|
|
141
|
+
plan.work_data = buf.data();
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
wsp_ggml_graph_compute(graph, &plan);
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
|
148
|
+
// the idea is to represent the original matrix multiplication:
|
|
149
|
+
//
|
|
150
|
+
// Z = X @ Y
|
|
151
|
+
//
|
|
152
|
+
// with the sum of two matrix multiplications:
|
|
153
|
+
//
|
|
154
|
+
// Z = (X_0 @ Y_0) + (X_1 @ Y_1)
|
|
155
|
+
//
|
|
156
|
+
// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
|
|
157
|
+
// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
|
|
158
|
+
// general-purpose kernels
|
|
159
|
+
//
|
|
160
|
+
static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * x, struct wsp_ggml_tensor * y, int pad = 32) {
|
|
161
|
+
// use padding only if dimension 0 is at least 8 times larger than the padding
|
|
162
|
+
// else we won't get much benefit from the optimization
|
|
163
|
+
const int n_pad_req = 8;
|
|
164
|
+
|
|
165
|
+
if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
|
|
166
|
+
return wsp_ggml_mul_mat(ctx, x, y);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
struct wsp_ggml_tensor * x_0 = wsp_ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
|
|
170
|
+
struct wsp_ggml_tensor * x_1 = wsp_ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
|
|
171
|
+
|
|
172
|
+
struct wsp_ggml_tensor * y_0 = wsp_ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
|
|
173
|
+
struct wsp_ggml_tensor * y_1 = wsp_ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
|
|
174
|
+
|
|
175
|
+
return wsp_ggml_add(ctx,
|
|
176
|
+
wsp_ggml_mul_mat(ctx, x_0, y_0),
|
|
177
|
+
wsp_ggml_mul_mat(ctx, x_1, y_1));
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
// TODO: check if other platforms can benefit from this optimization
|
|
181
|
+
#if defined(WSP_GGML_USE_METAL)
|
|
182
|
+
#define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad
|
|
183
|
+
#endif
|
|
119
184
|
|
|
120
185
|
// available whisper models
|
|
121
186
|
enum e_model {
|
|
@@ -231,38 +296,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
|
231
296
|
|
|
232
297
|
static const size_t MB = 1ull*1024*1024;
|
|
233
298
|
|
|
234
|
-
|
|
235
|
-
{ MODEL_TINY, 62ull*MB },
|
|
236
|
-
{ MODEL_BASE, 80ull*MB },
|
|
237
|
-
{ MODEL_SMALL, 120ull*MB },
|
|
238
|
-
{ MODEL_MEDIUM, 158ull*MB },
|
|
239
|
-
{ MODEL_LARGE, 198ull*MB },
|
|
240
|
-
};
|
|
241
|
-
|
|
242
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
|
|
243
|
-
{ MODEL_TINY, 18ull*MB },
|
|
244
|
-
{ MODEL_BASE, 24ull*MB },
|
|
245
|
-
{ MODEL_SMALL, 36ull*MB },
|
|
246
|
-
{ MODEL_MEDIUM, 48ull*MB },
|
|
247
|
-
{ MODEL_LARGE, 60ull*MB },
|
|
248
|
-
};
|
|
249
|
-
|
|
250
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
|
|
251
|
-
{ MODEL_TINY, 4ull*MB },
|
|
252
|
-
{ MODEL_BASE, 4ull*MB },
|
|
253
|
-
{ MODEL_SMALL, 6ull*MB },
|
|
254
|
-
{ MODEL_MEDIUM, 7ull*MB },
|
|
255
|
-
{ MODEL_LARGE, 9ull*MB },
|
|
256
|
-
};
|
|
257
|
-
|
|
258
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
|
|
259
|
-
{ MODEL_TINY, 4ull*MB },
|
|
260
|
-
{ MODEL_BASE, 4ull*MB },
|
|
261
|
-
{ MODEL_SMALL, 6ull*MB },
|
|
262
|
-
{ MODEL_MEDIUM, 7ull*MB },
|
|
263
|
-
{ MODEL_LARGE, 9ull*MB },
|
|
264
|
-
};
|
|
265
|
-
|
|
299
|
+
// TODO: avoid using GGUF
|
|
266
300
|
static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
|
267
301
|
{ WSP_GGML_TYPE_F32,
|
|
268
302
|
{
|
|
@@ -329,38 +363,6 @@ static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL =
|
|
|
329
363
|
},
|
|
330
364
|
};
|
|
331
365
|
|
|
332
|
-
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
|
|
333
|
-
{ MODEL_TINY, 3ull*MB },
|
|
334
|
-
{ MODEL_BASE, 6ull*MB },
|
|
335
|
-
{ MODEL_SMALL, 16ull*MB },
|
|
336
|
-
{ MODEL_MEDIUM, 43ull*MB },
|
|
337
|
-
{ MODEL_LARGE, 71ull*MB },
|
|
338
|
-
};
|
|
339
|
-
|
|
340
|
-
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
|
|
341
|
-
{ MODEL_TINY, 9ull*MB },
|
|
342
|
-
{ MODEL_BASE, 18ull*MB },
|
|
343
|
-
{ MODEL_SMALL, 53ull*MB },
|
|
344
|
-
{ MODEL_MEDIUM, 141ull*MB },
|
|
345
|
-
{ MODEL_LARGE, 235ull*MB },
|
|
346
|
-
};
|
|
347
|
-
|
|
348
|
-
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
|
349
|
-
{ MODEL_TINY, 30ull*MB },
|
|
350
|
-
{ MODEL_BASE, 38ull*MB },
|
|
351
|
-
{ MODEL_SMALL, 56ull*MB },
|
|
352
|
-
{ MODEL_MEDIUM, 74ull*MB },
|
|
353
|
-
{ MODEL_LARGE, 94ull*MB },
|
|
354
|
-
};
|
|
355
|
-
|
|
356
|
-
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
|
357
|
-
{ MODEL_TINY, 3ull*MB },
|
|
358
|
-
{ MODEL_BASE, 5ull*MB },
|
|
359
|
-
{ MODEL_SMALL, 10ull*MB },
|
|
360
|
-
{ MODEL_MEDIUM, 18ull*MB },
|
|
361
|
-
{ MODEL_LARGE, 27ull*MB },
|
|
362
|
-
};
|
|
363
|
-
|
|
364
366
|
struct whisper_mel {
|
|
365
367
|
int n_len;
|
|
366
368
|
int n_len_org;
|
|
@@ -441,6 +443,7 @@ struct whisper_hparams {
|
|
|
441
443
|
int32_t n_text_layer = 4;
|
|
442
444
|
int32_t n_mels = 80;
|
|
443
445
|
int32_t ftype = 1;
|
|
446
|
+
float eps = 1e-5f;
|
|
444
447
|
};
|
|
445
448
|
|
|
446
449
|
// audio encoding layer
|
|
@@ -536,6 +539,7 @@ struct whisper_kv_cache {
|
|
|
536
539
|
|
|
537
540
|
struct wsp_ggml_context * ctx;
|
|
538
541
|
|
|
542
|
+
// buf points to the memory allocated for both wsp_ggml_tensor 'k' and 'v' (see kv_cache_init)
|
|
539
543
|
std::vector<uint8_t> buf;
|
|
540
544
|
|
|
541
545
|
int n; // number of tokens currently in the cache
|
|
@@ -601,7 +605,7 @@ struct whisper_sequence {
|
|
|
601
605
|
|
|
602
606
|
// TAGS: WHISPER_DECODER_INIT
|
|
603
607
|
struct whisper_decoder {
|
|
604
|
-
// each
|
|
608
|
+
// each decoder keeps its own KV-cache
|
|
605
609
|
whisper_kv_cache kv_self;
|
|
606
610
|
|
|
607
611
|
// the currently generated sequence of tokens
|
|
@@ -621,15 +625,75 @@ struct whisper_decoder {
|
|
|
621
625
|
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
|
622
626
|
};
|
|
623
627
|
|
|
628
|
+
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
|
629
|
+
template<typename A, typename B>
|
|
630
|
+
struct whisper_pair {
|
|
631
|
+
A first;
|
|
632
|
+
B second;
|
|
633
|
+
|
|
634
|
+
// Define a constructor that takes two arguments.
|
|
635
|
+
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
|
636
|
+
// Define a constructor that takes no argument.
|
|
637
|
+
whisper_pair() : first(A()), second(B()) {}
|
|
638
|
+
};
|
|
639
|
+
|
|
640
|
+
// beam-search helpers
|
|
641
|
+
struct kv_buf {
|
|
642
|
+
std::vector<uint8_t> k;
|
|
643
|
+
std::vector<uint8_t> v;
|
|
644
|
+
};
|
|
645
|
+
|
|
646
|
+
// wsp_ggml_allocr wrapper for whisper usage
|
|
647
|
+
struct whisper_allocr {
|
|
648
|
+
wsp_ggml_allocr * alloc = nullptr;
|
|
649
|
+
|
|
650
|
+
std::vector<uint8_t> meta;
|
|
651
|
+
std::vector<uint8_t> data;
|
|
652
|
+
};
|
|
653
|
+
|
|
654
|
+
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
|
655
|
+
return allocr.meta.size() + allocr.data.size();
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
|
659
|
+
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
|
|
660
|
+
const int tensor_alignment = 32;
|
|
661
|
+
|
|
662
|
+
auto & alloc = allocr.alloc;
|
|
663
|
+
auto & meta = allocr.meta;
|
|
664
|
+
auto & data = allocr.data;
|
|
665
|
+
|
|
666
|
+
meta.resize(wsp_ggml_tensor_overhead()*WSP_GGML_MAX_NODES + wsp_ggml_graph_overhead());
|
|
667
|
+
|
|
668
|
+
alloc = wsp_ggml_allocr_new_measure(tensor_alignment);
|
|
669
|
+
|
|
670
|
+
const size_t alloc_size = wsp_ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
|
|
671
|
+
|
|
672
|
+
wsp_ggml_allocr_free(alloc);
|
|
673
|
+
|
|
674
|
+
data.resize(alloc_size);
|
|
675
|
+
|
|
676
|
+
alloc = wsp_ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
|
680
|
+
if (allocr.alloc) {
|
|
681
|
+
wsp_ggml_allocr_free(allocr.alloc);
|
|
682
|
+
allocr.alloc = nullptr;
|
|
683
|
+
}
|
|
684
|
+
}
|
|
685
|
+
|
|
624
686
|
struct whisper_state {
|
|
625
687
|
int64_t t_sample_us = 0;
|
|
626
688
|
int64_t t_encode_us = 0;
|
|
627
689
|
int64_t t_decode_us = 0;
|
|
690
|
+
int64_t t_prompt_us = 0;
|
|
628
691
|
int64_t t_mel_us = 0;
|
|
629
692
|
|
|
630
693
|
int32_t n_sample = 0; // number of tokens sampled
|
|
631
694
|
int32_t n_encode = 0; // number of encoder calls
|
|
632
|
-
int32_t n_decode = 0; // number of decoder calls
|
|
695
|
+
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
|
|
696
|
+
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
|
|
633
697
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
|
634
698
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
|
635
699
|
|
|
@@ -640,12 +704,23 @@ struct whisper_state {
|
|
|
640
704
|
|
|
641
705
|
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
|
642
706
|
|
|
643
|
-
//
|
|
644
|
-
std::vector<
|
|
645
|
-
|
|
707
|
+
// buffer for swapping KV caches between decoders during beam-search
|
|
708
|
+
std::vector<kv_buf> kv_swap_bufs;
|
|
709
|
+
|
|
710
|
+
// reusable buffer for `struct wsp_ggml_graph_plan.work_data`
|
|
711
|
+
std::vector<uint8_t> work_buffer;
|
|
646
712
|
|
|
647
|
-
|
|
648
|
-
|
|
713
|
+
// ggml-alloc:
|
|
714
|
+
// - stores meta info about the intermediate tensors into the `meta` buffers
|
|
715
|
+
// - stores the actual tensor data into the `data` buffers
|
|
716
|
+
whisper_allocr alloc_conv;
|
|
717
|
+
whisper_allocr alloc_encode;
|
|
718
|
+
whisper_allocr alloc_cross;
|
|
719
|
+
whisper_allocr alloc_decode;
|
|
720
|
+
|
|
721
|
+
// result of the encoder
|
|
722
|
+
struct wsp_ggml_tensor * embd_conv = nullptr;
|
|
723
|
+
struct wsp_ggml_tensor * embd_enc = nullptr;
|
|
649
724
|
|
|
650
725
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
|
651
726
|
std::vector<float> logits;
|
|
@@ -654,7 +729,7 @@ struct whisper_state {
|
|
|
654
729
|
std::vector<whisper_token> prompt_past;
|
|
655
730
|
|
|
656
731
|
// work container used to avoid memory allocations
|
|
657
|
-
std::vector<
|
|
732
|
+
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
|
658
733
|
|
|
659
734
|
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
660
735
|
|
|
@@ -665,6 +740,10 @@ struct whisper_state {
|
|
|
665
740
|
whisper_coreml_context * ctx_coreml = nullptr;
|
|
666
741
|
#endif
|
|
667
742
|
|
|
743
|
+
#ifdef WSP_GGML_USE_METAL
|
|
744
|
+
wsp_ggml_metal_context * ctx_metal = nullptr;
|
|
745
|
+
#endif
|
|
746
|
+
|
|
668
747
|
#ifdef WHISPER_USE_OPENVINO
|
|
669
748
|
whisper_openvino_context * ctx_openvino = nullptr;
|
|
670
749
|
#endif
|
|
@@ -677,37 +756,6 @@ struct whisper_state {
|
|
|
677
756
|
|
|
678
757
|
// [EXPERIMENTAL] speed-up techniques
|
|
679
758
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
|
680
|
-
|
|
681
|
-
void use_buf(struct wsp_ggml_context * ctx, int i) {
|
|
682
|
-
#if defined(WHISPER_USE_SCRATCH)
|
|
683
|
-
size_t last_size = 0;
|
|
684
|
-
|
|
685
|
-
if (i == -1) {
|
|
686
|
-
last_size = wsp_ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
|
687
|
-
} else {
|
|
688
|
-
auto & buf = buf_scratch[i];
|
|
689
|
-
last_size = wsp_ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
|
|
690
|
-
}
|
|
691
|
-
|
|
692
|
-
if (buf_last >= 0) {
|
|
693
|
-
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
|
694
|
-
}
|
|
695
|
-
|
|
696
|
-
buf_last = i;
|
|
697
|
-
#else
|
|
698
|
-
(void) i;
|
|
699
|
-
(void) ctx;
|
|
700
|
-
#endif
|
|
701
|
-
}
|
|
702
|
-
|
|
703
|
-
size_t get_buf_max_mem(int i) const {
|
|
704
|
-
#if defined(WHISPER_USE_SCRATCH)
|
|
705
|
-
return buf_max_size[i];
|
|
706
|
-
#else
|
|
707
|
-
(void) i;
|
|
708
|
-
return 0;
|
|
709
|
-
#endif
|
|
710
|
-
}
|
|
711
759
|
};
|
|
712
760
|
|
|
713
761
|
struct whisper_context {
|
|
@@ -730,6 +778,13 @@ static void whisper_default_log(const char * text) {
|
|
|
730
778
|
|
|
731
779
|
static whisper_log_callback whisper_log = whisper_default_log;
|
|
732
780
|
|
|
781
|
+
#ifdef __GNUC__
|
|
782
|
+
#ifdef __MINGW32__
|
|
783
|
+
__attribute__((gnu_format(printf, 1, 2)))
|
|
784
|
+
#else
|
|
785
|
+
__attribute__((format(printf, 1, 2)))
|
|
786
|
+
#endif
|
|
787
|
+
#endif
|
|
733
788
|
static void log(const char * fmt, ...) {
|
|
734
789
|
if (!whisper_log) return;
|
|
735
790
|
char buf[1024];
|
|
@@ -747,10 +802,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
|
747
802
|
|
|
748
803
|
static bool kv_cache_init(
|
|
749
804
|
const struct whisper_hparams & hparams,
|
|
750
|
-
const size_t mem_bytes,
|
|
751
805
|
struct whisper_kv_cache & cache,
|
|
752
806
|
wsp_ggml_type wtype,
|
|
753
807
|
int n_ctx) {
|
|
808
|
+
const int64_t n_text_state = hparams.n_text_state;
|
|
809
|
+
const int64_t n_text_layer = hparams.n_text_layer;
|
|
810
|
+
|
|
811
|
+
const int64_t n_mem = n_text_layer*n_ctx;
|
|
812
|
+
const int64_t n_elements = n_text_state*n_mem;
|
|
813
|
+
|
|
814
|
+
const size_t mem_bytes = 2*(wsp_ggml_type_size(wtype)*n_elements + wsp_ggml_tensor_overhead());
|
|
815
|
+
|
|
754
816
|
cache.buf.resize(mem_bytes);
|
|
755
817
|
|
|
756
818
|
struct wsp_ggml_init_params params = {
|
|
@@ -766,12 +828,6 @@ static bool kv_cache_init(
|
|
|
766
828
|
return false;
|
|
767
829
|
}
|
|
768
830
|
|
|
769
|
-
const int n_text_state = hparams.n_text_state;
|
|
770
|
-
const int n_text_layer = hparams.n_text_layer;
|
|
771
|
-
|
|
772
|
-
const int n_mem = n_text_layer*n_ctx;
|
|
773
|
-
const int n_elements = n_text_state*n_mem;
|
|
774
|
-
|
|
775
831
|
cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
|
776
832
|
cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
|
777
833
|
|
|
@@ -914,22 +970,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
914
970
|
|
|
915
971
|
// print memory requirements
|
|
916
972
|
{
|
|
917
|
-
//
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
MEM_REQ_SCRATCH1.at(model.type) +
|
|
921
|
-
MEM_REQ_SCRATCH2.at(model.type) +
|
|
922
|
-
MEM_REQ_SCRATCH3.at(model.type) +
|
|
923
|
-
scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
|
|
924
|
-
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
|
925
|
-
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
|
926
|
-
|
|
927
|
-
// this is the memory required by one decoder
|
|
928
|
-
const size_t mem_required_decoder =
|
|
929
|
-
scale*MEM_REQ_KV_SELF.at(model.type);
|
|
930
|
-
|
|
931
|
-
log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
|
932
|
-
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
|
973
|
+
// TODO
|
|
974
|
+
//log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
|
975
|
+
// mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
|
933
976
|
}
|
|
934
977
|
|
|
935
978
|
// initialize all memory buffers
|
|
@@ -1438,49 +1481,56 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1438
1481
|
return true;
|
|
1439
1482
|
}
|
|
1440
1483
|
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
|
1444
|
-
// part of the transformer model and returns the encoded features
|
|
1445
|
-
//
|
|
1446
|
-
// - wctx: the model
|
|
1447
|
-
// - wstate: the state of the encoder
|
|
1448
|
-
// - n_threads: number of threads to use
|
|
1449
|
-
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
|
1450
|
-
//
|
|
1451
|
-
static bool whisper_encode_internal(
|
|
1452
|
-
whisper_context & wctx,
|
|
1453
|
-
whisper_state & wstate,
|
|
1454
|
-
const int mel_offset,
|
|
1455
|
-
const int n_threads){
|
|
1484
|
+
static bool whisper_encode_external(const whisper_state & wstate) {
|
|
1485
|
+
WSP_GGML_UNUSED(wstate);
|
|
1456
1486
|
|
|
1457
|
-
|
|
1487
|
+
#ifndef WHISPER_USE_COREML
|
|
1488
|
+
const bool use_coreml = false;
|
|
1489
|
+
#else
|
|
1490
|
+
const bool use_coreml = wstate.ctx_coreml != nullptr;
|
|
1491
|
+
#endif
|
|
1492
|
+
|
|
1493
|
+
#ifndef WHISPER_USE_OPENVINO
|
|
1494
|
+
const bool use_openvino = false;
|
|
1495
|
+
#else
|
|
1496
|
+
const bool use_openvino = wstate.ctx_openvino != nullptr;
|
|
1497
|
+
#endif
|
|
1498
|
+
|
|
1499
|
+
return use_coreml || use_openvino;
|
|
1500
|
+
}
|
|
1458
1501
|
|
|
1502
|
+
static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
1503
|
+
whisper_context & wctx,
|
|
1504
|
+
whisper_state & wstate,
|
|
1505
|
+
const int mel_offset) {
|
|
1459
1506
|
const auto & model = wctx.model;
|
|
1460
1507
|
const auto & mel_inp = wstate.mel;
|
|
1461
1508
|
const auto & hparams = model.hparams;
|
|
1462
1509
|
|
|
1463
1510
|
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
1464
|
-
const int n_state = hparams.n_audio_state;
|
|
1465
|
-
const int n_head = hparams.n_audio_head;
|
|
1466
|
-
const int n_layer = hparams.n_audio_layer;
|
|
1511
|
+
const int n_state = hparams.n_audio_state; WSP_GGML_UNUSED(n_state);
|
|
1467
1512
|
|
|
1468
1513
|
const int n_mels = hparams.n_mels;
|
|
1469
|
-
assert(mel_inp.n_mel == n_mels);
|
|
1470
1514
|
|
|
1471
1515
|
struct wsp_ggml_init_params params = {
|
|
1472
|
-
/*.mem_size =*/ wstate.
|
|
1473
|
-
/*.mem_buffer =*/ wstate.
|
|
1474
|
-
/*.no_alloc =*/
|
|
1516
|
+
/*.mem_size =*/ wstate.alloc_conv.meta.size(),
|
|
1517
|
+
/*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
|
|
1518
|
+
/*.no_alloc =*/ true,
|
|
1475
1519
|
};
|
|
1476
1520
|
|
|
1477
1521
|
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
1478
1522
|
|
|
1479
|
-
|
|
1523
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
1524
|
+
|
|
1525
|
+
wsp_ggml_allocr * alloc = wstate.alloc_conv.alloc;
|
|
1480
1526
|
|
|
1481
1527
|
struct wsp_ggml_tensor * mel = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, 2*n_ctx, n_mels);
|
|
1528
|
+
wsp_ggml_allocr_alloc(alloc, mel);
|
|
1529
|
+
|
|
1482
1530
|
assert(mel->type == WSP_GGML_TYPE_F32);
|
|
1483
|
-
{
|
|
1531
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1532
|
+
assert(mel_inp.n_mel == n_mels);
|
|
1533
|
+
|
|
1484
1534
|
float * dst = (float *) mel->data;
|
|
1485
1535
|
memset(dst, 0, wsp_ggml_nbytes(mel));
|
|
1486
1536
|
|
|
@@ -1494,25 +1544,11 @@ static bool whisper_encode_internal(
|
|
|
1494
1544
|
}
|
|
1495
1545
|
}
|
|
1496
1546
|
|
|
1497
|
-
struct wsp_ggml_tensor * cur;
|
|
1547
|
+
struct wsp_ggml_tensor * cur = nullptr;
|
|
1498
1548
|
|
|
1499
|
-
|
|
1500
|
-
const bool use_coreml = false;
|
|
1501
|
-
#else
|
|
1502
|
-
const bool use_coreml = wstate.ctx_coreml != nullptr;
|
|
1503
|
-
#endif
|
|
1504
|
-
|
|
1505
|
-
#ifndef WHISPER_USE_OPENVINO
|
|
1506
|
-
const bool use_openvino = false;
|
|
1507
|
-
#else
|
|
1508
|
-
const bool use_openvino = wstate.ctx_openvino != nullptr;
|
|
1509
|
-
#endif
|
|
1510
|
-
|
|
1511
|
-
if (!use_coreml && !use_openvino) {
|
|
1549
|
+
if (!whisper_encode_external(wstate)) {
|
|
1512
1550
|
// convolution + gelu
|
|
1513
1551
|
{
|
|
1514
|
-
wstate.use_buf(ctx0, 1);
|
|
1515
|
-
|
|
1516
1552
|
cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
|
1517
1553
|
cur = wsp_ggml_add(ctx0,
|
|
1518
1554
|
wsp_ggml_repeat(ctx0,
|
|
@@ -1522,8 +1558,6 @@ static bool whisper_encode_internal(
|
|
|
1522
1558
|
|
|
1523
1559
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1524
1560
|
|
|
1525
|
-
wstate.use_buf(ctx0, 0);
|
|
1526
|
-
|
|
1527
1561
|
cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
|
1528
1562
|
cur = wsp_ggml_add(ctx0,
|
|
1529
1563
|
wsp_ggml_repeat(ctx0,
|
|
@@ -1534,373 +1568,433 @@ static bool whisper_encode_internal(
|
|
|
1534
1568
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1535
1569
|
}
|
|
1536
1570
|
|
|
1537
|
-
wstate.
|
|
1571
|
+
wstate.embd_conv = cur;
|
|
1572
|
+
} else {
|
|
1573
|
+
#ifdef WHISPER_USE_COREML
|
|
1574
|
+
cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
1575
|
+
wsp_ggml_allocr_alloc(alloc, cur);
|
|
1538
1576
|
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1577
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1578
|
+
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
|
1579
|
+
}
|
|
1580
|
+
#endif
|
|
1581
|
+
#ifdef WHISPER_USE_OPENVINO
|
|
1582
|
+
cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
1583
|
+
wsp_ggml_allocr_alloc(alloc, cur);
|
|
1543
1584
|
|
|
1544
|
-
|
|
1585
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1586
|
+
whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
|
|
1587
|
+
}
|
|
1588
|
+
#endif
|
|
1545
1589
|
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
// memset(model.memory_cross_v->data, 0, wsp_ggml_nbytes(model.memory_cross_v));
|
|
1549
|
-
//}
|
|
1590
|
+
wstate.embd_enc = cur;
|
|
1591
|
+
}
|
|
1550
1592
|
|
|
1551
|
-
|
|
1593
|
+
wsp_ggml_build_forward_expand(gf, cur);
|
|
1552
1594
|
|
|
1553
|
-
|
|
1554
|
-
const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
|
|
1595
|
+
wsp_ggml_free(ctx0);
|
|
1555
1596
|
|
|
1556
|
-
|
|
1597
|
+
return gf;
|
|
1598
|
+
}
|
|
1557
1599
|
|
|
1558
|
-
|
|
1600
|
+
static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
1601
|
+
whisper_context & wctx,
|
|
1602
|
+
whisper_state & wstate) {
|
|
1603
|
+
const auto & model = wctx.model;
|
|
1604
|
+
const auto & hparams = model.hparams;
|
|
1559
1605
|
|
|
1560
|
-
|
|
1606
|
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
1607
|
+
const int n_state = hparams.n_audio_state;
|
|
1608
|
+
const int n_head = hparams.n_audio_head;
|
|
1609
|
+
const int n_layer = hparams.n_audio_layer;
|
|
1561
1610
|
|
|
1562
|
-
|
|
1563
|
-
|
|
1611
|
+
struct wsp_ggml_init_params params = {
|
|
1612
|
+
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
|
|
1613
|
+
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
|
|
1614
|
+
/*.no_alloc =*/ true,
|
|
1615
|
+
};
|
|
1564
1616
|
|
|
1565
|
-
|
|
1617
|
+
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
1566
1618
|
|
|
1567
|
-
|
|
1568
|
-
const auto & layer = model.layers_encoder[il];
|
|
1619
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
1569
1620
|
|
|
1570
|
-
|
|
1571
|
-
{
|
|
1572
|
-
wstate.use_buf(ctx0, 0);
|
|
1621
|
+
wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
|
1573
1622
|
|
|
1574
|
-
|
|
1623
|
+
struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
1624
|
+
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
1575
1625
|
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
wsp_ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
|
1580
|
-
cur),
|
|
1581
|
-
wsp_ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
|
1582
|
-
}
|
|
1626
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1627
|
+
wsp_ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
|
|
1628
|
+
}
|
|
1583
1629
|
|
|
1584
|
-
|
|
1585
|
-
{
|
|
1586
|
-
wstate.use_buf(ctx0, 1);
|
|
1630
|
+
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
|
|
1587
1631
|
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1632
|
+
// ===================================================================
|
|
1633
|
+
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
1634
|
+
//static int iter = -1;
|
|
1635
|
+
//const int n_iter = 1500/n_ctx;
|
|
1591
1636
|
|
|
1592
|
-
|
|
1593
|
-
wsp_ggml_repeat(ctx0,
|
|
1594
|
-
layer.attn_q_b,
|
|
1595
|
-
Qcur),
|
|
1596
|
-
Qcur);
|
|
1637
|
+
//iter = (iter + 1) % n_iter;
|
|
1597
1638
|
|
|
1598
|
-
|
|
1639
|
+
//if (iter == 0) {
|
|
1640
|
+
// memset(model.memory_cross_k->data, 0, wsp_ggml_nbytes(model.memory_cross_k));
|
|
1641
|
+
// memset(model.memory_cross_v->data, 0, wsp_ggml_nbytes(model.memory_cross_v));
|
|
1642
|
+
//}
|
|
1599
1643
|
|
|
1600
|
-
|
|
1601
|
-
struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
|
|
1602
|
-
layer.attn_k_w,
|
|
1603
|
-
cur);
|
|
1644
|
+
static int iter = 0;
|
|
1604
1645
|
|
|
1605
|
-
|
|
1646
|
+
const size_t e_pe_stride = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe);
|
|
1647
|
+
const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
|
|
1606
1648
|
|
|
1607
|
-
|
|
1608
|
-
layer.attn_v_w,
|
|
1609
|
-
cur);
|
|
1649
|
+
struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
|
|
1610
1650
|
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1651
|
+
cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_cont(ctx0, wsp_ggml_transpose(ctx0, cur)));
|
|
1652
|
+
|
|
1653
|
+
// ===================================================================
|
|
1654
|
+
|
|
1655
|
+
// original:
|
|
1656
|
+
//cur = wsp_ggml_add(ctx0, model.e_pe, wsp_ggml_transpose(ctx0, cur));
|
|
1657
|
+
|
|
1658
|
+
struct wsp_ggml_tensor * inpL = cur;
|
|
1616
1659
|
|
|
1617
|
-
|
|
1660
|
+
for (int il = 0; il < n_layer; ++il) {
|
|
1661
|
+
const auto & layer = model.layers_encoder[il];
|
|
1662
|
+
|
|
1663
|
+
// norm
|
|
1664
|
+
{
|
|
1665
|
+
cur = wsp_ggml_norm(ctx0, inpL, hparams.eps);
|
|
1666
|
+
|
|
1667
|
+
// cur = ln_0_w*cur + ln_0_b
|
|
1668
|
+
cur = wsp_ggml_add(ctx0,
|
|
1669
|
+
wsp_ggml_mul(ctx0, cur, layer.attn_ln_0_w),
|
|
1670
|
+
layer.attn_ln_0_b);
|
|
1671
|
+
}
|
|
1672
|
+
|
|
1673
|
+
// self-attention
|
|
1674
|
+
{
|
|
1675
|
+
struct wsp_ggml_tensor * Qcur = wsp_ggml_mul_mat(ctx0,
|
|
1676
|
+
layer.attn_q_w,
|
|
1677
|
+
cur);
|
|
1678
|
+
|
|
1679
|
+
Qcur = wsp_ggml_add(ctx0, Qcur, layer.attn_q_b);
|
|
1680
|
+
|
|
1681
|
+
//Qcur = wsp_ggml_scale(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1682
|
+
|
|
1683
|
+
// note: no bias for Key
|
|
1684
|
+
struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
|
|
1685
|
+
layer.attn_k_w,
|
|
1686
|
+
cur);
|
|
1687
|
+
|
|
1688
|
+
//Kcur = wsp_ggml_scale(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1689
|
+
|
|
1690
|
+
struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
|
|
1691
|
+
layer.attn_v_w,
|
|
1692
|
+
cur);
|
|
1693
|
+
|
|
1694
|
+
Vcur = wsp_ggml_add(ctx0, Vcur, layer.attn_v_b);
|
|
1618
1695
|
|
|
1619
|
-
|
|
1696
|
+
// ------
|
|
1620
1697
|
|
|
1621
1698
|
#ifdef WHISPER_USE_FLASH_ATTN
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1699
|
+
struct wsp_ggml_tensor * Q =
|
|
1700
|
+
wsp_ggml_permute(ctx0,
|
|
1701
|
+
wsp_ggml_cpy(ctx0,
|
|
1702
|
+
Qcur,
|
|
1703
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1704
|
+
0, 2, 1, 3);
|
|
1705
|
+
|
|
1706
|
+
struct wsp_ggml_tensor * K =
|
|
1707
|
+
wsp_ggml_permute(ctx0,
|
|
1708
|
+
wsp_ggml_cpy(ctx0,
|
|
1709
|
+
Kcur,
|
|
1710
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1711
|
+
0, 2, 1, 3);
|
|
1712
|
+
|
|
1713
|
+
struct wsp_ggml_tensor * V =
|
|
1714
|
+
wsp_ggml_cpy(ctx0,
|
|
1715
|
+
wsp_ggml_permute(ctx0,
|
|
1716
|
+
wsp_ggml_reshape_3d(ctx0,
|
|
1717
|
+
Vcur,
|
|
1718
|
+
n_state/n_head, n_head, n_ctx),
|
|
1719
|
+
1, 2, 0, 3),
|
|
1720
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
|
1721
|
+
|
|
1722
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false);
|
|
1646
1723
|
#else
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
|
|
1656
|
-
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
|
|
1663
|
-
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
|
1680
|
-
);
|
|
1681
|
-
|
|
1682
|
-
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
1724
|
+
struct wsp_ggml_tensor * Q =
|
|
1725
|
+
wsp_ggml_permute(ctx0,
|
|
1726
|
+
wsp_ggml_cpy(ctx0,
|
|
1727
|
+
Qcur,
|
|
1728
|
+
wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
|
1729
|
+
0, 2, 1, 3);
|
|
1730
|
+
|
|
1731
|
+
struct wsp_ggml_tensor * K =
|
|
1732
|
+
wsp_ggml_permute(ctx0,
|
|
1733
|
+
wsp_ggml_cpy(ctx0,
|
|
1734
|
+
Kcur,
|
|
1735
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1736
|
+
0, 2, 1, 3);
|
|
1737
|
+
|
|
1738
|
+
// K * Q
|
|
1739
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
|
|
1740
|
+
|
|
1741
|
+
struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQscale);
|
|
1742
|
+
|
|
1743
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_scaled);
|
|
1744
|
+
|
|
1745
|
+
struct wsp_ggml_tensor * V =
|
|
1746
|
+
wsp_ggml_cpy(ctx0,
|
|
1747
|
+
wsp_ggml_permute(ctx0,
|
|
1748
|
+
wsp_ggml_reshape_3d(ctx0,
|
|
1749
|
+
Vcur,
|
|
1750
|
+
n_state/n_head, n_head, n_ctx),
|
|
1751
|
+
1, 2, 0, 3),
|
|
1752
|
+
wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
|
1753
|
+
);
|
|
1754
|
+
|
|
1755
|
+
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
1683
1756
|
#endif
|
|
1684
|
-
|
|
1757
|
+
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
1685
1758
|
|
|
1686
|
-
|
|
1759
|
+
cur = wsp_ggml_cpy(ctx0,
|
|
1760
|
+
KQV_merged,
|
|
1761
|
+
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
|
|
1762
|
+
}
|
|
1687
1763
|
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
|
|
1691
|
-
|
|
1764
|
+
// projection
|
|
1765
|
+
{
|
|
1766
|
+
cur = wsp_ggml_mul_mat(ctx0,
|
|
1767
|
+
layer.attn_ln_1_w,
|
|
1768
|
+
cur);
|
|
1692
1769
|
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
wstate.use_buf(ctx0, 0);
|
|
1770
|
+
cur = wsp_ggml_add(ctx0, cur, layer.attn_ln_1_b);
|
|
1771
|
+
}
|
|
1696
1772
|
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
cur);
|
|
1773
|
+
// add the input
|
|
1774
|
+
cur = wsp_ggml_add(ctx0, cur, inpL);
|
|
1700
1775
|
|
|
1701
|
-
|
|
1776
|
+
struct wsp_ggml_tensor * inpFF = cur;
|
|
1777
|
+
|
|
1778
|
+
// feed-forward network
|
|
1779
|
+
{
|
|
1780
|
+
// norm
|
|
1781
|
+
{
|
|
1782
|
+
cur = wsp_ggml_norm(ctx0, inpFF, hparams.eps);
|
|
1702
1783
|
|
|
1784
|
+
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
1703
1785
|
cur = wsp_ggml_add(ctx0,
|
|
1704
|
-
|
|
1705
|
-
|
|
1786
|
+
wsp_ggml_mul(ctx0, cur, layer.mlp_ln_w),
|
|
1787
|
+
layer.mlp_ln_b);
|
|
1706
1788
|
}
|
|
1707
1789
|
|
|
1708
|
-
|
|
1790
|
+
#ifdef WHISPER_USE_FLASH_FF
|
|
1791
|
+
cur = wsp_ggml_flash_ff(ctx0,
|
|
1792
|
+
wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
|
1793
|
+
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
1794
|
+
#else
|
|
1795
|
+
// fully connected
|
|
1796
|
+
cur = wsp_ggml_mul_mat(ctx0,
|
|
1797
|
+
layer.mlp_0_w,
|
|
1798
|
+
cur);
|
|
1709
1799
|
|
|
1710
|
-
|
|
1711
|
-
cur = wsp_ggml_add(ctx0, cur, inpL);
|
|
1800
|
+
cur = wsp_ggml_add(ctx0, cur, layer.mlp_0_b);
|
|
1712
1801
|
|
|
1713
|
-
|
|
1802
|
+
// GELU activation
|
|
1803
|
+
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1714
1804
|
|
|
1715
|
-
//
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
wstate.use_buf(ctx0, 0);
|
|
1805
|
+
// projection
|
|
1806
|
+
cur = wsp_ggml_mul_mat(ctx0,
|
|
1807
|
+
layer.mlp_1_w,
|
|
1808
|
+
cur);
|
|
1720
1809
|
|
|
1721
|
-
|
|
1810
|
+
cur = wsp_ggml_add(ctx0, cur, layer.mlp_1_b);
|
|
1811
|
+
#endif
|
|
1812
|
+
}
|
|
1722
1813
|
|
|
1723
|
-
|
|
1814
|
+
inpL = wsp_ggml_add(ctx0, cur, inpFF);
|
|
1815
|
+
}
|
|
1724
1816
|
|
|
1725
|
-
|
|
1726
|
-
cur = wsp_ggml_add(ctx0,
|
|
1727
|
-
wsp_ggml_mul(ctx0,
|
|
1728
|
-
wsp_ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
|
1729
|
-
cur),
|
|
1730
|
-
wsp_ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
|
1731
|
-
}
|
|
1817
|
+
cur = inpL;
|
|
1732
1818
|
|
|
1733
|
-
|
|
1734
|
-
|
|
1819
|
+
// norm
|
|
1820
|
+
{
|
|
1821
|
+
cur = wsp_ggml_norm(ctx0, cur, hparams.eps);
|
|
1735
1822
|
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1739
|
-
|
|
1740
|
-
|
|
1823
|
+
// cur = ln_f_g*cur + ln_f_b
|
|
1824
|
+
cur = wsp_ggml_add(ctx0,
|
|
1825
|
+
wsp_ggml_mul(ctx0, cur, model.e_ln_w),
|
|
1826
|
+
model.e_ln_b);
|
|
1827
|
+
}
|
|
1741
1828
|
|
|
1742
|
-
|
|
1743
|
-
cur = wsp_ggml_mul_mat(ctx0,
|
|
1744
|
-
layer.mlp_0_w,
|
|
1745
|
-
cur);
|
|
1829
|
+
wsp_ggml_build_forward_expand(gf, cur);
|
|
1746
1830
|
|
|
1747
|
-
|
|
1831
|
+
wstate.embd_enc = cur;
|
|
1748
1832
|
|
|
1749
|
-
|
|
1750
|
-
wsp_ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
|
1751
|
-
cur);
|
|
1833
|
+
//wsp_ggml_graph_print(gf);
|
|
1752
1834
|
|
|
1753
|
-
|
|
1835
|
+
////////////////////////////////////////////////////////////////////////////
|
|
1754
1836
|
|
|
1755
|
-
|
|
1756
|
-
|
|
1837
|
+
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
1838
|
+
// wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
1839
|
+
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
1840
|
+
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
|
1841
|
+
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
|
1842
|
+
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
|
1757
1843
|
|
|
1758
|
-
|
|
1844
|
+
wsp_ggml_free(ctx0);
|
|
1759
1845
|
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
layer.mlp_1_w,
|
|
1763
|
-
cur);
|
|
1846
|
+
return gf;
|
|
1847
|
+
}
|
|
1764
1848
|
|
|
1765
|
-
|
|
1849
|
+
// pre-compute cross-attention memory
|
|
1850
|
+
static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
1851
|
+
whisper_context & wctx,
|
|
1852
|
+
whisper_state & wstate) {
|
|
1853
|
+
const auto & model = wctx.model;
|
|
1854
|
+
const auto & hparams = model.hparams;
|
|
1766
1855
|
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
#endif
|
|
1771
|
-
}
|
|
1856
|
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
1857
|
+
const int n_state = hparams.n_audio_state;
|
|
1858
|
+
const int n_head = hparams.n_audio_head;
|
|
1772
1859
|
|
|
1773
|
-
|
|
1860
|
+
struct wsp_ggml_init_params params = {
|
|
1861
|
+
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
|
|
1862
|
+
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
|
|
1863
|
+
/*.no_alloc =*/ true,
|
|
1864
|
+
};
|
|
1774
1865
|
|
|
1775
|
-
|
|
1776
|
-
}
|
|
1866
|
+
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
1777
1867
|
|
|
1778
|
-
|
|
1868
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
1779
1869
|
|
|
1780
|
-
|
|
1781
|
-
{
|
|
1782
|
-
wstate.use_buf(ctx0, 0);
|
|
1870
|
+
wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
|
1783
1871
|
|
|
1784
|
-
|
|
1872
|
+
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
|
|
1785
1873
|
|
|
1786
|
-
|
|
1874
|
+
struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
1875
|
+
wsp_ggml_allocr_alloc(alloc, Kscale);
|
|
1787
1876
|
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
wsp_ggml_repeat(ctx0, model.e_ln_w, cur),
|
|
1792
|
-
cur),
|
|
1793
|
-
wsp_ggml_repeat(ctx0, model.e_ln_b, cur));
|
|
1794
|
-
}
|
|
1877
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1878
|
+
wsp_ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
|
|
1879
|
+
}
|
|
1795
1880
|
|
|
1796
|
-
|
|
1881
|
+
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
1882
|
+
auto & layer = model.layers_decoder[il];
|
|
1797
1883
|
|
|
1798
|
-
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
gf.n_threads = n_threads;
|
|
1884
|
+
struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
|
|
1885
|
+
layer.cross_attn_k_w,
|
|
1886
|
+
cur);
|
|
1802
1887
|
|
|
1803
|
-
|
|
1804
|
-
wsp_ggml_graph_compute(ctx0, &gf);
|
|
1888
|
+
Kcross = wsp_ggml_scale(ctx0, Kcross, Kscale);
|
|
1805
1889
|
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
#ifdef WHISPER_USE_COREML
|
|
1810
|
-
else if (use_coreml) {
|
|
1811
|
-
wstate.use_buf(ctx0, -1);
|
|
1890
|
+
struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
|
|
1891
|
+
layer.cross_attn_v_w,
|
|
1892
|
+
cur);
|
|
1812
1893
|
|
|
1813
|
-
|
|
1894
|
+
Vcross = wsp_ggml_add(ctx0,
|
|
1895
|
+
Vcross,
|
|
1896
|
+
layer.cross_attn_v_b);
|
|
1814
1897
|
|
|
1815
|
-
|
|
1816
|
-
}
|
|
1817
|
-
#endif
|
|
1818
|
-
#ifdef WHISPER_USE_OPENVINO
|
|
1819
|
-
else if (use_openvino) {
|
|
1820
|
-
wstate.use_buf(ctx0, -1);
|
|
1898
|
+
Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
|
1821
1899
|
|
|
1822
|
-
|
|
1900
|
+
struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k,
|
|
1901
|
+
n_state*n_ctx,
|
|
1902
|
+
(wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
|
1823
1903
|
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1904
|
+
struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
|
1905
|
+
( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
1906
|
+
(il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
|
|
1907
|
+
|
|
1908
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcross, k));
|
|
1909
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcross, v));
|
|
1827
1910
|
}
|
|
1828
|
-
#endif
|
|
1829
1911
|
|
|
1830
|
-
//
|
|
1831
|
-
//{
|
|
1832
|
-
// printf("ne0 = %d\n", cur->ne[0]);
|
|
1833
|
-
// printf("ne1 = %d\n", cur->ne[1]);
|
|
1834
|
-
// for (int i = 0; i < 10; ++i) {
|
|
1835
|
-
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
1836
|
-
// }
|
|
1837
|
-
// printf("... ");
|
|
1838
|
-
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
|
|
1839
|
-
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
1840
|
-
// }
|
|
1841
|
-
// printf("\n");
|
|
1842
|
-
//}
|
|
1912
|
+
//wsp_ggml_graph_print(gf);
|
|
1843
1913
|
|
|
1844
|
-
|
|
1845
|
-
{
|
|
1846
|
-
struct wsp_ggml_cgraph gf = {};
|
|
1847
|
-
gf.n_threads = n_threads;
|
|
1914
|
+
wsp_ggml_free(ctx0);
|
|
1848
1915
|
|
|
1849
|
-
|
|
1850
|
-
|
|
1851
|
-
cur->src0 = nullptr;
|
|
1852
|
-
cur->src1 = nullptr;
|
|
1916
|
+
return gf;
|
|
1917
|
+
}
|
|
1853
1918
|
|
|
1854
|
-
|
|
1855
|
-
|
|
1919
|
+
// evaluate the encoder with the given state
|
|
1920
|
+
//
|
|
1921
|
+
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
|
1922
|
+
// part of the transformer model and returns the encoded features
|
|
1923
|
+
//
|
|
1924
|
+
// - wctx: the model
|
|
1925
|
+
// - wstate: the state of the encoder
|
|
1926
|
+
// - n_threads: number of threads to use
|
|
1927
|
+
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
|
1928
|
+
//
|
|
1929
|
+
static bool whisper_encode_internal(
|
|
1930
|
+
whisper_context & wctx,
|
|
1931
|
+
whisper_state & wstate,
|
|
1932
|
+
const int mel_offset,
|
|
1933
|
+
const int n_threads,
|
|
1934
|
+
whisper_abort_callback abort_callback,
|
|
1935
|
+
void * abort_callback_data) {
|
|
1936
|
+
const int64_t t_start_us = wsp_ggml_time_us();
|
|
1856
1937
|
|
|
1857
|
-
|
|
1938
|
+
// conv
|
|
1939
|
+
{
|
|
1940
|
+
auto & alloc = wstate.alloc_conv.alloc;
|
|
1858
1941
|
|
|
1859
|
-
|
|
1860
|
-
layer.cross_attn_k_w,
|
|
1861
|
-
cur);
|
|
1942
|
+
wsp_ggml_allocr_reset(alloc);
|
|
1862
1943
|
|
|
1863
|
-
|
|
1944
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
|
|
1864
1945
|
|
|
1865
|
-
|
|
1946
|
+
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1866
1947
|
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
1948
|
+
if (!whisper_encode_external(wstate)) {
|
|
1949
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1950
|
+
}
|
|
1951
|
+
}
|
|
1870
1952
|
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
Vcross),
|
|
1875
|
-
Vcross);
|
|
1953
|
+
// encoder
|
|
1954
|
+
if (!whisper_encode_external(wstate)) {
|
|
1955
|
+
auto & alloc = wstate.alloc_encode.alloc;
|
|
1876
1956
|
|
|
1877
|
-
|
|
1957
|
+
wsp_ggml_allocr_reset(alloc);
|
|
1878
1958
|
|
|
1879
|
-
|
|
1959
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
|
|
1880
1960
|
|
|
1881
|
-
|
|
1882
|
-
struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
|
1883
|
-
( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
1884
|
-
(il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
|
|
1961
|
+
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1885
1962
|
|
|
1886
|
-
|
|
1887
|
-
|
|
1963
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1964
|
+
if (wstate.ctx_metal) {
|
|
1965
|
+
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
1966
|
+
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
1967
|
+
} else {
|
|
1968
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1888
1969
|
}
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
1970
|
+
#else
|
|
1971
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1972
|
+
#endif
|
|
1892
1973
|
}
|
|
1893
1974
|
|
|
1894
|
-
|
|
1975
|
+
// cross
|
|
1976
|
+
{
|
|
1977
|
+
auto & alloc = wstate.alloc_cross.alloc;
|
|
1895
1978
|
|
|
1896
|
-
|
|
1897
|
-
// wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
1898
|
-
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
1899
|
-
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
|
1900
|
-
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
|
1901
|
-
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
|
1979
|
+
wsp_ggml_allocr_reset(alloc);
|
|
1902
1980
|
|
|
1903
|
-
|
|
1981
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
|
1982
|
+
|
|
1983
|
+
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1984
|
+
|
|
1985
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1986
|
+
if (wstate.ctx_metal) {
|
|
1987
|
+
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
1988
|
+
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
1989
|
+
} else {
|
|
1990
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1991
|
+
}
|
|
1992
|
+
#else
|
|
1993
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1994
|
+
#endif
|
|
1995
|
+
}
|
|
1996
|
+
|
|
1997
|
+
// wsp_ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
1904
1998
|
|
|
1905
1999
|
wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
|
|
1906
2000
|
wstate.n_encode++;
|
|
@@ -1908,26 +2002,13 @@ static bool whisper_encode_internal(
|
|
|
1908
2002
|
return true;
|
|
1909
2003
|
}
|
|
1910
2004
|
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
// - n_tokens: number of tokens in the prompt
|
|
1919
|
-
// - n_past: number of past tokens to prefix the prompt with
|
|
1920
|
-
//
|
|
1921
|
-
static bool whisper_decode_internal(
|
|
1922
|
-
whisper_context & wctx,
|
|
1923
|
-
whisper_state & wstate,
|
|
1924
|
-
whisper_decoder & decoder,
|
|
1925
|
-
const whisper_token * tokens,
|
|
1926
|
-
const int n_tokens,
|
|
1927
|
-
const int n_past,
|
|
1928
|
-
const int n_threads) {
|
|
1929
|
-
const int64_t t_start_us = wsp_ggml_time_us();
|
|
1930
|
-
|
|
2005
|
+
static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
2006
|
+
whisper_context & wctx,
|
|
2007
|
+
whisper_state & wstate,
|
|
2008
|
+
whisper_decoder & decoder,
|
|
2009
|
+
const whisper_token * tokens,
|
|
2010
|
+
int n_tokens,
|
|
2011
|
+
int n_past) {
|
|
1931
2012
|
const auto & model = wctx.model;
|
|
1932
2013
|
const auto & hparams = model.hparams;
|
|
1933
2014
|
|
|
@@ -1935,10 +2016,6 @@ static bool whisper_decode_internal(
|
|
|
1935
2016
|
|
|
1936
2017
|
WHISPER_ASSERT(!!kv_self.ctx);
|
|
1937
2018
|
|
|
1938
|
-
auto & logits_out = wstate.logits;
|
|
1939
|
-
|
|
1940
|
-
const int n_vocab = hparams.n_vocab;
|
|
1941
|
-
|
|
1942
2019
|
const int n_ctx = hparams.n_text_ctx;
|
|
1943
2020
|
const int n_state = hparams.n_text_state;
|
|
1944
2021
|
const int n_head = hparams.n_text_head;
|
|
@@ -1950,25 +2027,39 @@ static bool whisper_decode_internal(
|
|
|
1950
2027
|
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
|
1951
2028
|
|
|
1952
2029
|
struct wsp_ggml_init_params params = {
|
|
1953
|
-
/*.mem_size =*/ wstate.
|
|
1954
|
-
/*.mem_buffer =*/ wstate.
|
|
1955
|
-
/*.no_alloc =*/
|
|
2030
|
+
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
|
2031
|
+
/*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
|
|
2032
|
+
/*.no_alloc =*/ true,
|
|
1956
2033
|
};
|
|
1957
2034
|
|
|
1958
2035
|
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
1959
2036
|
|
|
1960
|
-
|
|
1961
|
-
|
|
2037
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
2038
|
+
|
|
2039
|
+
wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
|
1962
2040
|
|
|
1963
2041
|
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
|
|
1964
|
-
|
|
2042
|
+
wsp_ggml_allocr_alloc(alloc, embd);
|
|
2043
|
+
|
|
2044
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2045
|
+
memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd));
|
|
2046
|
+
}
|
|
1965
2047
|
|
|
1966
2048
|
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
|
|
1967
|
-
|
|
1968
|
-
|
|
2049
|
+
wsp_ggml_allocr_alloc(alloc, position);
|
|
2050
|
+
|
|
2051
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2052
|
+
for (int i = 0; i < N; ++i) {
|
|
2053
|
+
((int32_t *) position->data)[i] = n_past + i;
|
|
2054
|
+
}
|
|
1969
2055
|
}
|
|
1970
2056
|
|
|
1971
|
-
|
|
2057
|
+
struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
2058
|
+
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
2059
|
+
|
|
2060
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2061
|
+
wsp_ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
|
|
2062
|
+
}
|
|
1972
2063
|
|
|
1973
2064
|
// token encoding + position encoding
|
|
1974
2065
|
struct wsp_ggml_tensor * cur =
|
|
@@ -1983,16 +2074,14 @@ static bool whisper_decode_internal(
|
|
|
1983
2074
|
|
|
1984
2075
|
// norm
|
|
1985
2076
|
{
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
cur = wsp_ggml_norm(ctx0, inpL);
|
|
2077
|
+
cur = wsp_ggml_norm(ctx0, inpL, hparams.eps);
|
|
1989
2078
|
|
|
1990
2079
|
// cur = ln_0_w*cur + ln_0_b
|
|
1991
2080
|
cur = wsp_ggml_add(ctx0,
|
|
1992
2081
|
wsp_ggml_mul(ctx0,
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
|
|
2082
|
+
cur,
|
|
2083
|
+
layer.attn_ln_0_w),
|
|
2084
|
+
layer.attn_ln_0_b);
|
|
1996
2085
|
}
|
|
1997
2086
|
|
|
1998
2087
|
// self-attention
|
|
@@ -2002,19 +2091,17 @@ static bool whisper_decode_internal(
|
|
|
2002
2091
|
cur);
|
|
2003
2092
|
|
|
2004
2093
|
Qcur = wsp_ggml_add(ctx0,
|
|
2005
|
-
|
|
2006
|
-
layer.attn_q_b
|
|
2007
|
-
Qcur),
|
|
2008
|
-
Qcur);
|
|
2094
|
+
Qcur,
|
|
2095
|
+
layer.attn_q_b);
|
|
2009
2096
|
|
|
2010
|
-
Qcur =
|
|
2097
|
+
Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
|
|
2011
2098
|
|
|
2012
2099
|
// note: no bias for Key
|
|
2013
2100
|
struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
|
|
2014
2101
|
layer.attn_k_w,
|
|
2015
2102
|
cur);
|
|
2016
2103
|
|
|
2017
|
-
Kcur =
|
|
2104
|
+
Kcur = wsp_ggml_scale(ctx0, Kcur, KQscale);
|
|
2018
2105
|
|
|
2019
2106
|
// store key and value to memory
|
|
2020
2107
|
{
|
|
@@ -2023,10 +2110,8 @@ static bool whisper_decode_internal(
|
|
|
2023
2110
|
cur);
|
|
2024
2111
|
|
|
2025
2112
|
Vcur = wsp_ggml_add(ctx0,
|
|
2026
|
-
|
|
2027
|
-
layer.attn_v_b
|
|
2028
|
-
Vcur),
|
|
2029
|
-
Vcur);
|
|
2113
|
+
Vcur,
|
|
2114
|
+
layer.attn_v_b);
|
|
2030
2115
|
|
|
2031
2116
|
Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, N));
|
|
2032
2117
|
|
|
@@ -2035,42 +2120,32 @@ static bool whisper_decode_internal(
|
|
|
2035
2120
|
( n_ctx)*wsp_ggml_element_size(kv_self.v),
|
|
2036
2121
|
(il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + n_past*wsp_ggml_element_size(kv_self.v));
|
|
2037
2122
|
|
|
2038
|
-
wsp_ggml_build_forward_expand(
|
|
2039
|
-
wsp_ggml_build_forward_expand(
|
|
2123
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
|
|
2124
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
|
|
2040
2125
|
}
|
|
2041
2126
|
|
|
2042
2127
|
// ------
|
|
2043
2128
|
|
|
2044
|
-
wstate.use_buf(ctx0, 0);
|
|
2045
|
-
|
|
2046
2129
|
struct wsp_ggml_tensor * Q =
|
|
2047
2130
|
wsp_ggml_permute(ctx0,
|
|
2048
|
-
|
|
2049
|
-
Qcur,
|
|
2050
|
-
wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
|
2131
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
|
2051
2132
|
0, 2, 1, 3);
|
|
2052
2133
|
|
|
2053
2134
|
struct wsp_ggml_tensor * K =
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
wstate.use_buf(ctx0, 1);
|
|
2135
|
+
wsp_ggml_view_3d(ctx0, kv_self.k,
|
|
2136
|
+
n_state/n_head, n_past + N, n_head,
|
|
2137
|
+
wsp_ggml_element_size(kv_self.k)*n_state,
|
|
2138
|
+
wsp_ggml_element_size(kv_self.k)*n_state/n_head,
|
|
2139
|
+
wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
|
2061
2140
|
|
|
2062
2141
|
// K * Q
|
|
2063
2142
|
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
|
|
2064
2143
|
|
|
2065
|
-
//struct wsp_ggml_tensor * KQ_scaled =
|
|
2066
|
-
// wsp_ggml_scale_inplace(ctx0,
|
|
2067
|
-
// KQ,
|
|
2068
|
-
// wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
|
2069
|
-
// );
|
|
2144
|
+
//struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
|
|
2070
2145
|
|
|
2071
|
-
struct wsp_ggml_tensor * KQ_masked =
|
|
2146
|
+
struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
2072
2147
|
|
|
2073
|
-
struct wsp_ggml_tensor * KQ_soft_max =
|
|
2148
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
|
|
2074
2149
|
|
|
2075
2150
|
struct wsp_ggml_tensor * V =
|
|
2076
2151
|
wsp_ggml_view_3d(ctx0, kv_self.v,
|
|
@@ -2090,36 +2165,28 @@ static bool whisper_decode_internal(
|
|
|
2090
2165
|
|
|
2091
2166
|
// projection
|
|
2092
2167
|
{
|
|
2093
|
-
wstate.use_buf(ctx0, 0);
|
|
2094
|
-
|
|
2095
2168
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
2096
2169
|
layer.attn_ln_1_w,
|
|
2097
2170
|
cur);
|
|
2098
2171
|
|
|
2099
|
-
wstate.use_buf(ctx0, 1);
|
|
2100
|
-
|
|
2101
2172
|
cur = wsp_ggml_add(ctx0,
|
|
2102
|
-
|
|
2103
|
-
|
|
2173
|
+
cur,
|
|
2174
|
+
layer.attn_ln_1_b);
|
|
2104
2175
|
}
|
|
2105
2176
|
|
|
2106
|
-
wstate.use_buf(ctx0, 2);
|
|
2107
|
-
|
|
2108
2177
|
// add the input
|
|
2109
2178
|
struct wsp_ggml_tensor * inpCA = wsp_ggml_add(ctx0, cur, inpL);
|
|
2110
2179
|
|
|
2111
2180
|
// norm
|
|
2112
2181
|
{
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
cur = wsp_ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
|
2182
|
+
cur = wsp_ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
|
|
2116
2183
|
|
|
2117
2184
|
// cur = ln_0_w*cur + ln_0_b
|
|
2118
2185
|
cur = wsp_ggml_add(ctx0,
|
|
2119
2186
|
wsp_ggml_mul(ctx0,
|
|
2120
|
-
|
|
2121
|
-
|
|
2122
|
-
|
|
2187
|
+
cur,
|
|
2188
|
+
layer.cross_attn_ln_0_w),
|
|
2189
|
+
layer.cross_attn_ln_0_b);
|
|
2123
2190
|
}
|
|
2124
2191
|
|
|
2125
2192
|
// cross-attention
|
|
@@ -2129,18 +2196,18 @@ static bool whisper_decode_internal(
|
|
|
2129
2196
|
cur);
|
|
2130
2197
|
|
|
2131
2198
|
Qcur = wsp_ggml_add(ctx0,
|
|
2132
|
-
|
|
2133
|
-
layer.cross_attn_q_b
|
|
2134
|
-
Qcur),
|
|
2135
|
-
Qcur);
|
|
2199
|
+
Qcur,
|
|
2200
|
+
layer.cross_attn_q_b);
|
|
2136
2201
|
|
|
2137
|
-
Qcur =
|
|
2202
|
+
Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
|
|
2138
2203
|
|
|
2139
2204
|
// Kcross is already scaled
|
|
2140
2205
|
struct wsp_ggml_tensor * Kcross =
|
|
2141
|
-
|
|
2142
|
-
|
|
2143
|
-
n_state
|
|
2206
|
+
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2207
|
+
n_state/n_head, M, n_head,
|
|
2208
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2209
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
|
2210
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
|
|
2144
2211
|
|
|
2145
2212
|
//struct wsp_ggml_tensor * Vcross =
|
|
2146
2213
|
// wsp_ggml_reshape_3d(ctx0,
|
|
@@ -2163,26 +2230,22 @@ static bool whisper_decode_internal(
|
|
|
2163
2230
|
|
|
2164
2231
|
struct wsp_ggml_tensor * Q =
|
|
2165
2232
|
wsp_ggml_permute(ctx0,
|
|
2166
|
-
|
|
2167
|
-
Qcur,
|
|
2168
|
-
wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
|
2233
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
|
2169
2234
|
0, 2, 1, 3);
|
|
2170
2235
|
|
|
2171
|
-
struct wsp_ggml_tensor * K = wsp_ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
|
2172
|
-
|
|
2173
2236
|
// K * Q
|
|
2174
|
-
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0,
|
|
2237
|
+
struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
|
|
2175
2238
|
|
|
2176
2239
|
//struct wsp_ggml_tensor * KQ_scaled =
|
|
2177
|
-
//
|
|
2240
|
+
// wsp_ggml_scale(ctx0,
|
|
2178
2241
|
// KQ,
|
|
2179
2242
|
// wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
|
2180
2243
|
// );
|
|
2181
2244
|
|
|
2182
2245
|
// no masking for cross-attention
|
|
2183
|
-
//struct wsp_ggml_tensor * KQ_masked =
|
|
2246
|
+
//struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
|
2184
2247
|
|
|
2185
|
-
struct wsp_ggml_tensor * KQ_soft_max =
|
|
2248
|
+
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ);
|
|
2186
2249
|
|
|
2187
2250
|
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
2188
2251
|
|
|
@@ -2196,21 +2259,15 @@ static bool whisper_decode_internal(
|
|
|
2196
2259
|
|
|
2197
2260
|
// projection
|
|
2198
2261
|
{
|
|
2199
|
-
wstate.use_buf(ctx0, 0);
|
|
2200
|
-
|
|
2201
2262
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
2202
2263
|
layer.cross_attn_ln_1_w,
|
|
2203
2264
|
cur);
|
|
2204
2265
|
|
|
2205
|
-
wstate.use_buf(ctx0, 1);
|
|
2206
|
-
|
|
2207
2266
|
cur = wsp_ggml_add(ctx0,
|
|
2208
|
-
|
|
2209
|
-
|
|
2267
|
+
cur,
|
|
2268
|
+
layer.cross_attn_ln_1_b);
|
|
2210
2269
|
}
|
|
2211
2270
|
|
|
2212
|
-
wstate.use_buf(ctx0, 2);
|
|
2213
|
-
|
|
2214
2271
|
// add the input
|
|
2215
2272
|
cur = wsp_ggml_add(ctx0, cur, inpCA);
|
|
2216
2273
|
|
|
@@ -2220,54 +2277,38 @@ static bool whisper_decode_internal(
|
|
|
2220
2277
|
{
|
|
2221
2278
|
// norm
|
|
2222
2279
|
{
|
|
2223
|
-
|
|
2224
|
-
|
|
2225
|
-
cur = wsp_ggml_norm(ctx0, inpFF);
|
|
2226
|
-
|
|
2227
|
-
wstate.use_buf(ctx0, 1);
|
|
2280
|
+
cur = wsp_ggml_norm(ctx0, inpFF, hparams.eps);
|
|
2228
2281
|
|
|
2229
2282
|
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
2230
2283
|
cur = wsp_ggml_add(ctx0,
|
|
2231
2284
|
wsp_ggml_mul(ctx0,
|
|
2232
|
-
|
|
2233
|
-
|
|
2234
|
-
|
|
2285
|
+
cur,
|
|
2286
|
+
layer.mlp_ln_w),
|
|
2287
|
+
layer.mlp_ln_b);
|
|
2235
2288
|
}
|
|
2236
2289
|
|
|
2237
|
-
wstate.use_buf(ctx0, 0);
|
|
2238
|
-
|
|
2239
2290
|
// fully connected
|
|
2240
2291
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
2241
2292
|
layer.mlp_0_w,
|
|
2242
2293
|
cur);
|
|
2243
2294
|
|
|
2244
|
-
wstate.use_buf(ctx0, 1);
|
|
2245
|
-
|
|
2246
2295
|
cur = wsp_ggml_add(ctx0,
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2250
|
-
wstate.use_buf(ctx0, 0);
|
|
2296
|
+
cur,
|
|
2297
|
+
layer.mlp_0_b);
|
|
2251
2298
|
|
|
2252
2299
|
// GELU activation
|
|
2253
2300
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
2254
2301
|
|
|
2255
|
-
wstate.use_buf(ctx0, 1);
|
|
2256
|
-
|
|
2257
2302
|
// projection
|
|
2258
2303
|
cur = wsp_ggml_mul_mat(ctx0,
|
|
2259
2304
|
layer.mlp_1_w,
|
|
2260
2305
|
cur);
|
|
2261
2306
|
|
|
2262
|
-
wstate.use_buf(ctx0, 0);
|
|
2263
|
-
|
|
2264
2307
|
cur = wsp_ggml_add(ctx0,
|
|
2265
|
-
|
|
2266
|
-
|
|
2308
|
+
cur,
|
|
2309
|
+
layer.mlp_1_b);
|
|
2267
2310
|
}
|
|
2268
2311
|
|
|
2269
|
-
wstate.use_buf(ctx0, 3);
|
|
2270
|
-
|
|
2271
2312
|
inpL = wsp_ggml_add(ctx0, cur, inpFF);
|
|
2272
2313
|
}
|
|
2273
2314
|
|
|
@@ -2275,21 +2316,15 @@ static bool whisper_decode_internal(
|
|
|
2275
2316
|
|
|
2276
2317
|
// norm
|
|
2277
2318
|
{
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
cur = wsp_ggml_norm(ctx0, cur);
|
|
2281
|
-
|
|
2282
|
-
wstate.use_buf(ctx0, 1);
|
|
2319
|
+
cur = wsp_ggml_norm(ctx0, cur, hparams.eps);
|
|
2283
2320
|
|
|
2284
2321
|
cur = wsp_ggml_add(ctx0,
|
|
2285
2322
|
wsp_ggml_mul(ctx0,
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2323
|
+
cur,
|
|
2324
|
+
model.d_ln_w),
|
|
2325
|
+
model.d_ln_b);
|
|
2289
2326
|
}
|
|
2290
2327
|
|
|
2291
|
-
wstate.use_buf(ctx0, 0);
|
|
2292
|
-
|
|
2293
2328
|
// compute logits only for the last token
|
|
2294
2329
|
// comment this line to compute logits for all N tokens
|
|
2295
2330
|
// might be useful in the future
|
|
@@ -2297,23 +2332,77 @@ static bool whisper_decode_internal(
|
|
|
2297
2332
|
|
|
2298
2333
|
struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
|
|
2299
2334
|
|
|
2300
|
-
|
|
2335
|
+
wsp_ggml_build_forward_expand(gf, logits);
|
|
2336
|
+
|
|
2337
|
+
wsp_ggml_free(ctx0);
|
|
2338
|
+
|
|
2339
|
+
return gf;
|
|
2340
|
+
}
|
|
2341
|
+
|
|
2342
|
+
// evaluate the decoder
|
|
2343
|
+
//
|
|
2344
|
+
// given text prompt + audio features -> computes the logits for the next token
|
|
2345
|
+
//
|
|
2346
|
+
// - model: the model
|
|
2347
|
+
// - n_threads: number of threads to use
|
|
2348
|
+
// - tokens: text prompt
|
|
2349
|
+
// - n_tokens: number of tokens in the prompt
|
|
2350
|
+
// - n_past: number of past tokens to prefix the prompt with
|
|
2351
|
+
//
|
|
2352
|
+
static bool whisper_decode_internal(
|
|
2353
|
+
whisper_context & wctx,
|
|
2354
|
+
whisper_state & wstate,
|
|
2355
|
+
whisper_decoder & decoder,
|
|
2356
|
+
const whisper_token * tokens,
|
|
2357
|
+
const int n_tokens,
|
|
2358
|
+
const int n_past,
|
|
2359
|
+
const int n_threads,
|
|
2360
|
+
whisper_abort_callback abort_callback,
|
|
2361
|
+
void * abort_callback_data) {
|
|
2362
|
+
const int64_t t_start_us = wsp_ggml_time_us();
|
|
2363
|
+
|
|
2364
|
+
const auto & model = wctx.model;
|
|
2365
|
+
const auto & hparams = model.hparams;
|
|
2366
|
+
|
|
2367
|
+
const int n_vocab = hparams.n_vocab;
|
|
2368
|
+
|
|
2369
|
+
auto & logits_out = wstate.logits;
|
|
2370
|
+
|
|
2371
|
+
struct wsp_ggml_tensor * logits;
|
|
2301
2372
|
|
|
2302
|
-
//
|
|
2373
|
+
// decoder
|
|
2303
2374
|
{
|
|
2304
|
-
|
|
2305
|
-
|
|
2375
|
+
auto & alloc = wstate.alloc_decode.alloc;
|
|
2376
|
+
|
|
2377
|
+
wsp_ggml_allocr_reset(alloc);
|
|
2378
|
+
|
|
2379
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
|
|
2380
|
+
|
|
2381
|
+
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
2382
|
+
|
|
2383
|
+
logits = gf->nodes[gf->n_nodes - 1];
|
|
2384
|
+
|
|
2385
|
+
#ifdef WSP_GGML_USE_METAL
|
|
2386
|
+
if (wstate.ctx_metal) {
|
|
2387
|
+
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
2388
|
+
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
2389
|
+
} else {
|
|
2390
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2391
|
+
}
|
|
2392
|
+
#else
|
|
2393
|
+
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2394
|
+
#endif
|
|
2306
2395
|
}
|
|
2307
2396
|
|
|
2308
2397
|
// extract logits for all N tokens
|
|
2309
|
-
//logits_out.resize(
|
|
2310
|
-
//memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*
|
|
2398
|
+
//logits_out.resize(n_tokens*n_vocab);
|
|
2399
|
+
//memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
|
|
2311
2400
|
|
|
2312
2401
|
// extract logits only for the last token
|
|
2313
2402
|
logits_out.resize(n_vocab);
|
|
2314
2403
|
memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_vocab);
|
|
2315
2404
|
|
|
2316
|
-
if (
|
|
2405
|
+
if (n_tokens > 1) {
|
|
2317
2406
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
2318
2407
|
// wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
2319
2408
|
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
@@ -2322,14 +2411,18 @@ static bool whisper_decode_internal(
|
|
|
2322
2411
|
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
|
2323
2412
|
}
|
|
2324
2413
|
|
|
2325
|
-
|
|
2326
|
-
|
|
2327
|
-
|
|
2328
|
-
|
|
2414
|
+
if (n_tokens == 1) {
|
|
2415
|
+
wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
|
|
2416
|
+
wstate.n_decode++;
|
|
2417
|
+
} else {
|
|
2418
|
+
wstate.t_prompt_us += wsp_ggml_time_us() - t_start_us;
|
|
2419
|
+
wstate.n_prompt++;
|
|
2420
|
+
}
|
|
2329
2421
|
|
|
2330
2422
|
return true;
|
|
2331
2423
|
}
|
|
2332
2424
|
|
|
2425
|
+
|
|
2333
2426
|
// 500 -> 00:05.000
|
|
2334
2427
|
// 6000 -> 01:00.000
|
|
2335
2428
|
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
@@ -2351,7 +2444,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
2351
2444
|
static float sin_vals[SIN_COS_N_COUNT];
|
|
2352
2445
|
static float cos_vals[SIN_COS_N_COUNT];
|
|
2353
2446
|
|
|
2354
|
-
// In FFT, we frequently use sine and cosine operations with the same values.
|
|
2447
|
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
|
2355
2448
|
// We can use precalculated values to speed up the process.
|
|
2356
2449
|
static void fill_sin_cos_table() {
|
|
2357
2450
|
static bool is_filled = false;
|
|
@@ -2446,7 +2539,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2446
2539
|
}
|
|
2447
2540
|
|
|
2448
2541
|
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
|
|
2449
|
-
if (output.size() < length) {
|
|
2542
|
+
if (output.size() < static_cast<size_t>(length)) {
|
|
2450
2543
|
output.resize(length);
|
|
2451
2544
|
}
|
|
2452
2545
|
int offset = -1;
|
|
@@ -2738,9 +2831,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2738
2831
|
fill_sin_cos_table();
|
|
2739
2832
|
whisper_state * state = new whisper_state;
|
|
2740
2833
|
|
|
2741
|
-
|
|
2742
|
-
|
|
2743
|
-
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
|
2834
|
+
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
|
2744
2835
|
log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
|
2745
2836
|
delete state;
|
|
2746
2837
|
return nullptr;
|
|
@@ -2751,7 +2842,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2751
2842
|
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
|
2752
2843
|
}
|
|
2753
2844
|
|
|
2754
|
-
if (!kv_cache_init(ctx->model.hparams,
|
|
2845
|
+
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
|
2755
2846
|
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
2756
2847
|
delete state;
|
|
2757
2848
|
return nullptr;
|
|
@@ -2772,6 +2863,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2772
2863
|
if (!state->ctx_coreml) {
|
|
2773
2864
|
log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
2774
2865
|
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
|
2866
|
+
delete state;
|
|
2775
2867
|
return nullptr;
|
|
2776
2868
|
#endif
|
|
2777
2869
|
} else {
|
|
@@ -2786,15 +2878,111 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2786
2878
|
// TAGS: WHISPER_DECODER_INIT
|
|
2787
2879
|
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
|
2788
2880
|
|
|
2789
|
-
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
|
2790
|
-
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
|
2881
|
+
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
|
|
2882
|
+
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
|
|
2791
2883
|
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
|
2792
|
-
state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
|
|
2793
2884
|
|
|
2794
|
-
|
|
2795
|
-
|
|
2796
|
-
|
|
2797
|
-
|
|
2885
|
+
// conv allocator
|
|
2886
|
+
{
|
|
2887
|
+
whisper_allocr_graph_init(state->alloc_conv,
|
|
2888
|
+
[&]() {
|
|
2889
|
+
return whisper_build_graph_conv(*ctx, *state, 0);
|
|
2890
|
+
});
|
|
2891
|
+
|
|
2892
|
+
log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
|
|
2893
|
+
}
|
|
2894
|
+
|
|
2895
|
+
// encoder allocator
|
|
2896
|
+
if (!whisper_encode_external(*state)) {
|
|
2897
|
+
whisper_allocr_graph_init(state->alloc_encode,
|
|
2898
|
+
[&]() {
|
|
2899
|
+
return whisper_build_graph_encoder(*ctx, *state);
|
|
2900
|
+
});
|
|
2901
|
+
|
|
2902
|
+
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
|
|
2903
|
+
}
|
|
2904
|
+
|
|
2905
|
+
// cross allocator
|
|
2906
|
+
{
|
|
2907
|
+
whisper_allocr_graph_init(state->alloc_cross,
|
|
2908
|
+
[&]() {
|
|
2909
|
+
return whisper_build_graph_cross(*ctx, *state);
|
|
2910
|
+
});
|
|
2911
|
+
|
|
2912
|
+
log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
|
|
2913
|
+
}
|
|
2914
|
+
|
|
2915
|
+
// decoder allocator
|
|
2916
|
+
{
|
|
2917
|
+
whisper_allocr_graph_init(state->alloc_decode,
|
|
2918
|
+
[&]() {
|
|
2919
|
+
const auto & hparams = ctx->model.hparams;
|
|
2920
|
+
|
|
2921
|
+
// TODO: make sure this is the worst-case scenario
|
|
2922
|
+
const int n_tokens = hparams.n_text_ctx;
|
|
2923
|
+
const int n_past = 0;
|
|
2924
|
+
|
|
2925
|
+
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
|
|
2926
|
+
});
|
|
2927
|
+
|
|
2928
|
+
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
|
2929
|
+
}
|
|
2930
|
+
|
|
2931
|
+
#ifdef WSP_GGML_USE_METAL
|
|
2932
|
+
state->ctx_metal = wsp_ggml_metal_init(1);
|
|
2933
|
+
if (!state->ctx_metal) {
|
|
2934
|
+
log("%s: wsp_ggml_metal_init() failed\n", __func__);
|
|
2935
|
+
delete state;
|
|
2936
|
+
return nullptr;
|
|
2937
|
+
}
|
|
2938
|
+
|
|
2939
|
+
log("%s: Metal context initialized\n", __func__);
|
|
2940
|
+
|
|
2941
|
+
// this allocates all Metal resources and memory buffers
|
|
2942
|
+
|
|
2943
|
+
void * data_ptr = NULL;
|
|
2944
|
+
size_t data_size = 0;
|
|
2945
|
+
|
|
2946
|
+
// TODO: add mmap support
|
|
2947
|
+
//if (params.use_mmap) {
|
|
2948
|
+
// data_ptr = ctx->model.mapping->addr;
|
|
2949
|
+
// data_size = ctx->model.mapping->size;
|
|
2950
|
+
//} else {
|
|
2951
|
+
// data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
|
|
2952
|
+
// data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
|
|
2953
|
+
//}
|
|
2954
|
+
|
|
2955
|
+
data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
|
|
2956
|
+
data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
|
|
2957
|
+
|
|
2958
|
+
const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx);
|
|
2959
|
+
|
|
2960
|
+
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
|
|
2961
|
+
|
|
2962
|
+
#define WHISPER_METAL_CHECK_BUF(result) \
|
|
2963
|
+
if (!(result)) { \
|
|
2964
|
+
log("%s: failed to add metal buffer\n", __func__); \
|
|
2965
|
+
delete state; \
|
|
2966
|
+
return nullptr; \
|
|
2967
|
+
}
|
|
2968
|
+
|
|
2969
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
|
|
2970
|
+
|
|
2971
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
|
|
2972
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
|
|
2973
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
|
|
2974
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
|
|
2975
|
+
|
|
2976
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
|
|
2977
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
|
|
2978
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
|
|
2979
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
|
|
2980
|
+
|
|
2981
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
|
|
2982
|
+
|
|
2983
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
|
|
2984
|
+
#undef WHISPER_METAL_CHECK_BUF
|
|
2985
|
+
#endif
|
|
2798
2986
|
|
|
2799
2987
|
state->rng = std::mt19937(0);
|
|
2800
2988
|
|
|
@@ -2851,7 +3039,6 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
2851
3039
|
}
|
|
2852
3040
|
|
|
2853
3041
|
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
|
2854
|
-
|
|
2855
3042
|
log("%s: loading model from '%s'\n", __func__, path_model);
|
|
2856
3043
|
|
|
2857
3044
|
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
@@ -3004,6 +3191,13 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3004
3191
|
}
|
|
3005
3192
|
#endif
|
|
3006
3193
|
|
|
3194
|
+
#ifdef WSP_GGML_USE_METAL
|
|
3195
|
+
if (state->ctx_metal) {
|
|
3196
|
+
wsp_ggml_metal_free(state->ctx_metal);
|
|
3197
|
+
state->ctx_metal = nullptr;
|
|
3198
|
+
}
|
|
3199
|
+
#endif
|
|
3200
|
+
|
|
3007
3201
|
#ifdef WHISPER_USE_OPENVINO
|
|
3008
3202
|
if (state->ctx_openvino != nullptr) {
|
|
3009
3203
|
whisper_openvino_free(state->ctx_openvino);
|
|
@@ -3011,6 +3205,11 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3011
3205
|
}
|
|
3012
3206
|
#endif
|
|
3013
3207
|
|
|
3208
|
+
whisper_allocr_free(state->alloc_conv);
|
|
3209
|
+
whisper_allocr_free(state->alloc_decode);
|
|
3210
|
+
whisper_allocr_free(state->alloc_cross);
|
|
3211
|
+
whisper_allocr_free(state->alloc_encode);
|
|
3212
|
+
|
|
3014
3213
|
delete state;
|
|
3015
3214
|
}
|
|
3016
3215
|
}
|
|
@@ -3103,7 +3302,7 @@ int whisper_set_mel(
|
|
|
3103
3302
|
}
|
|
3104
3303
|
|
|
3105
3304
|
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
|
3106
|
-
if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
|
|
3305
|
+
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
|
3107
3306
|
log("%s: failed to eval\n", __func__);
|
|
3108
3307
|
return -1;
|
|
3109
3308
|
}
|
|
@@ -3112,7 +3311,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
|
3112
3311
|
}
|
|
3113
3312
|
|
|
3114
3313
|
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
3115
|
-
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
|
|
3314
|
+
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
|
3116
3315
|
log("%s: failed to eval\n", __func__);
|
|
3117
3316
|
return -1;
|
|
3118
3317
|
}
|
|
@@ -3123,7 +3322,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
|
3123
3322
|
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
3124
3323
|
const int selected_decoder_id = 0;
|
|
3125
3324
|
|
|
3126
|
-
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
|
3325
|
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
|
3127
3326
|
log("%s: failed to eval\n", __func__);
|
|
3128
3327
|
return 1;
|
|
3129
3328
|
}
|
|
@@ -3140,7 +3339,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
|
|
3140
3339
|
return false;
|
|
3141
3340
|
}
|
|
3142
3341
|
|
|
3143
|
-
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
|
3342
|
+
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
|
3144
3343
|
log("%s: failed to eval\n", __func__);
|
|
3145
3344
|
return 1;
|
|
3146
3345
|
}
|
|
@@ -3431,12 +3630,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
|
|
3431
3630
|
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
3432
3631
|
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
|
3433
3632
|
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
3633
|
+
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
|
3434
3634
|
|
|
3435
3635
|
log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
|
3436
3636
|
log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
|
3437
3637
|
log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
|
3438
3638
|
log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
|
3439
3639
|
log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
|
3640
|
+
log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
|
3440
3641
|
}
|
|
3441
3642
|
log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
|
3442
3643
|
}
|
|
@@ -3446,6 +3647,11 @@ void whisper_reset_timings(struct whisper_context * ctx) {
|
|
|
3446
3647
|
ctx->state->t_sample_us = 0;
|
|
3447
3648
|
ctx->state->t_encode_us = 0;
|
|
3448
3649
|
ctx->state->t_decode_us = 0;
|
|
3650
|
+
ctx->state->t_prompt_us = 0;
|
|
3651
|
+
ctx->state->n_sample = 0;
|
|
3652
|
+
ctx->state->n_encode = 0;
|
|
3653
|
+
ctx->state->n_decode = 0;
|
|
3654
|
+
ctx->state->n_prompt = 0;
|
|
3449
3655
|
}
|
|
3450
3656
|
}
|
|
3451
3657
|
|
|
@@ -3475,6 +3681,7 @@ const char * whisper_print_system_info(void) {
|
|
|
3475
3681
|
s += "FMA = " + std::to_string(wsp_ggml_cpu_has_fma()) + " | ";
|
|
3476
3682
|
s += "NEON = " + std::to_string(wsp_ggml_cpu_has_neon()) + " | ";
|
|
3477
3683
|
s += "ARM_FMA = " + std::to_string(wsp_ggml_cpu_has_arm_fma()) + " | ";
|
|
3684
|
+
s += "METAL = " + std::to_string(wsp_ggml_cpu_has_metal()) + " | ";
|
|
3478
3685
|
s += "F16C = " + std::to_string(wsp_ggml_cpu_has_f16c()) + " | ";
|
|
3479
3686
|
s += "FP16_VA = " + std::to_string(wsp_ggml_cpu_has_fp16_va()) + " | ";
|
|
3480
3687
|
s += "WASM_SIMD = " + std::to_string(wsp_ggml_cpu_has_wasm_simd()) + " | ";
|
|
@@ -3566,6 +3773,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3566
3773
|
/*.encoder_begin_callback =*/ nullptr,
|
|
3567
3774
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
3568
3775
|
|
|
3776
|
+
/*.abort_callback =*/ nullptr,
|
|
3777
|
+
/*.abort_callback_user_data =*/ nullptr,
|
|
3778
|
+
|
|
3569
3779
|
/*.logits_filter_callback =*/ nullptr,
|
|
3570
3780
|
/*.logits_filter_callback_user_data =*/ nullptr,
|
|
3571
3781
|
};
|
|
@@ -3970,17 +4180,21 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
3970
4180
|
|
|
3971
4181
|
auto & logits_id = state.logits_id;
|
|
3972
4182
|
|
|
3973
|
-
logits_id.
|
|
4183
|
+
logits_id.resize(n_logits);
|
|
3974
4184
|
for (int i = 0; i < n_logits; ++i) {
|
|
3975
|
-
logits_id.
|
|
4185
|
+
logits_id[i].first = logits[i];
|
|
4186
|
+
logits_id[i].second = i;
|
|
3976
4187
|
}
|
|
3977
4188
|
|
|
3978
|
-
|
|
3979
|
-
|
|
3980
|
-
|
|
3981
|
-
|
|
3982
|
-
|
|
3983
|
-
|
|
4189
|
+
{
|
|
4190
|
+
using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
|
|
4191
|
+
std::partial_sort(
|
|
4192
|
+
logits_id.begin(),
|
|
4193
|
+
logits_id.begin() + k, logits_id.end(),
|
|
4194
|
+
[](const pair_type & a, const pair_type & b) {
|
|
4195
|
+
return a.first > b.first;
|
|
4196
|
+
});
|
|
4197
|
+
}
|
|
3984
4198
|
|
|
3985
4199
|
std::vector<whisper_token_data> result;
|
|
3986
4200
|
result.reserve(k);
|
|
@@ -4075,6 +4289,115 @@ static void whisper_sequence_score(
|
|
|
4075
4289
|
}
|
|
4076
4290
|
}
|
|
4077
4291
|
|
|
4292
|
+
static bool whisper_kv_swap_fast(
|
|
4293
|
+
std::vector<int> & view,
|
|
4294
|
+
whisper_decoder src[],
|
|
4295
|
+
std::vector<kv_buf> & kv_swap_bufs,
|
|
4296
|
+
const int & n_decoders) {
|
|
4297
|
+
WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
|
|
4298
|
+
|
|
4299
|
+
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
|
|
4300
|
+
std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
|
|
4301
|
+
|
|
4302
|
+
// (buffer->decoder or decoder->decoder)
|
|
4303
|
+
std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
|
|
4304
|
+
|
|
4305
|
+
// (decoder<->decoder)
|
|
4306
|
+
std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
|
|
4307
|
+
std::vector<whisper_pair<int, int>> p_swap_vec;
|
|
4308
|
+
p_swap_vec.reserve(n_decoders);
|
|
4309
|
+
|
|
4310
|
+
// see https://github.com/ggerganov/whisper.cpp/wiki
|
|
4311
|
+
for (int i = 0; i < n_decoders; i++) {
|
|
4312
|
+
// zero-copy (no modification)
|
|
4313
|
+
if (i == view[i] || view[i] < 0) {
|
|
4314
|
+
continue;
|
|
4315
|
+
}
|
|
4316
|
+
|
|
4317
|
+
bool is_one_copy = true;
|
|
4318
|
+
// since we modify data sequentially, we only consider decoder indices after current index
|
|
4319
|
+
for (int j = i + 1; j < n_decoders; j++) {
|
|
4320
|
+
if (i == view[j]) {
|
|
4321
|
+
// detect symmetric diagram
|
|
4322
|
+
if (j == view[i]) {
|
|
4323
|
+
p_swap_set.insert(i);
|
|
4324
|
+
p_swap_set.insert(j);
|
|
4325
|
+
p_swap_vec.emplace_back(i, j);
|
|
4326
|
+
} else {
|
|
4327
|
+
two_copy.insert(i);
|
|
4328
|
+
is_one_copy = false;
|
|
4329
|
+
}
|
|
4330
|
+
break;
|
|
4331
|
+
}
|
|
4332
|
+
}
|
|
4333
|
+
if (is_one_copy) {
|
|
4334
|
+
one_copy.insert(i);
|
|
4335
|
+
}
|
|
4336
|
+
}
|
|
4337
|
+
|
|
4338
|
+
kv_swap_bufs.resize(n_decoders);
|
|
4339
|
+
|
|
4340
|
+
for (int i = 0; i < n_decoders; i++) {
|
|
4341
|
+
kv_swap_bufs[i].k.resize(wsp_ggml_nbytes(src[i].kv_self.k));
|
|
4342
|
+
kv_swap_bufs[i].v.resize(wsp_ggml_nbytes(src[i].kv_self.v));
|
|
4343
|
+
}
|
|
4344
|
+
|
|
4345
|
+
for (auto & i : two_copy) {
|
|
4346
|
+
// make a copy of KV caches
|
|
4347
|
+
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
|
|
4348
|
+
memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
|
|
4349
|
+
memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
|
|
4350
|
+
}
|
|
4351
|
+
|
|
4352
|
+
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
|
|
4353
|
+
for (auto & i : two_copy) {
|
|
4354
|
+
// skip the decoder indices that require pointer swapping
|
|
4355
|
+
if (p_swap_set.find(i) != p_swap_set.end()) {
|
|
4356
|
+
continue;
|
|
4357
|
+
}
|
|
4358
|
+
|
|
4359
|
+
if (two_copy.find(view[i]) != two_copy.end()) {
|
|
4360
|
+
// modify KV caches of decoder using data from kv_swap_bufs
|
|
4361
|
+
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
|
4362
|
+
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
|
4363
|
+
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
4364
|
+
} else {
|
|
4365
|
+
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
|
4366
|
+
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
|
4367
|
+
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
|
|
4368
|
+
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
|
|
4369
|
+
}
|
|
4370
|
+
}
|
|
4371
|
+
|
|
4372
|
+
// then modify one-copy decoder KV caches
|
|
4373
|
+
for (auto & i : one_copy) {
|
|
4374
|
+
// skip the decoder indices that require pointer swapping
|
|
4375
|
+
if (p_swap_set.find(i) != p_swap_set.end()) {
|
|
4376
|
+
continue;
|
|
4377
|
+
}
|
|
4378
|
+
|
|
4379
|
+
if (two_copy.find(view[i]) != two_copy.end()) {
|
|
4380
|
+
// modify KV caches of decoder using data from kv_swap_bufs
|
|
4381
|
+
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
|
4382
|
+
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
|
4383
|
+
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
4384
|
+
} else {
|
|
4385
|
+
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
|
4386
|
+
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
|
4387
|
+
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
|
|
4388
|
+
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
|
|
4389
|
+
}
|
|
4390
|
+
}
|
|
4391
|
+
|
|
4392
|
+
// swap the pointers
|
|
4393
|
+
for (auto & i : p_swap_vec) {
|
|
4394
|
+
WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
|
|
4395
|
+
std::swap(src[i.first].kv_self, src[i.second].kv_self);
|
|
4396
|
+
}
|
|
4397
|
+
|
|
4398
|
+
return true;
|
|
4399
|
+
}
|
|
4400
|
+
|
|
4078
4401
|
int whisper_full_with_state(
|
|
4079
4402
|
struct whisper_context * ctx,
|
|
4080
4403
|
struct whisper_state * state,
|
|
@@ -4182,6 +4505,21 @@ int whisper_full_with_state(
|
|
|
4182
4505
|
decoder.probs.resize (ctx->vocab.n_vocab);
|
|
4183
4506
|
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
4184
4507
|
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
4508
|
+
|
|
4509
|
+
// TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
|
|
4510
|
+
#ifdef WSP_GGML_USE_METAL
|
|
4511
|
+
#define WHISPER_METAL_CHECK_BUF(result) \
|
|
4512
|
+
if (!(result)) { \
|
|
4513
|
+
log("%s: failed to add metal buffer\n", __func__); \
|
|
4514
|
+
return 0; \
|
|
4515
|
+
}
|
|
4516
|
+
|
|
4517
|
+
const std::string kv_name = "kv_self_" + std::to_string(j);
|
|
4518
|
+
auto & kv_self = decoder.kv_self;
|
|
4519
|
+
|
|
4520
|
+
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
|
|
4521
|
+
#undef WHISPER_METAL_CHECK_BUF
|
|
4522
|
+
#endif
|
|
4185
4523
|
}
|
|
4186
4524
|
}
|
|
4187
4525
|
|
|
@@ -4197,7 +4535,7 @@ int whisper_full_with_state(
|
|
|
4197
4535
|
|
|
4198
4536
|
// initial prompt
|
|
4199
4537
|
if (!params.prompt_tokens && params.initial_prompt) {
|
|
4200
|
-
prompt_tokens.resize(
|
|
4538
|
+
prompt_tokens.resize(2048);
|
|
4201
4539
|
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
|
|
4202
4540
|
params.prompt_tokens = prompt_tokens.data();
|
|
4203
4541
|
params.prompt_n_tokens = prompt_tokens.size();
|
|
@@ -4238,14 +4576,6 @@ int whisper_full_with_state(
|
|
|
4238
4576
|
std::vector<whisper_token> prompt;
|
|
4239
4577
|
prompt.reserve(whisper_n_text_ctx(ctx));
|
|
4240
4578
|
|
|
4241
|
-
// beam-search helpers
|
|
4242
|
-
struct kv_buf {
|
|
4243
|
-
std::vector<uint8_t> k;
|
|
4244
|
-
std::vector<uint8_t> v;
|
|
4245
|
-
};
|
|
4246
|
-
|
|
4247
|
-
std::vector<kv_buf> kv_bufs;
|
|
4248
|
-
|
|
4249
4579
|
struct beam_candidate {
|
|
4250
4580
|
int decoder_idx;
|
|
4251
4581
|
int seek_delta;
|
|
@@ -4279,7 +4609,7 @@ int whisper_full_with_state(
|
|
|
4279
4609
|
}
|
|
4280
4610
|
|
|
4281
4611
|
// encode audio features starting at offset seek
|
|
4282
|
-
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
|
|
4612
|
+
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
4283
4613
|
log("%s: failed to encode\n", __func__);
|
|
4284
4614
|
return -6;
|
|
4285
4615
|
}
|
|
@@ -4362,7 +4692,7 @@ int whisper_full_with_state(
|
|
|
4362
4692
|
}
|
|
4363
4693
|
WHISPER_PRINT_DEBUG("\n\n");
|
|
4364
4694
|
|
|
4365
|
-
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
|
|
4695
|
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
4366
4696
|
log("%s: failed to decode\n", __func__);
|
|
4367
4697
|
return -7;
|
|
4368
4698
|
}
|
|
@@ -4382,8 +4712,8 @@ int whisper_full_with_state(
|
|
|
4382
4712
|
|
|
4383
4713
|
decoder.kv_self.n += prompt.size();
|
|
4384
4714
|
|
|
4385
|
-
memcpy(decoder.probs.data(),
|
|
4386
|
-
memcpy(decoder.logits.data(),
|
|
4715
|
+
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
|
4716
|
+
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
4387
4717
|
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
|
4388
4718
|
}
|
|
4389
4719
|
|
|
@@ -4394,23 +4724,7 @@ int whisper_full_with_state(
|
|
|
4394
4724
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
|
4395
4725
|
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
4396
4726
|
|
|
4397
|
-
// store the KV caches of all decoders when doing beam-search
|
|
4398
4727
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
|
4399
|
-
kv_bufs.resize(n_decoders_cur);
|
|
4400
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
4401
|
-
auto & decoder = state->decoders[j];
|
|
4402
|
-
|
|
4403
|
-
if (decoder.completed || decoder.failed) {
|
|
4404
|
-
continue;
|
|
4405
|
-
}
|
|
4406
|
-
|
|
4407
|
-
kv_bufs[j].k.resize(wsp_ggml_nbytes(decoder.kv_self.k));
|
|
4408
|
-
kv_bufs[j].v.resize(wsp_ggml_nbytes(decoder.kv_self.v));
|
|
4409
|
-
|
|
4410
|
-
memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
|
|
4411
|
-
memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
|
|
4412
|
-
}
|
|
4413
|
-
|
|
4414
4728
|
beam_candidates.clear();
|
|
4415
4729
|
}
|
|
4416
4730
|
|
|
@@ -4458,6 +4772,7 @@ int whisper_full_with_state(
|
|
|
4458
4772
|
});
|
|
4459
4773
|
|
|
4460
4774
|
uint32_t cur_c = 0;
|
|
4775
|
+
std::vector<int> decoder_idx(n_decoders_cur, -1);
|
|
4461
4776
|
|
|
4462
4777
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
4463
4778
|
auto & decoder = state->decoders[j];
|
|
@@ -4476,12 +4791,13 @@ int whisper_full_with_state(
|
|
|
4476
4791
|
decoder.seek_delta = cur.seek_delta;
|
|
4477
4792
|
decoder.has_ts = cur.has_ts;
|
|
4478
4793
|
|
|
4479
|
-
|
|
4480
|
-
memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
|
|
4481
|
-
|
|
4794
|
+
decoder_idx[j] = cur.decoder_idx;
|
|
4482
4795
|
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
|
4483
4796
|
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
|
4484
4797
|
}
|
|
4798
|
+
|
|
4799
|
+
// update KV caches
|
|
4800
|
+
whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
|
|
4485
4801
|
}
|
|
4486
4802
|
|
|
4487
4803
|
// update the decoder state
|
|
@@ -4600,7 +4916,7 @@ int whisper_full_with_state(
|
|
|
4600
4916
|
|
|
4601
4917
|
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
|
4602
4918
|
|
|
4603
|
-
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
|
|
4919
|
+
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
4604
4920
|
log("%s: failed to decode\n", __func__);
|
|
4605
4921
|
return -8;
|
|
4606
4922
|
}
|
|
@@ -4910,6 +5226,12 @@ int whisper_full_parallel(
|
|
|
4910
5226
|
ctx->state->t_sample_us += states[i]->t_sample_us;
|
|
4911
5227
|
ctx->state->t_encode_us += states[i]->t_encode_us;
|
|
4912
5228
|
ctx->state->t_decode_us += states[i]->t_decode_us;
|
|
5229
|
+
ctx->state->t_prompt_us += states[i]->t_prompt_us;
|
|
5230
|
+
|
|
5231
|
+
ctx->state->n_sample += states[i]->n_sample;
|
|
5232
|
+
ctx->state->n_encode += states[i]->n_encode;
|
|
5233
|
+
ctx->state->n_decode += states[i]->n_decode;
|
|
5234
|
+
ctx->state->n_prompt += states[i]->n_prompt;
|
|
4913
5235
|
|
|
4914
5236
|
whisper_free_state(states[i]);
|
|
4915
5237
|
}
|
|
@@ -4963,6 +5285,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
|
|
|
4963
5285
|
return ctx->state->result_all[i_segment].t1;
|
|
4964
5286
|
}
|
|
4965
5287
|
|
|
5288
|
+
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
|
|
5289
|
+
return state->result_all[i_segment].speaker_turn_next;
|
|
5290
|
+
}
|
|
5291
|
+
|
|
4966
5292
|
bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
|
|
4967
5293
|
return ctx->state->result_all[i_segment].speaker_turn_next;
|
|
4968
5294
|
}
|
|
@@ -5106,7 +5432,8 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
5106
5432
|
// b: N*N*sizeof(float)
|
|
5107
5433
|
// c: N*N*sizeof(float)
|
|
5108
5434
|
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
|
5109
|
-
std::vector<
|
|
5435
|
+
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead());
|
|
5436
|
+
std::vector<uint8_t> work;
|
|
5110
5437
|
|
|
5111
5438
|
// put a bunch of random data in the buffer
|
|
5112
5439
|
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
|
@@ -5158,17 +5485,15 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
5158
5485
|
|
|
5159
5486
|
struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(c);
|
|
5160
5487
|
|
|
5161
|
-
gf.n_threads = n_threads;
|
|
5162
|
-
|
|
5163
5488
|
double tsum = 0.0;
|
|
5164
5489
|
|
|
5165
5490
|
// heat-up
|
|
5166
|
-
|
|
5491
|
+
wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
|
|
5167
5492
|
|
|
5168
5493
|
for (int i = 0; i < n_max; ++i) {
|
|
5169
5494
|
const int64_t t0 = wsp_ggml_time_us();
|
|
5170
5495
|
|
|
5171
|
-
|
|
5496
|
+
wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
|
|
5172
5497
|
|
|
5173
5498
|
const int64_t t1 = wsp_ggml_time_us();
|
|
5174
5499
|
|