whispercpp 1.3.0 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
data/ext/ruby_whisper.cpp CHANGED
@@ -1,4 +1,5 @@
1
1
  #include <ruby.h>
2
+ #include <ruby/memory_view.h>
2
3
  #include "ruby_whisper.h"
3
4
  #define DR_WAV_IMPLEMENTATION
4
5
  #include "dr_wav.h"
@@ -35,6 +36,102 @@ extern "C" {
35
36
  VALUE mWhisper;
36
37
  VALUE cContext;
37
38
  VALUE cParams;
39
+ VALUE eError;
40
+
41
+ VALUE cSegment;
42
+ VALUE cModel;
43
+
44
+ static ID id_to_s;
45
+ static ID id_call;
46
+ static ID id___method__;
47
+ static ID id_to_enum;
48
+ static ID id_length;
49
+ static ID id_next;
50
+ static ID id_new;
51
+ static ID id_to_path;
52
+ static ID id_pre_converted_models;
53
+
54
+ static bool is_log_callback_finalized = false;
55
+
56
+ /*
57
+ * call-seq:
58
+ * lang_max_id -> Integer
59
+ */
60
+ static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
61
+ return INT2NUM(whisper_lang_max_id());
62
+ }
63
+
64
+ /*
65
+ * call-seq:
66
+ * lang_id(lang_name) -> Integer
67
+ */
68
+ static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
69
+ const char * lang_str = StringValueCStr(lang);
70
+ const int id = whisper_lang_id(lang_str);
71
+ if (-1 == id) {
72
+ rb_raise(rb_eArgError, "language not found: %s", lang_str);
73
+ }
74
+ return INT2NUM(id);
75
+ }
76
+
77
+ /*
78
+ * call-seq:
79
+ * lang_str(lang_id) -> String
80
+ */
81
+ static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
82
+ const int lang_id = NUM2INT(id);
83
+ const char * str = whisper_lang_str(lang_id);
84
+ if (nullptr == str) {
85
+ rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
86
+ }
87
+ return rb_str_new2(str);
88
+ }
89
+
90
+ /*
91
+ * call-seq:
92
+ * lang_str(lang_id) -> String
93
+ */
94
+ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
95
+ const int lang_id = NUM2INT(id);
96
+ const char * str_full = whisper_lang_str_full(lang_id);
97
+ if (nullptr == str_full) {
98
+ rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
99
+ }
100
+ return rb_str_new2(str_full);
101
+ }
102
+
103
+ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
104
+ is_log_callback_finalized = true;
105
+ return Qnil;
106
+ }
107
+
108
+ /*
109
+ * call-seq:
110
+ * log_set ->(level, buffer, user_data) { ... }, user_data -> nil
111
+ */
112
+ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
113
+ VALUE old_callback = rb_iv_get(self, "log_callback");
114
+ if (!NIL_P(old_callback)) {
115
+ rb_undefine_finalizer(old_callback);
116
+ }
117
+
118
+ rb_iv_set(self, "log_callback", log_callback);
119
+ rb_iv_set(self, "user_data", user_data);
120
+
121
+ VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
122
+ rb_define_finalizer(log_callback, finalize_log_callback);
123
+
124
+ whisper_log_set([](ggml_log_level level, const char * buffer, void * user_data) {
125
+ if (is_log_callback_finalized) {
126
+ return;
127
+ }
128
+ VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
129
+ VALUE udata = rb_iv_get(mWhisper, "user_data");
130
+ rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
131
+ }, nullptr);
132
+
133
+ return Qnil;
134
+ }
38
135
 
39
136
  static void ruby_whisper_free(ruby_whisper *rw) {
40
137
  if (rw->context) {
@@ -42,6 +139,7 @@ static void ruby_whisper_free(ruby_whisper *rw) {
42
139
  rw->context = NULL;
43
140
  }
44
141
  }
142
+
45
143
  static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
46
144
  }
47
145
 
@@ -54,10 +152,20 @@ void rb_whisper_free(ruby_whisper *rw) {
54
152
  free(rw);
55
153
  }
56
154
 
155
+ void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) {
156
+ rb_gc_mark(rwc->user_data);
157
+ rb_gc_mark(rwc->callback);
158
+ rb_gc_mark(rwc->callbacks);
159
+ }
160
+
57
161
  void rb_whisper_params_mark(ruby_whisper_params *rwp) {
162
+ rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
163
+ rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
164
+ rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
58
165
  }
59
166
 
60
167
  void rb_whisper_params_free(ruby_whisper_params *rwp) {
168
+ // How to free user_data and callback only when not referred to by others?
61
169
  ruby_whisper_params_free(rwp);
62
170
  free(rwp);
63
171
  }
@@ -69,13 +177,33 @@ static VALUE ruby_whisper_allocate(VALUE klass) {
69
177
  return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
70
178
  }
71
179
 
180
+ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate() {
181
+ ruby_whisper_callback_container *container;
182
+ container = ALLOC(ruby_whisper_callback_container);
183
+ container->context = nullptr;
184
+ container->user_data = Qnil;
185
+ container->callback = Qnil;
186
+ container->callbacks = rb_ary_new();
187
+ return container;
188
+ }
189
+
72
190
  static VALUE ruby_whisper_params_allocate(VALUE klass) {
73
191
  ruby_whisper_params *rwp;
74
192
  rwp = ALLOC(ruby_whisper_params);
75
193
  rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
194
+ rwp->diarize = false;
195
+ rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
196
+ rwp->progress_callback_container = rb_whisper_callback_container_allocate();
197
+ rwp->abort_callback_container = rb_whisper_callback_container_allocate();
76
198
  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
77
199
  }
78
200
 
201
+ /*
202
+ * call-seq:
203
+ * new("base.en") -> Whisper::Context
204
+ * new("path/to/model.bin") -> Whisper::Context
205
+ * new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
206
+ */
79
207
  static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
80
208
  ruby_whisper *rw;
81
209
  VALUE whisper_model_file_path;
@@ -84,7 +212,15 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
84
212
  rb_scan_args(argc, argv, "01", &whisper_model_file_path);
85
213
  Data_Get_Struct(self, ruby_whisper, rw);
86
214
 
87
- if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
215
+ VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
216
+ VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
217
+ if (!NIL_P(pre_converted_model)) {
218
+ whisper_model_file_path = pre_converted_model;
219
+ }
220
+ if (rb_respond_to(whisper_model_file_path, id_to_path)) {
221
+ whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
222
+ }
223
+ if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
88
224
  rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
89
225
  }
90
226
  rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
@@ -94,10 +230,21 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
94
230
  return self;
95
231
  }
96
232
 
233
+ // High level API
234
+ static VALUE rb_whisper_segment_initialize(VALUE context, int index);
235
+
97
236
  /*
98
237
  * transcribe a single file
99
238
  * can emit to a block results
100
239
  *
240
+ * params = Whisper::Params.new
241
+ * params.duration = 60_000
242
+ * whisper.transcribe "path/to/audio.wav", params do |text|
243
+ * puts text
244
+ * end
245
+ *
246
+ * call-seq:
247
+ * transcribe(path_to_audio, params) {|text| ...}
101
248
  **/
102
249
  static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
103
250
  ruby_whisper *rw;
@@ -108,7 +255,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
108
255
  Data_Get_Struct(self, ruby_whisper, rw);
109
256
  Data_Get_Struct(params, ruby_whisper_params, rwp);
110
257
 
111
- if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) {
258
+ if (!rb_respond_to(wave_file_path, id_to_s)) {
112
259
  rb_raise(rb_eRuntimeError, "Expected file path to wave file");
113
260
  }
114
261
 
@@ -206,6 +353,81 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
206
353
  rwp->params.encoder_begin_callback_user_data = &is_aborted;
207
354
  }
208
355
 
