supervisely 6.73.242__py3-none-any.whl → 6.73.244__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +1 -1
- supervisely/_utils.py +18 -0
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/card/card.py +3 -0
- supervisely/app/widgets/classes_table/classes_table.py +15 -1
- supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
- supervisely/app/widgets/custom_models_selector/template.html +1 -1
- supervisely/app/widgets/experiment_selector/__init__.py +0 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
- supervisely/app/widgets/experiment_selector/style.css +27 -0
- supervisely/app/widgets/experiment_selector/template.html +82 -0
- supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
- supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
- supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/__init__.py +3 -1
- supervisely/nn/artifacts/artifacts.py +10 -0
- supervisely/nn/artifacts/detectron2.py +2 -0
- supervisely/nn/artifacts/hrda.py +3 -0
- supervisely/nn/artifacts/mmclassification.py +2 -0
- supervisely/nn/artifacts/mmdetection.py +6 -3
- supervisely/nn/artifacts/mmsegmentation.py +2 -0
- supervisely/nn/artifacts/ritm.py +3 -1
- supervisely/nn/artifacts/rtdetr.py +2 -0
- supervisely/nn/artifacts/unet.py +2 -0
- supervisely/nn/artifacts/yolov5.py +3 -0
- supervisely/nn/artifacts/yolov8.py +7 -1
- supervisely/nn/experiments.py +113 -0
- supervisely/nn/inference/gui/__init__.py +3 -1
- supervisely/nn/inference/gui/gui.py +31 -232
- supervisely/nn/inference/gui/serving_gui.py +223 -0
- supervisely/nn/inference/gui/serving_gui_template.py +240 -0
- supervisely/nn/inference/inference.py +225 -24
- supervisely/nn/training/__init__.py +0 -0
- supervisely/nn/training/gui/__init__.py +1 -0
- supervisely/nn/training/gui/classes_selector.py +100 -0
- supervisely/nn/training/gui/gui.py +539 -0
- supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
- supervisely/nn/training/gui/input_selector.py +70 -0
- supervisely/nn/training/gui/model_selector.py +95 -0
- supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
- supervisely/nn/training/gui/training_logs.py +93 -0
- supervisely/nn/training/gui/training_process.py +114 -0
- supervisely/nn/training/gui/utils.py +128 -0
- supervisely/nn/training/loggers/__init__.py +0 -0
- supervisely/nn/training/loggers/base_train_logger.py +58 -0
- supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
- supervisely/nn/training/train_app.py +2038 -0
- supervisely/nn/utils.py +5 -0
- supervisely/project/project.py +1 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/METADATA +3 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/RECORD +57 -35
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/LICENSE +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/WHEEL +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,2038 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TrainApp module.
|
|
3
|
+
|
|
4
|
+
This module contains the `TrainApp` class and related functionality to facilitate
|
|
5
|
+
training workflows in a Supervisely application.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import shutil
|
|
9
|
+
import subprocess
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from os import listdir
|
|
12
|
+
from os.path import basename, isdir, isfile, join
|
|
13
|
+
from typing import Any, Dict, List, Optional, Union
|
|
14
|
+
from urllib.request import urlopen
|
|
15
|
+
|
|
16
|
+
import httpx
|
|
17
|
+
import yaml
|
|
18
|
+
from fastapi import Request, Response
|
|
19
|
+
from fastapi.responses import StreamingResponse
|
|
20
|
+
from starlette.background import BackgroundTask
|
|
21
|
+
|
|
22
|
+
import supervisely.io.env as sly_env
|
|
23
|
+
import supervisely.io.fs as sly_fs
|
|
24
|
+
import supervisely.io.json as sly_json
|
|
25
|
+
from supervisely import (
|
|
26
|
+
Api,
|
|
27
|
+
Application,
|
|
28
|
+
DatasetInfo,
|
|
29
|
+
OpenMode,
|
|
30
|
+
Project,
|
|
31
|
+
ProjectInfo,
|
|
32
|
+
ProjectMeta,
|
|
33
|
+
WorkflowMeta,
|
|
34
|
+
WorkflowSettings,
|
|
35
|
+
download_project,
|
|
36
|
+
is_development,
|
|
37
|
+
is_production,
|
|
38
|
+
logger,
|
|
39
|
+
)
|
|
40
|
+
from supervisely._utils import get_filename_from_headers
|
|
41
|
+
from supervisely.api.file_api import FileInfo
|
|
42
|
+
from supervisely.app import get_synced_data_dir
|
|
43
|
+
from supervisely.app.widgets import Progress
|
|
44
|
+
from supervisely.nn.benchmark import (
|
|
45
|
+
InstanceSegmentationBenchmark,
|
|
46
|
+
InstanceSegmentationEvaluator,
|
|
47
|
+
ObjectDetectionBenchmark,
|
|
48
|
+
ObjectDetectionEvaluator,
|
|
49
|
+
SemanticSegmentationBenchmark,
|
|
50
|
+
SemanticSegmentationEvaluator,
|
|
51
|
+
)
|
|
52
|
+
from supervisely.nn.inference import RuntimeType, SessionJSON
|
|
53
|
+
from supervisely.nn.task_type import TaskType
|
|
54
|
+
from supervisely.nn.training.gui.gui import TrainGUI
|
|
55
|
+
from supervisely.nn.training.loggers.tensorboard_logger import tb_logger
|
|
56
|
+
from supervisely.nn.utils import ModelSource
|
|
57
|
+
from supervisely.output import set_directory
|
|
58
|
+
from supervisely.project.download import (
|
|
59
|
+
copy_from_cache,
|
|
60
|
+
download_to_cache,
|
|
61
|
+
get_cache_size,
|
|
62
|
+
is_cached,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TrainApp:
|
|
67
|
+
"""
|
|
68
|
+
A class representing the training application.
|
|
69
|
+
|
|
70
|
+
This class initializes and manages the training workflow, including
|
|
71
|
+
handling inputs, hyperparameters, project management, and output artifacts.
|
|
72
|
+
|
|
73
|
+
:param framework_name: Name of the ML framework used.
|
|
74
|
+
:type framework_name: str
|
|
75
|
+
:param models: List of model configurations.
|
|
76
|
+
:type models: List[Dict[str, Any]]
|
|
77
|
+
:param hyperparameters: Path or string content of hyperparameters in YAML format.
|
|
78
|
+
:type hyperparameters: str
|
|
79
|
+
:param app_options: Options for the application layout and behavior.
|
|
80
|
+
:type app_options: Optional[Dict[str, Any]]
|
|
81
|
+
:param work_dir: Path to the working directory for storing intermediate files.
|
|
82
|
+
:type work_dir: Optional[str]
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
framework_name: str,
|
|
88
|
+
models: Union[str, List[Dict[str, Any]]],
|
|
89
|
+
hyperparameters: str,
|
|
90
|
+
app_options: Union[str, Dict[str, Any]] = None,
|
|
91
|
+
work_dir: str = None,
|
|
92
|
+
):
|
|
93
|
+
|
|
94
|
+
# Init
|
|
95
|
+
self._api = Api.from_env()
|
|
96
|
+
|
|
97
|
+
# Constants
|
|
98
|
+
self._experiment_json_file = "experiment_info.json"
|
|
99
|
+
self._app_state_file = "app_state.json"
|
|
100
|
+
self._train_val_split_file = "train_val_split.json"
|
|
101
|
+
self._hyperparameters_file = "hyperparameters.yaml"
|
|
102
|
+
self._model_meta_file = "model_meta.json"
|
|
103
|
+
|
|
104
|
+
self._sly_project_dir_name = "sly_project"
|
|
105
|
+
self._model_dir_name = "model"
|
|
106
|
+
self._log_dir_name = "logs"
|
|
107
|
+
self._output_dir_name = "output"
|
|
108
|
+
self._output_checkpoints_dir_name = "checkpoints"
|
|
109
|
+
self._remote_checkpoints_dir_name = "checkpoints"
|
|
110
|
+
self._experiments_dir_name = "experiments"
|
|
111
|
+
self._default_work_dir_name = "work_dir"
|
|
112
|
+
self._tensorboard_port = 6006
|
|
113
|
+
|
|
114
|
+
if is_production():
|
|
115
|
+
self.task_id = sly_env.task_id()
|
|
116
|
+
else:
|
|
117
|
+
self.task_id = "debug-session"
|
|
118
|
+
logger.info("TrainApp is running in debug mode")
|
|
119
|
+
|
|
120
|
+
self.framework_name = framework_name
|
|
121
|
+
self._team_id = sly_env.team_id()
|
|
122
|
+
self._workspace_id = sly_env.workspace_id()
|
|
123
|
+
self._app_name = sly_env.app_name(raise_not_found=False)
|
|
124
|
+
|
|
125
|
+
# TODO: read files
|
|
126
|
+
self._models = self._load_models(models)
|
|
127
|
+
self._hyperparameters = self._load_hyperparameters(hyperparameters)
|
|
128
|
+
self._app_options = self._load_app_options(app_options)
|
|
129
|
+
self._inference_class = None
|
|
130
|
+
# ----------------------------------------- #
|
|
131
|
+
|
|
132
|
+
# Directories
|
|
133
|
+
if work_dir is not None:
|
|
134
|
+
self.work_dir = work_dir
|
|
135
|
+
else:
|
|
136
|
+
self.work_dir = join(get_synced_data_dir(), self._default_work_dir_name)
|
|
137
|
+
self.output_dir = join(self.work_dir, self._output_dir_name)
|
|
138
|
+
self._output_checkpoints_dir = join(self.output_dir, self._output_checkpoints_dir_name)
|
|
139
|
+
self.project_dir = join(self.work_dir, self._sly_project_dir_name)
|
|
140
|
+
self.train_dataset_dir = join(self.project_dir, "train")
|
|
141
|
+
self.val_dataset_dir = join(self.project_dir, "val")
|
|
142
|
+
self.sly_project = None
|
|
143
|
+
self.train_split, self.val_split = None, None
|
|
144
|
+
# -------------------------- #
|
|
145
|
+
|
|
146
|
+
# Input
|
|
147
|
+
# ----------------------------------------- #
|
|
148
|
+
|
|
149
|
+
# Classes
|
|
150
|
+
# ----------------------------------------- #
|
|
151
|
+
|
|
152
|
+
# Model
|
|
153
|
+
self.model_files = {}
|
|
154
|
+
self.model_dir = join(self.work_dir, self._model_dir_name)
|
|
155
|
+
self.log_dir = join(self.work_dir, self._log_dir_name)
|
|
156
|
+
# ----------------------------------------- #
|
|
157
|
+
|
|
158
|
+
# Hyperparameters
|
|
159
|
+
# ----------------------------------------- #
|
|
160
|
+
|
|
161
|
+
# Layout
|
|
162
|
+
self.gui: TrainGUI = TrainGUI(
|
|
163
|
+
self.framework_name, self._models, self._hyperparameters, self._app_options
|
|
164
|
+
)
|
|
165
|
+
self.app = Application(layout=self.gui.layout)
|
|
166
|
+
self._server = self.app.get_server()
|
|
167
|
+
self._train_func = None
|
|
168
|
+
|
|
169
|
+
# Benchmark parameters
|
|
170
|
+
if self.is_model_benchmark_enabled:
|
|
171
|
+
self._benchmark_params = {
|
|
172
|
+
"model_files": {},
|
|
173
|
+
"model_source": ModelSource.CUSTOM,
|
|
174
|
+
"model_info": {},
|
|
175
|
+
"device": None,
|
|
176
|
+
"runtime": RuntimeType.PYTORCH,
|
|
177
|
+
}
|
|
178
|
+
# -------------------------- #
|
|
179
|
+
|
|
180
|
+
# Train endpoints
|
|
181
|
+
@self._server.post("/train_from_api")
|
|
182
|
+
def _train_from_api(response: Response, request: Request):
|
|
183
|
+
try:
|
|
184
|
+
state = request.state.state
|
|
185
|
+
app_state = state["app_state"]
|
|
186
|
+
self.gui.load_from_app_state(app_state)
|
|
187
|
+
|
|
188
|
+
self._wrapped_start_training()
|
|
189
|
+
|
|
190
|
+
return {"result": "model was successfully trained"}
|
|
191
|
+
except Exception as e:
|
|
192
|
+
self.gui.training_process.start_button.loading = False
|
|
193
|
+
raise e
|
|
194
|
+
|
|
195
|
+
def _register_routes(self):
|
|
196
|
+
"""
|
|
197
|
+
Registers API routes for TensorBoard and training endpoints.
|
|
198
|
+
|
|
199
|
+
These routes enable communication with the application for training
|
|
200
|
+
and visualizing logs in TensorBoard.
|
|
201
|
+
"""
|
|
202
|
+
client = httpx.AsyncClient(base_url=f"http://127.0.0.1:{self._tensorboard_port}/")
|
|
203
|
+
|
|
204
|
+
@self._server.post("/tensorboard/{path:path}")
|
|
205
|
+
@self._server.get("/tensorboard/{path:path}")
|
|
206
|
+
async def _proxy_tensorboard(path: str, request: Request):
|
|
207
|
+
url = httpx.URL(path=path, query=request.url.query.encode("utf-8"))
|
|
208
|
+
headers = [(k, v) for k, v in request.headers.raw if k != b"host"]
|
|
209
|
+
req = client.build_request(
|
|
210
|
+
request.method, url, headers=headers, content=request.stream()
|
|
211
|
+
)
|
|
212
|
+
r = await client.send(req, stream=True)
|
|
213
|
+
return StreamingResponse(
|
|
214
|
+
r.aiter_raw(),
|
|
215
|
+
status_code=r.status_code,
|
|
216
|
+
headers=r.headers,
|
|
217
|
+
background=BackgroundTask(r.aclose),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def _prepare_working_dir(self):
|
|
221
|
+
"""
|
|
222
|
+
Prepares the working directory by creating required subdirectories.
|
|
223
|
+
"""
|
|
224
|
+
sly_fs.mkdir(self.work_dir, True)
|
|
225
|
+
sly_fs.mkdir(self.output_dir, True)
|
|
226
|
+
sly_fs.mkdir(self._output_checkpoints_dir, True)
|
|
227
|
+
sly_fs.mkdir(self.project_dir, True)
|
|
228
|
+
sly_fs.mkdir(self.model_dir, True)
|
|
229
|
+
sly_fs.mkdir(self.log_dir, True)
|
|
230
|
+
|
|
231
|
+
# Properties
|
|
232
|
+
# General
|
|
233
|
+
# ----------------------------------------- #
|
|
234
|
+
|
|
235
|
+
# Input Data
|
|
236
|
+
@property
|
|
237
|
+
def project_id(self) -> int:
|
|
238
|
+
"""
|
|
239
|
+
Returns the ID of the project.
|
|
240
|
+
|
|
241
|
+
:return: Project ID.
|
|
242
|
+
:rtype: int
|
|
243
|
+
"""
|
|
244
|
+
return self.gui.project_id
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def project_name(self) -> str:
|
|
248
|
+
"""
|
|
249
|
+
Returns the name of the project.
|
|
250
|
+
|
|
251
|
+
:return: Project name.
|
|
252
|
+
:rtype: str
|
|
253
|
+
"""
|
|
254
|
+
return self.gui.project_info.name
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def project_info(self) -> ProjectInfo:
|
|
258
|
+
"""
|
|
259
|
+
Returns ProjectInfo object, which contains information about the project.
|
|
260
|
+
|
|
261
|
+
:return: Project name.
|
|
262
|
+
:rtype: str
|
|
263
|
+
"""
|
|
264
|
+
return self.gui.project_info
|
|
265
|
+
|
|
266
|
+
# ----------------------------------------- #
|
|
267
|
+
|
|
268
|
+
# Model
|
|
269
|
+
@property
|
|
270
|
+
def model_source(self) -> str:
|
|
271
|
+
"""
|
|
272
|
+
Return whether the model is pretrained or custom.
|
|
273
|
+
|
|
274
|
+
:return: Model source.
|
|
275
|
+
:rtype: str
|
|
276
|
+
"""
|
|
277
|
+
return self.gui.model_selector.get_model_source()
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def model_name(self) -> str:
|
|
281
|
+
"""
|
|
282
|
+
Returns the name of the model.
|
|
283
|
+
|
|
284
|
+
:return: Model name.
|
|
285
|
+
:rtype: str
|
|
286
|
+
"""
|
|
287
|
+
return self.gui.model_selector.get_model_name()
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def model_info(self) -> dict:
|
|
291
|
+
"""
|
|
292
|
+
Returns a selected row in dict format from the models table.
|
|
293
|
+
|
|
294
|
+
:return: Model name.
|
|
295
|
+
:rtype: str
|
|
296
|
+
"""
|
|
297
|
+
return self.gui.model_selector.get_model_info()
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def model_meta(self) -> ProjectMeta:
|
|
301
|
+
"""
|
|
302
|
+
Returns the model metadata.
|
|
303
|
+
|
|
304
|
+
:return: Model metadata.
|
|
305
|
+
:rtype: dict
|
|
306
|
+
"""
|
|
307
|
+
project_meta_json = self.sly_project.meta.to_json()
|
|
308
|
+
model_meta = {
|
|
309
|
+
"classes": [
|
|
310
|
+
item for item in project_meta_json["classes"] if item["title"] in self.classes
|
|
311
|
+
]
|
|
312
|
+
}
|
|
313
|
+
return ProjectMeta.from_json(model_meta)
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def device(self) -> str:
|
|
317
|
+
"""
|
|
318
|
+
Returns the selected device for training.
|
|
319
|
+
|
|
320
|
+
:return: Device name.
|
|
321
|
+
:rtype: str
|
|
322
|
+
"""
|
|
323
|
+
return self.gui.training_process.get_device()
|
|
324
|
+
|
|
325
|
+
# Classes
|
|
326
|
+
@property
|
|
327
|
+
def classes(self) -> List[str]:
|
|
328
|
+
"""
|
|
329
|
+
Returns the selected classes for training.
|
|
330
|
+
|
|
331
|
+
:return: List of selected classes.
|
|
332
|
+
:rtype: List[str]
|
|
333
|
+
"""
|
|
334
|
+
return self.gui.classes_selector.get_selected_classes()
|
|
335
|
+
|
|
336
|
+
@property
|
|
337
|
+
def num_classes(self) -> int:
|
|
338
|
+
"""
|
|
339
|
+
Returns the number of selected classes for training.
|
|
340
|
+
|
|
341
|
+
:return: Number of selected classes.
|
|
342
|
+
:rtype: int
|
|
343
|
+
"""
|
|
344
|
+
return len(self.gui.classes_selector.get_selected_classes())
|
|
345
|
+
|
|
346
|
+
# Hyperparameters
|
|
347
|
+
@property
|
|
348
|
+
def hyperparameters(self) -> Dict[str, Any]:
|
|
349
|
+
"""
|
|
350
|
+
Returns the selected hyperparameters for training in dict format.
|
|
351
|
+
|
|
352
|
+
:return: Hyperparameters in dict format.
|
|
353
|
+
:rtype: Dict[str, Any]
|
|
354
|
+
"""
|
|
355
|
+
return yaml.safe_load(self.hyperparameters_yaml)
|
|
356
|
+
|
|
357
|
+
@property
|
|
358
|
+
def hyperparameters_yaml(self) -> str:
|
|
359
|
+
"""
|
|
360
|
+
Returns the selected hyperparameters for training in raw format as a string.
|
|
361
|
+
|
|
362
|
+
:return: Hyperparameters in raw format.
|
|
363
|
+
:rtype: str
|
|
364
|
+
"""
|
|
365
|
+
return self.gui.hyperparameters_selector.get_hyperparameters()
|
|
366
|
+
|
|
367
|
+
# Train Process
|
|
368
|
+
@property
|
|
369
|
+
def progress_bar_main(self) -> Progress:
|
|
370
|
+
"""
|
|
371
|
+
Returns the main progress bar widget.
|
|
372
|
+
|
|
373
|
+
:return: Main progress bar widget.
|
|
374
|
+
:rtype: Progress
|
|
375
|
+
"""
|
|
376
|
+
return self.gui.training_logs.progress_bar_main
|
|
377
|
+
|
|
378
|
+
@property
|
|
379
|
+
def progress_bar_secondary(self) -> Progress:
|
|
380
|
+
"""
|
|
381
|
+
Returns the secondary progress bar widget.
|
|
382
|
+
|
|
383
|
+
:return: Secondary progress bar widget.
|
|
384
|
+
:rtype: Progress
|
|
385
|
+
"""
|
|
386
|
+
return self.gui.training_logs.progress_bar_secondary
|
|
387
|
+
|
|
388
|
+
@property
|
|
389
|
+
def is_model_benchmark_enabled(self) -> bool:
|
|
390
|
+
"""
|
|
391
|
+
Checks if model benchmarking is enabled based on application options and GUI settings.
|
|
392
|
+
|
|
393
|
+
:return: True if model benchmarking is enabled, False otherwise.
|
|
394
|
+
:rtype: bool
|
|
395
|
+
"""
|
|
396
|
+
return (
|
|
397
|
+
self._app_options.get("model_benchmark", True)
|
|
398
|
+
and self.gui.hyperparameters_selector.get_model_benchmark_checkbox_value()
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Output
|
|
402
|
+
# ----------------------------------------- #
|
|
403
|
+
|
|
404
|
+
# region TRAIN START
|
|
405
|
+
@property
|
|
406
|
+
def start(self):
|
|
407
|
+
"""
|
|
408
|
+
Decorator for the training function defined by user.
|
|
409
|
+
It wraps user-defined training function and prepares and finalizes the training process.
|
|
410
|
+
"""
|
|
411
|
+
|
|
412
|
+
def decorator(func):
|
|
413
|
+
self._train_func = func
|
|
414
|
+
self.gui.training_process.start_button.click(self._wrapped_start_training)
|
|
415
|
+
return func
|
|
416
|
+
|
|
417
|
+
return decorator
|
|
418
|
+
|
|
419
|
+
def _prepare(self) -> None:
|
|
420
|
+
"""
|
|
421
|
+
Prepares the environment for training by setting up directories,
|
|
422
|
+
downloading project and model data.
|
|
423
|
+
"""
|
|
424
|
+
logger.info("Preparing for training")
|
|
425
|
+
self.gui.disable_select_buttons()
|
|
426
|
+
|
|
427
|
+
# Step 1. Workflow Input
|
|
428
|
+
if is_production():
|
|
429
|
+
self._workflow_input()
|
|
430
|
+
# Step 2. Download Project
|
|
431
|
+
self._download_project()
|
|
432
|
+
# Step 3. Split Project
|
|
433
|
+
self._split_project()
|
|
434
|
+
# Step 4. Convert Supervisely to X format
|
|
435
|
+
# Step 5. Download Model files
|
|
436
|
+
self._download_model()
|
|
437
|
+
|
|
438
|
+
def _finalize(self, experiment_info: dict) -> None:
|
|
439
|
+
"""
|
|
440
|
+
Finalizes the training process by validating outputs, uploading artifacts,
|
|
441
|
+
and updating the UI.
|
|
442
|
+
|
|
443
|
+
:param experiment_info: Information about the experiment results that should be returned in user's training function.
|
|
444
|
+
:type experiment_info: dict
|
|
445
|
+
"""
|
|
446
|
+
logger.info("Finalizing training")
|
|
447
|
+
|
|
448
|
+
# Step 1. Validate experiment_info
|
|
449
|
+
success, reason = self._validate_experiment_info(experiment_info)
|
|
450
|
+
if not success:
|
|
451
|
+
raise ValueError(f"{reason}. Failed to upload artifacts")
|
|
452
|
+
|
|
453
|
+
# Step 2. Preprocess artifacts
|
|
454
|
+
self._preprocess_artifacts(experiment_info)
|
|
455
|
+
|
|
456
|
+
# Step3. Postprocess splits
|
|
457
|
+
splits_data = self._postprocess_splits()
|
|
458
|
+
|
|
459
|
+
# Step 3. Upload artifacts
|
|
460
|
+
remote_dir, file_info = self._upload_artifacts()
|
|
461
|
+
|
|
462
|
+
# Step 4. Run Model Benchmark
|
|
463
|
+
mb_eval_report, mb_eval_report_id = None, None
|
|
464
|
+
|
|
465
|
+
if self.is_model_benchmark_enabled:
|
|
466
|
+
try:
|
|
467
|
+
mb_eval_report, mb_eval_report_id = self._run_model_benchmark(
|
|
468
|
+
self.output_dir, remote_dir, experiment_info, splits_data
|
|
469
|
+
)
|
|
470
|
+
except Exception as e:
|
|
471
|
+
logger.error(f"Model benchmark failed: {e}")
|
|
472
|
+
|
|
473
|
+
# Step 4. Generate and upload additional files
|
|
474
|
+
self._generate_experiment_info(remote_dir, experiment_info, mb_eval_report_id)
|
|
475
|
+
self._generate_app_state(remote_dir, experiment_info)
|
|
476
|
+
self._generate_hyperparameters(remote_dir, experiment_info)
|
|
477
|
+
self._generate_train_val_splits(remote_dir, splits_data)
|
|
478
|
+
self._generate_model_meta(remote_dir, experiment_info)
|
|
479
|
+
|
|
480
|
+
# Step 5. Set output widgets
|
|
481
|
+
self._set_training_output(remote_dir, file_info)
|
|
482
|
+
|
|
483
|
+
# Step 6. Workflow output
|
|
484
|
+
if is_production():
|
|
485
|
+
self._workflow_output(remote_dir, file_info, mb_eval_report)
|
|
486
|
+
|
|
487
|
+
# region TRAIN END
|
|
488
|
+
|
|
489
|
+
def register_inference_class(self, inference_class: Any, inference_settings: dict = {}) -> None:
|
|
490
|
+
"""
|
|
491
|
+
Registers an inference class for the training application to do model benchmarking.
|
|
492
|
+
|
|
493
|
+
:param inference_class: Inference class to be registered inherited from `supervisely.nn.inference.Inference`.
|
|
494
|
+
:type inference_class: Any
|
|
495
|
+
:param inference_settings: Settings for the inference class.
|
|
496
|
+
:type inference_settings: dict
|
|
497
|
+
"""
|
|
498
|
+
self._inference_class = inference_class
|
|
499
|
+
self._inference_settings = inference_settings
|
|
500
|
+
|
|
501
|
+
def get_app_state(self, experiment_info: dict = None) -> dict:
|
|
502
|
+
"""
|
|
503
|
+
Returns the current state of the application.
|
|
504
|
+
|
|
505
|
+
:return: Application state.
|
|
506
|
+
:rtype: dict
|
|
507
|
+
"""
|
|
508
|
+
input_data = {"project_id": self.project_id}
|
|
509
|
+
train_val_splits = self._get_train_val_splits_for_app_state()
|
|
510
|
+
model = self._get_model_config_for_app_state(experiment_info)
|
|
511
|
+
|
|
512
|
+
options = {
|
|
513
|
+
"model_benchmark": {
|
|
514
|
+
"enable": self.gui.hyperparameters_selector.get_model_benchmark_checkbox_value(),
|
|
515
|
+
"speed_test": self.gui.hyperparameters_selector.get_speedtest_checkbox_value(),
|
|
516
|
+
},
|
|
517
|
+
"cache_project": self.gui.input_selector.get_cache_value(),
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
app_state = {
|
|
521
|
+
"input": input_data,
|
|
522
|
+
"train_val_split": train_val_splits,
|
|
523
|
+
"classes": self.classes,
|
|
524
|
+
"model": model,
|
|
525
|
+
"hyperparameters": self.hyperparameters_yaml,
|
|
526
|
+
"options": options,
|
|
527
|
+
}
|
|
528
|
+
return app_state
|
|
529
|
+
|
|
530
|
+
def load_app_state(self, app_state: dict) -> None:
|
|
531
|
+
"""
|
|
532
|
+
Load the GUI state from app state dictionary.
|
|
533
|
+
|
|
534
|
+
:param app_state: The state dictionary.
|
|
535
|
+
:type app_state: dict
|
|
536
|
+
|
|
537
|
+
app_state example:
|
|
538
|
+
|
|
539
|
+
app_state = {
|
|
540
|
+
"input": {"project_id": 55555},
|
|
541
|
+
"train_val_splits": {
|
|
542
|
+
"method": "random",
|
|
543
|
+
"split": "train",
|
|
544
|
+
"percent": 90
|
|
545
|
+
},
|
|
546
|
+
"classes": ["apple"],
|
|
547
|
+
"model": {
|
|
548
|
+
"source": "Pretrained models",
|
|
549
|
+
"model_name": "rtdetr_r50vd_coco_objects365"
|
|
550
|
+
},
|
|
551
|
+
"hyperparameters": hyperparameters, # yaml string
|
|
552
|
+
"options": {
|
|
553
|
+
"model_benchmark": {
|
|
554
|
+
"enable": True,
|
|
555
|
+
"speed_test": True
|
|
556
|
+
},
|
|
557
|
+
"cache_project": True
|
|
558
|
+
}
|
|
559
|
+
}
|
|
560
|
+
"""
|
|
561
|
+
self.gui.load_from_app_state(app_state)
|
|
562
|
+
|
|
563
|
+
# Loaders
|
|
564
|
+
def _load_models(self, models: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
|
565
|
+
"""
|
|
566
|
+
Loads models from the provided file or list of model configurations.
|
|
567
|
+
"""
|
|
568
|
+
if isinstance(models, str):
|
|
569
|
+
if sly_fs.file_exists(models) and sly_fs.get_file_ext(models) == ".json":
|
|
570
|
+
models = sly_json.load_json_file(models)
|
|
571
|
+
else:
|
|
572
|
+
raise ValueError(
|
|
573
|
+
"Invalid models file. Please provide a valid '.json' file or a list of model configurations."
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
if not isinstance(models, list):
|
|
577
|
+
raise ValueError("models parameters must be a list of dicts")
|
|
578
|
+
for item in models:
|
|
579
|
+
if not isinstance(item, dict):
|
|
580
|
+
raise ValueError(f"Each item in models must be a dict.")
|
|
581
|
+
model_meta = item.get("meta")
|
|
582
|
+
if model_meta is None:
|
|
583
|
+
raise ValueError(
|
|
584
|
+
"Model metadata not found. Please update provided models parameter to include key 'meta'."
|
|
585
|
+
)
|
|
586
|
+
model_files = model_meta.get("model_files")
|
|
587
|
+
if model_files is None:
|
|
588
|
+
raise ValueError(
|
|
589
|
+
"Model files not found in model metadata. "
|
|
590
|
+
"Please update provided models oarameter to include key 'model_files' in 'meta' key."
|
|
591
|
+
)
|
|
592
|
+
return models
|
|
593
|
+
|
|
594
|
+
def _load_hyperparameters(self, hyperparameters: str) -> dict:
|
|
595
|
+
"""
|
|
596
|
+
Loads hyperparameters from file path.
|
|
597
|
+
|
|
598
|
+
:param hyperparameters: Path to hyperparameters file.
|
|
599
|
+
:type hyperparameters: str
|
|
600
|
+
:return: Hyperparameters in dict format.
|
|
601
|
+
:rtype: dict
|
|
602
|
+
"""
|
|
603
|
+
if not isinstance(hyperparameters, str):
|
|
604
|
+
raise ValueError(
|
|
605
|
+
f"Expected a string with config or path for hyperparameters, but got {type(hyperparameters).__name__}"
|
|
606
|
+
)
|
|
607
|
+
if hyperparameters.endswith((".yml", ".yaml")):
|
|
608
|
+
try:
|
|
609
|
+
with open(hyperparameters, "r") as file:
|
|
610
|
+
return file.read()
|
|
611
|
+
except Exception as e:
|
|
612
|
+
raise ValueError(f"Failed to load YAML file: {hyperparameters}. Error: {e}")
|
|
613
|
+
return hyperparameters
|
|
614
|
+
|
|
615
|
+
def _load_app_options(self, app_options: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
|
|
616
|
+
"""
|
|
617
|
+
Loads the app_options parameter to ensure it is in the correct format.
|
|
618
|
+
"""
|
|
619
|
+
if app_options is None:
|
|
620
|
+
return {}
|
|
621
|
+
|
|
622
|
+
if isinstance(app_options, str):
|
|
623
|
+
if sly_fs.file_exists(app_options) and sly_fs.get_file_ext(app_options) in [
|
|
624
|
+
".yaml",
|
|
625
|
+
".yml",
|
|
626
|
+
]:
|
|
627
|
+
app_options = self._load_yaml(app_options)
|
|
628
|
+
else:
|
|
629
|
+
raise ValueError(
|
|
630
|
+
"Invalid app_options file provided. Please provide a valid '.yaml' or '.yml' file or a dictionary with app_options."
|
|
631
|
+
)
|
|
632
|
+
if not isinstance(app_options, dict):
|
|
633
|
+
raise ValueError("app_options must be a dict")
|
|
634
|
+
return app_options
|
|
635
|
+
|
|
636
|
+
def _load_yaml(self, path: str) -> dict:
|
|
637
|
+
"""
|
|
638
|
+
Load a YAML file from the specified path.
|
|
639
|
+
|
|
640
|
+
:param path: Path to the YAML file.
|
|
641
|
+
:type path: str
|
|
642
|
+
:return: YAML file contents.
|
|
643
|
+
:rtype: dict
|
|
644
|
+
"""
|
|
645
|
+
with open(path, "r") as file:
|
|
646
|
+
return yaml.safe_load(file)
|
|
647
|
+
|
|
648
|
+
# ----------------------------------------- #
|
|
649
|
+
|
|
650
|
+
# Preprocess
|
|
651
|
+
# Download Project
|
|
652
|
+
def _download_project(self) -> None:
|
|
653
|
+
"""
|
|
654
|
+
Downloads the project data from Supervisely.
|
|
655
|
+
If the cache is enabled, it will attempt to retrieve the project from the cache.
|
|
656
|
+
"""
|
|
657
|
+
dataset_infos = [dataset for _, dataset in self._api.dataset.tree(self.project_id)]
|
|
658
|
+
|
|
659
|
+
if self.gui.train_val_splits_selector.get_split_method() == "Based on datasets":
|
|
660
|
+
selected_ds_ids = (
|
|
661
|
+
self.gui.train_val_splits_selector.get_train_dataset_ids()
|
|
662
|
+
+ self.gui.train_val_splits_selector.get_val_dataset_ids()
|
|
663
|
+
)
|
|
664
|
+
dataset_infos = [ds_info for ds_info in dataset_infos if ds_info.id in selected_ds_ids]
|
|
665
|
+
|
|
666
|
+
total_images = sum(ds_info.images_count for ds_info in dataset_infos)
|
|
667
|
+
if not self.gui.input_selector.get_cache_value() or is_development():
|
|
668
|
+
self._download_no_cache(dataset_infos, total_images)
|
|
669
|
+
self.sly_project = Project(self.project_dir, OpenMode.READ)
|
|
670
|
+
return
|
|
671
|
+
|
|
672
|
+
try:
|
|
673
|
+
self._download_with_cache(dataset_infos, total_images)
|
|
674
|
+
except Exception:
|
|
675
|
+
logger.warning(
|
|
676
|
+
"Failed to retrieve project from cache. Downloading it",
|
|
677
|
+
exc_info=True,
|
|
678
|
+
)
|
|
679
|
+
if sly_fs.dir_exists(self.project_dir):
|
|
680
|
+
sly_fs.clean_dir(self.project_dir)
|
|
681
|
+
self._download_no_cache(dataset_infos, total_images)
|
|
682
|
+
finally:
|
|
683
|
+
self.sly_project = Project(self.project_dir, OpenMode.READ)
|
|
684
|
+
logger.info(f"Project downloaded successfully to: '{self.project_dir}'")
|
|
685
|
+
|
|
686
|
+
def _download_no_cache(self, dataset_infos: List[DatasetInfo], total_images: int) -> None:
|
|
687
|
+
"""
|
|
688
|
+
Downloads the project data from Supervisely without using the cache.
|
|
689
|
+
|
|
690
|
+
:param dataset_infos: List of dataset information objects.
|
|
691
|
+
:type dataset_infos: List[DatasetInfo]
|
|
692
|
+
:param total_images: Total number of images to download.
|
|
693
|
+
:type total_images: int
|
|
694
|
+
"""
|
|
695
|
+
with self.progress_bar_main(message="Downloading input data", total=total_images) as pbar:
|
|
696
|
+
self.progress_bar_main.show()
|
|
697
|
+
download_project(
|
|
698
|
+
api=self._api,
|
|
699
|
+
project_id=self.project_id,
|
|
700
|
+
dest_dir=self.project_dir,
|
|
701
|
+
dataset_ids=[ds_info.id for ds_info in dataset_infos],
|
|
702
|
+
log_progress=True,
|
|
703
|
+
progress_cb=pbar.update,
|
|
704
|
+
)
|
|
705
|
+
self.progress_bar_main.hide()
|
|
706
|
+
|
|
707
|
+
def _download_with_cache(
|
|
708
|
+
self,
|
|
709
|
+
dataset_infos: List[DatasetInfo],
|
|
710
|
+
total_images: int,
|
|
711
|
+
) -> None:
|
|
712
|
+
"""
|
|
713
|
+
Downloads the project data from Supervisely using the cache.
|
|
714
|
+
|
|
715
|
+
:param dataset_infos: List of dataset information objects.
|
|
716
|
+
:type dataset_infos: List[DatasetInfo]
|
|
717
|
+
:param total_images: Total number of images to download.
|
|
718
|
+
:type total_images: int
|
|
719
|
+
"""
|
|
720
|
+
to_download = [
|
|
721
|
+
info for info in dataset_infos if not is_cached(self.project_info.id, info.name)
|
|
722
|
+
]
|
|
723
|
+
cached = [info for info in dataset_infos if is_cached(self.project_info.id, info.name)]
|
|
724
|
+
|
|
725
|
+
logger.info(self._get_cache_log_message(cached, to_download))
|
|
726
|
+
with self.progress_bar_main(message="Downloading input data", total=total_images) as pbar:
|
|
727
|
+
self.progress_bar_main.show()
|
|
728
|
+
download_to_cache(
|
|
729
|
+
api=self._api,
|
|
730
|
+
project_id=self.project_info.id,
|
|
731
|
+
dataset_infos=dataset_infos,
|
|
732
|
+
log_progress=True,
|
|
733
|
+
progress_cb=pbar.update,
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
total_cache_size = sum(
|
|
737
|
+
get_cache_size(self.project_info.id, ds.name) for ds in dataset_infos
|
|
738
|
+
)
|
|
739
|
+
with self.progress_bar_main(
|
|
740
|
+
message="Retrieving data from cache",
|
|
741
|
+
total=total_cache_size,
|
|
742
|
+
unit="B",
|
|
743
|
+
unit_scale=True,
|
|
744
|
+
unit_divisor=1024,
|
|
745
|
+
) as pbar:
|
|
746
|
+
copy_from_cache(
|
|
747
|
+
project_id=self.project_info.id,
|
|
748
|
+
dest_dir=self.project_dir,
|
|
749
|
+
dataset_names=[ds_info.name for ds_info in dataset_infos],
|
|
750
|
+
progress_cb=pbar.update,
|
|
751
|
+
)
|
|
752
|
+
self.progress_bar_main.hide()
|
|
753
|
+
|
|
754
|
+
def _get_cache_log_message(self, cached: bool, to_download: List[DatasetInfo]) -> str:
|
|
755
|
+
"""
|
|
756
|
+
Utility method to generate a log message for cache status.
|
|
757
|
+
"""
|
|
758
|
+
if not cached:
|
|
759
|
+
log_msg = "No cached datasets found"
|
|
760
|
+
else:
|
|
761
|
+
log_msg = "Using cached datasets: " + ", ".join(
|
|
762
|
+
f"{ds_info.name} ({ds_info.id})" for ds_info in cached
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
if not to_download:
|
|
766
|
+
log_msg += ". All datasets are cached. No datasets to download"
|
|
767
|
+
else:
|
|
768
|
+
log_msg += ". Downloading datasets: " + ", ".join(
|
|
769
|
+
f"{ds_info.name} ({ds_info.id})" for ds_info in to_download
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
return log_msg
|
|
773
|
+
|
|
774
|
+
# Split Project
|
|
775
|
+
def _split_project(self) -> None:
|
|
776
|
+
"""
|
|
777
|
+
Split the project into training and validation sets.
|
|
778
|
+
All images and annotations will be renamed and moved to the appropriate directories.
|
|
779
|
+
Assigns self.sly_project to the new project, which contains only 2 datasets: train and val.
|
|
780
|
+
"""
|
|
781
|
+
# Load splits
|
|
782
|
+
self.gui.train_val_splits_selector.set_sly_project(self.sly_project)
|
|
783
|
+
self._train_split, self._val_split = (
|
|
784
|
+
self.gui.train_val_splits_selector.train_val_splits.get_splits()
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
# Prepare paths
|
|
788
|
+
project_split_path = join(self.work_dir, "splits")
|
|
789
|
+
paths = {
|
|
790
|
+
"train": {
|
|
791
|
+
"split_path": join(project_split_path, "train"),
|
|
792
|
+
"img_dir": join(project_split_path, "train", "img"),
|
|
793
|
+
"ann_dir": join(project_split_path, "train", "ann"),
|
|
794
|
+
},
|
|
795
|
+
"val": {
|
|
796
|
+
"split_path": join(project_split_path, "val"),
|
|
797
|
+
"img_dir": join(project_split_path, "val", "img"),
|
|
798
|
+
"ann_dir": join(project_split_path, "val", "ann"),
|
|
799
|
+
},
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
# Create necessary directories (only once)
|
|
803
|
+
for dataset_paths in paths.values():
|
|
804
|
+
for path in dataset_paths.values():
|
|
805
|
+
sly_fs.mkdir(path, True)
|
|
806
|
+
|
|
807
|
+
# Format for image names
|
|
808
|
+
items_count = max(len(self._train_split), len(self._val_split))
|
|
809
|
+
num_digits = len(str(items_count))
|
|
810
|
+
image_name_formats = {
|
|
811
|
+
"train": f"train_img_{{:0{num_digits}d}}",
|
|
812
|
+
"val": f"val_img_{{:0{num_digits}d}}",
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
# Utility function to move files
|
|
816
|
+
def move_files(split, paths, img_name_format, pbar):
|
|
817
|
+
"""
|
|
818
|
+
Move files to the appropriate directories.
|
|
819
|
+
"""
|
|
820
|
+
for idx, item in enumerate(split, start=1):
|
|
821
|
+
item_name = img_name_format.format(idx) + sly_fs.get_file_ext(item.name)
|
|
822
|
+
ann_name = f"{item_name}.json"
|
|
823
|
+
shutil.copy(item.img_path, join(paths["img_dir"], item_name))
|
|
824
|
+
shutil.copy(item.ann_path, join(paths["ann_dir"], ann_name))
|
|
825
|
+
pbar.update(1)
|
|
826
|
+
|
|
827
|
+
# Main split processing
|
|
828
|
+
with self.progress_bar_main(
|
|
829
|
+
message="Applying train/val splits to project", total=2
|
|
830
|
+
) as main_pbar:
|
|
831
|
+
self.progress_bar_main.show()
|
|
832
|
+
for dataset in ["train", "val"]:
|
|
833
|
+
split = self._train_split if dataset == "train" else self._val_split
|
|
834
|
+
with self.progress_bar_secondary(
|
|
835
|
+
message=f"Preparing '{dataset}'", total=len(split)
|
|
836
|
+
) as second_pbar:
|
|
837
|
+
self.progress_bar_secondary.show()
|
|
838
|
+
move_files(split, paths[dataset], image_name_formats[dataset], second_pbar)
|
|
839
|
+
main_pbar.update(1)
|
|
840
|
+
self.progress_bar_secondary.hide()
|
|
841
|
+
self.progress_bar_main.hide()
|
|
842
|
+
|
|
843
|
+
# Clean up project directory
|
|
844
|
+
project_datasets = [
|
|
845
|
+
join(self.project_dir, item)
|
|
846
|
+
for item in listdir(self.project_dir)
|
|
847
|
+
if isdir(join(self.project_dir, item))
|
|
848
|
+
]
|
|
849
|
+
for dataset in project_datasets:
|
|
850
|
+
sly_fs.remove_dir(dataset)
|
|
851
|
+
|
|
852
|
+
# Move processed splits to final destination
|
|
853
|
+
train_ds_path = join(self.project_dir, "train")
|
|
854
|
+
val_ds_path = join(self.project_dir, "val")
|
|
855
|
+
with self.progress_bar_main(message="Processing splits", total=2) as pbar:
|
|
856
|
+
self.progress_bar_main.show()
|
|
857
|
+
for dataset in ["train", "val"]:
|
|
858
|
+
shutil.move(
|
|
859
|
+
paths[dataset]["split_path"],
|
|
860
|
+
train_ds_path if dataset == "train" else val_ds_path,
|
|
861
|
+
)
|
|
862
|
+
pbar.update(1)
|
|
863
|
+
self.progress_bar_main.hide()
|
|
864
|
+
|
|
865
|
+
# Clean up temporary directory
|
|
866
|
+
sly_fs.remove_dir(project_split_path)
|
|
867
|
+
self.sly_project = Project(self.project_dir, OpenMode.READ)
|
|
868
|
+
|
|
869
|
+
# ----------------------------------------- #
|
|
870
|
+
|
|
871
|
+
# ----------------------------------------- #
|
|
872
|
+
# Download Model
|
|
873
|
+
def _download_model(self) -> None:
|
|
874
|
+
"""
|
|
875
|
+
Downloads the model data from the selected source.
|
|
876
|
+
- Checkpoint and config keys inside the model_files dict can be provided as local paths.
|
|
877
|
+
|
|
878
|
+
For Pretrained models:
|
|
879
|
+
- The files that will be downloaded are specified in the `meta` key under `model_files`.
|
|
880
|
+
- All files listed in the `model_files` key will be downloaded by provided link.
|
|
881
|
+
Example of a pretrained model entry:
|
|
882
|
+
[
|
|
883
|
+
{
|
|
884
|
+
"Model": "example_model",
|
|
885
|
+
"dataset": "COCO",
|
|
886
|
+
"AP_val": 46.4,
|
|
887
|
+
"Params(M)": 20,
|
|
888
|
+
"FPS(T4)": 217,
|
|
889
|
+
"meta": {
|
|
890
|
+
"task_type": "object detection",
|
|
891
|
+
"model_name": "example_model",
|
|
892
|
+
"model_files": {
|
|
893
|
+
# For remote files provide as links
|
|
894
|
+
"checkpoint": "https://example.com/checkpoint.pth",
|
|
895
|
+
"config": "https://example.com/config.yaml"
|
|
896
|
+
|
|
897
|
+
# For local files provide as paths
|
|
898
|
+
# "checkpoint": "/path/to/checkpoint.pth",
|
|
899
|
+
# "config": "/path/to/config.yaml"
|
|
900
|
+
}
|
|
901
|
+
}
|
|
902
|
+
},
|
|
903
|
+
...
|
|
904
|
+
]
|
|
905
|
+
|
|
906
|
+
For Custom models:
|
|
907
|
+
- All custom models trained inside Supervisely are managed automatically by this class.
|
|
908
|
+
"""
|
|
909
|
+
if self.model_source == ModelSource.PRETRAINED:
|
|
910
|
+
self._download_pretrained_model()
|
|
911
|
+
|
|
912
|
+
else:
|
|
913
|
+
self._download_custom_model()
|
|
914
|
+
logger.info(f"Model files have been downloaded successfully to: '{self.model_dir}'")
|
|
915
|
+
|
|
916
|
+
def _download_pretrained_model(self):
|
|
917
|
+
"""
|
|
918
|
+
Downloads the pretrained model data.
|
|
919
|
+
"""
|
|
920
|
+
# General
|
|
921
|
+
self.model_files = {}
|
|
922
|
+
model_meta = self.model_info["meta"]
|
|
923
|
+
model_files = model_meta["model_files"]
|
|
924
|
+
|
|
925
|
+
with self.progress_bar_main(
|
|
926
|
+
message="Downloading model files",
|
|
927
|
+
total=len(model_files),
|
|
928
|
+
) as model_download_main_pbar:
|
|
929
|
+
self.progress_bar_main.show()
|
|
930
|
+
for file in model_files:
|
|
931
|
+
file_url = model_files[file]
|
|
932
|
+
file_path = join(self.model_dir, file)
|
|
933
|
+
|
|
934
|
+
if file_url.startswith("http"):
|
|
935
|
+
with urlopen(file_url) as f:
|
|
936
|
+
file_size = f.length
|
|
937
|
+
file_name = get_filename_from_headers(file_url)
|
|
938
|
+
file_path = join(self.model_dir, file_name)
|
|
939
|
+
with self.progress_bar_secondary(
|
|
940
|
+
message=f"Downloading '{file_name}' ",
|
|
941
|
+
total=file_size,
|
|
942
|
+
unit="bytes",
|
|
943
|
+
unit_scale=True,
|
|
944
|
+
) as model_download_secondary_pbar:
|
|
945
|
+
self.progress_bar_secondary.show()
|
|
946
|
+
sly_fs.download(
|
|
947
|
+
url=file_url,
|
|
948
|
+
save_path=file_path,
|
|
949
|
+
progress=model_download_secondary_pbar.update,
|
|
950
|
+
)
|
|
951
|
+
self.model_files[file] = file_path
|
|
952
|
+
else:
|
|
953
|
+
self.model_files[file] = file_url
|
|
954
|
+
model_download_main_pbar.update(1)
|
|
955
|
+
|
|
956
|
+
self.progress_bar_main.hide()
|
|
957
|
+
self.progress_bar_secondary.hide()
|
|
958
|
+
|
|
959
|
+
def _download_custom_model(self):
|
|
960
|
+
"""
|
|
961
|
+
Downloads the custom model data.
|
|
962
|
+
"""
|
|
963
|
+
# General
|
|
964
|
+
self.model_files = {}
|
|
965
|
+
|
|
966
|
+
# Need to merge file_url with arts dir
|
|
967
|
+
artifacts_dir = self.model_info["artifacts_dir"]
|
|
968
|
+
model_files = self.model_info["model_files"]
|
|
969
|
+
remote_paths = {name: join(artifacts_dir, file) for name, file in model_files.items()}
|
|
970
|
+
|
|
971
|
+
# Add selected checkpoint to model_files
|
|
972
|
+
checkpoint = self.gui.model_selector.experiment_selector.get_selected_checkpoint_path()
|
|
973
|
+
remote_paths["checkpoint"] = checkpoint
|
|
974
|
+
|
|
975
|
+
with self.progress_bar_main(
|
|
976
|
+
message="Downloading model files",
|
|
977
|
+
total=len(model_files),
|
|
978
|
+
) as model_download_main_pbar:
|
|
979
|
+
self.progress_bar_main.show()
|
|
980
|
+
for name, remote_path in remote_paths.items():
|
|
981
|
+
file_info = self._api.file.get_info_by_path(self._team_id, remote_path)
|
|
982
|
+
file_name = basename(remote_path)
|
|
983
|
+
local_path = join(self.model_dir, file_name)
|
|
984
|
+
file_size = file_info.sizeb
|
|
985
|
+
|
|
986
|
+
with self.progress_bar_secondary(
|
|
987
|
+
message=f"Downloading {file_name}",
|
|
988
|
+
total=file_size,
|
|
989
|
+
unit="bytes",
|
|
990
|
+
unit_scale=True,
|
|
991
|
+
) as model_download_secondary_pbar:
|
|
992
|
+
self.progress_bar_secondary.show()
|
|
993
|
+
self._api.file.download(
|
|
994
|
+
self._team_id,
|
|
995
|
+
remote_path,
|
|
996
|
+
local_path,
|
|
997
|
+
progress_cb=model_download_secondary_pbar.update,
|
|
998
|
+
)
|
|
999
|
+
model_download_main_pbar.update(1)
|
|
1000
|
+
self.model_files[name] = local_path
|
|
1001
|
+
|
|
1002
|
+
self.progress_bar_main.hide()
|
|
1003
|
+
self.progress_bar_secondary.hide()
|
|
1004
|
+
|
|
1005
|
+
# ----------------------------------------- #
|
|
1006
|
+
|
|
1007
|
+
# Postprocess
|
|
1008
|
+
|
|
1009
|
+
def _validate_experiment_info(self, experiment_info: dict) -> tuple:
|
|
1010
|
+
"""
|
|
1011
|
+
Validates the experiment_info parameter to ensure it is in the correct format.
|
|
1012
|
+
experiment_info is returned by the user's training function.
|
|
1013
|
+
|
|
1014
|
+
experiment_info should contain the following keys:
|
|
1015
|
+
- model_name": str
|
|
1016
|
+
- task_type": str
|
|
1017
|
+
- model_files": dict
|
|
1018
|
+
- checkpoints": list
|
|
1019
|
+
- best_checkpoint": str
|
|
1020
|
+
|
|
1021
|
+
Other keys are generated by the TrainApp class automatically
|
|
1022
|
+
|
|
1023
|
+
:param experiment_info: Information about the experiment results.
|
|
1024
|
+
:type experiment_info: dict
|
|
1025
|
+
:return: Tuple of success status and reason for failure.
|
|
1026
|
+
:rtype: tuple
|
|
1027
|
+
"""
|
|
1028
|
+
if not isinstance(experiment_info, dict):
|
|
1029
|
+
reason = f"Validation failed: 'experiment_info' must be a dictionary not '{type(experiment_info)}'"
|
|
1030
|
+
return False, reason
|
|
1031
|
+
|
|
1032
|
+
logger.debug("Starting validation of 'experiment_info'")
|
|
1033
|
+
required_keys = {
|
|
1034
|
+
"model_name": str,
|
|
1035
|
+
"task_type": str,
|
|
1036
|
+
"model_files": dict,
|
|
1037
|
+
"checkpoints": (list, str),
|
|
1038
|
+
"best_checkpoint": str,
|
|
1039
|
+
}
|
|
1040
|
+
|
|
1041
|
+
for key, expected_type in required_keys.items():
|
|
1042
|
+
if key not in experiment_info:
|
|
1043
|
+
reason = f"Validation failed: Missing required key '{key}'"
|
|
1044
|
+
return False, reason
|
|
1045
|
+
|
|
1046
|
+
if not isinstance(experiment_info[key], expected_type):
|
|
1047
|
+
reason = (
|
|
1048
|
+
f"Validation failed: Key '{key}' should be of type {expected_type.__name__}"
|
|
1049
|
+
)
|
|
1050
|
+
return False, reason
|
|
1051
|
+
|
|
1052
|
+
if isinstance(experiment_info["checkpoints"], list):
|
|
1053
|
+
for checkpoint in experiment_info["checkpoints"]:
|
|
1054
|
+
if not isinstance(checkpoint, str):
|
|
1055
|
+
reason = "Validation failed: All items in 'checkpoints' list must be strings"
|
|
1056
|
+
return False, reason
|
|
1057
|
+
if not sly_fs.file_exists(checkpoint):
|
|
1058
|
+
reason = f"Validation failed: Checkpoint file: '{checkpoint}' does not exist"
|
|
1059
|
+
return False, reason
|
|
1060
|
+
|
|
1061
|
+
best_checkpoint = experiment_info["best_checkpoint"]
|
|
1062
|
+
checkpoints = experiment_info["checkpoints"]
|
|
1063
|
+
if isinstance(checkpoints, list):
|
|
1064
|
+
checkpoints = [sly_fs.get_file_name_with_ext(checkpoint) for checkpoint in checkpoints]
|
|
1065
|
+
if best_checkpoint not in checkpoints:
|
|
1066
|
+
reason = (
|
|
1067
|
+
f"Validation failed: Best checkpoint file: '{best_checkpoint}' does not exist"
|
|
1068
|
+
)
|
|
1069
|
+
return False, reason
|
|
1070
|
+
elif isinstance(checkpoints, str):
|
|
1071
|
+
checkpoints = [
|
|
1072
|
+
sly_fs.get_file_name_with_ext(checkpoint)
|
|
1073
|
+
for checkpoint in sly_fs.list_dir_recursively(checkpoints, [".pt", ".pth"])
|
|
1074
|
+
]
|
|
1075
|
+
if best_checkpoint not in checkpoints:
|
|
1076
|
+
reason = (
|
|
1077
|
+
f"Validation failed: Best checkpoint file: '{best_checkpoint}' does not exist"
|
|
1078
|
+
)
|
|
1079
|
+
return False, reason
|
|
1080
|
+
else:
|
|
1081
|
+
reason = "Validation failed: 'checkpoints' should be a list of paths or a path to directory with checkpoints"
|
|
1082
|
+
return False, reason
|
|
1083
|
+
|
|
1084
|
+
logger.debug("Validation successful")
|
|
1085
|
+
return True, None
|
|
1086
|
+
|
|
1087
|
+
def _postprocess_splits(self) -> dict:
|
|
1088
|
+
"""
|
|
1089
|
+
Processes the train and val splits to generate the necessary data for the experiment_info.json file.
|
|
1090
|
+
"""
|
|
1091
|
+
val_dataset_ids = None
|
|
1092
|
+
val_images_ids = None
|
|
1093
|
+
train_dataset_ids = None
|
|
1094
|
+
train_images_ids = None
|
|
1095
|
+
|
|
1096
|
+
split_method = self.gui.train_val_splits_selector.get_split_method()
|
|
1097
|
+
train_set, val_set = self._train_split, self._val_split
|
|
1098
|
+
if split_method == "Based on datasets":
|
|
1099
|
+
val_dataset_ids = self.gui.train_val_splits_selector.get_val_dataset_ids()
|
|
1100
|
+
train_dataset_ids = self.gui.train_val_splits_selector.get_train_dataset_ids
|
|
1101
|
+
else:
|
|
1102
|
+
dataset_infos = [dataset for _, dataset in self._api.dataset.tree(self.project_id)]
|
|
1103
|
+
ds_infos_dict = {}
|
|
1104
|
+
for dataset in dataset_infos:
|
|
1105
|
+
if dataset.parent_id is not None:
|
|
1106
|
+
parent_ds = self._api.dataset.get_info_by_id(dataset.parent_id)
|
|
1107
|
+
dataset_name = f"{parent_ds.name}/{dataset.name}"
|
|
1108
|
+
else:
|
|
1109
|
+
dataset_name = dataset.name
|
|
1110
|
+
ds_infos_dict[dataset_name] = dataset
|
|
1111
|
+
|
|
1112
|
+
def get_image_infos_by_split(ds_infos_dict: dict, split: list):
|
|
1113
|
+
image_names_per_dataset = {}
|
|
1114
|
+
for item in split:
|
|
1115
|
+
image_names_per_dataset.setdefault(item.dataset_name, []).append(item.name)
|
|
1116
|
+
image_infos = []
|
|
1117
|
+
for dataset_name, image_names in image_names_per_dataset.items():
|
|
1118
|
+
ds_info = ds_infos_dict[dataset_name]
|
|
1119
|
+
image_infos.extend(
|
|
1120
|
+
self._api.image.get_list(
|
|
1121
|
+
ds_info.id,
|
|
1122
|
+
filters=[
|
|
1123
|
+
{
|
|
1124
|
+
"field": "name",
|
|
1125
|
+
"operator": "in",
|
|
1126
|
+
"value": image_names,
|
|
1127
|
+
}
|
|
1128
|
+
],
|
|
1129
|
+
)
|
|
1130
|
+
)
|
|
1131
|
+
return image_infos
|
|
1132
|
+
|
|
1133
|
+
val_image_infos = get_image_infos_by_split(ds_infos_dict, val_set)
|
|
1134
|
+
train_image_infos = get_image_infos_by_split(ds_infos_dict, train_set)
|
|
1135
|
+
val_images_ids = [img_info.id for img_info in val_image_infos]
|
|
1136
|
+
train_images_ids = [img_info.id for img_info in train_image_infos]
|
|
1137
|
+
|
|
1138
|
+
splits_data = {
|
|
1139
|
+
"train": {
|
|
1140
|
+
"dataset_ids": train_dataset_ids,
|
|
1141
|
+
"images_ids": train_images_ids,
|
|
1142
|
+
},
|
|
1143
|
+
"val": {
|
|
1144
|
+
"dataset_ids": val_dataset_ids,
|
|
1145
|
+
"images_ids": val_images_ids,
|
|
1146
|
+
},
|
|
1147
|
+
}
|
|
1148
|
+
return splits_data
|
|
1149
|
+
|
|
1150
|
+
def _preprocess_artifacts(self, experiment_info: dict) -> None:
|
|
1151
|
+
"""
|
|
1152
|
+
Preprocesses and move the artifacts generated by the training process to output directories.
|
|
1153
|
+
|
|
1154
|
+
:param experiment_info: Information about the experiment results.
|
|
1155
|
+
:type experiment_info: dict
|
|
1156
|
+
"""
|
|
1157
|
+
# Preprocess artifacts
|
|
1158
|
+
logger.debug("Preprocessing artifacts")
|
|
1159
|
+
if "model_files" not in experiment_info:
|
|
1160
|
+
experiment_info["model_files"] = {}
|
|
1161
|
+
else:
|
|
1162
|
+
# Move model files to output directory except config, config will be processed next
|
|
1163
|
+
files = {k: v for k, v in experiment_info["model_files"].items() if k != "config"}
|
|
1164
|
+
for file in files:
|
|
1165
|
+
if isfile:
|
|
1166
|
+
shutil.move(
|
|
1167
|
+
file,
|
|
1168
|
+
join(self.output_dir, sly_fs.get_file_name_with_ext(file)),
|
|
1169
|
+
)
|
|
1170
|
+
elif isdir:
|
|
1171
|
+
shutil.move(file, join(self.output_dir, basename(file)))
|
|
1172
|
+
|
|
1173
|
+
# Preprocess config
|
|
1174
|
+
logger.debug("Preprocessing config")
|
|
1175
|
+
config = experiment_info["model_files"].get("config")
|
|
1176
|
+
if config is not None:
|
|
1177
|
+
config_name = sly_fs.get_file_name_with_ext(experiment_info["model_files"]["config"])
|
|
1178
|
+
output_config_path = join(self.output_dir, config_name)
|
|
1179
|
+
shutil.move(experiment_info["model_files"]["config"], output_config_path)
|
|
1180
|
+
if self.is_model_benchmark_enabled:
|
|
1181
|
+
self._benchmark_params["model_files"]["config"] = output_config_path
|
|
1182
|
+
|
|
1183
|
+
# Prepare checkpoints
|
|
1184
|
+
checkpoints = experiment_info["checkpoints"]
|
|
1185
|
+
# If checkpoints returned as directory
|
|
1186
|
+
if isinstance(checkpoints, str):
|
|
1187
|
+
checkpoint_paths = []
|
|
1188
|
+
for checkpoint_path in sly_fs.list_files_recursively(checkpoints, [".pt", ".pth"]):
|
|
1189
|
+
checkpoint_paths.append(checkpoint_path)
|
|
1190
|
+
elif isinstance(checkpoints, list):
|
|
1191
|
+
checkpoint_paths = checkpoints
|
|
1192
|
+
else:
|
|
1193
|
+
raise ValueError(
|
|
1194
|
+
"Checkpoints should be a list of paths or a path to directory with checkpoints"
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
best_checkpoints_name = experiment_info["best_checkpoint"]
|
|
1198
|
+
for checkpoint_path in checkpoint_paths:
|
|
1199
|
+
new_checkpoint_path = join(
|
|
1200
|
+
self._output_checkpoints_dir,
|
|
1201
|
+
sly_fs.get_file_name_with_ext(checkpoint_path),
|
|
1202
|
+
)
|
|
1203
|
+
shutil.move(checkpoint_path, new_checkpoint_path)
|
|
1204
|
+
if self.is_model_benchmark_enabled:
|
|
1205
|
+
if sly_fs.get_file_name_with_ext(checkpoint_path) == best_checkpoints_name:
|
|
1206
|
+
self._benchmark_params["model_files"]["checkpoint"] = new_checkpoint_path
|
|
1207
|
+
|
|
1208
|
+
# Prepare logs
|
|
1209
|
+
if sly_fs.dir_exists(self.log_dir):
|
|
1210
|
+
logs_dir = join(self.output_dir, "logs")
|
|
1211
|
+
shutil.move(self.log_dir, logs_dir)
|
|
1212
|
+
|
|
1213
|
+
# Generate experiment_info.json and app_state.json
|
|
1214
|
+
def _upload_file_to_team_files(self, local_path: str, remote_path: str, message: str) -> None:
|
|
1215
|
+
"""Helper function to upload a file with progress."""
|
|
1216
|
+
logger.debug(f"Uploading '{local_path}' to Supervisely")
|
|
1217
|
+
total_size = sly_fs.get_file_size(local_path)
|
|
1218
|
+
with self.progress_bar_main(
|
|
1219
|
+
message=message, total=total_size, unit="bytes", unit_scale=True
|
|
1220
|
+
) as upload_artifacts_pbar:
|
|
1221
|
+
self.progress_bar_main.show()
|
|
1222
|
+
self._api.file.upload(
|
|
1223
|
+
self._team_id,
|
|
1224
|
+
local_path,
|
|
1225
|
+
remote_path,
|
|
1226
|
+
progress_cb=upload_artifacts_pbar,
|
|
1227
|
+
)
|
|
1228
|
+
self.progress_bar_main.hide()
|
|
1229
|
+
|
|
1230
|
+
def _generate_train_val_splits(self, remote_dir: str, splits_data: dict) -> None:
|
|
1231
|
+
"""
|
|
1232
|
+
Generates and uploads the train and val splits to the output directory.
|
|
1233
|
+
|
|
1234
|
+
:param remote_dir: Remote directory path.
|
|
1235
|
+
:type remote_dir: str
|
|
1236
|
+
"""
|
|
1237
|
+
local_train_val_split_path = join(self.output_dir, self._train_val_split_file)
|
|
1238
|
+
remote_train_val_split_path = join(remote_dir, self._train_val_split_file)
|
|
1239
|
+
|
|
1240
|
+
data = {
|
|
1241
|
+
"train": splits_data["train"]["images_ids"],
|
|
1242
|
+
"val": splits_data["val"]["images_ids"],
|
|
1243
|
+
}
|
|
1244
|
+
|
|
1245
|
+
sly_json.dump_json_file(data, local_train_val_split_path)
|
|
1246
|
+
self._upload_file_to_team_files(
|
|
1247
|
+
local_train_val_split_path,
|
|
1248
|
+
remote_train_val_split_path,
|
|
1249
|
+
f"Uploading '{self._train_val_split_file}' to Team Files",
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1252
|
+
def _generate_model_meta(self, remote_dir: str, experiment_info: dict) -> None:
|
|
1253
|
+
"""
|
|
1254
|
+
Generates and uploads the model_meta.json file to the output directory.
|
|
1255
|
+
|
|
1256
|
+
:param remote_dir: Remote directory path.
|
|
1257
|
+
:type remote_dir: str
|
|
1258
|
+
:param experiment_info: Information about the experiment results.
|
|
1259
|
+
:type experiment_info: dict
|
|
1260
|
+
"""
|
|
1261
|
+
# @TODO: Handle tags for classification tasks
|
|
1262
|
+
local_path = join(self.output_dir, self._model_meta_file)
|
|
1263
|
+
remote_path = join(remote_dir, self._model_meta_file)
|
|
1264
|
+
|
|
1265
|
+
sly_json.dump_json_file(self.model_meta.to_json(), local_path)
|
|
1266
|
+
self._upload_file_to_team_files(
|
|
1267
|
+
local_path,
|
|
1268
|
+
remote_path,
|
|
1269
|
+
f"Uploading '{self._model_meta_file}' to Team Files",
|
|
1270
|
+
)
|
|
1271
|
+
|
|
1272
|
+
def _generate_experiment_info(
|
|
1273
|
+
self,
|
|
1274
|
+
remote_dir: str,
|
|
1275
|
+
experiment_info: Dict,
|
|
1276
|
+
evaluation_report_id: Optional[int] = None,
|
|
1277
|
+
) -> None:
|
|
1278
|
+
"""
|
|
1279
|
+
Generates and uploads the experiment_info.json file to the output directory.
|
|
1280
|
+
|
|
1281
|
+
:param remote_dir: Remote directory path.
|
|
1282
|
+
:type remote_dir: str
|
|
1283
|
+
:param experiment_info: Information about the experiment results.
|
|
1284
|
+
:type experiment_info: dict
|
|
1285
|
+
:param evaluation_report_id: Evaluation report file ID.
|
|
1286
|
+
:type evaluation_report_id: int
|
|
1287
|
+
"""
|
|
1288
|
+
logger.debug("Updating experiment info")
|
|
1289
|
+
|
|
1290
|
+
experiment_info = {
|
|
1291
|
+
"experiment_name": self.gui.training_process.get_experiment_name(),
|
|
1292
|
+
"framework_name": self.framework_name,
|
|
1293
|
+
"model_name": experiment_info["model_name"],
|
|
1294
|
+
"task_type": experiment_info["task_type"],
|
|
1295
|
+
"project_id": self.project_info.id,
|
|
1296
|
+
"task_id": self.task_id,
|
|
1297
|
+
"model_files": experiment_info["model_files"],
|
|
1298
|
+
"checkpoints": experiment_info["checkpoints"],
|
|
1299
|
+
"best_checkpoint": experiment_info["best_checkpoint"],
|
|
1300
|
+
"app_state": self._app_state_file,
|
|
1301
|
+
"model_meta": self._model_meta_file,
|
|
1302
|
+
"train_val_split": self._train_val_split_file,
|
|
1303
|
+
"hyperparameters": self._hyperparameters_file,
|
|
1304
|
+
"artifacts_dir": remote_dir,
|
|
1305
|
+
"datetime": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
1306
|
+
"evaluation_report_id": evaluation_report_id,
|
|
1307
|
+
"eval_metrics": {},
|
|
1308
|
+
}
|
|
1309
|
+
|
|
1310
|
+
remote_checkpoints_dir = join(remote_dir, self._remote_checkpoints_dir_name)
|
|
1311
|
+
checkpoint_files = self._api.file.list(
|
|
1312
|
+
self._team_id, remote_checkpoints_dir, return_type="fileinfo"
|
|
1313
|
+
)
|
|
1314
|
+
experiment_info["checkpoints"] = [
|
|
1315
|
+
f"checkpoints/{checkpoint.name}" for checkpoint in checkpoint_files
|
|
1316
|
+
]
|
|
1317
|
+
|
|
1318
|
+
experiment_info["best_checkpoint"] = sly_fs.get_file_name_with_ext(
|
|
1319
|
+
experiment_info["best_checkpoint"]
|
|
1320
|
+
)
|
|
1321
|
+
experiment_info["model_files"]["config"] = sly_fs.get_file_name_with_ext(
|
|
1322
|
+
experiment_info["model_files"]["config"]
|
|
1323
|
+
)
|
|
1324
|
+
|
|
1325
|
+
local_path = join(self.output_dir, self._experiment_json_file)
|
|
1326
|
+
remote_path = join(remote_dir, self._experiment_json_file)
|
|
1327
|
+
sly_json.dump_json_file(experiment_info, local_path)
|
|
1328
|
+
self._upload_file_to_team_files(
|
|
1329
|
+
local_path,
|
|
1330
|
+
remote_path,
|
|
1331
|
+
f"Uploading '{self._experiment_json_file}' to Team Files",
|
|
1332
|
+
)
|
|
1333
|
+
|
|
1334
|
+
def _generate_hyperparameters(self, remote_dir: str, experiment_info: Dict) -> None:
|
|
1335
|
+
"""
|
|
1336
|
+
Generates and uploads the hyperparameters.yaml file to the output directory.
|
|
1337
|
+
|
|
1338
|
+
:param remote_dir: Remote directory path.
|
|
1339
|
+
:type remote_dir: str
|
|
1340
|
+
:param experiment_info: Information about the experiment results.
|
|
1341
|
+
:type experiment_info: dict
|
|
1342
|
+
"""
|
|
1343
|
+
local_path = join(self.output_dir, self._hyperparameters_file)
|
|
1344
|
+
remote_path = join(remote_dir, self._hyperparameters_file)
|
|
1345
|
+
|
|
1346
|
+
with open(local_path, "w") as file:
|
|
1347
|
+
file.write(self.hyperparameters_yaml)
|
|
1348
|
+
|
|
1349
|
+
self._upload_file_to_team_files(
|
|
1350
|
+
local_path,
|
|
1351
|
+
remote_path,
|
|
1352
|
+
f"Uploading '{self._hyperparameters_file}' to Team Files",
|
|
1353
|
+
)
|
|
1354
|
+
|
|
1355
|
+
def _generate_app_state(self, remote_dir: str, experiment_info: Dict) -> None:
|
|
1356
|
+
"""
|
|
1357
|
+
Generates and uploads the app_state.json file to the output directory.
|
|
1358
|
+
|
|
1359
|
+
:param remote_dir: Remote directory path.
|
|
1360
|
+
:type remote_dir: str
|
|
1361
|
+
:param experiment_info: Information about the experiment results.
|
|
1362
|
+
:type experiment_info: dict
|
|
1363
|
+
"""
|
|
1364
|
+
app_state = self.get_app_state(experiment_info)
|
|
1365
|
+
|
|
1366
|
+
local_path = join(self.output_dir, self._app_state_file)
|
|
1367
|
+
remote_path = join(remote_dir, self._app_state_file)
|
|
1368
|
+
sly_json.dump_json_file(app_state, local_path)
|
|
1369
|
+
self._upload_file_to_team_files(
|
|
1370
|
+
local_path, remote_path, f"Uploading '{self._app_state_file}' to Team Files"
|
|
1371
|
+
)
|
|
1372
|
+
|
|
1373
|
+
def _get_train_val_splits_for_app_state(self) -> Dict:
|
|
1374
|
+
"""
|
|
1375
|
+
Gets the train and val splits information for app_state.json.
|
|
1376
|
+
|
|
1377
|
+
:return: Train and val splits information based on selected split method.
|
|
1378
|
+
:rtype: dict
|
|
1379
|
+
"""
|
|
1380
|
+
split_method = self.gui.train_val_splits_selector.get_split_method()
|
|
1381
|
+
train_val_splits = {"method": split_method.lower()}
|
|
1382
|
+
if split_method == "Random":
|
|
1383
|
+
train_val_splits.update(
|
|
1384
|
+
{
|
|
1385
|
+
"split": "train",
|
|
1386
|
+
"percent": self.gui.train_val_splits_selector.train_val_splits.get_train_split_percent(),
|
|
1387
|
+
}
|
|
1388
|
+
)
|
|
1389
|
+
elif split_method == "Based on tags":
|
|
1390
|
+
train_val_splits.update(
|
|
1391
|
+
{
|
|
1392
|
+
"train_tag": self.gui.train_val_splits_selector.train_val_splits.get_train_tag(),
|
|
1393
|
+
"val_tag": self.gui.train_val_splits_selector.train_val_splits.get_val_tag(),
|
|
1394
|
+
"untagged_action": self.gui.train_val_splits_selector.train_val_splits.get_untagged_action(),
|
|
1395
|
+
}
|
|
1396
|
+
)
|
|
1397
|
+
elif split_method == "Based on datasets":
|
|
1398
|
+
train_val_splits.update(
|
|
1399
|
+
{
|
|
1400
|
+
"train_datasets": self.gui.train_val_splits_selector.train_val_splits.get_train_dataset_ids(),
|
|
1401
|
+
"val_datasets": self.gui.train_val_splits_selector.train_val_splits.get_val_dataset_ids(),
|
|
1402
|
+
}
|
|
1403
|
+
)
|
|
1404
|
+
return train_val_splits
|
|
1405
|
+
|
|
1406
|
+
def _get_model_config_for_app_state(self, experiment_info: Dict = None) -> Dict:
|
|
1407
|
+
"""
|
|
1408
|
+
Gets the model configuration information for app_state.json.
|
|
1409
|
+
|
|
1410
|
+
:param experiment_info: Information about the experiment results.
|
|
1411
|
+
:type experiment_info: dict
|
|
1412
|
+
"""
|
|
1413
|
+
experiment_info = experiment_info or {}
|
|
1414
|
+
|
|
1415
|
+
if self.model_source == ModelSource.PRETRAINED:
|
|
1416
|
+
model_name = experiment_info.get("model_name") or self.model_info.get("meta", {}).get(
|
|
1417
|
+
"model_name"
|
|
1418
|
+
)
|
|
1419
|
+
return {
|
|
1420
|
+
"source": ModelSource.PRETRAINED,
|
|
1421
|
+
"model_name": model_name,
|
|
1422
|
+
}
|
|
1423
|
+
elif self.model_source == ModelSource.CUSTOM:
|
|
1424
|
+
return {
|
|
1425
|
+
"source": ModelSource.CUSTOM,
|
|
1426
|
+
"task_id": self.task_id,
|
|
1427
|
+
"checkpoint": "checkpoint.pth",
|
|
1428
|
+
}
|
|
1429
|
+
|
|
1430
|
+
# ----------------------------------------- #
|
|
1431
|
+
|
|
1432
|
+
# Upload artifacts
|
|
1433
|
+
def _upload_artifacts(self) -> None:
|
|
1434
|
+
"""
|
|
1435
|
+
Uploads the training artifacts to Supervisely.
|
|
1436
|
+
Path is generated based on the project ID, task ID, and framework name.
|
|
1437
|
+
|
|
1438
|
+
Path: /experiments/{project_id}_{project_name}/{task_id}_{framework_name}/
|
|
1439
|
+
Example path: /experiments/43192_Apples/68271_rt-detr/
|
|
1440
|
+
"""
|
|
1441
|
+
logger.info(f"Uploading directory: '{self.output_dir}' to Supervisely")
|
|
1442
|
+
task_id = self.task_id
|
|
1443
|
+
|
|
1444
|
+
remote_artifacts_dir = f"/{self._experiments_dir_name}/{self.project_id}_{self.project_name}/{task_id}_{self.framework_name}/"
|
|
1445
|
+
|
|
1446
|
+
# Clean debug directory if exists
|
|
1447
|
+
if task_id == "debug-session":
|
|
1448
|
+
if self._api.file.dir_exists(self._team_id, f"{remote_artifacts_dir}/", True):
|
|
1449
|
+
with self.progress_bar_main(
|
|
1450
|
+
message=f"[Debug] Cleaning train artifacts: '{remote_artifacts_dir}/'",
|
|
1451
|
+
total=1,
|
|
1452
|
+
) as upload_artifacts_pbar:
|
|
1453
|
+
self.progress_bar_main.show()
|
|
1454
|
+
self._api.file.remove_dir(self._team_id, f"{remote_artifacts_dir}", True)
|
|
1455
|
+
upload_artifacts_pbar.update(1)
|
|
1456
|
+
self.progress_bar_main.hide()
|
|
1457
|
+
|
|
1458
|
+
# Generate link file
|
|
1459
|
+
if is_production():
|
|
1460
|
+
app_url = f"/apps/sessions/{task_id}"
|
|
1461
|
+
else:
|
|
1462
|
+
app_url = "This is a debug session. No link available."
|
|
1463
|
+
app_link_path = join(self.output_dir, "open_app.lnk")
|
|
1464
|
+
with open(app_link_path, "w") as text_file:
|
|
1465
|
+
print(app_url, file=text_file)
|
|
1466
|
+
|
|
1467
|
+
local_files = sly_fs.list_files_recursively(self.output_dir)
|
|
1468
|
+
total_size = sum([sly_fs.get_file_size(file_path) for file_path in local_files])
|
|
1469
|
+
with self.progress_bar_main(
|
|
1470
|
+
message="Uploading train artifacts to Team Files",
|
|
1471
|
+
total=total_size,
|
|
1472
|
+
unit="bytes",
|
|
1473
|
+
unit_scale=True,
|
|
1474
|
+
) as upload_artifacts_pbar:
|
|
1475
|
+
self.progress_bar_main.show()
|
|
1476
|
+
remote_dir = self._api.file.upload_directory(
|
|
1477
|
+
self._team_id,
|
|
1478
|
+
self.output_dir,
|
|
1479
|
+
remote_artifacts_dir,
|
|
1480
|
+
progress_size_cb=upload_artifacts_pbar,
|
|
1481
|
+
)
|
|
1482
|
+
self.progress_bar_main.hide()
|
|
1483
|
+
|
|
1484
|
+
file_info = self._api.file.get_info_by_path(self._team_id, join(remote_dir, "open_app.lnk"))
|
|
1485
|
+
return remote_dir, file_info
|
|
1486
|
+
|
|
1487
|
+
def _set_training_output(self, remote_dir: str, file_info: FileInfo) -> None:
|
|
1488
|
+
"""
|
|
1489
|
+
Sets the training output in the GUI.
|
|
1490
|
+
"""
|
|
1491
|
+
logger.info("All training artifacts uploaded successfully")
|
|
1492
|
+
self.gui.training_process.start_button.loading = False
|
|
1493
|
+
self.gui.training_process.start_button.disable()
|
|
1494
|
+
self.gui.training_process.stop_button.disable()
|
|
1495
|
+
self.gui.training_logs.tensorboard_button.disable()
|
|
1496
|
+
|
|
1497
|
+
set_directory(remote_dir)
|
|
1498
|
+
self.gui.training_process.artifacts_thumbnail.set(file_info)
|
|
1499
|
+
self.gui.training_process.artifacts_thumbnail.show()
|
|
1500
|
+
self.gui.training_process.success_message.show()
|
|
1501
|
+
|
|
1502
|
+
# Model Benchmark
|
|
1503
|
+
def _get_eval_results_dir_name(self) -> str:
|
|
1504
|
+
"""
|
|
1505
|
+
Returns the evaluation results path.
|
|
1506
|
+
"""
|
|
1507
|
+
task_info = self._api.task.get_info_by_id(self.task_id)
|
|
1508
|
+
task_dir = f"{self.task_id}_{task_info['meta']['app']['name']}"
|
|
1509
|
+
eval_res_dir = f"/model-benchmark/evaluation/{self.project_info.id}_{self.project_info.name}/{task_dir}/"
|
|
1510
|
+
eval_res_dir = self._api.storage.get_free_dir_name(self._team_id(), eval_res_dir)
|
|
1511
|
+
return eval_res_dir
|
|
1512
|
+
|
|
1513
|
+
def _run_model_benchmark(
|
|
1514
|
+
self,
|
|
1515
|
+
local_artifacts_dir: str,
|
|
1516
|
+
remote_artifacts_dir: str,
|
|
1517
|
+
experiment_info: dict,
|
|
1518
|
+
splits_data: dict,
|
|
1519
|
+
) -> tuple:
|
|
1520
|
+
"""
|
|
1521
|
+
Runs the Model Benchmark evaluation process. Model benchmark runs only in production mode.
|
|
1522
|
+
|
|
1523
|
+
:param local_artifacts_dir: Local directory path.
|
|
1524
|
+
:type local_artifacts_dir: str
|
|
1525
|
+
:param remote_artifacts_dir: Remote directory path.
|
|
1526
|
+
:type remote_artifacts_dir: str
|
|
1527
|
+
:param experiment_info: Information about the experiment results.
|
|
1528
|
+
:type experiment_info: dict
|
|
1529
|
+
:param splits_data: Information about the train and val splits.
|
|
1530
|
+
:type splits_data: dict
|
|
1531
|
+
:return: Evaluation report and report ID.
|
|
1532
|
+
:rtype: tuple
|
|
1533
|
+
"""
|
|
1534
|
+
report, report_id = None, None
|
|
1535
|
+
if self._inference_class is None:
|
|
1536
|
+
logger.warn(
|
|
1537
|
+
"Inference class is not registered, model benchmark disabled. "
|
|
1538
|
+
"Use 'register_inference_class' method to register inference class."
|
|
1539
|
+
)
|
|
1540
|
+
return report, report_id
|
|
1541
|
+
|
|
1542
|
+
# Can't get task type from session. requires before session init
|
|
1543
|
+
supported_task_types = [
|
|
1544
|
+
TaskType.OBJECT_DETECTION,
|
|
1545
|
+
TaskType.INSTANCE_SEGMENTATION,
|
|
1546
|
+
]
|
|
1547
|
+
task_type = experiment_info["task_type"]
|
|
1548
|
+
if task_type not in supported_task_types:
|
|
1549
|
+
logger.warn(
|
|
1550
|
+
f"Task type: '{task_type}' is not supported for Model Benchmark. "
|
|
1551
|
+
f"Supported tasks: {', '.join(task_type)}"
|
|
1552
|
+
)
|
|
1553
|
+
return report, report_id
|
|
1554
|
+
|
|
1555
|
+
logger.info("Running Model Benchmark evaluation")
|
|
1556
|
+
try:
|
|
1557
|
+
remote_checkpoints_dir = join(remote_artifacts_dir, "checkpoints")
|
|
1558
|
+
best_checkpoint = experiment_info.get("best_checkpoint", None)
|
|
1559
|
+
best_filename = sly_fs.get_file_name_with_ext(best_checkpoint)
|
|
1560
|
+
remote_best_checkpoint = join(remote_checkpoints_dir, best_filename)
|
|
1561
|
+
|
|
1562
|
+
config_path = experiment_info["model_files"].get("config")
|
|
1563
|
+
if config_path is not None:
|
|
1564
|
+
remote_config_path = join(
|
|
1565
|
+
remote_artifacts_dir, sly_fs.get_file_name_with_ext(config_path)
|
|
1566
|
+
)
|
|
1567
|
+
else:
|
|
1568
|
+
remote_config_path = None
|
|
1569
|
+
|
|
1570
|
+
logger.info(f"Creating the report for the best model: {best_filename!r}")
|
|
1571
|
+
self.gui.training_process.model_benchmark_report_text.show()
|
|
1572
|
+
self.progress_bar_main(message="Starting Model Benchmark evaluation", total=1)
|
|
1573
|
+
self.progress_bar_main.show()
|
|
1574
|
+
|
|
1575
|
+
# 0. Serve trained model
|
|
1576
|
+
m = self._inference_class(
|
|
1577
|
+
model_dir=self.model_dir,
|
|
1578
|
+
use_gui=False,
|
|
1579
|
+
custom_inference_settings=self._inference_settings,
|
|
1580
|
+
)
|
|
1581
|
+
|
|
1582
|
+
logger.info(f"Using device: {self.device}")
|
|
1583
|
+
|
|
1584
|
+
self._benchmark_params["device"] = self.device
|
|
1585
|
+
self._benchmark_params["model_info"] = {
|
|
1586
|
+
"artifacts_dir": remote_artifacts_dir,
|
|
1587
|
+
"model_name": experiment_info["model_name"],
|
|
1588
|
+
"framework_name": self.framework_name,
|
|
1589
|
+
"model_meta": self.model_meta.to_json(),
|
|
1590
|
+
}
|
|
1591
|
+
|
|
1592
|
+
logger.info(f"Deploy parameters: {self._benchmark_params}")
|
|
1593
|
+
|
|
1594
|
+
m._load_model_headless(**self._benchmark_params)
|
|
1595
|
+
m.serve()
|
|
1596
|
+
|
|
1597
|
+
port = 8000
|
|
1598
|
+
session = SessionJSON(self._api, session_url=f"http://localhost:{port}")
|
|
1599
|
+
benchmark_dir = join(local_artifacts_dir, "benchmark")
|
|
1600
|
+
sly_fs.mkdir(benchmark_dir, True)
|
|
1601
|
+
|
|
1602
|
+
# 1. Init benchmark
|
|
1603
|
+
benchmark_dataset_ids = splits_data["val"]["dataset_ids"]
|
|
1604
|
+
benchmark_images_ids = splits_data["val"]["images_ids"]
|
|
1605
|
+
train_dataset_ids = splits_data["train"]["dataset_ids"]
|
|
1606
|
+
train_images_ids = splits_data["train"]["images_ids"]
|
|
1607
|
+
|
|
1608
|
+
bm = None
|
|
1609
|
+
if task_type == TaskType.OBJECT_DETECTION:
|
|
1610
|
+
eval_params = ObjectDetectionEvaluator.load_yaml_evaluation_params()
|
|
1611
|
+
eval_params = yaml.safe_load(eval_params)
|
|
1612
|
+
bm = ObjectDetectionBenchmark(
|
|
1613
|
+
self._api,
|
|
1614
|
+
self.project_info.id,
|
|
1615
|
+
output_dir=benchmark_dir,
|
|
1616
|
+
gt_dataset_ids=benchmark_dataset_ids,
|
|
1617
|
+
gt_images_ids=benchmark_images_ids,
|
|
1618
|
+
progress=self.progress_bar_main,
|
|
1619
|
+
progress_secondary=self.progress_bar_secondary,
|
|
1620
|
+
classes_whitelist=self.classes,
|
|
1621
|
+
evaluation_params=eval_params,
|
|
1622
|
+
)
|
|
1623
|
+
elif task_type == TaskType.INSTANCE_SEGMENTATION:
|
|
1624
|
+
eval_params = InstanceSegmentationEvaluator.load_yaml_evaluation_params()
|
|
1625
|
+
eval_params = yaml.safe_load(eval_params)
|
|
1626
|
+
bm = InstanceSegmentationBenchmark(
|
|
1627
|
+
self._api,
|
|
1628
|
+
self.project_info.id,
|
|
1629
|
+
output_dir=benchmark_dir,
|
|
1630
|
+
gt_dataset_ids=benchmark_dataset_ids,
|
|
1631
|
+
gt_images_ids=benchmark_images_ids,
|
|
1632
|
+
progress=self.progress_bar_main,
|
|
1633
|
+
progress_secondary=self.progress_bar_secondary,
|
|
1634
|
+
classes_whitelist=self.classes,
|
|
1635
|
+
evaluation_params=eval_params,
|
|
1636
|
+
)
|
|
1637
|
+
elif task_type == TaskType.SEMANTIC_SEGMENTATION:
|
|
1638
|
+
eval_params = SemanticSegmentationEvaluator.load_yaml_evaluation_params()
|
|
1639
|
+
eval_params = yaml.safe_load(eval_params)
|
|
1640
|
+
bm = SemanticSegmentationBenchmark(
|
|
1641
|
+
self._api,
|
|
1642
|
+
self.project_info.id,
|
|
1643
|
+
output_dir=benchmark_dir,
|
|
1644
|
+
gt_dataset_ids=benchmark_dataset_ids,
|
|
1645
|
+
gt_images_ids=benchmark_images_ids,
|
|
1646
|
+
progress=self.progress_bar_main,
|
|
1647
|
+
progress_secondary=self.progress_bar_secondary,
|
|
1648
|
+
classes_whitelist=self.classes,
|
|
1649
|
+
evaluation_params=eval_params,
|
|
1650
|
+
)
|
|
1651
|
+
else:
|
|
1652
|
+
raise ValueError(f"Task type: '{task_type}' is not supported for Model Benchmark")
|
|
1653
|
+
|
|
1654
|
+
if self.gui.train_val_splits_selector.get_split_method() == "Based on datasets":
|
|
1655
|
+
train_info = {
|
|
1656
|
+
"app_session_id": self.task_id,
|
|
1657
|
+
"train_dataset_ids": train_dataset_ids,
|
|
1658
|
+
"train_images_ids": None,
|
|
1659
|
+
"images_count": len(self._train_split),
|
|
1660
|
+
}
|
|
1661
|
+
else:
|
|
1662
|
+
train_info = {
|
|
1663
|
+
"app_session_id": self.task_id,
|
|
1664
|
+
"train_dataset_ids": None,
|
|
1665
|
+
"train_images_ids": train_images_ids,
|
|
1666
|
+
"images_count": len(self._train_split),
|
|
1667
|
+
}
|
|
1668
|
+
bm.train_info = train_info
|
|
1669
|
+
|
|
1670
|
+
# 2. Run inference
|
|
1671
|
+
bm.run_inference(session)
|
|
1672
|
+
|
|
1673
|
+
# 3. Pull results from the server
|
|
1674
|
+
gt_project_path, dt_project_path = bm.download_projects(save_images=False)
|
|
1675
|
+
|
|
1676
|
+
# 4. Evaluate
|
|
1677
|
+
bm._evaluate(gt_project_path, dt_project_path)
|
|
1678
|
+
|
|
1679
|
+
# 5. Upload evaluation results
|
|
1680
|
+
eval_res_dir = self._get_eval_results_dir_name()
|
|
1681
|
+
bm.upload_eval_results(eval_res_dir + "/evaluation/")
|
|
1682
|
+
|
|
1683
|
+
# 6. Speed test
|
|
1684
|
+
if self.gui.hyperparameters_selector.get_speedtest_checkbox_value() is True:
|
|
1685
|
+
bm.run_speedtest(session, self.project_info.id)
|
|
1686
|
+
self.progress_bar_secondary.hide() # @TODO: add progress bar
|
|
1687
|
+
bm.upload_speedtest_results(eval_res_dir + "/speedtest/")
|
|
1688
|
+
|
|
1689
|
+
# 7. Prepare visualizations, report and upload
|
|
1690
|
+
bm.visualize()
|
|
1691
|
+
remote_dir = bm.upload_visualizations(eval_res_dir + "/visualizations/")
|
|
1692
|
+
report = bm.upload_report_link(remote_dir)
|
|
1693
|
+
report_id = report.id
|
|
1694
|
+
|
|
1695
|
+
# 8. UI updates
|
|
1696
|
+
benchmark_report_template = self._api.file.get_info_by_path(
|
|
1697
|
+
self._team_id(), remote_dir + "template.vue"
|
|
1698
|
+
)
|
|
1699
|
+
|
|
1700
|
+
self.gui.training_process.model_benchmark_report_text.hide()
|
|
1701
|
+
self.gui.training_process.model_benchmark_report_thumbnail.set(
|
|
1702
|
+
benchmark_report_template
|
|
1703
|
+
)
|
|
1704
|
+
self.gui.training_process.model_benchmark_report_thumbnail.show()
|
|
1705
|
+
self.progress_bar_main.hide()
|
|
1706
|
+
self.progress_bar_secondary.hide()
|
|
1707
|
+
logger.info("Model benchmark evaluation completed successfully")
|
|
1708
|
+
logger.info(
|
|
1709
|
+
f"Predictions project name: {bm.dt_project_info.name}. Workspace_id: {bm.dt_project_info.workspace_id}"
|
|
1710
|
+
)
|
|
1711
|
+
logger.info(
|
|
1712
|
+
f"Differences project name: {bm.diff_project_info.name}. Workspace_id: {bm.diff_project_info.workspace_id}"
|
|
1713
|
+
)
|
|
1714
|
+
except Exception as e:
|
|
1715
|
+
logger.error(f"Model benchmark failed. {repr(e)}", exc_info=True)
|
|
1716
|
+
self.gui.training_process.model_benchmark_report_text.hide()
|
|
1717
|
+
self.progress_bar_main.hide()
|
|
1718
|
+
self.progress_bar_secondary.hide()
|
|
1719
|
+
try:
|
|
1720
|
+
if bm.dt_project_info:
|
|
1721
|
+
self._api.project.remove(bm.dt_project_info.id)
|
|
1722
|
+
if bm.diff_project_info:
|
|
1723
|
+
self._api.project.remove(bm.diff_project_info.id)
|
|
1724
|
+
except Exception as e2:
|
|
1725
|
+
return report, report_id
|
|
1726
|
+
return report, report_id
|
|
1727
|
+
|
|
1728
|
+
# ----------------------------------------- #
|
|
1729
|
+
|
|
1730
|
+
# Workflow
|
|
1731
|
+
def _workflow_input(self):
|
|
1732
|
+
"""
|
|
1733
|
+
Adds the input data to the workflow.
|
|
1734
|
+
"""
|
|
1735
|
+
try:
|
|
1736
|
+
project_version_id = self._api.project.version.create(
|
|
1737
|
+
self.project_info,
|
|
1738
|
+
self._app_name,
|
|
1739
|
+
f"This backup was created automatically by Supervisely before the {self._app_name} task with ID: {self._api.task_id}",
|
|
1740
|
+
)
|
|
1741
|
+
except Exception as e:
|
|
1742
|
+
logger.warning(f"Failed to create a project version: {repr(e)}")
|
|
1743
|
+
project_version_id = None
|
|
1744
|
+
|
|
1745
|
+
try:
|
|
1746
|
+
if project_version_id is None:
|
|
1747
|
+
project_version_id = (
|
|
1748
|
+
self.project_info.version.get("id", None) if self.project_info.version else None
|
|
1749
|
+
)
|
|
1750
|
+
self._api.app.workflow.add_input_project(
|
|
1751
|
+
self.project_info.id, version_id=project_version_id
|
|
1752
|
+
)
|
|
1753
|
+
|
|
1754
|
+
if self.model_source == ModelSource.CUSTOM:
|
|
1755
|
+
file_info = self._api.file.get_info_by_path(
|
|
1756
|
+
self._team_id,
|
|
1757
|
+
self.gui.model_selector.experiment_selector.get_selected_checkpoint_path(),
|
|
1758
|
+
)
|
|
1759
|
+
if file_info is not None:
|
|
1760
|
+
self._api.app.workflow.add_input_file(file_info, model_weight=True)
|
|
1761
|
+
logger.debug(
|
|
1762
|
+
f"Workflow Input: Project ID - {self.project_info.id}, Project Version ID - {project_version_id}, Input File - {True if file_info else False}"
|
|
1763
|
+
)
|
|
1764
|
+
except Exception as e:
|
|
1765
|
+
logger.debug(f"Failed to add input to the workflow: {repr(e)}")
|
|
1766
|
+
|
|
1767
|
+
def _workflow_output(
|
|
1768
|
+
self,
|
|
1769
|
+
team_files_dir: str,
|
|
1770
|
+
file_info: FileInfo,
|
|
1771
|
+
model_benchmark_report: Optional[FileInfo] = None,
|
|
1772
|
+
):
|
|
1773
|
+
"""
|
|
1774
|
+
Adds the output data to the workflow.
|
|
1775
|
+
"""
|
|
1776
|
+
try:
|
|
1777
|
+
module_id = (
|
|
1778
|
+
self._api.task.get_info_by_id(self._api.task_id)
|
|
1779
|
+
.get("meta", {})
|
|
1780
|
+
.get("app", {})
|
|
1781
|
+
.get("id")
|
|
1782
|
+
)
|
|
1783
|
+
logger.debug(f"Workflow Output: Model artifacts - '{team_files_dir}'")
|
|
1784
|
+
|
|
1785
|
+
node_settings = WorkflowSettings(
|
|
1786
|
+
title=self._app_name,
|
|
1787
|
+
url=(
|
|
1788
|
+
f"/apps/{module_id}/sessions/{self._api.task_id}"
|
|
1789
|
+
if module_id
|
|
1790
|
+
else f"apps/sessions/{self._api.task_id}"
|
|
1791
|
+
),
|
|
1792
|
+
url_title="Show Results",
|
|
1793
|
+
)
|
|
1794
|
+
|
|
1795
|
+
if file_info:
|
|
1796
|
+
relation_settings = WorkflowSettings(
|
|
1797
|
+
title="Train Artifacts",
|
|
1798
|
+
icon="folder",
|
|
1799
|
+
icon_color="#FFA500",
|
|
1800
|
+
icon_bg_color="#FFE8BE",
|
|
1801
|
+
url=f"/files/{file_info.id}/true",
|
|
1802
|
+
url_title="Open Folder",
|
|
1803
|
+
)
|
|
1804
|
+
meta = WorkflowMeta(
|
|
1805
|
+
relation_settings=relation_settings, node_settings=node_settings
|
|
1806
|
+
)
|
|
1807
|
+
logger.debug(f"Workflow Output: meta \n {meta}")
|
|
1808
|
+
self._api.app.workflow.add_output_file(file_info, model_weight=True, meta=meta)
|
|
1809
|
+
else:
|
|
1810
|
+
logger.debug(
|
|
1811
|
+
f"File with checkpoints not found in Team Files. Cannot set workflow output."
|
|
1812
|
+
)
|
|
1813
|
+
|
|
1814
|
+
if model_benchmark_report:
|
|
1815
|
+
mb_relation_settings = WorkflowSettings(
|
|
1816
|
+
title="Model Benchmark",
|
|
1817
|
+
icon="assignment",
|
|
1818
|
+
icon_color="#674EA7",
|
|
1819
|
+
icon_bg_color="#CCCCFF",
|
|
1820
|
+
url=f"/model-benchmark?id={model_benchmark_report.id}",
|
|
1821
|
+
url_title="Open Report",
|
|
1822
|
+
)
|
|
1823
|
+
|
|
1824
|
+
meta = WorkflowMeta(
|
|
1825
|
+
relation_settings=mb_relation_settings, node_settings=node_settings
|
|
1826
|
+
)
|
|
1827
|
+
self._api.app.workflow.add_output_file(model_benchmark_report, meta=meta)
|
|
1828
|
+
else:
|
|
1829
|
+
logger.debug(
|
|
1830
|
+
f"File with model benchmark report not found in Team Files. Cannot set workflow output."
|
|
1831
|
+
)
|
|
1832
|
+
except Exception as e:
|
|
1833
|
+
logger.debug(f"Failed to add output to the workflow: {repr(e)}")
|
|
1834
|
+
# ----------------------------------------- #
|
|
1835
|
+
|
|
1836
|
+
# Logger
|
|
1837
|
+
def _init_logger(self):
|
|
1838
|
+
"""
|
|
1839
|
+
Initialize training logger. Set up Tensorboard and callbacks.
|
|
1840
|
+
"""
|
|
1841
|
+
train_logger = self._app_options.get("train_logger", "")
|
|
1842
|
+
if train_logger.lower() == "tensorboard":
|
|
1843
|
+
tb_logger.set_log_dir(self.log_dir)
|
|
1844
|
+
self._setup_logger_callbacks()
|
|
1845
|
+
self._init_tensorboard()
|
|
1846
|
+
|
|
1847
|
+
def _init_tensorboard(self):
|
|
1848
|
+
self._register_routes()
|
|
1849
|
+
args = [
|
|
1850
|
+
"tensorboard",
|
|
1851
|
+
"--logdir",
|
|
1852
|
+
self.log_dir,
|
|
1853
|
+
"--host=localhost",
|
|
1854
|
+
f"--port={self._tensorboard_port}",
|
|
1855
|
+
"--load_fast=true",
|
|
1856
|
+
"--reload_multifile=true",
|
|
1857
|
+
]
|
|
1858
|
+
self._tensorboard_process = subprocess.Popen(args)
|
|
1859
|
+
self.app.call_before_shutdown(self.stop_tensorboard)
|
|
1860
|
+
print(f"Tensorboard server has been started")
|
|
1861
|
+
self.gui.training_logs.tensorboard_button.enable()
|
|
1862
|
+
|
|
1863
|
+
def start_tensorboard(self, log_dir: str, port: int = None):
|
|
1864
|
+
"""
|
|
1865
|
+
Method to manually start Tensorboard in the user's training code.
|
|
1866
|
+
Tensorboard is started automatically if the 'train_logger' is set to 'tensorboard' in app_options.yaml file.
|
|
1867
|
+
|
|
1868
|
+
:param log_dir: Directory path to the log files.
|
|
1869
|
+
:type log_dir: str
|
|
1870
|
+
:param port: Port number for Tensorboard, defaults to None
|
|
1871
|
+
:type port: int, optional
|
|
1872
|
+
"""
|
|
1873
|
+
if port is not None:
|
|
1874
|
+
self._tensorboard_port = port
|
|
1875
|
+
self.log_dir = log_dir
|
|
1876
|
+
self._init_tensorboard()
|
|
1877
|
+
|
|
1878
|
+
def stop_tensorboard(self):
|
|
1879
|
+
"""Stop Tensorboard server"""
|
|
1880
|
+
if self._tensorboard_process is not None:
|
|
1881
|
+
self._tensorboard_process.terminate()
|
|
1882
|
+
self._tensorboard_process = None
|
|
1883
|
+
print(f"Tensorboard server has been stopped")
|
|
1884
|
+
else:
|
|
1885
|
+
print("Tensorboard server is not running")
|
|
1886
|
+
|
|
1887
|
+
def _setup_logger_callbacks(self):
|
|
1888
|
+
"""
|
|
1889
|
+
Set up callbacks for the training logger.
|
|
1890
|
+
"""
|
|
1891
|
+
epoch_pbar = None
|
|
1892
|
+
step_pbar = None
|
|
1893
|
+
|
|
1894
|
+
def start_training_callback(total_epochs: int):
|
|
1895
|
+
"""
|
|
1896
|
+
Callback function that is called when the training process starts.
|
|
1897
|
+
"""
|
|
1898
|
+
nonlocal epoch_pbar
|
|
1899
|
+
logger.info(f"Training started for {total_epochs} epochs")
|
|
1900
|
+
pbar_widget = self.progress_bar_main
|
|
1901
|
+
pbar_widget.show()
|
|
1902
|
+
epoch_pbar = pbar_widget(message=f"Epochs", total=total_epochs)
|
|
1903
|
+
|
|
1904
|
+
def finish_training_callback():
|
|
1905
|
+
"""
|
|
1906
|
+
Callback function that is called when the training process finishes.
|
|
1907
|
+
"""
|
|
1908
|
+
self.progress_bar_main.hide()
|
|
1909
|
+
self.progress_bar_secondary.hide()
|
|
1910
|
+
|
|
1911
|
+
train_logger = self._app_options.get("train_logger", "")
|
|
1912
|
+
if train_logger == "tensorboard":
|
|
1913
|
+
tb_logger.close()
|
|
1914
|
+
|
|
1915
|
+
def start_epoch_callback(total_steps: int):
|
|
1916
|
+
"""
|
|
1917
|
+
Callback function that is called when a new epoch starts.
|
|
1918
|
+
"""
|
|
1919
|
+
nonlocal step_pbar
|
|
1920
|
+
logger.info(f"Epoch started. Total steps: {total_steps}")
|
|
1921
|
+
pbar_widget = self.progress_bar_secondary
|
|
1922
|
+
pbar_widget.show()
|
|
1923
|
+
step_pbar = pbar_widget(message=f"Steps", total=total_steps)
|
|
1924
|
+
|
|
1925
|
+
def finish_epoch_callback():
|
|
1926
|
+
"""
|
|
1927
|
+
Callback function that is called when an epoch finishes.
|
|
1928
|
+
"""
|
|
1929
|
+
epoch_pbar.update(1)
|
|
1930
|
+
|
|
1931
|
+
def step_callback():
|
|
1932
|
+
"""
|
|
1933
|
+
Callback function that is called when a step iteration is completed.
|
|
1934
|
+
"""
|
|
1935
|
+
step_pbar.update(1)
|
|
1936
|
+
|
|
1937
|
+
tb_logger.add_on_train_started_callback(start_training_callback)
|
|
1938
|
+
tb_logger.add_on_train_finish_callback(finish_training_callback)
|
|
1939
|
+
|
|
1940
|
+
tb_logger.add_on_epoch_started_callback(start_epoch_callback)
|
|
1941
|
+
tb_logger.add_on_epoch_finish_callback(finish_epoch_callback)
|
|
1942
|
+
|
|
1943
|
+
tb_logger.add_on_step_callback(step_callback)
|
|
1944
|
+
|
|
1945
|
+
# ----------------------------------------- #
|
|
1946
|
+
def _wrapped_start_training(self):
|
|
1947
|
+
"""
|
|
1948
|
+
Wrapper function to wrap the training process.
|
|
1949
|
+
"""
|
|
1950
|
+
experiment_info = None
|
|
1951
|
+
|
|
1952
|
+
try:
|
|
1953
|
+
self._set_train_widgets_state_on_start()
|
|
1954
|
+
if self._train_func is None:
|
|
1955
|
+
raise ValueError("Train function is not defined")
|
|
1956
|
+
self._prepare_working_dir()
|
|
1957
|
+
self._init_logger()
|
|
1958
|
+
except Exception as e:
|
|
1959
|
+
message = "Error occurred during training initialization. Please check the logs for more details."
|
|
1960
|
+
self._show_error(message, e)
|
|
1961
|
+
self._restore_train_widgets_state_on_error()
|
|
1962
|
+
|
|
1963
|
+
try:
|
|
1964
|
+
self.gui.training_process.validator_text.set("Preparing data for training...", "info")
|
|
1965
|
+
self._prepare()
|
|
1966
|
+
except Exception as e:
|
|
1967
|
+
message = (
|
|
1968
|
+
"Error occurred during data preparation. Please check the logs for more details."
|
|
1969
|
+
)
|
|
1970
|
+
self._show_error(message, e)
|
|
1971
|
+
self._restore_train_widgets_state_on_error()
|
|
1972
|
+
return
|
|
1973
|
+
|
|
1974
|
+
try:
|
|
1975
|
+
self.gui.training_process.validator_text.set("Training is in progress...", "info")
|
|
1976
|
+
experiment_info = self._train_func()
|
|
1977
|
+
except Exception as e:
|
|
1978
|
+
message = "Error occurred during training. Please check the logs for more details."
|
|
1979
|
+
self._show_error(message, e)
|
|
1980
|
+
self._restore_train_widgets_state_on_error()
|
|
1981
|
+
return
|
|
1982
|
+
|
|
1983
|
+
try:
|
|
1984
|
+
self.gui.training_process.validator_text.set(
|
|
1985
|
+
"Finalizing and uploading training artifacts...", "info"
|
|
1986
|
+
)
|
|
1987
|
+
self._finalize(experiment_info)
|
|
1988
|
+
self.gui.training_process.start_button.loading = False
|
|
1989
|
+
self.gui.training_process.validator_text.set(
|
|
1990
|
+
self.gui.training_process.success_message_text, "success"
|
|
1991
|
+
)
|
|
1992
|
+
except Exception as e:
|
|
1993
|
+
message = "Error occurred during finalizing and uploading training artifacts . Please check the logs for more details."
|
|
1994
|
+
self._show_error(message, e)
|
|
1995
|
+
self._restore_train_widgets_state_on_error()
|
|
1996
|
+
return
|
|
1997
|
+
|
|
1998
|
+
def _show_error(self, message: str, e=None):
|
|
1999
|
+
if e is not None:
|
|
2000
|
+
logger.error(f"{message}: {repr(e)}", exc_info=True)
|
|
2001
|
+
else:
|
|
2002
|
+
logger.error(message)
|
|
2003
|
+
self.gui.training_process.validator_text.set(message, "error")
|
|
2004
|
+
self.gui.training_process.validator_text.show()
|
|
2005
|
+
self.gui.training_process.start_button.loading = False
|
|
2006
|
+
self._restore_train_widgets_state_on_error()
|
|
2007
|
+
|
|
2008
|
+
def _set_train_widgets_state_on_start(self):
|
|
2009
|
+
self._validate_experiment_name()
|
|
2010
|
+
self.gui.training_process.experiment_name_input.disable()
|
|
2011
|
+
if self._app_options.get("device_selector", False):
|
|
2012
|
+
self.gui.training_process.select_device._select.disable()
|
|
2013
|
+
self.gui.training_process.select_device.disable()
|
|
2014
|
+
|
|
2015
|
+
self.gui.training_logs.card.unlock()
|
|
2016
|
+
self.gui.stepper.set_active_step(7)
|
|
2017
|
+
self.gui.training_process.validator_text.set("Training is started...", "info")
|
|
2018
|
+
self.gui.training_process.validator_text.show()
|
|
2019
|
+
self.gui.training_process.start_button.loading = True
|
|
2020
|
+
|
|
2021
|
+
def _restore_train_widgets_state_on_error(self):
|
|
2022
|
+
self.gui.training_logs.card.lock()
|
|
2023
|
+
self.gui.stepper.set_active_step(self.gui.stepper.get_active_step() - 1)
|
|
2024
|
+
self.gui.training_process.experiment_name_input.enable()
|
|
2025
|
+
if self._app_options.get("device_selector", False):
|
|
2026
|
+
self.gui.training_process.select_device._select.enable()
|
|
2027
|
+
self.gui.training_process.select_device.enable()
|
|
2028
|
+
|
|
2029
|
+
def _validate_experiment_name(self) -> bool:
|
|
2030
|
+
experiment_name = self.gui.training_process.get_experiment_name()
|
|
2031
|
+
if not experiment_name:
|
|
2032
|
+
logger.error("Experiment name is empty")
|
|
2033
|
+
raise ValueError("Experiment name is empty")
|
|
2034
|
+
invalid_chars = r"\/\*&^%@$#"
|
|
2035
|
+
if any(char in experiment_name for char in invalid_chars):
|
|
2036
|
+
logger.error(f"Experiment name contains invalid characters: {invalid_chars}")
|
|
2037
|
+
raise ValueError(f"Experiment name contains invalid characters: {invalid_chars}")
|
|
2038
|
+
return True
|