supervisely 6.73.457__py3-none-any.whl → 6.73.458__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +24 -1
- supervisely/api/image_api.py +4 -0
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +41 -1
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +45 -11
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +63 -7
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/inference/inference.py +364 -73
- supervisely/nn/inference/inference_request.py +3 -2
- supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +178 -25
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort_tracker.py +14 -7
- supervisely/nn/tracker/visualize.py +70 -72
- supervisely/video/video.py +15 -1
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/METADATA +3 -2
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/RECORD +40 -40
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/LICENSE +0 -0
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/WHEEL +0 -0
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
from typing import Any, Dict
|
2
|
+
|
2
3
|
from supervisely.app.widgets import Button, Card, Container, TagsTable, Text
|
3
4
|
|
4
5
|
|
@@ -45,7 +46,6 @@ class TagsSelector:
|
|
45
46
|
content=self.container,
|
46
47
|
lock_message=self.lock_message,
|
47
48
|
)
|
48
|
-
self.card.lock()
|
49
49
|
# -------------------------------- #
|
50
50
|
|
51
51
|
@property
|
@@ -1,12 +1,22 @@
|
|
1
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple,
|
1
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
2
|
+
|
2
3
|
from supervisely import logger
|
3
4
|
from supervisely.api.api import Api
|
4
|
-
from supervisely.api.project_api import ProjectInfo
|
5
5
|
from supervisely.api.dataset_api import DatasetInfo
|
6
|
-
from supervisely.
|
7
|
-
from supervisely.
|
6
|
+
from supervisely.api.image_api import ImageInfo
|
7
|
+
from supervisely.api.project_api import ProjectInfo
|
8
8
|
from supervisely.app import DataJson
|
9
|
-
from supervisely.app.widgets import Button, Card, Stepper, Text, Widget
|
9
|
+
from supervisely.app.widgets import Button, Card, Progress, Stepper, Text, Widget
|
10
|
+
from supervisely.nn.model.prediction import Prediction
|
11
|
+
from supervisely.project.project import ProjectType
|
12
|
+
from supervisely.project.project_meta import ProjectMeta
|
13
|
+
from supervisely.project.video_project import VideoInfo
|
14
|
+
from supervisely.video_annotation.frame import Frame
|
15
|
+
from supervisely.video_annotation.frame_collection import FrameCollection
|
16
|
+
from supervisely.video_annotation.video_annotation import VideoAnnotation
|
17
|
+
from supervisely.video_annotation.video_figure import VideoFigure
|
18
|
+
from supervisely.video_annotation.video_object import VideoObject
|
19
|
+
from supervisely.video_annotation.video_object_collection import VideoObjectCollection
|
10
20
|
|
11
21
|
button_clicked = {}
|
12
22
|
|
@@ -81,7 +91,7 @@ def wrap_button_click(
|
|
81
91
|
bid = button.widget_id
|
82
92
|
button_clicked[bid] = False
|
83
93
|
|
84
|
-
def button_click(button_clicked_value: Optional[bool] = None):
|
94
|
+
def button_click(button_clicked_value: Optional[bool] = None, suppress_actions: bool = False):
|
85
95
|
if button_clicked_value is None or button_clicked_value is False:
|
86
96
|
if validation_func is not None:
|
87
97
|
success = validation_func()
|
@@ -95,12 +105,12 @@ def wrap_button_click(
|
|
95
105
|
|
96
106
|
if button_clicked[bid] and upd_params:
|
97
107
|
update_custom_button_params(button, reselect_params)
|
98
|
-
if on_select_click is not None:
|
108
|
+
if not suppress_actions and on_select_click is not None:
|
99
109
|
for func in on_select_click:
|
100
110
|
func()
|
101
111
|
else:
|
102
112
|
update_custom_button_params(button, select_params)
|
103
|
-
if on_reselect_click is not None:
|
113
|
+
if not suppress_actions and on_reselect_click is not None:
|
104
114
|
for func in on_reselect_click:
|
105
115
|
func()
|
106
116
|
validation_text.hide()
|
@@ -115,7 +125,7 @@ def wrap_button_click(
|
|
115
125
|
disable=button_clicked[bid],
|
116
126
|
)
|
117
127
|
if callback is not None and not button_clicked[bid]:
|
118
|
-
callback(False)
|
128
|
+
callback(False, True)
|
119
129
|
|
120
130
|
if collapse_card is not None:
|
121
131
|
card, collapse = collapse_card
|
@@ -154,129 +164,236 @@ def find_parents_in_tree(
|
|
154
164
|
return _dfs(tree, [])
|
155
165
|
|
156
166
|
|
157
|
-
def
|
167
|
+
def _copy_items_to_dataset(
|
158
168
|
api: Api,
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
dataset_ids: List[int] = [],
|
169
|
+
src_dataset_id: int,
|
170
|
+
dst_dataset: DatasetInfo,
|
171
|
+
project_type: str,
|
163
172
|
with_annotations: bool = True,
|
173
|
+
progress_cb: Callable = None,
|
164
174
|
progress: Progress = None,
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
:type with_annotations: bool
|
181
|
-
:param progress: Progress callback
|
182
|
-
:type progress: Progress
|
183
|
-
:return: Created project
|
184
|
-
:rtype: ProjectInfo
|
185
|
-
"""
|
175
|
+
items_infos: List[Union[ImageInfo, VideoInfo]] = None,
|
176
|
+
) -> Union[List[ImageInfo], List[VideoInfo]]:
|
177
|
+
if progress is None:
|
178
|
+
progress = Progress()
|
179
|
+
|
180
|
+
def combined_progress(n):
|
181
|
+
progress_cb(n)
|
182
|
+
pbar.update(n)
|
183
|
+
|
184
|
+
if project_type == ProjectType.IMAGES:
|
185
|
+
if items_infos is None:
|
186
|
+
items_infos = api.image.get_list(src_dataset_id)
|
187
|
+
with progress(
|
188
|
+
message=f"Copying items from dataset: {dst_dataset.name}", total=len(items_infos)
|
189
|
+
) as pbar:
|
186
190
|
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
type=ProjectType.IMAGES,
|
192
|
-
change_name_if_conflict=True,
|
193
|
-
)
|
194
|
-
if with_annotations:
|
195
|
-
api.project.merge_metas(src_project_id=project_id, dst_project_id=created_project.id)
|
196
|
-
return created_project
|
197
|
-
|
198
|
-
def _copy_full_project(
|
199
|
-
created_project: ProjectInfo, src_datasets_tree: Dict[DatasetInfo, Dict]
|
200
|
-
):
|
201
|
-
src_dst_ds_id_map: Dict[int, int] = {}
|
202
|
-
|
203
|
-
def _create_full_tree(ds_tree: Dict[DatasetInfo, Dict], parent_id: int = None):
|
204
|
-
for src_ds, nested_src_ds_tree in ds_tree.items():
|
205
|
-
dst_ds = api.dataset.create(
|
206
|
-
project_id=created_project.id,
|
207
|
-
name=src_ds.name,
|
208
|
-
description=src_ds.description,
|
209
|
-
change_name_if_conflict=True,
|
210
|
-
parent_id=parent_id,
|
211
|
-
)
|
212
|
-
src_dst_ds_id_map[src_ds.id] = dst_ds
|
213
|
-
|
214
|
-
# Preserve dataset custom data
|
215
|
-
info_ds = api.dataset.get_info_by_id(src_ds.id)
|
216
|
-
if info_ds.custom_data:
|
217
|
-
api.dataset.update_custom_data(dst_ds.id, info_ds.custom_data)
|
218
|
-
_create_full_tree(nested_src_ds_tree, parent_id=dst_ds.id)
|
219
|
-
|
220
|
-
_create_full_tree(src_datasets_tree)
|
221
|
-
|
222
|
-
for src_ds_id, dst_ds in src_dst_ds_id_map.items():
|
223
|
-
_copy_items(src_ds_id, dst_ds)
|
224
|
-
|
225
|
-
def _copy_datasets(created_project: ProjectInfo, src_datasets_tree: Dict[DatasetInfo, Dict]):
|
226
|
-
created_datasets: Dict[int, DatasetInfo] = {}
|
227
|
-
processed_copy: Set[int] = set()
|
228
|
-
|
229
|
-
for dataset_id in dataset_ids:
|
230
|
-
chain = find_parents_in_tree(src_datasets_tree, dataset_id, with_self=True)
|
231
|
-
if not chain:
|
232
|
-
logger.warning(
|
233
|
-
f"Dataset id {dataset_id} not found in project {project_id}. Skipping."
|
234
|
-
)
|
235
|
-
continue
|
191
|
+
if progress_cb:
|
192
|
+
_progress_cb = combined_progress
|
193
|
+
else:
|
194
|
+
_progress_cb = pbar.update
|
236
195
|
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
196
|
+
progress.show()
|
197
|
+
copied = api.image.copy_batch_optimized(
|
198
|
+
src_dataset_id=src_dataset_id,
|
199
|
+
src_image_infos=items_infos,
|
200
|
+
dst_dataset_id=dst_dataset.id,
|
201
|
+
with_annotations=with_annotations,
|
202
|
+
progress_cb=_progress_cb,
|
203
|
+
)
|
204
|
+
progress.hide()
|
205
|
+
elif project_type == ProjectType.VIDEOS:
|
206
|
+
if items_infos is None:
|
207
|
+
items_infos = api.video.get_list(src_dataset_id)
|
242
208
|
|
243
|
-
created_ds = api.dataset.create(
|
244
|
-
created_project.id,
|
245
|
-
ds_info.name,
|
246
|
-
description=ds_info.description,
|
247
|
-
change_name_if_conflict=False,
|
248
|
-
parent_id=parent_created_id,
|
249
|
-
)
|
250
|
-
created_datasets[ds_info.id] = created_ds
|
251
|
-
src_info = api.dataset.get_info_by_id(ds_info.id)
|
252
|
-
if src_info.custom_data:
|
253
|
-
api.dataset.update_custom_data(created_ds.id, src_info.custom_data)
|
254
|
-
parent_created_id = created_ds.id
|
255
|
-
|
256
|
-
if dataset_id not in processed_copy:
|
257
|
-
_copy_items(dataset_id, created_datasets[dataset_id])
|
258
|
-
processed_copy.add(dataset_id)
|
259
|
-
|
260
|
-
def _copy_items(src_ds_id: int, dst_ds: DatasetInfo):
|
261
|
-
input_img_infos = api.image.get_list(src_ds_id)
|
262
209
|
with progress(
|
263
|
-
message=f"Copying items from dataset: {
|
210
|
+
message=f"Copying items from dataset: {dst_dataset.name}", total=len(items_infos)
|
264
211
|
) as pbar:
|
212
|
+
if progress_cb:
|
213
|
+
_progress_cb = combined_progress
|
214
|
+
else:
|
215
|
+
_progress_cb = pbar.update
|
265
216
|
progress.show()
|
266
|
-
api.
|
267
|
-
|
268
|
-
|
269
|
-
dst_dataset_id=dst_ds.id,
|
217
|
+
copied = api.video.copy_batch(
|
218
|
+
dst_dataset_id=dst_dataset.id,
|
219
|
+
ids=[info.id for info in items_infos],
|
270
220
|
with_annotations=with_annotations,
|
271
|
-
progress_cb=
|
221
|
+
progress_cb=_progress_cb,
|
272
222
|
)
|
273
223
|
progress.hide()
|
224
|
+
else:
|
225
|
+
raise NotImplementedError(f"Copy not implemented for project type {project_type}")
|
226
|
+
return copied
|
274
227
|
|
275
|
-
created_project = _create_project()
|
276
|
-
src_datasets_tree = api.dataset.get_tree(project_id)
|
277
228
|
|
278
|
-
|
279
|
-
|
229
|
+
def get_items_infos(
|
230
|
+
api: Api, items_ids: List[int], project_type: str
|
231
|
+
) -> List[Union[ImageInfo, VideoInfo]]:
|
232
|
+
if project_type == ProjectType.IMAGES:
|
233
|
+
items_infos: List[ImageInfo] = api.image.get_info_by_id_batch(items_ids)
|
234
|
+
elif project_type == ProjectType.VIDEOS:
|
235
|
+
items_infos: List[VideoInfo] = api.video.get_info_by_id_batch(items_ids)
|
280
236
|
else:
|
281
|
-
|
237
|
+
raise NotImplementedError(f"Items of type {project_type} are not supported")
|
238
|
+
return items_infos
|
239
|
+
|
240
|
+
|
241
|
+
def copy_items_to_project(
|
242
|
+
api: Api,
|
243
|
+
src_project_id: int,
|
244
|
+
items: Union[List[ImageInfo], List[VideoInfo]],
|
245
|
+
dst_project_id: int,
|
246
|
+
with_annotations: bool = True,
|
247
|
+
progress_cb: Progress = None,
|
248
|
+
ds_progress: Progress = None,
|
249
|
+
project_type: str = None,
|
250
|
+
src_datasets_tree: Dict[DatasetInfo, Dict] = None,
|
251
|
+
) -> Union[List[ImageInfo], List[VideoInfo]]:
|
252
|
+
if project_type is None:
|
253
|
+
dst_project_info = api.project.get_info_by_id(src_project_id)
|
254
|
+
project_type = dst_project_info.type
|
255
|
+
if len(items) == 0:
|
256
|
+
return []
|
257
|
+
if len(set(info.project_id for info in items)) != 1:
|
258
|
+
raise ValueError("Items must belong to the same project")
|
259
|
+
|
260
|
+
items_by_dataset: Dict[int, List[Union[ImageInfo, VideoInfo]]] = {}
|
261
|
+
for item_info in items:
|
262
|
+
items_by_dataset.setdefault(item_info.dataset_id, []).append(item_info)
|
263
|
+
|
264
|
+
if src_datasets_tree is None:
|
265
|
+
src_datasets_tree = api.dataset.get_tree(src_project_id)
|
266
|
+
|
267
|
+
created_datasets: Dict[int, DatasetInfo] = {}
|
268
|
+
processed_copy: Set[int] = set()
|
269
|
+
|
270
|
+
copied_items = {}
|
271
|
+
for dataset_id, items_infos in items_by_dataset.items():
|
272
|
+
chain = find_parents_in_tree(src_datasets_tree, dataset_id, with_self=True)
|
273
|
+
if not chain:
|
274
|
+
logger.warning(f"Dataset id {dataset_id} not found in project. Skipping")
|
275
|
+
continue
|
276
|
+
|
277
|
+
parent_created_id = None
|
278
|
+
for ds_info in chain:
|
279
|
+
if ds_info.id in created_datasets:
|
280
|
+
parent_created_id = created_datasets[ds_info.id].id
|
281
|
+
continue
|
282
|
+
|
283
|
+
created_ds = api.dataset.create(
|
284
|
+
dst_project_id,
|
285
|
+
ds_info.name,
|
286
|
+
description=ds_info.description,
|
287
|
+
change_name_if_conflict=False,
|
288
|
+
parent_id=parent_created_id,
|
289
|
+
)
|
290
|
+
if ds_info.custom_data:
|
291
|
+
created_ds = api.dataset.update_custom_data(created_ds.id, ds_info.custom_data)
|
292
|
+
created_datasets[ds_info.id] = created_ds
|
293
|
+
parent_created_id = created_ds.id
|
294
|
+
|
295
|
+
if dataset_id not in processed_copy:
|
296
|
+
copied_ds_items = _copy_items_to_dataset(
|
297
|
+
api=api,
|
298
|
+
src_dataset_id=dataset_id,
|
299
|
+
dst_dataset=created_datasets[dataset_id],
|
300
|
+
project_type=project_type,
|
301
|
+
with_annotations=with_annotations,
|
302
|
+
progress_cb=progress_cb,
|
303
|
+
progress=ds_progress,
|
304
|
+
items_infos=items_infos,
|
305
|
+
)
|
306
|
+
for src_info, dst_info in zip(items_infos, copied_ds_items):
|
307
|
+
copied_items[src_info.id] = dst_info
|
308
|
+
processed_copy.add(dataset_id)
|
309
|
+
return [copied_items[item.id] for item in items]
|
310
|
+
|
311
|
+
|
312
|
+
def create_project(
|
313
|
+
api: Api,
|
314
|
+
project_id: int,
|
315
|
+
project_name: str,
|
316
|
+
workspace_id: int,
|
317
|
+
copy_meta: bool = False,
|
318
|
+
project_type: str = None,
|
319
|
+
) -> ProjectInfo:
|
320
|
+
if project_type is None:
|
321
|
+
project_info = api.project.get_info_by_id(project_id)
|
322
|
+
project_type = project_info.type
|
323
|
+
created_project = api.project.create(
|
324
|
+
workspace_id,
|
325
|
+
project_name,
|
326
|
+
type=project_type,
|
327
|
+
change_name_if_conflict=True,
|
328
|
+
)
|
329
|
+
if copy_meta:
|
330
|
+
api.project.merge_metas(src_project_id=project_id, dst_project_id=created_project.id)
|
282
331
|
return created_project
|
332
|
+
|
333
|
+
|
334
|
+
def copy_project(
|
335
|
+
api: Api,
|
336
|
+
project_id: int,
|
337
|
+
workspace_id: int,
|
338
|
+
project_name: str,
|
339
|
+
items_ids: List[int] = None,
|
340
|
+
with_annotations: bool = True,
|
341
|
+
progress: Progress = None,
|
342
|
+
) -> ProjectInfo:
|
343
|
+
dst_project = create_project(
|
344
|
+
api, project_id, project_name, workspace_id=workspace_id, copy_meta=True
|
345
|
+
)
|
346
|
+
items = []
|
347
|
+
if items_ids is None:
|
348
|
+
project_type = dst_project.type
|
349
|
+
datasets = api.dataset.get_list(project_id, recursive=True)
|
350
|
+
if project_type == ProjectType.IMAGES:
|
351
|
+
get_items_f = api.image.get_list
|
352
|
+
elif project_type == ProjectType.VIDEOS:
|
353
|
+
get_items_f = api.video.get_list
|
354
|
+
else:
|
355
|
+
raise NotImplementedError(f"Project type {project_type} is not supported")
|
356
|
+
for ds in datasets:
|
357
|
+
ds_items = get_items_f(dataset_id=ds.id)
|
358
|
+
if ds_items:
|
359
|
+
items.extend(ds_items)
|
360
|
+
else:
|
361
|
+
items = get_items_infos(api, items_ids, dst_project.type)
|
362
|
+
copy_items_to_project(
|
363
|
+
api=api,
|
364
|
+
src_project_id=project_id,
|
365
|
+
items=items,
|
366
|
+
dst_project_id=dst_project.id,
|
367
|
+
with_annotations=with_annotations,
|
368
|
+
ds_progress=progress,
|
369
|
+
project_type=dst_project.type,
|
370
|
+
)
|
371
|
+
return dst_project
|
372
|
+
|
373
|
+
|
374
|
+
def video_annotation_from_predictions(
|
375
|
+
predictions: List[Prediction], project_meta: ProjectMeta, frame_size: Tuple[int, int]
|
376
|
+
) -> VideoAnnotation:
|
377
|
+
objects = {}
|
378
|
+
frames = []
|
379
|
+
for i, prediction in enumerate(predictions):
|
380
|
+
figures = []
|
381
|
+
for label in prediction.annotation.labels:
|
382
|
+
obj_name = label.obj_class.name
|
383
|
+
if not obj_name in objects:
|
384
|
+
obj_class = project_meta.get_obj_class(obj_name)
|
385
|
+
if obj_class is None:
|
386
|
+
continue
|
387
|
+
objects[obj_name] = VideoObject(obj_class)
|
388
|
+
|
389
|
+
vid_object = objects[obj_name]
|
390
|
+
if vid_object:
|
391
|
+
figures.append(VideoFigure(vid_object, label.geometry, frame_index=i))
|
392
|
+
frame = Frame(i, figures=figures)
|
393
|
+
frames.append(frame)
|
394
|
+
return VideoAnnotation(
|
395
|
+
img_size=frame_size,
|
396
|
+
frames_count=len(frames),
|
397
|
+
objects=VideoObjectCollection(list(objects.values())),
|
398
|
+
frames=FrameCollection(frames),
|
399
|
+
)
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import os
|
2
1
|
from typing import Dict, List, Optional
|
3
2
|
|
4
3
|
from fastapi import BackgroundTasks, Request
|
@@ -24,7 +23,8 @@ class PredictApp:
|
|
24
23
|
@self.gui.output_selector.start_button.click
|
25
24
|
def start_prediction():
|
26
25
|
if self.gui.output_selector.validate_step():
|
27
|
-
|
26
|
+
widgets_to_disable = self.gui.output_selector.widgets_to_disable + [self.gui.settings_selector.preview.run_button]
|
27
|
+
disable_enable(widgets_to_disable, True)
|
28
28
|
self.gui.run()
|
29
29
|
self.shutdown_serving_app()
|
30
30
|
self.shutdown_predict_app()
|
@@ -72,6 +72,15 @@ class ModelAPI:
|
|
72
72
|
else:
|
73
73
|
return self._post("get_custom_inference_settings", {})["settings"]
|
74
74
|
|
75
|
+
def get_tracking_settings(self):
|
76
|
+
# @TODO: botsort hardcoded
|
77
|
+
# Add dropdown selector for tracking algorithms later
|
78
|
+
if self.task_id is not None:
|
79
|
+
return self.api.task.send_request(self.task_id, "get_tracking_settings", {})["botsort"]
|
80
|
+
else:
|
81
|
+
return self._post("get_tracking_settings", {})["botsort"]
|
82
|
+
|
83
|
+
|
75
84
|
def get_model_meta(self):
|
76
85
|
if self.task_id is not None:
|
77
86
|
return ProjectMeta.from_json(
|
@@ -36,9 +36,19 @@ class BaseTracker:
|
|
36
36
|
def video_annotation(self) -> VideoAnnotation:
|
37
37
|
"""Return the accumulated VideoAnnotation."""
|
38
38
|
raise NotImplementedError("This method should be overridden by subclasses.")
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def get_default_params(cls) -> Dict[str, Any]:
|
42
|
+
"""
|
43
|
+
Get default configurable parameters for this tracker.
|
44
|
+
Must be implemented in subclass.
|
45
|
+
"""
|
46
|
+
raise NotImplementedError(
|
47
|
+
f"Method get_default_params() must be implemented in {cls.__name__}"
|
48
|
+
)
|
39
49
|
|
40
50
|
def _validate_device(self) -> None:
|
41
51
|
if self.device != 'cpu' and not self.device.startswith('cuda'):
|
42
52
|
raise ValueError(
|
43
53
|
f"Invalid device '{self.device}'. Supported devices are 'cpu' or 'cuda'."
|
44
|
-
)
|
54
|
+
)
|
@@ -54,6 +54,7 @@ class BotSortTracker(BaseTracker):
|
|
54
54
|
self.settings.update(settings)
|
55
55
|
|
56
56
|
args = SimpleNamespace(**self.settings)
|
57
|
+
args.name = "BotSORT"
|
57
58
|
args.device = self.device
|
58
59
|
|
59
60
|
self.tracker = BoTSORT(args=args)
|
@@ -66,12 +67,8 @@ class BotSortTracker(BaseTracker):
|
|
66
67
|
self.frame_shape = ()
|
67
68
|
|
68
69
|
def _load_default_settings(self) -> dict:
|
69
|
-
"""
|
70
|
-
|
71
|
-
config_path = current_dir / "botsort/botsort_config.yaml"
|
72
|
-
|
73
|
-
with open(config_path, 'r', encoding='utf-8') as file:
|
74
|
-
return yaml.safe_load(file)
|
70
|
+
"""Internal method: calls classmethod"""
|
71
|
+
return self.get_default_params()
|
75
72
|
|
76
73
|
def update(self, frame: np.ndarray, annotation: Annotation) -> List[Dict[str, Any]]:
|
77
74
|
"""Update tracker and return list of matches for current frame."""
|
@@ -263,4 +260,14 @@ class BotSortTracker(BaseTracker):
|
|
263
260
|
)
|
264
261
|
raise ValueError(error_msg)
|
265
262
|
|
266
|
-
return self._create_video_annotation()
|
263
|
+
return self._create_video_annotation()
|
264
|
+
|
265
|
+
@classmethod
|
266
|
+
def get_default_params(cls) -> Dict[str, Any]:
|
267
|
+
"""Public API: get default params WITHOUT creating instance."""
|
268
|
+
current_dir = Path(__file__).parent
|
269
|
+
config_path = current_dir / "botsort/botsort_config.yaml"
|
270
|
+
|
271
|
+
with open(config_path, 'r', encoding='utf-8') as file:
|
272
|
+
return yaml.safe_load(file)
|
273
|
+
|