356
+ if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
357
+ rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
358
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
359
+
360
+ // Currently, doesn't support state because
361
+ // those require to resolve GC-related problems.
362
+ if (!NIL_P(container->callback)) {
363
+ rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
364
+ }
365
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
366
+ if (0 == callbacks_len) {
367
+ return;
368
+ }
369
+ const int n_segments = whisper_full_n_segments_from_state(state);
370
+ for (int i = n_new; i > 0; i--) {
371
+ int i_segment = n_segments - i;
372
+ VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
373
+ for (int j = 0; j < callbacks_len; j++) {
374
+ VALUE cb = rb_ary_entry(container->callbacks, j);
375
+ rb_funcall(cb, id_call, 1, segment);
376
+ }
377
+ }
378
+ };
379
+ rwp->new_segment_callback_container->context = &self;
380
+ rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
381
+ }
382
+
383
+ if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
384
+ rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
385
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
386
+ const VALUE progress = INT2NUM(progress_cur);
387
+ // Currently, doesn't support state because
388
+ // those require to resolve GC-related problems.
389
+ if (!NIL_P(container->callback)) {
390
+ rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
391
+ }
392
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
393
+ if (0 == callbacks_len) {
394
+ return;
395
+ }
396
+ for (int j = 0; j < callbacks_len; j++) {
397
+ VALUE cb = rb_ary_entry(container->callbacks, j);
398
+ rb_funcall(cb, id_call, 1, progress);
399
+ }
400
+ };
401
+ rwp->progress_callback_container->context = &self;
402
+ rwp->params.progress_callback_user_data = rwp->progress_callback_container;
403
+ }
404
+
405
+ if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
406
+ rwp->params.abort_callback = [](void * user_data) {
407
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
408
+ if (!NIL_P(container->callback)) {
409
+ VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
410
+ if (!NIL_P(result) && Qfalse != result) {
411
+ return true;
412
+ }
413
+ }
414
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
415
+ if (0 == callbacks_len) {
416
+ return false;
417
+ }
418
+ for (int j = 0; j < callbacks_len; j++) {
419
+ VALUE cb = rb_ary_entry(container->callbacks, j);
420
+ VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
421
+ if (!NIL_P(result) && Qfalse != result) {
422
+ return true;
423
+ }
424
+ }
425
+ return false;
426
+ };
427
+ rwp->abort_callback_container->context = &self;
428
+ rwp->params.abort_callback_user_data = rwp->abort_callback_container;
429
+ }
430
+
209
431
  if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
210
432
  fprintf(stderr, "failed to process audio\n");
211
433
  return self;
@@ -216,15 +438,396 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
216
438
  const char * text = whisper_full_get_segment_text(rw->context, i);
217
439
  output = rb_str_concat(output, rb_str_new2(text));
218
440
  }
219
- VALUE idCall = rb_intern("call");
441
+ VALUE idCall = id_call;
220
442
  if (blk != Qnil) {
221
443
  rb_funcall(blk, idCall, 1, output);
222
444
  }
223
445
  return self;
224
446
  }
225
447
 
