ultralytics 8.3.11__py3-none-any.whl → 8.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

@@ -804,31 +804,30 @@ class Annotator:
804
804
  self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf
805
805
  )
806
806
 
807
- def plot_distance_and_line(self, pixels_distance, centroids, line_color, centroid_color):
807
+ def plot_distance_and_line(
808
+ self, pixels_distance, centroids, line_color=(104, 31, 17), centroid_color=(255, 0, 255)
809
+ ):
808
810
  """
809
811
  Plot the distance and line on frame.
810
812
 
811
813
  Args:
812
814
  pixels_distance (float): Pixels distance between two bbox centroids.
813
815
  centroids (list): Bounding box centroids data.
814
- line_color (tuple): RGB distance line color.
815
- centroid_color (tuple): RGB bounding box centroid color.
816
+ line_color (tuple, optional): Distance line color.
817
+ centroid_color (tuple, optional): Bounding box centroid color.
816
818
  """
817
819
  # Get the text size
818
- (text_width_m, text_height_m), _ = cv2.getTextSize(
819
- f"Pixels Distance: {pixels_distance:.2f}", 0, self.sf, self.tf
820
- )
820
+ text = f"Pixels Distance: {pixels_distance:.2f}"
821
+ (text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)
821
822
 
822
823
  # Define corners with 10-pixel margin and draw rectangle
823
- top_left = (15, 25)
824
- bottom_right = (15 + text_width_m + 20, 25 + text_height_m + 20)
825
- cv2.rectangle(self.im, top_left, bottom_right, centroid_color, -1)
824
+ cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)
826
825
 
827
826
  # Calculate the position for the text with a 10-pixel margin and draw text
828
- text_position = (top_left[0] + 10, top_left[1] + text_height_m + 10)
827
+ text_position = (25, 25 + text_height_m + 10)
829
828
  cv2.putText(
830
829
  self.im,
831
- f"Pixels Distance: {pixels_distance:.2f}",
830
+ text,
832
831
  text_position,
833
832
  0,
834
833
  self.sf,
@@ -1156,16 +1155,16 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
1156
1155
  save_dir = Path(file).parent if file else Path(dir)
1157
1156
  if classify:
1158
1157
  fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
1159
- index = [1, 4, 2, 3]
1158
+ index = [2, 5, 3, 4]
1160
1159
  elif segment:
1161
1160
  fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
1162
- index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]
1161
+ index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
1163
1162
  elif pose:
1164
1163
  fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
1165
- index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]
1164
+ index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
1166
1165
  else:
1167
1166
  fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
1168
- index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
1167
+ index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
1169
1168
  ax = ax.ravel()
1170
1169
  files = list(save_dir.glob("results*.csv"))
1171
1170
  assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ultralytics
3
- Version: 8.3.11
3
+ Version: 8.3.13
4
4
  Summary: Ultralytics YOLO for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -60,10 +60,6 @@ Requires-Dist: mkdocs-jupyter; extra == "dev"
60
60
  Requires-Dist: mkdocs-redirects; extra == "dev"
61
61
  Requires-Dist: mkdocs-ultralytics-plugin>=0.1.8; extra == "dev"
62
62
  Requires-Dist: mkdocs-macros-plugin>=1.0.5; extra == "dev"
63
- Provides-Extra: explorer
64
- Requires-Dist: lancedb; extra == "explorer"
65
- Requires-Dist: duckdb<=0.9.2; extra == "explorer"
66
- Requires-Dist: streamlit; extra == "explorer"
67
63
  Provides-Extra: export
68
64
  Requires-Dist: onnx>=1.12.0; extra == "export"
69
65
  Requires-Dist: openvino>=2024.0.0; extra == "export"
@@ -85,6 +81,9 @@ Provides-Extra: logging
85
81
  Requires-Dist: comet; extra == "logging"
86
82
  Requires-Dist: tensorboard>=2.13.0; extra == "logging"
87
83
  Requires-Dist: dvclive>=2.12.0; extra == "logging"
84
+ Provides-Extra: solutions
85
+ Requires-Dist: shapely>=2.0.0; extra == "solutions"
86
+ Requires-Dist: streamlit; extra == "solutions"
88
87
 
89
88
  <div align="center">
90
89
  <p>
@@ -1,16 +1,16 @@
1
1
  tests/__init__.py,sha256=iVH5nXrACTDv0_ZIVRPi-9f6oYBl6g-tCkeR2Hb8MFM,666
