informers 1.0.3 → 1.1.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -44,7 +44,7 @@ module Informers
44
44
  end
45
45
 
46
46
  const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
47
- model_info = model_class_mapping[config.model_type]
47
+ model_info = model_class_mapping[config[:model_type]]
48
48
  if !model_info
49
49
  next # Item not found in this mapping
50
50
  end
@@ -52,15 +52,17 @@ module Informers
52
52
  end
53
53
 
54
54
  if const_defined?(:BASE_IF_FAIL)
55
- warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
55
+ warn "Unknown model class #{config[:model_type].inspect}, attempting to construct from base class."
56
56
  PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
57
57
  else
58
- raise Error, "Unsupported model type: #{config.model_type}"
58
+ raise Error, "Unsupported model type: #{config[:model_type]}"
59
59
  end
60
60
  end
61
61
  end
62
62
 
63
63
  class PreTrainedModel
64
+ MAIN_INPUT_NAME = :input_ids
65
+
64
66
  attr_reader :config
65
67
 
66
68
  def initialize(config, session)
@@ -76,11 +78,24 @@ module Informers
76
78
 
77
79
  case model_type
78
80
  when MODEL_TYPES[:DecoderOnly]
79
- raise Todo
81
+ @can_generate = true
82
+
83
+ @run_beam = method(:decoder_run_beam)
84
+ @get_start_beams = method(:decoder_start_beams)
85
+ @update_beam = method(:decoder_update_beam)
86
+ @forward = method(:decoder_forward)
87
+
80
88
  when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
81
- raise Todo
89
+ @can_generate = true
90
+
91
+ @run_beam = method(:seq2seq_run_beam)
92
+ @get_start_beams = method(:seq2seq_start_beams)
93
+ @update_beam = method(:seq2seq_update_beam)
94
+ @forward = method(:seq2seq_forward)
95
+
82
96
  when MODEL_TYPES[:EncoderDecoder]
83
- raise Todo
97
+ @forward = method(:encoder_forward)
98
+
84
99
  else
85
100
  @forward = method(:encoder_forward)
86
101
  end
@@ -110,20 +125,37 @@ module Informers
110
125
  model_type = MODEL_TYPE_MAPPING[model_name]
111
126
 
112
127
  if model_type == MODEL_TYPES[:DecoderOnly]
113
- raise Todo
128
+ info = [
129
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
130
+ construct_session(pretrained_model_name_or_path, options[:model_file_name] || "decoder_model_merged", **options),
131
+ Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
132
+ ]
114
133
 
115
134
  elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
116
- raise Todo
135
+ info = [
136
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
137
+ construct_session(pretrained_model_name_or_path, "encoder_model", **options),
138
+ construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options),
139
+ Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
140
+ ]
117
141
 
118
142
  elsif model_type == MODEL_TYPES[:MaskGeneration]
119
- raise Todo
143
+ info = [
144
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
145
+ construct_session(pretrained_model_name_or_path, "vision_encoder", **options),
146
+ construct_session(pretrained_model_name_or_path, "prompt_encoder_mask_decoder", **options)
147
+ ]
120
148
 
121
149
  elsif model_type == MODEL_TYPES[:EncoderDecoder]
122
- raise Todo
150
+ info = [
151
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
152
+ construct_session(pretrained_model_name_or_path, "encoder_model", **options),
153
+ construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options)
154
+ ]
123
155
 
124
156
  else
125
157
  if model_type != MODEL_TYPES[:EncoderOnly]
126
- warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
158
+ warn "Model type for '#{model_name || config[:model_type]}' not found, assuming encoder-only architecture. Please report this."
127
159
  end
