informers 1.0.3 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,856 @@
1
+ module Informers
2
+ class FeatureExtractor
3
+ attr_reader :config
4
+
5
+ def initialize(config)
6
+ super()
7
+ @config = config
8
+ end
9
+ end
10
+
11
+ class ImageFeatureExtractor < FeatureExtractor
12
+ def initialize(config)
13
+ super(config)
14
+
15
+ @image_mean = @config["image_mean"] || @config["mean"]
16
+ @image_std = @config["image_std"] || @config["std"]
17
+
18
+ @resample = @config["resample"] || 2 # 2 => bilinear
19
+ @do_rescale = @config.fetch("do_rescale", true)
20
+ @rescale_factor = @config["rescale_factor"] || (1 / 255.0)
21
+ @do_normalize = @config["do_normalize"]
22
+
23
+ @do_resize = @config["do_resize"]
24
+ @do_thumbnail = @config["do_thumbnail"]
25
+ @size = @config["size"]
26
+ @size_divisibility = @config["size_divisibility"] || @config["size_divisor"]
27
+
28
+ @do_center_crop = @config["do_center_crop"]
29
+ @crop_size = @config["crop_size"]
30
+ @do_convert_rgb = @config.fetch("do_convert_rgb", true)
31
+ @do_crop_margin = @config["do_crop_margin"]
32
+
33
+ @pad_size = @config["pad_size"]
34
+ @do_pad = @config["do_pad"]
35
+
36
+ if @do_pad && !@pad_size && @size && !@size["width"].nil? && !@size["height"].nil?
37
+ # Should pad, but no pad size specified
38
+ # We infer the pad size from the resize size
39
+ @pad_size = @size
40
+ end
41
+
42
+ @do_flip_channel_order = @config["do_flip_channel_order"] || false
43
+ end
44
+
45
+ def thumbnail(image, size, resample = 2)
46
+ input_height = image.height
47
+ input_width = image.width
48
+
49
+ output_height = size["height"]
50
+ output_width = size["width"]
51
+
52
+ # We always resize to the smallest of either the input or output size.
53
+ height = [input_height, output_height].min
54
+ width = [input_width, output_width].min
55
+
56
+ if height == input_height && width == input_width
57
+ return image
58
+ end
59
+ if input_height > input_width
60
+ width = (input_width * height / input_height).floor
61
+ elsif input_width > input_height
62
+ height = (input_height * width / input_width).floor
63
+ end
64
+ image.resize(width, height, resample:)
65
+ end
66
+
67
+ def pad_image(
68
+ pixel_data,
69
+ img_dims,
70
+ pad_size,
71
+ mode: "constant",
72
+ center: false,
73
+ constant_values: 0
74
+ )
75
+ image_height, image_width, image_channels = img_dims
76
+
77
+ if pad_size.is_a?(Numeric)
78
+ padded_image_width = pad_size
79
+ padded_image_height = pad_size
80
+ else
81
+ padded_image_width = pad_size[:width] || pad_size["width"]
82
+ padded_image_height = pad_size[:height] || pad_size["height"]
83
+ end
84
+
85
+ # Only add padding if there is a difference in size
86
+ if padded_image_width != image_width || padded_image_height != image_height
87
+ padded_pixel_data = Array.new(padded_image_width * padded_image_height * image_channels)
88
+ if constant_values.is_a?(Array)
89
+ # Fill with constant values, cycling through the array
90
+ padded_pixel_data.length.times do |i|
91
+ padded_pixel_data[i] = constant_values[i % image_channels]
92
+ end
93
+ elsif constant_values != 0
94
+ padded_pixel_data.fill(constant_values)
95
+ end
96
+
97
+ left, top =
98
+ if center
99
+ [((padded_image_width - image_width) / 2.0).floor, ((padded_image_height - image_height) / 2.0).floor]
100
+ else
101
+ [0, 0]
102
+ end
103
+
104
+ # Copy the original image into the padded image
105
+ image_height.times do |i|
106
+ a = (i + top) * padded_image_width
107
+ b = i * image_width
108
+ image_width.times do |j|
109
+ c = (a + j + left) * image_channels
110
+ d = (b + j) * image_channels
111
+ image_channels.times do |k|
112
+ padded_pixel_data[c + k] = pixel_data[d + k]
113
+ end
114
+ end
115
+ end
116
+
117
+ if mode == "symmetric"
118
+ if center
119
+ raise Error, "`center` padding is not supported when `mode` is set to `symmetric`."
120
+ end
121
+ h1 = image_height - 1
122
+ w1 = image_width - 1
123
+ padded_image_height.times do |i|
124
+ a = i * padded_image_width
125
+ b = Utils.calculate_reflect_offset(i, h1) * image_width
126
+
127
+ padded_image_width.times do |j|
128
+ next if i < image_height && j < image_width # Do not overwrite original image
129
+ c = (a + j) * image_channels
130
+ d = (b + Utils.calculate_reflect_offset(j, w1)) * image_channels
131
+
132
+ # Copy channel-wise
133
+ image_channels.times do |k|
134
+ padded_pixel_data[c + k] = pixel_data[d + k]
135
+ end
136
+ end
137
+ end
138
+ end
139
+
140
+ # Update pixel data and image dimensions
141
+ pixel_data = padded_pixel_data
142
+ img_dims = [padded_image_height, padded_image_width, image_channels]
143
+ end
144
+ [pixel_data, img_dims]
145
+ end
146
+
147
+ def rescale(pixel_data)
148
+ pixel_data.length.times do |i|
149
+ pixel_data[i] *= @rescale_factor
150
+ end
151
+ end
152
+
153
+ def get_resize_output_image_size(image, size)
154
+ src_width, src_height = image.size
155
+
156
+ if @do_thumbnail
157
+ # NOTE: custom logic for `Donut` models
158
+ height = size["height"]
159
+ width = size["width"]
160
+ shortest_edge = [height, width].min
161
+ elsif size.is_a?(Numeric)
162
+ shortest_edge = size
163
+ longest_edge = @config["max_size"] || shortest_edge
164
+ elsif !size.nil?
165
+ # Extract known properties from `size`
166
+ shortest_edge = size["shortest_edge"]
167
+ longest_edge = size["longest_edge"]
168
+ end
169
+
170
+ if !shortest_edge.nil? || !longest_edge.nil?
171
+ # http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/
172
+ # Try resize so that shortest edge is `shortest_edge` (target)
173
+ short_resize_factor =
174
+ if shortest_edge.nil?
175
+ 1 # If `shortest_edge` is not set, don't upscale
176
+ else
177
+ [shortest_edge / src_width.to_f, shortest_edge / src_height.to_f].max
178
+ end
179
+
180
+ new_width = src_width * short_resize_factor
181
+ new_height = src_height * short_resize_factor
182
+
183
+ # The new width and height might be greater than `longest_edge`, so
184
+ # we downscale again to ensure the largest dimension is `longest_edge`
185
+ long_resize_factor =
186
+ if longest_edge.nil?
187
+ 1 # If `longest_edge` is not set, don't downscale
188
+ else
189
+ [longest_edge / new_width.to_f, longest_edge / new_height.to_f].min
190
+ end
191
+
192
+ # To avoid certain floating point precision issues, we round to 2 decimal places
193
+ final_width = (new_width * long_resize_factor).round(2).floor
194
+ final_height = (new_height * long_resize_factor).round(2).floor
195
+
196
+ if !@size_divisibility.nil?
197
+ raise Todo
198
+ end
199
+ [final_width, final_height]
200
+ elsif !size.nil? && !size["width"].nil? && !size["height"].nil?
201
+ new_width = size["width"]
202
+ new_height = size["height"]
203
+
204
+ if @config["keep_aspect_ratio"] && @config["ensure_multiple_of"]
205
+ raise Todo
206
+ end
207
+
208
+ [new_width, new_height]
209
+ else
210
+ raise Todo
211
+ end
212
+ end
213
+
214
+ def resize(image)
215
+ new_width, new_height = get_resize_output_image_size(image, @size)
216
+ image.resize(new_width, new_height, resample: @resample)
217
+ end
218
+
219
+ def preprocess(
220
+ image,
221
+ do_normalize: nil,
222
+ do_pad: nil,
223
+ do_convert_rgb: nil,
224
+ do_convert_grayscale: nil,
225
+ do_flip_channel_order: nil
226
+ )
227
+ if @do_crop_margin
228
+ # NOTE: Specific to nougat processors. This is done before resizing,
229
+ # and can be interpreted as a pre-preprocessing step.
230
+ image = crop_margin(image)
231
+ end
232
+
233
+ src_width, src_height = image.size # original image size
234
+
235
+ # Convert image to RGB if specified in config.
236
+ if !do_convert_rgb.nil? ? do_convert_rgb : @do_convert_rgb
237
+ image = image.rgb
238
+ elsif do_convert_grayscale
239
+ image = image.grayscale
240
+ end
241
+
242
+ # Resize all images
243
+ if @do_resize
244
+ image = resize(image)
245
+ end
246
+
247
+ # Resize the image using thumbnail method.
248
+ if @do_thumbnail
249
+ image = thumbnail(image, @size, @resample)
250
+ end
251
+
252
+ if @do_center_crop
253
+ if @crop_size.is_a?(Integer)
254
+ crop_width = @crop_size
255
+ crop_height = @crop_size
256
+ else
257
+ crop_width = @crop_size["width"]
258
+ crop_height = @crop_size["height"]
259
+ end
260
+ image = image.center_crop(crop_width, crop_height)
261
+ end
262
+
263
+ reshaped_input_size = [image.height, image.width]
264
+
265
+ # NOTE: All pixel-level manipulation (i.e., modifying `pixelData`)
266
+ # occurs with data in the hwc format (height, width, channels),
267
+ # to emulate the behavior of the original Python code (w/ numpy).
268
+ pixel_data = image.data
269
+ img_dims = [image.height, image.width, image.channels]
270
+
271
+ if @do_rescale
272
+ rescale(pixel_data)
273
+ end
274
+
275
+ if !do_normalize.nil? ? do_normalize : @do_normalize
276
+ image_mean = @image_mean
277
+ if !@image_mean.is_a?(Array)
278
+ image_mean = new Array(image.channels) { image_mean }
279
+ end
280
+
281
+ image_std = @image_std
282
+ if !@image_std.is_a?(Array)
283
+ image_std = new Array(image.channels) { image_std }
284
+ end
285
+
286
+ if image_mean.length != image.channels || image_std.length != image.channels
287
+ raise Error, "When set to arrays, the length of `image_mean` (#{image_mean.length}) and `image_std` (#{image_std.length}) must match the number of channels in the image (#{image.channels})."
288
+ end
289
+
290
+ i = 0
291
+ while i < pixel_data.length
292
+ image.channels.times do |j|
293
+ pixel_data[i + j] = (pixel_data[i + j] - image_mean[j]) / image_std[j]
294
+ end
295
+ i += image.channels
296
+ end
297
+ end
298
+
299
+ # do padding after rescaling/normalizing
300
+ if !do_pad.nil? ? do_pad : @do_pad
301
+ if @pad_size
302
+ padded = pad_image(pixel_data, [image.height, image.width, image.channels], @pad_size)
303
+ pixel_data, img_dims = padded # Update pixel data and image dimensions
304
+ elsif @size_divisibility
305
+ raise Todo
306
+ end
307
+ end
308
+
309
+ if !do_flip_channel_order.nil? ? do_flip_channel_order : @do_flip_channel_order
310
+ raise Todo
311
+ end
312
+
313
+ # convert to channel dimension format (hwc -> chw)
314
+ h, w, c = img_dims
315
+ pixel_values =
316
+ c.times.map do |ci|
317
+ h.times.map do |hi|
318
+ w.times.map do |wi|
319
+ index = (hi * w * c) + (wi * c) + ci
320
+ pixel_data[index]
321
+ end
322
+ end
323
+ end
324
+
325
+ {
326
+ original_size: [src_height, src_width],
327
+ reshaped_input_size: reshaped_input_size,
328
+ pixel_values: pixel_values
329
+ }
330
+ end
331
+
332
+ def call(images, *args)
333
+ if !images.is_a?(Array)
334
+ images = [images]
335
+ end
336
+
337
+ image_data = images.map { |x| preprocess(x) }
338
+
339
+ # Stack pixel values
340
+ pixel_values = Utils.stack(image_data.map { |x| x[:pixel_values] }, 0)
341
+
342
+ {
343
+ pixel_values: pixel_values,
344
+
345
+ # Original sizes of images
346
+ original_sizes: image_data.map { |x| x[:original_size] },
347
+
348
+ # Reshaped sizes of images, before padding or cropping
349
+ reshaped_input_sizes: image_data.map { |x| x[:reshaped_input_size] }
350
+ }
351
+ end
352
+ end
353
+
354
+ class CLIPFeatureExtractor < ImageFeatureExtractor
355
+ end
356
+
357
+ class DPTFeatureExtractor < ImageFeatureExtractor
358
+ end
359
+
360
+ class ViTFeatureExtractor < ImageFeatureExtractor
361
+ end
362
+
363
+ class OwlViTFeatureExtractor < ImageFeatureExtractor
364
+ def post_process_object_detection(*args)
365
+ Utils.post_process_object_detection(*args)
366
+ end
367
+ end
368
+
369
+ class Swin2SRImageProcessor < ImageFeatureExtractor
370
+ def pad_image(pixel_data, img_dims, pad_size, **options)
371
+ # NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.
372
+ # In other words, the image is padded so that its width and height are multiples of `padSize`.
373
+ image_height, image_width, _image_channels = img_dims
374
+
375
+ super(
376
+ pixel_data,
377
+ img_dims,
378
+ {
379
+ # NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already
380
+ # a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19).
381
+ # For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`.
382
+ width: image_width + (pad_size - image_width % pad_size) % pad_size,
383
+ height: image_height + (pad_size - image_height % pad_size) % pad_size
384
+ },
385
+ mode: "symmetric",
386
+ center: false,
387
+ constant_values: -1,
388
+ **options
389
+ )
390
+ end
391
+ end
392
+
393
+ class DonutFeatureExtractor < ImageFeatureExtractor
394
+ def pad_image(pixel_data, img_dims, pad_size, **options)
395
+ _image_height, _image_width, image_channels = img_dims
396
+
397
+ image_mean = @image_mean
398
+ if !image_mean.is_a?(Array)
399
+ image_mean = new Array(image_channels, image_mean)
400
+ end
401
+
402
+ image_std = @image_std
403
+ if !image_std.is_a?(Array)
404
+ image_std = new Array(image_channels, image_std)
405
+ end
406
+
407
+ constant_values = image_mean.map.with_index { |x, i| -x / image_std[i] }
408
+
409
+ super(
410
+ pixel_data,
411
+ img_dims,
412
+ pad_size,
413
+ center: true,
414
+ # Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed.
415
+ # For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451
416
+ constant_values: constant_values,
417
+ **options
418
+ )
419
+ end
420
+ end
421
+
422
+ class DetrFeatureExtractor < ImageFeatureExtractor
423
+ def call(images)
424
+ result = super(images)
425
+
426
+ # TODO support differently-sized images, for now assume all images are the same size.
427
+ # TODO support different mask sizes (not just 64x64)
428
+ # Currently, just fill pixel mask with 1s
429
+ mask_size = [result[:pixel_values].size, 64, 64]
430
+ pixel_mask =
431
+ mask_size[0].times.map do
432
+ mask_size[1].times.map do
433
+ mask_size[2].times.map do
434
+ 1
435
+ end
436
+ end
437
+ end
438
+
439
+ result.merge(pixel_mask: pixel_mask)
440
+ end
441
+
442
+ def post_process_object_detection(*args)
443
+ Utils.post_process_object_detection(*args)
444
+ end
445
+
446
+ def remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels)
447
+ mask_probs_item = []
448
+ pred_scores_item = []
449
+ pred_labels_item = []
450
+
451
+ class_logits.size.times do |j|
452
+ cls = class_logits[j]
453
+ mask = mask_logits[j]
454
+
455
+ pred_label = Utils.max(cls)[1]
456
+ if pred_label == num_labels
457
+ # Is the background, so we ignore it
458
+ next
459
+ end
460
+
461
+ scores = Utils.softmax(cls)
462
+ pred_score = scores[pred_label]
463
+ if pred_score > object_mask_threshold
464
+ mask_probs_item << mask
465
+ pred_scores_item << pred_score
466
+ pred_labels_item << pred_label
467
+ end
468
+ end
469
+
470
+ [mask_probs_item, pred_scores_item, pred_labels_item]
471
+ end
472
+
473
+ def check_segment_validity(
474
+ mask_labels,
475
+ mask_probs,
476
+ k,
477
+ mask_threshold = 0.5,
478
+ overlap_mask_area_threshold = 0.8
479
+ )
480
+ # mask_k is a 1D array of indices, indicating where the mask is equal to k
481
+ mask_k = []
482
+ mask_k_area = 0
483
+ original_area = 0
484
+
485
+ mask_probs_k_data = mask_probs[k].flatten
486
+
487
+ # Compute the area of all the stuff in query k
488
+ mask_labels.length.times do |i|
489
+ if mask_labels[i] == k
490
+ mask_k << i
491
+ mask_k_area += 1
492
+ end
493
+
494
+ if mask_probs_k_data[i] >= mask_threshold
495
+ original_area += 1
496
+ end
497
+ end
498
+ mask_exists = mask_k_area > 0 && original_area > 0
499
+
500
+ # Eliminate disconnected tiny segments
501
+ if mask_exists
502
+ # Perform additional check
503
+ area_ratio = mask_k_area / original_area
504
+ mask_exists = area_ratio > overlap_mask_area_threshold
505
+ end
506
+
507
+ [mask_exists, mask_k]
508
+ end
509
+
510
+ def compute_segments(
511
+ mask_probs,
512
+ pred_scores,
513
+ pred_labels,
514
+ mask_threshold,
515
+ overlap_mask_area_threshold,
516
+ label_ids_to_fuse = nil,
517
+ target_size = nil
518
+ )
519
+ height, width = target_size || Utils.dims(mask_probs[0])
520
+
521
+ segmentation = Array.new(height * width)
522
+ segments = []
523
+
524
+ # 1. If target_size is not null, we need to resize the masks to the target size
525
+ if !target_size.nil?
526
+ # resize the masks to the target size
527
+ mask_probs.length.times do |i|
528
+ mask_probs[i] = Utils.interpolate(mask_probs[i], target_size, "bilinear", false)
529
+ end
530
+ end
531
+
532
+ # 2. Weigh each mask by its prediction score
533
+ # NOTE: `mask_probs` is updated in-place
534
+ #
535
+ # Temporary storage for the best label/scores for each pixel ([height, width]):
536
+ mask_labels = Array.new(mask_probs[0].flatten.length)
537
+ best_scores = Array.new(mask_probs[0].flatten.length, 0)
538
+
539
+ mask_probs.length.times do |i|
540
+ score = pred_scores[i]
541
+
542
+ mask_probs_i_data = mask_probs[i].flatten
543
+ mask_probs_i_dims = Utils.dims(mask_probs[i])
544
+
545
+ mask_probs_i_data.length.times do |j|
546
+ mask_probs_i_data[j] *= score
547
+ if mask_probs_i_data[j] > best_scores[j]
548
+ mask_labels[j] = i
549
+ best_scores[j] = mask_probs_i_data[j]
550
+ end
551
+ end
552
+
553
+ mask_probs[i] = Utils.reshape(mask_probs_i_data, mask_probs_i_dims)
554
+ end
555
+
556
+ current_segment_id = 0
557
+
558
+ # stuff_memory_list = {}
559
+ pred_labels.length.times do |k|
560
+ pred_class = pred_labels[k]
561
+
562
+ # TODO add `should_fuse`
563
+ # should_fuse = label_ids_to_fuse.include?(pred_class)
564
+
565
+ # Check if mask exists and large enough to be a segment
566
+ mask_exists, mask_k = check_segment_validity(
567
+ mask_labels,
568
+ mask_probs,
569
+ k,
570
+ mask_threshold,
571
+ overlap_mask_area_threshold
572
+ )
573
+
574
+ if !mask_exists
575
+ # Nothing to see here
576
+ next
577
+ end
578
+
579
+ current_segment_id += 1
580
+
581
+ # Add current object segment to final segmentation map
582
+ mask_k.each do |index|
583
+ segmentation[index] = current_segment_id
584
+ end
585
+
586
+ segments << {
587
+ id: current_segment_id,
588
+ label_id: pred_class,
589
+ score: pred_scores[k]
590
+ }
591
+ end
592
+
593
+ segmentation = Utils.reshape(segmentation, [height, width])
594
+
595
+ [segmentation, segments]
596
+ end
597
+
598
+ def post_process_panoptic_segmentation(
599
+ outputs,
600
+ threshold: 0.5,
601
+ mask_threshold: 0.5,
602
+ overlap_mask_area_threshold: 0.8,
603
+ label_ids_to_fuse: nil,
604
+ target_sizes: nil
605
+ )
606
+ if label_ids_to_fuse.nil?
607
+ warn "`label_ids_to_fuse` unset. No instance will be fused."
608
+ label_ids_to_fuse = Set.new
609
+ end
610
+
611
+ class_queries_logits = outputs[:logits] # [batch_size, num_queries, num_classes+1]
612
+ masks_queries_logits = outputs[:pred_masks] # [batch_size, num_queries, height, width]
613
+
614
+ mask_probs = Utils.sigmoid(masks_queries_logits) # [batch_size, num_queries, height, width]
615
+
616
+ batch_size, _num_queries, num_labels = class_queries_logits.size, class_queries_logits[0].size, class_queries_logits[0][0].size
617
+ num_labels -= 1 # Remove last class (background)
618
+
619
+ if !target_sizes.nil? && target_sizes.length != batch_size
620
+ raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
621
+ end
622
+
623
+ to_return = []
624
+ batch_size.times do |i|
625
+ target_size = !target_sizes.nil? ? target_sizes[i] : nil
626
+
627
+ class_logits = class_queries_logits[i]
628
+ mask_logits = mask_probs[i]
629
+
630
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels)
631
+
632
+ if pred_labels_item.length == 0
633
+ raise Todo
634
+ end
635
+
636
+ # Get segmentation map and segment information of batch item
637
+ segmentation, segments = compute_segments(
638
+ mask_probs_item,
639
+ pred_scores_item,
640
+ pred_labels_item,
641
+ mask_threshold,
642
+ overlap_mask_area_threshold,
643
+ label_ids_to_fuse,
644
+ target_size
645
+ )
646
+
647
+ to_return << {
648
+ segmentation: segmentation,
649
+ segments_info: segments
650
+ }
651
+ end
652
+
653
+ to_return
654
+ end
655
+ end
656
+
657
+ module Utils
658
+ def self.center_to_corners_format(v)
659
+ centerX, centerY, width, height = v
660
+ [
661
+ centerX - width / 2.0,
662
+ centerY - height / 2.0,
663
+ centerX + width / 2.0,
664
+ centerY + height / 2.0
665
+ ]
666
+ end
667
+
668
+ def self.post_process_object_detection(outputs, threshold = 0.5, target_sizes = nil, is_zero_shot = false)
669
+ out_logits = outputs[:logits]
670
+ out_bbox = outputs[:pred_boxes]
671
+ batch_size, num_boxes, num_classes = out_logits.size, out_logits[0].size, out_logits[0][0].size
672
+
673
+ if !target_sizes.nil? && target_sizes.length != batch_size
674
+ raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
675
+ end
676
+ to_return = []
677
+ batch_size.times do |i|
678
+ target_size = !target_sizes.nil? ? target_sizes[i] : nil
679
+ info = {
680
+ boxes: [],
681
+ classes: [],
682
+ scores: []
683
+ }
684
+ logits = out_logits[i]
685
+ bbox = out_bbox[i]
686
+
687
+ num_boxes.times do |j|
688
+ logit = logits[j]
689
+
690
+ indices = []
691
+ if is_zero_shot
692
+ # Get indices of classes with high enough probability
693
+ probs = Utils.sigmoid(logit)
694
+ probs.length.times do |k|
695
+ if probs[k] > threshold
696
+ indices << k
697
+ end
698
+ end
699
+ else
700
+ # Get most probable class
701
+ max_index = Utils.max(logit)[1]
702
+
703
+ if max_index == num_classes - 1
704
+ # This is the background class, skip it
705
+ next
706
+ end
707
+ indices << max_index
708
+
709
+ # Compute softmax over classes
710
+ probs = Utils.softmax(logit)
711
+ end
712
+
713
+ indices.each do |index|
714
+ box = bbox[j]
715
+
716
+ # convert to [x0, y0, x1, y1] format
717
+ box = center_to_corners_format(box)
718
+ if !target_size.nil?
719
+ box = box.map.with_index { |x, i| x * target_size[(i + 1) % 2] }
720
+ end
721
+
722
+ info[:boxes] << box
723
+ info[:classes] << index
724
+ info[:scores] << probs[index]
725
+ end
726
+ end
727
+ to_return << info
728
+ end
729
+ to_return
730
+ end
731
+ end
732
+
733
+ class WhisperFeatureExtractor < FeatureExtractor
734
+ def initialize(config)
735
+ super(config)
736
+
737
+ raise Todo
738
+ end
739
+
740
+ def _extract_fbank_features(waveform)
741
+ raise Todo
742
+ end
743
+
744
+ def call(audio)
745
+ raise Todo
746
+ end
747
+ end
748
+
749
+ class Wav2Vec2FeatureExtractor < FeatureExtractor
750
+ def _zero_mean_unit_var_norm(input_values)
751
+ sum = input_values.sum
752
+ mean = sum / input_values.length.to_f
753
+ variance = input_values.sum { |b| (b - mean) ** 2 } / input_values.length.to_f
754
+ input_values.map { |x| (x - mean) / Math.sqrt(variance + 1e-7) }
755
+ end
756
+
757
+ def call(audio)
758
+ # TODO
759
+ # validate_audio_inputs(audio, 'Wav2Vec2FeatureExtractor')
760
+
761
+ input_values = audio
762
+
763
+ # zero-mean and unit-variance normalization
764
+ if @config["do_normalize"]
765
+ input_values = _zero_mean_unit_var_norm(input_values)
766
+ end
767
+
768
+ # TODO: allow user to pass in attention mask
769
+ {
770
+ input_values: [input_values],
771
+ attention_mask: [Array.new(input_values.length, 1)]
772
+ }
773
+ end
774
+ end
775
+
776
+ class ClapFeatureExtractor < FeatureExtractor
777
+ def initialize(config)
778
+ super(config)
779
+
780
+ # TODO
781
+ end
782
+
783
+ def call(audio, max_length: nil)
784
+ raise Todo
785
+ end
786
+ end
787
+
788
+ class Processor
789
+ attr_reader :feature_extractor
790
+
791
+ def initialize(feature_extractor)
792
+ @feature_extractor = feature_extractor
793
+ end
794
+
795
+ def call(input, *args)
796
+ @feature_extractor.(input, *args)
797
+ end
798
+ end
799
+
800
+ class AutoProcessor
801
+ FEATURE_EXTRACTOR_CLASS_MAPPING = {
802
+ "ViTFeatureExtractor" => ViTFeatureExtractor,
803
+ "OwlViTFeatureExtractor" => OwlViTFeatureExtractor,
804
+ "CLIPFeatureExtractor" => CLIPFeatureExtractor,
805
+ "DPTFeatureExtractor" => DPTFeatureExtractor,
806
+ "DetrFeatureExtractor" => DetrFeatureExtractor,
807
+ "Swin2SRImageProcessor" => Swin2SRImageProcessor,
808
+ "DonutFeatureExtractor" => DonutFeatureExtractor,
809
+ "WhisperFeatureExtractor" => WhisperFeatureExtractor,
810
+ "Wav2Vec2FeatureExtractor" => Wav2Vec2FeatureExtractor,
811
+ "ClapFeatureExtractor" => ClapFeatureExtractor
812
+ }
813
+
814
+ PROCESSOR_CLASS_MAPPING = {}
815
+
816
+ def self.from_pretrained(
817
+ pretrained_model_name_or_path,
818
+ progress_callback: nil,
819
+ config: nil,
820
+ cache_dir: nil,
821
+ local_files_only: false,
822
+ revision: "main",
823
+ **kwargs
824
+ )
825
+ preprocessor_config = config || Utils::Hub.get_model_json(pretrained_model_name_or_path, "preprocessor_config.json", true,
826
+ progress_callback:,
827
+ config:,
828
+ cache_dir:,
829
+ local_files_only:,
830
+ revision:
831
+ )
832
+
833
+ # Determine feature extractor class
834
+ # TODO: Ensure backwards compatibility with old configs
835
+ key = preprocessor_config["feature_extractor_type"] || preprocessor_config["image_processor_type"]
836
+ feature_extractor_class = FEATURE_EXTRACTOR_CLASS_MAPPING[key]
837
+
838
+ if !feature_extractor_class
839
+ if preprocessor_config["size"]
840
+ # Assume ImageFeatureExtractor
841
+ warn "Feature extractor type #{key.inspect} not found, assuming ImageFeatureExtractor due to size parameter in config."
842
+ feature_extractor_class = ImageFeatureExtractor
843
+ else
844
+ raise Error, "Unknown Feature Extractor type: #{key}"
845
+ end
846
+ end
847
+
848
+ # If no associated processor class, use default
849
+ processor_class = PROCESSOR_CLASS_MAPPING[preprocessor_config["processor_class"]] || Processor
850
+
851
+ # Instantiate processor and feature extractor
852
+ feature_extractor = feature_extractor_class.new(preprocessor_config)
853
+ processor_class.new(feature_extractor)
854
+ end
855
+ end
856
+ end