448
+ /*
449
+ * call-seq:
450
+ * model_n_vocab -> Integer
451
+ */
452
+ VALUE ruby_whisper_model_n_vocab(VALUE self) {
453
+ ruby_whisper *rw;
454
+ Data_Get_Struct(self, ruby_whisper, rw);
455
+ return INT2NUM(whisper_model_n_vocab(rw->context));
456
+ }
457
+
458
+ /*
459
+ * call-seq:
460
+ * model_n_audio_ctx -> Integer
461
+ */
462
+ VALUE ruby_whisper_model_n_audio_ctx(VALUE self) {
463
+ ruby_whisper *rw;
464
+ Data_Get_Struct(self, ruby_whisper, rw);
465
+ return INT2NUM(whisper_model_n_audio_ctx(rw->context));
466
+ }
467
+
468
+ /*
469
+ * call-seq:
470
+ * model_n_audio_state -> Integer
471
+ */
472
+ VALUE ruby_whisper_model_n_audio_state(VALUE self) {
473
+ ruby_whisper *rw;
474
+ Data_Get_Struct(self, ruby_whisper, rw);
475
+ return INT2NUM(whisper_model_n_audio_state(rw->context));
476
+ }
477
+
478
+ /*
479
+ * call-seq:
480
+ * model_n_audio_head -> Integer
481
+ */
482
+ VALUE ruby_whisper_model_n_audio_head(VALUE self) {
483
+ ruby_whisper *rw;
484
+ Data_Get_Struct(self, ruby_whisper, rw);
485
+ return INT2NUM(whisper_model_n_audio_head(rw->context));
486
+ }
487
+
488
+ /*
489
+ * call-seq:
490
+ * model_n_audio_layer -> Integer
491
+ */
492
+ VALUE ruby_whisper_model_n_audio_layer(VALUE self) {
493
+ ruby_whisper *rw;
494
+ Data_Get_Struct(self, ruby_whisper, rw);
495
+ return INT2NUM(whisper_model_n_audio_layer(rw->context));
496
+ }
497
+
498
+ /*
499
+ * call-seq:
500
+ * model_n_text_ctx -> Integer
501
+ */
502
+ VALUE ruby_whisper_model_n_text_ctx(VALUE self) {
503
+ ruby_whisper *rw;
504
+ Data_Get_Struct(self, ruby_whisper, rw);
505
+ return INT2NUM(whisper_model_n_text_ctx(rw->context));
506
+ }
507
+
508
+ /*
509
+ * call-seq:
510
+ * model_n_text_state -> Integer
511
+ */
512
+ VALUE ruby_whisper_model_n_text_state(VALUE self) {
513
+ ruby_whisper *rw;
514
+ Data_Get_Struct(self, ruby_whisper, rw);
515
+ return INT2NUM(whisper_model_n_text_state(rw->context));
516
+ }
517
+
518
+ /*
519
+ * call-seq:
520
+ * model_n_text_head -> Integer
521
+ */
522
+ VALUE ruby_whisper_model_n_text_head(VALUE self) {
523
+ ruby_whisper *rw;
524
+ Data_Get_Struct(self, ruby_whisper, rw);
525
+ return INT2NUM(whisper_model_n_text_head(rw->context));
526
+ }
527
+
528
+ /*
529
+ * call-seq:
530
+ * model_n_text_layer -> Integer
531
+ */
532
+ VALUE ruby_whisper_model_n_text_layer(VALUE self) {
533
+ ruby_whisper *rw;
534
+ Data_Get_Struct(self, ruby_whisper, rw);
535
+ return INT2NUM(whisper_model_n_text_layer(rw->context));
536
+ }
537
+
538
+ /*
539
+ * call-seq:
540
+ * model_n_mels -> Integer
541
+ */
542
+ VALUE ruby_whisper_model_n_mels(VALUE self) {
543
+ ruby_whisper *rw;
544
+ Data_Get_Struct(self, ruby_whisper, rw);
545
+ return INT2NUM(whisper_model_n_mels(rw->context));
546
+ }
547
+
548
+ /*
549
+ * call-seq:
550
+ * model_ftype -> Integer
551
+ */
552
+ VALUE ruby_whisper_model_ftype(VALUE self) {
553
+ ruby_whisper *rw;
554
+ Data_Get_Struct(self, ruby_whisper, rw);
555
+ return INT2NUM(whisper_model_ftype(rw->context));
556
+ }
557
+
558
+ /*
559
+ * call-seq:
560
+ * model_type -> String
561
+ */
562
+ VALUE ruby_whisper_model_type(VALUE self) {
563
+ ruby_whisper *rw;
564
+ Data_Get_Struct(self, ruby_whisper, rw);
565
+ return rb_str_new2(whisper_model_type_readable(rw->context));
566
+ }
567
+
568
+ /*
569
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
570
+ * Not thread safe for same context
571
+ * Uses the specified decoding strategy to obtain the text.
572
+ *
573
+ * call-seq:
574
+ * full(params, samples, n_samples) -> nil
575
+ * full(params, samples) -> nil
576
+ *
577
+ * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
578
+ */
579
+ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
580
+ if (argc < 2 || argc > 3) {
581
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
582
+ }
583
+
584
+ ruby_whisper *rw;
585
+ ruby_whisper_params *rwp;
586
+ Data_Get_Struct(self, ruby_whisper, rw);
587
+ VALUE params = argv[0];
588
+ Data_Get_Struct(params, ruby_whisper_params, rwp);
589
+ VALUE samples = argv[1];
590
+ int n_samples;
591
+ rb_memory_view_t view;
592
+ const bool memory_view_available_p = rb_memory_view_available_p(samples);
593
+ if (argc == 3) {
594
+ n_samples = NUM2INT(argv[2]);
595
+ if (TYPE(samples) == T_ARRAY) {
596
+ if (RARRAY_LEN(samples) < n_samples) {
597
+ rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
598
+ }
599
+ }
600
+ // Should check when samples.respond_to?(:length)?
601
+ } else {
602
+ if (TYPE(samples) == T_ARRAY) {
603
+ n_samples = RARRAY_LEN(samples);
604
+ } else if (memory_view_available_p) {
605
+ if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
606
+ view.obj = Qnil;
607
+ rb_raise(rb_eArgError, "unable to get a memory view");
608
+ }
609
+ n_samples = view.byte_size / view.item_size;
610
+ } else if (rb_respond_to(samples, id_length)) {
611
+ n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
612
+ } else {
613
+ rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
614
+ }
615
+ }
616
+ float * c_samples = (float *)malloc(n_samples * sizeof(float));
617
+ if (memory_view_available_p) {
618
+ c_samples = (float *)view.data;
619
+ } else {
620
+ if (TYPE(samples) == T_ARRAY) {
621
+ for (int i = 0; i < n_samples; i++) {
622
+ c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
623
+ }
624
+ } else {
625
+ // TODO: use rb_block_call
626
+ VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
627
+ for (int i = 0; i < n_samples; i++) {
628
+ // TODO: check if iter is exhausted and raise ArgumentError appropriately
629
+ VALUE sample = rb_funcall(iter, id_next, 0);
630
+ c_samples[i] = RFLOAT_VALUE(sample);
631
+ }
632
+ }
633
+ }
634
+ const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
635
+ if (0 == result) {
636
+ return Qnil;
637
+ } else {
638
+ rb_exc_raise(rb_funcall(eError, id_new, 1, result));
639
+ }
640
+ }
641
+
642
+ /*
643
+ * Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
644
+ * Result is stored in the default state of the context
645
+ * Not thread safe if executed in parallel on the same context.
646
+ * It seems this approach can offer some speedup in some cases.
647
+ * However, the transcription accuracy can be worse at the beginning and end of each chunk.
648
+ *
649
+ * call-seq:
650
+ * full_parallel(params, samples) -> nil
651
+ * full_parallel(params, samples, n_samples) -> nil
652
+ * full_parallel(params, samples, n_samples, n_processors) -> nil
653
+ * full_parallel(params, samples, nil, n_processors) -> nil
654
+ */
655
+ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
656
+ if (argc < 2 || argc > 4) {
657
+ rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
658
+ }
659
+
660
+ ruby_whisper *rw;
661
+ ruby_whisper_params *rwp;
662
+ Data_Get_Struct(self, ruby_whisper, rw);
663
+ VALUE params = argv[0];
664
+ Data_Get_Struct(params, ruby_whisper_params, rwp);
665
+ VALUE samples = argv[1];
666
+ int n_samples;
667
+ int n_processors;
668
+ rb_memory_view_t view;
669
+ const bool memory_view_available_p = rb_memory_view_available_p(samples);
670
+ switch (argc) {
671
+ case 2:
672
+ n_processors = 1;
673
+ break;
674
+ case 3:
675
+ n_processors = 1;
676
+ break;
677
+ case 4:
678
+ n_processors = NUM2INT(argv[3]);
679
+ break;
680
+ }
681
+ if (argc >= 3 && !NIL_P(argv[2])) {
682
+ n_samples = NUM2INT(argv[2]);
683
+ if (TYPE(samples) == T_ARRAY) {
684
+ if (RARRAY_LEN(samples) < n_samples) {
685
+ rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
686
+ }
687
+ }
688
+ // Should check when samples.respond_to?(:length)?
689
+ } else if (memory_view_available_p) {
690
+ if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
691
+ view.obj = Qnil;
692
+ rb_raise(rb_eArgError, "unable to get a memory view");
693
+ }
694
+ n_samples = view.byte_size / view.item_size;
695
+ } else {
696
+ if (TYPE(samples) == T_ARRAY) {
697
+ n_samples = RARRAY_LEN(samples);
698
+ } else if (rb_respond_to(samples, id_length)) {
699
+ n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
700
+ } else {
701
+ rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
702
+ }
703
+ }
704
+ float * c_samples = (float *)malloc(n_samples * sizeof(float));
705
+ if (memory_view_available_p) {
706
+ c_samples = (float *)view.data;
707
+ } else {
708
+ if (TYPE(samples) == T_ARRAY) {
709
+ for (int i = 0; i < n_samples; i++) {
710
+ c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
711
+ }
712
+ } else {
713
+ // FIXME: use rb_block_call
714
+ VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
715
+ for (int i = 0; i < n_samples; i++) {
716
+ // TODO: check if iter is exhausted and raise ArgumentError
717
+ VALUE sample = rb_funcall(iter, id_next, 0);
718
+ c_samples[i] = RFLOAT_VALUE(sample);
719
+ }
720
+ }
721
+ }
722
+ const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
723
+ if (0 == result) {
724
+ return Qnil;
725
+ } else {
726
+ rb_exc_raise(rb_funcall(eError, id_new, 1, result));
727
+ }
728
+ }
729
+
730
+ /*
731
+ * Number of segments.
732
+ *
733
+ * call-seq:
734
+ * full_n_segments -> Integer
735
+ */
736
+ static VALUE ruby_whisper_full_n_segments(VALUE self) {
737
+ ruby_whisper *rw;
738
+ Data_Get_Struct(self, ruby_whisper, rw);
739
+ return INT2NUM(whisper_full_n_segments(rw->context));
740
+ }
741
+
742
+ /*
743
+ * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
744
+ *
745
+ * call-seq:
746
+ * full_lang_id -> Integer
747
+ */
748
+ static VALUE ruby_whisper_full_lang_id(VALUE self) {
749
+ ruby_whisper *rw;
750
+ Data_Get_Struct(self, ruby_whisper, rw);
751
+ return INT2NUM(whisper_full_lang_id(rw->context));
752
+ }
753
+
754
+ static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) {
755
+ const int c_i_segment = NUM2INT(i_segment);
756
+ if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
757
+ rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
758
+ }
759
+ return c_i_segment;
760
+ }
761
+
762
+ /*
763
+ * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
764
+ *
765
+ * full_get_segment_t0(3) # => 1668 (16680 ms)
766
+ *
767
+ * call-seq:
768
+ * full_get_segment_t0(segment_index) -> Integer
769
+ */
770
+ static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) {
771
+ ruby_whisper *rw;
772
+ Data_Get_Struct(self, ruby_whisper, rw);
773
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
774
+ const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
775
+ return INT2NUM(t0);
776
+ }
777
+
778
+ /*
779
+ * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
780
+ *
781
+ * full_get_segment_t1(3) # => 1668 (16680 ms)
782
+ *
783
+ * call-seq:
784
+ * full_get_segment_t1(segment_index) -> Integer
785
+ */
786
+ static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) {
787
+ ruby_whisper *rw;
788
+ Data_Get_Struct(self, ruby_whisper, rw);
789
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
790
+ const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
791
+ return INT2NUM(t1);
792
+ }
793
+
794
+ /*
795
+ * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
796
+ *
797
+ * full_get_segment_speacker_turn_next(3) # => true
798
+ *
799
+ * call-seq:
800
+ * full_get_segment_speacker_turn_next(segment_index) -> bool
801
+ */
802
+ static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) {
803
+ ruby_whisper *rw;
804
+ Data_Get_Struct(self, ruby_whisper, rw);
805
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
806
+ const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
807
+ return speaker_turn_next ? Qtrue : Qfalse;
808
+ }
809
+
810
+ /*
811
+ * Text of a segment indexed by +segment_index+.
812
+ *
813
+ * full_get_segment_text(3) # => "ask not what your country can do for you, ..."
814
+ *
815
+ * call-seq:
816
+ * full_get_segment_text(segment_index) -> String
817
+ */
818
+ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
819
+ ruby_whisper *rw;
820
+ Data_Get_Struct(self, ruby_whisper, rw);
821
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
822
+ const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
823
+ return rb_str_new2(text);
824
+ }
825
+
226
826
  /*
227
827
  * params.language = "auto" | "en", etc...
828
+ *
829
+ * call-seq:
830
+ * language = lang_name -> lang_name
228
831
  */
