informers 1.0.3 → 1.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -7,6 +7,26 @@ module Informers
7
7
  @tokenizer = tokenizer
8
8
  @processor = processor
9
9
  end
10
+
11
+ private
12
+
13
+ def prepare_images(images)
14
+ if !images.is_a?(Array)
15
+ images = [images]
16
+ end
17
+
18
+ # Possibly convert any non-images to images
19
+ images.map { |x| Utils::RawImage.read(x) }
20
+ end
21
+
22
+ def get_bounding_box(box, as_integer)
23
+ if as_integer
24
+ box = box.map { |x| x.to_i }
25
+ end
26
+ xmin, ymin, xmax, ymax = box
27
+
28
+ {xmin:, ymin:, xmax:, ymax:}
29
+ end
10
30
  end
11
31
 
12
32
  class TextClassificationPipeline < Pipeline
@@ -21,13 +41,13 @@ module Informers
21
41
  outputs = @model.(model_inputs)
22
42
 
23
43
  function_to_apply =
24
- if @model.config.problem_type == "multi_label_classification"
44
+ if @model.config[:problem_type] == "multi_label_classification"
25
45
  ->(batch) { Utils.sigmoid(batch) }
26
46
  else
27
47
  ->(batch) { Utils.softmax(batch) } # single_label_classification (default)
28
48
  end
29
49
 
30
- id2label = @model.config.id2label
50
+ id2label = @model.config[:id2label]
31
51
 
32
52
  to_return = []
33
53
  outputs.logits.each do |batch|
@@ -70,7 +90,7 @@ module Informers
70
90
  outputs = @model.(model_inputs)
71
91
 
72
92
  logits = outputs.logits
73
- id2label = @model.config.id2label
93
+ id2label = @model.config[:id2label]
74
94
 
75
95
  to_return = []
76
96
  logits.length.times do |i|
@@ -243,6 +263,527 @@ module Informers
243
263
  end
244
264
  end
245
265
 