128
160
  info = [
129
161
  AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
@@ -153,8 +185,445 @@ module Informers
153
185
  @forward.(model_inputs, **kwargs)
154
186
  end
155
187
 
188
+ def generate(inputs, generation_config = nil, logits_processor = nil, inputs_attention_mask: nil)
189
+ if !@can_generate
190
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
191
+ error_message = "The current model class (#{model_name}) is not compatible with `.generate()`, as it doesn't have a language model head."
192
+ raise Error, error_message
193
+ end
194
+
195
+ if !inputs.is_a?(Array)
196
+ raise ArgumentError, "`inputs` must be an Array, but is #{inputs.class.name}"
197
+ end
198
+
199
+ if @config[:is_encoder_decoder]
200
+ # Generating from the encoder outputs
201
+ input_ids_seq_length = 0
202
+ else
203
+ input_ids_seq_length = inputs.length
204
+
205
+ # decoder-only
206
+ if input_ids_seq_length == 0
207
+ raise Error, "Must supply a non-empty array of input token ids."
208
+ end
209
+ end
210
+
211
+ # Update generation config with defaults
212
+ generation_config = get_generation_config(generation_config)
213
+
214
+ logits_processor ||= Utils::LogitsProcessorList.new
215
+
216
+ # Update logits processor
217
+ logits_processor = get_logits_processor(
218
+ generation_config,
219
+ input_ids_seq_length,
220
+ logits_processor
221
+ )
222
+
223
+ eos_token_ids = generation_config[:eos_token_id]
224
+ if !eos_token_ids.nil? && !eos_token_ids.is_a?(Array)
225
+ eos_token_ids = [eos_token_ids]
226
+ end
227
+
228
+ num_output_tokens = 1
229
+ max_output_tokens = num_output_tokens + (generation_config[:max_new_tokens] || Float::INFINITY)
230
+
231
+ # Only use max length if max_new_tokens is not provided
232
+ use_max_length = generation_config[:max_length].is_a?(Integer) && generation_config[:max_new_tokens].nil?
233
+ sampler = Utils::Sampler.get_sampler(generation_config)
234
+
235
+ beams = get_start_beams(inputs, generation_config, num_output_tokens, inputs_attention_mask)
236
+
237
+ while beams.any? { |x| !x[:done] } && num_output_tokens < max_output_tokens
238
+ newest_beams = []
239
+ beams.each do |beam|
240
+ if beam[:done]
241
+ # Add this beam back into the pool
242
+ newest_beams << beam
243
+ next
244
+ end
245
+ if use_max_length && beam[:output_token_ids].length >= generation_config["max_length"]
246
+ # Set this beam to done and add it back into the pool
247
+ beam[:done] = true
248
+ newest_beams << beam
249
+ next
250
+ end
251
+
252
+ output = run_beam(beam)
253
+
254
+ # add attentions/scores to beam only if user requested
255
+ if generation_config["output_attentions"]
256
+ add_attentions_to_beam(beam, output)
257
+ end
258
+
259
+ # Logits are of the form [batch_size, out_seq_length, vocab_size]
260
+ # In most cases, this will be [batch_size, 1, vocab_size]
261
+ # So, we select the last token's logits:
262
+ # (equivalent to `logits = outputs.logits[:, -1, :]`)
263
+ logits = output["logits"].map { |v| v[-1] }
264
+
265
+ # Apply logits processor
266
+ logits_processor.(beam[:output_token_ids], logits)
267
+
268
+ sampled_tokens = sampler.(logits)
269
+ sampled_tokens.each do |new_token_id, log_prob|
270
+ # use previous beam as a starting point
271
+ new_beam = beam.dup
272
+
273
+ # update new beam
274
+ update_beam(new_beam, new_token_id)
275
+
276
+ new_beam[:score] += log_prob
277
+
278
+ if eos_token_ids && eos_token_ids.include?(new_token_id)
279
+ new_beam[:done] = true
280
+ end
281
+
282
+ newest_beams << new_beam
283
+ end
284
+ end
285
+ num_output_tokens += 1
286
+
287
+ # Next, we get the best beams, per ID
288
+ newest_beams =
289
+ group_beams(newest_beams).map do |group|
290
+ group.sort_by { |v| -v[:score] }[0...generation_config["num_beams"]]
291
+ end
292
+
293
+ # Flatten beams
294
+ beams = newest_beams.flatten(1)
295
+
296
+ # Run callback
297
+ if generation_config["callback_function"]
298
+ generation_config["callback_function"].(beams)
299
+ end
300
+ end
301
+
302
+ # TODO: Ensure that we can return non-batched outputs
303
+
304
+ grouped_beams = group_beams(beams)
305
+
306
+ get_flattened = lambda do |key|
307
+ grouped_beams.flat_map do |batch|
308
+ if generation_config["num_return_sequences"] > 1
309
+ raise Todo
310
+ else
311
+ [batch[0][key]]
312
+ end
313
+ end
314
+ end
315
+
316
+ sequences = get_flattened.(:output_token_ids) # [1, seqLength]
317
+
318
+ if generation_config["return_dict_in_generate"]
319
+ raise Todo
320
+ else
321
+ sequences
322
+ end
323
+ end
324
+
156
325
  private
157
326
 
327
+ def get_logits_processor(
328
+ generation_config,
329
+ input_ids_seq_length,
330
+ logits_processor = nil
331
+ )
332
+ processors = Utils::LogitsProcessorList.new
333
+
334
+ if !generation_config["repetition_penalty"].nil? && generation_config["repetition_penalty"] != 1.0
335
+ processors.push(Utils::RepetitionPenaltyLogitsProcessor.new(generation_config["repetition_penalty"]))
336
+ end
337
+
338
+ if !generation_config["no_repeat_ngram_size"].nil? && generation_config["no_repeat_ngram_size"] > 0
339
+ processors.push(Utils::NoRepeatNGramLogitsProcessor.new(generation_config["no_repeat_ngram_size"]))
340
+ end
341
+
342
+ if !generation_config["bad_words_ids"].nil?
343
+ processors.push(Utils::NoBadWordsLogitsProcessor.new(generation_config["bad_words_ids"], generation_config["eos_token_id"]))
344
+ end
345
+
346
+ if !generation_config["min_length"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_length"] > 0
347
+ processors.push(Utils::MinLengthLogitsProcessor.new(generation_config["min_length"], generation_config["eos_token_id"]))
348
+ end
349
+
350
+ if !generation_config["min_new_tokens"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_new_tokens"] > 0
351
+ processors.push(Utils::MinNewTokensLengthLogitsProcessor.new(
352
+ input_ids_seq_length,
353
+ generation_config["min_new_tokens"],
354
+ generation_config["eos_token_id"]
355
+ ))
356
+ end
357
+
358
+ if !generation_config["forced_bos_token_id"].nil?
359
+ processors.push(Utils::ForcedBOSTokenLogitsProcessor.new(generation_config["forced_bos_token_id"]))
360
+ end
361
+
362
+ if !generation_config["forced_eos_token_id"].nil?
363
+ processors.push(Utils::ForcedEOSTokenLogitsProcessor.new(
364
+ generation_config["max_length"],
365
+ generation_config["forced_eos_token_id"]
366
+ ))
367
+ end
368
+
369
+ if !generation_config["begin_suppress_tokens"].nil?
370
+ raise Todo
371
+ end
372
+
373
+ if !generation_config["forced_decoder_ids"].nil?
374
+ processors.push(Utils::ForceTokensLogitsProcessor.new(generation_config["forced_decoder_ids"]))
375
+ end
376
+
377
+ if !logits_processor.nil?
378
+ processors.concat(logits_processor)
379
+ end
380
+
381
+ processors
382
+ end
383
+
384
+ def get_generation_config(generation_config)
385
+ # Create empty generation config (contains defaults)
386
+ # We pass `@config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
387
+ gen_config = Utils::GenerationConfig.new(@config.to_h)
388
+
389
+ # Apply model's generation config, if it exists
390
+ if @generation_config
391
+ gen_config.merge!(@generation_config)
392
+ end
393
+
394
+ # Finally, use any generation config specified by the user
395
+ # when calling `generate`
396
+ if !generation_config.nil?
397
+ gen_config.merge!(generation_config)
398
+ end
399
+
400
+ gen_config
401
+ end
402
+
403
+ def seq2seq_forward(model_inputs)
404
+ encoder_outputs = model_inputs[:encoder_outputs]
405
+ past_key_values = model_inputs[:past_key_values]
406
+
407
+ if !encoder_outputs
408
+ # Encoder outputs are not given, so we must compute them.
409
+ encoder_outputs = encoder_forward(model_inputs)[0]
410
+ end
411
+ decoder_feeds = {
412
+ input_ids: model_inputs[:decoder_input_ids],
413
+ encoder_hidden_states: encoder_outputs
414
+ }
415
+ use_cache_branch = !!past_key_values
416
+
417
+ if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
418
+ decoder_feeds[:use_cache_branch] = [use_cache_branch]
419
+ end
420
+
421
+ if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("encoder_attention_mask")
422
+ decoder_feeds[:encoder_attention_mask] = model_inputs[:attention_mask]
423
+ end
424
+
425
+ prepare_position_ids(@decoder_merged_session, decoder_feeds, use_cache_branch)
426
+ add_past_key_values(decoder_feeds, past_key_values)
427
+
428
+ decoder_results = session_run(@decoder_merged_session, decoder_feeds)
429
+ decoder_results = @decoder_merged_session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
430
+ logits = decoder_results["logits"]
431
+ past_key_values = get_past_key_values(decoder_results, past_key_values)
432
+
433
+ # Get cross attention and/or decoder attentions if they are present
434
+ attns = get_attentions(decoder_results)
435
+
436
+ Seq2SeqLMOutput.new(logits, past_key_values, encoder_outputs, attns["decoder_attentions"], attns["cross_attentions"])
437
+ end
438
+
439
+ def prepare_position_ids(session, feeds, use_cache_branch)
440
+ if !session.inputs.map { |v| v[:name] }.include?("position_ids")
441
+ return
442
+ end
443
+
444
+ raise Todo
445
+ end
446
+
447
+ def get_past_key_values(decoder_results, past_key_values)
448
+ pkvs = {}
449
+
450
+ decoder_results.each_key do |name|
451
+ if name.start_with?("present")
452
+ new_name = name.sub("present", "past_key_values")
453
+
454
+ if past_key_values && name.include?("encoder")
455
+ # Optimization introduced by optimum to reuse past key values. So, we just replace the constant
456
+ # outputs with the previous past key values.
457
+ # https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
458
+ pkvs[new_name] = past_key_values[new_name]
459
+ else
460
+ pkvs[new_name] = decoder_results[name]
461
+ end
462
+ end
463
+ end
464
+ pkvs
465
+ end
466
+
467
+ def get_attentions(decoder_results)
468
+ attns = {}
469
+
470
+ ["cross_attentions", "decoder_attentions"].each do |attn_name|
471
+ result = []
472
+ decoder_results.each_key do |name|
473
+ if name.start_with?(attn_name)
474
+ index = name.split(".").pop
475
+ result[index] = decoder_results[name]
476
+ end
477
+ end
478
+ attns[attn_name] = result
479
+ end
480
+ attns
481
+ end
482
+
483
+ def add_past_key_values(decoder_feeds, past_key_values)
484
+ if past_key_values
485
+ decoder_feeds.merge!(past_key_values)
486
+ else
487
+ # TODO support batches (i.e., batch_size > 1)
488
+ batch_size = 1
489
+
490
+ if @config[:is_encoder_decoder] && (!@add_encoder_pkv.nil? ? @add_encoder_pkv : true)
491
+ _encoder_dims = [batch_size, @num_encoder_heads, 0, @encoder_dim_kv]
492
+ _decoder_dims = [batch_size, @num_decoder_heads, 0, @decoder_dim_kv]
493
+ @num_decoder_layers.times do |i|
494
+ # decoder_feeds["past_key_values.#{i}.encoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
495
+ # decoder_feeds["past_key_values.#{i}.encoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
496
+ # decoder_feeds["past_key_values.#{i}.decoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
497
+ # decoder_feeds["past_key_values.#{i}.decoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
498
+ end
499
+ elsif @config[:model_type] == "falcon"
500
+ raise Todo
501
+ elsif @config[:multi_query]
502
+ raise Todo
503
+ elsif @config[:model_type] == "bloom"
504
+ raise Todo
505
+ else
506
+ _dims = [batch_size, @num_heads, 0, @dim_kv]
507
+ @num_layers.times do |i|
508
+ # decoder_feeds["past_key_values.#{i}.key"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
509
+ # decoder_feeds["past_key_values.#{i}.value"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
510
+ end
511
+ end
512
+ end
513
+ end
514
+
515
+ def seq2seq_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask = nil)
516
+ beams = []
517
+ beam_id = 0
518
+
519
+ requires_attention_mask = !@requires_attention_mask.nil? ? @requires_attention_mask : true
520
+
521
+ # decoder_input_ids == output_token_ids
522
+ decoder_input_ids =
523
+ generation_config["decoder_input_ids"] ||
524
+ generation_config["decoder_start_token_id"] ||
525
+ generation_config["bos_token_id"] ||
526
+ generation_config["eos_token_id"]
527
+
528
+ if !decoder_input_ids.is_a?(Array)
529
+ decoder_input_ids = [decoder_input_ids]
530
+ end
531
+
532
+ input_token_ids.each do |tokens|
533
+ # TODO: Improve
534
+ # Currently, just add back batch dimension.
535
+ # In future, allow for true parallel execution
536
+ tokens = [tokens]
537
+
538
+ # Create beam
539
+ start = {
540
+ inputs: tokens,
541
+ encoder_outputs: nil,
542
+ prev_model_outputs: nil,
543
+
544
+ output_token_ids: decoder_input_ids,
545
+ done: false,
546
+ score: 0,
547
+ id: beam_id # assign unique id to beams
548
+ }
549
+ beam_id += 1
550
+
551
+ if requires_attention_mask
552
+ start[:attention_mask] = prepare_attention_mask(tokens)
553
+ end
554
+
555
+ beams << start
556
+ end
557
+
558
+ beams
559
+ end
560
+
561
+ def prepare_attention_mask(tokens)
562
+ # Prepare attention mask
563
+ pad_token_id = @config["pad_token_id"]
564
+ eos_token_id = @config["eos_token_id"]
565
+ if eos_token_id.is_a?(Integer)
566
+ eos_token_id = [eos_token_id]
567
+ end
568
+
569
+ is_pad_token_in_inputs = !tokens.index(pad_token_id).nil?
570
+ is_pad_token_not_equal_to_eos_token_id = eos_token_id.nil? || !eos_token_id.include?(pad_token_id)
571
+
572
+ if is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id
573
+ raise Todo
574
+ else
575
+ Utils.ones_like(tokens)
576
+ end
577
+ end
578
+
579
+ def seq2seq_run_beam(beam)
580
+ input_name = self.class.const_get(:MAIN_INPUT_NAME)
581
+
582
+ decoder_input_ids = beam[:output_token_ids]
583
+ if beam[:prev_model_outputs]
584
+ # After the first step, `prev_model_outputs` won't be null.
585
+ # So, we cut decoder_input_ids if past is used
586
+ decoder_input_ids = [decoder_input_ids[-1]]
587
+ end
588
+
589
+ # 1. Prepare
590
+ model_inputs = {
591
+ input_name => beam[:inputs],
592
+ decoder_input_ids: [decoder_input_ids],
593
+ encoder_outputs: beam[:encoder_outputs],
594
+ past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
595
+ }
596
+ if beam[:attention_mask]
597
+ model_inputs[:attention_mask] = beam[:attention_mask]
598
+ end
599
+
600
+ # 2. Run
601
+ output = @forward.(model_inputs)
602
+
603
+ # 3. Update
604
+ beam[:prev_model_outputs] = output
605
+ beam[:encoder_outputs] = output[:encoder_outputs]
606
+
607
+ output
608
+ end
609
+
610
+ def seq2seq_update_beam(beam, new_token_id)
611
+ beam[:output_token_ids] += [new_token_id]
612
+ end
613
+
614
+ def group_beams(beams)
615
+ # Group beams by their ids
616
+ groups = {}
617
+ beams.each do |obj|
618
+ if !groups[obj[:id]]
619
+ groups[obj[:id]] = [obj]
620
+ else
621
+ groups[obj[:id]] << obj
622
+ end
623
+ end
624
+ groups.values
625
+ end
626
+
158
627
  def encoder_forward(model_inputs, output_names: nil)
159
628
  encoder_feeds = {}
160
629
  @session.inputs.each do |input|
@@ -167,7 +636,96 @@ module Informers
167
636
  session_run(@session, encoder_feeds, output_names:)
168
637
  end
169
638
 
170
- def session_run(session, inputs, output_names:)
639
+ def decoder_forward(model_inputs)
640
+ input_ids, past_key_values, attention_mask =
641
+ model_inputs.values_at(:input_ids, :past_key_values, :attention_mask)
642
+ decoder_feeds = {
643
+ input_ids: input_ids,
644
+ attention_mask: attention_mask || prepare_attention_mask(input_ids)
645
+ }
646
+ use_cache_branch = !!past_key_values
647
+
648
+ if @session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
649
+ decoder_feeds[:use_cache_branch] = [use_cache_branch]
650
+ end
651
+
652
+ prepare_position_ids(@session, decoder_feeds, use_cache_branch)
653
+
654
+ add_past_key_values(decoder_feeds, past_key_values)
655
+
656
+ decoder_results = session_run(@session, decoder_feeds)
657
+ decoder_results = @session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
658
+
659
+ logits = decoder_results["logits"]
660
+
661
+ past_key_values = get_past_key_values(decoder_results, past_key_values)
662
+ {"logits" => logits, past_key_values: past_key_values}
663
+ end
664
+
665
+ def decoder_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
666
+ beams = []
667
+
668
+ beam_id = 0
669
+ input_token_ids.each do |tokens|
670
+ output_token_ids = tokens.dup
671
+
672
+ # TODO: Improve
673
+ # Currently, just add back batch dimension.
674
+ # In future, allow for true parallel execution
675
+ tokens = [tokens]
676
+
677
+ if inputs_attention_mask
678
+ attn_mask = inputs_attention_mask[beam_id]
679
+ attn_mask = [attn_mask]
680
+ else
681
+ attn_mask = prepare_attention_mask(tokens)
682
+ end
683
+
684
+ start = {
685
+ input: tokens,
686
+ model_input_ids: tokens,
687
+ attention_mask: attn_mask,
688
+ prev_model_outputs: nil,
689
+
690
+ output_token_ids: output_token_ids,
691
+ num_output_tokens: num_output_tokens,
692
+
693
+ done: false,
694
+ score: 0,
695
+ id: beam_id # assign unique id to beams
696
+ }
697
+ beam_id += 1
698
+
699
+ beams << start
700
+ end
701
+ beams
702
+ end
703
+
704
+ def decoder_run_beam(beam)
705
+ attn_mask_data = Array.new(beam[:output_token_ids].length, 1)
706
+
707
+ # 1. Prepare
708
+ model_inputs = {
709
+ input_ids: beam[:model_input_ids],
710
+ attention_mask: [attn_mask_data],
711
+ past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
712
+ }
713
+
714
+ # 2. Run
715
+ output = @forward.(model_inputs)
716
+
717
+ # 3. Update
718
+ beam[:prev_model_outputs] = output
719
+
720
+ output
721
+ end
722
+
723
+ def decoder_update_beam(beam, new_token_id)
724
+ beam[:output_token_ids] += [new_token_id]
725
+ beam[:model_input_ids] = [[new_token_id]]
726
+ end
727
+
728
+ def session_run(session, inputs, output_names: nil)
171
729
  checked_inputs = validate_inputs(session, inputs)
172
730
  begin
173
731
  output = session.run(output_names || @output_names, checked_inputs)
@@ -187,6 +745,18 @@ module Informers
187
745
  def validate_inputs(session, inputs)
188
746
  inputs
189
747
  end
748
+
749
+ def get_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
750
+ @get_start_beams.(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
751
+ end
752
+
753
+ def run_beam(beam)
754
+ @run_beam.(beam)
755
+ end
756
+
757
+ def update_beam(beam, new_token_id)
758
+ @update_beam.(beam, new_token_id)
759
+ end
190
760
  end
191
761
 
192
762
  class BertPreTrainedModel < PreTrainedModel
@@ -195,6 +765,12 @@ module Informers
195
765
  class BertModel < BertPreTrainedModel
196
766
  end
197
767
 
768
+ class BertForMaskedLM < BertPreTrainedModel
769
+ def call(model_inputs)
770
+ MaskedLMOutput.new(*super(model_inputs))
771
+ end
772
+ end
773
+
198
774
  class BertForSequenceClassification < BertPreTrainedModel
199
775
  def call(model_inputs)
200
776
  SequenceClassifierOutput.new(*super(model_inputs))
@@ -243,6 +819,126 @@ module Informers
243
819
  class MPNetModel < MPNetPreTrainedModel
244
820
  end
245
821
 
822
+ class T5PreTrainedModel < PreTrainedModel
823
+ end
824
+
825
+ class T5Model < T5PreTrainedModel
826
+ end
827
+
828
+ class T5ForConditionalGeneration < T5PreTrainedModel
829
+ def initialize(config, session, decoder_merged_session, generation_config)
830
+ super(config, session)
831
+ @decoder_merged_session = decoder_merged_session
832
+ @generation_config = generation_config
833
+
834
+ @num_decoder_layers = @config[:num_decoder_layers]
835
+ @num_decoder_heads = @config[:num_heads]
836
+ @decoder_dim_kv = @config[:d_kv]
837
+
838
+ @num_encoder_layers = @config[:num_layers]
839
+ @num_encoder_heads = @config[:num_heads]
840
+ @encoder_dim_kv = @config[:d_kv]
841
+ end
842
+ end
843
+
844
+ class BartPretrainedModel < PreTrainedModel
845
+ end
846
+
847
+ class BartModel < BartPretrainedModel
848
+ end
849
+
850
+ class BartForConditionalGeneration < BartPretrainedModel
851
+ def initialize(config, session, decoder_merged_session, generation_config)
852
+ super(config, session)
853
+ @decoder_merged_session = decoder_merged_session
854
+ @generation_config = generation_config
855
+
856
+ @num_decoder_layers = @config["decoder_layers"]
857
+ @num_decoder_heads = @config["decoder_attention_heads"]
858
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
859
+
860
+ @num_encoder_layers = @config["encoder_layers"]
861
+ @num_encoder_heads = @config["encoder_attention_heads"]
862
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads
863
+ end
864
+ end
865
+
866
+ class BartForSequenceClassification < BartPretrainedModel
867
+ def call(model_inputs)
868
+ SequenceClassifierOutput.new(*super(model_inputs))
869
+ end
870
+ end
871
+
872
+ class MBartPreTrainedModel < PreTrainedModel
873
+ end
874
+
875
+ class MBartModel < MBartPreTrainedModel
876
+ end
877
+
878
+ class MBartForCausalLM < MBartPreTrainedModel
879
+ attr_reader :num_decoder_layers, :num_decoder_heads, :decoder_dim_kv,
880
+ :num_encoder_layers, :num_encoder_heads, :encoder_dim_kv
881
+
882
+ def initialize(config, decoder_merged_session, generation_config)
883
+ super(config, decoder_merged_session)
884
+ @generation_config = generation_config
885
+
886
+ @num_decoder_layers = @config["decoder_layers"]
887
+ @num_decoder_heads = @config["decoder_attention_heads"]
888
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
889
+
890
+ @num_encoder_layers = @config["encoder_layers"]
891
+ @num_encoder_heads = @config["encoder_attention_heads"]
892
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
893
+ end
894
+ end
895
+
896
+ class M2M100PreTrainedModel < PreTrainedModel
897
+ end
898
+
899
+ class M2M100Model < M2M100PreTrainedModel
900
+ end
901
+
902
+ class M2M100ForConditionalGeneration < M2M100PreTrainedModel
903
+ def initialize(config, session, decoder_merged_session, generation_config)
904
+ super(config, session)
905
+ @decoder_merged_session = decoder_merged_session
906
+ @generation_config = generation_config
907
+
908
+ @num_decoder_layers = @config["decoder_layers"]
909
+ @num_decoder_heads = @config["decoder_attention_heads"]
910
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
911
+
912
+ @num_encoder_layers = @config["encoder_layers"]
913
+ @num_encoder_heads = @config["encoder_attention_heads"]
914
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
915
+ end
916
+ end
917
+
918
+ class Wav2Vec2PreTrainedModel < PreTrainedModel
919
+ end
920
+
921
+ class Wav2Vec2Model < Wav2Vec2PreTrainedModel
922
+ end
923
+
924
+ class Wav2Vec2ForSequenceClassification < Wav2Vec2PreTrainedModel
925
+ def call(model_inputs)
926
+ SequenceClassifierOutput.new(*super(model_inputs))
927
+ end
928
+ end
929
+
930
+ class RobertaPreTrainedModel < PreTrainedModel
931
+ end
932
+
933
+ class RobertaModel < RobertaPreTrainedModel
934
+ end
935
+
936
+ class RobertaForMaskedLM < RobertaPreTrainedModel
937
+ def call(model_inputs)
938
+ MaskedLMOutput.new(*super(model_inputs))
939
+ end
940
+ end
941
+
246
942
  class XLMRobertaPreTrainedModel < PreTrainedModel
247
943
  end
248
944
 
@@ -255,34 +951,351 @@ module Informers
255
951
  end
256
952
  end
257
953
 
954
+ class ViTPreTrainedModel < PreTrainedModel
955
+ end
956
+
957
+ class ViTModel < ViTPreTrainedModel
958
+ end
959
+
960
+ class ViTForImageClassification < ViTPreTrainedModel
961
+ def call(model_inputs)
962
+ SequenceClassifierOutput.new(*super(model_inputs))
963
+ end
964
+ end
965
+
966
+ class CLIPPreTrainedModel < PreTrainedModel
967
+ end
968
+
969
+ class CLIPModel < CLIPPreTrainedModel
970
+ end
971
+
972
+ class GPT2PreTrainedModel < PreTrainedModel
973
+ attr_reader :num_heads, :num_layers, :dim_kv
974
+
975
+ def initialize(config, session, generation_config)
976
+ super(config, session)
977
+ @generation_config = generation_config
978
+
979
+ # config doesn't contain pad_token_id, so we assume it is the eos_token_id
980
+ @config["pad_token_id"] = @config["eos_token_id"]
981
+
982
+ @num_heads = @config["n_head"]
983
+ @num_layers = @config["n_layer"]
984
+ @dim_kv = @config["n_embd"] / @num_heads.to_f
985
+ end
986
+ end
987
+
988
+ class GPT2Model < GPT2PreTrainedModel
989
+ end
990
+
991
+ class GPT2LMHeadModel < GPT2PreTrainedModel
992
+ end
993
+
994
+ class OwlViTPreTrainedModel < PreTrainedModel
995
+ end
996
+
997
+ class OwlViTModel < OwlViTPreTrainedModel
998
+ end
999
+
1000
+ class OwlViTForObjectDetection < OwlViTPreTrainedModel
1001
+ end
1002
+
1003
+ class DetrPreTrainedModel < PreTrainedModel
1004
+ end
1005
+
1006
+ class DetrModel < DetrPreTrainedModel
1007
+ end
1008
+
1009
+ class DetrForObjectDetection < DetrPreTrainedModel
1010
+ def call(model_inputs)
1011
+ DetrObjectDetectionOutput.new(*super(model_inputs))
1012
+ end
1013
+ end
1014
+
1015
+ class DetrForSegmentation < DetrPreTrainedModel
1016
+ def call(model_inputs)
1017
+ DetrSegmentationOutput.new(*super(model_inputs))
1018
+ end
1019
+ end
1020
+
1021
+ class Swin2SRPreTrainedModel < PreTrainedModel
1022
+ end
1023
+
1024
+ class Swin2SRModel < Swin2SRPreTrainedModel
1025
+ end
1026
+
1027
+ class Swin2SRForImageSuperResolution < Swin2SRPreTrainedModel
1028
+ end
1029
+
1030
+ class DPTPreTrainedModel < PreTrainedModel
1031
+ end
1032
+
1033
+ class DPTModel < DPTPreTrainedModel
1034
+ end
1035
+
1036
+ class DPTForDepthEstimation < DPTPreTrainedModel
1037
+ end
1038
+
1039
+ class VisionEncoderDecoderModel < PreTrainedModel
1040
+ MAIN_INPUT_NAME = :pixel_values
1041
+
1042
+ def initialize(config, session, decoder_merged_session, generation_config)
1043
+ super(config, session)
1044
+ @decoder_merged_session = decoder_merged_session
1045
+ @generation_config = generation_config
1046
+
1047
+ # Extract configs
1048
+ encoder_config = @config["encoder"]
1049
+ decoder_config = @config["decoder"]
1050
+
1051
+ # Validate encoder
1052
+ encoder_model_type = encoder_config["model_type"]
1053
+ encoder_model = MODEL_MAPPING_NAMES_ENCODER_ONLY[encoder_model_type] || MODEL_MAPPING_NAMES_ENCODER_DECODER[encoder_model_type]
1054
+ if !encoder_model
1055
+ warn "Model type for encoder '#{encoder_model_type}' not found, assuming encoder-only architecture. Please report this."
1056
+ end
1057
+
1058
+ # Validate decoder
1059
+ decoder_model = MODEL_WITH_LM_HEAD_MAPPING_NAMES[decoder_config["model_type"]]
1060
+ if !decoder_model
1061
+ raise Error, "Unable to construct `VisionEncoderDecoder` due to unsupported decoder: \"#{decoder_config["model_type"]}\""
1062
+ end
1063
+
1064
+ decoder_model_class = decoder_model[1]
1065
+ decoder = decoder_model_class.new(decoder_config, decoder_merged_session, generation_config)
1066
+
1067
+ @add_encoder_pkv = decoder.respond_to?(:num_decoder_layers)
1068
+ if @add_encoder_pkv
1069
+ # Decoder is part of an encoder-decoder model
1070
+ @num_decoder_layers = decoder.num_decoder_layers
1071
+ @num_decoder_heads = decoder.num_decoder_heads
1072
+ @decoder_dim_kv = decoder.decoder_dim_kv
1073
+
1074
+ @num_encoder_layers = decoder.num_encoder_layers
1075
+ @num_encoder_heads = decoder.num_encoder_heads
1076
+ @encoder_dim_kv = decoder.encoder_dim_kv
1077
+ else
1078
+ # Decoder is a decoder-only model
1079
+ @num_layers = decoder.num_layers
1080
+ @num_heads = decoder.num_heads
1081
+ @dim_kv = decoder.dim_kv
1082
+ end
1083
+ end
1084
+ end
1085
+
1086
+ class DonutSwinPreTrainedModel < PreTrainedModel
1087
+ end
1088
+
1089
+ class DonutSwinModel < DonutSwinPreTrainedModel
1090
+ end
1091
+
1092
+ class WhisperPreTrainedModel < PreTrainedModel
1093
+ end
1094
+
1095
+ class WhisperModel < WhisperPreTrainedModel
1096
+ end
1097
+
1098
+ class WhisperForConditionalGeneration < WhisperPreTrainedModel
1099
+ REQUIRES_ATTENTION_MASK = false
1100
+ MAIN_INPUT_NAME = :input_features
1101
+
1102
+ def initialize(config, session, decoder_merged_session, generation_config)
1103
+ super(config, session)
1104
+ @decoder_merged_session = decoder_merged_session
1105
+ @generation_config = generation_config
1106
+
1107
+ @num_decoder_layers = @config["decoder_layers"]
1108
+ @num_decoder_heads = @config["decoder_attention_heads"]
1109
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
1110
+
1111
+ @num_encoder_layers = @config["encoder_layers"]
1112
+ @num_encoder_heads = @config["encoder_attention_heads"]
1113
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
1114
+ end
1115
+
1116
+ def generate(inputs, generation_config = nil, logits_processor = nil)
1117
+ raise Todo
1118
+ end
1119
+ end
1120
+
1121
+ class VitsPreTrainedModel < PreTrainedModel
1122
+ end
1123
+
1124
+ class VitsModel < VitsPreTrainedModel
1125
+ def call(model_inputs)
1126
+ VitsModelOutput.new(*super(model_inputs))
1127
+ end
1128
+ end
1129
+
1130
+ class SpeechT5PreTrainedModel < PreTrainedModel
1131
+ end
1132
+
1133
+ class SpeechT5Model < SpeechT5PreTrainedModel
1134
+ end
1135
+
1136
+ class SpeechT5ForSpeechToText < SpeechT5PreTrainedModel
1137
+ end
1138
+
1139
+ class SpeechT5ForTextToSpeech < SpeechT5PreTrainedModel
1140
+ end
1141
+
1142
+ class ClapPreTrainedModel < PreTrainedModel
1143
+ end
1144
+
1145
+ class ClapModel < ClapPreTrainedModel
1146
+ end
1147
+
258
1148
  MODEL_MAPPING_NAMES_ENCODER_ONLY = {
259
1149
  "bert" => ["BertModel", BertModel],
260
1150
  "nomic_bert" => ["NomicBertModel", NomicBertModel],
261
1151
  "deberta-v2" => ["DebertaV2Model", DebertaV2Model],
262
1152
  "mpnet" => ["MPNetModel", MPNetModel],
263
1153
  "distilbert" => ["DistilBertModel", DistilBertModel],
264
- "xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel]
1154
+ "roberta" => ["RobertaModel", RobertaModel],
1155
+ "xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel],
1156
+ "clap" => ["ClapModel", ClapModel],
1157
+ "clip" => ["CLIPModel", CLIPModel],
1158
+ "detr" => ["DetrModel", DetrModel],
1159
+ "vit" => ["ViTModel", ViTModel],
1160
+ "owlvit" => ["OwlViTModel", OwlViTModel],
1161
+ "donut-swin" => ["DonutSwinModel", DonutSwinModel]
1162
+ }
1163
+
1164
+ MODEL_MAPPING_NAMES_ENCODER_DECODER = {
1165
+ "bart" => ["BartModel", BartModel]
1166
+ }
1167
+
1168
+ MODEL_MAPPING_NAMES_DECODER_ONLY = {
1169
+ }
1170
+
1171
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = {
1172
+ "whisper" => ["WhisperForConditionalGeneration", WhisperForConditionalGeneration]
1173
+ }
1174
+
1175
+ MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = {
1176
+ "speecht5" => ["SpeechT5ForTextToSpeech", SpeechT5ForTextToSpeech]
1177
+ }
1178
+
1179
+ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = {
1180
+ "vits" => ["VitsModel", VitsModel]
265
1181
  }
266
1182
 
267
1183
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
268
1184
  "bert" => ["BertForSequenceClassification", BertForSequenceClassification],
269
1185
  "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
270
- "xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification]
1186
+ "xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
1187
+ "bart" => ["BartForSequenceClassification", BartForSequenceClassification]
271
1188
  }
272
1189
 
273
1190
  MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
274
1191
  "bert" => ["BertForTokenClassification", BertForTokenClassification]
275
1192
  }
276
1193
 
1194
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {
1195
+ "t5" => ["T5ForConditionalGeneration", T5ForConditionalGeneration],
1196
+ "bart" => ["BartForConditionalGeneration", BartForConditionalGeneration],
1197
+ "m2m_100" => ["M2M100ForConditionalGeneration", M2M100ForConditionalGeneration]
1198
+ }
1199
+
1200
+ MODEL_WITH_LM_HEAD_MAPPING_NAMES = {
1201
+ "gpt2" => ["GPT2LMHeadModel", GPT2LMHeadModel],
1202
+ "mbart" => ["MBartForCausalLM", MBartForCausalLM]
1203
+ }
1204
+
1205
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
1206
+ "bert" => ["BertForMaskedLM", BertForMaskedLM],
1207
+ "roberta" => ["RobertaForMaskedLM", RobertaForMaskedLM]
1208
+ }
1209
+
277
1210
  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