229
832
  static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
230
833
  ruby_whisper_params *rwp;
@@ -236,6 +839,10 @@ static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
236
839
  }
237
840
  return value;
238
841
  }
842
+ /*
843
+ * call-seq:
844
+ * language -> String
845
+ */
239
846
  static VALUE ruby_whisper_params_get_language(VALUE self) {
240
847
  ruby_whisper_params *rwp;
241
848
  Data_Get_Struct(self, ruby_whisper_params, rwp);
@@ -245,78 +852,209 @@ static VALUE ruby_whisper_params_get_language(VALUE self) {
245
852
  return rb_str_new2("auto");
246
853
  }
247
854
  }
855
+ /*
856
+ * call-seq:
857
+ * translate = do_translate -> do_translate
858
+ */
248
859
  static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
249
860
  BOOL_PARAMS_SETTER(self, translate, value)
250
861
  }
862
+ /*
863
+ * call-seq:
864
+ * translate -> bool
865
+ */
251
866
  static VALUE ruby_whisper_params_get_translate(VALUE self) {
252
867
  BOOL_PARAMS_GETTER(self, translate)
253
868
  }
869
+ /*
870
+ * call-seq:
871
+ * no_context = dont_use_context -> dont_use_context
872
+ */
254
873
  static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
255
874
  BOOL_PARAMS_SETTER(self, no_context, value)
256
875
  }
876
+ /*
877
+ * If true, does not use past transcription (if any) as initial prompt for the decoder.
878
+ *
879
+ * call-seq:
880
+ * no_context -> bool
881
+ */
257
882
  static VALUE ruby_whisper_params_get_no_context(VALUE self) {
258
883
  BOOL_PARAMS_GETTER(self, no_context)
259
884
  }
885
+ /*
886
+ * call-seq:
887
+ * single_segment = force_single -> force_single
888
+ */
260
889
  static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
261
890
  BOOL_PARAMS_SETTER(self, single_segment, value)
262
891
  }
892
+ /*
893
+ * If true, forces single segment output (useful for streaming).
894
+ *
895
+ * call-seq:
896
+ * single_segment -> bool
897
+ */
263
898
  static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
264
899
  BOOL_PARAMS_GETTER(self, single_segment)
265
900
  }
901
+ /*
902
+ * call-seq:
903
+ * print_special = force_print -> force_print
904
+ */
266
905
  static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
267
906
  BOOL_PARAMS_SETTER(self, print_special, value)
268
907
  }
908
+ /*
909
+ * If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
910
+ *
911
+ * call-seq:
912
+ * print_special -> bool
913
+ */
269
914
  static VALUE ruby_whisper_params_get_print_special(VALUE self) {
270
915
  BOOL_PARAMS_GETTER(self, print_special)
271
916
  }
917
+ /*
918
+ * call-seq:
919
+ * print_progress = force_print -> force_print
920
+ */
272
921
  static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
273
922
  BOOL_PARAMS_SETTER(self, print_progress, value)
274
923
  }
924
+ /*
925
+ * If true, prints progress information.
926
+ *
927
+ * call-seq:
928
+ * print_progress -> bool
929
+ */
275
930
  static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
276
931
  BOOL_PARAMS_GETTER(self, print_progress)
277
932
  }
933
+ /*
934
+ * call-seq:
935
+ * print_realtime = force_print -> force_print
936
+ */
278
937
  static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
279
938
  BOOL_PARAMS_SETTER(self, print_realtime, value)
280
939
  }
940
+ /*
941
+ * If true, prints results from within whisper.cpp. (avoid it, use callback instead)
942
+ * call-seq:
943
+ * print_realtime -> bool
944
+ */
281
945
  static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
282
946
  BOOL_PARAMS_GETTER(self, print_realtime)
283
947
  }
948
+ /*
949
+ * call-seq:
950
+ * print_timestamps = force_print -> force_print
951
+ */
284
952
  static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
285
953
  BOOL_PARAMS_SETTER(self, print_timestamps, value)
286
954
  }
955
+ /*
956
+ * If true, prints timestamps for each text segment when printing realtime.
957
+ *
958
+ * call-seq:
959
+ * print_timestamps -> bool
960
+ */
287
961
  static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
288
962
  BOOL_PARAMS_GETTER(self, print_timestamps)
289
963
  }
964
+ /*
965
+ * call-seq:
966
+ * suppress_blank = force_suppress -> force_suppress
967
+ */
290
968
  static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
291
969
  BOOL_PARAMS_SETTER(self, suppress_blank, value)
292
970
  }
971
+ /*
972
+ * If true, suppresses blank outputs.
973
+ *
974
+ * call-seq:
975
+ * suppress_blank -> bool
976
+ */
293
977
  static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
294
978
  BOOL_PARAMS_GETTER(self, suppress_blank)
295
979
  }
980
+ /*
981
+ * call-seq:
982
+ * suppress_non_speech_tokens = force_suppress -> force_suppress
983
+ */
296
984
  static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
297
985
  BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
298
986
  }
987
+ /*
988
+ * If true, suppresses non-speech-tokens.
989
+ *
990
+ * call-seq:
991
+ * suppress_non_speech_tokens -> bool
992
+ */
299
993
  static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
300
994
  BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
301
995
  }
996
+ /*
997
+ * If true, enables token-level timestamps.
998
+ *
999
+ * call-seq:
1000
+ * token_timestamps -> bool
1001
+ */
302
1002
  static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
303
1003
  BOOL_PARAMS_GETTER(self, token_timestamps)
304
1004
  }
1005
+ /*
1006
+ * call-seq:
1007
+ * token_timestamps = force_timestamps -> force_timestamps
1008
+ */
305
1009
  static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
306
1010
  BOOL_PARAMS_SETTER(self, token_timestamps, value)
307
1011
  }
1012
+ /*
1013
+ * If true, split on word rather than on token (when used with max_len).
1014
+ *
1015
+ * call-seq:
1016
+ * translate -> bool
1017
+ */
308
1018
  static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
309
1019
  BOOL_PARAMS_GETTER(self, split_on_word)
310
1020
  }
1021
+ /*
1022
+ * call-seq:
1023
+ * split_on_word = force_split -> force_split
1024
+ */
311
1025
  static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
312
1026
  BOOL_PARAMS_SETTER(self, split_on_word, value)
313
1027
  }
314
- static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
315
- BOOL_PARAMS_GETTER(self, speed_up)
1028
+ /*
1029
+ * Tokens to provide to the whisper decoder as initial prompt
1030
+ * these are prepended to any existing text context from a previous call
1031
+ * use whisper_tokenize() to convert text to tokens.
1032
+ * Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
1033
+ *
1034
+ * call-seq:
1035
+ * initial_prompt -> String
1036
+ */
1037
+ static VALUE ruby_whisper_params_get_initial_prompt(VALUE self) {
1038
+ ruby_whisper_params *rwp;
1039
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1040
+ return rwp->params.initial_prompt == nullptr ? Qnil : rb_str_new2(rwp->params.initial_prompt);
316
1041
  }