2
2
  tests/conftest.py,sha256=9PFAiwAy6eeORGspr5dOKxVuFDVKqYg8Nn_RxSJ27UI,2919
3
- tests/test_cli.py,sha256=E4lMt49TGo12Lb5CgQfpk1bwyFUZuFxF0V9j_ykV7xM,4821
4
- tests/test_cuda.py,sha256=KoRtRLUB7KOb9IXYX4mCi295Uh_cZEEFhCyvCDGRK9s,5381
3
+ tests/test_cli.py,sha256=G7OJ1ErQYsGy2Dx1zP-0p7EZR4aPoAdtLGiY4Hm7jQM,5006
4
+ tests/test_cuda.py,sha256=cxohrqAjjPCgTaChYhQ4CObQzz7ZyhVGPTxkcxHSf0g,5940
5
5
  tests/test_engine.py,sha256=dcEcJsMQh61rDSNv7l4TIAgybLpzjVwerv9JZC_KCM8,4934
6
6
  tests/test_exports.py,sha256=fpTKEVBUGLF3WiZPNKRs-IEcIY4cfxgvgKjUNfodjww,8042
7
7
  tests/test_integrations.py,sha256=f5-QCUk1SU_-qn4mBCZwS3GN3tXEBIIXo4z2EhExbHw,6126
8
8
  tests/test_python.py,sha256=I1RRdCwLdrc3jX06huVxct8HX8ccQOmQgVpuEflRl0U,23560
9
9
  tests/test_solutions.py,sha256=dpxWGKO-aJ3Yff4KR7BQGajX9VyFdGTWEtcbmFC3WwE,3005
10
- ultralytics/__init__.py,sha256=19JcU9M-VZ6RCIz0c3u-8ynzEpeqhYIKNSN9t_kpNuI,753
10
+ ultralytics/__init__.py,sha256=2rg2RMDy6HqtBcSx4b7eBss9eXQ6leZvZ6drzM-8sFI,681
11
11
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
12
12
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
13
- ultralytics/cfg/__init__.py,sha256=N-XONBXwmD3vzoE4icBXznkV8LOLmf6ak6mRdGPucvw,33146
13
+ ultralytics/cfg/__init__.py,sha256=CUg1z4zY3KyR-V4_bghMY8s1xuu-M50gm-v_vpHdXEM,31753
14
14
  ultralytics/cfg/default.yaml,sha256=ul49zgSzTegMmc8CFeu9tXkWNvQhETdZMa9EgDNSnY4,8319
15
15
  ultralytics/cfg/datasets/Argoverse.yaml,sha256=FyeuJT5CHq_9d4hlfAf0kpZlnbUMO0S--UJ1yIqcdKk,3134
16
16
  ultralytics/cfg/datasets/DOTAv1.5.yaml,sha256=QVfp_Qp-4rukuicaB4qx86NxSHM8Mrzym8l_fIDo8gw,1195
@@ -85,7 +85,7 @@ ultralytics/cfg/models/v9/yolov9e.yaml,sha256=dhaR47WxuLOrZWDCceS4bQG00sQdrMc8FQ
85
85
  ultralytics/cfg/models/v9/yolov9m.yaml,sha256=l6CmivzNu44sRVmkQXk4-tXflbV1nWnk5MSc8su2vhs,1311
86
86
  ultralytics/cfg/models/v9/yolov9s.yaml,sha256=lPWcu-6ub1kCBD6zIDFwthYZ3RvdJfODWKy3vEQWRjo,1291
87
87
  ultralytics/cfg/models/v9/yolov9t.yaml,sha256=qL__kr6GoefpQWP4jV0jdzwTp46bdFUcqtPRnfDbkY8,1275
88
- ultralytics/cfg/solutions/default.yaml,sha256=CByxINYMyoGzGKdurDk2GhYc8XOa8Z6H7CZx7uZSPpc,1532
88
+ ultralytics/cfg/solutions/default.yaml,sha256=zZ_ksoZ-BcAIL5jjw0jzHgraoe7363oxXqfSsg9yopk,1673
89
89
  ultralytics/cfg/trackers/botsort.yaml,sha256=8B0xNbnG_E-9DCUpap72PWkUgBb1AjuApEn7gHiVngE,916
90
90
  ultralytics/cfg/trackers/bytetrack.yaml,sha256=8vpTZ2x9mhRXJymoJvs1G8kTXo_HxbSwHup2FQALT3A,721
