ultralytics 8.2.67__py3-none-any.whl → 8.2.69__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- tests/test_cli.py +4 -16
- ultralytics/__init__.py +1 -1
- ultralytics/data/augment.py +1 -1
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/models/fastsam/__init__.py +1 -2
- ultralytics/models/fastsam/model.py +18 -0
- ultralytics/models/fastsam/predict.py +115 -1
- ultralytics/utils/ops.py +1 -1
- {ultralytics-8.2.67.dist-info → ultralytics-8.2.69.dist-info}/METADATA +1 -1
- {ultralytics-8.2.67.dist-info → ultralytics-8.2.69.dist-info}/RECORD +14 -14
- {ultralytics-8.2.67.dist-info → ultralytics-8.2.69.dist-info}/WHEEL +1 -1
- ultralytics/models/fastsam/prompt.py +0 -352
- {ultralytics-8.2.67.dist-info → ultralytics-8.2.69.dist-info}/LICENSE +0 -0
- {ultralytics-8.2.67.dist-info → ultralytics-8.2.69.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.2.67.dist-info → ultralytics-8.2.69.dist-info}/top_level.txt +0 -0
tests/test_cli.py
CHANGED
|
@@ -68,7 +68,6 @@ def test_fastsam(task="segment", model=WEIGHTS_DIR / "FastSAM-s.pt", data="coco8
|
|
|
68
68
|
run(f"yolo segment predict model={model} source={source} imgsz=32 save save_crop save_txt")
|
|
69
69
|
|
|
70
70
|
from ultralytics import FastSAM
|
|
71
|
-
from ultralytics.models.fastsam import FastSAMPrompt
|
|
72
71
|
from ultralytics.models.sam import Predictor
|
|
73
72
|
|
|
74
73
|
# Create a FastSAM model
|
|
@@ -81,21 +80,10 @@ def test_fastsam(task="segment", model=WEIGHTS_DIR / "FastSAM-s.pt", data="coco8
|
|
|
81
80
|
# Remove small regions
|
|
82
81
|
new_masks, _ = Predictor.remove_small_regions(everything_results[0].masks.data, min_area=20)
|
|
83
82
|
|
|
84
|
-
#
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
# Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
|
|
89
|
-
ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
|
|
90
|
-
|
|
91
|
-
# Text prompt
|
|
92
|
-
ann = prompt_process.text_prompt(text="a photo of a dog")
|
|
93
|
-
|
|
94
|
-
# Point prompt
|
|
95
|
-
# Points default [[0,0]] [[x1,y1],[x2,y2]]
|
|
96
|
-
# Point_label default [0] [1,0] 0:background, 1:foreground
|
|
97
|
-
ann = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])
|
|
98
|
-
prompt_process.plot(annotations=ann, output="./")
|
|
83
|
+
# Run inference with bboxes and points and texts prompt at the same time
|
|
84
|
+
results = sam_model(
|
|
85
|
+
source, bboxes=[439, 437, 524, 709], points=[[200, 200]], labels=[1], texts="a photo of a dog"
|
|
86
|
+
)
|
|
99
87
|
|
|
100
88
|
|
|
101
89
|
def test_mobilesam():
|
ultralytics/__init__.py
CHANGED
ultralytics/data/augment.py
CHANGED
|
@@ -2221,7 +2221,7 @@ class RandomLoadText:
|
|
|
2221
2221
|
pos_labels = np.unique(cls).tolist()
|
|
2222
2222
|
|
|
2223
2223
|
if len(pos_labels) > self.max_samples:
|
|
2224
|
-
pos_labels =
|
|
2224
|
+
pos_labels = random.sample(pos_labels, k=self.max_samples)
|
|
2225
2225
|
|
|
2226
2226
|
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
|
|
2227
2227
|
neg_labels = [i for i in range(num_classes) if i not in pos_labels]
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import statistics
|
|
5
|
+
import time
|
|
6
|
+
from typing import List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GCPRegions:
|
|
12
|
+
"""
|
|
13
|
+
A class for managing and analyzing Google Cloud Platform (GCP) regions.
|
|
14
|
+
|
|
15
|
+
This class provides functionality to initialize, categorize, and analyze GCP regions based on their
|
|
16
|
+
geographical location, tier classification, and network latency.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.
|
|
20
|
+
|
|
21
|
+
Methods:
|
|
22
|
+
tier1: Returns a list of tier 1 GCP regions.
|
|
23
|
+
tier2: Returns a list of tier 2 GCP regions.
|
|
24
|
+
lowest_latency: Determines the GCP region(s) with the lowest network latency.
|
|
25
|
+
|
|
26
|
+
Examples:
|
|
27
|
+
>>> from ultralytics.hub.google import GCPRegions
|
|
28
|
+
>>> regions = GCPRegions()
|
|
29
|
+
>>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3)
|
|
30
|
+
>>> print(f"Lowest latency region: {lowest_latency_region[0][0]}")
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self):
|
|
34
|
+
"""Initializes the GCPRegions class with predefined Google Cloud Platform regions and their details."""
|
|
35
|
+
self.regions = {
|
|
36
|
+
"asia-east1": (1, "Taiwan", "China"),
|
|
37
|
+
"asia-east2": (2, "Hong Kong", "China"),
|
|
38
|
+
"asia-northeast1": (1, "Tokyo", "Japan"),
|
|
39
|
+
"asia-northeast2": (1, "Osaka", "Japan"),
|
|
40
|
+
"asia-northeast3": (2, "Seoul", "South Korea"),
|
|
41
|
+
"asia-south1": (2, "Mumbai", "India"),
|
|
42
|
+
"asia-south2": (2, "Delhi", "India"),
|
|
43
|
+
"asia-southeast1": (2, "Jurong West", "Singapore"),
|
|
44
|
+
"asia-southeast2": (2, "Jakarta", "Indonesia"),
|
|
45
|
+
"australia-southeast1": (2, "Sydney", "Australia"),
|
|
46
|
+
"australia-southeast2": (2, "Melbourne", "Australia"),
|
|
47
|
+
"europe-central2": (2, "Warsaw", "Poland"),
|
|
48
|
+
"europe-north1": (1, "Hamina", "Finland"),
|
|
49
|
+
"europe-southwest1": (1, "Madrid", "Spain"),
|
|
50
|
+
"europe-west1": (1, "St. Ghislain", "Belgium"),
|
|
51
|
+
"europe-west10": (2, "Berlin", "Germany"),
|
|
52
|
+
"europe-west12": (2, "Turin", "Italy"),
|
|
53
|
+
"europe-west2": (2, "London", "United Kingdom"),
|
|
54
|
+
"europe-west3": (2, "Frankfurt", "Germany"),
|
|
55
|
+
"europe-west4": (1, "Eemshaven", "Netherlands"),
|
|
56
|
+
"europe-west6": (2, "Zurich", "Switzerland"),
|
|
57
|
+
"europe-west8": (1, "Milan", "Italy"),
|
|
58
|
+
"europe-west9": (1, "Paris", "France"),
|
|
59
|
+
"me-central1": (2, "Doha", "Qatar"),
|
|
60
|
+
"me-west1": (1, "Tel Aviv", "Israel"),
|
|
61
|
+
"northamerica-northeast1": (2, "Montreal", "Canada"),
|
|
62
|
+
"northamerica-northeast2": (2, "Toronto", "Canada"),
|
|
63
|
+
"southamerica-east1": (2, "São Paulo", "Brazil"),
|
|
64
|
+
"southamerica-west1": (2, "Santiago", "Chile"),
|
|
65
|
+
"us-central1": (1, "Iowa", "United States"),
|
|
66
|
+
"us-east1": (1, "South Carolina", "United States"),
|
|
67
|
+
"us-east4": (1, "Northern Virginia", "United States"),
|
|
68
|
+
"us-east5": (1, "Columbus", "United States"),
|
|
69
|
+
"us-south1": (1, "Dallas", "United States"),
|
|
70
|
+
"us-west1": (1, "Oregon", "United States"),
|
|
71
|
+
"us-west2": (2, "Los Angeles", "United States"),
|
|
72
|
+
"us-west3": (2, "Salt Lake City", "United States"),
|
|
73
|
+
"us-west4": (2, "Las Vegas", "United States"),
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
def tier1(self) -> List[str]:
|
|
77
|
+
"""Returns a list of GCP regions classified as tier 1 based on predefined criteria."""
|
|
78
|
+
return [region for region, info in self.regions.items() if info[0] == 1]
|
|
79
|
+
|
|
80
|
+
def tier2(self) -> List[str]:
|
|
81
|
+
"""Returns a list of GCP regions classified as tier 2 based on predefined criteria."""
|
|
82
|
+
return [region for region, info in self.regions.items() if info[0] == 2]
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def _ping_region(region: str, attempts: int = 1) -> Tuple[str, float, float, float, float]:
|
|
86
|
+
"""Pings a specified GCP region and returns latency statistics: mean, min, max, and standard deviation."""
|
|
87
|
+
url = f"https://{region}-docker.pkg.dev"
|
|
88
|
+
latencies = []
|
|
89
|
+
for _ in range(attempts):
|
|
90
|
+
try:
|
|
91
|
+
start_time = time.time()
|
|
92
|
+
_ = requests.head(url, timeout=5)
|
|
93
|
+
latency = (time.time() - start_time) * 1000 # convert latency to milliseconds
|
|
94
|
+
if latency != float("inf"):
|
|
95
|
+
latencies.append(latency)
|
|
96
|
+
except requests.RequestException:
|
|
97
|
+
pass
|
|
98
|
+
if not latencies:
|
|
99
|
+
return region, float("inf"), float("inf"), float("inf"), float("inf")
|
|
100
|
+
|
|
101
|
+
std_dev = statistics.stdev(latencies) if len(latencies) > 1 else 0
|
|
102
|
+
return region, statistics.mean(latencies), std_dev, min(latencies), max(latencies)
|
|
103
|
+
|
|
104
|
+
def lowest_latency(
|
|
105
|
+
self,
|
|
106
|
+
top: int = 1,
|
|
107
|
+
verbose: bool = False,
|
|
108
|
+
tier: Optional[int] = None,
|
|
109
|
+
attempts: int = 1,
|
|
110
|
+
) -> List[Tuple[str, float, float, float, float]]:
|
|
111
|
+
"""
|
|
112
|
+
Determines the GCP regions with the lowest latency based on ping tests.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
top (int): Number of top regions to return.
|
|
116
|
+
verbose (bool): If True, prints detailed latency information for all tested regions.
|
|
117
|
+
tier (int | None): Filter regions by tier (1 or 2). If None, all regions are tested.
|
|
118
|
+
attempts (int): Number of ping attempts per region.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
(List[Tuple[str, float, float, float, float]]): List of tuples containing region information and
|
|
122
|
+
latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
|
|
123
|
+
|
|
124
|
+
Examples:
|
|
125
|
+
>>> regions = GCPRegions()
|
|
126
|
+
>>> results = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=2)
|
|
127
|
+
>>> print(results[0][0]) # Print the name of the lowest latency region
|
|
128
|
+
"""
|
|
129
|
+
if verbose:
|
|
130
|
+
print(f"Testing GCP regions for latency (with {attempts} {'retry' if attempts == 1 else 'attempts'})...")
|
|
131
|
+
|
|
132
|
+
regions_to_test = [k for k, v in self.regions.items() if v[0] == tier] if tier else list(self.regions.keys())
|
|
133
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
|
|
134
|
+
results = list(executor.map(lambda r: self._ping_region(r, attempts), regions_to_test))
|
|
135
|
+
|
|
136
|
+
sorted_results = sorted(results, key=lambda x: x[1])
|
|
137
|
+
|
|
138
|
+
if verbose:
|
|
139
|
+
print(f"{'Region':<25} {'Location':<35} {'Tier':<5} {'Latency (ms)'}")
|
|
140
|
+
for region, mean, std, min_, max_ in sorted_results:
|
|
141
|
+
tier, city, country = self.regions[region]
|
|
142
|
+
location = f"{city}, {country}"
|
|
143
|
+
if mean == float("inf"):
|
|
144
|
+
print(f"{region:<25} {location:<35} {tier:<5} {'Timeout'}")
|
|
145
|
+
else:
|
|
146
|
+
print(f"{region:<25} {location:<35} {tier:<5} {mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})")
|
|
147
|
+
print(f"\nLowest latency region{'s' if top > 1 else ''}:")
|
|
148
|
+
for region, mean, std, min_, max_ in sorted_results[:top]:
|
|
149
|
+
tier, city, country = self.regions[region]
|
|
150
|
+
location = f"{city}, {country}"
|
|
151
|
+
print(f"{region} ({location}, {mean:.0f} ± {std:.0f} ms ({min_:.0f} - {max_:.0f}))")
|
|
152
|
+
|
|
153
|
+
return sorted_results[:top]
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# Usage example
|
|
157
|
+
if __name__ == "__main__":
|
|
158
|
+
regions = GCPRegions()
|
|
159
|
+
top_3_latency_tier1 = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=3)
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from .model import FastSAM
|
|
4
4
|
from .predict import FastSAMPredictor
|
|
5
|
-
from .prompt import FastSAMPrompt
|
|
6
5
|
from .val import FastSAMValidator
|
|
7
6
|
|
|
8
|
-
__all__ = "FastSAMPredictor", "FastSAM", "
|
|
7
|
+
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"
|
|
@@ -28,6 +28,24 @@ class FastSAM(Model):
|
|
|
28
28
|
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
|
29
29
|
super().__init__(model=model, task="segment")
|
|
30
30
|
|
|
31
|
+
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
|
|
32
|
+
"""
|
|
33
|
+
Performs segmentation prediction on the given image or video source.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
|
37
|
+
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
|
38
|
+
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
|
39
|
+
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
|
40
|
+
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
|
41
|
+
texts (list, optional): List of texts for prompted segmentation. Defaults to None.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
(list): The model predictions.
|
|
45
|
+
"""
|
|
46
|
+
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
|
47
|
+
return super().predict(source, stream, prompts=prompts, **kwargs)
|
|
48
|
+
|
|
31
49
|
@property
|
|
32
50
|
def task_map(self):
|
|
33
51
|
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
import torch
|
|
3
|
+
from PIL import Image
|
|
3
4
|
|
|
4
5
|
from ultralytics.models.yolo.segment import SegmentationPredictor
|
|
6
|
+
from ultralytics.utils import DEFAULT_CFG, checks
|
|
5
7
|
from ultralytics.utils.metrics import box_iou
|
|
8
|
+
from ultralytics.utils.ops import scale_masks
|
|
6
9
|
|
|
7
10
|
from .utils import adjust_bboxes_to_image_border
|
|
8
11
|
|
|
@@ -17,8 +20,16 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
17
20
|
class segmentation.
|
|
18
21
|
"""
|
|
19
22
|
|
|
23
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
24
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
25
|
+
self.prompts = {}
|
|
26
|
+
|
|
20
27
|
def postprocess(self, preds, img, orig_imgs):
|
|
21
28
|
"""Applies box postprocess for FastSAM predictions."""
|
|
29
|
+
bboxes = self.prompts.pop("bboxes", None)
|
|
30
|
+
points = self.prompts.pop("points", None)
|
|
31
|
+
labels = self.prompts.pop("labels", None)
|
|
32
|
+
texts = self.prompts.pop("texts", None)
|
|
22
33
|
results = super().postprocess(preds, img, orig_imgs)
|
|
23
34
|
for result in results:
|
|
24
35
|
full_box = torch.tensor(
|
|
@@ -28,4 +39,107 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
28
39
|
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
|
|
29
40
|
if idx.numel() != 0:
|
|
30
41
|
result.boxes.xyxy[idx] = full_box
|
|
31
|
-
|
|
42
|
+
|
|
43
|
+
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
|
|
44
|
+
|
|
45
|
+
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
|
|
46
|
+
"""
|
|
47
|
+
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
|
48
|
+
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
results (Results | List[Results]): The original inference results from FastSAM models without any prompts.
|
|
52
|
+
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
|
53
|
+
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
|
54
|
+
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
|
55
|
+
texts (str | List[str], optional): Textual prompts, a list contains string objects.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
(List[Results]): The output results determined by prompts.
|
|
59
|
+
"""
|
|
60
|
+
if bboxes is None and points is None and texts is None:
|
|
61
|
+
return results
|
|
62
|
+
prompt_results = []
|
|
63
|
+
if not isinstance(results, list):
|
|
64
|
+
results = [results]
|
|
65
|
+
for result in results:
|
|
66
|
+
masks = result.masks.data
|
|
67
|
+
if masks.shape[1:] != result.orig_shape:
|
|
68
|
+
masks = scale_masks(masks[None], result.orig_shape)[0]
|
|
69
|
+
# bboxes prompt
|
|
70
|
+
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
|
71
|
+
if bboxes is not None:
|
|
72
|
+
bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
|
|
73
|
+
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
|
74
|
+
bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
|
75
|
+
mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
|
|
76
|
+
full_mask_areas = torch.sum(masks, dim=(1, 2))
|
|
77
|
+
|
|
78
|
+
union = bbox_areas[:, None] + full_mask_areas - mask_areas
|
|
79
|
+
idx[torch.argmax(mask_areas / union, dim=1)] = True
|
|
80
|
+
if points is not None:
|
|
81
|
+
points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
|
|
82
|
+
points = points[None] if points.ndim == 1 else points
|
|
83
|
+
if labels is None:
|
|
84
|
+
labels = torch.ones(points.shape[0])
|
|
85
|
+
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
|
86
|
+
assert len(labels) == len(
|
|
87
|
+
points
|
|
88
|
+
), f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
|
|
89
|
+
point_idx = (
|
|
90
|
+
torch.ones(len(result), dtype=torch.bool, device=self.device)
|
|
91
|
+
if labels.sum() == 0 # all negative points
|
|
92
|
+
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
|
93
|
+
)
|
|
94
|
+
for p, l in zip(points, labels):
|
|
95
|
+
point_idx[torch.nonzero(masks[:, p[1], p[0]], as_tuple=True)[0]] = True if l else False
|
|
96
|
+
idx |= point_idx
|
|
97
|
+
if texts is not None:
|
|
98
|
+
if isinstance(texts, str):
|
|
99
|
+
texts = [texts]
|
|
100
|
+
crop_ims, filter_idx = [], []
|
|
101
|
+
for i, b in enumerate(result.boxes.xyxy.tolist()):
|
|
102
|
+
x1, y1, x2, y2 = [int(x) for x in b]
|
|
103
|
+
if masks[i].sum() <= 100:
|
|
104
|
+
filter_idx.append(i)
|
|
105
|
+
continue
|
|
106
|
+
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
|
|
107
|
+
similarity = self._clip_inference(crop_ims, texts)
|
|
108
|
+
text_idx = torch.argmax(similarity, dim=-1) # (M, )
|
|
109
|
+
if len(filter_idx):
|
|
110
|
+
text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
|
|
111
|
+
idx[text_idx] = True
|
|
112
|
+
|
|
113
|
+
prompt_results.append(result[idx])
|
|
114
|
+
|
|
115
|
+
return prompt_results
|
|
116
|
+
|
|
117
|
+
def _clip_inference(self, images, texts):
|
|
118
|
+
"""
|
|
119
|
+
CLIP Inference process.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
images (List[PIL.Image]): A list of source images and each of them should be PIL.Image type with RGB channel order.
|
|
123
|
+
texts (List[str]): A list of prompt texts and each of them should be string object.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
(torch.Tensor): The similarity between given images and texts.
|
|
127
|
+
"""
|
|
128
|
+
try:
|
|
129
|
+
import clip
|
|
130
|
+
except ImportError:
|
|
131
|
+
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
|
132
|
+
import clip
|
|
133
|
+
if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
|
|
134
|
+
self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
|
|
135
|
+
images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
|
|
136
|
+
tokenized_text = clip.tokenize(texts).to(self.device)
|
|
137
|
+
image_features = self.clip_model.encode_image(images)
|
|
138
|
+
text_features = self.clip_model.encode_text(tokenized_text)
|
|
139
|
+
image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
|
|
140
|
+
text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
|
|
141
|
+
return (image_features * text_features[:, None]).sum(-1) # (M, N)
|
|
142
|
+
|
|
143
|
+
def set_prompts(self, prompts):
|
|
144
|
+
"""Set prompts in advance."""
|
|
145
|
+
self.prompts = prompts
|
ultralytics/utils/ops.py
CHANGED
|
@@ -363,7 +363,7 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
|
|
363
363
|
ratio_pad (tuple): the ratio of the padding to the original image.
|
|
364
364
|
|
|
365
365
|
Returns:
|
|
366
|
-
masks (
|
|
366
|
+
masks (np.ndarray): The masks that are being returned with shape [h, w, num].
|
|
367
367
|
"""
|
|
368
368
|
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
|
369
369
|
im1_shape = masks.shape
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ultralytics
|
|
3
|
-
Version: 8.2.
|
|
3
|
+
Version: 8.2.69
|
|
4
4
|
Summary: Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
|
|
5
5
|
Author: Glenn Jocher, Ayush Chaurasia, Jing Qiu
|
|
6
6
|
Maintainer: Glenn Jocher, Ayush Chaurasia, Jing Qiu
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
tests/__init__.py,sha256=9evx3lOdKZeY1iWXvH-FkMkgf8jLucWICoabzeD6aYg,626
|
|
2
2
|
tests/conftest.py,sha256=3ZtD4VlMKK5jVJwIPCrNAcG63vywJzdLq7U2AfYR2VI,2919
|
|
3
|
-
tests/test_cli.py,sha256=
|
|
3
|
+
tests/test_cli.py,sha256=PqZVSKBjLeHwQzh_hVKucQibqTFtP-2ZS6ndZRpqUDI,4654
|
|
4
4
|
tests/test_cuda.py,sha256=uD-ddNEcBMFQmQ9iE4fIGh0EIcGwEoDEUNVCEHicaWE,5133
|
|
5
5
|
tests/test_engine.py,sha256=xW-UT9_9xZp-7-hSnbJgMw_ezTk6NqTOIiA59XZDmxA,4934
|
|
6
6
|
tests/test_explorer.py,sha256=NcxSJeB6FxwkN09hQl7nnQL--HjfHB_WcZk0mEmBNHI,2215
|
|
@@ -8,7 +8,7 @@ tests/test_exports.py,sha256=Uezf3OatpPHlo5qoPw-2kqkZxuMCF9L4XF2riD4vmII,8225
|
|
|
8
8
|
tests/test_integrations.py,sha256=xglcfMPjfVh346PV8WTpk6tBxraCXEFJEQyyJMr5tyU,6064
|
|
9
9
|
tests/test_python.py,sha256=cLK8dyRf_4H_znFIm-krnOFMydwkxKlVZvHwl9vbck8,21780
|
|
10
10
|
tests/test_solutions.py,sha256=EACnPXbeJe2aVTOKfqMk5jclKKCWCVgFEzjpR6y7Sh8,3304
|
|
11
|
-
ultralytics/__init__.py,sha256=
|
|
11
|
+
ultralytics/__init__.py,sha256=YWRj4FNGuxXRahBpsPRAOxm3h0rYMQTFTPqJcnwUEDE,694
|
|
12
12
|
ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
|
|
13
13
|
ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
|
|
14
14
|
ultralytics/cfg/__init__.py,sha256=fD3Llw12sIkJo4g667t6b051je9nEpwdBLGgbbVEzHY,32973
|
|
@@ -84,7 +84,7 @@ ultralytics/cfg/trackers/botsort.yaml,sha256=YrPmj18p1UU40kJH5NRdL_4S8f7knggkk_q
|
|
|
84
84
|
ultralytics/cfg/trackers/bytetrack.yaml,sha256=QvHmtuwulK4X6j3T5VEqtCm0sbWWBUVmWPcCcM20qe0,688
|
|
85
85
|
ultralytics/data/__init__.py,sha256=VGe-ATG7j35F4A4r8Jmzffjlhve4JAJPgRa5ahKTU18,616
|
|
86
86
|
ultralytics/data/annotator.py,sha256=1Hyu6ubrBL8KmRrt1keGn-K4XTqQdAVyIwTsQiBtzLU,2489
|
|
87
|
-
ultralytics/data/augment.py,sha256=
|
|
87
|
+
ultralytics/data/augment.py,sha256=ExU4khJfJ_TeczkJRLNUDscN57SJvAjnm-reouJcxGI,119309
|
|
88
88
|
ultralytics/data/base.py,sha256=C3teLnw97ZTbpJHT9P7yYWosAKocMzgJjRe1rxgfpls,13524
|
|
89
89
|
ultralytics/data/build.py,sha256=AfMmz0sHIYmwry_90tEJFRk_kz0S3SolScVXqYHiT08,7261
|
|
90
90
|
ultralytics/data/converter.py,sha256=7640xKuf7LPeoTwoCvgbIXM5xbzyq72Hu2Rf2lrgjRY,17554
|
|
@@ -109,11 +109,11 @@ ultralytics/hub/__init__.py,sha256=93bqI8x8-MfDYdKkQVduuocUiQj3WGnk1nIk0li08zA,5
|
|
|
109
109
|
ultralytics/hub/auth.py,sha256=FID58NE6fh7Op_B45QOpWBw1qoBN0ponL16uvyb2dZ8,5399
|
|
110
110
|
ultralytics/hub/session.py,sha256=UF_aVwyxnbP-OzpzKXGGhi4i6KGWjjhoj5Qsn46dFpE,16257
|
|
111
111
|
ultralytics/hub/utils.py,sha256=tXfM3QbXBcf4Y6StgHI1pktT4OM7Ic9eF3xiBFHGlhY,9721
|
|
112
|
+
ultralytics/hub/google/__init__.py,sha256=qyvvpGP-4NAtrn7GLqfqxP_aWuRP1T0OvJYafWKvL2Q,7512
|
|
112
113
|
ultralytics/models/__init__.py,sha256=TT9iLCL_n9Y80dcUq0Fo-p-GRZCSU2vrWXM3CoMwqqE,265
|
|
113
|
-
ultralytics/models/fastsam/__init__.py,sha256=
|
|
114
|
-
ultralytics/models/fastsam/model.py,sha256=
|
|
115
|
-
ultralytics/models/fastsam/predict.py,sha256=
|
|
116
|
-
ultralytics/models/fastsam/prompt.py,sha256=4d9e1fEuGpTPWRfu3rG6HT8Bc0rtqJtRpNrlHkmkKcY,15860
|
|
114
|
+
ultralytics/models/fastsam/__init__.py,sha256=W0rRSJM3vdxcsneuiN6_ajkUw86k6-opUKdLxVhKOoQ,203
|
|
115
|
+
ultralytics/models/fastsam/model.py,sha256=r5VZj-KLKaqZtEKTZxQik8vQI2N9uOF4xpV_gA-P8h0,2101
|
|
116
|
+
ultralytics/models/fastsam/predict.py,sha256=ej1Z93W73hThBxuHTdb-LB-yElijKnAMxrTUMlXJ8Qs,7262
|
|
117
117
|
ultralytics/models/fastsam/utils.py,sha256=dCSm6l5yua_PTT5aNvyOvn1Q0h42Ta_NovO7sTbsBxM,715
|
|
118
118
|
ultralytics/models/fastsam/val.py,sha256=ILKmw3U8FYmmQsO9wk9-bJ9Pyp_ZthJM36b61L75s3Y,1967
|
|
119
119
|
ultralytics/models/nas/__init__.py,sha256=d6-WTrYLXvbPs58ebA0-583ODi-VyzXc-t4aGIDQK6M,179
|
|
@@ -204,7 +204,7 @@ ultralytics/utils/files.py,sha256=TVfY0Wi5IsUc4YdsDzC0dAg-jAP5exYvwqB3VmXhDLY,67
|
|
|
204
204
|
ultralytics/utils/instance.py,sha256=5daM5nkxBv9hr5QzyII8zmuFj24hHuNtcr4EMCHAtpY,15654
|
|
205
205
|
ultralytics/utils/loss.py,sha256=mDHGmF-gjggAUVhI1dkCm7TtfZHCwz25XKm4M2xJKLs,33916
|
|
206
206
|
ultralytics/utils/metrics.py,sha256=UXMhBnTtMcpTANxmQqcYkVnj8NeAt39gZez0g6jbrW0,53786
|
|
207
|
-
ultralytics/utils/ops.py,sha256=
|
|
207
|
+
ultralytics/utils/ops.py,sha256=WJHyjyTH8xl5bRkBX0JB3K1sHAGONHx_joubUewE0A8,32709
|
|
208
208
|
ultralytics/utils/patches.py,sha256=Oo3DkP7MbXnNGvPfoFSocAkVvaPh9kwMT_9RQUfjVhI,3594
|
|
209
209
|
ultralytics/utils/plotting.py,sha256=5HRfiG2dklWZJheTxGTy0gFRk39utHcZbMJl7j2hnMI,55522
|
|
210
210
|
ultralytics/utils/tal.py,sha256=hia39MhWPFpDWOTAXC_5vz-9cUdiRHZs-UcTnxD4Dlo,16112
|
|
@@ -222,9 +222,9 @@ ultralytics/utils/callbacks/neptune.py,sha256=5Z3ua5YBTUS56FH8VQKQG1aaIo9fH8GEyz
|
|
|
222
222
|
ultralytics/utils/callbacks/raytune.py,sha256=ODVYzy-CoM4Uge0zjkh3Hnh9nF2M0vhDrSenXnvcizw,705
|
|
223
223
|
ultralytics/utils/callbacks/tensorboard.py,sha256=QEgOVhUqY9akOs5TJIwz1Rvn6l32xWLpOxlwEyWF0B8,4136
|
|
224
224
|
ultralytics/utils/callbacks/wb.py,sha256=9-fjQIdLjr3b73DTE3rHO171KvbH1VweJ-bmbv-rqTw,6747
|
|
225
|
-
ultralytics-8.2.
|
|
226
|
-
ultralytics-8.2.
|
|
227
|
-
ultralytics-8.2.
|
|
228
|
-
ultralytics-8.2.
|
|
229
|
-
ultralytics-8.2.
|
|
230
|
-
ultralytics-8.2.
|
|
225
|
+
ultralytics-8.2.69.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
|
|
226
|
+
ultralytics-8.2.69.dist-info/METADATA,sha256=htZwlHV6f-WyWZpx2aAgEhKJYDRhK56EMOs0w0XwhZ4,41337
|
|
227
|
+
ultralytics-8.2.69.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
228
|
+
ultralytics-8.2.69.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
|
|
229
|
+
ultralytics-8.2.69.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
|
|
230
|
+
ultralytics-8.2.69.dist-info/RECORD,,
|
|
@@ -1,352 +0,0 @@
|
|
|
1
|
-
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
|
|
6
|
-
import cv2
|
|
7
|
-
import numpy as np
|
|
8
|
-
import torch
|
|
9
|
-
from PIL import Image
|
|
10
|
-
from torch import Tensor
|
|
11
|
-
|
|
12
|
-
from ultralytics.utils import TQDM, checks
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class FastSAMPrompt:
|
|
16
|
-
"""
|
|
17
|
-
Fast Segment Anything Model class for image annotation and visualization.
|
|
18
|
-
|
|
19
|
-
Attributes:
|
|
20
|
-
device (str): Computing device ('cuda' or 'cpu').
|
|
21
|
-
results: Object detection or segmentation results.
|
|
22
|
-
source: Source image or image path.
|
|
23
|
-
clip: CLIP model for linear assignment.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
def __init__(self, source, results, device="cuda") -> None:
|
|
27
|
-
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
|
28
|
-
if isinstance(source, (str, Path)) and os.path.isdir(source):
|
|
29
|
-
raise ValueError("FastSAM only accepts image paths and PIL Image sources, not directories.")
|
|
30
|
-
self.device = device
|
|
31
|
-
self.results = results
|
|
32
|
-
self.source = source
|
|
33
|
-
|
|
34
|
-
# Import and assign clip
|
|
35
|
-
try:
|
|
36
|
-
import clip
|
|
37
|
-
except ImportError:
|
|
38
|
-
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
|
39
|
-
import clip
|
|
40
|
-
self.clip = clip
|
|
41
|
-
|
|
42
|
-
@staticmethod
|
|
43
|
-
def _segment_image(image, bbox):
|
|
44
|
-
"""Segments the given image according to the provided bounding box coordinates."""
|
|
45
|
-
image_array = np.array(image)
|
|
46
|
-
segmented_image_array = np.zeros_like(image_array)
|
|
47
|
-
x1, y1, x2, y2 = bbox
|
|
48
|
-
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
|
49
|
-
segmented_image = Image.fromarray(segmented_image_array)
|
|
50
|
-
black_image = Image.new("RGB", image.size, (255, 255, 255))
|
|
51
|
-
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
|
52
|
-
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
|
53
|
-
transparency_mask[y1:y2, x1:x2] = 255
|
|
54
|
-
transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
|
|
55
|
-
black_image.paste(segmented_image, mask=transparency_mask_image)
|
|
56
|
-
return black_image
|
|
57
|
-
|
|
58
|
-
@staticmethod
|
|
59
|
-
def _format_results(result, filter=0):
|
|
60
|
-
"""Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and
|
|
61
|
-
area.
|
|
62
|
-
"""
|
|
63
|
-
annotations = []
|
|
64
|
-
n = len(result.masks.data) if result.masks is not None else 0
|
|
65
|
-
for i in range(n):
|
|
66
|
-
mask = result.masks.data[i] == 1.0
|
|
67
|
-
if torch.sum(mask) >= filter:
|
|
68
|
-
annotation = {
|
|
69
|
-
"id": i,
|
|
70
|
-
"segmentation": mask.cpu().numpy(),
|
|
71
|
-
"bbox": result.boxes.data[i],
|
|
72
|
-
"score": result.boxes.conf[i],
|
|
73
|
-
}
|
|
74
|
-
annotation["area"] = annotation["segmentation"].sum()
|
|
75
|
-
annotations.append(annotation)
|
|
76
|
-
return annotations
|
|
77
|
-
|
|
78
|
-
@staticmethod
|
|
79
|
-
def _get_bbox_from_mask(mask):
|
|
80
|
-
"""Applies morphological transformations to the mask, displays it, and if with_contours is True, draws
|
|
81
|
-
contours.
|
|
82
|
-
"""
|
|
83
|
-
mask = mask.astype(np.uint8)
|
|
84
|
-
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
85
|
-
x1, y1, w, h = cv2.boundingRect(contours[0])
|
|
86
|
-
x2, y2 = x1 + w, y1 + h
|
|
87
|
-
if len(contours) > 1:
|
|
88
|
-
for b in contours:
|
|
89
|
-
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
|
90
|
-
x1 = min(x1, x_t)
|
|
91
|
-
y1 = min(y1, y_t)
|
|
92
|
-
x2 = max(x2, x_t + w_t)
|
|
93
|
-
y2 = max(y2, y_t + h_t)
|
|
94
|
-
return [x1, y1, x2, y2]
|
|
95
|
-
|
|
96
|
-
def plot(
|
|
97
|
-
self,
|
|
98
|
-
annotations,
|
|
99
|
-
output,
|
|
100
|
-
bbox=None,
|
|
101
|
-
points=None,
|
|
102
|
-
point_label=None,
|
|
103
|
-
mask_random_color=True,
|
|
104
|
-
better_quality=True,
|
|
105
|
-
retina=False,
|
|
106
|
-
with_contours=True,
|
|
107
|
-
):
|
|
108
|
-
"""
|
|
109
|
-
Plots annotations, bounding boxes, and points on images and saves the output.
|
|
110
|
-
|
|
111
|
-
Args:
|
|
112
|
-
annotations (list): Annotations to be plotted.
|
|
113
|
-
output (str or Path): Output directory for saving the plots.
|
|
114
|
-
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
|
|
115
|
-
points (list, optional): Points to be plotted. Defaults to None.
|
|
116
|
-
point_label (list, optional): Labels for the points. Defaults to None.
|
|
117
|
-
mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
|
|
118
|
-
better_quality (bool, optional): Whether to apply morphological transformations for better mask quality.
|
|
119
|
-
Defaults to True.
|
|
120
|
-
retina (bool, optional): Whether to use retina mask. Defaults to False.
|
|
121
|
-
with_contours (bool, optional): Whether to plot contours. Defaults to True.
|
|
122
|
-
"""
|
|
123
|
-
import matplotlib.pyplot as plt
|
|
124
|
-
|
|
125
|
-
pbar = TQDM(annotations, total=len(annotations))
|
|
126
|
-
for ann in pbar:
|
|
127
|
-
result_name = os.path.basename(ann.path)
|
|
128
|
-
image = ann.orig_img[..., ::-1] # BGR to RGB
|
|
129
|
-
original_h, original_w = ann.orig_shape
|
|
130
|
-
# For macOS only
|
|
131
|
-
# plt.switch_backend('TkAgg')
|
|
132
|
-
plt.figure(figsize=(original_w / 100, original_h / 100))
|
|
133
|
-
# Add subplot with no margin.
|
|
134
|
-
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
135
|
-
plt.margins(0, 0)
|
|
136
|
-
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
137
|
-
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
138
|
-
plt.imshow(image)
|
|
139
|
-
|
|
140
|
-
if ann.masks is not None:
|
|
141
|
-
masks = ann.masks.data
|
|
142
|
-
if better_quality:
|
|
143
|
-
if isinstance(masks[0], torch.Tensor):
|
|
144
|
-
masks = np.array(masks.cpu())
|
|
145
|
-
for i, mask in enumerate(masks):
|
|
146
|
-
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
|
147
|
-
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
|
148
|
-
|
|
149
|
-
self.fast_show_mask(
|
|
150
|
-
masks,
|
|
151
|
-
plt.gca(),
|
|
152
|
-
random_color=mask_random_color,
|
|
153
|
-
bbox=bbox,
|
|
154
|
-
points=points,
|
|
155
|
-
pointlabel=point_label,
|
|
156
|
-
retinamask=retina,
|
|
157
|
-
target_height=original_h,
|
|
158
|
-
target_width=original_w,
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
if with_contours:
|
|
162
|
-
contour_all = []
|
|
163
|
-
temp = np.zeros((original_h, original_w, 1))
|
|
164
|
-
for i, mask in enumerate(masks):
|
|
165
|
-
mask = mask.astype(np.uint8)
|
|
166
|
-
if not retina:
|
|
167
|
-
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
|
|
168
|
-
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
|
169
|
-
contour_all.extend(iter(contours))
|
|
170
|
-
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
|
171
|
-
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
|
|
172
|
-
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
|
173
|
-
plt.imshow(contour_mask)
|
|
174
|
-
|
|
175
|
-
# Save the figure
|
|
176
|
-
save_path = Path(output) / result_name
|
|
177
|
-
save_path.parent.mkdir(exist_ok=True, parents=True)
|
|
178
|
-
plt.axis("off")
|
|
179
|
-
plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
|
|
180
|
-
plt.close()
|
|
181
|
-
pbar.set_description(f"Saving {result_name} to {save_path}")
|
|
182
|
-
|
|
183
|
-
@staticmethod
|
|
184
|
-
def fast_show_mask(
|
|
185
|
-
annotation,
|
|
186
|
-
ax,
|
|
187
|
-
random_color=False,
|
|
188
|
-
bbox=None,
|
|
189
|
-
points=None,
|
|
190
|
-
pointlabel=None,
|
|
191
|
-
retinamask=True,
|
|
192
|
-
target_height=960,
|
|
193
|
-
target_width=960,
|
|
194
|
-
):
|
|
195
|
-
"""
|
|
196
|
-
Quickly shows the mask annotations on the given matplotlib axis.
|
|
197
|
-
|
|
198
|
-
Args:
|
|
199
|
-
annotation (array-like): Mask annotation.
|
|
200
|
-
ax (matplotlib.axes.Axes): Matplotlib axis.
|
|
201
|
-
random_color (bool, optional): Whether to use random color for masks. Defaults to False.
|
|
202
|
-
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
|
|
203
|
-
points (list, optional): Points to be plotted. Defaults to None.
|
|
204
|
-
pointlabel (list, optional): Labels for the points. Defaults to None.
|
|
205
|
-
retinamask (bool, optional): Whether to use retina mask. Defaults to True.
|
|
206
|
-
target_height (int, optional): Target height for resizing. Defaults to 960.
|
|
207
|
-
target_width (int, optional): Target width for resizing. Defaults to 960.
|
|
208
|
-
"""
|
|
209
|
-
import matplotlib.pyplot as plt
|
|
210
|
-
|
|
211
|
-
n, h, w = annotation.shape # batch, height, width
|
|
212
|
-
|
|
213
|
-
areas = np.sum(annotation, axis=(1, 2))
|
|
214
|
-
annotation = annotation[np.argsort(areas)]
|
|
215
|
-
|
|
216
|
-
index = (annotation != 0).argmax(axis=0)
|
|
217
|
-
if random_color:
|
|
218
|
-
color = np.random.random((n, 1, 1, 3))
|
|
219
|
-
else:
|
|
220
|
-
color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
|
|
221
|
-
transparency = np.ones((n, 1, 1, 1)) * 0.6
|
|
222
|
-
visual = np.concatenate([color, transparency], axis=-1)
|
|
223
|
-
mask_image = np.expand_dims(annotation, -1) * visual
|
|
224
|
-
|
|
225
|
-
show = np.zeros((h, w, 4))
|
|
226
|
-
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
|
|
227
|
-
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
|
228
|
-
|
|
229
|
-
show[h_indices, w_indices, :] = mask_image[indices]
|
|
230
|
-
if bbox is not None:
|
|
231
|
-
x1, y1, x2, y2 = bbox
|
|
232
|
-
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
|
|
233
|
-
# Draw point
|
|
234
|
-
if points is not None:
|
|
235
|
-
plt.scatter(
|
|
236
|
-
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
|
237
|
-
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
|
238
|
-
s=20,
|
|
239
|
-
c="y",
|
|
240
|
-
)
|
|
241
|
-
plt.scatter(
|
|
242
|
-
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
|
243
|
-
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
|
244
|
-
s=20,
|
|
245
|
-
c="m",
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
if not retinamask:
|
|
249
|
-
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
|
250
|
-
ax.imshow(show)
|
|
251
|
-
|
|
252
|
-
@torch.no_grad()
|
|
253
|
-
def retrieve(self, model, preprocess, elements, search_text: str, device) -> Tensor:
|
|
254
|
-
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
|
|
255
|
-
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
|
256
|
-
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
|
257
|
-
stacked_images = torch.stack(preprocessed_images)
|
|
258
|
-
image_features = model.encode_image(stacked_images)
|
|
259
|
-
text_features = model.encode_text(tokenized_text)
|
|
260
|
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
261
|
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
262
|
-
probs = 100.0 * image_features @ text_features.T
|
|
263
|
-
return probs[:, 0].softmax(dim=0)
|
|
264
|
-
|
|
265
|
-
def _crop_image(self, format_results):
|
|
266
|
-
"""Crops an image based on provided annotation format and returns cropped images and related data."""
|
|
267
|
-
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
|
268
|
-
ori_w, ori_h = image.size
|
|
269
|
-
annotations = format_results
|
|
270
|
-
mask_h, mask_w = annotations[0]["segmentation"].shape
|
|
271
|
-
if ori_w != mask_w or ori_h != mask_h:
|
|
272
|
-
image = image.resize((mask_w, mask_h))
|
|
273
|
-
cropped_images = []
|
|
274
|
-
filter_id = []
|
|
275
|
-
for _, mask in enumerate(annotations):
|
|
276
|
-
if np.sum(mask["segmentation"]) <= 100:
|
|
277
|
-
filter_id.append(_)
|
|
278
|
-
continue
|
|
279
|
-
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
|
|
280
|
-
cropped_images.append(self._segment_image(image, bbox)) # save cropped image
|
|
281
|
-
|
|
282
|
-
return cropped_images, filter_id, annotations
|
|
283
|
-
|
|
284
|
-
def box_prompt(self, bbox):
|
|
285
|
-
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
|
286
|
-
if self.results[0].masks is not None:
|
|
287
|
-
assert bbox[2] != 0 and bbox[3] != 0, "Bounding box width and height should not be zero"
|
|
288
|
-
masks = self.results[0].masks.data
|
|
289
|
-
target_height, target_width = self.results[0].orig_shape
|
|
290
|
-
h = masks.shape[1]
|
|
291
|
-
w = masks.shape[2]
|
|
292
|
-
if h != target_height or w != target_width:
|
|
293
|
-
bbox = [
|
|
294
|
-
int(bbox[0] * w / target_width),
|
|
295
|
-
int(bbox[1] * h / target_height),
|
|
296
|
-
int(bbox[2] * w / target_width),
|
|
297
|
-
int(bbox[3] * h / target_height),
|
|
298
|
-
]
|
|
299
|
-
bbox[0] = max(round(bbox[0]), 0)
|
|
300
|
-
bbox[1] = max(round(bbox[1]), 0)
|
|
301
|
-
bbox[2] = min(round(bbox[2]), w)
|
|
302
|
-
bbox[3] = min(round(bbox[3]), h)
|
|
303
|
-
|
|
304
|
-
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
|
305
|
-
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
|
306
|
-
|
|
307
|
-
masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
|
|
308
|
-
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
|
309
|
-
|
|
310
|
-
union = bbox_area + orig_masks_area - masks_area
|
|
311
|
-
iou = masks_area / union
|
|
312
|
-
max_iou_index = torch.argmax(iou)
|
|
313
|
-
|
|
314
|
-
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
|
|
315
|
-
return self.results
|
|
316
|
-
|
|
317
|
-
def point_prompt(self, points, pointlabel): # numpy
|
|
318
|
-
"""Adjusts points on detected masks based on user input and returns the modified results."""
|
|
319
|
-
if self.results[0].masks is not None:
|
|
320
|
-
masks = self._format_results(self.results[0], 0)
|
|
321
|
-
target_height, target_width = self.results[0].orig_shape
|
|
322
|
-
h = masks[0]["segmentation"].shape[0]
|
|
323
|
-
w = masks[0]["segmentation"].shape[1]
|
|
324
|
-
if h != target_height or w != target_width:
|
|
325
|
-
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
|
326
|
-
onemask = np.zeros((h, w))
|
|
327
|
-
for annotation in masks:
|
|
328
|
-
mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
|
|
329
|
-
for i, point in enumerate(points):
|
|
330
|
-
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
|
331
|
-
onemask += mask
|
|
332
|
-
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
|
333
|
-
onemask -= mask
|
|
334
|
-
onemask = onemask >= 1
|
|
335
|
-
self.results[0].masks.data = torch.tensor(np.array([onemask]))
|
|
336
|
-
return self.results
|
|
337
|
-
|
|
338
|
-
def text_prompt(self, text, clip_download_root=None):
|
|
339
|
-
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
|
340
|
-
if self.results[0].masks is not None:
|
|
341
|
-
format_results = self._format_results(self.results[0], 0)
|
|
342
|
-
cropped_images, filter_id, annotations = self._crop_image(format_results)
|
|
343
|
-
clip_model, preprocess = self.clip.load("ViT-B/32", download_root=clip_download_root, device=self.device)
|
|
344
|
-
scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device)
|
|
345
|
-
max_idx = torch.argmax(scores)
|
|
346
|
-
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
|
347
|
-
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
|
|
348
|
-
return self.results
|
|
349
|
-
|
|
350
|
-
def everything_prompt(self):
|
|
351
|
-
"""Returns the processed results from the previous methods in the class."""
|
|
352
|
-
return self.results
|
|
File without changes
|
|
File without changes
|
|
File without changes
|