278
1211
  "distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
279
1212
  }
280
1213
 
1214
+ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {
1215
+ "vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
1216
+ }
1217
+
1218
+ MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = {
1219
+ "vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
1220
+ }
1221
+
1222
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
1223
+ "vit" => ["ViTForImageClassification", ViTForImageClassification]
1224
+ }
1225
+
1226
+ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = {
1227
+ "detr" => ["DetrForObjectDetection", DetrForObjectDetection]
1228
+ }
1229
+
1230
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = {
1231
+ "owlvit" => ["OwlViTForObjectDetection", OwlViTForObjectDetection]
1232
+ }
1233
+
1234
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = {
1235
+ "detr" => ["DetrForSegmentation", DetrForSegmentation]
1236
+ }
1237
+
1238
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = {
1239
+ }
1240
+
1241
+ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = {
1242
+ }
1243
+
1244
+ MODEL_FOR_CTC_MAPPING_NAMES = {
1245
+ }
1246
+
1247
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = {
1248
+ "wav2vec2" => ["Wav2Vec2ForSequenceClassification", Wav2Vec2ForSequenceClassification]
1249
+ }
1250
+
1251
+ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = {
1252
+ }
1253
+
1254
+ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = {
1255
+ }
1256
+
1257
+ MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = {
1258
+ }
1259
+
1260
+ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = {
1261
+ "swin2sr" => ["Swin2SRForImageSuperResolution", Swin2SRForImageSuperResolution]
1262
+ }
1263
+
1264
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = {
1265
+ "dpt" => ["DPTForDepthEstimation", DPTForDepthEstimation]
1266
+ }
1267
+
1268
+ MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = {
1269
+ }
1270
+
281
1271
  MODEL_CLASS_TYPE_MAPPING = [
282
1272
  [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES[:EncoderOnly]],
1273
+ [MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES[:EncoderDecoder]],
1274
+ [MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES[:DecoderOnly]],
283
1275
  [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
284
1276
  [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
285
- [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
1277
+ [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
1278
+ [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
1279
+ [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES[:DecoderOnly]],
1280
+ [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1281
+ [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1282
+ [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Vision2Seq]],
1283
+ [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1284
+ [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1285
+ [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1286
+ [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1287
+ [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1288
+ [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1289
+ [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1290
+ [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1291
+ [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES[:MaskGeneration]],
1292
+ [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1293
+ [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1294
+ [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
1295
+ [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1296
+ [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1297
+ [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1298
+ [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
286
1299
  ]
287
1300
 
288
1301
  MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
@@ -306,11 +1319,113 @@ module Informers
306
1319
  MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
307
1320
  end
308
1321
 
1322
+ class AutoModelForSeq2SeqLM < PretrainedMixin
1323
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES]
1324
+ end
1325
+
1326
+ class AutoModelForSpeechSeq2Seq < PretrainedMixin
1327
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES]
1328
+ end
1329
+
1330
+ class AutoModelForTextToSpectrogram < PretrainedMixin
1331
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES]
1332
+ end
1333
+
1334
+ class AutoModelForTextToWaveform < PretrainedMixin
1335
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES]
1336
+ end
1337
+
1338
+ class AutoModelForCausalLM < PretrainedMixin
1339
+ MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]
1340
+ end
1341
+
1342
+ class AutoModelForMaskedLM < PretrainedMixin
1343
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]
1344
+ end
1345
+
309
1346
  class AutoModelForQuestionAnswering < PretrainedMixin
310
1347
  MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
311
1348
  end
312
1349
 
1350
+ class AutoModelForVision2Seq < PretrainedMixin
1351
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]
1352
+ end
1353
+
1354
+ class AutoModelForImageClassification < PretrainedMixin
1355
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]
1356
+ end
1357
+
1358
+ class AutoModelForImageSegmentation < PretrainedMixin
1359
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]
1360
+ end
1361
+
1362
+ class AutoModelForSemanticSegmentation < PretrainedMixin
1363
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES]
1364
+ end
1365
+
1366
+ class AutoModelForObjectDetection < PretrainedMixin
1367
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]
1368
+ end
1369
+
1370
+ class AutoModelForZeroShotObjectDetection < PretrainedMixin
1371
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]
1372
+ end
1373
+
1374
+ class AutoModelForMaskGeneration < PretrainedMixin
1375
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES]
1376
+ end
1377
+
1378
+ class AutoModelForCTC < PretrainedMixin
1379
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_CTC_MAPPING_NAMES]
1380
+ end
1381
+
1382
+ class AutoModelForAudioClassification < PretrainedMixin
1383
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES]
1384
+ end
1385
+
1386
+ class AutoModelForXVector < PretrainedMixin
1387
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES]
1388
+ end
1389
+
1390
+ class AutoModelForAudioFrameClassification < PretrainedMixin
1391
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES]
1392
+ end
1393
+
1394
+ class AutoModelForDocumentQuestionAnswering < PretrainedMixin
1395
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]
1396
+ end
1397
+
1398
+ class AutoModelForImageMatting < PretrainedMixin
1399
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES]
1400
+ end
1401
+
1402
+ class AutoModelForImageToImage < PretrainedMixin
1403
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]
1404
+ end
1405
+
1406
+ class AutoModelForDepthEstimation < PretrainedMixin
1407
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES]
1408
+ end
1409
+
1410
+ class AutoModelForImageFeatureExtraction < PretrainedMixin
1411
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]
1412
+ end
1413
+
313
1414
  class ModelOutput