91
91
  ultralytics/data/__init__.py,sha256=VGe-ATG7j35F4A4r8Jmzffjlhve4JAJPgRa5ahKTU18,616
@@ -98,13 +98,8 @@ ultralytics/data/dataset.py,sha256=D556AW0ZEsW3V8c5zJiHM_prc_YfZqymIkDKPw3k9Io,2
98
98
  ultralytics/data/loaders.py,sha256=Fr70Q9p9t7buLW_8R2_lI_nyCMG033gWSxvwy1M-a-U,28449
99
99
  ultralytics/data/split_dota.py,sha256=yOtypHoY5HvIVBKZgFXdfj2tuCLLEBnMwNfAeG94Eik,10680
100
100
  ultralytics/data/utils.py,sha256=u6OZ7InLpI1em5aEPz13ZzS9BcO37dcY9_s2btXGZYQ,31076
101
- ultralytics/data/explorer/__init__.py,sha256=-Y3m1ZedepOQUv_KW82zaGxvU_PSHcuwUTFqG9BhAr4,113
102
- ultralytics/data/explorer/explorer.py,sha256=JWmLHHhp68h2q3vx4poBou5RYoAX3R89yihR50YLDb0,18881
103
- ultralytics/data/explorer/utils.py,sha256=EvvukQiQUTBrsZznmMnyEX2EqTuwZo_Geyc8yfi8NIA,7085
104
- ultralytics/data/explorer/gui/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
105
- ultralytics/data/explorer/gui/dash.py,sha256=6XOZy9NrkPEXREJPbi0EBkGgu78TAdHpdhSB2HuBOAo,10222
106
101
  ultralytics/engine/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
107
- ultralytics/engine/exporter.py,sha256=b3OIHAABVyqqSizJKQxiWPZvKzIvThK9kucN2iYWnwE,57487
102
+ ultralytics/engine/exporter.py,sha256=lKmVaypzVY3R-RkCs1KQNrSpF6W4jKMZqNBKZ0CfzmA,57670
108
103
  ultralytics/engine/model.py,sha256=pvL1uf-wwdWL8Iph7VEAYn1-z7wEHzVug21V_0_gO6M,51456
109
104
  ultralytics/engine/predictor.py,sha256=keTelEeo23Dcbs-XvmRWAPIs4pbCNDtsMBz88WM1eK8,17534
110
105
  ultralytics/engine/results.py,sha256=BxanBI8PhBCfs-9cSy-GS6naScuiD3hdvUAJWPW2mS0,75043
@@ -135,7 +130,7 @@ ultralytics/models/sam/__init__.py,sha256=o4_D6y8YJlOXIK7Lwo9RHnIJJ9xoFNi4zK99QS
135
130
  ultralytics/models/sam/amg.py,sha256=GrmO_8YfIDt_QkPEMF_WFjPZkhwhf7iwx7ig8JgOUnE,8709
136
131
  ultralytics/models/sam/build.py,sha256=np9vP7AETCZA2Wdds-uj2eQKVnpHQaVpRrE2-U2uMTI,12153
137
132
  ultralytics/models/sam/model.py,sha256=2KFUp8SHiqOgwUjkdqdau0oduJwKQxm4N9GHWjdhUFo,7382
138
- ultralytics/models/sam/predict.py,sha256=_spP0uYNFzUnybwBvzZhF3iEMwvAi6bxryRdUwxwweM,38608
133
+ ultralytics/models/sam/predict.py,sha256=LSxys7fuQycrAoqf_EFohk9ftu7cq1F2GY9_fuIl5uE,40384
139
134
  ultralytics/models/sam/modules/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
140
135
  ultralytics/models/sam/modules/blocks.py,sha256=Q-KwhFbdyZhl1tjG_kP2LcQkZbzoNt618i-NRrKNx2Y,45919
141
136
  ultralytics/models/sam/modules/decoders.py,sha256=mODsqnTN_CjE3H0Sh9cd8PfTnHANPjGB1bjqHxfezSg,25830
@@ -152,7 +147,7 @@ ultralytics/models/yolo/__init__.py,sha256=e1cZr9pbSbf3Ya2OvkTjGRwD_E2YZpe610xsk
152
147
  ultralytics/models/yolo/model.py,sha256=E4TuJZZux0L_SG7sC0SDgxrmeBvuZRpxprPrCC26lvs,4233
153
148
  ultralytics/models/yolo/classify/__init__.py,sha256=t-4pUHmgI2gjhc-l3bqNEcEtKD1dO40nD4Vc6Y2xD6o,355
