whisper.rn 0.4.0-rc.4 → 0.4.0-rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +57 -134
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +188 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +8 -1
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +2444 -359
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +1105 -197
  21. package/cpp/ggml-quants.c +66 -61
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +1040 -1590
  24. package/cpp/ggml.h +109 -30
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +143 -59
  29. package/cpp/rn-whisper.h +48 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +68 -137
  37. package/lib/commonjs/index.js.map +1 -1
  38. package/lib/commonjs/version.json +1 -1
  39. package/lib/module/index.js.map +1 -1
  40. package/lib/module/version.json +1 -1
  41. package/lib/typescript/index.d.ts +5 -0
  42. package/lib/typescript/index.d.ts.map +1 -1
  43. package/package.json +6 -5
  44. package/src/index.ts +5 -0
  45. package/src/version.json +1 -1
  46. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
  47. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
  48. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  49. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/whisper.cpp CHANGED
@@ -1,10 +1,15 @@
1
1
  #include "whisper.h"
2
+
2
3
  #ifdef WHISPER_USE_COREML
3
4
  #include "coreml/whisper-encoder.h"
4
5
  #endif
5
6
 
6
7
  #ifdef WSP_GGML_USE_METAL
7
- # include "ggml-metal.h"
8
+ #include "ggml-metal.h"
9
+ #endif
10
+
11
+ #ifdef WSP_GGML_USE_CUBLAS
12
+ #include "ggml-cuda.h"
8
13
  #endif
9
14
 
10
15
  #ifdef WHISPER_USE_OPENVINO
@@ -13,7 +18,9 @@
13
18
 
14
19
  #include "ggml.h"
15
20
  #include "ggml-alloc.h"
21
+ #include "ggml-backend.h"
16
22
 
23
+ #include <atomic>
17
24
  #include <algorithm>
18
25
  #include <cassert>
19
26
  #define _USE_MATH_DEFINES
@@ -97,10 +104,32 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
97
104
  #define BYTESWAP_TENSOR(t) do {} while (0)
98
105
  #endif
99
106
 
107
+ #ifdef __GNUC__
108
+ #ifdef __MINGW32__
109
+ #define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
110
+ #else
111
+ #define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
112
+ #endif
113
+ #else
114
+ #define WHISPER_ATTRIBUTE_FORMAT(...)
115
+ #endif
116
+
117
+ //
118
+ // logging
119
+ //
120
+
121
+ WHISPER_ATTRIBUTE_FORMAT(2, 3)
122
+ static void whisper_log_internal (wsp_ggml_log_level level, const char * format, ...);
123
+ static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data);
124
+
125
+ #define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
126
+ #define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
127
+ #define WHISPER_LOG_ERROR(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
128
+
100
129
  #define WHISPER_ASSERT(x) \
101
130
  do { \
102
131
  if (!(x)) { \
103
- log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
132
+ WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
104
133
  abort(); \
105
134
  } \
106
135
  } while (0)
@@ -119,7 +148,7 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
119
148
 
120
149
  //#define WHISPER_USE_FLASH_ATTN
121
150
  //#define WHISPER_USE_FLASH_FF
122
- #define WHISPER_MAX_DECODERS 16
151
+ #define WHISPER_MAX_DECODERS 8
123
152
  #define WHISPER_MAX_NODES 4096
124
153
 
125
154
  //
@@ -127,8 +156,8 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
127
156
  //
128
157
 
129
158
  static void wsp_ggml_graph_compute_helper(
159
+ struct wsp_ggml_cgraph * graph,
130
160
  std::vector<uint8_t> & buf,
131
- wsp_ggml_cgraph * graph,
132
161
  int n_threads,
133
162
  whisper_abort_callback abort_callback,
134
163
  void * abort_callback_data) {
@@ -145,6 +174,21 @@ static void wsp_ggml_graph_compute_helper(
145
174
  wsp_ggml_graph_compute(graph, &plan);
146
175
  }
147
176
 
177
+ static void wsp_ggml_graph_compute_helper(
178
+ struct wsp_ggml_backend * backend,
179
+ struct wsp_ggml_cgraph * graph,
180
+ int n_threads) {
181
+ if (wsp_ggml_backend_is_cpu(backend)) {
182
+ wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
183
+ }
184
+ #ifdef WSP_GGML_USE_METAL
185
+ if (wsp_ggml_backend_is_metal(backend)) {
186
+ wsp_ggml_backend_metal_set_n_cb(backend, n_threads);
187
+ }
188
+ #endif
189
+ wsp_ggml_backend_graph_compute(backend, graph);
190
+ }
191
+
148
192
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
149
193
  // the idea is to represent the original matrix multiplication:
150
194
  //
@@ -179,6 +223,7 @@ static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * c
179
223
  }
180
224
 
181
225
  // TODO: check if other platforms can benefit from this optimization
226
+ // TODO: CUDA is currently broken - seems wsp_ggml_mul_mat does not handle views correctly
182
227
  #if defined(WSP_GGML_USE_METAL)
183
228
  #define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad
184
229
  #endif
@@ -305,75 +350,6 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
305
350
  { "yue", { 99, "cantonese", } },
306
351
  };
307
352
 
308
- static const size_t MB = 1ull*1024*1024;
309
-
310
- // TODO: avoid using GGUF
311
- static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
312
- { WSP_GGML_TYPE_F32,
313
- {
314
- { MODEL_TINY, 74ull*MB },
315
- { MODEL_BASE, 142ull*MB },
316
- { MODEL_SMALL, 466ull*MB },
317
- { MODEL_MEDIUM, 1464ull*MB },
318
- { MODEL_LARGE, 2952ull*MB },
319
- },
320
- },
321
- { WSP_GGML_TYPE_F16,
322
- {
323
- { MODEL_TINY, 74ull*MB },
324
- { MODEL_BASE, 142ull*MB },
325
- { MODEL_SMALL, 466ull*MB },
326
- { MODEL_MEDIUM, 1464ull*MB },
327
- { MODEL_LARGE, 2952ull*MB },
328
- },
329
- },
330
- { WSP_GGML_TYPE_Q4_0,
331
- {
332
- { MODEL_TINY, 26ull*MB },
333
- { MODEL_BASE, 50ull*MB },
334
- { MODEL_SMALL, 154ull*MB },
335
- { MODEL_MEDIUM, 470ull*MB },
336
- { MODEL_LARGE, 940ull*MB },
337
- },
338
- },
339
- { WSP_GGML_TYPE_Q4_1,
340
- {
341
- { MODEL_TINY, 32ull*MB },
342
- { MODEL_BASE, 58ull*MB },
343
- { MODEL_SMALL, 182ull*MB },
344
- { MODEL_MEDIUM, 562ull*MB },
345
- { MODEL_LARGE, 1124ull*MB },
346
- },
347
- },
348
- { WSP_GGML_TYPE_Q5_0,
349
- {
350
- { MODEL_TINY, 30ull*MB },
351
- { MODEL_BASE, 54ull*MB },
352
- { MODEL_SMALL, 170ull*MB },
353
- { MODEL_MEDIUM, 516ull*MB },
354
- { MODEL_LARGE, 1034ull*MB },
355
- },
356
- },
357
- { WSP_GGML_TYPE_Q5_1,
358
- {
359
- { MODEL_TINY, 32ull*MB },
360
- { MODEL_BASE, 58ull*MB },
361
- { MODEL_SMALL, 182ull*MB },
362
- { MODEL_MEDIUM, 562ull*MB },
363
- { MODEL_LARGE, 1124ull*MB },
364
- },
365
- },
366
- { WSP_GGML_TYPE_Q8_0,
367
- {
368
- { MODEL_TINY, 45ull*MB },
369
- { MODEL_BASE, 84ull*MB },
370
- { MODEL_SMALL, 268ull*MB },
371
- { MODEL_MEDIUM, 834ull*MB },
372
- { MODEL_LARGE, 1674ull*MB },
373
- },
374
- },
375
- };
376
-
377
353
  struct whisper_mel {
378
354
  int n_len;
379
355
  int n_len_org;
@@ -431,6 +407,121 @@ struct whisper_segment {
431
407
  bool speaker_turn_next;
432
408
  };
433
409
 
410
+ struct whisper_batch {
411
+ int32_t n_tokens;
412
+
413
+ whisper_token * token;
414
+ whisper_pos * pos;
415
+ int32_t * n_seq_id;
416
+ whisper_seq_id ** seq_id; // null terminated
417
+ int8_t * logits;
418
+ };
419
+
420
+ static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
421
+ whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
422
+
423
+ batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens));
424
+ batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens));
425
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
426
+ batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
427
+ for (int i = 0; i < n_tokens; ++i) {
428
+ batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
429
+ }
430
+ batch.seq_id[n_tokens] = nullptr;
431
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
432
+
433
+ return batch;
434
+ }
435
+
436
+ static void whisper_batch_free(struct whisper_batch batch) {
437
+ if (batch.token) free(batch.token);
438
+ if (batch.pos) free(batch.pos);
439
+ if (batch.n_seq_id) free(batch.n_seq_id);
440
+ if (batch.seq_id) {
441
+ for (int i = 0; batch.seq_id[i]; ++i) {
442
+ free(batch.seq_id[i]);
443
+ }
444
+ free(batch.seq_id);
445
+ }
446
+ if (batch.logits) free(batch.logits);
447
+ }
448
+
449
+ static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
450
+ batch.n_tokens = n_tokens;
451
+ for (int i = 0; i < n_tokens; ++i) {
452
+ if (tokens) {
453
+ batch.token[i] = tokens[i];
454
+ }
455
+ batch.pos [i] = n_past + i;
456
+ batch.n_seq_id[i] = 1;
457
+ batch.seq_id [i][0] = seq_id;
458
+ batch.logits [i] = 0;
459
+ }
460
+ batch.logits[n_tokens - 1] = 1;
461
+ }
462
+
463
+ // replace std::pair by using customized pair struct (reason: std::pair is very slow)
464
+ template<typename A, typename B>
465
+ struct whisper_pair {
466
+ A first;
467
+ B second;
468
+
469
+ // Define a constructor that takes two arguments.
470
+ whisper_pair(const A& a, const B& b) : first(a), second(b) {}
471
+ // Define a constructor that takes no argument.
472
+ whisper_pair() : first(A()), second(B()) {}
473
+ };
474
+
475
+ // wsp_ggml_allocr wrapper for whisper usage
476
+ struct whisper_allocr {
477
+ wsp_ggml_allocr * alloc = nullptr;
478
+
479
+ std::vector<uint8_t> meta;
480
+
481
+ wsp_ggml_backend_buffer_t buffer;
482
+ };
483
+
484
+ static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
485
+ return allocr.meta.size() + wsp_ggml_allocr_max_size(allocr.alloc);
486
+ }
487
+
488
+ // measure the memory usage of a graph and prepare the allocr's internal data buffer
489
+ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, wsp_ggml_backend_t backend, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
490
+ auto & alloc = allocr.alloc;
491
+ auto & meta = allocr.meta;
492
+
493
+ alloc = wsp_ggml_allocr_new_measure_from_backend(backend);
494
+
495
+ meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
496
+
497
+ wsp_ggml_allocr_alloc_graph(alloc, get_graph());
498
+ }
499
+
500
+ static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, wsp_ggml_backend_t backend) {
501
+ if (allocr.alloc == nullptr) {
502
+ // this can be null if we use external encoder like CoreML or OpenVINO
503
+ return;
504
+ }
505
+
506
+ auto & alloc = allocr.alloc;
507
+ auto & buffer = allocr.buffer;
508
+
509
+ size_t size = wsp_ggml_allocr_max_size(alloc);
510
+
511
+ wsp_ggml_allocr_free(alloc);
512
+
513
+ buffer = wsp_ggml_backend_alloc_buffer(backend, size);
514
+ alloc = wsp_ggml_allocr_new_from_buffer(buffer);
515
+ }
516
+
517
+ static void whisper_allocr_free(struct whisper_allocr & allocr) {
518
+ if (allocr.alloc) {
519
+ wsp_ggml_allocr_free(allocr.alloc);
520
+ wsp_ggml_backend_buffer_free(allocr.buffer);
521
+ allocr.alloc = nullptr;
522
+ }
523
+ }
524
+
434
525
  // medium
435
526
  // hparams: {
436
527
  // 'n_mels': 80,
@@ -548,16 +639,31 @@ struct whisper_layer_decoder {
548
639
  struct wsp_ggml_tensor * mlp_1_b;
549
640
  };
550
641
 
642
+ struct whisper_kv_cell {
643
+ whisper_pos pos = -1;
644
+
645
+ std::set<whisper_seq_id> seq_id;
646
+
647
+ bool has_seq_id(const whisper_seq_id & id) const {
648
+ return seq_id.find(id) != seq_id.end();
649
+ }
650
+ };
651
+
551
652
  struct whisper_kv_cache {
653
+ uint32_t head = 0;
654
+ uint32_t size = 0;
655
+
656
+ // computed before each graph build
657
+ uint32_t n = 0;
658
+
659
+ std::vector<whisper_kv_cell> cells;
660
+
552
661
  struct wsp_ggml_tensor * k;
553
662
  struct wsp_ggml_tensor * v;
554
663
 
555
664
  struct wsp_ggml_context * ctx;
556
665
 
557
- // buf points to the memory allocated for both wsp_ggml_tensor 'k' and 'v' (see kv_cache_init)
558
- std::vector<uint8_t> buf;
559
-
560
- int n; // number of tokens currently in the cache
666
+ wsp_ggml_backend_buffer_t buffer;
561
667
  };
562
668
 
