whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +7 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +188 -109
  9. package/cpp/README.md +1 -1
  10. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  11. package/cpp/coreml/whisper-encoder.h +4 -0
  12. package/cpp/coreml/whisper-encoder.mm +4 -2
  13. package/cpp/ggml-alloc.c +451 -282
  14. package/cpp/ggml-alloc.h +74 -8
  15. package/cpp/ggml-backend-impl.h +112 -0
  16. package/cpp/ggml-backend.c +1357 -0
  17. package/cpp/ggml-backend.h +181 -0
  18. package/cpp/ggml-impl.h +243 -0
  19. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
  20. package/cpp/ggml-metal.h +28 -1
  21. package/cpp/ggml-metal.m +1128 -308
  22. package/cpp/ggml-quants.c +7382 -0
  23. package/cpp/ggml-quants.h +224 -0
  24. package/cpp/ggml.c +3848 -5245
  25. package/cpp/ggml.h +353 -155
  26. package/cpp/rn-audioutils.cpp +68 -0
  27. package/cpp/rn-audioutils.h +14 -0
  28. package/cpp/rn-whisper-log.h +11 -0
  29. package/cpp/rn-whisper.cpp +141 -59
  30. package/cpp/rn-whisper.h +47 -15
  31. package/cpp/whisper.cpp +1750 -964
  32. package/cpp/whisper.h +97 -15
  33. package/ios/RNWhisper.mm +15 -9
  34. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  35. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  36. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  37. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  38. package/ios/RNWhisperAudioUtils.h +0 -2
  39. package/ios/RNWhisperAudioUtils.m +0 -56
  40. package/ios/RNWhisperContext.h +8 -12
  41. package/ios/RNWhisperContext.mm +132 -138
  42. package/jest/mock.js +1 -1
  43. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  44. package/lib/commonjs/index.js +28 -9
  45. package/lib/commonjs/index.js.map +1 -1
  46. package/lib/commonjs/version.json +1 -1
  47. package/lib/module/NativeRNWhisper.js.map +1 -1
  48. package/lib/module/index.js +28 -9
  49. package/lib/module/index.js.map +1 -1
  50. package/lib/module/version.json +1 -1
  51. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  52. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  53. package/lib/typescript/index.d.ts +7 -2
  54. package/lib/typescript/index.d.ts.map +1 -1
  55. package/package.json +6 -5
  56. package/src/NativeRNWhisper.ts +8 -1
  57. package/src/index.ts +29 -17
  58. package/src/version.json +1 -1
  59. package/whisper-rn.podspec +1 -2
package/cpp/whisper.cpp CHANGED
@@ -1,10 +1,15 @@
1
1
  #include "whisper.h"
2
+
2
3
  #ifdef WHISPER_USE_COREML
3
4
  #include "coreml/whisper-encoder.h"
4
5
  #endif
5
6
 
6
7
  #ifdef WSP_GGML_USE_METAL
7
- # include "ggml-metal.h"
8
+ #include "ggml-metal.h"
9
+ #endif
10
+
11
+ #ifdef WSP_GGML_USE_CUBLAS
12
+ #include "ggml-cuda.h"
8
13
  #endif
9
14
 
10
15
  #ifdef WHISPER_USE_OPENVINO
@@ -13,7 +18,9 @@
13
18
 
14
19
  #include "ggml.h"
15
20
  #include "ggml-alloc.h"
21
+ #include "ggml-backend.h"
16
22
 
23
+ #include <atomic>
17
24
  #include <algorithm>
18
25
  #include <cassert>
19
26
  #define _USE_MATH_DEFINES
@@ -97,10 +104,32 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
97
104
  #define BYTESWAP_TENSOR(t) do {} while (0)
98
105
  #endif
99
106
 
107
+ #ifdef __GNUC__
108
+ #ifdef __MINGW32__
109
+ #define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
110
+ #else
111
+ #define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
112
+ #endif
113
+ #else
114
+ #define WHISPER_ATTRIBUTE_FORMAT(...)
115
+ #endif
116
+
117
+ //
118
+ // logging
119
+ //
120
+
121
+ WHISPER_ATTRIBUTE_FORMAT(2, 3)
122
+ static void whisper_log_internal (wsp_ggml_log_level level, const char * format, ...);
123
+ static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data);
124
+
125
+ #define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
126
+ #define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
127
+ #define WHISPER_LOG_ERROR(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
128
+
100
129
  #define WHISPER_ASSERT(x) \
101
130
  do { \
102
131
  if (!(x)) { \
103
- log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
132
+ WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
104
133
  abort(); \
105
134
  } \
106
135
  } while (0)
@@ -119,15 +148,16 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
119
148
 
120
149
  //#define WHISPER_USE_FLASH_ATTN
121
150
  //#define WHISPER_USE_FLASH_FF
122
- #define WHISPER_MAX_DECODERS 16
151
+ #define WHISPER_MAX_DECODERS 8
152
+ #define WHISPER_MAX_NODES 4096
123
153
 
124
154
  //
125
155
  // ggml helpers
126
156
  //
127
157
 
128
158
  static void wsp_ggml_graph_compute_helper(
159
+ struct wsp_ggml_cgraph * graph,
129
160
  std::vector<uint8_t> & buf,
130
- wsp_ggml_cgraph * graph,
131
161
  int n_threads,
132
162
  whisper_abort_callback abort_callback,
133
163
  void * abort_callback_data) {
@@ -144,6 +174,21 @@ static void wsp_ggml_graph_compute_helper(
144
174
  wsp_ggml_graph_compute(graph, &plan);
145
175
  }
146
176
 
177
+ static void wsp_ggml_graph_compute_helper(
178
+ struct wsp_ggml_backend * backend,
179
+ struct wsp_ggml_cgraph * graph,
180
+ int n_threads) {
181
+ if (wsp_ggml_backend_is_cpu(backend)) {
182
+ wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
183
+ }
184
+ #ifdef WSP_GGML_USE_METAL
185
+ if (wsp_ggml_backend_is_metal(backend)) {
186
+ wsp_ggml_backend_metal_set_n_cb(backend, n_threads);
187
+ }
188
+ #endif
189
+ wsp_ggml_backend_graph_compute(backend, graph);
190
+ }
191
+
147
192
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
148
193
  // the idea is to represent the original matrix multiplication:
149
194
  //
@@ -178,6 +223,7 @@ static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * c
178
223
  }
179
224
 
180
225
  // TODO: check if other platforms can benefit from this optimization
226
+ // TODO: CUDA is currently broken - seems wsp_ggml_mul_mat does not handle views correctly
181
227
  #if defined(WSP_GGML_USE_METAL)
182
228
  #define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad
183
229
  #endif
@@ -192,6 +238,15 @@ enum e_model {
192
238
  MODEL_LARGE,
193
239
  };
194
240
 
241
+ static const std::map<e_model, std::string> g_model_name = {
242
+ { MODEL_UNKNOWN, "unknown" },
243
+ { MODEL_TINY, "tiny" },
244
+ { MODEL_BASE, "base" },
245
+ { MODEL_SMALL, "small" },
246
+ { MODEL_MEDIUM, "medium" },
247
+ { MODEL_LARGE, "large" },
248
+ };
249
+
195
250
  static const std::map<std::string, std::pair<int, std::string>> g_lang = {
196
251
  { "en", { 0, "english", } },
197
252
  { "zh", { 1, "chinese", } },
@@ -292,75 +347,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
292
347
  { "ba", { 96, "bashkir", } },
293
348
  { "jw", { 97, "javanese", } },
294
349
  { "su", { 98, "sundanese", } },
295
- };
296
-
297
- static const size_t MB = 1ull*1024*1024;
298
-
299
- // TODO: avoid using GGUF
300
- static const std::map<wsp_ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
301
- { WSP_GGML_TYPE_F32,
302
- {
303
- { MODEL_TINY, 74ull*MB },
304
- { MODEL_BASE, 142ull*MB },
305
- { MODEL_SMALL, 466ull*MB },
306
- { MODEL_MEDIUM, 1464ull*MB },
307
- { MODEL_LARGE, 2952ull*MB },
308
- },
309
- },
310
- { WSP_GGML_TYPE_F16,
311
- {
312
- { MODEL_TINY, 74ull*MB },
313
- { MODEL_BASE, 142ull*MB },
314
- { MODEL_SMALL, 466ull*MB },
315
- { MODEL_MEDIUM, 1464ull*MB },
316
- { MODEL_LARGE, 2952ull*MB },
317
- },
318
- },
319
- { WSP_GGML_TYPE_Q4_0,
320
- {
321
- { MODEL_TINY, 26ull*MB },
322
- { MODEL_BASE, 50ull*MB },
323
- { MODEL_SMALL, 154ull*MB },
324
- { MODEL_MEDIUM, 470ull*MB },
325
- { MODEL_LARGE, 940ull*MB },
326
- },
327
- },
328
- { WSP_GGML_TYPE_Q4_1,
329
- {
330
- { MODEL_TINY, 32ull*MB },
331
- { MODEL_BASE, 58ull*MB },
332
- { MODEL_SMALL, 182ull*MB },
333
- { MODEL_MEDIUM, 562ull*MB },
334
- { MODEL_LARGE, 1124ull*MB },
335
- },
336
- },
337
- { WSP_GGML_TYPE_Q5_0,
338
- {
339
- { MODEL_TINY, 30ull*MB },
340
- { MODEL_BASE, 54ull*MB },
341
- { MODEL_SMALL, 170ull*MB },
342
- { MODEL_MEDIUM, 516ull*MB },
343
- { MODEL_LARGE, 1034ull*MB },
344
- },
345
- },
346
- { WSP_GGML_TYPE_Q5_1,
347
- {
348
- { MODEL_TINY, 32ull*MB },
349
- { MODEL_BASE, 58ull*MB },
350
- { MODEL_SMALL, 182ull*MB },
351
- { MODEL_MEDIUM, 562ull*MB },
352
- { MODEL_LARGE, 1124ull*MB },
353
- },
354
- },
355
- { WSP_GGML_TYPE_Q8_0,
356
- {
357
- { MODEL_TINY, 45ull*MB },
358
- { MODEL_BASE, 84ull*MB },
359
- { MODEL_SMALL, 268ull*MB },
360
- { MODEL_MEDIUM, 834ull*MB },
361
- { MODEL_LARGE, 1674ull*MB },
362
- },
363
- },
350
+ { "yue", { 99, "cantonese", } },
364
351
  };
365
352
 
366
353
  struct whisper_mel {
@@ -401,7 +388,11 @@ struct whisper_vocab {
401
388
  id token_beg = 50363; // begin timestamps
402
389
 
403
390
  bool is_multilingual() const {
404
- return n_vocab == 51865;
391
+ return n_vocab >= 51865;
392
+ }
393
+
394
+ int num_languages() const {
395
+ return n_vocab - 51765 - (is_multilingual() ? 1 : 0);
405
396
  }
406
397
  };
407
398
 
@@ -416,6 +407,121 @@ struct whisper_segment {
416
407
  bool speaker_turn_next;
417
408
  };
418
409
 
410
+ struct whisper_batch {
411
+ int32_t n_tokens;
412
+
413
+ whisper_token * token;
414
+ whisper_pos * pos;
415
+ int32_t * n_seq_id;
416
+ whisper_seq_id ** seq_id; // null terminated
417
+ int8_t * logits;
418
+ };
419
+
420
+ static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
421
+ whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
422
+
423
+ batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens));
424
+ batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens));
425
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
426
+ batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
427
+ for (int i = 0; i < n_tokens; ++i) {
428
+ batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
429
+ }
430
+ batch.seq_id[n_tokens] = nullptr;
431
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
432
+
433
+ return batch;
434
+ }
435
+
436
+ static void whisper_batch_free(struct whisper_batch batch) {
437
+ if (batch.token) free(batch.token);
438
+ if (batch.pos) free(batch.pos);
439
+ if (batch.n_seq_id) free(batch.n_seq_id);
440
+ if (batch.seq_id) {
441
+ for (int i = 0; batch.seq_id[i]; ++i) {
442
+ free(batch.seq_id[i]);
443
+ }
444
+ free(batch.seq_id);
445
+ }
446
+ if (batch.logits) free(batch.logits);
447
+ }
448
+
449
+ static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
450
+ batch.n_tokens = n_tokens;
451
+ for (int i = 0; i < n_tokens; ++i) {
452
+ if (tokens) {
453
+ batch.token[i] = tokens[i];
454
+ }
455
+ batch.pos [i] = n_past + i;
456
+ batch.n_seq_id[i] = 1;
457
+ batch.seq_id [i][0] = seq_id;
458
+ batch.logits [i] = 0;
459
+ }
460
+ batch.logits[n_tokens - 1] = 1;
461
+ }
462
+
463
+ // replace std::pair by using customized pair struct (reason: std::pair is very slow)
464
+ template<typename A, typename B>
465
+ struct whisper_pair {
466
+ A first;
467
+ B second;
468
+
469
+ // Define a constructor that takes two arguments.
470
+ whisper_pair(const A& a, const B& b) : first(a), second(b) {}
471
+ // Define a constructor that takes no argument.
472
+ whisper_pair() : first(A()), second(B()) {}
473
+ };
474
+
475
+ // wsp_ggml_allocr wrapper for whisper usage
476
+ struct whisper_allocr {
477
+ wsp_ggml_allocr * alloc = nullptr;
478
+
479
+ std::vector<uint8_t> meta;
480
+
481
+ wsp_ggml_backend_buffer_t buffer;
482
+ };
483
+
484
+ static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
485
+ return allocr.meta.size() + wsp_ggml_allocr_max_size(allocr.alloc);
486
+ }
487
+
488
+ // measure the memory usage of a graph and prepare the allocr's internal data buffer
489
+ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, wsp_ggml_backend_t backend, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
490
+ auto & alloc = allocr.alloc;
491
+ auto & meta = allocr.meta;
492
+
493
+ alloc = wsp_ggml_allocr_new_measure_from_backend(backend);
494
+
495
+ meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
496
+
497
+ wsp_ggml_allocr_alloc_graph(alloc, get_graph());
498
+ }
499
+
500
+ static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, wsp_ggml_backend_t backend) {
501
+ if (allocr.alloc == nullptr) {
502
+ // this can be null if we use external encoder like CoreML or OpenVINO
503
+ return;
504
+ }
505
+
506
+ auto & alloc = allocr.alloc;
507
+ auto & buffer = allocr.buffer;
508
+
509
+ size_t size = wsp_ggml_allocr_max_size(alloc);
510
+
511
+ wsp_ggml_allocr_free(alloc);
512
+
513
+ buffer = wsp_ggml_backend_alloc_buffer(backend, size);
514
+ alloc = wsp_ggml_allocr_new_from_buffer(buffer);
515
+ }
516
+
517
+ static void whisper_allocr_free(struct whisper_allocr & allocr) {
518
+ if (allocr.alloc) {
519
+ wsp_ggml_allocr_free(allocr.alloc);
520
+ wsp_ggml_backend_buffer_free(allocr.buffer);
521
+ allocr.alloc = nullptr;
522
+ }
523
+ }
524
+
419
525
  // medium
420
526
  // hparams: {
421
527
  // 'n_mels': 80,
@@ -533,16 +639,31 @@ struct whisper_layer_decoder {
533
639
  struct wsp_ggml_tensor * mlp_1_b;
534
640
  };
535
641
 
642
+ struct whisper_kv_cell {
643
+ whisper_pos pos = -1;
644
+
645
+ std::set<whisper_seq_id> seq_id;
646
+
647
+ bool has_seq_id(const whisper_seq_id & id) const {
648
+ return seq_id.find(id) != seq_id.end();
649
+ }
650
+ };
651
+
536
652
  struct whisper_kv_cache {
653
+ uint32_t head = 0;
654
+ uint32_t size = 0;
655
+
656
+ // computed before each graph build
657
+ uint32_t n = 0;
658
+
659
+ std::vector<whisper_kv_cell> cells;
660
+
537
661
  struct wsp_ggml_tensor * k;
538
662
  struct wsp_ggml_tensor * v;
539
663
 
540
664
  struct wsp_ggml_context * ctx;
541
665
 
542
- // buf points to the memory allocated for both wsp_ggml_tensor 'k' and 'v' (see kv_cache_init)
543
- std::vector<uint8_t> buf;
544
-
545
- int n; // number of tokens currently in the cache
666
+ wsp_ggml_backend_buffer_t buffer;
546
667
  };
547
668
 
