ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
1
3
  import time
2
4
  from threading import Thread
3
5
 
@@ -7,95 +9,114 @@ from ultralytics import Explorer
7
9
  from ultralytics.utils import ROOT, SETTINGS
8
10
  from ultralytics.utils.checks import check_requirements
9
11
 
10
- check_requirements('streamlit>=1.29.0')
11
- check_requirements('streamlit-select>=0.2')
12
+ check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2"))
13
+
12
14
  import streamlit as st
13
15
  from streamlit_select import image_select
14
16
 
15
17
 
16
18
  def _get_explorer():
17
- exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
18
- thread = Thread(target=exp.create_embeddings_table,
19
- kwargs={'force': st.session_state.get('force_recreate_embeddings')})
19
+ """Initializes and returns an instance of the Explorer class."""
20
+ exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
21
+ thread = Thread(
22
+ target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
23
+ )
20
24
  thread.start()
21
- progress_bar = st.progress(0, text='Creating embeddings table...')
25
+ progress_bar = st.progress(0, text="Creating embeddings table...")
22
26
  while exp.progress < 1:
23
27
  time.sleep(0.1)
24
- progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
28
+ progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
25
29
  thread.join()
26
- st.session_state['explorer'] = exp
30
+ st.session_state["explorer"] = exp
27
31
  progress_bar.empty()
28
32
 
29
33
 
30
34
  def init_explorer_form():
31
- datasets = ROOT / 'cfg' / 'datasets'
32
- ds = [d.name for d in datasets.glob('*.yaml')]
35
+ """Initializes an Explorer instance and creates embeddings table with progress tracking."""
36
+ datasets = ROOT / "cfg" / "datasets"
37
+ ds = [d.name for d in datasets.glob("*.yaml")]
33
38
  models = [
34
- 'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
35
- 'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
36
- 'yolov8l-pose.pt', 'yolov8x-pose.pt']
37
- with st.form(key='explorer_init_form'):
39
+ "yolov8n.pt",
40
+ "yolov8s.pt",
41
+ "yolov8m.pt",
42
+ "yolov8l.pt",
43
+ "yolov8x.pt",
44
+ "yolov8n-seg.pt",
45
+ "yolov8s-seg.pt",
46
+ "yolov8m-seg.pt",
47
+ "yolov8l-seg.pt",
48
+ "yolov8x-seg.pt",
49
+ "yolov8n-pose.pt",
50
+ "yolov8s-pose.pt",
51
+ "yolov8m-pose.pt",
52
+ "yolov8l-pose.pt",
53
+ "yolov8x-pose.pt",
54
+ ]
55
+ with st.form(key="explorer_init_form"):
38
56
  col1, col2 = st.columns(2)
39
57
  with col1:
40
- st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
58
+ st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
41
59
  with col2:
42
- st.selectbox('Select model', models, key='model')
43
- st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
60
+ st.selectbox("Select model", models, key="model")
61
+ st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
44
62
 
45
- st.form_submit_button('Explore', on_click=_get_explorer)
63
+ st.form_submit_button("Explore", on_click=_get_explorer)
46
64
 
47
65
 
48
66
  def query_form():
49
- with st.form('query_form'):
67
+ """Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
68
+ with st.form("query_form"):
50
69
  col1, col2 = st.columns([0.8, 0.2])
51
70
  with col1:
52
- st.text_input('Query',
53
- "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
54
- label_visibility='collapsed',
55
- key='query')
71
+ st.text_input(
72
+ "Query",
73
+ "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
74
+ label_visibility="collapsed",
75
+ key="query",
76
+ )
56
77
  with col2:
57
- st.form_submit_button('Query', on_click=run_sql_query)
78
+ st.form_submit_button("Query", on_click=run_sql_query)
58
79
 
59
80
 
60
81
  def ai_query_form():
61
- with st.form('ai_query_form'):
82
+ """Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
83
+ with st.form("ai_query_form"):
62
84
  col1, col2 = st.columns([0.8, 0.2])
63
85
  with col1:
64
- st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
86
+ st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
65
87
  with col2:
66
- st.form_submit_button('Ask AI', on_click=run_ai_query)
88
+ st.form_submit_button("Ask AI", on_click=run_ai_query)
67
89
 
