informers 1.0.3 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,7 +44,7 @@ module Informers
44
44
  end
45
45
 
46
46
  const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
47
- model_info = model_class_mapping[config.model_type]
47
+ model_info = model_class_mapping[config[:model_type]]
48
48
  if !model_info
49
49
  next # Item not found in this mapping
50
50
  end
@@ -52,15 +52,17 @@ module Informers
52
52
  end
53
53
 
54
54
  if const_defined?(:BASE_IF_FAIL)
55
- warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
55
+ warn "Unknown model class #{config[:model_type].inspect}, attempting to construct from base class."
56
56
  PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
57
57
  else
58
- raise Error, "Unsupported model type: #{config.model_type}"
58
+ raise Error, "Unsupported model type: #{config[:model_type]}"
59
59
  end
60
60
  end
61
61
  end
62
62
 
63
63
  class PreTrainedModel
64
+ MAIN_INPUT_NAME = :input_ids
65
+
64
66
  attr_reader :config
65
67
 
66
68
  def initialize(config, session)
@@ -76,11 +78,24 @@ module Informers
76
78
 
77
79
  case model_type
78
80
  when MODEL_TYPES[:DecoderOnly]
79
- raise Todo
81
+ @can_generate = true
82
+
83
+ @run_beam = method(:decoder_run_beam)
84
+ @get_start_beams = method(:decoder_start_beams)
85
+ @update_beam = method(:decoder_update_beam)
86
+ @forward = method(:decoder_forward)
87
+
80
88
  when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
81
- raise Todo
89
+ @can_generate = true
90
+
91
+ @run_beam = method(:seq2seq_run_beam)
92
+ @get_start_beams = method(:seq2seq_start_beams)
93
+ @update_beam = method(:seq2seq_update_beam)
94
+ @forward = method(:seq2seq_forward)
95
+
82
96
  when MODEL_TYPES[:EncoderDecoder]
83
- raise Todo
97
+ @forward = method(:encoder_forward)
98
+
84
99
  else
85
100
  @forward = method(:encoder_forward)
86
101
  end
@@ -110,20 +125,37 @@ module Informers
110
125
  model_type = MODEL_TYPE_MAPPING[model_name]
111
126
 
112
127
  if model_type == MODEL_TYPES[:DecoderOnly]
113
- raise Todo
128
+ info = [
129
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
130
+ construct_session(pretrained_model_name_or_path, options[:model_file_name] || "decoder_model_merged", **options),
131
+ Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
132
+ ]
114
133
 
115
134
  elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
116
- raise Todo
135
+ info = [
136
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
137
+ construct_session(pretrained_model_name_or_path, "encoder_model", **options),
138
+ construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options),
139
+ Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
140
+ ]
117
141
 
118
142
  elsif model_type == MODEL_TYPES[:MaskGeneration]
119
- raise Todo
143
+ info = [
144
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
145
+ construct_session(pretrained_model_name_or_path, "vision_encoder", **options),
146
+ construct_session(pretrained_model_name_or_path, "prompt_encoder_mask_decoder", **options)
147
+ ]
120
148
 
121
149
  elsif model_type == MODEL_TYPES[:EncoderDecoder]
122
- raise Todo
150
+ info = [
151
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
152
+ construct_session(pretrained_model_name_or_path, "encoder_model", **options),
153
+ construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options)
154
+ ]
123
155
 
124
156
  else
125
157
  if model_type != MODEL_TYPES[:EncoderOnly]
126
- warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
158
+ warn "Model type for '#{model_name || config[:model_type]}' not found, assuming encoder-only architecture. Please report this."
127
159
  end
