whispercpp 1.2.0.2 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (135) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +46 -86
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -7
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/ggml/include/ggml.h +2285 -0
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/include/whisper.h +672 -0
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1608 -159
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/src/whisper.cpp +7393 -0
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -8616
  133. data/ext/ggml.h +0 -748
  134. data/ext/whisper.cpp +0 -4829
  135. data/ext/whisper.h +0 -402
data/ext/whisper.cpp DELETED
@@ -1,4829 +0,0 @@
1
- #define WHISPER_BUILD
2
- #include "whisper.h"
3
-
4
- #include "ggml.h"
5
-
6
- #include <algorithm>
7
- #include <cassert>
8
- #define _USE_MATH_DEFINES
9
- #include <cmath>
10
- #include <cstdio>
11
- #include <cstring>
12
- #include <fstream>
13
- #include <map>
14
- #include <string>
15
- #include <thread>
16
- #include <vector>
17
- #include <regex>
18
- #include <random>
19
-
20
- #if defined(GGML_BIG_ENDIAN)
21
- #include <bit>
22
-
23
- template<typename T>
24
- static T byteswap(T value) {
25
- return std::byteswap(value);
26
- }
27
-
28
- template<>
29
- float byteswap(float value) {
30
- return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
31
- }
32
-
33
- template<typename T>
34
- static void byteswap_tensor_data(ggml_tensor * tensor) {
35
- T * datum = reinterpret_cast<T *>(tensor->data);
36
- for (int i = 0; i < ggml_nelements(tensor); i++) {
37
- datum[i] = byteswap(datum[i]);
38
- }
39
- }
40
-
41
- static void byteswap_tensor(ggml_tensor * tensor) {
42
- switch (tensor->type) {
43
- case GGML_TYPE_I16: {
44
- byteswap_tensor_data<int16_t>(tensor);
45
- break;
46
- }
47
- case GGML_TYPE_F16: {
48
- byteswap_tensor_data<ggml_fp16_t>(tensor);
49
- break;
50
- }
51
- case GGML_TYPE_I32: {
52
- byteswap_tensor_data<int32_t>(tensor);
53
- break;
54
- }
55
- case GGML_TYPE_F32: {
56
- byteswap_tensor_data<float>(tensor);
57
- break;
58
- }
59
- default: { // GML_TYPE_I8
60
- break;
61
- }
62
- }
63
- }
64
-
65
- #define BYTESWAP_VALUE(d) d = byteswap(d)
66
- #define BYTESWAP_FILTERS(f) \
67
- do { \
68
- for (auto & datum : f.data) { \
69
- datum = byteswap(datum); \
70
- } \
71
- } while (0)
72
- #define BYTESWAP_TENSOR(t) \
73
- do { \
74
- byteswap_tensor(tensor); \
75
- } while (0)
76
- #else
77
- #define BYTESWAP_VALUE(d) do {} while (0)
78
- #define BYTESWAP_FILTERS(f) do {} while (0)
79
- #define BYTESWAP_TENSOR(t) do {} while (0)
80
- #endif
81
-
82
- #define WHISPER_ASSERT(x) \
83
- do { \
84
- if (!(x)) { \
85
- fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
86
- abort(); \
87
- } \
88
- } while (0)
89
-
90
- // define this to enable verbose trace logging - useful for debugging purposes
91
- //#define WHISPER_DEBUG
92
-
93
- #if defined(WHISPER_DEBUG)
94
- #define WHISPER_PRINT_DEBUG(...) \
95
- do { \
96
- fprintf(stderr, __VA_ARGS__); \
97
- } while (0)
98
- #else
99
- #define WHISPER_PRINT_DEBUG(...)
100
- #endif
101
-
102
- #define WHISPER_USE_FLASH_ATTN
103
- //#define WHISPER_USE_FLASH_FF
104
- #define WHISPER_MAX_DECODERS 16
105
-
106
- #define WHISPER_USE_SCRATCH
107
- #define WHISPER_MAX_SCRATCH_BUFFERS 16
108
-
109
- // available whisper models
110
- enum e_model {
111
- MODEL_UNKNOWN,
112
- MODEL_TINY,
113
- MODEL_BASE,
114
- MODEL_SMALL,
115
- MODEL_MEDIUM,
116
- MODEL_LARGE,
117
- };
118
-
119
- static const std::map<std::string, std::pair<int, std::string>> g_lang = {
120
- { "en", { 0, "english", } },
121
- { "zh", { 1, "chinese", } },
122
- { "de", { 2, "german", } },
123
- { "es", { 3, "spanish", } },
124
- { "ru", { 4, "russian", } },
125
- { "ko", { 5, "korean", } },
126
- { "fr", { 6, "french", } },
127
- { "ja", { 7, "japanese", } },
128
- { "pt", { 8, "portuguese", } },
129
- { "tr", { 9, "turkish", } },
130
- { "pl", { 10, "polish", } },
131
- { "ca", { 11, "catalan", } },
132
- { "nl", { 12, "dutch", } },
133
- { "ar", { 13, "arabic", } },
134
- { "sv", { 14, "swedish", } },
135
- { "it", { 15, "italian", } },
136
- { "id", { 16, "indonesian", } },
137
- { "hi", { 17, "hindi", } },
138
- { "fi", { 18, "finnish", } },
139
- { "vi", { 19, "vietnamese", } },
140
- { "iw", { 20, "hebrew", } },
141
- { "uk", { 21, "ukrainian", } },
142
- { "el", { 22, "greek", } },
143
- { "ms", { 23, "malay", } },
144
- { "cs", { 24, "czech", } },
145
- { "ro", { 25, "romanian", } },
146
- { "da", { 26, "danish", } },
147
- { "hu", { 27, "hungarian", } },
148
- { "ta", { 28, "tamil", } },
149
- { "no", { 29, "norwegian", } },
150
- { "th", { 30, "thai", } },
151
- { "ur", { 31, "urdu", } },
152
- { "hr", { 32, "croatian", } },
153
- { "bg", { 33, "bulgarian", } },
154
- { "lt", { 34, "lithuanian", } },
155
- { "la", { 35, "latin", } },
156
- { "mi", { 36, "maori", } },
157
- { "ml", { 37, "malayalam", } },
158
- { "cy", { 38, "welsh", } },
159
- { "sk", { 39, "slovak", } },
160
- { "te", { 40, "telugu", } },
161
- { "fa", { 41, "persian", } },
162
- { "lv", { 42, "latvian", } },
163
- { "bn", { 43, "bengali", } },
164
- { "sr", { 44, "serbian", } },
165
- { "az", { 45, "azerbaijani", } },
166
- { "sl", { 46, "slovenian", } },
167
- { "kn", { 47, "kannada", } },
168
- { "et", { 48, "estonian", } },
169
- { "mk", { 49, "macedonian", } },
170
- { "br", { 50, "breton", } },
171
- { "eu", { 51, "basque", } },
172
- { "is", { 52, "icelandic", } },
173
- { "hy", { 53, "armenian", } },
174
- { "ne", { 54, "nepali", } },
175
- { "mn", { 55, "mongolian", } },
176
- { "bs", { 56, "bosnian", } },
177
- { "kk", { 57, "kazakh", } },
178
- { "sq", { 58, "albanian", } },
179
- { "sw", { 59, "swahili", } },
180
- { "gl", { 60, "galician", } },
181
- { "mr", { 61, "marathi", } },
182
- { "pa", { 62, "punjabi", } },
183
- { "si", { 63, "sinhala", } },
184
- { "km", { 64, "khmer", } },
185
- { "sn", { 65, "shona", } },
186
- { "yo", { 66, "yoruba", } },
187
- { "so", { 67, "somali", } },
188
- { "af", { 68, "afrikaans", } },
189
- { "oc", { 69, "occitan", } },
190
- { "ka", { 70, "georgian", } },
191
- { "be", { 71, "belarusian", } },
192
- { "tg", { 72, "tajik", } },
193
- { "sd", { 73, "sindhi", } },
194
- { "gu", { 74, "gujarati", } },
195
- { "am", { 75, "amharic", } },
196
- { "yi", { 76, "yiddish", } },
197
- { "lo", { 77, "lao", } },
198
- { "uz", { 78, "uzbek", } },
199
- { "fo", { 79, "faroese", } },
200
- { "ht", { 80, "haitian creole", } },
201
- { "ps", { 81, "pashto", } },
202
- { "tk", { 82, "turkmen", } },
203
- { "nn", { 83, "nynorsk", } },
204
- { "mt", { 84, "maltese", } },
205
- { "sa", { 85, "sanskrit", } },
206
- { "lb", { 86, "luxembourgish", } },
207
- { "my", { 87, "myanmar", } },
208
- { "bo", { 88, "tibetan", } },
209
- { "tl", { 89, "tagalog", } },
210
- { "mg", { 90, "malagasy", } },
211
- { "as", { 91, "assamese", } },
212
- { "tt", { 92, "tatar", } },
213
- { "haw", { 93, "hawaiian", } },
214
- { "ln", { 94, "lingala", } },
215
- { "ha", { 95, "hausa", } },
216
- { "ba", { 96, "bashkir", } },
217
- { "jw", { 97, "javanese", } },
218
- { "su", { 98, "sundanese", } },
219
- };
220
-
221
- static const size_t MB = 1024*1024;
222
-
223
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
224
- { MODEL_TINY, 12ull*MB },
225
- { MODEL_BASE, 15ull*MB },
226
- { MODEL_SMALL, 23ull*MB },
227
- { MODEL_MEDIUM, 31ull*MB },
228
- { MODEL_LARGE, 38ull*MB },
229
- };
230
-
231
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
232
- { MODEL_TINY, 18ull*MB },
233
- { MODEL_BASE, 24ull*MB },
234
- { MODEL_SMALL, 36ull*MB },
235
- { MODEL_MEDIUM, 48ull*MB },
236
- { MODEL_LARGE, 60ull*MB },
237
- };
238
-
239
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
240
- { MODEL_TINY, 4ull*MB },
241
- { MODEL_BASE, 4ull*MB },
242
- { MODEL_SMALL, 6ull*MB },
243
- { MODEL_MEDIUM, 7ull*MB },
244
- { MODEL_LARGE, 9ull*MB },
245
- };
246
-
247
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
248
- { MODEL_TINY, 4ull*MB },
249
- { MODEL_BASE, 4ull*MB },
250
- { MODEL_SMALL, 6ull*MB },
251
- { MODEL_MEDIUM, 7ull*MB },
252
- { MODEL_LARGE, 9ull*MB },
253
- };
254
-
255
- static const std::map<e_model, size_t> MEM_REQ_MODEL = {
256
- { MODEL_TINY, 74ull*MB },
257
- { MODEL_BASE, 142ull*MB },
258
- { MODEL_SMALL, 466ull*MB },
259
- { MODEL_MEDIUM, 1464ull*MB },
260
- { MODEL_LARGE, 2952ull*MB },
261
- };
262
-
263
- static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
264
- { MODEL_TINY, 3ull*MB },
265
- { MODEL_BASE, 6ull*MB },
266
- { MODEL_SMALL, 16ull*MB },
267
- { MODEL_MEDIUM, 43ull*MB },
268
- { MODEL_LARGE, 71ull*MB },
269
- };
270
-
271
- static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
272
- { MODEL_TINY, 9ull*MB },
273
- { MODEL_BASE, 18ull*MB },
274
- { MODEL_SMALL, 53ull*MB },
275
- { MODEL_MEDIUM, 141ull*MB },
276
- { MODEL_LARGE, 235ull*MB },
277
- };
278
-
279
- static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
280
- { MODEL_TINY, 6ull*MB },
281
- { MODEL_BASE, 8ull*MB },
282
- { MODEL_SMALL, 13ull*MB },
283
- { MODEL_MEDIUM, 22ull*MB },
284
- { MODEL_LARGE, 33ull*MB },
285
- };
286
-
287
- static const std::map<e_model, size_t> MEM_REQ_DECODE = {
288
- { MODEL_TINY, 3ull*MB },
289
- { MODEL_BASE, 5ull*MB },
290
- { MODEL_SMALL, 10ull*MB },
291
- { MODEL_MEDIUM, 18ull*MB },
292
- { MODEL_LARGE, 27ull*MB },
293
- };
294
-
295
- struct whisper_mel {
296
- int n_len;
297
- int n_mel;
298
-
299
- std::vector<float> data;
300
- };
301
-
302
- struct whisper_filters {
303
- int32_t n_mel;
304
- int32_t n_fft;
305
-
306
- std::vector<float> data;
307
- };
308
-
309
- struct whisper_vocab {
310
- using id = int32_t;
311
- using token = std::string;
312
-
313
- int n_vocab = 51864;
314
-
315
- std::map<token, id> token_to_id;
316
- std::map<id, token> id_to_token;
317
-
318
- id token_eot = 50256;
319
- id token_sot = 50257;
320
- id token_prev = 50360;
321
- id token_solm = 50361; // ??
322
- id token_not = 50362; // no timestamps
323
- id token_beg = 50363;
324
-
325
- // available tasks
326
- static const id token_translate = 50358;
327
- static const id token_transcribe = 50359;
328
-
329
- bool is_multilingual() const {
330
- return n_vocab == 51865;
331
- }
332
- };
333
-
334
- struct whisper_segment {
335
- int64_t t0;
336
- int64_t t1;
337
-
338
- std::string text;
339
-
340
- std::vector<whisper_token_data> tokens;
341
- };
342
-
343
- // medium
344
- // hparams: {
345
- // 'n_mels': 80,
346
- // 'n_vocab': 51864,
347
- // 'n_audio_ctx': 1500,
348
- // 'n_audio_state': 1024,
349
- // 'n_audio_head': 16,
350
- // 'n_audio_layer': 24,
351
- // 'n_text_ctx': 448,
352
- // 'n_text_state': 1024,
353
- // 'n_text_head': 16,
354
- // 'n_text_layer': 24
355
- // }
356
- //
357
- // default hparams (Whisper tiny)
358
- struct whisper_hparams {
359
- int32_t n_vocab = 51864;
360
- int32_t n_audio_ctx = 1500;
361
- int32_t n_audio_state = 384;
362
- int32_t n_audio_head = 6;
363
- int32_t n_audio_layer = 4;
364
- int32_t n_text_ctx = 448;
365
- int32_t n_text_state = 384;
366
- int32_t n_text_head = 6;
367
- int32_t n_text_layer = 4;
368
- int32_t n_mels = 80;
369
- int32_t f16 = 1;
370
- };
371
-
372
- // audio encoding layer
373
- struct whisper_layer_encoder {
374
- // encoder.blocks.*.attn_ln
375
- struct ggml_tensor * attn_ln_0_w;
376
- struct ggml_tensor * attn_ln_0_b;
377
-
378
- // encoder.blocks.*.attn.out
379
- struct ggml_tensor * attn_ln_1_w;
380
- struct ggml_tensor * attn_ln_1_b;
381
-
382
- // encoder.blocks.*.attn.query
383
- struct ggml_tensor * attn_q_w;
384
- struct ggml_tensor * attn_q_b;
385
-
386
- // encoder.blocks.*.attn.key
387
- struct ggml_tensor * attn_k_w;
388
-
389
- // encoder.blocks.*.attn.value
390
- struct ggml_tensor * attn_v_w;
391
- struct ggml_tensor * attn_v_b;
392
-
393
- // encoder.blocks.*.mlp_ln
394
- struct ggml_tensor * mlp_ln_w;
395
- struct ggml_tensor * mlp_ln_b;
396
-
397
- // encoder.blocks.*.mlp.0
398
- struct ggml_tensor * mlp_0_w;
399
- struct ggml_tensor * mlp_0_b;
400
-
401
- // encoder.blocks.*.mlp.2
402
- struct ggml_tensor * mlp_1_w;
403
- struct ggml_tensor * mlp_1_b;
404
- };
405
-
406
- // token decoding layer
407
- struct whisper_layer_decoder {
408
- // decoder.blocks.*.attn_ln
409
- struct ggml_tensor * attn_ln_0_w;
410
- struct ggml_tensor * attn_ln_0_b;
411
-
412
- // decoder.blocks.*.attn.out
413
- struct ggml_tensor * attn_ln_1_w;
414
- struct ggml_tensor * attn_ln_1_b;
415
-
416
- // decoder.blocks.*.attn.query
417
- struct ggml_tensor * attn_q_w;
418
- struct ggml_tensor * attn_q_b;
419
-
420
- // decoder.blocks.*.attn.key
421
- struct ggml_tensor * attn_k_w;
422
-
423
- // decoder.blocks.*.attn.value
424
- struct ggml_tensor * attn_v_w;
425
- struct ggml_tensor * attn_v_b;
426
-
427
- // decoder.blocks.*.cross_attn_ln
428
- struct ggml_tensor * cross_attn_ln_0_w;
429
- struct ggml_tensor * cross_attn_ln_0_b;
430
-
431
- // decoder.blocks.*.cross_attn.out
432
- struct ggml_tensor * cross_attn_ln_1_w;
433
- struct ggml_tensor * cross_attn_ln_1_b;
434
-
435
- // decoder.blocks.*.cross_attn.query
436
- struct ggml_tensor * cross_attn_q_w;
437
- struct ggml_tensor * cross_attn_q_b;
438
-
439
- // decoder.blocks.*.cross_attn.key
440
- struct ggml_tensor * cross_attn_k_w;
441
-
442
- // decoder.blocks.*.cross_attn.value
443
- struct ggml_tensor * cross_attn_v_w;
444
- struct ggml_tensor * cross_attn_v_b;
445
-
446
- // decoder.blocks.*.mlp_ln
447
- struct ggml_tensor * mlp_ln_w;
448
- struct ggml_tensor * mlp_ln_b;
449
-
450
- // decoder.blocks.*.mlp.0
451
- struct ggml_tensor * mlp_0_w;
452
- struct ggml_tensor * mlp_0_b;
453
-
454
- // decoder.blocks.*.mlp.2
455
- struct ggml_tensor * mlp_1_w;
456
- struct ggml_tensor * mlp_1_b;
457
- };
458
-
459
- struct whisper_kv_cache {
460
- struct ggml_tensor * k;
461
- struct ggml_tensor * v;
462
-
463
- struct ggml_context * ctx;
464
-
465
- std::vector<uint8_t> buf;
466
-
467
- int n; // number of tokens currently in the cache
468
- };
469
-
470
- struct whisper_model {
471
- e_model type = MODEL_UNKNOWN;
472
-
473
- whisper_hparams hparams;
474
- whisper_filters filters;
475
-
476
- // encoder.positional_embedding
477
- struct ggml_tensor * e_pe;
478
-
479
- // encoder.conv1
480
- struct ggml_tensor * e_conv_1_w;
481
- struct ggml_tensor * e_conv_1_b;
482
-
483
- // encoder.conv2
484
- struct ggml_tensor * e_conv_2_w;
485
- struct ggml_tensor * e_conv_2_b;
486
-
487
- // encoder.ln_post
488
- struct ggml_tensor * e_ln_w;
489
- struct ggml_tensor * e_ln_b;
490
-
491
- // decoder.positional_embedding
492
- struct ggml_tensor * d_pe;
493
-
494
- // decoder.token_embedding
495
- struct ggml_tensor * d_te;
496
-
497
- // decoder.ln
498
- struct ggml_tensor * d_ln_w;
499
- struct ggml_tensor * d_ln_b;
500
-
501
- std::vector<whisper_layer_encoder> layers_encoder;
502
- std::vector<whisper_layer_decoder> layers_decoder;
503
-
504
- // context
505
- struct ggml_context * ctx;
506
-
507
- // the model memory buffer is read-only and can be shared between processors
508
- std::vector<uint8_t> * buf;
509
-
510
- // tensors
511
- int n_loaded;
512
- std::map<std::string, struct ggml_tensor *> tensors;
513
- };
514
-
515
- struct whisper_sequence {
516
- std::vector<whisper_token_data> tokens;
517
-
518
- // the accumulated transcription in the current interation (used to truncate the tokens array)
519
- int result_len;
520
-
521
- double sum_logprobs_all; // the sum of the log probabilities of the tokens
522
- double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens)
523
- double avg_logprobs; // the average log probability of the tokens
524
- double entropy; // the entropy of the tokens
525
- double score; // likelihood rank score
526
- };
527
-
528
- // TAGS: WHISPER_DECODER_INIT
529
- struct whisper_decoder {
530
- // each decoders keeps its own KV-cache
531
- whisper_kv_cache kv_self;
532
-
533
- // the currently generated sequence of tokens
534
- whisper_sequence sequence;
535
-
536
- int seek_delta; // the window shift found so far based on the decoded timestamp tokens
537
-
538
- bool failed; // has the current segment failed to decode?
539
- bool completed; // has the decoder completed the current segment?
540
- bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
541
-
542
- // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
543
- std::vector<float> probs;
544
- std::vector<float> logits;
545
- std::vector<float> logprobs;
546
-
547
- std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
548
- };
549
-
550
- struct whisper_context {
551
- int64_t t_load_us = 0;
552
- int64_t t_mel_us = 0;
553
- int64_t t_sample_us = 0;
554
- int64_t t_encode_us = 0;
555
- int64_t t_decode_us = 0;
556
- int64_t t_start_us = 0;
557
-
558
- int32_t n_sample = 0; // number of tokens sampled
559
- int32_t n_encode = 0; // number of encoder calls
560
- int32_t n_decode = 0; // number of decoder calls
561
- int32_t n_fail_p = 0; // number of logprob threshold failures
562
- int32_t n_fail_h = 0; // number of entropy threshold failures
563
-
564
- ggml_type wtype; // weight type (FP32 or FP16)
565
-
566
- whisper_mel mel;
567
-
568
- whisper_model model;
569
- whisper_vocab vocab;
570
-
571
- // cross-attention KV cache for the decoders
572
- // shared between all decoders
573
- whisper_kv_cache kv_cross;
574
-
575
- whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
576
-
577
- // memory buffers used by encode / decode contexts
578
- std::vector<uint8_t> buf_compute;
579
- std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
580
-
581
- int buf_last = 0;
582
- size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
583
-
584
- // decode output (2-dimensional array: [n_tokens][n_vocab])
585
- std::vector<float> logits;
586
-
587
- std::vector<whisper_segment> result_all;
588
- std::vector<whisper_token> prompt_past;
589
-
590
- // work container used to avoid memory allocations
591
- std::vector<std::pair<double, whisper_vocab::id>> logits_id;
592
-
593
- mutable std::mt19937 rng; // used for sampling at t > 0.0
594
-
595
- int lang_id = 0; // english by default
596
-
597
- // [EXPERIMENTAL] token-level timestamps data
598
- int64_t t_beg = 0;
599
- int64_t t_last = 0;
600
- whisper_token tid_last;
601
- std::vector<float> energy; // PCM signal energy
602
-
603
- // [EXPERIMENTAL] speed-up techniques
604
- int32_t exp_n_audio_ctx = 0; // 0 - use default
605
-
606
- // [EXPERIMENTAL] abort handling
607
- bool running = true;
608
-
609
- void use_buf(struct ggml_context * ctx, int i) {
610
- #if defined(WHISPER_USE_SCRATCH)
611
- size_t last_size = 0;
612
-
613
- if (i == -1) {
614
- last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
615
- } else {
616
- auto & buf = buf_scratch[i];
617
- last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
618
- }
619
-
620
- if (buf_last >= 0) {
621
- buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
622
- }
623
-
624
- buf_last = i;
625
- #else
626
- (void) i;
627
- (void) ctx;
628
- #endif
629
- }
630
-
631
- size_t get_buf_max_mem(int i) const {
632
- #if defined(WHISPER_USE_SCRATCH)
633
- return buf_max_size[i];
634
- #else
635
- (void) i;
636
- return 0;
637
- #endif
638
- }
639
- };
640
-
641
- template<typename T>
642
- static void read_safe(whisper_model_loader * loader, T & dest) {
643
- loader->read(loader->context, &dest, sizeof(T));
644
- BYTESWAP_VALUE(dest);
645
- }
646
-
647
- static bool kv_cache_init(
648
- const struct whisper_hparams & hparams,
649
- const size_t mem_bytes,
650
- struct whisper_kv_cache & cache,
651
- ggml_type wtype,
652
- int n_ctx) {
653
- cache.buf.resize(mem_bytes);
654
-
655
- struct ggml_init_params params;
656
- params.mem_size = cache.buf.size();
657
- params.mem_buffer = cache.buf.data();
658
-
659
- cache.ctx = ggml_init(params);
660
-
661
- if (!cache.ctx) {
662
- fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
663
- return false;
664
- }
665
-
666
- const int n_text_state = hparams.n_text_state;
667
- const int n_text_layer = hparams.n_text_layer;
668
-
669
- const int n_mem = n_text_layer*n_ctx;
670
- const int n_elements = n_text_state*n_mem;
671
-
672
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
673
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
674
-
675
- return true;
676
- }
677
-
678
- static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
679
- WHISPER_ASSERT(cache.ctx);
680
-
681
- const int n_elements = ggml_nelements(cache.k);
682
- WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
683
-
684
- const ggml_type wtype = cache.k->type;
685
- WHISPER_ASSERT(wtype == cache.v->type);
686
-
687
- WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
688
-
689
- struct ggml_init_params params;
690
- params.mem_size = cache.buf.size();
691
- params.mem_buffer = cache.buf.data();
692
-
693
- cache.ctx = ggml_init(params);
694
-
695
- if (!cache.ctx) {
696
- fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
697
- return false;
698
- }
699
-
700
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
701
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
702
-
703
- return true;
704
- }
705
-
706
- static void kv_cache_free(struct whisper_kv_cache & cache) {
707
- if (cache.ctx) {
708
- ggml_free(cache.ctx);
709
- cache.ctx = nullptr;
710
- }
711
- }
712
-
713
- // load the model from a ggml file
714
- //
715
- // file format:
716
- //
717
- // - hparams
718
- // - pre-computed mel filters
719
- // - vocab
720
- // - weights
721
- //
722
- // see the convert-pt-to-ggml.py script for details
723
- //
724
- static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
725
- fprintf(stderr, "%s: loading model\n", __func__);
726
-
727
- const int64_t t_start_us = ggml_time_us();
728
-
729
- wctx.t_start_us = t_start_us;
730
-
731
- auto & model = wctx.model;
732
- auto & vocab = wctx.vocab;
733
-
734
- // verify magic
735
- {
736
- uint32_t magic;
737
- read_safe(loader, magic);
738
- if (magic != 0x67676d6c) {
739
- fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
740
- return false;
741
- }
742
- }
743
-
744
- //load hparams
745
- {
746
- auto & hparams = model.hparams;
747
-
748
- read_safe(loader, hparams.n_vocab);
749
- read_safe(loader, hparams.n_audio_ctx);
750
- read_safe(loader, hparams.n_audio_state);
751
- read_safe(loader, hparams.n_audio_head);
752
- read_safe(loader, hparams.n_audio_layer);
753
- read_safe(loader, hparams.n_text_ctx);
754
- read_safe(loader, hparams.n_text_state);
755
- read_safe(loader, hparams.n_text_head);
756
- read_safe(loader, hparams.n_text_layer);
757
- read_safe(loader, hparams.n_mels);
758
- read_safe(loader, hparams.f16);
759
-
760
- assert(hparams.n_text_state == hparams.n_audio_state);
761
-
762
- if (hparams.n_audio_layer == 4) {
763
- model.type = e_model::MODEL_TINY;
764
- }
765
-
766
- if (hparams.n_audio_layer == 6) {
767
- model.type = e_model::MODEL_BASE;
768
- }
769
-
770
- if (hparams.n_audio_layer == 12) {
771
- model.type = e_model::MODEL_SMALL;
772
- }
773
-
774
- if (hparams.n_audio_layer == 24) {
775
- model.type = e_model::MODEL_MEDIUM;
776
- }
777
-
778
- if (hparams.n_audio_layer == 32) {
779
- model.type = e_model::MODEL_LARGE;
780
- }
781
-
782
- // for the big tensors, we have the option to store the data in 16-bit floats
783
- // in order to save memory and also to speed up the computation
784
- wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
785
-
786
- const size_t scale = model.hparams.f16 ? 1 : 2;
787
-
788
- fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
789
- fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
790
- fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
791
- fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
792
- fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
793
- fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
794
- fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
795
- fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
796
- fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
797
- fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
798
- fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
799
- fprintf(stderr, "%s: type = %d\n", __func__, model.type);
800
-
801
- // print memory requirements
802
- {
803
- // this is the total memory required to run the inference
804
- const size_t mem_required =
805
- MEM_REQ_SCRATCH0.at (model.type) +
806
- MEM_REQ_SCRATCH1.at (model.type) +
807
- MEM_REQ_SCRATCH2.at (model.type) +
808
- MEM_REQ_SCRATCH3.at (model.type) +
809
- scale*MEM_REQ_MODEL.at (model.type) +
810
- scale*MEM_REQ_KV_CROSS.at(model.type) +
811
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
812
-
813
- // this is the memory required by one decoder
814
- const size_t mem_required_decoder =
815
- scale*MEM_REQ_KV_SELF.at(model.type);
816
-
817
- fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
818
- mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
819
- }
820
-
821
- // initialize all memory buffers
822
- // always have at least one decoder
823
-
824
- wctx.model.buf = new std::vector<uint8_t>();
825
- wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
826
-
827
- if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
828
- fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
829
- return false;
830
- }
831
-
832
- {
833
- const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
834
- fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
835
- }
836
-
837
- if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
838
- fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
839
- return false;
840
- }
841
-
842
- {
843
- const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
844
- fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
845
- }
846
-
847
- wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
848
-
849
- wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
850
- wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
851
- wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
852
- wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
853
- }
854
-
855
- // load mel filters
856
- {
857
- auto & filters = wctx.model.filters;
858
-
859
- read_safe(loader, filters.n_mel);
860
- read_safe(loader, filters.n_fft);
861
-
862
- filters.data.resize(filters.n_mel * filters.n_fft);
863
- loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
864
- BYTESWAP_FILTERS(filters);
865
- }
866
-
867
- // load vocab
868
- {
869
- int32_t n_vocab = 0;
870
- read_safe(loader, n_vocab);
871
-
872
- //if (n_vocab != model.hparams.n_vocab) {
873
- // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
874
- // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
875
- // return false;
876
- //}
877
-
878
- std::string word;
879
- std::vector<char> tmp;
880
-
881
- tmp.reserve(128);
882
-
883
- for (int i = 0; i < n_vocab; i++) {
884
- uint32_t len;
885
- read_safe(loader, len);
886
-
887
- if (len > 0) {
888
- tmp.resize(len);
889
- loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
890
- word.assign(&tmp[0], tmp.size());
891
- } else {
892
- // seems like we have an empty-string token in multi-language models (i = 50256)
893
- //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
894
- word = "";
895
- }
896
-
897
- vocab.token_to_id[word] = i;
898
- vocab.id_to_token[i] = word;
899
-
900
- //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
901
- }
902
-
903
- vocab.n_vocab = model.hparams.n_vocab;
904
- if (vocab.is_multilingual()) {
905
- vocab.token_eot++;
906
- vocab.token_sot++;
907
- vocab.token_prev++;
908
- vocab.token_solm++;
909
- vocab.token_not++;
910
- vocab.token_beg++;
911
- }
912
-
913
- if (n_vocab < model.hparams.n_vocab) {
914
- fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
915
- for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
916
- if (i > vocab.token_beg) {
917
- word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
918
- } else if (i == vocab.token_eot) {
919
- word = "[_EOT_]";
920
- } else if (i == vocab.token_sot) {
921
- word = "[_SOT_]";
922
- } else if (i == vocab.token_prev) {
923
- word = "[_PREV_]";
924
- } else if (i == vocab.token_not) {
925
- word = "[_NOT_]";
926
- } else if (i == vocab.token_beg) {
927
- word = "[_BEG_]";
928
- } else {
929
- word = "[_extra_token_" + std::to_string(i) + "]";
930
- }
931
- vocab.token_to_id[word] = i;
932
- vocab.id_to_token[i] = word;
933
- }
934
- }
935
-
936
- wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
937
-
938
- wctx.logits_id.reserve(n_vocab);
939
-
940
- // TAGS: WHISPER_DECODER_INIT
941
- wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
942
-
943
- wctx.decoders[0].probs.reserve (vocab.n_vocab);
944
- wctx.decoders[0].logits.reserve (vocab.n_vocab);
945
- wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
946
- }
947
-
948
- size_t ctx_size = 0;
949
-
950
- const ggml_type wtype = wctx.wtype;
951
-
952
- {
953
- const auto & hparams = model.hparams;
954
-
955
- const int n_vocab = hparams.n_vocab;
956
-
957
- const int n_audio_ctx = hparams.n_audio_ctx;
958
- const int n_audio_state = hparams.n_audio_state;
959
- const int n_audio_layer = hparams.n_audio_layer;
960
-
961
- const int n_text_ctx = hparams.n_text_ctx;
962
- const int n_text_state = hparams.n_text_state;
963
- const int n_text_layer = hparams.n_text_layer;
964
-
965
- const int n_mels = hparams.n_mels;
966
-
967
- // encoder
968
- {
969
- ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
970
-
971
- ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
972
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
973
-
974
- ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
975
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
976
-
977
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
978
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
979
- }
980
-
981
- // decoder
982
- {
983
- ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
984
-
985
- ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
986
-
987
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
988
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
989
- }
990
-
991
- // encoder layers
992
- {
993
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
994
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
995
-
996
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
997
- ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
998
-
999
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
1000
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
1001
-
1002
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
1003
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
1004
-
1005
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
1006
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
1007
-
1008
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
1009
-
1010
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
1011
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
1012
-
1013
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
1014
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
1015
- }
1016
-
1017
- // decoder layers
1018
- {
1019
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
1020
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
1021
-
1022
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
1023
- ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
1024
-
1025
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
1026
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
1027
-
1028
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
1029
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
1030
-
1031
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
1032
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
1033
-
1034
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
1035
-
1036
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
1037
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
1038
-
1039
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
1040
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
1041
- //
1042
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
1043
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
1044
-
1045
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
1046
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
1047
-
1048
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
1049
-
1050
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
1051
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
1052
-
1053
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
1054
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
1055
- }
1056
-
1057
- ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
1058
-
1059
- fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
1060
- }
1061
-
1062
- // create the ggml context
1063
- {
1064
- struct ggml_init_params params;
1065
- params.mem_size = wctx.model.buf->size();
1066
- params.mem_buffer = wctx.model.buf->data();
1067
-
1068
- model.ctx = ggml_init(params);
1069
- if (!model.ctx) {
1070
- fprintf(stderr, "%s: ggml_init() failed\n", __func__);
1071
- return false;
1072
- }
1073
- }
1074
-
1075
- // prepare memory for the weights
1076
- {
1077
- auto & ctx = model.ctx;
1078
-
1079
- const auto & hparams = model.hparams;
1080
-
1081
- const int n_vocab = hparams.n_vocab;
1082
-
1083
- const int n_audio_ctx = hparams.n_audio_ctx;
1084
- const int n_audio_state = hparams.n_audio_state;
1085
- const int n_audio_layer = hparams.n_audio_layer;
1086
-
1087
- const int n_text_ctx = hparams.n_text_ctx;
1088
- const int n_text_state = hparams.n_text_state;
1089
- const int n_text_layer = hparams.n_text_layer;
1090
-
1091
- const int n_mels = hparams.n_mels;
1092
-
1093
- model.layers_encoder.resize(n_audio_layer);
1094
- model.layers_decoder.resize(n_text_layer);
1095
-
1096
- // encoder
1097
- {
1098
- model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1099
-
1100
- model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
1101
- model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1102
-
1103
- model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
1104
- model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1105
-
1106
- model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1107
- model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1108
-
1109
- // map by name
1110
- model.tensors["encoder.positional_embedding"] = model.e_pe;
1111
-
1112
- model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
1113
- model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
1114
-
1115
- model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
1116
- model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
1117
-
1118
- model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
1119
- model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
1120
-
1121
- for (int i = 0; i < n_audio_layer; ++i) {
1122
- auto & layer = model.layers_encoder[i];
1123
-
1124
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1125
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1126
-
1127
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
1128
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
1129
-
1130
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
1131
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1132
-
1133
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1134
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1135
-
1136
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1137
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1138
-
1139
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1140
-
1141
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1142
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1143
-
1144
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1145
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1146
-
1147
- // map by name
1148
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1149
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1150
-
1151
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1152
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1153
-
1154
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1155
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1156
-
1157
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1158
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1159
-
1160
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1161
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1162
-
1163
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1164
-
1165
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1166
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1167
-
1168
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1169
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1170
- }
1171
- }
1172
-
1173
- // decoder
1174
- {
1175
- model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
1176
-
1177
- model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
1178
-
1179
- model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1180
- model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1181
-
1182
- // map by name
1183
- model.tensors["decoder.positional_embedding"] = model.d_pe;
1184
-
1185
- model.tensors["decoder.token_embedding.weight"] = model.d_te;
1186
-
1187
- model.tensors["decoder.ln.weight"] = model.d_ln_w;
1188
- model.tensors["decoder.ln.bias"] = model.d_ln_b;
1189
-
1190
- for (int i = 0; i < n_text_layer; ++i) {
1191
- auto & layer = model.layers_decoder[i];
1192
-
1193
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1194
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1195
-
1196
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
1197
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
1198
-
1199
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
1200
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1201
-
1202
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1203
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1204
-
1205
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1206
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1207
-
1208
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1209
-
1210
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1211
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1212
-
1213
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1214
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1215
-
1216
- layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1217
- layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1218
-
1219
- layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1220
- layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1221
-
1222
- layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1223
-
1224
- layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1225
- layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1226
-
1227
- layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1228
- layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1229
-
1230
- // map by name
1231
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1232
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1233
-
1234
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1235
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1236
-
1237
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1238
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1239
-
1240
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1241
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1242
-
1243
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1244
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1245
-
1246
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1247
-
1248
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1249
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1250
-
1251
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1252
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1253
-
1254
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
1255
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
1256
-
1257
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
1258
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
1259
-
1260
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
1261
-
1262
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
1263
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
1264
-
1265
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
1266
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
1267
- }
1268
- }
1269
- }
1270
-
1271
- // load weights
1272
- {
1273
- size_t total_size = 0;
1274
-
1275
- model.n_loaded = 0;
1276
-
1277
- while (true) {
1278
- int32_t n_dims;
1279
- int32_t length;
1280
- int32_t ftype;
1281
-
1282
- read_safe(loader, n_dims);
1283
- read_safe(loader, length);
1284
- read_safe(loader, ftype);
1285
-
1286
- if (loader->eof(loader->context)) {
1287
- break;
1288
- }
1289
-
1290
- int32_t nelements = 1;
1291
- int32_t ne[3] = { 1, 1, 1 };
1292
- for (int i = 0; i < n_dims; ++i) {
1293
- read_safe(loader, ne[i]);
1294
- nelements *= ne[i];
1295
- }
1296
-
1297
- std::string name;
1298
- std::vector<char> tmp(length); // create a buffer
1299
- loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
1300
- name.assign(&tmp[0], tmp.size());
1301
-
1302
- if (model.tensors.find(name) == model.tensors.end()) {
1303
- fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
1304
- return false;
1305
- }
1306
-
1307
- auto tensor = model.tensors[name.data()];
1308
- if (ggml_nelements(tensor) != nelements) {
1309
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1310
- return false;
1311
- }
1312
-
1313
- if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1314
- fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1315
- __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
1316
- return false;
1317
- }
1318
-
1319
- const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
1320
-
1321
- if (nelements*bpe != ggml_nbytes(tensor)) {
1322
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1323
- __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1324
- return false;
1325
- }
1326
-
1327
- loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1328
- BYTESWAP_TENSOR(tensor);
1329
-
1330
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
1331
- total_size += ggml_nbytes(tensor);
1332
- model.n_loaded++;
1333
- }
1334
-
1335
- fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
1336
-
1337
- if (model.n_loaded == 0) {
1338
- fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1339
- } else if (model.n_loaded != (int) model.tensors.size()) {
1340
- fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1341
- return false;
1342
- }
1343
- }
1344
-
1345
- wctx.rng = std::mt19937(0);
1346
-
1347
- wctx.t_load_us = ggml_time_us() - t_start_us;
1348
-
1349
- return true;
1350
- }
1351
-
1352
- // evaluate the encoder
1353
- //
1354
- // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1355
- // part of the transformer model and returns the encoded features
1356
- //
1357
- // - model: the model
1358
- // - n_threads: number of threads to use
1359
- // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1360
- //
1361
- static bool whisper_encode(
1362
- whisper_context & wctx,
1363
- const int mel_offset,
1364
- const int n_threads) {
1365
- const int64_t t_start_us = ggml_time_us();
1366
-
1367
- const auto & model = wctx.model;
1368
- const auto & mel_inp = wctx.mel;
1369
- const auto & hparams = model.hparams;
1370
-
1371
- const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1372
- const int n_state = hparams.n_audio_state;
1373
- const int n_head = hparams.n_audio_head;
1374
- const int n_layer = hparams.n_audio_layer;
1375
-
1376
- const int n_mels = hparams.n_mels;
1377
- assert(mel_inp.n_mel == n_mels);
1378
-
1379
- struct ggml_init_params params;
1380
- params.mem_size = wctx.buf_compute.size();
1381
- params.mem_buffer = wctx.buf_compute.data();
1382
-
1383
- struct ggml_context * ctx0 = ggml_init(params);
1384
-
1385
- wctx.use_buf(ctx0, 0);
1386
-
1387
- struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1388
- assert(mel->type == GGML_TYPE_F32);
1389
- {
1390
- float * dst = (float *) mel->data;
1391
- memset(dst, 0, ggml_nbytes(mel));
1392
-
1393
- const int i0 = std::min(mel_offset, mel_inp.n_len);
1394
- const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1395
-
1396
- for (int j = 0; j < mel_inp.n_mel; ++j) {
1397
- for (int i = i0; i < i1; ++i) {
1398
- dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1399
- }
1400
- }
1401
- }
1402
-
1403
- struct ggml_tensor * cur;
1404
-
1405
- // convolution + gelu
1406
- {
1407
- wctx.use_buf(ctx0, 1);
1408
-
1409
- cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1410
- cur = ggml_add(ctx0,
1411
- ggml_repeat(ctx0,
1412
- model.e_conv_1_b,
1413
- cur),
1414
- cur);
1415
-
1416
- cur = ggml_gelu(ctx0, cur);
1417
-
1418
- wctx.use_buf(ctx0, 0);
1419
-
1420
- cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1421
- cur = ggml_add(ctx0,
1422
- ggml_repeat(ctx0,
1423
- model.e_conv_2_b,
1424
- cur),
1425
- cur);
1426
-
1427
- cur = ggml_gelu(ctx0, cur);
1428
- }
1429
-
1430
- wctx.use_buf(ctx0, 3);
1431
-
1432
- // ===================================================================
1433
- // NOTE: experimenting with partial evaluation of the encoder (ignore)
1434
- //static int iter = -1;
1435
- //const int n_iter = 1500/n_ctx;
1436
-
1437
- //iter = (iter + 1) % n_iter;
1438
-
1439
- //if (iter == 0) {
1440
- // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
1441
- // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
1442
- //}
1443
-
1444
- static int iter = 0;
1445
-
1446
- const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1447
- const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1448
-
1449
- struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1450
-
1451
- cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
1452
-
1453
- // ===================================================================
1454
-
1455
- // original:
1456
- //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1457
-
1458
- struct ggml_tensor * inpL = cur;
1459
-
1460
- for (int il = 0; il < n_layer; ++il) {
1461
- const auto & layer = model.layers_encoder[il];
1462
-
1463
- // norm
1464
- {
1465
- wctx.use_buf(ctx0, 0);
1466
-
1467
- cur = ggml_norm(ctx0, inpL);
1468
-
1469
- // cur = ln_0_w*cur + ln_0_b
1470
- cur = ggml_add(ctx0,
1471
- ggml_mul(ctx0,
1472
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1473
- cur),
1474
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1475
- }
1476
-
1477
- // self-attention
1478
- {
1479
- wctx.use_buf(ctx0, 1);
1480
-
1481
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1482
- layer.attn_q_w,
1483
- cur);
1484
-
1485
- Qcur = ggml_add(ctx0,
1486
- ggml_repeat(ctx0,
1487
- layer.attn_q_b,
1488
- Qcur),
1489
- Qcur);
1490
-
1491
- //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1492
-
1493
- // note: no bias for Key
1494
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1495
- layer.attn_k_w,
1496
- cur);
1497
-
1498
- //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1499
-
1500
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1501
- layer.attn_v_w,
1502
- cur);
1503
-
1504
- Vcur = ggml_add(ctx0,
1505
- ggml_repeat(ctx0,
1506
- layer.attn_v_b,
1507
- Vcur),
1508
- Vcur);
1509
-
1510
- // ------
1511
-
1512
- wctx.use_buf(ctx0, 0);
1513
-
1514
- #ifdef WHISPER_USE_FLASH_ATTN
1515
- struct ggml_tensor * Q =
1516
- ggml_permute(ctx0,
1517
- ggml_cpy(ctx0,
1518
- Qcur,
1519
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1520
- 0, 2, 1, 3);
1521
-
1522
- struct ggml_tensor * K =
1523
- ggml_permute(ctx0,
1524
- ggml_cpy(ctx0,
1525
- Kcur,
1526
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1527
- 0, 2, 1, 3);
1528
-
1529
- struct ggml_tensor * V =
1530
- ggml_cpy(ctx0,
1531
- ggml_permute(ctx0,
1532
- ggml_reshape_3d(ctx0,
1533
- Vcur,
1534
- n_state/n_head, n_head, n_ctx),
1535
- 1, 2, 0, 3),
1536
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
1537
- );
1538
-
1539
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1540
- #else
1541
- struct ggml_tensor * Q =
1542
- ggml_permute(ctx0,
1543
- ggml_cpy(ctx0,
1544
- Qcur,
1545
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1546
- 0, 2, 1, 3);
1547
-
1548
- struct ggml_tensor * K =
1549
- ggml_permute(ctx0,
1550
- ggml_cpy(ctx0,
1551
- Kcur,
1552
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1553
- 0, 2, 1, 3);
1554
-
1555
- // K * Q
1556
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1557
-
1558
- struct ggml_tensor * KQ_scaled =
1559
- ggml_scale(ctx0,
1560
- KQ,
1561
- ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1562
- );
1563
-
1564
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
1565
-
1566
- //struct ggml_tensor * V_trans =
1567
- // ggml_permute(ctx0,
1568
- // ggml_cpy(ctx0,
1569
- // Vcur,
1570
- // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1571
- // 1, 2, 0, 3);
1572
-
1573
- //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1574
-
1575
- struct ggml_tensor * V =
1576
- ggml_cpy(ctx0,
1577
- ggml_permute(ctx0,
1578
- ggml_reshape_3d(ctx0,
1579
- Vcur,
1580
- n_state/n_head, n_head, n_ctx),
1581
- 0, 2, 1, 3),
1582
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
1583
- );
1584
-
1585
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
1586
- #endif
1587
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1588
-
1589
- wctx.use_buf(ctx0, 1);
1590
-
1591
- cur = ggml_cpy(ctx0,
1592
- KQV_merged,
1593
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1594
- }
1595
-
1596
- // projection
1597
- {
1598
- wctx.use_buf(ctx0, 0);
1599
-
1600
- cur = ggml_mul_mat(ctx0,
1601
- layer.attn_ln_1_w,
1602
- cur);
1603
-
1604
- wctx.use_buf(ctx0, 1);
1605
-
1606
- cur = ggml_add(ctx0,
1607
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1608
- cur);
1609
- }
1610
-
1611
- wctx.use_buf(ctx0, 2);
1612
-
1613
- // add the input
1614
- cur = ggml_add(ctx0, cur, inpL);
1615
-
1616
- struct ggml_tensor * inpFF = cur;
1617
-
1618
- // feed-forward network
1619
- {
1620
- // norm
1621
- {
1622
- wctx.use_buf(ctx0, 0);
1623
-
1624
- cur = ggml_norm(ctx0, inpFF);
1625
-
1626
- wctx.use_buf(ctx0, 1);
1627
-
1628
- // cur = mlp_ln_w*cur + mlp_ln_b
1629
- cur = ggml_add(ctx0,
1630
- ggml_mul(ctx0,
1631
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1632
- cur),
1633
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1634
- }
1635
-
1636
- #ifdef WHISPER_USE_FLASH_FF
1637
- wctx.use_buf(ctx0, 0);
1638
-
1639
- cur = ggml_flash_ff(ctx0,
1640
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
1641
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1642
- #else
1643
- wctx.use_buf(ctx0, 0);
1644
-
1645
- // fully connected
1646
- cur = ggml_mul_mat(ctx0,
1647
- layer.mlp_0_w,
1648
- cur);
1649
-
1650
- wctx.use_buf(ctx0, 1);
1651
-
1652
- cur = ggml_add(ctx0,
1653
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
1654
- cur);
1655
-
1656
- wctx.use_buf(ctx0, 0);
1657
-
1658
- // GELU activation
1659
- cur = ggml_gelu(ctx0, cur);
1660
-
1661
- wctx.use_buf(ctx0, 1);
1662
-
1663
- // projection
1664
- cur = ggml_mul_mat(ctx0,
1665
- layer.mlp_1_w,
1666
- cur);
1667
-
1668
- wctx.use_buf(ctx0, 0);
1669
-
1670
- cur = ggml_add(ctx0,
1671
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
1672
- cur);
1673
- #endif
1674
- }
1675
-
1676
- wctx.use_buf(ctx0, 3);
1677
-
1678
- inpL = ggml_add(ctx0, cur, inpFF);
1679
- }
1680
-
1681
- cur = inpL;
1682
-
1683
- // norm
1684
- {
1685
- wctx.use_buf(ctx0, 0);
1686
-
1687
- cur = ggml_norm(ctx0, cur);
1688
-
1689
- wctx.use_buf(ctx0, 1);
1690
-
1691
- // cur = ln_f_g*cur + ln_f_b
1692
- cur = ggml_add(ctx0,
1693
- ggml_mul(ctx0,
1694
- ggml_repeat(ctx0, model.e_ln_w, cur),
1695
- cur),
1696
- ggml_repeat(ctx0, model.e_ln_b, cur));
1697
- }
1698
-
1699
- wctx.use_buf(ctx0, -1);
1700
-
1701
- // run the computation
1702
- {
1703
- struct ggml_cgraph gf = {};
1704
- gf.n_threads = n_threads;
1705
-
1706
- ggml_build_forward_expand(&gf, cur);
1707
- ggml_graph_compute (ctx0, &gf);
1708
-
1709
- //ggml_graph_print(&gf);
1710
- }
1711
-
1712
- // cur
1713
- //{
1714
- // printf("ne0 = %d\n", cur->ne[0]);
1715
- // printf("ne1 = %d\n", cur->ne[1]);
1716
- // for (int i = 0; i < 10; ++i) {
1717
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1718
- // }
1719
- // printf("... ");
1720
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1721
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1722
- // }
1723
- // printf("\n");
1724
- //}
1725
-
1726
- // pre-compute cross-attention memory
1727
- {
1728
- struct ggml_cgraph gf = {};
1729
- gf.n_threads = n_threads;
1730
-
1731
- // TODO: hack to disconnect the encoded features from the previous graph
1732
- cur->op = GGML_OP_NONE;
1733
- cur->src0 = nullptr;
1734
- cur->src1 = nullptr;
1735
-
1736
- for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1737
- auto & layer = model.layers_decoder[il];
1738
-
1739
- wctx.use_buf(ctx0, 0);
1740
-
1741
- struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1742
- layer.cross_attn_k_w,
1743
- cur);
1744
-
1745
- Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1746
-
1747
- wctx.use_buf(ctx0, 1);
1748
-
1749
- struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1750
- layer.cross_attn_v_w,
1751
- cur);
1752
-
1753
- Vcross = ggml_add(ctx0,
1754
- ggml_repeat(ctx0,
1755
- layer.cross_attn_v_b,
1756
- Vcross),
1757
- Vcross);
1758
-
1759
- wctx.use_buf(ctx0, -1);
1760
-
1761
- //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1762
- //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1763
- struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
1764
- struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
1765
-
1766
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1767
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1768
- }
1769
-
1770
- ggml_graph_compute(ctx0, &gf);
1771
- //ggml_graph_print(&gf);
1772
- }
1773
-
1774
- ////////////////////////////////////////////////////////////////////////////
1775
-
1776
- //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1777
- // ggml_used_mem(ctx0)/1024.0/1024.0,
1778
- // wctx.get_buf_max_mem(0)/1024.0/1024.0,
1779
- // wctx.get_buf_max_mem(1)/1024.0/1024.0,
1780
- // wctx.get_buf_max_mem(2)/1024.0/1024.0,
1781
- // wctx.get_buf_max_mem(3)/1024.0/1024.0);
1782
-
1783
- ggml_free(ctx0);
1784
-
1785
- wctx.t_encode_us += ggml_time_us() - t_start_us;
1786
- wctx.n_encode++;
1787
-
1788
- return true;
1789
- }
1790
-
1791
- // evaluate the decoder
1792
- //
1793
- // given text prompt + audio features -> computes the logits for the next token
1794
- //
1795
- // - model: the model
1796
- // - n_threads: number of threads to use
1797
- // - tokens: text prompt
1798
- // - n_tokens: number of tokens in the prompt
1799
- // - n_past: number of past tokens to prefix the prompt with
1800
- //
1801
- static bool whisper_decode(
1802
- whisper_context & wctx,
1803
- whisper_decoder & decoder,
1804
- const whisper_token * tokens,
1805
- const int n_tokens,
1806
- const int n_past,
1807
- const int n_threads) {
1808
- const int64_t t_start_us = ggml_time_us();
1809
-
1810
- const auto & model = wctx.model;
1811
- const auto & hparams = model.hparams;
1812
-
1813
- auto & kv_self = decoder.kv_self;
1814
-
1815
- WHISPER_ASSERT(!!kv_self.ctx);
1816
-
1817
- auto & logits_out = wctx.logits;
1818
-
1819
- const int n_vocab = hparams.n_vocab;
1820
-
1821
- const int n_ctx = hparams.n_text_ctx;
1822
- const int n_state = hparams.n_text_state;
1823
- const int n_head = hparams.n_text_head;
1824
- const int n_layer = hparams.n_text_layer;
1825
-
1826
- const int N = n_tokens;
1827
- const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1828
-
1829
- //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1830
-
1831
- struct ggml_init_params params;
1832
- params.mem_size = wctx.buf_compute.size();
1833
- params.mem_buffer = wctx.buf_compute.data();
1834
-
1835
- struct ggml_context * ctx0 = ggml_init(params);
1836
-
1837
- struct ggml_cgraph gf = {};
1838
- gf.n_threads = n_threads;
1839
-
1840
- struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1841
- memcpy(embd->data, tokens, N*ggml_element_size(embd));
1842
-
1843
- struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1844
- for (int i = 0; i < N; ++i) {
1845
- ((int32_t *) position->data)[i] = n_past + i;
1846
- }
1847
-
1848
- wctx.use_buf(ctx0, 3);
1849
-
1850
- // token encoding + position encoding
1851
- struct ggml_tensor * cur =
1852
- ggml_add(ctx0,
1853
- ggml_get_rows(ctx0, model.d_te, embd),
1854
- ggml_get_rows(ctx0, model.d_pe, position));
1855
-
1856
- struct ggml_tensor * inpL = cur;
1857
-
1858
- for (int il = 0; il < n_layer; ++il) {
1859
- const auto & layer = model.layers_decoder[il];
1860
-
1861
- // norm
1862
- {
1863
- wctx.use_buf(ctx0, 0);
1864
-
1865
- cur = ggml_norm(ctx0, inpL);
1866
-
1867
- // cur = ln_0_w*cur + ln_0_b
1868
- cur = ggml_add(ctx0,
1869
- ggml_mul(ctx0,
1870
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1871
- cur),
1872
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1873
- }
1874
-
1875
- // self-attention
1876
- {
1877
- wctx.use_buf(ctx0, 1);
1878
-
1879
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1880
- layer.attn_q_w,
1881
- cur);
1882
-
1883
- Qcur = ggml_add(ctx0,
1884
- ggml_repeat(ctx0,
1885
- layer.attn_q_b,
1886
- Qcur),
1887
- Qcur);
1888
-
1889
- Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1890
-
1891
- // note: no bias for Key
1892
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1893
- layer.attn_k_w,
1894
- cur);
1895
-
1896
- Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1897
-
1898
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1899
- layer.attn_v_w,
1900
- cur);
1901
-
1902
- Vcur = ggml_add(ctx0,
1903
- ggml_repeat(ctx0,
1904
- layer.attn_v_b,
1905
- Vcur),
1906
- Vcur);
1907
-
1908
- // store key and value to memory
1909
- {
1910
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
1911
- struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
1912
-
1913
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
1914
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
1915
- }
1916
-
1917
- // ------
1918
-
1919
- wctx.use_buf(ctx0, 0);
1920
-
1921
- struct ggml_tensor * Q =
1922
- ggml_permute(ctx0,
1923
- ggml_cpy(ctx0,
1924
- Qcur,
1925
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1926
- 0, 2, 1, 3);
1927
-
1928
- struct ggml_tensor * K =
1929
- ggml_permute(ctx0,
1930
- ggml_reshape_3d(ctx0,
1931
- ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
1932
- n_state/n_head, n_head, n_past + N),
1933
- 0, 2, 1, 3);
1934
-
1935
- wctx.use_buf(ctx0, 1);
1936
-
1937
- // K * Q
1938
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1939
-
1940
- wctx.use_buf(ctx0, 0);
1941
-
1942
- //struct ggml_tensor * KQ_scaled =
1943
- // ggml_scale(ctx0,
1944
- // KQ,
1945
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1946
- // );
1947
-
1948
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1949
-
1950
- wctx.use_buf(ctx0, 1);
1951
-
1952
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1953
-
1954
- wctx.use_buf(ctx0, 0);
1955
-
1956
- struct ggml_tensor * V_trans =
1957
- ggml_permute(ctx0,
1958
- ggml_reshape_3d(ctx0,
1959
- ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
1960
- n_state/n_head, n_head, n_past + N),
1961
- 1, 2, 0, 3);
1962
-
1963
- wctx.use_buf(ctx0, 1);
1964
-
1965
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1966
-
1967
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1968
-
1969
- cur = ggml_cpy(ctx0,
1970
- KQV_merged,
1971
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
1972
- }
1973
-
1974
- // projection
1975
- {
1976
- wctx.use_buf(ctx0, 0);
1977
-
1978
- cur = ggml_mul_mat(ctx0,
1979
- layer.attn_ln_1_w,
1980
- cur);
1981
-
1982
- wctx.use_buf(ctx0, 1);
1983
-
1984
- cur = ggml_add(ctx0,
1985
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1986
- cur);
1987
- }
1988
-
1989
- wctx.use_buf(ctx0, 2);
1990
-
1991
- // add the input
1992
- struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
1993
-
1994
- // norm
1995
- {
1996
- wctx.use_buf(ctx0, 0);
1997
-
1998
- cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1999
-
2000
- wctx.use_buf(ctx0, 1);
2001
-
2002
- // cur = ln_0_w*cur + ln_0_b
2003
- cur = ggml_add(ctx0,
2004
- ggml_mul(ctx0,
2005
- ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
2006
- cur),
2007
- ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
2008
- }
2009
-
2010
- // cross-attention
2011
- {
2012
- wctx.use_buf(ctx0, 0);
2013
-
2014
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
2015
- layer.cross_attn_q_w,
2016
- cur);
2017
-
2018
- Qcur = ggml_add(ctx0,
2019
- ggml_repeat(ctx0,
2020
- layer.cross_attn_q_b,
2021
- Qcur),
2022
- Qcur);
2023
-
2024
- Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2025
-
2026
- // Kcross is already scaled
2027
- struct ggml_tensor * Kcross =
2028
- ggml_reshape_3d(ctx0,
2029
- ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
2030
- n_state/n_head, n_head, M);
2031
-
2032
- struct ggml_tensor * Vcross =
2033
- ggml_reshape_3d(ctx0,
2034
- ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
2035
- n_state/n_head, n_head, M);
2036
-
2037
- struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
2038
-
2039
- // ------
2040
-
2041
- wctx.use_buf(ctx0, 1);
2042
-
2043
- struct ggml_tensor * Q =
2044
- ggml_permute(ctx0,
2045
- ggml_cpy(ctx0,
2046
- Qcur,
2047
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
2048
- 0, 2, 1, 3);
2049
-
2050
- struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2051
-
2052
- wctx.use_buf(ctx0, 0);
2053
-
2054
- // K * Q
2055
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2056
-
2057
- //struct ggml_tensor * KQ_scaled =
2058
- // ggml_scale(ctx0,
2059
- // KQ,
2060
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2061
- // );
2062
-
2063
- // no masking for cross-attention
2064
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2065
-
2066
- wctx.use_buf(ctx0, 1);
2067
-
2068
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2069
-
2070
- wctx.use_buf(ctx0, 0);
2071
-
2072
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2073
-
2074
- wctx.use_buf(ctx0, 1);
2075
-
2076
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2077
-
2078
- // cur = KQV_merged.contiguous().view(n_state, N)
2079
- cur = ggml_cpy(ctx0,
2080
- KQV_merged,
2081
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
2082
- }
2083
-
2084
- // projection
2085
- {
2086
- wctx.use_buf(ctx0, 0);
2087
-
2088
- cur = ggml_mul_mat(ctx0,
2089
- layer.cross_attn_ln_1_w,
2090
- cur);
2091
-
2092
- wctx.use_buf(ctx0, 1);
2093
-
2094
- cur = ggml_add(ctx0,
2095
- ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2096
- cur);
2097
- }
2098
-
2099
- wctx.use_buf(ctx0, 2);
2100
-
2101
- // add the input
2102
- cur = ggml_add(ctx0, cur, inpCA);
2103
-
2104
- struct ggml_tensor * inpFF = cur;
2105
-
2106
- // feed-forward network
2107
- {
2108
- // norm
2109
- {
2110
- wctx.use_buf(ctx0, 0);
2111
-
2112
- cur = ggml_norm(ctx0, inpFF);
2113
-
2114
- wctx.use_buf(ctx0, 1);
2115
-
2116
- // cur = mlp_ln_w*cur + mlp_ln_b
2117
- cur = ggml_add(ctx0,
2118
- ggml_mul(ctx0,
2119
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
2120
- cur),
2121
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2122
- }
2123
-
2124
- wctx.use_buf(ctx0, 0);
2125
-
2126
- // fully connected
2127
- cur = ggml_mul_mat(ctx0,
2128
- layer.mlp_0_w,
2129
- cur);
2130
-
2131
- wctx.use_buf(ctx0, 1);
2132
-
2133
- cur = ggml_add(ctx0,
2134
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
2135
- cur);
2136
-
2137
- wctx.use_buf(ctx0, 0);
2138
-
2139
- // GELU activation
2140
- cur = ggml_gelu(ctx0, cur);
2141
-
2142
- wctx.use_buf(ctx0, 1);
2143
-
2144
- // projection
2145
- cur = ggml_mul_mat(ctx0,
2146
- layer.mlp_1_w,
2147
- cur);
2148
-
2149
- wctx.use_buf(ctx0, 0);
2150
-
2151
- cur = ggml_add(ctx0,
2152
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
2153
- cur);
2154
- }
2155
-
2156
- wctx.use_buf(ctx0, 3);
2157
-
2158
- inpL = ggml_add(ctx0, cur, inpFF);
2159
- }
2160
-
2161
- cur = inpL;
2162
-
2163
- // norm
2164
- {
2165
- wctx.use_buf(ctx0, 0);
2166
-
2167
- cur = ggml_norm(ctx0, cur);
2168
-
2169
- wctx.use_buf(ctx0, 1);
2170
-
2171
- cur = ggml_add(ctx0,
2172
- ggml_mul(ctx0,
2173
- ggml_repeat(ctx0, model.d_ln_w, cur),
2174
- cur),
2175
- ggml_repeat(ctx0, model.d_ln_b, cur));
2176
- }
2177
-
2178
- wctx.use_buf(ctx0, 0);
2179
-
2180
- // compute logits only for the last token
2181
- // comment this line to compute logits for all N tokens
2182
- // might be useful in the future
2183
- cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2184
-
2185
- struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2186
-
2187
- wctx.use_buf(ctx0, -1);
2188
-
2189
- // run the computation
2190
- {
2191
- ggml_build_forward_expand(&gf, logits);
2192
- ggml_graph_compute (ctx0, &gf);
2193
- }
2194
-
2195
- // extract logits for all N tokens
2196
- //logits_out.resize(N*n_vocab);
2197
- //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
2198
-
2199
- // extract logits only for the last token
2200
- logits_out.resize(n_vocab);
2201
- memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
2202
-
2203
- if (N > 1) {
2204
- //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2205
- // ggml_used_mem(ctx0)/1024.0/1024.0,
2206
- // wctx.get_buf_max_mem(0)/1024.0/1024.0,
2207
- // wctx.get_buf_max_mem(1)/1024.0/1024.0,
2208
- // wctx.get_buf_max_mem(2)/1024.0/1024.0,
2209
- // wctx.get_buf_max_mem(3)/1024.0/1024.0);
2210
- }
2211
-
2212
- ggml_free(ctx0);
2213
-
2214
- wctx.t_decode_us += ggml_time_us() - t_start_us;
2215
- wctx.n_decode++;
2216
-
2217
- return true;
2218
- }
2219
-
2220
- // 500 -> 00:05.000
2221
- // 6000 -> 01:00.000
2222
- static std::string to_timestamp(int64_t t, bool comma = false) {
2223
- int64_t msec = t * 10;
2224
- int64_t hr = msec / (1000 * 60 * 60);
2225
- msec = msec - hr * (1000 * 60 * 60);
2226
- int64_t min = msec / (1000 * 60);
2227
- msec = msec - min * (1000 * 60);
2228
- int64_t sec = msec / 1000;
2229
- msec = msec - sec * 1000;
2230
-
2231
- char buf[32];
2232
- snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
2233
-
2234
- return std::string(buf);
2235
- }
2236
-
2237
- // naive Discrete Fourier Transform
2238
- // input is real-valued
2239
- // output is complex-valued
2240
- static void dft(const std::vector<float> & in, std::vector<float> & out) {
2241
- int N = in.size();
2242
-
2243
- out.resize(N*2);
2244
-
2245
- for (int k = 0; k < N; k++) {
2246
- float re = 0;
2247
- float im = 0;
2248
-
2249
- for (int n = 0; n < N; n++) {
2250
- float angle = 2*M_PI*k*n/N;
2251
- re += in[n]*cos(angle);
2252
- im -= in[n]*sin(angle);
2253
- }
2254
-
2255
- out[k*2 + 0] = re;
2256
- out[k*2 + 1] = im;
2257
- }
2258
- }
2259
-
2260
- // Cooley-Tukey FFT
2261
- // poor man's implementation - use something better
2262
- // input is real-valued
2263
- // output is complex-valued
2264
- static void fft(const std::vector<float> & in, std::vector<float> & out) {
2265
- out.resize(in.size()*2);
2266
-
2267
- int N = in.size();
2268
-
2269
- if (N == 1) {
2270
- out[0] = in[0];
2271
- out[1] = 0;
2272
- return;
2273
- }
2274
-
2275
- if (N%2 == 1) {
2276
- dft(in, out);
2277
- return;
2278
- }
2279
-
2280
- std::vector<float> even;
2281
- std::vector<float> odd;
2282
-
2283
- even.reserve(N/2);
2284
- odd.reserve(N/2);
2285
-
2286
- for (int i = 0; i < N; i++) {
2287
- if (i % 2 == 0) {
2288
- even.push_back(in[i]);
2289
- } else {
2290
- odd.push_back(in[i]);
2291
- }
2292
- }
2293
-
2294
- std::vector<float> even_fft;
2295
- std::vector<float> odd_fft;
2296
-
2297
- fft(even, even_fft);
2298
- fft(odd, odd_fft);
2299
-
2300
- for (int k = 0; k < N/2; k++) {
2301
- float theta = 2*M_PI*k/N;
2302
-
2303
- float re = cos(theta);
2304
- float im = -sin(theta);
2305
-
2306
- float re_odd = odd_fft[2*k + 0];
2307
- float im_odd = odd_fft[2*k + 1];
2308
-
2309
- out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
2310
- out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
2311
-
2312
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2313
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2314
- }
2315
- }
2316
-
2317
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2318
- static bool log_mel_spectrogram(
2319
- whisper_context & wctx,
2320
- const float * samples,
2321
- const int n_samples,
2322
- const int /*sample_rate*/,
2323
- const int fft_size,
2324
- const int fft_step,
2325
- const int n_mel,
2326
- const int n_threads,
2327
- const whisper_filters & filters,
2328
- const bool speed_up,
2329
- whisper_mel & mel) {
2330
- const int64_t t_start_us = ggml_time_us();
2331
-
2332
- // Hanning window
2333
- std::vector<float> hann;
2334
- hann.resize(fft_size);
2335
- for (int i = 0; i < fft_size; i++) {
2336
- hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
2337
- }
2338
-
2339
- mel.n_mel = n_mel;
2340
- mel.n_len = (n_samples)/fft_step;
2341
- mel.data.resize(mel.n_mel*mel.n_len);
2342
-
2343
- const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
2344
-
2345
- //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2346
- //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
2347
-
2348
- std::vector<std::thread> workers(n_threads);
2349
- for (int iw = 0; iw < n_threads; ++iw) {
2350
- workers[iw] = std::thread([&](int ith) {
2351
- std::vector<float> fft_in;
2352
- fft_in.resize(fft_size);
2353
- for (int i = 0; i < fft_size; i++) {
2354
- fft_in[i] = 0.0;
2355
- }
2356
-
2357
- std::vector<float> fft_out;
2358
- fft_out.resize(2*fft_size);
2359
-
2360
- for (int i = ith; i < mel.n_len; i += n_threads) {
2361
- const int offset = i*fft_step;
2362
-
2363
- // apply Hanning window
2364
- for (int j = 0; j < fft_size; j++) {
2365
- if (offset + j < n_samples) {
2366
- fft_in[j] = hann[j]*samples[offset + j];
2367
- } else {
2368
- fft_in[j] = 0.0;
2369
- }
2370
- }
2371
-
2372
- // FFT -> mag^2
2373
- fft(fft_in, fft_out);
2374
-
2375
- for (int j = 0; j < fft_size; j++) {
2376
- fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
2377
- }
2378
- for (int j = 1; j < fft_size/2; j++) {
2379
- //if (i == 0) {
2380
- // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
2381
- //}
2382
- fft_out[j] += fft_out[fft_size - j];
2383
- }
2384
- if (i == 0) {
2385
- //for (int j = 0; j < fft_size; j++) {
2386
- // printf("%d: %e\n", j, fft_out[j]);
2387
- //}
2388
- }
2389
-
2390
- if (speed_up) {
2391
- // scale down in the frequency domain results in a speed up in the time domain
2392
- for (int j = 0; j < n_fft; j++) {
2393
- fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
2394
- }
2395
- }
2396
-
2397
- // mel spectrogram
2398
- for (int j = 0; j < mel.n_mel; j++) {
2399
- double sum = 0.0;
2400
-
2401
- for (int k = 0; k < n_fft; k++) {
2402
- sum += fft_out[k]*filters.data[j*n_fft + k];
2403
- }
2404
- if (sum < 1e-10) {
2405
- sum = 1e-10;
2406
- }
2407
-
2408
- sum = log10(sum);
2409
-
2410
- mel.data[j*mel.n_len + i] = sum;
2411
- }
2412
- }
2413
- }, iw);
2414
- }
2415
-
2416
- for (int iw = 0; iw < n_threads; ++iw) {
2417
- workers[iw].join();
2418
- }
2419
-
2420
- // clamping and normalization
2421
- double mmax = -1e20;
2422
- for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2423
- if (mel.data[i] > mmax) {
2424
- mmax = mel.data[i];
2425
- }
2426
- }
2427
- //printf("%s: max = %f\n", __func__, mmax);
2428
-
2429
- mmax -= 8.0;
2430
-
2431
- for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2432
- if (mel.data[i] < mmax) {
2433
- mel.data[i] = mmax;
2434
- }
2435
-
2436
- mel.data[i] = (mel.data[i] + 4.0)/4.0;
2437
- }
2438
-
2439
- wctx.t_mel_us += ggml_time_us() - t_start_us;
2440
-
2441
- return true;
2442
- }
2443
-
2444
- // split text into tokens
2445
- //
2446
- // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
2447
- //
2448
- // Regex (Python):
2449
- // r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
2450
- //
2451
- // Regex (C++):
2452
- // R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
2453
- //
2454
- static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
2455
- std::vector<std::string> words;
2456
-
2457
- // first split the text into words
2458
- {
2459
- std::string str = text;
2460
- std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
2461
-
2462
- std::regex re(pat);
2463
- std::smatch m;
2464
-
2465
- while (std::regex_search(str, m, re)) {
2466
- for (auto x : m) {
2467
- words.push_back(x);
2468
- }
2469
- str = m.suffix();
2470
- }
2471
- }
2472
-
2473
- // find the longest tokens that form the words:
2474
- std::vector<whisper_vocab::id> tokens;
2475
- for (const auto & word : words) {
2476
- if (word.empty()) continue;
2477
-
2478
- int i = 0;
2479
- int n = word.size();
2480
- while (i < n) {
2481
- int j = n;
2482
- while (j > i) {
2483
- auto it = vocab.token_to_id.find(word.substr(i, j-i));
2484
- if (it != vocab.token_to_id.end()) {
2485
- tokens.push_back(it->second);
2486
- i = j;
2487
- break;
2488
- }
2489
- --j;
2490
- }
2491
- if (i == n) {
2492
- break;
2493
- }
2494
- if (j == i) {
2495
- auto sub = word.substr(i, 1);
2496
- if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
2497
- tokens.push_back(vocab.token_to_id.at(sub));
2498
- } else {
2499
- fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
2500
- }
2501
- ++i;
2502
- }
2503
- }
2504
- }
2505
-
2506
- return tokens;
2507
- }
2508
-
2509
- //
2510
- // interface implementation
2511
- //
2512
-
2513
- struct whisper_context * whisper_init_from_file(const char * path_model) {
2514
- whisper_model_loader loader = {};
2515
-
2516
- fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
2517
-
2518
- auto fin = std::ifstream(path_model, std::ios::binary);
2519
- if (!fin) {
2520
- fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
2521
- return nullptr;
2522
- }
2523
-
2524
- loader.context = &fin;
2525
- loader.read = [](void * ctx, void * output, size_t read_size) {
2526
- std::ifstream * fin = (std::ifstream*)ctx;
2527
- fin->read((char *)output, read_size);
2528
- return read_size;
2529
- };
2530
-
2531
- loader.eof = [](void * ctx) {
2532
- std::ifstream * fin = (std::ifstream*)ctx;
2533
- return fin->eof();
2534
- };
2535
-
2536
- loader.close = [](void * ctx) {
2537
- std::ifstream * fin = (std::ifstream*)ctx;
2538
- fin->close();
2539
- };
2540
-
2541
- return whisper_init(&loader);
2542
- }
2543
-
2544
- struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
2545
- struct buf_context {
2546
- uint8_t* buffer;
2547
- size_t size;
2548
- size_t current_offset;
2549
- };
2550
-
2551
- buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
2552
- whisper_model_loader loader = {};
2553
-
2554
- fprintf(stderr, "%s: loading model from buffer\n", __func__);
2555
-
2556
- loader.context = &ctx;
2557
-
2558
- loader.read = [](void * ctx, void * output, size_t read_size) {
2559
- buf_context * buf = reinterpret_cast<buf_context *>(ctx);
2560
-
2561
- size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
2562
-
2563
- memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
2564
- buf->current_offset += size_to_copy;
2565
-
2566
- return size_to_copy;
2567
- };
2568
-
2569
- loader.eof = [](void * ctx) {
2570
- buf_context * buf = reinterpret_cast<buf_context *>(ctx);
2571
-
2572
- return buf->current_offset >= buf->size;
2573
- };
2574
-
2575
- loader.close = [](void * /*ctx*/) { };
2576
-
2577
- return whisper_init(&loader);
2578
- }
2579
-
2580
- struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2581
- ggml_time_init();
2582
-
2583
- whisper_context * ctx = new whisper_context;
2584
-
2585
- if (!whisper_model_load(loader, *ctx)) {
2586
- loader->close(loader->context);
2587
- fprintf(stderr, "%s: failed to load model\n", __func__);
2588
- delete ctx;
2589
- return nullptr;
2590
- }
2591
-
2592
- loader->close(loader->context);
2593
-
2594
- return ctx;
2595
- }
2596
-
2597
- void whisper_free(struct whisper_context * ctx) {
2598
- if (ctx) {
2599
- if (ctx->model.ctx) {
2600
- ggml_free(ctx->model.ctx);
2601
- }
2602
- if (ctx->model.buf) {
2603
- delete ctx->model.buf;
2604
- }
2605
- if (ctx->kv_cross.ctx) {
2606
- ggml_free(ctx->kv_cross.ctx);
2607
- }
2608
- for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
2609
- if (ctx->decoders[i].kv_self.ctx) {
2610
- ggml_free(ctx->decoders[i].kv_self.ctx);
2611
- }
2612
- }
2613
- delete ctx;
2614
- }
2615
- }
2616
-
2617
- int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2618
- if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
2619
- fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2620
- return -1;
2621
- }
2622
-
2623
- return 0;
2624
- }
2625
-
2626
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
2627
- int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2628
- if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
2629
- fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2630
- return -1;
2631
- }
2632
-
2633
- return 0;
2634
- }
2635
-
2636
- int whisper_set_mel(
2637
- struct whisper_context * ctx,
2638
- const float * data,
2639
- int n_len,
2640
- int n_mel) {
2641
- if (n_mel != WHISPER_N_MEL) {
2642
- fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
2643
- return -1;
2644
- }
2645
-
2646
- ctx->mel.n_len = n_len;
2647
- ctx->mel.n_mel = n_mel;
2648
-
2649
- ctx->mel.data.resize(n_len*n_mel);
2650
- memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
2651
-
2652
- return 0;
2653
- }
2654
-
2655
- int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2656
- if (!whisper_encode(*ctx, offset, n_threads)) {
2657
- fprintf(stderr, "%s: failed to eval\n", __func__);
2658
- return -1;
2659
- }
2660
-
2661
- return 0;
2662
- }
2663
-
2664
- int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2665
- // TODO: add selected_decoder_id to context
2666
- const int selected_decoder_id = 0;
2667
-
2668
- if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2669
- fprintf(stderr, "%s: failed to eval\n", __func__);
2670
- return 1;
2671
- }
2672
-
2673
- return 0;
2674
- }
2675
-
2676
- int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
2677
- const auto res = tokenize(ctx->vocab, text);
2678
-
2679
- if (n_max_tokens < (int) res.size()) {
2680
- fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
2681
- return -1;
2682
- }
2683
-
2684
- for (int i = 0; i < (int) res.size(); i++) {
2685
- tokens[i] = res[i];
2686
- }
2687
-
2688
- return res.size();
2689
- }
2690
-
2691
- int whisper_lang_max_id() {
2692
- auto max_id = 0;
2693
- for (const auto & kv : g_lang) {
2694
- max_id = std::max(max_id, kv.second.first);
2695
- }
2696
-
2697
- return max_id;
2698
- }
2699
-
2700
- int whisper_lang_id(const char * lang) {
2701
- if (!g_lang.count(lang)) {
2702
- for (const auto & kv : g_lang) {
2703
- if (kv.second.second == lang) {
2704
- return kv.second.first;
2705
- }
2706
- }
2707
-
2708
- fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
2709
- return -1;
2710
- }
2711
-
2712
- return g_lang.at(lang).first;
2713
- }
2714
-
2715
- const char * whisper_lang_str(int id) {
2716
- for (const auto & kv : g_lang) {
2717
- if (kv.second.first == id) {
2718
- return kv.first.c_str();
2719
- }
2720
- }
2721
-
2722
- fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
2723
- return nullptr;
2724
- }
2725
-
2726
- int whisper_lang_auto_detect(
2727
- struct whisper_context * ctx,
2728
- int offset_ms,
2729
- int n_threads,
2730
- float * lang_probs) {
2731
- const int seek = offset_ms/10;
2732
-
2733
- if (seek < 0) {
2734
- fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
2735
- return -1;
2736
- }
2737
-
2738
- if (seek >= ctx->mel.n_len) {
2739
- fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
2740
- return -2;
2741
- }
2742
-
2743
- // run the encoder
2744
- if (whisper_encode(ctx, seek, n_threads) != 0) {
2745
- fprintf(stderr, "%s: failed to encode\n", __func__);
2746
- return -6;
2747
- }
2748
-
2749
- const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
2750
-
2751
- if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
2752
- fprintf(stderr, "%s: failed to decode\n", __func__);
2753
- return -7;
2754
- }
2755
-
2756
- auto & logits_id = ctx->logits_id;
2757
- logits_id.clear();
2758
-
2759
- for (const auto & kv : g_lang) {
2760
- const auto token_lang = whisper_token_lang(ctx, kv.second.first);
2761
- logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
2762
- }
2763
-
2764
- // sort descending
2765
- {
2766
- using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
2767
- std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
2768
- return a.first > b.first;
2769
- });
2770
- }
2771
-
2772
- // softmax
2773
- {
2774
- const auto max = logits_id[0].first;
2775
-
2776
- double sum = 0.0f;
2777
- for (auto & kv : logits_id) {
2778
- kv.first = exp(kv.first - max);
2779
- sum += kv.first;
2780
- }
2781
-
2782
- for (auto & kv : logits_id) {
2783
- kv.first /= sum;
2784
- }
2785
- }
2786
-
2787
- {
2788
- for (const auto & prob : logits_id) {
2789
- if (lang_probs) {
2790
- lang_probs[prob.second] = prob.first;
2791
- }
2792
-
2793
- //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
2794
- }
2795
- }
2796
-
2797
- return logits_id[0].second;
2798
- }
2799
-
2800
- int whisper_n_len(struct whisper_context * ctx) {
2801
- return ctx->mel.n_len;
2802
- }
2803
-
2804
- int whisper_n_vocab(struct whisper_context * ctx) {
2805
- return ctx->vocab.n_vocab;
2806
- }
2807
-
2808
- int whisper_n_text_ctx(struct whisper_context * ctx) {
2809
- return ctx->model.hparams.n_text_ctx;
2810
- }
2811
-
2812
- int whisper_n_audio_ctx(struct whisper_context * ctx) {
2813
- return ctx->model.hparams.n_audio_ctx;
2814
- }
2815
-
2816
- int whisper_is_multilingual(struct whisper_context * ctx) {
2817
- return ctx->vocab.is_multilingual() ? 1 : 0;
2818
- }
2819
-
2820
- float * whisper_get_logits(struct whisper_context * ctx) {
2821
- return ctx->logits.data();
2822
- }
2823
-
2824
- const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
2825
- return ctx->vocab.id_to_token.at(token).c_str();
2826
- }
2827
-
2828
- whisper_token whisper_token_eot(struct whisper_context * ctx) {
2829
- return ctx->vocab.token_eot;
2830
- }
2831
-
2832
- whisper_token whisper_token_sot(struct whisper_context * ctx) {
2833
- return ctx->vocab.token_sot;
2834
- }
2835
-
2836
- whisper_token whisper_token_prev(struct whisper_context * ctx) {
2837
- return ctx->vocab.token_prev;
2838
- }
2839
-
2840
- whisper_token whisper_token_solm(struct whisper_context * ctx) {
2841
- return ctx->vocab.token_solm;
2842
- }
2843
-
2844
- whisper_token whisper_token_not(struct whisper_context * ctx) {
2845
- return ctx->vocab.token_not;
2846
- }
2847
-
2848
- whisper_token whisper_token_beg(struct whisper_context * ctx) {
2849
- return ctx->vocab.token_beg;
2850
- }
2851
-
2852
- whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
2853
- return whisper_token_sot(ctx) + 1 + lang_id;
2854
- }
2855
-
2856
- whisper_token whisper_token_translate(void) {
2857
- return whisper_vocab::token_translate;
2858
- }
2859
-
2860
- whisper_token whisper_token_transcribe(void) {
2861
- return whisper_vocab::token_transcribe;
2862
- }
2863
-
2864
- void whisper_print_timings(struct whisper_context * ctx) {
2865
- const int64_t t_end_us = ggml_time_us();
2866
-
2867
- const int32_t n_sample = std::max(1, ctx->n_sample);
2868
- const int32_t n_encode = std::max(1, ctx->n_encode);
2869
- const int32_t n_decode = std::max(1, ctx->n_decode);
2870
-
2871
- fprintf(stderr, "\n");
2872
- fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
2873
- fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
2874
- fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
2875
- fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
2876
- fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
2877
- fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
2878
- fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
2879
- }
2880
-
2881
- void whisper_reset_timings(struct whisper_context * ctx) {
2882
- ctx->t_sample_us = 0;
2883
- ctx->t_encode_us = 0;
2884
- ctx->t_decode_us = 0;
2885
- }
2886
-
2887
- const char * whisper_print_system_info(void) {
2888
- static std::string s;
2889
-
2890
- s = "";
2891
- s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
2892
- s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
2893
- s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
2894
- s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
2895
- s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
2896
- s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
2897
- s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
2898
- s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
2899
- s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
2900
- s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
2901
- s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
2902
- s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
2903
-
2904
- return s.c_str();
2905
- }
2906
-
2907
- ////////////////////////////////////////////////////////////////////////////
2908
-
2909
- struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
2910
- struct whisper_full_params result = {
2911
- /*.strategy =*/ strategy,
2912
-
2913
- /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2914
- /*.n_max_text_ctx =*/ 16384,
2915
- /*.offset_ms =*/ 0,
2916
- /*.duration_ms =*/ 0,
2917
-
2918
- /*.translate =*/ false,
2919
- /*.no_context =*/ false,
2920
- /*.single_segment =*/ false,
2921
- /*.print_special =*/ false,
2922
- /*.print_progress =*/ true,
2923
- /*.print_realtime =*/ false,
2924
- /*.print_timestamps =*/ true,
2925
-
2926
- /*.token_timestamps =*/ false,
2927
- /*.thold_pt =*/ 0.01f,
2928
- /*.thold_ptsum =*/ 0.01f,
2929
- /*.max_len =*/ 0,
2930
- /*.split_on_word =*/ false,
2931
- /*.max_tokens =*/ 0,
2932
-
2933
- /*.speed_up =*/ false,
2934
- /*.audio_ctx =*/ 0,
2935
-
2936
- /*.prompt_tokens =*/ nullptr,
2937
- /*.prompt_n_tokens =*/ 0,
2938
-
2939
- /*.language =*/ "en",
2940
-
2941
- /*.suppress_blank =*/ true,
2942
- /*.suppress_non_speech_tokens =*/ false,
2943
-
2944
- /*.temperature =*/ 0.0f,
2945
- /*.max_initial_ts =*/ 1.0f,
2946
- /*.length_penalty =*/ -1.0f,
2947
-
2948
- /*.temperature_inc =*/ 0.2f,
2949
- /*.entropy_thold =*/ 2.4f,
2950
- /*.logprob_thold =*/ -1.0f,
2951
- /*.no_speech_thold =*/ 0.6f,
2952
-
2953
- /*.greedy =*/ {
2954
- /*.best_of =*/ -1,
2955
- },
2956
-
2957
- /*.beam_search =*/ {
2958
- /*.beam_size =*/ -1,
2959
-
2960
- /*.patience =*/ -1.0f,
2961
- },
2962
-
2963
- /*.new_segment_callback =*/ nullptr,
2964
- /*.new_segment_callback_user_data =*/ nullptr,
2965
-
2966
- /*.encoder_begin_callback =*/ nullptr,
2967
- /*.encoder_begin_callback_user_data =*/ nullptr,
2968
-
2969
- /*.logits_filter_callback =*/ nullptr,
2970
- /*.logits_filter_callback_user_data =*/ nullptr,
2971
- };
2972
-
2973
- switch (strategy) {
2974
- case WHISPER_SAMPLING_GREEDY:
2975
- {
2976
- result.greedy = {
2977
- /*.best_of =*/ 1,
2978
- };
2979
- } break;
2980
- case WHISPER_SAMPLING_BEAM_SEARCH:
2981
- {
2982
- result.beam_search = {
2983
- /*.beam_size =*/ 5,
2984
-
2985
- /*.patience =*/ -1.0f,
2986
- };
2987
- } break;
2988
- }
2989
-
2990
- return result;
2991
- }
2992
-
2993
- // forward declarations
2994
- static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
2995
- static void whisper_exp_compute_token_level_timestamps(
2996
- struct whisper_context & ctx,
2997
- int i_segment,
2998
- float thold_pt,
2999
- float thold_ptsum);
3000
-
3001
- // trim from start (in place)
3002
- static inline void ltrim(std::string &s) {
3003
- s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
3004
- return !std::isspace(ch);
3005
- }));
3006
- }
3007
-
3008
- // trim from end (in place)
3009
- static inline void rtrim(std::string &s) {
3010
- s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
3011
- return !std::isspace(ch);
3012
- }).base(), s.end());
3013
- }
3014
-
3015
- // trim from both ends (in place)
3016
- static inline void trim(std::string &s) {
3017
- rtrim(s);
3018
- ltrim(s);
3019
- }
3020
-
3021
- static inline bool should_split_on_word(const char * txt, bool split_on_word) {
3022
- if (!split_on_word) return true;
3023
-
3024
- return txt[0] == ' ';
3025
- }
3026
-
3027
- // wrap the last segment to max_len characters
3028
- // returns the number of new segments
3029
- static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
3030
- auto segment = ctx.result_all.back();
3031
-
3032
- int res = 1;
3033
- int acc = 0;
3034
-
3035
- std::string text;
3036
-
3037
- for (int i = 0; i < (int) segment.tokens.size(); i++) {
3038
- const auto & token = segment.tokens[i];
3039
- if (token.id >= whisper_token_eot(&ctx)) {
3040
- continue;
3041
- }
3042
-
3043
- const auto txt = whisper_token_to_str(&ctx, token.id);
3044
- const int cur = strlen(txt);
3045
-
3046
- if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
3047
- // split here
3048
- if (split_on_word) {
3049
- trim(text);
3050
- }
3051
-
3052
- ctx.result_all.back().text = std::move(text);
3053
- ctx.result_all.back().t1 = token.t0;
3054
- ctx.result_all.back().tokens.resize(i);
3055
-
3056
- ctx.result_all.push_back({});
3057
- ctx.result_all.back().t0 = token.t0;
3058
- ctx.result_all.back().t1 = segment.t1;
3059
-
3060
- // add tokens [i, end] to the new segment
3061
- ctx.result_all.back().tokens.insert(
3062
- ctx.result_all.back().tokens.end(),
3063
- segment.tokens.begin() + i,
3064
- segment.tokens.end());
3065
-
3066
- acc = 0;
3067
- text = "";
3068
-
3069
- segment = ctx.result_all.back();
3070
- i = -1;
3071
-
3072
- res++;
3073
- } else {
3074
- acc += cur;
3075
- text += txt;
3076
- }
3077
- }
3078
-
3079
- if (split_on_word) {
3080
- trim(text);
3081
- }
3082
- ctx.result_all.back().text = std::move(text);
3083
-
3084
- return res;
3085
- }
3086
-
3087
- static const std::vector<std::string> non_speech_tokens = {
3088
- "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
3089
- "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
3090
- "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
3091
- "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
3092
- };
3093
-
3094
- // process the logits for the selected decoder
3095
- // - applies logit filters
3096
- // - computes logprobs and probs
3097
- static void whisper_process_logits(
3098
- struct whisper_context & ctx,
3099
- const struct whisper_full_params params,
3100
- struct whisper_decoder & decoder,
3101
- float temperature) {
3102
- const auto & vocab = ctx.vocab;
3103
- const auto & tokens_cur = decoder.sequence.tokens;
3104
-
3105
- const bool is_initial = tokens_cur.size() == 0;
3106
- const int n_logits = vocab.id_to_token.size();
3107
-
3108
- WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
3109
-
3110
- // extract the logits for the last token
3111
- // we will be mutating and therefore we don't want to use the ctx.logits buffer directly
3112
- auto & probs = decoder.probs;
3113
- auto & logits = decoder.logits;
3114
- auto & logprobs = decoder.logprobs;
3115
- {
3116
- logits.resize(n_logits);
3117
- memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
3118
-
3119
- if (temperature > 0.0f) {
3120
- for (int i = 0; i < n_logits; i++) {
3121
- logits[i] /= temperature;
3122
- }
3123
- }
3124
-
3125
- // will be populated a bit later
3126
- probs.resize(n_logits);
3127
- logprobs.resize(n_logits);
3128
- }
3129
-
3130
- // apply logit filters here
3131
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
3132
- {
3133
- // suppress blank
3134
- // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
3135
- if (params.suppress_blank) {
3136
- if (is_initial) {
3137
- logits[vocab.token_eot] = -INFINITY;
3138
- logits[vocab.token_to_id.at(" ")] = -INFINITY;
3139
- }
3140
- }
3141
-
3142
- // suppress <|notimestamps|> token
3143
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
3144
- logits[vocab.token_not] = -INFINITY;
3145
-
3146
- // suppress sot and solm tokens
3147
- logits[vocab.token_sot] = -INFINITY;
3148
- logits[vocab.token_solm] = -INFINITY;
3149
-
3150
- // suppress task tokens
3151
- logits[vocab.token_translate] = -INFINITY;
3152
- logits[vocab.token_transcribe] = -INFINITY;
3153
-
3154
- if (params.logits_filter_callback) {
3155
- params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
3156
- }
3157
-
3158
- // suppress non-speech tokens
3159
- // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
3160
- if (params.suppress_non_speech_tokens) {
3161
- for (const std::string & token : non_speech_tokens) {
3162
- const std::string suppress_tokens[] = {token, " " + token};
3163
- for (const std::string & suppress_token : suppress_tokens) {
3164
- if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
3165
- logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
3166
- }
3167
- }
3168
- }
3169
-
3170
- // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
3171
- if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
3172
- logits[vocab.token_to_id.at(" -")] = -INFINITY;
3173
- }
3174
- if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
3175
- logits[vocab.token_to_id.at(" '")] = -INFINITY;
3176
- }
3177
- }
3178
-
3179
- // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
3180
- // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
3181
- {
3182
- const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
3183
- const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
3184
-
3185
- //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
3186
-
3187
- if (last_was_timestamp) {
3188
- if (penultimate_was_timestamp) {
3189
- for (int i = vocab.token_beg; i < n_logits; ++i) {
3190
- logits[i] = -INFINITY;
3191
- }
3192
- } else {
3193
- for (int i = 0; i < vocab.token_eot; ++i) {
3194
- logits[i] = -INFINITY;
3195
- }
3196
- }
3197
- }
3198
- }
3199
-
3200
- // the initial timestamp cannot be larger than max_initial_ts
3201
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
3202
- if (is_initial && params.max_initial_ts > 0.0f) {
3203
- const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
3204
- const int tid0 = std::round(params.max_initial_ts/precision);
3205
-
3206
- for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
3207
- logits[i] = -INFINITY;
3208
- }
3209
- }
3210
-
3211
- // condition timestamp tokens to be increasing
3212
- // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556
3213
- if (decoder.has_ts) {
3214
- const int tid0 = decoder.seek_delta/2;
3215
-
3216
- for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) {
3217
- logits[i] = -INFINITY;
3218
- }
3219
- }
3220
-
3221
- // populate the logprobs array (log_softmax)
3222
- {
3223
- const float logit_max = *std::max_element(logits.begin(), logits.end());
3224
- float logsumexp = 0.0f;
3225
- for (int i = 0; i < n_logits; ++i) {
3226
- if (logits[i] > -INFINITY) {
3227
- logsumexp += expf(logits[i] - logit_max);
3228
- }
3229
- }
3230
- logsumexp = logf(logsumexp) + logit_max;
3231
-
3232
- for (int i = 0; i < n_logits; ++i) {
3233
- if (logits[i] > -INFINITY) {
3234
- logprobs[i] = logits[i] - logsumexp;
3235
- } else {
3236
- logprobs[i] = -INFINITY;
3237
- }
3238
- }
3239
- }
3240
-
3241
- // if sum of probability over timestamps is above any other token, sample timestamp
3242
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
3243
- {
3244
- // logsumexp over timestamps
3245
- float timestamp_logprob = -INFINITY;
3246
- {
3247
- float logsumexp = 0.0f;
3248
- const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
3249
- for (int i = vocab.token_beg; i < n_logits; ++i) {
3250
- if (logprobs[i] > -INFINITY) {
3251
- logsumexp += expf(logprobs[i] - logprob_max);
3252
- }
3253
- }
3254
- if (logsumexp > 0.0f) {
3255
- timestamp_logprob = logf(logsumexp) + logprob_max;
3256
- }
3257
- }
3258
-
3259
- const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
3260
-
3261
- //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
3262
-
3263
- if (timestamp_logprob > max_text_token_logprob) {
3264
- for (int i = 0; i < vocab.token_beg; ++i) {
3265
- logits[i] = -INFINITY;
3266
- logprobs[i] = -INFINITY;
3267
- }
3268
- }
3269
- }
3270
- }
3271
-
3272
- // compute probs
3273
- {
3274
- for (int i = 0; i < n_logits; ++i) {
3275
- if (logits[i] == -INFINITY) {
3276
- probs[i] = 0.0f;
3277
- } else {
3278
- probs[i] = expf(logprobs[i]);
3279
- }
3280
- }
3281
- }
3282
-
3283
- #if 0
3284
- // print first 100 logits - token string : logit
3285
- for (int i = 0; i < 100; i++) {
3286
- const auto token = vocab.id_to_token.at(i);
3287
- const auto prob = probs[i];
3288
- const auto logit = logits[i];
3289
- const auto logprob = logprobs[i];
3290
- printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
3291
- }
3292
-
3293
- // "And", "and", " And", " and"
3294
- printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
3295
- printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
3296
- printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
3297
- printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
3298
- printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
3299
-
3300
- printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
3301
- printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
3302
- printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
3303
- printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
3304
- printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
3305
-
3306
- printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
3307
- printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
3308
- printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
3309
- printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
3310
- printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
3311
- #endif
3312
- }
3313
-
3314
- static whisper_token_data whisper_sample_token(
3315
- whisper_context & ctx,
3316
- const whisper_decoder & decoder,
3317
- bool best) {
3318
- whisper_token_data result = {
3319
- 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
3320
- };
3321
-
3322
- const auto & vocab = ctx.vocab;
3323
-
3324
- const auto & probs = decoder.probs;
3325
- const auto & logprobs = decoder.logprobs;
3326
-
3327
- const int n_logits = vocab.n_vocab;
3328
-
3329
- {
3330
- double sum_ts = 0.0;
3331
- double max_ts = 0.0;
3332
-
3333
- for (int i = vocab.token_beg; i < n_logits; i++) {
3334
- if (probs[i] == -INFINITY) {
3335
- continue;
3336
- }
3337
-
3338
- sum_ts += probs[i];
3339
- if (max_ts < probs[i]) {
3340
- max_ts = probs[i];
3341
- result.tid = i;
3342
- }
3343
- }
3344
-
3345
- result.pt = max_ts/(sum_ts + 1e-10);
3346
- result.ptsum = sum_ts;
3347
- }
3348
-
3349
- if (best) {
3350
- for (int i = 0; i < n_logits; ++i) {
3351
- if (result.p < probs[i]) {
3352
- result.id = i;
3353
- result.p = probs[i];
3354
- result.plog = logprobs[i];
3355
- }
3356
- }
3357
- } else {
3358
- std::discrete_distribution<> dist(probs.begin(), probs.end());
3359
-
3360
- result.id = dist(ctx.rng);
3361
- result.p = probs[result.id];
3362
- result.plog = logprobs[result.id];
3363
- }
3364
-
3365
- if (result.id >= vocab.token_beg) {
3366
- result.tid = result.id;
3367
- result.pt = result.p;
3368
- }
3369
-
3370
- ctx.n_sample++;
3371
-
3372
- return result;
3373
- }
3374
-
3375
- static std::vector<whisper_token_data> whisper_sample_token_topk(
3376
- whisper_context & ctx,
3377
- const whisper_decoder & decoder,
3378
- int k) {
3379
- const auto & vocab = ctx.vocab;
3380
-
3381
- const auto & probs = decoder.probs;
3382
- const auto & logits = decoder.logits;
3383
- const auto & logprobs = decoder.logprobs;
3384
-
3385
- const int n_logits = vocab.n_vocab;
3386
-
3387
- auto & logits_id = ctx.logits_id;
3388
-
3389
- logits_id.clear();
3390
- for (int i = 0; i < n_logits; ++i) {
3391
- logits_id.push_back({ logits[i], i });
3392
- }
3393
-
3394
- std::partial_sort(
3395
- logits_id.begin(),
3396
- logits_id.begin() + k, logits_id.end(),
3397
- [](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
3398
- return a.first > b.first;
3399
- });
3400
-
3401
- std::vector<whisper_token_data> result;
3402
- result.reserve(k);
3403
-
3404
- whisper_token tid = vocab.token_beg;
3405
-
3406
- float pt = 0.0;
3407
- float ptsum = 0.0;
3408
-
3409
- {
3410
- double sum_ts = 0.0;
3411
- double max_ts = 0.0;
3412
-
3413
- for (int i = vocab.token_beg; i < n_logits; i++) {
3414
- if (probs[i] == -INFINITY) {
3415
- continue;
3416
- }
3417
-
3418
- sum_ts += probs[i];
3419
- if (max_ts < probs[i]) {
3420
- max_ts = probs[i];
3421
- tid = i;
3422
- }
3423
- }
3424
-
3425
- pt = max_ts/(sum_ts + 1e-10);
3426
- ptsum = sum_ts;
3427
- }
3428
-
3429
- for (int i = 0; i < k; ++i) {
3430
- const auto id = logits_id[i].second;
3431
-
3432
- result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
3433
-
3434
- if (result[i].id >= vocab.token_beg) {
3435
- result[i].tid = result[i].id;
3436
- result[i].pt = result[i].p;
3437
- }
3438
- }
3439
-
3440
- ctx.n_sample++;
3441
-
3442
- return result;
3443
- }
3444
-
3445
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
3446
- static void whisper_sequence_score(
3447
- const struct whisper_full_params & params,
3448
- whisper_sequence & sequence) {
3449
- if (sequence.result_len == 0) {
3450
- return;
3451
- }
3452
-
3453
- double result = 0.0f;
3454
-
3455
- for (int i = 0; i < sequence.result_len; ++i) {
3456
- result += sequence.tokens[i].plog;
3457
- }
3458
-
3459
- sequence.sum_logprobs = result;
3460
- sequence.avg_logprobs = result/sequence.result_len;
3461
-
3462
- double penalty = sequence.result_len;
3463
-
3464
- if (params.length_penalty > 0.0f) {
3465
- penalty = pow((5.0 + penalty)/6.0, params.length_penalty);
3466
- }
3467
-
3468
- sequence.score = result/penalty;
3469
-
3470
- // compute the entropy of the sequence of the last 32 tokens
3471
- {
3472
- const int n = 32;
3473
-
3474
- int cnt = 0;
3475
- double entropy = 0.0f;
3476
-
3477
- std::map<whisper_token, int> token_counts;
3478
- for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) {
3479
- token_counts[sequence.tokens[i].id]++;
3480
- cnt++;
3481
- }
3482
-
3483
- for (const auto & kv : token_counts) {
3484
- const auto p = kv.second/(double)cnt;
3485
- entropy -= p*log(p);
3486
-
3487
- //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
3488
- }
3489
-
3490
- sequence.entropy = entropy;
3491
- }
3492
- }
3493
-
3494
- int whisper_full(
3495
- struct whisper_context * ctx,
3496
- struct whisper_full_params params,
3497
- const float * samples,
3498
- int n_samples) {
3499
- // clear old results
3500
- auto & result_all = ctx->result_all;
3501
-
3502
- result_all.clear();
3503
-
3504
- // compute log mel spectrogram
3505
- if (params.speed_up) {
3506
- if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
3507
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3508
- return -1;
3509
- }
3510
- } else {
3511
- if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
3512
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3513
- return -2;
3514
- }
3515
- }
3516
-
3517
- // auto-detect language if not specified
3518
- if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
3519
- std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
3520
-
3521
- const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
3522
- if (lang_id < 0) {
3523
- fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
3524
- return -3;
3525
- }
3526
- ctx->lang_id = lang_id;
3527
- params.language = whisper_lang_str(lang_id);
3528
-
3529
- fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
3530
- }
3531
-
3532
- if (params.token_timestamps) {
3533
- ctx->t_beg = 0;
3534
- ctx->t_last = 0;
3535
- ctx->tid_last = 0;
3536
- ctx->energy = get_signal_energy(samples, n_samples, 32);
3537
- }
3538
-
3539
- const int seek_start = params.offset_ms/10;
3540
- const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
3541
-
3542
- // if length of spectrogram is less than 1s (100 samples), then return
3543
- // basically don't process anything that is less than 1s
3544
- // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
3545
- if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
3546
- return 0;
3547
- }
3548
-
3549
- // a set of temperatures to use
3550
- // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
3551
- std::vector<float> temperatures;
3552
- if (params.temperature_inc > 0.0f) {
3553
- for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) {
3554
- temperatures.push_back(t);
3555
- }
3556
- } else {
3557
- temperatures.push_back(params.temperature);
3558
- }
3559
-
3560
- // initialize the decoders
3561
- int n_decoders = 1;
3562
-
3563
- switch (params.strategy) {
3564
- case WHISPER_SAMPLING_GREEDY:
3565
- {
3566
- n_decoders = params.greedy.best_of;
3567
- } break;
3568
- case WHISPER_SAMPLING_BEAM_SEARCH:
3569
- {
3570
- n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size);
3571
- } break;
3572
- };
3573
-
3574
- n_decoders = std::max(1, n_decoders);
3575
-
3576
- // TAGS: WHISPER_DECODER_INIT
3577
- for (int j = 1; j < n_decoders && ctx->running; j++) {
3578
- auto & decoder = ctx->decoders[j];
3579
-
3580
- if (decoder.kv_self.ctx == nullptr) {
3581
- decoder.kv_self = ctx->decoders[0].kv_self;
3582
- if (!kv_cache_reinit(decoder.kv_self)) {
3583
- fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
3584
- return -4;
3585
- }
3586
-
3587
- WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
3588
-
3589
- decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
3590
-
3591
- decoder.probs.resize (ctx->vocab.n_vocab);
3592
- decoder.logits.resize (ctx->vocab.n_vocab);
3593
- decoder.logprobs.resize(ctx->vocab.n_vocab);
3594
- }
3595
- }
3596
-
3597
- // the accumulated text context so far
3598
- auto & prompt_past = ctx->prompt_past;
3599
- if (params.no_context) {
3600
- prompt_past.clear();
3601
- }
3602
-
3603
- // prepend the prompt tokens to the prompt_past
3604
- if (params.prompt_tokens && params.prompt_n_tokens > 0) {
3605
- // parse tokens from the pointer
3606
- for (int i = 0; i < params.prompt_n_tokens; i++) {
3607
- prompt_past.push_back(params.prompt_tokens[i]);
3608
- }
3609
- std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
3610
- }
3611
-
3612
- // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
3613
- if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
3614
- fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
3615
- return -5;
3616
- }
3617
- ctx->exp_n_audio_ctx = params.audio_ctx;
3618
-
3619
- // these tokens determine the task that will be performed
3620
- std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
3621
- if (whisper_is_multilingual(ctx)) {
3622
- const int lang_id = whisper_lang_id(params.language);
3623
- ctx->lang_id = lang_id;
3624
- prompt_init.push_back(whisper_token_lang(ctx, lang_id));
3625
- if (params.translate) {
3626
- prompt_init.push_back(whisper_token_translate());
3627
- } else {
3628
- prompt_init.push_back(whisper_token_transcribe());
3629
- }
3630
- }
3631
-
3632
- int progress_prev = 0;
3633
- int progress_step = 5;
3634
-
3635
- int seek = seek_start;
3636
-
3637
- std::vector<whisper_token> prompt;
3638
- prompt.reserve(whisper_n_text_ctx(ctx));
3639
-
3640
- // beam-search helpers
3641
- struct kv_buf {
3642
- std::vector<uint8_t> k;
3643
- std::vector<uint8_t> v;
3644
- };
3645
-
3646
- std::vector<kv_buf> kv_bufs;
3647
-
3648
- struct beam_candidate {
3649
- int decoder_idx;
3650
- int seek_delta;
3651
-
3652
- bool has_ts;
3653
-
3654
- whisper_sequence sequence;
3655
- };
3656
-
3657
- std::vector<beam_candidate> beam_candidates;
3658
-
3659
- // main loop
3660
- while (ctx->running) {
3661
- const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
3662
- while (progress_cur >= progress_prev + progress_step) {
3663
- progress_prev += progress_step;
3664
- if (params.print_progress) {
3665
- fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
3666
- }
3667
- }
3668
-
3669
- // of only 1 second left, then stop
3670
- if (seek + 100 >= seek_end) {
3671
- break;
3672
- }
3673
-
3674
- if (params.encoder_begin_callback) {
3675
- if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
3676
- fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
3677
- break;
3678
- }
3679
- }
3680
-
3681
- // encode audio features starting at offset seek
3682
- if (!whisper_encode(*ctx, seek, params.n_threads)) {
3683
- fprintf(stderr, "%s: failed to encode\n", __func__);
3684
- return -6;
3685
- }
3686
-
3687
- // if there is a very short audio segment left to process, we remove any past prompt since it tends
3688
- // to confuse the decoder and often make it repeat or hallucinate stuff
3689
- if (seek > seek_start && seek + 500 >= seek_end) {
3690
- prompt_past.clear();
3691
- }
3692
-
3693
- int best_decoder_id = 0;
3694
-
3695
- for (int it = 0; it < (int) temperatures.size(); ++it) {
3696
- const float t_cur = temperatures[it];
3697
-
3698
- int n_decoders_cur = 1;
3699
-
3700
- switch (params.strategy) {
3701
- case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
3702
- {
3703
- if (t_cur > 0.0f) {
3704
- n_decoders_cur = params.greedy.best_of;
3705
- }
3706
- } break;
3707
- case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
3708
- {
3709
- if (t_cur > 0.0f) {
3710
- n_decoders_cur = params.greedy.best_of;
3711
- } else {
3712
- n_decoders_cur = params.beam_search.beam_size;
3713
- }
3714
- } break;
3715
- };
3716
-
3717
- n_decoders_cur = std::max(1, n_decoders_cur);
3718
-
3719
- WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
3720
-
3721
- // TAGS: WHISPER_DECODER_INIT
3722
- for (int j = 0; j < n_decoders_cur; ++j) {
3723
- auto & decoder = ctx->decoders[j];
3724
-
3725
- decoder.kv_self.n = 0;
3726
-
3727
- decoder.sequence.tokens.clear();
3728
- decoder.sequence.result_len = 0;
3729
- decoder.sequence.sum_logprobs_all = 0.0;
3730
- decoder.sequence.sum_logprobs = -INFINITY;
3731
- decoder.sequence.avg_logprobs = -INFINITY;
3732
- decoder.sequence.entropy = 0.0;
3733
- decoder.sequence.score = -INFINITY;
3734
-
3735
- decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
3736
-
3737
- decoder.failed = false;
3738
- decoder.completed = false;
3739
- decoder.has_ts = false;
3740
- }
3741
-
3742
- // init prompt and kv cache for the current iteration
3743
- // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
3744
- {
3745
- prompt.clear();
3746
-
3747
- // if we have already generated some text, use it as a prompt to condition the next generation
3748
- if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
3749
- int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
3750
-
3751
- prompt = { whisper_token_prev(ctx) };
3752
- prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
3753
- }
3754
-
3755
- // init new transcription with sot, language (opt) and task tokens
3756
- prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
3757
-
3758
- // print the prompt
3759
- WHISPER_PRINT_DEBUG("\n\n");
3760
- for (int i = 0; i < (int) prompt.size(); i++) {
3761
- WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
3762
- }
3763
- WHISPER_PRINT_DEBUG("\n\n");
3764
-
3765
- if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
3766
- fprintf(stderr, "%s: failed to decode\n", __func__);
3767
- return -7;
3768
- }
3769
-
3770
- {
3771
- const int64_t t_start_sample_us = ggml_time_us();
3772
-
3773
- whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
3774
-
3775
- ctx->decoders[0].kv_self.n += prompt.size();
3776
-
3777
- for (int j = 1; j < n_decoders_cur; ++j) {
3778
- auto & decoder = ctx->decoders[j];
3779
-
3780
- memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
3781
- memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
3782
-
3783
- decoder.kv_self.n += prompt.size();
3784
-
3785
- memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
3786
- memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
3787
- memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
3788
- }
3789
-
3790
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
3791
- }
3792
- }
3793
-
3794
- for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
3795
- const int64_t t_start_sample_us = ggml_time_us();
3796
-
3797
- // store the KV caches of all decoders when doing beam-search
3798
- if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
3799
- kv_bufs.resize(n_decoders_cur);
3800
- for (int j = 0; j < n_decoders_cur; ++j) {
3801
- auto & decoder = ctx->decoders[j];
3802
-
3803
- if (decoder.completed || decoder.failed) {
3804
- continue;
3805
- }
3806
-
3807
- kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
3808
- kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
3809
-
3810
- memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
3811
- memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
3812
- }
3813
-
3814
- beam_candidates.clear();
3815
- }
3816
-
3817
- // generate new sequence candidates for each decoder
3818
- for (int j = 0; j < n_decoders_cur; ++j) {
3819
- auto & decoder = ctx->decoders[j];
3820
-
3821
- if (decoder.completed || decoder.failed) {
3822
- continue;
3823
- }
3824
-
3825
- switch (params.strategy) {
3826
- case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
3827
- {
3828
- if (t_cur < 1e-6f) {
3829
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
3830
- } else {
3831
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
3832
- }
3833
-
3834
- decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
3835
- } break;
3836
- case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
3837
- {
3838
- const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
3839
-
3840
- for (const auto & token : tokens_new) {
3841
- beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
3842
- beam_candidates.back().sequence.tokens.push_back(token);
3843
- beam_candidates.back().sequence.sum_logprobs_all += token.plog;
3844
-
3845
- //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
3846
- }
3847
- } break;
3848
- };
3849
- }
3850
-
3851
- // for beam-search, choose the top candidates and update the KV caches
3852
- if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
3853
- std::sort(
3854
- beam_candidates.begin(),
3855
- beam_candidates.end(),
3856
- [](const beam_candidate & a, const beam_candidate & b) {
3857
- return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
3858
- });
3859
-
3860
- uint32_t cur_c = 0;
3861
-
3862
- for (int j = 0; j < n_decoders_cur; ++j) {
3863
- auto & decoder = ctx->decoders[j];
3864
-
3865
- if (decoder.completed || decoder.failed) {
3866
- continue;
3867
- }
3868
-
3869
- auto & cur = beam_candidates[cur_c++];
3870
-
3871
- while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
3872
- ++cur_c;
3873
- }
3874
-
3875
- decoder.sequence = cur.sequence;
3876
- decoder.seek_delta = cur.seek_delta;
3877
- decoder.has_ts = cur.has_ts;
3878
-
3879
- memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
3880
- memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
3881
-
3882
- WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
3883
- __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
3884
- }
3885
- }
3886
-
3887
- // update the decoder state
3888
- // - check if the sequence is completed
3889
- // - check if the sequence is failed
3890
- // - update sliding window based on timestamp tokens
3891
- for (int j = 0; j < n_decoders_cur; ++j) {
3892
- auto & decoder = ctx->decoders[j];
3893
-
3894
- if (decoder.completed || decoder.failed) {
3895
- continue;
3896
- }
3897
-
3898
- auto & has_ts = decoder.has_ts;
3899
- auto & failed = decoder.failed;
3900
- auto & completed = decoder.completed;
3901
- auto & seek_delta = decoder.seek_delta;
3902
- auto & result_len = decoder.sequence.result_len;
3903
-
3904
- {
3905
- const auto & token = decoder.sequence.tokens.back();
3906
-
3907
- // timestamp token - update sliding window
3908
- if (token.id > whisper_token_beg(ctx)) {
3909
- const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
3910
-
3911
- // do not allow to go back in time
3912
- if (has_ts && seek_delta > seek_delta_new && result_len < i) {
3913
- failed = true; // TODO: maybe this is not a failure ?
3914
- continue;
3915
- }
3916
-
3917
- seek_delta = seek_delta_new;
3918
- result_len = i + 1;
3919
- has_ts = true;
3920
- }
3921
-
3922
- #ifdef WHISPER_DEBUG
3923
- {
3924
- const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
3925
- WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
3926
- __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
3927
- }
3928
- #endif
3929
-
3930
- // end of segment
3931
- if (token.id == whisper_token_eot(ctx) || // end of text token
3932
- (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
3933
- (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
3934
- ) {
3935
- if (result_len == 0) {
3936
- if (seek + seek_delta + 100 >= seek_end) {
3937
- result_len = i + 1;
3938
- } else {
3939
- failed = true;
3940
- continue;
3941
- }
3942
- }
3943
-
3944
- if (params.single_segment) {
3945
- result_len = i + 1;
3946
- seek_delta = 100*WHISPER_CHUNK_SIZE;
3947
- }
3948
-
3949
- completed = true;
3950
- continue;
3951
- }
3952
-
3953
- // TESTS: if no tensors are loaded, it means we are running tests
3954
- if (ctx->model.n_loaded == 0) {
3955
- seek_delta = 100*WHISPER_CHUNK_SIZE;
3956
- completed = true;
3957
- continue;
3958
- }
3959
- }
3960
-
3961
- // sometimes, the decoding can get stuck in a repetition loop
3962
- // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
3963
- if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
3964
- failed = true;
3965
- continue;
3966
- }
3967
- }
3968
-
3969
- // check if all decoders have finished (i.e. completed or failed)
3970
- {
3971
- bool completed_all = true;
3972
-
3973
- for (int j = 0; j < n_decoders_cur; ++j) {
3974
- auto & decoder = ctx->decoders[j];
3975
-
3976
- if (decoder.completed || decoder.failed) {
3977
- continue;
3978
- }
3979
-
3980
- completed_all = false;
3981
- }
3982
-
3983
- if (completed_all) {
3984
- break;
3985
- }
3986
- }
3987
-
3988
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
3989
-
3990
- // obtain logits for the next token
3991
- for (int j = 0; j < n_decoders_cur; ++j) {
3992
- auto & decoder = ctx->decoders[j];
3993
-
3994
- if (decoder.failed || decoder.completed) {
3995
- continue;
3996
- }
3997
-
3998
- decoder.tokens_tmp.resize(1);
3999
- decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
4000
-
4001
- //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4002
-
4003
- if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4004
- fprintf(stderr, "%s: failed to decode\n", __func__);
4005
- return -8;
4006
- }
4007
-
4008
- {
4009
- const int64_t t_start_sample_us = ggml_time_us();
4010
-
4011
- whisper_process_logits(*ctx, params, decoder, t_cur);
4012
-
4013
- ++decoder.kv_self.n;
4014
-
4015
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
4016
- }
4017
- }
4018
- }
4019
-
4020
- // rank the resulting sequences and select the best one
4021
- {
4022
- double best_score = -INFINITY;
4023
-
4024
- for (int j = 0; j < n_decoders_cur; ++j) {
4025
- auto & decoder = ctx->decoders[j];
4026
-
4027
- if (decoder.failed) {
4028
- continue;
4029
- }
4030
-
4031
- decoder.sequence.tokens.resize(decoder.sequence.result_len);
4032
- whisper_sequence_score(params, decoder.sequence);
4033
-
4034
- WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
4035
- __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
4036
-
4037
- if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
4038
- WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
4039
- __func__, j, decoder.sequence.entropy, params.entropy_thold);
4040
-
4041
- decoder.failed = true;
4042
- ctx->n_fail_h++;
4043
-
4044
- continue;
4045
- }
4046
-
4047
- if (best_score < decoder.sequence.score) {
4048
- best_score = decoder.sequence.score;
4049
- best_decoder_id = j;
4050
- }
4051
- }
4052
-
4053
- WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
4054
- }
4055
-
4056
- // was the decoding successful for the current temperature?
4057
- {
4058
- bool success = true;
4059
-
4060
- const auto & decoder = ctx->decoders[best_decoder_id];
4061
-
4062
- if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
4063
- success = false;
4064
- ctx->n_fail_p++;
4065
- }
4066
-
4067
- if (success) {
4068
- //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
4069
- // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
4070
- //}
4071
-
4072
- break;
4073
- }
4074
- }
4075
-
4076
- WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
4077
- }
4078
-
4079
- // output results through a user-provided callback
4080
- {
4081
- const auto & best_decoder = ctx->decoders[best_decoder_id];
4082
-
4083
- const auto seek_delta = best_decoder.seek_delta;
4084
- const auto result_len = best_decoder.sequence.result_len;
4085
-
4086
- const auto & tokens_cur = best_decoder.sequence.tokens;
4087
-
4088
- //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
4089
-
4090
- // update prompt_past
4091
- prompt_past.clear();
4092
- if (prompt.front() == whisper_token_prev(ctx)) {
4093
- prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
4094
- }
4095
-
4096
- for (int i = 0; i < result_len; ++i) {
4097
- prompt_past.push_back(tokens_cur[i].id);
4098
- }
4099
-
4100
- // store the text from this iteration
4101
- if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
4102
- int i0 = 0;
4103
- auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
4104
-
4105
- std::string text;
4106
-
4107
- for (int i = 0; i < (int) tokens_cur.size(); i++) {
4108
- //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
4109
- // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
4110
- // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
4111
-
4112
- if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
4113
- } else {
4114
- text += whisper_token_to_str(ctx, tokens_cur[i].id);
4115
- }
4116
-
4117
- if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
4118
- const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
4119
-
4120
- if (!text.empty()) {
4121
- const auto tt0 = params.speed_up ? 2*t0 : t0;
4122
- const auto tt1 = params.speed_up ? 2*t1 : t1;
4123
-
4124
- if (params.print_realtime) {
4125
- if (params.print_timestamps) {
4126
- printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
4127
- } else {
4128
- printf("%s", text.c_str());
4129
- fflush(stdout);
4130
- }
4131
- }
4132
-
4133
- //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
4134
-
4135
- result_all.push_back({ tt0, tt1, text, {} });
4136
- for (int j = i0; j <= i; j++) {
4137
- result_all.back().tokens.push_back(tokens_cur[j]);
4138
- }
4139
-
4140
- int n_new = 1;
4141
-
4142
- if (params.token_timestamps) {
4143
- whisper_exp_compute_token_level_timestamps(
4144
- *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4145
-
4146
- if (params.max_len > 0) {
4147
- n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4148
- }
4149
- }
4150
- if (params.new_segment_callback) {
4151
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4152
- }
4153
- }
4154
- text = "";
4155
- while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
4156
- i++;
4157
- }
4158
- i--;
4159
- t0 = t1;
4160
- i0 = i + 1;
4161
- }
4162
- }
4163
-
4164
- if (!text.empty()) {
4165
- const auto t1 = seek + seek_delta;
4166
-
4167
- const auto tt0 = params.speed_up ? 2*t0 : t0;
4168
- const auto tt1 = params.speed_up ? 2*t1 : t1;
4169
-
4170
- if (params.print_realtime) {
4171
- if (params.print_timestamps) {
4172
- printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
4173
- } else {
4174
- printf("%s", text.c_str());
4175
- fflush(stdout);
4176
- }
4177
- }
4178
-
4179
- result_all.push_back({ tt0, tt1, text, {} });
4180
- for (int j = i0; j < (int) tokens_cur.size(); j++) {
4181
- result_all.back().tokens.push_back(tokens_cur[j]);
4182
- }
4183
-
4184
- int n_new = 1;
4185
-
4186
- if (params.token_timestamps) {
4187
- whisper_exp_compute_token_level_timestamps(
4188
- *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4189
-
4190
- if (params.max_len > 0) {
4191
- n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4192
- }
4193
- }
4194
- if (params.new_segment_callback) {
4195
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4196
- }
4197
- }
4198
- }
4199
-
4200
- // update audio window
4201
- seek += seek_delta;
4202
-
4203
- WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
4204
- }
4205
- }
4206
-
4207
- return 0;
4208
- }
4209
-
4210
- void whisper_running_abort(struct whisper_context * ctx) {
4211
- ctx->running = false;
4212
- }
4213
-
4214
- void whisper_running_restore(struct whisper_context * ctx) {
4215
- ctx->running = true;
4216
- }
4217
-
4218
- bool whisper_running_state(struct whisper_context * ctx) {
4219
- return ctx->running;
4220
- }
4221
-
4222
- int whisper_full_parallel(
4223
- struct whisper_context * ctx,
4224
- struct whisper_full_params params,
4225
- const float * samples,
4226
- int n_samples,
4227
- int n_processors) {
4228
- if (n_processors == 1) {
4229
- return whisper_full(ctx, params, samples, n_samples);
4230
- }
4231
-
4232
- int ret = 0;
4233
-
4234
- // prepare separate contexts for each thread
4235
- std::vector<struct whisper_context> ctxs(n_processors - 1);
4236
-
4237
- for (int i = 0; i < n_processors - 1; ++i) {
4238
- auto & ctx_p = ctxs[i];
4239
-
4240
- ctx_p = *ctx;
4241
-
4242
- ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
4243
-
4244
- ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
4245
-
4246
- if (!kv_cache_reinit(ctx_p.kv_cross)) {
4247
- fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
4248
- return false;
4249
- }
4250
-
4251
- // TAGS: WHISPER_DECODER_INIT
4252
- for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4253
- if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
4254
- fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
4255
- return false;
4256
- }
4257
-
4258
- ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
4259
-
4260
- ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
4261
- ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
4262
- ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
4263
- }
4264
- }
4265
-
4266
- const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
4267
- const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
4268
-
4269
- // the calling thread will process the first chunk
4270
- // while the other threads will process the remaining chunks
4271
-
4272
- std::vector<std::thread> workers(n_processors - 1);
4273
- for (int i = 0; i < n_processors - 1; ++i) {
4274
- const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
4275
- const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
4276
-
4277
- auto params_cur = params;
4278
-
4279
- params_cur.offset_ms = 0;
4280
- params_cur.print_progress = false;
4281
- params_cur.print_realtime = false;
4282
-
4283
- params_cur.new_segment_callback = nullptr;
4284
- params_cur.new_segment_callback_user_data = nullptr;
4285
-
4286
- workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
4287
- }
4288
-
4289
- {
4290
- auto params_cur = params;
4291
-
4292
- ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
4293
- }
4294
-
4295
- for (int i = 0; i < n_processors - 1; ++i) {
4296
- workers[i].join();
4297
- }
4298
-
4299
- const int64_t offset_t = (int64_t) params.offset_ms/10.0;
4300
-
4301
- // combine results into ctx->result_all
4302
- for (int i = 0; i < n_processors - 1; ++i) {
4303
- auto & results_i = ctxs[i].result_all;
4304
-
4305
- for (auto & result : results_i) {
4306
- // correct the segment timestamp taking into account the offset
4307
- result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
4308
- result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
4309
-
4310
- // make sure that segments are not overlapping
4311
- if (!ctx->result_all.empty()) {
4312
- result.t0 = std::max(result.t0, ctx->result_all.back().t1);
4313
- }
4314
-
4315
- ctx->result_all.push_back(std::move(result));
4316
-
4317
- // call the new_segment_callback for each segment
4318
- if (params.new_segment_callback) {
4319
- params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
4320
- }
4321
- }
4322
-
4323
- ctx->t_mel_us += ctxs[i].t_mel_us;
4324
- ctx->t_sample_us += ctxs[i].t_sample_us;
4325
- ctx->t_encode_us += ctxs[i].t_encode_us;
4326
- ctx->t_decode_us += ctxs[i].t_decode_us;
4327
-
4328
- kv_cache_free(ctx->kv_cross);
4329
-
4330
- for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4331
- kv_cache_free(ctx->decoders[j].kv_self);
4332
- }
4333
- }
4334
-
4335
- // average the timings
4336
- ctx->t_mel_us /= n_processors;
4337
- ctx->t_sample_us /= n_processors;
4338
- ctx->t_encode_us /= n_processors;
4339
- ctx->t_decode_us /= n_processors;
4340
-
4341
- // print information about the audio boundaries
4342
- fprintf(stderr, "\n");
4343
- fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
4344
- for (int i = 0; i < n_processors - 1; ++i) {
4345
- fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
4346
- }
4347
- fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
4348
-
4349
- return ret;
4350
- }
4351
-
4352
- int whisper_full_n_segments(struct whisper_context * ctx) {
4353
- return ctx->result_all.size();
4354
- }
4355
-
4356
- int whisper_full_lang_id(struct whisper_context * ctx) {
4357
- return ctx->lang_id;
4358
- }
4359
-
4360
- int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
4361
- return ctx->result_all[i_segment].t0;
4362
- }
4363
-
4364
- int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
4365
- return ctx->result_all[i_segment].t1;
4366
- }
4367
-
4368
- const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
4369
- return ctx->result_all[i_segment].text.c_str();
4370
- }
4371
-
4372
- int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
4373
- return ctx->result_all[i_segment].tokens.size();
4374
- }
4375
-
4376
- const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
4377
- return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
4378
- }
4379
-
4380
- whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
4381
- return ctx->result_all[i_segment].tokens[i_token].id;
4382
- }
4383
-
4384
- struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
4385
- return ctx->result_all[i_segment].tokens[i_token];
4386
- }
4387
-
4388
- float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
4389
- return ctx->result_all[i_segment].tokens[i_token].p;
4390
- }
4391
-
4392
- // =================================================================================================
4393
-
4394
- //
4395
- // Temporary interface needed for exposing ggml interface
4396
- // Will be removed in the future when ggml becomes a separate library
4397
- //
4398
-
4399
- WHISPER_API int whisper_bench_memcpy(int n_threads) {
4400
- ggml_time_init();
4401
-
4402
- size_t n = 50;
4403
- size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations
4404
-
4405
- // 1 GB array
4406
- const size_t size = arr*1024llu*1024llu;
4407
-
4408
- char * src = (char *) malloc(size);
4409
- char * dst = (char *) malloc(size);
4410
-
4411
- for (size_t i = 0; i < size; i++) src[i] = i;
4412
-
4413
- memcpy(dst, src, size); // heat-up
4414
-
4415
- double tsum = 0.0;
4416
-
4417
- for (size_t i = 0; i < n; i++) {
4418
- const int64_t t0 = ggml_time_us();
4419
-
4420
- memcpy(dst, src, size);
4421
-
4422
- const int64_t t1 = ggml_time_us();
4423
-
4424
- tsum += (t1 - t0)*1e-6;
4425
-
4426
- src[0] = rand();
4427
- }
4428
-
4429
- fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
4430
-
4431
- // needed to prevent the compile from optimizing the memcpy away
4432
- {
4433
- double sum = 0.0;
4434
-
4435
- for (size_t i = 0; i < size; i++) sum += dst[i];
4436
-
4437
- fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
4438
- }
4439
-
4440
- free(src);
4441
- free(dst);
4442
-
4443
- return 0;
4444
- }
4445
-
4446
- WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
4447
- ggml_time_init();
4448
-
4449
- const int n_max = 128;
4450
-
4451
- const std::vector<size_t> sizes = {
4452
- 64, 128, 256, 512, 1024, 2048, 4096,
4453
- };
4454
-
4455
- const size_t N_max = sizes.back();
4456
-
4457
- // a: N*N*sizeof(float)
4458
- // b: N*N*sizeof(float)
4459
- // c: N*N*sizeof(float)
4460
- // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
4461
- std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
4462
-
4463
- for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
4464
-
4465
- for (int j = 0; j < (int) sizes.size(); j++) {
4466
- int n_fp16 = 0;
4467
- int n_fp32 = 0;
4468
-
4469
- // GFLOPS/s
4470
- double s_fp16 = 0.0;
4471
- double s_fp32 = 0.0;
4472
-
4473
- const size_t N = sizes[j];
4474
-
4475
- for (int k = 0; k < 2; ++k) {
4476
- const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32;
4477
-
4478
- double & s = k == 0 ? s_fp16 : s_fp32;
4479
- int & n = k == 0 ? n_fp16 : n_fp32;
4480
-
4481
- struct ggml_init_params gparams = {
4482
- /*.mem_size =*/ buf.size(),
4483
- /*.mem_buffer =*/ buf.data(),
4484
- };
4485
-
4486
- struct ggml_context * ctx0 = ggml_init(gparams);
4487
-
4488
- struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N);
4489
- struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N);
4490
-
4491
- struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
4492
-
4493
- struct ggml_cgraph gf = ggml_build_forward(c);
4494
-
4495
- gf.n_threads = n_threads;
4496
-
4497
- double tsum = 0.0;
4498
-
4499
- // heat-up
4500
- ggml_graph_compute(ctx0, &gf);
4501
-
4502
- for (int i = 0; i < n_max; ++i) {
4503
- const int64_t t0 = ggml_time_us();
4504
-
4505
- ggml_graph_compute(ctx0, &gf);
4506
-
4507
- const int64_t t1 = ggml_time_us();
4508
-
4509
- tsum += (t1 - t0)*1e-6;
4510
- n++;
4511
-
4512
- if (tsum > 1.0 && n >= 3) {
4513
- break;
4514
- }
4515
- }
4516
-
4517
- ggml_free(ctx0);
4518
-
4519
- s = ((2.0*N*N*N*n)/tsum)*1e-9;
4520
- }
4521
-
4522
- fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
4523
- N, N, s_fp16, n_fp16, s_fp32, n_fp32);
4524
- }
4525
-
4526
- return 0;
4527
- }
4528
-
4529
- // =================================================================================================
4530
-
4531
- // =================================================================================================
4532
-
4533
- //
4534
- // Experimental stuff below
4535
- //
4536
- // Not sure if these should be part of the library at all, because the quality of the results is not
4537
- // guaranteed. Might get removed at some point unless a robust algorithm implementation is found
4538
- //
4539
-
4540
- // =================================================================================================
4541
-
4542
- //
4543
- // token-level timestamps
4544
- //
4545
-
4546
- static int timestamp_to_sample(int64_t t, int n_samples) {
4547
- return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
4548
- }
4549
-
4550
- static int64_t sample_to_timestamp(int i_sample) {
4551
- return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
4552
- }
4553
-
4554
- // a cost-function / heuristic that is high for text that takes longer to pronounce
4555
- // obviously, can be improved
4556
- static float voice_length(const std::string & text) {
4557
- float res = 0.0f;
4558
-
4559
- for (char c : text) {
4560
- if (c == ' ') {
4561
- res += 0.01f;
4562
- } else if (c == ',') {
4563
- res += 2.00f;
4564
- } else if (c == '.') {
4565
- res += 3.00f;
4566
- } else if (c == '!') {
4567
- res += 3.00f;
4568
- } else if (c == '?') {
4569
- res += 3.00f;
4570
- } else if (c >= '0' && c <= '9') {
4571
- res += 3.00f;
4572
- } else {
4573
- res += 1.00f;
4574
- }
4575
- }
4576
-
4577
- return res;
4578
- }
4579
-
4580
- // average the fabs of the signal
4581
- static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
4582
- const int hw = n_samples_per_half_window;
4583
-
4584
- std::vector<float> result(n_samples);
4585
-
4586
- for (int i = 0; i < n_samples; i++) {
4587
- float sum = 0;
4588
- for (int j = -hw; j <= hw; j++) {
4589
- if (i + j >= 0 && i + j < n_samples) {
4590
- sum += fabs(signal[i + j]);
4591
- }
4592
- }
4593
- result[i] = sum/(2*hw + 1);
4594
- }
4595
-
4596
- return result;
4597
- }
4598
-
4599
- static void whisper_exp_compute_token_level_timestamps(
4600
- struct whisper_context & ctx,
4601
- int i_segment,
4602
- float thold_pt,
4603
- float thold_ptsum) {
4604
- auto & segment = ctx.result_all[i_segment];
4605
- auto & tokens = segment.tokens;
4606
-
4607
- const int n_samples = ctx.energy.size();
4608
-
4609
- if (n_samples == 0) {
4610
- fprintf(stderr, "%s: no signal data available\n", __func__);
4611
- return;
4612
- }
4613
-
4614
- const int64_t t0 = segment.t0;
4615
- const int64_t t1 = segment.t1;
4616
-
4617
- const int n = tokens.size();
4618
-
4619
- if (n == 0) {
4620
- return;
4621
- }
4622
-
4623
- if (n == 1) {
4624
- tokens[0].t0 = t0;
4625
- tokens[0].t1 = t1;
4626
-
4627
- return;
4628
- }
4629
-
4630
- auto & t_beg = ctx.t_beg;
4631
- auto & t_last = ctx.t_last;
4632
- auto & tid_last = ctx.tid_last;
4633
-
4634
- for (int j = 0; j < n; ++j) {
4635
- auto & token = tokens[j];
4636
-
4637
- if (j == 0) {
4638
- if (token.id == whisper_token_beg(&ctx)) {
4639
- tokens[j ].t0 = t0;
4640
- tokens[j ].t1 = t0;
4641
- tokens[j + 1].t0 = t0;
4642
-
4643
- t_beg = t0;
4644
- t_last = t0;
4645
- tid_last = whisper_token_beg(&ctx);
4646
- } else {
4647
- tokens[j ].t0 = t_last;
4648
- }
4649
- }
4650
-
4651
- const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
4652
-
4653
- tokens[j].id = token.id;
4654
- tokens[j].tid = token.tid;
4655
- tokens[j].p = token.p;
4656
- tokens[j].pt = token.pt;
4657
- tokens[j].ptsum = token.ptsum;
4658
-
4659
- tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
4660
-
4661
- if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
4662
- if (j > 0) {
4663
- tokens[j - 1].t1 = tt;
4664
- }
4665
- tokens[j].t0 = tt;
4666
- tid_last = token.tid;
4667
- }
4668
- }
4669
-
4670
- tokens[n - 2].t1 = t1;
4671
- tokens[n - 1].t0 = t1;
4672
- tokens[n - 1].t1 = t1;
4673
-
4674
- t_last = t1;
4675
-
4676
- // find intervals of tokens with unknown timestamps
4677
- // fill the timestamps by proportionally splitting the interval based on the token voice lengths
4678
- {
4679
- int p0 = 0;
4680
- int p1 = 0;
4681
-
4682
- while (true) {
4683
- while (p1 < n && tokens[p1].t1 < 0) {
4684
- p1++;
4685
- }
4686
-
4687
- if (p1 >= n) {
4688
- p1--;
4689
- }
4690
-
4691
- //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1);
4692
-
4693
- if (p1 > p0) {
4694
- double psum = 0.0;
4695
- for (int j = p0; j <= p1; j++) {
4696
- psum += tokens[j].vlen;
4697
- }
4698
-
4699
- //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
4700
-
4701
- const double dt = tokens[p1].t1 - tokens[p0].t0;
4702
-
4703
- // split the time proportionally to the voice length
4704
- for (int j = p0 + 1; j <= p1; j++) {
4705
- const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
4706
-
4707
- tokens[j - 1].t1 = ct;
4708
- tokens[j ].t0 = ct;
4709
- }
4710
- }
4711
-
4712
- p1++;
4713
- p0 = p1;
4714
- if (p1 >= n) {
4715
- break;
4716
- }
4717
- }
4718
- }
4719
-
4720
- // fix up (just in case)
4721
- for (int j = 0; j < n - 1; j++) {
4722
- if (tokens[j].t1 < 0) {
4723
- tokens[j + 1].t0 = tokens[j].t1;
4724
- }
4725
-
4726
- if (j > 0) {
4727
- if (tokens[j - 1].t1 > tokens[j].t0) {
4728
- tokens[j].t0 = tokens[j - 1].t1;
4729
- tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
4730
- }
4731
- }
4732
- }
4733
-
4734
- // VAD
4735
- // expand or contract tokens based on voice activity
4736
- {
4737
- const int hw = WHISPER_SAMPLE_RATE/8;
4738
-
4739
- for (int j = 0; j < n; j++) {
4740
- if (tokens[j].id >= whisper_token_eot(&ctx)) {
4741
- continue;
4742
- }
4743
-
4744
- int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
4745
- int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
4746
-
4747
- const int ss0 = std::max(s0 - hw, 0);
4748
- const int ss1 = std::min(s1 + hw, n_samples);
4749
-
4750
- const int ns = ss1 - ss0;
4751
-
4752
- float sum = 0.0f;
4753
-
4754
- for (int k = ss0; k < ss1; k++) {
4755
- sum += ctx.energy[k];
4756
- }
4757
-
4758
- const float thold = 0.5*sum/ns;
4759
-
4760
- {
4761
- int k = s0;
4762
- if (ctx.energy[k] > thold && j > 0) {
4763
- while (k > 0 && ctx.energy[k] > thold) {
4764
- k--;
4765
- }
4766
- tokens[j].t0 = sample_to_timestamp(k);
4767
- if (tokens[j].t0 < tokens[j - 1].t1) {
4768
- tokens[j].t0 = tokens[j - 1].t1;
4769
- } else {
4770
- s0 = k;
4771
- }
4772
- } else {
4773
- while (ctx.energy[k] < thold && k < s1) {
4774
- k++;
4775
- }
4776
- s0 = k;
4777
- tokens[j].t0 = sample_to_timestamp(k);
4778
- }
4779
- }
4780
-
4781
- {
4782
- int k = s1;
4783
- if (ctx.energy[k] > thold) {
4784
- while (k < n_samples - 1 && ctx.energy[k] > thold) {
4785
- k++;
4786
- }
4787
- tokens[j].t1 = sample_to_timestamp(k);
4788
- if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
4789
- tokens[j].t1 = tokens[j + 1].t0;
4790
- } else {
4791
- s1 = k;
4792
- }
4793
- } else {
4794
- while (ctx.energy[k] < thold && k > s0) {
4795
- k--;
4796
- }
4797
- s1 = k;
4798
- tokens[j].t1 = sample_to_timestamp(k);
4799
- }
4800
- }
4801
- }
4802
- }
4803
-
4804
- // fixed token expand (optional)
4805
- //{
4806
- // const int t_expand = 0;
4807
-
4808
- // for (int j = 0; j < n; j++) {
4809
- // if (j > 0) {
4810
- // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
4811
- // }
4812
- // if (j < n - 1) {
4813
- // tokens[j].t1 = tokens[j].t1 + t_expand;
4814
- // }
4815
- // }
4816
- //}
4817
-
4818
- // debug info
4819
- //for (int j = 0; j < n; ++j) {
4820
- // const auto & token = tokens[j];
4821
- // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]";
4822
- // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
4823
- // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id));
4824
-
4825
- // if (tokens[j].id >= whisper_token_eot(&ctx)) {
4826
- // continue;
4827
- // }
4828
- //}
4829
- }