266
+ class FillMaskPipeline < Pipeline
267
+ def call(texts, top_k: 5)
268
+ model_inputs = @tokenizer.(texts, padding: true, truncation: true)
269
+ outputs = @model.(model_inputs)
270
+
271
+ to_return = []
272
+ model_inputs[:input_ids].each_with_index do |ids, i|
273
+ mask_token_index = ids.index(@tokenizer.mask_token_id)
274
+
275
+ if mask_token_index.nil?
276
+ raise ArgumentError, "Mask token (#{@tokenizer.mask_token}) not found in text."
277
+ end
278
+ logits = outputs.logits[i]
279
+ item_logits = logits[mask_token_index]
280
+
281
+ scores = Utils.get_top_items(Utils.softmax(item_logits), top_k)
282
+
283
+ to_return <<
284
+ scores.map do |x|
285
+ sequence = ids.dup
286
+ sequence[mask_token_index] = x[0]
287
+
288
+ {
289
+ score: x[1],
290
+ token: x[0],
291
+ token_str: @tokenizer.id_to_token(x[0]),
292
+ sequence: @tokenizer.decode(sequence, skip_special_tokens: true)
293
+ }
294
+ end
295
+ end
296
+ texts.is_a?(Array) ? to_return : to_return[0]
297
+ end
298
+ end
299
+
300
+ class Text2TextGenerationPipeline < Pipeline
301
+ KEY = :generated_text
302
+
303
+ def call(texts, **generate_kwargs)
304
+ if !texts.is_a?(Array)
305
+ texts = [texts]
306
+ end
307
+
308
+ # Add global prefix, if present
309
+ if @model.config[:prefix]
310
+ texts = texts.map { |x| @model.config[:prefix] + x }
311
+ end
312
+
313
+ # Handle task specific params:
314
+ task_specific_params = @model.config[:task_specific_params]
315
+ if task_specific_params && task_specific_params[@task]
316
+ # Add prefixes, if present
317
+ if task_specific_params[@task]["prefix"]
318
+ texts = texts.map { |x| task_specific_params[@task]["prefix"] + x }
319
+ end
320
+
321
+ # TODO update generation config
322
+ end
323
+
324
+ tokenizer = @tokenizer
325
+ tokenizer_options = {
326
+ padding: true,
327
+ truncation: true
328
+ }
329
+ if is_a?(TranslationPipeline) && tokenizer.respond_to?(:_build_translation_inputs)
330
+ input_ids = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs)[:input_ids]
331
+ else
332
+ input_ids = tokenizer.(texts, **tokenizer_options)[:input_ids]
333
+ end
334
+
335
+ output_token_ids = @model.generate(input_ids, generate_kwargs)
336
+
337
+ tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)
338
+ .map { |text| {self.class.const_get(:KEY) => text} }
339
+ end
340
+ end
341
+
342
+ class SummarizationPipeline < Text2TextGenerationPipeline
343
+ KEY = :summary_text
344
+ end
345
+
346
+ class TranslationPipeline < Text2TextGenerationPipeline
347
+ KEY = :translation_text
348
+ end
349
+
350
+ class TextGenerationPipeline < Pipeline
351
+ def call(texts, **generate_kwargs)
352
+ is_batched = false
353
+ is_chat_input = false
354
+
355
+ # Normalize inputs
356
+ if texts.is_a?(String)
357
+ texts = [texts]
358
+ inputs = texts
359
+ else
360
+ raise Todo
361
+ end
362
+
363
+ # By default, do not add special tokens
364
+ add_special_tokens = generate_kwargs[:add_special_tokens] || false
365
+
366
+ # /By default, return full text
367
+ return_full_text =
368
+ if is_chat_input
369
+ false
370
+ else
371
+ generate_kwargs[:return_full_text] || true
372
+ end
373
+
374
+ @tokenizer.padding_side = "left"
375
+ input_ids, attention_mask =
376
+ @tokenizer.(inputs, add_special_tokens:, padding: true, truncation: true)
377
+ .values_at(:input_ids, :attention_mask)
378
+
379
+ output_token_ids =
380
+ @model.generate(
381
+ input_ids, generate_kwargs, nil, inputs_attention_mask: attention_mask
382
+ )
383
+
384
+ decoded = @tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)
385
+
386
+ if !return_full_text && Utils.dims(input_ids)[-1] > 0
387
+ prompt_lengths = @tokenizer.batch_decode(input_ids, skip_special_tokens: true).map { |x| x.length }
388
+ end
389
+
390
+ to_return = Array.new(texts.length) { [] }
391
+ decoded.length.times do |i|
392
+ text_index = (i / output_token_ids.length.to_i * texts.length).floor
393
+
394
+ if prompt_lengths
395
+ raise Todo
396
+ end
397
+ # TODO is_chat_input
398
+ to_return[text_index] << {
399
+ generated_text: decoded[i]
400
+ }
401
+ end
402
+ !is_batched && to_return.length == 1 ? to_return[0] : to_return
403
+ end
404
+ end
405
+
406
+ class ZeroShotClassificationPipeline < Pipeline
407
+ def initialize(**options)
408
+ super(**options)
409
+
410
+ @label2id = @model.config[:label2id].transform_keys(&:downcase)
411
+
412
+ @entailment_id = @label2id["entailment"]
413
+ if @entailment_id.nil?
414
+ warn "Could not find 'entailment' in label2id mapping. Using 2 as entailment_id."
415
+ @entailment_id = 2
416
+ end
417
+
418
+ @contradiction_id = @label2id["contradiction"] || @label2id["not_entailment"]
419
+ if @contradiction_id.nil?
420
+ warn "Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id."
421
+ @contradiction_id = 0
422
+ end
423
+ end
424
+
425
+ def call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false)
426
+ is_batched = texts.is_a?(Array)
427
+ if !is_batched
428
+ texts = [texts]
429
+ end
430
+ if !candidate_labels.is_a?(Array)
431
+ candidate_labels = [candidate_labels]
432
+ end
433
+
434
+ # Insert labels into hypothesis template
435
+ hypotheses = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }
436
+
437
+ # How to perform the softmax over the logits:
438
+ # - true: softmax over the entailment vs. contradiction dim for each label independently
439
+ # - false: softmax the "entailment" logits over all candidate labels
440
+ softmax_each = multi_label || candidate_labels.length == 1
441
+
442
+ to_return = []
443
+ texts.each do |premise|
444
+ entails_logits = []
445
+
446
+ hypotheses.each do |hypothesis|
447
+ inputs = @tokenizer.(
448
+ premise,
449
+ text_pair: hypothesis,
450
+ padding: true,
451
+ truncation: true
452
+ )
453
+ outputs = @model.(inputs)
454
+
455
+ if softmax_each
456
+ entails_logits << [
457
+ outputs.logits[0][@contradiction_id],
458
+ outputs.logits[0][@entailment_id]
459
+ ]
460
+ else
461
+ entails_logits << outputs.logits[0][@entailment_id]
462
+ end
463
+ end
464
+
465
+ scores =
466
+ if softmax_each
467
+ entails_logits.map { |x| Utils.softmax(x)[1] }
468
+ else
469
+ Utils.softmax(entails_logits)
470
+ end
471
+
472
+ # Sort by scores (desc) and return scores with indices
473
+ scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] }
474
+
475
+ to_return << {
476
+ sequence: premise,
477
+ labels: scores_sorted.map { |x| candidate_labels[x[1]] },
478
+ scores: scores_sorted.map { |x| x[0] }
479
+ }
480
+ end
481
+ is_batched ? to_return : to_return[0]
482
+ end
483
+ end
484
+
485
+ class ImageToTextPipeline < Pipeline
486
+ def call(images, **generate_kwargs)
487
+ is_batched = images.is_a?(Array)
488
+ prepared_images = prepare_images(images)
489
+
490
+ pixel_values = @processor.(prepared_images)[:pixel_values]
491
+
492
+ to_return = []
493
+ pixel_values.each do |batch|
494
+ batch = [batch]
495
+ output = @model.generate(batch, **generate_kwargs)
496
+ decoded = @tokenizer
497
+ .batch_decode(output, skip_special_tokens: true)
498
+ .map { |x| {generated_text: x.strip} }
499
+ to_return << decoded
500
+ end
501
+
502
+ is_batched ? to_return : to_return[0]
503
+ end
504
+ end
505
+
506
+ class ImageClassificationPipeline < Pipeline
507
+ def call(images, top_k: 1)
508
+ is_batched = images.is_a?(Array)
509
+ prepared_images = prepare_images(images)
510
+
511
+ pixel_values = @processor.(prepared_images)[:pixel_values]
512
+ output = @model.({pixel_values: pixel_values})
513
+
514
+ id2label = @model.config[:id2label]
515
+ to_return = []
516
+ output.logits.each do |batch|
517
+ scores = Utils.get_top_items(Utils.softmax(batch), top_k)
518
+
519
+ vals =
520
+ scores.map do |x|
521
+ {
522
+ label: id2label[x[0].to_s],
523
+ score: x[1]
524
+ }
525
+ end
526
+ if top_k == 1
527
+ to_return.push(*vals)
528
+ else
529
+ to_return << vals
530
+ end
531
+ end
532
+
533
+ is_batched || top_k == 1 ? to_return : to_return[0]
534
+ end
535
+ end
536
+
537
+ class ImageSegmentationPipeline < Pipeline
538
+ def initialize(**options)
539
+ super(**options)
540
+
541
+ @subtasks_mapping = {
542
+ "panoptic" => "post_process_panoptic_segmentation",
543
+ "instance" => "post_process_instance_segmentation",
544
+ "semantic" => "post_process_semantic_segmentation"
545
+ }
546
+ end
547
+
548
+ def call(
549
+ images,
550
+ threshold: 0.5,
551
+ mask_threshold: 0.5,
552
+ overlap_mask_area_threshold: 0.8,
553
+ label_ids_to_fuse: nil,
554
+ target_sizes: nil,
555
+ subtask: nil
556
+ )
557
+ is_batched = images.is_a?(Array)
558
+
559
+ if is_batched && images.length != 1
560
+ raise Error, "Image segmentation pipeline currently only supports a batch size of 1."
561
+ end
562
+
563
+ prepared_images = prepare_images(images)
564
+ image_sizes = prepared_images.map { |x| [x.height, x.width] }
565
+
566
+ model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)
567
+ output = @model.(model_inputs)
568
+
569
+ if !subtask.nil?
570
+ fn = @subtasks_mapping[subtask]
571
+ else
572
+ @subtasks_mapping.each do |task, func|
573
+ if @processor.feature_extractor.respond_to?(func)
574
+ fn = @processor.feature_extractor.method(func)
575
+ subtask = task
576
+ break
577
+ end
578
+ end
579
+ end
580
+
581
+ id2label = @model.config[:id2label]
582
+
583
+ annotation = []
584
+ if subtask == "panoptic" || subtask == "instance"
585
+ processed = fn.(
586
+ output,
587
+ threshold:,
588
+ mask_threshold:,
589
+ overlap_mask_area_threshold:,
590
+ label_ids_to_fuse:,
591
+ target_sizes: target_sizes || image_sizes, # TODO FIX?
592
+ )[0]
593
+
594
+ _segmentation = processed[:segmentation]
595
+
596
+ processed[:segments_info].each do |segment|
597
+ annotation << {
598
+ label: id2label[segment[:label_id].to_s],
599
+ score: segment[:score]
600
+ # TODO mask
601
+ }
602
+ end
603
+ elsif subtask == "semantic"
604
+ raise Todo
605
+ else
606
+ raise Error, "Subtask #{subtask} not supported."
607
+ end
608
+
609
+ annotation
610
+ end
611
+ end
612
+
613
+ class ZeroShotImageClassificationPipeline < Pipeline
614
+ def call(images, candidate_labels, hypothesis_template: "This is a photo of {}")
615
+ is_batched = images.is_a?(Array)
616
+ prepared_images = prepare_images(images)
617
+
618
+ # Insert label into hypothesis template
619
+ texts = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }
620
+
621
+ # Run tokenization
622
+ text_inputs = @tokenizer.(texts,
623
+ padding: @model.config[:model_type] == "siglip" ? "max_length" : true,
624
+ truncation: true
625
+ )
626
+
627
+ # Run processor
628
+ pixel_values = @processor.(prepared_images)[:pixel_values]
629
+
630
+ # Run model with both text and pixel inputs
631
+ output = @model.(text_inputs.merge(pixel_values: pixel_values))
632
+
633
+ function_to_apply =
634
+ if @model.config[:model_type] == "siglip"
635
+ ->(batch) { Utils.sigmoid(batch) }
636
+ else
637
+ ->(batch) { Utils.softmax(batch) }
638
+ end
639
+
640
+ # Compare each image with each candidate label
641
+ to_return = []
642
+ output[0].each do |batch|
643
+ # Compute softmax per image
644
+ probs = function_to_apply.(batch)
645
+
646
+ result = probs
647
+ .map.with_index { |x, i| {label: candidate_labels[i], score: x} }
648
+ .sort_by { |v| -v[:score] }
649
+
650
+ to_return << result
651
+ end
652
+
653
+ is_batched ? to_return : to_return[0]
654
+ end
655
+ end
656
+
657
+ class ObjectDetectionPipeline < Pipeline
658
+ def call(images, threshold: 0.9, percentage: false)
659
+ is_batched = images.is_a?(Array)
660
+
661
+ if is_batched && images.length != 1
662
+ raise Error, "Object detection pipeline currently only supports a batch size of 1."
663
+ end
664
+ prepared_images = prepare_images(images)
665
+
666
+ image_sizes = percentage ? nil : prepared_images.map { |x| [x.height, x.width] }
667
+
668
+ model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)
669
+ output = @model.(model_inputs)
670
+
671
+ processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_sizes)
672
+
673
+ # Add labels
674
+ id2label = @model.config[:id2label]
675
+
676
+ # Format output
677
+ result =
678
+ processed.map do |batch|
679
+ batch[:boxes].map.with_index do |box, i|
680
+ {
681
+ label: id2label[batch[:classes][i].to_s],
682
+ score: batch[:scores][i],
683
+ box: get_bounding_box(box, !percentage)
684
+ }
685
+ end.sort_by { |v| -v[:score] }
686
+ end
687
+
688
+ is_batched ? result : result[0]
689
+ end
690
+ end
691
+
692
+ class ZeroShotObjectDetectionPipeline < Pipeline
693
+ def call(
694
+ images,
695
+ candidate_labels,
696
+ threshold: 0.1,
697
+ top_k: nil,
698
+ percentage: false
699
+ )
700
+ is_batched = images.is_a?(Array)
701
+ prepared_images = prepare_images(images)
702
+
703
+ # Run tokenization
704
+ text_inputs = @tokenizer.(candidate_labels,
705
+ padding: true,
706
+ truncation: true
707
+ )
708
+
709
+ # Run processor
710
+ model_inputs = @processor.(prepared_images)
711
+
712
+ # Since non-maximum suppression is performed for exporting, we need to
713
+ # process each image separately. For more information, see:
714
+ # https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
715
+ to_return = []
716
+ prepared_images.length.times do |i|
717
+ image = prepared_images[i]
718
+ image_size = percentage ? nil : [[image.height, image.width]]
719
+ pixel_values = [model_inputs[:pixel_values][i]]
720
+
721
+ # Run model with both text and pixel inputs
722
+ output = @model.(text_inputs.merge(pixel_values: pixel_values))
723
+ # TODO remove
724
+ output = @model.instance_variable_get(:@session).outputs.map { |v| v[:name].to_sym }.zip(output).to_h
725
+
726
+ processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_size, true)[0]
727
+ result =
728
+ processed[:boxes].map.with_index do |box, i|
729
+ {
730
+ label: candidate_labels[processed[:classes][i]],
731
+ score: processed[:scores][i],
732
+ box: get_bounding_box(box, !percentage),
733
+ }
734
+ end
735
+ result.sort_by! { |v| -v[:score] }
736
+ if !top_k.nil?
737
+ result = result[0...topk]
738
+ end
739
+ to_return << result
740
+ end
741
+
742
+ is_batched ? to_return : to_return[0]
743
+ end
744
+ end
745
+
746
+ class DocumentQuestionAnsweringPipeline < Pipeline
747
+ def call(image, question, **generate_kwargs)
748
+ # NOTE: For now, we only support a batch size of 1
749
+
750
+ # Preprocess image
751
+ prepared_image = prepare_images(image)[0]
752
+ pixel_values = @processor.(prepared_image)[:pixel_values]
753
+
754
+ # Run tokenization
755
+ task_prompt = "<s_docvqa><s_question>#{question}</s_question><s_answer>"
756
+ decoder_input_ids =
757
+ @tokenizer.(
758
+ task_prompt,
759
+ add_special_tokens: false,
760
+ padding: true,
761
+ truncation: true
762
+ )[:input_ids]
763
+
764
+ # Run model
765
+ output =
766
+ @model.generate(
767
+ pixel_values,
768
+ generate_kwargs.merge(
769
+ decoder_input_ids: decoder_input_ids[0],
770
+ max_length: @model.config["decoder"]["max_position_embeddings"]
771
+ ).transform_keys(&:to_s)
772
+ )
773
+
774
+ # Decode output
775
+ decoded = @tokenizer.batch_decode(output, skip_special_tokens: false)[0]
776
+
777
+ # Parse answer
778
+ match = decoded.match(/<s_answer>(.*?)<\/s_answer>/)
779
+ answer = nil
780
+ if match && match.length >= 2
781
+ answer = match[1].strip
782
+ end
783
+ [{answer:}]
784
+ end
785
+ end
786
+
246
787
  class FeatureExtractionPipeline < Pipeline