68
90
 
69
91
  def find_similar_imgs(imgs):
70
- exp = st.session_state['explorer']
71
- similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
72
- paths = similar.to_pydict()['im_file']
73
- st.session_state['imgs'] = paths
92
+ """Initializes a Streamlit form for AI-based image querying with custom input."""
93
+ exp = st.session_state["explorer"]
94
+ similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
95
+ paths = similar.to_pydict()["im_file"]
96
+ st.session_state["imgs"] = paths
74
97
 
75
98
 
76
99
  def similarity_form(selected_imgs):
77
- st.write('Similarity Search')
78
- with st.form('similarity_form'):
100
+ """Initializes a form for AI-based image querying with custom input in Streamlit."""
101
+ st.write("Similarity Search")
102
+ with st.form("similarity_form"):
79
103
  subcol1, subcol2 = st.columns([1, 1])
80
104
  with subcol1:
81
- st.number_input('limit',
82
- min_value=None,
83
- max_value=None,
84
- value=25,
85
- label_visibility='collapsed',
86
- key='limit')
105
+ st.number_input(
106
+ "limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
107
+ )
87
108
 
88
109
  with subcol2:
89
110
  disabled = not len(selected_imgs)
90
- st.write('Selected: ', len(selected_imgs))
111
+ st.write("Selected: ", len(selected_imgs))
91
112
  st.form_submit_button(
92
- 'Search',
113
+ "Search",
93
114
  disabled=disabled,
94
115
  on_click=find_similar_imgs,
95
- args=(selected_imgs, ),
116
+ args=(selected_imgs,),
96
117
  )
97
118
  if disabled:
98
- st.error('Select at least one image to search.')
119
+ st.error("Select at least one image to search.")
99
120
 
100
121
 
101
122
  # def persist_reset_form():
@@ -109,96 +130,109 @@ def similarity_form(selected_imgs):
109
130
 
110
131
 
111
132
  def run_sql_query():
112
- st.session_state['error'] = None
113
- query = st.session_state.get('query')
133
+ """Executes an SQL query and returns the results."""
134
+ st.session_state["error"] = None
135
+ query = st.session_state.get("query")
114
136
  if query.rstrip().lstrip():
115
- exp = st.session_state['explorer']
116
- res = exp.sql_query(query, return_type='arrow')
117
- st.session_state['imgs'] = res.to_pydict()['im_file']
137
+ exp = st.session_state["explorer"]
138
+ res = exp.sql_query(query, return_type="arrow")
139
+ st.session_state["imgs"] = res.to_pydict()["im_file"]
118
140
 
119
141
 
120
142
  def run_ai_query():
121
- if not SETTINGS['openai_api_key']:
143
+ """Execute SQL query and update session state with query results."""
144
+ if not SETTINGS["openai_api_key"]:
122
145
  st.session_state[
123
- 'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
146
+ "error"
147
+ ] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
124
148
  return
125
- st.session_state['error'] = None
126
- query = st.session_state.get('ai_query')
149
+ st.session_state["error"] = None
150
+ query = st.session_state.get("ai_query")
127
151
  if query.rstrip().lstrip():
128
- exp = st.session_state['explorer']
152
+ exp = st.session_state["explorer"]
129
153
  res = exp.ask_ai(query)
130
154
  if not isinstance(res, pd.DataFrame) or res.empty:
131
- st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
155
+ st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
132
156
  return
133
- st.session_state['imgs'] = res['im_file'].to_list()
157
+ st.session_state["imgs"] = res["im_file"].to_list()
134
158
 
135
159
 
136
160
  def reset_explorer():
137
- st.session_state['explorer'] = None
138
- st.session_state['imgs'] = None
139
- st.session_state['error'] = None
161
+ """Resets the explorer to its initial state by clearing session variables."""
162
+ st.session_state["explorer"] = None
163
+ st.session_state["imgs"] = None
164
+ st.session_state["error"] = None
140
165
 
141
166
 
142
167
  def utralytics_explorer_docs_callback():
168
+ """Resets the explorer to its initial state by clearing session variables."""
143
169
  with st.container(border=True):
