whispercpp 1.3.0 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
data/extsources.rb ADDED
@@ -0,0 +1,6 @@
1
+ require "yaml"
2
+
3
+ sources = `git ls-files -z ../..`.split("\x0")
4
+ paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
5
+ paths.delete "bindings/ruby/**"
6
+ EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources
@@ -0,0 +1,157 @@
1
+ require "whisper.so"
2
+ require "uri"
3
+ require "net/http"
4
+ require "time"
5
+ require "pathname"
6
+ require "io/console/size"
7
+
8
+ class Whisper::Model
9
+ class URI
10
+ def initialize(uri)
11
+ @uri = URI(uri)
12
+ end
13
+
14
+ def to_path
15
+ cache
16
+ cache_path.to_path
17
+ end
18
+
19
+ def clear_cache
20
+ path = cache_path
21
+ path.delete if path.exist?
22
+ end
23
+
24
+ private
25
+
26
+ def cache_path
27
+ base_cache_dir/@uri.host/@uri.path[1..]
28
+ end
29
+
30
+ def base_cache_dir
31
+ base = case RUBY_PLATFORM
32
+ when /mswin|mingw/
33
+ ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
34
+ when /darwin/
35
+ Pathname(Dir.home)/"Library/Caches"
36
+ else
37
+ ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
38
+ end
39
+ base/"whisper.cpp"
40
+ end
41
+
42
+ def cache
43
+ path = cache_path
44
+ headers = {}
45
+ headers["if-modified-since"] = path.mtime.httpdate if path.exist?
46
+ request @uri, headers
47
+ path
48
+ end
49
+
50
+ def request(uri, headers)
51
+ Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
52
+ request = Net::HTTP::Get.new(uri, headers)
53
+ http.request request do |response|
54
+ case response
55
+ when Net::HTTPNotModified
56
+ # noop
57
+ when Net::HTTPOK
58
+ download response
59
+ when Net::HTTPRedirection
60
+ request URI(response["location"]), headers
61
+ else
62
+ return if headers.key?("if-modified-since") # Use cache file
63
+
64
+ raise "#{response.code} #{response.message}\n#{response.body}"
65
+ end
66
+ end
67
+ end
68
+ end
69
+
70
+ def download(response)
71
+ path = cache_path
72
+ path.dirname.mkpath unless path.dirname.exist?
73
+ downloading_path = Pathname("#{path}.downloading")
74
+ size = response.content_length
75
+ downloading_path.open "wb" do |file|
76
+ downloaded = 0
77
+ response.read_body do |chunk|
78
+ file << chunk
79
+ downloaded += chunk.bytesize
80
+ show_progress downloaded, size
81
+ end
82
+ end
83
+ downloading_path.rename path
84
+ end
85
+
86
+ def show_progress(current, size)
87
+ return unless $stderr.tty?
88
+ return unless size
89
+
90
+ unless @prev
91
+ @prev = Time.now
92
+ $stderr.puts "Downloading #{@uri}"
93
+ end
94
+
95
+ now = Time.now
96
+ return if now - @prev < 1 && current < size
97
+
98
+ progress_width = 20
99
+ progress = current.to_f / size
100
+ arrow_length = progress * progress_width
101
+ arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
102
+ line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
103
+ padding = ' ' * ($stderr.winsize[1] - line.size)
104
+ $stderr.print "\r#{line}#{padding}"
105
+ $stderr.puts if current >= size
106
+ @prev = now
107
+ end
108
+
109
+ def format_bytesize(bytesize)
110
+ return "0.0 B" if bytesize.zero?
111
+
112
+ units = %w[B KiB MiB GiB TiB]
113
+ exp = (Math.log(bytesize) / Math.log(1024)).to_i
114
+ format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
115
+ end
116
+ end
117
+
118
+ @pre_converted_models = {}
119
+ %w[
120
+ tiny
121
+ tiny.en
122
+ tiny-q5_1
123
+ tiny.en-q5_1
124
+ tiny-q8_0
125
+ base
126
+ base.en
127
+ base-q5_1
128
+ base.en-q5_1
129
+ base-q8_0
130
+ small
131
+ small.en
132
+ small.en-tdrz
133
+ small-q5_1
134
+ small.en-q5_1
135
+ small-q8_0
136
+ medium
137
+ medium.en
138
+ medium-q5_0
139
+ medium.en-q5_0
140
+ medium-q8_0
141
+ large-v1
142
+ large-v2
143
+ large-v2-q5_0
144
+ large-v2-q8_0
145
+ large-v3
146
+ large-v3-q5_0
147
+ large-v3-turbo
148
+ large-v3-turbo-q5_0
149
+ large-v3-turbo-q8_0
150
+ ].each do |name|
151
+ @pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
152
+ end
153
+
154
+ class << self
155
+ attr_reader :pre_converted_models
156
+ end
157
+ end
data/lib/whisper.rb ADDED
@@ -0,0 +1,2 @@
1
+ require "whisper.so"
2
+ require "whisper/model/uri"
data/tests/helper.rb ADDED
@@ -0,0 +1,7 @@
1
+ require "test/unit"
2
+ require "whisper"
3
+ require_relative "jfk_reader/jfk_reader"
4
+
5
+ class TestBase < Test::Unit::TestCase
6
+ AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
7
+ end
@@ -0,0 +1,5 @@
1
+ Makefile
2
+ jfk_reader.o
3
+ jfk_reader.so
4
+ jfk_reader.bundle
5
+ jfk_reader.dll
@@ -0,0 +1,3 @@
1
+ require "mkmf"
2
+
3
+ create_makefile("jfk_reader")
@@ -0,0 +1,68 @@
1
+ #include <ruby.h>
2
+ #include <ruby/memory_view.h>
3
+ #include <ruby/encoding.h>
4
+
5
+ static VALUE
6
+ jfk_reader_initialize(VALUE self, VALUE audio_path)
7
+ {
8
+ rb_iv_set(self, "audio_path", audio_path);
9
+ return Qnil;
10
+ }
11
+
12
+ static bool
13
+ jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
14
+ {
15
+ VALUE audio_path = rb_iv_get(obj, "audio_path");
16
+ const char *audio_path_str = StringValueCStr(audio_path);
17
+ const int n_samples = 176000;
18
+ float *data = (float *)malloc(n_samples * sizeof(float));
19
+ short *samples = (short *)malloc(n_samples * sizeof(short));
20
+ FILE *file = fopen(audio_path_str, "rb");
21
+
22
+ fseek(file, 78, SEEK_SET);
23
+ fread(samples, sizeof(short), n_samples, file);
24
+ fclose(file);
25
+ for (int i = 0; i < n_samples; i++) {
26
+ data[i] = samples[i]/32768.0;
27
+ }
28
+
29
+ view->obj = obj;
30
+ view->data = (void *)data;
31
+ view->byte_size = sizeof(float) * n_samples;
32
+ view->readonly = true;
33
+ view->format = "f";
34
+ view->item_size = sizeof(float);
35
+ view->item_desc.components = NULL;
36
+ view->item_desc.length = 0;
37
+ view->ndim = 1;
38
+ view->shape = NULL;
39
+ view->sub_offsets = NULL;
40
+ view->private_data = NULL;
41
+
42
+ return true;
43
+ }
44
+
45
+ static bool
46
+ jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
47
+ {
48
+ return true;
49
+ }
50
+
51
+ static bool
52
+ jfk_reader_memory_view_available_p(const VALUE obj)
53
+ {
54
+ return true;
55
+ }
56
+
57
+ static const rb_memory_view_entry_t jfk_reader_view_entry = {
58
+ jfk_reader_get_memory_view,
59
+ jfk_reader_release_memory_view,
60
+ jfk_reader_memory_view_available_p
61
+ };
62
+
63
+ void Init_jfk_reader(void)
64
+ {
65
+ VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
66
+ rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
67
+ rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
68
+ }
@@ -0,0 +1,160 @@
1
+ require_relative "helper"
2
+
3
+ class TestCallback < TestBase
4
+ def setup
5
+ GC.start
6
+ @params = Whisper::Params.new
7
+ @whisper = Whisper::Context.new("base.en")
8
+ @audio = File.join(AUDIO)
9
+ end
10
+
11
+ def test_new_segment_callback
12
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
13
+ assert_kind_of Integer, n_new
14
+ assert n_new > 0
15
+ assert_same @whisper, context
16
+
17
+ n_segments = context.full_n_segments
18
+ n_new.times do |i|
19
+ i_segment = n_segments - 1 + i
20
+ start_time = context.full_get_segment_t0(i_segment) * 10
21
+ end_time = context.full_get_segment_t1(i_segment) * 10
22
+ text = context.full_get_segment_text(i_segment)
23
+
24
+ assert_kind_of Integer, start_time
25
+ assert start_time >= 0
26
+ assert_kind_of Integer, end_time
27
+ assert end_time > 0
28
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
29
+ end
30
+ }
31
+
32
+ @whisper.transcribe(@audio, @params)
33
+ end
34
+
35
+ def test_new_segment_callback_closure
36
+ search_word = "what"
37
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
38
+ n_segments = context.full_n_segments
39
+ n_new.times do |i|
40
+ i_segment = n_segments - 1 + i
41
+ text = context.full_get_segment_text(i_segment)
42
+ if text.include?(search_word)
43
+ t0 = context.full_get_segment_t0(i_segment)
44
+ t1 = context.full_get_segment_t1(i_segment)
45
+ raise "search word '#{search_word}' found at between #{t0} and #{t1}"
46
+ end
47
+ end
48
+ }
49
+
50
+ assert_raise RuntimeError do
51
+ @whisper.transcribe(@audio, @params)
52
+ end
53
+ end
54
+
55
+ def test_new_segment_callback_user_data
56
+ udata = Object.new
57
+ @params.new_segment_callback_user_data = udata
58
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
59
+ assert_same udata, user_data
60
+ }
61
+
62
+ @whisper.transcribe(@audio, @params)
63
+ end
64
+
65
+ def test_new_segment_callback_user_data_gc
66
+ @params.new_segment_callback_user_data = "My user data"
67
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
68
+ assert_equal "My user data", user_data
69
+ }
70
+ GC.start
71
+
72
+ assert_same @whisper, @whisper.transcribe(@audio, @params)
73
+ end
74
+
75
+ def test_progress_callback
76
+ first = nil
77
+ last = nil
78
+ @params.progress_callback = ->(context, state, progress, user_data) {
79
+ assert_kind_of Integer, progress
80
+ assert 0 <= progress && progress <= 100
81
+ assert_same @whisper, context
82
+ first = progress if first.nil?
83
+ last = progress
84
+ }
85
+ @whisper.transcribe(@audio, @params)
86
+ assert_equal 0, first
87
+ assert_equal 100, last
88
+ end
89
+
90
+ def test_progress_callback_user_data
91
+ udata = Object.new
92
+ @params.progress_callback_user_data = udata
93
+ @params.progress_callback = ->(context, state, n_new, user_data) {
94
+ assert_same udata, user_data
95
+ }
96
+
97
+ @whisper.transcribe(@audio, @params)
98
+ end
99
+
100
+ def test_on_progress
101
+ first = nil
102
+ last = nil
103
+ @params.on_progress do |progress|
104
+ assert_kind_of Integer, progress
105
+ assert 0 <= progress && progress <= 100
106
+ first = progress if first.nil?
107
+ last = progress
108
+ end
109
+ @whisper.transcribe(@audio, @params)
110
+ assert_equal 0, first
111
+ assert_equal 100, last
112
+ end
113
+
114
+ def test_abort_callback
115
+ i = 0
116
+ @params.abort_callback = ->(user_data) {
117
+ assert_nil user_data
118
+ i += 1
119
+ return false
120
+ }
121
+ @whisper.transcribe(@audio, @params)
122
+ assert i > 0
123
+ end
124
+
125
+ def test_abort_callback_abort
126
+ i = 0
127
+ @params.abort_callback = ->(user_data) {
128
+ i += 1
129
+ return i == 3
130
+ }
131
+ @whisper.transcribe(@audio, @params)
132
+ assert_equal 3, i
133
+ end
134
+
135
+ def test_abort_callback_user_data
136
+ udata = Object.new
137
+ @params.abort_callback_user_data = udata
138
+ yielded = nil
139
+ @params.abort_callback = ->(user_data) {
140
+ yielded = user_data
141
+ }
142
+ @whisper.transcribe(@audio, @params)
143
+ assert_same udata, yielded
144
+ end
145
+
146
+ def test_abort_on
147
+ do_abort = false
148
+ aborted_from_callback = false
149
+ @params.on_new_segment do |segment|
150
+ do_abort = true if segment.text.match? /ask/
151
+ end
152
+ i = 0
153
+ @params.abort_on do
154
+ i += 1
155
+ do_abort
156
+ end
157
+ @whisper.transcribe(@audio, @params)
158
+ assert i > 0
159
+ end
160
+ end
@@ -0,0 +1,20 @@
1
+ require_relative "helper"
2
+
3
+ class TestError < TestBase
4
+ def test_error
5
+ error = Whisper::Error.new(-2)
6
+ assert_equal "failed to compute log mel spectrogram", error.message
7
+ assert_equal -2, error.code
8
+ end
9
+
10
+ def test_unknown_error
11
+ error = Whisper::Error.new(-20)
12
+ assert_equal "unknown error", error.message
13
+ end
14
+
15
+ def test_non_int_code
16
+ assert_raise TypeError do
17
+ error = Whisper::Error.new("non int")
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,71 @@
1
+ require_relative "helper"
2
+ require "pathname"
3
+
4
+ class TestModel < TestBase
5
+ def test_model
6
+ whisper = Whisper::Context.new("base.en")
7
+ assert_instance_of Whisper::Model, whisper.model
8
+ end
9
+
10
+ def test_attributes
11
+ whisper = Whisper::Context.new("base.en")
12
+ model = whisper.model
13
+
14
+ assert_equal 51864, model.n_vocab
15
+ assert_equal 1500, model.n_audio_ctx
16
+ assert_equal 512, model.n_audio_state
17
+ assert_equal 8, model.n_audio_head
18
+ assert_equal 6, model.n_audio_layer
19
+ assert_equal 448, model.n_text_ctx
20
+ assert_equal 512, model.n_text_state
21
+ assert_equal 8, model.n_text_head
22
+ assert_equal 6, model.n_text_layer
23
+ assert_equal 80, model.n_mels
24
+ assert_equal 1, model.ftype
25
+ assert_equal "base", model.type
26
+ end
27
+
28
+ def test_gc
29
+ model = Whisper::Context.new("base.en").model
30
+ GC.start
31
+
32
+ assert_equal 51864, model.n_vocab
33
+ assert_equal 1500, model.n_audio_ctx
34
+ assert_equal 512, model.n_audio_state
35
+ assert_equal 8, model.n_audio_head
36
+ assert_equal 6, model.n_audio_layer
37
+ assert_equal 448, model.n_text_ctx
38
+ assert_equal 512, model.n_text_state
39
+ assert_equal 8, model.n_text_head
40
+ assert_equal 6, model.n_text_layer
41
+ assert_equal 80, model.n_mels
42
+ assert_equal 1, model.ftype
43
+ assert_equal "base", model.type
44
+ end
45
+
46
+ def test_pathname
47
+ path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
48
+ whisper = Whisper::Context.new(path)
49
+ model = whisper.model
50
+
51
+ assert_equal 51864, model.n_vocab
52
+ assert_equal 1500, model.n_audio_ctx
53
+ assert_equal 512, model.n_audio_state
54
+ assert_equal 8, model.n_audio_head
55
+ assert_equal 6, model.n_audio_layer
56
+ assert_equal 448, model.n_text_ctx
57
+ assert_equal 512, model.n_text_state
58
+ assert_equal 8, model.n_text_head
59
+ assert_equal 6, model.n_text_layer
60
+ assert_equal 80, model.n_mels
61
+ assert_equal 1, model.ftype
62
+ assert_equal "base", model.type
63
+ end
64
+
65
+ def test_auto_download
66
+ path = Whisper::Model.pre_converted_models["base.en"].to_path
67
+
68
+ assert_path_exist path
69
+ assert_equal 147964211, File.size(path)
70
+ end
71
+ end
@@ -0,0 +1,31 @@
1
+ require_relative "helper"
2
+ require 'tempfile'
3
+ require 'tmpdir'
4
+ require 'shellwords'
5
+
6
+ class TestPackage < TestBase
7
+ def test_build
8
+ Tempfile.create do |file|
9
+ assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
10
+ assert file.size > 0
11
+ assert_path_exist file.to_path
12
+ end
13
+ end
14
+
15
+ sub_test_case "Building binary on installation" do
16
+ def setup
17
+ system "rake", "build", exception: true
18
+ end
19
+
20
+ def test_install
21
+ match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
22
+ filename = match_data[1]
23
+ version = match_data[2]
24
+ basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
25
+ Dir.mktmpdir do |dir|
26
+ system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
27
+ assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
28
+ end
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,160 @@
1
+ require_relative "helper"
2
+
3
+ class TestParams < TestBase
4
+ def setup
5
+ @params = Whisper::Params.new
6
+ end
7
+
8
+ def test_language
9
+ @params.language = "en"
10
+ assert_equal @params.language, "en"
11
+ @params.language = "auto"
12
+ assert_equal @params.language, "auto"
13
+ end
14
+
15
+ def test_offset
16
+ @params.offset = 10_000
17
+ assert_equal @params.offset, 10_000
18
+ @params.offset = 0
19
+ assert_equal @params.offset, 0
20
+ end
21
+
22
+ def test_duration
23
+ @params.duration = 60_000
24
+ assert_equal @params.duration, 60_000
25
+ @params.duration = 0
26
+ assert_equal @params.duration, 0
27
+ end
28
+
29
+ def test_max_text_tokens
30
+ @params.max_text_tokens = 300
31
+ assert_equal @params.max_text_tokens, 300
32
+ @params.max_text_tokens = 0
33
+ assert_equal @params.max_text_tokens, 0
34
+ end
35
+
36
+ def test_translate
37
+ @params.translate = true
38
+ assert @params.translate
39
+ @params.translate = false
40
+ assert !@params.translate
41
+ end
42
+
43
+ def test_no_context
44
+ @params.no_context = true
45
+ assert @params.no_context
46
+ @params.no_context = false
47
+ assert !@params.no_context
48
+ end
49
+
50
+ def test_single_segment
51
+ @params.single_segment = true
52
+ assert @params.single_segment
53
+ @params.single_segment = false
54
+ assert !@params.single_segment
55
+ end
56
+
57
+ def test_print_special
58
+ @params.print_special = true
59
+ assert @params.print_special
60
+ @params.print_special = false
61
+ assert !@params.print_special
62
+ end
63
+
64
+ def test_print_progress
65
+ @params.print_progress = true
66
+ assert @params.print_progress
67
+ @params.print_progress = false
68
+ assert !@params.print_progress
69
+ end
70
+
71
+ def test_print_realtime
72
+ @params.print_realtime = true
73
+ assert @params.print_realtime
74
+ @params.print_realtime = false
75
+ assert !@params.print_realtime
76
+ end
77
+
78
+ def test_print_timestamps
79
+ @params.print_timestamps = true
80
+ assert @params.print_timestamps
81
+ @params.print_timestamps = false
82
+ assert !@params.print_timestamps
83
+ end
84
+
85
+ def test_suppress_blank
86
+ @params.suppress_blank = true
87
+ assert @params.suppress_blank
88
+ @params.suppress_blank = false
89
+ assert !@params.suppress_blank
90
+ end
91
+
92
+ def test_suppress_non_speech_tokens
93
+ @params.suppress_non_speech_tokens = true
94
+ assert @params.suppress_non_speech_tokens
95
+ @params.suppress_non_speech_tokens = false
96
+ assert !@params.suppress_non_speech_tokens
97
+ end
98
+
99
+ def test_token_timestamps
100
+ @params.token_timestamps = true
101
+ assert @params.token_timestamps
102
+ @params.token_timestamps = false
103
+ assert !@params.token_timestamps
104
+ end
105
+
106
+ def test_split_on_word
107
+ @params.split_on_word = true
108
+ assert @params.split_on_word
109
+ @params.split_on_word = false
110
+ assert !@params.split_on_word
111
+ end
112
+
113
+ def test_initial_prompt
114
+ assert_nil @params.initial_prompt
115
+ @params.initial_prompt = "You are a polite person."
116
+ assert_equal "You are a polite person.", @params.initial_prompt
117
+ end
118
+
119
+ def test_temperature
120
+ assert_equal 0.0, @params.temperature
121
+ @params.temperature = 0.5
122
+ assert_equal 0.5, @params.temperature
123
+ end
124
+
125
+ def test_max_initial_ts
126
+ assert_equal 1.0, @params.max_initial_ts
127
+ @params.max_initial_ts = 600.0
128
+ assert_equal 600.0, @params.max_initial_ts
129
+ end
130
+
131
+ def test_length_penalty
132
+ assert_equal -1.0, @params.length_penalty
133
+ @params.length_penalty = 0.5
134
+ assert_equal 0.5, @params.length_penalty
135
+ end
136
+
137
+ def test_temperature_inc
138
+ assert_in_delta 0.2, @params.temperature_inc
139
+ @params.temperature_inc = 0.5
140
+ assert_in_delta 0.5, @params.temperature_inc
141
+ end
142
+
143
+ def test_entropy_thold
144
+ assert_in_delta 2.4, @params.entropy_thold
145
+ @params.entropy_thold = 3.0
146
+ assert_in_delta 3.0, @params.entropy_thold
147
+ end
148
+
149
+ def test_logprob_thold
150
+ assert_in_delta -1.0, @params.logprob_thold
151
+ @params.logprob_thold = -0.5
152
+ assert_in_delta -0.5, @params.logprob_thold
153
+ end
154
+
155
+ def test_no_speech_thold
156
+ assert_in_delta 0.6, @params.no_speech_thold
157
+ @params.no_speech_thold = 0.2
158
+ assert_in_delta 0.2, @params.no_speech_thold
159
+ end
160
+ end