informers 1.0.3 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +123 -0
- data/lib/informers/configs.rb +10 -8
- data/lib/informers/model.rb +2 -9
- data/lib/informers/models.rb +997 -12
- data/lib/informers/pipelines.rb +768 -8
- data/lib/informers/processors.rb +796 -0
- data/lib/informers/tokenizers.rb +154 -4
- data/lib/informers/utils/core.rb +4 -0
- data/lib/informers/utils/generation.rb +294 -0
- data/lib/informers/utils/image.rb +116 -0
- data/lib/informers/utils/math.rb +73 -0
- data/lib/informers/utils/tensor.rb +46 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +3 -0
- metadata +8 -5
data/lib/informers/pipelines.rb
CHANGED
@@ -7,6 +7,26 @@ module Informers
|
|
7
7
|
@tokenizer = tokenizer
|
8
8
|
@processor = processor
|
9
9
|
end
|
10
|
+
|
11
|
+
private
|
12
|
+
|
13
|
+
def prepare_images(images)
|
14
|
+
if !images.is_a?(Array)
|
15
|
+
images = [images]
|
16
|
+
end
|
17
|
+
|
18
|
+
# Possibly convert any non-images to images
|
19
|
+
images.map { |x| Utils::RawImage.read(x) }
|
20
|
+
end
|
21
|
+
|
22
|
+
def get_bounding_box(box, as_integer)
|
23
|
+
if as_integer
|
24
|
+
box = box.map { |x| x.to_i }
|
25
|
+
end
|
26
|
+
xmin, ymin, xmax, ymax = box
|
27
|
+
|
28
|
+
{xmin:, ymin:, xmax:, ymax:}
|
29
|
+
end
|
10
30
|
end
|
11
31
|
|
12
32
|
class TextClassificationPipeline < Pipeline
|
@@ -21,13 +41,13 @@ module Informers
|
|
21
41
|
outputs = @model.(model_inputs)
|
22
42
|
|
23
43
|
function_to_apply =
|
24
|
-
if @model.config
|
44
|
+
if @model.config[:problem_type] == "multi_label_classification"
|
25
45
|
->(batch) { Utils.sigmoid(batch) }
|
26
46
|
else
|
27
47
|
->(batch) { Utils.softmax(batch) } # single_label_classification (default)
|
28
48
|
end
|
29
49
|
|
30
|
-
id2label = @model.config
|
50
|
+
id2label = @model.config[:id2label]
|
31
51
|
|
32
52
|
to_return = []
|
33
53
|
outputs.logits.each do |batch|
|
@@ -70,7 +90,7 @@ module Informers
|
|
70
90
|
outputs = @model.(model_inputs)
|
71
91
|
|
72
92
|
logits = outputs.logits
|
73
|
-
id2label = @model.config
|
93
|
+
id2label = @model.config[:id2label]
|
74
94
|
|
75
95
|
to_return = []
|
76
96
|
logits.length.times do |i|
|
@@ -243,6 +263,527 @@ module Informers
|
|
243
263
|
end
|
244
264
|
end
|
245
265
|
|
266
|
+
class FillMaskPipeline < Pipeline
|
267
|
+
def call(texts, top_k: 5)
|
268
|
+
model_inputs = @tokenizer.(texts, padding: true, truncation: true)
|
269
|
+
outputs = @model.(model_inputs)
|
270
|
+
|
271
|
+
to_return = []
|
272
|
+
model_inputs[:input_ids].each_with_index do |ids, i|
|
273
|
+
mask_token_index = ids.index(@tokenizer.mask_token_id)
|
274
|
+
|
275
|
+
if mask_token_index.nil?
|
276
|
+
raise ArgumentError, "Mask token (#{@tokenizer.mask_token}) not found in text."
|
277
|
+
end
|
278
|
+
logits = outputs.logits[i]
|
279
|
+
item_logits = logits[mask_token_index]
|
280
|
+
|
281
|
+
scores = Utils.get_top_items(Utils.softmax(item_logits), top_k)
|
282
|
+
|
283
|
+
to_return <<
|
284
|
+
scores.map do |x|
|
285
|
+
sequence = ids.dup
|
286
|
+
sequence[mask_token_index] = x[0]
|
287
|
+
|
288
|
+
{
|
289
|
+
score: x[1],
|
290
|
+
token: x[0],
|
291
|
+
token_str: @tokenizer.id_to_token(x[0]),
|
292
|
+
sequence: @tokenizer.decode(sequence, skip_special_tokens: true)
|
293
|
+
}
|
294
|
+
end
|
295
|
+
end
|
296
|
+
texts.is_a?(Array) ? to_return : to_return[0]
|
297
|
+
end
|
298
|
+
end
|
299
|
+
|
300
|
+
class Text2TextGenerationPipeline < Pipeline
|
301
|
+
KEY = :generated_text
|
302
|
+
|
303
|
+
def call(texts, **generate_kwargs)
|
304
|
+
if !texts.is_a?(Array)
|
305
|
+
texts = [texts]
|
306
|
+
end
|
307
|
+
|
308
|
+
# Add global prefix, if present
|
309
|
+
if @model.config[:prefix]
|
310
|
+
texts = texts.map { |x| @model.config[:prefix] + x }
|
311
|
+
end
|
312
|
+
|
313
|
+
# Handle task specific params:
|
314
|
+
task_specific_params = @model.config[:task_specific_params]
|
315
|
+
if task_specific_params && task_specific_params[@task]
|
316
|
+
# Add prefixes, if present
|
317
|
+
if task_specific_params[@task]["prefix"]
|
318
|
+
texts = texts.map { |x| task_specific_params[@task]["prefix"] + x }
|
319
|
+
end
|
320
|
+
|
321
|
+
# TODO update generation config
|
322
|
+
end
|
323
|
+
|
324
|
+
tokenizer = @tokenizer
|
325
|
+
tokenizer_options = {
|
326
|
+
padding: true,
|
327
|
+
truncation: true
|
328
|
+
}
|
329
|
+
if is_a?(TranslationPipeline) && tokenizer.respond_to?(:_build_translation_inputs)
|
330
|
+
input_ids = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs)[:input_ids]
|
331
|
+
else
|
332
|
+
input_ids = tokenizer.(texts, **tokenizer_options)[:input_ids]
|
333
|
+
end
|
334
|
+
|
335
|
+
output_token_ids = @model.generate(input_ids, generate_kwargs)
|
336
|
+
|
337
|
+
tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)
|
338
|
+
.map { |text| {self.class.const_get(:KEY) => text} }
|
339
|
+
end
|
340
|
+
end
|
341
|
+
|
342
|
+
class SummarizationPipeline < Text2TextGenerationPipeline
|
343
|
+
KEY = :summary_text
|
344
|
+
end
|
345
|
+
|
346
|
+
class TranslationPipeline < Text2TextGenerationPipeline
|
347
|
+
KEY = :translation_text
|
348
|
+
end
|
349
|
+
|
350
|
+
class TextGenerationPipeline < Pipeline
|
351
|
+
def call(texts, **generate_kwargs)
|
352
|
+
is_batched = false
|
353
|
+
is_chat_input = false
|
354
|
+
|
355
|
+
# Normalize inputs
|
356
|
+
if texts.is_a?(String)
|
357
|
+
texts = [texts]
|
358
|
+
inputs = texts
|
359
|
+
else
|
360
|
+
raise Todo
|
361
|
+
end
|
362
|
+
|
363
|
+
# By default, do not add special tokens
|
364
|
+
add_special_tokens = generate_kwargs[:add_special_tokens] || false
|
365
|
+
|
366
|
+
# /By default, return full text
|
367
|
+
return_full_text =
|
368
|
+
if is_chat_input
|
369
|
+
false
|
370
|
+
else
|
371
|
+
generate_kwargs[:return_full_text] || true
|
372
|
+
end
|
373
|
+
|
374
|
+
@tokenizer.padding_side = "left"
|
375
|
+
input_ids, attention_mask =
|
376
|
+
@tokenizer.(inputs, add_special_tokens:, padding: true, truncation: true)
|
377
|
+
.values_at(:input_ids, :attention_mask)
|
378
|
+
|
379
|
+
output_token_ids =
|
380
|
+
@model.generate(
|
381
|
+
input_ids, generate_kwargs, nil, inputs_attention_mask: attention_mask
|
382
|
+
)
|
383
|
+
|
384
|
+
decoded = @tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)
|
385
|
+
|
386
|
+
if !return_full_text && Utils.dims(input_ids)[-1] > 0
|
387
|
+
prompt_lengths = @tokenizer.batch_decode(input_ids, skip_special_tokens: true).map { |x| x.length }
|
388
|
+
end
|
389
|
+
|
390
|
+
to_return = Array.new(texts.length) { [] }
|
391
|
+
decoded.length.times do |i|
|
392
|
+
text_index = (i / output_token_ids.length.to_i * texts.length).floor
|
393
|
+
|
394
|
+
if prompt_lengths
|
395
|
+
raise Todo
|
396
|
+
end
|
397
|
+
# TODO is_chat_input
|
398
|
+
to_return[text_index] << {
|
399
|
+
generated_text: decoded[i]
|
400
|
+
}
|
401
|
+
end
|
402
|
+
!is_batched && to_return.length == 1 ? to_return[0] : to_return
|
403
|
+
end
|
404
|
+
end
|
405
|
+
|
406
|
+
class ZeroShotClassificationPipeline < Pipeline
|
407
|
+
def initialize(**options)
|
408
|
+
super(**options)
|
409
|
+
|
410
|
+
@label2id = @model.config[:label2id].transform_keys(&:downcase)
|
411
|
+
|
412
|
+
@entailment_id = @label2id["entailment"]
|
413
|
+
if @entailment_id.nil?
|
414
|
+
warn "Could not find 'entailment' in label2id mapping. Using 2 as entailment_id."
|
415
|
+
@entailment_id = 2
|
416
|
+
end
|
417
|
+
|
418
|
+
@contradiction_id = @label2id["contradiction"] || @label2id["not_entailment"]
|
419
|
+
if @contradiction_id.nil?
|
420
|
+
warn "Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id."
|
421
|
+
@contradiction_id = 0
|
422
|
+
end
|
423
|
+
end
|
424
|
+
|
425
|
+
def call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false)
|
426
|
+
is_batched = texts.is_a?(Array)
|
427
|
+
if !is_batched
|
428
|
+
texts = [texts]
|
429
|
+
end
|
430
|
+
if !candidate_labels.is_a?(Array)
|
431
|
+
candidate_labels = [candidate_labels]
|
432
|
+
end
|
433
|
+
|
434
|
+
# Insert labels into hypothesis template
|
435
|
+
hypotheses = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }
|
436
|
+
|
437
|
+
# How to perform the softmax over the logits:
|
438
|
+
# - true: softmax over the entailment vs. contradiction dim for each label independently
|
439
|
+
# - false: softmax the "entailment" logits over all candidate labels
|
440
|
+
softmax_each = multi_label || candidate_labels.length == 1
|
441
|
+
|
442
|
+
to_return = []
|
443
|
+
texts.each do |premise|
|
444
|
+
entails_logits = []
|
445
|
+
|
446
|
+
hypotheses.each do |hypothesis|
|
447
|
+
inputs = @tokenizer.(
|
448
|
+
premise,
|
449
|
+
text_pair: hypothesis,
|
450
|
+
padding: true,
|
451
|
+
truncation: true
|
452
|
+
)
|
453
|
+
outputs = @model.(inputs)
|
454
|
+
|
455
|
+
if softmax_each
|
456
|
+
entails_logits << [
|
457
|
+
outputs.logits[0][@contradiction_id],
|
458
|
+
outputs.logits[0][@entailment_id]
|
459
|
+
]
|
460
|
+
else
|
461
|
+
entails_logits << outputs.logits[0][@entailment_id]
|
462
|
+
end
|
463
|
+
end
|
464
|
+
|
465
|
+
scores =
|
466
|
+
if softmax_each
|
467
|
+
entails_logits.map { |x| Utils.softmax(x)[1] }
|
468
|
+
else
|
469
|
+
Utils.softmax(entails_logits)
|
470
|
+
end
|
471
|
+
|
472
|
+
# Sort by scores (desc) and return scores with indices
|
473
|
+
scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] }
|
474
|
+
|
475
|
+
to_return << {
|
476
|
+
sequence: premise,
|
477
|
+
labels: scores_sorted.map { |x| candidate_labels[x[1]] },
|
478
|
+
scores: scores_sorted.map { |x| x[0] }
|
479
|
+
}
|
480
|
+
end
|
481
|
+
is_batched ? to_return : to_return[0]
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
class ImageToTextPipeline < Pipeline
|
486
|
+
def call(images, **generate_kwargs)
|
487
|
+
is_batched = images.is_a?(Array)
|
488
|
+
prepared_images = prepare_images(images)
|
489
|
+
|
490
|
+
pixel_values = @processor.(prepared_images)[:pixel_values]
|
491
|
+
|
492
|
+
to_return = []
|
493
|
+
pixel_values.each do |batch|
|
494
|
+
batch = [batch]
|
495
|
+
output = @model.generate(batch, **generate_kwargs)
|
496
|
+
decoded = @tokenizer
|
497
|
+
.batch_decode(output, skip_special_tokens: true)
|
498
|
+
.map { |x| {generated_text: x.strip} }
|
499
|
+
to_return << decoded
|
500
|
+
end
|
501
|
+
|
502
|
+
is_batched ? to_return : to_return[0]
|
503
|
+
end
|
504
|
+
end
|
505
|
+
|
506
|
+
class ImageClassificationPipeline < Pipeline
|
507
|
+
def call(images, top_k: 1)
|
508
|
+
is_batched = images.is_a?(Array)
|
509
|
+
prepared_images = prepare_images(images)
|
510
|
+
|
511
|
+
pixel_values = @processor.(prepared_images)[:pixel_values]
|
512
|
+
output = @model.({pixel_values: pixel_values})
|
513
|
+
|
514
|
+
id2label = @model.config[:id2label]
|
515
|
+
to_return = []
|
516
|
+
output.logits.each do |batch|
|
517
|
+
scores = Utils.get_top_items(Utils.softmax(batch), top_k)
|
518
|
+
|
519
|
+
vals =
|
520
|
+
scores.map do |x|
|
521
|
+
{
|
522
|
+
label: id2label[x[0].to_s],
|
523
|
+
score: x[1]
|
524
|
+
}
|
525
|
+
end
|
526
|
+
if top_k == 1
|
527
|
+
to_return.push(*vals)
|
528
|
+
else
|
529
|
+
to_return << vals
|
530
|
+
end
|
531
|
+
end
|
532
|
+
|
533
|
+
is_batched || top_k == 1 ? to_return : to_return[0]
|
534
|
+
end
|
535
|
+
end
|
536
|
+
|
537
|
+
class ImageSegmentationPipeline < Pipeline
|
538
|
+
def initialize(**options)
|
539
|
+
super(**options)
|
540
|
+
|
541
|
+
@subtasks_mapping = {
|
542
|
+
"panoptic" => "post_process_panoptic_segmentation",
|
543
|
+
"instance" => "post_process_instance_segmentation",
|
544
|
+
"semantic" => "post_process_semantic_segmentation"
|
545
|
+
}
|
546
|
+
end
|
547
|
+
|
548
|
+
def call(
|
549
|
+
images,
|
550
|
+
threshold: 0.5,
|
551
|
+
mask_threshold: 0.5,
|
552
|
+
overlap_mask_area_threshold: 0.8,
|
553
|
+
label_ids_to_fuse: nil,
|
554
|
+
target_sizes: nil,
|
555
|
+
subtask: nil
|
556
|
+
)
|
557
|
+
is_batched = images.is_a?(Array)
|
558
|
+
|
559
|
+
if is_batched && images.length != 1
|
560
|
+
raise Error, "Image segmentation pipeline currently only supports a batch size of 1."
|
561
|
+
end
|
562
|
+
|
563
|
+
prepared_images = prepare_images(images)
|
564
|
+
image_sizes = prepared_images.map { |x| [x.height, x.width] }
|
565
|
+
|
566
|
+
model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)
|
567
|
+
output = @model.(model_inputs)
|
568
|
+
|
569
|
+
if !subtask.nil?
|
570
|
+
fn = @subtasks_mapping[subtask]
|
571
|
+
else
|
572
|
+
@subtasks_mapping.each do |task, func|
|
573
|
+
if @processor.feature_extractor.respond_to?(func)
|
574
|
+
fn = @processor.feature_extractor.method(func)
|
575
|
+
subtask = task
|
576
|
+
break
|
577
|
+
end
|
578
|
+
end
|
579
|
+
end
|
580
|
+
|
581
|
+
id2label = @model.config[:id2label]
|
582
|
+
|
583
|
+
annotation = []
|
584
|
+
if subtask == "panoptic" || subtask == "instance"
|
585
|
+
processed = fn.(
|
586
|
+
output,
|
587
|
+
threshold:,
|
588
|
+
mask_threshold:,
|
589
|
+
overlap_mask_area_threshold:,
|
590
|
+
label_ids_to_fuse:,
|
591
|
+
target_sizes: target_sizes || image_sizes, # TODO FIX?
|
592
|
+
)[0]
|
593
|
+
|
594
|
+
_segmentation = processed[:segmentation]
|
595
|
+
|
596
|
+
processed[:segments_info].each do |segment|
|
597
|
+
annotation << {
|
598
|
+
label: id2label[segment[:label_id].to_s],
|
599
|
+
score: segment[:score]
|
600
|
+
# TODO mask
|
601
|
+
}
|
602
|
+
end
|
603
|
+
elsif subtask == "semantic"
|
604
|
+
raise Todo
|
605
|
+
else
|
606
|
+
raise Error, "Subtask #{subtask} not supported."
|
607
|
+
end
|
608
|
+
|
609
|
+
annotation
|
610
|
+
end
|
611
|
+
end
|
612
|
+
|
613
|
+
class ZeroShotImageClassificationPipeline < Pipeline
|
614
|
+
def call(images, candidate_labels, hypothesis_template: "This is a photo of {}")
|
615
|
+
is_batched = images.is_a?(Array)
|
616
|
+
prepared_images = prepare_images(images)
|
617
|
+
|
618
|
+
# Insert label into hypothesis template
|
619
|
+
texts = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }
|
620
|
+
|
621
|
+
# Run tokenization
|
622
|
+
text_inputs = @tokenizer.(texts,
|
623
|
+
padding: @model.config[:model_type] == "siglip" ? "max_length" : true,
|
624
|
+
truncation: true
|
625
|
+
)
|
626
|
+
|
627
|
+
# Run processor
|
628
|
+
pixel_values = @processor.(prepared_images)[:pixel_values]
|
629
|
+
|
630
|
+
# Run model with both text and pixel inputs
|
631
|
+
output = @model.(text_inputs.merge(pixel_values: pixel_values))
|
632
|
+
|
633
|
+
function_to_apply =
|
634
|
+
if @model.config[:model_type] == "siglip"
|
635
|
+
->(batch) { Utils.sigmoid(batch) }
|
636
|
+
else
|
637
|
+
->(batch) { Utils.softmax(batch) }
|
638
|
+
end
|
639
|
+
|
640
|
+
# Compare each image with each candidate label
|
641
|
+
to_return = []
|
642
|
+
output[0].each do |batch|
|
643
|
+
# Compute softmax per image
|
644
|
+
probs = function_to_apply.(batch)
|
645
|
+
|
646
|
+
result = probs
|
647
|
+
.map.with_index { |x, i| {label: candidate_labels[i], score: x} }
|
648
|
+
.sort_by { |v| -v[:score] }
|
649
|
+
|
650
|
+
to_return << result
|
651
|
+
end
|
652
|
+
|
653
|
+
is_batched ? to_return : to_return[0]
|
654
|
+
end
|
655
|
+
end
|
656
|
+
|
657
|
+
class ObjectDetectionPipeline < Pipeline
|
658
|
+
def call(images, threshold: 0.9, percentage: false)
|
659
|
+
is_batched = images.is_a?(Array)
|
660
|
+
|
661
|
+
if is_batched && images.length != 1
|
662
|
+
raise Error, "Object detection pipeline currently only supports a batch size of 1."
|
663
|
+
end
|
664
|
+
prepared_images = prepare_images(images)
|
665
|
+
|
666
|
+
image_sizes = percentage ? nil : prepared_images.map { |x| [x.height, x.width] }
|
667
|
+
|
668
|
+
model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)
|
669
|
+
output = @model.(model_inputs)
|
670
|
+
|
671
|
+
processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_sizes)
|
672
|
+
|
673
|
+
# Add labels
|
674
|
+
id2label = @model.config[:id2label]
|
675
|
+
|
676
|
+
# Format output
|
677
|
+
result =
|
678
|
+
processed.map do |batch|
|
679
|
+
batch[:boxes].map.with_index do |box, i|
|
680
|
+
{
|
681
|
+
label: id2label[batch[:classes][i].to_s],
|
682
|
+
score: batch[:scores][i],
|
683
|
+
box: get_bounding_box(box, !percentage)
|
684
|
+
}
|
685
|
+
end.sort_by { |v| -v[:score] }
|
686
|
+
end
|
687
|
+
|
688
|
+
is_batched ? result : result[0]
|
689
|
+
end
|
690
|
+
end
|
691
|
+
|
692
|
+
class ZeroShotObjectDetectionPipeline < Pipeline
|
693
|
+
def call(
|
694
|
+
images,
|
695
|
+
candidate_labels,
|
696
|
+
threshold: 0.1,
|
697
|
+
top_k: nil,
|
698
|
+
percentage: false
|
699
|
+
)
|
700
|
+
is_batched = images.is_a?(Array)
|
701
|
+
prepared_images = prepare_images(images)
|
702
|
+
|
703
|
+
# Run tokenization
|
704
|
+
text_inputs = @tokenizer.(candidate_labels,
|
705
|
+
padding: true,
|
706
|
+
truncation: true
|
707
|
+
)
|
708
|
+
|
709
|
+
# Run processor
|
710
|
+
model_inputs = @processor.(prepared_images)
|
711
|
+
|
712
|
+
# Since non-maximum suppression is performed for exporting, we need to
|
713
|
+
# process each image separately. For more information, see:
|
714
|
+
# https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
|
715
|
+
to_return = []
|
716
|
+
prepared_images.length.times do |i|
|
717
|
+
image = prepared_images[i]
|
718
|
+
image_size = percentage ? nil : [[image.height, image.width]]
|
719
|
+
pixel_values = [model_inputs[:pixel_values][i]]
|
720
|
+
|
721
|
+
# Run model with both text and pixel inputs
|
722
|
+
output = @model.(text_inputs.merge(pixel_values: pixel_values))
|
723
|
+
# TODO remove
|
724
|
+
output = @model.instance_variable_get(:@session).outputs.map { |v| v[:name].to_sym }.zip(output).to_h
|
725
|
+
|
726
|
+
processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_size, true)[0]
|
727
|
+
result =
|
728
|
+
processed[:boxes].map.with_index do |box, i|
|
729
|
+
{
|
730
|
+
label: candidate_labels[processed[:classes][i]],
|
731
|
+
score: processed[:scores][i],
|
732
|
+
box: get_bounding_box(box, !percentage),
|
733
|
+
}
|
734
|
+
end
|
735
|
+
result.sort_by! { |v| -v[:score] }
|
736
|
+
if !top_k.nil?
|
737
|
+
result = result[0...topk]
|
738
|
+
end
|
739
|
+
to_return << result
|
740
|
+
end
|
741
|
+
|
742
|
+
is_batched ? to_return : to_return[0]
|
743
|
+
end
|
744
|
+
end
|
745
|
+
|
746
|
+
class DocumentQuestionAnsweringPipeline < Pipeline
|
747
|
+
def call(image, question, **generate_kwargs)
|
748
|
+
# NOTE: For now, we only support a batch size of 1
|
749
|
+
|
750
|
+
# Preprocess image
|
751
|
+
prepared_image = prepare_images(image)[0]
|
752
|
+
pixel_values = @processor.(prepared_image)[:pixel_values]
|
753
|
+
|
754
|
+
# Run tokenization
|
755
|
+
task_prompt = "<s_docvqa><s_question>#{question}</s_question><s_answer>"
|
756
|
+
decoder_input_ids =
|
757
|
+
@tokenizer.(
|
758
|
+
task_prompt,
|
759
|
+
add_special_tokens: false,
|
760
|
+
padding: true,
|
761
|
+
truncation: true
|
762
|
+
)[:input_ids]
|
763
|
+
|
764
|
+
# Run model
|
765
|
+
output =
|
766
|
+
@model.generate(
|
767
|
+
pixel_values,
|
768
|
+
generate_kwargs.merge(
|
769
|
+
decoder_input_ids: decoder_input_ids[0],
|
770
|
+
max_length: @model.config["decoder"]["max_position_embeddings"]
|
771
|
+
).transform_keys(&:to_s)
|
772
|
+
)
|
773
|
+
|
774
|
+
# Decode output
|
775
|
+
decoded = @tokenizer.batch_decode(output, skip_special_tokens: false)[0]
|
776
|
+
|
777
|
+
# Parse answer
|
778
|
+
match = decoded.match(/<s_answer>(.*?)<\/s_answer>/)
|
779
|
+
answer = nil
|
780
|
+
if match && match.length >= 2
|
781
|
+
answer = match[1].strip
|
782
|
+
end
|
783
|
+
[{answer:}]
|
784
|
+
end
|
785
|
+
end
|
786
|
+
|
246
787
|
class FeatureExtractionPipeline < Pipeline
|
247
788
|
def call(
|
248
789
|
texts,
|
@@ -306,6 +847,69 @@ module Informers
|
|
306
847
|
end
|
307
848
|
end
|
308
849
|
|
850
|
+
class ImageFeatureExtractionPipeline < Pipeline
|
851
|
+
def call(images)
|
852
|
+
prepared_images = prepare_images(images)
|
853
|
+
pixel_values = @processor.(prepared_images)[:pixel_values]
|
854
|
+
outputs = @model.({pixel_values: pixel_values})
|
855
|
+
|
856
|
+
result = outputs[0]
|
857
|
+
result
|
858
|
+
end
|
859
|
+
end
|
860
|
+
|
861
|
+
class ImageToImagePipeline < Pipeline
|
862
|
+
def call(images)
|
863
|
+
prepared_images = prepare_images(images)
|
864
|
+
inputs = @processor.(prepared_images)
|
865
|
+
outputs = @model.(inputs);
|
866
|
+
|
867
|
+
to_return = []
|
868
|
+
outputs[0].each do |batch|
|
869
|
+
# TODO flatten first
|
870
|
+
output =
|
871
|
+
batch.map do |v|
|
872
|
+
v.map do |v2|
|
873
|
+
v2.map do |v3|
|
874
|
+
(v3.clamp(0, 1) * 255).round
|
875
|
+
end
|
876
|
+
end
|
877
|
+
end
|
878
|
+
to_return << Utils::RawImage.from_array(output).image
|
879
|
+
end
|
880
|
+
|
881
|
+
to_return.length > 1 ? to_return : to_return[0]
|
882
|
+
end
|
883
|
+
end
|
884
|
+
|
885
|
+
class DepthEstimationPipeline < Pipeline
|
886
|
+
def call(images)
|
887
|
+
prepared_images = prepare_images(images)
|
888
|
+
|
889
|
+
inputs = @processor.(prepared_images)
|
890
|
+
predicted_depth = @model.(inputs)[0]
|
891
|
+
|
892
|
+
to_return = []
|
893
|
+
prepared_images.length.times do |i|
|
894
|
+
prediction = Utils.interpolate(predicted_depth[i], prepared_images[i].size.reverse, "bilinear", false)
|
895
|
+
max_prediction = Utils.max(prediction.flatten)[0]
|
896
|
+
formatted =
|
897
|
+
prediction.map do |v|
|
898
|
+
v.map do |v2|
|
899
|
+
v2.map do |v3|
|
900
|
+
(v3 * 255 / max_prediction).round
|
901
|
+
end
|
902
|
+
end
|
903
|
+
end
|
904
|
+
to_return << {
|
905
|
+
predicted_depth: predicted_depth[i],
|
906
|
+
depth: Utils::RawImage.from_array(formatted).image
|
907
|
+
}
|
908
|
+
end
|
909
|
+
to_return.length > 1 ? to_return : to_return[0]
|
910
|
+
end
|
911
|
+
end
|
912
|
+
|
309
913
|
class EmbeddingPipeline < FeatureExtractionPipeline
|
310
914
|
def call(
|
311
915
|
texts,
|
@@ -375,6 +979,145 @@ module Informers
|
|
375
979
|
},
|
376
980
|
type: "text"
|
377
981
|
},
|
982
|
+
"fill-mask" => {
|
983
|
+
tokenizer: AutoTokenizer,
|
984
|
+
pipeline: FillMaskPipeline,
|
985
|
+
model: AutoModelForMaskedLM,
|
986
|
+
default: {
|
987
|
+
model: "Xenova/bert-base-uncased"
|
988
|
+
},
|
989
|
+
type: "text"
|
990
|
+
},
|
991
|
+
"summarization" => {
|
992
|
+
tokenizer: AutoTokenizer,
|
993
|
+
pipeline: SummarizationPipeline,
|
994
|
+
model: AutoModelForSeq2SeqLM,
|
995
|
+
default: {
|
996
|
+
model: "Xenova/distilbart-cnn-6-6"
|
997
|
+
},
|
998
|
+
type: "text"
|
999
|
+
},
|
1000
|
+
"translation" => {
|
1001
|
+
tokenizer: AutoTokenizer,
|
1002
|
+
pipeline: TranslationPipeline,
|
1003
|
+
model: AutoModelForSeq2SeqLM,
|
1004
|
+
default: {
|
1005
|
+
model: "Xenova/t5-small"
|
1006
|
+
},
|
1007
|
+
type: "text"
|
1008
|
+
},
|
1009
|
+
"text2text-generation" => {
|
1010
|
+
tokenizer: AutoTokenizer,
|
1011
|
+
pipeline: Text2TextGenerationPipeline,
|
1012
|
+
model: AutoModelForSeq2SeqLM,
|
1013
|
+
default: {
|
1014
|
+
model: "Xenova/flan-t5-small"
|
1015
|
+
},
|
1016
|
+
type: "text"
|
1017
|
+
},
|
1018
|
+
"text-generation" => {
|
1019
|
+
tokenizer: AutoTokenizer,
|
1020
|
+
pipeline: TextGenerationPipeline,
|
1021
|
+
model: AutoModelForCausalLM,
|
1022
|
+
default: {
|
1023
|
+
model: "Xenova/gpt2"
|
1024
|
+
},
|
1025
|
+
type: "text"
|
1026
|
+
},
|
1027
|
+
"zero-shot-classification" => {
|
1028
|
+
tokenizer: AutoTokenizer,
|
1029
|
+
pipeline: ZeroShotClassificationPipeline,
|
1030
|
+
model: AutoModelForSequenceClassification,
|
1031
|
+
default: {
|
1032
|
+
model: "Xenova/distilbert-base-uncased-mnli"
|
1033
|
+
},
|
1034
|
+
type: "text"
|
1035
|
+
},
|
1036
|
+
"image-to-text" => {
|
1037
|
+
tokenizer: AutoTokenizer,
|
1038
|
+
pipeline: ImageToTextPipeline,
|
1039
|
+
model: AutoModelForVision2Seq,
|
1040
|
+
processor: AutoProcessor,
|
1041
|
+
default: {
|
1042
|
+
model: "Xenova/vit-gpt2-image-captioning"
|
1043
|
+
},
|
1044
|
+
type: "multimodal"
|
1045
|
+
},
|
1046
|
+
"image-classification" => {
|
1047
|
+
pipeline: ImageClassificationPipeline,
|
1048
|
+
model: AutoModelForImageClassification,
|
1049
|
+
processor: AutoProcessor,
|
1050
|
+
default: {
|
1051
|
+
model: "Xenova/vit-base-patch16-224",
|
1052
|
+
},
|
1053
|
+
type: "multimodal"
|
1054
|
+
},
|
1055
|
+
"image-segmentation" => {
|
1056
|
+
pipeline: ImageSegmentationPipeline,
|
1057
|
+
model: [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation],
|
1058
|
+
processor: AutoProcessor,
|
1059
|
+
default: {
|
1060
|
+
model: "Xenova/detr-resnet-50-panoptic",
|
1061
|
+
},
|
1062
|
+
type: "multimodal"
|
1063
|
+
},
|
1064
|
+
"zero-shot-image-classification" => {
|
1065
|
+
tokenizer: AutoTokenizer,
|
1066
|
+
pipeline: ZeroShotImageClassificationPipeline,
|
1067
|
+
model: AutoModel,
|
1068
|
+
processor: AutoProcessor,
|
1069
|
+
default: {
|
1070
|
+
model: "Xenova/clip-vit-base-patch32"
|
1071
|
+
},
|
1072
|
+
type: "multimodal"
|
1073
|
+
},
|
1074
|
+
"object-detection" => {
|
1075
|
+
pipeline: ObjectDetectionPipeline,
|
1076
|
+
model: AutoModelForObjectDetection,
|
1077
|
+
processor: AutoProcessor,
|
1078
|
+
default: {
|
1079
|
+
model: "Xenova/detr-resnet-50",
|
1080
|
+
},
|
1081
|
+
type: "multimodal"
|
1082
|
+
},
|
1083
|
+
"zero-shot-object-detection" => {
|
1084
|
+
tokenizer: AutoTokenizer,
|
1085
|
+
pipeline: ZeroShotObjectDetectionPipeline,
|
1086
|
+
model: AutoModelForZeroShotObjectDetection,
|
1087
|
+
processor: AutoProcessor,
|
1088
|
+
default: {
|
1089
|
+
model: "Xenova/owlvit-base-patch32"
|
1090
|
+
},
|
1091
|
+
type: "multimodal"
|
1092
|
+
},
|
1093
|
+
"document-question-answering" => {
|
1094
|
+
tokenizer: AutoTokenizer,
|
1095
|
+
pipeline: DocumentQuestionAnsweringPipeline,
|
1096
|
+
model: AutoModelForDocumentQuestionAnswering,
|
1097
|
+
processor: AutoProcessor,
|
1098
|
+
default: {
|
1099
|
+
model: "Xenova/donut-base-finetuned-docvqa"
|
1100
|
+
},
|
1101
|
+
type: "multimodal"
|
1102
|
+
},
|
1103
|
+
"image-to-image" => {
|
1104
|
+
pipeline: ImageToImagePipeline,
|
1105
|
+
model: AutoModelForImageToImage,
|
1106
|
+
processor: AutoProcessor,
|
1107
|
+
default: {
|
1108
|
+
model: "Xenova/swin2SR-classical-sr-x2-64"
|
1109
|
+
},
|
1110
|
+
type: "image"
|
1111
|
+
},
|
1112
|
+
"depth-estimation" => {
|
1113
|
+
pipeline: DepthEstimationPipeline,
|
1114
|
+
model: AutoModelForDepthEstimation,
|
1115
|
+
processor: AutoProcessor,
|
1116
|
+
default: {
|
1117
|
+
model: "Xenova/dpt-large"
|
1118
|
+
},
|
1119
|
+
type: "image"
|
1120
|
+
},
|
378
1121
|
"feature-extraction" => {
|
379
1122
|
tokenizer: AutoTokenizer,
|
380
1123
|
pipeline: FeatureExtractionPipeline,
|
@@ -384,6 +1127,15 @@ module Informers
|
|
384
1127
|
},
|
385
1128
|
type: "text"
|
386
1129
|
},
|
1130
|
+
"image-feature-extraction" => {
|
1131
|
+
processor: AutoProcessor,
|
1132
|
+
pipeline: ImageFeatureExtractionPipeline,
|
1133
|
+
model: [AutoModelForImageFeatureExtraction, AutoModel],
|
1134
|
+
default: {
|
1135
|
+
model: "Xenova/vit-base-patch16-224"
|
1136
|
+
},
|
1137
|
+
type: "image"
|
1138
|
+
},
|
387
1139
|
"embedding" => {
|
388
1140
|
tokenizer: AutoTokenizer,
|
389
1141
|
pipeline: EmbeddingPipeline,
|
@@ -439,14 +1191,14 @@ module Informers
|
|
439
1191
|
revision: "main",
|
440
1192
|
model_file_name: nil
|
441
1193
|
)
|
1194
|
+
# Apply aliases
|
1195
|
+
task = TASK_ALIASES[task] || task
|
1196
|
+
|
442
1197
|
if quantized == NO_DEFAULT
|
443
1198
|
# TODO move default to task class
|
444
|
-
quantized =
|
1199
|
+
quantized = ["text-classification", "token-classification", "question-answering", "feature-extraction"].include?(task)
|
445
1200
|
end
|
446
1201
|
|
447
|
-
# Apply aliases
|
448
|
-
task = TASK_ALIASES[task] || task
|
449
|
-
|
450
1202
|
# Get pipeline info
|
451
1203
|
pipeline_info = SUPPORTED_TASKS[task.split("_", 1)[0]]
|
452
1204
|
if !pipeline_info
|
@@ -502,7 +1254,15 @@ module Informers
|
|
502
1254
|
next if !cls
|
503
1255
|
|
504
1256
|
if cls.is_a?(Array)
|
505
|
-
|
1257
|
+
e = nil
|
1258
|
+
cls.each do |c|
|
1259
|
+
begin
|
1260
|
+
result[name] = c.from_pretrained(model, **pretrained_options)
|
1261
|
+
rescue => err
|
1262
|
+
e = err
|
1263
|
+
end
|
1264
|
+
end
|
1265
|
+
raise e unless result[name]
|
506
1266
|
else
|
507
1267
|
result[name] = cls.from_pretrained(model, **pretrained_options)
|
508
1268
|
end
|