563
669
  struct whisper_model {
@@ -594,17 +700,36 @@ struct whisper_model {
594
700
  std::vector<whisper_layer_encoder> layers_encoder;
595
701
  std::vector<whisper_layer_decoder> layers_decoder;
596
702
 
597
- // context
703
+ // ggml context that contains all the meta information about the model tensors
598
704
  struct wsp_ggml_context * ctx;
599
705
 
600
- // the model memory buffer is read-only and can be shared between processors
601
- std::vector<uint8_t> * buf;
706
+ // the model backend data is read-only and can be shared between processors
707
+ struct wsp_ggml_backend_buffer * buffer;
602
708
 
603
709
  // tensors
604
710
  int n_loaded;
605
711
  std::map<std::string, struct wsp_ggml_tensor *> tensors;
606
712
  };
607
713
 
714
+ struct whisper_partial_utf8 {
715
+ uint32_t value; // bit value so far (unshifted)
716
+ int n_remain; // num bytes remaining; -1 indicates invalid sequence
717
+ };
718
+
719
+ struct whisper_grammar {
720
+ /*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
721
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
722
+
723
+ // buffer for partially generated UTF-8 sequence from accepted tokens
724
+ whisper_partial_utf8 partial_utf8;
725
+ };
726
+
727
+ struct whisper_grammar_candidate {
728
+ whisper_token id;
729
+ const uint32_t * code_points;
730
+ whisper_partial_utf8 partial_utf8;
731
+ };
732
+
608
733
  struct whisper_sequence {
609
734
  std::vector<whisper_token_data> tokens;
610
735
 
@@ -620,12 +745,13 @@ struct whisper_sequence {
620
745
 
621
746
  // TAGS: WHISPER_DECODER_INIT
622
747
  struct whisper_decoder {
623
- // each decoder keeps its own KV-cache
624
- whisper_kv_cache kv_self;
625
-
626
748
  // the currently generated sequence of tokens
627
749
  whisper_sequence sequence;
628
750
 
751
+ // grammar parse state of generated sequence of tokens
752
+ whisper_grammar grammar;
753
+
754
+ int i_batch; // the index of the token in the current batch
629
755
  int seek_delta; // the window shift found so far based on the decoded timestamp tokens
630
756
 
631
757
  bool failed; // has the current segment failed to decode?
@@ -637,93 +763,42 @@ struct whisper_decoder {
637
763
  std::vector<float> logits;
638
764
  std::vector<float> logprobs;
639
765
 
640
- std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
641
- };
642
-
643
- // replace std::pair by using customized pair struct (reason: std::pair is very slow)
644
- template<typename A, typename B>
645
- struct whisper_pair {
646
- A first;
647
- B second;
648
-
649
- // Define a constructor that takes two arguments.
650
- whisper_pair(const A& a, const B& b) : first(a), second(b) {}
651
- // Define a constructor that takes no argument.
652
- whisper_pair() : first(A()), second(B()) {}
653
- };
654
-
655
- // beam-search helpers
656
- struct kv_buf {
657
- std::vector<uint8_t> k;
658
- std::vector<uint8_t> v;
659
- };
660
-
661
- // wsp_ggml_allocr wrapper for whisper usage
662
- struct whisper_allocr {
663
- wsp_ggml_allocr * alloc = nullptr;
766
+ // work container used to avoid memory allocations
767
+ std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
664
768
 
665
- std::vector<uint8_t> meta;
666
- std::vector<uint8_t> data;
769
+ mutable std::mt19937 rng; // used for sampling at t > 0.0
667
770
  };
668
771
 
669
- static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
670
- return allocr.meta.size() + allocr.data.size();
671
- }
672
-
673
- // measure the memory usage of a graph and prepare the allocr's internal data buffer
674
- static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
675
- const int tensor_alignment = 32;
676
-
677
- auto & alloc = allocr.alloc;
678
- auto & meta = allocr.meta;
679
- auto & data = allocr.data;
680
-
681
- meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
682
-
683
- alloc = wsp_ggml_allocr_new_measure(tensor_alignment);
684
-
685
- const size_t alloc_size = wsp_ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
686
-
687
- wsp_ggml_allocr_free(alloc);
688
-
689
- data.resize(alloc_size);
690
-
691
- alloc = wsp_ggml_allocr_new(data.data(), data.size(), tensor_alignment);
692
- }
693
-
694
- static void whisper_allocr_free(struct whisper_allocr & allocr) {
695
- if (allocr.alloc) {
696
- wsp_ggml_allocr_free(allocr.alloc);
697
- allocr.alloc = nullptr;
698
- }
699
- }
700
-
701
772
  struct whisper_state {
702
773
  int64_t t_sample_us = 0;
703
774
  int64_t t_encode_us = 0;
704
775
  int64_t t_decode_us = 0;
776
+ int64_t t_batchd_us = 0;
705
777
  int64_t t_prompt_us = 0;
706
778
  int64_t t_mel_us = 0;
707
779
 
708
780
  int32_t n_sample = 0; // number of tokens sampled
709
781
  int32_t n_encode = 0; // number of encoder calls
710
- int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
711
- int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
782
+ int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
783
+ int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
784
+ int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
712
785
  int32_t n_fail_p = 0; // number of logprob threshold failures
713
786
  int32_t n_fail_h = 0; // number of entropy threshold failures
714
787
 
788
+ // unified self-attention KV cache for all decoders
789
+ whisper_kv_cache kv_self;
790
+
715
791
  // cross-attention KV cache for the decoders
716
792
  // shared between all decoders
717
793
  whisper_kv_cache kv_cross;
794
+
718
795
  whisper_mel mel;
719
796
 
720
- whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
797
+ whisper_batch batch;
721
798
 
722
- // buffer for swapping KV caches between decoders during beam-search
723
- std::vector<kv_buf> kv_swap_bufs;
799
+ whisper_decoder decoders[WHISPER_MAX_DECODERS];
724
800
 
725
- // reusable buffer for `struct wsp_ggml_graph_plan.work_data`
726
- std::vector<uint8_t> work_buffer;
801
+ wsp_ggml_backend_t backend = nullptr;
727
802
 
728
803
  // ggml-alloc:
729
804
  // - stores meta info about the intermediate tensors into the `meta` buffers
@@ -737,36 +812,34 @@ struct whisper_state {
737
812
  struct wsp_ggml_tensor * embd_conv = nullptr;
738
813
  struct wsp_ggml_tensor * embd_enc = nullptr;
739
814
 
815
+ // helpers for GPU offloading
816
+ std::vector<float> inp_mel;
817
+ std::vector<float> inp_mask;
818
+
740
819
  // decode output (2-dimensional array: [n_tokens][n_vocab])
741
820
  std::vector<float> logits;
742
821
 
743
822
  std::vector<whisper_segment> result_all;
744
823
  std::vector<whisper_token> prompt_past;
745
824
 
746
- // work container used to avoid memory allocations
747
- std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
748
-
749
- mutable std::mt19937 rng; // used for sampling at t > 0.0
750
-
751
825
  int lang_id = 0; // english by default
752
826
 
753
827
  std::string path_model; // populated by whisper_init_from_file_with_params()
828
+
754
829
  #ifdef WHISPER_USE_COREML
755
830
  whisper_coreml_context * ctx_coreml = nullptr;
756
831
  #endif
757
832
 
758
- #ifdef WSP_GGML_USE_METAL
759
- wsp_ggml_metal_context * ctx_metal = nullptr;
760
- #endif
761
-
762
833
  #ifdef WHISPER_USE_OPENVINO
763
834
  whisper_openvino_context * ctx_openvino = nullptr;
764
835
  #endif
765
836
 
766
837
  // [EXPERIMENTAL] token-level timestamps data
767
- int64_t t_beg = 0;
838
+ int64_t t_beg = 0;
768
839
  int64_t t_last = 0;
840
+
769
841
  whisper_token tid_last;
842
+
770
843
  std::vector<float> energy; // PCM signal energy
771
844
 
772
845
  // [EXPERIMENTAL] speed-up techniques
@@ -780,35 +853,25 @@ struct whisper_context {
780
853
  wsp_ggml_type wtype = wsp_ggml_type::WSP_GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
781
854
  wsp_ggml_type itype = wsp_ggml_type::WSP_GGML_TYPE_F16; // intermediate type (FP32 or FP16)
782
855
 
856
+ whisper_context_params params;
857
+
783
858
  whisper_model model;
784
859
  whisper_vocab vocab;
860
+
785
861
  whisper_state * state = nullptr;
786
862
 
863
+ wsp_ggml_backend_t backend = nullptr;
864
+
787
865
  std::string path_model; // populated by whisper_init_from_file_with_params()
788
- whisper_context_params params;
789
866
  };
790
867
 
791
- static void whisper_default_log(const char * text) {
792
- fprintf(stderr, "%s", text);
793
- }
794
-
795
- static whisper_log_callback whisper_log = whisper_default_log;
868
+ struct whisper_global {
869
+ // We save the log callback globally
870
+ wsp_ggml_log_callback log_callback = whisper_log_callback_default;
871
+ void * log_callback_user_data = nullptr;
872
+ };
796
873
 
797
- #ifdef __GNUC__
798
- #ifdef __MINGW32__
799
- __attribute__((gnu_format(printf, 1, 2)))
800
- #else
801
- __attribute__((format(printf, 1, 2)))
802
- #endif
803
- #endif
804
- static void log(const char * fmt, ...) {
805
- if (!whisper_log) return;
806
- char buf[1024];
807
- va_list args;
808
- va_start(args, fmt);
809
- vsnprintf(buf, sizeof(buf), fmt, args);
810
- whisper_log(buf);
811
- }
874
+ static whisper_global g_state;
812
875
 
813
876
  template<typename T>
814
877
  static void read_safe(whisper_model_loader * loader, T & dest) {
@@ -819,6 +882,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
819
882
  static bool kv_cache_init(
820
883
  const struct whisper_hparams & hparams,
821
884
  struct whisper_kv_cache & cache,
885
+ wsp_ggml_backend_t backend,
822
886
  wsp_ggml_type wtype,
823
887
  int n_ctx) {
824
888
  const int64_t n_text_state = hparams.n_text_state;
@@ -827,66 +891,206 @@ static bool kv_cache_init(
827
891
  const int64_t n_mem = n_text_layer*n_ctx;
828
892
  const int64_t n_elements = n_text_state*n_mem;
829
893
 
830
- const size_t mem_bytes = 2*(wsp_ggml_type_size(wtype)*n_elements + wsp_ggml_tensor_overhead());
831
-
832
- cache.buf.resize(mem_bytes);
833
-
834
894
  struct wsp_ggml_init_params params = {
835
- /*.mem_size =*/ cache.buf.size(),
836
- /*.mem_buffer =*/ cache.buf.data(),
837
- /*.no_alloc =*/ false,
895
+ /*.mem_size =*/ 2*wsp_ggml_tensor_overhead(),
896
+ /*.mem_buffer =*/ nullptr,
897
+ /*.no_alloc =*/ true,
838
898
  };
839
899
 
900
+ cache.head = 0;
901
+ cache.size = n_ctx;
902
+
903
+ cache.cells.clear();
904
+ cache.cells.resize(n_ctx);
905
+
840
906
  cache.ctx = wsp_ggml_init(params);
841
907
 
842
908
  if (!cache.ctx) {
843
- log("%s: failed to allocate memory for kv cache\n", __func__);
909
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
844
910
  return false;
845
911
  }
846
912
 
847
913
  cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
848
914
  cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
849
915
 
916
+ const size_t mem_bytes = wsp_ggml_nbytes(cache.k) + wsp_ggml_nbytes(cache.v);
917
+
918
+ cache.buffer = wsp_ggml_backend_alloc_buffer(backend, mem_bytes);
919
+
920
+ // allocate the tensors into the backend buffer
921
+ {
922
+ wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(cache.buffer);
923
+
924
+ wsp_ggml_allocr_alloc(alloc, cache.k);
925
+ wsp_ggml_allocr_alloc(alloc, cache.v);
926
+
927
+ wsp_ggml_allocr_free(alloc);
928
+ }
929
+
850
930
  return true;
851
931
  }
852
932
 
853
- static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
854
- WHISPER_ASSERT(cache.ctx);
933
+ static void kv_cache_free(struct whisper_kv_cache & cache) {
934
+ if (cache.ctx) {
935
+ wsp_ggml_free(cache.ctx);
936
+ wsp_ggml_backend_buffer_free(cache.buffer);
937
+ cache.ctx = nullptr;
938
+ }
939
+ }
855
940
 
856
- const int n_elements = wsp_ggml_nelements(cache.k);
857
- WHISPER_ASSERT(n_elements == wsp_ggml_nelements(cache.v));
941
+ static bool whisper_kv_cache_find_slot(
942
+ struct whisper_kv_cache & cache,
943
+ const struct whisper_batch & batch) {
944
+ const uint32_t n_ctx = cache.size;
945
+ const uint32_t n_tokens = batch.n_tokens;
858
946
 
859
- const wsp_ggml_type wtype = cache.k->type;
860
- WHISPER_ASSERT(wtype == cache.v->type);
947
+ if (n_tokens > n_ctx) {
948
+ WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
949
+ return false;
950
+ }
861
951
 
862
- WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*wsp_ggml_type_sizef(wtype));
952
+ uint32_t n_tested = 0;
863
953
 
864
- struct wsp_ggml_init_params params = {
865
- /*.mem_size =*/ cache.buf.size(),
866
- /*.mem_buffer =*/ cache.buf.data(),
867
- /*.no_alloc =*/ false,
868
- };
954
+ while (true) {
955
+ if (cache.head + n_tokens > n_ctx) {
956
+ n_tested += n_ctx - cache.head;
957
+ cache.head = 0;
958
+ continue;
959
+ }
869
960
 
870
- cache.ctx = wsp_ggml_init(params);
961
+ bool found = true;
962
+ for (uint32_t i = 0; i < n_tokens; i++) {
963
+ if (cache.cells[cache.head + i].pos >= 0) {
964
+ found = false;
965
+ cache.head += i + 1;
966
+ n_tested += i + 1;
967
+ break;
968
+ }
969
+ }
871
970
 
872
- if (!cache.ctx) {
873
- log("%s: failed to allocate memory for kv cache\n", __func__);
874
- return false;
971
+ if (found) {
972
+ break;
973
+ }
974
+
975
+ if (n_tested >= n_ctx) {
976
+ //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
977
+ return false;
978
+ }
875
979
  }
876
980
 
877
- cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
878
- cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
981
+ for (uint32_t i = 0; i < n_tokens; i++) {
982
+ cache.cells[cache.head + i].pos = batch.pos[i];
983
+
984
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
985
+ cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
986
+ }
987
+ }
879
988
 
880
989
  return true;
881
990
  }
882
991
 
883
- static void kv_cache_free(struct whisper_kv_cache & cache) {
884
- if (cache.ctx) {
885
- wsp_ggml_free(cache.ctx);
886
- cache.ctx = nullptr;
992
+ // find how many cells are currently in use
993
+ static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
994
+ for (uint32_t i = cache.size - 1; i > 0; --i) {
995
+ if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
996
+ return i + 1;
997
+ }
998
+ }
999
+
1000
+ return 1;
1001
+ }
1002
+
1003
+ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
1004
+ for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
1005
+ cache.cells[i].pos = -1;
1006
+ cache.cells[i].seq_id.clear();
1007
+ }
1008
+ cache.head = 0;
1009
+ }
1010
+
1011
+ static void whisper_kv_cache_seq_rm(
1012
+ struct whisper_kv_cache & cache,
1013
+ whisper_seq_id seq_id,
1014
+ whisper_pos p0,
1015
+ whisper_pos p1) {
1016
+ uint32_t new_head = cache.size;
1017
+
1018
+ if (p0 < 0) p0 = 0;
1019
+ if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
1020
+
1021
+ for (uint32_t i = 0; i < cache.size; ++i) {
1022
+ if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1023
+ if (seq_id < 0) {
1024
+ cache.cells[i].seq_id.clear();
1025
+ } else if (cache.cells[i].has_seq_id(seq_id)) {
1026
+ cache.cells[i].seq_id.erase(seq_id);
1027
+ } else {
1028
+ continue;
1029
+ }
1030
+ if (cache.cells[i].seq_id.empty()) {
1031
+ cache.cells[i].pos = -1;
1032
+ if (new_head == cache.size) new_head = i;
1033
+ }
1034
+ }
1035
+ }
1036
+
1037
+ // If we freed up a slot, set head to it so searching can start there.
1038
+ if (new_head != cache.size) cache.head = new_head;
1039
+ }
1040
+
1041
+ static void whisper_kv_cache_seq_cp(
1042
+ struct whisper_kv_cache & cache,
1043
+ whisper_seq_id seq_id_src,
1044
+ whisper_seq_id seq_id_dst,
1045
+ whisper_pos p0,
1046
+ whisper_pos p1) {
1047
+ if (p0 < 0) p0 = 0;
1048
+ if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
1049
+
1050
+ cache.head = 0;
1051
+
1052
+ for (uint32_t i = 0; i < cache.size; ++i) {
1053
+ if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1054
+ cache.cells[i].seq_id.insert(seq_id_dst);
1055
+ }
887
1056
  }
888
1057
  }
889
1058
 
1059
+ static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1060
+ wsp_ggml_backend_t backend_gpu = NULL;
1061
+
1062
+ // initialize the backends
1063
+ #ifdef WSP_GGML_USE_CUBLAS
1064
+ if (params.use_gpu && wsp_ggml_cublas_loaded()) {
1065
+ WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1066
+ backend_gpu = wsp_ggml_backend_cuda_init(0);
1067
+ if (!backend_gpu) {
1068
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
1069
+ }
1070
+ }
1071
+ #endif
1072
+
1073
+ #ifdef WSP_GGML_USE_METAL
1074
+ if (params.use_gpu) {
1075
+ WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1076
+ wsp_ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
1077
+ backend_gpu = wsp_ggml_backend_metal_init();
1078
+ if (!backend_gpu) {
1079
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__);
1080
+ } else if (!wsp_ggml_backend_metal_supports_family(backend_gpu, 7)) {
1081
+ WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1082
+ wsp_ggml_backend_free(backend_gpu);
1083
+ backend_gpu = NULL;
1084
+ }
1085
+ }
1086
+ #endif
1087
+
1088
+ if (backend_gpu) {
1089
+ return backend_gpu;
1090
+ }
1091
+ return wsp_ggml_backend_cpu_init();
1092
+ }
1093
+
890
1094
  // load the model from a ggml file
891
1095
  //
892
1096
  // file format:
@@ -899,7 +1103,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
899
1103
  // see the convert-pt-to-ggml.py script for details
900
1104
  //
901
1105
  static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
902
- log("%s: loading model\n", __func__);
1106
+ WHISPER_LOG_INFO("%s: loading model\n", __func__);
903
1107
 
904
1108
  const int64_t t_start_us = wsp_ggml_time_us();
905
1109
 
@@ -913,7 +1117,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
913
1117
  uint32_t magic;
914
1118
  read_safe(loader, magic);
915
1119
  if (magic != WSP_GGML_FILE_MAGIC) {
916
- log("%s: invalid model data (bad magic)\n", __func__);
1120
+ WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
917
1121
  return false;
918
1122
  }
919
1123
  }
@@ -970,41 +1174,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
970
1174
  // in order to save memory and also to speed up the computation
971
1175
  wctx.wtype = wsp_ggml_ftype_to_wsp_ggml_type((wsp_ggml_ftype) (model.hparams.ftype));
972
1176
  if (wctx.wtype == WSP_GGML_TYPE_COUNT) {
973
- log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
1177
+ WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
974
1178
  return false;
975
1179
  }
976
1180
 
977
- const size_t scale = model.hparams.ftype ? 1 : 2;
978
-
979
- log("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
980
- log("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
981
- log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
982
- log("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
983
- log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
984
- log("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
985
- log("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
986
- log("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
987
- log("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
988
- log("%s: n_mels = %d\n", __func__, hparams.n_mels);
989
- log("%s: ftype = %d\n", __func__, model.hparams.ftype);
990
- log("%s: qntvr = %d\n", __func__, qntvr);
991
- log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
992
-
993
- // print memory requirements
994
- {
995
- // TODO
996
- //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
997
- // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
998
- }
999
-
1000
- // initialize all memory buffers
1001
- // always have at least one decoder
1002
-
1003
- wctx.model.buf = new std::vector<uint8_t>();
1004
- wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
1005
-
1006
- // we skip initialization of the state until it is needed
1007
- // because it might be that state will always be provided externally.
1181
+ WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
1182
+ WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
1183
+ WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
1184
+ WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
1185
+ WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
1186
+ WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
1187
+ WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
1188
+ WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
1189
+ WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
1190
+ WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels);
1191
+ WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype);
1192
+ WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr);
1193
+ WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
1008
1194
  }
1009
1195
 
1010
1196
  // load mel filters
@@ -1025,7 +1211,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1025
1211
  read_safe(loader, n_vocab);
1026
1212
 
1027
1213
  //if (n_vocab != model.hparams.n_vocab) {
1028
- // log("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
1214
+ // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
1029
1215
  // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
1030
1216
  // return false;
1031
1217
  //}
@@ -1045,7 +1231,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1045
1231
  word.assign(&tmp[0], tmp.size());
1046
1232
  } else {
1047
1233
  // seems like we have an empty-string token in multi-language models (i = 50256)
1048
- //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
1234
+ //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
1049
1235
  word = "";
1050
1236
  }
1051
1237
 
@@ -1073,7 +1259,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1073
1259
  }
1074
1260
 
1075
1261
  if (n_vocab < model.hparams.n_vocab) {
1076
- log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
1262
+ WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
1077
1263
  for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
1078
1264
  if (i > vocab.token_beg) {
1079
1265
  word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
@@ -1081,6 +1267,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1081
1267
  word = "[_EOT_]";
1082
1268
  } else if (i == vocab.token_sot) {
1083
1269
  word = "[_SOT_]";
1270
+ } else if (i == vocab.token_translate) {
1271
+ word = "[_TRANSLATE_]";
1272
+ } else if (i == vocab.token_transcribe) {
1273
+ word = "[_TRANSCRIBE_]";
1084
1274
  } else if (i == vocab.token_solm) {
1085
1275
  word = "[_SOLM_]";
1086
1276
  } else if (i == vocab.token_prev) {
@@ -1091,6 +1281,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1091
1281
  word = "[_NOT_]";
1092
1282
  } else if (i == vocab.token_beg) {
1093
1283
  word = "[_BEG_]";
1284
+ } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
1285
+ word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
1094
1286
  } else {
1095
1287
  word = "[_extra_token_" + std::to_string(i) + "]";
1096
1288
  }
@@ -1099,140 +1291,35 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1099
1291
  }
