whisper.rn 0.3.9 → 0.4.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/whisper.cpp CHANGED
@@ -3,11 +3,16 @@
3
3
  #include "coreml/whisper-encoder.h"
4
4
  #endif
5
5
 
6
- #if WHISPER_USE_OPENVINO
6
+ #ifdef WSP_GGML_USE_METAL
7
+ # include "ggml-metal.h"
8
+ #endif
9
+
10
+ #ifdef WHISPER_USE_OPENVINO
7
11
  #include "openvino/whisper-openvino-encoder.h"
8
12
  #endif
9
13
 
10
14
  #include "ggml.h"
15
+ #include "ggml-alloc.h"
11
16
 
12
17
  #include <algorithm>
13
18
  #include <cassert>
@@ -18,11 +23,13 @@
18
23
  #include <cstring>
19
24
  #include <fstream>
20
25
  #include <map>
26
+ #include <set>
21
27
  #include <string>
22
28
  #include <thread>
23
29
  #include <vector>
24
30
  #include <regex>
25
31
  #include <random>
32
+ #include <functional>
26
33
 
27
34
  #if defined(_MSC_VER)
28
35
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -114,8 +121,66 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
114
121
  //#define WHISPER_USE_FLASH_FF
115
122
  #define WHISPER_MAX_DECODERS 16
116
123
 
117
- #define WHISPER_USE_SCRATCH
118
- #define WHISPER_MAX_SCRATCH_BUFFERS 16
124
+ //
125
+ // ggml helpers
126
+ //
127
+
128
+ static void wsp_ggml_graph_compute_helper(
129
+ std::vector<uint8_t> & buf,
130
+ wsp_ggml_cgraph * graph,
131
+ int n_threads,
132
+ whisper_abort_callback abort_callback,
133
+ void * abort_callback_data) {
134
+ struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads);
135
+
136
+ plan.abort_callback = abort_callback;
137
+ plan.abort_callback_data = abort_callback_data;
138
+
139
+ if (plan.work_size > 0) {
140
+ buf.resize(plan.work_size);
141
+ plan.work_data = buf.data();
142
+ }
143
+
144
+ wsp_ggml_graph_compute(graph, &plan);
145
+ }
146
+
147
+ // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
148
+ // the idea is to represent the original matrix multiplication:
149
+ //
150
+ // Z = X @ Y
151
+ //
152
+ // with the sum of two matrix multiplications:
153
+ //
154
+ // Z = (X_0 @ Y_0) + (X_1 @ Y_1)
155
+ //
156
+ // here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
157
+ // and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
158
+ // general-purpose kernels
159
+ //
160
+ static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * x, struct wsp_ggml_tensor * y, int pad = 32) {
161
+ // use padding only if dimension 0 is at least 8 times larger than the padding
162
+ // else we won't get much benefit from the optimization
163
+ const int n_pad_req = 8;
164
+
165
+ if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
166
+ return wsp_ggml_mul_mat(ctx, x, y);
167
+ }
168
+
169
+ struct wsp_ggml_tensor * x_0 = wsp_ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
170
+ struct wsp_ggml_tensor * x_1 = wsp_ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
171
+
172
+ struct wsp_ggml_tensor * y_0 = wsp_ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
173
+ struct wsp_ggml_tensor * y_1 = wsp_ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
174
+
175
+ return wsp_ggml_add(ctx,
176
+ wsp_ggml_mul_mat(ctx, x_0, y_0),
177
+ wsp_ggml_mul_mat(ctx, x_1, y_1));
178
+ }
179
+
180
+ // TODO: check if other platforms can benefit from this optimization
181
+ #if defined(WSP_GGML_USE_METAL)
182
+ #define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad
183
+ #endif
119
184
 
120
185
  // available whisper models
121
186
  enum e_model {
@@ -231,38 +296,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
231
296
 
232
297
  static const size_t MB = 1ull*1024*1024;
233
298
 
234
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
235
- { MODEL_TINY, 62ull*MB },
236
- { MODEL_BASE, 80ull*MB },
237
- { MODEL_SMALL, 120ull*MB },
238
- { MODEL_MEDIUM, 158ull*MB },
239
- { MODEL_LARGE, 198ull*MB },
240
- };
241
-
242
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
243
- { MODEL_TINY, 18ull*MB },
244
- { MODEL_BASE, 24ull*MB },
245
- { MODEL_SMALL, 36ull*MB },
246
- { MODEL_MEDIUM, 48ull*MB },
247
- { MODEL_LARGE, 60ull*MB },
248
- };
249
-
250
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
251
- { MODEL_TINY, 4ull*MB },
252
- { MODEL_BASE, 4ull*MB },
253
- { MODEL_SMALL, 6ull*MB },
254
- { MODEL_MEDIUM, 7ull*MB },
255
- { MODEL_LARGE, 9ull*MB },
256
- };
257
-
258
- static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
259
- { MODEL_TINY, 4ull*MB },
260
- { MODEL_BASE, 4ull*MB },
261
- { MODEL_SMALL, 6ull*MB },
262
- { MODEL_MEDIUM, 7ull*MB },
263
- { MODEL_LARGE, 9ull*MB },
264
- };
265
-
299
+ // TODO: avoid using GGUF
266
300
  static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
267
301
  { WSP_GGML_TYPE_F32,
268
302
  {
@@ -329,38 +363,6 @@ static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL =
329
363
  },
330
364
  };
331
365
 