128
160
  info = [
129
161
  AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
@@ -153,8 +185,445 @@ module Informers
153
185
  @forward.(model_inputs, **kwargs)
154
186
  end
155
187
 
188
+ def generate(inputs, generation_config = nil, logits_processor = nil, inputs_attention_mask: nil)
189
+ if !@can_generate
190
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
191
+ error_message = "The current model class (#{model_name}) is not compatible with `.generate()`, as it doesn't have a language model head."
192
+ raise Error, error_message
193
+ end
194
+
195
+ if !inputs.is_a?(Array)
196
+ raise ArgumentError, "`inputs` must be an Array, but is #{inputs.class.name}"
197
+ end
198
+
199
+ if @config[:is_encoder_decoder]
200
+ # Generating from the encoder outputs
201
+ input_ids_seq_length = 0
202
+ else
203
+ input_ids_seq_length = inputs.length
204
+
205
+ # decoder-only
206
+ if input_ids_seq_length == 0
207
+ raise Error, "Must supply a non-empty array of input token ids."
208
+ end
209
+ end
210
+
211
+ # Update generation config with defaults
212
+ generation_config = get_generation_config(generation_config)
213
+
214
+ logits_processor ||= Utils::LogitsProcessorList.new
215
+
216
+ # Update logits processor
217
+ logits_processor = get_logits_processor(
218
+ generation_config,
219
+ input_ids_seq_length,
220
+ logits_processor
221
+ )
222
+
223
+ eos_token_ids = generation_config[:eos_token_id]
224
+ if !eos_token_ids.nil? && !eos_token_ids.is_a?(Array)
225
+ eos_token_ids = [eos_token_ids]
226
+ end
227
+
228
+ num_output_tokens = 1
229
+ max_output_tokens = num_output_tokens + (generation_config[:max_new_tokens] || Float::INFINITY)
230
+
231
+ # Only use max length if max_new_tokens is not provided
232
+ use_max_length = generation_config[:max_length].is_a?(Integer) && generation_config[:max_new_tokens].nil?
233
+ sampler = Utils::Sampler.get_sampler(generation_config)
234
+
235
+ beams = get_start_beams(inputs, generation_config, num_output_tokens, inputs_attention_mask)
236
+
237
+ while beams.any? { |x| !x[:done] } && num_output_tokens < max_output_tokens
238
+ newest_beams = []
239
+ beams.each do |beam|
240
+ if beam[:done]
241
+ # Add this beam back into the pool
242
+ newest_beams << beam
243
+ next
244
+ end
245
+ if use_max_length && beam[:output_token_ids].length >= generation_config["max_length"]
246
+ # Set this beam to done and add it back into the pool
247
+ beam[:done] = true
248
+ newest_beams << beam
249
+ next
250
+ end
251
+
252
+ output = run_beam(beam)
253
+
254
+ # add attentions/scores to beam only if user requested
255
+ if generation_config["output_attentions"]
256
+ add_attentions_to_beam(beam, output)
257
+ end
258
+
259
+ # Logits are of the form [batch_size, out_seq_length, vocab_size]
260
+ # In most cases, this will be [batch_size, 1, vocab_size]
261
+ # So, we select the last token's logits:
262
+ # (equivalent to `logits = outputs.logits[:, -1, :]`)
263
+ logits = output["logits"].map { |v| v[-1] }
264
+
265
+ # Apply logits processor
266
+ logits_processor.(beam[:output_token_ids], logits)
267
+
268
+ sampled_tokens = sampler.(logits)
269
+ sampled_tokens.each do |new_token_id, log_prob|
270
+ # use previous beam as a starting point
271
+ new_beam = beam.dup
272
+
273
+ # update new beam
274
+ update_beam(new_beam, new_token_id)
275
+
276
+ new_beam[:score] += log_prob
277
+
278
+ if eos_token_ids && eos_token_ids.include?(new_token_id)
279
+ new_beam[:done] = true
280
+ end
281
+
282
+ newest_beams << new_beam
283
+ end
284
+ end
285
+ num_output_tokens += 1
286
+
287
+ # Next, we get the best beams, per ID
288
+ newest_beams =
289
+ group_beams(newest_beams).map do |group|
290
+ group.sort_by { |v| -v[:score] }[0...generation_config["num_beams"]]
291
+ end
292
+
293
+ # Flatten beams
294
+ beams = newest_beams.flatten(1)
295
+
296
+ # Run callback
297
+ if generation_config["callback_function"]
298
+ generation_config["callback_function"].(beams)
299
+ end
300
+ end
301
+
302
+ # TODO: Ensure that we can return non-batched outputs
303
+
304
+ grouped_beams = group_beams(beams)
305
+
306
+ get_flattened = lambda do |key|
307
+ grouped_beams.flat_map do |batch|
308
+ if generation_config["num_return_sequences"] > 1
309
+ raise Todo
310
+ else
311
+ [batch[0][key]]
312
+ end
313
+ end
314
+ end
315
+
316
+ sequences = get_flattened.(:output_token_ids) # [1, seqLength]
317
+
318
+ if generation_config["return_dict_in_generate"]
319
+ raise Todo
320
+ else
321
+ sequences
322
+ end
323
+ end
324
+
156
325
  private
157
326
 
327
+ def get_logits_processor(
328
+ generation_config,
329
+ input_ids_seq_length,
330
+ logits_processor = nil
331
+ )
332
+ processors = Utils::LogitsProcessorList.new
333
+
334
+ if !generation_config["repetition_penalty"].nil? && generation_config["repetition_penalty"] != 1.0
335
+ processors.push(Utils::RepetitionPenaltyLogitsProcessor.new(generation_config["repetition_penalty"]))
336
+ end
337
+
338
+ if !generation_config["no_repeat_ngram_size"].nil? && generation_config["no_repeat_ngram_size"] > 0
339
+ processors.push(Utils::NoRepeatNGramLogitsProcessor.new(generation_config["no_repeat_ngram_size"]))
340
+ end
341
+
342
+ if !generation_config["bad_words_ids"].nil?
343
+ processors.push(Utils::NoBadWordsLogitsProcessor.new(generation_config["bad_words_ids"], generation_config["eos_token_id"]))
344
+ end
345
+
346
+ if !generation_config["min_length"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_length"] > 0
347
+ processors.push(Utils::MinLengthLogitsProcessor.new(generation_config["min_length"], generation_config["eos_token_id"]))
348
+ end
349
+
350
+ if !generation_config["min_new_tokens"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_new_tokens"] > 0
351
+ processors.push(Utils::MinNewTokensLengthLogitsProcessor.new(
352
+ input_ids_seq_length,
353
+ generation_config["min_new_tokens"],
354
+ generation_config["eos_token_id"]
355
+ ))
356
+ end
357
+
358
+ if !generation_config["forced_bos_token_id"].nil?
359
+ processors.push(Utils::ForcedBOSTokenLogitsProcessor.new(generation_config["forced_bos_token_id"]))
360
+ end
361
+
362
+ if !generation_config["forced_eos_token_id"].nil?
363
+ processors.push(Utils::ForcedEOSTokenLogitsProcessor.new(
364
+ generation_config["max_length"],
365
+ generation_config["forced_eos_token_id"]
366
+ ))
367
+ end
368
+
369
+ if !generation_config["begin_suppress_tokens"].nil?
370
+ raise Todo
371
+ end
372
+
373
+ if !generation_config["forced_decoder_ids"].nil?
374
+ processors.push(Utils::ForceTokensLogitsProcessor.new(generation_config["forced_decoder_ids"]))
375
+ end
376
+
377
+ if !logits_processor.nil?
378
+ processors.concat(logits_processor)
379
+ end
380
+
381
+ processors
382
+ end
383
+
384
+ def get_generation_config(generation_config)
385
+ # Create empty generation config (contains defaults)
386
+ # We pass `@config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
387
+ gen_config = Utils::GenerationConfig.new(@config.to_h)
388
+
389
+ # Apply model's generation config, if it exists
390
+ if @generation_config
391
+ gen_config.merge!(@generation_config)
392
+ end
393
+
394
+ # Finally, use any generation config specified by the user
395
+ # when calling `generate`
396
+ if !generation_config.nil?
397
+ gen_config.merge!(generation_config)
398
+ end
399
+
400
+ gen_config
401
+ end
402
+
403
+ def seq2seq_forward(model_inputs)
404
+ encoder_outputs = model_inputs[:encoder_outputs]
405
+ past_key_values = model_inputs[:past_key_values]
406
+
407
+ if !encoder_outputs
408
+ # Encoder outputs are not given, so we must compute them.
409
+ encoder_outputs = encoder_forward(model_inputs)[0]
410
+ end
411
+ decoder_feeds = {
412
+ input_ids: model_inputs[:decoder_input_ids],
413
+ encoder_hidden_states: encoder_outputs
414
+ }
415
+ use_cache_branch = !!past_key_values
416
+
417
+ if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
418
+ decoder_feeds[:use_cache_branch] = [use_cache_branch]
419
+ end
420
+
421
+ if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("encoder_attention_mask")
422
+ decoder_feeds[:encoder_attention_mask] = model_inputs[:attention_mask]
423
+ end
424
+
425
+ prepare_position_ids(@decoder_merged_session, decoder_feeds, use_cache_branch)
426
+ add_past_key_values(decoder_feeds, past_key_values)
427
+
428
+ decoder_results = session_run(@decoder_merged_session, decoder_feeds)
429
+ decoder_results = @decoder_merged_session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
430
+ logits = decoder_results["logits"]
431
+ past_key_values = get_past_key_values(decoder_results, past_key_values)
432
+
433
+ # Get cross attention and/or decoder attentions if they are present
434
+ attns = get_attentions(decoder_results)
435
+
436
+ Seq2SeqLMOutput.new(logits, past_key_values, encoder_outputs, attns["decoder_attentions"], attns["cross_attentions"])
437
+ end
438
+
439
+ def prepare_position_ids(session, feeds, use_cache_branch)
440
+ if !session.inputs.map { |v| v[:name] }.include?("position_ids")
441
+ return
442
+ end
443
+
444
+ raise Todo
445
+ end
446
+
447
+ def get_past_key_values(decoder_results, past_key_values)
448
+ pkvs = {}
449
+
450
+ decoder_results.each_key do |name|
451
+ if name.start_with?("present")
452
+ new_name = name.sub("present", "past_key_values")
453
+
454
+ if past_key_values && name.include?("encoder")
455
+ # Optimization introduced by optimum to reuse past key values. So, we just replace the constant
456
+ # outputs with the previous past key values.
457
+ # https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
458
+ pkvs[new_name] = past_key_values[new_name]
459
+ else
460
+ pkvs[new_name] = decoder_results[name]
461
+ end
462
+ end
463
+ end
464
+ pkvs
465
+ end
466
+
467
+ def get_attentions(decoder_results)
468
+ attns = {}
469
+
470
+ ["cross_attentions", "decoder_attentions"].each do |attn_name|
471
+ result = []
472
+ decoder_results.each_key do |name|
473
+ if name.start_with?(attn_name)
474
+ index = name.split(".").pop
475
+ result[index] = decoder_results[name]
476
+ end
477
+ end
478
+ attns[attn_name] = result
479
+ end
480
+ attns
481
+ end
482
+
483
+ def add_past_key_values(decoder_feeds, past_key_values)
484
+ if past_key_values
485
+ decoder_feeds.merge!(past_key_values)
486
+ else
487
+ # TODO support batches (i.e., batch_size > 1)
488
+ batch_size = 1
489
+
490
+ if @config[:is_encoder_decoder] && (!@add_encoder_pkv.nil? ? @add_encoder_pkv : true)
491
+ _encoder_dims = [batch_size, @num_encoder_heads, 0, @encoder_dim_kv]
492
+ _decoder_dims = [batch_size, @num_decoder_heads, 0, @decoder_dim_kv]
493
+ @num_decoder_layers.times do |i|
494
+ # decoder_feeds["past_key_values.#{i}.encoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
495
+ # decoder_feeds["past_key_values.#{i}.encoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
496
+ # decoder_feeds["past_key_values.#{i}.decoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
497
+ # decoder_feeds["past_key_values.#{i}.decoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
498
+ end
499
+ elsif @config[:model_type] == "falcon"
500
+ raise Todo
501
+ elsif @config[:multi_query]
502
+ raise Todo
503
+ elsif @config[:model_type] == "bloom"
504
+ raise Todo
505
+ else
506
+ _dims = [batch_size, @num_heads, 0, @dim_kv]
507
+ @num_layers.times do |i|
508
+ # decoder_feeds["past_key_values.#{i}.key"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
509
+ # decoder_feeds["past_key_values.#{i}.value"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
510
+ end
511
+ end
512
+ end
513
+ end
514
+
515
+ def seq2seq_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask = nil)
516
+ beams = []
517
+ beam_id = 0
518
+
519
+ requires_attention_mask = !@requires_attention_mask.nil? ? @requires_attention_mask : true
520
+
521
+ # decoder_input_ids == output_token_ids
522
+ decoder_input_ids =
523
+ generation_config["decoder_input_ids"] ||
524
+ generation_config["decoder_start_token_id"] ||
525
+ generation_config["bos_token_id"] ||
526
+ generation_config["eos_token_id"]
527
+
528
+ if !decoder_input_ids.is_a?(Array)
529
+ decoder_input_ids = [decoder_input_ids]
530
+ end
531
+
532
+ input_token_ids.each do |tokens|
533
+ # TODO: Improve
534
+ # Currently, just add back batch dimension.
535
+ # In future, allow for true parallel execution
536
+ tokens = [tokens]
537
+
538
+ # Create beam
539
+ start = {
540
+ inputs: tokens,
541
+ encoder_outputs: nil,
542
+ prev_model_outputs: nil,
543
+
544
+ output_token_ids: decoder_input_ids,
545
+ done: false,
546
+ score: 0,
547
+ id: beam_id # assign unique id to beams
548
+ }
549
+ beam_id += 1
550
+
551
+ if requires_attention_mask
552
+ start[:attention_mask] = prepare_attention_mask(tokens)
553
+ end
554
+
555
+ beams << start
556
+ end
557
+
558
+ beams
559
+ end
560
+
561
+ def prepare_attention_mask(tokens)
562
+ # Prepare attention mask
563
+ pad_token_id = @config["pad_token_id"]
564
+ eos_token_id = @config["eos_token_id"]
565
+ if eos_token_id.is_a?(Integer)
566
+ eos_token_id = [eos_token_id]
567
+ end
568
+
569
+ is_pad_token_in_inputs = !tokens.index(pad_token_id).nil?
570
+ is_pad_token_not_equal_to_eos_token_id = eos_token_id.nil? || !eos_token_id.include?(pad_token_id)
571
+
572
+ if is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id
573
+ raise Todo
574
+ else
575
+ Utils.ones_like(tokens)
576
+ end
577
+ end
578
+
579
+ def seq2seq_run_beam(beam)
580
+ input_name = self.class.const_get(:MAIN_INPUT_NAME)
581
+
582
+ decoder_input_ids = beam[:output_token_ids]
583
+ if beam[:prev_model_outputs]
584
+ # After the first step, `prev_model_outputs` won't be null.
585
+ # So, we cut decoder_input_ids if past is used
586
+ decoder_input_ids = [decoder_input_ids[-1]]
587
+ end
588
+
589
+ # 1. Prepare
590
+ model_inputs = {
591
+ input_name => beam[:inputs],
592
+ decoder_input_ids: [decoder_input_ids],
593
+ encoder_outputs: beam[:encoder_outputs],
594
+ past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
595
+ }
596
+ if beam[:attention_mask]
597
+ model_inputs[:attention_mask] = beam[:attention_mask]
598
+ end
599
+
600
+ # 2. Run
601
+ output = @forward.(model_inputs)
602
+
603
+ # 3. Update
604
+ beam[:prev_model_outputs] = output
605
+ beam[:encoder_outputs] = output[:encoder_outputs]
606
+
607
+ output
608
+ end
609
+
610
+ def seq2seq_update_beam(beam, new_token_id)
611
+ beam[:output_token_ids] += [new_token_id]
612
+ end
613
+
614
+ def group_beams(beams)
615
+ # Group beams by their ids
616
+ groups = {}
617
+ beams.each do |obj|
618
+ if !groups[obj[:id]]
619
+ groups[obj[:id]] = [obj]
620
+ else
621
+ groups[obj[:id]] << obj
622
+ end
623
+ end
624
+ groups.values
625
+ end
626
+
158
627
  def encoder_forward(model_inputs, output_names: nil)
159
628
  encoder_feeds = {}
160
629
  @session.inputs.each do |input|
@@ -167,7 +636,96 @@ module Informers
167
636
  session_run(@session, encoder_feeds, output_names:)
168
637
  end
169
638
 
170
- def session_run(session, inputs, output_names:)
639
+ def decoder_forward(model_inputs)
640
+ input_ids, past_key_values, attention_mask =
641
+ model_inputs.values_at(:input_ids, :past_key_values, :attention_mask)
642
+ decoder_feeds = {
643
+ input_ids: input_ids,
644
+ attention_mask: attention_mask || prepare_attention_mask(input_ids)
645
+ }
646
+ use_cache_branch = !!past_key_values
647
+
648
+ if @session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
649
+ decoder_feeds[:use_cache_branch] = [use_cache_branch]
650
+ end
651
+
652
+ prepare_position_ids(@session, decoder_feeds, use_cache_branch)
653
+
654
+ add_past_key_values(decoder_feeds, past_key_values)
655
+
656
+ decoder_results = session_run(@session, decoder_feeds)
657
+ decoder_results = @session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
658
+
659
+ logits = decoder_results["logits"]
660
+
661
+ past_key_values = get_past_key_values(decoder_results, past_key_values)
662
+ {"logits" => logits, past_key_values: past_key_values}
663
+ end
664
+
665
+ def decoder_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
666
+ beams = []
667
+
668
+ beam_id = 0
669
+ input_token_ids.each do |tokens|
670
+ output_token_ids = tokens.dup
671
+
672
+ # TODO: Improve
673
+ # Currently, just add back batch dimension.
674
+ # In future, allow for true parallel execution
675
+ tokens = [tokens]
676
+
677
+ if inputs_attention_mask
678
+ attn_mask = inputs_attention_mask[beam_id]
679
+ attn_mask = [attn_mask]
680
+ else
681
+ attn_mask = prepare_attention_mask(tokens)
682
+ end
683
+
684
+ start = {
685
+ input: tokens,
686
+ model_input_ids: tokens,
687
+ attention_mask: attn_mask,
688
+ prev_model_outputs: nil,
689
+
690
+ output_token_ids: output_token_ids,
691
+ num_output_tokens: num_output_tokens,
692
+
693
+ done: false,
694
+ score: 0,
695
+ id: beam_id # assign unique id to beams
696
+ }
697
+ beam_id += 1
698
+
699
+ beams << start
700
+ end
701
+ beams
702
+ end
703
+
704
+ def decoder_run_beam(beam)
705
+ attn_mask_data = Array.new(beam[:output_token_ids].length, 1)
706
+
707
+ # 1. Prepare
708
+ model_inputs = {
709
+ input_ids: beam[:model_input_ids],
710
+ attention_mask: [attn_mask_data],
711
+ past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
712
+ }
713
+
714
+ # 2. Run
715
+ output = @forward.(model_inputs)
716
+
717
+ # 3. Update
718
+ beam[:prev_model_outputs] = output
719
+
720
+ output
721
+ end
722
+
723
+ def decoder_update_beam(beam, new_token_id)
724
+ beam[:output_token_ids] += [new_token_id]
725
+ beam[:model_input_ids] = [[new_token_id]]
726
+ end
727
+
728
+ def session_run(session, inputs, output_names: nil)
171
729
  checked_inputs = validate_inputs(session, inputs)
172
730
  begin
173
731
  output = session.run(output_names || @output_names, checked_inputs)
@@ -187,6 +745,18 @@ module Informers
187
745
  def validate_inputs(session, inputs)
188
746
  inputs
189
747
  end
748
+
749
+ def get_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
750
+ @get_start_beams.(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
751
+ end
752
+
753
+ def run_beam(beam)
754
+ @run_beam.(beam)
755
+ end
756
+
757
+ def update_beam(beam, new_token_id)
758
+ @update_beam.(beam, new_token_id)
759
+ end
190
760
  end
191
761
 
192
762
  class BertPreTrainedModel < PreTrainedModel
@@ -195,6 +765,12 @@ module Informers
195
765
  class BertModel < BertPreTrainedModel
196
766
  end
197
767
 
768
+ class BertForMaskedLM < BertPreTrainedModel
769
+ def call(model_inputs)
770
+ MaskedLMOutput.new(*super(model_inputs))
771
+ end
772
+ end
773
+
198
774
  class BertForSequenceClassification < BertPreTrainedModel
199
775
  def call(model_inputs)
200
776
  SequenceClassifierOutput.new(*super(model_inputs))
@@ -243,6 +819,126 @@ module Informers
243
819
  class MPNetModel < MPNetPreTrainedModel
244
820
  end
245
821
 
822
+ class T5PreTrainedModel < PreTrainedModel
823
+ end
824
+
825
+ class T5Model < T5PreTrainedModel
826
+ end
827
+
828
+ class T5ForConditionalGeneration < T5PreTrainedModel
829
+ def initialize(config, session, decoder_merged_session, generation_config)
830
+ super(config, session)
831
+ @decoder_merged_session = decoder_merged_session
832
+ @generation_config = generation_config
833
+
834
+ @num_decoder_layers = @config[:num_decoder_layers]
835
+ @num_decoder_heads = @config[:num_heads]
836
+ @decoder_dim_kv = @config[:d_kv]
837
+
838
+ @num_encoder_layers = @config[:num_layers]
839
+ @num_encoder_heads = @config[:num_heads]
840
+ @encoder_dim_kv = @config[:d_kv]
841
+ end
842
+ end
843
+
844
+ class BartPretrainedModel < PreTrainedModel
845
+ end
846
+
847
+ class BartModel < BartPretrainedModel
848
+ end
849
+
850
+ class BartForConditionalGeneration < BartPretrainedModel
851
+ def initialize(config, session, decoder_merged_session, generation_config)
852
+ super(config, session)
853
+ @decoder_merged_session = decoder_merged_session
854
+ @generation_config = generation_config
855
+
856
+ @num_decoder_layers = @config["decoder_layers"]
857
+ @num_decoder_heads = @config["decoder_attention_heads"]
858
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
859
+
860
+ @num_encoder_layers = @config["encoder_layers"]
861
+ @num_encoder_heads = @config["encoder_attention_heads"]
862
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads
863
+ end
864
+ end
865
+
866
+ class BartForSequenceClassification < BartPretrainedModel
867
+ def call(model_inputs)
868
+ SequenceClassifierOutput.new(*super(model_inputs))
869
+ end
870
+ end
871
+
872
+ class MBartPreTrainedModel < PreTrainedModel
873
+ end
874
+
875
+ class MBartModel < MBartPreTrainedModel
876
+ end
877
+
878
+ class MBartForCausalLM < MBartPreTrainedModel
879
+ attr_reader :num_decoder_layers, :num_decoder_heads, :decoder_dim_kv,
880
+ :num_encoder_layers, :num_encoder_heads, :encoder_dim_kv
881
+
882
+ def initialize(config, decoder_merged_session, generation_config)
883
+ super(config, decoder_merged_session)
884
+ @generation_config = generation_config
885
+
886
+ @num_decoder_layers = @config["decoder_layers"]
887
+ @num_decoder_heads = @config["decoder_attention_heads"]
888
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
889
+
890
+ @num_encoder_layers = @config["encoder_layers"]
891
+ @num_encoder_heads = @config["encoder_attention_heads"]
892
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
893
+ end
894
+ end
895
+
896
+ class M2M100PreTrainedModel < PreTrainedModel
897
+ end
898
+
899
+ class M2M100Model < M2M100PreTrainedModel
900
+ end
901
+
902
+ class M2M100ForConditionalGeneration < M2M100PreTrainedModel
903
+ def initialize(config, session, decoder_merged_session, generation_config)
904
+ super(config, session)
905
+ @decoder_merged_session = decoder_merged_session
906
+ @generation_config = generation_config
907
+
908
+ @num_decoder_layers = @config["decoder_layers"]
909
+ @num_decoder_heads = @config["decoder_attention_heads"]
910
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
911
+
912
+ @num_encoder_layers = @config["encoder_layers"]
913
+ @num_encoder_heads = @config["encoder_attention_heads"]
914
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
915
+ end
916
+ end
917
+
918
+ class Wav2Vec2PreTrainedModel < PreTrainedModel
919
+ end
920
+
921
+ class Wav2Vec2Model < Wav2Vec2PreTrainedModel
922
+ end
923
+
924
+ class Wav2Vec2ForSequenceClassification < Wav2Vec2PreTrainedModel
925
+ def call(model_inputs)
926
+ SequenceClassifierOutput.new(*super(model_inputs))
927
+ end
928
+ end
929
+
930
+ class RobertaPreTrainedModel < PreTrainedModel
931
+ end
932
+
933
+ class RobertaModel < RobertaPreTrainedModel
934
+ end
935
+
936
+ class RobertaForMaskedLM < RobertaPreTrainedModel
937
+ def call(model_inputs)
938
+ MaskedLMOutput.new(*super(model_inputs))
939
+ end
940
+ end
941
+
246
942
  class XLMRobertaPreTrainedModel < PreTrainedModel
247
943
  end
248
944
 
@@ -255,34 +951,351 @@ module Informers
255
951
  end
256
952
  end
257
953
 
954
+ class ViTPreTrainedModel < PreTrainedModel
955
+ end
956
+
957
+ class ViTModel < ViTPreTrainedModel
958
+ end
959
+
960
+ class ViTForImageClassification < ViTPreTrainedModel
961
+ def call(model_inputs)
962
+ SequenceClassifierOutput.new(*super(model_inputs))
963
+ end
964
+ end
965
+
966
+ class CLIPPreTrainedModel < PreTrainedModel
967
+ end
968
+
969
+ class CLIPModel < CLIPPreTrainedModel
970
+ end
971
+
972
+ class GPT2PreTrainedModel < PreTrainedModel
973
+ attr_reader :num_heads, :num_layers, :dim_kv
974
+
975
+ def initialize(config, session, generation_config)
976
+ super(config, session)
977
+ @generation_config = generation_config
978
+
979
+ # config doesn't contain pad_token_id, so we assume it is the eos_token_id
980
+ @config["pad_token_id"] = @config["eos_token_id"]
981
+
982
+ @num_heads = @config["n_head"]
983
+ @num_layers = @config["n_layer"]
984
+ @dim_kv = @config["n_embd"] / @num_heads.to_f
985
+ end
986
+ end
987
+
988
+ class GPT2Model < GPT2PreTrainedModel
989
+ end
990
+
991
+ class GPT2LMHeadModel < GPT2PreTrainedModel
992
+ end
993
+
994
+ class OwlViTPreTrainedModel < PreTrainedModel
995
+ end
996
+
997
+ class OwlViTModel < OwlViTPreTrainedModel
998
+ end
999
+
1000
+ class OwlViTForObjectDetection < OwlViTPreTrainedModel
1001
+ end
1002
+
1003
+ class DetrPreTrainedModel < PreTrainedModel
1004
+ end
1005
+
1006
+ class DetrModel < DetrPreTrainedModel
1007
+ end
1008
+
1009
+ class DetrForObjectDetection < DetrPreTrainedModel
1010
+ def call(model_inputs)
1011
+ DetrObjectDetectionOutput.new(*super(model_inputs))
1012
+ end
1013
+ end
1014
+
1015
+ class DetrForSegmentation < DetrPreTrainedModel
1016
+ def call(model_inputs)
1017
+ DetrSegmentationOutput.new(*super(model_inputs))
1018
+ end
1019
+ end
1020
+
1021
+ class Swin2SRPreTrainedModel < PreTrainedModel
1022
+ end
1023
+
1024
+ class Swin2SRModel < Swin2SRPreTrainedModel
1025
+ end
1026
+
1027
+ class Swin2SRForImageSuperResolution < Swin2SRPreTrainedModel
1028
+ end
1029
+
1030
+ class DPTPreTrainedModel < PreTrainedModel
1031
+ end
1032
+
1033
+ class DPTModel < DPTPreTrainedModel
1034
+ end
1035
+
1036
+ class DPTForDepthEstimation < DPTPreTrainedModel
1037
+ end
1038
+
1039
+ class VisionEncoderDecoderModel < PreTrainedModel
1040
+ MAIN_INPUT_NAME = :pixel_values
1041
+
1042
+ def initialize(config, session, decoder_merged_session, generation_config)
1043
+ super(config, session)
1044
+ @decoder_merged_session = decoder_merged_session
1045
+ @generation_config = generation_config
1046
+
1047
+ # Extract configs
1048
+ encoder_config = @config["encoder"]
1049
+ decoder_config = @config["decoder"]
1050
+
1051
+ # Validate encoder
1052
+ encoder_model_type = encoder_config["model_type"]
1053
+ encoder_model = MODEL_MAPPING_NAMES_ENCODER_ONLY[encoder_model_type] || MODEL_MAPPING_NAMES_ENCODER_DECODER[encoder_model_type]
1054
+ if !encoder_model
1055
+ warn "Model type for encoder '#{encoder_model_type}' not found, assuming encoder-only architecture. Please report this."
1056
+ end
1057
+
1058
+ # Validate decoder
1059
+ decoder_model = MODEL_WITH_LM_HEAD_MAPPING_NAMES[decoder_config["model_type"]]
1060
+ if !decoder_model
1061
+ raise Error, "Unable to construct `VisionEncoderDecoder` due to unsupported decoder: \"#{decoder_config["model_type"]}\""
1062
+ end
1063
+
1064
+ decoder_model_class = decoder_model[1]
1065
+ decoder = decoder_model_class.new(decoder_config, decoder_merged_session, generation_config)
1066
+
1067
+ @add_encoder_pkv = decoder.respond_to?(:num_decoder_layers)
1068
+ if @add_encoder_pkv
1069
+ # Decoder is part of an encoder-decoder model
1070
+ @num_decoder_layers = decoder.num_decoder_layers
1071
+ @num_decoder_heads = decoder.num_decoder_heads
1072
+ @decoder_dim_kv = decoder.decoder_dim_kv
1073
+
1074
+ @num_encoder_layers = decoder.num_encoder_layers
1075
+ @num_encoder_heads = decoder.num_encoder_heads
1076
+ @encoder_dim_kv = decoder.encoder_dim_kv
1077
+ else
1078
+ # Decoder is a decoder-only model
1079
+ @num_layers = decoder.num_layers
1080
+ @num_heads = decoder.num_heads
1081
+ @dim_kv = decoder.dim_kv
1082
+ end
1083
+ end
1084
+ end
1085
+
1086
+ class DonutSwinPreTrainedModel < PreTrainedModel
1087
+ end
1088
+
1089
+ class DonutSwinModel < DonutSwinPreTrainedModel
1090
+ end
1091
+
1092
+ class WhisperPreTrainedModel < PreTrainedModel
1093
+ end
1094
+
1095
+ class WhisperModel < WhisperPreTrainedModel
1096
+ end
1097
+
1098
+ class WhisperForConditionalGeneration < WhisperPreTrainedModel
1099
+ REQUIRES_ATTENTION_MASK = false
1100
+ MAIN_INPUT_NAME = :input_features
1101
+
1102
+ def initialize(config, session, decoder_merged_session, generation_config)
1103
+ super(config, session)
1104
+ @decoder_merged_session = decoder_merged_session
1105
+ @generation_config = generation_config
1106
+
1107
+ @num_decoder_layers = @config["decoder_layers"]
1108
+ @num_decoder_heads = @config["decoder_attention_heads"]
1109
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
1110
+
1111
+ @num_encoder_layers = @config["encoder_layers"]
1112
+ @num_encoder_heads = @config["encoder_attention_heads"]
1113
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
1114
+ end
1115
+
1116
+ def generate(inputs, generation_config = nil, logits_processor = nil)
1117
+ raise Todo
1118
+ end
1119
+ end
1120
+
1121
+ class VitsPreTrainedModel < PreTrainedModel
1122
+ end
1123
+
1124
+ class VitsModel < VitsPreTrainedModel
1125
+ def call(model_inputs)
1126
+ VitsModelOutput.new(*super(model_inputs))
1127
+ end
1128
+ end
1129
+
1130
+ class SpeechT5PreTrainedModel < PreTrainedModel
1131
+ end
1132
+
1133
+ class SpeechT5Model < SpeechT5PreTrainedModel
1134
+ end
1135
+
1136
+ class SpeechT5ForSpeechToText < SpeechT5PreTrainedModel
1137
+ end
1138
+
1139
+ class SpeechT5ForTextToSpeech < SpeechT5PreTrainedModel
1140
+ end
1141
+
1142
+ class ClapPreTrainedModel < PreTrainedModel
1143
+ end
1144
+
1145
+ class ClapModel < ClapPreTrainedModel
1146
+ end
1147
+
258
1148
  MODEL_MAPPING_NAMES_ENCODER_ONLY = {
259
1149
  "bert" => ["BertModel", BertModel],
260
1150
  "nomic_bert" => ["NomicBertModel", NomicBertModel],
261
1151
  "deberta-v2" => ["DebertaV2Model", DebertaV2Model],
262
1152
  "mpnet" => ["MPNetModel", MPNetModel],
263
1153
  "distilbert" => ["DistilBertModel", DistilBertModel],
264
- "xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel]
1154
+ "roberta" => ["RobertaModel", RobertaModel],
1155
+ "xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel],
1156
+ "clap" => ["ClapModel", ClapModel],
1157
+ "clip" => ["CLIPModel", CLIPModel],
1158
+ "detr" => ["DetrModel", DetrModel],
1159
+ "vit" => ["ViTModel", ViTModel],
1160
+ "owlvit" => ["OwlViTModel", OwlViTModel],
1161
+ "donut-swin" => ["DonutSwinModel", DonutSwinModel]
1162
+ }
1163
+
1164
+ MODEL_MAPPING_NAMES_ENCODER_DECODER = {
1165
+ "bart" => ["BartModel", BartModel]
1166
+ }
1167
+
1168
+ MODEL_MAPPING_NAMES_DECODER_ONLY = {
1169
+ }
1170
+
1171
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = {
1172
+ "whisper" => ["WhisperForConditionalGeneration", WhisperForConditionalGeneration]
1173
+ }
1174
+
1175
+ MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = {
1176
+ "speecht5" => ["SpeechT5ForTextToSpeech", SpeechT5ForTextToSpeech]
1177
+ }
1178
+
1179
+ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = {
1180
+ "vits" => ["VitsModel", VitsModel]
265
1181
  }
266
1182
 
267
1183
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
268
1184
  "bert" => ["BertForSequenceClassification", BertForSequenceClassification],
269
1185
  "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
270
- "xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification]
1186
+ "xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
1187
+ "bart" => ["BartForSequenceClassification", BartForSequenceClassification]
271
1188
  }
272
1189
 
273
1190
  MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
274
1191
  "bert" => ["BertForTokenClassification", BertForTokenClassification]
275
1192
  }
276
1193
 
1194
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {
1195
+ "t5" => ["T5ForConditionalGeneration", T5ForConditionalGeneration],
1196
+ "bart" => ["BartForConditionalGeneration", BartForConditionalGeneration],
1197
+ "m2m_100" => ["M2M100ForConditionalGeneration", M2M100ForConditionalGeneration]
1198
+ }
1199
+
1200
+ MODEL_WITH_LM_HEAD_MAPPING_NAMES = {
1201
+ "gpt2" => ["GPT2LMHeadModel", GPT2LMHeadModel],
1202
+ "mbart" => ["MBartForCausalLM", MBartForCausalLM]
1203
+ }
1204
+
1205
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
1206
+ "bert" => ["BertForMaskedLM", BertForMaskedLM],
1207
+ "roberta" => ["RobertaForMaskedLM", RobertaForMaskedLM]
1208
+ }
1209
+
277
1210
  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
