informers 1.0.2 → 1.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -44,7 +44,7 @@ module Informers
44
44
  end
45
45
 
46
46
  const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
47
- model_info = model_class_mapping[config.model_type]
47
+ model_info = model_class_mapping[config[:model_type]]
48
48
  if !model_info
49
49
  next # Item not found in this mapping
50
50
  end
@@ -52,15 +52,17 @@ module Informers
52
52
  end
53
53
 
54
54
  if const_defined?(:BASE_IF_FAIL)
55
- warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
55
+ warn "Unknown model class #{config[:model_type].inspect}, attempting to construct from base class."
56
56
  PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
57
57
  else
58
- raise Error, "Unsupported model type: #{config.model_type}"
58
+ raise Error, "Unsupported model type: #{config[:model_type]}"
59
59
  end
60
60
  end
61
61
  end
62
62
 
63
63
  class PreTrainedModel
64
+ MAIN_INPUT_NAME = :input_ids
65
+
64
66
  attr_reader :config
65
67
 
66
68
  def initialize(config, session)
@@ -76,9 +78,19 @@ module Informers
76
78
 
77
79
  case model_type
78
80
  when MODEL_TYPES[:DecoderOnly]
79
- raise Todo
81
+ @can_generate = true
82
+
83
+ @run_beam = method(:decoder_run_beam)
84
+ @get_start_beams = method(:decoder_start_beams)
85
+ @update_beam = method(:decoder_update_beam)
86
+ @forward = method(:decoder_forward)
80
87
  when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
81
- raise Todo
88
+ @can_generate = true
89
+
90
+ @run_beam = method(:seq2seq_run_beam)
91
+ @get_start_beams = method(:seq2seq_start_beams)
92
+ @update_beam = method(:seq2seq_update_beam)
93
+ @forward = method(:seq2seq_forward)
82
94
  when MODEL_TYPES[:EncoderDecoder]
83
95
  raise Todo
84
96
  else
@@ -110,10 +122,19 @@ module Informers
110
122
  model_type = MODEL_TYPE_MAPPING[model_name]
111
123
 
112
124
  if model_type == MODEL_TYPES[:DecoderOnly]
113
- raise Todo
125
+ info = [
126
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
127
+ construct_session(pretrained_model_name_or_path, options[:model_file_name] || "decoder_model_merged", **options),
128
+ Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
129
+ ]
114
130
 
115
131
  elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
116
- raise Todo
132
+ info = [
133
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
134
+ construct_session(pretrained_model_name_or_path, "encoder_model", **options),
135
+ construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options),
136
+ Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
137
+ ]
117
138
 
118
139
  elsif model_type == MODEL_TYPES[:MaskGeneration]
119
140
  raise Todo
@@ -123,7 +144,7 @@ module Informers
123
144
 
124
145
  else
125
146
  if model_type != MODEL_TYPES[:EncoderOnly]
126
- warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
147
+ warn "Model type for '#{model_name || config[:model_type]}' not found, assuming encoder-only architecture. Please report this."
127
148
  end
