informers 1.0.2 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +213 -19
- data/lib/informers/configs.rb +10 -8
- data/lib/informers/model.rb +2 -14
- data/lib/informers/models.rb +1027 -13
- data/lib/informers/pipelines.rb +781 -14
- data/lib/informers/processors.rb +796 -0
- data/lib/informers/tokenizers.rb +166 -4
- data/lib/informers/utils/core.rb +4 -0
- data/lib/informers/utils/generation.rb +294 -0
- data/lib/informers/utils/image.rb +116 -0
- data/lib/informers/utils/math.rb +73 -0
- data/lib/informers/utils/tensor.rb +46 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +3 -0
- metadata +8 -5
data/lib/informers/models.rb
CHANGED
@@ -44,7 +44,7 @@ module Informers
|
|
44
44
|
end
|
45
45
|
|
46
46
|
const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
|
47
|
-
model_info = model_class_mapping[config
|
47
|
+
model_info = model_class_mapping[config[:model_type]]
|
48
48
|
if !model_info
|
49
49
|
next # Item not found in this mapping
|
50
50
|
end
|
@@ -52,15 +52,17 @@ module Informers
|
|
52
52
|
end
|
53
53
|
|
54
54
|
if const_defined?(:BASE_IF_FAIL)
|
55
|
-
warn "Unknown model class #{config
|
55
|
+
warn "Unknown model class #{config[:model_type].inspect}, attempting to construct from base class."
|
56
56
|
PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
|
57
57
|
else
|
58
|
-
raise Error, "Unsupported model type: #{config
|
58
|
+
raise Error, "Unsupported model type: #{config[:model_type]}"
|
59
59
|
end
|
60
60
|
end
|
61
61
|
end
|
62
62
|
|
63
63
|
class PreTrainedModel
|
64
|
+
MAIN_INPUT_NAME = :input_ids
|
65
|
+
|
64
66
|
attr_reader :config
|
65
67
|
|
66
68
|
def initialize(config, session)
|
@@ -76,9 +78,19 @@ module Informers
|
|
76
78
|
|
77
79
|
case model_type
|
78
80
|
when MODEL_TYPES[:DecoderOnly]
|
79
|
-
|
81
|
+
@can_generate = true
|
82
|
+
|
83
|
+
@run_beam = method(:decoder_run_beam)
|
84
|
+
@get_start_beams = method(:decoder_start_beams)
|
85
|
+
@update_beam = method(:decoder_update_beam)
|
86
|
+
@forward = method(:decoder_forward)
|
80
87
|
when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
|
81
|
-
|
88
|
+
@can_generate = true
|
89
|
+
|
90
|
+
@run_beam = method(:seq2seq_run_beam)
|
91
|
+
@get_start_beams = method(:seq2seq_start_beams)
|
92
|
+
@update_beam = method(:seq2seq_update_beam)
|
93
|
+
@forward = method(:seq2seq_forward)
|
82
94
|
when MODEL_TYPES[:EncoderDecoder]
|
83
95
|
raise Todo
|
84
96
|
else
|
@@ -110,10 +122,19 @@ module Informers
|
|
110
122
|
model_type = MODEL_TYPE_MAPPING[model_name]
|
111
123
|
|
112
124
|
if model_type == MODEL_TYPES[:DecoderOnly]
|
113
|
-
|
125
|
+
info = [
|
126
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
127
|
+
construct_session(pretrained_model_name_or_path, options[:model_file_name] || "decoder_model_merged", **options),
|
128
|
+
Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
|
129
|
+
]
|
114
130
|
|
115
131
|
elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
|
116
|
-
|
132
|
+
info = [
|
133
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
134
|
+
construct_session(pretrained_model_name_or_path, "encoder_model", **options),
|
135
|
+
construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options),
|
136
|
+
Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
|
137
|
+
]
|
117
138
|
|
118
139
|
elsif model_type == MODEL_TYPES[:MaskGeneration]
|
119
140
|
raise Todo
|
@@ -123,7 +144,7 @@ module Informers
|
|
123
144
|
|
124
145
|
else
|
125
146
|
if model_type != MODEL_TYPES[:EncoderOnly]
|
126
|
-
warn "Model type for '#{model_name || config
|
147
|
+
warn "Model type for '#{model_name || config[:model_type]}' not found, assuming encoder-only architecture. Please report this."
|
127
148
|
end
|
128
149
|
info = [
|
129
150
|
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
@@ -135,7 +156,15 @@ module Informers
|
|
135
156
|
end
|
136
157
|
|
137
158
|
def self.construct_session(pretrained_model_name_or_path, file_name, **options)
|
138
|
-
|
159
|
+
prefix = "onnx/"
|
160
|
+
if file_name.start_with?("../")
|
161
|
+
prefix = ""
|
162
|
+
file_name = file_name[3..]
|
163
|
+
elsif file_name.start_with?("/")
|
164
|
+
prefix = ""
|
165
|
+
file_name = file_name[1..]
|
166
|
+
end
|
167
|
+
model_file_name = "#{prefix}#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
|
139
168
|
path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
|
140
169
|
|
141
170
|
OnnxRuntime::InferenceSession.new(path)
|
@@ -145,8 +174,445 @@ module Informers
|
|
145
174
|
@forward.(model_inputs, **kwargs)
|
146
175
|
end
|
147
176
|
|
177
|
+
def generate(inputs, generation_config = nil, logits_processor = nil, inputs_attention_mask: nil)
|
178
|
+
if !@can_generate
|
179
|
+
model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
|
180
|
+
error_message = "The current model class (#{model_name}) is not compatible with `.generate()`, as it doesn't have a language model head."
|
181
|
+
raise Error, error_message
|
182
|
+
end
|
183
|
+
|
184
|
+
if !inputs.is_a?(Array)
|
185
|
+
raise ArgumentError, "`inputs` must be an Array, but is #{inputs.class.name}"
|
186
|
+
end
|
187
|
+
|
188
|
+
if @config[:is_encoder_decoder]
|
189
|
+
# Generating from the encoder outputs
|
190
|
+
input_ids_seq_length = 0
|
191
|
+
else
|
192
|
+
input_ids_seq_length = inputs.length
|
193
|
+
|
194
|
+
# decoder-only
|
195
|
+
if input_ids_seq_length == 0
|
196
|
+
raise Error, "Must supply a non-empty array of input token ids."
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
# Update generation config with defaults
|
201
|
+
generation_config = get_generation_config(generation_config)
|
202
|
+
|
203
|
+
logits_processor ||= Utils::LogitsProcessorList.new
|
204
|
+
|
205
|
+
# Update logits processor
|
206
|
+
logits_processor = get_logits_processor(
|
207
|
+
generation_config,
|
208
|
+
input_ids_seq_length,
|
209
|
+
logits_processor
|
210
|
+
)
|
211
|
+
|
212
|
+
eos_token_ids = generation_config[:eos_token_id]
|
213
|
+
if !eos_token_ids.nil? && !eos_token_ids.is_a?(Array)
|
214
|
+
eos_token_ids = [eos_token_ids]
|
215
|
+
end
|
216
|
+
|
217
|
+
num_output_tokens = 1
|
218
|
+
max_output_tokens = num_output_tokens + (generation_config[:max_new_tokens] || Float::INFINITY)
|
219
|
+
|
220
|
+
# Only use max length if max_new_tokens is not provided
|
221
|
+
use_max_length = generation_config[:max_length].is_a?(Integer) && generation_config[:max_new_tokens].nil?
|
222
|
+
sampler = Utils::Sampler.get_sampler(generation_config)
|
223
|
+
|
224
|
+
beams = get_start_beams(inputs, generation_config, num_output_tokens, inputs_attention_mask)
|
225
|
+
|
226
|
+
while beams.any? { |x| !x[:done] } && num_output_tokens < max_output_tokens
|
227
|
+
newest_beams = []
|
228
|
+
beams.each do |beam|
|
229
|
+
if beam[:done]
|
230
|
+
# Add this beam back into the pool
|
231
|
+
newest_beams << beam
|
232
|
+
next
|
233
|
+
end
|
234
|
+
if use_max_length && beam[:output_token_ids].length >= generation_config["max_length"]
|
235
|
+
# Set this beam to done and add it back into the pool
|
236
|
+
beam[:done] = true
|
237
|
+
newest_beams << beam
|
238
|
+
next
|
239
|
+
end
|
240
|
+
|
241
|
+
output = run_beam(beam)
|
242
|
+
|
243
|
+
# add attentions/scores to beam only if user requested
|
244
|
+
if generation_config["output_attentions"]
|
245
|
+
add_attentions_to_beam(beam, output)
|
246
|
+
end
|
247
|
+
|
248
|
+
# Logits are of the form [batch_size, out_seq_length, vocab_size]
|
249
|
+
# In most cases, this will be [batch_size, 1, vocab_size]
|
250
|
+
# So, we select the last token's logits:
|
251
|
+
# (equivalent to `logits = outputs.logits[:, -1, :]`)
|
252
|
+
logits = output["logits"].map { |v| v[-1] }
|
253
|
+
|
254
|
+
# Apply logits processor
|
255
|
+
logits_processor.(beam[:output_token_ids], logits)
|
256
|
+
|
257
|
+
sampled_tokens = sampler.(logits)
|
258
|
+
sampled_tokens.each do |new_token_id, log_prob|
|
259
|
+
# use previous beam as a starting point
|
260
|
+
new_beam = beam.dup
|
261
|
+
|
262
|
+
# update new beam
|
263
|
+
update_beam(new_beam, new_token_id)
|
264
|
+
|
265
|
+
new_beam[:score] += log_prob
|
266
|
+
|
267
|
+
if eos_token_ids && eos_token_ids.include?(new_token_id)
|
268
|
+
new_beam[:done] = true
|
269
|
+
end
|
270
|
+
|
271
|
+
newest_beams << new_beam
|
272
|
+
end
|
273
|
+
end
|
274
|
+
num_output_tokens += 1
|
275
|
+
|
276
|
+
# Next, we get the best beams, per ID
|
277
|
+
newest_beams =
|
278
|
+
group_beams(newest_beams).map do |group|
|
279
|
+
group.sort_by { |v| -v[:score] }[0...generation_config["num_beams"]]
|
280
|
+
end
|
281
|
+
|
282
|
+
# Flatten beams
|
283
|
+
beams = newest_beams.flatten(1)
|
284
|
+
|
285
|
+
# Run callback
|
286
|
+
if generation_config["callback_function"]
|
287
|
+
generation_config["callback_function"].(beams)
|
288
|
+
end
|
289
|
+
end
|
290
|
+
|
291
|
+
# TODO: Ensure that we can return non-batched outputs
|
292
|
+
|
293
|
+
grouped_beams = group_beams(beams)
|
294
|
+
|
295
|
+
get_flattened = lambda do |key|
|
296
|
+
grouped_beams.map do |batch|
|
297
|
+
if generation_config["num_return_sequences"] > 1
|
298
|
+
raise Todo
|
299
|
+
else
|
300
|
+
[batch[0][key]]
|
301
|
+
end
|
302
|
+
end.flatten(1)
|
303
|
+
end
|
304
|
+
|
305
|
+
sequences = get_flattened.(:output_token_ids) # [1, seqLength]
|
306
|
+
|
307
|
+
if generation_config["return_dict_in_generate"]
|
308
|
+
raise Todo
|
309
|
+
else
|
310
|
+
sequences
|
311
|
+
end
|
312
|
+
end
|
313
|
+
|
148
314
|
private
|
149
315
|
|
316
|
+
def get_logits_processor(
|
317
|
+
generation_config,
|
318
|
+
input_ids_seq_length,
|
319
|
+
logits_processor = nil
|
320
|
+
)
|
321
|
+
processors = Utils::LogitsProcessorList.new
|
322
|
+
|
323
|
+
if !generation_config["repetition_penalty"].nil? && generation_config["repetition_penalty"] != 1.0
|
324
|
+
processors.push(Utils::RepetitionPenaltyLogitsProcessor.new(generation_config["repetition_penalty"]))
|
325
|
+
end
|
326
|
+
|
327
|
+
if !generation_config["no_repeat_ngram_size"].nil? && generation_config["no_repeat_ngram_size"] > 0
|
328
|
+
processors.push(Utils::NoRepeatNGramLogitsProcessor.new(generation_config["no_repeat_ngram_size"]))
|
329
|
+
end
|
330
|
+
|
331
|
+
if !generation_config["bad_words_ids"].nil?
|
332
|
+
processors.push(Utils::NoBadWordsLogitsProcessor.new(generation_config["bad_words_ids"], generation_config["eos_token_id"]))
|
333
|
+
end
|
334
|
+
|
335
|
+
if !generation_config["min_length"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_length"] > 0
|
336
|
+
processors.push(Utils::MinLengthLogitsProcessor.new(generation_config["min_length"], generation_config["eos_token_id"]))
|
337
|
+
end
|
338
|
+
|
339
|
+
if !generation_config["min_new_tokens"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_new_tokens"] > 0
|
340
|
+
processors.push(Utils::MinNewTokensLengthLogitsProcessor.new(
|
341
|
+
input_ids_seq_length,
|
342
|
+
generation_config["min_new_tokens"],
|
343
|
+
generation_config["eos_token_id"]
|
344
|
+
))
|
345
|
+
end
|
346
|
+
|
347
|
+
if !generation_config["forced_bos_token_id"].nil?
|
348
|
+
processors.push(Utils::ForcedBOSTokenLogitsProcessor.new(generation_config["forced_bos_token_id"]))
|
349
|
+
end
|
350
|
+
|
351
|
+
if !generation_config["forced_eos_token_id"].nil?
|
352
|
+
processors.push(Utils::ForcedEOSTokenLogitsProcessor.new(
|
353
|
+
generation_config["max_length"],
|
354
|
+
generation_config["forced_eos_token_id"]
|
355
|
+
))
|
356
|
+
end
|
357
|
+
|
358
|
+
if !generation_config["begin_suppress_tokens"].nil?
|
359
|
+
raise Todo
|
360
|
+
end
|
361
|
+
|
362
|
+
if !generation_config["forced_decoder_ids"].nil?
|
363
|
+
processors.push(Utils::ForceTokensLogitsProcessor.new(generation_config["forced_decoder_ids"]))
|
364
|
+
end
|
365
|
+
|
366
|
+
if !logits_processor.nil?
|
367
|
+
processors.concat(logits_processor)
|
368
|
+
end
|
369
|
+
|
370
|
+
processors
|
371
|
+
end
|
372
|
+
|
373
|
+
def get_generation_config(generation_config)
|
374
|
+
# Create empty generation config (contains defaults)
|
375
|
+
# We pass `@config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
|
376
|
+
gen_config = Utils::GenerationConfig.new(@config.to_h)
|
377
|
+
|
378
|
+
# Apply model's generation config, if it exists
|
379
|
+
if @generation_config
|
380
|
+
gen_config.merge!(@generation_config)
|
381
|
+
end
|
382
|
+
|
383
|
+
# Finally, use any generation config specified by the user
|
384
|
+
# when calling `generate`
|
385
|
+
if !generation_config.nil?
|
386
|
+
gen_config.merge!(generation_config)
|
387
|
+
end
|
388
|
+
|
389
|
+
gen_config
|
390
|
+
end
|
391
|
+
|
392
|
+
def seq2seq_forward(model_inputs)
|
393
|
+
encoder_outputs = model_inputs[:encoder_outputs]
|
394
|
+
past_key_values = model_inputs[:past_key_values]
|
395
|
+
|
396
|
+
if !encoder_outputs
|
397
|
+
# Encoder outputs are not given, so we must compute them.
|
398
|
+
encoder_outputs = encoder_forward(model_inputs)[0]
|
399
|
+
end
|
400
|
+
decoder_feeds = {
|
401
|
+
input_ids: model_inputs[:decoder_input_ids],
|
402
|
+
encoder_hidden_states: encoder_outputs
|
403
|
+
}
|
404
|
+
use_cache_branch = !!past_key_values
|
405
|
+
|
406
|
+
if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
|
407
|
+
decoder_feeds[:use_cache_branch] = [use_cache_branch]
|
408
|
+
end
|
409
|
+
|
410
|
+
if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("encoder_attention_mask")
|
411
|
+
decoder_feeds[:encoder_attention_mask] = model_inputs[:attention_mask]
|
412
|
+
end
|
413
|
+
|
414
|
+
prepare_position_ids(@decoder_merged_session, decoder_feeds, use_cache_branch)
|
415
|
+
add_past_key_values(decoder_feeds, past_key_values)
|
416
|
+
|
417
|
+
decoder_results = session_run(@decoder_merged_session, decoder_feeds)
|
418
|
+
decoder_results = @decoder_merged_session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
|
419
|
+
logits = decoder_results["logits"]
|
420
|
+
past_key_values = get_past_key_values(decoder_results, past_key_values)
|
421
|
+
|
422
|
+
# Get cross attention and/or decoder attentions if they are present
|
423
|
+
attns = get_attentions(decoder_results)
|
424
|
+
|
425
|
+
Seq2SeqLMOutput.new(logits, past_key_values, encoder_outputs, attns["decoder_attentions"], attns["cross_attentions"])
|
426
|
+
end
|
427
|
+
|
428
|
+
def prepare_position_ids(session, feeds, use_cache_branch)
|
429
|
+
if !session.inputs.map { |v| v[:name] }.include?("position_ids")
|
430
|
+
return
|
431
|
+
end
|
432
|
+
|
433
|
+
raise Todo
|
434
|
+
end
|
435
|
+
|
436
|
+
def get_past_key_values(decoder_results, past_key_values)
|
437
|
+
pkvs = {}
|
438
|
+
|
439
|
+
decoder_results.each_key do |name|
|
440
|
+
if name.start_with?("present")
|
441
|
+
new_name = name.sub("present", "past_key_values")
|
442
|
+
|
443
|
+
if past_key_values && name.include?("encoder")
|
444
|
+
# Optimization introduced by optimum to reuse past key values. So, we just replace the constant
|
445
|
+
# outputs with the previous past key values.
|
446
|
+
# https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
|
447
|
+
pkvs[new_name] = past_key_values[new_name]
|
448
|
+
else
|
449
|
+
pkvs[new_name] = decoder_results[name]
|
450
|
+
end
|
451
|
+
end
|
452
|
+
end
|
453
|
+
pkvs
|
454
|
+
end
|
455
|
+
|
456
|
+
def get_attentions(decoder_results)
|
457
|
+
attns = {}
|
458
|
+
|
459
|
+
["cross_attentions", "decoder_attentions"].each do |attn_name|
|
460
|
+
result = []
|
461
|
+
decoder_results.each_key do |name|
|
462
|
+
if name.start_with?(attn_name)
|
463
|
+
index = name.split(".").pop
|
464
|
+
result[index] = decoder_results[name]
|
465
|
+
end
|
466
|
+
end
|
467
|
+
attns[attn_name] = result
|
468
|
+
end
|
469
|
+
attns
|
470
|
+
end
|
471
|
+
|
472
|
+
def add_past_key_values(decoder_feeds, past_key_values)
|
473
|
+
if past_key_values
|
474
|
+
decoder_feeds.merge!(past_key_values)
|
475
|
+
else
|
476
|
+
# TODO support batches (i.e., batch_size > 1)
|
477
|
+
batch_size = 1
|
478
|
+
|
479
|
+
if @config[:is_encoder_decoder] && (!@add_encoder_pkv.nil? ? @add_encoder_pkv : true)
|
480
|
+
_encoder_dims = [batch_size, @num_encoder_heads, 0, @encoder_dim_kv]
|
481
|
+
_decoder_dims = [batch_size, @num_decoder_heads, 0, @decoder_dim_kv]
|
482
|
+
@num_decoder_layers.times do |i|
|
483
|
+
# decoder_feeds["past_key_values.#{i}.encoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
|
484
|
+
# decoder_feeds["past_key_values.#{i}.encoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
|
485
|
+
# decoder_feeds["past_key_values.#{i}.decoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
|
486
|
+
# decoder_feeds["past_key_values.#{i}.decoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
|
487
|
+
end
|
488
|
+
elsif @config[:model_type] == "falcon"
|
489
|
+
raise Todo
|
490
|
+
elsif @config[:multi_query]
|
491
|
+
raise Todo
|
492
|
+
elsif @config[:model_type] == "bloom"
|
493
|
+
raise Todo
|
494
|
+
else
|
495
|
+
_dims = [batch_size, @num_heads, 0, @dim_kv]
|
496
|
+
@num_layers.times do |i|
|
497
|
+
# decoder_feeds["past_key_values.#{i}.key"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
|
498
|
+
# decoder_feeds["past_key_values.#{i}.value"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
|
499
|
+
end
|
500
|
+
end
|
501
|
+
end
|
502
|
+
end
|
503
|
+
|
504
|
+
def seq2seq_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask = nil)
|
505
|
+
beams = []
|
506
|
+
beam_id = 0
|
507
|
+
|
508
|
+
requires_attention_mask = !@requires_attention_mask.nil? ? @requires_attention_mask : true
|
509
|
+
|
510
|
+
# decoder_input_ids == output_token_ids
|
511
|
+
decoder_input_ids =
|
512
|
+
generation_config["decoder_input_ids"] ||
|
513
|
+
generation_config["decoder_start_token_id"] ||
|
514
|
+
generation_config["bos_token_id"] ||
|
515
|
+
generation_config["eos_token_id"]
|
516
|
+
|
517
|
+
if !decoder_input_ids.is_a?(Array)
|
518
|
+
decoder_input_ids = [decoder_input_ids]
|
519
|
+
end
|
520
|
+
|
521
|
+
input_token_ids.each do |tokens|
|
522
|
+
# TODO: Improve
|
523
|
+
# Currently, just add back batch dimension.
|
524
|
+
# In future, allow for true parallel execution
|
525
|
+
tokens = [tokens]
|
526
|
+
|
527
|
+
# Create beam
|
528
|
+
start = {
|
529
|
+
inputs: tokens,
|
530
|
+
encoder_outputs: nil,
|
531
|
+
prev_model_outputs: nil,
|
532
|
+
|
533
|
+
output_token_ids: decoder_input_ids,
|
534
|
+
done: false,
|
535
|
+
score: 0,
|
536
|
+
id: beam_id # assign unique id to beams
|
537
|
+
}
|
538
|
+
beam_id += 1
|
539
|
+
|
540
|
+
if requires_attention_mask
|
541
|
+
start[:attention_mask] = prepare_attention_mask(tokens)
|
542
|
+
end
|
543
|
+
|
544
|
+
beams << start
|
545
|
+
end
|
546
|
+
|
547
|
+
beams
|
548
|
+
end
|
549
|
+
|
550
|
+
def prepare_attention_mask(tokens)
|
551
|
+
# Prepare attention mask
|
552
|
+
pad_token_id = @config["pad_token_id"]
|
553
|
+
eos_token_id = @config["eos_token_id"]
|
554
|
+
if eos_token_id.is_a?(Integer)
|
555
|
+
eos_token_id = [eos_token_id]
|
556
|
+
end
|
557
|
+
|
558
|
+
is_pad_token_in_inputs = !tokens.index(pad_token_id).nil?
|
559
|
+
is_pad_token_not_equal_to_eos_token_id = eos_token_id.nil? || !eos_token_id.include?(pad_token_id)
|
560
|
+
|
561
|
+
if is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id
|
562
|
+
raise Todo
|
563
|
+
else
|
564
|
+
Utils.ones_like(tokens)
|
565
|
+
end
|
566
|
+
end
|
567
|
+
|
568
|
+
def seq2seq_run_beam(beam)
|
569
|
+
input_name = self.class.const_get(:MAIN_INPUT_NAME)
|
570
|
+
|
571
|
+
decoder_input_ids = beam[:output_token_ids]
|
572
|
+
if beam[:prev_model_outputs]
|
573
|
+
# After the first step, `prev_model_outputs` won't be null.
|
574
|
+
# So, we cut decoder_input_ids if past is used
|
575
|
+
decoder_input_ids = [decoder_input_ids[-1]]
|
576
|
+
end
|
577
|
+
|
578
|
+
# 1. Prepare
|
579
|
+
model_inputs = {
|
580
|
+
input_name => beam[:inputs],
|
581
|
+
decoder_input_ids: [decoder_input_ids],
|
582
|
+
encoder_outputs: beam[:encoder_outputs],
|
583
|
+
past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
|
584
|
+
}
|
585
|
+
if beam[:attention_mask]
|
586
|
+
model_inputs[:attention_mask] = beam[:attention_mask]
|
587
|
+
end
|
588
|
+
|
589
|
+
# 2. Run
|
590
|
+
output = @forward.(model_inputs)
|
591
|
+
|
592
|
+
# 3. Update
|
593
|
+
beam[:prev_model_outputs] = output
|
594
|
+
beam[:encoder_outputs] = output[:encoder_outputs]
|
595
|
+
|
596
|
+
output
|
597
|
+
end
|
598
|
+
|
599
|
+
def seq2seq_update_beam(beam, new_token_id)
|
600
|
+
beam[:output_token_ids] += [new_token_id]
|
601
|
+
end
|
602
|
+
|
603
|
+
def group_beams(beams)
|
604
|
+
# Group beams by their ids
|
605
|
+
groups = {}
|
606
|
+
beams.each do |obj|
|
607
|
+
if !groups[obj[:id]]
|
608
|
+
groups[obj[:id]] = [obj]
|
609
|
+
else
|
610
|
+
groups[obj[:id]] << obj
|
611
|
+
end
|
612
|
+
end
|
613
|
+
groups.values
|
614
|
+
end
|
615
|
+
|
150
616
|
def encoder_forward(model_inputs, output_names: nil)
|
151
617
|
encoder_feeds = {}
|
152
618
|
@session.inputs.each do |input|
|
@@ -159,7 +625,96 @@ module Informers
|
|
159
625
|
session_run(@session, encoder_feeds, output_names:)
|
160
626
|
end
|
161
627
|
|
162
|
-
def
|
628
|
+
def decoder_forward(model_inputs)
|
629
|
+
input_ids, past_key_values, attention_mask =
|
630
|
+
model_inputs.values_at(:input_ids, :past_key_values, :attention_mask)
|
631
|
+
decoder_feeds = {
|
632
|
+
input_ids: input_ids,
|
633
|
+
attention_mask: attention_mask || prepare_attention_mask(input_ids)
|
634
|
+
}
|
635
|
+
use_cache_branch = !!past_key_values
|
636
|
+
|
637
|
+
if @session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
|
638
|
+
decoder_feeds[:use_cache_branch] = [use_cache_branch]
|
639
|
+
end
|
640
|
+
|
641
|
+
prepare_position_ids(@session, decoder_feeds, use_cache_branch)
|
642
|
+
|
643
|
+
add_past_key_values(decoder_feeds, past_key_values)
|
644
|
+
|
645
|
+
decoder_results = session_run(@session, decoder_feeds)
|
646
|
+
decoder_results = @session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
|
647
|
+
|
648
|
+
logits = decoder_results["logits"]
|
649
|
+
|
650
|
+
past_key_values = get_past_key_values(decoder_results, past_key_values)
|
651
|
+
{"logits" => logits, past_key_values: past_key_values}
|
652
|
+
end
|
653
|
+
|
654
|
+
def decoder_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
|
655
|
+
beams = []
|
656
|
+
|
657
|
+
beam_id = 0
|
658
|
+
input_token_ids.each do |tokens|
|
659
|
+
output_token_ids = tokens.dup
|
660
|
+
|
661
|
+
# TODO: Improve
|
662
|
+
# Currently, just add back batch dimension.
|
663
|
+
# In future, allow for true parallel execution
|
664
|
+
tokens = [tokens]
|
665
|
+
|
666
|
+
if inputs_attention_mask
|
667
|
+
attn_mask = inputs_attention_mask[beam_id]
|
668
|
+
attn_mask = [attn_mask]
|
669
|
+
else
|
670
|
+
attn_mask = prepare_attention_mask(tokens)
|
671
|
+
end
|
672
|
+
|
673
|
+
start = {
|
674
|
+
input: tokens,
|
675
|
+
model_input_ids: tokens,
|
676
|
+
attention_mask: attn_mask,
|
677
|
+
prev_model_outputs: nil,
|
678
|
+
|
679
|
+
output_token_ids: output_token_ids,
|
680
|
+
num_output_tokens: num_output_tokens,
|
681
|
+
|
682
|
+
done: false,
|
683
|
+
score: 0,
|
684
|
+
id: beam_id # assign unique id to beams
|
685
|
+
}
|
686
|
+
beam_id += 1
|
687
|
+
|
688
|
+
beams << start
|
689
|
+
end
|
690
|
+
beams
|
691
|
+
end
|
692
|
+
|
693
|
+
def decoder_run_beam(beam)
|
694
|
+
attn_mask_data = Array.new(beam[:output_token_ids].length, 1)
|
695
|
+
|
696
|
+
# 1. Prepare
|
697
|
+
model_inputs = {
|
698
|
+
input_ids: beam[:model_input_ids],
|
699
|
+
attention_mask: [attn_mask_data],
|
700
|
+
past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
|
701
|
+
}
|
702
|
+
|
703
|
+
# 2. Run
|
704
|
+
output = @forward.(model_inputs)
|
705
|
+
|
706
|
+
# 3. Update
|
707
|
+
beam[:prev_model_outputs] = output
|
708
|
+
|
709
|
+
output
|
710
|
+
end
|
711
|
+
|
712
|
+
def decoder_update_beam(beam, new_token_id)
|
713
|
+
beam[:output_token_ids] += [new_token_id]
|
714
|
+
beam[:model_input_ids] = [[new_token_id]]
|
715
|
+
end
|
716
|
+
|
717
|
+
def session_run(session, inputs, output_names: nil)
|
163
718
|
checked_inputs = validate_inputs(session, inputs)
|
164
719
|
begin
|
165
720
|
output = session.run(output_names || @output_names, checked_inputs)
|
@@ -179,6 +734,18 @@ module Informers
|
|
179
734
|
def validate_inputs(session, inputs)
|
180
735
|
inputs
|
181
736
|
end
|
737
|
+
|
738
|
+
def get_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
|
739
|
+
@get_start_beams.(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
|
740
|
+
end
|
741
|
+
|
742
|
+
def run_beam(beam)
|
743
|
+
@run_beam.(beam)
|
744
|
+
end
|
745
|
+
|
746
|
+
def update_beam(beam, new_token_id)
|
747
|
+
@update_beam.(beam, new_token_id)
|
748
|
+
end
|
182
749
|
end
|
183
750
|
|
184
751
|
class BertPreTrainedModel < PreTrainedModel
|
@@ -187,6 +754,12 @@ module Informers
|
|
187
754
|
class BertModel < BertPreTrainedModel
|
188
755
|
end
|
189
756
|
|
757
|
+
class BertForMaskedLM < BertPreTrainedModel
|
758
|
+
def call(model_inputs)
|
759
|
+
MaskedLMOutput.new(*super(model_inputs))
|
760
|
+
end
|
761
|
+
end
|
762
|
+
|
190
763
|
class BertForSequenceClassification < BertPreTrainedModel
|
191
764
|
def call(model_inputs)
|
192
765
|
SequenceClassifierOutput.new(*super(model_inputs))
|
@@ -229,31 +802,376 @@ module Informers
|
|
229
802
|
end
|
230
803
|
end
|
231
804
|
|
805
|
+
class MPNetPreTrainedModel < PreTrainedModel
|
806
|
+
end
|
807
|
+
|
808
|
+
class MPNetModel < MPNetPreTrainedModel
|
809
|
+
end
|
810
|
+
|
811
|
+
class T5PreTrainedModel < PreTrainedModel
|
812
|
+
end
|
813
|
+
|
814
|
+
class T5Model < T5PreTrainedModel
|
815
|
+
end
|
816
|
+
|
817
|
+
class T5ForConditionalGeneration < T5PreTrainedModel
|
818
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
819
|
+
super(config, session)
|
820
|
+
@decoder_merged_session = decoder_merged_session
|
821
|
+
@generation_config = generation_config
|
822
|
+
|
823
|
+
@num_decoder_layers = @config[:num_decoder_layers]
|
824
|
+
@num_decoder_heads = @config[:num_heads]
|
825
|
+
@decoder_dim_kv = @config[:d_kv]
|
826
|
+
|
827
|
+
@num_encoder_layers = @config[:num_layers]
|
828
|
+
@num_encoder_heads = @config[:num_heads]
|
829
|
+
@encoder_dim_kv = @config[:d_kv]
|
830
|
+
end
|
831
|
+
end
|
832
|
+
|
833
|
+
class BartPretrainedModel < PreTrainedModel
|
834
|
+
end
|
835
|
+
|
836
|
+
class BartModel < BartPretrainedModel
|
837
|
+
end
|
838
|
+
|
839
|
+
class BartForConditionalGeneration < BartPretrainedModel
|
840
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
841
|
+
super(config, session)
|
842
|
+
@decoder_merged_session = decoder_merged_session
|
843
|
+
@generation_config = generation_config
|
844
|
+
|
845
|
+
@num_decoder_layers = @config["decoder_layers"]
|
846
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
847
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
848
|
+
|
849
|
+
@num_encoder_layers = @config["encoder_layers"]
|
850
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
851
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads
|
852
|
+
end
|
853
|
+
end
|
854
|
+
|
855
|
+
class BartForSequenceClassification < BartPretrainedModel
|
856
|
+
def call(model_inputs)
|
857
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
858
|
+
end
|
859
|
+
end
|
860
|
+
|
861
|
+
class MBartPreTrainedModel < PreTrainedModel
|
862
|
+
end
|
863
|
+
|
864
|
+
class MBartModel < MBartPreTrainedModel
|
865
|
+
end
|
866
|
+
|
867
|
+
class MBartForCausalLM < MBartPreTrainedModel
|
868
|
+
attr_reader :num_decoder_layers, :num_decoder_heads, :decoder_dim_kv,
|
869
|
+
:num_encoder_layers, :num_encoder_heads, :encoder_dim_kv
|
870
|
+
|
871
|
+
def initialize(config, decoder_merged_session, generation_config)
|
872
|
+
super(config, decoder_merged_session)
|
873
|
+
@generation_config = generation_config
|
874
|
+
|
875
|
+
@num_decoder_layers = @config["decoder_layers"]
|
876
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
877
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
878
|
+
|
879
|
+
@num_encoder_layers = @config["encoder_layers"]
|
880
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
881
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
|
882
|
+
end
|
883
|
+
end
|
884
|
+
|
885
|
+
class M2M100PreTrainedModel < PreTrainedModel
|
886
|
+
end
|
887
|
+
|
888
|
+
class M2M100Model < M2M100PreTrainedModel
|
889
|
+
end
|
890
|
+
|
891
|
+
class M2M100ForConditionalGeneration < M2M100PreTrainedModel
|
892
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
893
|
+
super(config, session)
|
894
|
+
@decoder_merged_session = decoder_merged_session
|
895
|
+
@generation_config = generation_config
|
896
|
+
|
897
|
+
@num_decoder_layers = @config["decoder_layers"]
|
898
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
899
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
900
|
+
|
901
|
+
@num_encoder_layers = @config["encoder_layers"]
|
902
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
903
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
|
904
|
+
end
|
905
|
+
end
|
906
|
+
|
907
|
+
class RobertaPreTrainedModel < PreTrainedModel
|
908
|
+
end
|
909
|
+
|
910
|
+
class RobertaModel < RobertaPreTrainedModel
|
911
|
+
end
|
912
|
+
|
913
|
+
class RobertaForMaskedLM < RobertaPreTrainedModel
|
914
|
+
def call(model_inputs)
|
915
|
+
MaskedLMOutput.new(*super(model_inputs))
|
916
|
+
end
|
917
|
+
end
|
918
|
+
|
919
|
+
class XLMRobertaPreTrainedModel < PreTrainedModel
|
920
|
+
end
|
921
|
+
|
922
|
+
class XLMRobertaModel < XLMRobertaPreTrainedModel
|
923
|
+
end
|
924
|
+
|
925
|
+
class XLMRobertaForSequenceClassification < XLMRobertaPreTrainedModel
|
926
|
+
def call(model_inputs)
|
927
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
928
|
+
end
|
929
|
+
end
|
930
|
+
|
931
|
+
class ViTPreTrainedModel < PreTrainedModel
|
932
|
+
end
|
933
|
+
|
934
|
+
class ViTModel < ViTPreTrainedModel
|
935
|
+
end
|
936
|
+
|
937
|
+
class ViTForImageClassification < ViTPreTrainedModel
|
938
|
+
def call(model_inputs)
|
939
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
940
|
+
end
|
941
|
+
end
|
942
|
+
|
943
|
+
class CLIPPreTrainedModel < PreTrainedModel
|
944
|
+
end
|
945
|
+
|
946
|
+
class CLIPModel < CLIPPreTrainedModel
|
947
|
+
end
|
948
|
+
|
949
|
+
class GPT2PreTrainedModel < PreTrainedModel
|
950
|
+
attr_reader :num_heads, :num_layers, :dim_kv
|
951
|
+
|
952
|
+
def initialize(config, session, generation_config)
|
953
|
+
super(config, session)
|
954
|
+
@generation_config = generation_config
|
955
|
+
|
956
|
+
# config doesn't contain pad_token_id, so we assume it is the eos_token_id
|
957
|
+
@config["pad_token_id"] = @config["eos_token_id"]
|
958
|
+
|
959
|
+
@num_heads = @config["n_head"]
|
960
|
+
@num_layers = @config["n_layer"]
|
961
|
+
@dim_kv = @config["n_embd"] / @num_heads.to_f
|
962
|
+
end
|
963
|
+
end
|
964
|
+
|
965
|
+
class GPT2Model < GPT2PreTrainedModel
|
966
|
+
end
|
967
|
+
|
968
|
+
class GPT2LMHeadModel < GPT2PreTrainedModel
|
969
|
+
end
|
970
|
+
|
971
|
+
class OwlViTPreTrainedModel < PreTrainedModel
|
972
|
+
end
|
973
|
+
|
974
|
+
class OwlViTModel < OwlViTPreTrainedModel
|
975
|
+
end
|
976
|
+
|
977
|
+
class OwlViTForObjectDetection < OwlViTPreTrainedModel
|
978
|
+
end
|
979
|
+
|
980
|
+
class DetrPreTrainedModel < PreTrainedModel
|
981
|
+
end
|
982
|
+
|
983
|
+
class DetrModel < DetrPreTrainedModel
|
984
|
+
end
|
985
|
+
|
986
|
+
class DetrForObjectDetection < DetrPreTrainedModel
|
987
|
+
def call(model_inputs)
|
988
|
+
DetrObjectDetectionOutput.new(*super(model_inputs))
|
989
|
+
end
|
990
|
+
end
|
991
|
+
|
992
|
+
class DetrForSegmentation < DetrPreTrainedModel
|
993
|
+
def call(model_inputs)
|
994
|
+
DetrSegmentationOutput.new(*super(model_inputs))
|
995
|
+
end
|
996
|
+
end
|
997
|
+
|
998
|
+
class Swin2SRPreTrainedModel < PreTrainedModel
|
999
|
+
end
|
1000
|
+
|
1001
|
+
class Swin2SRModel < Swin2SRPreTrainedModel
|
1002
|
+
end
|
1003
|
+
|
1004
|
+
class Swin2SRForImageSuperResolution < Swin2SRPreTrainedModel
|
1005
|
+
end
|
1006
|
+
|
1007
|
+
class DPTPreTrainedModel < PreTrainedModel
|
1008
|
+
end
|
1009
|
+
|
1010
|
+
class DPTModel < DPTPreTrainedModel
|
1011
|
+
end
|
1012
|
+
|
1013
|
+
class DPTForDepthEstimation < DPTPreTrainedModel
|
1014
|
+
end
|
1015
|
+
|
1016
|
+
class VisionEncoderDecoderModel < PreTrainedModel
|
1017
|
+
MAIN_INPUT_NAME = :pixel_values
|
1018
|
+
|
1019
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
1020
|
+
super(config, session)
|
1021
|
+
@decoder_merged_session = decoder_merged_session
|
1022
|
+
@generation_config = generation_config
|
1023
|
+
|
1024
|
+
# Extract configs
|
1025
|
+
encoder_config = @config["encoder"]
|
1026
|
+
decoder_config = @config["decoder"]
|
1027
|
+
|
1028
|
+
# Validate encoder
|
1029
|
+
encoder_model_type = encoder_config["model_type"]
|
1030
|
+
encoder_model = MODEL_MAPPING_NAMES_ENCODER_ONLY[encoder_model_type] || MODEL_MAPPING_NAMES_ENCODER_DECODER[encoder_model_type]
|
1031
|
+
if !encoder_model
|
1032
|
+
warn "Model type for encoder '#{encoder_model_type}' not found, assuming encoder-only architecture. Please report this."
|
1033
|
+
end
|
1034
|
+
|
1035
|
+
# Validate decoder
|
1036
|
+
decoder_model = MODEL_WITH_LM_HEAD_MAPPING_NAMES[decoder_config["model_type"]]
|
1037
|
+
if !decoder_model
|
1038
|
+
raise Error, "Unable to construct `VisionEncoderDecoder` due to unsupported decoder: \"#{decoder_config["model_type"]}\""
|
1039
|
+
end
|
1040
|
+
|
1041
|
+
decoder_model_class = decoder_model[1]
|
1042
|
+
decoder = decoder_model_class.new(decoder_config, decoder_merged_session, generation_config)
|
1043
|
+
|
1044
|
+
@add_encoder_pkv = decoder.respond_to?(:num_decoder_layers)
|
1045
|
+
if @add_encoder_pkv
|
1046
|
+
# Decoder is part of an encoder-decoder model
|
1047
|
+
@num_decoder_layers = decoder.num_decoder_layers
|
1048
|
+
@num_decoder_heads = decoder.num_decoder_heads
|
1049
|
+
@decoder_dim_kv = decoder.decoder_dim_kv
|
1050
|
+
|
1051
|
+
@num_encoder_layers = decoder.num_encoder_layers
|
1052
|
+
@num_encoder_heads = decoder.num_encoder_heads
|
1053
|
+
@encoder_dim_kv = decoder.encoder_dim_kv
|
1054
|
+
else
|
1055
|
+
# Decoder is a decoder-only model
|
1056
|
+
@num_layers = decoder.num_layers
|
1057
|
+
@num_heads = decoder.num_heads
|
1058
|
+
@dim_kv = decoder.dim_kv
|
1059
|
+
end
|
1060
|
+
end
|
1061
|
+
end
|
1062
|
+
|
1063
|
+
class DonutSwinPreTrainedModel < PreTrainedModel
|
1064
|
+
end
|
1065
|
+
|
1066
|
+
class DonutSwinModel < DonutSwinPreTrainedModel
|
1067
|
+
end
|
1068
|
+
|
232
1069
|
MODEL_MAPPING_NAMES_ENCODER_ONLY = {
|
233
1070
|
"bert" => ["BertModel", BertModel],
|
234
1071
|
"nomic_bert" => ["NomicBertModel", NomicBertModel],
|
235
1072
|
"deberta-v2" => ["DebertaV2Model", DebertaV2Model],
|
236
|
-
"
|
1073
|
+
"mpnet" => ["MPNetModel", MPNetModel],
|
1074
|
+
"distilbert" => ["DistilBertModel", DistilBertModel],
|
1075
|
+
"roberta" => ["RobertaModel", RobertaModel],
|
1076
|
+
"xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel],
|
1077
|
+
"clip" => ["CLIPModel", CLIPModel],
|
1078
|
+
"detr" => ["DetrModel", DetrModel],
|
1079
|
+
"vit" => ["ViTModel", ViTModel],
|
1080
|
+
"owlvit" => ["OwlViTModel", OwlViTModel],
|
1081
|
+
"donut-swin" => ["DonutSwinModel", DonutSwinModel]
|
1082
|
+
}
|
1083
|
+
|
1084
|
+
MODEL_MAPPING_NAMES_ENCODER_DECODER = {
|
1085
|
+
"bart" => ["BartModel", BartModel]
|
237
1086
|
}
|
238
1087
|
|
239
1088
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
240
1089
|
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
|
241
|
-
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
|
1090
|
+
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
|
1091
|
+
"xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
|
1092
|
+
"bart" => ["BartForSequenceClassification", BartForSequenceClassification]
|
242
1093
|
}
|
243
1094
|
|
244
1095
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
|
245
1096
|
"bert" => ["BertForTokenClassification", BertForTokenClassification]
|
246
1097
|
}
|
247
1098
|
|
1099
|
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {
|
1100
|
+
"t5" => ["T5ForConditionalGeneration", T5ForConditionalGeneration],
|
1101
|
+
"bart" => ["BartForConditionalGeneration", BartForConditionalGeneration],
|
1102
|
+
"m2m_100" => ["M2M100ForConditionalGeneration", M2M100ForConditionalGeneration]
|
1103
|
+
}
|
1104
|
+
|
1105
|
+
MODEL_WITH_LM_HEAD_MAPPING_NAMES = {
|
1106
|
+
"gpt2" => ["GPT2LMHeadModel", GPT2LMHeadModel],
|
1107
|
+
"mbart" => ["MBartForCausalLM", MBartForCausalLM]
|
1108
|
+
}
|
1109
|
+
|
1110
|
+
MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
|
1111
|
+
"bert" => ["BertForMaskedLM", BertForMaskedLM],
|
1112
|
+
"roberta" => ["RobertaForMaskedLM", RobertaForMaskedLM]
|
1113
|
+
}
|
1114
|
+
|
248
1115
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
249
1116
|
"distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
|
250
1117
|
}
|
251
1118
|
|
1119
|
+
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {
|
1120
|
+
"vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
|
1121
|
+
}
|
1122
|
+
|
1123
|
+
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = {
|
1124
|
+
"vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
|
1125
|
+
}
|
1126
|
+
|
1127
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
|
1128
|
+
"vit" => ["ViTForImageClassification", ViTForImageClassification]
|
1129
|
+
}
|
1130
|
+
|
1131
|
+
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = {
|
1132
|
+
"detr" => ["DetrForObjectDetection", DetrForObjectDetection]
|
1133
|
+
}
|
1134
|
+
|
1135
|
+
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = {
|
1136
|
+
"owlvit" => ["OwlViTForObjectDetection", OwlViTForObjectDetection]
|
1137
|
+
}
|
1138
|
+
|
1139
|
+
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = {
|
1140
|
+
"detr" => ["DetrForSegmentation", DetrForSegmentation]
|
1141
|
+
}
|
1142
|
+
|
1143
|
+
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = {
|
1144
|
+
}
|
1145
|
+
|
1146
|
+
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = {
|
1147
|
+
"swin2sr" => ["Swin2SRForImageSuperResolution", Swin2SRForImageSuperResolution]
|
1148
|
+
}
|
1149
|
+
|
1150
|
+
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = {
|
1151
|
+
"dpt" => ["DPTForDepthEstimation", DPTForDepthEstimation]
|
1152
|
+
}
|
1153
|
+
|
1154
|
+
MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = {
|
1155
|
+
}
|
1156
|
+
|
252
1157
|
MODEL_CLASS_TYPE_MAPPING = [
|
253
1158
|
[MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES[:EncoderOnly]],
|
1159
|
+
[MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES[:EncoderDecoder]],
|
254
1160
|
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
255
1161
|
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
256
|
-
[
|
1162
|
+
[MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
|
1163
|
+
[MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES[:DecoderOnly]],
|
1164
|
+
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1165
|
+
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1166
|
+
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Vision2Seq]],
|
1167
|
+
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1168
|
+
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1169
|
+
[MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1170
|
+
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1171
|
+
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1172
|
+
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1173
|
+
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1174
|
+
[MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
|
257
1175
|
]
|
258
1176
|
|
259
1177
|
MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
|
@@ -277,11 +1195,77 @@ module Informers
|
|
277
1195
|
MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
|
278
1196
|
end
|
279
1197
|
|
1198
|
+
class AutoModelForSeq2SeqLM < PretrainedMixin
|
1199
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES]
|
1200
|
+
end
|
1201
|
+
|
1202
|
+
class AutoModelForCausalLM < PretrainedMixin
|
1203
|
+
MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]
|
1204
|
+
end
|
1205
|
+
|
1206
|
+
class AutoModelForMaskedLM < PretrainedMixin
|
1207
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]
|
1208
|
+
end
|
1209
|
+
|
280
1210
|
class AutoModelForQuestionAnswering < PretrainedMixin
|
281
1211
|
MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
|
282
1212
|
end
|
283
1213
|
|
1214
|
+
class AutoModelForVision2Seq < PretrainedMixin
|
1215
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]
|
1216
|
+
end
|
1217
|
+
|
1218
|
+
class AutoModelForImageClassification < PretrainedMixin
|
1219
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]
|
1220
|
+
end
|
1221
|
+
|
1222
|
+
class AutoModelForImageSegmentation < PretrainedMixin
|
1223
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]
|
1224
|
+
end
|
1225
|
+
|
1226
|
+
class AutoModelForSemanticSegmentation < PretrainedMixin
|
1227
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES]
|
1228
|
+
end
|
1229
|
+
|
1230
|
+
class AutoModelForObjectDetection < PretrainedMixin
|
1231
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]
|
1232
|
+
end
|
1233
|
+
|
1234
|
+
class AutoModelForZeroShotObjectDetection < PretrainedMixin
|
1235
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]
|
1236
|
+
end
|
1237
|
+
|
1238
|
+
class AutoModelForDocumentQuestionAnswering < PretrainedMixin
|
1239
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]
|
1240
|
+
end
|
1241
|
+
|
1242
|
+
class AutoModelForImageToImage < PretrainedMixin
|
1243
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]
|
1244
|
+
end
|
1245
|
+
|
1246
|
+
class AutoModelForDepthEstimation < PretrainedMixin
|
1247
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES]
|
1248
|
+
end
|
1249
|
+
|
1250
|
+
class AutoModelForImageFeatureExtraction < PretrainedMixin
|
1251
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]
|
1252
|
+
end
|
1253
|
+
|
284
1254
|
class ModelOutput
|
1255
|
+
def [](key)
|
1256
|
+
instance_variable_get("@#{key}")
|
1257
|
+
end
|
1258
|
+
end
|
1259
|
+
|
1260
|
+
class Seq2SeqLMOutput < ModelOutput
|
1261
|
+
def initialize(logits, past_key_values, encoder_outputs, decoder_attentions = nil, cross_attentions = nil)
|
1262
|
+
super()
|
1263
|
+
@logits = logits
|
1264
|
+
@past_key_values = past_key_values
|
1265
|
+
@encoder_outputs = encoder_outputs
|
1266
|
+
@decoder_attentions = decoder_attentions
|
1267
|
+
@cross_attentions = cross_attentions
|
1268
|
+
end
|
285
1269
|
end
|
286
1270
|
|
287
1271
|
class SequenceClassifierOutput < ModelOutput
|
@@ -302,6 +1286,15 @@ module Informers
|
|
302
1286
|
end
|
303
1287
|
end
|
304
1288
|
|
1289
|
+
class MaskedLMOutput < ModelOutput
|
1290
|
+
attr_reader :logits
|
1291
|
+
|
1292
|
+
def initialize(logits)
|
1293
|
+
super()
|
1294
|
+
@logits = logits
|
1295
|
+
end
|
1296
|
+
end
|
1297
|
+
|
305
1298
|
class QuestionAnsweringModelOutput < ModelOutput
|
306
1299
|
attr_reader :start_logits, :end_logits
|
307
1300
|
|
@@ -311,4 +1304,25 @@ module Informers
|
|
311
1304
|
@end_logits = end_logits
|
312
1305
|
end
|
313
1306
|
end
|
1307
|
+
|
1308
|
+
class DetrObjectDetectionOutput < ModelOutput
|
1309
|
+
attr_reader :logits, :pred_boxes
|
1310
|
+
|
1311
|
+
def initialize(logits, pred_boxes)
|
1312
|
+
super()
|
1313
|
+
@logits = logits
|
1314
|
+
@pred_boxes = pred_boxes
|
1315
|
+
end
|
1316
|
+
end
|
1317
|
+
|
1318
|
+
class DetrSegmentationOutput < ModelOutput
|
1319
|
+
attr_reader :logits, :pred_boxes, :pred_masks
|
1320
|
+
|
1321
|
+
def initialize(logits, pred_boxes, pred_masks)
|
1322
|
+
super()
|
1323
|
+
@logits = logits
|
1324
|
+
@pred_boxes = pred_boxes
|
1325
|
+
@pred_masks = pred_masks
|
1326
|
+
end
|
1327
|
+
end
|
314
1328
|
end
|