247
788
  def call(
248
789
  texts,
@@ -306,6 +847,69 @@ module Informers
306
847
  end
307
848
  end
308
849
 
850
+ class ImageFeatureExtractionPipeline < Pipeline
851
+ def call(images)
852
+ prepared_images = prepare_images(images)
853
+ pixel_values = @processor.(prepared_images)[:pixel_values]
854
+ outputs = @model.({pixel_values: pixel_values})
855
+
856
+ result = outputs[0]
857
+ result
858
+ end
859
+ end
860
+
861
+ class ImageToImagePipeline < Pipeline
862
+ def call(images)
863
+ prepared_images = prepare_images(images)
864
+ inputs = @processor.(prepared_images)
865
+ outputs = @model.(inputs);
866
+
867
+ to_return = []
868
+ outputs[0].each do |batch|
869
+ # TODO flatten first
870
+ output =
871
+ batch.map do |v|
872
+ v.map do |v2|
873
+ v2.map do |v3|
874
+ (v3.clamp(0, 1) * 255).round
875
+ end
876
+ end
877
+ end
878
+ to_return << Utils::RawImage.from_array(output).image
879
+ end
880
+
881
+ to_return.length > 1 ? to_return : to_return[0]
882
+ end
883
+ end
884
+
885
+ class DepthEstimationPipeline < Pipeline
886
+ def call(images)
887
+ prepared_images = prepare_images(images)
888
+
889
+ inputs = @processor.(prepared_images)
890
+ predicted_depth = @model.(inputs)[0]
891
+
892
+ to_return = []
893
+ prepared_images.length.times do |i|
894
+ prediction = Utils.interpolate(predicted_depth[i], prepared_images[i].size.reverse, "bilinear", false)
895
+ max_prediction = Utils.max(prediction.flatten)[0]
896
+ formatted =
897
+ prediction.map do |v|
898
+ v.map do |v2|
899
+ v2.map do |v3|
900
+ (v3 * 255 / max_prediction).round
901
+ end
902
+ end
903
+ end
904
+ to_return << {
905
+ predicted_depth: predicted_depth[i],
906
+ depth: Utils::RawImage.from_array(formatted).image
907
+ }
908
+ end
909
+ to_return.length > 1 ? to_return : to_return[0]
910
+ end
911
+ end
912
+
309
913
  class EmbeddingPipeline < FeatureExtractionPipeline
310
914
  def call(
311
915
  texts,
@@ -375,6 +979,145 @@ module Informers
375
979
  },
376
980
  type: "text"
377
981
  },
