informers 1.0.2 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,26 @@ module Informers
7
7
  @tokenizer = tokenizer
8
8
  @processor = processor
9
9
  end
10
+
11
+ private
12
+
13
+ def prepare_images(images)
14
+ if !images.is_a?(Array)
15
+ images = [images]
16
+ end
17
+
18
+ # Possibly convert any non-images to images
19
+ images.map { |x| Utils::RawImage.read(x) }
20
+ end
21
+
22
+ def get_bounding_box(box, as_integer)
23
+ if as_integer
24
+ box = box.map { |x| x.to_i }
25
+ end
26
+ xmin, ymin, xmax, ymax = box
27
+
28
+ {xmin:, ymin:, xmax:, ymax:}
29
+ end
10
30
  end
11
31
 
12
32
  class TextClassificationPipeline < Pipeline
@@ -21,13 +41,13 @@ module Informers
21
41
  outputs = @model.(model_inputs)
22
42
 
23
43
  function_to_apply =
24
- if @model.config.problem_type == "multi_label_classification"
44
+ if @model.config[:problem_type] == "multi_label_classification"
25
45
  ->(batch) { Utils.sigmoid(batch) }
26
46
  else
27
47
  ->(batch) { Utils.softmax(batch) } # single_label_classification (default)
28
48
  end
29
49
 
30
- id2label = @model.config.id2label
50
+ id2label = @model.config[:id2label]
31
51
 
32
52
  to_return = []
33
53
  outputs.logits.each do |batch|
@@ -70,7 +90,7 @@ module Informers
70
90
  outputs = @model.(model_inputs)
71
91
 
72
92
  logits = outputs.logits
73
- id2label = @model.config.id2label
93
+ id2label = @model.config[:id2label]
74
94
 
75
95
  to_return = []
76
96
  logits.length.times do |i|
@@ -243,13 +263,535 @@ module Informers
243
263
  end
244
264
  end
245
265
 
