ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
1
3
  from io import BytesIO
2
4
  from pathlib import Path
3
- from typing import List
5
+ from typing import Any, List, Tuple, Union
4
6
 
5
7
  import cv2
6
8
  import numpy as np
7
9
  import torch
8
10
  from matplotlib import pyplot as plt
11
+ from pandas import DataFrame
9
12
  from PIL import Image
10
13
  from tqdm import tqdm
11
14
 
@@ -13,19 +16,17 @@ from ultralytics.data.augment import Format
13
16
  from ultralytics.data.dataset import YOLODataset
14
17
  from ultralytics.data.utils import check_det_dataset
15
18
  from ultralytics.models.yolo.model import YOLO
16
- from ultralytics.utils import LOGGER, checks
19
+ from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
17
20
 
18
- from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
21
+ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
19
22
 
20
23
 
21
24
  class ExplorerDataset(YOLODataset):
22
-
23
- def __init__(self, *args, data=None, **kwargs):
25
+ def __init__(self, *args, data: dict = None, **kwargs) -> None:
24
26
  super().__init__(*args, data=data, **kwargs)
25
27
 
26
- # NOTE: Load the image directly without any resize operations.
27
- def load_image(self, i):
28
- """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
28
+ def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
29
+ """Loads 1 image from dataset index 'i' without any resize ops."""
29
30
  im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
30
31
  if im is None: # not cached in RAM
31
32
  if fn.exists(): # load npy
@@ -33,15 +34,16 @@ class ExplorerDataset(YOLODataset):
33
34
  else: # read image
34
35
  im = cv2.imread(f) # BGR
35
36
  if im is None:
36
- raise FileNotFoundError(f'Image Not Found {f}')
37
+ raise FileNotFoundError(f"Image Not Found {f}")
37
38
  h0, w0 = im.shape[:2] # orig hw
38
39
  return im, (h0, w0), im.shape[:2]
39
40
 
40
41
  return self.ims[i], self.im_hw0[i], self.im_hw[i]
41
42
 
42
- def build_transforms(self, hyp=None):
43
+ def build_transforms(self, hyp: IterableSimpleNamespace = None):
44
+ """Creates transforms for dataset images without resizing."""
43
45
  return Format(
44
- bbox_format='xyxy',
46
+ bbox_format="xyxy",
45
47
  normalize=False,
46
48
  return_mask=self.use_segments,
47
49
  return_keypoint=self.use_keypoints,
@@ -52,14 +54,16 @@ class ExplorerDataset(YOLODataset):
52
54
 
53
55
 
54
56
  class Explorer:
55
-
56
- def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None:
57
- checks.check_requirements(['lancedb', 'duckdb'])
57
+ def __init__(
58
+ self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer"
59
+ ) -> None:
60
+ checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
58
61
  import lancedb
59
62
 
60
63
  self.connection = lancedb.connect(uri)
61
- self.table_name = Path(data).name.lower() + '_' + model.lower()
62
- self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
64
+ self.table_name = Path(data).name.lower() + "_" + model.lower()
65
+ self.sim_idx_base_name = (
66
+ f"{self.table_name}_sim_idx".lower()
63
67
  ) # Use this name and append thres and top_k to reuse the table
64
68
  self.model = YOLO(model)
65
69
  self.data = data # None
@@ -68,7 +72,7 @@ class Explorer:
68
72
  self.table = None
69
73
  self.progress = 0
70
74
 
71
- def create_embeddings_table(self, force=False, split='train'):
75
+ def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
72
76
  """
73
77
  Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
74
78
  already exists. Pass force=True to overwrite the existing table.
