whisper.rn 0.3.9 → 0.4.0-rc.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. package/android/src/main/CMakeLists.txt +2 -1
  2. package/android/src/main/jni.cpp +7 -1
  3. package/cpp/coreml/whisper-encoder.mm +7 -1
  4. package/cpp/ggml-alloc.c +633 -0
  5. package/cpp/ggml-alloc.h +26 -0
  6. package/cpp/ggml-metal.h +85 -0
  7. package/cpp/ggml-metal.m +1283 -0
  8. package/cpp/ggml-metal.metal +2353 -0
  9. package/cpp/ggml.c +5024 -2924
  10. package/cpp/ggml.h +569 -95
  11. package/cpp/whisper.cpp +1014 -667
  12. package/cpp/whisper.h +13 -0
  13. package/ios/RNWhisper.mm +2 -0
  14. package/ios/RNWhisperContext.h +1 -1
  15. package/ios/RNWhisperContext.mm +18 -4
  16. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  17. package/lib/commonjs/index.js +3 -1
  18. package/lib/commonjs/index.js.map +1 -1
  19. package/lib/module/NativeRNWhisper.js.map +1 -1
  20. package/lib/module/index.js +3 -1
  21. package/lib/module/index.js.map +1 -1
  22. package/lib/typescript/NativeRNWhisper.d.ts +1 -0
  23. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  24. package/lib/typescript/index.d.ts +3 -1
  25. package/lib/typescript/index.d.ts.map +1 -1
  26. package/package.json +1 -1
  27. package/src/NativeRNWhisper.ts +1 -0
  28. package/src/index.ts +4 -0
  29. package/whisper-rn.podspec +8 -2
  30. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
  31. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
  32. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  33. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/whisper.cpp CHANGED
@@ -3,11 +3,16 @@
3
3
  #include "coreml/whisper-encoder.h"
4
4
  #endif
5
5
 
6
- #if WHISPER_USE_OPENVINO
6
+ #ifdef WSP_GGML_USE_METAL
7
+ # include "ggml-metal.h"
8
+ #endif
9
+
10
+ #ifdef WHISPER_USE_OPENVINO
7
11
  #include "openvino/whisper-openvino-encoder.h"
8
12
  #endif
9
13
 
10
14
  #include "ggml.h"
15
+ #include "ggml-alloc.h"
11
16
 
12
17
  #include <algorithm>
13
18
  #include <cassert>
@@ -18,11 +23,13 @@
18
23
  #include <cstring>
19
24
  #include <fstream>
20
25
  #include <map>
26
+ #include <set>
21
27
  #include <string>
22
28
  #include <thread>
23
29
  #include <vector>
24
30
  #include <regex>
25
31
  #include <random>
32
+ #include <functional>
26
33
 
27
34
  #if defined(_MSC_VER)
28
35
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -114,8 +121,66 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
114
121
  //#define WHISPER_USE_FLASH_FF
115
122
  #define WHISPER_MAX_DECODERS 16
116
123
 
117
- #define WHISPER_USE_SCRATCH
118
- #define WHISPER_MAX_SCRATCH_BUFFERS 16
124
+ //
125
+ // ggml helpers
126
+ //
127
+
128
+ static void wsp_ggml_graph_compute_helper(
129
+ std::vector<uint8_t> & buf,
130
+ wsp_ggml_cgraph * graph,
131
+ int n_threads,
132
+ whisper_abort_callback abort_callback,
133
+ void * abort_callback_data) {
134
+ struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads);
135
+
136
+ plan.abort_callback = abort_callback;
137
+ plan.abort_callback_data = abort_callback_data;
138
+
139
+ if (plan.work_size > 0) {
140
+ buf.resize(plan.work_size);
141
+ plan.work_data = buf.data();
142
+ }
143
+
144
+ wsp_ggml_graph_compute(graph, &plan);
145
+ }
146
+
147
+ // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
148
+ // the idea is to represent the original matrix multiplication:
149
+ //
150
+ // Z = X @ Y
151
+ //
152
+ // with the sum of two matrix multiplications:
153
+ //
154
+ // Z = (X_0 @ Y_0) + (X_1 @ Y_1)
155
+ //
156
+ // here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
157
+ // and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
158
+ // general-purpose kernels
159
+ //
160
+ static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * x, struct wsp_ggml_tensor * y, int pad = 32) {
161
+ // use padding only if dimension 0 is at least 8 times larger than the padding
162
+ // else we won't get much benefit from the optimization
163
+ const int n_pad_req = 8;
164
+
165
+ if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
166
+ return wsp_ggml_mul_mat(ctx, x, y);
167
+ }
168
+
169
+ struct wsp_ggml_tensor * x_0 = wsp_ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
170
+ struct wsp_ggml_tensor * x_1 = wsp_ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
171
+
172
+ struct wsp_ggml_tensor * y_0 = wsp_ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
173
+ struct wsp_ggml_tensor * y_1 = wsp_ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
174
+
175
+ return wsp_ggml_add(ctx,
176
+ wsp_ggml_mul_mat(ctx, x_0, y_0),
177
+ wsp_ggml_mul_mat(ctx, x_1, y_1));
178
+ }
179
+
180
+ // TODO: check if other platforms can benefit from this optimization
181
+ #if defined(WSP_GGML_USE_METAL)
182
+ #define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad
183
+ #endif
119
184
 
120
185
  // available whisper models
121
186
  enum e_model {
@@ -231,38 +296,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
231
296
 
232
297
  static const size_t MB = 1ull*1024*1024;
233
298
 
234
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
235
- { MODEL_TINY, 62ull*MB },
236
- { MODEL_BASE, 80ull*MB },
237
- { MODEL_SMALL, 120ull*MB },
238
- { MODEL_MEDIUM, 158ull*MB },
239
- { MODEL_LARGE, 198ull*MB },
240
- };
241
-
242
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
243
- { MODEL_TINY, 18ull*MB },
244
- { MODEL_BASE, 24ull*MB },
245
- { MODEL_SMALL, 36ull*MB },
246
- { MODEL_MEDIUM, 48ull*MB },
247
- { MODEL_LARGE, 60ull*MB },
248
- };
249
-
250
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
251
- { MODEL_TINY, 4ull*MB },
252
- { MODEL_BASE, 4ull*MB },
253
- { MODEL_SMALL, 6ull*MB },
254
- { MODEL_MEDIUM, 7ull*MB },
255
- { MODEL_LARGE, 9ull*MB },
256
- };
257
-
258
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
259
- { MODEL_TINY, 4ull*MB },
260
- { MODEL_BASE, 4ull*MB },
261
- { MODEL_SMALL, 6ull*MB },
262
- { MODEL_MEDIUM, 7ull*MB },
263
- { MODEL_LARGE, 9ull*MB },
264
- };
265
-
299
+ // TODO: avoid using GGUF
266
300
  static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
267
301
  { WSP_GGML_TYPE_F32,
268
302
  {
@@ -329,38 +363,6 @@ static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL =
329
363
  },
330
364
  };
331
365
 