154
149
  ultralytics/models/yolo/classify/predict.py,sha256=0CEJ4B4fXbOMUnJy79gRvG-qdszOzTSLOb1xxkgsKek,2444
155
- ultralytics/models/yolo/classify/train.py,sha256=THXSkQVQVBuw1QxcEVA8MtLHYYdaAEqepObJCXoLcZ8,6358
150
+ ultralytics/models/yolo/classify/train.py,sha256=3aYzLDqX_03xR1xqlTn1TxA4t58cCIGI8RCtWheTrm0,6273
156
151
  ultralytics/models/yolo/classify/val.py,sha256=Tzizhp3ebzPvwJejrE8tb-TuXw4MdkEI9mOANV74eXQ,4909
157
152
  ultralytics/models/yolo/detect/__init__.py,sha256=JR8gZJWn7wMBbh-0j_073nxJVZTMFZVWTOG5Wnvk6w0,229
158
153
  ultralytics/models/yolo/detect/predict.py,sha256=-uZFLutxGYZX47RANcaxC-LFStRbv0nBv_8-ypadQoI,1471
@@ -185,10 +180,10 @@ ultralytics/nn/modules/transformer.py,sha256=tGiK8NmPfswwW1rbF21r5ILUkkZQ6Nk4s8j
185
180
  ultralytics/nn/modules/utils.py,sha256=a88cKl2wz1nMVSEBiajtvaCbDBQIkESWOKTZ_WAJy90,3195
186
181
  ultralytics/solutions/__init__.py,sha256=6RDeXWO1QSaMgCq8YrWXaj2xvPw2sJwJL_a0dgjCvz0,648
187
182
  ultralytics/solutions/ai_gym.py,sha256=lBAkWV8vrEdKAXcBFVbugPeZZ08MOjGYTdnFlG22vKM,3772
188
- ultralytics/solutions/analytics.py,sha256=bGuZes11D7DNiTsHdwu6PJ0QA0vCiqMMAtZ7NyEkshY,11568
189
- ultralytics/solutions/distance_calculation.py,sha256=o_DAHk4JX8n2Vt7E68MX67mREOBZuy5skbXtVZ6iu_4,5228
183
+ ultralytics/solutions/analytics.py,sha256=w5hnnBNSTQ35tJp6DDeWYw2ASjylp3ZmzrTXcdWwDw8,9319
184
+ ultralytics/solutions/distance_calculation.py,sha256=3D5qj9g-XGt_QPEu5IQI2ubTC0n2pmISDrNPl__JK9M,3373
190
185
  ultralytics/solutions/heatmap.py,sha256=2C4s_rVFcOc5oSWxb0pNxNoCawe4lxajpTDNFd4tVL8,3850
191
- ultralytics/solutions/object_counter.py,sha256=1Nsivk-cyGBM1G6eWe11_vdDWTdbJwaUFMJ1A7OK-Qg,5495
186
+ ultralytics/solutions/object_counter.py,sha256=7s3Q--CAFHr_uXzeq6epXgl5YSinc6q-VThPBx1Gj3Y,5485
192
187
  ultralytics/solutions/parking_management.py,sha256=VgYyhoSEo7fnPegIhNUqnFL0jlMEevALx0QQbzJ3vGI,9049
193
188
  ultralytics/solutions/queue_management.py,sha256=5d1RURQiqffAoET8S66gHimK0l3gKNAfuPO5U6_08jc,2716
194
189
  ultralytics/solutions/solutions.py,sha256=qWKGlwlH9858GfAdZkcu_QXbrzjTFStDvg16Eky0oyo,3541
@@ -213,10 +208,10 @@ ultralytics/utils/errors.py,sha256=GqP_Jgj_n0paxn8OMhn3DTCgoNkB2WjUcUaqs-M6SQk,8
213
208
  ultralytics/utils/files.py,sha256=uiXQSVABJRoI5ImnM6ndEBIFbECfksmWNEldBg8GnSo,8224
214
209
  ultralytics/utils/instance.py,sha256=QSms7mPHZ5e8JGuJYLohLWltzI0aBE8dob2rOUK4RtM,16249
215
210
  ultralytics/utils/loss.py,sha256=SW3FVFFp8Ki_LCT8wIdFbm6KmyPcQn3RmKNcvVAhMQI,34174
