informers 1.0.3 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,796 @@
1
+ module Informers
2
+ class FeatureExtractor
3
+ def initialize(config)
4
+ super()
5
+ @config = config
6
+ end
7
+ end
8
+
9
+ class ImageFeatureExtractor < FeatureExtractor
10
+ def initialize(config)
11
+ super(config)
12
+
13
+ @image_mean = @config["image_mean"] || @config["mean"]
14
+ @image_std = @config["image_std"] || @config["std"]
15
+
16
+ @resample = @config["resample"] || 2 # 2 => bilinear
17
+ @do_rescale = @config.fetch("do_rescale", true)
18
+ @rescale_factor = @config["rescale_factor"] || (1 / 255.0)
19
+ @do_normalize = @config["do_normalize"]
20
+
21
+ @do_resize = @config["do_resize"]
22
+ @do_thumbnail = @config["do_thumbnail"]
23
+ @size = @config["size"]
24
+ @size_divisibility = @config["size_divisibility"] || @config["size_divisor"]
25
+
26
+ @do_center_crop = @config["do_center_crop"]
27
+ @crop_size = @config["crop_size"]
28
+ @do_convert_rgb = @config.fetch("do_convert_rgb", true)
29
+ @do_crop_margin = @config["do_crop_margin"]
30
+
31
+ @pad_size = @config["pad_size"]
32
+ @do_pad = @config["do_pad"]
33
+
34
+ if @do_pad && !@pad_size && @size && !@size["width"].nil? && !@size["height"].nil?
35
+ # Should pad, but no pad size specified
36
+ # We infer the pad size from the resize size
37
+ @pad_size = @size
38
+ end
39
+
40
+ @do_flip_channel_order = @config["do_flip_channel_order"] || false
41
+ end
42
+
43
+ def thumbnail(image, size, resample = 2)
44
+ input_height = image.height
45
+ input_width = image.width
46
+
47
+ output_height = size["height"]
48
+ output_width = size["width"]
49
+
50
+ # We always resize to the smallest of either the input or output size.
51
+ height = [input_height, output_height].min
52
+ width = [input_width, output_width].min
53
+
54
+ if height == input_height && width == input_width
55
+ return image
56
+ end
57
+ if input_height > input_width
58
+ width = (input_width * height / input_height).floor
59
+ elsif input_width > input_height
60
+ height = (input_height * width / input_width).floor
61
+ end
62
+ image.resize(width, height, resample:)
63
+ end
64
+
65
+ def pad_image(
66
+ pixel_data,
67
+ img_dims,
68
+ pad_size,
69
+ mode: "constant",
70
+ center: false,
71
+ constant_values: 0
72
+ )
73
+ image_height, image_width, image_channels = img_dims
74
+
75
+ if pad_size.is_a?(Numeric)
76
+ padded_image_width = pad_size
77
+ padded_image_height = pad_size
78
+ else
79
+ padded_image_width = pad_size[:width] || pad_size["width"]
80
+ padded_image_height = pad_size[:height] || pad_size["height"]
81
+ end
82
+
83
+ # Only add padding if there is a difference in size
84
+ if padded_image_width != image_width || padded_image_height != image_height
85
+ padded_pixel_data = Array.new(padded_image_width * padded_image_height * image_channels)
86
+ if constant_values.is_a?(Array)
87
+ # Fill with constant values, cycling through the array
88
+ padded_pixel_data.length.times do |i|
89
+ padded_pixel_data[i] = constant_values[i % image_channels]
90
+ end
91
+ elsif constant_values != 0
92
+ padded_pixel_data.fill(constant_values)
93
+ end
94
+
95
+ left, top =
96
+ if center
97
+ [((padded_image_width - image_width) / 2.0).floor, ((padded_image_height - image_height) / 2.0).floor]
98
+ else
99
+ [0, 0]
100
+ end
101
+
102
+ # Copy the original image into the padded image
103
+ image_height.times do |i|
104
+ a = (i + top) * padded_image_width
105
+ b = i * image_width
106
+ image_width.times do |j|
107
+ c = (a + j + left) * image_channels
108
+ d = (b + j) * image_channels
109
+ image_channels.times do |k|
110
+ padded_pixel_data[c + k] = pixel_data[d + k]
111
+ end
112
+ end
113
+ end
114
+
115
+ if mode == "symmetric"
116
+ if center
117
+ raise Error, "`center` padding is not supported when `mode` is set to `symmetric`."
118
+ end
119
+ h1 = image_height - 1
120
+ w1 = image_width - 1
121
+ padded_image_height.times do |i|
122
+ a = i * padded_image_width
123
+ b = Utils.calculate_reflect_offset(i, h1) * image_width
124
+
125
+ padded_image_width.times do |j|
126
+ next if i < image_height && j < image_width # Do not overwrite original image
127
+ c = (a + j) * image_channels
128
+ d = (b + Utils.calculate_reflect_offset(j, w1)) * image_channels
129
+
130
+ # Copy channel-wise
131
+ image_channels.times do |k|
132
+ padded_pixel_data[c + k] = pixel_data[d + k]
133
+ end
134
+ end
135
+ end
136
+ end
137
+
138
+ # Update pixel data and image dimensions
139
+ pixel_data = padded_pixel_data
140
+ img_dims = [padded_image_height, padded_image_width, image_channels]
141
+ end
142
+ [pixel_data, img_dims]
143
+ end
144
+
145
+ def rescale(pixel_data)
146
+ pixel_data.length.times do |i|
147
+ pixel_data[i] *= @rescale_factor
148
+ end
149
+ end
150
+
151
+ def get_resize_output_image_size(image, size)
152
+ src_width, src_height = image.size
153
+
154
+ if @do_thumbnail
155
+ # NOTE: custom logic for `Donut` models
156
+ height = size["height"]
157
+ width = size["width"]
158
+ shortest_edge = [height, width].min
159
+ elsif size.is_a?(Numeric)
160
+ shortest_edge = size
161
+ longest_edge = @config["max_size"] || shortest_edge
162
+ elsif !size.nil?
163
+ # Extract known properties from `size`
164
+ shortest_edge = size["shortest_edge"]
165
+ longest_edge = size["longest_edge"]
166
+ end
167
+
168
+ if !shortest_edge.nil? || !longest_edge.nil?
169
+ # http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/
170
+ # Try resize so that shortest edge is `shortest_edge` (target)
171
+ short_resize_factor =
172
+ if shortest_edge.nil?
173
+ 1 # If `shortest_edge` is not set, don't upscale
174
+ else
175
+ [shortest_edge / src_width.to_f, shortest_edge / src_height.to_f].max
176
+ end
177
+
178
+ new_width = src_width * short_resize_factor
179
+ new_height = src_height * short_resize_factor
180
+
181
+ # The new width and height might be greater than `longest_edge`, so
182
+ # we downscale again to ensure the largest dimension is `longest_edge`
183
+ long_resize_factor =
184
+ if longest_edge.nil?
185
+ 1 # If `longest_edge` is not set, don't downscale
186
+ else
187
+ [longest_edge / new_width.to_f, longest_edge / new_height.to_f].min
188
+ end
189
+
190
+ # To avoid certain floating point precision issues, we round to 2 decimal places
191
+ final_width = (new_width * long_resize_factor).round(2).floor
192
+ final_height = (new_height * long_resize_factor).round(2).floor
193
+
194
+ if !@size_divisibility.nil?
195
+ raise Todo
196
+ end
197
+ [final_width, final_height]
198
+ elsif !size.nil? && !size["width"].nil? && !size["height"].nil?
199
+ new_width = size["width"]
200
+ new_height = size["height"]
201
+
202
+ if @config["keep_aspect_ratio"] && @config["ensure_multiple_of"]
203
+ raise Todo
204
+ end
205
+
206
+ [new_width, new_height]
207
+ else
208
+ raise Todo
209
+ end
210
+ end
211
+
212
+ def resize(image)
213
+ new_width, new_height = get_resize_output_image_size(image, @size)
214
+ image.resize(new_width, new_height, resample: @resample)
215
+ end
216
+
217
+ def preprocess(
218
+ image,
219
+ do_normalize: nil,
220
+ do_pad: nil,
221
+ do_convert_rgb: nil,
222
+ do_convert_grayscale: nil,
223
+ do_flip_channel_order: nil
224
+ )
225
+ if @do_crop_margin
226
+ # NOTE: Specific to nougat processors. This is done before resizing,
227
+ # and can be interpreted as a pre-preprocessing step.
228
+ image = crop_margin(image)
229
+ end
230
+
231
+ src_width, src_height = image.size # original image size
232
+
233
+ # Convert image to RGB if specified in config.
234
+ if !do_convert_rgb.nil? ? do_convert_rgb : @do_convert_rgb
235
+ image = image.rgb
236
+ elsif do_convert_grayscale
237
+ image = image.grayscale
238
+ end
239
+
240
+ # Resize all images
241
+ if @do_resize
242
+ image = resize(image)
243
+ end
244
+
245
+ # Resize the image using thumbnail method.
246
+ if @do_thumbnail
247
+ image = thumbnail(image, @size, @resample)
248
+ end
249
+
250
+ if @do_center_crop
251
+ if @crop_size.is_a?(Integer)
252
+ crop_width = @crop_size
253
+ crop_height = @crop_size
254
+ else
255
+ crop_width = @crop_size["width"]
256
+ crop_height = @crop_size["height"]
257
+ end
258
+ image = image.center_crop(crop_width, crop_height)
259
+ end
260
+
261
+ reshaped_input_size = [image.height, image.width]
262
+
263
+ # NOTE: All pixel-level manipulation (i.e., modifying `pixelData`)
264
+ # occurs with data in the hwc format (height, width, channels),
265
+ # to emulate the behavior of the original Python code (w/ numpy).
266
+ pixel_data = image.data
267
+ img_dims = [image.height, image.width, image.channels]
268
+
269
+ if @do_rescale
270
+ rescale(pixel_data)
271
+ end
272
+
273
+ if !do_normalize.nil? ? do_normalize : @do_normalize
274
+ image_mean = @image_mean
275
+ if !@image_mean.is_a?(Array)
276
+ image_mean = new Array(image.channels) { image_mean }
277
+ end
278
+
279
+ image_std = @image_std
280
+ if !@image_std.is_a?(Array)
281
+ image_std = new Array(image.channels) { image_std }
282
+ end
283
+
284
+ if image_mean.length != image.channels || image_std.length != image.channels
285
+ raise Error, "When set to arrays, the length of `image_mean` (#{image_mean.length}) and `image_std` (#{image_std.length}) must match the number of channels in the image (#{image.channels})."
286
+ end
287
+
288
+ i = 0
289
+ while i < pixel_data.length
290
+ image.channels.times do |j|
291
+ pixel_data[i + j] = (pixel_data[i + j] - image_mean[j]) / image_std[j]
292
+ end
293
+ i += image.channels
294
+ end
295
+ end
296
+
297
+ # do padding after rescaling/normalizing
298
+ if !do_pad.nil? ? do_pad : @do_pad
299
+ if @pad_size
300
+ padded = pad_image(pixel_data, [image.height, image.width, image.channels], @pad_size)
301
+ pixel_data, img_dims = padded # Update pixel data and image dimensions
302
+ elsif @size_divisibility
303
+ raise Todo
304
+ end
305
+ end
306
+
307
+ if !do_flip_channel_order.nil? ? do_flip_channel_order : @do_flip_channel_order
308
+ raise Todo
309
+ end
310
+
311
+ # convert to channel dimension format (hwc -> chw)
312
+ h, w, c = img_dims
313
+ pixel_values =
314
+ c.times.map do |ci|
315
+ h.times.map do |hi|
316
+ w.times.map do |wi|
317
+ index = (hi * w * c) + (wi * c) + ci
318
+ pixel_data[index]
319
+ end
320
+ end
321
+ end
322
+
323
+ {
324
+ original_size: [src_height, src_width],
325
+ reshaped_input_size: reshaped_input_size,
326
+ pixel_values: pixel_values
327
+ }
328
+ end
329
+
330
+ def call(images, *args)
331
+ if !images.is_a?(Array)
332
+ images = [images]
333
+ end
334
+
335
+ image_data = images.map { |x| preprocess(x) }
336
+
337
+ # Stack pixel values
338
+ pixel_values = Utils.stack(image_data.map { |x| x[:pixel_values] }, 0)
339
+
340
+ {
341
+ pixel_values: pixel_values,
342
+
343
+ # Original sizes of images
344
+ original_sizes: image_data.map { |x| x[:original_size] },
345
+
346
+ # Reshaped sizes of images, before padding or cropping
347
+ reshaped_input_sizes: image_data.map { |x| x[:reshaped_input_size] }
348
+ }
349
+ end
350
+ end
351
+
352
+ class CLIPFeatureExtractor < ImageFeatureExtractor
353
+ end
354
+
355
+ class DPTFeatureExtractor < ImageFeatureExtractor
356
+ end
357
+
358
+ class ViTFeatureExtractor < ImageFeatureExtractor
359
+ end
360
+
361
+ class OwlViTFeatureExtractor < ImageFeatureExtractor
362
+ def post_process_object_detection(*args)
363
+ Utils.post_process_object_detection(*args)
364
+ end
365
+ end
366
+
367
+ class Swin2SRImageProcessor < ImageFeatureExtractor
368
+ def pad_image(pixel_data, img_dims, pad_size, **options)
369
+ # NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.
370
+ # In other words, the image is padded so that its width and height are multiples of `padSize`.
371
+ image_height, image_width, _image_channels = img_dims
372
+
373
+ super(
374
+ pixel_data,
375
+ img_dims,
376
+ {
377
+ # NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already
378
+ # a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19).
379
+ # For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`.
380
+ width: image_width + (pad_size - image_width % pad_size) % pad_size,
381
+ height: image_height + (pad_size - image_height % pad_size) % pad_size
382
+ },
383
+ mode: "symmetric",
384
+ center: false,
385
+ constant_values: -1,
386
+ **options
387
+ )
388
+ end
389
+ end
390
+
391
+ class DonutFeatureExtractor < ImageFeatureExtractor
392
+ def pad_image(pixel_data, img_dims, pad_size, **options)
393
+ _image_height, _image_width, image_channels = img_dims
394
+
395
+ image_mean = @image_mean
396
+ if !image_mean.is_a?(Array)
397
+ image_mean = new Array(image_channels, image_mean)
398
+ end
399
+
400
+ image_std = @image_std
401
+ if !image_std.is_a?(Array)
402
+ image_std = new Array(image_channels, image_std)
403
+ end
404
+
405
+ constant_values = image_mean.map.with_index { |x, i| -x / image_std[i] }
406
+
407
+ super(
408
+ pixel_data,
409
+ img_dims,
410
+ pad_size,
411
+ center: true,
412
+ # Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed.
413
+ # For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451
414
+ constant_values: constant_values,
415
+ **options
416
+ )
417
+ end
418
+ end
419
+
420
+ class DetrFeatureExtractor < ImageFeatureExtractor
421
+ def call(images)
422
+ result = super(images)
423
+
424
+ # TODO support differently-sized images, for now assume all images are the same size.
425
+ # TODO support different mask sizes (not just 64x64)
426
+ # Currently, just fill pixel mask with 1s
427
+ mask_size = [result[:pixel_values].size, 64, 64]
428
+ pixel_mask =
429
+ mask_size[0].times.map do
430
+ mask_size[1].times.map do
431
+ mask_size[2].times.map do
432
+ 1
433
+ end
434
+ end
435
+ end
436
+
437
+ result.merge(pixel_mask: pixel_mask)
438
+ end
439
+
440
+ def post_process_object_detection(*args)
441
+ Utils.post_process_object_detection(*args)
442
+ end
443
+
444
+ def remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels)
445
+ mask_probs_item = []
446
+ pred_scores_item = []
447
+ pred_labels_item = []
448
+
449
+ class_logits.size.times do |j|
450
+ cls = class_logits[j]
451
+ mask = mask_logits[j]
452
+
453
+ pred_label = Utils.max(cls)[1]
454
+ if pred_label == num_labels
455
+ # Is the background, so we ignore it
456
+ next
457
+ end
458
+
459
+ scores = Utils.softmax(cls)
460
+ pred_score = scores[pred_label]
461
+ if pred_score > object_mask_threshold
462
+ mask_probs_item << mask
463
+ pred_scores_item << pred_score
464
+ pred_labels_item << pred_label
465
+ end
466
+ end
467
+
468
+ [mask_probs_item, pred_scores_item, pred_labels_item]
469
+ end
470
+
471
+ def check_segment_validity(
472
+ mask_labels,
473
+ mask_probs,
474
+ k,
475
+ mask_threshold = 0.5,
476
+ overlap_mask_area_threshold = 0.8
477
+ )
478
+ # mask_k is a 1D array of indices, indicating where the mask is equal to k
479
+ mask_k = []
480
+ mask_k_area = 0
481
+ original_area = 0
482
+
483
+ mask_probs_k_data = mask_probs[k].flatten
484
+
485
+ # Compute the area of all the stuff in query k
486
+ mask_labels.length.times do |i|
487
+ if mask_labels[i] == k
488
+ mask_k << i
489
+ mask_k_area += 1
490
+ end
491
+
492
+ if mask_probs_k_data[i] >= mask_threshold
493
+ original_area += 1
494
+ end
495
+ end
496
+ mask_exists = mask_k_area > 0 && original_area > 0
497
+
498
+ # Eliminate disconnected tiny segments
499
+ if mask_exists
500
+ # Perform additional check
501
+ area_ratio = mask_k_area / original_area
502
+ mask_exists = area_ratio > overlap_mask_area_threshold
503
+ end
504
+
505
+ [mask_exists, mask_k]
506
+ end
507
+
508
+ def compute_segments(
509
+ mask_probs,
510
+ pred_scores,
511
+ pred_labels,
512
+ mask_threshold,
513
+ overlap_mask_area_threshold,
514
+ label_ids_to_fuse = nil,
515
+ target_size = nil
516
+ )
517
+ height, width = target_size || Utils.dims(mask_probs[0])
518
+
519
+ segmentation = Array.new(height * width)
520
+ segments = []
521
+
522
+ # 1. If target_size is not null, we need to resize the masks to the target size
523
+ if !target_size.nil?
524
+ # resize the masks to the target size
525
+ mask_probs.length.times do |i|
526
+ mask_probs[i] = Utils.interpolate(mask_probs[i], target_size, "bilinear", false)
527
+ end
528
+ end
529
+
530
+ # 2. Weigh each mask by its prediction score
531
+ # NOTE: `mask_probs` is updated in-place
532
+ #
533
+ # Temporary storage for the best label/scores for each pixel ([height, width]):
534
+ mask_labels = Array.new(mask_probs[0].flatten.length)
535
+ best_scores = Array.new(mask_probs[0].flatten.length, 0)
536
+
537
+ mask_probs.length.times do |i|
538
+ score = pred_scores[i]
539
+
540
+ mask_probs_i_data = mask_probs[i].flatten
541
+ mask_probs_i_dims = Utils.dims(mask_probs[i])
542
+
543
+ mask_probs_i_data.length.times do |j|
544
+ mask_probs_i_data[j] *= score
545
+ if mask_probs_i_data[j] > best_scores[j]
546
+ mask_labels[j] = i
547
+ best_scores[j] = mask_probs_i_data[j]
548
+ end
549
+ end
550
+
551
+ mask_probs[i] = Utils.reshape(mask_probs_i_data, mask_probs_i_dims)
552
+ end
553
+
554
+ current_segment_id = 0
555
+
556
+ # stuff_memory_list = {}
557
+ pred_labels.length.times do |k|
558
+ pred_class = pred_labels[k]
559
+
560
+ # TODO add `should_fuse`
561
+ # should_fuse = label_ids_to_fuse.include?(pred_class)
562
+
563
+ # Check if mask exists and large enough to be a segment
564
+ mask_exists, mask_k = check_segment_validity(
565
+ mask_labels,
566
+ mask_probs,
567
+ k,
568
+ mask_threshold,
569
+ overlap_mask_area_threshold
570
+ )
571
+
572
+ if !mask_exists
573
+ # Nothing to see here
574
+ next
575
+ end
576
+
577
+ current_segment_id += 1
578
+
579
+ # Add current object segment to final segmentation map
580
+ mask_k.each do |index|
581
+ segmentation[index] = current_segment_id
582
+ end
583
+
584
+ segments << {
585
+ id: current_segment_id,
586
+ label_id: pred_class,
587
+ score: pred_scores[k]
588
+ }
589
+ end
590
+
591
+ segmentation = Utils.reshape(segmentation, [height, width])
592
+
593
+ [segmentation, segments]
594
+ end
595
+
596
+ def post_process_panoptic_segmentation(
597
+ outputs,
598
+ threshold: 0.5,
599
+ mask_threshold: 0.5,
600
+ overlap_mask_area_threshold: 0.8,
601
+ label_ids_to_fuse: nil,
602
+ target_sizes: nil
603
+ )
604
+ if label_ids_to_fuse.nil?
605
+ warn "`label_ids_to_fuse` unset. No instance will be fused."
606
+ label_ids_to_fuse = Set.new
607
+ end
608
+
609
+ class_queries_logits = outputs[:logits] # [batch_size, num_queries, num_classes+1]
610
+ masks_queries_logits = outputs[:pred_masks] # [batch_size, num_queries, height, width]
611
+
612
+ mask_probs = Utils.sigmoid(masks_queries_logits) # [batch_size, num_queries, height, width]
613
+
614
+ batch_size, _num_queries, num_labels = class_queries_logits.size, class_queries_logits[0].size, class_queries_logits[0][0].size
615
+ num_labels -= 1 # Remove last class (background)
616
+
617
+ if !target_sizes.nil? && target_sizes.length != batch_size
618
+ raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
619
+ end
620
+
621
+ to_return = []
622
+ batch_size.times do |i|
623
+ target_size = !target_sizes.nil? ? target_sizes[i] : nil
624
+
625
+ class_logits = class_queries_logits[i]
626
+ mask_logits = mask_probs[i]
627
+
628
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels)
629
+
630
+ if pred_labels_item.length == 0
631
+ raise Todo
632
+ end
633
+
634
+ # Get segmentation map and segment information of batch item
635
+ segmentation, segments = compute_segments(
636
+ mask_probs_item,
637
+ pred_scores_item,
638
+ pred_labels_item,
639
+ mask_threshold,
640
+ overlap_mask_area_threshold,
641
+ label_ids_to_fuse,
642
+ target_size
643
+ )
644
+
645
+ to_return << {
646
+ segmentation: segmentation,
647
+ segments_info: segments
648
+ }
649
+ end
650
+
651
+ to_return
652
+ end
653
+ end
654
+
655
+ module Utils
656
+ def self.center_to_corners_format(v)
657
+ centerX, centerY, width, height = v
658
+ [
659
+ centerX - width / 2.0,
660
+ centerY - height / 2.0,
661
+ centerX + width / 2.0,
662
+ centerY + height / 2.0
663
+ ]
664
+ end
665
+
666
+ def self.post_process_object_detection(outputs, threshold = 0.5, target_sizes = nil, is_zero_shot = false)
667
+ out_logits = outputs[:logits]
668
+ out_bbox = outputs[:pred_boxes]
669
+ batch_size, num_boxes, num_classes = out_logits.size, out_logits[0].size, out_logits[0][0].size
670
+
671
+ if !target_sizes.nil? && target_sizes.length != batch_size
672
+ raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
673
+ end
674
+ to_return = []
675
+ batch_size.times do |i|
676
+ target_size = !target_sizes.nil? ? target_sizes[i] : nil
677
+ info = {
678
+ boxes: [],
679
+ classes: [],
680
+ scores: []
681
+ }
682
+ logits = out_logits[i]
683
+ bbox = out_bbox[i]
684
+
685
+ num_boxes.times do |j|
686
+ logit = logits[j]
687
+
688
+ indices = []
689
+ if is_zero_shot
690
+ # Get indices of classes with high enough probability
691
+ probs = Utils.sigmoid(logit)
692
+ probs.length.times do |k|
693
+ if probs[k] > threshold
694
+ indices << k
695
+ end
696
+ end
697
+ else
698
+ # Get most probable class
699
+ max_index = Utils.max(logit)[1]
700
+
701
+ if max_index == num_classes - 1
702
+ # This is the background class, skip it
703
+ next
704
+ end
705
+ indices << max_index
706
+
707
+ # Compute softmax over classes
708
+ probs = Utils.softmax(logit)
709
+ end
710
+
711
+ indices.each do |index|
712
+ box = bbox[j]
713
+
714
+ # convert to [x0, y0, x1, y1] format
715
+ box = center_to_corners_format(box)
716
+ if !target_size.nil?
717
+ box = box.map.with_index { |x, i| x * target_size[(i + 1) % 2] }
718
+ end
719
+
720
+ info[:boxes] << box
721
+ info[:classes] << index
722
+ info[:scores] << probs[index]
723
+ end
724
+ end
725
+ to_return << info
726
+ end
727
+ to_return
728
+ end
729
+ end
730
+
731
+ class Processor
732
+ attr_reader :feature_extractor
733
+
734
+ def initialize(feature_extractor)
735
+ @feature_extractor = feature_extractor
736
+ end
737
+
738
+ def call(input, *args)
739
+ @feature_extractor.(input, *args)
740
+ end
741
+ end
742
+
743
+ class AutoProcessor
744
+ FEATURE_EXTRACTOR_CLASS_MAPPING = {
745
+ "ViTFeatureExtractor" => ViTFeatureExtractor,
746
+ "OwlViTFeatureExtractor" => OwlViTFeatureExtractor,
747
+ "CLIPFeatureExtractor" => CLIPFeatureExtractor,
748
+ "DPTFeatureExtractor" => DPTFeatureExtractor,
749
+ "DetrFeatureExtractor" => DetrFeatureExtractor,
750
+ "Swin2SRImageProcessor" => Swin2SRImageProcessor,
751
+ "DonutFeatureExtractor" => DonutFeatureExtractor
752
+ }
753
+
754
+ PROCESSOR_CLASS_MAPPING = {}
755
+
756
+ def self.from_pretrained(
757
+ pretrained_model_name_or_path,
758
+ progress_callback: nil,
759
+ config: nil,
760
+ cache_dir: nil,
761
+ local_files_only: false,
762
+ revision: "main",
763
+ **kwargs
764
+ )
765
+ preprocessor_config = config || Utils::Hub::get_model_json(pretrained_model_name_or_path, "preprocessor_config.json", true,
766
+ progress_callback:,
767
+ config:,
768
+ cache_dir:,
769
+ local_files_only:,
770
+ revision:
771
+ )
772
+
773
+ # Determine feature extractor class
774
+ # TODO: Ensure backwards compatibility with old configs
775
+ key = preprocessor_config["feature_extractor_type"] || preprocessor_config["image_processor_type"]
776
+ feature_extractor_class = FEATURE_EXTRACTOR_CLASS_MAPPING[key]
777
+
778
+ if !feature_extractor_class
779
+ if preprocessor_config["size"]
780
+ # Assume ImageFeatureExtractor
781
+ warn "Feature extractor type #{key.inspect} not found, assuming ImageFeatureExtractor due to size parameter in config."
782
+ feature_extractor_class = ImageFeatureExtractor
783
+ else
784
+ raise Error, "Unknown Feature Extractor type: #{key}"
785
+ end
786
+ end
787
+
788
+ # If no associated processor class, use default
789
+ processor_class = PROCESSOR_CLASS_MAPPING[preprocessor_config["processor_class"]] || Processor
790
+
791
+ # Instantiate processor and feature extractor
792
+ feature_extractor = feature_extractor_class.new(preprocessor_config)
793
+ processor_class.new(feature_extractor)
794
+ end
795
+ end
796
+ end