supervisely 6.73.250__py3-none-any.whl → 6.73.252__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (35) hide show
  1. supervisely/api/dataset_api.py +17 -1
  2. supervisely/api/project_api.py +4 -1
  3. supervisely/api/volume/volume_annotation_api.py +7 -4
  4. supervisely/app/widgets/experiment_selector/experiment_selector.py +16 -8
  5. supervisely/nn/benchmark/base_benchmark.py +17 -2
  6. supervisely/nn/benchmark/base_evaluator.py +28 -6
  7. supervisely/nn/benchmark/instance_segmentation/benchmark.py +1 -1
  8. supervisely/nn/benchmark/instance_segmentation/evaluator.py +14 -0
  9. supervisely/nn/benchmark/object_detection/benchmark.py +1 -1
  10. supervisely/nn/benchmark/object_detection/evaluator.py +43 -13
  11. supervisely/nn/benchmark/object_detection/metric_provider.py +7 -0
  12. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +33 -7
  13. supervisely/nn/benchmark/utils/detection/utlis.py +6 -4
  14. supervisely/nn/experiments.py +23 -16
  15. supervisely/nn/inference/gui/serving_gui_template.py +2 -35
  16. supervisely/nn/inference/inference.py +71 -8
  17. supervisely/nn/training/__init__.py +2 -0
  18. supervisely/nn/training/gui/classes_selector.py +14 -14
  19. supervisely/nn/training/gui/gui.py +28 -13
  20. supervisely/nn/training/gui/hyperparameters_selector.py +90 -41
  21. supervisely/nn/training/gui/input_selector.py +8 -6
  22. supervisely/nn/training/gui/model_selector.py +7 -5
  23. supervisely/nn/training/gui/train_val_splits_selector.py +8 -9
  24. supervisely/nn/training/gui/training_logs.py +17 -17
  25. supervisely/nn/training/gui/training_process.py +41 -36
  26. supervisely/nn/training/loggers/__init__.py +22 -0
  27. supervisely/nn/training/loggers/base_train_logger.py +8 -5
  28. supervisely/nn/training/loggers/tensorboard_logger.py +4 -11
  29. supervisely/nn/training/train_app.py +276 -90
  30. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/METADATA +8 -3
  31. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/RECORD +35 -35
  32. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/LICENSE +0 -0
  33. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/WHEEL +0 -0
  34. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/entry_points.txt +0 -0
  35. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/top_level.txt +0 -0
@@ -9,8 +9,8 @@ import shutil
9
9
  import subprocess
10
10
  from datetime import datetime
11
11
  from os import listdir
12
- from os.path import basename, isdir, isfile, join
13
- from typing import Any, Dict, List, Optional, Union
12
+ from os.path import basename, exists, expanduser, isdir, isfile, join
13
+ from typing import Any, Dict, List, Literal, Optional, Union
14
14
  from urllib.request import urlopen
15
15
 
16
16
  import httpx
@@ -50,9 +50,10 @@ from supervisely.nn.benchmark import (
50
50
  SemanticSegmentationEvaluator,
51
51
  )
52
52
  from supervisely.nn.inference import RuntimeType, SessionJSON
53
+ from supervisely.nn.inference.inference import Inference
53
54
  from supervisely.nn.task_type import TaskType
54
55
  from supervisely.nn.training.gui.gui import TrainGUI
55
- from supervisely.nn.training.loggers.tensorboard_logger import tb_logger
56
+ from supervisely.nn.training.loggers import setup_train_logger, train_logger
56
57
  from supervisely.nn.utils import ModelSource
57
58
  from supervisely.output import set_directory