1415
+ def [](key)
1416
+ instance_variable_get("@#{key}")
1417
+ end
1418
+ end
1419
+
1420
+ class Seq2SeqLMOutput < ModelOutput
1421
+ def initialize(logits, past_key_values, encoder_outputs, decoder_attentions = nil, cross_attentions = nil)
1422
+ super()
1423
+ @logits = logits
1424
+ @past_key_values = past_key_values
1425
+ @encoder_outputs = encoder_outputs
1426
+ @decoder_attentions = decoder_attentions
1427
+ @cross_attentions = cross_attentions
1428
+ end
314
1429
  end
315
1430
 
316
1431
  class SequenceClassifierOutput < ModelOutput
@@ -331,6 +1446,15 @@ module Informers
331
1446
  end
332
1447
  end
333
1448
 
1449
+ class MaskedLMOutput < ModelOutput
1450
+ attr_reader :logits
1451
+
1452
+ def initialize(logits)
1453
+ super()
1454
+ @logits = logits
1455
+ end
1456
+ end
1457
+
334
1458
  class QuestionAnsweringModelOutput < ModelOutput
335
1459
  attr_reader :start_logits, :end_logits
336
1460
 
@@ -340,4 +1464,25 @@ module Informers
340
1464
  @end_logits = end_logits
341
1465
  end
342
1466
  end
1467
+
1468
+ class DetrObjectDetectionOutput < ModelOutput
1469
+ attr_reader :logits, :pred_boxes
1470
+
1471
+ def initialize(logits, pred_boxes)
1472
+ super()
1473
+ @logits = logits
1474
+ @pred_boxes = pred_boxes
1475
+ end
1476
+ end
1477
+
1478
+ class DetrSegmentationOutput < ModelOutput
1479
+ attr_reader :logits, :pred_boxes, :pred_masks
1480
+
1481
+ def initialize(logits, pred_boxes, pred_masks)
1482
+ super()
1483
+ @logits = logits
1484
+ @pred_boxes = pred_boxes
1485
+ @pred_masks = pred_masks
1486
+ end
1487
+ end
343
1488
  end