supervisely 6.73.419__py3-none-any.whl → 6.73.421__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. supervisely/api/api.py +10 -5
  2. supervisely/api/app_api.py +71 -4
  3. supervisely/api/module_api.py +4 -0
  4. supervisely/api/nn/deploy_api.py +15 -9
  5. supervisely/api/nn/ecosystem_models_api.py +201 -0
  6. supervisely/api/nn/neural_network_api.py +12 -3
  7. supervisely/api/project_api.py +35 -6
  8. supervisely/api/task_api.py +5 -1
  9. supervisely/app/widgets/__init__.py +8 -1
  10. supervisely/app/widgets/agent_selector/template.html +1 -0
  11. supervisely/app/widgets/deploy_model/__init__.py +0 -0
  12. supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
  13. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  14. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  15. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  16. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  17. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
  18. supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
  19. supervisely/app/widgets/fast_table/fast_table.py +402 -74
  20. supervisely/app/widgets/fast_table/script.js +364 -96
  21. supervisely/app/widgets/fast_table/style.css +24 -0
  22. supervisely/app/widgets/fast_table/template.html +43 -3
  23. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  24. supervisely/app/widgets/select/select.py +6 -4
  25. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
  26. supervisely/app/widgets/tabs/tabs.py +22 -6
  27. supervisely/app/widgets/tabs/template.html +5 -1
  28. supervisely/nn/artifacts/__init__.py +1 -1
  29. supervisely/nn/artifacts/artifacts.py +10 -2
  30. supervisely/nn/artifacts/detectron2.py +1 -0
  31. supervisely/nn/artifacts/hrda.py +1 -0
  32. supervisely/nn/artifacts/mmclassification.py +20 -0
  33. supervisely/nn/artifacts/mmdetection.py +5 -3
  34. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  35. supervisely/nn/artifacts/ritm.py +1 -0
  36. supervisely/nn/artifacts/rtdetr.py +1 -0
  37. supervisely/nn/artifacts/unet.py +1 -0
  38. supervisely/nn/artifacts/utils.py +3 -0
  39. supervisely/nn/artifacts/yolov5.py +2 -0
  40. supervisely/nn/artifacts/yolov8.py +1 -0
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  42. supervisely/nn/experiments.py +9 -0
  43. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  44. supervisely/nn/inference/inference.py +160 -94
  45. supervisely/nn/inference/predict_app/__init__.py +0 -0
  46. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  47. supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
  48. supervisely/nn/inference/predict_app/gui/gui.py +710 -0
  49. supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
  50. supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
  51. supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
  52. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  53. supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
  54. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  55. supervisely/nn/inference/predict_app/gui/utils.py +282 -0
  56. supervisely/nn/inference/predict_app/predict_app.py +184 -0
  57. supervisely/nn/inference/uploader.py +9 -5
  58. supervisely/nn/model/prediction.py +2 -0
  59. supervisely/nn/model/prediction_session.py +20 -3
  60. supervisely/nn/training/gui/gui.py +131 -44
  61. supervisely/nn/training/gui/model_selector.py +8 -6
  62. supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
  63. supervisely/nn/training/gui/training_artifacts.py +0 -5
  64. supervisely/nn/training/train_app.py +161 -44
  65. supervisely/project/project.py +211 -73
  66. supervisely/template/experiment/experiment.html.jinja +74 -17
  67. supervisely/template/experiment/experiment_generator.py +258 -112
  68. supervisely/template/experiment/header.html.jinja +31 -13
  69. supervisely/template/experiment/sly-style.css +7 -2
  70. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/METADATA +3 -1
  71. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/RECORD +75 -57
  72. supervisely/app/widgets/experiment_selector/style.css +0 -27
  73. supervisely/app/widgets/experiment_selector/template.html +0 -61
  74. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/LICENSE +0 -0
  75. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/WHEEL +0 -0
  76. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/entry_points.txt +0 -0
  77. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,110 @@
