whispercpp 1.3.0 → 1.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
data/extsources.rb
ADDED
@@ -0,0 +1,6 @@
|
|
1
|
+
require "yaml"
|
2
|
+
|
3
|
+
sources = `git ls-files -z ../..`.split("\x0")
|
4
|
+
paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
|
5
|
+
paths.delete "bindings/ruby/**"
|
6
|
+
EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources
|
@@ -0,0 +1,157 @@
|
|
1
|
+
require "whisper.so"
|
2
|
+
require "uri"
|
3
|
+
require "net/http"
|
4
|
+
require "time"
|
5
|
+
require "pathname"
|
6
|
+
require "io/console/size"
|
7
|
+
|
8
|
+
class Whisper::Model
|
9
|
+
class URI
|
10
|
+
def initialize(uri)
|
11
|
+
@uri = URI(uri)
|
12
|
+
end
|
13
|
+
|
14
|
+
def to_path
|
15
|
+
cache
|
16
|
+
cache_path.to_path
|
17
|
+
end
|
18
|
+
|
19
|
+
def clear_cache
|
20
|
+
path = cache_path
|
21
|
+
path.delete if path.exist?
|
22
|
+
end
|
23
|
+
|
24
|
+
private
|
25
|
+
|
26
|
+
def cache_path
|
27
|
+
base_cache_dir/@uri.host/@uri.path[1..]
|
28
|
+
end
|
29
|
+
|
30
|
+
def base_cache_dir
|
31
|
+
base = case RUBY_PLATFORM
|
32
|
+
when /mswin|mingw/
|
33
|
+
ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
|
34
|
+
when /darwin/
|
35
|
+
Pathname(Dir.home)/"Library/Caches"
|
36
|
+
else
|
37
|
+
ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
|
38
|
+
end
|
39
|
+
base/"whisper.cpp"
|
40
|
+
end
|
41
|
+
|
42
|
+
def cache
|
43
|
+
path = cache_path
|
44
|
+
headers = {}
|
45
|
+
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
|
46
|
+
request @uri, headers
|
47
|
+
path
|
48
|
+
end
|
49
|
+
|
50
|
+
def request(uri, headers)
|
51
|
+
Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
|
52
|
+
request = Net::HTTP::Get.new(uri, headers)
|
53
|
+
http.request request do |response|
|
54
|
+
case response
|
55
|
+
when Net::HTTPNotModified
|
56
|
+
# noop
|
57
|
+
when Net::HTTPOK
|
58
|
+
download response
|
59
|
+
when Net::HTTPRedirection
|
60
|
+
request URI(response["location"]), headers
|
61
|
+
else
|
62
|
+
return if headers.key?("if-modified-since") # Use cache file
|
63
|
+
|
64
|
+
raise "#{response.code} #{response.message}\n#{response.body}"
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
def download(response)
|
71
|
+
path = cache_path
|
72
|
+
path.dirname.mkpath unless path.dirname.exist?
|
73
|
+
downloading_path = Pathname("#{path}.downloading")
|
74
|
+
size = response.content_length
|
75
|
+
downloading_path.open "wb" do |file|
|
76
|
+
downloaded = 0
|
77
|
+
response.read_body do |chunk|
|
78
|
+
file << chunk
|
79
|
+
downloaded += chunk.bytesize
|
80
|
+
show_progress downloaded, size
|
81
|
+
end
|
82
|
+
end
|
83
|
+
downloading_path.rename path
|
84
|
+
end
|
85
|
+
|
86
|
+
def show_progress(current, size)
|
87
|
+
return unless $stderr.tty?
|
88
|
+
return unless size
|
89
|
+
|
90
|
+
unless @prev
|
91
|
+
@prev = Time.now
|
92
|
+
$stderr.puts "Downloading #{@uri}"
|
93
|
+
end
|
94
|
+
|
95
|
+
now = Time.now
|
96
|
+
return if now - @prev < 1 && current < size
|
97
|
+
|
98
|
+
progress_width = 20
|
99
|
+
progress = current.to_f / size
|
100
|
+
arrow_length = progress * progress_width
|
101
|
+
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
|
102
|
+
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
|
103
|
+
padding = ' ' * ($stderr.winsize[1] - line.size)
|
104
|
+
$stderr.print "\r#{line}#{padding}"
|
105
|
+
$stderr.puts if current >= size
|
106
|
+
@prev = now
|
107
|
+
end
|
108
|
+
|
109
|
+
def format_bytesize(bytesize)
|
110
|
+
return "0.0 B" if bytesize.zero?
|
111
|
+
|
112
|
+
units = %w[B KiB MiB GiB TiB]
|
113
|
+
exp = (Math.log(bytesize) / Math.log(1024)).to_i
|
114
|
+
format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
@pre_converted_models = {}
|
119
|
+
%w[
|
120
|
+
tiny
|
121
|
+
tiny.en
|
122
|
+
tiny-q5_1
|
123
|
+
tiny.en-q5_1
|
124
|
+
tiny-q8_0
|
125
|
+
base
|
126
|
+
base.en
|
127
|
+
base-q5_1
|
128
|
+
base.en-q5_1
|
129
|
+
base-q8_0
|
130
|
+
small
|
131
|
+
small.en
|
132
|
+
small.en-tdrz
|
133
|
+
small-q5_1
|
134
|
+
small.en-q5_1
|
135
|
+
small-q8_0
|
136
|
+
medium
|
137
|
+
medium.en
|
138
|
+
medium-q5_0
|
139
|
+
medium.en-q5_0
|
140
|
+
medium-q8_0
|
141
|
+
large-v1
|
142
|
+
large-v2
|
143
|
+
large-v2-q5_0
|
144
|
+
large-v2-q8_0
|
145
|
+
large-v3
|
146
|
+
large-v3-q5_0
|
147
|
+
large-v3-turbo
|
148
|
+
large-v3-turbo-q5_0
|
149
|
+
large-v3-turbo-q8_0
|
150
|
+
].each do |name|
|
151
|
+
@pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
|
152
|
+
end
|
153
|
+
|
154
|
+
class << self
|
155
|
+
attr_reader :pre_converted_models
|
156
|
+
end
|
157
|
+
end
|
data/lib/whisper.rb
ADDED
data/tests/helper.rb
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
#include <ruby.h>
|
2
|
+
#include <ruby/memory_view.h>
|
3
|
+
#include <ruby/encoding.h>
|
4
|
+
|
5
|
+
static VALUE
|
6
|
+
jfk_reader_initialize(VALUE self, VALUE audio_path)
|
7
|
+
{
|
8
|
+
rb_iv_set(self, "audio_path", audio_path);
|
9
|
+
return Qnil;
|
10
|
+
}
|
11
|
+
|
12
|
+
static bool
|
13
|
+
jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
|
14
|
+
{
|
15
|
+
VALUE audio_path = rb_iv_get(obj, "audio_path");
|
16
|
+
const char *audio_path_str = StringValueCStr(audio_path);
|
17
|
+
const int n_samples = 176000;
|
18
|
+
float *data = (float *)malloc(n_samples * sizeof(float));
|
19
|
+
short *samples = (short *)malloc(n_samples * sizeof(short));
|
20
|
+
FILE *file = fopen(audio_path_str, "rb");
|
21
|
+
|
22
|
+
fseek(file, 78, SEEK_SET);
|
23
|
+
fread(samples, sizeof(short), n_samples, file);
|
24
|
+
fclose(file);
|
25
|
+
for (int i = 0; i < n_samples; i++) {
|
26
|
+
data[i] = samples[i]/32768.0;
|
27
|
+
}
|
28
|
+
|
29
|
+
view->obj = obj;
|
30
|
+
view->data = (void *)data;
|
31
|
+
view->byte_size = sizeof(float) * n_samples;
|
32
|
+
view->readonly = true;
|
33
|
+
view->format = "f";
|
34
|
+
view->item_size = sizeof(float);
|
35
|
+
view->item_desc.components = NULL;
|
36
|
+
view->item_desc.length = 0;
|
37
|
+
view->ndim = 1;
|
38
|
+
view->shape = NULL;
|
39
|
+
view->sub_offsets = NULL;
|
40
|
+
view->private_data = NULL;
|
41
|
+
|
42
|
+
return true;
|
43
|
+
}
|
44
|
+
|
45
|
+
static bool
|
46
|
+
jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
|
47
|
+
{
|
48
|
+
return true;
|
49
|
+
}
|
50
|
+
|
51
|
+
static bool
|
52
|
+
jfk_reader_memory_view_available_p(const VALUE obj)
|
53
|
+
{
|
54
|
+
return true;
|
55
|
+
}
|
56
|
+
|
57
|
+
static const rb_memory_view_entry_t jfk_reader_view_entry = {
|
58
|
+
jfk_reader_get_memory_view,
|
59
|
+
jfk_reader_release_memory_view,
|
60
|
+
jfk_reader_memory_view_available_p
|
61
|
+
};
|
62
|
+
|
63
|
+
void Init_jfk_reader(void)
|
64
|
+
{
|
65
|
+
VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
|
66
|
+
rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
|
67
|
+
rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
|
68
|
+
}
|
@@ -0,0 +1,160 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
|
3
|
+
class TestCallback < TestBase
|
4
|
+
def setup
|
5
|
+
GC.start
|
6
|
+
@params = Whisper::Params.new
|
7
|
+
@whisper = Whisper::Context.new("base.en")
|
8
|
+
@audio = File.join(AUDIO)
|
9
|
+
end
|
10
|
+
|
11
|
+
def test_new_segment_callback
|
12
|
+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
13
|
+
assert_kind_of Integer, n_new
|
14
|
+
assert n_new > 0
|
15
|
+
assert_same @whisper, context
|
16
|
+
|
17
|
+
n_segments = context.full_n_segments
|
18
|
+
n_new.times do |i|
|
19
|
+
i_segment = n_segments - 1 + i
|
20
|
+
start_time = context.full_get_segment_t0(i_segment) * 10
|
21
|
+
end_time = context.full_get_segment_t1(i_segment) * 10
|
22
|
+
text = context.full_get_segment_text(i_segment)
|
23
|
+
|
24
|
+
assert_kind_of Integer, start_time
|
25
|
+
assert start_time >= 0
|
26
|
+
assert_kind_of Integer, end_time
|
27
|
+
assert end_time > 0
|
28
|
+
assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
|
29
|
+
end
|
30
|
+
}
|
31
|
+
|
32
|
+
@whisper.transcribe(@audio, @params)
|
33
|
+
end
|
34
|
+
|
35
|
+
def test_new_segment_callback_closure
|
36
|
+
search_word = "what"
|
37
|
+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
38
|
+
n_segments = context.full_n_segments
|
39
|
+
n_new.times do |i|
|
40
|
+
i_segment = n_segments - 1 + i
|
41
|
+
text = context.full_get_segment_text(i_segment)
|
42
|
+
if text.include?(search_word)
|
43
|
+
t0 = context.full_get_segment_t0(i_segment)
|
44
|
+
t1 = context.full_get_segment_t1(i_segment)
|
45
|
+
raise "search word '#{search_word}' found at between #{t0} and #{t1}"
|
46
|
+
end
|
47
|
+
end
|
48
|
+
}
|
49
|
+
|
50
|
+
assert_raise RuntimeError do
|
51
|
+
@whisper.transcribe(@audio, @params)
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
def test_new_segment_callback_user_data
|
56
|
+
udata = Object.new
|
57
|
+
@params.new_segment_callback_user_data = udata
|
58
|
+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
59
|
+
assert_same udata, user_data
|
60
|
+
}
|
61
|
+
|
62
|
+
@whisper.transcribe(@audio, @params)
|
63
|
+
end
|
64
|
+
|
65
|
+
def test_new_segment_callback_user_data_gc
|
66
|
+
@params.new_segment_callback_user_data = "My user data"
|
67
|
+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
68
|
+
assert_equal "My user data", user_data
|
69
|
+
}
|
70
|
+
GC.start
|
71
|
+
|
72
|
+
assert_same @whisper, @whisper.transcribe(@audio, @params)
|
73
|
+
end
|
74
|
+
|
75
|
+
def test_progress_callback
|
76
|
+
first = nil
|
77
|
+
last = nil
|
78
|
+
@params.progress_callback = ->(context, state, progress, user_data) {
|
79
|
+
assert_kind_of Integer, progress
|
80
|
+
assert 0 <= progress && progress <= 100
|
81
|
+
assert_same @whisper, context
|
82
|
+
first = progress if first.nil?
|
83
|
+
last = progress
|
84
|
+
}
|
85
|
+
@whisper.transcribe(@audio, @params)
|
86
|
+
assert_equal 0, first
|
87
|
+
assert_equal 100, last
|
88
|
+
end
|
89
|
+
|
90
|
+
def test_progress_callback_user_data
|
91
|
+
udata = Object.new
|
92
|
+
@params.progress_callback_user_data = udata
|
93
|
+
@params.progress_callback = ->(context, state, n_new, user_data) {
|
94
|
+
assert_same udata, user_data
|
95
|
+
}
|
96
|
+
|
97
|
+
@whisper.transcribe(@audio, @params)
|
98
|
+
end
|
99
|
+
|
100
|
+
def test_on_progress
|
101
|
+
first = nil
|
102
|
+
last = nil
|
103
|
+
@params.on_progress do |progress|
|
104
|
+
assert_kind_of Integer, progress
|
105
|
+
assert 0 <= progress && progress <= 100
|
106
|
+
first = progress if first.nil?
|
107
|
+
last = progress
|
108
|
+
end
|
109
|
+
@whisper.transcribe(@audio, @params)
|
110
|
+
assert_equal 0, first
|
111
|
+
assert_equal 100, last
|
112
|
+
end
|
113
|
+
|
114
|
+
def test_abort_callback
|
115
|
+
i = 0
|
116
|
+
@params.abort_callback = ->(user_data) {
|
117
|
+
assert_nil user_data
|
118
|
+
i += 1
|
119
|
+
return false
|
120
|
+
}
|
121
|
+
@whisper.transcribe(@audio, @params)
|
122
|
+
assert i > 0
|
123
|
+
end
|
124
|
+
|
125
|
+
def test_abort_callback_abort
|
126
|
+
i = 0
|
127
|
+
@params.abort_callback = ->(user_data) {
|
128
|
+
i += 1
|
129
|
+
return i == 3
|
130
|
+
}
|
131
|
+
@whisper.transcribe(@audio, @params)
|
132
|
+
assert_equal 3, i
|
133
|
+
end
|
134
|
+
|
135
|
+
def test_abort_callback_user_data
|
136
|
+
udata = Object.new
|
137
|
+
@params.abort_callback_user_data = udata
|
138
|
+
yielded = nil
|
139
|
+
@params.abort_callback = ->(user_data) {
|
140
|
+
yielded = user_data
|
141
|
+
}
|
142
|
+
@whisper.transcribe(@audio, @params)
|
143
|
+
assert_same udata, yielded
|
144
|
+
end
|
145
|
+
|
146
|
+
def test_abort_on
|
147
|
+
do_abort = false
|
148
|
+
aborted_from_callback = false
|
149
|
+
@params.on_new_segment do |segment|
|
150
|
+
do_abort = true if segment.text.match? /ask/
|
151
|
+
end
|
152
|
+
i = 0
|
153
|
+
@params.abort_on do
|
154
|
+
i += 1
|
155
|
+
do_abort
|
156
|
+
end
|
157
|
+
@whisper.transcribe(@audio, @params)
|
158
|
+
assert i > 0
|
159
|
+
end
|
160
|
+
end
|
data/tests/test_error.rb
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
|
3
|
+
class TestError < TestBase
|
4
|
+
def test_error
|
5
|
+
error = Whisper::Error.new(-2)
|
6
|
+
assert_equal "failed to compute log mel spectrogram", error.message
|
7
|
+
assert_equal -2, error.code
|
8
|
+
end
|
9
|
+
|
10
|
+
def test_unknown_error
|
11
|
+
error = Whisper::Error.new(-20)
|
12
|
+
assert_equal "unknown error", error.message
|
13
|
+
end
|
14
|
+
|
15
|
+
def test_non_int_code
|
16
|
+
assert_raise TypeError do
|
17
|
+
error = Whisper::Error.new("non int")
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
data/tests/test_model.rb
ADDED
@@ -0,0 +1,71 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
require "pathname"
|
3
|
+
|
4
|
+
class TestModel < TestBase
|
5
|
+
def test_model
|
6
|
+
whisper = Whisper::Context.new("base.en")
|
7
|
+
assert_instance_of Whisper::Model, whisper.model
|
8
|
+
end
|
9
|
+
|
10
|
+
def test_attributes
|
11
|
+
whisper = Whisper::Context.new("base.en")
|
12
|
+
model = whisper.model
|
13
|
+
|
14
|
+
assert_equal 51864, model.n_vocab
|
15
|
+
assert_equal 1500, model.n_audio_ctx
|
16
|
+
assert_equal 512, model.n_audio_state
|
17
|
+
assert_equal 8, model.n_audio_head
|
18
|
+
assert_equal 6, model.n_audio_layer
|
19
|
+
assert_equal 448, model.n_text_ctx
|
20
|
+
assert_equal 512, model.n_text_state
|
21
|
+
assert_equal 8, model.n_text_head
|
22
|
+
assert_equal 6, model.n_text_layer
|
23
|
+
assert_equal 80, model.n_mels
|
24
|
+
assert_equal 1, model.ftype
|
25
|
+
assert_equal "base", model.type
|
26
|
+
end
|
27
|
+
|
28
|
+
def test_gc
|
29
|
+
model = Whisper::Context.new("base.en").model
|
30
|
+
GC.start
|
31
|
+
|
32
|
+
assert_equal 51864, model.n_vocab
|
33
|
+
assert_equal 1500, model.n_audio_ctx
|
34
|
+
assert_equal 512, model.n_audio_state
|
35
|
+
assert_equal 8, model.n_audio_head
|
36
|
+
assert_equal 6, model.n_audio_layer
|
37
|
+
assert_equal 448, model.n_text_ctx
|
38
|
+
assert_equal 512, model.n_text_state
|
39
|
+
assert_equal 8, model.n_text_head
|
40
|
+
assert_equal 6, model.n_text_layer
|
41
|
+
assert_equal 80, model.n_mels
|
42
|
+
assert_equal 1, model.ftype
|
43
|
+
assert_equal "base", model.type
|
44
|
+
end
|
45
|
+
|
46
|
+
def test_pathname
|
47
|
+
path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
|
48
|
+
whisper = Whisper::Context.new(path)
|
49
|
+
model = whisper.model
|
50
|
+
|
51
|
+
assert_equal 51864, model.n_vocab
|
52
|
+
assert_equal 1500, model.n_audio_ctx
|
53
|
+
assert_equal 512, model.n_audio_state
|
54
|
+
assert_equal 8, model.n_audio_head
|
55
|
+
assert_equal 6, model.n_audio_layer
|
56
|
+
assert_equal 448, model.n_text_ctx
|
57
|
+
assert_equal 512, model.n_text_state
|
58
|
+
assert_equal 8, model.n_text_head
|
59
|
+
assert_equal 6, model.n_text_layer
|
60
|
+
assert_equal 80, model.n_mels
|
61
|
+
assert_equal 1, model.ftype
|
62
|
+
assert_equal "base", model.type
|
63
|
+
end
|
64
|
+
|
65
|
+
def test_auto_download
|
66
|
+
path = Whisper::Model.pre_converted_models["base.en"].to_path
|
67
|
+
|
68
|
+
assert_path_exist path
|
69
|
+
assert_equal 147964211, File.size(path)
|
70
|
+
end
|
71
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
require 'tempfile'
|
3
|
+
require 'tmpdir'
|
4
|
+
require 'shellwords'
|
5
|
+
|
6
|
+
class TestPackage < TestBase
|
7
|
+
def test_build
|
8
|
+
Tempfile.create do |file|
|
9
|
+
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
|
10
|
+
assert file.size > 0
|
11
|
+
assert_path_exist file.to_path
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
sub_test_case "Building binary on installation" do
|
16
|
+
def setup
|
17
|
+
system "rake", "build", exception: true
|
18
|
+
end
|
19
|
+
|
20
|
+
def test_install
|
21
|
+
match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
|
22
|
+
filename = match_data[1]
|
23
|
+
version = match_data[2]
|
24
|
+
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
|
25
|
+
Dir.mktmpdir do |dir|
|
26
|
+
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
|
27
|
+
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,160 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
|
3
|
+
class TestParams < TestBase
|
4
|
+
def setup
|
5
|
+
@params = Whisper::Params.new
|
6
|
+
end
|
7
|
+
|
8
|
+
def test_language
|
9
|
+
@params.language = "en"
|
10
|
+
assert_equal @params.language, "en"
|
11
|
+
@params.language = "auto"
|
12
|
+
assert_equal @params.language, "auto"
|
13
|
+
end
|
14
|
+
|
15
|
+
def test_offset
|
16
|
+
@params.offset = 10_000
|
17
|
+
assert_equal @params.offset, 10_000
|
18
|
+
@params.offset = 0
|
19
|
+
assert_equal @params.offset, 0
|
20
|
+
end
|
21
|
+
|
22
|
+
def test_duration
|
23
|
+
@params.duration = 60_000
|
24
|
+
assert_equal @params.duration, 60_000
|
25
|
+
@params.duration = 0
|
26
|
+
assert_equal @params.duration, 0
|
27
|
+
end
|
28
|
+
|
29
|
+
def test_max_text_tokens
|
30
|
+
@params.max_text_tokens = 300
|
31
|
+
assert_equal @params.max_text_tokens, 300
|
32
|
+
@params.max_text_tokens = 0
|
33
|
+
assert_equal @params.max_text_tokens, 0
|
34
|
+
end
|
35
|
+
|
36
|
+
def test_translate
|
37
|
+
@params.translate = true
|
38
|
+
assert @params.translate
|
39
|
+
@params.translate = false
|
40
|
+
assert !@params.translate
|
41
|
+
end
|
42
|
+
|
43
|
+
def test_no_context
|
44
|
+
@params.no_context = true
|
45
|
+
assert @params.no_context
|
46
|
+
@params.no_context = false
|
47
|
+
assert !@params.no_context
|
48
|
+
end
|
49
|
+
|
50
|
+
def test_single_segment
|
51
|
+
@params.single_segment = true
|
52
|
+
assert @params.single_segment
|
53
|
+
@params.single_segment = false
|
54
|
+
assert !@params.single_segment
|
55
|
+
end
|
56
|
+
|
57
|
+
def test_print_special
|
58
|
+
@params.print_special = true
|
59
|
+
assert @params.print_special
|
60
|
+
@params.print_special = false
|
61
|
+
assert !@params.print_special
|
62
|
+
end
|
63
|
+
|
64
|
+
def test_print_progress
|
65
|
+
@params.print_progress = true
|
66
|
+
assert @params.print_progress
|
67
|
+
@params.print_progress = false
|
68
|
+
assert !@params.print_progress
|
69
|
+
end
|
70
|
+
|
71
|
+
def test_print_realtime
|
72
|
+
@params.print_realtime = true
|
73
|
+
assert @params.print_realtime
|
74
|
+
@params.print_realtime = false
|
75
|
+
assert !@params.print_realtime
|
76
|
+
end
|
77
|
+
|
78
|
+
def test_print_timestamps
|
79
|
+
@params.print_timestamps = true
|
80
|
+
assert @params.print_timestamps
|
81
|
+
@params.print_timestamps = false
|
82
|
+
assert !@params.print_timestamps
|
83
|
+
end
|
84
|
+
|
85
|
+
def test_suppress_blank
|
86
|
+
@params.suppress_blank = true
|
87
|
+
assert @params.suppress_blank
|
88
|
+
@params.suppress_blank = false
|
89
|
+
assert !@params.suppress_blank
|
90
|
+
end
|
91
|
+
|
92
|
+
def test_suppress_non_speech_tokens
|
93
|
+
@params.suppress_non_speech_tokens = true
|
94
|
+
assert @params.suppress_non_speech_tokens
|
95
|
+
@params.suppress_non_speech_tokens = false
|
96
|
+
assert !@params.suppress_non_speech_tokens
|
97
|
+
end
|
98
|
+
|
99
|
+
def test_token_timestamps
|
100
|
+
@params.token_timestamps = true
|
101
|
+
assert @params.token_timestamps
|
102
|
+
@params.token_timestamps = false
|
103
|
+
assert !@params.token_timestamps
|
104
|
+
end
|
105
|
+
|
106
|
+
def test_split_on_word
|
107
|
+
@params.split_on_word = true
|
108
|
+
assert @params.split_on_word
|
109
|
+
@params.split_on_word = false
|
110
|
+
assert !@params.split_on_word
|
111
|
+
end
|
112
|
+
|
113
|
+
def test_initial_prompt
|
114
|
+
assert_nil @params.initial_prompt
|
115
|
+
@params.initial_prompt = "You are a polite person."
|
116
|
+
assert_equal "You are a polite person.", @params.initial_prompt
|
117
|
+
end
|
118
|
+
|
119
|
+
def test_temperature
|
120
|
+
assert_equal 0.0, @params.temperature
|
121
|
+
@params.temperature = 0.5
|
122
|
+
assert_equal 0.5, @params.temperature
|
123
|
+
end
|
124
|
+
|
125
|
+
def test_max_initial_ts
|
126
|
+
assert_equal 1.0, @params.max_initial_ts
|
127
|
+
@params.max_initial_ts = 600.0
|
128
|
+
assert_equal 600.0, @params.max_initial_ts
|
129
|
+
end
|
130
|
+
|
131
|
+
def test_length_penalty
|
132
|
+
assert_equal -1.0, @params.length_penalty
|
133
|
+
@params.length_penalty = 0.5
|
134
|
+
assert_equal 0.5, @params.length_penalty
|
135
|
+
end
|
136
|
+
|
137
|
+
def test_temperature_inc
|
138
|
+
assert_in_delta 0.2, @params.temperature_inc
|
139
|
+
@params.temperature_inc = 0.5
|
140
|
+
assert_in_delta 0.5, @params.temperature_inc
|
141
|
+
end
|
142
|
+
|
143
|
+
def test_entropy_thold
|
144
|
+
assert_in_delta 2.4, @params.entropy_thold
|
145
|
+
@params.entropy_thold = 3.0
|
146
|
+
assert_in_delta 3.0, @params.entropy_thold
|
147
|
+
end
|
148
|
+
|
149
|
+
def test_logprob_thold
|
150
|
+
assert_in_delta -1.0, @params.logprob_thold
|
151
|
+
@params.logprob_thold = -0.5
|
152
|
+
assert_in_delta -0.5, @params.logprob_thold
|
153
|
+
end
|
154
|
+
|
155
|
+
def test_no_speech_thold
|
156
|
+
assert_in_delta 0.6, @params.no_speech_thold
|
157
|
+
@params.no_speech_thold = 0.2
|
158
|
+
assert_in_delta 0.2, @params.no_speech_thold
|
159
|
+
end
|
160
|
+
end
|