144
- st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
145
- width=100)
170
+ st.image(
171
+ "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
172
+ width=100,
173
+ )
146
174
  st.markdown(
147
175
  "<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
148
176
  unsafe_allow_html=True,
149
- help=None)
150
- st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
177
+ help=None,
178
+ )
179
+ st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
151
180
 
152
181
 
153
182
  def layout():
154
- st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
183
+ """Resets explorer session variables and provides documentation with a link to API docs."""
184
+ st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
155
185
  st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
156
186
 
157
- if st.session_state.get('explorer') is None:
187
+ if st.session_state.get("explorer") is None:
158
188
  init_explorer_form()
159
189
  return
160
190
 
161
- st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
162
- exp = st.session_state.get('explorer')
163
- col1, col2 = st.columns([0.75, 0.25], gap='small')
191
+ st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
192
+ exp = st.session_state.get("explorer")
193
+ col1, col2 = st.columns([0.75, 0.25], gap="small")
164
194
  imgs = []
165
- if st.session_state.get('error'):
166
- st.error(st.session_state['error'])
195
+ if st.session_state.get("error"):
196
+ st.error(st.session_state["error"])
167
197
  else:
168
- imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
198
+ imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
169
199
  total_imgs, selected_imgs = len(imgs), []
170
200
  with col1:
171
201
  subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
172
202
  with subcol1:
173
- st.write('Max Images Displayed:')
203
+ st.write("Max Images Displayed:")
174
204
  with subcol2:
175
- num = st.number_input('Max Images Displayed',
176
- min_value=0,
177
- max_value=total_imgs,
178
- value=min(500, total_imgs),
179
- key='num_imgs_displayed',
180
- label_visibility='collapsed')
205
+ num = st.number_input(
206
+ "Max Images Displayed",
207
+ min_value=0,
208
+ max_value=total_imgs,
209
+ value=min(500, total_imgs),
210
+ key="num_imgs_displayed",
211
+ label_visibility="collapsed",
212
+ )
181
213
  with subcol3:
182
- st.write('Start Index:')
214
+ st.write("Start Index:")
183
215
  with subcol4:
184
- start_idx = st.number_input('Start Index',
185
- min_value=0,
186
- max_value=total_imgs,
187
- value=0,
188
- key='start_index',
189
- label_visibility='collapsed')
216
+ start_idx = st.number_input(
217
+ "Start Index",
218
+ min_value=0,
219
+ max_value=total_imgs,
220
+ value=0,
221
+ key="start_index",
222
+ label_visibility="collapsed",
223
+ )
190
224
  with subcol5:
191
- reset = st.button('Reset', use_container_width=False, key='reset')
225
+ reset = st.button("Reset", use_container_width=False, key="reset")
192
226
  if reset:
193
- st.session_state['imgs'] = None
227
+ st.session_state["imgs"] = None
194
228
  st.experimental_rerun()
195
229
 
196
230
  query_form()
197
231
  ai_query_form()
198
232
  if total_imgs:
199
- imgs_displayed = imgs[start_idx:start_idx + num]
233
+ imgs_displayed = imgs[start_idx : start_idx + num]
200
234
  selected_imgs = image_select(
201
- f'Total samples: {total_imgs}',
235
+ f"Total samples: {total_imgs}",
202
236
  images=imgs_displayed,
203
237
  use_container_width=False,
204
238
  # indices=[i for i in range(num)] if select_all else None,
@@ -210,5 +244,5 @@ def layout():
210
244
  utralytics_explorer_docs_callback()
211
245
 
212
246
 
213
- if __name__ == '__main__':
247
+ if __name__ == "__main__":
214
248
  layout()
@@ -1,3 +1,5 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
1
3
  import getpass
2
4
  from typing import List
3
5
 
@@ -14,6 +16,7 @@ from ultralytics.utils.plotting import plot_images
14
16
 
15
17
 
16
18
  def get_table_schema(vector_size):
19
+ """Extracts and returns the schema of a database table."""
17
20
  from lancedb.pydantic import LanceModel, Vector
18
21
 
19
22
  class Schema(LanceModel):
@@ -29,6 +32,7 @@ def get_table_schema(vector_size):
29
32
 
30
33
 
31
34
  def get_sim_index_schema():
35
+ """Returns a LanceModel schema for a database table with specified vector size."""
32
36
  from lancedb.pydantic import LanceModel
33
37
 
34
38
  class Schema(LanceModel):
@@ -41,14 +45,14 @@ def get_sim_index_schema():
41
45
 
42
46
 
43
47
  def sanitize_batch(batch, dataset_info):
44
- batch['cls'] = batch['cls'].flatten().int().tolist()
45
- box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
46
- batch['bboxes'] = [box for box, _ in box_cls_pair]
47
- batch['cls'] = [cls for _, cls in box_cls_pair]
48
- batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
49
- batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
50
- batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
51
-
48
+ """Sanitizes input batch for inference, ensuring correct format and dimensions."""
49
+ batch["cls"] = batch["cls"].flatten().int().tolist()
50
+ box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
51
+ batch["bboxes"] = [box for box, _ in box_cls_pair]
52
+ batch["cls"] = [cls for _, cls in box_cls_pair]
53
+ batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
54
+ batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
55
+ batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
52
56
  return batch
53
57
 
54
58
 
@@ -60,15 +64,16 @@ def plot_query_result(similar_set, plot_labels=True):
60
64
  similar_set (list): Pyarrow or pandas object containing the similar data points
61
65
  plot_labels (bool): Whether to plot labels or not
62
66
  """
63
- similar_set = similar_set.to_dict(
64
- orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
67
+ similar_set = (
68
+ similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
69
+ )
65
70
  empty_masks = [[[]]]
66
71
  empty_boxes = [[]]
67
- images = similar_set.get('im_file', [])
68
- bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
69
- masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
70
- kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
71
- cls = similar_set.get('cls', [])
72
+ images = similar_set.get("im_file", [])
73
+ bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
74
+ masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
75
+ kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
76
+ cls = similar_set.get("cls", [])
72
77
 
73
78
  plot_size = 640
74
79
  imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
@@ -99,33 +104,26 @@ def plot_query_result(similar_set, plot_labels=True):
99
104
  batch_idx = np.concatenate(batch_idx, axis=0)
100
105
  cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
101
106
 
102
- return plot_images(imgs,
103
- batch_idx,
104
- cls,
105
- bboxes=boxes,
106
- masks=masks,
107
- kpts=kpts,
108
- max_subplots=len(images),
109
- save=False,
110
- threaded=False)
107
+ return plot_images(
108
+ imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
109
+ )
111
110
 
112
111
 
113
112
  def prompt_sql_query(query):
114
- check_requirements('openai>=1.6.1')
113
+ """Plots images with optional labels from a similar data set."""
114
+ check_requirements("openai>=1.6.1")
115
115
  from openai import OpenAI
116
116
 
117
- if not SETTINGS['openai_api_key']:
118
- logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
119
- openai_api_key = getpass.getpass('OpenAI API key: ')
120
- SETTINGS.update({'openai_api_key': openai_api_key})
121
- openai = OpenAI(api_key=SETTINGS['openai_api_key'])
117
+ if not SETTINGS["openai_api_key"]:
118
+ logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
119
+ openai_api_key = getpass.getpass("OpenAI API key: ")
120
+ SETTINGS.update({"openai_api_key": openai_api_key})
121
+ openai = OpenAI(api_key=SETTINGS["openai_api_key"])
122
122
 
123
123
  messages = [
124
124
  {
125
- 'role':
126
- 'system',
127
- 'content':
128
- '''
125
+ "role": "system",
126
+ "content": """
129
127
  You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
130
128
  the following schema and a user request. You only need to output the format with fixed selection
131
129
  statement that selects everything from "'table'", like `SELECT * from 'table'`
@@ -159,10 +157,10 @@ def prompt_sql_query(query):
159
157
  request - Get all data points that contain 2 or more people and at least one dog
160
158
  correct query-
161
159
  SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
162
- '''},
163
- {
164
- 'role': 'user',
165
- 'content': f'{query}'}, ]
160
+ """,
161
+ },
162
+ {"role": "user", "content": f"{query}"},
163
+ ]
166
164
 
167
- response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
165
+ response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
168
166
  return response.choices[0].message.content