332
- static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
333
- { MODEL_TINY, 3ull*MB },
334
- { MODEL_BASE, 6ull*MB },
335
- { MODEL_SMALL, 16ull*MB },
336
- { MODEL_MEDIUM, 43ull*MB },
337
- { MODEL_LARGE, 71ull*MB },
338
- };
339
-
340
- static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
341
- { MODEL_TINY, 9ull*MB },
342
- { MODEL_BASE, 18ull*MB },
343
- { MODEL_SMALL, 53ull*MB },
344
- { MODEL_MEDIUM, 141ull*MB },
345
- { MODEL_LARGE, 235ull*MB },
346
- };
347
-
348
- static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
349
- { MODEL_TINY, 30ull*MB },
350
- { MODEL_BASE, 38ull*MB },
351
- { MODEL_SMALL, 56ull*MB },
352
- { MODEL_MEDIUM, 74ull*MB },
353
- { MODEL_LARGE, 94ull*MB },
354
- };
355
-
356
- static const std::map<e_model, size_t> MEM_REQ_DECODE = {
357
- { MODEL_TINY, 3ull*MB },
358
- { MODEL_BASE, 5ull*MB },
359
- { MODEL_SMALL, 10ull*MB },
360
- { MODEL_MEDIUM, 18ull*MB },
361
- { MODEL_LARGE, 27ull*MB },
362
- };
363
-
364
366
  struct whisper_mel {
365
367
  int n_len;
366
368
  int n_len_org;
@@ -441,6 +443,7 @@ struct whisper_hparams {
441
443
  int32_t n_text_layer = 4;
442
444
  int32_t n_mels = 80;
443
445
  int32_t ftype = 1;
446
+ float eps = 1e-5f;
444
447
  };
445
448
 
446
449
  // audio encoding layer
@@ -536,6 +539,7 @@ struct whisper_kv_cache {
536
539
 
537
540
  struct wsp_ggml_context * ctx;
538
541
 
542
+ // buf points to the memory allocated for both wsp_ggml_tensor 'k' and 'v' (see kv_cache_init)
539
543
  std::vector<uint8_t> buf;
540
544
 
541
545
  int n; // number of tokens currently in the cache
@@ -601,7 +605,7 @@ struct whisper_sequence {
601
605
 
602
606
  // TAGS: WHISPER_DECODER_INIT
603
607
  struct whisper_decoder {
604
- // each decoders keeps its own KV-cache
608
+ // each decoder keeps its own KV-cache
605
609
  whisper_kv_cache kv_self;
606
610
 
607
611
  // the currently generated sequence of tokens
@@ -621,15 +625,75 @@ struct whisper_decoder {
621
625
  std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
622
626
  };
623
627
 
628
+ // replace std::pair by using customized pair struct (reason: std::pair is very slow)
629
+ template<typename A, typename B>
630
+ struct whisper_pair {
631
+ A first;
632
+ B second;
633
+
634
+ // Define a constructor that takes two arguments.
635
+ whisper_pair(const A& a, const B& b) : first(a), second(b) {}
636
+ // Define a constructor that takes no argument.
637
+ whisper_pair() : first(A()), second(B()) {}
638
+ };
639
+
640
+ // beam-search helpers
641
+ struct kv_buf {
642
+ std::vector<uint8_t> k;
643
+ std::vector<uint8_t> v;
644
+ };
645
+
646
+ // wsp_ggml_allocr wrapper for whisper usage
647
+ struct whisper_allocr {
648
+ wsp_ggml_allocr * alloc = nullptr;
649
+
650
+ std::vector<uint8_t> meta;
651
+ std::vector<uint8_t> data;
652
+ };
653
+
654
+ static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
655
+ return allocr.meta.size() + allocr.data.size();
656
+ }
657
+
658
+ // measure the memory usage of a graph and prepare the allocr's internal data buffer
659
+ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
660
+ const int tensor_alignment = 32;
661
+
662
+ auto & alloc = allocr.alloc;
663
+ auto & meta = allocr.meta;
664
+ auto & data = allocr.data;
665
+
666
+ meta.resize(wsp_ggml_tensor_overhead()*WSP_GGML_MAX_NODES + wsp_ggml_graph_overhead());
667
+
668
+ alloc = wsp_ggml_allocr_new_measure(tensor_alignment);
669
+
670
+ const size_t alloc_size = wsp_ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
671
+
672
+ wsp_ggml_allocr_free(alloc);
673
+
674
+ data.resize(alloc_size);
675
+
676
+ alloc = wsp_ggml_allocr_new(data.data(), data.size(), tensor_alignment);
677
+ }
678
+
679
+ static void whisper_allocr_free(struct whisper_allocr & allocr) {
680
+ if (allocr.alloc) {
681
+ wsp_ggml_allocr_free(allocr.alloc);
682
+ allocr.alloc = nullptr;
683
+ }
684
+ }
685
+
624
686
  struct whisper_state {
625
687
  int64_t t_sample_us = 0;
626
688
  int64_t t_encode_us = 0;
627
689
  int64_t t_decode_us = 0;
690
+ int64_t t_prompt_us = 0;
628
691
  int64_t t_mel_us = 0;
629
692
 
630
693
  int32_t n_sample = 0; // number of tokens sampled
631
694
  int32_t n_encode = 0; // number of encoder calls
632
- int32_t n_decode = 0; // number of decoder calls
695
+ int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
696
+ int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
633
697
  int32_t n_fail_p = 0; // number of logprob threshold failures
634
698
  int32_t n_fail_h = 0; // number of entropy threshold failures
635
699
 
@@ -640,12 +704,23 @@ struct whisper_state {
640
704
 
641
705
  whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
642
706
 
643
- // memory buffers used by encode / decode contexts
644
- std::vector<uint8_t> buf_compute;
645
- std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
707
+ // buffer for swapping KV caches between decoders during beam-search
708
+ std::vector<kv_buf> kv_swap_bufs;
646
709
 
647
- int buf_last = 0;
648
- size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
710
+ // reusable buffer for `struct wsp_ggml_graph_plan.work_data`
711
+ std::vector<uint8_t> work_buffer;
712
+
713
+ // ggml-alloc:
714
+ // - stores meta info about the intermediate tensors into the `meta` buffers
715
+ // - stores the actual tensor data into the `data` buffers
716
+ whisper_allocr alloc_conv;
717
+ whisper_allocr alloc_encode;
718
+ whisper_allocr alloc_cross;
719
+ whisper_allocr alloc_decode;
720
+
721
+ // result of the encoder
722
+ struct wsp_ggml_tensor * embd_conv = nullptr;
723
+ struct wsp_ggml_tensor * embd_enc = nullptr;
649
724
 
650
725
  // decode output (2-dimensional array: [n_tokens][n_vocab])
651
726
  std::vector<float> logits;
@@ -654,7 +729,7 @@ struct whisper_state {
654
729
  std::vector<whisper_token> prompt_past;
655
730
 
656
731
  // work container used to avoid memory allocations
657
- std::vector<std::pair<double, whisper_vocab::id>> logits_id;
732
+ std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
658
733
 
659
734
  mutable std::mt19937 rng; // used for sampling at t > 0.0
660
735
 
@@ -665,6 +740,10 @@ struct whisper_state {
665
740
  whisper_coreml_context * ctx_coreml = nullptr;
666
741
  #endif
667
742
 
743
+ #ifdef WSP_GGML_USE_METAL
744
+ wsp_ggml_metal_context * ctx_metal = nullptr;
745
+ #endif
746
+
668
747
  #ifdef WHISPER_USE_OPENVINO
669
748
  whisper_openvino_context * ctx_openvino = nullptr;
670
749
  #endif
@@ -677,37 +756,6 @@ struct whisper_state {
677
756
 
678
757
  // [EXPERIMENTAL] speed-up techniques
679
758
  int32_t exp_n_audio_ctx = 0; // 0 - use default
680
-
681
- void use_buf(struct wsp_ggml_context * ctx, int i) {
682
- #if defined(WHISPER_USE_SCRATCH)
683
- size_t last_size = 0;
684
-
685
- if (i == -1) {
686
- last_size = wsp_ggml_set_scratch(ctx, { 0, 0, nullptr, });
687
- } else {
688
- auto & buf = buf_scratch[i];
689
- last_size = wsp_ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
690
- }
691
-
692
- if (buf_last >= 0) {
693
- buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
694
- }
695
-
696
- buf_last = i;
697
- #else
698
- (void) i;
699
- (void) ctx;
700
- #endif
701
- }
702
-
703
- size_t get_buf_max_mem(int i) const {
704
- #if defined(WHISPER_USE_SCRATCH)
705
- return buf_max_size[i];
706
- #else
707
- (void) i;
708
- return 0;
709
- #endif
710
- }
711
759
  };
712
760
 
713
761
  struct whisper_context {
@@ -722,6 +770,9 @@ struct whisper_context {
722
770
  whisper_state * state = nullptr;
723
771
 
724
772
  std::string path_model; // populated by whisper_init_from_file()
773
+ #ifdef WHISPER_USE_COREML
774
+ bool load_coreml = true;
775
+ #endif
725
776
  };
726
777
 
727
778
  static void whisper_default_log(const char * text) {
@@ -730,6 +781,13 @@ static void whisper_default_log(const char * text) {
730
781
 
731
782
  static whisper_log_callback whisper_log = whisper_default_log;
732
783
 
784
+ #ifdef __GNUC__
785
+ #ifdef __MINGW32__
786
+ __attribute__((gnu_format(printf, 1, 2)))
787
+ #else
788
+ __attribute__((format(printf, 1, 2)))
789
+ #endif
790
+ #endif
733
791
  static void log(const char * fmt, ...) {
734
792
  if (!whisper_log) return;
735
793
  char buf[1024];
@@ -747,10 +805,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
747
805
 
748
806
  static bool kv_cache_init(
749
807
  const struct whisper_hparams & hparams,
750
- const size_t mem_bytes,
751
808
  struct whisper_kv_cache & cache,
752
809
  wsp_ggml_type wtype,
753
810
  int n_ctx) {
811
+ const int64_t n_text_state = hparams.n_text_state;
812
+ const int64_t n_text_layer = hparams.n_text_layer;
813
+
814
+ const int64_t n_mem = n_text_layer*n_ctx;
815
+ const int64_t n_elements = n_text_state*n_mem;
816
+
817
+ const size_t mem_bytes = 2*(wsp_ggml_type_size(wtype)*n_elements + wsp_ggml_tensor_overhead());
818
+
754
819
  cache.buf.resize(mem_bytes);
755
820
 
756
821
  struct wsp_ggml_init_params params = {
@@ -766,12 +831,6 @@ static bool kv_cache_init(
766
831
  return false;
767
832
  }
768
833
 
769
- const int n_text_state = hparams.n_text_state;
770
- const int n_text_layer = hparams.n_text_layer;
771
-
772
- const int n_mem = n_text_layer*n_ctx;
773
- const int n_elements = n_text_state*n_mem;
774
-
775
834
  cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
776
835
  cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
777
836
 
@@ -914,22 +973,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
914
973
 
915
974
  // print memory requirements
916
975
  {
917
- // this is the total memory required to run the inference
918
- const size_t mem_required =
919
- MEM_REQ_SCRATCH0.at(model.type) +
920
- MEM_REQ_SCRATCH1.at(model.type) +
921
- MEM_REQ_SCRATCH2.at(model.type) +
922
- MEM_REQ_SCRATCH3.at(model.type) +
923
- scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
924
- scale*MEM_REQ_KV_CROSS.at(model.type) +
925
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
926
-
927
- // this is the memory required by one decoder
928
- const size_t mem_required_decoder =
929
- scale*MEM_REQ_KV_SELF.at(model.type);
930
-
931
- log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
932
- mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
976
+ // TODO
977
+ //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
978
+ // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
933
979
  }
934
980
 
935
981
  // initialize all memory buffers
@@ -1438,49 +1484,56 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1438
1484
  return true;
1439
1485
  }
1440
1486
 
1441
- // evaluate the encoder with the given state
1442
- //
1443
- // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1444
- // part of the transformer model and returns the encoded features
1445
- //
1446
- // - wctx: the model
1447
- // - wstate: the state of the encoder
1448
- // - n_threads: number of threads to use
1449
- // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1450
- //
1451
- static bool whisper_encode_internal(
1452
- whisper_context & wctx,
1453
- whisper_state & wstate,
1454
- const int mel_offset,
1455
- const int n_threads){
1487
+ static bool whisper_encode_external(const whisper_state & wstate) {
1488
+ WSP_GGML_UNUSED(wstate);
1456
1489
 
1457
- const int64_t t_start_us = wsp_ggml_time_us();
1490
+ #ifndef WHISPER_USE_COREML
1491
+ const bool use_coreml = false;
1492
+ #else
1493
+ const bool use_coreml = wstate.ctx_coreml != nullptr;
1494
+ #endif
1495
+
1496
+ #ifndef WHISPER_USE_OPENVINO
1497
+ const bool use_openvino = false;
1498
+ #else
1499
+ const bool use_openvino = wstate.ctx_openvino != nullptr;
1500
+ #endif
1501
+
1502
+ return use_coreml || use_openvino;
1503
+ }
1458
1504
 
1505
+ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1506
+ whisper_context & wctx,
1507
+ whisper_state & wstate,
1508
+ const int mel_offset) {
1459
1509
  const auto & model = wctx.model;
1460
1510
  const auto & mel_inp = wstate.mel;
1461
1511
  const auto & hparams = model.hparams;
1462
1512
 
1463
1513
  const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1464
- const int n_state = hparams.n_audio_state;
1465
- const int n_head = hparams.n_audio_head;
1466
- const int n_layer = hparams.n_audio_layer;
1514
+ const int n_state = hparams.n_audio_state; WSP_GGML_UNUSED(n_state);
1467
1515
 
1468
1516
  const int n_mels = hparams.n_mels;
1469
- assert(mel_inp.n_mel == n_mels);
1470
1517
 
1471
1518
  struct wsp_ggml_init_params params = {
1472
- /*.mem_size =*/ wstate.buf_compute.size(),
1473
- /*.mem_buffer =*/ wstate.buf_compute.data(),
1474
- /*.no_alloc =*/ false,
1519
+ /*.mem_size =*/ wstate.alloc_conv.meta.size(),
1520
+ /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
1521
+ /*.no_alloc =*/ true,
1475
1522
  };
1476
1523
 
1477
1524
  struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1478
1525
 
1479
- wstate.use_buf(ctx0, 0);
1526
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1527
+
1528
+ wsp_ggml_allocr * alloc = wstate.alloc_conv.alloc;
1480
1529
 
1481
1530
  struct wsp_ggml_tensor * mel = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, 2*n_ctx, n_mels);
1531
+ wsp_ggml_allocr_alloc(alloc, mel);
1532
+
1482
1533
  assert(mel->type == WSP_GGML_TYPE_F32);
1483
- {
1534
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1535
+ assert(mel_inp.n_mel == n_mels);
1536
+
1484
1537
  float * dst = (float *) mel->data;
1485
1538
  memset(dst, 0, wsp_ggml_nbytes(mel));
1486
1539
 
@@ -1494,25 +1547,11 @@ static bool whisper_encode_internal(
1494
1547
  }
1495
1548
  }
1496
1549
 
1497
- struct wsp_ggml_tensor * cur;
1498
-
1499
- #ifndef WHISPER_USE_COREML
1500
- const bool use_coreml = false;
1501
- #else
1502
- const bool use_coreml = wstate.ctx_coreml != nullptr;
1503
- #endif
1504
-
1505
- #ifndef WHISPER_USE_OPENVINO
1506
- const bool use_openvino = false;
1507
- #else
1508
- const bool use_openvino = wstate.ctx_openvino != nullptr;
1509
- #endif
1550
+ struct wsp_ggml_tensor * cur = nullptr;
1510
1551
 
1511
- if (!use_coreml && !use_openvino) {
1552
+ if (!whisper_encode_external(wstate)) {
1512
1553
  // convolution + gelu
1513
1554
  {
1514
- wstate.use_buf(ctx0, 1);
1515
-
1516
1555
  cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1517
1556
  cur = wsp_ggml_add(ctx0,
1518
1557
  wsp_ggml_repeat(ctx0,
@@ -1522,8 +1561,6 @@ static bool whisper_encode_internal(
1522
1561
 
1523
1562
  cur = wsp_ggml_gelu(ctx0, cur);
1524
1563
 
1525
- wstate.use_buf(ctx0, 0);
1526
-
1527
1564
  cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1528
1565
  cur = wsp_ggml_add(ctx0,
1529
1566
  wsp_ggml_repeat(ctx0,
@@ -1534,373 +1571,433 @@ static bool whisper_encode_internal(
1534
1571
  cur = wsp_ggml_gelu(ctx0, cur);
1535
1572
  }
1536
1573
 
1537
- wstate.use_buf(ctx0, 3);
1574
+ wstate.embd_conv = cur;
1575
+ } else {
1576
+ #ifdef WHISPER_USE_COREML
1577
+ cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1578
+ wsp_ggml_allocr_alloc(alloc, cur);
1538
1579
 
1539
- // ===================================================================
1540
- // NOTE: experimenting with partial evaluation of the encoder (ignore)
1541
- //static int iter = -1;
1542
- //const int n_iter = 1500/n_ctx;
1580
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1581
+ whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1582
+ }
1583
+ #endif
1584
+ #ifdef WHISPER_USE_OPENVINO
1585
+ cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1586
+ wsp_ggml_allocr_alloc(alloc, cur);
1543
1587
 
1544
- //iter = (iter + 1) % n_iter;
1588
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1589
+ whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
1590
+ }
1591
+ #endif
1545
1592
 
1546
- //if (iter == 0) {
1547
- // memset(model.memory_cross_k->data, 0, wsp_ggml_nbytes(model.memory_cross_k));
1548
- // memset(model.memory_cross_v->data, 0, wsp_ggml_nbytes(model.memory_cross_v));
1549
- //}
1593
+ wstate.embd_enc = cur;
1594
+ }
1550
1595
 
1551
- static int iter = 0;
1596
+ wsp_ggml_build_forward_expand(gf, cur);
1552
1597
 
1553
- const size_t e_pe_stride = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe);
1554
- const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
1598
+ wsp_ggml_free(ctx0);
1555
1599
 
1556
- struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1600
+ return gf;
1601
+ }
1557
1602
 
1558
- cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_transpose(ctx0, cur));
1603
+ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1604
+ whisper_context & wctx,
1605
+ whisper_state & wstate) {
1606
+ const auto & model = wctx.model;
1607
+ const auto & hparams = model.hparams;
1559
1608
 
1560
- // ===================================================================
1609
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1610
+ const int n_state = hparams.n_audio_state;
1611
+ const int n_head = hparams.n_audio_head;
1612
+ const int n_layer = hparams.n_audio_layer;
1561
1613
 
1562
- // original:
1563
- //cur = wsp_ggml_add(ctx0, model.e_pe, wsp_ggml_transpose(ctx0, cur));
1614
+ struct wsp_ggml_init_params params = {
1615
+ /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1616
+ /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
1617
+ /*.no_alloc =*/ true,
1618
+ };
1564
1619
 
1565
- struct wsp_ggml_tensor * inpL = cur;
1620
+ struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1566
1621
 
1567
- for (int il = 0; il < n_layer; ++il) {
1568
- const auto & layer = model.layers_encoder[il];
1622
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1569
1623
 
1570
- // norm
1571
- {
1572
- wstate.use_buf(ctx0, 0);
1624
+ wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
1573
1625
 
1574
- cur = wsp_ggml_norm(ctx0, inpL);
1626
+ struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
1627
+ wsp_ggml_allocr_alloc(alloc, KQscale);
1575
1628
 
1576
- // cur = ln_0_w*cur + ln_0_b
1577
- cur = wsp_ggml_add(ctx0,
1578
- wsp_ggml_mul(ctx0,
1579
- wsp_ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1580
- cur),
1581
- wsp_ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1582
- }
1629
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1630
+ wsp_ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
1631
+ }
1583
1632
 
1584
- // self-attention
1585
- {
1586
- wstate.use_buf(ctx0, 1);
1633
+ struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
1587
1634
 
1588
- struct wsp_ggml_tensor * Qcur = wsp_ggml_mul_mat(ctx0,
1589
- layer.attn_q_w,
1590
- cur);
1635
+ // ===================================================================
1636
+ // NOTE: experimenting with partial evaluation of the encoder (ignore)
1637
+ //static int iter = -1;
1638
+ //const int n_iter = 1500/n_ctx;
1591
1639
 
1592
- Qcur = wsp_ggml_add(ctx0,
1593
- wsp_ggml_repeat(ctx0,
1594
- layer.attn_q_b,
1595
- Qcur),
1596
- Qcur);
1640
+ //iter = (iter + 1) % n_iter;
1597
1641
 
1598
- //Qcur = wsp_ggml_scale_inplace(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1642
+ //if (iter == 0) {
1643
+ // memset(model.memory_cross_k->data, 0, wsp_ggml_nbytes(model.memory_cross_k));
1644
+ // memset(model.memory_cross_v->data, 0, wsp_ggml_nbytes(model.memory_cross_v));
1645
+ //}
1599
1646
 
1600
- // note: no bias for Key
1601
- struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
1602
- layer.attn_k_w,
1603
- cur);
1647
+ static int iter = 0;
1604
1648
 
1605
- //Kcur = wsp_ggml_scale_inplace(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1649
+ const size_t e_pe_stride = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe);
1650
+ const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
1606
1651
 
1607
- struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
1608
- layer.attn_v_w,
1609
- cur);
1652
+ struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1610
1653
 
1611
- Vcur = wsp_ggml_add(ctx0,
1612
- wsp_ggml_repeat(ctx0,
1613
- layer.attn_v_b,
1614
- Vcur),
1615
- Vcur);
1616
-
1617
- // ------
1654
+ cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_cont(ctx0, wsp_ggml_transpose(ctx0, cur)));
1618
1655
 
1619
- wstate.use_buf(ctx0, 0);
1656
+ // ===================================================================
1620
1657
 
1621
- #ifdef WHISPER_USE_FLASH_ATTN
1622
- struct wsp_ggml_tensor * Q =
1623
- wsp_ggml_permute(ctx0,
1624
- wsp_ggml_cpy(ctx0,
1625
- Qcur,
1626
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1627
- 0, 2, 1, 3);
1628
-
1629
- struct wsp_ggml_tensor * K =
1630
- wsp_ggml_permute(ctx0,
1631
- wsp_ggml_cpy(ctx0,
1632
- Kcur,
1633
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1634
- 0, 2, 1, 3);
1635
-
1636
- struct wsp_ggml_tensor * V =
1637
- wsp_ggml_cpy(ctx0,
1638
- wsp_ggml_permute(ctx0,
1639
- wsp_ggml_reshape_3d(ctx0,
1640
- Vcur,
1641
- n_state/n_head, n_head, n_ctx),
1642
- 1, 2, 0, 3),
1643
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
1644
-
1645
- struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false);
1646
- #else
1647
- struct wsp_ggml_tensor * Q =
1648
- wsp_ggml_permute(ctx0,
1649
- wsp_ggml_cpy(ctx0,
1650
- Qcur,
1651
- wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1652
- 0, 2, 1, 3);
1653
-
1654
- struct wsp_ggml_tensor * K =
1655
- wsp_ggml_permute(ctx0,
1656
- wsp_ggml_cpy(ctx0,
1657
- Kcur,
1658
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1659
- 0, 2, 1, 3);
1660
-
1661
- // K * Q
1662
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
1663
-
1664
- struct wsp_ggml_tensor * KQ_scaled =
1665
- wsp_ggml_scale_inplace(ctx0,
1666
- KQ,
1667
- wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1668
- );
1669
-
1670
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ_scaled);
1671
-
1672
- struct wsp_ggml_tensor * V =
1673
- wsp_ggml_cpy(ctx0,
1674
- wsp_ggml_permute(ctx0,
1675
- wsp_ggml_reshape_3d(ctx0,
1676
- Vcur,
1677
- n_state/n_head, n_head, n_ctx),
1678
- 1, 2, 0, 3),
1679
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1680
- );
1681
-
1682
- struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
1683
- #endif
1684
- struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1658
+ // original:
1659
+ //cur = wsp_ggml_add(ctx0, model.e_pe, wsp_ggml_transpose(ctx0, cur));
1685
1660
 
1686
- wstate.use_buf(ctx0, 1);
1661
+ struct wsp_ggml_tensor * inpL = cur;
1687
1662
 
1688
- cur = wsp_ggml_cpy(ctx0,
1689
- KQV_merged,
1690
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
1691
- }
1663
+ for (int il = 0; il < n_layer; ++il) {
1664
+ const auto & layer = model.layers_encoder[il];
1692
1665
 
1693
- // projection
1694
- {
1695
- wstate.use_buf(ctx0, 0);
1666
+ // norm
1667
+ {
1668
+ cur = wsp_ggml_norm(ctx0, inpL, hparams.eps);
1696
1669
 
1697
- cur = wsp_ggml_mul_mat(ctx0,
1698
- layer.attn_ln_1_w,
1699
- cur);
1670
+ // cur = ln_0_w*cur + ln_0_b
1671
+ cur = wsp_ggml_add(ctx0,
1672
+ wsp_ggml_mul(ctx0, cur, layer.attn_ln_0_w),
1673
+ layer.attn_ln_0_b);
1674
+ }
1700
1675
 
1701
- wstate.use_buf(ctx0, 1);
1676
+ // self-attention
1677
+ {
1678
+ struct wsp_ggml_tensor * Qcur = wsp_ggml_mul_mat(ctx0,
1679
+ layer.attn_q_w,
1680
+ cur);
1702
1681
 
1703
- cur = wsp_ggml_add(ctx0,
1704
- wsp_ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1705
- cur);
1706
- }
1682
+ Qcur = wsp_ggml_add(ctx0, Qcur, layer.attn_q_b);
1707
1683
 
1708
- wstate.use_buf(ctx0, 2);
1684
+ //Qcur = wsp_ggml_scale(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1709
1685
 
1710
- // add the input
1711
- cur = wsp_ggml_add(ctx0, cur, inpL);
1686
+ // note: no bias for Key
1687
+ struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
1688
+ layer.attn_k_w,
1689
+ cur);
1712
1690
 
1713
- struct wsp_ggml_tensor * inpFF = cur;
1691
+ //Kcur = wsp_ggml_scale(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1714
1692
 
1715
- // feed-forward network
1716
- {
1717
- // norm
1718
- {
1719
- wstate.use_buf(ctx0, 0);
1693
+ struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
1694
+ layer.attn_v_w,
1695
+ cur);
1720
1696
 
1721
- cur = wsp_ggml_norm(ctx0, inpFF);
1697
+ Vcur = wsp_ggml_add(ctx0, Vcur, layer.attn_v_b);
1722
1698
 
1723
- wstate.use_buf(ctx0, 1);
1699
+ // ------
1724
1700
 
1725
- // cur = mlp_ln_w*cur + mlp_ln_b
1726
- cur = wsp_ggml_add(ctx0,
1727
- wsp_ggml_mul(ctx0,
1728
- wsp_ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1729
- cur),
1730
- wsp_ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1731
- }
1701
+ #ifdef WHISPER_USE_FLASH_ATTN
1702
+ struct wsp_ggml_tensor * Q =
1703
+ wsp_ggml_permute(ctx0,
1704
+ wsp_ggml_cpy(ctx0,
1705
+ Qcur,
1706
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1707
+ 0, 2, 1, 3);
1732
1708
 
1733
- #ifdef WHISPER_USE_FLASH_FF
1734
- wstate.use_buf(ctx0, 0);
1709
+ struct wsp_ggml_tensor * K =
1710
+ wsp_ggml_permute(ctx0,
1711
+ wsp_ggml_cpy(ctx0,
1712
+ Kcur,
1713
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1714
+ 0, 2, 1, 3);
1735
1715
 
1736
- cur = wsp_ggml_flash_ff(ctx0,
1737
- wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
1738
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1716
+ struct wsp_ggml_tensor * V =
1717
+ wsp_ggml_cpy(ctx0,
1718
+ wsp_ggml_permute(ctx0,
1719
+ wsp_ggml_reshape_3d(ctx0,
1720
+ Vcur,
1721
+ n_state/n_head, n_head, n_ctx),
1722
+ 1, 2, 0, 3),
1723
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
1724
+
1725
+ struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false);
1739
1726
  #else
1740
- wstate.use_buf(ctx0, 0);
1727
+ struct wsp_ggml_tensor * Q =
1728
+ wsp_ggml_permute(ctx0,
1729
+ wsp_ggml_cpy(ctx0,
1730
+ Qcur,
1731
+ wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1732
+ 0, 2, 1, 3);
1741
1733
 
1742
- // fully connected
1743
- cur = wsp_ggml_mul_mat(ctx0,
1744
- layer.mlp_0_w,
1745
- cur);
1734
+ struct wsp_ggml_tensor * K =
1735
+ wsp_ggml_permute(ctx0,
1736
+ wsp_ggml_cpy(ctx0,
1737
+ Kcur,
1738
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1739
+ 0, 2, 1, 3);
1746
1740
 
1747
- wstate.use_buf(ctx0, 1);
1741
+ // K * Q
1742
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
1748
1743
 
1749
- cur = wsp_ggml_add(ctx0,
1750
- wsp_ggml_repeat(ctx0, layer.mlp_0_b, cur),
1751
- cur);
1744
+ struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQscale);
1752
1745
 
1753
- wstate.use_buf(ctx0, 0);
1746
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_scaled);
1754
1747
 
1755
- // GELU activation
1756
- cur = wsp_ggml_gelu(ctx0, cur);
1748
+ struct wsp_ggml_tensor * V =
1749
+ wsp_ggml_cpy(ctx0,
1750
+ wsp_ggml_permute(ctx0,
1751
+ wsp_ggml_reshape_3d(ctx0,
1752
+ Vcur,
1753
+ n_state/n_head, n_head, n_ctx),
1754
+ 1, 2, 0, 3),
1755
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1756
+ );
1757
1757
 
1758
- wstate.use_buf(ctx0, 1);
1758
+ struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
1759
+ #endif
1760
+ struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1759
1761
 
1760
- // projection
1761
- cur = wsp_ggml_mul_mat(ctx0,
1762
- layer.mlp_1_w,
1763
- cur);
1762
+ cur = wsp_ggml_cpy(ctx0,
1763
+ KQV_merged,
1764
+ wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
1765
+ }
1766
+
1767
+ // projection
1768
+ {
1769
+ cur = wsp_ggml_mul_mat(ctx0,
1770
+ layer.attn_ln_1_w,
1771
+ cur);
1772
+
1773
+ cur = wsp_ggml_add(ctx0, cur, layer.attn_ln_1_b);
1774
+ }
1775
+
1776
+ // add the input
1777
+ cur = wsp_ggml_add(ctx0, cur, inpL);
1764
1778
 
1765
- wstate.use_buf(ctx0, 0);
1779
+ struct wsp_ggml_tensor * inpFF = cur;
1780
+
1781
+ // feed-forward network
1782
+ {
1783
+ // norm
1784
+ {
1785
+ cur = wsp_ggml_norm(ctx0, inpFF, hparams.eps);
1766
1786
 
1787
+ // cur = mlp_ln_w*cur + mlp_ln_b
1767
1788
  cur = wsp_ggml_add(ctx0,
1768
- wsp_ggml_repeat(ctx0, layer.mlp_1_b, cur),
1769
- cur);
1770
- #endif
1789
+ wsp_ggml_mul(ctx0, cur, layer.mlp_ln_w),
1790
+ layer.mlp_ln_b);
1771
1791
  }
1772
1792
 
1773
- wstate.use_buf(ctx0, 3);
1793
+ #ifdef WHISPER_USE_FLASH_FF
1794
+ cur = wsp_ggml_flash_ff(ctx0,
1795
+ wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
1796
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1797
+ #else
1798
+ // fully connected
1799
+ cur = wsp_ggml_mul_mat(ctx0,
1800
+ layer.mlp_0_w,
1801
+ cur);
1802
+
1803
+ cur = wsp_ggml_add(ctx0, cur, layer.mlp_0_b);
1804
+
1805
+ // GELU activation
1806
+ cur = wsp_ggml_gelu(ctx0, cur);
1807
+
1808
+ // projection
1809
+ cur = wsp_ggml_mul_mat(ctx0,
1810
+ layer.mlp_1_w,
1811
+ cur);
1774
1812
 
1775
- inpL = wsp_ggml_add(ctx0, cur, inpFF);
1813
+ cur = wsp_ggml_add(ctx0, cur, layer.mlp_1_b);
1814
+ #endif
1776
1815
  }
1777
1816
 
1778
- cur = inpL;
1817
+ inpL = wsp_ggml_add(ctx0, cur, inpFF);
1818
+ }
1779
1819
 
1780
- // norm
1781
- {
1782
- wstate.use_buf(ctx0, 0);
1820
+ cur = inpL;
1783
1821
 
1784
- cur = wsp_ggml_norm(ctx0, cur);
1822
+ // norm
1823
+ {
1824
+ cur = wsp_ggml_norm(ctx0, cur, hparams.eps);
1785
1825
 
1786
- wstate.use_buf(ctx0, 1);
1826
+ // cur = ln_f_g*cur + ln_f_b
1827
+ cur = wsp_ggml_add(ctx0,
1828
+ wsp_ggml_mul(ctx0, cur, model.e_ln_w),
1829
+ model.e_ln_b);
1830
+ }
1787
1831
 
1788
- // cur = ln_f_g*cur + ln_f_b
1789
- cur = wsp_ggml_add(ctx0,
1790
- wsp_ggml_mul(ctx0,
1791
- wsp_ggml_repeat(ctx0, model.e_ln_w, cur),
1792
- cur),
1793
- wsp_ggml_repeat(ctx0, model.e_ln_b, cur));
1794
- }
1832
+ wsp_ggml_build_forward_expand(gf, cur);
1795
1833
 
1796
- wstate.use_buf(ctx0, -1);
1834
+ wstate.embd_enc = cur;
1797
1835
 
1798
- // run the computation
1799
- {
1800
- struct wsp_ggml_cgraph gf = {};
1801
- gf.n_threads = n_threads;
1836
+ //wsp_ggml_graph_print(gf);
1802
1837
 
1803
- wsp_ggml_build_forward_expand(&gf, cur);
1804
- wsp_ggml_graph_compute(ctx0, &gf);
1838
+ ////////////////////////////////////////////////////////////////////////////
1805
1839
 
1806
- //wsp_ggml_graph_print(&gf);
1807
- }
1808
- }
1809
- #ifdef WHISPER_USE_COREML
1810
- else if (use_coreml) {
1811
- wstate.use_buf(ctx0, -1);
1840
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1841
+ // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
1842
+ // wstate.get_buf_max_mem(0)/1024.0/1024.0,
1843
+ // wstate.get_buf_max_mem(1)/1024.0/1024.0,
1844
+ // wstate.get_buf_max_mem(2)/1024.0/1024.0,
1845
+ // wstate.get_buf_max_mem(3)/1024.0/1024.0);
1812
1846
 
1813
- cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1847
+ wsp_ggml_free(ctx0);
1814
1848
 
1815
- whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1816
- }
1817
- #endif
1818
- #ifdef WHISPER_USE_OPENVINO
1819
- else if (use_openvino) {
1820
- wstate.use_buf(ctx0, -1);
1849
+ return gf;
1850
+ }
1821
1851
 
1822
- cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1852
+ // pre-compute cross-attention memory
1853
+ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
1854
+ whisper_context & wctx,
1855
+ whisper_state & wstate) {
1856
+ const auto & model = wctx.model;
1857
+ const auto & hparams = model.hparams;
1823
1858
 
1824
- if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
1825
- return false;
1826
- }
1827
- }
1828
- #endif
1859
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1860
+ const int n_state = hparams.n_audio_state;
1861
+ const int n_head = hparams.n_audio_head;
1829
1862
 
1830
- // cur
1831
- //{
1832
- // printf("ne0 = %d\n", cur->ne[0]);
1833
- // printf("ne1 = %d\n", cur->ne[1]);
1834
- // for (int i = 0; i < 10; ++i) {
1835
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1836
- // }
1837
- // printf("... ");
1838
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1839
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1840
- // }
1841
- // printf("\n");
1842
- //}
1863
+ struct wsp_ggml_init_params params = {
1864
+ /*.mem_size =*/ wstate.alloc_cross.meta.size(),
1865
+ /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
1866
+ /*.no_alloc =*/ true,
1867
+ };
1843
1868
 
1844
- // pre-compute cross-attention memory
1845
- {
1846
- struct wsp_ggml_cgraph gf = {};
1847
- gf.n_threads = n_threads;
1869
+ struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1870
+
1871
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1848
1872
 
1849
- // TODO: hack to disconnect the encoded features from the previous graph
1850
- cur->op = WSP_GGML_OP_NONE;
1851
- cur->src0 = nullptr;
1852
- cur->src1 = nullptr;
1873
+ wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
1853
1874
 
1854
- for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1855
- auto& layer = model.layers_decoder[il];
1875
+ struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
1876
+
1877
+ struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
1878
+ wsp_ggml_allocr_alloc(alloc, Kscale);
1879
+
1880
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1881
+ wsp_ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
1882
+ }
1856
1883
 
1857
- wstate.use_buf(ctx0, 0);
1884
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1885
+ auto & layer = model.layers_decoder[il];
1858
1886
 
1859
- struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
1887
+ struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
1860
1888
  layer.cross_attn_k_w,
1861
1889
  cur);
1862
1890
 
1863
- Kcross = wsp_ggml_scale_inplace(ctx0, Kcross, wsp_ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
1891
+ Kcross = wsp_ggml_scale(ctx0, Kcross, Kscale);
1864
1892
 
1865
- wstate.use_buf(ctx0, 1);
1866
-
1867
- struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
1893
+ struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
1868
1894
  layer.cross_attn_v_w,
1869
1895
  cur);
1870
1896
 
1871
- Vcross = wsp_ggml_add(ctx0,
1872
- wsp_ggml_repeat(ctx0,
1873
- layer.cross_attn_v_b,
1874
- Vcross),
1875
- Vcross);
1897
+ Vcross = wsp_ggml_add(ctx0,
1898
+ Vcross,
1899
+ layer.cross_attn_v_b);
1900
+
1901
+ Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
1902
+
1903
+ struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k,
1904
+ n_state*n_ctx,
1905
+ (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
1906
+
1907
+ struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
1908
+ ( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
1909
+ (il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
1910
+
1911
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcross, k));
1912
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcross, v));
1913
+ }
1876
1914
 
1877
- wstate.use_buf(ctx0, -1);
1915
+ //wsp_ggml_graph_print(gf);
1878
1916
 
1879
- Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
1917
+ wsp_ggml_free(ctx0);
1880
1918
 
1881
- struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
1882
- struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
1883
- ( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
1884
- (il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
1919
+ return gf;
1920
+ }
1885
1921
 
1886
- wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Kcross, k));
1887
- wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Vcross, v));
1922
+ // evaluate the encoder with the given state
1923
+ //
1924
+ // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1925
+ // part of the transformer model and returns the encoded features
1926
+ //
1927
+ // - wctx: the model
1928
+ // - wstate: the state of the encoder
1929
+ // - n_threads: number of threads to use
1930
+ // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1931
+ //
1932
+ static bool whisper_encode_internal(
1933
+ whisper_context & wctx,
1934
+ whisper_state & wstate,
1935
+ const int mel_offset,
1936
+ const int n_threads,
1937
+ whisper_abort_callback abort_callback,
1938
+ void * abort_callback_data) {
1939
+ const int64_t t_start_us = wsp_ggml_time_us();
1940
+
1941
+ // conv
1942
+ {
1943
+ auto & alloc = wstate.alloc_conv.alloc;
1944
+
1945
+ wsp_ggml_allocr_reset(alloc);
1946
+
1947
+ wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
1948
+
1949
+ wsp_ggml_allocr_alloc_graph(alloc, gf);
1950
+
1951
+ if (!whisper_encode_external(wstate)) {
1952
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1888
1953
  }
1954
+ }
1955
+
1956
+ // encoder
1957
+ if (!whisper_encode_external(wstate)) {
1958
+ auto & alloc = wstate.alloc_encode.alloc;
1959
+
1960
+ wsp_ggml_allocr_reset(alloc);
1889
1961
 
1890
- wsp_ggml_graph_compute(ctx0, &gf);
1891
- //wsp_ggml_graph_print(&gf);
1962
+ wsp_ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
1963
+
1964
+ wsp_ggml_allocr_alloc_graph(alloc, gf);
1965
+
1966
+ #ifdef WSP_GGML_USE_METAL
1967
+ if (wstate.ctx_metal) {
1968
+ wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1969
+ wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
1970
+ } else {
1971
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1972
+ }
1973
+ #else
1974
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1975
+ #endif
1892
1976
  }
1893
1977
 
1894
- ////////////////////////////////////////////////////////////////////////////
1978
+ // cross
1979
+ {
1980
+ auto & alloc = wstate.alloc_cross.alloc;
1895
1981
 
1896
- //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1897
- // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
1898
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
1899
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
1900
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
1901
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
1982
+ wsp_ggml_allocr_reset(alloc);
1902
1983
 
1903
- wsp_ggml_free(ctx0);
1984
+ wsp_ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
1985
+
1986
+ wsp_ggml_allocr_alloc_graph(alloc, gf);
1987
+
1988
+ #ifdef WSP_GGML_USE_METAL
1989
+ if (wstate.ctx_metal) {
1990
+ wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1991
+ wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
1992
+ } else {
1993
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1994
+ }
1995
+ #else
1996
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1997
+ #endif
1998
+ }
1999
+
2000
+ // wsp_ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
1904
2001
 
1905
2002
  wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
1906
2003
  wstate.n_encode++;
@@ -1908,26 +2005,13 @@ static bool whisper_encode_internal(
1908
2005
  return true;
1909
2006
  }
1910
2007
 
1911
- // evaluate the decoder
1912
- //
1913
- // given text prompt + audio features -> computes the logits for the next token
1914
- //
1915
- // - model: the model
1916
- // - n_threads: number of threads to use
1917
- // - tokens: text prompt
1918
- // - n_tokens: number of tokens in the prompt
1919
- // - n_past: number of past tokens to prefix the prompt with
1920
- //
1921
- static bool whisper_decode_internal(
1922
- whisper_context & wctx,
1923
- whisper_state & wstate,
1924
- whisper_decoder & decoder,
1925
- const whisper_token * tokens,
1926
- const int n_tokens,
1927
- const int n_past,
1928
- const int n_threads) {
1929
- const int64_t t_start_us = wsp_ggml_time_us();
1930
-
2008
+ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2009
+ whisper_context & wctx,
2010
+ whisper_state & wstate,
2011
+ whisper_decoder & decoder,
2012
+ const whisper_token * tokens,
2013
+ int n_tokens,
2014
+ int n_past) {
1931
2015
  const auto & model = wctx.model;
1932
2016
  const auto & hparams = model.hparams;
1933
2017
 
@@ -1935,10 +2019,6 @@ static bool whisper_decode_internal(
1935
2019
 
1936
2020
  WHISPER_ASSERT(!!kv_self.ctx);
1937
2021
 
1938
- auto & logits_out = wstate.logits;
1939
-
1940
- const int n_vocab = hparams.n_vocab;
1941
-
1942
2022
  const int n_ctx = hparams.n_text_ctx;
1943
2023
  const int n_state = hparams.n_text_state;
1944
2024
  const int n_head = hparams.n_text_head;
@@ -1950,25 +2030,39 @@ static bool whisper_decode_internal(
1950
2030
  //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1951
2031
 
1952
2032
  struct wsp_ggml_init_params params = {
1953
- /*.mem_size =*/ wstate.buf_compute.size(),
1954
- /*.mem_buffer =*/ wstate.buf_compute.data(),
1955
- /*.no_alloc =*/ false,
2033
+ /*.mem_size =*/ wstate.alloc_decode.meta.size(),
2034
+ /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
2035
+ /*.no_alloc =*/ true,
1956
2036
  };
1957
2037
 
1958
2038
  struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1959
2039
 
1960
- struct wsp_ggml_cgraph gf = {};
1961
- gf.n_threads = n_threads;
2040
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
2041
+
2042
+ wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
1962
2043
 
1963
2044
  struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
1964
- memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd));
2045
+ wsp_ggml_allocr_alloc(alloc, embd);
2046
+
2047
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
2048
+ memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd));
2049
+ }
1965
2050
 
1966
2051
  struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
1967
- for (int i = 0; i < N; ++i) {
1968
- ((int32_t *) position->data)[i] = n_past + i;
2052
+ wsp_ggml_allocr_alloc(alloc, position);
2053
+
2054
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
2055
+ for (int i = 0; i < N; ++i) {
2056
+ ((int32_t *) position->data)[i] = n_past + i;
2057
+ }
1969
2058
  }
1970
2059
 
1971
- wstate.use_buf(ctx0, 3);
2060
+ struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
2061
+ wsp_ggml_allocr_alloc(alloc, KQscale);
2062
+
2063
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
2064
+ wsp_ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
2065
+ }
1972
2066
 
1973
2067
  // token encoding + position encoding
1974
2068
  struct wsp_ggml_tensor * cur =
@@ -1983,16 +2077,14 @@ static bool whisper_decode_internal(
1983
2077
 
1984
2078
  // norm
1985
2079
  {
1986
- wstate.use_buf(ctx0, 0);
1987
-
1988
- cur = wsp_ggml_norm(ctx0, inpL);
2080
+ cur = wsp_ggml_norm(ctx0, inpL, hparams.eps);
1989
2081
 
1990
2082
  // cur = ln_0_w*cur + ln_0_b
1991
2083
  cur = wsp_ggml_add(ctx0,
1992
2084
  wsp_ggml_mul(ctx0,
1993
- wsp_ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1994
- cur),
1995
- wsp_ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
2085
+ cur,
2086
+ layer.attn_ln_0_w),
2087
+ layer.attn_ln_0_b);
1996
2088
  }
1997
2089
 
1998
2090
  // self-attention
@@ -2002,19 +2094,17 @@ static bool whisper_decode_internal(
2002
2094
  cur);
2003
2095
 
2004
2096
  Qcur = wsp_ggml_add(ctx0,
2005
- wsp_ggml_repeat(ctx0,
2006
- layer.attn_q_b,
2007
- Qcur),
2008
- Qcur);
2097
+ Qcur,
2098
+ layer.attn_q_b);
2009
2099
 
2010
- Qcur = wsp_ggml_scale_inplace(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2100
+ Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
2011
2101
 
2012
2102
  // note: no bias for Key
2013
2103
  struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
2014
2104
  layer.attn_k_w,
2015
2105
  cur);
2016
2106
 
2017
- Kcur = wsp_ggml_scale_inplace(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2107
+ Kcur = wsp_ggml_scale(ctx0, Kcur, KQscale);
2018
2108
 
2019
2109
  // store key and value to memory
2020
2110
  {
@@ -2023,10 +2113,8 @@ static bool whisper_decode_internal(
2023
2113
  cur);
2024
2114
 
2025
2115
  Vcur = wsp_ggml_add(ctx0,
2026
- wsp_ggml_repeat(ctx0,
2027
- layer.attn_v_b,
2028
- Vcur),
2029
- Vcur);
2116
+ Vcur,
2117
+ layer.attn_v_b);
2030
2118
 
2031
2119
  Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, N));
2032
2120
 
@@ -2035,42 +2123,32 @@ static bool whisper_decode_internal(
2035
2123
  ( n_ctx)*wsp_ggml_element_size(kv_self.v),
2036
2124
  (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + n_past*wsp_ggml_element_size(kv_self.v));
2037
2125
 
2038
- wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Kcur, k));
2039
- wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Vcur, v));
2126
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
2127
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
2040
2128
  }
2041
2129
 
2042
2130
  // ------
2043
2131
 
2044
- wstate.use_buf(ctx0, 0);
2045
-
2046
2132
  struct wsp_ggml_tensor * Q =
2047
2133
  wsp_ggml_permute(ctx0,
2048
- wsp_ggml_cpy(ctx0,
2049
- Qcur,
2050
- wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)),
2134
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2051
2135
  0, 2, 1, 3);