548
669
  struct whisper_model {
@@ -579,17 +700,36 @@ struct whisper_model {
579
700
  std::vector<whisper_layer_encoder> layers_encoder;
580
701
  std::vector<whisper_layer_decoder> layers_decoder;
581
702
 
582
- // context
703
+ // ggml context that contains all the meta information about the model tensors
583
704
  struct wsp_ggml_context * ctx;
584
705
 
585
- // the model memory buffer is read-only and can be shared between processors
586
- std::vector<uint8_t> * buf;
706
+ // the model backend data is read-only and can be shared between processors
707
+ struct wsp_ggml_backend_buffer * buffer;
587
708
 
588
709
  // tensors
589
710
  int n_loaded;
590
711
  std::map<std::string, struct wsp_ggml_tensor *> tensors;
591
712
  };
592
713
 
714
+ struct whisper_partial_utf8 {
715
+ uint32_t value; // bit value so far (unshifted)
716
+ int n_remain; // num bytes remaining; -1 indicates invalid sequence
717
+ };
718
+
719
+ struct whisper_grammar {
720
+ /*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
721
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
722
+
723
+ // buffer for partially generated UTF-8 sequence from accepted tokens
724
+ whisper_partial_utf8 partial_utf8;
725
+ };
726
+
727
+ struct whisper_grammar_candidate {
728
+ whisper_token id;
729
+ const uint32_t * code_points;
730
+ whisper_partial_utf8 partial_utf8;
731
+ };
732
+
593
733
  struct whisper_sequence {
594
734
  std::vector<whisper_token_data> tokens;
595
735
 
@@ -605,12 +745,13 @@ struct whisper_sequence {
605
745
 
606
746
  // TAGS: WHISPER_DECODER_INIT
607
747
  struct whisper_decoder {
608
- // each decoder keeps its own KV-cache
609
- whisper_kv_cache kv_self;
610
-
611
748
  // the currently generated sequence of tokens
612
749
  whisper_sequence sequence;
613
750
 
751
+ // grammar parse state of generated sequence of tokens
752
+ whisper_grammar grammar;
753
+
754
+ int i_batch; // the index of the token in the current batch
614
755
  int seek_delta; // the window shift found so far based on the decoded timestamp tokens
615
756
 
616
757
  bool failed; // has the current segment failed to decode?
@@ -622,93 +763,42 @@ struct whisper_decoder {
622
763
  std::vector<float> logits;
623
764
  std::vector<float> logprobs;
624
765
 
625
- std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
626
- };
627
-
628
- // replace std::pair by using customized pair struct (reason: std::pair is very slow)
629
- template<typename A, typename B>
630
- struct whisper_pair {
631
- A first;
632
- B second;
633
-
634
- // Define a constructor that takes two arguments.
635
- whisper_pair(const A& a, const B& b) : first(a), second(b) {}
636
- // Define a constructor that takes no argument.
637
- whisper_pair() : first(A()), second(B()) {}
638
- };
639
-
640
- // beam-search helpers
641
- struct kv_buf {
642
- std::vector<uint8_t> k;
643
- std::vector<uint8_t> v;
644
- };
645
-
646
- // wsp_ggml_allocr wrapper for whisper usage
647
- struct whisper_allocr {
648
- wsp_ggml_allocr * alloc = nullptr;
766
+ // work container used to avoid memory allocations
767
+ std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
649
768
 
650
- std::vector<uint8_t> meta;
651
- std::vector<uint8_t> data;
769
+ mutable std::mt19937 rng; // used for sampling at t > 0.0
652
770
  };
653
771
 
654
- static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
655
- return allocr.meta.size() + allocr.data.size();
656
- }
657
-
658
- // measure the memory usage of a graph and prepare the allocr's internal data buffer
659
- static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct wsp_ggml_cgraph *()> && get_graph) {
660
- const int tensor_alignment = 32;
661
-
662
- auto & alloc = allocr.alloc;
663
- auto & meta = allocr.meta;
664
- auto & data = allocr.data;
665
-
666
- meta.resize(wsp_ggml_tensor_overhead()*WSP_GGML_MAX_NODES + wsp_ggml_graph_overhead());
667
-
668
- alloc = wsp_ggml_allocr_new_measure(tensor_alignment);
669
-
670
- const size_t alloc_size = wsp_ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
671
-
672
- wsp_ggml_allocr_free(alloc);
673
-
674
- data.resize(alloc_size);
675
-
676
- alloc = wsp_ggml_allocr_new(data.data(), data.size(), tensor_alignment);
677
- }
678
-
679
- static void whisper_allocr_free(struct whisper_allocr & allocr) {
680
- if (allocr.alloc) {
681
- wsp_ggml_allocr_free(allocr.alloc);
682
- allocr.alloc = nullptr;
683
- }
684
- }
685
-
686
772
  struct whisper_state {
687
773
  int64_t t_sample_us = 0;
688
774
  int64_t t_encode_us = 0;
689
775
  int64_t t_decode_us = 0;
776
+ int64_t t_batchd_us = 0;
690
777
  int64_t t_prompt_us = 0;
691
778
  int64_t t_mel_us = 0;
692
779
 
693
780
  int32_t n_sample = 0; // number of tokens sampled
694
781
  int32_t n_encode = 0; // number of encoder calls
695
- int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
696
- int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
782
+ int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
783
+ int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
784
+ int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
697
785
  int32_t n_fail_p = 0; // number of logprob threshold failures
698
786
  int32_t n_fail_h = 0; // number of entropy threshold failures
699
787
 
788
+ // unified self-attention KV cache for all decoders
789
+ whisper_kv_cache kv_self;
790
+
700
791
  // cross-attention KV cache for the decoders
701
792
  // shared between all decoders
702
793
  whisper_kv_cache kv_cross;
794
+
703
795
  whisper_mel mel;
704
796
 
705
- whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
797
+ whisper_batch batch;
706
798
 
707
- // buffer for swapping KV caches between decoders during beam-search
708
- std::vector<kv_buf> kv_swap_bufs;
799
+ whisper_decoder decoders[WHISPER_MAX_DECODERS];
709
800
 
710
- // reusable buffer for `struct wsp_ggml_graph_plan.work_data`
711
- std::vector<uint8_t> work_buffer;
801
+ wsp_ggml_backend_t backend = nullptr;
712
802
 
713
803
  // ggml-alloc:
714
804
  // - stores meta info about the intermediate tensors into the `meta` buffers
@@ -722,36 +812,34 @@ struct whisper_state {
722
812
  struct wsp_ggml_tensor * embd_conv = nullptr;
723
813
  struct wsp_ggml_tensor * embd_enc = nullptr;
724
814
 
815
+ // helpers for GPU offloading
816
+ std::vector<float> inp_mel;
817
+ std::vector<float> inp_mask;
818
+
725
819
  // decode output (2-dimensional array: [n_tokens][n_vocab])
726
820
  std::vector<float> logits;
727
821
 
728
822
  std::vector<whisper_segment> result_all;
729
823
  std::vector<whisper_token> prompt_past;
730
824
 
731
- // work container used to avoid memory allocations
732
- std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
733
-
734
- mutable std::mt19937 rng; // used for sampling at t > 0.0
735
-
736
825
  int lang_id = 0; // english by default
737
826
 
738
- std::string path_model; // populated by whisper_init_from_file()
827
+ std::string path_model; // populated by whisper_init_from_file_with_params()
828
+
739
829
  #ifdef WHISPER_USE_COREML
740
830
  whisper_coreml_context * ctx_coreml = nullptr;
741
831
  #endif
742
832
 
743
- #ifdef WSP_GGML_USE_METAL
744
- wsp_ggml_metal_context * ctx_metal = nullptr;
745
- #endif
746
-
747
833
  #ifdef WHISPER_USE_OPENVINO
748
834
  whisper_openvino_context * ctx_openvino = nullptr;
749
835
  #endif
750
836
 
751
837
  // [EXPERIMENTAL] token-level timestamps data
752
- int64_t t_beg = 0;
838
+ int64_t t_beg = 0;
753
839
  int64_t t_last = 0;
840
+
754
841
  whisper_token tid_last;
842
+
755
843
  std::vector<float> energy; // PCM signal energy
756
844
 
757
845
  // [EXPERIMENTAL] speed-up techniques
@@ -765,37 +853,25 @@ struct whisper_context {
765
853
  wsp_ggml_type wtype = wsp_ggml_type::WSP_GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
766
854
  wsp_ggml_type itype = wsp_ggml_type::WSP_GGML_TYPE_F16; // intermediate type (FP32 or FP16)
767
855
 
856
+ whisper_context_params params;
857
+
768
858
  whisper_model model;
769
859
  whisper_vocab vocab;
860
+
770
861
  whisper_state * state = nullptr;
771
862
 
772
- std::string path_model; // populated by whisper_init_from_file()
773
- #ifdef WHISPER_USE_COREML
774
- bool load_coreml = true;
775
- #endif
776
- };
863
+ wsp_ggml_backend_t backend = nullptr;
777
864
 
778
- static void whisper_default_log(const char * text) {
779
- fprintf(stderr, "%s", text);
780
- }
865
+ std::string path_model; // populated by whisper_init_from_file_with_params()
866
+ };
781
867
 
782
- static whisper_log_callback whisper_log = whisper_default_log;
868
+ struct whisper_global {
869
+ // We save the log callback globally
870
+ wsp_ggml_log_callback log_callback = whisper_log_callback_default;
871
+ void * log_callback_user_data = nullptr;
872
+ };
783
873
 
784
- #ifdef __GNUC__
785
- #ifdef __MINGW32__
786
- __attribute__((gnu_format(printf, 1, 2)))
787
- #else
788
- __attribute__((format(printf, 1, 2)))
789
- #endif
790
- #endif
791
- static void log(const char * fmt, ...) {
792
- if (!whisper_log) return;
793
- char buf[1024];
794
- va_list args;
795
- va_start(args, fmt);
796
- vsnprintf(buf, sizeof(buf), fmt, args);
797
- whisper_log(buf);
798
- }
874
+ static whisper_global g_state;
799
875
 
800
876
  template<typename T>
801
877
  static void read_safe(whisper_model_loader * loader, T & dest) {
@@ -806,6 +882,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
806
882
  static bool kv_cache_init(
807
883
  const struct whisper_hparams & hparams,
808
884
  struct whisper_kv_cache & cache,
885
+ wsp_ggml_backend_t backend,
809
886
  wsp_ggml_type wtype,
810
887
  int n_ctx) {
811
888
  const int64_t n_text_state = hparams.n_text_state;
@@ -814,64 +891,204 @@ static bool kv_cache_init(
814
891
  const int64_t n_mem = n_text_layer*n_ctx;
815
892
  const int64_t n_elements = n_text_state*n_mem;
816
893
 
817
- const size_t mem_bytes = 2*(wsp_ggml_type_size(wtype)*n_elements + wsp_ggml_tensor_overhead());
818
-
819
- cache.buf.resize(mem_bytes);
820
-
821
894
  struct wsp_ggml_init_params params = {
822
- /*.mem_size =*/ cache.buf.size(),
823
- /*.mem_buffer =*/ cache.buf.data(),
824
- /*.no_alloc =*/ false,
895
+ /*.mem_size =*/ 2*wsp_ggml_tensor_overhead(),
896
+ /*.mem_buffer =*/ nullptr,
897
+ /*.no_alloc =*/ true,
825
898
  };
826
899
 
900
+ cache.head = 0;
901
+ cache.size = n_ctx;
902
+
903
+ cache.cells.clear();
904
+ cache.cells.resize(n_ctx);
905
+
827
906
  cache.ctx = wsp_ggml_init(params);
828
907
 
829
908
  if (!cache.ctx) {
830
- log("%s: failed to allocate memory for kv cache\n", __func__);
909
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
831
910
  return false;
832
911
  }
833
912
 
834
913
  cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
835
914
  cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
836
915
 
916
+ const size_t mem_bytes = wsp_ggml_nbytes(cache.k) + wsp_ggml_nbytes(cache.v);
917
+
918
+ cache.buffer = wsp_ggml_backend_alloc_buffer(backend, mem_bytes);
919
+
920
+ // allocate the tensors into the backend buffer
921
+ {
922
+ wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(cache.buffer);
923
+
924
+ wsp_ggml_allocr_alloc(alloc, cache.k);
925
+ wsp_ggml_allocr_alloc(alloc, cache.v);
926
+
927
+ wsp_ggml_allocr_free(alloc);
928
+ }
929
+
837
930
  return true;
838
931
  }
839
932
 
840
- static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
841
- WHISPER_ASSERT(cache.ctx);
933
+ static void kv_cache_free(struct whisper_kv_cache & cache) {
934
+ if (cache.ctx) {
935
+ wsp_ggml_free(cache.ctx);
936
+ wsp_ggml_backend_buffer_free(cache.buffer);
937
+ cache.ctx = nullptr;
938
+ }
939
+ }
940
+
941
+ static bool whisper_kv_cache_find_slot(
942
+ struct whisper_kv_cache & cache,
943
+ const struct whisper_batch & batch) {
944
+ const uint32_t n_ctx = cache.size;
945
+ const uint32_t n_tokens = batch.n_tokens;
842
946
 
843
- const int n_elements = wsp_ggml_nelements(cache.k);
844
- WHISPER_ASSERT(n_elements == wsp_ggml_nelements(cache.v));
947
+ if (n_tokens > n_ctx) {
948
+ WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
949
+ return false;
950
+ }
845
951
 
846
- const wsp_ggml_type wtype = cache.k->type;
847
- WHISPER_ASSERT(wtype == cache.v->type);
952
+ uint32_t n_tested = 0;
848
953
 
849
- WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*wsp_ggml_type_sizef(wtype));
954
+ while (true) {
955
+ if (cache.head + n_tokens > n_ctx) {
956
+ n_tested += n_ctx - cache.head;
957
+ cache.head = 0;
958
+ continue;
959
+ }
850
960
 
851
- struct wsp_ggml_init_params params = {
852
- /*.mem_size =*/ cache.buf.size(),
853
- /*.mem_buffer =*/ cache.buf.data(),
854
- /*.no_alloc =*/ false,
855
- };
961
+ bool found = true;
962
+ for (uint32_t i = 0; i < n_tokens; i++) {
963
+ if (cache.cells[cache.head + i].pos >= 0) {
964
+ found = false;
965
+ cache.head += i + 1;
966
+ n_tested += i + 1;
967
+ break;
968
+ }
969
+ }
856
970
 
857
- cache.ctx = wsp_ggml_init(params);
971
+ if (found) {
972
+ break;
973
+ }
858
974
 
859
- if (!cache.ctx) {
860
- log("%s: failed to allocate memory for kv cache\n", __func__);
861
- return false;
975
+ if (n_tested >= n_ctx) {
976
+ //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
977
+ return false;
978
+ }
862
979
  }
863
980
 
864
- cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
865
- cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
981
+ for (uint32_t i = 0; i < n_tokens; i++) {
982
+ cache.cells[cache.head + i].pos = batch.pos[i];
983
+
984
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
985
+ cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
986
+ }
987
+ }
866
988
 
867
989
  return true;
868
990
  }
869
991
 
870
- static void kv_cache_free(struct whisper_kv_cache & cache) {
871
- if (cache.ctx) {
872
- wsp_ggml_free(cache.ctx);
873
- cache.ctx = nullptr;
992
+ // find how many cells are currently in use
993
+ static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
994
+ for (uint32_t i = cache.size - 1; i > 0; --i) {
995
+ if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
996
+ return i + 1;
997
+ }
998
+ }
999
+
1000
+ return 1;
1001
+ }
1002
+
1003
+ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
1004
+ for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
1005
+ cache.cells[i].pos = -1;
1006
+ cache.cells[i].seq_id.clear();
874
1007
  }
1008
+ cache.head = 0;
1009
+ }
1010
+
1011
+ static void whisper_kv_cache_seq_rm(
1012
+ struct whisper_kv_cache & cache,
1013
+ whisper_seq_id seq_id,
1014
+ whisper_pos p0,
1015
+ whisper_pos p1) {
1016
+ uint32_t new_head = cache.size;
1017
+
1018
+ if (p0 < 0) p0 = 0;
1019
+ if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
1020
+
1021
+ for (uint32_t i = 0; i < cache.size; ++i) {
1022
+ if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1023
+ if (seq_id < 0) {
1024
+ cache.cells[i].seq_id.clear();
1025
+ } else if (cache.cells[i].has_seq_id(seq_id)) {
1026
+ cache.cells[i].seq_id.erase(seq_id);
1027
+ } else {
1028
+ continue;
1029
+ }
1030
+ if (cache.cells[i].seq_id.empty()) {
1031
+ cache.cells[i].pos = -1;
1032
+ if (new_head == cache.size) new_head = i;
1033
+ }
1034
+ }
1035
+ }
1036
+
1037
+ // If we freed up a slot, set head to it so searching can start there.
1038
+ if (new_head != cache.size) cache.head = new_head;
1039
+ }
1040
+
1041
+ static void whisper_kv_cache_seq_cp(
1042
+ struct whisper_kv_cache & cache,
1043
+ whisper_seq_id seq_id_src,
1044
+ whisper_seq_id seq_id_dst,
1045
+ whisper_pos p0,
1046
+ whisper_pos p1) {
1047
+ if (p0 < 0) p0 = 0;
1048
+ if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
1049
+
1050
+ cache.head = 0;
1051
+
1052
+ for (uint32_t i = 0; i < cache.size; ++i) {
1053
+ if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1054
+ cache.cells[i].seq_id.insert(seq_id_dst);
1055
+ }
1056
+ }
1057
+ }
1058
+
1059
+ static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1060
+ wsp_ggml_backend_t backend_gpu = NULL;
1061
+
1062
+ // initialize the backends
1063
+ #ifdef WSP_GGML_USE_CUBLAS
1064
+ if (params.use_gpu && wsp_ggml_cublas_loaded()) {
1065
+ WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1066
+ backend_gpu = wsp_ggml_backend_cuda_init(0);
1067
+ if (!backend_gpu) {
1068
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
1069
+ }
1070
+ }
1071
+ #endif
1072
+
1073
+ #ifdef WSP_GGML_USE_METAL
1074
+ if (params.use_gpu) {
1075
+ WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1076
+ wsp_ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
1077
+ backend_gpu = wsp_ggml_backend_metal_init();
1078
+ if (!backend_gpu) {
1079
+ WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__);
1080
+ } else if (!wsp_ggml_backend_metal_supports_family(backend_gpu, 7)) {
1081
+ WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1082
+ wsp_ggml_backend_free(backend_gpu);
1083
+ backend_gpu = NULL;
1084
+ }
1085
+ }
1086
+ #endif
1087
+
1088
+ if (backend_gpu) {
1089
+ return backend_gpu;
1090
+ }
1091
+ return wsp_ggml_backend_cpu_init();
875
1092
  }
876
1093
 
877
1094
  // load the model from a ggml file
@@ -886,7 +1103,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
886
1103
  // see the convert-pt-to-ggml.py script for details
887
1104
  //
888
1105
  static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
889
- log("%s: loading model\n", __func__);
1106
+ WHISPER_LOG_INFO("%s: loading model\n", __func__);
890
1107
 
891
1108
  const int64_t t_start_us = wsp_ggml_time_us();
892
1109
 
@@ -900,7 +1117,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
900
1117
  uint32_t magic;
901
1118
  read_safe(loader, magic);
902
1119
  if (magic != WSP_GGML_FILE_MAGIC) {
903
- log("%s: invalid model data (bad magic)\n", __func__);
1120
+ WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
904
1121
  return false;
905
1122
  }
906
1123
  }
@@ -923,6 +1140,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
923
1140
 
924
1141
  assert(hparams.n_text_state == hparams.n_audio_state);
925
1142
 
1143
+ std::string mver = "";
1144
+
926
1145
  if (hparams.n_audio_layer == 4) {
927
1146
  model.type = e_model::MODEL_TINY;
928
1147
  }
@@ -941,6 +1160,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
941
1160
 
942
1161
  if (hparams.n_audio_layer == 32) {
943
1162
  model.type = e_model::MODEL_LARGE;
1163
+
1164
+ if (hparams.n_vocab == 51866) {
1165
+ mver = " v3";
1166
+ }
944
1167
  }
945
1168
 
946
1169
  const int32_t qntvr = hparams.ftype / WSP_GGML_QNT_VERSION_FACTOR;
@@ -951,41 +1174,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
951
1174
  // in order to save memory and also to speed up the computation
952
1175
  wctx.wtype = wsp_ggml_ftype_to_wsp_ggml_type((wsp_ggml_ftype) (model.hparams.ftype));
953
1176
  if (wctx.wtype == WSP_GGML_TYPE_COUNT) {
954
- log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
1177
+ WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
955
1178
  return false;
956
1179
  }
957
1180
 