317
- static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
318
- BOOL_PARAMS_SETTER(self, speed_up, value)
1042
+ /*
1043
+ * call-seq:
1044
+ * initial_prompt = prompt -> prompt
1045
+ */
1046
+ static VALUE ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) {
1047
+ ruby_whisper_params *rwp;
1048
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1049
+ rwp->params.initial_prompt = StringValueCStr(value);
1050
+ return value;
319
1051
  }
1052
+ /*
1053
+ * If true, enables diarization.
1054
+ *
1055
+ * call-seq:
1056
+ * diarize -> bool
1057
+ */
320
1058
  static VALUE ruby_whisper_params_get_diarize(VALUE self) {
321
1059
  ruby_whisper_params *rwp;
322
1060
  Data_Get_Struct(self, ruby_whisper_params, rwp);
@@ -326,6 +1064,10 @@ static VALUE ruby_whisper_params_get_diarize(VALUE self) {
326
1064
  return Qfalse;
327
1065
  }
328
1066
  }
1067
+ /*
1068
+ * call-seq:
1069
+ * diarize = force_diarize -> force_diarize
1070
+ */
329
1071
  static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
330
1072
  ruby_whisper_params *rwp;
331
1073
  Data_Get_Struct(self, ruby_whisper_params, rwp);
@@ -337,22 +1079,42 @@ static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
337
1079
  return value;
338
1080
  }
339
1081
 
1082
+ /*
1083
+ * Start offset in ms.
1084
+ *
1085
+ * call-seq:
1086
+ * offset -> Integer
1087
+ */
340
1088
  static VALUE ruby_whisper_params_get_offset(VALUE self) {
341
1089
  ruby_whisper_params *rwp;
342
1090
  Data_Get_Struct(self, ruby_whisper_params, rwp);
343
1091
  return INT2NUM(rwp->params.offset_ms);
344
1092
  }
1093
+ /*
1094
+ * call-seq:
1095
+ * offset = offset_ms -> offset_ms
1096
+ */
345
1097
  static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
346
1098
  ruby_whisper_params *rwp;
347
1099
  Data_Get_Struct(self, ruby_whisper_params, rwp);
348
1100
  rwp->params.offset_ms = NUM2INT(value);
349
1101
  return value;
350
1102
  }
1103
+ /*
1104
+ * Audio duration to process in ms.
1105
+ *
1106
+ * call-seq:
1107
+ * duration -> Integer
1108
+ */
351
1109
  static VALUE ruby_whisper_params_get_duration(VALUE self) {
352
1110
  ruby_whisper_params *rwp;
353
1111
  Data_Get_Struct(self, ruby_whisper_params, rwp);
354
1112
  return INT2NUM(rwp->params.duration_ms);
355
1113
  }
1114
+ /*
1115
+ * call-seq:
1116
+ * duration = duration_ms -> duration_ms
1117
+ */
356
1118
  static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
357
1119
  ruby_whisper_params *rwp;
358
1120
  Data_Get_Struct(self, ruby_whisper_params, rwp);
@@ -360,27 +1122,695 @@ static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
360
1122
  return value;
361
1123
  }
362
1124
 
1125
+ /*
1126
+ * Max tokens to use from past text as prompt for the decoder.
1127
+ *
1128
+ * call-seq:
1129
+ * max_text_tokens -> Integer
1130
+ */
363
1131
  static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
364
1132
  ruby_whisper_params *rwp;
365
1133
  Data_Get_Struct(self, ruby_whisper_params, rwp);
366
1134
  return INT2NUM(rwp->params.n_max_text_ctx);
367
1135
  }
1136
+ /*
1137
+ * call-seq:
1138
+ * max_text_tokens = n_tokens -> n_tokens
1139
+ */
368
1140
  static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
369
1141
  ruby_whisper_params *rwp;
370
1142
  Data_Get_Struct(self, ruby_whisper_params, rwp);
371
1143
  rwp->params.n_max_text_ctx = NUM2INT(value);
372
1144
  return value;
373
1145
  }
