ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -1,87 +1,122 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
1
3
  import time
2
4
  from threading import Thread
3
5
 
6
+ import pandas as pd
7
+
4
8
  from ultralytics import Explorer
5
- from ultralytics.utils import ROOT
9
+ from ultralytics.utils import ROOT, SETTINGS
6
10
  from ultralytics.utils.checks import check_requirements
7
11
 
8
- check_requirements('streamlit')
9
- check_requirements('streamlit-select>=0.2')
12
+ check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2"))
13
+
10
14
  import streamlit as st
11
15
  from streamlit_select import image_select
12
16
 
13
17
 
14
18
  def _get_explorer():
15
- exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
16
- thread = Thread(target=exp.create_embeddings_table,
17
- kwargs={'force': st.session_state.get('force_recreate_embeddings')})
19
+ """Initializes and returns an instance of the Explorer class."""
20
+ exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
21
+ thread = Thread(
22
+ target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
23
+ )
18
24
  thread.start()
19
- progress_bar = st.progress(0, text='Creating embeddings table...')
25
+ progress_bar = st.progress(0, text="Creating embeddings table...")
20
26
  while exp.progress < 1:
21
27
  time.sleep(0.1)
22
- progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
28
+ progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
23
29
  thread.join()
24
- st.session_state['explorer'] = exp
30
+ st.session_state["explorer"] = exp
25
31
  progress_bar.empty()
26
32
 
27
33
 
28
34
  def init_explorer_form():
29
- datasets = ROOT / 'cfg' / 'datasets'
30
- ds = [d.name for d in datasets.glob('*.yaml')]
35
+ """Initializes an Explorer instance and creates embeddings table with progress tracking."""
36
+ datasets = ROOT / "cfg" / "datasets"
37
+ ds = [d.name for d in datasets.glob("*.yaml")]
31
38
  models = [
32
- 'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
33
- 'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
34
- 'yolov8l-pose.pt', 'yolov8x-pose.pt']
35
- with st.form(key='explorer_init_form'):
39
+ "yolov8n.pt",
40
+ "yolov8s.pt",
41
+ "yolov8m.pt",
42
+ "yolov8l.pt",
43
+ "yolov8x.pt",
44
+ "yolov8n-seg.pt",
45
+ "yolov8s-seg.pt",
46
+ "yolov8m-seg.pt",
47
+ "yolov8l-seg.pt",
48
+ "yolov8x-seg.pt",
49
+ "yolov8n-pose.pt",
50
+ "yolov8s-pose.pt",
51
+ "yolov8m-pose.pt",
52
+ "yolov8l-pose.pt",
53
+ "yolov8x-pose.pt",
54
+ ]
55
+ with st.form(key="explorer_init_form"):
36
56
  col1, col2 = st.columns(2)
37
57
  with col1:
38
- dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
58
+ st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
39
59
  with col2:
40
- model = st.selectbox('Select model', models, key='model')
41
- st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
60
+ st.selectbox("Select model", models, key="model")
61
+ st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
42
62
 
43
- st.form_submit_button('Explore', on_click=_get_explorer)
63
+ st.form_submit_button("Explore", on_click=_get_explorer)
44
64
 
45
65
 
46
66
  def query_form():
47
- with st.form('query_form'):
67
+ """Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
68
+ with st.form("query_form"):
69
+ col1, col2 = st.columns([0.8, 0.2])
70
+ with col1:
71
+ st.text_input(
72
+ "Query",
73
+ "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
74
+ label_visibility="collapsed",
75
+ key="query",
76
+ )
77
+ with col2:
78
+ st.form_submit_button("Query", on_click=run_sql_query)
79
+
80
+
81
+ def ai_query_form():
82
+ """Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
83
+ with st.form("ai_query_form"):
48
84
  col1, col2 = st.columns([0.8, 0.2])
49
85
  with col1:
50
- query = st.text_input('Query', '', label_visibility='collapsed', key='query')
86
+ st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
51
87
  with col2:
52
- st.form_submit_button('Query', on_click=run_sql_query)
88
+ st.form_submit_button("Ask AI", on_click=run_ai_query)
53
89
 
54
90
 
55
91
  def find_similar_imgs(imgs):