332
- static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
333
- { MODEL_TINY, 3ull*MB },
334
- { MODEL_BASE, 6ull*MB },
335
- { MODEL_SMALL, 16ull*MB },
336
- { MODEL_MEDIUM, 43ull*MB },
337
- { MODEL_LARGE, 71ull*MB },
338
- };
339
-
340
- static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
341
- { MODEL_TINY, 9ull*MB },
342
- { MODEL_BASE, 18ull*MB },
343
- { MODEL_SMALL, 53ull*MB },
344
- { MODEL_MEDIUM, 141ull*MB },
345
- { MODEL_LARGE, 235ull*MB },
346
- };
347
-
348
- static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
349
- { MODEL_TINY, 30ull*MB },
350
- { MODEL_BASE, 38ull*MB },
351
- { MODEL_SMALL, 56ull*MB },
352
- { MODEL_MEDIUM, 74ull*MB },
353
- { MODEL_LARGE, 94ull*MB },
354
- };
355
-
356
- static const std::map<e_model, size_t> MEM_REQ_DECODE = {
357
- { MODEL_TINY, 3ull*MB },
358
- { MODEL_BASE, 5ull*MB },
359
- { MODEL_SMALL, 10ull*MB },
360
- { MODEL_MEDIUM, 18ull*MB },
361
- { MODEL_LARGE, 27ull*MB },
362
- };
363
-
364
366
  struct whisper_mel {
365
367
  int n_len;
366
368
  int n_len_org;
@@ -441,6 +443,7 @@ struct whisper_hparams {
441
443
  int32_t n_text_layer = 4;
442
444
  int32_t n_mels = 80;
443
445
  int32_t ftype = 1;
446
+ float eps = 1e-5f;
444
447
  };
445
448
 
446
449
  // audio encoding layer
@@ -536,6 +539,7 @@ struct whisper_kv_cache {
536
539
 
537
540
  struct wsp_ggml_context * ctx;
538
541
 
542
+ // buf points to the memory allocated for both wsp_ggml_tensor 'k' and 'v' (see kv_cache_init)
539
543
  std::vector<uint8_t> buf;
540
544
 
541
545
  int n; // number of tokens currently in the cache
@@ -601,7 +605,7 @@ struct whisper_sequence {
601
605
 
602
606
  // TAGS: WHISPER_DECODER_INIT
603
607
  struct whisper_decoder {
604
- // each decoders keeps its own KV-cache
608
+ // each decoder keeps its own KV-cache
605
609
  whisper_kv_cache kv_self;
606
610
 
607
611
  // the currently generated sequence of tokens
@@ -621,15 +625,75 @@ struct whisper_decoder {
621
625
  std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
622
626
  };
623
627
 
628
+ // replace std::pair by using customized pair struct (reason: std::pair is very slow)
629
+ template<typename A, typename B>
630
+ struct whisper_pair {
631
+ A first;
632
+ B second;
633
+
634
+ // Define a constructor that takes two arguments.
635
+ whisper_pair(const A& a, const B& b) : first(a), second(b) {}
636
+ // Define a constructor that takes no argument.
637
+ whisper_pair() : first(A()), second(B()) {}
638
+ };
639
+
640
+ // beam-search helpers
641
+ struct kv_buf {
642
+ std::vector<uint8_t> k;
643
+ std::vector<uint8_t> v;
644
+ };
645
+
646
+ // wsp_ggml_allocr wrapper for whisper usage
647
+ struct whisper_allocr {
648
+ wsp_ggml_allocr * alloc = nullptr;
649
+
650
+ std::vector<uint8_t> meta;
651
+ std::vector<uint8_t> data;
652
+ };
653
+
654
+ static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
655
+ return allocr.meta.size() + allocr.data.size();
656
+ }
657
+
658
+ // measure the memory usage of a graph and prepare the allocr's internal data buffer
659
+ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
660
+ const int tensor_alignment = 32;
661
+
662
+ auto & alloc = allocr.alloc;
663
+ auto & meta = allocr.meta;
664
+ auto & data = allocr.data;
665
+
666
+ meta.resize(wsp_ggml_tensor_overhead()*WSP_GGML_MAX_NODES + wsp_ggml_graph_overhead());
667
+
668
+ alloc = wsp_ggml_allocr_new_measure(tensor_alignment);
669
+
670
+ const size_t alloc_size = wsp_ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
671
+
672
+ wsp_ggml_allocr_free(alloc);
673
+
674
+ data.resize(alloc_size);
675
+
676
+ alloc = wsp_ggml_allocr_new(data.data(), data.size(), tensor_alignment);
677
+ }
678
+
679
+ static void whisper_allocr_free(struct whisper_allocr & allocr) {
680
+ if (allocr.alloc) {
681
+ wsp_ggml_allocr_free(allocr.alloc);
682
+ allocr.alloc = nullptr;
683
+ }
684
+ }
685
+
624
686
  struct whisper_state {
625
687
  int64_t t_sample_us = 0;
626
688
  int64_t t_encode_us = 0;
627
689
  int64_t t_decode_us = 0;
690
+ int64_t t_prompt_us = 0;
628
691
  int64_t t_mel_us = 0;
629
692
 
630
693
  int32_t n_sample = 0; // number of tokens sampled
631
694
  int32_t n_encode = 0; // number of encoder calls
632
- int32_t n_decode = 0; // number of decoder calls
695
+ int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
696
+ int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
633
697
  int32_t n_fail_p = 0; // number of logprob threshold failures
634
698
  int32_t n_fail_h = 0; // number of entropy threshold failures
635
699
 
@@ -640,12 +704,23 @@ struct whisper_state {
640
704
 
641
705
  whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
642
706
 
643
- // memory buffers used by encode / decode contexts
644
- std::vector<uint8_t> buf_compute;
645
- std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
707
+ // buffer for swapping KV caches between decoders during beam-search
708
+ std::vector<kv_buf> kv_swap_bufs;
709
+
710
+ // reusable buffer for `struct wsp_ggml_graph_plan.work_data`
711
+ std::vector<uint8_t> work_buffer;
646
712
 
647
- int buf_last = 0;
648
- size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
713
+ // ggml-alloc:
714
+ // - stores meta info about the intermediate tensors into the `meta` buffers
715
+ // - stores the actual tensor data into the `data` buffers
716
+ whisper_allocr alloc_conv;
717
+ whisper_allocr alloc_encode;
718
+ whisper_allocr alloc_cross;
719
+ whisper_allocr alloc_decode;
720
+
721
+ // result of the encoder
722
+ struct wsp_ggml_tensor * embd_conv = nullptr;
723
+ struct wsp_ggml_tensor * embd_enc = nullptr;
649
724
 
650
725
  // decode output (2-dimensional array: [n_tokens][n_vocab])
651
726
  std::vector<float> logits;
@@ -654,7 +729,7 @@ struct whisper_state {
654
729
  std::vector<whisper_token> prompt_past;
655
730
 
656
731
  // work container used to avoid memory allocations
657
- std::vector<std::pair<double, whisper_vocab::id>> logits_id;
732
+ std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
658
733
 
659
734
  mutable std::mt19937 rng; // used for sampling at t > 0.0
660
735
 
@@ -665,6 +740,10 @@ struct whisper_state {
665
740
  whisper_coreml_context * ctx_coreml = nullptr;
666
741
  #endif
667
742
 
743
+ #ifdef WSP_GGML_USE_METAL
744
+ wsp_ggml_metal_context * ctx_metal = nullptr;
745
+ #endif
746
+
668
747
  #ifdef WHISPER_USE_OPENVINO
669
748
  whisper_openvino_context * ctx_openvino = nullptr;
670
749
  #endif
@@ -677,37 +756,6 @@ struct whisper_state {
677
756
 
678
757
  // [EXPERIMENTAL] speed-up techniques
679
758
  int32_t exp_n_audio_ctx = 0; // 0 - use default
680
-
681
- void use_buf(struct wsp_ggml_context * ctx, int i) {
682
- #if defined(WHISPER_USE_SCRATCH)
683
- size_t last_size = 0;
684
-
685
- if (i == -1) {
686
- last_size = wsp_ggml_set_scratch(ctx, { 0, 0, nullptr, });
687
- } else {
688
- auto & buf = buf_scratch[i];
689
- last_size = wsp_ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
690
- }
691
-
692
- if (buf_last >= 0) {
693
- buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
694
- }
695
-
696
- buf_last = i;
697
- #else
698
- (void) i;
699
- (void) ctx;
700
- #endif
701
- }
702
-
703
- size_t get_buf_max_mem(int i) const {
704
- #if defined(WHISPER_USE_SCRATCH)
705
- return buf_max_size[i];
706
- #else
707
- (void) i;
708
- return 0;
709
- #endif
710
- }
711
759
  };
712
760
 
713
761
  struct whisper_context {
@@ -730,6 +778,13 @@ static void whisper_default_log(const char * text) {
730
778
 
731
779
  static whisper_log_callback whisper_log = whisper_default_log;
732
780
 
781
+ #ifdef __GNUC__
782
+ #ifdef __MINGW32__
783
+ __attribute__((gnu_format(printf, 1, 2)))
784
+ #else
785
+ __attribute__((format(printf, 1, 2)))
786
+ #endif
787
+ #endif
733
788
  static void log(const char * fmt, ...) {
734
789
  if (!whisper_log) return;
735
790
  char buf[1024];
@@ -747,10 +802,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
747
802
 
748
803
  static bool kv_cache_init(
749
804
  const struct whisper_hparams & hparams,
750
- const size_t mem_bytes,
751
805
  struct whisper_kv_cache & cache,
752
806
  wsp_ggml_type wtype,
753
807
  int n_ctx) {
808
+ const int64_t n_text_state = hparams.n_text_state;
809
+ const int64_t n_text_layer = hparams.n_text_layer;
810
+
811
+ const int64_t n_mem = n_text_layer*n_ctx;
812
+ const int64_t n_elements = n_text_state*n_mem;
813
+
814
+ const size_t mem_bytes = 2*(wsp_ggml_type_size(wtype)*n_elements + wsp_ggml_tensor_overhead());
815
+
754
816
  cache.buf.resize(mem_bytes);
755
817
 
756
818
  struct wsp_ggml_init_params params = {
@@ -766,12 +828,6 @@ static bool kv_cache_init(
766
828
  return false;
767
829
  }
768
830
 
769
- const int n_text_state = hparams.n_text_state;
770
- const int n_text_layer = hparams.n_text_layer;
771
-
772
- const int n_mem = n_text_layer*n_ctx;
773
- const int n_elements = n_text_state*n_mem;
774
-
775
831
  cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
776
832
  cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
777
833
 
@@ -914,22 +970,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
914
970
 
915
971
  // print memory requirements
916
972
  {
917
- // this is the total memory required to run the inference
918
- const size_t mem_required =
919
- MEM_REQ_SCRATCH0.at(model.type) +
920
- MEM_REQ_SCRATCH1.at(model.type) +
921
- MEM_REQ_SCRATCH2.at(model.type) +
922
- MEM_REQ_SCRATCH3.at(model.type) +
923
- scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
924
- scale*MEM_REQ_KV_CROSS.at(model.type) +
925
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
926
-
927
- // this is the memory required by one decoder
928
- const size_t mem_required_decoder =
929
- scale*MEM_REQ_KV_SELF.at(model.type);
930
-
931
- log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
932
- mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
973
+ // TODO
974
+ //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
975
+ // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
933
976
  }
934
977
 
935
978
  // initialize all memory buffers
@@ -1438,49 +1481,56 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1438
1481
  return true;
1439
1482
  }
1440
1483
 
1441
- // evaluate the encoder with the given state
1442
- //
1443
- // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1444
- // part of the transformer model and returns the encoded features
1445
- //
1446
- // - wctx: the model
1447
- // - wstate: the state of the encoder
1448
- // - n_threads: number of threads to use
1449
- // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1450
- //
1451
- static bool whisper_encode_internal(
1452
- whisper_context & wctx,
1453
- whisper_state & wstate,
1454
- const int mel_offset,
1455
- const int n_threads){
1484
+ static bool whisper_encode_external(const whisper_state & wstate) {
1485
+ WSP_GGML_UNUSED(wstate);
1456
1486
 
1457
- const int64_t t_start_us = wsp_ggml_time_us();
1487
+ #ifndef WHISPER_USE_COREML
1488
+ const bool use_coreml = false;
1489
+ #else
1490
+ const bool use_coreml = wstate.ctx_coreml != nullptr;
1491
+ #endif
1492
+
1493
+ #ifndef WHISPER_USE_OPENVINO
1494
+ const bool use_openvino = false;
1495
+ #else
1496
+ const bool use_openvino = wstate.ctx_openvino != nullptr;
1497
+ #endif
1498
+
1499
+ return use_coreml || use_openvino;
1500
+ }
1458
1501
 
1502
+ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1503
+ whisper_context & wctx,
1504
+ whisper_state & wstate,
1505
+ const int mel_offset) {
1459
1506
  const auto & model = wctx.model;
1460
1507
  const auto & mel_inp = wstate.mel;
1461
1508
  const auto & hparams = model.hparams;
1462
1509
 
1463
1510
  const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1464
- const int n_state = hparams.n_audio_state;
1465
- const int n_head = hparams.n_audio_head;
1466
- const int n_layer = hparams.n_audio_layer;
1511
+ const int n_state = hparams.n_audio_state; WSP_GGML_UNUSED(n_state);
1467
1512
 
1468
1513
  const int n_mels = hparams.n_mels;
1469
- assert(mel_inp.n_mel == n_mels);
1470
1514
 
1471
1515
  struct wsp_ggml_init_params params = {
1472
- /*.mem_size =*/ wstate.buf_compute.size(),
1473
- /*.mem_buffer =*/ wstate.buf_compute.data(),
1474
- /*.no_alloc =*/ false,
1516
+ /*.mem_size =*/ wstate.alloc_conv.meta.size(),
1517
+ /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
1518
+ /*.no_alloc =*/ true,
1475
1519
  };
1476
1520
 
1477
1521
  struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1478
1522
 
1479
- wstate.use_buf(ctx0, 0);
1523
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1524
+
1525
+ wsp_ggml_allocr * alloc = wstate.alloc_conv.alloc;
1480
1526
 
1481
1527
  struct wsp_ggml_tensor * mel = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, 2*n_ctx, n_mels);
1528
+ wsp_ggml_allocr_alloc(alloc, mel);
1529
+
1482
1530
  assert(mel->type == WSP_GGML_TYPE_F32);
1483
- {
1531
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1532
+ assert(mel_inp.n_mel == n_mels);
1533
+
1484
1534
  float * dst = (float *) mel->data;
1485
1535
  memset(dst, 0, wsp_ggml_nbytes(mel));
1486
1536
 
@@ -1494,25 +1544,11 @@ static bool whisper_encode_internal(
1494
1544
  }
1495
1545
  }
1496
1546
 
1497
- struct wsp_ggml_tensor * cur;
1547
+ struct wsp_ggml_tensor * cur = nullptr;
1498
1548
 
1499
- #ifndef WHISPER_USE_COREML
1500
- const bool use_coreml = false;
1501
- #else
1502
- const bool use_coreml = wstate.ctx_coreml != nullptr;
1503
- #endif
1504
-
1505
- #ifndef WHISPER_USE_OPENVINO
1506
- const bool use_openvino = false;
1507
- #else
1508
- const bool use_openvino = wstate.ctx_openvino != nullptr;
1509
- #endif
1510
-
1511
- if (!use_coreml && !use_openvino) {
1549
+ if (!whisper_encode_external(wstate)) {
1512
1550
  // convolution + gelu
1513
1551
  {
1514
- wstate.use_buf(ctx0, 1);
1515
-
1516
1552
  cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1517
1553
  cur = wsp_ggml_add(ctx0,
1518
1554
  wsp_ggml_repeat(ctx0,
@@ -1522,8 +1558,6 @@ static bool whisper_encode_internal(
1522
1558
 
1523
1559
  cur = wsp_ggml_gelu(ctx0, cur);
1524
1560
 
1525
- wstate.use_buf(ctx0, 0);
1526
-
1527
1561
  cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1528
1562
  cur = wsp_ggml_add(ctx0,
1529
1563
  wsp_ggml_repeat(ctx0,
@@ -1534,373 +1568,433 @@ static bool whisper_encode_internal(
1534
1568
  cur = wsp_ggml_gelu(ctx0, cur);
1535
1569
  }
1536
1570
 
1537
- wstate.use_buf(ctx0, 3);
1571
+ wstate.embd_conv = cur;
1572
+ } else {
1573
+ #ifdef WHISPER_USE_COREML
1574
+ cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1575
+ wsp_ggml_allocr_alloc(alloc, cur);
1538
1576
 
1539
- // ===================================================================
1540
- // NOTE: experimenting with partial evaluation of the encoder (ignore)
1541
- //static int iter = -1;
1542
- //const int n_iter = 1500/n_ctx;
1577
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1578
+ whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1579
+ }
1580
+ #endif
1581
+ #ifdef WHISPER_USE_OPENVINO
1582
+ cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1583
+ wsp_ggml_allocr_alloc(alloc, cur);
1543
1584
 
1544
- //iter = (iter + 1) % n_iter;
1585
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1586
+ whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
1587
+ }
1588
+ #endif
1545
1589
 
1546
- //if (iter == 0) {
1547
- // memset(model.memory_cross_k->data, 0, wsp_ggml_nbytes(model.memory_cross_k));
1548
- // memset(model.memory_cross_v->data, 0, wsp_ggml_nbytes(model.memory_cross_v));
1549
- //}
1590
+ wstate.embd_enc = cur;
1591
+ }
1550
1592
 
1551
- static int iter = 0;
1593
+ wsp_ggml_build_forward_expand(gf, cur);
1552
1594
 
1553
- const size_t e_pe_stride = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe);
1554
- const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
1595
+ wsp_ggml_free(ctx0);
1555
1596
 
1556
- struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1597
+ return gf;
1598
+ }
1557
1599
 
1558
- cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_transpose(ctx0, cur));
1600
+ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1601
+ whisper_context & wctx,
1602
+ whisper_state & wstate) {
1603
+ const auto & model = wctx.model;
1604
+ const auto & hparams = model.hparams;
1559
1605
 
1560
- // ===================================================================
1606
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1607
+ const int n_state = hparams.n_audio_state;
1608
+ const int n_head = hparams.n_audio_head;
1609
+ const int n_layer = hparams.n_audio_layer;
1561
1610
 
1562
- // original:
1563
- //cur = wsp_ggml_add(ctx0, model.e_pe, wsp_ggml_transpose(ctx0, cur));
1611
+ struct wsp_ggml_init_params params = {
1612
+ /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1613
+ /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
1614
+ /*.no_alloc =*/ true,
1615
+ };
1564
1616
 
1565
- struct wsp_ggml_tensor * inpL = cur;
1617
+ struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1566
1618
 
1567
- for (int il = 0; il < n_layer; ++il) {
1568
- const auto & layer = model.layers_encoder[il];
1619
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1569
1620
 
1570
- // norm
1571
- {
1572
- wstate.use_buf(ctx0, 0);
1621
+ wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
1573
1622
 
1574
- cur = wsp_ggml_norm(ctx0, inpL);
1623
+ struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
1624
+ wsp_ggml_allocr_alloc(alloc, KQscale);
1575
1625
 
1576
- // cur = ln_0_w*cur + ln_0_b
1577
- cur = wsp_ggml_add(ctx0,
1578
- wsp_ggml_mul(ctx0,
1579
- wsp_ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1580
- cur),
1581
- wsp_ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1582
- }
1626
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1627
+ wsp_ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
1628
+ }
1583
1629
 
1584
- // self-attention
1585
- {
1586
- wstate.use_buf(ctx0, 1);
1630
+ struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
1587
1631
 
1588
- struct wsp_ggml_tensor * Qcur = wsp_ggml_mul_mat(ctx0,
1589
- layer.attn_q_w,
1590
- cur);
1632
+ // ===================================================================
1633
+ // NOTE: experimenting with partial evaluation of the encoder (ignore)
1634
+ //static int iter = -1;
1635
+ //const int n_iter = 1500/n_ctx;
1591
1636
 
1592
- Qcur = wsp_ggml_add(ctx0,
1593
- wsp_ggml_repeat(ctx0,
1594
- layer.attn_q_b,
1595
- Qcur),
1596
- Qcur);
1637
+ //iter = (iter + 1) % n_iter;
1597
1638
 
1598
- //Qcur = wsp_ggml_scale_inplace(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1639
+ //if (iter == 0) {
1640
+ // memset(model.memory_cross_k->data, 0, wsp_ggml_nbytes(model.memory_cross_k));
1641
+ // memset(model.memory_cross_v->data, 0, wsp_ggml_nbytes(model.memory_cross_v));
1642
+ //}
1599
1643
 
1600
- // note: no bias for Key
1601
- struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
1602
- layer.attn_k_w,
1603
- cur);
1644
+ static int iter = 0;
1604
1645
 
1605
- //Kcur = wsp_ggml_scale_inplace(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1646
+ const size_t e_pe_stride = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe);
1647
+ const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
1606
1648
 
1607
- struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
1608
- layer.attn_v_w,
1609
- cur);
1649
+ struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1610
1650
 
1611
- Vcur = wsp_ggml_add(ctx0,
1612
- wsp_ggml_repeat(ctx0,
1613
- layer.attn_v_b,
1614
- Vcur),
1615
- Vcur);
1651
+ cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_cont(ctx0, wsp_ggml_transpose(ctx0, cur)));
1652
+
1653
+ // ===================================================================
1654
+
1655
+ // original:
1656
+ //cur = wsp_ggml_add(ctx0, model.e_pe, wsp_ggml_transpose(ctx0, cur));
1657
+
1658
+ struct wsp_ggml_tensor * inpL = cur;
1616
1659
 
1617
- // ------
1660
+ for (int il = 0; il < n_layer; ++il) {
1661
+ const auto & layer = model.layers_encoder[il];
1662
+
1663
+ // norm
1664
+ {
1665
+ cur = wsp_ggml_norm(ctx0, inpL, hparams.eps);
1666
+
1667
+ // cur = ln_0_w*cur + ln_0_b
1668
+ cur = wsp_ggml_add(ctx0,
1669
+ wsp_ggml_mul(ctx0, cur, layer.attn_ln_0_w),
1670
+ layer.attn_ln_0_b);
1671
+ }
1672
+
1673
+ // self-attention
1674
+ {
1675
+ struct wsp_ggml_tensor * Qcur = wsp_ggml_mul_mat(ctx0,
1676
+ layer.attn_q_w,
1677
+ cur);
1678
+
1679
+ Qcur = wsp_ggml_add(ctx0, Qcur, layer.attn_q_b);
1680
+
1681
+ //Qcur = wsp_ggml_scale(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1682
+
1683
+ // note: no bias for Key
1684
+ struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
1685
+ layer.attn_k_w,
1686
+ cur);
1687
+
1688
+ //Kcur = wsp_ggml_scale(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1689
+
1690
+ struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0,
1691
+ layer.attn_v_w,
1692
+ cur);
1693
+
1694
+ Vcur = wsp_ggml_add(ctx0, Vcur, layer.attn_v_b);
1618
1695
 
1619
- wstate.use_buf(ctx0, 0);
1696
+ // ------
1620
1697
 
1621
1698
  #ifdef WHISPER_USE_FLASH_ATTN
1622
- struct wsp_ggml_tensor * Q =
1623
- wsp_ggml_permute(ctx0,
1624
- wsp_ggml_cpy(ctx0,
1625
- Qcur,
1626
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1627
- 0, 2, 1, 3);
1628
-
1629
- struct wsp_ggml_tensor * K =
1630
- wsp_ggml_permute(ctx0,
1631
- wsp_ggml_cpy(ctx0,
1632
- Kcur,
1633
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1634
- 0, 2, 1, 3);
1635
-
1636
- struct wsp_ggml_tensor * V =
1637
- wsp_ggml_cpy(ctx0,
1638
- wsp_ggml_permute(ctx0,
1639
- wsp_ggml_reshape_3d(ctx0,
1640
- Vcur,
1641
- n_state/n_head, n_head, n_ctx),
1642
- 1, 2, 0, 3),
1643
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
1644
-
1645
- struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false);
1699
+ struct wsp_ggml_tensor * Q =
1700
+ wsp_ggml_permute(ctx0,
1701
+ wsp_ggml_cpy(ctx0,
1702
+ Qcur,
1703
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1704
+ 0, 2, 1, 3);
1705
+
1706
+ struct wsp_ggml_tensor * K =
1707
+ wsp_ggml_permute(ctx0,
1708
+ wsp_ggml_cpy(ctx0,
1709
+ Kcur,
1710
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1711
+ 0, 2, 1, 3);
1712
+
1713
+ struct wsp_ggml_tensor * V =
1714
+ wsp_ggml_cpy(ctx0,
1715
+ wsp_ggml_permute(ctx0,
1716
+ wsp_ggml_reshape_3d(ctx0,
1717
+ Vcur,
1718
+ n_state/n_head, n_head, n_ctx),
1719
+ 1, 2, 0, 3),
1720
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
1721
+
1722
+ struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false);
1646
1723
  #else
1647
- struct wsp_ggml_tensor * Q =
1648
- wsp_ggml_permute(ctx0,
1649
- wsp_ggml_cpy(ctx0,
1650
- Qcur,
1651
- wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1652
- 0, 2, 1, 3);
1653
-
1654
- struct wsp_ggml_tensor * K =
1655
- wsp_ggml_permute(ctx0,
1656
- wsp_ggml_cpy(ctx0,
1657
- Kcur,
1658
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1659
- 0, 2, 1, 3);
1660
-
1661
- // K * Q
1662
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
1663
-
1664
- struct wsp_ggml_tensor * KQ_scaled =
1665
- wsp_ggml_scale_inplace(ctx0,
1666
- KQ,
1667
- wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
1668
- );
1669
-
1670
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ_scaled);
1671
-
1672
- struct wsp_ggml_tensor * V =
1673
- wsp_ggml_cpy(ctx0,
1674
- wsp_ggml_permute(ctx0,
1675
- wsp_ggml_reshape_3d(ctx0,
1676
- Vcur,
1677
- n_state/n_head, n_head, n_ctx),
1678
- 1, 2, 0, 3),
1679
- wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1680
- );
1681
-
1682
- struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
1724
+ struct wsp_ggml_tensor * Q =
1725
+ wsp_ggml_permute(ctx0,
1726
+ wsp_ggml_cpy(ctx0,
1727
+ Qcur,
1728
+ wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1729
+ 0, 2, 1, 3);
1730
+
1731
+ struct wsp_ggml_tensor * K =
1732
+ wsp_ggml_permute(ctx0,
1733
+ wsp_ggml_cpy(ctx0,
1734
+ Kcur,
1735
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1736
+ 0, 2, 1, 3);
1737
+
1738
+ // K * Q
1739
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
1740
+
1741
+ struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQscale);
1742
+
1743
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_scaled);
1744
+
1745
+ struct wsp_ggml_tensor * V =
1746
+ wsp_ggml_cpy(ctx0,
1747
+ wsp_ggml_permute(ctx0,
1748
+ wsp_ggml_reshape_3d(ctx0,
1749
+ Vcur,
1750
+ n_state/n_head, n_head, n_ctx),
1751
+ 1, 2, 0, 3),
1752
+ wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1753
+ );
1754
+
1755
+ struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
1683
1756
  #endif
1684
- struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1757
+ struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1685
1758
 
1686
- wstate.use_buf(ctx0, 1);
1759
+ cur = wsp_ggml_cpy(ctx0,
1760
+ KQV_merged,
1761
+ wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
1762
+ }
1687
1763
 
1688
- cur = wsp_ggml_cpy(ctx0,
1689
- KQV_merged,
1690
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx));
1691
- }
1764
+ // projection
1765
+ {
1766
+ cur = wsp_ggml_mul_mat(ctx0,
1767
+ layer.attn_ln_1_w,
1768
+ cur);
1692
1769
 
1693
- // projection
1694
- {
1695
- wstate.use_buf(ctx0, 0);
1770
+ cur = wsp_ggml_add(ctx0, cur, layer.attn_ln_1_b);
1771
+ }
1696
1772
 
1697
- cur = wsp_ggml_mul_mat(ctx0,
1698
- layer.attn_ln_1_w,
1699
- cur);
1773
+ // add the input
1774
+ cur = wsp_ggml_add(ctx0, cur, inpL);
1700
1775
 
1701
- wstate.use_buf(ctx0, 1);
1776
+ struct wsp_ggml_tensor * inpFF = cur;
1777
+
1778
+ // feed-forward network
1779
+ {
1780
+ // norm
1781
+ {
1782
+ cur = wsp_ggml_norm(ctx0, inpFF, hparams.eps);
1702
1783
 
1784
+ // cur = mlp_ln_w*cur + mlp_ln_b
1703
1785
  cur = wsp_ggml_add(ctx0,
1704
- wsp_ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1705
- cur);
1786
+ wsp_ggml_mul(ctx0, cur, layer.mlp_ln_w),
1787
+ layer.mlp_ln_b);
1706
1788
  }
1707
1789
 
1708
- wstate.use_buf(ctx0, 2);
1790
+ #ifdef WHISPER_USE_FLASH_FF
1791
+ cur = wsp_ggml_flash_ff(ctx0,
1792
+ wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
1793
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1794
+ #else
1795
+ // fully connected
1796
+ cur = wsp_ggml_mul_mat(ctx0,
1797
+ layer.mlp_0_w,
1798
+ cur);
1709
1799
 
1710
- // add the input
1711
- cur = wsp_ggml_add(ctx0, cur, inpL);
1800
+ cur = wsp_ggml_add(ctx0, cur, layer.mlp_0_b);
1712
1801
 
1713
- struct wsp_ggml_tensor * inpFF = cur;
1802
+ // GELU activation
1803
+ cur = wsp_ggml_gelu(ctx0, cur);
1714
1804
 
1715
- // feed-forward network
1716
- {
1717
- // norm
1718
- {
1719
- wstate.use_buf(ctx0, 0);
1805
+ // projection
1806
+ cur = wsp_ggml_mul_mat(ctx0,
1807
+ layer.mlp_1_w,
1808
+ cur);
1720
1809
 
1721
- cur = wsp_ggml_norm(ctx0, inpFF);
1810
+ cur = wsp_ggml_add(ctx0, cur, layer.mlp_1_b);
1811
+ #endif
1812
+ }
1722
1813
 
1723
- wstate.use_buf(ctx0, 1);
1814
+ inpL = wsp_ggml_add(ctx0, cur, inpFF);
1815
+ }
1724
1816
 
1725
- // cur = mlp_ln_w*cur + mlp_ln_b
1726
- cur = wsp_ggml_add(ctx0,
1727
- wsp_ggml_mul(ctx0,
1728
- wsp_ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1729
- cur),
1730
- wsp_ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1731
- }
1817
+ cur = inpL;
1732
1818
 
1733
- #ifdef WHISPER_USE_FLASH_FF
1734
- wstate.use_buf(ctx0, 0);
1819
+ // norm
1820
+ {
1821
+ cur = wsp_ggml_norm(ctx0, cur, hparams.eps);
1735
1822
 
1736
- cur = wsp_ggml_flash_ff(ctx0,
1737
- wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
1738
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1739
- #else
1740
- wstate.use_buf(ctx0, 0);
1823
+ // cur = ln_f_g*cur + ln_f_b
1824
+ cur = wsp_ggml_add(ctx0,
1825
+ wsp_ggml_mul(ctx0, cur, model.e_ln_w),
1826
+ model.e_ln_b);
1827
+ }
1741
1828
 
1742
- // fully connected
1743
- cur = wsp_ggml_mul_mat(ctx0,
1744
- layer.mlp_0_w,
1745
- cur);
1829
+ wsp_ggml_build_forward_expand(gf, cur);
1746
1830
 
1747
- wstate.use_buf(ctx0, 1);
1831
+ wstate.embd_enc = cur;
1748
1832
 
1749
- cur = wsp_ggml_add(ctx0,
1750
- wsp_ggml_repeat(ctx0, layer.mlp_0_b, cur),
1751
- cur);
1833
+ //wsp_ggml_graph_print(gf);
1752
1834
 
1753
- wstate.use_buf(ctx0, 0);
1835
+ ////////////////////////////////////////////////////////////////////////////
1754
1836
 
1755
- // GELU activation
1756
- cur = wsp_ggml_gelu(ctx0, cur);
1837
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1838
+ // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
1839
+ // wstate.get_buf_max_mem(0)/1024.0/1024.0,
1840
+ // wstate.get_buf_max_mem(1)/1024.0/1024.0,
1841
+ // wstate.get_buf_max_mem(2)/1024.0/1024.0,
1842
+ // wstate.get_buf_max_mem(3)/1024.0/1024.0);
1757
1843
 
1758
- wstate.use_buf(ctx0, 1);
1844
+ wsp_ggml_free(ctx0);
1759
1845
 
1760
- // projection
1761
- cur = wsp_ggml_mul_mat(ctx0,
1762
- layer.mlp_1_w,
1763
- cur);
1846
+ return gf;
1847
+ }
1764
1848
 
1765
- wstate.use_buf(ctx0, 0);
1849
+ // pre-compute cross-attention memory
1850
+ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
1851
+ whisper_context & wctx,
1852
+ whisper_state & wstate) {
1853
+ const auto & model = wctx.model;
1854
+ const auto & hparams = model.hparams;
1766
1855
 
1767
- cur = wsp_ggml_add(ctx0,
1768
- wsp_ggml_repeat(ctx0, layer.mlp_1_b, cur),
1769
- cur);
1770
- #endif
1771
- }
1856
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1857
+ const int n_state = hparams.n_audio_state;
1858
+ const int n_head = hparams.n_audio_head;
1772
1859
 
1773
- wstate.use_buf(ctx0, 3);
1860
+ struct wsp_ggml_init_params params = {
1861
+ /*.mem_size =*/ wstate.alloc_cross.meta.size(),
1862
+ /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
1863
+ /*.no_alloc =*/ true,
1864
+ };
1774
1865
 
1775
- inpL = wsp_ggml_add(ctx0, cur, inpFF);
1776
- }
1866
+ struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1777
1867
 
1778
- cur = inpL;
1868
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1779
1869
 
1780
- // norm
1781
- {
1782
- wstate.use_buf(ctx0, 0);
1870
+ wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
1783
1871
 
1784
- cur = wsp_ggml_norm(ctx0, cur);
1872
+ struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
1785
1873
 
1786
- wstate.use_buf(ctx0, 1);
1874
+ struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
1875
+ wsp_ggml_allocr_alloc(alloc, Kscale);
1787
1876
 
1788
- // cur = ln_f_g*cur + ln_f_b
1789
- cur = wsp_ggml_add(ctx0,
1790
- wsp_ggml_mul(ctx0,
1791
- wsp_ggml_repeat(ctx0, model.e_ln_w, cur),
1792
- cur),
1793
- wsp_ggml_repeat(ctx0, model.e_ln_b, cur));
1794
- }
1877
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
1878
+ wsp_ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
1879
+ }
1795
1880
 
1796
- wstate.use_buf(ctx0, -1);
1881
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1882
+ auto & layer = model.layers_decoder[il];
1797
1883
 
1798
- // run the computation
1799
- {
1800
- struct wsp_ggml_cgraph gf = {};
1801
- gf.n_threads = n_threads;
1884
+ struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
1885
+ layer.cross_attn_k_w,
1886
+ cur);
1802
1887
 
1803
- wsp_ggml_build_forward_expand(&gf, cur);
1804
- wsp_ggml_graph_compute(ctx0, &gf);
1888
+ Kcross = wsp_ggml_scale(ctx0, Kcross, Kscale);
1805
1889
 
1806
- //wsp_ggml_graph_print(&gf);
1807
- }
1808
- }
1809
- #ifdef WHISPER_USE_COREML
1810
- else if (use_coreml) {
1811
- wstate.use_buf(ctx0, -1);
1890
+ struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
1891
+ layer.cross_attn_v_w,
1892
+ cur);
1812
1893
 
1813
- cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1894
+ Vcross = wsp_ggml_add(ctx0,
1895
+ Vcross,
1896
+ layer.cross_attn_v_b);
1814
1897
 
1815
- whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1816
- }
1817
- #endif
1818
- #ifdef WHISPER_USE_OPENVINO
1819
- else if (use_openvino) {
1820
- wstate.use_buf(ctx0, -1);
1898
+ Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
1821
1899
 
1822
- cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
1900
+ struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k,
1901
+ n_state*n_ctx,
1902
+ (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
1823
1903
 
1824
- if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
1825
- return false;
1826
- }
1904
+ struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
1905
+ ( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
1906
+ (il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
1907
+
1908
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcross, k));
1909
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcross, v));
1827
1910
  }
1828
- #endif
1829
1911
 
1830
- // cur
1831
- //{
1832
- // printf("ne0 = %d\n", cur->ne[0]);
1833
- // printf("ne1 = %d\n", cur->ne[1]);
1834
- // for (int i = 0; i < 10; ++i) {
1835
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1836
- // }
1837
- // printf("... ");
1838
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1839
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1840
- // }
1841
- // printf("\n");
1842
- //}
1912
+ //wsp_ggml_graph_print(gf);
1843
1913
 
1844
- // pre-compute cross-attention memory
1845
- {
1846
- struct wsp_ggml_cgraph gf = {};
1847
- gf.n_threads = n_threads;
1914
+ wsp_ggml_free(ctx0);
1848
1915
 
1849
- // TODO: hack to disconnect the encoded features from the previous graph
1850
- cur->op = WSP_GGML_OP_NONE;
1851
- cur->src0 = nullptr;
1852
- cur->src1 = nullptr;
1916
+ return gf;
1917
+ }
1853
1918
 
1854
- for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1855
- auto& layer = model.layers_decoder[il];
1919
+ // evaluate the encoder with the given state
1920
+ //
1921
+ // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1922
+ // part of the transformer model and returns the encoded features
1923
+ //
1924
+ // - wctx: the model
1925
+ // - wstate: the state of the encoder
1926
+ // - n_threads: number of threads to use
1927
+ // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1928
+ //
1929
+ static bool whisper_encode_internal(
1930
+ whisper_context & wctx,
1931
+ whisper_state & wstate,
1932
+ const int mel_offset,
1933
+ const int n_threads,
1934
+ whisper_abort_callback abort_callback,
1935
+ void * abort_callback_data) {
1936
+ const int64_t t_start_us = wsp_ggml_time_us();
1856
1937
 
1857
- wstate.use_buf(ctx0, 0);
1938
+ // conv
1939
+ {
1940
+ auto & alloc = wstate.alloc_conv.alloc;
1858
1941
 
1859
- struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0,
1860
- layer.cross_attn_k_w,
1861
- cur);
1942
+ wsp_ggml_allocr_reset(alloc);
1862
1943
 
1863
- Kcross = wsp_ggml_scale_inplace(ctx0, Kcross, wsp_ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
1944
+ wsp_ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
1864
1945
 
1865
- wstate.use_buf(ctx0, 1);
1946
+ wsp_ggml_allocr_alloc_graph(alloc, gf);
1866
1947
 
1867
- struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0,
1868
- layer.cross_attn_v_w,
1869
- cur);
1948
+ if (!whisper_encode_external(wstate)) {
1949
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1950
+ }
1951
+ }
1870
1952
 
1871
- Vcross = wsp_ggml_add(ctx0,
1872
- wsp_ggml_repeat(ctx0,
1873
- layer.cross_attn_v_b,
1874
- Vcross),
1875
- Vcross);
1953
+ // encoder
1954
+ if (!whisper_encode_external(wstate)) {
1955
+ auto & alloc = wstate.alloc_encode.alloc;
1876
1956
 
1877
- wstate.use_buf(ctx0, -1);
1957
+ wsp_ggml_allocr_reset(alloc);
1878
1958
 
1879
- Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
1959
+ wsp_ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
1880
1960
 
1881
- struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
1882
- struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
1883
- ( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v),
1884
- (il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
1961
+ wsp_ggml_allocr_alloc_graph(alloc, gf);
1885
1962
 
1886
- wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Kcross, k));
1887
- wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Vcross, v));
1963
+ #ifdef WSP_GGML_USE_METAL
1964
+ if (wstate.ctx_metal) {
1965
+ wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1966
+ wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
1967
+ } else {
1968
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1888
1969
  }
1889
-
1890
- wsp_ggml_graph_compute(ctx0, &gf);
1891
- //wsp_ggml_graph_print(&gf);
1970
+ #else
1971
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1972
+ #endif
1892
1973
  }
1893
1974
 
1894
- ////////////////////////////////////////////////////////////////////////////
1975
+ // cross
1976
+ {
1977
+ auto & alloc = wstate.alloc_cross.alloc;
1895
1978
 
1896
- //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1897
- // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
1898
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
1899
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
1900
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
1901
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
1979
+ wsp_ggml_allocr_reset(alloc);
1902
1980
 
1903
- wsp_ggml_free(ctx0);
1981
+ wsp_ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
1982
+
1983
+ wsp_ggml_allocr_alloc_graph(alloc, gf);
1984
+
1985
+ #ifdef WSP_GGML_USE_METAL
1986
+ if (wstate.ctx_metal) {
1987
+ wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1988
+ wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
1989
+ } else {
1990
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1991
+ }
1992
+ #else
1993
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1994
+ #endif
1995
+ }
1996
+
1997
+ // wsp_ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
1904
1998
 
1905
1999
  wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
1906
2000
  wstate.n_encode++;
@@ -1908,26 +2002,13 @@ static bool whisper_encode_internal(
1908
2002
  return true;
1909
2003
  }
1910
2004
 
1911
- // evaluate the decoder
1912
- //
1913
- // given text prompt + audio features -> computes the logits for the next token
1914
- //
1915
- // - model: the model
1916
- // - n_threads: number of threads to use
1917
- // - tokens: text prompt
1918
- // - n_tokens: number of tokens in the prompt
1919
- // - n_past: number of past tokens to prefix the prompt with
1920
- //
1921
- static bool whisper_decode_internal(
1922
- whisper_context & wctx,
1923
- whisper_state & wstate,
1924
- whisper_decoder & decoder,
1925
- const whisper_token * tokens,
1926
- const int n_tokens,
1927
- const int n_past,
1928
- const int n_threads) {
1929
- const int64_t t_start_us = wsp_ggml_time_us();
1930
-
2005
+ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2006
+ whisper_context & wctx,
2007
+ whisper_state & wstate,
2008
+ whisper_decoder & decoder,
2009
+ const whisper_token * tokens,
2010
+ int n_tokens,
2011
+ int n_past) {
1931
2012
  const auto & model = wctx.model;
1932
2013
  const auto & hparams = model.hparams;
1933
2014
 
@@ -1935,10 +2016,6 @@ static bool whisper_decode_internal(
1935
2016
 
1936
2017
  WHISPER_ASSERT(!!kv_self.ctx);
1937
2018
 
1938
- auto & logits_out = wstate.logits;
1939
-
1940
- const int n_vocab = hparams.n_vocab;
1941
-
1942
2019
  const int n_ctx = hparams.n_text_ctx;
1943
2020
  const int n_state = hparams.n_text_state;
1944
2021
  const int n_head = hparams.n_text_head;
@@ -1950,25 +2027,39 @@ static bool whisper_decode_internal(
1950
2027
  //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1951
2028
 
1952
2029
  struct wsp_ggml_init_params params = {
1953
- /*.mem_size =*/ wstate.buf_compute.size(),
1954
- /*.mem_buffer =*/ wstate.buf_compute.data(),
1955
- /*.no_alloc =*/ false,
2030
+ /*.mem_size =*/ wstate.alloc_decode.meta.size(),
2031
+ /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
2032
+ /*.no_alloc =*/ true,
1956
2033
  };
1957
2034
 
1958
2035
  struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1959
2036
 
1960
- struct wsp_ggml_cgraph gf = {};
1961
- gf.n_threads = n_threads;
2037
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
2038
+
2039
+ wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
1962
2040
 
1963
2041
  struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
1964
- memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd));
2042
+ wsp_ggml_allocr_alloc(alloc, embd);
2043
+
2044
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
2045
+ memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd));
2046
+ }
1965
2047
 
1966
2048
  struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
1967
- for (int i = 0; i < N; ++i) {
1968
- ((int32_t *) position->data)[i] = n_past + i;
2049
+ wsp_ggml_allocr_alloc(alloc, position);
2050
+
2051
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
2052
+ for (int i = 0; i < N; ++i) {
2053
+ ((int32_t *) position->data)[i] = n_past + i;
2054
+ }
1969
2055
  }
1970
2056
 
1971
- wstate.use_buf(ctx0, 3);
2057
+ struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
2058
+ wsp_ggml_allocr_alloc(alloc, KQscale);
2059
+
2060
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
2061
+ wsp_ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
2062
+ }
1972
2063
 
1973
2064
  // token encoding + position encoding
1974
2065
  struct wsp_ggml_tensor * cur =
@@ -1983,16 +2074,14 @@ static bool whisper_decode_internal(
1983
2074
 
1984
2075
  // norm
1985
2076
  {
1986
- wstate.use_buf(ctx0, 0);
1987
-
1988
- cur = wsp_ggml_norm(ctx0, inpL);
2077
+ cur = wsp_ggml_norm(ctx0, inpL, hparams.eps);
1989
2078
 
1990
2079
  // cur = ln_0_w*cur + ln_0_b
1991
2080
  cur = wsp_ggml_add(ctx0,
1992
2081
  wsp_ggml_mul(ctx0,
1993
- wsp_ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1994
- cur),
1995
- wsp_ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
2082
+ cur,
2083
+ layer.attn_ln_0_w),
2084
+ layer.attn_ln_0_b);
1996
2085
  }
1997
2086
 
1998
2087
  // self-attention
@@ -2002,19 +2091,17 @@ static bool whisper_decode_internal(
2002
2091
  cur);
2003
2092
 
2004
2093
  Qcur = wsp_ggml_add(ctx0,
2005
- wsp_ggml_repeat(ctx0,
2006
- layer.attn_q_b,
2007
- Qcur),
2008
- Qcur);
2094
+ Qcur,
2095
+ layer.attn_q_b);
2009
2096
 
