ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
1
3
  from io import BytesIO
2
4
  from pathlib import Path
3
5
  from typing import Any, List, Tuple, Union
@@ -20,13 +22,11 @@ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, pr
20
22
 
21
23
 
22
24
  class ExplorerDataset(YOLODataset):
23
-
24
25
  def __init__(self, *args, data: dict = None, **kwargs) -> None:
25
26
  super().__init__(*args, data=data, **kwargs)
26
27
 
27
- # NOTE: Load the image directly without any resize operations.
28
28
  def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
29
- """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
29
+ """Loads 1 image from dataset index 'i' without any resize ops."""
30
30
  im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
31
31
  if im is None: # not cached in RAM
32
32
  if fn.exists(): # load npy
@@ -34,15 +34,16 @@ class ExplorerDataset(YOLODataset):
34
34
  else: # read image
35
35
  im = cv2.imread(f) # BGR
36
36
  if im is None:
37
- raise FileNotFoundError(f'Image Not Found {f}')
37
+ raise FileNotFoundError(f"Image Not Found {f}")
38
38
  h0, w0 = im.shape[:2] # orig hw
39
39
  return im, (h0, w0), im.shape[:2]
40
40
 
41
41
  return self.ims[i], self.im_hw0[i], self.im_hw[i]
42
42
 
43
43
  def build_transforms(self, hyp: IterableSimpleNamespace = None):
44
+ """Creates transforms for dataset images without resizing."""
44
45
  return Format(
45
- bbox_format='xyxy',
46
+ bbox_format="xyxy",
46
47
  normalize=False,
47
48
  return_mask=self.use_segments,
48
49
  return_keypoint=self.use_keypoints,
@@ -53,17 +54,16 @@ class ExplorerDataset(YOLODataset):
53
54
 
54
55
 
55
56
  class Explorer:
56
-
57
- def __init__(self,
58
- data: Union[str, Path] = 'coco128.yaml',
59
- model: str = 'yolov8n.pt',
60
- uri: str = '~/ultralytics/explorer') -> None:
61
- checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
57
+ def __init__(
58
+ self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer"
59
+ ) -> None:
60
+ checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
62
61
  import lancedb
63
62
 
64
63
  self.connection = lancedb.connect(uri)
65
- self.table_name = Path(data).name.lower() + '_' + model.lower()
66
- self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
64
+ self.table_name = Path(data).name.lower() + "_" + model.lower()
65
+ self.sim_idx_base_name = (
66
+ f"{self.table_name}_sim_idx".lower()
67
67
  ) # Use this name and append thres and top_k to reuse the table
68
68
  self.model = YOLO(model)
69
69
  self.data = data # None
@@ -72,7 +72,7 @@ class Explorer:
72
72
  self.table = None
73
73
  self.progress = 0
74
74
 
75
- def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
75
+ def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
76
76
  """
77
77
  Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
78
78
  already exists. Pass force=True to overwrite the existing table.
@@ -88,20 +88,20 @@ class Explorer:
88
88
  ```
89
89
  """
90
90
  if self.table is not None and not force:
91
- LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
91
+ LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
92
92
  return
93
93
  if self.table_name in self.connection.table_names() and not force:
94
- LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
94
+ LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
95
95
  self.table = self.connection.open_table(self.table_name)
96
96
  self.progress = 1
97
97
  return
98
98
  if self.data is None:
99
- raise ValueError('Data must be provided to create embeddings table')
99
+ raise ValueError("Data must be provided to create embeddings table")
100
100
 
101
101
  data_info = check_det_dataset(self.data)
102
102
  if split not in data_info:
103
103
  raise ValueError(
104
- f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
104
+ f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
105
105
  )
106
106
 
107
107
  choice_set = data_info[split]
@@ -111,30 +111,33 @@ class Explorer:
111
111
 
112
112
  # Create the table schema
113
113
  batch = dataset[0]
114
- vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
115
- table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
114
+ vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
115
+ table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
116
116
  table.add(
117
- self._yield_batches(dataset,
118
- data_info,
119
- self.model,
120
- exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
117
+ self._yield_batches(
118
+ dataset,
119
+ data_info,
120
+ self.model,
121
+ exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
122
+ )
123
+ )
121
124
 
122
125
  self.table = table
123
126
 
124
127
  def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
125
- # Implement Batching
128
+ """Generates batches of data for embedding, excluding specified keys."""
126
129
  for i in tqdm(range(len(dataset))):
127
130
  self.progress = float(i + 1) / len(dataset)
128
131
  batch = dataset[i]
129
132
  for k in exclude_keys:
130
133
  batch.pop(k, None)
131
134
  batch = sanitize_batch(batch, data_info)
132
- batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
135
+ batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
133
136
  yield [batch]
134
137
 
135
- def query(self,
136
- imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
137
- limit: int = 25) -> Any: # pyarrow.Table
138
+ def query(
139
+ self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
140
+ ) -> Any: # pyarrow.Table
138
141
  """
139
142
  Query the table for similar images. Accepts a single image or a list of images.
140
143
 
@@ -143,7 +146,7 @@ class Explorer:
143
146
  limit (int): Number of results to return.
144
147
 
145
148
  Returns:
146
- An arrow table containing the results. Supports converting to:
149
+ (pyarrow.Table): An arrow table containing the results. Supports converting to:
147
150
  - pandas dataframe: `result.to_pandas()`
148
151
  - dict of lists: `result.to_pydict()`
149
152
 
@@ -155,18 +158,18 @@ class Explorer:
155
158
  ```