1100
1292
  }
1101
1293
 
1102
- log("%s: n_langs = %d\n", __func__, vocab.num_languages());
1294
+ WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages());
1103
1295
  }
1104
1296
 
1105
- size_t ctx_size = 0;
1106
-
1107
1297
  const wsp_ggml_type wtype = wctx.wtype;
1108
1298
  const wsp_ggml_type vtype = wctx.wtype == WSP_GGML_TYPE_F32 ? WSP_GGML_TYPE_F32 : WSP_GGML_TYPE_F16; // conv type
1109
1299
 
1300
+ // create the ggml context
1110
1301
  {
1111
1302
  const auto & hparams = model.hparams;
1112
1303
 
1113
- const int n_vocab = hparams.n_vocab;
1114
-
1115
- const int n_audio_ctx = hparams.n_audio_ctx;
1116
- const int n_audio_state = hparams.n_audio_state;
1117
1304
  const int n_audio_layer = hparams.n_audio_layer;
1305
+ const int n_text_layer = hparams.n_text_layer;
1118
1306
 
1119
- const int n_text_ctx = hparams.n_text_ctx;
1120
- const int n_text_state = hparams.n_text_state;
1121
- const int n_text_layer = hparams.n_text_layer;
1122
-
1123
- const int n_mels = hparams.n_mels;
1124
-
1125
- // encoder
1126
- {
1127
- ctx_size += n_audio_ctx*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_pe;
1128
-
1129
- ctx_size += 3*n_mels*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_1_w
1130
- ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_1_b
1131
-
1132
- ctx_size += 3*n_audio_state*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_2_w
1133
- ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_2_b
1134
-
1135
- ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_w;
1136
- ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_b;
1137
- }
1138
-
1139
- // decoder
1140
- {
1141
- ctx_size += n_text_ctx*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_pe;
1142
-
1143
- ctx_size += n_vocab*n_text_state*wsp_ggml_type_sizef(wtype); // d_te;
1144
-
1145
- ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_w;
1146
- ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_b;
1147
- }
1148
-
1149
- // encoder layers
1150
- {
1151
- ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w
1152
- ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b
1153
-
1154
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w
1155
- ctx_size += n_audio_layer*( 4*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b
1156
-
1157
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w
1158
- ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b
1159
-
1160
- ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w
1161
- ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b
1162
-
1163
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_q_w
1164
- ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b
1165
-
1166
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_k_w
1167
-
1168
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_v_w
1169
- ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b
1170
-
1171
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w
1172
- ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b
1173
- }
1174
-
1175
- // decoder layers
1176
- {
1177
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w
1178
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b
1179
-
1180
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w
1181
- ctx_size += n_text_layer*( 4*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b
1182
-
1183
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w
1184
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b
1307
+ const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1185
1308
 
1186
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w
1187
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b
1188
-
1189
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_q_w
1190
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b
1191
-
1192
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_k_w
1193
-
1194
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_v_w
1195
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b
1196
-
1197
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w
1198
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b
1199
- //
1200
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_w
1201
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_b
1202
-
1203
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_q_w
1204
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_q_b
1205
-
1206
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_k_w
1207
-
1208
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_v_w
1209
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_v_b
1210
-
1211
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_ln_1_w
1212
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_1_b
1213
- }
1214
-
1215
- ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
1216
-
1217
- log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
1218
- }
1219
-
1220
- // create the ggml context
1221
- {
1222
1309
  struct wsp_ggml_init_params params = {
1223
- /*.mem_size =*/ wctx.model.buf->size(),
1224
- /*.mem_buffer =*/ wctx.model.buf->data(),
1225
- /*.no_alloc =*/ false,
1310
+ /*.mem_size =*/ n_tensors*wsp_ggml_tensor_overhead(),
1311
+ /*.mem_buffer =*/ nullptr,
1312
+ /*.no_alloc =*/ true,
1226
1313
  };
1227
1314
 
1228
1315
  model.ctx = wsp_ggml_init(params);
1229
1316
  if (!model.ctx) {
1230
- log("%s: wsp_ggml_init() failed\n", __func__);
1317
+ WHISPER_LOG_ERROR("%s: wsp_ggml_init() failed\n", __func__);
1231
1318
  return false;
1232
1319
  }
1233
1320
  }
1234
1321
 
1235
- // prepare memory for the weights
1322
+ // prepare tensors for the weights
1236
1323
  {
1237
1324
  auto & ctx = model.ctx;
1238
1325
 
@@ -1255,16 +1342,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1255
1342
 
1256
1343
  // encoder
1257
1344
  {
1258
- model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1345
+ model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1259
1346
 
1260
- model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1261
- model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1347
+ model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1348
+ model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1262
1349
 
1263
- model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1264
- model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1350
+ model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1351
+ model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1265
1352
 
1266
- model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1267
- model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1353
+ model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1354
+ model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1268
1355
 
1269
1356
  // map by name
1270
1357
  model.tensors["encoder.positional_embedding"] = model.e_pe;
@@ -1428,12 +1515,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1428
1515
  }
1429
1516
  }
1430
1517
 
1518
+ wctx.backend = whisper_backend_init(wctx.params);
1519
+
1520
+ {
1521
+ size_t size_main = 0;
1522
+
1523
+ for (const auto & t : model.tensors) {
1524
+ size_main += wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead();
1525
+ }
1526
+
1527
+ model.buffer = wsp_ggml_backend_alloc_buffer(wctx.backend, size_main);
1528
+
1529
+ WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6);
1530
+ }
1531
+
1532
+ wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(model.buffer);
1533
+
1534
+ // allocate tensors in the backend buffers
1535
+ {
1536
+ for (const auto & t : model.tensors) {
1537
+ wsp_ggml_allocr_alloc(alloc, t.second);
1538
+ }
1539
+ }
1540
+
1431
1541
  // load weights
1432
1542
  {
1433
1543
  size_t total_size = 0;
1434
1544
 
1435
1545
  model.n_loaded = 0;
1436
1546
 
1547
+ std::vector<char> read_buf;
1548
+
1437
1549
  while (true) {
1438
1550
  int32_t n_dims;
1439
1551
  int32_t length;
@@ -1460,20 +1572,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1460
1572
  name.assign(&tmp[0], tmp.size());
1461
1573
 
1462
1574
  if (model.tensors.find(name) == model.tensors.end()) {
1463
- log("%s: unknown tensor '%s' in model file\n", __func__, name.data());
1575
+ WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
1464
1576
  return false;
1465
1577
  }
1466
1578
 
1467
1579
  auto tensor = model.tensors[name.data()];
1580
+
1468
1581
  if (wsp_ggml_nelements(tensor) != nelements) {
1469
- log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1470
- log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1582
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1583
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1471
1584
  __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
1472
1585
  return false;
1473
1586
  }
1474
1587
 
1475
1588
  if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1476
- log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1589
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1477
1590
  __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
1478
1591
  return false;
1479
1592
  }
@@ -1481,29 +1594,49 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1481
1594
  const size_t bpe = wsp_ggml_type_size(wsp_ggml_type(ttype));
1482
1595
 
1483
1596
  if ((nelements*bpe)/wsp_ggml_blck_size(tensor->type) != wsp_ggml_nbytes(tensor)) {
1484
- log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1597
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1485
1598
  __func__, name.data(), wsp_ggml_nbytes(tensor), nelements*bpe);
1486
1599
  return false;
1487
1600
  }
1488
1601
 
1489
- loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
1490
- BYTESWAP_TENSOR(tensor);
1602
+ wsp_ggml_backend_t backend = wctx.backend;
1603
+
1604
+ //printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
1605
+
1606
+ if ((wsp_ggml_backend_is_cpu(backend)
1607
+ #ifdef WSP_GGML_USE_METAL
1608
+ || wsp_ggml_backend_is_metal(backend)
1609
+ #endif
1610
+ )) {
1611
+ // for the CPU and Metal backend, we can read directly into the tensor
1612
+ loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
1613
+ BYTESWAP_TENSOR(tensor);
1614
+ } else {
1615
+ // read into a temporary buffer first, then copy to device memory
1616
+ read_buf.resize(wsp_ggml_nbytes(tensor));
1617
+
1618
+ loader->read(loader->context, read_buf.data(), read_buf.size());
1619
+
1620
+ wsp_ggml_backend_tensor_set(tensor, read_buf.data(), 0, wsp_ggml_nbytes(tensor));
1621
+ }
1491
1622
 
1492
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1024.0/1024.0);
1623
+ //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1e6);
1493
1624
  total_size += wsp_ggml_nbytes(tensor);
1494
1625
  model.n_loaded++;
1495
1626
  }
1496
1627
 
1497
- log("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
1628
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
1498
1629
 
1499
1630
  if (model.n_loaded == 0) {
1500
- log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1631
+ WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1501
1632
  } else if (model.n_loaded != (int) model.tensors.size()) {
1502
- log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1633
+ WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1503
1634
  return false;
1504
1635
  }
1505
1636
  }
1506
1637
 
1638
+ wsp_ggml_allocr_free(alloc);
1639
+
1507
1640
  wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
1508
1641
 
1509
1642
  return true;
@@ -1559,10 +1692,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1559
1692
  if (!wsp_ggml_allocr_is_measure(alloc)) {
1560
1693
  assert(mel_inp.n_mel == n_mels);
1561
1694
 
1562
- float * dst = (float *) mel->data;
1695
+ wstate.inp_mel.resize(wsp_ggml_nelements(mel));
1696
+
1697
+ float * dst = wstate.inp_mel.data();
1563
1698
  memset(dst, 0, wsp_ggml_nbytes(mel));
1564
1699
 
1565
- const int i0 = std::min(mel_offset, mel_inp.n_len);
1700
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
1566
1701
  const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1567
1702
 
1568
1703
  for (int j = 0; j < mel_inp.n_mel; ++j) {
@@ -1570,6 +1705,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1570
1705
  dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1571
1706
  }
1572
1707
  }
1708
+
1709
+ wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
1573
1710
  }
1574
1711
 
1575
1712
  struct wsp_ggml_tensor * cur = nullptr;
@@ -1577,25 +1714,18 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1577
1714
  if (!whisper_encode_external(wstate)) {
1578
1715
  // convolution + gelu
1579
1716
  {
1580
- cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1581
- cur = wsp_ggml_add(ctx0,
1582
- wsp_ggml_repeat(ctx0,
1583
- model.e_conv_1_b,
1584
- cur),
1585
- cur);
1717
+ cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1718
+ cur = wsp_ggml_add(ctx0, cur, model.e_conv_1_b);
1586
1719
 
1587
1720
  cur = wsp_ggml_gelu(ctx0, cur);
1588
1721
 
1589
1722
  cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1590
- cur = wsp_ggml_add(ctx0,
1591
- wsp_ggml_repeat(ctx0,
1592
- model.e_conv_2_b,
1593
- cur),
1594
- cur);
1723
+ cur = wsp_ggml_add(ctx0, cur, model.e_conv_2_b);
1595
1724
 
1596
1725
  cur = wsp_ggml_gelu(ctx0, cur);
1597
1726
  }
1598
1727
 
1728
+ wsp_ggml_set_name(cur, "embd_conv");
1599
1729
  wstate.embd_conv = cur;
1600
1730
  } else {
1601
1731
  #ifdef WHISPER_USE_COREML
@@ -1603,7 +1733,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1603
1733
  wsp_ggml_allocr_alloc(alloc, cur);
1604
1734
 
1605
1735
  if (!wsp_ggml_allocr_is_measure(alloc)) {
1606
- whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1736
+ whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
1607
1737
  }
1608
1738
  #endif
1609
1739
  #ifdef WHISPER_USE_OPENVINO
@@ -1615,6 +1745,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1615
1745
  }
1616
1746
  #endif
1617
1747
 
1748
+ wsp_ggml_set_name(cur, "embd_enc");
1618
1749
  wstate.embd_enc = cur;
1619
1750
  }
1620
1751
 
@@ -1648,15 +1779,22 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1648
1779
 
1649
1780
  wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
1650
1781
 
1782
+ //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state);
1783
+ //wsp_ggml_allocr_alloc(alloc, cur);
1784
+
1785
+ //if (!wsp_ggml_allocr_is_measure(alloc)) {
1786
+ // wsp_ggml_backend_tensor_copy(wstate.embd_conv, cur);
1787
+ //}
1788
+ struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
1789
+
1651
1790
  struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
1652
1791
  wsp_ggml_allocr_alloc(alloc, KQscale);
1653
1792
 
1654
1793
  if (!wsp_ggml_allocr_is_measure(alloc)) {
1655
- wsp_ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
1794
+ const float val = 1.0f/sqrtf(float(n_state)/n_head);
1795
+ wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
1656
1796
  }
1657
1797
 
1658
- struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
1659
-
1660
1798
  // ===================================================================
1661
1799
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
1662
1800
  //static int iter = -1;
@@ -1675,7 +1813,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1675
1813
  const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
1676
1814
 
1677
1815
  struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1678
-
1679
1816
  cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_cont(ctx0, wsp_ggml_transpose(ctx0, cur)));
1680
1817
 
1681
1818
  // ===================================================================
@@ -1863,11 +2000,11 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1863
2000
  ////////////////////////////////////////////////////////////////////////////
1864
2001
 
1865
2002
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1866
- // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
1867
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
1868
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
1869
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
1870
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2003
+ // wsp_ggml_used_mem(ctx0)/1e6,
2004
+ // wstate.get_buf_max_mem(0)/1e6,
2005
+ // wstate.get_buf_max_mem(1)/1e6,
2006
+ // wstate.get_buf_max_mem(2)/1e6,
2007
+ // wstate.get_buf_max_mem(3)/1e6);
1871
2008
 
1872
2009
  wsp_ggml_free(ctx0);
1873
2010
 
@@ -1897,13 +2034,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
1897
2034
 
1898
2035
  wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
1899
2036
 
2037
+ //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
2038
+ //wsp_ggml_allocr_alloc(alloc, cur);
2039
+
2040
+ //if (!wsp_ggml_allocr_is_measure(alloc)) {
2041
+ // wsp_ggml_backend_tensor_copy(wstate.embd_enc, cur);
2042
+ //}
1900
2043
  struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
1901
2044
 
1902
2045
  struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
1903
2046
  wsp_ggml_allocr_alloc(alloc, Kscale);
1904
2047
 
1905
2048
  if (!wsp_ggml_allocr_is_measure(alloc)) {
1906
- wsp_ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
2049
+ const float val = pow(float(n_state) / n_head, -0.25);
2050
+ wsp_ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
1907
2051
  }
1908
2052
 
1909
2053
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
@@ -1974,7 +2118,7 @@ static bool whisper_encode_internal(
1974
2118
  wsp_ggml_allocr_alloc_graph(alloc, gf);
1975
2119
 
1976
2120
  if (!whisper_encode_external(wstate)) {
1977
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2121
+ wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
1978
2122
  }
1979
2123
  }
1980
2124
 
@@ -1988,16 +2132,7 @@ static bool whisper_encode_internal(
1988
2132
 
1989
2133
  wsp_ggml_allocr_alloc_graph(alloc, gf);
1990
2134
 
1991
- #ifdef WSP_GGML_USE_METAL
1992
- if (wstate.ctx_metal) {
1993
- wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1994
- wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
1995
- } else {
1996
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1997
- }
1998
- #else
1999
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2000
- #endif
2135
+ wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2001
2136
  }
2002
2137
 
2003
2138
  // cross
@@ -2010,49 +2145,40 @@ static bool whisper_encode_internal(
2010
2145
 
2011
2146
  wsp_ggml_allocr_alloc_graph(alloc, gf);
2012
2147
 
2013
- #ifdef WSP_GGML_USE_METAL
2014
- if (wstate.ctx_metal) {
2015
- wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2016
- wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
2017
- } else {
2018
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2019
- }
2020
- #else
2021
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2022
- #endif
2148
+ wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2023
2149
  }
2024
2150
 
2025
- // wsp_ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
2026
-
2027
2151
  wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
2028
2152
  wstate.n_encode++;
2029
2153
 
2030
- return true;
2154
+ return !(abort_callback && abort_callback(abort_callback_data));
2031
2155
  }
2032
2156
 
2033
2157
  static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2034