@@ -84,20 +88,20 @@ class Explorer:
84
88
  ```
85
89
  """
86
90
  if self.table is not None and not force:
87
- LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
91
+ LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
88
92
  return
89
93
  if self.table_name in self.connection.table_names() and not force:
90
- LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
94
+ LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
91
95
  self.table = self.connection.open_table(self.table_name)
92
96
  self.progress = 1
93
97
  return
94
98
  if self.data is None:
95
- raise ValueError('Data must be provided to create embeddings table')
99
+ raise ValueError("Data must be provided to create embeddings table")
96
100
 
97
101
  data_info = check_det_dataset(self.data)
98
102
  if split not in data_info:
99
103
  raise ValueError(
100
- f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
104
+ f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
101
105
  )
102
106
 
103
107
  choice_set = data_info[split]
@@ -107,29 +111,33 @@ class Explorer:
107
111
 
108
112
  # Create the table schema
109
113
  batch = dataset[0]
110
- vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
111
- Schema = get_table_schema(vector_size)
112
- table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
114
+ vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
115
+ table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
113
116
  table.add(
114
- self._yield_batches(dataset,
115
- data_info,
116
- self.model,
117
- exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
117
+ self._yield_batches(
118
+ dataset,
119
+ data_info,
120
+ self.model,
121
+ exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
122
+ )
123
+ )
118
124
 
119
125
  self.table = table
120
126
 
121
- def _yield_batches(self, dataset, data_info, model, exclude_keys: List):
122
- # Implement Batching
127
+ def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
128
+ """Generates batches of data for embedding, excluding specified keys."""
123
129
  for i in tqdm(range(len(dataset))):
124
130
  self.progress = float(i + 1) / len(dataset)
125
131
  batch = dataset[i]
126
132
  for k in exclude_keys:
127
133
  batch.pop(k, None)
128
134
  batch = sanitize_batch(batch, data_info)
129
- batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
135
+ batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
130
136
  yield [batch]
131
137
 
132
- def query(self, imgs=None, limit=25):
138
+ def query(
139
+ self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
140
+ ) -> Any: # pyarrow.Table
133
141
  """
134
142
  Query the table for similar images. Accepts a single image or a list of images.
135
143
 
@@ -138,7 +146,7 @@ class Explorer:
138
146
  limit (int): Number of results to return.
139
147
 
140
148
  Returns:
141
- An arrow table containing the results. Supports converting to:
149
+ (pyarrow.Table): An arrow table containing the results. Supports converting to:
142
150
  - pandas dataframe: `result.to_pandas()`
143
151
  - dict of lists: `result.to_pydict()`
144
152
 
@@ -150,19 +158,18 @@ class Explorer:
150
158
  ```
151
159
  """
152
160
  if self.table is None:
153
- raise ValueError('Table is not created. Please create the table first.')
161
+ raise ValueError("Table is not created. Please create the table first.")
154
162
  if isinstance(imgs, str):
155
163
  imgs = [imgs]
156
- elif isinstance(imgs, list):
157
- pass
158
- else:
159
- raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
164
+ assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
160
165
  embeds = self.model.embed(imgs)
161
166
  # Get avg if multiple images are passed (len > 1)
162
167
  embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
163
168
  return self.table.search(embeds).limit(limit).to_arrow()
164
169
 
165
- def sql_query(self, query, return_type='pandas'):
170
+ def sql_query(
171
+ self, query: str, return_type: str = "pandas"
172
+ ) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
166
173
  """
167
174
  Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
168
175
 
@@ -171,37 +178,42 @@ class Explorer:
171
178
  return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
172
179
 
173
180
  Returns:
174
- An arrow table containing the results.
181
+ (pyarrow.Table): An arrow table containing the results.
175
182
 
176
183
  Example:
177
184
  ```python
178
185
  exp = Explorer()
179
186
  exp.create_embeddings_table()
180
- query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
187
+ query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
181
188
  result = exp.sql_query(query)
182
189
  ```
183
190
  """
191
+ assert return_type in [
192
+ "pandas",
193
+ "arrow",
194
+ ], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
184
195
  import duckdb
185
196
 
186
197
  if self.table is None:
187
- raise ValueError('Table is not created. Please create the table first.')
198
+ raise ValueError("Table is not created. Please create the table first.")
188
199
 
189
200
  # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
190
- table = self.table.to_arrow() # noqa
191
- if not query.startswith('SELECT') and not query.startswith('WHERE'):
201
+ table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
202
+ if not query.startswith("SELECT") and not query.startswith("WHERE"):
192
203
  raise ValueError(
193
- 'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
194
- if query.startswith('WHERE'):
204
+ f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
205
+ )
206
+ if query.startswith("WHERE"):
195
207
  query = f"SELECT * FROM 'table' {query}"
196
- LOGGER.info(f'Running query: {query}')
208
+ LOGGER.info(f"Running query: {query}")
197
209
 
198
210
  rs = duckdb.sql(query)
199
- if return_type == 'pandas':
211
+ if return_type == "pandas":
200
212
  return rs.df()
201
- elif return_type == 'arrow':
213
+ elif return_type == "arrow":
202
214
  return rs.arrow()
203
215
 
204
- def plot_sql_query(self, query, labels=True):
216
+ def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
205
217
  """
206
218
  Plot the results of a SQL-Like query on the table.
207
219
  Args:
@@ -209,21 +221,30 @@ class Explorer:
209
221
  labels (bool): Whether to plot the labels or not.
210
222
 
211
223
  Returns:
212
- PIL Image containing the plot.
224
+ (PIL.Image): Image containing the plot.
213
225
 
214
226
  Example:
215
227
  ```python
216
228
  exp = Explorer()
217
229
  exp.create_embeddings_table()
218
- query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
230
+ query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
219
231
  result = exp.plot_sql_query(query)
220
232
  ```
