informers 1.0.3 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,7 +44,7 @@ module Informers
44
44
  end
45
45
 
46
46
  const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
47
- model_info = model_class_mapping[config.model_type]
47
+ model_info = model_class_mapping[config[:model_type]]
48
48
  if !model_info
49
49
  next # Item not found in this mapping
50
50
  end
@@ -52,15 +52,17 @@ module Informers
52
52
  end
53
53
 
54
54
  if const_defined?(:BASE_IF_FAIL)
55
- warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
55
+ warn "Unknown model class #{config[:model_type].inspect}, attempting to construct from base class."
56
56
  PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
57
57
  else
58
- raise Error, "Unsupported model type: #{config.model_type}"
58
+ raise Error, "Unsupported model type: #{config[:model_type]}"
59
59
  end
60
60
  end
61
61
  end
62
62
 
63
63
  class PreTrainedModel
64
+ MAIN_INPUT_NAME = :input_ids
65
+
64
66
  attr_reader :config
65
67
 
66
68
  def initialize(config, session)
@@ -76,9 +78,19 @@ module Informers
76
78
 
77
79
  case model_type
78
80
  when MODEL_TYPES[:DecoderOnly]
79
- raise Todo
81
+ @can_generate = true
82
+
83
+ @run_beam = method(:decoder_run_beam)
84
+ @get_start_beams = method(:decoder_start_beams)
85
+ @update_beam = method(:decoder_update_beam)
86
+ @forward = method(:decoder_forward)
80
87
  when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
81
- raise Todo
88
+ @can_generate = true
89
+
90
+ @run_beam = method(:seq2seq_run_beam)
91
+ @get_start_beams = method(:seq2seq_start_beams)
92
+ @update_beam = method(:seq2seq_update_beam)
93
+ @forward = method(:seq2seq_forward)
82
94
  when MODEL_TYPES[:EncoderDecoder]
83
95
  raise Todo
84
96
  else
@@ -110,10 +122,19 @@ module Informers
110
122
  model_type = MODEL_TYPE_MAPPING[model_name]
111
123
 
112
124
  if model_type == MODEL_TYPES[:DecoderOnly]
113
- raise Todo
125
+ info = [
126
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
127
+ construct_session(pretrained_model_name_or_path, options[:model_file_name] || "decoder_model_merged", **options),
128
+ Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
129
+ ]
114
130
 
115
131
  elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
116
- raise Todo
132
+ info = [
133
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
134
+ construct_session(pretrained_model_name_or_path, "encoder_model", **options),
135
+ construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options),
136
+ Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
137
+ ]
117
138
 
118
139
  elsif model_type == MODEL_TYPES[:MaskGeneration]
119
140
  raise Todo
@@ -123,7 +144,7 @@ module Informers
123
144
 
124
145
  else
125
146
  if model_type != MODEL_TYPES[:EncoderOnly]
126
- warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
147
+ warn "Model type for '#{model_name || config[:model_type]}' not found, assuming encoder-only architecture. Please report this."
127
148
  end
