nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,481 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Transforms and data augmentation for both image + bbox.
9
+ """
10
+
11
+ import logging
12
+
13
+ import random
14
+ from typing import Iterable
15
+
16
+ import torch
17
+ import torchvision.transforms as T
18
+ import torchvision.transforms.functional as F
19
+ import torchvision.transforms.v2.functional as Fv2
20
+ from PIL import Image as PILImage
21
+
22
+ from torchvision.transforms import InterpolationMode
23
+
24
+ from training.utils.data_utils import VideoDatapoint
25
+
26
+
27
+ def hflip(datapoint, index):
28
+
29
+ datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
30
+ for obj in datapoint.frames[index].objects:
31
+ if obj.segment is not None:
32
+ obj.segment = F.hflip(obj.segment)
33
+
34
+ return datapoint
35
+
36
+
37
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
38
+ w, h = image_size
39
+ if max_size is not None:
40
+ min_original_size = float(min((w, h)))
41
+ max_original_size = float(max((w, h)))
42
+ if max_original_size / min_original_size * size > max_size:
43
+ size = max_size * min_original_size / max_original_size
44
+
45
+ if (w <= h and w == size) or (h <= w and h == size):
46
+ return (h, w)
47
+
48
+ if w < h:
49
+ ow = int(round(size))
50
+ oh = int(round(size * h / w))
51
+ else:
52
+ oh = int(round(size))
53
+ ow = int(round(size * w / h))
54
+
55
+ return (oh, ow)
56
+
57
+
58
+ def resize(datapoint, index, size, max_size=None, square=False, v2=False):
59
+ # size can be min_size (scalar) or (w, h) tuple
60
+
61
+ def get_size(image_size, size, max_size=None):
62
+ if isinstance(size, (list, tuple)):
63
+ return size[::-1]
64
+ else:
65
+ return get_size_with_aspect_ratio(image_size, size, max_size)
66
+
67
+ if square:
68
+ size = size, size
69
+ else:
70
+ cur_size = datapoint.frames[index].data.size()[-2:][::-1] if v2 else datapoint.frames[index].data.size
71
+ size = get_size(cur_size, size, max_size)
72
+
73
+ old_size = datapoint.frames[index].data.size()[-2:][::-1] if v2 else datapoint.frames[index].data.size
74
+ if v2:
75
+ datapoint.frames[index].data = Fv2.resize(datapoint.frames[index].data, size, antialias=True)
76
+ else:
77
+ datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
78
+
79
+ new_size = datapoint.frames[index].data.size()[-2:][::-1] if v2 else datapoint.frames[index].data.size
80
+
81
+ for obj in datapoint.frames[index].objects:
82
+ if obj.segment is not None:
83
+ obj.segment = F.resize(obj.segment[None, None], size).squeeze()
84
+
85
+ h, w = size
86
+ datapoint.frames[index].size = (h, w)
87
+ return datapoint
88
+
89
+
90
+ def pad(datapoint, index, padding, v2=False):
91
+ old_h, old_w = datapoint.frames[index].size
92
+ h, w = old_h, old_w
93
+ if len(padding) == 2:
94
+ # assumes that we only pad on the bottom right corners
95
+ datapoint.frames[index].data = F.pad(datapoint.frames[index].data, (0, 0, padding[0], padding[1]))
96
+ h += padding[1]
97
+ w += padding[0]
98
+ else:
99
+ # left, top, right, bottom
100
+ datapoint.frames[index].data = F.pad(
101
+ datapoint.frames[index].data,
102
+ (padding[0], padding[1], padding[2], padding[3]),
103
+ )
104
+ h += padding[1] + padding[3]
105
+ w += padding[0] + padding[2]
106
+
107
+ datapoint.frames[index].size = (h, w)
108
+
109
+ for obj in datapoint.frames[index].objects:
110
+ if obj.segment is not None:
111
+ if v2:
112
+ if len(padding) == 2:
113
+ obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
114
+ else:
115
+ obj.segment = Fv2.pad(obj.segment, tuple(padding))
116
+ else:
117
+ if len(padding) == 2:
118
+ obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
119
+ else:
120
+ obj.segment = F.pad(obj.segment, tuple(padding))
121
+ return datapoint
122
+
123
+
124
+ class RandomHorizontalFlip:
125
+ def __init__(self, consistent_transform, p=0.5):
126
+ self.p = p
127
+ self.consistent_transform = consistent_transform
128
+
129
+ def __call__(self, datapoint, **kwargs):
130
+ if self.consistent_transform:
131
+ if random.random() < self.p:
132
+ for i in range(len(datapoint.frames)):
133
+ datapoint = hflip(datapoint, i)
134
+ return datapoint
135
+ for i in range(len(datapoint.frames)):
136
+ if random.random() < self.p:
137
+ datapoint = hflip(datapoint, i)
138
+ return datapoint
139
+
140
+
141
+ class RandomResizeAPI:
142
+ def __init__(self, sizes, consistent_transform, max_size=None, square=False, v2=False):
143
+ if isinstance(sizes, int):
144
+ sizes = (sizes,)
145
+ assert isinstance(sizes, Iterable)
146
+ self.sizes = list(sizes)
147
+ self.max_size = max_size
148
+ self.square = square
149
+ self.consistent_transform = consistent_transform
150
+ self.v2 = v2
151
+
152
+ def __call__(self, datapoint, **kwargs):
153
+ if self.consistent_transform:
154
+ size = random.choice(self.sizes)
155
+ for i in range(len(datapoint.frames)):
156
+ datapoint = resize(datapoint, i, size, self.max_size, square=self.square, v2=self.v2)
157
+ return datapoint
158
+ for i in range(len(datapoint.frames)):
159
+ size = random.choice(self.sizes)
160
+ datapoint = resize(datapoint, i, size, self.max_size, square=self.square, v2=self.v2)
161
+ return datapoint
162
+
163
+
164
+ class ToTensorAPI:
165
+ def __init__(self, v2=False):
166
+ self.v2 = v2
167
+
168
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
169
+ for img in datapoint.frames:
170
+ if self.v2:
171
+ img.data = Fv2.to_image_tensor(img.data)
172
+ else:
173
+ img.data = F.to_tensor(img.data)
174
+ return datapoint
175
+
176
+
177
+ class NormalizeAPI:
178
+ def __init__(self, mean, std, v2=False):
179
+ self.mean = mean
180
+ self.std = std
181
+ self.v2 = v2
182
+
183
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
184
+ for img in datapoint.frames:
185
+ if self.v2:
186
+ img.data = Fv2.convert_image_dtype(img.data, torch.float32)
187
+ img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
188
+ else:
189
+ img.data = F.normalize(img.data, mean=self.mean, std=self.std)
190
+
191
+ return datapoint
192
+
193
+
194
+ class ComposeAPI:
195
+ def __init__(self, transforms):
196
+ self.transforms = transforms
197
+
198
+ def __call__(self, datapoint, **kwargs):
199
+ for t in self.transforms:
200
+ datapoint = t(datapoint, **kwargs)
201
+ return datapoint
202
+
203
+ def __repr__(self):
204
+ format_string = self.__class__.__name__ + "("
205
+ for t in self.transforms:
206
+ format_string += "\n"
207
+ format_string += " {0}".format(t)
208
+ format_string += "\n)"
209
+ return format_string
210
+
211
+
212
+ class RandomGrayscale:
213
+ def __init__(self, consistent_transform, p=0.5):
214
+ self.p = p
215
+ self.consistent_transform = consistent_transform
216
+ self.Grayscale = T.Grayscale(num_output_channels=3)
217
+
218
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
219
+ if self.consistent_transform:
220
+ if random.random() < self.p:
221
+ for img in datapoint.frames:
222
+ img.data = self.Grayscale(img.data)
223
+ return datapoint
224
+ for img in datapoint.frames:
225
+ if random.random() < self.p:
226
+ img.data = self.Grayscale(img.data)
227
+ return datapoint
228
+
229
+
230
+ class ColorJitter:
231
+ def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
232
+ self.consistent_transform = consistent_transform
233
+ self.brightness = brightness if isinstance(brightness, list) else [max(0, 1 - brightness), 1 + brightness]
234
+ self.contrast = contrast if isinstance(contrast, list) else [max(0, 1 - contrast), 1 + contrast]
235
+ self.saturation = saturation if isinstance(saturation, list) else [max(0, 1 - saturation), 1 + saturation]
236
+ self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
237
+
238
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
239
+ if self.consistent_transform:
240
+ # Create a color jitter transformation params
241
+ (
242
+ fn_idx,
243
+ brightness_factor,
244
+ contrast_factor,
245
+ saturation_factor,
246
+ hue_factor,
247
+ ) = T.ColorJitter.get_params(self.brightness, self.contrast, self.saturation, self.hue)
248
+ for img in datapoint.frames:
249
+ if not self.consistent_transform:
250
+ (
251
+ fn_idx,
252
+ brightness_factor,
253
+ contrast_factor,
254
+ saturation_factor,
255
+ hue_factor,
256
+ ) = T.ColorJitter.get_params(self.brightness, self.contrast, self.saturation, self.hue)
257
+ for fn_id in fn_idx:
258
+ if fn_id == 0 and brightness_factor is not None:
259
+ img.data = F.adjust_brightness(img.data, brightness_factor)
260
+ elif fn_id == 1 and contrast_factor is not None:
261
+ img.data = F.adjust_contrast(img.data, contrast_factor)
262
+ elif fn_id == 2 and saturation_factor is not None:
263
+ img.data = F.adjust_saturation(img.data, saturation_factor)
264
+ elif fn_id == 3 and hue_factor is not None:
265
+ img.data = F.adjust_hue(img.data, hue_factor)
266
+ return datapoint
267
+
268
+
269
+ class RandomAffine:
270
+ def __init__(
271
+ self,
272
+ degrees,
273
+ consistent_transform,
274
+ scale=None,
275
+ translate=None,
276
+ shear=None,
277
+ image_mean=(123, 116, 103),
278
+ log_warning=True,
279
+ num_tentatives=1,
280
+ image_interpolation="bicubic",
281
+ ):
282
+ """
283
+ The mask is required for this transform.
284
+ if consistent_transform if True, then the same random affine is applied to all frames and masks.
285
+ """
286
+ self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
287
+ self.scale = scale
288
+ self.shear = shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
289
+ self.translate = translate
290
+ self.fill_img = image_mean
291
+ self.consistent_transform = consistent_transform
292
+ self.log_warning = log_warning
293
+ self.num_tentatives = num_tentatives
294
+
295
+ if image_interpolation == "bicubic":
296
+ self.image_interpolation = InterpolationMode.BICUBIC
297
+ elif image_interpolation == "bilinear":
298
+ self.image_interpolation = InterpolationMode.BILINEAR
299
+ else:
300
+ raise NotImplementedError
301
+
302
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
303
+ for _tentative in range(self.num_tentatives):
304
+ res = self.transform_datapoint(datapoint)
305
+ if res is not None:
306
+ return res
307
+
308
+ if self.log_warning:
309
+ logging.warning(
310
+ f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
311
+ )
312
+ return datapoint
313
+
314
+ def transform_datapoint(self, datapoint: VideoDatapoint):
315
+ _, height, width = F.get_dimensions(datapoint.frames[0].data)
316
+ img_size = [width, height]
317
+
318
+ if self.consistent_transform:
319
+ # Create a random affine transformation
320
+ affine_params = T.RandomAffine.get_params(
321
+ degrees=self.degrees,
322
+ translate=self.translate,
323
+ scale_ranges=self.scale,
324
+ shears=self.shear,
325
+ img_size=img_size,
326
+ )
327
+
328
+ for img_idx, img in enumerate(datapoint.frames):
329
+ this_masks = [obj.segment.unsqueeze(0) if obj.segment is not None else None for obj in img.objects]
330
+ if not self.consistent_transform:
331
+ # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
332
+ affine_params = T.RandomAffine.get_params(
333
+ degrees=self.degrees,
334
+ translate=self.translate,
335
+ scale_ranges=self.scale,
336
+ shears=self.shear,
337
+ img_size=img_size,
338
+ )
339
+
340
+ transformed_bboxes, transformed_masks = [], []
341
+ for i in range(len(img.objects)):
342
+ if this_masks[i] is None:
343
+ transformed_masks.append(None)
344
+ # Dummy bbox for a dummy target
345
+ transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
346
+ else:
347
+ transformed_mask = F.affine(
348
+ this_masks[i],
349
+ *affine_params,
350
+ interpolation=InterpolationMode.NEAREST,
351
+ fill=0.0,
352
+ )
353
+ if img_idx == 0 and transformed_mask.max() == 0:
354
+ # We are dealing with a video and the object is not visible in the first frame
355
+ # Return the datapoint without transformation
356
+ return None
357
+ transformed_masks.append(transformed_mask.squeeze())
358
+
359
+ for i in range(len(img.objects)):
360
+ img.objects[i].segment = transformed_masks[i]
361
+
362
+ img.data = F.affine(
363
+ img.data,
364
+ *affine_params,
365
+ interpolation=self.image_interpolation,
366
+ fill=self.fill_img,
367
+ )
368
+ return datapoint
369
+
370
+
371
+ def random_mosaic_frame(
372
+ datapoint,
373
+ index,
374
+ grid_h,
375
+ grid_w,
376
+ target_grid_y,
377
+ target_grid_x,
378
+ should_hflip,
379
+ ):
380
+ # Step 1: downsize the images and paste them into a mosaic
381
+ image_data = datapoint.frames[index].data
382
+ is_pil = isinstance(image_data, PILImage.Image)
383
+ if is_pil:
384
+ H_im = image_data.height
385
+ W_im = image_data.width
386
+ image_data_output = PILImage.new("RGB", (W_im, H_im))
387
+ else:
388
+ H_im = image_data.size(-2)
389
+ W_im = image_data.size(-1)
390
+ image_data_output = torch.zeros_like(image_data)
391
+
392
+ downsize_cache = {}
393
+ for grid_y in range(grid_h):
394
+ for grid_x in range(grid_w):
395
+ y_offset_b = grid_y * H_im // grid_h
396
+ x_offset_b = grid_x * W_im // grid_w
397
+ y_offset_e = (grid_y + 1) * H_im // grid_h
398
+ x_offset_e = (grid_x + 1) * W_im // grid_w
399
+ H_im_downsize = y_offset_e - y_offset_b
400
+ W_im_downsize = x_offset_e - x_offset_b
401
+
402
+ if (H_im_downsize, W_im_downsize) in downsize_cache:
403
+ image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
404
+ else:
405
+ image_data_downsize = F.resize(
406
+ image_data,
407
+ size=(H_im_downsize, W_im_downsize),
408
+ interpolation=InterpolationMode.BILINEAR,
409
+ antialias=True, # antialiasing for downsizing
410
+ )
411
+ downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
412
+ if should_hflip[grid_y, grid_x].item():
413
+ image_data_downsize = F.hflip(image_data_downsize)
414
+
415
+ if is_pil:
416
+ image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
417
+ else:
418
+ image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = image_data_downsize
419
+
420
+ datapoint.frames[index].data = image_data_output
421
+
422
+ # Step 2: downsize the masks and paste them into the target grid of the mosaic
423
+ for obj in datapoint.frames[index].objects:
424
+ if obj.segment is None:
425
+ continue
426
+ assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
427
+ segment_output = torch.zeros_like(obj.segment)
428
+
429
+ target_y_offset_b = target_grid_y * H_im // grid_h
430
+ target_x_offset_b = target_grid_x * W_im // grid_w
431
+ target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
432
+ target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
433
+ target_H_im_downsize = target_y_offset_e - target_y_offset_b
434
+ target_W_im_downsize = target_x_offset_e - target_x_offset_b
435
+
436
+ segment_downsize = F.resize(
437
+ obj.segment[None, None],
438
+ size=(target_H_im_downsize, target_W_im_downsize),
439
+ interpolation=InterpolationMode.BILINEAR,
440
+ antialias=True, # antialiasing for downsizing
441
+ )[0, 0]
442
+ if should_hflip[target_grid_y, target_grid_x].item():
443
+ segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
444
+
445
+ segment_output[target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e] = segment_downsize
446
+ obj.segment = segment_output
447
+
448
+ return datapoint
449
+
450
+
451
+ class RandomMosaicVideoAPI:
452
+ def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
453
+ self.prob = prob
454
+ self.grid_h = grid_h
455
+ self.grid_w = grid_w
456
+ self.use_random_hflip = use_random_hflip
457
+
458
+ def __call__(self, datapoint, **kwargs):
459
+ if random.random() > self.prob:
460
+ return datapoint
461
+
462
+ # select a random location to place the target mask in the mosaic
463
+ target_grid_y = random.randint(0, self.grid_h - 1)
464
+ target_grid_x = random.randint(0, self.grid_w - 1)
465
+ # whether to flip each grid in the mosaic horizontally
466
+ if self.use_random_hflip:
467
+ should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
468
+ else:
469
+ should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
470
+ for i in range(len(datapoint.frames)):
471
+ datapoint = random_mosaic_frame(
472
+ datapoint,
473
+ i,
474
+ grid_h=self.grid_h,
475
+ grid_w=self.grid_w,
476
+ target_grid_y=target_grid_y,
477
+ target_grid_x=target_grid_x,
478
+ should_hflip=should_hflip,
479
+ )
480
+
481
+ return datapoint
@@ -0,0 +1,102 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
8
+
9
+ from typing import Iterable
10
+
11
+ import torch
12
+ from torch.utils.data import (
13
+ ConcatDataset as TorchConcatDataset,
14
+ Dataset,
15
+ Subset as TorchSubset,
16
+ )
17
+
18
+
19
+ class ConcatDataset(TorchConcatDataset):
20
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
21
+ super(ConcatDataset, self).__init__(datasets)
22
+
23
+ self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
24
+
25
+ def set_epoch(self, epoch: int):
26
+ for dataset in self.datasets:
27
+ if hasattr(dataset, "epoch"):
28
+ dataset.epoch = epoch
29
+ if hasattr(dataset, "set_epoch"):
30
+ dataset.set_epoch(epoch)
31
+
32
+
33
+ class Subset(TorchSubset):
34
+ def __init__(self, dataset, indices) -> None:
35
+ super(Subset, self).__init__(dataset, indices)
36
+
37
+ self.repeat_factors = dataset.repeat_factors[indices]
38
+ assert len(indices) == len(self.repeat_factors)
39
+
40
+
41
+ # Adapted from Detectron2
42
+ class RepeatFactorWrapper(Dataset):
43
+ """
44
+ Thin wrapper around a dataset to implement repeat factor sampling.
45
+ The underlying dataset must have a repeat_factors member to indicate the per-image factor.
46
+ Set it to uniformly ones to disable repeat factor sampling
47
+ """
48
+
49
+ def __init__(self, dataset, seed: int = 0):
50
+ self.dataset = dataset
51
+ self.epoch_ids = None
52
+ self._seed = seed
53
+
54
+ # Split into whole number (_int_part) and fractional (_frac_part) parts.
55
+ self._int_part = torch.trunc(dataset.repeat_factors)
56
+ self._frac_part = dataset.repeat_factors - self._int_part
57
+
58
+ def _get_epoch_indices(self, generator):
59
+ """
60
+ Create a list of dataset indices (with repeats) to use for one epoch.
61
+
62
+ Args:
63
+ generator (torch.Generator): pseudo random number generator used for
64
+ stochastic rounding.
65
+
66
+ Returns:
67
+ torch.Tensor: list of dataset indices to use in one epoch. Each index
68
+ is repeated based on its calculated repeat factor.
69
+ """
70
+ # Since repeat factors are fractional, we use stochastic rounding so
71
+ # that the target repeat factor is achieved in expectation over the
72
+ # course of training
73
+ rands = torch.rand(len(self._frac_part), generator=generator)
74
+ rep_factors = self._int_part + (rands < self._frac_part).float()
75
+ # Construct a list of indices in which we repeat images as specified
76
+ indices = []
77
+ for dataset_index, rep_factor in enumerate(rep_factors):
78
+ indices.extend([dataset_index] * int(rep_factor.item()))
79
+ return torch.tensor(indices, dtype=torch.int64)
80
+
81
+ def __len__(self):
82
+ if self.epoch_ids is None:
83
+ # Here we raise an error instead of returning original len(self.dataset) avoid
84
+ # accidentally using unwrapped length. Otherwise it's error-prone since the
85
+ # length changes to `len(self.epoch_ids)`changes after set_epoch is called.
86
+ raise RuntimeError("please call set_epoch first to get wrapped length")
87
+ # return len(self.dataset)
88
+
89
+ return len(self.epoch_ids)
90
+
91
+ def set_epoch(self, epoch: int):
92
+ g = torch.Generator()
93
+ g.manual_seed(self._seed + epoch)
94
+ self.epoch_ids = self._get_epoch_indices(g)
95
+ if hasattr(self.dataset, "set_epoch"):
96
+ self.dataset.set_epoch(epoch)
97
+
98
+ def __getitem__(self, idx):
99
+ if self.epoch_ids is None:
100
+ raise RuntimeError("Repeat ids haven't been computed. Did you forget to call set_epoch?")
101
+
102
+ return self.dataset[self.epoch_ids[idx]]