982
+ "fill-mask" => {
983
+ tokenizer: AutoTokenizer,
984
+ pipeline: FillMaskPipeline,
985
+ model: AutoModelForMaskedLM,
986
+ default: {
987
+ model: "Xenova/bert-base-uncased"
988
+ },
989
+ type: "text"
990
+ },
991
+ "summarization" => {
992
+ tokenizer: AutoTokenizer,
993
+ pipeline: SummarizationPipeline,
994
+ model: AutoModelForSeq2SeqLM,
995
+ default: {
996
+ model: "Xenova/distilbart-cnn-6-6"
997
+ },
998
+ type: "text"
999
+ },
1000
+ "translation" => {
1001
+ tokenizer: AutoTokenizer,
1002
+ pipeline: TranslationPipeline,
1003
+ model: AutoModelForSeq2SeqLM,
1004
+ default: {
1005
+ model: "Xenova/t5-small"
1006
+ },
1007
+ type: "text"
1008
+ },
1009
+ "text2text-generation" => {
1010
+ tokenizer: AutoTokenizer,
1011
+ pipeline: Text2TextGenerationPipeline,
1012
+ model: AutoModelForSeq2SeqLM,
1013
+ default: {
1014
+ model: "Xenova/flan-t5-small"
1015
+ },
1016
+ type: "text"
1017
+ },
1018
+ "text-generation" => {
1019
+ tokenizer: AutoTokenizer,
1020
+ pipeline: TextGenerationPipeline,
1021
+ model: AutoModelForCausalLM,
1022
+ default: {
1023
+ model: "Xenova/gpt2"
1024
+ },
1025
+ type: "text"
1026
+ },
1027
+ "zero-shot-classification" => {
1028
+ tokenizer: AutoTokenizer,
1029
+ pipeline: ZeroShotClassificationPipeline,
1030
+ model: AutoModelForSequenceClassification,
1031
+ default: {
1032
+ model: "Xenova/distilbert-base-uncased-mnli"
1033
+ },
1034
+ type: "text"
1035
+ },
1036
+ "image-to-text" => {
1037
+ tokenizer: AutoTokenizer,
1038
+ pipeline: ImageToTextPipeline,
1039
+ model: AutoModelForVision2Seq,
1040
+ processor: AutoProcessor,
1041
+ default: {
1042
+ model: "Xenova/vit-gpt2-image-captioning"
1043
+ },
1044
+ type: "multimodal"
1045
+ },
1046
+ "image-classification" => {
1047
+ pipeline: ImageClassificationPipeline,
1048
+ model: AutoModelForImageClassification,
1049
+ processor: AutoProcessor,
1050
+ default: {
1051
+ model: "Xenova/vit-base-patch16-224",
1052
+ },
1053
+ type: "multimodal"
1054
+ },
1055
+ "image-segmentation" => {
1056
+ pipeline: ImageSegmentationPipeline,
1057
+ model: [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation],
1058
+ processor: AutoProcessor,
1059
+ default: {
1060
+ model: "Xenova/detr-resnet-50-panoptic",
1061
+ },
1062
+ type: "multimodal"
1063
+ },
1064
+ "zero-shot-image-classification" => {
1065
+ tokenizer: AutoTokenizer,
1066
+ pipeline: ZeroShotImageClassificationPipeline,
1067
+ model: AutoModel,
1068
+ processor: AutoProcessor,
1069
+ default: {
1070
+ model: "Xenova/clip-vit-base-patch32"
1071
+ },
1072
+ type: "multimodal"
1073
+ },
1074
+ "object-detection" => {
1075
+ pipeline: ObjectDetectionPipeline,
1076
+ model: AutoModelForObjectDetection,
1077
+ processor: AutoProcessor,
1078
+ default: {
1079
+ model: "Xenova/detr-resnet-50",
1080
+ },
1081
+ type: "multimodal"
1082
+ },
1083
+ "zero-shot-object-detection" => {
1084
+ tokenizer: AutoTokenizer,
1085
+ pipeline: ZeroShotObjectDetectionPipeline,
1086
+ model: AutoModelForZeroShotObjectDetection,
1087
+ processor: AutoProcessor,
1088
+ default: {
1089
+ model: "Xenova/owlvit-base-patch32"
1090
+ },
1091
+ type: "multimodal"
1092
+ },
1093
+ "document-question-answering" => {
1094
+ tokenizer: AutoTokenizer,
1095
+ pipeline: DocumentQuestionAnsweringPipeline,
1096
+ model: AutoModelForDocumentQuestionAnswering,
1097
+ processor: AutoProcessor,
1098
+ default: {
1099
+ model: "Xenova/donut-base-finetuned-docvqa"
1100
+ },
1101
+ type: "multimodal"
1102
+ },
1103
+ "image-to-image" => {
1104
+ pipeline: ImageToImagePipeline,
1105
+ model: AutoModelForImageToImage,
1106
+ processor: AutoProcessor,
1107
+ default: {
1108
+ model: "Xenova/swin2SR-classical-sr-x2-64"
1109
+ },
1110
+ type: "image"
1111
+ },
1112
+ "depth-estimation" => {
1113
+ pipeline: DepthEstimationPipeline,
1114
+ model: AutoModelForDepthEstimation,
1115
+ processor: AutoProcessor,
1116
+ default: {
1117
+ model: "Xenova/dpt-large"
1118
+ },
1119
+ type: "image"
1120
+ },
378
1121
  "feature-extraction" => {
379
1122
  tokenizer: AutoTokenizer,
380
1123
  pipeline: FeatureExtractionPipeline,
@@ -384,6 +1127,15 @@ module Informers
384
1127
  },
