ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +34 -0
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +5 -0
- ultralytics/data/explorer/explorer.py +170 -97
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +146 -76
- ultralytics/data/explorer/utils.py +87 -25
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +63 -40
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -12
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +80 -58
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +67 -59
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +22 -15
- ultralytics/solutions/heatmap.py +76 -54
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -151
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +39 -29
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.237.dist-info/RECORD +0 -187
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,14 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
1
3
|
from io import BytesIO
|
|
2
4
|
from pathlib import Path
|
|
3
|
-
from typing import List
|
|
5
|
+
from typing import Any, List, Tuple, Union
|
|
4
6
|
|
|
5
7
|
import cv2
|
|
6
8
|
import numpy as np
|
|
7
9
|
import torch
|
|
8
10
|
from matplotlib import pyplot as plt
|
|
11
|
+
from pandas import DataFrame
|
|
9
12
|
from PIL import Image
|
|
10
13
|
from tqdm import tqdm
|
|
11
14
|
|
|
@@ -13,19 +16,17 @@ from ultralytics.data.augment import Format
|
|
|
13
16
|
from ultralytics.data.dataset import YOLODataset
|
|
14
17
|
from ultralytics.data.utils import check_det_dataset
|
|
15
18
|
from ultralytics.models.yolo.model import YOLO
|
|
16
|
-
from ultralytics.utils import LOGGER, checks
|
|
19
|
+
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
|
|
17
20
|
|
|
18
|
-
from .utils import get_sim_index_schema, get_table_schema,
|
|
21
|
+
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
|
|
19
22
|
|
|
20
23
|
|
|
21
24
|
class ExplorerDataset(YOLODataset):
|
|
22
|
-
|
|
23
|
-
def __init__(self, *args, data=None, **kwargs):
|
|
25
|
+
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
|
24
26
|
super().__init__(*args, data=data, **kwargs)
|
|
25
27
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
|
28
|
+
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
|
|
29
|
+
"""Loads 1 image from dataset index 'i' without any resize ops."""
|
|
29
30
|
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
|
30
31
|
if im is None: # not cached in RAM
|
|
31
32
|
if fn.exists(): # load npy
|
|
@@ -33,15 +34,16 @@ class ExplorerDataset(YOLODataset):
|
|
|
33
34
|
else: # read image
|
|
34
35
|
im = cv2.imread(f) # BGR
|
|
35
36
|
if im is None:
|
|
36
|
-
raise FileNotFoundError(f
|
|
37
|
+
raise FileNotFoundError(f"Image Not Found {f}")
|
|
37
38
|
h0, w0 = im.shape[:2] # orig hw
|
|
38
39
|
return im, (h0, w0), im.shape[:2]
|
|
39
40
|
|
|
40
41
|
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
|
41
42
|
|
|
42
|
-
def build_transforms(self, hyp=None):
|
|
43
|
+
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
|
44
|
+
"""Creates transforms for dataset images without resizing."""
|
|
43
45
|
return Format(
|
|
44
|
-
bbox_format=
|
|
46
|
+
bbox_format="xyxy",
|
|
45
47
|
normalize=False,
|
|
46
48
|
return_mask=self.use_segments,
|
|
47
49
|
return_keypoint=self.use_keypoints,
|
|
@@ -52,14 +54,16 @@ class ExplorerDataset(YOLODataset):
|
|
|
52
54
|
|
|
53
55
|
|
|
54
56
|
class Explorer:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
57
|
+
def __init__(
|
|
58
|
+
self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer"
|
|
59
|
+
) -> None:
|
|
60
|
+
checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
|
|
58
61
|
import lancedb
|
|
59
62
|
|
|
60
63
|
self.connection = lancedb.connect(uri)
|
|
61
|
-
self.table_name = Path(data).name.lower() +
|
|
62
|
-
self.sim_idx_base_name =
|
|
64
|
+
self.table_name = Path(data).name.lower() + "_" + model.lower()
|
|
65
|
+
self.sim_idx_base_name = (
|
|
66
|
+
f"{self.table_name}_sim_idx".lower()
|
|
63
67
|
) # Use this name and append thres and top_k to reuse the table
|
|
64
68
|
self.model = YOLO(model)
|
|
65
69
|
self.data = data # None
|
|
@@ -68,7 +72,7 @@ class Explorer:
|
|
|
68
72
|
self.table = None
|
|
69
73
|
self.progress = 0
|
|
70
74
|
|
|
71
|
-
def create_embeddings_table(self, force=False, split=
|
|
75
|
+
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
|
|
72
76
|
"""
|
|
73
77
|
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
|
74
78
|
already exists. Pass force=True to overwrite the existing table.
|
|
@@ -84,20 +88,20 @@ class Explorer:
|
|
|
84
88
|
```
|
|
85
89
|
"""
|
|
86
90
|
if self.table is not None and not force:
|
|
87
|
-
LOGGER.info(
|
|
91
|
+
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
|
|
88
92
|
return
|
|
89
93
|
if self.table_name in self.connection.table_names() and not force:
|
|
90
|
-
LOGGER.info(f
|
|
94
|
+
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
|
|
91
95
|
self.table = self.connection.open_table(self.table_name)
|
|
92
96
|
self.progress = 1
|
|
93
97
|
return
|
|
94
98
|
if self.data is None:
|
|
95
|
-
raise ValueError(
|
|
99
|
+
raise ValueError("Data must be provided to create embeddings table")
|
|
96
100
|
|
|
97
101
|
data_info = check_det_dataset(self.data)
|
|
98
102
|
if split not in data_info:
|
|
99
103
|
raise ValueError(
|
|
100
|
-
f
|
|
104
|
+
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
|
|
101
105
|
)
|
|
102
106
|
|
|
103
107
|
choice_set = data_info[split]
|
|
@@ -107,29 +111,33 @@ class Explorer:
|
|
|
107
111
|
|
|
108
112
|
# Create the table schema
|
|
109
113
|
batch = dataset[0]
|
|
110
|
-
vector_size = self.model.embed(batch[
|
|
111
|
-
|
|
112
|
-
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
|
|
114
|
+
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
|
|
115
|
+
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
|
|
113
116
|
table.add(
|
|
114
|
-
self._yield_batches(
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
117
|
+
self._yield_batches(
|
|
118
|
+
dataset,
|
|
119
|
+
data_info,
|
|
120
|
+
self.model,
|
|
121
|
+
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
|
|
122
|
+
)
|
|
123
|
+
)
|
|
118
124
|
|
|
119
125
|
self.table = table
|
|
120
126
|
|
|
121
|
-
def _yield_batches(self, dataset, data_info, model, exclude_keys: List):
|
|
122
|
-
|
|
127
|
+
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
|
|
128
|
+
"""Generates batches of data for embedding, excluding specified keys."""
|
|
123
129
|
for i in tqdm(range(len(dataset))):
|
|
124
130
|
self.progress = float(i + 1) / len(dataset)
|
|
125
131
|
batch = dataset[i]
|
|
126
132
|
for k in exclude_keys:
|
|
127
133
|
batch.pop(k, None)
|
|
128
134
|
batch = sanitize_batch(batch, data_info)
|
|
129
|
-
batch[
|
|
135
|
+
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
|
|
130
136
|
yield [batch]
|
|
131
137
|
|
|
132
|
-
def query(
|
|
138
|
+
def query(
|
|
139
|
+
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
|
|
140
|
+
) -> Any: # pyarrow.Table
|
|
133
141
|
"""
|
|
134
142
|
Query the table for similar images. Accepts a single image or a list of images.
|
|
135
143
|
|
|
@@ -138,7 +146,7 @@ class Explorer:
|
|
|
138
146
|
limit (int): Number of results to return.
|
|
139
147
|
|
|
140
148
|
Returns:
|
|
141
|
-
An arrow table containing the results. Supports converting to:
|
|
149
|
+
(pyarrow.Table): An arrow table containing the results. Supports converting to:
|
|
142
150
|
- pandas dataframe: `result.to_pandas()`
|
|
143
151
|
- dict of lists: `result.to_pydict()`
|
|
144
152
|
|
|
@@ -150,19 +158,18 @@ class Explorer:
|
|
|
150
158
|
```
|
|
151
159
|
"""
|
|
152
160
|
if self.table is None:
|
|
153
|
-
raise ValueError(
|
|
161
|
+
raise ValueError("Table is not created. Please create the table first.")
|
|
154
162
|
if isinstance(imgs, str):
|
|
155
163
|
imgs = [imgs]
|
|
156
|
-
|
|
157
|
-
pass
|
|
158
|
-
else:
|
|
159
|
-
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
|
|
164
|
+
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
|
|
160
165
|
embeds = self.model.embed(imgs)
|
|
161
166
|
# Get avg if multiple images are passed (len > 1)
|
|
162
167
|
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
|
163
168
|
return self.table.search(embeds).limit(limit).to_arrow()
|
|
164
169
|
|
|
165
|
-
def sql_query(
|
|
170
|
+
def sql_query(
|
|
171
|
+
self, query: str, return_type: str = "pandas"
|
|
172
|
+
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
|
166
173
|
"""
|
|
167
174
|
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
|
168
175
|
|
|
@@ -171,37 +178,42 @@ class Explorer:
|
|
|
171
178
|
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
|
172
179
|
|
|
173
180
|
Returns:
|
|
174
|
-
An arrow table containing the results.
|
|
181
|
+
(pyarrow.Table): An arrow table containing the results.
|
|
175
182
|
|
|
176
183
|
Example:
|
|
177
184
|
```python
|
|
178
185
|
exp = Explorer()
|
|
179
186
|
exp.create_embeddings_table()
|
|
180
|
-
query =
|
|
187
|
+
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
|
181
188
|
result = exp.sql_query(query)
|
|
182
189
|
```
|
|
183
190
|
"""
|
|
191
|
+
assert return_type in [
|
|
192
|
+
"pandas",
|
|
193
|
+
"arrow",
|
|
194
|
+
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
|
184
195
|
import duckdb
|
|
185
196
|
|
|
186
197
|
if self.table is None:
|
|
187
|
-
raise ValueError(
|
|
198
|
+
raise ValueError("Table is not created. Please create the table first.")
|
|
188
199
|
|
|
189
200
|
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
|
190
|
-
table = self.table.to_arrow() # noqa
|
|
191
|
-
if not query.startswith(
|
|
201
|
+
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
|
202
|
+
if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
|
192
203
|
raise ValueError(
|
|
193
|
-
|
|
194
|
-
|
|
204
|
+
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
|
|
205
|
+
)
|
|
206
|
+
if query.startswith("WHERE"):
|
|
195
207
|
query = f"SELECT * FROM 'table' {query}"
|
|
196
|
-
LOGGER.info(f
|
|
208
|
+
LOGGER.info(f"Running query: {query}")
|
|
197
209
|
|
|
198
210
|
rs = duckdb.sql(query)
|
|
199
|
-
if return_type ==
|
|
211
|
+
if return_type == "pandas":
|
|
200
212
|
return rs.df()
|
|
201
|
-
elif return_type ==
|
|
213
|
+
elif return_type == "arrow":
|
|
202
214
|
return rs.arrow()
|
|
203
215
|
|
|
204
|
-
def plot_sql_query(self, query, labels=True):
|
|
216
|
+
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
|
205
217
|
"""
|
|
206
218
|
Plot the results of a SQL-Like query on the table.
|
|
207
219
|
Args:
|
|
@@ -209,21 +221,30 @@ class Explorer:
|
|
|
209
221
|
labels (bool): Whether to plot the labels or not.
|
|
210
222
|
|
|
211
223
|
Returns:
|
|
212
|
-
PIL Image containing the plot.
|
|
224
|
+
(PIL.Image): Image containing the plot.
|
|
213
225
|
|
|
214
226
|
Example:
|
|
215
227
|
```python
|
|
216
228
|
exp = Explorer()
|
|
217
229
|
exp.create_embeddings_table()
|
|
218
|
-
query =
|
|
230
|
+
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
|
219
231
|
result = exp.plot_sql_query(query)
|
|
220
232
|
```
|
|
221
233
|
"""
|
|
222
|
-
result = self.sql_query(query, return_type=
|
|
223
|
-
|
|
234
|
+
result = self.sql_query(query, return_type="arrow")
|
|
235
|
+
if len(result) == 0:
|
|
236
|
+
LOGGER.info("No results found.")
|
|
237
|
+
return None
|
|
238
|
+
img = plot_query_result(result, plot_labels=labels)
|
|
224
239
|
return Image.fromarray(img)
|
|
225
240
|
|
|
226
|
-
def get_similar(
|
|
241
|
+
def get_similar(
|
|
242
|
+
self,
|
|
243
|
+
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
|
244
|
+
idx: Union[int, List[int]] = None,
|
|
245
|
+
limit: int = 25,
|
|
246
|
+
return_type: str = "pandas",
|
|
247
|
+
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
|
227
248
|
"""
|
|
228
249
|
Query the table for similar images. Accepts a single image or a list of images.
|
|
229
250
|
|
|
@@ -234,7 +255,7 @@ class Explorer:
|
|
|
234
255
|
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
|
235
256
|
|
|
236
257
|
Returns:
|
|
237
|
-
A
|
|
258
|
+
(pandas.DataFrame): A dataframe containing the results.
|
|
238
259
|
|
|
239
260
|
Example:
|
|
240
261
|
```python
|
|
@@ -243,15 +264,25 @@ class Explorer:
|
|
|
243
264
|
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
|
244
265
|
```
|
|
245
266
|
"""
|
|
267
|
+
assert return_type in [
|
|
268
|
+
"pandas",
|
|
269
|
+
"arrow",
|
|
270
|
+
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
|
246
271
|
img = self._check_imgs_or_idxs(img, idx)
|
|
247
272
|
similar = self.query(img, limit=limit)
|
|
248
273
|
|
|
249
|
-
if return_type ==
|
|
274
|
+
if return_type == "pandas":
|
|
250
275
|
return similar.to_pandas()
|
|
251
|
-
elif return_type ==
|
|
276
|
+
elif return_type == "arrow":
|
|
252
277
|
return similar
|
|
253
278
|
|
|
254
|
-
def plot_similar(
|
|
279
|
+
def plot_similar(
|
|
280
|
+
self,
|
|
281
|
+
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
|
282
|
+
idx: Union[int, List[int]] = None,
|
|
283
|
+
limit: int = 25,
|
|
284
|
+
labels: bool = True,
|
|
285
|
+
) -> Image.Image:
|
|
255
286
|
"""
|
|
256
287
|
Plot the similar images. Accepts images or indexes.
|
|
257
288
|
|
|
@@ -262,7 +293,7 @@ class Explorer:
|
|
|
262
293
|
limit (int): Number of results to return. Defaults to 25.
|
|
263
294
|
|
|
264
295
|
Returns:
|
|
265
|
-
PIL Image containing the plot.
|
|
296
|
+
(PIL.Image): Image containing the plot.
|
|
266
297
|
|
|
267
298
|
Example:
|
|
268
299
|
```python
|
|
@@ -271,11 +302,14 @@ class Explorer:
|
|
|
271
302
|
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
|
|
272
303
|
```
|
|
273
304
|
"""
|
|
274
|
-
similar = self.get_similar(img, idx, limit, return_type=
|
|
275
|
-
|
|
305
|
+
similar = self.get_similar(img, idx, limit, return_type="arrow")
|
|
306
|
+
if len(similar) == 0:
|
|
307
|
+
LOGGER.info("No results found.")
|
|
308
|
+
return None
|
|
309
|
+
img = plot_query_result(similar, plot_labels=labels)
|
|
276
310
|
return Image.fromarray(img)
|
|
277
311
|
|
|
278
|
-
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
|
312
|
+
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
|
|
279
313
|
"""
|
|
280
314
|
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
|
281
315
|
are max_dist or closer to the image in the embedding space at a given index.
|
|
@@ -283,11 +317,12 @@ class Explorer:
|
|
|
283
317
|
Args:
|
|
284
318
|
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
|
285
319
|
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
|
|
286
|
-
|
|
320
|
+
vector search. Defaults: None.
|
|
287
321
|
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
|
288
322
|
|
|
289
323
|
Returns:
|
|
290
|
-
A
|
|
324
|
+
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
|
|
325
|
+
include indices of similar images and their respective distances.
|
|
291
326
|
|
|
292
327
|
Example:
|
|
293
328
|
```python
|
|
@@ -297,39 +332,43 @@ class Explorer:
|
|
|
297
332
|
```
|
|
298
333
|
"""
|
|
299
334
|
if self.table is None:
|
|
300
|
-
raise ValueError(
|
|
301
|
-
sim_idx_table_name = f
|
|
335
|
+
raise ValueError("Table is not created. Please create the table first.")
|
|
336
|
+
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
|
|
302
337
|
if sim_idx_table_name in self.connection.table_names() and not force:
|
|
303
|
-
LOGGER.info(
|
|
338
|
+
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
|
|
304
339
|
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
|
305
340
|
|
|
306
341
|
if top_k and not (1.0 >= top_k >= 0.0):
|
|
307
|
-
raise ValueError(f
|
|
342
|
+
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
|
|
308
343
|
if max_dist < 0.0:
|
|
309
|
-
raise ValueError(f
|
|
344
|
+
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
|
|
310
345
|
|
|
311
346
|
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
|
312
347
|
top_k = max(top_k, 1)
|
|
313
|
-
features = self.table.to_lance().to_table(columns=[
|
|
314
|
-
im_files = features[
|
|
315
|
-
embeddings = features[
|
|
348
|
+
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
|
|
349
|
+
im_files = features["im_file"]
|
|
350
|
+
embeddings = features["vector"]
|
|
316
351
|
|
|
317
|
-
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode=
|
|
352
|
+
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
|
|
318
353
|
|
|
319
354
|
def _yield_sim_idx():
|
|
355
|
+
"""Generates a dataframe with similarity indices and distances for images."""
|
|
320
356
|
for i in tqdm(range(len(embeddings))):
|
|
321
|
-
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f
|
|
322
|
-
yield [
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
357
|
+
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
|
|
358
|
+
yield [
|
|
359
|
+
{
|
|
360
|
+
"idx": i,
|
|
361
|
+
"im_file": im_files[i],
|
|
362
|
+
"count": len(sim_idx),
|
|
363
|
+
"sim_im_files": sim_idx["im_file"].tolist(),
|
|
364
|
+
}
|
|
365
|
+
]
|
|
327
366
|
|
|
328
367
|
sim_table.add(_yield_sim_idx())
|
|
329
368
|
self.sim_index = sim_table
|
|
330
369
|
return sim_table.to_pandas()
|
|
331
370
|
|
|
332
|
-
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
|
371
|
+
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
|
|
333
372
|
"""
|
|
334
373
|
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
|
|
335
374
|
max_dist or closer to the image in the embedding space at a given index.
|
|
@@ -341,17 +380,20 @@ class Explorer:
|
|
|
341
380
|
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
|
342
381
|
|
|
343
382
|
Returns:
|
|
344
|
-
PIL Image containing the plot.
|
|
383
|
+
(PIL.Image): Image containing the plot.
|
|
345
384
|
|
|
346
385
|
Example:
|
|
347
386
|
```python
|
|
348
387
|
exp = Explorer()
|
|
349
388
|
exp.create_embeddings_table()
|
|
350
|
-
|
|
389
|
+
|
|
390
|
+
similarity_idx_plot = exp.plot_similarity_index()
|
|
391
|
+
similarity_idx_plot.show() # view image preview
|
|
392
|
+
similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
|
|
351
393
|
```
|
|
352
394
|
"""
|
|
353
395
|
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
|
354
|
-
sim_count = sim_idx[
|
|
396
|
+
sim_count = sim_idx["count"].tolist()
|
|
355
397
|
sim_count = np.array(sim_count)
|
|
356
398
|
|
|
357
399
|
indices = np.arange(len(sim_count))
|
|
@@ -360,37 +402,68 @@ class Explorer:
|
|
|
360
402
|
plt.bar(indices, sim_count)
|
|
361
403
|
|
|
362
404
|
# Customize the plot (optional)
|
|
363
|
-
plt.xlabel(
|
|
364
|
-
plt.ylabel(
|
|
365
|
-
plt.title(
|
|
405
|
+
plt.xlabel("data idx")
|
|
406
|
+
plt.ylabel("Count")
|
|
407
|
+
plt.title("Similarity Count")
|
|
366
408
|
buffer = BytesIO()
|
|
367
|
-
plt.savefig(buffer, format=
|
|
409
|
+
plt.savefig(buffer, format="png")
|
|
368
410
|
buffer.seek(0)
|
|
369
411
|
|
|
370
412
|
# Use Pillow to open the image from the buffer
|
|
371
|
-
return Image.open(buffer)
|
|
413
|
+
return Image.fromarray(np.array(Image.open(buffer)))
|
|
372
414
|
|
|
373
|
-
def _check_imgs_or_idxs(
|
|
415
|
+
def _check_imgs_or_idxs(
|
|
416
|
+
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
|
|
417
|
+
) -> List[np.ndarray]:
|
|
374
418
|
if img is None and idx is None:
|
|
375
|
-
raise ValueError(
|
|
419
|
+
raise ValueError("Either img or idx must be provided.")
|
|
376
420
|
if img is not None and idx is not None:
|
|
377
|
-
raise ValueError(
|
|
421
|
+
raise ValueError("Only one of img or idx must be provided.")
|
|
378
422
|
if idx is not None:
|
|
379
423
|
idx = idx if isinstance(idx, list) else [idx]
|
|
380
|
-
img = self.table.to_lance().take(idx, columns=[
|
|
424
|
+
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
|
|
381
425
|
|
|
382
426
|
return img if isinstance(img, list) else [img]
|
|
383
427
|
|
|
428
|
+
def ask_ai(self, query):
|
|
429
|
+
"""
|
|
430
|
+
Ask AI a question.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
query (str): Question to ask.
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
(pandas.DataFrame): A dataframe containing filtered results to the SQL query.
|
|
437
|
+
|
|
438
|
+
Example:
|
|
439
|
+
```python
|
|
440
|
+
exp = Explorer()
|
|
441
|
+
exp.create_embeddings_table()
|
|
442
|
+
answer = exp.ask_ai('Show images with 1 person and 2 dogs')
|
|
443
|
+
```
|
|
444
|
+
"""
|
|
445
|
+
result = prompt_sql_query(query)
|
|
446
|
+
try:
|
|
447
|
+
df = self.sql_query(result)
|
|
448
|
+
except Exception as e:
|
|
449
|
+
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
|
450
|
+
LOGGER.error(e)
|
|
451
|
+
return None
|
|
452
|
+
return df
|
|
453
|
+
|
|
384
454
|
def visualize(self, result):
|
|
385
455
|
"""
|
|
386
|
-
Visualize the results of a query.
|
|
456
|
+
Visualize the results of a query. TODO.
|
|
387
457
|
|
|
388
458
|
Args:
|
|
389
|
-
result (
|
|
459
|
+
result (pyarrow.Table): Table containing the results of a query.
|
|
390
460
|
"""
|
|
391
|
-
# TODO:
|
|
392
461
|
pass
|
|
393
462
|
|
|
394
463
|
def generate_report(self, result):
|
|
395
|
-
"""
|
|
464
|
+
"""
|
|
465
|
+
Generate a report of the dataset.
|
|
466
|
+
|
|
467
|
+
TODO
|
|
468
|
+
"""
|
|
396
469
|
pass
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|