2052
2136
 
2053
2137
  struct wsp_ggml_tensor * K =
2054
- wsp_ggml_permute(ctx0,
2055
- wsp_ggml_reshape_3d(ctx0,
2056
- wsp_ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*wsp_ggml_element_size(kv_self.k)*n_state),
2057
- n_state/n_head, n_head, n_past + N),
2058
- 0, 2, 1, 3);
2059
-
2060
- wstate.use_buf(ctx0, 1);
2138
+ wsp_ggml_view_3d(ctx0, kv_self.k,
2139
+ n_state/n_head, n_past + N, n_head,
2140
+ wsp_ggml_element_size(kv_self.k)*n_state,
2141
+ wsp_ggml_element_size(kv_self.k)*n_state/n_head,
2142
+ wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
2061
2143
 
2062
2144
  // K * Q
2063
2145
  struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2064
2146
 
2065
- //struct wsp_ggml_tensor * KQ_scaled =
2066
- // wsp_ggml_scale_inplace(ctx0,
2067
- // KQ,
2068
- // wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2069
- // );
2147
+ //struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
2070
2148
 
2071
- struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf_inplace(ctx0, KQ, n_past);
2149
+ struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
2072
2150
 
2073
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ_masked);
2151
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
2074
2152
 
2075
2153
  struct wsp_ggml_tensor * V =