1146
+ /*
1147
+ * call-seq:
1148
+ * temperature -> Float
1149
+ */
1150
+ static VALUE ruby_whisper_params_get_temperature(VALUE self) {
1151
+ ruby_whisper_params *rwp;
1152
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1153
+ return DBL2NUM(rwp->params.temperature);
1154
+ }
1155
+ /*
1156
+ * call-seq:
1157
+ * temperature = temp -> temp
1158
+ */
1159
+ static VALUE ruby_whisper_params_set_temperature(VALUE self, VALUE value) {
1160
+ ruby_whisper_params *rwp;
1161
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1162
+ rwp->params.temperature = RFLOAT_VALUE(value);
1163
+ return value;
1164
+ }
1165
+ /*
1166
+ * See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
1167
+ *
1168
+ * call-seq:
1169
+ * max_initial_ts -> Flaot
1170
+ */
1171
+ static VALUE ruby_whisper_params_get_max_initial_ts(VALUE self) {
1172
+ ruby_whisper_params *rwp;
1173
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1174
+ return DBL2NUM(rwp->params.max_initial_ts);
1175
+ }
1176
+ /*
1177
+ * call-seq:
1178
+ * max_initial_ts = timestamp -> timestamp
1179
+ */
1180
+ static VALUE ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) {
1181
+ ruby_whisper_params *rwp;
1182
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1183
+ rwp->params.max_initial_ts = RFLOAT_VALUE(value);
1184
+ return value;
1185
+ }
1186
+ /*
1187
+ * call-seq:
1188
+ * length_penalty -> Float
1189
+ */
1190
+ static VALUE ruby_whisper_params_get_length_penalty(VALUE self) {
1191
+ ruby_whisper_params *rwp;
1192
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1193
+ return DBL2NUM(rwp->params.length_penalty);
1194
+ }
1195
+ /*
1196
+ * call-seq:
1197
+ * length_penalty = penalty -> penalty
1198
+ */
1199
+ static VALUE ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) {
1200
+ ruby_whisper_params *rwp;
1201
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1202
+ rwp->params.length_penalty = RFLOAT_VALUE(value);
1203
+ return value;
1204
+ }
1205
+ /*
1206
+ * call-seq:
1207
+ * temperature_inc -> Float
1208
+ */
1209
+ static VALUE ruby_whisper_params_get_temperature_inc(VALUE self) {
1210
+ ruby_whisper_params *rwp;
1211
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1212
+ return DBL2NUM(rwp->params.temperature_inc);
1213
+ }
1214
+ /*
1215
+ * call-seq:
1216
+ * temperature_inc = inc -> inc
1217
+ */
1218
+ static VALUE ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) {
1219
+ ruby_whisper_params *rwp;
1220
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1221
+ rwp->params.temperature_inc = RFLOAT_VALUE(value);
1222
+ return value;
1223
+ }
1224
+ /*
1225
+ * Similar to OpenAI's "compression_ratio_threshold"
1226
+ *
1227
+ * call-seq:
1228
+ * entropy_thold -> Float
1229
+ */
1230
+ static VALUE ruby_whisper_params_get_entropy_thold(VALUE self) {
1231
+ ruby_whisper_params *rwp;
1232
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1233
+ return DBL2NUM(rwp->params.entropy_thold);
1234
+ }
1235
+ /*
1236
+ * call-seq:
1237
+ * entropy_thold = threshold -> threshold
1238
+ */
1239
+ static VALUE ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) {
1240
+ ruby_whisper_params *rwp;
1241
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1242
+ rwp->params.entropy_thold = RFLOAT_VALUE(value);
1243
+ return value;
1244
+ }
1245
+ /*
1246
+ * call-seq:
1247
+ * logprob_thold -> Float
1248
+ */
1249
+ static VALUE ruby_whisper_params_get_logprob_thold(VALUE self) {
1250
+ ruby_whisper_params *rwp;
1251
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1252
+ return DBL2NUM(rwp->params.logprob_thold);
1253
+ }
1254
+ /*
1255
+ * call-seq:
1256
+ * logprob_thold = threshold -> threshold
1257
+ */
1258
+ static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) {
1259
+ ruby_whisper_params *rwp;
1260
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1261
+ rwp->params.logprob_thold = RFLOAT_VALUE(value);
1262
+ return value;
1263
+ }
1264
+ /*
1265
+ * call-seq:
1266
+ * no_speech_thold -> Float
1267
+ */
1268
+ static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) {
1269
+ ruby_whisper_params *rwp;
1270
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1271
+ return DBL2NUM(rwp->params.no_speech_thold);
1272
+ }
1273
+ /*
1274
+ * call-seq:
1275
+ * no_speech_thold = threshold -> threshold
1276
+ */
1277
+ static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) {
1278
+ ruby_whisper_params *rwp;
1279
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1280
+ rwp->params.no_speech_thold = RFLOAT_VALUE(value);
1281
+ return value;
1282
+ }
1283
+ /*
1284
+ * Sets new segment callback, called for every newly generated text segment.
1285
+ *
1286
+ * params.new_segment_callback = ->(context, _, n_new, user_data) {
1287
+ * # ...
1288
+ * }
1289
+ *
1290
+ * call-seq:
1291
+ * new_segment_callback = callback -> callback
1292
+ */
1293
+ static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) {
1294
+ ruby_whisper_params *rwp;
1295
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1296
+ rwp->new_segment_callback_container->callback = value;
1297
+ return value;
1298
+ }
1299
+ /*
1300
+ * Sets user data passed to the last argument of new segment callback.
1301
+ *
1302
+ * call-seq:
1303
+ * new_segment_callback_user_data = user_data -> use_data
1304
+ */
1305
+ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) {
1306
+ ruby_whisper_params *rwp;
1307
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1308
+ rwp->new_segment_callback_container->user_data = value;
1309
+ return value;
1310
+ }
1311
+ /*
1312
+ * Sets progress callback, called on each progress update.
1313
+ *
1314
+ * params.new_segment_callback = ->(context, _, n_new, user_data) {
1315
+ * # ...
1316
+ * }
1317
+ *
1318
+ * call-seq:
1319
+ * progress_callback = callback -> callback
1320
+ */
1321
+ static VALUE ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) {
1322
+ ruby_whisper_params *rwp;
1323
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1324
+ rwp->progress_callback_container->callback = value;
1325
+ return value;
1326
+ }
1327
+ /*
1328
+ * Sets user data passed to the last argument of progress callback.
1329
+ *
1330
+ * call-seq:
1331
+ * progress_callback_user_data = user_data -> use_data
1332
+ */
1333
+ static VALUE ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) {
1334
+ ruby_whisper_params *rwp;
1335
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1336
+ rwp->progress_callback_container->user_data = value;
1337
+ return value;
1338
+ }
1339
+ /*
1340
+ * Sets abort callback, called to check if the process should be aborted.
1341
+ *
1342
+ * params.abort_callback = ->(user_data) {
1343
+ * # ...
1344
+ * }
1345
+ *
1346
+ * call-seq:
1347
+ * abort_callback = callback -> callback
1348
+ */
1349
+ static VALUE ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) {
1350
+ ruby_whisper_params *rwp;
1351
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1352
+ rwp->abort_callback_container->callback = value;
1353
+ return value;
1354
+ }
1355
+ /*
1356
+ * Sets user data passed to the last argument of abort callback.
1357
+ *
1358
+ * call-seq:
1359
+ * abort_callback_user_data = user_data -> use_data
1360
+ */
1361
+ static VALUE ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) {
1362
+ ruby_whisper_params *rwp;
1363
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1364
+ rwp->abort_callback_container->user_data = value;
1365
+ return value;
1366
+ }
1367
+
1368
+ // High level API
1369
+
1370
+ typedef struct {
1371
+ VALUE context;
1372
+ int index;
1373
+ } ruby_whisper_segment;
1374
+
1375
+ typedef struct {
1376
+ VALUE context;
1377
+ } ruby_whisper_model;
1378
+
1379
+ static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
1380
+ rb_gc_mark(rws->context);
1381
+ }
1382
+
1383
+ static VALUE ruby_whisper_segment_allocate(VALUE klass) {
1384
+ ruby_whisper_segment *rws;
1385
+ rws = ALLOC(ruby_whisper_segment);
1386
+ return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
1387
+ }
1388
+
1389
+ static VALUE rb_whisper_segment_initialize(VALUE context, int index) {
1390
+ ruby_whisper_segment *rws;
1391
+ const VALUE segment = ruby_whisper_segment_allocate(cSegment);
1392
+ Data_Get_Struct(segment, ruby_whisper_segment, rws);
1393
+ rws->context = context;
1394
+ rws->index = index;
1395
+ return segment;
1396
+ };
1397
+
1398
+ /*
1399
+ * Yields each Whisper::Segment:
1400
+ *
1401
+ * whisper.transcribe("path/to/audio.wav", params)
1402
+ * whisper.each_segment do |segment|
1403
+ * puts segment.text
1404
+ * end
1405
+ *
1406
+ * Returns an Enumerator if no block given:
1407
+ *
1408
+ * whisper.transcribe("path/to/audio.wav", params)
1409
+ * enum = whisper.each_segment
1410
+ * enum.to_a # => [#<Whisper::Segment>, ...]
1411
+ *
1412
+ * call-seq:
1413
+ * each_segment {|segment| ... }
1414
+ * each_segment -> Enumerator
1415
+ */
1416
+ static VALUE ruby_whisper_each_segment(VALUE self) {
1417
+ if (!rb_block_given_p()) {
1418
+ const VALUE method_name = rb_funcall(self, id___method__, 0);
1419
+ return rb_funcall(self, id_to_enum, 1, method_name);
1420
+ }
1421
+
1422
+ ruby_whisper *rw;
1423
+ Data_Get_Struct(self, ruby_whisper, rw);
1424
+
1425
+ const int n_segments = whisper_full_n_segments(rw->context);
1426
+ for (int i = 0; i < n_segments; ++i) {
1427
+ rb_yield(rb_whisper_segment_initialize(self, i));
1428
+ }
1429
+
1430
+ return self;
1431
+ }
1432
+
1433
+ /*
1434
+ * Hook called on new segment. Yields each Whisper::Segment.
1435
+ *
1436
+ * whisper.on_new_segment do |segment|
1437
+ * # ...
1438
+ * end
1439
+ *
1440
+ * call-seq:
1441
+ * on_new_segment {|segment| ... }
1442
+ */
1443
+ static VALUE ruby_whisper_params_on_new_segment(VALUE self) {
1444
+ ruby_whisper_params *rws;
1445
+ Data_Get_Struct(self, ruby_whisper_params, rws);
1446
+ const VALUE blk = rb_block_proc();
1447
+ rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
1448
+ return Qnil;
1449
+ }
1450
+
1451
+ /*
1452
+ * Hook called on progress update. Yields each progress Integer between 0 and 100.
1453
+ *
1454
+ * whisper.on_progress do |progress|
1455
+ * # ...
1456
+ * end
1457
+ *
1458
+ * call-seq:
1459
+ * on_progress {|progress| ... }
1460
+ */
1461
+ static VALUE ruby_whisper_params_on_progress(VALUE self) {
1462
+ ruby_whisper_params *rws;
1463
+ Data_Get_Struct(self, ruby_whisper_params, rws);
1464
+ const VALUE blk = rb_block_proc();
1465
+ rb_ary_push(rws->progress_callback_container->callbacks, blk);
1466
+ return Qnil;
1467
+ }
1468
+
1469
+ /*
1470
+ * Call block to determine whether abort or not. Return +true+ when you want to abort.
1471
+ *
1472
+ * params.abort_on do
1473
+ * if some_condition
1474
+ * true # abort
1475
+ * else
1476
+ * false # continue
1477
+ * end
1478
+ * end
1479
+ *
1480
+ * call-seq:
1481
+ * abort_on { ... }
1482
+ */
1483
+ static VALUE ruby_whisper_params_abort_on(VALUE self) {
1484
+ ruby_whisper_params *rws;
1485
+ Data_Get_Struct(self, ruby_whisper_params, rws);
1486
+ const VALUE blk = rb_block_proc();
1487
+ rb_ary_push(rws->abort_callback_container->callbacks, blk);
1488
+ return Qnil;
1489
+ }
1490
+
1491
+ /*
1492
+ * Start time in milliseconds.
1493
+ *
1494
+ * call-seq:
1495
+ * start_time -> Integer
1496
+ */
1497
+ static VALUE ruby_whisper_segment_get_start_time(VALUE self) {
1498
+ ruby_whisper_segment *rws;
1499
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
1500
+ ruby_whisper *rw;
1501
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
1502
+ const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
1503
+ // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
1504
+ return INT2NUM(t0 * 10);
1505
+ }
1506
+
1507
+ /*
1508
+ * End time in milliseconds.
1509
+ *
1510
+ * call-seq:
1511
+ * end_time -> Integer
1512
+ */
1513
+ static VALUE ruby_whisper_segment_get_end_time(VALUE self) {
1514
+ ruby_whisper_segment *rws;
1515
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
1516
+ ruby_whisper *rw;
1517
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
1518
+ const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
1519
+ // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
1520
+ return INT2NUM(t1 * 10);
1521
+ }
1522
+
1523
+ /*
1524
+ * Whether the next segment is predicted as a speaker turn.
1525
+ *
1526
+ * call-seq:
1527
+ * speaker_turn_next? -> bool
1528
+ */
1529
+ static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) {
1530
+ ruby_whisper_segment *rws;
1531
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
1532
+ ruby_whisper *rw;
1533
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
1534
+ return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
1535
+ }
1536
+
1537
+ /*
1538
+ * call-seq:
1539
+ * text -> String
1540
+ */
1541
+ static VALUE ruby_whisper_segment_get_text(VALUE self) {
1542
+ ruby_whisper_segment *rws;
1543
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
1544
+ ruby_whisper *rw;
1545
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
1546
+ const char * text = whisper_full_get_segment_text(rw->context, rws->index);
1547
+ return rb_str_new2(text);
1548
+ }
1549
+
1550
+ static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
1551
+ rb_gc_mark(rwm->context);
1552
+ }
1553
+
1554
+ static VALUE ruby_whisper_model_allocate(VALUE klass) {
1555
+ ruby_whisper_model *rwm;
1556
+ rwm = ALLOC(ruby_whisper_model);
1557
+ return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
1558
+ }
1559
+
1560
+ static VALUE rb_whisper_model_initialize(VALUE context) {
1561
+ ruby_whisper_model *rwm;
1562
+ const VALUE model = ruby_whisper_model_allocate(cModel);
1563
+ Data_Get_Struct(model, ruby_whisper_model, rwm);
1564
+ rwm->context = context;
1565
+ return model;
1566
+ };
1567
+
1568
+ /*
1569
+ * call-seq:
1570
+ * model -> Whisper::Model
1571
+ */
1572
+ static VALUE ruby_whisper_get_model(VALUE self) {
1573
+ return rb_whisper_model_initialize(self);
1574
+ }
1575
+
1576
+ /*
1577
+ * call-seq:
1578
+ * n_vocab -> Integer
1579
+ */
1580
+ static VALUE ruby_whisper_c_model_n_vocab(VALUE self) {
1581
+ ruby_whisper_model *rwm;
1582
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1583
+ ruby_whisper *rw;
1584
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1585
+ return INT2NUM(whisper_model_n_vocab(rw->context));
1586
+ }
1587
+
1588
+ /*
1589
+ * call-seq:
1590
+ * n_audio_ctx -> Integer
1591
+ */
1592
+ static VALUE ruby_whisper_c_model_n_audio_ctx(VALUE self) {
1593
+ ruby_whisper_model *rwm;
1594
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1595
+ ruby_whisper *rw;
1596
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1597
+ return INT2NUM(whisper_model_n_audio_ctx(rw->context));
1598
+ }
1599
+
1600
+ /*
1601
+ * call-seq:
1602
+ * n_audio_state -> Integer
1603
+ */
1604
+ static VALUE ruby_whisper_c_model_n_audio_state(VALUE self) {
1605
+ ruby_whisper_model *rwm;
1606
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1607
+ ruby_whisper *rw;
1608
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1609
+ return INT2NUM(whisper_model_n_audio_state(rw->context));
1610
+ }
1611
+
1612
+ /*
1613
+ * call-seq:
1614
+ * n_audio_head -> Integer
1615
+ */
1616
+ static VALUE ruby_whisper_c_model_n_audio_head(VALUE self) {
1617
+ ruby_whisper_model *rwm;
1618
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1619
+ ruby_whisper *rw;
1620
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1621
+ return INT2NUM(whisper_model_n_audio_head(rw->context));
1622
+ }
1623
+
1624
+ /*
1625
+ * call-seq:
1626
+ * n_audio_layer -> Integer
1627
+ */
1628
+ static VALUE ruby_whisper_c_model_n_audio_layer(VALUE self) {
1629
+ ruby_whisper_model *rwm;
1630
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1631
+ ruby_whisper *rw;
1632
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1633
+ return INT2NUM(whisper_model_n_audio_layer(rw->context));
1634
+ }
1635
+
1636
+ /*
1637
+ * call-seq:
1638
+ * n_text_ctx -> Integer
1639
+ */
1640
+ static VALUE ruby_whisper_c_model_n_text_ctx(VALUE self) {
1641
+ ruby_whisper_model *rwm;
1642
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1643
+ ruby_whisper *rw;
1644
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1645
+ return INT2NUM(whisper_model_n_text_ctx(rw->context));
1646
+ }
1647
+
1648
+ /*
1649
+ * call-seq:
1650
+ * n_text_state -> Integer
1651
+ */
1652
+ static VALUE ruby_whisper_c_model_n_text_state(VALUE self) {
1653
+ ruby_whisper_model *rwm;
1654
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1655
+ ruby_whisper *rw;
1656
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1657
+ return INT2NUM(whisper_model_n_text_state(rw->context));
1658
+ }
1659
+
1660
+ /*
1661
+ * call-seq:
1662
+ * n_text_head -> Integer
1663
+ */
1664
+ static VALUE ruby_whisper_c_model_n_text_head(VALUE self) {
1665
+ ruby_whisper_model *rwm;
1666
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1667
+ ruby_whisper *rw;
1668
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1669
+ return INT2NUM(whisper_model_n_text_head(rw->context));
1670
+ }
1671
+
1672
+ /*
1673
+ * call-seq:
1674
+ * n_text_layer -> Integer
1675
+ */
1676
+ static VALUE ruby_whisper_c_model_n_text_layer(VALUE self) {
1677
+ ruby_whisper_model *rwm;
1678
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1679
+ ruby_whisper *rw;
1680
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1681
+ return INT2NUM(whisper_model_n_text_layer(rw->context));
1682
+ }
1683
+
1684
+ /*
1685
+ * call-seq:
1686
+ * n_mels -> Integer
1687
+ */
1688
+ static VALUE ruby_whisper_c_model_n_mels(VALUE self) {
1689
+ ruby_whisper_model *rwm;
1690
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1691
+ ruby_whisper *rw;
1692
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1693
+ return INT2NUM(whisper_model_n_mels(rw->context));
1694
+ }
1695
+
1696
+ /*
1697
+ * call-seq:
1698
+ * ftype -> Integer
1699
+ */
1700
+ static VALUE ruby_whisper_c_model_ftype(VALUE self) {
1701
+ ruby_whisper_model *rwm;
1702
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1703
+ ruby_whisper *rw;
1704
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1705
+ return INT2NUM(whisper_model_ftype(rw->context));
1706
+ }
1707
+
1708
+ /*
1709
+ * call-seq:
1710
+ * type -> String
1711
+ */
1712
+ static VALUE ruby_whisper_c_model_type(VALUE self) {
1713
+ ruby_whisper_model *rwm;
1714
+ Data_Get_Struct(self, ruby_whisper_model, rwm);
1715
+ ruby_whisper *rw;
1716
+ Data_Get_Struct(rwm->context, ruby_whisper, rw);
1717
+ return rb_str_new2(whisper_model_type_readable(rw->context));
1718
+ }
1719
+
1720
+ static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
1721
+ const int c_code = NUM2INT(code);
1722
+ char *raw_message;
1723
+ switch (c_code) {
1724
+ case -2:
1725
+ raw_message = "failed to compute log mel spectrogram";
1726
+ break;
1727
+ case -3:
1728
+ raw_message = "failed to auto-detect language";
1729
+ break;
1730
+ case -4:
1731
+ raw_message = "too many decoders requested";
1732
+ break;
1733
+ case -5:
1734
+ raw_message = "audio_ctx is larger than the maximum allowed";
1735
+ break;
1736
+ case -6:
1737
+ raw_message = "failed to encode";
1738
+ break;
1739
+ case -7:
1740
+ raw_message = "whisper_kv_cache_init() failed for self-attention cache";
1741
+ break;
1742
+ case -8:
1743
+ raw_message = "failed to decode";
1744
+ break;
1745
+ case -9:
1746
+ raw_message = "failed to decode";
1747
+ break;
1748
+ default:
1749
+ raw_message = "unknown error";
1750
+ break;
1751
+ }
1752
+ const VALUE message = rb_str_new2(raw_message);
1753
+ rb_call_super(1, &message);
1754
+ rb_iv_set(self, "@code", code);
1755
+
1756
+ return self;
1757
+ }
1758
+
374
1759
 