958
- const size_t scale = model.hparams.ftype ? 1 : 2;
959
-
960
- log("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
961
- log("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
962
- log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
963
- log("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
964
- log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
965
- log("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
966
- log("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
967
- log("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
968
- log("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
969
- log("%s: n_mels = %d\n", __func__, hparams.n_mels);
970
- log("%s: ftype = %d\n", __func__, model.hparams.ftype);
971
- log("%s: qntvr = %d\n", __func__, qntvr);
972
- log("%s: type = %d\n", __func__, model.type);
973
-
974
- // print memory requirements
975
- {
976
- // TODO
977
- //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
978
- // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
979
- }
980
-
981
- // initialize all memory buffers
982
- // always have at least one decoder
983
-
984
- wctx.model.buf = new std::vector<uint8_t>();
985
- wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
986
-
987
- // we skip initialization of the state until it is needed
988
- // because it might be that state will always be provided externally.
1181
+ WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
1182
+ WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
1183
+ WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
1184
+ WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
1185
+ WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
1186
+ WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
1187
+ WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
1188
+ WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
1189
+ WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
1190
+ WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels);
1191
+ WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype);
1192
+ WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr);
1193
+ WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
989
1194
  }
990
1195
 
991
1196
  // load mel filters
@@ -1006,7 +1211,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1006
1211
  read_safe(loader, n_vocab);
1007
1212
 
1008
1213
  //if (n_vocab != model.hparams.n_vocab) {
1009
- // log("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
1214
+ // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
1010
1215
  // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
1011
1216
  // return false;
1012
1217
  //}
@@ -1026,7 +1231,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1026
1231
  word.assign(&tmp[0], tmp.size());
1027
1232
  } else {
1028
1233
  // seems like we have an empty-string token in multi-language models (i = 50256)
1029
- //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
1234
+ //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
1030
1235
  word = "";
1031
1236
  }
1032
1237
 
@@ -1040,17 +1245,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1040
1245
  if (vocab.is_multilingual()) {
1041
1246
  vocab.token_eot++;
1042
1247
  vocab.token_sot++;
1043
- vocab.token_translate++;
1044
- vocab.token_transcribe++;
1045
- vocab.token_solm++;
1046
- vocab.token_prev++;
1047
- vocab.token_nosp++;
1048
- vocab.token_not++;
1049
- vocab.token_beg++;
1248
+
1249
+ // account for variable number of language tokens
1250
+ const int dt = vocab.num_languages() - 98;
1251
+
1252
+ vocab.token_translate += dt;
1253
+ vocab.token_transcribe += dt;
1254
+ vocab.token_solm += dt;
1255
+ vocab.token_prev += dt;
1256
+ vocab.token_nosp += dt;
1257
+ vocab.token_not += dt;
1258
+ vocab.token_beg += dt;
1050
1259
  }
1051
1260
 
1052
1261
  if (n_vocab < model.hparams.n_vocab) {
1053
- log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
1262
+ WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
1054
1263
  for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
1055
1264
  if (i > vocab.token_beg) {
1056
1265
  word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
@@ -1058,6 +1267,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1058
1267
  word = "[_EOT_]";
1059
1268
  } else if (i == vocab.token_sot) {
1060
1269
  word = "[_SOT_]";
1270
+ } else if (i == vocab.token_translate) {
1271
+ word = "[_TRANSLATE_]";
1272
+ } else if (i == vocab.token_transcribe) {
1273
+ word = "[_TRANSCRIBE_]";
1061
1274
  } else if (i == vocab.token_solm) {
1062
1275
  word = "[_SOLM_]";
1063
1276
  } else if (i == vocab.token_prev) {
@@ -1068,6 +1281,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1068
1281
  word = "[_NOT_]";
1069
1282
  } else if (i == vocab.token_beg) {
1070
1283
  word = "[_BEG_]";
1284
+ } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
1285
+ word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
1071
1286
  } else {
1072
1287
  word = "[_extra_token_" + std::to_string(i) + "]";
1073
1288
  }
@@ -1075,139 +1290,36 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1075
1290
  vocab.id_to_token[i] = word;
1076
1291
  }
1077
1292
  }
1078
- }
1079
1293
 
1080
- size_t ctx_size = 0;
1294
+ WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages());
1295
+ }
1081
1296
 
1082
1297
  const wsp_ggml_type wtype = wctx.wtype;
1083
1298
  const wsp_ggml_type vtype = wctx.wtype == WSP_GGML_TYPE_F32 ? WSP_GGML_TYPE_F32 : WSP_GGML_TYPE_F16; // conv type
1084
1299
 
1300
+ // create the ggml context
1085
1301
  {
1086
1302
  const auto & hparams = model.hparams;
1087
1303
 
1088
- const int n_vocab = hparams.n_vocab;
1089
-
1090
- const int n_audio_ctx = hparams.n_audio_ctx;
1091
- const int n_audio_state = hparams.n_audio_state;
1092
1304
  const int n_audio_layer = hparams.n_audio_layer;
1305
+ const int n_text_layer = hparams.n_text_layer;
1093
1306
 
1094
- const int n_text_ctx = hparams.n_text_ctx;
1095
- const int n_text_state = hparams.n_text_state;
1096
- const int n_text_layer = hparams.n_text_layer;
1097
-
1098
- const int n_mels = hparams.n_mels;
1099
-
1100
- // encoder
1101
- {
1102
- ctx_size += n_audio_ctx*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_pe;
1103
-
1104
- ctx_size += 3*n_mels*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_1_w
1105
- ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_1_b
1106
-
1107
- ctx_size += 3*n_audio_state*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_2_w
1108
- ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_2_b
1109
-
1110
- ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_w;
1111
- ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_b;
1112
- }
1113
-
1114
- // decoder
1115
- {
1116
- ctx_size += n_text_ctx*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_pe;
1117
-
1118
- ctx_size += n_vocab*n_text_state*wsp_ggml_type_sizef(wtype); // d_te;
1119
-
1120
- ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_w;
1121
- ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_b;
1122
- }
1123
-
1124
- // encoder layers
1125
- {
1126
- ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w
1127
- ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b
1128
-
1129
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w
1130
- ctx_size += n_audio_layer*( 4*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b
1131
-
1132
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w
1133
- ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b
1134
-
1135
- ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w
1136
- ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b
1137
-
1138
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_q_w
1139
- ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b
1307
+ const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1140
1308
 
1141
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_k_w
1142
-
1143
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_v_w
1144
- ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b
1145
-
1146
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w
1147
- ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b
1148
- }
1149
-
1150
- // decoder layers
1151
- {
1152
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w
1153
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b
1154
-
1155
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w
1156
- ctx_size += n_text_layer*( 4*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b
1157
-
1158
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w
1159
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b
1160
-
1161
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w
1162
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b
1163
-
1164
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_q_w
1165
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b
1166
-
1167
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_k_w
1168
-
1169
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_v_w
1170
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b
1171
-
1172
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w
1173
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b
1174
- //
1175
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_w
1176
- ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_b
1177
-
1178
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_q_w
1179
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_q_b
1180
-
1181
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_k_w
1182
-
1183
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_v_w
1184
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_v_b
1185
-
1186
- ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_ln_1_w
1187
- ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_1_b
1188
- }
1189
-
1190
- ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
1191
-
1192
- log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
1193
- }
1194
-
1195
- // create the ggml context
1196
- {
1197
1309
  struct wsp_ggml_init_params params = {
1198
- /*.mem_size =*/ wctx.model.buf->size(),
1199
- /*.mem_buffer =*/ wctx.model.buf->data(),
1200
- /*.no_alloc =*/ false,
1310
+ /*.mem_size =*/ n_tensors*wsp_ggml_tensor_overhead(),
1311
+ /*.mem_buffer =*/ nullptr,
1312
+ /*.no_alloc =*/ true,
1201
1313
  };
1202
1314
 
1203
1315
  model.ctx = wsp_ggml_init(params);
1204
1316
  if (!model.ctx) {
1205
- log("%s: wsp_ggml_init() failed\n", __func__);
1317
+ WHISPER_LOG_ERROR("%s: wsp_ggml_init() failed\n", __func__);
1206
1318
  return false;
1207
1319
  }
1208
1320
  }
1209
1321
 
1210
- // prepare memory for the weights
1322
+ // prepare tensors for the weights
1211
1323
  {
1212
1324
  auto & ctx = model.ctx;
1213
1325
 
@@ -1230,16 +1342,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1230
1342
 
1231
1343
  // encoder
1232
1344
  {
1233
- model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1345
+ model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1234
1346
 
1235
- model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1236
- model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1347
+ model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1348
+ model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1237
1349
 
1238
- model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1239
- model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1350
+ model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1351
+ model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1240
1352
 
1241
- model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1242
- model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1353
+ model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1354
+ model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1243
1355
 
1244
1356
  // map by name
1245
1357
  model.tensors["encoder.positional_embedding"] = model.e_pe;
@@ -1403,12 +1515,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1403
1515
  }
1404
1516
  }
1405
1517
 
1518
+ wctx.backend = whisper_backend_init(wctx.params);
1519
+
1520
+ {
1521
+ size_t size_main = 0;
1522
+
1523
+ for (const auto & t : model.tensors) {
1524
+ size_main += wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead();
1525
+ }
1526
+
1527
+ model.buffer = wsp_ggml_backend_alloc_buffer(wctx.backend, size_main);
1528
+
1529
+ WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6);
1530
+ }
1531
+
1532
+ wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(model.buffer);
1533
+
1534
+ // allocate tensors in the backend buffers
1535
+ {
1536
+ for (const auto & t : model.tensors) {
1537
+ wsp_ggml_allocr_alloc(alloc, t.second);
1538
+ }
1539
+ }
1540
+
1406
1541
  // load weights
1407
1542
  {
1408
1543
  size_t total_size = 0;
1409
1544
 
1410
1545
  model.n_loaded = 0;
1411
1546
 
1547
+ std::vector<char> read_buf;
1548
+
1412
1549
  while (true) {
1413
1550
  int32_t n_dims;
1414
1551
  int32_t length;
@@ -1435,20 +1572,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1435
1572
  name.assign(&tmp[0], tmp.size());
1436
1573
 
1437
1574
  if (model.tensors.find(name) == model.tensors.end()) {
1438
- log("%s: unknown tensor '%s' in model file\n", __func__, name.data());
1575
+ WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
1439
1576
  return false;
1440
1577
  }
1441
1578
 
1442
1579
  auto tensor = model.tensors[name.data()];
1580
+
1443
1581
  if (wsp_ggml_nelements(tensor) != nelements) {
1444
- log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1445
- log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1582
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1583
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1446
1584
  __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
1447
1585
  return false;
1448
1586
  }
1449
1587
 
1450
1588
  if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1451
- log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1589
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1452
1590
  __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
1453
1591
  return false;
1454
1592
  }
@@ -1456,29 +1594,49 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1456
1594
  const size_t bpe = wsp_ggml_type_size(wsp_ggml_type(ttype));
1457
1595
 
1458
1596
  if ((nelements*bpe)/wsp_ggml_blck_size(tensor->type) != wsp_ggml_nbytes(tensor)) {
1459
- log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1597
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1460
1598
  __func__, name.data(), wsp_ggml_nbytes(tensor), nelements*bpe);
1461
1599
  return false;
1462
1600
  }
1463
1601
 
1464
- loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
1465
- BYTESWAP_TENSOR(tensor);
1602
+ wsp_ggml_backend_t backend = wctx.backend;
1466
1603
 
1467
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1024.0/1024.0);
1604
+ //printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
1605
+
1606
+ if ((wsp_ggml_backend_is_cpu(backend)
1607
+ #ifdef WSP_GGML_USE_METAL
1608
+ || wsp_ggml_backend_is_metal(backend)
1609
+ #endif
1610
+ )) {
1611
+ // for the CPU and Metal backend, we can read directly into the tensor
1612
+ loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
1613
+ BYTESWAP_TENSOR(tensor);
1614
+ } else {
1615
+ // read into a temporary buffer first, then copy to device memory
1616
+ read_buf.resize(wsp_ggml_nbytes(tensor));
1617
+
1618
+ loader->read(loader->context, read_buf.data(), read_buf.size());
1619
+
1620
+ wsp_ggml_backend_tensor_set(tensor, read_buf.data(), 0, wsp_ggml_nbytes(tensor));
1621
+ }
1622
+
1623
+ //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1e6);
1468
1624
  total_size += wsp_ggml_nbytes(tensor);
1469
1625
  model.n_loaded++;
1470
1626
  }
1471
1627
 
1472
- log("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
1628
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
1473
1629
 
1474
1630
  if (model.n_loaded == 0) {
1475
- log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1631
+ WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1476
1632
  } else if (model.n_loaded != (int) model.tensors.size()) {
1477
- log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1633
+ WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1478
1634
  return false;
1479
1635
  }
1480
1636
  }
1481
1637
 
1638
+ wsp_ggml_allocr_free(alloc);
1639
+
1482
1640
  wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
1483
1641
 
1484
1642
  return true;
@@ -1534,10 +1692,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1534
1692
  if (!wsp_ggml_allocr_is_measure(alloc)) {
1535
1693
  assert(mel_inp.n_mel == n_mels);
1536
1694
 
1537
- float * dst = (float *) mel->data;
1695
+ wstate.inp_mel.resize(wsp_ggml_nelements(mel));
1696
+
1697
+ float * dst = wstate.inp_mel.data();
1538
1698
  memset(dst, 0, wsp_ggml_nbytes(mel));
1539
1699
 
1540
- const int i0 = std::min(mel_offset, mel_inp.n_len);
1700
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
1541
1701
  const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1542
1702
 
1543
1703
  for (int j = 0; j < mel_inp.n_mel; ++j) {
@@ -1545,6 +1705,8 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1545
1705
  dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1546
1706
  }
1547
1707
  }
1708
+
1709
+ wsp_ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, wsp_ggml_nelements(mel)*sizeof(float));
1548
1710
  }
1549
1711
 
1550
1712
  struct wsp_ggml_tensor * cur = nullptr;
@@ -1553,24 +1715,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1553
1715
  // convolution + gelu
1554
1716
  {
1555
1717
  cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1556
- cur = wsp_ggml_add(ctx0,
1557
- wsp_ggml_repeat(ctx0,
1558
- model.e_conv_1_b,
1559
- cur),
1560
- cur);
1718
+ cur = wsp_ggml_add(ctx0, cur, model.e_conv_1_b);
1561
1719
 
1562
1720
  cur = wsp_ggml_gelu(ctx0, cur);
1563
1721
 
1564
1722
  cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1565
- cur = wsp_ggml_add(ctx0,
1566
- wsp_ggml_repeat(ctx0,
1567
- model.e_conv_2_b,
1568
- cur),
1569
- cur);
1723
+ cur = wsp_ggml_add(ctx0, cur, model.e_conv_2_b);
1570
1724
 
1571
1725
  cur = wsp_ggml_gelu(ctx0, cur);
1572
1726
  }
1573
1727
 
1728
+ wsp_ggml_set_name(cur, "embd_conv");
1574
1729
  wstate.embd_conv = cur;
1575
1730
  } else {
1576
1731
  #ifdef WHISPER_USE_COREML
@@ -1578,7 +1733,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1578
1733
  wsp_ggml_allocr_alloc(alloc, cur);
1579
1734
 
1580
1735
  if (!wsp_ggml_allocr_is_measure(alloc)) {
1581
- whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1736
+ whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
1582
1737
  }
1583
1738
  #endif
1584
1739
  #ifdef WHISPER_USE_OPENVINO
@@ -1590,6 +1745,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_conv(
1590
1745
  }
1591
1746
  #endif
1592
1747
 
1748
+ wsp_ggml_set_name(cur, "embd_enc");
1593
1749
  wstate.embd_enc = cur;
1594
1750
  }
1595
1751
 
@@ -1619,19 +1775,26 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1619
1775
 
1620
1776
  struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
1621
1777
 
1622
- wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
1778
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
1623
1779
 
1624
1780
  wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc;
1625
1781
 
1782
+ //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state);
1783
+ //wsp_ggml_allocr_alloc(alloc, cur);
1784
+
1785
+ //if (!wsp_ggml_allocr_is_measure(alloc)) {
1786
+ // wsp_ggml_backend_tensor_copy(wstate.embd_conv, cur);
1787
+ //}
1788
+ struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
1789
+
1626
1790
  struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
1627
1791
  wsp_ggml_allocr_alloc(alloc, KQscale);
1628
1792
 
1629
1793
  if (!wsp_ggml_allocr_is_measure(alloc)) {
1630
- wsp_ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
1794
+ const float val = 1.0f/sqrtf(float(n_state)/n_head);
1795
+ wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
1631
1796
  }
1632
1797
 
1633
- struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv);
1634
-
1635
1798
  // ===================================================================
1636
1799
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
1637
1800
  //static int iter = -1;
@@ -1650,7 +1813,6 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1650
1813
  const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter;
1651
1814
 
1652
1815
  struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1653
-
1654
1816
  cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_cont(ctx0, wsp_ggml_transpose(ctx0, cur)));
1655
1817
 
1656
1818
  // ===================================================================
@@ -1838,11 +2000,11 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder(
1838
2000
  ////////////////////////////////////////////////////////////////////////////
1839
2001
 
1840
2002
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
1841
- // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
1842
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
1843
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
1844
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
1845
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2003
+ // wsp_ggml_used_mem(ctx0)/1e6,
2004
+ // wstate.get_buf_max_mem(0)/1e6,
2005
+ // wstate.get_buf_max_mem(1)/1e6,
2006
+ // wstate.get_buf_max_mem(2)/1e6,
2007
+ // wstate.get_buf_max_mem(3)/1e6);
1846
2008
 
1847
2009
  wsp_ggml_free(ctx0);
1848
2010
 
@@ -1872,13 +2034,20 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross(
1872
2034
 
1873
2035
  wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc;
1874
2036
 
2037
+ //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx);
2038
+ //wsp_ggml_allocr_alloc(alloc, cur);
2039
+
2040
+ //if (!wsp_ggml_allocr_is_measure(alloc)) {
2041
+ // wsp_ggml_backend_tensor_copy(wstate.embd_enc, cur);
2042
+ //}
1875
2043
  struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc);
1876
2044
 
1877
2045
  struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1);
1878
2046
  wsp_ggml_allocr_alloc(alloc, Kscale);
1879
2047
 
1880
2048
  if (!wsp_ggml_allocr_is_measure(alloc)) {
1881
- wsp_ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
2049
+ const float val = pow(float(n_state) / n_head, -0.25);
2050
+ wsp_ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
1882
2051
  }
1883
2052
 
1884
2053
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
@@ -1949,7 +2118,7 @@ static bool whisper_encode_internal(
1949
2118
  wsp_ggml_allocr_alloc_graph(alloc, gf);
1950
2119
 
1951
2120
  if (!whisper_encode_external(wstate)) {
1952
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2121
+ wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
1953
2122
  }
1954
2123
  }
1955
2124
 
@@ -1963,16 +2132,7 @@ static bool whisper_encode_internal(
1963
2132
 
1964
2133
  wsp_ggml_allocr_alloc_graph(alloc, gf);
1965
2134
 
1966
- #ifdef WSP_GGML_USE_METAL
1967
- if (wstate.ctx_metal) {
1968
- wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1969
- wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
1970
- } else {
1971
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1972
- }
1973
- #else
1974
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1975
- #endif
2135
+ wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
1976
2136
  }
1977
2137
 
1978
2138
  // cross
@@ -1985,49 +2145,40 @@ static bool whisper_encode_internal(
1985
2145
 
1986
2146
  wsp_ggml_allocr_alloc_graph(alloc, gf);
1987
2147
 
1988
- #ifdef WSP_GGML_USE_METAL
1989
- if (wstate.ctx_metal) {
1990
- wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1991
- wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
1992
- } else {
1993
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1994
- }
1995
- #else
1996
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1997
- #endif
2148
+ wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
1998
2149
  }
1999
2150
 
2000
- // wsp_ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
2001
-
2002
2151
  wstate.t_encode_us += wsp_ggml_time_us() - t_start_us;