2076
2154
  wsp_ggml_view_3d(ctx0, kv_self.v,
@@ -2090,36 +2168,28 @@ static bool whisper_decode_internal(
2090
2168
 
2091
2169
  // projection
2092
2170
  {
2093
- wstate.use_buf(ctx0, 0);
2094
-
2095
2171
  cur = wsp_ggml_mul_mat(ctx0,
2096
2172
  layer.attn_ln_1_w,
2097
2173
  cur);
2098
2174
 
2099
- wstate.use_buf(ctx0, 1);
2100
-
2101
2175
  cur = wsp_ggml_add(ctx0,
2102
- wsp_ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
2103
- cur);
2176
+ cur,
2177
+ layer.attn_ln_1_b);
2104
2178
  }
2105
2179
 
2106
- wstate.use_buf(ctx0, 2);
2107
-
2108
2180
  // add the input
2109
2181
  struct wsp_ggml_tensor * inpCA = wsp_ggml_add(ctx0, cur, inpL);
2110
2182
 
2111
2183
  // norm
2112
2184
  {
2113
- wstate.use_buf(ctx0, 0);
2114
-
2115
- cur = wsp_ggml_norm(ctx0, inpCA); // note: we use inpCA here
2185
+ cur = wsp_ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
2116
2186
 
2117
2187
  // cur = ln_0_w*cur + ln_0_b
2118
2188
  cur = wsp_ggml_add(ctx0,
2119
2189
  wsp_ggml_mul(ctx0,
2120
- wsp_ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
2121
- cur),
2122
- wsp_ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
2190
+ cur,
2191
+ layer.cross_attn_ln_0_w),
2192
+ layer.cross_attn_ln_0_b);
2123
2193
  }
