informers 1.0.3 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +123 -0
- data/lib/informers/configs.rb +10 -8
- data/lib/informers/model.rb +2 -9
- data/lib/informers/models.rb +997 -12
- data/lib/informers/pipelines.rb +768 -8
- data/lib/informers/processors.rb +796 -0
- data/lib/informers/tokenizers.rb +154 -4
- data/lib/informers/utils/core.rb +4 -0
- data/lib/informers/utils/generation.rb +294 -0
- data/lib/informers/utils/image.rb +116 -0
- data/lib/informers/utils/math.rb +73 -0
- data/lib/informers/utils/tensor.rb +46 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +3 -0
- metadata +8 -5
data/lib/informers/models.rb
CHANGED
@@ -44,7 +44,7 @@ module Informers
|
|
44
44
|
end
|
45
45
|
|
46
46
|
const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
|
47
|
-
model_info = model_class_mapping[config
|
47
|
+
model_info = model_class_mapping[config[:model_type]]
|
48
48
|
if !model_info
|
49
49
|
next # Item not found in this mapping
|
50
50
|
end
|
@@ -52,15 +52,17 @@ module Informers
|
|
52
52
|
end
|
53
53
|
|
54
54
|
if const_defined?(:BASE_IF_FAIL)
|
55
|
-
warn "Unknown model class #{config
|
55
|
+
warn "Unknown model class #{config[:model_type].inspect}, attempting to construct from base class."
|
56
56
|
PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
|
57
57
|
else
|
58
|
-
raise Error, "Unsupported model type: #{config
|
58
|
+
raise Error, "Unsupported model type: #{config[:model_type]}"
|
59
59
|
end
|
60
60
|
end
|
61
61
|
end
|
62
62
|
|
63
63
|
class PreTrainedModel
|
64
|
+
MAIN_INPUT_NAME = :input_ids
|
65
|
+
|
64
66
|
attr_reader :config
|
65
67
|
|
66
68
|
def initialize(config, session)
|
@@ -76,9 +78,19 @@ module Informers
|
|
76
78
|
|
77
79
|
case model_type
|
78
80
|
when MODEL_TYPES[:DecoderOnly]
|
79
|
-
|
81
|
+
@can_generate = true
|
82
|
+
|
83
|
+
@run_beam = method(:decoder_run_beam)
|
84
|
+
@get_start_beams = method(:decoder_start_beams)
|
85
|
+
@update_beam = method(:decoder_update_beam)
|
86
|
+
@forward = method(:decoder_forward)
|
80
87
|
when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
|
81
|
-
|
88
|
+
@can_generate = true
|
89
|
+
|
90
|
+
@run_beam = method(:seq2seq_run_beam)
|
91
|
+
@get_start_beams = method(:seq2seq_start_beams)
|
92
|
+
@update_beam = method(:seq2seq_update_beam)
|
93
|
+
@forward = method(:seq2seq_forward)
|
82
94
|
when MODEL_TYPES[:EncoderDecoder]
|
83
95
|
raise Todo
|
84
96
|
else
|
@@ -110,10 +122,19 @@ module Informers
|
|
110
122
|
model_type = MODEL_TYPE_MAPPING[model_name]
|
111
123
|
|
112
124
|
if model_type == MODEL_TYPES[:DecoderOnly]
|
113
|
-
|
125
|
+
info = [
|
126
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
127
|
+
construct_session(pretrained_model_name_or_path, options[:model_file_name] || "decoder_model_merged", **options),
|
128
|
+
Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
|
129
|
+
]
|
114
130
|
|
115
131
|
elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
|
116
|
-
|
132
|
+
info = [
|
133
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
134
|
+
construct_session(pretrained_model_name_or_path, "encoder_model", **options),
|
135
|
+
construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options),
|
136
|
+
Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
|
137
|
+
]
|
117
138
|
|
118
139
|
elsif model_type == MODEL_TYPES[:MaskGeneration]
|
119
140
|
raise Todo
|
@@ -123,7 +144,7 @@ module Informers
|
|
123
144
|
|
124
145
|
else
|
125
146
|
if model_type != MODEL_TYPES[:EncoderOnly]
|
126
|
-
warn "Model type for '#{model_name || config
|
147
|
+
warn "Model type for '#{model_name || config[:model_type]}' not found, assuming encoder-only architecture. Please report this."
|
127
148
|
end
|
128
149
|
info = [
|
129
150
|
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
@@ -153,8 +174,445 @@ module Informers
|
|
153
174
|
@forward.(model_inputs, **kwargs)
|
154
175
|
end
|
155
176
|
|
177
|
+
def generate(inputs, generation_config = nil, logits_processor = nil, inputs_attention_mask: nil)
|
178
|
+
if !@can_generate
|
179
|
+
model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
|
180
|
+
error_message = "The current model class (#{model_name}) is not compatible with `.generate()`, as it doesn't have a language model head."
|
181
|
+
raise Error, error_message
|
182
|
+
end
|
183
|
+
|
184
|
+
if !inputs.is_a?(Array)
|
185
|
+
raise ArgumentError, "`inputs` must be an Array, but is #{inputs.class.name}"
|
186
|
+
end
|
187
|
+
|
188
|
+
if @config[:is_encoder_decoder]
|
189
|
+
# Generating from the encoder outputs
|
190
|
+
input_ids_seq_length = 0
|
191
|
+
else
|
192
|
+
input_ids_seq_length = inputs.length
|
193
|
+
|
194
|
+
# decoder-only
|
195
|
+
if input_ids_seq_length == 0
|
196
|
+
raise Error, "Must supply a non-empty array of input token ids."
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
# Update generation config with defaults
|
201
|
+
generation_config = get_generation_config(generation_config)
|
202
|
+
|
203
|
+
logits_processor ||= Utils::LogitsProcessorList.new
|
204
|
+
|
205
|
+
# Update logits processor
|
206
|
+
logits_processor = get_logits_processor(
|
207
|
+
generation_config,
|
208
|
+
input_ids_seq_length,
|
209
|
+
logits_processor
|
210
|
+
)
|
211
|
+
|
212
|
+
eos_token_ids = generation_config[:eos_token_id]
|
213
|
+
if !eos_token_ids.nil? && !eos_token_ids.is_a?(Array)
|
214
|
+
eos_token_ids = [eos_token_ids]
|
215
|
+
end
|
216
|
+
|
217
|
+
num_output_tokens = 1
|
218
|
+
max_output_tokens = num_output_tokens + (generation_config[:max_new_tokens] || Float::INFINITY)
|
219
|
+
|
220
|
+
# Only use max length if max_new_tokens is not provided
|
221
|
+
use_max_length = generation_config[:max_length].is_a?(Integer) && generation_config[:max_new_tokens].nil?
|
222
|
+
sampler = Utils::Sampler.get_sampler(generation_config)
|
223
|
+
|
224
|
+
beams = get_start_beams(inputs, generation_config, num_output_tokens, inputs_attention_mask)
|
225
|
+
|
226
|
+
while beams.any? { |x| !x[:done] } && num_output_tokens < max_output_tokens
|
227
|
+
newest_beams = []
|
228
|
+
beams.each do |beam|
|
229
|
+
if beam[:done]
|
230
|
+
# Add this beam back into the pool
|
231
|
+
newest_beams << beam
|
232
|
+
next
|
233
|
+
end
|
234
|
+
if use_max_length && beam[:output_token_ids].length >= generation_config["max_length"]
|
235
|
+
# Set this beam to done and add it back into the pool
|
236
|
+
beam[:done] = true
|
237
|
+
newest_beams << beam
|
238
|
+
next
|
239
|
+
end
|
240
|
+
|
241
|
+
output = run_beam(beam)
|
242
|
+
|
243
|
+
# add attentions/scores to beam only if user requested
|
244
|
+
if generation_config["output_attentions"]
|
245
|
+
add_attentions_to_beam(beam, output)
|
246
|
+
end
|
247
|
+
|
248
|
+
# Logits are of the form [batch_size, out_seq_length, vocab_size]
|
249
|
+
# In most cases, this will be [batch_size, 1, vocab_size]
|
250
|
+
# So, we select the last token's logits:
|
251
|
+
# (equivalent to `logits = outputs.logits[:, -1, :]`)
|
252
|
+
logits = output["logits"].map { |v| v[-1] }
|
253
|
+
|
254
|
+
# Apply logits processor
|
255
|
+
logits_processor.(beam[:output_token_ids], logits)
|
256
|
+
|
257
|
+
sampled_tokens = sampler.(logits)
|
258
|
+
sampled_tokens.each do |new_token_id, log_prob|
|
259
|
+
# use previous beam as a starting point
|
260
|
+
new_beam = beam.dup
|
261
|
+
|
262
|
+
# update new beam
|
263
|
+
update_beam(new_beam, new_token_id)
|
264
|
+
|
265
|
+
new_beam[:score] += log_prob
|
266
|
+
|
267
|
+
if eos_token_ids && eos_token_ids.include?(new_token_id)
|
268
|
+
new_beam[:done] = true
|
269
|
+
end
|
270
|
+
|
271
|
+
newest_beams << new_beam
|
272
|
+
end
|
273
|
+
end
|
274
|
+
num_output_tokens += 1
|
275
|
+
|
276
|
+
# Next, we get the best beams, per ID
|
277
|
+
newest_beams =
|
278
|
+
group_beams(newest_beams).map do |group|
|
279
|
+
group.sort_by { |v| -v[:score] }[0...generation_config["num_beams"]]
|
280
|
+
end
|
281
|
+
|
282
|
+
# Flatten beams
|
283
|
+
beams = newest_beams.flatten(1)
|
284
|
+
|
285
|
+
# Run callback
|
286
|
+
if generation_config["callback_function"]
|
287
|
+
generation_config["callback_function"].(beams)
|
288
|
+
end
|
289
|
+
end
|
290
|
+
|
291
|
+
# TODO: Ensure that we can return non-batched outputs
|
292
|
+
|
293
|
+
grouped_beams = group_beams(beams)
|
294
|
+
|
295
|
+
get_flattened = lambda do |key|
|
296
|
+
grouped_beams.map do |batch|
|
297
|
+
if generation_config["num_return_sequences"] > 1
|
298
|
+
raise Todo
|
299
|
+
else
|
300
|
+
[batch[0][key]]
|
301
|
+
end
|
302
|
+
end.flatten(1)
|
303
|
+
end
|
304
|
+
|
305
|
+
sequences = get_flattened.(:output_token_ids) # [1, seqLength]
|
306
|
+
|
307
|
+
if generation_config["return_dict_in_generate"]
|
308
|
+
raise Todo
|
309
|
+
else
|
310
|
+
sequences
|
311
|
+
end
|
312
|
+
end
|
313
|
+
|
156
314
|
private
|
157
315
|
|
316
|
+
def get_logits_processor(
|
317
|
+
generation_config,
|
318
|
+
input_ids_seq_length,
|
319
|
+
logits_processor = nil
|
320
|
+
)
|
321
|
+
processors = Utils::LogitsProcessorList.new
|
322
|
+
|
323
|
+
if !generation_config["repetition_penalty"].nil? && generation_config["repetition_penalty"] != 1.0
|
324
|
+
processors.push(Utils::RepetitionPenaltyLogitsProcessor.new(generation_config["repetition_penalty"]))
|
325
|
+
end
|
326
|
+
|
327
|
+
if !generation_config["no_repeat_ngram_size"].nil? && generation_config["no_repeat_ngram_size"] > 0
|
328
|
+
processors.push(Utils::NoRepeatNGramLogitsProcessor.new(generation_config["no_repeat_ngram_size"]))
|
329
|
+
end
|
330
|
+
|
331
|
+
if !generation_config["bad_words_ids"].nil?
|
332
|
+
processors.push(Utils::NoBadWordsLogitsProcessor.new(generation_config["bad_words_ids"], generation_config["eos_token_id"]))
|
333
|
+
end
|
334
|
+
|
335
|
+
if !generation_config["min_length"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_length"] > 0
|
336
|
+
processors.push(Utils::MinLengthLogitsProcessor.new(generation_config["min_length"], generation_config["eos_token_id"]))
|
337
|
+
end
|
338
|
+
|
339
|
+
if !generation_config["min_new_tokens"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_new_tokens"] > 0
|
340
|
+
processors.push(Utils::MinNewTokensLengthLogitsProcessor.new(
|
341
|
+
input_ids_seq_length,
|
342
|
+
generation_config["min_new_tokens"],
|
343
|
+
generation_config["eos_token_id"]
|
344
|
+
))
|
345
|
+
end
|
346
|
+
|
347
|
+
if !generation_config["forced_bos_token_id"].nil?
|
348
|
+
processors.push(Utils::ForcedBOSTokenLogitsProcessor.new(generation_config["forced_bos_token_id"]))
|
349
|
+
end
|
350
|
+
|
351
|
+
if !generation_config["forced_eos_token_id"].nil?
|
352
|
+
processors.push(Utils::ForcedEOSTokenLogitsProcessor.new(
|
353
|
+
generation_config["max_length"],
|
354
|
+
generation_config["forced_eos_token_id"]
|
355
|
+
))
|
356
|
+
end
|
357
|
+
|
358
|
+
if !generation_config["begin_suppress_tokens"].nil?
|
359
|
+
raise Todo
|
360
|
+
end
|
361
|
+
|
362
|
+
if !generation_config["forced_decoder_ids"].nil?
|
363
|
+
processors.push(Utils::ForceTokensLogitsProcessor.new(generation_config["forced_decoder_ids"]))
|
364
|
+
end
|
365
|
+
|
366
|
+
if !logits_processor.nil?
|
367
|
+
processors.concat(logits_processor)
|
368
|
+
end
|
369
|
+
|
370
|
+
processors
|
371
|
+
end
|
372
|
+
|
373
|
+
def get_generation_config(generation_config)
|
374
|
+
# Create empty generation config (contains defaults)
|
375
|
+
# We pass `@config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
|
376
|
+
gen_config = Utils::GenerationConfig.new(@config.to_h)
|
377
|
+
|
378
|
+
# Apply model's generation config, if it exists
|
379
|
+
if @generation_config
|
380
|
+
gen_config.merge!(@generation_config)
|
381
|
+
end
|
382
|
+
|
383
|
+
# Finally, use any generation config specified by the user
|
384
|
+
# when calling `generate`
|
385
|
+
if !generation_config.nil?
|
386
|
+
gen_config.merge!(generation_config)
|
387
|
+
end
|
388
|
+
|
389
|
+
gen_config
|
390
|
+
end
|
391
|
+
|
392
|
+
def seq2seq_forward(model_inputs)
|
393
|
+
encoder_outputs = model_inputs[:encoder_outputs]
|
394
|
+
past_key_values = model_inputs[:past_key_values]
|
395
|
+
|
396
|
+
if !encoder_outputs
|
397
|
+
# Encoder outputs are not given, so we must compute them.
|
398
|
+
encoder_outputs = encoder_forward(model_inputs)[0]
|
399
|
+
end
|
400
|
+
decoder_feeds = {
|
401
|
+
input_ids: model_inputs[:decoder_input_ids],
|
402
|
+
encoder_hidden_states: encoder_outputs
|
403
|
+
}
|
404
|
+
use_cache_branch = !!past_key_values
|
405
|
+
|
406
|
+
if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
|
407
|
+
decoder_feeds[:use_cache_branch] = [use_cache_branch]
|
408
|
+
end
|
409
|
+
|
410
|
+
if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("encoder_attention_mask")
|
411
|
+
decoder_feeds[:encoder_attention_mask] = model_inputs[:attention_mask]
|
412
|
+
end
|
413
|
+
|
414
|
+
prepare_position_ids(@decoder_merged_session, decoder_feeds, use_cache_branch)
|
415
|
+
add_past_key_values(decoder_feeds, past_key_values)
|
416
|
+
|
417
|
+
decoder_results = session_run(@decoder_merged_session, decoder_feeds)
|
418
|
+
decoder_results = @decoder_merged_session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
|
419
|
+
logits = decoder_results["logits"]
|
420
|
+
past_key_values = get_past_key_values(decoder_results, past_key_values)
|
421
|
+
|
422
|
+
# Get cross attention and/or decoder attentions if they are present
|
423
|
+
attns = get_attentions(decoder_results)
|
424
|
+
|
425
|
+
Seq2SeqLMOutput.new(logits, past_key_values, encoder_outputs, attns["decoder_attentions"], attns["cross_attentions"])
|
426
|
+
end
|
427
|
+
|
428
|
+
def prepare_position_ids(session, feeds, use_cache_branch)
|
429
|
+
if !session.inputs.map { |v| v[:name] }.include?("position_ids")
|
430
|
+
return
|
431
|
+
end
|
432
|
+
|
433
|
+
raise Todo
|
434
|
+
end
|
435
|
+
|
436
|
+
def get_past_key_values(decoder_results, past_key_values)
|
437
|
+
pkvs = {}
|
438
|
+
|
439
|
+
decoder_results.each_key do |name|
|
440
|
+
if name.start_with?("present")
|
441
|
+
new_name = name.sub("present", "past_key_values")
|
442
|
+
|
443
|
+
if past_key_values && name.include?("encoder")
|
444
|
+
# Optimization introduced by optimum to reuse past key values. So, we just replace the constant
|
445
|
+
# outputs with the previous past key values.
|
446
|
+
# https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
|
447
|
+
pkvs[new_name] = past_key_values[new_name]
|
448
|
+
else
|
449
|
+
pkvs[new_name] = decoder_results[name]
|
450
|
+
end
|
451
|
+
end
|
452
|
+
end
|
453
|
+
pkvs
|
454
|
+
end
|
455
|
+
|
456
|
+
def get_attentions(decoder_results)
|
457
|
+
attns = {}
|
458
|
+
|
459
|
+
["cross_attentions", "decoder_attentions"].each do |attn_name|
|
460
|
+
result = []
|
461
|
+
decoder_results.each_key do |name|
|
462
|
+
if name.start_with?(attn_name)
|
463
|
+
index = name.split(".").pop
|
464
|
+
result[index] = decoder_results[name]
|
465
|
+
end
|
466
|
+
end
|
467
|
+
attns[attn_name] = result
|
468
|
+
end
|
469
|
+
attns
|
470
|
+
end
|
471
|
+
|
472
|
+
def add_past_key_values(decoder_feeds, past_key_values)
|
473
|
+
if past_key_values
|
474
|
+
decoder_feeds.merge!(past_key_values)
|
475
|
+
else
|
476
|
+
# TODO support batches (i.e., batch_size > 1)
|
477
|
+
batch_size = 1
|
478
|
+
|
479
|
+
if @config[:is_encoder_decoder] && (!@add_encoder_pkv.nil? ? @add_encoder_pkv : true)
|
480
|
+
_encoder_dims = [batch_size, @num_encoder_heads, 0, @encoder_dim_kv]
|
481
|
+
_decoder_dims = [batch_size, @num_decoder_heads, 0, @decoder_dim_kv]
|
482
|
+
@num_decoder_layers.times do |i|
|
483
|
+
# decoder_feeds["past_key_values.#{i}.encoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
|
484
|
+
# decoder_feeds["past_key_values.#{i}.encoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
|
485
|
+
# decoder_feeds["past_key_values.#{i}.decoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
|
486
|
+
# decoder_feeds["past_key_values.#{i}.decoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
|
487
|
+
end
|
488
|
+
elsif @config[:model_type] == "falcon"
|
489
|
+
raise Todo
|
490
|
+
elsif @config[:multi_query]
|
491
|
+
raise Todo
|
492
|
+
elsif @config[:model_type] == "bloom"
|
493
|
+
raise Todo
|
494
|
+
else
|
495
|
+
_dims = [batch_size, @num_heads, 0, @dim_kv]
|
496
|
+
@num_layers.times do |i|
|
497
|
+
# decoder_feeds["past_key_values.#{i}.key"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
|
498
|
+
# decoder_feeds["past_key_values.#{i}.value"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
|
499
|
+
end
|
500
|
+
end
|
501
|
+
end
|
502
|
+
end
|
503
|
+
|
504
|
+
def seq2seq_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask = nil)
|
505
|
+
beams = []
|
506
|
+
beam_id = 0
|
507
|
+
|
508
|
+
requires_attention_mask = !@requires_attention_mask.nil? ? @requires_attention_mask : true
|
509
|
+
|
510
|
+
# decoder_input_ids == output_token_ids
|
511
|
+
decoder_input_ids =
|
512
|
+
generation_config["decoder_input_ids"] ||
|
513
|
+
generation_config["decoder_start_token_id"] ||
|
514
|
+
generation_config["bos_token_id"] ||
|
515
|
+
generation_config["eos_token_id"]
|
516
|
+
|
517
|
+
if !decoder_input_ids.is_a?(Array)
|
518
|
+
decoder_input_ids = [decoder_input_ids]
|
519
|
+
end
|
520
|
+
|
521
|
+
input_token_ids.each do |tokens|
|
522
|
+
# TODO: Improve
|
523
|
+
# Currently, just add back batch dimension.
|
524
|
+
# In future, allow for true parallel execution
|
525
|
+
tokens = [tokens]
|
526
|
+
|
527
|
+
# Create beam
|
528
|
+
start = {
|
529
|
+
inputs: tokens,
|
530
|
+
encoder_outputs: nil,
|
531
|
+
prev_model_outputs: nil,
|
532
|
+
|
533
|
+
output_token_ids: decoder_input_ids,
|
534
|
+
done: false,
|
535
|
+
score: 0,
|
536
|
+
id: beam_id # assign unique id to beams
|
537
|
+
}
|
538
|
+
beam_id += 1
|
539
|
+
|
540
|
+
if requires_attention_mask
|
541
|
+
start[:attention_mask] = prepare_attention_mask(tokens)
|
542
|
+
end
|
543
|
+
|
544
|
+
beams << start
|
545
|
+
end
|
546
|
+
|
547
|
+
beams
|
548
|
+
end
|
549
|
+
|
550
|
+
def prepare_attention_mask(tokens)
|
551
|
+
# Prepare attention mask
|
552
|
+
pad_token_id = @config["pad_token_id"]
|
553
|
+
eos_token_id = @config["eos_token_id"]
|
554
|
+
if eos_token_id.is_a?(Integer)
|
555
|
+
eos_token_id = [eos_token_id]
|
556
|
+
end
|
557
|
+
|
558
|
+
is_pad_token_in_inputs = !tokens.index(pad_token_id).nil?
|
559
|
+
is_pad_token_not_equal_to_eos_token_id = eos_token_id.nil? || !eos_token_id.include?(pad_token_id)
|
560
|
+
|
561
|
+
if is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id
|
562
|
+
raise Todo
|
563
|
+
else
|
564
|
+
Utils.ones_like(tokens)
|
565
|
+
end
|
566
|
+
end
|
567
|
+
|
568
|
+
def seq2seq_run_beam(beam)
|
569
|
+
input_name = self.class.const_get(:MAIN_INPUT_NAME)
|
570
|
+
|
571
|
+
decoder_input_ids = beam[:output_token_ids]
|
572
|
+
if beam[:prev_model_outputs]
|
573
|
+
# After the first step, `prev_model_outputs` won't be null.
|
574
|
+
# So, we cut decoder_input_ids if past is used
|
575
|
+
decoder_input_ids = [decoder_input_ids[-1]]
|
576
|
+
end
|
577
|
+
|
578
|
+
# 1. Prepare
|
579
|
+
model_inputs = {
|
580
|
+
input_name => beam[:inputs],
|
581
|
+
decoder_input_ids: [decoder_input_ids],
|
582
|
+
encoder_outputs: beam[:encoder_outputs],
|
583
|
+
past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
|
584
|
+
}
|
585
|
+
if beam[:attention_mask]
|
586
|
+
model_inputs[:attention_mask] = beam[:attention_mask]
|
587
|
+
end
|
588
|
+
|
589
|
+
# 2. Run
|
590
|
+
output = @forward.(model_inputs)
|
591
|
+
|
592
|
+
# 3. Update
|
593
|
+
beam[:prev_model_outputs] = output
|
594
|
+
beam[:encoder_outputs] = output[:encoder_outputs]
|
595
|
+
|
596
|
+
output
|
597
|
+
end
|
598
|
+
|
599
|
+
def seq2seq_update_beam(beam, new_token_id)
|
600
|
+
beam[:output_token_ids] += [new_token_id]
|
601
|
+
end
|
602
|
+
|
603
|
+
def group_beams(beams)
|
604
|
+
# Group beams by their ids
|
605
|
+
groups = {}
|
606
|
+
beams.each do |obj|
|
607
|
+
if !groups[obj[:id]]
|
608
|
+
groups[obj[:id]] = [obj]
|
609
|
+
else
|
610
|
+
groups[obj[:id]] << obj
|
611
|
+
end
|
612
|
+
end
|
613
|
+
groups.values
|
614
|
+
end
|
615
|
+
|
158
616
|
def encoder_forward(model_inputs, output_names: nil)
|
159
617
|
encoder_feeds = {}
|
160
618
|
@session.inputs.each do |input|
|
@@ -167,7 +625,96 @@ module Informers
|
|
167
625
|
session_run(@session, encoder_feeds, output_names:)
|
168
626
|
end
|
169
627
|
|
170
|
-
def
|
628
|
+
def decoder_forward(model_inputs)
|
629
|
+
input_ids, past_key_values, attention_mask =
|
630
|
+
model_inputs.values_at(:input_ids, :past_key_values, :attention_mask)
|
631
|
+
decoder_feeds = {
|
632
|
+
input_ids: input_ids,
|
633
|
+
attention_mask: attention_mask || prepare_attention_mask(input_ids)
|
634
|
+
}
|
635
|
+
use_cache_branch = !!past_key_values
|
636
|
+
|
637
|
+
if @session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
|
638
|
+
decoder_feeds[:use_cache_branch] = [use_cache_branch]
|
639
|
+
end
|
640
|
+
|
641
|
+
prepare_position_ids(@session, decoder_feeds, use_cache_branch)
|
642
|
+
|
643
|
+
add_past_key_values(decoder_feeds, past_key_values)
|
644
|
+
|
645
|
+
decoder_results = session_run(@session, decoder_feeds)
|
646
|
+
decoder_results = @session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
|
647
|
+
|
648
|
+
logits = decoder_results["logits"]
|
649
|
+
|
650
|
+
past_key_values = get_past_key_values(decoder_results, past_key_values)
|
651
|
+
{"logits" => logits, past_key_values: past_key_values}
|
652
|
+
end
|
653
|
+
|
654
|
+
def decoder_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
|
655
|
+
beams = []
|
656
|
+
|
657
|
+
beam_id = 0
|
658
|
+
input_token_ids.each do |tokens|
|
659
|
+
output_token_ids = tokens.dup
|
660
|
+
|
661
|
+
# TODO: Improve
|
662
|
+
# Currently, just add back batch dimension.
|
663
|
+
# In future, allow for true parallel execution
|
664
|
+
tokens = [tokens]
|
665
|
+
|
666
|
+
if inputs_attention_mask
|
667
|
+
attn_mask = inputs_attention_mask[beam_id]
|
668
|
+
attn_mask = [attn_mask]
|
669
|
+
else
|
670
|
+
attn_mask = prepare_attention_mask(tokens)
|
671
|
+
end
|
672
|
+
|
673
|
+
start = {
|
674
|
+
input: tokens,
|
675
|
+
model_input_ids: tokens,
|
676
|
+
attention_mask: attn_mask,
|
677
|
+
prev_model_outputs: nil,
|
678
|
+
|
679
|
+
output_token_ids: output_token_ids,
|
680
|
+
num_output_tokens: num_output_tokens,
|
681
|
+
|
682
|
+
done: false,
|
683
|
+
score: 0,
|
684
|
+
id: beam_id # assign unique id to beams
|
685
|
+
}
|
686
|
+
beam_id += 1
|
687
|
+
|
688
|
+
beams << start
|
689
|
+
end
|
690
|
+
beams
|
691
|
+
end
|
692
|
+
|
693
|
+
def decoder_run_beam(beam)
|
694
|
+
attn_mask_data = Array.new(beam[:output_token_ids].length, 1)
|
695
|
+
|
696
|
+
# 1. Prepare
|
697
|
+
model_inputs = {
|
698
|
+
input_ids: beam[:model_input_ids],
|
699
|
+
attention_mask: [attn_mask_data],
|
700
|
+
past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
|
701
|
+
}
|
702
|
+
|
703
|
+
# 2. Run
|
704
|
+
output = @forward.(model_inputs)
|
705
|
+
|
706
|
+
# 3. Update
|
707
|
+
beam[:prev_model_outputs] = output
|
708
|
+
|
709
|
+
output
|
710
|
+
end
|
711
|
+
|
712
|
+
def decoder_update_beam(beam, new_token_id)
|
713
|
+
beam[:output_token_ids] += [new_token_id]
|
714
|
+
beam[:model_input_ids] = [[new_token_id]]
|
715
|
+
end
|
716
|
+
|
717
|
+
def session_run(session, inputs, output_names: nil)
|
171
718
|
checked_inputs = validate_inputs(session, inputs)
|
172
719
|
begin
|
173
720
|
output = session.run(output_names || @output_names, checked_inputs)
|
@@ -187,6 +734,18 @@ module Informers
|
|
187
734
|
def validate_inputs(session, inputs)
|
188
735
|
inputs
|
189
736
|
end
|
737
|
+
|
738
|
+
def get_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
|
739
|
+
@get_start_beams.(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
|
740
|
+
end
|
741
|
+
|
742
|
+
def run_beam(beam)
|
743
|
+
@run_beam.(beam)
|
744
|
+
end
|
745
|
+
|
746
|
+
def update_beam(beam, new_token_id)
|
747
|
+
@update_beam.(beam, new_token_id)
|
748
|
+
end
|
190
749
|
end
|
191
750
|
|
192
751
|
class BertPreTrainedModel < PreTrainedModel
|
@@ -195,6 +754,12 @@ module Informers
|
|
195
754
|
class BertModel < BertPreTrainedModel
|
196
755
|
end
|
197
756
|
|
757
|
+
class BertForMaskedLM < BertPreTrainedModel
|
758
|
+
def call(model_inputs)
|
759
|
+
MaskedLMOutput.new(*super(model_inputs))
|
760
|
+
end
|
761
|
+
end
|
762
|
+
|
198
763
|
class BertForSequenceClassification < BertPreTrainedModel
|
199
764
|
def call(model_inputs)
|
200
765
|
SequenceClassifierOutput.new(*super(model_inputs))
|
@@ -243,6 +808,114 @@ module Informers
|
|
243
808
|
class MPNetModel < MPNetPreTrainedModel
|
244
809
|
end
|
245
810
|
|
811
|
+
class T5PreTrainedModel < PreTrainedModel
|
812
|
+
end
|
813
|
+
|
814
|
+
class T5Model < T5PreTrainedModel
|
815
|
+
end
|
816
|
+
|
817
|
+
class T5ForConditionalGeneration < T5PreTrainedModel
|
818
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
819
|
+
super(config, session)
|
820
|
+
@decoder_merged_session = decoder_merged_session
|
821
|
+
@generation_config = generation_config
|
822
|
+
|
823
|
+
@num_decoder_layers = @config[:num_decoder_layers]
|
824
|
+
@num_decoder_heads = @config[:num_heads]
|
825
|
+
@decoder_dim_kv = @config[:d_kv]
|
826
|
+
|
827
|
+
@num_encoder_layers = @config[:num_layers]
|
828
|
+
@num_encoder_heads = @config[:num_heads]
|
829
|
+
@encoder_dim_kv = @config[:d_kv]
|
830
|
+
end
|
831
|
+
end
|
832
|
+
|
833
|
+
class BartPretrainedModel < PreTrainedModel
|
834
|
+
end
|
835
|
+
|
836
|
+
class BartModel < BartPretrainedModel
|
837
|
+
end
|
838
|
+
|
839
|
+
class BartForConditionalGeneration < BartPretrainedModel
|
840
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
841
|
+
super(config, session)
|
842
|
+
@decoder_merged_session = decoder_merged_session
|
843
|
+
@generation_config = generation_config
|
844
|
+
|
845
|
+
@num_decoder_layers = @config["decoder_layers"]
|
846
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
847
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
848
|
+
|
849
|
+
@num_encoder_layers = @config["encoder_layers"]
|
850
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
851
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads
|
852
|
+
end
|
853
|
+
end
|
854
|
+
|
855
|
+
class BartForSequenceClassification < BartPretrainedModel
|
856
|
+
def call(model_inputs)
|
857
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
858
|
+
end
|
859
|
+
end
|
860
|
+
|
861
|
+
class MBartPreTrainedModel < PreTrainedModel
|
862
|
+
end
|
863
|
+
|
864
|
+
class MBartModel < MBartPreTrainedModel
|
865
|
+
end
|
866
|
+
|
867
|
+
class MBartForCausalLM < MBartPreTrainedModel
|
868
|
+
attr_reader :num_decoder_layers, :num_decoder_heads, :decoder_dim_kv,
|
869
|
+
:num_encoder_layers, :num_encoder_heads, :encoder_dim_kv
|
870
|
+
|
871
|
+
def initialize(config, decoder_merged_session, generation_config)
|
872
|
+
super(config, decoder_merged_session)
|
873
|
+
@generation_config = generation_config
|
874
|
+
|
875
|
+
@num_decoder_layers = @config["decoder_layers"]
|
876
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
877
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
878
|
+
|
879
|
+
@num_encoder_layers = @config["encoder_layers"]
|
880
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
881
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
|
882
|
+
end
|
883
|
+
end
|
884
|
+
|
885
|
+
class M2M100PreTrainedModel < PreTrainedModel
|
886
|
+
end
|
887
|
+
|
888
|
+
class M2M100Model < M2M100PreTrainedModel
|
889
|
+
end
|
890
|
+
|
891
|
+
class M2M100ForConditionalGeneration < M2M100PreTrainedModel
|
892
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
893
|
+
super(config, session)
|
894
|
+
@decoder_merged_session = decoder_merged_session
|
895
|
+
@generation_config = generation_config
|
896
|
+
|
897
|
+
@num_decoder_layers = @config["decoder_layers"]
|
898
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
899
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
900
|
+
|
901
|
+
@num_encoder_layers = @config["encoder_layers"]
|
902
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
903
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
|
904
|
+
end
|
905
|
+
end
|
906
|
+
|
907
|
+
class RobertaPreTrainedModel < PreTrainedModel
|
908
|
+
end
|
909
|
+
|
910
|
+
class RobertaModel < RobertaPreTrainedModel
|
911
|
+
end
|
912
|
+
|
913
|
+
class RobertaForMaskedLM < RobertaPreTrainedModel
|
914
|
+
def call(model_inputs)
|
915
|
+
MaskedLMOutput.new(*super(model_inputs))
|
916
|
+
end
|
917
|
+
end
|
918
|
+
|
246
919
|
class XLMRobertaPreTrainedModel < PreTrainedModel
|
247
920
|
end
|
248
921
|
|
@@ -255,34 +928,250 @@ module Informers
|
|
255
928
|
end
|
256
929
|
end
|
257
930
|
|
931
|
+
class ViTPreTrainedModel < PreTrainedModel
|
932
|
+
end
|
933
|
+
|
934
|
+
class ViTModel < ViTPreTrainedModel
|
935
|
+
end
|
936
|
+
|
937
|
+
class ViTForImageClassification < ViTPreTrainedModel
|
938
|
+
def call(model_inputs)
|
939
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
940
|
+
end
|
941
|
+
end
|
942
|
+
|
943
|
+
class CLIPPreTrainedModel < PreTrainedModel
|
944
|
+
end
|
945
|
+
|
946
|
+
class CLIPModel < CLIPPreTrainedModel
|
947
|
+
end
|
948
|
+
|
949
|
+
class GPT2PreTrainedModel < PreTrainedModel
|
950
|
+
attr_reader :num_heads, :num_layers, :dim_kv
|
951
|
+
|
952
|
+
def initialize(config, session, generation_config)
|
953
|
+
super(config, session)
|
954
|
+
@generation_config = generation_config
|
955
|
+
|
956
|
+
# config doesn't contain pad_token_id, so we assume it is the eos_token_id
|
957
|
+
@config["pad_token_id"] = @config["eos_token_id"]
|
958
|
+
|
959
|
+
@num_heads = @config["n_head"]
|
960
|
+
@num_layers = @config["n_layer"]
|
961
|
+
@dim_kv = @config["n_embd"] / @num_heads.to_f
|
962
|
+
end
|
963
|
+
end
|
964
|
+
|
965
|
+
class GPT2Model < GPT2PreTrainedModel
|
966
|
+
end
|
967
|
+
|
968
|
+
class GPT2LMHeadModel < GPT2PreTrainedModel
|
969
|
+
end
|
970
|
+
|
971
|
+
class OwlViTPreTrainedModel < PreTrainedModel
|
972
|
+
end
|
973
|
+
|
974
|
+
class OwlViTModel < OwlViTPreTrainedModel
|
975
|
+
end
|
976
|
+
|
977
|
+
class OwlViTForObjectDetection < OwlViTPreTrainedModel
|
978
|
+
end
|
979
|
+
|
980
|
+
class DetrPreTrainedModel < PreTrainedModel
|
981
|
+
end
|
982
|
+
|
983
|
+
class DetrModel < DetrPreTrainedModel
|
984
|
+
end
|
985
|
+
|
986
|
+
class DetrForObjectDetection < DetrPreTrainedModel
|
987
|
+
def call(model_inputs)
|
988
|
+
DetrObjectDetectionOutput.new(*super(model_inputs))
|
989
|
+
end
|
990
|
+
end
|
991
|
+
|
992
|
+
class DetrForSegmentation < DetrPreTrainedModel
|
993
|
+
def call(model_inputs)
|
994
|
+
DetrSegmentationOutput.new(*super(model_inputs))
|
995
|
+
end
|
996
|
+
end
|
997
|
+
|
998
|
+
class Swin2SRPreTrainedModel < PreTrainedModel
|
999
|
+
end
|
1000
|
+
|
1001
|
+
class Swin2SRModel < Swin2SRPreTrainedModel
|
1002
|
+
end
|
1003
|
+
|
1004
|
+
class Swin2SRForImageSuperResolution < Swin2SRPreTrainedModel
|
1005
|
+
end
|
1006
|
+
|
1007
|
+
class DPTPreTrainedModel < PreTrainedModel
|
1008
|
+
end
|
1009
|
+
|
1010
|
+
class DPTModel < DPTPreTrainedModel
|
1011
|
+
end
|
1012
|
+
|
1013
|
+
class DPTForDepthEstimation < DPTPreTrainedModel
|
1014
|
+
end
|
1015
|
+
|
1016
|
+
class VisionEncoderDecoderModel < PreTrainedModel
|
1017
|
+
MAIN_INPUT_NAME = :pixel_values
|
1018
|
+
|
1019
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
1020
|
+
super(config, session)
|
1021
|
+
@decoder_merged_session = decoder_merged_session
|
1022
|
+
@generation_config = generation_config
|
1023
|
+
|
1024
|
+
# Extract configs
|
1025
|
+
encoder_config = @config["encoder"]
|
1026
|
+
decoder_config = @config["decoder"]
|
1027
|
+
|
1028
|
+
# Validate encoder
|
1029
|
+
encoder_model_type = encoder_config["model_type"]
|
1030
|
+
encoder_model = MODEL_MAPPING_NAMES_ENCODER_ONLY[encoder_model_type] || MODEL_MAPPING_NAMES_ENCODER_DECODER[encoder_model_type]
|
1031
|
+
if !encoder_model
|
1032
|
+
warn "Model type for encoder '#{encoder_model_type}' not found, assuming encoder-only architecture. Please report this."
|
1033
|
+
end
|
1034
|
+
|
1035
|
+
# Validate decoder
|
1036
|
+
decoder_model = MODEL_WITH_LM_HEAD_MAPPING_NAMES[decoder_config["model_type"]]
|
1037
|
+
if !decoder_model
|
1038
|
+
raise Error, "Unable to construct `VisionEncoderDecoder` due to unsupported decoder: \"#{decoder_config["model_type"]}\""
|
1039
|
+
end
|
1040
|
+
|
1041
|
+
decoder_model_class = decoder_model[1]
|
1042
|
+
decoder = decoder_model_class.new(decoder_config, decoder_merged_session, generation_config)
|
1043
|
+
|
1044
|
+
@add_encoder_pkv = decoder.respond_to?(:num_decoder_layers)
|
1045
|
+
if @add_encoder_pkv
|
1046
|
+
# Decoder is part of an encoder-decoder model
|
1047
|
+
@num_decoder_layers = decoder.num_decoder_layers
|
1048
|
+
@num_decoder_heads = decoder.num_decoder_heads
|
1049
|
+
@decoder_dim_kv = decoder.decoder_dim_kv
|
1050
|
+
|
1051
|
+
@num_encoder_layers = decoder.num_encoder_layers
|
1052
|
+
@num_encoder_heads = decoder.num_encoder_heads
|
1053
|
+
@encoder_dim_kv = decoder.encoder_dim_kv
|
1054
|
+
else
|
1055
|
+
# Decoder is a decoder-only model
|
1056
|
+
@num_layers = decoder.num_layers
|
1057
|
+
@num_heads = decoder.num_heads
|
1058
|
+
@dim_kv = decoder.dim_kv
|
1059
|
+
end
|
1060
|
+
end
|
1061
|
+
end
|
1062
|
+
|
1063
|
+
class DonutSwinPreTrainedModel < PreTrainedModel
|
1064
|
+
end
|
1065
|
+
|
1066
|
+
class DonutSwinModel < DonutSwinPreTrainedModel
|
1067
|
+
end
|
1068
|
+
|
258
1069
|
MODEL_MAPPING_NAMES_ENCODER_ONLY = {
|
259
1070
|
"bert" => ["BertModel", BertModel],
|
260
1071
|
"nomic_bert" => ["NomicBertModel", NomicBertModel],
|
261
1072
|
"deberta-v2" => ["DebertaV2Model", DebertaV2Model],
|
262
1073
|
"mpnet" => ["MPNetModel", MPNetModel],
|
263
1074
|
"distilbert" => ["DistilBertModel", DistilBertModel],
|
264
|
-
"
|
1075
|
+
"roberta" => ["RobertaModel", RobertaModel],
|
1076
|
+
"xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel],
|
1077
|
+
"clip" => ["CLIPModel", CLIPModel],
|
1078
|
+
"detr" => ["DetrModel", DetrModel],
|
1079
|
+
"vit" => ["ViTModel", ViTModel],
|
1080
|
+
"owlvit" => ["OwlViTModel", OwlViTModel],
|
1081
|
+
"donut-swin" => ["DonutSwinModel", DonutSwinModel]
|
1082
|
+
}
|
1083
|
+
|
1084
|
+
MODEL_MAPPING_NAMES_ENCODER_DECODER = {
|
1085
|
+
"bart" => ["BartModel", BartModel]
|
265
1086
|
}
|
266
1087
|
|
267
1088
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
268
1089
|
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
|
269
1090
|
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
|
270
|
-
"xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification]
|
1091
|
+
"xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
|
1092
|
+
"bart" => ["BartForSequenceClassification", BartForSequenceClassification]
|
271
1093
|
}
|
272
1094
|
|
273
1095
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
|
274
1096
|
"bert" => ["BertForTokenClassification", BertForTokenClassification]
|
275
1097
|
}
|
276
1098
|
|
1099
|
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {
|
1100
|
+
"t5" => ["T5ForConditionalGeneration", T5ForConditionalGeneration],
|
1101
|
+
"bart" => ["BartForConditionalGeneration", BartForConditionalGeneration],
|
1102
|
+
"m2m_100" => ["M2M100ForConditionalGeneration", M2M100ForConditionalGeneration]
|
1103
|
+
}
|
1104
|
+
|
1105
|
+
MODEL_WITH_LM_HEAD_MAPPING_NAMES = {
|
1106
|
+
"gpt2" => ["GPT2LMHeadModel", GPT2LMHeadModel],
|
1107
|
+
"mbart" => ["MBartForCausalLM", MBartForCausalLM]
|
1108
|
+
}
|
1109
|
+
|
1110
|
+
MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
|
1111
|
+
"bert" => ["BertForMaskedLM", BertForMaskedLM],
|
1112
|
+
"roberta" => ["RobertaForMaskedLM", RobertaForMaskedLM]
|
1113
|
+
}
|
1114
|
+
|
277
1115
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
278
1116
|
"distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
|
279
1117
|
}
|
280
1118
|
|
1119
|
+
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {
|
1120
|
+
"vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
|
1121
|
+
}
|
1122
|
+
|
1123
|
+
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = {
|
1124
|
+
"vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
|
1125
|
+
}
|
1126
|
+
|
1127
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
|
1128
|
+
"vit" => ["ViTForImageClassification", ViTForImageClassification]
|
1129
|
+
}
|
1130
|
+
|
1131
|
+
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = {
|
1132
|
+
"detr" => ["DetrForObjectDetection", DetrForObjectDetection]
|
1133
|
+
}
|
1134
|
+
|
1135
|
+
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = {
|
1136
|
+
"owlvit" => ["OwlViTForObjectDetection", OwlViTForObjectDetection]
|
1137
|
+
}
|
1138
|
+
|
1139
|
+
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = {
|
1140
|
+
"detr" => ["DetrForSegmentation", DetrForSegmentation]
|
1141
|
+
}
|
1142
|
+
|
1143
|
+
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = {
|
1144
|
+
}
|
1145
|
+
|
1146
|
+
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = {
|
1147
|
+
"swin2sr" => ["Swin2SRForImageSuperResolution", Swin2SRForImageSuperResolution]
|
1148
|
+
}
|
1149
|
+
|
1150
|
+
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = {
|
1151
|
+
"dpt" => ["DPTForDepthEstimation", DPTForDepthEstimation]
|
1152
|
+
}
|
1153
|
+
|
1154
|
+
MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = {
|
1155
|
+
}
|
1156
|
+
|
281
1157
|
MODEL_CLASS_TYPE_MAPPING = [
|
282
1158
|
[MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES[:EncoderOnly]],
|
1159
|
+
[MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES[:EncoderDecoder]],
|
283
1160
|
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
284
1161
|
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
285
|
-
[
|
1162
|
+
[MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
|
1163
|
+
[MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES[:DecoderOnly]],
|
1164
|
+
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1165
|
+
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1166
|
+
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Vision2Seq]],
|
1167
|
+
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1168
|
+
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1169
|
+
[MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1170
|
+
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1171
|
+
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1172
|
+
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1173
|
+
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1174
|
+
[MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
|
286
1175
|
]
|
287
1176
|
|
288
1177
|
MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
|
@@ -306,11 +1195,77 @@ module Informers
|
|
306
1195
|
MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
|
307
1196
|
end
|
308
1197
|
|
1198
|
+
class AutoModelForSeq2SeqLM < PretrainedMixin
|
1199
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES]
|
1200
|
+
end
|
1201
|
+
|
1202
|
+
class AutoModelForCausalLM < PretrainedMixin
|
1203
|
+
MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]
|
1204
|
+
end
|
1205
|
+
|
1206
|
+
class AutoModelForMaskedLM < PretrainedMixin
|
1207
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]
|
1208
|
+
end
|
1209
|
+
|
309
1210
|
class AutoModelForQuestionAnswering < PretrainedMixin
|
310
1211
|
MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
|
311
1212
|
end
|
312
1213
|
|
1214
|
+
class AutoModelForVision2Seq < PretrainedMixin
|
1215
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]
|
1216
|
+
end
|
1217
|
+
|
1218
|
+
class AutoModelForImageClassification < PretrainedMixin
|
1219
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]
|
1220
|
+
end
|
1221
|
+
|
1222
|
+
class AutoModelForImageSegmentation < PretrainedMixin
|
1223
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]
|
1224
|
+
end
|
1225
|
+
|
1226
|
+
class AutoModelForSemanticSegmentation < PretrainedMixin
|
1227
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES]
|
1228
|
+
end
|
1229
|
+
|
1230
|
+
class AutoModelForObjectDetection < PretrainedMixin
|
1231
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]
|
1232
|
+
end
|
1233
|
+
|
1234
|
+
class AutoModelForZeroShotObjectDetection < PretrainedMixin
|
1235
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]
|
1236
|
+
end
|
1237
|
+
|
1238
|
+
class AutoModelForDocumentQuestionAnswering < PretrainedMixin
|
1239
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]
|
1240
|
+
end
|
1241
|
+
|
1242
|
+
class AutoModelForImageToImage < PretrainedMixin
|
1243
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]
|
1244
|
+
end
|
1245
|
+
|
1246
|
+
class AutoModelForDepthEstimation < PretrainedMixin
|
1247
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES]
|
1248
|
+
end
|
1249
|
+
|
1250
|
+
class AutoModelForImageFeatureExtraction < PretrainedMixin
|
1251
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]
|
1252
|
+
end
|
1253
|
+
|
313
1254
|
class ModelOutput
|
1255
|
+
def [](key)
|
1256
|
+
instance_variable_get("@#{key}")
|
1257
|
+
end
|
1258
|
+
end
|
1259
|
+
|
1260
|
+
class Seq2SeqLMOutput < ModelOutput
|
1261
|
+
def initialize(logits, past_key_values, encoder_outputs, decoder_attentions = nil, cross_attentions = nil)
|
1262
|
+
super()
|
1263
|
+
@logits = logits
|
1264
|
+
@past_key_values = past_key_values
|
1265
|
+
@encoder_outputs = encoder_outputs
|
1266
|
+
@decoder_attentions = decoder_attentions
|
1267
|
+
@cross_attentions = cross_attentions
|
1268
|
+
end
|
314
1269
|
end
|
315
1270
|
|
316
1271
|
class SequenceClassifierOutput < ModelOutput
|
@@ -331,6 +1286,15 @@ module Informers
|
|
331
1286
|
end
|
332
1287
|
end
|
333
1288
|
|
1289
|
+
class MaskedLMOutput < ModelOutput
|
1290
|
+
attr_reader :logits
|
1291
|
+
|
1292
|
+
def initialize(logits)
|
1293
|
+
super()
|
1294
|
+
@logits = logits
|
1295
|
+
end
|
1296
|
+
end
|
1297
|
+
|
334
1298
|
class QuestionAnsweringModelOutput < ModelOutput
|
335
1299
|
attr_reader :start_logits, :end_logits
|
336
1300
|
|
@@ -340,4 +1304,25 @@ module Informers
|
|
340
1304
|
@end_logits = end_logits
|
341
1305
|
end
|
342
1306
|
end
|
1307
|
+
|
1308
|
+
class DetrObjectDetectionOutput < ModelOutput
|
1309
|
+
attr_reader :logits, :pred_boxes
|
1310
|
+
|
1311
|
+
def initialize(logits, pred_boxes)
|
1312
|
+
super()
|
1313
|
+
@logits = logits
|
1314
|
+
@pred_boxes = pred_boxes
|
1315
|
+
end
|
1316
|
+
end
|
1317
|
+
|
1318
|
+
class DetrSegmentationOutput < ModelOutput
|
1319
|
+
attr_reader :logits, :pred_boxes, :pred_masks
|
1320
|
+
|
1321
|
+
def initialize(logits, pred_boxes, pred_masks)
|
1322
|
+
super()
|
1323
|
+
@logits = logits
|
1324
|
+
@pred_boxes = pred_boxes
|
1325
|
+
@pred_masks = pred_masks
|
1326
|
+
end
|
1327
|
+
end
|
343
1328
|
end
|