375
1760
  void Init_whisper() {
1761
+ id_to_s = rb_intern("to_s");
1762
+ id_call = rb_intern("call");
1763
+ id___method__ = rb_intern("__method__");
1764
+ id_to_enum = rb_intern("to_enum");
1765
+ id_length = rb_intern("length");
1766
+ id_next = rb_intern("next");
1767
+ id_new = rb_intern("new");
1768
+ id_to_path = rb_intern("to_path");
1769
+ id_pre_converted_models = rb_intern("pre_converted_models");
1770
+
376
1771
  mWhisper = rb_define_module("Whisper");
377
1772
  cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
378
1773
  cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
1774
+ eError = rb_define_class_under(mWhisper, "Error", rb_eStandardError);
1775
+
1776
+ rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
1777
+ rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
1778
+ rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
1779
+ rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR));
1780
+ rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
1781
+ rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
1782
+
1783
+ rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
1784
+ rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
1785
+ rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
1786
+ rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
1787
+ rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
1788
+ rb_define_singleton_method(mWhisper, "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
379
1789
 
380
1790
  rb_define_alloc_func(cContext, ruby_whisper_allocate);
381
1791
  rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
382
1792
 
383
1793
  rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
1794
+ rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0);
1795
+ rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
1796
+ rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0);
1797
+ rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0);
1798
+ rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
1799
+ rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
1800
+ rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0);
1801
+ rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0);
1802
+ rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0);
1803
+ rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0);
1804
+ rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0);
1805
+ rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0);
1806
+ rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
1807
+ rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
1808
+ rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
1809
+ rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
1810
+ rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
1811
+ rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
1812
+ rb_define_method(cContext, "full", ruby_whisper_full, -1);
1813
+ rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
384
1814
 