58
59
  from supervisely.project.download import (
@@ -109,6 +110,7 @@ class TrainApp:
109
110
  self._remote_checkpoints_dir_name = "checkpoints"
110
111
  self._experiments_dir_name = "experiments"
111
112
  self._default_work_dir_name = "work_dir"
113
+ self._export_dir_name = "export"
112
114
  self._tensorboard_port = 6006
113
115
 
114
116
  if is_production():
@@ -121,6 +123,7 @@ class TrainApp:
121
123
  self._team_id = sly_env.team_id()
122
124
  self._workspace_id = sly_env.workspace_id()
123
125
  self._app_name = sly_env.app_name(raise_not_found=False)
126
+ self._tensorboard_process = None
124
127
 
125
128
  # TODO: read files
126
129
  self._models = self._load_models(models)
@@ -139,7 +142,11 @@ class TrainApp:
139
142
  self.project_dir = join(self.work_dir, self._sly_project_dir_name)
140
143
  self.train_dataset_dir = join(self.project_dir, "train")
141
144
  self.val_dataset_dir = join(self.project_dir, "val")
145
+ self._model_cache_dir = join(expanduser("~"), ".cache", "supervisely", "checkpoints")
142
146
  self.sly_project = None
147
+ # -------------------------- #
148
+
149
+ # Train/Val splits
143
150
  self.train_split, self.val_split = None, None
144
151
  # -------------------------- #
145
152
 
@@ -166,6 +173,13 @@ class TrainApp:
166
173
  self._server = self.app.get_server()
167
174
  self._train_func = None
168
175
 
176
+ self._onnx_supported = self._app_options.get("export_onnx_supported", False)
177
+ self._tensorrt_supported = self._app_options.get("export_tensorrt_supported", False)
178
+ if self._onnx_supported:
179
+ self._convert_onnx_func = None
180
+ if self._tensorrt_supported:
181
+ self._convert_tensorrt_func = None
182
+
169
183
  # Benchmark parameters
170
184
  if self.is_model_benchmark_enabled:
171
185
  self._benchmark_params = {
@@ -263,6 +277,16 @@ class TrainApp:
263
277
  """
264
278
  return self.gui.project_info
265
279
 
280
+ @property
281
+ def project_meta(self) -> ProjectMeta:
282
+ """
283
+ Returns the project metadata.
284
+
285
+ :return: Project metadata.
286
+ :rtype: ProjectMeta
287
+ """
288
+ return self.gui.project_meta
289
+
266
290
  # ----------------------------------------- #
267
291
 
268
292
  # Model
@@ -304,7 +328,7 @@ class TrainApp:
304
328
  :return: Model metadata.
305
329
  :rtype: dict
306
330
  """
307
- project_meta_json = self.sly_project.meta.to_json()
331
+ project_meta_json = self.project_meta.to_json()
308
332
  model_meta = {
309
333
  "classes": [
310
334
  item for item in project_meta_json["classes"] if item["title"] in self.classes
@@ -331,7 +355,9 @@ class TrainApp:
331
355
  :return: List of selected classes.
332
356
  :rtype: List[str]
333
357
  """
334
- return self.gui.classes_selector.get_selected_classes()
358
+ selected_classes = set(self.gui.classes_selector.get_selected_classes())
359
+ # remap classes with project_meta order
360
+ return [x for x in self.project_meta.obj_classes.keys() if x in selected_classes]
335
361
 
336
362
  @property
337
363
  def num_classes(self) -> int:
@@ -401,7 +427,7 @@ class TrainApp:
401
427
  # Output
402
428
  # ----------------------------------------- #
403
429
 
404
- # region TRAIN START
430
+ # Wrappers
405
431
  @property
406
432
  def start(self):
407
433
  """
@@ -416,6 +442,34 @@ class TrainApp:
416
442
 
417
443
  return decorator
418
444
 
445
+ @property
446
+ def export_onnx(self):
447
+ """
448
+ Decorator for the export to ONNX function defined by user.
449
+ It wraps user-defined export function and prepares and finalizes the training process.
450
+ """
451
+
452
+ def decorator(func):
453
+ self._convert_onnx_func = func
454
+ return func
455
+
456
+ return decorator
457
+
458
+ @property
459
+ def export_tensorrt(self):
460
+ """
461
+ Decorator for the export to TensorRT function defined by user.
462
+ It wraps user-defined export function and prepares and finalizes the training process.
463
+ """
464
+
465
+ def decorator(func):
466
+ self._convert_tensorrt_func = func
467
+ return func
468
+
469
+ return decorator
470
+
471
+ # ----------------------------------------- #
472
+
419
473
  def _prepare(self) -> None:
420
474
  """
421
475
  Prepares the environment for training by setting up directories,
@@ -451,7 +505,7 @@ class TrainApp:
451
505
  raise ValueError(f"{reason}. Failed to upload artifacts")
452
506
 
453
507
  # Step 2. Preprocess artifacts
454
- self._preprocess_artifacts(experiment_info)
508
+ experiment_info = self._preprocess_artifacts(experiment_info)
455
509
 
456
510
  # Step3. Postprocess splits
457
511
  splits_data = self._postprocess_splits()
@@ -460,33 +514,46 @@ class TrainApp:
460
514
  remote_dir, file_info = self._upload_artifacts()
461
515
 
462
516
  # Step 4. Run Model Benchmark
463
- mb_eval_report, mb_eval_report_id = None, None
464
-
517
+ mb_eval_report_file, mb_eval_report_id, eval_metrics = None, None, {}
465
518
  if self.is_model_benchmark_enabled:
466
519
  try:
467
- mb_eval_report, mb_eval_report_id = self._run_model_benchmark(
520
+ mb_eval_report_file, mb_eval_report_id, eval_metrics = self._run_model_benchmark(
468
521
  self.output_dir, remote_dir, experiment_info, splits_data
469
522
  )
470
523
  except Exception as e:
471
524
  logger.error(f"Model benchmark failed: {e}")
472
525
 
473
- # Step 4. Generate and upload additional files
474
- self._generate_experiment_info(remote_dir, experiment_info, mb_eval_report_id)
526
+ # Step 5. [Optional] Convert weights
527
+ export_weights = {}
528
+ if self.gui.hyperparameters_selector.is_export_required():
529
+ try:
530
+ export_weights = self._export_weights(experiment_info)
531
+ self._set_progress_status("finalizing")
532
+ export_weights = self._upload_export_weights(export_weights, remote_dir)
533
+ except Exception as e:
534
+ logger.error(f"Export weights failed: {e}")
535
+
536
+ # Step 6. Generate and upload additional files
537
+ self._generate_experiment_info(
538
+ remote_dir, experiment_info, eval_metrics, mb_eval_report_id, export_weights
539
+ )
475
540
  self._generate_app_state(remote_dir, experiment_info)
476
541
  self._generate_hyperparameters(remote_dir, experiment_info)
477
542
  self._generate_train_val_splits(remote_dir, splits_data)
478
543
  self._generate_model_meta(remote_dir, experiment_info)
479
544
 
480
- # Step 5. Set output widgets
545
+ # Step 7. Set output widgets
481
546
  self._set_training_output(remote_dir, file_info)
482
547
 
483
- # Step 6. Workflow output
548
+ # Step 8. Workflow output
484
549
  if is_production():
485
- self._workflow_output(remote_dir, file_info, mb_eval_report)
550
+ self._workflow_output(remote_dir, file_info, mb_eval_report_file, mb_eval_report_id)
486
551
 
487
- # region TRAIN END
552
+ self._set_progress_status("completed")
488
553
 
489
- def register_inference_class(self, inference_class: Any, inference_settings: dict = {}) -> None:
554
+ def register_inference_class(
555
+ self, inference_class: Inference, inference_settings: dict = {}
556
+ ) -> None:
490
557
  """
491
558
  Registers an inference class for the training application to do model benchmarking.
492
559
 
@@ -878,6 +945,7 @@ class TrainApp:
878
945
  For Pretrained models:
879
946
  - The files that will be downloaded are specified in the `meta` key under `model_files`.
880
947
  - All files listed in the `model_files` key will be downloaded by provided link.
948
+ - If model files are already cached on agent, they will be copied to the model directory without downloading.
881
949
  Example of a pretrained model entry:
882
950
  [
883
951
  {
@@ -930,25 +998,39 @@ class TrainApp:
930
998
  for file in model_files:
931
999
  file_url = model_files[file]
932
1000
  file_path = join(self.model_dir, file)
933
-
1001
+ file_name = sly_fs.get_file_name_with_ext(file_url)
934
1002
  if file_url.startswith("http"):
935
1003
  with urlopen(file_url) as f:
936
1004
  file_size = f.length
937
1005
  file_name = get_filename_from_headers(file_url)
1006
+ if file_name is None:
1007
+ file_name = file
938
1008
  file_path = join(self.model_dir, file_name)
939
- with self.progress_bar_secondary(
940
- message=f"Downloading '{file_name}' ",
941
- total=file_size,
942
- unit="bytes",
943
- unit_scale=True,
944
- ) as model_download_secondary_pbar:
945
- self.progress_bar_secondary.show()
946
- sly_fs.download(
947
- url=file_url,
948
- save_path=file_path,
949
- progress=model_download_secondary_pbar.update,
950
- )
951
- self.model_files[file] = file_path
1009
+ cached_path = join(self._model_cache_dir, file_name)
1010
+ if exists(cached_path):
1011
+ self.model_files[file] = cached_path
1012
+ logger.debug(f"Model: '{file_name}' was found in checkpoint cache")
1013
+ model_download_main_pbar.update(1)
1014
+ continue
1015
+ if exists(file_path):
1016
+ self.model_files[file] = file_path
1017
+ logger.debug(f"Model: '{file_name}' was found in model dir")
1018
+ model_download_main_pbar.update(1)
1019
+ continue
1020
+
1021
+ with self.progress_bar_secondary(
1022
+ message=f"Downloading '{file_name}' ",
1023
+ total=file_size,
1024
+ unit="bytes",
1025
+ unit_scale=True,
1026
+ ) as model_download_secondary_pbar:
1027
+ self.progress_bar_secondary.show()
1028
+ sly_fs.download(
1029
+ url=file_url,
1030
+ save_path=file_path,
1031
+ progress=model_download_secondary_pbar.update,
1032
+ )
1033
+ self.model_files[file] = file_path
952
1034
  else:
953
1035
  self.model_files[file] = file_url
954
1036
  model_download_main_pbar.update(1)
@@ -1070,7 +1152,8 @@ class TrainApp:
1070
1152
  elif isinstance(checkpoints, str):
1071
1153
  checkpoints = [
1072
1154
  sly_fs.get_file_name_with_ext(checkpoint)
1073
- for checkpoint in sly_fs.list_dir_recursively(checkpoints, [".pt", ".pth"])
1155
+ for checkpoint in listdir(checkpoints)
1156
+ if sly_fs.get_file_ext(checkpoint) in [".pt", ".pth"]
1074
1157
  ]
1075
1158
  if best_checkpoint not in checkpoints:
1076
1159
  reason = (
@@ -1177,6 +1260,7 @@ class TrainApp:
1177
1260
  config_name = sly_fs.get_file_name_with_ext(experiment_info["model_files"]["config"])
1178
1261
  output_config_path = join(self.output_dir, config_name)
1179
1262
  shutil.move(experiment_info["model_files"]["config"], output_config_path)
1263
+ experiment_info["model_files"]["config"] = output_config_path
1180
1264
  if self.is_model_benchmark_enabled:
1181
1265
  self._benchmark_params["model_files"]["config"] = output_config_path
1182
1266
 
@@ -1185,8 +1269,10 @@ class TrainApp:
1185
1269
  # If checkpoints returned as directory
1186
1270
  if isinstance(checkpoints, str):
1187
1271
  checkpoint_paths = []
1188
- for checkpoint_path in sly_fs.list_files_recursively(checkpoints, [".pt", ".pth"]):
1189
- checkpoint_paths.append(checkpoint_path)
1272
+ for checkpoint_path in listdir(checkpoints):
1273
+ checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
1274
+ if checkpoint_ext in [".pt", ".pth"]:
1275
+ checkpoint_paths.append(join(checkpoints, checkpoint_path))
1190
1276
  elif isinstance(checkpoints, list):
1191
1277
  checkpoint_paths = checkpoints
1192
1278
  else:
@@ -1194,6 +1280,7 @@ class TrainApp:
1194
1280
  "Checkpoints should be a list of paths or a path to directory with checkpoints"
1195
1281
  )
1196
1282
 
1283
+ new_checkpoint_paths = []
1197
1284
  best_checkpoints_name = experiment_info["best_checkpoint"]
1198
1285
  for checkpoint_path in checkpoint_paths:
1199
1286
  new_checkpoint_path = join(
@@ -1201,14 +1288,18 @@ class TrainApp:
1201
1288
  sly_fs.get_file_name_with_ext(checkpoint_path),
1202
1289
  )
1203
1290
  shutil.move(checkpoint_path, new_checkpoint_path)
1204
- if self.is_model_benchmark_enabled:
1205
- if sly_fs.get_file_name_with_ext(checkpoint_path) == best_checkpoints_name:
1291
+ new_checkpoint_paths.append(new_checkpoint_path)
1292
+ if sly_fs.get_file_name_with_ext(checkpoint_path) == best_checkpoints_name:
1293
+ experiment_info["best_checkpoint"] = new_checkpoint_path
1294
+ if self.is_model_benchmark_enabled:
1206
1295
  self._benchmark_params["model_files"]["checkpoint"] = new_checkpoint_path
1296
+ experiment_info["checkpoints"] = new_checkpoint_paths
1207
1297
 
1208
1298
  # Prepare logs
1209
1299
  if sly_fs.dir_exists(self.log_dir):
1210
1300
  logs_dir = join(self.output_dir, "logs")
1211
1301
  shutil.copytree(self.log_dir, logs_dir)
1302
+ return experiment_info
1212
1303
 
1213
1304
  # Generate experiment_info.json and app_state.json
1214
1305
  def _upload_file_to_team_files(self, local_path: str, remote_path: str, message: str) -> None:
@@ -1273,7 +1364,9 @@ class TrainApp:
1273
1364
  self,
1274
1365
  remote_dir: str,
1275
1366
  experiment_info: Dict,
1367
+ eval_metrics: Dict = {},
1276
1368
  evaluation_report_id: Optional[int] = None,
1369
+ export_weights: Dict = {},
1277
1370
  ) -> None:
1278
1371
  """
1279
1372
  Generates and uploads the experiment_info.json file to the output directory.
@@ -1282,8 +1375,12 @@ class TrainApp:
1282
1375
  :type remote_dir: str
1283
1376
  :param experiment_info: Information about the experiment results.
1284
1377
  :type experiment_info: dict
1378
+ :param eval_metrics: Evaluation metrics.
1379
+ :type eval_metrics: dict
1285
1380
  :param evaluation_report_id: Evaluation report file ID.
1286
1381
  :type evaluation_report_id: int
1382
+ :param export_weights: Export data.
1383
+ :type export_weights: dict
1287
1384
  """
1288
1385
  logger.debug("Updating experiment info")
1289
1386
 
@@ -1296,7 +1393,8 @@ class TrainApp:
1296
1393
  "task_id": self.task_id,
1297
1394
  "model_files": experiment_info["model_files"],
1298
1395
  "checkpoints": experiment_info["checkpoints"],
1299
- "best_checkpoint": experiment_info["best_checkpoint"],
1396
+ "best_checkpoint": sly_fs.get_file_name_with_ext(experiment_info["best_checkpoint"]),
1397
+ "export": export_weights,
1300
1398
  "app_state": self._app_state_file,
1301
1399
  "model_meta": self._model_meta_file,
1302
1400
  "train_val_split": self._train_val_split_file,
@@ -1304,7 +1402,7 @@ class TrainApp:
1304
1402
  "artifacts_dir": remote_dir,
1305
1403
  "datetime": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
1306
1404
  "evaluation_report_id": evaluation_report_id,
1307
- "eval_metrics": {},
1405
+ "evaluation_metrics": eval_metrics,
1308
1406
  }
1309
1407
 
1310
1408
  remote_checkpoints_dir = join(remote_dir, self._remote_checkpoints_dir_name)
@@ -1497,7 +1595,9 @@ class TrainApp:
1497
1595
  set_directory(remote_dir)
1498
1596
  self.gui.training_process.artifacts_thumbnail.set(file_info)
1499
1597
  self.gui.training_process.artifacts_thumbnail.show()
1500
- self.gui.training_process.success_message.show()
1598
+ self.gui.training_process.validator_text.set(
1599
+ self.gui.training_process.success_message_text, "success"
1600
+ )
1501
1601
 
1502
1602
  # Model Benchmark
1503
1603
  def _get_eval_results_dir_name(self) -> str:
@@ -1506,8 +1606,10 @@ class TrainApp:
1506
1606
  """
1507
1607
  task_info = self._api.task.get_info_by_id(self.task_id)
1508
1608
  task_dir = f"{self.task_id}_{task_info['meta']['app']['name']}"
1509
- eval_res_dir = f"/model-benchmark/evaluation/{self.project_info.id}_{self.project_info.name}/{task_dir}/"
1510
- eval_res_dir = self._api.storage.get_free_dir_name(self._team_id(), eval_res_dir)
1609
+ eval_res_dir = (
1610
+ f"/model-benchmark/{self.project_info.id}_{self.project_info.name}/{task_dir}/"
1611
+ )
1612
+ eval_res_dir = self._api.storage.get_free_dir_name(self._team_id, eval_res_dir)
1511
1613
  return eval_res_dir
1512
1614
 
1513
1615
  def _run_model_benchmark(
@@ -1528,16 +1630,16 @@ class TrainApp:
1528
1630
  :type experiment_info: dict
1529
1631
  :param splits_data: Information about the train and val splits.
1530
1632
  :type splits_data: dict
1531
- :return: Evaluation report and report ID.
1633
+ :return: Evaluation report, report ID and evaluation metrics.
1532
1634
  :rtype: tuple
1533
1635
  """
1534
- report, report_id = None, None
1636
+ report_file, report_id, eval_metrics = None, None, {}
1535
1637
  if self._inference_class is None:
1536
1638
  logger.warn(
1537
1639
  "Inference class is not registered, model benchmark disabled. "
1538
1640
  "Use 'register_inference_class' method to register inference class."
1539
1641
  )
1540
- return report, report_id
1642
+ return report_file, report_id, eval_metrics
1541
1643
 
1542
1644
  # Can't get task type from session. requires before session init
1543
1645
  supported_task_types = [
@@ -1550,7 +1652,7 @@ class TrainApp:
1550
1652
  f"Task type: '{task_type}' is not supported for Model Benchmark. "
1551
1653
  f"Supported tasks: {', '.join(task_type)}"
1552
1654
  )
1553
- return report, report_id
1655
+ return report_file, report_id, eval_metrics
1554
1656
 
1555
1657
  logger.info("Running Model Benchmark evaluation")
1556
1658
  try:
@@ -1569,13 +1671,14 @@ class TrainApp:
1569
1671
 
1570
1672
  logger.info(f"Creating the report for the best model: {best_filename!r}")
1571
1673
  self.gui.training_process.validator_text.set(
1572
- f"Creating evaluation report for the best model: {best_filename!r}", "info"
1674
+ f"Creating evaluation report for the best model: {best_filename!r}",
1675
+ "info",
1573
1676
  )
1574
1677
  self.progress_bar_main(message="Starting Model Benchmark evaluation", total=1)
1575
1678
  self.progress_bar_main.show()
1576
1679
 
1577
1680
  # 0. Serve trained model
1578
- m = self._inference_class(
1681
+ m: Inference = self._inference_class(
1579
1682
  model_dir=self.model_dir,
1580
1683
  use_gui=False,
1581
1684
  custom_inference_settings=self._inference_settings,
@@ -1677,6 +1780,7 @@ class TrainApp:
1677
1780
 
1678
1781
  # 4. Evaluate
1679
1782
  bm._evaluate(gt_project_path, dt_project_path)
1783
+ bm._dump_eval_inference_info(bm._eval_inference_info)
1680
1784
 
1681
1785
  # 5. Upload evaluation results
1682
1786
  eval_res_dir = self._get_eval_results_dir_name()
@@ -1691,8 +1795,9 @@ class TrainApp:
1691
1795
  # 7. Prepare visualizations, report and upload
1692
1796
  bm.visualize()
1693
1797
  remote_dir = bm.upload_visualizations(eval_res_dir + "/visualizations/")
1694
- report = bm.upload_report_link(remote_dir)
1695
- report_id = report.id
1798
+ report_file = bm.upload_report_link(remote_dir)
1799
+ report_id = bm.report_id
1800
+ eval_metrics = bm.key_metrics
1696
1801
 
1697
1802
  # 8. UI updates
1698
1803
  benchmark_report_template = self._api.file.get_info_by_path(
@@ -1725,8 +1830,8 @@ class TrainApp:
1725
1830
  if bm.diff_project_info:
1726
1831
  self._api.project.remove(bm.diff_project_info.id)
1727
1832
  except Exception as e2:
1728
- return report, report_id
1729
- return report, report_id
1833
+ return report_file, report_id, eval_metrics
1834
+ return report_file, report_id, eval_metrics
1730
1835
 
1731
1836
  # ----------------------------------------- #
1732
1837
 
@@ -1772,6 +1877,7 @@ class TrainApp:
1772
1877
  team_files_dir: str,
1773
1878
  file_info: FileInfo,
1774
1879
  model_benchmark_report: Optional[FileInfo] = None,
1880
+ model_benchmark_report_id: Optional[FileInfo] = None,
1775
1881
  ):
1776
1882
  """
1777
1883
  Adds the output data to the workflow.
@@ -1814,24 +1920,25 @@ class TrainApp:
1814
1920
  f"File with checkpoints not found in Team Files. Cannot set workflow output."
1815
1921
  )
1816
1922
 
1817
- if model_benchmark_report:
1818
- mb_relation_settings = WorkflowSettings(
1819
- title="Model Benchmark",
1820
- icon="assignment",
1821
- icon_color="#674EA7",
1822
- icon_bg_color="#CCCCFF",
1823
- url=f"/model-benchmark?id={model_benchmark_report.id}",
1824
- url_title="Open Report",
1825
- )
1923
+ if self.is_model_benchmark_enabled:
1924
+ if model_benchmark_report:
1925
+ mb_relation_settings = WorkflowSettings(
1926
+ title="Model Benchmark",
1927
+ icon="assignment",
1928
+ icon_color="#674EA7",
1929
+ icon_bg_color="#CCCCFF",
1930
+ url=f"/model-benchmark?id={model_benchmark_report_id}",
1931
+ url_title="Open Report",
1932
+ )
1826
1933
 
1827
- meta = WorkflowMeta(
1828
- relation_settings=mb_relation_settings, node_settings=node_settings
1829
- )
1830
- self._api.app.workflow.add_output_file(model_benchmark_report, meta=meta)
1831
- else:
1832
- logger.debug(
1833
- f"File with model benchmark report not found in Team Files. Cannot set workflow output."
1834
- )
1934
+ meta = WorkflowMeta(
1935
+ relation_settings=mb_relation_settings, node_settings=node_settings
1936
+ )
1937
+ self._api.app.workflow.add_output_file(model_benchmark_report, meta=meta)
1938
+ else:
1939
+ logger.debug(
1940
+ f"File with model benchmark report not found in Team Files. Cannot set workflow output."
1941
+ )
1835
1942
  except Exception as e:
1836
1943
  logger.debug(f"Failed to add output to the workflow: {repr(e)}")
1837
1944
  # ----------------------------------------- #
@@ -1841,13 +1948,19 @@ class TrainApp:
1841
1948
  """
1842
1949
  Initialize training logger. Set up Tensorboard and callbacks.
1843
1950
  """
1844
- train_logger = self._app_options.get("train_logger", "")
1845
- if train_logger.lower() == "tensorboard":
1846
- tb_logger.set_log_dir(self.log_dir)
1847
- self._setup_logger_callbacks()
1951
+ selected_logger = self._app_options.get("train_logger", "")
1952
+ if selected_logger.lower() == "tensorboard":
1953
+ setup_train_logger("tensorboard_logger")
1954
+ train_logger.set_log_dir(self.log_dir)
1848
1955
  self._init_tensorboard()
1956
+ else:
1957
+ setup_train_logger("default_logger")
1958
+ self._setup_logger_callbacks()
1849
1959
 
1850
1960
  def _init_tensorboard(self):
1961
+ if self._tensorboard_process is not None:
1962
+ logger.debug("Tensorboard server is already running")
1963
+ return
1851
1964
  self._register_routes()
1852
1965
  args = [
1853
1966
  "tensorboard",
@@ -1910,10 +2023,7 @@ class TrainApp:
1910
2023
  """
1911
2024
  self.progress_bar_main.hide()
1912
2025
  self.progress_bar_secondary.hide()
1913
-
1914
- train_logger = self._app_options.get("train_logger", "")
1915
- if train_logger == "tensorboard":
1916
- tb_logger.close()
2026
+ train_logger.close()
1917
2027
 
1918
2028
  def start_epoch_callback(total_steps: int):
1919
2029
  """
@@ -1937,13 +2047,13 @@ class TrainApp:
1937
2047
  """
1938
2048
  step_pbar.update(1)
1939
2049
 
1940
- tb_logger.add_on_train_started_callback(start_training_callback)
1941
- tb_logger.add_on_train_finish_callback(finish_training_callback)
2050
+ train_logger.add_on_train_started_callback(start_training_callback)
2051
+ train_logger.add_on_train_finish_callback(finish_training_callback)
1942
2052
 
1943
- tb_logger.add_on_epoch_started_callback(start_epoch_callback)
1944
- tb_logger.add_on_epoch_finish_callback(finish_epoch_callback)
2053
+ train_logger.add_on_epoch_started_callback(start_epoch_callback)
2054
+ train_logger.add_on_epoch_finish_callback(finish_epoch_callback)
1945
2055
 
1946
- tb_logger.add_on_step_callback(step_callback)
2056
+ train_logger.add_on_step_finished_callback(step_callback)
1947
2057
 
1948
2058
  # ----------------------------------------- #
1949
2059
  def _wrapped_start_training(self):
@@ -1962,9 +2072,11 @@ class TrainApp:
1962
2072
  message = "Error occurred during training initialization. Please check the logs for more details."
1963
2073
  self._show_error(message, e)
1964
2074
  self._restore_train_widgets_state_on_error()
2075
+ self._set_progress_status("reset")
2076
+ return
1965
2077
 
1966
2078
  try:
1967
- self.gui.training_process.validator_text.set("Preparing data for training...", "info")
2079
+ self._set_progress_status("preparing")
1968
2080
  self._prepare()
1969
2081
  except Exception as e:
1970
2082
  message = (
@@ -1972,30 +2084,30 @@ class TrainApp:
1972
2084
  )
1973
2085
  self._show_error(message, e)
1974
2086
  self._restore_train_widgets_state_on_error()
2087
+ self._set_progress_status("reset")
1975
2088
  return
1976
2089
 
1977
2090
  try:
1978
- self.gui.training_process.validator_text.set("Training is in progress...", "info")
2091
+ self._set_progress_status("training")
2092
+ if self._app_options.get("train_logger", None) is None:
2093
+ self._set_progress_status("training")
1979
2094
  experiment_info = self._train_func()
1980
2095
  except Exception as e:
1981
2096
  message = "Error occurred during training. Please check the logs for more details."
1982
2097
  self._show_error(message, e)
1983
2098
  self._restore_train_widgets_state_on_error()
2099
+ self._set_progress_status("reset")
1984
2100
  return
1985
2101
 
1986
2102
  try:
1987
- self.gui.training_process.validator_text.set(
1988
- "Finalizing and uploading training artifacts...", "info"
1989
- )
2103
+ self._set_progress_status("finalizing")
1990
2104
  self._finalize(experiment_info)
1991
2105
  self.gui.training_process.start_button.loading = False
1992
- self.gui.training_process.validator_text.set(
1993
- self.gui.training_process.success_message_text, "success"
1994
- )
1995
2106
  except Exception as e:
1996
2107
  message = "Error occurred during finalizing and uploading training artifacts . Please check the logs for more details."
1997
2108
  self._show_error(message, e)
1998
2109
  self._restore_train_widgets_state_on_error()
2110
+ self._set_progress_status("reset")
1999
2111
  return
2000
2112
 
2001
2113
  def _show_error(self, message: str, e=None):
@@ -2017,7 +2129,7 @@ class TrainApp:
2017
2129
 
2018
2130
  self.gui.training_logs.card.unlock()
2019
2131
  self.gui.stepper.set_active_step(7)
2020
- self.gui.training_process.validator_text.set("Training is started...", "info")
2132
+ self.gui.training_process.validator_text.set("Training has been started...", "info")
2021
2133
  self.gui.training_process.validator_text.show()
2022
2134
  self.gui.training_process.start_button.loading = True
2023
2135
 
@@ -2039,3 +2151,77 @@ class TrainApp:
2039
2151
  logger.error(f"Experiment name contains invalid characters: {invalid_chars}")
2040
2152
  raise ValueError(f"Experiment name contains invalid characters: {invalid_chars}")
2041
2153
  return True
2154
+
2155
+ def _set_progress_status(
2156
+ self, status: Literal["reset", "completed", "training", "finalizing", "preparing"]
2157
+ ):
2158
+ message = ""
2159
+ if status == "reset":
2160
+ message = "Ready for training"
2161
+ elif status == "completed":
2162
+ message = "Training completed"
2163
+ elif status == "training":
2164
+ message = "Training is in progress..."
2165
+ elif status == "finalizing":
2166
+ message = "Finalizing and uploading training artifacts..."
2167
+ elif status == "preparing":
2168
+ message = "Preparing data for training..."
2169
+
2170
+ self.progress_bar_main.hide()
2171
+ self.progress_bar_secondary.hide()
2172
+ with self.progress_bar_main(message=message, total=1) as pbar:
2173
+ pbar.update(1)
2174
+ with self.progress_bar_secondary(message=message, total=1) as pbar:
2175
+ pbar.update(1)
2176
+
2177
+ def _export_weights(self, experiment_info: dict) -> List[str]:
2178
+ export_weights = {}
2179
+ if (
2180
+ self.gui.hyperparameters_selector.get_export_onnx_checkbox_value() is True
2181
+ and self._convert_onnx_func is not None
2182
+ ):
2183
+ self.gui.training_process.validator_text.set(
2184
+ f"Converting to {RuntimeType.ONNXRUNTIME}", "info"
2185
+ )
2186
+ onnx_path = self._convert_onnx_func(experiment_info)
2187
+ export_weights[RuntimeType.ONNXRUNTIME] = onnx_path
2188
+
2189
+ if (
2190
+ self.gui.hyperparameters_selector.get_export_tensorrt_checkbox_value() is True
2191
+ and self._convert_tensorrt_func is not None
2192
+ ):
2193
+ self.gui.training_process.validator_text.set(
2194
+ f"Converting to {RuntimeType.TENSORRT}", "info"
2195
+ )
2196
+ tensorrt_path = self._convert_tensorrt_func(experiment_info)
2197
+ export_weights[RuntimeType.TENSORRT] = tensorrt_path
2198
+ return export_weights
2199
+
2200
+ def _upload_export_weights(
2201
+ self, export_weights: Dict[str, str], remote_dir: str
2202
+ ) -> Dict[str, str]:
2203
+ with self.progress_bar_main(
2204
+ message="Uploading export weights",
2205
+ total=len(export_weights),
2206
+ ) as export_upload_main_pbar:
2207
+ self.progress_bar_main.show()
2208
+ for path in export_weights.values():
2209
+ file_name = sly_fs.get_file_name_with_ext(path)
2210
+ file_size = sly_fs.get_file_size(path)
2211
+ with self.progress_bar_secondary(
2212
+ message=f"Uploading '{file_name}' ",
2213
+ total=file_size,
2214
+ unit="bytes",
2215
+ unit_scale=True,
2216
+ ) as export_upload_secondary_pbar:
2217
+ destination_path = join(remote_dir, self._export_dir_name, file_name)
2218
+ self._api.file.upload(
2219
+ self._team_id, path, destination_path, export_upload_secondary_pbar
2220
+ )
2221
+ export_upload_main_pbar.update(1)
2222
+
2223
+ remote_export_weights = {
2224
+ runtime: join(self._export_dir_name, sly_fs.get_file_name_with_ext(path))
2225
+ for runtime, path in export_weights.items()
2226
+ }
2227
+ return remote_export_weights