1
+ from typing import Any, Dict
2
+ from supervisely.app.widgets import Button, Card, Container, TagsTable, Text
3
+
4
+
5
+ class TagsSelector:
6
+ title = "Tags Selector"
7
+ description = "Select tags that will be used for inference"
8
+ lock_message = "Select previous step to unlock"
9
+
10
+ def __init__(self):
11
+ # Init Step
12
+ self.display_widgets = []
13
+ # -------------------------------- #
14
+
15
+ # Init Base Widgets
16
+ self.validator_text = None
17
+ self.button = None
18
+ self.container = None
19
+ self.card = None
20
+ # -------------------------------- #
21
+
22
+ # Init Step Widgets
23
+ self.tags_table = None
24
+ # -------------------------------- #
25
+
26
+ # Tags
27
+ self.tags_table = TagsTable()
28
+ self.tags_table.hide()
29
+ # Add widgets to display ------------ #
30
+ self.display_widgets.extend([self.tags_table])
31
+ # ----------------------------------- #
32
+
33
+ # Base Widgets
34
+ self.validator_text = Text("")
35
+ self.validator_text.hide()
36
+ self.button = Button("Select")
37
+ self.display_widgets.extend([self.validator_text, self.button])
38
+ # -------------------------------- #
39
+
40
+ # Card Layout
41
+ self.container = Container(self.display_widgets)
42
+ self.card = Card(
43
+ title=self.title,
44
+ description=self.description,
45
+ content=self.container,
46
+ lock_message=self.lock_message,
47
+ )
48
+ self.card.lock()
49
+ # -------------------------------- #
50
+
51
+ @property
52
+ def widgets_to_disable(self) -> list:
53
+ return [self.tags_table]
54
+
55
+ def load_from_json(self, data: Dict[str, Any]) -> None:
56
+ if "tags" in data:
57
+ self.set_tags(data["tags"])
58
+
59
+ def get_selected_tags(self) -> list:
60
+ return self.tags_table.get_selected_tags()
61
+
62
+ def set_tags(self, tags) -> None:
63
+ self.tags_table.select_tags(tags)
64
+
65
+ def select_all_tags(self) -> None:
66
+ self.tags_table.select_all()
67
+
68
+ def get_settings(self) -> Dict[str, Any]:
69
+ return {"tags": self.get_selected_tags()}
70
+
71
+ def validate_step(self) -> bool:
72
+ if self.tags_table.is_hidden():
73
+ return True
74
+
75
+ self.validator_text.hide()
76
+
77
+ project_tags = self.tags_table.project_meta.tag_metas
78
+ if len(project_tags) == 0:
79
+ self.validator_text.set(text="Project has no tags", status="error")
80
+ self.validator_text.show()
81
+ return False
82
+
83
+ selected_tags = self.tags_table.get_selected_tags()
84
+ table_data = self.tags_table._table_data
85
+ empty_tags = [
86
+ row[0]["data"]
87
+ for row in table_data
88
+ if row[0]["data"] in selected_tags and row[2]["data"] == 0 and row[3]["data"] == 0
89
+ ]
90
+
91
+ n_tags = len(selected_tags)
92
+ if n_tags == 0:
93
+ message = "Please select at least one tag"
94
+ status = "error"
95
+ else:
96
+ tag_text = "tag" if n_tags == 1 else "tags"
97
+ message = f"Selected {n_tags} {tag_text}"
98
+ status = "success"
99
+ if empty_tags:
100
+ intersections = set(selected_tags).intersection(empty_tags)
101
+ if intersections:
102
+ tag_text = "tag" if len(intersections) == 1 else "tags"
103
+ message += (
104
+ f". Selected {tag_text} have no annotations: {', '.join(intersections)}"
105
+ )
106
+ status = "warning"
107
+
108
+ self.validator_text.set(text=message, status=status)
109
+ self.validator_text.show()
110
+ return n_tags > 0
@@ -0,0 +1,282 @@
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Set
2
+ from supervisely import logger
3
+ from supervisely.api.api import Api
4
+ from supervisely.api.project_api import ProjectInfo
5
+ from supervisely.api.dataset_api import DatasetInfo
6
+ from supervisely.project.project import ProjectType
7
+ from supervisely.app.widgets import Progress
8
+ from supervisely.app import DataJson
9
+ from supervisely.app.widgets import Button, Card, Stepper, Text, Widget
10
+
11
+ button_clicked = {}
12
+
13
+
14
+ def update_custom_params(
15
+ button: Button,
16
+ params_dct: Dict[str, Any],
17
+ ) -> None:
18
+ button_state = button.get_json_data()
19
+ for key in params_dct.keys():
20
+ if key not in button_state:
21
+ raise AttributeError(f"Parameter {key} doesn't exists.")
22
+ else:
23
+ DataJson()[button.widget_id][key] = params_dct[key]
24
+ DataJson().send_changes()
25
+
26
+
27
+ def update_custom_button_params(
28
+ button: Button,
29
+ params_dct: Dict[str, Any],
30
+ ) -> None:
31
+ params = params_dct.copy()
32
+ if "icon" in params and params["icon"] is not None:
33
+ new_icon = f'<i class="{params["icon"]}" style="margin-right: {button._icon_gap}px"></i>'
34
+ params["icon"] = new_icon
35
+ update_custom_params(button, params)
36
+
37
+
38
+ def disable_enable(widgets: List[Widget], disable: bool = True):
39
+ for w in widgets:
40
+ if disable:
41
+ w.disable()
42
+ else:
43
+ w.enable()
44
+
45
+
46
+ def unlock_lock(cards: List[Card], unlock: bool = True, message: str = None):
47
+ for w in cards:
48
+ if unlock:
49
+ w.unlock()
50
+ # w.uncollapse()
51
+ else:
52
+ w.lock(message)
53
+ # w.collapse()
54
+
55
+
56
+ def collapse_uncollapse(cards: List[Card], collapse: bool = True):
57
+ for w in cards:
58
+ if collapse:
59
+ w.collapse()
60
+ else:
61
+ w.uncollapse()
62
+
63
+
64
+ def wrap_button_click(
65
+ button: Button,
66
+ cards_to_unlock: List[Card],
67
+ widgets_to_disable: List[Widget],
68
+ callback: Optional[Callable] = None,
69
+ lock_msg: str = None,
70
+ upd_params: bool = True,
71
+ validation_text: Text = None,
72
+ validation_func: Optional[Callable] = None,
73
+ on_select_click: Optional[Callable] = None,
74
+ on_reselect_click: Optional[Callable] = None,
75
+ collapse_card: Tuple[Card, bool] = None,
76
+ ) -> Callable[[Optional[bool]], None]:
77
+ global button_clicked
78
+
79
+ select_params = {"icon": None, "plain": False, "text": "Select"}
80
+ reselect_params = {"icon": "zmdi zmdi-refresh", "plain": True, "text": "Reselect"}
81
+ bid = button.widget_id
82
+ button_clicked[bid] = False
83
+
84
+ def button_click(button_clicked_value: Optional[bool] = None):
85
+ if button_clicked_value is None or button_clicked_value is False:
86
+ if validation_func is not None:
87
+ success = validation_func()
88
+ if not success:
89
+ return
90
+
91
+ if button_clicked_value is not None:
92
+ button_clicked[bid] = button_clicked_value
93
+ else:
94
+ button_clicked[bid] = not button_clicked[bid]
95
+
96
+ if button_clicked[bid] and upd_params:
97
+ update_custom_button_params(button, reselect_params)
98
+ if on_select_click is not None:
99
+ for func in on_select_click:
100
+ func()
101
+ else:
102
+ update_custom_button_params(button, select_params)
103
+ if on_reselect_click is not None:
104
+ for func in on_reselect_click:
105
+ func()
106
+ validation_text.hide()
107
+
108
+ unlock_lock(
109
+ cards_to_unlock,
110
+ unlock=button_clicked[bid],
111
+ message=lock_msg,
112
+ )
113
+ disable_enable(
114
+ widgets_to_disable,
115
+ disable=button_clicked[bid],
116
+ )
117
+ if callback is not None and not button_clicked[bid]:
118
+ callback(False)
119
+
120
+ if collapse_card is not None:
121
+ card, collapse = collapse_card
122
+ if collapse:
123
+ collapse_uncollapse([card], collapse)
124
+
125
+ return button_click
126
+
127
+
128
+ def set_stepper_step(stepper: Stepper, button: Button, next_pos: int):
129
+ bid = button.widget_id
130
+ if button_clicked[bid] is True:
131
+ stepper.set_active_step(next_pos)
132
+ else:
133
+ stepper.set_active_step(next_pos - 1)
134
+
135
+
136
+ def find_parents_in_tree(
137
+ tree: Dict[DatasetInfo, Dict], dataset_id: int, with_self: bool = False
138
+ ) -> Optional[List[DatasetInfo]]:
139
+ """
140
+ Find all parent datasets in the tree for a given dataset ID.
141
+ """
142
+
143
+ def _dfs(subtree: Dict[DatasetInfo, Dict], parents: List[DatasetInfo]):
144
+ for dataset_info, children in subtree.items():
145
+ if dataset_info.id == dataset_id:
146
+ if with_self:
147
+ return parents + [dataset_info]
148
+ return parents
149
+ res = _dfs(children, parents + [dataset_info])
150
+ if res is not None:
151
+ return res
152
+ return None
153
+
154
+ return _dfs(tree, [])
155
+
156
+
157
+ def copy_project(
158
+ api: Api,
159
+ project_name: str,
160
+ workspace_id: int,
161
+ project_id: int,
162
+ dataset_ids: List[int] = [],
163
+ with_annotations: bool = True,
164
+ progress: Progress = None,
165
+ ):
166
+ """
167
+ Copy a project
168
+
169
+ :param api: Supervisely API
170
+ :type api: Api
171
+ :param project_name: Name of the new project
172
+ :type project_name: str
173
+ :param workspace_id: ID of the workspace
174
+ :type workspace_id: int
175
+ :param project_id: ID of the project to copy
176
+ :type project_id: int
177
+ :param dataset_ids: List of dataset IDs to copy. If empty, all datasets from the project will be copied.
178
+ :type dataset_ids: List[int]
179
+ :param with_annotations: Whether to copy annotations
180
+ :type with_annotations: bool
181
+ :param progress: Progress callback
182
+ :type progress: Progress
183
+ :return: Created project
184
+ :rtype: ProjectInfo
185
+ """
186
+
187
+ def _create_project() -> ProjectInfo:
188
+ created_project = api.project.create(
189
+ workspace_id,
190
+ project_name,
191
+ type=ProjectType.IMAGES,
192
+ change_name_if_conflict=True,
193
+ )
194
+ if with_annotations:
195
+ api.project.merge_metas(src_project_id=project_id, dst_project_id=created_project.id)
196
+ return created_project
197
+
198
+ def _copy_full_project(
199
+ created_project: ProjectInfo, src_datasets_tree: Dict[DatasetInfo, Dict]
200
+ ):
201
+ src_dst_ds_id_map: Dict[int, int] = {}
202
+
203
+ def _create_full_tree(ds_tree: Dict[DatasetInfo, Dict], parent_id: int = None):
204
+ for src_ds, nested_src_ds_tree in ds_tree.items():
205
+ dst_ds = api.dataset.create(
206
+ project_id=created_project.id,
207
+ name=src_ds.name,
208
+ description=src_ds.description,
209
+ change_name_if_conflict=True,
210
+ parent_id=parent_id,
211
+ )
212
+ src_dst_ds_id_map[src_ds.id] = dst_ds
213
+
214
+ # Preserve dataset custom data
215
+ info_ds = api.dataset.get_info_by_id(src_ds.id)
216
+ if info_ds.custom_data:
217
+ api.dataset.update_custom_data(dst_ds.id, info_ds.custom_data)
218
+ _create_full_tree(nested_src_ds_tree, parent_id=dst_ds.id)
219
+
220
+ _create_full_tree(src_datasets_tree)
221
+
222
+ for src_ds_id, dst_ds in src_dst_ds_id_map.items():
223
+ _copy_items(src_ds_id, dst_ds)
224
+
225
+ def _copy_datasets(created_project: ProjectInfo, src_datasets_tree: Dict[DatasetInfo, Dict]):
226
+ created_datasets: Dict[int, DatasetInfo] = {}
227
+ processed_copy: Set[int] = set()
228
+
229
+ for dataset_id in dataset_ids:
230
+ chain = find_parents_in_tree(src_datasets_tree, dataset_id, with_self=True)
231
+ if not chain:
232
+ logger.warning(
233
+ f"Dataset id {dataset_id} not found in project {project_id}. Skipping."
234
+ )
235
+ continue
236
+
237
+ parent_created_id = None
238
+ for ds_info in chain:
239
+ if ds_info.id in created_datasets:
240
+ parent_created_id = created_datasets[ds_info.id].id
241
+ continue
242
+
243
+ created_ds = api.dataset.create(
244
+ created_project.id,
245
+ ds_info.name,
246
+ description=ds_info.description,
247
+ change_name_if_conflict=False,
248
+ parent_id=parent_created_id,
249
+ )
250
+ created_datasets[ds_info.id] = created_ds
251
+ src_info = api.dataset.get_info_by_id(ds_info.id)
252
+ if src_info.custom_data:
253
+ api.dataset.update_custom_data(created_ds.id, src_info.custom_data)
254
+ parent_created_id = created_ds.id
255
+
256
+ if dataset_id not in processed_copy:
257
+ _copy_items(dataset_id, created_datasets[dataset_id])
258
+ processed_copy.add(dataset_id)
259
+
260
+ def _copy_items(src_ds_id: int, dst_ds: DatasetInfo):
261
+ input_img_infos = api.image.get_list(src_ds_id)
262
+ with progress(
263
+ message=f"Copying items from dataset: {dst_ds.name}", total=len(input_img_infos)
264
+ ) as pbar:
265
+ progress.show()
266
+ api.image.copy_batch_optimized(
267
+ src_dataset_id=src_ds_id,
268
+ src_image_infos=input_img_infos,
269
+ dst_dataset_id=dst_ds.id,
270
+ with_annotations=with_annotations,
271
+ progress_cb=pbar.update,
272
+ )
273
+ progress.hide()
274
+
275
+ created_project = _create_project()
276
+ src_datasets_tree = api.dataset.get_tree(project_id)
277
+
278
+ if not dataset_ids:
279
+ _copy_full_project(created_project, src_datasets_tree)
280
+ else:
281
+ _copy_datasets(created_project, src_datasets_tree)
282
+ return created_project
@@ -0,0 +1,184 @@
1
+ import os
2
+ from typing import Dict, List, Optional
3
+
4
+ from fastapi import BackgroundTasks, Request
5
+
6
+ from supervisely._utils import logger
7
+ from supervisely.api.api import Api
8
+ from supervisely.app.fastapi.subapp import Application
9
+ from supervisely.nn.inference.predict_app.gui.gui import PredictAppGui
10
+ from supervisely.nn.model.prediction import Prediction
11
+ from supervisely.nn.inference.predict_app.gui.utils import disable_enable
12
+ import supervisely.io.fs as sly_fs
13
+
14
+
15
+ class PredictApp:
16
+ def __init__(self, api: Api):
17
+ _static_dir = "static"
18
+ sly_fs.mkdir(_static_dir, True)
19
+ self.api = api
20
+ self.gui = PredictAppGui(api, static_dir=_static_dir)
21
+ self.app = Application(self.gui.layout, static_dir=_static_dir)
22
+ self._add_endpoints()
23
+
24
+ @self.gui.output_selector.start_button.click
25
+ def start_prediction():
26
+ if self.gui.output_selector.validate_step():
27
+ disable_enable(self.gui.output_selector.widgets_to_disable, True)
28
+ self.gui.run()
29
+ self.shutdown_serving_app()
30
+ self.shutdown_predict_app()
31
+
32
+ def shutdown_serving_app(self):
33
+ if self.gui.output_selector.should_stop_serving_on_finish():
34
+ logger.info("Stopping serving app...")
35
+ self.gui.model_selector.model.stop()
36
+
37
+ def shutdown_predict_app(self):
38
+ if self.gui.output_selector.should_stop_self_on_finish():
39
+ self.gui.output_selector.start_button.disable()
40
+ logger.info("Stopping Predict App...")
41
+ self.app.stop()
42
+ else:
43
+ disable_enable(self.gui.output_selector.widgets_to_disable, False)
44
+ self.gui.output_selector.start_button.enable()
45
+
46
+ def run(self, run_parameters: Optional[Dict] = None) -> List[Prediction]:
47
+ return self.gui.run(run_parameters)
48
+
49
+ def stop(self):
50
+ self.gui.stop()
51
+
52
+ def shutdown_model(self):
53
+ self.gui.shutdown_model()
54
+
55
+ def load_from_json(self, data):
56
+ self.gui.load_from_json(data)
57
+ if data.get("run", False):
58
+ try:
59
+ self.run()
60
+ except Exception as e:
61
+ raise
62
+ finally:
63
+ if data.get("stop_after_run", False):
64
+ self.shutdown_model()
65
+ self.app.stop()
66
+
67
+ def get_inference_settings(self):
68
+ return self.gui.settings_selector.get_inference_settings()
69
+
70
+ def get_run_parameters(self):
71
+ return self.gui.get_run_parameters()
72
+
73
+ def _add_endpoints(self):
74
+ server = self.app.get_server()
75
+
76
+ @server.post("/load")
77
+ def load(request: Request, background_tasks: BackgroundTasks):
78
+ """
79
+ Load the model state from a JSON object.
80
+ This endpoint initializes the model with the provided state.
81
+ All the fields are optional
82
+
83
+ Example state:
84
+ state = {
85
+ "model": {
86
+ "mode": "connect",
87
+ "session_id": "12345"
88
+ # "mode": "pretrained",
89
+ # "framework: "YOLO",
90
+ # "model_name": "YOLO11m-seg",
91
+ # "mode": "custom",
92
+ # "train_task_id": 123
93
+ },
94
+ "items": {
95
+ "project_id": 123,
96
+ # "dataset_ids": [...],
97
+ # "video_id": 123
98
+ },
99
+ "inference_settings": {
100
+ "confidence_threshold": 0.5
101
+ },
102
+ "output": {
103
+ "mode": "create",
104
+ "project_name": "Predictions",
105
+ # "mode": "append",
106
+ # "mode": "replace",
107
+ # "mode": "iou_merge",
108
+ # "iou_merge_threshold": 0.5
109
+ }
110
+ }
111
+ """
112
+ state = request.state.state
113
+ stop_after_run = state.get("stop_after_run", False)
114
+ if stop_after_run:
115
+ state["stop_after_run"] = False
116
+ self.load_from_json(state)
117
+ if stop_after_run:
118
+ self.shutdown_model()
119
+ background_tasks.add_task(self.app.stop)
120
+
121
+ @server.post("/deploy")
122
+ def deploy(request: Request):
123
+ """
124
+ Deploy the model for inference.
125
+ This endpoint prepares the model for running predictions.
126
+ """
127
+ self.gui.model_selector.model._deploy()
128
+
129
+ @server.get("/inference_settings")
130
+ def get_inference_settings():
131
+ """
132
+ Get the inference settings for the model.
133
+ This endpoint returns the current inference settings.
134
+ """
135
+ return self.get_inference_settings()
136
+
137
+ @server.get("/run_parameters")
138
+ def get_run_parameters():
139
+ """
140
+ Get the run parameters for the model.
141
+ This endpoint returns the parameters needed to run the model.
142
+ """
143
+ return self.get_run_parameters()
144
+
145
+ @server.post("/predict")
146
+ def predict(request: Request):
147
+ """
148
+ Run the model prediction.
149
+ This endpoint processes the request data and runs the model prediction.
150
+
151
+ Example data:
152
+ data = {
153
+ "inference_settings": {
154
+ "conf": 0.6,
155
+ },
156
+ "item": {
157
+ # "project_id": ...,
158
+ # "dataset_ids": [...],
159
+ "image_ids": [1148679, 1148675],
160
+ },
161
+ "output": {"mode": "iou_merge", "iou_merge_threshold": 0.5},
162
+ }
163
+ """
164
+ state = request.state.state
165
+ run_parameters = {
166
+ "item": state["item"],
167
+ }
168
+ if "inference_settings" in state:
169
+ run_parameters["inference_settings"] = state["inference_settings"]
170
+ if "output" in state:
171
+ run_parameters["output"] = state["output"]
172
+ else:
173
+ run_parameters["output"] = {"mode": None}
174
+
175
+ predictions = self.run(run_parameters)
176
+ return [prediction.to_json() for prediction in predictions]
177
+
178
+ @server.post("/run")
179
+ def run(request: Request):
180
+ """
181
+ Run the model prediction.
182
+ """
183
+ predicitons = self.run()
184
+ return [prediction.to_json() for prediction in predicitons]
@@ -105,10 +105,6 @@ class Uploader:
105
105
  self.stop()
