whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/android/build.gradle +4 -0
- package/android/src/main/CMakeLists.txt +7 -0
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
- package/android/src/main/jni-utils.h +76 -0
- package/android/src/main/jni.cpp +188 -109
- package/cpp/README.md +1 -1
- package/cpp/coreml/whisper-encoder-impl.h +1 -1
- package/cpp/coreml/whisper-encoder.h +4 -0
- package/cpp/coreml/whisper-encoder.mm +4 -2
- package/cpp/ggml-alloc.c +451 -282
- package/cpp/ggml-alloc.h +74 -8
- package/cpp/ggml-backend-impl.h +112 -0
- package/cpp/ggml-backend.c +1357 -0
- package/cpp/ggml-backend.h +181 -0
- package/cpp/ggml-impl.h +243 -0
- package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
- package/cpp/ggml-metal.h +28 -1
- package/cpp/ggml-metal.m +1128 -308
- package/cpp/ggml-quants.c +7382 -0
- package/cpp/ggml-quants.h +224 -0
- package/cpp/ggml.c +3848 -5245
- package/cpp/ggml.h +353 -155
- package/cpp/rn-audioutils.cpp +68 -0
- package/cpp/rn-audioutils.h +14 -0
- package/cpp/rn-whisper-log.h +11 -0
- package/cpp/rn-whisper.cpp +141 -59
- package/cpp/rn-whisper.h +47 -15
- package/cpp/whisper.cpp +1750 -964
- package/cpp/whisper.h +97 -15
- package/ios/RNWhisper.mm +15 -9
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
- package/ios/RNWhisperAudioUtils.h +0 -2
- package/ios/RNWhisperAudioUtils.m +0 -56
- package/ios/RNWhisperContext.h +8 -12
- package/ios/RNWhisperContext.mm +132 -138
- package/jest/mock.js +1 -1
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +28 -9
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +28 -9
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +7 -1
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +7 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +6 -5
- package/src/NativeRNWhisper.ts +8 -1
- package/src/index.ts +29 -17
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -2
package/cpp/whisper.cpp
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
1
|
#include "whisper.h"
|
|
2
|
+
|
|
2
3
|
#ifdef WHISPER_USE_COREML
|
|
3
4
|
#include "coreml/whisper-encoder.h"
|
|
4
5
|
#endif
|
|
5
6
|
|
|
6
7
|
#ifdef WSP_GGML_USE_METAL
|
|
7
|
-
#
|
|
8
|
+
#include "ggml-metal.h"
|
|
9
|
+
#endif
|
|
10
|
+
|
|
11
|
+
#ifdef WSP_GGML_USE_CUBLAS
|
|
12
|
+
#include "ggml-cuda.h"
|
|
8
13
|
#endif
|
|
9
14
|
|
|
10
15
|
#ifdef WHISPER_USE_OPENVINO
|
|
@@ -13,7 +18,9 @@
|
|
|
13
18
|
|
|
14
19
|
#include "ggml.h"
|
|
15
20
|
#include "ggml-alloc.h"
|
|
21
|
+
#include "ggml-backend.h"
|
|
16
22
|
|
|
23
|
+
#include <atomic>
|
|
17
24
|
#include <algorithm>
|
|
18
25
|
#include <cassert>
|
|
19
26
|
#define _USE_MATH_DEFINES
|
|
@@ -97,10 +104,32 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
|
|
|
97
104
|
#define BYTESWAP_TENSOR(t) do {} while (0)
|
|
98
105
|
#endif
|
|
99
106
|
|
|
107
|
+
#ifdef __GNUC__
|
|
108
|
+
#ifdef __MINGW32__
|
|
109
|
+
#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
|
110
|
+
#else
|
|
111
|
+
#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
|
112
|
+
#endif
|
|
113
|
+
#else
|
|
114
|
+
#define WHISPER_ATTRIBUTE_FORMAT(...)
|
|
115
|
+
#endif
|
|
116
|
+
|
|
117
|
+
//
|
|
118
|
+
// logging
|
|
119
|
+
//
|
|
120
|
+
|
|
121
|
+
WHISPER_ATTRIBUTE_FORMAT(2, 3)
|
|
122
|
+
static void whisper_log_internal (wsp_ggml_log_level level, const char * format, ...);
|
|
123
|
+
static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data);
|
|
124
|
+
|
|
125
|
+
#define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
|
126
|
+
#define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
|
127
|
+
#define WHISPER_LOG_ERROR(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
|
128
|
+
|
|
100
129
|
#define WHISPER_ASSERT(x) \
|
|
101
130
|
do { \
|
|
102
131
|
if (!(x)) { \
|
|
103
|
-
|
|
132
|
+
WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
|
104
133
|
abort(); \
|
|
105
134
|
} \
|
|
106
135
|
} while (0)
|
|
@@ -119,15 +148,16 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
|
|
|
119
148
|
|
|
120
149
|
//#define WHISPER_USE_FLASH_ATTN
|
|
121
150
|
//#define WHISPER_USE_FLASH_FF
|
|
122
|
-
#define WHISPER_MAX_DECODERS
|
|
151
|
+
#define WHISPER_MAX_DECODERS 8
|
|
152
|
+
#define WHISPER_MAX_NODES 4096
|
|
123
153
|
|
|
124
154
|
//
|
|
125
155
|
// ggml helpers
|
|
126
156
|
//
|
|
127
157
|
|
|
128
158
|
static void wsp_ggml_graph_compute_helper(
|
|
159
|
+
struct wsp_ggml_cgraph * graph,
|
|
129
160
|
std::vector<uint8_t> & buf,
|
|
130
|
-
wsp_ggml_cgraph * graph,
|
|
131
161
|
int n_threads,
|
|
132
162
|
whisper_abort_callback abort_callback,
|
|
133
163
|
void * abort_callback_data) {
|
|
@@ -144,6 +174,21 @@ static void wsp_ggml_graph_compute_helper(
|
|
|
144
174
|
wsp_ggml_graph_compute(graph, &plan);
|
|
145
175
|
}
|
|
146
176
|
|
|
177
|
+
static void wsp_ggml_graph_compute_helper(
|
|
178
|
+
struct wsp_ggml_backend * backend,
|
|
179
|
+
struct wsp_ggml_cgraph * graph,
|
|
180
|
+
int n_threads) {
|
|
181
|
+
if (wsp_ggml_backend_is_cpu(backend)) {
|
|
182
|
+
wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
|
|
183
|
+
}
|
|
184
|
+
#ifdef WSP_GGML_USE_METAL
|
|
185
|
+
if (wsp_ggml_backend_is_metal(backend)) {
|
|
186
|
+
wsp_ggml_backend_metal_set_n_cb(backend, n_threads);
|
|
187
|
+
}
|
|
188
|
+
#endif
|
|
189
|
+
wsp_ggml_backend_graph_compute(backend, graph);
|
|
190
|
+
}
|
|
191
|
+
|
|
147
192
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
|
148
193
|
// the idea is to represent the original matrix multiplication:
|
|
149
194
|
//
|
|
@@ -178,6 +223,7 @@ static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * c
|
|
|
178
223
|
}
|
|
179
224
|
|
|
180
225
|
// TODO: check if other platforms can benefit from this optimization
|
|
226
|
+
// TODO: CUDA is currently broken - seems wsp_ggml_mul_mat does not handle views correctly
|
|
181
227
|
#if defined(WSP_GGML_USE_METAL)
|
|
182
228
|
#define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad
|
|
183
229
|
#endif
|
|
@@ -192,6 +238,15 @@ enum e_model {
|
|
|
192
238
|
MODEL_LARGE,
|
|
193
239
|
};
|
|
194
240
|
|
|
241
|
+
static const std::map<e_model, std::string> g_model_name = {
|
|
242
|
+
{ MODEL_UNKNOWN, "unknown" },
|
|
243
|
+
{ MODEL_TINY, "tiny" },
|
|
244
|
+
{ MODEL_BASE, "base" },
|
|
245
|
+
{ MODEL_SMALL, "small" },
|
|
246
|
+
{ MODEL_MEDIUM, "medium" },
|
|
247
|
+
{ MODEL_LARGE, "large" },
|
|
248
|
+
};
|
|
249
|
+
|
|
195
250
|
static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
196
251
|
{ "en", { 0, "english", } },
|
|
197
252
|
{ "zh", { 1, "chinese", } },
|
|
@@ -292,75 +347,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
|
292
347
|
{ "ba", { 96, "bashkir", } },
|
|
293
348
|
{ "jw", { 97, "javanese", } },
|
|
294
349
|
{ "su", { 98, "sundanese", } },
|
|
295
|
-
}
|
|
296
|
-
|
|
297
|
-
static const size_t MB = 1ull*1024*1024;
|
|
298
|
-
|
|
299
|
-
// TODO: avoid using GGUF
|
|
300
|
-
static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
|
301
|
-
{ WSP_GGML_TYPE_F32,
|
|
302
|
-
{
|
|
303
|
-
{ MODEL_TINY, 74ull*MB },
|
|
304
|
-
{ MODEL_BASE, 142ull*MB },
|
|
305
|
-
{ MODEL_SMALL, 466ull*MB },
|
|
306
|
-
{ MODEL_MEDIUM, 1464ull*MB },
|
|
307
|
-
{ MODEL_LARGE, 2952ull*MB },
|
|
308
|
-
},
|
|
309
|
-
},
|
|
310
|
-
{ WSP_GGML_TYPE_F16,
|
|
311
|
-
{
|
|
312
|
-
{ MODEL_TINY, 74ull*MB },
|
|
313
|
-
{ MODEL_BASE, 142ull*MB },
|
|
314
|
-
{ MODEL_SMALL, 466ull*MB },
|
|
315
|
-
{ MODEL_MEDIUM, 1464ull*MB },
|
|
316
|
-
{ MODEL_LARGE, 2952ull*MB },
|
|
317
|
-
},
|
|
318
|
-
},
|
|
319
|
-
{ WSP_GGML_TYPE_Q4_0,
|
|
320
|
-
{
|
|
321
|
-
{ MODEL_TINY, 26ull*MB },
|
|
322
|
-
{ MODEL_BASE, 50ull*MB },
|
|
323
|
-
{ MODEL_SMALL, 154ull*MB },
|
|
324
|
-
{ MODEL_MEDIUM, 470ull*MB },
|
|
325
|
-
{ MODEL_LARGE, 940ull*MB },
|
|
326
|
-
},
|
|
327
|
-
},
|
|
328
|
-
{ WSP_GGML_TYPE_Q4_1,
|
|
329
|
-
{
|
|
330
|
-
{ MODEL_TINY, 32ull*MB },
|
|
331
|
-
{ MODEL_BASE, 58ull*MB },
|
|
332
|
-
{ MODEL_SMALL, 182ull*MB },
|
|
333
|
-
{ MODEL_MEDIUM, 562ull*MB },
|
|
334
|
-
{ MODEL_LARGE, 1124ull*MB },
|
|
335
|
-
},
|
|
336
|
-
},
|
|
337
|
-
{ WSP_GGML_TYPE_Q5_0,
|
|
338
|
-
{
|
|
339
|
-
{ MODEL_TINY, 30ull*MB },
|
|
340
|
-
{ MODEL_BASE, 54ull*MB },
|
|
341
|
-
{ MODEL_SMALL, 170ull*MB },
|
|
342
|
-
{ MODEL_MEDIUM, 516ull*MB },
|
|
343
|
-
{ MODEL_LARGE, 1034ull*MB },
|
|
344
|
-
},
|
|
345
|
-
},
|
|
346
|
-
{ WSP_GGML_TYPE_Q5_1,
|
|
347
|
-
{
|
|
348
|
-
{ MODEL_TINY, 32ull*MB },
|
|
349
|
-
{ MODEL_BASE, 58ull*MB },
|
|
350
|
-
{ MODEL_SMALL, 182ull*MB },
|
|
351
|
-
{ MODEL_MEDIUM, 562ull*MB },
|
|
352
|
-
{ MODEL_LARGE, 1124ull*MB },
|
|
353
|
-
},
|
|
354
|
-
},
|
|
355
|
-
{ WSP_GGML_TYPE_Q8_0,
|
|
356
|
-
{
|
|
357
|
-
{ MODEL_TINY, 45ull*MB },
|
|
358
|
-
{ MODEL_BASE, 84ull*MB },
|
|
359
|
-
{ MODEL_SMALL, 268ull*MB },
|
|
360
|
-
{ MODEL_MEDIUM, 834ull*MB },
|
|
361
|
-
{ MODEL_LARGE, 1674ull*MB },
|
|
362
|
-
},
|
|
363
|
-
},
|
|
350
|
+
{ "yue", { 99, "cantonese", } },
|
|
364
351
|
};
|
|
365
352
|
|
|
366
353
|
struct whisper_mel {
|
|
@@ -401,7 +388,11 @@ struct whisper_vocab {
|
|
|
401
388
|
id token_beg = 50363; // begin timestamps
|
|
402
389
|
|
|
403
390
|
bool is_multilingual() const {
|
|
404
|
-
return n_vocab
|
|
391
|
+
return n_vocab >= 51865;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
int num_languages() const {
|
|
395
|
+
return n_vocab - 51765 - (is_multilingual() ? 1 : 0);
|
|
405
396
|
}
|
|
406
397
|
};
|
|
407
398
|
|
|
@@ -416,6 +407,121 @@ struct whisper_segment {
|
|
|
416
407
|
bool speaker_turn_next;
|
|
417
408
|
};
|
|
418
409
|
|
|
410
|
+
struct whisper_batch {
|
|
411
|
+
int32_t n_tokens;
|
|
412
|
+
|
|
413
|
+
whisper_token * token;
|
|
414
|
+
whisper_pos * pos;
|
|
415
|
+
int32_t * n_seq_id;
|
|
416
|
+
whisper_seq_id ** seq_id; // null terminated
|
|
417
|
+
int8_t * logits;
|
|
418
|
+
};
|
|
419
|
+
|
|
420
|
+
static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
|
|
421
|
+
whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
|
|
422
|
+
|
|
423
|
+
batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens));
|
|
424
|
+
batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens));
|
|
425
|
+
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
|
|
426
|
+
batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
|
|
427
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
428
|
+
batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
|
|
429
|
+
}
|
|
430
|
+
batch.seq_id[n_tokens] = nullptr;
|
|
431
|
+
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
|
432
|
+
|
|
433
|
+
return batch;
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
static void whisper_batch_free(struct whisper_batch batch) {
|
|
437
|
+
if (batch.token) free(batch.token);
|
|
438
|
+
if (batch.pos) free(batch.pos);
|
|
439
|
+
if (batch.n_seq_id) free(batch.n_seq_id);
|
|
440
|
+
if (batch.seq_id) {
|
|
441
|
+
for (int i = 0; batch.seq_id[i]; ++i) {
|
|
442
|
+
free(batch.seq_id[i]);
|
|
443
|
+
}
|
|
444
|
+
free(batch.seq_id);
|
|
445
|
+
}
|
|
446
|
+
if (batch.logits) free(batch.logits);
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
|
|
450
|
+
batch.n_tokens = n_tokens;
|
|
451
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
452
|
+
if (tokens) {
|
|
453
|
+
batch.token[i] = tokens[i];
|
|
454
|
+
}
|
|
455
|
+
batch.pos [i] = n_past + i;
|
|
456
|
+
batch.n_seq_id[i] = 1;
|
|
457
|
+
batch.seq_id [i][0] = seq_id;
|
|
458
|
+
batch.logits [i] = 0;
|
|
459
|
+
}
|
|
460
|
+
batch.logits[n_tokens - 1] = 1;
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
|
464
|
+
template<typename A, typename B>
|
|
465
|
+
struct whisper_pair {
|
|
466
|
+
A first;
|
|
467
|
+
B second;
|
|
468
|
+
|
|
469
|
+
// Define a constructor that takes two arguments.
|
|
470
|
+
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
|
471
|
+
// Define a constructor that takes no argument.
|
|
472
|
+
whisper_pair() : first(A()), second(B()) {}
|
|
473
|
+
};
|
|
474
|
+
|
|
475
|
+
// wsp_ggml_allocr wrapper for whisper usage
|
|
476
|
+
struct whisper_allocr {
|
|
477
|
+
wsp_ggml_allocr * alloc = nullptr;
|
|
478
|
+
|
|
479
|
+
std::vector<uint8_t> meta;
|
|
480
|
+
|
|
481
|
+
wsp_ggml_backend_buffer_t buffer;
|
|
482
|
+
};
|
|
483
|
+
|
|
484
|
+
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
|
485
|
+
return allocr.meta.size() + wsp_ggml_allocr_max_size(allocr.alloc);
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
|
489
|
+
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, wsp_ggml_backend_t backend, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
|
|
490
|
+
auto & alloc = allocr.alloc;
|
|
491
|
+
auto & meta = allocr.meta;
|
|
492
|
+
|
|
493
|
+
alloc = wsp_ggml_allocr_new_measure_from_backend(backend);
|
|
494
|
+
|
|
495
|
+
meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
|
|
496
|
+
|
|
497
|
+
wsp_ggml_allocr_alloc_graph(alloc, get_graph());
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, wsp_ggml_backend_t backend) {
|
|
501
|
+
if (allocr.alloc == nullptr) {
|
|
502
|
+
// this can be null if we use external encoder like CoreML or OpenVINO
|
|
503
|
+
return;
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
auto & alloc = allocr.alloc;
|
|
507
|
+
auto & buffer = allocr.buffer;
|
|
508
|
+
|
|
509
|
+
size_t size = wsp_ggml_allocr_max_size(alloc);
|
|
510
|
+
|
|
511
|
+
wsp_ggml_allocr_free(alloc);
|
|
512
|
+
|
|
513
|
+
buffer = wsp_ggml_backend_alloc_buffer(backend, size);
|
|
514
|
+
alloc = wsp_ggml_allocr_new_from_buffer(buffer);
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
|
518
|
+
if (allocr.alloc) {
|
|
519
|
+
wsp_ggml_allocr_free(allocr.alloc);
|
|
520
|
+
wsp_ggml_backend_buffer_free(allocr.buffer);
|
|
521
|
+
allocr.alloc = nullptr;
|
|
522
|
+
}
|
|
523
|
+
}
|
|
524
|
+
|
|
419
525
|
// medium
|
|
420
526
|
// hparams: {
|
|
421
527
|
// 'n_mels': 80,
|
|
@@ -533,16 +639,31 @@ struct whisper_layer_decoder {
|
|
|
533
639
|
struct wsp_ggml_tensor * mlp_1_b;
|
|
534
640
|
};
|
|
535
641
|
|
|
642
|
+
struct whisper_kv_cell {
|
|
643
|
+
whisper_pos pos = -1;
|
|
644
|
+
|
|
645
|
+
std::set<whisper_seq_id> seq_id;
|
|
646
|
+
|
|
647
|
+
bool has_seq_id(const whisper_seq_id & id) const {
|
|
648
|
+
return seq_id.find(id) != seq_id.end();
|
|
649
|
+
}
|
|
650
|
+
};
|
|
651
|
+
|
|
536
652
|
struct whisper_kv_cache {
|
|
653
|
+
uint32_t head = 0;
|
|
654
|
+
uint32_t size = 0;
|
|
655
|
+
|
|
656
|
+
// computed before each graph build
|
|
657
|
+
uint32_t n = 0;
|
|
658
|
+
|
|
659
|
+
std::vector<whisper_kv_cell> cells;
|
|
660
|
+
|
|
537
661
|
struct wsp_ggml_tensor * k;
|
|
538
662
|
struct wsp_ggml_tensor * v;
|
|
539
663
|
|
|
540
664
|
struct wsp_ggml_context * ctx;
|
|
541
665
|
|
|
542
|
-
|
|
543
|
-
std::vector<uint8_t> buf;
|
|
544
|
-
|
|
545
|
-
int n; // number of tokens currently in the cache
|
|
666
|
+
wsp_ggml_backend_buffer_t buffer;
|
|
546
667
|
};
|
|
547
668
|
|
|
548
669
|
struct whisper_model {
|
|
@@ -579,17 +700,36 @@ struct whisper_model {
|
|
|
579
700
|
std::vector<whisper_layer_encoder> layers_encoder;
|
|
580
701
|
std::vector<whisper_layer_decoder> layers_decoder;
|
|
581
702
|
|
|
582
|
-
// context
|
|
703
|
+
// ggml context that contains all the meta information about the model tensors
|
|
583
704
|
struct wsp_ggml_context * ctx;
|
|
584
705
|
|
|
585
|
-
// the model
|
|
586
|
-
|
|
706
|
+
// the model backend data is read-only and can be shared between processors
|
|
707
|
+
struct wsp_ggml_backend_buffer * buffer;
|
|
587
708
|
|
|
588
709
|
// tensors
|
|
589
710
|
int n_loaded;
|
|
590
711
|
std::map<std::string, struct wsp_ggml_tensor *> tensors;
|
|
591
712
|
};
|
|
592
713
|
|
|
714
|
+
struct whisper_partial_utf8 {
|
|
715
|
+
uint32_t value; // bit value so far (unshifted)
|
|
716
|
+
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
|
717
|
+
};
|
|
718
|
+
|
|
719
|
+
struct whisper_grammar {
|
|
720
|
+
/*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
|
|
721
|
+
std::vector<std::vector<const whisper_grammar_element *>> stacks;
|
|
722
|
+
|
|
723
|
+
// buffer for partially generated UTF-8 sequence from accepted tokens
|
|
724
|
+
whisper_partial_utf8 partial_utf8;
|
|
725
|
+
};
|
|
726
|
+
|
|
727
|
+
struct whisper_grammar_candidate {
|
|
728
|
+
whisper_token id;
|
|
729
|
+
const uint32_t * code_points;
|
|
730
|
+
whisper_partial_utf8 partial_utf8;
|
|
731
|
+
};
|
|
732
|
+
|
|
593
733
|
struct whisper_sequence {
|
|
594
734
|
std::vector<whisper_token_data> tokens;
|
|
595
735
|
|
|
@@ -605,12 +745,13 @@ struct whisper_sequence {
|
|
|
605
745
|
|
|
606
746
|
// TAGS: WHISPER_DECODER_INIT
|
|
607
747
|
struct whisper_decoder {
|
|
608
|
-
// each decoder keeps its own KV-cache
|
|
609
|
-
whisper_kv_cache kv_self;
|
|
610
|
-
|
|
611
748
|
// the currently generated sequence of tokens
|
|
612
749
|
whisper_sequence sequence;
|
|
613
750
|
|
|
751
|
+
// grammar parse state of generated sequence of tokens
|
|
752
|
+
whisper_grammar grammar;
|
|
753
|
+
|
|
754
|
+
int i_batch; // the index of the token in the current batch
|
|
614
755
|
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
|
|
615
756
|
|
|
616
757
|
bool failed; // has the current segment failed to decode?
|
|
@@ -622,93 +763,42 @@ struct whisper_decoder {
|
|
|
622
763
|
std::vector<float> logits;
|
|
623
764
|
std::vector<float> logprobs;
|
|
624
765
|
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
|
629
|
-
template<typename A, typename B>
|
|
630
|
-
struct whisper_pair {
|
|
631
|
-
A first;
|
|
632
|
-
B second;
|
|
633
|
-
|
|
634
|
-
// Define a constructor that takes two arguments.
|
|
635
|
-
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
|
636
|
-
// Define a constructor that takes no argument.
|
|
637
|
-
whisper_pair() : first(A()), second(B()) {}
|
|
638
|
-
};
|
|
639
|
-
|
|
640
|
-
// beam-search helpers
|
|
641
|
-
struct kv_buf {
|
|
642
|
-
std::vector<uint8_t> k;
|
|
643
|
-
std::vector<uint8_t> v;
|
|
644
|
-
};
|
|
645
|
-
|
|
646
|
-
// wsp_ggml_allocr wrapper for whisper usage
|
|
647
|
-
struct whisper_allocr {
|
|
648
|
-
wsp_ggml_allocr * alloc = nullptr;
|
|
766
|
+
// work container used to avoid memory allocations
|
|
767
|
+
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
|
649
768
|
|
|
650
|
-
std::
|
|
651
|
-
std::vector<uint8_t> data;
|
|
769
|
+
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
652
770
|
};
|
|
653
771
|
|
|
654
|
-
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
|
655
|
-
return allocr.meta.size() + allocr.data.size();
|
|
656
|
-
}
|
|
657
|
-
|
|
658
|
-
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
|
659
|
-
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
|
|
660
|
-
const int tensor_alignment = 32;
|
|
661
|
-
|
|
662
|
-
auto & alloc = allocr.alloc;
|
|
663
|
-
auto & meta = allocr.meta;
|
|
664
|
-
auto & data = allocr.data;
|
|
665
|
-
|
|
666
|
-
meta.resize(wsp_ggml_tensor_overhead()*WSP_GGML_MAX_NODES + wsp_ggml_graph_overhead());
|
|
667
|
-
|
|
668
|
-
alloc = wsp_ggml_allocr_new_measure(tensor_alignment);
|
|
669
|
-
|
|
670
|
-
const size_t alloc_size = wsp_ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
|
|
671
|
-
|
|
672
|
-
wsp_ggml_allocr_free(alloc);
|
|
673
|
-
|
|
674
|
-
data.resize(alloc_size);
|
|
675
|
-
|
|
676
|
-
alloc = wsp_ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
|
677
|
-
}
|
|
678
|
-
|
|
679
|
-
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
|
680
|
-
if (allocr.alloc) {
|
|
681
|
-
wsp_ggml_allocr_free(allocr.alloc);
|
|
682
|
-
allocr.alloc = nullptr;
|
|
683
|
-
}
|
|
684
|
-
}
|
|
685
|
-
|
|
686
772
|
struct whisper_state {
|
|
687
773
|
int64_t t_sample_us = 0;
|
|
688
774
|
int64_t t_encode_us = 0;
|
|
689
775
|
int64_t t_decode_us = 0;
|
|
776
|
+
int64_t t_batchd_us = 0;
|
|
690
777
|
int64_t t_prompt_us = 0;
|
|
691
778
|
int64_t t_mel_us = 0;
|
|
692
779
|
|
|
693
780
|
int32_t n_sample = 0; // number of tokens sampled
|
|
694
781
|
int32_t n_encode = 0; // number of encoder calls
|
|
695
|
-
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1
|
|
696
|
-
int32_t
|
|
782
|
+
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
|
|
783
|
+
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
|
|
784
|
+
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
|
|
697
785
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
|
698
786
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
|
699
787
|
|
|
788
|
+
// unified self-attention KV cache for all decoders
|
|
789
|
+
whisper_kv_cache kv_self;
|
|
790
|
+
|
|
700
791
|
// cross-attention KV cache for the decoders
|
|
701
792
|
// shared between all decoders
|
|
702
793
|
whisper_kv_cache kv_cross;
|
|
794
|
+
|
|
703
795
|
whisper_mel mel;
|
|
704
796
|
|
|
705
|
-
|
|
797
|
+
whisper_batch batch;
|
|
706
798
|
|
|
707
|
-
|
|
708
|
-
std::vector<kv_buf> kv_swap_bufs;
|
|
799
|
+
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
|
709
800
|
|
|
710
|
-
|
|
711
|
-
std::vector<uint8_t> work_buffer;
|
|
801
|
+
wsp_ggml_backend_t backend = nullptr;
|
|
712
802
|
|
|
713
803
|
// ggml-alloc:
|
|
714
804
|
// - stores meta info about the intermediate tensors into the `meta` buffers
|
|
@@ -722,36 +812,34 @@ struct whisper_state {
|
|
|
722
812
|
struct wsp_ggml_tensor * embd_conv = nullptr;
|
|
723
813
|
struct wsp_ggml_tensor * embd_enc = nullptr;
|
|
724
814
|
|
|
815
|
+
// helpers for GPU offloading
|
|
816
|
+
std::vector<float> inp_mel;
|
|
817
|
+
std::vector<float> inp_mask;
|
|
818
|
+
|
|
725
819
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
|
726
820
|
std::vector<float> logits;
|
|
727
821
|
|
|
728
822
|
std::vector<whisper_segment> result_all;
|
|
729
823
|
std::vector<whisper_token> prompt_past;
|
|
730
824
|
|
|
731
|
-
// work container used to avoid memory allocations
|
|
732
|
-
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
|
733
|
-
|
|
734
|
-
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
735
|
-
|
|
736
825
|
int lang_id = 0; // english by default
|
|
737
826
|
|
|
738
|
-
std::string path_model; // populated by
|
|
827
|
+
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
828
|
+
|
|
739
829
|
#ifdef WHISPER_USE_COREML
|
|
740
830
|
whisper_coreml_context * ctx_coreml = nullptr;
|
|
741
831
|
#endif
|
|
742
832
|
|
|
743
|
-
#ifdef WSP_GGML_USE_METAL
|
|
744
|
-
wsp_ggml_metal_context * ctx_metal = nullptr;
|
|
745
|
-
#endif
|
|
746
|
-
|
|
747
833
|
#ifdef WHISPER_USE_OPENVINO
|
|
748
834
|
whisper_openvino_context * ctx_openvino = nullptr;
|
|
749
835
|
#endif
|
|
750
836
|
|
|
751
837
|
// [EXPERIMENTAL] token-level timestamps data
|
|
752
|
-
int64_t t_beg
|
|
838
|
+
int64_t t_beg = 0;
|
|
753
839
|
int64_t t_last = 0;
|
|
840
|
+
|
|
754
841
|
whisper_token tid_last;
|
|
842
|
+
|
|
755
843
|
std::vector<float> energy; // PCM signal energy
|
|
756
844
|
|
|
757
845
|
// [EXPERIMENTAL] speed-up techniques
|
|
@@ -765,37 +853,25 @@ struct whisper_context {
|
|
|
765
853
|
wsp_ggml_type wtype = wsp_ggml_type::WSP_GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
|
|
766
854
|
wsp_ggml_type itype = wsp_ggml_type::WSP_GGML_TYPE_F16; // intermediate type (FP32 or FP16)
|
|
767
855
|
|
|
856
|
+
whisper_context_params params;
|
|
857
|
+
|
|
768
858
|
whisper_model model;
|
|
769
859
|
whisper_vocab vocab;
|
|
860
|
+
|
|
770
861
|
whisper_state * state = nullptr;
|
|
771
862
|
|
|
772
|
-
|
|
773
|
-
#ifdef WHISPER_USE_COREML
|
|
774
|
-
bool load_coreml = true;
|
|
775
|
-
#endif
|
|
776
|
-
};
|
|
863
|
+
wsp_ggml_backend_t backend = nullptr;
|
|
777
864
|
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
}
|
|
865
|
+
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
866
|
+
};
|
|
781
867
|
|
|
782
|
-
|
|
868
|
+
struct whisper_global {
|
|
869
|
+
// We save the log callback globally
|
|
870
|
+
wsp_ggml_log_callback log_callback = whisper_log_callback_default;
|
|
871
|
+
void * log_callback_user_data = nullptr;
|
|
872
|
+
};
|
|
783
873
|
|
|
784
|
-
|
|
785
|
-
#ifdef __MINGW32__
|
|
786
|
-
__attribute__((gnu_format(printf, 1, 2)))
|
|
787
|
-
#else
|
|
788
|
-
__attribute__((format(printf, 1, 2)))
|
|
789
|
-
#endif
|
|
790
|
-
#endif
|
|
791
|
-
static void log(const char * fmt, ...) {
|
|
792
|
-
if (!whisper_log) return;
|
|
793
|
-
char buf[1024];
|
|
794
|
-
va_list args;
|
|
795
|
-
va_start(args, fmt);
|
|
796
|
-
vsnprintf(buf, sizeof(buf), fmt, args);
|
|
797
|
-
whisper_log(buf);
|
|
798
|
-
}
|
|
874
|
+
static whisper_global g_state;
|
|
799
875
|
|
|
800
876
|
template<typename T>
|
|
801
877
|
static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
@@ -806,6 +882,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
|
806
882
|
static bool kv_cache_init(
|
|
807
883
|
const struct whisper_hparams & hparams,
|
|
808
884
|
struct whisper_kv_cache & cache,
|
|
885
|
+
wsp_ggml_backend_t backend,
|
|
809
886
|
wsp_ggml_type wtype,
|
|
810
887
|
int n_ctx) {
|
|
811
888
|
const int64_t n_text_state = hparams.n_text_state;
|
|
@@ -814,64 +891,204 @@ static bool kv_cache_init(
|
|
|
814
891
|
const int64_t n_mem = n_text_layer*n_ctx;
|
|
815
892
|
const int64_t n_elements = n_text_state*n_mem;
|
|
816
893
|
|
|
817
|
-
const size_t mem_bytes = 2*(wsp_ggml_type_size(wtype)*n_elements + wsp_ggml_tensor_overhead());
|
|
818
|
-
|
|
819
|
-
cache.buf.resize(mem_bytes);
|
|
820
|
-
|
|
821
894
|
struct wsp_ggml_init_params params = {
|
|
822
|
-
/*.mem_size =*/
|
|
823
|
-
/*.mem_buffer =*/
|
|
824
|
-
/*.no_alloc =*/
|
|
895
|
+
/*.mem_size =*/ 2*wsp_ggml_tensor_overhead(),
|
|
896
|
+
/*.mem_buffer =*/ nullptr,
|
|
897
|
+
/*.no_alloc =*/ true,
|
|
825
898
|
};
|
|
826
899
|
|
|
900
|
+
cache.head = 0;
|
|
901
|
+
cache.size = n_ctx;
|
|
902
|
+
|
|
903
|
+
cache.cells.clear();
|
|
904
|
+
cache.cells.resize(n_ctx);
|
|
905
|
+
|
|
827
906
|
cache.ctx = wsp_ggml_init(params);
|
|
828
907
|
|
|
829
908
|
if (!cache.ctx) {
|
|
830
|
-
|
|
909
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
|
|
831
910
|
return false;
|
|
832
911
|
}
|
|
833
912
|
|
|
834
913
|
cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
|
835
914
|
cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
|
836
915
|
|
|
916
|
+
const size_t mem_bytes = wsp_ggml_nbytes(cache.k) + wsp_ggml_nbytes(cache.v);
|
|
917
|
+
|
|
918
|
+
cache.buffer = wsp_ggml_backend_alloc_buffer(backend, mem_bytes);
|
|
919
|
+
|
|
920
|
+
// allocate the tensors into the backend buffer
|
|
921
|
+
{
|
|
922
|
+
wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(cache.buffer);
|
|
923
|
+
|
|
924
|
+
wsp_ggml_allocr_alloc(alloc, cache.k);
|
|
925
|
+
wsp_ggml_allocr_alloc(alloc, cache.v);
|
|
926
|
+
|
|
927
|
+
wsp_ggml_allocr_free(alloc);
|
|
928
|
+
}
|
|
929
|
+
|
|
837
930
|
return true;
|
|
838
931
|
}
|
|
839
932
|
|
|
840
|
-
static
|
|
841
|
-
|
|
933
|
+
static void kv_cache_free(struct whisper_kv_cache & cache) {
|
|
934
|
+
if (cache.ctx) {
|
|
935
|
+
wsp_ggml_free(cache.ctx);
|
|
936
|
+
wsp_ggml_backend_buffer_free(cache.buffer);
|
|
937
|
+
cache.ctx = nullptr;
|
|
938
|
+
}
|
|
939
|
+
}
|
|
940
|
+
|
|
941
|
+
static bool whisper_kv_cache_find_slot(
|
|
942
|
+
struct whisper_kv_cache & cache,
|
|
943
|
+
const struct whisper_batch & batch) {
|
|
944
|
+
const uint32_t n_ctx = cache.size;
|
|
945
|
+
const uint32_t n_tokens = batch.n_tokens;
|
|
842
946
|
|
|
843
|
-
|
|
844
|
-
|
|
947
|
+
if (n_tokens > n_ctx) {
|
|
948
|
+
WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
|
|
949
|
+
return false;
|
|
950
|
+
}
|
|
845
951
|
|
|
846
|
-
|
|
847
|
-
WHISPER_ASSERT(wtype == cache.v->type);
|
|
952
|
+
uint32_t n_tested = 0;
|
|
848
953
|
|
|
849
|
-
|
|
954
|
+
while (true) {
|
|
955
|
+
if (cache.head + n_tokens > n_ctx) {
|
|
956
|
+
n_tested += n_ctx - cache.head;
|
|
957
|
+
cache.head = 0;
|
|
958
|
+
continue;
|
|
959
|
+
}
|
|
850
960
|
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
961
|
+
bool found = true;
|
|
962
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
963
|
+
if (cache.cells[cache.head + i].pos >= 0) {
|
|
964
|
+
found = false;
|
|
965
|
+
cache.head += i + 1;
|
|
966
|
+
n_tested += i + 1;
|
|
967
|
+
break;
|
|
968
|
+
}
|
|
969
|
+
}
|
|
856
970
|
|
|
857
|
-
|
|
971
|
+
if (found) {
|
|
972
|
+
break;
|
|
973
|
+
}
|
|
858
974
|
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
975
|
+
if (n_tested >= n_ctx) {
|
|
976
|
+
//WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
|
977
|
+
return false;
|
|
978
|
+
}
|
|
862
979
|
}
|
|
863
980
|
|
|
864
|
-
|
|
865
|
-
|
|
981
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
982
|
+
cache.cells[cache.head + i].pos = batch.pos[i];
|
|
983
|
+
|
|
984
|
+
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
|
985
|
+
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
|
|
986
|
+
}
|
|
987
|
+
}
|
|
866
988
|
|
|
867
989
|
return true;
|
|
868
990
|
}
|
|
869
991
|
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
cache.
|
|
992
|
+
// find how many cells are currently in use
|
|
993
|
+
static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
|
|
994
|
+
for (uint32_t i = cache.size - 1; i > 0; --i) {
|
|
995
|
+
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
|
|
996
|
+
return i + 1;
|
|
997
|
+
}
|
|
998
|
+
}
|
|
999
|
+
|
|
1000
|
+
return 1;
|
|
1001
|
+
}
|
|
1002
|
+
|
|
1003
|
+
static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
|
1004
|
+
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
|
|
1005
|
+
cache.cells[i].pos = -1;
|
|
1006
|
+
cache.cells[i].seq_id.clear();
|
|
874
1007
|
}
|
|
1008
|
+
cache.head = 0;
|
|
1009
|
+
}
|
|
1010
|
+
|
|
1011
|
+
static void whisper_kv_cache_seq_rm(
|
|
1012
|
+
struct whisper_kv_cache & cache,
|
|
1013
|
+
whisper_seq_id seq_id,
|
|
1014
|
+
whisper_pos p0,
|
|
1015
|
+
whisper_pos p1) {
|
|
1016
|
+
uint32_t new_head = cache.size;
|
|
1017
|
+
|
|
1018
|
+
if (p0 < 0) p0 = 0;
|
|
1019
|
+
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
|
|
1020
|
+
|
|
1021
|
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
|
1022
|
+
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
|
1023
|
+
if (seq_id < 0) {
|
|
1024
|
+
cache.cells[i].seq_id.clear();
|
|
1025
|
+
} else if (cache.cells[i].has_seq_id(seq_id)) {
|
|
1026
|
+
cache.cells[i].seq_id.erase(seq_id);
|
|
1027
|
+
} else {
|
|
1028
|
+
continue;
|
|
1029
|
+
}
|
|
1030
|
+
if (cache.cells[i].seq_id.empty()) {
|
|
1031
|
+
cache.cells[i].pos = -1;
|
|
1032
|
+
if (new_head == cache.size) new_head = i;
|
|
1033
|
+
}
|
|
1034
|
+
}
|
|
1035
|
+
}
|
|
1036
|
+
|
|
1037
|
+
// If we freed up a slot, set head to it so searching can start there.
|
|
1038
|
+
if (new_head != cache.size) cache.head = new_head;
|
|
1039
|
+
}
|
|
1040
|
+
|
|
1041
|
+
static void whisper_kv_cache_seq_cp(
|
|
1042
|
+
struct whisper_kv_cache & cache,
|
|
1043
|
+
whisper_seq_id seq_id_src,
|
|
1044
|
+
whisper_seq_id seq_id_dst,
|
|
1045
|
+
whisper_pos p0,
|
|
1046
|
+
whisper_pos p1) {
|
|
1047
|
+
if (p0 < 0) p0 = 0;
|
|
1048
|
+
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
|
|
1049
|
+
|
|
1050
|
+
cache.head = 0;
|
|
1051
|
+
|
|
1052
|
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
|
1053
|
+
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
|
1054
|
+
cache.cells[i].seq_id.insert(seq_id_dst);
|
|
1055
|
+
}
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
|
|
1060
|
+
wsp_ggml_backend_t backend_gpu = NULL;
|
|
1061
|
+
|
|
1062
|
+
// initialize the backends
|
|
1063
|
+
#ifdef WSP_GGML_USE_CUBLAS
|
|
1064
|
+
if (params.use_gpu && wsp_ggml_cublas_loaded()) {
|
|
1065
|
+
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
|
1066
|
+
backend_gpu = wsp_ggml_backend_cuda_init(0);
|
|
1067
|
+
if (!backend_gpu) {
|
|
1068
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
|
|
1069
|
+
}
|
|
1070
|
+
}
|
|
1071
|
+
#endif
|
|
1072
|
+
|
|
1073
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1074
|
+
if (params.use_gpu) {
|
|
1075
|
+
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
|
|
1076
|
+
wsp_ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
|
|
1077
|
+
backend_gpu = wsp_ggml_backend_metal_init();
|
|
1078
|
+
if (!backend_gpu) {
|
|
1079
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__);
|
|
1080
|
+
} else if (!wsp_ggml_backend_metal_supports_family(backend_gpu, 7)) {
|
|
1081
|
+
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
|
|
1082
|
+
wsp_ggml_backend_free(backend_gpu);
|
|
1083
|
+
backend_gpu = NULL;
|
|
1084
|
+
}
|
|
1085
|
+
}
|
|
1086
|
+
#endif
|
|
1087
|
+
|
|
1088
|
+
if (backend_gpu) {
|
|
1089
|
+
return backend_gpu;
|
|
1090
|
+
}
|
|
1091
|
+
return wsp_ggml_backend_cpu_init();
|
|
875
1092
|
}
|
|
876
1093
|
|
|
877
1094
|
// load the model from a ggml file
|
|
@@ -886,7 +1103,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
|
|
|
886
1103
|
// see the convert-pt-to-ggml.py script for details
|
|
887
1104
|
//
|
|
888
1105
|
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
|
|
889
|
-
|
|
1106
|
+
WHISPER_LOG_INFO("%s: loading model\n", __func__);
|
|
890
1107
|
|
|
891
1108
|
const int64_t t_start_us = wsp_ggml_time_us();
|
|
892
1109
|
|
|
@@ -900,7 +1117,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
900
1117
|
uint32_t magic;
|
|
901
1118
|
read_safe(loader, magic);
|
|
902
1119
|
if (magic != WSP_GGML_FILE_MAGIC) {
|
|
903
|
-
|
|
1120
|
+
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
|
|
904
1121
|
return false;
|
|
905
1122
|
}
|
|
906
1123
|
}
|
|
@@ -923,6 +1140,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
923
1140
|
|
|
924
1141
|
assert(hparams.n_text_state == hparams.n_audio_state);
|
|
925
1142
|
|
|
1143
|
+
std::string mver = "";
|
|
1144
|
+
|
|
926
1145
|
if (hparams.n_audio_layer == 4) {
|
|
927
1146
|
model.type = e_model::MODEL_TINY;
|
|
928
1147
|
}
|
|
@@ -941,6 +1160,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
941
1160
|
|
|
942
1161
|
if (hparams.n_audio_layer == 32) {
|
|
943
1162
|
model.type = e_model::MODEL_LARGE;
|
|
1163
|
+
|
|
1164
|
+
if (hparams.n_vocab == 51866) {
|
|
1165
|
+
mver = " v3";
|
|
1166
|
+
}
|
|
944
1167
|
}
|
|
945
1168
|
|
|
946
1169
|
const int32_t qntvr = hparams.ftype / WSP_GGML_QNT_VERSION_FACTOR;
|
|
@@ -951,41 +1174,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
951
1174
|
// in order to save memory and also to speed up the computation
|
|
952
1175
|
wctx.wtype = wsp_ggml_ftype_to_wsp_ggml_type((wsp_ggml_ftype) (model.hparams.ftype));
|
|
953
1176
|
if (wctx.wtype == WSP_GGML_TYPE_COUNT) {
|
|
954
|
-
|
|
1177
|
+
WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
|
|
955
1178
|
return false;
|
|
956
1179
|
}
|
|
957
1180
|
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
log("%s: qntvr = %d\n", __func__, qntvr);
|
|
972
|
-
log("%s: type = %d\n", __func__, model.type);
|
|
973
|
-
|
|
974
|
-
// print memory requirements
|
|
975
|
-
{
|
|
976
|
-
// TODO
|
|
977
|
-
//log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
|
978
|
-
// mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
|
979
|
-
}
|
|
980
|
-
|
|
981
|
-
// initialize all memory buffers
|
|
982
|
-
// always have at least one decoder
|
|
983
|
-
|
|
984
|
-
wctx.model.buf = new std::vector<uint8_t>();
|
|
985
|
-
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
|
|
986
|
-
|
|
987
|
-
// we skip initialization of the state until it is needed
|
|
988
|
-
// because it might be that state will always be provided externally.
|
|
1181
|
+
WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
|
1182
|
+
WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
|
|
1183
|
+
WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
|
1184
|
+
WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
|
|
1185
|
+
WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
|
|
1186
|
+
WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
|
|
1187
|
+
WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
|
|
1188
|
+
WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
|
|
1189
|
+
WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
|
|
1190
|
+
WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels);
|
|
1191
|
+
WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype);
|
|
1192
|
+
WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr);
|
|
1193
|
+
WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
|
|
989
1194
|
}
|
|
990
1195
|
|
|
991
1196
|
// load mel filters
|
|
@@ -1006,7 +1211,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1006
1211
|
read_safe(loader, n_vocab);
|
|
1007
1212
|
|
|
1008
1213
|
//if (n_vocab != model.hparams.n_vocab) {
|
|
1009
|
-
//
|
|
1214
|
+
// WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
|
1010
1215
|
// __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
|
1011
1216
|
// return false;
|
|
1012
1217
|
//}
|
|
@@ -1026,7 +1231,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1026
1231
|
word.assign(&tmp[0], tmp.size());
|
|
1027
1232
|
} else {
|
|
1028
1233
|
// seems like we have an empty-string token in multi-language models (i = 50256)
|
|
1029
|
-
//
|
|
1234
|
+
//WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
|
|
1030
1235
|
word = "";
|
|
1031
1236
|
}
|
|
1032
1237
|
|
|
@@ -1040,17 +1245,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1040
1245
|
if (vocab.is_multilingual()) {
|
|
1041
1246
|
vocab.token_eot++;
|
|
1042
1247
|
vocab.token_sot++;
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
vocab.
|
|
1046
|
-
|
|
1047
|
-
vocab.
|
|
1048
|
-
vocab.
|
|
1049
|
-
vocab.
|
|
1248
|
+
|
|
1249
|
+
// account for variable number of language tokens
|
|
1250
|
+
const int dt = vocab.num_languages() - 98;
|
|
1251
|
+
|
|
1252
|
+
vocab.token_translate += dt;
|
|
1253
|
+
vocab.token_transcribe += dt;
|
|
1254
|
+
vocab.token_solm += dt;
|
|
1255
|
+
vocab.token_prev += dt;
|
|
1256
|
+
vocab.token_nosp += dt;
|
|
1257
|
+
vocab.token_not += dt;
|
|
1258
|
+
vocab.token_beg += dt;
|
|
1050
1259
|
}
|
|
1051
1260
|
|
|
1052
1261
|
if (n_vocab < model.hparams.n_vocab) {
|
|
1053
|
-
|
|
1262
|
+
WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
|
|
1054
1263
|
for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
|
|
1055
1264
|
if (i > vocab.token_beg) {
|
|
1056
1265
|
word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
|
|
@@ -1058,6 +1267,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1058
1267
|
word = "[_EOT_]";
|
|
1059
1268
|
} else if (i == vocab.token_sot) {
|
|
1060
1269
|
word = "[_SOT_]";
|
|
1270
|
+
} else if (i == vocab.token_translate) {
|
|
1271
|
+
word = "[_TRANSLATE_]";
|
|
1272
|
+
} else if (i == vocab.token_transcribe) {
|
|
1273
|
+
word = "[_TRANSCRIBE_]";
|
|
1061
1274
|
} else if (i == vocab.token_solm) {
|
|
1062
1275
|
word = "[_SOLM_]";
|
|
1063
1276
|
} else if (i == vocab.token_prev) {
|
|
@@ -1068,6 +1281,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1068
1281
|
word = "[_NOT_]";
|
|
1069
1282
|
} else if (i == vocab.token_beg) {
|
|
1070
1283
|
word = "[_BEG_]";
|
|
1284
|
+
} else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
|
|
1285
|
+
word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
|
|
1071
1286
|
} else {
|
|
1072
1287
|
word = "[_extra_token_" + std::to_string(i) + "]";
|
|
1073
1288
|
}
|
|
@@ -1075,139 +1290,36 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1075
1290
|
vocab.id_to_token[i] = word;
|
|
1076
1291
|
}
|
|
1077
1292
|
}
|
|
1078
|
-
}
|
|
1079
1293
|
|
|
1080
|
-
|
|
1294
|
+
WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages());
|
|
1295
|
+
}
|
|
1081
1296
|
|
|
1082
1297
|
const wsp_ggml_type wtype = wctx.wtype;
|
|
1083
1298
|
const wsp_ggml_type vtype = wctx.wtype == WSP_GGML_TYPE_F32 ? WSP_GGML_TYPE_F32 : WSP_GGML_TYPE_F16; // conv type
|
|
1084
1299
|
|
|
1300
|
+
// create the ggml context
|
|
1085
1301
|
{
|
|
1086
1302
|
const auto & hparams = model.hparams;
|
|
1087
1303
|
|
|
1088
|
-
const int n_vocab = hparams.n_vocab;
|
|
1089
|
-
|
|
1090
|
-
const int n_audio_ctx = hparams.n_audio_ctx;
|
|
1091
|
-
const int n_audio_state = hparams.n_audio_state;
|
|
1092
1304
|
const int n_audio_layer = hparams.n_audio_layer;
|
|
1305
|
+
const int n_text_layer = hparams.n_text_layer;
|
|
1093
1306
|
|
|
1094
|
-
const
|
|
1095
|
-
const int n_text_state = hparams.n_text_state;
|
|
1096
|
-
const int n_text_layer = hparams.n_text_layer;
|
|
1097
|
-
|
|
1098
|
-
const int n_mels = hparams.n_mels;
|
|
1099
|
-
|
|
1100
|
-
// encoder
|
|
1101
|
-
{
|
|
1102
|
-
ctx_size += n_audio_ctx*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_pe;
|
|
1103
|
-
|
|
1104
|
-
ctx_size += 3*n_mels*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_1_w
|
|
1105
|
-
ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_1_b
|
|
1106
|
-
|
|
1107
|
-
ctx_size += 3*n_audio_state*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_2_w
|
|
1108
|
-
ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_2_b
|
|
1109
|
-
|
|
1110
|
-
ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_w;
|
|
1111
|
-
ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_b;
|
|
1112
|
-
}
|
|
1113
|
-
|
|
1114
|
-
// decoder
|
|
1115
|
-
{
|
|
1116
|
-
ctx_size += n_text_ctx*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_pe;
|
|
1117
|
-
|
|
1118
|
-
ctx_size += n_vocab*n_text_state*wsp_ggml_type_sizef(wtype); // d_te;
|
|
1119
|
-
|
|
1120
|
-
ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_w;
|
|
1121
|
-
ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_b;
|
|
1122
|
-
}
|
|
1123
|
-
|
|
1124
|
-
// encoder layers
|
|
1125
|
-
{
|
|
1126
|
-
ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w
|
|
1127
|
-
ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b
|
|
1128
|
-
|
|
1129
|
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w
|
|
1130
|
-
ctx_size += n_audio_layer*( 4*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b
|
|
1131
|
-
|
|
1132
|
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w
|
|
1133
|
-
ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b
|
|
1134
|
-
|
|
1135
|
-
ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w
|
|
1136
|
-
ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b
|
|
1137
|
-
|
|
1138
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_q_w
|
|
1139
|
-
ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b
|
|
1307
|
+
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
|
1140
1308
|
|
|
1141
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_k_w
|
|
1142
|
-
|
|
1143
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_v_w
|
|
1144
|
-
ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b
|
|
1145
|
-
|
|
1146
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w
|
|
1147
|
-
ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b
|
|
1148
|
-
}
|
|
1149
|
-
|
|
1150
|
-
// decoder layers
|
|
1151
|
-
{
|
|
1152
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w
|
|
1153
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b
|
|
1154
|
-
|
|
1155
|
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w
|
|
1156
|
-
ctx_size += n_text_layer*( 4*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b
|
|
1157
|
-
|
|
1158
|
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w
|
|
1159
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b
|
|
1160
|
-
|
|
1161
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w
|
|
1162
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b
|
|
1163
|
-
|
|
1164
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_q_w
|
|
1165
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b
|
|
1166
|
-
|
|
1167
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_k_w
|
|
1168
|
-
|
|
1169
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_v_w
|
|
1170
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b
|
|
1171
|
-
|
|
1172
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w
|
|
1173
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b
|
|
1174
|
-
//
|
|
1175
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_w
|
|
1176
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_b
|
|
1177
|
-
|
|
1178
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_q_w
|
|
1179
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_q_b
|
|
1180
|
-
|
|
1181
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_k_w
|
|
1182
|
-
|
|
1183
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_v_w
|
|
1184
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_v_b
|
|
1185
|
-
|
|
1186
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_ln_1_w
|
|
1187
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_1_b
|
|
1188
|
-
}
|
|
1189
|
-
|
|
1190
|
-
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
|
|
1191
|
-
|
|
1192
|
-
log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
|
1193
|
-
}
|
|
1194
|
-
|
|
1195
|
-
// create the ggml context
|
|
1196
|
-
{
|
|
1197
1309
|
struct wsp_ggml_init_params params = {
|
|
1198
|
-
/*.mem_size =*/
|
|
1199
|
-
/*.mem_buffer =*/
|
|
1200
|
-
/*.no_alloc =*/
|
|
1310
|
+
/*.mem_size =*/ n_tensors*wsp_ggml_tensor_overhead(),
|
|
1311
|
+
/*.mem_buffer =*/ nullptr,
|
|
1312
|
+
/*.no_alloc =*/ true,
|
|
1201
1313
|
};
|
|
1202
1314
|
|
|
1203
1315
|
model.ctx = wsp_ggml_init(params);
|
|
1204
1316
|
if (!model.ctx) {
|
|
1205
|
-
|
|
1317
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_init() failed\n", __func__);
|
|
1206
1318
|
return false;
|
|
1207
1319
|
}
|
|
1208
1320
|
}
|
|
1209
1321
|
|
|
1210
|
-
// prepare
|
|
1322
|
+
// prepare tensors for the weights
|
|
1211
1323
|
{
|
|
1212
1324
|
auto & ctx = model.ctx;
|
|
1213
1325
|
|
|
@@ -1230,16 +1342,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1230
1342
|
|
|
1231
1343
|
// encoder
|
|
1232
1344
|
{
|
|
1233
|
-
model.e_pe
|
|
1345
|
+
model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
|
1234
1346
|
|
|
1235
|
-
model.e_conv_1_w
|
|
1236
|
-
model.e_conv_1_b
|
|
1347
|
+
model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
|
1348
|
+
model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
|
|
1237
1349
|
|
|
1238
|
-
model.e_conv_2_w
|
|
1239
|
-
model.e_conv_2_b
|
|
1350
|
+
model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
|
1351
|
+
model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
|
|
1240
1352
|
|
|
1241
|
-
model.e_ln_w
|
|
1242
|
-
model.e_ln_b
|
|
1353
|
+
model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
|
|
1354
|
+
model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
|
|
1243
1355
|
|
|
1244
1356
|
// map by name
|
|
1245
1357
|
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
|
@@ -1403,12 +1515,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1403
1515
|
}
|
|
1404
1516
|
}
|
|
1405
1517
|
|
|
1518
|
+
wctx.backend = whisper_backend_init(wctx.params);
|
|
1519
|
+
|
|
1520
|
+
{
|
|
1521
|
+
size_t size_main = 0;
|
|
1522
|
+
|
|
1523
|
+
for (const auto & t : model.tensors) {
|
|
1524
|
+
size_main += wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead();
|
|
1525
|
+
}
|
|
1526
|
+
|
|
1527
|
+
model.buffer = wsp_ggml_backend_alloc_buffer(wctx.backend, size_main);
|
|
1528
|
+
|
|
1529
|
+
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6);
|
|
1530
|
+
}
|
|
1531
|
+
|
|
1532
|
+
wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(model.buffer);
|
|
1533
|
+
|
|
1534
|
+
// allocate tensors in the backend buffers
|
|
1535
|
+
{
|
|
1536
|
+
for (const auto & t : model.tensors) {
|
|
1537
|
+
wsp_ggml_allocr_alloc(alloc, t.second);
|
|
1538
|
+
}
|
|
1539
|
+
}
|
|
1540
|
+
|
|
1406
1541
|
// load weights
|
|
1407
1542
|
{
|
|
1408
1543
|
size_t total_size = 0;
|
|
1409
1544
|
|
|
1410
1545
|
model.n_loaded = 0;
|
|
1411
1546
|
|
|
1547
|
+
std::vector<char> read_buf;
|
|
1548
|
+
|
|
1412
1549
|
while (true) {
|
|
1413
1550
|
int32_t n_dims;
|
|
1414
1551
|
int32_t length;
|
|
@@ -1435,20 +1572,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1435
1572
|
name.assign(&tmp[0], tmp.size());
|
|
1436
1573
|
|
|
1437
1574
|
if (model.tensors.find(name) == model.tensors.end()) {
|
|
1438
|
-
|
|
1575
|
+
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
|
1439
1576
|
return false;
|
|
1440
1577
|
}
|
|
1441
1578
|
|
|
1442
1579
|
auto tensor = model.tensors[name.data()];
|
|
1580
|
+
|
|
1443
1581
|
if (wsp_ggml_nelements(tensor) != nelements) {
|
|
1444
|
-
|
|
1445
|
-
|
|
1582
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
|
1583
|
+
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
|
1446
1584
|
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
|
1447
1585
|
return false;
|
|
1448
1586
|
}
|
|
1449
1587
|
|
|
1450
1588
|
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
|
1451
|
-
|
|
1589
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
|
1452
1590
|
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
|
1453
1591
|
return false;
|
|
1454
1592
|
}
|
|
@@ -1456,29 +1594,49 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1456
1594
|
const size_t bpe = wsp_ggml_type_size(wsp_ggml_type(ttype));
|
|
1457
1595
|
|
|
1458
1596
|
if ((nelements*bpe)/wsp_ggml_blck_size(tensor->type) != wsp_ggml_nbytes(tensor)) {
|
|
1459
|
-
|
|
1597
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
|
1460
1598
|
__func__, name.data(), wsp_ggml_nbytes(tensor), nelements*bpe);
|
|
1461
1599
|
return false;
|
|
1462
1600
|
}
|
|
1463
1601
|
|
|
1464
|
-
|
|
1465
|
-
BYTESWAP_TENSOR(tensor);
|
|
1602
|
+
wsp_ggml_backend_t backend = wctx.backend;
|
|
1466
1603
|
|
|
1467
|
-
//printf("%
|
|
1604
|
+
//printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
|
|
1605
|
+
|
|
1606
|
+
if ((wsp_ggml_backend_is_cpu(backend)
|
|
1607
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1608
|
+
|| wsp_ggml_backend_is_metal(backend)
|
|
1609
|
+
#endif
|
|
1610
|
+
)) {
|
|
1611
|
+
// for the CPU and Metal backend, we can read directly into the tensor
|
|
1612
|
+
loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
|
|
1613
|
+
BYTESWAP_TENSOR(tensor);
|
|
1614
|
+
} else {
|
|
1615
|
+
// read into a temporary buffer first, then copy to device memory
|
|
1616
|
+
read_buf.resize(wsp_ggml_nbytes(tensor));
|
|
1617
|
+
|
|
1618
|
+
loader->read(loader->context, read_buf.data(), read_buf.size());
|
|
1619
|
+
|
|
1620
|
+
wsp_ggml_backend_tensor_set(tensor, read_buf.data(), 0, wsp_ggml_nbytes(tensor));
|
|
1621
|
+
}
|
|
1622
|
+
|
|
1623
|
+
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1e6);
|
|
1468
1624
|
total_size += wsp_ggml_nbytes(tensor);
|
|
1469
1625
|
model.n_loaded++;
|
|
1470
1626
|
}
|
|
1471
1627
|
|
|
1472
|
-
|
|
1628
|
+
WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
|
|
1473
1629
|
|
|
1474
1630
|
if (model.n_loaded == 0) {
|
|
1475
|
-
|
|
1631
|
+
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
|
1476
1632
|
} else if (model.n_loaded != (int) model.tensors.size()) {
|
|
1477
|
-
|
|
1633
|
+
WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
|
1478
1634
|
return false;
|
|
1479
1635
|
}
|
|
1480
1636
|
}
|
|
1481
1637
|
|
|
1638
|
+
wsp_ggml_allocr_free(alloc);
|
|
1639
|
+
|
|
1482
1640
|
wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
|
|
1483
1641
|
|
|
1484
1642
|
return true;
|
|
@@ -1534,10 +1692,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1534
1692
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1535
1693
|
assert(mel_inp.n_mel == n_mels);
|
|
1536
1694
|
|
|
1537
|
-
|
|
1695
|
+
wstate.inp_mel.resize(wsp_ggml_nelements(mel));
|
|
1696
|
+
|
|
1697
|
+
float * dst = wstate.inp_mel.data();
|
|
1538
1698
|
memset(dst, 0, wsp_ggml_nbytes(mel));
|
|
1539
1699
|
|
|
1540
|
-
const int i0 = std::min(mel_offset,
|
|
1700
|
+
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
|
1541
1701
|
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
|
1542
1702
|
|
|
1543
1703
|
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
|
@@ -1545,6 +1705,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1545
1705
|
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
|
1546
1706
|
}
|
|
1547
1707
|
}
|
|
1708
|
+
|
|
1709
|
+
wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
|
|
1548
1710
|
}
|
|
1549
1711
|
|
|
1550
1712
|
struct wsp_ggml_tensor * cur = nullptr;
|
|
@@ -1553,24 +1715,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1553
1715
|
// convolution + gelu
|
|
1554
1716
|
{
|
|
1555
1717
|
cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
|
1556
|
-
cur = wsp_ggml_add(ctx0,
|
|
1557
|
-
wsp_ggml_repeat(ctx0,
|
|
1558
|
-
model.e_conv_1_b,
|
|
1559
|
-
cur),
|
|
1560
|
-
cur);
|
|
1718
|
+
cur = wsp_ggml_add(ctx0, cur, model.e_conv_1_b);
|
|
1561
1719
|
|
|
1562
1720
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1563
1721
|
|
|
1564
1722
|
cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
|
1565
|
-
cur = wsp_ggml_add(ctx0,
|
|
1566
|
-
wsp_ggml_repeat(ctx0,
|
|
1567
|
-
model.e_conv_2_b,
|
|
1568
|
-
cur),
|
|
1569
|
-
cur);
|
|
1723
|
+
cur = wsp_ggml_add(ctx0, cur, model.e_conv_2_b);
|
|
1570
1724
|
|
|
1571
1725
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1572
1726
|
}
|
|
1573
1727
|
|
|
1728
|
+
wsp_ggml_set_name(cur, "embd_conv");
|
|
1574
1729
|
wstate.embd_conv = cur;
|
|
1575
1730
|
} else {
|
|
1576
1731
|
#ifdef WHISPER_USE_COREML
|
|
@@ -1578,7 +1733,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1578
1733
|
wsp_ggml_allocr_alloc(alloc, cur);
|
|
1579
1734
|
|
|
1580
1735
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1581
|
-
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
|
1736
|
+
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
|
|
1582
1737
|
}
|
|
1583
1738
|
#endif
|
|
1584
1739
|
#ifdef WHISPER_USE_OPENVINO
|
|
@@ -1590,6 +1745,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1590
1745
|
}
|
|
1591
1746
|
#endif
|
|
1592
1747
|
|
|
1748
|
+
wsp_ggml_set_name(cur, "embd_enc");
|
|
1593
1749
|
wstate.embd_enc = cur;
|
|
1594
1750
|
}
|
|
1595
1751
|
|
|
@@ -1619,19 +1775,26 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1619
1775
|
|
|
1620
1776
|
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
1621
1777
|
|
|
1622
|
-
wsp_ggml_cgraph * gf =
|
|
1778
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
|
1623
1779
|
|
|
1624
1780
|
wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
|
1625
1781
|
|
|
1782
|
+
//struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state);
|
|
1783
|
+
//wsp_ggml_allocr_alloc(alloc, cur);
|
|
1784
|
+
|
|
1785
|
+
//if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1786
|
+
// wsp_ggml_backend_tensor_copy(wstate.embd_conv, cur);
|
|
1787
|
+
//}
|
|
1788
|
+
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
|
|
1789
|
+
|
|
1626
1790
|
struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
1627
1791
|
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
1628
1792
|
|
|
1629
1793
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1630
|
-
|
|
1794
|
+
const float val = 1.0f/sqrtf(float(n_state)/n_head);
|
|
1795
|
+
wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
|
1631
1796
|
}
|
|
1632
1797
|
|
|
1633
|
-
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
|
|
1634
|
-
|
|
1635
1798
|
// ===================================================================
|
|
1636
1799
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
1637
1800
|
//static int iter = -1;
|
|
@@ -1650,7 +1813,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1650
1813
|
const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
|
|
1651
1814
|
|
|
1652
1815
|
struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
|
|
1653
|
-
|
|
1654
1816
|
cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_cont(ctx0, wsp_ggml_transpose(ctx0, cur)));
|
|
1655
1817
|
|
|
1656
1818
|
// ===================================================================
|
|
@@ -1838,11 +2000,11 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1838
2000
|
////////////////////////////////////////////////////////////////////////////
|
|
1839
2001
|
|
|
1840
2002
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
1841
|
-
// wsp_ggml_used_mem(ctx0)/
|
|
1842
|
-
// wstate.get_buf_max_mem(0)/
|
|
1843
|
-
// wstate.get_buf_max_mem(1)/
|
|
1844
|
-
// wstate.get_buf_max_mem(2)/
|
|
1845
|
-
// wstate.get_buf_max_mem(3)/
|
|
2003
|
+
// wsp_ggml_used_mem(ctx0)/1e6,
|
|
2004
|
+
// wstate.get_buf_max_mem(0)/1e6,
|
|
2005
|
+
// wstate.get_buf_max_mem(1)/1e6,
|
|
2006
|
+
// wstate.get_buf_max_mem(2)/1e6,
|
|
2007
|
+
// wstate.get_buf_max_mem(3)/1e6);
|
|
1846
2008
|
|
|
1847
2009
|
wsp_ggml_free(ctx0);
|
|
1848
2010
|
|
|
@@ -1872,13 +2034,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
|
1872
2034
|
|
|
1873
2035
|
wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
|
1874
2036
|
|
|
2037
|
+
//struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
2038
|
+
//wsp_ggml_allocr_alloc(alloc, cur);
|
|
2039
|
+
|
|
2040
|
+
//if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2041
|
+
// wsp_ggml_backend_tensor_copy(wstate.embd_enc, cur);
|
|
2042
|
+
//}
|
|
1875
2043
|
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
|
|
1876
2044
|
|
|
1877
2045
|
struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
1878
2046
|
wsp_ggml_allocr_alloc(alloc, Kscale);
|
|
1879
2047
|
|
|
1880
2048
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1881
|
-
|
|
2049
|
+
const float val = pow(float(n_state) / n_head, -0.25);
|
|
2050
|
+
wsp_ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
|
|
1882
2051
|
}
|
|
1883
2052
|
|
|
1884
2053
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
@@ -1949,7 +2118,7 @@ static bool whisper_encode_internal(
|
|
|
1949
2118
|
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1950
2119
|
|
|
1951
2120
|
if (!whisper_encode_external(wstate)) {
|
|
1952
|
-
wsp_ggml_graph_compute_helper(wstate.
|
|
2121
|
+
wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
1953
2122
|
}
|
|
1954
2123
|
}
|
|
1955
2124
|
|
|
@@ -1963,16 +2132,7 @@ static bool whisper_encode_internal(
|
|
|
1963
2132
|
|
|
1964
2133
|
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1965
2134
|
|
|
1966
|
-
|
|
1967
|
-
if (wstate.ctx_metal) {
|
|
1968
|
-
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
1969
|
-
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
1970
|
-
} else {
|
|
1971
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1972
|
-
}
|
|
1973
|
-
#else
|
|
1974
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1975
|
-
#endif
|
|
2135
|
+
wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
1976
2136
|
}
|
|
1977
2137
|
|
|
1978
2138
|
// cross
|
|
@@ -1985,49 +2145,40 @@ static bool whisper_encode_internal(
|
|
|
1985
2145
|
|
|
1986
2146
|
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1987
2147
|
|
|
1988
|
-
|
|
1989
|
-
if (wstate.ctx_metal) {
|
|
1990
|
-
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
1991
|
-
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
1992
|
-
} else {
|
|
1993
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1994
|
-
}
|
|
1995
|
-
#else
|
|
1996
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1997
|
-
#endif
|
|
2148
|
+
wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
1998
2149
|
}
|
|
1999
2150
|
|
|
2000
|
-
// wsp_ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
2001
|
-
|
|
2002
2151
|
wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
|
|
2003
2152
|
wstate.n_encode++;
|
|
2004
2153
|
|
|
2005
|
-
return
|
|
2154
|
+
return !(abort_callback && abort_callback(abort_callback_data));
|
|
2006
2155
|
}
|
|
2007
2156
|
|
|
2008
2157
|
static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
2009
2158
|
whisper_context & wctx,
|
|
2010
2159
|
whisper_state & wstate,
|
|
2011
|
-
|
|
2012
|
-
const whisper_token * tokens,
|
|
2013
|
-
int n_tokens,
|
|
2014
|
-
int n_past) {
|
|
2160
|
+
const whisper_batch & batch) {
|
|
2015
2161
|
const auto & model = wctx.model;
|
|
2016
2162
|
const auto & hparams = model.hparams;
|
|
2017
2163
|
|
|
2018
|
-
auto & kv_self =
|
|
2164
|
+
auto & kv_self = wstate.kv_self;
|
|
2019
2165
|
|
|
2020
2166
|
WHISPER_ASSERT(!!kv_self.ctx);
|
|
2021
2167
|
|
|
2022
|
-
|
|
2168
|
+
wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
|
2169
|
+
|
|
2170
|
+
const int n_ctx = kv_self.size;
|
|
2023
2171
|
const int n_state = hparams.n_text_state;
|
|
2024
2172
|
const int n_head = hparams.n_text_head;
|
|
2025
2173
|
const int n_layer = hparams.n_text_layer;
|
|
2026
2174
|
|
|
2027
|
-
const int
|
|
2028
|
-
const int
|
|
2175
|
+
const int n_tokens = batch.n_tokens;
|
|
2176
|
+
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
2029
2177
|
|
|
2030
|
-
|
|
2178
|
+
const int32_t n_kv = wsp_ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
|
|
2179
|
+
const int32_t kv_head = wsp_ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
|
|
2180
|
+
|
|
2181
|
+
//WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
|
2031
2182
|
|
|
2032
2183
|
struct wsp_ggml_init_params params = {
|
|
2033
2184
|
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
|
@@ -2037,23 +2188,22 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2037
2188
|
|
|
2038
2189
|
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
2039
2190
|
|
|
2040
|
-
wsp_ggml_cgraph * gf =
|
|
2191
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
|
2041
2192
|
|
|
2042
|
-
|
|
2043
|
-
|
|
2044
|
-
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
|
|
2193
|
+
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2045
2194
|
wsp_ggml_allocr_alloc(alloc, embd);
|
|
2046
2195
|
|
|
2047
2196
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2048
|
-
|
|
2197
|
+
wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
|
|
2049
2198
|
}
|
|
2050
2199
|
|
|
2051
|
-
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32,
|
|
2200
|
+
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2052
2201
|
wsp_ggml_allocr_alloc(alloc, position);
|
|
2053
2202
|
|
|
2054
2203
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2055
|
-
for (int i = 0; i <
|
|
2056
|
-
|
|
2204
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
2205
|
+
const int32_t val = batch.pos[i];
|
|
2206
|
+
wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
|
2057
2207
|
}
|
|
2058
2208
|
}
|
|
2059
2209
|
|
|
@@ -2061,7 +2211,33 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2061
2211
|
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
2062
2212
|
|
|
2063
2213
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2064
|
-
|
|
2214
|
+
const float val = pow(float(n_state)/n_head, -0.25);
|
|
2215
|
+
wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
|
2216
|
+
}
|
|
2217
|
+
|
|
2218
|
+
struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, n_tokens, 1);
|
|
2219
|
+
wsp_ggml_allocr_alloc(alloc, KQ_mask);
|
|
2220
|
+
|
|
2221
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2222
|
+
wstate.inp_mask.resize(n_kv*n_tokens);
|
|
2223
|
+
|
|
2224
|
+
float * data = wstate.inp_mask.data();
|
|
2225
|
+
memset(data, 0, wsp_ggml_nbytes(KQ_mask));
|
|
2226
|
+
|
|
2227
|
+
for (int h = 0; h < 1; ++h) {
|
|
2228
|
+
for (int j = 0; j < n_tokens; ++j) {
|
|
2229
|
+
const whisper_pos pos = batch.pos[j];
|
|
2230
|
+
const whisper_seq_id seq_id = batch.seq_id[j][0];
|
|
2231
|
+
|
|
2232
|
+
for (int i = 0; i < n_kv; ++i) {
|
|
2233
|
+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
|
2234
|
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
|
2235
|
+
}
|
|
2236
|
+
}
|
|
2237
|
+
}
|
|
2238
|
+
}
|
|
2239
|
+
|
|
2240
|
+
wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
|
|
2065
2241
|
}
|
|
2066
2242
|
|
|
2067
2243
|
// token encoding + position encoding
|
|
@@ -2116,12 +2292,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2116
2292
|
Vcur,
|
|
2117
2293
|
layer.attn_v_b);
|
|
2118
2294
|
|
|
2119
|
-
Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state,
|
|
2295
|
+
Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
|
2120
2296
|
|
|
2121
|
-
struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k,
|
|
2122
|
-
struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v,
|
|
2297
|
+
struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
|
2298
|
+
struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
|
2123
2299
|
( n_ctx)*wsp_ggml_element_size(kv_self.v),
|
|
2124
|
-
(il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state +
|
|
2300
|
+
(il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
|
|
2125
2301
|
|
|
2126
2302
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
|
|
2127
2303
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
|
|
@@ -2131,12 +2307,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2131
2307
|
|
|
2132
2308
|
struct wsp_ggml_tensor * Q =
|
|
2133
2309
|
wsp_ggml_permute(ctx0,
|
|
2134
|
-
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head,
|
|
2310
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
|
2135
2311
|
0, 2, 1, 3);
|
|
2136
2312
|
|
|
2137
2313
|
struct wsp_ggml_tensor * K =
|
|
2138
2314
|
wsp_ggml_view_3d(ctx0, kv_self.k,
|
|
2139
|
-
n_state/n_head,
|
|
2315
|
+
n_state/n_head, n_kv, n_head,
|
|
2140
2316
|
wsp_ggml_element_size(kv_self.k)*n_state,
|
|
2141
2317
|
wsp_ggml_element_size(kv_self.k)*n_state/n_head,
|
|
2142
2318
|
wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
|
@@ -2146,16 +2322,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2146
2322
|
|
|
2147
2323
|
//struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
|
|
2148
2324
|
|
|
2149
|
-
struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
2325
|
+
//struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
2326
|
+
struct wsp_ggml_tensor * KQ_masked = wsp_ggml_add(ctx0, KQ, KQ_mask);
|
|
2150
2327
|
|
|
2151
2328
|
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
|
|
2152
2329
|
|
|
2153
2330
|
struct wsp_ggml_tensor * V =
|
|
2154
2331
|
wsp_ggml_view_3d(ctx0, kv_self.v,
|
|
2155
|
-
|
|
2332
|
+
n_kv, n_state/n_head, n_head,
|
|
2156
2333
|
n_ctx*wsp_ggml_element_size(kv_self.v),
|
|
2157
2334
|
n_ctx*wsp_ggml_element_size(kv_self.v)*n_state/n_head,
|
|
2158
|
-
|
|
2335
|
+
n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
|
|
2159
2336
|
|
|
2160
2337
|
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
2161
2338
|
|
|
@@ -2163,7 +2340,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2163
2340
|
|
|
2164
2341
|
cur = wsp_ggml_cpy(ctx0,
|
|
2165
2342
|
KQV_merged,
|
|
2166
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state,
|
|
2343
|
+
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
|
|
2167
2344
|
}
|
|
2168
2345
|
|
|
2169
2346
|
// projection
|
|
@@ -2207,33 +2384,33 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2207
2384
|
// Kcross is already scaled
|
|
2208
2385
|
struct wsp_ggml_tensor * Kcross =
|
|
2209
2386
|
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2210
|
-
n_state/n_head,
|
|
2387
|
+
n_state/n_head, n_audio_ctx, n_head,
|
|
2211
2388
|
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2212
2389
|
wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
|
2213
|
-
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*
|
|
2390
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
|
2214
2391
|
|
|
2215
2392
|
//struct wsp_ggml_tensor * Vcross =
|
|
2216
2393
|
// wsp_ggml_reshape_3d(ctx0,
|
|
2217
|
-
// wsp_ggml_view_1d(ctx0, wstate.kv_cross.v,
|
|
2218
|
-
// n_state/n_head, n_head,
|
|
2394
|
+
// wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
|
|
2395
|
+
// n_state/n_head, n_head, n_audio_ctx);
|
|
2219
2396
|
|
|
2220
2397
|
//struct wsp_ggml_tensor * V_trans =
|
|
2221
2398
|
// wsp_ggml_cpy(ctx0,
|
|
2222
2399
|
// wsp_ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
|
2223
|
-
// wsp_ggml_new_tensor_3d(ctx0, Vcross->type,
|
|
2400
|
+
// wsp_ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
|
2224
2401
|
|
|
2225
2402
|
struct wsp_ggml_tensor * V =
|
|
2226
2403
|
wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
|
|
2227
|
-
|
|
2228
|
-
|
|
2229
|
-
|
|
2230
|
-
|
|
2404
|
+
n_audio_ctx, n_state/n_head, n_head,
|
|
2405
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
2406
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
|
2407
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
|
2231
2408
|
|
|
2232
2409
|
// ------
|
|
2233
2410
|
|
|
2234
2411
|
struct wsp_ggml_tensor * Q =
|
|
2235
2412
|
wsp_ggml_permute(ctx0,
|
|
2236
|
-
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head,
|
|
2413
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
|
2237
2414
|
0, 2, 1, 3);
|
|
2238
2415
|
|
|
2239
2416
|
// K * Q
|
|
@@ -2254,10 +2431,10 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2254
2431
|
|
|
2255
2432
|
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2256
2433
|
|
|
2257
|
-
// cur = KQV_merged.contiguous().view(n_state,
|
|
2434
|
+
// cur = KQV_merged.contiguous().view(n_state, n_tokens)
|
|
2258
2435
|
cur = wsp_ggml_cpy(ctx0,
|
|
2259
2436
|
KQV_merged,
|
|
2260
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state,
|
|
2437
|
+
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
|
|
2261
2438
|
}
|
|
2262
2439
|
|
|
2263
2440
|
// projection
|
|
@@ -2329,9 +2506,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2329
2506
|
}
|
|
2330
2507
|
|
|
2331
2508
|
// compute logits only for the last token
|
|
2332
|
-
// comment this line to compute logits for all
|
|
2509
|
+
// comment this line to compute logits for all n_tokens
|
|
2333
2510
|
// might be useful in the future
|
|
2334
|
-
cur = wsp_ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
|
2511
|
+
//cur = wsp_ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
|
2335
2512
|
|
|
2336
2513
|
struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
|
|
2337
2514
|
|
|
@@ -2355,10 +2532,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2355
2532
|
static bool whisper_decode_internal(
|
|
2356
2533
|
whisper_context & wctx,
|
|
2357
2534
|
whisper_state & wstate,
|
|
2358
|
-
|
|
2359
|
-
const whisper_token * tokens,
|
|
2360
|
-
const int n_tokens,
|
|
2361
|
-
const int n_past,
|
|
2535
|
+
const whisper_batch & batch,
|
|
2362
2536
|
const int n_threads,
|
|
2363
2537
|
whisper_abort_callback abort_callback,
|
|
2364
2538
|
void * abort_callback_data) {
|
|
@@ -2367,65 +2541,72 @@ static bool whisper_decode_internal(
|
|
|
2367
2541
|
const auto & model = wctx.model;
|
|
2368
2542
|
const auto & hparams = model.hparams;
|
|
2369
2543
|
|
|
2370
|
-
const int n_vocab
|
|
2544
|
+
const int n_vocab = hparams.n_vocab;
|
|
2545
|
+
const int n_tokens = batch.n_tokens;
|
|
2371
2546
|
|
|
2372
2547
|
auto & logits_out = wstate.logits;
|
|
2373
2548
|
|
|
2374
2549
|
struct wsp_ggml_tensor * logits;
|
|
2375
2550
|
|
|
2551
|
+
// find KV slot for the batch
|
|
2552
|
+
{
|
|
2553
|
+
auto & kv_self = wstate.kv_self;
|
|
2554
|
+
|
|
2555
|
+
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
|
|
2556
|
+
return false;
|
|
2557
|
+
}
|
|
2558
|
+
|
|
2559
|
+
kv_self.n = whisper_kv_cache_cell_max(kv_self);
|
|
2560
|
+
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
|
2561
|
+
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
|
2562
|
+
}
|
|
2563
|
+
|
|
2376
2564
|
// decoder
|
|
2377
2565
|
{
|
|
2378
2566
|
auto & alloc = wstate.alloc_decode.alloc;
|
|
2379
2567
|
|
|
2380
2568
|
wsp_ggml_allocr_reset(alloc);
|
|
2381
2569
|
|
|
2382
|
-
wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate,
|
|
2570
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
|
|
2383
2571
|
|
|
2384
2572
|
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
2385
2573
|
|
|
2386
2574
|
logits = gf->nodes[gf->n_nodes - 1];
|
|
2387
2575
|
|
|
2388
|
-
|
|
2389
|
-
if (wstate.ctx_metal) {
|
|
2390
|
-
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
2391
|
-
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
2392
|
-
} else {
|
|
2393
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2394
|
-
}
|
|
2395
|
-
#else
|
|
2396
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2397
|
-
#endif
|
|
2576
|
+
wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
2398
2577
|
}
|
|
2399
2578
|
|
|
2400
|
-
|
|
2401
|
-
|
|
2402
|
-
|
|
2403
|
-
|
|
2404
|
-
|
|
2405
|
-
|
|
2406
|
-
|
|
2579
|
+
logits_out.resize(n_tokens*n_vocab);
|
|
2580
|
+
for (int i = 0; i < n_tokens; i++) {
|
|
2581
|
+
if (batch.logits[i] == 0) {
|
|
2582
|
+
continue;
|
|
2583
|
+
}
|
|
2584
|
+
wsp_ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
|
|
2585
|
+
}
|
|
2407
2586
|
|
|
2408
|
-
if (n_tokens > 1) {
|
|
2587
|
+
if (batch.n_tokens > 1) {
|
|
2409
2588
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
2410
|
-
// wsp_ggml_used_mem(ctx0)/
|
|
2411
|
-
// wstate.get_buf_max_mem(0)/
|
|
2412
|
-
// wstate.get_buf_max_mem(1)/
|
|
2413
|
-
// wstate.get_buf_max_mem(2)/
|
|
2414
|
-
// wstate.get_buf_max_mem(3)/
|
|
2589
|
+
// wsp_ggml_used_mem(ctx0)/1e6,
|
|
2590
|
+
// wstate.get_buf_max_mem(0)/1e6,
|
|
2591
|
+
// wstate.get_buf_max_mem(1)/1e6,
|
|
2592
|
+
// wstate.get_buf_max_mem(2)/1e6,
|
|
2593
|
+
// wstate.get_buf_max_mem(3)/1e6);
|
|
2415
2594
|
}
|
|
2416
2595
|
|
|
2417
|
-
if (n_tokens == 1) {
|
|
2596
|
+
if (batch.n_tokens == 1) {
|
|
2418
2597
|
wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
|
|
2419
2598
|
wstate.n_decode++;
|
|
2599
|
+
} else if (batch.n_tokens < 16) {
|
|
2600
|
+
wstate.t_batchd_us += wsp_ggml_time_us() - t_start_us;
|
|
2601
|
+
wstate.n_batchd += n_tokens;
|
|
2420
2602
|
} else {
|
|
2421
2603
|
wstate.t_prompt_us += wsp_ggml_time_us() - t_start_us;
|
|
2422
|
-
wstate.n_prompt
|
|
2604
|
+
wstate.n_prompt += n_tokens;
|
|
2423
2605
|
}
|
|
2424
2606
|
|
|
2425
|
-
return
|
|
2607
|
+
return !(abort_callback && abort_callback(abort_callback_data));
|
|
2426
2608
|
}
|
|
2427
2609
|
|
|
2428
|
-
|
|
2429
2610
|
// 500 -> 00:05.000
|
|
2430
2611
|
// 6000 -> 01:00.000
|
|
2431
2612
|
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
@@ -2769,7 +2950,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
|
|
|
2769
2950
|
--j;
|
|
2770
2951
|
}
|
|
2771
2952
|
if (!found) {
|
|
2772
|
-
|
|
2953
|
+
WHISPER_LOG_ERROR("unknown token\n");
|
|
2773
2954
|
++i;
|
|
2774
2955
|
}
|
|
2775
2956
|
}
|
|
@@ -2832,94 +3013,105 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
|
|
2832
3013
|
|
|
2833
3014
|
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
2834
3015
|
fill_sin_cos_table();
|
|
3016
|
+
|
|
2835
3017
|
whisper_state * state = new whisper_state;
|
|
2836
3018
|
|
|
2837
|
-
|
|
2838
|
-
|
|
3019
|
+
state->backend = whisper_backend_init(ctx->params);
|
|
3020
|
+
|
|
3021
|
+
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
|
|
3022
|
+
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
|
3023
|
+
const int factor = 3;
|
|
3024
|
+
|
|
3025
|
+
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
|
|
3026
|
+
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
|
2839
3027
|
delete state;
|
|
2840
3028
|
return nullptr;
|
|
2841
3029
|
}
|
|
2842
3030
|
|
|
2843
3031
|
{
|
|
2844
|
-
const size_t memory_size = wsp_ggml_nbytes(state->
|
|
2845
|
-
|
|
3032
|
+
const size_t memory_size = wsp_ggml_nbytes(state->kv_self.k) + wsp_ggml_nbytes(state->kv_self.v);
|
|
3033
|
+
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
2846
3034
|
}
|
|
2847
3035
|
|
|
2848
|
-
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
|
2849
|
-
|
|
3036
|
+
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
|
3037
|
+
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
2850
3038
|
delete state;
|
|
2851
3039
|
return nullptr;
|
|
2852
3040
|
}
|
|
2853
3041
|
|
|
2854
3042
|
{
|
|
2855
3043
|
const size_t memory_size = wsp_ggml_nbytes(state->kv_cross.k) + wsp_ggml_nbytes(state->kv_cross.v);
|
|
2856
|
-
|
|
3044
|
+
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
2857
3045
|
}
|
|
2858
3046
|
|
|
3047
|
+
|
|
2859
3048
|
#ifdef WHISPER_USE_COREML
|
|
2860
|
-
if (ctx->
|
|
3049
|
+
if (ctx->params.use_coreml) {
|
|
2861
3050
|
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
|
2862
3051
|
|
|
2863
|
-
|
|
2864
|
-
|
|
3052
|
+
WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
3053
|
+
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
|
2865
3054
|
|
|
2866
3055
|
state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
|
|
2867
3056
|
if (!state->ctx_coreml) {
|
|
2868
|
-
|
|
3057
|
+
WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
2869
3058
|
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
|
2870
3059
|
delete state;
|
|
2871
3060
|
return nullptr;
|
|
2872
3061
|
#endif
|
|
2873
3062
|
} else {
|
|
2874
|
-
|
|
3063
|
+
WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__);
|
|
3064
|
+
}
|
|
2875
3065
|
}
|
|
2876
|
-
}
|
|
2877
3066
|
#endif
|
|
2878
3067
|
|
|
2879
3068
|
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
|
2880
3069
|
|
|
2881
|
-
state->
|
|
3070
|
+
state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
|
|
2882
3071
|
|
|
2883
3072
|
// TAGS: WHISPER_DECODER_INIT
|
|
2884
3073
|
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
|
2885
3074
|
|
|
2886
|
-
state->decoders[0].probs.reserve
|
|
2887
|
-
state->decoders[0].logits.reserve
|
|
2888
|
-
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
|
3075
|
+
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
|
|
3076
|
+
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
|
|
3077
|
+
state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
|
|
3078
|
+
state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
|
|
3079
|
+
|
|
3080
|
+
state->decoders[0].rng = std::mt19937(0);
|
|
2889
3081
|
|
|
2890
3082
|
// conv allocator
|
|
2891
3083
|
{
|
|
2892
|
-
whisper_allocr_graph_init(state->alloc_conv,
|
|
3084
|
+
whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
|
|
2893
3085
|
[&]() {
|
|
2894
3086
|
return whisper_build_graph_conv(*ctx, *state, 0);
|
|
2895
3087
|
});
|
|
2896
3088
|
|
|
2897
|
-
|
|
3089
|
+
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
|
|
2898
3090
|
}
|
|
2899
3091
|
|
|
2900
3092
|
// encoder allocator
|
|
2901
3093
|
if (!whisper_encode_external(*state)) {
|
|
2902
|
-
whisper_allocr_graph_init(state->alloc_encode,
|
|
3094
|
+
whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
|
|
2903
3095
|
[&]() {
|
|
2904
3096
|
return whisper_build_graph_encoder(*ctx, *state);
|
|
2905
3097
|
});
|
|
2906
3098
|
|
|
2907
|
-
|
|
3099
|
+
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
|
|
2908
3100
|
}
|
|
2909
3101
|
|
|
2910
3102
|
// cross allocator
|
|
2911
3103
|
{
|
|
2912
|
-
whisper_allocr_graph_init(state->alloc_cross,
|
|
3104
|
+
whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
|
|
2913
3105
|
[&]() {
|
|
2914
3106
|
return whisper_build_graph_cross(*ctx, *state);
|
|
2915
3107
|
});
|
|
2916
3108
|
|
|
2917
|
-
|
|
3109
|
+
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
|
|
2918
3110
|
}
|
|
2919
3111
|
|
|
2920
3112
|
// decoder allocator
|
|
2921
3113
|
{
|
|
2922
|
-
whisper_allocr_graph_init(state->alloc_decode,
|
|
3114
|
+
whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
|
|
2923
3115
|
[&]() {
|
|
2924
3116
|
const auto & hparams = ctx->model.hparams;
|
|
2925
3117
|
|
|
@@ -2927,90 +3119,22 @@ if (ctx->load_coreml) { // Not in correct layer for easy patch
|
|
|
2927
3119
|
const int n_tokens = hparams.n_text_ctx;
|
|
2928
3120
|
const int n_past = 0;
|
|
2929
3121
|
|
|
2930
|
-
|
|
2931
|
-
});
|
|
2932
|
-
|
|
2933
|
-
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
|
2934
|
-
}
|
|
2935
|
-
|
|
2936
|
-
#ifdef WSP_GGML_USE_METAL
|
|
2937
|
-
state->ctx_metal = wsp_ggml_metal_init(1);
|
|
2938
|
-
if (!state->ctx_metal) {
|
|
2939
|
-
log("%s: wsp_ggml_metal_init() failed\n", __func__);
|
|
2940
|
-
delete state;
|
|
2941
|
-
return nullptr;
|
|
2942
|
-
}
|
|
2943
|
-
|
|
2944
|
-
log("%s: Metal context initialized\n", __func__);
|
|
2945
|
-
|
|
2946
|
-
// this allocates all Metal resources and memory buffers
|
|
2947
|
-
|
|
2948
|
-
void * data_ptr = NULL;
|
|
2949
|
-
size_t data_size = 0;
|
|
2950
|
-
|
|
2951
|
-
// TODO: add mmap support
|
|
2952
|
-
//if (params.use_mmap) {
|
|
2953
|
-
// data_ptr = ctx->model.mapping->addr;
|
|
2954
|
-
// data_size = ctx->model.mapping->size;
|
|
2955
|
-
//} else {
|
|
2956
|
-
// data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
|
|
2957
|
-
// data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
|
|
2958
|
-
//}
|
|
2959
|
-
|
|
2960
|
-
data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
|
|
2961
|
-
data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
|
|
3122
|
+
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
|
|
2962
3123
|
|
|
2963
|
-
|
|
2964
|
-
|
|
2965
|
-
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
|
|
3124
|
+
return whisper_build_graph_decoder(*ctx, *state, state->batch);
|
|
3125
|
+
});
|
|
2966
3126
|
|
|
2967
|
-
|
|
2968
|
-
if (!(result)) { \
|
|
2969
|
-
log("%s: failed to add metal buffer\n", __func__); \
|
|
2970
|
-
delete state; \
|
|
2971
|
-
return nullptr; \
|
|
3127
|
+
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
|
|
2972
3128
|
}
|
|
2973
3129
|
|
|
2974
|
-
|
|
2975
|
-
|
|
2976
|
-
|
|
2977
|
-
|
|
2978
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
|
|
2979
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
|
|
2980
|
-
|
|
2981
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
|
|
2982
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
|
|
2983
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
|
|
2984
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
|
|
2985
|
-
|
|
2986
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
|
|
2987
|
-
|
|
2988
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
|
|
2989
|
-
#undef WHISPER_METAL_CHECK_BUF
|
|
2990
|
-
#endif
|
|
2991
|
-
|
|
2992
|
-
state->rng = std::mt19937(0);
|
|
3130
|
+
whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
|
|
3131
|
+
whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
|
|
3132
|
+
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
|
|
3133
|
+
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
|
|
2993
3134
|
|
|
2994
3135
|
return state;
|
|
2995
3136
|
}
|
|
2996
3137
|
|
|
2997
|
-
#ifdef WHISPER_USE_COREML
|
|
2998
|
-
struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) {
|
|
2999
|
-
whisper_context * ctx = whisper_init_from_file_no_state(path_model);
|
|
3000
|
-
if (!ctx) {
|
|
3001
|
-
return nullptr;
|
|
3002
|
-
}
|
|
3003
|
-
ctx->load_coreml = false;
|
|
3004
|
-
ctx->state = whisper_init_state(ctx);
|
|
3005
|
-
if (!ctx->state) {
|
|
3006
|
-
whisper_free(ctx);
|
|
3007
|
-
return nullptr;
|
|
3008
|
-
}
|
|
3009
|
-
|
|
3010
|
-
return ctx;
|
|
3011
|
-
}
|
|
3012
|
-
#endif
|
|
3013
|
-
|
|
3014
3138
|
int whisper_ctx_init_openvino_encoder(
|
|
3015
3139
|
struct whisper_context * ctx,
|
|
3016
3140
|
const char * model_path,
|
|
@@ -3025,7 +3149,7 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
3025
3149
|
return 1;
|
|
3026
3150
|
#else
|
|
3027
3151
|
if (!model_path && ctx->path_model.empty()) {
|
|
3028
|
-
|
|
3152
|
+
WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
|
|
3029
3153
|
return 1;
|
|
3030
3154
|
}
|
|
3031
3155
|
|
|
@@ -3045,27 +3169,35 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
3045
3169
|
path_cache = cache_dir;
|
|
3046
3170
|
}
|
|
3047
3171
|
|
|
3048
|
-
|
|
3049
|
-
|
|
3172
|
+
WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
|
|
3173
|
+
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
|
3050
3174
|
|
|
3051
3175
|
ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
|
|
3052
3176
|
if (!ctx->state->ctx_openvino) {
|
|
3053
|
-
|
|
3177
|
+
WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
|
|
3054
3178
|
return 1;
|
|
3055
3179
|
} else {
|
|
3056
|
-
|
|
3180
|
+
WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__);
|
|
3057
3181
|
}
|
|
3058
3182
|
|
|
3059
3183
|
return 0;
|
|
3060
3184
|
#endif
|
|
3061
3185
|
}
|
|
3062
3186
|
|
|
3063
|
-
struct
|
|
3064
|
-
|
|
3187
|
+
struct whisper_context_params whisper_context_default_params() {
|
|
3188
|
+
struct whisper_context_params result = {
|
|
3189
|
+
/*.use_gpu =*/ true,
|
|
3190
|
+
/*.use_coreml =*/ false,
|
|
3191
|
+
};
|
|
3192
|
+
return result;
|
|
3193
|
+
}
|
|
3194
|
+
|
|
3195
|
+
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
|
|
3196
|
+
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
|
|
3065
3197
|
|
|
3066
3198
|
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
3067
3199
|
if (!fin) {
|
|
3068
|
-
|
|
3200
|
+
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
|
|
3069
3201
|
return nullptr;
|
|
3070
3202
|
}
|
|
3071
3203
|
|
|
@@ -3089,7 +3221,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
|
|
|
3089
3221
|
fin->close();
|
|
3090
3222
|
};
|
|
3091
3223
|
|
|
3092
|
-
auto ctx =
|
|
3224
|
+
auto ctx = whisper_init_with_params_no_state(&loader, params);
|
|
3093
3225
|
|
|
3094
3226
|
if (ctx) {
|
|
3095
3227
|
ctx->path_model = path_model;
|
|
@@ -3098,7 +3230,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
|
|
|
3098
3230
|
return ctx;
|
|
3099
3231
|
}
|
|
3100
3232
|
|
|
3101
|
-
struct whisper_context *
|
|
3233
|
+
struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) {
|
|
3102
3234
|
struct buf_context {
|
|
3103
3235
|
uint8_t* buffer;
|
|
3104
3236
|
size_t size;
|
|
@@ -3107,7 +3239,7 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
|
|
|
3107
3239
|
|
|
3108
3240
|
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
|
|
3109
3241
|
|
|
3110
|
-
|
|
3242
|
+
WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__);
|
|
3111
3243
|
|
|
3112
3244
|
whisper_model_loader loader = {};
|
|
3113
3245
|
|
|
@@ -3132,17 +3264,18 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
|
|
|
3132
3264
|
|
|
3133
3265
|
loader.close = [](void * /*ctx*/) { };
|
|
3134
3266
|
|
|
3135
|
-
return
|
|
3267
|
+
return whisper_init_with_params_no_state(&loader, params);
|
|
3136
3268
|
}
|
|
3137
3269
|
|
|
3138
|
-
struct whisper_context *
|
|
3270
|
+
struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
|
|
3139
3271
|
wsp_ggml_time_init();
|
|
3140
3272
|
|
|
3141
3273
|
whisper_context * ctx = new whisper_context;
|
|
3274
|
+
ctx->params = params;
|
|
3142
3275
|
|
|
3143
3276
|
if (!whisper_model_load(loader, *ctx)) {
|
|
3144
3277
|
loader->close(loader->context);
|
|
3145
|
-
|
|
3278
|
+
WHISPER_LOG_ERROR("%s: failed to load model\n", __func__);
|
|
3146
3279
|
delete ctx;
|
|
3147
3280
|
return nullptr;
|
|
3148
3281
|
}
|
|
@@ -3152,8 +3285,8 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
|
|
3152
3285
|
return ctx;
|
|
3153
3286
|
}
|
|
3154
3287
|
|
|
3155
|
-
struct whisper_context *
|
|
3156
|
-
whisper_context * ctx =
|
|
3288
|
+
struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) {
|
|
3289
|
+
whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params);
|
|
3157
3290
|
if (!ctx) {
|
|
3158
3291
|
return nullptr;
|
|
3159
3292
|
}
|
|
@@ -3167,8 +3300,8 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
|
|
|
3167
3300
|
return ctx;
|
|
3168
3301
|
}
|
|
3169
3302
|
|
|
3170
|
-
struct whisper_context *
|
|
3171
|
-
whisper_context * ctx =
|
|
3303
|
+
struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) {
|
|
3304
|
+
whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params);
|
|
3172
3305
|
if (!ctx) {
|
|
3173
3306
|
return nullptr;
|
|
3174
3307
|
}
|
|
@@ -3182,8 +3315,8 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
|
|
|
3182
3315
|
return ctx;
|
|
3183
3316
|
}
|
|
3184
3317
|
|
|
3185
|
-
struct whisper_context *
|
|
3186
|
-
whisper_context * ctx =
|
|
3318
|
+
struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) {
|
|
3319
|
+
whisper_context * ctx = whisper_init_with_params_no_state(loader, params);
|
|
3187
3320
|
if (!ctx) {
|
|
3188
3321
|
return nullptr;
|
|
3189
3322
|
}
|
|
@@ -3197,15 +3330,36 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
|
|
|
3197
3330
|
return ctx;
|
|
3198
3331
|
}
|
|
3199
3332
|
|
|
3333
|
+
struct whisper_context * whisper_init_from_file(const char * path_model) {
|
|
3334
|
+
return whisper_init_from_file_with_params(path_model, whisper_context_default_params());
|
|
3335
|
+
}
|
|
3336
|
+
|
|
3337
|
+
struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
|
|
3338
|
+
return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params());
|
|
3339
|
+
}
|
|
3340
|
+
|
|
3341
|
+
struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
|
|
3342
|
+
return whisper_init_with_params(loader, whisper_context_default_params());
|
|
3343
|
+
}
|
|
3344
|
+
|
|
3345
|
+
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
|
3346
|
+
return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params());
|
|
3347
|
+
}
|
|
3348
|
+
|
|
3349
|
+
struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
|
|
3350
|
+
return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params());
|
|
3351
|
+
}
|
|
3352
|
+
|
|
3353
|
+
struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
|
|
3354
|
+
return whisper_init_with_params_no_state(loader, whisper_context_default_params());
|
|
3355
|
+
}
|
|
3356
|
+
|
|
3200
3357
|
void whisper_free_state(struct whisper_state * state)
|
|
3201
3358
|
{
|
|
3202
3359
|
if (state) {
|
|
3360
|
+
kv_cache_free(state->kv_self);
|
|
3203
3361
|
kv_cache_free(state->kv_cross);
|
|
3204
3362
|
|
|
3205
|
-
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
|
|
3206
|
-
kv_cache_free(state->decoders[i].kv_self);
|
|
3207
|
-
}
|
|
3208
|
-
|
|
3209
3363
|
#ifdef WHISPER_USE_COREML
|
|
3210
3364
|
if (state->ctx_coreml != nullptr) {
|
|
3211
3365
|
whisper_coreml_free(state->ctx_coreml);
|
|
@@ -3213,13 +3367,6 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3213
3367
|
}
|
|
3214
3368
|
#endif
|
|
3215
3369
|
|
|
3216
|
-
#ifdef WSP_GGML_USE_METAL
|
|
3217
|
-
if (state->ctx_metal) {
|
|
3218
|
-
wsp_ggml_metal_free(state->ctx_metal);
|
|
3219
|
-
state->ctx_metal = nullptr;
|
|
3220
|
-
}
|
|
3221
|
-
#endif
|
|
3222
|
-
|
|
3223
3370
|
#ifdef WHISPER_USE_OPENVINO
|
|
3224
3371
|
if (state->ctx_openvino != nullptr) {
|
|
3225
3372
|
whisper_openvino_free(state->ctx_openvino);
|
|
@@ -3227,10 +3374,14 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3227
3374
|
}
|
|
3228
3375
|
#endif
|
|
3229
3376
|
|
|
3377
|
+
whisper_batch_free(state->batch);
|
|
3378
|
+
|
|
3230
3379
|
whisper_allocr_free(state->alloc_conv);
|
|
3231
|
-
whisper_allocr_free(state->alloc_decode);
|
|
3232
|
-
whisper_allocr_free(state->alloc_cross);
|
|
3233
3380
|
whisper_allocr_free(state->alloc_encode);
|
|
3381
|
+
whisper_allocr_free(state->alloc_cross);
|
|
3382
|
+
whisper_allocr_free(state->alloc_decode);
|
|
3383
|
+
|
|
3384
|
+
wsp_ggml_backend_free(state->backend);
|
|
3234
3385
|
|
|
3235
3386
|
delete state;
|
|
3236
3387
|
}
|
|
@@ -3241,16 +3392,25 @@ void whisper_free(struct whisper_context * ctx) {
|
|
|
3241
3392
|
if (ctx->model.ctx) {
|
|
3242
3393
|
wsp_ggml_free(ctx->model.ctx);
|
|
3243
3394
|
}
|
|
3244
|
-
|
|
3245
|
-
|
|
3395
|
+
|
|
3396
|
+
if (ctx->model.buffer) {
|
|
3397
|
+
wsp_ggml_backend_buffer_free(ctx->model.buffer);
|
|
3246
3398
|
}
|
|
3247
3399
|
|
|
3248
3400
|
whisper_free_state(ctx->state);
|
|
3249
3401
|
|
|
3402
|
+
wsp_ggml_backend_free(ctx->backend);
|
|
3403
|
+
|
|
3250
3404
|
delete ctx;
|
|
3251
3405
|
}
|
|
3252
3406
|
}
|
|
3253
3407
|
|
|
3408
|
+
void whisper_free_context_params(struct whisper_context_params * params) {
|
|
3409
|
+
if (params) {
|
|
3410
|
+
delete params;
|
|
3411
|
+
}
|
|
3412
|
+
}
|
|
3413
|
+
|
|
3254
3414
|
void whisper_free_params(struct whisper_full_params * params) {
|
|
3255
3415
|
if (params) {
|
|
3256
3416
|
delete params;
|
|
@@ -3258,8 +3418,8 @@ void whisper_free_params(struct whisper_full_params * params) {
|
|
|
3258
3418
|
}
|
|
3259
3419
|
|
|
3260
3420
|
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
|
3261
|
-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH,
|
|
3262
|
-
|
|
3421
|
+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
|
3422
|
+
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
|
3263
3423
|
return -1;
|
|
3264
3424
|
}
|
|
3265
3425
|
|
|
@@ -3272,8 +3432,8 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
|
3272
3432
|
|
|
3273
3433
|
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
|
3274
3434
|
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
|
3275
|
-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH,
|
|
3276
|
-
|
|
3435
|
+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
|
3436
|
+
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
|
3277
3437
|
return -1;
|
|
3278
3438
|
}
|
|
3279
3439
|
|
|
@@ -3295,13 +3455,13 @@ int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float *
|
|
|
3295
3455
|
// TODO
|
|
3296
3456
|
|
|
3297
3457
|
int whisper_set_mel_with_state(
|
|
3298
|
-
struct whisper_context *
|
|
3458
|
+
struct whisper_context * ctx,
|
|
3299
3459
|
struct whisper_state * state,
|
|
3300
3460
|
const float * data,
|
|
3301
3461
|
int n_len,
|
|
3302
3462
|
int n_mel) {
|
|
3303
|
-
if (n_mel !=
|
|
3304
|
-
|
|
3463
|
+
if (n_mel != ctx->model.filters.n_mel) {
|
|
3464
|
+
WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
|
|
3305
3465
|
return -1;
|
|
3306
3466
|
}
|
|
3307
3467
|
|
|
@@ -3325,7 +3485,7 @@ int whisper_set_mel(
|
|
|
3325
3485
|
|
|
3326
3486
|
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
|
3327
3487
|
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
|
3328
|
-
|
|
3488
|
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3329
3489
|
return -1;
|
|
3330
3490
|
}
|
|
3331
3491
|
|
|
@@ -3334,7 +3494,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
|
3334
3494
|
|
|
3335
3495
|
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
3336
3496
|
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
|
3337
|
-
|
|
3497
|
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3338
3498
|
return -1;
|
|
3339
3499
|
}
|
|
3340
3500
|
|
|
@@ -3342,10 +3502,12 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
|
3342
3502
|
}
|
|
3343
3503
|
|
|
3344
3504
|
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
3345
|
-
|
|
3505
|
+
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
|
|
3346
3506
|
|
|
3347
|
-
|
|
3348
|
-
|
|
3507
|
+
whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
|
|
3508
|
+
|
|
3509
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
|
|
3510
|
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3349
3511
|
return 1;
|
|
3350
3512
|
}
|
|
3351
3513
|
|
|
@@ -3353,27 +3515,19 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
|
3353
3515
|
}
|
|
3354
3516
|
|
|
3355
3517
|
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
3356
|
-
// TODO: add selected_decoder_id to state
|
|
3357
|
-
const int selected_decoder_id = 0;
|
|
3358
|
-
|
|
3359
3518
|
if (ctx->state == nullptr) {
|
|
3360
|
-
|
|
3361
|
-
return
|
|
3362
|
-
}
|
|
3363
|
-
|
|
3364
|
-
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
|
3365
|
-
log("%s: failed to eval\n", __func__);
|
|
3366
|
-
return 1;
|
|
3519
|
+
WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
|
|
3520
|
+
return -1;
|
|
3367
3521
|
}
|
|
3368
3522
|
|
|
3369
|
-
return
|
|
3523
|
+
return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads);
|
|
3370
3524
|
}
|
|
3371
3525
|
|
|
3372
3526
|
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
|
3373
3527
|
const auto res = tokenize(ctx->vocab, text);
|
|
3374
3528
|
|
|
3375
3529
|
if (n_max_tokens < (int) res.size()) {
|
|
3376
|
-
|
|
3530
|
+
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
|
3377
3531
|
return -1;
|
|
3378
3532
|
}
|
|
3379
3533
|
|
|
@@ -3401,7 +3555,7 @@ int whisper_lang_id(const char * lang) {
|
|
|
3401
3555
|
}
|
|
3402
3556
|
}
|
|
3403
3557
|
|
|
3404
|
-
|
|
3558
|
+
WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang);
|
|
3405
3559
|
return -1;
|
|
3406
3560
|
}
|
|
3407
3561
|
return g_lang.at(lang).first;
|
|
@@ -3414,7 +3568,18 @@ const char * whisper_lang_str(int id) {
|
|
|
3414
3568
|
}
|
|
3415
3569
|
}
|
|
3416
3570
|
|
|
3417
|
-
|
|
3571
|
+
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
|
|
3572
|
+
return nullptr;
|
|
3573
|
+
}
|
|
3574
|
+
|
|
3575
|
+
const char * whisper_lang_str_full(int id) {
|
|
3576
|
+
for (const auto & kv : g_lang) {
|
|
3577
|
+
if (kv.second.first == id) {
|
|
3578
|
+
return kv.second.second.c_str();
|
|
3579
|
+
}
|
|
3580
|
+
}
|
|
3581
|
+
|
|
3582
|
+
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
|
|
3418
3583
|
return nullptr;
|
|
3419
3584
|
}
|
|
3420
3585
|
|
|
@@ -3427,29 +3592,29 @@ int whisper_lang_auto_detect_with_state(
|
|
|
3427
3592
|
const int seek = offset_ms/10;
|
|
3428
3593
|
|
|
3429
3594
|
if (seek < 0) {
|
|
3430
|
-
|
|
3595
|
+
WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
|
|
3431
3596
|
return -1;
|
|
3432
3597
|
}
|
|
3433
3598
|
|
|
3434
3599
|
if (seek >= state->mel.n_len_org) {
|
|
3435
|
-
|
|
3600
|
+
WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
|
|
3436
3601
|
return -2;
|
|
3437
3602
|
}
|
|
3438
3603
|
|
|
3439
3604
|
// run the encoder
|
|
3440
3605
|
if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
|
|
3441
|
-
|
|
3606
|
+
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
|
3442
3607
|
return -6;
|
|
3443
3608
|
}
|
|
3444
3609
|
|
|
3445
3610
|
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
|
3446
3611
|
|
|
3447
3612
|
if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
|
3448
|
-
|
|
3613
|
+
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
3449
3614
|
return -7;
|
|
3450
3615
|
}
|
|
3451
3616
|
|
|
3452
|
-
auto & logits_id = state->logits_id;
|
|
3617
|
+
auto & logits_id = state->decoders[0].logits_id;
|
|
3453
3618
|
logits_id.clear();
|
|
3454
3619
|
|
|
3455
3620
|
for (const auto & kv : g_lang) {
|
|
@@ -3645,27 +3810,31 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
|
|
|
3645
3810
|
void whisper_print_timings(struct whisper_context * ctx) {
|
|
3646
3811
|
const int64_t t_end_us = wsp_ggml_time_us();
|
|
3647
3812
|
|
|
3648
|
-
|
|
3649
|
-
|
|
3813
|
+
WHISPER_LOG_INFO("\n");
|
|
3814
|
+
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
|
|
3650
3815
|
if (ctx->state != nullptr) {
|
|
3651
3816
|
|
|
3652
3817
|
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
3653
3818
|
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
|
3654
3819
|
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
3820
|
+
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
|
|
3655
3821
|
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
|
3656
3822
|
|
|
3657
|
-
|
|
3658
|
-
|
|
3659
|
-
|
|
3660
|
-
|
|
3661
|
-
|
|
3662
|
-
|
|
3823
|
+
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
|
3824
|
+
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
|
3825
|
+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
|
3826
|
+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
|
3827
|
+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
|
3828
|
+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
|
3829
|
+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
|
3663
3830
|
}
|
|
3664
|
-
|
|
3831
|
+
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
|
3665
3832
|
}
|
|
3666
3833
|
|
|
3667
3834
|
void whisper_reset_timings(struct whisper_context * ctx) {
|
|
3835
|
+
ctx->t_start_us = wsp_ggml_time_us();
|
|
3668
3836
|
if (ctx->state != nullptr) {
|
|
3837
|
+
ctx->state->t_mel_us = 0;
|
|
3669
3838
|
ctx->state->t_sample_us = 0;
|
|
3670
3839
|
ctx->state->t_encode_us = 0;
|
|
3671
3840
|
ctx->state->t_decode_us = 0;
|
|
@@ -3673,6 +3842,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
|
|
|
3673
3842
|
ctx->state->n_sample = 0;
|
|
3674
3843
|
ctx->state->n_encode = 0;
|
|
3675
3844
|
ctx->state->n_decode = 0;
|
|
3845
|
+
ctx->state->n_batchd = 0;
|
|
3676
3846
|
ctx->state->n_prompt = 0;
|
|
3677
3847
|
}
|
|
3678
3848
|
}
|
|
@@ -3711,14 +3881,441 @@ const char * whisper_print_system_info(void) {
|
|
|
3711
3881
|
s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
|
|
3712
3882
|
s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
|
|
3713
3883
|
s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
|
|
3884
|
+
s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cublas()) + " | ";
|
|
3714
3885
|
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
|
3715
3886
|
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
|
3716
3887
|
|
|
3717
3888
|
return s.c_str();
|
|
3718
3889
|
}
|
|
3719
3890
|
|
|
3891
|
+
//////////////////////////////////
|
|
3892
|
+
// Grammar - ported from llama.cpp
|
|
3893
|
+
//////////////////////////////////
|
|
3894
|
+
|
|
3895
|
+
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
|
3896
|
+
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
|
3897
|
+
std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
|
3898
|
+
const char * src,
|
|
3899
|
+
whisper_partial_utf8 partial_start) {
|
|
3900
|
+
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
|
3901
|
+
const char * pos = src;
|
|
3902
|
+
std::vector<uint32_t> code_points;
|
|
3903
|
+
uint32_t value = partial_start.value;
|
|
3904
|
+
int n_remain = partial_start.n_remain;
|
|
3905
|
+
|
|
3906
|
+
// continue previous decode, if applicable
|
|
3907
|
+
while (*pos != 0 && n_remain > 0) {
|
|
3908
|
+
uint8_t next_byte = static_cast<uint8_t>(*pos);
|
|
3909
|
+
if ((next_byte >> 6) != 2) {
|
|
3910
|
+
// invalid sequence, abort
|
|
3911
|
+
code_points.push_back(0);
|
|
3912
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
|
|
3913
|
+
}
|
|
3914
|
+
value = (value << 6) + (next_byte & 0x3F);
|
|
3915
|
+
++pos;
|
|
3916
|
+
--n_remain;
|
|
3917
|
+
}
|
|
3918
|
+
|
|
3919
|
+
if (partial_start.n_remain > 0 && n_remain == 0) {
|
|
3920
|
+
code_points.push_back(value);
|
|
3921
|
+
}
|
|
3922
|
+
|
|
3923
|
+
// decode any subsequent utf-8 sequences, which may end in an incomplete one
|
|
3924
|
+
while (*pos != 0) {
|
|
3925
|
+
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
|
3926
|
+
uint8_t highbits = first_byte >> 4;
|
|
3927
|
+
n_remain = lookup[highbits] - 1;
|
|
3928
|
+
|
|
3929
|
+
if (n_remain < 0) {
|
|
3930
|
+
// invalid sequence, abort
|
|
3931
|
+
code_points.clear();
|
|
3932
|
+
code_points.push_back(0);
|
|
3933
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
|
|
3934
|
+
}
|
|
3935
|
+
|
|
3936
|
+
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
|
3937
|
+
value = first_byte & mask;
|
|
3938
|
+
++pos;
|
|
3939
|
+
while (*pos != 0 && n_remain > 0) {
|
|
3940
|
+
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
|
3941
|
+
++pos;
|
|
3942
|
+
--n_remain;
|
|
3943
|
+
}
|
|
3944
|
+
if (n_remain == 0) {
|
|
3945
|
+
code_points.push_back(value);
|
|
3946
|
+
}
|
|
3947
|
+
}
|
|
3948
|
+
code_points.push_back(0);
|
|
3949
|
+
|
|
3950
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
|
|
3951
|
+
}
|
|
3952
|
+
|
|
3953
|
+
// returns true iff pos points to the end of one of the definitions of a rule
|
|
3954
|
+
static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
|
|
3955
|
+
switch (pos->type) {
|
|
3956
|
+
case WHISPER_GRETYPE_END: return true; // NOLINT
|
|
3957
|
+
case WHISPER_GRETYPE_ALT: return true; // NOLINT
|
|
3958
|
+
default: return false;
|
|
3959
|
+
}
|
|
3960
|
+
}
|
|
3961
|
+
|
|
3962
|
+
// returns true iff chr satisfies the char range at pos (regular or inverse range)
|
|
3963
|
+
// asserts that pos is pointing to a char range element
|
|
3964
|
+
static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
|
|
3965
|
+
const whisper_grammar_element * pos,
|
|
3966
|
+
const uint32_t chr) {
|
|
3967
|
+
|
|
3968
|
+
bool found = false;
|
|
3969
|
+
bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
|
|
3970
|
+
|
|
3971
|
+
WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
|
|
3972
|
+
|
|
3973
|
+
do {
|
|
3974
|
+
if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
|
|
3975
|
+
// inclusive range, e.g. [a-z]
|
|
3976
|
+
found = found || (pos->value <= chr && chr <= pos[1].value);
|
|
3977
|
+
pos += 2;
|
|
3978
|
+
} else {
|
|
3979
|
+
// exact char match, e.g. [a] or "a"
|
|
3980
|
+
found = found || pos->value == chr;
|
|
3981
|
+
pos += 1;
|
|
3982
|
+
}
|
|
3983
|
+
} while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
|
|
3984
|
+
|
|
3985
|
+
return std::make_pair(found == is_positive_char, pos);
|
|
3986
|
+
}
|
|
3987
|
+
|
|
3988
|
+
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
|
|
3989
|
+
// range at pos (regular or inverse range)
|
|
3990
|
+
// asserts that pos is pointing to a char range element
|
|
3991
|
+
static bool whisper_grammar_match_partial_char(
|
|
3992
|
+
const whisper_grammar_element * pos,
|
|
3993
|
+
const whisper_partial_utf8 partial_utf8) {
|
|
3994
|
+
|
|
3995
|
+
bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
|
|
3996
|
+
WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
|
|
3997
|
+
|
|
3998
|
+
uint32_t partial_value = partial_utf8.value;
|
|
3999
|
+
int n_remain = partial_utf8.n_remain;
|
|
4000
|
+
|
|
4001
|
+
// invalid sequence or 7-bit char split across 2 bytes (overlong)
|
|
4002
|
+
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
|
|
4003
|
+
return false;
|
|
4004
|
+
}
|
|
4005
|
+
|
|
4006
|
+
// range of possible code points this partial UTF-8 sequence could complete to
|
|
4007
|
+
uint32_t low = partial_value << (n_remain * 6);
|
|
4008
|
+
uint32_t high = low | ((1 << (n_remain * 6)) - 1);
|
|
4009
|
+
|
|
4010
|
+
if (low == 0) {
|
|
4011
|
+
if (n_remain == 2) {
|
|
4012
|
+
low = 1 << 11;
|
|
4013
|
+
} else if (n_remain == 3) {
|
|
4014
|
+
low = 1 << 16;
|
|
4015
|
+
}
|
|
4016
|
+
}
|
|
4017
|
+
|
|
4018
|
+
do {
|
|
4019
|
+
if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
|
|
4020
|
+
// inclusive range, e.g. [a-z]
|
|
4021
|
+
if (pos->value <= high && low <= pos[1].value) {
|
|
4022
|
+
return is_positive_char;
|
|
4023
|
+
}
|
|
4024
|
+
pos += 2;
|
|
4025
|
+
} else {
|
|
4026
|
+
// exact char match, e.g. [a] or "a"
|
|
4027
|
+
if (low <= pos->value && pos->value <= high) {
|
|
4028
|
+
return is_positive_char;
|
|
4029
|
+
}
|
|
4030
|
+
pos += 1;
|
|
4031
|
+
}
|
|
4032
|
+
} while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
|
|
4033
|
+
|
|
4034
|
+
return !is_positive_char;
|
|
4035
|
+
}
|
|
4036
|
+
|
|
4037
|
+
|
|
4038
|
+
// transforms a grammar pushdown stack into N possible stacks, all ending
|
|
4039
|
+
// at a character range (terminal element)
|
|
4040
|
+
static void whisper_grammar_advance_stack(
|
|
4041
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4042
|
+
const std::vector<const whisper_grammar_element *> & stack,
|
|
4043
|
+
std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
|
|
4044
|
+
|
|
4045
|
+
if (stack.empty()) {
|
|
4046
|
+
new_stacks.push_back(stack);
|
|
4047
|
+
return;
|
|
4048
|
+
}
|
|
4049
|
+
|
|
4050
|
+
const whisper_grammar_element * pos = stack.back();
|
|
4051
|
+
|
|
4052
|
+
switch (pos->type) {
|
|
4053
|
+
case WHISPER_GRETYPE_RULE_REF: {
|
|
4054
|
+
const size_t rule_id = static_cast<size_t>(pos->value);
|
|
4055
|
+
const whisper_grammar_element * subpos = rules[rule_id].data();
|
|
4056
|
+
do {
|
|
4057
|
+
// init new stack without the top (pos)
|
|
4058
|
+
std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
|
|
4059
|
+
if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
|
|
4060
|
+
// if this rule ref is followed by another element, add that to stack
|
|
4061
|
+
new_stack.push_back(pos + 1);
|
|
4062
|
+
}
|
|
4063
|
+
if (!whisper_grammar_is_end_of_sequence(subpos)) {
|
|
4064
|
+
// if alternate is nonempty, add to stack
|
|
4065
|
+
new_stack.push_back(subpos);
|
|
4066
|
+
}
|
|
4067
|
+
whisper_grammar_advance_stack(rules, new_stack, new_stacks);
|
|
4068
|
+
while (!whisper_grammar_is_end_of_sequence(subpos)) {
|
|
4069
|
+
// scan to end of alternate def
|
|
4070
|
+
subpos++;
|
|
4071
|
+
}
|
|
4072
|
+
if (subpos->type == WHISPER_GRETYPE_ALT) {
|
|
4073
|
+
// there's another alternate def of this rule to process
|
|
4074
|
+
subpos++;
|
|
4075
|
+
} else {
|
|
4076
|
+
break;
|
|
4077
|
+
}
|
|
4078
|
+
} while (true);
|
|
4079
|
+
break;
|
|
4080
|
+
}
|
|
4081
|
+
case WHISPER_GRETYPE_CHAR:
|
|
4082
|
+
case WHISPER_GRETYPE_CHAR_NOT:
|
|
4083
|
+
new_stacks.push_back(stack);
|
|
4084
|
+
break;
|
|
4085
|
+
default:
|
|
4086
|
+
// end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
|
|
4087
|
+
// (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
|
4088
|
+
// those
|
|
4089
|
+
WHISPER_ASSERT(false);
|
|
4090
|
+
}
|
|
4091
|
+
}
|
|
4092
|
+
|
|
4093
|
+
// takes a set of possible pushdown stacks on a grammar, which are required to
|
|
4094
|
+
// be positioned at a character range (see `whisper_grammar_advance_stack`), and
|
|
4095
|
+
// produces the N possible stacks if the given char is accepted at those
|
|
4096
|
+
// positions
|
|
4097
|
+
static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
|
|
4098
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4099
|
+
const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
|
|
4100
|
+
const uint32_t chr) {
|
|
4101
|
+
|
|
4102
|
+
std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
|
|
4103
|
+
|
|
4104
|
+
for (const auto & stack : stacks) {
|
|
4105
|
+
if (stack.empty()) {
|
|
4106
|
+
continue;
|
|
4107
|
+
}
|
|
4108
|
+
|
|
4109
|
+
auto match = whisper_grammar_match_char(stack.back(), chr);
|
|
4110
|
+
if (match.first) {
|
|
4111
|
+
const whisper_grammar_element * pos = match.second;
|
|
4112
|
+
|
|
4113
|
+
// update top of stack to next element, if any
|
|
4114
|
+
std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
|
|
4115
|
+
if (!whisper_grammar_is_end_of_sequence(pos)) {
|
|
4116
|
+
new_stack.push_back(pos);
|
|
4117
|
+
}
|
|
4118
|
+
whisper_grammar_advance_stack(rules, new_stack, new_stacks);
|
|
4119
|
+
}
|
|
4120
|
+
}
|
|
4121
|
+
|
|
4122
|
+
return new_stacks;
|
|
4123
|
+
}
|
|
4124
|
+
|
|
4125
|
+
static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
|
|
4126
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4127
|
+
const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
|
|
4128
|
+
const std::vector<whisper_grammar_candidate> & candidates);
|
|
4129
|
+
|
|
4130
|
+
static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
|
|
4131
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4132
|
+
const std::vector<const whisper_grammar_element *> & stack,
|
|
4133
|
+
const std::vector<whisper_grammar_candidate> & candidates) {
|
|
4134
|
+
|
|
4135
|
+
std::vector<whisper_grammar_candidate> rejects;
|
|
4136
|
+
|
|
4137
|
+
if (stack.empty()) {
|
|
4138
|
+
for (auto tok : candidates) {
|
|
4139
|
+
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
|
|
4140
|
+
rejects.push_back(tok);
|
|
4141
|
+
}
|
|
4142
|
+
}
|
|
4143
|
+
return rejects;
|
|
4144
|
+
}
|
|
4145
|
+
|
|
4146
|
+
const whisper_grammar_element * stack_pos = stack.back();
|
|
4147
|
+
|
|
4148
|
+
std::vector<whisper_grammar_candidate> next_candidates;
|
|
4149
|
+
for (auto tok : candidates) {
|
|
4150
|
+
if (*tok.code_points == 0) {
|
|
4151
|
+
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
|
4152
|
+
// that cannot satisfy this position in grammar
|
|
4153
|
+
if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
|
|
4154
|
+
rejects.push_back(tok);
|
|
4155
|
+
}
|
|
4156
|
+
} else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
|
|
4157
|
+
next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
|
|
4158
|
+
} else {
|
|
4159
|
+
rejects.push_back(tok);
|
|
4160
|
+
}
|
|
4161
|
+
}
|
|
4162
|
+
|
|
4163
|
+
const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
|
|
4164
|
+
|
|
4165
|
+
// update top of stack to next element, if any
|
|
4166
|
+
std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
|
|
4167
|
+
if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
|
|
4168
|
+
stack_after.push_back(stack_pos_after);
|
|
4169
|
+
}
|
|
4170
|
+
std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
|
|
4171
|
+
whisper_grammar_advance_stack(rules, stack_after, next_stacks);
|
|
4172
|
+
|
|
4173
|
+
auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
|
4174
|
+
for (auto tok : next_rejects) {
|
|
4175
|
+
rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
|
|
4176
|
+
}
|
|
4177
|
+
|
|
4178
|
+
return rejects;
|
|
4179
|
+
}
|
|
4180
|
+
|
|
4181
|
+
static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
|
|
4182
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4183
|
+
const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
|
|
4184
|
+
const std::vector<whisper_grammar_candidate> & candidates) {
|
|
4185
|
+
if (candidates.empty() || stacks.empty()) {
|
|
4186
|
+
return std::vector<whisper_grammar_candidate>();
|
|
4187
|
+
}
|
|
4188
|
+
|
|
4189
|
+
auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
|
4190
|
+
|
|
4191
|
+
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
|
4192
|
+
rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
|
4193
|
+
}
|
|
4194
|
+
return rejects;
|
|
4195
|
+
}
|
|
4196
|
+
|
|
4197
|
+
static struct whisper_grammar whisper_grammar_init(
|
|
4198
|
+
const whisper_grammar_element ** rules,
|
|
4199
|
+
size_t n_rules,
|
|
4200
|
+
size_t i_start_rule) {
|
|
4201
|
+
const whisper_grammar_element * pos;
|
|
4202
|
+
|
|
4203
|
+
// copy rule definitions into vectors
|
|
4204
|
+
std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
|
|
4205
|
+
for (size_t i = 0; i < n_rules; i++) {
|
|
4206
|
+
for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
|
|
4207
|
+
vec_rules[i].push_back(*pos);
|
|
4208
|
+
}
|
|
4209
|
+
vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
|
|
4210
|
+
}
|
|
4211
|
+
|
|
4212
|
+
// loop over alternates of start rule to build initial stacks
|
|
4213
|
+
std::vector<std::vector<const whisper_grammar_element *>> stacks;
|
|
4214
|
+
pos = rules[i_start_rule];
|
|
4215
|
+
do {
|
|
4216
|
+
std::vector<const whisper_grammar_element *> stack;
|
|
4217
|
+
if (!whisper_grammar_is_end_of_sequence(pos)) {
|
|
4218
|
+
// if alternate is nonempty, add to stack
|
|
4219
|
+
stack.push_back(pos);
|
|
4220
|
+
}
|
|
4221
|
+
whisper_grammar_advance_stack(vec_rules, stack, stacks);
|
|
4222
|
+
while (!whisper_grammar_is_end_of_sequence(pos)) {
|
|
4223
|
+
// scan to end of alternate def
|
|
4224
|
+
pos++;
|
|
4225
|
+
}
|
|
4226
|
+
if (pos->type == WHISPER_GRETYPE_ALT) {
|
|
4227
|
+
// there's another alternate def of this rule to process
|
|
4228
|
+
pos++;
|
|
4229
|
+
} else {
|
|
4230
|
+
break;
|
|
4231
|
+
}
|
|
4232
|
+
} while (true);
|
|
4233
|
+
|
|
4234
|
+
return { std::move(vec_rules), std::move(stacks), {} };
|
|
4235
|
+
}
|
|
4236
|
+
|
|
4237
|
+
static void whisper_suppress_invalid_grammar(
|
|
4238
|
+
whisper_context & ctx,
|
|
4239
|
+
const whisper_full_params & params,
|
|
4240
|
+
std::vector<float> & logits,
|
|
4241
|
+
const whisper_grammar & grammar) {
|
|
4242
|
+
|
|
4243
|
+
if (grammar.rules.empty() || grammar.stacks.empty()) {
|
|
4244
|
+
return;
|
|
4245
|
+
}
|
|
4246
|
+
|
|
4247
|
+
//bool allow_eot = false;
|
|
4248
|
+
//for (const auto & stack : grammar.stacks) {
|
|
4249
|
+
// if (stack.empty()) {
|
|
4250
|
+
// allow_eot = true;
|
|
4251
|
+
// break;
|
|
4252
|
+
// }
|
|
4253
|
+
//}
|
|
4254
|
+
|
|
4255
|
+
const whisper_token eot = whisper_token_eot(&ctx);
|
|
4256
|
+
|
|
4257
|
+
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
|
|
4258
|
+
std::vector<whisper_grammar_candidate> candidates_grammar;
|
|
4259
|
+
|
|
4260
|
+
for (whisper_token id = 0; id < eot; ++id) {
|
|
4261
|
+
const std::string & text = ctx.vocab.id_to_token[id];
|
|
4262
|
+
if (!text.empty()) {
|
|
4263
|
+
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
|
|
4264
|
+
candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
|
4265
|
+
}
|
|
4266
|
+
}
|
|
4267
|
+
|
|
4268
|
+
const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
|
4269
|
+
|
|
4270
|
+
for (const auto & reject : rejects) {
|
|
4271
|
+
logits[reject.id] -= params.grammar_penalty;
|
|
4272
|
+
}
|
|
4273
|
+
|
|
4274
|
+
// when the grammar allows a continuation, we penalize the end-of-text token
|
|
4275
|
+
//if (!allow_eot) {
|
|
4276
|
+
// logits[eot] -= params.grammar_penalty;
|
|
4277
|
+
//}
|
|
4278
|
+
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
|
|
4279
|
+
}
|
|
4280
|
+
|
|
4281
|
+
static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
|
|
4282
|
+
if (grammar.rules.empty() || grammar.stacks.empty()) {
|
|
4283
|
+
return;
|
|
4284
|
+
}
|
|
4285
|
+
|
|
4286
|
+
//fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
|
|
4287
|
+
|
|
4288
|
+
const std::string & text = ctx.vocab.id_to_token[token];
|
|
4289
|
+
|
|
4290
|
+
if (text.rfind("[_", 0) == 0) {
|
|
4291
|
+
// fprintf(stderr, " (skipped)\n");
|
|
4292
|
+
return;
|
|
4293
|
+
}
|
|
4294
|
+
// fprintf(stderr, "\n");
|
|
4295
|
+
|
|
4296
|
+
// Note terminating 0 in decoded string
|
|
4297
|
+
const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8);
|
|
4298
|
+
const auto & code_points = decoded.first;
|
|
4299
|
+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
|
4300
|
+
grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
|
|
4301
|
+
}
|
|
4302
|
+
grammar.partial_utf8 = decoded.second;
|
|
4303
|
+
}
|
|
4304
|
+
|
|
4305
|
+
//////////////
|
|
4306
|
+
// END grammar
|
|
4307
|
+
//////////////
|
|
4308
|
+
|
|
3720
4309
|
////////////////////////////////////////////////////////////////////////////
|
|
3721
4310
|
|
|
4311
|
+
struct whisper_context_params * whisper_context_default_params_by_ref() {
|
|
4312
|
+
struct whisper_context_params params = whisper_context_default_params();
|
|
4313
|
+
|
|
4314
|
+
struct whisper_context_params* result = new whisper_context_params();
|
|
4315
|
+
*result = params;
|
|
4316
|
+
return result;
|
|
4317
|
+
}
|
|
4318
|
+
|
|
3722
4319
|
struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
|
|
3723
4320
|
struct whisper_full_params params = whisper_full_default_params(strategy);
|
|
3724
4321
|
|
|
@@ -3738,6 +4335,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3738
4335
|
|
|
3739
4336
|
/*.translate =*/ false,
|
|
3740
4337
|
/*.no_context =*/ true,
|
|
4338
|
+
/*.no_timestamps =*/ false,
|
|
3741
4339
|
/*.single_segment =*/ false,
|
|
3742
4340
|
/*.print_special =*/ false,
|
|
3743
4341
|
/*.print_progress =*/ true,
|
|
@@ -3771,7 +4369,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3771
4369
|
/*.max_initial_ts =*/ 1.0f,
|
|
3772
4370
|
/*.length_penalty =*/ -1.0f,
|
|
3773
4371
|
|
|
3774
|
-
/*.temperature_inc =*/ 0.
|
|
4372
|
+
/*.temperature_inc =*/ 0.2f,
|
|
3775
4373
|
/*.entropy_thold =*/ 2.4f,
|
|
3776
4374
|
/*.logprob_thold =*/ -1.0f,
|
|
3777
4375
|
/*.no_speech_thold =*/ 0.6f,
|
|
@@ -3795,24 +4393,29 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3795
4393
|
/*.encoder_begin_callback =*/ nullptr,
|
|
3796
4394
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
3797
4395
|
|
|
3798
|
-
/*.abort_callback
|
|
3799
|
-
/*.abort_callback_user_data
|
|
4396
|
+
/*.abort_callback =*/ nullptr,
|
|
4397
|
+
/*.abort_callback_user_data =*/ nullptr,
|
|
3800
4398
|
|
|
3801
4399
|
/*.logits_filter_callback =*/ nullptr,
|
|
3802
4400
|
/*.logits_filter_callback_user_data =*/ nullptr,
|
|
4401
|
+
|
|
4402
|
+
/*.grammar_rules =*/ nullptr,
|
|
4403
|
+
/*.n_grammar_rules =*/ 0,
|
|
4404
|
+
/*.i_start_rule =*/ 0,
|
|
4405
|
+
/*.grammar_penalty =*/ 100.0f,
|
|
3803
4406
|
};
|
|
3804
4407
|
|
|
3805
4408
|
switch (strategy) {
|
|
3806
4409
|
case WHISPER_SAMPLING_GREEDY:
|
|
3807
4410
|
{
|
|
3808
4411
|
result.greedy = {
|
|
3809
|
-
/*.best_of =*/
|
|
4412
|
+
/*.best_of =*/ 5,
|
|
3810
4413
|
};
|
|
3811
4414
|
} break;
|
|
3812
4415
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
3813
4416
|
{
|
|
3814
4417
|
result.beam_search = {
|
|
3815
|
-
/*.beam_size =*/
|
|
4418
|
+
/*.beam_size =*/ 5,
|
|
3816
4419
|
|
|
3817
4420
|
/*.patience =*/ -1.0f,
|
|
3818
4421
|
};
|
|
@@ -3902,11 +4505,12 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
|
3902
4505
|
// process the logits for the selected decoder
|
|
3903
4506
|
// - applies logit filters
|
|
3904
4507
|
// - computes logprobs and probs
|
|
4508
|
+
// TODO: optimize
|
|
3905
4509
|
static void whisper_process_logits(
|
|
3906
4510
|
struct whisper_context & ctx,
|
|
3907
4511
|
struct whisper_state & state,
|
|
3908
|
-
const struct whisper_full_params params,
|
|
3909
4512
|
struct whisper_decoder & decoder,
|
|
4513
|
+
const struct whisper_full_params params,
|
|
3910
4514
|
float temperature) {
|
|
3911
4515
|
const auto & vocab = ctx.vocab;
|
|
3912
4516
|
const auto & tokens_cur = decoder.sequence.tokens;
|
|
@@ -3923,7 +4527,7 @@ static void whisper_process_logits(
|
|
|
3923
4527
|
auto & logprobs = decoder.logprobs;
|
|
3924
4528
|
{
|
|
3925
4529
|
logits.resize(n_logits);
|
|
3926
|
-
memcpy(logits.data(), state.logits.data() +
|
|
4530
|
+
memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
|
|
3927
4531
|
|
|
3928
4532
|
if (temperature > 0.0f) {
|
|
3929
4533
|
for (int i = 0; i < n_logits; i++) {
|
|
@@ -3951,6 +4555,11 @@ static void whisper_process_logits(
|
|
|
3951
4555
|
// suppress <|notimestamps|> token
|
|
3952
4556
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
|
|
3953
4557
|
logits[vocab.token_not] = -INFINITY;
|
|
4558
|
+
if (params.no_timestamps) {
|
|
4559
|
+
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
|
4560
|
+
logits[i] = -INFINITY;
|
|
4561
|
+
}
|
|
4562
|
+
}
|
|
3954
4563
|
|
|
3955
4564
|
// suppress sot and nosp tokens
|
|
3956
4565
|
logits[vocab.token_sot] = -INFINITY;
|
|
@@ -3964,6 +4573,15 @@ static void whisper_process_logits(
|
|
|
3964
4573
|
// suppress task tokens
|
|
3965
4574
|
logits[vocab.token_translate] = -INFINITY;
|
|
3966
4575
|
logits[vocab.token_transcribe] = -INFINITY;
|
|
4576
|
+
logits[vocab.token_prev] = -INFINITY;
|
|
4577
|
+
|
|
4578
|
+
// suppress lang tokens
|
|
4579
|
+
for (size_t i = 0; i < g_lang.size(); ++i) {
|
|
4580
|
+
logits[whisper_token_lang(&ctx, i)] = -INFINITY;
|
|
4581
|
+
}
|
|
4582
|
+
|
|
4583
|
+
// suppress prev token
|
|
4584
|
+
logits[vocab.token_prev] = -INFINITY;
|
|
3967
4585
|
|
|
3968
4586
|
if (params.logits_filter_callback) {
|
|
3969
4587
|
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
|
@@ -3996,7 +4614,7 @@ static void whisper_process_logits(
|
|
|
3996
4614
|
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
|
|
3997
4615
|
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
|
|
3998
4616
|
|
|
3999
|
-
//
|
|
4617
|
+
//WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
|
|
4000
4618
|
|
|
4001
4619
|
if (last_was_timestamp) {
|
|
4002
4620
|
if (penultimate_was_timestamp) {
|
|
@@ -4072,13 +4690,37 @@ static void whisper_process_logits(
|
|
|
4072
4690
|
|
|
4073
4691
|
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
|
|
4074
4692
|
|
|
4075
|
-
//
|
|
4693
|
+
//WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
|
4076
4694
|
|
|
4077
4695
|
if (timestamp_logprob > max_text_token_logprob) {
|
|
4078
4696
|
for (int i = 0; i < vocab.token_beg; ++i) {
|
|
4079
4697
|
logits[i] = -INFINITY;
|
|
4080
4698
|
logprobs[i] = -INFINITY;
|
|
4081
4699
|
}
|
|
4700
|
+
} else {
|
|
4701
|
+
if (params.n_grammar_rules > 0) {
|
|
4702
|
+
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
|
4703
|
+
|
|
4704
|
+
// populate the logprobs array (log_softmax)
|
|
4705
|
+
{
|
|
4706
|
+
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
|
4707
|
+
float logsumexp = 0.0f;
|
|
4708
|
+
for (int i = 0; i < n_logits; ++i) {
|
|
4709
|
+
if (logits[i] > -INFINITY) {
|
|
4710
|
+
logsumexp += expf(logits[i] - logit_max);
|
|
4711
|
+
}
|
|
4712
|
+
}
|
|
4713
|
+
logsumexp = logf(logsumexp) + logit_max;
|
|
4714
|
+
|
|
4715
|
+
for (int i = 0; i < n_logits; ++i) {
|
|
4716
|
+
if (logits[i] > -INFINITY) {
|
|
4717
|
+
logprobs[i] = logits[i] - logsumexp;
|
|
4718
|
+
} else {
|
|
4719
|
+
logprobs[i] = -INFINITY;
|
|
4720
|
+
}
|
|
4721
|
+
}
|
|
4722
|
+
}
|
|
4723
|
+
}
|
|
4082
4724
|
}
|
|
4083
4725
|
}
|
|
4084
4726
|
}
|
|
@@ -4096,38 +4738,60 @@ static void whisper_process_logits(
|
|
|
4096
4738
|
|
|
4097
4739
|
#if 0
|
|
4098
4740
|
// print first 100 logits - token string : logit
|
|
4099
|
-
for (int i = 0; i <
|
|
4100
|
-
|
|
4101
|
-
|
|
4102
|
-
|
|
4103
|
-
|
|
4104
|
-
|
|
4741
|
+
//for (int i = 0; i < 10; i++) {
|
|
4742
|
+
// const auto token = vocab.id_to_token.at(i);
|
|
4743
|
+
// const auto prob = probs[i];
|
|
4744
|
+
// const auto logit = logits[i];
|
|
4745
|
+
// const auto logprob = logprobs[i];
|
|
4746
|
+
// printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
|
|
4747
|
+
//}
|
|
4748
|
+
|
|
4749
|
+
// print sorted
|
|
4750
|
+
{
|
|
4751
|
+
std::vector<std::pair<float, int>> pairs;
|
|
4752
|
+
|
|
4753
|
+
for (int i = 0; i < n_logits; ++i) {
|
|
4754
|
+
pairs.push_back(std::make_pair(probs[i], i));
|
|
4755
|
+
}
|
|
4756
|
+
|
|
4757
|
+
std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
|
4758
|
+
return a.first > b.first;
|
|
4759
|
+
});
|
|
4760
|
+
|
|
4761
|
+
for (int i = 0; i < 10; i++) {
|
|
4762
|
+
const auto token = vocab.id_to_token.at(pairs[i].second);
|
|
4763
|
+
const auto prob = pairs[i].first;
|
|
4764
|
+
const auto logit = logits[pairs[i].second];
|
|
4765
|
+
const auto logprob = logprobs[pairs[i].second];
|
|
4766
|
+
printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
|
|
4767
|
+
}
|
|
4768
|
+
|
|
4769
|
+
printf("----------------\n");
|
|
4105
4770
|
}
|
|
4106
4771
|
|
|
4107
4772
|
// "And", "and", " And", " and"
|
|
4108
|
-
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
|
4109
|
-
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
|
4110
|
-
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
|
4111
|
-
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
|
4112
|
-
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
|
4113
|
-
|
|
4114
|
-
printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
|
4115
|
-
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
|
4116
|
-
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
|
4117
|
-
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
|
4118
|
-
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
4119
|
-
|
|
4120
|
-
printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
|
4121
|
-
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
|
4122
|
-
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
|
4123
|
-
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
|
4124
|
-
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
|
4773
|
+
//printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
|
4774
|
+
//printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
|
4775
|
+
//printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
|
4776
|
+
//printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
|
4777
|
+
//printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
|
4778
|
+
|
|
4779
|
+
//printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
|
4780
|
+
//printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
|
4781
|
+
//printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
|
4782
|
+
//printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
|
4783
|
+
//printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
4784
|
+
|
|
4785
|
+
//printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
|
4786
|
+
//printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
|
4787
|
+
//printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
|
4788
|
+
//printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
|
4789
|
+
//printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
|
4125
4790
|
#endif
|
|
4126
4791
|
}
|
|
4127
4792
|
|
|
4128
4793
|
static whisper_token_data whisper_sample_token(
|
|
4129
4794
|
whisper_context & ctx,
|
|
4130
|
-
whisper_state & state,
|
|
4131
4795
|
const whisper_decoder & decoder,
|
|
4132
4796
|
bool best) {
|
|
4133
4797
|
whisper_token_data result = {
|
|
@@ -4172,7 +4836,7 @@ static whisper_token_data whisper_sample_token(
|
|
|
4172
4836
|
} else {
|
|
4173
4837
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
4174
4838
|
|
|
4175
|
-
result.id = dist(
|
|
4839
|
+
result.id = dist(decoder.rng);
|
|
4176
4840
|
result.p = probs[result.id];
|
|
4177
4841
|
result.plog = logprobs[result.id];
|
|
4178
4842
|
}
|
|
@@ -4182,15 +4846,12 @@ static whisper_token_data whisper_sample_token(
|
|
|
4182
4846
|
result.pt = result.p;
|
|
4183
4847
|
}
|
|
4184
4848
|
|
|
4185
|
-
state.n_sample++;
|
|
4186
|
-
|
|
4187
4849
|
return result;
|
|
4188
4850
|
}
|
|
4189
4851
|
|
|
4190
4852
|
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
4191
4853
|
whisper_context & ctx,
|
|
4192
|
-
|
|
4193
|
-
const whisper_decoder & decoder,
|
|
4854
|
+
whisper_decoder & decoder,
|
|
4194
4855
|
int k) {
|
|
4195
4856
|
const auto & vocab = ctx.vocab;
|
|
4196
4857
|
|
|
@@ -4200,7 +4861,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
4200
4861
|
|
|
4201
4862
|
const int n_logits = vocab.n_vocab;
|
|
4202
4863
|
|
|
4203
|
-
auto & logits_id =
|
|
4864
|
+
auto & logits_id = decoder.logits_id;
|
|
4204
4865
|
|
|
4205
4866
|
logits_id.resize(n_logits);
|
|
4206
4867
|
for (int i = 0; i < n_logits; ++i) {
|
|
@@ -4246,8 +4907,11 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
4246
4907
|
ptsum = sum_ts;
|
|
4247
4908
|
}
|
|
4248
4909
|
|
|
4910
|
+
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
4911
|
+
|
|
4249
4912
|
for (int i = 0; i < k; ++i) {
|
|
4250
|
-
const auto id =
|
|
4913
|
+
const auto id = dist(decoder.rng);
|
|
4914
|
+
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
|
4251
4915
|
|
|
4252
4916
|
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
|
4253
4917
|
|
|
@@ -4257,8 +4921,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
4257
4921
|
}
|
|
4258
4922
|
}
|
|
4259
4923
|
|
|
4260
|
-
state.n_sample++;
|
|
4261
|
-
|
|
4262
4924
|
return result;
|
|
4263
4925
|
}
|
|
4264
4926
|
|
|
@@ -4311,115 +4973,6 @@ static void whisper_sequence_score(
|
|
|
4311
4973
|
}
|
|
4312
4974
|
}
|
|
4313
4975
|
|
|
4314
|
-
static bool whisper_kv_swap_fast(
|
|
4315
|
-
std::vector<int> & view,
|
|
4316
|
-
whisper_decoder src[],
|
|
4317
|
-
std::vector<kv_buf> & kv_swap_bufs,
|
|
4318
|
-
const int & n_decoders) {
|
|
4319
|
-
WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
|
|
4320
|
-
|
|
4321
|
-
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
|
|
4322
|
-
std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
|
|
4323
|
-
|
|
4324
|
-
// (buffer->decoder or decoder->decoder)
|
|
4325
|
-
std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
|
|
4326
|
-
|
|
4327
|
-
// (decoder<->decoder)
|
|
4328
|
-
std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
|
|
4329
|
-
std::vector<whisper_pair<int, int>> p_swap_vec;
|
|
4330
|
-
p_swap_vec.reserve(n_decoders);
|
|
4331
|
-
|
|
4332
|
-
// see https://github.com/ggerganov/whisper.cpp/wiki
|
|
4333
|
-
for (int i = 0; i < n_decoders; i++) {
|
|
4334
|
-
// zero-copy (no modification)
|
|
4335
|
-
if (i == view[i] || view[i] < 0) {
|
|
4336
|
-
continue;
|
|
4337
|
-
}
|
|
4338
|
-
|
|
4339
|
-
bool is_one_copy = true;
|
|
4340
|
-
// since we modify data sequentially, we only consider decoder indices after current index
|
|
4341
|
-
for (int j = i + 1; j < n_decoders; j++) {
|
|
4342
|
-
if (i == view[j]) {
|
|
4343
|
-
// detect symmetric diagram
|
|
4344
|
-
if (j == view[i]) {
|
|
4345
|
-
p_swap_set.insert(i);
|
|
4346
|
-
p_swap_set.insert(j);
|
|
4347
|
-
p_swap_vec.emplace_back(i, j);
|
|
4348
|
-
} else {
|
|
4349
|
-
two_copy.insert(i);
|
|
4350
|
-
is_one_copy = false;
|
|
4351
|
-
}
|
|
4352
|
-
break;
|
|
4353
|
-
}
|
|
4354
|
-
}
|
|
4355
|
-
if (is_one_copy) {
|
|
4356
|
-
one_copy.insert(i);
|
|
4357
|
-
}
|
|
4358
|
-
}
|
|
4359
|
-
|
|
4360
|
-
kv_swap_bufs.resize(n_decoders);
|
|
4361
|
-
|
|
4362
|
-
for (int i = 0; i < n_decoders; i++) {
|
|
4363
|
-
kv_swap_bufs[i].k.resize(wsp_ggml_nbytes(src[i].kv_self.k));
|
|
4364
|
-
kv_swap_bufs[i].v.resize(wsp_ggml_nbytes(src[i].kv_self.v));
|
|
4365
|
-
}
|
|
4366
|
-
|
|
4367
|
-
for (auto & i : two_copy) {
|
|
4368
|
-
// make a copy of KV caches
|
|
4369
|
-
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
|
|
4370
|
-
memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
|
|
4371
|
-
memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
|
|
4372
|
-
}
|
|
4373
|
-
|
|
4374
|
-
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
|
|
4375
|
-
for (auto & i : two_copy) {
|
|
4376
|
-
// skip the decoder indices that require pointer swapping
|
|
4377
|
-
if (p_swap_set.find(i) != p_swap_set.end()) {
|
|
4378
|
-
continue;
|
|
4379
|
-
}
|
|
4380
|
-
|
|
4381
|
-
if (two_copy.find(view[i]) != two_copy.end()) {
|
|
4382
|
-
// modify KV caches of decoder using data from kv_swap_bufs
|
|
4383
|
-
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
|
4384
|
-
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
|
4385
|
-
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
4386
|
-
} else {
|
|
4387
|
-
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
|
4388
|
-
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
|
4389
|
-
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
|
|
4390
|
-
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
|
|
4391
|
-
}
|
|
4392
|
-
}
|
|
4393
|
-
|
|
4394
|
-
// then modify one-copy decoder KV caches
|
|
4395
|
-
for (auto & i : one_copy) {
|
|
4396
|
-
// skip the decoder indices that require pointer swapping
|
|
4397
|
-
if (p_swap_set.find(i) != p_swap_set.end()) {
|
|
4398
|
-
continue;
|
|
4399
|
-
}
|
|
4400
|
-
|
|
4401
|
-
if (two_copy.find(view[i]) != two_copy.end()) {
|
|
4402
|
-
// modify KV caches of decoder using data from kv_swap_bufs
|
|
4403
|
-
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
|
4404
|
-
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
|
4405
|
-
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
4406
|
-
} else {
|
|
4407
|
-
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
|
4408
|
-
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
|
4409
|
-
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
|
|
4410
|
-
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
|
|
4411
|
-
}
|
|
4412
|
-
}
|
|
4413
|
-
|
|
4414
|
-
// swap the pointers
|
|
4415
|
-
for (auto & i : p_swap_vec) {
|
|
4416
|
-
WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
|
|
4417
|
-
std::swap(src[i.first].kv_self, src[i.second].kv_self);
|
|
4418
|
-
}
|
|
4419
|
-
|
|
4420
|
-
return true;
|
|
4421
|
-
}
|
|
4422
|
-
|
|
4423
4976
|
int whisper_full_with_state(
|
|
4424
4977
|
struct whisper_context * ctx,
|
|
4425
4978
|
struct whisper_state * state,
|
|
@@ -4435,11 +4988,11 @@ int whisper_full_with_state(
|
|
|
4435
4988
|
// compute log mel spectrogram
|
|
4436
4989
|
if (params.speed_up) {
|
|
4437
4990
|
// TODO: Replace PV with more advanced algorithm
|
|
4438
|
-
|
|
4991
|
+
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
4439
4992
|
return -1;
|
|
4440
4993
|
} else {
|
|
4441
4994
|
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
4442
|
-
|
|
4995
|
+
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
4443
4996
|
return -2;
|
|
4444
4997
|
}
|
|
4445
4998
|
}
|
|
@@ -4451,13 +5004,13 @@ int whisper_full_with_state(
|
|
|
4451
5004
|
|
|
4452
5005
|
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
|
4453
5006
|
if (lang_id < 0) {
|
|
4454
|
-
|
|
5007
|
+
WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
|
|
4455
5008
|
return -3;
|
|
4456
5009
|
}
|
|
4457
5010
|
state->lang_id = lang_id;
|
|
4458
5011
|
params.language = whisper_lang_str(lang_id);
|
|
4459
5012
|
|
|
4460
|
-
|
|
5013
|
+
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
|
4461
5014
|
if (params.detect_language) {
|
|
4462
5015
|
return 0;
|
|
4463
5016
|
}
|
|
@@ -4479,6 +5032,7 @@ int whisper_full_with_state(
|
|
|
4479
5032
|
// basically don't process anything that is less than 1.0s
|
|
4480
5033
|
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
|
4481
5034
|
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
|
|
5035
|
+
WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
|
|
4482
5036
|
return 0;
|
|
4483
5037
|
}
|
|
4484
5038
|
|
|
@@ -4509,40 +5063,23 @@ int whisper_full_with_state(
|
|
|
4509
5063
|
|
|
4510
5064
|
n_decoders = std::max(1, n_decoders);
|
|
4511
5065
|
|
|
5066
|
+
if (n_decoders > WHISPER_MAX_DECODERS) {
|
|
5067
|
+
WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
|
|
5068
|
+
return -4;
|
|
5069
|
+
}
|
|
5070
|
+
|
|
4512
5071
|
// TAGS: WHISPER_DECODER_INIT
|
|
4513
5072
|
for (int j = 1; j < n_decoders; j++) {
|
|
4514
5073
|
auto & decoder = state->decoders[j];
|
|
4515
5074
|
|
|
4516
|
-
|
|
4517
|
-
decoder.kv_self = state->decoders[0].kv_self;
|
|
4518
|
-
if (!kv_cache_reinit(decoder.kv_self)) {
|
|
4519
|
-
log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
|
4520
|
-
return -4;
|
|
4521
|
-
}
|
|
4522
|
-
|
|
4523
|
-
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
|
4524
|
-
|
|
4525
|
-
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
|
|
4526
|
-
|
|
4527
|
-
decoder.probs.resize (ctx->vocab.n_vocab);
|
|
4528
|
-
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
4529
|
-
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
5075
|
+
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
|
|
4530
5076
|
|
|
4531
|
-
|
|
4532
|
-
|
|
4533
|
-
|
|
4534
|
-
|
|
4535
|
-
log("%s: failed to add metal buffer\n", __func__); \
|
|
4536
|
-
return 0; \
|
|
4537
|
-
}
|
|
5077
|
+
decoder.probs.resize (ctx->vocab.n_vocab);
|
|
5078
|
+
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
5079
|
+
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
5080
|
+
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
|
|
4538
5081
|
|
|
4539
|
-
|
|
4540
|
-
auto & kv_self = decoder.kv_self;
|
|
4541
|
-
|
|
4542
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
|
|
4543
|
-
#undef WHISPER_METAL_CHECK_BUF
|
|
4544
|
-
#endif
|
|
4545
|
-
}
|
|
5082
|
+
decoder.rng = std::mt19937(0);
|
|
4546
5083
|
}
|
|
4547
5084
|
|
|
4548
5085
|
// the accumulated text context so far
|
|
@@ -4557,7 +5094,7 @@ int whisper_full_with_state(
|
|
|
4557
5094
|
|
|
4558
5095
|
// initial prompt
|
|
4559
5096
|
if (!params.prompt_tokens && params.initial_prompt) {
|
|
4560
|
-
prompt_tokens.resize(
|
|
5097
|
+
prompt_tokens.resize(1024);
|
|
4561
5098
|
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
|
|
4562
5099
|
params.prompt_tokens = prompt_tokens.data();
|
|
4563
5100
|
params.prompt_n_tokens = prompt_tokens.size();
|
|
@@ -4575,13 +5112,14 @@ int whisper_full_with_state(
|
|
|
4575
5112
|
|
|
4576
5113
|
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
|
4577
5114
|
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
|
4578
|
-
|
|
5115
|
+
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
|
4579
5116
|
return -5;
|
|
4580
5117
|
}
|
|
4581
5118
|
state->exp_n_audio_ctx = params.audio_ctx;
|
|
4582
5119
|
|
|
4583
5120
|
// these tokens determine the task that will be performed
|
|
4584
|
-
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
|
5121
|
+
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
|
|
5122
|
+
|
|
4585
5123
|
if (whisper_is_multilingual(ctx)) {
|
|
4586
5124
|
const int lang_id = whisper_lang_id(params.language);
|
|
4587
5125
|
state->lang_id = lang_id;
|
|
@@ -4593,6 +5131,19 @@ int whisper_full_with_state(
|
|
|
4593
5131
|
}
|
|
4594
5132
|
}
|
|
4595
5133
|
|
|
5134
|
+
// distilled models require the "no_timestamps" token
|
|
5135
|
+
{
|
|
5136
|
+
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
|
|
5137
|
+
if (is_distil && !params.no_timestamps) {
|
|
5138
|
+
WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
|
|
5139
|
+
params.no_timestamps = true;
|
|
5140
|
+
}
|
|
5141
|
+
}
|
|
5142
|
+
|
|
5143
|
+
if (params.no_timestamps) {
|
|
5144
|
+
prompt_init.push_back(whisper_token_not(ctx));
|
|
5145
|
+
}
|
|
5146
|
+
|
|
4596
5147
|
int seek = seek_start;
|
|
4597
5148
|
|
|
4598
5149
|
std::vector<whisper_token> prompt;
|
|
@@ -4605,8 +5156,10 @@ int whisper_full_with_state(
|
|
|
4605
5156
|
bool has_ts;
|
|
4606
5157
|
|
|
4607
5158
|
whisper_sequence sequence;
|
|
5159
|
+
whisper_grammar grammar;
|
|
4608
5160
|
};
|
|
4609
5161
|
|
|
5162
|
+
std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
|
|
4610
5163
|
std::vector<beam_candidate> beam_candidates;
|
|
4611
5164
|
|
|
4612
5165
|
// main loop
|
|
@@ -4615,24 +5168,24 @@ int whisper_full_with_state(
|
|
|
4615
5168
|
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
|
|
4616
5169
|
|
|
4617
5170
|
params.progress_callback(
|
|
4618
|
-
ctx,
|
|
5171
|
+
ctx, state, progress_cur, params.progress_callback_user_data);
|
|
4619
5172
|
}
|
|
4620
5173
|
|
|
4621
|
-
//
|
|
5174
|
+
// if only 1 second left, then stop
|
|
4622
5175
|
if (seek + 100 >= seek_end) {
|
|
4623
5176
|
break;
|
|
4624
5177
|
}
|
|
4625
5178
|
|
|
4626
5179
|
if (params.encoder_begin_callback) {
|
|
4627
5180
|
if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
|
|
4628
|
-
|
|
5181
|
+
WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
|
|
4629
5182
|
break;
|
|
4630
5183
|
}
|
|
4631
5184
|
}
|
|
4632
5185
|
|
|
4633
5186
|
// encode audio features starting at offset seek
|
|
4634
5187
|
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
4635
|
-
|
|
5188
|
+
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
|
4636
5189
|
return -6;
|
|
4637
5190
|
}
|
|
4638
5191
|
|
|
@@ -4668,14 +5221,12 @@ int whisper_full_with_state(
|
|
|
4668
5221
|
|
|
4669
5222
|
n_decoders_cur = std::max(1, n_decoders_cur);
|
|
4670
5223
|
|
|
4671
|
-
WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
|
|
5224
|
+
WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
|
|
4672
5225
|
|
|
4673
5226
|
// TAGS: WHISPER_DECODER_INIT
|
|
4674
5227
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
4675
5228
|
auto & decoder = state->decoders[j];
|
|
4676
5229
|
|
|
4677
|
-
decoder.kv_self.n = 0;
|
|
4678
|
-
|
|
4679
5230
|
decoder.sequence.tokens.clear();
|
|
4680
5231
|
decoder.sequence.result_len = 0;
|
|
4681
5232
|
decoder.sequence.sum_logprobs_all = 0.0;
|
|
@@ -4689,10 +5240,16 @@ int whisper_full_with_state(
|
|
|
4689
5240
|
decoder.failed = false;
|
|
4690
5241
|
decoder.completed = false;
|
|
4691
5242
|
decoder.has_ts = false;
|
|
5243
|
+
|
|
5244
|
+
if (params.grammar_rules != nullptr) {
|
|
5245
|
+
decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
|
|
5246
|
+
} else {
|
|
5247
|
+
decoder.grammar = {};
|
|
5248
|
+
}
|
|
4692
5249
|
}
|
|
4693
5250
|
|
|
4694
5251
|
// init prompt and kv cache for the current iteration
|
|
4695
|
-
//
|
|
5252
|
+
// TODO: do not recompute the prompt if it is the same as previous time
|
|
4696
5253
|
{
|
|
4697
5254
|
prompt.clear();
|
|
4698
5255
|
|
|
@@ -4714,25 +5271,26 @@ int whisper_full_with_state(
|
|
|
4714
5271
|
}
|
|
4715
5272
|
WHISPER_PRINT_DEBUG("\n\n");
|
|
4716
5273
|
|
|
4717
|
-
|
|
4718
|
-
|
|
5274
|
+
whisper_kv_cache_clear(state->kv_self);
|
|
5275
|
+
|
|
5276
|
+
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
|
5277
|
+
|
|
5278
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
5279
|
+
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
4719
5280
|
return -7;
|
|
4720
5281
|
}
|
|
4721
5282
|
|
|
4722
5283
|
{
|
|
4723
5284
|
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
4724
5285
|
|
|
4725
|
-
|
|
5286
|
+
state->decoders[0].i_batch = prompt.size() - 1;
|
|
4726
5287
|
|
|
4727
|
-
state->decoders[0]
|
|
5288
|
+
whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
|
|
4728
5289
|
|
|
4729
5290
|
for (int j = 1; j < n_decoders_cur; ++j) {
|
|
4730
5291
|
auto & decoder = state->decoders[j];
|
|
4731
5292
|
|
|
4732
|
-
|
|
4733
|
-
memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, wsp_ggml_nbytes(decoder.kv_self.v));
|
|
4734
|
-
|
|
4735
|
-
decoder.kv_self.n += prompt.size();
|
|
5293
|
+
whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
|
|
4736
5294
|
|
|
4737
5295
|
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
|
4738
5296
|
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
@@ -4747,41 +5305,81 @@ int whisper_full_with_state(
|
|
|
4747
5305
|
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
4748
5306
|
|
|
4749
5307
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
|
4750
|
-
|
|
5308
|
+
for (auto & bc : bc_per_dec) {
|
|
5309
|
+
bc.clear();
|
|
5310
|
+
}
|
|
4751
5311
|
}
|
|
4752
5312
|
|
|
4753
|
-
//
|
|
4754
|
-
|
|
4755
|
-
|
|
5313
|
+
// sampling
|
|
5314
|
+
// TODO: avoid memory allocations, optimize, avoid threads?
|
|
5315
|
+
{
|
|
5316
|
+
std::atomic<int> j_cur(0);
|
|
4756
5317
|
|
|
4757
|
-
|
|
4758
|
-
|
|
4759
|
-
|
|
5318
|
+
auto process = [&]() {
|
|
5319
|
+
while (true) {
|
|
5320
|
+
const int j = j_cur.fetch_add(1);
|
|
4760
5321
|
|
|
4761
|
-
|
|
4762
|
-
|
|
4763
|
-
|
|
4764
|
-
if (t_cur < 1e-6f) {
|
|
4765
|
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
|
|
4766
|
-
} else {
|
|
4767
|
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
|
|
4768
|
-
}
|
|
5322
|
+
if (j >= n_decoders_cur) {
|
|
5323
|
+
break;
|
|
5324
|
+
}
|
|
4769
5325
|
|
|
4770
|
-
|
|
4771
|
-
} break;
|
|
4772
|
-
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
4773
|
-
{
|
|
4774
|
-
const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
|
|
5326
|
+
auto & decoder = state->decoders[j];
|
|
4775
5327
|
|
|
4776
|
-
|
|
4777
|
-
|
|
4778
|
-
|
|
4779
|
-
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
|
|
5328
|
+
if (decoder.completed || decoder.failed) {
|
|
5329
|
+
continue;
|
|
5330
|
+
}
|
|
4780
5331
|
|
|
4781
|
-
|
|
4782
|
-
|
|
4783
|
-
|
|
5332
|
+
switch (params.strategy) {
|
|
5333
|
+
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
|
5334
|
+
{
|
|
5335
|
+
if (t_cur < 1e-6f) {
|
|
5336
|
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
|
5337
|
+
} else {
|
|
5338
|
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
|
5339
|
+
}
|
|
5340
|
+
|
|
5341
|
+
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
|
5342
|
+
} break;
|
|
5343
|
+
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
5344
|
+
{
|
|
5345
|
+
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
|
5346
|
+
|
|
5347
|
+
for (const auto & token : tokens_new) {
|
|
5348
|
+
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
|
|
5349
|
+
bc_per_dec[j].back().sequence.tokens.push_back(token);
|
|
5350
|
+
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
|
|
5351
|
+
}
|
|
5352
|
+
} break;
|
|
5353
|
+
};
|
|
5354
|
+
}
|
|
4784
5355
|
};
|
|
5356
|
+
|
|
5357
|
+
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
|
5358
|
+
|
|
5359
|
+
if (n_threads == 1) {
|
|
5360
|
+
process();
|
|
5361
|
+
} else {
|
|
5362
|
+
std::vector<std::thread> threads(n_threads - 1);
|
|
5363
|
+
|
|
5364
|
+
for (int t = 0; t < n_threads - 1; ++t) {
|
|
5365
|
+
threads[t] = std::thread(process);
|
|
5366
|
+
}
|
|
5367
|
+
|
|
5368
|
+
process();
|
|
5369
|
+
|
|
5370
|
+
for (int t = 0; t < n_threads - 1; ++t) {
|
|
5371
|
+
threads[t].join();
|
|
5372
|
+
}
|
|
5373
|
+
}
|
|
5374
|
+
}
|
|
5375
|
+
|
|
5376
|
+
beam_candidates.clear();
|
|
5377
|
+
for (const auto & bc : bc_per_dec) {
|
|
5378
|
+
beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
|
|
5379
|
+
|
|
5380
|
+
if (!bc.empty()) {
|
|
5381
|
+
state->n_sample += 1;
|
|
5382
|
+
}
|
|
4785
5383
|
}
|
|
4786
5384
|
|
|
4787
5385
|
// for beam-search, choose the top candidates and update the KV caches
|
|
@@ -4794,7 +5392,6 @@ int whisper_full_with_state(
|
|
|
4794
5392
|
});
|
|
4795
5393
|
|
|
4796
5394
|
uint32_t cur_c = 0;
|
|
4797
|
-
std::vector<int> decoder_idx(n_decoders_cur, -1);
|
|
4798
5395
|
|
|
4799
5396
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
4800
5397
|
auto & decoder = state->decoders[j];
|
|
@@ -4803,23 +5400,38 @@ int whisper_full_with_state(
|
|
|
4803
5400
|
continue;
|
|
4804
5401
|
}
|
|
4805
5402
|
|
|
5403
|
+
if (cur_c >= beam_candidates.size()) {
|
|
5404
|
+
cur_c = 0;
|
|
5405
|
+
}
|
|
5406
|
+
|
|
4806
5407
|
auto & cur = beam_candidates[cur_c++];
|
|
4807
5408
|
|
|
4808
5409
|
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
|
|
4809
5410
|
++cur_c;
|
|
4810
5411
|
}
|
|
4811
5412
|
|
|
4812
|
-
decoder.sequence = cur.sequence;
|
|
4813
5413
|
decoder.seek_delta = cur.seek_delta;
|
|
4814
5414
|
decoder.has_ts = cur.has_ts;
|
|
5415
|
+
decoder.sequence = cur.sequence;
|
|
5416
|
+
decoder.grammar = cur.grammar;
|
|
5417
|
+
|
|
5418
|
+
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
|
|
4815
5419
|
|
|
4816
|
-
decoder_idx[j] = cur.decoder_idx;
|
|
4817
5420
|
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
|
4818
5421
|
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
|
4819
5422
|
}
|
|
4820
5423
|
|
|
4821
|
-
|
|
4822
|
-
|
|
5424
|
+
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
5425
|
+
auto & decoder = state->decoders[j];
|
|
5426
|
+
|
|
5427
|
+
if (decoder.completed || decoder.failed) {
|
|
5428
|
+
continue;
|
|
5429
|
+
}
|
|
5430
|
+
|
|
5431
|
+
whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
|
|
5432
|
+
whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
|
|
5433
|
+
whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
|
|
5434
|
+
}
|
|
4823
5435
|
}
|
|
4824
5436
|
|
|
4825
5437
|
// update the decoder state
|
|
@@ -4848,6 +5460,7 @@ int whisper_full_with_state(
|
|
|
4848
5460
|
|
|
4849
5461
|
// do not allow to go back in time
|
|
4850
5462
|
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
|
5463
|
+
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
|
|
4851
5464
|
failed = true; // TODO: maybe this is not a failure ?
|
|
4852
5465
|
continue;
|
|
4853
5466
|
}
|
|
@@ -4857,6 +5470,8 @@ int whisper_full_with_state(
|
|
|
4857
5470
|
has_ts = true;
|
|
4858
5471
|
}
|
|
4859
5472
|
|
|
5473
|
+
whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
|
|
5474
|
+
|
|
4860
5475
|
#ifdef WHISPER_DEBUG
|
|
4861
5476
|
{
|
|
4862
5477
|
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
|
|
@@ -4874,6 +5489,7 @@ int whisper_full_with_state(
|
|
|
4874
5489
|
if (seek + seek_delta + 100 >= seek_end) {
|
|
4875
5490
|
result_len = i + 1;
|
|
4876
5491
|
} else {
|
|
5492
|
+
WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
|
|
4877
5493
|
failed = true;
|
|
4878
5494
|
continue;
|
|
4879
5495
|
}
|
|
@@ -4884,6 +5500,7 @@ int whisper_full_with_state(
|
|
|
4884
5500
|
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
4885
5501
|
}
|
|
4886
5502
|
|
|
5503
|
+
WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j);
|
|
4887
5504
|
completed = true;
|
|
4888
5505
|
continue;
|
|
4889
5506
|
}
|
|
@@ -4899,6 +5516,7 @@ int whisper_full_with_state(
|
|
|
4899
5516
|
// sometimes, the decoding can get stuck in a repetition loop
|
|
4900
5517
|
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
|
4901
5518
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
5519
|
+
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
|
|
4902
5520
|
failed = true;
|
|
4903
5521
|
continue;
|
|
4904
5522
|
}
|
|
@@ -4926,32 +5544,83 @@ int whisper_full_with_state(
|
|
|
4926
5544
|
state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
|
|
4927
5545
|
|
|
4928
5546
|
// obtain logits for the next token
|
|
4929
|
-
|
|
4930
|
-
auto &
|
|
5547
|
+
{
|
|
5548
|
+
auto & batch = state->batch;
|
|
4931
5549
|
|
|
4932
|
-
|
|
4933
|
-
|
|
4934
|
-
|
|
5550
|
+
batch.n_tokens = 0;
|
|
5551
|
+
|
|
5552
|
+
const int n_past = prompt.size() + i;
|
|
5553
|
+
|
|
5554
|
+
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
5555
|
+
auto & decoder = state->decoders[j];
|
|
5556
|
+
|
|
5557
|
+
if (decoder.failed || decoder.completed) {
|
|
5558
|
+
continue;
|
|
5559
|
+
}
|
|
4935
5560
|
|
|
4936
|
-
|
|
4937
|
-
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
|
|
5561
|
+
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
|
|
4938
5562
|
|
|
4939
|
-
|
|
5563
|
+
decoder.i_batch = batch.n_tokens;
|
|
4940
5564
|
|
|
4941
|
-
|
|
4942
|
-
|
|
5565
|
+
batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
|
|
5566
|
+
batch.pos [batch.n_tokens] = n_past;
|
|
5567
|
+
batch.n_seq_id[batch.n_tokens] = 1;
|
|
5568
|
+
batch.seq_id [batch.n_tokens][0] = j;
|
|
5569
|
+
batch.logits [batch.n_tokens] = 1;
|
|
5570
|
+
batch.n_tokens++;
|
|
5571
|
+
}
|
|
5572
|
+
|
|
5573
|
+
assert(batch.n_tokens > 0);
|
|
5574
|
+
|
|
5575
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
5576
|
+
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
4943
5577
|
return -8;
|
|
4944
5578
|
}
|
|
4945
5579
|
|
|
5580
|
+
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
5581
|
+
|
|
5582
|
+
// TODO: avoid memory allocations, optimize, avoid threads?
|
|
4946
5583
|
{
|
|
4947
|
-
|
|
5584
|
+
std::atomic<int> j_cur(0);
|
|
5585
|
+
|
|
5586
|
+
auto process = [&]() {
|
|
5587
|
+
while (true) {
|
|
5588
|
+
const int j = j_cur.fetch_add(1);
|
|
5589
|
+
|
|
5590
|
+
if (j >= n_decoders_cur) {
|
|
5591
|
+
break;
|
|
5592
|
+
}
|
|
4948
5593
|
|
|
4949
|
-
|
|
5594
|
+
auto & decoder = state->decoders[j];
|
|
4950
5595
|
|
|
4951
|
-
|
|
5596
|
+
if (decoder.failed || decoder.completed) {
|
|
5597
|
+
continue;
|
|
5598
|
+
}
|
|
5599
|
+
|
|
5600
|
+
whisper_process_logits(*ctx, *state, decoder, params, t_cur);
|
|
5601
|
+
}
|
|
5602
|
+
};
|
|
5603
|
+
|
|
5604
|
+
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
|
5605
|
+
|
|
5606
|
+
if (n_threads == 1) {
|
|
5607
|
+
process();
|
|
5608
|
+
} else {
|
|
5609
|
+
std::vector<std::thread> threads(n_threads - 1);
|
|
5610
|
+
|
|
5611
|
+
for (int t = 0; t < n_threads - 1; ++t) {
|
|
5612
|
+
threads[t] = std::thread(process);
|
|
5613
|
+
}
|
|
5614
|
+
|
|
5615
|
+
process();
|
|
4952
5616
|
|
|
4953
|
-
|
|
5617
|
+
for (int t = 0; t < n_threads - 1; ++t) {
|
|
5618
|
+
threads[t].join();
|
|
5619
|
+
}
|
|
5620
|
+
}
|
|
4954
5621
|
}
|
|
5622
|
+
|
|
5623
|
+
state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
|
|
4955
5624
|
}
|
|
4956
5625
|
}
|
|
4957
5626
|
|
|
@@ -4991,28 +5660,27 @@ int whisper_full_with_state(
|
|
|
4991
5660
|
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
|
4992
5661
|
}
|
|
4993
5662
|
|
|
5663
|
+
bool success = true;
|
|
5664
|
+
|
|
4994
5665
|
// was the decoding successful for the current temperature?
|
|
4995
5666
|
// do fallback only if:
|
|
4996
5667
|
// - we are not at the last temperature
|
|
4997
|
-
|
|
4998
|
-
if (it != (int) temperatures.size() - 1 &&
|
|
4999
|
-
seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
|
|
5000
|
-
bool success = true;
|
|
5001
|
-
|
|
5668
|
+
if (it != (int) temperatures.size() - 1) {
|
|
5002
5669
|
const auto & decoder = state->decoders[best_decoder_id];
|
|
5003
5670
|
|
|
5004
5671
|
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
|
5672
|
+
WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
|
|
5005
5673
|
success = false;
|
|
5006
5674
|
state->n_fail_p++;
|
|
5007
5675
|
}
|
|
5676
|
+
}
|
|
5008
5677
|
|
|
5009
|
-
|
|
5010
|
-
|
|
5011
|
-
|
|
5012
|
-
|
|
5678
|
+
if (success) {
|
|
5679
|
+
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
|
5680
|
+
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
|
5681
|
+
//}
|
|
5013
5682
|
|
|
5014
|
-
|
|
5015
|
-
}
|
|
5683
|
+
break;
|
|
5016
5684
|
}
|
|
5017
5685
|
|
|
5018
5686
|
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
|
@@ -5248,11 +5916,13 @@ int whisper_full_parallel(
|
|
|
5248
5916
|
ctx->state->t_sample_us += states[i]->t_sample_us;
|
|
5249
5917
|
ctx->state->t_encode_us += states[i]->t_encode_us;
|
|
5250
5918
|
ctx->state->t_decode_us += states[i]->t_decode_us;
|
|
5919
|
+
ctx->state->t_batchd_us += states[i]->t_batchd_us;
|
|
5251
5920
|
ctx->state->t_prompt_us += states[i]->t_prompt_us;
|
|
5252
5921
|
|
|
5253
5922
|
ctx->state->n_sample += states[i]->n_sample;
|
|
5254
5923
|
ctx->state->n_encode += states[i]->n_encode;
|
|
5255
5924
|
ctx->state->n_decode += states[i]->n_decode;
|
|
5925
|
+
ctx->state->n_batchd += states[i]->n_batchd;
|
|
5256
5926
|
ctx->state->n_prompt += states[i]->n_prompt;
|
|
5257
5927
|
|
|
5258
5928
|
whisper_free_state(states[i]);
|
|
@@ -5265,12 +5935,12 @@ int whisper_full_parallel(
|
|
|
5265
5935
|
ctx->state->t_decode_us /= n_processors;
|
|
5266
5936
|
|
|
5267
5937
|
// print information about the audio boundaries
|
|
5268
|
-
|
|
5269
|
-
|
|
5938
|
+
WHISPER_LOG_WARN("\n");
|
|
5939
|
+
WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
|
|
5270
5940
|
for (int i = 0; i < n_processors - 1; ++i) {
|
|
5271
|
-
|
|
5941
|
+
WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
|
|
5272
5942
|
}
|
|
5273
|
-
|
|
5943
|
+
WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
|
|
5274
5944
|
|
|
5275
5945
|
return ret;
|
|
5276
5946
|
}
|
|
@@ -5385,8 +6055,45 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
|
5385
6055
|
size_t n = 20;
|
|
5386
6056
|
size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations
|
|
5387
6057
|
|
|
5388
|
-
// 1GB
|
|
5389
|
-
const size_t size = arr*
|
|
6058
|
+
// 1GB array
|
|
6059
|
+
const size_t size = arr*1e6;
|
|
6060
|
+
|
|
6061
|
+
double sum = 0.0;
|
|
6062
|
+
|
|
6063
|
+
// heat-up
|
|
6064
|
+
{
|
|
6065
|
+
char * src = (char *) malloc(size);
|
|
6066
|
+
char * dst = (char *) malloc(size);
|
|
6067
|
+
|
|
6068
|
+
for (size_t i = 0; i < size; i++) src[i] = i;
|
|
6069
|
+
|
|
6070
|
+
memcpy(dst, src, size); // heat-up
|
|
6071
|
+
|
|
6072
|
+
double tsum = 0.0;
|
|
6073
|
+
|
|
6074
|
+
for (size_t i = 0; i < n; i++) {
|
|
6075
|
+
const int64_t t0 = wsp_ggml_time_us();
|
|
6076
|
+
|
|
6077
|
+
memcpy(dst, src, size);
|
|
6078
|
+
|
|
6079
|
+
const int64_t t1 = wsp_ggml_time_us();
|
|
6080
|
+
|
|
6081
|
+
tsum += (t1 - t0)*1e-6;
|
|
6082
|
+
|
|
6083
|
+
src[rand() % size] = rand() % 256;
|
|
6084
|
+
}
|
|
6085
|
+
|
|
6086
|
+
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9));
|
|
6087
|
+
s += strbuf;
|
|
6088
|
+
|
|
6089
|
+
// needed to prevent the compiler from optimizing the memcpy away
|
|
6090
|
+
{
|
|
6091
|
+
for (size_t i = 0; i < size; i++) sum += dst[i];
|
|
6092
|
+
}
|
|
6093
|
+
|
|
6094
|
+
free(src);
|
|
6095
|
+
free(dst);
|
|
6096
|
+
}
|
|
5390
6097
|
|
|
5391
6098
|
// single-thread
|
|
5392
6099
|
{
|
|
@@ -5398,7 +6105,6 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
|
5398
6105
|
memcpy(dst, src, size); // heat-up
|
|
5399
6106
|
|
|
5400
6107
|
double tsum = 0.0;
|
|
5401
|
-
double sum = 0.0;
|
|
5402
6108
|
|
|
5403
6109
|
for (size_t i = 0; i < n; i++) {
|
|
5404
6110
|
const int64_t t0 = wsp_ggml_time_us();
|
|
@@ -5412,21 +6118,73 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
|
5412
6118
|
src[rand() % size] = rand() % 256;
|
|
5413
6119
|
}
|
|
5414
6120
|
|
|
5415
|
-
snprintf(strbuf, sizeof(strbuf), "memcpy:
|
|
6121
|
+
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9));
|
|
5416
6122
|
s += strbuf;
|
|
5417
6123
|
|
|
5418
6124
|
// needed to prevent the compiler from optimizing the memcpy away
|
|
5419
6125
|
{
|
|
5420
6126
|
for (size_t i = 0; i < size; i++) sum += dst[i];
|
|
6127
|
+
}
|
|
6128
|
+
|
|
6129
|
+
free(src);
|
|
6130
|
+
free(dst);
|
|
6131
|
+
}
|
|
6132
|
+
|
|
6133
|
+
// multi-thread
|
|
6134
|
+
|
|
6135
|
+
for (uint32_t k = 1; k <= n_threads; k++) {
|
|
6136
|
+
char * src = (char *) malloc(size);
|
|
6137
|
+
char * dst = (char *) malloc(size);
|
|
6138
|
+
|
|
6139
|
+
for (size_t i = 0; i < size; i++) src[i] = i;
|
|
6140
|
+
|
|
6141
|
+
memcpy(dst, src, size); // heat-up
|
|
6142
|
+
|
|
6143
|
+
double tsum = 0.0;
|
|
6144
|
+
|
|
6145
|
+
auto helper = [&](int th) {
|
|
6146
|
+
const int64_t i0 = (th + 0)*size/k;
|
|
6147
|
+
const int64_t i1 = (th + 1)*size/k;
|
|
6148
|
+
|
|
6149
|
+
for (size_t i = 0; i < n; i++) {
|
|
6150
|
+
memcpy(dst + i0, src + i0, i1 - i0);
|
|
6151
|
+
|
|
6152
|
+
src[i0 + rand() % (i1 - i0)] = rand() % 256;
|
|
6153
|
+
};
|
|
6154
|
+
};
|
|
6155
|
+
|
|
6156
|
+
const int64_t t0 = wsp_ggml_time_us();
|
|
6157
|
+
|
|
6158
|
+
std::vector<std::thread> threads(k - 1);
|
|
6159
|
+
for (uint32_t th = 0; th < k - 1; ++th) {
|
|
6160
|
+
threads[th] = std::thread(helper, th);
|
|
6161
|
+
}
|
|
6162
|
+
|
|
6163
|
+
helper(k - 1);
|
|
6164
|
+
|
|
6165
|
+
for (uint32_t th = 0; th < k - 1; ++th) {
|
|
6166
|
+
threads[th].join();
|
|
6167
|
+
}
|
|
5421
6168
|
|
|
5422
|
-
|
|
5423
|
-
|
|
6169
|
+
const int64_t t1 = wsp_ggml_time_us();
|
|
6170
|
+
|
|
6171
|
+
tsum += (t1 - t0)*1e-6;
|
|
6172
|
+
|
|
6173
|
+
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
|
|
6174
|
+
s += strbuf;
|
|
6175
|
+
|
|
6176
|
+
// needed to prevent the compiler from optimizing the memcpy away
|
|
6177
|
+
{
|
|
6178
|
+
for (size_t i = 0; i < size; i++) sum += dst[i];
|
|
5424
6179
|
}
|
|
5425
6180
|
|
|
5426
6181
|
free(src);
|
|
5427
6182
|
free(dst);
|
|
5428
6183
|
}
|
|
5429
6184
|
|
|
6185
|
+
snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
|
|
6186
|
+
s += strbuf;
|
|
6187
|
+
|
|
5430
6188
|
return s.c_str();
|
|
5431
6189
|
}
|
|
5432
6190
|
|
|
@@ -5454,7 +6212,7 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
5454
6212
|
// b: N*N*sizeof(float)
|
|
5455
6213
|
// c: N*N*sizeof(float)
|
|
5456
6214
|
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
|
5457
|
-
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead());
|
|
6215
|
+
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead() + wsp_ggml_graph_overhead());
|
|
5458
6216
|
std::vector<uint8_t> work;
|
|
5459
6217
|
|
|
5460
6218
|
// put a bunch of random data in the buffer
|
|
@@ -5505,17 +6263,19 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
5505
6263
|
|
|
5506
6264
|
struct wsp_ggml_tensor * c = wsp_ggml_mul_mat(ctx0, a, b);
|
|
5507
6265
|
|
|
5508
|
-
struct wsp_ggml_cgraph gf =
|
|
6266
|
+
struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
6267
|
+
|
|
6268
|
+
wsp_ggml_build_forward_expand(gf, c);
|
|
5509
6269
|
|
|
5510
6270
|
double tsum = 0.0;
|
|
5511
6271
|
|
|
5512
6272
|
// heat-up
|
|
5513
|
-
wsp_ggml_graph_compute_helper(
|
|
6273
|
+
wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
|
|
5514
6274
|
|
|
5515
6275
|
for (int i = 0; i < n_max; ++i) {
|
|
5516
6276
|
const int64_t t0 = wsp_ggml_time_us();
|
|
5517
6277
|
|
|
5518
|
-
wsp_ggml_graph_compute_helper(
|
|
6278
|
+
wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
|
|
5519
6279
|
|
|
5520
6280
|
const int64_t t1 = wsp_ggml_time_us();
|
|
5521
6281
|
|
|
@@ -5633,7 +6393,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
5633
6393
|
const int n_samples = state.energy.size();
|
|
5634
6394
|
|
|
5635
6395
|
if (n_samples == 0) {
|
|
5636
|
-
|
|
6396
|
+
WHISPER_LOG_ERROR("%s: no signal data available\n", __func__);
|
|
5637
6397
|
return;
|
|
5638
6398
|
}
|
|
5639
6399
|
|
|
@@ -5854,6 +6614,32 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
5854
6614
|
//}
|
|
5855
6615
|
}
|
|
5856
6616
|
|
|
5857
|
-
void
|
|
5858
|
-
|
|
6617
|
+
void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
|
|
6618
|
+
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
|
6619
|
+
g_state.log_callback_user_data = user_data;
|
|
6620
|
+
}
|
|
6621
|
+
|
|
6622
|
+
WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
|
|
6623
|
+
static void whisper_log_internal(wsp_ggml_log_level level, const char * format, ...) {
|
|
6624
|
+
va_list args;
|
|
6625
|
+
va_start(args, format);
|
|
6626
|
+
char buffer[1024];
|
|
6627
|
+
int len = vsnprintf(buffer, 1024, format, args);
|
|
6628
|
+
if (len < 1024) {
|
|
6629
|
+
g_state.log_callback(level, buffer, g_state.log_callback_user_data);
|
|
6630
|
+
} else {
|
|
6631
|
+
char* buffer2 = new char[len+1];
|
|
6632
|
+
vsnprintf(buffer2, len+1, format, args);
|
|
6633
|
+
buffer2[len] = 0;
|
|
6634
|
+
g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
|
|
6635
|
+
delete[] buffer2;
|
|
6636
|
+
}
|
|
6637
|
+
va_end(args);
|
|
6638
|
+
}
|
|
6639
|
+
|
|
6640
|
+
static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data) {
|
|
6641
|
+
(void) level;
|
|
6642
|
+
(void) user_data;
|
|
6643
|
+
fputs(text, stderr);
|
|
6644
|
+
fflush(stderr);
|
|
5859
6645
|
}
|