whisper.rn 0.1.0 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4814 @@
1
+ #define WHISPER_BUILD
2
+ #include "whisper.h"
3
+
4
+ #include "ggml.h"
5
+
6
+ #include <algorithm>
7
+ #include <cassert>
8
+ #define _USE_MATH_DEFINES
9
+ #include <cmath>
10
+ #include <cstdio>
11
+ #include <cstring>
12
+ #include <fstream>
13
+ #include <map>
14
+ #include <string>
15
+ #include <thread>
16
+ #include <vector>
17
+ #include <regex>
18
+ #include <random>
19
+
20
+ #if defined(GGML_BIG_ENDIAN)
21
+ #include <bit>
22
+
23
+ template<typename T>
24
+ static T byteswap(T value) {
25
+ return std::byteswap(value);
26
+ }
27
+
28
+ template<>
29
+ float byteswap(float value) {
30
+ return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
31
+ }
32
+
33
+ template<typename T>
34
+ static void byteswap_tensor_data(ggml_tensor * tensor) {
35
+ T * datum = reinterpret_cast<T *>(tensor->data);
36
+ for (int i = 0; i < ggml_nelements(tensor); i++) {
37
+ datum[i] = byteswap(datum[i]);
38
+ }
39
+ }
40
+
41
+ static void byteswap_tensor(ggml_tensor * tensor) {
42
+ switch (tensor->type) {
43
+ case GGML_TYPE_I16: {
44
+ byteswap_tensor_data<int16_t>(tensor);
45
+ break;
46
+ }
47
+ case GGML_TYPE_F16: {
48
+ byteswap_tensor_data<ggml_fp16_t>(tensor);
49
+ break;
50
+ }
51
+ case GGML_TYPE_I32: {
52
+ byteswap_tensor_data<int32_t>(tensor);
53
+ break;
54
+ }
55
+ case GGML_TYPE_F32: {
56
+ byteswap_tensor_data<float>(tensor);
57
+ break;
58
+ }
59
+ default: { // GML_TYPE_I8
60
+ break;
61
+ }
62
+ }
63
+ }
64
+
65
+ #define BYTESWAP_VALUE(d) d = byteswap(d)
66
+ #define BYTESWAP_FILTERS(f) \
67
+ do { \
68
+ for (auto & datum : f.data) { \
69
+ datum = byteswap(datum); \
70
+ } \
71
+ } while (0)
72
+ #define BYTESWAP_TENSOR(t) \
73
+ do { \
74
+ byteswap_tensor(tensor); \
75
+ } while (0)
76
+ #else
77
+ #define BYTESWAP_VALUE(d) do {} while (0)
78
+ #define BYTESWAP_FILTERS(f) do {} while (0)
79
+ #define BYTESWAP_TENSOR(t) do {} while (0)
80
+ #endif
81
+
82
+ #define WHISPER_ASSERT(x) \
83
+ do { \
84
+ if (!(x)) { \
85
+ fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
86
+ abort(); \
87
+ } \
88
+ } while (0)
89
+
90
+ // define this to enable verbose trace logging - useful for debugging purposes
91
+ //#define WHISPER_DEBUG
92
+
93
+ #if defined(WHISPER_DEBUG)
94
+ #define WHISPER_PRINT_DEBUG(...) \
95
+ do { \
96
+ fprintf(stderr, __VA_ARGS__); \
97
+ } while (0)
98
+ #else
99
+ #define WHISPER_PRINT_DEBUG(...)
100
+ #endif
101
+
102
+ #define WHISPER_USE_FLASH_ATTN
103
+ //#define WHISPER_USE_FLASH_FF
104
+ #define WHISPER_MAX_DECODERS 16
105
+
106
+ #define WHISPER_USE_SCRATCH
107
+ #define WHISPER_MAX_SCRATCH_BUFFERS 16
108
+
109
+ // available whisper models
110
+ enum e_model {
111
+ MODEL_UNKNOWN,
112
+ MODEL_TINY,
113
+ MODEL_BASE,
114
+ MODEL_SMALL,
115
+ MODEL_MEDIUM,
116
+ MODEL_LARGE,
117
+ };
118
+
119
+ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
120
+ { "en", { 0, "english", } },
121
+ { "zh", { 1, "chinese", } },
122
+ { "de", { 2, "german", } },
123
+ { "es", { 3, "spanish", } },
124
+ { "ru", { 4, "russian", } },
125
+ { "ko", { 5, "korean", } },
126
+ { "fr", { 6, "french", } },
127
+ { "ja", { 7, "japanese", } },
128
+ { "pt", { 8, "portuguese", } },
129
+ { "tr", { 9, "turkish", } },
130
+ { "pl", { 10, "polish", } },
131
+ { "ca", { 11, "catalan", } },
132
+ { "nl", { 12, "dutch", } },
133
+ { "ar", { 13, "arabic", } },
134
+ { "sv", { 14, "swedish", } },
135
+ { "it", { 15, "italian", } },
136
+ { "id", { 16, "indonesian", } },
137
+ { "hi", { 17, "hindi", } },
138
+ { "fi", { 18, "finnish", } },
139
+ { "vi", { 19, "vietnamese", } },
140
+ { "iw", { 20, "hebrew", } },
141
+ { "uk", { 21, "ukrainian", } },
142
+ { "el", { 22, "greek", } },
143
+ { "ms", { 23, "malay", } },
144
+ { "cs", { 24, "czech", } },
145
+ { "ro", { 25, "romanian", } },
146
+ { "da", { 26, "danish", } },
147
+ { "hu", { 27, "hungarian", } },
148
+ { "ta", { 28, "tamil", } },
149
+ { "no", { 29, "norwegian", } },
150
+ { "th", { 30, "thai", } },
151
+ { "ur", { 31, "urdu", } },
152
+ { "hr", { 32, "croatian", } },
153
+ { "bg", { 33, "bulgarian", } },
154
+ { "lt", { 34, "lithuanian", } },
155
+ { "la", { 35, "latin", } },
156
+ { "mi", { 36, "maori", } },
157
+ { "ml", { 37, "malayalam", } },
158
+ { "cy", { 38, "welsh", } },
159
+ { "sk", { 39, "slovak", } },
160
+ { "te", { 40, "telugu", } },
161
+ { "fa", { 41, "persian", } },
162
+ { "lv", { 42, "latvian", } },
163
+ { "bn", { 43, "bengali", } },
164
+ { "sr", { 44, "serbian", } },
165
+ { "az", { 45, "azerbaijani", } },
166
+ { "sl", { 46, "slovenian", } },
167
+ { "kn", { 47, "kannada", } },
168
+ { "et", { 48, "estonian", } },
169
+ { "mk", { 49, "macedonian", } },
170
+ { "br", { 50, "breton", } },
171
+ { "eu", { 51, "basque", } },
172
+ { "is", { 52, "icelandic", } },
173
+ { "hy", { 53, "armenian", } },
174
+ { "ne", { 54, "nepali", } },
175
+ { "mn", { 55, "mongolian", } },
176
+ { "bs", { 56, "bosnian", } },
177
+ { "kk", { 57, "kazakh", } },
178
+ { "sq", { 58, "albanian", } },
179
+ { "sw", { 59, "swahili", } },
180
+ { "gl", { 60, "galician", } },
181
+ { "mr", { 61, "marathi", } },
182
+ { "pa", { 62, "punjabi", } },
183
+ { "si", { 63, "sinhala", } },
184
+ { "km", { 64, "khmer", } },
185
+ { "sn", { 65, "shona", } },
186
+ { "yo", { 66, "yoruba", } },
187
+ { "so", { 67, "somali", } },
188
+ { "af", { 68, "afrikaans", } },
189
+ { "oc", { 69, "occitan", } },
190
+ { "ka", { 70, "georgian", } },
191
+ { "be", { 71, "belarusian", } },
192
+ { "tg", { 72, "tajik", } },
193
+ { "sd", { 73, "sindhi", } },
194
+ { "gu", { 74, "gujarati", } },
195
+ { "am", { 75, "amharic", } },
196
+ { "yi", { 76, "yiddish", } },
197
+ { "lo", { 77, "lao", } },
198
+ { "uz", { 78, "uzbek", } },
199
+ { "fo", { 79, "faroese", } },
200
+ { "ht", { 80, "haitian creole", } },
201
+ { "ps", { 81, "pashto", } },
202
+ { "tk", { 82, "turkmen", } },
203
+ { "nn", { 83, "nynorsk", } },
204
+ { "mt", { 84, "maltese", } },
205
+ { "sa", { 85, "sanskrit", } },
206
+ { "lb", { 86, "luxembourgish", } },
207
+ { "my", { 87, "myanmar", } },
208
+ { "bo", { 88, "tibetan", } },
209
+ { "tl", { 89, "tagalog", } },
210
+ { "mg", { 90, "malagasy", } },
211
+ { "as", { 91, "assamese", } },
212
+ { "tt", { 92, "tatar", } },
213
+ { "haw", { 93, "hawaiian", } },
214
+ { "ln", { 94, "lingala", } },
215
+ { "ha", { 95, "hausa", } },
216
+ { "ba", { 96, "bashkir", } },
217
+ { "jw", { 97, "javanese", } },
218
+ { "su", { 98, "sundanese", } },
219
+ };
220
+
221
+ static const size_t MB = 1024*1024;
222
+
223
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
224
+ { MODEL_TINY, 12ull*MB },
225
+ { MODEL_BASE, 15ull*MB },
226
+ { MODEL_SMALL, 23ull*MB },
227
+ { MODEL_MEDIUM, 31ull*MB },
228
+ { MODEL_LARGE, 38ull*MB },
229
+ };
230
+
231
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
232
+ { MODEL_TINY, 18ull*MB },
233
+ { MODEL_BASE, 24ull*MB },
234
+ { MODEL_SMALL, 36ull*MB },
235
+ { MODEL_MEDIUM, 48ull*MB },
236
+ { MODEL_LARGE, 60ull*MB },
237
+ };
238
+
239
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
240
+ { MODEL_TINY, 4ull*MB },
241
+ { MODEL_BASE, 4ull*MB },
242
+ { MODEL_SMALL, 6ull*MB },
243
+ { MODEL_MEDIUM, 7ull*MB },
244
+ { MODEL_LARGE, 9ull*MB },
245
+ };
246
+
247
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
248
+ { MODEL_TINY, 4ull*MB },
249
+ { MODEL_BASE, 4ull*MB },
250
+ { MODEL_SMALL, 6ull*MB },
251
+ { MODEL_MEDIUM, 7ull*MB },
252
+ { MODEL_LARGE, 9ull*MB },
253
+ };
254
+
255
+ static const std::map<e_model, size_t> MEM_REQ_MODEL = {
256
+ { MODEL_TINY, 74ull*MB },
257
+ { MODEL_BASE, 142ull*MB },
258
+ { MODEL_SMALL, 466ull*MB },
259
+ { MODEL_MEDIUM, 1464ull*MB },
260
+ { MODEL_LARGE, 2952ull*MB },
261
+ };
262
+
263
+ static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
264
+ { MODEL_TINY, 3ull*MB },
265
+ { MODEL_BASE, 6ull*MB },
266
+ { MODEL_SMALL, 16ull*MB },
267
+ { MODEL_MEDIUM, 43ull*MB },
268
+ { MODEL_LARGE, 71ull*MB },
269
+ };
270
+
271
+ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
272
+ { MODEL_TINY, 9ull*MB },
273
+ { MODEL_BASE, 18ull*MB },
274
+ { MODEL_SMALL, 53ull*MB },
275
+ { MODEL_MEDIUM, 141ull*MB },
276
+ { MODEL_LARGE, 235ull*MB },
277
+ };
278
+
279
+ static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
280
+ { MODEL_TINY, 6ull*MB },
281
+ { MODEL_BASE, 8ull*MB },
282
+ { MODEL_SMALL, 13ull*MB },
283
+ { MODEL_MEDIUM, 22ull*MB },
284
+ { MODEL_LARGE, 33ull*MB },
285
+ };
286
+
287
+ static const std::map<e_model, size_t> MEM_REQ_DECODE = {
288
+ { MODEL_TINY, 3ull*MB },
289
+ { MODEL_BASE, 5ull*MB },
290
+ { MODEL_SMALL, 10ull*MB },
291
+ { MODEL_MEDIUM, 18ull*MB },
292
+ { MODEL_LARGE, 27ull*MB },
293
+ };
294
+
295
+ struct whisper_mel {
296
+ int n_len;
297
+ int n_mel;
298
+
299
+ std::vector<float> data;
300
+ };
301
+
302
+ struct whisper_filters {
303
+ int32_t n_mel;
304
+ int32_t n_fft;
305
+
306
+ std::vector<float> data;
307
+ };
308
+
309
+ struct whisper_vocab {
310
+ using id = int32_t;
311
+ using token = std::string;
312
+
313
+ int n_vocab = 51864;
314
+
315
+ std::map<token, id> token_to_id;
316
+ std::map<id, token> id_to_token;
317
+
318
+ id token_eot = 50256;
319
+ id token_sot = 50257;
320
+ id token_prev = 50360;
321
+ id token_solm = 50361; // ??
322
+ id token_not = 50362; // no timestamps
323
+ id token_beg = 50363;
324
+
325
+ // available tasks
326
+ static const id token_translate = 50358;
327
+ static const id token_transcribe = 50359;
328
+
329
+ bool is_multilingual() const {
330
+ return n_vocab == 51865;
331
+ }
332
+ };
333
+
334
+ struct whisper_segment {
335
+ int64_t t0;
336
+ int64_t t1;
337
+
338
+ std::string text;
339
+
340
+ std::vector<whisper_token_data> tokens;
341
+ };
342
+
343
+ // medium
344
+ // hparams: {
345
+ // 'n_mels': 80,
346
+ // 'n_vocab': 51864,
347
+ // 'n_audio_ctx': 1500,
348
+ // 'n_audio_state': 1024,
349
+ // 'n_audio_head': 16,
350
+ // 'n_audio_layer': 24,
351
+ // 'n_text_ctx': 448,
352
+ // 'n_text_state': 1024,
353
+ // 'n_text_head': 16,
354
+ // 'n_text_layer': 24
355
+ // }
356
+ //
357
+ // default hparams (Whisper tiny)
358
+ struct whisper_hparams {
359
+ int32_t n_vocab = 51864;
360
+ int32_t n_audio_ctx = 1500;
361
+ int32_t n_audio_state = 384;
362
+ int32_t n_audio_head = 6;
363
+ int32_t n_audio_layer = 4;
364
+ int32_t n_text_ctx = 448;
365
+ int32_t n_text_state = 384;
366
+ int32_t n_text_head = 6;
367
+ int32_t n_text_layer = 4;
368
+ int32_t n_mels = 80;
369
+ int32_t f16 = 1;
370
+ };
371
+
372
+ // audio encoding layer
373
+ struct whisper_layer_encoder {
374
+ // encoder.blocks.*.attn_ln
375
+ struct ggml_tensor * attn_ln_0_w;
376
+ struct ggml_tensor * attn_ln_0_b;
377
+
378
+ // encoder.blocks.*.attn.out
379
+ struct ggml_tensor * attn_ln_1_w;
380
+ struct ggml_tensor * attn_ln_1_b;
381
+
382
+ // encoder.blocks.*.attn.query
383
+ struct ggml_tensor * attn_q_w;
384
+ struct ggml_tensor * attn_q_b;
385
+
386
+ // encoder.blocks.*.attn.key
387
+ struct ggml_tensor * attn_k_w;
388
+
389
+ // encoder.blocks.*.attn.value
390
+ struct ggml_tensor * attn_v_w;
391
+ struct ggml_tensor * attn_v_b;
392
+
393
+ // encoder.blocks.*.mlp_ln
394
+ struct ggml_tensor * mlp_ln_w;
395
+ struct ggml_tensor * mlp_ln_b;
396
+
397
+ // encoder.blocks.*.mlp.0
398
+ struct ggml_tensor * mlp_0_w;
399
+ struct ggml_tensor * mlp_0_b;
400
+
401
+ // encoder.blocks.*.mlp.2
402
+ struct ggml_tensor * mlp_1_w;
403
+ struct ggml_tensor * mlp_1_b;
404
+ };
405
+
406
+ // token decoding layer
407
+ struct whisper_layer_decoder {
408
+ // decoder.blocks.*.attn_ln
409
+ struct ggml_tensor * attn_ln_0_w;
410
+ struct ggml_tensor * attn_ln_0_b;
411
+
412
+ // decoder.blocks.*.attn.out
413
+ struct ggml_tensor * attn_ln_1_w;
414
+ struct ggml_tensor * attn_ln_1_b;
415
+
416
+ // decoder.blocks.*.attn.query
417
+ struct ggml_tensor * attn_q_w;
418
+ struct ggml_tensor * attn_q_b;
419
+
420
+ // decoder.blocks.*.attn.key
421
+ struct ggml_tensor * attn_k_w;
422
+
423
+ // decoder.blocks.*.attn.value
424
+ struct ggml_tensor * attn_v_w;
425
+ struct ggml_tensor * attn_v_b;
426
+
427
+ // decoder.blocks.*.cross_attn_ln
428
+ struct ggml_tensor * cross_attn_ln_0_w;
429
+ struct ggml_tensor * cross_attn_ln_0_b;
430
+
431
+ // decoder.blocks.*.cross_attn.out
432
+ struct ggml_tensor * cross_attn_ln_1_w;
433
+ struct ggml_tensor * cross_attn_ln_1_b;
434
+
435
+ // decoder.blocks.*.cross_attn.query
436
+ struct ggml_tensor * cross_attn_q_w;
437
+ struct ggml_tensor * cross_attn_q_b;
438
+
439
+ // decoder.blocks.*.cross_attn.key
440
+ struct ggml_tensor * cross_attn_k_w;
441
+
442
+ // decoder.blocks.*.cross_attn.value
443
+ struct ggml_tensor * cross_attn_v_w;
444
+ struct ggml_tensor * cross_attn_v_b;
445
+
446
+ // decoder.blocks.*.mlp_ln
447
+ struct ggml_tensor * mlp_ln_w;
448
+ struct ggml_tensor * mlp_ln_b;
449
+
450
+ // decoder.blocks.*.mlp.0
451
+ struct ggml_tensor * mlp_0_w;
452
+ struct ggml_tensor * mlp_0_b;
453
+
454
+ // decoder.blocks.*.mlp.2
455
+ struct ggml_tensor * mlp_1_w;
456
+ struct ggml_tensor * mlp_1_b;
457
+ };
458
+
459
+ struct whisper_kv_cache {
460
+ struct ggml_tensor * k;
461
+ struct ggml_tensor * v;
462
+
463
+ struct ggml_context * ctx;
464
+
465
+ std::vector<uint8_t> buf;
466
+
467
+ int n; // number of tokens currently in the cache
468
+ };
469
+
470
+ struct whisper_model {
471
+ e_model type = MODEL_UNKNOWN;
472
+
473
+ whisper_hparams hparams;
474
+ whisper_filters filters;
475
+
476
+ // encoder.positional_embedding
477
+ struct ggml_tensor * e_pe;
478
+
479
+ // encoder.conv1
480
+ struct ggml_tensor * e_conv_1_w;
481
+ struct ggml_tensor * e_conv_1_b;
482
+
483
+ // encoder.conv2
484
+ struct ggml_tensor * e_conv_2_w;
485
+ struct ggml_tensor * e_conv_2_b;
486
+
487
+ // encoder.ln_post
488
+ struct ggml_tensor * e_ln_w;
489
+ struct ggml_tensor * e_ln_b;
490
+
491
+ // decoder.positional_embedding
492
+ struct ggml_tensor * d_pe;
493
+
494
+ // decoder.token_embedding
495
+ struct ggml_tensor * d_te;
496
+
497
+ // decoder.ln
498
+ struct ggml_tensor * d_ln_w;
499
+ struct ggml_tensor * d_ln_b;
500
+
501
+ std::vector<whisper_layer_encoder> layers_encoder;
502
+ std::vector<whisper_layer_decoder> layers_decoder;
503
+
504
+ // context
505
+ struct ggml_context * ctx;
506
+
507
+ // the model memory buffer is read-only and can be shared between processors
508
+ std::vector<uint8_t> * buf;
509
+
510
+ // tensors
511
+ int n_loaded;
512
+ std::map<std::string, struct ggml_tensor *> tensors;
513
+ };
514
+
515
+ struct whisper_sequence {
516
+ std::vector<whisper_token_data> tokens;
517
+
518
+ // the accumulated transcription in the current interation (used to truncate the tokens array)
519
+ int result_len;
520
+
521
+ double sum_logprobs_all; // the sum of the log probabilities of the tokens
522
+ double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens)
523
+ double avg_logprobs; // the average log probability of the tokens
524
+ double entropy; // the entropy of the tokens
525
+ double score; // likelihood rank score
526
+ };
527
+
528
+ // TAGS: WHISPER_DECODER_INIT
529
+ struct whisper_decoder {
530
+ // each decoders keeps its own KV-cache
531
+ whisper_kv_cache kv_self;
532
+
533
+ // the currently generated sequence of tokens
534
+ whisper_sequence sequence;
535
+
536
+ int seek_delta; // the window shift found so far based on the decoded timestamp tokens
537
+
538
+ bool failed; // has the current segment failed to decode?
539
+ bool completed; // has the decoder completed the current segment?
540
+ bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
541
+
542
+ // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
543
+ std::vector<float> probs;
544
+ std::vector<float> logits;
545
+ std::vector<float> logprobs;
546
+
547
+ std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
548
+ };
549
+
550
+ struct whisper_context {
551
+ int64_t t_load_us = 0;
552
+ int64_t t_mel_us = 0;
553
+ int64_t t_sample_us = 0;
554
+ int64_t t_encode_us = 0;
555
+ int64_t t_decode_us = 0;
556
+ int64_t t_start_us = 0;
557
+
558
+ int32_t n_sample = 0; // number of tokens sampled
559
+ int32_t n_encode = 0; // number of encoder calls
560
+ int32_t n_decode = 0; // number of decoder calls
561
+ int32_t n_fail_p = 0; // number of logprob threshold failures
562
+ int32_t n_fail_h = 0; // number of entropy threshold failures
563
+
564
+ ggml_type wtype; // weight type (FP32 or FP16)
565
+
566
+ whisper_mel mel;
567
+
568
+ whisper_model model;
569
+ whisper_vocab vocab;
570
+
571
+ // cross-attention KV cache for the decoders
572
+ // shared between all decoders
573
+ whisper_kv_cache kv_cross;
574
+
575
+ whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
576
+
577
+ // memory buffers used by encode / decode contexts
578
+ std::vector<uint8_t> buf_compute;
579
+ std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
580
+
581
+ int buf_last = 0;
582
+ size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
583
+
584
+ // decode output (2-dimensional array: [n_tokens][n_vocab])
585
+ std::vector<float> logits;
586
+
587
+ std::vector<whisper_segment> result_all;
588
+ std::vector<whisper_token> prompt_past;
589
+
590
+ // work container used to avoid memory allocations
591
+ std::vector<std::pair<double, whisper_vocab::id>> logits_id;
592
+
593
+ mutable std::mt19937 rng; // used for sampling at t > 0.0
594
+
595
+ int lang_id = 0; // english by default
596
+
597
+ // [EXPERIMENTAL] token-level timestamps data
598
+ int64_t t_beg = 0;
599
+ int64_t t_last = 0;
600
+ whisper_token tid_last;
601
+ std::vector<float> energy; // PCM signal energy
602
+
603
+ // [EXPERIMENTAL] speed-up techniques
604
+ int32_t exp_n_audio_ctx = 0; // 0 - use default
605
+
606
+ void use_buf(struct ggml_context * ctx, int i) {
607
+ #if defined(WHISPER_USE_SCRATCH)
608
+ size_t last_size = 0;
609
+
610
+ if (i == -1) {
611
+ last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
612
+ } else {
613
+ auto & buf = buf_scratch[i];
614
+ last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
615
+ }
616
+
617
+ if (buf_last >= 0) {
618
+ buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
619
+ }
620
+
621
+ buf_last = i;
622
+ #else
623
+ (void) i;
624
+ (void) ctx;
625
+ #endif
626
+ }
627
+
628
+ size_t get_buf_max_mem(int i) const {
629
+ #if defined(WHISPER_USE_SCRATCH)
630
+ return buf_max_size[i];
631
+ #else
632
+ (void) i;
633
+ return 0;
634
+ #endif
635
+ }
636
+ };
637
+
638
+ template<typename T>
639
+ static void read_safe(whisper_model_loader * loader, T & dest) {
640
+ loader->read(loader->context, &dest, sizeof(T));
641
+ BYTESWAP_VALUE(dest);
642
+ }
643
+
644
+ static bool kv_cache_init(
645
+ const struct whisper_hparams & hparams,
646
+ const size_t mem_bytes,
647
+ struct whisper_kv_cache & cache,
648
+ ggml_type wtype,
649
+ int n_ctx) {
650
+ cache.buf.resize(mem_bytes);
651
+
652
+ struct ggml_init_params params;
653
+ params.mem_size = cache.buf.size();
654
+ params.mem_buffer = cache.buf.data();
655
+
656
+ cache.ctx = ggml_init(params);
657
+
658
+ if (!cache.ctx) {
659
+ fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
660
+ return false;
661
+ }
662
+
663
+ const int n_text_state = hparams.n_text_state;
664
+ const int n_text_layer = hparams.n_text_layer;
665
+
666
+ const int n_mem = n_text_layer*n_ctx;
667
+ const int n_elements = n_text_state*n_mem;
668
+
669
+ cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
670
+ cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
671
+
672
+ return true;
673
+ }
674
+
675
+ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
676
+ WHISPER_ASSERT(cache.ctx);
677
+
678
+ const int n_elements = ggml_nelements(cache.k);
679
+ WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
680
+
681
+ const ggml_type wtype = cache.k->type;
682
+ WHISPER_ASSERT(wtype == cache.v->type);
683
+
684
+ WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
685
+
686
+ struct ggml_init_params params;
687
+ params.mem_size = cache.buf.size();
688
+ params.mem_buffer = cache.buf.data();
689
+
690
+ cache.ctx = ggml_init(params);
691
+
692
+ if (!cache.ctx) {
693
+ fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
694
+ return false;
695
+ }
696
+
697
+ cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
698
+ cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
699
+
700
+ return true;
701
+ }
702
+
703
+ static void kv_cache_free(struct whisper_kv_cache & cache) {
704
+ if (cache.ctx) {
705
+ ggml_free(cache.ctx);
706
+ cache.ctx = nullptr;
707
+ }
708
+ }
709
+
710
+ // load the model from a ggml file
711
+ //
712
+ // file format:
713
+ //
714
+ // - hparams
715
+ // - pre-computed mel filters
716
+ // - vocab
717
+ // - weights
718
+ //
719
+ // see the convert-pt-to-ggml.py script for details
720
+ //
721
+ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
722
+ fprintf(stderr, "%s: loading model\n", __func__);
723
+
724
+ const int64_t t_start_us = ggml_time_us();
725
+
726
+ wctx.t_start_us = t_start_us;
727
+
728
+ auto & model = wctx.model;
729
+ auto & vocab = wctx.vocab;
730
+
731
+ // verify magic
732
+ {
733
+ uint32_t magic;
734
+ read_safe(loader, magic);
735
+ if (magic != 0x67676d6c) {
736
+ fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
737
+ return false;
738
+ }
739
+ }
740
+
741
+ //load hparams
742
+ {
743
+ auto & hparams = model.hparams;
744
+
745
+ read_safe(loader, hparams.n_vocab);
746
+ read_safe(loader, hparams.n_audio_ctx);
747
+ read_safe(loader, hparams.n_audio_state);
748
+ read_safe(loader, hparams.n_audio_head);
749
+ read_safe(loader, hparams.n_audio_layer);
750
+ read_safe(loader, hparams.n_text_ctx);
751
+ read_safe(loader, hparams.n_text_state);
752
+ read_safe(loader, hparams.n_text_head);
753
+ read_safe(loader, hparams.n_text_layer);
754
+ read_safe(loader, hparams.n_mels);
755
+ read_safe(loader, hparams.f16);
756
+
757
+ assert(hparams.n_text_state == hparams.n_audio_state);
758
+
759
+ if (hparams.n_audio_layer == 4) {
760
+ model.type = e_model::MODEL_TINY;
761
+ }
762
+
763
+ if (hparams.n_audio_layer == 6) {
764
+ model.type = e_model::MODEL_BASE;
765
+ }
766
+
767
+ if (hparams.n_audio_layer == 12) {
768
+ model.type = e_model::MODEL_SMALL;
769
+ }
770
+
771
+ if (hparams.n_audio_layer == 24) {
772
+ model.type = e_model::MODEL_MEDIUM;
773
+ }
774
+
775
+ if (hparams.n_audio_layer == 32) {
776
+ model.type = e_model::MODEL_LARGE;
777
+ }
778
+
779
+ // for the big tensors, we have the option to store the data in 16-bit floats
780
+ // in order to save memory and also to speed up the computation
781
+ wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
782
+
783
+ const size_t scale = model.hparams.f16 ? 1 : 2;
784
+
785
+ fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
786
+ fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
787
+ fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
788
+ fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
789
+ fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
790
+ fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
791
+ fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
792
+ fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
793
+ fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
794
+ fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
795
+ fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
796
+ fprintf(stderr, "%s: type = %d\n", __func__, model.type);
797
+
798
+ // print memory requirements
799
+ {
800
+ // this is the total memory required to run the inference
801
+ const size_t mem_required =
802
+ MEM_REQ_SCRATCH0.at (model.type) +
803
+ MEM_REQ_SCRATCH1.at (model.type) +
804
+ MEM_REQ_SCRATCH2.at (model.type) +
805
+ MEM_REQ_SCRATCH3.at (model.type) +
806
+ scale*MEM_REQ_MODEL.at (model.type) +
807
+ scale*MEM_REQ_KV_CROSS.at(model.type) +
808
+ scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
809
+
810
+ // this is the memory required by one decoder
811
+ const size_t mem_required_decoder =
812
+ scale*MEM_REQ_KV_SELF.at(model.type);
813
+
814
+ fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
815
+ mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
816
+ }
817
+
818
+ // initialize all memory buffers
819
+ // always have at least one decoder
820
+
821
+ wctx.model.buf = new std::vector<uint8_t>();
822
+ wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
823
+
824
+ if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
825
+ fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
826
+ return false;
827
+ }
828
+
829
+ {
830
+ const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
831
+ fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
832
+ }
833
+
834
+ if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
835
+ fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
836
+ return false;
837
+ }
838
+
839
+ {
840
+ const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
841
+ fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
842
+ }
843
+
844
+ wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
845
+
846
+ wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
847
+ wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
848
+ wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
849
+ wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
850
+ }
851
+
852
+ // load mel filters
853
+ {
854
+ auto & filters = wctx.model.filters;
855
+
856
+ read_safe(loader, filters.n_mel);
857
+ read_safe(loader, filters.n_fft);
858
+
859
+ filters.data.resize(filters.n_mel * filters.n_fft);
860
+ loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
861
+ BYTESWAP_FILTERS(filters);
862
+ }
863
+
864
+ // load vocab
865
+ {
866
+ int32_t n_vocab = 0;
867
+ read_safe(loader, n_vocab);
868
+
869
+ //if (n_vocab != model.hparams.n_vocab) {
870
+ // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
871
+ // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
872
+ // return false;
873
+ //}
874
+
875
+ std::string word;
876
+ std::vector<char> tmp;
877
+
878
+ tmp.reserve(128);
879
+
880
+ for (int i = 0; i < n_vocab; i++) {
881
+ uint32_t len;
882
+ read_safe(loader, len);
883
+
884
+ if (len > 0) {
885
+ tmp.resize(len);
886
+ loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
887
+ word.assign(&tmp[0], tmp.size());
888
+ } else {
889
+ // seems like we have an empty-string token in multi-language models (i = 50256)
890
+ //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
891
+ word = "";
892
+ }
893
+
894
+ vocab.token_to_id[word] = i;
895
+ vocab.id_to_token[i] = word;
896
+
897
+ //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
898
+ }
899
+
900
+ vocab.n_vocab = model.hparams.n_vocab;
901
+ if (vocab.is_multilingual()) {
902
+ vocab.token_eot++;
903
+ vocab.token_sot++;
904
+ vocab.token_prev++;
905
+ vocab.token_solm++;
906
+ vocab.token_not++;
907
+ vocab.token_beg++;
908
+ }
909
+
910
+ if (n_vocab < model.hparams.n_vocab) {
911
+ fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
912
+ for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
913
+ if (i > vocab.token_beg) {
914
+ word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
915
+ } else if (i == vocab.token_eot) {
916
+ word = "[_EOT_]";
917
+ } else if (i == vocab.token_sot) {
918
+ word = "[_SOT_]";
919
+ } else if (i == vocab.token_prev) {
920
+ word = "[_PREV_]";
921
+ } else if (i == vocab.token_not) {
922
+ word = "[_NOT_]";
923
+ } else if (i == vocab.token_beg) {
924
+ word = "[_BEG_]";
925
+ } else {
926
+ word = "[_extra_token_" + std::to_string(i) + "]";
927
+ }
928
+ vocab.token_to_id[word] = i;
929
+ vocab.id_to_token[i] = word;
930
+ }
931
+ }
932
+
933
+ wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
934
+
935
+ wctx.logits_id.reserve(n_vocab);
936
+
937
+ // TAGS: WHISPER_DECODER_INIT
938
+ wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
939
+
940
+ wctx.decoders[0].probs.reserve (vocab.n_vocab);
941
+ wctx.decoders[0].logits.reserve (vocab.n_vocab);
942
+ wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
943
+ }
944
+
945
+ size_t ctx_size = 0;
946
+
947
+ const ggml_type wtype = wctx.wtype;
948
+
949
+ {
950
+ const auto & hparams = model.hparams;
951
+
952
+ const int n_vocab = hparams.n_vocab;
953
+
954
+ const int n_audio_ctx = hparams.n_audio_ctx;
955
+ const int n_audio_state = hparams.n_audio_state;
956
+ const int n_audio_layer = hparams.n_audio_layer;
957
+
958
+ const int n_text_ctx = hparams.n_text_ctx;
959
+ const int n_text_state = hparams.n_text_state;
960
+ const int n_text_layer = hparams.n_text_layer;
961
+
962
+ const int n_mels = hparams.n_mels;
963
+
964
+ // encoder
965
+ {
966
+ ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
967
+
968
+ ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
969
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
970
+
971
+ ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
972
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
973
+
974
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
975
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
976
+ }
977
+
978
+ // decoder
979
+ {
980
+ ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
981
+
982
+ ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
983
+
984
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
985
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
986
+ }
987
+
988
+ // encoder layers
989
+ {
990
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
991
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
992
+
993
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
994
+ ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
995
+
996
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
997
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
998
+
999
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
1000
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
1001
+
1002
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
1003
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
1004
+
1005
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
1006
+
1007
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
1008
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
1009
+
1010
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
1011
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
1012
+ }
1013
+
1014
+ // decoder layers
1015
+ {
1016
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
1017
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
1018
+
1019
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
1020
+ ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
1021
+
1022
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
1023
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
1024
+
1025
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
1026
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
1027
+
1028
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
1029
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
1030
+
1031
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
1032
+
1033
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
1034
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
1035
+
1036
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
1037
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
1038
+ //
1039
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
1040
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
1041
+
1042
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
1043
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
1044
+
1045
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
1046
+
1047
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
1048
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
1049
+
1050
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
1051
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
1052
+ }
1053
+
1054
+ ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
1055
+
1056
+ fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
1057
+ }
1058
+
1059
+ // create the ggml context
1060
+ {
1061
+ struct ggml_init_params params;
1062
+ params.mem_size = wctx.model.buf->size();
1063
+ params.mem_buffer = wctx.model.buf->data();
1064
+
1065
+ model.ctx = ggml_init(params);
1066
+ if (!model.ctx) {
1067
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
1068
+ return false;
1069
+ }
1070
+ }
1071
+
1072
+ // prepare memory for the weights
1073
+ {
1074
+ auto & ctx = model.ctx;
1075
+
1076
+ const auto & hparams = model.hparams;
1077
+
1078
+ const int n_vocab = hparams.n_vocab;
1079
+
1080
+ const int n_audio_ctx = hparams.n_audio_ctx;
1081
+ const int n_audio_state = hparams.n_audio_state;
1082
+ const int n_audio_layer = hparams.n_audio_layer;
1083
+
1084
+ const int n_text_ctx = hparams.n_text_ctx;
1085
+ const int n_text_state = hparams.n_text_state;
1086
+ const int n_text_layer = hparams.n_text_layer;
1087
+
1088
+ const int n_mels = hparams.n_mels;
1089
+
1090
+ model.layers_encoder.resize(n_audio_layer);
1091
+ model.layers_decoder.resize(n_text_layer);
1092
+
1093
+ // encoder
1094
+ {
1095
+ model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1096
+
1097
+ model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
1098
+ model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1099
+
1100
+ model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
1101
+ model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1102
+
1103
+ model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1104
+ model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1105
+
1106
+ // map by name
1107
+ model.tensors["encoder.positional_embedding"] = model.e_pe;
1108
+
1109
+ model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
1110
+ model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
1111
+
1112
+ model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
1113
+ model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
1114
+
1115
+ model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
1116
+ model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
1117
+
1118
+ for (int i = 0; i < n_audio_layer; ++i) {
1119
+ auto & layer = model.layers_encoder[i];
1120
+
1121
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1122
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1123
+
1124
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
1125
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
1126
+
1127
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
1128
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1129
+
1130
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1131
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1132
+
1133
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1134
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1135
+
1136
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1137
+
1138
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1139
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1140
+
1141
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1142
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1143
+
1144
+ // map by name
1145
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1146
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1147
+
1148
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1149
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1150
+
1151
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1152
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1153
+
1154
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1155
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1156
+
1157
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1158
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1159
+
1160
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1161
+
1162
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1163
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1164
+
1165
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1166
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1167
+ }
1168
+ }
1169
+
1170
+ // decoder
1171
+ {
1172
+ model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
1173
+
1174
+ model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
1175
+
1176
+ model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1177
+ model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1178
+
1179
+ // map by name
1180
+ model.tensors["decoder.positional_embedding"] = model.d_pe;
1181
+
1182
+ model.tensors["decoder.token_embedding.weight"] = model.d_te;
1183
+
1184
+ model.tensors["decoder.ln.weight"] = model.d_ln_w;
1185
+ model.tensors["decoder.ln.bias"] = model.d_ln_b;
1186
+
1187
+ for (int i = 0; i < n_text_layer; ++i) {
1188
+ auto & layer = model.layers_decoder[i];
1189
+
1190
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1191
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1192
+
1193
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
1194
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
1195
+
1196
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
1197
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1198
+
1199
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1200
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1201
+
1202
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1203
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1204
+
1205
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1206
+
1207
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1208
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1209
+
1210
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1211
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1212
+
1213
+ layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1214
+ layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1215
+
1216
+ layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1217
+ layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1218
+
1219
+ layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1220
+
1221
+ layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1222
+ layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1223
+
1224
+ layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1225
+ layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1226
+
1227
+ // map by name
1228
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1229
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1230
+
1231
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1232
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1233
+
1234
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1235
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1236
+
1237
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1238
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1239
+
1240
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1241
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1242
+
1243
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1244
+
1245
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1246
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1247
+
1248
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1249
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1250
+
1251
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
1252
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
1253
+
1254
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
1255
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
1256
+
1257
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
1258
+
1259
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
1260
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
1261
+
1262
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
1263
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
1264
+ }
1265
+ }
1266
+ }
1267
+
1268
+ // load weights
1269
+ {
1270
+ size_t total_size = 0;
1271
+
1272
+ model.n_loaded = 0;
1273
+
1274
+ while (true) {
1275
+ int32_t n_dims;
1276
+ int32_t length;
1277
+ int32_t ftype;
1278
+
1279
+ read_safe(loader, n_dims);
1280
+ read_safe(loader, length);
1281
+ read_safe(loader, ftype);
1282
+
1283
+ if (loader->eof(loader->context)) {
1284
+ break;
1285
+ }
1286
+
1287
+ int32_t nelements = 1;
1288
+ int32_t ne[3] = { 1, 1, 1 };
1289
+ for (int i = 0; i < n_dims; ++i) {
1290
+ read_safe(loader, ne[i]);
1291
+ nelements *= ne[i];
1292
+ }
1293
+
1294
+ std::string name;
1295
+ std::vector<char> tmp(length); // create a buffer
1296
+ loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
1297
+ name.assign(&tmp[0], tmp.size());
1298
+
1299
+ if (model.tensors.find(name) == model.tensors.end()) {
1300
+ fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
1301
+ return false;
1302
+ }
1303
+
1304
+ auto tensor = model.tensors[name.data()];
1305
+ if (ggml_nelements(tensor) != nelements) {
1306
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1307
+ return false;
1308
+ }
1309
+
1310
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1311
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1312
+ __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
1313
+ return false;
1314
+ }
1315
+
1316
+ const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
1317
+
1318
+ if (nelements*bpe != ggml_nbytes(tensor)) {
1319
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1320
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1321
+ return false;
1322
+ }
1323
+
1324
+ loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1325
+ BYTESWAP_TENSOR(tensor);
1326
+
1327
+ //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
1328
+ total_size += ggml_nbytes(tensor);
1329
+ model.n_loaded++;
1330
+ }
1331
+
1332
+ fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
1333
+
1334
+ if (model.n_loaded == 0) {
1335
+ fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1336
+ } else if (model.n_loaded != (int) model.tensors.size()) {
1337
+ fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1338
+ return false;
1339
+ }
1340
+ }
1341
+
1342
+ wctx.rng = std::mt19937(0);
1343
+
1344
+ wctx.t_load_us = ggml_time_us() - t_start_us;
1345
+
1346
+ return true;
1347
+ }
1348
+
1349
+ // evaluate the encoder
1350
+ //
1351
+ // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1352
+ // part of the transformer model and returns the encoded features
1353
+ //
1354
+ // - model: the model
1355
+ // - n_threads: number of threads to use
1356
+ // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1357
+ //
1358
+ static bool whisper_encode(
1359
+ whisper_context & wctx,
1360
+ const int mel_offset,
1361
+ const int n_threads) {
1362
+ const int64_t t_start_us = ggml_time_us();
1363
+
1364
+ const auto & model = wctx.model;
1365
+ const auto & mel_inp = wctx.mel;
1366
+ const auto & hparams = model.hparams;
1367
+
1368
+ const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1369
+ const int n_state = hparams.n_audio_state;
1370
+ const int n_head = hparams.n_audio_head;
1371
+ const int n_layer = hparams.n_audio_layer;
1372
+
1373
+ const int n_mels = hparams.n_mels;
1374
+ assert(mel_inp.n_mel == n_mels);
1375
+
1376
+ struct ggml_init_params params;
1377
+ params.mem_size = wctx.buf_compute.size();
1378
+ params.mem_buffer = wctx.buf_compute.data();
1379
+
1380
+ struct ggml_context * ctx0 = ggml_init(params);
1381
+
1382
+ wctx.use_buf(ctx0, 0);
1383
+
1384
+ struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1385
+ assert(mel->type == GGML_TYPE_F32);
1386
+ {
1387
+ float * dst = (float *) mel->data;
1388
+ memset(dst, 0, ggml_nbytes(mel));
1389
+
1390
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
1391
+ const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1392
+
1393
+ for (int j = 0; j < mel_inp.n_mel; ++j) {
1394
+ for (int i = i0; i < i1; ++i) {
1395
+ dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1396
+ }
1397
+ }
1398
+ }
1399
+
1400
+ struct ggml_tensor * cur;
1401
+
1402
+ // convolution + gelu
1403
+ {
1404
+ wctx.use_buf(ctx0, 1);
1405
+
1406
+ cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1407
+ cur = ggml_add(ctx0,
1408
+ ggml_repeat(ctx0,
1409
+ model.e_conv_1_b,
1410
+ cur),
1411
+ cur);
1412
+
1413
+ cur = ggml_gelu(ctx0, cur);
1414
+
1415
+ wctx.use_buf(ctx0, 0);
1416
+
1417
+ cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1418
+ cur = ggml_add(ctx0,
1419
+ ggml_repeat(ctx0,
1420
+ model.e_conv_2_b,
1421
+ cur),
1422
+ cur);
1423
+
1424
+ cur = ggml_gelu(ctx0, cur);
1425
+ }
1426
+
1427
+ wctx.use_buf(ctx0, 3);
1428
+
1429
+ // ===================================================================
1430
+ // NOTE: experimenting with partial evaluation of the encoder (ignore)
1431
+ //static int iter = -1;
1432
+ //const int n_iter = 1500/n_ctx;
1433
+
1434
+ //iter = (iter + 1) % n_iter;
1435
+
1436
+ //if (iter == 0) {
1437
+ // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
1438
+ // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
1439
+ //}
1440
+
1441
+ static int iter = 0;
1442
+
1443
+ const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1444
+ const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1445
+
1446
+ struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1447
+
1448
+ cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
1449
+
1450
+ // ===================================================================
1451
+
1452
+ // original:
1453
+ //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1454
+
1455
+ struct ggml_tensor * inpL = cur;
1456
+
1457
+ for (int il = 0; il < n_layer; ++il) {
1458
+ const auto & layer = model.layers_encoder[il];
1459
+
1460
+ // norm
1461
+ {
1462
+ wctx.use_buf(ctx0, 0);
1463
+
1464
+ cur = ggml_norm(ctx0, inpL);
1465
+
1466
+ // cur = ln_0_w*cur + ln_0_b
1467
+ cur = ggml_add(ctx0,
1468
+ ggml_mul(ctx0,
1469
+ ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1470
+ cur),
1471
+ ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1472
+ }
1473
+
1474
+ // self-attention
1475
+ {
1476
+ wctx.use_buf(ctx0, 1);
1477
+
1478
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1479
+ layer.attn_q_w,
1480
+ cur);
1481
+
1482
+ Qcur = ggml_add(ctx0,
1483
+ ggml_repeat(ctx0,
1484
+ layer.attn_q_b,
1485
+ Qcur),
1486
+ Qcur);
1487
+
1488
+ //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1489
+
1490
+ // note: no bias for Key
1491
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1492
+ layer.attn_k_w,
1493
+ cur);
1494
+
1495
+ //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1496
+
1497
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1498
+ layer.attn_v_w,
1499
+ cur);
1500
+
1501
+ Vcur = ggml_add(ctx0,
1502
+ ggml_repeat(ctx0,
1503
+ layer.attn_v_b,
1504
+ Vcur),
1505
+ Vcur);
1506
+
1507
+ // ------
1508
+
1509
+ wctx.use_buf(ctx0, 0);
1510
+
1511
+ #ifdef WHISPER_USE_FLASH_ATTN
1512
+ struct ggml_tensor * Q =
1513
+ ggml_permute(ctx0,
1514
+ ggml_cpy(ctx0,
1515
+ Qcur,
1516
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1517
+ 0, 2, 1, 3);
1518
+
1519
+ struct ggml_tensor * K =
1520
+ ggml_permute(ctx0,
1521
+ ggml_cpy(ctx0,
1522
+ Kcur,
1523
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1524
+ 0, 2, 1, 3);
1525
+
1526
+ struct ggml_tensor * V =
1527
+ ggml_cpy(ctx0,
1528
+ ggml_permute(ctx0,
1529
+ ggml_reshape_3d(ctx0,
1530
+ Vcur,
1531
+ n_state/n_head, n_head, n_ctx),
1532
+ 1, 2, 0, 3),
1533
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
1534
+ );
1535
+
1536
+ struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1537
+ #else
1538
+ struct ggml_tensor * Q =
1539
+ ggml_permute(ctx0,
1540
+ ggml_cpy(ctx0,
1541
+ Qcur,
1542
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1543
+ 0, 2, 1, 3);
1544
+
1545
+ struct ggml_tensor * K =
1546
+ ggml_permute(ctx0,
1547
+ ggml_cpy(ctx0,
1548
+ Kcur,
1549
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1550
+ 0, 2, 1, 3);
1551
+
1552
+ // K * Q
1553
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1554
+
1555
+ struct ggml_tensor * KQ_scaled =
1556
+ ggml_scale(ctx0,
1557
+ KQ,
1558
+ ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1559
+ );
1560
+
1561
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
1562
+
1563
+ //struct ggml_tensor * V_trans =
1564
+ // ggml_permute(ctx0,
1565
+ // ggml_cpy(ctx0,
1566
+ // Vcur,
1567
+ // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1568
+ // 1, 2, 0, 3);
1569
+
1570
+ //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1571
+
1572
+ struct ggml_tensor * V =
1573
+ ggml_cpy(ctx0,
1574
+ ggml_permute(ctx0,
1575
+ ggml_reshape_3d(ctx0,
1576
+ Vcur,
1577
+ n_state/n_head, n_head, n_ctx),
1578
+ 0, 2, 1, 3),
1579
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
1580
+ );
1581
+
1582
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
1583
+ #endif
1584
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1585
+
1586
+ wctx.use_buf(ctx0, 1);
1587
+
1588
+ cur = ggml_cpy(ctx0,
1589
+ KQV_merged,
1590
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1591
+ }
1592
+
1593
+ // projection
1594
+ {
1595
+ wctx.use_buf(ctx0, 0);
1596
+
1597
+ cur = ggml_mul_mat(ctx0,
1598
+ layer.attn_ln_1_w,
1599
+ cur);
1600
+
1601
+ wctx.use_buf(ctx0, 1);
1602
+
1603
+ cur = ggml_add(ctx0,
1604
+ ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1605
+ cur);
1606
+ }
1607
+
1608
+ wctx.use_buf(ctx0, 2);
1609
+
1610
+ // add the input
1611
+ cur = ggml_add(ctx0, cur, inpL);
1612
+
1613
+ struct ggml_tensor * inpFF = cur;
1614
+
1615
+ // feed-forward network
1616
+ {
1617
+ // norm
1618
+ {
1619
+ wctx.use_buf(ctx0, 0);
1620
+
1621
+ cur = ggml_norm(ctx0, inpFF);
1622
+
1623
+ wctx.use_buf(ctx0, 1);
1624
+
1625
+ // cur = mlp_ln_w*cur + mlp_ln_b
1626
+ cur = ggml_add(ctx0,
1627
+ ggml_mul(ctx0,
1628
+ ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1629
+ cur),
1630
+ ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1631
+ }
1632
+
1633
+ #ifdef WHISPER_USE_FLASH_FF
1634
+ wctx.use_buf(ctx0, 0);
1635
+
1636
+ cur = ggml_flash_ff(ctx0,
1637
+ ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
1638
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1639
+ #else
1640
+ wctx.use_buf(ctx0, 0);
1641
+
1642
+ // fully connected
1643
+ cur = ggml_mul_mat(ctx0,
1644
+ layer.mlp_0_w,
1645
+ cur);
1646
+
1647
+ wctx.use_buf(ctx0, 1);
1648
+
1649
+ cur = ggml_add(ctx0,
1650
+ ggml_repeat(ctx0, layer.mlp_0_b, cur),
1651
+ cur);
1652
+
1653
+ wctx.use_buf(ctx0, 0);
1654
+
1655
+ // GELU activation
1656
+ cur = ggml_gelu(ctx0, cur);
1657
+
1658
+ wctx.use_buf(ctx0, 1);
1659
+
1660
+ // projection
1661
+ cur = ggml_mul_mat(ctx0,
1662
+ layer.mlp_1_w,
1663
+ cur);
1664
+
1665
+ wctx.use_buf(ctx0, 0);
1666
+
1667
+ cur = ggml_add(ctx0,
1668
+ ggml_repeat(ctx0, layer.mlp_1_b, cur),
1669
+ cur);
1670
+ #endif
1671
+ }
1672
+
1673
+ wctx.use_buf(ctx0, 3);
1674
+
1675
+ inpL = ggml_add(ctx0, cur, inpFF);
1676
+ }
1677
+
1678
+ cur = inpL;
1679
+
1680
+ // norm
1681
+ {
1682
+ wctx.use_buf(ctx0, 0);
1683
+
1684
+ cur = ggml_norm(ctx0, cur);
1685
+
1686
+ wctx.use_buf(ctx0, 1);
1687
+
1688
+ // cur = ln_f_g*cur + ln_f_b
1689
+ cur = ggml_add(ctx0,
1690
+ ggml_mul(ctx0,
1691
+ ggml_repeat(ctx0, model.e_ln_w, cur),
1692
+ cur),
1693
+ ggml_repeat(ctx0, model.e_ln_b, cur));
1694
+ }
1695
+
1696
+ wctx.use_buf(ctx0, -1);
1697
+
1698
+ // run the computation
1699
+ {
1700
+ struct ggml_cgraph gf = {};
1701
+ gf.n_threads = n_threads;
1702
+
1703
+ ggml_build_forward_expand(&gf, cur);
1704
+ ggml_graph_compute (ctx0, &gf);
1705
+
1706
+ //ggml_graph_print(&gf);
1707
+ }
1708
+
1709
+ // cur
1710
+ //{
1711
+ // printf("ne0 = %d\n", cur->ne[0]);
1712
+ // printf("ne1 = %d\n", cur->ne[1]);
1713
+ // for (int i = 0; i < 10; ++i) {
1714
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1715
+ // }
1716
+ // printf("... ");
1717
+ // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1718
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1719
+ // }
1720
+ // printf("\n");
1721
+ //}
1722
+
1723
+ // pre-compute cross-attention memory
1724
+ {
1725
+ struct ggml_cgraph gf = {};
1726
+ gf.n_threads = n_threads;
1727
+
1728
+ // TODO: hack to disconnect the encoded features from the previous graph
1729
+ cur->op = GGML_OP_NONE;
1730
+ cur->src0 = nullptr;
1731
+ cur->src1 = nullptr;
1732
+
1733
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1734
+ auto & layer = model.layers_decoder[il];
1735
+
1736
+ wctx.use_buf(ctx0, 0);
1737
+
1738
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1739
+ layer.cross_attn_k_w,
1740
+ cur);
1741
+
1742
+ Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1743
+
1744
+ wctx.use_buf(ctx0, 1);
1745
+
1746
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1747
+ layer.cross_attn_v_w,
1748
+ cur);
1749
+
1750
+ Vcross = ggml_add(ctx0,
1751
+ ggml_repeat(ctx0,
1752
+ layer.cross_attn_v_b,
1753
+ Vcross),
1754
+ Vcross);
1755
+
1756
+ wctx.use_buf(ctx0, -1);
1757
+
1758
+ //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1759
+ //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1760
+ struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
1761
+ struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
1762
+
1763
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1764
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1765
+ }
1766
+
1767
+ ggml_graph_compute(ctx0, &gf);
1768
+ //ggml_graph_print(&gf);
1769
+ }
1770
+
1771
+ ////////////////////////////////////////////////////////////////////////////
1772
+
1773
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1774
+ // ggml_used_mem(ctx0)/1024.0/1024.0,
1775
+ // wctx.get_buf_max_mem(0)/1024.0/1024.0,
1776
+ // wctx.get_buf_max_mem(1)/1024.0/1024.0,
1777
+ // wctx.get_buf_max_mem(2)/1024.0/1024.0,
1778
+ // wctx.get_buf_max_mem(3)/1024.0/1024.0);
1779
+
1780
+ ggml_free(ctx0);
1781
+
1782
+ wctx.t_encode_us += ggml_time_us() - t_start_us;
1783
+ wctx.n_encode++;
1784
+
1785
+ return true;
1786
+ }
1787
+
1788
+ // evaluate the decoder
1789
+ //
1790
+ // given text prompt + audio features -> computes the logits for the next token
1791
+ //
1792
+ // - model: the model
1793
+ // - n_threads: number of threads to use
1794
+ // - tokens: text prompt
1795
+ // - n_tokens: number of tokens in the prompt
1796
+ // - n_past: number of past tokens to prefix the prompt with
1797
+ //
1798
+ static bool whisper_decode(
1799
+ whisper_context & wctx,
1800
+ whisper_decoder & decoder,
1801
+ const whisper_token * tokens,
1802
+ const int n_tokens,
1803
+ const int n_past,
1804
+ const int n_threads) {
1805
+ const int64_t t_start_us = ggml_time_us();
1806
+
1807
+ const auto & model = wctx.model;
1808
+ const auto & hparams = model.hparams;
1809
+
1810
+ auto & kv_self = decoder.kv_self;
1811
+
1812
+ WHISPER_ASSERT(!!kv_self.ctx);
1813
+
1814
+ auto & logits_out = wctx.logits;
1815
+
1816
+ const int n_vocab = hparams.n_vocab;
1817
+
1818
+ const int n_ctx = hparams.n_text_ctx;
1819
+ const int n_state = hparams.n_text_state;
1820
+ const int n_head = hparams.n_text_head;
1821
+ const int n_layer = hparams.n_text_layer;
1822
+
1823
+ const int N = n_tokens;
1824
+ const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1825
+
1826
+ //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1827
+
1828
+ struct ggml_init_params params;
1829
+ params.mem_size = wctx.buf_compute.size();
1830
+ params.mem_buffer = wctx.buf_compute.data();
1831
+
1832
+ struct ggml_context * ctx0 = ggml_init(params);
1833
+
1834
+ struct ggml_cgraph gf = {};
1835
+ gf.n_threads = n_threads;
1836
+
1837
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1838
+ memcpy(embd->data, tokens, N*ggml_element_size(embd));
1839
+
1840
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1841
+ for (int i = 0; i < N; ++i) {
1842
+ ((int32_t *) position->data)[i] = n_past + i;
1843
+ }
1844
+
1845
+ wctx.use_buf(ctx0, 3);
1846
+
1847
+ // token encoding + position encoding
1848
+ struct ggml_tensor * cur =
1849
+ ggml_add(ctx0,
1850
+ ggml_get_rows(ctx0, model.d_te, embd),
1851
+ ggml_get_rows(ctx0, model.d_pe, position));
1852
+
1853
+ struct ggml_tensor * inpL = cur;
1854
+
1855
+ for (int il = 0; il < n_layer; ++il) {
1856
+ const auto & layer = model.layers_decoder[il];
1857
+
1858
+ // norm
1859
+ {
1860
+ wctx.use_buf(ctx0, 0);
1861
+
1862
+ cur = ggml_norm(ctx0, inpL);
1863
+
1864
+ // cur = ln_0_w*cur + ln_0_b
1865
+ cur = ggml_add(ctx0,
1866
+ ggml_mul(ctx0,
1867
+ ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1868
+ cur),
1869
+ ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1870
+ }
1871
+
1872
+ // self-attention
1873
+ {
1874
+ wctx.use_buf(ctx0, 1);
1875
+
1876
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1877
+ layer.attn_q_w,
1878
+ cur);
1879
+
1880
+ Qcur = ggml_add(ctx0,
1881
+ ggml_repeat(ctx0,
1882
+ layer.attn_q_b,
1883
+ Qcur),
1884
+ Qcur);
1885
+
1886
+ Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1887
+
1888
+ // note: no bias for Key
1889
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1890
+ layer.attn_k_w,
1891
+ cur);
1892
+
1893
+ Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1894
+
1895
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1896
+ layer.attn_v_w,
1897
+ cur);
1898
+
1899
+ Vcur = ggml_add(ctx0,
1900
+ ggml_repeat(ctx0,
1901
+ layer.attn_v_b,
1902
+ Vcur),
1903
+ Vcur);
1904
+
1905
+ // store key and value to memory
1906
+ {
1907
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
1908
+ struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
1909
+
1910
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
1911
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
1912
+ }
1913
+
1914
+ // ------
1915
+
1916
+ wctx.use_buf(ctx0, 0);
1917
+
1918
+ struct ggml_tensor * Q =
1919
+ ggml_permute(ctx0,
1920
+ ggml_cpy(ctx0,
1921
+ Qcur,
1922
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1923
+ 0, 2, 1, 3);
1924
+
1925
+ struct ggml_tensor * K =
1926
+ ggml_permute(ctx0,
1927
+ ggml_reshape_3d(ctx0,
1928
+ ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
1929
+ n_state/n_head, n_head, n_past + N),
1930
+ 0, 2, 1, 3);
1931
+
1932
+ wctx.use_buf(ctx0, 1);
1933
+
1934
+ // K * Q
1935
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1936
+
1937
+ wctx.use_buf(ctx0, 0);
1938
+
1939
+ //struct ggml_tensor * KQ_scaled =
1940
+ // ggml_scale(ctx0,
1941
+ // KQ,
1942
+ // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1943
+ // );
1944
+
1945
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1946
+
1947
+ wctx.use_buf(ctx0, 1);
1948
+
1949
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1950
+
1951
+ wctx.use_buf(ctx0, 0);
1952
+
1953
+ struct ggml_tensor * V_trans =
1954
+ ggml_permute(ctx0,
1955
+ ggml_reshape_3d(ctx0,
1956
+ ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
1957
+ n_state/n_head, n_head, n_past + N),
1958
+ 1, 2, 0, 3);
1959
+
1960
+ wctx.use_buf(ctx0, 1);
1961
+
1962
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1963
+
1964
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1965
+
1966
+ cur = ggml_cpy(ctx0,
1967
+ KQV_merged,
1968
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
1969
+ }
1970
+
1971
+ // projection
1972
+ {
1973
+ wctx.use_buf(ctx0, 0);
1974
+
1975
+ cur = ggml_mul_mat(ctx0,
1976
+ layer.attn_ln_1_w,
1977
+ cur);
1978
+
1979
+ wctx.use_buf(ctx0, 1);
1980
+
1981
+ cur = ggml_add(ctx0,
1982
+ ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1983
+ cur);
1984
+ }
1985
+
1986
+ wctx.use_buf(ctx0, 2);
1987
+
1988
+ // add the input
1989
+ struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
1990
+
1991
+ // norm
1992
+ {
1993
+ wctx.use_buf(ctx0, 0);
1994
+
1995
+ cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1996
+
1997
+ wctx.use_buf(ctx0, 1);
1998
+
1999
+ // cur = ln_0_w*cur + ln_0_b
2000
+ cur = ggml_add(ctx0,
2001
+ ggml_mul(ctx0,
2002
+ ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
2003
+ cur),
2004
+ ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
2005
+ }
2006
+
2007
+ // cross-attention
2008
+ {
2009
+ wctx.use_buf(ctx0, 0);
2010
+
2011
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
2012
+ layer.cross_attn_q_w,
2013
+ cur);
2014
+
2015
+ Qcur = ggml_add(ctx0,
2016
+ ggml_repeat(ctx0,
2017
+ layer.cross_attn_q_b,
2018
+ Qcur),
2019
+ Qcur);
2020
+
2021
+ Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2022
+
2023
+ // Kcross is already scaled
2024
+ struct ggml_tensor * Kcross =
2025
+ ggml_reshape_3d(ctx0,
2026
+ ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
2027
+ n_state/n_head, n_head, M);
2028
+
2029
+ struct ggml_tensor * Vcross =
2030
+ ggml_reshape_3d(ctx0,
2031
+ ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
2032
+ n_state/n_head, n_head, M);
2033
+
2034
+ struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
2035
+
2036
+ // ------
2037
+
2038
+ wctx.use_buf(ctx0, 1);
2039
+
2040
+ struct ggml_tensor * Q =
2041
+ ggml_permute(ctx0,
2042
+ ggml_cpy(ctx0,
2043
+ Qcur,
2044
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
2045
+ 0, 2, 1, 3);
2046
+
2047
+ struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2048
+
2049
+ wctx.use_buf(ctx0, 0);
2050
+
2051
+ // K * Q
2052
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2053
+
2054
+ //struct ggml_tensor * KQ_scaled =
2055
+ // ggml_scale(ctx0,
2056
+ // KQ,
2057
+ // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2058
+ // );
2059
+
2060
+ // no masking for cross-attention
2061
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2062
+
2063
+ wctx.use_buf(ctx0, 1);
2064
+
2065
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2066
+
2067
+ wctx.use_buf(ctx0, 0);
2068
+
2069
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2070
+
2071
+ wctx.use_buf(ctx0, 1);
2072
+
2073
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2074
+
2075
+ // cur = KQV_merged.contiguous().view(n_state, N)
2076
+ cur = ggml_cpy(ctx0,
2077
+ KQV_merged,
2078
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
2079
+ }
2080
+
2081
+ // projection
2082
+ {
2083
+ wctx.use_buf(ctx0, 0);
2084
+
2085
+ cur = ggml_mul_mat(ctx0,
2086
+ layer.cross_attn_ln_1_w,
2087
+ cur);
2088
+
2089
+ wctx.use_buf(ctx0, 1);
2090
+
2091
+ cur = ggml_add(ctx0,
2092
+ ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2093
+ cur);
2094
+ }
2095
+
2096
+ wctx.use_buf(ctx0, 2);
2097
+
2098
+ // add the input
2099
+ cur = ggml_add(ctx0, cur, inpCA);
2100
+
2101
+ struct ggml_tensor * inpFF = cur;
2102
+
2103
+ // feed-forward network
2104
+ {
2105
+ // norm
2106
+ {
2107
+ wctx.use_buf(ctx0, 0);
2108
+
2109
+ cur = ggml_norm(ctx0, inpFF);
2110
+
2111
+ wctx.use_buf(ctx0, 1);
2112
+
2113
+ // cur = mlp_ln_w*cur + mlp_ln_b
2114
+ cur = ggml_add(ctx0,
2115
+ ggml_mul(ctx0,
2116
+ ggml_repeat(ctx0, layer.mlp_ln_w, cur),
2117
+ cur),
2118
+ ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2119
+ }
2120
+
2121
+ wctx.use_buf(ctx0, 0);
2122
+
2123
+ // fully connected
2124
+ cur = ggml_mul_mat(ctx0,
2125
+ layer.mlp_0_w,
2126
+ cur);
2127
+
2128
+ wctx.use_buf(ctx0, 1);
2129
+
2130
+ cur = ggml_add(ctx0,
2131
+ ggml_repeat(ctx0, layer.mlp_0_b, cur),
2132
+ cur);
2133
+
2134
+ wctx.use_buf(ctx0, 0);
2135
+
2136
+ // GELU activation
2137
+ cur = ggml_gelu(ctx0, cur);
2138
+
2139
+ wctx.use_buf(ctx0, 1);
2140
+
2141
+ // projection
2142
+ cur = ggml_mul_mat(ctx0,
2143
+ layer.mlp_1_w,
2144
+ cur);
2145
+
2146
+ wctx.use_buf(ctx0, 0);
2147
+
2148
+ cur = ggml_add(ctx0,
2149
+ ggml_repeat(ctx0, layer.mlp_1_b, cur),
2150
+ cur);
2151
+ }
2152
+
2153
+ wctx.use_buf(ctx0, 3);
2154
+
2155
+ inpL = ggml_add(ctx0, cur, inpFF);
2156
+ }
2157
+
2158
+ cur = inpL;
2159
+
2160
+ // norm
2161
+ {
2162
+ wctx.use_buf(ctx0, 0);
2163
+
2164
+ cur = ggml_norm(ctx0, cur);
2165
+
2166
+ wctx.use_buf(ctx0, 1);
2167
+
2168
+ cur = ggml_add(ctx0,
2169
+ ggml_mul(ctx0,
2170
+ ggml_repeat(ctx0, model.d_ln_w, cur),
2171
+ cur),
2172
+ ggml_repeat(ctx0, model.d_ln_b, cur));
2173
+ }
2174
+
2175
+ wctx.use_buf(ctx0, 0);
2176
+
2177
+ // compute logits only for the last token
2178
+ // comment this line to compute logits for all N tokens
2179
+ // might be useful in the future
2180
+ cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2181
+
2182
+ struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2183
+
2184
+ wctx.use_buf(ctx0, -1);
2185
+
2186
+ // run the computation
2187
+ {
2188
+ ggml_build_forward_expand(&gf, logits);
2189
+ ggml_graph_compute (ctx0, &gf);
2190
+ }
2191
+
2192
+ // extract logits for all N tokens
2193
+ //logits_out.resize(N*n_vocab);
2194
+ //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
2195
+
2196
+ // extract logits only for the last token
2197
+ logits_out.resize(n_vocab);
2198
+ memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
2199
+
2200
+ if (N > 1) {
2201
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2202
+ // ggml_used_mem(ctx0)/1024.0/1024.0,
2203
+ // wctx.get_buf_max_mem(0)/1024.0/1024.0,
2204
+ // wctx.get_buf_max_mem(1)/1024.0/1024.0,
2205
+ // wctx.get_buf_max_mem(2)/1024.0/1024.0,
2206
+ // wctx.get_buf_max_mem(3)/1024.0/1024.0);
2207
+ }
2208
+
2209
+ ggml_free(ctx0);
2210
+
2211
+ wctx.t_decode_us += ggml_time_us() - t_start_us;
2212
+ wctx.n_decode++;
2213
+
2214
+ return true;
2215
+ }
2216
+
2217
+ // 500 -> 00:05.000
2218
+ // 6000 -> 01:00.000
2219
+ static std::string to_timestamp(int64_t t, bool comma = false) {
2220
+ int64_t msec = t * 10;
2221
+ int64_t hr = msec / (1000 * 60 * 60);
2222
+ msec = msec - hr * (1000 * 60 * 60);
2223
+ int64_t min = msec / (1000 * 60);
2224
+ msec = msec - min * (1000 * 60);
2225
+ int64_t sec = msec / 1000;
2226
+ msec = msec - sec * 1000;
2227
+
2228
+ char buf[32];
2229
+ snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
2230
+
2231
+ return std::string(buf);
2232
+ }
2233
+
2234
+ // naive Discrete Fourier Transform
2235
+ // input is real-valued
2236
+ // output is complex-valued
2237
+ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2238
+ int N = in.size();
2239
+
2240
+ out.resize(N*2);
2241
+
2242
+ for (int k = 0; k < N; k++) {
2243
+ float re = 0;
2244
+ float im = 0;
2245
+
2246
+ for (int n = 0; n < N; n++) {
2247
+ float angle = 2*M_PI*k*n/N;
2248
+ re += in[n]*cos(angle);
2249
+ im -= in[n]*sin(angle);
2250
+ }
2251
+
2252
+ out[k*2 + 0] = re;
2253
+ out[k*2 + 1] = im;
2254
+ }
2255
+ }
2256
+
2257
+ // Cooley-Tukey FFT
2258
+ // poor man's implementation - use something better
2259
+ // input is real-valued
2260
+ // output is complex-valued
2261
+ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2262
+ out.resize(in.size()*2);
2263
+
2264
+ int N = in.size();
2265
+
2266
+ if (N == 1) {
2267
+ out[0] = in[0];
2268
+ out[1] = 0;
2269
+ return;
2270
+ }
2271
+
2272
+ if (N%2 == 1) {
2273
+ dft(in, out);
2274
+ return;
2275
+ }
2276
+
2277
+ std::vector<float> even;
2278
+ std::vector<float> odd;
2279
+
2280
+ even.reserve(N/2);
2281
+ odd.reserve(N/2);
2282
+
2283
+ for (int i = 0; i < N; i++) {
2284
+ if (i % 2 == 0) {
2285
+ even.push_back(in[i]);
2286
+ } else {
2287
+ odd.push_back(in[i]);
2288
+ }
2289
+ }
2290
+
2291
+ std::vector<float> even_fft;
2292
+ std::vector<float> odd_fft;
2293
+
2294
+ fft(even, even_fft);
2295
+ fft(odd, odd_fft);
2296
+
2297
+ for (int k = 0; k < N/2; k++) {
2298
+ float theta = 2*M_PI*k/N;
2299
+
2300
+ float re = cos(theta);
2301
+ float im = -sin(theta);
2302
+
2303
+ float re_odd = odd_fft[2*k + 0];
2304
+ float im_odd = odd_fft[2*k + 1];
2305
+
2306
+ out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
2307
+ out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
2308
+
2309
+ out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2310
+ out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2311
+ }
2312
+ }
2313
+
2314
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2315
+ static bool log_mel_spectrogram(
2316
+ whisper_context & wctx,
2317
+ const float * samples,
2318
+ const int n_samples,
2319
+ const int /*sample_rate*/,
2320
+ const int fft_size,
2321
+ const int fft_step,
2322
+ const int n_mel,
2323
+ const int n_threads,
2324
+ const whisper_filters & filters,
2325
+ const bool speed_up,
2326
+ whisper_mel & mel) {
2327
+ const int64_t t_start_us = ggml_time_us();
2328
+
2329
+ // Hanning window
2330
+ std::vector<float> hann;
2331
+ hann.resize(fft_size);
2332
+ for (int i = 0; i < fft_size; i++) {
2333
+ hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
2334
+ }
2335
+
2336
+ mel.n_mel = n_mel;
2337
+ mel.n_len = (n_samples)/fft_step;
2338
+ mel.data.resize(mel.n_mel*mel.n_len);
2339
+
2340
+ const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
2341
+
2342
+ //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2343
+ //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
2344
+
2345
+ std::vector<std::thread> workers(n_threads);
2346
+ for (int iw = 0; iw < n_threads; ++iw) {
2347
+ workers[iw] = std::thread([&](int ith) {
2348
+ std::vector<float> fft_in;
2349
+ fft_in.resize(fft_size);
2350
+ for (int i = 0; i < fft_size; i++) {
2351
+ fft_in[i] = 0.0;
2352
+ }
2353
+
2354
+ std::vector<float> fft_out;
2355
+ fft_out.resize(2*fft_size);
2356
+
2357
+ for (int i = ith; i < mel.n_len; i += n_threads) {
2358
+ const int offset = i*fft_step;
2359
+
2360
+ // apply Hanning window
2361
+ for (int j = 0; j < fft_size; j++) {
2362
+ if (offset + j < n_samples) {
2363
+ fft_in[j] = hann[j]*samples[offset + j];
2364
+ } else {
2365
+ fft_in[j] = 0.0;
2366
+ }
2367
+ }
2368
+
2369
+ // FFT -> mag^2
2370
+ fft(fft_in, fft_out);
2371
+
2372
+ for (int j = 0; j < fft_size; j++) {
2373
+ fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
2374
+ }
2375
+ for (int j = 1; j < fft_size/2; j++) {
2376
+ //if (i == 0) {
2377
+ // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
2378
+ //}
2379
+ fft_out[j] += fft_out[fft_size - j];
2380
+ }
2381
+ if (i == 0) {
2382
+ //for (int j = 0; j < fft_size; j++) {
2383
+ // printf("%d: %e\n", j, fft_out[j]);
2384
+ //}
2385
+ }
2386
+
2387
+ if (speed_up) {
2388
+ // scale down in the frequency domain results in a speed up in the time domain
2389
+ for (int j = 0; j < n_fft; j++) {
2390
+ fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
2391
+ }
2392
+ }
2393
+
2394
+ // mel spectrogram
2395
+ for (int j = 0; j < mel.n_mel; j++) {
2396
+ double sum = 0.0;
2397
+
2398
+ for (int k = 0; k < n_fft; k++) {
2399
+ sum += fft_out[k]*filters.data[j*n_fft + k];
2400
+ }
2401
+ if (sum < 1e-10) {
2402
+ sum = 1e-10;
2403
+ }
2404
+
2405
+ sum = log10(sum);
2406
+
2407
+ mel.data[j*mel.n_len + i] = sum;
2408
+ }
2409
+ }
2410
+ }, iw);
2411
+ }
2412
+
2413
+ for (int iw = 0; iw < n_threads; ++iw) {
2414
+ workers[iw].join();
2415
+ }
2416
+
2417
+ // clamping and normalization
2418
+ double mmax = -1e20;
2419
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2420
+ if (mel.data[i] > mmax) {
2421
+ mmax = mel.data[i];
2422
+ }
2423
+ }
2424
+ //printf("%s: max = %f\n", __func__, mmax);
2425
+
2426
+ mmax -= 8.0;
2427
+
2428
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2429
+ if (mel.data[i] < mmax) {
2430
+ mel.data[i] = mmax;
2431
+ }
2432
+
2433
+ mel.data[i] = (mel.data[i] + 4.0)/4.0;
2434
+ }
2435
+
2436
+ wctx.t_mel_us += ggml_time_us() - t_start_us;
2437
+
2438
+ return true;
2439
+ }
2440
+
2441
+ // split text into tokens
2442
+ //
2443
+ // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
2444
+ //
2445
+ // Regex (Python):
2446
+ // r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
2447
+ //
2448
+ // Regex (C++):
2449
+ // R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
2450
+ //
2451
+ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
2452
+ std::vector<std::string> words;
2453
+
2454
+ // first split the text into words
2455
+ {
2456
+ std::string str = text;
2457
+ std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
2458
+
2459
+ std::regex re(pat);
2460
+ std::smatch m;
2461
+
2462
+ while (std::regex_search(str, m, re)) {
2463
+ for (auto x : m) {
2464
+ words.push_back(x);
2465
+ }
2466
+ str = m.suffix();
2467
+ }
2468
+ }
2469
+
2470
+ // find the longest tokens that form the words:
2471
+ std::vector<whisper_vocab::id> tokens;
2472
+ for (const auto & word : words) {
2473
+ if (word.empty()) continue;
2474
+
2475
+ int i = 0;
2476
+ int n = word.size();
2477
+ while (i < n) {
2478
+ int j = n;
2479
+ while (j > i) {
2480
+ auto it = vocab.token_to_id.find(word.substr(i, j-i));
2481
+ if (it != vocab.token_to_id.end()) {
2482
+ tokens.push_back(it->second);
2483
+ i = j;
2484
+ break;
2485
+ }
2486
+ --j;
2487
+ }
2488
+ if (i == n) {
2489
+ break;
2490
+ }
2491
+ if (j == i) {
2492
+ auto sub = word.substr(i, 1);
2493
+ if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
2494
+ tokens.push_back(vocab.token_to_id.at(sub));
2495
+ } else {
2496
+ fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
2497
+ }
2498
+ ++i;
2499
+ }
2500
+ }
2501
+ }
2502
+
2503
+ return tokens;
2504
+ }
2505
+
2506
+ //
2507
+ // interface implementation
2508
+ //
2509
+
2510
+ struct whisper_context * whisper_init_from_file(const char * path_model) {
2511
+ whisper_model_loader loader = {};
2512
+
2513
+ fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
2514
+
2515
+ auto fin = std::ifstream(path_model, std::ios::binary);
2516
+ if (!fin) {
2517
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
2518
+ return nullptr;
2519
+ }
2520
+
2521
+ loader.context = &fin;
2522
+ loader.read = [](void * ctx, void * output, size_t read_size) {
2523
+ std::ifstream * fin = (std::ifstream*)ctx;
2524
+ fin->read((char *)output, read_size);
2525
+ return read_size;
2526
+ };
2527
+
2528
+ loader.eof = [](void * ctx) {
2529
+ std::ifstream * fin = (std::ifstream*)ctx;
2530
+ return fin->eof();
2531
+ };
2532
+
2533
+ loader.close = [](void * ctx) {
2534
+ std::ifstream * fin = (std::ifstream*)ctx;
2535
+ fin->close();
2536
+ };
2537
+
2538
+ return whisper_init(&loader);
2539
+ }
2540
+
2541
+ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
2542
+ struct buf_context {
2543
+ uint8_t* buffer;
2544
+ size_t size;
2545
+ size_t current_offset;
2546
+ };
2547
+
2548
+ buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
2549
+ whisper_model_loader loader = {};
2550
+
2551
+ fprintf(stderr, "%s: loading model from buffer\n", __func__);
2552
+
2553
+ loader.context = &ctx;
2554
+
2555
+ loader.read = [](void * ctx, void * output, size_t read_size) {
2556
+ buf_context * buf = reinterpret_cast<buf_context *>(ctx);
2557
+
2558
+ size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
2559
+
2560
+ memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
2561
+ buf->current_offset += size_to_copy;
2562
+
2563
+ return size_to_copy;
2564
+ };
2565
+
2566
+ loader.eof = [](void * ctx) {
2567
+ buf_context * buf = reinterpret_cast<buf_context *>(ctx);
2568
+
2569
+ return buf->current_offset >= buf->size;
2570
+ };
2571
+
2572
+ loader.close = [](void * /*ctx*/) { };
2573
+
2574
+ return whisper_init(&loader);
2575
+ }
2576
+
2577
+ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2578
+ ggml_time_init();
2579
+
2580
+ whisper_context * ctx = new whisper_context;
2581
+
2582
+ if (!whisper_model_load(loader, *ctx)) {
2583
+ loader->close(loader->context);
2584
+ fprintf(stderr, "%s: failed to load model\n", __func__);
2585
+ delete ctx;
2586
+ return nullptr;
2587
+ }
2588
+
2589
+ loader->close(loader->context);
2590
+
2591
+ return ctx;
2592
+ }
2593
+
2594
+ void whisper_free(struct whisper_context * ctx) {
2595
+ if (ctx) {
2596
+ if (ctx->model.ctx) {
2597
+ ggml_free(ctx->model.ctx);
2598
+ }
2599
+ if (ctx->model.buf) {
2600
+ delete ctx->model.buf;
2601
+ }
2602
+ if (ctx->kv_cross.ctx) {
2603
+ ggml_free(ctx->kv_cross.ctx);
2604
+ }
2605
+ for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
2606
+ if (ctx->decoders[i].kv_self.ctx) {
2607
+ ggml_free(ctx->decoders[i].kv_self.ctx);
2608
+ }
2609
+ }
2610
+ delete ctx;
2611
+ }
2612
+ }
2613
+
2614
+ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2615
+ if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
2616
+ fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2617
+ return -1;
2618
+ }
2619
+
2620
+ return 0;
2621
+ }
2622
+
2623
+ // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
2624
+ int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2625
+ if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
2626
+ fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2627
+ return -1;
2628
+ }
2629
+
2630
+ return 0;
2631
+ }
2632
+
2633
+ int whisper_set_mel(
2634
+ struct whisper_context * ctx,
2635
+ const float * data,
2636
+ int n_len,
2637
+ int n_mel) {
2638
+ if (n_mel != WHISPER_N_MEL) {
2639
+ fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
2640
+ return -1;
2641
+ }
2642
+
2643
+ ctx->mel.n_len = n_len;
2644
+ ctx->mel.n_mel = n_mel;
2645
+
2646
+ ctx->mel.data.resize(n_len*n_mel);
2647
+ memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
2648
+
2649
+ return 0;
2650
+ }
2651
+
2652
+ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2653
+ if (!whisper_encode(*ctx, offset, n_threads)) {
2654
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2655
+ return -1;
2656
+ }
2657
+
2658
+ return 0;
2659
+ }
2660
+
2661
+ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2662
+ // TODO: add selected_decoder_id to context
2663
+ const int selected_decoder_id = 0;
2664
+
2665
+ if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2666
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2667
+ return 1;
2668
+ }
2669
+
2670
+ return 0;
2671
+ }
2672
+
2673
+ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
2674
+ const auto res = tokenize(ctx->vocab, text);
2675
+
2676
+ if (n_max_tokens < (int) res.size()) {
2677
+ fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
2678
+ return -1;
2679
+ }
2680
+
2681
+ for (int i = 0; i < (int) res.size(); i++) {
2682
+ tokens[i] = res[i];
2683
+ }
2684
+
2685
+ return res.size();
2686
+ }
2687
+
2688
+ int whisper_lang_max_id() {
2689
+ auto max_id = 0;
2690
+ for (const auto & kv : g_lang) {
2691
+ max_id = std::max(max_id, kv.second.first);
2692
+ }
2693
+
2694
+ return max_id;
2695
+ }
2696
+
2697
+ int whisper_lang_id(const char * lang) {
2698
+ if (!g_lang.count(lang)) {
2699
+ for (const auto & kv : g_lang) {
2700
+ if (kv.second.second == lang) {
2701
+ return kv.second.first;
2702
+ }
2703
+ }
2704
+
2705
+ fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
2706
+ return -1;
2707
+ }
2708
+
2709
+ return g_lang.at(lang).first;
2710
+ }
2711
+
2712
+ const char * whisper_lang_str(int id) {
2713
+ for (const auto & kv : g_lang) {
2714
+ if (kv.second.first == id) {
2715
+ return kv.first.c_str();
2716
+ }
2717
+ }
2718
+
2719
+ fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
2720
+ return nullptr;
2721
+ }
2722
+
2723
+ int whisper_lang_auto_detect(
2724
+ struct whisper_context * ctx,
2725
+ int offset_ms,
2726
+ int n_threads,
2727
+ float * lang_probs) {
2728
+ const int seek = offset_ms/10;
2729
+
2730
+ if (seek < 0) {
2731
+ fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
2732
+ return -1;
2733
+ }
2734
+
2735
+ if (seek >= ctx->mel.n_len) {
2736
+ fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
2737
+ return -2;
2738
+ }
2739
+
2740
+ // run the encoder
2741
+ if (whisper_encode(ctx, seek, n_threads) != 0) {
2742
+ fprintf(stderr, "%s: failed to encode\n", __func__);
2743
+ return -6;
2744
+ }
2745
+
2746
+ const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
2747
+
2748
+ if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
2749
+ fprintf(stderr, "%s: failed to decode\n", __func__);
2750
+ return -7;
2751
+ }
2752
+
2753
+ auto & logits_id = ctx->logits_id;
2754
+ logits_id.clear();
2755
+
2756
+ for (const auto & kv : g_lang) {
2757
+ const auto token_lang = whisper_token_lang(ctx, kv.second.first);
2758
+ logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
2759
+ }
2760
+
2761
+ // sort descending
2762
+ {
2763
+ using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
2764
+ std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
2765
+ return a.first > b.first;
2766
+ });
2767
+ }
2768
+
2769
+ // softmax
2770
+ {
2771
+ const auto max = logits_id[0].first;
2772
+
2773
+ double sum = 0.0f;
2774
+ for (auto & kv : logits_id) {
2775
+ kv.first = exp(kv.first - max);
2776
+ sum += kv.first;
2777
+ }
2778
+
2779
+ for (auto & kv : logits_id) {
2780
+ kv.first /= sum;
2781
+ }
2782
+ }
2783
+
2784
+ {
2785
+ for (const auto & prob : logits_id) {
2786
+ if (lang_probs) {
2787
+ lang_probs[prob.second] = prob.first;
2788
+ }
2789
+
2790
+ //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
2791
+ }
2792
+ }
2793
+
2794
+ return logits_id[0].second;
2795
+ }
2796
+
2797
+ int whisper_n_len(struct whisper_context * ctx) {
2798
+ return ctx->mel.n_len;
2799
+ }
2800
+
2801
+ int whisper_n_vocab(struct whisper_context * ctx) {
2802
+ return ctx->vocab.n_vocab;
2803
+ }
2804
+
2805
+ int whisper_n_text_ctx(struct whisper_context * ctx) {
2806
+ return ctx->model.hparams.n_text_ctx;
2807
+ }
2808
+
2809
+ int whisper_n_audio_ctx(struct whisper_context * ctx) {
2810
+ return ctx->model.hparams.n_audio_ctx;
2811
+ }
2812
+
2813
+ int whisper_is_multilingual(struct whisper_context * ctx) {
2814
+ return ctx->vocab.is_multilingual() ? 1 : 0;
2815
+ }
2816
+
2817
+ float * whisper_get_logits(struct whisper_context * ctx) {
2818
+ return ctx->logits.data();
2819
+ }
2820
+
2821
+ const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
2822
+ return ctx->vocab.id_to_token.at(token).c_str();
2823
+ }
2824
+
2825
+ whisper_token whisper_token_eot(struct whisper_context * ctx) {
2826
+ return ctx->vocab.token_eot;
2827
+ }
2828
+
2829
+ whisper_token whisper_token_sot(struct whisper_context * ctx) {
2830
+ return ctx->vocab.token_sot;
2831
+ }
2832
+
2833
+ whisper_token whisper_token_prev(struct whisper_context * ctx) {
2834
+ return ctx->vocab.token_prev;
2835
+ }
2836
+
2837
+ whisper_token whisper_token_solm(struct whisper_context * ctx) {
2838
+ return ctx->vocab.token_solm;
2839
+ }
2840
+
2841
+ whisper_token whisper_token_not(struct whisper_context * ctx) {
2842
+ return ctx->vocab.token_not;
2843
+ }
2844
+
2845
+ whisper_token whisper_token_beg(struct whisper_context * ctx) {
2846
+ return ctx->vocab.token_beg;
2847
+ }
2848
+
2849
+ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
2850
+ return whisper_token_sot(ctx) + 1 + lang_id;
2851
+ }
2852
+
2853
+ whisper_token whisper_token_translate(void) {
2854
+ return whisper_vocab::token_translate;
2855
+ }
2856
+
2857
+ whisper_token whisper_token_transcribe(void) {
2858
+ return whisper_vocab::token_transcribe;
2859
+ }
2860
+
2861
+ void whisper_print_timings(struct whisper_context * ctx) {
2862
+ const int64_t t_end_us = ggml_time_us();
2863
+
2864
+ const int32_t n_sample = std::max(1, ctx->n_sample);
2865
+ const int32_t n_encode = std::max(1, ctx->n_encode);
2866
+ const int32_t n_decode = std::max(1, ctx->n_decode);
2867
+
2868
+ fprintf(stderr, "\n");
2869
+ fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
2870
+ fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
2871
+ fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
2872
+ fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
2873
+ fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
2874
+ fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
2875
+ fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
2876
+ }
2877
+
2878
+ void whisper_reset_timings(struct whisper_context * ctx) {
2879
+ ctx->t_sample_us = 0;
2880
+ ctx->t_encode_us = 0;
2881
+ ctx->t_decode_us = 0;
2882
+ }
2883
+
2884
+ const char * whisper_print_system_info(void) {
2885
+ static std::string s;
2886
+
2887
+ s = "";
2888
+ s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
2889
+ s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
2890
+ s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
2891
+ s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
2892
+ s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
2893
+ s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
2894
+ s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
2895
+ s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
2896
+ s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
2897
+ s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
2898
+ s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
2899
+ s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
2900
+
2901
+ return s.c_str();
2902
+ }
2903
+
2904
+ ////////////////////////////////////////////////////////////////////////////
2905
+
2906
+ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
2907
+ struct whisper_full_params result = {
2908
+ /*.strategy =*/ strategy,
2909
+
2910
+ /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2911
+ /*.n_max_text_ctx =*/ 16384,
2912
+ /*.offset_ms =*/ 0,
2913
+ /*.duration_ms =*/ 0,
2914
+
2915
+ /*.translate =*/ false,
2916
+ /*.no_context =*/ false,
2917
+ /*.single_segment =*/ false,
2918
+ /*.print_special =*/ false,
2919
+ /*.print_progress =*/ true,
2920
+ /*.print_realtime =*/ false,
2921
+ /*.print_timestamps =*/ true,
2922
+
2923
+ /*.token_timestamps =*/ false,
2924
+ /*.thold_pt =*/ 0.01f,
2925
+ /*.thold_ptsum =*/ 0.01f,
2926
+ /*.max_len =*/ 0,
2927
+ /*.split_on_word =*/ false,
2928
+ /*.max_tokens =*/ 0,
2929
+
2930
+ /*.speed_up =*/ false,
2931
+ /*.audio_ctx =*/ 0,
2932
+
2933
+ /*.prompt_tokens =*/ nullptr,
2934
+ /*.prompt_n_tokens =*/ 0,
2935
+
2936
+ /*.language =*/ "en",
2937
+
2938
+ /*.suppress_blank =*/ true,
2939
+ /*.suppress_non_speech_tokens =*/ false,
2940
+
2941
+ /*.temperature =*/ 0.0f,
2942
+ /*.max_initial_ts =*/ 1.0f,
2943
+ /*.length_penalty =*/ -1.0f,
2944
+
2945
+ /*.temperature_inc =*/ 0.2f,
2946
+ /*.entropy_thold =*/ 2.4f,
2947
+ /*.logprob_thold =*/ -1.0f,
2948
+ /*.no_speech_thold =*/ 0.6f,
2949
+
2950
+ /*.greedy =*/ {
2951
+ /*.best_of =*/ -1,
2952
+ },
2953
+
2954
+ /*.beam_search =*/ {
2955
+ /*.beam_size =*/ -1,
2956
+
2957
+ /*.patience =*/ -1.0f,
2958
+ },
2959
+
2960
+ /*.new_segment_callback =*/ nullptr,
2961
+ /*.new_segment_callback_user_data =*/ nullptr,
2962
+
2963
+ /*.encoder_begin_callback =*/ nullptr,
2964
+ /*.encoder_begin_callback_user_data =*/ nullptr,
2965
+
2966
+ /*.logits_filter_callback =*/ nullptr,
2967
+ /*.logits_filter_callback_user_data =*/ nullptr,
2968
+ };
2969
+
2970
+ switch (strategy) {
2971
+ case WHISPER_SAMPLING_GREEDY:
2972
+ {
2973
+ result.greedy = {
2974
+ /*.best_of =*/ 1,
2975
+ };
2976
+ } break;
2977
+ case WHISPER_SAMPLING_BEAM_SEARCH:
2978
+ {
2979
+ result.beam_search = {
2980
+ /*.beam_size =*/ 5,
2981
+
2982
+ /*.patience =*/ -1.0f,
2983
+ };
2984
+ } break;
2985
+ }
2986
+
2987
+ return result;
2988
+ }
2989
+
2990
+ // forward declarations
2991
+ static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
2992
+ static void whisper_exp_compute_token_level_timestamps(
2993
+ struct whisper_context & ctx,
2994
+ int i_segment,
2995
+ float thold_pt,
2996
+ float thold_ptsum);
2997
+
2998
+ // trim from start (in place)
2999
+ static inline void ltrim(std::string &s) {
3000
+ s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
3001
+ return !std::isspace(ch);
3002
+ }));
3003
+ }
3004
+
3005
+ // trim from end (in place)
3006
+ static inline void rtrim(std::string &s) {
3007
+ s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
3008
+ return !std::isspace(ch);
3009
+ }).base(), s.end());
3010
+ }
3011
+
3012
+ // trim from both ends (in place)
3013
+ static inline void trim(std::string &s) {
3014
+ rtrim(s);
3015
+ ltrim(s);
3016
+ }
3017
+
3018
+ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
3019
+ if (!split_on_word) return true;
3020
+
3021
+ return txt[0] == ' ';
3022
+ }
3023
+
3024
+ // wrap the last segment to max_len characters
3025
+ // returns the number of new segments
3026
+ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
3027
+ auto segment = ctx.result_all.back();
3028
+
3029
+ int res = 1;
3030
+ int acc = 0;
3031
+
3032
+ std::string text;
3033
+
3034
+ for (int i = 0; i < (int) segment.tokens.size(); i++) {
3035
+ const auto & token = segment.tokens[i];
3036
+ if (token.id >= whisper_token_eot(&ctx)) {
3037
+ continue;
3038
+ }
3039
+
3040
+ const auto txt = whisper_token_to_str(&ctx, token.id);
3041
+ const int cur = strlen(txt);
3042
+
3043
+ if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
3044
+ // split here
3045
+ if (split_on_word) {
3046
+ trim(text);
3047
+ }
3048
+
3049
+ ctx.result_all.back().text = std::move(text);
3050
+ ctx.result_all.back().t1 = token.t0;
3051
+ ctx.result_all.back().tokens.resize(i);
3052
+
3053
+ ctx.result_all.push_back({});
3054
+ ctx.result_all.back().t0 = token.t0;
3055
+ ctx.result_all.back().t1 = segment.t1;
3056
+
3057
+ // add tokens [i, end] to the new segment
3058
+ ctx.result_all.back().tokens.insert(
3059
+ ctx.result_all.back().tokens.end(),
3060
+ segment.tokens.begin() + i,
3061
+ segment.tokens.end());
3062
+
3063
+ acc = 0;
3064
+ text = "";
3065
+
3066
+ segment = ctx.result_all.back();
3067
+ i = -1;
3068
+
3069
+ res++;
3070
+ } else {
3071
+ acc += cur;
3072
+ text += txt;
3073
+ }
3074
+ }
3075
+
3076
+ if (split_on_word) {
3077
+ trim(text);
3078
+ }
3079
+ ctx.result_all.back().text = std::move(text);
3080
+
3081
+ return res;
3082
+ }
3083
+
3084
+ static const std::vector<std::string> non_speech_tokens = {
3085
+ "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
3086
+ "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
3087
+ "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
3088
+ "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
3089
+ };
3090
+
3091
+ // process the logits for the selected decoder
3092
+ // - applies logit filters
3093
+ // - computes logprobs and probs
3094
+ static void whisper_process_logits(
3095
+ struct whisper_context & ctx,
3096
+ const struct whisper_full_params params,
3097
+ struct whisper_decoder & decoder,
3098
+ float temperature) {
3099
+ const auto & vocab = ctx.vocab;
3100
+ const auto & tokens_cur = decoder.sequence.tokens;
3101
+
3102
+ const bool is_initial = tokens_cur.size() == 0;
3103
+ const int n_logits = vocab.id_to_token.size();
3104
+
3105
+ WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
3106
+
3107
+ // extract the logits for the last token
3108
+ // we will be mutating and therefore we don't want to use the ctx.logits buffer directly
3109
+ auto & probs = decoder.probs;
3110
+ auto & logits = decoder.logits;
3111
+ auto & logprobs = decoder.logprobs;
3112
+ {
3113
+ logits.resize(n_logits);
3114
+ memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
3115
+
3116
+ if (temperature > 0.0f) {
3117
+ for (int i = 0; i < n_logits; i++) {
3118
+ logits[i] /= temperature;
3119
+ }
3120
+ }
3121
+
3122
+ // will be populated a bit later
3123
+ probs.resize(n_logits);
3124
+ logprobs.resize(n_logits);
3125
+ }
3126
+
3127
+ // apply logit filters here
3128
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
3129
+ {
3130
+ // suppress blank
3131
+ // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
3132
+ if (params.suppress_blank) {
3133
+ if (is_initial) {
3134
+ logits[vocab.token_eot] = -INFINITY;
3135
+ logits[vocab.token_to_id.at(" ")] = -INFINITY;
3136
+ }
3137
+ }
3138
+
3139
+ // suppress <|notimestamps|> token
3140
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
3141
+ logits[vocab.token_not] = -INFINITY;
3142
+
3143
+ // suppress sot and solm tokens
3144
+ logits[vocab.token_sot] = -INFINITY;
3145
+ logits[vocab.token_solm] = -INFINITY;
3146
+
3147
+ // suppress task tokens
3148
+ logits[vocab.token_translate] = -INFINITY;
3149
+ logits[vocab.token_transcribe] = -INFINITY;
3150
+
3151
+ if (params.logits_filter_callback) {
3152
+ params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
3153
+ }
3154
+
3155
+ // suppress non-speech tokens
3156
+ // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
3157
+ if (params.suppress_non_speech_tokens) {
3158
+ for (const std::string & token : non_speech_tokens) {
3159
+ const std::string suppress_tokens[] = {token, " " + token};
3160
+ for (const std::string & suppress_token : suppress_tokens) {
3161
+ if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
3162
+ logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
3163
+ }
3164
+ }
3165
+ }
3166
+
3167
+ // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
3168
+ if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
3169
+ logits[vocab.token_to_id.at(" -")] = -INFINITY;
3170
+ }
3171
+ if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
3172
+ logits[vocab.token_to_id.at(" '")] = -INFINITY;
3173
+ }
3174
+ }
3175
+
3176
+ // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
3177
+ // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
3178
+ {
3179
+ const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
3180
+ const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
3181
+
3182
+ //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
3183
+
3184
+ if (last_was_timestamp) {
3185
+ if (penultimate_was_timestamp) {
3186
+ for (int i = vocab.token_beg; i < n_logits; ++i) {
3187
+ logits[i] = -INFINITY;
3188
+ }
3189
+ } else {
3190
+ for (int i = 0; i < vocab.token_eot; ++i) {
3191
+ logits[i] = -INFINITY;
3192
+ }
3193
+ }
3194
+ }
3195
+ }
3196
+
3197
+ // the initial timestamp cannot be larger than max_initial_ts
3198
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
3199
+ if (is_initial && params.max_initial_ts > 0.0f) {
3200
+ const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
3201
+ const int tid0 = std::round(params.max_initial_ts/precision);
3202
+
3203
+ for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
3204
+ logits[i] = -INFINITY;
3205
+ }
3206
+ }
3207
+
3208
+ // condition timestamp tokens to be increasing
3209
+ // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556
3210
+ if (decoder.has_ts) {
3211
+ const int tid0 = decoder.seek_delta/2;
3212
+
3213
+ for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) {
3214
+ logits[i] = -INFINITY;
3215
+ }
3216
+ }
3217
+
3218
+ // populate the logprobs array (log_softmax)
3219
+ {
3220
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
3221
+ float logsumexp = 0.0f;
3222
+ for (int i = 0; i < n_logits; ++i) {
3223
+ if (logits[i] > -INFINITY) {
3224
+ logsumexp += expf(logits[i] - logit_max);
3225
+ }
3226
+ }
3227
+ logsumexp = logf(logsumexp) + logit_max;
3228
+
3229
+ for (int i = 0; i < n_logits; ++i) {
3230
+ if (logits[i] > -INFINITY) {
3231
+ logprobs[i] = logits[i] - logsumexp;
3232
+ } else {
3233
+ logprobs[i] = -INFINITY;
3234
+ }
3235
+ }
3236
+ }
3237
+
3238
+ // if sum of probability over timestamps is above any other token, sample timestamp
3239
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
3240
+ {
3241
+ // logsumexp over timestamps
3242
+ float timestamp_logprob = -INFINITY;
3243
+ {
3244
+ float logsumexp = 0.0f;
3245
+ const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
3246
+ for (int i = vocab.token_beg; i < n_logits; ++i) {
3247
+ if (logprobs[i] > -INFINITY) {
3248
+ logsumexp += expf(logprobs[i] - logprob_max);
3249
+ }
3250
+ }
3251
+ if (logsumexp > 0.0f) {
3252
+ timestamp_logprob = logf(logsumexp) + logprob_max;
3253
+ }
3254
+ }
3255
+
3256
+ const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
3257
+
3258
+ //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
3259
+
3260
+ if (timestamp_logprob > max_text_token_logprob) {
3261
+ for (int i = 0; i < vocab.token_beg; ++i) {
3262
+ logits[i] = -INFINITY;
3263
+ logprobs[i] = -INFINITY;
3264
+ }
3265
+ }
3266
+ }
3267
+ }
3268
+
3269
+ // compute probs
3270
+ {
3271
+ for (int i = 0; i < n_logits; ++i) {
3272
+ if (logits[i] == -INFINITY) {
3273
+ probs[i] = 0.0f;
3274
+ } else {
3275
+ probs[i] = expf(logprobs[i]);
3276
+ }
3277
+ }
3278
+ }
3279
+
3280
+ #if 0
3281
+ // print first 100 logits - token string : logit
3282
+ for (int i = 0; i < 100; i++) {
3283
+ const auto token = vocab.id_to_token.at(i);
3284
+ const auto prob = probs[i];
3285
+ const auto logit = logits[i];
3286
+ const auto logprob = logprobs[i];
3287
+ printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
3288
+ }
3289
+
3290
+ // "And", "and", " And", " and"
3291
+ printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
3292
+ printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
3293
+ printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
3294
+ printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
3295
+ printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
3296
+
3297
+ printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
3298
+ printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
3299
+ printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
3300
+ printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
3301
+ printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
3302
+
3303
+ printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
3304
+ printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
3305
+ printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
3306
+ printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
3307
+ printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
3308
+ #endif
3309
+ }
3310
+
3311
+ static whisper_token_data whisper_sample_token(
3312
+ whisper_context & ctx,
3313
+ const whisper_decoder & decoder,
3314
+ bool best) {
3315
+ whisper_token_data result = {
3316
+ 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
3317
+ };
3318
+
3319
+ const auto & vocab = ctx.vocab;
3320
+
3321
+ const auto & probs = decoder.probs;
3322
+ const auto & logprobs = decoder.logprobs;
3323
+
3324
+ const int n_logits = vocab.n_vocab;
3325
+
3326
+ {
3327
+ double sum_ts = 0.0;
3328
+ double max_ts = 0.0;
3329
+
3330
+ for (int i = vocab.token_beg; i < n_logits; i++) {
3331
+ if (probs[i] == -INFINITY) {
3332
+ continue;
3333
+ }
3334
+
3335
+ sum_ts += probs[i];
3336
+ if (max_ts < probs[i]) {
3337
+ max_ts = probs[i];
3338
+ result.tid = i;
3339
+ }
3340
+ }
3341
+
3342
+ result.pt = max_ts/(sum_ts + 1e-10);
3343
+ result.ptsum = sum_ts;
3344
+ }
3345
+
3346
+ if (best) {
3347
+ for (int i = 0; i < n_logits; ++i) {
3348
+ if (result.p < probs[i]) {
3349
+ result.id = i;
3350
+ result.p = probs[i];
3351
+ result.plog = logprobs[i];
3352
+ }
3353
+ }
3354
+ } else {
3355
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
3356
+
3357
+ result.id = dist(ctx.rng);
3358
+ result.p = probs[result.id];
3359
+ result.plog = logprobs[result.id];
3360
+ }
3361
+
3362
+ if (result.id >= vocab.token_beg) {
3363
+ result.tid = result.id;
3364
+ result.pt = result.p;
3365
+ }
3366
+
3367
+ ctx.n_sample++;
3368
+
3369
+ return result;
3370
+ }
3371
+
3372
+ static std::vector<whisper_token_data> whisper_sample_token_topk(
3373
+ whisper_context & ctx,
3374
+ const whisper_decoder & decoder,
3375
+ int k) {
3376
+ const auto & vocab = ctx.vocab;
3377
+
3378
+ const auto & probs = decoder.probs;
3379
+ const auto & logits = decoder.logits;
3380
+ const auto & logprobs = decoder.logprobs;
3381
+
3382
+ const int n_logits = vocab.n_vocab;
3383
+
3384
+ auto & logits_id = ctx.logits_id;
3385
+
3386
+ logits_id.clear();
3387
+ for (int i = 0; i < n_logits; ++i) {
3388
+ logits_id.push_back({ logits[i], i });
3389
+ }
3390
+
3391
+ std::partial_sort(
3392
+ logits_id.begin(),
3393
+ logits_id.begin() + k, logits_id.end(),
3394
+ [](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
3395
+ return a.first > b.first;
3396
+ });
3397
+
3398
+ std::vector<whisper_token_data> result;
3399
+ result.reserve(k);
3400
+
3401
+ whisper_token tid = vocab.token_beg;
3402
+
3403
+ float pt = 0.0;
3404
+ float ptsum = 0.0;
3405
+
3406
+ {
3407
+ double sum_ts = 0.0;
3408
+ double max_ts = 0.0;
3409
+
3410
+ for (int i = vocab.token_beg; i < n_logits; i++) {
3411
+ if (probs[i] == -INFINITY) {
3412
+ continue;
3413
+ }
3414
+
3415
+ sum_ts += probs[i];
3416
+ if (max_ts < probs[i]) {
3417
+ max_ts = probs[i];
3418
+ tid = i;
3419
+ }
3420
+ }
3421
+
3422
+ pt = max_ts/(sum_ts + 1e-10);
3423
+ ptsum = sum_ts;
3424
+ }
3425
+
3426
+ for (int i = 0; i < k; ++i) {
3427
+ const auto id = logits_id[i].second;
3428
+
3429
+ result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
3430
+
3431
+ if (result[i].id >= vocab.token_beg) {
3432
+ result[i].tid = result[i].id;
3433
+ result[i].pt = result[i].p;
3434
+ }
3435
+ }
3436
+
3437
+ ctx.n_sample++;
3438
+
3439
+ return result;
3440
+ }
3441
+
3442
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
3443
+ static void whisper_sequence_score(
3444
+ const struct whisper_full_params & params,
3445
+ whisper_sequence & sequence) {
3446
+ if (sequence.result_len == 0) {
3447
+ return;
3448
+ }
3449
+
3450
+ double result = 0.0f;
3451
+
3452
+ for (int i = 0; i < sequence.result_len; ++i) {
3453
+ result += sequence.tokens[i].plog;
3454
+ }
3455
+
3456
+ sequence.sum_logprobs = result;
3457
+ sequence.avg_logprobs = result/sequence.result_len;
3458
+
3459
+ double penalty = sequence.result_len;
3460
+
3461
+ if (params.length_penalty > 0.0f) {
3462
+ penalty = pow((5.0 + penalty)/6.0, params.length_penalty);
3463
+ }
3464
+
3465
+ sequence.score = result/penalty;
3466
+
3467
+ // compute the entropy of the sequence of the last 32 tokens
3468
+ {
3469
+ const int n = 32;
3470
+
3471
+ int cnt = 0;
3472
+ double entropy = 0.0f;
3473
+
3474
+ std::map<whisper_token, int> token_counts;
3475
+ for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) {
3476
+ token_counts[sequence.tokens[i].id]++;
3477
+ cnt++;
3478
+ }
3479
+
3480
+ for (const auto & kv : token_counts) {
3481
+ const auto p = kv.second/(double)cnt;
3482
+ entropy -= p*log(p);
3483
+
3484
+ //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
3485
+ }
3486
+
3487
+ sequence.entropy = entropy;
3488
+ }
3489
+ }
3490
+
3491
+ int whisper_full(
3492
+ struct whisper_context * ctx,
3493
+ struct whisper_full_params params,
3494
+ const float * samples,
3495
+ int n_samples) {
3496
+ // clear old results
3497
+ auto & result_all = ctx->result_all;
3498
+
3499
+ result_all.clear();
3500
+
3501
+ // compute log mel spectrogram
3502
+ if (params.speed_up) {
3503
+ if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
3504
+ fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3505
+ return -1;
3506
+ }
3507
+ } else {
3508
+ if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
3509
+ fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3510
+ return -2;
3511
+ }
3512
+ }
3513
+
3514
+ // auto-detect language if not specified
3515
+ if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
3516
+ std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
3517
+
3518
+ const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
3519
+ if (lang_id < 0) {
3520
+ fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
3521
+ return -3;
3522
+ }
3523
+ ctx->lang_id = lang_id;
3524
+ params.language = whisper_lang_str(lang_id);
3525
+
3526
+ fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
3527
+ }
3528
+
3529
+ if (params.token_timestamps) {
3530
+ ctx->t_beg = 0;
3531
+ ctx->t_last = 0;
3532
+ ctx->tid_last = 0;
3533
+ ctx->energy = get_signal_energy(samples, n_samples, 32);
3534
+ }
3535
+
3536
+ const int seek_start = params.offset_ms/10;
3537
+ const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
3538
+
3539
+ // if length of spectrogram is less than 1s (100 samples), then return
3540
+ // basically don't process anything that is less than 1s
3541
+ // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
3542
+ if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
3543
+ return 0;
3544
+ }
3545
+
3546
+ // a set of temperatures to use
3547
+ // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
3548
+ std::vector<float> temperatures;
3549
+ if (params.temperature_inc > 0.0f) {
3550
+ for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) {
3551
+ temperatures.push_back(t);
3552
+ }
3553
+ } else {
3554
+ temperatures.push_back(params.temperature);
3555
+ }
3556
+
3557
+ // initialize the decoders
3558
+ int n_decoders = 1;
3559
+
3560
+ switch (params.strategy) {
3561
+ case WHISPER_SAMPLING_GREEDY:
3562
+ {
3563
+ n_decoders = params.greedy.best_of;
3564
+ } break;
3565
+ case WHISPER_SAMPLING_BEAM_SEARCH:
3566
+ {
3567
+ n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size);
3568
+ } break;
3569
+ };
3570
+
3571
+ n_decoders = std::max(1, n_decoders);
3572
+
3573
+ // TAGS: WHISPER_DECODER_INIT
3574
+ for (int j = 1; j < n_decoders; j++) {
3575
+ auto & decoder = ctx->decoders[j];
3576
+
3577
+ if (decoder.kv_self.ctx == nullptr) {
3578
+ decoder.kv_self = ctx->decoders[0].kv_self;
3579
+ if (!kv_cache_reinit(decoder.kv_self)) {
3580
+ fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
3581
+ return -4;
3582
+ }
3583
+
3584
+ WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
3585
+
3586
+ decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
3587
+
3588
+ decoder.probs.resize (ctx->vocab.n_vocab);
3589
+ decoder.logits.resize (ctx->vocab.n_vocab);
3590
+ decoder.logprobs.resize(ctx->vocab.n_vocab);
3591
+ }
3592
+ }
3593
+
3594
+ // the accumulated text context so far
3595
+ auto & prompt_past = ctx->prompt_past;
3596
+ if (params.no_context) {
3597
+ prompt_past.clear();
3598
+ }
3599
+
3600
+ // prepend the prompt tokens to the prompt_past
3601
+ if (params.prompt_tokens && params.prompt_n_tokens > 0) {
3602
+ // parse tokens from the pointer
3603
+ for (int i = 0; i < params.prompt_n_tokens; i++) {
3604
+ prompt_past.push_back(params.prompt_tokens[i]);
3605
+ }
3606
+ std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
3607
+ }
3608
+
3609
+ // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
3610
+ if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
3611
+ fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
3612
+ return -5;
3613
+ }
3614
+ ctx->exp_n_audio_ctx = params.audio_ctx;
3615
+
3616
+ // these tokens determine the task that will be performed
3617
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
3618
+ if (whisper_is_multilingual(ctx)) {
3619
+ const int lang_id = whisper_lang_id(params.language);
3620
+ ctx->lang_id = lang_id;
3621
+ prompt_init.push_back(whisper_token_lang(ctx, lang_id));
3622
+ if (params.translate) {
3623
+ prompt_init.push_back(whisper_token_translate());
3624
+ } else {
3625
+ prompt_init.push_back(whisper_token_transcribe());
3626
+ }
3627
+ }
3628
+
3629
+ int progress_prev = 0;
3630
+ int progress_step = 5;
3631
+
3632
+ int seek = seek_start;
3633
+
3634
+ std::vector<whisper_token> prompt;
3635
+ prompt.reserve(whisper_n_text_ctx(ctx));
3636
+
3637
+ // beam-search helpers
3638
+ struct kv_buf {
3639
+ std::vector<uint8_t> k;
3640
+ std::vector<uint8_t> v;
3641
+ };
3642
+
3643
+ std::vector<kv_buf> kv_bufs;
3644
+
3645
+ struct beam_candidate {
3646
+ int decoder_idx;
3647
+ int seek_delta;
3648
+
3649
+ bool has_ts;
3650
+
3651
+ whisper_sequence sequence;
3652
+ };
3653
+
3654
+ std::vector<beam_candidate> beam_candidates;
3655
+
3656
+ // main loop
3657
+ while (true) {
3658
+ const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
3659
+ while (progress_cur >= progress_prev + progress_step) {
3660
+ progress_prev += progress_step;
3661
+ if (params.print_progress) {
3662
+ fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
3663
+ }
3664
+ }
3665
+
3666
+ // of only 1 second left, then stop
3667
+ if (seek + 100 >= seek_end) {
3668
+ break;
3669
+ }
3670
+
3671
+ if (params.encoder_begin_callback) {
3672
+ if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
3673
+ fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
3674
+ break;
3675
+ }
3676
+ }
3677
+
3678
+ // encode audio features starting at offset seek
3679
+ if (!whisper_encode(*ctx, seek, params.n_threads)) {
3680
+ fprintf(stderr, "%s: failed to encode\n", __func__);
3681
+ return -6;
3682
+ }
3683
+
3684
+ // if there is a very short audio segment left to process, we remove any past prompt since it tends
3685
+ // to confuse the decoder and often make it repeat or hallucinate stuff
3686
+ if (seek > seek_start && seek + 500 >= seek_end) {
3687
+ prompt_past.clear();
3688
+ }
3689
+
3690
+ int best_decoder_id = 0;
3691
+
3692
+ for (int it = 0; it < (int) temperatures.size(); ++it) {
3693
+ const float t_cur = temperatures[it];
3694
+
3695
+ int n_decoders_cur = 1;
3696
+
3697
+ switch (params.strategy) {
3698
+ case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
3699
+ {
3700
+ if (t_cur > 0.0f) {
3701
+ n_decoders_cur = params.greedy.best_of;
3702
+ }
3703
+ } break;
3704
+ case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
3705
+ {
3706
+ if (t_cur > 0.0f) {
3707
+ n_decoders_cur = params.greedy.best_of;
3708
+ } else {
3709
+ n_decoders_cur = params.beam_search.beam_size;
3710
+ }
3711
+ } break;
3712
+ };
3713
+
3714
+ n_decoders_cur = std::max(1, n_decoders_cur);
3715
+
3716
+ WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
3717
+
3718
+ // TAGS: WHISPER_DECODER_INIT
3719
+ for (int j = 0; j < n_decoders_cur; ++j) {
3720
+ auto & decoder = ctx->decoders[j];
3721
+
3722
+ decoder.kv_self.n = 0;
3723
+
3724
+ decoder.sequence.tokens.clear();
3725
+ decoder.sequence.result_len = 0;
3726
+ decoder.sequence.sum_logprobs_all = 0.0;
3727
+ decoder.sequence.sum_logprobs = -INFINITY;
3728
+ decoder.sequence.avg_logprobs = -INFINITY;
3729
+ decoder.sequence.entropy = 0.0;
3730
+ decoder.sequence.score = -INFINITY;
3731
+
3732
+ decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
3733
+
3734
+ decoder.failed = false;
3735
+ decoder.completed = false;
3736
+ decoder.has_ts = false;
3737
+ }
3738
+
3739
+ // init prompt and kv cache for the current iteration
3740
+ // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
3741
+ {
3742
+ prompt.clear();
3743
+
3744
+ // if we have already generated some text, use it as a prompt to condition the next generation
3745
+ if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
3746
+ int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
3747
+
3748
+ prompt = { whisper_token_prev(ctx) };
3749
+ prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
3750
+ }
3751
+
3752
+ // init new transcription with sot, language (opt) and task tokens
3753
+ prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
3754
+
3755
+ // print the prompt
3756
+ WHISPER_PRINT_DEBUG("\n\n");
3757
+ for (int i = 0; i < (int) prompt.size(); i++) {
3758
+ WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
3759
+ }
3760
+ WHISPER_PRINT_DEBUG("\n\n");
3761
+
3762
+ if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
3763
+ fprintf(stderr, "%s: failed to decode\n", __func__);
3764
+ return -7;
3765
+ }
3766
+
3767
+ {
3768
+ const int64_t t_start_sample_us = ggml_time_us();
3769
+
3770
+ whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
3771
+
3772
+ ctx->decoders[0].kv_self.n += prompt.size();
3773
+
3774
+ for (int j = 1; j < n_decoders_cur; ++j) {
3775
+ auto & decoder = ctx->decoders[j];
3776
+
3777
+ memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
3778
+ memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
3779
+
3780
+ decoder.kv_self.n += prompt.size();
3781
+
3782
+ memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
3783
+ memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
3784
+ memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
3785
+ }
3786
+
3787
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
3788
+ }
3789
+ }
3790
+
3791
+ for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
3792
+ const int64_t t_start_sample_us = ggml_time_us();
3793
+
3794
+ // store the KV caches of all decoders when doing beam-search
3795
+ if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
3796
+ kv_bufs.resize(n_decoders_cur);
3797
+ for (int j = 0; j < n_decoders_cur; ++j) {
3798
+ auto & decoder = ctx->decoders[j];
3799
+
3800
+ if (decoder.completed || decoder.failed) {
3801
+ continue;
3802
+ }
3803
+
3804
+ kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
3805
+ kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
3806
+
3807
+ memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
3808
+ memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
3809
+ }
3810
+
3811
+ beam_candidates.clear();
3812
+ }
3813
+
3814
+ // generate new sequence candidates for each decoder
3815
+ for (int j = 0; j < n_decoders_cur; ++j) {
3816
+ auto & decoder = ctx->decoders[j];
3817
+
3818
+ if (decoder.completed || decoder.failed) {
3819
+ continue;
3820
+ }
3821
+
3822
+ switch (params.strategy) {
3823
+ case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
3824
+ {
3825
+ if (t_cur < 1e-6f) {
3826
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
3827
+ } else {
3828
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
3829
+ }
3830
+
3831
+ decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
3832
+ } break;
3833
+ case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
3834
+ {
3835
+ const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
3836
+
3837
+ for (const auto & token : tokens_new) {
3838
+ beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
3839
+ beam_candidates.back().sequence.tokens.push_back(token);
3840
+ beam_candidates.back().sequence.sum_logprobs_all += token.plog;
3841
+
3842
+ //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
3843
+ }
3844
+ } break;
3845
+ };
3846
+ }
3847
+
3848
+ // for beam-search, choose the top candidates and update the KV caches
3849
+ if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
3850
+ std::sort(
3851
+ beam_candidates.begin(),
3852
+ beam_candidates.end(),
3853
+ [](const beam_candidate & a, const beam_candidate & b) {
3854
+ return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
3855
+ });
3856
+
3857
+ uint32_t cur_c = 0;
3858
+
3859
+ for (int j = 0; j < n_decoders_cur; ++j) {
3860
+ auto & decoder = ctx->decoders[j];
3861
+
3862
+ if (decoder.completed || decoder.failed) {
3863
+ continue;
3864
+ }
3865
+
3866
+ auto & cur = beam_candidates[cur_c++];
3867
+
3868
+ while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
3869
+ ++cur_c;
3870
+ }
3871
+
3872
+ decoder.sequence = cur.sequence;
3873
+ decoder.seek_delta = cur.seek_delta;
3874
+ decoder.has_ts = cur.has_ts;
3875
+
3876
+ memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
3877
+ memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
3878
+
3879
+ WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
3880
+ __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
3881
+ }
3882
+ }
3883
+
3884
+ // update the decoder state
3885
+ // - check if the sequence is completed
3886
+ // - check if the sequence is failed
3887
+ // - update sliding window based on timestamp tokens
3888
+ for (int j = 0; j < n_decoders_cur; ++j) {
3889
+ auto & decoder = ctx->decoders[j];
3890
+
3891
+ if (decoder.completed || decoder.failed) {
3892
+ continue;
3893
+ }
3894
+
3895
+ auto & has_ts = decoder.has_ts;
3896
+ auto & failed = decoder.failed;
3897
+ auto & completed = decoder.completed;
3898
+ auto & seek_delta = decoder.seek_delta;
3899
+ auto & result_len = decoder.sequence.result_len;
3900
+
3901
+ {
3902
+ const auto & token = decoder.sequence.tokens.back();
3903
+
3904
+ // timestamp token - update sliding window
3905
+ if (token.id > whisper_token_beg(ctx)) {
3906
+ const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
3907
+
3908
+ // do not allow to go back in time
3909
+ if (has_ts && seek_delta > seek_delta_new && result_len < i) {
3910
+ failed = true; // TODO: maybe this is not a failure ?
3911
+ continue;
3912
+ }
3913
+
3914
+ seek_delta = seek_delta_new;
3915
+ result_len = i + 1;
3916
+ has_ts = true;
3917
+ }
3918
+
3919
+ #ifdef WHISPER_DEBUG
3920
+ {
3921
+ const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
3922
+ WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
3923
+ __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
3924
+ }
3925
+ #endif
3926
+
3927
+ // end of segment
3928
+ if (token.id == whisper_token_eot(ctx) || // end of text token
3929
+ (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
3930
+ (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
3931
+ ) {
3932
+ if (result_len == 0) {
3933
+ if (seek + seek_delta + 100 >= seek_end) {
3934
+ result_len = i + 1;
3935
+ } else {
3936
+ failed = true;
3937
+ continue;
3938
+ }
3939
+ }
3940
+
3941
+ if (params.single_segment) {
3942
+ result_len = i + 1;
3943
+ seek_delta = 100*WHISPER_CHUNK_SIZE;
3944
+ }
3945
+
3946
+ completed = true;
3947
+ continue;
3948
+ }
3949
+
3950
+ // TESTS: if no tensors are loaded, it means we are running tests
3951
+ if (ctx->model.n_loaded == 0) {
3952
+ seek_delta = 100*WHISPER_CHUNK_SIZE;
3953
+ completed = true;
3954
+ continue;
3955
+ }
3956
+ }
3957
+
3958
+ // sometimes, the decoding can get stuck in a repetition loop
3959
+ // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
3960
+ if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
3961
+ failed = true;
3962
+ continue;
3963
+ }
3964
+ }
3965
+
3966
+ // check if all decoders have finished (i.e. completed or failed)
3967
+ {
3968
+ bool completed_all = true;
3969
+
3970
+ for (int j = 0; j < n_decoders_cur; ++j) {
3971
+ auto & decoder = ctx->decoders[j];
3972
+
3973
+ if (decoder.completed || decoder.failed) {
3974
+ continue;
3975
+ }
3976
+
3977
+ completed_all = false;
3978
+ }
3979
+
3980
+ if (completed_all) {
3981
+ break;
3982
+ }
3983
+ }
3984
+
3985
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
3986
+
3987
+ // obtain logits for the next token
3988
+ for (int j = 0; j < n_decoders_cur; ++j) {
3989
+ auto & decoder = ctx->decoders[j];
3990
+
3991
+ if (decoder.failed || decoder.completed) {
3992
+ continue;
3993
+ }
3994
+
3995
+ decoder.tokens_tmp.resize(1);
3996
+ decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
3997
+
3998
+ //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
3999
+
4000
+ if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4001
+ fprintf(stderr, "%s: failed to decode\n", __func__);
4002
+ return -8;
4003
+ }
4004
+
4005
+ {
4006
+ const int64_t t_start_sample_us = ggml_time_us();
4007
+
4008
+ whisper_process_logits(*ctx, params, decoder, t_cur);
4009
+
4010
+ ++decoder.kv_self.n;
4011
+
4012
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
4013
+ }
4014
+ }
4015
+ }
4016
+
4017
+ // rank the resulting sequences and select the best one
4018
+ {
4019
+ double best_score = -INFINITY;
4020
+
4021
+ for (int j = 0; j < n_decoders_cur; ++j) {
4022
+ auto & decoder = ctx->decoders[j];
4023
+
4024
+ if (decoder.failed) {
4025
+ continue;
4026
+ }
4027
+
4028
+ decoder.sequence.tokens.resize(decoder.sequence.result_len);
4029
+ whisper_sequence_score(params, decoder.sequence);
4030
+
4031
+ WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
4032
+ __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
4033
+
4034
+ if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
4035
+ WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
4036
+ __func__, j, decoder.sequence.entropy, params.entropy_thold);
4037
+
4038
+ decoder.failed = true;
4039
+ ctx->n_fail_h++;
4040
+
4041
+ continue;
4042
+ }
4043
+
4044
+ if (best_score < decoder.sequence.score) {
4045
+ best_score = decoder.sequence.score;
4046
+ best_decoder_id = j;
4047
+ }
4048
+ }
4049
+
4050
+ WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
4051
+ }
4052
+
4053
+ // was the decoding successful for the current temperature?
4054
+ {
4055
+ bool success = true;
4056
+
4057
+ const auto & decoder = ctx->decoders[best_decoder_id];
4058
+
4059
+ if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
4060
+ success = false;
4061
+ ctx->n_fail_p++;
4062
+ }
4063
+
4064
+ if (success) {
4065
+ //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
4066
+ // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
4067
+ //}
4068
+
4069
+ break;
4070
+ }
4071
+ }
4072
+
4073
+ WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
4074
+ }
4075
+
4076
+ // output results through a user-provided callback
4077
+ {
4078
+ const auto & best_decoder = ctx->decoders[best_decoder_id];
4079
+
4080
+ const auto seek_delta = best_decoder.seek_delta;
4081
+ const auto result_len = best_decoder.sequence.result_len;
4082
+
4083
+ const auto & tokens_cur = best_decoder.sequence.tokens;
4084
+
4085
+ //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
4086
+
4087
+ // update prompt_past
4088
+ prompt_past.clear();
4089
+ if (prompt.front() == whisper_token_prev(ctx)) {
4090
+ prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
4091
+ }
4092
+
4093
+ for (int i = 0; i < result_len; ++i) {
4094
+ prompt_past.push_back(tokens_cur[i].id);
4095
+ }
4096
+
4097
+ // store the text from this iteration
4098
+ if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
4099
+ int i0 = 0;
4100
+ auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
4101
+
4102
+ std::string text;
4103
+
4104
+ for (int i = 0; i < (int) tokens_cur.size(); i++) {
4105
+ //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
4106
+ // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
4107
+ // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
4108
+
4109
+ if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
4110
+ } else {
4111
+ text += whisper_token_to_str(ctx, tokens_cur[i].id);
4112
+ }
4113
+
4114
+ if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
4115
+ const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
4116
+
4117
+ if (!text.empty()) {
4118
+ const auto tt0 = params.speed_up ? 2*t0 : t0;
4119
+ const auto tt1 = params.speed_up ? 2*t1 : t1;
4120
+
4121
+ if (params.print_realtime) {
4122
+ if (params.print_timestamps) {
4123
+ printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
4124
+ } else {
4125
+ printf("%s", text.c_str());
4126
+ fflush(stdout);
4127
+ }
4128
+ }
4129
+
4130
+ //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
4131
+
4132
+ result_all.push_back({ tt0, tt1, text, {} });
4133
+ for (int j = i0; j <= i; j++) {
4134
+ result_all.back().tokens.push_back(tokens_cur[j]);
4135
+ }
4136
+
4137
+ int n_new = 1;
4138
+
4139
+ if (params.token_timestamps) {
4140
+ whisper_exp_compute_token_level_timestamps(
4141
+ *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4142
+
4143
+ if (params.max_len > 0) {
4144
+ n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4145
+ }
4146
+ }
4147
+ if (params.new_segment_callback) {
4148
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4149
+ }
4150
+ }
4151
+ text = "";
4152
+ while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
4153
+ i++;
4154
+ }
4155
+ i--;
4156
+ t0 = t1;
4157
+ i0 = i + 1;
4158
+ }
4159
+ }
4160
+
4161
+ if (!text.empty()) {
4162
+ const auto t1 = seek + seek_delta;
4163
+
4164
+ const auto tt0 = params.speed_up ? 2*t0 : t0;
4165
+ const auto tt1 = params.speed_up ? 2*t1 : t1;
4166
+
4167
+ if (params.print_realtime) {
4168
+ if (params.print_timestamps) {
4169
+ printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
4170
+ } else {
4171
+ printf("%s", text.c_str());
4172
+ fflush(stdout);
4173
+ }
4174
+ }
4175
+
4176
+ result_all.push_back({ tt0, tt1, text, {} });
4177
+ for (int j = i0; j < (int) tokens_cur.size(); j++) {
4178
+ result_all.back().tokens.push_back(tokens_cur[j]);
4179
+ }
4180
+
4181
+ int n_new = 1;
4182
+
4183
+ if (params.token_timestamps) {
4184
+ whisper_exp_compute_token_level_timestamps(
4185
+ *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4186
+
4187
+ if (params.max_len > 0) {
4188
+ n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4189
+ }
4190
+ }
4191
+ if (params.new_segment_callback) {
4192
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4193
+ }
4194
+ }
4195
+ }
4196
+
4197
+ // update audio window
4198
+ seek += seek_delta;
4199
+
4200
+ WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
4201
+ }
4202
+ }
4203
+
4204
+ return 0;
4205
+ }
4206
+
4207
+ int whisper_full_parallel(
4208
+ struct whisper_context * ctx,
4209
+ struct whisper_full_params params,
4210
+ const float * samples,
4211
+ int n_samples,
4212
+ int n_processors) {
4213
+ if (n_processors == 1) {
4214
+ return whisper_full(ctx, params, samples, n_samples);
4215
+ }
4216
+
4217
+ int ret = 0;
4218
+
4219
+ // prepare separate contexts for each thread
4220
+ std::vector<struct whisper_context> ctxs(n_processors - 1);
4221
+
4222
+ for (int i = 0; i < n_processors - 1; ++i) {
4223
+ auto & ctx_p = ctxs[i];
4224
+
4225
+ ctx_p = *ctx;
4226
+
4227
+ ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
4228
+
4229
+ ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
4230
+
4231
+ if (!kv_cache_reinit(ctx_p.kv_cross)) {
4232
+ fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
4233
+ return false;
4234
+ }
4235
+
4236
+ // TAGS: WHISPER_DECODER_INIT
4237
+ for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4238
+ if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
4239
+ fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
4240
+ return false;
4241
+ }
4242
+
4243
+ ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
4244
+
4245
+ ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
4246
+ ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
4247
+ ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
4248
+ }
4249
+ }
4250
+
4251
+ const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
4252
+ const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
4253
+
4254
+ // the calling thread will process the first chunk
4255
+ // while the other threads will process the remaining chunks
4256
+
4257
+ std::vector<std::thread> workers(n_processors - 1);
4258
+ for (int i = 0; i < n_processors - 1; ++i) {
4259
+ const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
4260
+ const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
4261
+
4262
+ auto params_cur = params;
4263
+
4264
+ params_cur.offset_ms = 0;
4265
+ params_cur.print_progress = false;
4266
+ params_cur.print_realtime = false;
4267
+
4268
+ params_cur.new_segment_callback = nullptr;
4269
+ params_cur.new_segment_callback_user_data = nullptr;
4270
+
4271
+ workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
4272
+ }
4273
+
4274
+ {
4275
+ auto params_cur = params;
4276
+
4277
+ ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
4278
+ }
4279
+
4280
+ for (int i = 0; i < n_processors - 1; ++i) {
4281
+ workers[i].join();
4282
+ }
4283
+
4284
+ const int64_t offset_t = (int64_t) params.offset_ms/10.0;
4285
+
4286
+ // combine results into ctx->result_all
4287
+ for (int i = 0; i < n_processors - 1; ++i) {
4288
+ auto & results_i = ctxs[i].result_all;
4289
+
4290
+ for (auto & result : results_i) {
4291
+ // correct the segment timestamp taking into account the offset
4292
+ result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
4293
+ result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
4294
+
4295
+ // make sure that segments are not overlapping
4296
+ if (!ctx->result_all.empty()) {
4297
+ result.t0 = std::max(result.t0, ctx->result_all.back().t1);
4298
+ }
4299
+
4300
+ ctx->result_all.push_back(std::move(result));
4301
+
4302
+ // call the new_segment_callback for each segment
4303
+ if (params.new_segment_callback) {
4304
+ params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
4305
+ }
4306
+ }
4307
+
4308
+ ctx->t_mel_us += ctxs[i].t_mel_us;
4309
+ ctx->t_sample_us += ctxs[i].t_sample_us;
4310
+ ctx->t_encode_us += ctxs[i].t_encode_us;
4311
+ ctx->t_decode_us += ctxs[i].t_decode_us;
4312
+
4313
+ kv_cache_free(ctx->kv_cross);
4314
+
4315
+ for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4316
+ kv_cache_free(ctx->decoders[j].kv_self);
4317
+ }
4318
+ }
4319
+
4320
+ // average the timings
4321
+ ctx->t_mel_us /= n_processors;
4322
+ ctx->t_sample_us /= n_processors;
4323
+ ctx->t_encode_us /= n_processors;
4324
+ ctx->t_decode_us /= n_processors;
4325
+
4326
+ // print information about the audio boundaries
4327
+ fprintf(stderr, "\n");
4328
+ fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
4329
+ for (int i = 0; i < n_processors - 1; ++i) {
4330
+ fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
4331
+ }
4332
+ fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
4333
+
4334
+ return ret;
4335
+ }
4336
+
4337
+ int whisper_full_n_segments(struct whisper_context * ctx) {
4338
+ return ctx->result_all.size();
4339
+ }
4340
+
4341
+ int whisper_full_lang_id(struct whisper_context * ctx) {
4342
+ return ctx->lang_id;
4343
+ }
4344
+
4345
+ int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
4346
+ return ctx->result_all[i_segment].t0;
4347
+ }
4348
+
4349
+ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
4350
+ return ctx->result_all[i_segment].t1;
4351
+ }
4352
+
4353
+ const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
4354
+ return ctx->result_all[i_segment].text.c_str();
4355
+ }
4356
+
4357
+ int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
4358
+ return ctx->result_all[i_segment].tokens.size();
4359
+ }
4360
+
4361
+ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
4362
+ return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
4363
+ }
4364
+
4365
+ whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
4366
+ return ctx->result_all[i_segment].tokens[i_token].id;
4367
+ }
4368
+
4369
+ struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
4370
+ return ctx->result_all[i_segment].tokens[i_token];
4371
+ }
4372
+
4373
+ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
4374
+ return ctx->result_all[i_segment].tokens[i_token].p;
4375
+ }
4376
+
4377
+ // =================================================================================================
4378
+
4379
+ //
4380
+ // Temporary interface needed for exposing ggml interface
4381
+ // Will be removed in the future when ggml becomes a separate library
4382
+ //
4383
+
4384
+ WHISPER_API int whisper_bench_memcpy(int n_threads) {
4385
+ ggml_time_init();
4386
+
4387
+ size_t n = 50;
4388
+ size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations
4389
+
4390
+ // 1 GB array
4391
+ const size_t size = arr*1024llu*1024llu;
4392
+
4393
+ char * src = (char *) malloc(size);
4394
+ char * dst = (char *) malloc(size);
4395
+
4396
+ for (size_t i = 0; i < size; i++) src[i] = i;
4397
+
4398
+ memcpy(dst, src, size); // heat-up
4399
+
4400
+ double tsum = 0.0;
4401
+
4402
+ for (size_t i = 0; i < n; i++) {
4403
+ const int64_t t0 = ggml_time_us();
4404
+
4405
+ memcpy(dst, src, size);
4406
+
4407
+ const int64_t t1 = ggml_time_us();
4408
+
4409
+ tsum += (t1 - t0)*1e-6;
4410
+
4411
+ src[0] = rand();
4412
+ }
4413
+
4414
+ fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
4415
+
4416
+ // needed to prevent the compile from optimizing the memcpy away
4417
+ {
4418
+ double sum = 0.0;
4419
+
4420
+ for (size_t i = 0; i < size; i++) sum += dst[i];
4421
+
4422
+ fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
4423
+ }
4424
+
4425
+ free(src);
4426
+ free(dst);
4427
+
4428
+ return 0;
4429
+ }
4430
+
4431
+ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
4432
+ ggml_time_init();
4433
+
4434
+ const int n_max = 128;
4435
+
4436
+ const std::vector<size_t> sizes = {
4437
+ 64, 128, 256, 512, 1024, 2048, 4096,
4438
+ };
4439
+
4440
+ const size_t N_max = sizes.back();
4441
+
4442
+ // a: N*N*sizeof(float)
4443
+ // b: N*N*sizeof(float)
4444
+ // c: N*N*sizeof(float)
4445
+ // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
4446
+ std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
4447
+
4448
+ for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
4449
+
4450
+ for (int j = 0; j < (int) sizes.size(); j++) {
4451
+ int n_fp16 = 0;
4452
+ int n_fp32 = 0;
4453
+
4454
+ // GFLOPS/s
4455
+ double s_fp16 = 0.0;
4456
+ double s_fp32 = 0.0;
4457
+
4458
+ const size_t N = sizes[j];
4459
+
4460
+ for (int k = 0; k < 2; ++k) {
4461
+ const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32;
4462
+
4463
+ double & s = k == 0 ? s_fp16 : s_fp32;
4464
+ int & n = k == 0 ? n_fp16 : n_fp32;
4465
+
4466
+ struct ggml_init_params gparams = {
4467
+ /*.mem_size =*/ buf.size(),
4468
+ /*.mem_buffer =*/ buf.data(),
4469
+ };
4470
+
4471
+ struct ggml_context * ctx0 = ggml_init(gparams);
4472
+
4473
+ struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N);
4474
+ struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N);
4475
+
4476
+ struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
4477
+
4478
+ struct ggml_cgraph gf = ggml_build_forward(c);
4479
+
4480
+ gf.n_threads = n_threads;
4481
+
4482
+ double tsum = 0.0;
4483
+
4484
+ // heat-up
4485
+ ggml_graph_compute(ctx0, &gf);
4486
+
4487
+ for (int i = 0; i < n_max; ++i) {
4488
+ const int64_t t0 = ggml_time_us();
4489
+
4490
+ ggml_graph_compute(ctx0, &gf);
4491
+
4492
+ const int64_t t1 = ggml_time_us();
4493
+
4494
+ tsum += (t1 - t0)*1e-6;
4495
+ n++;
4496
+
4497
+ if (tsum > 1.0 && n >= 3) {
4498
+ break;
4499
+ }
4500
+ }
4501
+
4502
+ ggml_free(ctx0);
4503
+
4504
+ s = ((2.0*N*N*N*n)/tsum)*1e-9;
4505
+ }
4506
+
4507
+ fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
4508
+ N, N, s_fp16, n_fp16, s_fp32, n_fp32);
4509
+ }
4510
+
4511
+ return 0;
4512
+ }
4513
+
4514
+ // =================================================================================================
4515
+
4516
+ // =================================================================================================
4517
+
4518
+ //
4519
+ // Experimental stuff below
4520
+ //
4521
+ // Not sure if these should be part of the library at all, because the quality of the results is not
4522
+ // guaranteed. Might get removed at some point unless a robust algorithm implementation is found
4523
+ //
4524
+
4525
+ // =================================================================================================
4526
+
4527
+ //
4528
+ // token-level timestamps
4529
+ //
4530
+
4531
+ static int timestamp_to_sample(int64_t t, int n_samples) {
4532
+ return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
4533
+ }
4534
+
4535
+ static int64_t sample_to_timestamp(int i_sample) {
4536
+ return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
4537
+ }
4538
+
4539
+ // a cost-function / heuristic that is high for text that takes longer to pronounce
4540
+ // obviously, can be improved
4541
+ static float voice_length(const std::string & text) {
4542
+ float res = 0.0f;
4543
+
4544
+ for (char c : text) {
4545
+ if (c == ' ') {
4546
+ res += 0.01f;
4547
+ } else if (c == ',') {
4548
+ res += 2.00f;
4549
+ } else if (c == '.') {
4550
+ res += 3.00f;
4551
+ } else if (c == '!') {
4552
+ res += 3.00f;
4553
+ } else if (c == '?') {
4554
+ res += 3.00f;
4555
+ } else if (c >= '0' && c <= '9') {
4556
+ res += 3.00f;
4557
+ } else {
4558
+ res += 1.00f;
4559
+ }
4560
+ }
4561
+
4562
+ return res;
4563
+ }
4564
+
4565
+ // average the fabs of the signal
4566
+ static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
4567
+ const int hw = n_samples_per_half_window;
4568
+
4569
+ std::vector<float> result(n_samples);
4570
+
4571
+ for (int i = 0; i < n_samples; i++) {
4572
+ float sum = 0;
4573
+ for (int j = -hw; j <= hw; j++) {
4574
+ if (i + j >= 0 && i + j < n_samples) {
4575
+ sum += fabs(signal[i + j]);
4576
+ }
4577
+ }
4578
+ result[i] = sum/(2*hw + 1);
4579
+ }
4580
+
4581
+ return result;
4582
+ }
4583
+
4584
+ static void whisper_exp_compute_token_level_timestamps(
4585
+ struct whisper_context & ctx,
4586
+ int i_segment,
4587
+ float thold_pt,
4588
+ float thold_ptsum) {
4589
+ auto & segment = ctx.result_all[i_segment];
4590
+ auto & tokens = segment.tokens;
4591
+
4592
+ const int n_samples = ctx.energy.size();
4593
+
4594
+ if (n_samples == 0) {
4595
+ fprintf(stderr, "%s: no signal data available\n", __func__);
4596
+ return;
4597
+ }
4598
+
4599
+ const int64_t t0 = segment.t0;
4600
+ const int64_t t1 = segment.t1;
4601
+
4602
+ const int n = tokens.size();
4603
+
4604
+ if (n == 0) {
4605
+ return;
4606
+ }
4607
+
4608
+ if (n == 1) {
4609
+ tokens[0].t0 = t0;
4610
+ tokens[0].t1 = t1;
4611
+
4612
+ return;
4613
+ }
4614
+
4615
+ auto & t_beg = ctx.t_beg;
4616
+ auto & t_last = ctx.t_last;
4617
+ auto & tid_last = ctx.tid_last;
4618
+
4619
+ for (int j = 0; j < n; ++j) {
4620
+ auto & token = tokens[j];
4621
+
4622
+ if (j == 0) {
4623
+ if (token.id == whisper_token_beg(&ctx)) {
4624
+ tokens[j ].t0 = t0;
4625
+ tokens[j ].t1 = t0;
4626
+ tokens[j + 1].t0 = t0;
4627
+
4628
+ t_beg = t0;
4629
+ t_last = t0;
4630
+ tid_last = whisper_token_beg(&ctx);
4631
+ } else {
4632
+ tokens[j ].t0 = t_last;
4633
+ }
4634
+ }
4635
+
4636
+ const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
4637
+
4638
+ tokens[j].id = token.id;
4639
+ tokens[j].tid = token.tid;
4640
+ tokens[j].p = token.p;
4641
+ tokens[j].pt = token.pt;
4642
+ tokens[j].ptsum = token.ptsum;
4643
+
4644
+ tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
4645
+
4646
+ if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
4647
+ if (j > 0) {
4648
+ tokens[j - 1].t1 = tt;
4649
+ }
4650
+ tokens[j].t0 = tt;
4651
+ tid_last = token.tid;
4652
+ }
4653
+ }
4654
+
4655
+ tokens[n - 2].t1 = t1;
4656
+ tokens[n - 1].t0 = t1;
4657
+ tokens[n - 1].t1 = t1;
4658
+
4659
+ t_last = t1;
4660
+
4661
+ // find intervals of tokens with unknown timestamps
4662
+ // fill the timestamps by proportionally splitting the interval based on the token voice lengths
4663
+ {
4664
+ int p0 = 0;
4665
+ int p1 = 0;
4666
+
4667
+ while (true) {
4668
+ while (p1 < n && tokens[p1].t1 < 0) {
4669
+ p1++;
4670
+ }
4671
+
4672
+ if (p1 >= n) {
4673
+ p1--;
4674
+ }
4675
+
4676
+ //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1);
4677
+
4678
+ if (p1 > p0) {
4679
+ double psum = 0.0;
4680
+ for (int j = p0; j <= p1; j++) {
4681
+ psum += tokens[j].vlen;
4682
+ }
4683
+
4684
+ //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
4685
+
4686
+ const double dt = tokens[p1].t1 - tokens[p0].t0;
4687
+
4688
+ // split the time proportionally to the voice length
4689
+ for (int j = p0 + 1; j <= p1; j++) {
4690
+ const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
4691
+
4692
+ tokens[j - 1].t1 = ct;
4693
+ tokens[j ].t0 = ct;
4694
+ }
4695
+ }
4696
+
4697
+ p1++;
4698
+ p0 = p1;
4699
+ if (p1 >= n) {
4700
+ break;
4701
+ }
4702
+ }
4703
+ }
4704
+
4705
+ // fix up (just in case)
4706
+ for (int j = 0; j < n - 1; j++) {
4707
+ if (tokens[j].t1 < 0) {
4708
+ tokens[j + 1].t0 = tokens[j].t1;
4709
+ }
4710
+
4711
+ if (j > 0) {
4712
+ if (tokens[j - 1].t1 > tokens[j].t0) {
4713
+ tokens[j].t0 = tokens[j - 1].t1;
4714
+ tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
4715
+ }
4716
+ }
4717
+ }
4718
+
4719
+ // VAD
4720
+ // expand or contract tokens based on voice activity
4721
+ {
4722
+ const int hw = WHISPER_SAMPLE_RATE/8;
4723
+
4724
+ for (int j = 0; j < n; j++) {
4725
+ if (tokens[j].id >= whisper_token_eot(&ctx)) {
4726
+ continue;
4727
+ }
4728
+
4729
+ int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
4730
+ int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
4731
+
4732
+ const int ss0 = std::max(s0 - hw, 0);
4733
+ const int ss1 = std::min(s1 + hw, n_samples);
4734
+
4735
+ const int ns = ss1 - ss0;
4736
+
4737
+ float sum = 0.0f;
4738
+
4739
+ for (int k = ss0; k < ss1; k++) {
4740
+ sum += ctx.energy[k];
4741
+ }
4742
+
4743
+ const float thold = 0.5*sum/ns;
4744
+
4745
+ {
4746
+ int k = s0;
4747
+ if (ctx.energy[k] > thold && j > 0) {
4748
+ while (k > 0 && ctx.energy[k] > thold) {
4749
+ k--;
4750
+ }
4751
+ tokens[j].t0 = sample_to_timestamp(k);
4752
+ if (tokens[j].t0 < tokens[j - 1].t1) {
4753
+ tokens[j].t0 = tokens[j - 1].t1;
4754
+ } else {
4755
+ s0 = k;
4756
+ }
4757
+ } else {
4758
+ while (ctx.energy[k] < thold && k < s1) {
4759
+ k++;
4760
+ }
4761
+ s0 = k;
4762
+ tokens[j].t0 = sample_to_timestamp(k);
4763
+ }
4764
+ }
4765
+
4766
+ {
4767
+ int k = s1;
4768
+ if (ctx.energy[k] > thold) {
4769
+ while (k < n_samples - 1 && ctx.energy[k] > thold) {
4770
+ k++;
4771
+ }
4772
+ tokens[j].t1 = sample_to_timestamp(k);
4773
+ if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
4774
+ tokens[j].t1 = tokens[j + 1].t0;
4775
+ } else {
4776
+ s1 = k;
4777
+ }
4778
+ } else {
4779
+ while (ctx.energy[k] < thold && k > s0) {
4780
+ k--;
4781
+ }
4782
+ s1 = k;
4783
+ tokens[j].t1 = sample_to_timestamp(k);
4784
+ }
4785
+ }
4786
+ }
4787
+ }
4788
+
4789
+ // fixed token expand (optional)
4790
+ //{
4791
+ // const int t_expand = 0;
4792
+
4793
+ // for (int j = 0; j < n; j++) {
4794
+ // if (j > 0) {
4795
+ // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
4796
+ // }
4797
+ // if (j < n - 1) {
4798
+ // tokens[j].t1 = tokens[j].t1 + t_expand;
4799
+ // }
4800
+ // }
4801
+ //}
4802
+
4803
+ // debug info
4804
+ //for (int j = 0; j < n; ++j) {
4805
+ // const auto & token = tokens[j];
4806
+ // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]";
4807
+ // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
4808
+ // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id));
4809
+
4810
+ // if (tokens[j].id >= whisper_token_eot(&ctx)) {
4811
+ // continue;
4812
+ // }
4813
+ //}
4814
+ }