whisper.rn 0.2.5 → 0.3.0-rc.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/whisper.cpp CHANGED
@@ -1,5 +1,7 @@
1
- #define WHISPER_BUILD
2
1
  #include "whisper.h"
2
+ #if WHISPER_USE_COREML
3
+ #include "coreml/whisper-encoder.h"
4
+ #endif
3
5
 
4
6
  #include "ggml.h"
5
7
 
@@ -99,7 +101,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
99
101
  #define WHISPER_PRINT_DEBUG(...)
100
102
  #endif
101
103
 
102
- #define WHISPER_USE_FLASH_ATTN
104
+ //#define WHISPER_USE_FLASH_ATTN
103
105
  //#define WHISPER_USE_FLASH_FF
104
106
  #define WHISPER_MAX_DECODERS 16
105
107
 
@@ -218,14 +220,14 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
218
220
  { "su", { 98, "sundanese", } },
219
221
  };
220
222
 
221
- static const size_t MB = 1024*1024;
223
+ static const size_t MB = 1ull*1024*1024;
222
224
 
223
225
  static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
224
- { MODEL_TINY, 12ull*MB },
225
- { MODEL_BASE, 15ull*MB },
226
- { MODEL_SMALL, 23ull*MB },
227
- { MODEL_MEDIUM, 31ull*MB },
228
- { MODEL_LARGE, 38ull*MB },
226
+ { MODEL_TINY, 62ull*MB },
227
+ { MODEL_BASE, 80ull*MB },
228
+ { MODEL_SMALL, 120ull*MB },
229
+ { MODEL_MEDIUM, 158ull*MB },
230
+ { MODEL_LARGE, 198ull*MB },
229
231
  };
230
232
 
231
233
  static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
@@ -252,12 +254,79 @@ static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
252
254
  { MODEL_LARGE, 9ull*MB },
253
255
  };
254
256
 
255
- static const std::map<e_model, size_t> MEM_REQ_MODEL = {
256
- { MODEL_TINY, 74ull*MB },
257
- { MODEL_BASE, 142ull*MB },
258
- { MODEL_SMALL, 466ull*MB },
259
- { MODEL_MEDIUM, 1464ull*MB },
260
- { MODEL_LARGE, 2952ull*MB },
257
+ static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
258
+ { GGML_TYPE_F32,
259
+ {
260
+ { MODEL_TINY, 74ull*MB },
261
+ { MODEL_BASE, 142ull*MB },
262
+ { MODEL_SMALL, 466ull*MB },
263
+ { MODEL_MEDIUM, 1464ull*MB },
264
+ { MODEL_LARGE, 2952ull*MB },
265
+ },
266
+ },
267
+ { GGML_TYPE_F16,
268
+ {
269
+ { MODEL_TINY, 74ull*MB },
270
+ { MODEL_BASE, 142ull*MB },
271
+ { MODEL_SMALL, 466ull*MB },
272
+ { MODEL_MEDIUM, 1464ull*MB },
273
+ { MODEL_LARGE, 2952ull*MB },
274
+ },
275
+ },
276
+ { GGML_TYPE_Q4_0,
277
+ {
278
+ { MODEL_TINY, 26ull*MB },
279
+ { MODEL_BASE, 50ull*MB },
280
+ { MODEL_SMALL, 154ull*MB },
281
+ { MODEL_MEDIUM, 470ull*MB },
282
+ { MODEL_LARGE, 940ull*MB },
283
+ },
284
+ },
285
+ { GGML_TYPE_Q4_1,
286
+ {
287
+ { MODEL_TINY, 32ull*MB },
288
+ { MODEL_BASE, 58ull*MB },
289
+ { MODEL_SMALL, 182ull*MB },
290
+ { MODEL_MEDIUM, 562ull*MB },
291
+ { MODEL_LARGE, 1124ull*MB },
292
+ },
293
+ },
294
+ { GGML_TYPE_Q4_2,
295
+ {
296
+ { MODEL_TINY, 26ull*MB },
297
+ { MODEL_BASE, 50ull*MB },
298
+ { MODEL_SMALL, 154ull*MB },
299
+ { MODEL_MEDIUM, 470ull*MB },
300
+ { MODEL_LARGE, 940ull*MB },
301
+ },
302
+ },
303
+ { GGML_TYPE_Q5_0,
304
+ {
305
+ { MODEL_TINY, 30ull*MB },
306
+ { MODEL_BASE, 54ull*MB },
307
+ { MODEL_SMALL, 170ull*MB },
308
+ { MODEL_MEDIUM, 516ull*MB },
309
+ { MODEL_LARGE, 1034ull*MB },
310
+ },
311
+ },
312
+ { GGML_TYPE_Q5_1,
313
+ {
314
+ { MODEL_TINY, 32ull*MB },
315
+ { MODEL_BASE, 58ull*MB },
316
+ { MODEL_SMALL, 182ull*MB },
317
+ { MODEL_MEDIUM, 562ull*MB },
318
+ { MODEL_LARGE, 1124ull*MB },
319
+ },
320
+ },
321
+ { GGML_TYPE_Q8_0,
322
+ {
323
+ { MODEL_TINY, 45ull*MB },
324
+ { MODEL_BASE, 84ull*MB },
325
+ { MODEL_SMALL, 268ull*MB },
326
+ { MODEL_MEDIUM, 834ull*MB },
327
+ { MODEL_LARGE, 1674ull*MB },
328
+ },
329
+ },
261
330
  };
262
331
 
263
332
  static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
@@ -277,11 +346,11 @@ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
277
346
  };
278
347
 
279
348
  static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
280
- { MODEL_TINY, 6ull*MB },
281
- { MODEL_BASE, 8ull*MB },
282
- { MODEL_SMALL, 13ull*MB },
283
- { MODEL_MEDIUM, 22ull*MB },
284
- { MODEL_LARGE, 33ull*MB },
349
+ { MODEL_TINY, 30ull*MB },
350
+ { MODEL_BASE, 38ull*MB },
351
+ { MODEL_SMALL, 56ull*MB },
352
+ { MODEL_MEDIUM, 74ull*MB },
353
+ { MODEL_LARGE, 94ull*MB },
285
354
  };
286
355
 