2010
- Qcur = wsp_ggml_scale_inplace(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2097
+ Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
2011
2098
 
2012
2099
  // note: no bias for Key
2013
2100
  struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0,
2014
2101
  layer.attn_k_w,
2015
2102
  cur);
2016
2103
 
2017
- Kcur = wsp_ggml_scale_inplace(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2104
+ Kcur = wsp_ggml_scale(ctx0, Kcur, KQscale);
2018
2105
 
2019
2106
  // store key and value to memory
2020
2107
  {
@@ -2023,10 +2110,8 @@ static bool whisper_decode_internal(
2023
2110
  cur);
2024
2111
 
2025
2112
  Vcur = wsp_ggml_add(ctx0,
2026
- wsp_ggml_repeat(ctx0,
2027
- layer.attn_v_b,
2028
- Vcur),
2029
- Vcur);
2113
+ Vcur,
2114
+ layer.attn_v_b);
2030
2115
 
2031
2116
  Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, N));
2032
2117
 
@@ -2035,42 +2120,32 @@ static bool whisper_decode_internal(
2035
2120
  ( n_ctx)*wsp_ggml_element_size(kv_self.v),
2036
2121
  (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + n_past*wsp_ggml_element_size(kv_self.v));
2037
2122
 
2038
- wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Kcur, k));
2039
- wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Vcur, v));
2123
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
2124
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
2040
2125
  }