385
1128
  type: "text"
386
1129
  },
1130
+ "image-feature-extraction" => {
1131
+ processor: AutoProcessor,
1132
+ pipeline: ImageFeatureExtractionPipeline,
1133
+ model: [AutoModelForImageFeatureExtraction, AutoModel],
1134
+ default: {
1135
+ model: "Xenova/vit-base-patch16-224"
1136
+ },
1137
+ type: "image"
1138
+ },
387
1139
  "embedding" => {
388
1140
  tokenizer: AutoTokenizer,
389
1141
  pipeline: EmbeddingPipeline,
@@ -439,14 +1191,14 @@ module Informers
439
1191
  revision: "main",
440
1192
  model_file_name: nil
441
1193
  )
1194
+ # Apply aliases
1195
+ task = TASK_ALIASES[task] || task
1196
+
442
1197
  if quantized == NO_DEFAULT
443
1198
  # TODO move default to task class
444
- quantized = !["embedding", "reranking"].include?(task)
1199
+ quantized = ["text-classification", "token-classification", "question-answering", "feature-extraction"].include?(task)
445
1200
  end
446
1201
 
447
- # Apply aliases
448
- task = TASK_ALIASES[task] || task
449
-
450
1202
  # Get pipeline info
451
1203
  pipeline_info = SUPPORTED_TASKS[task.split("_", 1)[0]]
452
1204
  if !pipeline_info
@@ -502,7 +1254,15 @@ module Informers
502
1254
  next if !cls
503
1255
 
504
1256
  if cls.is_a?(Array)
505
- raise Todo
1257
+ e = nil
1258
+ cls.each do |c|
1259
+ begin
1260
+ result[name] = c.from_pretrained(model, **pretrained_options)
1261
+ rescue => err
1262
+ e = err
1263
+ end
1264
+ end
1265
+ raise e unless result[name]
506
1266
  else
507
1267
  result[name] = cls.from_pretrained(model, **pretrained_options)
508
1268
  end