128
149
  info = [
129
150
  AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
@@ -153,8 +174,445 @@ module Informers
153
174
  @forward.(model_inputs, **kwargs)
154
175
  end
155
176
 
177
+ def generate(inputs, generation_config = nil, logits_processor = nil, inputs_attention_mask: nil)
178
+ if !@can_generate
179
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
180
+ error_message = "The current model class (#{model_name}) is not compatible with `.generate()`, as it doesn't have a language model head."
181
+ raise Error, error_message
182
+ end
183
+
184
+ if !inputs.is_a?(Array)
185
+ raise ArgumentError, "`inputs` must be an Array, but is #{inputs.class.name}"
186
+ end
187
+
188
+ if @config[:is_encoder_decoder]
189
+ # Generating from the encoder outputs
190
+ input_ids_seq_length = 0
191
+ else
192
+ input_ids_seq_length = inputs.length
193
+
194
+ # decoder-only
195
+ if input_ids_seq_length == 0
196
+ raise Error, "Must supply a non-empty array of input token ids."
197
+ end
198
+ end
199
+
200
+ # Update generation config with defaults
201
+ generation_config = get_generation_config(generation_config)
202
+
203
+ logits_processor ||= Utils::LogitsProcessorList.new
204
+
205
+ # Update logits processor
206
+ logits_processor = get_logits_processor(
207
+ generation_config,
208
+ input_ids_seq_length,
209
+ logits_processor
210
+ )
211
+
212
+ eos_token_ids = generation_config[:eos_token_id]
213
+ if !eos_token_ids.nil? && !eos_token_ids.is_a?(Array)
214
+ eos_token_ids = [eos_token_ids]
215
+ end
216
+
217
+ num_output_tokens = 1
218
+ max_output_tokens = num_output_tokens + (generation_config[:max_new_tokens] || Float::INFINITY)
219
+
220
+ # Only use max length if max_new_tokens is not provided
221
+ use_max_length = generation_config[:max_length].is_a?(Integer) && generation_config[:max_new_tokens].nil?
222
+ sampler = Utils::Sampler.get_sampler(generation_config)
223
+
224
+ beams = get_start_beams(inputs, generation_config, num_output_tokens, inputs_attention_mask)
225
+
226
+ while beams.any? { |x| !x[:done] } && num_output_tokens < max_output_tokens
227
+ newest_beams = []
228
+ beams.each do |beam|
229
+ if beam[:done]
230
+ # Add this beam back into the pool
231
+ newest_beams << beam
232
+ next
233
+ end
234
+ if use_max_length && beam[:output_token_ids].length >= generation_config["max_length"]
235
+ # Set this beam to done and add it back into the pool
236
+ beam[:done] = true
237
+ newest_beams << beam
238
+ next
239
+ end
240
+
241
+ output = run_beam(beam)
242
+
243
+ # add attentions/scores to beam only if user requested
244
+ if generation_config["output_attentions"]
245
+ add_attentions_to_beam(beam, output)
246
+ end
247
+
248
+ # Logits are of the form [batch_size, out_seq_length, vocab_size]
249
+ # In most cases, this will be [batch_size, 1, vocab_size]
250
+ # So, we select the last token's logits:
251
+ # (equivalent to `logits = outputs.logits[:, -1, :]`)
252
+ logits = output["logits"].map { |v| v[-1] }
253
+
254
+ # Apply logits processor
255
+ logits_processor.(beam[:output_token_ids], logits)
256
+
257
+ sampled_tokens = sampler.(logits)
258
+ sampled_tokens.each do |new_token_id, log_prob|
259
+ # use previous beam as a starting point
260
+ new_beam = beam.dup
261
+
262
+ # update new beam
263
+ update_beam(new_beam, new_token_id)
264
+
265
+ new_beam[:score] += log_prob
266
+
267
+ if eos_token_ids && eos_token_ids.include?(new_token_id)
268
+ new_beam[:done] = true
269
+ end
270
+
271
+ newest_beams << new_beam
272
+ end
273
+ end
274
+ num_output_tokens += 1
275
+
276
+ # Next, we get the best beams, per ID
277
+ newest_beams =
278
+ group_beams(newest_beams).map do |group|
279
+ group.sort_by { |v| -v[:score] }[0...generation_config["num_beams"]]
280
+ end
281
+
282
+ # Flatten beams
283
+ beams = newest_beams.flatten(1)
284
+
285
+ # Run callback
286
+ if generation_config["callback_function"]
287
+ generation_config["callback_function"].(beams)
288
+ end
289
+ end
290
+
291
+ # TODO: Ensure that we can return non-batched outputs
292
+
293
+ grouped_beams = group_beams(beams)
294
+
295
+ get_flattened = lambda do |key|
296
+ grouped_beams.map do |batch|
297
+ if generation_config["num_return_sequences"] > 1
298
+ raise Todo
299
+ else
300
+ [batch[0][key]]
301
+ end
302
+ end.flatten(1)
303
+ end
304
+
305
+ sequences = get_flattened.(:output_token_ids) # [1, seqLength]
306
+
307
+ if generation_config["return_dict_in_generate"]
308
+ raise Todo
309
+ else
310
+ sequences
311
+ end
312
+ end
313
+
156
314
  private
157
315
 
316
+ def get_logits_processor(
317
+ generation_config,
318
+ input_ids_seq_length,
319
+ logits_processor = nil
320
+ )
321
+ processors = Utils::LogitsProcessorList.new
322
+
323
+ if !generation_config["repetition_penalty"].nil? && generation_config["repetition_penalty"] != 1.0
324
+ processors.push(Utils::RepetitionPenaltyLogitsProcessor.new(generation_config["repetition_penalty"]))
325
+ end
326
+
327
+ if !generation_config["no_repeat_ngram_size"].nil? && generation_config["no_repeat_ngram_size"] > 0
328
+ processors.push(Utils::NoRepeatNGramLogitsProcessor.new(generation_config["no_repeat_ngram_size"]))
329
+ end
330
+
331
+ if !generation_config["bad_words_ids"].nil?
332
+ processors.push(Utils::NoBadWordsLogitsProcessor.new(generation_config["bad_words_ids"], generation_config["eos_token_id"]))
333
+ end
334
+
335
+ if !generation_config["min_length"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_length"] > 0
336
+ processors.push(Utils::MinLengthLogitsProcessor.new(generation_config["min_length"], generation_config["eos_token_id"]))
337
+ end
338
+
339
+ if !generation_config["min_new_tokens"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_new_tokens"] > 0
340
+ processors.push(Utils::MinNewTokensLengthLogitsProcessor.new(
341
+ input_ids_seq_length,
342
+ generation_config["min_new_tokens"],
343
+ generation_config["eos_token_id"]
344
+ ))
345
+ end
346
+
347
+ if !generation_config["forced_bos_token_id"].nil?
348
+ processors.push(Utils::ForcedBOSTokenLogitsProcessor.new(generation_config["forced_bos_token_id"]))
349
+ end
350
+
351
+ if !generation_config["forced_eos_token_id"].nil?
352
+ processors.push(Utils::ForcedEOSTokenLogitsProcessor.new(
353
+ generation_config["max_length"],
354
+ generation_config["forced_eos_token_id"]
355
+ ))
356
+ end
357
+
358
+ if !generation_config["begin_suppress_tokens"].nil?
359
+ raise Todo
360
+ end
361
+
362
+ if !generation_config["forced_decoder_ids"].nil?
363
+ processors.push(Utils::ForceTokensLogitsProcessor.new(generation_config["forced_decoder_ids"]))
364
+ end
365
+
366
+ if !logits_processor.nil?
367
+ processors.concat(logits_processor)
368
+ end
369
+
370
+ processors
371
+ end
372
+
373
+ def get_generation_config(generation_config)
374
+ # Create empty generation config (contains defaults)
375
+ # We pass `@config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
376
+ gen_config = Utils::GenerationConfig.new(@config.to_h)
377
+
378
+ # Apply model's generation config, if it exists
379
+ if @generation_config
380
+ gen_config.merge!(@generation_config)
381
+ end
382
+
383
+ # Finally, use any generation config specified by the user
384
+ # when calling `generate`
385
+ if !generation_config.nil?
386
+ gen_config.merge!(generation_config)
387
+ end
388
+
389
+ gen_config
390
+ end
391
+
392
+ def seq2seq_forward(model_inputs)
393
+ encoder_outputs = model_inputs[:encoder_outputs]
394
+ past_key_values = model_inputs[:past_key_values]
395
+
396
+ if !encoder_outputs
397
+ # Encoder outputs are not given, so we must compute them.
398
+ encoder_outputs = encoder_forward(model_inputs)[0]
399
+ end
400
+ decoder_feeds = {
401
+ input_ids: model_inputs[:decoder_input_ids],
402
+ encoder_hidden_states: encoder_outputs
403
+ }
404
+ use_cache_branch = !!past_key_values
405
+
406
+ if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
407
+ decoder_feeds[:use_cache_branch] = [use_cache_branch]
408
+ end
409
+
410
+ if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("encoder_attention_mask")
411
+ decoder_feeds[:encoder_attention_mask] = model_inputs[:attention_mask]
412
+ end
413
+
414
+ prepare_position_ids(@decoder_merged_session, decoder_feeds, use_cache_branch)
415
+ add_past_key_values(decoder_feeds, past_key_values)
416
+
417
+ decoder_results = session_run(@decoder_merged_session, decoder_feeds)
418
+ decoder_results = @decoder_merged_session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
419
+ logits = decoder_results["logits"]
420
+ past_key_values = get_past_key_values(decoder_results, past_key_values)
421
+
422
+ # Get cross attention and/or decoder attentions if they are present
423
+ attns = get_attentions(decoder_results)
424
+
425
+ Seq2SeqLMOutput.new(logits, past_key_values, encoder_outputs, attns["decoder_attentions"], attns["cross_attentions"])
426
+ end
427
+
428
+ def prepare_position_ids(session, feeds, use_cache_branch)
429
+ if !session.inputs.map { |v| v[:name] }.include?("position_ids")
430
+ return
431
+ end
432
+
433
+ raise Todo
434
+ end
435
+
436
+ def get_past_key_values(decoder_results, past_key_values)
437
+ pkvs = {}
438
+
439
+ decoder_results.each_key do |name|
440
+ if name.start_with?("present")
441
+ new_name = name.sub("present", "past_key_values")
442
+
443
+ if past_key_values && name.include?("encoder")
444
+ # Optimization introduced by optimum to reuse past key values. So, we just replace the constant
445
+ # outputs with the previous past key values.
446
+ # https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
447
+ pkvs[new_name] = past_key_values[new_name]
448
+ else
449
+ pkvs[new_name] = decoder_results[name]
450
+ end
451
+ end
452
+ end
453
+ pkvs
454
+ end
455
+
456
+ def get_attentions(decoder_results)
457
+ attns = {}
458
+
459
+ ["cross_attentions", "decoder_attentions"].each do |attn_name|
460
+ result = []
461
+ decoder_results.each_key do |name|
462
+ if name.start_with?(attn_name)
463
+ index = name.split(".").pop
464
+ result[index] = decoder_results[name]
465
+ end
466
+ end
467
+ attns[attn_name] = result
468
+ end
469
+ attns
470
+ end
471
+
472
+ def add_past_key_values(decoder_feeds, past_key_values)
473
+ if past_key_values
474
+ decoder_feeds.merge!(past_key_values)
475
+ else
476
+ # TODO support batches (i.e., batch_size > 1)
477
+ batch_size = 1
478
+
479
+ if @config[:is_encoder_decoder] && (!@add_encoder_pkv.nil? ? @add_encoder_pkv : true)
480
+ _encoder_dims = [batch_size, @num_encoder_heads, 0, @encoder_dim_kv]
481
+ _decoder_dims = [batch_size, @num_decoder_heads, 0, @decoder_dim_kv]
482
+ @num_decoder_layers.times do |i|
483
+ # decoder_feeds["past_key_values.#{i}.encoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
484
+ # decoder_feeds["past_key_values.#{i}.encoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
485
+ # decoder_feeds["past_key_values.#{i}.decoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
486
+ # decoder_feeds["past_key_values.#{i}.decoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
487
+ end
488
+ elsif @config[:model_type] == "falcon"
489
+ raise Todo
490
+ elsif @config[:multi_query]
491
+ raise Todo
492
+ elsif @config[:model_type] == "bloom"
493
+ raise Todo
494
+ else
495
+ _dims = [batch_size, @num_heads, 0, @dim_kv]
496
+ @num_layers.times do |i|
497
+ # decoder_feeds["past_key_values.#{i}.key"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
498
+ # decoder_feeds["past_key_values.#{i}.value"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
499
+ end
500
+ end
501
+ end
502
+ end
503
+
504
+ def seq2seq_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask = nil)
505
+ beams = []
506
+ beam_id = 0
507
+
508
+ requires_attention_mask = !@requires_attention_mask.nil? ? @requires_attention_mask : true
509
+
510
+ # decoder_input_ids == output_token_ids
511
+ decoder_input_ids =
512
+ generation_config["decoder_input_ids"] ||
513
+ generation_config["decoder_start_token_id"] ||
514
+ generation_config["bos_token_id"] ||
515
+ generation_config["eos_token_id"]
516
+
517
+ if !decoder_input_ids.is_a?(Array)
518
+ decoder_input_ids = [decoder_input_ids]
519
+ end
520
+
521
+ input_token_ids.each do |tokens|
522
+ # TODO: Improve
523
+ # Currently, just add back batch dimension.
524
+ # In future, allow for true parallel execution
525
+ tokens = [tokens]
526
+
527
+ # Create beam
528
+ start = {
529
+ inputs: tokens,
530
+ encoder_outputs: nil,
531
+ prev_model_outputs: nil,
532
+
533
+ output_token_ids: decoder_input_ids,
534
+ done: false,
535
+ score: 0,
536
+ id: beam_id # assign unique id to beams
537
+ }
538
+ beam_id += 1
539
+
540
+ if requires_attention_mask
541
+ start[:attention_mask] = prepare_attention_mask(tokens)
542
+ end
543
+
544
+ beams << start
545
+ end
546
+
547
+ beams
548
+ end
549
+
550
+ def prepare_attention_mask(tokens)
551
+ # Prepare attention mask
552
+ pad_token_id = @config["pad_token_id"]
553
+ eos_token_id = @config["eos_token_id"]
554
+ if eos_token_id.is_a?(Integer)
555
+ eos_token_id = [eos_token_id]
556
+ end
557
+
558
+ is_pad_token_in_inputs = !tokens.index(pad_token_id).nil?
559
+ is_pad_token_not_equal_to_eos_token_id = eos_token_id.nil? || !eos_token_id.include?(pad_token_id)
560
+
561
+ if is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id
562
+ raise Todo
563
+ else
564
+ Utils.ones_like(tokens)
565
+ end
566
+ end
567
+
568
+ def seq2seq_run_beam(beam)
569
+ input_name = self.class.const_get(:MAIN_INPUT_NAME)
570
+
571
+ decoder_input_ids = beam[:output_token_ids]
572
+ if beam[:prev_model_outputs]
573
+ # After the first step, `prev_model_outputs` won't be null.
574
+ # So, we cut decoder_input_ids if past is used
575
+ decoder_input_ids = [decoder_input_ids[-1]]
576
+ end
577
+
578
+ # 1. Prepare
579
+ model_inputs = {
580
+ input_name => beam[:inputs],
581
+ decoder_input_ids: [decoder_input_ids],
582
+ encoder_outputs: beam[:encoder_outputs],
583
+ past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
584
+ }
585
+ if beam[:attention_mask]
586
+ model_inputs[:attention_mask] = beam[:attention_mask]
587
+ end
588
+
589
+ # 2. Run
590
+ output = @forward.(model_inputs)
591
+
592
+ # 3. Update
593
+ beam[:prev_model_outputs] = output
594
+ beam[:encoder_outputs] = output[:encoder_outputs]
595
+
596
+ output
597
+ end
598
+
599
+ def seq2seq_update_beam(beam, new_token_id)
600
+ beam[:output_token_ids] += [new_token_id]
601
+ end
602
+
603
+ def group_beams(beams)
604
+ # Group beams by their ids
605
+ groups = {}
606
+ beams.each do |obj|
607
+ if !groups[obj[:id]]
608
+ groups[obj[:id]] = [obj]
609
+ else
610
+ groups[obj[:id]] << obj
611
+ end
612
+ end
613
+ groups.values
614
+ end
615
+
158
616
  def encoder_forward(model_inputs, output_names: nil)
159
617
  encoder_feeds = {}
160
618
  @session.inputs.each do |input|
@@ -167,7 +625,96 @@ module Informers
167
625
  session_run(@session, encoder_feeds, output_names:)
168
626
  end
169
627
 
170
- def session_run(session, inputs, output_names:)
628
+ def decoder_forward(model_inputs)
629
+ input_ids, past_key_values, attention_mask =
630
+ model_inputs.values_at(:input_ids, :past_key_values, :attention_mask)
631
+ decoder_feeds = {
632
+ input_ids: input_ids,
633
+ attention_mask: attention_mask || prepare_attention_mask(input_ids)
634
+ }
635
+ use_cache_branch = !!past_key_values
636
+
637
+ if @session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
638
+ decoder_feeds[:use_cache_branch] = [use_cache_branch]
639
+ end
640
+
641
+ prepare_position_ids(@session, decoder_feeds, use_cache_branch)
642
+
643
+ add_past_key_values(decoder_feeds, past_key_values)
644
+
645
+ decoder_results = session_run(@session, decoder_feeds)
646
+ decoder_results = @session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
647
+
648
+ logits = decoder_results["logits"]
649
+
650
+ past_key_values = get_past_key_values(decoder_results, past_key_values)
651
+ {"logits" => logits, past_key_values: past_key_values}
652
+ end
653
+
654
+ def decoder_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
655
+ beams = []
656
+
657
+ beam_id = 0
658
+ input_token_ids.each do |tokens|
659
+ output_token_ids = tokens.dup
660
+
661
+ # TODO: Improve
662
+ # Currently, just add back batch dimension.
663
+ # In future, allow for true parallel execution
664
+ tokens = [tokens]
665
+
666
+ if inputs_attention_mask
667
+ attn_mask = inputs_attention_mask[beam_id]
668
+ attn_mask = [attn_mask]
669
+ else
670
+ attn_mask = prepare_attention_mask(tokens)
671
+ end
672
+
673
+ start = {
674
+ input: tokens,
675
+ model_input_ids: tokens,
676
+ attention_mask: attn_mask,
677
+ prev_model_outputs: nil,
678
+
679
+ output_token_ids: output_token_ids,
680
+ num_output_tokens: num_output_tokens,
681
+
682
+ done: false,
683
+ score: 0,
684
+ id: beam_id # assign unique id to beams
685
+ }
686
+ beam_id += 1
687
+
688
+ beams << start
689
+ end
690
+ beams
691
+ end
692
+
693
+ def decoder_run_beam(beam)
694
+ attn_mask_data = Array.new(beam[:output_token_ids].length, 1)
695
+
696
+ # 1. Prepare
697
+ model_inputs = {
698
+ input_ids: beam[:model_input_ids],
699
+ attention_mask: [attn_mask_data],
700
+ past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
701
+ }
702
+
703
+ # 2. Run
704
+ output = @forward.(model_inputs)
705
+
706
+ # 3. Update
707
+ beam[:prev_model_outputs] = output
708
+
709
+ output
710
+ end
711
+
712
+ def decoder_update_beam(beam, new_token_id)
713
+ beam[:output_token_ids] += [new_token_id]
714
+ beam[:model_input_ids] = [[new_token_id]]
715
+ end
716
+
717
+ def session_run(session, inputs, output_names: nil)
171
718
  checked_inputs = validate_inputs(session, inputs)
172
719
  begin
173
720
  output = session.run(output_names || @output_names, checked_inputs)
@@ -187,6 +734,18 @@ module Informers
187
734
  def validate_inputs(session, inputs)
188
735
  inputs
189
736
  end
737
+
738
+ def get_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
739
+ @get_start_beams.(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
740
+ end
741
+
742
+ def run_beam(beam)
743
+ @run_beam.(beam)
744
+ end
745
+
746
+ def update_beam(beam, new_token_id)
747
+ @update_beam.(beam, new_token_id)
748
+ end
190
749
  end
191
750
 
192
751
  class BertPreTrainedModel < PreTrainedModel
@@ -195,6 +754,12 @@ module Informers
195
754
  class BertModel < BertPreTrainedModel
196
755
  end
197
756
 
757
+ class BertForMaskedLM < BertPreTrainedModel
758
+ def call(model_inputs)
759
+ MaskedLMOutput.new(*super(model_inputs))
760
+ end
761
+ end
762
+
198
763
  class BertForSequenceClassification < BertPreTrainedModel
199
764
  def call(model_inputs)
200
765
  SequenceClassifierOutput.new(*super(model_inputs))
@@ -243,6 +808,114 @@ module Informers
243
808
  class MPNetModel < MPNetPreTrainedModel
244
809
  end
245
810
 
811
+ class T5PreTrainedModel < PreTrainedModel
812
+ end
813
+
814
+ class T5Model < T5PreTrainedModel
815
+ end
816
+
817
+ class T5ForConditionalGeneration < T5PreTrainedModel
818
+ def initialize(config, session, decoder_merged_session, generation_config)
819
+ super(config, session)
820
+ @decoder_merged_session = decoder_merged_session
821
+ @generation_config = generation_config
822
+
823
+ @num_decoder_layers = @config[:num_decoder_layers]
824
+ @num_decoder_heads = @config[:num_heads]
825
+ @decoder_dim_kv = @config[:d_kv]
826
+
827
+ @num_encoder_layers = @config[:num_layers]
828
+ @num_encoder_heads = @config[:num_heads]
829
+ @encoder_dim_kv = @config[:d_kv]
830
+ end
831
+ end
832
+
833
+ class BartPretrainedModel < PreTrainedModel
834
+ end
835
+
836
+ class BartModel < BartPretrainedModel
837
+ end
838
+
839
+ class BartForConditionalGeneration < BartPretrainedModel
840
+ def initialize(config, session, decoder_merged_session, generation_config)
841
+ super(config, session)
842
+ @decoder_merged_session = decoder_merged_session
843
+ @generation_config = generation_config
844
+
845
+ @num_decoder_layers = @config["decoder_layers"]
846
+ @num_decoder_heads = @config["decoder_attention_heads"]
847
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
848
+
849
+ @num_encoder_layers = @config["encoder_layers"]
850
+ @num_encoder_heads = @config["encoder_attention_heads"]
851
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads
852
+ end
853
+ end
854
+
855
+ class BartForSequenceClassification < BartPretrainedModel
856
+ def call(model_inputs)
857
+ SequenceClassifierOutput.new(*super(model_inputs))
858
+ end
859
+ end
860
+
861
+ class MBartPreTrainedModel < PreTrainedModel
862
+ end
863
+
864
+ class MBartModel < MBartPreTrainedModel
865
+ end
866
+
867
+ class MBartForCausalLM < MBartPreTrainedModel
868
+ attr_reader :num_decoder_layers, :num_decoder_heads, :decoder_dim_kv,
869
+ :num_encoder_layers, :num_encoder_heads, :encoder_dim_kv
870
+
871
+ def initialize(config, decoder_merged_session, generation_config)
872
+ super(config, decoder_merged_session)
873
+ @generation_config = generation_config
874
+
875
+ @num_decoder_layers = @config["decoder_layers"]
876
+ @num_decoder_heads = @config["decoder_attention_heads"]
877
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
878
+
879
+ @num_encoder_layers = @config["encoder_layers"]
880
+ @num_encoder_heads = @config["encoder_attention_heads"]
881
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
882
+ end
883
+ end
884
+
885
+ class M2M100PreTrainedModel < PreTrainedModel
886
+ end
887
+
888
+ class M2M100Model < M2M100PreTrainedModel
889
+ end
890
+
891
+ class M2M100ForConditionalGeneration < M2M100PreTrainedModel
892
+ def initialize(config, session, decoder_merged_session, generation_config)
893
+ super(config, session)
894
+ @decoder_merged_session = decoder_merged_session
895
+ @generation_config = generation_config
896
+
897
+ @num_decoder_layers = @config["decoder_layers"]
898
+ @num_decoder_heads = @config["decoder_attention_heads"]
899
+ @decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
900
+
901
+ @num_encoder_layers = @config["encoder_layers"]
902
+ @num_encoder_heads = @config["encoder_attention_heads"]
903
+ @encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
904
+ end
905
+ end
906
+
907
+ class RobertaPreTrainedModel < PreTrainedModel
908
+ end
909
+
910
+ class RobertaModel < RobertaPreTrainedModel
911
+ end
912
+
913
+ class RobertaForMaskedLM < RobertaPreTrainedModel
914
+ def call(model_inputs)
915
+ MaskedLMOutput.new(*super(model_inputs))
916
+ end
917
+ end
918
+
246
919
  class XLMRobertaPreTrainedModel < PreTrainedModel
247
920
  end
248
921
 
@@ -255,34 +928,250 @@ module Informers
255
928
  end
256
929
  end
257
930
 
931
+ class ViTPreTrainedModel < PreTrainedModel
932
+ end
933
+
934
+ class ViTModel < ViTPreTrainedModel
935
+ end
936
+
937
+ class ViTForImageClassification < ViTPreTrainedModel
938
+ def call(model_inputs)
939
+ SequenceClassifierOutput.new(*super(model_inputs))
940
+ end
941
+ end
942
+
943
+ class CLIPPreTrainedModel < PreTrainedModel
944
+ end
945
+
946
+ class CLIPModel < CLIPPreTrainedModel
947
+ end
948
+
949
+ class GPT2PreTrainedModel < PreTrainedModel
950
+ attr_reader :num_heads, :num_layers, :dim_kv
951
+
952
+ def initialize(config, session, generation_config)
953
+ super(config, session)
954
+ @generation_config = generation_config
955
+
956
+ # config doesn't contain pad_token_id, so we assume it is the eos_token_id
957
+ @config["pad_token_id"] = @config["eos_token_id"]
958
+
959
+ @num_heads = @config["n_head"]
960
+ @num_layers = @config["n_layer"]
961
+ @dim_kv = @config["n_embd"] / @num_heads.to_f
962
+ end
963
+ end
964
+
965
+ class GPT2Model < GPT2PreTrainedModel
966
+ end
967
+
968
+ class GPT2LMHeadModel < GPT2PreTrainedModel
969
+ end
970
+
971
+ class OwlViTPreTrainedModel < PreTrainedModel
972
+ end
973
+
974
+ class OwlViTModel < OwlViTPreTrainedModel
975
+ end
976
+
977
+ class OwlViTForObjectDetection < OwlViTPreTrainedModel
978
+ end
979
+
980
+ class DetrPreTrainedModel < PreTrainedModel
981
+ end
982
+
983
+ class DetrModel < DetrPreTrainedModel
984
+ end
985
+
986
+ class DetrForObjectDetection < DetrPreTrainedModel
987
+ def call(model_inputs)
988
+ DetrObjectDetectionOutput.new(*super(model_inputs))
989
+ end
990
+ end
991
+
992
+ class DetrForSegmentation < DetrPreTrainedModel
993
+ def call(model_inputs)
994
+ DetrSegmentationOutput.new(*super(model_inputs))
995
+ end
996
+ end
997
+
998
+ class Swin2SRPreTrainedModel < PreTrainedModel
999
+ end
1000
+
1001
+ class Swin2SRModel < Swin2SRPreTrainedModel
1002
+ end
1003
+
1004
+ class Swin2SRForImageSuperResolution < Swin2SRPreTrainedModel
1005
+ end
1006
+
1007
+ class DPTPreTrainedModel < PreTrainedModel
1008
+ end
1009
+
1010
+ class DPTModel < DPTPreTrainedModel
1011
+ end
1012
+
1013
+ class DPTForDepthEstimation < DPTPreTrainedModel
1014
+ end
1015
+
1016
+ class VisionEncoderDecoderModel < PreTrainedModel
1017
+ MAIN_INPUT_NAME = :pixel_values
1018
+
1019
+ def initialize(config, session, decoder_merged_session, generation_config)
1020
+ super(config, session)
1021
+ @decoder_merged_session = decoder_merged_session
1022
+ @generation_config = generation_config
1023
+
1024
+ # Extract configs
1025
+ encoder_config = @config["encoder"]
1026
+ decoder_config = @config["decoder"]
1027
+
1028
+ # Validate encoder
1029
+ encoder_model_type = encoder_config["model_type"]
1030
+ encoder_model = MODEL_MAPPING_NAMES_ENCODER_ONLY[encoder_model_type] || MODEL_MAPPING_NAMES_ENCODER_DECODER[encoder_model_type]
1031
+ if !encoder_model
1032
+ warn "Model type for encoder '#{encoder_model_type}' not found, assuming encoder-only architecture. Please report this."
1033
+ end
1034
+
1035
+ # Validate decoder
1036
+ decoder_model = MODEL_WITH_LM_HEAD_MAPPING_NAMES[decoder_config["model_type"]]
1037
+ if !decoder_model
1038
+ raise Error, "Unable to construct `VisionEncoderDecoder` due to unsupported decoder: \"#{decoder_config["model_type"]}\""
1039
+ end
1040
+
1041
+ decoder_model_class = decoder_model[1]
1042
+ decoder = decoder_model_class.new(decoder_config, decoder_merged_session, generation_config)
1043
+
1044
+ @add_encoder_pkv = decoder.respond_to?(:num_decoder_layers)
1045
+ if @add_encoder_pkv
1046
+ # Decoder is part of an encoder-decoder model
1047
+ @num_decoder_layers = decoder.num_decoder_layers
1048
+ @num_decoder_heads = decoder.num_decoder_heads
1049
+ @decoder_dim_kv = decoder.decoder_dim_kv
1050
+
1051
+ @num_encoder_layers = decoder.num_encoder_layers
1052
+ @num_encoder_heads = decoder.num_encoder_heads
1053
+ @encoder_dim_kv = decoder.encoder_dim_kv
1054
+ else
1055
+ # Decoder is a decoder-only model
1056
+ @num_layers = decoder.num_layers
1057
+ @num_heads = decoder.num_heads
1058
+ @dim_kv = decoder.dim_kv
1059
+ end
1060
+ end
1061
+ end
1062
+
1063
+ class DonutSwinPreTrainedModel < PreTrainedModel
1064
+ end
1065
+
1066
+ class DonutSwinModel < DonutSwinPreTrainedModel
1067
+ end
1068
+
258
1069
  MODEL_MAPPING_NAMES_ENCODER_ONLY = {
259
1070
  "bert" => ["BertModel", BertModel],
260
1071
  "nomic_bert" => ["NomicBertModel", NomicBertModel],
261
1072
  "deberta-v2" => ["DebertaV2Model", DebertaV2Model],
262
1073
  "mpnet" => ["MPNetModel", MPNetModel],
263
1074
  "distilbert" => ["DistilBertModel", DistilBertModel],
264
- "xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel]
1075
+ "roberta" => ["RobertaModel", RobertaModel],
1076
+ "xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel],
1077
+ "clip" => ["CLIPModel", CLIPModel],
1078
+ "detr" => ["DetrModel", DetrModel],
1079
+ "vit" => ["ViTModel", ViTModel],
1080
+ "owlvit" => ["OwlViTModel", OwlViTModel],
1081
+ "donut-swin" => ["DonutSwinModel", DonutSwinModel]
1082
+ }
1083
+
1084
+ MODEL_MAPPING_NAMES_ENCODER_DECODER = {
1085
+ "bart" => ["BartModel", BartModel]
265
1086
  }
266
1087
 
267
1088
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
268
1089
  "bert" => ["BertForSequenceClassification", BertForSequenceClassification],
269
1090
  "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
270
- "xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification]
1091
+ "xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
1092
+ "bart" => ["BartForSequenceClassification", BartForSequenceClassification]
271
1093
  }
272
1094
 
273
1095
  MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
274
1096
  "bert" => ["BertForTokenClassification", BertForTokenClassification]
275
1097
  }
276
1098
 
1099
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {
1100
+ "t5" => ["T5ForConditionalGeneration", T5ForConditionalGeneration],
1101
+ "bart" => ["BartForConditionalGeneration", BartForConditionalGeneration],
1102
+ "m2m_100" => ["M2M100ForConditionalGeneration", M2M100ForConditionalGeneration]
1103
+ }
1104
+
1105
+ MODEL_WITH_LM_HEAD_MAPPING_NAMES = {
1106
+ "gpt2" => ["GPT2LMHeadModel", GPT2LMHeadModel],
1107
+ "mbart" => ["MBartForCausalLM", MBartForCausalLM]
1108
+ }
1109
+
1110
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
1111
+ "bert" => ["BertForMaskedLM", BertForMaskedLM],
1112
+ "roberta" => ["RobertaForMaskedLM", RobertaForMaskedLM]
1113
+ }
1114
+
277
1115
  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
