informers 1.0.3 → 1.1.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,856 @@
1
+ module Informers
2
+ class FeatureExtractor
3
+ attr_reader :config
4
+
5
+ def initialize(config)
6
+ super()
7
+ @config = config
8
+ end
9
+ end
10
+
11
+ class ImageFeatureExtractor < FeatureExtractor
12
+ def initialize(config)
13
+ super(config)
14
+
15
+ @image_mean = @config["image_mean"] || @config["mean"]
16
+ @image_std = @config["image_std"] || @config["std"]
17
+
18
+ @resample = @config["resample"] || 2 # 2 => bilinear
19
+ @do_rescale = @config.fetch("do_rescale", true)
20
+ @rescale_factor = @config["rescale_factor"] || (1 / 255.0)
21
+ @do_normalize = @config["do_normalize"]
22
+
23
+ @do_resize = @config["do_resize"]
24
+ @do_thumbnail = @config["do_thumbnail"]
25
+ @size = @config["size"]
26
+ @size_divisibility = @config["size_divisibility"] || @config["size_divisor"]
27
+
28
+ @do_center_crop = @config["do_center_crop"]
29
+ @crop_size = @config["crop_size"]
30
+ @do_convert_rgb = @config.fetch("do_convert_rgb", true)
31
+ @do_crop_margin = @config["do_crop_margin"]
32
+
33
+ @pad_size = @config["pad_size"]
34
+ @do_pad = @config["do_pad"]
35
+
36
+ if @do_pad && !@pad_size && @size && !@size["width"].nil? && !@size["height"].nil?
37
+ # Should pad, but no pad size specified
38
+ # We infer the pad size from the resize size
39
+ @pad_size = @size
40
+ end
41
+
42
+ @do_flip_channel_order = @config["do_flip_channel_order"] || false
43
+ end
44
+
45
+ def thumbnail(image, size, resample = 2)
46
+ input_height = image.height
47
+ input_width = image.width
48
+
49
+ output_height = size["height"]
50
+ output_width = size["width"]
51
+
52
+ # We always resize to the smallest of either the input or output size.
53
+ height = [input_height, output_height].min
54
+ width = [input_width, output_width].min
55
+
56
+ if height == input_height && width == input_width
57
+ return image
58
+ end
59
+ if input_height > input_width
60
+ width = (input_width * height / input_height).floor
61
+ elsif input_width > input_height
62
+ height = (input_height * width / input_width).floor
63
+ end
64
+ image.resize(width, height, resample:)
65
+ end
66
+
67
+ def pad_image(
68
+ pixel_data,
69
+ img_dims,
70
+ pad_size,
71
+ mode: "constant",
72
+ center: false,
73
+ constant_values: 0
74
+ )
75
+ image_height, image_width, image_channels = img_dims
76
+
77
+ if pad_size.is_a?(Numeric)
78
+ padded_image_width = pad_size
79
+ padded_image_height = pad_size
80
+ else
81
+ padded_image_width = pad_size[:width] || pad_size["width"]
82
+ padded_image_height = pad_size[:height] || pad_size["height"]
83
+ end
84
+
85
+ # Only add padding if there is a difference in size
86
+ if padded_image_width != image_width || padded_image_height != image_height
87
+ padded_pixel_data = Array.new(padded_image_width * padded_image_height * image_channels)
88
+ if constant_values.is_a?(Array)
89
+ # Fill with constant values, cycling through the array
90
+ padded_pixel_data.length.times do |i|
91
+ padded_pixel_data[i] = constant_values[i % image_channels]
92
+ end
93
+ elsif constant_values != 0
94
+ padded_pixel_data.fill(constant_values)
95
+ end
96
+
97
+ left, top =
98
+ if center
99
+ [((padded_image_width - image_width) / 2.0).floor, ((padded_image_height - image_height) / 2.0).floor]
100
+ else
101
+ [0, 0]
102
+ end
103
+
104
+ # Copy the original image into the padded image
105
+ image_height.times do |i|
106
+ a = (i + top) * padded_image_width
107
+ b = i * image_width
108
+ image_width.times do |j|
109
+ c = (a + j + left) * image_channels
110
+ d = (b + j) * image_channels
111
+ image_channels.times do |k|
112
+ padded_pixel_data[c + k] = pixel_data[d + k]
113
+ end
114
+ end
115
+ end
116
+
117
+ if mode == "symmetric"
118
+ if center
119
+ raise Error, "`center` padding is not supported when `mode` is set to `symmetric`."
120
+ end
121
+ h1 = image_height - 1
122
+ w1 = image_width - 1
123
+ padded_image_height.times do |i|
124
+ a = i * padded_image_width
125
+ b = Utils.calculate_reflect_offset(i, h1) * image_width
126
+
127
+ padded_image_width.times do |j|
128
+ next if i < image_height && j < image_width # Do not overwrite original image
129
+ c = (a + j) * image_channels
130
+ d = (b + Utils.calculate_reflect_offset(j, w1)) * image_channels
131
+
132
+ # Copy channel-wise
133
+ image_channels.times do |k|
134
+ padded_pixel_data[c + k] = pixel_data[d + k]
135
+ end
136
+ end
137
+ end
138
+ end
139
+
140
+ # Update pixel data and image dimensions
141
+ pixel_data = padded_pixel_data
142
+ img_dims = [padded_image_height, padded_image_width, image_channels]
143
+ end
144
+ [pixel_data, img_dims]
145
+ end
146
+
147
+ def rescale(pixel_data)
148
+ pixel_data.length.times do |i|
149
+ pixel_data[i] *= @rescale_factor
150
+ end
151
+ end
152
+
153
+ def get_resize_output_image_size(image, size)
154
+ src_width, src_height = image.size
155
+
156
+ if @do_thumbnail
157
+ # NOTE: custom logic for `Donut` models
158
+ height = size["height"]
159
+ width = size["width"]
160
+ shortest_edge = [height, width].min
161
+ elsif size.is_a?(Numeric)
162
+ shortest_edge = size
163
+ longest_edge = @config["max_size"] || shortest_edge
164
+ elsif !size.nil?
165
+ # Extract known properties from `size`
166
+ shortest_edge = size["shortest_edge"]
167
+ longest_edge = size["longest_edge"]
168
+ end
169
+
170
+ if !shortest_edge.nil? || !longest_edge.nil?
171
+ # http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/
172
+ # Try resize so that shortest edge is `shortest_edge` (target)
173
+ short_resize_factor =
174
+ if shortest_edge.nil?
175
+ 1 # If `shortest_edge` is not set, don't upscale
176
+ else
177
+ [shortest_edge / src_width.to_f, shortest_edge / src_height.to_f].max
178
+ end
179
+
180
+ new_width = src_width * short_resize_factor
181
+ new_height = src_height * short_resize_factor
182
+
183
+ # The new width and height might be greater than `longest_edge`, so
184
+ # we downscale again to ensure the largest dimension is `longest_edge`
185
+ long_resize_factor =
186
+ if longest_edge.nil?
187
+ 1 # If `longest_edge` is not set, don't downscale
188
+ else
189
+ [longest_edge / new_width.to_f, longest_edge / new_height.to_f].min
190
+ end
191
+
192
+ # To avoid certain floating point precision issues, we round to 2 decimal places
193
+ final_width = (new_width * long_resize_factor).round(2).floor
194
+ final_height = (new_height * long_resize_factor).round(2).floor
195
+
196
+ if !@size_divisibility.nil?
197
+ raise Todo
198
+ end
199
+ [final_width, final_height]
200
+ elsif !size.nil? && !size["width"].nil? && !size["height"].nil?
201
+ new_width = size["width"]
202
+ new_height = size["height"]
203
+
204
+ if @config["keep_aspect_ratio"] && @config["ensure_multiple_of"]
205
+ raise Todo
206
+ end
207
+
208
+ [new_width, new_height]
209
+ else
210
+ raise Todo
211
+ end
212
+ end
213
+
214
+ def resize(image)
215
+ new_width, new_height = get_resize_output_image_size(image, @size)
216
+ image.resize(new_width, new_height, resample: @resample)
217
+ end
218
+
219
+ def preprocess(
220
+ image,
221
+ do_normalize: nil,
222
+ do_pad: nil,
223
+ do_convert_rgb: nil,
224
+ do_convert_grayscale: nil,
225
+ do_flip_channel_order: nil
226
+ )
227
+ if @do_crop_margin
228
+ # NOTE: Specific to nougat processors. This is done before resizing,
229
+ # and can be interpreted as a pre-preprocessing step.
230
+ image = crop_margin(image)
231
+ end
232
+
233
+ src_width, src_height = image.size # original image size
234
+
235
+ # Convert image to RGB if specified in config.
236
+ if !do_convert_rgb.nil? ? do_convert_rgb : @do_convert_rgb
237
+ image = image.rgb
238
+ elsif do_convert_grayscale
239
+ image = image.grayscale
240
+ end
241
+
242
+ # Resize all images
243
+ if @do_resize
244
+ image = resize(image)
245
+ end
246
+
247
+ # Resize the image using thumbnail method.
248
+ if @do_thumbnail
249
+ image = thumbnail(image, @size, @resample)
250
+ end
251
+
252
+ if @do_center_crop
253
+ if @crop_size.is_a?(Integer)
254
+ crop_width = @crop_size
255
+ crop_height = @crop_size
256
+ else
257
+ crop_width = @crop_size["width"]
258
+ crop_height = @crop_size["height"]
259
+ end
260
+ image = image.center_crop(crop_width, crop_height)
261
+ end
262
+
263
+ reshaped_input_size = [image.height, image.width]
264
+
265
+ # NOTE: All pixel-level manipulation (i.e., modifying `pixelData`)
266
+ # occurs with data in the hwc format (height, width, channels),
267
+ # to emulate the behavior of the original Python code (w/ numpy).
268
+ pixel_data = image.data
269
+ img_dims = [image.height, image.width, image.channels]
270
+
271
+ if @do_rescale
272
+ rescale(pixel_data)
273
+ end
274
+
275
+ if !do_normalize.nil? ? do_normalize : @do_normalize
276
+ image_mean = @image_mean
277
+ if !@image_mean.is_a?(Array)
278
+ image_mean = new Array(image.channels) { image_mean }
279
+ end
280
+
281
+ image_std = @image_std
282
+ if !@image_std.is_a?(Array)
283
+ image_std = new Array(image.channels) { image_std }
284
+ end
285
+
286
+ if image_mean.length != image.channels || image_std.length != image.channels
287
+ raise Error, "When set to arrays, the length of `image_mean` (#{image_mean.length}) and `image_std` (#{image_std.length}) must match the number of channels in the image (#{image.channels})."
288
+ end
289
+
290
+ i = 0
291
+ while i < pixel_data.length
292
+ image.channels.times do |j|
293
+ pixel_data[i + j] = (pixel_data[i + j] - image_mean[j]) / image_std[j]
294
+ end
295
+ i += image.channels
296
+ end
297
+ end
298
+
299
+ # do padding after rescaling/normalizing
300
+ if !do_pad.nil? ? do_pad : @do_pad
301
+ if @pad_size
302
+ padded = pad_image(pixel_data, [image.height, image.width, image.channels], @pad_size)
303
+ pixel_data, img_dims = padded # Update pixel data and image dimensions
304
+ elsif @size_divisibility
305
+ raise Todo
306
+ end
307
+ end
308
+
309
+ if !do_flip_channel_order.nil? ? do_flip_channel_order : @do_flip_channel_order
310
+ raise Todo
311
+ end
312
+
313
+ # convert to channel dimension format (hwc -> chw)
314
+ h, w, c = img_dims
315
+ pixel_values =
316
+ c.times.map do |ci|
317
+ h.times.map do |hi|
318
+ w.times.map do |wi|
319
+ index = (hi * w * c) + (wi * c) + ci
320
+ pixel_data[index]
321
+ end
322
+ end
323
+ end
324
+
325
+ {
326
+ original_size: [src_height, src_width],
327
+ reshaped_input_size: reshaped_input_size,
328
+ pixel_values: pixel_values
329
+ }
330
+ end
331
+
332
+ def call(images, *args)
333
+ if !images.is_a?(Array)
334
+ images = [images]
335
+ end
336
+
337
+ image_data = images.map { |x| preprocess(x) }
338
+
339
+ # Stack pixel values
340
+ pixel_values = Utils.stack(image_data.map { |x| x[:pixel_values] }, 0)
341
+
342
+ {
343
+ pixel_values: pixel_values,
344
+
345
+ # Original sizes of images
346
+ original_sizes: image_data.map { |x| x[:original_size] },
347
+
348
+ # Reshaped sizes of images, before padding or cropping
349
+ reshaped_input_sizes: image_data.map { |x| x[:reshaped_input_size] }
350
+ }
351
+ end
352
+ end
353
+
354
+ class CLIPFeatureExtractor < ImageFeatureExtractor
355
+ end
356
+
357
+ class DPTFeatureExtractor < ImageFeatureExtractor
358
+ end
359
+
360
+ class ViTFeatureExtractor < ImageFeatureExtractor
361
+ end
362
+
363
+ class OwlViTFeatureExtractor < ImageFeatureExtractor
364
+ def post_process_object_detection(*args)
365
+ Utils.post_process_object_detection(*args)
366
+ end
367
+ end
368
+
369
+ class Swin2SRImageProcessor < ImageFeatureExtractor
370
+ def pad_image(pixel_data, img_dims, pad_size, **options)
371
+ # NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.
372
+ # In other words, the image is padded so that its width and height are multiples of `padSize`.
373
+ image_height, image_width, _image_channels = img_dims
374
+
375
+ super(
376
+ pixel_data,
377
+ img_dims,
378
+ {
379
+ # NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already
380
+ # a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19).
381
+ # For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`.
382
+ width: image_width + (pad_size - image_width % pad_size) % pad_size,
383
+ height: image_height + (pad_size - image_height % pad_size) % pad_size
384
+ },
385
+ mode: "symmetric",
386
+ center: false,
387
+ constant_values: -1,
388
+ **options
389
+ )
390
+ end
391
+ end
392
+
393
+ class DonutFeatureExtractor < ImageFeatureExtractor
394
+ def pad_image(pixel_data, img_dims, pad_size, **options)
395
+ _image_height, _image_width, image_channels = img_dims
396
+
397
+ image_mean = @image_mean
398
+ if !image_mean.is_a?(Array)
399
+ image_mean = new Array(image_channels, image_mean)
400
+ end
401
+
402
+ image_std = @image_std
403
+ if !image_std.is_a?(Array)
404
+ image_std = new Array(image_channels, image_std)
405
+ end
406
+
407
+ constant_values = image_mean.map.with_index { |x, i| -x / image_std[i] }
408
+
409
+ super(
410
+ pixel_data,
411
+ img_dims,
412
+ pad_size,
413
+ center: true,
414
+ # Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed.
415
+ # For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451
416
+ constant_values: constant_values,
417
+ **options
418
+ )
419
+ end
420
+ end
421
+
422
+ class DetrFeatureExtractor < ImageFeatureExtractor
423
+ def call(images)
424
+ result = super(images)
425
+
426
+ # TODO support differently-sized images, for now assume all images are the same size.
427
+ # TODO support different mask sizes (not just 64x64)
428
+ # Currently, just fill pixel mask with 1s
429
+ mask_size = [result[:pixel_values].size, 64, 64]
430
+ pixel_mask =
431
+ mask_size[0].times.map do
432
+ mask_size[1].times.map do
433
+ mask_size[2].times.map do
434
+ 1
435
+ end
436
+ end
437
+ end
438
+
439
+ result.merge(pixel_mask: pixel_mask)
440
+ end
441
+
442
+ def post_process_object_detection(*args)
443
+ Utils.post_process_object_detection(*args)
444
+ end
445
+
446
+ def remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels)
447
+ mask_probs_item = []
448
+ pred_scores_item = []
449
+ pred_labels_item = []
450
+
451
+ class_logits.size.times do |j|
452
+ cls = class_logits[j]
453
+ mask = mask_logits[j]
454
+
455
+ pred_label = Utils.max(cls)[1]
456
+ if pred_label == num_labels
457
+ # Is the background, so we ignore it
458
+ next
459
+ end
460
+
461
+ scores = Utils.softmax(cls)
462
+ pred_score = scores[pred_label]
463
+ if pred_score > object_mask_threshold
464
+ mask_probs_item << mask
465
+ pred_scores_item << pred_score
466
+ pred_labels_item << pred_label
467
+ end
468
+ end
469
+
470
+ [mask_probs_item, pred_scores_item, pred_labels_item]
471
+ end
472
+
473
+ def check_segment_validity(
474
+ mask_labels,
475
+ mask_probs,
476
+ k,
477
+ mask_threshold = 0.5,
478
+ overlap_mask_area_threshold = 0.8
479
+ )
480
+ # mask_k is a 1D array of indices, indicating where the mask is equal to k
481
+ mask_k = []
482
+ mask_k_area = 0
483
+ original_area = 0
484
+
485
+ mask_probs_k_data = mask_probs[k].flatten
486
+
487
+ # Compute the area of all the stuff in query k
488
+ mask_labels.length.times do |i|
489
+ if mask_labels[i] == k
490
+ mask_k << i
491
+ mask_k_area += 1
492
+ end
493
+
494
+ if mask_probs_k_data[i] >= mask_threshold
495
+ original_area += 1
496
+ end
497
+ end
498
+ mask_exists = mask_k_area > 0 && original_area > 0
499
+
500
+ # Eliminate disconnected tiny segments
501
+ if mask_exists
502
+ # Perform additional check
503
+ area_ratio = mask_k_area / original_area
504
+ mask_exists = area_ratio > overlap_mask_area_threshold
505
+ end
506
+
507
+ [mask_exists, mask_k]
508
+ end
509
+
510
+ def compute_segments(
511
+ mask_probs,
512
+ pred_scores,
513
+ pred_labels,
514
+ mask_threshold,
515
+ overlap_mask_area_threshold,
516
+ label_ids_to_fuse = nil,
517
+ target_size = nil
518
+ )
519
+ height, width = target_size || Utils.dims(mask_probs[0])
520
+
521
+ segmentation = Array.new(height * width)
522
+ segments = []
523
+
524
+ # 1. If target_size is not null, we need to resize the masks to the target size
525
+ if !target_size.nil?
526
+ # resize the masks to the target size
527
+ mask_probs.length.times do |i|
528
+ mask_probs[i] = Utils.interpolate(mask_probs[i], target_size, "bilinear", false)
529
+ end
530
+ end
531
+
532
+ # 2. Weigh each mask by its prediction score
533
+ # NOTE: `mask_probs` is updated in-place
534
+ #
535
+ # Temporary storage for the best label/scores for each pixel ([height, width]):
536
+ mask_labels = Array.new(mask_probs[0].flatten.length)
537
+ best_scores = Array.new(mask_probs[0].flatten.length, 0)
538
+
539
+ mask_probs.length.times do |i|
540
+ score = pred_scores[i]
541
+
542
+ mask_probs_i_data = mask_probs[i].flatten
543
+ mask_probs_i_dims = Utils.dims(mask_probs[i])
544
+
545
+ mask_probs_i_data.length.times do |j|
546
+ mask_probs_i_data[j] *= score
547
+ if mask_probs_i_data[j] > best_scores[j]
548
+ mask_labels[j] = i
549
+ best_scores[j] = mask_probs_i_data[j]
550
+ end
551
+ end
552
+
553
+ mask_probs[i] = Utils.reshape(mask_probs_i_data, mask_probs_i_dims)
554
+ end
555
+
556
+ current_segment_id = 0
557
+
558
+ # stuff_memory_list = {}
559
+ pred_labels.length.times do |k|
560
+ pred_class = pred_labels[k]
561
+
562
+ # TODO add `should_fuse`
563
+ # should_fuse = label_ids_to_fuse.include?(pred_class)
564
+
565
+ # Check if mask exists and large enough to be a segment
566
+ mask_exists, mask_k = check_segment_validity(
567
+ mask_labels,
568
+ mask_probs,
569
+ k,
570
+ mask_threshold,
571
+ overlap_mask_area_threshold
572
+ )
573
+
574
+ if !mask_exists
575
+ # Nothing to see here
576
+ next
577
+ end
578
+
579
+ current_segment_id += 1
580
+
581
+ # Add current object segment to final segmentation map
582
+ mask_k.each do |index|
583
+ segmentation[index] = current_segment_id
584
+ end
585
+
586
+ segments << {
587
+ id: current_segment_id,
588
+ label_id: pred_class,
589
+ score: pred_scores[k]
590
+ }
591
+ end
592
+
593
+ segmentation = Utils.reshape(segmentation, [height, width])
594
+
595
+ [segmentation, segments]
596
+ end
597
+
598
+ def post_process_panoptic_segmentation(
599
+ outputs,
600
+ threshold: 0.5,
601
+ mask_threshold: 0.5,
602
+ overlap_mask_area_threshold: 0.8,
603
+ label_ids_to_fuse: nil,
604
+ target_sizes: nil
605
+ )
606
+ if label_ids_to_fuse.nil?
607
+ warn "`label_ids_to_fuse` unset. No instance will be fused."
608
+ label_ids_to_fuse = Set.new
609
+ end
610
+
611
+ class_queries_logits = outputs[:logits] # [batch_size, num_queries, num_classes+1]
612
+ masks_queries_logits = outputs[:pred_masks] # [batch_size, num_queries, height, width]
613
+
614
+ mask_probs = Utils.sigmoid(masks_queries_logits) # [batch_size, num_queries, height, width]
615
+
616
+ batch_size, _num_queries, num_labels = class_queries_logits.size, class_queries_logits[0].size, class_queries_logits[0][0].size
617
+ num_labels -= 1 # Remove last class (background)
618
+
619
+ if !target_sizes.nil? && target_sizes.length != batch_size
620
+ raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
621
+ end
622
+
623
+ to_return = []
624
+ batch_size.times do |i|
625
+ target_size = !target_sizes.nil? ? target_sizes[i] : nil
626
+
627
+ class_logits = class_queries_logits[i]
628
+ mask_logits = mask_probs[i]
629
+
630
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels)
631
+
632
+ if pred_labels_item.length == 0
633
+ raise Todo
634
+ end
635
+
636
+ # Get segmentation map and segment information of batch item
637
+ segmentation, segments = compute_segments(
638
+ mask_probs_item,
639
+ pred_scores_item,
640
+ pred_labels_item,
641
+ mask_threshold,
642
+ overlap_mask_area_threshold,
643
+ label_ids_to_fuse,
644
+ target_size
645
+ )
646
+
647
+ to_return << {
648
+ segmentation: segmentation,
649
+ segments_info: segments
650
+ }
651
+ end
652
+
653
+ to_return
654
+ end
655
+ end
656
+
657
+ module Utils
658
+ def self.center_to_corners_format(v)
659
+ centerX, centerY, width, height = v
660
+ [
661
+ centerX - width / 2.0,
662
+ centerY - height / 2.0,
663
+ centerX + width / 2.0,
664
+ centerY + height / 2.0
665
+ ]
666
+ end
667
+
668
+ def self.post_process_object_detection(outputs, threshold = 0.5, target_sizes = nil, is_zero_shot = false)
669
+ out_logits = outputs[:logits]
670
+ out_bbox = outputs[:pred_boxes]
671
+ batch_size, num_boxes, num_classes = out_logits.size, out_logits[0].size, out_logits[0][0].size
672
+
673
+ if !target_sizes.nil? && target_sizes.length != batch_size
674
+ raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
675
+ end
676
+ to_return = []
677
+ batch_size.times do |i|
678
+ target_size = !target_sizes.nil? ? target_sizes[i] : nil
679
+ info = {
680
+ boxes: [],
681
+ classes: [],
682
+ scores: []
683
+ }
684
+ logits = out_logits[i]
685
+ bbox = out_bbox[i]
686
+
687
+ num_boxes.times do |j|
688
+ logit = logits[j]
689
+
690
+ indices = []
691
+ if is_zero_shot
692
+ # Get indices of classes with high enough probability
693
+ probs = Utils.sigmoid(logit)
694
+ probs.length.times do |k|
695
+ if probs[k] > threshold
696
+ indices << k
697
+ end
698
+ end
699
+ else
700
+ # Get most probable class
701
+ max_index = Utils.max(logit)[1]
702
+
703
+ if max_index == num_classes - 1
704
+ # This is the background class, skip it
705
+ next
706
+ end
707
+ indices << max_index
708
+
709
+ # Compute softmax over classes
710
+ probs = Utils.softmax(logit)
711
+ end
712
+
713
+ indices.each do |index|
714
+ box = bbox[j]
715
+
716
+ # convert to [x0, y0, x1, y1] format
717
+ box = center_to_corners_format(box)
718
+ if !target_size.nil?
719
+ box = box.map.with_index { |x, i| x * target_size[(i + 1) % 2] }
720
+ end
721
+
722
+ info[:boxes] << box
723
+ info[:classes] << index
724
+ info[:scores] << probs[index]
725
+ end
726
+ end
727
+ to_return << info
728
+ end
729
+ to_return
730
+ end
731
+ end
732
+
733
+ class WhisperFeatureExtractor < FeatureExtractor
734
+ def initialize(config)
735
+ super(config)
736
+
737
+ raise Todo
738
+ end
739
+
740
+ def _extract_fbank_features(waveform)
741
+ raise Todo
742
+ end
743
+
744
+ def call(audio)
745
+ raise Todo
746
+ end
747
+ end
748
+
749
+ class Wav2Vec2FeatureExtractor < FeatureExtractor
750
+ def _zero_mean_unit_var_norm(input_values)
751
+ sum = input_values.sum
752
+ mean = sum / input_values.length.to_f
753
+ variance = input_values.sum { |b| (b - mean) ** 2 } / input_values.length.to_f
754
+ input_values.map { |x| (x - mean) / Math.sqrt(variance + 1e-7) }
755
+ end
756
+
757
+ def call(audio)
758
+ # TODO
759
+ # validate_audio_inputs(audio, 'Wav2Vec2FeatureExtractor')
760
+
761
+ input_values = audio
762
+
763
+ # zero-mean and unit-variance normalization
764
+ if @config["do_normalize"]
765
+ input_values = _zero_mean_unit_var_norm(input_values)
766
+ end
767
+
768
+ # TODO: allow user to pass in attention mask
769
+ {
770
+ input_values: [input_values],
771
+ attention_mask: [Array.new(input_values.length, 1)]
772
+ }
773
+ end
774
+ end
775
+
776
+ class ClapFeatureExtractor < FeatureExtractor
777
+ def initialize(config)
778
+ super(config)
779
+
780
+ # TODO
781
+ end
782
+
783
+ def call(audio, max_length: nil)
784
+ raise Todo
785
+ end
786
+ end
787
+
788
+ class Processor
789
+ attr_reader :feature_extractor
790
+
791
+ def initialize(feature_extractor)
792
+ @feature_extractor = feature_extractor
793
+ end
794
+
795
+ def call(input, *args)
796
+ @feature_extractor.(input, *args)
797
+ end
798
+ end
799
+
800
+ class AutoProcessor
801
+ FEATURE_EXTRACTOR_CLASS_MAPPING = {
802
+ "ViTFeatureExtractor" => ViTFeatureExtractor,
803
+ "OwlViTFeatureExtractor" => OwlViTFeatureExtractor,
804
+ "CLIPFeatureExtractor" => CLIPFeatureExtractor,
805
+ "DPTFeatureExtractor" => DPTFeatureExtractor,
806
+ "DetrFeatureExtractor" => DetrFeatureExtractor,
807
+ "Swin2SRImageProcessor" => Swin2SRImageProcessor,
808
+ "DonutFeatureExtractor" => DonutFeatureExtractor,
809
+ "WhisperFeatureExtractor" => WhisperFeatureExtractor,
810
+ "Wav2Vec2FeatureExtractor" => Wav2Vec2FeatureExtractor,
811
+ "ClapFeatureExtractor" => ClapFeatureExtractor
812
+ }
813
+
814
+ PROCESSOR_CLASS_MAPPING = {}
815
+
816
+ def self.from_pretrained(
817
+ pretrained_model_name_or_path,
818
+ progress_callback: nil,
819
+ config: nil,
820
+ cache_dir: nil,
821
+ local_files_only: false,
822
+ revision: "main",
823
+ **kwargs
824
+ )
825
+ preprocessor_config = config || Utils::Hub.get_model_json(pretrained_model_name_or_path, "preprocessor_config.json", true,
826
+ progress_callback:,
827
+ config:,
828
+ cache_dir:,
829
+ local_files_only:,
830
+ revision:
831
+ )
832
+
833
+ # Determine feature extractor class
834
+ # TODO: Ensure backwards compatibility with old configs
835
+ key = preprocessor_config["feature_extractor_type"] || preprocessor_config["image_processor_type"]
836
+ feature_extractor_class = FEATURE_EXTRACTOR_CLASS_MAPPING[key]
837
+
838
+ if !feature_extractor_class
839
+ if preprocessor_config["size"]
840
+ # Assume ImageFeatureExtractor
841
+ warn "Feature extractor type #{key.inspect} not found, assuming ImageFeatureExtractor due to size parameter in config."
842
+ feature_extractor_class = ImageFeatureExtractor
843
+ else
844
+ raise Error, "Unknown Feature Extractor type: #{key}"
845
+ end
846
+ end
847
+
848
+ # If no associated processor class, use default
849
+ processor_class = PROCESSOR_CLASS_MAPPING[preprocessor_config["processor_class"]] || Processor
850
+
851
+ # Instantiate processor and feature extractor
852
+ feature_extractor = feature_extractor_class.new(preprocessor_config)
853
+ processor_class.new(feature_extractor)
854
+ end
855
+ end
856
+ end