whisper.rn 0.2.4 → 0.3.0-rc.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +34 -5
- package/android/src/main/java/com/rnwhisper/RNWhisperModule.java +7 -2
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +7 -6
- package/android/src/main/jni/whisper/jni.cpp +54 -7
- package/cpp/ggml.c +6339 -1662
- package/cpp/ggml.h +741 -554
- package/cpp/rn-whisper.cpp +0 -23
- package/cpp/rn-whisper.h +0 -6
- package/cpp/whisper.cpp +928 -625
- package/cpp/whisper.h +26 -2
- package/ios/RNWhisper.mm +19 -1
- package/ios/RNWhisperContext.mm +8 -10
- package/lib/commonjs/index.js +12 -2
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/index.js +9 -2
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/index.d.ts +7 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/{index.tsx → index.ts} +10 -4
- package/whisper-rn.podspec +9 -3
package/cpp/whisper.cpp
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
|
-
#define WHISPER_BUILD
|
|
2
1
|
#include "whisper.h"
|
|
2
|
+
#if WHISPER_USE_COREML
|
|
3
|
+
#include "coreml/whisper-encoder.h"
|
|
4
|
+
#endif
|
|
3
5
|
|
|
4
6
|
#include "ggml.h"
|
|
5
7
|
|
|
@@ -99,7 +101,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|
|
99
101
|
#define WHISPER_PRINT_DEBUG(...)
|
|
100
102
|
#endif
|
|
101
103
|
|
|
102
|
-
|
|
104
|
+
//#define WHISPER_USE_FLASH_ATTN
|
|
103
105
|
//#define WHISPER_USE_FLASH_FF
|
|
104
106
|
#define WHISPER_MAX_DECODERS 16
|
|
105
107
|
|
|
@@ -218,14 +220,14 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
|
218
220
|
{ "su", { 98, "sundanese", } },
|
|
219
221
|
};
|
|
220
222
|
|
|
221
|
-
static const size_t MB = 1024*1024;
|
|
223
|
+
static const size_t MB = 1ull*1024*1024;
|
|
222
224
|
|
|
223
225
|
static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
|
|
224
|
-
{ MODEL_TINY,
|
|
225
|
-
{ MODEL_BASE,
|
|
226
|
-
{ MODEL_SMALL,
|
|
227
|
-
{ MODEL_MEDIUM,
|
|
228
|
-
{ MODEL_LARGE,
|
|
226
|
+
{ MODEL_TINY, 62ull*MB },
|
|
227
|
+
{ MODEL_BASE, 80ull*MB },
|
|
228
|
+
{ MODEL_SMALL, 120ull*MB },
|
|
229
|
+
{ MODEL_MEDIUM, 158ull*MB },
|
|
230
|
+
{ MODEL_LARGE, 198ull*MB },
|
|
229
231
|
};
|
|
230
232
|
|
|
231
233
|
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
|
|
@@ -252,12 +254,79 @@ static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
|
|
|
252
254
|
{ MODEL_LARGE, 9ull*MB },
|
|
253
255
|
};
|
|
254
256
|
|
|
255
|
-
static const std::map<e_model, size_t
|
|
256
|
-
{
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
257
|
+
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
|
258
|
+
{ GGML_TYPE_F32,
|
|
259
|
+
{
|
|
260
|
+
{ MODEL_TINY, 74ull*MB },
|
|
261
|
+
{ MODEL_BASE, 142ull*MB },
|
|
262
|
+
{ MODEL_SMALL, 466ull*MB },
|
|
263
|
+
{ MODEL_MEDIUM, 1464ull*MB },
|
|
264
|
+
{ MODEL_LARGE, 2952ull*MB },
|
|
265
|
+
},
|
|
266
|
+
},
|
|
267
|
+
{ GGML_TYPE_F16,
|
|
268
|
+
{
|
|
269
|
+
{ MODEL_TINY, 74ull*MB },
|
|
270
|
+
{ MODEL_BASE, 142ull*MB },
|
|
271
|
+
{ MODEL_SMALL, 466ull*MB },
|
|
272
|
+
{ MODEL_MEDIUM, 1464ull*MB },
|
|
273
|
+
{ MODEL_LARGE, 2952ull*MB },
|
|
274
|
+
},
|
|
275
|
+
},
|
|
276
|
+
{ GGML_TYPE_Q4_0,
|
|
277
|
+
{
|
|
278
|
+
{ MODEL_TINY, 26ull*MB },
|
|
279
|
+
{ MODEL_BASE, 50ull*MB },
|
|
280
|
+
{ MODEL_SMALL, 154ull*MB },
|
|
281
|
+
{ MODEL_MEDIUM, 470ull*MB },
|
|
282
|
+
{ MODEL_LARGE, 940ull*MB },
|
|
283
|
+
},
|
|
284
|
+
},
|
|
285
|
+
{ GGML_TYPE_Q4_1,
|
|
286
|
+
{
|
|
287
|
+
{ MODEL_TINY, 32ull*MB },
|
|
288
|
+
{ MODEL_BASE, 58ull*MB },
|
|
289
|
+
{ MODEL_SMALL, 182ull*MB },
|
|
290
|
+
{ MODEL_MEDIUM, 562ull*MB },
|
|
291
|
+
{ MODEL_LARGE, 1124ull*MB },
|
|
292
|
+
},
|
|
293
|
+
},
|
|
294
|
+
{ GGML_TYPE_Q4_2,
|
|
295
|
+
{
|
|
296
|
+
{ MODEL_TINY, 26ull*MB },
|
|
297
|
+
{ MODEL_BASE, 50ull*MB },
|
|
298
|
+
{ MODEL_SMALL, 154ull*MB },
|
|
299
|
+
{ MODEL_MEDIUM, 470ull*MB },
|
|
300
|
+
{ MODEL_LARGE, 940ull*MB },
|
|
301
|
+
},
|
|
302
|
+
},
|
|
303
|
+
{ GGML_TYPE_Q5_0,
|
|
304
|
+
{
|
|
305
|
+
{ MODEL_TINY, 30ull*MB },
|
|
306
|
+
{ MODEL_BASE, 54ull*MB },
|
|
307
|
+
{ MODEL_SMALL, 170ull*MB },
|
|
308
|
+
{ MODEL_MEDIUM, 516ull*MB },
|
|
309
|
+
{ MODEL_LARGE, 1034ull*MB },
|
|
310
|
+
},
|
|
311
|
+
},
|
|
312
|
+
{ GGML_TYPE_Q5_1,
|
|
313
|
+
{
|
|
314
|
+
{ MODEL_TINY, 32ull*MB },
|
|
315
|
+
{ MODEL_BASE, 58ull*MB },
|
|
316
|
+
{ MODEL_SMALL, 182ull*MB },
|
|
317
|
+
{ MODEL_MEDIUM, 562ull*MB },
|
|
318
|
+
{ MODEL_LARGE, 1124ull*MB },
|
|
319
|
+
},
|
|
320
|
+
},
|
|
321
|
+
{ GGML_TYPE_Q8_0,
|
|
322
|
+
{
|
|
323
|
+
{ MODEL_TINY, 45ull*MB },
|
|
324
|
+
{ MODEL_BASE, 84ull*MB },
|
|
325
|
+
{ MODEL_SMALL, 268ull*MB },
|
|
326
|
+
{ MODEL_MEDIUM, 834ull*MB },
|
|
327
|
+
{ MODEL_LARGE, 1674ull*MB },
|
|
328
|
+
},
|
|
329
|
+
},
|
|
261
330
|
};
|
|
262
331
|
|
|
263
332
|
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
|
|
@@ -277,11 +346,11 @@ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
|
|
|
277
346
|
};
|
|
278
347
|
|
|
279
348
|
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
|
280
|
-
{ MODEL_TINY,
|
|
281
|
-
{ MODEL_BASE,
|
|
282
|
-
{ MODEL_SMALL,
|
|
283
|
-
{ MODEL_MEDIUM,
|
|
284
|
-
{ MODEL_LARGE,
|
|
349
|
+
{ MODEL_TINY, 30ull*MB },
|
|
350
|
+
{ MODEL_BASE, 38ull*MB },
|
|
351
|
+
{ MODEL_SMALL, 56ull*MB },
|
|
352
|
+
{ MODEL_MEDIUM, 74ull*MB },
|
|
353
|
+
{ MODEL_LARGE, 94ull*MB },
|
|
285
354
|
};
|
|
286
355
|
|
|
287
356
|
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
|
@@ -294,6 +363,7 @@ static const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
|
|
294
363
|
|
|
295
364
|
struct whisper_mel {
|
|
296
365
|
int n_len;
|
|
366
|
+
int n_len_org;
|
|
297
367
|
int n_mel;
|
|
298
368
|
|
|
299
369
|
std::vector<float> data;
|
|
@@ -366,7 +436,7 @@ struct whisper_hparams {
|
|
|
366
436
|
int32_t n_text_head = 6;
|
|
367
437
|
int32_t n_text_layer = 4;
|
|
368
438
|
int32_t n_mels = 80;
|
|
369
|
-
int32_t
|
|
439
|
+
int32_t ftype = 1;
|
|
370
440
|
};
|
|
371
441
|
|
|
372
442
|
// audio encoding layer
|
|
@@ -586,6 +656,11 @@ struct whisper_state {
|
|
|
586
656
|
|
|
587
657
|
int lang_id = 0; // english by default
|
|
588
658
|
|
|
659
|
+
std::string path_model; // populated by whisper_init_from_file()
|
|
660
|
+
#ifdef WHISPER_USE_COREML
|
|
661
|
+
whisper_coreml_context * ctx_coreml = nullptr;
|
|
662
|
+
#endif
|
|
663
|
+
|
|
589
664
|
// [EXPERIMENTAL] token-level timestamps data
|
|
590
665
|
int64_t t_beg = 0;
|
|
591
666
|
int64_t t_last = 0;
|
|
@@ -628,15 +703,17 @@ struct whisper_state {
|
|
|
628
703
|
};
|
|
629
704
|
|
|
630
705
|
struct whisper_context {
|
|
631
|
-
int64_t t_load_us
|
|
706
|
+
int64_t t_load_us = 0;
|
|
632
707
|
int64_t t_start_us = 0;
|
|
633
708
|
|
|
634
|
-
|
|
635
|
-
ggml_type
|
|
709
|
+
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
|
|
710
|
+
ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)
|
|
636
711
|
|
|
637
712
|
whisper_model model;
|
|
638
713
|
whisper_vocab vocab;
|
|
639
714
|
whisper_state * state = nullptr;
|
|
715
|
+
|
|
716
|
+
std::string path_model; // populated by whisper_init_from_file()
|
|
640
717
|
};
|
|
641
718
|
|
|
642
719
|
template<typename T>
|
|
@@ -653,9 +730,11 @@ static bool kv_cache_init(
|
|
|
653
730
|
int n_ctx) {
|
|
654
731
|
cache.buf.resize(mem_bytes);
|
|
655
732
|
|
|
656
|
-
struct ggml_init_params params
|
|
657
|
-
|
|
658
|
-
|
|
733
|
+
struct ggml_init_params params = {
|
|
734
|
+
/*.mem_size =*/ cache.buf.size(),
|
|
735
|
+
/*.mem_buffer =*/ cache.buf.data(),
|
|
736
|
+
/*.no_alloc =*/ false,
|
|
737
|
+
};
|
|
659
738
|
|
|
660
739
|
cache.ctx = ggml_init(params);
|
|
661
740
|
|
|
@@ -685,11 +764,13 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
|
|
|
685
764
|
const ggml_type wtype = cache.k->type;
|
|
686
765
|
WHISPER_ASSERT(wtype == cache.v->type);
|
|
687
766
|
|
|
688
|
-
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*
|
|
767
|
+
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype));
|
|
689
768
|
|
|
690
|
-
struct ggml_init_params params
|
|
691
|
-
|
|
692
|
-
|
|
769
|
+
struct ggml_init_params params = {
|
|
770
|
+
/*.mem_size =*/ cache.buf.size(),
|
|
771
|
+
/*.mem_buffer =*/ cache.buf.data(),
|
|
772
|
+
/*.no_alloc =*/ false,
|
|
773
|
+
};
|
|
693
774
|
|
|
694
775
|
cache.ctx = ggml_init(params);
|
|
695
776
|
|
|
@@ -756,7 +837,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
756
837
|
read_safe(loader, hparams.n_text_head);
|
|
757
838
|
read_safe(loader, hparams.n_text_layer);
|
|
758
839
|
read_safe(loader, hparams.n_mels);
|
|
759
|
-
read_safe(loader, hparams.
|
|
840
|
+
read_safe(loader, hparams.ftype);
|
|
760
841
|
|
|
761
842
|
assert(hparams.n_text_state == hparams.n_audio_state);
|
|
762
843
|
|
|
@@ -780,11 +861,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
780
861
|
model.type = e_model::MODEL_LARGE;
|
|
781
862
|
}
|
|
782
863
|
|
|
783
|
-
// for the big tensors, we have the option to store the data in 16-bit floats
|
|
864
|
+
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
|
784
865
|
// in order to save memory and also to speed up the computation
|
|
785
|
-
wctx.wtype = model.hparams.
|
|
866
|
+
wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
|
|
867
|
+
if (wctx.wtype == GGML_TYPE_COUNT) {
|
|
868
|
+
fprintf(stderr, "%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
|
|
869
|
+
return false;
|
|
870
|
+
}
|
|
786
871
|
|
|
787
|
-
const size_t scale = model.hparams.
|
|
872
|
+
const size_t scale = model.hparams.ftype ? 1 : 2;
|
|
788
873
|
|
|
789
874
|
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
|
790
875
|
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
|
|
@@ -796,18 +881,18 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
796
881
|
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
|
|
797
882
|
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
|
|
798
883
|
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
|
|
799
|
-
fprintf(stderr, "%s:
|
|
884
|
+
fprintf(stderr, "%s: ftype = %d\n", __func__, model.hparams.ftype);
|
|
800
885
|
fprintf(stderr, "%s: type = %d\n", __func__, model.type);
|
|
801
886
|
|
|
802
887
|
// print memory requirements
|
|
803
888
|
{
|
|
804
889
|
// this is the total memory required to run the inference
|
|
805
890
|
const size_t mem_required =
|
|
806
|
-
MEM_REQ_SCRATCH0.at
|
|
807
|
-
MEM_REQ_SCRATCH1.at
|
|
808
|
-
MEM_REQ_SCRATCH2.at
|
|
809
|
-
MEM_REQ_SCRATCH3.at
|
|
810
|
-
scale*MEM_REQ_MODEL.at
|
|
891
|
+
MEM_REQ_SCRATCH0.at(model.type) +
|
|
892
|
+
MEM_REQ_SCRATCH1.at(model.type) +
|
|
893
|
+
MEM_REQ_SCRATCH2.at(model.type) +
|
|
894
|
+
MEM_REQ_SCRATCH3.at(model.type) +
|
|
895
|
+
scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
|
|
811
896
|
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
|
812
897
|
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
|
813
898
|
|
|
@@ -823,7 +908,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
823
908
|
// always have at least one decoder
|
|
824
909
|
|
|
825
910
|
wctx.model.buf = new std::vector<uint8_t>();
|
|
826
|
-
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
|
|
911
|
+
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
|
|
827
912
|
|
|
828
913
|
// we skip initialization of the state until it is needed
|
|
829
914
|
// because it might be that state will always be provided externally.
|
|
@@ -914,6 +999,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
914
999
|
size_t ctx_size = 0;
|
|
915
1000
|
|
|
916
1001
|
const ggml_type wtype = wctx.wtype;
|
|
1002
|
+
const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
|
|
917
1003
|
|
|
918
1004
|
{
|
|
919
1005
|
const auto & hparams = model.hparams;
|
|
@@ -932,92 +1018,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
932
1018
|
|
|
933
1019
|
// encoder
|
|
934
1020
|
{
|
|
935
|
-
ctx_size += n_audio_ctx*n_audio_state*
|
|
1021
|
+
ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe;
|
|
936
1022
|
|
|
937
|
-
ctx_size += 3*n_mels*n_audio_state*
|
|
938
|
-
ctx_size += n_audio_state*
|
|
1023
|
+
ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(vtype); // e_conv_1_w
|
|
1024
|
+
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b
|
|
939
1025
|
|
|
940
|
-
ctx_size += 3*n_audio_state*n_audio_state*
|
|
941
|
-
ctx_size += n_audio_state*
|
|
1026
|
+
ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(vtype); // e_conv_2_w
|
|
1027
|
+
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b
|
|
942
1028
|
|
|
943
|
-
ctx_size += n_audio_state*
|
|
944
|
-
ctx_size += n_audio_state*
|
|
1029
|
+
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w;
|
|
1030
|
+
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b;
|
|
945
1031
|
}
|
|
946
1032
|
|
|
947
1033
|
// decoder
|
|
948
1034
|
{
|
|
949
|
-
ctx_size += n_text_ctx*n_text_state*
|
|
1035
|
+
ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe;
|
|
950
1036
|
|
|
951
|
-
ctx_size += n_vocab*n_text_state*
|
|
1037
|
+
ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te;
|
|
952
1038
|
|
|
953
|
-
ctx_size += n_text_state*
|
|
954
|
-
ctx_size += n_text_state*
|
|
1039
|
+
ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w;
|
|
1040
|
+
ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b;
|
|
955
1041
|
}
|
|
956
1042
|
|
|
957
1043
|
// encoder layers
|
|
958
1044
|
{
|
|
959
|
-
ctx_size += n_audio_layer*(n_audio_state*
|
|
960
|
-
ctx_size += n_audio_layer*(n_audio_state*
|
|
1045
|
+
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
|
|
1046
|
+
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
|
|
961
1047
|
|
|
962
|
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*
|
|
963
|
-
ctx_size += n_audio_layer*( 4*n_audio_state*
|
|
1048
|
+
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w
|
|
1049
|
+
ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
|
|
964
1050
|
|
|
965
|
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*
|
|
966
|
-
ctx_size += n_audio_layer*( n_audio_state*
|
|
1051
|
+
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w
|
|
1052
|
+
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
|
|
967
1053
|
|
|
968
|
-
ctx_size += n_audio_layer*(n_audio_state*
|
|
969
|
-
ctx_size += n_audio_layer*(n_audio_state*
|
|
1054
|
+
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
|
|
1055
|
+
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
|
|
970
1056
|
|
|
971
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*
|
|
972
|
-
ctx_size += n_audio_layer*( n_audio_state*
|
|
1057
|
+
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w
|
|
1058
|
+
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
|
|
973
1059
|
|
|
974
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*
|
|
1060
|
+
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w
|
|
975
1061
|
|
|
976
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*
|
|
977
|
-
ctx_size += n_audio_layer*( n_audio_state*
|
|
1062
|
+
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w
|
|
1063
|
+
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
|
|
978
1064
|
|
|
979
|
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*
|
|
980
|
-
ctx_size += n_audio_layer*( n_audio_state*
|
|
1065
|
+
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w
|
|
1066
|
+
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
|
|
981
1067
|
}
|
|
982
1068
|
|
|
983
1069
|
// decoder layers
|
|
984
1070
|
{
|
|
985
|
-
ctx_size += n_text_layer*(n_text_state*
|
|
986
|
-
ctx_size += n_text_layer*(n_text_state*
|
|
1071
|
+
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
|
|
1072
|
+
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
|
|
987
1073
|
|
|
988
|
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*
|
|
989
|
-
ctx_size += n_text_layer*( 4*n_text_state*
|
|
1074
|
+
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w
|
|
1075
|
+
ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
|
|
990
1076
|
|
|
991
|
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*
|
|
992
|
-
ctx_size += n_text_layer*( n_text_state*
|
|
1077
|
+
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w
|
|
1078
|
+
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
|
|
993
1079
|
|
|
994
|
-
ctx_size += n_text_layer*(n_text_state*
|
|
995
|
-
ctx_size += n_text_layer*(n_text_state*
|
|
1080
|
+
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
|
|
1081
|
+
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
|
|
996
1082
|
|
|
997
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*
|
|
998
|
-
ctx_size += n_text_layer*( n_text_state*
|
|
1083
|
+
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w
|
|
1084
|
+
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
|
|
999
1085
|
|
|
1000
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*
|
|
1086
|
+
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w
|
|
1001
1087
|
|
|
1002
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*
|
|
1003
|
-
ctx_size += n_text_layer*( n_text_state*
|
|
1088
|
+
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w
|
|
1089
|
+
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
|
|
1004
1090
|
|
|
1005
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*
|
|
1006
|
-
ctx_size += n_text_layer*( n_text_state*
|
|
1091
|
+
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w
|
|
1092
|
+
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
|
|
1007
1093
|
//
|
|
1008
|
-
ctx_size += n_text_layer*(n_text_state*
|
|
1009
|
-
ctx_size += n_text_layer*(n_text_state*
|
|
1094
|
+
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w
|
|
1095
|
+
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b
|
|
1010
1096
|
|
|
1011
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*
|
|
1012
|
-
ctx_size += n_text_layer*( n_text_state*
|
|
1097
|
+
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w
|
|
1098
|
+
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b
|
|
1013
1099
|
|
|
1014
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*
|
|
1100
|
+
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w
|
|
1015
1101
|
|
|
1016
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*
|
|
1017
|
-
ctx_size += n_text_layer*( n_text_state*
|
|
1102
|
+
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w
|
|
1103
|
+
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b
|
|
1018
1104
|
|
|
1019
|
-
ctx_size += n_text_layer*(n_text_state*n_text_state*
|
|
1020
|
-
ctx_size += n_text_layer*( n_text_state*
|
|
1105
|
+
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w
|
|
1106
|
+
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b
|
|
1021
1107
|
}
|
|
1022
1108
|
|
|
1023
1109
|
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
|
|
@@ -1027,9 +1113,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1027
1113
|
|
|
1028
1114
|
// create the ggml context
|
|
1029
1115
|
{
|
|
1030
|
-
struct ggml_init_params params
|
|
1031
|
-
|
|
1032
|
-
|
|
1116
|
+
struct ggml_init_params params = {
|
|
1117
|
+
/*.mem_size =*/ wctx.model.buf->size(),
|
|
1118
|
+
/*.mem_buffer =*/ wctx.model.buf->data(),
|
|
1119
|
+
/*.no_alloc =*/ false,
|
|
1120
|
+
};
|
|
1033
1121
|
|
|
1034
1122
|
model.ctx = ggml_init(params);
|
|
1035
1123
|
if (!model.ctx) {
|
|
@@ -1061,175 +1149,175 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1061
1149
|
|
|
1062
1150
|
// encoder
|
|
1063
1151
|
{
|
|
1064
|
-
model.e_pe
|
|
1152
|
+
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
|
1065
1153
|
|
|
1066
|
-
model.e_conv_1_w = ggml_new_tensor_3d(ctx,
|
|
1154
|
+
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
|
1067
1155
|
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
|
1068
1156
|
|
|
1069
|
-
model.e_conv_2_w = ggml_new_tensor_3d(ctx,
|
|
1157
|
+
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
|
1070
1158
|
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
|
1071
1159
|
|
|
1072
|
-
model.e_ln_w
|
|
1073
|
-
model.e_ln_b
|
|
1160
|
+
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1161
|
+
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1074
1162
|
|
|
1075
1163
|
// map by name
|
|
1076
1164
|
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
|
1077
1165
|
|
|
1078
|
-
model.tensors["encoder.conv1.weight"]
|
|
1079
|
-
model.tensors["encoder.conv1.bias"]
|
|
1166
|
+
model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
|
|
1167
|
+
model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
|
|
1080
1168
|
|
|
1081
|
-
model.tensors["encoder.conv2.weight"]
|
|
1082
|
-
model.tensors["encoder.conv2.bias"]
|
|
1169
|
+
model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
|
|
1170
|
+
model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
|
|
1083
1171
|
|
|
1084
|
-
model.tensors["encoder.ln_post.weight"]
|
|
1085
|
-
model.tensors["encoder.ln_post.bias"]
|
|
1172
|
+
model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
|
|
1173
|
+
model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
|
|
1086
1174
|
|
|
1087
1175
|
for (int i = 0; i < n_audio_layer; ++i) {
|
|
1088
1176
|
auto & layer = model.layers_encoder[i];
|
|
1089
1177
|
|
|
1090
|
-
layer.mlp_ln_w
|
|
1091
|
-
layer.mlp_ln_b
|
|
1178
|
+
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1179
|
+
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1092
1180
|
|
|
1093
|
-
layer.mlp_0_w
|
|
1094
|
-
layer.mlp_0_b
|
|
1181
|
+
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
|
|
1182
|
+
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
|
|
1095
1183
|
|
|
1096
|
-
layer.mlp_1_w
|
|
1097
|
-
layer.mlp_1_b
|
|
1184
|
+
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
|
|
1185
|
+
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1098
1186
|
|
|
1099
|
-
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
|
|
1100
|
-
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
|
|
1187
|
+
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1188
|
+
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1101
1189
|
|
|
1102
|
-
layer.attn_q_w
|
|
1103
|
-
layer.attn_q_b
|
|
1190
|
+
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
1191
|
+
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1104
1192
|
|
|
1105
|
-
layer.attn_k_w
|
|
1193
|
+
layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
1106
1194
|
|
|
1107
|
-
layer.attn_v_w
|
|
1108
|
-
layer.attn_v_b
|
|
1195
|
+
layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
1196
|
+
layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1109
1197
|
|
|
1110
|
-
layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype,
|
|
1111
|
-
layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
|
|
1198
|
+
layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
1199
|
+
layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
1112
1200
|
|
|
1113
1201
|
// map by name
|
|
1114
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"]
|
|
1115
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"]
|
|
1202
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
|
1203
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
|
1116
1204
|
|
|
1117
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"]
|
|
1118
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"]
|
|
1205
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
|
|
1206
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
|
1119
1207
|
|
|
1120
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"]
|
|
1121
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"]
|
|
1208
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
|
|
1209
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
|
1122
1210
|
|
|
1123
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"]
|
|
1124
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"]
|
|
1211
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
|
1212
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
|
1125
1213
|
|
|
1126
1214
|
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
|
1127
1215
|
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
|
1128
1216
|
|
|
1129
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"]
|
|
1217
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
|
1130
1218
|
|
|
1131
1219
|
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
|
1132
1220
|
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
|
1133
1221
|
|
|
1134
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"]
|
|
1135
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"]
|
|
1222
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
|
1223
|
+
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
|
1136
1224
|
}
|
|
1137
1225
|
}
|
|
1138
1226
|
|
|
1139
1227
|
// decoder
|
|
1140
1228
|
{
|
|
1141
|
-
model.d_pe
|
|
1229
|
+
model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
|
|
1142
1230
|
|
|
1143
|
-
model.d_te
|
|
1231
|
+
model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
|
|
1144
1232
|
|
|
1145
1233
|
model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1146
1234
|
model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1147
1235
|
|
|
1148
1236
|
// map by name
|
|
1149
|
-
model.tensors["decoder.positional_embedding"]
|
|
1237
|
+
model.tensors["decoder.positional_embedding"] = model.d_pe;
|
|
1150
1238
|
|
|
1151
1239
|
model.tensors["decoder.token_embedding.weight"] = model.d_te;
|
|
1152
1240
|
|
|
1153
|
-
model.tensors["decoder.ln.weight"]
|
|
1154
|
-
model.tensors["decoder.ln.bias"]
|
|
1241
|
+
model.tensors["decoder.ln.weight"] = model.d_ln_w;
|
|
1242
|
+
model.tensors["decoder.ln.bias"] = model.d_ln_b;
|
|
1155
1243
|
|
|
1156
1244
|
for (int i = 0; i < n_text_layer; ++i) {
|
|
1157
1245
|
auto & layer = model.layers_decoder[i];
|
|
1158
1246
|
|
|
1159
|
-
layer.mlp_ln_w
|
|
1160
|
-
layer.mlp_ln_b
|
|
1247
|
+
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1248
|
+
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1161
1249
|
|
|
1162
|
-
layer.mlp_0_w
|
|
1163
|
-
layer.mlp_0_b
|
|
1250
|
+
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
|
|
1251
|
+
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
|
|
1164
1252
|
|
|
1165
|
-
layer.mlp_1_w
|
|
1166
|
-
layer.mlp_1_b
|
|
1253
|
+
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
|
|
1254
|
+
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1167
1255
|
|
|
1168
|
-
layer.attn_ln_0_w
|
|
1169
|
-
layer.attn_ln_0_b
|
|
1256
|
+
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1257
|
+
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1170
1258
|
|
|
1171
|
-
layer.attn_q_w
|
|
1172
|
-
layer.attn_q_b
|
|
1259
|
+
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
1260
|
+
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1173
1261
|
|
|
1174
|
-
layer.attn_k_w
|
|
1262
|
+
layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
1175
1263
|
|
|
1176
|
-
layer.attn_v_w
|
|
1177
|
-
layer.attn_v_b
|
|
1264
|
+
layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
1265
|
+
layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1178
1266
|
|
|
1179
|
-
layer.attn_ln_1_w
|
|
1180
|
-
layer.attn_ln_1_b
|
|
1267
|
+
layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
1268
|
+
layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1181
1269
|
|
|
1182
|
-
layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
|
|
1183
|
-
layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
|
|
1270
|
+
layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1271
|
+
layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1184
1272
|
|
|
1185
|
-
layer.cross_attn_q_w
|
|
1186
|
-
layer.cross_attn_q_b
|
|
1273
|
+
layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
1274
|
+
layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1187
1275
|
|
|
1188
|
-
layer.cross_attn_k_w
|
|
1276
|
+
layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
1189
1277
|
|
|
1190
|
-
layer.cross_attn_v_w
|
|
1191
|
-
layer.cross_attn_v_b
|
|
1278
|
+
layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
1279
|
+
layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1192
1280
|
|
|
1193
|
-
layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype,
|
|
1194
|
-
layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,
|
|
1281
|
+
layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
1282
|
+
layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
1195
1283
|
|
|
1196
1284
|
// map by name
|
|
1197
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"]
|
|
1198
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"]
|
|
1285
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
|
1286
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
|
1199
1287
|
|
|
1200
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"]
|
|
1201
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"]
|
|
1288
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
|
|
1289
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
|
1202
1290
|
|
|
1203
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"]
|
|
1204
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"]
|
|
1291
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
|
|
1292
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
|
1205
1293
|
|
|
1206
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"]
|
|
1207
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"]
|
|
1294
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
|
1295
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
|
1208
1296
|
|
|
1209
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"]
|
|
1210
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"]
|
|
1297
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
|
1298
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
|
1211
1299
|
|
|
1212
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"]
|
|
1300
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
|
1213
1301
|
|
|
1214
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"]
|
|
1215
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"]
|
|
1302
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
|
1303
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
|
1216
1304
|
|
|
1217
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"]
|
|
1218
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"]
|
|
1305
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
|
1306
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
|
1219
1307
|
|
|
1220
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"]
|
|
1221
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"]
|
|
1308
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
|
|
1309
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
|
|
1222
1310
|
|
|
1223
1311
|
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
|
|
1224
1312
|
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
|
|
1225
1313
|
|
|
1226
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"]
|
|
1314
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
|
|
1227
1315
|
|
|
1228
1316
|
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
|
|
1229
1317
|
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
|
|
1230
1318
|
|
|
1231
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"]
|
|
1232
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"]
|
|
1319
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
|
|
1320
|
+
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
|
|
1233
1321
|
}
|
|
1234
1322
|
}
|
|
1235
1323
|
}
|
|
@@ -1243,18 +1331,18 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1243
1331
|
while (true) {
|
|
1244
1332
|
int32_t n_dims;
|
|
1245
1333
|
int32_t length;
|
|
1246
|
-
int32_t
|
|
1334
|
+
int32_t ttype;
|
|
1247
1335
|
|
|
1248
1336
|
read_safe(loader, n_dims);
|
|
1249
1337
|
read_safe(loader, length);
|
|
1250
|
-
read_safe(loader,
|
|
1338
|
+
read_safe(loader, ttype);
|
|
1251
1339
|
|
|
1252
1340
|
if (loader->eof(loader->context)) {
|
|
1253
1341
|
break;
|
|
1254
1342
|
}
|
|
1255
1343
|
|
|
1256
1344
|
int32_t nelements = 1;
|
|
1257
|
-
int32_t ne[
|
|
1345
|
+
int32_t ne[4] = { 1, 1, 1, 1 };
|
|
1258
1346
|
for (int i = 0; i < n_dims; ++i) {
|
|
1259
1347
|
read_safe(loader, ne[i]);
|
|
1260
1348
|
nelements *= ne[i];
|
|
@@ -1273,18 +1361,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1273
1361
|
auto tensor = model.tensors[name.data()];
|
|
1274
1362
|
if (ggml_nelements(tensor) != nelements) {
|
|
1275
1363
|
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
|
1364
|
+
fprintf(stderr, "%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
|
1365
|
+
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
|
1276
1366
|
return false;
|
|
1277
1367
|
}
|
|
1278
1368
|
|
|
1279
1369
|
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
|
1280
1370
|
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
|
1281
|
-
__func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
|
|
1371
|
+
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
|
1282
1372
|
return false;
|
|
1283
1373
|
}
|
|
1284
1374
|
|
|
1285
|
-
const size_t bpe = (
|
|
1375
|
+
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
|
1286
1376
|
|
|
1287
|
-
if (nelements*bpe != ggml_nbytes(tensor)) {
|
|
1377
|
+
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
|
1288
1378
|
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
|
1289
1379
|
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
|
1290
1380
|
return false;
|
|
@@ -1293,7 +1383,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1293
1383
|
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
|
1294
1384
|
BYTESWAP_TENSOR(tensor);
|
|
1295
1385
|
|
|
1296
|
-
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2],
|
|
1386
|
+
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
|
|
1297
1387
|
total_size += ggml_nbytes(tensor);
|
|
1298
1388
|
model.n_loaded++;
|
|
1299
1389
|
}
|
|
@@ -1343,9 +1433,11 @@ static bool whisper_encode_internal(
|
|
|
1343
1433
|
const int n_mels = hparams.n_mels;
|
|
1344
1434
|
assert(mel_inp.n_mel == n_mels);
|
|
1345
1435
|
|
|
1346
|
-
struct ggml_init_params params
|
|
1347
|
-
|
|
1348
|
-
|
|
1436
|
+
struct ggml_init_params params = {
|
|
1437
|
+
/*.mem_size =*/ wstate.buf_compute.size(),
|
|
1438
|
+
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
|
1439
|
+
/*.no_alloc =*/ false,
|
|
1440
|
+
};
|
|
1349
1441
|
|
|
1350
1442
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
1351
1443
|
|
|
@@ -1369,312 +1461,320 @@ static bool whisper_encode_internal(
|
|
|
1369
1461
|
|
|
1370
1462
|
struct ggml_tensor * cur;
|
|
1371
1463
|
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1464
|
+
#ifndef WHISPER_USE_COREML
|
|
1465
|
+
const bool use_coreml = false;
|
|
1466
|
+
#else
|
|
1467
|
+
const bool use_coreml = wstate.ctx_coreml != nullptr;
|
|
1468
|
+
#endif
|
|
1375
1469
|
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
cur),
|
|
1381
|
-
cur);
|
|
1470
|
+
if (!use_coreml) {
|
|
1471
|
+
// convolution + gelu
|
|
1472
|
+
{
|
|
1473
|
+
wstate.use_buf(ctx0, 1);
|
|
1382
1474
|
|
|
1383
|
-
|
|
1475
|
+
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
|
|
1476
|
+
cur = ggml_add(ctx0,
|
|
1477
|
+
ggml_repeat(ctx0,
|
|
1478
|
+
model.e_conv_1_b,
|
|
1479
|
+
cur),
|
|
1480
|
+
cur);
|
|
1384
1481
|
|
|
1385
|
-
|
|
1482
|
+
cur = ggml_gelu(ctx0, cur);
|
|
1386
1483
|
|
|
1387
|
-
|
|
1388
|
-
cur = ggml_add(ctx0,
|
|
1389
|
-
ggml_repeat(ctx0,
|
|
1390
|
-
model.e_conv_2_b,
|
|
1391
|
-
cur),
|
|
1392
|
-
cur);
|
|
1484
|
+
wstate.use_buf(ctx0, 0);
|
|
1393
1485
|
|
|
1394
|
-
|
|
1395
|
-
|
|
1486
|
+
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
|
|
1487
|
+
cur = ggml_add(ctx0,
|
|
1488
|
+
ggml_repeat(ctx0,
|
|
1489
|
+
model.e_conv_2_b,
|
|
1490
|
+
cur),
|
|
1491
|
+
cur);
|
|
1396
1492
|
|
|
1397
|
-
|
|
1493
|
+
cur = ggml_gelu(ctx0, cur);
|
|
1494
|
+
}
|
|
1398
1495
|
|
|
1399
|
-
|
|
1400
|
-
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
1401
|
-
//static int iter = -1;
|
|
1402
|
-
//const int n_iter = 1500/n_ctx;
|
|
1496
|
+
wstate.use_buf(ctx0, 3);
|
|
1403
1497
|
|
|
1404
|
-
|
|
1498
|
+
// ===================================================================
|
|
1499
|
+
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
1500
|
+
//static int iter = -1;
|
|
1501
|
+
//const int n_iter = 1500/n_ctx;
|
|
1405
1502
|
|
|
1406
|
-
|
|
1407
|
-
// memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
|
|
1408
|
-
// memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
|
|
1409
|
-
//}
|
|
1503
|
+
//iter = (iter + 1) % n_iter;
|
|
1410
1504
|
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1505
|
+
//if (iter == 0) {
|
|
1506
|
+
// memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
|
|
1507
|
+
// memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
|
|
1508
|
+
//}
|
|
1415
1509
|
|
|
1416
|
-
|
|
1510
|
+
static int iter = 0;
|
|
1417
1511
|
|
|
1418
|
-
|
|
1512
|
+
const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
|
|
1513
|
+
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
|
|
1419
1514
|
|
|
1420
|
-
|
|
1515
|
+
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
|
|
1421
1516
|
|
|
1422
|
-
|
|
1423
|
-
//cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
|
|
1517
|
+
cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
|
|
1424
1518
|
|
|
1425
|
-
|
|
1519
|
+
// ===================================================================
|
|
1426
1520
|
|
|
1427
|
-
|
|
1428
|
-
|
|
1521
|
+
// original:
|
|
1522
|
+
//cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
|
|
1429
1523
|
|
|
1430
|
-
|
|
1431
|
-
{
|
|
1432
|
-
wstate.use_buf(ctx0, 0);
|
|
1524
|
+
struct ggml_tensor * inpL = cur;
|
|
1433
1525
|
|
|
1434
|
-
|
|
1526
|
+
for (int il = 0; il < n_layer; ++il) {
|
|
1527
|
+
const auto & layer = model.layers_encoder[il];
|
|
1435
1528
|
|
|
1436
|
-
//
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
|
1440
|
-
cur),
|
|
1441
|
-
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
|
1442
|
-
}
|
|
1529
|
+
// norm
|
|
1530
|
+
{
|
|
1531
|
+
wstate.use_buf(ctx0, 0);
|
|
1443
1532
|
|
|
1444
|
-
|
|
1445
|
-
{
|
|
1446
|
-
wstate.use_buf(ctx0, 1);
|
|
1533
|
+
cur = ggml_norm(ctx0, inpL);
|
|
1447
1534
|
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1535
|
+
// cur = ln_0_w*cur + ln_0_b
|
|
1536
|
+
cur = ggml_add(ctx0,
|
|
1537
|
+
ggml_mul(ctx0,
|
|
1538
|
+
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
|
1539
|
+
cur),
|
|
1540
|
+
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
|
1541
|
+
}
|
|
1451
1542
|
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
Qcur),
|
|
1456
|
-
Qcur);
|
|
1543
|
+
// self-attention
|
|
1544
|
+
{
|
|
1545
|
+
wstate.use_buf(ctx0, 1);
|
|
1457
1546
|
|
|
1458
|
-
|
|
1547
|
+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
|
1548
|
+
layer.attn_q_w,
|
|
1549
|
+
cur);
|
|
1459
1550
|
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1551
|
+
Qcur = ggml_add(ctx0,
|
|
1552
|
+
ggml_repeat(ctx0,
|
|
1553
|
+
layer.attn_q_b,
|
|
1554
|
+
Qcur),
|
|
1555
|
+
Qcur);
|
|
1464
1556
|
|
|
1465
|
-
|
|
1557
|
+
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1466
1558
|
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1559
|
+
// note: no bias for Key
|
|
1560
|
+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
|
1561
|
+
layer.attn_k_w,
|
|
1562
|
+
cur);
|
|
1470
1563
|
|
|
1471
|
-
|
|
1472
|
-
ggml_repeat(ctx0,
|
|
1473
|
-
layer.attn_v_b,
|
|
1474
|
-
Vcur),
|
|
1475
|
-
Vcur);
|
|
1564
|
+
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1476
1565
|
|
|
1477
|
-
|
|
1566
|
+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
|
1567
|
+
layer.attn_v_w,
|
|
1568
|
+
cur);
|
|
1478
1569
|
|
|
1479
|
-
|
|
1570
|
+
Vcur = ggml_add(ctx0,
|
|
1571
|
+
ggml_repeat(ctx0,
|
|
1572
|
+
layer.attn_v_b,
|
|
1573
|
+
Vcur),
|
|
1574
|
+
Vcur);
|
|
1480
1575
|
|
|
1481
|
-
|
|
1482
|
-
struct ggml_tensor * Q =
|
|
1483
|
-
ggml_permute(ctx0,
|
|
1484
|
-
ggml_cpy(ctx0,
|
|
1485
|
-
Qcur,
|
|
1486
|
-
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
|
1487
|
-
0, 2, 1, 3);
|
|
1576
|
+
// ------
|
|
1488
1577
|
|
|
1489
|
-
|
|
1490
|
-
ggml_permute(ctx0,
|
|
1491
|
-
ggml_cpy(ctx0,
|
|
1492
|
-
Kcur,
|
|
1493
|
-
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
|
1494
|
-
0, 2, 1, 3);
|
|
1578
|
+
wstate.use_buf(ctx0, 0);
|
|
1495
1579
|
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
n_state/n_head, n_head, n_ctx),
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1580
|
+
#ifdef WHISPER_USE_FLASH_ATTN
|
|
1581
|
+
struct ggml_tensor * Q =
|
|
1582
|
+
ggml_permute(ctx0,
|
|
1583
|
+
ggml_cpy(ctx0,
|
|
1584
|
+
Qcur,
|
|
1585
|
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1586
|
+
0, 2, 1, 3);
|
|
1587
|
+
|
|
1588
|
+
struct ggml_tensor * K =
|
|
1589
|
+
ggml_permute(ctx0,
|
|
1590
|
+
ggml_cpy(ctx0,
|
|
1591
|
+
Kcur,
|
|
1592
|
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1593
|
+
0, 2, 1, 3);
|
|
1594
|
+
|
|
1595
|
+
struct ggml_tensor * V =
|
|
1596
|
+
ggml_cpy(ctx0,
|
|
1597
|
+
ggml_permute(ctx0,
|
|
1598
|
+
ggml_reshape_3d(ctx0,
|
|
1599
|
+
Vcur,
|
|
1600
|
+
n_state/n_head, n_head, n_ctx),
|
|
1601
|
+
1, 2, 0, 3),
|
|
1602
|
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
|
1603
|
+
|
|
1604
|
+
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
|
|
1507
1605
|
#else
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1606
|
+
struct ggml_tensor * Q =
|
|
1607
|
+
ggml_permute(ctx0,
|
|
1608
|
+
ggml_cpy(ctx0,
|
|
1609
|
+
Qcur,
|
|
1610
|
+
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
|
1611
|
+
0, 2, 1, 3);
|
|
1612
|
+
|
|
1613
|
+
struct ggml_tensor * K =
|
|
1614
|
+
ggml_permute(ctx0,
|
|
1615
|
+
ggml_cpy(ctx0,
|
|
1616
|
+
Kcur,
|
|
1617
|
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
|
1618
|
+
0, 2, 1, 3);
|
|
1619
|
+
|
|
1620
|
+
// K * Q
|
|
1621
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
1622
|
+
|
|
1623
|
+
struct ggml_tensor * KQ_scaled =
|
|
1624
|
+
ggml_scale(ctx0,
|
|
1625
|
+
KQ,
|
|
1626
|
+
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
|
1627
|
+
);
|
|
1628
|
+
|
|
1629
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
|
|
1630
|
+
|
|
1631
|
+
struct ggml_tensor * V =
|
|
1632
|
+
ggml_cpy(ctx0,
|
|
1633
|
+
ggml_permute(ctx0,
|
|
1634
|
+
ggml_reshape_3d(ctx0,
|
|
1635
|
+
Vcur,
|
|
1636
|
+
n_state/n_head, n_head, n_ctx),
|
|
1637
|
+
1, 2, 0, 3),
|
|
1638
|
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
|
1639
|
+
);
|
|
1640
|
+
|
|
1641
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
1642
|
+
#endif
|
|
1643
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
1514
1644
|
|
|
1515
|
-
|
|
1516
|
-
ggml_permute(ctx0,
|
|
1517
|
-
ggml_cpy(ctx0,
|
|
1518
|
-
Kcur,
|
|
1519
|
-
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
|
1520
|
-
0, 2, 1, 3);
|
|
1645
|
+
wstate.use_buf(ctx0, 1);
|
|
1521
1646
|
|
|
1522
|
-
|
|
1523
|
-
|
|
1647
|
+
cur = ggml_cpy(ctx0,
|
|
1648
|
+
KQV_merged,
|
|
1649
|
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
|
1650
|
+
}
|
|
1524
1651
|
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
|
1529
|
-
);
|
|
1652
|
+
// projection
|
|
1653
|
+
{
|
|
1654
|
+
wstate.use_buf(ctx0, 0);
|
|
1530
1655
|
|
|
1531
|
-
|
|
1656
|
+
cur = ggml_mul_mat(ctx0,
|
|
1657
|
+
layer.attn_ln_1_w,
|
|
1658
|
+
cur);
|
|
1532
1659
|
|
|
1533
|
-
|
|
1534
|
-
// ggml_permute(ctx0,
|
|
1535
|
-
// ggml_cpy(ctx0,
|
|
1536
|
-
// Vcur,
|
|
1537
|
-
// ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
|
1538
|
-
// 1, 2, 0, 3);
|
|
1660
|
+
wstate.use_buf(ctx0, 1);
|
|
1539
1661
|
|
|
1540
|
-
|
|
1662
|
+
cur = ggml_add(ctx0,
|
|
1663
|
+
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
|
1664
|
+
cur);
|
|
1665
|
+
}
|
|
1541
1666
|
|
|
1542
|
-
|
|
1543
|
-
ggml_cpy(ctx0,
|
|
1544
|
-
ggml_permute(ctx0,
|
|
1545
|
-
ggml_reshape_3d(ctx0,
|
|
1546
|
-
Vcur,
|
|
1547
|
-
n_state/n_head, n_head, n_ctx),
|
|
1548
|
-
0, 2, 1, 3),
|
|
1549
|
-
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
|
|
1550
|
-
);
|
|
1551
|
-
|
|
1552
|
-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
|
|
1553
|
-
#endif
|
|
1554
|
-
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
1667
|
+
wstate.use_buf(ctx0, 2);
|
|
1555
1668
|
|
|
1556
|
-
|
|
1669
|
+
// add the input
|
|
1670
|
+
cur = ggml_add(ctx0, cur, inpL);
|
|
1557
1671
|
|
|
1558
|
-
|
|
1559
|
-
KQV_merged,
|
|
1560
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
|
1561
|
-
}
|
|
1672
|
+
struct ggml_tensor * inpFF = cur;
|
|
1562
1673
|
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
layer.attn_ln_1_w,
|
|
1569
|
-
cur);
|
|
1570
|
-
|
|
1571
|
-
wstate.use_buf(ctx0, 1);
|
|
1674
|
+
// feed-forward network
|
|
1675
|
+
{
|
|
1676
|
+
// norm
|
|
1677
|
+
{
|
|
1678
|
+
wstate.use_buf(ctx0, 0);
|
|
1572
1679
|
|
|
1573
|
-
|
|
1574
|
-
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
|
1575
|
-
cur);
|
|
1576
|
-
}
|
|
1680
|
+
cur = ggml_norm(ctx0, inpFF);
|
|
1577
1681
|
|
|
1578
|
-
|
|
1682
|
+
wstate.use_buf(ctx0, 1);
|
|
1579
1683
|
|
|
1580
|
-
|
|
1581
|
-
|
|
1684
|
+
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
1685
|
+
cur = ggml_add(ctx0,
|
|
1686
|
+
ggml_mul(ctx0,
|
|
1687
|
+
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
|
1688
|
+
cur),
|
|
1689
|
+
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
|
1690
|
+
}
|
|
1582
1691
|
|
|
1583
|
-
|
|
1692
|
+
#ifdef WHISPER_USE_FLASH_FF
|
|
1693
|
+
wstate.use_buf(ctx0, 0);
|
|
1584
1694
|
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1695
|
+
cur = ggml_flash_ff(ctx0,
|
|
1696
|
+
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
|
1697
|
+
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
1698
|
+
#else
|
|
1589
1699
|
wstate.use_buf(ctx0, 0);
|
|
1590
1700
|
|
|
1591
|
-
|
|
1701
|
+
// fully connected
|
|
1702
|
+
cur = ggml_mul_mat(ctx0,
|
|
1703
|
+
layer.mlp_0_w,
|
|
1704
|
+
cur);
|
|
1592
1705
|
|
|
1593
1706
|
wstate.use_buf(ctx0, 1);
|
|
1594
1707
|
|
|
1595
|
-
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
1596
1708
|
cur = ggml_add(ctx0,
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
cur),
|
|
1600
|
-
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
|
1601
|
-
}
|
|
1709
|
+
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
|
1710
|
+
cur);
|
|
1602
1711
|
|
|
1603
|
-
|
|
1604
|
-
wstate.use_buf(ctx0, 0);
|
|
1712
|
+
wstate.use_buf(ctx0, 0);
|
|
1605
1713
|
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
1609
|
-
#else
|
|
1610
|
-
wstate.use_buf(ctx0, 0);
|
|
1714
|
+
// GELU activation
|
|
1715
|
+
cur = ggml_gelu(ctx0, cur);
|
|
1611
1716
|
|
|
1612
|
-
|
|
1613
|
-
cur = ggml_mul_mat(ctx0,
|
|
1614
|
-
layer.mlp_0_w,
|
|
1615
|
-
cur);
|
|
1717
|
+
wstate.use_buf(ctx0, 1);
|
|
1616
1718
|
|
|
1617
|
-
|
|
1719
|
+
// projection
|
|
1720
|
+
cur = ggml_mul_mat(ctx0,
|
|
1721
|
+
layer.mlp_1_w,
|
|
1722
|
+
cur);
|
|
1618
1723
|
|
|
1619
|
-
|
|
1620
|
-
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
|
1621
|
-
cur);
|
|
1724
|
+
wstate.use_buf(ctx0, 0);
|
|
1622
1725
|
|
|
1623
|
-
|
|
1726
|
+
cur = ggml_add(ctx0,
|
|
1727
|
+
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
|
1728
|
+
cur);
|
|
1729
|
+
#endif
|
|
1730
|
+
}
|
|
1624
1731
|
|
|
1625
|
-
|
|
1626
|
-
cur = ggml_gelu(ctx0, cur);
|
|
1732
|
+
wstate.use_buf(ctx0, 3);
|
|
1627
1733
|
|
|
1628
|
-
|
|
1734
|
+
inpL = ggml_add(ctx0, cur, inpFF);
|
|
1735
|
+
}
|
|
1629
1736
|
|
|
1630
|
-
|
|
1631
|
-
cur = ggml_mul_mat(ctx0,
|
|
1632
|
-
layer.mlp_1_w,
|
|
1633
|
-
cur);
|
|
1737
|
+
cur = inpL;
|
|
1634
1738
|
|
|
1739
|
+
// norm
|
|
1740
|
+
{
|
|
1635
1741
|
wstate.use_buf(ctx0, 0);
|
|
1636
1742
|
|
|
1637
|
-
cur =
|
|
1638
|
-
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
|
1639
|
-
cur);
|
|
1640
|
-
#endif
|
|
1641
|
-
}
|
|
1642
|
-
|
|
1643
|
-
wstate.use_buf(ctx0, 3);
|
|
1743
|
+
cur = ggml_norm(ctx0, cur);
|
|
1644
1744
|
|
|
1645
|
-
|
|
1646
|
-
}
|
|
1745
|
+
wstate.use_buf(ctx0, 1);
|
|
1647
1746
|
|
|
1648
|
-
|
|
1747
|
+
// cur = ln_f_g*cur + ln_f_b
|
|
1748
|
+
cur = ggml_add(ctx0,
|
|
1749
|
+
ggml_mul(ctx0,
|
|
1750
|
+
ggml_repeat(ctx0, model.e_ln_w, cur),
|
|
1751
|
+
cur),
|
|
1752
|
+
ggml_repeat(ctx0, model.e_ln_b, cur));
|
|
1753
|
+
}
|
|
1649
1754
|
|
|
1650
|
-
|
|
1651
|
-
{
|
|
1652
|
-
wstate.use_buf(ctx0, 0);
|
|
1755
|
+
wstate.use_buf(ctx0, -1);
|
|
1653
1756
|
|
|
1654
|
-
|
|
1757
|
+
// run the computation
|
|
1758
|
+
{
|
|
1759
|
+
struct ggml_cgraph gf = {};
|
|
1760
|
+
gf.n_threads = n_threads;
|
|
1655
1761
|
|
|
1656
|
-
|
|
1762
|
+
ggml_build_forward_expand(&gf, cur);
|
|
1763
|
+
ggml_graph_compute(ctx0, &gf);
|
|
1657
1764
|
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
ggml_mul(ctx0,
|
|
1661
|
-
ggml_repeat(ctx0, model.e_ln_w, cur),
|
|
1662
|
-
cur),
|
|
1663
|
-
ggml_repeat(ctx0, model.e_ln_b, cur));
|
|
1765
|
+
//ggml_graph_print(&gf);
|
|
1766
|
+
}
|
|
1664
1767
|
}
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
// run the computation
|
|
1768
|
+
#ifdef WHISPER_USE_COREML
|
|
1769
|
+
else
|
|
1669
1770
|
{
|
|
1670
|
-
|
|
1671
|
-
gf.n_threads = n_threads;
|
|
1771
|
+
wstate.use_buf(ctx0, -1);
|
|
1672
1772
|
|
|
1673
|
-
|
|
1674
|
-
ggml_graph_compute(ctx0, &gf);
|
|
1773
|
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
|
1675
1774
|
|
|
1676
|
-
|
|
1775
|
+
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
|
1677
1776
|
}
|
|
1777
|
+
#endif
|
|
1678
1778
|
|
|
1679
1779
|
// cur
|
|
1680
1780
|
//{
|
|
@@ -1725,10 +1825,12 @@ static bool whisper_encode_internal(
|
|
|
1725
1825
|
|
|
1726
1826
|
wstate.use_buf(ctx0, -1);
|
|
1727
1827
|
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
|
1731
|
-
struct ggml_tensor* v =
|
|
1828
|
+
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
|
1829
|
+
|
|
1830
|
+
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
|
1831
|
+
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
|
1832
|
+
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
|
1833
|
+
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
|
1732
1834
|
|
|
1733
1835
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
|
|
1734
1836
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
|
@@ -1742,10 +1844,10 @@ static bool whisper_encode_internal(
|
|
|
1742
1844
|
|
|
1743
1845
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
1744
1846
|
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
1745
|
-
//
|
|
1746
|
-
//
|
|
1747
|
-
//
|
|
1748
|
-
//
|
|
1847
|
+
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
1848
|
+
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
|
1849
|
+
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
|
1850
|
+
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
|
1749
1851
|
|
|
1750
1852
|
ggml_free(ctx0);
|
|
1751
1853
|
|
|
@@ -1796,9 +1898,11 @@ static bool whisper_decode_internal(
|
|
|
1796
1898
|
|
|
1797
1899
|
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
|
1798
1900
|
|
|
1799
|
-
struct ggml_init_params params
|
|
1800
|
-
|
|
1801
|
-
|
|
1901
|
+
struct ggml_init_params params = {
|
|
1902
|
+
/*.mem_size =*/ wstate.buf_compute.size(),
|
|
1903
|
+
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
|
1904
|
+
/*.no_alloc =*/ false,
|
|
1905
|
+
};
|
|
1802
1906
|
|
|
1803
1907
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
1804
1908
|
|
|
@@ -1842,8 +1946,6 @@ static bool whisper_decode_internal(
|
|
|
1842
1946
|
|
|
1843
1947
|
// self-attention
|
|
1844
1948
|
{
|
|
1845
|
-
wstate.use_buf(ctx0, 1);
|
|
1846
|
-
|
|
1847
1949
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
|
1848
1950
|
layer.attn_q_w,
|
|
1849
1951
|
cur);
|
|
@@ -1863,20 +1965,24 @@ static bool whisper_decode_internal(
|
|
|
1863
1965
|
|
|
1864
1966
|
Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1865
1967
|
|
|
1866
|
-
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
|
1867
|
-
layer.attn_v_w,
|
|
1868
|
-
cur);
|
|
1869
|
-
|
|
1870
|
-
Vcur = ggml_add(ctx0,
|
|
1871
|
-
ggml_repeat(ctx0,
|
|
1872
|
-
layer.attn_v_b,
|
|
1873
|
-
Vcur),
|
|
1874
|
-
Vcur);
|
|
1875
|
-
|
|
1876
1968
|
// store key and value to memory
|
|
1877
1969
|
{
|
|
1970
|
+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
|
1971
|
+
layer.attn_v_w,
|
|
1972
|
+
cur);
|
|
1973
|
+
|
|
1974
|
+
Vcur = ggml_add(ctx0,
|
|
1975
|
+
ggml_repeat(ctx0,
|
|
1976
|
+
layer.attn_v_b,
|
|
1977
|
+
Vcur),
|
|
1978
|
+
Vcur);
|
|
1979
|
+
|
|
1980
|
+
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
|
|
1981
|
+
|
|
1878
1982
|
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
|
|
1879
|
-
struct ggml_tensor * v =
|
|
1983
|
+
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
|
|
1984
|
+
( n_ctx)*ggml_element_size(kv_self.v),
|
|
1985
|
+
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
|
|
1880
1986
|
|
|
1881
1987
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
|
1882
1988
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
|
@@ -1905,8 +2011,6 @@ static bool whisper_decode_internal(
|
|
|
1905
2011
|
// K * Q
|
|
1906
2012
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
1907
2013
|
|
|
1908
|
-
wstate.use_buf(ctx0, 0);
|
|
1909
|
-
|
|
1910
2014
|
//struct ggml_tensor * KQ_scaled =
|
|
1911
2015
|
// ggml_scale(ctx0,
|
|
1912
2016
|
// KQ,
|
|
@@ -1915,22 +2019,16 @@ static bool whisper_decode_internal(
|
|
|
1915
2019
|
|
|
1916
2020
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
1917
2021
|
|
|
1918
|
-
wstate.use_buf(ctx0, 1);
|
|
1919
|
-
|
|
1920
2022
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
|
1921
2023
|
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
n_state/n_head, n_head, n_past + N),
|
|
1929
|
-
1, 2, 0, 3);
|
|
1930
|
-
|
|
1931
|
-
wstate.use_buf(ctx0, 1);
|
|
2024
|
+
struct ggml_tensor * V =
|
|
2025
|
+
ggml_view_3d(ctx0, kv_self.v,
|
|
2026
|
+
n_past + N, n_state/n_head, n_head,
|
|
2027
|
+
n_ctx*ggml_element_size(kv_self.v),
|
|
2028
|
+
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
|
|
2029
|
+
il*n_ctx*ggml_element_size(kv_self.v)*n_state);
|
|
1932
2030
|
|
|
1933
|
-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0,
|
|
2031
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
1934
2032
|
|
|
1935
2033
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
1936
2034
|
|
|
@@ -1965,8 +2063,6 @@ static bool whisper_decode_internal(
|
|
|
1965
2063
|
|
|
1966
2064
|
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
|
1967
2065
|
|
|
1968
|
-
wstate.use_buf(ctx0, 1);
|
|
1969
|
-
|
|
1970
2066
|
// cur = ln_0_w*cur + ln_0_b
|
|
1971
2067
|
cur = ggml_add(ctx0,
|
|
1972
2068
|
ggml_mul(ctx0,
|
|
@@ -1977,8 +2073,6 @@ static bool whisper_decode_internal(
|
|
|
1977
2073
|
|
|
1978
2074
|
// cross-attention
|
|
1979
2075
|
{
|
|
1980
|
-
wstate.use_buf(ctx0, 0);
|
|
1981
|
-
|
|
1982
2076
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
|
1983
2077
|
layer.cross_attn_q_w,
|
|
1984
2078
|
cur);
|
|
@@ -1997,16 +2091,24 @@ static bool whisper_decode_internal(
|
|
|
1997
2091
|
ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
|
|
1998
2092
|
n_state/n_head, n_head, M);
|
|
1999
2093
|
|
|
2000
|
-
struct ggml_tensor * Vcross =
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2094
|
+
//struct ggml_tensor * Vcross =
|
|
2095
|
+
// ggml_reshape_3d(ctx0,
|
|
2096
|
+
// ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
|
|
2097
|
+
// n_state/n_head, n_head, M);
|
|
2004
2098
|
|
|
2005
|
-
struct ggml_tensor * V_trans =
|
|
2099
|
+
//struct ggml_tensor * V_trans =
|
|
2100
|
+
// ggml_cpy(ctx0,
|
|
2101
|
+
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
|
2102
|
+
// ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
|
|
2006
2103
|
|
|
2007
|
-
|
|
2104
|
+
struct ggml_tensor * V =
|
|
2105
|
+
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
|
2106
|
+
M, n_state/n_head, n_head,
|
|
2107
|
+
M*ggml_element_size(wstate.kv_cross.v),
|
|
2108
|
+
M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
|
2109
|
+
il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
|
|
2008
2110
|
|
|
2009
|
-
|
|
2111
|
+
// ------
|
|
2010
2112
|
|
|
2011
2113
|
struct ggml_tensor * Q =
|
|
2012
2114
|
ggml_permute(ctx0,
|
|
@@ -2017,8 +2119,6 @@ static bool whisper_decode_internal(
|
|
|
2017
2119
|
|
|
2018
2120
|
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
|
2019
2121
|
|
|
2020
|
-
wstate.use_buf(ctx0, 0);
|
|
2021
|
-
|
|
2022
2122
|
// K * Q
|
|
2023
2123
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
2024
2124
|
|
|
@@ -2031,15 +2131,9 @@ static bool whisper_decode_internal(
|
|
|
2031
2131
|
// no masking for cross-attention
|
|
2032
2132
|
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
|
2033
2133
|
|
|
2034
|
-
wstate.use_buf(ctx0, 1);
|
|
2035
|
-
|
|
2036
2134
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
|
2037
2135
|
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
|
2041
|
-
|
|
2042
|
-
wstate.use_buf(ctx0, 1);
|
|
2136
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
2043
2137
|
|
|
2044
2138
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2045
2139
|
|
|
@@ -2171,10 +2265,10 @@ static bool whisper_decode_internal(
|
|
|
2171
2265
|
if (N > 1) {
|
|
2172
2266
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
2173
2267
|
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
2174
|
-
//
|
|
2175
|
-
//
|
|
2176
|
-
//
|
|
2177
|
-
//
|
|
2268
|
+
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
2269
|
+
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
|
2270
|
+
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
|
2271
|
+
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
|
2178
2272
|
}
|
|
2179
2273
|
|
|
2180
2274
|
ggml_free(ctx0);
|
|
@@ -2282,6 +2376,68 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2282
2376
|
}
|
|
2283
2377
|
}
|
|
2284
2378
|
|
|
2379
|
+
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> &hann, const float *samples,
|
|
2380
|
+
int n_samples, int fft_size, int fft_step, int n_threads,
|
|
2381
|
+
const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
|
|
2382
|
+
std::vector<float> fft_in(fft_size, 0.0);
|
|
2383
|
+
std::vector<float> fft_out(2 * fft_size);
|
|
2384
|
+
int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
|
|
2385
|
+
|
|
2386
|
+
for (int i = ith; i < mel.n_len; i += n_threads) {
|
|
2387
|
+
const int offset = i * fft_step;
|
|
2388
|
+
|
|
2389
|
+
// apply Hanning window
|
|
2390
|
+
for (int j = 0; j < fft_size; j++) {
|
|
2391
|
+
if (offset + j < n_samples) {
|
|
2392
|
+
fft_in[j] = hann[j] * samples[offset + j];
|
|
2393
|
+
} else {
|
|
2394
|
+
fft_in[j] = 0.0;
|
|
2395
|
+
}
|
|
2396
|
+
}
|
|
2397
|
+
|
|
2398
|
+
// FFT -> mag^2
|
|
2399
|
+
fft(fft_in, fft_out);
|
|
2400
|
+
|
|
2401
|
+
for (int j = 0; j < fft_size; j++) {
|
|
2402
|
+
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
|
2403
|
+
}
|
|
2404
|
+
for (int j = 1; j < fft_size / 2; j++) {
|
|
2405
|
+
fft_out[j] += fft_out[fft_size - j];
|
|
2406
|
+
}
|
|
2407
|
+
|
|
2408
|
+
if (speed_up) {
|
|
2409
|
+
// scale down in the frequency domain results in a speed up in the time domain
|
|
2410
|
+
for (int j = 0; j < n_fft; j++) {
|
|
2411
|
+
fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
|
|
2412
|
+
}
|
|
2413
|
+
}
|
|
2414
|
+
|
|
2415
|
+
// mel spectrogram
|
|
2416
|
+
for (int j = 0; j < mel.n_mel; j++) {
|
|
2417
|
+
double sum = 0.0;
|
|
2418
|
+
|
|
2419
|
+
// unroll loop (suggested by GH user @lunixbochs)
|
|
2420
|
+
int k = 0;
|
|
2421
|
+
for (k = 0; k < n_fft - 3; k += 4) {
|
|
2422
|
+
sum +=
|
|
2423
|
+
fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
|
|
2424
|
+
fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
|
|
2425
|
+
fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
|
|
2426
|
+
fft_out[k + 3] * filters.data[j*n_fft + k + 3];
|
|
2427
|
+
}
|
|
2428
|
+
|
|
2429
|
+
// handle n_fft remainder
|
|
2430
|
+
for (; k < n_fft; k++) {
|
|
2431
|
+
sum += fft_out[k] * filters.data[j * n_fft + k];
|
|
2432
|
+
}
|
|
2433
|
+
|
|
2434
|
+
sum = log10(std::max(sum, 1e-10));
|
|
2435
|
+
|
|
2436
|
+
mel.data[j * mel.n_len + i] = sum;
|
|
2437
|
+
}
|
|
2438
|
+
}
|
|
2439
|
+
}
|
|
2440
|
+
|
|
2285
2441
|
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
|
2286
2442
|
static bool log_mel_spectrogram(
|
|
2287
2443
|
whisper_state & wstate,
|
|
@@ -2304,85 +2460,48 @@ static bool log_mel_spectrogram(
|
|
|
2304
2460
|
hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
|
|
2305
2461
|
}
|
|
2306
2462
|
|
|
2307
|
-
mel.n_mel
|
|
2308
|
-
mel.n_len
|
|
2309
|
-
mel.
|
|
2310
|
-
|
|
2311
|
-
const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
|
|
2312
|
-
|
|
2313
|
-
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
|
|
2314
|
-
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
|
|
2315
|
-
|
|
2316
|
-
std::vector<std::thread> workers(n_threads);
|
|
2317
|
-
for (int iw = 0; iw < n_threads; ++iw) {
|
|
2318
|
-
workers[iw] = std::thread([&](int ith) {
|
|
2319
|
-
std::vector<float> fft_in;
|
|
2320
|
-
fft_in.resize(fft_size);
|
|
2321
|
-
for (int i = 0; i < fft_size; i++) {
|
|
2322
|
-
fft_in[i] = 0.0;
|
|
2323
|
-
}
|
|
2463
|
+
mel.n_mel = n_mel;
|
|
2464
|
+
mel.n_len = n_samples/fft_step;
|
|
2465
|
+
mel.n_len_org = mel.n_len;
|
|
2324
2466
|
|
|
2325
|
-
|
|
2326
|
-
fft_out.resize(2*fft_size);
|
|
2467
|
+
std::vector<float> samples_padded;
|
|
2327
2468
|
|
|
2328
|
-
|
|
2329
|
-
|
|
2469
|
+
// pad audio with at least one extra chunk of zeros
|
|
2470
|
+
{
|
|
2471
|
+
const int pad = (100*WHISPER_CHUNK_SIZE)/2;
|
|
2330
2472
|
|
|
2331
|
-
|
|
2332
|
-
|
|
2333
|
-
|
|
2334
|
-
|
|
2335
|
-
} else {
|
|
2336
|
-
fft_in[j] = 0.0;
|
|
2337
|
-
}
|
|
2338
|
-
}
|
|
2473
|
+
if (mel.n_len % pad != 0) {
|
|
2474
|
+
mel.n_len = (mel.n_len/pad + 1)*pad;
|
|
2475
|
+
}
|
|
2476
|
+
mel.n_len += pad;
|
|
2339
2477
|
|
|
2340
|
-
|
|
2341
|
-
|
|
2478
|
+
samples_padded.resize(mel.n_len*fft_step);
|
|
2479
|
+
memcpy(samples_padded.data(), samples, n_samples*sizeof(float));
|
|
2480
|
+
memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float));
|
|
2342
2481
|
|
|
2343
|
-
|
|
2344
|
-
|
|
2345
|
-
}
|
|
2346
|
-
for (int j = 1; j < fft_size/2; j++) {
|
|
2347
|
-
//if (i == 0) {
|
|
2348
|
-
// printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
|
|
2349
|
-
//}
|
|
2350
|
-
fft_out[j] += fft_out[fft_size - j];
|
|
2351
|
-
}
|
|
2352
|
-
if (i == 0) {
|
|
2353
|
-
//for (int j = 0; j < fft_size; j++) {
|
|
2354
|
-
// printf("%d: %e\n", j, fft_out[j]);
|
|
2355
|
-
//}
|
|
2356
|
-
}
|
|
2357
|
-
|
|
2358
|
-
if (speed_up) {
|
|
2359
|
-
// scale down in the frequency domain results in a speed up in the time domain
|
|
2360
|
-
for (int j = 0; j < n_fft; j++) {
|
|
2361
|
-
fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
|
|
2362
|
-
}
|
|
2363
|
-
}
|
|
2482
|
+
samples = samples_padded.data();
|
|
2483
|
+
}
|
|
2364
2484
|
|
|
2365
|
-
|
|
2366
|
-
for (int j = 0; j < mel.n_mel; j++) {
|
|
2367
|
-
double sum = 0.0;
|
|
2485
|
+
mel.data.resize(mel.n_mel*mel.n_len);
|
|
2368
2486
|
|
|
2369
|
-
|
|
2370
|
-
|
|
2371
|
-
}
|
|
2372
|
-
if (sum < 1e-10) {
|
|
2373
|
-
sum = 1e-10;
|
|
2374
|
-
}
|
|
2487
|
+
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
|
|
2488
|
+
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
|
|
2375
2489
|
|
|
2376
|
-
|
|
2490
|
+
{
|
|
2491
|
+
std::vector<std::thread> workers(n_threads - 1);
|
|
2492
|
+
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
|
2493
|
+
workers[iw] = std::thread(
|
|
2494
|
+
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
|
|
2495
|
+
n_samples, fft_size, fft_step, n_threads,
|
|
2496
|
+
std::cref(filters), speed_up, std::ref(mel));
|
|
2497
|
+
}
|
|
2377
2498
|
|
|
2378
|
-
|
|
2379
|
-
|
|
2380
|
-
}
|
|
2381
|
-
}, iw);
|
|
2382
|
-
}
|
|
2499
|
+
// main thread
|
|
2500
|
+
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
|
|
2383
2501
|
|
|
2384
|
-
|
|
2385
|
-
|
|
2502
|
+
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
|
2503
|
+
workers[iw].join();
|
|
2504
|
+
}
|
|
2386
2505
|
}
|
|
2387
2506
|
|
|
2388
2507
|
// clamping and normalization
|
|
@@ -2406,6 +2525,8 @@ static bool log_mel_spectrogram(
|
|
|
2406
2525
|
|
|
2407
2526
|
wstate.t_mel_us += ggml_time_us() - t_start_us;
|
|
2408
2527
|
|
|
2528
|
+
//printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step);
|
|
2529
|
+
|
|
2409
2530
|
return true;
|
|
2410
2531
|
}
|
|
2411
2532
|
|
|
@@ -2447,25 +2568,20 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
|
|
|
2447
2568
|
int n = word.size();
|
|
2448
2569
|
while (i < n) {
|
|
2449
2570
|
int j = n;
|
|
2571
|
+
bool found = false;
|
|
2450
2572
|
while (j > i) {
|
|
2451
|
-
auto
|
|
2573
|
+
auto sub = word.substr(i, j-i);
|
|
2574
|
+
auto it = vocab.token_to_id.find(sub);
|
|
2452
2575
|
if (it != vocab.token_to_id.end()) {
|
|
2453
2576
|
tokens.push_back(it->second);
|
|
2454
2577
|
i = j;
|
|
2578
|
+
found = true;
|
|
2455
2579
|
break;
|
|
2456
2580
|
}
|
|
2457
2581
|
--j;
|
|
2458
2582
|
}
|
|
2459
|
-
if (
|
|
2460
|
-
|
|
2461
|
-
}
|
|
2462
|
-
if (j == i) {
|
|
2463
|
-
auto sub = word.substr(i, 1);
|
|
2464
|
-
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
|
2465
|
-
tokens.push_back(vocab.token_to_id.at(sub));
|
|
2466
|
-
} else {
|
|
2467
|
-
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
|
2468
|
-
}
|
|
2583
|
+
if (!found) {
|
|
2584
|
+
fprintf(stderr, "unknown token \n");
|
|
2469
2585
|
++i;
|
|
2470
2586
|
}
|
|
2471
2587
|
}
|
|
@@ -2478,14 +2594,28 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
|
|
|
2478
2594
|
// interface implementation
|
|
2479
2595
|
//
|
|
2480
2596
|
|
|
2597
|
+
#ifdef WHISPER_USE_COREML
|
|
2598
|
+
// replace .bin with -encoder.mlmodelc
|
|
2599
|
+
static std::string whisper_get_coreml_path_encoder(std::string path_bin) {
|
|
2600
|
+
auto pos = path_bin.rfind('.');
|
|
2601
|
+
if (pos != std::string::npos) {
|
|
2602
|
+
path_bin = path_bin.substr(0, pos);
|
|
2603
|
+
}
|
|
2604
|
+
|
|
2605
|
+
path_bin += "-encoder.mlmodelc";
|
|
2606
|
+
|
|
2607
|
+
return path_bin;
|
|
2608
|
+
}
|
|
2609
|
+
#endif
|
|
2610
|
+
|
|
2481
2611
|
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
2482
2612
|
whisper_state * state = new whisper_state;
|
|
2483
2613
|
|
|
2484
|
-
const size_t scale = ctx->model.hparams.
|
|
2485
|
-
|
|
2614
|
+
const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
|
|
2486
2615
|
|
|
2487
|
-
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->
|
|
2616
|
+
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
|
2488
2617
|
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
|
2618
|
+
delete state;
|
|
2489
2619
|
return nullptr;
|
|
2490
2620
|
}
|
|
2491
2621
|
|
|
@@ -2494,8 +2624,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2494
2624
|
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
|
2495
2625
|
}
|
|
2496
2626
|
|
|
2497
|
-
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->
|
|
2627
|
+
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
|
2498
2628
|
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
2629
|
+
delete state;
|
|
2499
2630
|
return nullptr;
|
|
2500
2631
|
}
|
|
2501
2632
|
|
|
@@ -2504,6 +2635,22 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2504
2635
|
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
|
2505
2636
|
}
|
|
2506
2637
|
|
|
2638
|
+
#ifdef WHISPER_USE_COREML
|
|
2639
|
+
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
|
2640
|
+
|
|
2641
|
+
fprintf(stderr, "%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
2642
|
+
fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__);
|
|
2643
|
+
|
|
2644
|
+
state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
|
|
2645
|
+
if (!state->ctx_coreml) {
|
|
2646
|
+
fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
|
2647
|
+
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
|
2648
|
+
return nullptr;
|
|
2649
|
+
#endif
|
|
2650
|
+
} else {
|
|
2651
|
+
fprintf(stderr, "%s: Core ML model loaded\n", __func__);
|
|
2652
|
+
}
|
|
2653
|
+
#endif
|
|
2507
2654
|
|
|
2508
2655
|
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
|
2509
2656
|
|
|
@@ -2528,7 +2675,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
|
2528
2675
|
}
|
|
2529
2676
|
|
|
2530
2677
|
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
|
2531
|
-
whisper_model_loader loader = {};
|
|
2532
2678
|
|
|
2533
2679
|
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
|
|
2534
2680
|
|
|
@@ -2538,7 +2684,10 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
|
|
|
2538
2684
|
return nullptr;
|
|
2539
2685
|
}
|
|
2540
2686
|
|
|
2687
|
+
whisper_model_loader loader = {};
|
|
2688
|
+
|
|
2541
2689
|
loader.context = &fin;
|
|
2690
|
+
|
|
2542
2691
|
loader.read = [](void * ctx, void * output, size_t read_size) {
|
|
2543
2692
|
std::ifstream * fin = (std::ifstream*)ctx;
|
|
2544
2693
|
fin->read((char *)output, read_size);
|
|
@@ -2555,7 +2704,13 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
|
|
|
2555
2704
|
fin->close();
|
|
2556
2705
|
};
|
|
2557
2706
|
|
|
2558
|
-
|
|
2707
|
+
auto ctx = whisper_init_no_state(&loader);
|
|
2708
|
+
|
|
2709
|
+
if (ctx) {
|
|
2710
|
+
ctx->path_model = path_model;
|
|
2711
|
+
}
|
|
2712
|
+
|
|
2713
|
+
return ctx;
|
|
2559
2714
|
}
|
|
2560
2715
|
|
|
2561
2716
|
struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
|
|
@@ -2566,10 +2721,11 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
|
|
|
2566
2721
|
};
|
|
2567
2722
|
|
|
2568
2723
|
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
|
|
2569
|
-
whisper_model_loader loader = {};
|
|
2570
2724
|
|
|
2571
2725
|
fprintf(stderr, "%s: loading model from buffer\n", __func__);
|
|
2572
2726
|
|
|
2727
|
+
whisper_model_loader loader = {};
|
|
2728
|
+
|
|
2573
2729
|
loader.context = &ctx;
|
|
2574
2730
|
|
|
2575
2731
|
loader.read = [](void * ctx, void * output, size_t read_size) {
|
|
@@ -2665,6 +2821,13 @@ void whisper_free_state(struct whisper_state * state)
|
|
|
2665
2821
|
kv_cache_free(state->decoders[i].kv_self);
|
|
2666
2822
|
}
|
|
2667
2823
|
|
|
2824
|
+
#ifdef WHISPER_USE_COREML
|
|
2825
|
+
if (state->ctx_coreml != nullptr) {
|
|
2826
|
+
whisper_coreml_free(state->ctx_coreml);
|
|
2827
|
+
state->ctx_coreml = nullptr;
|
|
2828
|
+
}
|
|
2829
|
+
#endif
|
|
2830
|
+
|
|
2668
2831
|
delete state;
|
|
2669
2832
|
}
|
|
2670
2833
|
}
|
|
@@ -2723,8 +2886,9 @@ int whisper_set_mel_with_state(
|
|
|
2723
2886
|
return -1;
|
|
2724
2887
|
}
|
|
2725
2888
|
|
|
2726
|
-
state->mel.n_len
|
|
2727
|
-
state->mel.
|
|
2889
|
+
state->mel.n_len = n_len;
|
|
2890
|
+
state->mel.n_len_org = n_len;
|
|
2891
|
+
state->mel.n_mel = n_mel;
|
|
2728
2892
|
|
|
2729
2893
|
state->mel.data.resize(n_len*n_mel);
|
|
2730
2894
|
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
|
|
@@ -2822,7 +2986,6 @@ int whisper_lang_id(const char * lang) {
|
|
|
2822
2986
|
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
|
2823
2987
|
return -1;
|
|
2824
2988
|
}
|
|
2825
|
-
|
|
2826
2989
|
return g_lang.at(lang).first;
|
|
2827
2990
|
}
|
|
2828
2991
|
|
|
@@ -2850,13 +3013,13 @@ int whisper_lang_auto_detect_with_state(
|
|
|
2850
3013
|
return -1;
|
|
2851
3014
|
}
|
|
2852
3015
|
|
|
2853
|
-
if (seek >= state->mel.
|
|
2854
|
-
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.
|
|
3016
|
+
if (seek >= state->mel.n_len_org) {
|
|
3017
|
+
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
|
|
2855
3018
|
return -2;
|
|
2856
3019
|
}
|
|
2857
3020
|
|
|
2858
3021
|
// run the encoder
|
|
2859
|
-
if (
|
|
3022
|
+
if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
|
|
2860
3023
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
2861
3024
|
return -6;
|
|
2862
3025
|
}
|
|
@@ -2920,12 +3083,77 @@ int whisper_lang_auto_detect(
|
|
|
2920
3083
|
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
|
|
2921
3084
|
}
|
|
2922
3085
|
|
|
3086
|
+
int whisper_model_n_vocab(struct whisper_context * ctx) {
|
|
3087
|
+
return ctx->model.hparams.n_vocab;
|
|
3088
|
+
}
|
|
3089
|
+
|
|
3090
|
+
int whisper_model_n_audio_ctx(struct whisper_context * ctx) {
|
|
3091
|
+
return ctx->model.hparams.n_audio_ctx;
|
|
3092
|
+
}
|
|
3093
|
+
|
|
3094
|
+
int whisper_model_n_audio_state(struct whisper_context * ctx) {
|
|
3095
|
+
return ctx->model.hparams.n_audio_state;
|
|
3096
|
+
}
|
|
3097
|
+
|
|
3098
|
+
int whisper_model_n_audio_head(struct whisper_context * ctx) {
|
|
3099
|
+
return ctx->model.hparams.n_audio_head;
|
|
3100
|
+
}
|
|
3101
|
+
|
|
3102
|
+
int whisper_model_n_audio_layer(struct whisper_context * ctx) {
|
|
3103
|
+
return ctx->model.hparams.n_audio_layer;
|
|
3104
|
+
}
|
|
3105
|
+
|
|
3106
|
+
int whisper_model_n_text_ctx(struct whisper_context * ctx) {
|
|
3107
|
+
return ctx->model.hparams.n_text_ctx;
|
|
3108
|
+
}
|
|
3109
|
+
|
|
3110
|
+
int whisper_model_n_text_state(struct whisper_context * ctx) {
|
|
3111
|
+
return ctx->model.hparams.n_text_state;
|
|
3112
|
+
}
|
|
3113
|
+
|
|
3114
|
+
int whisper_model_n_text_head(struct whisper_context * ctx) {
|
|
3115
|
+
return ctx->model.hparams.n_text_head;
|
|
3116
|
+
}
|
|
3117
|
+
|
|
3118
|
+
int whisper_model_n_text_layer(struct whisper_context * ctx) {
|
|
3119
|
+
return ctx->model.hparams.n_text_layer;
|
|
3120
|
+
}
|
|
3121
|
+
|
|
3122
|
+
int whisper_model_n_mels(struct whisper_context * ctx) {
|
|
3123
|
+
return ctx->model.hparams.n_mels;
|
|
3124
|
+
}
|
|
3125
|
+
|
|
3126
|
+
int whisper_model_ftype(struct whisper_context * ctx) {
|
|
3127
|
+
return ctx->model.hparams.ftype;
|
|
3128
|
+
}
|
|
3129
|
+
|
|
3130
|
+
int whisper_model_type(struct whisper_context * ctx) {
|
|
3131
|
+
return ctx->model.type;
|
|
3132
|
+
}
|
|
3133
|
+
|
|
3134
|
+
const char *whisper_model_type_readable(struct whisper_context * ctx) {
|
|
3135
|
+
switch (ctx->model.type) {
|
|
3136
|
+
case e_model::MODEL_TINY:
|
|
3137
|
+
return "tiny";
|
|
3138
|
+
case e_model::MODEL_BASE:
|
|
3139
|
+
return "base";
|
|
3140
|
+
case e_model::MODEL_SMALL:
|
|
3141
|
+
return "small";
|
|
3142
|
+
case e_model::MODEL_MEDIUM:
|
|
3143
|
+
return "medium";
|
|
3144
|
+
case e_model::MODEL_LARGE:
|
|
3145
|
+
return "large";
|
|
3146
|
+
default:
|
|
3147
|
+
return "unknown";
|
|
3148
|
+
}
|
|
3149
|
+
}
|
|
3150
|
+
|
|
2923
3151
|
int whisper_n_len_from_state(struct whisper_state * state) {
|
|
2924
|
-
return state->mel.
|
|
3152
|
+
return state->mel.n_len_org;
|
|
2925
3153
|
}
|
|
2926
3154
|
|
|
2927
3155
|
int whisper_n_len(struct whisper_context * ctx) {
|
|
2928
|
-
return ctx->state->mel.
|
|
3156
|
+
return ctx->state->mel.n_len_org;
|
|
2929
3157
|
}
|
|
2930
3158
|
|
|
2931
3159
|
int whisper_n_vocab(struct whisper_context * ctx) {
|
|
@@ -3021,6 +3249,14 @@ void whisper_reset_timings(struct whisper_context * ctx) {
|
|
|
3021
3249
|
}
|
|
3022
3250
|
}
|
|
3023
3251
|
|
|
3252
|
+
static int whisper_has_coreml(void) {
|
|
3253
|
+
#ifdef WHISPER_USE_COREML
|
|
3254
|
+
return 1;
|
|
3255
|
+
#else
|
|
3256
|
+
return 0;
|
|
3257
|
+
#endif
|
|
3258
|
+
}
|
|
3259
|
+
|
|
3024
3260
|
const char * whisper_print_system_info(void) {
|
|
3025
3261
|
static std::string s;
|
|
3026
3262
|
|
|
@@ -3037,6 +3273,7 @@ const char * whisper_print_system_info(void) {
|
|
|
3037
3273
|
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
|
3038
3274
|
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
|
3039
3275
|
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
|
3276
|
+
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
|
3040
3277
|
|
|
3041
3278
|
return s.c_str();
|
|
3042
3279
|
}
|
|
@@ -3070,10 +3307,12 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3070
3307
|
/*.speed_up =*/ false,
|
|
3071
3308
|
/*.audio_ctx =*/ 0,
|
|
3072
3309
|
|
|
3310
|
+
/*.initial_prompt =*/ nullptr,
|
|
3073
3311
|
/*.prompt_tokens =*/ nullptr,
|
|
3074
3312
|
/*.prompt_n_tokens =*/ 0,
|
|
3075
3313
|
|
|
3076
3314
|
/*.language =*/ "en",
|
|
3315
|
+
/*.detect_language =*/ false,
|
|
3077
3316
|
|
|
3078
3317
|
/*.suppress_blank =*/ true,
|
|
3079
3318
|
/*.suppress_non_speech_tokens =*/ false,
|
|
@@ -3082,7 +3321,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3082
3321
|
/*.max_initial_ts =*/ 1.0f,
|
|
3083
3322
|
/*.length_penalty =*/ -1.0f,
|
|
3084
3323
|
|
|
3085
|
-
/*.temperature_inc =*/ 0.
|
|
3324
|
+
/*.temperature_inc =*/ 0.4f,
|
|
3086
3325
|
/*.entropy_thold =*/ 2.4f,
|
|
3087
3326
|
/*.logprob_thold =*/ -1.0f,
|
|
3088
3327
|
/*.no_speech_thold =*/ 0.6f,
|
|
@@ -3100,6 +3339,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3100
3339
|
/*.new_segment_callback =*/ nullptr,
|
|
3101
3340
|
/*.new_segment_callback_user_data =*/ nullptr,
|
|
3102
3341
|
|
|
3342
|
+
/*.progress_callback =*/ nullptr,
|
|
3343
|
+
/*.progress_callback_user_data =*/ nullptr,
|
|
3344
|
+
|
|
3103
3345
|
/*.encoder_begin_callback =*/ nullptr,
|
|
3104
3346
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
3105
3347
|
|
|
@@ -3111,13 +3353,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
3111
3353
|
case WHISPER_SAMPLING_GREEDY:
|
|
3112
3354
|
{
|
|
3113
3355
|
result.greedy = {
|
|
3114
|
-
/*.best_of =*/
|
|
3356
|
+
/*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
|
|
3115
3357
|
};
|
|
3116
3358
|
} break;
|
|
3117
3359
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
3118
3360
|
{
|
|
3119
3361
|
result.beam_search = {
|
|
3120
|
-
/*.beam_size =*/ 5
|
|
3362
|
+
/*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
|
|
3121
3363
|
|
|
3122
3364
|
/*.patience =*/ -1.0f,
|
|
3123
3365
|
};
|
|
@@ -3138,15 +3380,15 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
3138
3380
|
|
|
3139
3381
|
// trim from start (in place)
|
|
3140
3382
|
static inline void ltrim(std::string &s) {
|
|
3141
|
-
s.erase(s.begin(), std::
|
|
3142
|
-
return
|
|
3383
|
+
s.erase(s.begin(), std::find_if_not(s.begin(), s.end(), [](unsigned char ch) {
|
|
3384
|
+
return std::isspace(ch);
|
|
3143
3385
|
}));
|
|
3144
3386
|
}
|
|
3145
3387
|
|
|
3146
3388
|
// trim from end (in place)
|
|
3147
3389
|
static inline void rtrim(std::string &s) {
|
|
3148
|
-
s.erase(std::
|
|
3149
|
-
return
|
|
3390
|
+
s.erase(std::find_if_not(s.rbegin(), s.rend(), [](unsigned char ch) {
|
|
3391
|
+
return std::isspace(ch);
|
|
3150
3392
|
}).base(), s.end());
|
|
3151
3393
|
}
|
|
3152
3394
|
|
|
@@ -3657,7 +3899,7 @@ int whisper_full_with_state(
|
|
|
3657
3899
|
}
|
|
3658
3900
|
|
|
3659
3901
|
// auto-detect language if not specified
|
|
3660
|
-
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
|
3902
|
+
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
|
|
3661
3903
|
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
|
3662
3904
|
|
|
3663
3905
|
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
|
@@ -3669,6 +3911,9 @@ int whisper_full_with_state(
|
|
|
3669
3911
|
params.language = whisper_lang_str(lang_id);
|
|
3670
3912
|
|
|
3671
3913
|
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
|
3914
|
+
if (params.detect_language) {
|
|
3915
|
+
return 0;
|
|
3916
|
+
}
|
|
3672
3917
|
}
|
|
3673
3918
|
|
|
3674
3919
|
if (params.token_timestamps) {
|
|
@@ -3679,7 +3924,7 @@ int whisper_full_with_state(
|
|
|
3679
3924
|
}
|
|
3680
3925
|
|
|
3681
3926
|
const int seek_start = params.offset_ms/10;
|
|
3682
|
-
const int seek_end =
|
|
3927
|
+
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
|
|
3683
3928
|
|
|
3684
3929
|
// if length of spectrogram is less than 1s (100 samples), then return
|
|
3685
3930
|
// basically don't process anything that is less than 1s
|
|
@@ -3742,13 +3987,26 @@ int whisper_full_with_state(
|
|
|
3742
3987
|
prompt_past.clear();
|
|
3743
3988
|
}
|
|
3744
3989
|
|
|
3745
|
-
//
|
|
3746
|
-
|
|
3747
|
-
|
|
3748
|
-
|
|
3749
|
-
|
|
3990
|
+
// prepare prompt
|
|
3991
|
+
{
|
|
3992
|
+
std::vector<whisper_token> prompt_tokens;
|
|
3993
|
+
|
|
3994
|
+
// initial prompt
|
|
3995
|
+
if (!params.prompt_tokens && params.initial_prompt) {
|
|
3996
|
+
prompt_tokens.resize(1024);
|
|
3997
|
+
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
|
|
3998
|
+
params.prompt_tokens = prompt_tokens.data();
|
|
3999
|
+
params.prompt_n_tokens = prompt_tokens.size();
|
|
4000
|
+
}
|
|
4001
|
+
|
|
4002
|
+
// prepend the prompt tokens to the prompt_past
|
|
4003
|
+
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
|
|
4004
|
+
// parse tokens from the pointer
|
|
4005
|
+
for (int i = 0; i < params.prompt_n_tokens; i++) {
|
|
4006
|
+
prompt_past.push_back(params.prompt_tokens[i]);
|
|
4007
|
+
}
|
|
4008
|
+
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
|
3750
4009
|
}
|
|
3751
|
-
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
|
3752
4010
|
}
|
|
3753
4011
|
|
|
3754
4012
|
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
|
@@ -3807,6 +4065,10 @@ int whisper_full_with_state(
|
|
|
3807
4065
|
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
|
|
3808
4066
|
}
|
|
3809
4067
|
}
|
|
4068
|
+
if (params.progress_callback) {
|
|
4069
|
+
params.progress_callback(
|
|
4070
|
+
ctx, ctx->state, progress_prev, params.progress_callback_user_data);
|
|
4071
|
+
}
|
|
3810
4072
|
|
|
3811
4073
|
// of only 1 second left, then stop
|
|
3812
4074
|
if (seek + 100 >= seek_end) {
|
|
@@ -4196,7 +4458,11 @@ int whisper_full_with_state(
|
|
|
4196
4458
|
}
|
|
4197
4459
|
|
|
4198
4460
|
// was the decoding successful for the current temperature?
|
|
4199
|
-
|
|
4461
|
+
// do fallback only if:
|
|
4462
|
+
// - we are not at the last temperature
|
|
4463
|
+
// - we are not at the end of the audio (3 sec)
|
|
4464
|
+
if (it != (int) temperatures.size() - 1 &&
|
|
4465
|
+
seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
|
|
4200
4466
|
bool success = true;
|
|
4201
4467
|
|
|
4202
4468
|
const auto & decoder = state->decoders[best_decoder_id];
|
|
@@ -4395,6 +4661,9 @@ int whisper_full_parallel(
|
|
|
4395
4661
|
params_cur.new_segment_callback = nullptr;
|
|
4396
4662
|
params_cur.new_segment_callback_user_data = nullptr;
|
|
4397
4663
|
|
|
4664
|
+
params_cur.progress_callback = nullptr;
|
|
4665
|
+
params_cur.progress_callback_user_data = nullptr;
|
|
4666
|
+
|
|
4398
4667
|
workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
|
|
4399
4668
|
}
|
|
4400
4669
|
|
|
@@ -4562,49 +4831,51 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
|
4562
4831
|
|
|
4563
4832
|
ggml_time_init();
|
|
4564
4833
|
|
|
4565
|
-
size_t n =
|
|
4566
|
-
size_t arr = n_threads > 0 ?
|
|
4834
|
+
size_t n = 20;
|
|
4835
|
+
size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations
|
|
4567
4836
|
|
|
4568
|
-
//
|
|
4837
|
+
// 1GB MB array
|
|
4569
4838
|
const size_t size = arr*1024llu*1024llu;
|
|
4570
4839
|
|
|
4571
|
-
|
|
4572
|
-
|
|
4840
|
+
// single-thread
|
|
4841
|
+
{
|
|
4842
|
+
char * src = (char *) malloc(size);
|
|
4843
|
+
char * dst = (char *) malloc(size);
|
|
4573
4844
|
|
|
4574
|
-
|
|
4845
|
+
for (size_t i = 0; i < size; i++) src[i] = i;
|
|
4575
4846
|
|
|
4576
|
-
|
|
4847
|
+
memcpy(dst, src, size); // heat-up
|
|
4577
4848
|
|
|
4578
|
-
|
|
4849
|
+
double tsum = 0.0;
|
|
4850
|
+
double sum = 0.0;
|
|
4579
4851
|
|
|
4580
|
-
|
|
4581
|
-
|
|
4852
|
+
for (size_t i = 0; i < n; i++) {
|
|
4853
|
+
const int64_t t0 = ggml_time_us();
|
|
4582
4854
|
|
|
4583
|
-
|
|
4855
|
+
memcpy(dst, src, size);
|
|
4584
4856
|
|
|
4585
|
-
|
|
4857
|
+
const int64_t t1 = ggml_time_us();
|
|
4586
4858
|
|
|
4587
|
-
|
|
4859
|
+
tsum += (t1 - t0)*1e-6;
|
|
4588
4860
|
|
|
4589
|
-
|
|
4590
|
-
|
|
4861
|
+
src[rand() % size] = rand() % 256;
|
|
4862
|
+
}
|
|
4591
4863
|
|
|
4592
|
-
|
|
4593
|
-
|
|
4864
|
+
snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
|
|
4865
|
+
s += strbuf;
|
|
4594
4866
|
|
|
4595
|
-
|
|
4596
|
-
|
|
4597
|
-
|
|
4867
|
+
// needed to prevent the compiler from optimizing the memcpy away
|
|
4868
|
+
{
|
|
4869
|
+
for (size_t i = 0; i < size; i++) sum += dst[i];
|
|
4598
4870
|
|
|
4599
|
-
|
|
4871
|
+
snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
|
|
4872
|
+
s += strbuf;
|
|
4873
|
+
}
|
|
4600
4874
|
|
|
4601
|
-
|
|
4602
|
-
|
|
4875
|
+
free(src);
|
|
4876
|
+
free(dst);
|
|
4603
4877
|
}
|
|
4604
4878
|
|
|
4605
|
-
free(src);
|
|
4606
|
-
free(dst);
|
|
4607
|
-
|
|
4608
4879
|
return s.c_str();
|
|
4609
4880
|
}
|
|
4610
4881
|
|
|
@@ -4634,27 +4905,48 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
|
4634
4905
|
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
|
4635
4906
|
std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
|
|
4636
4907
|
|
|
4908
|
+
// put a bunch of random data in the buffer
|
|
4637
4909
|
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
|
4638
4910
|
|
|
4639
4911
|
for (int j = 0; j < (int) sizes.size(); j++) {
|
|
4912
|
+
int n_q4_0 = 0;
|
|
4913
|
+
int n_q4_1 = 0;
|
|
4914
|
+
int n_q4_2 = 0;
|
|
4915
|
+
int n_q5_0 = 0;
|
|
4916
|
+
int n_q5_1 = 0;
|
|
4917
|
+
int n_q8_0 = 0;
|
|
4640
4918
|
int n_fp16 = 0;
|
|
4641
4919
|
int n_fp32 = 0;
|
|
4642
4920
|
|
|
4643
4921
|
// GFLOPS/s
|
|
4922
|
+
double s_q4_0 = 0.0;
|
|
4923
|
+
double s_q4_1 = 0.0;
|
|
4924
|
+
double s_q4_2 = 0.0;
|
|
4925
|
+
double s_q5_0 = 0.0;
|
|
4926
|
+
double s_q5_1 = 0.0;
|
|
4927
|
+
double s_q8_0 = 0.0;
|
|
4644
4928
|
double s_fp16 = 0.0;
|
|
4645
4929
|
double s_fp32 = 0.0;
|
|
4646
4930
|
|
|
4647
4931
|
const size_t N = sizes[j];
|
|
4648
4932
|
|
|
4649
|
-
for (int k = 0; k <
|
|
4650
|
-
const ggml_type wtype =
|
|
4933
|
+
for (int k = 0; k < 8; ++k) {
|
|
4934
|
+
const ggml_type wtype =
|
|
4935
|
+
k == 0 ? GGML_TYPE_Q4_0 :
|
|
4936
|
+
k == 1 ? GGML_TYPE_Q4_1 :
|
|
4937
|
+
k == 2 ? GGML_TYPE_Q4_2 :
|
|
4938
|
+
k == 3 ? GGML_TYPE_Q5_0 :
|
|
4939
|
+
k == 4 ? GGML_TYPE_Q5_1 :
|
|
4940
|
+
k == 5 ? GGML_TYPE_Q8_0 :
|
|
4941
|
+
k == 6 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
|
4651
4942
|
|
|
4652
|
-
double & s = k == 0 ? s_fp16 : s_fp32;
|
|
4653
|
-
int & n = k == 0 ? n_fp16
|
|
4943
|
+
double & s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 : k == 2 ? s_q4_2 : k == 3 ? s_q5_0 : k == 4 ? s_q5_1 : k == 5 ? s_q8_0 : k == 6 ? s_fp16 : /*k == 7*/ s_fp32;
|
|
4944
|
+
int & n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 : k == 2 ? n_q4_2 : k == 3 ? n_q5_0 : k == 4 ? n_q5_1 : k == 5 ? n_q8_0 : k == 6 ? n_fp16 : /*k == 7*/ n_fp32;
|
|
4654
4945
|
|
|
4655
4946
|
struct ggml_init_params gparams = {
|
|
4656
4947
|
/*.mem_size =*/ buf.size(),
|
|
4657
4948
|
/*.mem_buffer =*/ buf.data(),
|
|
4949
|
+
/*.no_alloc =*/ false,
|
|
4658
4950
|
};
|
|
4659
4951
|
|
|
4660
4952
|
struct ggml_context * ctx0 = ggml_init(gparams);
|
|
@@ -4693,8 +4985,19 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
|
4693
4985
|
s = ((2.0*N*N*N*n)/tsum)*1e-9;
|
|
4694
4986
|
}
|
|
4695
4987
|
|
|
4696
|
-
|
|
4697
|
-
|
|
4988
|
+
// Q4_0 | Q4_1 | Q4_2
|
|
4989
|
+
snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q4_0 %7.1f GFLOPS (%3d runs) | Q4_1 %7.1f GFLOPS (%3d runs) | Q4_2 %7.1f GFLOPS (%3d runs)\n",
|
|
4990
|
+
N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1, s_q4_2, n_q4_2);
|
|
4991
|
+
s += strbuf;
|
|
4992
|
+
|
|
4993
|
+
// Q5_0 | Q5_1 | Q8_0
|
|
4994
|
+
snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q5_0 %7.1f GFLOPS (%3d runs) | Q5_1 %7.1f GFLOPS (%3d runs) | Q8_0 %7.1f GFLOPS (%3d runs)\n",
|
|
4995
|
+
N, N, s_q5_0, n_q5_0, s_q5_1, n_q5_1, s_q8_0, n_q8_0);
|
|
4996
|
+
s += strbuf;
|
|
4997
|
+
|
|
4998
|
+
// F16 | F32
|
|
4999
|
+
snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: F16 %7.1f GFLOPS (%3d runs) | F32 %7.1f GFLOPS (%3d runs)\n",
|
|
5000
|
+
N, N, s_fp16, n_fp16, s_fp32, n_fp32);
|
|
4698
5001
|
s += strbuf;
|
|
4699
5002
|
}
|
|
4700
5003
|
|