156
159
  """
157
160
  if self.table is None:
158
- raise ValueError('Table is not created. Please create the table first.')
161
+ raise ValueError("Table is not created. Please create the table first.")
159
162
  if isinstance(imgs, str):
160
163
  imgs = [imgs]
161
- assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
164
+ assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
162
165
  embeds = self.model.embed(imgs)
163
166
  # Get avg if multiple images are passed (len > 1)
164
167
  embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
165
168
  return self.table.search(embeds).limit(limit).to_arrow()
166
169
 
167
- def sql_query(self,
168
- query: str,
169
- return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
170
+ def sql_query(
171
+ self, query: str, return_type: str = "pandas"
172
+ ) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
170
173
  """
171
174
  Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
172
175
 
@@ -175,7 +178,7 @@ class Explorer:
175
178
  return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
176
179
 
177
180
  Returns:
178
- An arrow table containing the results.
181
+ (pyarrow.Table): An arrow table containing the results.
179
182
 
180
183
  Example:
181
184
  ```python
@@ -185,27 +188,29 @@ class Explorer:
185
188
  result = exp.sql_query(query)
186
189
  ```
187
190
  """
188
- assert return_type in ['pandas',
189
- 'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
191
+ assert return_type in [
192
+ "pandas",
193
+ "arrow",
194
+ ], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
190
195
  import duckdb
191
196
 
192
197
  if self.table is None:
193
- raise ValueError('Table is not created. Please create the table first.')
198
+ raise ValueError("Table is not created. Please create the table first.")
194
199
 
195
200
  # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
196
201
  table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
197
- if not query.startswith('SELECT') and not query.startswith('WHERE'):
202
+ if not query.startswith("SELECT") and not query.startswith("WHERE"):
198
203
  raise ValueError(
199
- f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
204
+ f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
200
205
  )
201
- if query.startswith('WHERE'):
206
+ if query.startswith("WHERE"):
202
207
  query = f"SELECT * FROM 'table' {query}"
203
- LOGGER.info(f'Running query: {query}')
208
+ LOGGER.info(f"Running query: {query}")
204
209
 
205
210
  rs = duckdb.sql(query)
206
- if return_type == 'pandas':
211
+ if return_type == "pandas":
207
212
  return rs.df()
208
- elif return_type == 'arrow':
213
+ elif return_type == "arrow":
209
214
  return rs.arrow()
210
215
 
211
216
  def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
@@ -216,7 +221,7 @@ class Explorer:
216
221
  labels (bool): Whether to plot the labels or not.
217
222
 
218
223
  Returns:
219
- PIL Image containing the plot.
224
+ (PIL.Image): Image containing the plot.
220
225
 
221
226
  Example:
222
227
  ```python
@@ -226,18 +231,20 @@ class Explorer:
226
231
  result = exp.plot_sql_query(query)
227
232
  ```
228
233
  """
229
- result = self.sql_query(query, return_type='arrow')
234
+ result = self.sql_query(query, return_type="arrow")
230
235
  if len(result) == 0:
231
- LOGGER.info('No results found.')
236
+ LOGGER.info("No results found.")
232
237
  return None
233
238
  img = plot_query_result(result, plot_labels=labels)
234
239
  return Image.fromarray(img)
235
240
 
236
- def get_similar(self,
237
- img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
238
- idx: Union[int, List[int]] = None,
239
- limit: int = 25,
240
- return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
241
+ def get_similar(
242
+ self,
243
+ img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
244
+ idx: Union[int, List[int]] = None,
245
+ limit: int = 25,
246
+ return_type: str = "pandas",
247
+ ) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
241
248
  """
242
249
  Query the table for similar images. Accepts a single image or a list of images.
243
250
 
@@ -248,7 +255,7 @@ class Explorer:
248
255
  return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
249
256
 
250
257
  Returns:
251
- A table or pandas dataframe containing the results.
258
+ (pandas.DataFrame): A dataframe containing the results.
252
259
 
253
260
  Example:
254
261
  ```python
@@ -257,21 +264,25 @@ class Explorer:
257
264
  similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
258
265
  ```
259
266
  """
260
- assert return_type in ['pandas',
261
- 'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
267
+ assert return_type in [
268
+ "pandas",
269
+ "arrow",
270
+ ], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
262
271
  img = self._check_imgs_or_idxs(img, idx)
263
272
  similar = self.query(img, limit=limit)
264
273
 
265
- if return_type == 'pandas':
274
+ if return_type == "pandas":
266
275
  return similar.to_pandas()
267
- elif return_type == 'arrow':
276
+ elif return_type == "arrow":
268
277
  return similar
269
278
 
270
- def plot_similar(self,
271
- img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
272
- idx: Union[int, List[int]] = None,
273
- limit: int = 25,
274
- labels: bool = True) -> Image.Image:
279
+ def plot_similar(
280
+ self,
281
+ img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
282
+ idx: Union[int, List[int]] = None,
283
+ limit: int = 25,
284
+ labels: bool = True,
285
+ ) -> Image.Image:
275
286
  """
276
287
  Plot the similar images. Accepts images or indexes.
277
288
 
@@ -282,7 +293,7 @@ class Explorer:
282
293
  limit (int): Number of results to return. Defaults to 25.
283
294
 
284
295
  Returns:
285
- PIL Image containing the plot.
296
+ (PIL.Image): Image containing the plot.
286
297
 
287
298
  Example:
288
299
  ```python
@@ -291,9 +302,9 @@ class Explorer:
291
302
  similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
292
303
  ```
293
304
  """
294
- similar = self.get_similar(img, idx, limit, return_type='arrow')
305
+ similar = self.get_similar(img, idx, limit, return_type="arrow")
295
306
  if len(similar) == 0:
296
- LOGGER.info('No results found.')
307
+ LOGGER.info("No results found.")
297
308
  return None
298
309
  img = plot_query_result(similar, plot_labels=labels)
299
310
  return Image.fromarray(img)
@@ -306,11 +317,12 @@ class Explorer:
306
317
  Args:
307
318
  max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
308
319
  top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
309
- vector search. Defaults: None.
320
+ vector search. Defaults: None.
310
321
  force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
311
322
 
312
323
  Returns:
313
- A pandas dataframe containing the similarity index.
324
+ (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
325
+ include indices of similar images and their respective distances.
314
326
 
315
327
  Example:
316
328
  ```python
@@ -320,33 +332,37 @@ class Explorer:
320
332
  ```
321
333
  """
322
334
  if self.table is None:
323
- raise ValueError('Table is not created. Please create the table first.')
324
- sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
335
+ raise ValueError("Table is not created. Please create the table first.")
336
+ sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
325
337
  if sim_idx_table_name in self.connection.table_names() and not force:
326
- LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
338
+ LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
327
339
  return self.connection.open_table(sim_idx_table_name).to_pandas()
328
340
 
329
341
  if top_k and not (1.0 >= top_k >= 0.0):
330
- raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
342
+ raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
331
343
  if max_dist < 0.0:
332
- raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
344
+ raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
333
345
 
334
346
  top_k = int(top_k * len(self.table)) if top_k else len(self.table)
335
347
  top_k = max(top_k, 1)
336
- features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
337
- im_files = features['im_file']
338
- embeddings = features['vector']
348
+ features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
349
+ im_files = features["im_file"]
350
+ embeddings = features["vector"]
339
351
 
340
- sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
352
+ sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
341
353
 
342
354
  def _yield_sim_idx():
355
+ """Generates a dataframe with similarity indices and distances for images."""
343
356
  for i in tqdm(range(len(embeddings))):
344
- sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
345
- yield [{
346
- 'idx': i,
347
- 'im_file': im_files[i],
348
- 'count': len(sim_idx),
349
- 'sim_im_files': sim_idx['im_file'].tolist()}]
357
+ sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
358
+ yield [
359
+ {
360
+ "idx": i,
361
+ "im_file": im_files[i],
362
+ "count": len(sim_idx),
363
+ "sim_im_files": sim_idx["im_file"].tolist(),
364
+ }
365
+ ]
350
366
 
351
367
  sim_table.add(_yield_sim_idx())
352
368
  self.sim_index = sim_table
@@ -364,7 +380,7 @@ class Explorer:
364
380
  force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
365
381
 
366
382
  Returns:
367
- PIL.PngImagePlugin.PngImageFile containing the plot.
383
+ (PIL.Image): Image containing the plot.
368
384
 
369
385
  Example:
370
386
  ```python
@@ -377,7 +393,7 @@ class Explorer:
377
393
  ```
378
394
  """
379
395
  sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
380
- sim_count = sim_idx['count'].tolist()
396
+ sim_count = sim_idx["count"].tolist()
381
397
  sim_count = np.array(sim_count)
382
398
 
383
399
  indices = np.arange(len(sim_count))
@@ -386,25 +402,26 @@ class Explorer:
386
402
  plt.bar(indices, sim_count)
387
403
 
388
404
  # Customize the plot (optional)
389
- plt.xlabel('data idx')
390
- plt.ylabel('Count')
391
- plt.title('Similarity Count')
405
+ plt.xlabel("data idx")
406
+ plt.ylabel("Count")
407
+ plt.title("Similarity Count")
392
408
  buffer = BytesIO()
393
- plt.savefig(buffer, format='png')
409
+ plt.savefig(buffer, format="png")
394
410
  buffer.seek(0)
395
411
 
396
412
  # Use Pillow to open the image from the buffer
397
413
  return Image.fromarray(np.array(Image.open(buffer)))
398
414
 
399
- def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
400
- idx: Union[None, int, List[int]]) -> List[np.ndarray]:
415
+ def _check_imgs_or_idxs(
416
+ self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
417
+ ) -> List[np.ndarray]:
401
418
  if img is None and idx is None:
402
- raise ValueError('Either img or idx must be provided.')
419
+ raise ValueError("Either img or idx must be provided.")
403
420
  if img is not None and idx is not None:
404
- raise ValueError('Only one of img or idx must be provided.')
421
+ raise ValueError("Only one of img or idx must be provided.")
405
422
  if idx is not None:
406
423
  idx = idx if isinstance(idx, list) else [idx]
407
- img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
424
+ img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
408
425
 
409
426
  return img if isinstance(img, list) else [img]
410
427
 
@@ -416,7 +433,7 @@ class Explorer:
416
433
  query (str): Question to ask.
417
434
 
418
435
  Returns:
419
- Answer from AI.
436
+ (pandas.DataFrame): A dataframe containing filtered results to the SQL query.
420
437
 
421
438
  Example:
422
439
  ```python
@@ -429,21 +446,24 @@ class Explorer:
429
446
  try:
430
447
  df = self.sql_query(result)
431
448
  except Exception as e:
432
- LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
449
+ LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
433
450
  LOGGER.error(e)
434
451
  return None
435
452
  return df
436
453
 
437
454
  def visualize(self, result):
438
455
  """
439
- Visualize the results of a query.
456
+ Visualize the results of a query. TODO.
440
457
 
441
458
  Args:
442
- result (arrow table): Arrow table containing the results of a query.
459
+ result (pyarrow.Table): Table containing the results of a query.
443
460
  """
444
- # TODO:
445
461
  pass
446
462
 
447
463
  def generate_report(self, result):
448
- """Generate a report of the dataset."""
464
+ """
465
+ Generate a report of the dataset.
466
+
467
+ TODO
468
+ """
449
469
  pass
@@ -0,0 +1 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license