whispercpp 1.2.0.2 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +46 -86
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -7
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/ggml/include/ggml.h +2285 -0
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/include/whisper.h +672 -0
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1608 -159
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/src/whisper.cpp +7393 -0
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -8616
  133. data/ext/ggml.h +0 -748
  134. data/ext/whisper.cpp +0 -4829
  135. data/ext/whisper.h +0 -402
data/ext/whisper.cpp DELETED
@@ -1,4829 +0,0 @@
1
- #define WHISPER_BUILD
2
- #include "whisper.h"
3
-
4
- #include "ggml.h"
5
-
6
- #include <algorithm>
7
- #include <cassert>
8
- #define _USE_MATH_DEFINES
9
- #include <cmath>
10
- #include <cstdio>
11
- #include <cstring>
12
- #include <fstream>
13
- #include <map>
14
- #include <string>
15
- #include <thread>
16
- #include <vector>
17
- #include <regex>
18
- #include <random>
19
-
20
- #if defined(GGML_BIG_ENDIAN)
21
- #include <bit>
22
-
23
- template<typename T>
24
- static T byteswap(T value) {
25
- return std::byteswap(value);
26
- }
27
-
28
- template<>
29
- float byteswap(float value) {
30
- return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
31
- }
32
-
33
- template<typename T>
34
- static void byteswap_tensor_data(ggml_tensor * tensor) {
35
- T * datum = reinterpret_cast<T *>(tensor->data);
36
- for (int i = 0; i < ggml_nelements(tensor); i++) {
37
- datum[i] = byteswap(datum[i]);
38
- }
39
- }
40
-
41
- static void byteswap_tensor(ggml_tensor * tensor) {
42
- switch (tensor->type) {
43
- case GGML_TYPE_I16: {
44
- byteswap_tensor_data<int16_t>(tensor);
45
- break;
46
- }
47
- case GGML_TYPE_F16: {
48
- byteswap_tensor_data<ggml_fp16_t>(tensor);
49
- break;
50
- }
51
- case GGML_TYPE_I32: {
52
- byteswap_tensor_data<int32_t>(tensor);
53
- break;
54
- }
55
- case GGML_TYPE_F32: {
56
- byteswap_tensor_data<float>(tensor);
57
- break;
58
- }
59
- default: { // GML_TYPE_I8
60
- break;
61
- }
62
- }
63
- }
64
-
65
- #define BYTESWAP_VALUE(d) d = byteswap(d)
66
- #define BYTESWAP_FILTERS(f) \
67
- do { \
68
- for (auto & datum : f.data) { \
69
- datum = byteswap(datum); \
70
- } \
71
- } while (0)
72
- #define BYTESWAP_TENSOR(t) \
73
- do { \
74
- byteswap_tensor(tensor); \
75
- } while (0)
76
- #else
77
- #define BYTESWAP_VALUE(d) do {} while (0)
78
- #define BYTESWAP_FILTERS(f) do {} while (0)
79
- #define BYTESWAP_TENSOR(t) do {} while (0)
80
- #endif
81
-
82
- #define WHISPER_ASSERT(x) \
83
- do { \
84
- if (!(x)) { \
85
- fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
86
- abort(); \
87
- } \
88
- } while (0)
89
-
90
- // define this to enable verbose trace logging - useful for debugging purposes
91
- //#define WHISPER_DEBUG
92
-
93
- #if defined(WHISPER_DEBUG)
94
- #define WHISPER_PRINT_DEBUG(...) \
95
- do { \
96
- fprintf(stderr, __VA_ARGS__); \
97
- } while (0)
98
- #else
99
- #define WHISPER_PRINT_DEBUG(...)
100
- #endif
101
-
102
- #define WHISPER_USE_FLASH_ATTN
103
- //#define WHISPER_USE_FLASH_FF
104
- #define WHISPER_MAX_DECODERS 16
105
-
106
- #define WHISPER_USE_SCRATCH
107
- #define WHISPER_MAX_SCRATCH_BUFFERS 16
108
-
109
- // available whisper models
110
- enum e_model {
111
- MODEL_UNKNOWN,
112
- MODEL_TINY,
113
- MODEL_BASE,
114
- MODEL_SMALL,
115
- MODEL_MEDIUM,
116
- MODEL_LARGE,
117
- };
118
-
119
- static const std::map<std::string, std::pair<int, std::string>> g_lang = {
120
- { "en", { 0, "english", } },
121
- { "zh", { 1, "chinese", } },
122
- { "de", { 2, "german", } },
123
- { "es", { 3, "spanish", } },
124
- { "ru", { 4, "russian", } },
125
- { "ko", { 5, "korean", } },
126
- { "fr", { 6, "french", } },
127
- { "ja", { 7, "japanese", } },
128
- { "pt", { 8, "portuguese", } },
129
- { "tr", { 9, "turkish", } },
130
- { "pl", { 10, "polish", } },
131
- { "ca", { 11, "catalan", } },
132
- { "nl", { 12, "dutch", } },
133
- { "ar", { 13, "arabic", } },
134
- { "sv", { 14, "swedish", } },
135
- { "it", { 15, "italian", } },
136
- { "id", { 16, "indonesian", } },
137
- { "hi", { 17, "hindi", } },
138
- { "fi", { 18, "finnish", } },
139
- { "vi", { 19, "vietnamese", } },
140
- { "iw", { 20, "hebrew", } },
141
- { "uk", { 21, "ukrainian", } },
142
- { "el", { 22, "greek", } },
143
- { "ms", { 23, "malay", } },
144
- { "cs", { 24, "czech", } },
145
- { "ro", { 25, "romanian", } },
146
- { "da", { 26, "danish", } },
147
- { "hu", { 27, "hungarian", } },
148
- { "ta", { 28, "tamil", } },
149
- { "no", { 29, "norwegian", } },
150
- { "th", { 30, "thai", } },
151
- { "ur", { 31, "urdu", } },
152
- { "hr", { 32, "croatian", } },
153
- { "bg", { 33, "bulgarian", } },
154
- { "lt", { 34, "lithuanian", } },
155
- { "la", { 35, "latin", } },
156
- { "mi", { 36, "maori", } },
157
- { "ml", { 37, "malayalam", } },
158
- { "cy", { 38, "welsh", } },
159
- { "sk", { 39, "slovak", } },
160
- { "te", { 40, "telugu", } },
161
- { "fa", { 41, "persian", } },
162
- { "lv", { 42, "latvian", } },
163
- { "bn", { 43, "bengali", } },
164
- { "sr", { 44, "serbian", } },
165
- { "az", { 45, "azerbaijani", } },
166
- { "sl", { 46, "slovenian", } },
167
- { "kn", { 47, "kannada", } },
168
- { "et", { 48, "estonian", } },
169
- { "mk", { 49, "macedonian", } },
170
- { "br", { 50, "breton", } },
171
- { "eu", { 51, "basque", } },
172
- { "is", { 52, "icelandic", } },
173
- { "hy", { 53, "armenian", } },
174
- { "ne", { 54, "nepali", } },
175
- { "mn", { 55, "mongolian", } },
176
- { "bs", { 56, "bosnian", } },
177
- { "kk", { 57, "kazakh", } },
178
- { "sq", { 58, "albanian", } },
179
- { "sw", { 59, "swahili", } },
180
- { "gl", { 60, "galician", } },
181
- { "mr", { 61, "marathi", } },
182
- { "pa", { 62, "punjabi", } },
183
- { "si", { 63, "sinhala", } },
184
- { "km", { 64, "khmer", } },
185
- { "sn", { 65, "shona", } },
186
- { "yo", { 66, "yoruba", } },
187
- { "so", { 67, "somali", } },
188
- { "af", { 68, "afrikaans", } },
189
- { "oc", { 69, "occitan", } },
190
- { "ka", { 70, "georgian", } },
191
- { "be", { 71, "belarusian", } },
192
- { "tg", { 72, "tajik", } },
193
- { "sd", { 73, "sindhi", } },
194
- { "gu", { 74, "gujarati", } },
195
- { "am", { 75, "amharic", } },
196
- { "yi", { 76, "yiddish", } },
197
- { "lo", { 77, "lao", } },
198
- { "uz", { 78, "uzbek", } },
199
- { "fo", { 79, "faroese", } },
200
- { "ht", { 80, "haitian creole", } },
201
- { "ps", { 81, "pashto", } },
202
- { "tk", { 82, "turkmen", } },
203
- { "nn", { 83, "nynorsk", } },
204
- { "mt", { 84, "maltese", } },
205
- { "sa", { 85, "sanskrit", } },
206
- { "lb", { 86, "luxembourgish", } },
207
- { "my", { 87, "myanmar", } },
208
- { "bo", { 88, "tibetan", } },
209
- { "tl", { 89, "tagalog", } },
210
- { "mg", { 90, "malagasy", } },
211
- { "as", { 91, "assamese", } },
212
- { "tt", { 92, "tatar", } },
213
- { "haw", { 93, "hawaiian", } },
214
- { "ln", { 94, "lingala", } },
215
- { "ha", { 95, "hausa", } },
216
- { "ba", { 96, "bashkir", } },
217
- { "jw", { 97, "javanese", } },
218
- { "su", { 98, "sundanese", } },
219
- };
220
-
221
- static const size_t MB = 1024*1024;
222
-
223
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
224
- { MODEL_TINY, 12ull*MB },
225
- { MODEL_BASE, 15ull*MB },
226
- { MODEL_SMALL, 23ull*MB },
227
- { MODEL_MEDIUM, 31ull*MB },
228
- { MODEL_LARGE, 38ull*MB },
229
- };
230
-
231
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
232
- { MODEL_TINY, 18ull*MB },
233
- { MODEL_BASE, 24ull*MB },
234
- { MODEL_SMALL, 36ull*MB },
235
- { MODEL_MEDIUM, 48ull*MB },
236
- { MODEL_LARGE, 60ull*MB },
237
- };
238
-
239
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
240
- { MODEL_TINY, 4ull*MB },
241
- { MODEL_BASE, 4ull*MB },
242
- { MODEL_SMALL, 6ull*MB },
243
- { MODEL_MEDIUM, 7ull*MB },
244
- { MODEL_LARGE, 9ull*MB },
245
- };
246
-
247
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
248
- { MODEL_TINY, 4ull*MB },
249
- { MODEL_BASE, 4ull*MB },
250
- { MODEL_SMALL, 6ull*MB },
251
- { MODEL_MEDIUM, 7ull*MB },
252
- { MODEL_LARGE, 9ull*MB },
253
- };
254
-
255
- static const std::map<e_model, size_t> MEM_REQ_MODEL = {
256
- { MODEL_TINY, 74ull*MB },
257
- { MODEL_BASE, 142ull*MB },
258
- { MODEL_SMALL, 466ull*MB },
259
- { MODEL_MEDIUM, 1464ull*MB },
260
- { MODEL_LARGE, 2952ull*MB },
261
- };
262
-
263
- static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
264
- { MODEL_TINY, 3ull*MB },
265
- { MODEL_BASE, 6ull*MB },
266
- { MODEL_SMALL, 16ull*MB },
267
- { MODEL_MEDIUM, 43ull*MB },
268
- { MODEL_LARGE, 71ull*MB },
269
- };
270
-
271
- static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
272
- { MODEL_TINY, 9ull*MB },
273
- { MODEL_BASE, 18ull*MB },
274
- { MODEL_SMALL, 53ull*MB },
275
- { MODEL_MEDIUM, 141ull*MB },
276
- { MODEL_LARGE, 235ull*MB },
277
- };
278
-
279
- static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
280
- { MODEL_TINY, 6ull*MB },
281
- { MODEL_BASE, 8ull*MB },
282
- { MODEL_SMALL, 13ull*MB },
283
- { MODEL_MEDIUM, 22ull*MB },
284
- { MODEL_LARGE, 33ull*MB },
285
- };
286
-
287
- static const std::map<e_model, size_t> MEM_REQ_DECODE = {
288
- { MODEL_TINY, 3ull*MB },
289
- { MODEL_BASE, 5ull*MB },
290
- { MODEL_SMALL, 10ull*MB },
291
- { MODEL_MEDIUM, 18ull*MB },
292
- { MODEL_LARGE, 27ull*MB },
293
- };
294
-
295
- struct whisper_mel {
296
- int n_len;
297
- int n_mel;
298
-
299
- std::vector<float> data;
300
- };
301
-
302
- struct whisper_filters {
303
- int32_t n_mel;
304
- int32_t n_fft;
305
-
306
- std::vector<float> data;
307
- };
308
-
309
- struct whisper_vocab {
310
- using id = int32_t;
311
- using token = std::string;
312
-
313
- int n_vocab = 51864;
314
-
315
- std::map<token, id> token_to_id;
316
- std::map<id, token> id_to_token;
317
-
318
- id token_eot = 50256;
319
- id token_sot = 50257;
320
- id token_prev = 50360;
321
- id token_solm = 50361; // ??
322
- id token_not = 50362; // no timestamps
323
- id token_beg = 50363;
324
-
325
- // available tasks
326
- static const id token_translate = 50358;
327
- static const id token_transcribe = 50359;
328
-
329
- bool is_multilingual() const {
330
- return n_vocab == 51865;
331
- }
332
- };
333
-
334
- struct whisper_segment {
335
- int64_t t0;
336
- int64_t t1;
337
-
338
- std::string text;
339
-
340
- std::vector<whisper_token_data> tokens;
341
- };
342
-
343
- // medium
344
- // hparams: {
345
- // 'n_mels': 80,
346
- // 'n_vocab': 51864,
347
- // 'n_audio_ctx': 1500,
348
- // 'n_audio_state': 1024,
349
- // 'n_audio_head': 16,
350
- // 'n_audio_layer': 24,
351
- // 'n_text_ctx': 448,
352
- // 'n_text_state': 1024,
353
- // 'n_text_head': 16,
354
- // 'n_text_layer': 24
355
- // }
356
- //
357
- // default hparams (Whisper tiny)
358
- struct whisper_hparams {
359
- int32_t n_vocab = 51864;
360
- int32_t n_audio_ctx = 1500;
361
- int32_t n_audio_state = 384;
362
- int32_t n_audio_head = 6;
363
- int32_t n_audio_layer = 4;
364
- int32_t n_text_ctx = 448;
365
- int32_t n_text_state = 384;
366
- int32_t n_text_head = 6;
367
- int32_t n_text_layer = 4;
368
- int32_t n_mels = 80;
369
- int32_t f16 = 1;
370
- };
371
-
372
- // audio encoding layer
373
- struct whisper_layer_encoder {
374
- // encoder.blocks.*.attn_ln
375
- struct ggml_tensor * attn_ln_0_w;
376
- struct ggml_tensor * attn_ln_0_b;
377
-
378
- // encoder.blocks.*.attn.out
379
- struct ggml_tensor * attn_ln_1_w;
380
- struct ggml_tensor * attn_ln_1_b;
381
-
382
- // encoder.blocks.*.attn.query
383
- struct ggml_tensor * attn_q_w;
384
- struct ggml_tensor * attn_q_b;
385
-
386
- // encoder.blocks.*.attn.key
387
- struct ggml_tensor * attn_k_w;
388
-
389
- // encoder.blocks.*.attn.value
390
- struct ggml_tensor * attn_v_w;
391
- struct ggml_tensor * attn_v_b;
392
-
393
- // encoder.blocks.*.mlp_ln
394
- struct ggml_tensor * mlp_ln_w;
395
- struct ggml_tensor * mlp_ln_b;
396
-
397
- // encoder.blocks.*.mlp.0
398
- struct ggml_tensor * mlp_0_w;
399
- struct ggml_tensor * mlp_0_b;
400
-
401
- // encoder.blocks.*.mlp.2
402
- struct ggml_tensor * mlp_1_w;
403
- struct ggml_tensor * mlp_1_b;
404
- };
405
-
406
- // token decoding layer
407
- struct whisper_layer_decoder {
408
- // decoder.blocks.*.attn_ln
409
- struct ggml_tensor * attn_ln_0_w;
410
- struct ggml_tensor * attn_ln_0_b;
411
-
412
- // decoder.blocks.*.attn.out
413
- struct ggml_tensor * attn_ln_1_w;
414
- struct ggml_tensor * attn_ln_1_b;
415
-
416
- // decoder.blocks.*.attn.query
417
- struct ggml_tensor * attn_q_w;
418
- struct ggml_tensor * attn_q_b;
419
-
420
- // decoder.blocks.*.attn.key
421
- struct ggml_tensor * attn_k_w;
422
-
423
- // decoder.blocks.*.attn.value
424
- struct ggml_tensor * attn_v_w;
425
- struct ggml_tensor * attn_v_b;
426
-
427
- // decoder.blocks.*.cross_attn_ln
428
- struct ggml_tensor * cross_attn_ln_0_w;
429
- struct ggml_tensor * cross_attn_ln_0_b;
430
-
431
- // decoder.blocks.*.cross_attn.out
432
- struct ggml_tensor * cross_attn_ln_1_w;
433
- struct ggml_tensor * cross_attn_ln_1_b;
434
-
435
- // decoder.blocks.*.cross_attn.query
436
- struct ggml_tensor * cross_attn_q_w;
437
- struct ggml_tensor * cross_attn_q_b;
438
-
439
- // decoder.blocks.*.cross_attn.key
440
- struct ggml_tensor * cross_attn_k_w;
441
-
442
- // decoder.blocks.*.cross_attn.value
443
- struct ggml_tensor * cross_attn_v_w;
444
- struct ggml_tensor * cross_attn_v_b;
445
-
446
- // decoder.blocks.*.mlp_ln
447
- struct ggml_tensor * mlp_ln_w;
448
- struct ggml_tensor * mlp_ln_b;
449
-
450
- // decoder.blocks.*.mlp.0
451
- struct ggml_tensor * mlp_0_w;
452
- struct ggml_tensor * mlp_0_b;
453
-
454
- // decoder.blocks.*.mlp.2
455
- struct ggml_tensor * mlp_1_w;
456
- struct ggml_tensor * mlp_1_b;
457
- };
458
-
459
- struct whisper_kv_cache {
460
- struct ggml_tensor * k;
461
- struct ggml_tensor * v;
462
-
463
- struct ggml_context * ctx;
464
-
465
- std::vector<uint8_t> buf;
466
-
467
- int n; // number of tokens currently in the cache
468
- };
469
-
470
- struct whisper_model {
471
- e_model type = MODEL_UNKNOWN;
472
-
473
- whisper_hparams hparams;
474
- whisper_filters filters;
475
-
476
- // encoder.positional_embedding
477
- struct ggml_tensor * e_pe;
478
-
479
- // encoder.conv1
480
- struct ggml_tensor * e_conv_1_w;
481
- struct ggml_tensor * e_conv_1_b;
482
-
483
- // encoder.conv2
484
- struct ggml_tensor * e_conv_2_w;
485
- struct ggml_tensor * e_conv_2_b;
486
-
487
- // encoder.ln_post
488
- struct ggml_tensor * e_ln_w;
489
- struct ggml_tensor * e_ln_b;
490
-
491
- // decoder.positional_embedding
492
- struct ggml_tensor * d_pe;
493
-
494
- // decoder.token_embedding
495
- struct ggml_tensor * d_te;
496
-
497
- // decoder.ln
498
- struct ggml_tensor * d_ln_w;
499
- struct ggml_tensor * d_ln_b;
500
-
501
- std::vector<whisper_layer_encoder> layers_encoder;
502
- std::vector<whisper_layer_decoder> layers_decoder;
503
-
504
- // context
505
- struct ggml_context * ctx;
506
-
507
- // the model memory buffer is read-only and can be shared between processors
508
- std::vector<uint8_t> * buf;
509
-
510
- // tensors
511
- int n_loaded;
512
- std::map<std::string, struct ggml_tensor *> tensors;
513
- };
514
-
515
- struct whisper_sequence {
516
- std::vector<whisper_token_data> tokens;
517
-
518
- // the accumulated transcription in the current interation (used to truncate the tokens array)
519
- int result_len;
520
-
521
- double sum_logprobs_all; // the sum of the log probabilities of the tokens
522
- double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens)
523
- double avg_logprobs; // the average log probability of the tokens
524
- double entropy; // the entropy of the tokens
525
- double score; // likelihood rank score
526
- };
527
-
528
- // TAGS: WHISPER_DECODER_INIT
529
- struct whisper_decoder {
530
- // each decoders keeps its own KV-cache
531
- whisper_kv_cache kv_self;
532
-
533
- // the currently generated sequence of tokens
534
- whisper_sequence sequence;
535
-
536
- int seek_delta; // the window shift found so far based on the decoded timestamp tokens
537
-
538
- bool failed; // has the current segment failed to decode?
539
- bool completed; // has the decoder completed the current segment?
540
- bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
541
-
542
- // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
543
- std::vector<float> probs;
544
- std::vector<float> logits;
545
- std::vector<float> logprobs;
546
-
547
- std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
548
- };
549
-
550
- struct whisper_context {
551
- int64_t t_load_us = 0;
552
- int64_t t_mel_us = 0;
553
- int64_t t_sample_us = 0;
554
- int64_t t_encode_us = 0;
555
- int64_t t_decode_us = 0;
556
- int64_t t_start_us = 0;
557
-
558
- int32_t n_sample = 0; // number of tokens sampled
559
- int32_t n_encode = 0; // number of encoder calls
560
- int32_t n_decode = 0; // number of decoder calls
561
- int32_t n_fail_p = 0; // number of logprob threshold failures
562
- int32_t n_fail_h = 0; // number of entropy threshold failures
563
-
564
- ggml_type wtype; // weight type (FP32 or FP16)
565
-
566
- whisper_mel mel;
567
-
568
- whisper_model model;
569
- whisper_vocab vocab;
570
-
571
- // cross-attention KV cache for the decoders
572
- // shared between all decoders
573
- whisper_kv_cache kv_cross;
574
-
575
- whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
576
-
577
- // memory buffers used by encode / decode contexts
578
- std::vector<uint8_t> buf_compute;
579
- std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
580
-
581
- int buf_last = 0;
582
- size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
583
-
584
- // decode output (2-dimensional array: [n_tokens][n_vocab])
585
- std::vector<float> logits;
586
-
587
- std::vector<whisper_segment> result_all;
588
- std::vector<whisper_token> prompt_past;
589
-
590
- // work container used to avoid memory allocations
591
- std::vector<std::pair<double, whisper_vocab::id>> logits_id;
592
-
593
- mutable std::mt19937 rng; // used for sampling at t > 0.0
594
-
595
- int lang_id = 0; // english by default
596
-
597
- // [EXPERIMENTAL] token-level timestamps data
598
- int64_t t_beg = 0;
599
- int64_t t_last = 0;
600
- whisper_token tid_last;
601
- std::vector<float> energy; // PCM signal energy
602
-
603
- // [EXPERIMENTAL] speed-up techniques
604
- int32_t exp_n_audio_ctx = 0; // 0 - use default
605
-
606
- // [EXPERIMENTAL] abort handling
607
- bool running = true;
608
-
609
- void use_buf(struct ggml_context * ctx, int i) {
610
- #if defined(WHISPER_USE_SCRATCH)
611
- size_t last_size = 0;
612
-
613
- if (i == -1) {
614
- last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
615
- } else {
616
- auto & buf = buf_scratch[i];
617
- last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
618
- }
619
-
620
- if (buf_last >= 0) {
621
- buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
622
- }
623
-
624
- buf_last = i;
625
- #else
626
- (void) i;
627
- (void) ctx;
628
- #endif
629
- }
630
-
631
- size_t get_buf_max_mem(int i) const {
632
- #if defined(WHISPER_USE_SCRATCH)
633
- return buf_max_size[i];
634
- #else
635
- (void) i;
636
- return 0;
637
- #endif
638
- }
639
- };
640
-
641
- template<typename T>
642
- static void read_safe(whisper_model_loader * loader, T & dest) {
643
- loader->read(loader->context, &dest, sizeof(T));
644
- BYTESWAP_VALUE(dest);
645
- }
646
-
647
- static bool kv_cache_init(
648
- const struct whisper_hparams & hparams,
649
- const size_t mem_bytes,
650
- struct whisper_kv_cache & cache,
651
- ggml_type wtype,
652
- int n_ctx) {
653
- cache.buf.resize(mem_bytes);
654
-
655
- struct ggml_init_params params;
656
- params.mem_size = cache.buf.size();
657
- params.mem_buffer = cache.buf.data();
658
-
659
- cache.ctx = ggml_init(params);
660
-
661
- if (!cache.ctx) {
662
- fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
663
- return false;
664
- }
665
-
666
- const int n_text_state = hparams.n_text_state;
667
- const int n_text_layer = hparams.n_text_layer;
668
-
669
- const int n_mem = n_text_layer*n_ctx;
670
- const int n_elements = n_text_state*n_mem;
671
-
672
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
673
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
674
-
675
- return true;
676
- }
677
-
678
- static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
679
- WHISPER_ASSERT(cache.ctx);
680
-
681
- const int n_elements = ggml_nelements(cache.k);
682
- WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
683
-
684
- const ggml_type wtype = cache.k->type;
685
- WHISPER_ASSERT(wtype == cache.v->type);
686
-
687
- WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
688
-
689
- struct ggml_init_params params;
690
- params.mem_size = cache.buf.size();
691
- params.mem_buffer = cache.buf.data();
692
-
693
- cache.ctx = ggml_init(params);
694
-
695
- if (!cache.ctx) {
696
- fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
697
- return false;
698
- }
699
-
700
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
701
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
702
-
703
- return true;
704
- }
705
-
706
- static void kv_cache_free(struct whisper_kv_cache & cache) {
707
- if (cache.ctx) {
708
- ggml_free(cache.ctx);
709
- cache.ctx = nullptr;
710
- }
711
- }
712
-
713
- // load the model from a ggml file
714
- //
715
- // file format:
716
- //
717
- // - hparams
718
- // - pre-computed mel filters
719
- // - vocab
720
- // - weights
721
- //
722
- // see the convert-pt-to-ggml.py script for details
723
- //
724
- static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
725
- fprintf(stderr, "%s: loading model\n", __func__);
726
-
727
- const int64_t t_start_us = ggml_time_us();
728
-
729
- wctx.t_start_us = t_start_us;
730
-
731
- auto & model = wctx.model;
732
- auto & vocab = wctx.vocab;
733
-
734
- // verify magic
735
- {
736
- uint32_t magic;
737
- read_safe(loader, magic);
738
- if (magic != 0x67676d6c) {
739
- fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
740
- return false;
741
- }
742
- }
743
-
744
- //load hparams
745
- {
746
- auto & hparams = model.hparams;
747
-
748
- read_safe(loader, hparams.n_vocab);
749
- read_safe(loader, hparams.n_audio_ctx);
750
- read_safe(loader, hparams.n_audio_state);
751
- read_safe(loader, hparams.n_audio_head);
752
- read_safe(loader, hparams.n_audio_layer);
753
- read_safe(loader, hparams.n_text_ctx);
754
- read_safe(loader, hparams.n_text_state);
755
- read_safe(loader, hparams.n_text_head);
756
- read_safe(loader, hparams.n_text_layer);
757
- read_safe(loader, hparams.n_mels);
758
- read_safe(loader, hparams.f16);
759
-
760
- assert(hparams.n_text_state == hparams.n_audio_state);
761
-
762
- if (hparams.n_audio_layer == 4) {
763
- model.type = e_model::MODEL_TINY;
764
- }
765
-
766
- if (hparams.n_audio_layer == 6) {
767
- model.type = e_model::MODEL_BASE;
768
- }
769
-
770
- if (hparams.n_audio_layer == 12) {
771
- model.type = e_model::MODEL_SMALL;
772
- }
773
-
774
- if (hparams.n_audio_layer == 24) {
775
- model.type = e_model::MODEL_MEDIUM;
776
- }
777
-
778
- if (hparams.n_audio_layer == 32) {
779
- model.type = e_model::MODEL_LARGE;
780
- }
781
-
782
- // for the big tensors, we have the option to store the data in 16-bit floats
783
- // in order to save memory and also to speed up the computation
784
- wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
785
-
786
- const size_t scale = model.hparams.f16 ? 1 : 2;
787
-
788
- fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
789
- fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
790
- fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
791
- fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
792
- fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
793
- fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
794
- fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
795
- fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
796
- fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
797
- fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
798
- fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
799
- fprintf(stderr, "%s: type = %d\n", __func__, model.type);
800
-
801
- // print memory requirements
802
- {
803
- // this is the total memory required to run the inference
804
- const size_t mem_required =
805
- MEM_REQ_SCRATCH0.at (model.type) +
806
- MEM_REQ_SCRATCH1.at (model.type) +
807
- MEM_REQ_SCRATCH2.at (model.type) +
808
- MEM_REQ_SCRATCH3.at (model.type) +
809
- scale*MEM_REQ_MODEL.at (model.type) +
810
- scale*MEM_REQ_KV_CROSS.at(model.type) +
811
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
812
-
813
- // this is the memory required by one decoder
814
- const size_t mem_required_decoder =
815
- scale*MEM_REQ_KV_SELF.at(model.type);
816
-
817
- fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
818
- mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
819
- }
820
-
821
- // initialize all memory buffers
822
- // always have at least one decoder
823
-
824
- wctx.model.buf = new std::vector<uint8_t>();
825
- wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
826
-
827
- if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
828
- fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
829
- return false;
830
- }
831
-
832
- {
833
- const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
834
- fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
835
- }
836
-
837
- if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
838
- fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
839
- return false;
840
- }
841
-
842
- {
843
- const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
844
- fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
845
- }
846
-
847
- wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
848
-
849
- wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
850
- wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
851
- wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
852
- wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
853
- }
854
-
855
- // load mel filters
856
- {
857
- auto & filters = wctx.model.filters;
858
-
859
- read_safe(loader, filters.n_mel);
860
- read_safe(loader, filters.n_fft);
861
-
862
- filters.data.resize(filters.n_mel * filters.n_fft);
863
- loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
864
- BYTESWAP_FILTERS(filters);
865
- }
866
-
867
- // load vocab
868
- {
869
- int32_t n_vocab = 0;
870
- read_safe(loader, n_vocab);
871
-
872
- //if (n_vocab != model.hparams.n_vocab) {
873
- // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
874
- // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
875
- // return false;
876
- //}
877
-
878
- std::string word;
879
- std::vector<char> tmp;
880
-
881
- tmp.reserve(128);
882
-
883
- for (int i = 0; i < n_vocab; i++) {
884
- uint32_t len;
885
- read_safe(loader, len);
886
-
887
- if (len > 0) {
888
- tmp.resize(len);
889
- loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
890
- word.assign(&tmp[0], tmp.size());
891
- } else {
892
- // seems like we have an empty-string token in multi-language models (i = 50256)
893
- //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
894
- word = "";
895
- }
896
-
897
- vocab.token_to_id[word] = i;
898
- vocab.id_to_token[i] = word;
899
-
900
- //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
901
- }
902
-
903
- vocab.n_vocab = model.hparams.n_vocab;
904
- if (vocab.is_multilingual()) {
905
- vocab.token_eot++;
906
- vocab.token_sot++;
907
- vocab.token_prev++;
908
- vocab.token_solm++;
909
- vocab.token_not++;
910
- vocab.token_beg++;
911
- }
912
-
913
- if (n_vocab < model.hparams.n_vocab) {
914
- fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
915
- for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
916
- if (i > vocab.token_beg) {
917
- word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
918
- } else if (i == vocab.token_eot) {
919
- word = "[_EOT_]";
920
- } else if (i == vocab.token_sot) {
921
- word = "[_SOT_]";
922
- } else if (i == vocab.token_prev) {
923
- word = "[_PREV_]";
924
- } else if (i == vocab.token_not) {
925
- word = "[_NOT_]";
926
- } else if (i == vocab.token_beg) {
927
- word = "[_BEG_]";
928
- } else {
929
- word = "[_extra_token_" + std::to_string(i) + "]";
930
- }
931
- vocab.token_to_id[word] = i;
932
- vocab.id_to_token[i] = word;
933
- }
934
- }
935
-
936
- wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
937
-
938
- wctx.logits_id.reserve(n_vocab);
939
-
940
- // TAGS: WHISPER_DECODER_INIT
941
- wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
942
-
943
- wctx.decoders[0].probs.reserve (vocab.n_vocab);
944
- wctx.decoders[0].logits.reserve (vocab.n_vocab);
945
- wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
946
- }
947
-
948
- size_t ctx_size = 0;
949
-
950
- const ggml_type wtype = wctx.wtype;
951
-
952
- {
953
- const auto & hparams = model.hparams;
954
-
955
- const int n_vocab = hparams.n_vocab;
956
-
957
- const int n_audio_ctx = hparams.n_audio_ctx;
958
- const int n_audio_state = hparams.n_audio_state;
959
- const int n_audio_layer = hparams.n_audio_layer;
960
-
961
- const int n_text_ctx = hparams.n_text_ctx;
962
- const int n_text_state = hparams.n_text_state;
963
- const int n_text_layer = hparams.n_text_layer;
964
-
965
- const int n_mels = hparams.n_mels;
966
-
967
- // encoder
968
- {
969
- ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
970
-
971
- ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
972
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
973
-
974
- ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
975
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
976
-
977
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
978
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
979
- }
980
-
981
- // decoder
982
- {
983
- ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
984
-
985
- ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
986
-
987
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
988
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
989
- }
990
-
991
- // encoder layers
992
- {
993
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
994
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
995
-
996
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
997
- ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
998
-
999
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
1000
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
1001
-
1002
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
1003
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
1004
-
1005
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
1006
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
1007
-
1008
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
1009
-
1010
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
1011
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
1012
-
1013
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
1014
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
1015
- }
1016
-
1017
- // decoder layers
1018
- {
1019
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
1020
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
1021
-
1022
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
1023
- ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
1024
-
1025
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
1026
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
1027
-
1028
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
1029
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
1030
-
1031
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
1032
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
1033
-
1034
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
1035
-
1036
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
1037
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
1038
-
1039
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
1040
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
1041
- //
1042
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
1043
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
1044
-
1045
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
1046
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
1047
-
1048
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
1049
-
1050
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
1051
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
1052
-
1053
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
1054
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
1055
- }
1056
-
1057
- ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
1058
-
1059
- fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
1060
- }
1061
-
1062
- // create the ggml context
1063
- {
1064
- struct ggml_init_params params;
1065
- params.mem_size = wctx.model.buf->size();
1066
- params.mem_buffer = wctx.model.buf->data();
1067
-
1068
- model.ctx = ggml_init(params);
1069
- if (!model.ctx) {
1070
- fprintf(stderr, "%s: ggml_init() failed\n", __func__);
1071
- return false;
1072
- }
1073
- }
1074
-
1075
- // prepare memory for the weights
1076
- {
1077
- auto & ctx = model.ctx;
1078
-
1079
- const auto & hparams = model.hparams;
1080
-
1081
- const int n_vocab = hparams.n_vocab;
1082
-
1083
- const int n_audio_ctx = hparams.n_audio_ctx;
1084
- const int n_audio_state = hparams.n_audio_state;
1085
- const int n_audio_layer = hparams.n_audio_layer;
1086
-
1087
- const int n_text_ctx = hparams.n_text_ctx;
1088
- const int n_text_state = hparams.n_text_state;
1089
- const int n_text_layer = hparams.n_text_layer;
1090
-
1091
- const int n_mels = hparams.n_mels;
1092
-
1093
- model.layers_encoder.resize(n_audio_layer);
1094
- model.layers_decoder.resize(n_text_layer);
1095
-
1096
- // encoder
1097
- {
1098
- model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1099
-
1100
- model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
1101
- model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1102
-
1103
- model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
1104
- model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1105
-
1106
- model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1107
- model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1108
-
1109
- // map by name
1110
- model.tensors["encoder.positional_embedding"] = model.e_pe;
1111
-
1112
- model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
1113
- model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
1114
-
1115
- model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
1116
- model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
1117
-
1118
- model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
1119
- model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
1120
-
1121
- for (int i = 0; i < n_audio_layer; ++i) {
1122
- auto & layer = model.layers_encoder[i];
1123
-
1124
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1125
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1126
-
1127
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
1128
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
1129
-
1130
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
1131
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1132
-
1133
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1134
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1135
-
1136
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1137
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1138
-
1139
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1140
-
1141
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1142
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1143
-
1144
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1145
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1146
-
1147
- // map by name
1148
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1149
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1150
-
1151
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1152
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1153
-
1154
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1155
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1156
-
1157
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1158
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1159
-
1160
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1161
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1162
-
1163
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1164
-
1165
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1166
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1167
-
1168
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1169
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1170
- }
1171
- }
1172
-
1173
- // decoder
1174
- {
1175
- model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
1176
-
1177
- model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
1178
-
1179
- model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1180
- model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1181
-
1182
- // map by name
1183
- model.tensors["decoder.positional_embedding"] = model.d_pe;
1184
-
1185
- model.tensors["decoder.token_embedding.weight"] = model.d_te;
1186
-
1187
- model.tensors["decoder.ln.weight"] = model.d_ln_w;
1188
- model.tensors["decoder.ln.bias"] = model.d_ln_b;
1189
-
1190
- for (int i = 0; i < n_text_layer; ++i) {
1191
- auto & layer = model.layers_decoder[i];
1192
-
1193
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1194
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1195
-
1196
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
1197
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
1198
-
1199
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
1200
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1201
-
1202
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1203
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1204
-
1205
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1206
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1207
-
1208
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1209
-
1210
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1211
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1212
-
1213
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1214
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1215
-
1216
- layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1217
- layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1218
-
1219
- layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1220
- layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1221
-
1222
- layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1223
-
1224
- layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1225
- layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1226
-
1227
- layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1228
- layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1229
-
1230
- // map by name
1231
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1232
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1233
-
1234
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1235
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1236
-
1237
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1238
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1239
-
1240
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1241
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1242
-
1243
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1244
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1245
-
1246
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1247
-
1248
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1249
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1250
-
1251
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1252
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1253
-
1254
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
1255
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
1256
-
1257
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
1258
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
1259
-
1260
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
1261
-
1262
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
1263
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
1264
-
1265
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
1266
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
1267
- }
1268
- }
1269
- }
1270
-
1271
- // load weights
1272
- {
1273
- size_t total_size = 0;
1274
-
1275
- model.n_loaded = 0;
1276
-
1277
- while (true) {
1278
- int32_t n_dims;
1279
- int32_t length;
1280
- int32_t ftype;
1281
-
1282
- read_safe(loader, n_dims);
1283
- read_safe(loader, length);
1284
- read_safe(loader, ftype);
1285
-
1286
- if (loader->eof(loader->context)) {
1287
- break;
1288
- }
1289
-
1290
- int32_t nelements = 1;
1291
- int32_t ne[3] = { 1, 1, 1 };
1292
- for (int i = 0; i < n_dims; ++i) {
1293
- read_safe(loader, ne[i]);
1294
- nelements *= ne[i];
1295
- }
1296
-
1297
- std::string name;
1298
- std::vector<char> tmp(length); // create a buffer
1299
- loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
1300
- name.assign(&tmp[0], tmp.size());
1301
-
1302
- if (model.tensors.find(name) == model.tensors.end()) {
1303
- fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
1304
- return false;
1305
- }
1306
-
1307
- auto tensor = model.tensors[name.data()];
1308
- if (ggml_nelements(tensor) != nelements) {
1309
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1310
- return false;
1311
- }
1312
-
1313
- if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1314
- fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1315
- __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
1316
- return false;
1317
- }
1318
-
1319
- const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
1320
-
1321
- if (nelements*bpe != ggml_nbytes(tensor)) {
1322
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1323
- __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1324
- return false;
1325
- }
1326
-
1327
- loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1328
- BYTESWAP_TENSOR(tensor);
1329
-
1330
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
1331
- total_size += ggml_nbytes(tensor);
1332
- model.n_loaded++;
1333
- }
1334
-
1335
- fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
1336
-
1337
- if (model.n_loaded == 0) {
1338
- fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1339
- } else if (model.n_loaded != (int) model.tensors.size()) {
1340
- fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1341
- return false;
1342
- }
1343
- }
1344
-
1345
- wctx.rng = std::mt19937(0);
1346
-
1347
- wctx.t_load_us = ggml_time_us() - t_start_us;
1348
-
1349
- return true;
1350
- }
1351
-
1352
- // evaluate the encoder
1353
- //
1354
- // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1355
- // part of the transformer model and returns the encoded features
1356
- //
1357
- // - model: the model
1358
- // - n_threads: number of threads to use
1359
- // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1360
- //
1361
- static bool whisper_encode(
1362
- whisper_context & wctx,
1363
- const int mel_offset,
1364
- const int n_threads) {
1365
- const int64_t t_start_us = ggml_time_us();
1366
-
1367
- const auto & model = wctx.model;
1368
- const auto & mel_inp = wctx.mel;
1369
- const auto & hparams = model.hparams;
1370
-
1371
- const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1372
- const int n_state = hparams.n_audio_state;
1373
- const int n_head = hparams.n_audio_head;
1374
- const int n_layer = hparams.n_audio_layer;
1375
-
1376
- const int n_mels = hparams.n_mels;
1377
- assert(mel_inp.n_mel == n_mels);
1378
-
1379
- struct ggml_init_params params;
1380
- params.mem_size = wctx.buf_compute.size();
1381
- params.mem_buffer = wctx.buf_compute.data();
1382
-
1383
- struct ggml_context * ctx0 = ggml_init(params);
1384
-
1385
- wctx.use_buf(ctx0, 0);
1386
-
1387
- struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1388
- assert(mel->type == GGML_TYPE_F32);
1389
- {
1390
- float * dst = (float *) mel->data;
1391
- memset(dst, 0, ggml_nbytes(mel));
1392
-
1393
- const int i0 = std::min(mel_offset, mel_inp.n_len);
1394
- const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1395
-
1396
- for (int j = 0; j < mel_inp.n_mel; ++j) {
1397
- for (int i = i0; i < i1; ++i) {
1398
- dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1399
- }
1400
- }
1401
- }
1402
-
1403
- struct ggml_tensor * cur;
1404
-
1405
- // convolution + gelu
1406
- {
1407
- wctx.use_buf(ctx0, 1);
1408
-
1409
- cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1410
- cur = ggml_add(ctx0,
1411
- ggml_repeat(ctx0,
1412
- model.e_conv_1_b,
1413
- cur),
1414
- cur);
1415
-
1416
- cur = ggml_gelu(ctx0, cur);
1417
-
1418
- wctx.use_buf(ctx0, 0);
1419
-
1420
- cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1421
- cur = ggml_add(ctx0,
1422
- ggml_repeat(ctx0,
1423
- model.e_conv_2_b,
1424
- cur),
1425
- cur);
1426
-
1427
- cur = ggml_gelu(ctx0, cur);
1428
- }
1429
-
1430
- wctx.use_buf(ctx0, 3);
1431
-
1432
- // ===================================================================
1433
- // NOTE: experimenting with partial evaluation of the encoder (ignore)
1434
- //static int iter = -1;
1435
- //const int n_iter = 1500/n_ctx;
1436
-
1437
- //iter = (iter + 1) % n_iter;
1438
-
1439
- //if (iter == 0) {
1440
- // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
1441
- // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
1442
- //}
1443
-
1444
- static int iter = 0;
1445
-
1446
- const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1447
- const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1448
-
1449
- struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1450
-
1451
- cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
1452
-
1453
- // ===================================================================
1454
-
1455
- // original:
1456
- //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1457
-
1458
- struct ggml_tensor * inpL = cur;
1459
-
1460
- for (int il = 0; il < n_layer; ++il) {
1461
- const auto & layer = model.layers_encoder[il];
1462
-
1463
- // norm
1464
- {
1465
- wctx.use_buf(ctx0, 0);
1466
-
1467
- cur = ggml_norm(ctx0, inpL);
1468
-
1469
- // cur = ln_0_w*cur + ln_0_b
1470
- cur = ggml_add(ctx0,
1471
- ggml_mul(ctx0,
1472
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1473
- cur),
1474
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1475
- }
1476
-
1477
- // self-attention
1478
- {
1479
- wctx.use_buf(ctx0, 1);
1480
-
1481
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1482
- layer.attn_q_w,
1483
- cur);
1484
-
1485
- Qcur = ggml_add(ctx0,
1486
- ggml_repeat(ctx0,
1487
- layer.attn_q_b,
1488
- Qcur),
1489
- Qcur);
1490
-
1491
- //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1492
-
1493
- // note: no bias for Key
1494
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1495
- layer.attn_k_w,
1496
- cur);
1497
-
1498
- //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1499
-
1500
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1501
- layer.attn_v_w,
1502
- cur);
1503
-
1504
- Vcur = ggml_add(ctx0,
1505
- ggml_repeat(ctx0,
1506
- layer.attn_v_b,
1507
- Vcur),
1508
- Vcur);
1509
-
1510
- // ------
1511
-
1512
- wctx.use_buf(ctx0, 0);
1513
-
1514
- #ifdef WHISPER_USE_FLASH_ATTN
1515
- struct ggml_tensor * Q =
1516
- ggml_permute(ctx0,
1517
- ggml_cpy(ctx0,
1518
- Qcur,
1519
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1520
- 0, 2, 1, 3);
1521
-
1522
- struct ggml_tensor * K =
1523
- ggml_permute(ctx0,
1524
- ggml_cpy(ctx0,
1525
- Kcur,
1526
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1527
- 0, 2, 1, 3);
1528
-
1529
- struct ggml_tensor * V =
1530
- ggml_cpy(ctx0,
1531
- ggml_permute(ctx0,
1532
- ggml_reshape_3d(ctx0,
1533
- Vcur,
1534
- n_state/n_head, n_head, n_ctx),
1535
- 1, 2, 0, 3),
1536
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
1537
- );
1538
-
1539
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1540
- #else
1541
- struct ggml_tensor * Q =
1542
- ggml_permute(ctx0,
1543
- ggml_cpy(ctx0,
1544
- Qcur,
1545
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1546
- 0, 2, 1, 3);
1547
-
1548
- struct ggml_tensor * K =
1549
- ggml_permute(ctx0,
1550
- ggml_cpy(ctx0,
1551
- Kcur,
1552
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1553
- 0, 2, 1, 3);
1554
-
1555
- // K * Q
1556
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1557
-
1558
- struct ggml_tensor * KQ_scaled =
1559
- ggml_scale(ctx0,
1560
- KQ,
1561
- ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1562
- );
1563
-
1564
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
1565
-
1566
- //struct ggml_tensor * V_trans =
1567
- // ggml_permute(ctx0,
1568
- // ggml_cpy(ctx0,
1569
- // Vcur,
1570
- // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1571
- // 1, 2, 0, 3);
1572
-
1573
- //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1574
-
1575
- struct ggml_tensor * V =
1576
- ggml_cpy(ctx0,
1577
- ggml_permute(ctx0,
1578
- ggml_reshape_3d(ctx0,
1579
- Vcur,
1580
- n_state/n_head, n_head, n_ctx),
1581
- 0, 2, 1, 3),
1582
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
1583
- );
1584
-
1585
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
1586
- #endif
1587
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1588
-
1589
- wctx.use_buf(ctx0, 1);
1590
-
1591
- cur = ggml_cpy(ctx0,
1592
- KQV_merged,
1593
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1594
- }
1595
-
1596
- // projection
1597
- {
1598
- wctx.use_buf(ctx0, 0);
1599
-
1600
- cur = ggml_mul_mat(ctx0,
1601
- layer.attn_ln_1_w,
1602
- cur);
1603
-
1604
- wctx.use_buf(ctx0, 1);
1605
-
1606
- cur = ggml_add(ctx0,
1607
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1608
- cur);
1609
- }
1610
-
1611
- wctx.use_buf(ctx0, 2);
1612
-
1613
- // add the input
1614
- cur = ggml_add(ctx0, cur, inpL);
1615
-
1616
- struct ggml_tensor * inpFF = cur;
1617
-
1618
- // feed-forward network
1619
- {
1620
- // norm
1621
- {
1622
- wctx.use_buf(ctx0, 0);
1623
-
1624
- cur = ggml_norm(ctx0, inpFF);
1625
-
1626
- wctx.use_buf(ctx0, 1);
1627
-
1628
- // cur = mlp_ln_w*cur + mlp_ln_b
1629
- cur = ggml_add(ctx0,
1630
- ggml_mul(ctx0,
1631
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1632
- cur),
1633
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1634
- }
1635
-
1636
- #ifdef WHISPER_USE_FLASH_FF
1637
- wctx.use_buf(ctx0, 0);
1638
-
1639
- cur = ggml_flash_ff(ctx0,
1640
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
1641
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1642
- #else
1643
- wctx.use_buf(ctx0, 0);
1644
-
1645
- // fully connected
1646
- cur = ggml_mul_mat(ctx0,
1647
- layer.mlp_0_w,
1648
- cur);
1649
-
1650
- wctx.use_buf(ctx0, 1);
1651
-
1652
- cur = ggml_add(ctx0,
1653
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
1654
- cur);
1655
-
1656
- wctx.use_buf(ctx0, 0);
1657
-
1658
- // GELU activation
1659
- cur = ggml_gelu(ctx0, cur);
1660
-
1661
- wctx.use_buf(ctx0, 1);
1662
-
1663
- // projection
1664
- cur = ggml_mul_mat(ctx0,
1665
- layer.mlp_1_w,
1666
- cur);
1667
-
1668
- wctx.use_buf(ctx0, 0);
1669
-
1670
- cur = ggml_add(ctx0,
1671
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
1672
- cur);
1673
- #endif
1674
- }
1675
-
1676
- wctx.use_buf(ctx0, 3);
1677
-
1678
- inpL = ggml_add(ctx0, cur, inpFF);
1679
- }
1680
-
1681
- cur = inpL;
1682
-
1683
- // norm
1684
- {
1685
- wctx.use_buf(ctx0, 0);
1686
-
1687
- cur = ggml_norm(ctx0, cur);
1688
-
1689
- wctx.use_buf(ctx0, 1);
1690
-
1691
- // cur = ln_f_g*cur + ln_f_b
1692
- cur = ggml_add(ctx0,
1693
- ggml_mul(ctx0,
1694
- ggml_repeat(ctx0, model.e_ln_w, cur),
1695
- cur),
1696
- ggml_repeat(ctx0, model.e_ln_b, cur));
1697
- }
1698
-
1699
- wctx.use_buf(ctx0, -1);
1700
-
1701
- // run the computation
1702
- {
1703
- struct ggml_cgraph gf = {};
1704
- gf.n_threads = n_threads;
1705
-
1706
- ggml_build_forward_expand(&gf, cur);
1707
- ggml_graph_compute (ctx0, &gf);
1708
-
1709
- //ggml_graph_print(&gf);
1710
- }
1711
-
1712
- // cur
1713
- //{
1714
- // printf("ne0 = %d\n", cur->ne[0]);
1715
- // printf("ne1 = %d\n", cur->ne[1]);
1716
- // for (int i = 0; i < 10; ++i) {
1717
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1718
- // }
1719
- // printf("... ");
1720
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1721
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1722
- // }
1723
- // printf("\n");
1724
- //}
1725
-
1726
- // pre-compute cross-attention memory
1727
- {
1728
- struct ggml_cgraph gf = {};
1729
- gf.n_threads = n_threads;
1730
-
1731
- // TODO: hack to disconnect the encoded features from the previous graph
1732
- cur->op = GGML_OP_NONE;
1733
- cur->src0 = nullptr;
1734
- cur->src1 = nullptr;
1735
-
1736
- for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1737
- auto & layer = model.layers_decoder[il];
1738
-
1739
- wctx.use_buf(ctx0, 0);
1740
-
1741
- struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1742
- layer.cross_attn_k_w,
1743
- cur);
1744
-
1745
- Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1746
-
1747
- wctx.use_buf(ctx0, 1);
1748
-
1749
- struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1750
- layer.cross_attn_v_w,
1751
- cur);
1752
-
1753
- Vcross = ggml_add(ctx0,
1754
- ggml_repeat(ctx0,
1755
- layer.cross_attn_v_b,
1756
- Vcross),
1757
- Vcross);
1758
-
1759
- wctx.use_buf(ctx0, -1);
1760
-
1761
- //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1762
- //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1763
- struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
1764
- struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
1765
-
1766
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1767
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1768
- }
1769
-
1770
- ggml_graph_compute(ctx0, &gf);
1771
- //ggml_graph_print(&gf);
1772
- }
1773
-
1774
- ////////////////////////////////////////////////////////////////////////////
1775
-
1776
- //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1777
- // ggml_used_mem(ctx0)/1024.0/1024.0,
1778
- // wctx.get_buf_max_mem(0)/1024.0/1024.0,
1779
- // wctx.get_buf_max_mem(1)/1024.0/1024.0,
1780
- // wctx.get_buf_max_mem(2)/1024.0/1024.0,
1781
- // wctx.get_buf_max_mem(3)/1024.0/1024.0);
1782
-
1783
- ggml_free(ctx0);
1784
-
1785
- wctx.t_encode_us += ggml_time_us() - t_start_us;
1786
- wctx.n_encode++;
1787
-
1788
- return true;
1789
- }
1790
-
1791
- // evaluate the decoder
1792
- //
1793
- // given text prompt + audio features -> computes the logits for the next token
1794
- //
1795
- // - model: the model
1796
- // - n_threads: number of threads to use
1797
- // - tokens: text prompt
1798
- // - n_tokens: number of tokens in the prompt
1799
- // - n_past: number of past tokens to prefix the prompt with
1800
- //
1801
- static bool whisper_decode(
1802
- whisper_context & wctx,
1803
- whisper_decoder & decoder,
1804
- const whisper_token * tokens,
1805
- const int n_tokens,
1806
- const int n_past,
1807
- const int n_threads) {
1808
- const int64_t t_start_us = ggml_time_us();
1809
-
1810
- const auto & model = wctx.model;
1811
- const auto & hparams = model.hparams;
1812
-
1813
- auto & kv_self = decoder.kv_self;
1814
-
1815
- WHISPER_ASSERT(!!kv_self.ctx);
1816
-
1817
- auto & logits_out = wctx.logits;
1818
-
1819
- const int n_vocab = hparams.n_vocab;
1820
-
1821
- const int n_ctx = hparams.n_text_ctx;
1822
- const int n_state = hparams.n_text_state;
1823
- const int n_head = hparams.n_text_head;
1824
- const int n_layer = hparams.n_text_layer;
1825
-
1826
- const int N = n_tokens;
1827
- const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1828
-
1829
- //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1830
-
1831
- struct ggml_init_params params;
1832
- params.mem_size = wctx.buf_compute.size();
1833
- params.mem_buffer = wctx.buf_compute.data();
1834
-
1835
- struct ggml_context * ctx0 = ggml_init(params);
1836
-
1837
- struct ggml_cgraph gf = {};
1838
- gf.n_threads = n_threads;
1839
-
1840
- struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1841
- memcpy(embd->data, tokens, N*ggml_element_size(embd));
1842
-
1843
- struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1844
- for (int i = 0; i < N; ++i) {
1845
- ((int32_t *) position->data)[i] = n_past + i;
1846
- }
1847
-
1848
- wctx.use_buf(ctx0, 3);
1849
-
1850
- // token encoding + position encoding
1851
- struct ggml_tensor * cur =
1852
- ggml_add(ctx0,
1853
- ggml_get_rows(ctx0, model.d_te, embd),
1854
- ggml_get_rows(ctx0, model.d_pe, position));
1855
-
1856
- struct ggml_tensor * inpL = cur;
1857
-
1858
- for (int il = 0; il < n_layer; ++il) {
1859
- const auto & layer = model.layers_decoder[il];
1860
-
1861
- // norm
1862
- {
1863
- wctx.use_buf(ctx0, 0);
1864
-
1865
- cur = ggml_norm(ctx0, inpL);
1866
-
1867
- // cur = ln_0_w*cur + ln_0_b
1868
- cur = ggml_add(ctx0,
1869
- ggml_mul(ctx0,
1870
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1871
- cur),
1872
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1873
- }
1874
-
1875
- // self-attention
1876
- {
1877
- wctx.use_buf(ctx0, 1);
1878
-
1879
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1880
- layer.attn_q_w,
1881
- cur);
1882
-
1883
- Qcur = ggml_add(ctx0,
1884
- ggml_repeat(ctx0,
1885
- layer.attn_q_b,
1886
- Qcur),
1887
- Qcur);
1888
-
1889
- Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1890
-
1891
- // note: no bias for Key
1892
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1893
- layer.attn_k_w,
1894
- cur);
1895
-
1896
- Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1897
-
1898
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1899
- layer.attn_v_w,
1900
- cur);
1901
-
1902
- Vcur = ggml_add(ctx0,
1903
- ggml_repeat(ctx0,
1904
- layer.attn_v_b,
1905
- Vcur),
1906
- Vcur);
1907
-
1908
- // store key and value to memory
1909
- {
1910
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
1911
- struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
1912
-
1913
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
1914
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
1915
- }
1916
-
1917
- // ------
1918
-
1919
- wctx.use_buf(ctx0, 0);
1920
-
1921
- struct ggml_tensor * Q =
1922
- ggml_permute(ctx0,
1923
- ggml_cpy(ctx0,
1924
- Qcur,
1925
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1926
- 0, 2, 1, 3);
1927
-
1928
- struct ggml_tensor * K =
1929
- ggml_permute(ctx0,
1930
- ggml_reshape_3d(ctx0,
1931
- ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
1932
- n_state/n_head, n_head, n_past + N),
1933
- 0, 2, 1, 3);
1934
-
1935
- wctx.use_buf(ctx0, 1);
1936
-
1937
- // K * Q
1938
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1939
-
1940
- wctx.use_buf(ctx0, 0);
1941
-
1942
- //struct ggml_tensor * KQ_scaled =
1943
- // ggml_scale(ctx0,
1944
- // KQ,
1945
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1946
- // );
1947
-
1948
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1949
-
1950
- wctx.use_buf(ctx0, 1);
1951
-
1952
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1953
-
1954
- wctx.use_buf(ctx0, 0);
1955
-
1956
- struct ggml_tensor * V_trans =
1957
- ggml_permute(ctx0,
1958
- ggml_reshape_3d(ctx0,
1959
- ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
1960
- n_state/n_head, n_head, n_past + N),
1961
- 1, 2, 0, 3);
1962
-
1963
- wctx.use_buf(ctx0, 1);
1964
-
1965
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1966
-
1967
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1968
-
1969
- cur = ggml_cpy(ctx0,
1970
- KQV_merged,
1971
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
1972
- }
1973
-
1974
- // projection
1975
- {
1976
- wctx.use_buf(ctx0, 0);
1977
-
1978
- cur = ggml_mul_mat(ctx0,
1979
- layer.attn_ln_1_w,
1980
- cur);
1981
-
1982
- wctx.use_buf(ctx0, 1);
1983
-
1984
- cur = ggml_add(ctx0,
1985
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1986
- cur);
1987
- }
1988
-
1989
- wctx.use_buf(ctx0, 2);
1990
-
1991
- // add the input
1992
- struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
1993
-
1994
- // norm
1995
- {
1996
- wctx.use_buf(ctx0, 0);
1997
-
1998
- cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1999
-
2000
- wctx.use_buf(ctx0, 1);
2001
-
2002
- // cur = ln_0_w*cur + ln_0_b
2003
- cur = ggml_add(ctx0,
2004
- ggml_mul(ctx0,
2005
- ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
2006
- cur),
2007
- ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
2008
- }
2009
-
2010
- // cross-attention
2011
- {
2012
- wctx.use_buf(ctx0, 0);
2013
-
2014
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
2015
- layer.cross_attn_q_w,
2016
- cur);
2017
-
2018
- Qcur = ggml_add(ctx0,
2019
- ggml_repeat(ctx0,
2020
- layer.cross_attn_q_b,
2021
- Qcur),
2022
- Qcur);
2023
-
2024
- Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2025
-
2026
- // Kcross is already scaled
2027
- struct ggml_tensor * Kcross =
2028
- ggml_reshape_3d(ctx0,
2029
- ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
2030
- n_state/n_head, n_head, M);
2031
-
2032
- struct ggml_tensor * Vcross =
2033
- ggml_reshape_3d(ctx0,
2034
- ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
2035
- n_state/n_head, n_head, M);
2036
-
2037
- struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
2038
-
2039
- // ------
2040
-
2041
- wctx.use_buf(ctx0, 1);
2042
-
2043
- struct ggml_tensor * Q =
2044
- ggml_permute(ctx0,
2045
- ggml_cpy(ctx0,
2046
- Qcur,
2047
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
2048
- 0, 2, 1, 3);
2049
-
2050
- struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2051
-
2052
- wctx.use_buf(ctx0, 0);
2053
-
2054
- // K * Q
2055
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2056
-
2057
- //struct ggml_tensor * KQ_scaled =
2058
- // ggml_scale(ctx0,
2059
- // KQ,
2060
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2061
- // );
2062
-
2063
- // no masking for cross-attention
2064
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2065
-
2066
- wctx.use_buf(ctx0, 1);
2067
-
2068
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2069
-
2070
- wctx.use_buf(ctx0, 0);
2071
-
2072
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2073
-
2074
- wctx.use_buf(ctx0, 1);
2075
-
2076
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2077
-
2078
- // cur = KQV_merged.contiguous().view(n_state, N)
2079
- cur = ggml_cpy(ctx0,
2080
- KQV_merged,
2081
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
2082
- }
2083
-
2084
- // projection
2085
- {
2086
- wctx.use_buf(ctx0, 0);
2087
-
2088
- cur = ggml_mul_mat(ctx0,
2089
- layer.cross_attn_ln_1_w,
2090
- cur);
2091
-
2092
- wctx.use_buf(ctx0, 1);
2093
-
2094
- cur = ggml_add(ctx0,
2095
- ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2096
- cur);
2097
- }
2098
-
2099
- wctx.use_buf(ctx0, 2);
2100
-
2101
- // add the input
2102
- cur = ggml_add(ctx0, cur, inpCA);
2103
-
2104
- struct ggml_tensor * inpFF = cur;
2105
-
2106
- // feed-forward network
2107
- {
2108
- // norm
2109
- {
2110
- wctx.use_buf(ctx0, 0);
2111
-
2112
- cur = ggml_norm(ctx0, inpFF);
2113
-
2114
- wctx.use_buf(ctx0, 1);
2115
-
2116
- // cur = mlp_ln_w*cur + mlp_ln_b
2117
- cur = ggml_add(ctx0,
2118
- ggml_mul(ctx0,
2119
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
2120
- cur),
2121
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2122
- }
2123
-
2124
- wctx.use_buf(ctx0, 0);
2125
-
2126
- // fully connected
2127
- cur = ggml_mul_mat(ctx0,
2128
- layer.mlp_0_w,
2129
- cur);
2130
-
2131
- wctx.use_buf(ctx0, 1);
2132
-
2133
- cur = ggml_add(ctx0,
2134
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
2135
- cur);
2136
-
2137
- wctx.use_buf(ctx0, 0);
2138
-
2139
- // GELU activation
2140
- cur = ggml_gelu(ctx0, cur);
2141
-
2142
- wctx.use_buf(ctx0, 1);
2143
-
2144
- // projection
2145
- cur = ggml_mul_mat(ctx0,
2146
- layer.mlp_1_w,
2147
- cur);
2148
-
2149
- wctx.use_buf(ctx0, 0);
2150
-
2151
- cur = ggml_add(ctx0,
2152
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
2153
- cur);
2154
- }
2155
-
2156
- wctx.use_buf(ctx0, 3);
2157
-
2158
- inpL = ggml_add(ctx0, cur, inpFF);
2159
- }
2160
-
2161
- cur = inpL;
2162
-
2163
- // norm
2164
- {
2165
- wctx.use_buf(ctx0, 0);
2166
-
2167
- cur = ggml_norm(ctx0, cur);
2168
-
2169
- wctx.use_buf(ctx0, 1);
2170
-
2171
- cur = ggml_add(ctx0,
2172
- ggml_mul(ctx0,
2173
- ggml_repeat(ctx0, model.d_ln_w, cur),
2174
- cur),
2175
- ggml_repeat(ctx0, model.d_ln_b, cur));
2176
- }
2177
-
2178
- wctx.use_buf(ctx0, 0);
2179
-
2180
- // compute logits only for the last token
2181
- // comment this line to compute logits for all N tokens
2182
- // might be useful in the future
2183
- cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2184
-
2185
- struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2186
-
2187
- wctx.use_buf(ctx0, -1);
2188
-
2189
- // run the computation
2190
- {
2191
- ggml_build_forward_expand(&gf, logits);
2192
- ggml_graph_compute (ctx0, &gf);
2193
- }
2194
-
2195
- // extract logits for all N tokens
2196
- //logits_out.resize(N*n_vocab);
2197
- //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
2198
-
2199
- // extract logits only for the last token
2200
- logits_out.resize(n_vocab);
2201
- memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
2202
-
2203
- if (N > 1) {
2204
- //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2205
- // ggml_used_mem(ctx0)/1024.0/1024.0,
2206
- // wctx.get_buf_max_mem(0)/1024.0/1024.0,
2207
- // wctx.get_buf_max_mem(1)/1024.0/1024.0,
2208
- // wctx.get_buf_max_mem(2)/1024.0/1024.0,
2209
- // wctx.get_buf_max_mem(3)/1024.0/1024.0);
2210
- }
2211
-
2212
- ggml_free(ctx0);
2213
-
2214
- wctx.t_decode_us += ggml_time_us() - t_start_us;
2215
- wctx.n_decode++;
2216
-
2217
- return true;
2218
- }
2219
-
2220
- // 500 -> 00:05.000
2221
- // 6000 -> 01:00.000
2222
- static std::string to_timestamp(int64_t t, bool comma = false) {
2223
- int64_t msec = t * 10;
2224
- int64_t hr = msec / (1000 * 60 * 60);
2225
- msec = msec - hr * (1000 * 60 * 60);
2226
- int64_t min = msec / (1000 * 60);
2227
- msec = msec - min * (1000 * 60);
2228
- int64_t sec = msec / 1000;
2229
- msec = msec - sec * 1000;
2230
-
2231
- char buf[32];
2232
- snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
2233
-
2234
- return std::string(buf);
2235
- }
2236
-
2237
- // naive Discrete Fourier Transform
2238
- // input is real-valued
2239
- // output is complex-valued
2240
- static void dft(const std::vector<float> & in, std::vector<float> & out) {
2241
- int N = in.size();
2242
-
2243
- out.resize(N*2);
2244
-
2245
- for (int k = 0; k < N; k++) {
2246
- float re = 0;
2247
- float im = 0;
2248
-
2249
- for (int n = 0; n < N; n++) {
2250
- float angle = 2*M_PI*k*n/N;
2251
- re += in[n]*cos(angle);
2252
- im -= in[n]*sin(angle);
2253
- }
2254
-
2255
- out[k*2 + 0] = re;
2256
- out[k*2 + 1] = im;
2257
- }
2258
- }
2259
-
2260
- // Cooley-Tukey FFT
2261
- // poor man's implementation - use something better
2262
- // input is real-valued
2263
- // output is complex-valued
2264
- static void fft(const std::vector<float> & in, std::vector<float> & out) {
2265
- out.resize(in.size()*2);
2266
-
2267
- int N = in.size();
2268
-
2269
- if (N == 1) {
2270
- out[0] = in[0];
2271
- out[1] = 0;
2272
- return;
2273
- }
2274
-
2275
- if (N%2 == 1) {
2276
- dft(in, out);
2277
- return;
2278
- }
2279
-
2280
- std::vector<float> even;
2281
- std::vector<float> odd;
2282
-
2283
- even.reserve(N/2);
2284
- odd.reserve(N/2);
2285
-
2286
- for (int i = 0; i < N; i++) {
2287
- if (i % 2 == 0) {
2288
- even.push_back(in[i]);
2289
- } else {
2290
- odd.push_back(in[i]);
2291
- }
2292
- }
2293
-
2294
- std::vector<float> even_fft;
2295
- std::vector<float> odd_fft;
2296
-
2297
- fft(even, even_fft);
2298
- fft(odd, odd_fft);
2299
-
2300
- for (int k = 0; k < N/2; k++) {
2301
- float theta = 2*M_PI*k/N;
2302
-
2303
- float re = cos(theta);
2304
- float im = -sin(theta);
2305
-
2306
- float re_odd = odd_fft[2*k + 0];
2307
- float im_odd = odd_fft[2*k + 1];
2308
-
2309
- out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
2310
- out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
2311
-
2312
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2313
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2314
- }
2315
- }
2316
-
2317
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2318
- static bool log_mel_spectrogram(
2319
- whisper_context & wctx,
2320
- const float * samples,
2321
- const int n_samples,
2322
- const int /*sample_rate*/,
2323
- const int fft_size,
2324
- const int fft_step,
2325
- const int n_mel,
2326
- const int n_threads,
2327
- const whisper_filters & filters,
2328
- const bool speed_up,
2329
- whisper_mel & mel) {
2330
- const int64_t t_start_us = ggml_time_us();
2331
-
2332
- // Hanning window
2333
- std::vector<float> hann;
2334
- hann.resize(fft_size);
2335
- for (int i = 0; i < fft_size; i++) {
2336
- hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
2337
- }
2338
-
2339
- mel.n_mel = n_mel;
2340
- mel.n_len = (n_samples)/fft_step;
2341
- mel.data.resize(mel.n_mel*mel.n_len);
2342
-
2343
- const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
2344
-
2345
- //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2346
- //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
2347
-
2348
- std::vector<std::thread> workers(n_threads);
2349
- for (int iw = 0; iw < n_threads; ++iw) {
2350
- workers[iw] = std::thread([&](int ith) {
2351
- std::vector<float> fft_in;
2352
- fft_in.resize(fft_size);
2353
- for (int i = 0; i < fft_size; i++) {
2354
- fft_in[i] = 0.0;
2355
- }
2356
-
2357
- std::vector<float> fft_out;
2358
- fft_out.resize(2*fft_size);
2359
-
2360
- for (int i = ith; i < mel.n_len; i += n_threads) {
2361
- const int offset = i*fft_step;
2362
-
2363
- // apply Hanning window
2364
- for (int j = 0; j < fft_size; j++) {
2365
- if (offset + j < n_samples) {
2366
- fft_in[j] = hann[j]*samples[offset + j];
2367
- } else {
2368
- fft_in[j] = 0.0;
2369
- }
2370
- }
2371
-
2372
- // FFT -> mag^2
2373
- fft(fft_in, fft_out);
2374
-
2375
- for (int j = 0; j < fft_size; j++) {
2376
- fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
2377
- }
2378
- for (int j = 1; j < fft_size/2; j++) {
2379
- //if (i == 0) {
2380
- // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
2381
- //}
2382
- fft_out[j] += fft_out[fft_size - j];
2383
- }
2384
- if (i == 0) {
2385
- //for (int j = 0; j < fft_size; j++) {
2386
- // printf("%d: %e\n", j, fft_out[j]);
2387
- //}
2388
- }
2389
-
2390
- if (speed_up) {
2391
- // scale down in the frequency domain results in a speed up in the time domain
2392
- for (int j = 0; j < n_fft; j++) {
2393
- fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
2394
- }
2395
- }
2396
-
2397
- // mel spectrogram
2398
- for (int j = 0; j < mel.n_mel; j++) {
2399
- double sum = 0.0;
2400
-
2401
- for (int k = 0; k < n_fft; k++) {
2402
- sum += fft_out[k]*filters.data[j*n_fft + k];
2403
- }
2404
- if (sum < 1e-10) {
2405
- sum = 1e-10;
2406
- }
2407
-
2408
- sum = log10(sum);
2409
-
2410
- mel.data[j*mel.n_len + i] = sum;
2411
- }
2412
- }
2413
- }, iw);
2414
- }
2415
-
2416
- for (int iw = 0; iw < n_threads; ++iw) {
2417
- workers[iw].join();
2418
- }
2419
-
2420
- // clamping and normalization
2421
- double mmax = -1e20;
2422
- for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2423
- if (mel.data[i] > mmax) {
2424
- mmax = mel.data[i];
2425
- }
2426
- }
2427
- //printf("%s: max = %f\n", __func__, mmax);
2428
-
2429
- mmax -= 8.0;
2430
-
2431
- for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2432
- if (mel.data[i] < mmax) {
2433
- mel.data[i] = mmax;
2434
- }
2435
-
2436
- mel.data[i] = (mel.data[i] + 4.0)/4.0;
2437
- }
2438
-
2439
- wctx.t_mel_us += ggml_time_us() - t_start_us;
2440
-
2441
- return true;
2442
- }
2443
-
2444
- // split text into tokens
2445
- //
2446
- // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
2447
- //
2448
- // Regex (Python):
2449
- // r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
2450
- //
2451
- // Regex (C++):
2452
- // R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
2453
- //
2454
- static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
2455
- std::vector<std::string> words;
2456
-
2457
- // first split the text into words
2458
- {
2459
- std::string str = text;
2460
- std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
2461
-
2462
- std::regex re(pat);
2463
- std::smatch m;
2464
-
2465
- while (std::regex_search(str, m, re)) {
2466
- for (auto x : m) {
2467
- words.push_back(x);
2468
- }
2469
- str = m.suffix();
2470
- }
2471
- }
2472
-
2473
- // find the longest tokens that form the words:
2474
- std::vector<whisper_vocab::id> tokens;
2475
- for (const auto & word : words) {
2476
- if (word.empty()) continue;
2477
-
2478
- int i = 0;
2479
- int n = word.size();
2480
- while (i < n) {
2481
- int j = n;
2482
- while (j > i) {
2483
- auto it = vocab.token_to_id.find(word.substr(i, j-i));
2484
- if (it != vocab.token_to_id.end()) {
2485
- tokens.push_back(it->second);
2486
- i = j;
2487
- break;
2488
- }
2489
- --j;
2490
- }
2491
- if (i == n) {
2492
- break;
2493
- }
2494
- if (j == i) {
2495
- auto sub = word.substr(i, 1);
2496
- if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
2497
- tokens.push_back(vocab.token_to_id.at(sub));
2498
- } else {
2499
- fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
2500
- }
2501
- ++i;
2502
- }
2503
- }
2504
- }
2505
-
2506
- return tokens;
2507
- }
2508
-
2509
- //
2510
- // interface implementation
2511
- //
2512
-
2513
- struct whisper_context * whisper_init_from_file(const char * path_model) {
2514
- whisper_model_loader loader = {};
2515
-
2516
- fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
2517
-
2518
- auto fin = std::ifstream(path_model, std::ios::binary);
2519
- if (!fin) {
2520
- fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
2521
- return nullptr;
2522
- }
2523
-
2524
- loader.context = &fin;
2525
- loader.read = [](void * ctx, void * output, size_t read_size) {
2526
- std::ifstream * fin = (std::ifstream*)ctx;
2527
- fin->read((char *)output, read_size);
2528
- return read_size;
2529
- };
2530
-
2531
- loader.eof = [](void * ctx) {
2532
- std::ifstream * fin = (std::ifstream*)ctx;
2533
- return fin->eof();
2534
- };
2535
-
2536
- loader.close = [](void * ctx) {
2537
- std::ifstream * fin = (std::ifstream*)ctx;
2538
- fin->close();
2539
- };
2540
-
2541
- return whisper_init(&loader);
2542
- }
2543
-
2544
- struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
2545
- struct buf_context {
2546
- uint8_t* buffer;
2547
- size_t size;
2548
- size_t current_offset;
2549
- };
2550
-
2551
- buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
2552
- whisper_model_loader loader = {};
2553
-
2554
- fprintf(stderr, "%s: loading model from buffer\n", __func__);
2555
-
2556
- loader.context = &ctx;
2557
-
2558
- loader.read = [](void * ctx, void * output, size_t read_size) {
2559
- buf_context * buf = reinterpret_cast<buf_context *>(ctx);
2560
-
2561
- size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
2562
-
2563
- memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
2564
- buf->current_offset += size_to_copy;
2565
-
2566
- return size_to_copy;
2567
- };
2568
-
2569
- loader.eof = [](void * ctx) {
2570
- buf_context * buf = reinterpret_cast<buf_context *>(ctx);
2571
-
2572
- return buf->current_offset >= buf->size;
2573
- };
2574
-
2575
- loader.close = [](void * /*ctx*/) { };
2576
-
2577
- return whisper_init(&loader);
2578
- }
2579
-
2580
- struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2581
- ggml_time_init();
2582
-
2583
- whisper_context * ctx = new whisper_context;
2584
-
2585
- if (!whisper_model_load(loader, *ctx)) {
2586
- loader->close(loader->context);
2587
- fprintf(stderr, "%s: failed to load model\n", __func__);
2588
- delete ctx;
2589
- return nullptr;
2590
- }
2591
-
2592
- loader->close(loader->context);
2593
-
2594
- return ctx;
2595
- }
2596
-
2597
- void whisper_free(struct whisper_context * ctx) {
2598
- if (ctx) {
2599
- if (ctx->model.ctx) {
2600
- ggml_free(ctx->model.ctx);
2601
- }
2602
- if (ctx->model.buf) {
2603
- delete ctx->model.buf;
2604
- }
2605
- if (ctx->kv_cross.ctx) {
2606
- ggml_free(ctx->kv_cross.ctx);
2607
- }
2608
- for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
2609
- if (ctx->decoders[i].kv_self.ctx) {
2610
- ggml_free(ctx->decoders[i].kv_self.ctx);
2611
- }
2612
- }
2613
- delete ctx;
2614
- }
2615
- }
2616
-
2617
- int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2618
- if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
2619
- fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2620
- return -1;
2621
- }
2622
-
2623
- return 0;
2624
- }
2625
-
2626
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
2627
- int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2628
- if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
2629
- fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2630
- return -1;
2631
- }
2632
-
2633
- return 0;
2634
- }
2635
-
2636
- int whisper_set_mel(
2637
- struct whisper_context * ctx,
2638
- const float * data,
2639
- int n_len,
2640
- int n_mel) {
2641
- if (n_mel != WHISPER_N_MEL) {
2642
- fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
2643
- return -1;
2644
- }
2645
-
2646
- ctx->mel.n_len = n_len;
2647
- ctx->mel.n_mel = n_mel;
2648
-
2649
- ctx->mel.data.resize(n_len*n_mel);
2650
- memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
2651
-
2652
- return 0;
2653
- }
2654
-
2655
- int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2656
- if (!whisper_encode(*ctx, offset, n_threads)) {
2657
- fprintf(stderr, "%s: failed to eval\n", __func__);
2658
- return -1;
2659
- }
2660
-
2661
- return 0;
2662
- }
2663
-
2664
- int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2665
- // TODO: add selected_decoder_id to context
2666
- const int selected_decoder_id = 0;
2667
-
2668
- if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2669
- fprintf(stderr, "%s: failed to eval\n", __func__);
2670
- return 1;
2671
- }
2672
-
2673
- return 0;
2674
- }
2675
-
2676
- int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
2677
- const auto res = tokenize(ctx->vocab, text);
2678
-
2679
- if (n_max_tokens < (int) res.size()) {
2680
- fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
2681
- return -1;
2682
- }
2683
-
2684
- for (int i = 0; i < (int) res.size(); i++) {
2685
- tokens[i] = res[i];
2686
- }
2687
-
2688
- return res.size();
2689
- }
2690
-
2691
- int whisper_lang_max_id() {
2692
- auto max_id = 0;
2693
- for (const auto & kv : g_lang) {
2694
- max_id = std::max(max_id, kv.second.first);
2695
- }
2696
-
2697
- return max_id;
2698
- }
2699
-
2700
- int whisper_lang_id(const char * lang) {
2701
- if (!g_lang.count(lang)) {
2702
- for (const auto & kv : g_lang) {
2703
- if (kv.second.second == lang) {
2704
- return kv.second.first;
2705
- }
2706
- }
2707
-
2708
- fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
2709
- return -1;
2710
- }
2711
-
2712
- return g_lang.at(lang).first;
2713
- }
2714
-
2715
- const char * whisper_lang_str(int id) {
2716
- for (const auto & kv : g_lang) {
2717
- if (kv.second.first == id) {
2718
- return kv.first.c_str();
2719
- }
2720
- }
2721
-
2722
- fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
2723
- return nullptr;
2724
- }
2725
-
2726
- int whisper_lang_auto_detect(
2727
- struct whisper_context * ctx,
2728
- int offset_ms,
2729
- int n_threads,
2730
- float * lang_probs) {
2731
- const int seek = offset_ms/10;
2732
-
2733
- if (seek < 0) {
2734
- fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
2735
- return -1;
2736
- }
2737
-
2738
- if (seek >= ctx->mel.n_len) {
2739
- fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
2740
- return -2;
2741
- }
2742
-
2743
- // run the encoder
2744
- if (whisper_encode(ctx, seek, n_threads) != 0) {
2745
- fprintf(stderr, "%s: failed to encode\n", __func__);
2746
- return -6;
2747
- }
2748
-
2749
- const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
2750
-
2751
- if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
2752
- fprintf(stderr, "%s: failed to decode\n", __func__);
2753
- return -7;
2754
- }
2755
-
2756
- auto & logits_id = ctx->logits_id;
2757
- logits_id.clear();
2758
-
2759
- for (const auto & kv : g_lang) {
2760
- const auto token_lang = whisper_token_lang(ctx, kv.second.first);
2761
- logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
2762
- }
2763
-
2764
- // sort descending
2765
- {
2766
- using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
2767
- std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
2768
- return a.first > b.first;
2769
- });
2770
- }
2771
-
2772
- // softmax
2773
- {
2774
- const auto max = logits_id[0].first;
2775
-
2776
- double sum = 0.0f;
2777
- for (auto & kv : logits_id) {
2778
- kv.first = exp(kv.first - max);
2779
- sum += kv.first;
2780
- }
2781
-
2782
- for (auto & kv : logits_id) {
2783
- kv.first /= sum;
2784
- }
2785
- }
2786
-
2787
- {
2788
- for (const auto & prob : logits_id) {
2789
- if (lang_probs) {
2790
- lang_probs[prob.second] = prob.first;
2791
- }
2792
-
2793
- //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
2794
- }
2795
- }
2796
-
2797
- return logits_id[0].second;
2798
- }
2799
-
2800
- int whisper_n_len(struct whisper_context * ctx) {
2801
- return ctx->mel.n_len;
2802
- }
2803
-
2804
- int whisper_n_vocab(struct whisper_context * ctx) {
2805
- return ctx->vocab.n_vocab;
2806
- }
2807
-
2808
- int whisper_n_text_ctx(struct whisper_context * ctx) {
2809
- return ctx->model.hparams.n_text_ctx;
2810
- }
2811
-
2812
- int whisper_n_audio_ctx(struct whisper_context * ctx) {
2813
- return ctx->model.hparams.n_audio_ctx;
2814
- }
2815
-
2816
- int whisper_is_multilingual(struct whisper_context * ctx) {
2817
- return ctx->vocab.is_multilingual() ? 1 : 0;
2818
- }
2819
-
2820
- float * whisper_get_logits(struct whisper_context * ctx) {
2821
- return ctx->logits.data();
2822
- }
2823
-
2824
- const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
2825
- return ctx->vocab.id_to_token.at(token).c_str();
2826
- }
2827
-
2828
- whisper_token whisper_token_eot(struct whisper_context * ctx) {
2829
- return ctx->vocab.token_eot;
2830
- }
2831
-
2832
- whisper_token whisper_token_sot(struct whisper_context * ctx) {
2833
- return ctx->vocab.token_sot;
2834
- }
2835
-
2836
- whisper_token whisper_token_prev(struct whisper_context * ctx) {
2837
- return ctx->vocab.token_prev;
2838
- }
2839
-
2840
- whisper_token whisper_token_solm(struct whisper_context * ctx) {
2841
- return ctx->vocab.token_solm;
2842
- }
2843
-
2844
- whisper_token whisper_token_not(struct whisper_context * ctx) {
2845
- return ctx->vocab.token_not;
2846
- }
2847
-
2848
- whisper_token whisper_token_beg(struct whisper_context * ctx) {
2849
- return ctx->vocab.token_beg;
2850
- }
2851
-
2852
- whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
2853
- return whisper_token_sot(ctx) + 1 + lang_id;
2854
- }
2855
-
2856
- whisper_token whisper_token_translate(void) {
2857
- return whisper_vocab::token_translate;
2858
- }
2859
-
2860
- whisper_token whisper_token_transcribe(void) {
2861
- return whisper_vocab::token_transcribe;
2862
- }
2863
-
2864
- void whisper_print_timings(struct whisper_context * ctx) {
2865
- const int64_t t_end_us = ggml_time_us();
2866
-
2867
- const int32_t n_sample = std::max(1, ctx->n_sample);
2868
- const int32_t n_encode = std::max(1, ctx->n_encode);
2869
- const int32_t n_decode = std::max(1, ctx->n_decode);
2870
-
2871
- fprintf(stderr, "\n");
2872
- fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
2873
- fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
2874
- fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
2875
- fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
2876
- fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
2877
- fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
2878
- fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
2879
- }
2880
-
2881
- void whisper_reset_timings(struct whisper_context * ctx) {
2882
- ctx->t_sample_us = 0;
2883
- ctx->t_encode_us = 0;
2884
- ctx->t_decode_us = 0;
2885
- }
2886
-
2887
- const char * whisper_print_system_info(void) {
2888
- static std::string s;
2889
-
2890
- s = "";
2891
- s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
2892
- s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
2893
- s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
2894
- s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
2895
- s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
2896
- s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
2897
- s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
2898
- s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
2899
- s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
2900
- s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
2901
- s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
2902
- s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
2903
-
2904
- return s.c_str();
2905
- }
2906
-
2907
- ////////////////////////////////////////////////////////////////////////////
2908
-
2909
- struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
2910
- struct whisper_full_params result = {
2911
- /*.strategy =*/ strategy,
2912
-
2913
- /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2914
- /*.n_max_text_ctx =*/ 16384,
2915
- /*.offset_ms =*/ 0,
2916
- /*.duration_ms =*/ 0,
2917
-
2918
- /*.translate =*/ false,
2919
- /*.no_context =*/ false,
2920
- /*.single_segment =*/ false,
2921
- /*.print_special =*/ false,
2922
- /*.print_progress =*/ true,
2923
- /*.print_realtime =*/ false,
2924
- /*.print_timestamps =*/ true,
2925
-
2926
- /*.token_timestamps =*/ false,
2927
- /*.thold_pt =*/ 0.01f,
2928
- /*.thold_ptsum =*/ 0.01f,
2929
- /*.max_len =*/ 0,
2930
- /*.split_on_word =*/ false,
2931
- /*.max_tokens =*/ 0,
2932
-
2933
- /*.speed_up =*/ false,
2934
- /*.audio_ctx =*/ 0,
2935
-
2936
- /*.prompt_tokens =*/ nullptr,
2937
- /*.prompt_n_tokens =*/ 0,
2938
-
2939
- /*.language =*/ "en",
2940
-
2941
- /*.suppress_blank =*/ true,
2942
- /*.suppress_non_speech_tokens =*/ false,
2943
-
2944
- /*.temperature =*/ 0.0f,
2945
- /*.max_initial_ts =*/ 1.0f,
2946
- /*.length_penalty =*/ -1.0f,
2947
-
2948
- /*.temperature_inc =*/ 0.2f,
2949
- /*.entropy_thold =*/ 2.4f,
2950
- /*.logprob_thold =*/ -1.0f,
2951
- /*.no_speech_thold =*/ 0.6f,
2952
-
2953
- /*.greedy =*/ {
2954
- /*.best_of =*/ -1,
2955
- },
2956
-
2957
- /*.beam_search =*/ {
2958
- /*.beam_size =*/ -1,
2959
-
2960
- /*.patience =*/ -1.0f,
2961
- },
2962
-
2963
- /*.new_segment_callback =*/ nullptr,
2964
- /*.new_segment_callback_user_data =*/ nullptr,
2965
-
2966
- /*.encoder_begin_callback =*/ nullptr,
2967
- /*.encoder_begin_callback_user_data =*/ nullptr,
2968
-
2969
- /*.logits_filter_callback =*/ nullptr,
2970
- /*.logits_filter_callback_user_data =*/ nullptr,
2971
- };
2972
-
2973
- switch (strategy) {
2974
- case WHISPER_SAMPLING_GREEDY:
2975
- {
2976
- result.greedy = {
2977
- /*.best_of =*/ 1,
2978
- };
2979
- } break;
2980
- case WHISPER_SAMPLING_BEAM_SEARCH:
2981
- {
2982
- result.beam_search = {
2983
- /*.beam_size =*/ 5,
2984
-
2985
- /*.patience =*/ -1.0f,
2986
- };
2987
- } break;
2988
- }
2989
-
2990
- return result;
2991
- }
2992
-
2993
- // forward declarations
2994
- static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
2995
- static void whisper_exp_compute_token_level_timestamps(
2996
- struct whisper_context & ctx,
2997
- int i_segment,
2998
- float thold_pt,
2999
- float thold_ptsum);
3000
-
3001
- // trim from start (in place)
3002
- static inline void ltrim(std::string &s) {
3003
- s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
3004
- return !std::isspace(ch);
3005
- }));
3006
- }
3007
-
3008
- // trim from end (in place)
3009
- static inline void rtrim(std::string &s) {
3010
- s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
3011
- return !std::isspace(ch);
3012
- }).base(), s.end());
3013
- }
3014
-
3015
- // trim from both ends (in place)
3016
- static inline void trim(std::string &s) {
3017
- rtrim(s);
3018
- ltrim(s);
3019
- }
3020
-
3021
- static inline bool should_split_on_word(const char * txt, bool split_on_word) {
3022
- if (!split_on_word) return true;
3023
-
3024
- return txt[0] == ' ';
3025
- }
3026
-
3027
- // wrap the last segment to max_len characters
3028
- // returns the number of new segments
3029
- static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
3030
- auto segment = ctx.result_all.back();
3031
-
3032
- int res = 1;
3033
- int acc = 0;
3034
-
3035
- std::string text;
3036
-
3037
- for (int i = 0; i < (int) segment.tokens.size(); i++) {
3038
- const auto & token = segment.tokens[i];
3039
- if (token.id >= whisper_token_eot(&ctx)) {
3040
- continue;
3041
- }
3042
-
3043
- const auto txt = whisper_token_to_str(&ctx, token.id);
3044
- const int cur = strlen(txt);
3045
-
3046
- if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
3047
- // split here
3048
- if (split_on_word) {
3049
- trim(text);
3050
- }
3051
-
3052
- ctx.result_all.back().text = std::move(text);
3053
- ctx.result_all.back().t1 = token.t0;
3054
- ctx.result_all.back().tokens.resize(i);
3055
-
3056
- ctx.result_all.push_back({});
3057
- ctx.result_all.back().t0 = token.t0;
3058
- ctx.result_all.back().t1 = segment.t1;
3059
-
3060
- // add tokens [i, end] to the new segment
3061
- ctx.result_all.back().tokens.insert(
3062
- ctx.result_all.back().tokens.end(),
3063
- segment.tokens.begin() + i,
3064
- segment.tokens.end());
3065
-
3066
- acc = 0;
3067
- text = "";
3068
-
3069
- segment = ctx.result_all.back();
3070
- i = -1;
3071
-
3072
- res++;
3073
- } else {
3074
- acc += cur;
3075
- text += txt;
3076
- }
3077
- }
3078
-
3079
- if (split_on_word) {
3080
- trim(text);
3081
- }
3082
- ctx.result_all.back().text = std::move(text);
3083
-
3084
- return res;
3085
- }
3086
-
3087
- static const std::vector<std::string> non_speech_tokens = {
3088
- "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
3089
- "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
3090
- "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
3091
- "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
3092
- };
3093
-
3094
- // process the logits for the selected decoder
3095
- // - applies logit filters
3096
- // - computes logprobs and probs
3097
- static void whisper_process_logits(
3098
- struct whisper_context & ctx,
3099
- const struct whisper_full_params params,
3100
- struct whisper_decoder & decoder,
3101
- float temperature) {
3102
- const auto & vocab = ctx.vocab;
3103
- const auto & tokens_cur = decoder.sequence.tokens;
3104
-
3105
- const bool is_initial = tokens_cur.size() == 0;
3106
- const int n_logits = vocab.id_to_token.size();
3107
-
3108
- WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
3109
-
3110
- // extract the logits for the last token
3111
- // we will be mutating and therefore we don't want to use the ctx.logits buffer directly
3112
- auto & probs = decoder.probs;
3113
- auto & logits = decoder.logits;
3114
- auto & logprobs = decoder.logprobs;
3115
- {
3116
- logits.resize(n_logits);
3117
- memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
3118
-
3119
- if (temperature > 0.0f) {
3120
- for (int i = 0; i < n_logits; i++) {
3121
- logits[i] /= temperature;
3122
- }
3123
- }
3124
-
3125
- // will be populated a bit later
3126
- probs.resize(n_logits);
3127
- logprobs.resize(n_logits);
3128
- }
3129
-
3130
- // apply logit filters here
3131
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
3132
- {
3133
- // suppress blank
3134
- // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
3135
- if (params.suppress_blank) {
3136
- if (is_initial) {
3137
- logits[vocab.token_eot] = -INFINITY;
3138
- logits[vocab.token_to_id.at(" ")] = -INFINITY;
3139
- }
3140
- }
3141
-
3142
- // suppress <|notimestamps|> token
3143
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
3144
- logits[vocab.token_not] = -INFINITY;
3145
-
3146
- // suppress sot and solm tokens
3147
- logits[vocab.token_sot] = -INFINITY;
3148
- logits[vocab.token_solm] = -INFINITY;
3149
-
3150
- // suppress task tokens
3151
- logits[vocab.token_translate] = -INFINITY;
3152
- logits[vocab.token_transcribe] = -INFINITY;
3153
-
3154
- if (params.logits_filter_callback) {
3155
- params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
3156
- }
3157
-
3158
- // suppress non-speech tokens
3159
- // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
3160
- if (params.suppress_non_speech_tokens) {
3161
- for (const std::string & token : non_speech_tokens) {
3162
- const std::string suppress_tokens[] = {token, " " + token};
3163
- for (const std::string & suppress_token : suppress_tokens) {
3164
- if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
3165
- logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
3166
- }
3167
- }
3168
- }
3169
-
3170
- // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
3171
- if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
3172
- logits[vocab.token_to_id.at(" -")] = -INFINITY;
3173
- }
3174
- if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
3175
- logits[vocab.token_to_id.at(" '")] = -INFINITY;
3176
- }
3177
- }
3178
-
3179
- // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
3180
- // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
3181
- {
3182
- const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
3183
- const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
3184
-
3185
- //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
3186
-
3187
- if (last_was_timestamp) {
3188
- if (penultimate_was_timestamp) {
3189
- for (int i = vocab.token_beg; i < n_logits; ++i) {
3190
- logits[i] = -INFINITY;
3191
- }
3192
- } else {
3193
- for (int i = 0; i < vocab.token_eot; ++i) {
3194
- logits[i] = -INFINITY;
3195
- }
3196
- }
3197
- }
3198
- }
3199
-
3200
- // the initial timestamp cannot be larger than max_initial_ts
3201
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
3202
- if (is_initial && params.max_initial_ts > 0.0f) {
3203
- const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
3204
- const int tid0 = std::round(params.max_initial_ts/precision);
3205
-
3206
- for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
3207
- logits[i] = -INFINITY;
3208
- }
3209
- }
3210
-
3211
- // condition timestamp tokens to be increasing
3212
- // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556
3213
- if (decoder.has_ts) {
3214
- const int tid0 = decoder.seek_delta/2;
3215
-
3216
- for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) {
3217
- logits[i] = -INFINITY;
3218
- }
3219
- }
3220
-
3221
- // populate the logprobs array (log_softmax)
3222
- {
3223
- const float logit_max = *std::max_element(logits.begin(), logits.end());
3224
- float logsumexp = 0.0f;
3225
- for (int i = 0; i < n_logits; ++i) {
3226
- if (logits[i] > -INFINITY) {
3227
- logsumexp += expf(logits[i] - logit_max);
3228
- }
3229
- }
3230
- logsumexp = logf(logsumexp) + logit_max;
3231
-
3232
- for (int i = 0; i < n_logits; ++i) {
3233
- if (logits[i] > -INFINITY) {
3234
- logprobs[i] = logits[i] - logsumexp;
3235
- } else {
3236
- logprobs[i] = -INFINITY;
3237
- }
3238
- }
3239
- }
3240
-
3241
- // if sum of probability over timestamps is above any other token, sample timestamp
3242
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
3243
- {
3244
- // logsumexp over timestamps
3245
- float timestamp_logprob = -INFINITY;
3246
- {
3247
- float logsumexp = 0.0f;
3248
- const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
3249
- for (int i = vocab.token_beg; i < n_logits; ++i) {
3250
- if (logprobs[i] > -INFINITY) {
3251
- logsumexp += expf(logprobs[i] - logprob_max);
3252
- }
3253
- }
3254
- if (logsumexp > 0.0f) {
3255
- timestamp_logprob = logf(logsumexp) + logprob_max;
3256
- }
3257
- }
3258
-
3259
- const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
3260
-
3261
- //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
3262
-
3263
- if (timestamp_logprob > max_text_token_logprob) {
3264
- for (int i = 0; i < vocab.token_beg; ++i) {
3265
- logits[i] = -INFINITY;
3266
- logprobs[i] = -INFINITY;
3267
- }
3268
- }
3269
- }
3270
- }
3271
-
3272
- // compute probs
3273
- {
3274
- for (int i = 0; i < n_logits; ++i) {
3275
- if (logits[i] == -INFINITY) {
3276
- probs[i] = 0.0f;
3277
- } else {
3278
- probs[i] = expf(logprobs[i]);
3279
- }
3280
- }
3281
- }
3282
-
3283
- #if 0
3284
- // print first 100 logits - token string : logit
3285
- for (int i = 0; i < 100; i++) {
3286
- const auto token = vocab.id_to_token.at(i);
3287
- const auto prob = probs[i];
3288
- const auto logit = logits[i];
3289
- const auto logprob = logprobs[i];
3290
- printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
3291
- }
3292
-
3293
- // "And", "and", " And", " and"
3294
- printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
3295
- printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
3296
- printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
3297
- printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
3298
- printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
3299
-
3300
- printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
3301
- printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
3302
- printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
3303
- printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
3304
- printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
3305
-
3306
- printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
3307
- printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
3308
- printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
3309
- printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
3310
- printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
3311
- #endif
3312
- }
3313
-
3314
- static whisper_token_data whisper_sample_token(
3315
- whisper_context & ctx,
3316
- const whisper_decoder & decoder,
3317
- bool best) {
3318
- whisper_token_data result = {
3319
- 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
3320
- };
3321
-
3322
- const auto & vocab = ctx.vocab;
3323
-
3324
- const auto & probs = decoder.probs;
3325
- const auto & logprobs = decoder.logprobs;
3326
-
3327
- const int n_logits = vocab.n_vocab;
3328
-
3329
- {
3330
- double sum_ts = 0.0;
3331
- double max_ts = 0.0;
3332
-
3333
- for (int i = vocab.token_beg; i < n_logits; i++) {
3334
- if (probs[i] == -INFINITY) {
3335
- continue;
3336
- }
3337
-
3338
- sum_ts += probs[i];
3339
- if (max_ts < probs[i]) {
3340
- max_ts = probs[i];
3341
- result.tid = i;
3342
- }
3343
- }
3344
-
3345
- result.pt = max_ts/(sum_ts + 1e-10);
3346
- result.ptsum = sum_ts;
3347
- }
3348
-
3349
- if (best) {
3350
- for (int i = 0; i < n_logits; ++i) {
3351
- if (result.p < probs[i]) {
3352
- result.id = i;
3353
- result.p = probs[i];
3354
- result.plog = logprobs[i];
3355
- }
3356
- }
3357
- } else {
3358
- std::discrete_distribution<> dist(probs.begin(), probs.end());
3359
-
3360
- result.id = dist(ctx.rng);
3361
- result.p = probs[result.id];
3362
- result.plog = logprobs[result.id];
3363
- }
3364
-
3365
- if (result.id >= vocab.token_beg) {
3366
- result.tid = result.id;
3367
- result.pt = result.p;
3368
- }
3369
-
3370
- ctx.n_sample++;
3371
-
3372
- return result;
3373
- }
3374
-
3375
- static std::vector<whisper_token_data> whisper_sample_token_topk(
3376
- whisper_context & ctx,
3377
- const whisper_decoder & decoder,
3378
- int k) {
3379
- const auto & vocab = ctx.vocab;
3380
-
3381
- const auto & probs = decoder.probs;
3382
- const auto & logits = decoder.logits;
3383
- const auto & logprobs = decoder.logprobs;
3384
-
3385
- const int n_logits = vocab.n_vocab;
3386
-
3387
- auto & logits_id = ctx.logits_id;
3388
-
3389
- logits_id.clear();
3390
- for (int i = 0; i < n_logits; ++i) {
3391
- logits_id.push_back({ logits[i], i });
3392
- }
3393
-
3394
- std::partial_sort(
3395
- logits_id.begin(),
3396
- logits_id.begin() + k, logits_id.end(),
3397
- [](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
3398
- return a.first > b.first;
3399
- });
3400
-
3401
- std::vector<whisper_token_data> result;
3402
- result.reserve(k);
3403
-
3404
- whisper_token tid = vocab.token_beg;
3405
-
3406
- float pt = 0.0;
3407
- float ptsum = 0.0;
3408
-
3409
- {
3410
- double sum_ts = 0.0;
3411
- double max_ts = 0.0;
3412
-
3413
- for (int i = vocab.token_beg; i < n_logits; i++) {
3414
- if (probs[i] == -INFINITY) {
3415
- continue;
3416
- }
3417
-
3418
- sum_ts += probs[i];
3419
- if (max_ts < probs[i]) {
3420
- max_ts = probs[i];
3421
- tid = i;
3422
- }
3423
- }
3424
-
3425
- pt = max_ts/(sum_ts + 1e-10);
3426
- ptsum = sum_ts;
3427
- }
3428
-
3429
- for (int i = 0; i < k; ++i) {
3430
- const auto id = logits_id[i].second;
3431
-
3432
- result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
3433
-
3434
- if (result[i].id >= vocab.token_beg) {
3435
- result[i].tid = result[i].id;
3436
- result[i].pt = result[i].p;
3437
- }
3438
- }
3439
-
3440
- ctx.n_sample++;
3441
-
3442
- return result;
3443
- }
3444
-
3445
- // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
3446
- static void whisper_sequence_score(
3447
- const struct whisper_full_params & params,
3448
- whisper_sequence & sequence) {
3449
- if (sequence.result_len == 0) {
3450
- return;
3451
- }
3452
-
3453
- double result = 0.0f;
3454
-
3455
- for (int i = 0; i < sequence.result_len; ++i) {
3456
- result += sequence.tokens[i].plog;
3457
- }
3458
-
3459
- sequence.sum_logprobs = result;
3460
- sequence.avg_logprobs = result/sequence.result_len;
3461
-
3462
- double penalty = sequence.result_len;
3463
-
3464
- if (params.length_penalty > 0.0f) {
3465
- penalty = pow((5.0 + penalty)/6.0, params.length_penalty);
3466
- }
3467
-
3468
- sequence.score = result/penalty;
3469
-
3470
- // compute the entropy of the sequence of the last 32 tokens
3471
- {
3472
- const int n = 32;
3473
-
3474
- int cnt = 0;
3475
- double entropy = 0.0f;
3476
-
3477
- std::map<whisper_token, int> token_counts;
3478
- for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) {
3479
- token_counts[sequence.tokens[i].id]++;
3480
- cnt++;
3481
- }
3482
-
3483
- for (const auto & kv : token_counts) {
3484
- const auto p = kv.second/(double)cnt;
3485
- entropy -= p*log(p);
3486
-
3487
- //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
3488
- }
3489
-
3490
- sequence.entropy = entropy;
3491
- }
3492
- }
3493
-
3494
- int whisper_full(
3495
- struct whisper_context * ctx,
3496
- struct whisper_full_params params,
3497
- const float * samples,
3498
- int n_samples) {
3499
- // clear old results
3500
- auto & result_all = ctx->result_all;
3501
-
3502
- result_all.clear();
3503
-
3504
- // compute log mel spectrogram
3505
- if (params.speed_up) {
3506
- if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
3507
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3508
- return -1;
3509
- }
3510
- } else {
3511
- if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
3512
- fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3513
- return -2;
3514
- }
3515
- }
3516
-
3517
- // auto-detect language if not specified
3518
- if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
3519
- std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
3520
-
3521
- const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
3522
- if (lang_id < 0) {
3523
- fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
3524
- return -3;
3525
- }
3526
- ctx->lang_id = lang_id;
3527
- params.language = whisper_lang_str(lang_id);
3528
-
3529
- fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
3530
- }
3531
-
3532
- if (params.token_timestamps) {
3533
- ctx->t_beg = 0;
3534
- ctx->t_last = 0;
3535
- ctx->tid_last = 0;
3536
- ctx->energy = get_signal_energy(samples, n_samples, 32);
3537
- }
3538
-
3539
- const int seek_start = params.offset_ms/10;
3540
- const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
3541
-
3542
- // if length of spectrogram is less than 1s (100 samples), then return
3543
- // basically don't process anything that is less than 1s
3544
- // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
3545
- if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
3546
- return 0;
3547
- }
3548
-
3549
- // a set of temperatures to use
3550
- // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
3551
- std::vector<float> temperatures;
3552
- if (params.temperature_inc > 0.0f) {
3553
- for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) {
3554
- temperatures.push_back(t);
3555
- }
3556
- } else {
3557
- temperatures.push_back(params.temperature);
3558
- }
3559
-
3560
- // initialize the decoders
3561
- int n_decoders = 1;
3562
-
3563
- switch (params.strategy) {
3564
- case WHISPER_SAMPLING_GREEDY:
3565
- {
3566
- n_decoders = params.greedy.best_of;
3567
- } break;
3568
- case WHISPER_SAMPLING_BEAM_SEARCH:
3569
- {
3570
- n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size);
3571
- } break;
3572
- };
3573
-
3574
- n_decoders = std::max(1, n_decoders);
3575
-
3576
- // TAGS: WHISPER_DECODER_INIT
3577
- for (int j = 1; j < n_decoders && ctx->running; j++) {
3578
- auto & decoder = ctx->decoders[j];
3579
-
3580
- if (decoder.kv_self.ctx == nullptr) {
3581
- decoder.kv_self = ctx->decoders[0].kv_self;
3582
- if (!kv_cache_reinit(decoder.kv_self)) {
3583
- fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
3584
- return -4;
3585
- }
3586
-
3587
- WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
3588
-
3589
- decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
3590
-
3591
- decoder.probs.resize (ctx->vocab.n_vocab);
3592
- decoder.logits.resize (ctx->vocab.n_vocab);
3593
- decoder.logprobs.resize(ctx->vocab.n_vocab);
3594
- }
3595
- }
3596
-
3597
- // the accumulated text context so far
3598
- auto & prompt_past = ctx->prompt_past;
3599
- if (params.no_context) {
3600
- prompt_past.clear();
3601
- }
3602
-
3603
- // prepend the prompt tokens to the prompt_past
3604
- if (params.prompt_tokens && params.prompt_n_tokens > 0) {
3605
- // parse tokens from the pointer
3606
- for (int i = 0; i < params.prompt_n_tokens; i++) {
3607
- prompt_past.push_back(params.prompt_tokens[i]);
3608
- }
3609
- std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
3610
- }
3611
-
3612
- // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
3613
- if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
3614
- fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
3615
- return -5;
3616
- }
3617
- ctx->exp_n_audio_ctx = params.audio_ctx;
3618
-
3619
- // these tokens determine the task that will be performed
3620
- std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
3621
- if (whisper_is_multilingual(ctx)) {
3622
- const int lang_id = whisper_lang_id(params.language);
3623
- ctx->lang_id = lang_id;
3624
- prompt_init.push_back(whisper_token_lang(ctx, lang_id));
3625
- if (params.translate) {
3626
- prompt_init.push_back(whisper_token_translate());
3627
- } else {
3628
- prompt_init.push_back(whisper_token_transcribe());
3629
- }
3630
- }
3631
-
3632
- int progress_prev = 0;
3633
- int progress_step = 5;
3634
-
3635
- int seek = seek_start;
3636
-
3637
- std::vector<whisper_token> prompt;
3638
- prompt.reserve(whisper_n_text_ctx(ctx));
3639
-
3640
- // beam-search helpers
3641
- struct kv_buf {
3642
- std::vector<uint8_t> k;
3643
- std::vector<uint8_t> v;
3644
- };
3645
-
3646
- std::vector<kv_buf> kv_bufs;
3647
-
3648
- struct beam_candidate {
3649
- int decoder_idx;
3650
- int seek_delta;
3651
-
3652
- bool has_ts;
3653
-
3654
- whisper_sequence sequence;
3655
- };
3656
-
3657
- std::vector<beam_candidate> beam_candidates;
3658
-
3659
- // main loop
3660
- while (ctx->running) {
3661
- const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
3662
- while (progress_cur >= progress_prev + progress_step) {
3663
- progress_prev += progress_step;
3664
- if (params.print_progress) {
3665
- fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
3666
- }
3667
- }
3668
-
3669
- // of only 1 second left, then stop
3670
- if (seek + 100 >= seek_end) {
3671
- break;
3672
- }
3673
-
3674
- if (params.encoder_begin_callback) {
3675
- if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
3676
- fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
3677
- break;
3678
- }
3679
- }
3680
-
3681
- // encode audio features starting at offset seek
3682
- if (!whisper_encode(*ctx, seek, params.n_threads)) {
3683
- fprintf(stderr, "%s: failed to encode\n", __func__);
3684
- return -6;
3685
- }
3686
-
3687
- // if there is a very short audio segment left to process, we remove any past prompt since it tends
3688
- // to confuse the decoder and often make it repeat or hallucinate stuff
3689
- if (seek > seek_start && seek + 500 >= seek_end) {
3690
- prompt_past.clear();
3691
- }
3692
-
3693
- int best_decoder_id = 0;
3694
-
3695
- for (int it = 0; it < (int) temperatures.size(); ++it) {
3696
- const float t_cur = temperatures[it];
3697
-
3698
- int n_decoders_cur = 1;
3699
-
3700
- switch (params.strategy) {
3701
- case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
3702
- {
3703
- if (t_cur > 0.0f) {
3704
- n_decoders_cur = params.greedy.best_of;
3705
- }
3706
- } break;
3707
- case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
3708
- {
3709
- if (t_cur > 0.0f) {
3710
- n_decoders_cur = params.greedy.best_of;
3711
- } else {
3712
- n_decoders_cur = params.beam_search.beam_size;
3713
- }
3714
- } break;
3715
- };
3716
-
3717
- n_decoders_cur = std::max(1, n_decoders_cur);
3718
-
3719
- WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
3720
-
3721
- // TAGS: WHISPER_DECODER_INIT
3722
- for (int j = 0; j < n_decoders_cur; ++j) {
3723
- auto & decoder = ctx->decoders[j];
3724
-
3725
- decoder.kv_self.n = 0;
3726
-
3727
- decoder.sequence.tokens.clear();
3728
- decoder.sequence.result_len = 0;
3729
- decoder.sequence.sum_logprobs_all = 0.0;
3730
- decoder.sequence.sum_logprobs = -INFINITY;
3731
- decoder.sequence.avg_logprobs = -INFINITY;
3732
- decoder.sequence.entropy = 0.0;
3733
- decoder.sequence.score = -INFINITY;
3734
-
3735
- decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
3736
-
3737
- decoder.failed = false;
3738
- decoder.completed = false;
3739
- decoder.has_ts = false;
3740
- }
3741
-
3742
- // init prompt and kv cache for the current iteration
3743
- // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
3744
- {
3745
- prompt.clear();
3746
-
3747
- // if we have already generated some text, use it as a prompt to condition the next generation
3748
- if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
3749
- int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
3750
-
3751
- prompt = { whisper_token_prev(ctx) };
3752
- prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
3753
- }
3754
-
3755
- // init new transcription with sot, language (opt) and task tokens
3756
- prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
3757
-
3758
- // print the prompt
3759
- WHISPER_PRINT_DEBUG("\n\n");
3760
- for (int i = 0; i < (int) prompt.size(); i++) {
3761
- WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
3762
- }
3763
- WHISPER_PRINT_DEBUG("\n\n");
3764
-
3765
- if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
3766
- fprintf(stderr, "%s: failed to decode\n", __func__);
3767
- return -7;
3768
- }
3769
-
3770
- {
3771
- const int64_t t_start_sample_us = ggml_time_us();
3772
-
3773
- whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
3774
-
3775
- ctx->decoders[0].kv_self.n += prompt.size();
3776
-
3777
- for (int j = 1; j < n_decoders_cur; ++j) {
3778
- auto & decoder = ctx->decoders[j];
3779
-
3780
- memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
3781
- memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
3782
-
3783
- decoder.kv_self.n += prompt.size();
3784
-
3785
- memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
3786
- memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
3787
- memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
3788
- }
3789
-
3790
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
3791
- }
3792
- }
3793
-
3794
- for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
3795
- const int64_t t_start_sample_us = ggml_time_us();
3796
-
3797
- // store the KV caches of all decoders when doing beam-search
3798
- if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
3799
- kv_bufs.resize(n_decoders_cur);
3800
- for (int j = 0; j < n_decoders_cur; ++j) {
3801
- auto & decoder = ctx->decoders[j];
3802
-
3803
- if (decoder.completed || decoder.failed) {
3804
- continue;
3805
- }
3806
-
3807
- kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
3808
- kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
3809
-
3810
- memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
3811
- memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
3812
- }
3813
-
3814
- beam_candidates.clear();
3815
- }
3816
-
3817
- // generate new sequence candidates for each decoder
3818
- for (int j = 0; j < n_decoders_cur; ++j) {
3819
- auto & decoder = ctx->decoders[j];
3820
-
3821
- if (decoder.completed || decoder.failed) {
3822
- continue;
3823
- }
3824
-
3825
- switch (params.strategy) {
3826
- case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
3827
- {
3828
- if (t_cur < 1e-6f) {
3829
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
3830
- } else {
3831
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
3832
- }
3833
-
3834
- decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
3835
- } break;
3836
- case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
3837
- {
3838
- const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
3839
-
3840
- for (const auto & token : tokens_new) {
3841
- beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
3842
- beam_candidates.back().sequence.tokens.push_back(token);
3843
- beam_candidates.back().sequence.sum_logprobs_all += token.plog;
3844
-
3845
- //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
3846
- }
3847
- } break;
3848
- };
3849
- }
3850
-
3851
- // for beam-search, choose the top candidates and update the KV caches
3852
- if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
3853
- std::sort(
3854
- beam_candidates.begin(),
3855
- beam_candidates.end(),
3856
- [](const beam_candidate & a, const beam_candidate & b) {
3857
- return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
3858
- });
3859
-
3860
- uint32_t cur_c = 0;
3861
-
3862
- for (int j = 0; j < n_decoders_cur; ++j) {
3863
- auto & decoder = ctx->decoders[j];
3864
-
3865
- if (decoder.completed || decoder.failed) {
3866
- continue;
3867
- }
3868
-
3869
- auto & cur = beam_candidates[cur_c++];
3870
-
3871
- while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
3872
- ++cur_c;
3873
- }
3874
-
3875
- decoder.sequence = cur.sequence;
3876
- decoder.seek_delta = cur.seek_delta;
3877
- decoder.has_ts = cur.has_ts;
3878
-
3879
- memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
3880
- memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
3881
-
3882
- WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
3883
- __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
3884
- }
3885
- }
3886
-
3887
- // update the decoder state
3888
- // - check if the sequence is completed
3889
- // - check if the sequence is failed
3890
- // - update sliding window based on timestamp tokens
3891
- for (int j = 0; j < n_decoders_cur; ++j) {
3892
- auto & decoder = ctx->decoders[j];
3893
-
3894
- if (decoder.completed || decoder.failed) {
3895
- continue;
3896
- }
3897
-
3898
- auto & has_ts = decoder.has_ts;
3899
- auto & failed = decoder.failed;
3900
- auto & completed = decoder.completed;
3901
- auto & seek_delta = decoder.seek_delta;
3902
- auto & result_len = decoder.sequence.result_len;
3903
-
3904
- {
3905
- const auto & token = decoder.sequence.tokens.back();
3906
-
3907
- // timestamp token - update sliding window
3908
- if (token.id > whisper_token_beg(ctx)) {
3909
- const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
3910
-
3911
- // do not allow to go back in time
3912
- if (has_ts && seek_delta > seek_delta_new && result_len < i) {
3913
- failed = true; // TODO: maybe this is not a failure ?
3914
- continue;
3915
- }
3916
-
3917
- seek_delta = seek_delta_new;
3918
- result_len = i + 1;
3919
- has_ts = true;
3920
- }
3921
-
3922
- #ifdef WHISPER_DEBUG
3923
- {
3924
- const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
3925
- WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
3926
- __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
3927
- }
3928
- #endif
3929
-
3930
- // end of segment
3931
- if (token.id == whisper_token_eot(ctx) || // end of text token
3932
- (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
3933
- (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
3934
- ) {
3935
- if (result_len == 0) {
3936
- if (seek + seek_delta + 100 >= seek_end) {
3937
- result_len = i + 1;
3938
- } else {
3939
- failed = true;
3940
- continue;
3941
- }
3942
- }
3943
-
3944
- if (params.single_segment) {
3945
- result_len = i + 1;
3946
- seek_delta = 100*WHISPER_CHUNK_SIZE;
3947
- }
3948
-
3949
- completed = true;
3950
- continue;
3951
- }
3952
-
3953
- // TESTS: if no tensors are loaded, it means we are running tests
3954
- if (ctx->model.n_loaded == 0) {
3955
- seek_delta = 100*WHISPER_CHUNK_SIZE;
3956
- completed = true;
3957
- continue;
3958
- }
3959
- }
3960
-
3961
- // sometimes, the decoding can get stuck in a repetition loop
3962
- // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
3963
- if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
3964
- failed = true;
3965
- continue;
3966
- }
3967
- }
3968
-
3969
- // check if all decoders have finished (i.e. completed or failed)
3970
- {
3971
- bool completed_all = true;
3972
-
3973
- for (int j = 0; j < n_decoders_cur; ++j) {
3974
- auto & decoder = ctx->decoders[j];
3975
-
3976
- if (decoder.completed || decoder.failed) {
3977
- continue;
3978
- }
3979
-
3980
- completed_all = false;
3981
- }
3982
-
3983
- if (completed_all) {
3984
- break;
3985
- }
3986
- }
3987
-
3988
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
3989
-
3990
- // obtain logits for the next token
3991
- for (int j = 0; j < n_decoders_cur; ++j) {
3992
- auto & decoder = ctx->decoders[j];
3993
-
3994
- if (decoder.failed || decoder.completed) {
3995
- continue;
3996
- }
3997
-
3998
- decoder.tokens_tmp.resize(1);
3999
- decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
4000
-
4001
- //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4002
-
4003
- if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4004
- fprintf(stderr, "%s: failed to decode\n", __func__);
4005
- return -8;
4006
- }
4007
-
4008
- {
4009
- const int64_t t_start_sample_us = ggml_time_us();
4010
-
4011
- whisper_process_logits(*ctx, params, decoder, t_cur);
4012
-
4013
- ++decoder.kv_self.n;
4014
-
4015
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
4016
- }
4017
- }
4018
- }
4019
-
4020
- // rank the resulting sequences and select the best one
4021
- {
4022
- double best_score = -INFINITY;
4023
-
4024
- for (int j = 0; j < n_decoders_cur; ++j) {
4025
- auto & decoder = ctx->decoders[j];
4026
-
4027
- if (decoder.failed) {
4028
- continue;
4029
- }
4030
-
4031
- decoder.sequence.tokens.resize(decoder.sequence.result_len);
4032
- whisper_sequence_score(params, decoder.sequence);
4033
-
4034
- WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
4035
- __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
4036
-
4037
- if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
4038
- WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
4039
- __func__, j, decoder.sequence.entropy, params.entropy_thold);
4040
-
4041
- decoder.failed = true;
4042
- ctx->n_fail_h++;
4043
-
4044
- continue;
4045
- }
4046
-
4047
- if (best_score < decoder.sequence.score) {
4048
- best_score = decoder.sequence.score;
4049
- best_decoder_id = j;
4050
- }
4051
- }
4052
-
4053
- WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
4054
- }
4055
-
4056
- // was the decoding successful for the current temperature?
4057
- {
4058
- bool success = true;
4059
-
4060
- const auto & decoder = ctx->decoders[best_decoder_id];
4061
-
4062
- if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
4063
- success = false;
4064
- ctx->n_fail_p++;
4065
- }
4066
-
4067
- if (success) {
4068
- //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
4069
- // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
4070
- //}
4071
-
4072
- break;
4073
- }
4074
- }
4075
-
4076
- WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
4077
- }
4078
-
4079
- // output results through a user-provided callback
4080
- {
4081
- const auto & best_decoder = ctx->decoders[best_decoder_id];
4082
-
4083
- const auto seek_delta = best_decoder.seek_delta;
4084
- const auto result_len = best_decoder.sequence.result_len;
4085
-
4086
- const auto & tokens_cur = best_decoder.sequence.tokens;
4087
-
4088
- //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
4089
-
4090
- // update prompt_past
4091
- prompt_past.clear();
4092
- if (prompt.front() == whisper_token_prev(ctx)) {
4093
- prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
4094
- }
4095
-
4096
- for (int i = 0; i < result_len; ++i) {
4097
- prompt_past.push_back(tokens_cur[i].id);
4098
- }
4099
-
4100
- // store the text from this iteration
4101
- if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
4102
- int i0 = 0;
4103
- auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
4104
-
4105
- std::string text;
4106
-
4107
- for (int i = 0; i < (int) tokens_cur.size(); i++) {
4108
- //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
4109
- // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
4110
- // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
4111
-
4112
- if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
4113
- } else {
4114
- text += whisper_token_to_str(ctx, tokens_cur[i].id);
4115
- }
4116
-
4117
- if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
4118
- const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
4119
-
4120
- if (!text.empty()) {
4121
- const auto tt0 = params.speed_up ? 2*t0 : t0;
4122
- const auto tt1 = params.speed_up ? 2*t1 : t1;
4123
-
4124
- if (params.print_realtime) {
4125
- if (params.print_timestamps) {
4126
- printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
4127
- } else {
4128
- printf("%s", text.c_str());
4129
- fflush(stdout);
4130
- }
4131
- }
4132
-
4133
- //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
4134
-
4135
- result_all.push_back({ tt0, tt1, text, {} });
4136
- for (int j = i0; j <= i; j++) {
4137
- result_all.back().tokens.push_back(tokens_cur[j]);
4138
- }
4139
-
4140
- int n_new = 1;
4141
-
4142
- if (params.token_timestamps) {
4143
- whisper_exp_compute_token_level_timestamps(
4144
- *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4145
-
4146
- if (params.max_len > 0) {
4147
- n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4148
- }
4149
- }
4150
- if (params.new_segment_callback) {
4151
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4152
- }
4153
- }
4154
- text = "";
4155
- while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
4156
- i++;
4157
- }
4158
- i--;
4159
- t0 = t1;
4160
- i0 = i + 1;
4161
- }
4162
- }
4163
-
4164
- if (!text.empty()) {
4165
- const auto t1 = seek + seek_delta;
4166
-
4167
- const auto tt0 = params.speed_up ? 2*t0 : t0;
4168
- const auto tt1 = params.speed_up ? 2*t1 : t1;
4169
-
4170
- if (params.print_realtime) {
4171
- if (params.print_timestamps) {
4172
- printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
4173
- } else {
4174
- printf("%s", text.c_str());
4175
- fflush(stdout);
4176
- }
4177
- }
4178
-
4179
- result_all.push_back({ tt0, tt1, text, {} });
4180
- for (int j = i0; j < (int) tokens_cur.size(); j++) {
4181
- result_all.back().tokens.push_back(tokens_cur[j]);
4182
- }
4183
-
4184
- int n_new = 1;
4185
-
4186
- if (params.token_timestamps) {
4187
- whisper_exp_compute_token_level_timestamps(
4188
- *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4189
-
4190
- if (params.max_len > 0) {
4191
- n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4192
- }
4193
- }
4194
- if (params.new_segment_callback) {
4195
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4196
- }
4197
- }
4198
- }
4199
-
4200
- // update audio window
4201
- seek += seek_delta;
4202
-
4203
- WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
4204
- }
4205
- }
4206
-
4207
- return 0;
4208
- }
4209
-
4210
- void whisper_running_abort(struct whisper_context * ctx) {
4211
- ctx->running = false;
4212
- }
4213
-
4214
- void whisper_running_restore(struct whisper_context * ctx) {
4215
- ctx->running = true;
4216
- }
4217
-
4218
- bool whisper_running_state(struct whisper_context * ctx) {
4219
- return ctx->running;
4220
- }
4221
-
4222
- int whisper_full_parallel(
4223
- struct whisper_context * ctx,
4224
- struct whisper_full_params params,
4225
- const float * samples,
4226
- int n_samples,
4227
- int n_processors) {
4228
- if (n_processors == 1) {
4229
- return whisper_full(ctx, params, samples, n_samples);
4230
- }
4231
-
4232
- int ret = 0;
4233
-
4234
- // prepare separate contexts for each thread
4235
- std::vector<struct whisper_context> ctxs(n_processors - 1);
4236
-
4237
- for (int i = 0; i < n_processors - 1; ++i) {
4238
- auto & ctx_p = ctxs[i];
4239
-
4240
- ctx_p = *ctx;
4241
-
4242
- ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
4243
-
4244
- ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
4245
-
4246
- if (!kv_cache_reinit(ctx_p.kv_cross)) {
4247
- fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
4248
- return false;
4249
- }
4250
-
4251
- // TAGS: WHISPER_DECODER_INIT
4252
- for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4253
- if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
4254
- fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
4255
- return false;
4256
- }
4257
-
4258
- ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
4259
-
4260
- ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
4261
- ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
4262
- ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
4263
- }
4264
- }
4265
-
4266
- const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
4267
- const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
4268
-
4269
- // the calling thread will process the first chunk
4270
- // while the other threads will process the remaining chunks
4271
-
4272
- std::vector<std::thread> workers(n_processors - 1);
4273
- for (int i = 0; i < n_processors - 1; ++i) {
4274
- const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
4275
- const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
4276
-
4277
- auto params_cur = params;
4278
-
4279
- params_cur.offset_ms = 0;
4280
- params_cur.print_progress = false;
4281
- params_cur.print_realtime = false;
4282
-
4283
- params_cur.new_segment_callback = nullptr;
4284
- params_cur.new_segment_callback_user_data = nullptr;
4285
-
4286
- workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
4287
- }
4288
-
4289
- {
4290
- auto params_cur = params;
4291
-
4292
- ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
4293
- }
4294
-
4295
- for (int i = 0; i < n_processors - 1; ++i) {
4296
- workers[i].join();
4297
- }
4298
-
4299
- const int64_t offset_t = (int64_t) params.offset_ms/10.0;
4300
-
4301
- // combine results into ctx->result_all
4302
- for (int i = 0; i < n_processors - 1; ++i) {
4303
- auto & results_i = ctxs[i].result_all;
4304
-
4305
- for (auto & result : results_i) {
4306
- // correct the segment timestamp taking into account the offset
4307
- result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
4308
- result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
4309
-
4310
- // make sure that segments are not overlapping
4311
- if (!ctx->result_all.empty()) {
4312
- result.t0 = std::max(result.t0, ctx->result_all.back().t1);
4313
- }
4314
-
4315
- ctx->result_all.push_back(std::move(result));
4316
-
4317
- // call the new_segment_callback for each segment
4318
- if (params.new_segment_callback) {
4319
- params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
4320
- }
4321
- }
4322
-
4323
- ctx->t_mel_us += ctxs[i].t_mel_us;
4324
- ctx->t_sample_us += ctxs[i].t_sample_us;
4325
- ctx->t_encode_us += ctxs[i].t_encode_us;
4326
- ctx->t_decode_us += ctxs[i].t_decode_us;
4327
-
4328
- kv_cache_free(ctx->kv_cross);
4329
-
4330
- for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4331
- kv_cache_free(ctx->decoders[j].kv_self);
4332
- }
4333
- }
4334
-
4335
- // average the timings
4336
- ctx->t_mel_us /= n_processors;
4337
- ctx->t_sample_us /= n_processors;
4338
- ctx->t_encode_us /= n_processors;
4339
- ctx->t_decode_us /= n_processors;
4340
-
4341
- // print information about the audio boundaries
4342
- fprintf(stderr, "\n");
4343
- fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
4344
- for (int i = 0; i < n_processors - 1; ++i) {
4345
- fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
4346
- }
4347
- fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
4348
-
4349
- return ret;
4350
- }
4351
-
4352
- int whisper_full_n_segments(struct whisper_context * ctx) {
4353
- return ctx->result_all.size();
4354
- }
4355
-
4356
- int whisper_full_lang_id(struct whisper_context * ctx) {
4357
- return ctx->lang_id;
4358
- }
4359
-
4360
- int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
4361
- return ctx->result_all[i_segment].t0;
4362
- }
4363
-
4364
- int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
4365
- return ctx->result_all[i_segment].t1;
4366
- }
4367
-
4368
- const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
4369
- return ctx->result_all[i_segment].text.c_str();
4370
- }
4371
-
4372
- int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
4373
- return ctx->result_all[i_segment].tokens.size();
4374
- }
4375
-
4376
- const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
4377
- return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
4378
- }
4379
-
4380
- whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
4381
- return ctx->result_all[i_segment].tokens[i_token].id;
4382
- }
4383
-
4384
- struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
4385
- return ctx->result_all[i_segment].tokens[i_token];
4386
- }
4387
-
4388
- float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
4389
- return ctx->result_all[i_segment].tokens[i_token].p;
4390
- }
4391
-
4392
- // =================================================================================================
4393
-
4394
- //
4395
- // Temporary interface needed for exposing ggml interface
4396
- // Will be removed in the future when ggml becomes a separate library
4397
- //
4398
-
4399
- WHISPER_API int whisper_bench_memcpy(int n_threads) {
4400
- ggml_time_init();
4401
-
4402
- size_t n = 50;
4403
- size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations
4404
-
4405
- // 1 GB array
4406
- const size_t size = arr*1024llu*1024llu;
4407
-
4408
- char * src = (char *) malloc(size);
4409
- char * dst = (char *) malloc(size);
4410
-
4411
- for (size_t i = 0; i < size; i++) src[i] = i;
4412
-
4413
- memcpy(dst, src, size); // heat-up
4414
-
4415
- double tsum = 0.0;
4416
-
4417
- for (size_t i = 0; i < n; i++) {
4418
- const int64_t t0 = ggml_time_us();
4419
-
4420
- memcpy(dst, src, size);
4421
-
4422
- const int64_t t1 = ggml_time_us();
4423
-
4424
- tsum += (t1 - t0)*1e-6;
4425
-
4426
- src[0] = rand();
4427
- }
4428
-
4429
- fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
4430
-
4431
- // needed to prevent the compile from optimizing the memcpy away
4432
- {
4433
- double sum = 0.0;
4434
-
4435
- for (size_t i = 0; i < size; i++) sum += dst[i];
4436
-
4437
- fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
4438
- }
4439
-
4440
- free(src);
4441
- free(dst);
4442
-
4443
- return 0;
4444
- }
4445
-
4446
- WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
4447
- ggml_time_init();
4448
-
4449
- const int n_max = 128;
4450
-
4451
- const std::vector<size_t> sizes = {
4452
- 64, 128, 256, 512, 1024, 2048, 4096,
4453
- };
4454
-
4455
- const size_t N_max = sizes.back();
4456
-
4457
- // a: N*N*sizeof(float)
4458
- // b: N*N*sizeof(float)
4459
- // c: N*N*sizeof(float)
4460
- // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
4461
- std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
4462
-
4463
- for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
4464
-
4465
- for (int j = 0; j < (int) sizes.size(); j++) {
4466
- int n_fp16 = 0;
4467
- int n_fp32 = 0;
4468
-
4469
- // GFLOPS/s
4470
- double s_fp16 = 0.0;
4471
- double s_fp32 = 0.0;
4472
-
4473
- const size_t N = sizes[j];
4474
-
4475
- for (int k = 0; k < 2; ++k) {
4476
- const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32;
4477
-
4478
- double & s = k == 0 ? s_fp16 : s_fp32;
4479
- int & n = k == 0 ? n_fp16 : n_fp32;
4480
-
4481
- struct ggml_init_params gparams = {
4482
- /*.mem_size =*/ buf.size(),
4483
- /*.mem_buffer =*/ buf.data(),
4484
- };
4485
-
4486
- struct ggml_context * ctx0 = ggml_init(gparams);
4487
-
4488
- struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N);
4489
- struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N);
4490
-
4491
- struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
4492
-
4493
- struct ggml_cgraph gf = ggml_build_forward(c);
4494
-
4495
- gf.n_threads = n_threads;
4496
-
4497
- double tsum = 0.0;
4498
-
4499
- // heat-up
4500
- ggml_graph_compute(ctx0, &gf);
4501
-
4502
- for (int i = 0; i < n_max; ++i) {
4503
- const int64_t t0 = ggml_time_us();
4504
-
4505
- ggml_graph_compute(ctx0, &gf);
4506
-
4507
- const int64_t t1 = ggml_time_us();
4508
-
4509
- tsum += (t1 - t0)*1e-6;
4510
- n++;
4511
-
4512
- if (tsum > 1.0 && n >= 3) {
4513
- break;
4514
- }
4515
- }
4516
-
4517
- ggml_free(ctx0);
4518
-
4519
- s = ((2.0*N*N*N*n)/tsum)*1e-9;
4520
- }
4521
-
4522
- fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
4523
- N, N, s_fp16, n_fp16, s_fp32, n_fp32);
4524
- }
4525
-
4526
- return 0;
4527
- }
4528
-
4529
- // =================================================================================================
4530
-
4531
- // =================================================================================================
4532
-
4533
- //
4534
- // Experimental stuff below
4535
- //
4536
- // Not sure if these should be part of the library at all, because the quality of the results is not
4537
- // guaranteed. Might get removed at some point unless a robust algorithm implementation is found
4538
- //
4539
-
4540
- // =================================================================================================
4541
-
4542
- //
4543
- // token-level timestamps
4544
- //
4545
-
4546
- static int timestamp_to_sample(int64_t t, int n_samples) {
4547
- return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
4548
- }
4549
-
4550
- static int64_t sample_to_timestamp(int i_sample) {
4551
- return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
4552
- }
4553
-
4554
- // a cost-function / heuristic that is high for text that takes longer to pronounce
4555
- // obviously, can be improved
4556
- static float voice_length(const std::string & text) {
4557
- float res = 0.0f;
4558
-
4559
- for (char c : text) {
4560
- if (c == ' ') {
4561
- res += 0.01f;
4562
- } else if (c == ',') {
4563
- res += 2.00f;
4564
- } else if (c == '.') {
4565
- res += 3.00f;
4566
- } else if (c == '!') {
4567
- res += 3.00f;
4568
- } else if (c == '?') {
4569
- res += 3.00f;
4570
- } else if (c >= '0' && c <= '9') {
4571
- res += 3.00f;
4572
- } else {
4573
- res += 1.00f;
4574
- }
4575
- }
4576
-
4577
- return res;
4578
- }
4579
-
4580
- // average the fabs of the signal
4581
- static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
4582
- const int hw = n_samples_per_half_window;
4583
-
4584
- std::vector<float> result(n_samples);
4585
-
4586
- for (int i = 0; i < n_samples; i++) {
4587
- float sum = 0;
4588
- for (int j = -hw; j <= hw; j++) {
4589
- if (i + j >= 0 && i + j < n_samples) {
4590
- sum += fabs(signal[i + j]);
4591
- }
4592
- }
4593
- result[i] = sum/(2*hw + 1);
4594
- }
4595
-
4596
- return result;
4597
- }
4598
-
4599
- static void whisper_exp_compute_token_level_timestamps(
4600
- struct whisper_context & ctx,
4601
- int i_segment,
4602
- float thold_pt,
4603
- float thold_ptsum) {
4604
- auto & segment = ctx.result_all[i_segment];
4605
- auto & tokens = segment.tokens;
4606
-
4607
- const int n_samples = ctx.energy.size();
4608
-
4609
- if (n_samples == 0) {
4610
- fprintf(stderr, "%s: no signal data available\n", __func__);
4611
- return;
4612
- }
4613
-
4614
- const int64_t t0 = segment.t0;
4615
- const int64_t t1 = segment.t1;
4616
-
4617
- const int n = tokens.size();
4618
-
4619
- if (n == 0) {
4620
- return;
4621
- }
4622
-
4623
- if (n == 1) {
4624
- tokens[0].t0 = t0;
4625
- tokens[0].t1 = t1;
4626
-
4627
- return;
4628
- }
4629
-
4630
- auto & t_beg = ctx.t_beg;
4631
- auto & t_last = ctx.t_last;
4632
- auto & tid_last = ctx.tid_last;
4633
-
4634
- for (int j = 0; j < n; ++j) {
4635
- auto & token = tokens[j];
4636
-
4637
- if (j == 0) {
4638
- if (token.id == whisper_token_beg(&ctx)) {
4639
- tokens[j ].t0 = t0;
4640
- tokens[j ].t1 = t0;
4641
- tokens[j + 1].t0 = t0;
4642
-
4643
- t_beg = t0;
4644
- t_last = t0;
4645
- tid_last = whisper_token_beg(&ctx);
4646
- } else {
4647
- tokens[j ].t0 = t_last;
4648
- }
4649
- }
4650
-
4651
- const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
4652
-
4653
- tokens[j].id = token.id;
4654
- tokens[j].tid = token.tid;
4655
- tokens[j].p = token.p;
4656
- tokens[j].pt = token.pt;
4657
- tokens[j].ptsum = token.ptsum;
4658
-
4659
- tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
4660
-
4661
- if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
4662
- if (j > 0) {
4663
- tokens[j - 1].t1 = tt;
4664
- }
4665
- tokens[j].t0 = tt;
4666
- tid_last = token.tid;
4667
- }
4668
- }
4669
-
4670
- tokens[n - 2].t1 = t1;
4671
- tokens[n - 1].t0 = t1;
4672
- tokens[n - 1].t1 = t1;
4673
-
4674
- t_last = t1;
4675
-
4676
- // find intervals of tokens with unknown timestamps
4677
- // fill the timestamps by proportionally splitting the interval based on the token voice lengths
4678
- {
4679
- int p0 = 0;
4680
- int p1 = 0;
4681
-
4682
- while (true) {
4683
- while (p1 < n && tokens[p1].t1 < 0) {
4684
- p1++;
4685
- }
4686
-
4687
- if (p1 >= n) {
4688
- p1--;
4689
- }
4690
-
4691
- //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1);
4692
-
4693
- if (p1 > p0) {
4694
- double psum = 0.0;
4695
- for (int j = p0; j <= p1; j++) {
4696
- psum += tokens[j].vlen;
4697
- }
4698
-
4699
- //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
4700
-
4701
- const double dt = tokens[p1].t1 - tokens[p0].t0;
4702
-
4703
- // split the time proportionally to the voice length
4704
- for (int j = p0 + 1; j <= p1; j++) {
4705
- const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
4706
-
4707
- tokens[j - 1].t1 = ct;
4708
- tokens[j ].t0 = ct;
4709
- }
4710
- }
4711
-
4712
- p1++;
4713
- p0 = p1;
4714
- if (p1 >= n) {
4715
- break;
4716
- }
4717
- }
4718
- }
4719
-
4720
- // fix up (just in case)
4721
- for (int j = 0; j < n - 1; j++) {
4722
- if (tokens[j].t1 < 0) {
4723
- tokens[j + 1].t0 = tokens[j].t1;
4724
- }
4725
-
4726
- if (j > 0) {
4727
- if (tokens[j - 1].t1 > tokens[j].t0) {
4728
- tokens[j].t0 = tokens[j - 1].t1;
4729
- tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
4730
- }
4731
- }
4732
- }
4733
-
4734
- // VAD
4735
- // expand or contract tokens based on voice activity
4736
- {
4737
- const int hw = WHISPER_SAMPLE_RATE/8;
4738
-
4739
- for (int j = 0; j < n; j++) {
4740
- if (tokens[j].id >= whisper_token_eot(&ctx)) {
4741
- continue;
4742
- }
4743
-
4744
- int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
4745
- int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
4746
-
4747
- const int ss0 = std::max(s0 - hw, 0);
4748
- const int ss1 = std::min(s1 + hw, n_samples);
4749
-
4750
- const int ns = ss1 - ss0;
4751
-
4752
- float sum = 0.0f;
4753
-
4754
- for (int k = ss0; k < ss1; k++) {
4755
- sum += ctx.energy[k];
4756
- }
4757
-
4758
- const float thold = 0.5*sum/ns;
4759
-
4760
- {
4761
- int k = s0;
4762
- if (ctx.energy[k] > thold && j > 0) {
4763
- while (k > 0 && ctx.energy[k] > thold) {
4764
- k--;
4765
- }
4766
- tokens[j].t0 = sample_to_timestamp(k);
4767
- if (tokens[j].t0 < tokens[j - 1].t1) {
4768
- tokens[j].t0 = tokens[j - 1].t1;
4769
- } else {
4770
- s0 = k;
4771
- }
4772
- } else {
4773
- while (ctx.energy[k] < thold && k < s1) {
4774
- k++;
4775
- }
4776
- s0 = k;
4777
- tokens[j].t0 = sample_to_timestamp(k);
4778
- }
4779
- }
4780
-
4781
- {
4782
- int k = s1;
4783
- if (ctx.energy[k] > thold) {
4784
- while (k < n_samples - 1 && ctx.energy[k] > thold) {
4785
- k++;
4786
- }
4787
- tokens[j].t1 = sample_to_timestamp(k);
4788
- if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
4789
- tokens[j].t1 = tokens[j + 1].t0;
4790
- } else {
4791
- s1 = k;
4792
- }
4793
- } else {
4794
- while (ctx.energy[k] < thold && k > s0) {
4795
- k--;
4796
- }
4797
- s1 = k;
4798
- tokens[j].t1 = sample_to_timestamp(k);
4799
- }
4800
- }
4801
- }
4802
- }
4803
-
4804
- // fixed token expand (optional)
4805
- //{
4806
- // const int t_expand = 0;
4807
-
4808
- // for (int j = 0; j < n; j++) {
4809
- // if (j > 0) {
4810
- // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
4811
- // }
4812
- // if (j < n - 1) {
4813
- // tokens[j].t1 = tokens[j].t1 + t_expand;
4814
- // }
4815
- // }
4816
- //}
4817
-
4818
- // debug info
4819
- //for (int j = 0; j < n; ++j) {
4820
- // const auto & token = tokens[j];
4821
- // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]";
4822
- // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
4823
- // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id));
4824
-
4825
- // if (tokens[j].id >= whisper_token_eot(&ctx)) {
4826
- // continue;
4827
- // }
4828
- //}
4829
- }