whispercpp 1.2.0.2 → 1.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +46 -86
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -7
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/ggml/include/ggml.h +2285 -0
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/include/whisper.h +672 -0
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1608 -159
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/src/whisper.cpp +7393 -0
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -8616
- data/ext/ggml.h +0 -748
- data/ext/whisper.cpp +0 -4829
- data/ext/whisper.h +0 -402
data/ext/whisper.cpp
DELETED
@@ -1,4829 +0,0 @@
|
|
1
|
-
#define WHISPER_BUILD
|
2
|
-
#include "whisper.h"
|
3
|
-
|
4
|
-
#include "ggml.h"
|
5
|
-
|
6
|
-
#include <algorithm>
|
7
|
-
#include <cassert>
|
8
|
-
#define _USE_MATH_DEFINES
|
9
|
-
#include <cmath>
|
10
|
-
#include <cstdio>
|
11
|
-
#include <cstring>
|
12
|
-
#include <fstream>
|
13
|
-
#include <map>
|
14
|
-
#include <string>
|
15
|
-
#include <thread>
|
16
|
-
#include <vector>
|
17
|
-
#include <regex>
|
18
|
-
#include <random>
|
19
|
-
|
20
|
-
#if defined(GGML_BIG_ENDIAN)
|
21
|
-
#include <bit>
|
22
|
-
|
23
|
-
template<typename T>
|
24
|
-
static T byteswap(T value) {
|
25
|
-
return std::byteswap(value);
|
26
|
-
}
|
27
|
-
|
28
|
-
template<>
|
29
|
-
float byteswap(float value) {
|
30
|
-
return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
|
31
|
-
}
|
32
|
-
|
33
|
-
template<typename T>
|
34
|
-
static void byteswap_tensor_data(ggml_tensor * tensor) {
|
35
|
-
T * datum = reinterpret_cast<T *>(tensor->data);
|
36
|
-
for (int i = 0; i < ggml_nelements(tensor); i++) {
|
37
|
-
datum[i] = byteswap(datum[i]);
|
38
|
-
}
|
39
|
-
}
|
40
|
-
|
41
|
-
static void byteswap_tensor(ggml_tensor * tensor) {
|
42
|
-
switch (tensor->type) {
|
43
|
-
case GGML_TYPE_I16: {
|
44
|
-
byteswap_tensor_data<int16_t>(tensor);
|
45
|
-
break;
|
46
|
-
}
|
47
|
-
case GGML_TYPE_F16: {
|
48
|
-
byteswap_tensor_data<ggml_fp16_t>(tensor);
|
49
|
-
break;
|
50
|
-
}
|
51
|
-
case GGML_TYPE_I32: {
|
52
|
-
byteswap_tensor_data<int32_t>(tensor);
|
53
|
-
break;
|
54
|
-
}
|
55
|
-
case GGML_TYPE_F32: {
|
56
|
-
byteswap_tensor_data<float>(tensor);
|
57
|
-
break;
|
58
|
-
}
|
59
|
-
default: { // GML_TYPE_I8
|
60
|
-
break;
|
61
|
-
}
|
62
|
-
}
|
63
|
-
}
|
64
|
-
|
65
|
-
#define BYTESWAP_VALUE(d) d = byteswap(d)
|
66
|
-
#define BYTESWAP_FILTERS(f) \
|
67
|
-
do { \
|
68
|
-
for (auto & datum : f.data) { \
|
69
|
-
datum = byteswap(datum); \
|
70
|
-
} \
|
71
|
-
} while (0)
|
72
|
-
#define BYTESWAP_TENSOR(t) \
|
73
|
-
do { \
|
74
|
-
byteswap_tensor(tensor); \
|
75
|
-
} while (0)
|
76
|
-
#else
|
77
|
-
#define BYTESWAP_VALUE(d) do {} while (0)
|
78
|
-
#define BYTESWAP_FILTERS(f) do {} while (0)
|
79
|
-
#define BYTESWAP_TENSOR(t) do {} while (0)
|
80
|
-
#endif
|
81
|
-
|
82
|
-
#define WHISPER_ASSERT(x) \
|
83
|
-
do { \
|
84
|
-
if (!(x)) { \
|
85
|
-
fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
86
|
-
abort(); \
|
87
|
-
} \
|
88
|
-
} while (0)
|
89
|
-
|
90
|
-
// define this to enable verbose trace logging - useful for debugging purposes
|
91
|
-
//#define WHISPER_DEBUG
|
92
|
-
|
93
|
-
#if defined(WHISPER_DEBUG)
|
94
|
-
#define WHISPER_PRINT_DEBUG(...) \
|
95
|
-
do { \
|
96
|
-
fprintf(stderr, __VA_ARGS__); \
|
97
|
-
} while (0)
|
98
|
-
#else
|
99
|
-
#define WHISPER_PRINT_DEBUG(...)
|
100
|
-
#endif
|
101
|
-
|
102
|
-
#define WHISPER_USE_FLASH_ATTN
|
103
|
-
//#define WHISPER_USE_FLASH_FF
|
104
|
-
#define WHISPER_MAX_DECODERS 16
|
105
|
-
|
106
|
-
#define WHISPER_USE_SCRATCH
|
107
|
-
#define WHISPER_MAX_SCRATCH_BUFFERS 16
|
108
|
-
|
109
|
-
// available whisper models
|
110
|
-
enum e_model {
|
111
|
-
MODEL_UNKNOWN,
|
112
|
-
MODEL_TINY,
|
113
|
-
MODEL_BASE,
|
114
|
-
MODEL_SMALL,
|
115
|
-
MODEL_MEDIUM,
|
116
|
-
MODEL_LARGE,
|
117
|
-
};
|
118
|
-
|
119
|
-
static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
120
|
-
{ "en", { 0, "english", } },
|
121
|
-
{ "zh", { 1, "chinese", } },
|
122
|
-
{ "de", { 2, "german", } },
|
123
|
-
{ "es", { 3, "spanish", } },
|
124
|
-
{ "ru", { 4, "russian", } },
|
125
|
-
{ "ko", { 5, "korean", } },
|
126
|
-
{ "fr", { 6, "french", } },
|
127
|
-
{ "ja", { 7, "japanese", } },
|
128
|
-
{ "pt", { 8, "portuguese", } },
|
129
|
-
{ "tr", { 9, "turkish", } },
|
130
|
-
{ "pl", { 10, "polish", } },
|
131
|
-
{ "ca", { 11, "catalan", } },
|
132
|
-
{ "nl", { 12, "dutch", } },
|
133
|
-
{ "ar", { 13, "arabic", } },
|
134
|
-
{ "sv", { 14, "swedish", } },
|
135
|
-
{ "it", { 15, "italian", } },
|
136
|
-
{ "id", { 16, "indonesian", } },
|
137
|
-
{ "hi", { 17, "hindi", } },
|
138
|
-
{ "fi", { 18, "finnish", } },
|
139
|
-
{ "vi", { 19, "vietnamese", } },
|
140
|
-
{ "iw", { 20, "hebrew", } },
|
141
|
-
{ "uk", { 21, "ukrainian", } },
|
142
|
-
{ "el", { 22, "greek", } },
|
143
|
-
{ "ms", { 23, "malay", } },
|
144
|
-
{ "cs", { 24, "czech", } },
|
145
|
-
{ "ro", { 25, "romanian", } },
|
146
|
-
{ "da", { 26, "danish", } },
|
147
|
-
{ "hu", { 27, "hungarian", } },
|
148
|
-
{ "ta", { 28, "tamil", } },
|
149
|
-
{ "no", { 29, "norwegian", } },
|
150
|
-
{ "th", { 30, "thai", } },
|
151
|
-
{ "ur", { 31, "urdu", } },
|
152
|
-
{ "hr", { 32, "croatian", } },
|
153
|
-
{ "bg", { 33, "bulgarian", } },
|
154
|
-
{ "lt", { 34, "lithuanian", } },
|
155
|
-
{ "la", { 35, "latin", } },
|
156
|
-
{ "mi", { 36, "maori", } },
|
157
|
-
{ "ml", { 37, "malayalam", } },
|
158
|
-
{ "cy", { 38, "welsh", } },
|
159
|
-
{ "sk", { 39, "slovak", } },
|
160
|
-
{ "te", { 40, "telugu", } },
|
161
|
-
{ "fa", { 41, "persian", } },
|
162
|
-
{ "lv", { 42, "latvian", } },
|
163
|
-
{ "bn", { 43, "bengali", } },
|
164
|
-
{ "sr", { 44, "serbian", } },
|
165
|
-
{ "az", { 45, "azerbaijani", } },
|
166
|
-
{ "sl", { 46, "slovenian", } },
|
167
|
-
{ "kn", { 47, "kannada", } },
|
168
|
-
{ "et", { 48, "estonian", } },
|
169
|
-
{ "mk", { 49, "macedonian", } },
|
170
|
-
{ "br", { 50, "breton", } },
|
171
|
-
{ "eu", { 51, "basque", } },
|
172
|
-
{ "is", { 52, "icelandic", } },
|
173
|
-
{ "hy", { 53, "armenian", } },
|
174
|
-
{ "ne", { 54, "nepali", } },
|
175
|
-
{ "mn", { 55, "mongolian", } },
|
176
|
-
{ "bs", { 56, "bosnian", } },
|
177
|
-
{ "kk", { 57, "kazakh", } },
|
178
|
-
{ "sq", { 58, "albanian", } },
|
179
|
-
{ "sw", { 59, "swahili", } },
|
180
|
-
{ "gl", { 60, "galician", } },
|
181
|
-
{ "mr", { 61, "marathi", } },
|
182
|
-
{ "pa", { 62, "punjabi", } },
|
183
|
-
{ "si", { 63, "sinhala", } },
|
184
|
-
{ "km", { 64, "khmer", } },
|
185
|
-
{ "sn", { 65, "shona", } },
|
186
|
-
{ "yo", { 66, "yoruba", } },
|
187
|
-
{ "so", { 67, "somali", } },
|
188
|
-
{ "af", { 68, "afrikaans", } },
|
189
|
-
{ "oc", { 69, "occitan", } },
|
190
|
-
{ "ka", { 70, "georgian", } },
|
191
|
-
{ "be", { 71, "belarusian", } },
|
192
|
-
{ "tg", { 72, "tajik", } },
|
193
|
-
{ "sd", { 73, "sindhi", } },
|
194
|
-
{ "gu", { 74, "gujarati", } },
|
195
|
-
{ "am", { 75, "amharic", } },
|
196
|
-
{ "yi", { 76, "yiddish", } },
|
197
|
-
{ "lo", { 77, "lao", } },
|
198
|
-
{ "uz", { 78, "uzbek", } },
|
199
|
-
{ "fo", { 79, "faroese", } },
|
200
|
-
{ "ht", { 80, "haitian creole", } },
|
201
|
-
{ "ps", { 81, "pashto", } },
|
202
|
-
{ "tk", { 82, "turkmen", } },
|
203
|
-
{ "nn", { 83, "nynorsk", } },
|
204
|
-
{ "mt", { 84, "maltese", } },
|
205
|
-
{ "sa", { 85, "sanskrit", } },
|
206
|
-
{ "lb", { 86, "luxembourgish", } },
|
207
|
-
{ "my", { 87, "myanmar", } },
|
208
|
-
{ "bo", { 88, "tibetan", } },
|
209
|
-
{ "tl", { 89, "tagalog", } },
|
210
|
-
{ "mg", { 90, "malagasy", } },
|
211
|
-
{ "as", { 91, "assamese", } },
|
212
|
-
{ "tt", { 92, "tatar", } },
|
213
|
-
{ "haw", { 93, "hawaiian", } },
|
214
|
-
{ "ln", { 94, "lingala", } },
|
215
|
-
{ "ha", { 95, "hausa", } },
|
216
|
-
{ "ba", { 96, "bashkir", } },
|
217
|
-
{ "jw", { 97, "javanese", } },
|
218
|
-
{ "su", { 98, "sundanese", } },
|
219
|
-
};
|
220
|
-
|
221
|
-
static const size_t MB = 1024*1024;
|
222
|
-
|
223
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
|
224
|
-
{ MODEL_TINY, 12ull*MB },
|
225
|
-
{ MODEL_BASE, 15ull*MB },
|
226
|
-
{ MODEL_SMALL, 23ull*MB },
|
227
|
-
{ MODEL_MEDIUM, 31ull*MB },
|
228
|
-
{ MODEL_LARGE, 38ull*MB },
|
229
|
-
};
|
230
|
-
|
231
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
|
232
|
-
{ MODEL_TINY, 18ull*MB },
|
233
|
-
{ MODEL_BASE, 24ull*MB },
|
234
|
-
{ MODEL_SMALL, 36ull*MB },
|
235
|
-
{ MODEL_MEDIUM, 48ull*MB },
|
236
|
-
{ MODEL_LARGE, 60ull*MB },
|
237
|
-
};
|
238
|
-
|
239
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
|
240
|
-
{ MODEL_TINY, 4ull*MB },
|
241
|
-
{ MODEL_BASE, 4ull*MB },
|
242
|
-
{ MODEL_SMALL, 6ull*MB },
|
243
|
-
{ MODEL_MEDIUM, 7ull*MB },
|
244
|
-
{ MODEL_LARGE, 9ull*MB },
|
245
|
-
};
|
246
|
-
|
247
|
-
static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
|
248
|
-
{ MODEL_TINY, 4ull*MB },
|
249
|
-
{ MODEL_BASE, 4ull*MB },
|
250
|
-
{ MODEL_SMALL, 6ull*MB },
|
251
|
-
{ MODEL_MEDIUM, 7ull*MB },
|
252
|
-
{ MODEL_LARGE, 9ull*MB },
|
253
|
-
};
|
254
|
-
|
255
|
-
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
|
256
|
-
{ MODEL_TINY, 74ull*MB },
|
257
|
-
{ MODEL_BASE, 142ull*MB },
|
258
|
-
{ MODEL_SMALL, 466ull*MB },
|
259
|
-
{ MODEL_MEDIUM, 1464ull*MB },
|
260
|
-
{ MODEL_LARGE, 2952ull*MB },
|
261
|
-
};
|
262
|
-
|
263
|
-
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
|
264
|
-
{ MODEL_TINY, 3ull*MB },
|
265
|
-
{ MODEL_BASE, 6ull*MB },
|
266
|
-
{ MODEL_SMALL, 16ull*MB },
|
267
|
-
{ MODEL_MEDIUM, 43ull*MB },
|
268
|
-
{ MODEL_LARGE, 71ull*MB },
|
269
|
-
};
|
270
|
-
|
271
|
-
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
|
272
|
-
{ MODEL_TINY, 9ull*MB },
|
273
|
-
{ MODEL_BASE, 18ull*MB },
|
274
|
-
{ MODEL_SMALL, 53ull*MB },
|
275
|
-
{ MODEL_MEDIUM, 141ull*MB },
|
276
|
-
{ MODEL_LARGE, 235ull*MB },
|
277
|
-
};
|
278
|
-
|
279
|
-
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
280
|
-
{ MODEL_TINY, 6ull*MB },
|
281
|
-
{ MODEL_BASE, 8ull*MB },
|
282
|
-
{ MODEL_SMALL, 13ull*MB },
|
283
|
-
{ MODEL_MEDIUM, 22ull*MB },
|
284
|
-
{ MODEL_LARGE, 33ull*MB },
|
285
|
-
};
|
286
|
-
|
287
|
-
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
288
|
-
{ MODEL_TINY, 3ull*MB },
|
289
|
-
{ MODEL_BASE, 5ull*MB },
|
290
|
-
{ MODEL_SMALL, 10ull*MB },
|
291
|
-
{ MODEL_MEDIUM, 18ull*MB },
|
292
|
-
{ MODEL_LARGE, 27ull*MB },
|
293
|
-
};
|
294
|
-
|
295
|
-
struct whisper_mel {
|
296
|
-
int n_len;
|
297
|
-
int n_mel;
|
298
|
-
|
299
|
-
std::vector<float> data;
|
300
|
-
};
|
301
|
-
|
302
|
-
struct whisper_filters {
|
303
|
-
int32_t n_mel;
|
304
|
-
int32_t n_fft;
|
305
|
-
|
306
|
-
std::vector<float> data;
|
307
|
-
};
|
308
|
-
|
309
|
-
struct whisper_vocab {
|
310
|
-
using id = int32_t;
|
311
|
-
using token = std::string;
|
312
|
-
|
313
|
-
int n_vocab = 51864;
|
314
|
-
|
315
|
-
std::map<token, id> token_to_id;
|
316
|
-
std::map<id, token> id_to_token;
|
317
|
-
|
318
|
-
id token_eot = 50256;
|
319
|
-
id token_sot = 50257;
|
320
|
-
id token_prev = 50360;
|
321
|
-
id token_solm = 50361; // ??
|
322
|
-
id token_not = 50362; // no timestamps
|
323
|
-
id token_beg = 50363;
|
324
|
-
|
325
|
-
// available tasks
|
326
|
-
static const id token_translate = 50358;
|
327
|
-
static const id token_transcribe = 50359;
|
328
|
-
|
329
|
-
bool is_multilingual() const {
|
330
|
-
return n_vocab == 51865;
|
331
|
-
}
|
332
|
-
};
|
333
|
-
|
334
|
-
struct whisper_segment {
|
335
|
-
int64_t t0;
|
336
|
-
int64_t t1;
|
337
|
-
|
338
|
-
std::string text;
|
339
|
-
|
340
|
-
std::vector<whisper_token_data> tokens;
|
341
|
-
};
|
342
|
-
|
343
|
-
// medium
|
344
|
-
// hparams: {
|
345
|
-
// 'n_mels': 80,
|
346
|
-
// 'n_vocab': 51864,
|
347
|
-
// 'n_audio_ctx': 1500,
|
348
|
-
// 'n_audio_state': 1024,
|
349
|
-
// 'n_audio_head': 16,
|
350
|
-
// 'n_audio_layer': 24,
|
351
|
-
// 'n_text_ctx': 448,
|
352
|
-
// 'n_text_state': 1024,
|
353
|
-
// 'n_text_head': 16,
|
354
|
-
// 'n_text_layer': 24
|
355
|
-
// }
|
356
|
-
//
|
357
|
-
// default hparams (Whisper tiny)
|
358
|
-
struct whisper_hparams {
|
359
|
-
int32_t n_vocab = 51864;
|
360
|
-
int32_t n_audio_ctx = 1500;
|
361
|
-
int32_t n_audio_state = 384;
|
362
|
-
int32_t n_audio_head = 6;
|
363
|
-
int32_t n_audio_layer = 4;
|
364
|
-
int32_t n_text_ctx = 448;
|
365
|
-
int32_t n_text_state = 384;
|
366
|
-
int32_t n_text_head = 6;
|
367
|
-
int32_t n_text_layer = 4;
|
368
|
-
int32_t n_mels = 80;
|
369
|
-
int32_t f16 = 1;
|
370
|
-
};
|
371
|
-
|
372
|
-
// audio encoding layer
|
373
|
-
struct whisper_layer_encoder {
|
374
|
-
// encoder.blocks.*.attn_ln
|
375
|
-
struct ggml_tensor * attn_ln_0_w;
|
376
|
-
struct ggml_tensor * attn_ln_0_b;
|
377
|
-
|
378
|
-
// encoder.blocks.*.attn.out
|
379
|
-
struct ggml_tensor * attn_ln_1_w;
|
380
|
-
struct ggml_tensor * attn_ln_1_b;
|
381
|
-
|
382
|
-
// encoder.blocks.*.attn.query
|
383
|
-
struct ggml_tensor * attn_q_w;
|
384
|
-
struct ggml_tensor * attn_q_b;
|
385
|
-
|
386
|
-
// encoder.blocks.*.attn.key
|
387
|
-
struct ggml_tensor * attn_k_w;
|
388
|
-
|
389
|
-
// encoder.blocks.*.attn.value
|
390
|
-
struct ggml_tensor * attn_v_w;
|
391
|
-
struct ggml_tensor * attn_v_b;
|
392
|
-
|
393
|
-
// encoder.blocks.*.mlp_ln
|
394
|
-
struct ggml_tensor * mlp_ln_w;
|
395
|
-
struct ggml_tensor * mlp_ln_b;
|
396
|
-
|
397
|
-
// encoder.blocks.*.mlp.0
|
398
|
-
struct ggml_tensor * mlp_0_w;
|
399
|
-
struct ggml_tensor * mlp_0_b;
|
400
|
-
|
401
|
-
// encoder.blocks.*.mlp.2
|
402
|
-
struct ggml_tensor * mlp_1_w;
|
403
|
-
struct ggml_tensor * mlp_1_b;
|
404
|
-
};
|
405
|
-
|
406
|
-
// token decoding layer
|
407
|
-
struct whisper_layer_decoder {
|
408
|
-
// decoder.blocks.*.attn_ln
|
409
|
-
struct ggml_tensor * attn_ln_0_w;
|
410
|
-
struct ggml_tensor * attn_ln_0_b;
|
411
|
-
|
412
|
-
// decoder.blocks.*.attn.out
|
413
|
-
struct ggml_tensor * attn_ln_1_w;
|
414
|
-
struct ggml_tensor * attn_ln_1_b;
|
415
|
-
|
416
|
-
// decoder.blocks.*.attn.query
|
417
|
-
struct ggml_tensor * attn_q_w;
|
418
|
-
struct ggml_tensor * attn_q_b;
|
419
|
-
|
420
|
-
// decoder.blocks.*.attn.key
|
421
|
-
struct ggml_tensor * attn_k_w;
|
422
|
-
|
423
|
-
// decoder.blocks.*.attn.value
|
424
|
-
struct ggml_tensor * attn_v_w;
|
425
|
-
struct ggml_tensor * attn_v_b;
|
426
|
-
|
427
|
-
// decoder.blocks.*.cross_attn_ln
|
428
|
-
struct ggml_tensor * cross_attn_ln_0_w;
|
429
|
-
struct ggml_tensor * cross_attn_ln_0_b;
|
430
|
-
|
431
|
-
// decoder.blocks.*.cross_attn.out
|
432
|
-
struct ggml_tensor * cross_attn_ln_1_w;
|
433
|
-
struct ggml_tensor * cross_attn_ln_1_b;
|
434
|
-
|
435
|
-
// decoder.blocks.*.cross_attn.query
|
436
|
-
struct ggml_tensor * cross_attn_q_w;
|
437
|
-
struct ggml_tensor * cross_attn_q_b;
|
438
|
-
|
439
|
-
// decoder.blocks.*.cross_attn.key
|
440
|
-
struct ggml_tensor * cross_attn_k_w;
|
441
|
-
|
442
|
-
// decoder.blocks.*.cross_attn.value
|
443
|
-
struct ggml_tensor * cross_attn_v_w;
|
444
|
-
struct ggml_tensor * cross_attn_v_b;
|
445
|
-
|
446
|
-
// decoder.blocks.*.mlp_ln
|
447
|
-
struct ggml_tensor * mlp_ln_w;
|
448
|
-
struct ggml_tensor * mlp_ln_b;
|
449
|
-
|
450
|
-
// decoder.blocks.*.mlp.0
|
451
|
-
struct ggml_tensor * mlp_0_w;
|
452
|
-
struct ggml_tensor * mlp_0_b;
|
453
|
-
|
454
|
-
// decoder.blocks.*.mlp.2
|
455
|
-
struct ggml_tensor * mlp_1_w;
|
456
|
-
struct ggml_tensor * mlp_1_b;
|
457
|
-
};
|
458
|
-
|
459
|
-
struct whisper_kv_cache {
|
460
|
-
struct ggml_tensor * k;
|
461
|
-
struct ggml_tensor * v;
|
462
|
-
|
463
|
-
struct ggml_context * ctx;
|
464
|
-
|
465
|
-
std::vector<uint8_t> buf;
|
466
|
-
|
467
|
-
int n; // number of tokens currently in the cache
|
468
|
-
};
|
469
|
-
|
470
|
-
struct whisper_model {
|
471
|
-
e_model type = MODEL_UNKNOWN;
|
472
|
-
|
473
|
-
whisper_hparams hparams;
|
474
|
-
whisper_filters filters;
|
475
|
-
|
476
|
-
// encoder.positional_embedding
|
477
|
-
struct ggml_tensor * e_pe;
|
478
|
-
|
479
|
-
// encoder.conv1
|
480
|
-
struct ggml_tensor * e_conv_1_w;
|
481
|
-
struct ggml_tensor * e_conv_1_b;
|
482
|
-
|
483
|
-
// encoder.conv2
|
484
|
-
struct ggml_tensor * e_conv_2_w;
|
485
|
-
struct ggml_tensor * e_conv_2_b;
|
486
|
-
|
487
|
-
// encoder.ln_post
|
488
|
-
struct ggml_tensor * e_ln_w;
|
489
|
-
struct ggml_tensor * e_ln_b;
|
490
|
-
|
491
|
-
// decoder.positional_embedding
|
492
|
-
struct ggml_tensor * d_pe;
|
493
|
-
|
494
|
-
// decoder.token_embedding
|
495
|
-
struct ggml_tensor * d_te;
|
496
|
-
|
497
|
-
// decoder.ln
|
498
|
-
struct ggml_tensor * d_ln_w;
|
499
|
-
struct ggml_tensor * d_ln_b;
|
500
|
-
|
501
|
-
std::vector<whisper_layer_encoder> layers_encoder;
|
502
|
-
std::vector<whisper_layer_decoder> layers_decoder;
|
503
|
-
|
504
|
-
// context
|
505
|
-
struct ggml_context * ctx;
|
506
|
-
|
507
|
-
// the model memory buffer is read-only and can be shared between processors
|
508
|
-
std::vector<uint8_t> * buf;
|
509
|
-
|
510
|
-
// tensors
|
511
|
-
int n_loaded;
|
512
|
-
std::map<std::string, struct ggml_tensor *> tensors;
|
513
|
-
};
|
514
|
-
|
515
|
-
struct whisper_sequence {
|
516
|
-
std::vector<whisper_token_data> tokens;
|
517
|
-
|
518
|
-
// the accumulated transcription in the current interation (used to truncate the tokens array)
|
519
|
-
int result_len;
|
520
|
-
|
521
|
-
double sum_logprobs_all; // the sum of the log probabilities of the tokens
|
522
|
-
double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens)
|
523
|
-
double avg_logprobs; // the average log probability of the tokens
|
524
|
-
double entropy; // the entropy of the tokens
|
525
|
-
double score; // likelihood rank score
|
526
|
-
};
|
527
|
-
|
528
|
-
// TAGS: WHISPER_DECODER_INIT
|
529
|
-
struct whisper_decoder {
|
530
|
-
// each decoders keeps its own KV-cache
|
531
|
-
whisper_kv_cache kv_self;
|
532
|
-
|
533
|
-
// the currently generated sequence of tokens
|
534
|
-
whisper_sequence sequence;
|
535
|
-
|
536
|
-
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
|
537
|
-
|
538
|
-
bool failed; // has the current segment failed to decode?
|
539
|
-
bool completed; // has the decoder completed the current segment?
|
540
|
-
bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
|
541
|
-
|
542
|
-
// new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
|
543
|
-
std::vector<float> probs;
|
544
|
-
std::vector<float> logits;
|
545
|
-
std::vector<float> logprobs;
|
546
|
-
|
547
|
-
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
548
|
-
};
|
549
|
-
|
550
|
-
struct whisper_context {
|
551
|
-
int64_t t_load_us = 0;
|
552
|
-
int64_t t_mel_us = 0;
|
553
|
-
int64_t t_sample_us = 0;
|
554
|
-
int64_t t_encode_us = 0;
|
555
|
-
int64_t t_decode_us = 0;
|
556
|
-
int64_t t_start_us = 0;
|
557
|
-
|
558
|
-
int32_t n_sample = 0; // number of tokens sampled
|
559
|
-
int32_t n_encode = 0; // number of encoder calls
|
560
|
-
int32_t n_decode = 0; // number of decoder calls
|
561
|
-
int32_t n_fail_p = 0; // number of logprob threshold failures
|
562
|
-
int32_t n_fail_h = 0; // number of entropy threshold failures
|
563
|
-
|
564
|
-
ggml_type wtype; // weight type (FP32 or FP16)
|
565
|
-
|
566
|
-
whisper_mel mel;
|
567
|
-
|
568
|
-
whisper_model model;
|
569
|
-
whisper_vocab vocab;
|
570
|
-
|
571
|
-
// cross-attention KV cache for the decoders
|
572
|
-
// shared between all decoders
|
573
|
-
whisper_kv_cache kv_cross;
|
574
|
-
|
575
|
-
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
576
|
-
|
577
|
-
// memory buffers used by encode / decode contexts
|
578
|
-
std::vector<uint8_t> buf_compute;
|
579
|
-
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
|
580
|
-
|
581
|
-
int buf_last = 0;
|
582
|
-
size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
|
583
|
-
|
584
|
-
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
585
|
-
std::vector<float> logits;
|
586
|
-
|
587
|
-
std::vector<whisper_segment> result_all;
|
588
|
-
std::vector<whisper_token> prompt_past;
|
589
|
-
|
590
|
-
// work container used to avoid memory allocations
|
591
|
-
std::vector<std::pair<double, whisper_vocab::id>> logits_id;
|
592
|
-
|
593
|
-
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
594
|
-
|
595
|
-
int lang_id = 0; // english by default
|
596
|
-
|
597
|
-
// [EXPERIMENTAL] token-level timestamps data
|
598
|
-
int64_t t_beg = 0;
|
599
|
-
int64_t t_last = 0;
|
600
|
-
whisper_token tid_last;
|
601
|
-
std::vector<float> energy; // PCM signal energy
|
602
|
-
|
603
|
-
// [EXPERIMENTAL] speed-up techniques
|
604
|
-
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
605
|
-
|
606
|
-
// [EXPERIMENTAL] abort handling
|
607
|
-
bool running = true;
|
608
|
-
|
609
|
-
void use_buf(struct ggml_context * ctx, int i) {
|
610
|
-
#if defined(WHISPER_USE_SCRATCH)
|
611
|
-
size_t last_size = 0;
|
612
|
-
|
613
|
-
if (i == -1) {
|
614
|
-
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
615
|
-
} else {
|
616
|
-
auto & buf = buf_scratch[i];
|
617
|
-
last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
|
618
|
-
}
|
619
|
-
|
620
|
-
if (buf_last >= 0) {
|
621
|
-
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
622
|
-
}
|
623
|
-
|
624
|
-
buf_last = i;
|
625
|
-
#else
|
626
|
-
(void) i;
|
627
|
-
(void) ctx;
|
628
|
-
#endif
|
629
|
-
}
|
630
|
-
|
631
|
-
size_t get_buf_max_mem(int i) const {
|
632
|
-
#if defined(WHISPER_USE_SCRATCH)
|
633
|
-
return buf_max_size[i];
|
634
|
-
#else
|
635
|
-
(void) i;
|
636
|
-
return 0;
|
637
|
-
#endif
|
638
|
-
}
|
639
|
-
};
|
640
|
-
|
641
|
-
template<typename T>
|
642
|
-
static void read_safe(whisper_model_loader * loader, T & dest) {
|
643
|
-
loader->read(loader->context, &dest, sizeof(T));
|
644
|
-
BYTESWAP_VALUE(dest);
|
645
|
-
}
|
646
|
-
|
647
|
-
static bool kv_cache_init(
|
648
|
-
const struct whisper_hparams & hparams,
|
649
|
-
const size_t mem_bytes,
|
650
|
-
struct whisper_kv_cache & cache,
|
651
|
-
ggml_type wtype,
|
652
|
-
int n_ctx) {
|
653
|
-
cache.buf.resize(mem_bytes);
|
654
|
-
|
655
|
-
struct ggml_init_params params;
|
656
|
-
params.mem_size = cache.buf.size();
|
657
|
-
params.mem_buffer = cache.buf.data();
|
658
|
-
|
659
|
-
cache.ctx = ggml_init(params);
|
660
|
-
|
661
|
-
if (!cache.ctx) {
|
662
|
-
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
|
663
|
-
return false;
|
664
|
-
}
|
665
|
-
|
666
|
-
const int n_text_state = hparams.n_text_state;
|
667
|
-
const int n_text_layer = hparams.n_text_layer;
|
668
|
-
|
669
|
-
const int n_mem = n_text_layer*n_ctx;
|
670
|
-
const int n_elements = n_text_state*n_mem;
|
671
|
-
|
672
|
-
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
673
|
-
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
674
|
-
|
675
|
-
return true;
|
676
|
-
}
|
677
|
-
|
678
|
-
static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
|
679
|
-
WHISPER_ASSERT(cache.ctx);
|
680
|
-
|
681
|
-
const int n_elements = ggml_nelements(cache.k);
|
682
|
-
WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
|
683
|
-
|
684
|
-
const ggml_type wtype = cache.k->type;
|
685
|
-
WHISPER_ASSERT(wtype == cache.v->type);
|
686
|
-
|
687
|
-
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
|
688
|
-
|
689
|
-
struct ggml_init_params params;
|
690
|
-
params.mem_size = cache.buf.size();
|
691
|
-
params.mem_buffer = cache.buf.data();
|
692
|
-
|
693
|
-
cache.ctx = ggml_init(params);
|
694
|
-
|
695
|
-
if (!cache.ctx) {
|
696
|
-
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
|
697
|
-
return false;
|
698
|
-
}
|
699
|
-
|
700
|
-
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
701
|
-
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
702
|
-
|
703
|
-
return true;
|
704
|
-
}
|
705
|
-
|
706
|
-
static void kv_cache_free(struct whisper_kv_cache & cache) {
|
707
|
-
if (cache.ctx) {
|
708
|
-
ggml_free(cache.ctx);
|
709
|
-
cache.ctx = nullptr;
|
710
|
-
}
|
711
|
-
}
|
712
|
-
|
713
|
-
// load the model from a ggml file
|
714
|
-
//
|
715
|
-
// file format:
|
716
|
-
//
|
717
|
-
// - hparams
|
718
|
-
// - pre-computed mel filters
|
719
|
-
// - vocab
|
720
|
-
// - weights
|
721
|
-
//
|
722
|
-
// see the convert-pt-to-ggml.py script for details
|
723
|
-
//
|
724
|
-
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
|
725
|
-
fprintf(stderr, "%s: loading model\n", __func__);
|
726
|
-
|
727
|
-
const int64_t t_start_us = ggml_time_us();
|
728
|
-
|
729
|
-
wctx.t_start_us = t_start_us;
|
730
|
-
|
731
|
-
auto & model = wctx.model;
|
732
|
-
auto & vocab = wctx.vocab;
|
733
|
-
|
734
|
-
// verify magic
|
735
|
-
{
|
736
|
-
uint32_t magic;
|
737
|
-
read_safe(loader, magic);
|
738
|
-
if (magic != 0x67676d6c) {
|
739
|
-
fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
|
740
|
-
return false;
|
741
|
-
}
|
742
|
-
}
|
743
|
-
|
744
|
-
//load hparams
|
745
|
-
{
|
746
|
-
auto & hparams = model.hparams;
|
747
|
-
|
748
|
-
read_safe(loader, hparams.n_vocab);
|
749
|
-
read_safe(loader, hparams.n_audio_ctx);
|
750
|
-
read_safe(loader, hparams.n_audio_state);
|
751
|
-
read_safe(loader, hparams.n_audio_head);
|
752
|
-
read_safe(loader, hparams.n_audio_layer);
|
753
|
-
read_safe(loader, hparams.n_text_ctx);
|
754
|
-
read_safe(loader, hparams.n_text_state);
|
755
|
-
read_safe(loader, hparams.n_text_head);
|
756
|
-
read_safe(loader, hparams.n_text_layer);
|
757
|
-
read_safe(loader, hparams.n_mels);
|
758
|
-
read_safe(loader, hparams.f16);
|
759
|
-
|
760
|
-
assert(hparams.n_text_state == hparams.n_audio_state);
|
761
|
-
|
762
|
-
if (hparams.n_audio_layer == 4) {
|
763
|
-
model.type = e_model::MODEL_TINY;
|
764
|
-
}
|
765
|
-
|
766
|
-
if (hparams.n_audio_layer == 6) {
|
767
|
-
model.type = e_model::MODEL_BASE;
|
768
|
-
}
|
769
|
-
|
770
|
-
if (hparams.n_audio_layer == 12) {
|
771
|
-
model.type = e_model::MODEL_SMALL;
|
772
|
-
}
|
773
|
-
|
774
|
-
if (hparams.n_audio_layer == 24) {
|
775
|
-
model.type = e_model::MODEL_MEDIUM;
|
776
|
-
}
|
777
|
-
|
778
|
-
if (hparams.n_audio_layer == 32) {
|
779
|
-
model.type = e_model::MODEL_LARGE;
|
780
|
-
}
|
781
|
-
|
782
|
-
// for the big tensors, we have the option to store the data in 16-bit floats
|
783
|
-
// in order to save memory and also to speed up the computation
|
784
|
-
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
785
|
-
|
786
|
-
const size_t scale = model.hparams.f16 ? 1 : 2;
|
787
|
-
|
788
|
-
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
789
|
-
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
|
790
|
-
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
791
|
-
fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
|
792
|
-
fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
|
793
|
-
fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
|
794
|
-
fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
|
795
|
-
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
|
796
|
-
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
|
797
|
-
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
|
798
|
-
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
|
799
|
-
fprintf(stderr, "%s: type = %d\n", __func__, model.type);
|
800
|
-
|
801
|
-
// print memory requirements
|
802
|
-
{
|
803
|
-
// this is the total memory required to run the inference
|
804
|
-
const size_t mem_required =
|
805
|
-
MEM_REQ_SCRATCH0.at (model.type) +
|
806
|
-
MEM_REQ_SCRATCH1.at (model.type) +
|
807
|
-
MEM_REQ_SCRATCH2.at (model.type) +
|
808
|
-
MEM_REQ_SCRATCH3.at (model.type) +
|
809
|
-
scale*MEM_REQ_MODEL.at (model.type) +
|
810
|
-
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
811
|
-
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
812
|
-
|
813
|
-
// this is the memory required by one decoder
|
814
|
-
const size_t mem_required_decoder =
|
815
|
-
scale*MEM_REQ_KV_SELF.at(model.type);
|
816
|
-
|
817
|
-
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
818
|
-
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
819
|
-
}
|
820
|
-
|
821
|
-
// initialize all memory buffers
|
822
|
-
// always have at least one decoder
|
823
|
-
|
824
|
-
wctx.model.buf = new std::vector<uint8_t>();
|
825
|
-
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
|
826
|
-
|
827
|
-
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
|
828
|
-
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
829
|
-
return false;
|
830
|
-
}
|
831
|
-
|
832
|
-
{
|
833
|
-
const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
|
834
|
-
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
|
835
|
-
}
|
836
|
-
|
837
|
-
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
|
838
|
-
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
839
|
-
return false;
|
840
|
-
}
|
841
|
-
|
842
|
-
{
|
843
|
-
const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
|
844
|
-
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
|
845
|
-
}
|
846
|
-
|
847
|
-
wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
|
848
|
-
|
849
|
-
wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
|
850
|
-
wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
|
851
|
-
wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
|
852
|
-
wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
|
853
|
-
}
|
854
|
-
|
855
|
-
// load mel filters
|
856
|
-
{
|
857
|
-
auto & filters = wctx.model.filters;
|
858
|
-
|
859
|
-
read_safe(loader, filters.n_mel);
|
860
|
-
read_safe(loader, filters.n_fft);
|
861
|
-
|
862
|
-
filters.data.resize(filters.n_mel * filters.n_fft);
|
863
|
-
loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
|
864
|
-
BYTESWAP_FILTERS(filters);
|
865
|
-
}
|
866
|
-
|
867
|
-
// load vocab
|
868
|
-
{
|
869
|
-
int32_t n_vocab = 0;
|
870
|
-
read_safe(loader, n_vocab);
|
871
|
-
|
872
|
-
//if (n_vocab != model.hparams.n_vocab) {
|
873
|
-
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
874
|
-
// __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
875
|
-
// return false;
|
876
|
-
//}
|
877
|
-
|
878
|
-
std::string word;
|
879
|
-
std::vector<char> tmp;
|
880
|
-
|
881
|
-
tmp.reserve(128);
|
882
|
-
|
883
|
-
for (int i = 0; i < n_vocab; i++) {
|
884
|
-
uint32_t len;
|
885
|
-
read_safe(loader, len);
|
886
|
-
|
887
|
-
if (len > 0) {
|
888
|
-
tmp.resize(len);
|
889
|
-
loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
|
890
|
-
word.assign(&tmp[0], tmp.size());
|
891
|
-
} else {
|
892
|
-
// seems like we have an empty-string token in multi-language models (i = 50256)
|
893
|
-
//fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
|
894
|
-
word = "";
|
895
|
-
}
|
896
|
-
|
897
|
-
vocab.token_to_id[word] = i;
|
898
|
-
vocab.id_to_token[i] = word;
|
899
|
-
|
900
|
-
//printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
|
901
|
-
}
|
902
|
-
|
903
|
-
vocab.n_vocab = model.hparams.n_vocab;
|
904
|
-
if (vocab.is_multilingual()) {
|
905
|
-
vocab.token_eot++;
|
906
|
-
vocab.token_sot++;
|
907
|
-
vocab.token_prev++;
|
908
|
-
vocab.token_solm++;
|
909
|
-
vocab.token_not++;
|
910
|
-
vocab.token_beg++;
|
911
|
-
}
|
912
|
-
|
913
|
-
if (n_vocab < model.hparams.n_vocab) {
|
914
|
-
fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
|
915
|
-
for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
|
916
|
-
if (i > vocab.token_beg) {
|
917
|
-
word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
|
918
|
-
} else if (i == vocab.token_eot) {
|
919
|
-
word = "[_EOT_]";
|
920
|
-
} else if (i == vocab.token_sot) {
|
921
|
-
word = "[_SOT_]";
|
922
|
-
} else if (i == vocab.token_prev) {
|
923
|
-
word = "[_PREV_]";
|
924
|
-
} else if (i == vocab.token_not) {
|
925
|
-
word = "[_NOT_]";
|
926
|
-
} else if (i == vocab.token_beg) {
|
927
|
-
word = "[_BEG_]";
|
928
|
-
} else {
|
929
|
-
word = "[_extra_token_" + std::to_string(i) + "]";
|
930
|
-
}
|
931
|
-
vocab.token_to_id[word] = i;
|
932
|
-
vocab.id_to_token[i] = word;
|
933
|
-
}
|
934
|
-
}
|
935
|
-
|
936
|
-
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
937
|
-
|
938
|
-
wctx.logits_id.reserve(n_vocab);
|
939
|
-
|
940
|
-
// TAGS: WHISPER_DECODER_INIT
|
941
|
-
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
942
|
-
|
943
|
-
wctx.decoders[0].probs.reserve (vocab.n_vocab);
|
944
|
-
wctx.decoders[0].logits.reserve (vocab.n_vocab);
|
945
|
-
wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
|
946
|
-
}
|
947
|
-
|
948
|
-
size_t ctx_size = 0;
|
949
|
-
|
950
|
-
const ggml_type wtype = wctx.wtype;
|
951
|
-
|
952
|
-
{
|
953
|
-
const auto & hparams = model.hparams;
|
954
|
-
|
955
|
-
const int n_vocab = hparams.n_vocab;
|
956
|
-
|
957
|
-
const int n_audio_ctx = hparams.n_audio_ctx;
|
958
|
-
const int n_audio_state = hparams.n_audio_state;
|
959
|
-
const int n_audio_layer = hparams.n_audio_layer;
|
960
|
-
|
961
|
-
const int n_text_ctx = hparams.n_text_ctx;
|
962
|
-
const int n_text_state = hparams.n_text_state;
|
963
|
-
const int n_text_layer = hparams.n_text_layer;
|
964
|
-
|
965
|
-
const int n_mels = hparams.n_mels;
|
966
|
-
|
967
|
-
// encoder
|
968
|
-
{
|
969
|
-
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
|
970
|
-
|
971
|
-
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
|
972
|
-
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
|
973
|
-
|
974
|
-
ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
|
975
|
-
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
|
976
|
-
|
977
|
-
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
|
978
|
-
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
|
979
|
-
}
|
980
|
-
|
981
|
-
// decoder
|
982
|
-
{
|
983
|
-
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
|
984
|
-
|
985
|
-
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
|
986
|
-
|
987
|
-
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
|
988
|
-
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
|
989
|
-
}
|
990
|
-
|
991
|
-
// encoder layers
|
992
|
-
{
|
993
|
-
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
|
994
|
-
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
|
995
|
-
|
996
|
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
|
997
|
-
ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
|
998
|
-
|
999
|
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
|
1000
|
-
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
|
1001
|
-
|
1002
|
-
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
|
1003
|
-
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
|
1004
|
-
|
1005
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
|
1006
|
-
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
|
1007
|
-
|
1008
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
|
1009
|
-
|
1010
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
|
1011
|
-
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
|
1012
|
-
|
1013
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
|
1014
|
-
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
|
1015
|
-
}
|
1016
|
-
|
1017
|
-
// decoder layers
|
1018
|
-
{
|
1019
|
-
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
|
1020
|
-
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
|
1021
|
-
|
1022
|
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
|
1023
|
-
ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
|
1024
|
-
|
1025
|
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
|
1026
|
-
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
|
1027
|
-
|
1028
|
-
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
|
1029
|
-
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
|
1030
|
-
|
1031
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
|
1032
|
-
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
|
1033
|
-
|
1034
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
|
1035
|
-
|
1036
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
|
1037
|
-
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
|
1038
|
-
|
1039
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
|
1040
|
-
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
|
1041
|
-
//
|
1042
|
-
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
|
1043
|
-
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
|
1044
|
-
|
1045
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
|
1046
|
-
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
|
1047
|
-
|
1048
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
|
1049
|
-
|
1050
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
|
1051
|
-
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
|
1052
|
-
|
1053
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
|
1054
|
-
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
|
1055
|
-
}
|
1056
|
-
|
1057
|
-
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
|
1058
|
-
|
1059
|
-
fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
1060
|
-
}
|
1061
|
-
|
1062
|
-
// create the ggml context
|
1063
|
-
{
|
1064
|
-
struct ggml_init_params params;
|
1065
|
-
params.mem_size = wctx.model.buf->size();
|
1066
|
-
params.mem_buffer = wctx.model.buf->data();
|
1067
|
-
|
1068
|
-
model.ctx = ggml_init(params);
|
1069
|
-
if (!model.ctx) {
|
1070
|
-
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
1071
|
-
return false;
|
1072
|
-
}
|
1073
|
-
}
|
1074
|
-
|
1075
|
-
// prepare memory for the weights
|
1076
|
-
{
|
1077
|
-
auto & ctx = model.ctx;
|
1078
|
-
|
1079
|
-
const auto & hparams = model.hparams;
|
1080
|
-
|
1081
|
-
const int n_vocab = hparams.n_vocab;
|
1082
|
-
|
1083
|
-
const int n_audio_ctx = hparams.n_audio_ctx;
|
1084
|
-
const int n_audio_state = hparams.n_audio_state;
|
1085
|
-
const int n_audio_layer = hparams.n_audio_layer;
|
1086
|
-
|
1087
|
-
const int n_text_ctx = hparams.n_text_ctx;
|
1088
|
-
const int n_text_state = hparams.n_text_state;
|
1089
|
-
const int n_text_layer = hparams.n_text_layer;
|
1090
|
-
|
1091
|
-
const int n_mels = hparams.n_mels;
|
1092
|
-
|
1093
|
-
model.layers_encoder.resize(n_audio_layer);
|
1094
|
-
model.layers_decoder.resize(n_text_layer);
|
1095
|
-
|
1096
|
-
// encoder
|
1097
|
-
{
|
1098
|
-
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
1099
|
-
|
1100
|
-
model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
|
1101
|
-
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
1102
|
-
|
1103
|
-
model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
|
1104
|
-
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
1105
|
-
|
1106
|
-
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1107
|
-
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1108
|
-
|
1109
|
-
// map by name
|
1110
|
-
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
1111
|
-
|
1112
|
-
model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
|
1113
|
-
model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
|
1114
|
-
|
1115
|
-
model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
|
1116
|
-
model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
|
1117
|
-
|
1118
|
-
model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
|
1119
|
-
model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
|
1120
|
-
|
1121
|
-
for (int i = 0; i < n_audio_layer; ++i) {
|
1122
|
-
auto & layer = model.layers_encoder[i];
|
1123
|
-
|
1124
|
-
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1125
|
-
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1126
|
-
|
1127
|
-
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
|
1128
|
-
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
|
1129
|
-
|
1130
|
-
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
|
1131
|
-
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1132
|
-
|
1133
|
-
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1134
|
-
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1135
|
-
|
1136
|
-
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
1137
|
-
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1138
|
-
|
1139
|
-
layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
1140
|
-
|
1141
|
-
layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
1142
|
-
layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1143
|
-
|
1144
|
-
layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
1145
|
-
layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1146
|
-
|
1147
|
-
// map by name
|
1148
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
1149
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
1150
|
-
|
1151
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
|
1152
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
1153
|
-
|
1154
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
|
1155
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
1156
|
-
|
1157
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
1158
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
1159
|
-
|
1160
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
1161
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
1162
|
-
|
1163
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
1164
|
-
|
1165
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
1166
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
1167
|
-
|
1168
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
1169
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
1170
|
-
}
|
1171
|
-
}
|
1172
|
-
|
1173
|
-
// decoder
|
1174
|
-
{
|
1175
|
-
model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
|
1176
|
-
|
1177
|
-
model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
|
1178
|
-
|
1179
|
-
model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1180
|
-
model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1181
|
-
|
1182
|
-
// map by name
|
1183
|
-
model.tensors["decoder.positional_embedding"] = model.d_pe;
|
1184
|
-
|
1185
|
-
model.tensors["decoder.token_embedding.weight"] = model.d_te;
|
1186
|
-
|
1187
|
-
model.tensors["decoder.ln.weight"] = model.d_ln_w;
|
1188
|
-
model.tensors["decoder.ln.bias"] = model.d_ln_b;
|
1189
|
-
|
1190
|
-
for (int i = 0; i < n_text_layer; ++i) {
|
1191
|
-
auto & layer = model.layers_decoder[i];
|
1192
|
-
|
1193
|
-
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1194
|
-
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1195
|
-
|
1196
|
-
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
|
1197
|
-
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
|
1198
|
-
|
1199
|
-
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
|
1200
|
-
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1201
|
-
|
1202
|
-
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1203
|
-
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1204
|
-
|
1205
|
-
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1206
|
-
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1207
|
-
|
1208
|
-
layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1209
|
-
|
1210
|
-
layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1211
|
-
layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1212
|
-
|
1213
|
-
layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1214
|
-
layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1215
|
-
|
1216
|
-
layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1217
|
-
layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1218
|
-
|
1219
|
-
layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1220
|
-
layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1221
|
-
|
1222
|
-
layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1223
|
-
|
1224
|
-
layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1225
|
-
layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1226
|
-
|
1227
|
-
layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1228
|
-
layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1229
|
-
|
1230
|
-
// map by name
|
1231
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
1232
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
1233
|
-
|
1234
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
|
1235
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
1236
|
-
|
1237
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
|
1238
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
1239
|
-
|
1240
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
1241
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
1242
|
-
|
1243
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
1244
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
1245
|
-
|
1246
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
1247
|
-
|
1248
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
1249
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
1250
|
-
|
1251
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
1252
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
1253
|
-
|
1254
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
|
1255
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
|
1256
|
-
|
1257
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
|
1258
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
|
1259
|
-
|
1260
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
|
1261
|
-
|
1262
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
|
1263
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
|
1264
|
-
|
1265
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
|
1266
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
|
1267
|
-
}
|
1268
|
-
}
|
1269
|
-
}
|
1270
|
-
|
1271
|
-
// load weights
|
1272
|
-
{
|
1273
|
-
size_t total_size = 0;
|
1274
|
-
|
1275
|
-
model.n_loaded = 0;
|
1276
|
-
|
1277
|
-
while (true) {
|
1278
|
-
int32_t n_dims;
|
1279
|
-
int32_t length;
|
1280
|
-
int32_t ftype;
|
1281
|
-
|
1282
|
-
read_safe(loader, n_dims);
|
1283
|
-
read_safe(loader, length);
|
1284
|
-
read_safe(loader, ftype);
|
1285
|
-
|
1286
|
-
if (loader->eof(loader->context)) {
|
1287
|
-
break;
|
1288
|
-
}
|
1289
|
-
|
1290
|
-
int32_t nelements = 1;
|
1291
|
-
int32_t ne[3] = { 1, 1, 1 };
|
1292
|
-
for (int i = 0; i < n_dims; ++i) {
|
1293
|
-
read_safe(loader, ne[i]);
|
1294
|
-
nelements *= ne[i];
|
1295
|
-
}
|
1296
|
-
|
1297
|
-
std::string name;
|
1298
|
-
std::vector<char> tmp(length); // create a buffer
|
1299
|
-
loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
|
1300
|
-
name.assign(&tmp[0], tmp.size());
|
1301
|
-
|
1302
|
-
if (model.tensors.find(name) == model.tensors.end()) {
|
1303
|
-
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
1304
|
-
return false;
|
1305
|
-
}
|
1306
|
-
|
1307
|
-
auto tensor = model.tensors[name.data()];
|
1308
|
-
if (ggml_nelements(tensor) != nelements) {
|
1309
|
-
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
1310
|
-
return false;
|
1311
|
-
}
|
1312
|
-
|
1313
|
-
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
1314
|
-
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
1315
|
-
__func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
|
1316
|
-
return false;
|
1317
|
-
}
|
1318
|
-
|
1319
|
-
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
1320
|
-
|
1321
|
-
if (nelements*bpe != ggml_nbytes(tensor)) {
|
1322
|
-
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
1323
|
-
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
1324
|
-
return false;
|
1325
|
-
}
|
1326
|
-
|
1327
|
-
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
1328
|
-
BYTESWAP_TENSOR(tensor);
|
1329
|
-
|
1330
|
-
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
1331
|
-
total_size += ggml_nbytes(tensor);
|
1332
|
-
model.n_loaded++;
|
1333
|
-
}
|
1334
|
-
|
1335
|
-
fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
|
1336
|
-
|
1337
|
-
if (model.n_loaded == 0) {
|
1338
|
-
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
1339
|
-
} else if (model.n_loaded != (int) model.tensors.size()) {
|
1340
|
-
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
1341
|
-
return false;
|
1342
|
-
}
|
1343
|
-
}
|
1344
|
-
|
1345
|
-
wctx.rng = std::mt19937(0);
|
1346
|
-
|
1347
|
-
wctx.t_load_us = ggml_time_us() - t_start_us;
|
1348
|
-
|
1349
|
-
return true;
|
1350
|
-
}
|
1351
|
-
|
1352
|
-
// evaluate the encoder
|
1353
|
-
//
|
1354
|
-
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
1355
|
-
// part of the transformer model and returns the encoded features
|
1356
|
-
//
|
1357
|
-
// - model: the model
|
1358
|
-
// - n_threads: number of threads to use
|
1359
|
-
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
1360
|
-
//
|
1361
|
-
static bool whisper_encode(
|
1362
|
-
whisper_context & wctx,
|
1363
|
-
const int mel_offset,
|
1364
|
-
const int n_threads) {
|
1365
|
-
const int64_t t_start_us = ggml_time_us();
|
1366
|
-
|
1367
|
-
const auto & model = wctx.model;
|
1368
|
-
const auto & mel_inp = wctx.mel;
|
1369
|
-
const auto & hparams = model.hparams;
|
1370
|
-
|
1371
|
-
const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
|
1372
|
-
const int n_state = hparams.n_audio_state;
|
1373
|
-
const int n_head = hparams.n_audio_head;
|
1374
|
-
const int n_layer = hparams.n_audio_layer;
|
1375
|
-
|
1376
|
-
const int n_mels = hparams.n_mels;
|
1377
|
-
assert(mel_inp.n_mel == n_mels);
|
1378
|
-
|
1379
|
-
struct ggml_init_params params;
|
1380
|
-
params.mem_size = wctx.buf_compute.size();
|
1381
|
-
params.mem_buffer = wctx.buf_compute.data();
|
1382
|
-
|
1383
|
-
struct ggml_context * ctx0 = ggml_init(params);
|
1384
|
-
|
1385
|
-
wctx.use_buf(ctx0, 0);
|
1386
|
-
|
1387
|
-
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
1388
|
-
assert(mel->type == GGML_TYPE_F32);
|
1389
|
-
{
|
1390
|
-
float * dst = (float *) mel->data;
|
1391
|
-
memset(dst, 0, ggml_nbytes(mel));
|
1392
|
-
|
1393
|
-
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
1394
|
-
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
1395
|
-
|
1396
|
-
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
1397
|
-
for (int i = i0; i < i1; ++i) {
|
1398
|
-
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
1399
|
-
}
|
1400
|
-
}
|
1401
|
-
}
|
1402
|
-
|
1403
|
-
struct ggml_tensor * cur;
|
1404
|
-
|
1405
|
-
// convolution + gelu
|
1406
|
-
{
|
1407
|
-
wctx.use_buf(ctx0, 1);
|
1408
|
-
|
1409
|
-
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
|
1410
|
-
cur = ggml_add(ctx0,
|
1411
|
-
ggml_repeat(ctx0,
|
1412
|
-
model.e_conv_1_b,
|
1413
|
-
cur),
|
1414
|
-
cur);
|
1415
|
-
|
1416
|
-
cur = ggml_gelu(ctx0, cur);
|
1417
|
-
|
1418
|
-
wctx.use_buf(ctx0, 0);
|
1419
|
-
|
1420
|
-
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
|
1421
|
-
cur = ggml_add(ctx0,
|
1422
|
-
ggml_repeat(ctx0,
|
1423
|
-
model.e_conv_2_b,
|
1424
|
-
cur),
|
1425
|
-
cur);
|
1426
|
-
|
1427
|
-
cur = ggml_gelu(ctx0, cur);
|
1428
|
-
}
|
1429
|
-
|
1430
|
-
wctx.use_buf(ctx0, 3);
|
1431
|
-
|
1432
|
-
// ===================================================================
|
1433
|
-
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
1434
|
-
//static int iter = -1;
|
1435
|
-
//const int n_iter = 1500/n_ctx;
|
1436
|
-
|
1437
|
-
//iter = (iter + 1) % n_iter;
|
1438
|
-
|
1439
|
-
//if (iter == 0) {
|
1440
|
-
// memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
|
1441
|
-
// memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
|
1442
|
-
//}
|
1443
|
-
|
1444
|
-
static int iter = 0;
|
1445
|
-
|
1446
|
-
const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
|
1447
|
-
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
|
1448
|
-
|
1449
|
-
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
|
1450
|
-
|
1451
|
-
cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
|
1452
|
-
|
1453
|
-
// ===================================================================
|
1454
|
-
|
1455
|
-
// original:
|
1456
|
-
//cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
|
1457
|
-
|
1458
|
-
struct ggml_tensor * inpL = cur;
|
1459
|
-
|
1460
|
-
for (int il = 0; il < n_layer; ++il) {
|
1461
|
-
const auto & layer = model.layers_encoder[il];
|
1462
|
-
|
1463
|
-
// norm
|
1464
|
-
{
|
1465
|
-
wctx.use_buf(ctx0, 0);
|
1466
|
-
|
1467
|
-
cur = ggml_norm(ctx0, inpL);
|
1468
|
-
|
1469
|
-
// cur = ln_0_w*cur + ln_0_b
|
1470
|
-
cur = ggml_add(ctx0,
|
1471
|
-
ggml_mul(ctx0,
|
1472
|
-
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
1473
|
-
cur),
|
1474
|
-
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
1475
|
-
}
|
1476
|
-
|
1477
|
-
// self-attention
|
1478
|
-
{
|
1479
|
-
wctx.use_buf(ctx0, 1);
|
1480
|
-
|
1481
|
-
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
1482
|
-
layer.attn_q_w,
|
1483
|
-
cur);
|
1484
|
-
|
1485
|
-
Qcur = ggml_add(ctx0,
|
1486
|
-
ggml_repeat(ctx0,
|
1487
|
-
layer.attn_q_b,
|
1488
|
-
Qcur),
|
1489
|
-
Qcur);
|
1490
|
-
|
1491
|
-
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
1492
|
-
|
1493
|
-
// note: no bias for Key
|
1494
|
-
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
1495
|
-
layer.attn_k_w,
|
1496
|
-
cur);
|
1497
|
-
|
1498
|
-
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
1499
|
-
|
1500
|
-
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
1501
|
-
layer.attn_v_w,
|
1502
|
-
cur);
|
1503
|
-
|
1504
|
-
Vcur = ggml_add(ctx0,
|
1505
|
-
ggml_repeat(ctx0,
|
1506
|
-
layer.attn_v_b,
|
1507
|
-
Vcur),
|
1508
|
-
Vcur);
|
1509
|
-
|
1510
|
-
// ------
|
1511
|
-
|
1512
|
-
wctx.use_buf(ctx0, 0);
|
1513
|
-
|
1514
|
-
#ifdef WHISPER_USE_FLASH_ATTN
|
1515
|
-
struct ggml_tensor * Q =
|
1516
|
-
ggml_permute(ctx0,
|
1517
|
-
ggml_cpy(ctx0,
|
1518
|
-
Qcur,
|
1519
|
-
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
1520
|
-
0, 2, 1, 3);
|
1521
|
-
|
1522
|
-
struct ggml_tensor * K =
|
1523
|
-
ggml_permute(ctx0,
|
1524
|
-
ggml_cpy(ctx0,
|
1525
|
-
Kcur,
|
1526
|
-
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
1527
|
-
0, 2, 1, 3);
|
1528
|
-
|
1529
|
-
struct ggml_tensor * V =
|
1530
|
-
ggml_cpy(ctx0,
|
1531
|
-
ggml_permute(ctx0,
|
1532
|
-
ggml_reshape_3d(ctx0,
|
1533
|
-
Vcur,
|
1534
|
-
n_state/n_head, n_head, n_ctx),
|
1535
|
-
1, 2, 0, 3),
|
1536
|
-
ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
|
1537
|
-
);
|
1538
|
-
|
1539
|
-
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
|
1540
|
-
#else
|
1541
|
-
struct ggml_tensor * Q =
|
1542
|
-
ggml_permute(ctx0,
|
1543
|
-
ggml_cpy(ctx0,
|
1544
|
-
Qcur,
|
1545
|
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
1546
|
-
0, 2, 1, 3);
|
1547
|
-
|
1548
|
-
struct ggml_tensor * K =
|
1549
|
-
ggml_permute(ctx0,
|
1550
|
-
ggml_cpy(ctx0,
|
1551
|
-
Kcur,
|
1552
|
-
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
1553
|
-
0, 2, 1, 3);
|
1554
|
-
|
1555
|
-
// K * Q
|
1556
|
-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
1557
|
-
|
1558
|
-
struct ggml_tensor * KQ_scaled =
|
1559
|
-
ggml_scale(ctx0,
|
1560
|
-
KQ,
|
1561
|
-
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
1562
|
-
);
|
1563
|
-
|
1564
|
-
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
|
1565
|
-
|
1566
|
-
//struct ggml_tensor * V_trans =
|
1567
|
-
// ggml_permute(ctx0,
|
1568
|
-
// ggml_cpy(ctx0,
|
1569
|
-
// Vcur,
|
1570
|
-
// ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
1571
|
-
// 1, 2, 0, 3);
|
1572
|
-
|
1573
|
-
//struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
1574
|
-
|
1575
|
-
struct ggml_tensor * V =
|
1576
|
-
ggml_cpy(ctx0,
|
1577
|
-
ggml_permute(ctx0,
|
1578
|
-
ggml_reshape_3d(ctx0,
|
1579
|
-
Vcur,
|
1580
|
-
n_state/n_head, n_head, n_ctx),
|
1581
|
-
0, 2, 1, 3),
|
1582
|
-
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
|
1583
|
-
);
|
1584
|
-
|
1585
|
-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
|
1586
|
-
#endif
|
1587
|
-
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
1588
|
-
|
1589
|
-
wctx.use_buf(ctx0, 1);
|
1590
|
-
|
1591
|
-
cur = ggml_cpy(ctx0,
|
1592
|
-
KQV_merged,
|
1593
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
1594
|
-
}
|
1595
|
-
|
1596
|
-
// projection
|
1597
|
-
{
|
1598
|
-
wctx.use_buf(ctx0, 0);
|
1599
|
-
|
1600
|
-
cur = ggml_mul_mat(ctx0,
|
1601
|
-
layer.attn_ln_1_w,
|
1602
|
-
cur);
|
1603
|
-
|
1604
|
-
wctx.use_buf(ctx0, 1);
|
1605
|
-
|
1606
|
-
cur = ggml_add(ctx0,
|
1607
|
-
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
1608
|
-
cur);
|
1609
|
-
}
|
1610
|
-
|
1611
|
-
wctx.use_buf(ctx0, 2);
|
1612
|
-
|
1613
|
-
// add the input
|
1614
|
-
cur = ggml_add(ctx0, cur, inpL);
|
1615
|
-
|
1616
|
-
struct ggml_tensor * inpFF = cur;
|
1617
|
-
|
1618
|
-
// feed-forward network
|
1619
|
-
{
|
1620
|
-
// norm
|
1621
|
-
{
|
1622
|
-
wctx.use_buf(ctx0, 0);
|
1623
|
-
|
1624
|
-
cur = ggml_norm(ctx0, inpFF);
|
1625
|
-
|
1626
|
-
wctx.use_buf(ctx0, 1);
|
1627
|
-
|
1628
|
-
// cur = mlp_ln_w*cur + mlp_ln_b
|
1629
|
-
cur = ggml_add(ctx0,
|
1630
|
-
ggml_mul(ctx0,
|
1631
|
-
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
1632
|
-
cur),
|
1633
|
-
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
1634
|
-
}
|
1635
|
-
|
1636
|
-
#ifdef WHISPER_USE_FLASH_FF
|
1637
|
-
wctx.use_buf(ctx0, 0);
|
1638
|
-
|
1639
|
-
cur = ggml_flash_ff(ctx0,
|
1640
|
-
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
|
1641
|
-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
1642
|
-
#else
|
1643
|
-
wctx.use_buf(ctx0, 0);
|
1644
|
-
|
1645
|
-
// fully connected
|
1646
|
-
cur = ggml_mul_mat(ctx0,
|
1647
|
-
layer.mlp_0_w,
|
1648
|
-
cur);
|
1649
|
-
|
1650
|
-
wctx.use_buf(ctx0, 1);
|
1651
|
-
|
1652
|
-
cur = ggml_add(ctx0,
|
1653
|
-
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
1654
|
-
cur);
|
1655
|
-
|
1656
|
-
wctx.use_buf(ctx0, 0);
|
1657
|
-
|
1658
|
-
// GELU activation
|
1659
|
-
cur = ggml_gelu(ctx0, cur);
|
1660
|
-
|
1661
|
-
wctx.use_buf(ctx0, 1);
|
1662
|
-
|
1663
|
-
// projection
|
1664
|
-
cur = ggml_mul_mat(ctx0,
|
1665
|
-
layer.mlp_1_w,
|
1666
|
-
cur);
|
1667
|
-
|
1668
|
-
wctx.use_buf(ctx0, 0);
|
1669
|
-
|
1670
|
-
cur = ggml_add(ctx0,
|
1671
|
-
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
1672
|
-
cur);
|
1673
|
-
#endif
|
1674
|
-
}
|
1675
|
-
|
1676
|
-
wctx.use_buf(ctx0, 3);
|
1677
|
-
|
1678
|
-
inpL = ggml_add(ctx0, cur, inpFF);
|
1679
|
-
}
|
1680
|
-
|
1681
|
-
cur = inpL;
|
1682
|
-
|
1683
|
-
// norm
|
1684
|
-
{
|
1685
|
-
wctx.use_buf(ctx0, 0);
|
1686
|
-
|
1687
|
-
cur = ggml_norm(ctx0, cur);
|
1688
|
-
|
1689
|
-
wctx.use_buf(ctx0, 1);
|
1690
|
-
|
1691
|
-
// cur = ln_f_g*cur + ln_f_b
|
1692
|
-
cur = ggml_add(ctx0,
|
1693
|
-
ggml_mul(ctx0,
|
1694
|
-
ggml_repeat(ctx0, model.e_ln_w, cur),
|
1695
|
-
cur),
|
1696
|
-
ggml_repeat(ctx0, model.e_ln_b, cur));
|
1697
|
-
}
|
1698
|
-
|
1699
|
-
wctx.use_buf(ctx0, -1);
|
1700
|
-
|
1701
|
-
// run the computation
|
1702
|
-
{
|
1703
|
-
struct ggml_cgraph gf = {};
|
1704
|
-
gf.n_threads = n_threads;
|
1705
|
-
|
1706
|
-
ggml_build_forward_expand(&gf, cur);
|
1707
|
-
ggml_graph_compute (ctx0, &gf);
|
1708
|
-
|
1709
|
-
//ggml_graph_print(&gf);
|
1710
|
-
}
|
1711
|
-
|
1712
|
-
// cur
|
1713
|
-
//{
|
1714
|
-
// printf("ne0 = %d\n", cur->ne[0]);
|
1715
|
-
// printf("ne1 = %d\n", cur->ne[1]);
|
1716
|
-
// for (int i = 0; i < 10; ++i) {
|
1717
|
-
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
1718
|
-
// }
|
1719
|
-
// printf("... ");
|
1720
|
-
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
|
1721
|
-
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
1722
|
-
// }
|
1723
|
-
// printf("\n");
|
1724
|
-
//}
|
1725
|
-
|
1726
|
-
// pre-compute cross-attention memory
|
1727
|
-
{
|
1728
|
-
struct ggml_cgraph gf = {};
|
1729
|
-
gf.n_threads = n_threads;
|
1730
|
-
|
1731
|
-
// TODO: hack to disconnect the encoded features from the previous graph
|
1732
|
-
cur->op = GGML_OP_NONE;
|
1733
|
-
cur->src0 = nullptr;
|
1734
|
-
cur->src1 = nullptr;
|
1735
|
-
|
1736
|
-
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
1737
|
-
auto & layer = model.layers_decoder[il];
|
1738
|
-
|
1739
|
-
wctx.use_buf(ctx0, 0);
|
1740
|
-
|
1741
|
-
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
|
1742
|
-
layer.cross_attn_k_w,
|
1743
|
-
cur);
|
1744
|
-
|
1745
|
-
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
1746
|
-
|
1747
|
-
wctx.use_buf(ctx0, 1);
|
1748
|
-
|
1749
|
-
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
1750
|
-
layer.cross_attn_v_w,
|
1751
|
-
cur);
|
1752
|
-
|
1753
|
-
Vcross = ggml_add(ctx0,
|
1754
|
-
ggml_repeat(ctx0,
|
1755
|
-
layer.cross_attn_v_b,
|
1756
|
-
Vcross),
|
1757
|
-
Vcross);
|
1758
|
-
|
1759
|
-
wctx.use_buf(ctx0, -1);
|
1760
|
-
|
1761
|
-
//struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
|
1762
|
-
//struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
|
1763
|
-
struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
|
1764
|
-
struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
|
1765
|
-
|
1766
|
-
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
|
1767
|
-
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
1768
|
-
}
|
1769
|
-
|
1770
|
-
ggml_graph_compute(ctx0, &gf);
|
1771
|
-
//ggml_graph_print(&gf);
|
1772
|
-
}
|
1773
|
-
|
1774
|
-
////////////////////////////////////////////////////////////////////////////
|
1775
|
-
|
1776
|
-
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
1777
|
-
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
1778
|
-
// wctx.get_buf_max_mem(0)/1024.0/1024.0,
|
1779
|
-
// wctx.get_buf_max_mem(1)/1024.0/1024.0,
|
1780
|
-
// wctx.get_buf_max_mem(2)/1024.0/1024.0,
|
1781
|
-
// wctx.get_buf_max_mem(3)/1024.0/1024.0);
|
1782
|
-
|
1783
|
-
ggml_free(ctx0);
|
1784
|
-
|
1785
|
-
wctx.t_encode_us += ggml_time_us() - t_start_us;
|
1786
|
-
wctx.n_encode++;
|
1787
|
-
|
1788
|
-
return true;
|
1789
|
-
}
|
1790
|
-
|
1791
|
-
// evaluate the decoder
|
1792
|
-
//
|
1793
|
-
// given text prompt + audio features -> computes the logits for the next token
|
1794
|
-
//
|
1795
|
-
// - model: the model
|
1796
|
-
// - n_threads: number of threads to use
|
1797
|
-
// - tokens: text prompt
|
1798
|
-
// - n_tokens: number of tokens in the prompt
|
1799
|
-
// - n_past: number of past tokens to prefix the prompt with
|
1800
|
-
//
|
1801
|
-
static bool whisper_decode(
|
1802
|
-
whisper_context & wctx,
|
1803
|
-
whisper_decoder & decoder,
|
1804
|
-
const whisper_token * tokens,
|
1805
|
-
const int n_tokens,
|
1806
|
-
const int n_past,
|
1807
|
-
const int n_threads) {
|
1808
|
-
const int64_t t_start_us = ggml_time_us();
|
1809
|
-
|
1810
|
-
const auto & model = wctx.model;
|
1811
|
-
const auto & hparams = model.hparams;
|
1812
|
-
|
1813
|
-
auto & kv_self = decoder.kv_self;
|
1814
|
-
|
1815
|
-
WHISPER_ASSERT(!!kv_self.ctx);
|
1816
|
-
|
1817
|
-
auto & logits_out = wctx.logits;
|
1818
|
-
|
1819
|
-
const int n_vocab = hparams.n_vocab;
|
1820
|
-
|
1821
|
-
const int n_ctx = hparams.n_text_ctx;
|
1822
|
-
const int n_state = hparams.n_text_state;
|
1823
|
-
const int n_head = hparams.n_text_head;
|
1824
|
-
const int n_layer = hparams.n_text_layer;
|
1825
|
-
|
1826
|
-
const int N = n_tokens;
|
1827
|
-
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
|
1828
|
-
|
1829
|
-
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
1830
|
-
|
1831
|
-
struct ggml_init_params params;
|
1832
|
-
params.mem_size = wctx.buf_compute.size();
|
1833
|
-
params.mem_buffer = wctx.buf_compute.data();
|
1834
|
-
|
1835
|
-
struct ggml_context * ctx0 = ggml_init(params);
|
1836
|
-
|
1837
|
-
struct ggml_cgraph gf = {};
|
1838
|
-
gf.n_threads = n_threads;
|
1839
|
-
|
1840
|
-
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
1841
|
-
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
1842
|
-
|
1843
|
-
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
1844
|
-
for (int i = 0; i < N; ++i) {
|
1845
|
-
((int32_t *) position->data)[i] = n_past + i;
|
1846
|
-
}
|
1847
|
-
|
1848
|
-
wctx.use_buf(ctx0, 3);
|
1849
|
-
|
1850
|
-
// token encoding + position encoding
|
1851
|
-
struct ggml_tensor * cur =
|
1852
|
-
ggml_add(ctx0,
|
1853
|
-
ggml_get_rows(ctx0, model.d_te, embd),
|
1854
|
-
ggml_get_rows(ctx0, model.d_pe, position));
|
1855
|
-
|
1856
|
-
struct ggml_tensor * inpL = cur;
|
1857
|
-
|
1858
|
-
for (int il = 0; il < n_layer; ++il) {
|
1859
|
-
const auto & layer = model.layers_decoder[il];
|
1860
|
-
|
1861
|
-
// norm
|
1862
|
-
{
|
1863
|
-
wctx.use_buf(ctx0, 0);
|
1864
|
-
|
1865
|
-
cur = ggml_norm(ctx0, inpL);
|
1866
|
-
|
1867
|
-
// cur = ln_0_w*cur + ln_0_b
|
1868
|
-
cur = ggml_add(ctx0,
|
1869
|
-
ggml_mul(ctx0,
|
1870
|
-
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
1871
|
-
cur),
|
1872
|
-
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
1873
|
-
}
|
1874
|
-
|
1875
|
-
// self-attention
|
1876
|
-
{
|
1877
|
-
wctx.use_buf(ctx0, 1);
|
1878
|
-
|
1879
|
-
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
1880
|
-
layer.attn_q_w,
|
1881
|
-
cur);
|
1882
|
-
|
1883
|
-
Qcur = ggml_add(ctx0,
|
1884
|
-
ggml_repeat(ctx0,
|
1885
|
-
layer.attn_q_b,
|
1886
|
-
Qcur),
|
1887
|
-
Qcur);
|
1888
|
-
|
1889
|
-
Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
1890
|
-
|
1891
|
-
// note: no bias for Key
|
1892
|
-
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
1893
|
-
layer.attn_k_w,
|
1894
|
-
cur);
|
1895
|
-
|
1896
|
-
Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
1897
|
-
|
1898
|
-
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
1899
|
-
layer.attn_v_w,
|
1900
|
-
cur);
|
1901
|
-
|
1902
|
-
Vcur = ggml_add(ctx0,
|
1903
|
-
ggml_repeat(ctx0,
|
1904
|
-
layer.attn_v_b,
|
1905
|
-
Vcur),
|
1906
|
-
Vcur);
|
1907
|
-
|
1908
|
-
// store key and value to memory
|
1909
|
-
{
|
1910
|
-
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
|
1911
|
-
struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
|
1912
|
-
|
1913
|
-
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
1914
|
-
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
1915
|
-
}
|
1916
|
-
|
1917
|
-
// ------
|
1918
|
-
|
1919
|
-
wctx.use_buf(ctx0, 0);
|
1920
|
-
|
1921
|
-
struct ggml_tensor * Q =
|
1922
|
-
ggml_permute(ctx0,
|
1923
|
-
ggml_cpy(ctx0,
|
1924
|
-
Qcur,
|
1925
|
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
1926
|
-
0, 2, 1, 3);
|
1927
|
-
|
1928
|
-
struct ggml_tensor * K =
|
1929
|
-
ggml_permute(ctx0,
|
1930
|
-
ggml_reshape_3d(ctx0,
|
1931
|
-
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
|
1932
|
-
n_state/n_head, n_head, n_past + N),
|
1933
|
-
0, 2, 1, 3);
|
1934
|
-
|
1935
|
-
wctx.use_buf(ctx0, 1);
|
1936
|
-
|
1937
|
-
// K * Q
|
1938
|
-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
1939
|
-
|
1940
|
-
wctx.use_buf(ctx0, 0);
|
1941
|
-
|
1942
|
-
//struct ggml_tensor * KQ_scaled =
|
1943
|
-
// ggml_scale(ctx0,
|
1944
|
-
// KQ,
|
1945
|
-
// ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
1946
|
-
// );
|
1947
|
-
|
1948
|
-
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
1949
|
-
|
1950
|
-
wctx.use_buf(ctx0, 1);
|
1951
|
-
|
1952
|
-
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
1953
|
-
|
1954
|
-
wctx.use_buf(ctx0, 0);
|
1955
|
-
|
1956
|
-
struct ggml_tensor * V_trans =
|
1957
|
-
ggml_permute(ctx0,
|
1958
|
-
ggml_reshape_3d(ctx0,
|
1959
|
-
ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
|
1960
|
-
n_state/n_head, n_head, n_past + N),
|
1961
|
-
1, 2, 0, 3);
|
1962
|
-
|
1963
|
-
wctx.use_buf(ctx0, 1);
|
1964
|
-
|
1965
|
-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
1966
|
-
|
1967
|
-
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
1968
|
-
|
1969
|
-
cur = ggml_cpy(ctx0,
|
1970
|
-
KQV_merged,
|
1971
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
|
1972
|
-
}
|
1973
|
-
|
1974
|
-
// projection
|
1975
|
-
{
|
1976
|
-
wctx.use_buf(ctx0, 0);
|
1977
|
-
|
1978
|
-
cur = ggml_mul_mat(ctx0,
|
1979
|
-
layer.attn_ln_1_w,
|
1980
|
-
cur);
|
1981
|
-
|
1982
|
-
wctx.use_buf(ctx0, 1);
|
1983
|
-
|
1984
|
-
cur = ggml_add(ctx0,
|
1985
|
-
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
1986
|
-
cur);
|
1987
|
-
}
|
1988
|
-
|
1989
|
-
wctx.use_buf(ctx0, 2);
|
1990
|
-
|
1991
|
-
// add the input
|
1992
|
-
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
1993
|
-
|
1994
|
-
// norm
|
1995
|
-
{
|
1996
|
-
wctx.use_buf(ctx0, 0);
|
1997
|
-
|
1998
|
-
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
1999
|
-
|
2000
|
-
wctx.use_buf(ctx0, 1);
|
2001
|
-
|
2002
|
-
// cur = ln_0_w*cur + ln_0_b
|
2003
|
-
cur = ggml_add(ctx0,
|
2004
|
-
ggml_mul(ctx0,
|
2005
|
-
ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
|
2006
|
-
cur),
|
2007
|
-
ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
|
2008
|
-
}
|
2009
|
-
|
2010
|
-
// cross-attention
|
2011
|
-
{
|
2012
|
-
wctx.use_buf(ctx0, 0);
|
2013
|
-
|
2014
|
-
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
2015
|
-
layer.cross_attn_q_w,
|
2016
|
-
cur);
|
2017
|
-
|
2018
|
-
Qcur = ggml_add(ctx0,
|
2019
|
-
ggml_repeat(ctx0,
|
2020
|
-
layer.cross_attn_q_b,
|
2021
|
-
Qcur),
|
2022
|
-
Qcur);
|
2023
|
-
|
2024
|
-
Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
2025
|
-
|
2026
|
-
// Kcross is already scaled
|
2027
|
-
struct ggml_tensor * Kcross =
|
2028
|
-
ggml_reshape_3d(ctx0,
|
2029
|
-
ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
|
2030
|
-
n_state/n_head, n_head, M);
|
2031
|
-
|
2032
|
-
struct ggml_tensor * Vcross =
|
2033
|
-
ggml_reshape_3d(ctx0,
|
2034
|
-
ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
|
2035
|
-
n_state/n_head, n_head, M);
|
2036
|
-
|
2037
|
-
struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
|
2038
|
-
|
2039
|
-
// ------
|
2040
|
-
|
2041
|
-
wctx.use_buf(ctx0, 1);
|
2042
|
-
|
2043
|
-
struct ggml_tensor * Q =
|
2044
|
-
ggml_permute(ctx0,
|
2045
|
-
ggml_cpy(ctx0,
|
2046
|
-
Qcur,
|
2047
|
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
2048
|
-
0, 2, 1, 3);
|
2049
|
-
|
2050
|
-
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
2051
|
-
|
2052
|
-
wctx.use_buf(ctx0, 0);
|
2053
|
-
|
2054
|
-
// K * Q
|
2055
|
-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
2056
|
-
|
2057
|
-
//struct ggml_tensor * KQ_scaled =
|
2058
|
-
// ggml_scale(ctx0,
|
2059
|
-
// KQ,
|
2060
|
-
// ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
2061
|
-
// );
|
2062
|
-
|
2063
|
-
// no masking for cross-attention
|
2064
|
-
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
2065
|
-
|
2066
|
-
wctx.use_buf(ctx0, 1);
|
2067
|
-
|
2068
|
-
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
2069
|
-
|
2070
|
-
wctx.use_buf(ctx0, 0);
|
2071
|
-
|
2072
|
-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
2073
|
-
|
2074
|
-
wctx.use_buf(ctx0, 1);
|
2075
|
-
|
2076
|
-
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2077
|
-
|
2078
|
-
// cur = KQV_merged.contiguous().view(n_state, N)
|
2079
|
-
cur = ggml_cpy(ctx0,
|
2080
|
-
KQV_merged,
|
2081
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
|
2082
|
-
}
|
2083
|
-
|
2084
|
-
// projection
|
2085
|
-
{
|
2086
|
-
wctx.use_buf(ctx0, 0);
|
2087
|
-
|
2088
|
-
cur = ggml_mul_mat(ctx0,
|
2089
|
-
layer.cross_attn_ln_1_w,
|
2090
|
-
cur);
|
2091
|
-
|
2092
|
-
wctx.use_buf(ctx0, 1);
|
2093
|
-
|
2094
|
-
cur = ggml_add(ctx0,
|
2095
|
-
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
2096
|
-
cur);
|
2097
|
-
}
|
2098
|
-
|
2099
|
-
wctx.use_buf(ctx0, 2);
|
2100
|
-
|
2101
|
-
// add the input
|
2102
|
-
cur = ggml_add(ctx0, cur, inpCA);
|
2103
|
-
|
2104
|
-
struct ggml_tensor * inpFF = cur;
|
2105
|
-
|
2106
|
-
// feed-forward network
|
2107
|
-
{
|
2108
|
-
// norm
|
2109
|
-
{
|
2110
|
-
wctx.use_buf(ctx0, 0);
|
2111
|
-
|
2112
|
-
cur = ggml_norm(ctx0, inpFF);
|
2113
|
-
|
2114
|
-
wctx.use_buf(ctx0, 1);
|
2115
|
-
|
2116
|
-
// cur = mlp_ln_w*cur + mlp_ln_b
|
2117
|
-
cur = ggml_add(ctx0,
|
2118
|
-
ggml_mul(ctx0,
|
2119
|
-
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
2120
|
-
cur),
|
2121
|
-
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
2122
|
-
}
|
2123
|
-
|
2124
|
-
wctx.use_buf(ctx0, 0);
|
2125
|
-
|
2126
|
-
// fully connected
|
2127
|
-
cur = ggml_mul_mat(ctx0,
|
2128
|
-
layer.mlp_0_w,
|
2129
|
-
cur);
|
2130
|
-
|
2131
|
-
wctx.use_buf(ctx0, 1);
|
2132
|
-
|
2133
|
-
cur = ggml_add(ctx0,
|
2134
|
-
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
2135
|
-
cur);
|
2136
|
-
|
2137
|
-
wctx.use_buf(ctx0, 0);
|
2138
|
-
|
2139
|
-
// GELU activation
|
2140
|
-
cur = ggml_gelu(ctx0, cur);
|
2141
|
-
|
2142
|
-
wctx.use_buf(ctx0, 1);
|
2143
|
-
|
2144
|
-
// projection
|
2145
|
-
cur = ggml_mul_mat(ctx0,
|
2146
|
-
layer.mlp_1_w,
|
2147
|
-
cur);
|
2148
|
-
|
2149
|
-
wctx.use_buf(ctx0, 0);
|
2150
|
-
|
2151
|
-
cur = ggml_add(ctx0,
|
2152
|
-
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
2153
|
-
cur);
|
2154
|
-
}
|
2155
|
-
|
2156
|
-
wctx.use_buf(ctx0, 3);
|
2157
|
-
|
2158
|
-
inpL = ggml_add(ctx0, cur, inpFF);
|
2159
|
-
}
|
2160
|
-
|
2161
|
-
cur = inpL;
|
2162
|
-
|
2163
|
-
// norm
|
2164
|
-
{
|
2165
|
-
wctx.use_buf(ctx0, 0);
|
2166
|
-
|
2167
|
-
cur = ggml_norm(ctx0, cur);
|
2168
|
-
|
2169
|
-
wctx.use_buf(ctx0, 1);
|
2170
|
-
|
2171
|
-
cur = ggml_add(ctx0,
|
2172
|
-
ggml_mul(ctx0,
|
2173
|
-
ggml_repeat(ctx0, model.d_ln_w, cur),
|
2174
|
-
cur),
|
2175
|
-
ggml_repeat(ctx0, model.d_ln_b, cur));
|
2176
|
-
}
|
2177
|
-
|
2178
|
-
wctx.use_buf(ctx0, 0);
|
2179
|
-
|
2180
|
-
// compute logits only for the last token
|
2181
|
-
// comment this line to compute logits for all N tokens
|
2182
|
-
// might be useful in the future
|
2183
|
-
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
2184
|
-
|
2185
|
-
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
2186
|
-
|
2187
|
-
wctx.use_buf(ctx0, -1);
|
2188
|
-
|
2189
|
-
// run the computation
|
2190
|
-
{
|
2191
|
-
ggml_build_forward_expand(&gf, logits);
|
2192
|
-
ggml_graph_compute (ctx0, &gf);
|
2193
|
-
}
|
2194
|
-
|
2195
|
-
// extract logits for all N tokens
|
2196
|
-
//logits_out.resize(N*n_vocab);
|
2197
|
-
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
|
2198
|
-
|
2199
|
-
// extract logits only for the last token
|
2200
|
-
logits_out.resize(n_vocab);
|
2201
|
-
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
2202
|
-
|
2203
|
-
if (N > 1) {
|
2204
|
-
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
2205
|
-
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
2206
|
-
// wctx.get_buf_max_mem(0)/1024.0/1024.0,
|
2207
|
-
// wctx.get_buf_max_mem(1)/1024.0/1024.0,
|
2208
|
-
// wctx.get_buf_max_mem(2)/1024.0/1024.0,
|
2209
|
-
// wctx.get_buf_max_mem(3)/1024.0/1024.0);
|
2210
|
-
}
|
2211
|
-
|
2212
|
-
ggml_free(ctx0);
|
2213
|
-
|
2214
|
-
wctx.t_decode_us += ggml_time_us() - t_start_us;
|
2215
|
-
wctx.n_decode++;
|
2216
|
-
|
2217
|
-
return true;
|
2218
|
-
}
|
2219
|
-
|
2220
|
-
// 500 -> 00:05.000
|
2221
|
-
// 6000 -> 01:00.000
|
2222
|
-
static std::string to_timestamp(int64_t t, bool comma = false) {
|
2223
|
-
int64_t msec = t * 10;
|
2224
|
-
int64_t hr = msec / (1000 * 60 * 60);
|
2225
|
-
msec = msec - hr * (1000 * 60 * 60);
|
2226
|
-
int64_t min = msec / (1000 * 60);
|
2227
|
-
msec = msec - min * (1000 * 60);
|
2228
|
-
int64_t sec = msec / 1000;
|
2229
|
-
msec = msec - sec * 1000;
|
2230
|
-
|
2231
|
-
char buf[32];
|
2232
|
-
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
|
2233
|
-
|
2234
|
-
return std::string(buf);
|
2235
|
-
}
|
2236
|
-
|
2237
|
-
// naive Discrete Fourier Transform
|
2238
|
-
// input is real-valued
|
2239
|
-
// output is complex-valued
|
2240
|
-
static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
2241
|
-
int N = in.size();
|
2242
|
-
|
2243
|
-
out.resize(N*2);
|
2244
|
-
|
2245
|
-
for (int k = 0; k < N; k++) {
|
2246
|
-
float re = 0;
|
2247
|
-
float im = 0;
|
2248
|
-
|
2249
|
-
for (int n = 0; n < N; n++) {
|
2250
|
-
float angle = 2*M_PI*k*n/N;
|
2251
|
-
re += in[n]*cos(angle);
|
2252
|
-
im -= in[n]*sin(angle);
|
2253
|
-
}
|
2254
|
-
|
2255
|
-
out[k*2 + 0] = re;
|
2256
|
-
out[k*2 + 1] = im;
|
2257
|
-
}
|
2258
|
-
}
|
2259
|
-
|
2260
|
-
// Cooley-Tukey FFT
|
2261
|
-
// poor man's implementation - use something better
|
2262
|
-
// input is real-valued
|
2263
|
-
// output is complex-valued
|
2264
|
-
static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
2265
|
-
out.resize(in.size()*2);
|
2266
|
-
|
2267
|
-
int N = in.size();
|
2268
|
-
|
2269
|
-
if (N == 1) {
|
2270
|
-
out[0] = in[0];
|
2271
|
-
out[1] = 0;
|
2272
|
-
return;
|
2273
|
-
}
|
2274
|
-
|
2275
|
-
if (N%2 == 1) {
|
2276
|
-
dft(in, out);
|
2277
|
-
return;
|
2278
|
-
}
|
2279
|
-
|
2280
|
-
std::vector<float> even;
|
2281
|
-
std::vector<float> odd;
|
2282
|
-
|
2283
|
-
even.reserve(N/2);
|
2284
|
-
odd.reserve(N/2);
|
2285
|
-
|
2286
|
-
for (int i = 0; i < N; i++) {
|
2287
|
-
if (i % 2 == 0) {
|
2288
|
-
even.push_back(in[i]);
|
2289
|
-
} else {
|
2290
|
-
odd.push_back(in[i]);
|
2291
|
-
}
|
2292
|
-
}
|
2293
|
-
|
2294
|
-
std::vector<float> even_fft;
|
2295
|
-
std::vector<float> odd_fft;
|
2296
|
-
|
2297
|
-
fft(even, even_fft);
|
2298
|
-
fft(odd, odd_fft);
|
2299
|
-
|
2300
|
-
for (int k = 0; k < N/2; k++) {
|
2301
|
-
float theta = 2*M_PI*k/N;
|
2302
|
-
|
2303
|
-
float re = cos(theta);
|
2304
|
-
float im = -sin(theta);
|
2305
|
-
|
2306
|
-
float re_odd = odd_fft[2*k + 0];
|
2307
|
-
float im_odd = odd_fft[2*k + 1];
|
2308
|
-
|
2309
|
-
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
2310
|
-
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
2311
|
-
|
2312
|
-
out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
2313
|
-
out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
2314
|
-
}
|
2315
|
-
}
|
2316
|
-
|
2317
|
-
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
2318
|
-
static bool log_mel_spectrogram(
|
2319
|
-
whisper_context & wctx,
|
2320
|
-
const float * samples,
|
2321
|
-
const int n_samples,
|
2322
|
-
const int /*sample_rate*/,
|
2323
|
-
const int fft_size,
|
2324
|
-
const int fft_step,
|
2325
|
-
const int n_mel,
|
2326
|
-
const int n_threads,
|
2327
|
-
const whisper_filters & filters,
|
2328
|
-
const bool speed_up,
|
2329
|
-
whisper_mel & mel) {
|
2330
|
-
const int64_t t_start_us = ggml_time_us();
|
2331
|
-
|
2332
|
-
// Hanning window
|
2333
|
-
std::vector<float> hann;
|
2334
|
-
hann.resize(fft_size);
|
2335
|
-
for (int i = 0; i < fft_size; i++) {
|
2336
|
-
hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
|
2337
|
-
}
|
2338
|
-
|
2339
|
-
mel.n_mel = n_mel;
|
2340
|
-
mel.n_len = (n_samples)/fft_step;
|
2341
|
-
mel.data.resize(mel.n_mel*mel.n_len);
|
2342
|
-
|
2343
|
-
const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
|
2344
|
-
|
2345
|
-
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
|
2346
|
-
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
|
2347
|
-
|
2348
|
-
std::vector<std::thread> workers(n_threads);
|
2349
|
-
for (int iw = 0; iw < n_threads; ++iw) {
|
2350
|
-
workers[iw] = std::thread([&](int ith) {
|
2351
|
-
std::vector<float> fft_in;
|
2352
|
-
fft_in.resize(fft_size);
|
2353
|
-
for (int i = 0; i < fft_size; i++) {
|
2354
|
-
fft_in[i] = 0.0;
|
2355
|
-
}
|
2356
|
-
|
2357
|
-
std::vector<float> fft_out;
|
2358
|
-
fft_out.resize(2*fft_size);
|
2359
|
-
|
2360
|
-
for (int i = ith; i < mel.n_len; i += n_threads) {
|
2361
|
-
const int offset = i*fft_step;
|
2362
|
-
|
2363
|
-
// apply Hanning window
|
2364
|
-
for (int j = 0; j < fft_size; j++) {
|
2365
|
-
if (offset + j < n_samples) {
|
2366
|
-
fft_in[j] = hann[j]*samples[offset + j];
|
2367
|
-
} else {
|
2368
|
-
fft_in[j] = 0.0;
|
2369
|
-
}
|
2370
|
-
}
|
2371
|
-
|
2372
|
-
// FFT -> mag^2
|
2373
|
-
fft(fft_in, fft_out);
|
2374
|
-
|
2375
|
-
for (int j = 0; j < fft_size; j++) {
|
2376
|
-
fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
|
2377
|
-
}
|
2378
|
-
for (int j = 1; j < fft_size/2; j++) {
|
2379
|
-
//if (i == 0) {
|
2380
|
-
// printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
|
2381
|
-
//}
|
2382
|
-
fft_out[j] += fft_out[fft_size - j];
|
2383
|
-
}
|
2384
|
-
if (i == 0) {
|
2385
|
-
//for (int j = 0; j < fft_size; j++) {
|
2386
|
-
// printf("%d: %e\n", j, fft_out[j]);
|
2387
|
-
//}
|
2388
|
-
}
|
2389
|
-
|
2390
|
-
if (speed_up) {
|
2391
|
-
// scale down in the frequency domain results in a speed up in the time domain
|
2392
|
-
for (int j = 0; j < n_fft; j++) {
|
2393
|
-
fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
|
2394
|
-
}
|
2395
|
-
}
|
2396
|
-
|
2397
|
-
// mel spectrogram
|
2398
|
-
for (int j = 0; j < mel.n_mel; j++) {
|
2399
|
-
double sum = 0.0;
|
2400
|
-
|
2401
|
-
for (int k = 0; k < n_fft; k++) {
|
2402
|
-
sum += fft_out[k]*filters.data[j*n_fft + k];
|
2403
|
-
}
|
2404
|
-
if (sum < 1e-10) {
|
2405
|
-
sum = 1e-10;
|
2406
|
-
}
|
2407
|
-
|
2408
|
-
sum = log10(sum);
|
2409
|
-
|
2410
|
-
mel.data[j*mel.n_len + i] = sum;
|
2411
|
-
}
|
2412
|
-
}
|
2413
|
-
}, iw);
|
2414
|
-
}
|
2415
|
-
|
2416
|
-
for (int iw = 0; iw < n_threads; ++iw) {
|
2417
|
-
workers[iw].join();
|
2418
|
-
}
|
2419
|
-
|
2420
|
-
// clamping and normalization
|
2421
|
-
double mmax = -1e20;
|
2422
|
-
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
2423
|
-
if (mel.data[i] > mmax) {
|
2424
|
-
mmax = mel.data[i];
|
2425
|
-
}
|
2426
|
-
}
|
2427
|
-
//printf("%s: max = %f\n", __func__, mmax);
|
2428
|
-
|
2429
|
-
mmax -= 8.0;
|
2430
|
-
|
2431
|
-
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
2432
|
-
if (mel.data[i] < mmax) {
|
2433
|
-
mel.data[i] = mmax;
|
2434
|
-
}
|
2435
|
-
|
2436
|
-
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
2437
|
-
}
|
2438
|
-
|
2439
|
-
wctx.t_mel_us += ggml_time_us() - t_start_us;
|
2440
|
-
|
2441
|
-
return true;
|
2442
|
-
}
|
2443
|
-
|
2444
|
-
// split text into tokens
|
2445
|
-
//
|
2446
|
-
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
2447
|
-
//
|
2448
|
-
// Regex (Python):
|
2449
|
-
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
2450
|
-
//
|
2451
|
-
// Regex (C++):
|
2452
|
-
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
2453
|
-
//
|
2454
|
-
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
|
2455
|
-
std::vector<std::string> words;
|
2456
|
-
|
2457
|
-
// first split the text into words
|
2458
|
-
{
|
2459
|
-
std::string str = text;
|
2460
|
-
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
2461
|
-
|
2462
|
-
std::regex re(pat);
|
2463
|
-
std::smatch m;
|
2464
|
-
|
2465
|
-
while (std::regex_search(str, m, re)) {
|
2466
|
-
for (auto x : m) {
|
2467
|
-
words.push_back(x);
|
2468
|
-
}
|
2469
|
-
str = m.suffix();
|
2470
|
-
}
|
2471
|
-
}
|
2472
|
-
|
2473
|
-
// find the longest tokens that form the words:
|
2474
|
-
std::vector<whisper_vocab::id> tokens;
|
2475
|
-
for (const auto & word : words) {
|
2476
|
-
if (word.empty()) continue;
|
2477
|
-
|
2478
|
-
int i = 0;
|
2479
|
-
int n = word.size();
|
2480
|
-
while (i < n) {
|
2481
|
-
int j = n;
|
2482
|
-
while (j > i) {
|
2483
|
-
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
2484
|
-
if (it != vocab.token_to_id.end()) {
|
2485
|
-
tokens.push_back(it->second);
|
2486
|
-
i = j;
|
2487
|
-
break;
|
2488
|
-
}
|
2489
|
-
--j;
|
2490
|
-
}
|
2491
|
-
if (i == n) {
|
2492
|
-
break;
|
2493
|
-
}
|
2494
|
-
if (j == i) {
|
2495
|
-
auto sub = word.substr(i, 1);
|
2496
|
-
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
2497
|
-
tokens.push_back(vocab.token_to_id.at(sub));
|
2498
|
-
} else {
|
2499
|
-
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
2500
|
-
}
|
2501
|
-
++i;
|
2502
|
-
}
|
2503
|
-
}
|
2504
|
-
}
|
2505
|
-
|
2506
|
-
return tokens;
|
2507
|
-
}
|
2508
|
-
|
2509
|
-
//
|
2510
|
-
// interface implementation
|
2511
|
-
//
|
2512
|
-
|
2513
|
-
struct whisper_context * whisper_init_from_file(const char * path_model) {
|
2514
|
-
whisper_model_loader loader = {};
|
2515
|
-
|
2516
|
-
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
|
2517
|
-
|
2518
|
-
auto fin = std::ifstream(path_model, std::ios::binary);
|
2519
|
-
if (!fin) {
|
2520
|
-
fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
|
2521
|
-
return nullptr;
|
2522
|
-
}
|
2523
|
-
|
2524
|
-
loader.context = &fin;
|
2525
|
-
loader.read = [](void * ctx, void * output, size_t read_size) {
|
2526
|
-
std::ifstream * fin = (std::ifstream*)ctx;
|
2527
|
-
fin->read((char *)output, read_size);
|
2528
|
-
return read_size;
|
2529
|
-
};
|
2530
|
-
|
2531
|
-
loader.eof = [](void * ctx) {
|
2532
|
-
std::ifstream * fin = (std::ifstream*)ctx;
|
2533
|
-
return fin->eof();
|
2534
|
-
};
|
2535
|
-
|
2536
|
-
loader.close = [](void * ctx) {
|
2537
|
-
std::ifstream * fin = (std::ifstream*)ctx;
|
2538
|
-
fin->close();
|
2539
|
-
};
|
2540
|
-
|
2541
|
-
return whisper_init(&loader);
|
2542
|
-
}
|
2543
|
-
|
2544
|
-
struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
|
2545
|
-
struct buf_context {
|
2546
|
-
uint8_t* buffer;
|
2547
|
-
size_t size;
|
2548
|
-
size_t current_offset;
|
2549
|
-
};
|
2550
|
-
|
2551
|
-
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
|
2552
|
-
whisper_model_loader loader = {};
|
2553
|
-
|
2554
|
-
fprintf(stderr, "%s: loading model from buffer\n", __func__);
|
2555
|
-
|
2556
|
-
loader.context = &ctx;
|
2557
|
-
|
2558
|
-
loader.read = [](void * ctx, void * output, size_t read_size) {
|
2559
|
-
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
|
2560
|
-
|
2561
|
-
size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
|
2562
|
-
|
2563
|
-
memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
|
2564
|
-
buf->current_offset += size_to_copy;
|
2565
|
-
|
2566
|
-
return size_to_copy;
|
2567
|
-
};
|
2568
|
-
|
2569
|
-
loader.eof = [](void * ctx) {
|
2570
|
-
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
|
2571
|
-
|
2572
|
-
return buf->current_offset >= buf->size;
|
2573
|
-
};
|
2574
|
-
|
2575
|
-
loader.close = [](void * /*ctx*/) { };
|
2576
|
-
|
2577
|
-
return whisper_init(&loader);
|
2578
|
-
}
|
2579
|
-
|
2580
|
-
struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
|
2581
|
-
ggml_time_init();
|
2582
|
-
|
2583
|
-
whisper_context * ctx = new whisper_context;
|
2584
|
-
|
2585
|
-
if (!whisper_model_load(loader, *ctx)) {
|
2586
|
-
loader->close(loader->context);
|
2587
|
-
fprintf(stderr, "%s: failed to load model\n", __func__);
|
2588
|
-
delete ctx;
|
2589
|
-
return nullptr;
|
2590
|
-
}
|
2591
|
-
|
2592
|
-
loader->close(loader->context);
|
2593
|
-
|
2594
|
-
return ctx;
|
2595
|
-
}
|
2596
|
-
|
2597
|
-
void whisper_free(struct whisper_context * ctx) {
|
2598
|
-
if (ctx) {
|
2599
|
-
if (ctx->model.ctx) {
|
2600
|
-
ggml_free(ctx->model.ctx);
|
2601
|
-
}
|
2602
|
-
if (ctx->model.buf) {
|
2603
|
-
delete ctx->model.buf;
|
2604
|
-
}
|
2605
|
-
if (ctx->kv_cross.ctx) {
|
2606
|
-
ggml_free(ctx->kv_cross.ctx);
|
2607
|
-
}
|
2608
|
-
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
|
2609
|
-
if (ctx->decoders[i].kv_self.ctx) {
|
2610
|
-
ggml_free(ctx->decoders[i].kv_self.ctx);
|
2611
|
-
}
|
2612
|
-
}
|
2613
|
-
delete ctx;
|
2614
|
-
}
|
2615
|
-
}
|
2616
|
-
|
2617
|
-
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
2618
|
-
if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
|
2619
|
-
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
2620
|
-
return -1;
|
2621
|
-
}
|
2622
|
-
|
2623
|
-
return 0;
|
2624
|
-
}
|
2625
|
-
|
2626
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
2627
|
-
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
2628
|
-
if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
|
2629
|
-
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
2630
|
-
return -1;
|
2631
|
-
}
|
2632
|
-
|
2633
|
-
return 0;
|
2634
|
-
}
|
2635
|
-
|
2636
|
-
int whisper_set_mel(
|
2637
|
-
struct whisper_context * ctx,
|
2638
|
-
const float * data,
|
2639
|
-
int n_len,
|
2640
|
-
int n_mel) {
|
2641
|
-
if (n_mel != WHISPER_N_MEL) {
|
2642
|
-
fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
|
2643
|
-
return -1;
|
2644
|
-
}
|
2645
|
-
|
2646
|
-
ctx->mel.n_len = n_len;
|
2647
|
-
ctx->mel.n_mel = n_mel;
|
2648
|
-
|
2649
|
-
ctx->mel.data.resize(n_len*n_mel);
|
2650
|
-
memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
|
2651
|
-
|
2652
|
-
return 0;
|
2653
|
-
}
|
2654
|
-
|
2655
|
-
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
2656
|
-
if (!whisper_encode(*ctx, offset, n_threads)) {
|
2657
|
-
fprintf(stderr, "%s: failed to eval\n", __func__);
|
2658
|
-
return -1;
|
2659
|
-
}
|
2660
|
-
|
2661
|
-
return 0;
|
2662
|
-
}
|
2663
|
-
|
2664
|
-
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
2665
|
-
// TODO: add selected_decoder_id to context
|
2666
|
-
const int selected_decoder_id = 0;
|
2667
|
-
|
2668
|
-
if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
2669
|
-
fprintf(stderr, "%s: failed to eval\n", __func__);
|
2670
|
-
return 1;
|
2671
|
-
}
|
2672
|
-
|
2673
|
-
return 0;
|
2674
|
-
}
|
2675
|
-
|
2676
|
-
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
2677
|
-
const auto res = tokenize(ctx->vocab, text);
|
2678
|
-
|
2679
|
-
if (n_max_tokens < (int) res.size()) {
|
2680
|
-
fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
2681
|
-
return -1;
|
2682
|
-
}
|
2683
|
-
|
2684
|
-
for (int i = 0; i < (int) res.size(); i++) {
|
2685
|
-
tokens[i] = res[i];
|
2686
|
-
}
|
2687
|
-
|
2688
|
-
return res.size();
|
2689
|
-
}
|
2690
|
-
|
2691
|
-
int whisper_lang_max_id() {
|
2692
|
-
auto max_id = 0;
|
2693
|
-
for (const auto & kv : g_lang) {
|
2694
|
-
max_id = std::max(max_id, kv.second.first);
|
2695
|
-
}
|
2696
|
-
|
2697
|
-
return max_id;
|
2698
|
-
}
|
2699
|
-
|
2700
|
-
int whisper_lang_id(const char * lang) {
|
2701
|
-
if (!g_lang.count(lang)) {
|
2702
|
-
for (const auto & kv : g_lang) {
|
2703
|
-
if (kv.second.second == lang) {
|
2704
|
-
return kv.second.first;
|
2705
|
-
}
|
2706
|
-
}
|
2707
|
-
|
2708
|
-
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
2709
|
-
return -1;
|
2710
|
-
}
|
2711
|
-
|
2712
|
-
return g_lang.at(lang).first;
|
2713
|
-
}
|
2714
|
-
|
2715
|
-
const char * whisper_lang_str(int id) {
|
2716
|
-
for (const auto & kv : g_lang) {
|
2717
|
-
if (kv.second.first == id) {
|
2718
|
-
return kv.first.c_str();
|
2719
|
-
}
|
2720
|
-
}
|
2721
|
-
|
2722
|
-
fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
|
2723
|
-
return nullptr;
|
2724
|
-
}
|
2725
|
-
|
2726
|
-
int whisper_lang_auto_detect(
|
2727
|
-
struct whisper_context * ctx,
|
2728
|
-
int offset_ms,
|
2729
|
-
int n_threads,
|
2730
|
-
float * lang_probs) {
|
2731
|
-
const int seek = offset_ms/10;
|
2732
|
-
|
2733
|
-
if (seek < 0) {
|
2734
|
-
fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
|
2735
|
-
return -1;
|
2736
|
-
}
|
2737
|
-
|
2738
|
-
if (seek >= ctx->mel.n_len) {
|
2739
|
-
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
|
2740
|
-
return -2;
|
2741
|
-
}
|
2742
|
-
|
2743
|
-
// run the encoder
|
2744
|
-
if (whisper_encode(ctx, seek, n_threads) != 0) {
|
2745
|
-
fprintf(stderr, "%s: failed to encode\n", __func__);
|
2746
|
-
return -6;
|
2747
|
-
}
|
2748
|
-
|
2749
|
-
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
2750
|
-
|
2751
|
-
if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
2752
|
-
fprintf(stderr, "%s: failed to decode\n", __func__);
|
2753
|
-
return -7;
|
2754
|
-
}
|
2755
|
-
|
2756
|
-
auto & logits_id = ctx->logits_id;
|
2757
|
-
logits_id.clear();
|
2758
|
-
|
2759
|
-
for (const auto & kv : g_lang) {
|
2760
|
-
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
2761
|
-
logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
|
2762
|
-
}
|
2763
|
-
|
2764
|
-
// sort descending
|
2765
|
-
{
|
2766
|
-
using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
|
2767
|
-
std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
|
2768
|
-
return a.first > b.first;
|
2769
|
-
});
|
2770
|
-
}
|
2771
|
-
|
2772
|
-
// softmax
|
2773
|
-
{
|
2774
|
-
const auto max = logits_id[0].first;
|
2775
|
-
|
2776
|
-
double sum = 0.0f;
|
2777
|
-
for (auto & kv : logits_id) {
|
2778
|
-
kv.first = exp(kv.first - max);
|
2779
|
-
sum += kv.first;
|
2780
|
-
}
|
2781
|
-
|
2782
|
-
for (auto & kv : logits_id) {
|
2783
|
-
kv.first /= sum;
|
2784
|
-
}
|
2785
|
-
}
|
2786
|
-
|
2787
|
-
{
|
2788
|
-
for (const auto & prob : logits_id) {
|
2789
|
-
if (lang_probs) {
|
2790
|
-
lang_probs[prob.second] = prob.first;
|
2791
|
-
}
|
2792
|
-
|
2793
|
-
//printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
|
2794
|
-
}
|
2795
|
-
}
|
2796
|
-
|
2797
|
-
return logits_id[0].second;
|
2798
|
-
}
|
2799
|
-
|
2800
|
-
int whisper_n_len(struct whisper_context * ctx) {
|
2801
|
-
return ctx->mel.n_len;
|
2802
|
-
}
|
2803
|
-
|
2804
|
-
int whisper_n_vocab(struct whisper_context * ctx) {
|
2805
|
-
return ctx->vocab.n_vocab;
|
2806
|
-
}
|
2807
|
-
|
2808
|
-
int whisper_n_text_ctx(struct whisper_context * ctx) {
|
2809
|
-
return ctx->model.hparams.n_text_ctx;
|
2810
|
-
}
|
2811
|
-
|
2812
|
-
int whisper_n_audio_ctx(struct whisper_context * ctx) {
|
2813
|
-
return ctx->model.hparams.n_audio_ctx;
|
2814
|
-
}
|
2815
|
-
|
2816
|
-
int whisper_is_multilingual(struct whisper_context * ctx) {
|
2817
|
-
return ctx->vocab.is_multilingual() ? 1 : 0;
|
2818
|
-
}
|
2819
|
-
|
2820
|
-
float * whisper_get_logits(struct whisper_context * ctx) {
|
2821
|
-
return ctx->logits.data();
|
2822
|
-
}
|
2823
|
-
|
2824
|
-
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
|
2825
|
-
return ctx->vocab.id_to_token.at(token).c_str();
|
2826
|
-
}
|
2827
|
-
|
2828
|
-
whisper_token whisper_token_eot(struct whisper_context * ctx) {
|
2829
|
-
return ctx->vocab.token_eot;
|
2830
|
-
}
|
2831
|
-
|
2832
|
-
whisper_token whisper_token_sot(struct whisper_context * ctx) {
|
2833
|
-
return ctx->vocab.token_sot;
|
2834
|
-
}
|
2835
|
-
|
2836
|
-
whisper_token whisper_token_prev(struct whisper_context * ctx) {
|
2837
|
-
return ctx->vocab.token_prev;
|
2838
|
-
}
|
2839
|
-
|
2840
|
-
whisper_token whisper_token_solm(struct whisper_context * ctx) {
|
2841
|
-
return ctx->vocab.token_solm;
|
2842
|
-
}
|
2843
|
-
|
2844
|
-
whisper_token whisper_token_not(struct whisper_context * ctx) {
|
2845
|
-
return ctx->vocab.token_not;
|
2846
|
-
}
|
2847
|
-
|
2848
|
-
whisper_token whisper_token_beg(struct whisper_context * ctx) {
|
2849
|
-
return ctx->vocab.token_beg;
|
2850
|
-
}
|
2851
|
-
|
2852
|
-
whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
|
2853
|
-
return whisper_token_sot(ctx) + 1 + lang_id;
|
2854
|
-
}
|
2855
|
-
|
2856
|
-
whisper_token whisper_token_translate(void) {
|
2857
|
-
return whisper_vocab::token_translate;
|
2858
|
-
}
|
2859
|
-
|
2860
|
-
whisper_token whisper_token_transcribe(void) {
|
2861
|
-
return whisper_vocab::token_transcribe;
|
2862
|
-
}
|
2863
|
-
|
2864
|
-
void whisper_print_timings(struct whisper_context * ctx) {
|
2865
|
-
const int64_t t_end_us = ggml_time_us();
|
2866
|
-
|
2867
|
-
const int32_t n_sample = std::max(1, ctx->n_sample);
|
2868
|
-
const int32_t n_encode = std::max(1, ctx->n_encode);
|
2869
|
-
const int32_t n_decode = std::max(1, ctx->n_decode);
|
2870
|
-
|
2871
|
-
fprintf(stderr, "\n");
|
2872
|
-
fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
|
2873
|
-
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
|
2874
|
-
fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
|
2875
|
-
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
|
2876
|
-
fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
|
2877
|
-
fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
|
2878
|
-
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
2879
|
-
}
|
2880
|
-
|
2881
|
-
void whisper_reset_timings(struct whisper_context * ctx) {
|
2882
|
-
ctx->t_sample_us = 0;
|
2883
|
-
ctx->t_encode_us = 0;
|
2884
|
-
ctx->t_decode_us = 0;
|
2885
|
-
}
|
2886
|
-
|
2887
|
-
const char * whisper_print_system_info(void) {
|
2888
|
-
static std::string s;
|
2889
|
-
|
2890
|
-
s = "";
|
2891
|
-
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
|
2892
|
-
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
|
2893
|
-
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
2894
|
-
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
2895
|
-
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
2896
|
-
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
2897
|
-
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
2898
|
-
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
2899
|
-
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
2900
|
-
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
2901
|
-
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
2902
|
-
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
2903
|
-
|
2904
|
-
return s.c_str();
|
2905
|
-
}
|
2906
|
-
|
2907
|
-
////////////////////////////////////////////////////////////////////////////
|
2908
|
-
|
2909
|
-
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
|
2910
|
-
struct whisper_full_params result = {
|
2911
|
-
/*.strategy =*/ strategy,
|
2912
|
-
|
2913
|
-
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
|
2914
|
-
/*.n_max_text_ctx =*/ 16384,
|
2915
|
-
/*.offset_ms =*/ 0,
|
2916
|
-
/*.duration_ms =*/ 0,
|
2917
|
-
|
2918
|
-
/*.translate =*/ false,
|
2919
|
-
/*.no_context =*/ false,
|
2920
|
-
/*.single_segment =*/ false,
|
2921
|
-
/*.print_special =*/ false,
|
2922
|
-
/*.print_progress =*/ true,
|
2923
|
-
/*.print_realtime =*/ false,
|
2924
|
-
/*.print_timestamps =*/ true,
|
2925
|
-
|
2926
|
-
/*.token_timestamps =*/ false,
|
2927
|
-
/*.thold_pt =*/ 0.01f,
|
2928
|
-
/*.thold_ptsum =*/ 0.01f,
|
2929
|
-
/*.max_len =*/ 0,
|
2930
|
-
/*.split_on_word =*/ false,
|
2931
|
-
/*.max_tokens =*/ 0,
|
2932
|
-
|
2933
|
-
/*.speed_up =*/ false,
|
2934
|
-
/*.audio_ctx =*/ 0,
|
2935
|
-
|
2936
|
-
/*.prompt_tokens =*/ nullptr,
|
2937
|
-
/*.prompt_n_tokens =*/ 0,
|
2938
|
-
|
2939
|
-
/*.language =*/ "en",
|
2940
|
-
|
2941
|
-
/*.suppress_blank =*/ true,
|
2942
|
-
/*.suppress_non_speech_tokens =*/ false,
|
2943
|
-
|
2944
|
-
/*.temperature =*/ 0.0f,
|
2945
|
-
/*.max_initial_ts =*/ 1.0f,
|
2946
|
-
/*.length_penalty =*/ -1.0f,
|
2947
|
-
|
2948
|
-
/*.temperature_inc =*/ 0.2f,
|
2949
|
-
/*.entropy_thold =*/ 2.4f,
|
2950
|
-
/*.logprob_thold =*/ -1.0f,
|
2951
|
-
/*.no_speech_thold =*/ 0.6f,
|
2952
|
-
|
2953
|
-
/*.greedy =*/ {
|
2954
|
-
/*.best_of =*/ -1,
|
2955
|
-
},
|
2956
|
-
|
2957
|
-
/*.beam_search =*/ {
|
2958
|
-
/*.beam_size =*/ -1,
|
2959
|
-
|
2960
|
-
/*.patience =*/ -1.0f,
|
2961
|
-
},
|
2962
|
-
|
2963
|
-
/*.new_segment_callback =*/ nullptr,
|
2964
|
-
/*.new_segment_callback_user_data =*/ nullptr,
|
2965
|
-
|
2966
|
-
/*.encoder_begin_callback =*/ nullptr,
|
2967
|
-
/*.encoder_begin_callback_user_data =*/ nullptr,
|
2968
|
-
|
2969
|
-
/*.logits_filter_callback =*/ nullptr,
|
2970
|
-
/*.logits_filter_callback_user_data =*/ nullptr,
|
2971
|
-
};
|
2972
|
-
|
2973
|
-
switch (strategy) {
|
2974
|
-
case WHISPER_SAMPLING_GREEDY:
|
2975
|
-
{
|
2976
|
-
result.greedy = {
|
2977
|
-
/*.best_of =*/ 1,
|
2978
|
-
};
|
2979
|
-
} break;
|
2980
|
-
case WHISPER_SAMPLING_BEAM_SEARCH:
|
2981
|
-
{
|
2982
|
-
result.beam_search = {
|
2983
|
-
/*.beam_size =*/ 5,
|
2984
|
-
|
2985
|
-
/*.patience =*/ -1.0f,
|
2986
|
-
};
|
2987
|
-
} break;
|
2988
|
-
}
|
2989
|
-
|
2990
|
-
return result;
|
2991
|
-
}
|
2992
|
-
|
2993
|
-
// forward declarations
|
2994
|
-
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
|
2995
|
-
static void whisper_exp_compute_token_level_timestamps(
|
2996
|
-
struct whisper_context & ctx,
|
2997
|
-
int i_segment,
|
2998
|
-
float thold_pt,
|
2999
|
-
float thold_ptsum);
|
3000
|
-
|
3001
|
-
// trim from start (in place)
|
3002
|
-
static inline void ltrim(std::string &s) {
|
3003
|
-
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
|
3004
|
-
return !std::isspace(ch);
|
3005
|
-
}));
|
3006
|
-
}
|
3007
|
-
|
3008
|
-
// trim from end (in place)
|
3009
|
-
static inline void rtrim(std::string &s) {
|
3010
|
-
s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
|
3011
|
-
return !std::isspace(ch);
|
3012
|
-
}).base(), s.end());
|
3013
|
-
}
|
3014
|
-
|
3015
|
-
// trim from both ends (in place)
|
3016
|
-
static inline void trim(std::string &s) {
|
3017
|
-
rtrim(s);
|
3018
|
-
ltrim(s);
|
3019
|
-
}
|
3020
|
-
|
3021
|
-
static inline bool should_split_on_word(const char * txt, bool split_on_word) {
|
3022
|
-
if (!split_on_word) return true;
|
3023
|
-
|
3024
|
-
return txt[0] == ' ';
|
3025
|
-
}
|
3026
|
-
|
3027
|
-
// wrap the last segment to max_len characters
|
3028
|
-
// returns the number of new segments
|
3029
|
-
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
|
3030
|
-
auto segment = ctx.result_all.back();
|
3031
|
-
|
3032
|
-
int res = 1;
|
3033
|
-
int acc = 0;
|
3034
|
-
|
3035
|
-
std::string text;
|
3036
|
-
|
3037
|
-
for (int i = 0; i < (int) segment.tokens.size(); i++) {
|
3038
|
-
const auto & token = segment.tokens[i];
|
3039
|
-
if (token.id >= whisper_token_eot(&ctx)) {
|
3040
|
-
continue;
|
3041
|
-
}
|
3042
|
-
|
3043
|
-
const auto txt = whisper_token_to_str(&ctx, token.id);
|
3044
|
-
const int cur = strlen(txt);
|
3045
|
-
|
3046
|
-
if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
|
3047
|
-
// split here
|
3048
|
-
if (split_on_word) {
|
3049
|
-
trim(text);
|
3050
|
-
}
|
3051
|
-
|
3052
|
-
ctx.result_all.back().text = std::move(text);
|
3053
|
-
ctx.result_all.back().t1 = token.t0;
|
3054
|
-
ctx.result_all.back().tokens.resize(i);
|
3055
|
-
|
3056
|
-
ctx.result_all.push_back({});
|
3057
|
-
ctx.result_all.back().t0 = token.t0;
|
3058
|
-
ctx.result_all.back().t1 = segment.t1;
|
3059
|
-
|
3060
|
-
// add tokens [i, end] to the new segment
|
3061
|
-
ctx.result_all.back().tokens.insert(
|
3062
|
-
ctx.result_all.back().tokens.end(),
|
3063
|
-
segment.tokens.begin() + i,
|
3064
|
-
segment.tokens.end());
|
3065
|
-
|
3066
|
-
acc = 0;
|
3067
|
-
text = "";
|
3068
|
-
|
3069
|
-
segment = ctx.result_all.back();
|
3070
|
-
i = -1;
|
3071
|
-
|
3072
|
-
res++;
|
3073
|
-
} else {
|
3074
|
-
acc += cur;
|
3075
|
-
text += txt;
|
3076
|
-
}
|
3077
|
-
}
|
3078
|
-
|
3079
|
-
if (split_on_word) {
|
3080
|
-
trim(text);
|
3081
|
-
}
|
3082
|
-
ctx.result_all.back().text = std::move(text);
|
3083
|
-
|
3084
|
-
return res;
|
3085
|
-
}
|
3086
|
-
|
3087
|
-
static const std::vector<std::string> non_speech_tokens = {
|
3088
|
-
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
|
3089
|
-
"_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
|
3090
|
-
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
|
3091
|
-
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
3092
|
-
};
|
3093
|
-
|
3094
|
-
// process the logits for the selected decoder
|
3095
|
-
// - applies logit filters
|
3096
|
-
// - computes logprobs and probs
|
3097
|
-
static void whisper_process_logits(
|
3098
|
-
struct whisper_context & ctx,
|
3099
|
-
const struct whisper_full_params params,
|
3100
|
-
struct whisper_decoder & decoder,
|
3101
|
-
float temperature) {
|
3102
|
-
const auto & vocab = ctx.vocab;
|
3103
|
-
const auto & tokens_cur = decoder.sequence.tokens;
|
3104
|
-
|
3105
|
-
const bool is_initial = tokens_cur.size() == 0;
|
3106
|
-
const int n_logits = vocab.id_to_token.size();
|
3107
|
-
|
3108
|
-
WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
|
3109
|
-
|
3110
|
-
// extract the logits for the last token
|
3111
|
-
// we will be mutating and therefore we don't want to use the ctx.logits buffer directly
|
3112
|
-
auto & probs = decoder.probs;
|
3113
|
-
auto & logits = decoder.logits;
|
3114
|
-
auto & logprobs = decoder.logprobs;
|
3115
|
-
{
|
3116
|
-
logits.resize(n_logits);
|
3117
|
-
memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
|
3118
|
-
|
3119
|
-
if (temperature > 0.0f) {
|
3120
|
-
for (int i = 0; i < n_logits; i++) {
|
3121
|
-
logits[i] /= temperature;
|
3122
|
-
}
|
3123
|
-
}
|
3124
|
-
|
3125
|
-
// will be populated a bit later
|
3126
|
-
probs.resize(n_logits);
|
3127
|
-
logprobs.resize(n_logits);
|
3128
|
-
}
|
3129
|
-
|
3130
|
-
// apply logit filters here
|
3131
|
-
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
|
3132
|
-
{
|
3133
|
-
// suppress blank
|
3134
|
-
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
|
3135
|
-
if (params.suppress_blank) {
|
3136
|
-
if (is_initial) {
|
3137
|
-
logits[vocab.token_eot] = -INFINITY;
|
3138
|
-
logits[vocab.token_to_id.at(" ")] = -INFINITY;
|
3139
|
-
}
|
3140
|
-
}
|
3141
|
-
|
3142
|
-
// suppress <|notimestamps|> token
|
3143
|
-
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
|
3144
|
-
logits[vocab.token_not] = -INFINITY;
|
3145
|
-
|
3146
|
-
// suppress sot and solm tokens
|
3147
|
-
logits[vocab.token_sot] = -INFINITY;
|
3148
|
-
logits[vocab.token_solm] = -INFINITY;
|
3149
|
-
|
3150
|
-
// suppress task tokens
|
3151
|
-
logits[vocab.token_translate] = -INFINITY;
|
3152
|
-
logits[vocab.token_transcribe] = -INFINITY;
|
3153
|
-
|
3154
|
-
if (params.logits_filter_callback) {
|
3155
|
-
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
3156
|
-
}
|
3157
|
-
|
3158
|
-
// suppress non-speech tokens
|
3159
|
-
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
3160
|
-
if (params.suppress_non_speech_tokens) {
|
3161
|
-
for (const std::string & token : non_speech_tokens) {
|
3162
|
-
const std::string suppress_tokens[] = {token, " " + token};
|
3163
|
-
for (const std::string & suppress_token : suppress_tokens) {
|
3164
|
-
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
|
3165
|
-
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
|
3166
|
-
}
|
3167
|
-
}
|
3168
|
-
}
|
3169
|
-
|
3170
|
-
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
3171
|
-
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
|
3172
|
-
logits[vocab.token_to_id.at(" -")] = -INFINITY;
|
3173
|
-
}
|
3174
|
-
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
|
3175
|
-
logits[vocab.token_to_id.at(" '")] = -INFINITY;
|
3176
|
-
}
|
3177
|
-
}
|
3178
|
-
|
3179
|
-
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
3180
|
-
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
|
3181
|
-
{
|
3182
|
-
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
|
3183
|
-
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
|
3184
|
-
|
3185
|
-
//fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
|
3186
|
-
|
3187
|
-
if (last_was_timestamp) {
|
3188
|
-
if (penultimate_was_timestamp) {
|
3189
|
-
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
3190
|
-
logits[i] = -INFINITY;
|
3191
|
-
}
|
3192
|
-
} else {
|
3193
|
-
for (int i = 0; i < vocab.token_eot; ++i) {
|
3194
|
-
logits[i] = -INFINITY;
|
3195
|
-
}
|
3196
|
-
}
|
3197
|
-
}
|
3198
|
-
}
|
3199
|
-
|
3200
|
-
// the initial timestamp cannot be larger than max_initial_ts
|
3201
|
-
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
|
3202
|
-
if (is_initial && params.max_initial_ts > 0.0f) {
|
3203
|
-
const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
|
3204
|
-
const int tid0 = std::round(params.max_initial_ts/precision);
|
3205
|
-
|
3206
|
-
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
|
3207
|
-
logits[i] = -INFINITY;
|
3208
|
-
}
|
3209
|
-
}
|
3210
|
-
|
3211
|
-
// condition timestamp tokens to be increasing
|
3212
|
-
// ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556
|
3213
|
-
if (decoder.has_ts) {
|
3214
|
-
const int tid0 = decoder.seek_delta/2;
|
3215
|
-
|
3216
|
-
for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) {
|
3217
|
-
logits[i] = -INFINITY;
|
3218
|
-
}
|
3219
|
-
}
|
3220
|
-
|
3221
|
-
// populate the logprobs array (log_softmax)
|
3222
|
-
{
|
3223
|
-
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
3224
|
-
float logsumexp = 0.0f;
|
3225
|
-
for (int i = 0; i < n_logits; ++i) {
|
3226
|
-
if (logits[i] > -INFINITY) {
|
3227
|
-
logsumexp += expf(logits[i] - logit_max);
|
3228
|
-
}
|
3229
|
-
}
|
3230
|
-
logsumexp = logf(logsumexp) + logit_max;
|
3231
|
-
|
3232
|
-
for (int i = 0; i < n_logits; ++i) {
|
3233
|
-
if (logits[i] > -INFINITY) {
|
3234
|
-
logprobs[i] = logits[i] - logsumexp;
|
3235
|
-
} else {
|
3236
|
-
logprobs[i] = -INFINITY;
|
3237
|
-
}
|
3238
|
-
}
|
3239
|
-
}
|
3240
|
-
|
3241
|
-
// if sum of probability over timestamps is above any other token, sample timestamp
|
3242
|
-
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
3243
|
-
{
|
3244
|
-
// logsumexp over timestamps
|
3245
|
-
float timestamp_logprob = -INFINITY;
|
3246
|
-
{
|
3247
|
-
float logsumexp = 0.0f;
|
3248
|
-
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
|
3249
|
-
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
3250
|
-
if (logprobs[i] > -INFINITY) {
|
3251
|
-
logsumexp += expf(logprobs[i] - logprob_max);
|
3252
|
-
}
|
3253
|
-
}
|
3254
|
-
if (logsumexp > 0.0f) {
|
3255
|
-
timestamp_logprob = logf(logsumexp) + logprob_max;
|
3256
|
-
}
|
3257
|
-
}
|
3258
|
-
|
3259
|
-
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
|
3260
|
-
|
3261
|
-
//fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
3262
|
-
|
3263
|
-
if (timestamp_logprob > max_text_token_logprob) {
|
3264
|
-
for (int i = 0; i < vocab.token_beg; ++i) {
|
3265
|
-
logits[i] = -INFINITY;
|
3266
|
-
logprobs[i] = -INFINITY;
|
3267
|
-
}
|
3268
|
-
}
|
3269
|
-
}
|
3270
|
-
}
|
3271
|
-
|
3272
|
-
// compute probs
|
3273
|
-
{
|
3274
|
-
for (int i = 0; i < n_logits; ++i) {
|
3275
|
-
if (logits[i] == -INFINITY) {
|
3276
|
-
probs[i] = 0.0f;
|
3277
|
-
} else {
|
3278
|
-
probs[i] = expf(logprobs[i]);
|
3279
|
-
}
|
3280
|
-
}
|
3281
|
-
}
|
3282
|
-
|
3283
|
-
#if 0
|
3284
|
-
// print first 100 logits - token string : logit
|
3285
|
-
for (int i = 0; i < 100; i++) {
|
3286
|
-
const auto token = vocab.id_to_token.at(i);
|
3287
|
-
const auto prob = probs[i];
|
3288
|
-
const auto logit = logits[i];
|
3289
|
-
const auto logprob = logprobs[i];
|
3290
|
-
printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
|
3291
|
-
}
|
3292
|
-
|
3293
|
-
// "And", "and", " And", " and"
|
3294
|
-
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
3295
|
-
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
3296
|
-
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
3297
|
-
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
3298
|
-
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
3299
|
-
|
3300
|
-
printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
3301
|
-
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
3302
|
-
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
3303
|
-
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
3304
|
-
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
3305
|
-
|
3306
|
-
printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
3307
|
-
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
3308
|
-
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
3309
|
-
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
3310
|
-
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
3311
|
-
#endif
|
3312
|
-
}
|
3313
|
-
|
3314
|
-
static whisper_token_data whisper_sample_token(
|
3315
|
-
whisper_context & ctx,
|
3316
|
-
const whisper_decoder & decoder,
|
3317
|
-
bool best) {
|
3318
|
-
whisper_token_data result = {
|
3319
|
-
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
3320
|
-
};
|
3321
|
-
|
3322
|
-
const auto & vocab = ctx.vocab;
|
3323
|
-
|
3324
|
-
const auto & probs = decoder.probs;
|
3325
|
-
const auto & logprobs = decoder.logprobs;
|
3326
|
-
|
3327
|
-
const int n_logits = vocab.n_vocab;
|
3328
|
-
|
3329
|
-
{
|
3330
|
-
double sum_ts = 0.0;
|
3331
|
-
double max_ts = 0.0;
|
3332
|
-
|
3333
|
-
for (int i = vocab.token_beg; i < n_logits; i++) {
|
3334
|
-
if (probs[i] == -INFINITY) {
|
3335
|
-
continue;
|
3336
|
-
}
|
3337
|
-
|
3338
|
-
sum_ts += probs[i];
|
3339
|
-
if (max_ts < probs[i]) {
|
3340
|
-
max_ts = probs[i];
|
3341
|
-
result.tid = i;
|
3342
|
-
}
|
3343
|
-
}
|
3344
|
-
|
3345
|
-
result.pt = max_ts/(sum_ts + 1e-10);
|
3346
|
-
result.ptsum = sum_ts;
|
3347
|
-
}
|
3348
|
-
|
3349
|
-
if (best) {
|
3350
|
-
for (int i = 0; i < n_logits; ++i) {
|
3351
|
-
if (result.p < probs[i]) {
|
3352
|
-
result.id = i;
|
3353
|
-
result.p = probs[i];
|
3354
|
-
result.plog = logprobs[i];
|
3355
|
-
}
|
3356
|
-
}
|
3357
|
-
} else {
|
3358
|
-
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
3359
|
-
|
3360
|
-
result.id = dist(ctx.rng);
|
3361
|
-
result.p = probs[result.id];
|
3362
|
-
result.plog = logprobs[result.id];
|
3363
|
-
}
|
3364
|
-
|
3365
|
-
if (result.id >= vocab.token_beg) {
|
3366
|
-
result.tid = result.id;
|
3367
|
-
result.pt = result.p;
|
3368
|
-
}
|
3369
|
-
|
3370
|
-
ctx.n_sample++;
|
3371
|
-
|
3372
|
-
return result;
|
3373
|
-
}
|
3374
|
-
|
3375
|
-
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
3376
|
-
whisper_context & ctx,
|
3377
|
-
const whisper_decoder & decoder,
|
3378
|
-
int k) {
|
3379
|
-
const auto & vocab = ctx.vocab;
|
3380
|
-
|
3381
|
-
const auto & probs = decoder.probs;
|
3382
|
-
const auto & logits = decoder.logits;
|
3383
|
-
const auto & logprobs = decoder.logprobs;
|
3384
|
-
|
3385
|
-
const int n_logits = vocab.n_vocab;
|
3386
|
-
|
3387
|
-
auto & logits_id = ctx.logits_id;
|
3388
|
-
|
3389
|
-
logits_id.clear();
|
3390
|
-
for (int i = 0; i < n_logits; ++i) {
|
3391
|
-
logits_id.push_back({ logits[i], i });
|
3392
|
-
}
|
3393
|
-
|
3394
|
-
std::partial_sort(
|
3395
|
-
logits_id.begin(),
|
3396
|
-
logits_id.begin() + k, logits_id.end(),
|
3397
|
-
[](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
|
3398
|
-
return a.first > b.first;
|
3399
|
-
});
|
3400
|
-
|
3401
|
-
std::vector<whisper_token_data> result;
|
3402
|
-
result.reserve(k);
|
3403
|
-
|
3404
|
-
whisper_token tid = vocab.token_beg;
|
3405
|
-
|
3406
|
-
float pt = 0.0;
|
3407
|
-
float ptsum = 0.0;
|
3408
|
-
|
3409
|
-
{
|
3410
|
-
double sum_ts = 0.0;
|
3411
|
-
double max_ts = 0.0;
|
3412
|
-
|
3413
|
-
for (int i = vocab.token_beg; i < n_logits; i++) {
|
3414
|
-
if (probs[i] == -INFINITY) {
|
3415
|
-
continue;
|
3416
|
-
}
|
3417
|
-
|
3418
|
-
sum_ts += probs[i];
|
3419
|
-
if (max_ts < probs[i]) {
|
3420
|
-
max_ts = probs[i];
|
3421
|
-
tid = i;
|
3422
|
-
}
|
3423
|
-
}
|
3424
|
-
|
3425
|
-
pt = max_ts/(sum_ts + 1e-10);
|
3426
|
-
ptsum = sum_ts;
|
3427
|
-
}
|
3428
|
-
|
3429
|
-
for (int i = 0; i < k; ++i) {
|
3430
|
-
const auto id = logits_id[i].second;
|
3431
|
-
|
3432
|
-
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
3433
|
-
|
3434
|
-
if (result[i].id >= vocab.token_beg) {
|
3435
|
-
result[i].tid = result[i].id;
|
3436
|
-
result[i].pt = result[i].p;
|
3437
|
-
}
|
3438
|
-
}
|
3439
|
-
|
3440
|
-
ctx.n_sample++;
|
3441
|
-
|
3442
|
-
return result;
|
3443
|
-
}
|
3444
|
-
|
3445
|
-
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
|
3446
|
-
static void whisper_sequence_score(
|
3447
|
-
const struct whisper_full_params & params,
|
3448
|
-
whisper_sequence & sequence) {
|
3449
|
-
if (sequence.result_len == 0) {
|
3450
|
-
return;
|
3451
|
-
}
|
3452
|
-
|
3453
|
-
double result = 0.0f;
|
3454
|
-
|
3455
|
-
for (int i = 0; i < sequence.result_len; ++i) {
|
3456
|
-
result += sequence.tokens[i].plog;
|
3457
|
-
}
|
3458
|
-
|
3459
|
-
sequence.sum_logprobs = result;
|
3460
|
-
sequence.avg_logprobs = result/sequence.result_len;
|
3461
|
-
|
3462
|
-
double penalty = sequence.result_len;
|
3463
|
-
|
3464
|
-
if (params.length_penalty > 0.0f) {
|
3465
|
-
penalty = pow((5.0 + penalty)/6.0, params.length_penalty);
|
3466
|
-
}
|
3467
|
-
|
3468
|
-
sequence.score = result/penalty;
|
3469
|
-
|
3470
|
-
// compute the entropy of the sequence of the last 32 tokens
|
3471
|
-
{
|
3472
|
-
const int n = 32;
|
3473
|
-
|
3474
|
-
int cnt = 0;
|
3475
|
-
double entropy = 0.0f;
|
3476
|
-
|
3477
|
-
std::map<whisper_token, int> token_counts;
|
3478
|
-
for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) {
|
3479
|
-
token_counts[sequence.tokens[i].id]++;
|
3480
|
-
cnt++;
|
3481
|
-
}
|
3482
|
-
|
3483
|
-
for (const auto & kv : token_counts) {
|
3484
|
-
const auto p = kv.second/(double)cnt;
|
3485
|
-
entropy -= p*log(p);
|
3486
|
-
|
3487
|
-
//WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
|
3488
|
-
}
|
3489
|
-
|
3490
|
-
sequence.entropy = entropy;
|
3491
|
-
}
|
3492
|
-
}
|
3493
|
-
|
3494
|
-
int whisper_full(
|
3495
|
-
struct whisper_context * ctx,
|
3496
|
-
struct whisper_full_params params,
|
3497
|
-
const float * samples,
|
3498
|
-
int n_samples) {
|
3499
|
-
// clear old results
|
3500
|
-
auto & result_all = ctx->result_all;
|
3501
|
-
|
3502
|
-
result_all.clear();
|
3503
|
-
|
3504
|
-
// compute log mel spectrogram
|
3505
|
-
if (params.speed_up) {
|
3506
|
-
if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
|
3507
|
-
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
3508
|
-
return -1;
|
3509
|
-
}
|
3510
|
-
} else {
|
3511
|
-
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
|
3512
|
-
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
3513
|
-
return -2;
|
3514
|
-
}
|
3515
|
-
}
|
3516
|
-
|
3517
|
-
// auto-detect language if not specified
|
3518
|
-
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
3519
|
-
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
3520
|
-
|
3521
|
-
const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
|
3522
|
-
if (lang_id < 0) {
|
3523
|
-
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
|
3524
|
-
return -3;
|
3525
|
-
}
|
3526
|
-
ctx->lang_id = lang_id;
|
3527
|
-
params.language = whisper_lang_str(lang_id);
|
3528
|
-
|
3529
|
-
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
3530
|
-
}
|
3531
|
-
|
3532
|
-
if (params.token_timestamps) {
|
3533
|
-
ctx->t_beg = 0;
|
3534
|
-
ctx->t_last = 0;
|
3535
|
-
ctx->tid_last = 0;
|
3536
|
-
ctx->energy = get_signal_energy(samples, n_samples, 32);
|
3537
|
-
}
|
3538
|
-
|
3539
|
-
const int seek_start = params.offset_ms/10;
|
3540
|
-
const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
|
3541
|
-
|
3542
|
-
// if length of spectrogram is less than 1s (100 samples), then return
|
3543
|
-
// basically don't process anything that is less than 1s
|
3544
|
-
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
3545
|
-
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
|
3546
|
-
return 0;
|
3547
|
-
}
|
3548
|
-
|
3549
|
-
// a set of temperatures to use
|
3550
|
-
// [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
|
3551
|
-
std::vector<float> temperatures;
|
3552
|
-
if (params.temperature_inc > 0.0f) {
|
3553
|
-
for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) {
|
3554
|
-
temperatures.push_back(t);
|
3555
|
-
}
|
3556
|
-
} else {
|
3557
|
-
temperatures.push_back(params.temperature);
|
3558
|
-
}
|
3559
|
-
|
3560
|
-
// initialize the decoders
|
3561
|
-
int n_decoders = 1;
|
3562
|
-
|
3563
|
-
switch (params.strategy) {
|
3564
|
-
case WHISPER_SAMPLING_GREEDY:
|
3565
|
-
{
|
3566
|
-
n_decoders = params.greedy.best_of;
|
3567
|
-
} break;
|
3568
|
-
case WHISPER_SAMPLING_BEAM_SEARCH:
|
3569
|
-
{
|
3570
|
-
n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size);
|
3571
|
-
} break;
|
3572
|
-
};
|
3573
|
-
|
3574
|
-
n_decoders = std::max(1, n_decoders);
|
3575
|
-
|
3576
|
-
// TAGS: WHISPER_DECODER_INIT
|
3577
|
-
for (int j = 1; j < n_decoders && ctx->running; j++) {
|
3578
|
-
auto & decoder = ctx->decoders[j];
|
3579
|
-
|
3580
|
-
if (decoder.kv_self.ctx == nullptr) {
|
3581
|
-
decoder.kv_self = ctx->decoders[0].kv_self;
|
3582
|
-
if (!kv_cache_reinit(decoder.kv_self)) {
|
3583
|
-
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
3584
|
-
return -4;
|
3585
|
-
}
|
3586
|
-
|
3587
|
-
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
3588
|
-
|
3589
|
-
decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
|
3590
|
-
|
3591
|
-
decoder.probs.resize (ctx->vocab.n_vocab);
|
3592
|
-
decoder.logits.resize (ctx->vocab.n_vocab);
|
3593
|
-
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
3594
|
-
}
|
3595
|
-
}
|
3596
|
-
|
3597
|
-
// the accumulated text context so far
|
3598
|
-
auto & prompt_past = ctx->prompt_past;
|
3599
|
-
if (params.no_context) {
|
3600
|
-
prompt_past.clear();
|
3601
|
-
}
|
3602
|
-
|
3603
|
-
// prepend the prompt tokens to the prompt_past
|
3604
|
-
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
|
3605
|
-
// parse tokens from the pointer
|
3606
|
-
for (int i = 0; i < params.prompt_n_tokens; i++) {
|
3607
|
-
prompt_past.push_back(params.prompt_tokens[i]);
|
3608
|
-
}
|
3609
|
-
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
3610
|
-
}
|
3611
|
-
|
3612
|
-
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
3613
|
-
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
3614
|
-
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
3615
|
-
return -5;
|
3616
|
-
}
|
3617
|
-
ctx->exp_n_audio_ctx = params.audio_ctx;
|
3618
|
-
|
3619
|
-
// these tokens determine the task that will be performed
|
3620
|
-
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
3621
|
-
if (whisper_is_multilingual(ctx)) {
|
3622
|
-
const int lang_id = whisper_lang_id(params.language);
|
3623
|
-
ctx->lang_id = lang_id;
|
3624
|
-
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
|
3625
|
-
if (params.translate) {
|
3626
|
-
prompt_init.push_back(whisper_token_translate());
|
3627
|
-
} else {
|
3628
|
-
prompt_init.push_back(whisper_token_transcribe());
|
3629
|
-
}
|
3630
|
-
}
|
3631
|
-
|
3632
|
-
int progress_prev = 0;
|
3633
|
-
int progress_step = 5;
|
3634
|
-
|
3635
|
-
int seek = seek_start;
|
3636
|
-
|
3637
|
-
std::vector<whisper_token> prompt;
|
3638
|
-
prompt.reserve(whisper_n_text_ctx(ctx));
|
3639
|
-
|
3640
|
-
// beam-search helpers
|
3641
|
-
struct kv_buf {
|
3642
|
-
std::vector<uint8_t> k;
|
3643
|
-
std::vector<uint8_t> v;
|
3644
|
-
};
|
3645
|
-
|
3646
|
-
std::vector<kv_buf> kv_bufs;
|
3647
|
-
|
3648
|
-
struct beam_candidate {
|
3649
|
-
int decoder_idx;
|
3650
|
-
int seek_delta;
|
3651
|
-
|
3652
|
-
bool has_ts;
|
3653
|
-
|
3654
|
-
whisper_sequence sequence;
|
3655
|
-
};
|
3656
|
-
|
3657
|
-
std::vector<beam_candidate> beam_candidates;
|
3658
|
-
|
3659
|
-
// main loop
|
3660
|
-
while (ctx->running) {
|
3661
|
-
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
|
3662
|
-
while (progress_cur >= progress_prev + progress_step) {
|
3663
|
-
progress_prev += progress_step;
|
3664
|
-
if (params.print_progress) {
|
3665
|
-
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
|
3666
|
-
}
|
3667
|
-
}
|
3668
|
-
|
3669
|
-
// of only 1 second left, then stop
|
3670
|
-
if (seek + 100 >= seek_end) {
|
3671
|
-
break;
|
3672
|
-
}
|
3673
|
-
|
3674
|
-
if (params.encoder_begin_callback) {
|
3675
|
-
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
|
3676
|
-
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
|
3677
|
-
break;
|
3678
|
-
}
|
3679
|
-
}
|
3680
|
-
|
3681
|
-
// encode audio features starting at offset seek
|
3682
|
-
if (!whisper_encode(*ctx, seek, params.n_threads)) {
|
3683
|
-
fprintf(stderr, "%s: failed to encode\n", __func__);
|
3684
|
-
return -6;
|
3685
|
-
}
|
3686
|
-
|
3687
|
-
// if there is a very short audio segment left to process, we remove any past prompt since it tends
|
3688
|
-
// to confuse the decoder and often make it repeat or hallucinate stuff
|
3689
|
-
if (seek > seek_start && seek + 500 >= seek_end) {
|
3690
|
-
prompt_past.clear();
|
3691
|
-
}
|
3692
|
-
|
3693
|
-
int best_decoder_id = 0;
|
3694
|
-
|
3695
|
-
for (int it = 0; it < (int) temperatures.size(); ++it) {
|
3696
|
-
const float t_cur = temperatures[it];
|
3697
|
-
|
3698
|
-
int n_decoders_cur = 1;
|
3699
|
-
|
3700
|
-
switch (params.strategy) {
|
3701
|
-
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
3702
|
-
{
|
3703
|
-
if (t_cur > 0.0f) {
|
3704
|
-
n_decoders_cur = params.greedy.best_of;
|
3705
|
-
}
|
3706
|
-
} break;
|
3707
|
-
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
3708
|
-
{
|
3709
|
-
if (t_cur > 0.0f) {
|
3710
|
-
n_decoders_cur = params.greedy.best_of;
|
3711
|
-
} else {
|
3712
|
-
n_decoders_cur = params.beam_search.beam_size;
|
3713
|
-
}
|
3714
|
-
} break;
|
3715
|
-
};
|
3716
|
-
|
3717
|
-
n_decoders_cur = std::max(1, n_decoders_cur);
|
3718
|
-
|
3719
|
-
WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
|
3720
|
-
|
3721
|
-
// TAGS: WHISPER_DECODER_INIT
|
3722
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
3723
|
-
auto & decoder = ctx->decoders[j];
|
3724
|
-
|
3725
|
-
decoder.kv_self.n = 0;
|
3726
|
-
|
3727
|
-
decoder.sequence.tokens.clear();
|
3728
|
-
decoder.sequence.result_len = 0;
|
3729
|
-
decoder.sequence.sum_logprobs_all = 0.0;
|
3730
|
-
decoder.sequence.sum_logprobs = -INFINITY;
|
3731
|
-
decoder.sequence.avg_logprobs = -INFINITY;
|
3732
|
-
decoder.sequence.entropy = 0.0;
|
3733
|
-
decoder.sequence.score = -INFINITY;
|
3734
|
-
|
3735
|
-
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
|
3736
|
-
|
3737
|
-
decoder.failed = false;
|
3738
|
-
decoder.completed = false;
|
3739
|
-
decoder.has_ts = false;
|
3740
|
-
}
|
3741
|
-
|
3742
|
-
// init prompt and kv cache for the current iteration
|
3743
|
-
// run whisper_decoder() only for decoder 0 and copy the results for the other decoders
|
3744
|
-
{
|
3745
|
-
prompt.clear();
|
3746
|
-
|
3747
|
-
// if we have already generated some text, use it as a prompt to condition the next generation
|
3748
|
-
if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
|
3749
|
-
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
|
3750
|
-
|
3751
|
-
prompt = { whisper_token_prev(ctx) };
|
3752
|
-
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
|
3753
|
-
}
|
3754
|
-
|
3755
|
-
// init new transcription with sot, language (opt) and task tokens
|
3756
|
-
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
3757
|
-
|
3758
|
-
// print the prompt
|
3759
|
-
WHISPER_PRINT_DEBUG("\n\n");
|
3760
|
-
for (int i = 0; i < (int) prompt.size(); i++) {
|
3761
|
-
WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
|
3762
|
-
}
|
3763
|
-
WHISPER_PRINT_DEBUG("\n\n");
|
3764
|
-
|
3765
|
-
if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
|
3766
|
-
fprintf(stderr, "%s: failed to decode\n", __func__);
|
3767
|
-
return -7;
|
3768
|
-
}
|
3769
|
-
|
3770
|
-
{
|
3771
|
-
const int64_t t_start_sample_us = ggml_time_us();
|
3772
|
-
|
3773
|
-
whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
|
3774
|
-
|
3775
|
-
ctx->decoders[0].kv_self.n += prompt.size();
|
3776
|
-
|
3777
|
-
for (int j = 1; j < n_decoders_cur; ++j) {
|
3778
|
-
auto & decoder = ctx->decoders[j];
|
3779
|
-
|
3780
|
-
memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
|
3781
|
-
memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
|
3782
|
-
|
3783
|
-
decoder.kv_self.n += prompt.size();
|
3784
|
-
|
3785
|
-
memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
3786
|
-
memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
3787
|
-
memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
3788
|
-
}
|
3789
|
-
|
3790
|
-
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
3791
|
-
}
|
3792
|
-
}
|
3793
|
-
|
3794
|
-
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
3795
|
-
const int64_t t_start_sample_us = ggml_time_us();
|
3796
|
-
|
3797
|
-
// store the KV caches of all decoders when doing beam-search
|
3798
|
-
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
3799
|
-
kv_bufs.resize(n_decoders_cur);
|
3800
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
3801
|
-
auto & decoder = ctx->decoders[j];
|
3802
|
-
|
3803
|
-
if (decoder.completed || decoder.failed) {
|
3804
|
-
continue;
|
3805
|
-
}
|
3806
|
-
|
3807
|
-
kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
|
3808
|
-
kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
|
3809
|
-
|
3810
|
-
memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
|
3811
|
-
memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
|
3812
|
-
}
|
3813
|
-
|
3814
|
-
beam_candidates.clear();
|
3815
|
-
}
|
3816
|
-
|
3817
|
-
// generate new sequence candidates for each decoder
|
3818
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
3819
|
-
auto & decoder = ctx->decoders[j];
|
3820
|
-
|
3821
|
-
if (decoder.completed || decoder.failed) {
|
3822
|
-
continue;
|
3823
|
-
}
|
3824
|
-
|
3825
|
-
switch (params.strategy) {
|
3826
|
-
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
3827
|
-
{
|
3828
|
-
if (t_cur < 1e-6f) {
|
3829
|
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
3830
|
-
} else {
|
3831
|
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
3832
|
-
}
|
3833
|
-
|
3834
|
-
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
3835
|
-
} break;
|
3836
|
-
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
3837
|
-
{
|
3838
|
-
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
3839
|
-
|
3840
|
-
for (const auto & token : tokens_new) {
|
3841
|
-
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
|
3842
|
-
beam_candidates.back().sequence.tokens.push_back(token);
|
3843
|
-
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
|
3844
|
-
|
3845
|
-
//WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
|
3846
|
-
}
|
3847
|
-
} break;
|
3848
|
-
};
|
3849
|
-
}
|
3850
|
-
|
3851
|
-
// for beam-search, choose the top candidates and update the KV caches
|
3852
|
-
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
3853
|
-
std::sort(
|
3854
|
-
beam_candidates.begin(),
|
3855
|
-
beam_candidates.end(),
|
3856
|
-
[](const beam_candidate & a, const beam_candidate & b) {
|
3857
|
-
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
3858
|
-
});
|
3859
|
-
|
3860
|
-
uint32_t cur_c = 0;
|
3861
|
-
|
3862
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
3863
|
-
auto & decoder = ctx->decoders[j];
|
3864
|
-
|
3865
|
-
if (decoder.completed || decoder.failed) {
|
3866
|
-
continue;
|
3867
|
-
}
|
3868
|
-
|
3869
|
-
auto & cur = beam_candidates[cur_c++];
|
3870
|
-
|
3871
|
-
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
|
3872
|
-
++cur_c;
|
3873
|
-
}
|
3874
|
-
|
3875
|
-
decoder.sequence = cur.sequence;
|
3876
|
-
decoder.seek_delta = cur.seek_delta;
|
3877
|
-
decoder.has_ts = cur.has_ts;
|
3878
|
-
|
3879
|
-
memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
|
3880
|
-
memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
|
3881
|
-
|
3882
|
-
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
3883
|
-
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
3884
|
-
}
|
3885
|
-
}
|
3886
|
-
|
3887
|
-
// update the decoder state
|
3888
|
-
// - check if the sequence is completed
|
3889
|
-
// - check if the sequence is failed
|
3890
|
-
// - update sliding window based on timestamp tokens
|
3891
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
3892
|
-
auto & decoder = ctx->decoders[j];
|
3893
|
-
|
3894
|
-
if (decoder.completed || decoder.failed) {
|
3895
|
-
continue;
|
3896
|
-
}
|
3897
|
-
|
3898
|
-
auto & has_ts = decoder.has_ts;
|
3899
|
-
auto & failed = decoder.failed;
|
3900
|
-
auto & completed = decoder.completed;
|
3901
|
-
auto & seek_delta = decoder.seek_delta;
|
3902
|
-
auto & result_len = decoder.sequence.result_len;
|
3903
|
-
|
3904
|
-
{
|
3905
|
-
const auto & token = decoder.sequence.tokens.back();
|
3906
|
-
|
3907
|
-
// timestamp token - update sliding window
|
3908
|
-
if (token.id > whisper_token_beg(ctx)) {
|
3909
|
-
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
|
3910
|
-
|
3911
|
-
// do not allow to go back in time
|
3912
|
-
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
3913
|
-
failed = true; // TODO: maybe this is not a failure ?
|
3914
|
-
continue;
|
3915
|
-
}
|
3916
|
-
|
3917
|
-
seek_delta = seek_delta_new;
|
3918
|
-
result_len = i + 1;
|
3919
|
-
has_ts = true;
|
3920
|
-
}
|
3921
|
-
|
3922
|
-
#ifdef WHISPER_DEBUG
|
3923
|
-
{
|
3924
|
-
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
|
3925
|
-
WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
|
3926
|
-
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
|
3927
|
-
}
|
3928
|
-
#endif
|
3929
|
-
|
3930
|
-
// end of segment
|
3931
|
-
if (token.id == whisper_token_eot(ctx) || // end of text token
|
3932
|
-
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
3933
|
-
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
3934
|
-
) {
|
3935
|
-
if (result_len == 0) {
|
3936
|
-
if (seek + seek_delta + 100 >= seek_end) {
|
3937
|
-
result_len = i + 1;
|
3938
|
-
} else {
|
3939
|
-
failed = true;
|
3940
|
-
continue;
|
3941
|
-
}
|
3942
|
-
}
|
3943
|
-
|
3944
|
-
if (params.single_segment) {
|
3945
|
-
result_len = i + 1;
|
3946
|
-
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
3947
|
-
}
|
3948
|
-
|
3949
|
-
completed = true;
|
3950
|
-
continue;
|
3951
|
-
}
|
3952
|
-
|
3953
|
-
// TESTS: if no tensors are loaded, it means we are running tests
|
3954
|
-
if (ctx->model.n_loaded == 0) {
|
3955
|
-
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
3956
|
-
completed = true;
|
3957
|
-
continue;
|
3958
|
-
}
|
3959
|
-
}
|
3960
|
-
|
3961
|
-
// sometimes, the decoding can get stuck in a repetition loop
|
3962
|
-
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
3963
|
-
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
3964
|
-
failed = true;
|
3965
|
-
continue;
|
3966
|
-
}
|
3967
|
-
}
|
3968
|
-
|
3969
|
-
// check if all decoders have finished (i.e. completed or failed)
|
3970
|
-
{
|
3971
|
-
bool completed_all = true;
|
3972
|
-
|
3973
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
3974
|
-
auto & decoder = ctx->decoders[j];
|
3975
|
-
|
3976
|
-
if (decoder.completed || decoder.failed) {
|
3977
|
-
continue;
|
3978
|
-
}
|
3979
|
-
|
3980
|
-
completed_all = false;
|
3981
|
-
}
|
3982
|
-
|
3983
|
-
if (completed_all) {
|
3984
|
-
break;
|
3985
|
-
}
|
3986
|
-
}
|
3987
|
-
|
3988
|
-
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
3989
|
-
|
3990
|
-
// obtain logits for the next token
|
3991
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
3992
|
-
auto & decoder = ctx->decoders[j];
|
3993
|
-
|
3994
|
-
if (decoder.failed || decoder.completed) {
|
3995
|
-
continue;
|
3996
|
-
}
|
3997
|
-
|
3998
|
-
decoder.tokens_tmp.resize(1);
|
3999
|
-
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
|
4000
|
-
|
4001
|
-
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
4002
|
-
|
4003
|
-
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
|
4004
|
-
fprintf(stderr, "%s: failed to decode\n", __func__);
|
4005
|
-
return -8;
|
4006
|
-
}
|
4007
|
-
|
4008
|
-
{
|
4009
|
-
const int64_t t_start_sample_us = ggml_time_us();
|
4010
|
-
|
4011
|
-
whisper_process_logits(*ctx, params, decoder, t_cur);
|
4012
|
-
|
4013
|
-
++decoder.kv_self.n;
|
4014
|
-
|
4015
|
-
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
4016
|
-
}
|
4017
|
-
}
|
4018
|
-
}
|
4019
|
-
|
4020
|
-
// rank the resulting sequences and select the best one
|
4021
|
-
{
|
4022
|
-
double best_score = -INFINITY;
|
4023
|
-
|
4024
|
-
for (int j = 0; j < n_decoders_cur; ++j) {
|
4025
|
-
auto & decoder = ctx->decoders[j];
|
4026
|
-
|
4027
|
-
if (decoder.failed) {
|
4028
|
-
continue;
|
4029
|
-
}
|
4030
|
-
|
4031
|
-
decoder.sequence.tokens.resize(decoder.sequence.result_len);
|
4032
|
-
whisper_sequence_score(params, decoder.sequence);
|
4033
|
-
|
4034
|
-
WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
|
4035
|
-
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
|
4036
|
-
|
4037
|
-
if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
|
4038
|
-
WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
|
4039
|
-
__func__, j, decoder.sequence.entropy, params.entropy_thold);
|
4040
|
-
|
4041
|
-
decoder.failed = true;
|
4042
|
-
ctx->n_fail_h++;
|
4043
|
-
|
4044
|
-
continue;
|
4045
|
-
}
|
4046
|
-
|
4047
|
-
if (best_score < decoder.sequence.score) {
|
4048
|
-
best_score = decoder.sequence.score;
|
4049
|
-
best_decoder_id = j;
|
4050
|
-
}
|
4051
|
-
}
|
4052
|
-
|
4053
|
-
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
4054
|
-
}
|
4055
|
-
|
4056
|
-
// was the decoding successful for the current temperature?
|
4057
|
-
{
|
4058
|
-
bool success = true;
|
4059
|
-
|
4060
|
-
const auto & decoder = ctx->decoders[best_decoder_id];
|
4061
|
-
|
4062
|
-
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
4063
|
-
success = false;
|
4064
|
-
ctx->n_fail_p++;
|
4065
|
-
}
|
4066
|
-
|
4067
|
-
if (success) {
|
4068
|
-
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
4069
|
-
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
4070
|
-
//}
|
4071
|
-
|
4072
|
-
break;
|
4073
|
-
}
|
4074
|
-
}
|
4075
|
-
|
4076
|
-
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
4077
|
-
}
|
4078
|
-
|
4079
|
-
// output results through a user-provided callback
|
4080
|
-
{
|
4081
|
-
const auto & best_decoder = ctx->decoders[best_decoder_id];
|
4082
|
-
|
4083
|
-
const auto seek_delta = best_decoder.seek_delta;
|
4084
|
-
const auto result_len = best_decoder.sequence.result_len;
|
4085
|
-
|
4086
|
-
const auto & tokens_cur = best_decoder.sequence.tokens;
|
4087
|
-
|
4088
|
-
//WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
4089
|
-
|
4090
|
-
// update prompt_past
|
4091
|
-
prompt_past.clear();
|
4092
|
-
if (prompt.front() == whisper_token_prev(ctx)) {
|
4093
|
-
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
4094
|
-
}
|
4095
|
-
|
4096
|
-
for (int i = 0; i < result_len; ++i) {
|
4097
|
-
prompt_past.push_back(tokens_cur[i].id);
|
4098
|
-
}
|
4099
|
-
|
4100
|
-
// store the text from this iteration
|
4101
|
-
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
|
4102
|
-
int i0 = 0;
|
4103
|
-
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
4104
|
-
|
4105
|
-
std::string text;
|
4106
|
-
|
4107
|
-
for (int i = 0; i < (int) tokens_cur.size(); i++) {
|
4108
|
-
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
|
4109
|
-
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
|
4110
|
-
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
|
4111
|
-
|
4112
|
-
if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
|
4113
|
-
} else {
|
4114
|
-
text += whisper_token_to_str(ctx, tokens_cur[i].id);
|
4115
|
-
}
|
4116
|
-
|
4117
|
-
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
|
4118
|
-
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
4119
|
-
|
4120
|
-
if (!text.empty()) {
|
4121
|
-
const auto tt0 = params.speed_up ? 2*t0 : t0;
|
4122
|
-
const auto tt1 = params.speed_up ? 2*t1 : t1;
|
4123
|
-
|
4124
|
-
if (params.print_realtime) {
|
4125
|
-
if (params.print_timestamps) {
|
4126
|
-
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
|
4127
|
-
} else {
|
4128
|
-
printf("%s", text.c_str());
|
4129
|
-
fflush(stdout);
|
4130
|
-
}
|
4131
|
-
}
|
4132
|
-
|
4133
|
-
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
|
4134
|
-
|
4135
|
-
result_all.push_back({ tt0, tt1, text, {} });
|
4136
|
-
for (int j = i0; j <= i; j++) {
|
4137
|
-
result_all.back().tokens.push_back(tokens_cur[j]);
|
4138
|
-
}
|
4139
|
-
|
4140
|
-
int n_new = 1;
|
4141
|
-
|
4142
|
-
if (params.token_timestamps) {
|
4143
|
-
whisper_exp_compute_token_level_timestamps(
|
4144
|
-
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
4145
|
-
|
4146
|
-
if (params.max_len > 0) {
|
4147
|
-
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
|
4148
|
-
}
|
4149
|
-
}
|
4150
|
-
if (params.new_segment_callback) {
|
4151
|
-
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
4152
|
-
}
|
4153
|
-
}
|
4154
|
-
text = "";
|
4155
|
-
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
4156
|
-
i++;
|
4157
|
-
}
|
4158
|
-
i--;
|
4159
|
-
t0 = t1;
|
4160
|
-
i0 = i + 1;
|
4161
|
-
}
|
4162
|
-
}
|
4163
|
-
|
4164
|
-
if (!text.empty()) {
|
4165
|
-
const auto t1 = seek + seek_delta;
|
4166
|
-
|
4167
|
-
const auto tt0 = params.speed_up ? 2*t0 : t0;
|
4168
|
-
const auto tt1 = params.speed_up ? 2*t1 : t1;
|
4169
|
-
|
4170
|
-
if (params.print_realtime) {
|
4171
|
-
if (params.print_timestamps) {
|
4172
|
-
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
|
4173
|
-
} else {
|
4174
|
-
printf("%s", text.c_str());
|
4175
|
-
fflush(stdout);
|
4176
|
-
}
|
4177
|
-
}
|
4178
|
-
|
4179
|
-
result_all.push_back({ tt0, tt1, text, {} });
|
4180
|
-
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
4181
|
-
result_all.back().tokens.push_back(tokens_cur[j]);
|
4182
|
-
}
|
4183
|
-
|
4184
|
-
int n_new = 1;
|
4185
|
-
|
4186
|
-
if (params.token_timestamps) {
|
4187
|
-
whisper_exp_compute_token_level_timestamps(
|
4188
|
-
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
4189
|
-
|
4190
|
-
if (params.max_len > 0) {
|
4191
|
-
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
|
4192
|
-
}
|
4193
|
-
}
|
4194
|
-
if (params.new_segment_callback) {
|
4195
|
-
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
4196
|
-
}
|
4197
|
-
}
|
4198
|
-
}
|
4199
|
-
|
4200
|
-
// update audio window
|
4201
|
-
seek += seek_delta;
|
4202
|
-
|
4203
|
-
WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
|
4204
|
-
}
|
4205
|
-
}
|
4206
|
-
|
4207
|
-
return 0;
|
4208
|
-
}
|
4209
|
-
|
4210
|
-
void whisper_running_abort(struct whisper_context * ctx) {
|
4211
|
-
ctx->running = false;
|
4212
|
-
}
|
4213
|
-
|
4214
|
-
void whisper_running_restore(struct whisper_context * ctx) {
|
4215
|
-
ctx->running = true;
|
4216
|
-
}
|
4217
|
-
|
4218
|
-
bool whisper_running_state(struct whisper_context * ctx) {
|
4219
|
-
return ctx->running;
|
4220
|
-
}
|
4221
|
-
|
4222
|
-
int whisper_full_parallel(
|
4223
|
-
struct whisper_context * ctx,
|
4224
|
-
struct whisper_full_params params,
|
4225
|
-
const float * samples,
|
4226
|
-
int n_samples,
|
4227
|
-
int n_processors) {
|
4228
|
-
if (n_processors == 1) {
|
4229
|
-
return whisper_full(ctx, params, samples, n_samples);
|
4230
|
-
}
|
4231
|
-
|
4232
|
-
int ret = 0;
|
4233
|
-
|
4234
|
-
// prepare separate contexts for each thread
|
4235
|
-
std::vector<struct whisper_context> ctxs(n_processors - 1);
|
4236
|
-
|
4237
|
-
for (int i = 0; i < n_processors - 1; ++i) {
|
4238
|
-
auto & ctx_p = ctxs[i];
|
4239
|
-
|
4240
|
-
ctx_p = *ctx;
|
4241
|
-
|
4242
|
-
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
4243
|
-
|
4244
|
-
ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
|
4245
|
-
|
4246
|
-
if (!kv_cache_reinit(ctx_p.kv_cross)) {
|
4247
|
-
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
|
4248
|
-
return false;
|
4249
|
-
}
|
4250
|
-
|
4251
|
-
// TAGS: WHISPER_DECODER_INIT
|
4252
|
-
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
|
4253
|
-
if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
|
4254
|
-
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
|
4255
|
-
return false;
|
4256
|
-
}
|
4257
|
-
|
4258
|
-
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
|
4259
|
-
|
4260
|
-
ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
|
4261
|
-
ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
|
4262
|
-
ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
|
4263
|
-
}
|
4264
|
-
}
|
4265
|
-
|
4266
|
-
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
|
4267
|
-
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
|
4268
|
-
|
4269
|
-
// the calling thread will process the first chunk
|
4270
|
-
// while the other threads will process the remaining chunks
|
4271
|
-
|
4272
|
-
std::vector<std::thread> workers(n_processors - 1);
|
4273
|
-
for (int i = 0; i < n_processors - 1; ++i) {
|
4274
|
-
const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
|
4275
|
-
const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
|
4276
|
-
|
4277
|
-
auto params_cur = params;
|
4278
|
-
|
4279
|
-
params_cur.offset_ms = 0;
|
4280
|
-
params_cur.print_progress = false;
|
4281
|
-
params_cur.print_realtime = false;
|
4282
|
-
|
4283
|
-
params_cur.new_segment_callback = nullptr;
|
4284
|
-
params_cur.new_segment_callback_user_data = nullptr;
|
4285
|
-
|
4286
|
-
workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
|
4287
|
-
}
|
4288
|
-
|
4289
|
-
{
|
4290
|
-
auto params_cur = params;
|
4291
|
-
|
4292
|
-
ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
|
4293
|
-
}
|
4294
|
-
|
4295
|
-
for (int i = 0; i < n_processors - 1; ++i) {
|
4296
|
-
workers[i].join();
|
4297
|
-
}
|
4298
|
-
|
4299
|
-
const int64_t offset_t = (int64_t) params.offset_ms/10.0;
|
4300
|
-
|
4301
|
-
// combine results into ctx->result_all
|
4302
|
-
for (int i = 0; i < n_processors - 1; ++i) {
|
4303
|
-
auto & results_i = ctxs[i].result_all;
|
4304
|
-
|
4305
|
-
for (auto & result : results_i) {
|
4306
|
-
// correct the segment timestamp taking into account the offset
|
4307
|
-
result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
4308
|
-
result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
4309
|
-
|
4310
|
-
// make sure that segments are not overlapping
|
4311
|
-
if (!ctx->result_all.empty()) {
|
4312
|
-
result.t0 = std::max(result.t0, ctx->result_all.back().t1);
|
4313
|
-
}
|
4314
|
-
|
4315
|
-
ctx->result_all.push_back(std::move(result));
|
4316
|
-
|
4317
|
-
// call the new_segment_callback for each segment
|
4318
|
-
if (params.new_segment_callback) {
|
4319
|
-
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
|
4320
|
-
}
|
4321
|
-
}
|
4322
|
-
|
4323
|
-
ctx->t_mel_us += ctxs[i].t_mel_us;
|
4324
|
-
ctx->t_sample_us += ctxs[i].t_sample_us;
|
4325
|
-
ctx->t_encode_us += ctxs[i].t_encode_us;
|
4326
|
-
ctx->t_decode_us += ctxs[i].t_decode_us;
|
4327
|
-
|
4328
|
-
kv_cache_free(ctx->kv_cross);
|
4329
|
-
|
4330
|
-
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
|
4331
|
-
kv_cache_free(ctx->decoders[j].kv_self);
|
4332
|
-
}
|
4333
|
-
}
|
4334
|
-
|
4335
|
-
// average the timings
|
4336
|
-
ctx->t_mel_us /= n_processors;
|
4337
|
-
ctx->t_sample_us /= n_processors;
|
4338
|
-
ctx->t_encode_us /= n_processors;
|
4339
|
-
ctx->t_decode_us /= n_processors;
|
4340
|
-
|
4341
|
-
// print information about the audio boundaries
|
4342
|
-
fprintf(stderr, "\n");
|
4343
|
-
fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
|
4344
|
-
for (int i = 0; i < n_processors - 1; ++i) {
|
4345
|
-
fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
|
4346
|
-
}
|
4347
|
-
fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
|
4348
|
-
|
4349
|
-
return ret;
|
4350
|
-
}
|
4351
|
-
|
4352
|
-
int whisper_full_n_segments(struct whisper_context * ctx) {
|
4353
|
-
return ctx->result_all.size();
|
4354
|
-
}
|
4355
|
-
|
4356
|
-
int whisper_full_lang_id(struct whisper_context * ctx) {
|
4357
|
-
return ctx->lang_id;
|
4358
|
-
}
|
4359
|
-
|
4360
|
-
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
4361
|
-
return ctx->result_all[i_segment].t0;
|
4362
|
-
}
|
4363
|
-
|
4364
|
-
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
|
4365
|
-
return ctx->result_all[i_segment].t1;
|
4366
|
-
}
|
4367
|
-
|
4368
|
-
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
|
4369
|
-
return ctx->result_all[i_segment].text.c_str();
|
4370
|
-
}
|
4371
|
-
|
4372
|
-
int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
|
4373
|
-
return ctx->result_all[i_segment].tokens.size();
|
4374
|
-
}
|
4375
|
-
|
4376
|
-
const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
|
4377
|
-
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
|
4378
|
-
}
|
4379
|
-
|
4380
|
-
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
|
4381
|
-
return ctx->result_all[i_segment].tokens[i_token].id;
|
4382
|
-
}
|
4383
|
-
|
4384
|
-
struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
|
4385
|
-
return ctx->result_all[i_segment].tokens[i_token];
|
4386
|
-
}
|
4387
|
-
|
4388
|
-
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
4389
|
-
return ctx->result_all[i_segment].tokens[i_token].p;
|
4390
|
-
}
|
4391
|
-
|
4392
|
-
// =================================================================================================
|
4393
|
-
|
4394
|
-
//
|
4395
|
-
// Temporary interface needed for exposing ggml interface
|
4396
|
-
// Will be removed in the future when ggml becomes a separate library
|
4397
|
-
//
|
4398
|
-
|
4399
|
-
WHISPER_API int whisper_bench_memcpy(int n_threads) {
|
4400
|
-
ggml_time_init();
|
4401
|
-
|
4402
|
-
size_t n = 50;
|
4403
|
-
size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations
|
4404
|
-
|
4405
|
-
// 1 GB array
|
4406
|
-
const size_t size = arr*1024llu*1024llu;
|
4407
|
-
|
4408
|
-
char * src = (char *) malloc(size);
|
4409
|
-
char * dst = (char *) malloc(size);
|
4410
|
-
|
4411
|
-
for (size_t i = 0; i < size; i++) src[i] = i;
|
4412
|
-
|
4413
|
-
memcpy(dst, src, size); // heat-up
|
4414
|
-
|
4415
|
-
double tsum = 0.0;
|
4416
|
-
|
4417
|
-
for (size_t i = 0; i < n; i++) {
|
4418
|
-
const int64_t t0 = ggml_time_us();
|
4419
|
-
|
4420
|
-
memcpy(dst, src, size);
|
4421
|
-
|
4422
|
-
const int64_t t1 = ggml_time_us();
|
4423
|
-
|
4424
|
-
tsum += (t1 - t0)*1e-6;
|
4425
|
-
|
4426
|
-
src[0] = rand();
|
4427
|
-
}
|
4428
|
-
|
4429
|
-
fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
|
4430
|
-
|
4431
|
-
// needed to prevent the compile from optimizing the memcpy away
|
4432
|
-
{
|
4433
|
-
double sum = 0.0;
|
4434
|
-
|
4435
|
-
for (size_t i = 0; i < size; i++) sum += dst[i];
|
4436
|
-
|
4437
|
-
fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
|
4438
|
-
}
|
4439
|
-
|
4440
|
-
free(src);
|
4441
|
-
free(dst);
|
4442
|
-
|
4443
|
-
return 0;
|
4444
|
-
}
|
4445
|
-
|
4446
|
-
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
|
4447
|
-
ggml_time_init();
|
4448
|
-
|
4449
|
-
const int n_max = 128;
|
4450
|
-
|
4451
|
-
const std::vector<size_t> sizes = {
|
4452
|
-
64, 128, 256, 512, 1024, 2048, 4096,
|
4453
|
-
};
|
4454
|
-
|
4455
|
-
const size_t N_max = sizes.back();
|
4456
|
-
|
4457
|
-
// a: N*N*sizeof(float)
|
4458
|
-
// b: N*N*sizeof(float)
|
4459
|
-
// c: N*N*sizeof(float)
|
4460
|
-
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
4461
|
-
std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
|
4462
|
-
|
4463
|
-
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
4464
|
-
|
4465
|
-
for (int j = 0; j < (int) sizes.size(); j++) {
|
4466
|
-
int n_fp16 = 0;
|
4467
|
-
int n_fp32 = 0;
|
4468
|
-
|
4469
|
-
// GFLOPS/s
|
4470
|
-
double s_fp16 = 0.0;
|
4471
|
-
double s_fp32 = 0.0;
|
4472
|
-
|
4473
|
-
const size_t N = sizes[j];
|
4474
|
-
|
4475
|
-
for (int k = 0; k < 2; ++k) {
|
4476
|
-
const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
4477
|
-
|
4478
|
-
double & s = k == 0 ? s_fp16 : s_fp32;
|
4479
|
-
int & n = k == 0 ? n_fp16 : n_fp32;
|
4480
|
-
|
4481
|
-
struct ggml_init_params gparams = {
|
4482
|
-
/*.mem_size =*/ buf.size(),
|
4483
|
-
/*.mem_buffer =*/ buf.data(),
|
4484
|
-
};
|
4485
|
-
|
4486
|
-
struct ggml_context * ctx0 = ggml_init(gparams);
|
4487
|
-
|
4488
|
-
struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N);
|
4489
|
-
struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N);
|
4490
|
-
|
4491
|
-
struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
|
4492
|
-
|
4493
|
-
struct ggml_cgraph gf = ggml_build_forward(c);
|
4494
|
-
|
4495
|
-
gf.n_threads = n_threads;
|
4496
|
-
|
4497
|
-
double tsum = 0.0;
|
4498
|
-
|
4499
|
-
// heat-up
|
4500
|
-
ggml_graph_compute(ctx0, &gf);
|
4501
|
-
|
4502
|
-
for (int i = 0; i < n_max; ++i) {
|
4503
|
-
const int64_t t0 = ggml_time_us();
|
4504
|
-
|
4505
|
-
ggml_graph_compute(ctx0, &gf);
|
4506
|
-
|
4507
|
-
const int64_t t1 = ggml_time_us();
|
4508
|
-
|
4509
|
-
tsum += (t1 - t0)*1e-6;
|
4510
|
-
n++;
|
4511
|
-
|
4512
|
-
if (tsum > 1.0 && n >= 3) {
|
4513
|
-
break;
|
4514
|
-
}
|
4515
|
-
}
|
4516
|
-
|
4517
|
-
ggml_free(ctx0);
|
4518
|
-
|
4519
|
-
s = ((2.0*N*N*N*n)/tsum)*1e-9;
|
4520
|
-
}
|
4521
|
-
|
4522
|
-
fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
|
4523
|
-
N, N, s_fp16, n_fp16, s_fp32, n_fp32);
|
4524
|
-
}
|
4525
|
-
|
4526
|
-
return 0;
|
4527
|
-
}
|
4528
|
-
|
4529
|
-
// =================================================================================================
|
4530
|
-
|
4531
|
-
// =================================================================================================
|
4532
|
-
|
4533
|
-
//
|
4534
|
-
// Experimental stuff below
|
4535
|
-
//
|
4536
|
-
// Not sure if these should be part of the library at all, because the quality of the results is not
|
4537
|
-
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
|
4538
|
-
//
|
4539
|
-
|
4540
|
-
// =================================================================================================
|
4541
|
-
|
4542
|
-
//
|
4543
|
-
// token-level timestamps
|
4544
|
-
//
|
4545
|
-
|
4546
|
-
static int timestamp_to_sample(int64_t t, int n_samples) {
|
4547
|
-
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
4548
|
-
}
|
4549
|
-
|
4550
|
-
static int64_t sample_to_timestamp(int i_sample) {
|
4551
|
-
return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
|
4552
|
-
}
|
4553
|
-
|
4554
|
-
// a cost-function / heuristic that is high for text that takes longer to pronounce
|
4555
|
-
// obviously, can be improved
|
4556
|
-
static float voice_length(const std::string & text) {
|
4557
|
-
float res = 0.0f;
|
4558
|
-
|
4559
|
-
for (char c : text) {
|
4560
|
-
if (c == ' ') {
|
4561
|
-
res += 0.01f;
|
4562
|
-
} else if (c == ',') {
|
4563
|
-
res += 2.00f;
|
4564
|
-
} else if (c == '.') {
|
4565
|
-
res += 3.00f;
|
4566
|
-
} else if (c == '!') {
|
4567
|
-
res += 3.00f;
|
4568
|
-
} else if (c == '?') {
|
4569
|
-
res += 3.00f;
|
4570
|
-
} else if (c >= '0' && c <= '9') {
|
4571
|
-
res += 3.00f;
|
4572
|
-
} else {
|
4573
|
-
res += 1.00f;
|
4574
|
-
}
|
4575
|
-
}
|
4576
|
-
|
4577
|
-
return res;
|
4578
|
-
}
|
4579
|
-
|
4580
|
-
// average the fabs of the signal
|
4581
|
-
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
|
4582
|
-
const int hw = n_samples_per_half_window;
|
4583
|
-
|
4584
|
-
std::vector<float> result(n_samples);
|
4585
|
-
|
4586
|
-
for (int i = 0; i < n_samples; i++) {
|
4587
|
-
float sum = 0;
|
4588
|
-
for (int j = -hw; j <= hw; j++) {
|
4589
|
-
if (i + j >= 0 && i + j < n_samples) {
|
4590
|
-
sum += fabs(signal[i + j]);
|
4591
|
-
}
|
4592
|
-
}
|
4593
|
-
result[i] = sum/(2*hw + 1);
|
4594
|
-
}
|
4595
|
-
|
4596
|
-
return result;
|
4597
|
-
}
|
4598
|
-
|
4599
|
-
static void whisper_exp_compute_token_level_timestamps(
|
4600
|
-
struct whisper_context & ctx,
|
4601
|
-
int i_segment,
|
4602
|
-
float thold_pt,
|
4603
|
-
float thold_ptsum) {
|
4604
|
-
auto & segment = ctx.result_all[i_segment];
|
4605
|
-
auto & tokens = segment.tokens;
|
4606
|
-
|
4607
|
-
const int n_samples = ctx.energy.size();
|
4608
|
-
|
4609
|
-
if (n_samples == 0) {
|
4610
|
-
fprintf(stderr, "%s: no signal data available\n", __func__);
|
4611
|
-
return;
|
4612
|
-
}
|
4613
|
-
|
4614
|
-
const int64_t t0 = segment.t0;
|
4615
|
-
const int64_t t1 = segment.t1;
|
4616
|
-
|
4617
|
-
const int n = tokens.size();
|
4618
|
-
|
4619
|
-
if (n == 0) {
|
4620
|
-
return;
|
4621
|
-
}
|
4622
|
-
|
4623
|
-
if (n == 1) {
|
4624
|
-
tokens[0].t0 = t0;
|
4625
|
-
tokens[0].t1 = t1;
|
4626
|
-
|
4627
|
-
return;
|
4628
|
-
}
|
4629
|
-
|
4630
|
-
auto & t_beg = ctx.t_beg;
|
4631
|
-
auto & t_last = ctx.t_last;
|
4632
|
-
auto & tid_last = ctx.tid_last;
|
4633
|
-
|
4634
|
-
for (int j = 0; j < n; ++j) {
|
4635
|
-
auto & token = tokens[j];
|
4636
|
-
|
4637
|
-
if (j == 0) {
|
4638
|
-
if (token.id == whisper_token_beg(&ctx)) {
|
4639
|
-
tokens[j ].t0 = t0;
|
4640
|
-
tokens[j ].t1 = t0;
|
4641
|
-
tokens[j + 1].t0 = t0;
|
4642
|
-
|
4643
|
-
t_beg = t0;
|
4644
|
-
t_last = t0;
|
4645
|
-
tid_last = whisper_token_beg(&ctx);
|
4646
|
-
} else {
|
4647
|
-
tokens[j ].t0 = t_last;
|
4648
|
-
}
|
4649
|
-
}
|
4650
|
-
|
4651
|
-
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
|
4652
|
-
|
4653
|
-
tokens[j].id = token.id;
|
4654
|
-
tokens[j].tid = token.tid;
|
4655
|
-
tokens[j].p = token.p;
|
4656
|
-
tokens[j].pt = token.pt;
|
4657
|
-
tokens[j].ptsum = token.ptsum;
|
4658
|
-
|
4659
|
-
tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
|
4660
|
-
|
4661
|
-
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
|
4662
|
-
if (j > 0) {
|
4663
|
-
tokens[j - 1].t1 = tt;
|
4664
|
-
}
|
4665
|
-
tokens[j].t0 = tt;
|
4666
|
-
tid_last = token.tid;
|
4667
|
-
}
|
4668
|
-
}
|
4669
|
-
|
4670
|
-
tokens[n - 2].t1 = t1;
|
4671
|
-
tokens[n - 1].t0 = t1;
|
4672
|
-
tokens[n - 1].t1 = t1;
|
4673
|
-
|
4674
|
-
t_last = t1;
|
4675
|
-
|
4676
|
-
// find intervals of tokens with unknown timestamps
|
4677
|
-
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
|
4678
|
-
{
|
4679
|
-
int p0 = 0;
|
4680
|
-
int p1 = 0;
|
4681
|
-
|
4682
|
-
while (true) {
|
4683
|
-
while (p1 < n && tokens[p1].t1 < 0) {
|
4684
|
-
p1++;
|
4685
|
-
}
|
4686
|
-
|
4687
|
-
if (p1 >= n) {
|
4688
|
-
p1--;
|
4689
|
-
}
|
4690
|
-
|
4691
|
-
//printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1);
|
4692
|
-
|
4693
|
-
if (p1 > p0) {
|
4694
|
-
double psum = 0.0;
|
4695
|
-
for (int j = p0; j <= p1; j++) {
|
4696
|
-
psum += tokens[j].vlen;
|
4697
|
-
}
|
4698
|
-
|
4699
|
-
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
|
4700
|
-
|
4701
|
-
const double dt = tokens[p1].t1 - tokens[p0].t0;
|
4702
|
-
|
4703
|
-
// split the time proportionally to the voice length
|
4704
|
-
for (int j = p0 + 1; j <= p1; j++) {
|
4705
|
-
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
|
4706
|
-
|
4707
|
-
tokens[j - 1].t1 = ct;
|
4708
|
-
tokens[j ].t0 = ct;
|
4709
|
-
}
|
4710
|
-
}
|
4711
|
-
|
4712
|
-
p1++;
|
4713
|
-
p0 = p1;
|
4714
|
-
if (p1 >= n) {
|
4715
|
-
break;
|
4716
|
-
}
|
4717
|
-
}
|
4718
|
-
}
|
4719
|
-
|
4720
|
-
// fix up (just in case)
|
4721
|
-
for (int j = 0; j < n - 1; j++) {
|
4722
|
-
if (tokens[j].t1 < 0) {
|
4723
|
-
tokens[j + 1].t0 = tokens[j].t1;
|
4724
|
-
}
|
4725
|
-
|
4726
|
-
if (j > 0) {
|
4727
|
-
if (tokens[j - 1].t1 > tokens[j].t0) {
|
4728
|
-
tokens[j].t0 = tokens[j - 1].t1;
|
4729
|
-
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
|
4730
|
-
}
|
4731
|
-
}
|
4732
|
-
}
|
4733
|
-
|
4734
|
-
// VAD
|
4735
|
-
// expand or contract tokens based on voice activity
|
4736
|
-
{
|
4737
|
-
const int hw = WHISPER_SAMPLE_RATE/8;
|
4738
|
-
|
4739
|
-
for (int j = 0; j < n; j++) {
|
4740
|
-
if (tokens[j].id >= whisper_token_eot(&ctx)) {
|
4741
|
-
continue;
|
4742
|
-
}
|
4743
|
-
|
4744
|
-
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
|
4745
|
-
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
|
4746
|
-
|
4747
|
-
const int ss0 = std::max(s0 - hw, 0);
|
4748
|
-
const int ss1 = std::min(s1 + hw, n_samples);
|
4749
|
-
|
4750
|
-
const int ns = ss1 - ss0;
|
4751
|
-
|
4752
|
-
float sum = 0.0f;
|
4753
|
-
|
4754
|
-
for (int k = ss0; k < ss1; k++) {
|
4755
|
-
sum += ctx.energy[k];
|
4756
|
-
}
|
4757
|
-
|
4758
|
-
const float thold = 0.5*sum/ns;
|
4759
|
-
|
4760
|
-
{
|
4761
|
-
int k = s0;
|
4762
|
-
if (ctx.energy[k] > thold && j > 0) {
|
4763
|
-
while (k > 0 && ctx.energy[k] > thold) {
|
4764
|
-
k--;
|
4765
|
-
}
|
4766
|
-
tokens[j].t0 = sample_to_timestamp(k);
|
4767
|
-
if (tokens[j].t0 < tokens[j - 1].t1) {
|
4768
|
-
tokens[j].t0 = tokens[j - 1].t1;
|
4769
|
-
} else {
|
4770
|
-
s0 = k;
|
4771
|
-
}
|
4772
|
-
} else {
|
4773
|
-
while (ctx.energy[k] < thold && k < s1) {
|
4774
|
-
k++;
|
4775
|
-
}
|
4776
|
-
s0 = k;
|
4777
|
-
tokens[j].t0 = sample_to_timestamp(k);
|
4778
|
-
}
|
4779
|
-
}
|
4780
|
-
|
4781
|
-
{
|
4782
|
-
int k = s1;
|
4783
|
-
if (ctx.energy[k] > thold) {
|
4784
|
-
while (k < n_samples - 1 && ctx.energy[k] > thold) {
|
4785
|
-
k++;
|
4786
|
-
}
|
4787
|
-
tokens[j].t1 = sample_to_timestamp(k);
|
4788
|
-
if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
4789
|
-
tokens[j].t1 = tokens[j + 1].t0;
|
4790
|
-
} else {
|
4791
|
-
s1 = k;
|
4792
|
-
}
|
4793
|
-
} else {
|
4794
|
-
while (ctx.energy[k] < thold && k > s0) {
|
4795
|
-
k--;
|
4796
|
-
}
|
4797
|
-
s1 = k;
|
4798
|
-
tokens[j].t1 = sample_to_timestamp(k);
|
4799
|
-
}
|
4800
|
-
}
|
4801
|
-
}
|
4802
|
-
}
|
4803
|
-
|
4804
|
-
// fixed token expand (optional)
|
4805
|
-
//{
|
4806
|
-
// const int t_expand = 0;
|
4807
|
-
|
4808
|
-
// for (int j = 0; j < n; j++) {
|
4809
|
-
// if (j > 0) {
|
4810
|
-
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
|
4811
|
-
// }
|
4812
|
-
// if (j < n - 1) {
|
4813
|
-
// tokens[j].t1 = tokens[j].t1 + t_expand;
|
4814
|
-
// }
|
4815
|
-
// }
|
4816
|
-
//}
|
4817
|
-
|
4818
|
-
// debug info
|
4819
|
-
//for (int j = 0; j < n; ++j) {
|
4820
|
-
// const auto & token = tokens[j];
|
4821
|
-
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]";
|
4822
|
-
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
|
4823
|
-
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id));
|
4824
|
-
|
4825
|
-
// if (tokens[j].id >= whisper_token_eot(&ctx)) {
|
4826
|
-
// continue;
|
4827
|
-
// }
|
4828
|
-
//}
|
4829
|
-
}
|