nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,154 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import random
9
+ from copy import deepcopy
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+ from iopath.common.file_io import g_pathmgr
15
+ from PIL import Image as PILImage
16
+ from torchvision.datasets.vision import VisionDataset
17
+
18
+ from training.dataset.vos_raw_dataset import VOSRawDataset
19
+ from training.dataset.vos_sampler import VOSSampler
20
+ from training.dataset.vos_segment_loader import JSONSegmentLoader
21
+
22
+ from training.utils.data_utils import Frame, Object, VideoDatapoint
23
+
24
+ MAX_RETRIES = 100
25
+
26
+
27
+ class VOSDataset(VisionDataset):
28
+ def __init__(
29
+ self,
30
+ transforms,
31
+ training: bool,
32
+ video_dataset: VOSRawDataset,
33
+ sampler: VOSSampler,
34
+ multiplier: int,
35
+ always_target=True,
36
+ target_segments_available=True,
37
+ ):
38
+ self._transforms = transforms
39
+ self.training = training
40
+ self.video_dataset = video_dataset
41
+ self.sampler = sampler
42
+
43
+ self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
44
+ self.repeat_factors *= multiplier
45
+ print(f"Raw dataset length = {len(self.video_dataset)}")
46
+
47
+ self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
48
+ self.always_target = always_target
49
+ self.target_segments_available = target_segments_available
50
+
51
+ def _get_datapoint(self, idx):
52
+
53
+ for retry in range(MAX_RETRIES):
54
+ try:
55
+ if isinstance(idx, torch.Tensor):
56
+ idx = idx.item()
57
+ # sample a video
58
+ video, segment_loader = self.video_dataset.get_video(idx)
59
+ # sample frames and object indices to be used in a datapoint
60
+ sampled_frms_and_objs = self.sampler.sample(video, segment_loader, epoch=self.curr_epoch)
61
+ break # Succesfully loaded video
62
+ except Exception as e:
63
+ if self.training:
64
+ logging.warning(f"Loading failed (id={idx}); Retry {retry} with exception: {e}")
65
+ idx = random.randrange(0, len(self.video_dataset))
66
+ else:
67
+ # Shouldn't fail to load a val video
68
+ raise e
69
+
70
+ datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
71
+ for transform in self._transforms:
72
+ datapoint = transform(datapoint, epoch=self.curr_epoch)
73
+ return datapoint
74
+
75
+ def construct(self, video, sampled_frms_and_objs, segment_loader):
76
+ """
77
+ Constructs a VideoDatapoint sample to pass to transforms
78
+ """
79
+ sampled_frames = sampled_frms_and_objs.frames
80
+ sampled_object_ids = sampled_frms_and_objs.object_ids
81
+
82
+ images = []
83
+ rgb_images = load_images(sampled_frames)
84
+ # Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
85
+ for frame_idx, frame in enumerate(sampled_frames):
86
+ w, h = rgb_images[frame_idx].size
87
+ images.append(
88
+ Frame(
89
+ data=rgb_images[frame_idx],
90
+ objects=[],
91
+ )
92
+ )
93
+ # We load the gt segments associated with the current frame
94
+ if isinstance(segment_loader, JSONSegmentLoader):
95
+ segments = segment_loader.load(frame.frame_idx, obj_ids=sampled_object_ids)
96
+ else:
97
+ segments = segment_loader.load(frame.frame_idx)
98
+ for obj_id in sampled_object_ids:
99
+ # Extract the segment
100
+ if obj_id in segments:
101
+ assert segments[obj_id] is not None, "None targets are not supported"
102
+ # segment is uint8 and remains uint8 throughout the transforms
103
+ segment = segments[obj_id].to(torch.uint8)
104
+ else:
105
+ # There is no target, we either use a zero mask target or drop this object
106
+ if not self.always_target:
107
+ continue
108
+ segment = torch.zeros(h, w, dtype=torch.uint8)
109
+
110
+ images[frame_idx].objects.append(
111
+ Object(
112
+ object_id=obj_id,
113
+ frame_index=frame.frame_idx,
114
+ segment=segment,
115
+ )
116
+ )
117
+ return VideoDatapoint(
118
+ frames=images,
119
+ video_id=video.video_id,
120
+ size=(h, w),
121
+ )
122
+
123
+ def __getitem__(self, idx):
124
+ return self._get_datapoint(idx)
125
+
126
+ def __len__(self):
127
+ return len(self.video_dataset)
128
+
129
+
130
+ def load_images(frames):
131
+ all_images = []
132
+ cache = {}
133
+ for frame in frames:
134
+ if frame.data is None:
135
+ # Load the frame rgb data from file
136
+ path = frame.image_path
137
+ if path in cache:
138
+ all_images.append(deepcopy(all_images[cache[path]]))
139
+ continue
140
+ with g_pathmgr.open(path, "rb") as fopen:
141
+ all_images.append(PILImage.open(fopen).convert("RGB"))
142
+ cache[path] = len(all_images) - 1
143
+ else:
144
+ # The frame rgb data has already been loaded
145
+ # Convert it to a PILImage
146
+ all_images.append(tensor_2_PIL(frame.data))
147
+
148
+ return all_images
149
+
150
+
151
+ def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
152
+ data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
153
+ data = data.astype(np.uint8)
154
+ return PILImage.fromarray(data)
@@ -0,0 +1,290 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import glob
8
+ import logging
9
+ import os
10
+ from dataclasses import dataclass
11
+
12
+ from typing import List, Optional
13
+
14
+ import pandas as pd
15
+
16
+ import torch
17
+
18
+ from iopath.common.file_io import g_pathmgr
19
+
20
+ from omegaconf.listconfig import ListConfig
21
+
22
+ from training.dataset.vos_segment_loader import (
23
+ JSONSegmentLoader,
24
+ MultiplePNGSegmentLoader,
25
+ PalettisedPNGSegmentLoader,
26
+ SA1BSegmentLoader,
27
+ )
28
+
29
+
30
+ @dataclass
31
+ class VOSFrame:
32
+ frame_idx: int
33
+ image_path: str
34
+ data: Optional[torch.Tensor] = None
35
+ is_conditioning_only: Optional[bool] = False
36
+
37
+
38
+ @dataclass
39
+ class VOSVideo:
40
+ video_name: str
41
+ video_id: int
42
+ frames: List[VOSFrame]
43
+
44
+ def __len__(self):
45
+ return len(self.frames)
46
+
47
+
48
+ class VOSRawDataset:
49
+ def __init__(self):
50
+ pass
51
+
52
+ def get_video(self, idx):
53
+ raise NotImplementedError()
54
+
55
+
56
+ class PNGRawDataset(VOSRawDataset):
57
+ def __init__(
58
+ self,
59
+ img_folder,
60
+ gt_folder,
61
+ file_list_txt=None,
62
+ excluded_videos_list_txt=None,
63
+ sample_rate=1,
64
+ is_palette=True,
65
+ single_object_mode=False,
66
+ truncate_video=-1,
67
+ frames_sampling_mult=False,
68
+ ):
69
+ self.img_folder = img_folder
70
+ self.gt_folder = gt_folder
71
+ self.sample_rate = sample_rate
72
+ self.is_palette = is_palette
73
+ self.single_object_mode = single_object_mode
74
+ self.truncate_video = truncate_video
75
+
76
+ # Read the subset defined in file_list_txt
77
+ if file_list_txt is not None:
78
+ with g_pathmgr.open(file_list_txt, "r") as f:
79
+ subset = [os.path.splitext(line.strip())[0] for line in f]
80
+ else:
81
+ subset = os.listdir(self.img_folder)
82
+
83
+ # Read and process excluded files if provided
84
+ if excluded_videos_list_txt is not None:
85
+ with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
86
+ excluded_files = [os.path.splitext(line.strip())[0] for line in f]
87
+ else:
88
+ excluded_files = []
89
+
90
+ # Check if it's not in excluded_files
91
+ self.video_names = sorted([video_name for video_name in subset if video_name not in excluded_files])
92
+
93
+ if self.single_object_mode:
94
+ # single object mode
95
+ self.video_names = sorted(
96
+ [
97
+ os.path.join(video_name, obj)
98
+ for video_name in self.video_names
99
+ for obj in os.listdir(os.path.join(self.gt_folder, video_name))
100
+ ]
101
+ )
102
+
103
+ if frames_sampling_mult:
104
+ video_names_mult = []
105
+ for video_name in self.video_names:
106
+ num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
107
+ video_names_mult.extend([video_name] * num_frames)
108
+ self.video_names = video_names_mult
109
+
110
+ def get_video(self, idx):
111
+ """
112
+ Given a VOSVideo object, return the mask tensors.
113
+ """
114
+ video_name = self.video_names[idx]
115
+
116
+ if self.single_object_mode:
117
+ video_frame_root = os.path.join(self.img_folder, os.path.dirname(video_name))
118
+ else:
119
+ video_frame_root = os.path.join(self.img_folder, video_name)
120
+
121
+ video_mask_root = os.path.join(self.gt_folder, video_name)
122
+
123
+ if self.is_palette:
124
+ segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
125
+ else:
126
+ segment_loader = MultiplePNGSegmentLoader(video_mask_root, self.single_object_mode)
127
+
128
+ all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
129
+ if self.truncate_video > 0:
130
+ all_frames = all_frames[: self.truncate_video]
131
+ frames = []
132
+ for _, fpath in enumerate(all_frames[:: self.sample_rate]):
133
+ fid = int(os.path.basename(fpath).split(".")[0])
134
+ frames.append(VOSFrame(fid, image_path=fpath))
135
+ video = VOSVideo(video_name, idx, frames)
136
+ return video, segment_loader
137
+
138
+ def __len__(self):
139
+ return len(self.video_names)
140
+
141
+
142
+ class SA1BRawDataset(VOSRawDataset):
143
+ def __init__(
144
+ self,
145
+ img_folder,
146
+ gt_folder,
147
+ file_list_txt=None,
148
+ excluded_videos_list_txt=None,
149
+ num_frames=1,
150
+ mask_area_frac_thresh=1.1, # no filtering by default
151
+ uncertain_iou=-1, # no filtering by default
152
+ ):
153
+ self.img_folder = img_folder
154
+ self.gt_folder = gt_folder
155
+ self.num_frames = num_frames
156
+ self.mask_area_frac_thresh = mask_area_frac_thresh
157
+ self.uncertain_iou = uncertain_iou # stability score
158
+
159
+ # Read the subset defined in file_list_txt
160
+ if file_list_txt is not None:
161
+ with g_pathmgr.open(file_list_txt, "r") as f:
162
+ subset = [os.path.splitext(line.strip())[0] for line in f]
163
+ else:
164
+ subset = os.listdir(self.img_folder)
165
+ subset = [path.split(".")[0] for path in subset if path.endswith(".jpg")] # remove extension
166
+
167
+ # Read and process excluded files if provided
168
+ if excluded_videos_list_txt is not None:
169
+ with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
170
+ excluded_files = [os.path.splitext(line.strip())[0] for line in f]
171
+ else:
172
+ excluded_files = []
173
+
174
+ # Check if it's not in excluded_files and it exists
175
+ self.video_names = [video_name for video_name in subset if video_name not in excluded_files]
176
+
177
+ def get_video(self, idx):
178
+ """
179
+ Given a VOSVideo object, return the mask tensors.
180
+ """
181
+ video_name = self.video_names[idx]
182
+
183
+ video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
184
+ video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
185
+
186
+ segment_loader = SA1BSegmentLoader(
187
+ video_mask_path,
188
+ mask_area_frac_thresh=self.mask_area_frac_thresh,
189
+ video_frame_path=video_frame_path,
190
+ uncertain_iou=self.uncertain_iou,
191
+ )
192
+
193
+ frames = []
194
+ for frame_idx in range(self.num_frames):
195
+ frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
196
+ video_name = video_name.split("_")[-1] # filename is sa_{int}
197
+ # video id needs to be image_id to be able to load correct annotation file during eval
198
+ video = VOSVideo(video_name, int(video_name), frames)
199
+ return video, segment_loader
200
+
201
+ def __len__(self):
202
+ return len(self.video_names)
203
+
204
+
205
+ class JSONRawDataset(VOSRawDataset):
206
+ """
207
+ Dataset where the annotation in the format of SA-V json files
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ img_folder,
213
+ gt_folder,
214
+ file_list_txt=None,
215
+ excluded_videos_list_txt=None,
216
+ sample_rate=1,
217
+ rm_unannotated=True,
218
+ ann_every=1,
219
+ frames_fps=24,
220
+ ):
221
+ self.gt_folder = gt_folder
222
+ self.img_folder = img_folder
223
+ self.sample_rate = sample_rate
224
+ self.rm_unannotated = rm_unannotated
225
+ self.ann_every = ann_every
226
+ self.frames_fps = frames_fps
227
+
228
+ # Read and process excluded files if provided
229
+ excluded_files = []
230
+ if excluded_videos_list_txt is not None:
231
+ if isinstance(excluded_videos_list_txt, str):
232
+ excluded_videos_lists = [excluded_videos_list_txt]
233
+ elif isinstance(excluded_videos_list_txt, ListConfig):
234
+ excluded_videos_lists = list(excluded_videos_list_txt)
235
+ else:
236
+ raise NotImplementedError
237
+
238
+ for excluded_videos_list_txt in excluded_videos_lists:
239
+ with open(excluded_videos_list_txt, "r") as f:
240
+ excluded_files.extend([os.path.splitext(line.strip())[0] for line in f])
241
+ excluded_files = set(excluded_files)
242
+
243
+ # Read the subset defined in file_list_txt
244
+ if file_list_txt is not None:
245
+ with g_pathmgr.open(file_list_txt, "r") as f:
246
+ subset = [os.path.splitext(line.strip())[0] for line in f]
247
+ else:
248
+ subset = os.listdir(self.img_folder)
249
+
250
+ self.video_names = sorted([video_name for video_name in subset if video_name not in excluded_files])
251
+
252
+ def get_video(self, video_idx):
253
+ """
254
+ Given a VOSVideo object, return the mask tensors.
255
+ """
256
+ video_name = self.video_names[video_idx]
257
+ video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
258
+ segment_loader = JSONSegmentLoader(
259
+ video_json_path=video_json_path,
260
+ ann_every=self.ann_every,
261
+ frames_fps=self.frames_fps,
262
+ )
263
+
264
+ frame_ids = [
265
+ int(os.path.splitext(frame_name)[0])
266
+ for frame_name in sorted(os.listdir(os.path.join(self.img_folder, video_name)))
267
+ ]
268
+
269
+ frames = [
270
+ VOSFrame(
271
+ frame_id,
272
+ image_path=os.path.join(self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)),
273
+ )
274
+ for frame_id in frame_ids[:: self.sample_rate]
275
+ ]
276
+
277
+ if self.rm_unannotated:
278
+ # Eliminate the frames that have not been annotated
279
+ valid_frame_ids = [
280
+ i * segment_loader.ann_every
281
+ for i, annot in enumerate(segment_loader.frame_annots)
282
+ if annot is not None and None not in annot
283
+ ]
284
+ frames = [f for f in frames if f.frame_idx in valid_frame_ids]
285
+
286
+ video = VOSVideo(video_name, video_idx, frames)
287
+ return video, segment_loader
288
+
289
+ def __len__(self):
290
+ return len(self.video_names)
@@ -0,0 +1,103 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import random
8
+ from dataclasses import dataclass
9
+ from typing import List
10
+
11
+ from training.dataset.vos_segment_loader import LazySegments
12
+
13
+ MAX_RETRIES = 1000
14
+
15
+
16
+ @dataclass
17
+ class SampledFramesAndObjects:
18
+ frames: List[int]
19
+ object_ids: List[int]
20
+
21
+
22
+ class VOSSampler:
23
+ def __init__(self, sort_frames=True):
24
+ # frames are ordered by frame id when sort_frames is True
25
+ self.sort_frames = sort_frames
26
+
27
+ def sample(self, video):
28
+ raise NotImplementedError()
29
+
30
+
31
+ class RandomUniformSampler(VOSSampler):
32
+ def __init__(
33
+ self,
34
+ num_frames,
35
+ max_num_objects,
36
+ reverse_time_prob=0.0,
37
+ ):
38
+ self.num_frames = num_frames
39
+ self.max_num_objects = max_num_objects
40
+ self.reverse_time_prob = reverse_time_prob
41
+
42
+ def sample(self, video, segment_loader, epoch=None):
43
+
44
+ for retry in range(MAX_RETRIES):
45
+ if len(video.frames) < self.num_frames:
46
+ raise Exception(
47
+ f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
48
+ )
49
+ start = random.randrange(0, len(video.frames) - self.num_frames + 1)
50
+ frames = [video.frames[start + step] for step in range(self.num_frames)]
51
+ if random.uniform(0, 1) < self.reverse_time_prob:
52
+ # Reverse time
53
+ frames = frames[::-1]
54
+
55
+ # Get first frame object ids
56
+ visible_object_ids = []
57
+ loaded_segms = segment_loader.load(frames[0].frame_idx)
58
+ if isinstance(loaded_segms, LazySegments):
59
+ # LazySegments for SA1BRawDataset
60
+ visible_object_ids = list(loaded_segms.keys())
61
+ else:
62
+ for object_id, segment in segment_loader.load(frames[0].frame_idx).items():
63
+ if segment.sum():
64
+ visible_object_ids.append(object_id)
65
+
66
+ # First frame needs to have at least a target to track
67
+ if len(visible_object_ids) > 0:
68
+ break
69
+ if retry >= MAX_RETRIES - 1:
70
+ raise Exception("No visible objects")
71
+
72
+ object_ids = random.sample(
73
+ visible_object_ids,
74
+ min(len(visible_object_ids), self.max_num_objects),
75
+ )
76
+ return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
77
+
78
+
79
+ class EvalSampler(VOSSampler):
80
+ """
81
+ VOS Sampler for evaluation: sampling all the frames and all the objects in a video
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ ):
87
+ super().__init__()
88
+
89
+ def sample(self, video, segment_loader, epoch=None):
90
+ """
91
+ Sampling all the frames and all the objects
92
+ """
93
+ if self.sort_frames:
94
+ # ordered by frame id
95
+ frames = sorted(video.frames, key=lambda x: x.frame_idx)
96
+ else:
97
+ # use the original order
98
+ frames = video.frames
99
+ object_ids = segment_loader.load(frames[0].frame_idx).keys()
100
+ if len(object_ids) == 0:
101
+ raise Exception("First frame of the video has no objects")
102
+
103
+ return SampledFramesAndObjects(frames=frames, object_ids=object_ids)