106
106
  return
107
107
  except Exception as e:
108
- try:
109
- raise RuntimeError("Error in upload loop") from e
110
- except RuntimeError as e_:
111
- e = e_
112
108
  if self._logger is not None:
113
109
  self._logger.error("Error in upload loop: %s", str(e), exc_info=True)
114
110
  if not self._exception_event.is_set():
@@ -152,7 +148,9 @@ class Uploader:
152
148
  def __exit__(self, exc_type, exc_val, exc_tb):
153
149
  self.stop()
154
150
  try:
155
- self.join(timeout=5)
151
+ self.join(timeout=30)
152
+ if self._upload_thread.is_alive():
153
+ raise TimeoutError("Uploader thread didn't finish in time")
156
154
  except TimeoutError:
157
155
  _logger = logger
158
156
  if self._logger is not None:
@@ -161,4 +159,10 @@ class Uploader:
161
159
  if exc_type is not None:
162
160
  exc = exc_val.with_traceback(exc_tb)
163
161
  return self._exception_handler(exc)
162
+ if self.has_exception():
163
+ exc = self.exception
164
+ try:
165
+ raise RuntimeError(f"Error in uploader loop: {str(exc)}") from exc
166
+ except Exception as exc:
167
+ return self._exception_handler(exc)
164
168
  return False