216
- ultralytics/utils/metrics.py,sha256=UgLGudWp57uXDMlMUJy4gsz6cfVjcq7tYmHeto3TqvM,53927
211
+ ultralytics/utils/metrics.py,sha256=msPaXc244ndc0NPBhnNlHsKkVhdc-TMgFn5NATlZZVI,53918
217
212
  ultralytics/utils/ops.py,sha256=dsXNdyrYx_p6io6zezig9p84dxS7U-10vceHNVu2IL0,32888
218
213
  ultralytics/utils/patches.py,sha256=J-iOwIRbfUs-inBZerhnXby5tUKjYcOIyvhLTS352JE,3270
219
- ultralytics/utils/plotting.py,sha256=aozAEwcbc447ume9bQrEBTU04AzyiZZrnzcTzA2S6j0,61165
214
+ ultralytics/utils/plotting.py,sha256=RYTdMJtWOO5qPowca1a8izfasoIyGxzmfp9VGB_g0xE,61092
220
215
  ultralytics/utils/tal.py,sha256=ECsu95xEqOItmxMDN4YTD3FsUiIsQNWy0pZC3TfvFfk,16877
221
216
  ultralytics/utils/torch_utils.py,sha256=gVN-KSrAzJC1rW3woQd4FsTT693GD8rXiccToL2m4kM,30059
222
217
  ultralytics/utils/triton.py,sha256=gg1finxno_tY2Ge9PMhmu7PI9wvoFZoiicdT4Bhqv3w,3936
@@ -232,9 +227,9 @@ ultralytics/utils/callbacks/neptune.py,sha256=IbGQfEltamUKXJt93uSLQFn8c2rYh3DMTg
232
227
  ultralytics/utils/callbacks/raytune.py,sha256=ODVYzy-CoM4Uge0zjkh3Hnh9nF2M0vhDrSenXnvcizw,705
233
228
  ultralytics/utils/callbacks/tensorboard.py,sha256=bv4fkkesdgmZv_E2MU6wuaMBwEV5iI2G53RHPyD9quw,4170
234
229
  ultralytics/utils/callbacks/wb.py,sha256=upfbF8-LLXueUvulLaMDmKDhKCl_PWbNa_87PQ0L0Rc,6752