266
+ class FillMaskPipeline < Pipeline
267
+ def call(texts, top_k: 5)
268
+ model_inputs = @tokenizer.(texts, padding: true, truncation: true)
269
+ outputs = @model.(model_inputs)
270
+
271
+ to_return = []
272
+ model_inputs[:input_ids].each_with_index do |ids, i|
273
+ mask_token_index = ids.index(@tokenizer.mask_token_id)
274
+
275
+ if mask_token_index.nil?
276
+ raise ArgumentError, "Mask token (#{@tokenizer.mask_token}) not found in text."
277
+ end
278
+ logits = outputs.logits[i]
279
+ item_logits = logits[mask_token_index]
280
+
281
+ scores = Utils.get_top_items(Utils.softmax(item_logits), top_k)
282
+
283
+ to_return <<
284
+ scores.map do |x|
285
+ sequence = ids.dup
286
+ sequence[mask_token_index] = x[0]
287
+
288
+ {
289
+ score: x[1],
290
+ token: x[0],
291
+ token_str: @tokenizer.id_to_token(x[0]),
292
+ sequence: @tokenizer.decode(sequence, skip_special_tokens: true)
293
+ }
294
+ end
295
+ end
296
+ texts.is_a?(Array) ? to_return : to_return[0]
297
+ end
298
+ end
299
+
300
+ class Text2TextGenerationPipeline < Pipeline
301
+ KEY = :generated_text
302
+
303
+ def call(texts, **generate_kwargs)
304
+ if !texts.is_a?(Array)
305
+ texts = [texts]
306
+ end
307
+
308
+ # Add global prefix, if present
309
+ if @model.config[:prefix]
310
+ texts = texts.map { |x| @model.config[:prefix] + x }
311
+ end
312
+
313
+ # Handle task specific params:
314
+ task_specific_params = @model.config[:task_specific_params]
315
+ if task_specific_params && task_specific_params[@task]
316
+ # Add prefixes, if present
317
+ if task_specific_params[@task]["prefix"]
318
+ texts = texts.map { |x| task_specific_params[@task]["prefix"] + x }
319
+ end
320
+
321
+ # TODO update generation config
322
+ end
323
+
324
+ tokenizer = @tokenizer
325
+ tokenizer_options = {
326
+ padding: true,
327
+ truncation: true
328
+ }
329
+ if is_a?(TranslationPipeline) && tokenizer.respond_to?(:_build_translation_inputs)
330
+ input_ids = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs)[:input_ids]
331
+ else
332
+ input_ids = tokenizer.(texts, **tokenizer_options)[:input_ids]
333
+ end
334
+
335
+ output_token_ids = @model.generate(input_ids, generate_kwargs)
336
+
337
+ tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)
338
+ .map { |text| {self.class.const_get(:KEY) => text} }
339
+ end
340
+ end
341
+
342
+ class SummarizationPipeline < Text2TextGenerationPipeline
343
+ KEY = :summary_text
344
+ end
345
+
346
+ class TranslationPipeline < Text2TextGenerationPipeline
347
+ KEY = :translation_text
348
+ end
349
+
350
+ class TextGenerationPipeline < Pipeline
351
+ def call(texts, **generate_kwargs)
352
+ is_batched = false
353
+ is_chat_input = false
354
+
355
+ # Normalize inputs
356
+ if texts.is_a?(String)
357
+ texts = [texts]
358
+ inputs = texts
359
+ else
360
+ raise Todo
361
+ end
362
+
363
+ # By default, do not add special tokens
364
+ add_special_tokens = generate_kwargs[:add_special_tokens] || false
365
+
366
+ # /By default, return full text
367
+ return_full_text =
368
+ if is_chat_input
369
+ false
370
+ else
371
+ generate_kwargs[:return_full_text] || true
372
+ end
373
+
374
+ @tokenizer.padding_side = "left"
375
+ input_ids, attention_mask =
376
+ @tokenizer.(inputs, add_special_tokens:, padding: true, truncation: true)
377
+ .values_at(:input_ids, :attention_mask)
378
+
379
+ output_token_ids =
380
+ @model.generate(
381
+ input_ids, generate_kwargs, nil, inputs_attention_mask: attention_mask
382
+ )
383
+
384
+ decoded = @tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)
385
+
386
+ if !return_full_text && Utils.dims(input_ids)[-1] > 0
387
+ prompt_lengths = @tokenizer.batch_decode(input_ids, skip_special_tokens: true).map { |x| x.length }
388
+ end
389
+
390
+ to_return = Array.new(texts.length) { [] }
391
+ decoded.length.times do |i|
392
+ text_index = (i / output_token_ids.length.to_i * texts.length).floor
393
+
394
+ if prompt_lengths
395
+ raise Todo
396
+ end
397
+ # TODO is_chat_input
398
+ to_return[text_index] << {
399
+ generated_text: decoded[i]
400
+ }
401
+ end
402
+ !is_batched && to_return.length == 1 ? to_return[0] : to_return
403
+ end
404
+ end
405
+
406
+ class ZeroShotClassificationPipeline < Pipeline
407
+ def initialize(**options)
408
+ super(**options)
409
+
410
+ @label2id = @model.config[:label2id].transform_keys(&:downcase)
411
+
412
+ @entailment_id = @label2id["entailment"]
413
+ if @entailment_id.nil?
414
+ warn "Could not find 'entailment' in label2id mapping. Using 2 as entailment_id."
415
+ @entailment_id = 2
416
+ end
417
+
418
+ @contradiction_id = @label2id["contradiction"] || @label2id["not_entailment"]
419
+ if @contradiction_id.nil?
420
+ warn "Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id."
421
+ @contradiction_id = 0
422
+ end
423
+ end
424
+
425
+ def call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false)
426
+ is_batched = texts.is_a?(Array)
427
+ if !is_batched
428
+ texts = [texts]
429
+ end
430
+ if !candidate_labels.is_a?(Array)
431
+ candidate_labels = [candidate_labels]
432
+ end
433
+
434
+ # Insert labels into hypothesis template
435
+ hypotheses = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }
436
+
437
+ # How to perform the softmax over the logits:
438
+ # - true: softmax over the entailment vs. contradiction dim for each label independently
439
+ # - false: softmax the "entailment" logits over all candidate labels
440
+ softmax_each = multi_label || candidate_labels.length == 1
441
+
442
+ to_return = []
443
+ texts.each do |premise|
444
+ entails_logits = []
445
+
446
+ hypotheses.each do |hypothesis|
447
+ inputs = @tokenizer.(
448
+ premise,
449
+ text_pair: hypothesis,
450
+ padding: true,
451
+ truncation: true
452
+ )
453
+ outputs = @model.(inputs)
454
+
455
+ if softmax_each
456
+ entails_logits << [
457
+ outputs.logits[0][@contradiction_id],
458
+ outputs.logits[0][@entailment_id]
459
+ ]
460
+ else
461
+ entails_logits << outputs.logits[0][@entailment_id]
462
+ end
463
+ end
464
+
465
+ scores =
466
+ if softmax_each
467
+ entails_logits.map { |x| Utils.softmax(x)[1] }
468
+ else
469
+ Utils.softmax(entails_logits)
470
+ end
471
+
472
+ # Sort by scores (desc) and return scores with indices
473
+ scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] }
474
+
475
+ to_return << {
476
+ sequence: premise,
477
+ labels: scores_sorted.map { |x| candidate_labels[x[1]] },
478
+ scores: scores_sorted.map { |x| x[0] }
479
+ }
480
+ end
481
+ is_batched ? to_return : to_return[0]
482
+ end
483
+ end
484
+
485
+ class ImageToTextPipeline < Pipeline
486
+ def call(images, **generate_kwargs)
487
+ is_batched = images.is_a?(Array)
488
+ prepared_images = prepare_images(images)
489
+
490
+ pixel_values = @processor.(prepared_images)[:pixel_values]
491
+
492
+ to_return = []
493
+ pixel_values.each do |batch|
494
+ batch = [batch]
495
+ output = @model.generate(batch, **generate_kwargs)
496
+ decoded = @tokenizer
497
+ .batch_decode(output, skip_special_tokens: true)
498
+ .map { |x| {generated_text: x.strip} }
499
+ to_return << decoded
500
+ end
501
+
502
+ is_batched ? to_return : to_return[0]
503
+ end
504
+ end
505
+
506
+ class ImageClassificationPipeline < Pipeline
507
+ def call(images, top_k: 1)
508
+ is_batched = images.is_a?(Array)
509
+ prepared_images = prepare_images(images)
510
+
511
+ pixel_values = @processor.(prepared_images)[:pixel_values]
512
+ output = @model.({pixel_values: pixel_values})
513
+
514
+ id2label = @model.config[:id2label]
515
+ to_return = []
516
+ output.logits.each do |batch|
517
+ scores = Utils.get_top_items(Utils.softmax(batch), top_k)
518
+
519
+ vals =
520
+ scores.map do |x|
521
+ {
522
+ label: id2label[x[0].to_s],
523
+ score: x[1]
524
+ }
525
+ end
526
+ if top_k == 1
527
+ to_return.push(*vals)
528
+ else
529
+ to_return << vals
530
+ end
531
+ end
532
+
533
+ is_batched || top_k == 1 ? to_return : to_return[0]
534
+ end
535
+ end
536
+
537
+ class ImageSegmentationPipeline < Pipeline
538
+ def initialize(**options)
539
+ super(**options)
540
+
541
+ @subtasks_mapping = {
542
+ "panoptic" => "post_process_panoptic_segmentation",
543
+ "instance" => "post_process_instance_segmentation",
544
+ "semantic" => "post_process_semantic_segmentation"
545
+ }
546
+ end
547
+
548
+ def call(
549
+ images,
550
+ threshold: 0.5,
551
+ mask_threshold: 0.5,
552
+ overlap_mask_area_threshold: 0.8,
553
+ label_ids_to_fuse: nil,
554
+ target_sizes: nil,
555
+ subtask: nil
556
+ )
557
+ is_batched = images.is_a?(Array)
558
+
559
+ if is_batched && images.length != 1
560
+ raise Error, "Image segmentation pipeline currently only supports a batch size of 1."
561
+ end
562
+
563
+ prepared_images = prepare_images(images)
564
+ image_sizes = prepared_images.map { |x| [x.height, x.width] }
565
+
566
+ model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)
567
+ output = @model.(model_inputs)
568
+
569
+ if !subtask.nil?
570
+ fn = @subtasks_mapping[subtask]
571
+ else
572
+ @subtasks_mapping.each do |task, func|
573
+ if @processor.feature_extractor.respond_to?(func)
574
+ fn = @processor.feature_extractor.method(func)
575
+ subtask = task
576
+ break
577
+ end
578
+ end
579
+ end
580
+
581
+ id2label = @model.config[:id2label]
582
+
583
+ annotation = []
584
+ if subtask == "panoptic" || subtask == "instance"
585
+ processed = fn.(
586
+ output,
587
+ threshold:,
588
+ mask_threshold:,
589
+ overlap_mask_area_threshold:,
590
+ label_ids_to_fuse:,
591
+ target_sizes: target_sizes || image_sizes, # TODO FIX?
592
+ )[0]
593
+
594
+ _segmentation = processed[:segmentation]
595
+
596
+ processed[:segments_info].each do |segment|
597
+ annotation << {
598
+ label: id2label[segment[:label_id].to_s],
599
+ score: segment[:score]
600
+ # TODO mask
601
+ }
602
+ end
603
+ elsif subtask == "semantic"
604
+ raise Todo
605
+ else
606
+ raise Error, "Subtask #{subtask} not supported."
607
+ end
608
+
609
+ annotation
610
+ end
611
+ end
612
+
613
+ class ZeroShotImageClassificationPipeline < Pipeline
614
+ def call(images, candidate_labels, hypothesis_template: "This is a photo of {}")
615
+ is_batched = images.is_a?(Array)
616
+ prepared_images = prepare_images(images)
617
+
618
+ # Insert label into hypothesis template
619
+ texts = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }
620
+
621
+ # Run tokenization
622
+ text_inputs = @tokenizer.(texts,
623
+ padding: @model.config[:model_type] == "siglip" ? "max_length" : true,
624
+ truncation: true
625
+ )
626
+
627
+ # Run processor
628
+ pixel_values = @processor.(prepared_images)[:pixel_values]
629
+
630
+ # Run model with both text and pixel inputs
631
+ output = @model.(text_inputs.merge(pixel_values: pixel_values))
632
+
633
+ function_to_apply =
634
+ if @model.config[:model_type] == "siglip"
635
+ ->(batch) { Utils.sigmoid(batch) }
636
+ else
637
+ ->(batch) { Utils.softmax(batch) }
638
+ end
639
+
640
+ # Compare each image with each candidate label
641
+ to_return = []
642
+ output[0].each do |batch|
643
+ # Compute softmax per image
644
+ probs = function_to_apply.(batch)
645
+
646
+ result = probs
647
+ .map.with_index { |x, i| {label: candidate_labels[i], score: x} }
648
+ .sort_by { |v| -v[:score] }
649
+
650
+ to_return << result
651
+ end
652
+
653
+ is_batched ? to_return : to_return[0]
654
+ end
655
+ end
656
+
657
+ class ObjectDetectionPipeline < Pipeline
658
+ def call(images, threshold: 0.9, percentage: false)
659
+ is_batched = images.is_a?(Array)
660
+
661
+ if is_batched && images.length != 1
662
+ raise Error, "Object detection pipeline currently only supports a batch size of 1."
663
+ end
664
+ prepared_images = prepare_images(images)
665
+
666
+ image_sizes = percentage ? nil : prepared_images.map { |x| [x.height, x.width] }
667
+
668
+ model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)
669
+ output = @model.(model_inputs)
670
+
671
+ processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_sizes)
672
+
673
+ # Add labels
674
+ id2label = @model.config[:id2label]
675
+
676
+ # Format output
677
+ result =
678
+ processed.map do |batch|
679
+ batch[:boxes].map.with_index do |box, i|
680
+ {
681
+ label: id2label[batch[:classes][i].to_s],
682
+ score: batch[:scores][i],
683
+ box: get_bounding_box(box, !percentage)
684
+ }
685
+ end.sort_by { |v| -v[:score] }
686
+ end
687
+
688
+ is_batched ? result : result[0]
689
+ end
690
+ end
691
+
692
+ class ZeroShotObjectDetectionPipeline < Pipeline
693
+ def call(
694
+ images,
695
+ candidate_labels,
696
+ threshold: 0.1,
697
+ top_k: nil,
698
+ percentage: false
699
+ )
700
+ is_batched = images.is_a?(Array)
701
+ prepared_images = prepare_images(images)
702
+
703
+ # Run tokenization
704
+ text_inputs = @tokenizer.(candidate_labels,
705
+ padding: true,
706
+ truncation: true
707
+ )
708
+
709
+ # Run processor
710
+ model_inputs = @processor.(prepared_images)
711
+
712
+ # Since non-maximum suppression is performed for exporting, we need to
713
+ # process each image separately. For more information, see:
714
+ # https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
715
+ to_return = []
716
+ prepared_images.length.times do |i|
717
+ image = prepared_images[i]
718
+ image_size = percentage ? nil : [[image.height, image.width]]
719
+ pixel_values = [model_inputs[:pixel_values][i]]
720
+
721
+ # Run model with both text and pixel inputs
722
+ output = @model.(text_inputs.merge(pixel_values: pixel_values))
723
+ # TODO remove
724
+ output = @model.instance_variable_get(:@session).outputs.map { |v| v[:name].to_sym }.zip(output).to_h
725
+
726
+ processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_size, true)[0]
727
+ result =
728
+ processed[:boxes].map.with_index do |box, i|
729
+ {
730
+ label: candidate_labels[processed[:classes][i]],
731
+ score: processed[:scores][i],
732
+ box: get_bounding_box(box, !percentage),
733
+ }
734
+ end
735
+ result.sort_by! { |v| -v[:score] }
736
+ if !top_k.nil?
737
+ result = result[0...topk]
738
+ end
739
+ to_return << result
740
+ end
741
+
742
+ is_batched ? to_return : to_return[0]
743
+ end
744
+ end
745
+
746
+ class DocumentQuestionAnsweringPipeline < Pipeline
747
+ def call(image, question, **generate_kwargs)
748
+ # NOTE: For now, we only support a batch size of 1
749
+
750
+ # Preprocess image
751
+ prepared_image = prepare_images(image)[0]
752
+ pixel_values = @processor.(prepared_image)[:pixel_values]
753
+
754
+ # Run tokenization
755
+ task_prompt = "<s_docvqa><s_question>#{question}</s_question><s_answer>"
756
+ decoder_input_ids =
757
+ @tokenizer.(
758
+ task_prompt,
759
+ add_special_tokens: false,
760
+ padding: true,
761
+ truncation: true
762
+ )[:input_ids]
763
+
764
+ # Run model
765
+ output =
766
+ @model.generate(
767
+ pixel_values,
768
+ generate_kwargs.merge(
769
+ decoder_input_ids: decoder_input_ids[0],
770
+ max_length: @model.config["decoder"]["max_position_embeddings"]
771
+ ).transform_keys(&:to_s)
772
+ )
773
+
774
+ # Decode output
775
+ decoded = @tokenizer.batch_decode(output, skip_special_tokens: false)[0]
776
+
777
+ # Parse answer
778
+ match = decoded.match(/<s_answer>(.*?)<\/s_answer>/)
779
+ answer = nil
780
+ if match && match.length >= 2
781
+ answer = match[1].strip
782
+ end
783
+ [{answer:}]
784
+ end
785
+ end
786
+
246
787
  class FeatureExtractionPipeline < Pipeline
