ultralytics 8.3.11__py3-none-any.whl → 8.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- tests/test_cli.py +4 -1
- tests/test_cuda.py +13 -1
- ultralytics/__init__.py +1 -3
- ultralytics/cfg/__init__.py +2 -35
- ultralytics/cfg/solutions/default.yaml +1 -0
- ultralytics/engine/exporter.py +9 -1
- ultralytics/models/sam/predict.py +79 -50
- ultralytics/models/yolo/classify/train.py +1 -2
- ultralytics/solutions/analytics.py +151 -264
- ultralytics/solutions/distance_calculation.py +15 -72
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/utils/metrics.py +1 -1
- ultralytics/utils/plotting.py +14 -15
- {ultralytics-8.3.11.dist-info → ultralytics-8.3.13.dist-info}/METADATA +4 -5
- {ultralytics-8.3.11.dist-info → ultralytics-8.3.13.dist-info}/RECORD +19 -24
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -460
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -269
- ultralytics/data/explorer/utils.py +0 -167
- {ultralytics-8.3.11.dist-info → ultralytics-8.3.13.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.11.dist-info → ultralytics-8.3.13.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.11.dist-info → ultralytics-8.3.13.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.11.dist-info → ultralytics-8.3.13.dist-info}/top_level.txt +0 -0
ultralytics/utils/plotting.py
CHANGED
|
@@ -804,31 +804,30 @@ class Annotator:
|
|
|
804
804
|
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf
|
|
805
805
|
)
|
|
806
806
|
|
|
807
|
-
def plot_distance_and_line(
|
|
807
|
+
def plot_distance_and_line(
|
|
808
|
+
self, pixels_distance, centroids, line_color=(104, 31, 17), centroid_color=(255, 0, 255)
|
|
809
|
+
):
|
|
808
810
|
"""
|
|
809
811
|
Plot the distance and line on frame.
|
|
810
812
|
|
|
811
813
|
Args:
|
|
812
814
|
pixels_distance (float): Pixels distance between two bbox centroids.
|
|
813
815
|
centroids (list): Bounding box centroids data.
|
|
814
|
-
line_color (tuple):
|
|
815
|
-
centroid_color (tuple):
|
|
816
|
+
line_color (tuple, optional): Distance line color.
|
|
817
|
+
centroid_color (tuple, optional): Bounding box centroid color.
|
|
816
818
|
"""
|
|
817
819
|
# Get the text size
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
)
|
|
820
|
+
text = f"Pixels Distance: {pixels_distance:.2f}"
|
|
821
|
+
(text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)
|
|
821
822
|
|
|
822
823
|
# Define corners with 10-pixel margin and draw rectangle
|
|
823
|
-
|
|
824
|
-
bottom_right = (15 + text_width_m + 20, 25 + text_height_m + 20)
|
|
825
|
-
cv2.rectangle(self.im, top_left, bottom_right, centroid_color, -1)
|
|
824
|
+
cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)
|
|
826
825
|
|
|
827
826
|
# Calculate the position for the text with a 10-pixel margin and draw text
|
|
828
|
-
text_position = (
|
|
827
|
+
text_position = (25, 25 + text_height_m + 10)
|
|
829
828
|
cv2.putText(
|
|
830
829
|
self.im,
|
|
831
|
-
|
|
830
|
+
text,
|
|
832
831
|
text_position,
|
|
833
832
|
0,
|
|
834
833
|
self.sf,
|
|
@@ -1156,16 +1155,16 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
|
|
1156
1155
|
save_dir = Path(file).parent if file else Path(dir)
|
|
1157
1156
|
if classify:
|
|
1158
1157
|
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
|
|
1159
|
-
index = [
|
|
1158
|
+
index = [2, 5, 3, 4]
|
|
1160
1159
|
elif segment:
|
|
1161
1160
|
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
|
|
1162
|
-
index = [
|
|
1161
|
+
index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
|
|
1163
1162
|
elif pose:
|
|
1164
1163
|
fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
|
|
1165
|
-
index = [
|
|
1164
|
+
index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
|
|
1166
1165
|
else:
|
|
1167
1166
|
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
|
1168
|
-
index = [
|
|
1167
|
+
index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
|
|
1169
1168
|
ax = ax.ravel()
|
|
1170
1169
|
files = list(save_dir.glob("results*.csv"))
|
|
1171
1170
|
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ultralytics
|
|
3
|
-
Version: 8.3.
|
|
3
|
+
Version: 8.3.13
|
|
4
4
|
Summary: Ultralytics YOLO for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
|
|
5
5
|
Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
|
|
6
6
|
Maintainer-email: Ultralytics <hello@ultralytics.com>
|
|
@@ -60,10 +60,6 @@ Requires-Dist: mkdocs-jupyter; extra == "dev"
|
|
|
60
60
|
Requires-Dist: mkdocs-redirects; extra == "dev"
|
|
61
61
|
Requires-Dist: mkdocs-ultralytics-plugin>=0.1.8; extra == "dev"
|
|
62
62
|
Requires-Dist: mkdocs-macros-plugin>=1.0.5; extra == "dev"
|
|
63
|
-
Provides-Extra: explorer
|
|
64
|
-
Requires-Dist: lancedb; extra == "explorer"
|
|
65
|
-
Requires-Dist: duckdb<=0.9.2; extra == "explorer"
|
|
66
|
-
Requires-Dist: streamlit; extra == "explorer"
|
|
67
63
|
Provides-Extra: export
|
|
68
64
|
Requires-Dist: onnx>=1.12.0; extra == "export"
|
|
69
65
|
Requires-Dist: openvino>=2024.0.0; extra == "export"
|
|
@@ -85,6 +81,9 @@ Provides-Extra: logging
|
|
|
85
81
|
Requires-Dist: comet; extra == "logging"
|
|
86
82
|
Requires-Dist: tensorboard>=2.13.0; extra == "logging"
|
|
87
83
|
Requires-Dist: dvclive>=2.12.0; extra == "logging"
|
|
84
|
+
Provides-Extra: solutions
|
|
85
|
+
Requires-Dist: shapely>=2.0.0; extra == "solutions"
|
|
86
|
+
Requires-Dist: streamlit; extra == "solutions"
|
|
88
87
|
|
|
89
88
|
<div align="center">
|
|
90
89
|
<p>
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
tests/__init__.py,sha256=iVH5nXrACTDv0_ZIVRPi-9f6oYBl6g-tCkeR2Hb8MFM,666
|
|
2
2
|
tests/conftest.py,sha256=9PFAiwAy6eeORGspr5dOKxVuFDVKqYg8Nn_RxSJ27UI,2919
|
|
3
|
-
tests/test_cli.py,sha256=
|
|
4
|
-
tests/test_cuda.py,sha256=
|
|
3
|
+
tests/test_cli.py,sha256=G7OJ1ErQYsGy2Dx1zP-0p7EZR4aPoAdtLGiY4Hm7jQM,5006
|
|
4
|
+
tests/test_cuda.py,sha256=cxohrqAjjPCgTaChYhQ4CObQzz7ZyhVGPTxkcxHSf0g,5940
|
|
5
5
|
tests/test_engine.py,sha256=dcEcJsMQh61rDSNv7l4TIAgybLpzjVwerv9JZC_KCM8,4934
|
|
6
6
|
tests/test_exports.py,sha256=fpTKEVBUGLF3WiZPNKRs-IEcIY4cfxgvgKjUNfodjww,8042
|
|
7
7
|
tests/test_integrations.py,sha256=f5-QCUk1SU_-qn4mBCZwS3GN3tXEBIIXo4z2EhExbHw,6126
|
|
8
8
|
tests/test_python.py,sha256=I1RRdCwLdrc3jX06huVxct8HX8ccQOmQgVpuEflRl0U,23560
|
|
9
9
|
tests/test_solutions.py,sha256=dpxWGKO-aJ3Yff4KR7BQGajX9VyFdGTWEtcbmFC3WwE,3005
|
|
10
|
-
ultralytics/__init__.py,sha256=
|
|
10
|
+
ultralytics/__init__.py,sha256=2rg2RMDy6HqtBcSx4b7eBss9eXQ6leZvZ6drzM-8sFI,681
|
|
11
11
|
ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
|
|
12
12
|
ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
|
|
13
|
-
ultralytics/cfg/__init__.py,sha256=
|
|
13
|
+
ultralytics/cfg/__init__.py,sha256=CUg1z4zY3KyR-V4_bghMY8s1xuu-M50gm-v_vpHdXEM,31753
|
|
14
14
|
ultralytics/cfg/default.yaml,sha256=ul49zgSzTegMmc8CFeu9tXkWNvQhETdZMa9EgDNSnY4,8319
|
|
15
15
|
ultralytics/cfg/datasets/Argoverse.yaml,sha256=FyeuJT5CHq_9d4hlfAf0kpZlnbUMO0S--UJ1yIqcdKk,3134
|
|
16
16
|
ultralytics/cfg/datasets/DOTAv1.5.yaml,sha256=QVfp_Qp-4rukuicaB4qx86NxSHM8Mrzym8l_fIDo8gw,1195
|
|
@@ -85,7 +85,7 @@ ultralytics/cfg/models/v9/yolov9e.yaml,sha256=dhaR47WxuLOrZWDCceS4bQG00sQdrMc8FQ
|
|
|
85
85
|
ultralytics/cfg/models/v9/yolov9m.yaml,sha256=l6CmivzNu44sRVmkQXk4-tXflbV1nWnk5MSc8su2vhs,1311
|
|
86
86
|
ultralytics/cfg/models/v9/yolov9s.yaml,sha256=lPWcu-6ub1kCBD6zIDFwthYZ3RvdJfODWKy3vEQWRjo,1291
|
|
87
87
|
ultralytics/cfg/models/v9/yolov9t.yaml,sha256=qL__kr6GoefpQWP4jV0jdzwTp46bdFUcqtPRnfDbkY8,1275
|
|
88
|
-
ultralytics/cfg/solutions/default.yaml,sha256=
|
|
88
|
+
ultralytics/cfg/solutions/default.yaml,sha256=zZ_ksoZ-BcAIL5jjw0jzHgraoe7363oxXqfSsg9yopk,1673
|
|
89
89
|
ultralytics/cfg/trackers/botsort.yaml,sha256=8B0xNbnG_E-9DCUpap72PWkUgBb1AjuApEn7gHiVngE,916
|
|
90
90
|
ultralytics/cfg/trackers/bytetrack.yaml,sha256=8vpTZ2x9mhRXJymoJvs1G8kTXo_HxbSwHup2FQALT3A,721
|
|
91
91
|
ultralytics/data/__init__.py,sha256=VGe-ATG7j35F4A4r8Jmzffjlhve4JAJPgRa5ahKTU18,616
|
|
@@ -98,13 +98,8 @@ ultralytics/data/dataset.py,sha256=D556AW0ZEsW3V8c5zJiHM_prc_YfZqymIkDKPw3k9Io,2
|
|
|
98
98
|
ultralytics/data/loaders.py,sha256=Fr70Q9p9t7buLW_8R2_lI_nyCMG033gWSxvwy1M-a-U,28449
|
|
99
99
|
ultralytics/data/split_dota.py,sha256=yOtypHoY5HvIVBKZgFXdfj2tuCLLEBnMwNfAeG94Eik,10680
|
|
100
100
|
ultralytics/data/utils.py,sha256=u6OZ7InLpI1em5aEPz13ZzS9BcO37dcY9_s2btXGZYQ,31076
|
|
101
|
-
ultralytics/data/explorer/__init__.py,sha256=-Y3m1ZedepOQUv_KW82zaGxvU_PSHcuwUTFqG9BhAr4,113
|
|
102
|
-
ultralytics/data/explorer/explorer.py,sha256=JWmLHHhp68h2q3vx4poBou5RYoAX3R89yihR50YLDb0,18881
|
|
103
|
-
ultralytics/data/explorer/utils.py,sha256=EvvukQiQUTBrsZznmMnyEX2EqTuwZo_Geyc8yfi8NIA,7085
|
|
104
|
-
ultralytics/data/explorer/gui/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
|
|
105
|
-
ultralytics/data/explorer/gui/dash.py,sha256=6XOZy9NrkPEXREJPbi0EBkGgu78TAdHpdhSB2HuBOAo,10222
|
|
106
101
|
ultralytics/engine/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
|
|
107
|
-
ultralytics/engine/exporter.py,sha256=
|
|
102
|
+
ultralytics/engine/exporter.py,sha256=lKmVaypzVY3R-RkCs1KQNrSpF6W4jKMZqNBKZ0CfzmA,57670
|
|
108
103
|
ultralytics/engine/model.py,sha256=pvL1uf-wwdWL8Iph7VEAYn1-z7wEHzVug21V_0_gO6M,51456
|
|
109
104
|
ultralytics/engine/predictor.py,sha256=keTelEeo23Dcbs-XvmRWAPIs4pbCNDtsMBz88WM1eK8,17534
|
|
110
105
|
ultralytics/engine/results.py,sha256=BxanBI8PhBCfs-9cSy-GS6naScuiD3hdvUAJWPW2mS0,75043
|
|
@@ -135,7 +130,7 @@ ultralytics/models/sam/__init__.py,sha256=o4_D6y8YJlOXIK7Lwo9RHnIJJ9xoFNi4zK99QS
|
|
|
135
130
|
ultralytics/models/sam/amg.py,sha256=GrmO_8YfIDt_QkPEMF_WFjPZkhwhf7iwx7ig8JgOUnE,8709
|
|
136
131
|
ultralytics/models/sam/build.py,sha256=np9vP7AETCZA2Wdds-uj2eQKVnpHQaVpRrE2-U2uMTI,12153
|
|
137
132
|
ultralytics/models/sam/model.py,sha256=2KFUp8SHiqOgwUjkdqdau0oduJwKQxm4N9GHWjdhUFo,7382
|
|
138
|
-
ultralytics/models/sam/predict.py,sha256=
|
|
133
|
+
ultralytics/models/sam/predict.py,sha256=LSxys7fuQycrAoqf_EFohk9ftu7cq1F2GY9_fuIl5uE,40384
|
|
139
134
|
ultralytics/models/sam/modules/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
|
|
140
135
|
ultralytics/models/sam/modules/blocks.py,sha256=Q-KwhFbdyZhl1tjG_kP2LcQkZbzoNt618i-NRrKNx2Y,45919
|
|
141
136
|
ultralytics/models/sam/modules/decoders.py,sha256=mODsqnTN_CjE3H0Sh9cd8PfTnHANPjGB1bjqHxfezSg,25830
|
|
@@ -152,7 +147,7 @@ ultralytics/models/yolo/__init__.py,sha256=e1cZr9pbSbf3Ya2OvkTjGRwD_E2YZpe610xsk
|
|
|
152
147
|
ultralytics/models/yolo/model.py,sha256=E4TuJZZux0L_SG7sC0SDgxrmeBvuZRpxprPrCC26lvs,4233
|
|
153
148
|
ultralytics/models/yolo/classify/__init__.py,sha256=t-4pUHmgI2gjhc-l3bqNEcEtKD1dO40nD4Vc6Y2xD6o,355
|
|
154
149
|
ultralytics/models/yolo/classify/predict.py,sha256=0CEJ4B4fXbOMUnJy79gRvG-qdszOzTSLOb1xxkgsKek,2444
|
|
155
|
-
ultralytics/models/yolo/classify/train.py,sha256=
|
|
150
|
+
ultralytics/models/yolo/classify/train.py,sha256=3aYzLDqX_03xR1xqlTn1TxA4t58cCIGI8RCtWheTrm0,6273
|
|
156
151
|
ultralytics/models/yolo/classify/val.py,sha256=Tzizhp3ebzPvwJejrE8tb-TuXw4MdkEI9mOANV74eXQ,4909
|
|
157
152
|
ultralytics/models/yolo/detect/__init__.py,sha256=JR8gZJWn7wMBbh-0j_073nxJVZTMFZVWTOG5Wnvk6w0,229
|
|
158
153
|
ultralytics/models/yolo/detect/predict.py,sha256=-uZFLutxGYZX47RANcaxC-LFStRbv0nBv_8-ypadQoI,1471
|
|
@@ -185,10 +180,10 @@ ultralytics/nn/modules/transformer.py,sha256=tGiK8NmPfswwW1rbF21r5ILUkkZQ6Nk4s8j
|
|
|
185
180
|
ultralytics/nn/modules/utils.py,sha256=a88cKl2wz1nMVSEBiajtvaCbDBQIkESWOKTZ_WAJy90,3195
|
|
186
181
|
ultralytics/solutions/__init__.py,sha256=6RDeXWO1QSaMgCq8YrWXaj2xvPw2sJwJL_a0dgjCvz0,648
|
|
187
182
|
ultralytics/solutions/ai_gym.py,sha256=lBAkWV8vrEdKAXcBFVbugPeZZ08MOjGYTdnFlG22vKM,3772
|
|
188
|
-
ultralytics/solutions/analytics.py,sha256=
|
|
189
|
-
ultralytics/solutions/distance_calculation.py,sha256=
|
|
183
|
+
ultralytics/solutions/analytics.py,sha256=w5hnnBNSTQ35tJp6DDeWYw2ASjylp3ZmzrTXcdWwDw8,9319
|
|
184
|
+
ultralytics/solutions/distance_calculation.py,sha256=3D5qj9g-XGt_QPEu5IQI2ubTC0n2pmISDrNPl__JK9M,3373
|
|
190
185
|
ultralytics/solutions/heatmap.py,sha256=2C4s_rVFcOc5oSWxb0pNxNoCawe4lxajpTDNFd4tVL8,3850
|
|
191
|
-
ultralytics/solutions/object_counter.py,sha256=
|
|
186
|
+
ultralytics/solutions/object_counter.py,sha256=7s3Q--CAFHr_uXzeq6epXgl5YSinc6q-VThPBx1Gj3Y,5485
|
|
192
187
|
ultralytics/solutions/parking_management.py,sha256=VgYyhoSEo7fnPegIhNUqnFL0jlMEevALx0QQbzJ3vGI,9049
|
|
193
188
|
ultralytics/solutions/queue_management.py,sha256=5d1RURQiqffAoET8S66gHimK0l3gKNAfuPO5U6_08jc,2716
|
|
194
189
|
ultralytics/solutions/solutions.py,sha256=qWKGlwlH9858GfAdZkcu_QXbrzjTFStDvg16Eky0oyo,3541
|
|
@@ -213,10 +208,10 @@ ultralytics/utils/errors.py,sha256=GqP_Jgj_n0paxn8OMhn3DTCgoNkB2WjUcUaqs-M6SQk,8
|
|
|
213
208
|
ultralytics/utils/files.py,sha256=uiXQSVABJRoI5ImnM6ndEBIFbECfksmWNEldBg8GnSo,8224
|
|
214
209
|
ultralytics/utils/instance.py,sha256=QSms7mPHZ5e8JGuJYLohLWltzI0aBE8dob2rOUK4RtM,16249
|
|
215
210
|
ultralytics/utils/loss.py,sha256=SW3FVFFp8Ki_LCT8wIdFbm6KmyPcQn3RmKNcvVAhMQI,34174
|
|
216
|
-
ultralytics/utils/metrics.py,sha256=
|
|
211
|
+
ultralytics/utils/metrics.py,sha256=msPaXc244ndc0NPBhnNlHsKkVhdc-TMgFn5NATlZZVI,53918
|
|
217
212
|
ultralytics/utils/ops.py,sha256=dsXNdyrYx_p6io6zezig9p84dxS7U-10vceHNVu2IL0,32888
|
|
218
213
|
ultralytics/utils/patches.py,sha256=J-iOwIRbfUs-inBZerhnXby5tUKjYcOIyvhLTS352JE,3270
|
|
219
|
-
ultralytics/utils/plotting.py,sha256=
|
|
214
|
+
ultralytics/utils/plotting.py,sha256=RYTdMJtWOO5qPowca1a8izfasoIyGxzmfp9VGB_g0xE,61092
|
|
220
215
|
ultralytics/utils/tal.py,sha256=ECsu95xEqOItmxMDN4YTD3FsUiIsQNWy0pZC3TfvFfk,16877
|
|
221
216
|
ultralytics/utils/torch_utils.py,sha256=gVN-KSrAzJC1rW3woQd4FsTT693GD8rXiccToL2m4kM,30059
|
|
222
217
|
ultralytics/utils/triton.py,sha256=gg1finxno_tY2Ge9PMhmu7PI9wvoFZoiicdT4Bhqv3w,3936
|
|
@@ -232,9 +227,9 @@ ultralytics/utils/callbacks/neptune.py,sha256=IbGQfEltamUKXJt93uSLQFn8c2rYh3DMTg
|
|
|
232
227
|
ultralytics/utils/callbacks/raytune.py,sha256=ODVYzy-CoM4Uge0zjkh3Hnh9nF2M0vhDrSenXnvcizw,705
|
|
233
228
|
ultralytics/utils/callbacks/tensorboard.py,sha256=bv4fkkesdgmZv_E2MU6wuaMBwEV5iI2G53RHPyD9quw,4170
|
|
234
229
|
ultralytics/utils/callbacks/wb.py,sha256=upfbF8-LLXueUvulLaMDmKDhKCl_PWbNa_87PQ0L0Rc,6752
|
|
235
|
-
ultralytics-8.3.
|
|
236
|
-
ultralytics-8.3.
|
|
237
|
-
ultralytics-8.3.
|
|
238
|
-
ultralytics-8.3.
|
|
239
|
-
ultralytics-8.3.
|
|
240
|
-
ultralytics-8.3.
|
|
230
|
+
ultralytics-8.3.13.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
|
|
231
|
+
ultralytics-8.3.13.dist-info/METADATA,sha256=lwU4l_KFBx8TzMo_vwe4pyKOaepvZ3mPvWwm6Y951_A,34660
|
|
232
|
+
ultralytics-8.3.13.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
233
|
+
ultralytics-8.3.13.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
|
|
234
|
+
ultralytics-8.3.13.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
|
|
235
|
+
ultralytics-8.3.13.dist-info/RECORD,,
|
|
@@ -1,460 +0,0 @@
|
|
|
1
|
-
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
-
|
|
3
|
-
from io import BytesIO
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Any, List, Tuple, Union
|
|
6
|
-
|
|
7
|
-
import cv2
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
from matplotlib import pyplot as plt
|
|
11
|
-
from PIL import Image
|
|
12
|
-
from tqdm import tqdm
|
|
13
|
-
|
|
14
|
-
from ultralytics.data.augment import Format
|
|
15
|
-
from ultralytics.data.dataset import YOLODataset
|
|
16
|
-
from ultralytics.data.utils import check_det_dataset
|
|
17
|
-
from ultralytics.models.yolo.model import YOLO
|
|
18
|
-
from ultralytics.utils import LOGGER, USER_CONFIG_DIR, IterableSimpleNamespace, checks
|
|
19
|
-
|
|
20
|
-
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class ExplorerDataset(YOLODataset):
|
|
24
|
-
"""Extends YOLODataset for advanced data exploration and manipulation in model training workflows."""
|
|
25
|
-
|
|
26
|
-
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
|
27
|
-
"""Initializes the ExplorerDataset with the provided data arguments, extending the YOLODataset class."""
|
|
28
|
-
super().__init__(*args, data=data, **kwargs)
|
|
29
|
-
|
|
30
|
-
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
|
|
31
|
-
"""Loads 1 image from dataset index 'i' without any resize ops."""
|
|
32
|
-
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
|
33
|
-
if im is None: # not cached in RAM
|
|
34
|
-
if fn.exists(): # load npy
|
|
35
|
-
im = np.load(fn)
|
|
36
|
-
else: # read image
|
|
37
|
-
im = cv2.imread(f) # BGR
|
|
38
|
-
if im is None:
|
|
39
|
-
raise FileNotFoundError(f"Image Not Found {f}")
|
|
40
|
-
h0, w0 = im.shape[:2] # orig hw
|
|
41
|
-
return im, (h0, w0), im.shape[:2]
|
|
42
|
-
|
|
43
|
-
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
|
44
|
-
|
|
45
|
-
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
|
46
|
-
"""Creates transforms for dataset images without resizing."""
|
|
47
|
-
return Format(
|
|
48
|
-
bbox_format="xyxy",
|
|
49
|
-
normalize=False,
|
|
50
|
-
return_mask=self.use_segments,
|
|
51
|
-
return_keypoint=self.use_keypoints,
|
|
52
|
-
batch_idx=True,
|
|
53
|
-
mask_ratio=hyp.mask_ratio,
|
|
54
|
-
mask_overlap=hyp.overlap_mask,
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class Explorer:
|
|
59
|
-
"""Utility class for image embedding, table creation, and similarity querying using LanceDB and YOLO models."""
|
|
60
|
-
|
|
61
|
-
def __init__(
|
|
62
|
-
self,
|
|
63
|
-
data: Union[str, Path] = "coco128.yaml",
|
|
64
|
-
model: str = "yolov8n.pt",
|
|
65
|
-
uri: str = USER_CONFIG_DIR / "explorer",
|
|
66
|
-
) -> None:
|
|
67
|
-
"""Initializes the Explorer class with dataset path, model, and URI for database connection."""
|
|
68
|
-
# Note duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181
|
|
69
|
-
checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"])
|
|
70
|
-
import lancedb
|
|
71
|
-
|
|
72
|
-
self.connection = lancedb.connect(uri)
|
|
73
|
-
self.table_name = f"{Path(data).name.lower()}_{model.lower()}"
|
|
74
|
-
self.sim_idx_base_name = (
|
|
75
|
-
f"{self.table_name}_sim_idx".lower()
|
|
76
|
-
) # Use this name and append thres and top_k to reuse the table
|
|
77
|
-
self.model = YOLO(model)
|
|
78
|
-
self.data = data # None
|
|
79
|
-
self.choice_set = None
|
|
80
|
-
|
|
81
|
-
self.table = None
|
|
82
|
-
self.progress = 0
|
|
83
|
-
|
|
84
|
-
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
|
|
85
|
-
"""
|
|
86
|
-
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
|
87
|
-
already exists. Pass force=True to overwrite the existing table.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
force (bool): Whether to overwrite the existing table or not. Defaults to False.
|
|
91
|
-
split (str): Split of the dataset to use. Defaults to 'train'.
|
|
92
|
-
|
|
93
|
-
Example:
|
|
94
|
-
```python
|
|
95
|
-
exp = Explorer()
|
|
96
|
-
exp.create_embeddings_table()
|
|
97
|
-
```
|
|
98
|
-
"""
|
|
99
|
-
if self.table is not None and not force:
|
|
100
|
-
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
|
|
101
|
-
return
|
|
102
|
-
if self.table_name in self.connection.table_names() and not force:
|
|
103
|
-
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
|
|
104
|
-
self.table = self.connection.open_table(self.table_name)
|
|
105
|
-
self.progress = 1
|
|
106
|
-
return
|
|
107
|
-
if self.data is None:
|
|
108
|
-
raise ValueError("Data must be provided to create embeddings table")
|
|
109
|
-
|
|
110
|
-
data_info = check_det_dataset(self.data)
|
|
111
|
-
if split not in data_info:
|
|
112
|
-
raise ValueError(
|
|
113
|
-
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
choice_set = data_info[split]
|
|
117
|
-
choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
|
|
118
|
-
self.choice_set = choice_set
|
|
119
|
-
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
|
|
120
|
-
|
|
121
|
-
# Create the table schema
|
|
122
|
-
batch = dataset[0]
|
|
123
|
-
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
|
|
124
|
-
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
|
|
125
|
-
table.add(
|
|
126
|
-
self._yield_batches(
|
|
127
|
-
dataset,
|
|
128
|
-
data_info,
|
|
129
|
-
self.model,
|
|
130
|
-
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
|
|
131
|
-
)
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
self.table = table
|
|
135
|
-
|
|
136
|
-
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
|
|
137
|
-
"""Generates batches of data for embedding, excluding specified keys."""
|
|
138
|
-
for i in tqdm(range(len(dataset))):
|
|
139
|
-
self.progress = float(i + 1) / len(dataset)
|
|
140
|
-
batch = dataset[i]
|
|
141
|
-
for k in exclude_keys:
|
|
142
|
-
batch.pop(k, None)
|
|
143
|
-
batch = sanitize_batch(batch, data_info)
|
|
144
|
-
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
|
|
145
|
-
yield [batch]
|
|
146
|
-
|
|
147
|
-
def query(
|
|
148
|
-
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
|
|
149
|
-
) -> Any: # pyarrow.Table
|
|
150
|
-
"""
|
|
151
|
-
Query the table for similar images. Accepts a single image or a list of images.
|
|
152
|
-
|
|
153
|
-
Args:
|
|
154
|
-
imgs (str or list): Path to the image or a list of paths to the images.
|
|
155
|
-
limit (int): Number of results to return.
|
|
156
|
-
|
|
157
|
-
Returns:
|
|
158
|
-
(pyarrow.Table): An arrow table containing the results. Supports converting to:
|
|
159
|
-
- pandas dataframe: `result.to_pandas()`
|
|
160
|
-
- dict of lists: `result.to_pydict()`
|
|
161
|
-
|
|
162
|
-
Example:
|
|
163
|
-
```python
|
|
164
|
-
exp = Explorer()
|
|
165
|
-
exp.create_embeddings_table()
|
|
166
|
-
similar = exp.query(img="https://ultralytics.com/images/zidane.jpg")
|
|
167
|
-
```
|
|
168
|
-
"""
|
|
169
|
-
if self.table is None:
|
|
170
|
-
raise ValueError("Table is not created. Please create the table first.")
|
|
171
|
-
if isinstance(imgs, str):
|
|
172
|
-
imgs = [imgs]
|
|
173
|
-
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
|
|
174
|
-
embeds = self.model.embed(imgs)
|
|
175
|
-
# Get avg if multiple images are passed (len > 1)
|
|
176
|
-
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
|
177
|
-
return self.table.search(embeds).limit(limit).to_arrow()
|
|
178
|
-
|
|
179
|
-
def sql_query(
|
|
180
|
-
self, query: str, return_type: str = "pandas"
|
|
181
|
-
) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table
|
|
182
|
-
"""
|
|
183
|
-
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
|
184
|
-
|
|
185
|
-
Args:
|
|
186
|
-
query (str): SQL query to run.
|
|
187
|
-
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
|
188
|
-
|
|
189
|
-
Returns:
|
|
190
|
-
(pyarrow.Table): An arrow table containing the results.
|
|
191
|
-
|
|
192
|
-
Example:
|
|
193
|
-
```python
|
|
194
|
-
exp = Explorer()
|
|
195
|
-
exp.create_embeddings_table()
|
|
196
|
-
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
|
197
|
-
result = exp.sql_query(query)
|
|
198
|
-
```
|
|
199
|
-
"""
|
|
200
|
-
assert return_type in {
|
|
201
|
-
"pandas",
|
|
202
|
-
"arrow",
|
|
203
|
-
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
|
204
|
-
import duckdb
|
|
205
|
-
|
|
206
|
-
if self.table is None:
|
|
207
|
-
raise ValueError("Table is not created. Please create the table first.")
|
|
208
|
-
|
|
209
|
-
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
|
210
|
-
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
|
211
|
-
if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
|
212
|
-
raise ValueError(
|
|
213
|
-
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE "
|
|
214
|
-
f"clause. found {query}"
|
|
215
|
-
)
|
|
216
|
-
if query.startswith("WHERE"):
|
|
217
|
-
query = f"SELECT * FROM 'table' {query}"
|
|
218
|
-
LOGGER.info(f"Running query: {query}")
|
|
219
|
-
|
|
220
|
-
rs = duckdb.sql(query)
|
|
221
|
-
if return_type == "arrow":
|
|
222
|
-
return rs.arrow()
|
|
223
|
-
elif return_type == "pandas":
|
|
224
|
-
return rs.df()
|
|
225
|
-
|
|
226
|
-
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
|
227
|
-
"""
|
|
228
|
-
Plot the results of a SQL-Like query on the table.
|
|
229
|
-
|
|
230
|
-
Args:
|
|
231
|
-
query (str): SQL query to run.
|
|
232
|
-
labels (bool): Whether to plot the labels or not.
|
|
233
|
-
|
|
234
|
-
Returns:
|
|
235
|
-
(PIL.Image): Image containing the plot.
|
|
236
|
-
|
|
237
|
-
Example:
|
|
238
|
-
```python
|
|
239
|
-
exp = Explorer()
|
|
240
|
-
exp.create_embeddings_table()
|
|
241
|
-
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
|
242
|
-
result = exp.plot_sql_query(query)
|
|
243
|
-
```
|
|
244
|
-
"""
|
|
245
|
-
result = self.sql_query(query, return_type="arrow")
|
|
246
|
-
if len(result) == 0:
|
|
247
|
-
LOGGER.info("No results found.")
|
|
248
|
-
return None
|
|
249
|
-
img = plot_query_result(result, plot_labels=labels)
|
|
250
|
-
return Image.fromarray(img)
|
|
251
|
-
|
|
252
|
-
def get_similar(
|
|
253
|
-
self,
|
|
254
|
-
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
|
255
|
-
idx: Union[int, List[int]] = None,
|
|
256
|
-
limit: int = 25,
|
|
257
|
-
return_type: str = "pandas",
|
|
258
|
-
) -> Any: # pandas.DataFrame or pyarrow.Table
|
|
259
|
-
"""
|
|
260
|
-
Query the table for similar images. Accepts a single image or a list of images.
|
|
261
|
-
|
|
262
|
-
Args:
|
|
263
|
-
img (str or list): Path to the image or a list of paths to the images.
|
|
264
|
-
idx (int or list): Index of the image in the table or a list of indexes.
|
|
265
|
-
limit (int): Number of results to return. Defaults to 25.
|
|
266
|
-
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
|
267
|
-
|
|
268
|
-
Returns:
|
|
269
|
-
(pandas.DataFrame): A dataframe containing the results.
|
|
270
|
-
|
|
271
|
-
Example:
|
|
272
|
-
```python
|
|
273
|
-
exp = Explorer()
|
|
274
|
-
exp.create_embeddings_table()
|
|
275
|
-
similar = exp.get_similar(img="https://ultralytics.com/images/zidane.jpg")
|
|
276
|
-
```
|
|
277
|
-
"""
|
|
278
|
-
assert return_type in {"pandas", "arrow"}, f"Return type should be `pandas` or `arrow`, but got {return_type}"
|
|
279
|
-
img = self._check_imgs_or_idxs(img, idx)
|
|
280
|
-
similar = self.query(img, limit=limit)
|
|
281
|
-
|
|
282
|
-
if return_type == "arrow":
|
|
283
|
-
return similar
|
|
284
|
-
elif return_type == "pandas":
|
|
285
|
-
return similar.to_pandas()
|
|
286
|
-
|
|
287
|
-
def plot_similar(
|
|
288
|
-
self,
|
|
289
|
-
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
|
290
|
-
idx: Union[int, List[int]] = None,
|
|
291
|
-
limit: int = 25,
|
|
292
|
-
labels: bool = True,
|
|
293
|
-
) -> Image.Image:
|
|
294
|
-
"""
|
|
295
|
-
Plot the similar images. Accepts images or indexes.
|
|
296
|
-
|
|
297
|
-
Args:
|
|
298
|
-
img (str or list): Path to the image or a list of paths to the images.
|
|
299
|
-
idx (int or list): Index of the image in the table or a list of indexes.
|
|
300
|
-
labels (bool): Whether to plot the labels or not.
|
|
301
|
-
limit (int): Number of results to return. Defaults to 25.
|
|
302
|
-
|
|
303
|
-
Returns:
|
|
304
|
-
(PIL.Image): Image containing the plot.
|
|
305
|
-
|
|
306
|
-
Example:
|
|
307
|
-
```python
|
|
308
|
-
exp = Explorer()
|
|
309
|
-
exp.create_embeddings_table()
|
|
310
|
-
similar = exp.plot_similar(img="https://ultralytics.com/images/zidane.jpg")
|
|
311
|
-
```
|
|
312
|
-
"""
|
|
313
|
-
similar = self.get_similar(img, idx, limit, return_type="arrow")
|
|
314
|
-
if len(similar) == 0:
|
|
315
|
-
LOGGER.info("No results found.")
|
|
316
|
-
return None
|
|
317
|
-
img = plot_query_result(similar, plot_labels=labels)
|
|
318
|
-
return Image.fromarray(img)
|
|
319
|
-
|
|
320
|
-
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame
|
|
321
|
-
"""
|
|
322
|
-
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
|
323
|
-
are max_dist or closer to the image in the embedding space at a given index.
|
|
324
|
-
|
|
325
|
-
Args:
|
|
326
|
-
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
|
327
|
-
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit.
|
|
328
|
-
vector search. Defaults: None.
|
|
329
|
-
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
|
330
|
-
|
|
331
|
-
Returns:
|
|
332
|
-
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image,
|
|
333
|
-
and columns include indices of similar images and their respective distances.
|
|
334
|
-
|
|
335
|
-
Example:
|
|
336
|
-
```python
|
|
337
|
-
exp = Explorer()
|
|
338
|
-
exp.create_embeddings_table()
|
|
339
|
-
sim_idx = exp.similarity_index()
|
|
340
|
-
```
|
|
341
|
-
"""
|
|
342
|
-
if self.table is None:
|
|
343
|
-
raise ValueError("Table is not created. Please create the table first.")
|
|
344
|
-
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
|
|
345
|
-
if sim_idx_table_name in self.connection.table_names() and not force:
|
|
346
|
-
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
|
|
347
|
-
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
|
348
|
-
|
|
349
|
-
if top_k and not (1.0 >= top_k >= 0.0):
|
|
350
|
-
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
|
|
351
|
-
if max_dist < 0.0:
|
|
352
|
-
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
|
|
353
|
-
|
|
354
|
-
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
|
355
|
-
top_k = max(top_k, 1)
|
|
356
|
-
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
|
|
357
|
-
im_files = features["im_file"]
|
|
358
|
-
embeddings = features["vector"]
|
|
359
|
-
|
|
360
|
-
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
|
|
361
|
-
|
|
362
|
-
def _yield_sim_idx():
|
|
363
|
-
"""Generates a dataframe with similarity indices and distances for images."""
|
|
364
|
-
for i in tqdm(range(len(embeddings))):
|
|
365
|
-
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
|
|
366
|
-
yield [
|
|
367
|
-
{
|
|
368
|
-
"idx": i,
|
|
369
|
-
"im_file": im_files[i],
|
|
370
|
-
"count": len(sim_idx),
|
|
371
|
-
"sim_im_files": sim_idx["im_file"].tolist(),
|
|
372
|
-
}
|
|
373
|
-
]
|
|
374
|
-
|
|
375
|
-
sim_table.add(_yield_sim_idx())
|
|
376
|
-
self.sim_index = sim_table
|
|
377
|
-
return sim_table.to_pandas()
|
|
378
|
-
|
|
379
|
-
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
|
|
380
|
-
"""
|
|
381
|
-
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
|
|
382
|
-
max_dist or closer to the image in the embedding space at a given index.
|
|
383
|
-
|
|
384
|
-
Args:
|
|
385
|
-
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
|
386
|
-
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
|
|
387
|
-
running vector search. Defaults to 0.01.
|
|
388
|
-
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
|
389
|
-
|
|
390
|
-
Returns:
|
|
391
|
-
(PIL.Image): Image containing the plot.
|
|
392
|
-
|
|
393
|
-
Example:
|
|
394
|
-
```python
|
|
395
|
-
exp = Explorer()
|
|
396
|
-
exp.create_embeddings_table()
|
|
397
|
-
|
|
398
|
-
similarity_idx_plot = exp.plot_similarity_index()
|
|
399
|
-
similarity_idx_plot.show() # view image preview
|
|
400
|
-
similarity_idx_plot.save("path/to/save/similarity_index_plot.png") # save contents to file
|
|
401
|
-
```
|
|
402
|
-
"""
|
|
403
|
-
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
|
404
|
-
sim_count = sim_idx["count"].tolist()
|
|
405
|
-
sim_count = np.array(sim_count)
|
|
406
|
-
|
|
407
|
-
indices = np.arange(len(sim_count))
|
|
408
|
-
|
|
409
|
-
# Create the bar plot
|
|
410
|
-
plt.bar(indices, sim_count)
|
|
411
|
-
|
|
412
|
-
# Customize the plot (optional)
|
|
413
|
-
plt.xlabel("data idx")
|
|
414
|
-
plt.ylabel("Count")
|
|
415
|
-
plt.title("Similarity Count")
|
|
416
|
-
buffer = BytesIO()
|
|
417
|
-
plt.savefig(buffer, format="png")
|
|
418
|
-
buffer.seek(0)
|
|
419
|
-
|
|
420
|
-
# Use Pillow to open the image from the buffer
|
|
421
|
-
return Image.fromarray(np.array(Image.open(buffer)))
|
|
422
|
-
|
|
423
|
-
def _check_imgs_or_idxs(
|
|
424
|
-
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
|
|
425
|
-
) -> List[np.ndarray]:
|
|
426
|
-
"""Determines whether to fetch images or indexes based on provided arguments and returns image paths."""
|
|
427
|
-
if img is None and idx is None:
|
|
428
|
-
raise ValueError("Either img or idx must be provided.")
|
|
429
|
-
if img is not None and idx is not None:
|
|
430
|
-
raise ValueError("Only one of img or idx must be provided.")
|
|
431
|
-
if idx is not None:
|
|
432
|
-
idx = idx if isinstance(idx, list) else [idx]
|
|
433
|
-
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
|
|
434
|
-
|
|
435
|
-
return img if isinstance(img, list) else [img]
|
|
436
|
-
|
|
437
|
-
def ask_ai(self, query):
|
|
438
|
-
"""
|
|
439
|
-
Ask AI a question.
|
|
440
|
-
|
|
441
|
-
Args:
|
|
442
|
-
query (str): Question to ask.
|
|
443
|
-
|
|
444
|
-
Returns:
|
|
445
|
-
(pandas.DataFrame): A dataframe containing filtered results to the SQL query.
|
|
446
|
-
|
|
447
|
-
Example:
|
|
448
|
-
```python
|
|
449
|
-
exp = Explorer()
|
|
450
|
-
exp.create_embeddings_table()
|
|
451
|
-
answer = exp.ask_ai("Show images with 1 person and 2 dogs")
|
|
452
|
-
```
|
|
453
|
-
"""
|
|
454
|
-
result = prompt_sql_query(query)
|
|
455
|
-
try:
|
|
456
|
-
return self.sql_query(result)
|
|
457
|
-
except Exception as e:
|
|
458
|
-
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
|
459
|
-
LOGGER.error(e)
|
|
460
|
-
return None
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
# Ultralytics YOLO 🚀, AGPL-3.0 license
|