2124
2194
 
2125
2195
  // cross-attention
@@ -2129,18 +2199,18 @@ static bool whisper_decode_internal(
2129
2199
  cur);
2130
2200
 
2131
2201
  Qcur = wsp_ggml_add(ctx0,
2132
- wsp_ggml_repeat(ctx0,
2133
- layer.cross_attn_q_b,
2134
- Qcur),
2135
- Qcur);
2202
+ Qcur,
2203
+ layer.cross_attn_q_b);
2136
2204
 
2137
- Qcur = wsp_ggml_scale_inplace(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2205
+ Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
2138
2206
 
2139
2207
  // Kcross is already scaled
2140
2208
  struct wsp_ggml_tensor * Kcross =
2141
- wsp_ggml_reshape_3d(ctx0,
2142
- wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*wsp_ggml_element_size(wstate.kv_cross.k)*n_state),
2143
- n_state/n_head, n_head, M);
2209
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2210
+ n_state/n_head, M, n_head,
2211
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2212
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2213
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
2144
2214
 
2145
2215
  //struct wsp_ggml_tensor * Vcross =
2146
2216
  // wsp_ggml_reshape_3d(ctx0,
@@ -2163,26 +2233,22 @@ static bool whisper_decode_internal(
2163
2233
 
2164
2234
  struct wsp_ggml_tensor * Q =
2165
2235
  wsp_ggml_permute(ctx0,
2166
- wsp_ggml_cpy(ctx0,
2167
- Qcur,
2168
- wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)),
2236
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2169
2237
  0, 2, 1, 3);
