informers 1.0.2 → 1.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +213 -19
- data/lib/informers/configs.rb +10 -8
- data/lib/informers/model.rb +2 -14
- data/lib/informers/models.rb +1027 -13
- data/lib/informers/pipelines.rb +781 -14
- data/lib/informers/processors.rb +796 -0
- data/lib/informers/tokenizers.rb +166 -4
- data/lib/informers/utils/core.rb +4 -0
- data/lib/informers/utils/generation.rb +294 -0
- data/lib/informers/utils/image.rb +116 -0
- data/lib/informers/utils/math.rb +73 -0
- data/lib/informers/utils/tensor.rb +46 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +3 -0
- metadata +8 -5
data/lib/informers/tokenizers.rb
CHANGED
@@ -1,16 +1,65 @@
|
|
1
1
|
module Informers
|
2
2
|
class PreTrainedTokenizer
|
3
|
-
attr_reader :sep_token_id
|
3
|
+
attr_reader :mask_token, :mask_token_id, :sep_token_id
|
4
4
|
|
5
5
|
def initialize(tokenizer_json, tokenizer_config)
|
6
6
|
super()
|
7
7
|
|
8
|
+
@tokenizer_config = tokenizer_config
|
9
|
+
|
8
10
|
@tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)
|
9
11
|
|
10
|
-
|
11
|
-
@
|
12
|
+
# Add added_tokens to model
|
13
|
+
@special_tokens = []
|
14
|
+
@all_special_ids = []
|
15
|
+
|
16
|
+
@added_tokens = []
|
17
|
+
@tokenizer.added_tokens_decoder.each do |id, token|
|
18
|
+
@added_tokens << token
|
19
|
+
|
20
|
+
if token.special
|
21
|
+
@special_tokens << token.content
|
22
|
+
@all_special_ids << id
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
# Update additional_special_tokens
|
27
|
+
@additional_special_tokens = tokenizer_config["additional_special_tokens"] || []
|
28
|
+
@special_tokens.concat(@additional_special_tokens)
|
29
|
+
|
30
|
+
@mask_token = get_token("mask_token")
|
31
|
+
@mask_token_id = @tokenizer.token_to_id(@mask_token) if @mask_token
|
32
|
+
|
33
|
+
@sep_token = get_token("sep_token")
|
34
|
+
@sep_token_id = @tokenizer.token_to_id(@sep_token) if @sep_token
|
12
35
|
|
13
36
|
@model_max_length = tokenizer_config["model_max_length"]
|
37
|
+
|
38
|
+
# for donut-base-finetuned-docvqa
|
39
|
+
if @model_max_length && @model_max_length > (1 << 63)
|
40
|
+
@model_max_length = 1 << 63
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
def get_token(*keys)
|
45
|
+
keys.each do |key|
|
46
|
+
item = @tokenizer_config[key]
|
47
|
+
if !item
|
48
|
+
next
|
49
|
+
end
|
50
|
+
|
51
|
+
if item.is_a?(Hash)
|
52
|
+
if item["__type"] == "AddedToken"
|
53
|
+
return item["content"]
|
54
|
+
else
|
55
|
+
raise Error, "Unknown token: #{item}"
|
56
|
+
end
|
57
|
+
else
|
58
|
+
return item
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
nil
|
14
63
|
end
|
15
64
|
|
16
65
|
def call(
|
@@ -76,6 +125,22 @@ module Informers
|
|
76
125
|
def convert_tokens_to_string(tokens)
|
77
126
|
@tokenizer.decoder.decode(tokens)
|
78
127
|
end
|
128
|
+
|
129
|
+
def convert_tokens_to_ids(tokens)
|
130
|
+
tokens.map { |t| @tokenizer.token_to_id(t) }
|
131
|
+
end
|
132
|
+
|
133
|
+
def id_to_token(id)
|
134
|
+
@tokenizer.id_to_token(id)
|
135
|
+
end
|
136
|
+
|
137
|
+
def batch_decode(batch, **decode_args)
|
138
|
+
@tokenizer.decode_batch(batch, **decode_args)
|
139
|
+
end
|
140
|
+
|
141
|
+
def padding_side=(side)
|
142
|
+
@tokenizer.enable_padding(direction: side)
|
143
|
+
end
|
79
144
|
end
|
80
145
|
|
81
146
|
class BertTokenizer < PreTrainedTokenizer
|
@@ -91,11 +156,108 @@ module Informers
|
|
91
156
|
class DistilBertTokenizer < PreTrainedTokenizer
|
92
157
|
end
|
93
158
|
|
159
|
+
class T5Tokenizer < PreTrainedTokenizer
|
160
|
+
end
|
161
|
+
|
162
|
+
class GPT2Tokenizer < PreTrainedTokenizer
|
163
|
+
# _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`
|
164
|
+
end
|
165
|
+
|
166
|
+
class BartTokenizer < PreTrainedTokenizer
|
167
|
+
end
|
168
|
+
|
169
|
+
class RobertaTokenizer < PreTrainedTokenizer
|
170
|
+
end
|
171
|
+
|
172
|
+
class XLMRobertaTokenizer < PreTrainedTokenizer
|
173
|
+
end
|
174
|
+
|
175
|
+
class MPNetTokenizer < PreTrainedTokenizer
|
176
|
+
end
|
177
|
+
|
178
|
+
class CLIPTokenizer < PreTrainedTokenizer
|
179
|
+
end
|
180
|
+
|
181
|
+
class NllbTokenizer < PreTrainedTokenizer
|
182
|
+
attr_reader :language_regex, :language_codes, :lang_to_token
|
183
|
+
|
184
|
+
def initialize(tokenizer_json, tokenizer_config)
|
185
|
+
super(tokenizer_json, tokenizer_config)
|
186
|
+
|
187
|
+
@language_regex = /^[a-z]{3}_[A-Z][a-z]{3}$/
|
188
|
+
@language_codes = @special_tokens.filter { |x| @language_regex.match?(x) }
|
189
|
+
@lang_to_token = ->(x) { x } # Identity function
|
190
|
+
end
|
191
|
+
|
192
|
+
def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
|
193
|
+
Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
|
194
|
+
end
|
195
|
+
end
|
196
|
+
|
197
|
+
class M2M100Tokenizer < PreTrainedTokenizer
|
198
|
+
attr_reader :language_regex, :language_codes, :lang_to_token
|
199
|
+
|
200
|
+
def initialize(tokenizer_json, tokenizer_config)
|
201
|
+
super(tokenizer_json, tokenizer_config)
|
202
|
+
|
203
|
+
@language_regex = /^__[a-z]{2,3}__$/
|
204
|
+
@language_codes = @special_tokens
|
205
|
+
.filter { |x| @language_regex.match?(x) }
|
206
|
+
.map { |x| x.slice(2, -2) }
|
207
|
+
@lang_to_token = ->(x) { "__#{x}__" }
|
208
|
+
end
|
209
|
+
|
210
|
+
def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
|
211
|
+
Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
|
212
|
+
end
|
213
|
+
end
|
214
|
+
|
215
|
+
module Utils
|
216
|
+
def self._build_translation_inputs(slf, raw_inputs, tokenizer_options, generate_kwargs)
|
217
|
+
if !slf.respond_to?(:language_codes) || !slf.language_codes.is_a?(Array)
|
218
|
+
raise Error, "Tokenizer must have `language_codes` attribute set and it should be an array of language ids."
|
219
|
+
end
|
220
|
+
if !slf.respond_to?(:language_regex) || !slf.language_regex.is_a?(Regexp)
|
221
|
+
raise Error, "Tokenizer must have `language_regex` attribute set and it should be a regular expression."
|
222
|
+
end
|
223
|
+
if !slf.respond_to?(:lang_to_token) || !slf.lang_to_token.respond_to?(:call)
|
224
|
+
raise Error, "Tokenizer must have `lang_to_token` attribute set and it should be a function."
|
225
|
+
end
|
226
|
+
src_lang_token = generate_kwargs[:src_lang]
|
227
|
+
tgt_lang_token = generate_kwargs[:tgt_lang]
|
228
|
+
|
229
|
+
if !slf.language_codes.include?(tgt_lang_token)
|
230
|
+
raise Error, "Target language code #{tgt_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
|
231
|
+
end
|
232
|
+
|
233
|
+
if !src_lang_token.nil?
|
234
|
+
# Check that the source language is valid:
|
235
|
+
if !slf.language_codes.include?(src_lang_token)
|
236
|
+
raise Error, "Source language code #{src_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
|
237
|
+
end
|
238
|
+
end
|
239
|
+
|
240
|
+
# Override the `forced_bos_token_id` to force the correct language
|
241
|
+
generate_kwargs["forced_bos_token_id"] = slf.convert_tokens_to_ids([slf.lang_to_token.(tgt_lang_token)])[0]
|
242
|
+
|
243
|
+
slf.(raw_inputs, **tokenizer_options)
|
244
|
+
end
|
245
|
+
end
|
246
|
+
|
94
247
|
class AutoTokenizer
|
95
248
|
TOKENIZER_CLASS_MAPPING = {
|
249
|
+
"T5Tokenizer" => T5Tokenizer,
|
96
250
|
"BertTokenizer" => BertTokenizer,
|
97
251
|
"DebertaV2Tokenizer" => DebertaV2Tokenizer,
|
98
|
-
"DistilBertTokenizer" => DistilBertTokenizer
|
252
|
+
"DistilBertTokenizer" => DistilBertTokenizer,
|
253
|
+
"BartTokenizer" => BartTokenizer,
|
254
|
+
"RobertaTokenizer" => RobertaTokenizer,
|
255
|
+
"XLMRobertaTokenizer" => XLMRobertaTokenizer,
|
256
|
+
"MPNetTokenizer" => MPNetTokenizer,
|
257
|
+
"CLIPTokenizer" => CLIPTokenizer,
|
258
|
+
"GPT2Tokenizer" => GPT2Tokenizer,
|
259
|
+
"NllbTokenizer" => NllbTokenizer,
|
260
|
+
"M2M100Tokenizer" => M2M100Tokenizer
|
99
261
|
}
|
100
262
|
|
101
263
|
def self.from_pretrained(
|
data/lib/informers/utils/core.rb
CHANGED
@@ -0,0 +1,294 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
class GenerationConfig
|
4
|
+
def initialize(kwargs)
|
5
|
+
@config = {}
|
6
|
+
|
7
|
+
# Parameters that control the length of the output
|
8
|
+
@config["max_length"] = kwargs["max_length"] || 20
|
9
|
+
@config["max_new_tokens"] = kwargs["max_new_tokens"]
|
10
|
+
@config["min_length"] = kwargs["min_length"] || 0
|
11
|
+
@config["min_new_tokens"] = kwargs["min_new_tokens"]
|
12
|
+
@config["early_stopping"] = kwargs["early_stopping"] || false
|
13
|
+
@config["max_time"] = kwargs["max_time"]
|
14
|
+
|
15
|
+
# Parameters that control the generation strategy used
|
16
|
+
@config["do_sample"] = kwargs["do_sample"] || false
|
17
|
+
@config["num_beams"] = kwargs["num_beams"] || 1
|
18
|
+
@config["num_beam_groups"] = kwargs["num_beam_groups"] || 1
|
19
|
+
@config["penalty_alpha"] = kwargs["penalty_alpha"]
|
20
|
+
@config["use_cache"] = kwargs.fetch("use_cache", true)
|
21
|
+
|
22
|
+
# Parameters for manipulation of the model output logits
|
23
|
+
@config["temperature"] = kwargs["temperature"] || 1.0
|
24
|
+
@config["top_k"] = kwargs["top_k"] || 50
|
25
|
+
@config["top_p"] = kwargs["top_p"] || 1.0
|
26
|
+
@config["typical_p"] = kwargs["typical_p"] || 1.0
|
27
|
+
@config["epsilon_cutoff"] = kwargs["epsilon_cutoff"] || 0.0
|
28
|
+
@config["eta_cutoff"] = kwargs["eta_cutoff"] || 0.0
|
29
|
+
@config["diversity_penalty"] = kwargs["diversity_penalty"] || 0.0
|
30
|
+
@config["repetition_penalty"] = kwargs["repetition_penalty"] || 1.0
|
31
|
+
@config["encoder_repetition_penalty"] = kwargs["encoder_repetition_penalty"] || 1.0
|
32
|
+
@config["length_penalty"] = kwargs["length_penalty"] || 1.0
|
33
|
+
@config["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"] || 0
|
34
|
+
@config["bad_words_ids"] = kwargs["bad_words_ids"]
|
35
|
+
@config["force_words_ids"] = kwargs["force_words_ids"]
|
36
|
+
@config["renormalize_logits"] = kwargs["renormalize_logits"] || false
|
37
|
+
@config["constraints"] = kwargs["constraints"]
|
38
|
+
@config["forced_bos_token_id"] = kwargs["forced_bos_token_id"]
|
39
|
+
@config["forced_eos_token_id"] = kwargs["forced_eos_token_id"]
|
40
|
+
@config["remove_invalid_values"] = kwargs["remove_invalid_values"] || false
|
41
|
+
@config["exponential_decay_length_penalty"] = kwargs["exponential_decay_length_penalty"]
|
42
|
+
@config["suppress_tokens"] = kwargs["suppress_tokens"]
|
43
|
+
@config["begin_suppress_tokens"] = kwargs["begin_suppress_tokens"]
|
44
|
+
@config["forced_decoder_ids"] = kwargs["forced_decoder_ids"]
|
45
|
+
|
46
|
+
# Parameters that define the output variables of `generate`
|
47
|
+
@config["num_return_sequences"] = kwargs["num_return_sequences"] || 1
|
48
|
+
@config["output_attentions"] = kwargs["output_attentions"] || false
|
49
|
+
@config["output_hidden_states"] = kwargs["output_hidden_states"] || false
|
50
|
+
@config["output_scores"] = kwargs["output_scores"] || false
|
51
|
+
@config["return_dict_in_generate"] = kwargs["return_dict_in_generate"] || false
|
52
|
+
|
53
|
+
# Special tokens that can be used at generation time
|
54
|
+
@config["pad_token_id"] = kwargs["pad_token_id"]
|
55
|
+
@config["bos_token_id"] = kwargs["bos_token_id"]
|
56
|
+
@config["eos_token_id"] = kwargs["eos_token_id"]
|
57
|
+
|
58
|
+
# Generation parameters exclusive to encoder-decoder models
|
59
|
+
@config["encoder_no_repeat_ngram_size"] = kwargs["encoder_no_repeat_ngram_size"] || 0
|
60
|
+
@config["decoder_start_token_id"] = kwargs["decoder_start_token_id"]
|
61
|
+
|
62
|
+
# Wild card
|
63
|
+
@generation_kwargs = kwargs["generation_kwargs"] || {}
|
64
|
+
end
|
65
|
+
|
66
|
+
def [](key)
|
67
|
+
@config[key.to_s]
|
68
|
+
end
|
69
|
+
|
70
|
+
def merge!(config)
|
71
|
+
@config.merge!(config)
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
class Sampler
|
76
|
+
def initialize(generation_config)
|
77
|
+
super()
|
78
|
+
@generation_config = generation_config
|
79
|
+
end
|
80
|
+
|
81
|
+
def call(logits, index = -1)
|
82
|
+
# Sample from logits, of dims [batch, sequence_length, vocab_size].
|
83
|
+
# If index is specified, sample from [batch, index, vocab_size].
|
84
|
+
sample(logits, index)
|
85
|
+
end
|
86
|
+
|
87
|
+
def get_logits(logits, index)
|
88
|
+
vocab_size = Utils.dims(logits)[-1]
|
89
|
+
|
90
|
+
logs = logits.flatten
|
91
|
+
|
92
|
+
if index == -1
|
93
|
+
logs = logs.last(vocab_size)
|
94
|
+
else
|
95
|
+
raise Todo
|
96
|
+
end
|
97
|
+
|
98
|
+
# add temperature
|
99
|
+
if @generation_config["temperature"] > 0
|
100
|
+
logs = logs.map { |x| x / @generation_config["temperature"] }
|
101
|
+
end
|
102
|
+
logs
|
103
|
+
end
|
104
|
+
|
105
|
+
def self.get_sampler(generation_config)
|
106
|
+
if generation_config[:do_sample]
|
107
|
+
MultinomialSampler.new(generation_config)
|
108
|
+
elsif generation_config[:num_beams] > 1
|
109
|
+
BeamSearchSampler.new(generation_config)
|
110
|
+
else
|
111
|
+
if generation_config[:num_return_sequences] > 1
|
112
|
+
raise Error, "num_return_sequences has to be 1 when doing greedy search, but is #{generation_config[:num_return_sequences]}."
|
113
|
+
end
|
114
|
+
GreedySampler.new(generation_config)
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
class GreedySampler < Sampler
|
120
|
+
def sample(logits, index = -1)
|
121
|
+
# NOTE: no need to do log_softmax here since we only take the maximum
|
122
|
+
logs = get_logits(logits, index)
|
123
|
+
argmax = Utils.max(logs)[1]
|
124
|
+
|
125
|
+
# Note: score is meaningless in this context, since we are performing
|
126
|
+
# greedy search (p = 1 => log(p) = 0)
|
127
|
+
[
|
128
|
+
[argmax, 0]
|
129
|
+
]
|
130
|
+
end
|
131
|
+
end
|
132
|
+
|
133
|
+
class BeamSearchSampler < Sampler
|
134
|
+
def sample(logits, index = -1)
|
135
|
+
k = Utils.dims(logits)[-1] # defaults to vocab size
|
136
|
+
if @generation_config["top_k"] > 0
|
137
|
+
k = [@generation_config["top_k"], k].min
|
138
|
+
end
|
139
|
+
|
140
|
+
# Get logits of nth token
|
141
|
+
logs = get_logits(logits, index)
|
142
|
+
|
143
|
+
# Get top k tokens
|
144
|
+
top_logits = Utils.get_top_items(logs, k)
|
145
|
+
|
146
|
+
# Compute softmax over logits
|
147
|
+
probabilities = Utils.softmax(top_logits.map { |x| x[1] })
|
148
|
+
|
149
|
+
Array.new(@generation_config["num_beams"]) do |i|
|
150
|
+
[
|
151
|
+
top_logits[i][0],
|
152
|
+
Math.log(probabilities[i])
|
153
|
+
]
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
class LogitsProcessorList
|
159
|
+
def initialize
|
160
|
+
super
|
161
|
+
@processors = []
|
162
|
+
end
|
163
|
+
|
164
|
+
def push(item)
|
165
|
+
@processors << item
|
166
|
+
end
|
167
|
+
|
168
|
+
def concat(items)
|
169
|
+
@processors.concat(items)
|
170
|
+
end
|
171
|
+
|
172
|
+
def call(input_ids, batched_logits)
|
173
|
+
# NOTE: This is different from the Python code, since vanilla Ruby does not support vectorized operations.
|
174
|
+
# As a result, we apply each processor to each item in the batch.
|
175
|
+
batched_logits.each do |logits|
|
176
|
+
# Modifies logits inplace
|
177
|
+
@processors.each do |func|
|
178
|
+
func.(input_ids, logits)
|
179
|
+
end
|
180
|
+
end
|
181
|
+
end
|
182
|
+
|
183
|
+
def to_ary
|
184
|
+
@processors
|
185
|
+
end
|
186
|
+
end
|
187
|
+
|
188
|
+
class LogitsProcessor
|
189
|
+
end
|
190
|
+
|
191
|
+
class NoRepeatNGramLogitsProcessor < LogitsProcessor
|
192
|
+
def initialize(no_repeat_ngram_size)
|
193
|
+
super()
|
194
|
+
@no_repeat_ngram_size = no_repeat_ngram_size
|
195
|
+
end
|
196
|
+
|
197
|
+
def get_ngrams(prev_input_ids)
|
198
|
+
cur_len = prev_input_ids.length
|
199
|
+
|
200
|
+
ngrams = []
|
201
|
+
j = 0
|
202
|
+
while j < cur_len + 1 - @no_repeat_ngram_size
|
203
|
+
ngram = []
|
204
|
+
@no_repeat_ngram_size.times do |k|
|
205
|
+
ngram << prev_input_ids[j + k]
|
206
|
+
end
|
207
|
+
ngrams << ngram
|
208
|
+
j += 1
|
209
|
+
end
|
210
|
+
|
211
|
+
generated_ngram = {}
|
212
|
+
ngrams.each do |ngram|
|
213
|
+
prev_ngram = ngram.slice(0, ngram.length - 1)
|
214
|
+
prev_ngram_key = JSON.generate(prev_ngram)
|
215
|
+
prev_ngram_value = generated_ngram[prev_ngram_key] || []
|
216
|
+
prev_ngram_value << ngram[ngram.length - 1]
|
217
|
+
generated_ngram[prev_ngram_key] = prev_ngram_value
|
218
|
+
end
|
219
|
+
generated_ngram
|
220
|
+
end
|
221
|
+
|
222
|
+
def get_generated_ngrams(banned_ngrams, prev_input_ids)
|
223
|
+
ngram_idx = prev_input_ids.slice(prev_input_ids.length + 1 - @no_repeat_ngram_size, prev_input_ids.length)
|
224
|
+
banned = banned_ngrams[JSON.generate(ngram_idx)] || []
|
225
|
+
banned
|
226
|
+
end
|
227
|
+
|
228
|
+
def calc_banned_ngram_tokens(prev_input_ids)
|
229
|
+
banned_tokens = []
|
230
|
+
if prev_input_ids.length + 1 < @no_repeat_ngram_size
|
231
|
+
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
232
|
+
banned_tokens
|
233
|
+
else
|
234
|
+
generated_ngrams = get_ngrams(prev_input_ids)
|
235
|
+
banned_tokens = get_generated_ngrams(generated_ngrams, prev_input_ids)
|
236
|
+
banned_tokens
|
237
|
+
end
|
238
|
+
end
|
239
|
+
|
240
|
+
def call(input_ids, logits)
|
241
|
+
banned_tokens = calc_banned_ngram_tokens(input_ids)
|
242
|
+
|
243
|
+
banned_tokens.each do |token|
|
244
|
+
logits[token] = -Float::INFINITY
|
245
|
+
end
|
246
|
+
logits
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
class MinLengthLogitsProcessor < LogitsProcessor
|
251
|
+
def initialize(min_length, eos_token_id)
|
252
|
+
super()
|
253
|
+
@min_length = min_length
|
254
|
+
@eos_token_id = eos_token_id.is_a?(Array) ? eos_token_id : [eos_token_id]
|
255
|
+
end
|
256
|
+
|
257
|
+
def call(input_ids, logits)
|
258
|
+
if input_ids.length < @min_length
|
259
|
+
@eos_token_id.each do |eos_token|
|
260
|
+
logits[eos_token] = -Float::INFINITY
|
261
|
+
end
|
262
|
+
end
|
263
|
+
|
264
|
+
logits
|
265
|
+
end
|
266
|
+
end
|
267
|
+
|
268
|
+
class ForcedBOSTokenLogitsProcessor < LogitsProcessor
|
269
|
+
def initialize(bos_token_id)
|
270
|
+
super()
|
271
|
+
@bos_token_id = bos_token_id
|
272
|
+
end
|
273
|
+
|
274
|
+
def call(input_ids, logits)
|
275
|
+
if input_ids.length == 1
|
276
|
+
logits.map! { -Float::INFINITY }
|
277
|
+
logits[@bos_token_id] = 0
|
278
|
+
end
|
279
|
+
logits
|
280
|
+
end
|
281
|
+
end
|
282
|
+
|
283
|
+
class ForcedEOSTokenLogitsProcessor < LogitsProcessor
|
284
|
+
def initialize(max_length, forced_eos_token_id)
|
285
|
+
super()
|
286
|
+
@max_length = max_length
|
287
|
+
@forced_eos_token_id = forced_eos_token_id
|
288
|
+
end
|
289
|
+
|
290
|
+
def call(input_ids, logits)
|
291
|
+
end
|
292
|
+
end
|
293
|
+
end
|
294
|
+
end
|
@@ -0,0 +1,116 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
class RawImage
|
4
|
+
RESAMPLING_MAPPING = {
|
5
|
+
0 => "nearest",
|
6
|
+
1 => "lanczos",
|
7
|
+
2 => "bilinear",
|
8
|
+
3 => "bicubic",
|
9
|
+
4 => "box",
|
10
|
+
5 => "hamming",
|
11
|
+
}
|
12
|
+
|
13
|
+
attr_reader :image, :width, :height, :channels
|
14
|
+
|
15
|
+
def initialize(image)
|
16
|
+
@image = image
|
17
|
+
@width = image.width
|
18
|
+
@height = image.height
|
19
|
+
@channels = image.bands
|
20
|
+
end
|
21
|
+
|
22
|
+
def data
|
23
|
+
@image.write_to_memory.unpack("C*")
|
24
|
+
end
|
25
|
+
|
26
|
+
def size
|
27
|
+
[@width, @height]
|
28
|
+
end
|
29
|
+
|
30
|
+
def resize(width, height, resample: 2)
|
31
|
+
resample_method = RESAMPLING_MAPPING[resample] || resample
|
32
|
+
|
33
|
+
case resample_method
|
34
|
+
when "bilinear", "bicubic"
|
35
|
+
img =
|
36
|
+
@image.affine(
|
37
|
+
[width / @width.to_f, 0, 0, height / @height.to_f],
|
38
|
+
interpolate: Vips::Interpolate.new(resample_method.to_sym)
|
39
|
+
)
|
40
|
+
else
|
41
|
+
raise Todo
|
42
|
+
end
|
43
|
+
|
44
|
+
RawImage.new(img)
|
45
|
+
end
|
46
|
+
|
47
|
+
def center_crop(crop_width, crop_height)
|
48
|
+
# If the image is already the desired size, return it
|
49
|
+
if @width == crop_width && @height == crop_height
|
50
|
+
return self
|
51
|
+
end
|
52
|
+
|
53
|
+
# Determine bounds of the image in the new canvas
|
54
|
+
width_offset = (@width - crop_width) / 2.0
|
55
|
+
height_offset = (@height - crop_height) / 2.0
|
56
|
+
|
57
|
+
if width_offset >= 0 && height_offset >= 0
|
58
|
+
# Cropped image lies entirely within the original image
|
59
|
+
img = @image.crop(
|
60
|
+
width_offset.floor,
|
61
|
+
height_offset.floor,
|
62
|
+
crop_width,
|
63
|
+
crop_height
|
64
|
+
)
|
65
|
+
elsif width_offset <= 0 && height_offset <= 0
|
66
|
+
raise Todo
|
67
|
+
else
|
68
|
+
raise Todo
|
69
|
+
end
|
70
|
+
|
71
|
+
RawImage.new(img)
|
72
|
+
end
|
73
|
+
|
74
|
+
def rgb
|
75
|
+
if @channels == 3
|
76
|
+
return self
|
77
|
+
end
|
78
|
+
|
79
|
+
raise Todo
|
80
|
+
end
|
81
|
+
|
82
|
+
def save(path)
|
83
|
+
@image.write_to_file(path)
|
84
|
+
end
|
85
|
+
|
86
|
+
def self.read(input)
|
87
|
+
if input.is_a?(RawImage)
|
88
|
+
input
|
89
|
+
elsif input.is_a?(URI)
|
90
|
+
require "open-uri"
|
91
|
+
|
92
|
+
RawImage.new(Vips::Image.new_from_buffer(input.read, ""))
|
93
|
+
elsif input.is_a?(String)
|
94
|
+
RawImage.new(Vips::Image.new_from_file(input))
|
95
|
+
else
|
96
|
+
raise ArgumentError, "Unsupported input type: #{input.class.name}"
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
def self.from_array(input)
|
101
|
+
c, h, w = Utils.dims(input)
|
102
|
+
pixel_data = Array.new(w * h * c)
|
103
|
+
|
104
|
+
input.each_with_index do |cv, ci|
|
105
|
+
cv.each_with_index do |hv, hi|
|
106
|
+
hv.each_with_index do |v, wi|
|
107
|
+
pixel_data[(hi * w * c) + (wi * c) + ci] = v
|
108
|
+
end
|
109
|
+
end
|
110
|
+
end
|
111
|
+
|
112
|
+
RawImage.new(Vips::Image.new_from_memory_copy(pixel_data.pack("C*"), w, h, c, :uchar))
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
data/lib/informers/utils/math.rb
CHANGED
@@ -1,5 +1,75 @@
|
|
1
1
|
module Informers
|
2
2
|
module Utils
|
3
|
+
def self.interpolate_data(input, in_shape, out_shape, mode = "bilinear", align_corners = false)
|
4
|
+
in_channels, in_height, in_width = in_shape
|
5
|
+
out_height, out_width = out_shape
|
6
|
+
|
7
|
+
# TODO use mode and align_corners
|
8
|
+
|
9
|
+
# Output image dimensions
|
10
|
+
x_scale = out_width / in_width.to_f
|
11
|
+
y_scale = out_height / in_height.to_f
|
12
|
+
|
13
|
+
# Output image
|
14
|
+
out_img = Array.new(out_height * out_width * in_channels)
|
15
|
+
|
16
|
+
# Pre-calculate strides
|
17
|
+
in_stride = in_height * in_width;
|
18
|
+
out_stride = out_height * out_width;
|
19
|
+
|
20
|
+
out_height.times do |i|
|
21
|
+
out_width.times do |j|
|
22
|
+
# Calculate output offset
|
23
|
+
out_offset = i * out_width + j
|
24
|
+
|
25
|
+
# Calculate input pixel coordinates
|
26
|
+
x = (j + 0.5) / x_scale - 0.5
|
27
|
+
y = (i + 0.5) / y_scale - 0.5
|
28
|
+
|
29
|
+
# Calculate the four nearest input pixels
|
30
|
+
# We also check if the input pixel coordinates are within the image bounds
|
31
|
+
x1 = x.floor
|
32
|
+
y1 = y.floor
|
33
|
+
x2 = [x1 + 1, in_width - 1].min
|
34
|
+
y2 = [y1 + 1, in_height - 1].min
|
35
|
+
|
36
|
+
x1 = [x1, 0].max
|
37
|
+
y1 = [y1, 0].max
|
38
|
+
|
39
|
+
# Calculate the fractional distances between the input pixel and the four nearest pixels
|
40
|
+
s = x - x1
|
41
|
+
t = y - y1
|
42
|
+
|
43
|
+
# Perform bilinear interpolation
|
44
|
+
w1 = (1 - s) * (1 - t)
|
45
|
+
w2 = s * (1 - t)
|
46
|
+
w3 = (1 - s) * t
|
47
|
+
w4 = s * t
|
48
|
+
|
49
|
+
# Calculate the four nearest input pixel indices
|
50
|
+
y_stride = y1 * in_width
|
51
|
+
x_stride = y2 * in_width
|
52
|
+
idx1 = y_stride + x1
|
53
|
+
idx2 = y_stride + x2
|
54
|
+
idx3 = x_stride + x1
|
55
|
+
idx4 = x_stride + x2
|
56
|
+
|
57
|
+
in_channels.times do |k|
|
58
|
+
# Calculate channel offset
|
59
|
+
c_offset = k * in_stride
|
60
|
+
|
61
|
+
out_img[k * out_stride + out_offset] =
|
62
|
+
w1 * input[c_offset + idx1] +
|
63
|
+
w2 * input[c_offset + idx2] +
|
64
|
+
w3 * input[c_offset + idx3] +
|
65
|
+
w4 * input[c_offset + idx4]
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
out_img
|
71
|
+
end
|
72
|
+
|
3
73
|
def self.softmax(arr)
|
4
74
|
# Compute the maximum value in the array
|
5
75
|
max_val = arr.max
|
@@ -17,6 +87,9 @@ module Informers
|
|
17
87
|
end
|
18
88
|
|
19
89
|
def self.sigmoid(arr)
|
90
|
+
if arr[0].is_a?(Array)
|
91
|
+
return arr.map { |a| sigmoid(a) }
|
92
|
+
end
|
20
93
|
arr.map { |v| 1 / (1 + Math.exp(-v)) }
|
21
94
|
end
|
22
95
|
|