128
149
  info = [
129
150
  AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
@@ -135,7 +156,15 @@ module Informers
135
156
  end
136
157
 
137
158
  def self.construct_session(pretrained_model_name_or_path, file_name, **options)
138
- model_file_name = "onnx/#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
159
+ prefix = "onnx/"
160
+ if file_name.start_with?("../")
161
+ prefix = ""
162
+ file_name = file_name[3..]
163
+ elsif file_name.start_with?("/")
164
+ prefix = ""
165
+ file_name = file_name[1..]
166
+ end
167
+ model_file_name = "#{prefix}#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
139
168
  path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
140
169
 
141
170
  OnnxRuntime::InferenceSession.new(path)
@@ -145,8 +174,445 @@ module Informers
145
174
  @forward.(model_inputs, **kwargs)
146
175
  end
147
176
 
177
+ def generate(inputs, generation_config = nil, logits_processor = nil, inputs_attention_mask: nil)
178
+ if !@can_generate
179
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
180
+ error_message = "The current model class (#{model_name}) is not compatible with `.generate()`, as it doesn't have a language model head."
181
+ raise Error, error_message
182
+ end
183
+
184
+ if !inputs.is_a?(Array)
185
+ raise ArgumentError, "`inputs` must be an Array, but is #{inputs.class.name}"
186
+ end
187
+
188
+ if @config[:is_encoder_decoder]
189
+ # Generating from the encoder outputs
190
+ input_ids_seq_length = 0
191
+ else
192
+ input_ids_seq_length = inputs.length
193
+
194
+ # decoder-only
195
+ if input_ids_seq_length == 0
196
+ raise Error, "Must supply a non-empty array of input token ids."
197
+ end
198
+ end
199
+
200
+ # Update generation config with defaults
201
+ generation_config = get_generation_config(generation_config)
202
+
203
+ logits_processor ||= Utils::LogitsProcessorList.new
204
+
205
+ # Update logits processor
206
+ logits_processor = get_logits_processor(
207
+ generation_config,
208
+ input_ids_seq_length,
209
+ logits_processor
210
+ )
211
+
212
+ eos_token_ids = generation_config[:eos_token_id]
213
+ if !eos_token_ids.nil? && !eos_token_ids.is_a?(Array)
214
+ eos_token_ids = [eos_token_ids]
215
+ end
216
+
217
+ num_output_tokens = 1
218
+ max_output_tokens = num_output_tokens + (generation_config[:max_new_tokens] || Float::INFINITY)
219
+
220
+ # Only use max length if max_new_tokens is not provided
221
+ use_max_length = generation_config[:max_length].is_a?(Integer) && generation_config[:max_new_tokens].nil?
222
+ sampler = Utils::Sampler.get_sampler(generation_config)
223
+
224
+ beams = get_start_beams(inputs, generation_config, num_output_tokens, inputs_attention_mask)
225
+
226
+ while beams.any? { |x| !x[:done] } && num_output_tokens < max_output_tokens
227
+ newest_beams = []
228
+ beams.each do |beam|
229
+ if beam[:done]
230
+ # Add this beam back into the pool
231
+ newest_beams << beam
232
+ next
233
+ end
234
+ if use_max_length && beam[:output_token_ids].length >= generation_config["max_length"]
235
+ # Set this beam to done and add it back into the pool
236
+ beam[:done] = true
237
+ newest_beams << beam
238
+ next
239
+ end
240
+
241
+ output = run_beam(beam)
242
+
243
+ # add attentions/scores to beam only if user requested
244
+ if generation_config["output_attentions"]
245
+ add_attentions_to_beam(beam, output)
246
+ end
247
+
248
+ # Logits are of the form [batch_size, out_seq_length, vocab_size]
249
+ # In most cases, this will be [batch_size, 1, vocab_size]
250
+ # So, we select the last token's logits:
251
+ # (equivalent to `logits = outputs.logits[:, -1, :]`)
252
+ logits = output["logits"].map { |v| v[-1] }
253
+
254
+ # Apply logits processor
255
+ logits_processor.(beam[:output_token_ids], logits)
256
+
257
+ sampled_tokens = sampler.(logits)
258
+ sampled_tokens.each do |new_token_id, log_prob|
259
+ # use previous beam as a starting point
260
+ new_beam = beam.dup
261
+
262
+ # update new beam
263
+ update_beam(new_beam, new_token_id)
264
+
265
+ new_beam[:score] += log_prob
266
+
267
+ if eos_token_ids && eos_token_ids.include?(new_token_id)
268
+ new_beam[:done] = true
269
+ end
270
+
271
+ newest_beams << new_beam
272
+ end
273
+ end
274
+ num_output_tokens += 1
275
+
276
+ # Next, we get the best beams, per ID
277
+ newest_beams =
278
+ group_beams(newest_beams).map do |group|
279
+ group.sort_by { |v| -v[:score] }[0...generation_config["num_beams"]]
280
+ end
281
+
282
+ # Flatten beams
283
+ beams = newest_beams.flatten(1)
284
+
285
+ # Run callback
286
+ if generation_config["callback_function"]
287
+ generation_config["callback_function"].(beams)
288
+ end
289
+ end
290
+
291
+ # TODO: Ensure that we can return non-batched outputs
292
+
293
+ grouped_beams = group_beams(beams)
294
+
295
+ get_flattened = lambda do |key|
296
+ grouped_beams.map do |batch|
297
+ if generation_config["num_return_sequences"] > 1
298
+ raise Todo
299
+ else
300
+ [batch[0][key]]
301
+ end
302
+ end.flatten(1)
303
+ end
304
+
305
+ sequences = get_flattened.(:output_token_ids) # [1, seqLength]
306
+
307
+ if generation_config["return_dict_in_generate"]
308
+ raise Todo
309
+ else
310
+ sequences
311
+ end
312
+ end
313
+
148
314
  private
149
315
 
316
+ def get_logits_processor(
317
+ generation_config,
318
+ input_ids_seq_length,
319
+ logits_processor = nil
320
+ )
321
+ processors = Utils::LogitsProcessorList.new
322
+
323
+ if !generation_config["repetition_penalty"].nil? && generation_config["repetition_penalty"] != 1.0
324
+ processors.push(Utils::RepetitionPenaltyLogitsProcessor.new(generation_config["repetition_penalty"]))
325
+ end
326
+
327
+ if !generation_config["no_repeat_ngram_size"].nil? && generation_config["no_repeat_ngram_size"] > 0
328
+ processors.push(Utils::NoRepeatNGramLogitsProcessor.new(generation_config["no_repeat_ngram_size"]))
329
+ end
330
+
331
+ if !generation_config["bad_words_ids"].nil?
332
+ processors.push(Utils::NoBadWordsLogitsProcessor.new(generation_config["bad_words_ids"], generation_config["eos_token_id"]))
333
+ end
334
+
335
+ if !generation_config["min_length"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_length"] > 0
336
+ processors.push(Utils::MinLengthLogitsProcessor.new(generation_config["min_length"], generation_config["eos_token_id"]))
337
+ end
338
+
339
+ if !generation_config["min_new_tokens"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_new_tokens"] > 0
340
+ processors.push(Utils::MinNewTokensLengthLogitsProcessor.new(
341
+ input_ids_seq_length,
342
+ generation_config["min_new_tokens"],
343
+ generation_config["eos_token_id"]
344
+ ))
345
+ end
346
+
347
+ if !generation_config["forced_bos_token_id"].nil?
348
+ processors.push(Utils::ForcedBOSTokenLogitsProcessor.new(generation_config["forced_bos_token_id"]))
349
+ end
350
+
351
+ if !generation_config["forced_eos_token_id"].nil?
352
+ processors.push(Utils::ForcedEOSTokenLogitsProcessor.new(
353
+ generation_config["max_length"],
354
+ generation_config["forced_eos_token_id"]
355
+ ))
356
+ end
357
+
358
+ if !generation_config["begin_suppress_tokens"].nil?
359
+ raise Todo
360
+ end
361
+
362
+ if !generation_config["forced_decoder_ids"].nil?
363
+ processors.push(Utils::ForceTokensLogitsProcessor.new(generation_config["forced_decoder_ids"]))
364
+ end
365
+
366
+ if !logits_processor.nil?
367
+ processors.concat(logits_processor)
368
+ end
369
+
370
+ processors
371
+ end
372
+
373
+ def get_generation_config(generation_config)
374
+ # Create empty generation config (contains defaults)
375
+ # We pass `@config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
376
+ gen_config = Utils::GenerationConfig.new(@config.to_h)
377
+
378
+ # Apply model's generation config, if it exists
379
+ if @generation_config
380
+ gen_config.merge!(@generation_config)
381
+ end
382
+
383
+ # Finally, use any generation config specified by the user
384
+ # when calling `generate`
385
+ if !generation_config.nil?
386
+ gen_config.merge!(generation_config)
387
+ end
388
+
389
+ gen_config
390
+ end
391
+
392
+ def seq2seq_forward(model_inputs)
393
+ encoder_outputs = model_inputs[:encoder_outputs]
394
+ past_key_values = model_inputs[:past_key_values]
395
+
396
+ if !encoder_outputs
397
+ # Encoder outputs are not given, so we must compute them.
398
+ encoder_outputs = encoder_forward(model_inputs)[0]
399
+ end
400
+ decoder_feeds = {
401
+ input_ids: model_inputs[:decoder_input_ids],
402
+ encoder_hidden_states: encoder_outputs
403
+ }
404
+ use_cache_branch = !!past_key_values
405
+
406
+ if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
407
+ decoder_feeds[:use_cache_branch] = [use_cache_branch]
408
+ end
409
+
410
+ if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("encoder_attention_mask")
411
+ decoder_feeds[:encoder_attention_mask] = model_inputs[:attention_mask]
412
+ end
413
+
414
+ prepare_position_ids(@decoder_merged_session, decoder_feeds, use_cache_branch)
415
+ add_past_key_values(decoder_feeds, past_key_values)
416
+
417
+ decoder_results = session_run(@decoder_merged_session, decoder_feeds)
418
+ decoder_results = @decoder_merged_session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
419
+ logits = decoder_results["logits"]
420
+ past_key_values = get_past_key_values(decoder_results, past_key_values)
421
+
422
+ # Get cross attention and/or decoder attentions if they are present
423
+ attns = get_attentions(decoder_results)
424
+
425
+ Seq2SeqLMOutput.new(logits, past_key_values, encoder_outputs, attns["decoder_attentions"], attns["cross_attentions"])
426
+ end
427
+
428
+ def prepare_position_ids(session, feeds, use_cache_branch)
429
+ if !session.inputs.map { |v| v[:name] }.include?("position_ids")
430
+ return
431
+ end
432
+
433
+ raise Todo
434
+ end
435
+
436
+ def get_past_key_values(decoder_results, past_key_values)
437
+ pkvs = {}
438
+
439
+ decoder_results.each_key do |name|
440
+ if name.start_with?("present")
441
+ new_name = name.sub("present", "past_key_values")
442
+
443
+ if past_key_values && name.include?("encoder")
444
+ # Optimization introduced by optimum to reuse past key values. So, we just replace the constant
445
+ # outputs with the previous past key values.
446
+ # https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
447
+ pkvs[new_name] = past_key_values[new_name]
448
+ else
449
+ pkvs[new_name] = decoder_results[name]
450
+ end
451
+ end
452
+ end
453
+ pkvs
454
+ end
455
+
456
+ def get_attentions(decoder_results)
457
+ attns = {}
458
+
459
+ ["cross_attentions", "decoder_attentions"].each do |attn_name|
460
+ result = []
461
+ decoder_results.each_key do |name|
462
+ if name.start_with?(attn_name)
463
+ index = name.split(".").pop
464
+ result[index] = decoder_results[name]
465
+ end
466
+ end
467
+ attns[attn_name] = result
468
+ end
469
+ attns
470
+ end
471
+
472
+ def add_past_key_values(decoder_feeds, past_key_values)
473
+ if past_key_values
474
+ decoder_feeds.merge!(past_key_values)
475
+ else
476
+ # TODO support batches (i.e., batch_size > 1)
477
+ batch_size = 1
478
+
479
+ if @config[:is_encoder_decoder] && (!@add_encoder_pkv.nil? ? @add_encoder_pkv : true)
480
+ _encoder_dims = [batch_size, @num_encoder_heads, 0, @encoder_dim_kv]
481
+ _decoder_dims = [batch_size, @num_decoder_heads, 0, @decoder_dim_kv]
482
+ @num_decoder_layers.times do |i|
483
+ # decoder_feeds["past_key_values.#{i}.encoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
484
+ # decoder_feeds["past_key_values.#{i}.encoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
485
+ # decoder_feeds["past_key_values.#{i}.decoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
486
+ # decoder_feeds["past_key_values.#{i}.decoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
487
+ end
488
+ elsif @config[:model_type] == "falcon"
489
+ raise Todo
490
+ elsif @config[:multi_query]
491
+ raise Todo
492
+ elsif @config[:model_type] == "bloom"
493
+ raise Todo
494
+ else
495
+ _dims = [batch_size, @num_heads, 0, @dim_kv]
496
+ @num_layers.times do |i|
497
+ # decoder_feeds["past_key_values.#{i}.key"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
498
+ # decoder_feeds["past_key_values.#{i}.value"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
499
+ end
500
+ end
501
+ end
502
+ end
503
+
504
+ def seq2seq_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask = nil)
505
+ beams = []
506
+ beam_id = 0
507
+
508
+ requires_attention_mask = !@requires_attention_mask.nil? ? @requires_attention_mask : true
509
+
510
+ # decoder_input_ids == output_token_ids
511
+ decoder_input_ids =
512
+ generation_config["decoder_input_ids"] ||
513
+ generation_config["decoder_start_token_id"] ||
514
+ generation_config["bos_token_id"] ||
515
+ generation_config["eos_token_id"]
516
+
517
+ if !decoder_input_ids.is_a?(Array)
518
+ decoder_input_ids = [decoder_input_ids]
519
+ end
520
+
521
+ input_token_ids.each do |tokens|
522
+ # TODO: Improve
523
+ # Currently, just add back batch dimension.
524
+ # In future, allow for true parallel execution
525
+ tokens = [tokens]
526
+
527
+ # Create beam
528
+ start = {
529
+ inputs: tokens,
530
+ encoder_outputs: nil,
531
+ prev_model_outputs: nil,
532
+
533
+ output_token_ids: decoder_input_ids,
534
+ done: false,
535
+ score: 0,
536
+ id: beam_id # assign unique id to beams
537
+ }
538
+ beam_id += 1
539
+
540
+ if requires_attention_mask
541
+ start[:attention_mask] = prepare_attention_mask(tokens)
542
+ end
543
+
544
+ beams << start
545
+ end
546
+
547
+ beams
548
+ end
549
+
550
+ def prepare_attention_mask(tokens)
551
+ # Prepare attention mask
552
+ pad_token_id = @config["pad_token_id"]
553
+ eos_token_id = @config["eos_token_id"]
554
+ if eos_token_id.is_a?(Integer)
555
+ eos_token_id = [eos_token_id]
556
+ end
557
+
558
+ is_pad_token_in_inputs = !tokens.index(pad_token_id).nil?
559
+ is_pad_token_not_equal_to_eos_token_id = eos_token_id.nil? || !eos_token_id.include?(pad_token_id)
560
+
561
+ if is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id
562
+ raise Todo
563
+ else
564
+ Utils.ones_like(tokens)
565
+ end
566
+ end
567
+
568
+ def seq2seq_run_beam(beam)
569
+ input_name = self.class.const_get(:MAIN_INPUT_NAME)
570
+
571
+ decoder_input_ids = beam[:output_token_ids]
572
+ if beam[:prev_model_outputs]
573
+ # After the first step, `prev_model_outputs` won't be null.
574
+ # So, we cut decoder_input_ids if past is used
575
+ decoder_input_ids = [decoder_input_ids[-1]]
576
+ end
577
+
578
+ # 1. Prepare
579
+ model_inputs = {
580
+ input_name => beam[:inputs],
581
+ decoder_input_ids: [decoder_input_ids],
582
+ encoder_outputs: beam[:encoder_outputs],
583
+ past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
584
+ }
585
+ if beam[:attention_mask]
586
+ model_inputs[:attention_mask] = beam[:attention_mask]
587
+ end
588
+
589
+ # 2. Run
590
+ output = @forward.(model_inputs)
591
+
592
+ # 3. Update
593
+ beam[:prev_model_outputs] = output
594
+ beam[:encoder_outputs] = output[:encoder_outputs]
595
+
596
+ output
597
+ end
598
+
599
+ def seq2seq_update_beam(beam, new_token_id)
600
+ beam[:output_token_ids] += [new_token_id]
601
+ end
602
+
603
+ def group_beams(beams)
604
+ # Group beams by their ids
605
+ groups = {}
606
+ beams.each do |obj|
607
+ if !groups[obj[:id]]
608
+ groups[obj[:id]] = [obj]
609
+ else
610
+ groups[obj[:id]] << obj
611
+ end
612
+ end
613
+ groups.values
614
+ end
615
+
150
616
  def encoder_forward(model_inputs, output_names: nil)
151
617
  encoder_feeds = {}
152
618
  @session.inputs.each do |input|
@@ -159,7 +625,96 @@ module Informers
159
625
  session_run(@session, encoder_feeds, output_names:)
160
626
  end
161
627
 
162
- def session_run(session, inputs, output_names:)
628
+ def decoder_forward(model_inputs)
629
+ input_ids, past_key_values, attention_mask =
630
+ model_inputs.values_at(:input_ids, :past_key_values, :attention_mask)
631
+ decoder_feeds = {
632
+ input_ids: input_ids,
633
+ attention_mask: attention_mask || prepare_attention_mask(input_ids)
634
+ }
635
+ use_cache_branch = !!past_key_values
636
+
637
+ if @session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
638
+ decoder_feeds[:use_cache_branch] = [use_cache_branch]
639
+ end
640
+
641
+ prepare_position_ids(@session, decoder_feeds, use_cache_branch)
642
+
643
+ add_past_key_values(decoder_feeds, past_key_values)
644
+
645
+ decoder_results = session_run(@session, decoder_feeds)
646
+ decoder_results = @session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
647
+
648
+ logits = decoder_results["logits"]
649
+
650
+ past_key_values = get_past_key_values(decoder_results, past_key_values)
651
+ {"logits" => logits, past_key_values: past_key_values}
652
+ end
653
+
654
+ def decoder_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
655
+ beams = []
656
+
657
+ beam_id = 0
658
+ input_token_ids.each do |tokens|
659
+ output_token_ids = tokens.dup
660
+
661
+ # TODO: Improve
662
+ # Currently, just add back batch dimension.
663
+ # In future, allow for true parallel execution
664
+ tokens = [tokens]
665
+
666
+ if inputs_attention_mask
667
+ attn_mask = inputs_attention_mask[beam_id]
668
+ attn_mask = [attn_mask]
669
+ else
670
+ attn_mask = prepare_attention_mask(tokens)
671
+ end
672
+
673
+ start = {
674
+ input: tokens,
675
+ model_input_ids: tokens,
676
+ attention_mask: attn_mask,
677
+ prev_model_outputs: nil,
678
+
679
+ output_token_ids: output_token_ids,
680
+ num_output_tokens: num_output_tokens,
681
+
682
+ done: false,
683
+ score: 0,
684
+ id: beam_id # assign unique id to beams
685
+ }
686
+ beam_id += 1
687
+
688
+ beams << start
689
+ end
690
+ beams
691
+ end
692
+
693
+ def decoder_run_beam(beam)
694
+ attn_mask_data = Array.new(beam[:output_token_ids].length, 1)
695
+
696
+ # 1. Prepare
697
+ model_inputs = {
698
+ input_ids: beam[:model_input_ids],
699
+ attention_mask: [attn_mask_data],
700
+ past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
701
+ }
702
+
703
+ # 2. Run
704
+ output = @forward.(model_inputs)
705
+
706
+ # 3. Update
707
+ beam[:prev_model_outputs] = output
708
+
709
+ output
710
+ end
711
+
712
+ def decoder_update_beam(beam, new_token_id)
713
+ beam[:output_token_ids] += [new_token_id]
714
+ beam[:model_input_ids] = [[new_token_id]]
715
+ end
716
+
717
+ def session_run(session, inputs, output_names: nil)
163
718
  checked_inputs = validate_inputs(session, inputs)
164
719
  begin
165
720
  output = session.run(output_names || @output_names, checked_inputs)
@@ -179,6 +734,18 @@ module Informers
179
734
  def validate_inputs(session, inputs)
180
735
  inputs
181
736
  end
737
+
738
+ def get_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
739
+ @get_start_beams.(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
740
+ end
741
+
742
+ def run_beam(beam)
743
+ @run_beam.(beam)
744
+ end
745
+
746
+ def update_beam(beam, new_token_id)
747
+ @update_beam.(beam, new_token_id)
748
+ end
182
749
  end
183
750
 
184
751
  class BertPreTrainedModel < PreTrainedModel
@@ -187,6 +754,12 @@ module Informers
187
754
  class BertModel < BertPreTrainedModel
188
755
  end
189
756
 
757
+ class BertForMaskedLM < BertPreTrainedModel
758
+ def call(model_inputs)
759
+ MaskedLMOutput.new(*super(model_inputs))
760
+ end
761
+ end
762
+
190
763
  class BertForSequenceClassification < BertPreTrainedModel
191
764
  def call(model_inputs)
192
765
  SequenceClassifierOutput.new(*super(model_inputs))
@@ -229,31 +802,376 @@ module Informers
229
802
  end
230
803
  end
231
804
 
805
+ class MPNetPreTrainedModel < PreTrainedModel
806
+ end
807
+
808
+ class MPNetModel < MPNetPreTrainedModel
809
+ end
810
+
811
+ class T5PreTrainedModel < PreTrainedModel
812
+ end
813
+
814
+ class T5Model < T5PreTrainedModel
815
+ end
816
+
817
+ class T5ForConditionalGeneration < T5PreTrainedModel
818
+ def initialize(config, session, decoder_merged_session, generation_config)
819
+ super(config, session)
820
+ @decoder_merged_session = decoder_merged_session
821
+ @generation_config = generation_config
822
+
823
+ @num_decoder_layers = @config[:num_decoder_layers]
824
+ @num_decoder_heads = @config[:num_heads]
825
+ @decoder_dim_kv = @config[:d_kv]
826
+
827
+ @num_encoder_layers = @config[:num_layers]
828
+ @num_encoder_heads = @config[:num_heads]
829
+ @encoder_dim_kv = @config[:d_kv]
830
+ end
831
+ end
832
+
833
+ class BartPretrainedModel < PreTrainedModel
834
+ end
835
+
836
+ class BartModel < BartPretrainedModel
837
+ end
838
+
839
+ class BartForConditionalGeneration < BartPretrainedModel
840
+ def initialize(config, session, decoder_merged_session, generation_config)
841
+ super(config, session)
842
+ @decoder_merged_session = decoder_merged_session
843
+ @generation_config = generation_config
844
+
845
+ @num_decoder_layers = @config["decoder_layers"]
846
+ @num_decoder_heads = @config["decoder_attention_heads"]
847
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
848
+
849
+ @num_encoder_layers = @config["encoder_layers"]
850
+ @num_encoder_heads = @config["encoder_attention_heads"]
851
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads
852
+ end
853
+ end
854
+
855
+ class BartForSequenceClassification < BartPretrainedModel
856
+ def call(model_inputs)
857
+ SequenceClassifierOutput.new(*super(model_inputs))
858
+ end
859
+ end
860
+
861
+ class MBartPreTrainedModel < PreTrainedModel
862
+ end
863
+
864
+ class MBartModel < MBartPreTrainedModel
865
+ end
866
+
867
+ class MBartForCausalLM < MBartPreTrainedModel
868
+ attr_reader :num_decoder_layers, :num_decoder_heads, :decoder_dim_kv,
869
+ :num_encoder_layers, :num_encoder_heads, :encoder_dim_kv
870
+
871
+ def initialize(config, decoder_merged_session, generation_config)
872
+ super(config, decoder_merged_session)
873
+ @generation_config = generation_config
874
+
875
+ @num_decoder_layers = @config["decoder_layers"]
876
+ @num_decoder_heads = @config["decoder_attention_heads"]
877
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
878
+
879
+ @num_encoder_layers = @config["encoder_layers"]
880
+ @num_encoder_heads = @config["encoder_attention_heads"]
881
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
882
+ end
883
+ end
884
+
885
+ class M2M100PreTrainedModel < PreTrainedModel
886
+ end
887
+
888
+ class M2M100Model < M2M100PreTrainedModel
889
+ end
890
+
891
+ class M2M100ForConditionalGeneration < M2M100PreTrainedModel
892
+ def initialize(config, session, decoder_merged_session, generation_config)
893
+ super(config, session)
894
+ @decoder_merged_session = decoder_merged_session
895
+ @generation_config = generation_config
896
+
897
+ @num_decoder_layers = @config["decoder_layers"]
898
+ @num_decoder_heads = @config["decoder_attention_heads"]
899
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
900
+
901
+ @num_encoder_layers = @config["encoder_layers"]
902
+ @num_encoder_heads = @config["encoder_attention_heads"]
903
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
904
+ end
905
+ end
906
+
907
+ class RobertaPreTrainedModel < PreTrainedModel
908
+ end
909
+
910
+ class RobertaModel < RobertaPreTrainedModel
911
+ end
912
+
913
+ class RobertaForMaskedLM < RobertaPreTrainedModel
914
+ def call(model_inputs)
915
+ MaskedLMOutput.new(*super(model_inputs))
916
+ end
917
+ end
918
+
919
+ class XLMRobertaPreTrainedModel < PreTrainedModel
920
+ end
921
+
922
+ class XLMRobertaModel < XLMRobertaPreTrainedModel
923
+ end
924
+
925
+ class XLMRobertaForSequenceClassification < XLMRobertaPreTrainedModel
926
+ def call(model_inputs)
927
+ SequenceClassifierOutput.new(*super(model_inputs))
928
+ end
929
+ end
930
+
931
+ class ViTPreTrainedModel < PreTrainedModel
932
+ end
933
+
934
+ class ViTModel < ViTPreTrainedModel
935
+ end
936
+
937
+ class ViTForImageClassification < ViTPreTrainedModel
938
+ def call(model_inputs)
939
+ SequenceClassifierOutput.new(*super(model_inputs))
940
+ end
941
+ end
942
+
943
+ class CLIPPreTrainedModel < PreTrainedModel
944
+ end
945
+
946
+ class CLIPModel < CLIPPreTrainedModel
947
+ end
948
+
949
+ class GPT2PreTrainedModel < PreTrainedModel
950
+ attr_reader :num_heads, :num_layers, :dim_kv
951
+
952
+ def initialize(config, session, generation_config)
953
+ super(config, session)
954
+ @generation_config = generation_config
955
+
956
+ # config doesn't contain pad_token_id, so we assume it is the eos_token_id
957
+ @config["pad_token_id"] = @config["eos_token_id"]
958
+
959
+ @num_heads = @config["n_head"]
960
+ @num_layers = @config["n_layer"]
961
+ @dim_kv = @config["n_embd"] / @num_heads.to_f
962
+ end
963
+ end
964
+
965
+ class GPT2Model < GPT2PreTrainedModel
966
+ end
967
+
968
+ class GPT2LMHeadModel < GPT2PreTrainedModel
969
+ end
970
+
971
+ class OwlViTPreTrainedModel < PreTrainedModel
972
+ end
973
+
974
+ class OwlViTModel < OwlViTPreTrainedModel
975
+ end
976
+
977
+ class OwlViTForObjectDetection < OwlViTPreTrainedModel
978
+ end
979
+
980
+ class DetrPreTrainedModel < PreTrainedModel
981
+ end
982
+
983
+ class DetrModel < DetrPreTrainedModel
984
+ end
985
+
986
+ class DetrForObjectDetection < DetrPreTrainedModel
987
+ def call(model_inputs)
988
+ DetrObjectDetectionOutput.new(*super(model_inputs))
989
+ end
990
+ end
991
+
992
+ class DetrForSegmentation < DetrPreTrainedModel
993
+ def call(model_inputs)
994
+ DetrSegmentationOutput.new(*super(model_inputs))
995
+ end
996
+ end
997
+
998
+ class Swin2SRPreTrainedModel < PreTrainedModel
999
+ end
1000
+
1001
+ class Swin2SRModel < Swin2SRPreTrainedModel
1002
+ end
1003
+
1004
+ class Swin2SRForImageSuperResolution < Swin2SRPreTrainedModel
1005
+ end
1006
+
1007
+ class DPTPreTrainedModel < PreTrainedModel
1008
+ end
1009
+
1010
+ class DPTModel < DPTPreTrainedModel
1011
+ end
1012
+
1013
+ class DPTForDepthEstimation < DPTPreTrainedModel
1014
+ end
1015
+
1016
+ class VisionEncoderDecoderModel < PreTrainedModel
1017
+ MAIN_INPUT_NAME = :pixel_values
1018
+
1019
+ def initialize(config, session, decoder_merged_session, generation_config)
1020
+ super(config, session)
1021
+ @decoder_merged_session = decoder_merged_session
1022
+ @generation_config = generation_config
1023
+
1024
+ # Extract configs
1025
+ encoder_config = @config["encoder"]
1026
+ decoder_config = @config["decoder"]
1027
+
1028
+ # Validate encoder
1029
+ encoder_model_type = encoder_config["model_type"]
1030
+ encoder_model = MODEL_MAPPING_NAMES_ENCODER_ONLY[encoder_model_type] || MODEL_MAPPING_NAMES_ENCODER_DECODER[encoder_model_type]
1031
+ if !encoder_model
1032
+ warn "Model type for encoder '#{encoder_model_type}' not found, assuming encoder-only architecture. Please report this."
1033
+ end
1034
+
1035
+ # Validate decoder
1036
+ decoder_model = MODEL_WITH_LM_HEAD_MAPPING_NAMES[decoder_config["model_type"]]
1037
+ if !decoder_model
1038
+ raise Error, "Unable to construct `VisionEncoderDecoder` due to unsupported decoder: \"#{decoder_config["model_type"]}\""
1039
+ end
1040
+
1041
+ decoder_model_class = decoder_model[1]
1042
+ decoder = decoder_model_class.new(decoder_config, decoder_merged_session, generation_config)
1043
+
1044
+ @add_encoder_pkv = decoder.respond_to?(:num_decoder_layers)
1045
+ if @add_encoder_pkv
1046
+ # Decoder is part of an encoder-decoder model
1047
+ @num_decoder_layers = decoder.num_decoder_layers
1048
+ @num_decoder_heads = decoder.num_decoder_heads
1049
+ @decoder_dim_kv = decoder.decoder_dim_kv
1050
+
1051
+ @num_encoder_layers = decoder.num_encoder_layers
1052
+ @num_encoder_heads = decoder.num_encoder_heads
1053
+ @encoder_dim_kv = decoder.encoder_dim_kv
1054
+ else
1055
+ # Decoder is a decoder-only model
1056
+ @num_layers = decoder.num_layers
1057
+ @num_heads = decoder.num_heads
1058
+ @dim_kv = decoder.dim_kv
1059
+ end
1060
+ end
1061
+ end
1062
+
1063
+ class DonutSwinPreTrainedModel < PreTrainedModel
1064
+ end
1065
+
1066
+ class DonutSwinModel < DonutSwinPreTrainedModel
1067
+ end
1068
+
232
1069
  MODEL_MAPPING_NAMES_ENCODER_ONLY = {
233
1070
  "bert" => ["BertModel", BertModel],
234
1071
  "nomic_bert" => ["NomicBertModel", NomicBertModel],
235
1072
  "deberta-v2" => ["DebertaV2Model", DebertaV2Model],
236
- "distilbert" => ["DistilBertModel", DistilBertModel]
1073
+ "mpnet" => ["MPNetModel", MPNetModel],
1074
+ "distilbert" => ["DistilBertModel", DistilBertModel],
1075
+ "roberta" => ["RobertaModel", RobertaModel],
1076
+ "xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel],
1077
+ "clip" => ["CLIPModel", CLIPModel],
1078
+ "detr" => ["DetrModel", DetrModel],
1079
+ "vit" => ["ViTModel", ViTModel],
1080
+ "owlvit" => ["OwlViTModel", OwlViTModel],
1081
+ "donut-swin" => ["DonutSwinModel", DonutSwinModel]
1082
+ }
1083
+
1084
+ MODEL_MAPPING_NAMES_ENCODER_DECODER = {
1085
+ "bart" => ["BartModel", BartModel]
237
1086
  }
238
1087
 
239
1088
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
240
1089
  "bert" => ["BertForSequenceClassification", BertForSequenceClassification],
241
- "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
1090
+ "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
1091
+ "xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
1092
+ "bart" => ["BartForSequenceClassification", BartForSequenceClassification]
242
1093
  }
243
1094
 
244
1095
  MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
245
1096
  "bert" => ["BertForTokenClassification", BertForTokenClassification]
246
1097
  }
247
1098
 
1099
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {
1100
+ "t5" => ["T5ForConditionalGeneration", T5ForConditionalGeneration],
1101
+ "bart" => ["BartForConditionalGeneration", BartForConditionalGeneration],
1102
+ "m2m_100" => ["M2M100ForConditionalGeneration", M2M100ForConditionalGeneration]
1103
+ }
1104
+
1105
+ MODEL_WITH_LM_HEAD_MAPPING_NAMES = {
1106
+ "gpt2" => ["GPT2LMHeadModel", GPT2LMHeadModel],
1107
+ "mbart" => ["MBartForCausalLM", MBartForCausalLM]
1108
+ }
1109
+
1110
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
1111
+ "bert" => ["BertForMaskedLM", BertForMaskedLM],
1112
+ "roberta" => ["RobertaForMaskedLM", RobertaForMaskedLM]
1113
+ }
1114
+
248
1115
  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