@@ -240,6 +240,8 @@ class Prediction:
240
240
  if self.image_id is not None:
241
241
  try:
242
242
  if api is None:
243
+ # TODO: raise more clarifying error in case of failing of api init
244
+ # what a user should do to fix it?
243
245
  api = Api()
244
246
  return api.image.download_np(self.image_id)
245
247
  except Exception as e:
@@ -132,6 +132,14 @@ class PredictionSession:
132
132
  self.tracker = None
133
133
  self.tracker_settings = None
134
134
 
135
+ if "classes" in kwargs:
136
+ self.inference_settings["classes"] = kwargs["classes"]
137
+ # TODO: remove "settings", it is the same as inference_settings
138
+ if "settings" in kwargs:
139
+ self.inference_settings.update(kwargs["settings"])
140
+ if "inference_settings" in kwargs:
141
+ self.inference_settings.update(kwargs["inference_settings"])
142
+
135
143
  # extra input args
136
144
  image_ids = self._set_var_from_kwargs("image_ids", kwargs, image_id)
137
145
  video_ids = self._set_var_from_kwargs("video_ids", kwargs, video_id)
@@ -159,7 +167,6 @@ class PredictionSession:
159
167
  input = [input]
160
168
  if isinstance(input[0], np.ndarray):
161
169
  # input is numpy array
