ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
1
3
|
from io import BytesIO
|
|
2
4
|
from pathlib import Path
|
|
3
5
|
from typing import Any, List, Tuple, Union
|
|
@@ -20,13 +22,11 @@ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, pr
|
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
class ExplorerDataset(YOLODataset):
|
|
23
|
-
|
|
24
25
|
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
|
25
26
|
super().__init__(*args, data=data, **kwargs)
|
|
26
27
|
|
|
27
|
-
# NOTE: Load the image directly without any resize operations.
|
|
28
28
|
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
|
|
29
|
-
"""Loads 1 image from dataset index 'i'
|
|
29
|
+
"""Loads 1 image from dataset index 'i' without any resize ops."""
|
|
30
30
|
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
|
31
31
|
if im is None: # not cached in RAM
|
|
32
32
|
if fn.exists(): # load npy
|
|
@@ -34,15 +34,16 @@ class ExplorerDataset(YOLODataset):
|
|
|
34
34
|
else: # read image
|
|
35
35
|
im = cv2.imread(f) # BGR
|
|
36
36
|
if im is None:
|
|
37
|
-
raise FileNotFoundError(f
|
|
37
|
+
raise FileNotFoundError(f"Image Not Found {f}")
|
|
38
38
|
h0, w0 = im.shape[:2] # orig hw
|
|
39
39
|
return im, (h0, w0), im.shape[:2]
|
|
40
40
|
|
|
41
41
|
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
|
42
42
|
|
|
43
43
|
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
|
44
|
+
"""Creates transforms for dataset images without resizing."""
|
|
44
45
|
return Format(
|
|
45
|
-
bbox_format=
|
|
46
|
+
bbox_format="xyxy",
|
|
46
47
|
normalize=False,
|
|
47
48
|
return_mask=self.use_segments,
|
|
48
49
|
return_keypoint=self.use_keypoints,
|
|
@@ -53,17 +54,16 @@ class ExplorerDataset(YOLODataset):
|
|
|
53
54
|
|
|
54
55
|
|
|
55
56
|
class Explorer:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
uri: str = '~/ultralytics/explorer') -> None:
|
|
61
|
-
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
|
|
57
|
+
def __init__(
|
|
58
|
+
self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer"
|
|
59
|
+
) -> None:
|
|
60
|
+
checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
|
|
62
61
|
import lancedb
|
|
63
62
|
|
|
64
63
|
self.connection = lancedb.connect(uri)
|
|
65
|
-
self.table_name = Path(data).name.lower() +
|
|
66
|
-
self.sim_idx_base_name =
|
|
64
|
+
self.table_name = Path(data).name.lower() + "_" + model.lower()
|
|
65
|
+
self.sim_idx_base_name = (
|
|
66
|
+
f"{self.table_name}_sim_idx".lower()
|
|
67
67
|
) # Use this name and append thres and top_k to reuse the table
|
|
68
68
|
self.model = YOLO(model)
|
|
69
69
|
self.data = data # None
|
|
@@ -72,7 +72,7 @@ class Explorer:
|
|
|
72
72
|
self.table = None
|
|
73
73
|
self.progress = 0
|
|
74
74
|
|
|
75
|
-
def create_embeddings_table(self, force: bool = False, split: str =
|
|
75
|
+
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
|
|
76
76
|
"""
|
|
77
77
|
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
|
78
78
|
already exists. Pass force=True to overwrite the existing table.
|
|
@@ -88,20 +88,20 @@ class Explorer:
|
|
|
88
88
|
```
|
|
89
89
|
"""
|
|
90
90
|
if self.table is not None and not force:
|
|
91
|
-
LOGGER.info(
|
|
91
|
+
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
|
|
92
92
|
return
|
|
93
93
|
if self.table_name in self.connection.table_names() and not force:
|
|
94
|
-
LOGGER.info(f
|
|
94
|
+
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
|
|
95
95
|
self.table = self.connection.open_table(self.table_name)
|
|
96
96
|
self.progress = 1
|
|
97
97
|
return
|
|
98
98
|
if self.data is None:
|
|
99
|
-
raise ValueError(
|
|
99
|
+
raise ValueError("Data must be provided to create embeddings table")
|
|
100
100
|
|
|
101
101
|
data_info = check_det_dataset(self.data)
|
|
102
102
|
if split not in data_info:
|
|
103
103
|
raise ValueError(
|
|
104
|
-
f
|
|
104
|
+
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
|
|
105
105
|
)
|
|
106
106
|
|
|
107
107
|
choice_set = data_info[split]
|
|
@@ -111,30 +111,33 @@ class Explorer:
|
|
|
111
111
|
|
|
112
112
|
# Create the table schema
|
|
113
113
|
batch = dataset[0]
|
|
114
|
-
vector_size = self.model.embed(batch[
|
|
115
|
-
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode=
|
|
114
|
+
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
|
|
115
|
+
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
|
|
116
116
|
table.add(
|
|
117
|
-
self._yield_batches(
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
117
|
+
self._yield_batches(
|
|
118
|
+
dataset,
|
|
119
|
+
data_info,
|
|
120
|
+
self.model,
|
|
121
|
+
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
|
|
122
|
+
)
|
|
123
|
+
)
|
|
121
124
|
|
|
122
125
|
self.table = table
|
|
123
126
|
|
|
124
127
|
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
|
|
125
|
-
|
|
128
|
+
"""Generates batches of data for embedding, excluding specified keys."""
|
|
126
129
|
for i in tqdm(range(len(dataset))):
|
|
127
130
|
self.progress = float(i + 1) / len(dataset)
|
|
128
131
|
batch = dataset[i]
|
|
129
132
|
for k in exclude_keys:
|
|
130
133
|
batch.pop(k, None)
|
|
131
134
|
batch = sanitize_batch(batch, data_info)
|
|
132
|
-
batch[
|
|
135
|
+
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
|
|
133
136
|
yield [batch]
|
|
134
137
|
|
|
135
|
-
def query(
|
|
136
|
-
|
|
137
|
-
|
|
138
|
+
def query(
|
|
139
|
+
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
|
|
140
|
+
) -> Any: # pyarrow.Table
|
|
138
141
|
"""
|
|
139
142
|
Query the table for similar images. Accepts a single image or a list of images.
|
|
140
143
|
|
|
@@ -143,7 +146,7 @@ class Explorer:
|
|
|
143
146
|
limit (int): Number of results to return.
|
|
144
147
|
|
|
145
148
|
Returns:
|
|
146
|
-
An arrow table containing the results. Supports converting to:
|
|
149
|
+
(pyarrow.Table): An arrow table containing the results. Supports converting to:
|
|
147
150
|
- pandas dataframe: `result.to_pandas()`
|
|
148
151
|
- dict of lists: `result.to_pydict()`
|
|
149
152
|
|
|
@@ -155,18 +158,18 @@ class Explorer:
|
|
|
155
158
|
```
|
|
156
159
|
"""
|
|
157
160
|
if self.table is None:
|
|
158
|
-
raise ValueError(
|
|
161
|
+
raise ValueError("Table is not created. Please create the table first.")
|
|
159
162
|
if isinstance(imgs, str):
|
|
160
163
|
imgs = [imgs]
|
|
161
|
-
assert isinstance(imgs, list), f
|
|
164
|
+
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
|
|
162
165
|
embeds = self.model.embed(imgs)
|
|
163
166
|
# Get avg if multiple images are passed (len > 1)
|
|
164
167
|
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
|
165
168
|
return self.table.search(embeds).limit(limit).to_arrow()
|
|
166
169
|
|
|
167
|
-
def sql_query(
|
|
168
|
-
|
|
169
|
-
|
|
170
|
+
def sql_query(
|
|
171
|
+
self, query: str, return_type: str = "pandas"
|
|
172
|
+
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
|
170
173
|
"""
|
|
171
174
|
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
|
172
175
|
|
|
@@ -175,7 +178,7 @@ class Explorer:
|
|
|
175
178
|
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
|
176
179
|
|
|
177
180
|
Returns:
|
|
178
|
-
An arrow table containing the results.
|
|
181
|
+
(pyarrow.Table): An arrow table containing the results.
|
|
179
182
|
|
|
180
183
|
Example:
|
|
181
184
|
```python
|
|
@@ -185,27 +188,29 @@ class Explorer:
|
|
|
185
188
|
result = exp.sql_query(query)
|
|
186
189
|
```
|
|
187
190
|
"""
|
|
188
|
-
assert return_type in [
|
|
189
|
-
|
|
191
|
+
assert return_type in [
|
|
192
|
+
"pandas",
|
|
193
|
+
"arrow",
|
|
194
|
+
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
|
190
195
|
import duckdb
|
|
191
196
|
|
|
192
197
|
if self.table is None:
|
|
193
|
-
raise ValueError(
|
|
198
|
+
raise ValueError("Table is not created. Please create the table first.")
|
|
194
199
|
|
|
195
200
|
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
|
196
201
|
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
|
197
|
-
if not query.startswith(
|
|
202
|
+
if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
|
198
203
|
raise ValueError(
|
|
199
|
-
f
|
|
204
|
+
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
|
|
200
205
|
)
|
|
201
|
-
if query.startswith(
|
|
206
|
+
if query.startswith("WHERE"):
|
|
202
207
|
query = f"SELECT * FROM 'table' {query}"
|
|
203
|
-
LOGGER.info(f
|
|
208
|
+
LOGGER.info(f"Running query: {query}")
|
|
204
209
|
|
|
205
210
|
rs = duckdb.sql(query)
|
|
206
|
-
if return_type ==
|
|
211
|
+
if return_type == "pandas":
|
|
207
212
|
return rs.df()
|
|
208
|
-
elif return_type ==
|
|
213
|
+
elif return_type == "arrow":
|
|
209
214
|
return rs.arrow()
|
|
210
215
|
|
|
211
216
|
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
|
@@ -216,7 +221,7 @@ class Explorer:
|
|
|
216
221
|
labels (bool): Whether to plot the labels or not.
|
|
217
222
|
|
|
218
223
|
Returns:
|
|
219
|
-
PIL Image containing the plot.
|
|
224
|
+
(PIL.Image): Image containing the plot.
|
|
220
225
|
|
|
221
226
|
Example:
|
|
222
227
|
```python
|
|
@@ -226,18 +231,20 @@ class Explorer:
|
|
|
226
231
|
result = exp.plot_sql_query(query)
|
|
227
232
|
```
|
|
228
233
|
"""
|
|
229
|
-
result = self.sql_query(query, return_type=
|
|
234
|
+
result = self.sql_query(query, return_type="arrow")
|
|
230
235
|
if len(result) == 0:
|
|
231
|
-
LOGGER.info(
|
|
236
|
+
LOGGER.info("No results found.")
|
|
232
237
|
return None
|
|
233
238
|
img = plot_query_result(result, plot_labels=labels)
|
|
234
239
|
return Image.fromarray(img)
|
|
235
240
|
|
|
236
|
-
def get_similar(
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
+
def get_similar(
|
|
242
|
+
self,
|
|
243
|
+
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
|
244
|
+
idx: Union[int, List[int]] = None,
|
|
245
|
+
limit: int = 25,
|
|
246
|
+
return_type: str = "pandas",
|
|
247
|
+
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
|
241
248
|
"""
|
|
242
249
|
Query the table for similar images. Accepts a single image or a list of images.
|
|
243
250
|
|
|
@@ -248,7 +255,7 @@ class Explorer:
|
|
|
248
255
|
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
|
249
256
|
|
|
250
257
|
Returns:
|
|
251
|
-
A
|
|
258
|
+
(pandas.DataFrame): A dataframe containing the results.
|
|
252
259
|
|
|
253
260
|
Example:
|
|
254
261
|
```python
|
|
@@ -257,21 +264,25 @@ class Explorer:
|
|
|
257
264
|
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
|
258
265
|
```
|
|
259
266
|
"""
|
|
260
|
-
assert return_type in [
|
|
261
|
-
|
|
267
|
+
assert return_type in [
|
|
268
|
+
"pandas",
|
|
269
|
+
"arrow",
|
|
270
|
+
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
|
262
271
|
img = self._check_imgs_or_idxs(img, idx)
|
|
263
272
|
similar = self.query(img, limit=limit)
|
|
264
273
|
|
|
265
|
-
if return_type ==
|
|
274
|
+
if return_type == "pandas":
|
|
266
275
|
return similar.to_pandas()
|
|
267
|
-
elif return_type ==
|
|
276
|
+
elif return_type == "arrow":
|
|
268
277
|
return similar
|
|
269
278
|
|
|
270
|
-
def plot_similar(
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
279
|
+
def plot_similar(
|
|
280
|
+
self,
|
|
281
|
+
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
|
282
|
+
idx: Union[int, List[int]] = None,
|
|
283
|
+
limit: int = 25,
|
|
284
|
+
labels: bool = True,
|
|
285
|
+
) -> Image.Image:
|
|
275
286
|
"""
|
|
276
287
|
Plot the similar images. Accepts images or indexes.
|
|
277
288
|
|
|
@@ -282,7 +293,7 @@ class Explorer:
|
|
|
282
293
|
limit (int): Number of results to return. Defaults to 25.
|
|
283
294
|
|
|
284
295
|
Returns:
|
|
285
|
-
PIL Image containing the plot.
|
|
296
|
+
(PIL.Image): Image containing the plot.
|
|
286
297
|
|
|
287
298
|
Example:
|
|
288
299
|
```python
|
|
@@ -291,9 +302,9 @@ class Explorer:
|
|
|
291
302
|
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
|
|
292
303
|
```
|
|
293
304
|
"""
|
|
294
|
-
similar = self.get_similar(img, idx, limit, return_type=
|
|
305
|
+
similar = self.get_similar(img, idx, limit, return_type="arrow")
|
|
295
306
|
if len(similar) == 0:
|
|
296
|
-
LOGGER.info(
|
|
307
|
+
LOGGER.info("No results found.")
|
|
297
308
|
return None
|
|
298
309
|
img = plot_query_result(similar, plot_labels=labels)
|
|
299
310
|
return Image.fromarray(img)
|
|
@@ -306,11 +317,12 @@ class Explorer:
|
|
|
306
317
|
Args:
|
|
307
318
|
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
|
308
319
|
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
|
|
309
|
-
|
|
320
|
+
vector search. Defaults: None.
|
|
310
321
|
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
|
311
322
|
|
|
312
323
|
Returns:
|
|
313
|
-
A
|
|
324
|
+
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
|
|
325
|
+
include indices of similar images and their respective distances.
|
|
314
326
|
|
|
315
327
|
Example:
|
|
316
328
|
```python
|
|
@@ -320,33 +332,37 @@ class Explorer:
|
|
|
320
332
|
```
|
|
321
333
|
"""
|
|
322
334
|
if self.table is None:
|
|
323
|
-
raise ValueError(
|
|
324
|
-
sim_idx_table_name = f
|
|
335
|
+
raise ValueError("Table is not created. Please create the table first.")
|
|
336
|
+
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
|
|
325
337
|
if sim_idx_table_name in self.connection.table_names() and not force:
|
|
326
|
-
LOGGER.info(
|
|
338
|
+
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
|
|
327
339
|
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
|
328
340
|
|
|
329
341
|
if top_k and not (1.0 >= top_k >= 0.0):
|
|
330
|
-
raise ValueError(f
|
|
342
|
+
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
|
|
331
343
|
if max_dist < 0.0:
|
|
332
|
-
raise ValueError(f
|
|
344
|
+
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
|
|
333
345
|
|
|
334
346
|
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
|
335
347
|
top_k = max(top_k, 1)
|
|
336
|
-
features = self.table.to_lance().to_table(columns=[
|
|
337
|
-
im_files = features[
|
|
338
|
-
embeddings = features[
|
|
348
|
+
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
|
|
349
|
+
im_files = features["im_file"]
|
|
350
|
+
embeddings = features["vector"]
|
|
339
351
|
|
|
340
|
-
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode=
|
|
352
|
+
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
|
|
341
353
|
|
|
342
354
|
def _yield_sim_idx():
|
|
355
|
+
"""Generates a dataframe with similarity indices and distances for images."""
|
|
343
356
|
for i in tqdm(range(len(embeddings))):
|
|
344
|
-
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f
|
|
345
|
-
yield [
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
357
|
+
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
|
|
358
|
+
yield [
|
|
359
|
+
{
|
|
360
|
+
"idx": i,
|
|
361
|
+
"im_file": im_files[i],
|
|
362
|
+
"count": len(sim_idx),
|
|
363
|
+
"sim_im_files": sim_idx["im_file"].tolist(),
|
|
364
|
+
}
|
|
365
|
+
]
|
|
350
366
|
|
|
351
367
|
sim_table.add(_yield_sim_idx())
|
|
352
368
|
self.sim_index = sim_table
|
|
@@ -364,7 +380,7 @@ class Explorer:
|
|
|
364
380
|
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
|
365
381
|
|
|
366
382
|
Returns:
|
|
367
|
-
PIL.
|
|
383
|
+
(PIL.Image): Image containing the plot.
|
|
368
384
|
|
|
369
385
|
Example:
|
|
370
386
|
```python
|
|
@@ -377,7 +393,7 @@ class Explorer:
|
|
|
377
393
|
```
|
|
378
394
|
"""
|
|
379
395
|
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
|
380
|
-
sim_count = sim_idx[
|
|
396
|
+
sim_count = sim_idx["count"].tolist()
|
|
381
397
|
sim_count = np.array(sim_count)
|
|
382
398
|
|
|
383
399
|
indices = np.arange(len(sim_count))
|
|
@@ -386,25 +402,26 @@ class Explorer:
|
|
|
386
402
|
plt.bar(indices, sim_count)
|
|
387
403
|
|
|
388
404
|
# Customize the plot (optional)
|
|
389
|
-
plt.xlabel(
|
|
390
|
-
plt.ylabel(
|
|
391
|
-
plt.title(
|
|
405
|
+
plt.xlabel("data idx")
|
|
406
|
+
plt.ylabel("Count")
|
|
407
|
+
plt.title("Similarity Count")
|
|
392
408
|
buffer = BytesIO()
|
|
393
|
-
plt.savefig(buffer, format=
|
|
409
|
+
plt.savefig(buffer, format="png")
|
|
394
410
|
buffer.seek(0)
|
|
395
411
|
|
|
396
412
|
# Use Pillow to open the image from the buffer
|
|
397
413
|
return Image.fromarray(np.array(Image.open(buffer)))
|
|
398
414
|
|
|
399
|
-
def _check_imgs_or_idxs(
|
|
400
|
-
|
|
415
|
+
def _check_imgs_or_idxs(
|
|
416
|
+
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
|
|
417
|
+
) -> List[np.ndarray]:
|
|
401
418
|
if img is None and idx is None:
|
|
402
|
-
raise ValueError(
|
|
419
|
+
raise ValueError("Either img or idx must be provided.")
|
|
403
420
|
if img is not None and idx is not None:
|
|
404
|
-
raise ValueError(
|
|
421
|
+
raise ValueError("Only one of img or idx must be provided.")
|
|
405
422
|
if idx is not None:
|
|
406
423
|
idx = idx if isinstance(idx, list) else [idx]
|
|
407
|
-
img = self.table.to_lance().take(idx, columns=[
|
|
424
|
+
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
|
|
408
425
|
|
|
409
426
|
return img if isinstance(img, list) else [img]
|
|
410
427
|
|
|
@@ -416,7 +433,7 @@ class Explorer:
|
|
|
416
433
|
query (str): Question to ask.
|
|
417
434
|
|
|
418
435
|
Returns:
|
|
419
|
-
|
|
436
|
+
(pandas.DataFrame): A dataframe containing filtered results to the SQL query.
|
|
420
437
|
|
|
421
438
|
Example:
|
|
422
439
|
```python
|
|
@@ -429,21 +446,24 @@ class Explorer:
|
|
|
429
446
|
try:
|
|
430
447
|
df = self.sql_query(result)
|
|
431
448
|
except Exception as e:
|
|
432
|
-
LOGGER.error(
|
|
449
|
+
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
|
433
450
|
LOGGER.error(e)
|
|
434
451
|
return None
|
|
435
452
|
return df
|
|
436
453
|
|
|
437
454
|
def visualize(self, result):
|
|
438
455
|
"""
|
|
439
|
-
Visualize the results of a query.
|
|
456
|
+
Visualize the results of a query. TODO.
|
|
440
457
|
|
|
441
458
|
Args:
|
|
442
|
-
result (
|
|
459
|
+
result (pyarrow.Table): Table containing the results of a query.
|
|
443
460
|
"""
|
|
444
|
-
# TODO:
|
|
445
461
|
pass
|
|
446
462
|
|
|
447
463
|
def generate_report(self, result):
|
|
448
|
-
"""
|
|
464
|
+
"""
|
|
465
|
+
Generate a report of the dataset.
|
|
466
|
+
|
|
467
|
+
TODO
|
|
468
|
+
"""
|
|
449
469
|
pass
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|