249
1116
  "distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
250
1117
  }
251
1118
 
1119
+ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {
1120
+ "vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
1121
+ }
1122
+
1123
+ MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = {
1124
+ "vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
1125
+ }
1126
+
1127
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
1128
+ "vit" => ["ViTForImageClassification", ViTForImageClassification]
1129
+ }
1130
+
1131
+ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = {
1132
+ "detr" => ["DetrForObjectDetection", DetrForObjectDetection]
1133
+ }
1134
+
1135
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = {
1136
+ "owlvit" => ["OwlViTForObjectDetection", OwlViTForObjectDetection]
1137
+ }
1138
+
1139
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = {
1140
+ "detr" => ["DetrForSegmentation", DetrForSegmentation]
1141
+ }
1142
+
1143
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = {
1144
+ }
1145
+
1146
+ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = {
1147
+ "swin2sr" => ["Swin2SRForImageSuperResolution", Swin2SRForImageSuperResolution]
1148
+ }
1149
+
1150
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = {
1151
+ "dpt" => ["DPTForDepthEstimation", DPTForDepthEstimation]
1152
+ }
1153
+
1154
+ MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = {
1155
+ }
1156
+
252
1157
  MODEL_CLASS_TYPE_MAPPING = [
253
1158
  [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES[:EncoderOnly]],
1159
+ [MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES[:EncoderDecoder]],
254
1160
  [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
255
1161
  [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
256
- [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
1162
+ [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
1163
+ [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES[:DecoderOnly]],
1164
+ [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1165
+ [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1166
+ [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Vision2Seq]],
1167
+ [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1168
+ [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1169
+ [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1170
+ [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1171
+ [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1172
+ [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1173
+ [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1174
+ [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
257
1175
  ]
258
1176
 
259
1177
  MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
@@ -277,11 +1195,77 @@ module Informers
277
1195
  MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
278
1196
  end
279
1197
 
1198
+ class AutoModelForSeq2SeqLM < PretrainedMixin
1199
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES]
1200
+ end
1201
+
1202
+ class AutoModelForCausalLM < PretrainedMixin
1203
+ MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]
1204
+ end
1205
+
1206
+ class AutoModelForMaskedLM < PretrainedMixin
1207
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]
1208
+ end
1209
+
280
1210
  class AutoModelForQuestionAnswering < PretrainedMixin
281
1211
  MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
282
1212
  end
283
1213
 
1214
+ class AutoModelForVision2Seq < PretrainedMixin
1215
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]
1216
+ end
1217
+
1218
+ class AutoModelForImageClassification < PretrainedMixin
1219
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]
1220
+ end
1221
+
1222
+ class AutoModelForImageSegmentation < PretrainedMixin
1223
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]
1224
+ end
1225
+
1226
+ class AutoModelForSemanticSegmentation < PretrainedMixin
1227
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES]
1228
+ end
1229
+
1230
+ class AutoModelForObjectDetection < PretrainedMixin
1231
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]
1232
+ end
1233
+
1234
+ class AutoModelForZeroShotObjectDetection < PretrainedMixin
1235
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]
1236
+ end
1237
+
1238
+ class AutoModelForDocumentQuestionAnswering < PretrainedMixin
1239
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]
1240
+ end
1241
+
1242
+ class AutoModelForImageToImage < PretrainedMixin
1243
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]
1244
+ end
1245
+
1246
+ class AutoModelForDepthEstimation < PretrainedMixin
1247
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES]
1248
+ end
1249
+
1250
+ class AutoModelForImageFeatureExtraction < PretrainedMixin
1251
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]
1252
+ end
1253
+
284
1254
  class ModelOutput