278
1211
  "distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
279
1212
  }
280
1213
 
1214
+ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {
1215
+ "vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
1216
+ }
1217
+
1218
+ MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = {
1219
+ "vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
1220
+ }
1221
+
1222
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
1223
+ "vit" => ["ViTForImageClassification", ViTForImageClassification]
1224
+ }
1225
+
1226
+ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = {
1227
+ "detr" => ["DetrForObjectDetection", DetrForObjectDetection]
1228
+ }
1229
+
1230
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = {
1231
+ "owlvit" => ["OwlViTForObjectDetection", OwlViTForObjectDetection]
1232
+ }
1233
+
1234
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = {
1235
+ "detr" => ["DetrForSegmentation", DetrForSegmentation]
1236
+ }
1237
+
1238
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = {
1239
+ }
1240
+
1241
+ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = {
1242
+ }
1243
+
1244
+ MODEL_FOR_CTC_MAPPING_NAMES = {
1245
+ }
1246
+
1247
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = {
1248
+ "wav2vec2" => ["Wav2Vec2ForSequenceClassification", Wav2Vec2ForSequenceClassification]
1249
+ }
1250
+
1251
+ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = {
1252
+ }
1253
+
1254
+ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = {
1255
+ }
1256
+
1257
+ MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = {
1258
+ }
1259
+
1260
+ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = {
1261
+ "swin2sr" => ["Swin2SRForImageSuperResolution", Swin2SRForImageSuperResolution]
1262
+ }
1263
+
1264
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = {
1265
+ "dpt" => ["DPTForDepthEstimation", DPTForDepthEstimation]
1266
+ }
1267
+
1268
+ MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = {
1269
+ }
1270
+
281
1271
  MODEL_CLASS_TYPE_MAPPING = [
282
1272
  [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES[:EncoderOnly]],
1273
+ [MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES[:EncoderDecoder]],
1274
+ [MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES[:DecoderOnly]],
283
1275
  [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
284
1276
  [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
285
- [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
1277
+ [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
1278
+ [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
1279
+ [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES[:DecoderOnly]],
1280
+ [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1281
+ [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1282
+ [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Vision2Seq]],
1283
+ [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1284
+ [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1285
+ [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1286
+ [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1287
+ [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1288
+ [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1289
+ [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1290
+ [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1291
+ [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES[:MaskGeneration]],
1292
+ [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1293
+ [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1294
+ [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
1295
+ [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1296
+ [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1297
+ [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1298
+ [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
286
1299
  ]
287
1300
 
288
1301
  MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
@@ -306,11 +1319,113 @@ module Informers
306
1319
  MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
307
1320
  end
308
1321
 
1322
+ class AutoModelForSeq2SeqLM < PretrainedMixin
1323
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES]
1324
+ end
1325
+
1326
+ class AutoModelForSpeechSeq2Seq < PretrainedMixin
1327
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES]
1328
+ end
1329
+
1330
+ class AutoModelForTextToSpectrogram < PretrainedMixin
1331
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES]
1332
+ end
1333
+
1334
+ class AutoModelForTextToWaveform < PretrainedMixin
1335
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES]
1336
+ end
1337
+
1338
+ class AutoModelForCausalLM < PretrainedMixin
1339
+ MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]
1340
+ end
1341
+
1342
+ class AutoModelForMaskedLM < PretrainedMixin
1343
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]
1344
+ end
1345
+
309
1346
  class AutoModelForQuestionAnswering < PretrainedMixin
310
1347
  MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
311
1348
  end
312
1349
 
1350
+ class AutoModelForVision2Seq < PretrainedMixin
1351
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]
1352
+ end
1353
+
1354
+ class AutoModelForImageClassification < PretrainedMixin
1355
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]
1356
+ end
1357
+
1358
+ class AutoModelForImageSegmentation < PretrainedMixin
1359
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]
1360
+ end
1361
+
1362
+ class AutoModelForSemanticSegmentation < PretrainedMixin
1363
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES]
1364
+ end
1365
+
1366
+ class AutoModelForObjectDetection < PretrainedMixin
1367
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]
1368
+ end
1369
+
1370
+ class AutoModelForZeroShotObjectDetection < PretrainedMixin
1371
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]
1372
+ end
1373
+
1374
+ class AutoModelForMaskGeneration < PretrainedMixin
1375
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES]
1376
+ end
1377
+
1378
+ class AutoModelForCTC < PretrainedMixin
1379
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_CTC_MAPPING_NAMES]
1380
+ end
1381
+
1382
+ class AutoModelForAudioClassification < PretrainedMixin
1383
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES]
1384
+ end
1385
+
1386
+ class AutoModelForXVector < PretrainedMixin
1387
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES]
1388
+ end
1389
+
1390
+ class AutoModelForAudioFrameClassification < PretrainedMixin
1391
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES]
1392
+ end
1393
+
1394
+ class AutoModelForDocumentQuestionAnswering < PretrainedMixin
1395
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]
1396
+ end
1397
+
1398
+ class AutoModelForImageMatting < PretrainedMixin
1399
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES]
1400
+ end
1401
+
1402
+ class AutoModelForImageToImage < PretrainedMixin
1403
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]
1404
+ end
1405
+
1406
+ class AutoModelForDepthEstimation < PretrainedMixin
1407
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES]
1408
+ end
1409
+
1410
+ class AutoModelForImageFeatureExtraction < PretrainedMixin
1411
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]
1412
+ end
1413
+
313
1414
  class ModelOutput