247
788
  def call(
248
789
  texts,
249
790
  pooling: "none",
250
791
  normalize: false,
251
792
  quantize: false,
252
- precision: "binary"
793
+ precision: "binary",
794
+ model_output: nil
253
795
  )
254
796
  # Run tokenization
255
797
  model_inputs = @tokenizer.(texts,
@@ -258,8 +800,10 @@ module Informers
258
800
  )
259
801
  model_options = {}
260
802
 
261
- # optimization for sentence-transformers/all-MiniLM-L6-v2
262
- if @model.instance_variable_get(:@output_names) == ["token_embeddings"] && pooling == "mean" && normalize
803
+ if !model_output.nil?
804
+ model_options[:output_names] = Array(model_output)
805
+ elsif @model.instance_variable_get(:@output_names) == ["token_embeddings"] && pooling == "mean" && normalize
806
+ # optimization for sentence-transformers/all-MiniLM-L6-v2
263
807
  model_options[:output_names] = ["sentence_embedding"]
264
808
  pooling = "none"
265
809
  normalize = false
@@ -271,7 +815,9 @@ module Informers
271
815
  # TODO improve
272
816
  result =
273
817
  if outputs.is_a?(Array)
274
- raise Error, "unexpected outputs" if outputs.size != 1
818
+ # TODO show returned instead of all
819
+ output_names = @model.instance_variable_get(:@session).outputs.map { |v| v[:name] }
820
+ raise Error, "unexpected outputs: #{output_names}" if outputs.size != 1
275
821
  outputs[0]
