nnInteractive 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnInteractive/__init__.py +3 -0
- nnInteractive/inference/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
- nnInteractive/inference/inference_session.py +1400 -0
- nnInteractive/interaction/__init__.py +0 -0
- nnInteractive/interaction/point.py +166 -0
- nnInteractive/supervoxel/setup.py +4 -0
- nnInteractive/supervoxel/src/metadata.py +118 -0
- nnInteractive/supervoxel/src/reader.py +175 -0
- nnInteractive/supervoxel/src/run.py +136 -0
- nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
- nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
- nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
- nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
- nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
- nnInteractive/supervoxel/src/sam2/setup.py +174 -0
- nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
- nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
- nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
- nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
- nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
- nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
- nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
- nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
- nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
- nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
- nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
- nnInteractive/supervoxel/src/supervoxel.py +198 -0
- nnInteractive/trainer/__init__.py +0 -0
- nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- nnInteractive/utils/__init__.py +0 -0
- nnInteractive/utils/bboxes.py +217 -0
- nnInteractive/utils/checkpoint_cleansing.py +9 -0
- nnInteractive/utils/crop.py +268 -0
- nnInteractive/utils/erosion_dilation.py +48 -0
- nnInteractive/utils/inference_helpers.py +45 -0
- nnInteractive/utils/os_shennanigans.py +16 -0
- nnInteractive/utils/rounding.py +13 -0
- nninteractive-2.0.0.dist-info/METADATA +511 -0
- nninteractive-2.0.0.dist-info/RECORD +76 -0
- nninteractive-2.0.0.dist-info/WHEEL +5 -0
- nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
- nninteractive-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,481 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Transforms and data augmentation for both image + bbox.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import random
|
|
14
|
+
from typing import Iterable
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torchvision.transforms as T
|
|
18
|
+
import torchvision.transforms.functional as F
|
|
19
|
+
import torchvision.transforms.v2.functional as Fv2
|
|
20
|
+
from PIL import Image as PILImage
|
|
21
|
+
|
|
22
|
+
from torchvision.transforms import InterpolationMode
|
|
23
|
+
|
|
24
|
+
from training.utils.data_utils import VideoDatapoint
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def hflip(datapoint, index):
|
|
28
|
+
|
|
29
|
+
datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
|
|
30
|
+
for obj in datapoint.frames[index].objects:
|
|
31
|
+
if obj.segment is not None:
|
|
32
|
+
obj.segment = F.hflip(obj.segment)
|
|
33
|
+
|
|
34
|
+
return datapoint
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
|
38
|
+
w, h = image_size
|
|
39
|
+
if max_size is not None:
|
|
40
|
+
min_original_size = float(min((w, h)))
|
|
41
|
+
max_original_size = float(max((w, h)))
|
|
42
|
+
if max_original_size / min_original_size * size > max_size:
|
|
43
|
+
size = max_size * min_original_size / max_original_size
|
|
44
|
+
|
|
45
|
+
if (w <= h and w == size) or (h <= w and h == size):
|
|
46
|
+
return (h, w)
|
|
47
|
+
|
|
48
|
+
if w < h:
|
|
49
|
+
ow = int(round(size))
|
|
50
|
+
oh = int(round(size * h / w))
|
|
51
|
+
else:
|
|
52
|
+
oh = int(round(size))
|
|
53
|
+
ow = int(round(size * w / h))
|
|
54
|
+
|
|
55
|
+
return (oh, ow)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def resize(datapoint, index, size, max_size=None, square=False, v2=False):
|
|
59
|
+
# size can be min_size (scalar) or (w, h) tuple
|
|
60
|
+
|
|
61
|
+
def get_size(image_size, size, max_size=None):
|
|
62
|
+
if isinstance(size, (list, tuple)):
|
|
63
|
+
return size[::-1]
|
|
64
|
+
else:
|
|
65
|
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
|
66
|
+
|
|
67
|
+
if square:
|
|
68
|
+
size = size, size
|
|
69
|
+
else:
|
|
70
|
+
cur_size = datapoint.frames[index].data.size()[-2:][::-1] if v2 else datapoint.frames[index].data.size
|
|
71
|
+
size = get_size(cur_size, size, max_size)
|
|
72
|
+
|
|
73
|
+
old_size = datapoint.frames[index].data.size()[-2:][::-1] if v2 else datapoint.frames[index].data.size
|
|
74
|
+
if v2:
|
|
75
|
+
datapoint.frames[index].data = Fv2.resize(datapoint.frames[index].data, size, antialias=True)
|
|
76
|
+
else:
|
|
77
|
+
datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
|
|
78
|
+
|
|
79
|
+
new_size = datapoint.frames[index].data.size()[-2:][::-1] if v2 else datapoint.frames[index].data.size
|
|
80
|
+
|
|
81
|
+
for obj in datapoint.frames[index].objects:
|
|
82
|
+
if obj.segment is not None:
|
|
83
|
+
obj.segment = F.resize(obj.segment[None, None], size).squeeze()
|
|
84
|
+
|
|
85
|
+
h, w = size
|
|
86
|
+
datapoint.frames[index].size = (h, w)
|
|
87
|
+
return datapoint
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def pad(datapoint, index, padding, v2=False):
|
|
91
|
+
old_h, old_w = datapoint.frames[index].size
|
|
92
|
+
h, w = old_h, old_w
|
|
93
|
+
if len(padding) == 2:
|
|
94
|
+
# assumes that we only pad on the bottom right corners
|
|
95
|
+
datapoint.frames[index].data = F.pad(datapoint.frames[index].data, (0, 0, padding[0], padding[1]))
|
|
96
|
+
h += padding[1]
|
|
97
|
+
w += padding[0]
|
|
98
|
+
else:
|
|
99
|
+
# left, top, right, bottom
|
|
100
|
+
datapoint.frames[index].data = F.pad(
|
|
101
|
+
datapoint.frames[index].data,
|
|
102
|
+
(padding[0], padding[1], padding[2], padding[3]),
|
|
103
|
+
)
|
|
104
|
+
h += padding[1] + padding[3]
|
|
105
|
+
w += padding[0] + padding[2]
|
|
106
|
+
|
|
107
|
+
datapoint.frames[index].size = (h, w)
|
|
108
|
+
|
|
109
|
+
for obj in datapoint.frames[index].objects:
|
|
110
|
+
if obj.segment is not None:
|
|
111
|
+
if v2:
|
|
112
|
+
if len(padding) == 2:
|
|
113
|
+
obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
|
114
|
+
else:
|
|
115
|
+
obj.segment = Fv2.pad(obj.segment, tuple(padding))
|
|
116
|
+
else:
|
|
117
|
+
if len(padding) == 2:
|
|
118
|
+
obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
|
119
|
+
else:
|
|
120
|
+
obj.segment = F.pad(obj.segment, tuple(padding))
|
|
121
|
+
return datapoint
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class RandomHorizontalFlip:
|
|
125
|
+
def __init__(self, consistent_transform, p=0.5):
|
|
126
|
+
self.p = p
|
|
127
|
+
self.consistent_transform = consistent_transform
|
|
128
|
+
|
|
129
|
+
def __call__(self, datapoint, **kwargs):
|
|
130
|
+
if self.consistent_transform:
|
|
131
|
+
if random.random() < self.p:
|
|
132
|
+
for i in range(len(datapoint.frames)):
|
|
133
|
+
datapoint = hflip(datapoint, i)
|
|
134
|
+
return datapoint
|
|
135
|
+
for i in range(len(datapoint.frames)):
|
|
136
|
+
if random.random() < self.p:
|
|
137
|
+
datapoint = hflip(datapoint, i)
|
|
138
|
+
return datapoint
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class RandomResizeAPI:
|
|
142
|
+
def __init__(self, sizes, consistent_transform, max_size=None, square=False, v2=False):
|
|
143
|
+
if isinstance(sizes, int):
|
|
144
|
+
sizes = (sizes,)
|
|
145
|
+
assert isinstance(sizes, Iterable)
|
|
146
|
+
self.sizes = list(sizes)
|
|
147
|
+
self.max_size = max_size
|
|
148
|
+
self.square = square
|
|
149
|
+
self.consistent_transform = consistent_transform
|
|
150
|
+
self.v2 = v2
|
|
151
|
+
|
|
152
|
+
def __call__(self, datapoint, **kwargs):
|
|
153
|
+
if self.consistent_transform:
|
|
154
|
+
size = random.choice(self.sizes)
|
|
155
|
+
for i in range(len(datapoint.frames)):
|
|
156
|
+
datapoint = resize(datapoint, i, size, self.max_size, square=self.square, v2=self.v2)
|
|
157
|
+
return datapoint
|
|
158
|
+
for i in range(len(datapoint.frames)):
|
|
159
|
+
size = random.choice(self.sizes)
|
|
160
|
+
datapoint = resize(datapoint, i, size, self.max_size, square=self.square, v2=self.v2)
|
|
161
|
+
return datapoint
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class ToTensorAPI:
|
|
165
|
+
def __init__(self, v2=False):
|
|
166
|
+
self.v2 = v2
|
|
167
|
+
|
|
168
|
+
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
|
169
|
+
for img in datapoint.frames:
|
|
170
|
+
if self.v2:
|
|
171
|
+
img.data = Fv2.to_image_tensor(img.data)
|
|
172
|
+
else:
|
|
173
|
+
img.data = F.to_tensor(img.data)
|
|
174
|
+
return datapoint
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class NormalizeAPI:
|
|
178
|
+
def __init__(self, mean, std, v2=False):
|
|
179
|
+
self.mean = mean
|
|
180
|
+
self.std = std
|
|
181
|
+
self.v2 = v2
|
|
182
|
+
|
|
183
|
+
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
|
184
|
+
for img in datapoint.frames:
|
|
185
|
+
if self.v2:
|
|
186
|
+
img.data = Fv2.convert_image_dtype(img.data, torch.float32)
|
|
187
|
+
img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
|
|
188
|
+
else:
|
|
189
|
+
img.data = F.normalize(img.data, mean=self.mean, std=self.std)
|
|
190
|
+
|
|
191
|
+
return datapoint
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class ComposeAPI:
|
|
195
|
+
def __init__(self, transforms):
|
|
196
|
+
self.transforms = transforms
|
|
197
|
+
|
|
198
|
+
def __call__(self, datapoint, **kwargs):
|
|
199
|
+
for t in self.transforms:
|
|
200
|
+
datapoint = t(datapoint, **kwargs)
|
|
201
|
+
return datapoint
|
|
202
|
+
|
|
203
|
+
def __repr__(self):
|
|
204
|
+
format_string = self.__class__.__name__ + "("
|
|
205
|
+
for t in self.transforms:
|
|
206
|
+
format_string += "\n"
|
|
207
|
+
format_string += " {0}".format(t)
|
|
208
|
+
format_string += "\n)"
|
|
209
|
+
return format_string
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class RandomGrayscale:
|
|
213
|
+
def __init__(self, consistent_transform, p=0.5):
|
|
214
|
+
self.p = p
|
|
215
|
+
self.consistent_transform = consistent_transform
|
|
216
|
+
self.Grayscale = T.Grayscale(num_output_channels=3)
|
|
217
|
+
|
|
218
|
+
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
|
219
|
+
if self.consistent_transform:
|
|
220
|
+
if random.random() < self.p:
|
|
221
|
+
for img in datapoint.frames:
|
|
222
|
+
img.data = self.Grayscale(img.data)
|
|
223
|
+
return datapoint
|
|
224
|
+
for img in datapoint.frames:
|
|
225
|
+
if random.random() < self.p:
|
|
226
|
+
img.data = self.Grayscale(img.data)
|
|
227
|
+
return datapoint
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class ColorJitter:
|
|
231
|
+
def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
|
|
232
|
+
self.consistent_transform = consistent_transform
|
|
233
|
+
self.brightness = brightness if isinstance(brightness, list) else [max(0, 1 - brightness), 1 + brightness]
|
|
234
|
+
self.contrast = contrast if isinstance(contrast, list) else [max(0, 1 - contrast), 1 + contrast]
|
|
235
|
+
self.saturation = saturation if isinstance(saturation, list) else [max(0, 1 - saturation), 1 + saturation]
|
|
236
|
+
self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
|
|
237
|
+
|
|
238
|
+
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
|
239
|
+
if self.consistent_transform:
|
|
240
|
+
# Create a color jitter transformation params
|
|
241
|
+
(
|
|
242
|
+
fn_idx,
|
|
243
|
+
brightness_factor,
|
|
244
|
+
contrast_factor,
|
|
245
|
+
saturation_factor,
|
|
246
|
+
hue_factor,
|
|
247
|
+
) = T.ColorJitter.get_params(self.brightness, self.contrast, self.saturation, self.hue)
|
|
248
|
+
for img in datapoint.frames:
|
|
249
|
+
if not self.consistent_transform:
|
|
250
|
+
(
|
|
251
|
+
fn_idx,
|
|
252
|
+
brightness_factor,
|
|
253
|
+
contrast_factor,
|
|
254
|
+
saturation_factor,
|
|
255
|
+
hue_factor,
|
|
256
|
+
) = T.ColorJitter.get_params(self.brightness, self.contrast, self.saturation, self.hue)
|
|
257
|
+
for fn_id in fn_idx:
|
|
258
|
+
if fn_id == 0 and brightness_factor is not None:
|
|
259
|
+
img.data = F.adjust_brightness(img.data, brightness_factor)
|
|
260
|
+
elif fn_id == 1 and contrast_factor is not None:
|
|
261
|
+
img.data = F.adjust_contrast(img.data, contrast_factor)
|
|
262
|
+
elif fn_id == 2 and saturation_factor is not None:
|
|
263
|
+
img.data = F.adjust_saturation(img.data, saturation_factor)
|
|
264
|
+
elif fn_id == 3 and hue_factor is not None:
|
|
265
|
+
img.data = F.adjust_hue(img.data, hue_factor)
|
|
266
|
+
return datapoint
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class RandomAffine:
|
|
270
|
+
def __init__(
|
|
271
|
+
self,
|
|
272
|
+
degrees,
|
|
273
|
+
consistent_transform,
|
|
274
|
+
scale=None,
|
|
275
|
+
translate=None,
|
|
276
|
+
shear=None,
|
|
277
|
+
image_mean=(123, 116, 103),
|
|
278
|
+
log_warning=True,
|
|
279
|
+
num_tentatives=1,
|
|
280
|
+
image_interpolation="bicubic",
|
|
281
|
+
):
|
|
282
|
+
"""
|
|
283
|
+
The mask is required for this transform.
|
|
284
|
+
if consistent_transform if True, then the same random affine is applied to all frames and masks.
|
|
285
|
+
"""
|
|
286
|
+
self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
|
|
287
|
+
self.scale = scale
|
|
288
|
+
self.shear = shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
|
|
289
|
+
self.translate = translate
|
|
290
|
+
self.fill_img = image_mean
|
|
291
|
+
self.consistent_transform = consistent_transform
|
|
292
|
+
self.log_warning = log_warning
|
|
293
|
+
self.num_tentatives = num_tentatives
|
|
294
|
+
|
|
295
|
+
if image_interpolation == "bicubic":
|
|
296
|
+
self.image_interpolation = InterpolationMode.BICUBIC
|
|
297
|
+
elif image_interpolation == "bilinear":
|
|
298
|
+
self.image_interpolation = InterpolationMode.BILINEAR
|
|
299
|
+
else:
|
|
300
|
+
raise NotImplementedError
|
|
301
|
+
|
|
302
|
+
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
|
303
|
+
for _tentative in range(self.num_tentatives):
|
|
304
|
+
res = self.transform_datapoint(datapoint)
|
|
305
|
+
if res is not None:
|
|
306
|
+
return res
|
|
307
|
+
|
|
308
|
+
if self.log_warning:
|
|
309
|
+
logging.warning(
|
|
310
|
+
f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
|
|
311
|
+
)
|
|
312
|
+
return datapoint
|
|
313
|
+
|
|
314
|
+
def transform_datapoint(self, datapoint: VideoDatapoint):
|
|
315
|
+
_, height, width = F.get_dimensions(datapoint.frames[0].data)
|
|
316
|
+
img_size = [width, height]
|
|
317
|
+
|
|
318
|
+
if self.consistent_transform:
|
|
319
|
+
# Create a random affine transformation
|
|
320
|
+
affine_params = T.RandomAffine.get_params(
|
|
321
|
+
degrees=self.degrees,
|
|
322
|
+
translate=self.translate,
|
|
323
|
+
scale_ranges=self.scale,
|
|
324
|
+
shears=self.shear,
|
|
325
|
+
img_size=img_size,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
for img_idx, img in enumerate(datapoint.frames):
|
|
329
|
+
this_masks = [obj.segment.unsqueeze(0) if obj.segment is not None else None for obj in img.objects]
|
|
330
|
+
if not self.consistent_transform:
|
|
331
|
+
# if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
|
|
332
|
+
affine_params = T.RandomAffine.get_params(
|
|
333
|
+
degrees=self.degrees,
|
|
334
|
+
translate=self.translate,
|
|
335
|
+
scale_ranges=self.scale,
|
|
336
|
+
shears=self.shear,
|
|
337
|
+
img_size=img_size,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
transformed_bboxes, transformed_masks = [], []
|
|
341
|
+
for i in range(len(img.objects)):
|
|
342
|
+
if this_masks[i] is None:
|
|
343
|
+
transformed_masks.append(None)
|
|
344
|
+
# Dummy bbox for a dummy target
|
|
345
|
+
transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
|
|
346
|
+
else:
|
|
347
|
+
transformed_mask = F.affine(
|
|
348
|
+
this_masks[i],
|
|
349
|
+
*affine_params,
|
|
350
|
+
interpolation=InterpolationMode.NEAREST,
|
|
351
|
+
fill=0.0,
|
|
352
|
+
)
|
|
353
|
+
if img_idx == 0 and transformed_mask.max() == 0:
|
|
354
|
+
# We are dealing with a video and the object is not visible in the first frame
|
|
355
|
+
# Return the datapoint without transformation
|
|
356
|
+
return None
|
|
357
|
+
transformed_masks.append(transformed_mask.squeeze())
|
|
358
|
+
|
|
359
|
+
for i in range(len(img.objects)):
|
|
360
|
+
img.objects[i].segment = transformed_masks[i]
|
|
361
|
+
|
|
362
|
+
img.data = F.affine(
|
|
363
|
+
img.data,
|
|
364
|
+
*affine_params,
|
|
365
|
+
interpolation=self.image_interpolation,
|
|
366
|
+
fill=self.fill_img,
|
|
367
|
+
)
|
|
368
|
+
return datapoint
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def random_mosaic_frame(
|
|
372
|
+
datapoint,
|
|
373
|
+
index,
|
|
374
|
+
grid_h,
|
|
375
|
+
grid_w,
|
|
376
|
+
target_grid_y,
|
|
377
|
+
target_grid_x,
|
|
378
|
+
should_hflip,
|
|
379
|
+
):
|
|
380
|
+
# Step 1: downsize the images and paste them into a mosaic
|
|
381
|
+
image_data = datapoint.frames[index].data
|
|
382
|
+
is_pil = isinstance(image_data, PILImage.Image)
|
|
383
|
+
if is_pil:
|
|
384
|
+
H_im = image_data.height
|
|
385
|
+
W_im = image_data.width
|
|
386
|
+
image_data_output = PILImage.new("RGB", (W_im, H_im))
|
|
387
|
+
else:
|
|
388
|
+
H_im = image_data.size(-2)
|
|
389
|
+
W_im = image_data.size(-1)
|
|
390
|
+
image_data_output = torch.zeros_like(image_data)
|
|
391
|
+
|
|
392
|
+
downsize_cache = {}
|
|
393
|
+
for grid_y in range(grid_h):
|
|
394
|
+
for grid_x in range(grid_w):
|
|
395
|
+
y_offset_b = grid_y * H_im // grid_h
|
|
396
|
+
x_offset_b = grid_x * W_im // grid_w
|
|
397
|
+
y_offset_e = (grid_y + 1) * H_im // grid_h
|
|
398
|
+
x_offset_e = (grid_x + 1) * W_im // grid_w
|
|
399
|
+
H_im_downsize = y_offset_e - y_offset_b
|
|
400
|
+
W_im_downsize = x_offset_e - x_offset_b
|
|
401
|
+
|
|
402
|
+
if (H_im_downsize, W_im_downsize) in downsize_cache:
|
|
403
|
+
image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
|
|
404
|
+
else:
|
|
405
|
+
image_data_downsize = F.resize(
|
|
406
|
+
image_data,
|
|
407
|
+
size=(H_im_downsize, W_im_downsize),
|
|
408
|
+
interpolation=InterpolationMode.BILINEAR,
|
|
409
|
+
antialias=True, # antialiasing for downsizing
|
|
410
|
+
)
|
|
411
|
+
downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
|
|
412
|
+
if should_hflip[grid_y, grid_x].item():
|
|
413
|
+
image_data_downsize = F.hflip(image_data_downsize)
|
|
414
|
+
|
|
415
|
+
if is_pil:
|
|
416
|
+
image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
|
|
417
|
+
else:
|
|
418
|
+
image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = image_data_downsize
|
|
419
|
+
|
|
420
|
+
datapoint.frames[index].data = image_data_output
|
|
421
|
+
|
|
422
|
+
# Step 2: downsize the masks and paste them into the target grid of the mosaic
|
|
423
|
+
for obj in datapoint.frames[index].objects:
|
|
424
|
+
if obj.segment is None:
|
|
425
|
+
continue
|
|
426
|
+
assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
|
|
427
|
+
segment_output = torch.zeros_like(obj.segment)
|
|
428
|
+
|
|
429
|
+
target_y_offset_b = target_grid_y * H_im // grid_h
|
|
430
|
+
target_x_offset_b = target_grid_x * W_im // grid_w
|
|
431
|
+
target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
|
|
432
|
+
target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
|
|
433
|
+
target_H_im_downsize = target_y_offset_e - target_y_offset_b
|
|
434
|
+
target_W_im_downsize = target_x_offset_e - target_x_offset_b
|
|
435
|
+
|
|
436
|
+
segment_downsize = F.resize(
|
|
437
|
+
obj.segment[None, None],
|
|
438
|
+
size=(target_H_im_downsize, target_W_im_downsize),
|
|
439
|
+
interpolation=InterpolationMode.BILINEAR,
|
|
440
|
+
antialias=True, # antialiasing for downsizing
|
|
441
|
+
)[0, 0]
|
|
442
|
+
if should_hflip[target_grid_y, target_grid_x].item():
|
|
443
|
+
segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
|
|
444
|
+
|
|
445
|
+
segment_output[target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e] = segment_downsize
|
|
446
|
+
obj.segment = segment_output
|
|
447
|
+
|
|
448
|
+
return datapoint
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
class RandomMosaicVideoAPI:
|
|
452
|
+
def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
|
|
453
|
+
self.prob = prob
|
|
454
|
+
self.grid_h = grid_h
|
|
455
|
+
self.grid_w = grid_w
|
|
456
|
+
self.use_random_hflip = use_random_hflip
|
|
457
|
+
|
|
458
|
+
def __call__(self, datapoint, **kwargs):
|
|
459
|
+
if random.random() > self.prob:
|
|
460
|
+
return datapoint
|
|
461
|
+
|
|
462
|
+
# select a random location to place the target mask in the mosaic
|
|
463
|
+
target_grid_y = random.randint(0, self.grid_h - 1)
|
|
464
|
+
target_grid_x = random.randint(0, self.grid_w - 1)
|
|
465
|
+
# whether to flip each grid in the mosaic horizontally
|
|
466
|
+
if self.use_random_hflip:
|
|
467
|
+
should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
|
|
468
|
+
else:
|
|
469
|
+
should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
|
|
470
|
+
for i in range(len(datapoint.frames)):
|
|
471
|
+
datapoint = random_mosaic_frame(
|
|
472
|
+
datapoint,
|
|
473
|
+
i,
|
|
474
|
+
grid_h=self.grid_h,
|
|
475
|
+
grid_w=self.grid_w,
|
|
476
|
+
target_grid_y=target_grid_y,
|
|
477
|
+
target_grid_x=target_grid_x,
|
|
478
|
+
should_hflip=should_hflip,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
return datapoint
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
|
|
8
|
+
|
|
9
|
+
from typing import Iterable
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from torch.utils.data import (
|
|
13
|
+
ConcatDataset as TorchConcatDataset,
|
|
14
|
+
Dataset,
|
|
15
|
+
Subset as TorchSubset,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConcatDataset(TorchConcatDataset):
|
|
20
|
+
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
|
21
|
+
super(ConcatDataset, self).__init__(datasets)
|
|
22
|
+
|
|
23
|
+
self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
|
|
24
|
+
|
|
25
|
+
def set_epoch(self, epoch: int):
|
|
26
|
+
for dataset in self.datasets:
|
|
27
|
+
if hasattr(dataset, "epoch"):
|
|
28
|
+
dataset.epoch = epoch
|
|
29
|
+
if hasattr(dataset, "set_epoch"):
|
|
30
|
+
dataset.set_epoch(epoch)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Subset(TorchSubset):
|
|
34
|
+
def __init__(self, dataset, indices) -> None:
|
|
35
|
+
super(Subset, self).__init__(dataset, indices)
|
|
36
|
+
|
|
37
|
+
self.repeat_factors = dataset.repeat_factors[indices]
|
|
38
|
+
assert len(indices) == len(self.repeat_factors)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Adapted from Detectron2
|
|
42
|
+
class RepeatFactorWrapper(Dataset):
|
|
43
|
+
"""
|
|
44
|
+
Thin wrapper around a dataset to implement repeat factor sampling.
|
|
45
|
+
The underlying dataset must have a repeat_factors member to indicate the per-image factor.
|
|
46
|
+
Set it to uniformly ones to disable repeat factor sampling
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, dataset, seed: int = 0):
|
|
50
|
+
self.dataset = dataset
|
|
51
|
+
self.epoch_ids = None
|
|
52
|
+
self._seed = seed
|
|
53
|
+
|
|
54
|
+
# Split into whole number (_int_part) and fractional (_frac_part) parts.
|
|
55
|
+
self._int_part = torch.trunc(dataset.repeat_factors)
|
|
56
|
+
self._frac_part = dataset.repeat_factors - self._int_part
|
|
57
|
+
|
|
58
|
+
def _get_epoch_indices(self, generator):
|
|
59
|
+
"""
|
|
60
|
+
Create a list of dataset indices (with repeats) to use for one epoch.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
generator (torch.Generator): pseudo random number generator used for
|
|
64
|
+
stochastic rounding.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
torch.Tensor: list of dataset indices to use in one epoch. Each index
|
|
68
|
+
is repeated based on its calculated repeat factor.
|
|
69
|
+
"""
|
|
70
|
+
# Since repeat factors are fractional, we use stochastic rounding so
|
|
71
|
+
# that the target repeat factor is achieved in expectation over the
|
|
72
|
+
# course of training
|
|
73
|
+
rands = torch.rand(len(self._frac_part), generator=generator)
|
|
74
|
+
rep_factors = self._int_part + (rands < self._frac_part).float()
|
|
75
|
+
# Construct a list of indices in which we repeat images as specified
|
|
76
|
+
indices = []
|
|
77
|
+
for dataset_index, rep_factor in enumerate(rep_factors):
|
|
78
|
+
indices.extend([dataset_index] * int(rep_factor.item()))
|
|
79
|
+
return torch.tensor(indices, dtype=torch.int64)
|
|
80
|
+
|
|
81
|
+
def __len__(self):
|
|
82
|
+
if self.epoch_ids is None:
|
|
83
|
+
# Here we raise an error instead of returning original len(self.dataset) avoid
|
|
84
|
+
# accidentally using unwrapped length. Otherwise it's error-prone since the
|
|
85
|
+
# length changes to `len(self.epoch_ids)`changes after set_epoch is called.
|
|
86
|
+
raise RuntimeError("please call set_epoch first to get wrapped length")
|
|
87
|
+
# return len(self.dataset)
|
|
88
|
+
|
|
89
|
+
return len(self.epoch_ids)
|
|
90
|
+
|
|
91
|
+
def set_epoch(self, epoch: int):
|
|
92
|
+
g = torch.Generator()
|
|
93
|
+
g.manual_seed(self._seed + epoch)
|
|
94
|
+
self.epoch_ids = self._get_epoch_indices(g)
|
|
95
|
+
if hasattr(self.dataset, "set_epoch"):
|
|
96
|
+
self.dataset.set_epoch(epoch)
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, idx):
|
|
99
|
+
if self.epoch_ids is None:
|
|
100
|
+
raise RuntimeError("Repeat ids haven't been computed. Did you forget to call set_epoch?")
|
|
101
|
+
|
|
102
|
+
return self.dataset[self.epoch_ids[idx]]
|