56
- exp = st.session_state['explorer']
57
- similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
58
- paths = similar.to_pydict()['im_file']
59
- st.session_state['imgs'] = paths
92
+ """Initializes a Streamlit form for AI-based image querying with custom input."""
93
+ exp = st.session_state["explorer"]
94
+ similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
95
+ paths = similar.to_pydict()["im_file"]
96
+ st.session_state["imgs"] = paths
60
97
 
61
98
 
62
99
  def similarity_form(selected_imgs):
63
- st.write('Similarity Search')
64
- with st.form('similarity_form'):
100
+ """Initializes a form for AI-based image querying with custom input in Streamlit."""
101
+ st.write("Similarity Search")
102
+ with st.form("similarity_form"):
65
103
  subcol1, subcol2 = st.columns([1, 1])
66
104
  with subcol1:
67
- limit = st.number_input('limit',
68
- min_value=None,
69
- max_value=None,
70
- value=25,
71
- label_visibility='collapsed',
72
- key='limit')
105
+ st.number_input(
106
+ "limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
107
+ )
73
108
 
74
109
  with subcol2:
75
110
  disabled = not len(selected_imgs)
76
- st.write('Selected: ', len(selected_imgs))
111
+ st.write("Selected: ", len(selected_imgs))
77
112
  st.form_submit_button(
78
- 'Search',
113
+ "Search",
79
114
  disabled=disabled,
80
115
  on_click=find_similar_imgs,
81
- args=(selected_imgs, ),
116
+ args=(selected_imgs,),
82
117
  )
83
118
  if disabled:
84
- st.error('Select at least one image to search.')
119
+ st.error("Select at least one image to search.")
85
120
 
86
121
 
87
122
  # def persist_reset_form():
@@ -95,74 +130,109 @@ def similarity_form(selected_imgs):
95
130
 
96
131
 
97
132
  def run_sql_query():
98
- query = st.session_state.get('query')
133
+ """Executes an SQL query and returns the results."""
134
+ st.session_state["error"] = None
135
+ query = st.session_state.get("query")
136
+ if query.rstrip().lstrip():
137
+ exp = st.session_state["explorer"]
138
+ res = exp.sql_query(query, return_type="arrow")
139
+ st.session_state["imgs"] = res.to_pydict()["im_file"]
140
+
141
+
142
+ def run_ai_query():
143
+ """Execute SQL query and update session state with query results."""
144
+ if not SETTINGS["openai_api_key"]:
145
+ st.session_state[
146
+ "error"
147
+ ] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
148
+ return
149
+ st.session_state["error"] = None
150
+ query = st.session_state.get("ai_query")
99
151
  if query.rstrip().lstrip():
100
- exp = st.session_state['explorer']
101
- res = exp.sql_query(query, return_type='arrow')
102
- st.session_state['imgs'] = res.to_pydict()['im_file']
152
+ exp = st.session_state["explorer"]
153
+ res = exp.ask_ai(query)
154
+ if not isinstance(res, pd.DataFrame) or res.empty:
155
+ st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
156
+ return
157
+ st.session_state["imgs"] = res["im_file"].to_list()
103
158
 
104
159
 
105
160
  def reset_explorer():
106
- st.session_state['explorer'] = None
107
- st.session_state['imgs'] = None
161
+ """Resets the explorer to its initial state by clearing session variables."""
162
+ st.session_state["explorer"] = None
163
+ st.session_state["imgs"] = None
164
+ st.session_state["error"] = None
108
165
 
109
166
 
110
167
  def utralytics_explorer_docs_callback():
168
+ """Resets the explorer to its initial state by clearing session variables."""
111
169
  with st.container(border=True):
112
- st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
113
- width=100)
170
+ st.image(
171
+ "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
172
+ width=100,
173
+ )
114
174
  st.markdown(
115
- "<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>",
175
+ "<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
116
176
  unsafe_allow_html=True,
117
- help=None)
118
- st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/')
177
+ help=None,
178
+ )
179
+ st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
119
180
 
120
181
 
121
182
  def layout():
122
- st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
183
+ """Resets explorer session variables and provides documentation with a link to API docs."""
184
+ st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
123
185
  st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
124
186
 
125
- if st.session_state.get('explorer') is None:
187
+ if st.session_state.get("explorer") is None:
126
188
  init_explorer_form()
127
189
  return
128
190
 