2041
2126
 
2042
2127
  // ------
2043
2128
 
2044
- wstate.use_buf(ctx0, 0);
2045
-
2046
2129
  struct wsp_ggml_tensor * Q =
2047
2130
  wsp_ggml_permute(ctx0,
2048
- wsp_ggml_cpy(ctx0,
2049
- Qcur,
2050
- wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)),
2131
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2051
2132
  0, 2, 1, 3);
2052
2133
 
2053
2134
  struct wsp_ggml_tensor * K =
2054
- wsp_ggml_permute(ctx0,
2055
- wsp_ggml_reshape_3d(ctx0,
2056
- wsp_ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*wsp_ggml_element_size(kv_self.k)*n_state),
2057
- n_state/n_head, n_head, n_past + N),
2058
- 0, 2, 1, 3);
2059
-
2060
- wstate.use_buf(ctx0, 1);
2135
+ wsp_ggml_view_3d(ctx0, kv_self.k,
2136
+ n_state/n_head, n_past + N, n_head,
2137
+ wsp_ggml_element_size(kv_self.k)*n_state,
2138
+ wsp_ggml_element_size(kv_self.k)*n_state/n_head,
2139
+ wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
2061
2140
 
2062
2141
  // K * Q
2063
2142
  struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2064
2143
 
2065
- //struct wsp_ggml_tensor * KQ_scaled =
2066
- // wsp_ggml_scale_inplace(ctx0,
2067
- // KQ,
2068
- // wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2069
- // );
2144
+ //struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
2070
2145
 
2071
- struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf_inplace(ctx0, KQ, n_past);
2146
+ struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
2072
2147
 
2073
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ_masked);
2148
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
2074
2149
 