287
356
  static const std::map<e_model, size_t> MEM_REQ_DECODE = {
@@ -294,6 +363,7 @@ static const std::map<e_model, size_t> MEM_REQ_DECODE = {
294
363
 
295
364
  struct whisper_mel {
296
365
  int n_len;
366
+ int n_len_org;
297
367
  int n_mel;
298
368
 
299
369
  std::vector<float> data;
@@ -366,7 +436,7 @@ struct whisper_hparams {
366
436
  int32_t n_text_head = 6;
367
437
  int32_t n_text_layer = 4;
368
438
  int32_t n_mels = 80;
369
- int32_t f16 = 1;
439
+ int32_t ftype = 1;
370
440
  };
371
441
 
372
442
  // audio encoding layer
@@ -586,6 +656,11 @@ struct whisper_state {
586
656
 
587
657
  int lang_id = 0; // english by default
588
658
 
659
+ std::string path_model; // populated by whisper_init_from_file()
660
+ #ifdef WHISPER_USE_COREML
661
+ whisper_coreml_context * ctx_coreml = nullptr;
662
+ #endif
663
+
589
664
  // [EXPERIMENTAL] token-level timestamps data
590
665
  int64_t t_beg = 0;
591
666
  int64_t t_last = 0;
@@ -628,15 +703,17 @@ struct whisper_state {
628
703
  };
629
704
 
630
705
  struct whisper_context {
631
- int64_t t_load_us = 0;
706
+ int64_t t_load_us = 0;
632
707
  int64_t t_start_us = 0;
633
708
 
634
-
635
- ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
709
+ ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
710
+ ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)
636
711
 
637
712
  whisper_model model;
638
713
  whisper_vocab vocab;
639
714
  whisper_state * state = nullptr;
715
+
716
+ std::string path_model; // populated by whisper_init_from_file()
640
717
  };
641
718
 
642
719
  template<typename T>
@@ -653,9 +730,11 @@ static bool kv_cache_init(
653
730
  int n_ctx) {
654
731
  cache.buf.resize(mem_bytes);
655
732
 
656
- struct ggml_init_params params;
657
- params.mem_size = cache.buf.size();
658
- params.mem_buffer = cache.buf.data();
733
+ struct ggml_init_params params = {
734
+ /*.mem_size =*/ cache.buf.size(),
735
+ /*.mem_buffer =*/ cache.buf.data(),
736
+ /*.no_alloc =*/ false,
737
+ };
659
738
 
660
739
  cache.ctx = ggml_init(params);
661
740
 
@@ -685,11 +764,13 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
685
764
  const ggml_type wtype = cache.k->type;
686
765
  WHISPER_ASSERT(wtype == cache.v->type);
687
766
 
688
- WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
767
+ WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype));
689
768
 
690
- struct ggml_init_params params;
691
- params.mem_size = cache.buf.size();
692
- params.mem_buffer = cache.buf.data();
769
+ struct ggml_init_params params = {
770
+ /*.mem_size =*/ cache.buf.size(),
771
+ /*.mem_buffer =*/ cache.buf.data(),
772
+ /*.no_alloc =*/ false,
773
+ };
693
774
 
694
775
  cache.ctx = ggml_init(params);
695
776
 
@@ -756,7 +837,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
756
837
  read_safe(loader, hparams.n_text_head);
757
838
  read_safe(loader, hparams.n_text_layer);
758
839
  read_safe(loader, hparams.n_mels);
759
- read_safe(loader, hparams.f16);
840
+ read_safe(loader, hparams.ftype);
760
841
 
761
842
  assert(hparams.n_text_state == hparams.n_audio_state);
762
843
 
@@ -780,11 +861,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
780
861
  model.type = e_model::MODEL_LARGE;
781
862
  }
782
863
 
783
- // for the big tensors, we have the option to store the data in 16-bit floats
864
+ // for the big tensors, we have the option to store the data in 16-bit floats or quantized
784
865
  // in order to save memory and also to speed up the computation
785
- wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
866
+ wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
867
+ if (wctx.wtype == GGML_TYPE_COUNT) {
868
+ fprintf(stderr, "%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
869
+ return false;
870
+ }
786
871
 
787
- const size_t scale = model.hparams.f16 ? 1 : 2;
872
+ const size_t scale = model.hparams.ftype ? 1 : 2;
788
873
 
789
874
  fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
790
875
  fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
@@ -796,18 +881,18 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
796
881
  fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
797
882
  fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
798
883
  fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
799
- fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
884
+ fprintf(stderr, "%s: ftype = %d\n", __func__, model.hparams.ftype);
800
885
  fprintf(stderr, "%s: type = %d\n", __func__, model.type);
801
886
 
802
887
  // print memory requirements
803
888
  {
804
889
  // this is the total memory required to run the inference
805
890
  const size_t mem_required =
806
- MEM_REQ_SCRATCH0.at (model.type) +
807
- MEM_REQ_SCRATCH1.at (model.type) +
808
- MEM_REQ_SCRATCH2.at (model.type) +
809
- MEM_REQ_SCRATCH3.at (model.type) +
810
- scale*MEM_REQ_MODEL.at (model.type) +
891
+ MEM_REQ_SCRATCH0.at(model.type) +
892
+ MEM_REQ_SCRATCH1.at(model.type) +
893
+ MEM_REQ_SCRATCH2.at(model.type) +
894
+ MEM_REQ_SCRATCH3.at(model.type) +
895
+ scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
811
896
  scale*MEM_REQ_KV_CROSS.at(model.type) +
812
897
  scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
813
898
 
@@ -823,7 +908,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
823
908
  // always have at least one decoder
824
909
 
825
910
  wctx.model.buf = new std::vector<uint8_t>();
826
- wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
911
+ wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
827
912
 
828
913
  // we skip initialization of the state until it is needed
829
914
  // because it might be that state will always be provided externally.
@@ -914,6 +999,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
914
999
  size_t ctx_size = 0;
915
1000
 
916
1001
  const ggml_type wtype = wctx.wtype;
1002
+ const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
917
1003
 
918
1004
  {
919
1005
  const auto & hparams = model.hparams;
@@ -932,92 +1018,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
932
1018
 
933
1019
  // encoder
934
1020
  {
935
- ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
1021
+ ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe;
936
1022
 
937
- ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
938
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
1023
+ ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(vtype); // e_conv_1_w
1024
+ ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b
939
1025
 
940
- ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
941
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
1026
+ ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(vtype); // e_conv_2_w
1027
+ ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b
942
1028
 
943
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
944
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
1029
+ ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w;
1030
+ ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b;
945
1031
  }
946
1032
 
947
1033
  // decoder
948
1034
  {
949
- ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
1035
+ ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe;
950
1036
 
951
- ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
1037
+ ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te;
952
1038
 
953
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
954
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
1039
+ ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w;
1040
+ ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b;
955
1041
  }
956
1042
 
957
1043
  // encoder layers
958
1044
  {
959
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
960
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
1045
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
1046
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
961
1047
 
962
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
963
- ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
1048
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w
1049
+ ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
964
1050
 
965
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
966
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
1051
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w
1052
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
967
1053
 
968
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
969
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
1054
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
1055
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
970
1056
 
971
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
972
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
1057
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w
1058
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
973
1059
 
974
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
1060
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w
975
1061
 
976
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
977
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
1062
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w
1063
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
978
1064
 
979
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
980
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
1065
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w
1066
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
981
1067
  }
982
1068
 
983
1069
  // decoder layers
984
1070
  {
985
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
986
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
1071
+ ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
1072
+ ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
987
1073
 
988
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
989
- ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
1074
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w
1075
+ ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
990
1076
 
991
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
992
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
1077
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w
1078
+ ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
993
1079
 
994
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
995
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
1080
+ ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
1081
+ ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
996
1082
 
997
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
998
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
1083
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w
1084
+ ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
999
1085
 
1000
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
1086
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w
1001
1087
 
1002
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
1003
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
1088
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w
1089
+ ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
1004
1090
 
1005
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
1006
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
1091
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w
1092
+ ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
1007
1093
  //
1008
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
1009
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
1094
+ ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w
1095
+ ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b
1010
1096
 
1011
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
1012
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
1097
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w
1098
+ ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b
1013
1099
 
1014
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
1100
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w
1015
1101
 
1016
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
1017
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
1102
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w
1103
+ ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b
1018
1104
 
1019
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
1020
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
1105
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w
1106
+ ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b
1021
1107
  }
1022
1108
 
1023
1109
  ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
@@ -1027,9 +1113,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1027
1113
 
1028
1114
  // create the ggml context
1029
1115
  {
1030
- struct ggml_init_params params;
1031
- params.mem_size = wctx.model.buf->size();
1032
- params.mem_buffer = wctx.model.buf->data();
1116
+ struct ggml_init_params params = {
1117
+ /*.mem_size =*/ wctx.model.buf->size(),
1118
+ /*.mem_buffer =*/ wctx.model.buf->data(),
1119
+ /*.no_alloc =*/ false,
1120
+ };
1033
1121
 
1034
1122
  model.ctx = ggml_init(params);
1035
1123
  if (!model.ctx) {
@@ -1061,175 +1149,175 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1061
1149
 
1062
1150
  // encoder
1063
1151
  {
1064
- model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1152
+ model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1065
1153
 
1066
- model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
1154
+ model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1067
1155
  model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1068
1156
 
1069
- model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
1157
+ model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1070
1158
  model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1071
1159
 
1072
- model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1073
- model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1160
+ model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1161
+ model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1074
1162
 
1075
1163
  // map by name
1076
1164
  model.tensors["encoder.positional_embedding"] = model.e_pe;
1077
1165
 
1078
- model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
1079
- model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
1166
+ model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
1167
+ model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
1080
1168
 
1081
- model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
1082
- model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
1169
+ model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
1170
+ model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
1083
1171
 
1084
- model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
1085
- model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
1172
+ model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
1173
+ model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
1086
1174
 
1087
1175
  for (int i = 0; i < n_audio_layer; ++i) {
1088
1176
  auto & layer = model.layers_encoder[i];
1089
1177
 
1090
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1091
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1178
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1179
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1092
1180
 
1093
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
1094
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
1181
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
1182
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
1095
1183
 
1096
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
1097
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1184
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
1185
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1098
1186
 
1099
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1100
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1187
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1188
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1101
1189
 
1102
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1103
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1190
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1191
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1104
1192
 
1105
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1193
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1106
1194
 
1107
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1108
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1195
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1196
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1109
1197
 
1110
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1111
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1198
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1199
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1112
1200
 
1113
1201
  // map by name
1114
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1115
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1202
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1203
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1116
1204
 
1117
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1118
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1205
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1206
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1119
1207
 
1120
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1121
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1208
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1209
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1122
1210
 
1123
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1124
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1211
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1212
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1125
1213
 
1126
1214
  model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1127
1215
  model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1128
1216
 
1129
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1217
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1130
1218
 
1131
1219
  model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1132
1220
  model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1133
1221
 
1134
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1135
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1222
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1223
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1136
1224
  }
1137
1225
  }
1138
1226
 
1139
1227
  // decoder
1140
1228
  {
1141
- model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
1229
+ model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
1142
1230
 
1143
- model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
1231
+ model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
1144
1232
 
1145
1233
  model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1146
1234
  model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1147
1235
 
1148
1236
  // map by name
1149
- model.tensors["decoder.positional_embedding"] = model.d_pe;
1237
+ model.tensors["decoder.positional_embedding"] = model.d_pe;
1150
1238
 
1151
1239
  model.tensors["decoder.token_embedding.weight"] = model.d_te;
1152
1240
 
1153
- model.tensors["decoder.ln.weight"] = model.d_ln_w;
1154
- model.tensors["decoder.ln.bias"] = model.d_ln_b;
1241
+ model.tensors["decoder.ln.weight"] = model.d_ln_w;
1242
+ model.tensors["decoder.ln.bias"] = model.d_ln_b;
1155
1243
 
1156
1244
  for (int i = 0; i < n_text_layer; ++i) {
1157
1245
  auto & layer = model.layers_decoder[i];
1158
1246
 
1159
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1160
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1247
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1248
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1161
1249
 
1162
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
1163
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
1250
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
1251
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
1164
1252
 
1165
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
1166
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1253
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
1254
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1167
1255
 
1168
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1169
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1256
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1257
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1170
1258
 
1171
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1172
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1259
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1260
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1173
1261
 
1174
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1262
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1175
1263
 
1176
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1177
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1264
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1265
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1178
1266
 
1179
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1180
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1267
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1268
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1181
1269
 
1182
- layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1183
- layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1270
+ layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1271
+ layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1184
1272
 
1185
- layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1186
- layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1273
+ layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1274
+ layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1187
1275
 
1188
- layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1276
+ layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1189
1277
 
1190
- layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1191
- layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1278
+ layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1279
+ layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1192
1280
 
1193
- layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1194
- layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1281
+ layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1282
+ layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1195
1283
 
1196
1284
  // map by name
1197
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1198
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1285
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1286
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1199
1287
 
1200
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1201
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1288
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1289
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1202
1290
 
1203
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1204
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1291
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1292
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1205
1293
 
1206
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1207
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1294
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1295
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1208
1296
 
1209
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1210
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1297
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1298
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1211
1299
 
1212
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1300
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1213
1301
 
1214
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1215
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1302
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1303
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1216
1304
 
1217
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1218
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1305
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1306
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1219
1307
 
1220
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
1221
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
1308
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
1309
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
1222
1310
 
1223
1311
  model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
1224
1312
  model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
1225
1313
 
1226
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
1314
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
1227
1315
 
1228
1316
  model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
1229
1317
  model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
1230
1318
 
1231
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
1232
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
1319
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
1320
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
1233
1321
  }
1234
1322
  }
1235
1323
  }
@@ -1243,18 +1331,18 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1243
1331
  while (true) {
1244
1332
  int32_t n_dims;
1245
1333
  int32_t length;
1246
- int32_t ftype;
1334
+ int32_t ttype;
1247
1335
 
1248
1336
  read_safe(loader, n_dims);
1249
1337
  read_safe(loader, length);
1250
- read_safe(loader, ftype);
1338
+ read_safe(loader, ttype);
1251
1339
 
1252
1340
  if (loader->eof(loader->context)) {
1253
1341
  break;
1254
1342
  }
1255
1343
 
1256
1344
  int32_t nelements = 1;
1257
- int32_t ne[3] = { 1, 1, 1 };
1345
+ int32_t ne[4] = { 1, 1, 1, 1 };
1258
1346
  for (int i = 0; i < n_dims; ++i) {
1259
1347
  read_safe(loader, ne[i]);
1260
1348
  nelements *= ne[i];
@@ -1273,18 +1361,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1273
1361
  auto tensor = model.tensors[name.data()];
1274
1362
  if (ggml_nelements(tensor) != nelements) {
1275
1363
  fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1364
+ fprintf(stderr, "%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1365
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
1276
1366
  return false;
1277
1367
  }
1278
1368
 
1279
1369
  if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1280
1370
  fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1281
- __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
1371
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
1282
1372
  return false;
1283
1373
  }
1284
1374
 
1285
- const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
1375
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
1286
1376
 
1287
- if (nelements*bpe != ggml_nbytes(tensor)) {
1377
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
1288
1378
  fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1289
1379
  __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1290
1380
  return false;
@@ -1293,7 +1383,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1293
1383
  loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1294
1384
  BYTESWAP_TENSOR(tensor);
1295
1385
 
1296
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
1386
+ //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
1297
1387
  total_size += ggml_nbytes(tensor);
1298
1388
  model.n_loaded++;
1299
1389
  }
@@ -1343,9 +1433,11 @@ static bool whisper_encode_internal(
1343
1433
  const int n_mels = hparams.n_mels;
1344
1434
  assert(mel_inp.n_mel == n_mels);
1345
1435
 
1346
- struct ggml_init_params params;
1347
- params.mem_size = wstate.buf_compute.size();
1348
- params.mem_buffer = wstate.buf_compute.data();
1436
+ struct ggml_init_params params = {
1437
+ /*.mem_size =*/ wstate.buf_compute.size(),
1438
+ /*.mem_buffer =*/ wstate.buf_compute.data(),
1439
+ /*.no_alloc =*/ false,
1440
+ };
1349
1441
 
1350
1442
  struct ggml_context * ctx0 = ggml_init(params);
1351
1443
 
@@ -1369,312 +1461,320 @@ static bool whisper_encode_internal(
1369
1461
 
1370
1462
  struct ggml_tensor * cur;
1371
1463
 
1372
- // convolution + gelu
1373
- {
1374
- wstate.use_buf(ctx0, 1);
1464
+ #ifndef WHISPER_USE_COREML
1465
+ const bool use_coreml = false;
1466
+ #else
1467
+ const bool use_coreml = wstate.ctx_coreml != nullptr;
1468
+ #endif
1375
1469
 
1376
- cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1377
- cur = ggml_add(ctx0,
1378
- ggml_repeat(ctx0,
1379
- model.e_conv_1_b,
1380
- cur),
1381
- cur);
1470
+ if (!use_coreml) {
1471
+ // convolution + gelu
1472
+ {
1473
+ wstate.use_buf(ctx0, 1);
1382
1474
 
1383
- cur = ggml_gelu(ctx0, cur);
1475
+ cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1476
+ cur = ggml_add(ctx0,
1477
+ ggml_repeat(ctx0,
1478
+ model.e_conv_1_b,
1479
+ cur),
1480
+ cur);
1384
1481
 
1385
- wstate.use_buf(ctx0, 0);
1482
+ cur = ggml_gelu(ctx0, cur);
1386
1483
 
1387
- cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1388
- cur = ggml_add(ctx0,
1389
- ggml_repeat(ctx0,
1390
- model.e_conv_2_b,
1391
- cur),
1392
- cur);
1484
+ wstate.use_buf(ctx0, 0);
1393
1485
 
1394
- cur = ggml_gelu(ctx0, cur);
1395
- }
1486
+ cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1487
+ cur = ggml_add(ctx0,
1488
+ ggml_repeat(ctx0,
1489
+ model.e_conv_2_b,
1490
+ cur),
1491
+ cur);
1396
1492
 
1397
- wstate.use_buf(ctx0, 3);
1493
+ cur = ggml_gelu(ctx0, cur);
1494
+ }
1398
1495
 
1399
- // ===================================================================
1400
- // NOTE: experimenting with partial evaluation of the encoder (ignore)
1401
- //static int iter = -1;
1402
- //const int n_iter = 1500/n_ctx;
1496
+ wstate.use_buf(ctx0, 3);
1403
1497
 
1404
- //iter = (iter + 1) % n_iter;
1498
+ // ===================================================================
1499
+ // NOTE: experimenting with partial evaluation of the encoder (ignore)
1500
+ //static int iter = -1;
1501
+ //const int n_iter = 1500/n_ctx;
1405
1502
 
1406
- //if (iter == 0) {
1407
- // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
1408
- // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
1409
- //}
1503
+ //iter = (iter + 1) % n_iter;
1410
1504
 
1411
- static int iter = 0;
1412
-
1413
- const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1414
- const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1505
+ //if (iter == 0) {
1506
+ // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
1507
+ // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
1508
+ //}
1415
1509
 
1416
- struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1510
+ static int iter = 0;
1417
1511
 
1418
- cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
1512
+ const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1513
+ const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1419
1514
 
1420
- // ===================================================================
1515
+ struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1421
1516
 
1422
- // original:
1423
- //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1517
+ cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
1424
1518
 
1425
- struct ggml_tensor * inpL = cur;
1519
+ // ===================================================================
1426
1520
 
1427
- for (int il = 0; il < n_layer; ++il) {
1428
- const auto & layer = model.layers_encoder[il];
1521
+ // original:
1522
+ //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1429
1523
 
1430
- // norm
1431
- {
1432
- wstate.use_buf(ctx0, 0);
1524
+ struct ggml_tensor * inpL = cur;
1433
1525
 
1434
- cur = ggml_norm(ctx0, inpL);
1526
+ for (int il = 0; il < n_layer; ++il) {
1527
+ const auto & layer = model.layers_encoder[il];
1435
1528
 
1436
- // cur = ln_0_w*cur + ln_0_b
1437
- cur = ggml_add(ctx0,
1438
- ggml_mul(ctx0,
1439
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1440
- cur),
1441
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1442
- }
1529
+ // norm
1530
+ {
1531
+ wstate.use_buf(ctx0, 0);
1443
1532
 
1444
- // self-attention
1445
- {
1446
- wstate.use_buf(ctx0, 1);
1533
+ cur = ggml_norm(ctx0, inpL);
1447
1534
 
1448
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1449
- layer.attn_q_w,
1450
- cur);
1535
+ // cur = ln_0_w*cur + ln_0_b
1536
+ cur = ggml_add(ctx0,
1537
+ ggml_mul(ctx0,
1538
+ ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1539
+ cur),
1540
+ ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1541
+ }
1451
1542
 
1452
- Qcur = ggml_add(ctx0,
1453
- ggml_repeat(ctx0,
1454
- layer.attn_q_b,
1455
- Qcur),
1456
- Qcur);
1543
+ // self-attention
1544
+ {
1545
+ wstate.use_buf(ctx0, 1);
1457
1546
 
1458
- //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1547
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1548
+ layer.attn_q_w,
1549
+ cur);
1459
1550
 
1460
- // note: no bias for Key
1461
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1462
- layer.attn_k_w,
1463
- cur);
1551
+ Qcur = ggml_add(ctx0,
1552
+ ggml_repeat(ctx0,
1553
+ layer.attn_q_b,
1554
+ Qcur),
1555
+ Qcur);
1464
1556
 
1465
- //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1557
+ //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1466
1558
 
1467
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1468
- layer.attn_v_w,
1469
- cur);
1559
+ // note: no bias for Key
1560
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1561
+ layer.attn_k_w,
1562
+ cur);
1470
1563
 
1471
- Vcur = ggml_add(ctx0,
1472
- ggml_repeat(ctx0,
1473
- layer.attn_v_b,
1474
- Vcur),
1475
- Vcur);
1564
+ //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1476
1565
 
1477
- // ------
1566
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1567
+ layer.attn_v_w,
1568
+ cur);
1478
1569
 
1479
- wstate.use_buf(ctx0, 0);
1570
+ Vcur = ggml_add(ctx0,
1571
+ ggml_repeat(ctx0,
1572
+ layer.attn_v_b,
1573
+ Vcur),
1574
+ Vcur);
1480
1575
 
1481
- #ifdef WHISPER_USE_FLASH_ATTN
1482
- struct ggml_tensor * Q =
1483
- ggml_permute(ctx0,
1484
- ggml_cpy(ctx0,
1485
- Qcur,
1486
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1487
- 0, 2, 1, 3);
1576
+ // ------
1488
1577
 
1489
- struct ggml_tensor * K =
1490
- ggml_permute(ctx0,
1491
- ggml_cpy(ctx0,
1492
- Kcur,
1493
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1494
- 0, 2, 1, 3);
1578
+ wstate.use_buf(ctx0, 0);
1495
1579
 
1496
- struct ggml_tensor * V =
1497
- ggml_cpy(ctx0,
1498
- ggml_permute(ctx0,
1499
- ggml_reshape_3d(ctx0,
1500
- Vcur,
1501
- n_state/n_head, n_head, n_ctx),
1502
- 1, 2, 0, 3),
1503
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
1504
- );
1505
-
1506
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1580
+ #ifdef WHISPER_USE_FLASH_ATTN
1581
+ struct ggml_tensor * Q =
1582
+ ggml_permute(ctx0,
1583
+ ggml_cpy(ctx0,
1584
+ Qcur,
1585
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1586
+ 0, 2, 1, 3);
1587
+
1588
+ struct ggml_tensor * K =
1589
+ ggml_permute(ctx0,
1590
+ ggml_cpy(ctx0,
1591
+ Kcur,
1592
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1593
+ 0, 2, 1, 3);
1594
+
1595
+ struct ggml_tensor * V =
1596
+ ggml_cpy(ctx0,
1597
+ ggml_permute(ctx0,
1598
+ ggml_reshape_3d(ctx0,
1599
+ Vcur,
1600
+ n_state/n_head, n_head, n_ctx),
1601
+ 1, 2, 0, 3),
1602
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
1603
+
1604
+ struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1507
1605
  #else
1508
- struct ggml_tensor * Q =
1509
- ggml_permute(ctx0,
1510
- ggml_cpy(ctx0,
1511
- Qcur,
1512
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1513
- 0, 2, 1, 3);
1606
+ struct ggml_tensor * Q =
1607
+ ggml_permute(ctx0,
1608
+ ggml_cpy(ctx0,
1609
+ Qcur,
1610
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1611
+ 0, 2, 1, 3);
1612
+
1613
+ struct ggml_tensor * K =
1614
+ ggml_permute(ctx0,
1615
+ ggml_cpy(ctx0,
1616
+ Kcur,
1617
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1618
+ 0, 2, 1, 3);
1619
+
1620
+ // K * Q
1621
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1622
+
1623
+ struct ggml_tensor * KQ_scaled =
1624
+ ggml_scale(ctx0,
1625
+ KQ,
1626
+ ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1627
+ );
1628
+
1629
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
1630
+
1631
+ struct ggml_tensor * V =
1632
+ ggml_cpy(ctx0,
1633
+ ggml_permute(ctx0,
1634
+ ggml_reshape_3d(ctx0,
1635
+ Vcur,
1636
+ n_state/n_head, n_head, n_ctx),
1637
+ 1, 2, 0, 3),
1638
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1639
+ );
1640
+
1641
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
1642
+ #endif
1643
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1514
1644
 
1515
- struct ggml_tensor * K =
1516
- ggml_permute(ctx0,
1517
- ggml_cpy(ctx0,
1518
- Kcur,
1519
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1520
- 0, 2, 1, 3);
1645
+ wstate.use_buf(ctx0, 1);
1521
1646
 
1522
- // K * Q
1523
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1647
+ cur = ggml_cpy(ctx0,
1648
+ KQV_merged,
1649
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1650
+ }
1524
1651
 
1525
- struct ggml_tensor * KQ_scaled =
1526
- ggml_scale(ctx0,
1527
- KQ,
1528
- ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1529
- );
1652
+ // projection
1653
+ {
1654
+ wstate.use_buf(ctx0, 0);
1530
1655
 
1531
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
1656
+ cur = ggml_mul_mat(ctx0,
1657
+ layer.attn_ln_1_w,
1658
+ cur);
1532
1659
 
1533
- //struct ggml_tensor * V_trans =
1534
- // ggml_permute(ctx0,
1535
- // ggml_cpy(ctx0,
1536
- // Vcur,
1537
- // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1538
- // 1, 2, 0, 3);
1660
+ wstate.use_buf(ctx0, 1);
1539
1661
 
1540
- //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1662
+ cur = ggml_add(ctx0,
1663
+ ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1664
+ cur);
1665
+ }
1541
1666
 
1542
- struct ggml_tensor * V =
1543
- ggml_cpy(ctx0,
1544
- ggml_permute(ctx0,
1545
- ggml_reshape_3d(ctx0,
1546
- Vcur,
1547
- n_state/n_head, n_head, n_ctx),
1548
- 0, 2, 1, 3),
1549
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
1550
- );
1551
-
1552
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
1553
- #endif
1554
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1667
+ wstate.use_buf(ctx0, 2);
1555
1668
 
1556
- wstate.use_buf(ctx0, 1);
1669
+ // add the input
1670
+ cur = ggml_add(ctx0, cur, inpL);
1557
1671
 
1558
- cur = ggml_cpy(ctx0,
1559
- KQV_merged,
1560
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1561
- }
1672
+ struct ggml_tensor * inpFF = cur;
1562
1673
 
1563
- // projection
1564
- {
1565
- wstate.use_buf(ctx0, 0);
1566
-
1567
- cur = ggml_mul_mat(ctx0,
1568
- layer.attn_ln_1_w,
1569
- cur);
1570
-
1571
- wstate.use_buf(ctx0, 1);
1674
+ // feed-forward network
1675
+ {
1676
+ // norm
1677
+ {
1678
+ wstate.use_buf(ctx0, 0);
1572
1679
 
1573
- cur = ggml_add(ctx0,
1574
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1575
- cur);
1576
- }
1680
+ cur = ggml_norm(ctx0, inpFF);
1577
1681
 
1578
- wstate.use_buf(ctx0, 2);
1682
+ wstate.use_buf(ctx0, 1);
1579
1683
 
1580
- // add the input
1581
- cur = ggml_add(ctx0, cur, inpL);
1684
+ // cur = mlp_ln_w*cur + mlp_ln_b
1685
+ cur = ggml_add(ctx0,
1686
+ ggml_mul(ctx0,
1687
+ ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1688
+ cur),
1689
+ ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1690
+ }
1582
1691
 
1583
- struct ggml_tensor * inpFF = cur;
1692
+ #ifdef WHISPER_USE_FLASH_FF
1693
+ wstate.use_buf(ctx0, 0);
1584
1694
 
1585
- // feed-forward network
1586
- {
1587
- // norm
1588
- {
1695
+ cur = ggml_flash_ff(ctx0,
1696
+ ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
1697
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1698
+ #else
1589
1699
  wstate.use_buf(ctx0, 0);
1590
1700
 
1591
- cur = ggml_norm(ctx0, inpFF);
1701
+ // fully connected
1702
+ cur = ggml_mul_mat(ctx0,
1703
+ layer.mlp_0_w,
1704
+ cur);
1592
1705
 
1593
1706
  wstate.use_buf(ctx0, 1);
1594
1707
 
1595
- // cur = mlp_ln_w*cur + mlp_ln_b
1596
1708
  cur = ggml_add(ctx0,
1597
- ggml_mul(ctx0,
1598
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1599
- cur),
1600
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1601
- }
1709
+ ggml_repeat(ctx0, layer.mlp_0_b, cur),
1710
+ cur);
1602
1711
 
1603
- #ifdef WHISPER_USE_FLASH_FF
1604
- wstate.use_buf(ctx0, 0);
1712
+ wstate.use_buf(ctx0, 0);
1605
1713
 
1606
- cur = ggml_flash_ff(ctx0,
1607
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
1608
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1609
- #else
1610
- wstate.use_buf(ctx0, 0);
1714
+ // GELU activation
1715
+ cur = ggml_gelu(ctx0, cur);
1611
1716
 
1612
- // fully connected
1613
- cur = ggml_mul_mat(ctx0,
1614
- layer.mlp_0_w,
1615
- cur);
1717
+ wstate.use_buf(ctx0, 1);
1616
1718
 
1617
- wstate.use_buf(ctx0, 1);
1719
+ // projection
1720
+ cur = ggml_mul_mat(ctx0,
1721
+ layer.mlp_1_w,
1722
+ cur);
1618
1723
 
1619
- cur = ggml_add(ctx0,
1620
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
1621
- cur);
1724
+ wstate.use_buf(ctx0, 0);
1622
1725
 
1623
- wstate.use_buf(ctx0, 0);
1726
+ cur = ggml_add(ctx0,
1727
+ ggml_repeat(ctx0, layer.mlp_1_b, cur),
1728
+ cur);
1729
+ #endif
1730
+ }
1624
1731
 
1625
- // GELU activation
1626
- cur = ggml_gelu(ctx0, cur);
1732
+ wstate.use_buf(ctx0, 3);
1627
1733
 
1628
- wstate.use_buf(ctx0, 1);
1734
+ inpL = ggml_add(ctx0, cur, inpFF);
1735
+ }
1629
1736
 
1630
- // projection
1631
- cur = ggml_mul_mat(ctx0,
1632
- layer.mlp_1_w,
1633
- cur);
1737
+ cur = inpL;
1634
1738
 
1739
+ // norm
1740
+ {
1635
1741
  wstate.use_buf(ctx0, 0);
1636
1742
 
1637
- cur = ggml_add(ctx0,
1638
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
1639
- cur);
1640
- #endif
1641
- }
1642
-
1643
- wstate.use_buf(ctx0, 3);
1743
+ cur = ggml_norm(ctx0, cur);
1644
1744
 
1645
- inpL = ggml_add(ctx0, cur, inpFF);
1646
- }
1745
+ wstate.use_buf(ctx0, 1);
1647
1746
 
1648
- cur = inpL;
1747
+ // cur = ln_f_g*cur + ln_f_b
1748
+ cur = ggml_add(ctx0,
1749
+ ggml_mul(ctx0,
1750
+ ggml_repeat(ctx0, model.e_ln_w, cur),
1751
+ cur),
1752
+ ggml_repeat(ctx0, model.e_ln_b, cur));
1753
+ }
1649
1754
 
1650
- // norm
1651
- {
1652
- wstate.use_buf(ctx0, 0);
1755
+ wstate.use_buf(ctx0, -1);
1653
1756
 
1654
- cur = ggml_norm(ctx0, cur);
1757
+ // run the computation
1758
+ {
1759
+ struct ggml_cgraph gf = {};
1760
+ gf.n_threads = n_threads;
1655
1761
 
1656
- wstate.use_buf(ctx0, 1);
1762
+ ggml_build_forward_expand(&gf, cur);
1763
+ ggml_graph_compute(ctx0, &gf);
1657
1764
 
1658
- // cur = ln_f_g*cur + ln_f_b
1659
- cur = ggml_add(ctx0,
1660
- ggml_mul(ctx0,
1661
- ggml_repeat(ctx0, model.e_ln_w, cur),
1662
- cur),
1663
- ggml_repeat(ctx0, model.e_ln_b, cur));
1765
+ //ggml_graph_print(&gf);
1766
+ }
1664
1767
  }
1665
-
1666
- wstate.use_buf(ctx0, -1);
1667
-
1668
- // run the computation
1768
+ #ifdef WHISPER_USE_COREML
1769
+ else
1669
1770
  {
1670
- struct ggml_cgraph gf = {};
1671
- gf.n_threads = n_threads;
1771
+ wstate.use_buf(ctx0, -1);
1672
1772
 
1673
- ggml_build_forward_expand(&gf, cur);
1674
- ggml_graph_compute(ctx0, &gf);
1773
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1675
1774
 
1676
- //ggml_graph_print(&gf);
1775
+ whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1677
1776
  }
1777
+ #endif
1678
1778
 
1679
1779
  // cur
1680
1780
  //{
@@ -1725,10 +1825,12 @@ static bool whisper_encode_internal(
1725
1825
 
1726
1826
  wstate.use_buf(ctx0, -1);
1727
1827
 
1728
- //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1729
- //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1730
- struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
1731
- struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
1828
+ Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
1829
+
1830
+ struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
1831
+ struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
1832
+ ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
1833
+ (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
1732
1834
 
1733
1835
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1734
1836
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1742,10 +1844,10 @@ static bool whisper_encode_internal(
1742
1844
 
1743
1845
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1744
1846
  // ggml_used_mem(ctx0)/1024.0/1024.0,
1745
- // wctx.get_buf_max_mem(0)/1024.0/1024.0,
1746
- // wctx.get_buf_max_mem(1)/1024.0/1024.0,
1747
- // wctx.get_buf_max_mem(2)/1024.0/1024.0,
1748
- // wctx.get_buf_max_mem(3)/1024.0/1024.0);
1847
+ // wstate.get_buf_max_mem(0)/1024.0/1024.0,
1848
+ // wstate.get_buf_max_mem(1)/1024.0/1024.0,
1849
+ // wstate.get_buf_max_mem(2)/1024.0/1024.0,
1850
+ // wstate.get_buf_max_mem(3)/1024.0/1024.0);
1749
1851
 
1750
1852
  ggml_free(ctx0);
1751
1853
 
@@ -1796,9 +1898,11 @@ static bool whisper_decode_internal(
1796
1898
 
1797
1899
  //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1798
1900
 
1799
- struct ggml_init_params params;
1800
- params.mem_size = wstate.buf_compute.size();
1801
- params.mem_buffer = wstate.buf_compute.data();
1901
+ struct ggml_init_params params = {
1902
+ /*.mem_size =*/ wstate.buf_compute.size(),
1903
+ /*.mem_buffer =*/ wstate.buf_compute.data(),
1904
+ /*.no_alloc =*/ false,
1905
+ };
1802
1906
 
1803
1907
  struct ggml_context * ctx0 = ggml_init(params);
1804
1908
 
@@ -1842,8 +1946,6 @@ static bool whisper_decode_internal(
1842
1946
 
1843
1947
  // self-attention
1844
1948
  {
1845
- wstate.use_buf(ctx0, 1);
1846
-
1847
1949
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1848
1950
  layer.attn_q_w,
1849
1951
  cur);
@@ -1863,20 +1965,24 @@ static bool whisper_decode_internal(
1863
1965
 
1864
1966
  Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1865
1967
 
1866
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1867
- layer.attn_v_w,
1868
- cur);
1869
-
1870
- Vcur = ggml_add(ctx0,
1871
- ggml_repeat(ctx0,
1872
- layer.attn_v_b,
1873
- Vcur),
1874
- Vcur);
1875
-
1876
1968
  // store key and value to memory
1877
1969
  {
1970
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1971
+ layer.attn_v_w,
1972
+ cur);
1973
+
1974
+ Vcur = ggml_add(ctx0,
1975
+ ggml_repeat(ctx0,
1976
+ layer.attn_v_b,
1977
+ Vcur),
1978
+ Vcur);
1979
+
1980
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
1981
+
1878
1982
  struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
1879
- struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
1983
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
1984
+ ( n_ctx)*ggml_element_size(kv_self.v),
1985
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
1880
1986
 
1881
1987
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
1882
1988
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
@@ -1905,8 +2011,6 @@ static bool whisper_decode_internal(
1905
2011
  // K * Q
1906
2012
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1907
2013
 
1908
- wstate.use_buf(ctx0, 0);
1909
-
1910
2014
  //struct ggml_tensor * KQ_scaled =
1911
2015
  // ggml_scale(ctx0,
1912
2016
  // KQ,
@@ -1915,22 +2019,16 @@ static bool whisper_decode_internal(
1915
2019
 
1916
2020
  struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1917
2021
 
1918
- wstate.use_buf(ctx0, 1);
1919
-
1920
2022
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1921
2023
 
1922
- wstate.use_buf(ctx0, 0);
1923
-
1924
- struct ggml_tensor * V_trans =
1925
- ggml_permute(ctx0,
1926
- ggml_reshape_3d(ctx0,
1927
- ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
1928
- n_state/n_head, n_head, n_past + N),
1929
- 1, 2, 0, 3);
1930
-
1931
- wstate.use_buf(ctx0, 1);
2024
+ struct ggml_tensor * V =
2025
+ ggml_view_3d(ctx0, kv_self.v,
2026
+ n_past + N, n_state/n_head, n_head,
2027
+ n_ctx*ggml_element_size(kv_self.v),
2028
+ n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
2029
+ il*n_ctx*ggml_element_size(kv_self.v)*n_state);
1932
2030
 
1933
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2031
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
1934
2032
 
1935
2033
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1936
2034
 
@@ -1965,8 +2063,6 @@ static bool whisper_decode_internal(
1965
2063
 
1966
2064
  cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1967
2065
 
1968
- wstate.use_buf(ctx0, 1);
1969
-
1970
2066
  // cur = ln_0_w*cur + ln_0_b
1971
2067
  cur = ggml_add(ctx0,
1972
2068
  ggml_mul(ctx0,
@@ -1977,8 +2073,6 @@ static bool whisper_decode_internal(
1977
2073
 
1978
2074
  // cross-attention
1979
2075
  {
1980
- wstate.use_buf(ctx0, 0);
1981
-
1982
2076
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1983
2077
  layer.cross_attn_q_w,
1984
2078
  cur);
@@ -1997,16 +2091,24 @@ static bool whisper_decode_internal(
1997
2091
  ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
1998
2092
  n_state/n_head, n_head, M);
1999
2093
 
2000
- struct ggml_tensor * Vcross =
2001
- ggml_reshape_3d(ctx0,
2002
- ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
2003
- n_state/n_head, n_head, M);
2094
+ //struct ggml_tensor * Vcross =
2095
+ // ggml_reshape_3d(ctx0,
2096
+ // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
2097
+ // n_state/n_head, n_head, M);
2004
2098
 
2005
- struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
2099
+ //struct ggml_tensor * V_trans =
2100
+ // ggml_cpy(ctx0,
2101
+ // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2102
+ // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
2006
2103
 
2007
- // ------
2104
+ struct ggml_tensor * V =
2105
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
2106
+ M, n_state/n_head, n_head,
2107
+ M*ggml_element_size(wstate.kv_cross.v),
2108
+ M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2109
+ il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
2008
2110
 
2009
- wstate.use_buf(ctx0, 1);
2111
+ // ------
2010
2112
 
2011
2113
  struct ggml_tensor * Q =
2012
2114
  ggml_permute(ctx0,
@@ -2017,8 +2119,6 @@ static bool whisper_decode_internal(
2017
2119
 
2018
2120
  struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2019
2121
 
2020
- wstate.use_buf(ctx0, 0);
2021
-
2022
2122
  // K * Q
2023
2123
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2024
2124
 
@@ -2031,15 +2131,9 @@ static bool whisper_decode_internal(
2031
2131
  // no masking for cross-attention
2032
2132
  //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2033
2133
 
2034
- wstate.use_buf(ctx0, 1);
2035
-
2036
2134
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2037
2135
 
2038
- wstate.use_buf(ctx0, 0);
2039
-
2040
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2041
-
2042
- wstate.use_buf(ctx0, 1);
2136
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2043
2137
 
2044
2138
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2045
2139
 
@@ -2171,10 +2265,10 @@ static bool whisper_decode_internal(
2171
2265
  if (N > 1) {
2172
2266
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2173
2267
  // ggml_used_mem(ctx0)/1024.0/1024.0,
2174
- // wctx.get_buf_max_mem(0)/1024.0/1024.0,
2175
- // wctx.get_buf_max_mem(1)/1024.0/1024.0,
2176
- // wctx.get_buf_max_mem(2)/1024.0/1024.0,
2177
- // wctx.get_buf_max_mem(3)/1024.0/1024.0);
2268
+ // wstate.get_buf_max_mem(0)/1024.0/1024.0,
2269
+ // wstate.get_buf_max_mem(1)/1024.0/1024.0,
2270
+ // wstate.get_buf_max_mem(2)/1024.0/1024.0,
2271
+ // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2178
2272
  }
2179
2273
 
2180
2274
  ggml_free(ctx0);
@@ -2282,6 +2376,68 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2282
2376
  }
2283
2377
  }
2284
2378
 
2379
+ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> &hann, const float *samples,
2380
+ int n_samples, int fft_size, int fft_step, int n_threads,
2381
+ const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
2382
+ std::vector<float> fft_in(fft_size, 0.0);
2383
+ std::vector<float> fft_out(2 * fft_size);
2384
+ int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
2385
+
2386
+ for (int i = ith; i < mel.n_len; i += n_threads) {
2387
+ const int offset = i * fft_step;
2388
+
2389
+ // apply Hanning window
2390
+ for (int j = 0; j < fft_size; j++) {
2391
+ if (offset + j < n_samples) {
2392
+ fft_in[j] = hann[j] * samples[offset + j];
2393
+ } else {
2394
+ fft_in[j] = 0.0;
2395
+ }
2396
+ }
2397
+
2398
+ // FFT -> mag^2
2399
+ fft(fft_in, fft_out);
2400
+
2401
+ for (int j = 0; j < fft_size; j++) {
2402
+ fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
2403
+ }
2404
+ for (int j = 1; j < fft_size / 2; j++) {
2405
+ fft_out[j] += fft_out[fft_size - j];
2406
+ }
2407
+
2408
+ if (speed_up) {
2409
+ // scale down in the frequency domain results in a speed up in the time domain
2410
+ for (int j = 0; j < n_fft; j++) {
2411
+ fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
2412
+ }
2413
+ }
2414
+
2415
+ // mel spectrogram
2416
+ for (int j = 0; j < mel.n_mel; j++) {
2417
+ double sum = 0.0;
2418
+
2419
+ // unroll loop (suggested by GH user @lunixbochs)
2420
+ int k = 0;
2421
+ for (k = 0; k < n_fft - 3; k += 4) {
2422
+ sum +=
2423
+ fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
2424
+ fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
2425
+ fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
2426
+ fft_out[k + 3] * filters.data[j*n_fft + k + 3];
2427
+ }
2428
+
2429
+ // handle n_fft remainder
2430
+ for (; k < n_fft; k++) {
2431
+ sum += fft_out[k] * filters.data[j * n_fft + k];
2432
+ }
2433
+
2434
+ sum = log10(std::max(sum, 1e-10));
2435
+
2436
+ mel.data[j * mel.n_len + i] = sum;
2437
+ }
2438
+ }
2439
+ }
2440
+
2285
2441
  // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2286
2442
  static bool log_mel_spectrogram(
2287
2443
  whisper_state & wstate,
@@ -2304,85 +2460,48 @@ static bool log_mel_spectrogram(
2304
2460
  hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
2305
2461
  }
2306
2462
 
2307
- mel.n_mel = n_mel;
2308
- mel.n_len = (n_samples)/fft_step;
2309
- mel.data.resize(mel.n_mel*mel.n_len);
2310
-
2311
- const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
2312
-
2313
- //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2314
- //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
2315
-
2316
- std::vector<std::thread> workers(n_threads);
2317
- for (int iw = 0; iw < n_threads; ++iw) {
2318
- workers[iw] = std::thread([&](int ith) {
2319
- std::vector<float> fft_in;
2320
- fft_in.resize(fft_size);
2321
- for (int i = 0; i < fft_size; i++) {
2322
- fft_in[i] = 0.0;
2323
- }
2463
+ mel.n_mel = n_mel;
2464
+ mel.n_len = n_samples/fft_step;
2465
+ mel.n_len_org = mel.n_len;
2324
2466
 
2325
- std::vector<float> fft_out;
2326
- fft_out.resize(2*fft_size);
2467
+ std::vector<float> samples_padded;
2327
2468
 
2328
- for (int i = ith; i < mel.n_len; i += n_threads) {
2329
- const int offset = i*fft_step;
2469
+ // pad audio with at least one extra chunk of zeros
2470
+ {
2471
+ const int pad = (100*WHISPER_CHUNK_SIZE)/2;
2330
2472
 
2331
- // apply Hanning window
2332
- for (int j = 0; j < fft_size; j++) {
2333
- if (offset + j < n_samples) {
2334
- fft_in[j] = hann[j]*samples[offset + j];
2335
- } else {
2336
- fft_in[j] = 0.0;
2337
- }
2338
- }
2473
+ if (mel.n_len % pad != 0) {
2474
+ mel.n_len = (mel.n_len/pad + 1)*pad;
2475
+ }
2476
+ mel.n_len += pad;
2339
2477
 
2340
- // FFT -> mag^2
2341
- fft(fft_in, fft_out);
2478
+ samples_padded.resize(mel.n_len*fft_step);
2479
+ memcpy(samples_padded.data(), samples, n_samples*sizeof(float));
2480
+ memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float));
2342
2481
 
2343
- for (int j = 0; j < fft_size; j++) {
2344
- fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
2345
- }
2346
- for (int j = 1; j < fft_size/2; j++) {
2347
- //if (i == 0) {
2348
- // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
2349
- //}
2350
- fft_out[j] += fft_out[fft_size - j];
2351
- }
2352
- if (i == 0) {
2353
- //for (int j = 0; j < fft_size; j++) {
2354
- // printf("%d: %e\n", j, fft_out[j]);
2355
- //}
2356
- }
2357
-
2358
- if (speed_up) {
2359
- // scale down in the frequency domain results in a speed up in the time domain
2360
- for (int j = 0; j < n_fft; j++) {
2361
- fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
2362
- }
2363
- }
2482
+ samples = samples_padded.data();
2483
+ }
2364
2484
 
2365
- // mel spectrogram
2366
- for (int j = 0; j < mel.n_mel; j++) {
2367
- double sum = 0.0;
2485
+ mel.data.resize(mel.n_mel*mel.n_len);
2368
2486
 
2369
- for (int k = 0; k < n_fft; k++) {
2370
- sum += fft_out[k]*filters.data[j*n_fft + k];
2371
- }
2372
- if (sum < 1e-10) {
2373
- sum = 1e-10;
2374
- }
2487
+ //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2488
+ //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
2375
2489
 
2376
- sum = log10(sum);
2490
+ {
2491
+ std::vector<std::thread> workers(n_threads - 1);
2492
+ for (int iw = 0; iw < n_threads - 1; ++iw) {
2493
+ workers[iw] = std::thread(
2494
+ log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
2495
+ n_samples, fft_size, fft_step, n_threads,
2496
+ std::cref(filters), speed_up, std::ref(mel));
2497
+ }
2377
2498
 
2378
- mel.data[j*mel.n_len + i] = sum;
2379
- }
2380
- }
2381
- }, iw);
2382
- }
2499
+ // main thread
2500
+ log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
2383
2501
 
2384
- for (int iw = 0; iw < n_threads; ++iw) {
2385
- workers[iw].join();
2502
+ for (int iw = 0; iw < n_threads - 1; ++iw) {
2503
+ workers[iw].join();
2504
+ }
2386
2505
  }
2387
2506
 
2388
2507
  // clamping and normalization
@@ -2406,6 +2525,8 @@ static bool log_mel_spectrogram(
2406
2525
 
2407
2526
  wstate.t_mel_us += ggml_time_us() - t_start_us;
2408
2527
 
2528
+ //printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step);
2529
+
2409
2530
  return true;
2410
2531
  }
2411
2532
 
@@ -2447,25 +2568,20 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
2447
2568
  int n = word.size();
2448
2569
  while (i < n) {
2449
2570
  int j = n;
2571
+ bool found = false;
2450
2572
  while (j > i) {
2451
- auto it = vocab.token_to_id.find(word.substr(i, j-i));
2573
+ auto sub = word.substr(i, j-i);
2574
+ auto it = vocab.token_to_id.find(sub);
2452
2575
  if (it != vocab.token_to_id.end()) {
2453
2576
  tokens.push_back(it->second);
2454
2577
  i = j;
2578
+ found = true;
2455
2579
  break;
2456
2580
  }
2457
2581
  --j;
2458
2582
  }
2459
- if (i == n) {
2460
- break;
2461
- }
2462
- if (j == i) {
2463
- auto sub = word.substr(i, 1);
2464
- if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
2465
- tokens.push_back(vocab.token_to_id.at(sub));
2466
- } else {
2467
- fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
2468
- }
2583
+ if (!found) {
2584
+ fprintf(stderr, "unknown token \n");
2469
2585
  ++i;
2470
2586
  }
2471
2587
  }
@@ -2478,14 +2594,28 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
2478
2594
  // interface implementation
2479
2595
  //
2480
2596
 
2597
+ #ifdef WHISPER_USE_COREML
2598
+ // replace .bin with -encoder.mlmodelc
2599
+ static std::string whisper_get_coreml_path_encoder(std::string path_bin) {
2600
+ auto pos = path_bin.rfind('.');
2601
+ if (pos != std::string::npos) {
2602
+ path_bin = path_bin.substr(0, pos);
2603
+ }
2604
+
2605
+ path_bin += "-encoder.mlmodelc";
2606
+
2607
+ return path_bin;
2608
+ }
2609
+ #endif
2610
+
2481
2611
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
2482
2612
  whisper_state * state = new whisper_state;
2483
2613
 
2484
- const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
2485
-
2614
+ const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
2486
2615
 
2487
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
2616
+ if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2488
2617
  fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
2618
+ delete state;
2489
2619
  return nullptr;
2490
2620
  }
2491
2621
 
@@ -2494,8 +2624,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2494
2624
  fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2495
2625
  }
2496
2626
 
2497
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
2627
+ if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2498
2628
  fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2629
+ delete state;
2499
2630
  return nullptr;
2500
2631
  }
2501
2632
 
@@ -2504,6 +2635,22 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2504
2635
  fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2505
2636
  }
2506
2637
 
2638
+ #ifdef WHISPER_USE_COREML
2639
+ const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
2640
+
2641
+ fprintf(stderr, "%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
2642
+ fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__);
2643
+
2644
+ state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
2645
+ if (!state->ctx_coreml) {
2646
+ fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2647
+ #ifndef WHISPER_COREML_ALLOW_FALLBACK
2648
+ return nullptr;
2649
+ #endif
2650
+ } else {
2651
+ fprintf(stderr, "%s: Core ML model loaded\n", __func__);
2652
+ }
2653
+ #endif
2507
2654
 
2508
2655
  state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
2509
2656
 
@@ -2528,7 +2675,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2528
2675
  }
2529
2676
 
2530
2677
  struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
2531
- whisper_model_loader loader = {};
2532
2678
 
2533
2679
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
2534
2680
 
@@ -2538,7 +2684,10 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
2538
2684
  return nullptr;
2539
2685
  }
2540
2686
 
2687
+ whisper_model_loader loader = {};
2688
+
2541
2689
  loader.context = &fin;
2690
+
2542
2691
  loader.read = [](void * ctx, void * output, size_t read_size) {
2543
2692
  std::ifstream * fin = (std::ifstream*)ctx;
2544
2693
  fin->read((char *)output, read_size);
@@ -2555,7 +2704,13 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
2555
2704
  fin->close();
2556
2705
  };
2557
2706
 
2558
- return whisper_init_no_state(&loader);
2707
+ auto ctx = whisper_init_no_state(&loader);
2708
+
2709
+ if (ctx) {
2710
+ ctx->path_model = path_model;
2711
+ }
2712
+
2713
+ return ctx;
2559
2714
  }
2560
2715
 
2561
2716
  struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
@@ -2566,10 +2721,11 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
2566
2721
  };
2567
2722
 
2568
2723
  buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
2569
- whisper_model_loader loader = {};
2570
2724
 
2571
2725
  fprintf(stderr, "%s: loading model from buffer\n", __func__);
2572
2726
 
2727
+ whisper_model_loader loader = {};
2728
+
2573
2729
  loader.context = &ctx;
2574
2730
 
2575
2731
  loader.read = [](void * ctx, void * output, size_t read_size) {
@@ -2665,6 +2821,13 @@ void whisper_free_state(struct whisper_state * state)
2665
2821
  kv_cache_free(state->decoders[i].kv_self);
2666
2822
  }
2667
2823
 
2824
+ #ifdef WHISPER_USE_COREML
2825
+ if (state->ctx_coreml != nullptr) {
2826
+ whisper_coreml_free(state->ctx_coreml);
2827
+ state->ctx_coreml = nullptr;
2828
+ }
2829
+ #endif
2830
+
2668
2831
  delete state;
2669
2832
  }
2670
2833
  }
@@ -2723,8 +2886,9 @@ int whisper_set_mel_with_state(
2723
2886
  return -1;
2724
2887
  }
2725
2888
 
2726
- state->mel.n_len = n_len;
2727
- state->mel.n_mel = n_mel;
2889
+ state->mel.n_len = n_len;
2890
+ state->mel.n_len_org = n_len;
2891
+ state->mel.n_mel = n_mel;
2728
2892
 
2729
2893
  state->mel.data.resize(n_len*n_mel);
2730
2894
  memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
@@ -2822,7 +2986,6 @@ int whisper_lang_id(const char * lang) {
2822
2986
  fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
2823
2987
  return -1;
2824
2988
  }
2825
-
2826
2989
  return g_lang.at(lang).first;
2827
2990
  }
2828
2991
 
@@ -2850,13 +3013,13 @@ int whisper_lang_auto_detect_with_state(
2850
3013
  return -1;
2851
3014
  }
2852
3015
 
2853
- if (seek >= state->mel.n_len) {
2854
- fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10);
3016
+ if (seek >= state->mel.n_len_org) {
3017
+ fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
2855
3018
  return -2;
2856
3019
  }
2857
3020
 
2858
3021
  // run the encoder
2859
- if (whisper_encode(ctx, seek, n_threads) != 0) {
3022
+ if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
2860
3023
  fprintf(stderr, "%s: failed to encode\n", __func__);
2861
3024
  return -6;
2862
3025
  }
@@ -2920,12 +3083,77 @@ int whisper_lang_auto_detect(
2920
3083
  return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
2921
3084
  }
2922
3085
 
3086
+ int whisper_model_n_vocab(struct whisper_context * ctx) {
3087
+ return ctx->model.hparams.n_vocab;
3088
+ }
3089
+
3090
+ int whisper_model_n_audio_ctx(struct whisper_context * ctx) {
3091
+ return ctx->model.hparams.n_audio_ctx;
3092
+ }
3093
+
3094
+ int whisper_model_n_audio_state(struct whisper_context * ctx) {
3095
+ return ctx->model.hparams.n_audio_state;
3096
+ }
3097
+
3098
+ int whisper_model_n_audio_head(struct whisper_context * ctx) {
3099
+ return ctx->model.hparams.n_audio_head;
3100
+ }
3101
+
3102
+ int whisper_model_n_audio_layer(struct whisper_context * ctx) {
3103
+ return ctx->model.hparams.n_audio_layer;
3104
+ }
3105
+
3106
+ int whisper_model_n_text_ctx(struct whisper_context * ctx) {
3107
+ return ctx->model.hparams.n_text_ctx;
3108
+ }
3109
+
3110
+ int whisper_model_n_text_state(struct whisper_context * ctx) {
3111
+ return ctx->model.hparams.n_text_state;
3112
+ }
3113
+
3114
+ int whisper_model_n_text_head(struct whisper_context * ctx) {
3115
+ return ctx->model.hparams.n_text_head;
3116
+ }
3117
+
3118
+ int whisper_model_n_text_layer(struct whisper_context * ctx) {
3119
+ return ctx->model.hparams.n_text_layer;
3120
+ }
3121
+
3122
+ int whisper_model_n_mels(struct whisper_context * ctx) {
3123
+ return ctx->model.hparams.n_mels;
3124
+ }
3125
+
3126
+ int whisper_model_ftype(struct whisper_context * ctx) {
3127
+ return ctx->model.hparams.ftype;
3128
+ }
3129
+
3130
+ int whisper_model_type(struct whisper_context * ctx) {
3131
+ return ctx->model.type;
3132
+ }
3133
+
3134
+ const char *whisper_model_type_readable(struct whisper_context * ctx) {
3135
+ switch (ctx->model.type) {
3136
+ case e_model::MODEL_TINY:
3137
+ return "tiny";
3138
+ case e_model::MODEL_BASE:
3139
+ return "base";
3140
+ case e_model::MODEL_SMALL:
3141
+ return "small";
3142
+ case e_model::MODEL_MEDIUM:
3143
+ return "medium";
3144
+ case e_model::MODEL_LARGE:
3145
+ return "large";
3146
+ default:
3147
+ return "unknown";
3148
+ }
3149
+ }
3150
+
2923
3151
  int whisper_n_len_from_state(struct whisper_state * state) {
2924
- return state->mel.n_len;
3152
+ return state->mel.n_len_org;
2925
3153
  }
2926
3154
 
2927
3155
  int whisper_n_len(struct whisper_context * ctx) {
2928
- return ctx->state->mel.n_len;
3156
+ return ctx->state->mel.n_len_org;
2929
3157
  }
2930
3158
 
2931
3159
  int whisper_n_vocab(struct whisper_context * ctx) {
@@ -3021,6 +3249,14 @@ void whisper_reset_timings(struct whisper_context * ctx) {
3021
3249
  }
3022
3250
  }
3023
3251
 
3252
+ static int whisper_has_coreml(void) {
3253
+ #ifdef WHISPER_USE_COREML
3254
+ return 1;
3255
+ #else
3256
+ return 0;
3257
+ #endif
3258
+ }
3259
+
3024
3260
  const char * whisper_print_system_info(void) {
3025
3261
  static std::string s;
3026
3262
 
@@ -3037,6 +3273,7 @@ const char * whisper_print_system_info(void) {
3037
3273
  s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
3038
3274
  s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
3039
3275
  s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
3276
+ s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
3040
3277
 
3041
3278
  return s.c_str();
3042
3279
  }
@@ -3070,10 +3307,12 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3070
3307
  /*.speed_up =*/ false,
3071
3308
  /*.audio_ctx =*/ 0,
3072
3309
 
3310
+ /*.initial_prompt =*/ nullptr,
3073
3311
  /*.prompt_tokens =*/ nullptr,
3074
3312
  /*.prompt_n_tokens =*/ 0,
3075
3313
 
3076
3314
  /*.language =*/ "en",
3315
+ /*.detect_language =*/ false,
3077
3316
 
3078
3317
  /*.suppress_blank =*/ true,
3079
3318
  /*.suppress_non_speech_tokens =*/ false,
@@ -3082,7 +3321,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3082
3321
  /*.max_initial_ts =*/ 1.0f,
3083
3322
  /*.length_penalty =*/ -1.0f,
3084
3323
 
3085
- /*.temperature_inc =*/ 0.2f,
3324
+ /*.temperature_inc =*/ 0.4f,
3086
3325
  /*.entropy_thold =*/ 2.4f,
3087
3326
  /*.logprob_thold =*/ -1.0f,
3088
3327
  /*.no_speech_thold =*/ 0.6f,
@@ -3100,6 +3339,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3100
3339
  /*.new_segment_callback =*/ nullptr,
3101
3340
  /*.new_segment_callback_user_data =*/ nullptr,
3102
3341
 
3342
+ /*.progress_callback =*/ nullptr,
3343
+ /*.progress_callback_user_data =*/ nullptr,
3344
+
3103
3345
  /*.encoder_begin_callback =*/ nullptr,
3104
3346
  /*.encoder_begin_callback_user_data =*/ nullptr,
3105
3347
 
@@ -3111,13 +3353,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3111
3353
  case WHISPER_SAMPLING_GREEDY:
3112
3354
  {
3113
3355
  result.greedy = {
3114
- /*.best_of =*/ 1,
3356
+ /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
3115
3357
  };
3116
3358
  } break;
3117
3359
  case WHISPER_SAMPLING_BEAM_SEARCH:
3118
3360
  {
3119
3361
  result.beam_search = {
3120
- /*.beam_size =*/ 5,
3362
+ /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
3121
3363
 
3122
3364
  /*.patience =*/ -1.0f,
3123
3365
  };
@@ -3138,15 +3380,15 @@ static void whisper_exp_compute_token_level_timestamps(
3138
3380
 
3139
3381
  // trim from start (in place)
3140
3382
  static inline void ltrim(std::string &s) {
3141
- s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
3142
- return !std::isspace(ch);
3383
+ s.erase(s.begin(), std::find_if_not(s.begin(), s.end(), [](unsigned char ch) {
3384
+ return std::isspace(ch);
3143
3385
  }));
3144
3386
  }
3145
3387
 
3146
3388
  // trim from end (in place)
3147
3389
  static inline void rtrim(std::string &s) {
3148
- s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
3149
- return !std::isspace(ch);
3390
+ s.erase(std::find_if_not(s.rbegin(), s.rend(), [](unsigned char ch) {
3391
+ return std::isspace(ch);
3150
3392
  }).base(), s.end());
3151
3393
  }
3152
3394
 
@@ -3657,7 +3899,7 @@ int whisper_full_with_state(
3657
3899
  }
3658
3900
 
3659
3901
  // auto-detect language if not specified
3660
- if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
3902
+ if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
3661
3903
  std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
3662
3904
 
3663
3905
  const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
@@ -3669,6 +3911,9 @@ int whisper_full_with_state(
3669
3911
  params.language = whisper_lang_str(lang_id);
3670
3912
 
3671
3913
  fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
3914
+ if (params.detect_language) {
3915
+ return 0;
3916
+ }
3672
3917
  }
3673
3918
 
3674
3919
  if (params.token_timestamps) {
@@ -3679,7 +3924,7 @@ int whisper_full_with_state(
3679
3924
  }
3680
3925
 
3681
3926
  const int seek_start = params.offset_ms/10;
3682
- const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len_from_state(state) : params.duration_ms/10);
3927
+ const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
3683
3928
 
3684
3929
  // if length of spectrogram is less than 1s (100 samples), then return
3685
3930
  // basically don't process anything that is less than 1s
@@ -3742,13 +3987,26 @@ int whisper_full_with_state(
3742
3987
  prompt_past.clear();
3743
3988
  }
3744
3989
 
3745
- // prepend the prompt tokens to the prompt_past
3746
- if (params.prompt_tokens && params.prompt_n_tokens > 0) {
3747
- // parse tokens from the pointer
3748
- for (int i = 0; i < params.prompt_n_tokens; i++) {
3749
- prompt_past.push_back(params.prompt_tokens[i]);
3990
+ // prepare prompt
3991
+ {
3992
+ std::vector<whisper_token> prompt_tokens;
3993
+
3994
+ // initial prompt
3995
+ if (!params.prompt_tokens && params.initial_prompt) {
3996
+ prompt_tokens.resize(1024);
3997
+ prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
3998
+ params.prompt_tokens = prompt_tokens.data();
3999
+ params.prompt_n_tokens = prompt_tokens.size();
4000
+ }
4001
+
4002
+ // prepend the prompt tokens to the prompt_past
4003
+ if (params.prompt_tokens && params.prompt_n_tokens > 0) {
4004
+ // parse tokens from the pointer
4005
+ for (int i = 0; i < params.prompt_n_tokens; i++) {
4006
+ prompt_past.push_back(params.prompt_tokens[i]);
4007
+ }
4008
+ std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
3750
4009
  }
3751
- std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
3752
4010
  }
3753
4011
 
3754
4012
  // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
@@ -3807,6 +4065,10 @@ int whisper_full_with_state(
3807
4065
  fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
3808
4066
  }
3809
4067
  }
4068
+ if (params.progress_callback) {
4069
+ params.progress_callback(
4070
+ ctx, ctx->state, progress_prev, params.progress_callback_user_data);
4071
+ }
3810
4072
 
3811
4073
  // of only 1 second left, then stop
3812
4074
  if (seek + 100 >= seek_end) {
@@ -4196,7 +4458,11 @@ int whisper_full_with_state(
4196
4458
  }
4197
4459
 
4198
4460
  // was the decoding successful for the current temperature?
4199
- {
4461
+ // do fallback only if:
4462
+ // - we are not at the last temperature
4463
+ // - we are not at the end of the audio (3 sec)
4464
+ if (it != (int) temperatures.size() - 1 &&
4465
+ seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
4200
4466
  bool success = true;
4201
4467
 
4202
4468
  const auto & decoder = state->decoders[best_decoder_id];
@@ -4395,6 +4661,9 @@ int whisper_full_parallel(
4395
4661
  params_cur.new_segment_callback = nullptr;
4396
4662
  params_cur.new_segment_callback_user_data = nullptr;
4397
4663
 
4664
+ params_cur.progress_callback = nullptr;
4665
+ params_cur.progress_callback_user_data = nullptr;
4666
+
4398
4667
  workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
4399
4668
  }
4400
4669
 
@@ -4562,49 +4831,51 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
4562
4831
 
4563
4832
  ggml_time_init();
4564
4833
 
4565
- size_t n = 50;
4566
- size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations
4834
+ size_t n = 20;
4835
+ size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations
4567
4836
 
4568
- // 1 GB array
4837
+ // 1GB MB array
4569
4838
  const size_t size = arr*1024llu*1024llu;
4570
4839
 
4571
- char * src = (char *) malloc(size);
4572
- char * dst = (char *) malloc(size);
4840
+ // single-thread
4841
+ {
4842
+ char * src = (char *) malloc(size);
4843
+ char * dst = (char *) malloc(size);
4573
4844
 
4574
- for (size_t i = 0; i < size; i++) src[i] = i;
4845
+ for (size_t i = 0; i < size; i++) src[i] = i;
4575
4846
 
4576
- memcpy(dst, src, size); // heat-up
4847
+ memcpy(dst, src, size); // heat-up
4577
4848
 
4578
- double tsum = 0.0;
4849
+ double tsum = 0.0;
4850
+ double sum = 0.0;
4579
4851
 
4580
- for (size_t i = 0; i < n; i++) {
4581
- const int64_t t0 = ggml_time_us();
4852
+ for (size_t i = 0; i < n; i++) {
4853
+ const int64_t t0 = ggml_time_us();
4582
4854
 
4583
- memcpy(dst, src, size);
4855
+ memcpy(dst, src, size);
4584
4856
 
4585
- const int64_t t1 = ggml_time_us();
4857
+ const int64_t t1 = ggml_time_us();
4586
4858
 
4587
- tsum += (t1 - t0)*1e-6;
4859
+ tsum += (t1 - t0)*1e-6;
4588
4860
 
4589
- src[0] = rand();
4590
- }
4861
+ src[rand() % size] = rand() % 256;
4862
+ }
4591
4863
 
4592
- snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
4593
- s += strbuf;
4864
+ snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
4865
+ s += strbuf;
4594
4866
 
4595
- // needed to prevent the compile from optimizing the memcpy away
4596
- {
4597
- double sum = 0.0;
4867
+ // needed to prevent the compiler from optimizing the memcpy away
4868
+ {
4869
+ for (size_t i = 0; i < size; i++) sum += dst[i];
4598
4870
 
4599
- for (size_t i = 0; i < size; i++) sum += dst[i];
4871
+ snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
4872
+ s += strbuf;
4873
+ }
4600
4874
 
4601
- snprintf(strbuf, sizeof(strbuf), "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
4602
- s += strbuf;
4875
+ free(src);
4876
+ free(dst);
4603
4877
  }
4604
4878
 
4605
- free(src);
4606
- free(dst);
4607
-
4608
4879
  return s.c_str();
4609
4880
  }
4610
4881
 
@@ -4634,27 +4905,48 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
4634
4905
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
4635
4906
  std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
4636
4907
 
4908
+ // put a bunch of random data in the buffer
4637
4909
  for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
4638
4910
 
4639
4911
  for (int j = 0; j < (int) sizes.size(); j++) {
4912
+ int n_q4_0 = 0;
4913
+ int n_q4_1 = 0;
4914
+ int n_q4_2 = 0;
4915
+ int n_q5_0 = 0;
4916
+ int n_q5_1 = 0;
4917
+ int n_q8_0 = 0;
4640
4918
  int n_fp16 = 0;
4641
4919
  int n_fp32 = 0;
4642
4920
 
4643
4921
  // GFLOPS/s
4922
+ double s_q4_0 = 0.0;
4923
+ double s_q4_1 = 0.0;
4924
+ double s_q4_2 = 0.0;
4925
+ double s_q5_0 = 0.0;
4926
+ double s_q5_1 = 0.0;
4927
+ double s_q8_0 = 0.0;
4644
4928
  double s_fp16 = 0.0;
4645
4929
  double s_fp32 = 0.0;
4646
4930
 
4647
4931
  const size_t N = sizes[j];
4648
4932
 
4649
- for (int k = 0; k < 2; ++k) {
4650
- const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32;
4933
+ for (int k = 0; k < 8; ++k) {
4934
+ const ggml_type wtype =
4935
+ k == 0 ? GGML_TYPE_Q4_0 :
4936
+ k == 1 ? GGML_TYPE_Q4_1 :
4937
+ k == 2 ? GGML_TYPE_Q4_2 :
4938
+ k == 3 ? GGML_TYPE_Q5_0 :
4939
+ k == 4 ? GGML_TYPE_Q5_1 :
4940
+ k == 5 ? GGML_TYPE_Q8_0 :
4941
+ k == 6 ? GGML_TYPE_F16 : GGML_TYPE_F32;
4651
4942
 
4652
- double & s = k == 0 ? s_fp16 : s_fp32;
4653
- int & n = k == 0 ? n_fp16 : n_fp32;
4943
+ double & s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 : k == 2 ? s_q4_2 : k == 3 ? s_q5_0 : k == 4 ? s_q5_1 : k == 5 ? s_q8_0 : k == 6 ? s_fp16 : /*k == 7*/ s_fp32;
4944
+ int & n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 : k == 2 ? n_q4_2 : k == 3 ? n_q5_0 : k == 4 ? n_q5_1 : k == 5 ? n_q8_0 : k == 6 ? n_fp16 : /*k == 7*/ n_fp32;
4654
4945
 
4655
4946
  struct ggml_init_params gparams = {
4656
4947
  /*.mem_size =*/ buf.size(),
4657
4948
  /*.mem_buffer =*/ buf.data(),
4949
+ /*.no_alloc =*/ false,
4658
4950
  };
4659
4951
 
4660
4952
  struct ggml_context * ctx0 = ggml_init(gparams);
@@ -4693,8 +4985,19 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
4693
4985
  s = ((2.0*N*N*N*n)/tsum)*1e-9;
4694
4986
  }
4695
4987
 
4696
- snprintf(strbuf, sizeof(strbuf), "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
4697
- N, N, s_fp16, n_fp16, s_fp32, n_fp32);
4988
+ // Q4_0 | Q4_1 | Q4_2
4989
+ snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q4_0 %7.1f GFLOPS (%3d runs) | Q4_1 %7.1f GFLOPS (%3d runs) | Q4_2 %7.1f GFLOPS (%3d runs)\n",
4990
+ N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1, s_q4_2, n_q4_2);
4991
+ s += strbuf;
4992
+
4993
+ // Q5_0 | Q5_1 | Q8_0
4994
+ snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q5_0 %7.1f GFLOPS (%3d runs) | Q5_1 %7.1f GFLOPS (%3d runs) | Q8_0 %7.1f GFLOPS (%3d runs)\n",
4995
+ N, N, s_q5_0, n_q5_0, s_q5_1, n_q5_1, s_q8_0, n_q8_0);
4996
+ s += strbuf;
4997
+
4998
+ // F16 | F32
4999
+ snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: F16 %7.1f GFLOPS (%3d runs) | F32 %7.1f GFLOPS (%3d runs)\n",
5000
+ N, N, s_fp16, n_fp16, s_fp32, n_fp32);
4698
5001
  s += strbuf;
4699
5002
  }
4700
5003