supervisely 6.73.243__py3-none-any.whl → 6.73.245__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (56) hide show
  1. supervisely/__init__.py +1 -1
  2. supervisely/_utils.py +18 -0
  3. supervisely/app/widgets/__init__.py +1 -0
  4. supervisely/app/widgets/card/card.py +3 -0
  5. supervisely/app/widgets/classes_table/classes_table.py +15 -1
  6. supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
  7. supervisely/app/widgets/custom_models_selector/template.html +1 -1
  8. supervisely/app/widgets/experiment_selector/__init__.py +0 -0
  9. supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
  10. supervisely/app/widgets/experiment_selector/style.css +27 -0
  11. supervisely/app/widgets/experiment_selector/template.html +82 -0
  12. supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
  13. supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
  14. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
  15. supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
  16. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  17. supervisely/nn/__init__.py +3 -1
  18. supervisely/nn/artifacts/artifacts.py +10 -0
  19. supervisely/nn/artifacts/detectron2.py +2 -0
  20. supervisely/nn/artifacts/hrda.py +3 -0
  21. supervisely/nn/artifacts/mmclassification.py +2 -0
  22. supervisely/nn/artifacts/mmdetection.py +6 -3
  23. supervisely/nn/artifacts/mmsegmentation.py +2 -0
  24. supervisely/nn/artifacts/ritm.py +3 -1
  25. supervisely/nn/artifacts/rtdetr.py +2 -0
  26. supervisely/nn/artifacts/unet.py +2 -0
  27. supervisely/nn/artifacts/yolov5.py +3 -0
  28. supervisely/nn/artifacts/yolov8.py +7 -1
  29. supervisely/nn/experiments.py +113 -0
  30. supervisely/nn/inference/gui/__init__.py +3 -1
  31. supervisely/nn/inference/gui/gui.py +31 -232
  32. supervisely/nn/inference/gui/serving_gui.py +223 -0
  33. supervisely/nn/inference/gui/serving_gui_template.py +240 -0
  34. supervisely/nn/inference/inference.py +225 -24
  35. supervisely/nn/training/__init__.py +0 -0
  36. supervisely/nn/training/gui/__init__.py +1 -0
  37. supervisely/nn/training/gui/classes_selector.py +100 -0
  38. supervisely/nn/training/gui/gui.py +539 -0
  39. supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
  40. supervisely/nn/training/gui/input_selector.py +70 -0
  41. supervisely/nn/training/gui/model_selector.py +95 -0
  42. supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
  43. supervisely/nn/training/gui/training_logs.py +93 -0
  44. supervisely/nn/training/gui/training_process.py +114 -0
  45. supervisely/nn/training/gui/utils.py +128 -0
  46. supervisely/nn/training/loggers/__init__.py +0 -0
  47. supervisely/nn/training/loggers/base_train_logger.py +58 -0
  48. supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
  49. supervisely/nn/training/train_app.py +2038 -0
  50. supervisely/nn/utils.py +5 -0
  51. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/METADATA +3 -1
  52. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/RECORD +56 -34
  53. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/LICENSE +0 -0
  54. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/WHEEL +0 -0
  55. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/entry_points.txt +0 -0
  56. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2038 @@