2170
2238
 
2171
- struct wsp_ggml_tensor * K = wsp_ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2172
-
2173
2239
  // K * Q
2174
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2240
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
2175
2241
 
2176
2242
  //struct wsp_ggml_tensor * KQ_scaled =
2177
- // wsp_ggml_scale_inplace(ctx0,
2243
+ // wsp_ggml_scale(ctx0,
2178
2244
  // KQ,
2179
2245
  // wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2180
2246
  // );
2181
2247
 
2182
2248
  // no masking for cross-attention
2183
- //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
2249
+ //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2184
2250
 
2185
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ);
2251
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ);
2186
2252
 
2187
2253
  struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2188
2254
 
@@ -2196,21 +2262,15 @@ static bool whisper_decode_internal(
2196
2262
 
2197
2263
  // projection
2198
2264
  {
2199
- wstate.use_buf(ctx0, 0);
2200
-
2201
2265
  cur = wsp_ggml_mul_mat(ctx0,
2202
2266
  layer.cross_attn_ln_1_w,
2203
2267
  cur);
2204
2268
 
2205
- wstate.use_buf(ctx0, 1);
2206
-
2207
2269
  cur = wsp_ggml_add(ctx0,
2208
- wsp_ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2209
- cur);
2270
+ cur,
2271
+ layer.cross_attn_ln_1_b);
2210
2272
  }
2211
2273
 
2212
- wstate.use_buf(ctx0, 2);
2213
-
2214
2274
  // add the input
2215
2275
  cur = wsp_ggml_add(ctx0, cur, inpCA);
2216
2276
 
@@ -2220,54 +2280,38 @@ static bool whisper_decode_internal(
2220
2280
  {
2221
2281
  // norm
2222
2282
  {
2223
- wstate.use_buf(ctx0, 0);
2224
-
2225
- cur = wsp_ggml_norm(ctx0, inpFF);
2226
-
2227
- wstate.use_buf(ctx0, 1);
2283
+ cur = wsp_ggml_norm(ctx0, inpFF, hparams.eps);
2228
2284
 
2229
2285
  // cur = mlp_ln_w*cur + mlp_ln_b
2230
2286
  cur = wsp_ggml_add(ctx0,
2231
2287
  wsp_ggml_mul(ctx0,
2232
- wsp_ggml_repeat(ctx0, layer.mlp_ln_w, cur),
2233
- cur),
2234
- wsp_ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2288
+ cur,
2289
+ layer.mlp_ln_w),
2290
+ layer.mlp_ln_b);
2235
2291
  }
2236
2292
 
2237
- wstate.use_buf(ctx0, 0);
2238
-
2239
2293
  // fully connected
2240
2294
  cur = wsp_ggml_mul_mat(ctx0,
2241
2295
  layer.mlp_0_w,
2242
2296
  cur);
2243
2297
 
2244
- wstate.use_buf(ctx0, 1);
2245
-
2246
2298
  cur = wsp_ggml_add(ctx0,
2247
- wsp_ggml_repeat(ctx0, layer.mlp_0_b, cur),
2248
- cur);
2249
-
2250
- wstate.use_buf(ctx0, 0);
2299
+ cur,
2300
+ layer.mlp_0_b);
2251
2301
 
2252
2302
  // GELU activation
2253
2303
  cur = wsp_ggml_gelu(ctx0, cur);
2254
2304
 
2255
- wstate.use_buf(ctx0, 1);
2256
-
2257
2305
  // projection
2258
2306
  cur = wsp_ggml_mul_mat(ctx0,
2259
2307
  layer.mlp_1_w,
2260
2308
  cur);
2261
2309
 
2262
- wstate.use_buf(ctx0, 0);
2263
-
2264
2310
  cur = wsp_ggml_add(ctx0,
2265
- wsp_ggml_repeat(ctx0, layer.mlp_1_b, cur),
2266
- cur);
2311
+ cur,
2312
+ layer.mlp_1_b);
2267
2313
  }
2268
2314
 
2269
- wstate.use_buf(ctx0, 3);
2270
-
2271
2315
  inpL = wsp_ggml_add(ctx0, cur, inpFF);
2272
2316
  }
2273
2317
 
@@ -2275,21 +2319,15 @@ static bool whisper_decode_internal(
2275
2319
 
2276
2320
  // norm
2277
2321
  {
2278
- wstate.use_buf(ctx0, 0);
2279
-
2280
- cur = wsp_ggml_norm(ctx0, cur);
2281
-
2282
- wstate.use_buf(ctx0, 1);
2322
+ cur = wsp_ggml_norm(ctx0, cur, hparams.eps);
2283
2323
 
2284
2324
  cur = wsp_ggml_add(ctx0,
2285
2325
  wsp_ggml_mul(ctx0,
2286
- wsp_ggml_repeat(ctx0, model.d_ln_w, cur),
2287
- cur),
2288
- wsp_ggml_repeat(ctx0, model.d_ln_b, cur));
2326
+ cur,
2327
+ model.d_ln_w),
2328
+ model.d_ln_b);
2289
2329
  }
2290
2330
 
2291
- wstate.use_buf(ctx0, 0);
2292
-
2293
2331
  // compute logits only for the last token
2294
2332
  // comment this line to compute logits for all N tokens
2295
2333
  // might be useful in the future
@@ -2297,23 +2335,77 @@ static bool whisper_decode_internal(
2297
2335
 
2298
2336
  struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
2299
2337
 
2300
- wstate.use_buf(ctx0, -1);
2338
+ wsp_ggml_build_forward_expand(gf, logits);
2339
+
2340
+ wsp_ggml_free(ctx0);
2341
+
2342
+ return gf;
2343
+ }
2344
+
2345
+ // evaluate the decoder
2346
+ //
2347
+ // given text prompt + audio features -> computes the logits for the next token
2348
+ //
2349
+ // - model: the model
2350
+ // - n_threads: number of threads to use
2351
+ // - tokens: text prompt
2352
+ // - n_tokens: number of tokens in the prompt
2353
+ // - n_past: number of past tokens to prefix the prompt with
2354
+ //
2355
+ static bool whisper_decode_internal(
2356
+ whisper_context & wctx,
2357
+ whisper_state & wstate,
2358
+ whisper_decoder & decoder,
2359
+ const whisper_token * tokens,
2360
+ const int n_tokens,
2361
+ const int n_past,
2362
+ const int n_threads,
2363
+ whisper_abort_callback abort_callback,
2364
+ void * abort_callback_data) {
2365
+ const int64_t t_start_us = wsp_ggml_time_us();
2366
+
2367
+ const auto & model = wctx.model;
2368
+ const auto & hparams = model.hparams;
2369
+
2370
+ const int n_vocab = hparams.n_vocab;
2371
+
2372
+ auto & logits_out = wstate.logits;
2373
+
2374
+ struct wsp_ggml_tensor * logits;
2301
2375
 
2302
- // run the computation
2376
+ // decoder
2303
2377
  {
2304
- wsp_ggml_build_forward_expand(&gf, logits);
2305
- wsp_ggml_graph_compute (ctx0, &gf);
2378
+ auto & alloc = wstate.alloc_decode.alloc;
2379
+
2380
+ wsp_ggml_allocr_reset(alloc);
2381
+
2382
+ wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
2383
+
2384
+ wsp_ggml_allocr_alloc_graph(alloc, gf);
2385
+
2386
+ logits = gf->nodes[gf->n_nodes - 1];
2387
+
2388
+ #ifdef WSP_GGML_USE_METAL
2389
+ if (wstate.ctx_metal) {
2390
+ wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2391
+ wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
2392
+ } else {
2393
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2394
+ }
2395
+ #else
2396
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2397
+ #endif
2306
2398
  }
2307
2399
 
2308
2400
  // extract logits for all N tokens
2309
- //logits_out.resize(N*n_vocab);
2310
- //memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*N*n_vocab);
2401
+ //logits_out.resize(n_tokens*n_vocab);
2402
+ //memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
2311
2403
 
2312
2404
  // extract logits only for the last token
2313
2405
  logits_out.resize(n_vocab);
2314
2406
  memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_vocab);
2315
2407
 
2316
- if (N > 1) {
2408
+ if (n_tokens > 1) {
2317
2409
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2318
2410
  // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
2319
2411
  // wstate.get_buf_max_mem(0)/1024.0/1024.0,
@@ -2322,14 +2414,18 @@ static bool whisper_decode_internal(
2322
2414
  // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2323
2415
  }
2324
2416
 
2325
- wsp_ggml_free(ctx0);
2326
-
2327
- wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
2328
- wstate.n_decode++;
2417
+ if (n_tokens == 1) {
2418
+ wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
2419
+ wstate.n_decode++;
2420
+ } else {
2421
+ wstate.t_prompt_us += wsp_ggml_time_us() - t_start_us;
2422
+ wstate.n_prompt++;
2423
+ }
2329
2424
 
2330
2425
  return true;
2331
2426
  }
2332
2427
 
2428
+
2333
2429
  // 500 -> 00:05.000
2334
2430
  // 6000 -> 01:00.000
2335
2431
  static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2351,7 +2447,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2351
2447
  static float sin_vals[SIN_COS_N_COUNT];
2352
2448
  static float cos_vals[SIN_COS_N_COUNT];
2353
2449
 
2354
- // In FFT, we frequently use sine and cosine operations with the same values.
2450
+ // In FFT, we frequently use sine and cosine operations with the same values.
2355
2451
  // We can use precalculated values to speed up the process.
2356
2452
  static void fill_sin_cos_table() {
2357
2453
  static bool is_filled = false;
@@ -2446,7 +2542,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2446
2542
  }
2447
2543
 
2448
2544
  static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2449
- if (output.size() < length) {
2545
+ if (output.size() < static_cast<size_t>(length)) {
2450
2546
  output.resize(length);
2451
2547
  }
2452
2548
  int offset = -1;
@@ -2738,9 +2834,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2738
2834
  fill_sin_cos_table();
2739
2835
  whisper_state * state = new whisper_state;
2740
2836
 
2741
- const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
2742
-
2743
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2837
+ if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2744
2838
  log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
2745
2839
  delete state;
2746
2840
  return nullptr;
@@ -2751,7 +2845,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2751
2845
  log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2752
2846
  }
2753
2847
 
2754
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2848
+ if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2755
2849
  log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2756
2850
  delete state;
2757
2851
  return nullptr;
@@ -2763,6 +2857,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2763
2857
  }