2075
2150
  struct wsp_ggml_tensor * V =
2076
2151
  wsp_ggml_view_3d(ctx0, kv_self.v,
@@ -2090,36 +2165,28 @@ static bool whisper_decode_internal(
2090
2165
 
2091
2166
  // projection
2092
2167
  {
2093
- wstate.use_buf(ctx0, 0);
2094
-
2095
2168
  cur = wsp_ggml_mul_mat(ctx0,
2096
2169
  layer.attn_ln_1_w,
2097
2170
  cur);
2098
2171
 
2099
- wstate.use_buf(ctx0, 1);
2100
-
2101
2172
  cur = wsp_ggml_add(ctx0,
2102
- wsp_ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
2103
- cur);
2173
+ cur,
2174
+ layer.attn_ln_1_b);
2104
2175
  }
2105
2176
 
2106
- wstate.use_buf(ctx0, 2);
2107
-
2108
2177
  // add the input
2109
2178
  struct wsp_ggml_tensor * inpCA = wsp_ggml_add(ctx0, cur, inpL);
2110
2179
 
2111
2180
  // norm
2112
2181
  {
2113
- wstate.use_buf(ctx0, 0);
2114
-
2115
- cur = wsp_ggml_norm(ctx0, inpCA); // note: we use inpCA here
2182
+ cur = wsp_ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
2116
2183
 
2117
2184
  // cur = ln_0_w*cur + ln_0_b
2118
2185
  cur = wsp_ggml_add(ctx0,
2119
2186
  wsp_ggml_mul(ctx0,
2120
- wsp_ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
2121
- cur),
2122
- wsp_ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
2187
+ cur,
2188
+ layer.cross_attn_ln_0_w),
2189
+ layer.cross_attn_ln_0_b);
2123
2190
  }
