nnInteractive 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnInteractive/__init__.py +3 -0
- nnInteractive/inference/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
- nnInteractive/inference/inference_session.py +1400 -0
- nnInteractive/interaction/__init__.py +0 -0
- nnInteractive/interaction/point.py +166 -0
- nnInteractive/supervoxel/setup.py +4 -0
- nnInteractive/supervoxel/src/metadata.py +118 -0
- nnInteractive/supervoxel/src/reader.py +175 -0
- nnInteractive/supervoxel/src/run.py +136 -0
- nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
- nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
- nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
- nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
- nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
- nnInteractive/supervoxel/src/sam2/setup.py +174 -0
- nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
- nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
- nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
- nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
- nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
- nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
- nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
- nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
- nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
- nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
- nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
- nnInteractive/supervoxel/src/supervoxel.py +198 -0
- nnInteractive/trainer/__init__.py +0 -0
- nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- nnInteractive/utils/__init__.py +0 -0
- nnInteractive/utils/bboxes.py +217 -0
- nnInteractive/utils/checkpoint_cleansing.py +9 -0
- nnInteractive/utils/crop.py +268 -0
- nnInteractive/utils/erosion_dilation.py +48 -0
- nnInteractive/utils/inference_helpers.py +45 -0
- nnInteractive/utils/os_shennanigans.py +16 -0
- nnInteractive/utils/rounding.py +13 -0
- nninteractive-2.0.0.dist-info/METADATA +511 -0
- nninteractive-2.0.0.dist-info/RECORD +76 -0
- nninteractive-2.0.0.dist-info/WHEEL +5 -0
- nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
- nninteractive-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import random
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from iopath.common.file_io import g_pathmgr
|
|
15
|
+
from PIL import Image as PILImage
|
|
16
|
+
from torchvision.datasets.vision import VisionDataset
|
|
17
|
+
|
|
18
|
+
from training.dataset.vos_raw_dataset import VOSRawDataset
|
|
19
|
+
from training.dataset.vos_sampler import VOSSampler
|
|
20
|
+
from training.dataset.vos_segment_loader import JSONSegmentLoader
|
|
21
|
+
|
|
22
|
+
from training.utils.data_utils import Frame, Object, VideoDatapoint
|
|
23
|
+
|
|
24
|
+
MAX_RETRIES = 100
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VOSDataset(VisionDataset):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
transforms,
|
|
31
|
+
training: bool,
|
|
32
|
+
video_dataset: VOSRawDataset,
|
|
33
|
+
sampler: VOSSampler,
|
|
34
|
+
multiplier: int,
|
|
35
|
+
always_target=True,
|
|
36
|
+
target_segments_available=True,
|
|
37
|
+
):
|
|
38
|
+
self._transforms = transforms
|
|
39
|
+
self.training = training
|
|
40
|
+
self.video_dataset = video_dataset
|
|
41
|
+
self.sampler = sampler
|
|
42
|
+
|
|
43
|
+
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
|
|
44
|
+
self.repeat_factors *= multiplier
|
|
45
|
+
print(f"Raw dataset length = {len(self.video_dataset)}")
|
|
46
|
+
|
|
47
|
+
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
|
|
48
|
+
self.always_target = always_target
|
|
49
|
+
self.target_segments_available = target_segments_available
|
|
50
|
+
|
|
51
|
+
def _get_datapoint(self, idx):
|
|
52
|
+
|
|
53
|
+
for retry in range(MAX_RETRIES):
|
|
54
|
+
try:
|
|
55
|
+
if isinstance(idx, torch.Tensor):
|
|
56
|
+
idx = idx.item()
|
|
57
|
+
# sample a video
|
|
58
|
+
video, segment_loader = self.video_dataset.get_video(idx)
|
|
59
|
+
# sample frames and object indices to be used in a datapoint
|
|
60
|
+
sampled_frms_and_objs = self.sampler.sample(video, segment_loader, epoch=self.curr_epoch)
|
|
61
|
+
break # Succesfully loaded video
|
|
62
|
+
except Exception as e:
|
|
63
|
+
if self.training:
|
|
64
|
+
logging.warning(f"Loading failed (id={idx}); Retry {retry} with exception: {e}")
|
|
65
|
+
idx = random.randrange(0, len(self.video_dataset))
|
|
66
|
+
else:
|
|
67
|
+
# Shouldn't fail to load a val video
|
|
68
|
+
raise e
|
|
69
|
+
|
|
70
|
+
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
|
|
71
|
+
for transform in self._transforms:
|
|
72
|
+
datapoint = transform(datapoint, epoch=self.curr_epoch)
|
|
73
|
+
return datapoint
|
|
74
|
+
|
|
75
|
+
def construct(self, video, sampled_frms_and_objs, segment_loader):
|
|
76
|
+
"""
|
|
77
|
+
Constructs a VideoDatapoint sample to pass to transforms
|
|
78
|
+
"""
|
|
79
|
+
sampled_frames = sampled_frms_and_objs.frames
|
|
80
|
+
sampled_object_ids = sampled_frms_and_objs.object_ids
|
|
81
|
+
|
|
82
|
+
images = []
|
|
83
|
+
rgb_images = load_images(sampled_frames)
|
|
84
|
+
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
|
|
85
|
+
for frame_idx, frame in enumerate(sampled_frames):
|
|
86
|
+
w, h = rgb_images[frame_idx].size
|
|
87
|
+
images.append(
|
|
88
|
+
Frame(
|
|
89
|
+
data=rgb_images[frame_idx],
|
|
90
|
+
objects=[],
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
# We load the gt segments associated with the current frame
|
|
94
|
+
if isinstance(segment_loader, JSONSegmentLoader):
|
|
95
|
+
segments = segment_loader.load(frame.frame_idx, obj_ids=sampled_object_ids)
|
|
96
|
+
else:
|
|
97
|
+
segments = segment_loader.load(frame.frame_idx)
|
|
98
|
+
for obj_id in sampled_object_ids:
|
|
99
|
+
# Extract the segment
|
|
100
|
+
if obj_id in segments:
|
|
101
|
+
assert segments[obj_id] is not None, "None targets are not supported"
|
|
102
|
+
# segment is uint8 and remains uint8 throughout the transforms
|
|
103
|
+
segment = segments[obj_id].to(torch.uint8)
|
|
104
|
+
else:
|
|
105
|
+
# There is no target, we either use a zero mask target or drop this object
|
|
106
|
+
if not self.always_target:
|
|
107
|
+
continue
|
|
108
|
+
segment = torch.zeros(h, w, dtype=torch.uint8)
|
|
109
|
+
|
|
110
|
+
images[frame_idx].objects.append(
|
|
111
|
+
Object(
|
|
112
|
+
object_id=obj_id,
|
|
113
|
+
frame_index=frame.frame_idx,
|
|
114
|
+
segment=segment,
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
return VideoDatapoint(
|
|
118
|
+
frames=images,
|
|
119
|
+
video_id=video.video_id,
|
|
120
|
+
size=(h, w),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def __getitem__(self, idx):
|
|
124
|
+
return self._get_datapoint(idx)
|
|
125
|
+
|
|
126
|
+
def __len__(self):
|
|
127
|
+
return len(self.video_dataset)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def load_images(frames):
|
|
131
|
+
all_images = []
|
|
132
|
+
cache = {}
|
|
133
|
+
for frame in frames:
|
|
134
|
+
if frame.data is None:
|
|
135
|
+
# Load the frame rgb data from file
|
|
136
|
+
path = frame.image_path
|
|
137
|
+
if path in cache:
|
|
138
|
+
all_images.append(deepcopy(all_images[cache[path]]))
|
|
139
|
+
continue
|
|
140
|
+
with g_pathmgr.open(path, "rb") as fopen:
|
|
141
|
+
all_images.append(PILImage.open(fopen).convert("RGB"))
|
|
142
|
+
cache[path] = len(all_images) - 1
|
|
143
|
+
else:
|
|
144
|
+
# The frame rgb data has already been loaded
|
|
145
|
+
# Convert it to a PILImage
|
|
146
|
+
all_images.append(tensor_2_PIL(frame.data))
|
|
147
|
+
|
|
148
|
+
return all_images
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
|
|
152
|
+
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
|
|
153
|
+
data = data.astype(np.uint8)
|
|
154
|
+
return PILImage.fromarray(data)
|
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import glob
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
from typing import List, Optional
|
|
13
|
+
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from iopath.common.file_io import g_pathmgr
|
|
19
|
+
|
|
20
|
+
from omegaconf.listconfig import ListConfig
|
|
21
|
+
|
|
22
|
+
from training.dataset.vos_segment_loader import (
|
|
23
|
+
JSONSegmentLoader,
|
|
24
|
+
MultiplePNGSegmentLoader,
|
|
25
|
+
PalettisedPNGSegmentLoader,
|
|
26
|
+
SA1BSegmentLoader,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class VOSFrame:
|
|
32
|
+
frame_idx: int
|
|
33
|
+
image_path: str
|
|
34
|
+
data: Optional[torch.Tensor] = None
|
|
35
|
+
is_conditioning_only: Optional[bool] = False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class VOSVideo:
|
|
40
|
+
video_name: str
|
|
41
|
+
video_id: int
|
|
42
|
+
frames: List[VOSFrame]
|
|
43
|
+
|
|
44
|
+
def __len__(self):
|
|
45
|
+
return len(self.frames)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class VOSRawDataset:
|
|
49
|
+
def __init__(self):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
def get_video(self, idx):
|
|
53
|
+
raise NotImplementedError()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PNGRawDataset(VOSRawDataset):
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
img_folder,
|
|
60
|
+
gt_folder,
|
|
61
|
+
file_list_txt=None,
|
|
62
|
+
excluded_videos_list_txt=None,
|
|
63
|
+
sample_rate=1,
|
|
64
|
+
is_palette=True,
|
|
65
|
+
single_object_mode=False,
|
|
66
|
+
truncate_video=-1,
|
|
67
|
+
frames_sampling_mult=False,
|
|
68
|
+
):
|
|
69
|
+
self.img_folder = img_folder
|
|
70
|
+
self.gt_folder = gt_folder
|
|
71
|
+
self.sample_rate = sample_rate
|
|
72
|
+
self.is_palette = is_palette
|
|
73
|
+
self.single_object_mode = single_object_mode
|
|
74
|
+
self.truncate_video = truncate_video
|
|
75
|
+
|
|
76
|
+
# Read the subset defined in file_list_txt
|
|
77
|
+
if file_list_txt is not None:
|
|
78
|
+
with g_pathmgr.open(file_list_txt, "r") as f:
|
|
79
|
+
subset = [os.path.splitext(line.strip())[0] for line in f]
|
|
80
|
+
else:
|
|
81
|
+
subset = os.listdir(self.img_folder)
|
|
82
|
+
|
|
83
|
+
# Read and process excluded files if provided
|
|
84
|
+
if excluded_videos_list_txt is not None:
|
|
85
|
+
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
|
|
86
|
+
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
|
|
87
|
+
else:
|
|
88
|
+
excluded_files = []
|
|
89
|
+
|
|
90
|
+
# Check if it's not in excluded_files
|
|
91
|
+
self.video_names = sorted([video_name for video_name in subset if video_name not in excluded_files])
|
|
92
|
+
|
|
93
|
+
if self.single_object_mode:
|
|
94
|
+
# single object mode
|
|
95
|
+
self.video_names = sorted(
|
|
96
|
+
[
|
|
97
|
+
os.path.join(video_name, obj)
|
|
98
|
+
for video_name in self.video_names
|
|
99
|
+
for obj in os.listdir(os.path.join(self.gt_folder, video_name))
|
|
100
|
+
]
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if frames_sampling_mult:
|
|
104
|
+
video_names_mult = []
|
|
105
|
+
for video_name in self.video_names:
|
|
106
|
+
num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
|
|
107
|
+
video_names_mult.extend([video_name] * num_frames)
|
|
108
|
+
self.video_names = video_names_mult
|
|
109
|
+
|
|
110
|
+
def get_video(self, idx):
|
|
111
|
+
"""
|
|
112
|
+
Given a VOSVideo object, return the mask tensors.
|
|
113
|
+
"""
|
|
114
|
+
video_name = self.video_names[idx]
|
|
115
|
+
|
|
116
|
+
if self.single_object_mode:
|
|
117
|
+
video_frame_root = os.path.join(self.img_folder, os.path.dirname(video_name))
|
|
118
|
+
else:
|
|
119
|
+
video_frame_root = os.path.join(self.img_folder, video_name)
|
|
120
|
+
|
|
121
|
+
video_mask_root = os.path.join(self.gt_folder, video_name)
|
|
122
|
+
|
|
123
|
+
if self.is_palette:
|
|
124
|
+
segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
|
|
125
|
+
else:
|
|
126
|
+
segment_loader = MultiplePNGSegmentLoader(video_mask_root, self.single_object_mode)
|
|
127
|
+
|
|
128
|
+
all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
|
|
129
|
+
if self.truncate_video > 0:
|
|
130
|
+
all_frames = all_frames[: self.truncate_video]
|
|
131
|
+
frames = []
|
|
132
|
+
for _, fpath in enumerate(all_frames[:: self.sample_rate]):
|
|
133
|
+
fid = int(os.path.basename(fpath).split(".")[0])
|
|
134
|
+
frames.append(VOSFrame(fid, image_path=fpath))
|
|
135
|
+
video = VOSVideo(video_name, idx, frames)
|
|
136
|
+
return video, segment_loader
|
|
137
|
+
|
|
138
|
+
def __len__(self):
|
|
139
|
+
return len(self.video_names)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class SA1BRawDataset(VOSRawDataset):
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
img_folder,
|
|
146
|
+
gt_folder,
|
|
147
|
+
file_list_txt=None,
|
|
148
|
+
excluded_videos_list_txt=None,
|
|
149
|
+
num_frames=1,
|
|
150
|
+
mask_area_frac_thresh=1.1, # no filtering by default
|
|
151
|
+
uncertain_iou=-1, # no filtering by default
|
|
152
|
+
):
|
|
153
|
+
self.img_folder = img_folder
|
|
154
|
+
self.gt_folder = gt_folder
|
|
155
|
+
self.num_frames = num_frames
|
|
156
|
+
self.mask_area_frac_thresh = mask_area_frac_thresh
|
|
157
|
+
self.uncertain_iou = uncertain_iou # stability score
|
|
158
|
+
|
|
159
|
+
# Read the subset defined in file_list_txt
|
|
160
|
+
if file_list_txt is not None:
|
|
161
|
+
with g_pathmgr.open(file_list_txt, "r") as f:
|
|
162
|
+
subset = [os.path.splitext(line.strip())[0] for line in f]
|
|
163
|
+
else:
|
|
164
|
+
subset = os.listdir(self.img_folder)
|
|
165
|
+
subset = [path.split(".")[0] for path in subset if path.endswith(".jpg")] # remove extension
|
|
166
|
+
|
|
167
|
+
# Read and process excluded files if provided
|
|
168
|
+
if excluded_videos_list_txt is not None:
|
|
169
|
+
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
|
|
170
|
+
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
|
|
171
|
+
else:
|
|
172
|
+
excluded_files = []
|
|
173
|
+
|
|
174
|
+
# Check if it's not in excluded_files and it exists
|
|
175
|
+
self.video_names = [video_name for video_name in subset if video_name not in excluded_files]
|
|
176
|
+
|
|
177
|
+
def get_video(self, idx):
|
|
178
|
+
"""
|
|
179
|
+
Given a VOSVideo object, return the mask tensors.
|
|
180
|
+
"""
|
|
181
|
+
video_name = self.video_names[idx]
|
|
182
|
+
|
|
183
|
+
video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
|
|
184
|
+
video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
|
|
185
|
+
|
|
186
|
+
segment_loader = SA1BSegmentLoader(
|
|
187
|
+
video_mask_path,
|
|
188
|
+
mask_area_frac_thresh=self.mask_area_frac_thresh,
|
|
189
|
+
video_frame_path=video_frame_path,
|
|
190
|
+
uncertain_iou=self.uncertain_iou,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
frames = []
|
|
194
|
+
for frame_idx in range(self.num_frames):
|
|
195
|
+
frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
|
|
196
|
+
video_name = video_name.split("_")[-1] # filename is sa_{int}
|
|
197
|
+
# video id needs to be image_id to be able to load correct annotation file during eval
|
|
198
|
+
video = VOSVideo(video_name, int(video_name), frames)
|
|
199
|
+
return video, segment_loader
|
|
200
|
+
|
|
201
|
+
def __len__(self):
|
|
202
|
+
return len(self.video_names)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class JSONRawDataset(VOSRawDataset):
|
|
206
|
+
"""
|
|
207
|
+
Dataset where the annotation in the format of SA-V json files
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(
|
|
211
|
+
self,
|
|
212
|
+
img_folder,
|
|
213
|
+
gt_folder,
|
|
214
|
+
file_list_txt=None,
|
|
215
|
+
excluded_videos_list_txt=None,
|
|
216
|
+
sample_rate=1,
|
|
217
|
+
rm_unannotated=True,
|
|
218
|
+
ann_every=1,
|
|
219
|
+
frames_fps=24,
|
|
220
|
+
):
|
|
221
|
+
self.gt_folder = gt_folder
|
|
222
|
+
self.img_folder = img_folder
|
|
223
|
+
self.sample_rate = sample_rate
|
|
224
|
+
self.rm_unannotated = rm_unannotated
|
|
225
|
+
self.ann_every = ann_every
|
|
226
|
+
self.frames_fps = frames_fps
|
|
227
|
+
|
|
228
|
+
# Read and process excluded files if provided
|
|
229
|
+
excluded_files = []
|
|
230
|
+
if excluded_videos_list_txt is not None:
|
|
231
|
+
if isinstance(excluded_videos_list_txt, str):
|
|
232
|
+
excluded_videos_lists = [excluded_videos_list_txt]
|
|
233
|
+
elif isinstance(excluded_videos_list_txt, ListConfig):
|
|
234
|
+
excluded_videos_lists = list(excluded_videos_list_txt)
|
|
235
|
+
else:
|
|
236
|
+
raise NotImplementedError
|
|
237
|
+
|
|
238
|
+
for excluded_videos_list_txt in excluded_videos_lists:
|
|
239
|
+
with open(excluded_videos_list_txt, "r") as f:
|
|
240
|
+
excluded_files.extend([os.path.splitext(line.strip())[0] for line in f])
|
|
241
|
+
excluded_files = set(excluded_files)
|
|
242
|
+
|
|
243
|
+
# Read the subset defined in file_list_txt
|
|
244
|
+
if file_list_txt is not None:
|
|
245
|
+
with g_pathmgr.open(file_list_txt, "r") as f:
|
|
246
|
+
subset = [os.path.splitext(line.strip())[0] for line in f]
|
|
247
|
+
else:
|
|
248
|
+
subset = os.listdir(self.img_folder)
|
|
249
|
+
|
|
250
|
+
self.video_names = sorted([video_name for video_name in subset if video_name not in excluded_files])
|
|
251
|
+
|
|
252
|
+
def get_video(self, video_idx):
|
|
253
|
+
"""
|
|
254
|
+
Given a VOSVideo object, return the mask tensors.
|
|
255
|
+
"""
|
|
256
|
+
video_name = self.video_names[video_idx]
|
|
257
|
+
video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
|
|
258
|
+
segment_loader = JSONSegmentLoader(
|
|
259
|
+
video_json_path=video_json_path,
|
|
260
|
+
ann_every=self.ann_every,
|
|
261
|
+
frames_fps=self.frames_fps,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
frame_ids = [
|
|
265
|
+
int(os.path.splitext(frame_name)[0])
|
|
266
|
+
for frame_name in sorted(os.listdir(os.path.join(self.img_folder, video_name)))
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
frames = [
|
|
270
|
+
VOSFrame(
|
|
271
|
+
frame_id,
|
|
272
|
+
image_path=os.path.join(self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)),
|
|
273
|
+
)
|
|
274
|
+
for frame_id in frame_ids[:: self.sample_rate]
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
if self.rm_unannotated:
|
|
278
|
+
# Eliminate the frames that have not been annotated
|
|
279
|
+
valid_frame_ids = [
|
|
280
|
+
i * segment_loader.ann_every
|
|
281
|
+
for i, annot in enumerate(segment_loader.frame_annots)
|
|
282
|
+
if annot is not None and None not in annot
|
|
283
|
+
]
|
|
284
|
+
frames = [f for f in frames if f.frame_idx in valid_frame_ids]
|
|
285
|
+
|
|
286
|
+
video = VOSVideo(video_name, video_idx, frames)
|
|
287
|
+
return video, segment_loader
|
|
288
|
+
|
|
289
|
+
def __len__(self):
|
|
290
|
+
return len(self.video_names)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import random
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import List
|
|
10
|
+
|
|
11
|
+
from training.dataset.vos_segment_loader import LazySegments
|
|
12
|
+
|
|
13
|
+
MAX_RETRIES = 1000
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class SampledFramesAndObjects:
|
|
18
|
+
frames: List[int]
|
|
19
|
+
object_ids: List[int]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class VOSSampler:
|
|
23
|
+
def __init__(self, sort_frames=True):
|
|
24
|
+
# frames are ordered by frame id when sort_frames is True
|
|
25
|
+
self.sort_frames = sort_frames
|
|
26
|
+
|
|
27
|
+
def sample(self, video):
|
|
28
|
+
raise NotImplementedError()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RandomUniformSampler(VOSSampler):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
num_frames,
|
|
35
|
+
max_num_objects,
|
|
36
|
+
reverse_time_prob=0.0,
|
|
37
|
+
):
|
|
38
|
+
self.num_frames = num_frames
|
|
39
|
+
self.max_num_objects = max_num_objects
|
|
40
|
+
self.reverse_time_prob = reverse_time_prob
|
|
41
|
+
|
|
42
|
+
def sample(self, video, segment_loader, epoch=None):
|
|
43
|
+
|
|
44
|
+
for retry in range(MAX_RETRIES):
|
|
45
|
+
if len(video.frames) < self.num_frames:
|
|
46
|
+
raise Exception(
|
|
47
|
+
f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
|
|
48
|
+
)
|
|
49
|
+
start = random.randrange(0, len(video.frames) - self.num_frames + 1)
|
|
50
|
+
frames = [video.frames[start + step] for step in range(self.num_frames)]
|
|
51
|
+
if random.uniform(0, 1) < self.reverse_time_prob:
|
|
52
|
+
# Reverse time
|
|
53
|
+
frames = frames[::-1]
|
|
54
|
+
|
|
55
|
+
# Get first frame object ids
|
|
56
|
+
visible_object_ids = []
|
|
57
|
+
loaded_segms = segment_loader.load(frames[0].frame_idx)
|
|
58
|
+
if isinstance(loaded_segms, LazySegments):
|
|
59
|
+
# LazySegments for SA1BRawDataset
|
|
60
|
+
visible_object_ids = list(loaded_segms.keys())
|
|
61
|
+
else:
|
|
62
|
+
for object_id, segment in segment_loader.load(frames[0].frame_idx).items():
|
|
63
|
+
if segment.sum():
|
|
64
|
+
visible_object_ids.append(object_id)
|
|
65
|
+
|
|
66
|
+
# First frame needs to have at least a target to track
|
|
67
|
+
if len(visible_object_ids) > 0:
|
|
68
|
+
break
|
|
69
|
+
if retry >= MAX_RETRIES - 1:
|
|
70
|
+
raise Exception("No visible objects")
|
|
71
|
+
|
|
72
|
+
object_ids = random.sample(
|
|
73
|
+
visible_object_ids,
|
|
74
|
+
min(len(visible_object_ids), self.max_num_objects),
|
|
75
|
+
)
|
|
76
|
+
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class EvalSampler(VOSSampler):
|
|
80
|
+
"""
|
|
81
|
+
VOS Sampler for evaluation: sampling all the frames and all the objects in a video
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
):
|
|
87
|
+
super().__init__()
|
|
88
|
+
|
|
89
|
+
def sample(self, video, segment_loader, epoch=None):
|
|
90
|
+
"""
|
|
91
|
+
Sampling all the frames and all the objects
|
|
92
|
+
"""
|
|
93
|
+
if self.sort_frames:
|
|
94
|
+
# ordered by frame id
|
|
95
|
+
frames = sorted(video.frames, key=lambda x: x.frame_idx)
|
|
96
|
+
else:
|
|
97
|
+
# use the original order
|
|
98
|
+
frames = video.frames
|
|
99
|
+
object_ids = segment_loader.load(frames[0].frame_idx).keys()
|
|
100
|
+
if len(object_ids) == 0:
|
|
101
|
+
raise Exception("First frame of the video has no objects")
|
|
102
|
+
|
|
103
|
+
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|