221
233
  """
222
- result = self.sql_query(query, return_type='arrow')
223
- img = plot_similar_images(result, plot_labels=labels)
234
+ result = self.sql_query(query, return_type="arrow")
235
+ if len(result) == 0:
236
+ LOGGER.info("No results found.")
237
+ return None
238
+ img = plot_query_result(result, plot_labels=labels)
224
239
  return Image.fromarray(img)
225
240
 
226
- def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
241
+ def get_similar(
242
+ self,
243
+ img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
244
+ idx: Union[int, List[int]] = None,
245
+ limit: int = 25,
246
+ return_type: str = "pandas",
247
+ ) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
227
248
  """
228
249
  Query the table for similar images. Accepts a single image or a list of images.
229
250
 
@@ -234,7 +255,7 @@ class Explorer:
234
255
  return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
235
256
 
236
257
  Returns:
237
- A table or pandas dataframe containing the results.
258
+ (pandas.DataFrame): A dataframe containing the results.
238
259
 
239
260
  Example:
240
261
  ```python
@@ -243,15 +264,25 @@ class Explorer:
243
264
  similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
244
265
  ```
245
266
  """
267
+ assert return_type in [
268
+ "pandas",
269
+ "arrow",
270
+ ], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
246
271
  img = self._check_imgs_or_idxs(img, idx)
247
272
  similar = self.query(img, limit=limit)
248
273
 
249
- if return_type == 'pandas':
274
+ if return_type == "pandas":
250
275
  return similar.to_pandas()
251
- elif return_type == 'arrow':
276
+ elif return_type == "arrow":
252
277
  return similar
253
278
 
254
- def plot_similar(self, img=None, idx=None, limit=25, labels=True):
279
+ def plot_similar(
280
+ self,
281
+ img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
282
+ idx: Union[int, List[int]] = None,
283
+ limit: int = 25,
284
+ labels: bool = True,
285
+ ) -> Image.Image:
255
286
  """
256
287
  Plot the similar images. Accepts images or indexes.
257
288
 
@@ -262,7 +293,7 @@ class Explorer:
262
293
  limit (int): Number of results to return. Defaults to 25.
263
294
 
264
295
  Returns:
265
- PIL Image containing the plot.
296
+ (PIL.Image): Image containing the plot.
266
297
 
267
298
  Example:
268
299
  ```python
@@ -271,11 +302,14 @@ class Explorer:
271
302
  similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
272
303
  ```
273
304
  """
274
- similar = self.get_similar(img, idx, limit, return_type='arrow')
275
- img = plot_similar_images(similar, plot_labels=labels)
305
+ similar = self.get_similar(img, idx, limit, return_type="arrow")
306
+ if len(similar) == 0:
307
+ LOGGER.info("No results found.")
308
+ return None
309
+ img = plot_query_result(similar, plot_labels=labels)
276
310
  return Image.fromarray(img)
277
311
 
278
- def similarity_index(self, max_dist=0.2, top_k=None, force=False):
312
+ def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
279
313
  """
280
314
  Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
281
315
  are max_dist or closer to the image in the embedding space at a given index.
@@ -283,11 +317,12 @@ class Explorer:
283
317
  Args:
284
318
  max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
285
319
  top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
286
- vector search. Defaults to 0.01.
320
+ vector search. Defaults: None.
287
321
  force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
288
322
 
289
323
  Returns:
290
- A pandas dataframe containing the similarity index.
324
+ (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
325
+ include indices of similar images and their respective distances.
291
326
 
292
327
  Example:
293
328
  ```python
@@ -297,39 +332,43 @@ class Explorer:
297
332
  ```
298
333
  """
299
334
  if self.table is None:
300
- raise ValueError('Table is not created. Please create the table first.')
301
- sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
335
+ raise ValueError("Table is not created. Please create the table first.")
336
+ sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
302
337
  if sim_idx_table_name in self.connection.table_names() and not force:
303
- LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
338
+ LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
304
339
  return self.connection.open_table(sim_idx_table_name).to_pandas()
305
340
 
306
341
  if top_k and not (1.0 >= top_k >= 0.0):
307
- raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
342
+ raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
308
343
  if max_dist < 0.0:
309
- raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
344
+ raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
310
345
 
311
346
  top_k = int(top_k * len(self.table)) if top_k else len(self.table)
312
347
  top_k = max(top_k, 1)
313
- features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
314
- im_files = features['im_file']
315
- embeddings = features['vector']
348
+ features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
349
+ im_files = features["im_file"]
350
+ embeddings = features["vector"]
316
351
 
317
- sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
352
+ sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
318
353
 
319
354
  def _yield_sim_idx():
355
+ """Generates a dataframe with similarity indices and distances for images."""
320
356
  for i in tqdm(range(len(embeddings))):
321
- sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
322
- yield [{
323
- 'idx': i,
324
- 'im_file': im_files[i],
325
- 'count': len(sim_idx),
326
- 'sim_im_files': sim_idx['im_file'].tolist()}]
357
+ sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
358
+ yield [
359
+ {
360
+ "idx": i,
361
+ "im_file": im_files[i],
362
+ "count": len(sim_idx),
363
+ "sim_im_files": sim_idx["im_file"].tolist(),
364
+ }
365
+ ]
327
366
 
328
367
  sim_table.add(_yield_sim_idx())
329
368
  self.sim_index = sim_table
330
369
  return sim_table.to_pandas()
331
370
 
332
- def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
371
+ def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
333
372
  """
334
373
  Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
335
374
  max_dist or closer to the image in the embedding space at a given index.
@@ -341,17 +380,20 @@ class Explorer:
341
380
  force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
342
381
 
343
382
  Returns:
344
- PIL Image containing the plot.
383
+ (PIL.Image): Image containing the plot.
345
384
 
346
385
  Example:
347
386
  ```python
348
387
  exp = Explorer()
349
388
  exp.create_embeddings_table()
350
- exp.plot_similarity_index()
389
+
390
+ similarity_idx_plot = exp.plot_similarity_index()
391
+ similarity_idx_plot.show() # view image preview
392
+ similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
351
393
  ```
352
394
  """
353
395
  sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
354
- sim_count = sim_idx['count'].tolist()
396
+ sim_count = sim_idx["count"].tolist()
355
397
  sim_count = np.array(sim_count)
356
398
 
357
399
  indices = np.arange(len(sim_count))
@@ -360,37 +402,68 @@ class Explorer:
360
402
  plt.bar(indices, sim_count)
361
403
 
362
404
  # Customize the plot (optional)
363
- plt.xlabel('data idx')
364
- plt.ylabel('Count')
365
- plt.title('Similarity Count')
405
+ plt.xlabel("data idx")
406
+ plt.ylabel("Count")
407
+ plt.title("Similarity Count")
366
408
  buffer = BytesIO()
367
- plt.savefig(buffer, format='png')
409
+ plt.savefig(buffer, format="png")
368
410
  buffer.seek(0)
369
411
 
370
412
  # Use Pillow to open the image from the buffer
371
- return Image.open(buffer)
413
+ return Image.fromarray(np.array(Image.open(buffer)))
372
414
 
373
- def _check_imgs_or_idxs(self, img, idx):
415
+ def _check_imgs_or_idxs(
416
+ self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
417
+ ) -> List[np.ndarray]:
374
418
  if img is None and idx is None:
375
- raise ValueError('Either img or idx must be provided.')
419
+ raise ValueError("Either img or idx must be provided.")
376
420
  if img is not None and idx is not None:
377
- raise ValueError('Only one of img or idx must be provided.')
421
+ raise ValueError("Only one of img or idx must be provided.")
378
422
  if idx is not None:
379
423
  idx = idx if isinstance(idx, list) else [idx]
380
- img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
424
+ img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
381
425
 
382
426
  return img if isinstance(img, list) else [img]
383
427
 
428
+ def ask_ai(self, query):
429
+ """
430
+ Ask AI a question.
431
+
432
+ Args:
433
+ query (str): Question to ask.
434
+
435
+ Returns:
436
+ (pandas.DataFrame): A dataframe containing filtered results to the SQL query.
437
+
438
+ Example:
439
+ ```python
440
+ exp = Explorer()
441
+ exp.create_embeddings_table()
442
+ answer = exp.ask_ai('Show images with 1 person and 2 dogs')
443
+ ```
444
+ """
445
+ result = prompt_sql_query(query)
446
+ try:
447
+ df = self.sql_query(result)
448
+ except Exception as e:
449
+ LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
450
+ LOGGER.error(e)
451
+ return None
452
+ return df
453
+
384
454
  def visualize(self, result):
385
455
  """
386
- Visualize the results of a query.
456
+ Visualize the results of a query. TODO.
387
457
 
388
458
  Args:
389
- result (arrow table): Arrow table containing the results of a query.
459
+ result (pyarrow.Table): Table containing the results of a query.
390
460
  """
391
- # TODO:
392
461
  pass
393
462
 
394
463
  def generate_report(self, result):
395
- """Generate a report of the dataset."""
464
+ """
465
+ Generate a report of the dataset.
466
+
467
+ TODO
468
+ """
396
469
  pass
@@ -0,0 +1 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license