276
822
  else
277
823
  outputs.logits
@@ -285,6 +831,7 @@ module Informers
285
831
  when "cls"
286
832
  result = result.map(&:first)
287
833
  else
834
+ # TODO raise ArgumentError in 2.0
288
835
  raise Error, "Pooling method '#{pooling}' not supported."
289
836
  end
290
837
 
@@ -300,13 +847,77 @@ module Informers
300
847
  end
301
848
  end
302
849
 
850
+ class ImageFeatureExtractionPipeline < Pipeline
851
+ def call(images)
852
+ prepared_images = prepare_images(images)
853
+ pixel_values = @processor.(prepared_images)[:pixel_values]
854
+ outputs = @model.({pixel_values: pixel_values})
855
+
856
+ result = outputs[0]
857
+ result
858
+ end
859
+ end
860
+
861
+ class ImageToImagePipeline < Pipeline
862
+ def call(images)
863
+ prepared_images = prepare_images(images)
864
+ inputs = @processor.(prepared_images)
865
+ outputs = @model.(inputs);
866
+
867
+ to_return = []
868
+ outputs[0].each do |batch|
869
+ # TODO flatten first
870
+ output =
871
+ batch.map do |v|
872
+ v.map do |v2|
873
+ v2.map do |v3|
874
+ (v3.clamp(0, 1) * 255).round
875
+ end
876
+ end
877
+ end
878
+ to_return << Utils::RawImage.from_array(output).image
879
+ end
880
+
881
+ to_return.length > 1 ? to_return : to_return[0]
882
+ end
883
+ end
884
+
885
+ class DepthEstimationPipeline < Pipeline
886
+ def call(images)
887
+ prepared_images = prepare_images(images)
888
+
889
+ inputs = @processor.(prepared_images)
890
+ predicted_depth = @model.(inputs)[0]
891
+
892
+ to_return = []
893
+ prepared_images.length.times do |i|
894
+ prediction = Utils.interpolate(predicted_depth[i], prepared_images[i].size.reverse, "bilinear", false)
895
+ max_prediction = Utils.max(prediction.flatten)[0]
896
+ formatted =
897
+ prediction.map do |v|
898
+ v.map do |v2|
899
+ v2.map do |v3|
900
+ (v3 * 255 / max_prediction).round
901
+ end
902
+ end
903
+ end
904
+ to_return << {
905
+ predicted_depth: predicted_depth[i],
906
+ depth: Utils::RawImage.from_array(formatted).image
907
+ }
908
+ end
909
+ to_return.length > 1 ? to_return : to_return[0]
910
+ end
911
+ end
912
+
303
913
  class EmbeddingPipeline < FeatureExtractionPipeline
