dgenerate-ultralytics-headless 8.3.185__py3-none-any.whl → 8.3.187__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.185.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/METADATA +6 -8
- {dgenerate_ultralytics_headless-8.3.185.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/RECORD +31 -30
- tests/test_python.py +2 -10
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/datasets/Argoverse.yaml +2 -2
- ultralytics/cfg/datasets/Objects365.yaml +3 -3
- ultralytics/cfg/datasets/SKU-110K.yaml +4 -4
- ultralytics/cfg/datasets/VOC.yaml +2 -4
- ultralytics/cfg/datasets/VisDrone.yaml +2 -2
- ultralytics/cfg/datasets/xView.yaml +2 -2
- ultralytics/data/build.py +2 -2
- ultralytics/data/utils.py +0 -2
- ultralytics/engine/exporter.py +4 -1
- ultralytics/engine/results.py +1 -4
- ultralytics/engine/trainer.py +3 -3
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/modules/sam.py +6 -6
- ultralytics/models/sam/predict.py +363 -6
- ultralytics/solutions/region_counter.py +3 -2
- ultralytics/utils/__init__.py +25 -162
- ultralytics/utils/autodevice.py +1 -1
- ultralytics/utils/benchmarks.py +9 -8
- ultralytics/utils/callbacks/wb.py +9 -3
- ultralytics/utils/downloads.py +29 -19
- ultralytics/utils/logger.py +10 -11
- ultralytics/utils/plotting.py +13 -20
- ultralytics/utils/tqdm.py +462 -0
- {dgenerate_ultralytics_headless-8.3.185.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.185.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.185.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.185.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,9 @@ segmentation tasks.
|
|
9
9
|
"""
|
10
10
|
|
11
11
|
from collections import OrderedDict
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
12
13
|
|
14
|
+
import cv2
|
13
15
|
import numpy as np
|
14
16
|
import torch
|
15
17
|
import torch.nn.functional as F
|
@@ -283,7 +285,7 @@ class Predictor(BasePredictor):
|
|
283
285
|
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
284
286
|
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
285
287
|
labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
|
286
|
-
masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
|
288
|
+
masks (List[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array with shape (H, W).
|
287
289
|
|
288
290
|
Returns:
|
289
291
|
bboxes (torch.Tensor | None): Transformed bounding boxes.
|
@@ -315,7 +317,11 @@ class Predictor(BasePredictor):
|
|
315
317
|
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
316
318
|
bboxes *= r
|
317
319
|
if masks is not None:
|
318
|
-
masks =
|
320
|
+
masks = np.asarray(masks, dtype=np.uint8)
|
321
|
+
masks = masks[None] if masks.ndim == 2 else masks
|
322
|
+
letterbox = LetterBox(dst_shape, auto=False, center=False, padding_value=0, interpolation=cv2.INTER_NEAREST)
|
323
|
+
masks = np.stack([letterbox(image=x).squeeze() for x in masks], axis=0)
|
324
|
+
masks = torch.tensor(masks, dtype=self.torch_dtype, device=self.device)
|
319
325
|
return bboxes, points, labels, masks
|
320
326
|
|
321
327
|
def generate(
|
@@ -514,7 +520,9 @@ class Predictor(BasePredictor):
|
|
514
520
|
pred_bboxes = batched_mask_to_box(masks)
|
515
521
|
# NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
|
516
522
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
517
|
-
|
523
|
+
idx = pred_scores > self.args.conf
|
524
|
+
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)[idx]
|
525
|
+
masks = masks[idx]
|
518
526
|
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
519
527
|
# Reset segment-all mode.
|
520
528
|
self.segment_all = False
|
@@ -815,9 +823,8 @@ class SAM2Predictor(Predictor):
|
|
815
823
|
if self.model.directly_add_no_mem_embed:
|
816
824
|
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
817
825
|
feats = [
|
818
|
-
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
819
|
-
|
820
|
-
][::-1]
|
826
|
+
feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats, self._bb_feat_sizes)
|
827
|
+
]
|
821
828
|
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
822
829
|
|
823
830
|
def _inference_features(
|
@@ -1678,3 +1685,353 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1678
1685
|
self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
|
1679
1686
|
for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
|
1680
1687
|
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
1688
|
+
|
1689
|
+
|
1690
|
+
class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
1691
|
+
"""
|
1692
|
+
SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
|
1693
|
+
sequence of images.
|
1694
|
+
|
1695
|
+
Attributes:
|
1696
|
+
memory_bank (list): OrderedDict: Stores the states of each image with prompts.
|
1697
|
+
obj_idx_set (set): A set to keep track of the object indices that have been added.
|
1698
|
+
obj_id_to_idx (OrderedDict): Maps object IDs to their corresponding indices.
|
1699
|
+
obj_idx_to_id (OrderedDict): Maps object indices to their corresponding IDs.
|
1700
|
+
|
1701
|
+
Methods:
|
1702
|
+
get_model: Retrieves and configures the model with binarization enabled.
|
1703
|
+
inference: Performs inference on a single image with optional prompts and object IDs.
|
1704
|
+
postprocess: Post-processes the predictions to apply non-overlapping constraints if required.
|
1705
|
+
update_memory: Append the imgState to the memory_bank and update the memory for the model.
|
1706
|
+
track_step: Tracking step for the current image state to predict masks.
|
1707
|
+
get_maskmem_enc: Get memory and positional encoding from the memory bank.
|
1708
|
+
|
1709
|
+
Examples:
|
1710
|
+
>>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
|
1711
|
+
>>> predictor(source=support_img1, bboxes=bboxes1, obj_ids=labels1, update_memory=True)
|
1712
|
+
>>> results1 = predictor(source=query_img1)
|
1713
|
+
>>> predictor(source=support_img2, bboxes=bboxes2, obj_ids=labels2, update_memory=True)
|
1714
|
+
>>> results2 = predictor(source=query_img2)
|
1715
|
+
"""
|
1716
|
+
|
1717
|
+
def __init__(
|
1718
|
+
self,
|
1719
|
+
cfg: Any = DEFAULT_CFG,
|
1720
|
+
overrides: Optional[Dict[str, Any]] = None,
|
1721
|
+
max_obj_num: int = 3,
|
1722
|
+
_callbacks: Optional[Dict[str, Any]] = None,
|
1723
|
+
) -> None:
|
1724
|
+
"""
|
1725
|
+
Initialize the predictor with configuration and optional overrides.
|
1726
|
+
|
1727
|
+
This constructor initializes the SAM2DynamicInteractivePredictor with a given configuration, applies any
|
1728
|
+
specified overrides
|
1729
|
+
|
1730
|
+
Args:
|
1731
|
+
cfg (Dict[str, Any]): Configuration dictionary containing default settings.
|
1732
|
+
overrides (Dict[str, Any] | None): Dictionary of values to override default configuration.
|
1733
|
+
max_obj_num (int): Maximum number of objects to track. Default is 3. this is set to keep fix feature size for the model.
|
1734
|
+
_callbacks (Dict[str, Any] | None): Dictionary of callback functions to customize behavior.
|
1735
|
+
|
1736
|
+
Examples:
|
1737
|
+
>>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
|
1738
|
+
>>> predictor_example_with_imgsz = SAM2DynamicInteractivePredictor(overrides={"imgsz": 640})
|
1739
|
+
>>> predictor_example_with_callback = SAM2DynamicInteractivePredictor(
|
1740
|
+
... _callbacks={"on_predict_start": custom_callback}
|
1741
|
+
... )
|
1742
|
+
"""
|
1743
|
+
super().__init__(cfg, overrides, _callbacks)
|
1744
|
+
self.non_overlap_masks = True
|
1745
|
+
|
1746
|
+
# Initialize the memory bank to store image states
|
1747
|
+
# NOTE: probably need to use dict for better query
|
1748
|
+
self.memory_bank = []
|
1749
|
+
|
1750
|
+
# Initialize the object index set and mappings
|
1751
|
+
self.obj_idx_set = set()
|
1752
|
+
self.obj_id_to_idx = OrderedDict()
|
1753
|
+
self.obj_idx_to_id = OrderedDict()
|
1754
|
+
self._max_obj_num = max_obj_num
|
1755
|
+
for i in range(self._max_obj_num):
|
1756
|
+
self.obj_id_to_idx[i + 1] = i
|
1757
|
+
self.obj_idx_to_id[i] = i + 1
|
1758
|
+
|
1759
|
+
@smart_inference_mode()
|
1760
|
+
def inference(
|
1761
|
+
self,
|
1762
|
+
img: Union[torch.Tensor, np.ndarray],
|
1763
|
+
bboxes: Optional[List[List[float]]] = None,
|
1764
|
+
masks: Optional[Union[torch.Tensor, np.ndarray]] = None,
|
1765
|
+
points: Optional[List[List[float]]] = None,
|
1766
|
+
labels: Optional[List[int]] = None,
|
1767
|
+
obj_ids: Optional[List[int]] = None,
|
1768
|
+
update_memory: bool = False,
|
1769
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1770
|
+
"""
|
1771
|
+
Perform inference on a single image with optional bounding boxes, masks, points and object IDs.
|
1772
|
+
It has two modes: one is to run inference on a single image without updating the memory,
|
1773
|
+
and the other is to update the memory with the provided prompts and object IDs.
|
1774
|
+
When update_memory is True, it will update the memory with the provided prompts and obj_ids.
|
1775
|
+
When update_memory is False, it will only run inference on the provided image without updating the memory.
|
1776
|
+
|
1777
|
+
Args:
|
1778
|
+
img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
|
1779
|
+
bboxes (List[List[float]] | None): Optional list of bounding boxes to update the memory.
|
1780
|
+
masks (List[torch.Tensor | np.ndarray] | None): Optional masks to update the memory.
|
1781
|
+
points (List[List[float]] | None): Optional list of points to update the memory, each point is [x, y].
|
1782
|
+
labels (List[int] | None): Optional list of object IDs corresponding to the points (>0 for positive, 0 for negative).
|
1783
|
+
obj_ids (List[int] | None): Optional list of object IDs corresponding to the prompts.
|
1784
|
+
update_memory (bool): Flag to indicate whether to update the memory with new objects.
|
1785
|
+
|
1786
|
+
Returns:
|
1787
|
+
res_masks (torch.Tensor): The output masks in shape (C, H, W)
|
1788
|
+
object_score_logits (torch.Tensor): Quality scores for each mask
|
1789
|
+
"""
|
1790
|
+
self.get_im_features(img)
|
1791
|
+
points, labels, masks = self._prepare_prompts(
|
1792
|
+
dst_shape=self.imgsz,
|
1793
|
+
src_shape=self.batch[1][0].shape[:2],
|
1794
|
+
points=points,
|
1795
|
+
bboxes=bboxes,
|
1796
|
+
labels=labels,
|
1797
|
+
masks=masks,
|
1798
|
+
)
|
1799
|
+
|
1800
|
+
if update_memory:
|
1801
|
+
if isinstance(obj_ids, int):
|
1802
|
+
obj_ids = [obj_ids]
|
1803
|
+
assert obj_ids is not None, "obj_ids must be provided when update_memory is True"
|
1804
|
+
assert masks is not None or points is not None, (
|
1805
|
+
"bboxes, masks, or points must be provided when update_memory is True"
|
1806
|
+
)
|
1807
|
+
if points is None: # placeholder
|
1808
|
+
points = torch.zeros((len(obj_ids), 0, 2), dtype=self.torch_dtype, device=self.device)
|
1809
|
+
labels = torch.zeros((len(obj_ids), 0), dtype=torch.int32, device=self.device)
|
1810
|
+
if masks is not None:
|
1811
|
+
assert len(masks) == len(obj_ids), "masks and obj_ids must have the same length."
|
1812
|
+
assert len(points) == len(obj_ids), "points and obj_ids must have the same length."
|
1813
|
+
self.update_memory(obj_ids, points, labels, masks)
|
1814
|
+
|
1815
|
+
current_out = self.track_step()
|
1816
|
+
pred_masks, pred_scores = current_out["pred_masks"], current_out["object_score_logits"]
|
1817
|
+
# filter the masks and logits based on the object indices
|
1818
|
+
if len(self.obj_idx_set) == 0:
|
1819
|
+
raise RuntimeError("No objects have been added to the state. Please add objects before inference.")
|
1820
|
+
idx = list(self.obj_idx_set) # cls id
|
1821
|
+
pred_masks, pred_scores = pred_masks[idx], pred_scores[idx]
|
1822
|
+
# the original score are in [-32,32], and a object score larger than 0 means the object is present, we map it to [-1,1] range,
|
1823
|
+
# and use a activate function to make sure the object score logits are non-negative, so that we can use it as a mask
|
1824
|
+
pred_scores = torch.clamp_(pred_scores / 32, min=0)
|
1825
|
+
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
1826
|
+
|
1827
|
+
def get_im_features(self, img: Union[torch.Tensor, np.ndarray]) -> None:
|
1828
|
+
"""
|
1829
|
+
Initialize the image state by processing the input image and extracting features.
|
1830
|
+
|
1831
|
+
Args:
|
1832
|
+
img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
|
1833
|
+
"""
|
1834
|
+
vis_feats, vis_pos_embed, feat_sizes = SAM2VideoPredictor.get_im_features(self, img, batch=self._max_obj_num)
|
1835
|
+
self.high_res_features = [
|
1836
|
+
feat.permute(1, 2, 0).view(*feat.shape[1:], *feat_size)
|
1837
|
+
for feat, feat_size in zip(vis_feats[:-1], feat_sizes[:-1])
|
1838
|
+
]
|
1839
|
+
|
1840
|
+
self.vision_feats = vis_feats
|
1841
|
+
self.vision_pos_embeds = vis_pos_embed
|
1842
|
+
self.feat_sizes = feat_sizes
|
1843
|
+
|
1844
|
+
@smart_inference_mode()
|
1845
|
+
def update_memory(
|
1846
|
+
self,
|
1847
|
+
obj_ids: List[int] = None,
|
1848
|
+
points: Optional[torch.Tensor] = None,
|
1849
|
+
labels: Optional[torch.Tensor] = None,
|
1850
|
+
masks: Optional[torch.Tensor] = None,
|
1851
|
+
) -> None:
|
1852
|
+
"""
|
1853
|
+
Append the imgState to the memory_bank and update the memory for the model.
|
1854
|
+
|
1855
|
+
Args:
|
1856
|
+
obj_ids (List[int]): List of object IDs corresponding to the prompts.
|
1857
|
+
points (torch.Tensor | None): Tensor of shape (B, N, 2) representing the input points for N objects.
|
1858
|
+
labels (torch.Tensor | None): Tensor of shape (B, N) representing the labels for the input points.
|
1859
|
+
masks (torch.Tensor | None): Optional tensor of shape (N, H, W) representing the input masks for N objects.
|
1860
|
+
"""
|
1861
|
+
consolidated_out = {
|
1862
|
+
"maskmem_features": None,
|
1863
|
+
"maskmem_pos_enc": None,
|
1864
|
+
"pred_masks": torch.full(
|
1865
|
+
size=(self._max_obj_num, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
|
1866
|
+
fill_value=-1024.0,
|
1867
|
+
dtype=self.torch_dtype,
|
1868
|
+
device=self.device,
|
1869
|
+
),
|
1870
|
+
"obj_ptr": torch.full(
|
1871
|
+
size=(self._max_obj_num, self.model.hidden_dim),
|
1872
|
+
fill_value=-1024.0,
|
1873
|
+
dtype=self.torch_dtype,
|
1874
|
+
device=self.device,
|
1875
|
+
),
|
1876
|
+
"object_score_logits": torch.full(
|
1877
|
+
size=(self._max_obj_num, 1),
|
1878
|
+
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
1879
|
+
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
1880
|
+
fill_value=-32, # 10.0,
|
1881
|
+
dtype=self.torch_dtype,
|
1882
|
+
device=self.device,
|
1883
|
+
),
|
1884
|
+
}
|
1885
|
+
|
1886
|
+
for i, obj_id in enumerate(obj_ids):
|
1887
|
+
assert obj_id < self._max_obj_num
|
1888
|
+
obj_idx = self._obj_id_to_idx(int(obj_id))
|
1889
|
+
self.obj_idx_set.add(obj_idx)
|
1890
|
+
point, label = points[[i]], labels[[i]]
|
1891
|
+
mask = masks[[i]][None] if masks is not None else None
|
1892
|
+
# Currently, only bbox prompt or mask prompt is supported, so we assert that bbox is not None.
|
1893
|
+
assert point is not None or mask is not None, "Either bbox, points or mask is required"
|
1894
|
+
out = self.track_step(obj_idx, point, label, mask)
|
1895
|
+
if out is not None:
|
1896
|
+
obj_mask = out["pred_masks"]
|
1897
|
+
assert obj_mask.shape[-2:] == consolidated_out["pred_masks"].shape[-2:], (
|
1898
|
+
f"Expected mask shape {consolidated_out['pred_masks'].shape[-2:]} but got {obj_mask.shape[-2:]} for object {obj_idx}."
|
1899
|
+
)
|
1900
|
+
consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = obj_mask
|
1901
|
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
1902
|
+
|
1903
|
+
if "object_score_logits" in out.keys():
|
1904
|
+
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out["object_score_logits"]
|
1905
|
+
|
1906
|
+
high_res_masks = F.interpolate(
|
1907
|
+
consolidated_out["pred_masks"].to(self.device, non_blocking=True),
|
1908
|
+
size=self.imgsz,
|
1909
|
+
mode="bilinear",
|
1910
|
+
align_corners=False,
|
1911
|
+
)
|
1912
|
+
|
1913
|
+
if self.model.non_overlap_masks_for_mem_enc:
|
1914
|
+
high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
|
1915
|
+
maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
|
1916
|
+
current_vision_feats=self.vision_feats,
|
1917
|
+
feat_sizes=self.feat_sizes,
|
1918
|
+
pred_masks_high_res=high_res_masks,
|
1919
|
+
object_score_logits=consolidated_out["object_score_logits"],
|
1920
|
+
is_mask_from_pts=True,
|
1921
|
+
)
|
1922
|
+
consolidated_out["maskmem_features"] = maskmem_features
|
1923
|
+
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
1924
|
+
self.memory_bank.append(consolidated_out)
|
1925
|
+
|
1926
|
+
def _prepare_memory_conditioned_features(self, obj_idx: Optional[int]) -> torch.Tensor:
|
1927
|
+
"""
|
1928
|
+
Prepare the memory-conditioned features for the current image state. If obj_idx is provided, it supposes to
|
1929
|
+
prepare features for a specific prompted object in the image. If obj_idx is None, it prepares features for all
|
1930
|
+
objects in the image. If there is no memory, it will directly add a no-memory embedding to the current vision
|
1931
|
+
features. If there is memory, it will use the memory features from previous frames to condition the current
|
1932
|
+
vision features using a transformer attention mechanism.
|
1933
|
+
|
1934
|
+
Args:
|
1935
|
+
obj_idx (int | None): The index of the object for which to prepare the features.
|
1936
|
+
|
1937
|
+
Returns:
|
1938
|
+
pix_feat_with_mem (torch.Tensor): The memory-conditioned pixel features.
|
1939
|
+
"""
|
1940
|
+
if len(self.memory_bank) == 0 or isinstance(obj_idx, int):
|
1941
|
+
# for initial conditioning frames with, encode them without using any previous memory
|
1942
|
+
# directly add no-mem embedding (instead of using the transformer encoder)
|
1943
|
+
pix_feat_with_mem = self.vision_feats[-1] + self.model.no_mem_embed
|
1944
|
+
else:
|
1945
|
+
# for inference frames, use the memory features from previous frames
|
1946
|
+
memory, memory_pos_embed = self.get_maskmem_enc()
|
1947
|
+
pix_feat_with_mem = self.model.memory_attention(
|
1948
|
+
curr=self.vision_feats[-1:],
|
1949
|
+
curr_pos=self.vision_pos_embeds[-1:],
|
1950
|
+
memory=memory,
|
1951
|
+
memory_pos=memory_pos_embed,
|
1952
|
+
num_obj_ptr_tokens=0, # num_obj_ptr_tokens
|
1953
|
+
)
|
1954
|
+
# reshape the output (HW)BC => BCHW
|
1955
|
+
return pix_feat_with_mem.permute(1, 2, 0).view(
|
1956
|
+
self._max_obj_num,
|
1957
|
+
self.model.memory_attention.d_model,
|
1958
|
+
*self.feat_sizes[-1],
|
1959
|
+
)
|
1960
|
+
|
1961
|
+
def get_maskmem_enc(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
1962
|
+
"""Get the memory and positional encoding from the memory, which is used to condition the current image
|
1963
|
+
features.
|
1964
|
+
"""
|
1965
|
+
to_cat_memory, to_cat_memory_pos_embed = [], []
|
1966
|
+
for consolidated_out in self.memory_bank:
|
1967
|
+
to_cat_memory.append(consolidated_out["maskmem_features"].flatten(2).permute(2, 0, 1)) # (H*W, B, C)
|
1968
|
+
maskmem_enc = consolidated_out["maskmem_pos_enc"][-1].flatten(2).permute(2, 0, 1)
|
1969
|
+
maskmem_enc = maskmem_enc + self.model.maskmem_tpos_enc[self.model.num_maskmem - 1]
|
1970
|
+
to_cat_memory_pos_embed.append(maskmem_enc)
|
1971
|
+
|
1972
|
+
memory = torch.cat(to_cat_memory, dim=0)
|
1973
|
+
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
|
1974
|
+
return memory, memory_pos_embed
|
1975
|
+
|
1976
|
+
def _obj_id_to_idx(self, obj_id: int) -> Optional[int]:
|
1977
|
+
"""
|
1978
|
+
Map client-side object id to model-side object index.
|
1979
|
+
|
1980
|
+
Args:
|
1981
|
+
obj_id (int): The client-side object ID.
|
1982
|
+
|
1983
|
+
Returns:
|
1984
|
+
(int): The model-side object index, or None if not found.
|
1985
|
+
"""
|
1986
|
+
return self.obj_id_to_idx.get(obj_id, None)
|
1987
|
+
|
1988
|
+
def track_step(
|
1989
|
+
self,
|
1990
|
+
obj_idx: Optional[int] = None,
|
1991
|
+
point: Optional[torch.Tensor] = None,
|
1992
|
+
label: Optional[torch.Tensor] = None,
|
1993
|
+
mask: Optional[torch.Tensor] = None,
|
1994
|
+
) -> Dict[str, Any]:
|
1995
|
+
"""
|
1996
|
+
Tracking step for the current image state to predict masks.
|
1997
|
+
|
1998
|
+
This method processes the image features and runs the SAM heads to predict masks. If obj_idx is provided, it
|
1999
|
+
processes the features for a specific prompted object in the image. If obj_idx is None, it processes the
|
2000
|
+
features for all objects in the image. The method supports both mask-based output without SAM and full
|
2001
|
+
SAM processing with memory-conditioned features.
|
2002
|
+
|
2003
|
+
Args:
|
2004
|
+
obj_idx (int | None): The index of the object for which to predict masks. If None, it processes all objects.
|
2005
|
+
point (torch.Tensor | None): The coordinates of the points of interest with shape (N, 2).
|
2006
|
+
label (torch.Tensor | None): The labels corresponding to the points where 1 means positive clicks, 0 means negative clicks.
|
2007
|
+
mask (torch.Tensor | None): The mask input for the object with shape (H, W).
|
2008
|
+
|
2009
|
+
Returns:
|
2010
|
+
current_out (Dict[str, Any]): A dictionary containing the current output with mask predictions and object pointers.
|
2011
|
+
Keys include 'point_inputs', 'mask_inputs', 'pred_masks', 'pred_masks_high_res', 'obj_ptr', 'object_score_logits'.
|
2012
|
+
"""
|
2013
|
+
current_out = {}
|
2014
|
+
if mask is not None and self.model.use_mask_input_as_output_without_sam:
|
2015
|
+
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
2016
|
+
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
2017
|
+
pix_feat = self.vision_feats[-1].permute(1, 2, 0)
|
2018
|
+
pix_feat = pix_feat.view(-1, self.model.memory_attention.d_model, *self.feat_sizes[-1])
|
2019
|
+
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._use_mask_as_output(mask)
|
2020
|
+
else:
|
2021
|
+
# fused the visual feature with previous memory features in the memory bank
|
2022
|
+
pix_feat_with_mem = self._prepare_memory_conditioned_features(obj_idx)
|
2023
|
+
# calculate the first feature if adding obj_idx exists(means adding prompts)
|
2024
|
+
pix_feat_with_mem = pix_feat_with_mem[0:1] if obj_idx is not None else pix_feat_with_mem
|
2025
|
+
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._forward_sam_heads(
|
2026
|
+
backbone_features=pix_feat_with_mem,
|
2027
|
+
point_inputs={"point_coords": point, "point_labels": label} if obj_idx is not None else None,
|
2028
|
+
mask_inputs=mask,
|
2029
|
+
multimask_output=False,
|
2030
|
+
high_res_features=[feat[: pix_feat_with_mem.size(0)] for feat in self.high_res_features],
|
2031
|
+
)
|
2032
|
+
current_out["pred_masks"] = low_res_masks
|
2033
|
+
current_out["pred_masks_high_res"] = high_res_masks
|
2034
|
+
current_out["obj_ptr"] = obj_ptr
|
2035
|
+
current_out["object_score_logits"] = object_score_logits
|
2036
|
+
|
2037
|
+
return current_out
|
@@ -115,8 +115,9 @@ class RegionCounter(BaseSolution):
|
|
115
115
|
|
116
116
|
# Display region counts
|
117
117
|
for region in self.counting_regions:
|
118
|
-
|
119
|
-
pts =
|
118
|
+
poly = region["polygon"]
|
119
|
+
pts = list(map(tuple, np.array(poly.exterior.coords, dtype=np.int32)))
|
120
|
+
(x1, y1), (x2, y2) = [(int(poly.centroid.x), int(poly.centroid.y))] * 2
|
120
121
|
annotator.draw_region(pts, region["region_color"], self.line_width * 2)
|
121
122
|
annotator.adaptive_label(
|
122
123
|
[x1, y1, x2, y2],
|