informers 1.0.2 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +213 -19
- data/lib/informers/configs.rb +10 -8
- data/lib/informers/model.rb +2 -14
- data/lib/informers/models.rb +1027 -13
- data/lib/informers/pipelines.rb +781 -14
- data/lib/informers/processors.rb +796 -0
- data/lib/informers/tokenizers.rb +166 -4
- data/lib/informers/utils/core.rb +4 -0
- data/lib/informers/utils/generation.rb +294 -0
- data/lib/informers/utils/image.rb +116 -0
- data/lib/informers/utils/math.rb +73 -0
- data/lib/informers/utils/tensor.rb +46 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +3 -0
- metadata +8 -5
data/lib/informers/tokenizers.rb
CHANGED
@@ -1,16 +1,65 @@
|
|
1
1
|
module Informers
|
2
2
|
class PreTrainedTokenizer
|
3
|
-
attr_reader :sep_token_id
|
3
|
+
attr_reader :mask_token, :mask_token_id, :sep_token_id
|
4
4
|
|
5
5
|
def initialize(tokenizer_json, tokenizer_config)
|
6
6
|
super()
|
7
7
|
|
8
|
+
@tokenizer_config = tokenizer_config
|
9
|
+
|
8
10
|
@tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)
|
9
11
|
|
10
|
-
|
11
|
-
@
|
12
|
+
# Add added_tokens to model
|
13
|
+
@special_tokens = []
|
14
|
+
@all_special_ids = []
|
15
|
+
|
16
|
+
@added_tokens = []
|
17
|
+
@tokenizer.added_tokens_decoder.each do |id, token|
|
18
|
+
@added_tokens << token
|
19
|
+
|
20
|
+
if token.special
|
21
|
+
@special_tokens << token.content
|
22
|
+
@all_special_ids << id
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
# Update additional_special_tokens
|
27
|
+
@additional_special_tokens = tokenizer_config["additional_special_tokens"] || []
|
28
|
+
@special_tokens.concat(@additional_special_tokens)
|
29
|
+
|
30
|
+
@mask_token = get_token("mask_token")
|
31
|
+
@mask_token_id = @tokenizer.token_to_id(@mask_token) if @mask_token
|
32
|
+
|
33
|
+
@sep_token = get_token("sep_token")
|
34
|
+
@sep_token_id = @tokenizer.token_to_id(@sep_token) if @sep_token
|
12
35
|
|
13
36
|
@model_max_length = tokenizer_config["model_max_length"]
|
37
|
+
|
38
|
+
# for donut-base-finetuned-docvqa
|
39
|
+
if @model_max_length && @model_max_length > (1 << 63)
|
40
|
+
@model_max_length = 1 << 63
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
def get_token(*keys)
|
45
|
+
keys.each do |key|
|
46
|
+
item = @tokenizer_config[key]
|
47
|
+
if !item
|
48
|
+
next
|
49
|
+
end
|
50
|
+
|
51
|
+
if item.is_a?(Hash)
|
52
|
+
if item["__type"] == "AddedToken"
|
53
|
+
return item["content"]
|
54
|
+
else
|
55
|
+
raise Error, "Unknown token: #{item}"
|
56
|
+
end
|
57
|
+
else
|
58
|
+
return item
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
nil
|
14
63
|
end
|
15
64
|
|
16
65
|
def call(
|
@@ -76,6 +125,22 @@ module Informers
|
|
76
125
|
def convert_tokens_to_string(tokens)
|
77
126
|
@tokenizer.decoder.decode(tokens)
|
78
127
|
end
|
128
|
+
|
129
|
+
def convert_tokens_to_ids(tokens)
|
130
|
+
tokens.map { |t| @tokenizer.token_to_id(t) }
|
131
|
+
end
|
132
|
+
|
133
|
+
def id_to_token(id)
|
134
|
+
@tokenizer.id_to_token(id)
|
135
|
+
end
|
136
|
+
|
137
|
+
def batch_decode(batch, **decode_args)
|
138
|
+
@tokenizer.decode_batch(batch, **decode_args)
|
139
|
+
end
|
140
|
+
|
141
|
+
def padding_side=(side)
|
142
|
+
@tokenizer.enable_padding(direction: side)
|
143
|
+
end
|
79
144
|
end
|
80
145
|
|
81
146
|
class BertTokenizer < PreTrainedTokenizer
|
@@ -91,11 +156,108 @@ module Informers
|
|
91
156
|
class DistilBertTokenizer < PreTrainedTokenizer
|
92
157
|
end
|
93
158
|
|
159
|
+
class T5Tokenizer < PreTrainedTokenizer
|
160
|
+
end
|
161
|
+
|
162
|
+
class GPT2Tokenizer < PreTrainedTokenizer
|
163
|
+
# _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`
|
164
|
+
end
|
165
|
+
|
166
|
+
class BartTokenizer < PreTrainedTokenizer
|
167
|
+
end
|
168
|
+
|
169
|
+
class RobertaTokenizer < PreTrainedTokenizer
|
170
|
+
end
|
171
|
+
|
172
|
+
class XLMRobertaTokenizer < PreTrainedTokenizer
|
173
|
+
end
|
174
|
+
|
175
|
+
class MPNetTokenizer < PreTrainedTokenizer
|
176
|
+
end
|
177
|
+
|
178
|
+
class CLIPTokenizer < PreTrainedTokenizer
|
179
|
+
end
|
180
|
+
|
181
|
+
class NllbTokenizer < PreTrainedTokenizer
|
182
|
+
attr_reader :language_regex, :language_codes, :lang_to_token
|
183
|
+
|
184
|
+
def initialize(tokenizer_json, tokenizer_config)
|
185
|
+
super(tokenizer_json, tokenizer_config)
|
186
|
+
|
187
|
+
@language_regex = /^[a-z]{3}_[A-Z][a-z]{3}$/
|
188
|
+
@language_codes = @special_tokens.filter { |x| @language_regex.match?(x) }
|
189
|
+
@lang_to_token = ->(x) { x } # Identity function
|
190
|
+
end
|
191
|
+
|
192
|
+
def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
|
193
|
+
Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
|
194
|
+
end
|
195
|
+
end
|
196
|
+
|
197
|
+
class M2M100Tokenizer < PreTrainedTokenizer
|
198
|
+
attr_reader :language_regex, :language_codes, :lang_to_token
|
199
|
+
|
200
|
+
def initialize(tokenizer_json, tokenizer_config)
|
201
|
+
super(tokenizer_json, tokenizer_config)
|
202
|
+
|
203
|
+
@language_regex = /^__[a-z]{2,3}__$/
|
204
|
+
@language_codes = @special_tokens
|
205
|
+
.filter { |x| @language_regex.match?(x) }
|
206
|
+
.map { |x| x.slice(2, -2) }
|
207
|
+
@lang_to_token = ->(x) { "__#{x}__" }
|
208
|
+
end
|
209
|
+
|
210
|
+
def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
|
211
|
+
Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
|
212
|
+
end
|
213
|
+
end
|
214
|
+
|
215
|
+
module Utils
|
216
|
+
def self._build_translation_inputs(slf, raw_inputs, tokenizer_options, generate_kwargs)
|
217
|
+
if !slf.respond_to?(:language_codes) || !slf.language_codes.is_a?(Array)
|
218
|
+
raise Error, "Tokenizer must have `language_codes` attribute set and it should be an array of language ids."
|
219
|
+
end
|
220
|
+
if !slf.respond_to?(:language_regex) || !slf.language_regex.is_a?(Regexp)
|
221
|
+
raise Error, "Tokenizer must have `language_regex` attribute set and it should be a regular expression."
|
222
|
+
end
|
223
|
+
if !slf.respond_to?(:lang_to_token) || !slf.lang_to_token.respond_to?(:call)
|
224
|
+
raise Error, "Tokenizer must have `lang_to_token` attribute set and it should be a function."
|
225
|
+
end
|
226
|
+
src_lang_token = generate_kwargs[:src_lang]
|
227
|
+
tgt_lang_token = generate_kwargs[:tgt_lang]
|
228
|
+
|
229
|
+
if !slf.language_codes.include?(tgt_lang_token)
|
230
|
+
raise Error, "Target language code #{tgt_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
|
231
|
+
end
|
232
|
+
|
233
|
+
if !src_lang_token.nil?
|
234
|
+
# Check that the source language is valid:
|
235
|
+
if !slf.language_codes.include?(src_lang_token)
|
236
|
+
raise Error, "Source language code #{src_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
|
237
|
+
end
|
238
|
+
end
|
239
|
+
|
240
|
+
# Override the `forced_bos_token_id` to force the correct language
|
241
|
+
generate_kwargs["forced_bos_token_id"] = slf.convert_tokens_to_ids([slf.lang_to_token.(tgt_lang_token)])[0]
|
242
|
+
|
243
|
+
slf.(raw_inputs, **tokenizer_options)
|
244
|
+
end
|
245
|
+
end
|
246
|
+
|
94
247
|
class AutoTokenizer
|
95
248
|
TOKENIZER_CLASS_MAPPING = {
|
249
|
+
"T5Tokenizer" => T5Tokenizer,
|
96
250
|
"BertTokenizer" => BertTokenizer,
|
97
251
|
"DebertaV2Tokenizer" => DebertaV2Tokenizer,
|
98
|
-
"DistilBertTokenizer" => DistilBertTokenizer
|
252
|
+
"DistilBertTokenizer" => DistilBertTokenizer,
|
253
|
+
"BartTokenizer" => BartTokenizer,
|
254
|
+
"RobertaTokenizer" => RobertaTokenizer,
|
255
|
+
"XLMRobertaTokenizer" => XLMRobertaTokenizer,
|
256
|
+
"MPNetTokenizer" => MPNetTokenizer,
|
257
|
+
"CLIPTokenizer" => CLIPTokenizer,
|
258
|
+
"GPT2Tokenizer" => GPT2Tokenizer,
|
259
|
+
"NllbTokenizer" => NllbTokenizer,
|
260
|
+
"M2M100Tokenizer" => M2M100Tokenizer
|
99
261
|
}
|
100
262
|
|
101
263
|
def self.from_pretrained(
|
data/lib/informers/utils/core.rb
CHANGED
@@ -0,0 +1,294 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
class GenerationConfig
|
4
|
+
def initialize(kwargs)
|
5
|
+
@config = {}
|
6
|
+
|
7
|
+
# Parameters that control the length of the output
|
8
|
+
@config["max_length"] = kwargs["max_length"] || 20
|
9
|
+
@config["max_new_tokens"] = kwargs["max_new_tokens"]
|
10
|
+
@config["min_length"] = kwargs["min_length"] || 0
|
11
|
+
@config["min_new_tokens"] = kwargs["min_new_tokens"]
|
12
|
+
@config["early_stopping"] = kwargs["early_stopping"] || false
|
13
|
+
@config["max_time"] = kwargs["max_time"]
|
14
|
+
|
15
|
+
# Parameters that control the generation strategy used
|
16
|
+
@config["do_sample"] = kwargs["do_sample"] || false
|
17
|
+
@config["num_beams"] = kwargs["num_beams"] || 1
|
18
|
+
@config["num_beam_groups"] = kwargs["num_beam_groups"] || 1
|
19
|
+
@config["penalty_alpha"] = kwargs["penalty_alpha"]
|
20
|
+
@config["use_cache"] = kwargs.fetch("use_cache", true)
|
21
|
+
|
22
|
+
# Parameters for manipulation of the model output logits
|
23
|
+
@config["temperature"] = kwargs["temperature"] || 1.0
|
24
|
+
@config["top_k"] = kwargs["top_k"] || 50
|
25
|
+
@config["top_p"] = kwargs["top_p"] || 1.0
|
26
|
+
@config["typical_p"] = kwargs["typical_p"] || 1.0
|
27
|
+
@config["epsilon_cutoff"] = kwargs["epsilon_cutoff"] || 0.0
|
28
|
+
@config["eta_cutoff"] = kwargs["eta_cutoff"] || 0.0
|
29
|
+
@config["diversity_penalty"] = kwargs["diversity_penalty"] || 0.0
|
30
|
+
@config["repetition_penalty"] = kwargs["repetition_penalty"] || 1.0
|
31
|
+
@config["encoder_repetition_penalty"] = kwargs["encoder_repetition_penalty"] || 1.0
|
32
|
+
@config["length_penalty"] = kwargs["length_penalty"] || 1.0
|
33
|
+
@config["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"] || 0
|
34
|
+
@config["bad_words_ids"] = kwargs["bad_words_ids"]
|
35
|
+
@config["force_words_ids"] = kwargs["force_words_ids"]
|
36
|
+
@config["renormalize_logits"] = kwargs["renormalize_logits"] || false
|
37
|
+
@config["constraints"] = kwargs["constraints"]
|
38
|
+
@config["forced_bos_token_id"] = kwargs["forced_bos_token_id"]
|
39
|
+
@config["forced_eos_token_id"] = kwargs["forced_eos_token_id"]
|
40
|
+
@config["remove_invalid_values"] = kwargs["remove_invalid_values"] || false
|
41
|
+
@config["exponential_decay_length_penalty"] = kwargs["exponential_decay_length_penalty"]
|
42
|
+
@config["suppress_tokens"] = kwargs["suppress_tokens"]
|
43
|
+
@config["begin_suppress_tokens"] = kwargs["begin_suppress_tokens"]
|
44
|
+
@config["forced_decoder_ids"] = kwargs["forced_decoder_ids"]
|
45
|
+
|
46
|
+
# Parameters that define the output variables of `generate`
|
47
|
+
@config["num_return_sequences"] = kwargs["num_return_sequences"] || 1
|
48
|
+
@config["output_attentions"] = kwargs["output_attentions"] || false
|
49
|
+
@config["output_hidden_states"] = kwargs["output_hidden_states"] || false
|
50
|
+
@config["output_scores"] = kwargs["output_scores"] || false
|
51
|
+
@config["return_dict_in_generate"] = kwargs["return_dict_in_generate"] || false
|
52
|
+
|
53
|
+
# Special tokens that can be used at generation time
|
54
|
+
@config["pad_token_id"] = kwargs["pad_token_id"]
|
55
|
+
@config["bos_token_id"] = kwargs["bos_token_id"]
|
56
|
+
@config["eos_token_id"] = kwargs["eos_token_id"]
|
57
|
+
|
58
|
+
# Generation parameters exclusive to encoder-decoder models
|
59
|
+
@config["encoder_no_repeat_ngram_size"] = kwargs["encoder_no_repeat_ngram_size"] || 0
|
60
|
+
@config["decoder_start_token_id"] = kwargs["decoder_start_token_id"]
|
61
|
+
|
62
|
+
# Wild card
|
63
|
+
@generation_kwargs = kwargs["generation_kwargs"] || {}
|
64
|
+
end
|
65
|
+
|
66
|
+
def [](key)
|
67
|
+
@config[key.to_s]
|
68
|
+
end
|
69
|
+
|
70
|
+
def merge!(config)
|
71
|
+
@config.merge!(config)
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
class Sampler
|
76
|
+
def initialize(generation_config)
|
77
|
+
super()
|
78
|
+
@generation_config = generation_config
|
79
|
+
end
|
80
|
+
|
81
|
+
def call(logits, index = -1)
|
82
|
+
# Sample from logits, of dims [batch, sequence_length, vocab_size].
|
83
|
+
# If index is specified, sample from [batch, index, vocab_size].
|
84
|
+
sample(logits, index)
|
85
|
+
end
|
86
|
+
|
87
|
+
def get_logits(logits, index)
|
88
|
+
vocab_size = Utils.dims(logits)[-1]
|
89
|
+
|
90
|
+
logs = logits.flatten
|
91
|
+
|
92
|
+
if index == -1
|
93
|
+
logs = logs.last(vocab_size)
|
94
|
+
else
|
95
|
+
raise Todo
|
96
|
+
end
|
97
|
+
|
98
|
+
# add temperature
|
99
|
+
if @generation_config["temperature"] > 0
|
100
|
+
logs = logs.map { |x| x / @generation_config["temperature"] }
|
101
|
+
end
|
102
|
+
logs
|
103
|
+
end
|
104
|
+
|
105
|
+
def self.get_sampler(generation_config)
|
106
|
+
if generation_config[:do_sample]
|
107
|
+
MultinomialSampler.new(generation_config)
|
108
|
+
elsif generation_config[:num_beams] > 1
|
109
|
+
BeamSearchSampler.new(generation_config)
|
110
|
+
else
|
111
|
+
if generation_config[:num_return_sequences] > 1
|
112
|
+
raise Error, "num_return_sequences has to be 1 when doing greedy search, but is #{generation_config[:num_return_sequences]}."
|
113
|
+
end
|
114
|
+
GreedySampler.new(generation_config)
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
class GreedySampler < Sampler
|
120
|
+
def sample(logits, index = -1)
|
121
|
+
# NOTE: no need to do log_softmax here since we only take the maximum
|
122
|
+
logs = get_logits(logits, index)
|
123
|
+
argmax = Utils.max(logs)[1]
|
124
|
+
|
125
|
+
# Note: score is meaningless in this context, since we are performing
|
126
|
+
# greedy search (p = 1 => log(p) = 0)
|
127
|
+
[
|
128
|
+
[argmax, 0]
|
129
|
+
]
|
130
|
+
end
|
131
|
+
end
|
132
|
+
|
133
|
+
class BeamSearchSampler < Sampler
|
134
|
+
def sample(logits, index = -1)
|
135
|
+
k = Utils.dims(logits)[-1] # defaults to vocab size
|
136
|
+
if @generation_config["top_k"] > 0
|
137
|
+
k = [@generation_config["top_k"], k].min
|
138
|
+
end
|
139
|
+
|
140
|
+
# Get logits of nth token
|
141
|
+
logs = get_logits(logits, index)
|
142
|
+
|
143
|
+
# Get top k tokens
|
144
|
+
top_logits = Utils.get_top_items(logs, k)
|
145
|
+
|
146
|
+
# Compute softmax over logits
|
147
|
+
probabilities = Utils.softmax(top_logits.map { |x| x[1] })
|
148
|
+
|
149
|
+
Array.new(@generation_config["num_beams"]) do |i|
|
150
|
+
[
|
151
|
+
top_logits[i][0],
|
152
|
+
Math.log(probabilities[i])
|
153
|
+
]
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
class LogitsProcessorList
|
159
|
+
def initialize
|
160
|
+
super
|
161
|
+
@processors = []
|
162
|
+
end
|
163
|
+
|
164
|
+
def push(item)
|
165
|
+
@processors << item
|
166
|
+
end
|
167
|
+
|
168
|
+
def concat(items)
|
169
|
+
@processors.concat(items)
|
170
|
+
end
|
171
|
+
|
172
|
+
def call(input_ids, batched_logits)
|
173
|
+
# NOTE: This is different from the Python code, since vanilla Ruby does not support vectorized operations.
|
174
|
+
# As a result, we apply each processor to each item in the batch.
|
175
|
+
batched_logits.each do |logits|
|
176
|
+
# Modifies logits inplace
|
177
|
+
@processors.each do |func|
|
178
|
+
func.(input_ids, logits)
|
179
|
+
end
|
180
|
+
end
|
181
|
+
end
|
182
|
+
|
183
|
+
def to_ary
|
184
|
+
@processors
|
185
|
+
end
|
186
|
+
end
|
187
|
+
|
188
|
+
class LogitsProcessor
|
189
|
+
end
|
190
|
+
|
191
|
+
class NoRepeatNGramLogitsProcessor < LogitsProcessor
|
192
|
+
def initialize(no_repeat_ngram_size)
|
193
|
+
super()
|
194
|
+
@no_repeat_ngram_size = no_repeat_ngram_size
|
195
|
+
end
|
196
|
+
|
197
|
+
def get_ngrams(prev_input_ids)
|
198
|
+
cur_len = prev_input_ids.length
|
199
|
+
|
200
|
+
ngrams = []
|
201
|
+
j = 0
|
202
|
+
while j < cur_len + 1 - @no_repeat_ngram_size
|
203
|
+
ngram = []
|
204
|
+
@no_repeat_ngram_size.times do |k|
|
205
|
+
ngram << prev_input_ids[j + k]
|
206
|
+
end
|
207
|
+
ngrams << ngram
|
208
|
+
j += 1
|
209
|
+
end
|
210
|
+
|
211
|
+
generated_ngram = {}
|
212
|
+
ngrams.each do |ngram|
|
213
|
+
prev_ngram = ngram.slice(0, ngram.length - 1)
|
214
|
+
prev_ngram_key = JSON.generate(prev_ngram)
|
215
|
+
prev_ngram_value = generated_ngram[prev_ngram_key] || []
|
216
|
+
prev_ngram_value << ngram[ngram.length - 1]
|
217
|
+
generated_ngram[prev_ngram_key] = prev_ngram_value
|
218
|
+
end
|
219
|
+
generated_ngram
|
220
|
+
end
|
221
|
+
|
222
|
+
def get_generated_ngrams(banned_ngrams, prev_input_ids)
|
223
|
+
ngram_idx = prev_input_ids.slice(prev_input_ids.length + 1 - @no_repeat_ngram_size, prev_input_ids.length)
|
224
|
+
banned = banned_ngrams[JSON.generate(ngram_idx)] || []
|
225
|
+
banned
|
226
|
+
end
|
227
|
+
|
228
|
+
def calc_banned_ngram_tokens(prev_input_ids)
|
229
|
+
banned_tokens = []
|
230
|
+
if prev_input_ids.length + 1 < @no_repeat_ngram_size
|
231
|
+
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
232
|
+
banned_tokens
|
233
|
+
else
|
234
|
+
generated_ngrams = get_ngrams(prev_input_ids)
|
235
|
+
banned_tokens = get_generated_ngrams(generated_ngrams, prev_input_ids)
|
236
|
+
banned_tokens
|
237
|
+
end
|
238
|
+
end
|
239
|
+
|
240
|
+
def call(input_ids, logits)
|
241
|
+
banned_tokens = calc_banned_ngram_tokens(input_ids)
|
242
|
+
|
243
|
+
banned_tokens.each do |token|
|
244
|
+
logits[token] = -Float::INFINITY
|
245
|
+
end
|
246
|
+
logits
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
class MinLengthLogitsProcessor < LogitsProcessor
|
251
|
+
def initialize(min_length, eos_token_id)
|
252
|
+
super()
|
253
|
+
@min_length = min_length
|
254
|
+
@eos_token_id = eos_token_id.is_a?(Array) ? eos_token_id : [eos_token_id]
|
255
|
+
end
|
256
|
+
|
257
|
+
def call(input_ids, logits)
|
258
|
+
if input_ids.length < @min_length
|
259
|
+
@eos_token_id.each do |eos_token|
|
260
|
+
logits[eos_token] = -Float::INFINITY
|
261
|
+
end
|
262
|
+
end
|
263
|
+
|
264
|
+
logits
|
265
|
+
end
|
266
|
+
end
|
267
|
+
|
268
|
+
class ForcedBOSTokenLogitsProcessor < LogitsProcessor
|
269
|
+
def initialize(bos_token_id)
|
270
|
+
super()
|
271
|
+
@bos_token_id = bos_token_id
|
272
|
+
end
|
273
|
+
|
274
|
+
def call(input_ids, logits)
|
275
|
+
if input_ids.length == 1
|
276
|
+
logits.map! { -Float::INFINITY }
|
277
|
+
logits[@bos_token_id] = 0
|
278
|
+
end
|
279
|
+
logits
|
280
|
+
end
|
281
|
+
end
|
282
|
+
|
283
|
+
class ForcedEOSTokenLogitsProcessor < LogitsProcessor
|
284
|
+
def initialize(max_length, forced_eos_token_id)
|
285
|
+
super()
|
286
|
+
@max_length = max_length
|
287
|
+
@forced_eos_token_id = forced_eos_token_id
|
288
|
+
end
|
289
|
+
|
290
|
+
def call(input_ids, logits)
|
291
|
+
end
|
292
|
+
end
|
293
|
+
end
|
294
|
+
end
|
@@ -0,0 +1,116 @@
|
|
1
|
+
module Informers
|
2
|
+
module Utils
|
3
|
+
class RawImage
|
4
|
+
RESAMPLING_MAPPING = {
|
5
|
+
0 => "nearest",
|
6
|
+
1 => "lanczos",
|
7
|
+
2 => "bilinear",
|
8
|
+
3 => "bicubic",
|
9
|
+
4 => "box",
|
10
|
+
5 => "hamming",
|
11
|
+
}
|
12
|
+
|
13
|
+
attr_reader :image, :width, :height, :channels
|
14
|
+
|
15
|
+
def initialize(image)
|
16
|
+
@image = image
|
17
|
+
@width = image.width
|
18
|
+
@height = image.height
|
19
|
+
@channels = image.bands
|
20
|
+
end
|
21
|
+
|
22
|
+
def data
|
23
|
+
@image.write_to_memory.unpack("C*")
|
24
|
+
end
|
25
|
+
|
26
|
+
def size
|
27
|
+
[@width, @height]
|
28
|
+
end
|
29
|
+
|
30
|
+
def resize(width, height, resample: 2)
|
31
|
+
resample_method = RESAMPLING_MAPPING[resample] || resample
|
32
|
+
|
33
|
+
case resample_method
|
34
|
+
when "bilinear", "bicubic"
|
35
|
+
img =
|
36
|
+
@image.affine(
|
37
|
+
[width / @width.to_f, 0, 0, height / @height.to_f],
|
38
|
+
interpolate: Vips::Interpolate.new(resample_method.to_sym)
|
39
|
+
)
|
40
|
+
else
|
41
|
+
raise Todo
|
42
|
+
end
|
43
|
+
|
44
|
+
RawImage.new(img)
|
45
|
+
end
|
46
|
+
|
47
|
+
def center_crop(crop_width, crop_height)
|
48
|
+
# If the image is already the desired size, return it
|
49
|
+
if @width == crop_width && @height == crop_height
|
50
|
+
return self
|
51
|
+
end
|
52
|
+
|
53
|
+
# Determine bounds of the image in the new canvas
|
54
|
+
width_offset = (@width - crop_width) / 2.0
|
55
|
+
height_offset = (@height - crop_height) / 2.0
|
56
|
+
|
57
|
+
if width_offset >= 0 && height_offset >= 0
|
58
|
+
# Cropped image lies entirely within the original image
|
59
|
+
img = @image.crop(
|
60
|
+
width_offset.floor,
|
61
|
+
height_offset.floor,
|
62
|
+
crop_width,
|
63
|
+
crop_height
|
64
|
+
)
|
65
|
+
elsif width_offset <= 0 && height_offset <= 0
|
66
|
+
raise Todo
|
67
|
+
else
|
68
|
+
raise Todo
|
69
|
+
end
|
70
|
+
|
71
|
+
RawImage.new(img)
|
72
|
+
end
|
73
|
+
|
74
|
+
def rgb
|
75
|
+
if @channels == 3
|
76
|
+
return self
|
77
|
+
end
|
78
|
+
|
79
|
+
raise Todo
|
80
|
+
end
|
81
|
+
|
82
|
+
def save(path)
|
83
|
+
@image.write_to_file(path)
|
84
|
+
end
|
85
|
+
|
86
|
+
def self.read(input)
|
87
|
+
if input.is_a?(RawImage)
|
88
|
+
input
|
89
|
+
elsif input.is_a?(URI)
|
90
|
+
require "open-uri"
|
91
|
+
|
92
|
+
RawImage.new(Vips::Image.new_from_buffer(input.read, ""))
|
93
|
+
elsif input.is_a?(String)
|
94
|
+
RawImage.new(Vips::Image.new_from_file(input))
|
95
|
+
else
|
96
|
+
raise ArgumentError, "Unsupported input type: #{input.class.name}"
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
def self.from_array(input)
|
101
|
+
c, h, w = Utils.dims(input)
|
102
|
+
pixel_data = Array.new(w * h * c)
|
103
|
+
|
104
|
+
input.each_with_index do |cv, ci|
|
105
|
+
cv.each_with_index do |hv, hi|
|
106
|
+
hv.each_with_index do |v, wi|
|
107
|
+
pixel_data[(hi * w * c) + (wi * c) + ci] = v
|
108
|
+
end
|
109
|
+
end
|
110
|
+
end
|
111
|
+
|
112
|
+
RawImage.new(Vips::Image.new_from_memory_copy(pixel_data.pack("C*"), w, h, c, :uchar))
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
data/lib/informers/utils/math.rb
CHANGED
@@ -1,5 +1,75 @@
|
|
1
1
|
module Informers
|
2
2
|
module Utils
|
3
|
+
def self.interpolate_data(input, in_shape, out_shape, mode = "bilinear", align_corners = false)
|
4
|
+
in_channels, in_height, in_width = in_shape
|
5
|
+
out_height, out_width = out_shape
|
6
|
+
|
7
|
+
# TODO use mode and align_corners
|
8
|
+
|
9
|
+
# Output image dimensions
|
10
|
+
x_scale = out_width / in_width.to_f
|
11
|
+
y_scale = out_height / in_height.to_f
|
12
|
+
|
13
|
+
# Output image
|
14
|
+
out_img = Array.new(out_height * out_width * in_channels)
|
15
|
+
|
16
|
+
# Pre-calculate strides
|
17
|
+
in_stride = in_height * in_width;
|
18
|
+
out_stride = out_height * out_width;
|
19
|
+
|
20
|
+
out_height.times do |i|
|
21
|
+
out_width.times do |j|
|
22
|
+
# Calculate output offset
|
23
|
+
out_offset = i * out_width + j
|
24
|
+
|
25
|
+
# Calculate input pixel coordinates
|
26
|
+
x = (j + 0.5) / x_scale - 0.5
|
27
|
+
y = (i + 0.5) / y_scale - 0.5
|
28
|
+
|
29
|
+
# Calculate the four nearest input pixels
|
30
|
+
# We also check if the input pixel coordinates are within the image bounds
|
31
|
+
x1 = x.floor
|
32
|
+
y1 = y.floor
|
33
|
+
x2 = [x1 + 1, in_width - 1].min
|
34
|
+
y2 = [y1 + 1, in_height - 1].min
|
35
|
+
|
36
|
+
x1 = [x1, 0].max
|
37
|
+
y1 = [y1, 0].max
|
38
|
+
|
39
|
+
# Calculate the fractional distances between the input pixel and the four nearest pixels
|
40
|
+
s = x - x1
|
41
|
+
t = y - y1
|
42
|
+
|
43
|
+
# Perform bilinear interpolation
|
44
|
+
w1 = (1 - s) * (1 - t)
|
45
|
+
w2 = s * (1 - t)
|
46
|
+
w3 = (1 - s) * t
|
47
|
+
w4 = s * t
|
48
|
+
|
49
|
+
# Calculate the four nearest input pixel indices
|
50
|
+
y_stride = y1 * in_width
|
51
|
+
x_stride = y2 * in_width
|
52
|
+
idx1 = y_stride + x1
|
53
|
+
idx2 = y_stride + x2
|
54
|
+
idx3 = x_stride + x1
|
55
|
+
idx4 = x_stride + x2
|
56
|
+
|
57
|
+
in_channels.times do |k|
|
58
|
+
# Calculate channel offset
|
59
|
+
c_offset = k * in_stride
|
60
|
+
|
61
|
+
out_img[k * out_stride + out_offset] =
|
62
|
+
w1 * input[c_offset + idx1] +
|
63
|
+
w2 * input[c_offset + idx2] +
|
64
|
+
w3 * input[c_offset + idx3] +
|
65
|
+
w4 * input[c_offset + idx4]
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
out_img
|
71
|
+
end
|
72
|
+
|
3
73
|
def self.softmax(arr)
|
4
74
|
# Compute the maximum value in the array
|
5
75
|
max_val = arr.max
|
@@ -17,6 +87,9 @@ module Informers
|
|
17
87
|
end
|
18
88
|
|
19
89
|
def self.sigmoid(arr)
|
90
|
+
if arr[0].is_a?(Array)
|
91
|
+
return arr.map { |a| sigmoid(a) }
|
92
|
+
end
|
20
93
|
arr.map { |v| 1 / (1 + Math.exp(-v)) }
|
21
94
|
end
|
22
95
|
|