2158
  whisper_context & wctx,
2035
2159
  whisper_state & wstate,
2036
- whisper_decoder & decoder,
2037
- const whisper_token * tokens,
2038
- int n_tokens,
2039
- int n_past) {
2160
+ const whisper_batch & batch) {
2040
2161
  const auto & model = wctx.model;
2041
2162
  const auto & hparams = model.hparams;
2042
2163
 
2043
- auto & kv_self = decoder.kv_self;
2164
+ auto & kv_self = wstate.kv_self;
2044
2165
 
2045
2166
  WHISPER_ASSERT(!!kv_self.ctx);
2046
2167
 
2047
- const int n_ctx = hparams.n_text_ctx;
2168
+ wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
2169
+
2170
+ const int n_ctx = kv_self.size;
2048
2171
  const int n_state = hparams.n_text_state;
2049
2172
  const int n_head = hparams.n_text_head;
2050
2173
  const int n_layer = hparams.n_text_layer;
2051
2174
 
2052
- const int N = n_tokens;
2053
- const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2175
+ const int n_tokens = batch.n_tokens;
2176
+ const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2054
2177
 
2055
- //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
2178
+ const int32_t n_kv = wsp_ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
2179
+ const int32_t kv_head = wsp_ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
2180
+
2181
+ //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2056
2182
 
2057
2183
  struct wsp_ggml_init_params params = {
2058
2184
  /*.mem_size =*/ wstate.alloc_decode.meta.size(),
@@ -2064,21 +2190,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2064
2190
 
2065
2191
  wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
2066
2192
 
2067
- wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
2068
-
2069
- struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
2193
+ struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2070
2194
  wsp_ggml_allocr_alloc(alloc, embd);
2071
2195
 
2072
2196
  if (!wsp_ggml_allocr_is_measure(alloc)) {
2073
- memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd));
2197
+ wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
2074
2198
  }
2075
2199
 
2076
- struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
2200
+ struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2077
2201
  wsp_ggml_allocr_alloc(alloc, position);
2078
2202
 
2079
2203
  if (!wsp_ggml_allocr_is_measure(alloc)) {
2080
- for (int i = 0; i < N; ++i) {
2081
- ((int32_t *) position->data)[i] = n_past + i;
2204
+ for (int i = 0; i < n_tokens; ++i) {
2205
+ const int32_t val = batch.pos[i];
2206
+ wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
2082
2207
  }
2083
2208
  }
2084
2209
 
@@ -2086,7 +2211,33 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2086
2211
  wsp_ggml_allocr_alloc(alloc, KQscale);
2087
2212
 
2088
2213
  if (!wsp_ggml_allocr_is_measure(alloc)) {
2089
- wsp_ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
2214
+ const float val = pow(float(n_state)/n_head, -0.25);
2215
+ wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
2216
+ }
2217
+
2218
+ struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, n_tokens, 1);
2219
+ wsp_ggml_allocr_alloc(alloc, KQ_mask);
2220
+
2221
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
2222
+ wstate.inp_mask.resize(n_kv*n_tokens);
2223
+
2224
+ float * data = wstate.inp_mask.data();
2225
+ memset(data, 0, wsp_ggml_nbytes(KQ_mask));
2226
+
2227
+ for (int h = 0; h < 1; ++h) {
2228
+ for (int j = 0; j < n_tokens; ++j) {
2229
+ const whisper_pos pos = batch.pos[j];
2230
+ const whisper_seq_id seq_id = batch.seq_id[j][0];
2231
+
2232
+ for (int i = 0; i < n_kv; ++i) {
2233
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2234
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
2235
+ }
2236
+ }
2237
+ }
2238
+ }
2239
+
2240
+ wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
2090
2241
  }
2091
2242
 
2092
2243
  // token encoding + position encoding
@@ -2141,12 +2292,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2141
2292
  Vcur,
2142
2293
  layer.attn_v_b);
2143
2294
 
2144
- Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, N));
2295
+ Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2145
2296
 
2146
- struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k, N*n_state, (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
2147
- struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v, N, n_state,
2297
+ struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2298
+ struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2148
2299
  ( n_ctx)*wsp_ggml_element_size(kv_self.v),
2149
- (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + n_past*wsp_ggml_element_size(kv_self.v));
2300
+ (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
2150
2301
 
2151
2302
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
2152
2303
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
@@ -2156,12 +2307,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2156
2307
 
2157
2308
  struct wsp_ggml_tensor * Q =
2158
2309
  wsp_ggml_permute(ctx0,
2159
- wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2310
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2160
2311
  0, 2, 1, 3);
2161
2312
 
2162
2313
  struct wsp_ggml_tensor * K =
2163
2314
  wsp_ggml_view_3d(ctx0, kv_self.k,
2164
- n_state/n_head, n_past + N, n_head,
2315
+ n_state/n_head, n_kv, n_head,
2165
2316
  wsp_ggml_element_size(kv_self.k)*n_state,
2166
2317
  wsp_ggml_element_size(kv_self.k)*n_state/n_head,
2167
2318
  wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
@@ -2171,16 +2322,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2171
2322
 
2172
2323
  //struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
2173
2324
 
2174
- struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
2325
+ //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
2326
+ struct wsp_ggml_tensor * KQ_masked = wsp_ggml_add(ctx0, KQ, KQ_mask);
2175
2327
 
2176
2328
  struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
2177
2329
 
2178
2330
  struct wsp_ggml_tensor * V =
2179
2331
  wsp_ggml_view_3d(ctx0, kv_self.v,
2180
- n_past + N, n_state/n_head, n_head,
2332
+ n_kv, n_state/n_head, n_head,
2181
2333
  n_ctx*wsp_ggml_element_size(kv_self.v),
2182
2334
  n_ctx*wsp_ggml_element_size(kv_self.v)*n_state/n_head,
2183
- il*n_ctx*wsp_ggml_element_size(kv_self.v)*n_state);
2335
+ n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
2184
2336
 
2185
2337
  struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2186
2338
 
@@ -2188,7 +2340,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2188
2340
 
2189
2341
  cur = wsp_ggml_cpy(ctx0,
2190
2342
  KQV_merged,
2191
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, N));
2343
+ wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
2192
2344
  }
2193
2345
 
2194
2346
  // projection
@@ -2232,33 +2384,33 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2232
2384
  // Kcross is already scaled
2233
2385
  struct wsp_ggml_tensor * Kcross =
2234
2386
  wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2235
- n_state/n_head, M, n_head,
2387
+ n_state/n_head, n_audio_ctx, n_head,
2236
2388
  wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2237
2389
  wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2238
- wsp_ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
2390
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2239
2391
 
2240
2392
  //struct wsp_ggml_tensor * Vcross =
2241
2393
  // wsp_ggml_reshape_3d(ctx0,
2242
- // wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
2243
- // n_state/n_head, n_head, M);
2394
+ // wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
2395
+ // n_state/n_head, n_head, n_audio_ctx);
2244
2396
 
2245
2397
  //struct wsp_ggml_tensor * V_trans =
2246
2398
  // wsp_ggml_cpy(ctx0,
2247
2399
  // wsp_ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2248
- // wsp_ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
2400
+ // wsp_ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
2249
2401
 
2250
2402
  struct wsp_ggml_tensor * V =
2251
2403
  wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
2252
- M, n_state/n_head, n_head,
2253
- M*wsp_ggml_element_size(wstate.kv_cross.v),
2254
- M*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2255
- il*M*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
2404
+ n_audio_ctx, n_state/n_head, n_head,
2405
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
2406
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2407
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
2256
2408
 
2257
2409
  // ------
2258
2410
 
2259
2411
  struct wsp_ggml_tensor * Q =
2260
2412
  wsp_ggml_permute(ctx0,
2261
- wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2413
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2262
2414
  0, 2, 1, 3);
2263
2415
 
2264
2416
  // K * Q
@@ -2279,10 +2431,10 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2279
2431
 
2280
2432
  struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2281
2433
 
2282
- // cur = KQV_merged.contiguous().view(n_state, N)
2434
+ // cur = KQV_merged.contiguous().view(n_state, n_tokens)
2283
2435
  cur = wsp_ggml_cpy(ctx0,
2284
2436
  KQV_merged,
2285
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, N));
2437
+ wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
2286
2438
  }
2287
2439
 
2288
2440
  // projection
@@ -2354,9 +2506,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2354
2506
  }
2355
2507
 
2356
2508
  // compute logits only for the last token
2357
- // comment this line to compute logits for all N tokens
2509
+ // comment this line to compute logits for all n_tokens
2358
2510
  // might be useful in the future
2359
- cur = wsp_ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2511
+ //cur = wsp_ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2360
2512
 
2361
2513
  struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
2362
2514
 
@@ -2380,10 +2532,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2380
2532
  static bool whisper_decode_internal(
2381
2533
  whisper_context & wctx,
2382
2534
  whisper_state & wstate,
2383
- whisper_decoder & decoder,
2384
- const whisper_token * tokens,
2385
- const int n_tokens,
2386
- const int n_past,
2535
+ const whisper_batch & batch,
2387
2536
  const int n_threads,
2388
2537
  whisper_abort_callback abort_callback,
2389
2538
  void * abort_callback_data) {
@@ -2392,65 +2541,72 @@ static bool whisper_decode_internal(
2392
2541
  const auto & model = wctx.model;
2393
2542
  const auto & hparams = model.hparams;
2394
2543
 
2395
- const int n_vocab = hparams.n_vocab;
2544
+ const int n_vocab = hparams.n_vocab;
2545
+ const int n_tokens = batch.n_tokens;
2396
2546
 
2397
2547
  auto & logits_out = wstate.logits;
2398
2548
 
2399
2549
  struct wsp_ggml_tensor * logits;
2400
2550
 
2551
+ // find KV slot for the batch
2552
+ {
2553
+ auto & kv_self = wstate.kv_self;
2554
+
2555
+ if (!whisper_kv_cache_find_slot(kv_self, batch)) {
2556
+ return false;
2557
+ }
2558
+
2559
+ kv_self.n = whisper_kv_cache_cell_max(kv_self);
2560
+ //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2561
+ //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2562
+ }
2563
+
2401
2564
  // decoder
2402
2565
  {
2403
2566
  auto & alloc = wstate.alloc_decode.alloc;
2404
2567
 
2405
2568
  wsp_ggml_allocr_reset(alloc);
2406
2569
 
2407
- wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
2570
+ wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
2408
2571
 
2409
2572
  wsp_ggml_allocr_alloc_graph(alloc, gf);
2410
2573
 
2411
2574
  logits = gf->nodes[gf->n_nodes - 1];
2412
2575
 
2413
- #ifdef WSP_GGML_USE_METAL
2414
- if (wstate.ctx_metal) {
2415
- wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2416
- wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
2417
- } else {
2418
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2419
- }
2420
- #else
2421
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2422
- #endif
2576
+ wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2423
2577
  }
2424
2578
 
2425
- // extract logits for all N tokens
2426
- //logits_out.resize(n_tokens*n_vocab);
2427
- //memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
2428
-
2429
- // extract logits only for the last token
2430
- logits_out.resize(n_vocab);
2431
- memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_vocab);
2579
+ logits_out.resize(n_tokens*n_vocab);
2580
+ for (int i = 0; i < n_tokens; i++) {
2581
+ if (batch.logits[i] == 0) {
2582
+ continue;
2583
+ }
2584
+ wsp_ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
2585
+ }
2432
2586
 
2433
- if (n_tokens > 1) {
2587
+ if (batch.n_tokens > 1) {
2434
2588
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2435
- // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
2436
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
2437
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
2438
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
2439
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2589
+ // wsp_ggml_used_mem(ctx0)/1e6,
2590
+ // wstate.get_buf_max_mem(0)/1e6,
2591
+ // wstate.get_buf_max_mem(1)/1e6,
2592
+ // wstate.get_buf_max_mem(2)/1e6,
2593
+ // wstate.get_buf_max_mem(3)/1e6);
2440
2594
  }
2441
2595
 
2442
- if (n_tokens == 1) {
2596
+ if (batch.n_tokens == 1) {
2443
2597
  wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
2444
2598
  wstate.n_decode++;
2599
+ } else if (batch.n_tokens < 16) {
2600
+ wstate.t_batchd_us += wsp_ggml_time_us() - t_start_us;
2601
+ wstate.n_batchd += n_tokens;
2445
2602
  } else {
2446
2603
  wstate.t_prompt_us += wsp_ggml_time_us() - t_start_us;
2447
- wstate.n_prompt++;
2604
+ wstate.n_prompt += n_tokens;
2448
2605
  }
2449
2606
 
2450
- return true;
2607
+ return !(abort_callback && abort_callback(abort_callback_data));
2451
2608
  }
2452
2609
 
2453
-
2454
2610
  // 500 -> 00:05.000
2455
2611
  // 6000 -> 01:00.000
2456
2612
  static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2794,7 +2950,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
2794
2950
  --j;
2795
2951
  }
2796
2952
  if (!found) {
2797
- log("unknown token\n");
2953
+ WHISPER_LOG_ERROR("unknown token\n");
2798
2954
  ++i;
2799
2955
  }
2800
2956
  }
@@ -2857,95 +3013,105 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
2857
3013
 
2858
3014
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
2859
3015
  fill_sin_cos_table();
3016
+
2860
3017
  whisper_state * state = new whisper_state;
2861
3018
 
2862
- if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2863
- log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3019
+ state->backend = whisper_backend_init(ctx->params);
3020
+
3021
+ // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3022
+ // in theory, there can be a case where this is not enough, but in practice it should always be enough
3023
+ const int factor = 3;
3024
+
3025
+ if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
3026
+ WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
2864
3027
  delete state;
2865
3028
  return nullptr;
2866
3029
  }
2867
3030
 
2868
3031
  {
2869
- const size_t memory_size = wsp_ggml_nbytes(state->decoders[0].kv_self.k) + wsp_ggml_nbytes(state->decoders[0].kv_self.v);
2870
- log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
3032
+ const size_t memory_size = wsp_ggml_nbytes(state->kv_self.k) + wsp_ggml_nbytes(state->kv_self.v);
3033
+ WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
2871
3034
  }
2872
3035
 
2873
- if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2874
- log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
3036
+ if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
3037
+ WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2875
3038
  delete state;
2876
3039
  return nullptr;
2877
3040
  }
2878
3041
 
2879
3042
  {
2880
3043
  const size_t memory_size = wsp_ggml_nbytes(state->kv_cross.k) + wsp_ggml_nbytes(state->kv_cross.v);
2881
- log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
3044
+ WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
2882
3045
  }
2883
3046
 
2884
-
3047
+
2885
3048
  #ifdef WHISPER_USE_COREML
2886
3049
  if (ctx->params.use_coreml) {
2887
3050
  const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
2888
3051
 
2889
- log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
2890
- log("%s: first run on a device may take a while ...\n", __func__);
3052
+ WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
3053
+ WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
2891
3054
 
2892
3055
  state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
2893
3056
  if (!state->ctx_coreml) {
2894
- log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
3057
+ WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2895
3058
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
2896
3059
  delete state;
2897
3060
  return nullptr;
2898
3061
  #endif
2899
3062
  } else {
2900
- log("%s: Core ML model loaded\n", __func__);
3063
+ WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__);
2901
3064
  }
2902
3065
  }
2903
3066
  #endif
2904
3067
 
2905
3068
  state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
2906
3069
 
2907
- state->logits_id.reserve(ctx->model.hparams.n_vocab);
3070
+ state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
2908
3071
 
2909
3072
  // TAGS: WHISPER_DECODER_INIT
2910
3073
  state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
2911
3074
 
2912
- state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
2913
- state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
2914
- state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
3075
+ state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
3076
+ state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
3077
+ state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
3078
+ state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
3079
+
3080
+ state->decoders[0].rng = std::mt19937(0);
2915
3081
 
2916
3082
  // conv allocator
2917
3083
  {
2918
- whisper_allocr_graph_init(state->alloc_conv,
3084
+ whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
2919
3085
  [&]() {
2920
3086
  return whisper_build_graph_conv(*ctx, *state, 0);
2921
3087
  });
2922
3088
 
2923
- log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
3089
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
2924
3090
  }
2925
3091
 
2926
3092
  // encoder allocator
2927
3093
  if (!whisper_encode_external(*state)) {
2928
- whisper_allocr_graph_init(state->alloc_encode,
3094
+ whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
2929
3095
  [&]() {
2930
3096
  return whisper_build_graph_encoder(*ctx, *state);
2931
3097
  });
2932
3098
 
2933
- log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
3099
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
2934
3100
  }
2935
3101
 
2936
3102
  // cross allocator
2937
3103
  {
2938
- whisper_allocr_graph_init(state->alloc_cross,
3104
+ whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
2939
3105
  [&]() {
2940
3106
  return whisper_build_graph_cross(*ctx, *state);
2941
3107
  });
2942
3108
 
2943
- log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
3109
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
2944
3110
  }
2945
3111
 
2946
3112
  // decoder allocator
