ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
1
3
|
import time
|
|
2
4
|
from threading import Thread
|
|
3
5
|
|
|
@@ -7,95 +9,114 @@ from ultralytics import Explorer
|
|
|
7
9
|
from ultralytics.utils import ROOT, SETTINGS
|
|
8
10
|
from ultralytics.utils.checks import check_requirements
|
|
9
11
|
|
|
10
|
-
check_requirements(
|
|
11
|
-
|
|
12
|
+
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2"))
|
|
13
|
+
|
|
12
14
|
import streamlit as st
|
|
13
15
|
from streamlit_select import image_select
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def _get_explorer():
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
19
|
+
"""Initializes and returns an instance of the Explorer class."""
|
|
20
|
+
exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
|
|
21
|
+
thread = Thread(
|
|
22
|
+
target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
|
|
23
|
+
)
|
|
20
24
|
thread.start()
|
|
21
|
-
progress_bar = st.progress(0, text=
|
|
25
|
+
progress_bar = st.progress(0, text="Creating embeddings table...")
|
|
22
26
|
while exp.progress < 1:
|
|
23
27
|
time.sleep(0.1)
|
|
24
|
-
progress_bar.progress(exp.progress, text=f
|
|
28
|
+
progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
|
|
25
29
|
thread.join()
|
|
26
|
-
st.session_state[
|
|
30
|
+
st.session_state["explorer"] = exp
|
|
27
31
|
progress_bar.empty()
|
|
28
32
|
|
|
29
33
|
|
|
30
34
|
def init_explorer_form():
|
|
31
|
-
|
|
32
|
-
|
|
35
|
+
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
|
|
36
|
+
datasets = ROOT / "cfg" / "datasets"
|
|
37
|
+
ds = [d.name for d in datasets.glob("*.yaml")]
|
|
33
38
|
models = [
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
39
|
+
"yolov8n.pt",
|
|
40
|
+
"yolov8s.pt",
|
|
41
|
+
"yolov8m.pt",
|
|
42
|
+
"yolov8l.pt",
|
|
43
|
+
"yolov8x.pt",
|
|
44
|
+
"yolov8n-seg.pt",
|
|
45
|
+
"yolov8s-seg.pt",
|
|
46
|
+
"yolov8m-seg.pt",
|
|
47
|
+
"yolov8l-seg.pt",
|
|
48
|
+
"yolov8x-seg.pt",
|
|
49
|
+
"yolov8n-pose.pt",
|
|
50
|
+
"yolov8s-pose.pt",
|
|
51
|
+
"yolov8m-pose.pt",
|
|
52
|
+
"yolov8l-pose.pt",
|
|
53
|
+
"yolov8x-pose.pt",
|
|
54
|
+
]
|
|
55
|
+
with st.form(key="explorer_init_form"):
|
|
38
56
|
col1, col2 = st.columns(2)
|
|
39
57
|
with col1:
|
|
40
|
-
st.selectbox(
|
|
58
|
+
st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
|
|
41
59
|
with col2:
|
|
42
|
-
st.selectbox(
|
|
43
|
-
st.checkbox(
|
|
60
|
+
st.selectbox("Select model", models, key="model")
|
|
61
|
+
st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
|
|
44
62
|
|
|
45
|
-
st.form_submit_button(
|
|
63
|
+
st.form_submit_button("Explore", on_click=_get_explorer)
|
|
46
64
|
|
|
47
65
|
|
|
48
66
|
def query_form():
|
|
49
|
-
with
|
|
67
|
+
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
|
|
68
|
+
with st.form("query_form"):
|
|
50
69
|
col1, col2 = st.columns([0.8, 0.2])
|
|
51
70
|
with col1:
|
|
52
|
-
st.text_input(
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
71
|
+
st.text_input(
|
|
72
|
+
"Query",
|
|
73
|
+
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
|
|
74
|
+
label_visibility="collapsed",
|
|
75
|
+
key="query",
|
|
76
|
+
)
|
|
56
77
|
with col2:
|
|
57
|
-
st.form_submit_button(
|
|
78
|
+
st.form_submit_button("Query", on_click=run_sql_query)
|
|
58
79
|
|
|
59
80
|
|
|
60
81
|
def ai_query_form():
|
|
61
|
-
with
|
|
82
|
+
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
|
|
83
|
+
with st.form("ai_query_form"):
|
|
62
84
|
col1, col2 = st.columns([0.8, 0.2])
|
|
63
85
|
with col1:
|
|
64
|
-
st.text_input(
|
|
86
|
+
st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
|
|
65
87
|
with col2:
|
|
66
|
-
st.form_submit_button(
|
|
88
|
+
st.form_submit_button("Ask AI", on_click=run_ai_query)
|
|
67
89
|
|
|
68
90
|
|
|
69
91
|
def find_similar_imgs(imgs):
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
92
|
+
"""Initializes a Streamlit form for AI-based image querying with custom input."""
|
|
93
|
+
exp = st.session_state["explorer"]
|
|
94
|
+
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
|
|
95
|
+
paths = similar.to_pydict()["im_file"]
|
|
96
|
+
st.session_state["imgs"] = paths
|
|
74
97
|
|
|
75
98
|
|
|
76
99
|
def similarity_form(selected_imgs):
|
|
77
|
-
|
|
78
|
-
|
|
100
|
+
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
|
|
101
|
+
st.write("Similarity Search")
|
|
102
|
+
with st.form("similarity_form"):
|
|
79
103
|
subcol1, subcol2 = st.columns([1, 1])
|
|
80
104
|
with subcol1:
|
|
81
|
-
st.number_input(
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
value=25,
|
|
85
|
-
label_visibility='collapsed',
|
|
86
|
-
key='limit')
|
|
105
|
+
st.number_input(
|
|
106
|
+
"limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
|
|
107
|
+
)
|
|
87
108
|
|
|
88
109
|
with subcol2:
|
|
89
110
|
disabled = not len(selected_imgs)
|
|
90
|
-
st.write(
|
|
111
|
+
st.write("Selected: ", len(selected_imgs))
|
|
91
112
|
st.form_submit_button(
|
|
92
|
-
|
|
113
|
+
"Search",
|
|
93
114
|
disabled=disabled,
|
|
94
115
|
on_click=find_similar_imgs,
|
|
95
|
-
args=(selected_imgs,
|
|
116
|
+
args=(selected_imgs,),
|
|
96
117
|
)
|
|
97
118
|
if disabled:
|
|
98
|
-
st.error(
|
|
119
|
+
st.error("Select at least one image to search.")
|
|
99
120
|
|
|
100
121
|
|
|
101
122
|
# def persist_reset_form():
|
|
@@ -109,96 +130,109 @@ def similarity_form(selected_imgs):
|
|
|
109
130
|
|
|
110
131
|
|
|
111
132
|
def run_sql_query():
|
|
112
|
-
|
|
113
|
-
|
|
133
|
+
"""Executes an SQL query and returns the results."""
|
|
134
|
+
st.session_state["error"] = None
|
|
135
|
+
query = st.session_state.get("query")
|
|
114
136
|
if query.rstrip().lstrip():
|
|
115
|
-
exp = st.session_state[
|
|
116
|
-
res = exp.sql_query(query, return_type=
|
|
117
|
-
st.session_state[
|
|
137
|
+
exp = st.session_state["explorer"]
|
|
138
|
+
res = exp.sql_query(query, return_type="arrow")
|
|
139
|
+
st.session_state["imgs"] = res.to_pydict()["im_file"]
|
|
118
140
|
|
|
119
141
|
|
|
120
142
|
def run_ai_query():
|
|
121
|
-
|
|
143
|
+
"""Execute SQL query and update session state with query results."""
|
|
144
|
+
if not SETTINGS["openai_api_key"]:
|
|
122
145
|
st.session_state[
|
|
123
|
-
|
|
146
|
+
"error"
|
|
147
|
+
] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
|
124
148
|
return
|
|
125
|
-
st.session_state[
|
|
126
|
-
query = st.session_state.get(
|
|
149
|
+
st.session_state["error"] = None
|
|
150
|
+
query = st.session_state.get("ai_query")
|
|
127
151
|
if query.rstrip().lstrip():
|
|
128
|
-
exp = st.session_state[
|
|
152
|
+
exp = st.session_state["explorer"]
|
|
129
153
|
res = exp.ask_ai(query)
|
|
130
154
|
if not isinstance(res, pd.DataFrame) or res.empty:
|
|
131
|
-
st.session_state[
|
|
155
|
+
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
|
|
132
156
|
return
|
|
133
|
-
st.session_state[
|
|
157
|
+
st.session_state["imgs"] = res["im_file"].to_list()
|
|
134
158
|
|
|
135
159
|
|
|
136
160
|
def reset_explorer():
|
|
137
|
-
|
|
138
|
-
st.session_state[
|
|
139
|
-
st.session_state[
|
|
161
|
+
"""Resets the explorer to its initial state by clearing session variables."""
|
|
162
|
+
st.session_state["explorer"] = None
|
|
163
|
+
st.session_state["imgs"] = None
|
|
164
|
+
st.session_state["error"] = None
|
|
140
165
|
|
|
141
166
|
|
|
142
167
|
def utralytics_explorer_docs_callback():
|
|
168
|
+
"""Resets the explorer to its initial state by clearing session variables."""
|
|
143
169
|
with st.container(border=True):
|
|
144
|
-
st.image(
|
|
145
|
-
|
|
170
|
+
st.image(
|
|
171
|
+
"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
|
|
172
|
+
width=100,
|
|
173
|
+
)
|
|
146
174
|
st.markdown(
|
|
147
175
|
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
|
|
148
176
|
unsafe_allow_html=True,
|
|
149
|
-
help=None
|
|
150
|
-
|
|
177
|
+
help=None,
|
|
178
|
+
)
|
|
179
|
+
st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
|
|
151
180
|
|
|
152
181
|
|
|
153
182
|
def layout():
|
|
154
|
-
|
|
183
|
+
"""Resets explorer session variables and provides documentation with a link to API docs."""
|
|
184
|
+
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
|
|
155
185
|
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
|
|
156
186
|
|
|
157
|
-
if st.session_state.get(
|
|
187
|
+
if st.session_state.get("explorer") is None:
|
|
158
188
|
init_explorer_form()
|
|
159
189
|
return
|
|
160
190
|
|
|
161
|
-
st.button(
|
|
162
|
-
exp = st.session_state.get(
|
|
163
|
-
col1, col2 = st.columns([0.75, 0.25], gap=
|
|
191
|
+
st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
|
|
192
|
+
exp = st.session_state.get("explorer")
|
|
193
|
+
col1, col2 = st.columns([0.75, 0.25], gap="small")
|
|
164
194
|
imgs = []
|
|
165
|
-
if st.session_state.get(
|
|
166
|
-
st.error(st.session_state[
|
|
195
|
+
if st.session_state.get("error"):
|
|
196
|
+
st.error(st.session_state["error"])
|
|
167
197
|
else:
|
|
168
|
-
imgs = st.session_state.get(
|
|
198
|
+
imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
|
|
169
199
|
total_imgs, selected_imgs = len(imgs), []
|
|
170
200
|
with col1:
|
|
171
201
|
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
|
|
172
202
|
with subcol1:
|
|
173
|
-
st.write(
|
|
203
|
+
st.write("Max Images Displayed:")
|
|
174
204
|
with subcol2:
|
|
175
|
-
num = st.number_input(
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
205
|
+
num = st.number_input(
|
|
206
|
+
"Max Images Displayed",
|
|
207
|
+
min_value=0,
|
|
208
|
+
max_value=total_imgs,
|
|
209
|
+
value=min(500, total_imgs),
|
|
210
|
+
key="num_imgs_displayed",
|
|
211
|
+
label_visibility="collapsed",
|
|
212
|
+
)
|
|
181
213
|
with subcol3:
|
|
182
|
-
st.write(
|
|
214
|
+
st.write("Start Index:")
|
|
183
215
|
with subcol4:
|
|
184
|
-
start_idx = st.number_input(
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
216
|
+
start_idx = st.number_input(
|
|
217
|
+
"Start Index",
|
|
218
|
+
min_value=0,
|
|
219
|
+
max_value=total_imgs,
|
|
220
|
+
value=0,
|
|
221
|
+
key="start_index",
|
|
222
|
+
label_visibility="collapsed",
|
|
223
|
+
)
|
|
190
224
|
with subcol5:
|
|
191
|
-
reset = st.button(
|
|
225
|
+
reset = st.button("Reset", use_container_width=False, key="reset")
|
|
192
226
|
if reset:
|
|
193
|
-
st.session_state[
|
|
227
|
+
st.session_state["imgs"] = None
|
|
194
228
|
st.experimental_rerun()
|
|
195
229
|
|
|
196
230
|
query_form()
|
|
197
231
|
ai_query_form()
|
|
198
232
|
if total_imgs:
|
|
199
|
-
imgs_displayed = imgs[start_idx:start_idx + num]
|
|
233
|
+
imgs_displayed = imgs[start_idx : start_idx + num]
|
|
200
234
|
selected_imgs = image_select(
|
|
201
|
-
f
|
|
235
|
+
f"Total samples: {total_imgs}",
|
|
202
236
|
images=imgs_displayed,
|
|
203
237
|
use_container_width=False,
|
|
204
238
|
# indices=[i for i in range(num)] if select_all else None,
|
|
@@ -210,5 +244,5 @@ def layout():
|
|
|
210
244
|
utralytics_explorer_docs_callback()
|
|
211
245
|
|
|
212
246
|
|
|
213
|
-
if __name__ ==
|
|
247
|
+
if __name__ == "__main__":
|
|
214
248
|
layout()
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
1
3
|
import getpass
|
|
2
4
|
from typing import List
|
|
3
5
|
|
|
@@ -14,6 +16,7 @@ from ultralytics.utils.plotting import plot_images
|
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def get_table_schema(vector_size):
|
|
19
|
+
"""Extracts and returns the schema of a database table."""
|
|
17
20
|
from lancedb.pydantic import LanceModel, Vector
|
|
18
21
|
|
|
19
22
|
class Schema(LanceModel):
|
|
@@ -29,6 +32,7 @@ def get_table_schema(vector_size):
|
|
|
29
32
|
|
|
30
33
|
|
|
31
34
|
def get_sim_index_schema():
|
|
35
|
+
"""Returns a LanceModel schema for a database table with specified vector size."""
|
|
32
36
|
from lancedb.pydantic import LanceModel
|
|
33
37
|
|
|
34
38
|
class Schema(LanceModel):
|
|
@@ -41,14 +45,14 @@ def get_sim_index_schema():
|
|
|
41
45
|
|
|
42
46
|
|
|
43
47
|
def sanitize_batch(batch, dataset_info):
|
|
44
|
-
batch
|
|
45
|
-
|
|
46
|
-
batch[
|
|
47
|
-
batch[
|
|
48
|
-
batch[
|
|
49
|
-
batch[
|
|
50
|
-
batch[
|
|
51
|
-
|
|
48
|
+
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
|
|
49
|
+
batch["cls"] = batch["cls"].flatten().int().tolist()
|
|
50
|
+
box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
|
|
51
|
+
batch["bboxes"] = [box for box, _ in box_cls_pair]
|
|
52
|
+
batch["cls"] = [cls for _, cls in box_cls_pair]
|
|
53
|
+
batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
|
|
54
|
+
batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
|
|
55
|
+
batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
|
|
52
56
|
return batch
|
|
53
57
|
|
|
54
58
|
|
|
@@ -60,15 +64,16 @@ def plot_query_result(similar_set, plot_labels=True):
|
|
|
60
64
|
similar_set (list): Pyarrow or pandas object containing the similar data points
|
|
61
65
|
plot_labels (bool): Whether to plot labels or not
|
|
62
66
|
"""
|
|
63
|
-
similar_set =
|
|
64
|
-
orient=
|
|
67
|
+
similar_set = (
|
|
68
|
+
similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
|
69
|
+
)
|
|
65
70
|
empty_masks = [[[]]]
|
|
66
71
|
empty_boxes = [[]]
|
|
67
|
-
images = similar_set.get(
|
|
68
|
-
bboxes = similar_set.get(
|
|
69
|
-
masks = similar_set.get(
|
|
70
|
-
kpts = similar_set.get(
|
|
71
|
-
cls = similar_set.get(
|
|
72
|
+
images = similar_set.get("im_file", [])
|
|
73
|
+
bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
|
|
74
|
+
masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
|
|
75
|
+
kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
|
|
76
|
+
cls = similar_set.get("cls", [])
|
|
72
77
|
|
|
73
78
|
plot_size = 640
|
|
74
79
|
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
|
|
@@ -99,33 +104,26 @@ def plot_query_result(similar_set, plot_labels=True):
|
|
|
99
104
|
batch_idx = np.concatenate(batch_idx, axis=0)
|
|
100
105
|
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
|
|
101
106
|
|
|
102
|
-
return plot_images(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
bboxes=boxes,
|
|
106
|
-
masks=masks,
|
|
107
|
-
kpts=kpts,
|
|
108
|
-
max_subplots=len(images),
|
|
109
|
-
save=False,
|
|
110
|
-
threaded=False)
|
|
107
|
+
return plot_images(
|
|
108
|
+
imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
|
|
109
|
+
)
|
|
111
110
|
|
|
112
111
|
|
|
113
112
|
def prompt_sql_query(query):
|
|
114
|
-
|
|
113
|
+
"""Plots images with optional labels from a similar data set."""
|
|
114
|
+
check_requirements("openai>=1.6.1")
|
|
115
115
|
from openai import OpenAI
|
|
116
116
|
|
|
117
|
-
if not SETTINGS[
|
|
118
|
-
logger.warning(
|
|
119
|
-
openai_api_key = getpass.getpass(
|
|
120
|
-
SETTINGS.update({
|
|
121
|
-
openai = OpenAI(api_key=SETTINGS[
|
|
117
|
+
if not SETTINGS["openai_api_key"]:
|
|
118
|
+
logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
|
|
119
|
+
openai_api_key = getpass.getpass("OpenAI API key: ")
|
|
120
|
+
SETTINGS.update({"openai_api_key": openai_api_key})
|
|
121
|
+
openai = OpenAI(api_key=SETTINGS["openai_api_key"])
|
|
122
122
|
|
|
123
123
|
messages = [
|
|
124
124
|
{
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
'content':
|
|
128
|
-
'''
|
|
125
|
+
"role": "system",
|
|
126
|
+
"content": """
|
|
129
127
|
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
|
|
130
128
|
the following schema and a user request. You only need to output the format with fixed selection
|
|
131
129
|
statement that selects everything from "'table'", like `SELECT * from 'table'`
|
|
@@ -159,10 +157,10 @@ def prompt_sql_query(query):
|
|
|
159
157
|
request - Get all data points that contain 2 or more people and at least one dog
|
|
160
158
|
correct query-
|
|
161
159
|
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
160
|
+
""",
|
|
161
|
+
},
|
|
162
|
+
{"role": "user", "content": f"{query}"},
|
|
163
|
+
]
|
|
166
164
|
|
|
167
|
-
response = openai.chat.completions.create(model=
|
|
165
|
+
response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
|
|
168
166
|
return response.choices[0].message.content
|