2003
2152
  wstate.n_encode++;
2004
2153
 
2005
- return true;
2154
+ return !(abort_callback && abort_callback(abort_callback_data));
2006
2155
  }
2007
2156
 
2008
2157
  static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2009
2158
  whisper_context & wctx,
2010
2159
  whisper_state & wstate,
2011
- whisper_decoder & decoder,
2012
- const whisper_token * tokens,
2013
- int n_tokens,
2014
- int n_past) {
2160
+ const whisper_batch & batch) {
2015
2161
  const auto & model = wctx.model;
2016
2162
  const auto & hparams = model.hparams;
2017
2163
 
2018
- auto & kv_self = decoder.kv_self;
2164
+ auto & kv_self = wstate.kv_self;
2019
2165
 
2020
2166
  WHISPER_ASSERT(!!kv_self.ctx);
2021
2167
 
2022
- const int n_ctx = hparams.n_text_ctx;
2168
+ wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
2169
+
2170
+ const int n_ctx = kv_self.size;
2023
2171
  const int n_state = hparams.n_text_state;
2024
2172
  const int n_head = hparams.n_text_head;
2025
2173
  const int n_layer = hparams.n_text_layer;
2026
2174
 
2027
- const int N = n_tokens;
2028
- const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2175
+ const int n_tokens = batch.n_tokens;
2176
+ const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2029
2177
 
2030
- //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
2178
+ const int32_t n_kv = wsp_ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
2179
+ const int32_t kv_head = wsp_ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
2180
+
2181
+ //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2031
2182
 
2032
2183
  struct wsp_ggml_init_params params = {
2033
2184
  /*.mem_size =*/ wstate.alloc_decode.meta.size(),
@@ -2037,23 +2188,22 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2037
2188
 
2038
2189
  struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
2039
2190
 
2040
- wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
2191
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
2041
2192
 
2042
- wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc;
2043
-
2044
- struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
2193
+ struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2045
2194
  wsp_ggml_allocr_alloc(alloc, embd);
2046
2195
 
2047
2196
  if (!wsp_ggml_allocr_is_measure(alloc)) {
2048
- memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd));
2197
+ wsp_ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*wsp_ggml_element_size(embd));
2049
2198
  }
2050
2199
 
2051
- struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N);
2200
+ struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, n_tokens);
2052
2201
  wsp_ggml_allocr_alloc(alloc, position);
2053
2202
 
2054
2203
  if (!wsp_ggml_allocr_is_measure(alloc)) {
2055
- for (int i = 0; i < N; ++i) {
2056
- ((int32_t *) position->data)[i] = n_past + i;
2204
+ for (int i = 0; i < n_tokens; ++i) {
2205
+ const int32_t val = batch.pos[i];
2206
+ wsp_ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
2057
2207
  }
2058
2208
  }
2059
2209
 
@@ -2061,7 +2211,33 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2061
2211
  wsp_ggml_allocr_alloc(alloc, KQscale);
2062
2212
 
2063
2213
  if (!wsp_ggml_allocr_is_measure(alloc)) {
2064
- wsp_ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
2214
+ const float val = pow(float(n_state)/n_head, -0.25);
2215
+ wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
2216
+ }
2217
+
2218
+ struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, n_tokens, 1);
2219
+ wsp_ggml_allocr_alloc(alloc, KQ_mask);
2220
+
2221
+ if (!wsp_ggml_allocr_is_measure(alloc)) {
2222
+ wstate.inp_mask.resize(n_kv*n_tokens);
2223
+
2224
+ float * data = wstate.inp_mask.data();
2225
+ memset(data, 0, wsp_ggml_nbytes(KQ_mask));
2226
+
2227
+ for (int h = 0; h < 1; ++h) {
2228
+ for (int j = 0; j < n_tokens; ++j) {
2229
+ const whisper_pos pos = batch.pos[j];
2230
+ const whisper_seq_id seq_id = batch.seq_id[j][0];
2231
+
2232
+ for (int i = 0; i < n_kv; ++i) {
2233
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2234
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
2235
+ }
2236
+ }
2237
+ }
2238
+ }
2239
+
2240
+ wsp_ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, wsp_ggml_nelements(KQ_mask)*sizeof(float));
2065
2241
  }
2066
2242
 
2067
2243
  // token encoding + position encoding
@@ -2116,12 +2292,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2116
2292
  Vcur,
2117
2293
  layer.attn_v_b);
2118
2294
 
2119
- Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, N));
2295
+ Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2120
2296
 
2121
- struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k, N*n_state, (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
2122
- struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v, N, n_state,
2297
+ struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2298
+ struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2123
2299
  ( n_ctx)*wsp_ggml_element_size(kv_self.v),
2124
- (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + n_past*wsp_ggml_element_size(kv_self.v));
2300
+ (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + kv_head*wsp_ggml_element_size(kv_self.v));
2125
2301
 
2126
2302
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Kcur, k));
2127
2303
  wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, Vcur, v));
@@ -2131,12 +2307,12 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2131
2307
 
2132
2308
  struct wsp_ggml_tensor * Q =
2133
2309
  wsp_ggml_permute(ctx0,
2134
- wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2310
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2135
2311
  0, 2, 1, 3);
2136
2312
 
2137
2313
  struct wsp_ggml_tensor * K =
2138
2314
  wsp_ggml_view_3d(ctx0, kv_self.k,
2139
- n_state/n_head, n_past + N, n_head,
2315
+ n_state/n_head, n_kv, n_head,
2140
2316
  wsp_ggml_element_size(kv_self.k)*n_state,
2141
2317
  wsp_ggml_element_size(kv_self.k)*n_state/n_head,
2142
2318
  wsp_ggml_element_size(kv_self.k)*n_state*n_ctx*il);
@@ -2146,16 +2322,17 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2146
2322
 
2147
2323
  //struct wsp_ggml_tensor * KQ_scaled = wsp_ggml_scale(ctx0, KQ, KQ_scale);
2148
2324
 
2149
- struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
2325
+ //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf(ctx0, KQ, n_past);
2326
+ struct wsp_ggml_tensor * KQ_masked = wsp_ggml_add(ctx0, KQ, KQ_mask);
2150
2327
 
2151
2328
  struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max(ctx0, KQ_masked);
2152
2329
 
2153
2330
  struct wsp_ggml_tensor * V =
2154
2331
  wsp_ggml_view_3d(ctx0, kv_self.v,
2155
- n_past + N, n_state/n_head, n_head,
2332
+ n_kv, n_state/n_head, n_head,
2156
2333
  n_ctx*wsp_ggml_element_size(kv_self.v),
2157
2334
  n_ctx*wsp_ggml_element_size(kv_self.v)*n_state/n_head,
2158
- il*n_ctx*wsp_ggml_element_size(kv_self.v)*n_state);
2335
+ n_ctx*wsp_ggml_element_size(kv_self.v)*n_state*il);
2159
2336
 
2160
2337
  struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max);
2161
2338
 
@@ -2163,7 +2340,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2163
2340
 
2164
2341
  cur = wsp_ggml_cpy(ctx0,
2165
2342
  KQV_merged,
2166
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, N));
2343
+ wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
2167
2344
  }
2168
2345
 
2169
2346
  // projection
@@ -2207,33 +2384,33 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2207
2384
  // Kcross is already scaled
2208
2385
  struct wsp_ggml_tensor * Kcross =
2209
2386
  wsp_ggml_view_3d(ctx0, wstate.kv_cross.k,
2210
- n_state/n_head, M, n_head,
2387
+ n_state/n_head, n_audio_ctx, n_head,
2211
2388
  wsp_ggml_element_size(wstate.kv_cross.k)*n_state,
2212
2389
  wsp_ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2213
- wsp_ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
2390
+ wsp_ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2214
2391
 
2215
2392
  //struct wsp_ggml_tensor * Vcross =
2216
2393
  // wsp_ggml_reshape_3d(ctx0,
2217
- // wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
2218
- // n_state/n_head, n_head, M);
2394
+ // wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state),
2395
+ // n_state/n_head, n_head, n_audio_ctx);
2219
2396
 
2220
2397
  //struct wsp_ggml_tensor * V_trans =
2221
2398
  // wsp_ggml_cpy(ctx0,
2222
2399
  // wsp_ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2223
- // wsp_ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
2400
+ // wsp_ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
2224
2401
 
2225
2402
  struct wsp_ggml_tensor * V =
2226
2403
  wsp_ggml_view_3d(ctx0, wstate.kv_cross.v,
2227
- M, n_state/n_head, n_head,
2228
- M*wsp_ggml_element_size(wstate.kv_cross.v),
2229
- M*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2230
- il*M*wsp_ggml_element_size(wstate.kv_cross.v)*n_state);
2404
+ n_audio_ctx, n_state/n_head, n_head,
2405
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v),
2406
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2407
+ n_audio_ctx*wsp_ggml_element_size(wstate.kv_cross.v)*n_state*il);
2231
2408
 
2232
2409
  // ------
2233
2410
 
2234
2411
  struct wsp_ggml_tensor * Q =
2235
2412
  wsp_ggml_permute(ctx0,
2236
- wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
2413
+ wsp_ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2237
2414
  0, 2, 1, 3);
2238
2415
 
2239
2416
  // K * Q
@@ -2254,10 +2431,10 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2254
2431
 
2255
2432
  struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2256
2433
 
2257
- // cur = KQV_merged.contiguous().view(n_state, N)
2434
+ // cur = KQV_merged.contiguous().view(n_state, n_tokens)
2258
2435
  cur = wsp_ggml_cpy(ctx0,
2259
2436
  KQV_merged,
2260
- wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, N));
2437
+ wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_tokens));
2261
2438
  }
2262
2439
 
2263
2440
  // projection
@@ -2329,9 +2506,9 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2329
2506
  }
2330
2507
 
2331
2508
  // compute logits only for the last token
2332
- // comment this line to compute logits for all N tokens
2509
+ // comment this line to compute logits for all n_tokens
2333
2510
  // might be useful in the future
2334
- cur = wsp_ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2511
+ //cur = wsp_ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
2335
2512
 
2336
2513
  struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur);
2337
2514
 
@@ -2355,10 +2532,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder(
2355
2532
  static bool whisper_decode_internal(
2356
2533
  whisper_context & wctx,
2357
2534
  whisper_state & wstate,
2358
- whisper_decoder & decoder,
2359
- const whisper_token * tokens,
2360
- const int n_tokens,
2361
- const int n_past,
2535
+ const whisper_batch & batch,
2362
2536
  const int n_threads,
2363
2537
  whisper_abort_callback abort_callback,
2364
2538
  void * abort_callback_data) {
@@ -2367,65 +2541,72 @@ static bool whisper_decode_internal(
2367
2541
  const auto & model = wctx.model;
2368
2542
  const auto & hparams = model.hparams;
2369
2543
 
2370
- const int n_vocab = hparams.n_vocab;
2544
+ const int n_vocab = hparams.n_vocab;
2545
+ const int n_tokens = batch.n_tokens;
2371
2546
 
2372
2547
  auto & logits_out = wstate.logits;
2373
2548
 
2374
2549
  struct wsp_ggml_tensor * logits;
2375
2550
 
2551
+ // find KV slot for the batch
2552
+ {
2553
+ auto & kv_self = wstate.kv_self;
2554
+
2555
+ if (!whisper_kv_cache_find_slot(kv_self, batch)) {
2556
+ return false;
2557
+ }
2558
+
2559
+ kv_self.n = whisper_kv_cache_cell_max(kv_self);
2560
+ //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2561
+ //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2562
+ }
2563
+
2376
2564
  // decoder
2377
2565
  {
2378
2566
  auto & alloc = wstate.alloc_decode.alloc;
2379
2567
 
2380
2568
  wsp_ggml_allocr_reset(alloc);
2381
2569
 
2382
- wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
2570
+ wsp_ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
2383
2571
 
2384
2572
  wsp_ggml_allocr_alloc_graph(alloc, gf);
2385
2573
 
2386
2574
  logits = gf->nodes[gf->n_nodes - 1];
2387
2575
 
2388
- #ifdef WSP_GGML_USE_METAL
2389
- if (wstate.ctx_metal) {
2390
- wsp_ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2391
- wsp_ggml_metal_graph_compute(wstate.ctx_metal, gf);
2392
- } else {
2393
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2394
- }
2395
- #else
2396
- wsp_ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2397
- #endif
2576
+ wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads);
2398
2577
  }
2399
2578
 
2400
- // extract logits for all N tokens
2401
- //logits_out.resize(n_tokens*n_vocab);
2402
- //memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
2403
-
2404
- // extract logits only for the last token
2405
- logits_out.resize(n_vocab);
2406
- memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_vocab);
2579
+ logits_out.resize(n_tokens*n_vocab);
2580
+ for (int i = 0; i < n_tokens; i++) {
2581
+ if (batch.logits[i] == 0) {
2582
+ continue;
2583
+ }
2584
+ wsp_ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
2585
+ }
2407
2586
 
2408
- if (n_tokens > 1) {
2587
+ if (batch.n_tokens > 1) {
2409
2588
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
2410
- // wsp_ggml_used_mem(ctx0)/1024.0/1024.0,
2411
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
2412
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
2413
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
2414
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
2589
+ // wsp_ggml_used_mem(ctx0)/1e6,
2590
+ // wstate.get_buf_max_mem(0)/1e6,
2591
+ // wstate.get_buf_max_mem(1)/1e6,
2592
+ // wstate.get_buf_max_mem(2)/1e6,
2593
+ // wstate.get_buf_max_mem(3)/1e6);
2415
2594
  }
2416
2595
 
2417
- if (n_tokens == 1) {
2596
+ if (batch.n_tokens == 1) {
2418
2597
  wstate.t_decode_us += wsp_ggml_time_us() - t_start_us;
2419
2598
  wstate.n_decode++;
2599
+ } else if (batch.n_tokens < 16) {
2600
+ wstate.t_batchd_us += wsp_ggml_time_us() - t_start_us;
2601
+ wstate.n_batchd += n_tokens;
2420
2602
  } else {
2421
2603
  wstate.t_prompt_us += wsp_ggml_time_us() - t_start_us;
2422
- wstate.n_prompt++;
2604
+ wstate.n_prompt += n_tokens;
2423
2605
  }
2424
2606
 
2425
- return true;
2607
+ return !(abort_callback && abort_callback(abort_callback_data));
2426
2608
  }
2427
2609
 
2428
-
2429
2610
  // 500 -> 00:05.000
2430
2611
  // 6000 -> 01:00.000
2431
2612
  static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2769,7 +2950,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
2769
2950
  --j;
2770
2951
  }
2771
2952
  if (!found) {
2772
- log("unknown token\n");
2953
+ WHISPER_LOG_ERROR("unknown token\n");
2773
2954
  ++i;
2774
2955
  }
2775
2956
  }
@@ -2832,94 +3013,105 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
2832
3013
 
2833
3014
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
2834
3015
  fill_sin_cos_table();
3016
+
2835
3017
  whisper_state * state = new whisper_state;
2836
3018
 
2837
- if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2838
- log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3019
+ state->backend = whisper_backend_init(ctx->params);
3020
+
3021
+ // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3022
+ // in theory, there can be a case where this is not enough, but in practice it should always be enough
3023
+ const int factor = 3;
3024
+
3025
+ if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
3026
+ WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
2839
3027
  delete state;
2840
3028
  return nullptr;
2841
3029
  }
2842
3030
 
2843
3031
  {
2844
- const size_t memory_size = wsp_ggml_nbytes(state->decoders[0].kv_self.k) + wsp_ggml_nbytes(state->decoders[0].kv_self.v);
2845
- log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
3032
+ const size_t memory_size = wsp_ggml_nbytes(state->kv_self.k) + wsp_ggml_nbytes(state->kv_self.v);
3033
+ WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
2846
3034
  }
2847
3035
 
2848
- if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2849
- log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
3036
+ if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
3037
+ WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2850
3038
  delete state;
2851
3039
  return nullptr;
2852
3040
  }
2853
3041
 
2854
3042
  {
2855
3043
  const size_t memory_size = wsp_ggml_nbytes(state->kv_cross.k) + wsp_ggml_nbytes(state->kv_cross.v);
2856
- log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
3044
+ WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
2857
3045
  }
2858
3046
 
3047
+
2859
3048
  #ifdef WHISPER_USE_COREML
2860
- if (ctx->load_coreml) { // Not in correct layer for easy patch
3049
+ if (ctx->params.use_coreml) {
2861
3050
  const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
2862
3051
 
2863
- log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
2864
- log("%s: first run on a device may take a while ...\n", __func__);
3052
+ WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
3053
+ WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
2865
3054
 
2866
3055
  state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
2867
3056
  if (!state->ctx_coreml) {
2868
- log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
3057
+ WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2869
3058
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
2870
3059
  delete state;
2871
3060
  return nullptr;
2872
3061
  #endif
2873
3062
  } else {
2874
- log("%s: Core ML model loaded\n", __func__);
3063
+ WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__);
3064
+ }
2875
3065
  }
2876
- }
2877
3066
  #endif
2878
3067
 
2879
3068
  state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
2880
3069
 
2881
- state->logits_id.reserve(ctx->model.hparams.n_vocab);
3070
+ state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
2882
3071
 
2883
3072
  // TAGS: WHISPER_DECODER_INIT
2884
3073
  state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
2885
3074
 
2886
- state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
2887
- state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
2888
- state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
3075
+ state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
3076
+ state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
3077
+ state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
3078
+ state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
3079
+
3080
+ state->decoders[0].rng = std::mt19937(0);
2889
3081
 
2890
3082
  // conv allocator
2891
3083
  {
2892
- whisper_allocr_graph_init(state->alloc_conv,
3084
+ whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
2893
3085
  [&]() {
2894
3086
  return whisper_build_graph_conv(*ctx, *state, 0);
2895
3087
  });
2896
3088
 
2897
- log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
3089
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
2898
3090
  }
2899
3091
 
2900
3092
  // encoder allocator
2901
3093
  if (!whisper_encode_external(*state)) {
2902
- whisper_allocr_graph_init(state->alloc_encode,
3094
+ whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
2903
3095
  [&]() {
2904
3096
  return whisper_build_graph_encoder(*ctx, *state);
2905
3097
  });
2906
3098
 
2907
- log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
3099
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
2908
3100
  }
2909
3101
 
2910
3102
  // cross allocator
2911
3103
  {
2912
- whisper_allocr_graph_init(state->alloc_cross,
3104
+ whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
2913
3105
  [&]() {
2914
3106
  return whisper_build_graph_cross(*ctx, *state);
2915
3107
  });
2916
3108
 
2917
- log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
3109
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
2918
3110
  }
2919
3111
 
2920
3112
  // decoder allocator
