nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,289 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import glob
8
+ import json
9
+ import os
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+
15
+ from PIL import Image as PILImage
16
+
17
+ try:
18
+ from pycocotools import mask as mask_utils
19
+ except:
20
+ pass
21
+
22
+
23
+ class JSONSegmentLoader:
24
+ def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
25
+ # Annotations in the json are provided every ann_every th frame
26
+ self.ann_every = ann_every
27
+ # Ids of the objects to consider when sampling this video
28
+ self.valid_obj_ids = valid_obj_ids
29
+ with open(video_json_path, "r") as f:
30
+ data = json.load(f)
31
+ if isinstance(data, list):
32
+ self.frame_annots = data
33
+ elif isinstance(data, dict):
34
+ masklet_field_name = "masklet" if "masklet" in data else "masks"
35
+ self.frame_annots = data[masklet_field_name]
36
+ if "fps" in data:
37
+ if isinstance(data["fps"], list):
38
+ annotations_fps = int(data["fps"][0])
39
+ else:
40
+ annotations_fps = int(data["fps"])
41
+ assert frames_fps % annotations_fps == 0
42
+ self.ann_every = frames_fps // annotations_fps
43
+ else:
44
+ raise NotImplementedError
45
+
46
+ def load(self, frame_id, obj_ids=None):
47
+ assert frame_id % self.ann_every == 0
48
+ rle_mask = self.frame_annots[frame_id // self.ann_every]
49
+
50
+ valid_objs_ids = set(range(len(rle_mask)))
51
+ if self.valid_obj_ids is not None:
52
+ # Remove the masklets that have been filtered out for this video
53
+ valid_objs_ids &= set(self.valid_obj_ids)
54
+ if obj_ids is not None:
55
+ # Only keep the objects that have been sampled
56
+ valid_objs_ids &= set(obj_ids)
57
+ valid_objs_ids = sorted(list(valid_objs_ids))
58
+
59
+ # Construct rle_masks_filtered that only contains the rle masks we are interested in
60
+ id_2_idx = {}
61
+ rle_mask_filtered = []
62
+ for obj_id in valid_objs_ids:
63
+ if rle_mask[obj_id] is not None:
64
+ id_2_idx[obj_id] = len(rle_mask_filtered)
65
+ rle_mask_filtered.append(rle_mask[obj_id])
66
+ else:
67
+ id_2_idx[obj_id] = None
68
+
69
+ # Decode the masks
70
+ raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(2, 0, 1) # (num_obj, h, w)
71
+ segments = {}
72
+ for obj_id in valid_objs_ids:
73
+ if id_2_idx[obj_id] is None:
74
+ segments[obj_id] = None
75
+ else:
76
+ idx = id_2_idx[obj_id]
77
+ segments[obj_id] = raw_segments[idx]
78
+ return segments
79
+
80
+ def get_valid_obj_frames_ids(self, num_frames_min=None):
81
+ # For each object, find all the frames with a valid (not None) mask
82
+ num_objects = len(self.frame_annots[0])
83
+
84
+ # The result dict associates each obj_id with the id of its valid frames
85
+ res = {obj_id: [] for obj_id in range(num_objects)}
86
+
87
+ for annot_idx, annot in enumerate(self.frame_annots):
88
+ for obj_id in range(num_objects):
89
+ if annot[obj_id] is not None:
90
+ res[obj_id].append(int(annot_idx * self.ann_every))
91
+
92
+ if num_frames_min is not None:
93
+ # Remove masklets that have less than num_frames_min valid masks
94
+ for obj_id, valid_frames in list(res.items()):
95
+ if len(valid_frames) < num_frames_min:
96
+ res.pop(obj_id)
97
+
98
+ return res
99
+
100
+
101
+ class PalettisedPNGSegmentLoader:
102
+ def __init__(self, video_png_root):
103
+ """
104
+ SegmentLoader for datasets with masks stored as palettised PNGs.
105
+ video_png_root: the folder contains all the masks stored in png
106
+ """
107
+ self.video_png_root = video_png_root
108
+ # build a mapping from frame id to their PNG mask path
109
+ # note that in some datasets, the PNG paths could have more
110
+ # than 5 digits, e.g. "00000000.png" instead of "00000.png"
111
+ png_filenames = os.listdir(self.video_png_root)
112
+ self.frame_id_to_png_filename = {}
113
+ for filename in png_filenames:
114
+ frame_id, _ = os.path.splitext(filename)
115
+ self.frame_id_to_png_filename[int(frame_id)] = filename
116
+
117
+ def load(self, frame_id):
118
+ """
119
+ load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
120
+ Args:
121
+ frame_id: int, define the mask path
122
+ Return:
123
+ binary_segments: dict
124
+ """
125
+ # check the path
126
+ mask_path = os.path.join(self.video_png_root, self.frame_id_to_png_filename[frame_id])
127
+
128
+ # load the mask
129
+ masks = PILImage.open(mask_path).convert("P")
130
+ masks = np.array(masks)
131
+
132
+ object_id = pd.unique(masks.flatten())
133
+ object_id = object_id[object_id != 0] # remove background (0)
134
+
135
+ # convert into N binary segmentation masks
136
+ binary_segments = {}
137
+ for i in object_id:
138
+ bs = masks == i
139
+ binary_segments[i] = torch.from_numpy(bs)
140
+
141
+ return binary_segments
142
+
143
+ def __len__(self):
144
+ return
145
+
146
+
147
+ class MultiplePNGSegmentLoader:
148
+ def __init__(self, video_png_root, single_object_mode=False):
149
+ """
150
+ video_png_root: the folder contains all the masks stored in png
151
+ single_object_mode: whether to load only a single object at a time
152
+ """
153
+ self.video_png_root = video_png_root
154
+ self.single_object_mode = single_object_mode
155
+ # read a mask to know the resolution of the video
156
+ if self.single_object_mode:
157
+ tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
158
+ else:
159
+ tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
160
+ tmp_mask = np.array(PILImage.open(tmp_mask_path))
161
+ self.H = tmp_mask.shape[0]
162
+ self.W = tmp_mask.shape[1]
163
+ if self.single_object_mode:
164
+ self.obj_id = int(video_png_root.split("/")[-1]) + 1 # offset by 1 as bg is 0
165
+ else:
166
+ self.obj_id = None
167
+
168
+ def load(self, frame_id):
169
+ if self.single_object_mode:
170
+ return self._load_single_png(frame_id)
171
+ else:
172
+ return self._load_multiple_pngs(frame_id)
173
+
174
+ def _load_single_png(self, frame_id):
175
+ """
176
+ load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
177
+ Args:
178
+ frame_id: int, define the mask path
179
+ Return:
180
+ binary_segments: dict
181
+ """
182
+ mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
183
+ binary_segments = {}
184
+
185
+ if os.path.exists(mask_path):
186
+ mask = np.array(PILImage.open(mask_path))
187
+ else:
188
+ # if png doesn't exist, empty mask
189
+ mask = np.zeros((self.H, self.W), dtype=bool)
190
+ binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
191
+ return binary_segments
192
+
193
+ def _load_multiple_pngs(self, frame_id):
194
+ """
195
+ load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
196
+ Args:
197
+ frame_id: int, define the mask path
198
+ Return:
199
+ binary_segments: dict
200
+ """
201
+ # get the path
202
+ all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
203
+ num_objects = len(all_objects)
204
+ assert num_objects > 0
205
+
206
+ # load the masks
207
+ binary_segments = {}
208
+ for obj_folder in all_objects:
209
+ # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
210
+ obj_id = int(obj_folder.split("/")[-1])
211
+ obj_id = obj_id + 1 # offset 1 as bg is 0
212
+ mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
213
+ if os.path.exists(mask_path):
214
+ mask = np.array(PILImage.open(mask_path))
215
+ else:
216
+ mask = np.zeros((self.H, self.W), dtype=bool)
217
+ binary_segments[obj_id] = torch.from_numpy(mask > 0)
218
+
219
+ return binary_segments
220
+
221
+ def __len__(self):
222
+ return
223
+
224
+
225
+ class LazySegments:
226
+ """
227
+ Only decodes segments that are actually used.
228
+ """
229
+
230
+ def __init__(self):
231
+ self.segments = {}
232
+ self.cache = {}
233
+
234
+ def __setitem__(self, key, item):
235
+ self.segments[key] = item
236
+
237
+ def __getitem__(self, key):
238
+ if key in self.cache:
239
+ return self.cache[key]
240
+ rle = self.segments[key]
241
+ mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
242
+ self.cache[key] = mask
243
+ return mask
244
+
245
+ def __contains__(self, key):
246
+ return key in self.segments
247
+
248
+ def __len__(self):
249
+ return len(self.segments)
250
+
251
+ def keys(self):
252
+ return self.segments.keys()
253
+
254
+
255
+ class SA1BSegmentLoader:
256
+ def __init__(
257
+ self,
258
+ video_mask_path,
259
+ mask_area_frac_thresh=1.1,
260
+ video_frame_path=None,
261
+ uncertain_iou=-1,
262
+ ):
263
+ with open(video_mask_path, "r") as f:
264
+ self.frame_annots = json.load(f)
265
+
266
+ if mask_area_frac_thresh <= 1.0:
267
+ # Lazily read frame
268
+ orig_w, orig_h = PILImage.open(video_frame_path).size
269
+ area = orig_w * orig_h
270
+
271
+ self.frame_annots = self.frame_annots["annotations"]
272
+
273
+ rle_masks = []
274
+ for frame_annot in self.frame_annots:
275
+ if not frame_annot["area"] > 0:
276
+ continue
277
+ if ("uncertain_iou" in frame_annot) and (frame_annot["uncertain_iou"] < uncertain_iou):
278
+ # uncertain_iou is stability score
279
+ continue
280
+ if mask_area_frac_thresh <= 1.0 and (frame_annot["area"] / area) >= mask_area_frac_thresh:
281
+ continue
282
+ rle_masks.append(frame_annot["segmentation"])
283
+
284
+ self.segments = LazySegments()
285
+ for i, rle in enumerate(rle_masks):
286
+ self.segments[i] = rle
287
+
288
+ def load(self, frame_idx):
289
+ return self.segments
@@ -0,0 +1,290 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from typing import Dict, List
9
+
10
+ import torch
11
+ import torch.distributed
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from training.trainer import CORE_LOSS_KEY
16
+
17
+ from training.utils.distributed import get_world_size, is_dist_avail_and_initialized
18
+
19
+
20
+ def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
21
+ """
22
+ Compute the DICE loss, similar to generalized IOU for masks
23
+ Args:
24
+ inputs: A float tensor of arbitrary shape.
25
+ The predictions for each example.
26
+ targets: A float tensor with the same shape as inputs. Stores the binary
27
+ classification label for each element in inputs
28
+ (0 for the negative class and 1 for the positive class).
29
+ num_objects: Number of objects in the batch
30
+ loss_on_multimask: True if multimask prediction is enabled
31
+ Returns:
32
+ Dice loss tensor
33
+ """
34
+ inputs = inputs.sigmoid()
35
+ if loss_on_multimask:
36
+ # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
37
+ assert inputs.dim() == 4 and targets.dim() == 4
38
+ # flatten spatial dimension while keeping multimask channel dimension
39
+ inputs = inputs.flatten(2)
40
+ targets = targets.flatten(2)
41
+ numerator = 2 * (inputs * targets).sum(-1)
42
+ else:
43
+ inputs = inputs.flatten(1)
44
+ numerator = 2 * (inputs * targets).sum(1)
45
+ denominator = inputs.sum(-1) + targets.sum(-1)
46
+ loss = 1 - (numerator + 1) / (denominator + 1)
47
+ if loss_on_multimask:
48
+ return loss / num_objects
49
+ return loss.sum() / num_objects
50
+
51
+
52
+ def sigmoid_focal_loss(
53
+ inputs,
54
+ targets,
55
+ num_objects,
56
+ alpha: float = 0.25,
57
+ gamma: float = 2,
58
+ loss_on_multimask=False,
59
+ ):
60
+ """
61
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
62
+ Args:
63
+ inputs: A float tensor of arbitrary shape.
64
+ The predictions for each example.
65
+ targets: A float tensor with the same shape as inputs. Stores the binary
66
+ classification label for each element in inputs
67
+ (0 for the negative class and 1 for the positive class).
68
+ num_objects: Number of objects in the batch
69
+ alpha: (optional) Weighting factor in range (0,1) to balance
70
+ positive vs negative examples. Default = -1 (no weighting).
71
+ gamma: Exponent of the modulating factor (1 - p_t) to
72
+ balance easy vs hard examples.
73
+ loss_on_multimask: True if multimask prediction is enabled
74
+ Returns:
75
+ focal loss tensor
76
+ """
77
+ prob = inputs.sigmoid()
78
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
79
+ p_t = prob * targets + (1 - prob) * (1 - targets)
80
+ loss = ce_loss * ((1 - p_t) ** gamma)
81
+
82
+ if alpha >= 0:
83
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
84
+ loss = alpha_t * loss
85
+
86
+ if loss_on_multimask:
87
+ # loss is [N, M, H, W] where M corresponds to multiple predicted masks
88
+ assert loss.dim() == 4
89
+ return loss.flatten(2).mean(-1) / num_objects # average over spatial dims
90
+ return loss.mean(1).sum() / num_objects
91
+
92
+
93
+ def iou_loss(inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False):
94
+ """
95
+ Args:
96
+ inputs: A float tensor of arbitrary shape.
97
+ The predictions for each example.
98
+ targets: A float tensor with the same shape as inputs. Stores the binary
99
+ classification label for each element in inputs
100
+ (0 for the negative class and 1 for the positive class).
101
+ pred_ious: A float tensor containing the predicted IoUs scores per mask
102
+ num_objects: Number of objects in the batch
103
+ loss_on_multimask: True if multimask prediction is enabled
104
+ use_l1_loss: Whether to use L1 loss is used instead of MSE loss
105
+ Returns:
106
+ IoU loss tensor
107
+ """
108
+ assert inputs.dim() == 4 and targets.dim() == 4
109
+ pred_mask = inputs.flatten(2) > 0
110
+ gt_mask = targets.flatten(2) > 0
111
+ area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
112
+ area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
113
+ actual_ious = area_i / torch.clamp(area_u, min=1.0)
114
+
115
+ if use_l1_loss:
116
+ loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
117
+ else:
118
+ loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
119
+ if loss_on_multimask:
120
+ return loss / num_objects
121
+ return loss.sum() / num_objects
122
+
123
+
124
+ class MultiStepMultiMasksAndIous(nn.Module):
125
+ def __init__(
126
+ self,
127
+ weight_dict,
128
+ focal_alpha=0.25,
129
+ focal_gamma=2,
130
+ supervise_all_iou=False,
131
+ iou_use_l1_loss=False,
132
+ pred_obj_scores=False,
133
+ focal_gamma_obj_score=0.0,
134
+ focal_alpha_obj_score=-1,
135
+ ):
136
+ """
137
+ This class computes the multi-step multi-mask and IoU losses.
138
+ Args:
139
+ weight_dict: dict containing weights for focal, dice, iou losses
140
+ focal_alpha: alpha for sigmoid focal loss
141
+ focal_gamma: gamma for sigmoid focal loss
142
+ supervise_all_iou: if True, back-prop iou losses for all predicted masks
143
+ iou_use_l1_loss: use L1 loss instead of MSE loss for iou
144
+ pred_obj_scores: if True, compute loss for object scores
145
+ focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
146
+ focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
147
+ """
148
+
149
+ super().__init__()
150
+ self.weight_dict = weight_dict
151
+ self.focal_alpha = focal_alpha
152
+ self.focal_gamma = focal_gamma
153
+ assert "loss_mask" in self.weight_dict
154
+ assert "loss_dice" in self.weight_dict
155
+ assert "loss_iou" in self.weight_dict
156
+ if "loss_class" not in self.weight_dict:
157
+ self.weight_dict["loss_class"] = 0.0
158
+
159
+ self.focal_alpha_obj_score = focal_alpha_obj_score
160
+ self.focal_gamma_obj_score = focal_gamma_obj_score
161
+ self.supervise_all_iou = supervise_all_iou
162
+ self.iou_use_l1_loss = iou_use_l1_loss
163
+ self.pred_obj_scores = pred_obj_scores
164
+
165
+ def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
166
+ assert len(outs_batch) == len(targets_batch)
167
+ num_objects = torch.tensor(
168
+ (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
169
+ ) # Number of objects is fixed within a batch
170
+ if is_dist_avail_and_initialized():
171
+ torch.distributed.all_reduce(num_objects)
172
+ num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()
173
+
174
+ losses = defaultdict(int)
175
+ for outs, targets in zip(outs_batch, targets_batch):
176
+ cur_losses = self._forward(outs, targets, num_objects)
177
+ for k, v in cur_losses.items():
178
+ losses[k] += v
179
+
180
+ return losses
181
+
182
+ def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
183
+ """
184
+ Compute the losses related to the masks: the focal loss and the dice loss.
185
+ and also the MAE or MSE loss between predicted IoUs and actual IoUs.
186
+
187
+ Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
188
+ of shape [N, M, H, W], where M could be 1 or larger, corresponding to
189
+ one or multiple predicted masks from a click.
190
+
191
+ We back-propagate focal, dice losses only on the prediction channel
192
+ with the lowest focal+dice loss between predicted mask and ground-truth.
193
+ If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
194
+ """
195
+
196
+ target_masks = targets.unsqueeze(1).float()
197
+ assert target_masks.dim() == 4 # [N, 1, H, W]
198
+ src_masks_list = outputs["multistep_pred_multimasks_high_res"]
199
+ ious_list = outputs["multistep_pred_ious"]
200
+ object_score_logits_list = outputs["multistep_object_score_logits"]
201
+
202
+ assert len(src_masks_list) == len(ious_list)
203
+ assert len(object_score_logits_list) == len(ious_list)
204
+
205
+ # accumulate the loss over prediction steps
206
+ losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
207
+ for src_masks, ious, object_score_logits in zip(src_masks_list, ious_list, object_score_logits_list):
208
+ self._update_losses(losses, src_masks, target_masks, ious, num_objects, object_score_logits)
209
+ losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
210
+ return losses
211
+
212
+ def _update_losses(self, losses, src_masks, target_masks, ious, num_objects, object_score_logits):
213
+ target_masks = target_masks.expand_as(src_masks)
214
+ # get focal, dice and iou loss on all output masks in a prediction step
215
+ loss_multimask = sigmoid_focal_loss(
216
+ src_masks,
217
+ target_masks,
218
+ num_objects,
219
+ alpha=self.focal_alpha,
220
+ gamma=self.focal_gamma,
221
+ loss_on_multimask=True,
222
+ )
223
+ loss_multidice = dice_loss(src_masks, target_masks, num_objects, loss_on_multimask=True)
224
+ if not self.pred_obj_scores:
225
+ loss_class = torch.tensor(0.0, dtype=loss_multimask.dtype, device=loss_multimask.device)
226
+ target_obj = torch.ones(
227
+ loss_multimask.shape[0],
228
+ 1,
229
+ dtype=loss_multimask.dtype,
230
+ device=loss_multimask.device,
231
+ )
232
+ else:
233
+ target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[..., None].float()
234
+ loss_class = sigmoid_focal_loss(
235
+ object_score_logits,
236
+ target_obj,
237
+ num_objects,
238
+ alpha=self.focal_alpha_obj_score,
239
+ gamma=self.focal_gamma_obj_score,
240
+ )
241
+
242
+ loss_multiiou = iou_loss(
243
+ src_masks,
244
+ target_masks,
245
+ ious,
246
+ num_objects,
247
+ loss_on_multimask=True,
248
+ use_l1_loss=self.iou_use_l1_loss,
249
+ )
250
+ assert loss_multimask.dim() == 2
251
+ assert loss_multidice.dim() == 2
252
+ assert loss_multiiou.dim() == 2
253
+ if loss_multimask.size(1) > 1:
254
+ # take the mask indices with the smallest focal + dice loss for back propagation
255
+ loss_combo = loss_multimask * self.weight_dict["loss_mask"] + loss_multidice * self.weight_dict["loss_dice"]
256
+ best_loss_inds = torch.argmin(loss_combo, dim=-1)
257
+ batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
258
+ loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
259
+ loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
260
+ # calculate the iou prediction and slot losses only in the index
261
+ # with the minimum loss for each mask (to be consistent w/ SAM)
262
+ if self.supervise_all_iou:
263
+ loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
264
+ else:
265
+ loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
266
+ else:
267
+ loss_mask = loss_multimask
268
+ loss_dice = loss_multidice
269
+ loss_iou = loss_multiiou
270
+
271
+ # backprop focal, dice and iou loss only if obj present
272
+ loss_mask = loss_mask * target_obj
273
+ loss_dice = loss_dice * target_obj
274
+ loss_iou = loss_iou * target_obj
275
+
276
+ # sum over batch dimension (note that the losses are already divided by num_objects)
277
+ losses["loss_mask"] += loss_mask.sum()
278
+ losses["loss_dice"] += loss_dice.sum()
279
+ losses["loss_iou"] += loss_iou.sum()
280
+ losses["loss_class"] += loss_class
281
+
282
+ def reduce_loss(self, losses):
283
+ reduced_loss = 0.0
284
+ for loss_key, weight in self.weight_dict.items():
285
+ if loss_key not in losses:
286
+ raise ValueError(f"{type(self)} doesn't compute {loss_key}")
287
+ if weight != 0:
288
+ reduced_loss += losses[loss_key] * weight
289
+
290
+ return reduced_loss
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.