235
- ultralytics-8.3.11.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
236
- ultralytics-8.3.11.dist-info/METADATA,sha256=bD92haGae0_AYj7u5Z8EE6OoUn5sAhwau3R0S-z6pVQ,34700
237
- ultralytics-8.3.11.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
238
- ultralytics-8.3.11.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
239
- ultralytics-8.3.11.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
240
- ultralytics-8.3.11.dist-info/RECORD,,
230
+ ultralytics-8.3.13.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
231
+ ultralytics-8.3.13.dist-info/METADATA,sha256=lwU4l_KFBx8TzMo_vwe4pyKOaepvZ3mPvWwm6Y951_A,34660
232
+ ultralytics-8.3.13.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
233
+ ultralytics-8.3.13.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
234
+ ultralytics-8.3.13.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
235
+ ultralytics-8.3.13.dist-info/RECORD,,
@@ -1,5 +0,0 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
2
-
3
- from .utils import plot_query_result
4
-
5
- __all__ = ["plot_query_result"]
@@ -1,460 +0,0 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
2
-
3
- from io import BytesIO
4
- from pathlib import Path
5
- from typing import Any, List, Tuple, Union
6
-
7
- import cv2
8
- import numpy as np
9
- import torch
10
- from matplotlib import pyplot as plt
11
- from PIL import Image
12
- from tqdm import tqdm
13
-
14
- from ultralytics.data.augment import Format
15
- from ultralytics.data.dataset import YOLODataset
16
- from ultralytics.data.utils import check_det_dataset
17
- from ultralytics.models.yolo.model import YOLO
18
- from ultralytics.utils import LOGGER, USER_CONFIG_DIR, IterableSimpleNamespace, checks
19
-
20
- from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
21
-
22
-
23
- class ExplorerDataset(YOLODataset):
24
- """Extends YOLODataset for advanced data exploration and manipulation in model training workflows."""
25
-
26
- def __init__(self, *args, data: dict = None, **kwargs) -> None:
27
- """Initializes the ExplorerDataset with the provided data arguments, extending the YOLODataset class."""
28
- super().__init__(*args, data=data, **kwargs)
29
-
30
- def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
31
- """Loads 1 image from dataset index 'i' without any resize ops."""
32
- im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
33
- if im is None: # not cached in RAM
34
- if fn.exists(): # load npy
35
- im = np.load(fn)
36
- else: # read image
37
- im = cv2.imread(f) # BGR
38
- if im is None:
39
- raise FileNotFoundError(f"Image Not Found {f}")
40
- h0, w0 = im.shape[:2] # orig hw
41
- return im, (h0, w0), im.shape[:2]
42
-
43
- return self.ims[i], self.im_hw0[i], self.im_hw[i]
44
-
45
- def build_transforms(self, hyp: IterableSimpleNamespace = None):
46
- """Creates transforms for dataset images without resizing."""
47
- return Format(
48
- bbox_format="xyxy",
49
- normalize=False,
50
- return_mask=self.use_segments,
51
- return_keypoint=self.use_keypoints,
52
- batch_idx=True,
53
- mask_ratio=hyp.mask_ratio,
54
- mask_overlap=hyp.overlap_mask,
55
- )
56
-
57
-
58
- class Explorer:
59
- """Utility class for image embedding, table creation, and similarity querying using LanceDB and YOLO models."""
60
-
61
- def __init__(
62
- self,
63
- data: Union[str, Path] = "coco128.yaml",
64
- model: str = "yolov8n.pt",
65
- uri: str = USER_CONFIG_DIR / "explorer",
66
- ) -> None:
67
- """Initializes the Explorer class with dataset path, model, and URI for database connection."""
68
- # Note duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181
69
- checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"])
70
- import lancedb
71
-
72
- self.connection = lancedb.connect(uri)
73
- self.table_name = f"{Path(data).name.lower()}_{model.lower()}"
74
- self.sim_idx_base_name = (
75
- f"{self.table_name}_sim_idx".lower()
76
- ) # Use this name and append thres and top_k to reuse the table
77
- self.model = YOLO(model)
78
- self.data = data # None
79
- self.choice_set = None
80
-
81
- self.table = None
82
- self.progress = 0
83
-
84
- def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
85
- """
86
- Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
87
- already exists. Pass force=True to overwrite the existing table.
88
-
89
- Args:
90
- force (bool): Whether to overwrite the existing table or not. Defaults to False.
91
- split (str): Split of the dataset to use. Defaults to 'train'.
92
-
93
- Example:
94
- ```python
95
- exp = Explorer()
96
- exp.create_embeddings_table()
97
- ```
98
- """
99
- if self.table is not None and not force:
100
- LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
101
- return
102
- if self.table_name in self.connection.table_names() and not force:
103
- LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
104
- self.table = self.connection.open_table(self.table_name)
105
- self.progress = 1
106
- return
107
- if self.data is None:
108
- raise ValueError("Data must be provided to create embeddings table")
109
-
110
- data_info = check_det_dataset(self.data)
111
- if split not in data_info:
112
- raise ValueError(
113
- f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
114
- )
115
-
116
- choice_set = data_info[split]
117
- choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
118
- self.choice_set = choice_set
119
- dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
120
-
121
- # Create the table schema
122
- batch = dataset[0]
123
- vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
124
- table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
125
- table.add(
126
- self._yield_batches(
127
- dataset,
128
- data_info,
129
- self.model,
130
- exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
131
- )
132
- )
133
-
134
- self.table = table
135
-
136
- def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
137
- """Generates batches of data for embedding, excluding specified keys."""
138
- for i in tqdm(range(len(dataset))):
139
- self.progress = float(i + 1) / len(dataset)
140
- batch = dataset[i]
141
- for k in exclude_keys:
142
- batch.pop(k, None)
143
- batch = sanitize_batch(batch, data_info)
144
- batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
145
- yield [batch]
146
-
147
- def query(
148
- self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
149
- ) -> Any: # pyarrow.Table
150
- """
151
- Query the table for similar images. Accepts a single image or a list of images.
152
-
153
- Args:
154
- imgs (str or list): Path to the image or a list of paths to the images.
155
- limit (int): Number of results to return.
156
-
157
- Returns:
158
- (pyarrow.Table): An arrow table containing the results. Supports converting to:
159
- - pandas dataframe: `result.to_pandas()`
160
- - dict of lists: `result.to_pydict()`
161
-
162
- Example:
163
- ```python
164
- exp = Explorer()
165
- exp.create_embeddings_table()
166
- similar = exp.query(img="https://ultralytics.com/images/zidane.jpg")
167
- ```
168
- """
169
- if self.table is None:
170
- raise ValueError("Table is not created. Please create the table first.")
171
- if isinstance(imgs, str):
172
- imgs = [imgs]
173
- assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
174
- embeds = self.model.embed(imgs)
175
- # Get avg if multiple images are passed (len > 1)
176
- embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
177
- return self.table.search(embeds).limit(limit).to_arrow()
178
-
179
- def sql_query(
180
- self, query: str, return_type: str = "pandas"
181
- ) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table
182
- """
183
- Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
184
-
185
- Args:
186
- query (str): SQL query to run.
187
- return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
188
-
189
- Returns:
190
- (pyarrow.Table): An arrow table containing the results.
191
-
192
- Example:
193
- ```python
194
- exp = Explorer()
195
- exp.create_embeddings_table()
196
- query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
197
- result = exp.sql_query(query)
198
- ```
199
- """
200
- assert return_type in {
201
- "pandas",
202
- "arrow",
203
- }, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
204
- import duckdb
205
-
206
- if self.table is None:
207
- raise ValueError("Table is not created. Please create the table first.")
208
-
209
- # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
210
- table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
211
- if not query.startswith("SELECT") and not query.startswith("WHERE"):
212
- raise ValueError(
213
- f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE "
214
- f"clause. found {query}"
215
- )
216
- if query.startswith("WHERE"):
217
- query = f"SELECT * FROM 'table' {query}"
218
- LOGGER.info(f"Running query: {query}")
219
-
220
- rs = duckdb.sql(query)
221
- if return_type == "arrow":
222
- return rs.arrow()
223
- elif return_type == "pandas":
224
- return rs.df()
225
-
226
- def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
227
- """
228
- Plot the results of a SQL-Like query on the table.
229
-
230
- Args:
231
- query (str): SQL query to run.
232
- labels (bool): Whether to plot the labels or not.
233
-
234
- Returns:
235
- (PIL.Image): Image containing the plot.
236
-
237
- Example:
238
- ```python
239
- exp = Explorer()
240
- exp.create_embeddings_table()
241
- query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
242
- result = exp.plot_sql_query(query)
243
- ```
244
- """
245
- result = self.sql_query(query, return_type="arrow")
246
- if len(result) == 0:
247
- LOGGER.info("No results found.")
248
- return None
249
- img = plot_query_result(result, plot_labels=labels)
250
- return Image.fromarray(img)
251
-
252
- def get_similar(
253
- self,
254
- img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
255
- idx: Union[int, List[int]] = None,
256
- limit: int = 25,
257
- return_type: str = "pandas",
258
- ) -> Any: # pandas.DataFrame or pyarrow.Table
259
- """
260
- Query the table for similar images. Accepts a single image or a list of images.
261
-
262
- Args:
263
- img (str or list): Path to the image or a list of paths to the images.
264
- idx (int or list): Index of the image in the table or a list of indexes.
265
- limit (int): Number of results to return. Defaults to 25.
266
- return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
267
-
268
- Returns:
269
- (pandas.DataFrame): A dataframe containing the results.
270
-
271
- Example:
272
- ```python
273
- exp = Explorer()
274
- exp.create_embeddings_table()
275
- similar = exp.get_similar(img="https://ultralytics.com/images/zidane.jpg")
276
- ```
277
- """
278
- assert return_type in {"pandas", "arrow"}, f"Return type should be `pandas` or `arrow`, but got {return_type}"
279
- img = self._check_imgs_or_idxs(img, idx)
280
- similar = self.query(img, limit=limit)
281
-
282
- if return_type == "arrow":
283
- return similar
284
- elif return_type == "pandas":
285
- return similar.to_pandas()
286
-
287
- def plot_similar(
288
- self,
289
- img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
290
- idx: Union[int, List[int]] = None,
291
- limit: int = 25,
292
- labels: bool = True,
293
- ) -> Image.Image:
294
- """
295
- Plot the similar images. Accepts images or indexes.
296
-
297
- Args:
298
- img (str or list): Path to the image or a list of paths to the images.
299
- idx (int or list): Index of the image in the table or a list of indexes.
300
- labels (bool): Whether to plot the labels or not.
301
- limit (int): Number of results to return. Defaults to 25.
302
-
303
- Returns:
304
- (PIL.Image): Image containing the plot.
305
-
306
- Example:
307
- ```python
308
- exp = Explorer()
309
- exp.create_embeddings_table()
310
- similar = exp.plot_similar(img="https://ultralytics.com/images/zidane.jpg")
311
- ```
312
- """
313
- similar = self.get_similar(img, idx, limit, return_type="arrow")
314
- if len(similar) == 0:
315
- LOGGER.info("No results found.")
316
- return None
317
- img = plot_query_result(similar, plot_labels=labels)
318
- return Image.fromarray(img)
319
-
320
- def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame
321
- """
322
- Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
323
- are max_dist or closer to the image in the embedding space at a given index.
324
-
325
- Args:
326
- max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
327
- top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit.
328
- vector search. Defaults: None.
329
- force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
330
-
331
- Returns:
332
- (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image,
333
- and columns include indices of similar images and their respective distances.
334
-
335
- Example:
336
- ```python
337
- exp = Explorer()
338
- exp.create_embeddings_table()
339
- sim_idx = exp.similarity_index()
340
- ```
341
- """
342
- if self.table is None:
343
- raise ValueError("Table is not created. Please create the table first.")
344
- sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
345
- if sim_idx_table_name in self.connection.table_names() and not force:
346
- LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
347
- return self.connection.open_table(sim_idx_table_name).to_pandas()
348
-
349
- if top_k and not (1.0 >= top_k >= 0.0):
350
- raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
351
- if max_dist < 0.0:
352
- raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
353
-
354
- top_k = int(top_k * len(self.table)) if top_k else len(self.table)
355
- top_k = max(top_k, 1)
356
- features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
357
- im_files = features["im_file"]
358
- embeddings = features["vector"]
359
-
360
- sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
361
-
362
- def _yield_sim_idx():
363
- """Generates a dataframe with similarity indices and distances for images."""
364
- for i in tqdm(range(len(embeddings))):
365
- sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
366
- yield [
367
- {
368
- "idx": i,
369
- "im_file": im_files[i],
370
- "count": len(sim_idx),
371
- "sim_im_files": sim_idx["im_file"].tolist(),
372
- }
373
- ]
374
-
375
- sim_table.add(_yield_sim_idx())
376
- self.sim_index = sim_table
377
- return sim_table.to_pandas()
378
-
379
- def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
380
- """
381
- Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
382
- max_dist or closer to the image in the embedding space at a given index.
383
-
384
- Args:
385
- max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
386
- top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
387
- running vector search. Defaults to 0.01.
388
- force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
389
-
390
- Returns:
391
- (PIL.Image): Image containing the plot.
392
-
393
- Example:
394
- ```python
395
- exp = Explorer()
396
- exp.create_embeddings_table()
397
-
398
- similarity_idx_plot = exp.plot_similarity_index()
399
- similarity_idx_plot.show() # view image preview
400
- similarity_idx_plot.save("path/to/save/similarity_index_plot.png") # save contents to file
401
- ```
402
- """
403
- sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
404
- sim_count = sim_idx["count"].tolist()
405
- sim_count = np.array(sim_count)
406
-
407
- indices = np.arange(len(sim_count))
408
-
409
- # Create the bar plot
410
- plt.bar(indices, sim_count)
411
-
412
- # Customize the plot (optional)
413
- plt.xlabel("data idx")
414
- plt.ylabel("Count")
415
- plt.title("Similarity Count")
416
- buffer = BytesIO()
417
- plt.savefig(buffer, format="png")
418
- buffer.seek(0)
419
-
420
- # Use Pillow to open the image from the buffer
421
- return Image.fromarray(np.array(Image.open(buffer)))
422
-
423
- def _check_imgs_or_idxs(
424
- self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
425
- ) -> List[np.ndarray]:
426
- """Determines whether to fetch images or indexes based on provided arguments and returns image paths."""
427
- if img is None and idx is None:
428
- raise ValueError("Either img or idx must be provided.")
429
- if img is not None and idx is not None:
430
- raise ValueError("Only one of img or idx must be provided.")
431
- if idx is not None:
432
- idx = idx if isinstance(idx, list) else [idx]
433
- img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
434
-
435
- return img if isinstance(img, list) else [img]
436
-
437
- def ask_ai(self, query):
438
- """
439
- Ask AI a question.
440
-
441
- Args:
442
- query (str): Question to ask.
443
-
444
- Returns:
445
- (pandas.DataFrame): A dataframe containing filtered results to the SQL query.
446
-
447
- Example:
448
- ```python
449
- exp = Explorer()
450
- exp.create_embeddings_table()
451
- answer = exp.ask_ai("Show images with 1 person and 2 dogs")
452
- ```
453
- """
454
- result = prompt_sql_query(query)
455
- try:
456
- return self.sql_query(result)
457
- except Exception as e:
458
- LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
459
- LOGGER.error(e)
460
- return None
@@ -1 +0,0 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license