1255
+ def [](key)
1256
+ instance_variable_get("@#{key}")
1257
+ end
1258
+ end
1259
+
1260
+ class Seq2SeqLMOutput < ModelOutput
1261
+ def initialize(logits, past_key_values, encoder_outputs, decoder_attentions = nil, cross_attentions = nil)
1262
+ super()
1263
+ @logits = logits
1264
+ @past_key_values = past_key_values
1265
+ @encoder_outputs = encoder_outputs
1266
+ @decoder_attentions = decoder_attentions
1267
+ @cross_attentions = cross_attentions
1268
+ end
285
1269
  end
286
1270
 
287
1271
  class SequenceClassifierOutput < ModelOutput
@@ -302,6 +1286,15 @@ module Informers
302
1286
  end
303
1287
  end
304
1288
 
1289
+ class MaskedLMOutput < ModelOutput
1290
+ attr_reader :logits
1291
+
1292
+ def initialize(logits)
1293
+ super()
1294
+ @logits = logits
1295
+ end
1296
+ end
1297
+
305
1298
  class QuestionAnsweringModelOutput < ModelOutput
306
1299
  attr_reader :start_logits, :end_logits
307
1300
 
@@ -311,4 +1304,25 @@ module Informers
311
1304
  @end_logits = end_logits
312
1305
  end
313
1306
  end
1307
+
1308
+ class DetrObjectDetectionOutput < ModelOutput
1309
+ attr_reader :logits, :pred_boxes
1310
+
1311
+ def initialize(logits, pred_boxes)
1312
+ super()
1313
+ @logits = logits
1314
+ @pred_boxes = pred_boxes
1315
+ end
1316
+ end
1317
+
1318
+ class DetrSegmentationOutput < ModelOutput
1319
+ attr_reader :logits, :pred_boxes, :pred_masks
1320
+
1321
+ def initialize(logits, pred_boxes, pred_masks)
1322
+ super()
1323
+ @logits = logits
1324
+ @pred_boxes = pred_boxes
1325
+ @pred_masks = pred_masks
1326
+ end
1327
+ end
314
1328
  end