supervisely 6.73.420__py3-none-any.whl → 6.73.421__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/api.py +10 -5
- supervisely/api/app_api.py +71 -4
- supervisely/api/module_api.py +4 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/project_api.py +35 -6
- supervisely/api/task_api.py +5 -1
- supervisely/app/widgets/__init__.py +8 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/deploy_model/__init__.py +0 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
- supervisely/app/widgets/fast_table/fast_table.py +402 -74
- supervisely/app/widgets/fast_table/script.js +364 -96
- supervisely/app/widgets/fast_table/style.css +24 -0
- supervisely/app/widgets/fast_table/template.html +43 -3
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +160 -94
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
- supervisely/nn/inference/predict_app/gui/gui.py +710 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +282 -0
- supervisely/nn/inference/predict_app/predict_app.py +184 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/prediction.py +2 -0
- supervisely/nn/model/prediction_session.py +20 -3
- supervisely/nn/training/gui/gui.py +131 -44
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
- supervisely/nn/training/gui/training_artifacts.py +0 -5
- supervisely/nn/training/train_app.py +161 -44
- supervisely/template/experiment/experiment.html.jinja +74 -17
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/METADATA +3 -1
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/RECORD +74 -56
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/LICENSE +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/WHEEL +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
from supervisely.app.widgets import Button, Card, Container, TagsTable, Text
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TagsSelector:
|
|
6
|
+
title = "Tags Selector"
|
|
7
|
+
description = "Select tags that will be used for inference"
|
|
8
|
+
lock_message = "Select previous step to unlock"
|
|
9
|
+
|
|
10
|
+
def __init__(self):
|
|
11
|
+
# Init Step
|
|
12
|
+
self.display_widgets = []
|
|
13
|
+
# -------------------------------- #
|
|
14
|
+
|
|
15
|
+
# Init Base Widgets
|
|
16
|
+
self.validator_text = None
|
|
17
|
+
self.button = None
|
|
18
|
+
self.container = None
|
|
19
|
+
self.card = None
|
|
20
|
+
# -------------------------------- #
|
|
21
|
+
|
|
22
|
+
# Init Step Widgets
|
|
23
|
+
self.tags_table = None
|
|
24
|
+
# -------------------------------- #
|
|
25
|
+
|
|
26
|
+
# Tags
|
|
27
|
+
self.tags_table = TagsTable()
|
|
28
|
+
self.tags_table.hide()
|
|
29
|
+
# Add widgets to display ------------ #
|
|
30
|
+
self.display_widgets.extend([self.tags_table])
|
|
31
|
+
# ----------------------------------- #
|
|
32
|
+
|
|
33
|
+
# Base Widgets
|
|
34
|
+
self.validator_text = Text("")
|
|
35
|
+
self.validator_text.hide()
|
|
36
|
+
self.button = Button("Select")
|
|
37
|
+
self.display_widgets.extend([self.validator_text, self.button])
|
|
38
|
+
# -------------------------------- #
|
|
39
|
+
|
|
40
|
+
# Card Layout
|
|
41
|
+
self.container = Container(self.display_widgets)
|
|
42
|
+
self.card = Card(
|
|
43
|
+
title=self.title,
|
|
44
|
+
description=self.description,
|
|
45
|
+
content=self.container,
|
|
46
|
+
lock_message=self.lock_message,
|
|
47
|
+
)
|
|
48
|
+
self.card.lock()
|
|
49
|
+
# -------------------------------- #
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def widgets_to_disable(self) -> list:
|
|
53
|
+
return [self.tags_table]
|
|
54
|
+
|
|
55
|
+
def load_from_json(self, data: Dict[str, Any]) -> None:
|
|
56
|
+
if "tags" in data:
|
|
57
|
+
self.set_tags(data["tags"])
|
|
58
|
+
|
|
59
|
+
def get_selected_tags(self) -> list:
|
|
60
|
+
return self.tags_table.get_selected_tags()
|
|
61
|
+
|
|
62
|
+
def set_tags(self, tags) -> None:
|
|
63
|
+
self.tags_table.select_tags(tags)
|
|
64
|
+
|
|
65
|
+
def select_all_tags(self) -> None:
|
|
66
|
+
self.tags_table.select_all()
|
|
67
|
+
|
|
68
|
+
def get_settings(self) -> Dict[str, Any]:
|
|
69
|
+
return {"tags": self.get_selected_tags()}
|
|
70
|
+
|
|
71
|
+
def validate_step(self) -> bool:
|
|
72
|
+
if self.tags_table.is_hidden():
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
self.validator_text.hide()
|
|
76
|
+
|
|
77
|
+
project_tags = self.tags_table.project_meta.tag_metas
|
|
78
|
+
if len(project_tags) == 0:
|
|
79
|
+
self.validator_text.set(text="Project has no tags", status="error")
|
|
80
|
+
self.validator_text.show()
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
selected_tags = self.tags_table.get_selected_tags()
|
|
84
|
+
table_data = self.tags_table._table_data
|
|
85
|
+
empty_tags = [
|
|
86
|
+
row[0]["data"]
|
|
87
|
+
for row in table_data
|
|
88
|
+
if row[0]["data"] in selected_tags and row[2]["data"] == 0 and row[3]["data"] == 0
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
n_tags = len(selected_tags)
|
|
92
|
+
if n_tags == 0:
|
|
93
|
+
message = "Please select at least one tag"
|
|
94
|
+
status = "error"
|
|
95
|
+
else:
|
|
96
|
+
tag_text = "tag" if n_tags == 1 else "tags"
|
|
97
|
+
message = f"Selected {n_tags} {tag_text}"
|
|
98
|
+
status = "success"
|
|
99
|
+
if empty_tags:
|
|
100
|
+
intersections = set(selected_tags).intersection(empty_tags)
|
|
101
|
+
if intersections:
|
|
102
|
+
tag_text = "tag" if len(intersections) == 1 else "tags"
|
|
103
|
+
message += (
|
|
104
|
+
f". Selected {tag_text} have no annotations: {', '.join(intersections)}"
|
|
105
|
+
)
|
|
106
|
+
status = "warning"
|
|
107
|
+
|
|
108
|
+
self.validator_text.set(text=message, status=status)
|
|
109
|
+
self.validator_text.show()
|
|
110
|
+
return n_tags > 0
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Set
|
|
2
|
+
from supervisely import logger
|
|
3
|
+
from supervisely.api.api import Api
|
|
4
|
+
from supervisely.api.project_api import ProjectInfo
|
|
5
|
+
from supervisely.api.dataset_api import DatasetInfo
|
|
6
|
+
from supervisely.project.project import ProjectType
|
|
7
|
+
from supervisely.app.widgets import Progress
|
|
8
|
+
from supervisely.app import DataJson
|
|
9
|
+
from supervisely.app.widgets import Button, Card, Stepper, Text, Widget
|
|
10
|
+
|
|
11
|
+
button_clicked = {}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def update_custom_params(
|
|
15
|
+
button: Button,
|
|
16
|
+
params_dct: Dict[str, Any],
|
|
17
|
+
) -> None:
|
|
18
|
+
button_state = button.get_json_data()
|
|
19
|
+
for key in params_dct.keys():
|
|
20
|
+
if key not in button_state:
|
|
21
|
+
raise AttributeError(f"Parameter {key} doesn't exists.")
|
|
22
|
+
else:
|
|
23
|
+
DataJson()[button.widget_id][key] = params_dct[key]
|
|
24
|
+
DataJson().send_changes()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def update_custom_button_params(
|
|
28
|
+
button: Button,
|
|
29
|
+
params_dct: Dict[str, Any],
|
|
30
|
+
) -> None:
|
|
31
|
+
params = params_dct.copy()
|
|
32
|
+
if "icon" in params and params["icon"] is not None:
|
|
33
|
+
new_icon = f'<i class="{params["icon"]}" style="margin-right: {button._icon_gap}px"></i>'
|
|
34
|
+
params["icon"] = new_icon
|
|
35
|
+
update_custom_params(button, params)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def disable_enable(widgets: List[Widget], disable: bool = True):
|
|
39
|
+
for w in widgets:
|
|
40
|
+
if disable:
|
|
41
|
+
w.disable()
|
|
42
|
+
else:
|
|
43
|
+
w.enable()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def unlock_lock(cards: List[Card], unlock: bool = True, message: str = None):
|
|
47
|
+
for w in cards:
|
|
48
|
+
if unlock:
|
|
49
|
+
w.unlock()
|
|
50
|
+
# w.uncollapse()
|
|
51
|
+
else:
|
|
52
|
+
w.lock(message)
|
|
53
|
+
# w.collapse()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def collapse_uncollapse(cards: List[Card], collapse: bool = True):
|
|
57
|
+
for w in cards:
|
|
58
|
+
if collapse:
|
|
59
|
+
w.collapse()
|
|
60
|
+
else:
|
|
61
|
+
w.uncollapse()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def wrap_button_click(
|
|
65
|
+
button: Button,
|
|
66
|
+
cards_to_unlock: List[Card],
|
|
67
|
+
widgets_to_disable: List[Widget],
|
|
68
|
+
callback: Optional[Callable] = None,
|
|
69
|
+
lock_msg: str = None,
|
|
70
|
+
upd_params: bool = True,
|
|
71
|
+
validation_text: Text = None,
|
|
72
|
+
validation_func: Optional[Callable] = None,
|
|
73
|
+
on_select_click: Optional[Callable] = None,
|
|
74
|
+
on_reselect_click: Optional[Callable] = None,
|
|
75
|
+
collapse_card: Tuple[Card, bool] = None,
|
|
76
|
+
) -> Callable[[Optional[bool]], None]:
|
|
77
|
+
global button_clicked
|
|
78
|
+
|
|
79
|
+
select_params = {"icon": None, "plain": False, "text": "Select"}
|
|
80
|
+
reselect_params = {"icon": "zmdi zmdi-refresh", "plain": True, "text": "Reselect"}
|
|
81
|
+
bid = button.widget_id
|
|
82
|
+
button_clicked[bid] = False
|
|
83
|
+
|
|
84
|
+
def button_click(button_clicked_value: Optional[bool] = None):
|
|
85
|
+
if button_clicked_value is None or button_clicked_value is False:
|
|
86
|
+
if validation_func is not None:
|
|
87
|
+
success = validation_func()
|
|
88
|
+
if not success:
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
if button_clicked_value is not None:
|
|
92
|
+
button_clicked[bid] = button_clicked_value
|
|
93
|
+
else:
|
|
94
|
+
button_clicked[bid] = not button_clicked[bid]
|
|
95
|
+
|
|
96
|
+
if button_clicked[bid] and upd_params:
|
|
97
|
+
update_custom_button_params(button, reselect_params)
|
|
98
|
+
if on_select_click is not None:
|
|
99
|
+
for func in on_select_click:
|
|
100
|
+
func()
|
|
101
|
+
else:
|
|
102
|
+
update_custom_button_params(button, select_params)
|
|
103
|
+
if on_reselect_click is not None:
|
|
104
|
+
for func in on_reselect_click:
|
|
105
|
+
func()
|
|
106
|
+
validation_text.hide()
|
|
107
|
+
|
|
108
|
+
unlock_lock(
|
|
109
|
+
cards_to_unlock,
|
|
110
|
+
unlock=button_clicked[bid],
|
|
111
|
+
message=lock_msg,
|
|
112
|
+
)
|
|
113
|
+
disable_enable(
|
|
114
|
+
widgets_to_disable,
|
|
115
|
+
disable=button_clicked[bid],
|
|
116
|
+
)
|
|
117
|
+
if callback is not None and not button_clicked[bid]:
|
|
118
|
+
callback(False)
|
|
119
|
+
|
|
120
|
+
if collapse_card is not None:
|
|
121
|
+
card, collapse = collapse_card
|
|
122
|
+
if collapse:
|
|
123
|
+
collapse_uncollapse([card], collapse)
|
|
124
|
+
|
|
125
|
+
return button_click
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def set_stepper_step(stepper: Stepper, button: Button, next_pos: int):
|
|
129
|
+
bid = button.widget_id
|
|
130
|
+
if button_clicked[bid] is True:
|
|
131
|
+
stepper.set_active_step(next_pos)
|
|
132
|
+
else:
|
|
133
|
+
stepper.set_active_step(next_pos - 1)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def find_parents_in_tree(
|
|
137
|
+
tree: Dict[DatasetInfo, Dict], dataset_id: int, with_self: bool = False
|
|
138
|
+
) -> Optional[List[DatasetInfo]]:
|
|
139
|
+
"""
|
|
140
|
+
Find all parent datasets in the tree for a given dataset ID.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def _dfs(subtree: Dict[DatasetInfo, Dict], parents: List[DatasetInfo]):
|
|
144
|
+
for dataset_info, children in subtree.items():
|
|
145
|
+
if dataset_info.id == dataset_id:
|
|
146
|
+
if with_self:
|
|
147
|
+
return parents + [dataset_info]
|
|
148
|
+
return parents
|
|
149
|
+
res = _dfs(children, parents + [dataset_info])
|
|
150
|
+
if res is not None:
|
|
151
|
+
return res
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
return _dfs(tree, [])
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def copy_project(
|
|
158
|
+
api: Api,
|
|
159
|
+
project_name: str,
|
|
160
|
+
workspace_id: int,
|
|
161
|
+
project_id: int,
|
|
162
|
+
dataset_ids: List[int] = [],
|
|
163
|
+
with_annotations: bool = True,
|
|
164
|
+
progress: Progress = None,
|
|
165
|
+
):
|
|
166
|
+
"""
|
|
167
|
+
Copy a project
|
|
168
|
+
|
|
169
|
+
:param api: Supervisely API
|
|
170
|
+
:type api: Api
|
|
171
|
+
:param project_name: Name of the new project
|
|
172
|
+
:type project_name: str
|
|
173
|
+
:param workspace_id: ID of the workspace
|
|
174
|
+
:type workspace_id: int
|
|
175
|
+
:param project_id: ID of the project to copy
|
|
176
|
+
:type project_id: int
|
|
177
|
+
:param dataset_ids: List of dataset IDs to copy. If empty, all datasets from the project will be copied.
|
|
178
|
+
:type dataset_ids: List[int]
|
|
179
|
+
:param with_annotations: Whether to copy annotations
|
|
180
|
+
:type with_annotations: bool
|
|
181
|
+
:param progress: Progress callback
|
|
182
|
+
:type progress: Progress
|
|
183
|
+
:return: Created project
|
|
184
|
+
:rtype: ProjectInfo
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def _create_project() -> ProjectInfo:
|
|
188
|
+
created_project = api.project.create(
|
|
189
|
+
workspace_id,
|
|
190
|
+
project_name,
|
|
191
|
+
type=ProjectType.IMAGES,
|
|
192
|
+
change_name_if_conflict=True,
|
|
193
|
+
)
|
|
194
|
+
if with_annotations:
|
|
195
|
+
api.project.merge_metas(src_project_id=project_id, dst_project_id=created_project.id)
|
|
196
|
+
return created_project
|
|
197
|
+
|
|
198
|
+
def _copy_full_project(
|
|
199
|
+
created_project: ProjectInfo, src_datasets_tree: Dict[DatasetInfo, Dict]
|
|
200
|
+
):
|
|
201
|
+
src_dst_ds_id_map: Dict[int, int] = {}
|
|
202
|
+
|
|
203
|
+
def _create_full_tree(ds_tree: Dict[DatasetInfo, Dict], parent_id: int = None):
|
|
204
|
+
for src_ds, nested_src_ds_tree in ds_tree.items():
|
|
205
|
+
dst_ds = api.dataset.create(
|
|
206
|
+
project_id=created_project.id,
|
|
207
|
+
name=src_ds.name,
|
|
208
|
+
description=src_ds.description,
|
|
209
|
+
change_name_if_conflict=True,
|
|
210
|
+
parent_id=parent_id,
|
|
211
|
+
)
|
|
212
|
+
src_dst_ds_id_map[src_ds.id] = dst_ds
|
|
213
|
+
|
|
214
|
+
# Preserve dataset custom data
|
|
215
|
+
info_ds = api.dataset.get_info_by_id(src_ds.id)
|
|
216
|
+
if info_ds.custom_data:
|
|
217
|
+
api.dataset.update_custom_data(dst_ds.id, info_ds.custom_data)
|
|
218
|
+
_create_full_tree(nested_src_ds_tree, parent_id=dst_ds.id)
|
|
219
|
+
|
|
220
|
+
_create_full_tree(src_datasets_tree)
|
|
221
|
+
|
|
222
|
+
for src_ds_id, dst_ds in src_dst_ds_id_map.items():
|
|
223
|
+
_copy_items(src_ds_id, dst_ds)
|
|
224
|
+
|
|
225
|
+
def _copy_datasets(created_project: ProjectInfo, src_datasets_tree: Dict[DatasetInfo, Dict]):
|
|
226
|
+
created_datasets: Dict[int, DatasetInfo] = {}
|
|
227
|
+
processed_copy: Set[int] = set()
|
|
228
|
+
|
|
229
|
+
for dataset_id in dataset_ids:
|
|
230
|
+
chain = find_parents_in_tree(src_datasets_tree, dataset_id, with_self=True)
|
|
231
|
+
if not chain:
|
|
232
|
+
logger.warning(
|
|
233
|
+
f"Dataset id {dataset_id} not found in project {project_id}. Skipping."
|
|
234
|
+
)
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
parent_created_id = None
|
|
238
|
+
for ds_info in chain:
|
|
239
|
+
if ds_info.id in created_datasets:
|
|
240
|
+
parent_created_id = created_datasets[ds_info.id].id
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
created_ds = api.dataset.create(
|
|
244
|
+
created_project.id,
|
|
245
|
+
ds_info.name,
|
|
246
|
+
description=ds_info.description,
|
|
247
|
+
change_name_if_conflict=False,
|
|
248
|
+
parent_id=parent_created_id,
|
|
249
|
+
)
|
|
250
|
+
created_datasets[ds_info.id] = created_ds
|
|
251
|
+
src_info = api.dataset.get_info_by_id(ds_info.id)
|
|
252
|
+
if src_info.custom_data:
|
|
253
|
+
api.dataset.update_custom_data(created_ds.id, src_info.custom_data)
|
|
254
|
+
parent_created_id = created_ds.id
|
|
255
|
+
|
|
256
|
+
if dataset_id not in processed_copy:
|
|
257
|
+
_copy_items(dataset_id, created_datasets[dataset_id])
|
|
258
|
+
processed_copy.add(dataset_id)
|
|
259
|
+
|
|
260
|
+
def _copy_items(src_ds_id: int, dst_ds: DatasetInfo):
|
|
261
|
+
input_img_infos = api.image.get_list(src_ds_id)
|
|
262
|
+
with progress(
|
|
263
|
+
message=f"Copying items from dataset: {dst_ds.name}", total=len(input_img_infos)
|
|
264
|
+
) as pbar:
|
|
265
|
+
progress.show()
|
|
266
|
+
api.image.copy_batch_optimized(
|
|
267
|
+
src_dataset_id=src_ds_id,
|
|
268
|
+
src_image_infos=input_img_infos,
|
|
269
|
+
dst_dataset_id=dst_ds.id,
|
|
270
|
+
with_annotations=with_annotations,
|
|
271
|
+
progress_cb=pbar.update,
|
|
272
|
+
)
|
|
273
|
+
progress.hide()
|
|
274
|
+
|
|
275
|
+
created_project = _create_project()
|
|
276
|
+
src_datasets_tree = api.dataset.get_tree(project_id)
|
|
277
|
+
|
|
278
|
+
if not dataset_ids:
|
|
279
|
+
_copy_full_project(created_project, src_datasets_tree)
|
|
280
|
+
else:
|
|
281
|
+
_copy_datasets(created_project, src_datasets_tree)
|
|
282
|
+
return created_project
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from fastapi import BackgroundTasks, Request
|
|
5
|
+
|
|
6
|
+
from supervisely._utils import logger
|
|
7
|
+
from supervisely.api.api import Api
|
|
8
|
+
from supervisely.app.fastapi.subapp import Application
|
|
9
|
+
from supervisely.nn.inference.predict_app.gui.gui import PredictAppGui
|
|
10
|
+
from supervisely.nn.model.prediction import Prediction
|
|
11
|
+
from supervisely.nn.inference.predict_app.gui.utils import disable_enable
|
|
12
|
+
import supervisely.io.fs as sly_fs
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PredictApp:
|
|
16
|
+
def __init__(self, api: Api):
|
|
17
|
+
_static_dir = "static"
|
|
18
|
+
sly_fs.mkdir(_static_dir, True)
|
|
19
|
+
self.api = api
|
|
20
|
+
self.gui = PredictAppGui(api, static_dir=_static_dir)
|
|
21
|
+
self.app = Application(self.gui.layout, static_dir=_static_dir)
|
|
22
|
+
self._add_endpoints()
|
|
23
|
+
|
|
24
|
+
@self.gui.output_selector.start_button.click
|
|
25
|
+
def start_prediction():
|
|
26
|
+
if self.gui.output_selector.validate_step():
|
|
27
|
+
disable_enable(self.gui.output_selector.widgets_to_disable, True)
|
|
28
|
+
self.gui.run()
|
|
29
|
+
self.shutdown_serving_app()
|
|
30
|
+
self.shutdown_predict_app()
|
|
31
|
+
|
|
32
|
+
def shutdown_serving_app(self):
|
|
33
|
+
if self.gui.output_selector.should_stop_serving_on_finish():
|
|
34
|
+
logger.info("Stopping serving app...")
|
|
35
|
+
self.gui.model_selector.model.stop()
|
|
36
|
+
|
|
37
|
+
def shutdown_predict_app(self):
|
|
38
|
+
if self.gui.output_selector.should_stop_self_on_finish():
|
|
39
|
+
self.gui.output_selector.start_button.disable()
|
|
40
|
+
logger.info("Stopping Predict App...")
|
|
41
|
+
self.app.stop()
|
|
42
|
+
else:
|
|
43
|
+
disable_enable(self.gui.output_selector.widgets_to_disable, False)
|
|
44
|
+
self.gui.output_selector.start_button.enable()
|
|
45
|
+
|
|
46
|
+
def run(self, run_parameters: Optional[Dict] = None) -> List[Prediction]:
|
|
47
|
+
return self.gui.run(run_parameters)
|
|
48
|
+
|
|
49
|
+
def stop(self):
|
|
50
|
+
self.gui.stop()
|
|
51
|
+
|
|
52
|
+
def shutdown_model(self):
|
|
53
|
+
self.gui.shutdown_model()
|
|
54
|
+
|
|
55
|
+
def load_from_json(self, data):
|
|
56
|
+
self.gui.load_from_json(data)
|
|
57
|
+
if data.get("run", False):
|
|
58
|
+
try:
|
|
59
|
+
self.run()
|
|
60
|
+
except Exception as e:
|
|
61
|
+
raise
|
|
62
|
+
finally:
|
|
63
|
+
if data.get("stop_after_run", False):
|
|
64
|
+
self.shutdown_model()
|
|
65
|
+
self.app.stop()
|
|
66
|
+
|
|
67
|
+
def get_inference_settings(self):
|
|
68
|
+
return self.gui.settings_selector.get_inference_settings()
|
|
69
|
+
|
|
70
|
+
def get_run_parameters(self):
|
|
71
|
+
return self.gui.get_run_parameters()
|
|
72
|
+
|
|
73
|
+
def _add_endpoints(self):
|
|
74
|
+
server = self.app.get_server()
|
|
75
|
+
|
|
76
|
+
@server.post("/load")
|
|
77
|
+
def load(request: Request, background_tasks: BackgroundTasks):
|
|
78
|
+
"""
|
|
79
|
+
Load the model state from a JSON object.
|
|
80
|
+
This endpoint initializes the model with the provided state.
|
|
81
|
+
All the fields are optional
|
|
82
|
+
|
|
83
|
+
Example state:
|
|
84
|
+
state = {
|
|
85
|
+
"model": {
|
|
86
|
+
"mode": "connect",
|
|
87
|
+
"session_id": "12345"
|
|
88
|
+
# "mode": "pretrained",
|
|
89
|
+
# "framework: "YOLO",
|
|
90
|
+
# "model_name": "YOLO11m-seg",
|
|
91
|
+
# "mode": "custom",
|
|
92
|
+
# "train_task_id": 123
|
|
93
|
+
},
|
|
94
|
+
"items": {
|
|
95
|
+
"project_id": 123,
|
|
96
|
+
# "dataset_ids": [...],
|
|
97
|
+
# "video_id": 123
|
|
98
|
+
},
|
|
99
|
+
"inference_settings": {
|
|
100
|
+
"confidence_threshold": 0.5
|
|
101
|
+
},
|
|
102
|
+
"output": {
|
|
103
|
+
"mode": "create",
|
|
104
|
+
"project_name": "Predictions",
|
|
105
|
+
# "mode": "append",
|
|
106
|
+
# "mode": "replace",
|
|
107
|
+
# "mode": "iou_merge",
|
|
108
|
+
# "iou_merge_threshold": 0.5
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
"""
|
|
112
|
+
state = request.state.state
|
|
113
|
+
stop_after_run = state.get("stop_after_run", False)
|
|
114
|
+
if stop_after_run:
|
|
115
|
+
state["stop_after_run"] = False
|
|
116
|
+
self.load_from_json(state)
|
|
117
|
+
if stop_after_run:
|
|
118
|
+
self.shutdown_model()
|
|
119
|
+
background_tasks.add_task(self.app.stop)
|
|
120
|
+
|
|
121
|
+
@server.post("/deploy")
|
|
122
|
+
def deploy(request: Request):
|
|
123
|
+
"""
|
|
124
|
+
Deploy the model for inference.
|
|
125
|
+
This endpoint prepares the model for running predictions.
|
|
126
|
+
"""
|
|
127
|
+
self.gui.model_selector.model._deploy()
|
|
128
|
+
|
|
129
|
+
@server.get("/inference_settings")
|
|
130
|
+
def get_inference_settings():
|
|
131
|
+
"""
|
|
132
|
+
Get the inference settings for the model.
|
|
133
|
+
This endpoint returns the current inference settings.
|
|
134
|
+
"""
|
|
135
|
+
return self.get_inference_settings()
|
|
136
|
+
|
|
137
|
+
@server.get("/run_parameters")
|
|
138
|
+
def get_run_parameters():
|
|
139
|
+
"""
|
|
140
|
+
Get the run parameters for the model.
|
|
141
|
+
This endpoint returns the parameters needed to run the model.
|
|
142
|
+
"""
|
|
143
|
+
return self.get_run_parameters()
|
|
144
|
+
|
|
145
|
+
@server.post("/predict")
|
|
146
|
+
def predict(request: Request):
|
|
147
|
+
"""
|
|
148
|
+
Run the model prediction.
|
|
149
|
+
This endpoint processes the request data and runs the model prediction.
|
|
150
|
+
|
|
151
|
+
Example data:
|
|
152
|
+
data = {
|
|
153
|
+
"inference_settings": {
|
|
154
|
+
"conf": 0.6,
|
|
155
|
+
},
|
|
156
|
+
"item": {
|
|
157
|
+
# "project_id": ...,
|
|
158
|
+
# "dataset_ids": [...],
|
|
159
|
+
"image_ids": [1148679, 1148675],
|
|
160
|
+
},
|
|
161
|
+
"output": {"mode": "iou_merge", "iou_merge_threshold": 0.5},
|
|
162
|
+
}
|
|
163
|
+
"""
|
|
164
|
+
state = request.state.state
|
|
165
|
+
run_parameters = {
|
|
166
|
+
"item": state["item"],
|
|
167
|
+
}
|
|
168
|
+
if "inference_settings" in state:
|
|
169
|
+
run_parameters["inference_settings"] = state["inference_settings"]
|
|
170
|
+
if "output" in state:
|
|
171
|
+
run_parameters["output"] = state["output"]
|
|
172
|
+
else:
|
|
173
|
+
run_parameters["output"] = {"mode": None}
|
|
174
|
+
|
|
175
|
+
predictions = self.run(run_parameters)
|
|
176
|
+
return [prediction.to_json() for prediction in predictions]
|
|
177
|
+
|
|
178
|
+
@server.post("/run")
|
|
179
|
+
def run(request: Request):
|
|
180
|
+
"""
|
|
181
|
+
Run the model prediction.
|
|
182
|
+
"""
|
|
183
|
+
predicitons = self.run()
|
|
184
|
+
return [prediction.to_json() for prediction in predicitons]
|
|
@@ -105,10 +105,6 @@ class Uploader:
|
|
|
105
105
|
self.stop()
|
|
106
106
|
return
|
|
107
107
|
except Exception as e:
|
|
108
|
-
try:
|
|
109
|
-
raise RuntimeError("Error in upload loop") from e
|
|
110
|
-
except RuntimeError as e_:
|
|
111
|
-
e = e_
|
|
112
108
|
if self._logger is not None:
|
|
113
109
|
self._logger.error("Error in upload loop: %s", str(e), exc_info=True)
|
|
114
110
|
if not self._exception_event.is_set():
|
|
@@ -152,7 +148,9 @@ class Uploader:
|
|
|
152
148
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
153
149
|
self.stop()
|
|
154
150
|
try:
|
|
155
|
-
self.join(timeout=
|
|
151
|
+
self.join(timeout=30)
|
|
152
|
+
if self._upload_thread.is_alive():
|
|
153
|
+
raise TimeoutError("Uploader thread didn't finish in time")
|
|
156
154
|
except TimeoutError:
|
|
157
155
|
_logger = logger
|
|
158
156
|
if self._logger is not None:
|
|
@@ -161,4 +159,10 @@ class Uploader:
|
|
|
161
159
|
if exc_type is not None:
|
|
162
160
|
exc = exc_val.with_traceback(exc_tb)
|
|
163
161
|
return self._exception_handler(exc)
|
|
162
|
+
if self.has_exception():
|
|
163
|
+
exc = self.exception
|
|
164
|
+
try:
|
|
165
|
+
raise RuntimeError(f"Error in uploader loop: {str(exc)}") from exc
|
|
166
|
+
except Exception as exc:
|
|
167
|
+
return self._exception_handler(exc)
|
|
164
168
|
return False
|
|
@@ -240,6 +240,8 @@ class Prediction:
|
|
|
240
240
|
if self.image_id is not None:
|
|
241
241
|
try:
|
|
242
242
|
if api is None:
|
|
243
|
+
# TODO: raise more clarifying error in case of failing of api init
|
|
244
|
+
# what a user should do to fix it?
|
|
243
245
|
api = Api()
|
|
244
246
|
return api.image.download_np(self.image_id)
|
|
245
247
|
except Exception as e:
|
|
@@ -132,6 +132,14 @@ class PredictionSession:
|
|
|
132
132
|
self.tracker = None
|
|
133
133
|
self.tracker_settings = None
|
|
134
134
|
|
|
135
|
+
if "classes" in kwargs:
|
|
136
|
+
self.inference_settings["classes"] = kwargs["classes"]
|
|
137
|
+
# TODO: remove "settings", it is the same as inference_settings
|
|
138
|
+
if "settings" in kwargs:
|
|
139
|
+
self.inference_settings.update(kwargs["settings"])
|
|
140
|
+
if "inference_settings" in kwargs:
|
|
141
|
+
self.inference_settings.update(kwargs["inference_settings"])
|
|
142
|
+
|
|
135
143
|
# extra input args
|
|
136
144
|
image_ids = self._set_var_from_kwargs("image_ids", kwargs, image_id)
|
|
137
145
|
video_ids = self._set_var_from_kwargs("video_ids", kwargs, video_id)
|
|
@@ -159,7 +167,6 @@ class PredictionSession:
|
|
|
159
167
|
input = [input]
|
|
160
168
|
if isinstance(input[0], np.ndarray):
|
|
161
169
|
# input is numpy array
|
|
162
|
-
kwargs = get_valid_kwargs(kwargs, self._predict_images, exclude=["images"])
|
|
163
170
|
self._predict_images(input, **kwargs)
|
|
164
171
|
elif isinstance(input[0], (str, PathLike)):
|
|
165
172
|
if len(input) > 1:
|
|
@@ -288,6 +295,8 @@ class PredictionSession:
|
|
|
288
295
|
body["state"]["settings"] = self.inference_settings
|
|
289
296
|
if self.api_token is not None:
|
|
290
297
|
body["api_token"] = self.api_token
|
|
298
|
+
if "model_prediction_suffix" in self.kwargs:
|
|
299
|
+
body["state"]["model_prediction_suffix"] = self.kwargs["model_prediction_suffix"]
|
|
291
300
|
return body
|
|
292
301
|
|
|
293
302
|
def _post(self, method, *args, retries=5, **kwargs) -> requests.Response:
|
|
@@ -562,7 +571,11 @@ class PredictionSession:
|
|
|
562
571
|
return self._predict_images_bytes(images, batch_size=batch_size)
|
|
563
572
|
|
|
564
573
|
def _predict_images_ids(
|
|
565
|
-
self,
|
|
574
|
+
self,
|
|
575
|
+
images: List[int],
|
|
576
|
+
batch_size: int = None,
|
|
577
|
+
upload_mode: str = None,
|
|
578
|
+
output_project_id: int = None,
|
|
566
579
|
):
|
|
567
580
|
method = "inference_batch_ids_async"
|
|
568
581
|
json_body = self._get_json_body()
|
|
@@ -572,6 +585,8 @@ class PredictionSession:
|
|
|
572
585
|
state["batch_size"] = batch_size
|
|
573
586
|
if upload_mode is not None:
|
|
574
587
|
state["upload_mode"] = upload_mode
|
|
588
|
+
if output_project_id is not None:
|
|
589
|
+
state["output_project_id"] = output_project_id
|
|
575
590
|
return self._start_inference(method, json=json_body)
|
|
576
591
|
|
|
577
592
|
def _predict_videos(
|
|
@@ -647,6 +662,7 @@ class PredictionSession:
|
|
|
647
662
|
upload_mode: str = None,
|
|
648
663
|
iou_merge_threshold: float = None,
|
|
649
664
|
cache_project_on_model: bool = None,
|
|
665
|
+
output_project_id: int = None,
|
|
650
666
|
):
|
|
651
667
|
if len(project_ids) != 1:
|
|
652
668
|
raise ValueError("Only one project can be processed at a time.")
|
|
@@ -664,7 +680,8 @@ class PredictionSession:
|
|
|
664
680
|
state["iou_merge_threshold"] = iou_merge_threshold
|
|
665
681
|
if cache_project_on_model is not None:
|
|
666
682
|
state["cache_project_on_model"] = cache_project_on_model
|
|
667
|
-
|
|
683
|
+
if output_project_id is not None:
|
|
684
|
+
state["output_project_id"] = output_project_id
|
|
668
685
|
return self._start_inference(method, json=json_body)
|
|
669
686
|
|
|
670
687
|
def _predict_datasets(
|