2921
3113
  {
2922
- whisper_allocr_graph_init(state->alloc_decode,
3114
+ whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
2923
3115
  [&]() {
2924
3116
  const auto & hparams = ctx->model.hparams;
2925
3117
 
@@ -2927,90 +3119,22 @@ if (ctx->load_coreml) { // Not in correct layer for easy patch
2927
3119
  const int n_tokens = hparams.n_text_ctx;
2928
3120
  const int n_past = 0;
2929
3121
 
2930
- return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
2931
- });
2932
-
2933
- log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
2934
- }
2935
-
2936
- #ifdef WSP_GGML_USE_METAL
2937
- state->ctx_metal = wsp_ggml_metal_init(1);
2938
- if (!state->ctx_metal) {
2939
- log("%s: wsp_ggml_metal_init() failed\n", __func__);
2940
- delete state;
2941
- return nullptr;
2942
- }
2943
-
2944
- log("%s: Metal context initialized\n", __func__);
2945
-
2946
- // this allocates all Metal resources and memory buffers
2947
-
2948
- void * data_ptr = NULL;
2949
- size_t data_size = 0;
2950
-
2951
- // TODO: add mmap support
2952
- //if (params.use_mmap) {
2953
- // data_ptr = ctx->model.mapping->addr;
2954
- // data_size = ctx->model.mapping->size;
2955
- //} else {
2956
- // data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2957
- // data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
2958
- //}
2959
-
2960
- data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx);
2961
- data_size = wsp_ggml_get_mem_size (ctx->model.ctx);
3122
+ whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
2962
3123
 
2963
- const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx);
2964
-
2965
- log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
3124
+ return whisper_build_graph_decoder(*ctx, *state, state->batch);
3125
+ });
2966
3126
 
2967
- #define WHISPER_METAL_CHECK_BUF(result) \
2968
- if (!(result)) { \
2969
- log("%s: failed to add metal buffer\n", __func__); \
2970
- delete state; \
2971
- return nullptr; \
3127
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
2972
3128
  }
2973
3129
 
2974
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
2975
-
2976
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
2977
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
2978
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
2979
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
2980
-
2981
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
2982
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
2983
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
2984
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
2985
-
2986
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
2987
-
2988
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
2989
- #undef WHISPER_METAL_CHECK_BUF
2990
- #endif
2991
-
2992
- state->rng = std::mt19937(0);
3130
+ whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
3131
+ whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
3132
+ whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
3133
+ whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
2993
3134
 
2994
3135
  return state;
2995
3136
  }
2996
3137
 
2997
- #ifdef WHISPER_USE_COREML
2998
- struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) {
2999
- whisper_context * ctx = whisper_init_from_file_no_state(path_model);
3000
- if (!ctx) {
3001
- return nullptr;
3002
- }
3003
- ctx->load_coreml = false;
3004
- ctx->state = whisper_init_state(ctx);
3005
- if (!ctx->state) {
3006
- whisper_free(ctx);
3007
- return nullptr;
3008
- }
3009
-
3010
- return ctx;
3011
- }
3012
- #endif
3013
-
3014
3138
  int whisper_ctx_init_openvino_encoder(
3015
3139
  struct whisper_context * ctx,
3016
3140
  const char * model_path,
@@ -3025,7 +3149,7 @@ int whisper_ctx_init_openvino_encoder(
3025
3149
  return 1;
3026
3150
  #else
3027
3151
  if (!model_path && ctx->path_model.empty()) {
3028
- log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
3152
+ WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
3029
3153
  return 1;
3030
3154
  }
3031
3155
 
@@ -3045,27 +3169,35 @@ int whisper_ctx_init_openvino_encoder(
3045
3169
  path_cache = cache_dir;
3046
3170
  }
3047
3171
 
3048
- log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3049
- log("%s: first run on a device may take a while ...\n", __func__);
3172
+ WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3173
+ WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
3050
3174
 
3051
3175
  ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3052
3176
  if (!ctx->state->ctx_openvino) {
3053
- log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3177
+ WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3054
3178
  return 1;
3055
3179
  } else {
3056
- log("%s: OpenVINO model loaded\n", __func__);
3180
+ WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__);
3057
3181
  }
3058
3182
 
3059
3183
  return 0;
3060
3184
  #endif
3061
3185
  }
3062
3186
 
3063
- struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
3064
- log("%s: loading model from '%s'\n", __func__, path_model);
3187
+ struct whisper_context_params whisper_context_default_params() {
3188
+ struct whisper_context_params result = {
3189
+ /*.use_gpu =*/ true,
3190
+ /*.use_coreml =*/ false,
3191
+ };
3192
+ return result;
3193
+ }
3194
+
3195
+ struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3196
+ WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
3065
3197
 
3066
3198
  auto fin = std::ifstream(path_model, std::ios::binary);
3067
3199
  if (!fin) {
3068
- log("%s: failed to open '%s'\n", __func__, path_model);
3200
+ WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
3069
3201
  return nullptr;
3070
3202
  }
3071
3203
 
@@ -3089,7 +3221,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
3089
3221
  fin->close();
3090
3222
  };
3091
3223
 
3092
- auto ctx = whisper_init_no_state(&loader);
3224
+ auto ctx = whisper_init_with_params_no_state(&loader, params);
3093
3225
 
3094
3226
  if (ctx) {
3095
3227
  ctx->path_model = path_model;
@@ -3098,7 +3230,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
3098
3230
  return ctx;
3099
3231
  }
3100
3232
 
3101
- struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
3233
+ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) {
3102
3234
  struct buf_context {
3103
3235
  uint8_t* buffer;
3104
3236
  size_t size;
@@ -3107,7 +3239,7 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
3107
3239
 
3108
3240
  buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
3109
3241
 
3110
- log("%s: loading model from buffer\n", __func__);
3242
+ WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__);
3111
3243
 
3112
3244
  whisper_model_loader loader = {};
3113
3245
 
@@ -3132,17 +3264,18 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
3132
3264
 
3133
3265
  loader.close = [](void * /*ctx*/) { };
3134
3266
 
3135
- return whisper_init_no_state(&loader);
3267
+ return whisper_init_with_params_no_state(&loader, params);
3136
3268
  }
3137
3269
 
3138
- struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
3270
+ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
3139
3271
  wsp_ggml_time_init();
3140
3272
 
3141
3273
  whisper_context * ctx = new whisper_context;
3274
+ ctx->params = params;
3142
3275
 
3143
3276
  if (!whisper_model_load(loader, *ctx)) {
3144
3277
  loader->close(loader->context);
3145
- log("%s: failed to load model\n", __func__);
3278
+ WHISPER_LOG_ERROR("%s: failed to load model\n", __func__);
3146
3279
  delete ctx;
3147
3280
  return nullptr;
3148
3281
  }
@@ -3152,8 +3285,8 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
3152
3285
  return ctx;
3153
3286
  }
3154
3287
 
3155
- struct whisper_context * whisper_init_from_file(const char * path_model) {
3156
- whisper_context * ctx = whisper_init_from_file_no_state(path_model);
3288
+ struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) {
3289
+ whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params);
3157
3290
  if (!ctx) {
3158
3291
  return nullptr;
3159
3292
  }
@@ -3167,8 +3300,8 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
3167
3300
  return ctx;
3168
3301
  }
3169
3302
 
3170
- struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
3171
- whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
3303
+ struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) {
3304
+ whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params);
3172
3305
  if (!ctx) {
3173
3306
  return nullptr;
3174
3307
  }
@@ -3182,8 +3315,8 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
3182
3315
  return ctx;
3183
3316
  }
3184
3317
 
3185
- struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
3186
- whisper_context * ctx = whisper_init_no_state(loader);
3318
+ struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) {
3319
+ whisper_context * ctx = whisper_init_with_params_no_state(loader, params);
3187
3320
  if (!ctx) {
3188
3321
  return nullptr;
3189
3322
  }
@@ -3197,15 +3330,36 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
3197
3330
  return ctx;
3198
3331
  }
3199
3332
 
3333
+ struct whisper_context * whisper_init_from_file(const char * path_model) {
3334
+ return whisper_init_from_file_with_params(path_model, whisper_context_default_params());
3335
+ }
3336
+
3337
+ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
3338
+ return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params());
3339
+ }
3340
+
3341
+ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
3342
+ return whisper_init_with_params(loader, whisper_context_default_params());
3343
+ }
3344
+
3345
+ struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
3346
+ return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params());
3347
+ }
3348
+
3349
+ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
3350
+ return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params());
3351
+ }
3352
+
3353
+ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
3354
+ return whisper_init_with_params_no_state(loader, whisper_context_default_params());
3355
+ }
3356
+
3200
3357
  void whisper_free_state(struct whisper_state * state)
3201
3358
  {
3202
3359
  if (state) {
3360
+ kv_cache_free(state->kv_self);
3203
3361
  kv_cache_free(state->kv_cross);
3204
3362
 
3205
- for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
3206
- kv_cache_free(state->decoders[i].kv_self);
3207
- }
3208
-
3209
3363
  #ifdef WHISPER_USE_COREML
3210
3364
  if (state->ctx_coreml != nullptr) {
3211
3365
  whisper_coreml_free(state->ctx_coreml);
@@ -3213,13 +3367,6 @@ void whisper_free_state(struct whisper_state * state)
3213
3367
  }
3214
3368
  #endif
3215
3369
 
3216
- #ifdef WSP_GGML_USE_METAL
3217
- if (state->ctx_metal) {
3218
- wsp_ggml_metal_free(state->ctx_metal);
3219
- state->ctx_metal = nullptr;
3220
- }
3221
- #endif
3222
-
3223
3370
  #ifdef WHISPER_USE_OPENVINO
3224
3371
  if (state->ctx_openvino != nullptr) {
3225
3372
  whisper_openvino_free(state->ctx_openvino);
@@ -3227,10 +3374,14 @@ void whisper_free_state(struct whisper_state * state)
3227
3374
  }
3228
3375
  #endif
3229
3376
 
3377
+ whisper_batch_free(state->batch);
3378
+
3230
3379
  whisper_allocr_free(state->alloc_conv);
3231
- whisper_allocr_free(state->alloc_decode);
3232
- whisper_allocr_free(state->alloc_cross);
3233
3380
  whisper_allocr_free(state->alloc_encode);
3381
+ whisper_allocr_free(state->alloc_cross);
3382
+ whisper_allocr_free(state->alloc_decode);
3383
+
3384
+ wsp_ggml_backend_free(state->backend);
3234
3385
 
3235
3386
  delete state;
3236
3387
  }
@@ -3241,16 +3392,25 @@ void whisper_free(struct whisper_context * ctx) {
3241
3392
  if (ctx->model.ctx) {
3242
3393
  wsp_ggml_free(ctx->model.ctx);
3243
3394
  }
3244
- if (ctx->model.buf) {
3245
- delete ctx->model.buf;
3395
+
3396
+ if (ctx->model.buffer) {
3397
+ wsp_ggml_backend_buffer_free(ctx->model.buffer);
3246
3398
  }
3247
3399
 
3248
3400
  whisper_free_state(ctx->state);
3249
3401
 
3402
+ wsp_ggml_backend_free(ctx->backend);
3403
+
3250
3404
  delete ctx;
3251
3405
  }
3252
3406
  }
3253
3407
 
3408
+ void whisper_free_context_params(struct whisper_context_params * params) {
3409
+ if (params) {
3410
+ delete params;
3411
+ }
3412
+ }
3413
+
3254
3414
  void whisper_free_params(struct whisper_full_params * params) {
3255
3415
  if (params) {
3256
3416
  delete params;
@@ -3258,8 +3418,8 @@ void whisper_free_params(struct whisper_full_params * params) {
3258
3418
  }
3259
3419
 
3260
3420
  int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3261
- if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
3262
- log("%s: failed to compute mel spectrogram\n", __func__);
3421
+ if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3422
+ WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3263
3423
  return -1;
3264
3424
  }
3265
3425
 
@@ -3272,8 +3432,8 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3272
3432
 
3273
3433
  // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3274
3434
  int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3275
- if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
3276
- log("%s: failed to compute mel spectrogram\n", __func__);
3435
+ if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3436
+ WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3277
3437
  return -1;
3278
3438
  }
3279
3439
 
@@ -3295,13 +3455,13 @@ int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float *
3295
3455
  // TODO
3296
3456
 
3297
3457
  int whisper_set_mel_with_state(
3298
- struct whisper_context * /*ctx*/,
3458
+ struct whisper_context * ctx,
3299
3459
  struct whisper_state * state,
3300
3460
  const float * data,
3301
3461
  int n_len,
3302
3462
  int n_mel) {
3303
- if (n_mel != WHISPER_N_MEL) {
3304
- log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
3463
+ if (n_mel != ctx->model.filters.n_mel) {
3464
+ WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
3305
3465
  return -1;
3306
3466
  }
3307
3467
 
@@ -3325,7 +3485,7 @@ int whisper_set_mel(
3325
3485
 
3326
3486
  int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3327
3487
  if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
3328
- log("%s: failed to eval\n", __func__);
3488
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3329
3489
  return -1;
3330
3490
  }
3331
3491
 
@@ -3334,7 +3494,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
3334
3494
 
3335
3495
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3336
3496
  if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
3337
- log("%s: failed to eval\n", __func__);
3497
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3338
3498
  return -1;
3339
3499
  }
3340
3500
 
@@ -3342,10 +3502,12 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3342
3502
  }
3343
3503
 
3344
3504
  int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3345
- const int selected_decoder_id = 0;
3505
+ whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
3346
3506
 
3347
- if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3348
- log("%s: failed to eval\n", __func__);
3507
+ whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
3508
+
3509
+ if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
3510
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3349
3511
  return 1;
3350
3512
  }
3351
3513
 
@@ -3353,27 +3515,19 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
3353
3515
  }
3354
3516
 
3355
3517
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3356
- // TODO: add selected_decoder_id to state
3357
- const int selected_decoder_id = 0;
3358
-
3359
3518
  if (ctx->state == nullptr) {
3360
- log("%s: ERROR state was not loaded.\n", __func__);
3361
- return false;
3362
- }
3363
-
3364
- if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3365
- log("%s: failed to eval\n", __func__);
3366
- return 1;
3519
+ WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
3520
+ return -1;
3367
3521
  }
3368
3522
 
3369
- return 0;
3523
+ return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads);
3370
3524
  }
3371
3525
 
3372
3526
  int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
3373
3527
  const auto res = tokenize(ctx->vocab, text);
3374
3528
 
3375
3529
  if (n_max_tokens < (int) res.size()) {
3376
- log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3530
+ WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3377
3531
  return -1;
3378
3532
  }
3379
3533
 
@@ -3401,7 +3555,7 @@ int whisper_lang_id(const char * lang) {
3401
3555
  }
3402
3556
  }
3403
3557
 
3404
- log("%s: unknown language '%s'\n", __func__, lang);
3558
+ WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang);
3405
3559
  return -1;
3406
3560
  }
3407
3561
  return g_lang.at(lang).first;
@@ -3414,7 +3568,18 @@ const char * whisper_lang_str(int id) {
3414
3568
  }
3415
3569
  }
3416
3570
 
3417
- log("%s: unknown language id %d\n", __func__, id);
3571
+ WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
3572
+ return nullptr;
3573
+ }
3574
+
3575
+ const char * whisper_lang_str_full(int id) {
3576
+ for (const auto & kv : g_lang) {
3577
+ if (kv.second.first == id) {
3578
+ return kv.second.second.c_str();
3579
+ }
3580
+ }
3581
+
3582
+ WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
3418
3583
  return nullptr;
3419
3584
  }
3420
3585
 
@@ -3427,29 +3592,29 @@ int whisper_lang_auto_detect_with_state(
3427
3592
  const int seek = offset_ms/10;
3428
3593
 
3429
3594
  if (seek < 0) {
3430
- log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
3595
+ WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
3431
3596
  return -1;
3432
3597
  }
3433
3598
 
3434
3599
  if (seek >= state->mel.n_len_org) {
3435
- log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
3600
+ WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
3436
3601
  return -2;
3437
3602
  }
3438
3603
 
3439
3604
  // run the encoder
3440
3605
  if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
3441
- log("%s: failed to encode\n", __func__);
3606
+ WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
3442
3607
  return -6;
3443
3608
  }
3444
3609
 
3445
3610
  const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
3446
3611
 
3447
3612
  if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
3448
- log("%s: failed to decode\n", __func__);
3613
+ WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
3449
3614
  return -7;
3450
3615
  }
3451
3616
 
3452
- auto & logits_id = state->logits_id;
3617
+ auto & logits_id = state->decoders[0].logits_id;
3453
3618
  logits_id.clear();
3454
3619
 
