ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +527 -67
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +44 -37
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +84 -56
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.28.dist-info/METADATA +0 -373
- ultralytics-8.1.28.dist-info/RECORD +0 -197
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,357 +0,0 @@
|
|
1
|
-
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2
|
-
|
3
|
-
import os
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
import cv2
|
7
|
-
import matplotlib.pyplot as plt
|
8
|
-
import numpy as np
|
9
|
-
import torch
|
10
|
-
from PIL import Image
|
11
|
-
|
12
|
-
from ultralytics.utils import TQDM
|
13
|
-
|
14
|
-
|
15
|
-
class FastSAMPrompt:
|
16
|
-
"""
|
17
|
-
Fast Segment Anything Model class for image annotation and visualization.
|
18
|
-
|
19
|
-
Attributes:
|
20
|
-
device (str): Computing device ('cuda' or 'cpu').
|
21
|
-
results: Object detection or segmentation results.
|
22
|
-
source: Source image or image path.
|
23
|
-
clip: CLIP model for linear assignment.
|
24
|
-
"""
|
25
|
-
|
26
|
-
def __init__(self, source, results, device="cuda") -> None:
|
27
|
-
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
28
|
-
self.device = device
|
29
|
-
self.results = results
|
30
|
-
self.source = source
|
31
|
-
|
32
|
-
# Import and assign clip
|
33
|
-
try:
|
34
|
-
import clip
|
35
|
-
except ImportError:
|
36
|
-
from ultralytics.utils.checks import check_requirements
|
37
|
-
|
38
|
-
check_requirements("git+https://github.com/openai/CLIP.git")
|
39
|
-
import clip
|
40
|
-
self.clip = clip
|
41
|
-
|
42
|
-
@staticmethod
|
43
|
-
def _segment_image(image, bbox):
|
44
|
-
"""Segments the given image according to the provided bounding box coordinates."""
|
45
|
-
image_array = np.array(image)
|
46
|
-
segmented_image_array = np.zeros_like(image_array)
|
47
|
-
x1, y1, x2, y2 = bbox
|
48
|
-
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
49
|
-
segmented_image = Image.fromarray(segmented_image_array)
|
50
|
-
black_image = Image.new("RGB", image.size, (255, 255, 255))
|
51
|
-
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
52
|
-
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
53
|
-
transparency_mask[y1:y2, x1:x2] = 255
|
54
|
-
transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
|
55
|
-
black_image.paste(segmented_image, mask=transparency_mask_image)
|
56
|
-
return black_image
|
57
|
-
|
58
|
-
@staticmethod
|
59
|
-
def _format_results(result, filter=0):
|
60
|
-
"""Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and
|
61
|
-
area.
|
62
|
-
"""
|
63
|
-
annotations = []
|
64
|
-
n = len(result.masks.data) if result.masks is not None else 0
|
65
|
-
for i in range(n):
|
66
|
-
mask = result.masks.data[i] == 1.0
|
67
|
-
if torch.sum(mask) >= filter:
|
68
|
-
annotation = {
|
69
|
-
"id": i,
|
70
|
-
"segmentation": mask.cpu().numpy(),
|
71
|
-
"bbox": result.boxes.data[i],
|
72
|
-
"score": result.boxes.conf[i],
|
73
|
-
}
|
74
|
-
annotation["area"] = annotation["segmentation"].sum()
|
75
|
-
annotations.append(annotation)
|
76
|
-
return annotations
|
77
|
-
|
78
|
-
@staticmethod
|
79
|
-
def _get_bbox_from_mask(mask):
|
80
|
-
"""Applies morphological transformations to the mask, displays it, and if with_contours is True, draws
|
81
|
-
contours.
|
82
|
-
"""
|
83
|
-
mask = mask.astype(np.uint8)
|
84
|
-
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
85
|
-
x1, y1, w, h = cv2.boundingRect(contours[0])
|
86
|
-
x2, y2 = x1 + w, y1 + h
|
87
|
-
if len(contours) > 1:
|
88
|
-
for b in contours:
|
89
|
-
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
90
|
-
x1 = min(x1, x_t)
|
91
|
-
y1 = min(y1, y_t)
|
92
|
-
x2 = max(x2, x_t + w_t)
|
93
|
-
y2 = max(y2, y_t + h_t)
|
94
|
-
return [x1, y1, x2, y2]
|
95
|
-
|
96
|
-
def plot(
|
97
|
-
self,
|
98
|
-
annotations,
|
99
|
-
output,
|
100
|
-
bbox=None,
|
101
|
-
points=None,
|
102
|
-
point_label=None,
|
103
|
-
mask_random_color=True,
|
104
|
-
better_quality=True,
|
105
|
-
retina=False,
|
106
|
-
with_contours=True,
|
107
|
-
):
|
108
|
-
"""
|
109
|
-
Plots annotations, bounding boxes, and points on images and saves the output.
|
110
|
-
|
111
|
-
Args:
|
112
|
-
annotations (list): Annotations to be plotted.
|
113
|
-
output (str or Path): Output directory for saving the plots.
|
114
|
-
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
|
115
|
-
points (list, optional): Points to be plotted. Defaults to None.
|
116
|
-
point_label (list, optional): Labels for the points. Defaults to None.
|
117
|
-
mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
|
118
|
-
better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True.
|
119
|
-
retina (bool, optional): Whether to use retina mask. Defaults to False.
|
120
|
-
with_contours (bool, optional): Whether to plot contours. Defaults to True.
|
121
|
-
"""
|
122
|
-
pbar = TQDM(annotations, total=len(annotations))
|
123
|
-
for ann in pbar:
|
124
|
-
result_name = os.path.basename(ann.path)
|
125
|
-
image = ann.orig_img[..., ::-1] # BGR to RGB
|
126
|
-
original_h, original_w = ann.orig_shape
|
127
|
-
# For macOS only
|
128
|
-
# plt.switch_backend('TkAgg')
|
129
|
-
plt.figure(figsize=(original_w / 100, original_h / 100))
|
130
|
-
# Add subplot with no margin.
|
131
|
-
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
132
|
-
plt.margins(0, 0)
|
133
|
-
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
134
|
-
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
135
|
-
plt.imshow(image)
|
136
|
-
|
137
|
-
if ann.masks is not None:
|
138
|
-
masks = ann.masks.data
|
139
|
-
if better_quality:
|
140
|
-
if isinstance(masks[0], torch.Tensor):
|
141
|
-
masks = np.array(masks.cpu())
|
142
|
-
for i, mask in enumerate(masks):
|
143
|
-
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
144
|
-
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
145
|
-
|
146
|
-
self.fast_show_mask(
|
147
|
-
masks,
|
148
|
-
plt.gca(),
|
149
|
-
random_color=mask_random_color,
|
150
|
-
bbox=bbox,
|
151
|
-
points=points,
|
152
|
-
pointlabel=point_label,
|
153
|
-
retinamask=retina,
|
154
|
-
target_height=original_h,
|
155
|
-
target_width=original_w,
|
156
|
-
)
|
157
|
-
|
158
|
-
if with_contours:
|
159
|
-
contour_all = []
|
160
|
-
temp = np.zeros((original_h, original_w, 1))
|
161
|
-
for i, mask in enumerate(masks):
|
162
|
-
mask = mask.astype(np.uint8)
|
163
|
-
if not retina:
|
164
|
-
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
|
165
|
-
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
166
|
-
contour_all.extend(iter(contours))
|
167
|
-
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
168
|
-
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
|
169
|
-
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
170
|
-
plt.imshow(contour_mask)
|
171
|
-
|
172
|
-
# Save the figure
|
173
|
-
save_path = Path(output) / result_name
|
174
|
-
save_path.parent.mkdir(exist_ok=True, parents=True)
|
175
|
-
plt.axis("off")
|
176
|
-
plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
|
177
|
-
plt.close()
|
178
|
-
pbar.set_description(f"Saving {result_name} to {save_path}")
|
179
|
-
|
180
|
-
@staticmethod
|
181
|
-
def fast_show_mask(
|
182
|
-
annotation,
|
183
|
-
ax,
|
184
|
-
random_color=False,
|
185
|
-
bbox=None,
|
186
|
-
points=None,
|
187
|
-
pointlabel=None,
|
188
|
-
retinamask=True,
|
189
|
-
target_height=960,
|
190
|
-
target_width=960,
|
191
|
-
):
|
192
|
-
"""
|
193
|
-
Quickly shows the mask annotations on the given matplotlib axis.
|
194
|
-
|
195
|
-
Args:
|
196
|
-
annotation (array-like): Mask annotation.
|
197
|
-
ax (matplotlib.axes.Axes): Matplotlib axis.
|
198
|
-
random_color (bool, optional): Whether to use random color for masks. Defaults to False.
|
199
|
-
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
|
200
|
-
points (list, optional): Points to be plotted. Defaults to None.
|
201
|
-
pointlabel (list, optional): Labels for the points. Defaults to None.
|
202
|
-
retinamask (bool, optional): Whether to use retina mask. Defaults to True.
|
203
|
-
target_height (int, optional): Target height for resizing. Defaults to 960.
|
204
|
-
target_width (int, optional): Target width for resizing. Defaults to 960.
|
205
|
-
"""
|
206
|
-
n, h, w = annotation.shape # batch, height, width
|
207
|
-
|
208
|
-
areas = np.sum(annotation, axis=(1, 2))
|
209
|
-
annotation = annotation[np.argsort(areas)]
|
210
|
-
|
211
|
-
index = (annotation != 0).argmax(axis=0)
|
212
|
-
if random_color:
|
213
|
-
color = np.random.random((n, 1, 1, 3))
|
214
|
-
else:
|
215
|
-
color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
|
216
|
-
transparency = np.ones((n, 1, 1, 1)) * 0.6
|
217
|
-
visual = np.concatenate([color, transparency], axis=-1)
|
218
|
-
mask_image = np.expand_dims(annotation, -1) * visual
|
219
|
-
|
220
|
-
show = np.zeros((h, w, 4))
|
221
|
-
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
|
222
|
-
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
223
|
-
|
224
|
-
show[h_indices, w_indices, :] = mask_image[indices]
|
225
|
-
if bbox is not None:
|
226
|
-
x1, y1, x2, y2 = bbox
|
227
|
-
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
|
228
|
-
# Draw point
|
229
|
-
if points is not None:
|
230
|
-
plt.scatter(
|
231
|
-
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
232
|
-
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
233
|
-
s=20,
|
234
|
-
c="y",
|
235
|
-
)
|
236
|
-
plt.scatter(
|
237
|
-
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
238
|
-
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
239
|
-
s=20,
|
240
|
-
c="m",
|
241
|
-
)
|
242
|
-
|
243
|
-
if not retinamask:
|
244
|
-
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
245
|
-
ax.imshow(show)
|
246
|
-
|
247
|
-
@torch.no_grad()
|
248
|
-
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
249
|
-
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
|
250
|
-
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
251
|
-
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
252
|
-
stacked_images = torch.stack(preprocessed_images)
|
253
|
-
image_features = model.encode_image(stacked_images)
|
254
|
-
text_features = model.encode_text(tokenized_text)
|
255
|
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
256
|
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
257
|
-
probs = 100.0 * image_features @ text_features.T
|
258
|
-
return probs[:, 0].softmax(dim=0)
|
259
|
-
|
260
|
-
def _crop_image(self, format_results):
|
261
|
-
"""Crops an image based on provided annotation format and returns cropped images and related data."""
|
262
|
-
if os.path.isdir(self.source):
|
263
|
-
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
264
|
-
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
265
|
-
ori_w, ori_h = image.size
|
266
|
-
annotations = format_results
|
267
|
-
mask_h, mask_w = annotations[0]["segmentation"].shape
|
268
|
-
if ori_w != mask_w or ori_h != mask_h:
|
269
|
-
image = image.resize((mask_w, mask_h))
|
270
|
-
cropped_boxes = []
|
271
|
-
cropped_images = []
|
272
|
-
not_crop = []
|
273
|
-
filter_id = []
|
274
|
-
for _, mask in enumerate(annotations):
|
275
|
-
if np.sum(mask["segmentation"]) <= 100:
|
276
|
-
filter_id.append(_)
|
277
|
-
continue
|
278
|
-
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
|
279
|
-
cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
|
280
|
-
cropped_images.append(bbox) # save cropped image bbox
|
281
|
-
|
282
|
-
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
283
|
-
|
284
|
-
def box_prompt(self, bbox):
|
285
|
-
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
286
|
-
if self.results[0].masks is not None:
|
287
|
-
assert bbox[2] != 0 and bbox[3] != 0
|
288
|
-
if os.path.isdir(self.source):
|
289
|
-
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
290
|
-
masks = self.results[0].masks.data
|
291
|
-
target_height, target_width = self.results[0].orig_shape
|
292
|
-
h = masks.shape[1]
|
293
|
-
w = masks.shape[2]
|
294
|
-
if h != target_height or w != target_width:
|
295
|
-
bbox = [
|
296
|
-
int(bbox[0] * w / target_width),
|
297
|
-
int(bbox[1] * h / target_height),
|
298
|
-
int(bbox[2] * w / target_width),
|
299
|
-
int(bbox[3] * h / target_height),
|
300
|
-
]
|
301
|
-
bbox[0] = max(round(bbox[0]), 0)
|
302
|
-
bbox[1] = max(round(bbox[1]), 0)
|
303
|
-
bbox[2] = min(round(bbox[2]), w)
|
304
|
-
bbox[3] = min(round(bbox[3]), h)
|
305
|
-
|
306
|
-
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
307
|
-
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
308
|
-
|
309
|
-
masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
|
310
|
-
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
311
|
-
|
312
|
-
union = bbox_area + orig_masks_area - masks_area
|
313
|
-
iou = masks_area / union
|
314
|
-
max_iou_index = torch.argmax(iou)
|
315
|
-
|
316
|
-
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
|
317
|
-
return self.results
|
318
|
-
|
319
|
-
def point_prompt(self, points, pointlabel): # numpy
|
320
|
-
"""Adjusts points on detected masks based on user input and returns the modified results."""
|
321
|
-
if self.results[0].masks is not None:
|
322
|
-
if os.path.isdir(self.source):
|
323
|
-
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
324
|
-
masks = self._format_results(self.results[0], 0)
|
325
|
-
target_height, target_width = self.results[0].orig_shape
|
326
|
-
h = masks[0]["segmentation"].shape[0]
|
327
|
-
w = masks[0]["segmentation"].shape[1]
|
328
|
-
if h != target_height or w != target_width:
|
329
|
-
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
330
|
-
onemask = np.zeros((h, w))
|
331
|
-
for annotation in masks:
|
332
|
-
mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
|
333
|
-
for i, point in enumerate(points):
|
334
|
-
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
335
|
-
onemask += mask
|
336
|
-
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
337
|
-
onemask -= mask
|
338
|
-
onemask = onemask >= 1
|
339
|
-
self.results[0].masks.data = torch.tensor(np.array([onemask]))
|
340
|
-
return self.results
|
341
|
-
|
342
|
-
def text_prompt(self, text):
|
343
|
-
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
344
|
-
if self.results[0].masks is not None:
|
345
|
-
format_results = self._format_results(self.results[0], 0)
|
346
|
-
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
347
|
-
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
|
348
|
-
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
349
|
-
max_idx = scores.argsort()
|
350
|
-
max_idx = max_idx[-1]
|
351
|
-
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
352
|
-
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
|
353
|
-
return self.results
|
354
|
-
|
355
|
-
def everything_prompt(self):
|
356
|
-
"""Returns the processed results from the previous methods in the class."""
|
357
|
-
return self.results
|