278
1116
  "distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
279
1117
  }
280
1118
 
1119
+ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {
1120
+ "vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
1121
+ }
1122
+
1123
+ MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = {
1124
+ "vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
1125
+ }
1126
+
1127
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
1128
+ "vit" => ["ViTForImageClassification", ViTForImageClassification]
1129
+ }
1130
+
1131
+ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = {
1132
+ "detr" => ["DetrForObjectDetection", DetrForObjectDetection]
1133
+ }
1134
+
1135
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = {
1136
+ "owlvit" => ["OwlViTForObjectDetection", OwlViTForObjectDetection]
1137
+ }
1138
+
1139
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = {
1140
+ "detr" => ["DetrForSegmentation", DetrForSegmentation]
1141
+ }
1142
+
1143
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = {
1144
+ }
1145
+
1146
+ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = {
1147
+ "swin2sr" => ["Swin2SRForImageSuperResolution", Swin2SRForImageSuperResolution]
1148
+ }
1149
+
1150
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = {
1151
+ "dpt" => ["DPTForDepthEstimation", DPTForDepthEstimation]
1152
+ }
1153
+
1154
+ MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = {
1155
+ }
1156
+
281
1157
  MODEL_CLASS_TYPE_MAPPING = [
282
1158
  [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES[:EncoderOnly]],
1159
+ [MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES[:EncoderDecoder]],
283
1160
  [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
284
1161
  [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
285
- [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
1162
+ [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
1163
+ [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES[:DecoderOnly]],
1164
+ [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1165
+ [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1166
+ [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Vision2Seq]],
1167
+ [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1168
+ [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1169
+ [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1170
+ [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1171
+ [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1172
+ [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1173
+ [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
1174
+ [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
286
1175
  ]
287
1176
 
288
1177
  MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
@@ -306,11 +1195,77 @@ module Informers
306
1195
  MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
307
1196
  end
308
1197
 
1198
+ class AutoModelForSeq2SeqLM < PretrainedMixin
1199
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES]
1200
+ end
1201
+
1202
+ class AutoModelForCausalLM < PretrainedMixin
1203
+ MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]
1204
+ end
1205
+
1206
+ class AutoModelForMaskedLM < PretrainedMixin
1207
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]
1208
+ end
1209
+
309
1210
  class AutoModelForQuestionAnswering < PretrainedMixin
310
1211
  MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
311
1212
  end
312
1213
 
1214
+ class AutoModelForVision2Seq < PretrainedMixin
1215
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]
1216
+ end
1217
+
1218
+ class AutoModelForImageClassification < PretrainedMixin
1219
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]
1220
+ end
1221
+
1222
+ class AutoModelForImageSegmentation < PretrainedMixin
1223
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]
1224
+ end
1225
+
1226
+ class AutoModelForSemanticSegmentation < PretrainedMixin
1227
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES]
1228
+ end
1229
+
1230
+ class AutoModelForObjectDetection < PretrainedMixin
1231
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]
1232
+ end
1233
+
1234
+ class AutoModelForZeroShotObjectDetection < PretrainedMixin
1235
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]
1236
+ end
1237
+
1238
+ class AutoModelForDocumentQuestionAnswering < PretrainedMixin
1239
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]
1240
+ end
1241
+
1242
+ class AutoModelForImageToImage < PretrainedMixin
1243
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]
1244
+ end
1245
+
1246
+ class AutoModelForDepthEstimation < PretrainedMixin
1247
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES]
1248
+ end
1249
+
1250
+ class AutoModelForImageFeatureExtraction < PretrainedMixin
1251
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]
1252
+ end
1253
+
313
1254
  class ModelOutput