3455
3620
  for (const auto & kv : g_lang) {
@@ -3645,27 +3810,31 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
3645
3810
  void whisper_print_timings(struct whisper_context * ctx) {
3646
3811
  const int64_t t_end_us = wsp_ggml_time_us();
3647
3812
 
3648
- log("\n");
3649
- log("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3813
+ WHISPER_LOG_INFO("\n");
3814
+ WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3650
3815
  if (ctx->state != nullptr) {
3651
3816
 
3652
3817
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
3653
3818
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
3654
3819
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3820
+ const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
3655
3821
  const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3656
3822
 
3657
- log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3658
- log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3659
- log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3660
- log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3661
- log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3662
- log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3823
+ WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3824
+ WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3825
+ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3826
+ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3827
+ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3828
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
3829
+ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3663
3830
  }
3664
- log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3831
+ WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3665
3832
  }
3666
3833
 
3667
3834
  void whisper_reset_timings(struct whisper_context * ctx) {
3835
+ ctx->t_start_us = wsp_ggml_time_us();
3668
3836
  if (ctx->state != nullptr) {
3837
+ ctx->state->t_mel_us = 0;
3669
3838
  ctx->state->t_sample_us = 0;
3670
3839
  ctx->state->t_encode_us = 0;
3671
3840
  ctx->state->t_decode_us = 0;
@@ -3673,6 +3842,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
3673
3842
  ctx->state->n_sample = 0;
3674
3843
  ctx->state->n_encode = 0;
3675
3844
  ctx->state->n_decode = 0;
3845
+ ctx->state->n_batchd = 0;
3676
3846
  ctx->state->n_prompt = 0;
3677
3847
  }
3678
3848
  }
@@ -3711,14 +3881,441 @@ const char * whisper_print_system_info(void) {
3711
3881
  s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
3712
3882
  s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
3713
3883
  s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
3884
+ s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cublas()) + " | ";
3714
3885
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
3715
3886
  s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
3716
3887
 
3717
3888
  return s.c_str();
3718
3889
  }
3719
3890
 
3891
+ //////////////////////////////////
3892
+ // Grammar - ported from llama.cpp
3893
+ //////////////////////////////////
3894
+
3895
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
3896
+ // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
3897
+ std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
3898
+ const char * src,
3899
+ whisper_partial_utf8 partial_start) {
3900
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
3901
+ const char * pos = src;
3902
+ std::vector<uint32_t> code_points;
3903
+ uint32_t value = partial_start.value;
3904
+ int n_remain = partial_start.n_remain;
3905
+
3906
+ // continue previous decode, if applicable
3907
+ while (*pos != 0 && n_remain > 0) {
3908
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
3909
+ if ((next_byte >> 6) != 2) {
3910
+ // invalid sequence, abort
3911
+ code_points.push_back(0);
3912
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
3913
+ }
3914
+ value = (value << 6) + (next_byte & 0x3F);
3915
+ ++pos;
3916
+ --n_remain;
3917
+ }
3918
+
3919
+ if (partial_start.n_remain > 0 && n_remain == 0) {
3920
+ code_points.push_back(value);
3921
+ }
3922
+
3923
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
3924
+ while (*pos != 0) {
3925
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
3926
+ uint8_t highbits = first_byte >> 4;
3927
+ n_remain = lookup[highbits] - 1;
3928
+
3929
+ if (n_remain < 0) {
3930
+ // invalid sequence, abort
3931
+ code_points.clear();
3932
+ code_points.push_back(0);
3933
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
3934
+ }
3935
+
3936
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
3937
+ value = first_byte & mask;
3938
+ ++pos;
3939
+ while (*pos != 0 && n_remain > 0) {
3940
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
3941
+ ++pos;
3942
+ --n_remain;
3943
+ }
3944
+ if (n_remain == 0) {
3945
+ code_points.push_back(value);
3946
+ }
3947
+ }
3948
+ code_points.push_back(0);
3949
+
3950
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
3951
+ }
3952
+
3953
+ // returns true iff pos points to the end of one of the definitions of a rule
3954
+ static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
3955
+ switch (pos->type) {
3956
+ case WHISPER_GRETYPE_END: return true; // NOLINT
3957
+ case WHISPER_GRETYPE_ALT: return true; // NOLINT
3958
+ default: return false;
3959
+ }
3960
+ }
3961
+
3962
+ // returns true iff chr satisfies the char range at pos (regular or inverse range)
3963
+ // asserts that pos is pointing to a char range element
3964
+ static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
3965
+ const whisper_grammar_element * pos,
3966
+ const uint32_t chr) {
3967
+
3968
+ bool found = false;
3969
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
3970
+
3971
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
3972
+
3973
+ do {
3974
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
3975
+ // inclusive range, e.g. [a-z]
3976
+ found = found || (pos->value <= chr && chr <= pos[1].value);
3977
+ pos += 2;
3978
+ } else {
3979
+ // exact char match, e.g. [a] or "a"
3980
+ found = found || pos->value == chr;
3981
+ pos += 1;
3982
+ }
3983
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
3984
+
3985
+ return std::make_pair(found == is_positive_char, pos);
3986
+ }
3987
+
3988
+ // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
3989
+ // range at pos (regular or inverse range)
3990
+ // asserts that pos is pointing to a char range element
3991
+ static bool whisper_grammar_match_partial_char(
3992
+ const whisper_grammar_element * pos,
3993
+ const whisper_partial_utf8 partial_utf8) {
3994
+
3995
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
3996
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
3997
+
3998
+ uint32_t partial_value = partial_utf8.value;
3999
+ int n_remain = partial_utf8.n_remain;
4000
+
4001
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
4002
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
4003
+ return false;
4004
+ }
4005
+
4006
+ // range of possible code points this partial UTF-8 sequence could complete to
4007
+ uint32_t low = partial_value << (n_remain * 6);
4008
+ uint32_t high = low | ((1 << (n_remain * 6)) - 1);
4009
+
4010
+ if (low == 0) {
4011
+ if (n_remain == 2) {
4012
+ low = 1 << 11;
4013
+ } else if (n_remain == 3) {
4014
+ low = 1 << 16;
4015
+ }
4016
+ }
4017
+
4018
+ do {
4019
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
4020
+ // inclusive range, e.g. [a-z]
4021
+ if (pos->value <= high && low <= pos[1].value) {
4022
+ return is_positive_char;
4023
+ }
4024
+ pos += 2;
4025
+ } else {
4026
+ // exact char match, e.g. [a] or "a"
4027
+ if (low <= pos->value && pos->value <= high) {
4028
+ return is_positive_char;
4029
+ }
4030
+ pos += 1;
4031
+ }
4032
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
4033
+
4034
+ return !is_positive_char;
4035
+ }
4036
+
4037
+
4038
+ // transforms a grammar pushdown stack into N possible stacks, all ending
4039
+ // at a character range (terminal element)
4040
+ static void whisper_grammar_advance_stack(
4041
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4042
+ const std::vector<const whisper_grammar_element *> & stack,
4043
+ std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
4044
+
4045
+ if (stack.empty()) {
4046
+ new_stacks.push_back(stack);
4047
+ return;
4048
+ }
4049
+
4050
+ const whisper_grammar_element * pos = stack.back();
4051
+
4052
+ switch (pos->type) {
4053
+ case WHISPER_GRETYPE_RULE_REF: {
4054
+ const size_t rule_id = static_cast<size_t>(pos->value);
4055
+ const whisper_grammar_element * subpos = rules[rule_id].data();
4056
+ do {
4057
+ // init new stack without the top (pos)
4058
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
4059
+ if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
4060
+ // if this rule ref is followed by another element, add that to stack
4061
+ new_stack.push_back(pos + 1);
4062
+ }
4063
+ if (!whisper_grammar_is_end_of_sequence(subpos)) {
4064
+ // if alternate is nonempty, add to stack
4065
+ new_stack.push_back(subpos);
4066
+ }
4067
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
4068
+ while (!whisper_grammar_is_end_of_sequence(subpos)) {
4069
+ // scan to end of alternate def
4070
+ subpos++;
4071
+ }
4072
+ if (subpos->type == WHISPER_GRETYPE_ALT) {
4073
+ // there's another alternate def of this rule to process
4074
+ subpos++;
4075
+ } else {
4076
+ break;
4077
+ }
4078
+ } while (true);
4079
+ break;
4080
+ }
4081
+ case WHISPER_GRETYPE_CHAR:
4082
+ case WHISPER_GRETYPE_CHAR_NOT:
4083
+ new_stacks.push_back(stack);
4084
+ break;
4085
+ default:
4086
+ // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
4087
+ // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
4088
+ // those
4089
+ WHISPER_ASSERT(false);
4090
+ }
4091
+ }
4092
+
4093
+ // takes a set of possible pushdown stacks on a grammar, which are required to
4094
+ // be positioned at a character range (see `whisper_grammar_advance_stack`), and
4095
+ // produces the N possible stacks if the given char is accepted at those
4096
+ // positions
4097
+ static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
4098
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4099
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
4100
+ const uint32_t chr) {
4101
+
4102
+ std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
4103
+
4104
+ for (const auto & stack : stacks) {
4105
+ if (stack.empty()) {
4106
+ continue;
4107
+ }
4108
+
4109
+ auto match = whisper_grammar_match_char(stack.back(), chr);
4110
+ if (match.first) {
4111
+ const whisper_grammar_element * pos = match.second;
4112
+
4113
+ // update top of stack to next element, if any
4114
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
4115
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
4116
+ new_stack.push_back(pos);
4117
+ }
4118
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
4119
+ }
4120
+ }
4121
+
4122
+ return new_stacks;
4123
+ }
4124
+
4125
+ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
4126
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4127
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
4128
+ const std::vector<whisper_grammar_candidate> & candidates);
4129
+
4130
+ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
4131
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4132
+ const std::vector<const whisper_grammar_element *> & stack,
4133
+ const std::vector<whisper_grammar_candidate> & candidates) {
4134
+
4135
+ std::vector<whisper_grammar_candidate> rejects;
4136
+
4137
+ if (stack.empty()) {
4138
+ for (auto tok : candidates) {
4139
+ if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
4140
+ rejects.push_back(tok);
4141
+ }
4142
+ }
4143
+ return rejects;
4144
+ }
4145
+
4146
+ const whisper_grammar_element * stack_pos = stack.back();
4147
+
4148
+ std::vector<whisper_grammar_candidate> next_candidates;
4149
+ for (auto tok : candidates) {
4150
+ if (*tok.code_points == 0) {
4151
+ // reached end of full codepoints in token, reject iff it ended in a partial sequence
4152
+ // that cannot satisfy this position in grammar
4153
+ if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
4154
+ rejects.push_back(tok);
4155
+ }
4156
+ } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
4157
+ next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
4158
+ } else {
4159
+ rejects.push_back(tok);
4160
+ }
4161
+ }
4162
+
4163
+ const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
4164
+
4165
+ // update top of stack to next element, if any
4166
+ std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
4167
+ if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
4168
+ stack_after.push_back(stack_pos_after);
4169
+ }
4170
+ std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
4171
+ whisper_grammar_advance_stack(rules, stack_after, next_stacks);
4172
+
4173
+ auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
4174
+ for (auto tok : next_rejects) {
4175
+ rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
4176
+ }
4177
+
4178
+ return rejects;
4179
+ }
4180
+
4181
+ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
4182
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4183
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
4184
+ const std::vector<whisper_grammar_candidate> & candidates) {
4185
+ if (candidates.empty() || stacks.empty()) {
4186
+ return std::vector<whisper_grammar_candidate>();
4187
+ }
4188
+
4189
+ auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
4190
+
4191
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
4192
+ rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
4193
+ }
4194
+ return rejects;
4195
+ }
4196
+
4197
+ static struct whisper_grammar whisper_grammar_init(
4198
+ const whisper_grammar_element ** rules,
4199
+ size_t n_rules,
4200
+ size_t i_start_rule) {
4201
+ const whisper_grammar_element * pos;
4202
+
4203
+ // copy rule definitions into vectors
4204
+ std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
4205
+ for (size_t i = 0; i < n_rules; i++) {
4206
+ for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
4207
+ vec_rules[i].push_back(*pos);
4208
+ }
4209
+ vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
4210
+ }
4211
+
4212
+ // loop over alternates of start rule to build initial stacks
4213
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
4214
+ pos = rules[i_start_rule];
4215
+ do {
4216
+ std::vector<const whisper_grammar_element *> stack;
4217
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
4218
+ // if alternate is nonempty, add to stack
4219
+ stack.push_back(pos);
4220
+ }
4221
+ whisper_grammar_advance_stack(vec_rules, stack, stacks);
4222
+ while (!whisper_grammar_is_end_of_sequence(pos)) {
4223
+ // scan to end of alternate def
4224
+ pos++;
4225
+ }
4226
+ if (pos->type == WHISPER_GRETYPE_ALT) {
4227
+ // there's another alternate def of this rule to process
4228
+ pos++;
4229
+ } else {
4230
+ break;
4231
+ }
4232
+ } while (true);
4233
+
4234
+ return { std::move(vec_rules), std::move(stacks), {} };
4235
+ }
4236
+
4237
+ static void whisper_suppress_invalid_grammar(
4238
+ whisper_context & ctx,
4239
+ const whisper_full_params & params,
4240
+ std::vector<float> & logits,
4241
+ const whisper_grammar & grammar) {
4242
+
4243
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
4244
+ return;
4245
+ }
4246
+
4247
+ //bool allow_eot = false;
4248
+ //for (const auto & stack : grammar.stacks) {
4249
+ // if (stack.empty()) {
4250
+ // allow_eot = true;
4251
+ // break;
4252
+ // }
4253
+ //}
4254
+
4255
+ const whisper_token eot = whisper_token_eot(&ctx);
4256
+
4257
+ std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
4258
+ std::vector<whisper_grammar_candidate> candidates_grammar;
4259
+
4260
+ for (whisper_token id = 0; id < eot; ++id) {
4261
+ const std::string & text = ctx.vocab.id_to_token[id];
4262
+ if (!text.empty()) {
4263
+ candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
4264
+ candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
4265
+ }
4266
+ }
4267
+
4268
+ const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
4269
+
4270
+ for (const auto & reject : rejects) {
4271
+ logits[reject.id] -= params.grammar_penalty;
4272
+ }
4273
+
4274
+ // when the grammar allows a continuation, we penalize the end-of-text token
4275
+ //if (!allow_eot) {
4276
+ // logits[eot] -= params.grammar_penalty;
4277
+ //}
4278
+ //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
4279
+ }
4280
+
4281
+ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
4282
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
4283
+ return;
4284
+ }
4285
+
4286
+ //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
4287
+
4288
+ const std::string & text = ctx.vocab.id_to_token[token];
4289
+
4290
+ if (text.rfind("[_", 0) == 0) {
4291
+ // fprintf(stderr, " (skipped)\n");
4292
+ return;
4293
+ }
4294
+ // fprintf(stderr, "\n");
4295
+
4296
+ // Note terminating 0 in decoded string
4297
+ const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8);
4298
+ const auto & code_points = decoded.first;
4299
+ for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
4300
+ grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
4301
+ }
4302
+ grammar.partial_utf8 = decoded.second;
4303
+ }
4304
+
4305
+ //////////////
4306
+ // END grammar
4307
+ //////////////
4308
+
3720
4309
  ////////////////////////////////////////////////////////////////////////////
3721
4310
 
4311
+ struct whisper_context_params * whisper_context_default_params_by_ref() {
4312
+ struct whisper_context_params params = whisper_context_default_params();
4313
+
4314
+ struct whisper_context_params* result = new whisper_context_params();
4315
+ *result = params;
4316
+ return result;
4317
+ }
4318
+
3722
4319
  struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
3723
4320
  struct whisper_full_params params = whisper_full_default_params(strategy);
3724
4321
 
@@ -3738,6 +4335,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3738
4335
 
3739
4336
  /*.translate =*/ false,
3740
4337
  /*.no_context =*/ true,
4338
+ /*.no_timestamps =*/ false,
3741
4339
  /*.single_segment =*/ false,
3742
4340
  /*.print_special =*/ false,
3743
4341
  /*.print_progress =*/ true,
@@ -3771,7 +4369,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3771
4369
  /*.max_initial_ts =*/ 1.0f,
3772
4370
  /*.length_penalty =*/ -1.0f,
3773
4371
 
3774
- /*.temperature_inc =*/ 0.4f,
4372
+ /*.temperature_inc =*/ 0.2f,
3775
4373
  /*.entropy_thold =*/ 2.4f,
3776
4374
  /*.logprob_thold =*/ -1.0f,
3777
4375
  /*.no_speech_thold =*/ 0.6f,
@@ -3795,24 +4393,29 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3795
4393
  /*.encoder_begin_callback =*/ nullptr,
3796
4394
  /*.encoder_begin_callback_user_data =*/ nullptr,
3797
4395
 
3798
- /*.abort_callback =*/ nullptr,
3799
- /*.abort_callback_user_data =*/ nullptr,
4396
+ /*.abort_callback =*/ nullptr,
4397
+ /*.abort_callback_user_data =*/ nullptr,
3800
4398
 
3801
4399
  /*.logits_filter_callback =*/ nullptr,
3802
4400
  /*.logits_filter_callback_user_data =*/ nullptr,
4401
+
4402
+ /*.grammar_rules =*/ nullptr,
4403
+ /*.n_grammar_rules =*/ 0,
4404
+ /*.i_start_rule =*/ 0,
4405
+ /*.grammar_penalty =*/ 100.0f,
3803
4406
  };
3804
4407
 
