ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +37 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +111 -41
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +579 -244
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +191 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +226 -82
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +172 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +305 -112
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.63.dist-info/METADATA +370 -0
  235. ultralytics-8.3.63.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
@@ -1,472 +0,0 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
2
-
3
- from io import BytesIO
4
- from pathlib import Path
5
- from typing import Any, List, Tuple, Union
6
-
7
- import cv2
8
- import numpy as np
9
- import torch
10
- from PIL import Image
11
- from matplotlib import pyplot as plt
12
- from pandas import DataFrame
13
- from tqdm import tqdm
14
-
15
- from ultralytics.data.augment import Format
16
- from ultralytics.data.dataset import YOLODataset
17
- from ultralytics.data.utils import check_det_dataset
18
- from ultralytics.models.yolo.model import YOLO
19
- from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks, USER_CONFIG_DIR
20
- from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
21
-
22
-
23
- class ExplorerDataset(YOLODataset):
24
- def __init__(self, *args, data: dict = None, **kwargs) -> None:
25
- super().__init__(*args, data=data, **kwargs)
26
-
27
- def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
28
- """Loads 1 image from dataset index 'i' without any resize ops."""
29
- im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
30
- if im is None: # not cached in RAM
31
- if fn.exists(): # load npy
32
- im = np.load(fn)
33
- else: # read image
34
- im = cv2.imread(f) # BGR
35
- if im is None:
36
- raise FileNotFoundError(f"Image Not Found {f}")
37
- h0, w0 = im.shape[:2] # orig hw
38
- return im, (h0, w0), im.shape[:2]
39
-
40
- return self.ims[i], self.im_hw0[i], self.im_hw[i]
41
-
42
- def build_transforms(self, hyp: IterableSimpleNamespace = None):
43
- """Creates transforms for dataset images without resizing."""
44
- return Format(
45
- bbox_format="xyxy",
46
- normalize=False,
47
- return_mask=self.use_segments,
48
- return_keypoint=self.use_keypoints,
49
- batch_idx=True,
50
- mask_ratio=hyp.mask_ratio,
51
- mask_overlap=hyp.overlap_mask,
52
- )
53
-
54
-
55
- class Explorer:
56
- def __init__(
57
- self,
58
- data: Union[str, Path] = "coco128.yaml",
59
- model: str = "yolov8n.pt",
60
- uri: str = USER_CONFIG_DIR / "explorer",
61
- ) -> None:
62
- # Note duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181
63
- checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"])
64
- import lancedb
65
-
66
- self.connection = lancedb.connect(uri)
67
- self.table_name = Path(data).name.lower() + "_" + model.lower()
68
- self.sim_idx_base_name = (
69
- f"{self.table_name}_sim_idx".lower()
70
- ) # Use this name and append thres and top_k to reuse the table
71
- self.model = YOLO(model)
72
- self.data = data # None
73
- self.choice_set = None
74
-
75
- self.table = None
76
- self.progress = 0
77
-
78
- def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
79
- """
80
- Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
81
- already exists. Pass force=True to overwrite the existing table.
82
-
83
- Args:
84
- force (bool): Whether to overwrite the existing table or not. Defaults to False.
85
- split (str): Split of the dataset to use. Defaults to 'train'.
86
-
87
- Example:
88
- ```python
89
- exp = Explorer()
90
- exp.create_embeddings_table()
91
- ```
92
- """
93
- if self.table is not None and not force:
94
- LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
95
- return
96
- if self.table_name in self.connection.table_names() and not force:
97
- LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
98
- self.table = self.connection.open_table(self.table_name)
99
- self.progress = 1
100
- return
101
- if self.data is None:
102
- raise ValueError("Data must be provided to create embeddings table")
103
-
104
- data_info = check_det_dataset(self.data)
105
- if split not in data_info:
106
- raise ValueError(
107
- f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
108
- )
109
-
110
- choice_set = data_info[split]
111
- choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
112
- self.choice_set = choice_set
113
- dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
114
-
115
- # Create the table schema
116
- batch = dataset[0]
117
- vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
118
- table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
119
- table.add(
120
- self._yield_batches(
121
- dataset,
122
- data_info,
123
- self.model,
124
- exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
125
- )
126
- )
127
-
128
- self.table = table
129
-
130
- def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
131
- """Generates batches of data for embedding, excluding specified keys."""
132
- for i in tqdm(range(len(dataset))):
133
- self.progress = float(i + 1) / len(dataset)
134
- batch = dataset[i]
135
- for k in exclude_keys:
136
- batch.pop(k, None)
137
- batch = sanitize_batch(batch, data_info)
138
- batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
139
- yield [batch]
140
-
141
- def query(
142
- self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
143
- ) -> Any: # pyarrow.Table
144
- """
145
- Query the table for similar images. Accepts a single image or a list of images.
146
-
147
- Args:
148
- imgs (str or list): Path to the image or a list of paths to the images.
149
- limit (int): Number of results to return.
150
-
151
- Returns:
152
- (pyarrow.Table): An arrow table containing the results. Supports converting to:
153
- - pandas dataframe: `result.to_pandas()`
154
- - dict of lists: `result.to_pydict()`
155
-
156
- Example:
157
- ```python
158
- exp = Explorer()
159
- exp.create_embeddings_table()
160
- similar = exp.query(img='https://ultralytics.com/images/zidane.jpg')
161
- ```
162
- """
163
- if self.table is None:
164
- raise ValueError("Table is not created. Please create the table first.")
165
- if isinstance(imgs, str):
166
- imgs = [imgs]
167
- assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
168
- embeds = self.model.embed(imgs)
169
- # Get avg if multiple images are passed (len > 1)
170
- embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
171
- return self.table.search(embeds).limit(limit).to_arrow()
172
-
173
- def sql_query(
174
- self, query: str, return_type: str = "pandas"
175
- ) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
176
- """
177
- Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
178
-
179
- Args:
180
- query (str): SQL query to run.
181
- return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
182
-
183
- Returns:
184
- (pyarrow.Table): An arrow table containing the results.
185
-
186
- Example:
187
- ```python
188
- exp = Explorer()
189
- exp.create_embeddings_table()
190
- query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
191
- result = exp.sql_query(query)
192
- ```
193
- """
194
- assert return_type in {
195
- "pandas",
196
- "arrow",
197
- }, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
198
- import duckdb
199
-
200
- if self.table is None:
201
- raise ValueError("Table is not created. Please create the table first.")
202
-
203
- # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
204
- table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
205
- if not query.startswith("SELECT") and not query.startswith("WHERE"):
206
- raise ValueError(
207
- f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
208
- )
209
- if query.startswith("WHERE"):
210
- query = f"SELECT * FROM 'table' {query}"
211
- LOGGER.info(f"Running query: {query}")
212
-
213
- rs = duckdb.sql(query)
214
- if return_type == "arrow":
215
- return rs.arrow()
216
- elif return_type == "pandas":
217
- return rs.df()
218
-
219
- def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
220
- """
221
- Plot the results of a SQL-Like query on the table.
222
- Args:
223
- query (str): SQL query to run.
224
- labels (bool): Whether to plot the labels or not.
225
-
226
- Returns:
227
- (PIL.Image): Image containing the plot.
228
-
229
- Example:
230
- ```python
231
- exp = Explorer()
232
- exp.create_embeddings_table()
233
- query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
234
- result = exp.plot_sql_query(query)
235
- ```
236
- """
237
- result = self.sql_query(query, return_type="arrow")
238
- if len(result) == 0:
239
- LOGGER.info("No results found.")
240
- return None
241
- img = plot_query_result(result, plot_labels=labels)
242
- return Image.fromarray(img)
243
-
244
- def get_similar(
245
- self,
246
- img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
247
- idx: Union[int, List[int]] = None,
248
- limit: int = 25,
249
- return_type: str = "pandas",
250
- ) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
251
- """
252
- Query the table for similar images. Accepts a single image or a list of images.
253
-
254
- Args:
255
- img (str or list): Path to the image or a list of paths to the images.
256
- idx (int or list): Index of the image in the table or a list of indexes.
257
- limit (int): Number of results to return. Defaults to 25.
258
- return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
259
-
260
- Returns:
261
- (pandas.DataFrame): A dataframe containing the results.
262
-
263
- Example:
264
- ```python
265
- exp = Explorer()
266
- exp.create_embeddings_table()
267
- similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
268
- ```
269
- """
270
- assert return_type in {
271
- "pandas",
272
- "arrow",
273
- }, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
274
- img = self._check_imgs_or_idxs(img, idx)
275
- similar = self.query(img, limit=limit)
276
-
277
- if return_type == "arrow":
278
- return similar
279
- elif return_type == "pandas":
280
- return similar.to_pandas()
281
-
282
- def plot_similar(
283
- self,
284
- img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
285
- idx: Union[int, List[int]] = None,
286
- limit: int = 25,
287
- labels: bool = True,
288
- ) -> Image.Image:
289
- """
290
- Plot the similar images. Accepts images or indexes.
291
-
292
- Args:
293
- img (str or list): Path to the image or a list of paths to the images.
294
- idx (int or list): Index of the image in the table or a list of indexes.
295
- labels (bool): Whether to plot the labels or not.
296
- limit (int): Number of results to return. Defaults to 25.
297
-
298
- Returns:
299
- (PIL.Image): Image containing the plot.
300
-
301
- Example:
302
- ```python
303
- exp = Explorer()
304
- exp.create_embeddings_table()
305
- similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
306
- ```
307
- """
308
- similar = self.get_similar(img, idx, limit, return_type="arrow")
309
- if len(similar) == 0:
310
- LOGGER.info("No results found.")
311
- return None
312
- img = plot_query_result(similar, plot_labels=labels)
313
- return Image.fromarray(img)
314
-
315
- def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
316
- """
317
- Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
318
- are max_dist or closer to the image in the embedding space at a given index.
319
-
320
- Args:
321
- max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
322
- top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
323
- vector search. Defaults: None.
324
- force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
325
-
326
- Returns:
327
- (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
328
- include indices of similar images and their respective distances.
329
-
330
- Example:
331
- ```python
332
- exp = Explorer()
333
- exp.create_embeddings_table()
334
- sim_idx = exp.similarity_index()
335
- ```
336
- """
337
- if self.table is None:
338
- raise ValueError("Table is not created. Please create the table first.")
339
- sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
340
- if sim_idx_table_name in self.connection.table_names() and not force:
341
- LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
342
- return self.connection.open_table(sim_idx_table_name).to_pandas()
343
-
344
- if top_k and not (1.0 >= top_k >= 0.0):
345
- raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
346
- if max_dist < 0.0:
347
- raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
348
-
349
- top_k = int(top_k * len(self.table)) if top_k else len(self.table)
350
- top_k = max(top_k, 1)
351
- features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
352
- im_files = features["im_file"]
353
- embeddings = features["vector"]
354
-
355
- sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
356
-
357
- def _yield_sim_idx():
358
- """Generates a dataframe with similarity indices and distances for images."""
359
- for i in tqdm(range(len(embeddings))):
360
- sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
361
- yield [
362
- {
363
- "idx": i,
364
- "im_file": im_files[i],
365
- "count": len(sim_idx),
366
- "sim_im_files": sim_idx["im_file"].tolist(),
367
- }
368
- ]
369
-
370
- sim_table.add(_yield_sim_idx())
371
- self.sim_index = sim_table
372
- return sim_table.to_pandas()
373
-
374
- def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
375
- """
376
- Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
377
- max_dist or closer to the image in the embedding space at a given index.
378
-
379
- Args:
380
- max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
381
- top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
382
- running vector search. Defaults to 0.01.
383
- force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
384
-
385
- Returns:
386
- (PIL.Image): Image containing the plot.
387
-
388
- Example:
389
- ```python
390
- exp = Explorer()
391
- exp.create_embeddings_table()
392
-
393
- similarity_idx_plot = exp.plot_similarity_index()
394
- similarity_idx_plot.show() # view image preview
395
- similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
396
- ```
397
- """
398
- sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
399
- sim_count = sim_idx["count"].tolist()
400
- sim_count = np.array(sim_count)
401
-
402
- indices = np.arange(len(sim_count))
403
-
404
- # Create the bar plot
405
- plt.bar(indices, sim_count)
406
-
407
- # Customize the plot (optional)
408
- plt.xlabel("data idx")
409
- plt.ylabel("Count")
410
- plt.title("Similarity Count")
411
- buffer = BytesIO()
412
- plt.savefig(buffer, format="png")
413
- buffer.seek(0)
414
-
415
- # Use Pillow to open the image from the buffer
416
- return Image.fromarray(np.array(Image.open(buffer)))
417
-
418
- def _check_imgs_or_idxs(
419
- self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
420
- ) -> List[np.ndarray]:
421
- if img is None and idx is None:
422
- raise ValueError("Either img or idx must be provided.")
423
- if img is not None and idx is not None:
424
- raise ValueError("Only one of img or idx must be provided.")
425
- if idx is not None:
426
- idx = idx if isinstance(idx, list) else [idx]
427
- img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
428
-
429
- return img if isinstance(img, list) else [img]
430
-
431
- def ask_ai(self, query):
432
- """
433
- Ask AI a question.
434
-
435
- Args:
436
- query (str): Question to ask.
437
-
438
- Returns:
439
- (pandas.DataFrame): A dataframe containing filtered results to the SQL query.
440
-
441
- Example:
442
- ```python
443
- exp = Explorer()
444
- exp.create_embeddings_table()
445
- answer = exp.ask_ai('Show images with 1 person and 2 dogs')
446
- ```
447
- """
448
- result = prompt_sql_query(query)
449
- try:
450
- df = self.sql_query(result)
451
- except Exception as e:
452
- LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
453
- LOGGER.error(e)
454
- return None
455
- return df
456
-
457
- def visualize(self, result):
458
- """
459
- Visualize the results of a query. TODO.
460
-
461
- Args:
462
- result (pyarrow.Table): Table containing the results of a query.
463
- """
464
- pass
465
-
466
- def generate_report(self, result):
467
- """
468
- Generate a report of the dataset.
469
-
470
- TODO
471
- """
472
- pass
@@ -1 +0,0 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license