2124
2191
 
2125
2192
  // cross-attention
@@ -2129,18 +2196,18 @@ static bool whisper_decode_internal(
2129
2196
  cur);
2130
2197
 
2131
2198
  Qcur = wsp_ggml_add(ctx0,
2132
- wsp_ggml_repeat(ctx0,
2133
- layer.cross_attn_q_b,
2134
- Qcur),
2135
- Qcur);
2199
+ Qcur,
2200
+ layer.cross_attn_q_b);
2136
2201
 
2137
- Qcur = wsp_ggml_scale_inplace(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
2202
+ Qcur = wsp_ggml_scale(ctx0, Qcur, KQscale);
2138
2203
 
2139
2204
  // Kcross is already scaled
2140
2205
  struct wsp_ggml_tensor * Kcross =
2141
- wsp_ggml_reshape_3d(ctx0,
2142
- wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*wsp_ggml_element_size(wstate.kv_cross.k)*n_state),
2143
- n_state/n_head, n_head, M);
2206
+ wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2207
+ n_state/n_head, M, n_head,
2208
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2209
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2210
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
2144
2211
 
2145
2212
  //struct wsp_ggml_tensor * Vcross =
2146
2213
  // wsp_ggml_reshape_3d(ctx0,
@@ -2163,26 +2230,22 @@ static bool whisper_decode_internal(
2163
2230
 
2164
2231
  struct wsp_ggml_tensor * Q =
2165
2232
  wsp_ggml_permute(ctx0,
2166
- wsp_ggml_cpy(ctx0,
2167
- Qcur,
2168
- wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)),
2233
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2169
2234
  0, 2, 1, 3);
2170
2235
 
2171
- struct wsp_ggml_tensor * K = wsp_ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2172
-
2173
2236
  // K * Q
2174
- struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q);
2237
+ struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, Kcross, Q);
2175
2238
 
2176
2239
  //struct wsp_ggml_tensor * KQ_scaled =
2177
- // wsp_ggml_scale_inplace(ctx0,
2240
+ // wsp_ggml_scale(ctx0,
2178
2241
  // KQ,
2179
2242
  // wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2180
2243
  // );
2181
2244
 
2182
2245
  // no masking for cross-attention
2183
- //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
2246
+ //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2184
2247
 
2185
- struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ);
2248
+ struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ);
2186
2249
 
2187
2250
  struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2188
2251
 
@@ -2196,21 +2259,15 @@ static bool whisper_decode_internal(
2196
2259
 
2197
2260
  // projection
2198
2261
  {
2199
- wstate.use_buf(ctx0, 0);
2200
-
2201
2262
  cur = wsp_ggml_mul_mat(ctx0,
2202
2263
  layer.cross_attn_ln_1_w,
2203
2264
  cur);
2204
2265
 
2205
- wstate.use_buf(ctx0, 1);
2206
-
2207
2266
  cur = wsp_ggml_add(ctx0,
2208
- wsp_ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2209
- cur);
2267
+ cur,
2268
+ layer.cross_attn_ln_1_b);
2210
2269
  }
2211
2270
 
2212
- wstate.use_buf(ctx0, 2);
2213
-
2214
2271
  // add the input
2215
2272
  cur = wsp_ggml_add(ctx0, cur, inpCA);
2216
2273
 
@@ -2220,54 +2277,38 @@ static bool whisper_decode_internal(
2220
2277
  {
2221
2278
  // norm
2222
2279
  {
2223
- wstate.use_buf(ctx0, 0);
2224
-
2225
- cur = wsp_ggml_norm(ctx0, inpFF);
2226
-
2227
- wstate.use_buf(ctx0, 1);
2280
+ cur = wsp_ggml_norm(ctx0, inpFF, hparams.eps);
2228
2281
 
2229
2282
  // cur = mlp_ln_w*cur + mlp_ln_b
2230
2283
  cur = wsp_ggml_add(ctx0,
2231
2284
  wsp_ggml_mul(ctx0,
2232
- wsp_ggml_repeat(ctx0, layer.mlp_ln_w, cur),
2233
- cur),
2234
- wsp_ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2285
+ cur,
2286
+ layer.mlp_ln_w),
2287
+ layer.mlp_ln_b);
2235
2288
  }
2236
2289
 
2237
- wstate.use_buf(ctx0, 0);
2238
-
2239
2290
  // fully connected
2240
2291
  cur = wsp_ggml_mul_mat(ctx0,
2241
2292
  layer.mlp_0_w,
2242
2293
  cur);
2243
2294
 
2244
- wstate.use_buf(ctx0, 1);
2245
-
2246
2295
  cur = wsp_ggml_add(ctx0,
2247
- wsp_ggml_repeat(ctx0, layer.mlp_0_b, cur),
2248
- cur);
2249
-
2250
- wstate.use_buf(ctx0, 0);
2296
+ cur,
2297
+ layer.mlp_0_b);
2251
2298
 
2252
2299
  // GELU activation
2253
2300
  cur = wsp_ggml_gelu(ctx0, cur);
2254
2301
 
2255
- wstate.use_buf(ctx0, 1);
2256
-
2257
2302
  // projection
2258
2303
  cur = wsp_ggml_mul_mat(ctx0,
2259
2304
  layer.mlp_1_w,
2260
2305
  cur);
2261
2306
 
2262
- wstate.use_buf(ctx0, 0);
2263
-
2264
2307
  cur = wsp_ggml_add(ctx0,
2265
- wsp_ggml_repeat(ctx0, layer.mlp_1_b, cur),
2266
- cur);
2308
+ cur,
2309
+ layer.mlp_1_b);
2267
2310
  }
2268
2311
 
2269
- wstate.use_buf(ctx0, 3);
2270
-
2271
2312
  inpL = wsp_ggml_add(ctx0, cur, inpFF);
2272
2313
  }
2273
2314
 
@@ -2275,21 +2316,15 @@ static bool whisper_decode_internal(
2275
2316
 
2276
2317
  // norm
2277
2318
  {
2278
- wstate.use_buf(ctx0, 0);
2279
-
2280
- cur = wsp_ggml_norm(ctx0, cur);
2281
-
2282
- wstate.use_buf(ctx0, 1);
2319
+ cur = wsp_ggml_norm(ctx0, cur, hparams.eps);
2283
2320
 
2284
2321
  cur = wsp_ggml_add(ctx0,
2285
2322
  wsp_ggml_mul(ctx0,
2286
- wsp_ggml_repeat(ctx0, model.d_ln_w, cur),
2287
- cur),
2288
- wsp_ggml_repeat(ctx0, model.d_ln_b, cur));
2323
+ cur,
2324
+ model.d_ln_w),
2325
+ model.d_ln_b);
2289
2326
  }
2290
2327
 
2291
- wstate.use_buf(ctx0, 0);
2292
-
2293
2328
  // compute logits only for the last token
2294
2329
  // comment this line to compute logits for all N tokens
2295
2330
  // might be useful in the future
@@ -2297,23 +2332,77 @@ static bool whisper_decode_internal(
2297
2332
 
2298
2333
  struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
2299
2334
 
2300
- wstate.use_buf(ctx0, -1);
2335
+ wsp_ggml_build_forward_expand(gf, logits);
2336
+
2337
+ wsp_ggml_free(ctx0);
2338
+
2339
+ return gf;
2340
+ }
2341
+
2342
+ // evaluate the decoder
2343
+ //
2344
+ // given text prompt + audio features -> computes the logits for the next token
2345
+ //
2346
+ // - model: the model
2347
+ // - n_threads: number of threads to use
2348
+ // - tokens: text prompt
2349
+ // - n_tokens: number of tokens in the prompt
2350
+ // - n_past: number of past tokens to prefix the prompt with
2351
+ //
2352
+ static bool whisper_decode_internal(
2353
+ whisper_context & wctx,
2354
+ whisper_state & wstate,
2355
+ whisper_decoder & decoder,
2356
+ const whisper_token * tokens,
2357
+ const int n_tokens,
2358
+ const int n_past,
2359
+ const int n_threads,
2360
+ whisper_abort_callback abort_callback,
2361
+ void * abort_callback_data) {
2362
+ const int64_t t_start_us = wsp_ggml_time_us();
2363
+
2364
+ const auto & model = wctx.model;
2365
+ const auto & hparams = model.hparams;
2366
+
2367
+ const int n_vocab = hparams.n_vocab;
2368
+
2369
+ auto & logits_out = wstate.logits;
2370
+
2371
+ struct wsp_ggml_tensor * logits;
2301
2372
 
2302
- // run the computation
2373
+ // decoder
2303
2374
  {
2304
- wsp_ggml_build_forward_expand(&gf, logits);
2305
- wsp_ggml_graph_compute (ctx0, &gf);
2375
+ auto & alloc = wstate.alloc_decode.alloc;
2376
+
2377
+ wsp_ggml_allocr_reset(alloc);
2378
+
2379
+ wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
2380
+
2381
+ wsp_ggml_allocr_alloc_graph(alloc, gf);
2382
+
2383
+ logits = gf->nodes[gf->n_nodes - 1];
2384
+
2385
+ #ifdef WSP_GGML_USE_METAL
2386
+ if (wstate.ctx_metal) {
2387
+ wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2388
+ wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
2389
+ } else {
2390
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2391
+ }
2392
+ #else
2393
+ wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2394
+ #endif
2306
2395
  }
2307
2396
 
2308
2397
  // extract logits for all N tokens
2309
- //logits_out.resize(N*n_vocab);
2310
- //memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*N*n_vocab);
2398
+ //logits_out.resize(n_tokens*n_vocab);
2399
+ //memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
2311
2400
 
2312
2401
  // extract logits only for the last token
2313
2402
  logits_out.resize(n_vocab);
2314
2403
  memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_vocab);
2315
2404
 