2764
2858
 
2765
2859
  #ifdef WHISPER_USE_COREML
2860
+ if (ctx->load_coreml) { // Not in correct layer for easy patch
2766
2861
  const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
2767
2862
 
2768
2863
  log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
@@ -2772,11 +2867,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2772
2867
  if (!state->ctx_coreml) {
2773
2868
  log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2774
2869
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
2870
+ delete state;
2775
2871
  return nullptr;
2776
2872
  #endif
2777
2873
  } else {
2778
2874
  log("%s: Core ML model loaded\n", __func__);
2779
2875
  }
2876
+ }
2780
2877
  #endif
2781
2878
 
2782
2879
  state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
@@ -2786,21 +2883,134 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2786
2883
  // TAGS: WHISPER_DECODER_INIT
2787
2884
  state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
2788
2885
 
2789
- state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
2790
- state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
2886
+ state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
2887
+ state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
2791
2888
  state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
2792
- state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
2793
2889
 
2794
- state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
2795
- state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
2796
- state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
2797
- state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
2890
+ // conv allocator
2891
+ {
2892
+ whisper_allocr_graph_init(state->alloc_conv,
2893
+ [&]() {
2894
+ return whisper_build_graph_conv(*ctx, *state, 0);
2895
+ });
2896
+
2897
+ log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
2898
+ }
2899
+
2900
+ // encoder allocator
2901
+ if (!whisper_encode_external(*state)) {
2902
+ whisper_allocr_graph_init(state->alloc_encode,
2903
+ [&]() {
2904
+ return whisper_build_graph_encoder(*ctx, *state);
2905
+ });
2906
+
2907
+ log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
2908
+ }
2909
+
2910
+ // cross allocator
2911
+ {
2912
+ whisper_allocr_graph_init(state->alloc_cross,
2913
+ [&]() {
2914
+ return whisper_build_graph_cross(*ctx, *state);
2915
+ });
2916
+
2917
+ log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
2918
+ }
2919
+
2920
+ // decoder allocator
2921
+ {
2922
+ whisper_allocr_graph_init(state->alloc_decode,
2923
+ [&]() {
2924
+ const auto & hparams = ctx->model.hparams;
2925
+
2926
+ // TODO: make sure this is the worst-case scenario
2927
+ const int n_tokens = hparams.n_text_ctx;
2928
+ const int n_past = 0;
2929
+
2930
+ return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
2931
+ });
2932
+
2933
+ log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
2934
+ }
2935
+
2936
+ #ifdef WSP_GGML_USE_METAL
2937
+ state->ctx_metal = wsp_ggml_metal_init(1);
2938
+ if (!state->ctx_metal) {
2939
+ log("%s: wsp_ggml_metal_init() failed\n", __func__);
2940
+ delete state;
2941
+ return nullptr;
2942
+ }
2943
+
2944
+ log("%s: Metal context initialized\n", __func__);
2945
+
2946
+ // this allocates all Metal resources and memory buffers
2947
+
2948
+ void * data_ptr = NULL;
2949
+ size_t data_size = 0;
2950
+
2951
+ // TODO: add mmap support
2952
+ //if (params.use_mmap) {
2953
+ // data_ptr = ctx->model.mapping->addr;
2954
+ // data_size = ctx->model.mapping->size;
2955
+ //} else {
2956
+ // data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2957
+ // data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2958
+ //}
2959
+
2960
+ data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2961
+ data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2962
+
2963
+ const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx);
2964
+
2965
+ log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
2966
+
2967
+ #define WHISPER_METAL_CHECK_BUF(result) \
2968
+ if (!(result)) { \
2969
+ log("%s: failed to add metal buffer\n", __func__); \
2970
+ delete state; \
2971
+ return nullptr; \
2972
+ }
2973
+
2974
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
2975
+
2976
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
2977
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
2978
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
2979
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
2980
+
2981
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
2982
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
2983
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
2984
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
2985
+
2986
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
2987
+
2988
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
2989
+ #undef WHISPER_METAL_CHECK_BUF
2990
+ #endif
2798
2991
 
2799
2992
  state->rng = std::mt19937(0);
2800
2993
 
2801
2994
  return state;
2802
2995
  }
2803
2996
 
2997
+ #ifdef WHISPER_USE_COREML
2998
+ struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) {
2999
+ whisper_context * ctx = whisper_init_from_file_no_state(path_model);
3000
+ if (!ctx) {
3001
+ return nullptr;
3002
+ }
3003
+ ctx->load_coreml = false;
3004
+ ctx->state = whisper_init_state(ctx);
3005
+ if (!ctx->state) {
3006
+ whisper_free(ctx);
3007
+ return nullptr;
3008
+ }
3009
+
3010
+ return ctx;
3011
+ }
3012
+ #endif
3013
+
2804
3014
  int whisper_ctx_init_openvino_encoder(
2805
3015
  struct whisper_context * ctx,
2806
3016
  const char * model_path,
@@ -2851,7 +3061,6 @@ int whisper_ctx_init_openvino_encoder(
2851
3061
  }
2852
3062
 
2853
3063
  struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
2854
-
2855
3064
  log("%s: loading model from '%s'\n", __func__, path_model);
2856
3065
 
2857
3066
  auto fin = std::ifstream(path_model, std::ios::binary);
@@ -3004,6 +3213,13 @@ void whisper_free_state(struct whisper_state * state)
3004
3213
  }
3005
3214
  #endif
3006
3215
 
3216
+ #ifdef WSP_GGML_USE_METAL
3217
+ if (state->ctx_metal) {
3218
+ wsp_ggml_metal_free(state->ctx_metal);
3219
+ state->ctx_metal = nullptr;
3220
+ }
3221
+ #endif
3222
+
3007
3223
  #ifdef WHISPER_USE_OPENVINO
3008
3224
  if (state->ctx_openvino != nullptr) {
3009
3225
  whisper_openvino_free(state->ctx_openvino);
@@ -3011,6 +3227,11 @@ void whisper_free_state(struct whisper_state * state)
3011
3227
  }
3012
3228
  #endif
3013
3229
 
3230
+ whisper_allocr_free(state->alloc_conv);
3231
+ whisper_allocr_free(state->alloc_decode);
3232
+ whisper_allocr_free(state->alloc_cross);
3233
+ whisper_allocr_free(state->alloc_encode);
3234
+
3014
3235
  delete state;
3015
3236
  }
3016
3237
  }
@@ -3103,7 +3324,7 @@ int whisper_set_mel(
3103
3324
  }
3104
3325
 
3105
3326
  int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3106
- if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
3327
+ if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
3107
3328
  log("%s: failed to eval\n", __func__);
3108
3329
  return -1;
3109
3330
  }
@@ -3112,7 +3333,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
3112
3333
  }
3113
3334
 
3114
3335
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3115
- if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
3336
+ if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
3116
3337
  log("%s: failed to eval\n", __func__);
3117
3338
  return -1;
3118
3339
  }
@@ -3123,7 +3344,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3123
3344
  int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3124
3345
  const int selected_decoder_id = 0;
3125
3346
 
3126
- if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3347
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3127
3348
  log("%s: failed to eval\n", __func__);
3128
3349
  return 1;
3129
3350
  }
@@ -3140,7 +3361,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
3140
3361
  return false;
3141
3362
  }
3142
3363
 
3143
- if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3364
+ if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3144
3365
  log("%s: failed to eval\n", __func__);
3145
3366
  return 1;
3146
3367
  }
@@ -3431,12 +3652,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
3431
3652
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3432
3653
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3433
3654
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3655
+ const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3434
3656
 
3435
3657
  log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3436
3658
  log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3437
3659
  log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3438
3660
  log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3439
3661
  log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3662
+ log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3440
3663
  }
3441
3664
  log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3442
3665
  }
@@ -3446,6 +3669,11 @@ void whisper_reset_timings(struct whisper_context * ctx) {
3446
3669
  ctx->state->t_sample_us = 0;
3447
3670
  ctx->state->t_encode_us = 0;
3448
3671
  ctx->state->t_decode_us = 0;
3672
+ ctx->state->t_prompt_us = 0;
3673
+ ctx->state->n_sample = 0;
3674
+ ctx->state->n_encode = 0;
3675
+ ctx->state->n_decode = 0;
3676
+ ctx->state->n_prompt = 0;
3449
3677
  }
3450
3678
  }
3451
3679
 
@@ -3475,6 +3703,7 @@ const char * whisper_print_system_info(void) {
3475
3703
  s += "FMA = " + std::to_string(wsp_ggml_cpu_has_fma()) + " | ";
3476
3704
  s += "NEON = " + std::to_string(wsp_ggml_cpu_has_neon()) + " | ";
3477
3705
  s += "ARM_FMA = " + std::to_string(wsp_ggml_cpu_has_arm_fma()) + " | ";
3706
+ s += "METAL = " + std::to_string(wsp_ggml_cpu_has_metal()) + " | ";
3478
3707
  s += "F16C = " + std::to_string(wsp_ggml_cpu_has_f16c()) + " | ";
3479
3708
  s += "FP16_VA = " + std::to_string(wsp_ggml_cpu_has_fp16_va()) + " | ";
3480
3709
  s += "WASM_SIMD = " + std::to_string(wsp_ggml_cpu_has_wasm_simd()) + " | ";
@@ -3566,6 +3795,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3566
3795
  /*.encoder_begin_callback =*/ nullptr,
3567
3796
  /*.encoder_begin_callback_user_data =*/ nullptr,
3568
3797
 
3798
+ /*.abort_callback =*/ nullptr,
3799
+ /*.abort_callback_user_data =*/ nullptr,
3800
+
3569
3801
  /*.logits_filter_callback =*/ nullptr,
3570
3802
  /*.logits_filter_callback_user_data =*/ nullptr,
3571
3803
  };
@@ -3970,17 +4202,21 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
3970
4202
 
3971
4203
  auto & logits_id = state.logits_id;
3972
4204
 
3973
- logits_id.clear();
4205
+ logits_id.resize(n_logits);
3974
4206
  for (int i = 0; i < n_logits; ++i) {
3975
- logits_id.push_back({ logits[i], i });
4207
+ logits_id[i].first = logits[i];
4208
+ logits_id[i].second = i;
3976
4209
  }
3977
4210
 
3978
- std::partial_sort(
3979
- logits_id.begin(),
3980
- logits_id.begin() + k, logits_id.end(),
3981
- [](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
3982
- return a.first > b.first;
3983
- });
4211
+ {
4212
+ using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
4213
+ std::partial_sort(
4214
+ logits_id.begin(),
4215
+ logits_id.begin() + k, logits_id.end(),
4216
+ [](const pair_type & a, const pair_type & b) {
4217
+ return a.first > b.first;
4218
+ });
4219
+ }
3984
4220
 