1415
+ def [](key)
1416
+ instance_variable_get("@#{key}")
1417
+ end
1418
+ end
1419
+
1420
+ class Seq2SeqLMOutput < ModelOutput
1421
+ def initialize(logits, past_key_values, encoder_outputs, decoder_attentions = nil, cross_attentions = nil)
1422
+ super()
1423
+ @logits = logits
1424
+ @past_key_values = past_key_values
1425
+ @encoder_outputs = encoder_outputs
1426
+ @decoder_attentions = decoder_attentions
1427
+ @cross_attentions = cross_attentions
1428
+ end
314
1429
  end
315
1430
 
316
1431
  class SequenceClassifierOutput < ModelOutput
@@ -331,6 +1446,15 @@ module Informers
331
1446
  end
332
1447
  end
333
1448
 
1449
+ class MaskedLMOutput < ModelOutput
1450
+ attr_reader :logits
1451
+
1452
+ def initialize(logits)
1453
+ super()
1454
+ @logits = logits
1455
+ end
1456
+ end
1457
+
334
1458
  class QuestionAnsweringModelOutput < ModelOutput
335
1459
  attr_reader :start_logits, :end_logits
336
1460
 
@@ -340,4 +1464,25 @@ module Informers
340
1464
  @end_logits = end_logits
341
1465
  end
342
1466
  end
1467
+
1468
+ class DetrObjectDetectionOutput < ModelOutput
1469
+ attr_reader :logits, :pred_boxes
1470
+
1471
+ def initialize(logits, pred_boxes)
1472
+ super()
1473
+ @logits = logits
1474
+ @pred_boxes = pred_boxes
1475
+ end
1476
+ end
1477
+
1478
+ class DetrSegmentationOutput < ModelOutput
1479
+ attr_reader :logits, :pred_boxes, :pred_masks
1480
+
1481
+ def initialize(logits, pred_boxes, pred_masks)
1482
+ super()
1483
+ @logits = logits
1484
+ @pred_boxes = pred_boxes
1485
+ @pred_masks = pred_masks
1486
+ end
1487
+ end
343
1488
  end