304
914
  def call(
305
915
  texts,
306
916
  pooling: "mean",
307
- normalize: true
917
+ normalize: true,
918
+ model_output: nil
308
919
  )
309
- super(texts, pooling:, normalize:)
920
+ super(texts, pooling:, normalize:, model_output:)
310
921
  end
311
922
  end
312
923
 
@@ -368,6 +979,145 @@ module Informers
368
979
  },
369
980
  type: "text"
370
981
  },
982
+ "fill-mask" => {
983
+ tokenizer: AutoTokenizer,
984
+ pipeline: FillMaskPipeline,
985
+ model: AutoModelForMaskedLM,
986
+ default: {
987
+ model: "Xenova/bert-base-uncased"
988
+ },
989
+ type: "text"
990
+ },
991
+ "summarization" => {
992
+ tokenizer: AutoTokenizer,
993
+ pipeline: SummarizationPipeline,
994
+ model: AutoModelForSeq2SeqLM,
995
+ default: {
996
+ model: "Xenova/distilbart-cnn-6-6"
997
+ },
998
+ type: "text"
999
+ },
1000
+ "translation" => {
1001
+ tokenizer: AutoTokenizer,
1002
+ pipeline: TranslationPipeline,
1003
+ model: AutoModelForSeq2SeqLM,
1004
+ default: {
1005
+ model: "Xenova/t5-small"
1006
+ },
1007
+ type: "text"
1008
+ },
1009
+ "text2text-generation" => {
1010
+ tokenizer: AutoTokenizer,
1011
+ pipeline: Text2TextGenerationPipeline,
1012
+ model: AutoModelForSeq2SeqLM,
1013
+ default: {
1014
+ model: "Xenova/flan-t5-small"
1015
+ },
1016
+ type: "text"
1017
+ },
1018
+ "text-generation" => {
1019
+ tokenizer: AutoTokenizer,
1020
+ pipeline: TextGenerationPipeline,
1021
+ model: AutoModelForCausalLM,
1022
+ default: {
1023
+ model: "Xenova/gpt2"
1024
+ },
1025
+ type: "text"
1026
+ },
1027
+ "zero-shot-classification" => {
1028
+ tokenizer: AutoTokenizer,
1029
+ pipeline: ZeroShotClassificationPipeline,
1030
+ model: AutoModelForSequenceClassification,
1031
+ default: {
1032
+ model: "Xenova/distilbert-base-uncased-mnli"
1033
+ },
1034
+ type: "text"
1035
+ },
1036
+ "image-to-text" => {
1037
+ tokenizer: AutoTokenizer,
1038
+ pipeline: ImageToTextPipeline,
1039
+ model: AutoModelForVision2Seq,
1040
+ processor: AutoProcessor,
1041
+ default: {
1042
+ model: "Xenova/vit-gpt2-image-captioning"
1043
+ },
1044
+ type: "multimodal"
1045
+ },
1046
+ "image-classification" => {
1047
+ pipeline: ImageClassificationPipeline,
1048
+ model: AutoModelForImageClassification,
1049
+ processor: AutoProcessor,
1050
+ default: {
1051
+ model: "Xenova/vit-base-patch16-224",
1052
+ },
1053
+ type: "multimodal"
1054
+ },
1055
+ "image-segmentation" => {
1056
+ pipeline: ImageSegmentationPipeline,
1057
+ model: [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation],
1058
+ processor: AutoProcessor,
1059
+ default: {
1060
+ model: "Xenova/detr-resnet-50-panoptic",
1061
+ },
1062
+ type: "multimodal"
1063
+ },
1064
+ "zero-shot-image-classification" => {
1065
+ tokenizer: AutoTokenizer,
1066
+ pipeline: ZeroShotImageClassificationPipeline,
1067
+ model: AutoModel,
1068
+ processor: AutoProcessor,
1069
+ default: {
1070
+ model: "Xenova/clip-vit-base-patch32"
1071
+ },
1072
+ type: "multimodal"
1073
+ },
1074
+ "object-detection" => {
1075
+ pipeline: ObjectDetectionPipeline,
1076
+ model: AutoModelForObjectDetection,
1077
+ processor: AutoProcessor,
1078
+ default: {
1079
+ model: "Xenova/detr-resnet-50",
1080
+ },
1081
+ type: "multimodal"
1082
+ },
1083
+ "zero-shot-object-detection" => {
1084
+ tokenizer: AutoTokenizer,
1085
+ pipeline: ZeroShotObjectDetectionPipeline,
1086
+ model: AutoModelForZeroShotObjectDetection,
1087
+ processor: AutoProcessor,
1088
+ default: {
1089
+ model: "Xenova/owlvit-base-patch32"
1090
+ },
1091
+ type: "multimodal"
1092
+ },
1093
+ "document-question-answering" => {
1094
+ tokenizer: AutoTokenizer,
1095
+ pipeline: DocumentQuestionAnsweringPipeline,
1096
+ model: AutoModelForDocumentQuestionAnswering,
1097
+ processor: AutoProcessor,
1098
+ default: {
1099
+ model: "Xenova/donut-base-finetuned-docvqa"
1100
+ },
1101
+ type: "multimodal"
1102
+ },
1103
+ "image-to-image" => {
1104
+ pipeline: ImageToImagePipeline,
1105
+ model: AutoModelForImageToImage,
1106
+ processor: AutoProcessor,
1107
+ default: {
1108
+ model: "Xenova/swin2SR-classical-sr-x2-64"
1109
+ },
1110
+ type: "image"
1111
+ },
1112
+ "depth-estimation" => {
1113
+ pipeline: DepthEstimationPipeline,
1114
+ model: AutoModelForDepthEstimation,
1115
+ processor: AutoProcessor,
1116
+ default: {
1117
+ model: "Xenova/dpt-large"
1118
+ },
1119
+ type: "image"
1120
+ },
371
1121
  "feature-extraction" => {
372
1122
  tokenizer: AutoTokenizer,
373
1123
  pipeline: FeatureExtractionPipeline,
@@ -377,6 +1127,15 @@ module Informers
377
1127
  },
