whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/android/build.gradle +4 -0
- package/android/src/main/CMakeLists.txt +5 -0
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
- package/android/src/main/jni-utils.h +76 -0
- package/android/src/main/jni.cpp +187 -112
- package/cpp/README.md +1 -1
- package/cpp/coreml/whisper-encoder-impl.h +1 -1
- package/cpp/coreml/whisper-encoder.h +4 -0
- package/cpp/coreml/whisper-encoder.mm +4 -2
- package/cpp/ggml-alloc.c +55 -19
- package/cpp/ggml-alloc.h +7 -0
- package/cpp/ggml-backend-impl.h +46 -21
- package/cpp/ggml-backend.c +563 -156
- package/cpp/ggml-backend.h +62 -17
- package/cpp/ggml-impl.h +1 -1
- package/cpp/ggml-metal-whisper.metal +1010 -253
- package/cpp/ggml-metal.h +7 -1
- package/cpp/ggml-metal.m +618 -187
- package/cpp/ggml-quants.c +64 -59
- package/cpp/ggml-quants.h +40 -40
- package/cpp/ggml.c +751 -1466
- package/cpp/ggml.h +90 -25
- package/cpp/rn-audioutils.cpp +68 -0
- package/cpp/rn-audioutils.h +14 -0
- package/cpp/rn-whisper-log.h +11 -0
- package/cpp/rn-whisper.cpp +141 -59
- package/cpp/rn-whisper.h +47 -15
- package/cpp/whisper.cpp +1635 -928
- package/cpp/whisper.h +55 -10
- package/ios/RNWhisper.mm +7 -7
- package/ios/RNWhisperAudioUtils.h +0 -2
- package/ios/RNWhisperAudioUtils.m +0 -56
- package/ios/RNWhisperContext.h +3 -11
- package/ios/RNWhisperContext.mm +62 -134
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +6 -5
- package/src/version.json +1 -1
package/cpp/whisper.cpp
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
1
|
#include "whisper.h"
|
|
2
|
+
|
|
2
3
|
#ifdef WHISPER_USE_COREML
|
|
3
4
|
#include "coreml/whisper-encoder.h"
|
|
4
5
|
#endif
|
|
5
6
|
|
|
6
7
|
#ifdef WSP_GGML_USE_METAL
|
|
7
|
-
#
|
|
8
|
+
#include "ggml-metal.h"
|
|
9
|
+
#endif
|
|
10
|
+
|
|
11
|
+
#ifdef WSP_GGML_USE_CUBLAS
|
|
12
|
+
#include "ggml-cuda.h"
|
|
8
13
|
#endif
|
|
9
14
|
|
|
10
15
|
#ifdef WHISPER_USE_OPENVINO
|
|
@@ -13,7 +18,9 @@
|
|
|
13
18
|
|
|
14
19
|
#include "ggml.h"
|
|
15
20
|
#include "ggml-alloc.h"
|
|
21
|
+
#include "ggml-backend.h"
|
|
16
22
|
|
|
23
|
+
#include <atomic>
|
|
17
24
|
#include <algorithm>
|
|
18
25
|
#include <cassert>
|
|
19
26
|
#define _USE_MATH_DEFINES
|
|
@@ -97,10 +104,32 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
|
|
|
97
104
|
#define BYTESWAP_TENSOR(t) do {} while (0)
|
|
98
105
|
#endif
|
|
99
106
|
|
|
107
|
+
#ifdef __GNUC__
|
|
108
|
+
#ifdef __MINGW32__
|
|
109
|
+
#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
|
110
|
+
#else
|
|
111
|
+
#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
|
112
|
+
#endif
|
|
113
|
+
#else
|
|
114
|
+
#define WHISPER_ATTRIBUTE_FORMAT(...)
|
|
115
|
+
#endif
|
|
116
|
+
|
|
117
|
+
//
|
|
118
|
+
// logging
|
|
119
|
+
//
|
|
120
|
+
|
|
121
|
+
WHISPER_ATTRIBUTE_FORMAT(2, 3)
|
|
122
|
+
static void whisper_log_internal (wsp_ggml_log_level level, const char * format, ...);
|
|
123
|
+
static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data);
|
|
124
|
+
|
|
125
|
+
#define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
|
126
|
+
#define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
|
127
|
+
#define WHISPER_LOG_ERROR(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
|
128
|
+
|
|
100
129
|
#define WHISPER_ASSERT(x) \
|
|
101
130
|
do { \
|
|
102
131
|
if (!(x)) { \
|
|
103
|
-
|
|
132
|
+
WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
|
104
133
|
abort(); \
|
|
105
134
|
} \
|
|
106
135
|
} while (0)
|
|
@@ -119,7 +148,7 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
|
|
|
119
148
|
|
|
120
149
|
//#define WHISPER_USE_FLASH_ATTN
|
|
121
150
|
//#define WHISPER_USE_FLASH_FF
|
|
122
|
-
#define WHISPER_MAX_DECODERS
|
|
151
|
+
#define WHISPER_MAX_DECODERS 8
|
|
123
152
|
#define WHISPER_MAX_NODES 4096
|
|
124
153
|
|
|
125
154
|
//
|
|
@@ -127,8 +156,8 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
|
|
|
127
156
|
//
|
|
128
157
|
|
|
129
158
|
static void wsp_ggml_graph_compute_helper(
|
|
159
|
+
struct wsp_ggml_cgraph * graph,
|
|
130
160
|
std::vector<uint8_t> & buf,
|
|
131
|
-
wsp_ggml_cgraph * graph,
|
|
132
161
|
int n_threads,
|
|
133
162
|
whisper_abort_callback abort_callback,
|
|
134
163
|
void * abort_callback_data) {
|
|
@@ -145,6 +174,21 @@ static void wsp_ggml_graph_compute_helper(
|
|
|
145
174
|
wsp_ggml_graph_compute(graph, &plan);
|
|
146
175
|
}
|
|
147
176
|
|
|
177
|
+
static void wsp_ggml_graph_compute_helper(
|
|
178
|
+
struct wsp_ggml_backend * backend,
|
|
179
|
+
struct wsp_ggml_cgraph * graph,
|
|
180
|
+
int n_threads) {
|
|
181
|
+
if (wsp_ggml_backend_is_cpu(backend)) {
|
|
182
|
+
wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
|
|
183
|
+
}
|
|
184
|
+
#ifdef WSP_GGML_USE_METAL
|
|
185
|
+
if (wsp_ggml_backend_is_metal(backend)) {
|
|
186
|
+
wsp_ggml_backend_metal_set_n_cb(backend, n_threads);
|
|
187
|
+
}
|
|
188
|
+
#endif
|
|
189
|
+
wsp_ggml_backend_graph_compute(backend, graph);
|
|
190
|
+
}
|
|
191
|
+
|
|
148
192
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
|
149
193
|
// the idea is to represent the original matrix multiplication:
|
|
150
194
|
//
|
|
@@ -179,6 +223,7 @@ static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * c
|
|
|
179
223
|
}
|
|
180
224
|
|
|
181
225
|
// TODO: check if other platforms can benefit from this optimization
|
|
226
|
+
// TODO: CUDA is currently broken - seems wsp_ggml_mul_mat does not handle views correctly
|
|
182
227
|
#if defined(WSP_GGML_USE_METAL)
|
|
183
228
|
#define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad
|
|
184
229
|
#endif
|
|
@@ -305,75 +350,6 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
|
305
350
|
{ "yue", { 99, "cantonese", } },
|
|
306
351
|
};
|
|
307
352
|
|
|
308
|
-
static const size_t MB = 1ull*1024*1024;
|
|
309
|
-
|
|
310
|
-
// TODO: avoid using GGUF
|
|
311
|
-
static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
|
312
|
-
{ WSP_GGML_TYPE_F32,
|
|
313
|
-
{
|
|
314
|
-
{ MODEL_TINY, 74ull*MB },
|
|
315
|
-
{ MODEL_BASE, 142ull*MB },
|
|
316
|
-
{ MODEL_SMALL, 466ull*MB },
|
|
317
|
-
{ MODEL_MEDIUM, 1464ull*MB },
|
|
318
|
-
{ MODEL_LARGE, 2952ull*MB },
|
|
319
|
-
},
|
|
320
|
-
},
|
|
321
|
-
{ WSP_GGML_TYPE_F16,
|
|
322
|
-
{
|
|
323
|
-
{ MODEL_TINY, 74ull*MB },
|
|
324
|
-
{ MODEL_BASE, 142ull*MB },
|
|
325
|
-
{ MODEL_SMALL, 466ull*MB },
|
|
326
|
-
{ MODEL_MEDIUM, 1464ull*MB },
|
|
327
|
-
{ MODEL_LARGE, 2952ull*MB },
|
|
328
|
-
},
|
|
329
|
-
},
|
|
330
|
-
{ WSP_GGML_TYPE_Q4_0,
|
|
331
|
-
{
|
|
332
|
-
{ MODEL_TINY, 26ull*MB },
|
|
333
|
-
{ MODEL_BASE, 50ull*MB },
|
|
334
|
-
{ MODEL_SMALL, 154ull*MB },
|
|
335
|
-
{ MODEL_MEDIUM, 470ull*MB },
|
|
336
|
-
{ MODEL_LARGE, 940ull*MB },
|
|
337
|
-
},
|
|
338
|
-
},
|
|
339
|
-
{ WSP_GGML_TYPE_Q4_1,
|
|
340
|
-
{
|
|
341
|
-
{ MODEL_TINY, 32ull*MB },
|
|
342
|
-
{ MODEL_BASE, 58ull*MB },
|
|
343
|
-
{ MODEL_SMALL, 182ull*MB },
|
|
344
|
-
{ MODEL_MEDIUM, 562ull*MB },
|
|
345
|
-
{ MODEL_LARGE, 1124ull*MB },
|
|
346
|
-
},
|
|
347
|
-
},
|
|
348
|
-
{ WSP_GGML_TYPE_Q5_0,
|
|
349
|
-
{
|
|
350
|
-
{ MODEL_TINY, 30ull*MB },
|
|
351
|
-
{ MODEL_BASE, 54ull*MB },
|
|
352
|
-
{ MODEL_SMALL, 170ull*MB },
|
|
353
|
-
{ MODEL_MEDIUM, 516ull*MB },
|
|
354
|
-
{ MODEL_LARGE, 1034ull*MB },
|
|
355
|
-
},
|
|
356
|
-
},
|
|
357
|
-
{ WSP_GGML_TYPE_Q5_1,
|
|
358
|
-
{
|
|
359
|
-
{ MODEL_TINY, 32ull*MB },
|
|
360
|
-
{ MODEL_BASE, 58ull*MB },
|
|
361
|
-
{ MODEL_SMALL, 182ull*MB },
|
|
362
|
-
{ MODEL_MEDIUM, 562ull*MB },
|
|
363
|
-
{ MODEL_LARGE, 1124ull*MB },
|
|
364
|
-
},
|
|
365
|
-
},
|
|
366
|
-
{ WSP_GGML_TYPE_Q8_0,
|
|
367
|
-
{
|
|
368
|
-
{ MODEL_TINY, 45ull*MB },
|
|
369
|
-
{ MODEL_BASE, 84ull*MB },
|
|
370
|
-
{ MODEL_SMALL, 268ull*MB },
|
|
371
|
-
{ MODEL_MEDIUM, 834ull*MB },
|
|
372
|
-
{ MODEL_LARGE, 1674ull*MB },
|
|
373
|
-
},
|
|
374
|
-
},
|
|
375
|
-
};
|
|
376
|
-
|
|
377
353
|
struct whisper_mel {
|
|
378
354
|
int n_len;
|
|
379
355
|
int n_len_org;
|
|
@@ -431,6 +407,121 @@ struct whisper_segment {
|
|
|
431
407
|
bool speaker_turn_next;
|
|
432
408
|
};
|
|
433
409
|
|
|
410
|
+
struct whisper_batch {
|
|
411
|
+
int32_t n_tokens;
|
|
412
|
+
|
|
413
|
+
whisper_token * token;
|
|
414
|
+
whisper_pos * pos;
|
|
415
|
+
int32_t * n_seq_id;
|
|
416
|
+
whisper_seq_id ** seq_id; // null terminated
|
|
417
|
+
int8_t * logits;
|
|
418
|
+
};
|
|
419
|
+
|
|
420
|
+
static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
|
|
421
|
+
whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
|
|
422
|
+
|
|
423
|
+
batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens));
|
|
424
|
+
batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens));
|
|
425
|
+
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
|
|
426
|
+
batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
|
|
427
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
428
|
+
batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
|
|
429
|
+
}
|
|
430
|
+
batch.seq_id[n_tokens] = nullptr;
|
|
431
|
+
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
|
432
|
+
|
|
433
|
+
return batch;
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
static void whisper_batch_free(struct whisper_batch batch) {
|
|
437
|
+
if (batch.token) free(batch.token);
|
|
438
|
+
if (batch.pos) free(batch.pos);
|
|
439
|
+
if (batch.n_seq_id) free(batch.n_seq_id);
|
|
440
|
+
if (batch.seq_id) {
|
|
441
|
+
for (int i = 0; batch.seq_id[i]; ++i) {
|
|
442
|
+
free(batch.seq_id[i]);
|
|
443
|
+
}
|
|
444
|
+
free(batch.seq_id);
|
|
445
|
+
}
|
|
446
|
+
if (batch.logits) free(batch.logits);
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
|
|
450
|
+
batch.n_tokens = n_tokens;
|
|
451
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
452
|
+
if (tokens) {
|
|
453
|
+
batch.token[i] = tokens[i];
|
|
454
|
+
}
|
|
455
|
+
batch.pos [i] = n_past + i;
|
|
456
|
+
batch.n_seq_id[i] = 1;
|
|
457
|
+
batch.seq_id [i][0] = seq_id;
|
|
458
|
+
batch.logits [i] = 0;
|
|
459
|
+
}
|
|
460
|
+
batch.logits[n_tokens - 1] = 1;
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
|
464
|
+
template<typename A, typename B>
|
|
465
|
+
struct whisper_pair {
|
|
466
|
+
A first;
|
|
467
|
+
B second;
|
|
468
|
+
|
|
469
|
+
// Define a constructor that takes two arguments.
|
|
470
|
+
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
|
471
|
+
// Define a constructor that takes no argument.
|
|
472
|
+
whisper_pair() : first(A()), second(B()) {}
|
|
473
|
+
};
|
|
474
|
+
|
|
475
|
+
// wsp_ggml_allocr wrapper for whisper usage
|
|
476
|
+
struct whisper_allocr {
|
|
477
|
+
wsp_ggml_allocr * alloc = nullptr;
|
|
478
|
+
|
|
479
|
+
std::vector<uint8_t> meta;
|
|
480
|
+
|
|
481
|
+
wsp_ggml_backend_buffer_t buffer;
|
|
482
|
+
};
|
|
483
|
+
|
|
484
|
+
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
|
485
|
+
return allocr.meta.size() + wsp_ggml_allocr_max_size(allocr.alloc);
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
|
489
|
+
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, wsp_ggml_backend_t backend, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
|
|
490
|
+
auto & alloc = allocr.alloc;
|
|
491
|
+
auto & meta = allocr.meta;
|
|
492
|
+
|
|
493
|
+
alloc = wsp_ggml_allocr_new_measure_from_backend(backend);
|
|
494
|
+
|
|
495
|
+
meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
|
|
496
|
+
|
|
497
|
+
wsp_ggml_allocr_alloc_graph(alloc, get_graph());
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, wsp_ggml_backend_t backend) {
|
|
501
|
+
if (allocr.alloc == nullptr) {
|
|
502
|
+
// this can be null if we use external encoder like CoreML or OpenVINO
|
|
503
|
+
return;
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
auto & alloc = allocr.alloc;
|
|
507
|
+
auto & buffer = allocr.buffer;
|
|
508
|
+
|
|
509
|
+
size_t size = wsp_ggml_allocr_max_size(alloc);
|
|
510
|
+
|
|
511
|
+
wsp_ggml_allocr_free(alloc);
|
|
512
|
+
|
|
513
|
+
buffer = wsp_ggml_backend_alloc_buffer(backend, size);
|
|
514
|
+
alloc = wsp_ggml_allocr_new_from_buffer(buffer);
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
|
518
|
+
if (allocr.alloc) {
|
|
519
|
+
wsp_ggml_allocr_free(allocr.alloc);
|
|
520
|
+
wsp_ggml_backend_buffer_free(allocr.buffer);
|
|
521
|
+
allocr.alloc = nullptr;
|
|
522
|
+
}
|
|
523
|
+
}
|
|
524
|
+
|
|
434
525
|
// medium
|
|
435
526
|
// hparams: {
|
|
436
527
|
// 'n_mels': 80,
|
|
@@ -548,16 +639,31 @@ struct whisper_layer_decoder {
|
|
|
548
639
|
struct wsp_ggml_tensor * mlp_1_b;
|
|
549
640
|
};
|
|
550
641
|
|
|
642
|
+
struct whisper_kv_cell {
|
|
643
|
+
whisper_pos pos = -1;
|
|
644
|
+
|
|
645
|
+
std::set<whisper_seq_id> seq_id;
|
|
646
|
+
|
|
647
|
+
bool has_seq_id(const whisper_seq_id & id) const {
|
|
648
|
+
return seq_id.find(id) != seq_id.end();
|
|
649
|
+
}
|
|
650
|
+
};
|
|
651
|
+
|
|
551
652
|
struct whisper_kv_cache {
|
|
653
|
+
uint32_t head = 0;
|
|
654
|
+
uint32_t size = 0;
|
|
655
|
+
|
|
656
|
+
// computed before each graph build
|
|
657
|
+
uint32_t n = 0;
|
|
658
|
+
|
|
659
|
+
std::vector<whisper_kv_cell> cells;
|
|
660
|
+
|
|
552
661
|
struct wsp_ggml_tensor * k;
|
|
553
662
|
struct wsp_ggml_tensor * v;
|
|
554
663
|
|
|
555
664
|
struct wsp_ggml_context * ctx;
|
|
556
665
|
|
|
557
|
-
|
|
558
|
-
std::vector<uint8_t> buf;
|
|
559
|
-
|
|
560
|
-
int n; // number of tokens currently in the cache
|
|
666
|
+
wsp_ggml_backend_buffer_t buffer;
|
|
561
667
|
};
|
|
562
668
|
|
|
563
669
|
struct whisper_model {
|
|
@@ -594,17 +700,36 @@ struct whisper_model {
|
|
|
594
700
|
std::vector<whisper_layer_encoder> layers_encoder;
|
|
595
701
|
std::vector<whisper_layer_decoder> layers_decoder;
|
|
596
702
|
|
|
597
|
-
// context
|
|
703
|
+
// ggml context that contains all the meta information about the model tensors
|
|
598
704
|
struct wsp_ggml_context * ctx;
|
|
599
705
|
|
|
600
|
-
// the model
|
|
601
|
-
|
|
706
|
+
// the model backend data is read-only and can be shared between processors
|
|
707
|
+
struct wsp_ggml_backend_buffer * buffer;
|
|
602
708
|
|
|
603
709
|
// tensors
|
|
604
710
|
int n_loaded;
|
|
605
711
|
std::map<std::string, struct wsp_ggml_tensor *> tensors;
|
|
606
712
|
};
|
|
607
713
|
|
|
714
|
+
struct whisper_partial_utf8 {
|
|
715
|
+
uint32_t value; // bit value so far (unshifted)
|
|
716
|
+
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
|
717
|
+
};
|
|
718
|
+
|
|
719
|
+
struct whisper_grammar {
|
|
720
|
+
/*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
|
|
721
|
+
std::vector<std::vector<const whisper_grammar_element *>> stacks;
|
|
722
|
+
|
|
723
|
+
// buffer for partially generated UTF-8 sequence from accepted tokens
|
|
724
|
+
whisper_partial_utf8 partial_utf8;
|
|
725
|
+
};
|
|
726
|
+
|
|
727
|
+
struct whisper_grammar_candidate {
|
|
728
|
+
whisper_token id;
|
|
729
|
+
const uint32_t * code_points;
|
|
730
|
+
whisper_partial_utf8 partial_utf8;
|
|
731
|
+
};
|
|
732
|
+
|
|
608
733
|
struct whisper_sequence {
|
|
609
734
|
std::vector<whisper_token_data> tokens;
|
|
610
735
|
|
|
@@ -620,12 +745,13 @@ struct whisper_sequence {
|
|
|
620
745
|
|
|
621
746
|
// TAGS: WHISPER_DECODER_INIT
|
|
622
747
|
struct whisper_decoder {
|
|
623
|
-
// each decoder keeps its own KV-cache
|
|
624
|
-
whisper_kv_cache kv_self;
|
|
625
|
-
|
|
626
748
|
// the currently generated sequence of tokens
|
|
627
749
|
whisper_sequence sequence;
|
|
628
750
|
|
|
751
|
+
// grammar parse state of generated sequence of tokens
|
|
752
|
+
whisper_grammar grammar;
|
|
753
|
+
|
|
754
|
+
int i_batch; // the index of the token in the current batch
|
|
629
755
|
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
|
|
630
756
|
|
|
631
757
|
bool failed; // has the current segment failed to decode?
|
|
@@ -637,93 +763,42 @@ struct whisper_decoder {
|
|
|
637
763
|
std::vector<float> logits;
|
|
638
764
|
std::vector<float> logprobs;
|
|
639
765
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
|
644
|
-
template<typename A, typename B>
|
|
645
|
-
struct whisper_pair {
|
|
646
|
-
A first;
|
|
647
|
-
B second;
|
|
648
|
-
|
|
649
|
-
// Define a constructor that takes two arguments.
|
|
650
|
-
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
|
651
|
-
// Define a constructor that takes no argument.
|
|
652
|
-
whisper_pair() : first(A()), second(B()) {}
|
|
653
|
-
};
|
|
654
|
-
|
|
655
|
-
// beam-search helpers
|
|
656
|
-
struct kv_buf {
|
|
657
|
-
std::vector<uint8_t> k;
|
|
658
|
-
std::vector<uint8_t> v;
|
|
659
|
-
};
|
|
660
|
-
|
|
661
|
-
// wsp_ggml_allocr wrapper for whisper usage
|
|
662
|
-
struct whisper_allocr {
|
|
663
|
-
wsp_ggml_allocr * alloc = nullptr;
|
|
766
|
+
// work container used to avoid memory allocations
|
|
767
|
+
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
|
664
768
|
|
|
665
|
-
std::
|
|
666
|
-
std::vector<uint8_t> data;
|
|
769
|
+
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
667
770
|
};
|
|
668
771
|
|
|
669
|
-
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
|
670
|
-
return allocr.meta.size() + allocr.data.size();
|
|
671
|
-
}
|
|
672
|
-
|
|
673
|
-
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
|
674
|
-
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
|
|
675
|
-
const int tensor_alignment = 32;
|
|
676
|
-
|
|
677
|
-
auto & alloc = allocr.alloc;
|
|
678
|
-
auto & meta = allocr.meta;
|
|
679
|
-
auto & data = allocr.data;
|
|
680
|
-
|
|
681
|
-
meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
|
|
682
|
-
|
|
683
|
-
alloc = wsp_ggml_allocr_new_measure(tensor_alignment);
|
|
684
|
-
|
|
685
|
-
const size_t alloc_size = wsp_ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
|
|
686
|
-
|
|
687
|
-
wsp_ggml_allocr_free(alloc);
|
|
688
|
-
|
|
689
|
-
data.resize(alloc_size);
|
|
690
|
-
|
|
691
|
-
alloc = wsp_ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
|
692
|
-
}
|
|
693
|
-
|
|
694
|
-
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
|
695
|
-
if (allocr.alloc) {
|
|
696
|
-
wsp_ggml_allocr_free(allocr.alloc);
|
|
697
|
-
allocr.alloc = nullptr;
|
|
698
|
-
}
|
|
699
|
-
}
|
|
700
|
-
|
|
701
772
|
struct whisper_state {
|
|
702
773
|
int64_t t_sample_us = 0;
|
|
703
774
|
int64_t t_encode_us = 0;
|
|
704
775
|
int64_t t_decode_us = 0;
|
|
776
|
+
int64_t t_batchd_us = 0;
|
|
705
777
|
int64_t t_prompt_us = 0;
|
|
706
778
|
int64_t t_mel_us = 0;
|
|
707
779
|
|
|
708
780
|
int32_t n_sample = 0; // number of tokens sampled
|
|
709
781
|
int32_t n_encode = 0; // number of encoder calls
|
|
710
|
-
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1
|
|
711
|
-
int32_t
|
|
782
|
+
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
|
|
783
|
+
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
|
|
784
|
+
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
|
|
712
785
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
|
713
786
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
|
714
787
|
|
|
788
|
+
// unified self-attention KV cache for all decoders
|
|
789
|
+
whisper_kv_cache kv_self;
|
|
790
|
+
|
|
715
791
|
// cross-attention KV cache for the decoders
|
|
716
792
|
// shared between all decoders
|
|
717
793
|
whisper_kv_cache kv_cross;
|
|
794
|
+
|
|
718
795
|
whisper_mel mel;
|
|
719
796
|
|
|
720
|
-
|
|
797
|
+
whisper_batch batch;
|
|
721
798
|
|
|
722
|
-
|
|
723
|
-
std::vector<kv_buf> kv_swap_bufs;
|
|
799
|
+
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
|
724
800
|
|
|
725
|
-
|
|
726
|
-
std::vector<uint8_t> work_buffer;
|
|
801
|
+
wsp_ggml_backend_t backend = nullptr;
|
|
727
802
|
|
|
728
803
|
// ggml-alloc:
|
|
729
804
|
// - stores meta info about the intermediate tensors into the `meta` buffers
|
|
@@ -737,36 +812,34 @@ struct whisper_state {
|
|
|
737
812
|
struct wsp_ggml_tensor * embd_conv = nullptr;
|
|
738
813
|
struct wsp_ggml_tensor * embd_enc = nullptr;
|
|
739
814
|
|
|
815
|
+
// helpers for GPU offloading
|
|
816
|
+
std::vector<float> inp_mel;
|
|
817
|
+
std::vector<float> inp_mask;
|
|
818
|
+
|
|
740
819
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
|
741
820
|
std::vector<float> logits;
|
|
742
821
|
|
|
743
822
|
std::vector<whisper_segment> result_all;
|
|
744
823
|
std::vector<whisper_token> prompt_past;
|
|
745
824
|
|
|
746
|
-
// work container used to avoid memory allocations
|
|
747
|
-
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
|
748
|
-
|
|
749
|
-
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
|
750
|
-
|
|
751
825
|
int lang_id = 0; // english by default
|
|
752
826
|
|
|
753
827
|
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
828
|
+
|
|
754
829
|
#ifdef WHISPER_USE_COREML
|
|
755
830
|
whisper_coreml_context * ctx_coreml = nullptr;
|
|
756
831
|
#endif
|
|
757
832
|
|
|
758
|
-
#ifdef WSP_GGML_USE_METAL
|
|
759
|
-
wsp_ggml_metal_context * ctx_metal = nullptr;
|
|
760
|
-
#endif
|
|
761
|
-
|
|
762
833
|
#ifdef WHISPER_USE_OPENVINO
|
|
763
834
|
whisper_openvino_context * ctx_openvino = nullptr;
|
|
764
835
|
#endif
|
|
765
836
|
|
|
766
837
|
// [EXPERIMENTAL] token-level timestamps data
|
|
767
|
-
int64_t t_beg
|
|
838
|
+
int64_t t_beg = 0;
|
|
768
839
|
int64_t t_last = 0;
|
|
840
|
+
|
|
769
841
|
whisper_token tid_last;
|
|
842
|
+
|
|
770
843
|
std::vector<float> energy; // PCM signal energy
|
|
771
844
|
|
|
772
845
|
// [EXPERIMENTAL] speed-up techniques
|
|
@@ -780,35 +853,25 @@ struct whisper_context {
|
|
|
780
853
|
wsp_ggml_type wtype = wsp_ggml_type::WSP_GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
|
|
781
854
|
wsp_ggml_type itype = wsp_ggml_type::WSP_GGML_TYPE_F16; // intermediate type (FP32 or FP16)
|
|
782
855
|
|
|
856
|
+
whisper_context_params params;
|
|
857
|
+
|
|
783
858
|
whisper_model model;
|
|
784
859
|
whisper_vocab vocab;
|
|
860
|
+
|
|
785
861
|
whisper_state * state = nullptr;
|
|
786
862
|
|
|
863
|
+
wsp_ggml_backend_t backend = nullptr;
|
|
864
|
+
|
|
787
865
|
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
788
|
-
whisper_context_params params;
|
|
789
866
|
};
|
|
790
867
|
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
868
|
+
struct whisper_global {
|
|
869
|
+
// We save the log callback globally
|
|
870
|
+
wsp_ggml_log_callback log_callback = whisper_log_callback_default;
|
|
871
|
+
void * log_callback_user_data = nullptr;
|
|
872
|
+
};
|
|
796
873
|
|
|
797
|
-
|
|
798
|
-
#ifdef __MINGW32__
|
|
799
|
-
__attribute__((gnu_format(printf, 1, 2)))
|
|
800
|
-
#else
|
|
801
|
-
__attribute__((format(printf, 1, 2)))
|
|
802
|
-
#endif
|
|
803
|
-
#endif
|
|
804
|
-
static void log(const char * fmt, ...) {
|
|
805
|
-
if (!whisper_log) return;
|
|
806
|
-
char buf[1024];
|
|
807
|
-
va_list args;
|
|
808
|
-
va_start(args, fmt);
|
|
809
|
-
vsnprintf(buf, sizeof(buf), fmt, args);
|
|
810
|
-
whisper_log(buf);
|
|
811
|
-
}
|
|
874
|
+
static whisper_global g_state;
|
|
812
875
|
|
|
813
876
|
template<typename T>
|
|
814
877
|
static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
@@ -819,6 +882,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
|
819
882
|
static bool kv_cache_init(
|
|
820
883
|
const struct whisper_hparams & hparams,
|
|
821
884
|
struct whisper_kv_cache & cache,
|
|
885
|
+
wsp_ggml_backend_t backend,
|
|
822
886
|
wsp_ggml_type wtype,
|
|
823
887
|
int n_ctx) {
|
|
824
888
|
const int64_t n_text_state = hparams.n_text_state;
|
|
@@ -827,66 +891,206 @@ static bool kv_cache_init(
|
|
|
827
891
|
const int64_t n_mem = n_text_layer*n_ctx;
|
|
828
892
|
const int64_t n_elements = n_text_state*n_mem;
|
|
829
893
|
|
|
830
|
-
const size_t mem_bytes = 2*(wsp_ggml_type_size(wtype)*n_elements + wsp_ggml_tensor_overhead());
|
|
831
|
-
|
|
832
|
-
cache.buf.resize(mem_bytes);
|
|
833
|
-
|
|
834
894
|
struct wsp_ggml_init_params params = {
|
|
835
|
-
/*.mem_size =*/
|
|
836
|
-
/*.mem_buffer =*/
|
|
837
|
-
/*.no_alloc =*/
|
|
895
|
+
/*.mem_size =*/ 2*wsp_ggml_tensor_overhead(),
|
|
896
|
+
/*.mem_buffer =*/ nullptr,
|
|
897
|
+
/*.no_alloc =*/ true,
|
|
838
898
|
};
|
|
839
899
|
|
|
900
|
+
cache.head = 0;
|
|
901
|
+
cache.size = n_ctx;
|
|
902
|
+
|
|
903
|
+
cache.cells.clear();
|
|
904
|
+
cache.cells.resize(n_ctx);
|
|
905
|
+
|
|
840
906
|
cache.ctx = wsp_ggml_init(params);
|
|
841
907
|
|
|
842
908
|
if (!cache.ctx) {
|
|
843
|
-
|
|
909
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
|
|
844
910
|
return false;
|
|
845
911
|
}
|
|
846
912
|
|
|
847
913
|
cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
|
848
914
|
cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
|
849
915
|
|
|
916
|
+
const size_t mem_bytes = wsp_ggml_nbytes(cache.k) + wsp_ggml_nbytes(cache.v);
|
|
917
|
+
|
|
918
|
+
cache.buffer = wsp_ggml_backend_alloc_buffer(backend, mem_bytes);
|
|
919
|
+
|
|
920
|
+
// allocate the tensors into the backend buffer
|
|
921
|
+
{
|
|
922
|
+
wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(cache.buffer);
|
|
923
|
+
|
|
924
|
+
wsp_ggml_allocr_alloc(alloc, cache.k);
|
|
925
|
+
wsp_ggml_allocr_alloc(alloc, cache.v);
|
|
926
|
+
|
|
927
|
+
wsp_ggml_allocr_free(alloc);
|
|
928
|
+
}
|
|
929
|
+
|
|
850
930
|
return true;
|
|
851
931
|
}
|
|
852
932
|
|
|
853
|
-
static
|
|
854
|
-
|
|
933
|
+
static void kv_cache_free(struct whisper_kv_cache & cache) {
|
|
934
|
+
if (cache.ctx) {
|
|
935
|
+
wsp_ggml_free(cache.ctx);
|
|
936
|
+
wsp_ggml_backend_buffer_free(cache.buffer);
|
|
937
|
+
cache.ctx = nullptr;
|
|
938
|
+
}
|
|
939
|
+
}
|
|
855
940
|
|
|
856
|
-
|
|
857
|
-
|
|
941
|
+
static bool whisper_kv_cache_find_slot(
|
|
942
|
+
struct whisper_kv_cache & cache,
|
|
943
|
+
const struct whisper_batch & batch) {
|
|
944
|
+
const uint32_t n_ctx = cache.size;
|
|
945
|
+
const uint32_t n_tokens = batch.n_tokens;
|
|
858
946
|
|
|
859
|
-
|
|
860
|
-
|
|
947
|
+
if (n_tokens > n_ctx) {
|
|
948
|
+
WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
|
|
949
|
+
return false;
|
|
950
|
+
}
|
|
861
951
|
|
|
862
|
-
|
|
952
|
+
uint32_t n_tested = 0;
|
|
863
953
|
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
954
|
+
while (true) {
|
|
955
|
+
if (cache.head + n_tokens > n_ctx) {
|
|
956
|
+
n_tested += n_ctx - cache.head;
|
|
957
|
+
cache.head = 0;
|
|
958
|
+
continue;
|
|
959
|
+
}
|
|
869
960
|
|
|
870
|
-
|
|
961
|
+
bool found = true;
|
|
962
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
963
|
+
if (cache.cells[cache.head + i].pos >= 0) {
|
|
964
|
+
found = false;
|
|
965
|
+
cache.head += i + 1;
|
|
966
|
+
n_tested += i + 1;
|
|
967
|
+
break;
|
|
968
|
+
}
|
|
969
|
+
}
|
|
871
970
|
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
971
|
+
if (found) {
|
|
972
|
+
break;
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
if (n_tested >= n_ctx) {
|
|
976
|
+
//WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
|
977
|
+
return false;
|
|
978
|
+
}
|
|
875
979
|
}
|
|
876
980
|
|
|
877
|
-
|
|
878
|
-
|
|
981
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
982
|
+
cache.cells[cache.head + i].pos = batch.pos[i];
|
|
983
|
+
|
|
984
|
+
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
|
985
|
+
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
|
|
986
|
+
}
|
|
987
|
+
}
|
|
879
988
|
|
|
880
989
|
return true;
|
|
881
990
|
}
|
|
882
991
|
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
cache.
|
|
992
|
+
// find how many cells are currently in use
|
|
993
|
+
static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
|
|
994
|
+
for (uint32_t i = cache.size - 1; i > 0; --i) {
|
|
995
|
+
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
|
|
996
|
+
return i + 1;
|
|
997
|
+
}
|
|
998
|
+
}
|
|
999
|
+
|
|
1000
|
+
return 1;
|
|
1001
|
+
}
|
|
1002
|
+
|
|
1003
|
+
static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
|
1004
|
+
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
|
|
1005
|
+
cache.cells[i].pos = -1;
|
|
1006
|
+
cache.cells[i].seq_id.clear();
|
|
1007
|
+
}
|
|
1008
|
+
cache.head = 0;
|
|
1009
|
+
}
|
|
1010
|
+
|
|
1011
|
+
static void whisper_kv_cache_seq_rm(
|
|
1012
|
+
struct whisper_kv_cache & cache,
|
|
1013
|
+
whisper_seq_id seq_id,
|
|
1014
|
+
whisper_pos p0,
|
|
1015
|
+
whisper_pos p1) {
|
|
1016
|
+
uint32_t new_head = cache.size;
|
|
1017
|
+
|
|
1018
|
+
if (p0 < 0) p0 = 0;
|
|
1019
|
+
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
|
|
1020
|
+
|
|
1021
|
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
|
1022
|
+
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
|
1023
|
+
if (seq_id < 0) {
|
|
1024
|
+
cache.cells[i].seq_id.clear();
|
|
1025
|
+
} else if (cache.cells[i].has_seq_id(seq_id)) {
|
|
1026
|
+
cache.cells[i].seq_id.erase(seq_id);
|
|
1027
|
+
} else {
|
|
1028
|
+
continue;
|
|
1029
|
+
}
|
|
1030
|
+
if (cache.cells[i].seq_id.empty()) {
|
|
1031
|
+
cache.cells[i].pos = -1;
|
|
1032
|
+
if (new_head == cache.size) new_head = i;
|
|
1033
|
+
}
|
|
1034
|
+
}
|
|
1035
|
+
}
|
|
1036
|
+
|
|
1037
|
+
// If we freed up a slot, set head to it so searching can start there.
|
|
1038
|
+
if (new_head != cache.size) cache.head = new_head;
|
|
1039
|
+
}
|
|
1040
|
+
|
|
1041
|
+
static void whisper_kv_cache_seq_cp(
|
|
1042
|
+
struct whisper_kv_cache & cache,
|
|
1043
|
+
whisper_seq_id seq_id_src,
|
|
1044
|
+
whisper_seq_id seq_id_dst,
|
|
1045
|
+
whisper_pos p0,
|
|
1046
|
+
whisper_pos p1) {
|
|
1047
|
+
if (p0 < 0) p0 = 0;
|
|
1048
|
+
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
|
|
1049
|
+
|
|
1050
|
+
cache.head = 0;
|
|
1051
|
+
|
|
1052
|
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
|
1053
|
+
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
|
1054
|
+
cache.cells[i].seq_id.insert(seq_id_dst);
|
|
1055
|
+
}
|
|
887
1056
|
}
|
|
888
1057
|
}
|
|
889
1058
|
|
|
1059
|
+
static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
|
|
1060
|
+
wsp_ggml_backend_t backend_gpu = NULL;
|
|
1061
|
+
|
|
1062
|
+
// initialize the backends
|
|
1063
|
+
#ifdef WSP_GGML_USE_CUBLAS
|
|
1064
|
+
if (params.use_gpu && wsp_ggml_cublas_loaded()) {
|
|
1065
|
+
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
|
1066
|
+
backend_gpu = wsp_ggml_backend_cuda_init(0);
|
|
1067
|
+
if (!backend_gpu) {
|
|
1068
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
|
|
1069
|
+
}
|
|
1070
|
+
}
|
|
1071
|
+
#endif
|
|
1072
|
+
|
|
1073
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1074
|
+
if (params.use_gpu) {
|
|
1075
|
+
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
|
|
1076
|
+
wsp_ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
|
|
1077
|
+
backend_gpu = wsp_ggml_backend_metal_init();
|
|
1078
|
+
if (!backend_gpu) {
|
|
1079
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__);
|
|
1080
|
+
} else if (!wsp_ggml_backend_metal_supports_family(backend_gpu, 7)) {
|
|
1081
|
+
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
|
|
1082
|
+
wsp_ggml_backend_free(backend_gpu);
|
|
1083
|
+
backend_gpu = NULL;
|
|
1084
|
+
}
|
|
1085
|
+
}
|
|
1086
|
+
#endif
|
|
1087
|
+
|
|
1088
|
+
if (backend_gpu) {
|
|
1089
|
+
return backend_gpu;
|
|
1090
|
+
}
|
|
1091
|
+
return wsp_ggml_backend_cpu_init();
|
|
1092
|
+
}
|
|
1093
|
+
|
|
890
1094
|
// load the model from a ggml file
|
|
891
1095
|
//
|
|
892
1096
|
// file format:
|
|
@@ -899,7 +1103,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
|
|
|
899
1103
|
// see the convert-pt-to-ggml.py script for details
|
|
900
1104
|
//
|
|
901
1105
|
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
|
|
902
|
-
|
|
1106
|
+
WHISPER_LOG_INFO("%s: loading model\n", __func__);
|
|
903
1107
|
|
|
904
1108
|
const int64_t t_start_us = wsp_ggml_time_us();
|
|
905
1109
|
|
|
@@ -913,7 +1117,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
913
1117
|
uint32_t magic;
|
|
914
1118
|
read_safe(loader, magic);
|
|
915
1119
|
if (magic != WSP_GGML_FILE_MAGIC) {
|
|
916
|
-
|
|
1120
|
+
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
|
|
917
1121
|
return false;
|
|
918
1122
|
}
|
|
919
1123
|
}
|
|
@@ -970,41 +1174,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
970
1174
|
// in order to save memory and also to speed up the computation
|
|
971
1175
|
wctx.wtype = wsp_ggml_ftype_to_wsp_ggml_type((wsp_ggml_ftype) (model.hparams.ftype));
|
|
972
1176
|
if (wctx.wtype == WSP_GGML_TYPE_COUNT) {
|
|
973
|
-
|
|
1177
|
+
WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
|
|
974
1178
|
return false;
|
|
975
1179
|
}
|
|
976
1180
|
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
log("%s: qntvr = %d\n", __func__, qntvr);
|
|
991
|
-
log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
|
|
992
|
-
|
|
993
|
-
// print memory requirements
|
|
994
|
-
{
|
|
995
|
-
// TODO
|
|
996
|
-
//log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
|
997
|
-
// mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
|
998
|
-
}
|
|
999
|
-
|
|
1000
|
-
// initialize all memory buffers
|
|
1001
|
-
// always have at least one decoder
|
|
1002
|
-
|
|
1003
|
-
wctx.model.buf = new std::vector<uint8_t>();
|
|
1004
|
-
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
|
|
1005
|
-
|
|
1006
|
-
// we skip initialization of the state until it is needed
|
|
1007
|
-
// because it might be that state will always be provided externally.
|
|
1181
|
+
WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
|
1182
|
+
WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
|
|
1183
|
+
WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
|
1184
|
+
WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
|
|
1185
|
+
WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
|
|
1186
|
+
WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
|
|
1187
|
+
WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
|
|
1188
|
+
WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
|
|
1189
|
+
WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
|
|
1190
|
+
WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels);
|
|
1191
|
+
WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype);
|
|
1192
|
+
WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr);
|
|
1193
|
+
WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
|
|
1008
1194
|
}
|
|
1009
1195
|
|
|
1010
1196
|
// load mel filters
|
|
@@ -1025,7 +1211,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1025
1211
|
read_safe(loader, n_vocab);
|
|
1026
1212
|
|
|
1027
1213
|
//if (n_vocab != model.hparams.n_vocab) {
|
|
1028
|
-
//
|
|
1214
|
+
// WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
|
1029
1215
|
// __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
|
1030
1216
|
// return false;
|
|
1031
1217
|
//}
|
|
@@ -1045,7 +1231,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1045
1231
|
word.assign(&tmp[0], tmp.size());
|
|
1046
1232
|
} else {
|
|
1047
1233
|
// seems like we have an empty-string token in multi-language models (i = 50256)
|
|
1048
|
-
//
|
|
1234
|
+
//WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
|
|
1049
1235
|
word = "";
|
|
1050
1236
|
}
|
|
1051
1237
|
|
|
@@ -1073,7 +1259,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1073
1259
|
}
|
|
1074
1260
|
|
|
1075
1261
|
if (n_vocab < model.hparams.n_vocab) {
|
|
1076
|
-
|
|
1262
|
+
WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
|
|
1077
1263
|
for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
|
|
1078
1264
|
if (i > vocab.token_beg) {
|
|
1079
1265
|
word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
|
|
@@ -1081,6 +1267,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1081
1267
|
word = "[_EOT_]";
|
|
1082
1268
|
} else if (i == vocab.token_sot) {
|
|
1083
1269
|
word = "[_SOT_]";
|
|
1270
|
+
} else if (i == vocab.token_translate) {
|
|
1271
|
+
word = "[_TRANSLATE_]";
|
|
1272
|
+
} else if (i == vocab.token_transcribe) {
|
|
1273
|
+
word = "[_TRANSCRIBE_]";
|
|
1084
1274
|
} else if (i == vocab.token_solm) {
|
|
1085
1275
|
word = "[_SOLM_]";
|
|
1086
1276
|
} else if (i == vocab.token_prev) {
|
|
@@ -1091,6 +1281,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1091
1281
|
word = "[_NOT_]";
|
|
1092
1282
|
} else if (i == vocab.token_beg) {
|
|
1093
1283
|
word = "[_BEG_]";
|
|
1284
|
+
} else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
|
|
1285
|
+
word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
|
|
1094
1286
|
} else {
|
|
1095
1287
|
word = "[_extra_token_" + std::to_string(i) + "]";
|
|
1096
1288
|
}
|
|
@@ -1099,140 +1291,35 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1099
1291
|
}
|
|
1100
1292
|
}
|
|
1101
1293
|
|
|
1102
|
-
|
|
1294
|
+
WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages());
|
|
1103
1295
|
}
|
|
1104
1296
|
|
|
1105
|
-
size_t ctx_size = 0;
|
|
1106
|
-
|
|
1107
1297
|
const wsp_ggml_type wtype = wctx.wtype;
|
|
1108
1298
|
const wsp_ggml_type vtype = wctx.wtype == WSP_GGML_TYPE_F32 ? WSP_GGML_TYPE_F32 : WSP_GGML_TYPE_F16; // conv type
|
|
1109
1299
|
|
|
1300
|
+
// create the ggml context
|
|
1110
1301
|
{
|
|
1111
1302
|
const auto & hparams = model.hparams;
|
|
1112
1303
|
|
|
1113
|
-
const int n_vocab = hparams.n_vocab;
|
|
1114
|
-
|
|
1115
|
-
const int n_audio_ctx = hparams.n_audio_ctx;
|
|
1116
|
-
const int n_audio_state = hparams.n_audio_state;
|
|
1117
1304
|
const int n_audio_layer = hparams.n_audio_layer;
|
|
1305
|
+
const int n_text_layer = hparams.n_text_layer;
|
|
1118
1306
|
|
|
1119
|
-
const
|
|
1120
|
-
const int n_text_state = hparams.n_text_state;
|
|
1121
|
-
const int n_text_layer = hparams.n_text_layer;
|
|
1122
|
-
|
|
1123
|
-
const int n_mels = hparams.n_mels;
|
|
1124
|
-
|
|
1125
|
-
// encoder
|
|
1126
|
-
{
|
|
1127
|
-
ctx_size += n_audio_ctx*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_pe;
|
|
1128
|
-
|
|
1129
|
-
ctx_size += 3*n_mels*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_1_w
|
|
1130
|
-
ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_1_b
|
|
1131
|
-
|
|
1132
|
-
ctx_size += 3*n_audio_state*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_2_w
|
|
1133
|
-
ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_2_b
|
|
1134
|
-
|
|
1135
|
-
ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_w;
|
|
1136
|
-
ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_b;
|
|
1137
|
-
}
|
|
1138
|
-
|
|
1139
|
-
// decoder
|
|
1140
|
-
{
|
|
1141
|
-
ctx_size += n_text_ctx*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_pe;
|
|
1142
|
-
|
|
1143
|
-
ctx_size += n_vocab*n_text_state*wsp_ggml_type_sizef(wtype); // d_te;
|
|
1144
|
-
|
|
1145
|
-
ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_w;
|
|
1146
|
-
ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_b;
|
|
1147
|
-
}
|
|
1148
|
-
|
|
1149
|
-
// encoder layers
|
|
1150
|
-
{
|
|
1151
|
-
ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w
|
|
1152
|
-
ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b
|
|
1153
|
-
|
|
1154
|
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w
|
|
1155
|
-
ctx_size += n_audio_layer*( 4*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b
|
|
1156
|
-
|
|
1157
|
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w
|
|
1158
|
-
ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b
|
|
1159
|
-
|
|
1160
|
-
ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w
|
|
1161
|
-
ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b
|
|
1162
|
-
|
|
1163
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_q_w
|
|
1164
|
-
ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b
|
|
1165
|
-
|
|
1166
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_k_w
|
|
1167
|
-
|
|
1168
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_v_w
|
|
1169
|
-
ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b
|
|
1170
|
-
|
|
1171
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w
|
|
1172
|
-
ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b
|
|
1173
|
-
}
|
|
1174
|
-
|
|
1175
|
-
// decoder layers
|
|
1176
|
-
{
|
|
1177
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w
|
|
1178
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b
|
|
1179
|
-
|
|
1180
|
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w
|
|
1181
|
-
ctx_size += n_text_layer*( 4*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b
|
|
1182
|
-
|
|
1183
|
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w
|
|
1184
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b
|
|
1307
|
+
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
|
1185
1308
|
|
|
1186
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w
|
|
1187
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b
|
|
1188
|
-
|
|
1189
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_q_w
|
|
1190
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b
|
|
1191
|
-
|
|
1192
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_k_w
|
|
1193
|
-
|
|
1194
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_v_w
|
|
1195
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b
|
|
1196
|
-
|
|
1197
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w
|
|
1198
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b
|
|
1199
|
-
//
|
|
1200
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_w
|
|
1201
|
-
ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_b
|
|
1202
|
-
|
|
1203
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_q_w
|
|
1204
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_q_b
|
|
1205
|
-
|
|
1206
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_k_w
|
|
1207
|
-
|
|
1208
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_v_w
|
|
1209
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_v_b
|
|
1210
|
-
|
|
1211
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_ln_1_w
|
|
1212
|
-
ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_1_b
|
|
1213
|
-
}
|
|
1214
|
-
|
|
1215
|
-
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
|
|
1216
|
-
|
|
1217
|
-
log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
|
1218
|
-
}
|
|
1219
|
-
|
|
1220
|
-
// create the ggml context
|
|
1221
|
-
{
|
|
1222
1309
|
struct wsp_ggml_init_params params = {
|
|
1223
|
-
/*.mem_size =*/
|
|
1224
|
-
/*.mem_buffer =*/
|
|
1225
|
-
/*.no_alloc =*/
|
|
1310
|
+
/*.mem_size =*/ n_tensors*wsp_ggml_tensor_overhead(),
|
|
1311
|
+
/*.mem_buffer =*/ nullptr,
|
|
1312
|
+
/*.no_alloc =*/ true,
|
|
1226
1313
|
};
|
|
1227
1314
|
|
|
1228
1315
|
model.ctx = wsp_ggml_init(params);
|
|
1229
1316
|
if (!model.ctx) {
|
|
1230
|
-
|
|
1317
|
+
WHISPER_LOG_ERROR("%s: wsp_ggml_init() failed\n", __func__);
|
|
1231
1318
|
return false;
|
|
1232
1319
|
}
|
|
1233
1320
|
}
|
|
1234
1321
|
|
|
1235
|
-
// prepare
|
|
1322
|
+
// prepare tensors for the weights
|
|
1236
1323
|
{
|
|
1237
1324
|
auto & ctx = model.ctx;
|
|
1238
1325
|
|
|
@@ -1255,16 +1342,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1255
1342
|
|
|
1256
1343
|
// encoder
|
|
1257
1344
|
{
|
|
1258
|
-
model.e_pe
|
|
1345
|
+
model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
|
1259
1346
|
|
|
1260
|
-
model.e_conv_1_w
|
|
1261
|
-
model.e_conv_1_b
|
|
1347
|
+
model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
|
1348
|
+
model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
|
|
1262
1349
|
|
|
1263
|
-
model.e_conv_2_w
|
|
1264
|
-
model.e_conv_2_b
|
|
1350
|
+
model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
|
1351
|
+
model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
|
|
1265
1352
|
|
|
1266
|
-
model.e_ln_w
|
|
1267
|
-
model.e_ln_b
|
|
1353
|
+
model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
|
|
1354
|
+
model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
|
|
1268
1355
|
|
|
1269
1356
|
// map by name
|
|
1270
1357
|
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
|
@@ -1428,12 +1515,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1428
1515
|
}
|
|
1429
1516
|
}
|
|
1430
1517
|
|
|
1518
|
+
wctx.backend = whisper_backend_init(wctx.params);
|
|
1519
|
+
|
|
1520
|
+
{
|
|
1521
|
+
size_t size_main = 0;
|
|
1522
|
+
|
|
1523
|
+
for (const auto & t : model.tensors) {
|
|
1524
|
+
size_main += wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead();
|
|
1525
|
+
}
|
|
1526
|
+
|
|
1527
|
+
model.buffer = wsp_ggml_backend_alloc_buffer(wctx.backend, size_main);
|
|
1528
|
+
|
|
1529
|
+
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6);
|
|
1530
|
+
}
|
|
1531
|
+
|
|
1532
|
+
wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(model.buffer);
|
|
1533
|
+
|
|
1534
|
+
// allocate tensors in the backend buffers
|
|
1535
|
+
{
|
|
1536
|
+
for (const auto & t : model.tensors) {
|
|
1537
|
+
wsp_ggml_allocr_alloc(alloc, t.second);
|
|
1538
|
+
}
|
|
1539
|
+
}
|
|
1540
|
+
|
|
1431
1541
|
// load weights
|
|
1432
1542
|
{
|
|
1433
1543
|
size_t total_size = 0;
|
|
1434
1544
|
|
|
1435
1545
|
model.n_loaded = 0;
|
|
1436
1546
|
|
|
1547
|
+
std::vector<char> read_buf;
|
|
1548
|
+
|
|
1437
1549
|
while (true) {
|
|
1438
1550
|
int32_t n_dims;
|
|
1439
1551
|
int32_t length;
|
|
@@ -1460,20 +1572,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1460
1572
|
name.assign(&tmp[0], tmp.size());
|
|
1461
1573
|
|
|
1462
1574
|
if (model.tensors.find(name) == model.tensors.end()) {
|
|
1463
|
-
|
|
1575
|
+
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
|
1464
1576
|
return false;
|
|
1465
1577
|
}
|
|
1466
1578
|
|
|
1467
1579
|
auto tensor = model.tensors[name.data()];
|
|
1580
|
+
|
|
1468
1581
|
if (wsp_ggml_nelements(tensor) != nelements) {
|
|
1469
|
-
|
|
1470
|
-
|
|
1582
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
|
1583
|
+
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
|
1471
1584
|
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
|
1472
1585
|
return false;
|
|
1473
1586
|
}
|
|
1474
1587
|
|
|
1475
1588
|
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
|
1476
|
-
|
|
1589
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
|
1477
1590
|
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
|
1478
1591
|
return false;
|
|
1479
1592
|
}
|
|
@@ -1481,29 +1594,49 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1481
1594
|
const size_t bpe = wsp_ggml_type_size(wsp_ggml_type(ttype));
|
|
1482
1595
|
|
|
1483
1596
|
if ((nelements*bpe)/wsp_ggml_blck_size(tensor->type) != wsp_ggml_nbytes(tensor)) {
|
|
1484
|
-
|
|
1597
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
|
1485
1598
|
__func__, name.data(), wsp_ggml_nbytes(tensor), nelements*bpe);
|
|
1486
1599
|
return false;
|
|
1487
1600
|
}
|
|
1488
1601
|
|
|
1489
|
-
|
|
1490
|
-
|
|
1602
|
+
wsp_ggml_backend_t backend = wctx.backend;
|
|
1603
|
+
|
|
1604
|
+
//printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
|
|
1605
|
+
|
|
1606
|
+
if ((wsp_ggml_backend_is_cpu(backend)
|
|
1607
|
+
#ifdef WSP_GGML_USE_METAL
|
|
1608
|
+
|| wsp_ggml_backend_is_metal(backend)
|
|
1609
|
+
#endif
|
|
1610
|
+
)) {
|
|
1611
|
+
// for the CPU and Metal backend, we can read directly into the tensor
|
|
1612
|
+
loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
|
|
1613
|
+
BYTESWAP_TENSOR(tensor);
|
|
1614
|
+
} else {
|
|
1615
|
+
// read into a temporary buffer first, then copy to device memory
|
|
1616
|
+
read_buf.resize(wsp_ggml_nbytes(tensor));
|
|
1617
|
+
|
|
1618
|
+
loader->read(loader->context, read_buf.data(), read_buf.size());
|
|
1619
|
+
|
|
1620
|
+
wsp_ggml_backend_tensor_set(tensor, read_buf.data(), 0, wsp_ggml_nbytes(tensor));
|
|
1621
|
+
}
|
|
1491
1622
|
|
|
1492
|
-
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/
|
|
1623
|
+
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1e6);
|
|
1493
1624
|
total_size += wsp_ggml_nbytes(tensor);
|
|
1494
1625
|
model.n_loaded++;
|
|
1495
1626
|
}
|
|
1496
1627
|
|
|
1497
|
-
|
|
1628
|
+
WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
|
|
1498
1629
|
|
|
1499
1630
|
if (model.n_loaded == 0) {
|
|
1500
|
-
|
|
1631
|
+
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
|
1501
1632
|
} else if (model.n_loaded != (int) model.tensors.size()) {
|
|
1502
|
-
|
|
1633
|
+
WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
|
1503
1634
|
return false;
|
|
1504
1635
|
}
|
|
1505
1636
|
}
|
|
1506
1637
|
|
|
1638
|
+
wsp_ggml_allocr_free(alloc);
|
|
1639
|
+
|
|
1507
1640
|
wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
|
|
1508
1641
|
|
|
1509
1642
|
return true;
|
|
@@ -1559,10 +1692,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1559
1692
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1560
1693
|
assert(mel_inp.n_mel == n_mels);
|
|
1561
1694
|
|
|
1562
|
-
|
|
1695
|
+
wstate.inp_mel.resize(wsp_ggml_nelements(mel));
|
|
1696
|
+
|
|
1697
|
+
float * dst = wstate.inp_mel.data();
|
|
1563
1698
|
memset(dst, 0, wsp_ggml_nbytes(mel));
|
|
1564
1699
|
|
|
1565
|
-
const int i0 = std::min(mel_offset,
|
|
1700
|
+
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
|
1566
1701
|
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
|
1567
1702
|
|
|
1568
1703
|
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
|
@@ -1570,6 +1705,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1570
1705
|
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
|
1571
1706
|
}
|
|
1572
1707
|
}
|
|
1708
|
+
|
|
1709
|
+
wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
|
|
1573
1710
|
}
|
|
1574
1711
|
|
|
1575
1712
|
struct wsp_ggml_tensor * cur = nullptr;
|
|
@@ -1577,25 +1714,18 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1577
1714
|
if (!whisper_encode_external(wstate)) {
|
|
1578
1715
|
// convolution + gelu
|
|
1579
1716
|
{
|
|
1580
|
-
cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
|
1581
|
-
cur = wsp_ggml_add(ctx0,
|
|
1582
|
-
wsp_ggml_repeat(ctx0,
|
|
1583
|
-
model.e_conv_1_b,
|
|
1584
|
-
cur),
|
|
1585
|
-
cur);
|
|
1717
|
+
cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
|
1718
|
+
cur = wsp_ggml_add(ctx0, cur, model.e_conv_1_b);
|
|
1586
1719
|
|
|
1587
1720
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1588
1721
|
|
|
1589
1722
|
cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
|
1590
|
-
cur = wsp_ggml_add(ctx0,
|
|
1591
|
-
wsp_ggml_repeat(ctx0,
|
|
1592
|
-
model.e_conv_2_b,
|
|
1593
|
-
cur),
|
|
1594
|
-
cur);
|
|
1723
|
+
cur = wsp_ggml_add(ctx0, cur, model.e_conv_2_b);
|
|
1595
1724
|
|
|
1596
1725
|
cur = wsp_ggml_gelu(ctx0, cur);
|
|
1597
1726
|
}
|
|
1598
1727
|
|
|
1728
|
+
wsp_ggml_set_name(cur, "embd_conv");
|
|
1599
1729
|
wstate.embd_conv = cur;
|
|
1600
1730
|
} else {
|
|
1601
1731
|
#ifdef WHISPER_USE_COREML
|
|
@@ -1603,7 +1733,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1603
1733
|
wsp_ggml_allocr_alloc(alloc, cur);
|
|
1604
1734
|
|
|
1605
1735
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1606
|
-
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
|
1736
|
+
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
|
|
1607
1737
|
}
|
|
1608
1738
|
#endif
|
|
1609
1739
|
#ifdef WHISPER_USE_OPENVINO
|
|
@@ -1615,6 +1745,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
|
|
|
1615
1745
|
}
|
|
1616
1746
|
#endif
|
|
1617
1747
|
|
|
1748
|
+
wsp_ggml_set_name(cur, "embd_enc");
|
|
1618
1749
|
wstate.embd_enc = cur;
|
|
1619
1750
|
}
|
|
1620
1751
|
|
|
@@ -1648,15 +1779,22 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1648
1779
|
|
|
1649
1780
|
wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
|
1650
1781
|
|
|
1782
|
+
//struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state);
|
|
1783
|
+
//wsp_ggml_allocr_alloc(alloc, cur);
|
|
1784
|
+
|
|
1785
|
+
//if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1786
|
+
// wsp_ggml_backend_tensor_copy(wstate.embd_conv, cur);
|
|
1787
|
+
//}
|
|
1788
|
+
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
|
|
1789
|
+
|
|
1651
1790
|
struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
1652
1791
|
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
1653
1792
|
|
|
1654
1793
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1655
|
-
|
|
1794
|
+
const float val = 1.0f/sqrtf(float(n_state)/n_head);
|
|
1795
|
+
wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
|
1656
1796
|
}
|
|
1657
1797
|
|
|
1658
|
-
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
|
|
1659
|
-
|
|
1660
1798
|
// ===================================================================
|
|
1661
1799
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
1662
1800
|
//static int iter = -1;
|
|
@@ -1675,7 +1813,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1675
1813
|
const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
|
|
1676
1814
|
|
|
1677
1815
|
struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
|
|
1678
|
-
|
|
1679
1816
|
cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_cont(ctx0, wsp_ggml_transpose(ctx0, cur)));
|
|
1680
1817
|
|
|
1681
1818
|
// ===================================================================
|
|
@@ -1863,11 +2000,11 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
|
|
|
1863
2000
|
////////////////////////////////////////////////////////////////////////////
|
|
1864
2001
|
|
|
1865
2002
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
1866
|
-
// wsp_ggml_used_mem(ctx0)/
|
|
1867
|
-
// wstate.get_buf_max_mem(0)/
|
|
1868
|
-
// wstate.get_buf_max_mem(1)/
|
|
1869
|
-
// wstate.get_buf_max_mem(2)/
|
|
1870
|
-
// wstate.get_buf_max_mem(3)/
|
|
2003
|
+
// wsp_ggml_used_mem(ctx0)/1e6,
|
|
2004
|
+
// wstate.get_buf_max_mem(0)/1e6,
|
|
2005
|
+
// wstate.get_buf_max_mem(1)/1e6,
|
|
2006
|
+
// wstate.get_buf_max_mem(2)/1e6,
|
|
2007
|
+
// wstate.get_buf_max_mem(3)/1e6);
|
|
1871
2008
|
|
|
1872
2009
|
wsp_ggml_free(ctx0);
|
|
1873
2010
|
|
|
@@ -1897,13 +2034,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
|
|
|
1897
2034
|
|
|
1898
2035
|
wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
|
1899
2036
|
|
|
2037
|
+
//struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
|
|
2038
|
+
//wsp_ggml_allocr_alloc(alloc, cur);
|
|
2039
|
+
|
|
2040
|
+
//if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2041
|
+
// wsp_ggml_backend_tensor_copy(wstate.embd_enc, cur);
|
|
2042
|
+
//}
|
|
1900
2043
|
struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
|
|
1901
2044
|
|
|
1902
2045
|
struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
|
|
1903
2046
|
wsp_ggml_allocr_alloc(alloc, Kscale);
|
|
1904
2047
|
|
|
1905
2048
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
1906
|
-
|
|
2049
|
+
const float val = pow(float(n_state) / n_head, -0.25);
|
|
2050
|
+
wsp_ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
|
|
1907
2051
|
}
|
|
1908
2052
|
|
|
1909
2053
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
@@ -1974,7 +2118,7 @@ static bool whisper_encode_internal(
|
|
|
1974
2118
|
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1975
2119
|
|
|
1976
2120
|
if (!whisper_encode_external(wstate)) {
|
|
1977
|
-
wsp_ggml_graph_compute_helper(wstate.
|
|
2121
|
+
wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
1978
2122
|
}
|
|
1979
2123
|
}
|
|
1980
2124
|
|
|
@@ -1988,16 +2132,7 @@ static bool whisper_encode_internal(
|
|
|
1988
2132
|
|
|
1989
2133
|
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
1990
2134
|
|
|
1991
|
-
|
|
1992
|
-
if (wstate.ctx_metal) {
|
|
1993
|
-
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
1994
|
-
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
1995
|
-
} else {
|
|
1996
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
1997
|
-
}
|
|
1998
|
-
#else
|
|
1999
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2000
|
-
#endif
|
|
2135
|
+
wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
2001
2136
|
}
|
|
2002
2137
|
|
|
2003
2138
|
// cross
|
|
@@ -2010,49 +2145,40 @@ static bool whisper_encode_internal(
|
|
|
2010
2145
|
|
|
2011
2146
|
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
2012
2147
|
|
|
2013
|
-
|
|
2014
|
-
if (wstate.ctx_metal) {
|
|
2015
|
-
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
2016
|
-
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
2017
|
-
} else {
|
|
2018
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2019
|
-
}
|
|
2020
|
-
#else
|
|
2021
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2022
|
-
#endif
|
|
2148
|
+
wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
2023
2149
|
}
|
|
2024
2150
|
|
|
2025
|
-
// wsp_ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
2026
|
-
|
|
2027
2151
|
wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
|
|
2028
2152
|
wstate.n_encode++;
|
|
2029
2153
|
|
|
2030
|
-
return
|
|
2154
|
+
return !(abort_callback && abort_callback(abort_callback_data));
|
|
2031
2155
|
}
|
|
2032
2156
|
|
|
2033
2157
|
static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
2034
2158
|
whisper_context & wctx,
|
|
2035
2159
|
whisper_state & wstate,
|
|
2036
|
-
|
|
2037
|
-
const whisper_token * tokens,
|
|
2038
|
-
int n_tokens,
|
|
2039
|
-
int n_past) {
|
|
2160
|
+
const whisper_batch & batch) {
|
|
2040
2161
|
const auto & model = wctx.model;
|
|
2041
2162
|
const auto & hparams = model.hparams;
|
|
2042
2163
|
|
|
2043
|
-
auto & kv_self =
|
|
2164
|
+
auto & kv_self = wstate.kv_self;
|
|
2044
2165
|
|
|
2045
2166
|
WHISPER_ASSERT(!!kv_self.ctx);
|
|
2046
2167
|
|
|
2047
|
-
|
|
2168
|
+
wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
|
2169
|
+
|
|
2170
|
+
const int n_ctx = kv_self.size;
|
|
2048
2171
|
const int n_state = hparams.n_text_state;
|
|
2049
2172
|
const int n_head = hparams.n_text_head;
|
|
2050
2173
|
const int n_layer = hparams.n_text_layer;
|
|
2051
2174
|
|
|
2052
|
-
const int
|
|
2053
|
-
const int
|
|
2175
|
+
const int n_tokens = batch.n_tokens;
|
|
2176
|
+
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
2054
2177
|
|
|
2055
|
-
|
|
2178
|
+
const int32_t n_kv = wsp_ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
|
|
2179
|
+
const int32_t kv_head = wsp_ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
|
|
2180
|
+
|
|
2181
|
+
//WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
|
2056
2182
|
|
|
2057
2183
|
struct wsp_ggml_init_params params = {
|
|
2058
2184
|
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
|
@@ -2064,21 +2190,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2064
2190
|
|
|
2065
2191
|
wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
|
2066
2192
|
|
|
2067
|
-
|
|
2068
|
-
|
|
2069
|
-
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
|
|
2193
|
+
struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2070
2194
|
wsp_ggml_allocr_alloc(alloc, embd);
|
|
2071
2195
|
|
|
2072
2196
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2073
|
-
|
|
2197
|
+
wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
|
|
2074
2198
|
}
|
|
2075
2199
|
|
|
2076
|
-
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32,
|
|
2200
|
+
struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
|
|
2077
2201
|
wsp_ggml_allocr_alloc(alloc, position);
|
|
2078
2202
|
|
|
2079
2203
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2080
|
-
for (int i = 0; i <
|
|
2081
|
-
|
|
2204
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
2205
|
+
const int32_t val = batch.pos[i];
|
|
2206
|
+
wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
|
2082
2207
|
}
|
|
2083
2208
|
}
|
|
2084
2209
|
|
|
@@ -2086,7 +2211,33 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2086
2211
|
wsp_ggml_allocr_alloc(alloc, KQscale);
|
|
2087
2212
|
|
|
2088
2213
|
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2089
|
-
|
|
2214
|
+
const float val = pow(float(n_state)/n_head, -0.25);
|
|
2215
|
+
wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
|
2216
|
+
}
|
|
2217
|
+
|
|
2218
|
+
struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, n_tokens, 1);
|
|
2219
|
+
wsp_ggml_allocr_alloc(alloc, KQ_mask);
|
|
2220
|
+
|
|
2221
|
+
if (!wsp_ggml_allocr_is_measure(alloc)) {
|
|
2222
|
+
wstate.inp_mask.resize(n_kv*n_tokens);
|
|
2223
|
+
|
|
2224
|
+
float * data = wstate.inp_mask.data();
|
|
2225
|
+
memset(data, 0, wsp_ggml_nbytes(KQ_mask));
|
|
2226
|
+
|
|
2227
|
+
for (int h = 0; h < 1; ++h) {
|
|
2228
|
+
for (int j = 0; j < n_tokens; ++j) {
|
|
2229
|
+
const whisper_pos pos = batch.pos[j];
|
|
2230
|
+
const whisper_seq_id seq_id = batch.seq_id[j][0];
|
|
2231
|
+
|
|
2232
|
+
for (int i = 0; i < n_kv; ++i) {
|
|
2233
|
+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
|
2234
|
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
|
2235
|
+
}
|
|
2236
|
+
}
|
|
2237
|
+
}
|
|
2238
|
+
}
|
|
2239
|
+
|
|
2240
|
+
wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
|
|
2090
2241
|
}
|
|
2091
2242
|
|
|
2092
2243
|
// token encoding + position encoding
|
|
@@ -2141,12 +2292,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2141
2292
|
Vcur,
|
|
2142
2293
|
layer.attn_v_b);
|
|
2143
2294
|
|
|
2144
|
-
Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state,
|
|
2295
|
+
Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
|
2145
2296
|
|
|
2146
|
-
struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k,
|
|
2147
|
-
struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v,
|
|
2297
|
+
struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
|
2298
|
+
struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
|
2148
2299
|
( n_ctx)*wsp_ggml_element_size(kv_self.v),
|
|
2149
|
-
(il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state +
|
|
2300
|
+
(il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
|
|
2150
2301
|
|
|
2151
2302
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
|
|
2152
2303
|
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
|
|
@@ -2156,12 +2307,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2156
2307
|
|
|
2157
2308
|
struct wsp_ggml_tensor * Q =
|
|
2158
2309
|
wsp_ggml_permute(ctx0,
|
|
2159
|
-
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head,
|
|
2310
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
|
2160
2311
|
0, 2, 1, 3);
|
|
2161
2312
|
|
|
2162
2313
|
struct wsp_ggml_tensor * K =
|
|
2163
2314
|
wsp_ggml_view_3d(ctx0, kv_self.k,
|
|
2164
|
-
n_state/n_head,
|
|
2315
|
+
n_state/n_head, n_kv, n_head,
|
|
2165
2316
|
wsp_ggml_element_size(kv_self.k)*n_state,
|
|
2166
2317
|
wsp_ggml_element_size(kv_self.k)*n_state/n_head,
|
|
2167
2318
|
wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
|
@@ -2171,16 +2322,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2171
2322
|
|
|
2172
2323
|
//struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
|
|
2173
2324
|
|
|
2174
|
-
struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
2325
|
+
//struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
2326
|
+
struct wsp_ggml_tensor * KQ_masked = wsp_ggml_add(ctx0, KQ, KQ_mask);
|
|
2175
2327
|
|
|
2176
2328
|
struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
|
|
2177
2329
|
|
|
2178
2330
|
struct wsp_ggml_tensor * V =
|
|
2179
2331
|
wsp_ggml_view_3d(ctx0, kv_self.v,
|
|
2180
|
-
|
|
2332
|
+
n_kv, n_state/n_head, n_head,
|
|
2181
2333
|
n_ctx*wsp_ggml_element_size(kv_self.v),
|
|
2182
2334
|
n_ctx*wsp_ggml_element_size(kv_self.v)*n_state/n_head,
|
|
2183
|
-
|
|
2335
|
+
n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
|
|
2184
2336
|
|
|
2185
2337
|
struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
2186
2338
|
|
|
@@ -2188,7 +2340,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2188
2340
|
|
|
2189
2341
|
cur = wsp_ggml_cpy(ctx0,
|
|
2190
2342
|
KQV_merged,
|
|
2191
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state,
|
|
2343
|
+
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
|
|
2192
2344
|
}
|
|
2193
2345
|
|
|
2194
2346
|
// projection
|
|
@@ -2232,33 +2384,33 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2232
2384
|
// Kcross is already scaled
|
|
2233
2385
|
struct wsp_ggml_tensor * Kcross =
|
|
2234
2386
|
wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
|
|
2235
|
-
n_state/n_head,
|
|
2387
|
+
n_state/n_head, n_audio_ctx, n_head,
|
|
2236
2388
|
wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
|
|
2237
2389
|
wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
|
2238
|
-
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*
|
|
2390
|
+
wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
|
2239
2391
|
|
|
2240
2392
|
//struct wsp_ggml_tensor * Vcross =
|
|
2241
2393
|
// wsp_ggml_reshape_3d(ctx0,
|
|
2242
|
-
// wsp_ggml_view_1d(ctx0, wstate.kv_cross.v,
|
|
2243
|
-
// n_state/n_head, n_head,
|
|
2394
|
+
// wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
|
|
2395
|
+
// n_state/n_head, n_head, n_audio_ctx);
|
|
2244
2396
|
|
|
2245
2397
|
//struct wsp_ggml_tensor * V_trans =
|
|
2246
2398
|
// wsp_ggml_cpy(ctx0,
|
|
2247
2399
|
// wsp_ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
|
2248
|
-
// wsp_ggml_new_tensor_3d(ctx0, Vcross->type,
|
|
2400
|
+
// wsp_ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
|
2249
2401
|
|
|
2250
2402
|
struct wsp_ggml_tensor * V =
|
|
2251
2403
|
wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
|
|
2252
|
-
|
|
2253
|
-
|
|
2254
|
-
|
|
2255
|
-
|
|
2404
|
+
n_audio_ctx, n_state/n_head, n_head,
|
|
2405
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
|
|
2406
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
|
2407
|
+
n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
|
2256
2408
|
|
|
2257
2409
|
// ------
|
|
2258
2410
|
|
|
2259
2411
|
struct wsp_ggml_tensor * Q =
|
|
2260
2412
|
wsp_ggml_permute(ctx0,
|
|
2261
|
-
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head,
|
|
2413
|
+
wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
|
2262
2414
|
0, 2, 1, 3);
|
|
2263
2415
|
|
|
2264
2416
|
// K * Q
|
|
@@ -2279,10 +2431,10 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2279
2431
|
|
|
2280
2432
|
struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2281
2433
|
|
|
2282
|
-
// cur = KQV_merged.contiguous().view(n_state,
|
|
2434
|
+
// cur = KQV_merged.contiguous().view(n_state, n_tokens)
|
|
2283
2435
|
cur = wsp_ggml_cpy(ctx0,
|
|
2284
2436
|
KQV_merged,
|
|
2285
|
-
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state,
|
|
2437
|
+
wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
|
|
2286
2438
|
}
|
|
2287
2439
|
|
|
2288
2440
|
// projection
|
|
@@ -2354,9 +2506,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2354
2506
|
}
|
|
2355
2507
|
|
|
2356
2508
|
// compute logits only for the last token
|
|
2357
|
-
// comment this line to compute logits for all
|
|
2509
|
+
// comment this line to compute logits for all n_tokens
|
|
2358
2510
|
// might be useful in the future
|
|
2359
|
-
cur = wsp_ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
|
2511
|
+
//cur = wsp_ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
|
2360
2512
|
|
|
2361
2513
|
struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
|
|
2362
2514
|
|
|
@@ -2380,10 +2532,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
|
|
|
2380
2532
|
static bool whisper_decode_internal(
|
|
2381
2533
|
whisper_context & wctx,
|
|
2382
2534
|
whisper_state & wstate,
|
|
2383
|
-
|
|
2384
|
-
const whisper_token * tokens,
|
|
2385
|
-
const int n_tokens,
|
|
2386
|
-
const int n_past,
|
|
2535
|
+
const whisper_batch & batch,
|
|
2387
2536
|
const int n_threads,
|
|
2388
2537
|
whisper_abort_callback abort_callback,
|
|
2389
2538
|
void * abort_callback_data) {
|
|
@@ -2392,65 +2541,72 @@ static bool whisper_decode_internal(
|
|
|
2392
2541
|
const auto & model = wctx.model;
|
|
2393
2542
|
const auto & hparams = model.hparams;
|
|
2394
2543
|
|
|
2395
|
-
const int n_vocab
|
|
2544
|
+
const int n_vocab = hparams.n_vocab;
|
|
2545
|
+
const int n_tokens = batch.n_tokens;
|
|
2396
2546
|
|
|
2397
2547
|
auto & logits_out = wstate.logits;
|
|
2398
2548
|
|
|
2399
2549
|
struct wsp_ggml_tensor * logits;
|
|
2400
2550
|
|
|
2551
|
+
// find KV slot for the batch
|
|
2552
|
+
{
|
|
2553
|
+
auto & kv_self = wstate.kv_self;
|
|
2554
|
+
|
|
2555
|
+
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
|
|
2556
|
+
return false;
|
|
2557
|
+
}
|
|
2558
|
+
|
|
2559
|
+
kv_self.n = whisper_kv_cache_cell_max(kv_self);
|
|
2560
|
+
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
|
2561
|
+
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
|
2562
|
+
}
|
|
2563
|
+
|
|
2401
2564
|
// decoder
|
|
2402
2565
|
{
|
|
2403
2566
|
auto & alloc = wstate.alloc_decode.alloc;
|
|
2404
2567
|
|
|
2405
2568
|
wsp_ggml_allocr_reset(alloc);
|
|
2406
2569
|
|
|
2407
|
-
wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate,
|
|
2570
|
+
wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
|
|
2408
2571
|
|
|
2409
2572
|
wsp_ggml_allocr_alloc_graph(alloc, gf);
|
|
2410
2573
|
|
|
2411
2574
|
logits = gf->nodes[gf->n_nodes - 1];
|
|
2412
2575
|
|
|
2413
|
-
|
|
2414
|
-
if (wstate.ctx_metal) {
|
|
2415
|
-
wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
|
2416
|
-
wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
|
2417
|
-
} else {
|
|
2418
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2419
|
-
}
|
|
2420
|
-
#else
|
|
2421
|
-
wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
|
2422
|
-
#endif
|
|
2576
|
+
wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
2423
2577
|
}
|
|
2424
2578
|
|
|
2425
|
-
|
|
2426
|
-
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
|
|
2430
|
-
|
|
2431
|
-
|
|
2579
|
+
logits_out.resize(n_tokens*n_vocab);
|
|
2580
|
+
for (int i = 0; i < n_tokens; i++) {
|
|
2581
|
+
if (batch.logits[i] == 0) {
|
|
2582
|
+
continue;
|
|
2583
|
+
}
|
|
2584
|
+
wsp_ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
|
|
2585
|
+
}
|
|
2432
2586
|
|
|
2433
|
-
if (n_tokens > 1) {
|
|
2587
|
+
if (batch.n_tokens > 1) {
|
|
2434
2588
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
2435
|
-
// wsp_ggml_used_mem(ctx0)/
|
|
2436
|
-
// wstate.get_buf_max_mem(0)/
|
|
2437
|
-
// wstate.get_buf_max_mem(1)/
|
|
2438
|
-
// wstate.get_buf_max_mem(2)/
|
|
2439
|
-
// wstate.get_buf_max_mem(3)/
|
|
2589
|
+
// wsp_ggml_used_mem(ctx0)/1e6,
|
|
2590
|
+
// wstate.get_buf_max_mem(0)/1e6,
|
|
2591
|
+
// wstate.get_buf_max_mem(1)/1e6,
|
|
2592
|
+
// wstate.get_buf_max_mem(2)/1e6,
|
|
2593
|
+
// wstate.get_buf_max_mem(3)/1e6);
|
|
2440
2594
|
}
|
|
2441
2595
|
|
|
2442
|
-
if (n_tokens == 1) {
|
|
2596
|
+
if (batch.n_tokens == 1) {
|
|
2443
2597
|
wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
|
|
2444
2598
|
wstate.n_decode++;
|
|
2599
|
+
} else if (batch.n_tokens < 16) {
|
|
2600
|
+
wstate.t_batchd_us += wsp_ggml_time_us() - t_start_us;
|
|
2601
|
+
wstate.n_batchd += n_tokens;
|
|
2445
2602
|
} else {
|
|
2446
2603
|
wstate.t_prompt_us += wsp_ggml_time_us() - t_start_us;
|
|
2447
|
-
wstate.n_prompt
|
|
2604
|
+
wstate.n_prompt += n_tokens;
|
|
2448
2605
|
}
|
|
2449
2606
|
|
|
2450
|
-
return
|
|
2607
|
+
return !(abort_callback && abort_callback(abort_callback_data));
|
|
2451
2608
|
}
|
|
2452
2609
|
|
|
2453
|
-
|
|
2454
2610
|
// 500 -> 00:05.000
|
|
2455
2611
|
// 6000 -> 01:00.000
|
|
2456
2612
|
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
@@ -2794,7 +2950,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
|
|
|
2794
2950
|
--j;
|
|
2795
2951
|
}
|
|
2796
2952
|
if (!found) {
|
|
2797
|
-
|
|
2953
|
+
WHISPER_LOG_ERROR("unknown token\n");
|
|
2798
2954
|
++i;
|
|
2799
2955
|
}
|
|
2800
2956
|
}
|
|
@@ -2857,95 +3013,105 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
|
|
2857
3013
|
|
|
2858
3014
|
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
2859
3015
|
fill_sin_cos_table();
|
|
3016
|
+
|
|
2860
3017
|
whisper_state * state = new whisper_state;
|
|
2861
3018
|
|
|
2862
|
-
|
|
2863
|
-
|
|
3019
|
+
state->backend = whisper_backend_init(ctx->params);
|
|
3020
|
+
|
|
3021
|
+
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
|
|
3022
|
+
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
|
3023
|
+
const int factor = 3;
|
|
3024
|
+
|
|
3025
|
+
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
|
|
3026
|
+
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
|
2864
3027
|
delete state;
|
|
2865
3028
|
return nullptr;
|
|
2866
3029
|
}
|
|
2867
3030
|
|
|
2868
3031
|
{
|
|
2869
|
-
const size_t memory_size = wsp_ggml_nbytes(state->
|
|
2870
|
-
|
|
3032
|
+
const size_t memory_size = wsp_ggml_nbytes(state->kv_self.k) + wsp_ggml_nbytes(state->kv_self.v);
|
|
3033
|
+
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
2871
3034
|
}
|
|
2872
3035
|
|
|
2873
|
-
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
|
2874
|
-
|
|
3036
|
+
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
|
3037
|
+
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
2875
3038
|
delete state;
|
|
2876
3039
|
return nullptr;
|
|
2877
3040
|
}
|
|
2878
3041
|
|
|
2879
3042
|
{
|
|
2880
3043
|
const size_t memory_size = wsp_ggml_nbytes(state->kv_cross.k) + wsp_ggml_nbytes(state->kv_cross.v);
|
|
2881
|
-
|
|
3044
|
+
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
|
2882
3045
|
}
|
|
2883
3046
|
|
|
2884
|
-
|
|
3047
|
+
|
|
2885
3048
|
#ifdef WHISPER_USE_COREML
|
|
2886
3049
|
if (ctx->params.use_coreml) {
|
|
2887
3050
|
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
|
2888
3051
|
|
|
2889
|
-
|
|
2890
|
-
|
|
3052
|
+
WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
3053
|
+
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
|
2891
3054
|
|
|
2892
3055
|
state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
|
|
2893
3056
|
if (!state->ctx_coreml) {
|
|
2894
|
-
|
|
3057
|
+
WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
2895
3058
|
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
|
2896
3059
|
delete state;
|
|
2897
3060
|
return nullptr;
|
|
2898
3061
|
#endif
|
|
2899
3062
|
} else {
|
|
2900
|
-
|
|
3063
|
+
WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__);
|
|
2901
3064
|
}
|
|
2902
3065
|
}
|
|
2903
3066
|
#endif
|
|
2904
3067
|
|
|
2905
3068
|
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
|
2906
3069
|
|
|
2907
|
-
state->
|
|
3070
|
+
state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
|
|
2908
3071
|
|
|
2909
3072
|
// TAGS: WHISPER_DECODER_INIT
|
|
2910
3073
|
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
|
2911
3074
|
|
|
2912
|
-
state->decoders[0].probs.reserve
|
|
2913
|
-
state->decoders[0].logits.reserve
|
|
2914
|
-
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
|
3075
|
+
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
|
|
3076
|
+
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
|
|
3077
|
+
state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
|
|
3078
|
+
state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
|
|
3079
|
+
|
|
3080
|
+
state->decoders[0].rng = std::mt19937(0);
|
|
2915
3081
|
|
|
2916
3082
|
// conv allocator
|
|
2917
3083
|
{
|
|
2918
|
-
whisper_allocr_graph_init(state->alloc_conv,
|
|
3084
|
+
whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
|
|
2919
3085
|
[&]() {
|
|
2920
3086
|
return whisper_build_graph_conv(*ctx, *state, 0);
|
|
2921
3087
|
});
|
|
2922
3088
|
|
|
2923
|
-
|
|
3089
|
+
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
|
|
2924
3090
|
}
|
|
2925
3091
|
|
|
2926
3092
|
// encoder allocator
|
|
2927
3093
|
if (!whisper_encode_external(*state)) {
|
|
2928
|
-
whisper_allocr_graph_init(state->alloc_encode,
|
|
3094
|
+
whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
|
|
2929
3095
|
[&]() {
|
|
2930
3096
|
return whisper_build_graph_encoder(*ctx, *state);
|
|
2931
3097
|
});
|
|
2932
3098
|
|
|
2933
|
-
|
|
3099
|
+
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
|
|
2934
3100
|
}
|
|
2935
3101
|
|
|
2936
3102
|
// cross allocator
|
|
2937
3103
|
{
|
|
2938
|
-
whisper_allocr_graph_init(state->alloc_cross,
|
|
3104
|
+
whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
|
|
2939
3105
|
[&]() {
|
|
2940
3106
|
return whisper_build_graph_cross(*ctx, *state);
|
|
2941
3107
|
});
|
|
2942
3108
|
|
|
2943
|
-
|
|
3109
|
+
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
|
|
2944
3110
|
}
|
|
2945
3111
|
|
|
2946
3112
|
// decoder allocator
|
|
2947
3113
|
{
|
|
2948
|
-
whisper_allocr_graph_init(state->alloc_decode,
|
|
3114
|
+
whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
|
|
2949
3115
|
[&]() {
|
|
2950
3116
|
const auto & hparams = ctx->model.hparams;
|
|
2951
3117
|
|
|
@@ -2953,74 +3119,18 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2953
3119
|
const int n_tokens = hparams.n_text_ctx;
|
|
2954
3120
|
const int n_past = 0;
|
|
2955
3121
|
|
|
2956
|
-
|
|
2957
|
-
});
|
|
2958
|
-
|
|
2959
|
-
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
|
2960
|
-
}
|
|
2961
|
-
|
|
2962
|
-
#ifdef WSP_GGML_USE_METAL
|
|
2963
|
-
if (ctx->params.use_gpu) {
|
|
2964
|
-
state->ctx_metal = wsp_ggml_metal_init(1);
|
|
2965
|
-
if (!state->ctx_metal) {
|
|
2966
|
-
log("%s: wsp_ggml_metal_init() failed\n", __func__);
|
|
2967
|
-
delete state;
|
|
2968
|
-
return nullptr;
|
|
2969
|
-
}
|
|
2970
|
-
}
|
|
2971
|
-
|
|
2972
|
-
if (state->ctx_metal) {
|
|
2973
|
-
log("%s: Metal context initialized\n", __func__);
|
|
2974
|
-
|
|
2975
|
-
// this allocates all Metal resources and memory buffers
|
|
2976
|
-
|
|
2977
|
-
void * data_ptr = NULL;
|
|
2978
|
-
size_t data_size = 0;
|
|
2979
|
-
|
|
2980
|
-
// TODO: add mmap support
|
|
2981
|
-
//if (params.use_mmap) {
|
|
2982
|
-
// data_ptr = ctx->model.mapping->addr;
|
|
2983
|
-
// data_size = ctx->model.mapping->size;
|
|
2984
|
-
//} else {
|
|
2985
|
-
// data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
|
|
2986
|
-
// data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
|
|
2987
|
-
//}
|
|
2988
|
-
|
|
2989
|
-
data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
|
|
2990
|
-
data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
|
|
2991
|
-
|
|
2992
|
-
const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx);
|
|
2993
|
-
|
|
2994
|
-
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
|
|
2995
|
-
|
|
2996
|
-
#define WHISPER_METAL_CHECK_BUF(result) \
|
|
2997
|
-
if (!(result)) { \
|
|
2998
|
-
log("%s: failed to add metal buffer\n", __func__); \
|
|
2999
|
-
delete state; \
|
|
3000
|
-
return nullptr; \
|
|
3001
|
-
}
|
|
3002
|
-
|
|
3003
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
|
|
3004
|
-
|
|
3005
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
|
|
3006
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
|
|
3007
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
|
|
3008
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
|
|
3009
|
-
|
|
3010
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
|
|
3011
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
|
|
3012
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
|
|
3013
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
|
|
3122
|
+
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
|
|
3014
3123
|
|
|
3015
|
-
|
|
3016
|
-
|
|
3017
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
|
|
3018
|
-
#undef WHISPER_METAL_CHECK_BUF
|
|
3124
|
+
return whisper_build_graph_decoder(*ctx, *state, state->batch);
|
|
3125
|
+
});
|
|
3019
3126
|
|
|
3127
|
+
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
|
|
3020
3128
|
}
|
|
3021
|
-
#endif
|
|
3022
3129
|
|
|
3023
|
-
state->
|
|
3130
|
+
whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
|
|
3131
|
+
whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
|
|
3132
|
+
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
|
|
3133
|
+
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
|
|
3024
3134
|
|
|
3025
3135
|
return state;
|
|
3026
3136
|
}
|
|
@@ -3039,7 +3149,7 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
3039
3149
|
return 1;
|
|
3040
3150
|
#else
|
|
3041
3151
|
if (!model_path && ctx->path_model.empty()) {
|
|
3042
|
-
|
|
3152
|
+
WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
|
|
3043
3153
|
return 1;
|
|
3044
3154
|
}
|
|
3045
3155
|
|
|
@@ -3059,15 +3169,15 @@ int whisper_ctx_init_openvino_encoder(
|
|
|
3059
3169
|
path_cache = cache_dir;
|
|
3060
3170
|
}
|
|
3061
3171
|
|
|
3062
|
-
|
|
3063
|
-
|
|
3172
|
+
WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
|
|
3173
|
+
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
|
3064
3174
|
|
|
3065
3175
|
ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
|
|
3066
3176
|
if (!ctx->state->ctx_openvino) {
|
|
3067
|
-
|
|
3177
|
+
WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
|
|
3068
3178
|
return 1;
|
|
3069
3179
|
} else {
|
|
3070
|
-
|
|
3180
|
+
WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__);
|
|
3071
3181
|
}
|
|
3072
3182
|
|
|
3073
3183
|
return 0;
|
|
@@ -3083,11 +3193,11 @@ struct whisper_context_params whisper_context_default_params() {
|
|
|
3083
3193
|
}
|
|
3084
3194
|
|
|
3085
3195
|
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
|
|
3086
|
-
|
|
3196
|
+
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
|
|
3087
3197
|
|
|
3088
3198
|
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
3089
3199
|
if (!fin) {
|
|
3090
|
-
|
|
3200
|
+
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
|
|
3091
3201
|
return nullptr;
|
|
3092
3202
|
}
|
|
3093
3203
|
|
|
@@ -3129,7 +3239,7 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
|
|
|
3129
3239
|
|
|
3130
3240
|
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
|
|
3131
3241
|
|
|
3132
|
-
|
|
3242
|
+
WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__);
|
|
3133
3243
|
|
|
3134
3244
|
whisper_model_loader loader = {};
|
|
3135
3245
|
|
|
@@ -3165,7 +3275,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
|
|
|
3165
3275
|
|
|
3166
3276
|
if (!whisper_model_load(loader, *ctx)) {
|
|
3167
3277
|
loader->close(loader->context);
|
|
3168
|
-
|
|
3278
|
+
WHISPER_LOG_ERROR("%s: failed to load model\n", __func__);
|
|
3169
3279
|
delete ctx;
|
|
3170
3280
|
return nullptr;
|
|
3171
3281
|
}
|
|
@@ -3247,12 +3357,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
|
|
3247
3357
|
void whisper_free_state(struct whisper_state * state)
|
|
3248
3358
|
{
|
|
3249
3359
|
if (state) {
|
|
3360
|
+
kv_cache_free(state->kv_self);
|
|
3250
3361
|
kv_cache_free(state->kv_cross);
|
|
3251
3362
|
|
|
3252
|
-
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
|
|
3253
|
-
kv_cache_free(state->decoders[i].kv_self);
|
|
3254
|
-
}
|
|
3255
|
-
|
|
3256
3363
|
#ifdef WHISPER_USE_COREML
|
|
3257
3364
|
if (state->ctx_coreml != nullptr) {
|
|
3258
3365
|
whisper_coreml_free(state->ctx_coreml);
|
|
@@ -3260,13 +3367,6 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3260
3367
|
}
|
|
3261
3368
|
#endif
|
|
3262
3369
|
|
|
3263
|
-
#ifdef WSP_GGML_USE_METAL
|
|
3264
|
-
if (state->ctx_metal) {
|
|
3265
|
-
wsp_ggml_metal_free(state->ctx_metal);
|
|
3266
|
-
state->ctx_metal = nullptr;
|
|
3267
|
-
}
|
|
3268
|
-
#endif
|
|
3269
|
-
|
|
3270
3370
|
#ifdef WHISPER_USE_OPENVINO
|
|
3271
3371
|
if (state->ctx_openvino != nullptr) {
|
|
3272
3372
|
whisper_openvino_free(state->ctx_openvino);
|
|
@@ -3274,10 +3374,14 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
3274
3374
|
}
|
|
3275
3375
|
#endif
|
|
3276
3376
|
|
|
3377
|
+
whisper_batch_free(state->batch);
|
|
3378
|
+
|
|
3277
3379
|
whisper_allocr_free(state->alloc_conv);
|
|
3278
|
-
whisper_allocr_free(state->alloc_decode);
|
|
3279
|
-
whisper_allocr_free(state->alloc_cross);
|
|
3280
3380
|
whisper_allocr_free(state->alloc_encode);
|
|
3381
|
+
whisper_allocr_free(state->alloc_cross);
|
|
3382
|
+
whisper_allocr_free(state->alloc_decode);
|
|
3383
|
+
|
|
3384
|
+
wsp_ggml_backend_free(state->backend);
|
|
3281
3385
|
|
|
3282
3386
|
delete state;
|
|
3283
3387
|
}
|
|
@@ -3288,12 +3392,15 @@ void whisper_free(struct whisper_context * ctx) {
|
|
|
3288
3392
|
if (ctx->model.ctx) {
|
|
3289
3393
|
wsp_ggml_free(ctx->model.ctx);
|
|
3290
3394
|
}
|
|
3291
|
-
|
|
3292
|
-
|
|
3395
|
+
|
|
3396
|
+
if (ctx->model.buffer) {
|
|
3397
|
+
wsp_ggml_backend_buffer_free(ctx->model.buffer);
|
|
3293
3398
|
}
|
|
3294
3399
|
|
|
3295
3400
|
whisper_free_state(ctx->state);
|
|
3296
3401
|
|
|
3402
|
+
wsp_ggml_backend_free(ctx->backend);
|
|
3403
|
+
|
|
3297
3404
|
delete ctx;
|
|
3298
3405
|
}
|
|
3299
3406
|
}
|
|
@@ -3312,7 +3419,7 @@ void whisper_free_params(struct whisper_full_params * params) {
|
|
|
3312
3419
|
|
|
3313
3420
|
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
|
3314
3421
|
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
|
3315
|
-
|
|
3422
|
+
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
|
3316
3423
|
return -1;
|
|
3317
3424
|
}
|
|
3318
3425
|
|
|
@@ -3326,7 +3433,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
|
3326
3433
|
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
|
3327
3434
|
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
|
3328
3435
|
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
|
3329
|
-
|
|
3436
|
+
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
|
3330
3437
|
return -1;
|
|
3331
3438
|
}
|
|
3332
3439
|
|
|
@@ -3354,7 +3461,7 @@ int whisper_set_mel_with_state(
|
|
|
3354
3461
|
int n_len,
|
|
3355
3462
|
int n_mel) {
|
|
3356
3463
|
if (n_mel != ctx->model.filters.n_mel) {
|
|
3357
|
-
|
|
3464
|
+
WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
|
|
3358
3465
|
return -1;
|
|
3359
3466
|
}
|
|
3360
3467
|
|
|
@@ -3378,7 +3485,7 @@ int whisper_set_mel(
|
|
|
3378
3485
|
|
|
3379
3486
|
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
|
3380
3487
|
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
|
3381
|
-
|
|
3488
|
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3382
3489
|
return -1;
|
|
3383
3490
|
}
|
|
3384
3491
|
|
|
@@ -3387,7 +3494,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
|
3387
3494
|
|
|
3388
3495
|
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
3389
3496
|
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
|
3390
|
-
|
|
3497
|
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3391
3498
|
return -1;
|
|
3392
3499
|
}
|
|
3393
3500
|
|
|
@@ -3395,10 +3502,12 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
|
3395
3502
|
}
|
|
3396
3503
|
|
|
3397
3504
|
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
3398
|
-
|
|
3505
|
+
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
|
|
3506
|
+
|
|
3507
|
+
whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
|
|
3399
3508
|
|
|
3400
|
-
if (!whisper_decode_internal(*ctx, *state, state->
|
|
3401
|
-
|
|
3509
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
|
|
3510
|
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3402
3511
|
return 1;
|
|
3403
3512
|
}
|
|
3404
3513
|
|
|
@@ -3406,27 +3515,19 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
|
3406
3515
|
}
|
|
3407
3516
|
|
|
3408
3517
|
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
3409
|
-
// TODO: add selected_decoder_id to state
|
|
3410
|
-
const int selected_decoder_id = 0;
|
|
3411
|
-
|
|
3412
3518
|
if (ctx->state == nullptr) {
|
|
3413
|
-
|
|
3414
|
-
return
|
|
3415
|
-
}
|
|
3416
|
-
|
|
3417
|
-
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
|
3418
|
-
log("%s: failed to eval\n", __func__);
|
|
3419
|
-
return 1;
|
|
3519
|
+
WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
|
|
3520
|
+
return -1;
|
|
3420
3521
|
}
|
|
3421
3522
|
|
|
3422
|
-
return
|
|
3523
|
+
return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads);
|
|
3423
3524
|
}
|
|
3424
3525
|
|
|
3425
3526
|
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
|
3426
3527
|
const auto res = tokenize(ctx->vocab, text);
|
|
3427
3528
|
|
|
3428
3529
|
if (n_max_tokens < (int) res.size()) {
|
|
3429
|
-
|
|
3530
|
+
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
|
3430
3531
|
return -1;
|
|
3431
3532
|
}
|
|
3432
3533
|
|
|
@@ -3454,7 +3555,7 @@ int whisper_lang_id(const char * lang) {
|
|
|
3454
3555
|
}
|
|
3455
3556
|
}
|
|
3456
3557
|
|
|
3457
|
-
|
|
3558
|
+
WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang);
|
|
3458
3559
|
return -1;
|
|
3459
3560
|
}
|
|
3460
3561
|
return g_lang.at(lang).first;
|
|
@@ -3467,7 +3568,18 @@ const char * whisper_lang_str(int id) {
|
|
|
3467
3568
|
}
|
|
3468
3569
|
}
|
|
3469
3570
|
|
|
3470
|
-
|
|
3571
|
+
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
|
|
3572
|
+
return nullptr;
|
|
3573
|
+
}
|
|
3574
|
+
|
|
3575
|
+
const char * whisper_lang_str_full(int id) {
|
|
3576
|
+
for (const auto & kv : g_lang) {
|
|
3577
|
+
if (kv.second.first == id) {
|
|
3578
|
+
return kv.second.second.c_str();
|
|
3579
|
+
}
|
|
3580
|
+
}
|
|
3581
|
+
|
|
3582
|
+
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
|
|
3471
3583
|
return nullptr;
|
|
3472
3584
|
}
|
|
3473
3585
|
|
|
@@ -3480,29 +3592,29 @@ int whisper_lang_auto_detect_with_state(
|
|
|
3480
3592
|
const int seek = offset_ms/10;
|
|
3481
3593
|
|
|
3482
3594
|
if (seek < 0) {
|
|
3483
|
-
|
|
3595
|
+
WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
|
|
3484
3596
|
return -1;
|
|
3485
3597
|
}
|
|
3486
3598
|
|
|
3487
3599
|
if (seek >= state->mel.n_len_org) {
|
|
3488
|
-
|
|
3600
|
+
WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
|
|
3489
3601
|
return -2;
|
|
3490
3602
|
}
|
|
3491
3603
|
|
|
3492
3604
|
// run the encoder
|
|
3493
3605
|
if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
|
|
3494
|
-
|
|
3606
|
+
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
|
3495
3607
|
return -6;
|
|
3496
3608
|
}
|
|
3497
3609
|
|
|
3498
3610
|
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
|
3499
3611
|
|
|
3500
3612
|
if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
|
3501
|
-
|
|
3613
|
+
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
3502
3614
|
return -7;
|
|
3503
3615
|
}
|
|
3504
3616
|
|
|
3505
|
-
auto & logits_id = state->logits_id;
|
|
3617
|
+
auto & logits_id = state->decoders[0].logits_id;
|
|
3506
3618
|
logits_id.clear();
|
|
3507
3619
|
|
|
3508
3620
|
for (const auto & kv : g_lang) {
|
|
@@ -3698,28 +3810,31 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
|
|
|
3698
3810
|
void whisper_print_timings(struct whisper_context * ctx) {
|
|
3699
3811
|
const int64_t t_end_us = wsp_ggml_time_us();
|
|
3700
3812
|
|
|
3701
|
-
|
|
3702
|
-
|
|
3813
|
+
WHISPER_LOG_INFO("\n");
|
|
3814
|
+
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
|
|
3703
3815
|
if (ctx->state != nullptr) {
|
|
3704
3816
|
|
|
3705
3817
|
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
3706
3818
|
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
|
3707
3819
|
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
3820
|
+
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
|
|
3708
3821
|
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
|
3709
3822
|
|
|
3710
|
-
|
|
3711
|
-
|
|
3712
|
-
|
|
3713
|
-
|
|
3714
|
-
|
|
3715
|
-
|
|
3823
|
+
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
|
3824
|
+
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
|
3825
|
+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
|
3826
|
+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
|
3827
|
+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
|
3828
|
+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
|
3829
|
+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
|
3716
3830
|
}
|
|
3717
|
-
|
|
3831
|
+
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
|
3718
3832
|
}
|
|
3719
3833
|
|
|
3720
3834
|
void whisper_reset_timings(struct whisper_context * ctx) {
|
|
3721
3835
|
ctx->t_start_us = wsp_ggml_time_us();
|
|
3722
3836
|
if (ctx->state != nullptr) {
|
|
3837
|
+
ctx->state->t_mel_us = 0;
|
|
3723
3838
|
ctx->state->t_sample_us = 0;
|
|
3724
3839
|
ctx->state->t_encode_us = 0;
|
|
3725
3840
|
ctx->state->t_decode_us = 0;
|
|
@@ -3727,6 +3842,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
|
|
|
3727
3842
|
ctx->state->n_sample = 0;
|
|
3728
3843
|
ctx->state->n_encode = 0;
|
|
3729
3844
|
ctx->state->n_decode = 0;
|
|
3845
|
+
ctx->state->n_batchd = 0;
|
|
3730
3846
|
ctx->state->n_prompt = 0;
|
|
3731
3847
|
}
|
|
3732
3848
|
}
|
|
@@ -3765,12 +3881,431 @@ const char * whisper_print_system_info(void) {
|
|
|
3765
3881
|
s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
|
|
3766
3882
|
s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
|
|
3767
3883
|
s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
|
|
3884
|
+
s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cublas()) + " | ";
|
|
3768
3885
|
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
|
3769
3886
|
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
|
3770
3887
|
|
|
3771
|
-
return s.c_str();
|
|
3888
|
+
return s.c_str();
|
|
3889
|
+
}
|
|
3890
|
+
|
|
3891
|
+
//////////////////////////////////
|
|
3892
|
+
// Grammar - ported from llama.cpp
|
|
3893
|
+
//////////////////////////////////
|
|
3894
|
+
|
|
3895
|
+
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
|
3896
|
+
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
|
3897
|
+
std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
|
3898
|
+
const char * src,
|
|
3899
|
+
whisper_partial_utf8 partial_start) {
|
|
3900
|
+
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
|
3901
|
+
const char * pos = src;
|
|
3902
|
+
std::vector<uint32_t> code_points;
|
|
3903
|
+
uint32_t value = partial_start.value;
|
|
3904
|
+
int n_remain = partial_start.n_remain;
|
|
3905
|
+
|
|
3906
|
+
// continue previous decode, if applicable
|
|
3907
|
+
while (*pos != 0 && n_remain > 0) {
|
|
3908
|
+
uint8_t next_byte = static_cast<uint8_t>(*pos);
|
|
3909
|
+
if ((next_byte >> 6) != 2) {
|
|
3910
|
+
// invalid sequence, abort
|
|
3911
|
+
code_points.push_back(0);
|
|
3912
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
|
|
3913
|
+
}
|
|
3914
|
+
value = (value << 6) + (next_byte & 0x3F);
|
|
3915
|
+
++pos;
|
|
3916
|
+
--n_remain;
|
|
3917
|
+
}
|
|
3918
|
+
|
|
3919
|
+
if (partial_start.n_remain > 0 && n_remain == 0) {
|
|
3920
|
+
code_points.push_back(value);
|
|
3921
|
+
}
|
|
3922
|
+
|
|
3923
|
+
// decode any subsequent utf-8 sequences, which may end in an incomplete one
|
|
3924
|
+
while (*pos != 0) {
|
|
3925
|
+
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
|
3926
|
+
uint8_t highbits = first_byte >> 4;
|
|
3927
|
+
n_remain = lookup[highbits] - 1;
|
|
3928
|
+
|
|
3929
|
+
if (n_remain < 0) {
|
|
3930
|
+
// invalid sequence, abort
|
|
3931
|
+
code_points.clear();
|
|
3932
|
+
code_points.push_back(0);
|
|
3933
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
|
|
3934
|
+
}
|
|
3935
|
+
|
|
3936
|
+
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
|
3937
|
+
value = first_byte & mask;
|
|
3938
|
+
++pos;
|
|
3939
|
+
while (*pos != 0 && n_remain > 0) {
|
|
3940
|
+
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
|
3941
|
+
++pos;
|
|
3942
|
+
--n_remain;
|
|
3943
|
+
}
|
|
3944
|
+
if (n_remain == 0) {
|
|
3945
|
+
code_points.push_back(value);
|
|
3946
|
+
}
|
|
3947
|
+
}
|
|
3948
|
+
code_points.push_back(0);
|
|
3949
|
+
|
|
3950
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
|
|
3951
|
+
}
|
|
3952
|
+
|
|
3953
|
+
// returns true iff pos points to the end of one of the definitions of a rule
|
|
3954
|
+
static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
|
|
3955
|
+
switch (pos->type) {
|
|
3956
|
+
case WHISPER_GRETYPE_END: return true; // NOLINT
|
|
3957
|
+
case WHISPER_GRETYPE_ALT: return true; // NOLINT
|
|
3958
|
+
default: return false;
|
|
3959
|
+
}
|
|
3960
|
+
}
|
|
3961
|
+
|
|
3962
|
+
// returns true iff chr satisfies the char range at pos (regular or inverse range)
|
|
3963
|
+
// asserts that pos is pointing to a char range element
|
|
3964
|
+
static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
|
|
3965
|
+
const whisper_grammar_element * pos,
|
|
3966
|
+
const uint32_t chr) {
|
|
3967
|
+
|
|
3968
|
+
bool found = false;
|
|
3969
|
+
bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
|
|
3970
|
+
|
|
3971
|
+
WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
|
|
3972
|
+
|
|
3973
|
+
do {
|
|
3974
|
+
if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
|
|
3975
|
+
// inclusive range, e.g. [a-z]
|
|
3976
|
+
found = found || (pos->value <= chr && chr <= pos[1].value);
|
|
3977
|
+
pos += 2;
|
|
3978
|
+
} else {
|
|
3979
|
+
// exact char match, e.g. [a] or "a"
|
|
3980
|
+
found = found || pos->value == chr;
|
|
3981
|
+
pos += 1;
|
|
3982
|
+
}
|
|
3983
|
+
} while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
|
|
3984
|
+
|
|
3985
|
+
return std::make_pair(found == is_positive_char, pos);
|
|
3986
|
+
}
|
|
3987
|
+
|
|
3988
|
+
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
|
|
3989
|
+
// range at pos (regular or inverse range)
|
|
3990
|
+
// asserts that pos is pointing to a char range element
|
|
3991
|
+
static bool whisper_grammar_match_partial_char(
|
|
3992
|
+
const whisper_grammar_element * pos,
|
|
3993
|
+
const whisper_partial_utf8 partial_utf8) {
|
|
3994
|
+
|
|
3995
|
+
bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
|
|
3996
|
+
WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
|
|
3997
|
+
|
|
3998
|
+
uint32_t partial_value = partial_utf8.value;
|
|
3999
|
+
int n_remain = partial_utf8.n_remain;
|
|
4000
|
+
|
|
4001
|
+
// invalid sequence or 7-bit char split across 2 bytes (overlong)
|
|
4002
|
+
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
|
|
4003
|
+
return false;
|
|
4004
|
+
}
|
|
4005
|
+
|
|
4006
|
+
// range of possible code points this partial UTF-8 sequence could complete to
|
|
4007
|
+
uint32_t low = partial_value << (n_remain * 6);
|
|
4008
|
+
uint32_t high = low | ((1 << (n_remain * 6)) - 1);
|
|
4009
|
+
|
|
4010
|
+
if (low == 0) {
|
|
4011
|
+
if (n_remain == 2) {
|
|
4012
|
+
low = 1 << 11;
|
|
4013
|
+
} else if (n_remain == 3) {
|
|
4014
|
+
low = 1 << 16;
|
|
4015
|
+
}
|
|
4016
|
+
}
|
|
4017
|
+
|
|
4018
|
+
do {
|
|
4019
|
+
if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
|
|
4020
|
+
// inclusive range, e.g. [a-z]
|
|
4021
|
+
if (pos->value <= high && low <= pos[1].value) {
|
|
4022
|
+
return is_positive_char;
|
|
4023
|
+
}
|
|
4024
|
+
pos += 2;
|
|
4025
|
+
} else {
|
|
4026
|
+
// exact char match, e.g. [a] or "a"
|
|
4027
|
+
if (low <= pos->value && pos->value <= high) {
|
|
4028
|
+
return is_positive_char;
|
|
4029
|
+
}
|
|
4030
|
+
pos += 1;
|
|
4031
|
+
}
|
|
4032
|
+
} while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
|
|
4033
|
+
|
|
4034
|
+
return !is_positive_char;
|
|
4035
|
+
}
|
|
4036
|
+
|
|
4037
|
+
|
|
4038
|
+
// transforms a grammar pushdown stack into N possible stacks, all ending
|
|
4039
|
+
// at a character range (terminal element)
|
|
4040
|
+
static void whisper_grammar_advance_stack(
|
|
4041
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4042
|
+
const std::vector<const whisper_grammar_element *> & stack,
|
|
4043
|
+
std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
|
|
4044
|
+
|
|
4045
|
+
if (stack.empty()) {
|
|
4046
|
+
new_stacks.push_back(stack);
|
|
4047
|
+
return;
|
|
4048
|
+
}
|
|
4049
|
+
|
|
4050
|
+
const whisper_grammar_element * pos = stack.back();
|
|
4051
|
+
|
|
4052
|
+
switch (pos->type) {
|
|
4053
|
+
case WHISPER_GRETYPE_RULE_REF: {
|
|
4054
|
+
const size_t rule_id = static_cast<size_t>(pos->value);
|
|
4055
|
+
const whisper_grammar_element * subpos = rules[rule_id].data();
|
|
4056
|
+
do {
|
|
4057
|
+
// init new stack without the top (pos)
|
|
4058
|
+
std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
|
|
4059
|
+
if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
|
|
4060
|
+
// if this rule ref is followed by another element, add that to stack
|
|
4061
|
+
new_stack.push_back(pos + 1);
|
|
4062
|
+
}
|
|
4063
|
+
if (!whisper_grammar_is_end_of_sequence(subpos)) {
|
|
4064
|
+
// if alternate is nonempty, add to stack
|
|
4065
|
+
new_stack.push_back(subpos);
|
|
4066
|
+
}
|
|
4067
|
+
whisper_grammar_advance_stack(rules, new_stack, new_stacks);
|
|
4068
|
+
while (!whisper_grammar_is_end_of_sequence(subpos)) {
|
|
4069
|
+
// scan to end of alternate def
|
|
4070
|
+
subpos++;
|
|
4071
|
+
}
|
|
4072
|
+
if (subpos->type == WHISPER_GRETYPE_ALT) {
|
|
4073
|
+
// there's another alternate def of this rule to process
|
|
4074
|
+
subpos++;
|
|
4075
|
+
} else {
|
|
4076
|
+
break;
|
|
4077
|
+
}
|
|
4078
|
+
} while (true);
|
|
4079
|
+
break;
|
|
4080
|
+
}
|
|
4081
|
+
case WHISPER_GRETYPE_CHAR:
|
|
4082
|
+
case WHISPER_GRETYPE_CHAR_NOT:
|
|
4083
|
+
new_stacks.push_back(stack);
|
|
4084
|
+
break;
|
|
4085
|
+
default:
|
|
4086
|
+
// end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
|
|
4087
|
+
// (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
|
4088
|
+
// those
|
|
4089
|
+
WHISPER_ASSERT(false);
|
|
4090
|
+
}
|
|
4091
|
+
}
|
|
4092
|
+
|
|
4093
|
+
// takes a set of possible pushdown stacks on a grammar, which are required to
|
|
4094
|
+
// be positioned at a character range (see `whisper_grammar_advance_stack`), and
|
|
4095
|
+
// produces the N possible stacks if the given char is accepted at those
|
|
4096
|
+
// positions
|
|
4097
|
+
static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
|
|
4098
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4099
|
+
const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
|
|
4100
|
+
const uint32_t chr) {
|
|
4101
|
+
|
|
4102
|
+
std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
|
|
4103
|
+
|
|
4104
|
+
for (const auto & stack : stacks) {
|
|
4105
|
+
if (stack.empty()) {
|
|
4106
|
+
continue;
|
|
4107
|
+
}
|
|
4108
|
+
|
|
4109
|
+
auto match = whisper_grammar_match_char(stack.back(), chr);
|
|
4110
|
+
if (match.first) {
|
|
4111
|
+
const whisper_grammar_element * pos = match.second;
|
|
4112
|
+
|
|
4113
|
+
// update top of stack to next element, if any
|
|
4114
|
+
std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
|
|
4115
|
+
if (!whisper_grammar_is_end_of_sequence(pos)) {
|
|
4116
|
+
new_stack.push_back(pos);
|
|
4117
|
+
}
|
|
4118
|
+
whisper_grammar_advance_stack(rules, new_stack, new_stacks);
|
|
4119
|
+
}
|
|
4120
|
+
}
|
|
4121
|
+
|
|
4122
|
+
return new_stacks;
|
|
4123
|
+
}
|
|
4124
|
+
|
|
4125
|
+
static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
|
|
4126
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4127
|
+
const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
|
|
4128
|
+
const std::vector<whisper_grammar_candidate> & candidates);
|
|
4129
|
+
|
|
4130
|
+
static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
|
|
4131
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4132
|
+
const std::vector<const whisper_grammar_element *> & stack,
|
|
4133
|
+
const std::vector<whisper_grammar_candidate> & candidates) {
|
|
4134
|
+
|
|
4135
|
+
std::vector<whisper_grammar_candidate> rejects;
|
|
4136
|
+
|
|
4137
|
+
if (stack.empty()) {
|
|
4138
|
+
for (auto tok : candidates) {
|
|
4139
|
+
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
|
|
4140
|
+
rejects.push_back(tok);
|
|
4141
|
+
}
|
|
4142
|
+
}
|
|
4143
|
+
return rejects;
|
|
4144
|
+
}
|
|
4145
|
+
|
|
4146
|
+
const whisper_grammar_element * stack_pos = stack.back();
|
|
4147
|
+
|
|
4148
|
+
std::vector<whisper_grammar_candidate> next_candidates;
|
|
4149
|
+
for (auto tok : candidates) {
|
|
4150
|
+
if (*tok.code_points == 0) {
|
|
4151
|
+
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
|
4152
|
+
// that cannot satisfy this position in grammar
|
|
4153
|
+
if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
|
|
4154
|
+
rejects.push_back(tok);
|
|
4155
|
+
}
|
|
4156
|
+
} else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
|
|
4157
|
+
next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
|
|
4158
|
+
} else {
|
|
4159
|
+
rejects.push_back(tok);
|
|
4160
|
+
}
|
|
4161
|
+
}
|
|
4162
|
+
|
|
4163
|
+
const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
|
|
4164
|
+
|
|
4165
|
+
// update top of stack to next element, if any
|
|
4166
|
+
std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
|
|
4167
|
+
if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
|
|
4168
|
+
stack_after.push_back(stack_pos_after);
|
|
4169
|
+
}
|
|
4170
|
+
std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
|
|
4171
|
+
whisper_grammar_advance_stack(rules, stack_after, next_stacks);
|
|
4172
|
+
|
|
4173
|
+
auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
|
4174
|
+
for (auto tok : next_rejects) {
|
|
4175
|
+
rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
|
|
4176
|
+
}
|
|
4177
|
+
|
|
4178
|
+
return rejects;
|
|
4179
|
+
}
|
|
4180
|
+
|
|
4181
|
+
static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
|
|
4182
|
+
const std::vector<std::vector<whisper_grammar_element>> & rules,
|
|
4183
|
+
const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
|
|
4184
|
+
const std::vector<whisper_grammar_candidate> & candidates) {
|
|
4185
|
+
if (candidates.empty() || stacks.empty()) {
|
|
4186
|
+
return std::vector<whisper_grammar_candidate>();
|
|
4187
|
+
}
|
|
4188
|
+
|
|
4189
|
+
auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
|
4190
|
+
|
|
4191
|
+
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
|
4192
|
+
rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
|
4193
|
+
}
|
|
4194
|
+
return rejects;
|
|
4195
|
+
}
|
|
4196
|
+
|
|
4197
|
+
static struct whisper_grammar whisper_grammar_init(
|
|
4198
|
+
const whisper_grammar_element ** rules,
|
|
4199
|
+
size_t n_rules,
|
|
4200
|
+
size_t i_start_rule) {
|
|
4201
|
+
const whisper_grammar_element * pos;
|
|
4202
|
+
|
|
4203
|
+
// copy rule definitions into vectors
|
|
4204
|
+
std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
|
|
4205
|
+
for (size_t i = 0; i < n_rules; i++) {
|
|
4206
|
+
for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
|
|
4207
|
+
vec_rules[i].push_back(*pos);
|
|
4208
|
+
}
|
|
4209
|
+
vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
|
|
4210
|
+
}
|
|
4211
|
+
|
|
4212
|
+
// loop over alternates of start rule to build initial stacks
|
|
4213
|
+
std::vector<std::vector<const whisper_grammar_element *>> stacks;
|
|
4214
|
+
pos = rules[i_start_rule];
|
|
4215
|
+
do {
|
|
4216
|
+
std::vector<const whisper_grammar_element *> stack;
|
|
4217
|
+
if (!whisper_grammar_is_end_of_sequence(pos)) {
|
|
4218
|
+
// if alternate is nonempty, add to stack
|
|
4219
|
+
stack.push_back(pos);
|
|
4220
|
+
}
|
|
4221
|
+
whisper_grammar_advance_stack(vec_rules, stack, stacks);
|
|
4222
|
+
while (!whisper_grammar_is_end_of_sequence(pos)) {
|
|
4223
|
+
// scan to end of alternate def
|
|
4224
|
+
pos++;
|
|
4225
|
+
}
|
|
4226
|
+
if (pos->type == WHISPER_GRETYPE_ALT) {
|
|
4227
|
+
// there's another alternate def of this rule to process
|
|
4228
|
+
pos++;
|
|
4229
|
+
} else {
|
|
4230
|
+
break;
|
|
4231
|
+
}
|
|
4232
|
+
} while (true);
|
|
4233
|
+
|
|
4234
|
+
return { std::move(vec_rules), std::move(stacks), {} };
|
|
4235
|
+
}
|
|
4236
|
+
|
|
4237
|
+
static void whisper_suppress_invalid_grammar(
|
|
4238
|
+
whisper_context & ctx,
|
|
4239
|
+
const whisper_full_params & params,
|
|
4240
|
+
std::vector<float> & logits,
|
|
4241
|
+
const whisper_grammar & grammar) {
|
|
4242
|
+
|
|
4243
|
+
if (grammar.rules.empty() || grammar.stacks.empty()) {
|
|
4244
|
+
return;
|
|
4245
|
+
}
|
|
4246
|
+
|
|
4247
|
+
//bool allow_eot = false;
|
|
4248
|
+
//for (const auto & stack : grammar.stacks) {
|
|
4249
|
+
// if (stack.empty()) {
|
|
4250
|
+
// allow_eot = true;
|
|
4251
|
+
// break;
|
|
4252
|
+
// }
|
|
4253
|
+
//}
|
|
4254
|
+
|
|
4255
|
+
const whisper_token eot = whisper_token_eot(&ctx);
|
|
4256
|
+
|
|
4257
|
+
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
|
|
4258
|
+
std::vector<whisper_grammar_candidate> candidates_grammar;
|
|
4259
|
+
|
|
4260
|
+
for (whisper_token id = 0; id < eot; ++id) {
|
|
4261
|
+
const std::string & text = ctx.vocab.id_to_token[id];
|
|
4262
|
+
if (!text.empty()) {
|
|
4263
|
+
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
|
|
4264
|
+
candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
|
4265
|
+
}
|
|
4266
|
+
}
|
|
4267
|
+
|
|
4268
|
+
const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
|
4269
|
+
|
|
4270
|
+
for (const auto & reject : rejects) {
|
|
4271
|
+
logits[reject.id] -= params.grammar_penalty;
|
|
4272
|
+
}
|
|
4273
|
+
|
|
4274
|
+
// when the grammar allows a continuation, we penalize the end-of-text token
|
|
4275
|
+
//if (!allow_eot) {
|
|
4276
|
+
// logits[eot] -= params.grammar_penalty;
|
|
4277
|
+
//}
|
|
4278
|
+
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
|
|
4279
|
+
}
|
|
4280
|
+
|
|
4281
|
+
static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
|
|
4282
|
+
if (grammar.rules.empty() || grammar.stacks.empty()) {
|
|
4283
|
+
return;
|
|
4284
|
+
}
|
|
4285
|
+
|
|
4286
|
+
//fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
|
|
4287
|
+
|
|
4288
|
+
const std::string & text = ctx.vocab.id_to_token[token];
|
|
4289
|
+
|
|
4290
|
+
if (text.rfind("[_", 0) == 0) {
|
|
4291
|
+
// fprintf(stderr, " (skipped)\n");
|
|
4292
|
+
return;
|
|
4293
|
+
}
|
|
4294
|
+
// fprintf(stderr, "\n");
|
|
4295
|
+
|
|
4296
|
+
// Note terminating 0 in decoded string
|
|
4297
|
+
const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8);
|
|
4298
|
+
const auto & code_points = decoded.first;
|
|
4299
|
+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
|
4300
|
+
grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
|
|
4301
|
+
}
|
|
4302
|
+
grammar.partial_utf8 = decoded.second;
|
|
3772
4303
|
}
|
|
3773
4304
|
|
|
4305
|
+
//////////////
|
|
4306
|
+
// END grammar
|
|
4307
|
+
//////////////
|
|
4308
|
+
|
|
3774
4309
|
////////////////////////////////////////////////////////////////////////////
|
|
3775
4310
|
|
|
3776
4311
|
struct whisper_context_params * whisper_context_default_params_by_ref() {
|
|
@@ -3800,6 +4335,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3800
4335
|
|
|
3801
4336
|
/*.translate =*/ false,
|
|
3802
4337
|
/*.no_context =*/ true,
|
|
4338
|
+
/*.no_timestamps =*/ false,
|
|
3803
4339
|
/*.single_segment =*/ false,
|
|
3804
4340
|
/*.print_special =*/ false,
|
|
3805
4341
|
/*.print_progress =*/ true,
|
|
@@ -3833,7 +4369,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3833
4369
|
/*.max_initial_ts =*/ 1.0f,
|
|
3834
4370
|
/*.length_penalty =*/ -1.0f,
|
|
3835
4371
|
|
|
3836
|
-
/*.temperature_inc =*/ 0.
|
|
4372
|
+
/*.temperature_inc =*/ 0.2f,
|
|
3837
4373
|
/*.entropy_thold =*/ 2.4f,
|
|
3838
4374
|
/*.logprob_thold =*/ -1.0f,
|
|
3839
4375
|
/*.no_speech_thold =*/ 0.6f,
|
|
@@ -3862,19 +4398,24 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3862
4398
|
|
|
3863
4399
|
/*.logits_filter_callback =*/ nullptr,
|
|
3864
4400
|
/*.logits_filter_callback_user_data =*/ nullptr,
|
|
4401
|
+
|
|
4402
|
+
/*.grammar_rules =*/ nullptr,
|
|
4403
|
+
/*.n_grammar_rules =*/ 0,
|
|
4404
|
+
/*.i_start_rule =*/ 0,
|
|
4405
|
+
/*.grammar_penalty =*/ 100.0f,
|
|
3865
4406
|
};
|
|
3866
4407
|
|
|
3867
4408
|
switch (strategy) {
|
|
3868
4409
|
case WHISPER_SAMPLING_GREEDY:
|
|
3869
4410
|
{
|
|
3870
4411
|
result.greedy = {
|
|
3871
|
-
/*.best_of =*/
|
|
4412
|
+
/*.best_of =*/ 5,
|
|
3872
4413
|
};
|
|
3873
4414
|
} break;
|
|
3874
4415
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
3875
4416
|
{
|
|
3876
4417
|
result.beam_search = {
|
|
3877
|
-
/*.beam_size =*/
|
|
4418
|
+
/*.beam_size =*/ 5,
|
|
3878
4419
|
|
|
3879
4420
|
/*.patience =*/ -1.0f,
|
|
3880
4421
|
};
|
|
@@ -3964,11 +4505,12 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
|
3964
4505
|
// process the logits for the selected decoder
|
|
3965
4506
|
// - applies logit filters
|
|
3966
4507
|
// - computes logprobs and probs
|
|
4508
|
+
// TODO: optimize
|
|
3967
4509
|
static void whisper_process_logits(
|
|
3968
4510
|
struct whisper_context & ctx,
|
|
3969
4511
|
struct whisper_state & state,
|
|
3970
|
-
const struct whisper_full_params params,
|
|
3971
4512
|
struct whisper_decoder & decoder,
|
|
4513
|
+
const struct whisper_full_params params,
|
|
3972
4514
|
float temperature) {
|
|
3973
4515
|
const auto & vocab = ctx.vocab;
|
|
3974
4516
|
const auto & tokens_cur = decoder.sequence.tokens;
|
|
@@ -3985,7 +4527,7 @@ static void whisper_process_logits(
|
|
|
3985
4527
|
auto & logprobs = decoder.logprobs;
|
|
3986
4528
|
{
|
|
3987
4529
|
logits.resize(n_logits);
|
|
3988
|
-
memcpy(logits.data(), state.logits.data() +
|
|
4530
|
+
memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
|
|
3989
4531
|
|
|
3990
4532
|
if (temperature > 0.0f) {
|
|
3991
4533
|
for (int i = 0; i < n_logits; i++) {
|
|
@@ -4013,6 +4555,11 @@ static void whisper_process_logits(
|
|
|
4013
4555
|
// suppress <|notimestamps|> token
|
|
4014
4556
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
|
|
4015
4557
|
logits[vocab.token_not] = -INFINITY;
|
|
4558
|
+
if (params.no_timestamps) {
|
|
4559
|
+
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
|
4560
|
+
logits[i] = -INFINITY;
|
|
4561
|
+
}
|
|
4562
|
+
}
|
|
4016
4563
|
|
|
4017
4564
|
// suppress sot and nosp tokens
|
|
4018
4565
|
logits[vocab.token_sot] = -INFINITY;
|
|
@@ -4028,6 +4575,14 @@ static void whisper_process_logits(
|
|
|
4028
4575
|
logits[vocab.token_transcribe] = -INFINITY;
|
|
4029
4576
|
logits[vocab.token_prev] = -INFINITY;
|
|
4030
4577
|
|
|
4578
|
+
// suppress lang tokens
|
|
4579
|
+
for (size_t i = 0; i < g_lang.size(); ++i) {
|
|
4580
|
+
logits[whisper_token_lang(&ctx, i)] = -INFINITY;
|
|
4581
|
+
}
|
|
4582
|
+
|
|
4583
|
+
// suppress prev token
|
|
4584
|
+
logits[vocab.token_prev] = -INFINITY;
|
|
4585
|
+
|
|
4031
4586
|
if (params.logits_filter_callback) {
|
|
4032
4587
|
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
|
4033
4588
|
}
|
|
@@ -4059,7 +4614,7 @@ static void whisper_process_logits(
|
|
|
4059
4614
|
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
|
|
4060
4615
|
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
|
|
4061
4616
|
|
|
4062
|
-
//
|
|
4617
|
+
//WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
|
|
4063
4618
|
|
|
4064
4619
|
if (last_was_timestamp) {
|
|
4065
4620
|
if (penultimate_was_timestamp) {
|
|
@@ -4135,13 +4690,37 @@ static void whisper_process_logits(
|
|
|
4135
4690
|
|
|
4136
4691
|
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
|
|
4137
4692
|
|
|
4138
|
-
//
|
|
4693
|
+
//WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
|
4139
4694
|
|
|
4140
4695
|
if (timestamp_logprob > max_text_token_logprob) {
|
|
4141
4696
|
for (int i = 0; i < vocab.token_beg; ++i) {
|
|
4142
4697
|
logits[i] = -INFINITY;
|
|
4143
4698
|
logprobs[i] = -INFINITY;
|
|
4144
4699
|
}
|
|
4700
|
+
} else {
|
|
4701
|
+
if (params.n_grammar_rules > 0) {
|
|
4702
|
+
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
|
4703
|
+
|
|
4704
|
+
// populate the logprobs array (log_softmax)
|
|
4705
|
+
{
|
|
4706
|
+
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
|
4707
|
+
float logsumexp = 0.0f;
|
|
4708
|
+
for (int i = 0; i < n_logits; ++i) {
|
|
4709
|
+
if (logits[i] > -INFINITY) {
|
|
4710
|
+
logsumexp += expf(logits[i] - logit_max);
|
|
4711
|
+
}
|
|
4712
|
+
}
|
|
4713
|
+
logsumexp = logf(logsumexp) + logit_max;
|
|
4714
|
+
|
|
4715
|
+
for (int i = 0; i < n_logits; ++i) {
|
|
4716
|
+
if (logits[i] > -INFINITY) {
|
|
4717
|
+
logprobs[i] = logits[i] - logsumexp;
|
|
4718
|
+
} else {
|
|
4719
|
+
logprobs[i] = -INFINITY;
|
|
4720
|
+
}
|
|
4721
|
+
}
|
|
4722
|
+
}
|
|
4723
|
+
}
|
|
4145
4724
|
}
|
|
4146
4725
|
}
|
|
4147
4726
|
}
|
|
@@ -4159,38 +4738,60 @@ static void whisper_process_logits(
|
|
|
4159
4738
|
|
|
4160
4739
|
#if 0
|
|
4161
4740
|
// print first 100 logits - token string : logit
|
|
4162
|
-
for (int i = 0; i <
|
|
4163
|
-
|
|
4164
|
-
|
|
4165
|
-
|
|
4166
|
-
|
|
4167
|
-
|
|
4741
|
+
//for (int i = 0; i < 10; i++) {
|
|
4742
|
+
// const auto token = vocab.id_to_token.at(i);
|
|
4743
|
+
// const auto prob = probs[i];
|
|
4744
|
+
// const auto logit = logits[i];
|
|
4745
|
+
// const auto logprob = logprobs[i];
|
|
4746
|
+
// printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
|
|
4747
|
+
//}
|
|
4748
|
+
|
|
4749
|
+
// print sorted
|
|
4750
|
+
{
|
|
4751
|
+
std::vector<std::pair<float, int>> pairs;
|
|
4752
|
+
|
|
4753
|
+
for (int i = 0; i < n_logits; ++i) {
|
|
4754
|
+
pairs.push_back(std::make_pair(probs[i], i));
|
|
4755
|
+
}
|
|
4756
|
+
|
|
4757
|
+
std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
|
4758
|
+
return a.first > b.first;
|
|
4759
|
+
});
|
|
4760
|
+
|
|
4761
|
+
for (int i = 0; i < 10; i++) {
|
|
4762
|
+
const auto token = vocab.id_to_token.at(pairs[i].second);
|
|
4763
|
+
const auto prob = pairs[i].first;
|
|
4764
|
+
const auto logit = logits[pairs[i].second];
|
|
4765
|
+
const auto logprob = logprobs[pairs[i].second];
|
|
4766
|
+
printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
|
|
4767
|
+
}
|
|
4768
|
+
|
|
4769
|
+
printf("----------------\n");
|
|
4168
4770
|
}
|
|
4169
4771
|
|
|
4170
4772
|
// "And", "and", " And", " and"
|
|
4171
|
-
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
|
4172
|
-
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
|
4173
|
-
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
|
4174
|
-
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
|
4175
|
-
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
|
4176
|
-
|
|
4177
|
-
printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
|
4178
|
-
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
|
4179
|
-
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
|
4180
|
-
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
|
4181
|
-
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
4182
|
-
|
|
4183
|
-
printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
|
4184
|
-
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
|
4185
|
-
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
|
4186
|
-
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
|
4187
|
-
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
|
4773
|
+
//printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
|
4774
|
+
//printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
|
4775
|
+
//printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
|
4776
|
+
//printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
|
4777
|
+
//printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
|
4778
|
+
|
|
4779
|
+
//printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
|
4780
|
+
//printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
|
4781
|
+
//printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
|
4782
|
+
//printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
|
4783
|
+
//printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
|
4784
|
+
|
|
4785
|
+
//printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
|
4786
|
+
//printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
|
4787
|
+
//printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
|
4788
|
+
//printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
|
4789
|
+
//printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
|
4188
4790
|
#endif
|
|
4189
4791
|
}
|
|
4190
4792
|
|
|
4191
4793
|
static whisper_token_data whisper_sample_token(
|
|
4192
4794
|
whisper_context & ctx,
|
|
4193
|
-
whisper_state & state,
|
|
4194
4795
|
const whisper_decoder & decoder,
|
|
4195
4796
|
bool best) {
|
|
4196
4797
|
whisper_token_data result = {
|
|
@@ -4235,7 +4836,7 @@ static whisper_token_data whisper_sample_token(
|
|
|
4235
4836
|
} else {
|
|
4236
4837
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
4237
4838
|
|
|
4238
|
-
result.id = dist(
|
|
4839
|
+
result.id = dist(decoder.rng);
|
|
4239
4840
|
result.p = probs[result.id];
|
|
4240
4841
|
result.plog = logprobs[result.id];
|
|
4241
4842
|
}
|
|
@@ -4245,15 +4846,12 @@ static whisper_token_data whisper_sample_token(
|
|
|
4245
4846
|
result.pt = result.p;
|
|
4246
4847
|
}
|
|
4247
4848
|
|
|
4248
|
-
state.n_sample++;
|
|
4249
|
-
|
|
4250
4849
|
return result;
|
|
4251
4850
|
}
|
|
4252
4851
|
|
|
4253
4852
|
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
4254
4853
|
whisper_context & ctx,
|
|
4255
|
-
|
|
4256
|
-
const whisper_decoder & decoder,
|
|
4854
|
+
whisper_decoder & decoder,
|
|
4257
4855
|
int k) {
|
|
4258
4856
|
const auto & vocab = ctx.vocab;
|
|
4259
4857
|
|
|
@@ -4263,7 +4861,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
4263
4861
|
|
|
4264
4862
|
const int n_logits = vocab.n_vocab;
|
|
4265
4863
|
|
|
4266
|
-
auto & logits_id =
|
|
4864
|
+
auto & logits_id = decoder.logits_id;
|
|
4267
4865
|
|
|
4268
4866
|
logits_id.resize(n_logits);
|
|
4269
4867
|
for (int i = 0; i < n_logits; ++i) {
|
|
@@ -4309,8 +4907,11 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
4309
4907
|
ptsum = sum_ts;
|
|
4310
4908
|
}
|
|
4311
4909
|
|
|
4910
|
+
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
4911
|
+
|
|
4312
4912
|
for (int i = 0; i < k; ++i) {
|
|
4313
|
-
const auto id =
|
|
4913
|
+
const auto id = dist(decoder.rng);
|
|
4914
|
+
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
|
4314
4915
|
|
|
4315
4916
|
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
|
4316
4917
|
|
|
@@ -4320,8 +4921,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
4320
4921
|
}
|
|
4321
4922
|
}
|
|
4322
4923
|
|
|
4323
|
-
state.n_sample++;
|
|
4324
|
-
|
|
4325
4924
|
return result;
|
|
4326
4925
|
}
|
|
4327
4926
|
|
|
@@ -4374,115 +4973,6 @@ static void whisper_sequence_score(
|
|
|
4374
4973
|
}
|
|
4375
4974
|
}
|
|
4376
4975
|
|
|
4377
|
-
static bool whisper_kv_swap_fast(
|
|
4378
|
-
std::vector<int> & view,
|
|
4379
|
-
whisper_decoder src[],
|
|
4380
|
-
std::vector<kv_buf> & kv_swap_bufs,
|
|
4381
|
-
const int & n_decoders) {
|
|
4382
|
-
WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
|
|
4383
|
-
|
|
4384
|
-
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
|
|
4385
|
-
std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
|
|
4386
|
-
|
|
4387
|
-
// (buffer->decoder or decoder->decoder)
|
|
4388
|
-
std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
|
|
4389
|
-
|
|
4390
|
-
// (decoder<->decoder)
|
|
4391
|
-
std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
|
|
4392
|
-
std::vector<whisper_pair<int, int>> p_swap_vec;
|
|
4393
|
-
p_swap_vec.reserve(n_decoders);
|
|
4394
|
-
|
|
4395
|
-
// see https://github.com/ggerganov/whisper.cpp/wiki
|
|
4396
|
-
for (int i = 0; i < n_decoders; i++) {
|
|
4397
|
-
// zero-copy (no modification)
|
|
4398
|
-
if (i == view[i] || view[i] < 0) {
|
|
4399
|
-
continue;
|
|
4400
|
-
}
|
|
4401
|
-
|
|
4402
|
-
bool is_one_copy = true;
|
|
4403
|
-
// since we modify data sequentially, we only consider decoder indices after current index
|
|
4404
|
-
for (int j = i + 1; j < n_decoders; j++) {
|
|
4405
|
-
if (i == view[j]) {
|
|
4406
|
-
// detect symmetric diagram
|
|
4407
|
-
if (j == view[i]) {
|
|
4408
|
-
p_swap_set.insert(i);
|
|
4409
|
-
p_swap_set.insert(j);
|
|
4410
|
-
p_swap_vec.emplace_back(i, j);
|
|
4411
|
-
} else {
|
|
4412
|
-
two_copy.insert(i);
|
|
4413
|
-
is_one_copy = false;
|
|
4414
|
-
}
|
|
4415
|
-
break;
|
|
4416
|
-
}
|
|
4417
|
-
}
|
|
4418
|
-
if (is_one_copy) {
|
|
4419
|
-
one_copy.insert(i);
|
|
4420
|
-
}
|
|
4421
|
-
}
|
|
4422
|
-
|
|
4423
|
-
kv_swap_bufs.resize(n_decoders);
|
|
4424
|
-
|
|
4425
|
-
for (int i = 0; i < n_decoders; i++) {
|
|
4426
|
-
kv_swap_bufs[i].k.resize(wsp_ggml_nbytes(src[i].kv_self.k));
|
|
4427
|
-
kv_swap_bufs[i].v.resize(wsp_ggml_nbytes(src[i].kv_self.v));
|
|
4428
|
-
}
|
|
4429
|
-
|
|
4430
|
-
for (auto & i : two_copy) {
|
|
4431
|
-
// make a copy of KV caches
|
|
4432
|
-
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
|
|
4433
|
-
memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
|
|
4434
|
-
memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
|
|
4435
|
-
}
|
|
4436
|
-
|
|
4437
|
-
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
|
|
4438
|
-
for (auto & i : two_copy) {
|
|
4439
|
-
// skip the decoder indices that require pointer swapping
|
|
4440
|
-
if (p_swap_set.find(i) != p_swap_set.end()) {
|
|
4441
|
-
continue;
|
|
4442
|
-
}
|
|
4443
|
-
|
|
4444
|
-
if (two_copy.find(view[i]) != two_copy.end()) {
|
|
4445
|
-
// modify KV caches of decoder using data from kv_swap_bufs
|
|
4446
|
-
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
|
4447
|
-
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
|
4448
|
-
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
4449
|
-
} else {
|
|
4450
|
-
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
|
4451
|
-
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
|
4452
|
-
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
|
|
4453
|
-
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
|
|
4454
|
-
}
|
|
4455
|
-
}
|
|
4456
|
-
|
|
4457
|
-
// then modify one-copy decoder KV caches
|
|
4458
|
-
for (auto & i : one_copy) {
|
|
4459
|
-
// skip the decoder indices that require pointer swapping
|
|
4460
|
-
if (p_swap_set.find(i) != p_swap_set.end()) {
|
|
4461
|
-
continue;
|
|
4462
|
-
}
|
|
4463
|
-
|
|
4464
|
-
if (two_copy.find(view[i]) != two_copy.end()) {
|
|
4465
|
-
// modify KV caches of decoder using data from kv_swap_bufs
|
|
4466
|
-
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
|
4467
|
-
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
|
4468
|
-
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
4469
|
-
} else {
|
|
4470
|
-
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
|
4471
|
-
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
|
4472
|
-
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
|
|
4473
|
-
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
|
|
4474
|
-
}
|
|
4475
|
-
}
|
|
4476
|
-
|
|
4477
|
-
// swap the pointers
|
|
4478
|
-
for (auto & i : p_swap_vec) {
|
|
4479
|
-
WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
|
|
4480
|
-
std::swap(src[i.first].kv_self, src[i.second].kv_self);
|
|
4481
|
-
}
|
|
4482
|
-
|
|
4483
|
-
return true;
|
|
4484
|
-
}
|
|
4485
|
-
|
|
4486
4976
|
int whisper_full_with_state(
|
|
4487
4977
|
struct whisper_context * ctx,
|
|
4488
4978
|
struct whisper_state * state,
|
|
@@ -4498,11 +4988,11 @@ int whisper_full_with_state(
|
|
|
4498
4988
|
// compute log mel spectrogram
|
|
4499
4989
|
if (params.speed_up) {
|
|
4500
4990
|
// TODO: Replace PV with more advanced algorithm
|
|
4501
|
-
|
|
4991
|
+
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
4502
4992
|
return -1;
|
|
4503
4993
|
} else {
|
|
4504
4994
|
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
4505
|
-
|
|
4995
|
+
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
4506
4996
|
return -2;
|
|
4507
4997
|
}
|
|
4508
4998
|
}
|
|
@@ -4514,13 +5004,13 @@ int whisper_full_with_state(
|
|
|
4514
5004
|
|
|
4515
5005
|
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
|
4516
5006
|
if (lang_id < 0) {
|
|
4517
|
-
|
|
5007
|
+
WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
|
|
4518
5008
|
return -3;
|
|
4519
5009
|
}
|
|
4520
5010
|
state->lang_id = lang_id;
|
|
4521
5011
|
params.language = whisper_lang_str(lang_id);
|
|
4522
5012
|
|
|
4523
|
-
|
|
5013
|
+
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
|
4524
5014
|
if (params.detect_language) {
|
|
4525
5015
|
return 0;
|
|
4526
5016
|
}
|
|
@@ -4542,6 +5032,7 @@ int whisper_full_with_state(
|
|
|
4542
5032
|
// basically don't process anything that is less than 1.0s
|
|
4543
5033
|
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
|
4544
5034
|
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
|
|
5035
|
+
WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
|
|
4545
5036
|
return 0;
|
|
4546
5037
|
}
|
|
4547
5038
|
|
|
@@ -4572,42 +5063,23 @@ int whisper_full_with_state(
|
|
|
4572
5063
|
|
|
4573
5064
|
n_decoders = std::max(1, n_decoders);
|
|
4574
5065
|
|
|
5066
|
+
if (n_decoders > WHISPER_MAX_DECODERS) {
|
|
5067
|
+
WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
|
|
5068
|
+
return -4;
|
|
5069
|
+
}
|
|
5070
|
+
|
|
4575
5071
|
// TAGS: WHISPER_DECODER_INIT
|
|
4576
5072
|
for (int j = 1; j < n_decoders; j++) {
|
|
4577
5073
|
auto & decoder = state->decoders[j];
|
|
4578
5074
|
|
|
4579
|
-
|
|
4580
|
-
decoder.kv_self = state->decoders[0].kv_self;
|
|
4581
|
-
if (!kv_cache_reinit(decoder.kv_self)) {
|
|
4582
|
-
log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
|
4583
|
-
return -4;
|
|
4584
|
-
}
|
|
4585
|
-
|
|
4586
|
-
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
|
4587
|
-
|
|
4588
|
-
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
|
|
4589
|
-
|
|
4590
|
-
decoder.probs.resize (ctx->vocab.n_vocab);
|
|
4591
|
-
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
4592
|
-
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
5075
|
+
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
|
|
4593
5076
|
|
|
4594
|
-
|
|
4595
|
-
|
|
4596
|
-
|
|
4597
|
-
|
|
4598
|
-
if (!(result)) { \
|
|
4599
|
-
log("%s: failed to add metal buffer\n", __func__); \
|
|
4600
|
-
return 0; \
|
|
4601
|
-
}
|
|
5077
|
+
decoder.probs.resize (ctx->vocab.n_vocab);
|
|
5078
|
+
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
5079
|
+
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
5080
|
+
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
|
|
4602
5081
|
|
|
4603
|
-
|
|
4604
|
-
auto & kv_self = decoder.kv_self;
|
|
4605
|
-
|
|
4606
|
-
WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
|
|
4607
|
-
#undef WHISPER_METAL_CHECK_BUF
|
|
4608
|
-
}
|
|
4609
|
-
#endif
|
|
4610
|
-
}
|
|
5082
|
+
decoder.rng = std::mt19937(0);
|
|
4611
5083
|
}
|
|
4612
5084
|
|
|
4613
5085
|
// the accumulated text context so far
|
|
@@ -4640,13 +5112,13 @@ int whisper_full_with_state(
|
|
|
4640
5112
|
|
|
4641
5113
|
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
|
4642
5114
|
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
|
4643
|
-
|
|
5115
|
+
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
|
4644
5116
|
return -5;
|
|
4645
5117
|
}
|
|
4646
5118
|
state->exp_n_audio_ctx = params.audio_ctx;
|
|
4647
5119
|
|
|
4648
5120
|
// these tokens determine the task that will be performed
|
|
4649
|
-
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
|
5121
|
+
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
|
|
4650
5122
|
|
|
4651
5123
|
if (whisper_is_multilingual(ctx)) {
|
|
4652
5124
|
const int lang_id = whisper_lang_id(params.language);
|
|
@@ -4659,17 +5131,19 @@ int whisper_full_with_state(
|
|
|
4659
5131
|
}
|
|
4660
5132
|
}
|
|
4661
5133
|
|
|
5134
|
+
// distilled models require the "no_timestamps" token
|
|
4662
5135
|
{
|
|
4663
5136
|
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
|
|
4664
|
-
|
|
4665
|
-
|
|
4666
|
-
|
|
4667
|
-
if (is_distil) {
|
|
4668
|
-
log("%s: using distilled model - forcing no_timestamps\n", __func__);
|
|
4669
|
-
prompt_init.push_back(whisper_token_not(ctx));
|
|
5137
|
+
if (is_distil && !params.no_timestamps) {
|
|
5138
|
+
WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
|
|
5139
|
+
params.no_timestamps = true;
|
|
4670
5140
|
}
|
|
4671
5141
|
}
|
|
4672
5142
|
|
|
5143
|
+
if (params.no_timestamps) {
|
|
5144
|
+
prompt_init.push_back(whisper_token_not(ctx));
|
|
5145
|
+
}
|
|
5146
|
+
|
|
4673
5147
|
int seek = seek_start;
|
|
4674
5148
|
|
|
4675
5149
|
std::vector<whisper_token> prompt;
|
|
@@ -4682,8 +5156,10 @@ int whisper_full_with_state(
|
|
|
4682
5156
|
bool has_ts;
|
|
4683
5157
|
|
|
4684
5158
|
whisper_sequence sequence;
|
|
5159
|
+
whisper_grammar grammar;
|
|
4685
5160
|
};
|
|
4686
5161
|
|
|
5162
|
+
std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
|
|
4687
5163
|
std::vector<beam_candidate> beam_candidates;
|
|
4688
5164
|
|
|
4689
5165
|
// main loop
|
|
@@ -4692,24 +5168,24 @@ int whisper_full_with_state(
|
|
|
4692
5168
|
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
|
|
4693
5169
|
|
|
4694
5170
|
params.progress_callback(
|
|
4695
|
-
ctx,
|
|
5171
|
+
ctx, state, progress_cur, params.progress_callback_user_data);
|
|
4696
5172
|
}
|
|
4697
5173
|
|
|
4698
|
-
//
|
|
5174
|
+
// if only 1 second left, then stop
|
|
4699
5175
|
if (seek + 100 >= seek_end) {
|
|
4700
5176
|
break;
|
|
4701
5177
|
}
|
|
4702
5178
|
|
|
4703
5179
|
if (params.encoder_begin_callback) {
|
|
4704
5180
|
if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
|
|
4705
|
-
|
|
5181
|
+
WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
|
|
4706
5182
|
break;
|
|
4707
5183
|
}
|
|
4708
5184
|
}
|
|
4709
5185
|
|
|
4710
5186
|
// encode audio features starting at offset seek
|
|
4711
5187
|
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
4712
|
-
|
|
5188
|
+
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
|
4713
5189
|
return -6;
|
|
4714
5190
|
}
|
|
4715
5191
|
|
|
@@ -4745,14 +5221,12 @@ int whisper_full_with_state(
|
|
|
4745
5221
|
|
|
4746
5222
|
n_decoders_cur = std::max(1, n_decoders_cur);
|
|
4747
5223
|
|
|
4748
|
-
WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
|
|
5224
|
+
WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
|
|
4749
5225
|
|
|
4750
5226
|
// TAGS: WHISPER_DECODER_INIT
|
|
4751
5227
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
4752
5228
|
auto & decoder = state->decoders[j];
|
|
4753
5229
|
|
|
4754
|
-
decoder.kv_self.n = 0;
|
|
4755
|
-
|
|
4756
5230
|
decoder.sequence.tokens.clear();
|
|
4757
5231
|
decoder.sequence.result_len = 0;
|
|
4758
5232
|
decoder.sequence.sum_logprobs_all = 0.0;
|
|
@@ -4766,10 +5240,16 @@ int whisper_full_with_state(
|
|
|
4766
5240
|
decoder.failed = false;
|
|
4767
5241
|
decoder.completed = false;
|
|
4768
5242
|
decoder.has_ts = false;
|
|
5243
|
+
|
|
5244
|
+
if (params.grammar_rules != nullptr) {
|
|
5245
|
+
decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
|
|
5246
|
+
} else {
|
|
5247
|
+
decoder.grammar = {};
|
|
5248
|
+
}
|
|
4769
5249
|
}
|
|
4770
5250
|
|
|
4771
5251
|
// init prompt and kv cache for the current iteration
|
|
4772
|
-
//
|
|
5252
|
+
// TODO: do not recompute the prompt if it is the same as previous time
|
|
4773
5253
|
{
|
|
4774
5254
|
prompt.clear();
|
|
4775
5255
|
|
|
@@ -4791,25 +5271,26 @@ int whisper_full_with_state(
|
|
|
4791
5271
|
}
|
|
4792
5272
|
WHISPER_PRINT_DEBUG("\n\n");
|
|
4793
5273
|
|
|
4794
|
-
|
|
4795
|
-
|
|
5274
|
+
whisper_kv_cache_clear(state->kv_self);
|
|
5275
|
+
|
|
5276
|
+
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
|
5277
|
+
|
|
5278
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
5279
|
+
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
4796
5280
|
return -7;
|
|
4797
5281
|
}
|
|
4798
5282
|
|
|
4799
5283
|
{
|
|
4800
5284
|
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
4801
5285
|
|
|
4802
|
-
|
|
5286
|
+
state->decoders[0].i_batch = prompt.size() - 1;
|
|
4803
5287
|
|
|
4804
|
-
state->decoders[0]
|
|
5288
|
+
whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
|
|
4805
5289
|
|
|
4806
5290
|
for (int j = 1; j < n_decoders_cur; ++j) {
|
|
4807
5291
|
auto & decoder = state->decoders[j];
|
|
4808
5292
|
|
|
4809
|
-
|
|
4810
|
-
memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, wsp_ggml_nbytes(decoder.kv_self.v));
|
|
4811
|
-
|
|
4812
|
-
decoder.kv_self.n += prompt.size();
|
|
5293
|
+
whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
|
|
4813
5294
|
|
|
4814
5295
|
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
|
4815
5296
|
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
@@ -4824,41 +5305,81 @@ int whisper_full_with_state(
|
|
|
4824
5305
|
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
4825
5306
|
|
|
4826
5307
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
|
4827
|
-
|
|
5308
|
+
for (auto & bc : bc_per_dec) {
|
|
5309
|
+
bc.clear();
|
|
5310
|
+
}
|
|
4828
5311
|
}
|
|
4829
5312
|
|
|
4830
|
-
//
|
|
4831
|
-
|
|
4832
|
-
|
|
5313
|
+
// sampling
|
|
5314
|
+
// TODO: avoid memory allocations, optimize, avoid threads?
|
|
5315
|
+
{
|
|
5316
|
+
std::atomic<int> j_cur(0);
|
|
4833
5317
|
|
|
4834
|
-
|
|
4835
|
-
|
|
4836
|
-
|
|
5318
|
+
auto process = [&]() {
|
|
5319
|
+
while (true) {
|
|
5320
|
+
const int j = j_cur.fetch_add(1);
|
|
4837
5321
|
|
|
4838
|
-
|
|
4839
|
-
|
|
4840
|
-
|
|
4841
|
-
if (t_cur < 1e-6f) {
|
|
4842
|
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
|
|
4843
|
-
} else {
|
|
4844
|
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
|
|
4845
|
-
}
|
|
5322
|
+
if (j >= n_decoders_cur) {
|
|
5323
|
+
break;
|
|
5324
|
+
}
|
|
4846
5325
|
|
|
4847
|
-
|
|
4848
|
-
} break;
|
|
4849
|
-
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
4850
|
-
{
|
|
4851
|
-
const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
|
|
5326
|
+
auto & decoder = state->decoders[j];
|
|
4852
5327
|
|
|
4853
|
-
|
|
4854
|
-
|
|
4855
|
-
|
|
4856
|
-
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
|
|
5328
|
+
if (decoder.completed || decoder.failed) {
|
|
5329
|
+
continue;
|
|
5330
|
+
}
|
|
4857
5331
|
|
|
4858
|
-
|
|
4859
|
-
|
|
4860
|
-
|
|
5332
|
+
switch (params.strategy) {
|
|
5333
|
+
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
|
5334
|
+
{
|
|
5335
|
+
if (t_cur < 1e-6f) {
|
|
5336
|
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
|
5337
|
+
} else {
|
|
5338
|
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
|
5339
|
+
}
|
|
5340
|
+
|
|
5341
|
+
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
|
5342
|
+
} break;
|
|
5343
|
+
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
5344
|
+
{
|
|
5345
|
+
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
|
5346
|
+
|
|
5347
|
+
for (const auto & token : tokens_new) {
|
|
5348
|
+
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
|
|
5349
|
+
bc_per_dec[j].back().sequence.tokens.push_back(token);
|
|
5350
|
+
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
|
|
5351
|
+
}
|
|
5352
|
+
} break;
|
|
5353
|
+
};
|
|
5354
|
+
}
|
|
4861
5355
|
};
|
|
5356
|
+
|
|
5357
|
+
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
|
5358
|
+
|
|
5359
|
+
if (n_threads == 1) {
|
|
5360
|
+
process();
|
|
5361
|
+
} else {
|
|
5362
|
+
std::vector<std::thread> threads(n_threads - 1);
|
|
5363
|
+
|
|
5364
|
+
for (int t = 0; t < n_threads - 1; ++t) {
|
|
5365
|
+
threads[t] = std::thread(process);
|
|
5366
|
+
}
|
|
5367
|
+
|
|
5368
|
+
process();
|
|
5369
|
+
|
|
5370
|
+
for (int t = 0; t < n_threads - 1; ++t) {
|
|
5371
|
+
threads[t].join();
|
|
5372
|
+
}
|
|
5373
|
+
}
|
|
5374
|
+
}
|
|
5375
|
+
|
|
5376
|
+
beam_candidates.clear();
|
|
5377
|
+
for (const auto & bc : bc_per_dec) {
|
|
5378
|
+
beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
|
|
5379
|
+
|
|
5380
|
+
if (!bc.empty()) {
|
|
5381
|
+
state->n_sample += 1;
|
|
5382
|
+
}
|
|
4862
5383
|
}
|
|
4863
5384
|
|
|
4864
5385
|
// for beam-search, choose the top candidates and update the KV caches
|
|
@@ -4871,7 +5392,6 @@ int whisper_full_with_state(
|
|
|
4871
5392
|
});
|
|
4872
5393
|
|
|
4873
5394
|
uint32_t cur_c = 0;
|
|
4874
|
-
std::vector<int> decoder_idx(n_decoders_cur, -1);
|
|
4875
5395
|
|
|
4876
5396
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
4877
5397
|
auto & decoder = state->decoders[j];
|
|
@@ -4880,23 +5400,38 @@ int whisper_full_with_state(
|
|
|
4880
5400
|
continue;
|
|
4881
5401
|
}
|
|
4882
5402
|
|
|
5403
|
+
if (cur_c >= beam_candidates.size()) {
|
|
5404
|
+
cur_c = 0;
|
|
5405
|
+
}
|
|
5406
|
+
|
|
4883
5407
|
auto & cur = beam_candidates[cur_c++];
|
|
4884
5408
|
|
|
4885
5409
|
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
|
|
4886
5410
|
++cur_c;
|
|
4887
5411
|
}
|
|
4888
5412
|
|
|
4889
|
-
decoder.sequence = cur.sequence;
|
|
4890
5413
|
decoder.seek_delta = cur.seek_delta;
|
|
4891
5414
|
decoder.has_ts = cur.has_ts;
|
|
5415
|
+
decoder.sequence = cur.sequence;
|
|
5416
|
+
decoder.grammar = cur.grammar;
|
|
5417
|
+
|
|
5418
|
+
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
|
|
4892
5419
|
|
|
4893
|
-
decoder_idx[j] = cur.decoder_idx;
|
|
4894
5420
|
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
|
4895
5421
|
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
|
4896
5422
|
}
|
|
4897
5423
|
|
|
4898
|
-
|
|
4899
|
-
|
|
5424
|
+
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
5425
|
+
auto & decoder = state->decoders[j];
|
|
5426
|
+
|
|
5427
|
+
if (decoder.completed || decoder.failed) {
|
|
5428
|
+
continue;
|
|
5429
|
+
}
|
|
5430
|
+
|
|
5431
|
+
whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
|
|
5432
|
+
whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
|
|
5433
|
+
whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
|
|
5434
|
+
}
|
|
4900
5435
|
}
|
|
4901
5436
|
|
|
4902
5437
|
// update the decoder state
|
|
@@ -4925,6 +5460,7 @@ int whisper_full_with_state(
|
|
|
4925
5460
|
|
|
4926
5461
|
// do not allow to go back in time
|
|
4927
5462
|
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
|
5463
|
+
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
|
|
4928
5464
|
failed = true; // TODO: maybe this is not a failure ?
|
|
4929
5465
|
continue;
|
|
4930
5466
|
}
|
|
@@ -4934,6 +5470,8 @@ int whisper_full_with_state(
|
|
|
4934
5470
|
has_ts = true;
|
|
4935
5471
|
}
|
|
4936
5472
|
|
|
5473
|
+
whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
|
|
5474
|
+
|
|
4937
5475
|
#ifdef WHISPER_DEBUG
|
|
4938
5476
|
{
|
|
4939
5477
|
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
|
|
@@ -4951,6 +5489,7 @@ int whisper_full_with_state(
|
|
|
4951
5489
|
if (seek + seek_delta + 100 >= seek_end) {
|
|
4952
5490
|
result_len = i + 1;
|
|
4953
5491
|
} else {
|
|
5492
|
+
WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
|
|
4954
5493
|
failed = true;
|
|
4955
5494
|
continue;
|
|
4956
5495
|
}
|
|
@@ -4961,6 +5500,7 @@ int whisper_full_with_state(
|
|
|
4961
5500
|
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
|
4962
5501
|
}
|
|
4963
5502
|
|
|
5503
|
+
WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j);
|
|
4964
5504
|
completed = true;
|
|
4965
5505
|
continue;
|
|
4966
5506
|
}
|
|
@@ -4976,6 +5516,7 @@ int whisper_full_with_state(
|
|
|
4976
5516
|
// sometimes, the decoding can get stuck in a repetition loop
|
|
4977
5517
|
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
|
4978
5518
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
|
5519
|
+
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
|
|
4979
5520
|
failed = true;
|
|
4980
5521
|
continue;
|
|
4981
5522
|
}
|
|
@@ -5003,32 +5544,83 @@ int whisper_full_with_state(
|
|
|
5003
5544
|
state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
|
|
5004
5545
|
|
|
5005
5546
|
// obtain logits for the next token
|
|
5006
|
-
|
|
5007
|
-
auto &
|
|
5547
|
+
{
|
|
5548
|
+
auto & batch = state->batch;
|
|
5008
5549
|
|
|
5009
|
-
|
|
5010
|
-
|
|
5011
|
-
|
|
5550
|
+
batch.n_tokens = 0;
|
|
5551
|
+
|
|
5552
|
+
const int n_past = prompt.size() + i;
|
|
5553
|
+
|
|
5554
|
+
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
5555
|
+
auto & decoder = state->decoders[j];
|
|
5556
|
+
|
|
5557
|
+
if (decoder.failed || decoder.completed) {
|
|
5558
|
+
continue;
|
|
5559
|
+
}
|
|
5560
|
+
|
|
5561
|
+
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
|
|
5012
5562
|
|
|
5013
|
-
|
|
5014
|
-
|
|
5563
|
+
decoder.i_batch = batch.n_tokens;
|
|
5564
|
+
|
|
5565
|
+
batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
|
|
5566
|
+
batch.pos [batch.n_tokens] = n_past;
|
|
5567
|
+
batch.n_seq_id[batch.n_tokens] = 1;
|
|
5568
|
+
batch.seq_id [batch.n_tokens][0] = j;
|
|
5569
|
+
batch.logits [batch.n_tokens] = 1;
|
|
5570
|
+
batch.n_tokens++;
|
|
5571
|
+
}
|
|
5015
5572
|
|
|
5016
|
-
|
|
5573
|
+
assert(batch.n_tokens > 0);
|
|
5017
5574
|
|
|
5018
|
-
if (!whisper_decode_internal(*ctx, *state,
|
|
5019
|
-
|
|
5575
|
+
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
5576
|
+
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
5020
5577
|
return -8;
|
|
5021
5578
|
}
|
|
5022
5579
|
|
|
5580
|
+
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
5581
|
+
|
|
5582
|
+
// TODO: avoid memory allocations, optimize, avoid threads?
|
|
5023
5583
|
{
|
|
5024
|
-
|
|
5584
|
+
std::atomic<int> j_cur(0);
|
|
5585
|
+
|
|
5586
|
+
auto process = [&]() {
|
|
5587
|
+
while (true) {
|
|
5588
|
+
const int j = j_cur.fetch_add(1);
|
|
5589
|
+
|
|
5590
|
+
if (j >= n_decoders_cur) {
|
|
5591
|
+
break;
|
|
5592
|
+
}
|
|
5593
|
+
|
|
5594
|
+
auto & decoder = state->decoders[j];
|
|
5595
|
+
|
|
5596
|
+
if (decoder.failed || decoder.completed) {
|
|
5597
|
+
continue;
|
|
5598
|
+
}
|
|
5599
|
+
|
|
5600
|
+
whisper_process_logits(*ctx, *state, decoder, params, t_cur);
|
|
5601
|
+
}
|
|
5602
|
+
};
|
|
5603
|
+
|
|
5604
|
+
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
|
5025
5605
|
|
|
5026
|
-
|
|
5606
|
+
if (n_threads == 1) {
|
|
5607
|
+
process();
|
|
5608
|
+
} else {
|
|
5609
|
+
std::vector<std::thread> threads(n_threads - 1);
|
|
5610
|
+
|
|
5611
|
+
for (int t = 0; t < n_threads - 1; ++t) {
|
|
5612
|
+
threads[t] = std::thread(process);
|
|
5613
|
+
}
|
|
5027
5614
|
|
|
5028
|
-
|
|
5615
|
+
process();
|
|
5029
5616
|
|
|
5030
|
-
|
|
5617
|
+
for (int t = 0; t < n_threads - 1; ++t) {
|
|
5618
|
+
threads[t].join();
|
|
5619
|
+
}
|
|
5620
|
+
}
|
|
5031
5621
|
}
|
|
5622
|
+
|
|
5623
|
+
state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
|
|
5032
5624
|
}
|
|
5033
5625
|
}
|
|
5034
5626
|
|
|
@@ -5068,28 +5660,27 @@ int whisper_full_with_state(
|
|
|
5068
5660
|
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
|
5069
5661
|
}
|
|
5070
5662
|
|
|
5663
|
+
bool success = true;
|
|
5664
|
+
|
|
5071
5665
|
// was the decoding successful for the current temperature?
|
|
5072
5666
|
// do fallback only if:
|
|
5073
5667
|
// - we are not at the last temperature
|
|
5074
|
-
|
|
5075
|
-
if (it != (int) temperatures.size() - 1 &&
|
|
5076
|
-
seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
|
|
5077
|
-
bool success = true;
|
|
5078
|
-
|
|
5668
|
+
if (it != (int) temperatures.size() - 1) {
|
|
5079
5669
|
const auto & decoder = state->decoders[best_decoder_id];
|
|
5080
5670
|
|
|
5081
5671
|
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
|
5672
|
+
WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
|
|
5082
5673
|
success = false;
|
|
5083
5674
|
state->n_fail_p++;
|
|
5084
5675
|
}
|
|
5676
|
+
}
|
|
5085
5677
|
|
|
5086
|
-
|
|
5087
|
-
|
|
5088
|
-
|
|
5089
|
-
|
|
5678
|
+
if (success) {
|
|
5679
|
+
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
|
5680
|
+
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
|
5681
|
+
//}
|
|
5090
5682
|
|
|
5091
|
-
|
|
5092
|
-
}
|
|
5683
|
+
break;
|
|
5093
5684
|
}
|
|
5094
5685
|
|
|
5095
5686
|
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
|
@@ -5325,11 +5916,13 @@ int whisper_full_parallel(
|
|
|
5325
5916
|
ctx->state->t_sample_us += states[i]->t_sample_us;
|
|
5326
5917
|
ctx->state->t_encode_us += states[i]->t_encode_us;
|
|
5327
5918
|
ctx->state->t_decode_us += states[i]->t_decode_us;
|
|
5919
|
+
ctx->state->t_batchd_us += states[i]->t_batchd_us;
|
|
5328
5920
|
ctx->state->t_prompt_us += states[i]->t_prompt_us;
|
|
5329
5921
|
|
|
5330
5922
|
ctx->state->n_sample += states[i]->n_sample;
|
|
5331
5923
|
ctx->state->n_encode += states[i]->n_encode;
|
|
5332
5924
|
ctx->state->n_decode += states[i]->n_decode;
|
|
5925
|
+
ctx->state->n_batchd += states[i]->n_batchd;
|
|
5333
5926
|
ctx->state->n_prompt += states[i]->n_prompt;
|
|
5334
5927
|
|
|
5335
5928
|
whisper_free_state(states[i]);
|
|
@@ -5342,12 +5935,12 @@ int whisper_full_parallel(
|
|
|
5342
5935
|
ctx->state->t_decode_us /= n_processors;
|
|
5343
5936
|
|
|
5344
5937
|
// print information about the audio boundaries
|
|
5345
|
-
|
|
5346
|
-
|
|
5938
|
+
WHISPER_LOG_WARN("\n");
|
|
5939
|
+
WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
|
|
5347
5940
|
for (int i = 0; i < n_processors - 1; ++i) {
|
|
5348
|
-
|
|
5941
|
+
WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
|
|
5349
5942
|
}
|
|
5350
|
-
|
|
5943
|
+
WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
|
|
5351
5944
|
|
|
5352
5945
|
return ret;
|
|
5353
5946
|
}
|
|
@@ -5462,8 +6055,45 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
|
5462
6055
|
size_t n = 20;
|
|
5463
6056
|
size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations
|
|
5464
6057
|
|
|
5465
|
-
// 1GB
|
|
5466
|
-
const size_t size = arr*
|
|
6058
|
+
// 1GB array
|
|
6059
|
+
const size_t size = arr*1e6;
|
|
6060
|
+
|
|
6061
|
+
double sum = 0.0;
|
|
6062
|
+
|
|
6063
|
+
// heat-up
|
|
6064
|
+
{
|
|
6065
|
+
char * src = (char *) malloc(size);
|
|
6066
|
+
char * dst = (char *) malloc(size);
|
|
6067
|
+
|
|
6068
|
+
for (size_t i = 0; i < size; i++) src[i] = i;
|
|
6069
|
+
|
|
6070
|
+
memcpy(dst, src, size); // heat-up
|
|
6071
|
+
|
|
6072
|
+
double tsum = 0.0;
|
|
6073
|
+
|
|
6074
|
+
for (size_t i = 0; i < n; i++) {
|
|
6075
|
+
const int64_t t0 = wsp_ggml_time_us();
|
|
6076
|
+
|
|
6077
|
+
memcpy(dst, src, size);
|
|
6078
|
+
|
|
6079
|
+
const int64_t t1 = wsp_ggml_time_us();
|
|
6080
|
+
|
|
6081
|
+
tsum += (t1 - t0)*1e-6;
|
|
6082
|
+
|
|
6083
|
+
src[rand() % size] = rand() % 256;
|
|
6084
|
+
}
|
|
6085
|
+
|
|
6086
|
+
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9));
|
|
6087
|
+
s += strbuf;
|
|
6088
|
+
|
|
6089
|
+
// needed to prevent the compiler from optimizing the memcpy away
|
|
6090
|
+
{
|
|
6091
|
+
for (size_t i = 0; i < size; i++) sum += dst[i];
|
|
6092
|
+
}
|
|
6093
|
+
|
|
6094
|
+
free(src);
|
|
6095
|
+
free(dst);
|
|
6096
|
+
}
|
|
5467
6097
|
|
|
5468
6098
|
// single-thread
|
|
5469
6099
|
{
|
|
@@ -5475,7 +6105,6 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
|
5475
6105
|
memcpy(dst, src, size); // heat-up
|
|
5476
6106
|
|
|
5477
6107
|
double tsum = 0.0;
|
|
5478
|
-
double sum = 0.0;
|
|
5479
6108
|
|
|
5480
6109
|
for (size_t i = 0; i < n; i++) {
|
|
5481
6110
|
const int64_t t0 = wsp_ggml_time_us();
|
|
@@ -5489,21 +6118,73 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
|
5489
6118
|
src[rand() % size] = rand() % 256;
|
|
5490
6119
|
}
|
|
5491
6120
|
|
|
5492
|
-
snprintf(strbuf, sizeof(strbuf), "memcpy:
|
|
6121
|
+
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9));
|
|
5493
6122
|
s += strbuf;
|
|
5494
6123
|
|
|
5495
6124
|
// needed to prevent the compiler from optimizing the memcpy away
|
|
5496
6125
|
{
|
|
5497
6126
|
for (size_t i = 0; i < size; i++) sum += dst[i];
|
|
6127
|
+
}
|
|
6128
|
+
|
|
6129
|
+
free(src);
|
|
6130
|
+
free(dst);
|
|
6131
|
+
}
|
|
6132
|
+
|
|
6133
|
+
// multi-thread
|
|
6134
|
+
|
|
6135
|
+
for (uint32_t k = 1; k <= n_threads; k++) {
|
|
6136
|
+
char * src = (char *) malloc(size);
|
|
6137
|
+
char * dst = (char *) malloc(size);
|
|
5498
6138
|
|
|
5499
|
-
|
|
5500
|
-
|
|
6139
|
+
for (size_t i = 0; i < size; i++) src[i] = i;
|
|
6140
|
+
|
|
6141
|
+
memcpy(dst, src, size); // heat-up
|
|
6142
|
+
|
|
6143
|
+
double tsum = 0.0;
|
|
6144
|
+
|
|
6145
|
+
auto helper = [&](int th) {
|
|
6146
|
+
const int64_t i0 = (th + 0)*size/k;
|
|
6147
|
+
const int64_t i1 = (th + 1)*size/k;
|
|
6148
|
+
|
|
6149
|
+
for (size_t i = 0; i < n; i++) {
|
|
6150
|
+
memcpy(dst + i0, src + i0, i1 - i0);
|
|
6151
|
+
|
|
6152
|
+
src[i0 + rand() % (i1 - i0)] = rand() % 256;
|
|
6153
|
+
};
|
|
6154
|
+
};
|
|
6155
|
+
|
|
6156
|
+
const int64_t t0 = wsp_ggml_time_us();
|
|
6157
|
+
|
|
6158
|
+
std::vector<std::thread> threads(k - 1);
|
|
6159
|
+
for (uint32_t th = 0; th < k - 1; ++th) {
|
|
6160
|
+
threads[th] = std::thread(helper, th);
|
|
6161
|
+
}
|
|
6162
|
+
|
|
6163
|
+
helper(k - 1);
|
|
6164
|
+
|
|
6165
|
+
for (uint32_t th = 0; th < k - 1; ++th) {
|
|
6166
|
+
threads[th].join();
|
|
6167
|
+
}
|
|
6168
|
+
|
|
6169
|
+
const int64_t t1 = wsp_ggml_time_us();
|
|
6170
|
+
|
|
6171
|
+
tsum += (t1 - t0)*1e-6;
|
|
6172
|
+
|
|
6173
|
+
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
|
|
6174
|
+
s += strbuf;
|
|
6175
|
+
|
|
6176
|
+
// needed to prevent the compiler from optimizing the memcpy away
|
|
6177
|
+
{
|
|
6178
|
+
for (size_t i = 0; i < size; i++) sum += dst[i];
|
|
5501
6179
|
}
|
|
5502
6180
|
|
|
5503
6181
|
free(src);
|
|
5504
6182
|
free(dst);
|
|
5505
6183
|
}
|
|
5506
6184
|
|
|
6185
|
+
snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
|
|
6186
|
+
s += strbuf;
|
|
6187
|
+
|
|
5507
6188
|
return s.c_str();
|
|
5508
6189
|
}
|
|
5509
6190
|
|
|
@@ -5589,12 +6270,12 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
5589
6270
|
double tsum = 0.0;
|
|
5590
6271
|
|
|
5591
6272
|
// heat-up
|
|
5592
|
-
wsp_ggml_graph_compute_helper(
|
|
6273
|
+
wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
|
|
5593
6274
|
|
|
5594
6275
|
for (int i = 0; i < n_max; ++i) {
|
|
5595
6276
|
const int64_t t0 = wsp_ggml_time_us();
|
|
5596
6277
|
|
|
5597
|
-
wsp_ggml_graph_compute_helper(
|
|
6278
|
+
wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
|
|
5598
6279
|
|
|
5599
6280
|
const int64_t t1 = wsp_ggml_time_us();
|
|
5600
6281
|
|
|
@@ -5712,7 +6393,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
5712
6393
|
const int n_samples = state.energy.size();
|
|
5713
6394
|
|
|
5714
6395
|
if (n_samples == 0) {
|
|
5715
|
-
|
|
6396
|
+
WHISPER_LOG_ERROR("%s: no signal data available\n", __func__);
|
|
5716
6397
|
return;
|
|
5717
6398
|
}
|
|
5718
6399
|
|
|
@@ -5933,6 +6614,32 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
5933
6614
|
//}
|
|
5934
6615
|
}
|
|
5935
6616
|
|
|
5936
|
-
void
|
|
5937
|
-
|
|
6617
|
+
void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
|
|
6618
|
+
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
|
6619
|
+
g_state.log_callback_user_data = user_data;
|
|
6620
|
+
}
|
|
6621
|
+
|
|
6622
|
+
WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
|
|
6623
|
+
static void whisper_log_internal(wsp_ggml_log_level level, const char * format, ...) {
|
|
6624
|
+
va_list args;
|
|
6625
|
+
va_start(args, format);
|
|
6626
|
+
char buffer[1024];
|
|
6627
|
+
int len = vsnprintf(buffer, 1024, format, args);
|
|
6628
|
+
if (len < 1024) {
|
|
6629
|
+
g_state.log_callback(level, buffer, g_state.log_callback_user_data);
|
|
6630
|
+
} else {
|
|
6631
|
+
char* buffer2 = new char[len+1];
|
|
6632
|
+
vsnprintf(buffer2, len+1, format, args);
|
|
6633
|
+
buffer2[len] = 0;
|
|
6634
|
+
g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
|
|
6635
|
+
delete[] buffer2;
|
|
6636
|
+
}
|
|
6637
|
+
va_end(args);
|
|
6638
|
+
}
|
|
6639
|
+
|
|
6640
|
+
static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data) {
|
|
6641
|
+
(void) level;
|
|
6642
|
+
(void) user_data;
|
|
6643
|
+
fputs(text, stderr);
|
|
6644
|
+
fflush(stderr);
|
|
5938
6645
|
}
|