informers 0.2.0 → 1.0.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +63 -99
- data/lib/informers/configs.rb +48 -0
- data/lib/informers/env.rb +14 -0
- data/lib/informers/model.rb +31 -0
- data/lib/informers/models.rb +294 -0
- data/lib/informers/pipelines.rb +439 -0
- data/lib/informers/tokenizers.rb +141 -0
- data/lib/informers/utils/core.rb +7 -0
- data/lib/informers/utils/hub.rb +240 -0
- data/lib/informers/utils/math.rb +44 -0
- data/lib/informers/utils/tensor.rb +26 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +28 -9
- metadata +21 -41
- data/lib/informers/feature_extraction.rb +0 -59
- data/lib/informers/fill_mask.rb +0 -109
- data/lib/informers/ner.rb +0 -106
- data/lib/informers/question_answering.rb +0 -197
- data/lib/informers/sentiment_analysis.rb +0 -72
- data/lib/informers/text_generation.rb +0 -54
- data/vendor/LICENSE-bert.txt +0 -202
- data/vendor/LICENSE-blingfire.txt +0 -21
- data/vendor/LICENSE-gpt2.txt +0 -24
- data/vendor/LICENSE-roberta.txt +0 -21
- data/vendor/bert_base_cased_tok.bin +0 -0
- data/vendor/bert_base_tok.bin +0 -0
- data/vendor/gpt2.bin +0 -0
- data/vendor/gpt2.i2w +0 -0
- data/vendor/roberta.bin +0 -0
- data/vendor/roberta.i2w +0 -0
@@ -0,0 +1,240 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
module Hub
|
4
|
+
class FileResponse
|
5
|
+
attr_reader :exists, :status
|
6
|
+
|
7
|
+
def initialize(file_path)
|
8
|
+
@file_path = file_path
|
9
|
+
|
10
|
+
@exists = File.exist?(file_path)
|
11
|
+
if @exists
|
12
|
+
@status = ["200", "OK"]
|
13
|
+
else
|
14
|
+
@status = ["404", "Not Found"]
|
15
|
+
end
|
16
|
+
end
|
17
|
+
|
18
|
+
def read
|
19
|
+
File.binread(@file_path)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
def self.is_valid_url(string, protocols = nil, valid_hosts = nil)
|
24
|
+
begin
|
25
|
+
url = URI.parse(string)
|
26
|
+
rescue
|
27
|
+
return false
|
28
|
+
end
|
29
|
+
if protocols && !protocols.include?(url.scheme)
|
30
|
+
return false
|
31
|
+
end
|
32
|
+
if valid_hosts && !valid_hosts.include?(url.host)
|
33
|
+
return false
|
34
|
+
end
|
35
|
+
true
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.get_file(url_or_path, progress_callback = nil, progress_info = {})
|
39
|
+
if !is_valid_url(url_or_path, ["http", "https"])
|
40
|
+
raise Error, "Invalid url"
|
41
|
+
else
|
42
|
+
headers = {}
|
43
|
+
headers["User-Agent"] = "informers/#{VERSION};"
|
44
|
+
|
45
|
+
# Check whether we are making a request to the Hugging Face Hub.
|
46
|
+
is_hfurl = is_valid_url(url_or_path, ["http", "https"], ["huggingface.co", "hf.co"])
|
47
|
+
if is_hfurl
|
48
|
+
# If an access token is present in the environment variables,
|
49
|
+
# we add it to the request headers.
|
50
|
+
token = ENV["HF_TOKEN"]
|
51
|
+
if token
|
52
|
+
headers["Authorization"] = "Bearer #{token}"
|
53
|
+
end
|
54
|
+
end
|
55
|
+
options = {}
|
56
|
+
if progress_callback
|
57
|
+
total_size = nil
|
58
|
+
options[:content_length_proc] = lambda do |size|
|
59
|
+
total_size = size
|
60
|
+
Utils.dispatch_callback(progress_callback, {status: "download"}.merge(progress_info).merge(total_size: size))
|
61
|
+
end
|
62
|
+
options[:progress_proc] = lambda do |size|
|
63
|
+
Utils.dispatch_callback(progress_callback, {status: "progress"}.merge(progress_info).merge(size: size, total_size: total_size))
|
64
|
+
end
|
65
|
+
end
|
66
|
+
URI.parse(url_or_path).open(**headers, **options)
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
class FileCache
|
71
|
+
attr_reader :path
|
72
|
+
|
73
|
+
def initialize(path)
|
74
|
+
@path = path
|
75
|
+
end
|
76
|
+
|
77
|
+
def match(request)
|
78
|
+
file_path = resolve_path(request)
|
79
|
+
file = FileResponse.new(file_path)
|
80
|
+
|
81
|
+
file if file.exists
|
82
|
+
end
|
83
|
+
|
84
|
+
def put(request, buffer)
|
85
|
+
output_path = resolve_path(request)
|
86
|
+
|
87
|
+
begin
|
88
|
+
FileUtils.mkdir_p(File.dirname(output_path))
|
89
|
+
File.binwrite(output_path, buffer)
|
90
|
+
rescue => e
|
91
|
+
warn "An error occurred while writing the file to cache: #{e}"
|
92
|
+
end
|
93
|
+
end
|
94
|
+
|
95
|
+
def resolve_path(request)
|
96
|
+
File.join(@path, request)
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
def self.try_cache(cache, *names)
|
101
|
+
names.each do |name|
|
102
|
+
begin
|
103
|
+
result = cache.match(name)
|
104
|
+
return result if result
|
105
|
+
rescue
|
106
|
+
next
|
107
|
+
end
|
108
|
+
end
|
109
|
+
nil
|
110
|
+
end
|
111
|
+
|
112
|
+
def self.get_model_file(path_or_repo_id, filename, fatal = true, **options)
|
113
|
+
# Initiate file retrieval
|
114
|
+
Utils.dispatch_callback(options[:progress_callback], {
|
115
|
+
status: "initiate",
|
116
|
+
name: path_or_repo_id,
|
117
|
+
file: filename
|
118
|
+
})
|
119
|
+
|
120
|
+
# If `cache_dir` is not specified, use the default cache directory
|
121
|
+
cache = FileCache.new(options[:cache_dir] || Informers.cache_dir)
|
122
|
+
|
123
|
+
revision = options[:revision] || "main"
|
124
|
+
|
125
|
+
request_url = path_join(path_or_repo_id, filename)
|
126
|
+
|
127
|
+
remote_url = path_join(
|
128
|
+
Informers.remote_host,
|
129
|
+
Informers.remote_path_template
|
130
|
+
.gsub("{model}", path_or_repo_id)
|
131
|
+
.gsub("{revision}", URI.encode_www_form_component(revision)),
|
132
|
+
filename
|
133
|
+
)
|
134
|
+
|
135
|
+
# Choose cache key for filesystem cache
|
136
|
+
# When using the main revision (default), we use the request URL as the cache key.
|
137
|
+
# If a specific revision is requested, we account for this in the cache key.
|
138
|
+
fs_cache_key = revision == "main" ? request_url : path_join(path_or_repo_id, revision, filename)
|
139
|
+
|
140
|
+
proposed_cache_key = fs_cache_key
|
141
|
+
|
142
|
+
resolved_path = cache.resolve_path(proposed_cache_key)
|
143
|
+
|
144
|
+
# Whether to cache the final response in the end.
|
145
|
+
to_cache_response = false
|
146
|
+
|
147
|
+
# A caching system is available, so we try to get the file from it.
|
148
|
+
response = try_cache(cache, proposed_cache_key)
|
149
|
+
|
150
|
+
cache_hit = !response.nil?
|
151
|
+
|
152
|
+
if response.nil?
|
153
|
+
# File is not cached, so we perform the request
|
154
|
+
|
155
|
+
if response.nil? || response.status[0] == "404"
|
156
|
+
# File not found locally. This means either:
|
157
|
+
# - The user has disabled local file access (`Informers.allow_local_models = false`)
|
158
|
+
# - the path is a valid HTTP url (`response.nil?`)
|
159
|
+
# - the path is not a valid HTTP url and the file is not present on the file system or local server (`response.status[0] == "404"`)
|
160
|
+
|
161
|
+
if options[:local_files_only] || !Informers.allow_remote_models
|
162
|
+
# User requested local files only, but the file is not found locally.
|
163
|
+
if fatal
|
164
|
+
raise Error, "`local_files_only: true` or `Informers.allow_remote_models = false` and file was not found locally at #{resolved_path.inspect}."
|
165
|
+
else
|
166
|
+
# File not found, but this file is optional.
|
167
|
+
# TODO in future, cache the response?
|
168
|
+
return nil
|
169
|
+
end
|
170
|
+
end
|
171
|
+
|
172
|
+
progress_info = {
|
173
|
+
name: path_or_repo_id,
|
174
|
+
file: filename
|
175
|
+
}
|
176
|
+
|
177
|
+
# File not found locally, so we try to download it from the remote server
|
178
|
+
response = get_file(remote_url, options[:progress_callback], progress_info)
|
179
|
+
|
180
|
+
if response.status[0] != "200"
|
181
|
+
# should not happen
|
182
|
+
raise Todo
|
183
|
+
end
|
184
|
+
|
185
|
+
# Success! We use the proposed cache key from earlier
|
186
|
+
cache_key = proposed_cache_key
|
187
|
+
end
|
188
|
+
|
189
|
+
to_cache_response = cache && !response.is_a?(FileResponse) && response.status[0] == "200"
|
190
|
+
end
|
191
|
+
|
192
|
+
buffer = response.read
|
193
|
+
|
194
|
+
if to_cache_response && cache_key && cache.match(cache_key).nil?
|
195
|
+
cache.put(cache_key, buffer)
|
196
|
+
end
|
197
|
+
|
198
|
+
Utils.dispatch_callback(options[:progress_callback], {
|
199
|
+
status: "done",
|
200
|
+
name: path_or_repo_id,
|
201
|
+
file: filename,
|
202
|
+
cache_hit: cache_hit
|
203
|
+
})
|
204
|
+
|
205
|
+
resolved_path
|
206
|
+
end
|
207
|
+
|
208
|
+
def self.get_model_json(model_path, file_name, fatal = true, **options)
|
209
|
+
buffer = get_model_file(model_path, file_name, fatal, **options)
|
210
|
+
if buffer.nil?
|
211
|
+
# Return empty object
|
212
|
+
return {}
|
213
|
+
end
|
214
|
+
|
215
|
+
JSON.load_file(buffer)
|
216
|
+
end
|
217
|
+
|
218
|
+
def self.path_join(*parts)
|
219
|
+
parts = parts.map.with_index do |part, index|
|
220
|
+
if index != 0
|
221
|
+
part = part.delete_prefix("/")
|
222
|
+
end
|
223
|
+
if index != parts.length - 1
|
224
|
+
part = part.delete_suffix("/")
|
225
|
+
end
|
226
|
+
part
|
227
|
+
end
|
228
|
+
parts.join("/")
|
229
|
+
end
|
230
|
+
|
231
|
+
def self.display_progress(filename, width, size, expected_size)
|
232
|
+
bar_width = width - (filename.length + 3)
|
233
|
+
progress = size / expected_size.to_f
|
234
|
+
done = (progress * bar_width).round
|
235
|
+
not_done = bar_width - done
|
236
|
+
"#{filename} |#{"█" * done}#{" " * not_done}|"
|
237
|
+
end
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
@@ -0,0 +1,44 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
def self.softmax(arr)
|
4
|
+
# Compute the maximum value in the array
|
5
|
+
max_val = arr.max
|
6
|
+
|
7
|
+
# Compute the exponentials of the array values
|
8
|
+
exps = arr.map { |x| Math.exp(x - max_val) }
|
9
|
+
|
10
|
+
# Compute the sum of the exponentials
|
11
|
+
sum_exps = exps.sum
|
12
|
+
|
13
|
+
# Compute the softmax values
|
14
|
+
softmax_arr = exps.map { |x| x / sum_exps }
|
15
|
+
|
16
|
+
softmax_arr
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.sigmoid(arr)
|
20
|
+
arr.map { |v| 1 / (1 + Math.exp(-v)) }
|
21
|
+
end
|
22
|
+
|
23
|
+
def self.get_top_items(items, top_k = 0)
|
24
|
+
# if top == 0, return all
|
25
|
+
|
26
|
+
items = items
|
27
|
+
.map.with_index { |x, i| [i, x] } # Get indices ([index, score])
|
28
|
+
.sort_by { |v| -v[1] } # Sort by log probabilities
|
29
|
+
|
30
|
+
if !top_k.nil? && top_k > 0
|
31
|
+
items = items.slice(0, top_k) # Get top k items
|
32
|
+
end
|
33
|
+
|
34
|
+
items
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.max(arr)
|
38
|
+
if arr.length == 0
|
39
|
+
raise Error, "Array must not be empty"
|
40
|
+
end
|
41
|
+
arr.map.with_index.max_by { |v, _| v }
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
def self.mean_pooling(last_hidden_state, attention_mask)
|
4
|
+
last_hidden_state.zip(attention_mask).map do |state, mask|
|
5
|
+
state[0].size.times.map do |k|
|
6
|
+
sum = 0.0
|
7
|
+
count = 0
|
8
|
+
|
9
|
+
state.zip(mask) do |s, m|
|
10
|
+
count += m
|
11
|
+
sum += s[k] * m
|
12
|
+
end
|
13
|
+
|
14
|
+
sum / count
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.normalize(result)
|
20
|
+
result.map do |row|
|
21
|
+
norm = Math.sqrt(row.sum { |v| v * v })
|
22
|
+
row.map { |v| v / norm }
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
data/lib/informers/version.rb
CHANGED
data/lib/informers.rb
CHANGED
@@ -1,13 +1,32 @@
|
|
1
1
|
# dependencies
|
2
|
-
require "blingfire"
|
3
|
-
require "numo/narray"
|
4
2
|
require "onnxruntime"
|
3
|
+
require "tokenizers"
|
4
|
+
|
5
|
+
# stdlib
|
6
|
+
require "io/console"
|
7
|
+
require "json"
|
8
|
+
require "open-uri"
|
9
|
+
require "stringio"
|
10
|
+
require "uri"
|
5
11
|
|
6
12
|
# modules
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
13
|
+
require_relative "informers/utils/core"
|
14
|
+
require_relative "informers/utils/hub"
|
15
|
+
require_relative "informers/utils/math"
|
16
|
+
require_relative "informers/utils/tensor"
|
17
|
+
require_relative "informers/configs"
|
18
|
+
require_relative "informers/env"
|
19
|
+
require_relative "informers/model"
|
20
|
+
require_relative "informers/models"
|
21
|
+
require_relative "informers/tokenizers"
|
22
|
+
require_relative "informers/pipelines"
|
23
|
+
|
24
|
+
module Informers
|
25
|
+
class Error < StandardError; end
|
26
|
+
|
27
|
+
class Todo < Error
|
28
|
+
def message
|
29
|
+
"not implemented yet"
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
metadata
CHANGED
@@ -1,57 +1,43 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: informers
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 1.0.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2024-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
|
-
name:
|
15
|
-
requirement: !ruby/object:Gem::Requirement
|
16
|
-
requirements:
|
17
|
-
- - ">="
|
18
|
-
- !ruby/object:Gem::Version
|
19
|
-
version: 0.1.7
|
20
|
-
type: :runtime
|
21
|
-
prerelease: false
|
22
|
-
version_requirements: !ruby/object:Gem::Requirement
|
23
|
-
requirements:
|
24
|
-
- - ">="
|
25
|
-
- !ruby/object:Gem::Version
|
26
|
-
version: 0.1.7
|
27
|
-
- !ruby/object:Gem::Dependency
|
28
|
-
name: numo-narray
|
14
|
+
name: onnxruntime
|
29
15
|
requirement: !ruby/object:Gem::Requirement
|
30
16
|
requirements:
|
31
17
|
- - ">="
|
32
18
|
- !ruby/object:Gem::Version
|
33
|
-
version: '0'
|
19
|
+
version: '0.9'
|
34
20
|
type: :runtime
|
35
21
|
prerelease: false
|
36
22
|
version_requirements: !ruby/object:Gem::Requirement
|
37
23
|
requirements:
|
38
24
|
- - ">="
|
39
25
|
- !ruby/object:Gem::Version
|
40
|
-
version: '0'
|
26
|
+
version: '0.9'
|
41
27
|
- !ruby/object:Gem::Dependency
|
42
|
-
name:
|
28
|
+
name: tokenizers
|
43
29
|
requirement: !ruby/object:Gem::Requirement
|
44
30
|
requirements:
|
45
31
|
- - ">="
|
46
32
|
- !ruby/object:Gem::Version
|
47
|
-
version: 0.5.
|
33
|
+
version: 0.5.2
|
48
34
|
type: :runtime
|
49
35
|
prerelease: false
|
50
36
|
version_requirements: !ruby/object:Gem::Requirement
|
51
37
|
requirements:
|
52
38
|
- - ">="
|
53
39
|
- !ruby/object:Gem::Version
|
54
|
-
version: 0.5.
|
40
|
+
version: 0.5.2
|
55
41
|
description:
|
56
42
|
email: andrew@ankane.org
|
57
43
|
executables: []
|
@@ -62,23 +48,17 @@ files:
|
|
62
48
|
- LICENSE.txt
|
63
49
|
- README.md
|
64
50
|
- lib/informers.rb
|
65
|
-
- lib/informers/
|
66
|
-
- lib/informers/
|
67
|
-
- lib/informers/
|
68
|
-
- lib/informers/
|
69
|
-
- lib/informers/
|
70
|
-
- lib/informers/
|
51
|
+
- lib/informers/configs.rb
|
52
|
+
- lib/informers/env.rb
|
53
|
+
- lib/informers/model.rb
|
54
|
+
- lib/informers/models.rb
|
55
|
+
- lib/informers/pipelines.rb
|
56
|
+
- lib/informers/tokenizers.rb
|
57
|
+
- lib/informers/utils/core.rb
|
58
|
+
- lib/informers/utils/hub.rb
|
59
|
+
- lib/informers/utils/math.rb
|
60
|
+
- lib/informers/utils/tensor.rb
|
71
61
|
- lib/informers/version.rb
|
72
|
-
- vendor/LICENSE-bert.txt
|
73
|
-
- vendor/LICENSE-blingfire.txt
|
74
|
-
- vendor/LICENSE-gpt2.txt
|
75
|
-
- vendor/LICENSE-roberta.txt
|
76
|
-
- vendor/bert_base_cased_tok.bin
|
77
|
-
- vendor/bert_base_tok.bin
|
78
|
-
- vendor/gpt2.bin
|
79
|
-
- vendor/gpt2.i2w
|
80
|
-
- vendor/roberta.bin
|
81
|
-
- vendor/roberta.i2w
|
82
62
|
homepage: https://github.com/ankane/informers
|
83
63
|
licenses:
|
84
64
|
- Apache-2.0
|
@@ -91,15 +71,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
91
71
|
requirements:
|
92
72
|
- - ">="
|
93
73
|
- !ruby/object:Gem::Version
|
94
|
-
version: '
|
74
|
+
version: '3.1'
|
95
75
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
96
76
|
requirements:
|
97
77
|
- - ">="
|
98
78
|
- !ruby/object:Gem::Version
|
99
79
|
version: '0'
|
100
80
|
requirements: []
|
101
|
-
rubygems_version: 3.
|
81
|
+
rubygems_version: 3.5.11
|
102
82
|
signing_key:
|
103
83
|
specification_version: 4
|
104
|
-
summary:
|
84
|
+
summary: Fast transformer inference for Ruby
|
105
85
|
test_files: []
|
@@ -1,59 +0,0 @@
|
|
1
|
-
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
-
# Copyright 2020 Andrew Kane.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
|
16
|
-
module Informers
|
17
|
-
class FeatureExtraction
|
18
|
-
def initialize(model_path)
|
19
|
-
tokenizer_path = File.expand_path("../../vendor/bert_base_cased_tok.bin", __dir__)
|
20
|
-
@tokenizer = BlingFire.load_model(tokenizer_path)
|
21
|
-
@model = OnnxRuntime::Model.new(model_path)
|
22
|
-
end
|
23
|
-
|
24
|
-
def predict(texts)
|
25
|
-
singular = !texts.is_a?(Array)
|
26
|
-
texts = [texts] if singular
|
27
|
-
|
28
|
-
# tokenize
|
29
|
-
input_ids =
|
30
|
-
texts.map do |text|
|
31
|
-
tokens = @tokenizer.text_to_ids(text, nil, 100) # unk token
|
32
|
-
tokens.unshift(101) # cls token
|
33
|
-
tokens << 102 # sep token
|
34
|
-
tokens
|
35
|
-
end
|
36
|
-
|
37
|
-
max_tokens = input_ids.map(&:size).max
|
38
|
-
attention_mask = []
|
39
|
-
input_ids.each do |ids|
|
40
|
-
zeros = [0] * (max_tokens - ids.size)
|
41
|
-
|
42
|
-
mask = ([1] * ids.size) + zeros
|
43
|
-
attention_mask << mask
|
44
|
-
|
45
|
-
ids.concat(zeros)
|
46
|
-
end
|
47
|
-
|
48
|
-
# infer
|
49
|
-
input = {
|
50
|
-
input_ids: input_ids,
|
51
|
-
attention_mask: attention_mask
|
52
|
-
}
|
53
|
-
output = @model.predict(input)
|
54
|
-
scores = output["output_0"] || output["last_hidden_state"]
|
55
|
-
|
56
|
-
singular ? scores.first : scores
|
57
|
-
end
|
58
|
-
end
|
59
|
-
end
|
data/lib/informers/fill_mask.rb
DELETED
@@ -1,109 +0,0 @@
|
|
1
|
-
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
-
# Copyright 2021 Andrew Kane.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
|
16
|
-
module Informers
|
17
|
-
class FillMask
|
18
|
-
def initialize(model_path)
|
19
|
-
encoder_path = File.expand_path("../../vendor/roberta.bin", __dir__)
|
20
|
-
@encoder = BlingFire.load_model(encoder_path, prefix: false)
|
21
|
-
|
22
|
-
decoder_path = File.expand_path("../../vendor/roberta.i2w", __dir__)
|
23
|
-
@decoder = BlingFire.load_model(decoder_path)
|
24
|
-
|
25
|
-
@model = OnnxRuntime::Model.new(model_path)
|
26
|
-
end
|
27
|
-
|
28
|
-
def predict(texts)
|
29
|
-
singular = !texts.is_a?(Array)
|
30
|
-
texts = [texts] if singular
|
31
|
-
|
32
|
-
mask_token = 50264
|
33
|
-
|
34
|
-
# tokenize
|
35
|
-
input_ids =
|
36
|
-
texts.map do |text|
|
37
|
-
tokens = @encoder.text_to_ids(text, nil, 3) # unk token
|
38
|
-
|
39
|
-
# add mask token
|
40
|
-
mask_sequence = [28696, 43776, 15698]
|
41
|
-
masks = []
|
42
|
-
(tokens.size - 2).times do |i|
|
43
|
-
masks << i if tokens[i..(i + 2)] == mask_sequence
|
44
|
-
end
|
45
|
-
masks.reverse.each do |mask|
|
46
|
-
tokens = tokens[0...mask] + [mask_token] + tokens[(mask + 3)..-1]
|
47
|
-
end
|
48
|
-
|
49
|
-
tokens.unshift(0) # cls token
|
50
|
-
tokens << 2 # sep token
|
51
|
-
|
52
|
-
tokens
|
53
|
-
end
|
54
|
-
|
55
|
-
max_tokens = input_ids.map(&:size).max
|
56
|
-
attention_mask = []
|
57
|
-
input_ids.each do |ids|
|
58
|
-
zeros = [0] * (max_tokens - ids.size)
|
59
|
-
|
60
|
-
mask = ([1] * ids.size) + zeros
|
61
|
-
attention_mask << mask
|
62
|
-
|
63
|
-
ids.concat(zeros)
|
64
|
-
end
|
65
|
-
|
66
|
-
input = {
|
67
|
-
input_ids: input_ids,
|
68
|
-
attention_mask: attention_mask
|
69
|
-
}
|
70
|
-
|
71
|
-
masked_index = input_ids.map { |v| v.each_index.select { |i| v[i] == mask_token } }
|
72
|
-
masked_index.each do |v|
|
73
|
-
raise "No mask_token (<mask>) found on the input" if v.size < 1
|
74
|
-
raise "More than one mask_token (<mask>) is not supported" if v.size > 1
|
75
|
-
end
|
76
|
-
|
77
|
-
res = @model.predict(input)
|
78
|
-
outputs = res["output_0"] || res["logits"]
|
79
|
-
batch_size = outputs.size
|
80
|
-
|
81
|
-
results = []
|
82
|
-
batch_size.times do |i|
|
83
|
-
result = []
|
84
|
-
|
85
|
-
logits = outputs[i][masked_index[i][0]]
|
86
|
-
values = logits.map { |v| Math.exp(v) }
|
87
|
-
sum = values.sum
|
88
|
-
probs = values.map { |v| v / sum }
|
89
|
-
res = probs.each_with_index.sort_by { |v| -v[0] }.first(5)
|
90
|
-
|
91
|
-
res.each do |(v, p)|
|
92
|
-
tokens = input[:input_ids][i].dup
|
93
|
-
tokens[masked_index[i][0]] = p
|
94
|
-
result << {
|
95
|
-
sequence: @decoder.ids_to_text(tokens),
|
96
|
-
score: v,
|
97
|
-
token: p,
|
98
|
-
# TODO figure out prefix space
|
99
|
-
token_str: @decoder.ids_to_text([p], skip_special_tokens: false)
|
100
|
-
}
|
101
|
-
end
|
102
|
-
|
103
|
-
results += [result]
|
104
|
-
end
|
105
|
-
|
106
|
-
singular ? results.first : results
|
107
|
-
end
|
108
|
-
end
|
109
|
-
end
|