162
- kwargs = get_valid_kwargs(kwargs, self._predict_images, exclude=["images"])
163
170
  self._predict_images(input, **kwargs)
164
171
  elif isinstance(input[0], (str, PathLike)):
165
172
  if len(input) > 1:
@@ -288,6 +295,8 @@ class PredictionSession:
288
295
  body["state"]["settings"] = self.inference_settings
289
296
  if self.api_token is not None:
290
297
  body["api_token"] = self.api_token
298
+ if "model_prediction_suffix" in self.kwargs:
299
+ body["state"]["model_prediction_suffix"] = self.kwargs["model_prediction_suffix"]
291
300
  return body
292
301
 
293
302
  def _post(self, method, *args, retries=5, **kwargs) -> requests.Response:
@@ -562,7 +571,11 @@ class PredictionSession:
562
571
  return self._predict_images_bytes(images, batch_size=batch_size)
563
572
 
564
573
  def _predict_images_ids(
565
- self, images: List[int], batch_size: int = None, upload_mode: str = None
574
+ self,
575
+ images: List[int],
576
+ batch_size: int = None,
577
+ upload_mode: str = None,
578
+ output_project_id: int = None,
566
579
  ):
567
580
  method = "inference_batch_ids_async"
568
581
  json_body = self._get_json_body()
@@ -572,6 +585,8 @@ class PredictionSession:
572
585
  state["batch_size"] = batch_size
573
586
  if upload_mode is not None:
574
587
  state["upload_mode"] = upload_mode
588
+ if output_project_id is not None:
589
+ state["output_project_id"] = output_project_id
575
590
  return self._start_inference(method, json=json_body)
576
591
 
577
592
  def _predict_videos(
@@ -647,6 +662,7 @@ class PredictionSession:
647
662
  upload_mode: str = None,
648
663
  iou_merge_threshold: float = None,
649
664
  cache_project_on_model: bool = None,
665
+ output_project_id: int = None,
650
666
  ):
651
667
  if len(project_ids) != 1:
652
668
  raise ValueError("Only one project can be processed at a time.")
@@ -664,7 +680,8 @@ class PredictionSession:
664
680
  state["iou_merge_threshold"] = iou_merge_threshold
665
681
  if cache_project_on_model is not None:
666
682
  state["cache_project_on_model"] = cache_project_on_model
667
-
683
+ if output_project_id is not None:
684
+ state["output_project_id"] = output_project_id
668
685
  return self._start_inference(method, json=json_body)
669
686
 
670
687
  def _predict_datasets(