129
- st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
130
- exp = st.session_state.get('explorer')
131
- col1, col2 = st.columns([0.75, 0.25], gap='small')
132
-
133
- imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
134
- total_imgs = len(imgs)
191
+ st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
192
+ exp = st.session_state.get("explorer")
193
+ col1, col2 = st.columns([0.75, 0.25], gap="small")
194
+ imgs = []
195
+ if st.session_state.get("error"):
196
+ st.error(st.session_state["error"])
197
+ else:
198
+ imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
199
+ total_imgs, selected_imgs = len(imgs), []
135
200
  with col1:
136
201
  subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
137
202
  with subcol1:
138
- st.write('Max Images Displayed:')
203
+ st.write("Max Images Displayed:")
139
204
  with subcol2:
140
- num = st.number_input('Max Images Displayed',
141
- min_value=0,
142
- max_value=total_imgs,
143
- value=min(500, total_imgs),
144
- key='num_imgs_displayed',
145
- label_visibility='collapsed')
205
+ num = st.number_input(
206
+ "Max Images Displayed",
207
+ min_value=0,
208
+ max_value=total_imgs,
209
+ value=min(500, total_imgs),
210
+ key="num_imgs_displayed",
211
+ label_visibility="collapsed",
212
+ )
146
213
  with subcol3:
147
- st.write('Start Index:')
214
+ st.write("Start Index:")
148
215
  with subcol4:
149
- start_idx = st.number_input('Start Index',
150
- min_value=0,
151
- max_value=total_imgs,
152
- value=0,
153
- key='start_index',
154
- label_visibility='collapsed')
216
+ start_idx = st.number_input(
217
+ "Start Index",
218
+ min_value=0,
219
+ max_value=total_imgs,
220
+ value=0,
221
+ key="start_index",
222
+ label_visibility="collapsed",
223
+ )
155
224
  with subcol5:
156
- reset = st.button('Reset', use_container_width=False, key='reset')
225
+ reset = st.button("Reset", use_container_width=False, key="reset")
157
226
  if reset:
158
- st.session_state['imgs'] = None
227
+ st.session_state["imgs"] = None
159
228
  st.experimental_rerun()
160
229
 
161
230
  query_form()
231
+ ai_query_form()
162
232
  if total_imgs:
163
- imgs_displayed = imgs[start_idx:start_idx + num]
233
+ imgs_displayed = imgs[start_idx : start_idx + num]
164
234
  selected_imgs = image_select(
165
- f'Total samples: {total_imgs}',
235
+ f"Total samples: {total_imgs}",
166
236
  images=imgs_displayed,
167
237
  use_container_width=False,
168
238
  # indices=[i for i in range(num)] if select_all else None,
@@ -174,5 +244,5 @@ def layout():
174
244
  utralytics_explorer_docs_callback()
175
245
 
176
246
 
177
- if __name__ == '__main__':
247
+ if __name__ == "__main__":
178
248
  layout()
@@ -1,14 +1,22 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import getpass
1
4
  from typing import List
2
5
 
3
6
  import cv2
4
7
  import numpy as np
8
+ import pandas as pd
5
9
 
6
10
  from ultralytics.data.augment import LetterBox
11
+ from ultralytics.utils import LOGGER as logger
12
+ from ultralytics.utils import SETTINGS
13
+ from ultralytics.utils.checks import check_requirements
7
14
  from ultralytics.utils.ops import xyxy2xywh
8
15
  from ultralytics.utils.plotting import plot_images
9
16
 
10
17
 
11
18
  def get_table_schema(vector_size):
19
+ """Extracts and returns the schema of a database table."""
12
20
  from lancedb.pydantic import LanceModel, Vector
13
21
 
14
22
  class Schema(LanceModel):
@@ -24,6 +32,7 @@ def get_table_schema(vector_size):
24
32
 
25
33
 
26
34
  def get_sim_index_schema():
35
+ """Returns a LanceModel schema for a database table with specified vector size."""
27
36
  from lancedb.pydantic import LanceModel
28
37
 
29
38
  class Schema(LanceModel):
@@ -36,33 +45,35 @@ def get_sim_index_schema():
36
45
 
37
46
 
38
47
  def sanitize_batch(batch, dataset_info):
39
- batch['cls'] = batch['cls'].flatten().int().tolist()
40
- box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
41
- batch['bboxes'] = [box for box, _ in box_cls_pair]
42
- batch['cls'] = [cls for _, cls in box_cls_pair]
43
- batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
44
- batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
45
- batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
46
-
48
+ """Sanitizes input batch for inference, ensuring correct format and dimensions."""
49
+ batch["cls"] = batch["cls"].flatten().int().tolist()
50
+ box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
51
+ batch["bboxes"] = [box for box, _ in box_cls_pair]
52
+ batch["cls"] = [cls for _, cls in box_cls_pair]
53
+ batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
54
+ batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
55
+ batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
47
56
  return batch
48
57
 
49
58
 
50
- def plot_similar_images(similar_set, plot_labels=True):
59
+ def plot_query_result(similar_set, plot_labels=True):
51
60
  """
52
61
  Plot images from the similar set.
53
62
 
54
63
  Args:
55
- similar_set (list): Pyarrow table containing the similar data points
64
+ similar_set (list): Pyarrow or pandas object containing the similar data points
56
65
  plot_labels (bool): Whether to plot labels or not
57
66
  """
58
- similar_set = similar_set.to_pydict()
67
+ similar_set = (
68
+ similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
69
+ )
59
70
  empty_masks = [[[]]]
60
71
  empty_boxes = [[]]
61
- images = similar_set.get('im_file', [])
62
- bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
63
- masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
64
- kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
65
- cls = similar_set.get('cls', [])
72
+ images = similar_set.get("im_file", [])
73
+ bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
74
+ masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
75
+ kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
76
+ cls = similar_set.get("cls", [])
66
77
 
67
78
  plot_size = 640
68
79
  imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
@@ -93,12 +104,63 @@ def plot_similar_images(similar_set, plot_labels=True):
93
104
  batch_idx = np.concatenate(batch_idx, axis=0)
94
105
  cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
95
106
 
96
- return plot_images(imgs,
97
- batch_idx,
98
- cls,
99
- bboxes=boxes,
100
- masks=masks,
101
- kpts=kpts,
102
- max_subplots=len(images),
103
- save=False,
104
- threaded=False)
107
+ return plot_images(
108
+ imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
109
+ )
110
+
111
+
112
+ def prompt_sql_query(query):
113
+ """Plots images with optional labels from a similar data set."""
114
+ check_requirements("openai>=1.6.1")
115
+ from openai import OpenAI
116
+
117
+ if not SETTINGS["openai_api_key"]:
118
+ logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
119
+ openai_api_key = getpass.getpass("OpenAI API key: ")
120
+ SETTINGS.update({"openai_api_key": openai_api_key})
121
+ openai = OpenAI(api_key=SETTINGS["openai_api_key"])
122
+
123
+ messages = [
124
+ {
125
+ "role": "system",
126
+ "content": """
127
+ You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
128
+ the following schema and a user request. You only need to output the format with fixed selection
129
+ statement that selects everything from "'table'", like `SELECT * from 'table'`
130
+
131
+ Schema:
132
+ im_file: string not null
133
+ labels: list<item: string> not null
134
+ child 0, item: string
135
+ cls: list<item: int64> not null
136
+ child 0, item: int64
137
+ bboxes: list<item: list<item: double>> not null
138
+ child 0, item: list<item: double>
139
+ child 0, item: double
140
+ masks: list<item: list<item: list<item: int64>>> not null
141
+ child 0, item: list<item: list<item: int64>>
142
+ child 0, item: list<item: int64>
143
+ child 0, item: int64
144
+ keypoints: list<item: list<item: list<item: double>>> not null
145
+ child 0, item: list<item: list<item: double>>
146
+ child 0, item: list<item: double>
147
+ child 0, item: double
148
+ vector: fixed_size_list<item: float>[256] not null
149
+ child 0, item: float
150
+
151
+ Some details about the schema:
152
+ - the "labels" column contains the string values like 'person' and 'dog' for the respective objects
153
+ in each image
154
+ - the "cls" column contains the integer values on these classes that map them the labels
155
+
156
+ Example of a correct query:
157
+ request - Get all data points that contain 2 or more people and at least one dog
158
+ correct query-
159
+ SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
160
+ """,
161
+ },
162
+ {"role": "user", "content": f"{query}"},
163
+ ]
164
+
165
+ response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
166
+ return response.choices[0].message.content