1255
+ def [](key)
1256
+ instance_variable_get("@#{key}")
1257
+ end
1258
+ end
1259
+
1260
+ class Seq2SeqLMOutput < ModelOutput
1261
+ def initialize(logits, past_key_values, encoder_outputs, decoder_attentions = nil, cross_attentions = nil)
1262
+ super()
1263
+ @logits = logits
1264
+ @past_key_values = past_key_values
1265
+ @encoder_outputs = encoder_outputs
1266
+ @decoder_attentions = decoder_attentions
1267
+ @cross_attentions = cross_attentions
1268
+ end
314
1269
  end
315
1270
 
316
1271
  class SequenceClassifierOutput < ModelOutput
@@ -331,6 +1286,15 @@ module Informers
331
1286
  end
332
1287
  end
333
1288
 
1289
+ class MaskedLMOutput < ModelOutput
1290
+ attr_reader :logits
1291
+
1292
+ def initialize(logits)
1293
+ super()
1294
+ @logits = logits
1295
+ end
1296
+ end
1297
+
334
1298
  class QuestionAnsweringModelOutput < ModelOutput
335
1299
  attr_reader :start_logits, :end_logits
336
1300
 
@@ -340,4 +1304,25 @@ module Informers
340
1304
  @end_logits = end_logits
341
1305
  end
342
1306
  end
1307
+
1308
+ class DetrObjectDetectionOutput < ModelOutput
1309
+ attr_reader :logits, :pred_boxes
1310
+
1311
+ def initialize(logits, pred_boxes)
1312
+ super()
1313
+ @logits = logits
1314
+ @pred_boxes = pred_boxes
1315
+ end
1316
+ end
1317
+
1318
+ class DetrSegmentationOutput < ModelOutput
1319
+ attr_reader :logits, :pred_boxes, :pred_masks
1320
+
1321
+ def initialize(logits, pred_boxes, pred_masks)
1322
+ super()
1323
+ @logits = logits
1324
+ @pred_boxes = pred_boxes
1325
+ @pred_masks = pred_masks
1326
+ end
1327
+ end
343
1328
  end