informers 1.0.3 → 1.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,796 @@
1
+ module Informers
2
+ class FeatureExtractor
3
+ def initialize(config)
4
+ super()
5
+ @config = config
6
+ end
7
+ end
8
+
9
+ class ImageFeatureExtractor < FeatureExtractor
10
+ def initialize(config)
11
+ super(config)
12
+
13
+ @image_mean = @config["image_mean"] || @config["mean"]
14
+ @image_std = @config["image_std"] || @config["std"]
15
+
16
+ @resample = @config["resample"] || 2 # 2 => bilinear
17
+ @do_rescale = @config.fetch("do_rescale", true)
18
+ @rescale_factor = @config["rescale_factor"] || (1 / 255.0)
19
+ @do_normalize = @config["do_normalize"]
20
+
21
+ @do_resize = @config["do_resize"]
22
+ @do_thumbnail = @config["do_thumbnail"]
23
+ @size = @config["size"]
24
+ @size_divisibility = @config["size_divisibility"] || @config["size_divisor"]
25
+
26
+ @do_center_crop = @config["do_center_crop"]
27
+ @crop_size = @config["crop_size"]
28
+ @do_convert_rgb = @config.fetch("do_convert_rgb", true)
29
+ @do_crop_margin = @config["do_crop_margin"]
30
+
31
+ @pad_size = @config["pad_size"]
32
+ @do_pad = @config["do_pad"]
33
+
34
+ if @do_pad && !@pad_size && @size && !@size["width"].nil? && !@size["height"].nil?
35
+ # Should pad, but no pad size specified
36
+ # We infer the pad size from the resize size
37
+ @pad_size = @size
38
+ end
39
+
40
+ @do_flip_channel_order = @config["do_flip_channel_order"] || false
41
+ end
42
+
43
+ def thumbnail(image, size, resample = 2)
44
+ input_height = image.height
45
+ input_width = image.width
46
+
47
+ output_height = size["height"]
48
+ output_width = size["width"]
49
+
50
+ # We always resize to the smallest of either the input or output size.
51
+ height = [input_height, output_height].min
52
+ width = [input_width, output_width].min
53
+
54
+ if height == input_height && width == input_width
55
+ return image
56
+ end
57
+ if input_height > input_width
58
+ width = (input_width * height / input_height).floor
59
+ elsif input_width > input_height
60
+ height = (input_height * width / input_width).floor
61
+ end
62
+ image.resize(width, height, resample:)
63
+ end
64
+
65
+ def pad_image(
66
+ pixel_data,
67
+ img_dims,
68
+ pad_size,
69
+ mode: "constant",
70
+ center: false,
71
+ constant_values: 0
72
+ )
73
+ image_height, image_width, image_channels = img_dims
74
+
75
+ if pad_size.is_a?(Numeric)
76
+ padded_image_width = pad_size
77
+ padded_image_height = pad_size
78
+ else
79
+ padded_image_width = pad_size[:width] || pad_size["width"]
80
+ padded_image_height = pad_size[:height] || pad_size["height"]
81
+ end
82
+
83
+ # Only add padding if there is a difference in size
84
+ if padded_image_width != image_width || padded_image_height != image_height
85
+ padded_pixel_data = Array.new(padded_image_width * padded_image_height * image_channels)
86
+ if constant_values.is_a?(Array)
87
+ # Fill with constant values, cycling through the array
88
+ padded_pixel_data.length.times do |i|
89
+ padded_pixel_data[i] = constant_values[i % image_channels]
90
+ end
91
+ elsif constant_values != 0
92
+ padded_pixel_data.fill(constant_values)
93
+ end
94
+
95
+ left, top =
96
+ if center
97
+ [((padded_image_width - image_width) / 2.0).floor, ((padded_image_height - image_height) / 2.0).floor]
98
+ else
99
+ [0, 0]
100
+ end
101
+
102
+ # Copy the original image into the padded image
103
+ image_height.times do |i|
104
+ a = (i + top) * padded_image_width
105
+ b = i * image_width
106
+ image_width.times do |j|
107
+ c = (a + j + left) * image_channels
108
+ d = (b + j) * image_channels
109
+ image_channels.times do |k|
110
+ padded_pixel_data[c + k] = pixel_data[d + k]
111
+ end
112
+ end
113
+ end
114
+
115
+ if mode == "symmetric"
116
+ if center
117
+ raise Error, "`center` padding is not supported when `mode` is set to `symmetric`."
118
+ end
119
+ h1 = image_height - 1
120
+ w1 = image_width - 1
121
+ padded_image_height.times do |i|
122
+ a = i * padded_image_width
123
+ b = Utils.calculate_reflect_offset(i, h1) * image_width
124
+
125
+ padded_image_width.times do |j|
126
+ next if i < image_height && j < image_width # Do not overwrite original image
127
+ c = (a + j) * image_channels
128
+ d = (b + Utils.calculate_reflect_offset(j, w1)) * image_channels
129
+
130
+ # Copy channel-wise
131
+ image_channels.times do |k|
132
+ padded_pixel_data[c + k] = pixel_data[d + k]
133
+ end
134
+ end
135
+ end
136
+ end
137
+
138
+ # Update pixel data and image dimensions
139
+ pixel_data = padded_pixel_data
140
+ img_dims = [padded_image_height, padded_image_width, image_channels]
141
+ end
142
+ [pixel_data, img_dims]
143
+ end
144
+
145
+ def rescale(pixel_data)
146
+ pixel_data.length.times do |i|
147
+ pixel_data[i] *= @rescale_factor
148
+ end
149
+ end
150
+
151
+ def get_resize_output_image_size(image, size)
152
+ src_width, src_height = image.size
153
+
154
+ if @do_thumbnail
155
+ # NOTE: custom logic for `Donut` models
156
+ height = size["height"]
157
+ width = size["width"]
158
+ shortest_edge = [height, width].min
159
+ elsif size.is_a?(Numeric)
160
+ shortest_edge = size
161
+ longest_edge = @config["max_size"] || shortest_edge
162
+ elsif !size.nil?
163
+ # Extract known properties from `size`
164
+ shortest_edge = size["shortest_edge"]
165
+ longest_edge = size["longest_edge"]
166
+ end
167
+
168
+ if !shortest_edge.nil? || !longest_edge.nil?
169
+ # http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/
170
+ # Try resize so that shortest edge is `shortest_edge` (target)
171
+ short_resize_factor =
172
+ if shortest_edge.nil?
173
+ 1 # If `shortest_edge` is not set, don't upscale
174
+ else
175
+ [shortest_edge / src_width.to_f, shortest_edge / src_height.to_f].max
176
+ end
177
+
178
+ new_width = src_width * short_resize_factor
179
+ new_height = src_height * short_resize_factor
180
+
181
+ # The new width and height might be greater than `longest_edge`, so
182
+ # we downscale again to ensure the largest dimension is `longest_edge`
183
+ long_resize_factor =
184
+ if longest_edge.nil?
185
+ 1 # If `longest_edge` is not set, don't downscale
186
+ else
187
+ [longest_edge / new_width.to_f, longest_edge / new_height.to_f].min
188
+ end
189
+
190
+ # To avoid certain floating point precision issues, we round to 2 decimal places
191
+ final_width = (new_width * long_resize_factor).round(2).floor
192
+ final_height = (new_height * long_resize_factor).round(2).floor
193
+
194
+ if !@size_divisibility.nil?
195
+ raise Todo
196
+ end
197
+ [final_width, final_height]
198
+ elsif !size.nil? && !size["width"].nil? && !size["height"].nil?
199
+ new_width = size["width"]
200
+ new_height = size["height"]
201
+
202
+ if @config["keep_aspect_ratio"] && @config["ensure_multiple_of"]
203
+ raise Todo
204
+ end
205
+
206
+ [new_width, new_height]
207
+ else
208
+ raise Todo
209
+ end
210
+ end
211
+
212
+ def resize(image)
213
+ new_width, new_height = get_resize_output_image_size(image, @size)
214
+ image.resize(new_width, new_height, resample: @resample)
215
+ end
216
+
217
+ def preprocess(
218
+ image,
219
+ do_normalize: nil,
220
+ do_pad: nil,
221
+ do_convert_rgb: nil,
222
+ do_convert_grayscale: nil,
223
+ do_flip_channel_order: nil
224
+ )
225
+ if @do_crop_margin
226
+ # NOTE: Specific to nougat processors. This is done before resizing,
227
+ # and can be interpreted as a pre-preprocessing step.
228
+ image = crop_margin(image)
229
+ end
230
+
231
+ src_width, src_height = image.size # original image size
232
+
233
+ # Convert image to RGB if specified in config.
234
+ if !do_convert_rgb.nil? ? do_convert_rgb : @do_convert_rgb
235
+ image = image.rgb
236
+ elsif do_convert_grayscale
237
+ image = image.grayscale
238
+ end
239
+
240
+ # Resize all images
241
+ if @do_resize
242
+ image = resize(image)
243
+ end
244
+
245
+ # Resize the image using thumbnail method.
246
+ if @do_thumbnail
247
+ image = thumbnail(image, @size, @resample)
248
+ end
249
+
250
+ if @do_center_crop
251
+ if @crop_size.is_a?(Integer)
252
+ crop_width = @crop_size
253
+ crop_height = @crop_size
254
+ else
255
+ crop_width = @crop_size["width"]
256
+ crop_height = @crop_size["height"]
257
+ end
258
+ image = image.center_crop(crop_width, crop_height)
259
+ end
260
+
261
+ reshaped_input_size = [image.height, image.width]
262
+
263
+ # NOTE: All pixel-level manipulation (i.e., modifying `pixelData`)
264
+ # occurs with data in the hwc format (height, width, channels),
265
+ # to emulate the behavior of the original Python code (w/ numpy).
266
+ pixel_data = image.data
267
+ img_dims = [image.height, image.width, image.channels]
268
+
269
+ if @do_rescale
270
+ rescale(pixel_data)
271
+ end
272
+
273
+ if !do_normalize.nil? ? do_normalize : @do_normalize
274
+ image_mean = @image_mean
275
+ if !@image_mean.is_a?(Array)
276
+ image_mean = new Array(image.channels) { image_mean }
277
+ end
278
+
279
+ image_std = @image_std
280
+ if !@image_std.is_a?(Array)
281
+ image_std = new Array(image.channels) { image_std }
282
+ end
283
+
284
+ if image_mean.length != image.channels || image_std.length != image.channels
285
+ raise Error, "When set to arrays, the length of `image_mean` (#{image_mean.length}) and `image_std` (#{image_std.length}) must match the number of channels in the image (#{image.channels})."
286
+ end
287
+
288
+ i = 0
289
+ while i < pixel_data.length
290
+ image.channels.times do |j|
291
+ pixel_data[i + j] = (pixel_data[i + j] - image_mean[j]) / image_std[j]
292
+ end
293
+ i += image.channels
294
+ end
295
+ end
296
+
297
+ # do padding after rescaling/normalizing
298
+ if !do_pad.nil? ? do_pad : @do_pad
299
+ if @pad_size
300
+ padded = pad_image(pixel_data, [image.height, image.width, image.channels], @pad_size)
301
+ pixel_data, img_dims = padded # Update pixel data and image dimensions
302
+ elsif @size_divisibility
303
+ raise Todo
304
+ end
305
+ end
306
+
307
+ if !do_flip_channel_order.nil? ? do_flip_channel_order : @do_flip_channel_order
308
+ raise Todo
309
+ end
310
+
311
+ # convert to channel dimension format (hwc -> chw)
312
+ h, w, c = img_dims
313
+ pixel_values =
314
+ c.times.map do |ci|
315
+ h.times.map do |hi|
316
+ w.times.map do |wi|
317
+ index = (hi * w * c) + (wi * c) + ci
318
+ pixel_data[index]
319
+ end
320
+ end
321
+ end
322
+
323
+ {
324
+ original_size: [src_height, src_width],
325
+ reshaped_input_size: reshaped_input_size,
326
+ pixel_values: pixel_values
327
+ }
328
+ end
329
+
330
+ def call(images, *args)
331
+ if !images.is_a?(Array)
332
+ images = [images]
333
+ end
334
+
335
+ image_data = images.map { |x| preprocess(x) }
336
+
337
+ # Stack pixel values
338
+ pixel_values = Utils.stack(image_data.map { |x| x[:pixel_values] }, 0)
339
+
340
+ {
341
+ pixel_values: pixel_values,
342
+
343
+ # Original sizes of images
344
+ original_sizes: image_data.map { |x| x[:original_size] },
345
+
346
+ # Reshaped sizes of images, before padding or cropping
347
+ reshaped_input_sizes: image_data.map { |x| x[:reshaped_input_size] }
348
+ }
349
+ end
350
+ end
351
+
352
+ class CLIPFeatureExtractor < ImageFeatureExtractor
353
+ end
354
+
355
+ class DPTFeatureExtractor < ImageFeatureExtractor
356
+ end
357
+
358
+ class ViTFeatureExtractor < ImageFeatureExtractor
359
+ end
360
+
361
+ class OwlViTFeatureExtractor < ImageFeatureExtractor
362
+ def post_process_object_detection(*args)
363
+ Utils.post_process_object_detection(*args)
364
+ end
365
+ end
366
+
367
+ class Swin2SRImageProcessor < ImageFeatureExtractor
368
+ def pad_image(pixel_data, img_dims, pad_size, **options)
369
+ # NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.
370
+ # In other words, the image is padded so that its width and height are multiples of `padSize`.
371
+ image_height, image_width, _image_channels = img_dims
372
+
373
+ super(
374
+ pixel_data,
375
+ img_dims,
376
+ {
377
+ # NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already
378
+ # a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19).
379
+ # For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`.
380
+ width: image_width + (pad_size - image_width % pad_size) % pad_size,
381
+ height: image_height + (pad_size - image_height % pad_size) % pad_size
382
+ },
383
+ mode: "symmetric",
384
+ center: false,
385
+ constant_values: -1,
386
+ **options
387
+ )
388
+ end
389
+ end
390
+
391
+ class DonutFeatureExtractor < ImageFeatureExtractor
392
+ def pad_image(pixel_data, img_dims, pad_size, **options)
393
+ _image_height, _image_width, image_channels = img_dims
394
+
395
+ image_mean = @image_mean
396
+ if !image_mean.is_a?(Array)
397
+ image_mean = new Array(image_channels, image_mean)
398
+ end
399
+
400
+ image_std = @image_std
401
+ if !image_std.is_a?(Array)
402
+ image_std = new Array(image_channels, image_std)
403
+ end
404
+
405
+ constant_values = image_mean.map.with_index { |x, i| -x / image_std[i] }
406
+
407
+ super(
408
+ pixel_data,
409
+ img_dims,
410
+ pad_size,
411
+ center: true,
412
+ # Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed.
413
+ # For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451
414
+ constant_values: constant_values,
415
+ **options
416
+ )
417
+ end
418
+ end
419
+
420
+ class DetrFeatureExtractor < ImageFeatureExtractor
421
+ def call(images)
422
+ result = super(images)
423
+
424
+ # TODO support differently-sized images, for now assume all images are the same size.
425
+ # TODO support different mask sizes (not just 64x64)
426
+ # Currently, just fill pixel mask with 1s
427
+ mask_size = [result[:pixel_values].size, 64, 64]
428
+ pixel_mask =
429
+ mask_size[0].times.map do
430
+ mask_size[1].times.map do
431
+ mask_size[2].times.map do
432
+ 1
433
+ end
434
+ end
435
+ end
436
+
437
+ result.merge(pixel_mask: pixel_mask)
438
+ end
439
+
440
+ def post_process_object_detection(*args)
441
+ Utils.post_process_object_detection(*args)
442
+ end
443
+
444
+ def remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels)
445
+ mask_probs_item = []
446
+ pred_scores_item = []
447
+ pred_labels_item = []
448
+
449
+ class_logits.size.times do |j|
450
+ cls = class_logits[j]
451
+ mask = mask_logits[j]
452
+
453
+ pred_label = Utils.max(cls)[1]
454
+ if pred_label == num_labels
455
+ # Is the background, so we ignore it
456
+ next
457
+ end
458
+
459
+ scores = Utils.softmax(cls)
460
+ pred_score = scores[pred_label]
461
+ if pred_score > object_mask_threshold
462
+ mask_probs_item << mask
463
+ pred_scores_item << pred_score
464
+ pred_labels_item << pred_label
465
+ end
466
+ end
467
+
468
+ [mask_probs_item, pred_scores_item, pred_labels_item]
469
+ end
470
+
471
+ def check_segment_validity(
472
+ mask_labels,
473
+ mask_probs,
474
+ k,
475
+ mask_threshold = 0.5,
476
+ overlap_mask_area_threshold = 0.8
477
+ )
478
+ # mask_k is a 1D array of indices, indicating where the mask is equal to k
479
+ mask_k = []
480
+ mask_k_area = 0
481
+ original_area = 0
482
+
483
+ mask_probs_k_data = mask_probs[k].flatten
484
+
485
+ # Compute the area of all the stuff in query k
486
+ mask_labels.length.times do |i|
487
+ if mask_labels[i] == k
488
+ mask_k << i
489
+ mask_k_area += 1
490
+ end
491
+
492
+ if mask_probs_k_data[i] >= mask_threshold
493
+ original_area += 1
494
+ end
495
+ end
496
+ mask_exists = mask_k_area > 0 && original_area > 0
497
+
498
+ # Eliminate disconnected tiny segments
499
+ if mask_exists
500
+ # Perform additional check
501
+ area_ratio = mask_k_area / original_area
502
+ mask_exists = area_ratio > overlap_mask_area_threshold
503
+ end
504
+
505
+ [mask_exists, mask_k]
506
+ end
507
+
508
+ def compute_segments(
509
+ mask_probs,
510
+ pred_scores,
511
+ pred_labels,
512
+ mask_threshold,
513
+ overlap_mask_area_threshold,
514
+ label_ids_to_fuse = nil,
515
+ target_size = nil
516
+ )
517
+ height, width = target_size || Utils.dims(mask_probs[0])
518
+
519
+ segmentation = Array.new(height * width)
520
+ segments = []
521
+
522
+ # 1. If target_size is not null, we need to resize the masks to the target size
523
+ if !target_size.nil?
524
+ # resize the masks to the target size
525
+ mask_probs.length.times do |i|
526
+ mask_probs[i] = Utils.interpolate(mask_probs[i], target_size, "bilinear", false)
527
+ end
528
+ end
529
+
530
+ # 2. Weigh each mask by its prediction score
531
+ # NOTE: `mask_probs` is updated in-place
532
+ #
533
+ # Temporary storage for the best label/scores for each pixel ([height, width]):
534
+ mask_labels = Array.new(mask_probs[0].flatten.length)
535
+ best_scores = Array.new(mask_probs[0].flatten.length, 0)
536
+
537
+ mask_probs.length.times do |i|
538
+ score = pred_scores[i]
539
+
540
+ mask_probs_i_data = mask_probs[i].flatten
541
+ mask_probs_i_dims = Utils.dims(mask_probs[i])
542
+
543
+ mask_probs_i_data.length.times do |j|
544
+ mask_probs_i_data[j] *= score
545
+ if mask_probs_i_data[j] > best_scores[j]
546
+ mask_labels[j] = i
547
+ best_scores[j] = mask_probs_i_data[j]
548
+ end
549
+ end
550
+
551
+ mask_probs[i] = Utils.reshape(mask_probs_i_data, mask_probs_i_dims)
552
+ end
553
+
554
+ current_segment_id = 0
555
+
556
+ # stuff_memory_list = {}
557
+ pred_labels.length.times do |k|
558
+ pred_class = pred_labels[k]
559
+
560
+ # TODO add `should_fuse`
561
+ # should_fuse = label_ids_to_fuse.include?(pred_class)
562
+
563
+ # Check if mask exists and large enough to be a segment
564
+ mask_exists, mask_k = check_segment_validity(
565
+ mask_labels,
566
+ mask_probs,
567
+ k,
568
+ mask_threshold,
569
+ overlap_mask_area_threshold
570
+ )
571
+
572
+ if !mask_exists
573
+ # Nothing to see here
574
+ next
575
+ end
576
+
577
+ current_segment_id += 1
578
+
579
+ # Add current object segment to final segmentation map
580
+ mask_k.each do |index|
581
+ segmentation[index] = current_segment_id
582
+ end
583
+
584
+ segments << {
585
+ id: current_segment_id,
586
+ label_id: pred_class,
587
+ score: pred_scores[k]
588
+ }
589
+ end
590
+
591
+ segmentation = Utils.reshape(segmentation, [height, width])
592
+
593
+ [segmentation, segments]
594
+ end
595
+
596
+ def post_process_panoptic_segmentation(
597
+ outputs,
598
+ threshold: 0.5,
599
+ mask_threshold: 0.5,
600
+ overlap_mask_area_threshold: 0.8,
601
+ label_ids_to_fuse: nil,
602
+ target_sizes: nil
603
+ )
604
+ if label_ids_to_fuse.nil?
605
+ warn "`label_ids_to_fuse` unset. No instance will be fused."
606
+ label_ids_to_fuse = Set.new
607
+ end
608
+
609
+ class_queries_logits = outputs[:logits] # [batch_size, num_queries, num_classes+1]
610
+ masks_queries_logits = outputs[:pred_masks] # [batch_size, num_queries, height, width]
611
+
612
+ mask_probs = Utils.sigmoid(masks_queries_logits) # [batch_size, num_queries, height, width]
613
+
614
+ batch_size, _num_queries, num_labels = class_queries_logits.size, class_queries_logits[0].size, class_queries_logits[0][0].size
615
+ num_labels -= 1 # Remove last class (background)
616
+
617
+ if !target_sizes.nil? && target_sizes.length != batch_size
618
+ raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
619
+ end
620
+
621
+ to_return = []
622
+ batch_size.times do |i|
623
+ target_size = !target_sizes.nil? ? target_sizes[i] : nil
624
+
625
+ class_logits = class_queries_logits[i]
626
+ mask_logits = mask_probs[i]
627
+
628
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels)
629
+
630
+ if pred_labels_item.length == 0
631
+ raise Todo
632
+ end
633
+
634
+ # Get segmentation map and segment information of batch item
635
+ segmentation, segments = compute_segments(
636
+ mask_probs_item,
637
+ pred_scores_item,
638
+ pred_labels_item,
639
+ mask_threshold,
640
+ overlap_mask_area_threshold,
641
+ label_ids_to_fuse,
642
+ target_size
643
+ )
644
+
645
+ to_return << {
646
+ segmentation: segmentation,
647
+ segments_info: segments
648
+ }
649
+ end
650
+
651
+ to_return
652
+ end
653
+ end
654
+
655
+ module Utils
656
+ def self.center_to_corners_format(v)
657
+ centerX, centerY, width, height = v
658
+ [
659
+ centerX - width / 2.0,
660
+ centerY - height / 2.0,
661
+ centerX + width / 2.0,
662
+ centerY + height / 2.0
663
+ ]
664
+ end
665
+
666
+ def self.post_process_object_detection(outputs, threshold = 0.5, target_sizes = nil, is_zero_shot = false)
667
+ out_logits = outputs[:logits]
668
+ out_bbox = outputs[:pred_boxes]
669
+ batch_size, num_boxes, num_classes = out_logits.size, out_logits[0].size, out_logits[0][0].size
670
+
671
+ if !target_sizes.nil? && target_sizes.length != batch_size
672
+ raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
673
+ end
674
+ to_return = []
675
+ batch_size.times do |i|
676
+ target_size = !target_sizes.nil? ? target_sizes[i] : nil
677
+ info = {
678
+ boxes: [],
679
+ classes: [],
680
+ scores: []
681
+ }
682
+ logits = out_logits[i]
683
+ bbox = out_bbox[i]
684
+
685
+ num_boxes.times do |j|
686
+ logit = logits[j]
687
+
688
+ indices = []
689
+ if is_zero_shot
690
+ # Get indices of classes with high enough probability
691
+ probs = Utils.sigmoid(logit)
692
+ probs.length.times do |k|
693
+ if probs[k] > threshold
694
+ indices << k
695
+ end
696
+ end
697
+ else
698
+ # Get most probable class
699
+ max_index = Utils.max(logit)[1]
700
+
701
+ if max_index == num_classes - 1
702
+ # This is the background class, skip it
703
+ next
704
+ end
705
+ indices << max_index
706
+
707
+ # Compute softmax over classes
708
+ probs = Utils.softmax(logit)
709
+ end
710
+
711
+ indices.each do |index|
712
+ box = bbox[j]
713
+
714
+ # convert to [x0, y0, x1, y1] format
715
+ box = center_to_corners_format(box)
716
+ if !target_size.nil?
717
+ box = box.map.with_index { |x, i| x * target_size[(i + 1) % 2] }
718
+ end
719
+
720
+ info[:boxes] << box
721
+ info[:classes] << index
722
+ info[:scores] << probs[index]
723
+ end
724
+ end
725
+ to_return << info
726
+ end
727
+ to_return
728
+ end
729
+ end
730
+
731
+ class Processor
732
+ attr_reader :feature_extractor
733
+
734
+ def initialize(feature_extractor)
735
+ @feature_extractor = feature_extractor
736
+ end
737
+
738
+ def call(input, *args)
739
+ @feature_extractor.(input, *args)
740
+ end
741
+ end
742
+
743
+ class AutoProcessor
744
+ FEATURE_EXTRACTOR_CLASS_MAPPING = {
745
+ "ViTFeatureExtractor" => ViTFeatureExtractor,
746
+ "OwlViTFeatureExtractor" => OwlViTFeatureExtractor,
747
+ "CLIPFeatureExtractor" => CLIPFeatureExtractor,
748
+ "DPTFeatureExtractor" => DPTFeatureExtractor,
749
+ "DetrFeatureExtractor" => DetrFeatureExtractor,
750
+ "Swin2SRImageProcessor" => Swin2SRImageProcessor,
751
+ "DonutFeatureExtractor" => DonutFeatureExtractor
752
+ }
753
+
754
+ PROCESSOR_CLASS_MAPPING = {}
755
+
756
+ def self.from_pretrained(
757
+ pretrained_model_name_or_path,
758
+ progress_callback: nil,
759
+ config: nil,
760
+ cache_dir: nil,
761
+ local_files_only: false,
762
+ revision: "main",
763
+ **kwargs
764
+ )
765
+ preprocessor_config = config || Utils::Hub::get_model_json(pretrained_model_name_or_path, "preprocessor_config.json", true,
766
+ progress_callback:,
767
+ config:,
768
+ cache_dir:,
769
+ local_files_only:,
770
+ revision:
771
+ )
772
+
773
+ # Determine feature extractor class
774
+ # TODO: Ensure backwards compatibility with old configs
775
+ key = preprocessor_config["feature_extractor_type"] || preprocessor_config["image_processor_type"]
776
+ feature_extractor_class = FEATURE_EXTRACTOR_CLASS_MAPPING[key]
777
+
778
+ if !feature_extractor_class
779
+ if preprocessor_config["size"]
780
+ # Assume ImageFeatureExtractor
781
+ warn "Feature extractor type #{key.inspect} not found, assuming ImageFeatureExtractor due to size parameter in config."
782
+ feature_extractor_class = ImageFeatureExtractor
783
+ else
784
+ raise Error, "Unknown Feature Extractor type: #{key}"
785
+ end
786
+ end
787
+
788
+ # If no associated processor class, use default
789
+ processor_class = PROCESSOR_CLASS_MAPPING[preprocessor_config["processor_class"]] || Processor
790
+
791
+ # Instantiate processor and feature extractor
792
+ feature_extractor = feature_extractor_class.new(preprocessor_config)
793
+ processor_class.new(feature_extractor)
794
+ end
795
+ end
796
+ end