informers 1.0.2 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +213 -19
- data/lib/informers/configs.rb +10 -8
- data/lib/informers/model.rb +2 -14
- data/lib/informers/models.rb +1027 -13
- data/lib/informers/pipelines.rb +781 -14
- data/lib/informers/processors.rb +796 -0
- data/lib/informers/tokenizers.rb +166 -4
- data/lib/informers/utils/core.rb +4 -0
- data/lib/informers/utils/generation.rb +294 -0
- data/lib/informers/utils/image.rb +116 -0
- data/lib/informers/utils/math.rb +73 -0
- data/lib/informers/utils/tensor.rb +46 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +3 -0
- metadata +8 -5
data/lib/informers/pipelines.rb
CHANGED
@@ -7,6 +7,26 @@ module Informers
|
|
7
7
|
@tokenizer = tokenizer
|
8
8
|
@processor = processor
|
9
9
|
end
|
10
|
+
|
11
|
+
private
|
12
|
+
|
13
|
+
def prepare_images(images)
|
14
|
+
if !images.is_a?(Array)
|
15
|
+
images = [images]
|
16
|
+
end
|
17
|
+
|
18
|
+
# Possibly convert any non-images to images
|
19
|
+
images.map { |x| Utils::RawImage.read(x) }
|
20
|
+
end
|
21
|
+
|
22
|
+
def get_bounding_box(box, as_integer)
|
23
|
+
if as_integer
|
24
|
+
box = box.map { |x| x.to_i }
|
25
|
+
end
|
26
|
+
xmin, ymin, xmax, ymax = box
|
27
|
+
|
28
|
+
{xmin:, ymin:, xmax:, ymax:}
|
29
|
+
end
|
10
30
|
end
|
11
31
|
|
12
32
|
class TextClassificationPipeline < Pipeline
|
@@ -21,13 +41,13 @@ module Informers
|
|
21
41
|
outputs = @model.(model_inputs)
|
22
42
|
|
23
43
|
function_to_apply =
|
24
|
-
if @model.config
|
44
|
+
if @model.config[:problem_type] == "multi_label_classification"
|
25
45
|
->(batch) { Utils.sigmoid(batch) }
|
26
46
|
else
|
27
47
|
->(batch) { Utils.softmax(batch) } # single_label_classification (default)
|
28
48
|
end
|
29
49
|
|
30
|
-
id2label = @model.config
|
50
|
+
id2label = @model.config[:id2label]
|
31
51
|
|
32
52
|
to_return = []
|
33
53
|
outputs.logits.each do |batch|
|
@@ -70,7 +90,7 @@ module Informers
|
|
70
90
|
outputs = @model.(model_inputs)
|
71
91
|
|
72
92
|
logits = outputs.logits
|
73
|
-
id2label = @model.config
|
93
|
+
id2label = @model.config[:id2label]
|
74
94
|
|
75
95
|
to_return = []
|
76
96
|
logits.length.times do |i|
|
@@ -243,13 +263,535 @@ module Informers
|
|
243
263
|
end
|
244
264
|
end
|
245
265
|
|
266
|
+
class FillMaskPipeline < Pipeline
|
267
|
+
def call(texts, top_k: 5)
|
268
|
+
model_inputs = @tokenizer.(texts, padding: true, truncation: true)
|
269
|
+
outputs = @model.(model_inputs)
|
270
|
+
|
271
|
+
to_return = []
|
272
|
+
model_inputs[:input_ids].each_with_index do |ids, i|
|
273
|
+
mask_token_index = ids.index(@tokenizer.mask_token_id)
|
274
|
+
|
275
|
+
if mask_token_index.nil?
|
276
|
+
raise ArgumentError, "Mask token (#{@tokenizer.mask_token}) not found in text."
|
277
|
+
end
|
278
|
+
logits = outputs.logits[i]
|
279
|
+
item_logits = logits[mask_token_index]
|
280
|
+
|
281
|
+
scores = Utils.get_top_items(Utils.softmax(item_logits), top_k)
|
282
|
+
|
283
|
+
to_return <<
|
284
|
+
scores.map do |x|
|
285
|
+
sequence = ids.dup
|
286
|
+
sequence[mask_token_index] = x[0]
|
287
|
+
|
288
|
+
{
|
289
|
+
score: x[1],
|
290
|
+
token: x[0],
|
291
|
+
token_str: @tokenizer.id_to_token(x[0]),
|
292
|
+
sequence: @tokenizer.decode(sequence, skip_special_tokens: true)
|
293
|
+
}
|
294
|
+
end
|
295
|
+
end
|
296
|
+
texts.is_a?(Array) ? to_return : to_return[0]
|
297
|
+
end
|
298
|
+
end
|
299
|
+
|
300
|
+
class Text2TextGenerationPipeline < Pipeline
|
301
|
+
KEY = :generated_text
|
302
|
+
|
303
|
+
def call(texts, **generate_kwargs)
|
304
|
+
if !texts.is_a?(Array)
|
305
|
+
texts = [texts]
|
306
|
+
end
|
307
|
+
|
308
|
+
# Add global prefix, if present
|
309
|
+
if @model.config[:prefix]
|
310
|
+
texts = texts.map { |x| @model.config[:prefix] + x }
|
311
|
+
end
|
312
|
+
|
313
|
+
# Handle task specific params:
|
314
|
+
task_specific_params = @model.config[:task_specific_params]
|
315
|
+
if task_specific_params && task_specific_params[@task]
|
316
|
+
# Add prefixes, if present
|
317
|
+
if task_specific_params[@task]["prefix"]
|
318
|
+
texts = texts.map { |x| task_specific_params[@task]["prefix"] + x }
|
319
|
+
end
|
320
|
+
|
321
|
+
# TODO update generation config
|
322
|
+
end
|
323
|
+
|
324
|
+
tokenizer = @tokenizer
|
325
|
+
tokenizer_options = {
|
326
|
+
padding: true,
|
327
|
+
truncation: true
|
328
|
+
}
|
329
|
+
if is_a?(TranslationPipeline) && tokenizer.respond_to?(:_build_translation_inputs)
|
330
|
+
input_ids = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs)[:input_ids]
|
331
|
+
else
|
332
|
+
input_ids = tokenizer.(texts, **tokenizer_options)[:input_ids]
|
333
|
+
end
|
334
|
+
|
335
|
+
output_token_ids = @model.generate(input_ids, generate_kwargs)
|
336
|
+
|
337
|
+
tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)
|
338
|
+
.map { |text| {self.class.const_get(:KEY) => text} }
|
339
|
+
end
|
340
|
+
end
|
341
|
+
|
342
|
+
class SummarizationPipeline < Text2TextGenerationPipeline
|
343
|
+
KEY = :summary_text
|
344
|
+
end
|
345
|
+
|
346
|
+
class TranslationPipeline < Text2TextGenerationPipeline
|
347
|
+
KEY = :translation_text
|
348
|
+
end
|
349
|
+
|
350
|
+
class TextGenerationPipeline < Pipeline
|
351
|
+
def call(texts, **generate_kwargs)
|
352
|
+
is_batched = false
|
353
|
+
is_chat_input = false
|
354
|
+
|
355
|
+
# Normalize inputs
|
356
|
+
if texts.is_a?(String)
|
357
|
+
texts = [texts]
|
358
|
+
inputs = texts
|
359
|
+
else
|
360
|
+
raise Todo
|
361
|
+
end
|
362
|
+
|
363
|
+
# By default, do not add special tokens
|
364
|
+
add_special_tokens = generate_kwargs[:add_special_tokens] || false
|
365
|
+
|
366
|
+
# /By default, return full text
|
367
|
+
return_full_text =
|
368
|
+
if is_chat_input
|
369
|
+
false
|
370
|
+
else
|
371
|
+
generate_kwargs[:return_full_text] || true
|
372
|
+
end
|
373
|
+
|
374
|
+
@tokenizer.padding_side = "left"
|
375
|
+
input_ids, attention_mask =
|
376
|
+
@tokenizer.(inputs, add_special_tokens:, padding: true, truncation: true)
|
377
|
+
.values_at(:input_ids, :attention_mask)
|
378
|
+
|
379
|
+
output_token_ids =
|
380
|
+
@model.generate(
|
381
|
+
input_ids, generate_kwargs, nil, inputs_attention_mask: attention_mask
|
382
|
+
)
|
383
|
+
|
384
|
+
decoded = @tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)
|
385
|
+
|
386
|
+
if !return_full_text && Utils.dims(input_ids)[-1] > 0
|
387
|
+
prompt_lengths = @tokenizer.batch_decode(input_ids, skip_special_tokens: true).map { |x| x.length }
|
388
|
+
end
|
389
|
+
|
390
|
+
to_return = Array.new(texts.length) { [] }
|
391
|
+
decoded.length.times do |i|
|
392
|
+
text_index = (i / output_token_ids.length.to_i * texts.length).floor
|
393
|
+
|
394
|
+
if prompt_lengths
|
395
|
+
raise Todo
|
396
|
+
end
|
397
|
+
# TODO is_chat_input
|
398
|
+
to_return[text_index] << {
|
399
|
+
generated_text: decoded[i]
|
400
|
+
}
|
401
|
+
end
|
402
|
+
!is_batched && to_return.length == 1 ? to_return[0] : to_return
|
403
|
+
end
|
404
|
+
end
|
405
|
+
|
406
|
+
class ZeroShotClassificationPipeline < Pipeline
|
407
|
+
def initialize(**options)
|
408
|
+
super(**options)
|
409
|
+
|
410
|
+
@label2id = @model.config[:label2id].transform_keys(&:downcase)
|
411
|
+
|
412
|
+
@entailment_id = @label2id["entailment"]
|
413
|
+
if @entailment_id.nil?
|
414
|
+
warn "Could not find 'entailment' in label2id mapping. Using 2 as entailment_id."
|
415
|
+
@entailment_id = 2
|
416
|
+
end
|
417
|
+
|
418
|
+
@contradiction_id = @label2id["contradiction"] || @label2id["not_entailment"]
|
419
|
+
if @contradiction_id.nil?
|
420
|
+
warn "Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id."
|
421
|
+
@contradiction_id = 0
|
422
|
+
end
|
423
|
+
end
|
424
|
+
|
425
|
+
def call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false)
|
426
|
+
is_batched = texts.is_a?(Array)
|
427
|
+
if !is_batched
|
428
|
+
texts = [texts]
|
429
|
+
end
|
430
|
+
if !candidate_labels.is_a?(Array)
|
431
|
+
candidate_labels = [candidate_labels]
|
432
|
+
end
|
433
|
+
|
434
|
+
# Insert labels into hypothesis template
|
435
|
+
hypotheses = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }
|
436
|
+
|
437
|
+
# How to perform the softmax over the logits:
|
438
|
+
# - true: softmax over the entailment vs. contradiction dim for each label independently
|
439
|
+
# - false: softmax the "entailment" logits over all candidate labels
|
440
|
+
softmax_each = multi_label || candidate_labels.length == 1
|
441
|
+
|
442
|
+
to_return = []
|
443
|
+
texts.each do |premise|
|
444
|
+
entails_logits = []
|
445
|
+
|
446
|
+
hypotheses.each do |hypothesis|
|
447
|
+
inputs = @tokenizer.(
|
448
|
+
premise,
|
449
|
+
text_pair: hypothesis,
|
450
|
+
padding: true,
|
451
|
+
truncation: true
|
452
|
+
)
|
453
|
+
outputs = @model.(inputs)
|
454
|
+
|
455
|
+
if softmax_each
|
456
|
+
entails_logits << [
|
457
|
+
outputs.logits[0][@contradiction_id],
|
458
|
+
outputs.logits[0][@entailment_id]
|
459
|
+
]
|
460
|
+
else
|
461
|
+
entails_logits << outputs.logits[0][@entailment_id]
|
462
|
+
end
|
463
|
+
end
|
464
|
+
|
465
|
+
scores =
|
466
|
+
if softmax_each
|
467
|
+
entails_logits.map { |x| Utils.softmax(x)[1] }
|
468
|
+
else
|
469
|
+
Utils.softmax(entails_logits)
|
470
|
+
end
|
471
|
+
|
472
|
+
# Sort by scores (desc) and return scores with indices
|
473
|
+
scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] }
|
474
|
+
|
475
|
+
to_return << {
|
476
|
+
sequence: premise,
|
477
|
+
labels: scores_sorted.map { |x| candidate_labels[x[1]] },
|
478
|
+
scores: scores_sorted.map { |x| x[0] }
|
479
|
+
}
|
480
|
+
end
|
481
|
+
is_batched ? to_return : to_return[0]
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
class ImageToTextPipeline < Pipeline
|
486
|
+
def call(images, **generate_kwargs)
|
487
|
+
is_batched = images.is_a?(Array)
|
488
|
+
prepared_images = prepare_images(images)
|
489
|
+
|
490
|
+
pixel_values = @processor.(prepared_images)[:pixel_values]
|
491
|
+
|
492
|
+
to_return = []
|
493
|
+
pixel_values.each do |batch|
|
494
|
+
batch = [batch]
|
495
|
+
output = @model.generate(batch, **generate_kwargs)
|
496
|
+
decoded = @tokenizer
|
497
|
+
.batch_decode(output, skip_special_tokens: true)
|
498
|
+
.map { |x| {generated_text: x.strip} }
|
499
|
+
to_return << decoded
|
500
|
+
end
|
501
|
+
|
502
|
+
is_batched ? to_return : to_return[0]
|
503
|
+
end
|
504
|
+
end
|
505
|
+
|
506
|
+
class ImageClassificationPipeline < Pipeline
|
507
|
+
def call(images, top_k: 1)
|
508
|
+
is_batched = images.is_a?(Array)
|
509
|
+
prepared_images = prepare_images(images)
|
510
|
+
|
511
|
+
pixel_values = @processor.(prepared_images)[:pixel_values]
|
512
|
+
output = @model.({pixel_values: pixel_values})
|
513
|
+
|
514
|
+
id2label = @model.config[:id2label]
|
515
|
+
to_return = []
|
516
|
+
output.logits.each do |batch|
|
517
|
+
scores = Utils.get_top_items(Utils.softmax(batch), top_k)
|
518
|
+
|
519
|
+
vals =
|
520
|
+
scores.map do |x|
|
521
|
+
{
|
522
|
+
label: id2label[x[0].to_s],
|
523
|
+
score: x[1]
|
524
|
+
}
|
525
|
+
end
|
526
|
+
if top_k == 1
|
527
|
+
to_return.push(*vals)
|
528
|
+
else
|
529
|
+
to_return << vals
|
530
|
+
end
|
531
|
+
end
|
532
|
+
|
533
|
+
is_batched || top_k == 1 ? to_return : to_return[0]
|
534
|
+
end
|
535
|
+
end
|
536
|
+
|
537
|
+
class ImageSegmentationPipeline < Pipeline
|
538
|
+
def initialize(**options)
|
539
|
+
super(**options)
|
540
|
+
|
541
|
+
@subtasks_mapping = {
|
542
|
+
"panoptic" => "post_process_panoptic_segmentation",
|
543
|
+
"instance" => "post_process_instance_segmentation",
|
544
|
+
"semantic" => "post_process_semantic_segmentation"
|
545
|
+
}
|
546
|
+
end
|
547
|
+
|
548
|
+
def call(
|
549
|
+
images,
|
550
|
+
threshold: 0.5,
|
551
|
+
mask_threshold: 0.5,
|
552
|
+
overlap_mask_area_threshold: 0.8,
|
553
|
+
label_ids_to_fuse: nil,
|
554
|
+
target_sizes: nil,
|
555
|
+
subtask: nil
|
556
|
+
)
|
557
|
+
is_batched = images.is_a?(Array)
|
558
|
+
|
559
|
+
if is_batched && images.length != 1
|
560
|
+
raise Error, "Image segmentation pipeline currently only supports a batch size of 1."
|
561
|
+
end
|
562
|
+
|
563
|
+
prepared_images = prepare_images(images)
|
564
|
+
image_sizes = prepared_images.map { |x| [x.height, x.width] }
|
565
|
+
|
566
|
+
model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)
|
567
|
+
output = @model.(model_inputs)
|
568
|
+
|
569
|
+
if !subtask.nil?
|
570
|
+
fn = @subtasks_mapping[subtask]
|
571
|
+
else
|
572
|
+
@subtasks_mapping.each do |task, func|
|
573
|
+
if @processor.feature_extractor.respond_to?(func)
|
574
|
+
fn = @processor.feature_extractor.method(func)
|
575
|
+
subtask = task
|
576
|
+
break
|
577
|
+
end
|
578
|
+
end
|
579
|
+
end
|
580
|
+
|
581
|
+
id2label = @model.config[:id2label]
|
582
|
+
|
583
|
+
annotation = []
|
584
|
+
if subtask == "panoptic" || subtask == "instance"
|
585
|
+
processed = fn.(
|
586
|
+
output,
|
587
|
+
threshold:,
|
588
|
+
mask_threshold:,
|
589
|
+
overlap_mask_area_threshold:,
|
590
|
+
label_ids_to_fuse:,
|
591
|
+
target_sizes: target_sizes || image_sizes, # TODO FIX?
|
592
|
+
)[0]
|
593
|
+
|
594
|
+
_segmentation = processed[:segmentation]
|
595
|
+
|
596
|
+
processed[:segments_info].each do |segment|
|
597
|
+
annotation << {
|
598
|
+
label: id2label[segment[:label_id].to_s],
|
599
|
+
score: segment[:score]
|
600
|
+
# TODO mask
|
601
|
+
}
|
602
|
+
end
|
603
|
+
elsif subtask == "semantic"
|
604
|
+
raise Todo
|
605
|
+
else
|
606
|
+
raise Error, "Subtask #{subtask} not supported."
|
607
|
+
end
|
608
|
+
|
609
|
+
annotation
|
610
|
+
end
|
611
|
+
end
|
612
|
+
|
613
|
+
class ZeroShotImageClassificationPipeline < Pipeline
|
614
|
+
def call(images, candidate_labels, hypothesis_template: "This is a photo of {}")
|
615
|
+
is_batched = images.is_a?(Array)
|
616
|
+
prepared_images = prepare_images(images)
|
617
|
+
|
618
|
+
# Insert label into hypothesis template
|
619
|
+
texts = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }
|
620
|
+
|
621
|
+
# Run tokenization
|
622
|
+
text_inputs = @tokenizer.(texts,
|
623
|
+
padding: @model.config[:model_type] == "siglip" ? "max_length" : true,
|
624
|
+
truncation: true
|
625
|
+
)
|
626
|
+
|
627
|
+
# Run processor
|
628
|
+
pixel_values = @processor.(prepared_images)[:pixel_values]
|
629
|
+
|
630
|
+
# Run model with both text and pixel inputs
|
631
|
+
output = @model.(text_inputs.merge(pixel_values: pixel_values))
|
632
|
+
|
633
|
+
function_to_apply =
|
634
|
+
if @model.config[:model_type] == "siglip"
|
635
|
+
->(batch) { Utils.sigmoid(batch) }
|
636
|
+
else
|
637
|
+
->(batch) { Utils.softmax(batch) }
|
638
|
+
end
|
639
|
+
|
640
|
+
# Compare each image with each candidate label
|
641
|
+
to_return = []
|
642
|
+
output[0].each do |batch|
|
643
|
+
# Compute softmax per image
|
644
|
+
probs = function_to_apply.(batch)
|
645
|
+
|
646
|
+
result = probs
|
647
|
+
.map.with_index { |x, i| {label: candidate_labels[i], score: x} }
|
648
|
+
.sort_by { |v| -v[:score] }
|
649
|
+
|
650
|
+
to_return << result
|
651
|
+
end
|
652
|
+
|
653
|
+
is_batched ? to_return : to_return[0]
|
654
|
+
end
|
655
|
+
end
|
656
|
+
|
657
|
+
class ObjectDetectionPipeline < Pipeline
|
658
|
+
def call(images, threshold: 0.9, percentage: false)
|
659
|
+
is_batched = images.is_a?(Array)
|
660
|
+
|
661
|
+
if is_batched && images.length != 1
|
662
|
+
raise Error, "Object detection pipeline currently only supports a batch size of 1."
|
663
|
+
end
|
664
|
+
prepared_images = prepare_images(images)
|
665
|
+
|
666
|
+
image_sizes = percentage ? nil : prepared_images.map { |x| [x.height, x.width] }
|
667
|
+
|
668
|
+
model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)
|
669
|
+
output = @model.(model_inputs)
|
670
|
+
|
671
|
+
processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_sizes)
|
672
|
+
|
673
|
+
# Add labels
|
674
|
+
id2label = @model.config[:id2label]
|
675
|
+
|
676
|
+
# Format output
|
677
|
+
result =
|
678
|
+
processed.map do |batch|
|
679
|
+
batch[:boxes].map.with_index do |box, i|
|
680
|
+
{
|
681
|
+
label: id2label[batch[:classes][i].to_s],
|
682
|
+
score: batch[:scores][i],
|
683
|
+
box: get_bounding_box(box, !percentage)
|
684
|
+
}
|
685
|
+
end.sort_by { |v| -v[:score] }
|
686
|
+
end
|
687
|
+
|
688
|
+
is_batched ? result : result[0]
|
689
|
+
end
|
690
|
+
end
|
691
|
+
|
692
|
+
class ZeroShotObjectDetectionPipeline < Pipeline
|
693
|
+
def call(
|
694
|
+
images,
|
695
|
+
candidate_labels,
|
696
|
+
threshold: 0.1,
|
697
|
+
top_k: nil,
|
698
|
+
percentage: false
|
699
|
+
)
|
700
|
+
is_batched = images.is_a?(Array)
|
701
|
+
prepared_images = prepare_images(images)
|
702
|
+
|
703
|
+
# Run tokenization
|
704
|
+
text_inputs = @tokenizer.(candidate_labels,
|
705
|
+
padding: true,
|
706
|
+
truncation: true
|
707
|
+
)
|
708
|
+
|
709
|
+
# Run processor
|
710
|
+
model_inputs = @processor.(prepared_images)
|
711
|
+
|
712
|
+
# Since non-maximum suppression is performed for exporting, we need to
|
713
|
+
# process each image separately. For more information, see:
|
714
|
+
# https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
|
715
|
+
to_return = []
|
716
|
+
prepared_images.length.times do |i|
|
717
|
+
image = prepared_images[i]
|
718
|
+
image_size = percentage ? nil : [[image.height, image.width]]
|
719
|
+
pixel_values = [model_inputs[:pixel_values][i]]
|
720
|
+
|
721
|
+
# Run model with both text and pixel inputs
|
722
|
+
output = @model.(text_inputs.merge(pixel_values: pixel_values))
|
723
|
+
# TODO remove
|
724
|
+
output = @model.instance_variable_get(:@session).outputs.map { |v| v[:name].to_sym }.zip(output).to_h
|
725
|
+
|
726
|
+
processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_size, true)[0]
|
727
|
+
result =
|
728
|
+
processed[:boxes].map.with_index do |box, i|
|
729
|
+
{
|
730
|
+
label: candidate_labels[processed[:classes][i]],
|
731
|
+
score: processed[:scores][i],
|
732
|
+
box: get_bounding_box(box, !percentage),
|
733
|
+
}
|
734
|
+
end
|
735
|
+
result.sort_by! { |v| -v[:score] }
|
736
|
+
if !top_k.nil?
|
737
|
+
result = result[0...topk]
|
738
|
+
end
|
739
|
+
to_return << result
|
740
|
+
end
|
741
|
+
|
742
|
+
is_batched ? to_return : to_return[0]
|
743
|
+
end
|
744
|
+
end
|
745
|
+
|
746
|
+
class DocumentQuestionAnsweringPipeline < Pipeline
|
747
|
+
def call(image, question, **generate_kwargs)
|
748
|
+
# NOTE: For now, we only support a batch size of 1
|
749
|
+
|
750
|
+
# Preprocess image
|
751
|
+
prepared_image = prepare_images(image)[0]
|
752
|
+
pixel_values = @processor.(prepared_image)[:pixel_values]
|
753
|
+
|
754
|
+
# Run tokenization
|
755
|
+
task_prompt = "<s_docvqa><s_question>#{question}</s_question><s_answer>"
|
756
|
+
decoder_input_ids =
|
757
|
+
@tokenizer.(
|
758
|
+
task_prompt,
|
759
|
+
add_special_tokens: false,
|
760
|
+
padding: true,
|
761
|
+
truncation: true
|
762
|
+
)[:input_ids]
|
763
|
+
|
764
|
+
# Run model
|
765
|
+
output =
|
766
|
+
@model.generate(
|
767
|
+
pixel_values,
|
768
|
+
generate_kwargs.merge(
|
769
|
+
decoder_input_ids: decoder_input_ids[0],
|
770
|
+
max_length: @model.config["decoder"]["max_position_embeddings"]
|
771
|
+
).transform_keys(&:to_s)
|
772
|
+
)
|
773
|
+
|
774
|
+
# Decode output
|
775
|
+
decoded = @tokenizer.batch_decode(output, skip_special_tokens: false)[0]
|
776
|
+
|
777
|
+
# Parse answer
|
778
|
+
match = decoded.match(/<s_answer>(.*?)<\/s_answer>/)
|
779
|
+
answer = nil
|
780
|
+
if match && match.length >= 2
|
781
|
+
answer = match[1].strip
|
782
|
+
end
|
783
|
+
[{answer:}]
|
784
|
+
end
|
785
|
+
end
|
786
|
+
|
246
787
|
class FeatureExtractionPipeline < Pipeline
|
247
788
|
def call(
|
248
789
|
texts,
|
249
790
|
pooling: "none",
|
250
791
|
normalize: false,
|
251
792
|
quantize: false,
|
252
|
-
precision: "binary"
|
793
|
+
precision: "binary",
|
794
|
+
model_output: nil
|
253
795
|
)
|
254
796
|
# Run tokenization
|
255
797
|
model_inputs = @tokenizer.(texts,
|
@@ -258,8 +800,10 @@ module Informers
|
|
258
800
|
)
|
259
801
|
model_options = {}
|
260
802
|
|
261
|
-
|
262
|
-
|
803
|
+
if !model_output.nil?
|
804
|
+
model_options[:output_names] = Array(model_output)
|
805
|
+
elsif @model.instance_variable_get(:@output_names) == ["token_embeddings"] && pooling == "mean" && normalize
|
806
|
+
# optimization for sentence-transformers/all-MiniLM-L6-v2
|
263
807
|
model_options[:output_names] = ["sentence_embedding"]
|
264
808
|
pooling = "none"
|
265
809
|
normalize = false
|
@@ -271,7 +815,9 @@ module Informers
|
|
271
815
|
# TODO improve
|
272
816
|
result =
|
273
817
|
if outputs.is_a?(Array)
|
274
|
-
|
818
|
+
# TODO show returned instead of all
|
819
|
+
output_names = @model.instance_variable_get(:@session).outputs.map { |v| v[:name] }
|
820
|
+
raise Error, "unexpected outputs: #{output_names}" if outputs.size != 1
|
275
821
|
outputs[0]
|
276
822
|
else
|
277
823
|
outputs.logits
|
@@ -285,6 +831,7 @@ module Informers
|
|
285
831
|
when "cls"
|
286
832
|
result = result.map(&:first)
|
287
833
|
else
|
834
|
+
# TODO raise ArgumentError in 2.0
|
288
835
|
raise Error, "Pooling method '#{pooling}' not supported."
|
289
836
|
end
|
290
837
|
|
@@ -300,13 +847,77 @@ module Informers
|
|
300
847
|
end
|
301
848
|
end
|
302
849
|
|
850
|
+
class ImageFeatureExtractionPipeline < Pipeline
|
851
|
+
def call(images)
|
852
|
+
prepared_images = prepare_images(images)
|
853
|
+
pixel_values = @processor.(prepared_images)[:pixel_values]
|
854
|
+
outputs = @model.({pixel_values: pixel_values})
|
855
|
+
|
856
|
+
result = outputs[0]
|
857
|
+
result
|
858
|
+
end
|
859
|
+
end
|
860
|
+
|
861
|
+
class ImageToImagePipeline < Pipeline
|
862
|
+
def call(images)
|
863
|
+
prepared_images = prepare_images(images)
|
864
|
+
inputs = @processor.(prepared_images)
|
865
|
+
outputs = @model.(inputs);
|
866
|
+
|
867
|
+
to_return = []
|
868
|
+
outputs[0].each do |batch|
|
869
|
+
# TODO flatten first
|
870
|
+
output =
|
871
|
+
batch.map do |v|
|
872
|
+
v.map do |v2|
|
873
|
+
v2.map do |v3|
|
874
|
+
(v3.clamp(0, 1) * 255).round
|
875
|
+
end
|
876
|
+
end
|
877
|
+
end
|
878
|
+
to_return << Utils::RawImage.from_array(output).image
|
879
|
+
end
|
880
|
+
|
881
|
+
to_return.length > 1 ? to_return : to_return[0]
|
882
|
+
end
|
883
|
+
end
|
884
|
+
|
885
|
+
class DepthEstimationPipeline < Pipeline
|
886
|
+
def call(images)
|
887
|
+
prepared_images = prepare_images(images)
|
888
|
+
|
889
|
+
inputs = @processor.(prepared_images)
|
890
|
+
predicted_depth = @model.(inputs)[0]
|
891
|
+
|
892
|
+
to_return = []
|
893
|
+
prepared_images.length.times do |i|
|
894
|
+
prediction = Utils.interpolate(predicted_depth[i], prepared_images[i].size.reverse, "bilinear", false)
|
895
|
+
max_prediction = Utils.max(prediction.flatten)[0]
|
896
|
+
formatted =
|
897
|
+
prediction.map do |v|
|
898
|
+
v.map do |v2|
|
899
|
+
v2.map do |v3|
|
900
|
+
(v3 * 255 / max_prediction).round
|
901
|
+
end
|
902
|
+
end
|
903
|
+
end
|
904
|
+
to_return << {
|
905
|
+
predicted_depth: predicted_depth[i],
|
906
|
+
depth: Utils::RawImage.from_array(formatted).image
|
907
|
+
}
|
908
|
+
end
|
909
|
+
to_return.length > 1 ? to_return : to_return[0]
|
910
|
+
end
|
911
|
+
end
|
912
|
+
|
303
913
|
class EmbeddingPipeline < FeatureExtractionPipeline
|
304
914
|
def call(
|
305
915
|
texts,
|
306
916
|
pooling: "mean",
|
307
|
-
normalize: true
|
917
|
+
normalize: true,
|
918
|
+
model_output: nil
|
308
919
|
)
|
309
|
-
super(texts, pooling:, normalize:)
|
920
|
+
super(texts, pooling:, normalize:, model_output:)
|
310
921
|
end
|
311
922
|
end
|
312
923
|
|
@@ -368,6 +979,145 @@ module Informers
|
|
368
979
|
},
|
369
980
|
type: "text"
|
370
981
|
},
|
982
|
+
"fill-mask" => {
|
983
|
+
tokenizer: AutoTokenizer,
|
984
|
+
pipeline: FillMaskPipeline,
|
985
|
+
model: AutoModelForMaskedLM,
|
986
|
+
default: {
|
987
|
+
model: "Xenova/bert-base-uncased"
|
988
|
+
},
|
989
|
+
type: "text"
|
990
|
+
},
|
991
|
+
"summarization" => {
|
992
|
+
tokenizer: AutoTokenizer,
|
993
|
+
pipeline: SummarizationPipeline,
|
994
|
+
model: AutoModelForSeq2SeqLM,
|
995
|
+
default: {
|
996
|
+
model: "Xenova/distilbart-cnn-6-6"
|
997
|
+
},
|
998
|
+
type: "text"
|
999
|
+
},
|
1000
|
+
"translation" => {
|
1001
|
+
tokenizer: AutoTokenizer,
|
1002
|
+
pipeline: TranslationPipeline,
|
1003
|
+
model: AutoModelForSeq2SeqLM,
|
1004
|
+
default: {
|
1005
|
+
model: "Xenova/t5-small"
|
1006
|
+
},
|
1007
|
+
type: "text"
|
1008
|
+
},
|
1009
|
+
"text2text-generation" => {
|
1010
|
+
tokenizer: AutoTokenizer,
|
1011
|
+
pipeline: Text2TextGenerationPipeline,
|
1012
|
+
model: AutoModelForSeq2SeqLM,
|
1013
|
+
default: {
|
1014
|
+
model: "Xenova/flan-t5-small"
|
1015
|
+
},
|
1016
|
+
type: "text"
|
1017
|
+
},
|
1018
|
+
"text-generation" => {
|
1019
|
+
tokenizer: AutoTokenizer,
|
1020
|
+
pipeline: TextGenerationPipeline,
|
1021
|
+
model: AutoModelForCausalLM,
|
1022
|
+
default: {
|
1023
|
+
model: "Xenova/gpt2"
|
1024
|
+
},
|
1025
|
+
type: "text"
|
1026
|
+
},
|
1027
|
+
"zero-shot-classification" => {
|
1028
|
+
tokenizer: AutoTokenizer,
|
1029
|
+
pipeline: ZeroShotClassificationPipeline,
|
1030
|
+
model: AutoModelForSequenceClassification,
|
1031
|
+
default: {
|
1032
|
+
model: "Xenova/distilbert-base-uncased-mnli"
|
1033
|
+
},
|
1034
|
+
type: "text"
|
1035
|
+
},
|
1036
|
+
"image-to-text" => {
|
1037
|
+
tokenizer: AutoTokenizer,
|
1038
|
+
pipeline: ImageToTextPipeline,
|
1039
|
+
model: AutoModelForVision2Seq,
|
1040
|
+
processor: AutoProcessor,
|
1041
|
+
default: {
|
1042
|
+
model: "Xenova/vit-gpt2-image-captioning"
|
1043
|
+
},
|
1044
|
+
type: "multimodal"
|
1045
|
+
},
|
1046
|
+
"image-classification" => {
|
1047
|
+
pipeline: ImageClassificationPipeline,
|
1048
|
+
model: AutoModelForImageClassification,
|
1049
|
+
processor: AutoProcessor,
|
1050
|
+
default: {
|
1051
|
+
model: "Xenova/vit-base-patch16-224",
|
1052
|
+
},
|
1053
|
+
type: "multimodal"
|
1054
|
+
},
|
1055
|
+
"image-segmentation" => {
|
1056
|
+
pipeline: ImageSegmentationPipeline,
|
1057
|
+
model: [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation],
|
1058
|
+
processor: AutoProcessor,
|
1059
|
+
default: {
|
1060
|
+
model: "Xenova/detr-resnet-50-panoptic",
|
1061
|
+
},
|
1062
|
+
type: "multimodal"
|
1063
|
+
},
|
1064
|
+
"zero-shot-image-classification" => {
|
1065
|
+
tokenizer: AutoTokenizer,
|
1066
|
+
pipeline: ZeroShotImageClassificationPipeline,
|
1067
|
+
model: AutoModel,
|
1068
|
+
processor: AutoProcessor,
|
1069
|
+
default: {
|
1070
|
+
model: "Xenova/clip-vit-base-patch32"
|
1071
|
+
},
|
1072
|
+
type: "multimodal"
|
1073
|
+
},
|
1074
|
+
"object-detection" => {
|
1075
|
+
pipeline: ObjectDetectionPipeline,
|
1076
|
+
model: AutoModelForObjectDetection,
|
1077
|
+
processor: AutoProcessor,
|
1078
|
+
default: {
|
1079
|
+
model: "Xenova/detr-resnet-50",
|
1080
|
+
},
|
1081
|
+
type: "multimodal"
|
1082
|
+
},
|
1083
|
+
"zero-shot-object-detection" => {
|
1084
|
+
tokenizer: AutoTokenizer,
|
1085
|
+
pipeline: ZeroShotObjectDetectionPipeline,
|
1086
|
+
model: AutoModelForZeroShotObjectDetection,
|
1087
|
+
processor: AutoProcessor,
|
1088
|
+
default: {
|
1089
|
+
model: "Xenova/owlvit-base-patch32"
|
1090
|
+
},
|
1091
|
+
type: "multimodal"
|
1092
|
+
},
|
1093
|
+
"document-question-answering" => {
|
1094
|
+
tokenizer: AutoTokenizer,
|
1095
|
+
pipeline: DocumentQuestionAnsweringPipeline,
|
1096
|
+
model: AutoModelForDocumentQuestionAnswering,
|
1097
|
+
processor: AutoProcessor,
|
1098
|
+
default: {
|
1099
|
+
model: "Xenova/donut-base-finetuned-docvqa"
|
1100
|
+
},
|
1101
|
+
type: "multimodal"
|
1102
|
+
},
|
1103
|
+
"image-to-image" => {
|
1104
|
+
pipeline: ImageToImagePipeline,
|
1105
|
+
model: AutoModelForImageToImage,
|
1106
|
+
processor: AutoProcessor,
|
1107
|
+
default: {
|
1108
|
+
model: "Xenova/swin2SR-classical-sr-x2-64"
|
1109
|
+
},
|
1110
|
+
type: "image"
|
1111
|
+
},
|
1112
|
+
"depth-estimation" => {
|
1113
|
+
pipeline: DepthEstimationPipeline,
|
1114
|
+
model: AutoModelForDepthEstimation,
|
1115
|
+
processor: AutoProcessor,
|
1116
|
+
default: {
|
1117
|
+
model: "Xenova/dpt-large"
|
1118
|
+
},
|
1119
|
+
type: "image"
|
1120
|
+
},
|
371
1121
|
"feature-extraction" => {
|
372
1122
|
tokenizer: AutoTokenizer,
|
373
1123
|
pipeline: FeatureExtractionPipeline,
|
@@ -377,6 +1127,15 @@ module Informers
|
|
377
1127
|
},
|
378
1128
|
type: "text"
|
379
1129
|
},
|
1130
|
+
"image-feature-extraction" => {
|
1131
|
+
processor: AutoProcessor,
|
1132
|
+
pipeline: ImageFeatureExtractionPipeline,
|
1133
|
+
model: [AutoModelForImageFeatureExtraction, AutoModel],
|
1134
|
+
default: {
|
1135
|
+
model: "Xenova/vit-base-patch16-224"
|
1136
|
+
},
|
1137
|
+
type: "image"
|
1138
|
+
},
|
380
1139
|
"embedding" => {
|
381
1140
|
tokenizer: AutoTokenizer,
|
382
1141
|
pipeline: EmbeddingPipeline,
|
@@ -432,14 +1191,14 @@ module Informers
|
|
432
1191
|
revision: "main",
|
433
1192
|
model_file_name: nil
|
434
1193
|
)
|
1194
|
+
# Apply aliases
|
1195
|
+
task = TASK_ALIASES[task] || task
|
1196
|
+
|
435
1197
|
if quantized == NO_DEFAULT
|
436
1198
|
# TODO move default to task class
|
437
|
-
quantized =
|
1199
|
+
quantized = ["text-classification", "token-classification", "question-answering", "feature-extraction"].include?(task)
|
438
1200
|
end
|
439
1201
|
|
440
|
-
# Apply aliases
|
441
|
-
task = TASK_ALIASES[task] || task
|
442
|
-
|
443
1202
|
# Get pipeline info
|
444
1203
|
pipeline_info = SUPPORTED_TASKS[task.split("_", 1)[0]]
|
445
1204
|
if !pipeline_info
|
@@ -495,7 +1254,15 @@ module Informers
|
|
495
1254
|
next if !cls
|
496
1255
|
|
497
1256
|
if cls.is_a?(Array)
|
498
|
-
|
1257
|
+
e = nil
|
1258
|
+
cls.each do |c|
|
1259
|
+
begin
|
1260
|
+
result[name] = c.from_pretrained(model, **pretrained_options)
|
1261
|
+
rescue => err
|
1262
|
+
e = err
|
1263
|
+
end
|
1264
|
+
end
|
1265
|
+
raise e unless result[name]
|
499
1266
|
else
|
500
1267
|
result[name] = cls.from_pretrained(model, **pretrained_options)
|
501
1268
|
end
|