1
+ """
2
+ TrainApp module.
3
+
4
+ This module contains the `TrainApp` class and related functionality to facilitate
5
+ training workflows in a Supervisely application.
6
+ """
7
+
8
+ import shutil
9
+ import subprocess
10
+ from datetime import datetime
11
+ from os import listdir
12
+ from os.path import basename, isdir, isfile, join
13
+ from typing import Any, Dict, List, Optional, Union
14
+ from urllib.request import urlopen
15
+
16
+ import httpx
17
+ import yaml
18
+ from fastapi import Request, Response
19
+ from fastapi.responses import StreamingResponse
20
+ from starlette.background import BackgroundTask
21
+
22
+ import supervisely.io.env as sly_env
23
+ import supervisely.io.fs as sly_fs
24
+ import supervisely.io.json as sly_json
25
+ from supervisely import (
26
+ Api,
27
+ Application,
28
+ DatasetInfo,
29
+ OpenMode,
30
+ Project,
31
+ ProjectInfo,
32
+ ProjectMeta,
33
+ WorkflowMeta,
34
+ WorkflowSettings,
35
+ download_project,
36
+ is_development,
37
+ is_production,
38
+ logger,
39
+ )
40
+ from supervisely._utils import get_filename_from_headers
41
+ from supervisely.api.file_api import FileInfo
42
+ from supervisely.app import get_synced_data_dir
43
+ from supervisely.app.widgets import Progress
44
+ from supervisely.nn.benchmark import (
45
+ InstanceSegmentationBenchmark,
46
+ InstanceSegmentationEvaluator,
47
+ ObjectDetectionBenchmark,
48
+ ObjectDetectionEvaluator,
49
+ SemanticSegmentationBenchmark,
50
+ SemanticSegmentationEvaluator,
51
+ )
52
+ from supervisely.nn.inference import RuntimeType, SessionJSON
53
+ from supervisely.nn.task_type import TaskType
54
+ from supervisely.nn.training.gui.gui import TrainGUI
55
+ from supervisely.nn.training.loggers.tensorboard_logger import tb_logger
56
+ from supervisely.nn.utils import ModelSource
57
+ from supervisely.output import set_directory
58
+ from supervisely.project.download import (
59
+ copy_from_cache,
60
+ download_to_cache,
61
+ get_cache_size,
62
+ is_cached,
63
+ )
64
+
65
+
66
+ class TrainApp:
67
+ """
68
+ A class representing the training application.
69
+
70
+ This class initializes and manages the training workflow, including
71
+ handling inputs, hyperparameters, project management, and output artifacts.
72
+
73
+ :param framework_name: Name of the ML framework used.
74
+ :type framework_name: str
75
+ :param models: List of model configurations.
76
+ :type models: List[Dict[str, Any]]
77
+ :param hyperparameters: Path or string content of hyperparameters in YAML format.
78
+ :type hyperparameters: str
79
+ :param app_options: Options for the application layout and behavior.
80
+ :type app_options: Optional[Dict[str, Any]]
81
+ :param work_dir: Path to the working directory for storing intermediate files.
82
+ :type work_dir: Optional[str]
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ framework_name: str,
88
+ models: Union[str, List[Dict[str, Any]]],
89
+ hyperparameters: str,
90
+ app_options: Union[str, Dict[str, Any]] = None,
91
+ work_dir: str = None,
92
+ ):
93
+
94
+ # Init
95
+ self._api = Api.from_env()
96
+
97
+ # Constants
98
+ self._experiment_json_file = "experiment_info.json"
99
+ self._app_state_file = "app_state.json"
100
+ self._train_val_split_file = "train_val_split.json"
101
+ self._hyperparameters_file = "hyperparameters.yaml"
102
+ self._model_meta_file = "model_meta.json"
103
+
104
+ self._sly_project_dir_name = "sly_project"
105
+ self._model_dir_name = "model"
106
+ self._log_dir_name = "logs"
107
+ self._output_dir_name = "output"
108
+ self._output_checkpoints_dir_name = "checkpoints"
109
+ self._remote_checkpoints_dir_name = "checkpoints"
110
+ self._experiments_dir_name = "experiments"
111
+ self._default_work_dir_name = "work_dir"
112
+ self._tensorboard_port = 6006
113
+
114
+ if is_production():
115
+ self.task_id = sly_env.task_id()
116
+ else:
117
+ self.task_id = "debug-session"
118
+ logger.info("TrainApp is running in debug mode")
119
+
120
+ self.framework_name = framework_name
121
+ self._team_id = sly_env.team_id()
122
+ self._workspace_id = sly_env.workspace_id()
123
+ self._app_name = sly_env.app_name(raise_not_found=False)
124
+
125
+ # TODO: read files
126
+ self._models = self._load_models(models)
127
+ self._hyperparameters = self._load_hyperparameters(hyperparameters)
128
+ self._app_options = self._load_app_options(app_options)
129
+ self._inference_class = None
130
+ # ----------------------------------------- #
131
+
132
+ # Directories
133
+ if work_dir is not None:
134
+ self.work_dir = work_dir
135
+ else:
136
+ self.work_dir = join(get_synced_data_dir(), self._default_work_dir_name)
137
+ self.output_dir = join(self.work_dir, self._output_dir_name)
138
+ self._output_checkpoints_dir = join(self.output_dir, self._output_checkpoints_dir_name)
139
+ self.project_dir = join(self.work_dir, self._sly_project_dir_name)
140
+ self.train_dataset_dir = join(self.project_dir, "train")
141
+ self.val_dataset_dir = join(self.project_dir, "val")
142
+ self.sly_project = None
143
+ self.train_split, self.val_split = None, None
144
+ # -------------------------- #
145
+
146
+ # Input
147
+ # ----------------------------------------- #
148
+
149
+ # Classes
150
+ # ----------------------------------------- #
151
+
152
+ # Model
153
+ self.model_files = {}
154
+ self.model_dir = join(self.work_dir, self._model_dir_name)
155
+ self.log_dir = join(self.work_dir, self._log_dir_name)
156
+ # ----------------------------------------- #
157
+
158
+ # Hyperparameters
159
+ # ----------------------------------------- #
160
+
161
+ # Layout
162
+ self.gui: TrainGUI = TrainGUI(
163
+ self.framework_name, self._models, self._hyperparameters, self._app_options
164
+ )
165
+ self.app = Application(layout=self.gui.layout)
166
+ self._server = self.app.get_server()
167
+ self._train_func = None
168
+
169
+ # Benchmark parameters
170
+ if self.is_model_benchmark_enabled:
171
+ self._benchmark_params = {
172
+ "model_files": {},
173
+ "model_source": ModelSource.CUSTOM,
174
+ "model_info": {},
175
+ "device": None,
176
+ "runtime": RuntimeType.PYTORCH,
177
+ }
178
+ # -------------------------- #
179
+
180
+ # Train endpoints
181
+ @self._server.post("/train_from_api")
182
+ def _train_from_api(response: Response, request: Request):
183
+ try:
184
+ state = request.state.state
185
+ app_state = state["app_state"]
186
+ self.gui.load_from_app_state(app_state)
187
+
188
+ self._wrapped_start_training()
189
+
190
+ return {"result": "model was successfully trained"}
191
+ except Exception as e:
192
+ self.gui.training_process.start_button.loading = False
193
+ raise e
194
+
195
+ def _register_routes(self):
196
+ """
197
+ Registers API routes for TensorBoard and training endpoints.
198
+
199
+ These routes enable communication with the application for training
200
+ and visualizing logs in TensorBoard.
201
+ """
202
+ client = httpx.AsyncClient(base_url=f"http://127.0.0.1:{self._tensorboard_port}/")
203
+
204
+ @self._server.post("/tensorboard/{path:path}")
205
+ @self._server.get("/tensorboard/{path:path}")
206
+ async def _proxy_tensorboard(path: str, request: Request):
207
+ url = httpx.URL(path=path, query=request.url.query.encode("utf-8"))
208
+ headers = [(k, v) for k, v in request.headers.raw if k != b"host"]
209
+ req = client.build_request(
210
+ request.method, url, headers=headers, content=request.stream()
211
+ )
212
+ r = await client.send(req, stream=True)
213
+ return StreamingResponse(
214
+ r.aiter_raw(),
215
+ status_code=r.status_code,
216
+ headers=r.headers,
217
+ background=BackgroundTask(r.aclose),
218
+ )
219
+
220
+ def _prepare_working_dir(self):
221
+ """
222
+ Prepares the working directory by creating required subdirectories.
223
+ """
224
+ sly_fs.mkdir(self.work_dir, True)
225
+ sly_fs.mkdir(self.output_dir, True)
226
+ sly_fs.mkdir(self._output_checkpoints_dir, True)
227
+ sly_fs.mkdir(self.project_dir, True)
228
+ sly_fs.mkdir(self.model_dir, True)
229
+ sly_fs.mkdir(self.log_dir, True)
230
+
231
+ # Properties
232
+ # General
233
+ # ----------------------------------------- #
234
+
235
+ # Input Data
236
+ @property
237
+ def project_id(self) -> int:
238
+ """
239
+ Returns the ID of the project.
240
+
241
+ :return: Project ID.
242
+ :rtype: int
243
+ """
244
+ return self.gui.project_id
245
+
246
+ @property
247
+ def project_name(self) -> str:
248
+ """
249
+ Returns the name of the project.
250
+
251
+ :return: Project name.
252
+ :rtype: str
253
+ """
254
+ return self.gui.project_info.name
255
+
256
+ @property
257
+ def project_info(self) -> ProjectInfo:
258
+ """
259
+ Returns ProjectInfo object, which contains information about the project.
260
+
261
+ :return: Project name.
262
+ :rtype: str
263
+ """
264
+ return self.gui.project_info
265
+
266
+ # ----------------------------------------- #
267
+
268
+ # Model
269
+ @property
270
+ def model_source(self) -> str:
271
+ """
272
+ Return whether the model is pretrained or custom.
273
+
274
+ :return: Model source.
275
+ :rtype: str
276
+ """
277
+ return self.gui.model_selector.get_model_source()
278
+
279
+ @property
280
+ def model_name(self) -> str:
281
+ """
282
+ Returns the name of the model.
283
+
284
+ :return: Model name.
285
+ :rtype: str
286
+ """
287
+ return self.gui.model_selector.get_model_name()
288
+
289
+ @property
290
+ def model_info(self) -> dict:
291
+ """
292
+ Returns a selected row in dict format from the models table.
293
+
294
+ :return: Model name.
295
+ :rtype: str
296
+ """
297
+ return self.gui.model_selector.get_model_info()
298
+
299
+ @property
300
+ def model_meta(self) -> ProjectMeta:
301
+ """
302
+ Returns the model metadata.
303
+
304
+ :return: Model metadata.
305
+ :rtype: dict
306
+ """
307
+ project_meta_json = self.sly_project.meta.to_json()
308
+ model_meta = {
309
+ "classes": [
310
+ item for item in project_meta_json["classes"] if item["title"] in self.classes
311
+ ]
312
+ }
313
+ return ProjectMeta.from_json(model_meta)
314
+
315
+ @property
316
+ def device(self) -> str:
317
+ """
318
+ Returns the selected device for training.
319
+
320
+ :return: Device name.
321
+ :rtype: str
322
+ """
323
+ return self.gui.training_process.get_device()
324
+
325
+ # Classes
326
+ @property
327
+ def classes(self) -> List[str]:
328
+ """
329
+ Returns the selected classes for training.
330
+
331
+ :return: List of selected classes.
332
+ :rtype: List[str]
333
+ """
334
+ return self.gui.classes_selector.get_selected_classes()
335
+
336
+ @property
337
+ def num_classes(self) -> int:
338
+ """
339
+ Returns the number of selected classes for training.
340
+
341
+ :return: Number of selected classes.
342
+ :rtype: int
343
+ """
344
+ return len(self.gui.classes_selector.get_selected_classes())
345
+
346
+ # Hyperparameters
347
+ @property
348
+ def hyperparameters(self) -> Dict[str, Any]:
349
+ """
350
+ Returns the selected hyperparameters for training in dict format.
351
+
352
+ :return: Hyperparameters in dict format.
353
+ :rtype: Dict[str, Any]
354
+ """
355
+ return yaml.safe_load(self.hyperparameters_yaml)
356
+
357
+ @property
358
+ def hyperparameters_yaml(self) -> str:
359
+ """
360
+ Returns the selected hyperparameters for training in raw format as a string.
361
+
362
+ :return: Hyperparameters in raw format.
363
+ :rtype: str
364
+ """
365
+ return self.gui.hyperparameters_selector.get_hyperparameters()
366
+
367
+ # Train Process
368
+ @property
369
+ def progress_bar_main(self) -> Progress:
370
+ """
371
+ Returns the main progress bar widget.
372
+
373
+ :return: Main progress bar widget.
374
+ :rtype: Progress
375
+ """
376
+ return self.gui.training_logs.progress_bar_main
377
+
378
+ @property
379
+ def progress_bar_secondary(self) -> Progress:
380
+ """
381
+ Returns the secondary progress bar widget.
382
+
383
+ :return: Secondary progress bar widget.
384
+ :rtype: Progress
385
+ """
386
+ return self.gui.training_logs.progress_bar_secondary
387
+
388
+ @property
389
+ def is_model_benchmark_enabled(self) -> bool:
390
+ """
391
+ Checks if model benchmarking is enabled based on application options and GUI settings.
392
+
393
+ :return: True if model benchmarking is enabled, False otherwise.
394
+ :rtype: bool
395
+ """
396
+ return (
397
+ self._app_options.get("model_benchmark", True)
398
+ and self.gui.hyperparameters_selector.get_model_benchmark_checkbox_value()
399
+ )
400
+
401
+ # Output
402
+ # ----------------------------------------- #
403
+
404
+ # region TRAIN START
405
+ @property
406
+ def start(self):
407
+ """
408
+ Decorator for the training function defined by user.
409
+ It wraps user-defined training function and prepares and finalizes the training process.
410
+ """
411
+
412
+ def decorator(func):
413
+ self._train_func = func
414
+ self.gui.training_process.start_button.click(self._wrapped_start_training)
415
+ return func
416
+
417
+ return decorator
418
+
419
+ def _prepare(self) -> None:
420
+ """
421
+ Prepares the environment for training by setting up directories,
422
+ downloading project and model data.
423
+ """
424
+ logger.info("Preparing for training")
425
+ self.gui.disable_select_buttons()
426
+
427
+ # Step 1. Workflow Input
428
+ if is_production():
429
+ self._workflow_input()
430
+ # Step 2. Download Project
431
+ self._download_project()
432
+ # Step 3. Split Project
433
+ self._split_project()
434
+ # Step 4. Convert Supervisely to X format
435
+ # Step 5. Download Model files
436
+ self._download_model()
437
+
438
+ def _finalize(self, experiment_info: dict) -> None:
439
+ """
440
+ Finalizes the training process by validating outputs, uploading artifacts,
441
+ and updating the UI.
442
+
443
+ :param experiment_info: Information about the experiment results that should be returned in user's training function.
444
+ :type experiment_info: dict
445
+ """
446
+ logger.info("Finalizing training")
447
+
448
+ # Step 1. Validate experiment_info
449
+ success, reason = self._validate_experiment_info(experiment_info)
450
+ if not success:
451
+ raise ValueError(f"{reason}. Failed to upload artifacts")
452
+
453
+ # Step 2. Preprocess artifacts
454
+ self._preprocess_artifacts(experiment_info)
455
+
456
+ # Step3. Postprocess splits
457
+ splits_data = self._postprocess_splits()
458
+
459
+ # Step 3. Upload artifacts
460
+ remote_dir, file_info = self._upload_artifacts()
461
+
462
+ # Step 4. Run Model Benchmark
463
+ mb_eval_report, mb_eval_report_id = None, None
464
+
465
+ if self.is_model_benchmark_enabled:
466
+ try:
467
+ mb_eval_report, mb_eval_report_id = self._run_model_benchmark(
468
+ self.output_dir, remote_dir, experiment_info, splits_data
469
+ )
470
+ except Exception as e:
471
+ logger.error(f"Model benchmark failed: {e}")
472
+
473
+ # Step 4. Generate and upload additional files
474
+ self._generate_experiment_info(remote_dir, experiment_info, mb_eval_report_id)
475
+ self._generate_app_state(remote_dir, experiment_info)
476
+ self._generate_hyperparameters(remote_dir, experiment_info)
477
+ self._generate_train_val_splits(remote_dir, splits_data)
478
+ self._generate_model_meta(remote_dir, experiment_info)
479
+
480
+ # Step 5. Set output widgets
481
+ self._set_training_output(remote_dir, file_info)
482
+
483
+ # Step 6. Workflow output
484
+ if is_production():
485
+ self._workflow_output(remote_dir, file_info, mb_eval_report)
486
+
487
+ # region TRAIN END
488
+
489
+ def register_inference_class(self, inference_class: Any, inference_settings: dict = {}) -> None:
490
+ """
491
+ Registers an inference class for the training application to do model benchmarking.
492
+
493
+ :param inference_class: Inference class to be registered inherited from `supervisely.nn.inference.Inference`.
494
+ :type inference_class: Any
495
+ :param inference_settings: Settings for the inference class.
496
+ :type inference_settings: dict
497
+ """
498
+ self._inference_class = inference_class
499
+ self._inference_settings = inference_settings
500
+
501
+ def get_app_state(self, experiment_info: dict = None) -> dict:
502
+ """
503
+ Returns the current state of the application.
504
+
505
+ :return: Application state.
506
+ :rtype: dict
507
+ """
508
+ input_data = {"project_id": self.project_id}
509
+ train_val_splits = self._get_train_val_splits_for_app_state()
510
+ model = self._get_model_config_for_app_state(experiment_info)
511
+
512
+ options = {
513
+ "model_benchmark": {
514
+ "enable": self.gui.hyperparameters_selector.get_model_benchmark_checkbox_value(),
515
+ "speed_test": self.gui.hyperparameters_selector.get_speedtest_checkbox_value(),
516
+ },
517
+ "cache_project": self.gui.input_selector.get_cache_value(),
518
+ }
519
+
520
+ app_state = {
521
+ "input": input_data,
522
+ "train_val_split": train_val_splits,
523
+ "classes": self.classes,
524
+ "model": model,
525
+ "hyperparameters": self.hyperparameters_yaml,
526
+ "options": options,
527
+ }
528
+ return app_state
529
+
530
+ def load_app_state(self, app_state: dict) -> None:
531
+ """
532
+ Load the GUI state from app state dictionary.
533
+
534
+ :param app_state: The state dictionary.
535
+ :type app_state: dict
536
+
537
+ app_state example:
538
+
539
+ app_state = {
540
+ "input": {"project_id": 55555},
541
+ "train_val_splits": {
542
+ "method": "random",
543
+ "split": "train",
544
+ "percent": 90
545
+ },
546
+ "classes": ["apple"],
547
+ "model": {
548
+ "source": "Pretrained models",
549
+ "model_name": "rtdetr_r50vd_coco_objects365"
550
+ },
551
+ "hyperparameters": hyperparameters, # yaml string
552
+ "options": {
553
+ "model_benchmark": {
554
+ "enable": True,
555
+ "speed_test": True
556
+ },
557
+ "cache_project": True
558
+ }
559
+ }
560
+ """
561
+ self.gui.load_from_app_state(app_state)
562
+
563
+ # Loaders
564
+ def _load_models(self, models: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
565
+ """
566
+ Loads models from the provided file or list of model configurations.
567
+ """
568
+ if isinstance(models, str):
569
+ if sly_fs.file_exists(models) and sly_fs.get_file_ext(models) == ".json":
570
+ models = sly_json.load_json_file(models)
571
+ else:
572
+ raise ValueError(
573
+ "Invalid models file. Please provide a valid '.json' file or a list of model configurations."
574
+ )
575
+
576
+ if not isinstance(models, list):
577
+ raise ValueError("models parameters must be a list of dicts")
578
+ for item in models:
579
+ if not isinstance(item, dict):
580
+ raise ValueError(f"Each item in models must be a dict.")
581
+ model_meta = item.get("meta")
582
+ if model_meta is None:
583
+ raise ValueError(
584
+ "Model metadata not found. Please update provided models parameter to include key 'meta'."
585
+ )
586
+ model_files = model_meta.get("model_files")
587
+ if model_files is None:
588
+ raise ValueError(
589
+ "Model files not found in model metadata. "
590
+ "Please update provided models oarameter to include key 'model_files' in 'meta' key."
591
+ )
592
+ return models
593
+
594
+ def _load_hyperparameters(self, hyperparameters: str) -> dict:
595
+ """
596
+ Loads hyperparameters from file path.
597
+
598
+ :param hyperparameters: Path to hyperparameters file.
599
+ :type hyperparameters: str
600
+ :return: Hyperparameters in dict format.
601
+ :rtype: dict
602
+ """
603
+ if not isinstance(hyperparameters, str):
604
+ raise ValueError(
605
+ f"Expected a string with config or path for hyperparameters, but got {type(hyperparameters).__name__}"
606
+ )
607
+ if hyperparameters.endswith((".yml", ".yaml")):
608
+ try:
609
+ with open(hyperparameters, "r") as file:
610
+ return file.read()
611
+ except Exception as e:
612
+ raise ValueError(f"Failed to load YAML file: {hyperparameters}. Error: {e}")
613
+ return hyperparameters
614
+
615
+ def _load_app_options(self, app_options: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
616
+ """
617
+ Loads the app_options parameter to ensure it is in the correct format.
618
+ """
619
+ if app_options is None:
620
+ return {}
621
+
622
+ if isinstance(app_options, str):
623
+ if sly_fs.file_exists(app_options) and sly_fs.get_file_ext(app_options) in [
624
+ ".yaml",
625
+ ".yml",
626
+ ]:
627
+ app_options = self._load_yaml(app_options)
628
+ else:
629
+ raise ValueError(
630
+ "Invalid app_options file provided. Please provide a valid '.yaml' or '.yml' file or a dictionary with app_options."
631
+ )
632
+ if not isinstance(app_options, dict):
633
+ raise ValueError("app_options must be a dict")
634
+ return app_options
635
+
636
+ def _load_yaml(self, path: str) -> dict:
637
+ """
638
+ Load a YAML file from the specified path.
639
+
640
+ :param path: Path to the YAML file.
641
+ :type path: str
642
+ :return: YAML file contents.
643
+ :rtype: dict
644
+ """
645
+ with open(path, "r") as file:
646
+ return yaml.safe_load(file)
647
+
648
+ # ----------------------------------------- #
649
+
650
+ # Preprocess
651
+ # Download Project
652
+ def _download_project(self) -> None:
653
+ """
654
+ Downloads the project data from Supervisely.
655
+ If the cache is enabled, it will attempt to retrieve the project from the cache.
656
+ """
657
+ dataset_infos = [dataset for _, dataset in self._api.dataset.tree(self.project_id)]
658
+
659
+ if self.gui.train_val_splits_selector.get_split_method() == "Based on datasets":
660
+ selected_ds_ids = (
661
+ self.gui.train_val_splits_selector.get_train_dataset_ids()
662
+ + self.gui.train_val_splits_selector.get_val_dataset_ids()
663
+ )
664
+ dataset_infos = [ds_info for ds_info in dataset_infos if ds_info.id in selected_ds_ids]
665
+
666
+ total_images = sum(ds_info.images_count for ds_info in dataset_infos)
667
+ if not self.gui.input_selector.get_cache_value() or is_development():
668
+ self._download_no_cache(dataset_infos, total_images)
669
+ self.sly_project = Project(self.project_dir, OpenMode.READ)
670
+ return
671
+
672
+ try:
673
+ self._download_with_cache(dataset_infos, total_images)
674
+ except Exception:
675
+ logger.warning(
676
+ "Failed to retrieve project from cache. Downloading it",
677
+ exc_info=True,
678
+ )
679
+ if sly_fs.dir_exists(self.project_dir):
680
+ sly_fs.clean_dir(self.project_dir)
681
+ self._download_no_cache(dataset_infos, total_images)
682
+ finally:
683
+ self.sly_project = Project(self.project_dir, OpenMode.READ)
684
+ logger.info(f"Project downloaded successfully to: '{self.project_dir}'")
685
+
686
+ def _download_no_cache(self, dataset_infos: List[DatasetInfo], total_images: int) -> None:
687
+ """
688
+ Downloads the project data from Supervisely without using the cache.
689
+
690
+ :param dataset_infos: List of dataset information objects.
691
+ :type dataset_infos: List[DatasetInfo]
692
+ :param total_images: Total number of images to download.
693
+ :type total_images: int
694
+ """
695
+ with self.progress_bar_main(message="Downloading input data", total=total_images) as pbar:
696
+ self.progress_bar_main.show()
697
+ download_project(
698
+ api=self._api,
699
+ project_id=self.project_id,
700
+ dest_dir=self.project_dir,
701
+ dataset_ids=[ds_info.id for ds_info in dataset_infos],
702
+ log_progress=True,
703
+ progress_cb=pbar.update,
704
+ )
705
+ self.progress_bar_main.hide()
706
+
707
+ def _download_with_cache(
708
+ self,
709
+ dataset_infos: List[DatasetInfo],
710
+ total_images: int,
711
+ ) -> None:
712
+ """
713
+ Downloads the project data from Supervisely using the cache.
714
+
715
+ :param dataset_infos: List of dataset information objects.
716
+ :type dataset_infos: List[DatasetInfo]
717
+ :param total_images: Total number of images to download.
718
+ :type total_images: int
719
+ """
720
+ to_download = [
721
+ info for info in dataset_infos if not is_cached(self.project_info.id, info.name)
722
+ ]
723
+ cached = [info for info in dataset_infos if is_cached(self.project_info.id, info.name)]
724
+
725
+ logger.info(self._get_cache_log_message(cached, to_download))
726
+ with self.progress_bar_main(message="Downloading input data", total=total_images) as pbar:
727
+ self.progress_bar_main.show()
728
+ download_to_cache(
729
+ api=self._api,
730
+ project_id=self.project_info.id,
731
+ dataset_infos=dataset_infos,
732
+ log_progress=True,
733
+ progress_cb=pbar.update,
734
+ )
735
+
736
+ total_cache_size = sum(
737
+ get_cache_size(self.project_info.id, ds.name) for ds in dataset_infos
738
+ )
739
+ with self.progress_bar_main(
740
+ message="Retrieving data from cache",
741
+ total=total_cache_size,
742
+ unit="B",
743
+ unit_scale=True,
744
+ unit_divisor=1024,
745
+ ) as pbar:
746
+ copy_from_cache(
747
+ project_id=self.project_info.id,
748
+ dest_dir=self.project_dir,
749
+ dataset_names=[ds_info.name for ds_info in dataset_infos],
750
+ progress_cb=pbar.update,
751
+ )
752
+ self.progress_bar_main.hide()
753
+
754
+ def _get_cache_log_message(self, cached: bool, to_download: List[DatasetInfo]) -> str:
755
+ """
756
+ Utility method to generate a log message for cache status.
757
+ """
758
+ if not cached:
759
+ log_msg = "No cached datasets found"
760
+ else:
761
+ log_msg = "Using cached datasets: " + ", ".join(
762
+ f"{ds_info.name} ({ds_info.id})" for ds_info in cached
763
+ )
764
+
765
+ if not to_download:
766
+ log_msg += ". All datasets are cached. No datasets to download"
767
+ else:
768
+ log_msg += ". Downloading datasets: " + ", ".join(
769
+ f"{ds_info.name} ({ds_info.id})" for ds_info in to_download
770
+ )
771
+
772
+ return log_msg
773
+
774
+ # Split Project
775
+ def _split_project(self) -> None:
776
+ """
777
+ Split the project into training and validation sets.
778
+ All images and annotations will be renamed and moved to the appropriate directories.
779
+ Assigns self.sly_project to the new project, which contains only 2 datasets: train and val.
780
+ """
781
+ # Load splits
782
+ self.gui.train_val_splits_selector.set_sly_project(self.sly_project)
783
+ self._train_split, self._val_split = (
784
+ self.gui.train_val_splits_selector.train_val_splits.get_splits()
785
+ )
786
+
787
+ # Prepare paths
788
+ project_split_path = join(self.work_dir, "splits")
789
+ paths = {
790
+ "train": {
791
+ "split_path": join(project_split_path, "train"),
792
+ "img_dir": join(project_split_path, "train", "img"),
793
+ "ann_dir": join(project_split_path, "train", "ann"),
794
+ },
795
+ "val": {
796
+ "split_path": join(project_split_path, "val"),
797
+ "img_dir": join(project_split_path, "val", "img"),
798
+ "ann_dir": join(project_split_path, "val", "ann"),
799
+ },
800
+ }
801
+
802
+ # Create necessary directories (only once)
803
+ for dataset_paths in paths.values():
804
+ for path in dataset_paths.values():
805
+ sly_fs.mkdir(path, True)
806
+
807
+ # Format for image names
808
+ items_count = max(len(self._train_split), len(self._val_split))
809
+ num_digits = len(str(items_count))
810
+ image_name_formats = {
811
+ "train": f"train_img_{{:0{num_digits}d}}",
812
+ "val": f"val_img_{{:0{num_digits}d}}",
813
+ }
814
+
815
+ # Utility function to move files
816
+ def move_files(split, paths, img_name_format, pbar):
817
+ """
818
+ Move files to the appropriate directories.
819
+ """
820
+ for idx, item in enumerate(split, start=1):
821
+ item_name = img_name_format.format(idx) + sly_fs.get_file_ext(item.name)
822
+ ann_name = f"{item_name}.json"
823
+ shutil.copy(item.img_path, join(paths["img_dir"], item_name))
824
+ shutil.copy(item.ann_path, join(paths["ann_dir"], ann_name))
825
+ pbar.update(1)
826
+
827
+ # Main split processing
828
+ with self.progress_bar_main(
829
+ message="Applying train/val splits to project", total=2
830
+ ) as main_pbar:
831
+ self.progress_bar_main.show()
832
+ for dataset in ["train", "val"]:
833
+ split = self._train_split if dataset == "train" else self._val_split
834
+ with self.progress_bar_secondary(
835
+ message=f"Preparing '{dataset}'", total=len(split)
836
+ ) as second_pbar:
837
+ self.progress_bar_secondary.show()
838
+ move_files(split, paths[dataset], image_name_formats[dataset], second_pbar)
839
+ main_pbar.update(1)
840
+ self.progress_bar_secondary.hide()
841
+ self.progress_bar_main.hide()
842
+
843
+ # Clean up project directory
844
+ project_datasets = [
845
+ join(self.project_dir, item)
846
+ for item in listdir(self.project_dir)
847
+ if isdir(join(self.project_dir, item))
848
+ ]
849
+ for dataset in project_datasets:
850
+ sly_fs.remove_dir(dataset)
851
+
852
+ # Move processed splits to final destination
853
+ train_ds_path = join(self.project_dir, "train")
854
+ val_ds_path = join(self.project_dir, "val")
855
+ with self.progress_bar_main(message="Processing splits", total=2) as pbar:
856
+ self.progress_bar_main.show()
857
+ for dataset in ["train", "val"]:
858
+ shutil.move(
859
+ paths[dataset]["split_path"],
860
+ train_ds_path if dataset == "train" else val_ds_path,
861
+ )
862
+ pbar.update(1)
863
+ self.progress_bar_main.hide()
864
+
865
+ # Clean up temporary directory
866
+ sly_fs.remove_dir(project_split_path)
867
+ self.sly_project = Project(self.project_dir, OpenMode.READ)
868
+
869
+ # ----------------------------------------- #
870
+
871
+ # ----------------------------------------- #
872
+ # Download Model
873
+ def _download_model(self) -> None:
874
+ """
875
+ Downloads the model data from the selected source.
876
+ - Checkpoint and config keys inside the model_files dict can be provided as local paths.
877
+
878
+ For Pretrained models:
879
+ - The files that will be downloaded are specified in the `meta` key under `model_files`.
880
+ - All files listed in the `model_files` key will be downloaded by provided link.
881
+ Example of a pretrained model entry:
882
+ [
883
+ {
884
+ "Model": "example_model",
885
+ "dataset": "COCO",
886
+ "AP_val": 46.4,
887
+ "Params(M)": 20,
888
+ "FPS(T4)": 217,
889
+ "meta": {
890
+ "task_type": "object detection",
891
+ "model_name": "example_model",
892
+ "model_files": {
893
+ # For remote files provide as links
894
+ "checkpoint": "https://example.com/checkpoint.pth",
895
+ "config": "https://example.com/config.yaml"
896
+
897
+ # For local files provide as paths
898
+ # "checkpoint": "/path/to/checkpoint.pth",
899
+ # "config": "/path/to/config.yaml"
900
+ }
901
+ }
902
+ },
903
+ ...
904
+ ]
905
+
906
+ For Custom models:
907
+ - All custom models trained inside Supervisely are managed automatically by this class.
908
+ """
909
+ if self.model_source == ModelSource.PRETRAINED:
910
+ self._download_pretrained_model()
911
+
912
+ else:
913
+ self._download_custom_model()
914
+ logger.info(f"Model files have been downloaded successfully to: '{self.model_dir}'")
915
+
916
+ def _download_pretrained_model(self):
917
+ """
918
+ Downloads the pretrained model data.
919
+ """
920
+ # General
921
+ self.model_files = {}
922
+ model_meta = self.model_info["meta"]
923
+ model_files = model_meta["model_files"]
924
+
925
+ with self.progress_bar_main(
926
+ message="Downloading model files",
927
+ total=len(model_files),
928
+ ) as model_download_main_pbar:
929
+ self.progress_bar_main.show()
930
+ for file in model_files:
931
+ file_url = model_files[file]
932
+ file_path = join(self.model_dir, file)
933
+
934
+ if file_url.startswith("http"):
935
+ with urlopen(file_url) as f:
936
+ file_size = f.length
937
+ file_name = get_filename_from_headers(file_url)
938
+ file_path = join(self.model_dir, file_name)
939
+ with self.progress_bar_secondary(
940
+ message=f"Downloading '{file_name}' ",
941
+ total=file_size,
942
+ unit="bytes",
943
+ unit_scale=True,
944
+ ) as model_download_secondary_pbar:
945
+ self.progress_bar_secondary.show()
946
+ sly_fs.download(
947
+ url=file_url,
948
+ save_path=file_path,
949
+ progress=model_download_secondary_pbar.update,
950
+ )
951
+ self.model_files[file] = file_path
952
+ else:
953
+ self.model_files[file] = file_url
954
+ model_download_main_pbar.update(1)
955
+
956
+ self.progress_bar_main.hide()
957
+ self.progress_bar_secondary.hide()
958
+
959
+ def _download_custom_model(self):
960
+ """
961
+ Downloads the custom model data.
962
+ """
963
+ # General
964
+ self.model_files = {}
965
+
966
+ # Need to merge file_url with arts dir
967
+ artifacts_dir = self.model_info["artifacts_dir"]
968
+ model_files = self.model_info["model_files"]
969
+ remote_paths = {name: join(artifacts_dir, file) for name, file in model_files.items()}
970
+
971
+ # Add selected checkpoint to model_files
972
+ checkpoint = self.gui.model_selector.experiment_selector.get_selected_checkpoint_path()
973
+ remote_paths["checkpoint"] = checkpoint
974
+
975
+ with self.progress_bar_main(
976
+ message="Downloading model files",
977
+ total=len(model_files),
978
+ ) as model_download_main_pbar:
979
+ self.progress_bar_main.show()
980
+ for name, remote_path in remote_paths.items():
981
+ file_info = self._api.file.get_info_by_path(self._team_id, remote_path)
982
+ file_name = basename(remote_path)
983
+ local_path = join(self.model_dir, file_name)
984
+ file_size = file_info.sizeb
985
+
986
+ with self.progress_bar_secondary(
987
+ message=f"Downloading {file_name}",
988
+ total=file_size,
989
+ unit="bytes",
990
+ unit_scale=True,
991
+ ) as model_download_secondary_pbar:
992
+ self.progress_bar_secondary.show()
993
+ self._api.file.download(
994
+ self._team_id,
995
+ remote_path,
996
+ local_path,
997
+ progress_cb=model_download_secondary_pbar.update,
998
+ )
999
+ model_download_main_pbar.update(1)
1000
+ self.model_files[name] = local_path
1001
+
1002
+ self.progress_bar_main.hide()
1003
+ self.progress_bar_secondary.hide()
1004
+
1005
+ # ----------------------------------------- #
1006
+
1007
+ # Postprocess
1008
+
1009
+ def _validate_experiment_info(self, experiment_info: dict) -> tuple:
1010
+ """
1011
+ Validates the experiment_info parameter to ensure it is in the correct format.
1012
+ experiment_info is returned by the user's training function.
1013
+
1014
+ experiment_info should contain the following keys:
1015
+ - model_name": str
1016
+ - task_type": str
1017
+ - model_files": dict
1018
+ - checkpoints": list
1019
+ - best_checkpoint": str
1020
+
1021
+ Other keys are generated by the TrainApp class automatically
1022
+
1023
+ :param experiment_info: Information about the experiment results.
1024
+ :type experiment_info: dict
1025
+ :return: Tuple of success status and reason for failure.
1026
+ :rtype: tuple
1027
+ """
1028
+ if not isinstance(experiment_info, dict):
1029
+ reason = f"Validation failed: 'experiment_info' must be a dictionary not '{type(experiment_info)}'"
1030
+ return False, reason
1031
+
1032
+ logger.debug("Starting validation of 'experiment_info'")
1033
+ required_keys = {
1034
+ "model_name": str,
1035
+ "task_type": str,
1036
+ "model_files": dict,
1037
+ "checkpoints": (list, str),
1038
+ "best_checkpoint": str,
1039
+ }
1040
+
1041
+ for key, expected_type in required_keys.items():
1042
+ if key not in experiment_info:
1043
+ reason = f"Validation failed: Missing required key '{key}'"
1044
+ return False, reason
1045
+
1046
+ if not isinstance(experiment_info[key], expected_type):
1047
+ reason = (
1048
+ f"Validation failed: Key '{key}' should be of type {expected_type.__name__}"
1049
+ )
1050
+ return False, reason
1051
+
1052
+ if isinstance(experiment_info["checkpoints"], list):
1053
+ for checkpoint in experiment_info["checkpoints"]:
1054
+ if not isinstance(checkpoint, str):
1055
+ reason = "Validation failed: All items in 'checkpoints' list must be strings"
1056
+ return False, reason
1057
+ if not sly_fs.file_exists(checkpoint):
1058
+ reason = f"Validation failed: Checkpoint file: '{checkpoint}' does not exist"
1059
+ return False, reason
1060
+
1061
+ best_checkpoint = experiment_info["best_checkpoint"]
1062
+ checkpoints = experiment_info["checkpoints"]
1063
+ if isinstance(checkpoints, list):
1064
+ checkpoints = [sly_fs.get_file_name_with_ext(checkpoint) for checkpoint in checkpoints]
1065
+ if best_checkpoint not in checkpoints:
1066
+ reason = (
1067
+ f"Validation failed: Best checkpoint file: '{best_checkpoint}' does not exist"
1068
+ )
1069
+ return False, reason
1070
+ elif isinstance(checkpoints, str):
1071
+ checkpoints = [
1072
+ sly_fs.get_file_name_with_ext(checkpoint)
1073
+ for checkpoint in sly_fs.list_dir_recursively(checkpoints, [".pt", ".pth"])
1074
+ ]
1075
+ if best_checkpoint not in checkpoints:
1076
+ reason = (
1077
+ f"Validation failed: Best checkpoint file: '{best_checkpoint}' does not exist"
1078
+ )
1079
+ return False, reason
1080
+ else:
1081
+ reason = "Validation failed: 'checkpoints' should be a list of paths or a path to directory with checkpoints"
1082
+ return False, reason
1083
+
1084
+ logger.debug("Validation successful")
1085
+ return True, None
1086
+
1087
+ def _postprocess_splits(self) -> dict:
1088
+ """
1089
+ Processes the train and val splits to generate the necessary data for the experiment_info.json file.
1090
+ """
1091
+ val_dataset_ids = None
1092
+ val_images_ids = None
1093
+ train_dataset_ids = None
1094
+ train_images_ids = None
1095
+
1096
+ split_method = self.gui.train_val_splits_selector.get_split_method()
1097
+ train_set, val_set = self._train_split, self._val_split
1098
+ if split_method == "Based on datasets":
1099
+ val_dataset_ids = self.gui.train_val_splits_selector.get_val_dataset_ids()
1100
+ train_dataset_ids = self.gui.train_val_splits_selector.get_train_dataset_ids
1101
+ else:
1102
+ dataset_infos = [dataset for _, dataset in self._api.dataset.tree(self.project_id)]
1103
+ ds_infos_dict = {}
1104
+ for dataset in dataset_infos:
1105
+ if dataset.parent_id is not None:
1106
+ parent_ds = self._api.dataset.get_info_by_id(dataset.parent_id)
1107
+ dataset_name = f"{parent_ds.name}/{dataset.name}"
1108
+ else:
1109
+ dataset_name = dataset.name
1110
+ ds_infos_dict[dataset_name] = dataset
1111
+
1112
+ def get_image_infos_by_split(ds_infos_dict: dict, split: list):
1113
+ image_names_per_dataset = {}
1114
+ for item in split:
1115
+ image_names_per_dataset.setdefault(item.dataset_name, []).append(item.name)
1116
+ image_infos = []
1117
+ for dataset_name, image_names in image_names_per_dataset.items():
1118
+ ds_info = ds_infos_dict[dataset_name]
1119
+ image_infos.extend(
1120
+ self._api.image.get_list(
1121
+ ds_info.id,
1122
+ filters=[
1123
+ {
1124
+ "field": "name",
1125
+ "operator": "in",
1126
+ "value": image_names,
1127
+ }
1128
+ ],
1129
+ )
1130
+ )
1131
+ return image_infos
1132
+
1133
+ val_image_infos = get_image_infos_by_split(ds_infos_dict, val_set)
1134
+ train_image_infos = get_image_infos_by_split(ds_infos_dict, train_set)
1135
+ val_images_ids = [img_info.id for img_info in val_image_infos]
1136
+ train_images_ids = [img_info.id for img_info in train_image_infos]
1137
+
1138
+ splits_data = {
1139
+ "train": {
1140
+ "dataset_ids": train_dataset_ids,
1141
+ "images_ids": train_images_ids,
1142
+ },
1143
+ "val": {
1144
+ "dataset_ids": val_dataset_ids,
1145
+ "images_ids": val_images_ids,
1146
+ },
1147
+ }
1148
+ return splits_data
1149
+
1150
+ def _preprocess_artifacts(self, experiment_info: dict) -> None:
1151
+ """
1152
+ Preprocesses and move the artifacts generated by the training process to output directories.
1153
+
1154
+ :param experiment_info: Information about the experiment results.
1155
+ :type experiment_info: dict
1156
+ """
1157
+ # Preprocess artifacts
1158
+ logger.debug("Preprocessing artifacts")
1159
+ if "model_files" not in experiment_info:
1160
+ experiment_info["model_files"] = {}
1161
+ else:
1162
+ # Move model files to output directory except config, config will be processed next
1163
+ files = {k: v for k, v in experiment_info["model_files"].items() if k != "config"}
1164
+ for file in files:
1165
+ if isfile:
1166
+ shutil.move(
1167
+ file,
1168
+ join(self.output_dir, sly_fs.get_file_name_with_ext(file)),
1169
+ )
1170
+ elif isdir:
1171
+ shutil.move(file, join(self.output_dir, basename(file)))
1172
+
1173
+ # Preprocess config
1174
+ logger.debug("Preprocessing config")
1175
+ config = experiment_info["model_files"].get("config")
1176
+ if config is not None:
1177
+ config_name = sly_fs.get_file_name_with_ext(experiment_info["model_files"]["config"])
1178
+ output_config_path = join(self.output_dir, config_name)
1179
+ shutil.move(experiment_info["model_files"]["config"], output_config_path)
1180
+ if self.is_model_benchmark_enabled:
1181
+ self._benchmark_params["model_files"]["config"] = output_config_path
1182
+
1183
+ # Prepare checkpoints
1184
+ checkpoints = experiment_info["checkpoints"]
1185
+ # If checkpoints returned as directory
1186
+ if isinstance(checkpoints, str):
1187
+ checkpoint_paths = []
1188
+ for checkpoint_path in sly_fs.list_files_recursively(checkpoints, [".pt", ".pth"]):
1189
+ checkpoint_paths.append(checkpoint_path)
1190
+ elif isinstance(checkpoints, list):
1191
+ checkpoint_paths = checkpoints
1192
+ else:
1193
+ raise ValueError(
1194
+ "Checkpoints should be a list of paths or a path to directory with checkpoints"
1195
+ )
1196
+
1197
+ best_checkpoints_name = experiment_info["best_checkpoint"]
1198
+ for checkpoint_path in checkpoint_paths:
1199
+ new_checkpoint_path = join(
1200
+ self._output_checkpoints_dir,
1201
+ sly_fs.get_file_name_with_ext(checkpoint_path),
1202
+ )
1203
+ shutil.move(checkpoint_path, new_checkpoint_path)
1204
+ if self.is_model_benchmark_enabled:
1205
+ if sly_fs.get_file_name_with_ext(checkpoint_path) == best_checkpoints_name:
1206
+ self._benchmark_params["model_files"]["checkpoint"] = new_checkpoint_path
1207
+
1208
+ # Prepare logs
1209
+ if sly_fs.dir_exists(self.log_dir):
1210
+ logs_dir = join(self.output_dir, "logs")
1211
+ shutil.move(self.log_dir, logs_dir)
1212
+
1213
+ # Generate experiment_info.json and app_state.json
1214
+ def _upload_file_to_team_files(self, local_path: str, remote_path: str, message: str) -> None:
1215
+ """Helper function to upload a file with progress."""
1216
+ logger.debug(f"Uploading '{local_path}' to Supervisely")
1217
+ total_size = sly_fs.get_file_size(local_path)
1218
+ with self.progress_bar_main(
1219
+ message=message, total=total_size, unit="bytes", unit_scale=True
1220
+ ) as upload_artifacts_pbar:
1221
+ self.progress_bar_main.show()
1222
+ self._api.file.upload(
1223
+ self._team_id,
1224
+ local_path,
1225
+ remote_path,
1226
+ progress_cb=upload_artifacts_pbar,
1227
+ )
1228
+ self.progress_bar_main.hide()
1229
+
1230
+ def _generate_train_val_splits(self, remote_dir: str, splits_data: dict) -> None:
1231
+ """
1232
+ Generates and uploads the train and val splits to the output directory.
1233
+
1234
+ :param remote_dir: Remote directory path.
1235
+ :type remote_dir: str
1236
+ """
1237
+ local_train_val_split_path = join(self.output_dir, self._train_val_split_file)
1238
+ remote_train_val_split_path = join(remote_dir, self._train_val_split_file)
1239
+
1240
+ data = {
1241
+ "train": splits_data["train"]["images_ids"],
1242
+ "val": splits_data["val"]["images_ids"],
1243
+ }
1244
+
1245
+ sly_json.dump_json_file(data, local_train_val_split_path)
1246
+ self._upload_file_to_team_files(
1247
+ local_train_val_split_path,
1248
+ remote_train_val_split_path,
1249
+ f"Uploading '{self._train_val_split_file}' to Team Files",
1250
+ )
1251
+
1252
+ def _generate_model_meta(self, remote_dir: str, experiment_info: dict) -> None:
1253
+ """
1254
+ Generates and uploads the model_meta.json file to the output directory.
1255
+
1256
+ :param remote_dir: Remote directory path.
1257
+ :type remote_dir: str
1258
+ :param experiment_info: Information about the experiment results.
1259
+ :type experiment_info: dict
1260
+ """
1261
+ # @TODO: Handle tags for classification tasks
1262
+ local_path = join(self.output_dir, self._model_meta_file)
1263
+ remote_path = join(remote_dir, self._model_meta_file)
1264
+
1265
+ sly_json.dump_json_file(self.model_meta.to_json(), local_path)
1266
+ self._upload_file_to_team_files(
1267
+ local_path,
1268
+ remote_path,
1269
+ f"Uploading '{self._model_meta_file}' to Team Files",
1270
+ )
1271
+
1272
+ def _generate_experiment_info(
1273
+ self,
1274
+ remote_dir: str,
1275
+ experiment_info: Dict,
1276
+ evaluation_report_id: Optional[int] = None,
1277
+ ) -> None:
1278
+ """
1279
+ Generates and uploads the experiment_info.json file to the output directory.
1280
+
1281
+ :param remote_dir: Remote directory path.
1282
+ :type remote_dir: str
1283
+ :param experiment_info: Information about the experiment results.
1284
+ :type experiment_info: dict
1285
+ :param evaluation_report_id: Evaluation report file ID.
1286
+ :type evaluation_report_id: int
1287
+ """
1288
+ logger.debug("Updating experiment info")
1289
+
1290
+ experiment_info = {
1291
+ "experiment_name": self.gui.training_process.get_experiment_name(),
1292
+ "framework_name": self.framework_name,
1293
+ "model_name": experiment_info["model_name"],
1294
+ "task_type": experiment_info["task_type"],
1295
+ "project_id": self.project_info.id,
1296
+ "task_id": self.task_id,
1297
+ "model_files": experiment_info["model_files"],
1298
+ "checkpoints": experiment_info["checkpoints"],
1299
+ "best_checkpoint": experiment_info["best_checkpoint"],
1300
+ "app_state": self._app_state_file,
1301
+ "model_meta": self._model_meta_file,
1302
+ "train_val_split": self._train_val_split_file,
1303
+ "hyperparameters": self._hyperparameters_file,
1304
+ "artifacts_dir": remote_dir,
1305
+ "datetime": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
1306
+ "evaluation_report_id": evaluation_report_id,
1307
+ "eval_metrics": {},
1308
+ }
1309
+
1310
+ remote_checkpoints_dir = join(remote_dir, self._remote_checkpoints_dir_name)
1311
+ checkpoint_files = self._api.file.list(
1312
+ self._team_id, remote_checkpoints_dir, return_type="fileinfo"
1313
+ )
1314
+ experiment_info["checkpoints"] = [
1315
+ f"checkpoints/{checkpoint.name}" for checkpoint in checkpoint_files
1316
+ ]
1317
+
1318
+ experiment_info["best_checkpoint"] = sly_fs.get_file_name_with_ext(
1319
+ experiment_info["best_checkpoint"]
1320
+ )
1321
+ experiment_info["model_files"]["config"] = sly_fs.get_file_name_with_ext(
1322
+ experiment_info["model_files"]["config"]
1323
+ )
1324
+
1325
+ local_path = join(self.output_dir, self._experiment_json_file)
1326
+ remote_path = join(remote_dir, self._experiment_json_file)
1327
+ sly_json.dump_json_file(experiment_info, local_path)
1328
+ self._upload_file_to_team_files(
1329
+ local_path,
1330
+ remote_path,
1331
+ f"Uploading '{self._experiment_json_file}' to Team Files",
1332
+ )
1333
+
1334
+ def _generate_hyperparameters(self, remote_dir: str, experiment_info: Dict) -> None:
1335
+ """
1336
+ Generates and uploads the hyperparameters.yaml file to the output directory.
1337
+
1338
+ :param remote_dir: Remote directory path.
1339
+ :type remote_dir: str
1340
+ :param experiment_info: Information about the experiment results.
1341
+ :type experiment_info: dict
1342
+ """
1343
+ local_path = join(self.output_dir, self._hyperparameters_file)
1344
+ remote_path = join(remote_dir, self._hyperparameters_file)
1345
+
1346
+ with open(local_path, "w") as file:
1347
+ file.write(self.hyperparameters_yaml)
1348
+
1349
+ self._upload_file_to_team_files(
1350
+ local_path,
1351
+ remote_path,
1352
+ f"Uploading '{self._hyperparameters_file}' to Team Files",
1353
+ )
1354
+
1355
+ def _generate_app_state(self, remote_dir: str, experiment_info: Dict) -> None:
1356
+ """
1357
+ Generates and uploads the app_state.json file to the output directory.
1358
+
1359
+ :param remote_dir: Remote directory path.
1360
+ :type remote_dir: str
1361
+ :param experiment_info: Information about the experiment results.
1362
+ :type experiment_info: dict
1363
+ """
1364
+ app_state = self.get_app_state(experiment_info)
1365
+
1366
+ local_path = join(self.output_dir, self._app_state_file)
1367
+ remote_path = join(remote_dir, self._app_state_file)
1368
+ sly_json.dump_json_file(app_state, local_path)
1369
+ self._upload_file_to_team_files(
1370
+ local_path, remote_path, f"Uploading '{self._app_state_file}' to Team Files"
1371
+ )
1372
+
1373
+ def _get_train_val_splits_for_app_state(self) -> Dict:
1374
+ """
1375
+ Gets the train and val splits information for app_state.json.
1376
+
1377
+ :return: Train and val splits information based on selected split method.
1378
+ :rtype: dict
1379
+ """
1380
+ split_method = self.gui.train_val_splits_selector.get_split_method()
1381
+ train_val_splits = {"method": split_method.lower()}
1382
+ if split_method == "Random":
1383
+ train_val_splits.update(
1384
+ {
1385
+ "split": "train",
1386
+ "percent": self.gui.train_val_splits_selector.train_val_splits.get_train_split_percent(),
1387
+ }
1388
+ )
1389
+ elif split_method == "Based on tags":
1390
+ train_val_splits.update(
1391
+ {
1392
+ "train_tag": self.gui.train_val_splits_selector.train_val_splits.get_train_tag(),
1393
+ "val_tag": self.gui.train_val_splits_selector.train_val_splits.get_val_tag(),
1394
+ "untagged_action": self.gui.train_val_splits_selector.train_val_splits.get_untagged_action(),
1395
+ }
1396
+ )
1397
+ elif split_method == "Based on datasets":
1398
+ train_val_splits.update(
1399
+ {
1400
+ "train_datasets": self.gui.train_val_splits_selector.train_val_splits.get_train_dataset_ids(),
1401
+ "val_datasets": self.gui.train_val_splits_selector.train_val_splits.get_val_dataset_ids(),
1402
+ }
1403
+ )
1404
+ return train_val_splits
1405
+
1406
+ def _get_model_config_for_app_state(self, experiment_info: Dict = None) -> Dict:
1407
+ """
1408
+ Gets the model configuration information for app_state.json.
1409
+
1410
+ :param experiment_info: Information about the experiment results.
1411
+ :type experiment_info: dict
1412
+ """
1413
+ experiment_info = experiment_info or {}
1414
+
1415
+ if self.model_source == ModelSource.PRETRAINED:
1416
+ model_name = experiment_info.get("model_name") or self.model_info.get("meta", {}).get(
1417
+ "model_name"
1418
+ )
1419
+ return {
1420
+ "source": ModelSource.PRETRAINED,
1421
+ "model_name": model_name,
1422
+ }
1423
+ elif self.model_source == ModelSource.CUSTOM:
1424
+ return {
1425
+ "source": ModelSource.CUSTOM,
1426
+ "task_id": self.task_id,
1427
+ "checkpoint": "checkpoint.pth",
1428
+ }
1429
+
1430
+ # ----------------------------------------- #
1431
+
1432
+ # Upload artifacts
1433
+ def _upload_artifacts(self) -> None:
1434
+ """
1435
+ Uploads the training artifacts to Supervisely.
1436
+ Path is generated based on the project ID, task ID, and framework name.
1437
+
1438
+ Path: /experiments/{project_id}_{project_name}/{task_id}_{framework_name}/
1439
+ Example path: /experiments/43192_Apples/68271_rt-detr/
1440
+ """
1441
+ logger.info(f"Uploading directory: '{self.output_dir}' to Supervisely")
1442
+ task_id = self.task_id
1443
+
1444
+ remote_artifacts_dir = f"/{self._experiments_dir_name}/{self.project_id}_{self.project_name}/{task_id}_{self.framework_name}/"
1445
+
1446
+ # Clean debug directory if exists
1447
+ if task_id == "debug-session":
1448
+ if self._api.file.dir_exists(self._team_id, f"{remote_artifacts_dir}/", True):
1449
+ with self.progress_bar_main(
1450
+ message=f"[Debug] Cleaning train artifacts: '{remote_artifacts_dir}/'",
1451
+ total=1,
1452
+ ) as upload_artifacts_pbar:
1453
+ self.progress_bar_main.show()
1454
+ self._api.file.remove_dir(self._team_id, f"{remote_artifacts_dir}", True)
1455
+ upload_artifacts_pbar.update(1)
1456
+ self.progress_bar_main.hide()
1457
+
1458
+ # Generate link file
1459
+ if is_production():
1460
+ app_url = f"/apps/sessions/{task_id}"
1461
+ else:
1462
+ app_url = "This is a debug session. No link available."
1463
+ app_link_path = join(self.output_dir, "open_app.lnk")
1464
+ with open(app_link_path, "w") as text_file:
1465
+ print(app_url, file=text_file)
1466
+
1467
+ local_files = sly_fs.list_files_recursively(self.output_dir)
1468
+ total_size = sum([sly_fs.get_file_size(file_path) for file_path in local_files])
1469
+ with self.progress_bar_main(
1470
+ message="Uploading train artifacts to Team Files",
1471
+ total=total_size,
1472
+ unit="bytes",
1473
+ unit_scale=True,
1474
+ ) as upload_artifacts_pbar:
1475
+ self.progress_bar_main.show()
1476
+ remote_dir = self._api.file.upload_directory(
1477
+ self._team_id,
1478
+ self.output_dir,
1479
+ remote_artifacts_dir,
1480
+ progress_size_cb=upload_artifacts_pbar,
1481
+ )
1482
+ self.progress_bar_main.hide()
1483
+
1484
+ file_info = self._api.file.get_info_by_path(self._team_id, join(remote_dir, "open_app.lnk"))
1485
+ return remote_dir, file_info
1486
+
1487
+ def _set_training_output(self, remote_dir: str, file_info: FileInfo) -> None:
1488
+ """
1489
+ Sets the training output in the GUI.
1490
+ """
1491
+ logger.info("All training artifacts uploaded successfully")
1492
+ self.gui.training_process.start_button.loading = False
1493
+ self.gui.training_process.start_button.disable()
1494
+ self.gui.training_process.stop_button.disable()
1495
+ self.gui.training_logs.tensorboard_button.disable()
1496
+
1497
+ set_directory(remote_dir)
1498
+ self.gui.training_process.artifacts_thumbnail.set(file_info)
1499
+ self.gui.training_process.artifacts_thumbnail.show()
1500
+ self.gui.training_process.success_message.show()
1501
+
1502
+ # Model Benchmark
1503
+ def _get_eval_results_dir_name(self) -> str:
1504
+ """
1505
+ Returns the evaluation results path.
1506
+ """
1507
+ task_info = self._api.task.get_info_by_id(self.task_id)
1508
+ task_dir = f"{self.task_id}_{task_info['meta']['app']['name']}"
1509
+ eval_res_dir = f"/model-benchmark/evaluation/{self.project_info.id}_{self.project_info.name}/{task_dir}/"
1510
+ eval_res_dir = self._api.storage.get_free_dir_name(self._team_id(), eval_res_dir)
1511
+ return eval_res_dir
1512
+
1513
+ def _run_model_benchmark(
1514
+ self,
1515
+ local_artifacts_dir: str,
1516
+ remote_artifacts_dir: str,
1517
+ experiment_info: dict,
1518
+ splits_data: dict,
1519
+ ) -> tuple:
1520
+ """
1521
+ Runs the Model Benchmark evaluation process. Model benchmark runs only in production mode.
1522
+
1523
+ :param local_artifacts_dir: Local directory path.
1524
+ :type local_artifacts_dir: str
1525
+ :param remote_artifacts_dir: Remote directory path.
1526
+ :type remote_artifacts_dir: str
1527
+ :param experiment_info: Information about the experiment results.
1528
+ :type experiment_info: dict
1529
+ :param splits_data: Information about the train and val splits.
1530
+ :type splits_data: dict
1531
+ :return: Evaluation report and report ID.
1532
+ :rtype: tuple
1533
+ """
1534
+ report, report_id = None, None
1535
+ if self._inference_class is None:
1536
+ logger.warn(
1537
+ "Inference class is not registered, model benchmark disabled. "
1538
+ "Use 'register_inference_class' method to register inference class."
1539
+ )
1540
+ return report, report_id
1541
+
1542
+ # Can't get task type from session. requires before session init
1543
+ supported_task_types = [
1544
+ TaskType.OBJECT_DETECTION,
1545
+ TaskType.INSTANCE_SEGMENTATION,
1546
+ ]
1547
+ task_type = experiment_info["task_type"]
1548
+ if task_type not in supported_task_types:
1549
+ logger.warn(
1550
+ f"Task type: '{task_type}' is not supported for Model Benchmark. "
1551
+ f"Supported tasks: {', '.join(task_type)}"
1552
+ )
1553
+ return report, report_id
1554
+
1555
+ logger.info("Running Model Benchmark evaluation")
1556
+ try:
1557
+ remote_checkpoints_dir = join(remote_artifacts_dir, "checkpoints")
1558
+ best_checkpoint = experiment_info.get("best_checkpoint", None)
1559
+ best_filename = sly_fs.get_file_name_with_ext(best_checkpoint)
1560
+ remote_best_checkpoint = join(remote_checkpoints_dir, best_filename)
1561
+
1562
+ config_path = experiment_info["model_files"].get("config")
1563
+ if config_path is not None:
1564
+ remote_config_path = join(
1565
+ remote_artifacts_dir, sly_fs.get_file_name_with_ext(config_path)
1566
+ )
1567
+ else:
1568
+ remote_config_path = None
1569
+
1570
+ logger.info(f"Creating the report for the best model: {best_filename!r}")
1571
+ self.gui.training_process.model_benchmark_report_text.show()
1572
+ self.progress_bar_main(message="Starting Model Benchmark evaluation", total=1)
1573
+ self.progress_bar_main.show()
1574
+
1575
+ # 0. Serve trained model
1576
+ m = self._inference_class(
1577
+ model_dir=self.model_dir,
1578
+ use_gui=False,
1579
+ custom_inference_settings=self._inference_settings,
1580
+ )
1581
+
1582
+ logger.info(f"Using device: {self.device}")
1583
+
1584
+ self._benchmark_params["device"] = self.device
1585
+ self._benchmark_params["model_info"] = {
1586
+ "artifacts_dir": remote_artifacts_dir,
1587
+ "model_name": experiment_info["model_name"],
1588
+ "framework_name": self.framework_name,
1589
+ "model_meta": self.model_meta.to_json(),
1590
+ }
1591
+
1592
+ logger.info(f"Deploy parameters: {self._benchmark_params}")
1593
+
1594
+ m._load_model_headless(**self._benchmark_params)
1595
+ m.serve()
1596
+
1597
+ port = 8000
1598
+ session = SessionJSON(self._api, session_url=f"http://localhost:{port}")
1599
+ benchmark_dir = join(local_artifacts_dir, "benchmark")
1600
+ sly_fs.mkdir(benchmark_dir, True)
1601
+
1602
+ # 1. Init benchmark
1603
+ benchmark_dataset_ids = splits_data["val"]["dataset_ids"]
1604
+ benchmark_images_ids = splits_data["val"]["images_ids"]
1605
+ train_dataset_ids = splits_data["train"]["dataset_ids"]
1606
+ train_images_ids = splits_data["train"]["images_ids"]
1607
+
1608
+ bm = None
1609
+ if task_type == TaskType.OBJECT_DETECTION:
1610
+ eval_params = ObjectDetectionEvaluator.load_yaml_evaluation_params()
1611
+ eval_params = yaml.safe_load(eval_params)
1612
+ bm = ObjectDetectionBenchmark(
1613
+ self._api,
1614
+ self.project_info.id,
1615
+ output_dir=benchmark_dir,
1616
+ gt_dataset_ids=benchmark_dataset_ids,
1617
+ gt_images_ids=benchmark_images_ids,
1618
+ progress=self.progress_bar_main,
1619
+ progress_secondary=self.progress_bar_secondary,
1620
+ classes_whitelist=self.classes,
1621
+ evaluation_params=eval_params,
1622
+ )
1623
+ elif task_type == TaskType.INSTANCE_SEGMENTATION:
1624
+ eval_params = InstanceSegmentationEvaluator.load_yaml_evaluation_params()
1625
+ eval_params = yaml.safe_load(eval_params)
1626
+ bm = InstanceSegmentationBenchmark(
1627
+ self._api,
1628
+ self.project_info.id,
1629
+ output_dir=benchmark_dir,
1630
+ gt_dataset_ids=benchmark_dataset_ids,
1631
+ gt_images_ids=benchmark_images_ids,
1632
+ progress=self.progress_bar_main,
1633
+ progress_secondary=self.progress_bar_secondary,
1634
+ classes_whitelist=self.classes,
1635
+ evaluation_params=eval_params,
1636
+ )
1637
+ elif task_type == TaskType.SEMANTIC_SEGMENTATION:
1638
+ eval_params = SemanticSegmentationEvaluator.load_yaml_evaluation_params()
1639
+ eval_params = yaml.safe_load(eval_params)
1640
+ bm = SemanticSegmentationBenchmark(
1641
+ self._api,
1642
+ self.project_info.id,
1643
+ output_dir=benchmark_dir,
1644
+ gt_dataset_ids=benchmark_dataset_ids,
1645
+ gt_images_ids=benchmark_images_ids,
1646
+ progress=self.progress_bar_main,
1647
+ progress_secondary=self.progress_bar_secondary,
1648
+ classes_whitelist=self.classes,
1649
+ evaluation_params=eval_params,
1650
+ )
1651
+ else:
1652
+ raise ValueError(f"Task type: '{task_type}' is not supported for Model Benchmark")
1653
+
1654
+ if self.gui.train_val_splits_selector.get_split_method() == "Based on datasets":
1655
+ train_info = {
1656
+ "app_session_id": self.task_id,
1657
+ "train_dataset_ids": train_dataset_ids,
1658
+ "train_images_ids": None,
1659
+ "images_count": len(self._train_split),
1660
+ }
1661
+ else:
1662
+ train_info = {
1663
+ "app_session_id": self.task_id,
1664
+ "train_dataset_ids": None,
1665
+ "train_images_ids": train_images_ids,
1666
+ "images_count": len(self._train_split),
1667
+ }
1668
+ bm.train_info = train_info
1669
+
1670
+ # 2. Run inference
1671
+ bm.run_inference(session)
1672
+
1673
+ # 3. Pull results from the server
1674
+ gt_project_path, dt_project_path = bm.download_projects(save_images=False)
1675
+
1676
+ # 4. Evaluate
1677
+ bm._evaluate(gt_project_path, dt_project_path)
1678
+
1679
+ # 5. Upload evaluation results
1680
+ eval_res_dir = self._get_eval_results_dir_name()
1681
+ bm.upload_eval_results(eval_res_dir + "/evaluation/")
1682
+
1683
+ # 6. Speed test
1684
+ if self.gui.hyperparameters_selector.get_speedtest_checkbox_value() is True:
1685
+ bm.run_speedtest(session, self.project_info.id)
1686
+ self.progress_bar_secondary.hide() # @TODO: add progress bar
1687
+ bm.upload_speedtest_results(eval_res_dir + "/speedtest/")
1688
+
1689
+ # 7. Prepare visualizations, report and upload
1690
+ bm.visualize()
1691
+ remote_dir = bm.upload_visualizations(eval_res_dir + "/visualizations/")
1692
+ report = bm.upload_report_link(remote_dir)
1693
+ report_id = report.id
1694
+
1695
+ # 8. UI updates
1696
+ benchmark_report_template = self._api.file.get_info_by_path(
1697
+ self._team_id(), remote_dir + "template.vue"
1698
+ )
1699
+
1700
+ self.gui.training_process.model_benchmark_report_text.hide()
1701
+ self.gui.training_process.model_benchmark_report_thumbnail.set(
1702
+ benchmark_report_template
1703
+ )
1704
+ self.gui.training_process.model_benchmark_report_thumbnail.show()
1705
+ self.progress_bar_main.hide()
1706
+ self.progress_bar_secondary.hide()
1707
+ logger.info("Model benchmark evaluation completed successfully")
1708
+ logger.info(
1709
+ f"Predictions project name: {bm.dt_project_info.name}. Workspace_id: {bm.dt_project_info.workspace_id}"
1710
+ )
1711
+ logger.info(
1712
+ f"Differences project name: {bm.diff_project_info.name}. Workspace_id: {bm.diff_project_info.workspace_id}"
1713
+ )
1714
+ except Exception as e:
1715
+ logger.error(f"Model benchmark failed. {repr(e)}", exc_info=True)
1716
+ self.gui.training_process.model_benchmark_report_text.hide()
1717
+ self.progress_bar_main.hide()
1718
+ self.progress_bar_secondary.hide()
1719
+ try:
1720
+ if bm.dt_project_info:
1721
+ self._api.project.remove(bm.dt_project_info.id)
1722
+ if bm.diff_project_info:
1723
+ self._api.project.remove(bm.diff_project_info.id)
1724
+ except Exception as e2:
1725
+ return report, report_id
1726
+ return report, report_id
1727
+
1728
+ # ----------------------------------------- #
1729
+
1730
+ # Workflow
1731
+ def _workflow_input(self):
1732
+ """
1733
+ Adds the input data to the workflow.
1734
+ """
1735
+ try:
1736
+ project_version_id = self._api.project.version.create(
1737
+ self.project_info,
1738
+ self._app_name,
1739
+ f"This backup was created automatically by Supervisely before the {self._app_name} task with ID: {self._api.task_id}",
1740
+ )
1741
+ except Exception as e:
1742
+ logger.warning(f"Failed to create a project version: {repr(e)}")
1743
+ project_version_id = None
1744
+
1745
+ try:
1746
+ if project_version_id is None:
1747
+ project_version_id = (
1748
+ self.project_info.version.get("id", None) if self.project_info.version else None
1749
+ )
1750
+ self._api.app.workflow.add_input_project(
1751
+ self.project_info.id, version_id=project_version_id
1752
+ )
1753
+
1754
+ if self.model_source == ModelSource.CUSTOM:
1755
+ file_info = self._api.file.get_info_by_path(
1756
+ self._team_id,
1757
+ self.gui.model_selector.experiment_selector.get_selected_checkpoint_path(),
1758
+ )
1759
+ if file_info is not None:
1760
+ self._api.app.workflow.add_input_file(file_info, model_weight=True)
1761
+ logger.debug(
1762
+ f"Workflow Input: Project ID - {self.project_info.id}, Project Version ID - {project_version_id}, Input File - {True if file_info else False}"
1763
+ )
1764
+ except Exception as e:
1765
+ logger.debug(f"Failed to add input to the workflow: {repr(e)}")
1766
+
1767
+ def _workflow_output(
1768
+ self,
1769
+ team_files_dir: str,
1770
+ file_info: FileInfo,
1771
+ model_benchmark_report: Optional[FileInfo] = None,
1772
+ ):
1773
+ """
1774
+ Adds the output data to the workflow.
1775
+ """
1776
+ try:
1777
+ module_id = (
1778
+ self._api.task.get_info_by_id(self._api.task_id)
1779
+ .get("meta", {})
1780
+ .get("app", {})
1781
+ .get("id")
1782
+ )
1783
+ logger.debug(f"Workflow Output: Model artifacts - '{team_files_dir}'")
1784
+
1785
+ node_settings = WorkflowSettings(
1786
+ title=self._app_name,
1787
+ url=(
1788
+ f"/apps/{module_id}/sessions/{self._api.task_id}"
1789
+ if module_id
1790
+ else f"apps/sessions/{self._api.task_id}"
1791
+ ),
1792
+ url_title="Show Results",
1793
+ )
1794
+
1795
+ if file_info:
1796
+ relation_settings = WorkflowSettings(
1797
+ title="Train Artifacts",
1798
+ icon="folder",
1799
+ icon_color="#FFA500",
1800
+ icon_bg_color="#FFE8BE",
1801
+ url=f"/files/{file_info.id}/true",
1802
+ url_title="Open Folder",
1803
+ )
1804
+ meta = WorkflowMeta(
1805
+ relation_settings=relation_settings, node_settings=node_settings
1806
+ )
1807
+ logger.debug(f"Workflow Output: meta \n {meta}")
1808
+ self._api.app.workflow.add_output_file(file_info, model_weight=True, meta=meta)
1809
+ else:
1810
+ logger.debug(
1811
+ f"File with checkpoints not found in Team Files. Cannot set workflow output."
1812
+ )
1813
+
1814
+ if model_benchmark_report:
1815
+ mb_relation_settings = WorkflowSettings(
1816
+ title="Model Benchmark",
1817
+ icon="assignment",
1818
+ icon_color="#674EA7",
1819
+ icon_bg_color="#CCCCFF",
1820
+ url=f"/model-benchmark?id={model_benchmark_report.id}",
1821
+ url_title="Open Report",
1822
+ )
1823
+
1824
+ meta = WorkflowMeta(
1825
+ relation_settings=mb_relation_settings, node_settings=node_settings
1826
+ )
1827
+ self._api.app.workflow.add_output_file(model_benchmark_report, meta=meta)
1828
+ else:
1829
+ logger.debug(
1830
+ f"File with model benchmark report not found in Team Files. Cannot set workflow output."
1831
+ )
1832
+ except Exception as e:
1833
+ logger.debug(f"Failed to add output to the workflow: {repr(e)}")
1834
+ # ----------------------------------------- #
1835
+
1836
+ # Logger
1837
+ def _init_logger(self):
1838
+ """
1839
+ Initialize training logger. Set up Tensorboard and callbacks.
1840
+ """
1841
+ train_logger = self._app_options.get("train_logger", "")
1842
+ if train_logger.lower() == "tensorboard":
1843
+ tb_logger.set_log_dir(self.log_dir)
1844
+ self._setup_logger_callbacks()
1845
+ self._init_tensorboard()
1846
+
1847
+ def _init_tensorboard(self):
1848
+ self._register_routes()
1849
+ args = [
1850
+ "tensorboard",
1851
+ "--logdir",
1852
+ self.log_dir,
1853
+ "--host=localhost",
1854
+ f"--port={self._tensorboard_port}",
1855
+ "--load_fast=true",
1856
+ "--reload_multifile=true",
1857
+ ]
1858
+ self._tensorboard_process = subprocess.Popen(args)
1859
+ self.app.call_before_shutdown(self.stop_tensorboard)
1860
+ print(f"Tensorboard server has been started")
1861
+ self.gui.training_logs.tensorboard_button.enable()
1862
+
1863
+ def start_tensorboard(self, log_dir: str, port: int = None):
1864
+ """
1865
+ Method to manually start Tensorboard in the user's training code.
1866
+ Tensorboard is started automatically if the 'train_logger' is set to 'tensorboard' in app_options.yaml file.
1867
+
1868
+ :param log_dir: Directory path to the log files.
1869
+ :type log_dir: str
1870
+ :param port: Port number for Tensorboard, defaults to None
1871
+ :type port: int, optional
1872
+ """
1873
+ if port is not None:
1874
+ self._tensorboard_port = port
1875
+ self.log_dir = log_dir
1876
+ self._init_tensorboard()
1877
+
1878
+ def stop_tensorboard(self):
1879
+ """Stop Tensorboard server"""
1880
+ if self._tensorboard_process is not None:
1881
+ self._tensorboard_process.terminate()
1882
+ self._tensorboard_process = None
1883
+ print(f"Tensorboard server has been stopped")
1884
+ else:
1885
+ print("Tensorboard server is not running")
1886
+
1887
+ def _setup_logger_callbacks(self):
1888
+ """
1889
+ Set up callbacks for the training logger.
1890
+ """
1891
+ epoch_pbar = None
1892
+ step_pbar = None
1893
+
1894
+ def start_training_callback(total_epochs: int):
1895
+ """
1896
+ Callback function that is called when the training process starts.
1897
+ """
1898
+ nonlocal epoch_pbar
1899
+ logger.info(f"Training started for {total_epochs} epochs")
1900
+ pbar_widget = self.progress_bar_main
1901
+ pbar_widget.show()
1902
+ epoch_pbar = pbar_widget(message=f"Epochs", total=total_epochs)
1903
+
1904
+ def finish_training_callback():
1905
+ """
1906
+ Callback function that is called when the training process finishes.
1907
+ """
1908
+ self.progress_bar_main.hide()
1909
+ self.progress_bar_secondary.hide()
1910
+
1911
+ train_logger = self._app_options.get("train_logger", "")
1912
+ if train_logger == "tensorboard":
1913
+ tb_logger.close()
1914
+
1915
+ def start_epoch_callback(total_steps: int):
1916
+ """
1917
+ Callback function that is called when a new epoch starts.
1918
+ """
1919
+ nonlocal step_pbar
1920
+ logger.info(f"Epoch started. Total steps: {total_steps}")
1921
+ pbar_widget = self.progress_bar_secondary
1922
+ pbar_widget.show()
1923
+ step_pbar = pbar_widget(message=f"Steps", total=total_steps)
1924
+
1925
+ def finish_epoch_callback():
1926
+ """
1927
+ Callback function that is called when an epoch finishes.
1928
+ """
1929
+ epoch_pbar.update(1)
1930
+
1931
+ def step_callback():
1932
+ """
1933
+ Callback function that is called when a step iteration is completed.
1934
+ """
1935
+ step_pbar.update(1)
1936
+
1937
+ tb_logger.add_on_train_started_callback(start_training_callback)
1938
+ tb_logger.add_on_train_finish_callback(finish_training_callback)
1939
+
1940
+ tb_logger.add_on_epoch_started_callback(start_epoch_callback)
1941
+ tb_logger.add_on_epoch_finish_callback(finish_epoch_callback)
1942
+
1943
+ tb_logger.add_on_step_callback(step_callback)
1944
+
1945
+ # ----------------------------------------- #
1946
+ def _wrapped_start_training(self):
1947
+ """
1948
+ Wrapper function to wrap the training process.
1949
+ """
1950
+ experiment_info = None
1951
+
1952
+ try:
1953
+ self._set_train_widgets_state_on_start()
1954
+ if self._train_func is None:
1955
+ raise ValueError("Train function is not defined")
1956
+ self._prepare_working_dir()
1957
+ self._init_logger()
1958
+ except Exception as e:
1959
+ message = "Error occurred during training initialization. Please check the logs for more details."
1960
+ self._show_error(message, e)
1961
+ self._restore_train_widgets_state_on_error()
1962
+
1963
+ try:
1964
+ self.gui.training_process.validator_text.set("Preparing data for training...", "info")
1965
+ self._prepare()
1966
+ except Exception as e:
1967
+ message = (
1968
+ "Error occurred during data preparation. Please check the logs for more details."
1969
+ )
1970
+ self._show_error(message, e)
1971
+ self._restore_train_widgets_state_on_error()
1972
+ return
1973
+
1974
+ try:
1975
+ self.gui.training_process.validator_text.set("Training is in progress...", "info")
1976
+ experiment_info = self._train_func()
1977
+ except Exception as e:
1978
+ message = "Error occurred during training. Please check the logs for more details."
1979
+ self._show_error(message, e)
1980
+ self._restore_train_widgets_state_on_error()
1981
+ return
1982
+
1983
+ try:
1984
+ self.gui.training_process.validator_text.set(
1985
+ "Finalizing and uploading training artifacts...", "info"
1986
+ )
1987
+ self._finalize(experiment_info)
1988
+ self.gui.training_process.start_button.loading = False
1989
+ self.gui.training_process.validator_text.set(
1990
+ self.gui.training_process.success_message_text, "success"
1991
+ )
1992
+ except Exception as e:
1993
+ message = "Error occurred during finalizing and uploading training artifacts . Please check the logs for more details."
1994
+ self._show_error(message, e)
1995
+ self._restore_train_widgets_state_on_error()
1996
+ return
1997
+
1998
+ def _show_error(self, message: str, e=None):
1999
+ if e is not None:
2000
+ logger.error(f"{message}: {repr(e)}", exc_info=True)
2001
+ else:
2002
+ logger.error(message)
2003
+ self.gui.training_process.validator_text.set(message, "error")
2004
+ self.gui.training_process.validator_text.show()
2005
+ self.gui.training_process.start_button.loading = False
2006
+ self._restore_train_widgets_state_on_error()
2007
+
2008
+ def _set_train_widgets_state_on_start(self):
2009
+ self._validate_experiment_name()
2010
+ self.gui.training_process.experiment_name_input.disable()
2011
+ if self._app_options.get("device_selector", False):
2012
+ self.gui.training_process.select_device._select.disable()
2013
+ self.gui.training_process.select_device.disable()
2014
+
2015
+ self.gui.training_logs.card.unlock()
2016
+ self.gui.stepper.set_active_step(7)
2017
+ self.gui.training_process.validator_text.set("Training is started...", "info")
2018
+ self.gui.training_process.validator_text.show()
2019
+ self.gui.training_process.start_button.loading = True
2020
+
2021
+ def _restore_train_widgets_state_on_error(self):
2022
+ self.gui.training_logs.card.lock()
2023
+ self.gui.stepper.set_active_step(self.gui.stepper.get_active_step() - 1)
2024
+ self.gui.training_process.experiment_name_input.enable()
2025
+ if self._app_options.get("device_selector", False):
2026
+ self.gui.training_process.select_device._select.enable()
2027
+ self.gui.training_process.select_device.enable()
2028
+
2029
+ def _validate_experiment_name(self) -> bool:
2030
+ experiment_name = self.gui.training_process.get_experiment_name()
2031
+ if not experiment_name:
2032
+ logger.error("Experiment name is empty")
2033
+ raise ValueError("Experiment name is empty")
2034
+ invalid_chars = r"\/"
2035
+ if any(char in experiment_name for char in invalid_chars):
2036
+ logger.error(f"Experiment name contains invalid characters: {invalid_chars}")
2037
+ raise ValueError(f"Experiment name contains invalid characters: {invalid_chars}")
2038
+ return True