2947
3113
  {
2948
- whisper_allocr_graph_init(state->alloc_decode,
3114
+ whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
2949
3115
  [&]() {
2950
3116
  const auto & hparams = ctx->model.hparams;
2951
3117
 
@@ -2953,74 +3119,18 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2953
3119
  const int n_tokens = hparams.n_text_ctx;
2954
3120
  const int n_past = 0;
2955
3121
 
2956
- return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
2957
- });
2958
-
2959
- log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
2960
- }
2961
-
2962
- #ifdef WSP_GGML_USE_METAL
2963
- if (ctx->params.use_gpu) {
2964
- state->ctx_metal = wsp_ggml_metal_init(1);
2965
- if (!state->ctx_metal) {
2966
- log("%s: wsp_ggml_metal_init() failed\n", __func__);
2967
- delete state;
2968
- return nullptr;
2969
- }
2970
- }
2971
-
2972
- if (state->ctx_metal) {
2973
- log("%s: Metal context initialized\n", __func__);
2974
-
2975
- // this allocates all Metal resources and memory buffers
2976
-
2977
- void * data_ptr = NULL;
2978
- size_t data_size = 0;
2979
-
2980
- // TODO: add mmap support
2981
- //if (params.use_mmap) {
2982
- // data_ptr = ctx->model.mapping->addr;
2983
- // data_size = ctx->model.mapping->size;
2984
- //} else {
2985
- // data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2986
- // data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2987
- //}
2988
-
2989
- data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2990
- data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2991
-
2992
- const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx);
2993
-
2994
- log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
2995
-
2996
- #define WHISPER_METAL_CHECK_BUF(result) \
2997
- if (!(result)) { \
2998
- log("%s: failed to add metal buffer\n", __func__); \
2999
- delete state; \
3000
- return nullptr; \
3001
- }
3002
-
3003
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
3004
-
3005
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
3006
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
3007
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
3008
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
3009
-
3010
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
3011
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
3012
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
3013
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
3122
+ whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
3014
3123
 
3015
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
3016
-
3017
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
3018
- #undef WHISPER_METAL_CHECK_BUF
3124
+ return whisper_build_graph_decoder(*ctx, *state, state->batch);
3125
+ });
3019
3126
 
3127
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
3020
3128
  }
3021
- #endif
3022
3129
 
3023
- state->rng = std::mt19937(0);
3130
+ whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
3131
+ whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
3132
+ whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
3133
+ whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
3024
3134
 
3025
3135
  return state;
3026
3136
  }
@@ -3039,7 +3149,7 @@ int whisper_ctx_init_openvino_encoder(
3039
3149
  return 1;
3040
3150
  #else
3041
3151
  if (!model_path && ctx->path_model.empty()) {
3042
- log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
3152
+ WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
3043
3153
  return 1;
3044
3154
  }
3045
3155
 
@@ -3059,15 +3169,15 @@ int whisper_ctx_init_openvino_encoder(
3059
3169
  path_cache = cache_dir;
3060
3170
  }
3061
3171
 
3062
- log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3063
- log("%s: first run on a device may take a while ...\n", __func__);
3172
+ WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3173
+ WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
3064
3174
 
3065
3175
  ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3066
3176
  if (!ctx->state->ctx_openvino) {
3067
- log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3177
+ WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3068
3178
  return 1;
3069
3179
  } else {
3070
- log("%s: OpenVINO model loaded\n", __func__);
3180
+ WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__);
3071
3181
  }
3072
3182
 
3073
3183
  return 0;
@@ -3083,11 +3193,11 @@ struct whisper_context_params whisper_context_default_params() {
3083
3193
  }
3084
3194
 
3085
3195
  struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3086
- log("%s: loading model from '%s'\n", __func__, path_model);
3196
+ WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
3087
3197
 
3088
3198
  auto fin = std::ifstream(path_model, std::ios::binary);
3089
3199
  if (!fin) {
3090
- log("%s: failed to open '%s'\n", __func__, path_model);
3200
+ WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
3091
3201
  return nullptr;
3092
3202
  }
3093
3203
 
@@ -3129,7 +3239,7 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
3129
3239
 
3130
3240
  buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
3131
3241
 
3132
- log("%s: loading model from buffer\n", __func__);
3242
+ WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__);
3133
3243
 
3134
3244
  whisper_model_loader loader = {};
3135
3245
 
@@ -3165,7 +3275,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
3165
3275
 
3166
3276
  if (!whisper_model_load(loader, *ctx)) {
3167
3277
  loader->close(loader->context);
3168
- log("%s: failed to load model\n", __func__);
3278
+ WHISPER_LOG_ERROR("%s: failed to load model\n", __func__);
3169
3279
  delete ctx;
3170
3280
  return nullptr;
3171
3281
  }
@@ -3247,12 +3357,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
3247
3357
  void whisper_free_state(struct whisper_state * state)
3248
3358
  {
3249
3359
  if (state) {
3360
+ kv_cache_free(state->kv_self);
3250
3361
  kv_cache_free(state->kv_cross);
3251
3362
 
3252
- for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
3253
- kv_cache_free(state->decoders[i].kv_self);
3254
- }
3255
-
3256
3363
  #ifdef WHISPER_USE_COREML
3257
3364
  if (state->ctx_coreml != nullptr) {
3258
3365
  whisper_coreml_free(state->ctx_coreml);
@@ -3260,13 +3367,6 @@ void whisper_free_state(struct whisper_state * state)
3260
3367
  }
3261
3368
  #endif
3262
3369
 
3263
- #ifdef WSP_GGML_USE_METAL
3264
- if (state->ctx_metal) {
3265
- wsp_ggml_metal_free(state->ctx_metal);
3266
- state->ctx_metal = nullptr;
3267
- }
3268
- #endif
3269
-
3270
3370
  #ifdef WHISPER_USE_OPENVINO
3271
3371
  if (state->ctx_openvino != nullptr) {
3272
3372
  whisper_openvino_free(state->ctx_openvino);
@@ -3274,10 +3374,14 @@ void whisper_free_state(struct whisper_state * state)
3274
3374
  }
3275
3375
  #endif
3276
3376
 
3377
+ whisper_batch_free(state->batch);
3378
+
3277
3379
  whisper_allocr_free(state->alloc_conv);
3278
- whisper_allocr_free(state->alloc_decode);
3279
- whisper_allocr_free(state->alloc_cross);
3280
3380
  whisper_allocr_free(state->alloc_encode);
3381
+ whisper_allocr_free(state->alloc_cross);
3382
+ whisper_allocr_free(state->alloc_decode);
3383
+
3384
+ wsp_ggml_backend_free(state->backend);
3281
3385
 
3282
3386
  delete state;
3283
3387
  }
@@ -3288,12 +3392,15 @@ void whisper_free(struct whisper_context * ctx) {
3288
3392
  if (ctx->model.ctx) {
3289
3393
  wsp_ggml_free(ctx->model.ctx);
3290
3394
  }
3291
- if (ctx->model.buf) {
3292
- delete ctx->model.buf;
3395
+
3396
+ if (ctx->model.buffer) {
3397
+ wsp_ggml_backend_buffer_free(ctx->model.buffer);
3293
3398
  }
3294
3399
 
3295
3400
  whisper_free_state(ctx->state);
3296
3401
 
3402
+ wsp_ggml_backend_free(ctx->backend);
3403
+
3297
3404
  delete ctx;
3298
3405
  }
3299
3406
  }
@@ -3312,7 +3419,7 @@ void whisper_free_params(struct whisper_full_params * params) {
3312
3419
 
3313
3420
  int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3314
3421
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3315
- log("%s: failed to compute mel spectrogram\n", __func__);
3422
+ WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3316
3423
  return -1;
3317
3424
  }
3318
3425
 
@@ -3326,7 +3433,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3326
3433
  // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3327
3434
  int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3328
3435
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3329
- log("%s: failed to compute mel spectrogram\n", __func__);
3436
+ WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3330
3437
  return -1;
3331
3438
  }
3332
3439
 
@@ -3354,7 +3461,7 @@ int whisper_set_mel_with_state(
3354
3461
  int n_len,
3355
3462
  int n_mel) {
3356
3463
  if (n_mel != ctx->model.filters.n_mel) {
3357
- log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
3464
+ WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
3358
3465
  return -1;
3359
3466
  }
3360
3467
 
@@ -3378,7 +3485,7 @@ int whisper_set_mel(
3378
3485
 
3379
3486
  int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3380
3487
  if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
3381
- log("%s: failed to eval\n", __func__);
3488
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3382
3489
  return -1;
3383
3490
  }
3384
3491
 
@@ -3387,7 +3494,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
3387
3494
 
3388
3495
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3389
3496
  if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
3390
- log("%s: failed to eval\n", __func__);
3497
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3391
3498
  return -1;
3392
3499
  }
3393
3500
 
@@ -3395,10 +3502,12 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3395
3502
  }
3396
3503
 
3397
3504
  int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3398
- const int selected_decoder_id = 0;
3505
+ whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
3506
+
3507
+ whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
3399
3508
 
3400
- if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3401
- log("%s: failed to eval\n", __func__);
3509
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
3510
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3402
3511
  return 1;
3403
3512
  }
3404
3513
 
@@ -3406,27 +3515,19 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
3406
3515
  }
3407
3516
 
3408
3517
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3409
- // TODO: add selected_decoder_id to state
3410
- const int selected_decoder_id = 0;
3411
-
3412
3518
  if (ctx->state == nullptr) {
3413
- log("%s: ERROR state was not loaded.\n", __func__);
3414
- return false;
3415
- }
3416
-
3417
- if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3418
- log("%s: failed to eval\n", __func__);
3419
- return 1;
3519
+ WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
3520
+ return -1;
3420
3521
  }
3421
3522
 
3422
- return 0;
3523
+ return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads);
3423
3524
  }
3424
3525
 
3425
3526
  int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
3426
3527
  const auto res = tokenize(ctx->vocab, text);
3427
3528
 
3428
3529
  if (n_max_tokens < (int) res.size()) {
3429
- log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3530
+ WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3430
3531
  return -1;
3431
3532
  }
3432
3533
 
@@ -3454,7 +3555,7 @@ int whisper_lang_id(const char * lang) {
3454
3555
  }
3455
3556
  }
3456
3557
 
3457
- log("%s: unknown language '%s'\n", __func__, lang);
3558
+ WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang);
3458
3559
  return -1;
3459
3560
  }
3460
3561
  return g_lang.at(lang).first;
@@ -3467,7 +3568,18 @@ const char * whisper_lang_str(int id) {
3467
3568
  }
3468
3569
  }
3469
3570
 
3470
- log("%s: unknown language id %d\n", __func__, id);
3571
+ WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
3572
+ return nullptr;
3573
+ }
3574
+
3575
+ const char * whisper_lang_str_full(int id) {
3576
+ for (const auto & kv : g_lang) {
3577
+ if (kv.second.first == id) {
3578
+ return kv.second.second.c_str();
3579
+ }
3580
+ }
3581
+
3582
+ WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
3471
3583
  return nullptr;
3472
3584
  }
3473
3585
 
@@ -3480,29 +3592,29 @@ int whisper_lang_auto_detect_with_state(
3480
3592
  const int seek = offset_ms/10;
3481
3593
 
3482
3594
  if (seek < 0) {
3483
- log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
3595
+ WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
3484
3596
  return -1;
3485
3597
  }
3486
3598
 
3487
3599
  if (seek >= state->mel.n_len_org) {
3488
- log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
3600
+ WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
3489
3601
  return -2;
3490
3602
  }
3491
3603
 
3492
3604
  // run the encoder
3493
3605
  if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
3494
- log("%s: failed to encode\n", __func__);
3606
+ WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
3495
3607
  return -6;
3496
3608
  }
3497
3609
 
3498
3610
  const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
3499
3611
 
3500
3612
  if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
3501
- log("%s: failed to decode\n", __func__);
3613
+ WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
3502
3614
  return -7;
3503
3615
  }
3504
3616
 
3505
- auto & logits_id = state->logits_id;
3617
+ auto & logits_id = state->decoders[0].logits_id;
3506
3618
  logits_id.clear();
3507
3619
 