385
1815
  rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
386
1816
 
@@ -408,8 +1838,8 @@ void Init_whisper() {
408
1838
  rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
409
1839
  rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
410
1840
  rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
411
- rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
412
- rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
1841
+ rb_define_method(cParams, "initial_prompt", ruby_whisper_params_get_initial_prompt, 0);
1842
+ rb_define_method(cParams, "initial_prompt=", ruby_whisper_params_set_initial_prompt, 1);
413
1843
  rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
414
1844
  rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
415
1845
 
@@ -420,6 +1850,59 @@ void Init_whisper() {
420
1850
 
421
1851
  rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
422
1852
  rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
1853
+ rb_define_method(cParams, "temperature", ruby_whisper_params_get_temperature, 0);
1854
+ rb_define_method(cParams, "temperature=", ruby_whisper_params_set_temperature, 1);
1855
+ rb_define_method(cParams, "max_initial_ts", ruby_whisper_params_get_max_initial_ts, 0);
1856
+ rb_define_method(cParams, "max_initial_ts=", ruby_whisper_params_set_max_initial_ts, 1);
1857
+ rb_define_method(cParams, "length_penalty", ruby_whisper_params_get_length_penalty, 0);
1858
+ rb_define_method(cParams, "length_penalty=", ruby_whisper_params_set_length_penalty, 1);
1859
+ rb_define_method(cParams, "temperature_inc", ruby_whisper_params_get_temperature_inc, 0);
1860
+ rb_define_method(cParams, "temperature_inc=", ruby_whisper_params_set_temperature_inc, 1);
1861
+ rb_define_method(cParams, "entropy_thold", ruby_whisper_params_get_entropy_thold, 0);
1862
+ rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1);
1863
+ rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0);
1864
+ rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1);
1865
+ rb_define_method(cParams, "no_speech_thold", ruby_whisper_params_get_no_speech_thold, 0);
1866
+ rb_define_method(cParams, "no_speech_thold=", ruby_whisper_params_set_no_speech_thold, 1);
1867
+
1868
+ rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
1869
+ rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
1870
+ rb_define_method(cParams, "progress_callback=", ruby_whisper_params_set_progress_callback, 1);
1871
+ rb_define_method(cParams, "progress_callback_user_data=", ruby_whisper_params_set_progress_callback_user_data, 1);
1872
+ rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
1873
+ rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
1874
+
1875
+ rb_define_attr(eError, "code", true, false);
1876
+ rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
1877
+
1878
+ // High leve
1879
+ cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
1880
+
1881
+ rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
1882
+ rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
1883
+ rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
1884
+ rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
1885
+ rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
1886
+ rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
1887
+ rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
1888
+ rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
1889
+ rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
1890
+
1891
+ cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
1892
+ rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
1893
+ rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
1894
+ rb_define_method(cModel, "n_vocab", ruby_whisper_c_model_n_vocab, 0);
1895
+ rb_define_method(cModel, "n_audio_ctx", ruby_whisper_c_model_n_audio_ctx, 0);
1896
+ rb_define_method(cModel, "n_audio_state", ruby_whisper_c_model_n_audio_state, 0);
1897
+ rb_define_method(cModel, "n_audio_head", ruby_whisper_c_model_n_audio_head, 0);
1898
+ rb_define_method(cModel, "n_audio_layer", ruby_whisper_c_model_n_audio_layer, 0);
1899
+ rb_define_method(cModel, "n_text_ctx", ruby_whisper_c_model_n_text_ctx, 0);
1900
+ rb_define_method(cModel, "n_text_state", ruby_whisper_c_model_n_text_state, 0);
1901
+ rb_define_method(cModel, "n_text_head", ruby_whisper_c_model_n_text_head, 0);
1902
+ rb_define_method(cModel, "n_text_layer", ruby_whisper_c_model_n_text_layer, 0);
1903
+ rb_define_method(cModel, "n_mels", ruby_whisper_c_model_n_mels, 0);
1904
+ rb_define_method(cModel, "ftype", ruby_whisper_c_model_ftype, 0);
1905
+ rb_define_method(cModel, "type", ruby_whisper_c_model_type, 0);
423
1906
  }
424
1907
  #ifdef __cplusplus
425
1908
  }