informers 1.0.3 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +137 -7
- data/lib/informers/configs.rb +10 -8
- data/lib/informers/model.rb +2 -9
- data/lib/informers/models.rb +1160 -15
- data/lib/informers/pipelines.rb +943 -11
- data/lib/informers/processors.rb +856 -0
- data/lib/informers/tokenizers.rb +159 -5
- data/lib/informers/utils/audio.rb +18 -0
- data/lib/informers/utils/core.rb +4 -0
- data/lib/informers/utils/ffmpeg.rb +45 -0
- data/lib/informers/utils/generation.rb +294 -0
- data/lib/informers/utils/image.rb +116 -0
- data/lib/informers/utils/math.rb +73 -0
- data/lib/informers/utils/tensor.rb +46 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +6 -0
- metadata +10 -5
data/lib/informers/tokenizers.rb
CHANGED
@@ -1,16 +1,65 @@
|
|
1
1
|
module Informers
|
2
2
|
class PreTrainedTokenizer
|
3
|
-
attr_reader :sep_token_id
|
3
|
+
attr_reader :mask_token, :mask_token_id, :sep_token_id
|
4
4
|
|
5
5
|
def initialize(tokenizer_json, tokenizer_config)
|
6
6
|
super()
|
7
7
|
|
8
|
+
@tokenizer_config = tokenizer_config
|
9
|
+
|
8
10
|
@tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)
|
9
11
|
|
10
|
-
|
11
|
-
@
|
12
|
+
# Add added_tokens to model
|
13
|
+
@special_tokens = []
|
14
|
+
@all_special_ids = []
|
15
|
+
|
16
|
+
@added_tokens = []
|
17
|
+
@tokenizer.added_tokens_decoder.each do |id, token|
|
18
|
+
@added_tokens << token
|
19
|
+
|
20
|
+
if token.special
|
21
|
+
@special_tokens << token.content
|
22
|
+
@all_special_ids << id
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
# Update additional_special_tokens
|
27
|
+
@additional_special_tokens = tokenizer_config["additional_special_tokens"] || []
|
28
|
+
@special_tokens.concat(@additional_special_tokens)
|
29
|
+
|
30
|
+
@mask_token = get_token("mask_token")
|
31
|
+
@mask_token_id = @tokenizer.token_to_id(@mask_token) if @mask_token
|
32
|
+
|
33
|
+
@sep_token = get_token("sep_token")
|
34
|
+
@sep_token_id = @tokenizer.token_to_id(@sep_token) if @sep_token
|
12
35
|
|
13
36
|
@model_max_length = tokenizer_config["model_max_length"]
|
37
|
+
|
38
|
+
# for donut-base-finetuned-docvqa
|
39
|
+
if @model_max_length && @model_max_length > (1 << 63)
|
40
|
+
@model_max_length = 1 << 63
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
def get_token(*keys)
|
45
|
+
keys.each do |key|
|
46
|
+
item = @tokenizer_config[key]
|
47
|
+
if !item
|
48
|
+
next
|
49
|
+
end
|
50
|
+
|
51
|
+
if item.is_a?(Hash)
|
52
|
+
if item["__type"] == "AddedToken"
|
53
|
+
return item["content"]
|
54
|
+
else
|
55
|
+
raise Error, "Unknown token: #{item}"
|
56
|
+
end
|
57
|
+
else
|
58
|
+
return item
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
nil
|
14
63
|
end
|
15
64
|
|
16
65
|
def call(
|
@@ -76,6 +125,22 @@ module Informers
|
|
76
125
|
def convert_tokens_to_string(tokens)
|
77
126
|
@tokenizer.decoder.decode(tokens)
|
78
127
|
end
|
128
|
+
|
129
|
+
def convert_tokens_to_ids(tokens)
|
130
|
+
tokens.map { |t| @tokenizer.token_to_id(t) }
|
131
|
+
end
|
132
|
+
|
133
|
+
def id_to_token(id)
|
134
|
+
@tokenizer.id_to_token(id)
|
135
|
+
end
|
136
|
+
|
137
|
+
def batch_decode(batch, **decode_args)
|
138
|
+
@tokenizer.decode_batch(batch, **decode_args)
|
139
|
+
end
|
140
|
+
|
141
|
+
def padding_side=(side)
|
142
|
+
@tokenizer.enable_padding(direction: side)
|
143
|
+
end
|
79
144
|
end
|
80
145
|
|
81
146
|
class BertTokenizer < PreTrainedTokenizer
|
@@ -91,6 +156,16 @@ module Informers
|
|
91
156
|
class DistilBertTokenizer < PreTrainedTokenizer
|
92
157
|
end
|
93
158
|
|
159
|
+
class T5Tokenizer < PreTrainedTokenizer
|
160
|
+
end
|
161
|
+
|
162
|
+
class GPT2Tokenizer < PreTrainedTokenizer
|
163
|
+
# _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`
|
164
|
+
end
|
165
|
+
|
166
|
+
class BartTokenizer < PreTrainedTokenizer
|
167
|
+
end
|
168
|
+
|
94
169
|
class RobertaTokenizer < PreTrainedTokenizer
|
95
170
|
end
|
96
171
|
|
@@ -100,14 +175,93 @@ module Informers
|
|
100
175
|
class MPNetTokenizer < PreTrainedTokenizer
|
101
176
|
end
|
102
177
|
|
178
|
+
class CLIPTokenizer < PreTrainedTokenizer
|
179
|
+
end
|
180
|
+
|
181
|
+
class NllbTokenizer < PreTrainedTokenizer
|
182
|
+
attr_reader :language_regex, :language_codes, :lang_to_token
|
183
|
+
|
184
|
+
def initialize(tokenizer_json, tokenizer_config)
|
185
|
+
super(tokenizer_json, tokenizer_config)
|
186
|
+
|
187
|
+
@language_regex = /^[a-z]{3}_[A-Z][a-z]{3}$/
|
188
|
+
@language_codes = @special_tokens.filter { |x| @language_regex.match?(x) }
|
189
|
+
@lang_to_token = ->(x) { x } # Identity function
|
190
|
+
end
|
191
|
+
|
192
|
+
def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
|
193
|
+
Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
|
194
|
+
end
|
195
|
+
end
|
196
|
+
|
197
|
+
class M2M100Tokenizer < PreTrainedTokenizer
|
198
|
+
attr_reader :language_regex, :language_codes, :lang_to_token
|
199
|
+
|
200
|
+
def initialize(tokenizer_json, tokenizer_config)
|
201
|
+
super(tokenizer_json, tokenizer_config)
|
202
|
+
|
203
|
+
@language_regex = /^__[a-z]{2,3}__$/
|
204
|
+
@language_codes = @special_tokens
|
205
|
+
.filter { |x| @language_regex.match?(x) }
|
206
|
+
.map { |x| x.slice(2, -2) }
|
207
|
+
@lang_to_token = ->(x) { "__#{x}__" }
|
208
|
+
end
|
209
|
+
|
210
|
+
def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
|
211
|
+
Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
|
212
|
+
end
|
213
|
+
end
|
214
|
+
|
215
|
+
module Utils
|
216
|
+
def self._build_translation_inputs(slf, raw_inputs, tokenizer_options, generate_kwargs)
|
217
|
+
if !slf.respond_to?(:language_codes) || !slf.language_codes.is_a?(Array)
|
218
|
+
raise Error, "Tokenizer must have `language_codes` attribute set and it should be an array of language ids."
|
219
|
+
end
|
220
|
+
if !slf.respond_to?(:language_regex) || !slf.language_regex.is_a?(Regexp)
|
221
|
+
raise Error, "Tokenizer must have `language_regex` attribute set and it should be a regular expression."
|
222
|
+
end
|
223
|
+
if !slf.respond_to?(:lang_to_token) || !slf.lang_to_token.respond_to?(:call)
|
224
|
+
raise Error, "Tokenizer must have `lang_to_token` attribute set and it should be a function."
|
225
|
+
end
|
226
|
+
src_lang_token = generate_kwargs[:src_lang]
|
227
|
+
tgt_lang_token = generate_kwargs[:tgt_lang]
|
228
|
+
|
229
|
+
if !slf.language_codes.include?(tgt_lang_token)
|
230
|
+
raise Error, "Target language code #{tgt_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
|
231
|
+
end
|
232
|
+
|
233
|
+
if !src_lang_token.nil?
|
234
|
+
# Check that the source language is valid:
|
235
|
+
if !slf.language_codes.include?(src_lang_token)
|
236
|
+
raise Error, "Source language code #{src_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
|
237
|
+
end
|
238
|
+
end
|
239
|
+
|
240
|
+
# Override the `forced_bos_token_id` to force the correct language
|
241
|
+
generate_kwargs["forced_bos_token_id"] = slf.convert_tokens_to_ids([slf.lang_to_token.(tgt_lang_token)])[0]
|
242
|
+
|
243
|
+
slf.(raw_inputs, **tokenizer_options)
|
244
|
+
end
|
245
|
+
end
|
246
|
+
|
247
|
+
class SpeechT5Tokenizer < PreTrainedTokenizer
|
248
|
+
end
|
249
|
+
|
103
250
|
class AutoTokenizer
|
104
251
|
TOKENIZER_CLASS_MAPPING = {
|
252
|
+
"T5Tokenizer" => T5Tokenizer,
|
105
253
|
"BertTokenizer" => BertTokenizer,
|
106
254
|
"DebertaV2Tokenizer" => DebertaV2Tokenizer,
|
107
255
|
"DistilBertTokenizer" => DistilBertTokenizer,
|
256
|
+
"BartTokenizer" => BartTokenizer,
|
108
257
|
"RobertaTokenizer" => RobertaTokenizer,
|
109
258
|
"XLMRobertaTokenizer" => XLMRobertaTokenizer,
|
110
|
-
"MPNetTokenizer" => MPNetTokenizer
|
259
|
+
"MPNetTokenizer" => MPNetTokenizer,
|
260
|
+
"CLIPTokenizer" => CLIPTokenizer,
|
261
|
+
"GPT2Tokenizer" => GPT2Tokenizer,
|
262
|
+
"NllbTokenizer" => NllbTokenizer,
|
263
|
+
"M2M100Tokenizer" => M2M100Tokenizer,
|
264
|
+
"SpeechT5Tokenizer" => SpeechT5Tokenizer
|
111
265
|
}
|
112
266
|
|
113
267
|
def self.from_pretrained(
|
@@ -146,7 +300,7 @@ module Informers
|
|
146
300
|
def self.load_tokenizer(pretrained_model_name_or_path, **options)
|
147
301
|
info = [
|
148
302
|
Utils::Hub.get_model_file(pretrained_model_name_or_path, "tokenizer.json", true, **options),
|
149
|
-
Utils::Hub.get_model_json(pretrained_model_name_or_path, "tokenizer_config.json", true, **options)
|
303
|
+
Utils::Hub.get_model_json(pretrained_model_name_or_path, "tokenizer_config.json", true, **options)
|
150
304
|
]
|
151
305
|
|
152
306
|
# Override legacy option if `options.legacy` is not null
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
def self.read_audio(input, sampling_rate)
|
4
|
+
data =
|
5
|
+
if input.is_a?(URI)
|
6
|
+
require "open-uri"
|
7
|
+
|
8
|
+
input.read
|
9
|
+
elsif input.is_a?(String)
|
10
|
+
File.binread(input)
|
11
|
+
else
|
12
|
+
raise ArgumentError, "Unsupported input type: #{input.class.name}"
|
13
|
+
end
|
14
|
+
|
15
|
+
ffmpeg_read(data, sampling_rate)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
data/lib/informers/utils/core.rb
CHANGED
@@ -0,0 +1,45 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Informers
|
16
|
+
module Utils
|
17
|
+
# from the Transformers Python library
|
18
|
+
def self.ffmpeg_read(data, sampling_rate)
|
19
|
+
ar = "#{sampling_rate}"
|
20
|
+
ac = "1"
|
21
|
+
format_for_conversion = "f32le"
|
22
|
+
ffmpeg_command = [
|
23
|
+
"ffmpeg",
|
24
|
+
"-i",
|
25
|
+
"pipe:0",
|
26
|
+
"-ac",
|
27
|
+
ac,
|
28
|
+
"-ar",
|
29
|
+
ar,
|
30
|
+
"-f",
|
31
|
+
format_for_conversion,
|
32
|
+
"-hide_banner",
|
33
|
+
"-loglevel",
|
34
|
+
"quiet",
|
35
|
+
"pipe:1"
|
36
|
+
]
|
37
|
+
|
38
|
+
stdout, status = Open3.capture2(*ffmpeg_command, stdin_data: data)
|
39
|
+
if !status.success?
|
40
|
+
raise Error, "ffmpeg was not found but is required to load audio files from filename"
|
41
|
+
end
|
42
|
+
stdout.unpack("e*")
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
@@ -0,0 +1,294 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
class GenerationConfig
|
4
|
+
def initialize(kwargs)
|
5
|
+
@config = {}
|
6
|
+
|
7
|
+
# Parameters that control the length of the output
|
8
|
+
@config["max_length"] = kwargs["max_length"] || 20
|
9
|
+
@config["max_new_tokens"] = kwargs["max_new_tokens"]
|
10
|
+
@config["min_length"] = kwargs["min_length"] || 0
|
11
|
+
@config["min_new_tokens"] = kwargs["min_new_tokens"]
|
12
|
+
@config["early_stopping"] = kwargs["early_stopping"] || false
|
13
|
+
@config["max_time"] = kwargs["max_time"]
|
14
|
+
|
15
|
+
# Parameters that control the generation strategy used
|
16
|
+
@config["do_sample"] = kwargs["do_sample"] || false
|
17
|
+
@config["num_beams"] = kwargs["num_beams"] || 1
|
18
|
+
@config["num_beam_groups"] = kwargs["num_beam_groups"] || 1
|
19
|
+
@config["penalty_alpha"] = kwargs["penalty_alpha"]
|
20
|
+
@config["use_cache"] = kwargs.fetch("use_cache", true)
|
21
|
+
|
22
|
+
# Parameters for manipulation of the model output logits
|
23
|
+
@config["temperature"] = kwargs["temperature"] || 1.0
|
24
|
+
@config["top_k"] = kwargs["top_k"] || 50
|
25
|
+
@config["top_p"] = kwargs["top_p"] || 1.0
|
26
|
+
@config["typical_p"] = kwargs["typical_p"] || 1.0
|
27
|
+
@config["epsilon_cutoff"] = kwargs["epsilon_cutoff"] || 0.0
|
28
|
+
@config["eta_cutoff"] = kwargs["eta_cutoff"] || 0.0
|
29
|
+
@config["diversity_penalty"] = kwargs["diversity_penalty"] || 0.0
|
30
|
+
@config["repetition_penalty"] = kwargs["repetition_penalty"] || 1.0
|
31
|
+
@config["encoder_repetition_penalty"] = kwargs["encoder_repetition_penalty"] || 1.0
|
32
|
+
@config["length_penalty"] = kwargs["length_penalty"] || 1.0
|
33
|
+
@config["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"] || 0
|
34
|
+
@config["bad_words_ids"] = kwargs["bad_words_ids"]
|
35
|
+
@config["force_words_ids"] = kwargs["force_words_ids"]
|
36
|
+
@config["renormalize_logits"] = kwargs["renormalize_logits"] || false
|
37
|
+
@config["constraints"] = kwargs["constraints"]
|
38
|
+
@config["forced_bos_token_id"] = kwargs["forced_bos_token_id"]
|
39
|
+
@config["forced_eos_token_id"] = kwargs["forced_eos_token_id"]
|
40
|
+
@config["remove_invalid_values"] = kwargs["remove_invalid_values"] || false
|
41
|
+
@config["exponential_decay_length_penalty"] = kwargs["exponential_decay_length_penalty"]
|
42
|
+
@config["suppress_tokens"] = kwargs["suppress_tokens"]
|
43
|
+
@config["begin_suppress_tokens"] = kwargs["begin_suppress_tokens"]
|
44
|
+
@config["forced_decoder_ids"] = kwargs["forced_decoder_ids"]
|
45
|
+
|
46
|
+
# Parameters that define the output variables of `generate`
|
47
|
+
@config["num_return_sequences"] = kwargs["num_return_sequences"] || 1
|
48
|
+
@config["output_attentions"] = kwargs["output_attentions"] || false
|
49
|
+
@config["output_hidden_states"] = kwargs["output_hidden_states"] || false
|
50
|
+
@config["output_scores"] = kwargs["output_scores"] || false
|
51
|
+
@config["return_dict_in_generate"] = kwargs["return_dict_in_generate"] || false
|
52
|
+
|
53
|
+
# Special tokens that can be used at generation time
|
54
|
+
@config["pad_token_id"] = kwargs["pad_token_id"]
|
55
|
+
@config["bos_token_id"] = kwargs["bos_token_id"]
|
56
|
+
@config["eos_token_id"] = kwargs["eos_token_id"]
|
57
|
+
|
58
|
+
# Generation parameters exclusive to encoder-decoder models
|
59
|
+
@config["encoder_no_repeat_ngram_size"] = kwargs["encoder_no_repeat_ngram_size"] || 0
|
60
|
+
@config["decoder_start_token_id"] = kwargs["decoder_start_token_id"]
|
61
|
+
|
62
|
+
# Wild card
|
63
|
+
@generation_kwargs = kwargs["generation_kwargs"] || {}
|
64
|
+
end
|
65
|
+
|
66
|
+
def [](key)
|
67
|
+
@config[key.to_s]
|
68
|
+
end
|
69
|
+
|
70
|
+
def merge!(config)
|
71
|
+
@config.merge!(config)
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
class Sampler
|
76
|
+
def initialize(generation_config)
|
77
|
+
super()
|
78
|
+
@generation_config = generation_config
|
79
|
+
end
|
80
|
+
|
81
|
+
def call(logits, index = -1)
|
82
|
+
# Sample from logits, of dims [batch, sequence_length, vocab_size].
|
83
|
+
# If index is specified, sample from [batch, index, vocab_size].
|
84
|
+
sample(logits, index)
|
85
|
+
end
|
86
|
+
|
87
|
+
def get_logits(logits, index)
|
88
|
+
vocab_size = Utils.dims(logits)[-1]
|
89
|
+
|
90
|
+
logs = logits.flatten
|
91
|
+
|
92
|
+
if index == -1
|
93
|
+
logs = logs.last(vocab_size)
|
94
|
+
else
|
95
|
+
raise Todo
|
96
|
+
end
|
97
|
+
|
98
|
+
# add temperature
|
99
|
+
if @generation_config["temperature"] > 0
|
100
|
+
logs = logs.map { |x| x / @generation_config["temperature"] }
|
101
|
+
end
|
102
|
+
logs
|
103
|
+
end
|
104
|
+
|
105
|
+
def self.get_sampler(generation_config)
|
106
|
+
if generation_config[:do_sample]
|
107
|
+
MultinomialSampler.new(generation_config)
|
108
|
+
elsif generation_config[:num_beams] > 1
|
109
|
+
BeamSearchSampler.new(generation_config)
|
110
|
+
else
|
111
|
+
if generation_config[:num_return_sequences] > 1
|
112
|
+
raise Error, "num_return_sequences has to be 1 when doing greedy search, but is #{generation_config[:num_return_sequences]}."
|
113
|
+
end
|
114
|
+
GreedySampler.new(generation_config)
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
class GreedySampler < Sampler
|
120
|
+
def sample(logits, index = -1)
|
121
|
+
# NOTE: no need to do log_softmax here since we only take the maximum
|
122
|
+
logs = get_logits(logits, index)
|
123
|
+
argmax = Utils.max(logs)[1]
|
124
|
+
|
125
|
+
# Note: score is meaningless in this context, since we are performing
|
126
|
+
# greedy search (p = 1 => log(p) = 0)
|
127
|
+
[
|
128
|
+
[argmax, 0]
|
129
|
+
]
|
130
|
+
end
|
131
|
+
end
|
132
|
+
|
133
|
+
class BeamSearchSampler < Sampler
|
134
|
+
def sample(logits, index = -1)
|
135
|
+
k = Utils.dims(logits)[-1] # defaults to vocab size
|
136
|
+
if @generation_config["top_k"] > 0
|
137
|
+
k = [@generation_config["top_k"], k].min
|
138
|
+
end
|
139
|
+
|
140
|
+
# Get logits of nth token
|
141
|
+
logs = get_logits(logits, index)
|
142
|
+
|
143
|
+
# Get top k tokens
|
144
|
+
top_logits = Utils.get_top_items(logs, k)
|
145
|
+
|
146
|
+
# Compute softmax over logits
|
147
|
+
probabilities = Utils.softmax(top_logits.map { |x| x[1] })
|
148
|
+
|
149
|
+
Array.new(@generation_config["num_beams"]) do |i|
|
150
|
+
[
|
151
|
+
top_logits[i][0],
|
152
|
+
Math.log(probabilities[i])
|
153
|
+
]
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
class LogitsProcessorList
|
159
|
+
def initialize
|
160
|
+
super
|
161
|
+
@processors = []
|
162
|
+
end
|
163
|
+
|
164
|
+
def push(item)
|
165
|
+
@processors << item
|
166
|
+
end
|
167
|
+
|
168
|
+
def concat(items)
|
169
|
+
@processors.concat(items)
|
170
|
+
end
|
171
|
+
|
172
|
+
def call(input_ids, batched_logits)
|
173
|
+
# NOTE: This is different from the Python code, since vanilla Ruby does not support vectorized operations.
|
174
|
+
# As a result, we apply each processor to each item in the batch.
|
175
|
+
batched_logits.each do |logits|
|
176
|
+
# Modifies logits inplace
|
177
|
+
@processors.each do |func|
|
178
|
+
func.(input_ids, logits)
|
179
|
+
end
|
180
|
+
end
|
181
|
+
end
|
182
|
+
|
183
|
+
def to_ary
|
184
|
+
@processors
|
185
|
+
end
|
186
|
+
end
|
187
|
+
|
188
|
+
class LogitsProcessor
|
189
|
+
end
|
190
|
+
|
191
|
+
class NoRepeatNGramLogitsProcessor < LogitsProcessor
|
192
|
+
def initialize(no_repeat_ngram_size)
|
193
|
+
super()
|
194
|
+
@no_repeat_ngram_size = no_repeat_ngram_size
|
195
|
+
end
|
196
|
+
|
197
|
+
def get_ngrams(prev_input_ids)
|
198
|
+
cur_len = prev_input_ids.length
|
199
|
+
|
200
|
+
ngrams = []
|
201
|
+
j = 0
|
202
|
+
while j < cur_len + 1 - @no_repeat_ngram_size
|
203
|
+
ngram = []
|
204
|
+
@no_repeat_ngram_size.times do |k|
|
205
|
+
ngram << prev_input_ids[j + k]
|
206
|
+
end
|
207
|
+
ngrams << ngram
|
208
|
+
j += 1
|
209
|
+
end
|
210
|
+
|
211
|
+
generated_ngram = {}
|
212
|
+
ngrams.each do |ngram|
|
213
|
+
prev_ngram = ngram.slice(0, ngram.length - 1)
|
214
|
+
prev_ngram_key = JSON.generate(prev_ngram)
|
215
|
+
prev_ngram_value = generated_ngram[prev_ngram_key] || []
|
216
|
+
prev_ngram_value << ngram[ngram.length - 1]
|
217
|
+
generated_ngram[prev_ngram_key] = prev_ngram_value
|
218
|
+
end
|
219
|
+
generated_ngram
|
220
|
+
end
|
221
|
+
|
222
|
+
def get_generated_ngrams(banned_ngrams, prev_input_ids)
|
223
|
+
ngram_idx = prev_input_ids.slice(prev_input_ids.length + 1 - @no_repeat_ngram_size, prev_input_ids.length)
|
224
|
+
banned = banned_ngrams[JSON.generate(ngram_idx)] || []
|
225
|
+
banned
|
226
|
+
end
|
227
|
+
|
228
|
+
def calc_banned_ngram_tokens(prev_input_ids)
|
229
|
+
banned_tokens = []
|
230
|
+
if prev_input_ids.length + 1 < @no_repeat_ngram_size
|
231
|
+
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
232
|
+
banned_tokens
|
233
|
+
else
|
234
|
+
generated_ngrams = get_ngrams(prev_input_ids)
|
235
|
+
banned_tokens = get_generated_ngrams(generated_ngrams, prev_input_ids)
|
236
|
+
banned_tokens
|
237
|
+
end
|
238
|
+
end
|
239
|
+
|
240
|
+
def call(input_ids, logits)
|
241
|
+
banned_tokens = calc_banned_ngram_tokens(input_ids)
|
242
|
+
|
243
|
+
banned_tokens.each do |token|
|
244
|
+
logits[token] = -Float::INFINITY
|
245
|
+
end
|
246
|
+
logits
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
class MinLengthLogitsProcessor < LogitsProcessor
|
251
|
+
def initialize(min_length, eos_token_id)
|
252
|
+
super()
|
253
|
+
@min_length = min_length
|
254
|
+
@eos_token_id = eos_token_id.is_a?(Array) ? eos_token_id : [eos_token_id]
|
255
|
+
end
|
256
|
+
|
257
|
+
def call(input_ids, logits)
|
258
|
+
if input_ids.length < @min_length
|
259
|
+
@eos_token_id.each do |eos_token|
|
260
|
+
logits[eos_token] = -Float::INFINITY
|
261
|
+
end
|
262
|
+
end
|
263
|
+
|
264
|
+
logits
|
265
|
+
end
|
266
|
+
end
|
267
|
+
|
268
|
+
class ForcedBOSTokenLogitsProcessor < LogitsProcessor
|
269
|
+
def initialize(bos_token_id)
|
270
|
+
super()
|
271
|
+
@bos_token_id = bos_token_id
|
272
|
+
end
|
273
|
+
|
274
|
+
def call(input_ids, logits)
|
275
|
+
if input_ids.length == 1
|
276
|
+
logits.map! { -Float::INFINITY }
|
277
|
+
logits[@bos_token_id] = 0
|
278
|
+
end
|
279
|
+
logits
|
280
|
+
end
|
281
|
+
end
|
282
|
+
|
283
|
+
class ForcedEOSTokenLogitsProcessor < LogitsProcessor
|
284
|
+
def initialize(max_length, forced_eos_token_id)
|
285
|
+
super()
|
286
|
+
@max_length = max_length
|
287
|
+
@forced_eos_token_id = forced_eos_token_id
|
288
|
+
end
|
289
|
+
|
290
|
+
def call(input_ids, logits)
|
291
|
+
end
|
292
|
+
end
|
293
|
+
end
|
294
|
+
end
|
@@ -0,0 +1,116 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
class RawImage
|
4
|
+
RESAMPLING_MAPPING = {
|
5
|
+
0 => "nearest",
|
6
|
+
1 => "lanczos",
|
7
|
+
2 => "bilinear",
|
8
|
+
3 => "bicubic",
|
9
|
+
4 => "box",
|
10
|
+
5 => "hamming"
|
11
|
+
}
|
12
|
+
|
13
|
+
attr_reader :image, :width, :height, :channels
|
14
|
+
|
15
|
+
def initialize(image)
|
16
|
+
@image = image
|
17
|
+
@width = image.width
|
18
|
+
@height = image.height
|
19
|
+
@channels = image.bands
|
20
|
+
end
|
21
|
+
|
22
|
+
def data
|
23
|
+
@image.write_to_memory.unpack("C*")
|
24
|
+
end
|
25
|
+
|
26
|
+
def size
|
27
|
+
[@width, @height]
|
28
|
+
end
|
29
|
+
|
30
|
+
def resize(width, height, resample: 2)
|
31
|
+
resample_method = RESAMPLING_MAPPING[resample] || resample
|
32
|
+
|
33
|
+
case resample_method
|
34
|
+
when "bilinear", "bicubic"
|
35
|
+
img =
|
36
|
+
@image.affine(
|
37
|
+
[width / @width.to_f, 0, 0, height / @height.to_f],
|
38
|
+
interpolate: Vips::Interpolate.new(resample_method.to_sym)
|
39
|
+
)
|
40
|
+
else
|
41
|
+
raise Todo
|
42
|
+
end
|
43
|
+
|
44
|
+
RawImage.new(img)
|
45
|
+
end
|
46
|
+
|
47
|
+
def center_crop(crop_width, crop_height)
|
48
|
+
# If the image is already the desired size, return it
|
49
|
+
if @width == crop_width && @height == crop_height
|
50
|
+
return self
|
51
|
+
end
|
52
|
+
|
53
|
+
# Determine bounds of the image in the new canvas
|
54
|
+
width_offset = (@width - crop_width) / 2.0
|
55
|
+
height_offset = (@height - crop_height) / 2.0
|
56
|
+
|
57
|
+
if width_offset >= 0 && height_offset >= 0
|
58
|
+
# Cropped image lies entirely within the original image
|
59
|
+
img = @image.crop(
|
60
|
+
width_offset.floor,
|
61
|
+
height_offset.floor,
|
62
|
+
crop_width,
|
63
|
+
crop_height
|
64
|
+
)
|
65
|
+
elsif width_offset <= 0 && height_offset <= 0
|
66
|
+
raise Todo
|
67
|
+
else
|
68
|
+
raise Todo
|
69
|
+
end
|
70
|
+
|
71
|
+
RawImage.new(img)
|
72
|
+
end
|
73
|
+
|
74
|
+
def rgb
|
75
|
+
if @channels == 3
|
76
|
+
return self
|
77
|
+
end
|
78
|
+
|
79
|
+
raise Todo
|
80
|
+
end
|
81
|
+
|
82
|
+
def save(path)
|
83
|
+
@image.write_to_file(path)
|
84
|
+
end
|
85
|
+
|
86
|
+
def self.read(input)
|
87
|
+
if input.is_a?(RawImage)
|
88
|
+
input
|
89
|
+
elsif input.is_a?(URI)
|
90
|
+
require "open-uri"
|
91
|
+
|
92
|
+
RawImage.new(Vips::Image.new_from_buffer(input.read, ""))
|
93
|
+
elsif input.is_a?(String)
|
94
|
+
RawImage.new(Vips::Image.new_from_file(input))
|
95
|
+
else
|
96
|
+
raise ArgumentError, "Unsupported input type: #{input.class.name}"
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
def self.from_array(input)
|
101
|
+
c, h, w = Utils.dims(input)
|
102
|
+
pixel_data = Array.new(w * h * c)
|
103
|
+
|
104
|
+
input.each_with_index do |cv, ci|
|
105
|
+
cv.each_with_index do |hv, hi|
|
106
|
+
hv.each_with_index do |v, wi|
|
107
|
+
pixel_data[(hi * w * c) + (wi * c) + ci] = v
|
108
|
+
end
|
109
|
+
end
|
110
|
+
end
|
111
|
+
|
112
|
+
RawImage.new(Vips::Image.new_from_memory_copy(pixel_data.pack("C*"), w, h, c, :uchar))
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|