3508
3620
  for (const auto & kv : g_lang) {
@@ -3698,28 +3810,31 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
3698
3810
  void whisper_print_timings(struct whisper_context * ctx) {
3699
3811
  const int64_t t_end_us = wsp_ggml_time_us();
3700
3812
 
3701
- log("\n");
3702
- log("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3813
+ WHISPER_LOG_INFO("\n");
3814
+ WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3703
3815
  if (ctx->state != nullptr) {
3704
3816
 
3705
3817
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3706
3818
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3707
3819
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3820
+ const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
3708
3821
  const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3709
3822
 
3710
- log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3711
- log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3712
- log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3713
- log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3714
- log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3715
- log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3823
+ WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3824
+ WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3825
+ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3826
+ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3827
+ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3828
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
3829
+ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3716
3830
  }
3717
- log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3831
+ WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3718
3832
  }
3719
3833
 
3720
3834
  void whisper_reset_timings(struct whisper_context * ctx) {
3721
3835
  ctx->t_start_us = wsp_ggml_time_us();
3722
3836
  if (ctx->state != nullptr) {
3837
+ ctx->state->t_mel_us = 0;
3723
3838
  ctx->state->t_sample_us = 0;
3724
3839
  ctx->state->t_encode_us = 0;
3725
3840
  ctx->state->t_decode_us = 0;
@@ -3727,6 +3842,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
3727
3842
  ctx->state->n_sample = 0;
3728
3843
  ctx->state->n_encode = 0;
3729
3844
  ctx->state->n_decode = 0;
3845
+ ctx->state->n_batchd = 0;
3730
3846
  ctx->state->n_prompt = 0;
3731
3847
  }
3732
3848
  }
@@ -3765,12 +3881,431 @@ const char * whisper_print_system_info(void) {
3765
3881
  s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
3766
3882
  s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
3767
3883
  s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
3884
+ s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cublas()) + " | ";
3768
3885
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
3769
3886
  s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
3770
3887
 
3771
- return s.c_str();
3888
+ return s.c_str();
3889
+ }
3890
+
3891
+ //////////////////////////////////
3892
+ // Grammar - ported from llama.cpp
3893
+ //////////////////////////////////
3894
+
3895
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
3896
+ // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
3897
+ std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
3898
+ const char * src,
3899
+ whisper_partial_utf8 partial_start) {
3900
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
3901
+ const char * pos = src;
3902
+ std::vector<uint32_t> code_points;
3903
+ uint32_t value = partial_start.value;
3904
+ int n_remain = partial_start.n_remain;
3905
+
3906
+ // continue previous decode, if applicable
3907
+ while (*pos != 0 && n_remain > 0) {
3908
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
3909
+ if ((next_byte >> 6) != 2) {
3910
+ // invalid sequence, abort
3911
+ code_points.push_back(0);
3912
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
3913
+ }
3914
+ value = (value << 6) + (next_byte & 0x3F);
3915
+ ++pos;
3916
+ --n_remain;
3917
+ }
3918
+
3919
+ if (partial_start.n_remain > 0 && n_remain == 0) {
3920
+ code_points.push_back(value);
3921
+ }
3922
+
3923
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
3924
+ while (*pos != 0) {
3925
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
3926
+ uint8_t highbits = first_byte >> 4;
3927
+ n_remain = lookup[highbits] - 1;
3928
+
3929
+ if (n_remain < 0) {
3930
+ // invalid sequence, abort
3931
+ code_points.clear();
3932
+ code_points.push_back(0);
3933
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
3934
+ }
3935
+
3936
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
3937
+ value = first_byte & mask;
3938
+ ++pos;
3939
+ while (*pos != 0 && n_remain > 0) {
3940
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
3941
+ ++pos;
3942
+ --n_remain;
3943
+ }
3944
+ if (n_remain == 0) {
3945
+ code_points.push_back(value);
3946
+ }
3947
+ }
3948
+ code_points.push_back(0);
3949
+
3950
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
3951
+ }
3952
+
3953
+ // returns true iff pos points to the end of one of the definitions of a rule
3954
+ static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
3955
+ switch (pos->type) {
3956
+ case WHISPER_GRETYPE_END: return true; // NOLINT
3957
+ case WHISPER_GRETYPE_ALT: return true; // NOLINT
3958
+ default: return false;
3959
+ }
3960
+ }
3961
+
3962
+ // returns true iff chr satisfies the char range at pos (regular or inverse range)
3963
+ // asserts that pos is pointing to a char range element
3964
+ static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
3965
+ const whisper_grammar_element * pos,
3966
+ const uint32_t chr) {
3967
+
3968
+ bool found = false;
3969
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
3970
+
3971
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
3972
+
3973
+ do {
3974
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
3975
+ // inclusive range, e.g. [a-z]
3976
+ found = found || (pos->value <= chr && chr <= pos[1].value);
3977
+ pos += 2;
3978
+ } else {
3979
+ // exact char match, e.g. [a] or "a"
3980
+ found = found || pos->value == chr;
3981
+ pos += 1;
3982
+ }
3983
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
3984
+
3985
+ return std::make_pair(found == is_positive_char, pos);
3986
+ }
3987
+
3988
+ // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
3989
+ // range at pos (regular or inverse range)
3990
+ // asserts that pos is pointing to a char range element
3991
+ static bool whisper_grammar_match_partial_char(
3992
+ const whisper_grammar_element * pos,
3993
+ const whisper_partial_utf8 partial_utf8) {
3994
+
3995
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
3996
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
3997
+
3998
+ uint32_t partial_value = partial_utf8.value;
3999
+ int n_remain = partial_utf8.n_remain;
4000
+
4001
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
4002
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
4003
+ return false;
4004
+ }
4005
+
4006
+ // range of possible code points this partial UTF-8 sequence could complete to
4007
+ uint32_t low = partial_value << (n_remain * 6);
4008
+ uint32_t high = low | ((1 << (n_remain * 6)) - 1);
4009
+
4010
+ if (low == 0) {
4011
+ if (n_remain == 2) {
4012
+ low = 1 << 11;
4013
+ } else if (n_remain == 3) {
4014
+ low = 1 << 16;
4015
+ }
4016
+ }
4017
+
4018
+ do {
4019
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
4020
+ // inclusive range, e.g. [a-z]
4021
+ if (pos->value <= high && low <= pos[1].value) {
4022
+ return is_positive_char;
4023
+ }
4024
+ pos += 2;
4025
+ } else {
4026
+ // exact char match, e.g. [a] or "a"
4027
+ if (low <= pos->value && pos->value <= high) {
4028
+ return is_positive_char;
4029
+ }
4030
+ pos += 1;
4031
+ }
4032
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
4033
+
4034
+ return !is_positive_char;
4035
+ }
4036
+
4037
+
4038
+ // transforms a grammar pushdown stack into N possible stacks, all ending
4039
+ // at a character range (terminal element)
4040
+ static void whisper_grammar_advance_stack(
4041
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4042
+ const std::vector<const whisper_grammar_element *> & stack,
4043
+ std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
4044
+
4045
+ if (stack.empty()) {
4046
+ new_stacks.push_back(stack);
4047
+ return;
4048
+ }
4049
+
4050
+ const whisper_grammar_element * pos = stack.back();
4051
+
4052
+ switch (pos->type) {
4053
+ case WHISPER_GRETYPE_RULE_REF: {
4054
+ const size_t rule_id = static_cast<size_t>(pos->value);
4055
+ const whisper_grammar_element * subpos = rules[rule_id].data();
4056
+ do {
4057
+ // init new stack without the top (pos)
4058
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
4059
+ if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
4060
+ // if this rule ref is followed by another element, add that to stack
4061
+ new_stack.push_back(pos + 1);
4062
+ }
4063
+ if (!whisper_grammar_is_end_of_sequence(subpos)) {
4064
+ // if alternate is nonempty, add to stack
4065
+ new_stack.push_back(subpos);
4066
+ }
4067
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
4068
+ while (!whisper_grammar_is_end_of_sequence(subpos)) {
4069
+ // scan to end of alternate def
4070
+ subpos++;
4071
+ }
4072
+ if (subpos->type == WHISPER_GRETYPE_ALT) {
4073
+ // there's another alternate def of this rule to process
4074
+ subpos++;
4075
+ } else {
4076
+ break;
4077
+ }
4078
+ } while (true);
4079
+ break;
4080
+ }
4081
+ case WHISPER_GRETYPE_CHAR:
4082
+ case WHISPER_GRETYPE_CHAR_NOT:
4083
+ new_stacks.push_back(stack);
4084
+ break;
4085
+ default:
4086
+ // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
4087
+ // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
4088
+ // those
4089
+ WHISPER_ASSERT(false);
4090
+ }
4091
+ }
4092
+
4093
+ // takes a set of possible pushdown stacks on a grammar, which are required to
4094
+ // be positioned at a character range (see `whisper_grammar_advance_stack`), and
4095
+ // produces the N possible stacks if the given char is accepted at those
4096
+ // positions
4097
+ static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
4098
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4099
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
4100
+ const uint32_t chr) {
4101
+
4102
+ std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
4103
+
4104
+ for (const auto & stack : stacks) {
4105
+ if (stack.empty()) {
4106
+ continue;
4107
+ }
4108
+
4109
+ auto match = whisper_grammar_match_char(stack.back(), chr);
4110
+ if (match.first) {
4111
+ const whisper_grammar_element * pos = match.second;
4112
+
4113
+ // update top of stack to next element, if any
4114
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
4115
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
4116
+ new_stack.push_back(pos);
4117
+ }
4118
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
4119
+ }
4120
+ }
4121
+
4122
+ return new_stacks;
4123
+ }
4124
+
4125
+ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
4126
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4127
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
4128
+ const std::vector<whisper_grammar_candidate> & candidates);
4129
+
4130
+ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
4131
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4132
+ const std::vector<const whisper_grammar_element *> & stack,
4133
+ const std::vector<whisper_grammar_candidate> & candidates) {
4134
+
4135
+ std::vector<whisper_grammar_candidate> rejects;
4136
+
4137
+ if (stack.empty()) {
4138
+ for (auto tok : candidates) {
4139
+ if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
4140
+ rejects.push_back(tok);
4141
+ }
4142
+ }
4143
+ return rejects;
4144
+ }
4145
+
4146
+ const whisper_grammar_element * stack_pos = stack.back();
4147
+
4148
+ std::vector<whisper_grammar_candidate> next_candidates;
4149
+ for (auto tok : candidates) {
4150
+ if (*tok.code_points == 0) {
4151
+ // reached end of full codepoints in token, reject iff it ended in a partial sequence
4152
+ // that cannot satisfy this position in grammar
4153
+ if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
4154
+ rejects.push_back(tok);
4155
+ }
4156
+ } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
4157
+ next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
4158
+ } else {
4159
+ rejects.push_back(tok);
4160
+ }
4161
+ }
4162
+
4163
+ const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
4164
+
4165
+ // update top of stack to next element, if any
4166
+ std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
4167
+ if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
4168
+ stack_after.push_back(stack_pos_after);
4169
+ }
4170
+ std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
4171
+ whisper_grammar_advance_stack(rules, stack_after, next_stacks);
4172
+
4173
+ auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
4174
+ for (auto tok : next_rejects) {
4175
+ rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
4176
+ }
4177
+
4178
+ return rejects;
4179
+ }
4180
+
4181
+ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
4182
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4183
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
4184
+ const std::vector<whisper_grammar_candidate> & candidates) {
4185
+ if (candidates.empty() || stacks.empty()) {
4186
+ return std::vector<whisper_grammar_candidate>();
4187
+ }
4188
+
4189
+ auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
4190
+
4191
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
4192
+ rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
4193
+ }
4194
+ return rejects;
4195
+ }
4196
+
4197
+ static struct whisper_grammar whisper_grammar_init(
4198
+ const whisper_grammar_element ** rules,
4199
+ size_t n_rules,
4200
+ size_t i_start_rule) {
4201
+ const whisper_grammar_element * pos;
4202
+
4203
+ // copy rule definitions into vectors
4204
+ std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
4205
+ for (size_t i = 0; i < n_rules; i++) {
4206
+ for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
4207
+ vec_rules[i].push_back(*pos);
4208
+ }
4209
+ vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
4210
+ }
4211
+
4212
+ // loop over alternates of start rule to build initial stacks
4213
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
4214
+ pos = rules[i_start_rule];
4215
+ do {
4216
+ std::vector<const whisper_grammar_element *> stack;
4217
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
4218
+ // if alternate is nonempty, add to stack
4219
+ stack.push_back(pos);
4220
+ }
4221
+ whisper_grammar_advance_stack(vec_rules, stack, stacks);
4222
+ while (!whisper_grammar_is_end_of_sequence(pos)) {
4223
+ // scan to end of alternate def
4224
+ pos++;
4225
+ }
4226
+ if (pos->type == WHISPER_GRETYPE_ALT) {
4227
+ // there's another alternate def of this rule to process
4228
+ pos++;
4229
+ } else {
4230
+ break;
4231
+ }
4232
+ } while (true);
4233
+
4234
+ return { std::move(vec_rules), std::move(stacks), {} };
4235
+ }
4236
+
4237
+ static void whisper_suppress_invalid_grammar(
4238
+ whisper_context & ctx,
4239
+ const whisper_full_params & params,
4240
+ std::vector<float> & logits,
4241
+ const whisper_grammar & grammar) {
4242
+
4243
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
4244
+ return;
4245
+ }
4246
+
4247
+ //bool allow_eot = false;
4248
+ //for (const auto & stack : grammar.stacks) {
4249
+ // if (stack.empty()) {
4250
+ // allow_eot = true;
4251
+ // break;
4252
+ // }
4253
+ //}
4254
+
4255
+ const whisper_token eot = whisper_token_eot(&ctx);
4256
+
4257
+ std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
4258
+ std::vector<whisper_grammar_candidate> candidates_grammar;
4259
+
4260
+ for (whisper_token id = 0; id < eot; ++id) {
4261
+ const std::string & text = ctx.vocab.id_to_token[id];
4262
+ if (!text.empty()) {
4263
+ candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
4264
+ candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
4265
+ }
4266
+ }
4267
+
4268
+ const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
4269
+
4270
+ for (const auto & reject : rejects) {
4271
+ logits[reject.id] -= params.grammar_penalty;
4272
+ }
4273
+
4274
+ // when the grammar allows a continuation, we penalize the end-of-text token
4275
+ //if (!allow_eot) {
4276
+ // logits[eot] -= params.grammar_penalty;
4277
+ //}
4278
+ //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
4279
+ }
4280
+
4281
+ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
4282
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
4283
+ return;
4284
+ }
4285
+
4286
+ //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
4287
+
4288
+ const std::string & text = ctx.vocab.id_to_token[token];
4289
+
4290
+ if (text.rfind("[_", 0) == 0) {
4291
+ // fprintf(stderr, " (skipped)\n");
4292
+ return;
4293
+ }
4294
+ // fprintf(stderr, "\n");
4295
+
4296
+ // Note terminating 0 in decoded string
4297
+ const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8);
4298
+ const auto & code_points = decoded.first;
4299
+ for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
4300
+ grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
4301
+ }
4302
+ grammar.partial_utf8 = decoded.second;
3772
4303
  }
3773
4304
 
4305
+ //////////////
4306
+ // END grammar
4307
+ //////////////
4308
+
3774
4309
  ////////////////////////////////////////////////////////////////////////////
3775
4310
 
3776
4311
  struct whisper_context_params * whisper_context_default_params_by_ref() {
@@ -3800,6 +4335,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3800
4335
 
3801
4336
  /*.translate =*/ false,
3802
4337
  /*.no_context =*/ true,
4338
+ /*.no_timestamps =*/ false,
3803
4339
  /*.single_segment =*/ false,
3804
4340
  /*.print_special =*/ false,
3805
4341
  /*.print_progress =*/ true,
@@ -3833,7 +4369,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3833
4369
  /*.max_initial_ts =*/ 1.0f,
3834
4370
  /*.length_penalty =*/ -1.0f,
3835
4371
 
3836
- /*.temperature_inc =*/ 0.4f,
4372
+ /*.temperature_inc =*/ 0.2f,
3837
4373
  /*.entropy_thold =*/ 2.4f,
3838
4374
  /*.logprob_thold =*/ -1.0f,
3839
4375
  /*.no_speech_thold =*/ 0.6f,
@@ -3862,19 +4398,24 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3862
4398
 
3863
4399
  /*.logits_filter_callback =*/ nullptr,
3864
4400
  /*.logits_filter_callback_user_data =*/ nullptr,
4401
+
4402
+ /*.grammar_rules =*/ nullptr,
4403
+ /*.n_grammar_rules =*/ 0,
4404
+ /*.i_start_rule =*/ 0,
4405
+ /*.grammar_penalty =*/ 100.0f,
3865
4406
  };
3866
4407
 
