ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +527 -67
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +44 -37
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +84 -56
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.28.dist-info/METADATA +0 -373
  244. ultralytics-8.1.28.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,357 +0,0 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
2
-
3
- import os
4
- from pathlib import Path
5
-
6
- import cv2
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- import torch
10
- from PIL import Image
11
-
12
- from ultralytics.utils import TQDM
13
-
14
-
15
- class FastSAMPrompt:
16
- """
17
- Fast Segment Anything Model class for image annotation and visualization.
18
-
19
- Attributes:
20
- device (str): Computing device ('cuda' or 'cpu').
21
- results: Object detection or segmentation results.
22
- source: Source image or image path.
23
- clip: CLIP model for linear assignment.
24
- """
25
-
26
- def __init__(self, source, results, device="cuda") -> None:
27
- """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
28
- self.device = device
29
- self.results = results
30
- self.source = source
31
-
32
- # Import and assign clip
33
- try:
34
- import clip
35
- except ImportError:
36
- from ultralytics.utils.checks import check_requirements
37
-
38
- check_requirements("git+https://github.com/openai/CLIP.git")
39
- import clip
40
- self.clip = clip
41
-
42
- @staticmethod
43
- def _segment_image(image, bbox):
44
- """Segments the given image according to the provided bounding box coordinates."""
45
- image_array = np.array(image)
46
- segmented_image_array = np.zeros_like(image_array)
47
- x1, y1, x2, y2 = bbox
48
- segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
49
- segmented_image = Image.fromarray(segmented_image_array)
50
- black_image = Image.new("RGB", image.size, (255, 255, 255))
51
- # transparency_mask = np.zeros_like((), dtype=np.uint8)
52
- transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
53
- transparency_mask[y1:y2, x1:x2] = 255
54
- transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
55
- black_image.paste(segmented_image, mask=transparency_mask_image)
56
- return black_image
57
-
58
- @staticmethod
59
- def _format_results(result, filter=0):
60
- """Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and
61
- area.
62
- """
63
- annotations = []
64
- n = len(result.masks.data) if result.masks is not None else 0
65
- for i in range(n):
66
- mask = result.masks.data[i] == 1.0
67
- if torch.sum(mask) >= filter:
68
- annotation = {
69
- "id": i,
70
- "segmentation": mask.cpu().numpy(),
71
- "bbox": result.boxes.data[i],
72
- "score": result.boxes.conf[i],
73
- }
74
- annotation["area"] = annotation["segmentation"].sum()
75
- annotations.append(annotation)
76
- return annotations
77
-
78
- @staticmethod
79
- def _get_bbox_from_mask(mask):
80
- """Applies morphological transformations to the mask, displays it, and if with_contours is True, draws
81
- contours.
82
- """
83
- mask = mask.astype(np.uint8)
84
- contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
85
- x1, y1, w, h = cv2.boundingRect(contours[0])
86
- x2, y2 = x1 + w, y1 + h
87
- if len(contours) > 1:
88
- for b in contours:
89
- x_t, y_t, w_t, h_t = cv2.boundingRect(b)
90
- x1 = min(x1, x_t)
91
- y1 = min(y1, y_t)
92
- x2 = max(x2, x_t + w_t)
93
- y2 = max(y2, y_t + h_t)
94
- return [x1, y1, x2, y2]
95
-
96
- def plot(
97
- self,
98
- annotations,
99
- output,
100
- bbox=None,
101
- points=None,
102
- point_label=None,
103
- mask_random_color=True,
104
- better_quality=True,
105
- retina=False,
106
- with_contours=True,
107
- ):
108
- """
109
- Plots annotations, bounding boxes, and points on images and saves the output.
110
-
111
- Args:
112
- annotations (list): Annotations to be plotted.
113
- output (str or Path): Output directory for saving the plots.
114
- bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
115
- points (list, optional): Points to be plotted. Defaults to None.
116
- point_label (list, optional): Labels for the points. Defaults to None.
117
- mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
118
- better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True.
119
- retina (bool, optional): Whether to use retina mask. Defaults to False.
120
- with_contours (bool, optional): Whether to plot contours. Defaults to True.
121
- """
122
- pbar = TQDM(annotations, total=len(annotations))
123
- for ann in pbar:
124
- result_name = os.path.basename(ann.path)
125
- image = ann.orig_img[..., ::-1] # BGR to RGB
126
- original_h, original_w = ann.orig_shape
127
- # For macOS only
128
- # plt.switch_backend('TkAgg')
129
- plt.figure(figsize=(original_w / 100, original_h / 100))
130
- # Add subplot with no margin.
131
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
132
- plt.margins(0, 0)
133
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
134
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
135
- plt.imshow(image)
136
-
137
- if ann.masks is not None:
138
- masks = ann.masks.data
139
- if better_quality:
140
- if isinstance(masks[0], torch.Tensor):
141
- masks = np.array(masks.cpu())
142
- for i, mask in enumerate(masks):
143
- mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
144
- masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
145
-
146
- self.fast_show_mask(
147
- masks,
148
- plt.gca(),
149
- random_color=mask_random_color,
150
- bbox=bbox,
151
- points=points,
152
- pointlabel=point_label,
153
- retinamask=retina,
154
- target_height=original_h,
155
- target_width=original_w,
156
- )
157
-
158
- if with_contours:
159
- contour_all = []
160
- temp = np.zeros((original_h, original_w, 1))
161
- for i, mask in enumerate(masks):
162
- mask = mask.astype(np.uint8)
163
- if not retina:
164
- mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
165
- contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
166
- contour_all.extend(iter(contours))
167
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
168
- color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
169
- contour_mask = temp / 255 * color.reshape(1, 1, -1)
170
- plt.imshow(contour_mask)
171
-
172
- # Save the figure
173
- save_path = Path(output) / result_name
174
- save_path.parent.mkdir(exist_ok=True, parents=True)
175
- plt.axis("off")
176
- plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
177
- plt.close()
178
- pbar.set_description(f"Saving {result_name} to {save_path}")
179
-
180
- @staticmethod
181
- def fast_show_mask(
182
- annotation,
183
- ax,
184
- random_color=False,
185
- bbox=None,
186
- points=None,
187
- pointlabel=None,
188
- retinamask=True,
189
- target_height=960,
190
- target_width=960,
191
- ):
192
- """
193
- Quickly shows the mask annotations on the given matplotlib axis.
194
-
195
- Args:
196
- annotation (array-like): Mask annotation.
197
- ax (matplotlib.axes.Axes): Matplotlib axis.
198
- random_color (bool, optional): Whether to use random color for masks. Defaults to False.
199
- bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
200
- points (list, optional): Points to be plotted. Defaults to None.
201
- pointlabel (list, optional): Labels for the points. Defaults to None.
202
- retinamask (bool, optional): Whether to use retina mask. Defaults to True.
203
- target_height (int, optional): Target height for resizing. Defaults to 960.
204
- target_width (int, optional): Target width for resizing. Defaults to 960.
205
- """
206
- n, h, w = annotation.shape # batch, height, width
207
-
208
- areas = np.sum(annotation, axis=(1, 2))
209
- annotation = annotation[np.argsort(areas)]
210
-
211
- index = (annotation != 0).argmax(axis=0)
212
- if random_color:
213
- color = np.random.random((n, 1, 1, 3))
214
- else:
215
- color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
216
- transparency = np.ones((n, 1, 1, 1)) * 0.6
217
- visual = np.concatenate([color, transparency], axis=-1)
218
- mask_image = np.expand_dims(annotation, -1) * visual
219
-
220
- show = np.zeros((h, w, 4))
221
- h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
222
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
223
-
224
- show[h_indices, w_indices, :] = mask_image[indices]
225
- if bbox is not None:
226
- x1, y1, x2, y2 = bbox
227
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
228
- # Draw point
229
- if points is not None:
230
- plt.scatter(
231
- [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
232
- [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
233
- s=20,
234
- c="y",
235
- )
236
- plt.scatter(
237
- [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
238
- [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
239
- s=20,
240
- c="m",
241
- )
242
-
243
- if not retinamask:
244
- show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
245
- ax.imshow(show)
246
-
247
- @torch.no_grad()
248
- def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
249
- """Processes images and text with a model, calculates similarity, and returns softmax score."""
250
- preprocessed_images = [preprocess(image).to(device) for image in elements]
251
- tokenized_text = self.clip.tokenize([search_text]).to(device)
252
- stacked_images = torch.stack(preprocessed_images)
253
- image_features = model.encode_image(stacked_images)
254
- text_features = model.encode_text(tokenized_text)
255
- image_features /= image_features.norm(dim=-1, keepdim=True)
256
- text_features /= text_features.norm(dim=-1, keepdim=True)
257
- probs = 100.0 * image_features @ text_features.T
258
- return probs[:, 0].softmax(dim=0)
259
-
260
- def _crop_image(self, format_results):
261
- """Crops an image based on provided annotation format and returns cropped images and related data."""
262
- if os.path.isdir(self.source):
263
- raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
264
- image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
265
- ori_w, ori_h = image.size
266
- annotations = format_results
267
- mask_h, mask_w = annotations[0]["segmentation"].shape
268
- if ori_w != mask_w or ori_h != mask_h:
269
- image = image.resize((mask_w, mask_h))
270
- cropped_boxes = []
271
- cropped_images = []
272
- not_crop = []
273
- filter_id = []
274
- for _, mask in enumerate(annotations):
275
- if np.sum(mask["segmentation"]) <= 100:
276
- filter_id.append(_)
277
- continue
278
- bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
279
- cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
280
- cropped_images.append(bbox) # save cropped image bbox
281
-
282
- return cropped_boxes, cropped_images, not_crop, filter_id, annotations
283
-
284
- def box_prompt(self, bbox):
285
- """Modifies the bounding box properties and calculates IoU between masks and bounding box."""
286
- if self.results[0].masks is not None:
287
- assert bbox[2] != 0 and bbox[3] != 0
288
- if os.path.isdir(self.source):
289
- raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
290
- masks = self.results[0].masks.data
291
- target_height, target_width = self.results[0].orig_shape
292
- h = masks.shape[1]
293
- w = masks.shape[2]
294
- if h != target_height or w != target_width:
295
- bbox = [
296
- int(bbox[0] * w / target_width),
297
- int(bbox[1] * h / target_height),
298
- int(bbox[2] * w / target_width),
299
- int(bbox[3] * h / target_height),
300
- ]
301
- bbox[0] = max(round(bbox[0]), 0)
302
- bbox[1] = max(round(bbox[1]), 0)
303
- bbox[2] = min(round(bbox[2]), w)
304
- bbox[3] = min(round(bbox[3]), h)
305
-
306
- # IoUs = torch.zeros(len(masks), dtype=torch.float32)
307
- bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
308
-
309
- masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
310
- orig_masks_area = torch.sum(masks, dim=(1, 2))
311
-
312
- union = bbox_area + orig_masks_area - masks_area
313
- iou = masks_area / union
314
- max_iou_index = torch.argmax(iou)
315
-
316
- self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
317
- return self.results
318
-
319
- def point_prompt(self, points, pointlabel): # numpy
320
- """Adjusts points on detected masks based on user input and returns the modified results."""
321
- if self.results[0].masks is not None:
322
- if os.path.isdir(self.source):
323
- raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
324
- masks = self._format_results(self.results[0], 0)
325
- target_height, target_width = self.results[0].orig_shape
326
- h = masks[0]["segmentation"].shape[0]
327
- w = masks[0]["segmentation"].shape[1]
328
- if h != target_height or w != target_width:
329
- points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
330
- onemask = np.zeros((h, w))
331
- for annotation in masks:
332
- mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
333
- for i, point in enumerate(points):
334
- if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
335
- onemask += mask
336
- if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
337
- onemask -= mask
338
- onemask = onemask >= 1
339
- self.results[0].masks.data = torch.tensor(np.array([onemask]))
340
- return self.results
341
-
342
- def text_prompt(self, text):
343
- """Processes a text prompt, applies it to existing results and returns the updated results."""
344
- if self.results[0].masks is not None:
345
- format_results = self._format_results(self.results[0], 0)
346
- cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
347
- clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
348
- scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
349
- max_idx = scores.argsort()
350
- max_idx = max_idx[-1]
351
- max_idx += sum(np.array(filter_id) <= int(max_idx))
352
- self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
353
- return self.results
354
-
355
- def everything_prompt(self):
356
- """Returns the processed results from the previous methods in the class."""
357
- return self.results