378
1128
  type: "text"
379
1129
  },
1130
+ "image-feature-extraction" => {
1131
+ processor: AutoProcessor,
1132
+ pipeline: ImageFeatureExtractionPipeline,
1133
+ model: [AutoModelForImageFeatureExtraction, AutoModel],
1134
+ default: {
1135
+ model: "Xenova/vit-base-patch16-224"
1136
+ },
1137
+ type: "image"
1138
+ },
380
1139
  "embedding" => {
381
1140
  tokenizer: AutoTokenizer,
382
1141
  pipeline: EmbeddingPipeline,
@@ -432,14 +1191,14 @@ module Informers
432
1191
  revision: "main",
433
1192
  model_file_name: nil
434
1193
  )
1194
+ # Apply aliases
1195
+ task = TASK_ALIASES[task] || task
1196
+
435
1197
  if quantized == NO_DEFAULT
436
1198
  # TODO move default to task class
437
- quantized = !["embedding", "reranking"].include?(task)
1199
+ quantized = ["text-classification", "token-classification", "question-answering", "feature-extraction"].include?(task)
438
1200
  end
439
1201
 
440
- # Apply aliases
441
- task = TASK_ALIASES[task] || task
442
-
443
1202
  # Get pipeline info
444
1203
  pipeline_info = SUPPORTED_TASKS[task.split("_", 1)[0]]
445
1204
  if !pipeline_info
@@ -495,7 +1254,15 @@ module Informers
495
1254
  next if !cls
496
1255
 
497
1256
  if cls.is_a?(Array)
498
- raise Todo
1257
+ e = nil
1258
+ cls.each do |c|
1259
+ begin
1260
+ result[name] = c.from_pretrained(model, **pretrained_options)
1261
+ rescue => err
1262
+ e = err
1263
+ end
1264
+ end
1265
+ raise e unless result[name]
499
1266
  else
500
1267
  result[name] = cls.from_pretrained(model, **pretrained_options)
501
1268
  end