whispercpp 1.3.0 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
data/extsources.rb ADDED
@@ -0,0 +1,6 @@
1
+ require "yaml"
2
+
3
+ sources = `git ls-files -z ../..`.split("\x0")
4
+ paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
5
+ paths.delete "bindings/ruby/**"
6
+ EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources
@@ -0,0 +1,157 @@
1
+ require "whisper.so"
2
+ require "uri"
3
+ require "net/http"
4
+ require "time"
5
+ require "pathname"
6
+ require "io/console/size"
7
+
8
+ class Whisper::Model
9
+ class URI
10
+ def initialize(uri)
11
+ @uri = URI(uri)
12
+ end
13
+
14
+ def to_path
15
+ cache
16
+ cache_path.to_path
17
+ end
18
+
19
+ def clear_cache
20
+ path = cache_path
21
+ path.delete if path.exist?
22
+ end
23
+
24
+ private
25
+
26
+ def cache_path
27
+ base_cache_dir/@uri.host/@uri.path[1..]
28
+ end
29
+
30
+ def base_cache_dir
31
+ base = case RUBY_PLATFORM
32
+ when /mswin|mingw/
33
+ ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
34
+ when /darwin/
35
+ Pathname(Dir.home)/"Library/Caches"
36
+ else
37
+ ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
38
+ end
39
+ base/"whisper.cpp"
40
+ end
41
+
42
+ def cache
43
+ path = cache_path
44
+ headers = {}
45
+ headers["if-modified-since"] = path.mtime.httpdate if path.exist?
46
+ request @uri, headers
47
+ path
48
+ end
49
+
50
+ def request(uri, headers)
51
+ Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
52
+ request = Net::HTTP::Get.new(uri, headers)
53
+ http.request request do |response|
54
+ case response
55
+ when Net::HTTPNotModified
56
+ # noop
57
+ when Net::HTTPOK
58
+ download response
59
+ when Net::HTTPRedirection
60
+ request URI(response["location"]), headers
61
+ else
62
+ return if headers.key?("if-modified-since") # Use cache file
63
+
64
+ raise "#{response.code} #{response.message}\n#{response.body}"
65
+ end
66
+ end
67
+ end
68
+ end
69
+
70
+ def download(response)
71
+ path = cache_path
72
+ path.dirname.mkpath unless path.dirname.exist?
73
+ downloading_path = Pathname("#{path}.downloading")
74
+ size = response.content_length
75
+ downloading_path.open "wb" do |file|
76
+ downloaded = 0
77
+ response.read_body do |chunk|
78
+ file << chunk
79
+ downloaded += chunk.bytesize
80
+ show_progress downloaded, size
81
+ end
82
+ end
83
+ downloading_path.rename path
84
+ end
85
+
86
+ def show_progress(current, size)
87
+ return unless $stderr.tty?
88
+ return unless size
89
+
90
+ unless @prev
91
+ @prev = Time.now
92
+ $stderr.puts "Downloading #{@uri}"
93
+ end
94
+
95
+ now = Time.now
96
+ return if now - @prev < 1 && current < size
97
+
98
+ progress_width = 20
99
+ progress = current.to_f / size
100
+ arrow_length = progress * progress_width
101
+ arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
102
+ line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
103
+ padding = ' ' * ($stderr.winsize[1] - line.size)
104
+ $stderr.print "\r#{line}#{padding}"
105
+ $stderr.puts if current >= size
106
+ @prev = now
107
+ end
108
+
109
+ def format_bytesize(bytesize)
110
+ return "0.0 B" if bytesize.zero?
111
+
112
+ units = %w[B KiB MiB GiB TiB]
113
+ exp = (Math.log(bytesize) / Math.log(1024)).to_i
114
+ format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
115
+ end
116
+ end
117
+
118
+ @pre_converted_models = {}
119
+ %w[
120
+ tiny
121
+ tiny.en
122
+ tiny-q5_1
123
+ tiny.en-q5_1
124
+ tiny-q8_0
125
+ base
126
+ base.en
127
+ base-q5_1
128
+ base.en-q5_1
129
+ base-q8_0
130
+ small
131
+ small.en
132
+ small.en-tdrz
133
+ small-q5_1
134
+ small.en-q5_1
135
+ small-q8_0
136
+ medium
137
+ medium.en
138
+ medium-q5_0
139
+ medium.en-q5_0
140
+ medium-q8_0
141
+ large-v1
142
+ large-v2
143
+ large-v2-q5_0
144
+ large-v2-q8_0
145
+ large-v3
146
+ large-v3-q5_0
147
+ large-v3-turbo
148
+ large-v3-turbo-q5_0
149
+ large-v3-turbo-q8_0
150
+ ].each do |name|
151
+ @pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
152
+ end
153
+
154
+ class << self
155
+ attr_reader :pre_converted_models
156
+ end
157
+ end
data/lib/whisper.rb ADDED
@@ -0,0 +1,2 @@
1
+ require "whisper.so"
2
+ require "whisper/model/uri"
data/tests/helper.rb ADDED
@@ -0,0 +1,7 @@
1
+ require "test/unit"
2
+ require "whisper"
3
+ require_relative "jfk_reader/jfk_reader"
4
+
5
+ class TestBase < Test::Unit::TestCase
6
+ AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
7
+ end
@@ -0,0 +1,5 @@
1
+ Makefile
2
+ jfk_reader.o
3
+ jfk_reader.so
4
+ jfk_reader.bundle
5
+ jfk_reader.dll
@@ -0,0 +1,3 @@
1
+ require "mkmf"
2
+
3
+ create_makefile("jfk_reader")
@@ -0,0 +1,68 @@
1
+ #include <ruby.h>
2
+ #include <ruby/memory_view.h>
3
+ #include <ruby/encoding.h>
4
+
5
+ static VALUE
6
+ jfk_reader_initialize(VALUE self, VALUE audio_path)
7
+ {
8
+ rb_iv_set(self, "audio_path", audio_path);
9
+ return Qnil;
10
+ }
11
+
12
+ static bool
13
+ jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
14
+ {
15
+ VALUE audio_path = rb_iv_get(obj, "audio_path");
16
+ const char *audio_path_str = StringValueCStr(audio_path);
17
+ const int n_samples = 176000;
18
+ float *data = (float *)malloc(n_samples * sizeof(float));
19
+ short *samples = (short *)malloc(n_samples * sizeof(short));
20
+ FILE *file = fopen(audio_path_str, "rb");
21
+
22
+ fseek(file, 78, SEEK_SET);
23
+ fread(samples, sizeof(short), n_samples, file);
24
+ fclose(file);
25
+ for (int i = 0; i < n_samples; i++) {
26
+ data[i] = samples[i]/32768.0;
27
+ }
28
+
29
+ view->obj = obj;
30
+ view->data = (void *)data;
31
+ view->byte_size = sizeof(float) * n_samples;
32
+ view->readonly = true;
33
+ view->format = "f";
34
+ view->item_size = sizeof(float);
35
+ view->item_desc.components = NULL;
36
+ view->item_desc.length = 0;
37
+ view->ndim = 1;
38
+ view->shape = NULL;
39
+ view->sub_offsets = NULL;
40
+ view->private_data = NULL;
41
+
42
+ return true;
43
+ }
44
+
45
+ static bool
46
+ jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
47
+ {
48
+ return true;
49
+ }
50
+
51
+ static bool
52
+ jfk_reader_memory_view_available_p(const VALUE obj)
53
+ {
54
+ return true;
55
+ }
56
+
57
+ static const rb_memory_view_entry_t jfk_reader_view_entry = {
58
+ jfk_reader_get_memory_view,
59
+ jfk_reader_release_memory_view,
60
+ jfk_reader_memory_view_available_p
61
+ };
62
+
63
+ void Init_jfk_reader(void)
64
+ {
65
+ VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
66
+ rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
67
+ rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
68
+ }
@@ -0,0 +1,160 @@
1
+ require_relative "helper"
2
+
3
+ class TestCallback < TestBase
4
+ def setup
5
+ GC.start
6
+ @params = Whisper::Params.new
7
+ @whisper = Whisper::Context.new("base.en")
8
+ @audio = File.join(AUDIO)
9
+ end
10
+
11
+ def test_new_segment_callback
12
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
13
+ assert_kind_of Integer, n_new
14
+ assert n_new > 0
15
+ assert_same @whisper, context
16
+
17
+ n_segments = context.full_n_segments
18
+ n_new.times do |i|
19
+ i_segment = n_segments - 1 + i
20
+ start_time = context.full_get_segment_t0(i_segment) * 10
21
+ end_time = context.full_get_segment_t1(i_segment) * 10
22
+ text = context.full_get_segment_text(i_segment)
23
+
24
+ assert_kind_of Integer, start_time
25
+ assert start_time >= 0
26
+ assert_kind_of Integer, end_time
27
+ assert end_time > 0
28
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
29
+ end
30
+ }
31
+
32
+ @whisper.transcribe(@audio, @params)
33
+ end
34
+
35
+ def test_new_segment_callback_closure
36
+ search_word = "what"
37
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
38
+ n_segments = context.full_n_segments
39
+ n_new.times do |i|
40
+ i_segment = n_segments - 1 + i
41
+ text = context.full_get_segment_text(i_segment)
42
+ if text.include?(search_word)
43
+ t0 = context.full_get_segment_t0(i_segment)
44
+ t1 = context.full_get_segment_t1(i_segment)
45
+ raise "search word '#{search_word}' found at between #{t0} and #{t1}"
46
+ end
47
+ end
48
+ }
49
+
50
+ assert_raise RuntimeError do
51
+ @whisper.transcribe(@audio, @params)
52
+ end
53
+ end
54
+
55
+ def test_new_segment_callback_user_data
56
+ udata = Object.new
57
+ @params.new_segment_callback_user_data = udata
58
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
59
+ assert_same udata, user_data
60
+ }
61
+
62
+ @whisper.transcribe(@audio, @params)
63
+ end
64
+
65
+ def test_new_segment_callback_user_data_gc
66
+ @params.new_segment_callback_user_data = "My user data"
67
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
68
+ assert_equal "My user data", user_data
69
+ }
70
+ GC.start
71
+
72
+ assert_same @whisper, @whisper.transcribe(@audio, @params)
73
+ end
74
+
75
+ def test_progress_callback
76
+ first = nil
77
+ last = nil
78
+ @params.progress_callback = ->(context, state, progress, user_data) {
79
+ assert_kind_of Integer, progress
80
+ assert 0 <= progress && progress <= 100
81
+ assert_same @whisper, context
82
+ first = progress if first.nil?
83
+ last = progress
84
+ }
85
+ @whisper.transcribe(@audio, @params)
86
+ assert_equal 0, first
87
+ assert_equal 100, last
88
+ end
89
+
90
+ def test_progress_callback_user_data
91
+ udata = Object.new
92
+ @params.progress_callback_user_data = udata
93
+ @params.progress_callback = ->(context, state, n_new, user_data) {
94
+ assert_same udata, user_data
95
+ }
96
+
97
+ @whisper.transcribe(@audio, @params)
98
+ end
99
+
100
+ def test_on_progress
101
+ first = nil
102
+ last = nil
103
+ @params.on_progress do |progress|
104
+ assert_kind_of Integer, progress
105
+ assert 0 <= progress && progress <= 100
106
+ first = progress if first.nil?
107
+ last = progress
108
+ end
109
+ @whisper.transcribe(@audio, @params)
110
+ assert_equal 0, first
111
+ assert_equal 100, last
112
+ end
113
+
114
+ def test_abort_callback
115
+ i = 0
116
+ @params.abort_callback = ->(user_data) {
117
+ assert_nil user_data
118
+ i += 1
119
+ return false
120
+ }
121
+ @whisper.transcribe(@audio, @params)
122
+ assert i > 0
123
+ end
124
+
125
+ def test_abort_callback_abort
126
+ i = 0
127
+ @params.abort_callback = ->(user_data) {
128
+ i += 1
129
+ return i == 3
130
+ }
131
+ @whisper.transcribe(@audio, @params)
132
+ assert_equal 3, i
133
+ end
134
+
135
+ def test_abort_callback_user_data
136
+ udata = Object.new
137
+ @params.abort_callback_user_data = udata
138
+ yielded = nil
139
+ @params.abort_callback = ->(user_data) {
140
+ yielded = user_data
141
+ }
142
+ @whisper.transcribe(@audio, @params)
143
+ assert_same udata, yielded
144
+ end
145
+
146
+ def test_abort_on
147
+ do_abort = false
148
+ aborted_from_callback = false
149
+ @params.on_new_segment do |segment|
150
+ do_abort = true if segment.text.match? /ask/
151
+ end
152
+ i = 0
153
+ @params.abort_on do
154
+ i += 1
155
+ do_abort
156
+ end
157
+ @whisper.transcribe(@audio, @params)
158
+ assert i > 0
159
+ end
160
+ end
@@ -0,0 +1,20 @@
1
+ require_relative "helper"
2
+
3
+ class TestError < TestBase
4
+ def test_error
5
+ error = Whisper::Error.new(-2)
6
+ assert_equal "failed to compute log mel spectrogram", error.message
7
+ assert_equal -2, error.code
8
+ end
9
+
10
+ def test_unknown_error
11
+ error = Whisper::Error.new(-20)
12
+ assert_equal "unknown error", error.message
13
+ end
14
+
15
+ def test_non_int_code
16
+ assert_raise TypeError do
17
+ error = Whisper::Error.new("non int")
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,71 @@
1
+ require_relative "helper"
2
+ require "pathname"
3
+
4
+ class TestModel < TestBase
5
+ def test_model
6
+ whisper = Whisper::Context.new("base.en")
7
+ assert_instance_of Whisper::Model, whisper.model
8
+ end
9
+
10
+ def test_attributes
11
+ whisper = Whisper::Context.new("base.en")
12
+ model = whisper.model
13
+
14
+ assert_equal 51864, model.n_vocab
15
+ assert_equal 1500, model.n_audio_ctx
16
+ assert_equal 512, model.n_audio_state
17
+ assert_equal 8, model.n_audio_head
18
+ assert_equal 6, model.n_audio_layer
19
+ assert_equal 448, model.n_text_ctx
20
+ assert_equal 512, model.n_text_state
21
+ assert_equal 8, model.n_text_head
22
+ assert_equal 6, model.n_text_layer
23
+ assert_equal 80, model.n_mels
24
+ assert_equal 1, model.ftype
25
+ assert_equal "base", model.type
26
+ end
27
+
28
+ def test_gc
29
+ model = Whisper::Context.new("base.en").model
30
+ GC.start
31
+
32
+ assert_equal 51864, model.n_vocab
33
+ assert_equal 1500, model.n_audio_ctx
34
+ assert_equal 512, model.n_audio_state
35
+ assert_equal 8, model.n_audio_head
36
+ assert_equal 6, model.n_audio_layer
37
+ assert_equal 448, model.n_text_ctx
38
+ assert_equal 512, model.n_text_state
39
+ assert_equal 8, model.n_text_head
40
+ assert_equal 6, model.n_text_layer
41
+ assert_equal 80, model.n_mels
42
+ assert_equal 1, model.ftype
43
+ assert_equal "base", model.type
44
+ end
45
+
46
+ def test_pathname
47
+ path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
48
+ whisper = Whisper::Context.new(path)
49
+ model = whisper.model
50
+
51
+ assert_equal 51864, model.n_vocab
52
+ assert_equal 1500, model.n_audio_ctx
53
+ assert_equal 512, model.n_audio_state
54
+ assert_equal 8, model.n_audio_head
55
+ assert_equal 6, model.n_audio_layer
56
+ assert_equal 448, model.n_text_ctx
57
+ assert_equal 512, model.n_text_state
58
+ assert_equal 8, model.n_text_head
59
+ assert_equal 6, model.n_text_layer
60
+ assert_equal 80, model.n_mels
61
+ assert_equal 1, model.ftype
62
+ assert_equal "base", model.type
63
+ end
64
+
65
+ def test_auto_download
66
+ path = Whisper::Model.pre_converted_models["base.en"].to_path
67
+
68
+ assert_path_exist path
69
+ assert_equal 147964211, File.size(path)
70
+ end
71
+ end
@@ -0,0 +1,31 @@
1
+ require_relative "helper"
2
+ require 'tempfile'
3
+ require 'tmpdir'
4
+ require 'shellwords'
5
+
6
+ class TestPackage < TestBase
7
+ def test_build
8
+ Tempfile.create do |file|
9
+ assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
10
+ assert file.size > 0
11
+ assert_path_exist file.to_path
12
+ end
13
+ end
14
+
15
+ sub_test_case "Building binary on installation" do
16
+ def setup
17
+ system "rake", "build", exception: true
18
+ end
19
+
20
+ def test_install
21
+ match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
22
+ filename = match_data[1]
23
+ version = match_data[2]
24
+ basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
25
+ Dir.mktmpdir do |dir|
26
+ system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
27
+ assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
28
+ end
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,160 @@
1
+ require_relative "helper"
2
+
3
+ class TestParams < TestBase
4
+ def setup
5
+ @params = Whisper::Params.new
6
+ end
7
+
8
+ def test_language
9
+ @params.language = "en"
10
+ assert_equal @params.language, "en"
11
+ @params.language = "auto"
12
+ assert_equal @params.language, "auto"
13
+ end
14
+
15
+ def test_offset
16
+ @params.offset = 10_000
17
+ assert_equal @params.offset, 10_000
18
+ @params.offset = 0
19
+ assert_equal @params.offset, 0
20
+ end
21
+
22
+ def test_duration
23
+ @params.duration = 60_000
24
+ assert_equal @params.duration, 60_000
25
+ @params.duration = 0
26
+ assert_equal @params.duration, 0
27
+ end
28
+
29
+ def test_max_text_tokens
30
+ @params.max_text_tokens = 300
31
+ assert_equal @params.max_text_tokens, 300
32
+ @params.max_text_tokens = 0
33
+ assert_equal @params.max_text_tokens, 0
34
+ end
35
+
36
+ def test_translate
37
+ @params.translate = true
38
+ assert @params.translate
39
+ @params.translate = false
40
+ assert !@params.translate
41
+ end
42
+
43
+ def test_no_context
44
+ @params.no_context = true
45
+ assert @params.no_context
46
+ @params.no_context = false
47
+ assert !@params.no_context
48
+ end
49
+
50
+ def test_single_segment
51
+ @params.single_segment = true
52
+ assert @params.single_segment
53
+ @params.single_segment = false
54
+ assert !@params.single_segment
55
+ end
56
+
57
+ def test_print_special
58
+ @params.print_special = true
59
+ assert @params.print_special
60
+ @params.print_special = false
61
+ assert !@params.print_special
62
+ end
63
+
64
+ def test_print_progress
65
+ @params.print_progress = true
66
+ assert @params.print_progress
67
+ @params.print_progress = false
68
+ assert !@params.print_progress
69
+ end
70
+
71
+ def test_print_realtime
72
+ @params.print_realtime = true
73
+ assert @params.print_realtime
74
+ @params.print_realtime = false
75
+ assert !@params.print_realtime
76
+ end
77
+
78
+ def test_print_timestamps
79
+ @params.print_timestamps = true
80
+ assert @params.print_timestamps
81
+ @params.print_timestamps = false
82
+ assert !@params.print_timestamps
83
+ end
84
+
85
+ def test_suppress_blank
86
+ @params.suppress_blank = true
87
+ assert @params.suppress_blank
88
+ @params.suppress_blank = false
89
+ assert !@params.suppress_blank
90
+ end
91
+
92
+ def test_suppress_non_speech_tokens
93
+ @params.suppress_non_speech_tokens = true
94
+ assert @params.suppress_non_speech_tokens
95
+ @params.suppress_non_speech_tokens = false
96
+ assert !@params.suppress_non_speech_tokens
97
+ end
98
+
99
+ def test_token_timestamps
100
+ @params.token_timestamps = true
101
+ assert @params.token_timestamps
102
+ @params.token_timestamps = false
103
+ assert !@params.token_timestamps
104
+ end
105
+
106
+ def test_split_on_word
107
+ @params.split_on_word = true
108
+ assert @params.split_on_word
109
+ @params.split_on_word = false
110
+ assert !@params.split_on_word
111
+ end
112
+
113
+ def test_initial_prompt
114
+ assert_nil @params.initial_prompt
115
+ @params.initial_prompt = "You are a polite person."
116
+ assert_equal "You are a polite person.", @params.initial_prompt
117
+ end
118
+
119
+ def test_temperature
120
+ assert_equal 0.0, @params.temperature
121
+ @params.temperature = 0.5
122
+ assert_equal 0.5, @params.temperature
123
+ end
124
+
125
+ def test_max_initial_ts
126
+ assert_equal 1.0, @params.max_initial_ts
127
+ @params.max_initial_ts = 600.0
128
+ assert_equal 600.0, @params.max_initial_ts
129
+ end
130
+
131
+ def test_length_penalty
132
+ assert_equal -1.0, @params.length_penalty
133
+ @params.length_penalty = 0.5
134
+ assert_equal 0.5, @params.length_penalty
135
+ end
136
+
137
+ def test_temperature_inc
138
+ assert_in_delta 0.2, @params.temperature_inc
139
+ @params.temperature_inc = 0.5
140
+ assert_in_delta 0.5, @params.temperature_inc
141
+ end
142
+
143
+ def test_entropy_thold
144
+ assert_in_delta 2.4, @params.entropy_thold
145
+ @params.entropy_thold = 3.0
146
+ assert_in_delta 3.0, @params.entropy_thold
147
+ end
148
+
149
+ def test_logprob_thold
150
+ assert_in_delta -1.0, @params.logprob_thold
151
+ @params.logprob_thold = -0.5
152
+ assert_in_delta -0.5, @params.logprob_thold
153
+ end
154
+
155
+ def test_no_speech_thold
156
+ assert_in_delta 0.6, @params.no_speech_thold
157
+ @params.no_speech_thold = 0.2
158
+ assert_in_delta 0.2, @params.no_speech_thold
159
+ end
160
+ end