3805
4408
  switch (strategy) {
3806
4409
  case WHISPER_SAMPLING_GREEDY:
3807
4410
  {
3808
4411
  result.greedy = {
3809
- /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
4412
+ /*.best_of =*/ 5,
3810
4413
  };
3811
4414
  } break;
3812
4415
  case WHISPER_SAMPLING_BEAM_SEARCH:
3813
4416
  {
3814
4417
  result.beam_search = {
3815
- /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
4418
+ /*.beam_size =*/ 5,
3816
4419
 
3817
4420
  /*.patience =*/ -1.0f,
3818
4421
  };
@@ -3902,11 +4505,12 @@ static const std::vector<std::string> non_speech_tokens = {
3902
4505
  // process the logits for the selected decoder
3903
4506
  // - applies logit filters
3904
4507
  // - computes logprobs and probs
4508
+ // TODO: optimize
3905
4509
  static void whisper_process_logits(
3906
4510
  struct whisper_context & ctx,
3907
4511
  struct whisper_state & state,
3908
- const struct whisper_full_params params,
3909
4512
  struct whisper_decoder & decoder,
4513
+ const struct whisper_full_params params,
3910
4514
  float temperature) {
3911
4515
  const auto & vocab = ctx.vocab;
3912
4516
  const auto & tokens_cur = decoder.sequence.tokens;
@@ -3923,7 +4527,7 @@ static void whisper_process_logits(
3923
4527
  auto & logprobs = decoder.logprobs;
3924
4528
  {
3925
4529
  logits.resize(n_logits);
3926
- memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
4530
+ memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
3927
4531
 
3928
4532
  if (temperature > 0.0f) {
3929
4533
  for (int i = 0; i < n_logits; i++) {
@@ -3951,6 +4555,11 @@ static void whisper_process_logits(
3951
4555
  // suppress <|notimestamps|> token
3952
4556
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
3953
4557
  logits[vocab.token_not] = -INFINITY;
4558
+ if (params.no_timestamps) {
4559
+ for (int i = vocab.token_beg; i < n_logits; ++i) {
4560
+ logits[i] = -INFINITY;
4561
+ }
4562
+ }
3954
4563
 
3955
4564
  // suppress sot and nosp tokens
3956
4565
  logits[vocab.token_sot] = -INFINITY;
@@ -3964,6 +4573,15 @@ static void whisper_process_logits(
3964
4573
  // suppress task tokens
3965
4574
  logits[vocab.token_translate] = -INFINITY;
3966
4575
  logits[vocab.token_transcribe] = -INFINITY;
4576
+ logits[vocab.token_prev] = -INFINITY;
4577
+
4578
+ // suppress lang tokens
4579
+ for (size_t i = 0; i < g_lang.size(); ++i) {
4580
+ logits[whisper_token_lang(&ctx, i)] = -INFINITY;
4581
+ }
4582
+
4583
+ // suppress prev token
4584
+ logits[vocab.token_prev] = -INFINITY;
3967
4585
 
3968
4586
  if (params.logits_filter_callback) {
3969
4587
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
@@ -3996,7 +4614,7 @@ static void whisper_process_logits(
3996
4614
  const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
3997
4615
  const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
3998
4616
 
3999
- //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
4617
+ //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
4000
4618
 
4001
4619
  if (last_was_timestamp) {
4002
4620
  if (penultimate_was_timestamp) {
@@ -4072,13 +4690,37 @@ static void whisper_process_logits(
4072
4690
 
4073
4691
  const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
4074
4692
 
4075
- //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4693
+ //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4076
4694
 
4077
4695
  if (timestamp_logprob > max_text_token_logprob) {
4078
4696
  for (int i = 0; i < vocab.token_beg; ++i) {
4079
4697
  logits[i] = -INFINITY;
4080
4698
  logprobs[i] = -INFINITY;
4081
4699
  }
4700
+ } else {
4701
+ if (params.n_grammar_rules > 0) {
4702
+ whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
4703
+
4704
+ // populate the logprobs array (log_softmax)
4705
+ {
4706
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
4707
+ float logsumexp = 0.0f;
4708
+ for (int i = 0; i < n_logits; ++i) {
4709
+ if (logits[i] > -INFINITY) {
4710
+ logsumexp += expf(logits[i] - logit_max);
4711
+ }
4712
+ }
4713
+ logsumexp = logf(logsumexp) + logit_max;
4714
+
4715
+ for (int i = 0; i < n_logits; ++i) {
4716
+ if (logits[i] > -INFINITY) {
4717
+ logprobs[i] = logits[i] - logsumexp;
4718
+ } else {
4719
+ logprobs[i] = -INFINITY;
4720
+ }
4721
+ }
4722
+ }
4723
+ }
4082
4724
  }
4083
4725
  }
4084
4726
  }
@@ -4096,38 +4738,60 @@ static void whisper_process_logits(
4096
4738
 
4097
4739
  #if 0
4098
4740
  // print first 100 logits - token string : logit
4099
- for (int i = 0; i < 100; i++) {
4100
- const auto token = vocab.id_to_token.at(i);
4101
- const auto prob = probs[i];
4102
- const auto logit = logits[i];
4103
- const auto logprob = logprobs[i];
4104
- printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
4741
+ //for (int i = 0; i < 10; i++) {
4742
+ // const auto token = vocab.id_to_token.at(i);
4743
+ // const auto prob = probs[i];
4744
+ // const auto logit = logits[i];
4745
+ // const auto logprob = logprobs[i];
4746
+ // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
4747
+ //}
4748
+
4749
+ // print sorted
4750
+ {
4751
+ std::vector<std::pair<float, int>> pairs;
4752
+
4753
+ for (int i = 0; i < n_logits; ++i) {
4754
+ pairs.push_back(std::make_pair(probs[i], i));
4755
+ }
4756
+
4757
+ std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
4758
+ return a.first > b.first;
4759
+ });
4760
+
4761
+ for (int i = 0; i < 10; i++) {
4762
+ const auto token = vocab.id_to_token.at(pairs[i].second);
4763
+ const auto prob = pairs[i].first;
4764
+ const auto logit = logits[pairs[i].second];
4765
+ const auto logprob = logprobs[pairs[i].second];
4766
+ printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
4767
+ }
4768
+
4769
+ printf("----------------\n");
4105
4770
  }
4106
4771
 
4107
4772
  // "And", "and", " And", " and"
4108
- printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
4109
- printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
4110
- printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
4111
- printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
4112
- printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
4113
-
4114
- printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
4115
- printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
4116
- printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
4117
- printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
4118
- printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
4119
-
4120
- printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
4121
- printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
4122
- printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
4123
- printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
4124
- printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
4773
+ //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
4774
+ //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
4775
+ //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
4776
+ //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
4777
+ //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
4778
+
4779
+ //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
4780
+ //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
4781
+ //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
4782
+ //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
4783
+ //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
4784
+
4785
+ //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
4786
+ //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
4787
+ //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
4788
+ //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
4789
+ //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
4125
4790
  #endif
4126
4791
  }
4127
4792
 
4128
4793
  static whisper_token_data whisper_sample_token(
4129
4794
  whisper_context & ctx,
4130
- whisper_state & state,
4131
4795
  const whisper_decoder & decoder,
4132
4796
  bool best) {
4133
4797
  whisper_token_data result = {
@@ -4172,7 +4836,7 @@ static whisper_token_data whisper_sample_token(
4172
4836
  } else {
4173
4837
  std::discrete_distribution<> dist(probs.begin(), probs.end());
4174
4838
 
4175
- result.id = dist(state.rng);
4839
+ result.id = dist(decoder.rng);
4176
4840
  result.p = probs[result.id];
4177
4841
  result.plog = logprobs[result.id];
4178
4842
  }
@@ -4182,15 +4846,12 @@ static whisper_token_data whisper_sample_token(
4182
4846
  result.pt = result.p;
4183
4847
  }
4184
4848
 
4185
- state.n_sample++;
4186
-
4187
4849
  return result;
4188
4850
  }
4189
4851
 
4190
4852
  static std::vector<whisper_token_data> whisper_sample_token_topk(
4191
4853
  whisper_context & ctx,
4192
- whisper_state & state,
4193
- const whisper_decoder & decoder,
4854
+ whisper_decoder & decoder,
4194
4855
  int k) {
4195
4856
  const auto & vocab = ctx.vocab;
4196
4857
 
@@ -4200,7 +4861,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4200
4861
 
4201
4862
  const int n_logits = vocab.n_vocab;
4202
4863
 
4203
- auto & logits_id = state.logits_id;
4864
+ auto & logits_id = decoder.logits_id;
4204
4865
 
4205
4866
  logits_id.resize(n_logits);
4206
4867
  for (int i = 0; i < n_logits; ++i) {
@@ -4246,8 +4907,11 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4246
4907
  ptsum = sum_ts;
4247
4908
  }
4248
4909
 
4910
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
4911
+
4249
4912
  for (int i = 0; i < k; ++i) {
4250
- const auto id = logits_id[i].second;
4913
+ const auto id = dist(decoder.rng);
4914
+ //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
4251
4915
 
4252
4916
  result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
4253
4917
 
@@ -4257,8 +4921,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4257
4921
  }
4258
4922
  }
4259
4923
 
4260
- state.n_sample++;
4261
-
4262
4924
  return result;
4263
4925
  }
4264
4926
 
@@ -4311,115 +4973,6 @@ static void whisper_sequence_score(
4311
4973
  }
4312
4974
  }
4313
4975
 
4314
- static bool whisper_kv_swap_fast(
4315
- std::vector<int> & view,
4316
- whisper_decoder src[],
4317
- std::vector<kv_buf> & kv_swap_bufs,
4318
- const int & n_decoders) {
4319
- WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
4320
-
4321
- // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
4322
- std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
4323
-
4324
- // (buffer->decoder or decoder->decoder)
4325
- std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
4326
-
4327
- // (decoder<->decoder)
4328
- std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
4329
- std::vector<whisper_pair<int, int>> p_swap_vec;
4330
- p_swap_vec.reserve(n_decoders);
4331
-
4332
- // see https://github.com/ggerganov/whisper.cpp/wiki
4333
- for (int i = 0; i < n_decoders; i++) {
4334
- // zero-copy (no modification)
4335
- if (i == view[i] || view[i] < 0) {
4336
- continue;
4337
- }
4338
-
4339
- bool is_one_copy = true;
4340
- // since we modify data sequentially, we only consider decoder indices after current index
4341
- for (int j = i + 1; j < n_decoders; j++) {
4342
- if (i == view[j]) {
4343
- // detect symmetric diagram
4344
- if (j == view[i]) {
4345
- p_swap_set.insert(i);
4346
- p_swap_set.insert(j);
4347
- p_swap_vec.emplace_back(i, j);
4348
- } else {
4349
- two_copy.insert(i);
4350
- is_one_copy = false;
4351
- }
4352
- break;
4353
- }
4354
- }
4355
- if (is_one_copy) {
4356
- one_copy.insert(i);
4357
- }
4358
- }
4359
-
4360
- kv_swap_bufs.resize(n_decoders);
4361
-
4362
- for (int i = 0; i < n_decoders; i++) {
4363
- kv_swap_bufs[i].k.resize(wsp_ggml_nbytes(src[i].kv_self.k));
4364
- kv_swap_bufs[i].v.resize(wsp_ggml_nbytes(src[i].kv_self.v));
4365
- }
4366
-
4367
- for (auto & i : two_copy) {
4368
- // make a copy of KV caches
4369
- WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
4370
- memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
4371
- memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
4372
- }
4373
-
4374
- // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
4375
- for (auto & i : two_copy) {
4376
- // skip the decoder indices that require pointer swapping
4377
- if (p_swap_set.find(i) != p_swap_set.end()) {
4378
- continue;
4379
- }
4380
-
4381
- if (two_copy.find(view[i]) != two_copy.end()) {
4382
- // modify KV caches of decoder using data from kv_swap_bufs
4383
- WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4384
- memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4385
- memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4386
- } else {
4387
- // modify KV caches of decoder using data from correspond decoder KV caches directly
4388
- WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4389
- memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
4390
- memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
4391
- }
4392
- }
4393
-
4394
- // then modify one-copy decoder KV caches
4395
- for (auto & i : one_copy) {
4396
- // skip the decoder indices that require pointer swapping
4397
- if (p_swap_set.find(i) != p_swap_set.end()) {
4398
- continue;
4399
- }
4400
-
4401
- if (two_copy.find(view[i]) != two_copy.end()) {
4402
- // modify KV caches of decoder using data from kv_swap_bufs
4403
- WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4404
- memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4405
- memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4406
- } else {
4407
- // modify KV caches of decoder using data from correspond decoder KV caches directly
4408
- WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4409
- memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, wsp_ggml_nbytes(src[view[i]].kv_self.k));
4410
- memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, wsp_ggml_nbytes(src[view[i]].kv_self.v));
4411
- }
4412
- }
4413
-
4414
- // swap the pointers
4415
- for (auto & i : p_swap_vec) {
4416
- WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
4417
- std::swap(src[i.first].kv_self, src[i.second].kv_self);
4418
- }
4419
-
4420
- return true;
4421
- }
4422
-
4423
4976
  int whisper_full_with_state(
4424
4977
  struct whisper_context * ctx,
4425
4978
  struct whisper_state * state,
@@ -4435,11 +4988,11 @@ int whisper_full_with_state(
4435
4988
  // compute log mel spectrogram
4436
4989
  if (params.speed_up) {
4437
4990
  // TODO: Replace PV with more advanced algorithm
4438
- log("%s: failed to compute log mel spectrogram\n", __func__);
4991
+ WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
4439
4992
  return -1;
4440
4993
  } else {
4441
4994
  if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
4442
- log("%s: failed to compute log mel spectrogram\n", __func__);
4995
+ WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
4443
4996
  return -2;
4444
4997
  }
4445
4998
  }
@@ -4451,13 +5004,13 @@ int whisper_full_with_state(
4451
5004
 
4452
5005
  const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
4453
5006
  if (lang_id < 0) {
4454
- log("%s: failed to auto-detect language\n", __func__);
5007
+ WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
4455
5008
  return -3;
4456
5009
  }
4457
5010
  state->lang_id = lang_id;
4458
5011
  params.language = whisper_lang_str(lang_id);
4459
5012
 
4460
- log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
5013
+ WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
4461
5014
  if (params.detect_language) {
4462
5015
  return 0;
4463
5016
  }
@@ -4479,6 +5032,7 @@ int whisper_full_with_state(
4479
5032
  // basically don't process anything that is less than 1.0s
4480
5033
  // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
4481
5034
  if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
5035
+ WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
4482
5036
  return 0;
4483
5037
  }
4484
5038
 
@@ -4509,40 +5063,23 @@ int whisper_full_with_state(
4509
5063
 
4510
5064
  n_decoders = std::max(1, n_decoders);
4511
5065
 
5066
+ if (n_decoders > WHISPER_MAX_DECODERS) {
5067
+ WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
5068
+ return -4;
5069
+ }
5070
+
4512
5071
  // TAGS: WHISPER_DECODER_INIT
4513
5072
  for (int j = 1; j < n_decoders; j++) {
4514
5073
  auto & decoder = state->decoders[j];
4515
5074
 
4516
- if (decoder.kv_self.ctx == nullptr) {
4517
- decoder.kv_self = state->decoders[0].kv_self;
4518
- if (!kv_cache_reinit(decoder.kv_self)) {
4519
- log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
4520
- return -4;
4521
- }
4522
-
4523
- WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
4524
-
4525
- decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
4526
-
4527
- decoder.probs.resize (ctx->vocab.n_vocab);
4528
- decoder.logits.resize (ctx->vocab.n_vocab);
4529
- decoder.logprobs.resize(ctx->vocab.n_vocab);
5075
+ decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
4530
5076
 
4531
- // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
4532
- #ifdef WSP_GGML_USE_METAL
4533
- #define WHISPER_METAL_CHECK_BUF(result) \
4534
- if (!(result)) { \
4535
- log("%s: failed to add metal buffer\n", __func__); \
4536
- return 0; \
4537
- }
5077
+ decoder.probs.resize (ctx->vocab.n_vocab);
5078
+ decoder.logits.resize (ctx->vocab.n_vocab);
5079
+ decoder.logprobs.resize(ctx->vocab.n_vocab);
5080
+ decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
4538
5081
 
4539
- const std::string kv_name = "kv_self_" + std::to_string(j);
4540
- auto & kv_self = decoder.kv_self;
4541
-
4542
- WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
4543
- #undef WHISPER_METAL_CHECK_BUF
4544
- #endif
4545
- }
5082
+ decoder.rng = std::mt19937(0);
4546
5083
  }
4547
5084
 
4548
5085
  // the accumulated text context so far
@@ -4557,7 +5094,7 @@ int whisper_full_with_state(
4557
5094
 
4558
5095
  // initial prompt
4559
5096
  if (!params.prompt_tokens && params.initial_prompt) {
4560
- prompt_tokens.resize(2048);
5097
+ prompt_tokens.resize(1024);
4561
5098
  prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
4562
5099
  params.prompt_tokens = prompt_tokens.data();
4563
5100
  params.prompt_n_tokens = prompt_tokens.size();
@@ -4575,13 +5112,14 @@ int whisper_full_with_state(
4575
5112
 
4576
5113
  // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
4577
5114
  if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
4578
- log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
5115
+ WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
4579
5116
  return -5;
4580
5117
  }
4581
5118
  state->exp_n_audio_ctx = params.audio_ctx;
4582
5119
 
4583
5120
  // these tokens determine the task that will be performed
4584
- std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
5121
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
5122
+
4585
5123
  if (whisper_is_multilingual(ctx)) {
4586
5124
  const int lang_id = whisper_lang_id(params.language);
4587
5125
  state->lang_id = lang_id;
@@ -4593,6 +5131,19 @@ int whisper_full_with_state(
4593
5131
  }
4594
5132
  }
4595
5133
 
5134
+ // distilled models require the "no_timestamps" token
5135
+ {
5136
+ const bool is_distil = ctx->model.hparams.n_text_layer == 2;
5137
+ if (is_distil && !params.no_timestamps) {
5138
+ WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
5139
+ params.no_timestamps = true;
5140
+ }
5141
+ }
5142
+
5143
+ if (params.no_timestamps) {
5144
+ prompt_init.push_back(whisper_token_not(ctx));
5145
+ }
5146
+
4596
5147
  int seek = seek_start;
4597
5148
 
4598
5149
  std::vector<whisper_token> prompt;
@@ -4605,8 +5156,10 @@ int whisper_full_with_state(
4605
5156
  bool has_ts;
4606
5157
 
4607
5158
  whisper_sequence sequence;
5159
+ whisper_grammar grammar;
4608
5160
  };
4609
5161
 
5162
+ std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
4610
5163
  std::vector<beam_candidate> beam_candidates;
4611
5164
 
4612
5165
  // main loop
@@ -4615,24 +5168,24 @@ int whisper_full_with_state(
4615
5168
  const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
4616
5169
 
4617
5170
  params.progress_callback(
4618
- ctx, ctx->state, progress_cur, params.progress_callback_user_data);
5171
+ ctx, state, progress_cur, params.progress_callback_user_data);
4619
5172
  }
4620
5173
 
4621
- // of only 1 second left, then stop
5174
+ // if only 1 second left, then stop
4622
5175
  if (seek + 100 >= seek_end) {
4623
5176
  break;
4624
5177
  }
4625
5178
 
4626
5179
  if (params.encoder_begin_callback) {
4627
5180
  if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
4628
- log("%s: encoder_begin_callback returned false - aborting\n", __func__);
5181
+ WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
4629
5182
  break;
4630
5183
  }
4631
5184
  }
4632
5185
 
4633
5186
  // encode audio features starting at offset seek
4634
5187
  if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4635
- log("%s: failed to encode\n", __func__);
5188
+ WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
4636
5189
  return -6;
4637
5190
  }
4638
5191
 
@@ -4668,14 +5221,12 @@ int whisper_full_with_state(
4668
5221
 
4669
5222
  n_decoders_cur = std::max(1, n_decoders_cur);
4670
5223
 
4671
- WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
5224
+ WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
4672
5225
 
4673
5226
  // TAGS: WHISPER_DECODER_INIT
4674
5227
  for (int j = 0; j < n_decoders_cur; ++j) {
4675
5228
  auto & decoder = state->decoders[j];
4676
5229
 
4677
- decoder.kv_self.n = 0;
4678
-
4679
5230
  decoder.sequence.tokens.clear();
4680
5231
  decoder.sequence.result_len = 0;
4681
5232
  decoder.sequence.sum_logprobs_all = 0.0;
@@ -4689,10 +5240,16 @@ int whisper_full_with_state(
4689
5240
  decoder.failed = false;
4690
5241
  decoder.completed = false;
4691
5242
  decoder.has_ts = false;
5243
+
5244
+ if (params.grammar_rules != nullptr) {
5245
+ decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
5246
+ } else {
5247
+ decoder.grammar = {};
5248
+ }
4692
5249
  }
4693
5250
 
4694
5251
  // init prompt and kv cache for the current iteration
4695
- // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
5252
+ // TODO: do not recompute the prompt if it is the same as previous time
4696
5253
  {
4697
5254
  prompt.clear();
4698
5255
 
@@ -4714,25 +5271,26 @@ int whisper_full_with_state(
4714
5271
  }
4715
5272
  WHISPER_PRINT_DEBUG("\n\n");
4716
5273
 
4717
- if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4718
- log("%s: failed to decode\n", __func__);
5274
+ whisper_kv_cache_clear(state->kv_self);
5275
+
5276
+ whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5277
+
5278
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5279
+ WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
4719
5280
  return -7;
4720
5281
  }
4721
5282
 
4722
5283
  {
4723
5284
  const int64_t t_start_sample_us = wsp_ggml_time_us();
4724
5285
 
4725
- whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
5286
+ state->decoders[0].i_batch = prompt.size() - 1;
4726
5287
 
4727
- state->decoders[0].kv_self.n += prompt.size();
5288
+ whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
4728
5289
 
4729
5290
  for (int j = 1; j < n_decoders_cur; ++j) {
4730
5291
  auto & decoder = state->decoders[j];
4731
5292
 
4732
- memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, wsp_ggml_nbytes(decoder.kv_self.k));
4733
- memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, wsp_ggml_nbytes(decoder.kv_self.v));
4734
-
4735
- decoder.kv_self.n += prompt.size();
5293
+ whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
4736
5294
 
4737
5295
  memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
4738
5296
  memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
@@ -4747,41 +5305,81 @@ int whisper_full_with_state(
4747
5305
  const int64_t t_start_sample_us = wsp_ggml_time_us();
4748
5306
 
4749
5307
  if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
4750
- beam_candidates.clear();
5308
+ for (auto & bc : bc_per_dec) {
5309
+ bc.clear();
5310
+ }
4751
5311
  }
4752
5312
 
4753
- // generate new sequence candidates for each decoder
4754
- for (int j = 0; j < n_decoders_cur; ++j) {
4755
- auto & decoder = state->decoders[j];
5313
+ // sampling
5314
+ // TODO: avoid memory allocations, optimize, avoid threads?
5315
+ {
5316
+ std::atomic<int> j_cur(0);
4756
5317
 
4757
- if (decoder.completed || decoder.failed) {
4758
- continue;
4759
- }
5318
+ auto process = [&]() {
5319
+ while (true) {
5320
+ const int j = j_cur.fetch_add(1);
4760
5321
 
4761
- switch (params.strategy) {
4762
- case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
4763
- {
4764
- if (t_cur < 1e-6f) {
4765
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
4766
- } else {
4767
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
4768
- }
5322
+ if (j >= n_decoders_cur) {
5323
+ break;
5324
+ }
4769
5325
 
4770
- decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
4771
- } break;
4772
- case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
4773
- {
4774
- const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
5326
+ auto & decoder = state->decoders[j];
4775
5327
 
4776
- for (const auto & token : tokens_new) {
4777
- beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
4778
- beam_candidates.back().sequence.tokens.push_back(token);
4779
- beam_candidates.back().sequence.sum_logprobs_all += token.plog;
5328
+ if (decoder.completed || decoder.failed) {
5329
+ continue;
5330
+ }
4780
5331
 
4781
- //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
4782
- }
4783
- } break;
5332
+ switch (params.strategy) {
5333
+ case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
5334
+ {
5335
+ if (t_cur < 1e-6f) {
5336
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
5337
+ } else {
5338
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
5339
+ }
5340
+
5341
+ decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
5342
+ } break;
5343
+ case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
5344
+ {
5345
+ const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
5346
+
5347
+ for (const auto & token : tokens_new) {
5348
+ bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
5349
+ bc_per_dec[j].back().sequence.tokens.push_back(token);
5350
+ bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
5351
+ }
5352
+ } break;
5353
+ };
5354
+ }
4784
5355
  };
5356
+
5357
+ const int n_threads = std::min(params.n_threads, n_decoders_cur);
5358
+
5359
+ if (n_threads == 1) {
5360
+ process();
5361
+ } else {
5362
+ std::vector<std::thread> threads(n_threads - 1);
5363
+
5364
+ for (int t = 0; t < n_threads - 1; ++t) {
5365
+ threads[t] = std::thread(process);
5366
+ }
5367
+
5368
+ process();
5369
+
5370
+ for (int t = 0; t < n_threads - 1; ++t) {
5371
+ threads[t].join();
5372
+ }
5373
+ }
5374
+ }
5375
+
5376
+ beam_candidates.clear();
5377
+ for (const auto & bc : bc_per_dec) {
5378
+ beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
5379
+
5380
+ if (!bc.empty()) {
5381
+ state->n_sample += 1;
5382
+ }
4785
5383
  }
4786
5384
 
4787
5385
  // for beam-search, choose the top candidates and update the KV caches
@@ -4794,7 +5392,6 @@ int whisper_full_with_state(
4794
5392
  });
4795
5393
 
4796
5394
  uint32_t cur_c = 0;
4797
- std::vector<int> decoder_idx(n_decoders_cur, -1);
4798
5395
 
4799
5396
  for (int j = 0; j < n_decoders_cur; ++j) {
4800
5397
  auto & decoder = state->decoders[j];
@@ -4803,23 +5400,38 @@ int whisper_full_with_state(
4803
5400
  continue;
4804
5401
  }
4805
5402
 
5403
+ if (cur_c >= beam_candidates.size()) {
5404
+ cur_c = 0;
5405
+ }
5406
+
4806
5407
  auto & cur = beam_candidates[cur_c++];
4807
5408
 
4808
5409
  while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
4809
5410
  ++cur_c;
4810
5411
  }
4811
5412
 
4812
- decoder.sequence = cur.sequence;
4813
5413
  decoder.seek_delta = cur.seek_delta;
4814
5414
  decoder.has_ts = cur.has_ts;
5415
+ decoder.sequence = cur.sequence;
5416
+ decoder.grammar = cur.grammar;
5417
+
5418
+ whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
4815
5419
 
4816
- decoder_idx[j] = cur.decoder_idx;
4817
5420
  WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
4818
5421
  __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
4819
5422
  }
4820
5423
 
4821
- // update KV caches
4822
- whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
5424
+ for (int j = 0; j < n_decoders_cur; ++j) {
5425
+ auto & decoder = state->decoders[j];
5426
+
5427
+ if (decoder.completed || decoder.failed) {
5428
+ continue;
5429
+ }
5430
+
5431
+ whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
5432
+ whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
5433
+ whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
5434
+ }
4823
5435
  }
4824
5436
 
4825
5437
  // update the decoder state
@@ -4848,6 +5460,7 @@ int whisper_full_with_state(
4848
5460
 
4849
5461
  // do not allow to go back in time
4850
5462
  if (has_ts && seek_delta > seek_delta_new && result_len < i) {
5463
+ WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
4851
5464
  failed = true; // TODO: maybe this is not a failure ?
4852
5465
  continue;
4853
5466
  }
@@ -4857,6 +5470,8 @@ int whisper_full_with_state(
4857
5470
  has_ts = true;
4858
5471
  }
4859
5472
 
5473
+ whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
5474
+
4860
5475
  #ifdef WHISPER_DEBUG
4861
5476
  {
4862
5477
  const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
@@ -4874,6 +5489,7 @@ int whisper_full_with_state(
4874
5489
  if (seek + seek_delta + 100 >= seek_end) {
4875
5490
  result_len = i + 1;
4876
5491
  } else {
5492
+ WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
4877
5493
  failed = true;
4878
5494
  continue;
4879
5495
  }
@@ -4884,6 +5500,7 @@ int whisper_full_with_state(
4884
5500
  seek_delta = 100*WHISPER_CHUNK_SIZE;
4885
5501
  }
4886
5502
 
5503
+ WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j);
4887
5504
  completed = true;
4888
5505
  continue;
4889
5506
  }
@@ -4899,6 +5516,7 @@ int whisper_full_with_state(
4899
5516
  // sometimes, the decoding can get stuck in a repetition loop
4900
5517
  // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
4901
5518
  if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
5519
+ WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
4902
5520
  failed = true;
4903
5521
  continue;
4904
5522
  }
@@ -4926,32 +5544,83 @@ int whisper_full_with_state(
4926
5544
  state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
4927
5545
 
4928
5546
  // obtain logits for the next token
4929
- for (int j = 0; j < n_decoders_cur; ++j) {
4930
- auto & decoder = state->decoders[j];
5547
+ {
5548
+ auto & batch = state->batch;
4931
5549
 
4932
- if (decoder.failed || decoder.completed) {
4933
- continue;
4934
- }
5550
+ batch.n_tokens = 0;
5551
+
5552
+ const int n_past = prompt.size() + i;
5553
+
5554
+ for (int j = 0; j < n_decoders_cur; ++j) {
5555
+ auto & decoder = state->decoders[j];
5556
+
5557
+ if (decoder.failed || decoder.completed) {
5558
+ continue;
5559
+ }
4935
5560
 
4936
- decoder.tokens_tmp.resize(1);
4937
- decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
5561
+ //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
4938
5562
 
4939
- //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
5563
+ decoder.i_batch = batch.n_tokens;
4940
5564
 
4941
- if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4942
- log("%s: failed to decode\n", __func__);
5565
+ batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
5566
+ batch.pos [batch.n_tokens] = n_past;
5567
+ batch.n_seq_id[batch.n_tokens] = 1;
5568
+ batch.seq_id [batch.n_tokens][0] = j;
5569
+ batch.logits [batch.n_tokens] = 1;
5570
+ batch.n_tokens++;
5571
+ }
5572
+
5573
+ assert(batch.n_tokens > 0);
5574
+
5575
+ if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5576
+ WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
4943
5577
  return -8;
4944
5578
  }
4945
5579
 
5580
+ const int64_t t_start_sample_us = wsp_ggml_time_us();
5581
+
5582
+ // TODO: avoid memory allocations, optimize, avoid threads?
4946
5583
  {
4947
- const int64_t t_start_sample_us = wsp_ggml_time_us();
5584
+ std::atomic<int> j_cur(0);
5585
+
5586
+ auto process = [&]() {
5587
+ while (true) {
5588
+ const int j = j_cur.fetch_add(1);
5589
+
5590
+ if (j >= n_decoders_cur) {
5591
+ break;
5592
+ }
4948
5593
 
4949
- whisper_process_logits(*ctx, *state, params, decoder, t_cur);
5594
+ auto & decoder = state->decoders[j];
4950
5595
 
4951
- ++decoder.kv_self.n;
5596
+ if (decoder.failed || decoder.completed) {
5597
+ continue;
5598
+ }
5599
+
5600
+ whisper_process_logits(*ctx, *state, decoder, params, t_cur);
5601
+ }
5602
+ };
5603
+
5604
+ const int n_threads = std::min(params.n_threads, n_decoders_cur);
5605
+
5606
+ if (n_threads == 1) {
5607
+ process();
5608
+ } else {
5609
+ std::vector<std::thread> threads(n_threads - 1);
5610
+
5611
+ for (int t = 0; t < n_threads - 1; ++t) {
5612
+ threads[t] = std::thread(process);
5613
+ }
5614
+
5615
+ process();
4952
5616
 
4953
- state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
5617
+ for (int t = 0; t < n_threads - 1; ++t) {
5618
+ threads[t].join();
5619
+ }
5620
+ }
4954
5621
  }
5622
+
5623
+ state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us;
4955
5624
  }
4956
5625
  }
4957
5626
 
@@ -4991,28 +5660,27 @@ int whisper_full_with_state(
4991
5660
  WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
4992
5661
  }
4993
5662
 
5663
+ bool success = true;
5664
+
4994
5665
  // was the decoding successful for the current temperature?
4995
5666
  // do fallback only if:
4996
5667
  // - we are not at the last temperature
4997
- // - we are not at the end of the audio (3 sec)
4998
- if (it != (int) temperatures.size() - 1 &&
4999
- seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
5000
- bool success = true;
5001
-
5668
+ if (it != (int) temperatures.size() - 1) {
5002
5669
  const auto & decoder = state->decoders[best_decoder_id];
5003
5670
 
5004
5671
  if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
5672
+ WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
5005
5673
  success = false;
5006
5674
  state->n_fail_p++;
5007
5675
  }
5676
+ }
5008
5677
 
5009
- if (success) {
5010
- //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
5011
- // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
5012
- //}
5678
+ if (success) {
5679
+ //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
5680
+ // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
5681
+ //}
5013
5682
 
5014
- break;
5015
- }
5683
+ break;
5016
5684
  }
5017
5685
 
5018
5686
  WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
@@ -5248,11 +5916,13 @@ int whisper_full_parallel(
5248
5916
  ctx->state->t_sample_us += states[i]->t_sample_us;
5249
5917
  ctx->state->t_encode_us += states[i]->t_encode_us;
5250
5918
  ctx->state->t_decode_us += states[i]->t_decode_us;
5919
+ ctx->state->t_batchd_us += states[i]->t_batchd_us;
5251
5920
  ctx->state->t_prompt_us += states[i]->t_prompt_us;
5252
5921
 
5253
5922
  ctx->state->n_sample += states[i]->n_sample;
5254
5923
  ctx->state->n_encode += states[i]->n_encode;
5255
5924
  ctx->state->n_decode += states[i]->n_decode;
5925
+ ctx->state->n_batchd += states[i]->n_batchd;
5256
5926
  ctx->state->n_prompt += states[i]->n_prompt;
5257
5927
 
5258
5928
  whisper_free_state(states[i]);
@@ -5265,12 +5935,12 @@ int whisper_full_parallel(
5265
5935
  ctx->state->t_decode_us /= n_processors;
5266
5936
 
5267
5937
  // print information about the audio boundaries
5268
- log("\n");
5269
- log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
5938
+ WHISPER_LOG_WARN("\n");
5939
+ WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
5270
5940
  for (int i = 0; i < n_processors - 1; ++i) {
5271
- log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
5941
+ WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
5272
5942
  }
5273
- log("%s: the transcription quality may be degraded near these boundaries\n", __func__);
5943
+ WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
5274
5944
 
5275
5945
  return ret;
5276
5946
  }
@@ -5385,8 +6055,45 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
5385
6055
  size_t n = 20;
5386
6056
  size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations
5387
6057
 
5388
- // 1GB MB array
5389
- const size_t size = arr*1024llu*1024llu;
6058
+ // 1GB array
6059
+ const size_t size = arr*1e6;
6060
+
6061
+ double sum = 0.0;
6062
+
6063
+ // heat-up
6064
+ {
6065
+ char * src = (char *) malloc(size);
6066
+ char * dst = (char *) malloc(size);
6067
+
6068
+ for (size_t i = 0; i < size; i++) src[i] = i;
6069
+
6070
+ memcpy(dst, src, size); // heat-up
6071
+
6072
+ double tsum = 0.0;
6073
+
6074
+ for (size_t i = 0; i < n; i++) {
6075
+ const int64_t t0 = wsp_ggml_time_us();
6076
+
6077
+ memcpy(dst, src, size);
6078
+
6079
+ const int64_t t1 = wsp_ggml_time_us();
6080
+
6081
+ tsum += (t1 - t0)*1e-6;
6082
+
6083
+ src[rand() % size] = rand() % 256;
6084
+ }
6085
+
6086
+ snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9));
6087
+ s += strbuf;
6088
+
6089
+ // needed to prevent the compiler from optimizing the memcpy away
6090
+ {
6091
+ for (size_t i = 0; i < size; i++) sum += dst[i];
6092
+ }
6093
+
6094
+ free(src);
6095
+ free(dst);
6096
+ }
5390
6097
 
5391
6098
  // single-thread
5392
6099
  {
@@ -5398,7 +6105,6 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
5398
6105
  memcpy(dst, src, size); // heat-up
5399
6106
 
5400
6107
  double tsum = 0.0;
5401
- double sum = 0.0;
5402
6108
 
5403
6109
  for (size_t i = 0; i < n; i++) {
5404
6110
  const int64_t t0 = wsp_ggml_time_us();
@@ -5412,21 +6118,73 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
5412
6118
  src[rand() % size] = rand() % 256;
5413
6119
  }
5414
6120
 
5415
- snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
6121
+ snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9));
5416
6122
  s += strbuf;
5417
6123
 
5418
6124
  // needed to prevent the compiler from optimizing the memcpy away
5419
6125
  {
5420
6126
  for (size_t i = 0; i < size; i++) sum += dst[i];
6127
+ }
6128
+
6129
+ free(src);
6130
+ free(dst);
6131
+ }
6132
+
6133
+ // multi-thread
6134
+
6135
+ for (uint32_t k = 1; k <= n_threads; k++) {
6136
+ char * src = (char *) malloc(size);
6137
+ char * dst = (char *) malloc(size);
6138
+
6139
+ for (size_t i = 0; i < size; i++) src[i] = i;
6140
+
6141
+ memcpy(dst, src, size); // heat-up
6142
+
6143
+ double tsum = 0.0;
6144
+
6145
+ auto helper = [&](int th) {
6146
+ const int64_t i0 = (th + 0)*size/k;
6147
+ const int64_t i1 = (th + 1)*size/k;
6148
+
6149
+ for (size_t i = 0; i < n; i++) {
6150
+ memcpy(dst + i0, src + i0, i1 - i0);
6151
+
6152
+ src[i0 + rand() % (i1 - i0)] = rand() % 256;
6153
+ };
6154
+ };
6155
+
6156
+ const int64_t t0 = wsp_ggml_time_us();
6157
+
6158
+ std::vector<std::thread> threads(k - 1);
6159
+ for (uint32_t th = 0; th < k - 1; ++th) {
6160
+ threads[th] = std::thread(helper, th);
6161
+ }
6162
+
6163
+ helper(k - 1);
6164
+
6165
+ for (uint32_t th = 0; th < k - 1; ++th) {
6166
+ threads[th].join();
6167
+ }
5421
6168
 
5422
- snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
5423
- s += strbuf;
6169
+ const int64_t t1 = wsp_ggml_time_us();
6170
+
6171
+ tsum += (t1 - t0)*1e-6;
6172
+
6173
+ snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
6174
+ s += strbuf;
6175
+
6176
+ // needed to prevent the compiler from optimizing the memcpy away
6177
+ {
6178
+ for (size_t i = 0; i < size; i++) sum += dst[i];
5424
6179
  }
5425
6180
 
5426
6181
  free(src);
5427
6182
  free(dst);
5428
6183
  }
5429
6184
 
6185
+ snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
6186
+ s += strbuf;
6187
+
5430
6188
  return s.c_str();
5431
6189
  }
5432
6190
 
@@ -5454,7 +6212,7 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
5454
6212
  // b: N*N*sizeof(float)
5455
6213
  // c: N*N*sizeof(float)
5456
6214
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5457
- std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead());
6215
+ std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead() + wsp_ggml_graph_overhead());
5458
6216
  std::vector<uint8_t> work;
5459
6217
 
5460
6218
  // put a bunch of random data in the buffer
@@ -5505,17 +6263,19 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
5505
6263
 
5506
6264
  struct wsp_ggml_tensor * c = wsp_ggml_mul_mat(ctx0, a, b);
5507
6265
 
5508
- struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(c);
6266
+ struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
6267
+
6268
+ wsp_ggml_build_forward_expand(gf, c);
5509
6269
 
5510
6270
  double tsum = 0.0;
5511
6271
 
5512
6272
  // heat-up
5513
- wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
6273
+ wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
5514
6274
 
5515
6275
  for (int i = 0; i < n_max; ++i) {
5516
6276
  const int64_t t0 = wsp_ggml_time_us();
5517
6277
 
5518
- wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
6278
+ wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
5519
6279
 
5520
6280
  const int64_t t1 = wsp_ggml_time_us();
5521
6281
 
@@ -5633,7 +6393,7 @@ static void whisper_exp_compute_token_level_timestamps(
5633
6393
  const int n_samples = state.energy.size();
5634
6394
 
5635
6395
  if (n_samples == 0) {
5636
- log("%s: no signal data available\n", __func__);
6396
+ WHISPER_LOG_ERROR("%s: no signal data available\n", __func__);
5637
6397
  return;
5638
6398
  }
5639
6399
 
@@ -5854,6 +6614,32 @@ static void whisper_exp_compute_token_level_timestamps(
5854
6614
  //}
5855
6615
  }
5856
6616
 
5857
- void whisper_set_log_callback(whisper_log_callback callback) {
5858
- whisper_log = callback;
6617
+ void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
6618
+ g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
6619
+ g_state.log_callback_user_data = user_data;
6620
+ }
6621
+
6622
+ WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
6623
+ static void whisper_log_internal(wsp_ggml_log_level level, const char * format, ...) {
6624
+ va_list args;
6625
+ va_start(args, format);
6626
+ char buffer[1024];
6627
+ int len = vsnprintf(buffer, 1024, format, args);
6628
+ if (len < 1024) {
6629
+ g_state.log_callback(level, buffer, g_state.log_callback_user_data);
6630
+ } else {
6631
+ char* buffer2 = new char[len+1];
6632
+ vsnprintf(buffer2, len+1, format, args);
6633
+ buffer2[len] = 0;
6634
+ g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
6635
+ delete[] buffer2;
6636
+ }
6637
+ va_end(args);
6638
+ }
6639
+
6640
+ static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data) {
6641
+ (void) level;
6642
+ (void) user_data;
6643
+ fputs(text, stderr);
6644
+ fflush(stderr);
5859
6645
  }