3985
4221
  std::vector<whisper_token_data> result;
3986
4222
  result.reserve(k);
@@ -4075,6 +4311,115 @@ static void whisper_sequence_score(
4075
4311
  }
4076
4312
  }
4077
4313
 
4314
+ static bool whisper_kv_swap_fast(
4315
+ std::vector<int> & view,
4316
+ whisper_decoder src[],
4317
+ std::vector<kv_buf> & kv_swap_bufs,
4318
+ const int & n_decoders) {
4319
+ WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
4320
+
4321
+ // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
4322
+ std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
4323
+
4324
+ // (buffer->decoder or decoder->decoder)
4325
+ std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
4326
+
4327
+ // (decoder<->decoder)
4328
+ std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
4329
+ std::vector<whisper_pair<int, int>> p_swap_vec;
4330
+ p_swap_vec.reserve(n_decoders);
4331
+
4332
+ // see https://github.com/ggerganov/whisper.cpp/wiki
4333
+ for (int i = 0; i < n_decoders; i++) {
4334
+ // zero-copy (no modification)
4335
+ if (i == view[i] || view[i] < 0) {
4336
+ continue;
4337
+ }
4338
+
4339
+ bool is_one_copy = true;
4340
+ // since we modify data sequentially, we only consider decoder indices after current index
4341
+ for (int j = i + 1; j < n_decoders; j++) {
4342
+ if (i == view[j]) {
4343
+ // detect symmetric diagram
4344
+ if (j == view[i]) {
4345
+ p_swap_set.insert(i);
4346
+ p_swap_set.insert(j);
4347
+ p_swap_vec.emplace_back(i, j);
4348
+ } else {
4349
+ two_copy.insert(i);
4350
+ is_one_copy = false;
4351
+ }
4352
+ break;
4353
+ }
4354
+ }
4355
+ if (is_one_copy) {
4356
+ one_copy.insert(i);
4357
+ }
4358
+ }
4359
+
4360
+ kv_swap_bufs.resize(n_decoders);
4361
+
4362
+ for (int i = 0; i < n_decoders; i++) {
4363
+ kv_swap_bufs[i].k.resize(wsp_ggml_nbytes(src[i].kv_self.k));
4364
+ kv_swap_bufs[i].v.resize(wsp_ggml_nbytes(src[i].kv_self.v));
4365
+ }
4366
+
4367
+ for (auto & i : two_copy) {
4368
+ // make a copy of KV caches
4369
+ WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
4370
+ memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
4371
+ memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
4372
+ }
4373
+
4374
+ // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
4375
+ for (auto & i : two_copy) {
4376
+ // skip the decoder indices that require pointer swapping
4377
+ if (p_swap_set.find(i) != p_swap_set.end()) {
4378
+ continue;
4379
+ }
4380
+
4381
+ if (two_copy.find(view[i]) != two_copy.end()) {
4382
+ // modify KV caches of decoder using data from kv_swap_bufs
4383
+ WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4384
+ memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4385
+ memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4386
+ } else {
4387
+ // modify KV caches of decoder using data from correspond decoder KV caches directly
4388
+ WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4389
+ memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
4390
+ memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
4391
+ }
4392
+ }
4393
+
4394
+ // then modify one-copy decoder KV caches
4395
+ for (auto & i : one_copy) {
4396
+ // skip the decoder indices that require pointer swapping
4397
+ if (p_swap_set.find(i) != p_swap_set.end()) {
4398
+ continue;
4399
+ }
4400
+
4401
+ if (two_copy.find(view[i]) != two_copy.end()) {
4402
+ // modify KV caches of decoder using data from kv_swap_bufs
4403
+ WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4404
+ memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4405
+ memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4406
+ } else {
4407
+ // modify KV caches of decoder using data from correspond decoder KV caches directly
4408
+ WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4409
+ memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
4410
+ memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
4411
+ }
4412
+ }
4413
+
4414
+ // swap the pointers
4415
+ for (auto & i : p_swap_vec) {
4416
+ WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
4417
+ std::swap(src[i.first].kv_self, src[i.second].kv_self);
4418
+ }
4419
+
4420
+ return true;
4421
+ }
4422
+
4078
4423
  int whisper_full_with_state(
4079
4424
  struct whisper_context * ctx,
4080
4425
  struct whisper_state * state,
@@ -4182,6 +4527,21 @@ int whisper_full_with_state(
4182
4527
  decoder.probs.resize (ctx->vocab.n_vocab);
4183
4528
  decoder.logits.resize (ctx->vocab.n_vocab);
4184
4529
  decoder.logprobs.resize(ctx->vocab.n_vocab);
4530
+
4531
+ // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
4532
+ #ifdef WSP_GGML_USE_METAL
4533
+ #define WHISPER_METAL_CHECK_BUF(result) \
4534
+ if (!(result)) { \
4535
+ log("%s: failed to add metal buffer\n", __func__); \
4536
+ return 0; \
4537
+ }
4538
+
4539
+ const std::string kv_name = "kv_self_" + std::to_string(j);
4540
+ auto & kv_self = decoder.kv_self;
4541
+
4542
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
4543
+ #undef WHISPER_METAL_CHECK_BUF
4544
+ #endif
4185
4545
  }
4186
4546
  }
4187
4547
 
@@ -4197,7 +4557,7 @@ int whisper_full_with_state(
4197
4557
 
4198
4558
  // initial prompt
4199
4559
  if (!params.prompt_tokens && params.initial_prompt) {
4200
- prompt_tokens.resize(1024);
4560
+ prompt_tokens.resize(2048);
4201
4561
  prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
4202
4562
  params.prompt_tokens = prompt_tokens.data();
4203
4563
  params.prompt_n_tokens = prompt_tokens.size();
@@ -4238,14 +4598,6 @@ int whisper_full_with_state(
4238
4598
  std::vector<whisper_token> prompt;
4239
4599
  prompt.reserve(whisper_n_text_ctx(ctx));
4240
4600
 
4241
- // beam-search helpers
4242
- struct kv_buf {
4243
- std::vector<uint8_t> k;
4244
- std::vector<uint8_t> v;
4245
- };
4246
-
4247
- std::vector<kv_buf> kv_bufs;
4248
-
4249
4601
  struct beam_candidate {
4250
4602
  int decoder_idx;
4251
4603
  int seek_delta;
@@ -4279,7 +4631,7 @@ int whisper_full_with_state(
4279
4631
  }
4280
4632
 
4281
4633
  // encode audio features starting at offset seek
4282
- if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
4634
+ if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4283
4635
  log("%s: failed to encode\n", __func__);
4284
4636
  return -6;
4285
4637
  }
@@ -4362,7 +4714,7 @@ int whisper_full_with_state(
4362
4714
  }
4363
4715
  WHISPER_PRINT_DEBUG("\n\n");
4364
4716
 
4365
- if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
4717
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4366
4718
  log("%s: failed to decode\n", __func__);
4367
4719
  return -7;
4368
4720
  }
@@ -4382,8 +4734,8 @@ int whisper_full_with_state(
4382
4734
 
4383
4735
  decoder.kv_self.n += prompt.size();
4384
4736
 
4385
- memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
4386
- memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
4737
+ memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
4738
+ memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
4387
4739
  memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
4388
4740
  }
4389
4741
 
@@ -4394,23 +4746,7 @@ int whisper_full_with_state(
4394
4746
  for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
4395
4747
  const int64_t t_start_sample_us = wsp_ggml_time_us();
4396
4748
 
4397
- // store the KV caches of all decoders when doing beam-search
4398
4749
  if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
4399
- kv_bufs.resize(n_decoders_cur);
4400
- for (int j = 0; j < n_decoders_cur; ++j) {
4401
- auto & decoder = state->decoders[j];
4402
-
4403
- if (decoder.completed || decoder.failed) {
4404
- continue;
4405
- }
4406
-
4407
- kv_bufs[j].k.resize(wsp_ggml_nbytes(decoder.kv_self.k));
4408
- kv_bufs[j].v.resize(wsp_ggml_nbytes(decoder.kv_self.v));
4409
-
4410
- memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
4411
- memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
4412
- }
4413
-
4414
4750
  beam_candidates.clear();
4415
4751
  }
4416
4752
 
@@ -4458,6 +4794,7 @@ int whisper_full_with_state(
4458
4794
  });
4459
4795
 
4460
4796
  uint32_t cur_c = 0;
4797
+ std::vector<int> decoder_idx(n_decoders_cur, -1);
4461
4798
 
4462
4799
  for (int j = 0; j < n_decoders_cur; ++j) {
4463
4800
  auto & decoder = state->decoders[j];
@@ -4476,12 +4813,13 @@ int whisper_full_with_state(
4476
4813
  decoder.seek_delta = cur.seek_delta;
4477
4814
  decoder.has_ts = cur.has_ts;
4478
4815
 
4479
- memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
4480
- memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
4481
-
4816
+ decoder_idx[j] = cur.decoder_idx;
4482
4817
  WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
4483
4818
  __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
4484
4819
  }
4820
+
4821
+ // update KV caches
4822
+ whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
4485
4823
  }
4486
4824
 
4487
4825
  // update the decoder state
@@ -4600,7 +4938,7 @@ int whisper_full_with_state(
4600
4938
 
4601
4939
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4602
4940
 
4603
- if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4941
+ if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4604
4942
  log("%s: failed to decode\n", __func__);
4605
4943
  return -8;
4606
4944
  }
@@ -4910,6 +5248,12 @@ int whisper_full_parallel(
4910
5248
  ctx->state->t_sample_us += states[i]->t_sample_us;
4911
5249
  ctx->state->t_encode_us += states[i]->t_encode_us;
4912
5250
  ctx->state->t_decode_us += states[i]->t_decode_us;
5251
+ ctx->state->t_prompt_us += states[i]->t_prompt_us;
5252
+
5253
+ ctx->state->n_sample += states[i]->n_sample;
5254
+ ctx->state->n_encode += states[i]->n_encode;
5255
+ ctx->state->n_decode += states[i]->n_decode;
5256
+ ctx->state->n_prompt += states[i]->n_prompt;
4913
5257
 
4914
5258
  whisper_free_state(states[i]);
4915
5259
  }
@@ -4963,6 +5307,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
4963
5307
  return ctx->state->result_all[i_segment].t1;
4964
5308
  }
4965
5309
 
5310
+ bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
5311
+ return state->result_all[i_segment].speaker_turn_next;
5312
+ }
5313
+
4966
5314
  bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
4967
5315
  return ctx->state->result_all[i_segment].speaker_turn_next;
4968
5316
  }
@@ -5106,7 +5454,8 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
5106
5454
  // b: N*N*sizeof(float)
5107
5455
  // c: N*N*sizeof(float)
5108
5456
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5109
- std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*512);
5457
+ std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead());
5458
+ std::vector<uint8_t> work;
5110
5459
 
5111
5460
  // put a bunch of random data in the buffer
5112
5461
  for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
@@ -5158,17 +5507,15 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
5158
5507
 
5159
5508
  struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(c);
5160
5509
 
5161
- gf.n_threads = n_threads;
5162
-
5163
5510
  double tsum = 0.0;
5164
5511
 
5165
5512
  // heat-up
5166
- wsp_ggml_graph_compute(ctx0, &gf);
5513
+ wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
5167
5514
 
5168
5515
  for (int i = 0; i < n_max; ++i) {
5169
5516
  const int64_t t0 = wsp_ggml_time_us();
5170
5517
 
5171
- wsp_ggml_graph_compute(ctx0, &gf);
5518
+ wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
5172
5519
 
5173
5520
  const int64_t t1 = wsp_ggml_time_us();
5174
5521