2316
- if (N > 1) {
2405
+ if (n_tokens > 1) {
2317
2406
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2318
2407
  // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
2319
2408
  // wstate.get_buf_max_mem(0)/1024.0/1024.0,
@@ -2322,14 +2411,18 @@ static bool whisper_decode_internal(
2322
2411
  // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2323
2412
  }
2324
2413
 
2325
- wsp_ggml_free(ctx0);
2326
-
2327
- wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
2328
- wstate.n_decode++;
2414
+ if (n_tokens == 1) {
2415
+ wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
2416
+ wstate.n_decode++;
2417
+ } else {
2418
+ wstate.t_prompt_us += wsp_ggml_time_us() - t_start_us;
2419
+ wstate.n_prompt++;
2420
+ }
2329
2421
 
2330
2422
  return true;
2331
2423
  }
2332
2424
 
2425
+
2333
2426
  // 500 -> 00:05.000
2334
2427
  // 6000 -> 01:00.000
2335
2428
  static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2351,7 +2444,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2351
2444
  static float sin_vals[SIN_COS_N_COUNT];
2352
2445
  static float cos_vals[SIN_COS_N_COUNT];
2353
2446
 
2354
- // In FFT, we frequently use sine and cosine operations with the same values.
2447
+ // In FFT, we frequently use sine and cosine operations with the same values.
2355
2448
  // We can use precalculated values to speed up the process.
2356
2449
  static void fill_sin_cos_table() {
2357
2450
  static bool is_filled = false;
@@ -2446,7 +2539,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2446
2539
  }
2447
2540
 
2448
2541
  static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2449
- if (output.size() < length) {
2542
+ if (output.size() < static_cast<size_t>(length)) {
2450
2543
  output.resize(length);
2451
2544
  }
2452
2545
  int offset = -1;
@@ -2738,9 +2831,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2738
2831
  fill_sin_cos_table();
2739
2832
  whisper_state * state = new whisper_state;
2740
2833
 
2741
- const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
2742
-
2743
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2834
+ if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2744
2835
  log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
2745
2836
  delete state;
2746
2837
  return nullptr;
@@ -2751,7 +2842,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2751
2842
  log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2752
2843
  }
2753
2844
 
2754
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2845
+ if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2755
2846
  log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2756
2847
  delete state;
2757
2848
  return nullptr;
@@ -2772,6 +2863,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2772
2863
  if (!state->ctx_coreml) {
2773
2864
  log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2774
2865
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
2866
+ delete state;
2775
2867
  return nullptr;
2776
2868
  #endif
2777
2869
  } else {
@@ -2786,15 +2878,111 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2786
2878
  // TAGS: WHISPER_DECODER_INIT
2787
2879
  state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
2788
2880
 
2789
- state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
2790
- state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
2881
+ state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
2882
+ state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
2791
2883
  state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
2792
- state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
2793
2884
 
2794
- state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
2795
- state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
2796
- state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
2797
- state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
2885
+ // conv allocator
2886
+ {
2887
+ whisper_allocr_graph_init(state->alloc_conv,
2888
+ [&]() {
2889
+ return whisper_build_graph_conv(*ctx, *state, 0);
2890
+ });
2891
+
2892
+ log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
2893
+ }
2894
+
2895
+ // encoder allocator
2896
+ if (!whisper_encode_external(*state)) {
2897
+ whisper_allocr_graph_init(state->alloc_encode,
2898
+ [&]() {
2899
+ return whisper_build_graph_encoder(*ctx, *state);
2900
+ });
2901
+
2902
+ log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
2903
+ }
2904
+
2905
+ // cross allocator
2906
+ {
2907
+ whisper_allocr_graph_init(state->alloc_cross,
2908
+ [&]() {
2909
+ return whisper_build_graph_cross(*ctx, *state);
2910
+ });
2911
+
2912
+ log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
2913
+ }
2914
+
2915
+ // decoder allocator
2916
+ {
2917
+ whisper_allocr_graph_init(state->alloc_decode,
2918
+ [&]() {
2919
+ const auto & hparams = ctx->model.hparams;
2920
+
2921
+ // TODO: make sure this is the worst-case scenario
2922
+ const int n_tokens = hparams.n_text_ctx;
2923
+ const int n_past = 0;
2924
+
2925
+ return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
2926
+ });
2927
+
2928
+ log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
2929
+ }
2930
+
2931
+ #ifdef WSP_GGML_USE_METAL
2932
+ state->ctx_metal = wsp_ggml_metal_init(1);
2933
+ if (!state->ctx_metal) {
2934
+ log("%s: wsp_ggml_metal_init() failed\n", __func__);
2935
+ delete state;
2936
+ return nullptr;
2937
+ }
2938
+
2939
+ log("%s: Metal context initialized\n", __func__);
2940
+
2941
+ // this allocates all Metal resources and memory buffers
2942
+
2943
+ void * data_ptr = NULL;
2944
+ size_t data_size = 0;
2945
+
2946
+ // TODO: add mmap support
2947
+ //if (params.use_mmap) {
2948
+ // data_ptr = ctx->model.mapping->addr;
2949
+ // data_size = ctx->model.mapping->size;
2950
+ //} else {
2951
+ // data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2952
+ // data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2953
+ //}
2954
+
2955
+ data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2956
+ data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2957
+
2958
+ const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx);
2959
+
2960
+ log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
2961
+
2962
+ #define WHISPER_METAL_CHECK_BUF(result) \
2963
+ if (!(result)) { \
2964
+ log("%s: failed to add metal buffer\n", __func__); \
2965
+ delete state; \
2966
+ return nullptr; \
2967
+ }
2968
+
2969
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
2970
+
2971
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
2972
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
2973
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
2974
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
2975
+
2976
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
2977
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
2978
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
2979
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
2980
+
2981
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
2982
+
2983
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
2984
+ #undef WHISPER_METAL_CHECK_BUF
2985
+ #endif
2798
2986
 
2799
2987
  state->rng = std::mt19937(0);
2800
2988
 
@@ -2851,7 +3039,6 @@ int whisper_ctx_init_openvino_encoder(
2851
3039
  }
2852
3040
 
2853
3041
  struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
2854
-
2855
3042
  log("%s: loading model from '%s'\n", __func__, path_model);
2856
3043
 
2857
3044
  auto fin = std::ifstream(path_model, std::ios::binary);
@@ -3004,6 +3191,13 @@ void whisper_free_state(struct whisper_state * state)
3004
3191
  }
3005
3192
  #endif
3006
3193
 
3194
+ #ifdef WSP_GGML_USE_METAL
3195
+ if (state->ctx_metal) {
3196
+ wsp_ggml_metal_free(state->ctx_metal);
3197
+ state->ctx_metal = nullptr;
3198
+ }
3199
+ #endif
3200
+
3007
3201
  #ifdef WHISPER_USE_OPENVINO
3008
3202
  if (state->ctx_openvino != nullptr) {
3009
3203
  whisper_openvino_free(state->ctx_openvino);
@@ -3011,6 +3205,11 @@ void whisper_free_state(struct whisper_state * state)
3011
3205
  }
3012
3206
  #endif
3013
3207
 
3208
+ whisper_allocr_free(state->alloc_conv);
3209
+ whisper_allocr_free(state->alloc_decode);
3210
+ whisper_allocr_free(state->alloc_cross);
3211
+ whisper_allocr_free(state->alloc_encode);
3212
+
3014
3213
  delete state;
3015
3214
  }
3016
3215
  }
@@ -3103,7 +3302,7 @@ int whisper_set_mel(
3103
3302
  }
3104
3303
 
3105
3304
  int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3106
- if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
3305
+ if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
3107
3306
  log("%s: failed to eval\n", __func__);
3108
3307
  return -1;
3109
3308
  }
@@ -3112,7 +3311,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
3112
3311
  }
3113
3312
 
3114
3313
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3115
- if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
3314
+ if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
3116
3315
  log("%s: failed to eval\n", __func__);
3117
3316
  return -1;
3118
3317
  }
@@ -3123,7 +3322,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3123
3322
  int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3124
3323
  const int selected_decoder_id = 0;
3125
3324
 
3126
- if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3325
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3127
3326
  log("%s: failed to eval\n", __func__);
3128
3327
  return 1;
3129
3328
  }
@@ -3140,7 +3339,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
3140
3339
  return false;
3141
3340
  }
3142
3341
 
3143
- if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3342
+ if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3144
3343
  log("%s: failed to eval\n", __func__);
3145
3344
  return 1;
3146
3345
  }
@@ -3431,12 +3630,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
3431
3630
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3432
3631
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3433
3632
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3633
+ const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3434
3634
 
3435
3635
  log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3436
3636
  log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3437
3637
  log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3438
3638
  log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3439
3639
  log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3640
+ log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3440
3641
  }
3441
3642
  log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3442
3643
  }
@@ -3446,6 +3647,11 @@ void whisper_reset_timings(struct whisper_context * ctx) {
3446
3647
  ctx->state->t_sample_us = 0;
3447
3648
  ctx->state->t_encode_us = 0;
3448
3649
  ctx->state->t_decode_us = 0;
3650
+ ctx->state->t_prompt_us = 0;
3651
+ ctx->state->n_sample = 0;
3652
+ ctx->state->n_encode = 0;
3653
+ ctx->state->n_decode = 0;
3654
+ ctx->state->n_prompt = 0;
3449
3655
  }
3450
3656
  }
3451
3657
 
@@ -3475,6 +3681,7 @@ const char * whisper_print_system_info(void) {
3475
3681
  s += "FMA = " + std::to_string(wsp_ggml_cpu_has_fma()) + " | ";
3476
3682
  s += "NEON = " + std::to_string(wsp_ggml_cpu_has_neon()) + " | ";
3477
3683
  s += "ARM_FMA = " + std::to_string(wsp_ggml_cpu_has_arm_fma()) + " | ";
3684
+ s += "METAL = " + std::to_string(wsp_ggml_cpu_has_metal()) + " | ";
3478
3685
  s += "F16C = " + std::to_string(wsp_ggml_cpu_has_f16c()) + " | ";
3479
3686
  s += "FP16_VA = " + std::to_string(wsp_ggml_cpu_has_fp16_va()) + " | ";
3480
3687
  s += "WASM_SIMD = " + std::to_string(wsp_ggml_cpu_has_wasm_simd()) + " | ";
@@ -3566,6 +3773,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3566
3773
  /*.encoder_begin_callback =*/ nullptr,
3567
3774
  /*.encoder_begin_callback_user_data =*/ nullptr,
3568
3775
 
3776
+ /*.abort_callback =*/ nullptr,
3777
+ /*.abort_callback_user_data =*/ nullptr,
3778
+
3569
3779
  /*.logits_filter_callback =*/ nullptr,
3570
3780
  /*.logits_filter_callback_user_data =*/ nullptr,
3571
3781
  };
@@ -3970,17 +4180,21 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
3970
4180
 
3971
4181
  auto & logits_id = state.logits_id;
3972
4182
 
3973
- logits_id.clear();
4183
+ logits_id.resize(n_logits);
3974
4184
  for (int i = 0; i < n_logits; ++i) {
3975
- logits_id.push_back({ logits[i], i });
4185
+ logits_id[i].first = logits[i];
4186
+ logits_id[i].second = i;
3976
4187
  }
3977
4188
 
3978
- std::partial_sort(
3979
- logits_id.begin(),
3980
- logits_id.begin() + k, logits_id.end(),
3981
- [](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
3982
- return a.first > b.first;
3983
- });
4189
+ {
4190
+ using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
4191
+ std::partial_sort(
4192
+ logits_id.begin(),
4193
+ logits_id.begin() + k, logits_id.end(),
4194
+ [](const pair_type & a, const pair_type & b) {
4195
+ return a.first > b.first;
4196
+ });
4197
+ }
3984
4198
 
3985
4199
  std::vector<whisper_token_data> result;
3986
4200
  result.reserve(k);
@@ -4075,6 +4289,115 @@ static void whisper_sequence_score(
4075
4289
  }
4076
4290
  }
4077
4291
 
