informers 1.0.3 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +123 -0
- data/lib/informers/configs.rb +10 -8
- data/lib/informers/model.rb +2 -9
- data/lib/informers/models.rb +997 -12
- data/lib/informers/pipelines.rb +768 -8
- data/lib/informers/processors.rb +796 -0
- data/lib/informers/tokenizers.rb +154 -4
- data/lib/informers/utils/core.rb +4 -0
- data/lib/informers/utils/generation.rb +294 -0
- data/lib/informers/utils/image.rb +116 -0
- data/lib/informers/utils/math.rb +73 -0
- data/lib/informers/utils/tensor.rb +46 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +3 -0
- metadata +8 -5
@@ -0,0 +1,796 @@
|
|
1
|
+
module Informers
|
2
|
+
class FeatureExtractor
|
3
|
+
def initialize(config)
|
4
|
+
super()
|
5
|
+
@config = config
|
6
|
+
end
|
7
|
+
end
|
8
|
+
|
9
|
+
class ImageFeatureExtractor < FeatureExtractor
|
10
|
+
def initialize(config)
|
11
|
+
super(config)
|
12
|
+
|
13
|
+
@image_mean = @config["image_mean"] || @config["mean"]
|
14
|
+
@image_std = @config["image_std"] || @config["std"]
|
15
|
+
|
16
|
+
@resample = @config["resample"] || 2 # 2 => bilinear
|
17
|
+
@do_rescale = @config.fetch("do_rescale", true)
|
18
|
+
@rescale_factor = @config["rescale_factor"] || (1 / 255.0)
|
19
|
+
@do_normalize = @config["do_normalize"]
|
20
|
+
|
21
|
+
@do_resize = @config["do_resize"]
|
22
|
+
@do_thumbnail = @config["do_thumbnail"]
|
23
|
+
@size = @config["size"]
|
24
|
+
@size_divisibility = @config["size_divisibility"] || @config["size_divisor"]
|
25
|
+
|
26
|
+
@do_center_crop = @config["do_center_crop"]
|
27
|
+
@crop_size = @config["crop_size"]
|
28
|
+
@do_convert_rgb = @config.fetch("do_convert_rgb", true)
|
29
|
+
@do_crop_margin = @config["do_crop_margin"]
|
30
|
+
|
31
|
+
@pad_size = @config["pad_size"]
|
32
|
+
@do_pad = @config["do_pad"]
|
33
|
+
|
34
|
+
if @do_pad && !@pad_size && @size && !@size["width"].nil? && !@size["height"].nil?
|
35
|
+
# Should pad, but no pad size specified
|
36
|
+
# We infer the pad size from the resize size
|
37
|
+
@pad_size = @size
|
38
|
+
end
|
39
|
+
|
40
|
+
@do_flip_channel_order = @config["do_flip_channel_order"] || false
|
41
|
+
end
|
42
|
+
|
43
|
+
def thumbnail(image, size, resample = 2)
|
44
|
+
input_height = image.height
|
45
|
+
input_width = image.width
|
46
|
+
|
47
|
+
output_height = size["height"]
|
48
|
+
output_width = size["width"]
|
49
|
+
|
50
|
+
# We always resize to the smallest of either the input or output size.
|
51
|
+
height = [input_height, output_height].min
|
52
|
+
width = [input_width, output_width].min
|
53
|
+
|
54
|
+
if height == input_height && width == input_width
|
55
|
+
return image
|
56
|
+
end
|
57
|
+
if input_height > input_width
|
58
|
+
width = (input_width * height / input_height).floor
|
59
|
+
elsif input_width > input_height
|
60
|
+
height = (input_height * width / input_width).floor
|
61
|
+
end
|
62
|
+
image.resize(width, height, resample:)
|
63
|
+
end
|
64
|
+
|
65
|
+
def pad_image(
|
66
|
+
pixel_data,
|
67
|
+
img_dims,
|
68
|
+
pad_size,
|
69
|
+
mode: "constant",
|
70
|
+
center: false,
|
71
|
+
constant_values: 0
|
72
|
+
)
|
73
|
+
image_height, image_width, image_channels = img_dims
|
74
|
+
|
75
|
+
if pad_size.is_a?(Numeric)
|
76
|
+
padded_image_width = pad_size
|
77
|
+
padded_image_height = pad_size
|
78
|
+
else
|
79
|
+
padded_image_width = pad_size[:width] || pad_size["width"]
|
80
|
+
padded_image_height = pad_size[:height] || pad_size["height"]
|
81
|
+
end
|
82
|
+
|
83
|
+
# Only add padding if there is a difference in size
|
84
|
+
if padded_image_width != image_width || padded_image_height != image_height
|
85
|
+
padded_pixel_data = Array.new(padded_image_width * padded_image_height * image_channels)
|
86
|
+
if constant_values.is_a?(Array)
|
87
|
+
# Fill with constant values, cycling through the array
|
88
|
+
padded_pixel_data.length.times do |i|
|
89
|
+
padded_pixel_data[i] = constant_values[i % image_channels]
|
90
|
+
end
|
91
|
+
elsif constant_values != 0
|
92
|
+
padded_pixel_data.fill(constant_values)
|
93
|
+
end
|
94
|
+
|
95
|
+
left, top =
|
96
|
+
if center
|
97
|
+
[((padded_image_width - image_width) / 2.0).floor, ((padded_image_height - image_height) / 2.0).floor]
|
98
|
+
else
|
99
|
+
[0, 0]
|
100
|
+
end
|
101
|
+
|
102
|
+
# Copy the original image into the padded image
|
103
|
+
image_height.times do |i|
|
104
|
+
a = (i + top) * padded_image_width
|
105
|
+
b = i * image_width
|
106
|
+
image_width.times do |j|
|
107
|
+
c = (a + j + left) * image_channels
|
108
|
+
d = (b + j) * image_channels
|
109
|
+
image_channels.times do |k|
|
110
|
+
padded_pixel_data[c + k] = pixel_data[d + k]
|
111
|
+
end
|
112
|
+
end
|
113
|
+
end
|
114
|
+
|
115
|
+
if mode == "symmetric"
|
116
|
+
if center
|
117
|
+
raise Error, "`center` padding is not supported when `mode` is set to `symmetric`."
|
118
|
+
end
|
119
|
+
h1 = image_height - 1
|
120
|
+
w1 = image_width - 1
|
121
|
+
padded_image_height.times do |i|
|
122
|
+
a = i * padded_image_width
|
123
|
+
b = Utils.calculate_reflect_offset(i, h1) * image_width
|
124
|
+
|
125
|
+
padded_image_width.times do |j|
|
126
|
+
next if i < image_height && j < image_width # Do not overwrite original image
|
127
|
+
c = (a + j) * image_channels
|
128
|
+
d = (b + Utils.calculate_reflect_offset(j, w1)) * image_channels
|
129
|
+
|
130
|
+
# Copy channel-wise
|
131
|
+
image_channels.times do |k|
|
132
|
+
padded_pixel_data[c + k] = pixel_data[d + k]
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
137
|
+
|
138
|
+
# Update pixel data and image dimensions
|
139
|
+
pixel_data = padded_pixel_data
|
140
|
+
img_dims = [padded_image_height, padded_image_width, image_channels]
|
141
|
+
end
|
142
|
+
[pixel_data, img_dims]
|
143
|
+
end
|
144
|
+
|
145
|
+
def rescale(pixel_data)
|
146
|
+
pixel_data.length.times do |i|
|
147
|
+
pixel_data[i] *= @rescale_factor
|
148
|
+
end
|
149
|
+
end
|
150
|
+
|
151
|
+
def get_resize_output_image_size(image, size)
|
152
|
+
src_width, src_height = image.size
|
153
|
+
|
154
|
+
if @do_thumbnail
|
155
|
+
# NOTE: custom logic for `Donut` models
|
156
|
+
height = size["height"]
|
157
|
+
width = size["width"]
|
158
|
+
shortest_edge = [height, width].min
|
159
|
+
elsif size.is_a?(Numeric)
|
160
|
+
shortest_edge = size
|
161
|
+
longest_edge = @config["max_size"] || shortest_edge
|
162
|
+
elsif !size.nil?
|
163
|
+
# Extract known properties from `size`
|
164
|
+
shortest_edge = size["shortest_edge"]
|
165
|
+
longest_edge = size["longest_edge"]
|
166
|
+
end
|
167
|
+
|
168
|
+
if !shortest_edge.nil? || !longest_edge.nil?
|
169
|
+
# http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/
|
170
|
+
# Try resize so that shortest edge is `shortest_edge` (target)
|
171
|
+
short_resize_factor =
|
172
|
+
if shortest_edge.nil?
|
173
|
+
1 # If `shortest_edge` is not set, don't upscale
|
174
|
+
else
|
175
|
+
[shortest_edge / src_width.to_f, shortest_edge / src_height.to_f].max
|
176
|
+
end
|
177
|
+
|
178
|
+
new_width = src_width * short_resize_factor
|
179
|
+
new_height = src_height * short_resize_factor
|
180
|
+
|
181
|
+
# The new width and height might be greater than `longest_edge`, so
|
182
|
+
# we downscale again to ensure the largest dimension is `longest_edge`
|
183
|
+
long_resize_factor =
|
184
|
+
if longest_edge.nil?
|
185
|
+
1 # If `longest_edge` is not set, don't downscale
|
186
|
+
else
|
187
|
+
[longest_edge / new_width.to_f, longest_edge / new_height.to_f].min
|
188
|
+
end
|
189
|
+
|
190
|
+
# To avoid certain floating point precision issues, we round to 2 decimal places
|
191
|
+
final_width = (new_width * long_resize_factor).round(2).floor
|
192
|
+
final_height = (new_height * long_resize_factor).round(2).floor
|
193
|
+
|
194
|
+
if !@size_divisibility.nil?
|
195
|
+
raise Todo
|
196
|
+
end
|
197
|
+
[final_width, final_height]
|
198
|
+
elsif !size.nil? && !size["width"].nil? && !size["height"].nil?
|
199
|
+
new_width = size["width"]
|
200
|
+
new_height = size["height"]
|
201
|
+
|
202
|
+
if @config["keep_aspect_ratio"] && @config["ensure_multiple_of"]
|
203
|
+
raise Todo
|
204
|
+
end
|
205
|
+
|
206
|
+
[new_width, new_height]
|
207
|
+
else
|
208
|
+
raise Todo
|
209
|
+
end
|
210
|
+
end
|
211
|
+
|
212
|
+
def resize(image)
|
213
|
+
new_width, new_height = get_resize_output_image_size(image, @size)
|
214
|
+
image.resize(new_width, new_height, resample: @resample)
|
215
|
+
end
|
216
|
+
|
217
|
+
def preprocess(
|
218
|
+
image,
|
219
|
+
do_normalize: nil,
|
220
|
+
do_pad: nil,
|
221
|
+
do_convert_rgb: nil,
|
222
|
+
do_convert_grayscale: nil,
|
223
|
+
do_flip_channel_order: nil
|
224
|
+
)
|
225
|
+
if @do_crop_margin
|
226
|
+
# NOTE: Specific to nougat processors. This is done before resizing,
|
227
|
+
# and can be interpreted as a pre-preprocessing step.
|
228
|
+
image = crop_margin(image)
|
229
|
+
end
|
230
|
+
|
231
|
+
src_width, src_height = image.size # original image size
|
232
|
+
|
233
|
+
# Convert image to RGB if specified in config.
|
234
|
+
if !do_convert_rgb.nil? ? do_convert_rgb : @do_convert_rgb
|
235
|
+
image = image.rgb
|
236
|
+
elsif do_convert_grayscale
|
237
|
+
image = image.grayscale
|
238
|
+
end
|
239
|
+
|
240
|
+
# Resize all images
|
241
|
+
if @do_resize
|
242
|
+
image = resize(image)
|
243
|
+
end
|
244
|
+
|
245
|
+
# Resize the image using thumbnail method.
|
246
|
+
if @do_thumbnail
|
247
|
+
image = thumbnail(image, @size, @resample)
|
248
|
+
end
|
249
|
+
|
250
|
+
if @do_center_crop
|
251
|
+
if @crop_size.is_a?(Integer)
|
252
|
+
crop_width = @crop_size
|
253
|
+
crop_height = @crop_size
|
254
|
+
else
|
255
|
+
crop_width = @crop_size["width"]
|
256
|
+
crop_height = @crop_size["height"]
|
257
|
+
end
|
258
|
+
image = image.center_crop(crop_width, crop_height)
|
259
|
+
end
|
260
|
+
|
261
|
+
reshaped_input_size = [image.height, image.width]
|
262
|
+
|
263
|
+
# NOTE: All pixel-level manipulation (i.e., modifying `pixelData`)
|
264
|
+
# occurs with data in the hwc format (height, width, channels),
|
265
|
+
# to emulate the behavior of the original Python code (w/ numpy).
|
266
|
+
pixel_data = image.data
|
267
|
+
img_dims = [image.height, image.width, image.channels]
|
268
|
+
|
269
|
+
if @do_rescale
|
270
|
+
rescale(pixel_data)
|
271
|
+
end
|
272
|
+
|
273
|
+
if !do_normalize.nil? ? do_normalize : @do_normalize
|
274
|
+
image_mean = @image_mean
|
275
|
+
if !@image_mean.is_a?(Array)
|
276
|
+
image_mean = new Array(image.channels) { image_mean }
|
277
|
+
end
|
278
|
+
|
279
|
+
image_std = @image_std
|
280
|
+
if !@image_std.is_a?(Array)
|
281
|
+
image_std = new Array(image.channels) { image_std }
|
282
|
+
end
|
283
|
+
|
284
|
+
if image_mean.length != image.channels || image_std.length != image.channels
|
285
|
+
raise Error, "When set to arrays, the length of `image_mean` (#{image_mean.length}) and `image_std` (#{image_std.length}) must match the number of channels in the image (#{image.channels})."
|
286
|
+
end
|
287
|
+
|
288
|
+
i = 0
|
289
|
+
while i < pixel_data.length
|
290
|
+
image.channels.times do |j|
|
291
|
+
pixel_data[i + j] = (pixel_data[i + j] - image_mean[j]) / image_std[j]
|
292
|
+
end
|
293
|
+
i += image.channels
|
294
|
+
end
|
295
|
+
end
|
296
|
+
|
297
|
+
# do padding after rescaling/normalizing
|
298
|
+
if !do_pad.nil? ? do_pad : @do_pad
|
299
|
+
if @pad_size
|
300
|
+
padded = pad_image(pixel_data, [image.height, image.width, image.channels], @pad_size)
|
301
|
+
pixel_data, img_dims = padded # Update pixel data and image dimensions
|
302
|
+
elsif @size_divisibility
|
303
|
+
raise Todo
|
304
|
+
end
|
305
|
+
end
|
306
|
+
|
307
|
+
if !do_flip_channel_order.nil? ? do_flip_channel_order : @do_flip_channel_order
|
308
|
+
raise Todo
|
309
|
+
end
|
310
|
+
|
311
|
+
# convert to channel dimension format (hwc -> chw)
|
312
|
+
h, w, c = img_dims
|
313
|
+
pixel_values =
|
314
|
+
c.times.map do |ci|
|
315
|
+
h.times.map do |hi|
|
316
|
+
w.times.map do |wi|
|
317
|
+
index = (hi * w * c) + (wi * c) + ci
|
318
|
+
pixel_data[index]
|
319
|
+
end
|
320
|
+
end
|
321
|
+
end
|
322
|
+
|
323
|
+
{
|
324
|
+
original_size: [src_height, src_width],
|
325
|
+
reshaped_input_size: reshaped_input_size,
|
326
|
+
pixel_values: pixel_values
|
327
|
+
}
|
328
|
+
end
|
329
|
+
|
330
|
+
def call(images, *args)
|
331
|
+
if !images.is_a?(Array)
|
332
|
+
images = [images]
|
333
|
+
end
|
334
|
+
|
335
|
+
image_data = images.map { |x| preprocess(x) }
|
336
|
+
|
337
|
+
# Stack pixel values
|
338
|
+
pixel_values = Utils.stack(image_data.map { |x| x[:pixel_values] }, 0)
|
339
|
+
|
340
|
+
{
|
341
|
+
pixel_values: pixel_values,
|
342
|
+
|
343
|
+
# Original sizes of images
|
344
|
+
original_sizes: image_data.map { |x| x[:original_size] },
|
345
|
+
|
346
|
+
# Reshaped sizes of images, before padding or cropping
|
347
|
+
reshaped_input_sizes: image_data.map { |x| x[:reshaped_input_size] }
|
348
|
+
}
|
349
|
+
end
|
350
|
+
end
|
351
|
+
|
352
|
+
class CLIPFeatureExtractor < ImageFeatureExtractor
|
353
|
+
end
|
354
|
+
|
355
|
+
class DPTFeatureExtractor < ImageFeatureExtractor
|
356
|
+
end
|
357
|
+
|
358
|
+
class ViTFeatureExtractor < ImageFeatureExtractor
|
359
|
+
end
|
360
|
+
|
361
|
+
class OwlViTFeatureExtractor < ImageFeatureExtractor
|
362
|
+
def post_process_object_detection(*args)
|
363
|
+
Utils.post_process_object_detection(*args)
|
364
|
+
end
|
365
|
+
end
|
366
|
+
|
367
|
+
class Swin2SRImageProcessor < ImageFeatureExtractor
|
368
|
+
def pad_image(pixel_data, img_dims, pad_size, **options)
|
369
|
+
# NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.
|
370
|
+
# In other words, the image is padded so that its width and height are multiples of `padSize`.
|
371
|
+
image_height, image_width, _image_channels = img_dims
|
372
|
+
|
373
|
+
super(
|
374
|
+
pixel_data,
|
375
|
+
img_dims,
|
376
|
+
{
|
377
|
+
# NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already
|
378
|
+
# a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19).
|
379
|
+
# For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`.
|
380
|
+
width: image_width + (pad_size - image_width % pad_size) % pad_size,
|
381
|
+
height: image_height + (pad_size - image_height % pad_size) % pad_size
|
382
|
+
},
|
383
|
+
mode: "symmetric",
|
384
|
+
center: false,
|
385
|
+
constant_values: -1,
|
386
|
+
**options
|
387
|
+
)
|
388
|
+
end
|
389
|
+
end
|
390
|
+
|
391
|
+
class DonutFeatureExtractor < ImageFeatureExtractor
|
392
|
+
def pad_image(pixel_data, img_dims, pad_size, **options)
|
393
|
+
_image_height, _image_width, image_channels = img_dims
|
394
|
+
|
395
|
+
image_mean = @image_mean
|
396
|
+
if !image_mean.is_a?(Array)
|
397
|
+
image_mean = new Array(image_channels, image_mean)
|
398
|
+
end
|
399
|
+
|
400
|
+
image_std = @image_std
|
401
|
+
if !image_std.is_a?(Array)
|
402
|
+
image_std = new Array(image_channels, image_std)
|
403
|
+
end
|
404
|
+
|
405
|
+
constant_values = image_mean.map.with_index { |x, i| -x / image_std[i] }
|
406
|
+
|
407
|
+
super(
|
408
|
+
pixel_data,
|
409
|
+
img_dims,
|
410
|
+
pad_size,
|
411
|
+
center: true,
|
412
|
+
# Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed.
|
413
|
+
# For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451
|
414
|
+
constant_values: constant_values,
|
415
|
+
**options
|
416
|
+
)
|
417
|
+
end
|
418
|
+
end
|
419
|
+
|
420
|
+
class DetrFeatureExtractor < ImageFeatureExtractor
|
421
|
+
def call(images)
|
422
|
+
result = super(images)
|
423
|
+
|
424
|
+
# TODO support differently-sized images, for now assume all images are the same size.
|
425
|
+
# TODO support different mask sizes (not just 64x64)
|
426
|
+
# Currently, just fill pixel mask with 1s
|
427
|
+
mask_size = [result[:pixel_values].size, 64, 64]
|
428
|
+
pixel_mask =
|
429
|
+
mask_size[0].times.map do
|
430
|
+
mask_size[1].times.map do
|
431
|
+
mask_size[2].times.map do
|
432
|
+
1
|
433
|
+
end
|
434
|
+
end
|
435
|
+
end
|
436
|
+
|
437
|
+
result.merge(pixel_mask: pixel_mask)
|
438
|
+
end
|
439
|
+
|
440
|
+
def post_process_object_detection(*args)
|
441
|
+
Utils.post_process_object_detection(*args)
|
442
|
+
end
|
443
|
+
|
444
|
+
def remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels)
|
445
|
+
mask_probs_item = []
|
446
|
+
pred_scores_item = []
|
447
|
+
pred_labels_item = []
|
448
|
+
|
449
|
+
class_logits.size.times do |j|
|
450
|
+
cls = class_logits[j]
|
451
|
+
mask = mask_logits[j]
|
452
|
+
|
453
|
+
pred_label = Utils.max(cls)[1]
|
454
|
+
if pred_label == num_labels
|
455
|
+
# Is the background, so we ignore it
|
456
|
+
next
|
457
|
+
end
|
458
|
+
|
459
|
+
scores = Utils.softmax(cls)
|
460
|
+
pred_score = scores[pred_label]
|
461
|
+
if pred_score > object_mask_threshold
|
462
|
+
mask_probs_item << mask
|
463
|
+
pred_scores_item << pred_score
|
464
|
+
pred_labels_item << pred_label
|
465
|
+
end
|
466
|
+
end
|
467
|
+
|
468
|
+
[mask_probs_item, pred_scores_item, pred_labels_item]
|
469
|
+
end
|
470
|
+
|
471
|
+
def check_segment_validity(
|
472
|
+
mask_labels,
|
473
|
+
mask_probs,
|
474
|
+
k,
|
475
|
+
mask_threshold = 0.5,
|
476
|
+
overlap_mask_area_threshold = 0.8
|
477
|
+
)
|
478
|
+
# mask_k is a 1D array of indices, indicating where the mask is equal to k
|
479
|
+
mask_k = []
|
480
|
+
mask_k_area = 0
|
481
|
+
original_area = 0
|
482
|
+
|
483
|
+
mask_probs_k_data = mask_probs[k].flatten
|
484
|
+
|
485
|
+
# Compute the area of all the stuff in query k
|
486
|
+
mask_labels.length.times do |i|
|
487
|
+
if mask_labels[i] == k
|
488
|
+
mask_k << i
|
489
|
+
mask_k_area += 1
|
490
|
+
end
|
491
|
+
|
492
|
+
if mask_probs_k_data[i] >= mask_threshold
|
493
|
+
original_area += 1
|
494
|
+
end
|
495
|
+
end
|
496
|
+
mask_exists = mask_k_area > 0 && original_area > 0
|
497
|
+
|
498
|
+
# Eliminate disconnected tiny segments
|
499
|
+
if mask_exists
|
500
|
+
# Perform additional check
|
501
|
+
area_ratio = mask_k_area / original_area
|
502
|
+
mask_exists = area_ratio > overlap_mask_area_threshold
|
503
|
+
end
|
504
|
+
|
505
|
+
[mask_exists, mask_k]
|
506
|
+
end
|
507
|
+
|
508
|
+
def compute_segments(
|
509
|
+
mask_probs,
|
510
|
+
pred_scores,
|
511
|
+
pred_labels,
|
512
|
+
mask_threshold,
|
513
|
+
overlap_mask_area_threshold,
|
514
|
+
label_ids_to_fuse = nil,
|
515
|
+
target_size = nil
|
516
|
+
)
|
517
|
+
height, width = target_size || Utils.dims(mask_probs[0])
|
518
|
+
|
519
|
+
segmentation = Array.new(height * width)
|
520
|
+
segments = []
|
521
|
+
|
522
|
+
# 1. If target_size is not null, we need to resize the masks to the target size
|
523
|
+
if !target_size.nil?
|
524
|
+
# resize the masks to the target size
|
525
|
+
mask_probs.length.times do |i|
|
526
|
+
mask_probs[i] = Utils.interpolate(mask_probs[i], target_size, "bilinear", false)
|
527
|
+
end
|
528
|
+
end
|
529
|
+
|
530
|
+
# 2. Weigh each mask by its prediction score
|
531
|
+
# NOTE: `mask_probs` is updated in-place
|
532
|
+
#
|
533
|
+
# Temporary storage for the best label/scores for each pixel ([height, width]):
|
534
|
+
mask_labels = Array.new(mask_probs[0].flatten.length)
|
535
|
+
best_scores = Array.new(mask_probs[0].flatten.length, 0)
|
536
|
+
|
537
|
+
mask_probs.length.times do |i|
|
538
|
+
score = pred_scores[i]
|
539
|
+
|
540
|
+
mask_probs_i_data = mask_probs[i].flatten
|
541
|
+
mask_probs_i_dims = Utils.dims(mask_probs[i])
|
542
|
+
|
543
|
+
mask_probs_i_data.length.times do |j|
|
544
|
+
mask_probs_i_data[j] *= score
|
545
|
+
if mask_probs_i_data[j] > best_scores[j]
|
546
|
+
mask_labels[j] = i
|
547
|
+
best_scores[j] = mask_probs_i_data[j]
|
548
|
+
end
|
549
|
+
end
|
550
|
+
|
551
|
+
mask_probs[i] = Utils.reshape(mask_probs_i_data, mask_probs_i_dims)
|
552
|
+
end
|
553
|
+
|
554
|
+
current_segment_id = 0
|
555
|
+
|
556
|
+
# stuff_memory_list = {}
|
557
|
+
pred_labels.length.times do |k|
|
558
|
+
pred_class = pred_labels[k]
|
559
|
+
|
560
|
+
# TODO add `should_fuse`
|
561
|
+
# should_fuse = label_ids_to_fuse.include?(pred_class)
|
562
|
+
|
563
|
+
# Check if mask exists and large enough to be a segment
|
564
|
+
mask_exists, mask_k = check_segment_validity(
|
565
|
+
mask_labels,
|
566
|
+
mask_probs,
|
567
|
+
k,
|
568
|
+
mask_threshold,
|
569
|
+
overlap_mask_area_threshold
|
570
|
+
)
|
571
|
+
|
572
|
+
if !mask_exists
|
573
|
+
# Nothing to see here
|
574
|
+
next
|
575
|
+
end
|
576
|
+
|
577
|
+
current_segment_id += 1
|
578
|
+
|
579
|
+
# Add current object segment to final segmentation map
|
580
|
+
mask_k.each do |index|
|
581
|
+
segmentation[index] = current_segment_id
|
582
|
+
end
|
583
|
+
|
584
|
+
segments << {
|
585
|
+
id: current_segment_id,
|
586
|
+
label_id: pred_class,
|
587
|
+
score: pred_scores[k]
|
588
|
+
}
|
589
|
+
end
|
590
|
+
|
591
|
+
segmentation = Utils.reshape(segmentation, [height, width])
|
592
|
+
|
593
|
+
[segmentation, segments]
|
594
|
+
end
|
595
|
+
|
596
|
+
def post_process_panoptic_segmentation(
|
597
|
+
outputs,
|
598
|
+
threshold: 0.5,
|
599
|
+
mask_threshold: 0.5,
|
600
|
+
overlap_mask_area_threshold: 0.8,
|
601
|
+
label_ids_to_fuse: nil,
|
602
|
+
target_sizes: nil
|
603
|
+
)
|
604
|
+
if label_ids_to_fuse.nil?
|
605
|
+
warn "`label_ids_to_fuse` unset. No instance will be fused."
|
606
|
+
label_ids_to_fuse = Set.new
|
607
|
+
end
|
608
|
+
|
609
|
+
class_queries_logits = outputs[:logits] # [batch_size, num_queries, num_classes+1]
|
610
|
+
masks_queries_logits = outputs[:pred_masks] # [batch_size, num_queries, height, width]
|
611
|
+
|
612
|
+
mask_probs = Utils.sigmoid(masks_queries_logits) # [batch_size, num_queries, height, width]
|
613
|
+
|
614
|
+
batch_size, _num_queries, num_labels = class_queries_logits.size, class_queries_logits[0].size, class_queries_logits[0][0].size
|
615
|
+
num_labels -= 1 # Remove last class (background)
|
616
|
+
|
617
|
+
if !target_sizes.nil? && target_sizes.length != batch_size
|
618
|
+
raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
619
|
+
end
|
620
|
+
|
621
|
+
to_return = []
|
622
|
+
batch_size.times do |i|
|
623
|
+
target_size = !target_sizes.nil? ? target_sizes[i] : nil
|
624
|
+
|
625
|
+
class_logits = class_queries_logits[i]
|
626
|
+
mask_logits = mask_probs[i]
|
627
|
+
|
628
|
+
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels)
|
629
|
+
|
630
|
+
if pred_labels_item.length == 0
|
631
|
+
raise Todo
|
632
|
+
end
|
633
|
+
|
634
|
+
# Get segmentation map and segment information of batch item
|
635
|
+
segmentation, segments = compute_segments(
|
636
|
+
mask_probs_item,
|
637
|
+
pred_scores_item,
|
638
|
+
pred_labels_item,
|
639
|
+
mask_threshold,
|
640
|
+
overlap_mask_area_threshold,
|
641
|
+
label_ids_to_fuse,
|
642
|
+
target_size
|
643
|
+
)
|
644
|
+
|
645
|
+
to_return << {
|
646
|
+
segmentation: segmentation,
|
647
|
+
segments_info: segments
|
648
|
+
}
|
649
|
+
end
|
650
|
+
|
651
|
+
to_return
|
652
|
+
end
|
653
|
+
end
|
654
|
+
|
655
|
+
module Utils
|
656
|
+
def self.center_to_corners_format(v)
|
657
|
+
centerX, centerY, width, height = v
|
658
|
+
[
|
659
|
+
centerX - width / 2.0,
|
660
|
+
centerY - height / 2.0,
|
661
|
+
centerX + width / 2.0,
|
662
|
+
centerY + height / 2.0
|
663
|
+
]
|
664
|
+
end
|
665
|
+
|
666
|
+
def self.post_process_object_detection(outputs, threshold = 0.5, target_sizes = nil, is_zero_shot = false)
|
667
|
+
out_logits = outputs[:logits]
|
668
|
+
out_bbox = outputs[:pred_boxes]
|
669
|
+
batch_size, num_boxes, num_classes = out_logits.size, out_logits[0].size, out_logits[0][0].size
|
670
|
+
|
671
|
+
if !target_sizes.nil? && target_sizes.length != batch_size
|
672
|
+
raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
673
|
+
end
|
674
|
+
to_return = []
|
675
|
+
batch_size.times do |i|
|
676
|
+
target_size = !target_sizes.nil? ? target_sizes[i] : nil
|
677
|
+
info = {
|
678
|
+
boxes: [],
|
679
|
+
classes: [],
|
680
|
+
scores: []
|
681
|
+
}
|
682
|
+
logits = out_logits[i]
|
683
|
+
bbox = out_bbox[i]
|
684
|
+
|
685
|
+
num_boxes.times do |j|
|
686
|
+
logit = logits[j]
|
687
|
+
|
688
|
+
indices = []
|
689
|
+
if is_zero_shot
|
690
|
+
# Get indices of classes with high enough probability
|
691
|
+
probs = Utils.sigmoid(logit)
|
692
|
+
probs.length.times do |k|
|
693
|
+
if probs[k] > threshold
|
694
|
+
indices << k
|
695
|
+
end
|
696
|
+
end
|
697
|
+
else
|
698
|
+
# Get most probable class
|
699
|
+
max_index = Utils.max(logit)[1]
|
700
|
+
|
701
|
+
if max_index == num_classes - 1
|
702
|
+
# This is the background class, skip it
|
703
|
+
next
|
704
|
+
end
|
705
|
+
indices << max_index
|
706
|
+
|
707
|
+
# Compute softmax over classes
|
708
|
+
probs = Utils.softmax(logit)
|
709
|
+
end
|
710
|
+
|
711
|
+
indices.each do |index|
|
712
|
+
box = bbox[j]
|
713
|
+
|
714
|
+
# convert to [x0, y0, x1, y1] format
|
715
|
+
box = center_to_corners_format(box)
|
716
|
+
if !target_size.nil?
|
717
|
+
box = box.map.with_index { |x, i| x * target_size[(i + 1) % 2] }
|
718
|
+
end
|
719
|
+
|
720
|
+
info[:boxes] << box
|
721
|
+
info[:classes] << index
|
722
|
+
info[:scores] << probs[index]
|
723
|
+
end
|
724
|
+
end
|
725
|
+
to_return << info
|
726
|
+
end
|
727
|
+
to_return
|
728
|
+
end
|
729
|
+
end
|
730
|
+
|
731
|
+
class Processor
|
732
|
+
attr_reader :feature_extractor
|
733
|
+
|
734
|
+
def initialize(feature_extractor)
|
735
|
+
@feature_extractor = feature_extractor
|
736
|
+
end
|
737
|
+
|
738
|
+
def call(input, *args)
|
739
|
+
@feature_extractor.(input, *args)
|
740
|
+
end
|
741
|
+
end
|
742
|
+
|
743
|
+
class AutoProcessor
|
744
|
+
FEATURE_EXTRACTOR_CLASS_MAPPING = {
|
745
|
+
"ViTFeatureExtractor" => ViTFeatureExtractor,
|
746
|
+
"OwlViTFeatureExtractor" => OwlViTFeatureExtractor,
|
747
|
+
"CLIPFeatureExtractor" => CLIPFeatureExtractor,
|
748
|
+
"DPTFeatureExtractor" => DPTFeatureExtractor,
|
749
|
+
"DetrFeatureExtractor" => DetrFeatureExtractor,
|
750
|
+
"Swin2SRImageProcessor" => Swin2SRImageProcessor,
|
751
|
+
"DonutFeatureExtractor" => DonutFeatureExtractor
|
752
|
+
}
|
753
|
+
|
754
|
+
PROCESSOR_CLASS_MAPPING = {}
|
755
|
+
|
756
|
+
def self.from_pretrained(
|
757
|
+
pretrained_model_name_or_path,
|
758
|
+
progress_callback: nil,
|
759
|
+
config: nil,
|
760
|
+
cache_dir: nil,
|
761
|
+
local_files_only: false,
|
762
|
+
revision: "main",
|
763
|
+
**kwargs
|
764
|
+
)
|
765
|
+
preprocessor_config = config || Utils::Hub::get_model_json(pretrained_model_name_or_path, "preprocessor_config.json", true,
|
766
|
+
progress_callback:,
|
767
|
+
config:,
|
768
|
+
cache_dir:,
|
769
|
+
local_files_only:,
|
770
|
+
revision:
|
771
|
+
)
|
772
|
+
|
773
|
+
# Determine feature extractor class
|
774
|
+
# TODO: Ensure backwards compatibility with old configs
|
775
|
+
key = preprocessor_config["feature_extractor_type"] || preprocessor_config["image_processor_type"]
|
776
|
+
feature_extractor_class = FEATURE_EXTRACTOR_CLASS_MAPPING[key]
|
777
|
+
|
778
|
+
if !feature_extractor_class
|
779
|
+
if preprocessor_config["size"]
|
780
|
+
# Assume ImageFeatureExtractor
|
781
|
+
warn "Feature extractor type #{key.inspect} not found, assuming ImageFeatureExtractor due to size parameter in config."
|
782
|
+
feature_extractor_class = ImageFeatureExtractor
|
783
|
+
else
|
784
|
+
raise Error, "Unknown Feature Extractor type: #{key}"
|
785
|
+
end
|
786
|
+
end
|
787
|
+
|
788
|
+
# If no associated processor class, use default
|
789
|
+
processor_class = PROCESSOR_CLASS_MAPPING[preprocessor_config["processor_class"]] || Processor
|
790
|
+
|
791
|
+
# Instantiate processor and feature extractor
|
792
|
+
feature_extractor = feature_extractor_class.new(preprocessor_config)
|
793
|
+
processor_class.new(feature_extractor)
|
794
|
+
end
|
795
|
+
end
|
796
|
+
end
|