3867
4408
  switch (strategy) {
3868
4409
  case WHISPER_SAMPLING_GREEDY:
3869
4410
  {
3870
4411
  result.greedy = {
3871
- /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
4412
+ /*.best_of =*/ 5,
3872
4413
  };
3873
4414
  } break;
3874
4415
  case WHISPER_SAMPLING_BEAM_SEARCH:
3875
4416
  {
3876
4417
  result.beam_search = {
3877
- /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
4418
+ /*.beam_size =*/ 5,
3878
4419
 
3879
4420
  /*.patience =*/ -1.0f,
3880
4421
  };
@@ -3964,11 +4505,12 @@ static const std::vector<std::string> non_speech_tokens = {
3964
4505
  // process the logits for the selected decoder
3965
4506
  // - applies logit filters
3966
4507
  // - computes logprobs and probs
4508
+ // TODO: optimize
3967
4509
  static void whisper_process_logits(
3968
4510
  struct whisper_context & ctx,
3969
4511
  struct whisper_state & state,
3970
- const struct whisper_full_params params,
3971
4512
  struct whisper_decoder & decoder,
4513
+ const struct whisper_full_params params,
3972
4514
  float temperature) {
3973
4515
  const auto & vocab = ctx.vocab;
3974
4516
  const auto & tokens_cur = decoder.sequence.tokens;
@@ -3985,7 +4527,7 @@ static void whisper_process_logits(
3985
4527
  auto & logprobs = decoder.logprobs;
3986
4528
  {
3987
4529
  logits.resize(n_logits);
3988
- memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
4530
+ memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
3989
4531
 
3990
4532
  if (temperature > 0.0f) {
3991
4533
  for (int i = 0; i < n_logits; i++) {
@@ -4013,6 +4555,11 @@ static void whisper_process_logits(
4013
4555
  // suppress <|notimestamps|> token
4014
4556
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
4015
4557
  logits[vocab.token_not] = -INFINITY;
4558
+ if (params.no_timestamps) {
4559
+ for (int i = vocab.token_beg; i < n_logits; ++i) {
4560
+ logits[i] = -INFINITY;
4561
+ }
4562
+ }
4016
4563
 
4017
4564
  // suppress sot and nosp tokens
4018
4565
  logits[vocab.token_sot] = -INFINITY;
@@ -4028,6 +4575,14 @@ static void whisper_process_logits(
4028
4575
  logits[vocab.token_transcribe] = -INFINITY;
4029
4576
  logits[vocab.token_prev] = -INFINITY;
4030
4577
 
4578
+ // suppress lang tokens
4579
+ for (size_t i = 0; i < g_lang.size(); ++i) {
4580
+ logits[whisper_token_lang(&ctx, i)] = -INFINITY;
4581
+ }
4582
+
4583
+ // suppress prev token
4584
+ logits[vocab.token_prev] = -INFINITY;
4585
+
4031
4586
  if (params.logits_filter_callback) {
4032
4587
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
4033
4588
  }
@@ -4059,7 +4614,7 @@ static void whisper_process_logits(
4059
4614
  const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
4060
4615
  const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
4061
4616
 
4062
- //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
4617
+ //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
4063
4618
 
4064
4619
  if (last_was_timestamp) {
4065
4620
  if (penultimate_was_timestamp) {
@@ -4135,13 +4690,37 @@ static void whisper_process_logits(
4135
4690
 
4136
4691
  const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
4137
4692
 
4138
- //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4693
+ //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4139
4694
 
4140
4695
  if (timestamp_logprob > max_text_token_logprob) {
4141
4696
  for (int i = 0; i < vocab.token_beg; ++i) {
4142
4697
  logits[i] = -INFINITY;
4143
4698
  logprobs[i] = -INFINITY;
4144
4699
  }
4700
+ } else {
4701
+ if (params.n_grammar_rules > 0) {
4702
+ whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
4703
+
4704
+ // populate the logprobs array (log_softmax)
4705
+ {
4706
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
4707
+ float logsumexp = 0.0f;
4708
+ for (int i = 0; i < n_logits; ++i) {
4709
+ if (logits[i] > -INFINITY) {
4710
+ logsumexp += expf(logits[i] - logit_max);
4711
+ }
4712
+ }
4713
+ logsumexp = logf(logsumexp) + logit_max;
4714
+
4715
+ for (int i = 0; i < n_logits; ++i) {
4716
+ if (logits[i] > -INFINITY) {
4717
+ logprobs[i] = logits[i] - logsumexp;
4718
+ } else {
4719
+ logprobs[i] = -INFINITY;
4720
+ }
4721
+ }
4722
+ }
4723
+ }
4145
4724
  }
4146
4725
  }
4147
4726
  }
@@ -4159,38 +4738,60 @@ static void whisper_process_logits(
4159
4738
 
4160
4739
  #if 0
4161
4740
  // print first 100 logits - token string : logit
4162
- for (int i = 0; i < 100; i++) {
4163
- const auto token = vocab.id_to_token.at(i);
4164
- const auto prob = probs[i];
4165
- const auto logit = logits[i];
4166
- const auto logprob = logprobs[i];
4167
- printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
4741
+ //for (int i = 0; i < 10; i++) {
4742
+ // const auto token = vocab.id_to_token.at(i);
4743
+ // const auto prob = probs[i];
4744
+ // const auto logit = logits[i];
4745
+ // const auto logprob = logprobs[i];
4746
+ // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
4747
+ //}
4748
+
4749
+ // print sorted
4750
+ {
4751
+ std::vector<std::pair<float, int>> pairs;
4752
+
4753
+ for (int i = 0; i < n_logits; ++i) {
4754
+ pairs.push_back(std::make_pair(probs[i], i));
4755
+ }
4756
+
4757
+ std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
4758
+ return a.first > b.first;
4759
+ });
4760
+
4761
+ for (int i = 0; i < 10; i++) {
4762
+ const auto token = vocab.id_to_token.at(pairs[i].second);
4763
+ const auto prob = pairs[i].first;
4764
+ const auto logit = logits[pairs[i].second];
4765
+ const auto logprob = logprobs[pairs[i].second];
4766
+ printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
4767
+ }
4768
+
4769
+ printf("----------------\n");
4168
4770
  }
4169
4771
 
4170
4772
  // "And", "and", " And", " and"
4171
- printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
4172
- printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
4173
- printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
4174
- printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
4175
- printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
4176
-
4177
- printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
4178
- printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
4179
- printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
4180
- printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
4181
- printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
4182
-
4183
- printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
4184
- printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
4185
- printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
4186
- printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
4187
- printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
4773
+ //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
4774
+ //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
4775
+ //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
4776
+ //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
4777
+ //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
4778
+
4779
+ //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
4780
+ //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
4781
+ //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
4782
+ //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
4783
+ //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
4784
+
4785
+ //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
4786
+ //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
4787
+ //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
4788
+ //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
4789
+ //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
4188
4790
  #endif
4189
4791
  }
4190
4792
 
4191
4793
  static whisper_token_data whisper_sample_token(
4192
4794
  whisper_context & ctx,
4193
- whisper_state & state,
4194
4795
  const whisper_decoder & decoder,
4195
4796
  bool best) {
4196
4797
  whisper_token_data result = {
@@ -4235,7 +4836,7 @@ static whisper_token_data whisper_sample_token(
4235
4836
  } else {
4236
4837
  std::discrete_distribution<> dist(probs.begin(), probs.end());
4237
4838
 
4238
- result.id = dist(state.rng);
4839
+ result.id = dist(decoder.rng);
4239
4840
  result.p = probs[result.id];
4240
4841
  result.plog = logprobs[result.id];
4241
4842
  }
@@ -4245,15 +4846,12 @@ static whisper_token_data whisper_sample_token(
4245
4846
  result.pt = result.p;
4246
4847
  }
4247
4848
 
4248
- state.n_sample++;
4249
-
4250
4849
  return result;
4251
4850
  }
4252
4851
 
4253
4852
  static std::vector<whisper_token_data> whisper_sample_token_topk(
4254
4853
  whisper_context & ctx,
4255
- whisper_state & state,
4256
- const whisper_decoder & decoder,
4854
+ whisper_decoder & decoder,
4257
4855
  int k) {
4258
4856
  const auto & vocab = ctx.vocab;
4259
4857
 
@@ -4263,7 +4861,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4263
4861
 
4264
4862
  const int n_logits = vocab.n_vocab;
4265
4863
 
4266
- auto & logits_id = state.logits_id;
4864
+ auto & logits_id = decoder.logits_id;
4267
4865
 
4268
4866
  logits_id.resize(n_logits);
4269
4867
  for (int i = 0; i < n_logits; ++i) {
@@ -4309,8 +4907,11 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4309
4907
  ptsum = sum_ts;
4310
4908
  }
4311
4909
 
4910
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
4911
+
4312
4912
  for (int i = 0; i < k; ++i) {
4313
- const auto id = logits_id[i].second;
4913
+ const auto id = dist(decoder.rng);
4914
+ //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
4314
4915
 
4315
4916
  result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
4316
4917
 
@@ -4320,8 +4921,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4320
4921
  }
4321
4922
  }
4322
4923
 
4323
- state.n_sample++;
4324
-
4325
4924
  return result;
4326
4925
  }
4327
4926
 
@@ -4374,115 +4973,6 @@ static void whisper_sequence_score(
4374
4973
  }
4375
4974
  }
4376
4975
 
4377
- static bool whisper_kv_swap_fast(
4378
- std::vector<int> & view,
4379
- whisper_decoder src[],
4380
- std::vector<kv_buf> & kv_swap_bufs,
4381
- const int & n_decoders) {
4382
- WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
4383
-
4384
- // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
4385
- std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
4386
-
4387
- // (buffer->decoder or decoder->decoder)
4388
- std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
4389
-
4390
- // (decoder<->decoder)
4391
- std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
4392
- std::vector<whisper_pair<int, int>> p_swap_vec;
4393
- p_swap_vec.reserve(n_decoders);
4394
-
4395
- // see https://github.com/ggerganov/whisper.cpp/wiki
4396
- for (int i = 0; i < n_decoders; i++) {
4397
- // zero-copy (no modification)
4398
- if (i == view[i] || view[i] < 0) {
4399
- continue;
4400
- }
4401
-
4402
- bool is_one_copy = true;
4403
- // since we modify data sequentially, we only consider decoder indices after current index
4404
- for (int j = i + 1; j < n_decoders; j++) {
4405
- if (i == view[j]) {
4406
- // detect symmetric diagram
4407
- if (j == view[i]) {
4408
- p_swap_set.insert(i);
4409
- p_swap_set.insert(j);
4410
- p_swap_vec.emplace_back(i, j);
4411
- } else {
4412
- two_copy.insert(i);
4413
- is_one_copy = false;
4414
- }
4415
- break;
4416
- }
4417
- }
4418
- if (is_one_copy) {
4419
- one_copy.insert(i);
4420
- }
4421
- }
4422
-
4423
- kv_swap_bufs.resize(n_decoders);
4424
-
4425
- for (int i = 0; i < n_decoders; i++) {
4426
- kv_swap_bufs[i].k.resize(wsp_ggml_nbytes(src[i].kv_self.k));
4427
- kv_swap_bufs[i].v.resize(wsp_ggml_nbytes(src[i].kv_self.v));
4428
- }
4429
-
4430
- for (auto & i : two_copy) {
4431
- // make a copy of KV caches
4432
- WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
4433
- memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
4434
- memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
4435
- }
4436
-
4437
- // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
4438
- for (auto & i : two_copy) {
4439
- // skip the decoder indices that require pointer swapping
4440
- if (p_swap_set.find(i) != p_swap_set.end()) {
4441
- continue;
4442
- }
4443
-
4444
- if (two_copy.find(view[i]) != two_copy.end()) {
4445
- // modify KV caches of decoder using data from kv_swap_bufs
4446
- WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4447
- memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4448
- memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4449
- } else {
4450
- // modify KV caches of decoder using data from correspond decoder KV caches directly
4451
- WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4452
- memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
4453
- memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
4454
- }
4455
- }
4456
-
4457
- // then modify one-copy decoder KV caches
4458
- for (auto & i : one_copy) {
4459
- // skip the decoder indices that require pointer swapping
4460
- if (p_swap_set.find(i) != p_swap_set.end()) {
4461
- continue;
4462
- }
4463
-
4464
- if (two_copy.find(view[i]) != two_copy.end()) {
4465
- // modify KV caches of decoder using data from kv_swap_bufs
4466
- WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4467
- memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4468
- memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4469
- } else {
4470
- // modify KV caches of decoder using data from correspond decoder KV caches directly
4471
- WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4472
- memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
4473
- memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
4474
- }
4475
- }
4476
-
4477
- // swap the pointers
4478
- for (auto & i : p_swap_vec) {
4479
- WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
4480
- std::swap(src[i.first].kv_self, src[i.second].kv_self);
4481
- }
4482
-
4483
- return true;
4484
- }
4485
-
4486
4976
  int whisper_full_with_state(
4487
4977
  struct whisper_context * ctx,
4488
4978
  struct whisper_state * state,
@@ -4498,11 +4988,11 @@ int whisper_full_with_state(
4498
4988
  // compute log mel spectrogram
4499
4989
  if (params.speed_up) {
4500
4990
  // TODO: Replace PV with more advanced algorithm
4501
- log("%s: failed to compute log mel spectrogram\n", __func__);
4991
+ WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
4502
4992
  return -1;
4503
4993
  } else {
4504
4994
  if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
4505
- log("%s: failed to compute log mel spectrogram\n", __func__);
4995
+ WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
4506
4996
  return -2;
4507
4997
  }
4508
4998
  }
@@ -4514,13 +5004,13 @@ int whisper_full_with_state(
4514
5004
 
4515
5005
  const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
4516
5006
  if (lang_id < 0) {
4517
- log("%s: failed to auto-detect language\n", __func__);
5007
+ WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
4518
5008
  return -3;
4519
5009
  }
4520
5010
  state->lang_id = lang_id;
4521
5011
  params.language = whisper_lang_str(lang_id);
4522
5012
 
4523
- log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
5013
+ WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
4524
5014
  if (params.detect_language) {
4525
5015
  return 0;
4526
5016
  }
@@ -4542,6 +5032,7 @@ int whisper_full_with_state(
4542
5032
  // basically don't process anything that is less than 1.0s
4543
5033
  // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
4544
5034
  if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
5035
+ WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
4545
5036
  return 0;
4546
5037
  }
4547
5038
 
@@ -4572,42 +5063,23 @@ int whisper_full_with_state(
4572
5063
 
4573
5064
  n_decoders = std::max(1, n_decoders);
4574
5065
 
5066
+ if (n_decoders > WHISPER_MAX_DECODERS) {
5067
+ WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
5068
+ return -4;
5069
+ }
5070
+
4575
5071
  // TAGS: WHISPER_DECODER_INIT
4576
5072
  for (int j = 1; j < n_decoders; j++) {
4577
5073
  auto & decoder = state->decoders[j];
4578
5074
 
4579
- if (decoder.kv_self.ctx == nullptr) {
4580
- decoder.kv_self = state->decoders[0].kv_self;
4581
- if (!kv_cache_reinit(decoder.kv_self)) {
4582
- log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
4583
- return -4;
4584
- }
4585
-
4586
- WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
4587
-
4588
- decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
4589
-
4590
- decoder.probs.resize (ctx->vocab.n_vocab);
4591
- decoder.logits.resize (ctx->vocab.n_vocab);
4592
- decoder.logprobs.resize(ctx->vocab.n_vocab);
5075
+ decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
4593
5076
 
4594
- // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
4595
- #ifdef WSP_GGML_USE_METAL
4596
- if (state->ctx_metal) {
4597
- #define WHISPER_METAL_CHECK_BUF(result) \
4598
- if (!(result)) { \
4599
- log("%s: failed to add metal buffer\n", __func__); \
4600
- return 0; \
4601
- }
5077
+ decoder.probs.resize (ctx->vocab.n_vocab);
5078
+ decoder.logits.resize (ctx->vocab.n_vocab);
5079
+ decoder.logprobs.resize(ctx->vocab.n_vocab);
5080
+ decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
4602
5081
 
4603
- const std::string kv_name = "kv_self_" + std::to_string(j);
4604
- auto & kv_self = decoder.kv_self;
4605
-
4606
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
4607
- #undef WHISPER_METAL_CHECK_BUF
4608
- }
4609
- #endif
4610
- }
5082
+ decoder.rng = std::mt19937(0);
4611
5083
  }
4612
5084
 
4613
5085
  // the accumulated text context so far
@@ -4640,13 +5112,13 @@ int whisper_full_with_state(
4640
5112
 
4641
5113
  // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
4642
5114
  if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
4643
- log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
5115
+ WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
4644
5116
  return -5;
4645
5117
  }
4646
5118
  state->exp_n_audio_ctx = params.audio_ctx;
4647
5119
 
4648
5120
  // these tokens determine the task that will be performed
4649
- std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
5121
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
4650
5122
 
4651
5123
  if (whisper_is_multilingual(ctx)) {
4652
5124
  const int lang_id = whisper_lang_id(params.language);
@@ -4659,17 +5131,19 @@ int whisper_full_with_state(
4659
5131
  }
4660
5132
  }
4661
5133
 
5134
+ // distilled models require the "no_timestamps" token
4662
5135
  {
4663
5136
  const bool is_distil = ctx->model.hparams.n_text_layer == 2;
4664
-
4665
- // distilled models require the "no_timestamps" token
4666
- // TODO: add input parameter (#1229)
4667
- if (is_distil) {
4668
- log("%s: using distilled model - forcing no_timestamps\n", __func__);
4669
- prompt_init.push_back(whisper_token_not(ctx));
5137
+ if (is_distil && !params.no_timestamps) {
5138
+ WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
5139
+ params.no_timestamps = true;
4670
5140
  }
4671
5141
  }
4672
5142
 
5143
+ if (params.no_timestamps) {
5144
+ prompt_init.push_back(whisper_token_not(ctx));
5145
+ }
5146
+
4673
5147
  int seek = seek_start;
4674
5148
 
4675
5149
  std::vector<whisper_token> prompt;
@@ -4682,8 +5156,10 @@ int whisper_full_with_state(
4682
5156
  bool has_ts;
4683
5157
 
4684
5158
  whisper_sequence sequence;
5159
+ whisper_grammar grammar;
4685
5160
  };
4686
5161
 
5162
+ std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
4687
5163
  std::vector<beam_candidate> beam_candidates;
4688
5164
 
4689
5165
  // main loop
@@ -4692,24 +5168,24 @@ int whisper_full_with_state(
4692
5168
  const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
4693
5169
 
4694
5170
  params.progress_callback(
4695
- ctx, ctx->state, progress_cur, params.progress_callback_user_data);
5171
+ ctx, state, progress_cur, params.progress_callback_user_data);
4696
5172
  }
4697
5173
 
4698
- // of only 1 second left, then stop
5174
+ // if only 1 second left, then stop
4699
5175
  if (seek + 100 >= seek_end) {
4700
5176
  break;
4701
5177
  }
4702
5178
 
4703
5179
  if (params.encoder_begin_callback) {
4704
5180
  if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
4705
- log("%s: encoder_begin_callback returned false - aborting\n", __func__);
5181
+ WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
4706
5182
  break;
4707
5183
  }
4708
5184
  }
4709
5185
 
4710
5186
  // encode audio features starting at offset seek
4711
5187
  if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4712
- log("%s: failed to encode\n", __func__);
5188
+ WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
4713
5189
  return -6;
4714
5190
  }
4715
5191
 
@@ -4745,14 +5221,12 @@ int whisper_full_with_state(
4745
5221
 
4746
5222
  n_decoders_cur = std::max(1, n_decoders_cur);
4747
5223
 
4748
- WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
5224
+ WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
4749
5225
 
4750
5226
  // TAGS: WHISPER_DECODER_INIT
4751
5227
  for (int j = 0; j < n_decoders_cur; ++j) {
4752
5228
  auto & decoder = state->decoders[j];
4753
5229
 
4754
- decoder.kv_self.n = 0;
4755
-
4756
5230
  decoder.sequence.tokens.clear();
4757
5231
  decoder.sequence.result_len = 0;
4758
5232
  decoder.sequence.sum_logprobs_all = 0.0;
@@ -4766,10 +5240,16 @@ int whisper_full_with_state(
4766
5240
  decoder.failed = false;
4767
5241
  decoder.completed = false;
4768
5242
  decoder.has_ts = false;
5243
+
5244
+ if (params.grammar_rules != nullptr) {
5245
+ decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
5246
+ } else {
5247
+ decoder.grammar = {};
5248
+ }
4769
5249
  }
4770
5250
 
4771
5251
  // init prompt and kv cache for the current iteration
4772
- // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
5252
+ // TODO: do not recompute the prompt if it is the same as previous time
4773
5253
  {
4774
5254
  prompt.clear();
4775
5255
 
@@ -4791,25 +5271,26 @@ int whisper_full_with_state(
4791
5271
  }
4792
5272
  WHISPER_PRINT_DEBUG("\n\n");
4793
5273
 
4794
- if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4795
- log("%s: failed to decode\n", __func__);
5274
+ whisper_kv_cache_clear(state->kv_self);
5275
+
5276
+ whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5277
+
5278
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5279
+ WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
4796
5280
  return -7;
4797
5281
  }
4798
5282
 
4799
5283
  {
4800
5284
  const int64_t t_start_sample_us = wsp_ggml_time_us();
4801
5285
 
4802
- whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
5286
+ state->decoders[0].i_batch = prompt.size() - 1;
4803
5287
 
4804
- state->decoders[0].kv_self.n += prompt.size();
5288
+ whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
4805
5289
 
4806
5290
  for (int j = 1; j < n_decoders_cur; ++j) {
4807
5291
  auto & decoder = state->decoders[j];
4808
5292
 
4809
- memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, wsp_ggml_nbytes(decoder.kv_self.k));
4810
- memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, wsp_ggml_nbytes(decoder.kv_self.v));
4811
-
4812
- decoder.kv_self.n += prompt.size();
5293
+ whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
4813
5294
 
4814
5295
  memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
4815
5296
  memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
@@ -4824,41 +5305,81 @@ int whisper_full_with_state(
4824
5305
  const int64_t t_start_sample_us = wsp_ggml_time_us();
4825
5306
 
4826
5307
  if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
4827
- beam_candidates.clear();
5308
+ for (auto & bc : bc_per_dec) {
5309
+ bc.clear();
5310
+ }
4828
5311
  }
4829
5312
 
4830
- // generate new sequence candidates for each decoder
4831
- for (int j = 0; j < n_decoders_cur; ++j) {
4832
- auto & decoder = state->decoders[j];
5313
+ // sampling
5314
+ // TODO: avoid memory allocations, optimize, avoid threads?
5315
+ {
5316
+ std::atomic<int> j_cur(0);
4833
5317
 
4834
- if (decoder.completed || decoder.failed) {
4835
- continue;
4836
- }
5318
+ auto process = [&]() {
5319
+ while (true) {
5320
+ const int j = j_cur.fetch_add(1);
4837
5321
 
4838
- switch (params.strategy) {
4839
- case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
4840
- {
4841
- if (t_cur < 1e-6f) {
4842
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
4843
- } else {
4844
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
4845
- }
5322
+ if (j >= n_decoders_cur) {
5323
+ break;
5324
+ }
4846
5325
 
4847
- decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
4848
- } break;
4849
- case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
4850
- {
4851
- const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
5326
+ auto & decoder = state->decoders[j];
4852
5327
 
4853
- for (const auto & token : tokens_new) {
4854
- beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
4855
- beam_candidates.back().sequence.tokens.push_back(token);
4856
- beam_candidates.back().sequence.sum_logprobs_all += token.plog;
5328
+ if (decoder.completed || decoder.failed) {
5329
+ continue;
5330
+ }
4857
5331
 
4858
- //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
4859
- }
4860
- } break;
5332
+ switch (params.strategy) {
5333
+ case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
5334
+ {
5335
+ if (t_cur < 1e-6f) {
5336
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
5337
+ } else {
5338
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
5339
+ }
5340
+
5341
+ decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
5342
+ } break;
5343
+ case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
5344
+ {
5345
+ const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
5346
+
5347
+ for (const auto & token : tokens_new) {
5348
+ bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
5349
+ bc_per_dec[j].back().sequence.tokens.push_back(token);
5350
+ bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
5351
+ }
5352
+ } break;
5353
+ };
5354
+ }
4861
5355
  };
5356
+
5357
+ const int n_threads = std::min(params.n_threads, n_decoders_cur);
5358
+
5359
+ if (n_threads == 1) {
5360
+ process();
5361
+ } else {
5362
+ std::vector<std::thread> threads(n_threads - 1);
5363
+
5364
+ for (int t = 0; t < n_threads - 1; ++t) {
5365
+ threads[t] = std::thread(process);
5366
+ }
5367
+
5368
+ process();
5369
+
5370
+ for (int t = 0; t < n_threads - 1; ++t) {
5371
+ threads[t].join();
5372
+ }
5373
+ }
5374
+ }
5375
+
5376
+ beam_candidates.clear();
5377
+ for (const auto & bc : bc_per_dec) {
5378
+ beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
5379
+
5380
+ if (!bc.empty()) {
5381
+ state->n_sample += 1;
5382
+ }
4862
5383
  }
4863
5384
 
4864
5385
  // for beam-search, choose the top candidates and update the KV caches
@@ -4871,7 +5392,6 @@ int whisper_full_with_state(
4871
5392
  });
4872
5393
 
4873
5394
  uint32_t cur_c = 0;
4874
- std::vector<int> decoder_idx(n_decoders_cur, -1);
4875
5395
 
4876
5396
  for (int j = 0; j < n_decoders_cur; ++j) {
4877
5397
  auto & decoder = state->decoders[j];
@@ -4880,23 +5400,38 @@ int whisper_full_with_state(
4880
5400
  continue;
4881
5401
  }
4882
5402
 
5403
+ if (cur_c >= beam_candidates.size()) {
5404
+ cur_c = 0;
5405
+ }
5406
+
4883
5407
  auto & cur = beam_candidates[cur_c++];
4884
5408
 
4885
5409
  while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
4886
5410
  ++cur_c;
4887
5411
  }
4888
5412
 
4889
- decoder.sequence = cur.sequence;
4890
5413
  decoder.seek_delta = cur.seek_delta;
4891
5414
  decoder.has_ts = cur.has_ts;
5415
+ decoder.sequence = cur.sequence;
5416
+ decoder.grammar = cur.grammar;
5417
+
5418
+ whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
4892
5419
 
4893
- decoder_idx[j] = cur.decoder_idx;
4894
5420
  WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
4895
5421
  __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
4896
5422
  }
4897
5423
 
4898
- // update KV caches
4899
- whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
5424
+ for (int j = 0; j < n_decoders_cur; ++j) {
5425
+ auto & decoder = state->decoders[j];
5426
+
5427
+ if (decoder.completed || decoder.failed) {
5428
+ continue;
5429
+ }
5430
+
5431
+ whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
5432
+ whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
5433
+ whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
5434
+ }
4900
5435
  }
4901
5436
 
4902
5437
  // update the decoder state
@@ -4925,6 +5460,7 @@ int whisper_full_with_state(
4925
5460
 
4926
5461
  // do not allow to go back in time
4927
5462
  if (has_ts && seek_delta > seek_delta_new && result_len < i) {
5463
+ WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
4928
5464
  failed = true; // TODO: maybe this is not a failure ?
4929
5465
  continue;
4930
5466
  }
@@ -4934,6 +5470,8 @@ int whisper_full_with_state(
4934
5470
  has_ts = true;
4935
5471
  }
4936
5472
 
5473
+ whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
5474
+
4937
5475
  #ifdef WHISPER_DEBUG
4938
5476
  {
4939
5477
  const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
@@ -4951,6 +5489,7 @@ int whisper_full_with_state(
4951
5489
  if (seek + seek_delta + 100 >= seek_end) {
4952
5490
  result_len = i + 1;
4953
5491
  } else {
5492
+ WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
4954
5493
  failed = true;
4955
5494
  continue;
4956
5495
  }
@@ -4961,6 +5500,7 @@ int whisper_full_with_state(
4961
5500
  seek_delta = 100*WHISPER_CHUNK_SIZE;
4962
5501
  }
4963
5502
 
5503
+ WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j);
4964
5504
  completed = true;
4965
5505
  continue;
4966
5506
  }
@@ -4976,6 +5516,7 @@ int whisper_full_with_state(
4976
5516
  // sometimes, the decoding can get stuck in a repetition loop
4977
5517
  // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
4978
5518
  if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
5519
+ WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
4979
5520
  failed = true;
4980
5521
  continue;
4981
5522
  }
@@ -5003,32 +5544,83 @@ int whisper_full_with_state(
5003
5544
  state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
5004
5545
 
5005
5546
  // obtain logits for the next token
5006
- for (int j = 0; j < n_decoders_cur; ++j) {
5007
- auto & decoder = state->decoders[j];
5547
+ {
5548
+ auto & batch = state->batch;
5008
5549
 
5009
- if (decoder.failed || decoder.completed) {
5010
- continue;
5011
- }
5550
+ batch.n_tokens = 0;
5551
+
5552
+ const int n_past = prompt.size() + i;
5553
+
5554
+ for (int j = 0; j < n_decoders_cur; ++j) {
5555
+ auto & decoder = state->decoders[j];
5556
+
5557
+ if (decoder.failed || decoder.completed) {
5558
+ continue;
5559
+ }
5560
+
5561
+ //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
5012
5562
 
5013
- decoder.tokens_tmp.resize(1);
5014
- decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
5563
+ decoder.i_batch = batch.n_tokens;
5564
+
5565
+ batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
5566
+ batch.pos [batch.n_tokens] = n_past;
5567
+ batch.n_seq_id[batch.n_tokens] = 1;
5568
+ batch.seq_id [batch.n_tokens][0] = j;
5569
+ batch.logits [batch.n_tokens] = 1;
5570
+ batch.n_tokens++;
5571
+ }
5015
5572
 
5016
- //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
5573
+ assert(batch.n_tokens > 0);
5017
5574
 
5018
- if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5019
- log("%s: failed to decode\n", __func__);
5575
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5576
+ WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5020
5577
  return -8;
5021
5578
  }
5022
5579
 
5580
+ const int64_t t_start_sample_us = wsp_ggml_time_us();
5581
+
5582
+ // TODO: avoid memory allocations, optimize, avoid threads?
5023
5583
  {
5024
- const int64_t t_start_sample_us = wsp_ggml_time_us();
5584
+ std::atomic<int> j_cur(0);
5585
+
5586
+ auto process = [&]() {
5587
+ while (true) {
5588
+ const int j = j_cur.fetch_add(1);
5589
+
5590
+ if (j >= n_decoders_cur) {
5591
+ break;
5592
+ }
5593
+
5594
+ auto & decoder = state->decoders[j];
5595
+
5596
+ if (decoder.failed || decoder.completed) {
5597
+ continue;
5598
+ }
5599
+
5600
+ whisper_process_logits(*ctx, *state, decoder, params, t_cur);
5601
+ }
5602
+ };
5603
+
5604
+ const int n_threads = std::min(params.n_threads, n_decoders_cur);
5025
5605
 
5026
- whisper_process_logits(*ctx, *state, params, decoder, t_cur);
5606
+ if (n_threads == 1) {
5607
+ process();
5608
+ } else {
5609
+ std::vector<std::thread> threads(n_threads - 1);
5610
+
5611
+ for (int t = 0; t < n_threads - 1; ++t) {
5612
+ threads[t] = std::thread(process);
5613
+ }
5027
5614
 
5028
- ++decoder.kv_self.n;
5615
+ process();
5029
5616
 
5030
- state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
5617
+ for (int t = 0; t < n_threads - 1; ++t) {
5618
+ threads[t].join();
5619
+ }
5620
+ }
5031
5621
  }
5622
+
5623
+ state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
5032
5624
  }
5033
5625
  }
5034
5626
 
@@ -5068,28 +5660,27 @@ int whisper_full_with_state(
5068
5660
  WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
5069
5661
  }
5070
5662
 
5663
+ bool success = true;
5664
+
5071
5665
  // was the decoding successful for the current temperature?
5072
5666
  // do fallback only if:
5073
5667
  // - we are not at the last temperature
5074
- // - we are not at the end of the audio (3 sec)
5075
- if (it != (int) temperatures.size() - 1 &&
5076
- seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
5077
- bool success = true;
5078
-
5668
+ if (it != (int) temperatures.size() - 1) {
5079
5669
  const auto & decoder = state->decoders[best_decoder_id];
5080
5670
 
5081
5671
  if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
5672
+ WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
5082
5673
  success = false;
5083
5674
  state->n_fail_p++;
5084
5675
  }
5676
+ }
5085
5677
 
5086
- if (success) {
5087
- //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
5088
- // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
5089
- //}
5678
+ if (success) {
5679
+ //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
5680
+ // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
5681
+ //}
5090
5682
 
5091
- break;
5092
- }
5683
+ break;
5093
5684
  }
5094
5685
 
5095
5686
  WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
@@ -5325,11 +5916,13 @@ int whisper_full_parallel(
5325
5916
  ctx->state->t_sample_us += states[i]->t_sample_us;
5326
5917
  ctx->state->t_encode_us += states[i]->t_encode_us;
5327
5918
  ctx->state->t_decode_us += states[i]->t_decode_us;
5919
+ ctx->state->t_batchd_us += states[i]->t_batchd_us;
5328
5920
  ctx->state->t_prompt_us += states[i]->t_prompt_us;
5329
5921
 
5330
5922
  ctx->state->n_sample += states[i]->n_sample;
5331
5923
  ctx->state->n_encode += states[i]->n_encode;
5332
5924
  ctx->state->n_decode += states[i]->n_decode;
5925
+ ctx->state->n_batchd += states[i]->n_batchd;
5333
5926
  ctx->state->n_prompt += states[i]->n_prompt;
5334
5927
 
5335
5928
  whisper_free_state(states[i]);
@@ -5342,12 +5935,12 @@ int whisper_full_parallel(
5342
5935
  ctx->state->t_decode_us /= n_processors;
5343
5936
 
5344
5937
  // print information about the audio boundaries
5345
- log("\n");
5346
- log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
5938
+ WHISPER_LOG_WARN("\n");
5939
+ WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
5347
5940
  for (int i = 0; i < n_processors - 1; ++i) {
5348
- log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
5941
+ WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
5349
5942
  }
5350
- log("%s: the transcription quality may be degraded near these boundaries\n", __func__);
5943
+ WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
5351
5944
 
5352
5945
  return ret;
5353
5946
  }
@@ -5462,8 +6055,45 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
5462
6055
  size_t n = 20;
5463
6056
  size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations
5464
6057
 
5465
- // 1GB MB array
5466
- const size_t size = arr*1024llu*1024llu;
6058
+ // 1GB array
6059
+ const size_t size = arr*1e6;
6060
+
6061
+ double sum = 0.0;
6062
+
6063
+ // heat-up
6064
+ {
6065
+ char * src = (char *) malloc(size);
6066
+ char * dst = (char *) malloc(size);
6067
+
6068
+ for (size_t i = 0; i < size; i++) src[i] = i;
6069
+
6070
+ memcpy(dst, src, size); // heat-up
6071
+
6072
+ double tsum = 0.0;
6073
+
6074
+ for (size_t i = 0; i < n; i++) {
6075
+ const int64_t t0 = wsp_ggml_time_us();
6076
+
6077
+ memcpy(dst, src, size);
6078
+
6079
+ const int64_t t1 = wsp_ggml_time_us();
6080
+
6081
+ tsum += (t1 - t0)*1e-6;
6082
+
6083
+ src[rand() % size] = rand() % 256;
6084
+ }
6085
+
6086
+ snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9));
6087
+ s += strbuf;
6088
+
6089
+ // needed to prevent the compiler from optimizing the memcpy away
6090
+ {
6091
+ for (size_t i = 0; i < size; i++) sum += dst[i];
6092
+ }
6093
+
6094
+ free(src);
6095
+ free(dst);
6096
+ }
5467
6097
 
5468
6098
  // single-thread
5469
6099
  {
@@ -5475,7 +6105,6 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
5475
6105
  memcpy(dst, src, size); // heat-up
5476
6106
 
5477
6107
  double tsum = 0.0;
5478
- double sum = 0.0;
5479
6108
 
5480
6109
  for (size_t i = 0; i < n; i++) {
5481
6110
  const int64_t t0 = wsp_ggml_time_us();
@@ -5489,21 +6118,73 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
5489
6118
  src[rand() % size] = rand() % 256;
5490
6119
  }
5491
6120
 
5492
- snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
6121
+ snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9));
5493
6122
  s += strbuf;
5494
6123
 
5495
6124
  // needed to prevent the compiler from optimizing the memcpy away
5496
6125
  {
5497
6126
  for (size_t i = 0; i < size; i++) sum += dst[i];
6127
+ }
6128
+
6129
+ free(src);
6130
+ free(dst);
6131
+ }
6132
+
6133
+ // multi-thread
6134
+
6135
+ for (uint32_t k = 1; k <= n_threads; k++) {
6136
+ char * src = (char *) malloc(size);
6137
+ char * dst = (char *) malloc(size);
5498
6138
 
5499
- snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
5500
- s += strbuf;
6139
+ for (size_t i = 0; i < size; i++) src[i] = i;
6140
+
6141
+ memcpy(dst, src, size); // heat-up
6142
+
6143
+ double tsum = 0.0;
6144
+
6145
+ auto helper = [&](int th) {
6146
+ const int64_t i0 = (th + 0)*size/k;
6147
+ const int64_t i1 = (th + 1)*size/k;
6148
+
6149
+ for (size_t i = 0; i < n; i++) {
6150
+ memcpy(dst + i0, src + i0, i1 - i0);
6151
+
6152
+ src[i0 + rand() % (i1 - i0)] = rand() % 256;
6153
+ };
6154
+ };
6155
+
6156
+ const int64_t t0 = wsp_ggml_time_us();
6157
+
6158
+ std::vector<std::thread> threads(k - 1);
6159
+ for (uint32_t th = 0; th < k - 1; ++th) {
6160
+ threads[th] = std::thread(helper, th);
6161
+ }
6162
+
6163
+ helper(k - 1);
6164
+
6165
+ for (uint32_t th = 0; th < k - 1; ++th) {
6166
+ threads[th].join();
6167
+ }
6168
+
6169
+ const int64_t t1 = wsp_ggml_time_us();
6170
+
6171
+ tsum += (t1 - t0)*1e-6;
6172
+
6173
+ snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
6174
+ s += strbuf;
6175
+
6176
+ // needed to prevent the compiler from optimizing the memcpy away
6177
+ {
6178
+ for (size_t i = 0; i < size; i++) sum += dst[i];
5501
6179
  }
5502
6180
 
5503
6181
  free(src);
5504
6182
  free(dst);
5505
6183
  }
5506
6184
 
6185
+ snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
6186
+ s += strbuf;
6187
+
5507
6188
  return s.c_str();
5508
6189
  }
5509
6190
 
@@ -5589,12 +6270,12 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
5589
6270
  double tsum = 0.0;
5590
6271
 
5591
6272
  // heat-up
5592
- wsp_ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
6273
+ wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
5593
6274
 
5594
6275
  for (int i = 0; i < n_max; ++i) {
5595
6276
  const int64_t t0 = wsp_ggml_time_us();
5596
6277
 
5597
- wsp_ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
6278
+ wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
5598
6279
 
5599
6280
  const int64_t t1 = wsp_ggml_time_us();
5600
6281
 
@@ -5712,7 +6393,7 @@ static void whisper_exp_compute_token_level_timestamps(
5712
6393
  const int n_samples = state.energy.size();
5713
6394
 
5714
6395
  if (n_samples == 0) {
5715
- log("%s: no signal data available\n", __func__);
6396
+ WHISPER_LOG_ERROR("%s: no signal data available\n", __func__);
5716
6397
  return;
5717
6398
  }
5718
6399
 
@@ -5933,6 +6614,32 @@ static void whisper_exp_compute_token_level_timestamps(
5933
6614
  //}
5934
6615
  }
5935
6616
 
5936
- void whisper_set_log_callback(whisper_log_callback callback) {
5937
- whisper_log = callback;
6617
+ void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
6618
+ g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
6619
+ g_state.log_callback_user_data = user_data;
6620
+ }
6621
+
6622
+ WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
6623
+ static void whisper_log_internal(wsp_ggml_log_level level, const char * format, ...) {
6624
+ va_list args;
6625
+ va_start(args, format);
6626
+ char buffer[1024];
6627
+ int len = vsnprintf(buffer, 1024, format, args);
6628
+ if (len < 1024) {
6629
+ g_state.log_callback(level, buffer, g_state.log_callback_user_data);
6630
+ } else {
6631
+ char* buffer2 = new char[len+1];
6632
+ vsnprintf(buffer2, len+1, format, args);
6633
+ buffer2[len] = 0;
6634
+ g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
6635
+ delete[] buffer2;
6636
+ }
6637
+ va_end(args);
6638
+ }
6639
+
6640
+ static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data) {
6641
+ (void) level;
6642
+ (void) user_data;
6643
+ fputs(text, stderr);
6644
+ fflush(stderr);
5938
6645
  }