4292
+ static bool whisper_kv_swap_fast(
4293
+ std::vector<int> & view,
4294
+ whisper_decoder src[],
4295
+ std::vector<kv_buf> & kv_swap_bufs,
4296
+ const int & n_decoders) {
4297
+ WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
4298
+
4299
+ // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
4300
+ std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
4301
+
4302
+ // (buffer->decoder or decoder->decoder)
4303
+ std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
4304
+
4305
+ // (decoder<->decoder)
4306
+ std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
4307
+ std::vector<whisper_pair<int, int>> p_swap_vec;
4308
+ p_swap_vec.reserve(n_decoders);
4309
+
4310
+ // see https://github.com/ggerganov/whisper.cpp/wiki
4311
+ for (int i = 0; i < n_decoders; i++) {
4312
+ // zero-copy (no modification)
4313
+ if (i == view[i] || view[i] < 0) {
4314
+ continue;
4315
+ }
4316
+
4317
+ bool is_one_copy = true;
4318
+ // since we modify data sequentially, we only consider decoder indices after current index
4319
+ for (int j = i + 1; j < n_decoders; j++) {
4320
+ if (i == view[j]) {
4321
+ // detect symmetric diagram
4322
+ if (j == view[i]) {
4323
+ p_swap_set.insert(i);
4324
+ p_swap_set.insert(j);
4325
+ p_swap_vec.emplace_back(i, j);
4326
+ } else {
4327
+ two_copy.insert(i);
4328
+ is_one_copy = false;
4329
+ }
4330
+ break;
4331
+ }
4332
+ }
4333
+ if (is_one_copy) {
4334
+ one_copy.insert(i);
4335
+ }
4336
+ }
4337
+
4338
+ kv_swap_bufs.resize(n_decoders);
4339
+
4340
+ for (int i = 0; i < n_decoders; i++) {
4341
+ kv_swap_bufs[i].k.resize(wsp_ggml_nbytes(src[i].kv_self.k));
4342
+ kv_swap_bufs[i].v.resize(wsp_ggml_nbytes(src[i].kv_self.v));
4343
+ }
4344
+
4345
+ for (auto & i : two_copy) {
4346
+ // make a copy of KV caches
4347
+ WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
4348
+ memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
4349
+ memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
4350
+ }
4351
+
4352
+ // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
4353
+ for (auto & i : two_copy) {
4354
+ // skip the decoder indices that require pointer swapping
4355
+ if (p_swap_set.find(i) != p_swap_set.end()) {
4356
+ continue;
4357
+ }
4358
+
4359
+ if (two_copy.find(view[i]) != two_copy.end()) {
4360
+ // modify KV caches of decoder using data from kv_swap_bufs
4361
+ WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4362
+ memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4363
+ memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4364
+ } else {
4365
+ // modify KV caches of decoder using data from correspond decoder KV caches directly
4366
+ WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4367
+ memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
4368
+ memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
4369
+ }
4370
+ }
4371
+
4372
+ // then modify one-copy decoder KV caches
4373
+ for (auto & i : one_copy) {
4374
+ // skip the decoder indices that require pointer swapping
4375
+ if (p_swap_set.find(i) != p_swap_set.end()) {
4376
+ continue;
4377
+ }
4378
+
4379
+ if (two_copy.find(view[i]) != two_copy.end()) {
4380
+ // modify KV caches of decoder using data from kv_swap_bufs
4381
+ WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4382
+ memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4383
+ memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4384
+ } else {
4385
+ // modify KV caches of decoder using data from correspond decoder KV caches directly
4386
+ WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4387
+ memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
4388
+ memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
4389
+ }
4390
+ }
4391
+
4392
+ // swap the pointers
4393
+ for (auto & i : p_swap_vec) {
4394
+ WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
4395
+ std::swap(src[i.first].kv_self, src[i.second].kv_self);
4396
+ }
4397
+
4398
+ return true;
4399
+ }
4400
+
4078
4401
  int whisper_full_with_state(
4079
4402
  struct whisper_context * ctx,
4080
4403
  struct whisper_state * state,
@@ -4182,6 +4505,21 @@ int whisper_full_with_state(
4182
4505
  decoder.probs.resize (ctx->vocab.n_vocab);
4183
4506
  decoder.logits.resize (ctx->vocab.n_vocab);
4184
4507
  decoder.logprobs.resize(ctx->vocab.n_vocab);
4508
+
4509
+ // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
4510
+ #ifdef WSP_GGML_USE_METAL
4511
+ #define WHISPER_METAL_CHECK_BUF(result) \
4512
+ if (!(result)) { \
4513
+ log("%s: failed to add metal buffer\n", __func__); \
4514
+ return 0; \
4515
+ }
4516
+
4517
+ const std::string kv_name = "kv_self_" + std::to_string(j);
4518
+ auto & kv_self = decoder.kv_self;
4519
+
4520
+ WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
4521
+ #undef WHISPER_METAL_CHECK_BUF
4522
+ #endif
4185
4523
  }
4186
4524
  }
4187
4525
 
@@ -4197,7 +4535,7 @@ int whisper_full_with_state(
4197
4535
 
4198
4536
  // initial prompt
4199
4537
  if (!params.prompt_tokens && params.initial_prompt) {
4200
- prompt_tokens.resize(1024);
4538
+ prompt_tokens.resize(2048);
4201
4539
  prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
4202
4540
  params.prompt_tokens = prompt_tokens.data();
4203
4541
  params.prompt_n_tokens = prompt_tokens.size();
@@ -4238,14 +4576,6 @@ int whisper_full_with_state(
4238
4576
  std::vector<whisper_token> prompt;
4239
4577
  prompt.reserve(whisper_n_text_ctx(ctx));
4240
4578
 
4241
- // beam-search helpers
4242
- struct kv_buf {
4243
- std::vector<uint8_t> k;
4244
- std::vector<uint8_t> v;
4245
- };
4246
-
4247
- std::vector<kv_buf> kv_bufs;
4248
-
4249
4579
  struct beam_candidate {
4250
4580
  int decoder_idx;
4251
4581
  int seek_delta;
@@ -4279,7 +4609,7 @@ int whisper_full_with_state(
4279
4609
  }
4280
4610
 
4281
4611
  // encode audio features starting at offset seek
4282
- if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
4612
+ if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4283
4613
  log("%s: failed to encode\n", __func__);
4284
4614
  return -6;
4285
4615
  }
@@ -4362,7 +4692,7 @@ int whisper_full_with_state(
4362
4692
  }
4363
4693
  WHISPER_PRINT_DEBUG("\n\n");
4364
4694
 
4365
- if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
4695
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4366
4696
  log("%s: failed to decode\n", __func__);
4367
4697
  return -7;
4368
4698
  }
@@ -4382,8 +4712,8 @@ int whisper_full_with_state(
4382
4712
 
4383
4713
  decoder.kv_self.n += prompt.size();
4384
4714
 
4385
- memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
4386
- memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
4715
+ memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
4716
+ memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
4387
4717
  memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
4388
4718
  }
4389
4719
 
@@ -4394,23 +4724,7 @@ int whisper_full_with_state(
4394
4724
  for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
4395
4725
  const int64_t t_start_sample_us = wsp_ggml_time_us();
4396
4726
 
4397
- // store the KV caches of all decoders when doing beam-search
4398
4727
  if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
4399
- kv_bufs.resize(n_decoders_cur);
4400
- for (int j = 0; j < n_decoders_cur; ++j) {
4401
- auto & decoder = state->decoders[j];
4402
-
4403
- if (decoder.completed || decoder.failed) {
4404
- continue;
4405
- }
4406
-
4407
- kv_bufs[j].k.resize(wsp_ggml_nbytes(decoder.kv_self.k));
4408
- kv_bufs[j].v.resize(wsp_ggml_nbytes(decoder.kv_self.v));
4409
-
4410
- memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
4411
- memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
4412
- }
4413
-
4414
4728
  beam_candidates.clear();
4415
4729
  }
4416
4730
 
@@ -4458,6 +4772,7 @@ int whisper_full_with_state(
4458
4772
  });
4459
4773
 
4460
4774
  uint32_t cur_c = 0;
4775
+ std::vector<int> decoder_idx(n_decoders_cur, -1);
4461
4776
 
4462
4777
  for (int j = 0; j < n_decoders_cur; ++j) {
4463
4778
  auto & decoder = state->decoders[j];
@@ -4476,12 +4791,13 @@ int whisper_full_with_state(
4476
4791
  decoder.seek_delta = cur.seek_delta;
4477
4792
  decoder.has_ts = cur.has_ts;
4478
4793
 
4479
- memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
4480
- memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
4481
-
4794
+ decoder_idx[j] = cur.decoder_idx;
4482
4795
  WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
4483
4796
  __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
4484
4797
  }
4798
+
4799
+ // update KV caches
4800
+ whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
4485
4801
  }
4486
4802
 
4487
4803
  // update the decoder state
@@ -4600,7 +4916,7 @@ int whisper_full_with_state(
4600
4916
 
4601
4917
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4602
4918
 
4603
- if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4919
+ if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4604
4920
  log("%s: failed to decode\n", __func__);
4605
4921
  return -8;
4606
4922
  }
@@ -4910,6 +5226,12 @@ int whisper_full_parallel(
4910
5226
  ctx->state->t_sample_us += states[i]->t_sample_us;
4911
5227
  ctx->state->t_encode_us += states[i]->t_encode_us;
4912
5228
  ctx->state->t_decode_us += states[i]->t_decode_us;
5229
+ ctx->state->t_prompt_us += states[i]->t_prompt_us;
5230
+
5231
+ ctx->state->n_sample += states[i]->n_sample;
5232
+ ctx->state->n_encode += states[i]->n_encode;
5233
+ ctx->state->n_decode += states[i]->n_decode;
5234
+ ctx->state->n_prompt += states[i]->n_prompt;
4913
5235
 
4914
5236
  whisper_free_state(states[i]);
4915
5237
  }
@@ -4963,6 +5285,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
4963
5285
  return ctx->state->result_all[i_segment].t1;
4964
5286
  }
4965
5287
 
5288
+ bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
5289
+ return state->result_all[i_segment].speaker_turn_next;
5290
+ }
5291
+
4966
5292
  bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
4967
5293
  return ctx->state->result_all[i_segment].speaker_turn_next;
4968
5294
  }
@@ -5106,7 +5432,8 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
5106
5432
  // b: N*N*sizeof(float)
5107
5433
  // c: N*N*sizeof(float)
5108
5434
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5109
- std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*512);
5435
+ std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead());
5436
+ std::vector<uint8_t> work;
5110
5437
 
5111
5438
  // put a bunch of random data in the buffer
5112
5439
  for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
@@ -5158,17 +5485,15 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
5158
5485
 
5159
5486
  struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(c);
5160
5487
 
5161
- gf.n_threads = n_threads;
5162
-
5163
5488
  double tsum = 0.0;
5164
5489
 
5165
5490
  // heat-up
5166
- wsp_ggml_graph_compute(ctx0, &gf);
5491
+ wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
5167
5492
 
5168
5493
  for (int i = 0; i < n_max; ++i) {
5169
5494
  const int64_t t0 = wsp_ggml_time_us();
5170
5495
 
5171
- wsp_ggml_graph_compute(ctx0, &gf);
5496
+ wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
5172
5497
 
5173
5498
  const int64_t t1 = wsp_ggml_time_us();
5174
5499