whispercpp 1.3.0 → 1.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
data/extsources.rb
ADDED
@@ -0,0 +1,6 @@
|
|
1
|
+
require "yaml"
|
2
|
+
|
3
|
+
sources = `git ls-files -z ../..`.split("\x0")
|
4
|
+
paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
|
5
|
+
paths.delete "bindings/ruby/**"
|
6
|
+
EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources
|
@@ -0,0 +1,157 @@
|
|
1
|
+
require "whisper.so"
|
2
|
+
require "uri"
|
3
|
+
require "net/http"
|
4
|
+
require "time"
|
5
|
+
require "pathname"
|
6
|
+
require "io/console/size"
|
7
|
+
|
8
|
+
class Whisper::Model
|
9
|
+
class URI
|
10
|
+
def initialize(uri)
|
11
|
+
@uri = URI(uri)
|
12
|
+
end
|
13
|
+
|
14
|
+
def to_path
|
15
|
+
cache
|
16
|
+
cache_path.to_path
|
17
|
+
end
|
18
|
+
|
19
|
+
def clear_cache
|
20
|
+
path = cache_path
|
21
|
+
path.delete if path.exist?
|
22
|
+
end
|
23
|
+
|
24
|
+
private
|
25
|
+
|
26
|
+
def cache_path
|
27
|
+
base_cache_dir/@uri.host/@uri.path[1..]
|
28
|
+
end
|
29
|
+
|
30
|
+
def base_cache_dir
|
31
|
+
base = case RUBY_PLATFORM
|
32
|
+
when /mswin|mingw/
|
33
|
+
ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
|
34
|
+
when /darwin/
|
35
|
+
Pathname(Dir.home)/"Library/Caches"
|
36
|
+
else
|
37
|
+
ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
|
38
|
+
end
|
39
|
+
base/"whisper.cpp"
|
40
|
+
end
|
41
|
+
|
42
|
+
def cache
|
43
|
+
path = cache_path
|
44
|
+
headers = {}
|
45
|
+
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
|
46
|
+
request @uri, headers
|
47
|
+
path
|
48
|
+
end
|
49
|
+
|
50
|
+
def request(uri, headers)
|
51
|
+
Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
|
52
|
+
request = Net::HTTP::Get.new(uri, headers)
|
53
|
+
http.request request do |response|
|
54
|
+
case response
|
55
|
+
when Net::HTTPNotModified
|
56
|
+
# noop
|
57
|
+
when Net::HTTPOK
|
58
|
+
download response
|
59
|
+
when Net::HTTPRedirection
|
60
|
+
request URI(response["location"]), headers
|
61
|
+
else
|
62
|
+
return if headers.key?("if-modified-since") # Use cache file
|
63
|
+
|
64
|
+
raise "#{response.code} #{response.message}\n#{response.body}"
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
def download(response)
|
71
|
+
path = cache_path
|
72
|
+
path.dirname.mkpath unless path.dirname.exist?
|
73
|
+
downloading_path = Pathname("#{path}.downloading")
|
74
|
+
size = response.content_length
|
75
|
+
downloading_path.open "wb" do |file|
|
76
|
+
downloaded = 0
|
77
|
+
response.read_body do |chunk|
|
78
|
+
file << chunk
|
79
|
+
downloaded += chunk.bytesize
|
80
|
+
show_progress downloaded, size
|
81
|
+
end
|
82
|
+
end
|
83
|
+
downloading_path.rename path
|
84
|
+
end
|
85
|
+
|
86
|
+
def show_progress(current, size)
|
87
|
+
return unless $stderr.tty?
|
88
|
+
return unless size
|
89
|
+
|
90
|
+
unless @prev
|
91
|
+
@prev = Time.now
|
92
|
+
$stderr.puts "Downloading #{@uri}"
|
93
|
+
end
|
94
|
+
|
95
|
+
now = Time.now
|
96
|
+
return if now - @prev < 1 && current < size
|
97
|
+
|
98
|
+
progress_width = 20
|
99
|
+
progress = current.to_f / size
|
100
|
+
arrow_length = progress * progress_width
|
101
|
+
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
|
102
|
+
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
|
103
|
+
padding = ' ' * ($stderr.winsize[1] - line.size)
|
104
|
+
$stderr.print "\r#{line}#{padding}"
|
105
|
+
$stderr.puts if current >= size
|
106
|
+
@prev = now
|
107
|
+
end
|
108
|
+
|
109
|
+
def format_bytesize(bytesize)
|
110
|
+
return "0.0 B" if bytesize.zero?
|
111
|
+
|
112
|
+
units = %w[B KiB MiB GiB TiB]
|
113
|
+
exp = (Math.log(bytesize) / Math.log(1024)).to_i
|
114
|
+
format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
@pre_converted_models = {}
|
119
|
+
%w[
|
120
|
+
tiny
|
121
|
+
tiny.en
|
122
|
+
tiny-q5_1
|
123
|
+
tiny.en-q5_1
|
124
|
+
tiny-q8_0
|
125
|
+
base
|
126
|
+
base.en
|
127
|
+
base-q5_1
|
128
|
+
base.en-q5_1
|
129
|
+
base-q8_0
|
130
|
+
small
|
131
|
+
small.en
|
132
|
+
small.en-tdrz
|
133
|
+
small-q5_1
|
134
|
+
small.en-q5_1
|
135
|
+
small-q8_0
|
136
|
+
medium
|
137
|
+
medium.en
|
138
|
+
medium-q5_0
|
139
|
+
medium.en-q5_0
|
140
|
+
medium-q8_0
|
141
|
+
large-v1
|
142
|
+
large-v2
|
143
|
+
large-v2-q5_0
|
144
|
+
large-v2-q8_0
|
145
|
+
large-v3
|
146
|
+
large-v3-q5_0
|
147
|
+
large-v3-turbo
|
148
|
+
large-v3-turbo-q5_0
|
149
|
+
large-v3-turbo-q8_0
|
150
|
+
].each do |name|
|
151
|
+
@pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
|
152
|
+
end
|
153
|
+
|
154
|
+
class << self
|
155
|
+
attr_reader :pre_converted_models
|
156
|
+
end
|
157
|
+
end
|
data/lib/whisper.rb
ADDED
data/tests/helper.rb
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
#include <ruby.h>
|
2
|
+
#include <ruby/memory_view.h>
|
3
|
+
#include <ruby/encoding.h>
|
4
|
+
|
5
|
+
static VALUE
|
6
|
+
jfk_reader_initialize(VALUE self, VALUE audio_path)
|
7
|
+
{
|
8
|
+
rb_iv_set(self, "audio_path", audio_path);
|
9
|
+
return Qnil;
|
10
|
+
}
|
11
|
+
|
12
|
+
static bool
|
13
|
+
jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
|
14
|
+
{
|
15
|
+
VALUE audio_path = rb_iv_get(obj, "audio_path");
|
16
|
+
const char *audio_path_str = StringValueCStr(audio_path);
|
17
|
+
const int n_samples = 176000;
|
18
|
+
float *data = (float *)malloc(n_samples * sizeof(float));
|
19
|
+
short *samples = (short *)malloc(n_samples * sizeof(short));
|
20
|
+
FILE *file = fopen(audio_path_str, "rb");
|
21
|
+
|
22
|
+
fseek(file, 78, SEEK_SET);
|
23
|
+
fread(samples, sizeof(short), n_samples, file);
|
24
|
+
fclose(file);
|
25
|
+
for (int i = 0; i < n_samples; i++) {
|
26
|
+
data[i] = samples[i]/32768.0;
|
27
|
+
}
|
28
|
+
|
29
|
+
view->obj = obj;
|
30
|
+
view->data = (void *)data;
|
31
|
+
view->byte_size = sizeof(float) * n_samples;
|
32
|
+
view->readonly = true;
|
33
|
+
view->format = "f";
|
34
|
+
view->item_size = sizeof(float);
|
35
|
+
view->item_desc.components = NULL;
|
36
|
+
view->item_desc.length = 0;
|
37
|
+
view->ndim = 1;
|
38
|
+
view->shape = NULL;
|
39
|
+
view->sub_offsets = NULL;
|
40
|
+
view->private_data = NULL;
|
41
|
+
|
42
|
+
return true;
|
43
|
+
}
|
44
|
+
|
45
|
+
static bool
|
46
|
+
jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
|
47
|
+
{
|
48
|
+
return true;
|
49
|
+
}
|
50
|
+
|
51
|
+
static bool
|
52
|
+
jfk_reader_memory_view_available_p(const VALUE obj)
|
53
|
+
{
|
54
|
+
return true;
|
55
|
+
}
|
56
|
+
|
57
|
+
static const rb_memory_view_entry_t jfk_reader_view_entry = {
|
58
|
+
jfk_reader_get_memory_view,
|
59
|
+
jfk_reader_release_memory_view,
|
60
|
+
jfk_reader_memory_view_available_p
|
61
|
+
};
|
62
|
+
|
63
|
+
void Init_jfk_reader(void)
|
64
|
+
{
|
65
|
+
VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
|
66
|
+
rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
|
67
|
+
rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
|
68
|
+
}
|
@@ -0,0 +1,160 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
|
3
|
+
class TestCallback < TestBase
|
4
|
+
def setup
|
5
|
+
GC.start
|
6
|
+
@params = Whisper::Params.new
|
7
|
+
@whisper = Whisper::Context.new("base.en")
|
8
|
+
@audio = File.join(AUDIO)
|
9
|
+
end
|
10
|
+
|
11
|
+
def test_new_segment_callback
|
12
|
+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
13
|
+
assert_kind_of Integer, n_new
|
14
|
+
assert n_new > 0
|
15
|
+
assert_same @whisper, context
|
16
|
+
|
17
|
+
n_segments = context.full_n_segments
|
18
|
+
n_new.times do |i|
|
19
|
+
i_segment = n_segments - 1 + i
|
20
|
+
start_time = context.full_get_segment_t0(i_segment) * 10
|
21
|
+
end_time = context.full_get_segment_t1(i_segment) * 10
|
22
|
+
text = context.full_get_segment_text(i_segment)
|
23
|
+
|
24
|
+
assert_kind_of Integer, start_time
|
25
|
+
assert start_time >= 0
|
26
|
+
assert_kind_of Integer, end_time
|
27
|
+
assert end_time > 0
|
28
|
+
assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
|
29
|
+
end
|
30
|
+
}
|
31
|
+
|
32
|
+
@whisper.transcribe(@audio, @params)
|
33
|
+
end
|
34
|
+
|
35
|
+
def test_new_segment_callback_closure
|
36
|
+
search_word = "what"
|
37
|
+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
38
|
+
n_segments = context.full_n_segments
|
39
|
+
n_new.times do |i|
|
40
|
+
i_segment = n_segments - 1 + i
|
41
|
+
text = context.full_get_segment_text(i_segment)
|
42
|
+
if text.include?(search_word)
|
43
|
+
t0 = context.full_get_segment_t0(i_segment)
|
44
|
+
t1 = context.full_get_segment_t1(i_segment)
|
45
|
+
raise "search word '#{search_word}' found at between #{t0} and #{t1}"
|
46
|
+
end
|
47
|
+
end
|
48
|
+
}
|
49
|
+
|
50
|
+
assert_raise RuntimeError do
|
51
|
+
@whisper.transcribe(@audio, @params)
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
def test_new_segment_callback_user_data
|
56
|
+
udata = Object.new
|
57
|
+
@params.new_segment_callback_user_data = udata
|
58
|
+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
59
|
+
assert_same udata, user_data
|
60
|
+
}
|
61
|
+
|
62
|
+
@whisper.transcribe(@audio, @params)
|
63
|
+
end
|
64
|
+
|
65
|
+
def test_new_segment_callback_user_data_gc
|
66
|
+
@params.new_segment_callback_user_data = "My user data"
|
67
|
+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
68
|
+
assert_equal "My user data", user_data
|
69
|
+
}
|
70
|
+
GC.start
|
71
|
+
|
72
|
+
assert_same @whisper, @whisper.transcribe(@audio, @params)
|
73
|
+
end
|
74
|
+
|
75
|
+
def test_progress_callback
|
76
|
+
first = nil
|
77
|
+
last = nil
|
78
|
+
@params.progress_callback = ->(context, state, progress, user_data) {
|
79
|
+
assert_kind_of Integer, progress
|
80
|
+
assert 0 <= progress && progress <= 100
|
81
|
+
assert_same @whisper, context
|
82
|
+
first = progress if first.nil?
|
83
|
+
last = progress
|
84
|
+
}
|
85
|
+
@whisper.transcribe(@audio, @params)
|
86
|
+
assert_equal 0, first
|
87
|
+
assert_equal 100, last
|
88
|
+
end
|
89
|
+
|
90
|
+
def test_progress_callback_user_data
|
91
|
+
udata = Object.new
|
92
|
+
@params.progress_callback_user_data = udata
|
93
|
+
@params.progress_callback = ->(context, state, n_new, user_data) {
|
94
|
+
assert_same udata, user_data
|
95
|
+
}
|
96
|
+
|
97
|
+
@whisper.transcribe(@audio, @params)
|
98
|
+
end
|
99
|
+
|
100
|
+
def test_on_progress
|
101
|
+
first = nil
|
102
|
+
last = nil
|
103
|
+
@params.on_progress do |progress|
|
104
|
+
assert_kind_of Integer, progress
|
105
|
+
assert 0 <= progress && progress <= 100
|
106
|
+
first = progress if first.nil?
|
107
|
+
last = progress
|
108
|
+
end
|
109
|
+
@whisper.transcribe(@audio, @params)
|
110
|
+
assert_equal 0, first
|
111
|
+
assert_equal 100, last
|
112
|
+
end
|
113
|
+
|
114
|
+
def test_abort_callback
|
115
|
+
i = 0
|
116
|
+
@params.abort_callback = ->(user_data) {
|
117
|
+
assert_nil user_data
|
118
|
+
i += 1
|
119
|
+
return false
|
120
|
+
}
|
121
|
+
@whisper.transcribe(@audio, @params)
|
122
|
+
assert i > 0
|
123
|
+
end
|
124
|
+
|
125
|
+
def test_abort_callback_abort
|
126
|
+
i = 0
|
127
|
+
@params.abort_callback = ->(user_data) {
|
128
|
+
i += 1
|
129
|
+
return i == 3
|
130
|
+
}
|
131
|
+
@whisper.transcribe(@audio, @params)
|
132
|
+
assert_equal 3, i
|
133
|
+
end
|
134
|
+
|
135
|
+
def test_abort_callback_user_data
|
136
|
+
udata = Object.new
|
137
|
+
@params.abort_callback_user_data = udata
|
138
|
+
yielded = nil
|
139
|
+
@params.abort_callback = ->(user_data) {
|
140
|
+
yielded = user_data
|
141
|
+
}
|
142
|
+
@whisper.transcribe(@audio, @params)
|
143
|
+
assert_same udata, yielded
|
144
|
+
end
|
145
|
+
|
146
|
+
def test_abort_on
|
147
|
+
do_abort = false
|
148
|
+
aborted_from_callback = false
|
149
|
+
@params.on_new_segment do |segment|
|
150
|
+
do_abort = true if segment.text.match? /ask/
|
151
|
+
end
|
152
|
+
i = 0
|
153
|
+
@params.abort_on do
|
154
|
+
i += 1
|
155
|
+
do_abort
|
156
|
+
end
|
157
|
+
@whisper.transcribe(@audio, @params)
|
158
|
+
assert i > 0
|
159
|
+
end
|
160
|
+
end
|
data/tests/test_error.rb
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
|
3
|
+
class TestError < TestBase
|
4
|
+
def test_error
|
5
|
+
error = Whisper::Error.new(-2)
|
6
|
+
assert_equal "failed to compute log mel spectrogram", error.message
|
7
|
+
assert_equal -2, error.code
|
8
|
+
end
|
9
|
+
|
10
|
+
def test_unknown_error
|
11
|
+
error = Whisper::Error.new(-20)
|
12
|
+
assert_equal "unknown error", error.message
|
13
|
+
end
|
14
|
+
|
15
|
+
def test_non_int_code
|
16
|
+
assert_raise TypeError do
|
17
|
+
error = Whisper::Error.new("non int")
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
data/tests/test_model.rb
ADDED
@@ -0,0 +1,71 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
require "pathname"
|
3
|
+
|
4
|
+
class TestModel < TestBase
|
5
|
+
def test_model
|
6
|
+
whisper = Whisper::Context.new("base.en")
|
7
|
+
assert_instance_of Whisper::Model, whisper.model
|
8
|
+
end
|
9
|
+
|
10
|
+
def test_attributes
|
11
|
+
whisper = Whisper::Context.new("base.en")
|
12
|
+
model = whisper.model
|
13
|
+
|
14
|
+
assert_equal 51864, model.n_vocab
|
15
|
+
assert_equal 1500, model.n_audio_ctx
|
16
|
+
assert_equal 512, model.n_audio_state
|
17
|
+
assert_equal 8, model.n_audio_head
|
18
|
+
assert_equal 6, model.n_audio_layer
|
19
|
+
assert_equal 448, model.n_text_ctx
|
20
|
+
assert_equal 512, model.n_text_state
|
21
|
+
assert_equal 8, model.n_text_head
|
22
|
+
assert_equal 6, model.n_text_layer
|
23
|
+
assert_equal 80, model.n_mels
|
24
|
+
assert_equal 1, model.ftype
|
25
|
+
assert_equal "base", model.type
|
26
|
+
end
|
27
|
+
|
28
|
+
def test_gc
|
29
|
+
model = Whisper::Context.new("base.en").model
|
30
|
+
GC.start
|
31
|
+
|
32
|
+
assert_equal 51864, model.n_vocab
|
33
|
+
assert_equal 1500, model.n_audio_ctx
|
34
|
+
assert_equal 512, model.n_audio_state
|
35
|
+
assert_equal 8, model.n_audio_head
|
36
|
+
assert_equal 6, model.n_audio_layer
|
37
|
+
assert_equal 448, model.n_text_ctx
|
38
|
+
assert_equal 512, model.n_text_state
|
39
|
+
assert_equal 8, model.n_text_head
|
40
|
+
assert_equal 6, model.n_text_layer
|
41
|
+
assert_equal 80, model.n_mels
|
42
|
+
assert_equal 1, model.ftype
|
43
|
+
assert_equal "base", model.type
|
44
|
+
end
|
45
|
+
|
46
|
+
def test_pathname
|
47
|
+
path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
|
48
|
+
whisper = Whisper::Context.new(path)
|
49
|
+
model = whisper.model
|
50
|
+
|
51
|
+
assert_equal 51864, model.n_vocab
|
52
|
+
assert_equal 1500, model.n_audio_ctx
|
53
|
+
assert_equal 512, model.n_audio_state
|
54
|
+
assert_equal 8, model.n_audio_head
|
55
|
+
assert_equal 6, model.n_audio_layer
|
56
|
+
assert_equal 448, model.n_text_ctx
|
57
|
+
assert_equal 512, model.n_text_state
|
58
|
+
assert_equal 8, model.n_text_head
|
59
|
+
assert_equal 6, model.n_text_layer
|
60
|
+
assert_equal 80, model.n_mels
|
61
|
+
assert_equal 1, model.ftype
|
62
|
+
assert_equal "base", model.type
|
63
|
+
end
|
64
|
+
|
65
|
+
def test_auto_download
|
66
|
+
path = Whisper::Model.pre_converted_models["base.en"].to_path
|
67
|
+
|
68
|
+
assert_path_exist path
|
69
|
+
assert_equal 147964211, File.size(path)
|
70
|
+
end
|
71
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
require 'tempfile'
|
3
|
+
require 'tmpdir'
|
4
|
+
require 'shellwords'
|
5
|
+
|
6
|
+
class TestPackage < TestBase
|
7
|
+
def test_build
|
8
|
+
Tempfile.create do |file|
|
9
|
+
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
|
10
|
+
assert file.size > 0
|
11
|
+
assert_path_exist file.to_path
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
sub_test_case "Building binary on installation" do
|
16
|
+
def setup
|
17
|
+
system "rake", "build", exception: true
|
18
|
+
end
|
19
|
+
|
20
|
+
def test_install
|
21
|
+
match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
|
22
|
+
filename = match_data[1]
|
23
|
+
version = match_data[2]
|
24
|
+
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
|
25
|
+
Dir.mktmpdir do |dir|
|
26
|
+
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
|
27
|
+
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,160 @@
|
|
1
|
+
require_relative "helper"
|
2
|
+
|
3
|
+
class TestParams < TestBase
|
4
|
+
def setup
|
5
|
+
@params = Whisper::Params.new
|
6
|
+
end
|
7
|
+
|
8
|
+
def test_language
|
9
|
+
@params.language = "en"
|
10
|
+
assert_equal @params.language, "en"
|
11
|
+
@params.language = "auto"
|
12
|
+
assert_equal @params.language, "auto"
|
13
|
+
end
|
14
|
+
|
15
|
+
def test_offset
|
16
|
+
@params.offset = 10_000
|
17
|
+
assert_equal @params.offset, 10_000
|
18
|
+
@params.offset = 0
|
19
|
+
assert_equal @params.offset, 0
|
20
|
+
end
|
21
|
+
|
22
|
+
def test_duration
|
23
|
+
@params.duration = 60_000
|
24
|
+
assert_equal @params.duration, 60_000
|
25
|
+
@params.duration = 0
|
26
|
+
assert_equal @params.duration, 0
|
27
|
+
end
|
28
|
+
|
29
|
+
def test_max_text_tokens
|
30
|
+
@params.max_text_tokens = 300
|
31
|
+
assert_equal @params.max_text_tokens, 300
|
32
|
+
@params.max_text_tokens = 0
|
33
|
+
assert_equal @params.max_text_tokens, 0
|
34
|
+
end
|
35
|
+
|
36
|
+
def test_translate
|
37
|
+
@params.translate = true
|
38
|
+
assert @params.translate
|
39
|
+
@params.translate = false
|
40
|
+
assert !@params.translate
|
41
|
+
end
|
42
|
+
|
43
|
+
def test_no_context
|
44
|
+
@params.no_context = true
|
45
|
+
assert @params.no_context
|
46
|
+
@params.no_context = false
|
47
|
+
assert !@params.no_context
|
48
|
+
end
|
49
|
+
|
50
|
+
def test_single_segment
|
51
|
+
@params.single_segment = true
|
52
|
+
assert @params.single_segment
|
53
|
+
@params.single_segment = false
|
54
|
+
assert !@params.single_segment
|
55
|
+
end
|
56
|
+
|
57
|
+
def test_print_special
|
58
|
+
@params.print_special = true
|
59
|
+
assert @params.print_special
|
60
|
+
@params.print_special = false
|
61
|
+
assert !@params.print_special
|
62
|
+
end
|
63
|
+
|
64
|
+
def test_print_progress
|
65
|
+
@params.print_progress = true
|
66
|
+
assert @params.print_progress
|
67
|
+
@params.print_progress = false
|
68
|
+
assert !@params.print_progress
|
69
|
+
end
|
70
|
+
|
71
|
+
def test_print_realtime
|
72
|
+
@params.print_realtime = true
|
73
|
+
assert @params.print_realtime
|
74
|
+
@params.print_realtime = false
|
75
|
+
assert !@params.print_realtime
|
76
|
+
end
|
77
|
+
|
78
|
+
def test_print_timestamps
|
79
|
+
@params.print_timestamps = true
|
80
|
+
assert @params.print_timestamps
|
81
|
+
@params.print_timestamps = false
|
82
|
+
assert !@params.print_timestamps
|
83
|
+
end
|
84
|
+
|
85
|
+
def test_suppress_blank
|
86
|
+
@params.suppress_blank = true
|
87
|
+
assert @params.suppress_blank
|
88
|
+
@params.suppress_blank = false
|
89
|
+
assert !@params.suppress_blank
|
90
|
+
end
|
91
|
+
|
92
|
+
def test_suppress_non_speech_tokens
|
93
|
+
@params.suppress_non_speech_tokens = true
|
94
|
+
assert @params.suppress_non_speech_tokens
|
95
|
+
@params.suppress_non_speech_tokens = false
|
96
|
+
assert !@params.suppress_non_speech_tokens
|
97
|
+
end
|
98
|
+
|
99
|
+
def test_token_timestamps
|
100
|
+
@params.token_timestamps = true
|
101
|
+
assert @params.token_timestamps
|
102
|
+
@params.token_timestamps = false
|
103
|
+
assert !@params.token_timestamps
|
104
|
+
end
|
105
|
+
|
106
|
+
def test_split_on_word
|
107
|
+
@params.split_on_word = true
|
108
|
+
assert @params.split_on_word
|
109
|
+
@params.split_on_word = false
|
110
|
+
assert !@params.split_on_word
|
111
|
+
end
|
112
|
+
|
113
|
+
def test_initial_prompt
|
114
|
+
assert_nil @params.initial_prompt
|
115
|
+
@params.initial_prompt = "You are a polite person."
|
116
|
+
assert_equal "You are a polite person.", @params.initial_prompt
|
117
|
+
end
|
118
|
+
|
119
|
+
def test_temperature
|
120
|
+
assert_equal 0.0, @params.temperature
|
121
|
+
@params.temperature = 0.5
|
122
|
+
assert_equal 0.5, @params.temperature
|
123
|
+
end
|
124
|
+
|
125
|
+
def test_max_initial_ts
|
126
|
+
assert_equal 1.0, @params.max_initial_ts
|
127
|
+
@params.max_initial_ts = 600.0
|
128
|
+
assert_equal 600.0, @params.max_initial_ts
|
129
|
+
end
|
130
|
+
|
131
|
+
def test_length_penalty
|
132
|
+
assert_equal -1.0, @params.length_penalty
|
133
|
+
@params.length_penalty = 0.5
|
134
|
+
assert_equal 0.5, @params.length_penalty
|
135
|
+
end
|
136
|
+
|
137
|
+
def test_temperature_inc
|
138
|
+
assert_in_delta 0.2, @params.temperature_inc
|
139
|
+
@params.temperature_inc = 0.5
|
140
|
+
assert_in_delta 0.5, @params.temperature_inc
|
141
|
+
end
|
142
|
+
|
143
|
+
def test_entropy_thold
|
144
|
+
assert_in_delta 2.4, @params.entropy_thold
|
145
|
+
@params.entropy_thold = 3.0
|
146
|
+
assert_in_delta 3.0, @params.entropy_thold
|
147
|
+
end
|
148
|
+
|
149
|
+
def test_logprob_thold
|
150
|
+
assert_in_delta -1.0, @params.logprob_thold
|
151
|
+
@params.logprob_thold = -0.5
|
152
|
+
assert_in_delta -0.5, @params.logprob_thold
|
153
|
+
end
|
154
|
+
|
155
|
+
def test_no_speech_thold
|
156
|
+
assert_in_delta 0.6, @params.no_speech_thold
|
157
|
+
@params.no_speech_thold = 0.2
|
158
|
+
assert_in_delta 0.2, @params.no_speech_thold
|
159
|
+
end
|
160
|
+
end
|