informers 0.2.0 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +63 -99
- data/lib/informers/configs.rb +48 -0
- data/lib/informers/env.rb +14 -0
- data/lib/informers/model.rb +31 -0
- data/lib/informers/models.rb +294 -0
- data/lib/informers/pipelines.rb +439 -0
- data/lib/informers/tokenizers.rb +141 -0
- data/lib/informers/utils/core.rb +7 -0
- data/lib/informers/utils/hub.rb +240 -0
- data/lib/informers/utils/math.rb +44 -0
- data/lib/informers/utils/tensor.rb +26 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +28 -9
- metadata +21 -41
- data/lib/informers/feature_extraction.rb +0 -59
- data/lib/informers/fill_mask.rb +0 -109
- data/lib/informers/ner.rb +0 -106
- data/lib/informers/question_answering.rb +0 -197
- data/lib/informers/sentiment_analysis.rb +0 -72
- data/lib/informers/text_generation.rb +0 -54
- data/vendor/LICENSE-bert.txt +0 -202
- data/vendor/LICENSE-blingfire.txt +0 -21
- data/vendor/LICENSE-gpt2.txt +0 -24
- data/vendor/LICENSE-roberta.txt +0 -21
- data/vendor/bert_base_cased_tok.bin +0 -0
- data/vendor/bert_base_tok.bin +0 -0
- data/vendor/gpt2.bin +0 -0
- data/vendor/gpt2.i2w +0 -0
- data/vendor/roberta.bin +0 -0
- data/vendor/roberta.i2w +0 -0
@@ -0,0 +1,240 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
module Hub
|
4
|
+
class FileResponse
|
5
|
+
attr_reader :exists, :status
|
6
|
+
|
7
|
+
def initialize(file_path)
|
8
|
+
@file_path = file_path
|
9
|
+
|
10
|
+
@exists = File.exist?(file_path)
|
11
|
+
if @exists
|
12
|
+
@status = ["200", "OK"]
|
13
|
+
else
|
14
|
+
@status = ["404", "Not Found"]
|
15
|
+
end
|
16
|
+
end
|
17
|
+
|
18
|
+
def read
|
19
|
+
File.binread(@file_path)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
def self.is_valid_url(string, protocols = nil, valid_hosts = nil)
|
24
|
+
begin
|
25
|
+
url = URI.parse(string)
|
26
|
+
rescue
|
27
|
+
return false
|
28
|
+
end
|
29
|
+
if protocols && !protocols.include?(url.scheme)
|
30
|
+
return false
|
31
|
+
end
|
32
|
+
if valid_hosts && !valid_hosts.include?(url.host)
|
33
|
+
return false
|
34
|
+
end
|
35
|
+
true
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.get_file(url_or_path, progress_callback = nil, progress_info = {})
|
39
|
+
if !is_valid_url(url_or_path, ["http", "https"])
|
40
|
+
raise Error, "Invalid url"
|
41
|
+
else
|
42
|
+
headers = {}
|
43
|
+
headers["User-Agent"] = "informers/#{VERSION};"
|
44
|
+
|
45
|
+
# Check whether we are making a request to the Hugging Face Hub.
|
46
|
+
is_hfurl = is_valid_url(url_or_path, ["http", "https"], ["huggingface.co", "hf.co"])
|
47
|
+
if is_hfurl
|
48
|
+
# If an access token is present in the environment variables,
|
49
|
+
# we add it to the request headers.
|
50
|
+
token = ENV["HF_TOKEN"]
|
51
|
+
if token
|
52
|
+
headers["Authorization"] = "Bearer #{token}"
|
53
|
+
end
|
54
|
+
end
|
55
|
+
options = {}
|
56
|
+
if progress_callback
|
57
|
+
total_size = nil
|
58
|
+
options[:content_length_proc] = lambda do |size|
|
59
|
+
total_size = size
|
60
|
+
Utils.dispatch_callback(progress_callback, {status: "download"}.merge(progress_info).merge(total_size: size))
|
61
|
+
end
|
62
|
+
options[:progress_proc] = lambda do |size|
|
63
|
+
Utils.dispatch_callback(progress_callback, {status: "progress"}.merge(progress_info).merge(size: size, total_size: total_size))
|
64
|
+
end
|
65
|
+
end
|
66
|
+
URI.parse(url_or_path).open(**headers, **options)
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
class FileCache
|
71
|
+
attr_reader :path
|
72
|
+
|
73
|
+
def initialize(path)
|
74
|
+
@path = path
|
75
|
+
end
|
76
|
+
|
77
|
+
def match(request)
|
78
|
+
file_path = resolve_path(request)
|
79
|
+
file = FileResponse.new(file_path)
|
80
|
+
|
81
|
+
file if file.exists
|
82
|
+
end
|
83
|
+
|
84
|
+
def put(request, buffer)
|
85
|
+
output_path = resolve_path(request)
|
86
|
+
|
87
|
+
begin
|
88
|
+
FileUtils.mkdir_p(File.dirname(output_path))
|
89
|
+
File.binwrite(output_path, buffer)
|
90
|
+
rescue => e
|
91
|
+
warn "An error occurred while writing the file to cache: #{e}"
|
92
|
+
end
|
93
|
+
end
|
94
|
+
|
95
|
+
def resolve_path(request)
|
96
|
+
File.join(@path, request)
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
def self.try_cache(cache, *names)
|
101
|
+
names.each do |name|
|
102
|
+
begin
|
103
|
+
result = cache.match(name)
|
104
|
+
return result if result
|
105
|
+
rescue
|
106
|
+
next
|
107
|
+
end
|
108
|
+
end
|
109
|
+
nil
|
110
|
+
end
|
111
|
+
|
112
|
+
def self.get_model_file(path_or_repo_id, filename, fatal = true, **options)
|
113
|
+
# Initiate file retrieval
|
114
|
+
Utils.dispatch_callback(options[:progress_callback], {
|
115
|
+
status: "initiate",
|
116
|
+
name: path_or_repo_id,
|
117
|
+
file: filename
|
118
|
+
})
|
119
|
+
|
120
|
+
# If `cache_dir` is not specified, use the default cache directory
|
121
|
+
cache = FileCache.new(options[:cache_dir] || Informers.cache_dir)
|
122
|
+
|
123
|
+
revision = options[:revision] || "main"
|
124
|
+
|
125
|
+
request_url = path_join(path_or_repo_id, filename)
|
126
|
+
|
127
|
+
remote_url = path_join(
|
128
|
+
Informers.remote_host,
|
129
|
+
Informers.remote_path_template
|
130
|
+
.gsub("{model}", path_or_repo_id)
|
131
|
+
.gsub("{revision}", URI.encode_www_form_component(revision)),
|
132
|
+
filename
|
133
|
+
)
|
134
|
+
|
135
|
+
# Choose cache key for filesystem cache
|
136
|
+
# When using the main revision (default), we use the request URL as the cache key.
|
137
|
+
# If a specific revision is requested, we account for this in the cache key.
|
138
|
+
fs_cache_key = revision == "main" ? request_url : path_join(path_or_repo_id, revision, filename)
|
139
|
+
|
140
|
+
proposed_cache_key = fs_cache_key
|
141
|
+
|
142
|
+
resolved_path = cache.resolve_path(proposed_cache_key)
|
143
|
+
|
144
|
+
# Whether to cache the final response in the end.
|
145
|
+
to_cache_response = false
|
146
|
+
|
147
|
+
# A caching system is available, so we try to get the file from it.
|
148
|
+
response = try_cache(cache, proposed_cache_key)
|
149
|
+
|
150
|
+
cache_hit = !response.nil?
|
151
|
+
|
152
|
+
if response.nil?
|
153
|
+
# File is not cached, so we perform the request
|
154
|
+
|
155
|
+
if response.nil? || response.status[0] == "404"
|
156
|
+
# File not found locally. This means either:
|
157
|
+
# - The user has disabled local file access (`Informers.allow_local_models = false`)
|
158
|
+
# - the path is a valid HTTP url (`response.nil?`)
|
159
|
+
# - the path is not a valid HTTP url and the file is not present on the file system or local server (`response.status[0] == "404"`)
|
160
|
+
|
161
|
+
if options[:local_files_only] || !Informers.allow_remote_models
|
162
|
+
# User requested local files only, but the file is not found locally.
|
163
|
+
if fatal
|
164
|
+
raise Error, "`local_files_only: true` or `Informers.allow_remote_models = false` and file was not found locally at #{resolved_path.inspect}."
|
165
|
+
else
|
166
|
+
# File not found, but this file is optional.
|
167
|
+
# TODO in future, cache the response?
|
168
|
+
return nil
|
169
|
+
end
|
170
|
+
end
|
171
|
+
|
172
|
+
progress_info = {
|
173
|
+
name: path_or_repo_id,
|
174
|
+
file: filename
|
175
|
+
}
|
176
|
+
|
177
|
+
# File not found locally, so we try to download it from the remote server
|
178
|
+
response = get_file(remote_url, options[:progress_callback], progress_info)
|
179
|
+
|
180
|
+
if response.status[0] != "200"
|
181
|
+
# should not happen
|
182
|
+
raise Todo
|
183
|
+
end
|
184
|
+
|
185
|
+
# Success! We use the proposed cache key from earlier
|
186
|
+
cache_key = proposed_cache_key
|
187
|
+
end
|
188
|
+
|
189
|
+
to_cache_response = cache && !response.is_a?(FileResponse) && response.status[0] == "200"
|
190
|
+
end
|
191
|
+
|
192
|
+
buffer = response.read
|
193
|
+
|
194
|
+
if to_cache_response && cache_key && cache.match(cache_key).nil?
|
195
|
+
cache.put(cache_key, buffer)
|
196
|
+
end
|
197
|
+
|
198
|
+
Utils.dispatch_callback(options[:progress_callback], {
|
199
|
+
status: "done",
|
200
|
+
name: path_or_repo_id,
|
201
|
+
file: filename,
|
202
|
+
cache_hit: cache_hit
|
203
|
+
})
|
204
|
+
|
205
|
+
resolved_path
|
206
|
+
end
|
207
|
+
|
208
|
+
def self.get_model_json(model_path, file_name, fatal = true, **options)
|
209
|
+
buffer = get_model_file(model_path, file_name, fatal, **options)
|
210
|
+
if buffer.nil?
|
211
|
+
# Return empty object
|
212
|
+
return {}
|
213
|
+
end
|
214
|
+
|
215
|
+
JSON.load_file(buffer)
|
216
|
+
end
|
217
|
+
|
218
|
+
def self.path_join(*parts)
|
219
|
+
parts = parts.map.with_index do |part, index|
|
220
|
+
if index != 0
|
221
|
+
part = part.delete_prefix("/")
|
222
|
+
end
|
223
|
+
if index != parts.length - 1
|
224
|
+
part = part.delete_suffix("/")
|
225
|
+
end
|
226
|
+
part
|
227
|
+
end
|
228
|
+
parts.join("/")
|
229
|
+
end
|
230
|
+
|
231
|
+
def self.display_progress(filename, width, size, expected_size)
|
232
|
+
bar_width = width - (filename.length + 3)
|
233
|
+
progress = size / expected_size.to_f
|
234
|
+
done = (progress * bar_width).round
|
235
|
+
not_done = bar_width - done
|
236
|
+
"#{filename} |#{"█" * done}#{" " * not_done}|"
|
237
|
+
end
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
@@ -0,0 +1,44 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
def self.softmax(arr)
|
4
|
+
# Compute the maximum value in the array
|
5
|
+
max_val = arr.max
|
6
|
+
|
7
|
+
# Compute the exponentials of the array values
|
8
|
+
exps = arr.map { |x| Math.exp(x - max_val) }
|
9
|
+
|
10
|
+
# Compute the sum of the exponentials
|
11
|
+
sum_exps = exps.sum
|
12
|
+
|
13
|
+
# Compute the softmax values
|
14
|
+
softmax_arr = exps.map { |x| x / sum_exps }
|
15
|
+
|
16
|
+
softmax_arr
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.sigmoid(arr)
|
20
|
+
arr.map { |v| 1 / (1 + Math.exp(-v)) }
|
21
|
+
end
|
22
|
+
|
23
|
+
def self.get_top_items(items, top_k = 0)
|
24
|
+
# if top == 0, return all
|
25
|
+
|
26
|
+
items = items
|
27
|
+
.map.with_index { |x, i| [i, x] } # Get indices ([index, score])
|
28
|
+
.sort_by { |v| -v[1] } # Sort by log probabilities
|
29
|
+
|
30
|
+
if !top_k.nil? && top_k > 0
|
31
|
+
items = items.slice(0, top_k) # Get top k items
|
32
|
+
end
|
33
|
+
|
34
|
+
items
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.max(arr)
|
38
|
+
if arr.length == 0
|
39
|
+
raise Error, "Array must not be empty"
|
40
|
+
end
|
41
|
+
arr.map.with_index.max_by { |v, _| v }
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
def self.mean_pooling(last_hidden_state, attention_mask)
|
4
|
+
last_hidden_state.zip(attention_mask).map do |state, mask|
|
5
|
+
state[0].size.times.map do |k|
|
6
|
+
sum = 0.0
|
7
|
+
count = 0
|
8
|
+
|
9
|
+
state.zip(mask) do |s, m|
|
10
|
+
count += m
|
11
|
+
sum += s[k] * m
|
12
|
+
end
|
13
|
+
|
14
|
+
sum / count
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.normalize(result)
|
20
|
+
result.map do |row|
|
21
|
+
norm = Math.sqrt(row.sum { |v| v * v })
|
22
|
+
row.map { |v| v / norm }
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
data/lib/informers/version.rb
CHANGED
data/lib/informers.rb
CHANGED
@@ -1,13 +1,32 @@
|
|
1
1
|
# dependencies
|
2
|
-
require "blingfire"
|
3
|
-
require "numo/narray"
|
4
2
|
require "onnxruntime"
|
3
|
+
require "tokenizers"
|
4
|
+
|
5
|
+
# stdlib
|
6
|
+
require "io/console"
|
7
|
+
require "json"
|
8
|
+
require "open-uri"
|
9
|
+
require "stringio"
|
10
|
+
require "uri"
|
5
11
|
|
6
12
|
# modules
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
13
|
+
require_relative "informers/utils/core"
|
14
|
+
require_relative "informers/utils/hub"
|
15
|
+
require_relative "informers/utils/math"
|
16
|
+
require_relative "informers/utils/tensor"
|
17
|
+
require_relative "informers/configs"
|
18
|
+
require_relative "informers/env"
|
19
|
+
require_relative "informers/model"
|
20
|
+
require_relative "informers/models"
|
21
|
+
require_relative "informers/tokenizers"
|
22
|
+
require_relative "informers/pipelines"
|
23
|
+
|
24
|
+
module Informers
|
25
|
+
class Error < StandardError; end
|
26
|
+
|
27
|
+
class Todo < Error
|
28
|
+
def message
|
29
|
+
"not implemented yet"
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
metadata
CHANGED
@@ -1,57 +1,43 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: informers
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 1.0.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2024-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
|
-
name:
|
15
|
-
requirement: !ruby/object:Gem::Requirement
|
16
|
-
requirements:
|
17
|
-
- - ">="
|
18
|
-
- !ruby/object:Gem::Version
|
19
|
-
version: 0.1.7
|
20
|
-
type: :runtime
|
21
|
-
prerelease: false
|
22
|
-
version_requirements: !ruby/object:Gem::Requirement
|
23
|
-
requirements:
|
24
|
-
- - ">="
|
25
|
-
- !ruby/object:Gem::Version
|
26
|
-
version: 0.1.7
|
27
|
-
- !ruby/object:Gem::Dependency
|
28
|
-
name: numo-narray
|
14
|
+
name: onnxruntime
|
29
15
|
requirement: !ruby/object:Gem::Requirement
|
30
16
|
requirements:
|
31
17
|
- - ">="
|
32
18
|
- !ruby/object:Gem::Version
|
33
|
-
version: '0'
|
19
|
+
version: '0.9'
|
34
20
|
type: :runtime
|
35
21
|
prerelease: false
|
36
22
|
version_requirements: !ruby/object:Gem::Requirement
|
37
23
|
requirements:
|
38
24
|
- - ">="
|
39
25
|
- !ruby/object:Gem::Version
|
40
|
-
version: '0'
|
26
|
+
version: '0.9'
|
41
27
|
- !ruby/object:Gem::Dependency
|
42
|
-
name:
|
28
|
+
name: tokenizers
|
43
29
|
requirement: !ruby/object:Gem::Requirement
|
44
30
|
requirements:
|
45
31
|
- - ">="
|
46
32
|
- !ruby/object:Gem::Version
|
47
|
-
version: 0.5.
|
33
|
+
version: 0.5.2
|
48
34
|
type: :runtime
|
49
35
|
prerelease: false
|
50
36
|
version_requirements: !ruby/object:Gem::Requirement
|
51
37
|
requirements:
|
52
38
|
- - ">="
|
53
39
|
- !ruby/object:Gem::Version
|
54
|
-
version: 0.5.
|
40
|
+
version: 0.5.2
|
55
41
|
description:
|
56
42
|
email: andrew@ankane.org
|
57
43
|
executables: []
|
@@ -62,23 +48,17 @@ files:
|
|
62
48
|
- LICENSE.txt
|
63
49
|
- README.md
|
64
50
|
- lib/informers.rb
|
65
|
-
- lib/informers/
|
66
|
-
- lib/informers/
|
67
|
-
- lib/informers/
|
68
|
-
- lib/informers/
|
69
|
-
- lib/informers/
|
70
|
-
- lib/informers/
|
51
|
+
- lib/informers/configs.rb
|
52
|
+
- lib/informers/env.rb
|
53
|
+
- lib/informers/model.rb
|
54
|
+
- lib/informers/models.rb
|
55
|
+
- lib/informers/pipelines.rb
|
56
|
+
- lib/informers/tokenizers.rb
|
57
|
+
- lib/informers/utils/core.rb
|
58
|
+
- lib/informers/utils/hub.rb
|
59
|
+
- lib/informers/utils/math.rb
|
60
|
+
- lib/informers/utils/tensor.rb
|
71
61
|
- lib/informers/version.rb
|
72
|
-
- vendor/LICENSE-bert.txt
|
73
|
-
- vendor/LICENSE-blingfire.txt
|
74
|
-
- vendor/LICENSE-gpt2.txt
|
75
|
-
- vendor/LICENSE-roberta.txt
|
76
|
-
- vendor/bert_base_cased_tok.bin
|
77
|
-
- vendor/bert_base_tok.bin
|
78
|
-
- vendor/gpt2.bin
|
79
|
-
- vendor/gpt2.i2w
|
80
|
-
- vendor/roberta.bin
|
81
|
-
- vendor/roberta.i2w
|
82
62
|
homepage: https://github.com/ankane/informers
|
83
63
|
licenses:
|
84
64
|
- Apache-2.0
|
@@ -91,15 +71,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
91
71
|
requirements:
|
92
72
|
- - ">="
|
93
73
|
- !ruby/object:Gem::Version
|
94
|
-
version: '
|
74
|
+
version: '3.1'
|
95
75
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
96
76
|
requirements:
|
97
77
|
- - ">="
|
98
78
|
- !ruby/object:Gem::Version
|
99
79
|
version: '0'
|
100
80
|
requirements: []
|
101
|
-
rubygems_version: 3.
|
81
|
+
rubygems_version: 3.5.11
|
102
82
|
signing_key:
|
103
83
|
specification_version: 4
|
104
|
-
summary:
|
84
|
+
summary: Fast transformer inference for Ruby
|
105
85
|
test_files: []
|
@@ -1,59 +0,0 @@
|
|
1
|
-
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
-
# Copyright 2020 Andrew Kane.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
|
16
|
-
module Informers
|
17
|
-
class FeatureExtraction
|
18
|
-
def initialize(model_path)
|
19
|
-
tokenizer_path = File.expand_path("../../vendor/bert_base_cased_tok.bin", __dir__)
|
20
|
-
@tokenizer = BlingFire.load_model(tokenizer_path)
|
21
|
-
@model = OnnxRuntime::Model.new(model_path)
|
22
|
-
end
|
23
|
-
|
24
|
-
def predict(texts)
|
25
|
-
singular = !texts.is_a?(Array)
|
26
|
-
texts = [texts] if singular
|
27
|
-
|
28
|
-
# tokenize
|
29
|
-
input_ids =
|
30
|
-
texts.map do |text|
|
31
|
-
tokens = @tokenizer.text_to_ids(text, nil, 100) # unk token
|
32
|
-
tokens.unshift(101) # cls token
|
33
|
-
tokens << 102 # sep token
|
34
|
-
tokens
|
35
|
-
end
|
36
|
-
|
37
|
-
max_tokens = input_ids.map(&:size).max
|
38
|
-
attention_mask = []
|
39
|
-
input_ids.each do |ids|
|
40
|
-
zeros = [0] * (max_tokens - ids.size)
|
41
|
-
|
42
|
-
mask = ([1] * ids.size) + zeros
|
43
|
-
attention_mask << mask
|
44
|
-
|
45
|
-
ids.concat(zeros)
|
46
|
-
end
|
47
|
-
|
48
|
-
# infer
|
49
|
-
input = {
|
50
|
-
input_ids: input_ids,
|
51
|
-
attention_mask: attention_mask
|
52
|
-
}
|
53
|
-
output = @model.predict(input)
|
54
|
-
scores = output["output_0"] || output["last_hidden_state"]
|
55
|
-
|
56
|
-
singular ? scores.first : scores
|
57
|
-
end
|
58
|
-
end
|
59
|
-
end
|
data/lib/informers/fill_mask.rb
DELETED
@@ -1,109 +0,0 @@
|
|
1
|
-
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
-
# Copyright 2021 Andrew Kane.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
|
16
|
-
module Informers
|
17
|
-
class FillMask
|
18
|
-
def initialize(model_path)
|
19
|
-
encoder_path = File.expand_path("../../vendor/roberta.bin", __dir__)
|
20
|
-
@encoder = BlingFire.load_model(encoder_path, prefix: false)
|
21
|
-
|
22
|
-
decoder_path = File.expand_path("../../vendor/roberta.i2w", __dir__)
|
23
|
-
@decoder = BlingFire.load_model(decoder_path)
|
24
|
-
|
25
|
-
@model = OnnxRuntime::Model.new(model_path)
|
26
|
-
end
|
27
|
-
|
28
|
-
def predict(texts)
|
29
|
-
singular = !texts.is_a?(Array)
|
30
|
-
texts = [texts] if singular
|
31
|
-
|
32
|
-
mask_token = 50264
|
33
|
-
|
34
|
-
# tokenize
|
35
|
-
input_ids =
|
36
|
-
texts.map do |text|
|
37
|
-
tokens = @encoder.text_to_ids(text, nil, 3) # unk token
|
38
|
-
|
39
|
-
# add mask token
|
40
|
-
mask_sequence = [28696, 43776, 15698]
|
41
|
-
masks = []
|
42
|
-
(tokens.size - 2).times do |i|
|
43
|
-
masks << i if tokens[i..(i + 2)] == mask_sequence
|
44
|
-
end
|
45
|
-
masks.reverse.each do |mask|
|
46
|
-
tokens = tokens[0...mask] + [mask_token] + tokens[(mask + 3)..-1]
|
47
|
-
end
|
48
|
-
|
49
|
-
tokens.unshift(0) # cls token
|
50
|
-
tokens << 2 # sep token
|
51
|
-
|
52
|
-
tokens
|
53
|
-
end
|
54
|
-
|
55
|
-
max_tokens = input_ids.map(&:size).max
|
56
|
-
attention_mask = []
|
57
|
-
input_ids.each do |ids|
|
58
|
-
zeros = [0] * (max_tokens - ids.size)
|
59
|
-
|
60
|
-
mask = ([1] * ids.size) + zeros
|
61
|
-
attention_mask << mask
|
62
|
-
|
63
|
-
ids.concat(zeros)
|
64
|
-
end
|
65
|
-
|
66
|
-
input = {
|
67
|
-
input_ids: input_ids,
|
68
|
-
attention_mask: attention_mask
|
69
|
-
}
|
70
|
-
|
71
|
-
masked_index = input_ids.map { |v| v.each_index.select { |i| v[i] == mask_token } }
|
72
|
-
masked_index.each do |v|
|
73
|
-
raise "No mask_token (<mask>) found on the input" if v.size < 1
|
74
|
-
raise "More than one mask_token (<mask>) is not supported" if v.size > 1
|
75
|
-
end
|
76
|
-
|
77
|
-
res = @model.predict(input)
|
78
|
-
outputs = res["output_0"] || res["logits"]
|
79
|
-
batch_size = outputs.size
|
80
|
-
|
81
|
-
results = []
|
82
|
-
batch_size.times do |i|
|
83
|
-
result = []
|
84
|
-
|
85
|
-
logits = outputs[i][masked_index[i][0]]
|
86
|
-
values = logits.map { |v| Math.exp(v) }
|
87
|
-
sum = values.sum
|
88
|
-
probs = values.map { |v| v / sum }
|
89
|
-
res = probs.each_with_index.sort_by { |v| -v[0] }.first(5)
|
90
|
-
|
91
|
-
res.each do |(v, p)|
|
92
|
-
tokens = input[:input_ids][i].dup
|
93
|
-
tokens[masked_index[i][0]] = p
|
94
|
-
result << {
|
95
|
-
sequence: @decoder.ids_to_text(tokens),
|
96
|
-
score: v,
|
97
|
-
token: p,
|
98
|
-
# TODO figure out prefix space
|
99
|
-
token_str: @decoder.ids_to_text([p], skip_special_tokens: false)
|
100
|
-
}
|
101
|
-
end
|
102
|
-
|
103
|
-
results += [result]
|
104
|
-
end
|
105
|
-
|
106
|
-
singular ? results.first : results
|
107
|
-
end
|
108
|
-
end
|
109
|
-
end
|