supervisely 6.73.250__py3-none-any.whl → 6.73.252__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/dataset_api.py +17 -1
- supervisely/api/project_api.py +4 -1
- supervisely/api/volume/volume_annotation_api.py +7 -4
- supervisely/app/widgets/experiment_selector/experiment_selector.py +16 -8
- supervisely/nn/benchmark/base_benchmark.py +17 -2
- supervisely/nn/benchmark/base_evaluator.py +28 -6
- supervisely/nn/benchmark/instance_segmentation/benchmark.py +1 -1
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +14 -0
- supervisely/nn/benchmark/object_detection/benchmark.py +1 -1
- supervisely/nn/benchmark/object_detection/evaluator.py +43 -13
- supervisely/nn/benchmark/object_detection/metric_provider.py +7 -0
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +33 -7
- supervisely/nn/benchmark/utils/detection/utlis.py +6 -4
- supervisely/nn/experiments.py +23 -16
- supervisely/nn/inference/gui/serving_gui_template.py +2 -35
- supervisely/nn/inference/inference.py +71 -8
- supervisely/nn/training/__init__.py +2 -0
- supervisely/nn/training/gui/classes_selector.py +14 -14
- supervisely/nn/training/gui/gui.py +28 -13
- supervisely/nn/training/gui/hyperparameters_selector.py +90 -41
- supervisely/nn/training/gui/input_selector.py +8 -6
- supervisely/nn/training/gui/model_selector.py +7 -5
- supervisely/nn/training/gui/train_val_splits_selector.py +8 -9
- supervisely/nn/training/gui/training_logs.py +17 -17
- supervisely/nn/training/gui/training_process.py +41 -36
- supervisely/nn/training/loggers/__init__.py +22 -0
- supervisely/nn/training/loggers/base_train_logger.py +8 -5
- supervisely/nn/training/loggers/tensorboard_logger.py +4 -11
- supervisely/nn/training/train_app.py +276 -90
- {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/METADATA +8 -3
- {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/RECORD +35 -35
- {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/LICENSE +0 -0
- {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/WHEEL +0 -0
- {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/top_level.txt +0 -0
|
@@ -9,8 +9,8 @@ import shutil
|
|
|
9
9
|
import subprocess
|
|
10
10
|
from datetime import datetime
|
|
11
11
|
from os import listdir
|
|
12
|
-
from os.path import basename, isdir, isfile, join
|
|
13
|
-
from typing import Any, Dict, List, Optional, Union
|
|
12
|
+
from os.path import basename, exists, expanduser, isdir, isfile, join
|
|
13
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
14
14
|
from urllib.request import urlopen
|
|
15
15
|
|
|
16
16
|
import httpx
|
|
@@ -50,9 +50,10 @@ from supervisely.nn.benchmark import (
|
|
|
50
50
|
SemanticSegmentationEvaluator,
|
|
51
51
|
)
|
|
52
52
|
from supervisely.nn.inference import RuntimeType, SessionJSON
|
|
53
|
+
from supervisely.nn.inference.inference import Inference
|
|
53
54
|
from supervisely.nn.task_type import TaskType
|
|
54
55
|
from supervisely.nn.training.gui.gui import TrainGUI
|
|
55
|
-
from supervisely.nn.training.loggers
|
|
56
|
+
from supervisely.nn.training.loggers import setup_train_logger, train_logger
|
|
56
57
|
from supervisely.nn.utils import ModelSource
|
|
57
58
|
from supervisely.output import set_directory
|
|
58
59
|
from supervisely.project.download import (
|
|
@@ -109,6 +110,7 @@ class TrainApp:
|
|
|
109
110
|
self._remote_checkpoints_dir_name = "checkpoints"
|
|
110
111
|
self._experiments_dir_name = "experiments"
|
|
111
112
|
self._default_work_dir_name = "work_dir"
|
|
113
|
+
self._export_dir_name = "export"
|
|
112
114
|
self._tensorboard_port = 6006
|
|
113
115
|
|
|
114
116
|
if is_production():
|
|
@@ -121,6 +123,7 @@ class TrainApp:
|
|
|
121
123
|
self._team_id = sly_env.team_id()
|
|
122
124
|
self._workspace_id = sly_env.workspace_id()
|
|
123
125
|
self._app_name = sly_env.app_name(raise_not_found=False)
|
|
126
|
+
self._tensorboard_process = None
|
|
124
127
|
|
|
125
128
|
# TODO: read files
|
|
126
129
|
self._models = self._load_models(models)
|
|
@@ -139,7 +142,11 @@ class TrainApp:
|
|
|
139
142
|
self.project_dir = join(self.work_dir, self._sly_project_dir_name)
|
|
140
143
|
self.train_dataset_dir = join(self.project_dir, "train")
|
|
141
144
|
self.val_dataset_dir = join(self.project_dir, "val")
|
|
145
|
+
self._model_cache_dir = join(expanduser("~"), ".cache", "supervisely", "checkpoints")
|
|
142
146
|
self.sly_project = None
|
|
147
|
+
# -------------------------- #
|
|
148
|
+
|
|
149
|
+
# Train/Val splits
|
|
143
150
|
self.train_split, self.val_split = None, None
|
|
144
151
|
# -------------------------- #
|
|
145
152
|
|
|
@@ -166,6 +173,13 @@ class TrainApp:
|
|
|
166
173
|
self._server = self.app.get_server()
|
|
167
174
|
self._train_func = None
|
|
168
175
|
|
|
176
|
+
self._onnx_supported = self._app_options.get("export_onnx_supported", False)
|
|
177
|
+
self._tensorrt_supported = self._app_options.get("export_tensorrt_supported", False)
|
|
178
|
+
if self._onnx_supported:
|
|
179
|
+
self._convert_onnx_func = None
|
|
180
|
+
if self._tensorrt_supported:
|
|
181
|
+
self._convert_tensorrt_func = None
|
|
182
|
+
|
|
169
183
|
# Benchmark parameters
|
|
170
184
|
if self.is_model_benchmark_enabled:
|
|
171
185
|
self._benchmark_params = {
|
|
@@ -263,6 +277,16 @@ class TrainApp:
|
|
|
263
277
|
"""
|
|
264
278
|
return self.gui.project_info
|
|
265
279
|
|
|
280
|
+
@property
|
|
281
|
+
def project_meta(self) -> ProjectMeta:
|
|
282
|
+
"""
|
|
283
|
+
Returns the project metadata.
|
|
284
|
+
|
|
285
|
+
:return: Project metadata.
|
|
286
|
+
:rtype: ProjectMeta
|
|
287
|
+
"""
|
|
288
|
+
return self.gui.project_meta
|
|
289
|
+
|
|
266
290
|
# ----------------------------------------- #
|
|
267
291
|
|
|
268
292
|
# Model
|
|
@@ -304,7 +328,7 @@ class TrainApp:
|
|
|
304
328
|
:return: Model metadata.
|
|
305
329
|
:rtype: dict
|
|
306
330
|
"""
|
|
307
|
-
project_meta_json = self.
|
|
331
|
+
project_meta_json = self.project_meta.to_json()
|
|
308
332
|
model_meta = {
|
|
309
333
|
"classes": [
|
|
310
334
|
item for item in project_meta_json["classes"] if item["title"] in self.classes
|
|
@@ -331,7 +355,9 @@ class TrainApp:
|
|
|
331
355
|
:return: List of selected classes.
|
|
332
356
|
:rtype: List[str]
|
|
333
357
|
"""
|
|
334
|
-
|
|
358
|
+
selected_classes = set(self.gui.classes_selector.get_selected_classes())
|
|
359
|
+
# remap classes with project_meta order
|
|
360
|
+
return [x for x in self.project_meta.obj_classes.keys() if x in selected_classes]
|
|
335
361
|
|
|
336
362
|
@property
|
|
337
363
|
def num_classes(self) -> int:
|
|
@@ -401,7 +427,7 @@ class TrainApp:
|
|
|
401
427
|
# Output
|
|
402
428
|
# ----------------------------------------- #
|
|
403
429
|
|
|
404
|
-
#
|
|
430
|
+
# Wrappers
|
|
405
431
|
@property
|
|
406
432
|
def start(self):
|
|
407
433
|
"""
|
|
@@ -416,6 +442,34 @@ class TrainApp:
|
|
|
416
442
|
|
|
417
443
|
return decorator
|
|
418
444
|
|
|
445
|
+
@property
|
|
446
|
+
def export_onnx(self):
|
|
447
|
+
"""
|
|
448
|
+
Decorator for the export to ONNX function defined by user.
|
|
449
|
+
It wraps user-defined export function and prepares and finalizes the training process.
|
|
450
|
+
"""
|
|
451
|
+
|
|
452
|
+
def decorator(func):
|
|
453
|
+
self._convert_onnx_func = func
|
|
454
|
+
return func
|
|
455
|
+
|
|
456
|
+
return decorator
|
|
457
|
+
|
|
458
|
+
@property
|
|
459
|
+
def export_tensorrt(self):
|
|
460
|
+
"""
|
|
461
|
+
Decorator for the export to TensorRT function defined by user.
|
|
462
|
+
It wraps user-defined export function and prepares and finalizes the training process.
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
def decorator(func):
|
|
466
|
+
self._convert_tensorrt_func = func
|
|
467
|
+
return func
|
|
468
|
+
|
|
469
|
+
return decorator
|
|
470
|
+
|
|
471
|
+
# ----------------------------------------- #
|
|
472
|
+
|
|
419
473
|
def _prepare(self) -> None:
|
|
420
474
|
"""
|
|
421
475
|
Prepares the environment for training by setting up directories,
|
|
@@ -451,7 +505,7 @@ class TrainApp:
|
|
|
451
505
|
raise ValueError(f"{reason}. Failed to upload artifacts")
|
|
452
506
|
|
|
453
507
|
# Step 2. Preprocess artifacts
|
|
454
|
-
self._preprocess_artifacts(experiment_info)
|
|
508
|
+
experiment_info = self._preprocess_artifacts(experiment_info)
|
|
455
509
|
|
|
456
510
|
# Step3. Postprocess splits
|
|
457
511
|
splits_data = self._postprocess_splits()
|
|
@@ -460,33 +514,46 @@ class TrainApp:
|
|
|
460
514
|
remote_dir, file_info = self._upload_artifacts()
|
|
461
515
|
|
|
462
516
|
# Step 4. Run Model Benchmark
|
|
463
|
-
|
|
464
|
-
|
|
517
|
+
mb_eval_report_file, mb_eval_report_id, eval_metrics = None, None, {}
|
|
465
518
|
if self.is_model_benchmark_enabled:
|
|
466
519
|
try:
|
|
467
|
-
|
|
520
|
+
mb_eval_report_file, mb_eval_report_id, eval_metrics = self._run_model_benchmark(
|
|
468
521
|
self.output_dir, remote_dir, experiment_info, splits_data
|
|
469
522
|
)
|
|
470
523
|
except Exception as e:
|
|
471
524
|
logger.error(f"Model benchmark failed: {e}")
|
|
472
525
|
|
|
473
|
-
# Step
|
|
474
|
-
|
|
526
|
+
# Step 5. [Optional] Convert weights
|
|
527
|
+
export_weights = {}
|
|
528
|
+
if self.gui.hyperparameters_selector.is_export_required():
|
|
529
|
+
try:
|
|
530
|
+
export_weights = self._export_weights(experiment_info)
|
|
531
|
+
self._set_progress_status("finalizing")
|
|
532
|
+
export_weights = self._upload_export_weights(export_weights, remote_dir)
|
|
533
|
+
except Exception as e:
|
|
534
|
+
logger.error(f"Export weights failed: {e}")
|
|
535
|
+
|
|
536
|
+
# Step 6. Generate and upload additional files
|
|
537
|
+
self._generate_experiment_info(
|
|
538
|
+
remote_dir, experiment_info, eval_metrics, mb_eval_report_id, export_weights
|
|
539
|
+
)
|
|
475
540
|
self._generate_app_state(remote_dir, experiment_info)
|
|
476
541
|
self._generate_hyperparameters(remote_dir, experiment_info)
|
|
477
542
|
self._generate_train_val_splits(remote_dir, splits_data)
|
|
478
543
|
self._generate_model_meta(remote_dir, experiment_info)
|
|
479
544
|
|
|
480
|
-
# Step
|
|
545
|
+
# Step 7. Set output widgets
|
|
481
546
|
self._set_training_output(remote_dir, file_info)
|
|
482
547
|
|
|
483
|
-
# Step
|
|
548
|
+
# Step 8. Workflow output
|
|
484
549
|
if is_production():
|
|
485
|
-
self._workflow_output(remote_dir, file_info,
|
|
550
|
+
self._workflow_output(remote_dir, file_info, mb_eval_report_file, mb_eval_report_id)
|
|
486
551
|
|
|
487
|
-
|
|
552
|
+
self._set_progress_status("completed")
|
|
488
553
|
|
|
489
|
-
def register_inference_class(
|
|
554
|
+
def register_inference_class(
|
|
555
|
+
self, inference_class: Inference, inference_settings: dict = {}
|
|
556
|
+
) -> None:
|
|
490
557
|
"""
|
|
491
558
|
Registers an inference class for the training application to do model benchmarking.
|
|
492
559
|
|
|
@@ -878,6 +945,7 @@ class TrainApp:
|
|
|
878
945
|
For Pretrained models:
|
|
879
946
|
- The files that will be downloaded are specified in the `meta` key under `model_files`.
|
|
880
947
|
- All files listed in the `model_files` key will be downloaded by provided link.
|
|
948
|
+
- If model files are already cached on agent, they will be copied to the model directory without downloading.
|
|
881
949
|
Example of a pretrained model entry:
|
|
882
950
|
[
|
|
883
951
|
{
|
|
@@ -930,25 +998,39 @@ class TrainApp:
|
|
|
930
998
|
for file in model_files:
|
|
931
999
|
file_url = model_files[file]
|
|
932
1000
|
file_path = join(self.model_dir, file)
|
|
933
|
-
|
|
1001
|
+
file_name = sly_fs.get_file_name_with_ext(file_url)
|
|
934
1002
|
if file_url.startswith("http"):
|
|
935
1003
|
with urlopen(file_url) as f:
|
|
936
1004
|
file_size = f.length
|
|
937
1005
|
file_name = get_filename_from_headers(file_url)
|
|
1006
|
+
if file_name is None:
|
|
1007
|
+
file_name = file
|
|
938
1008
|
file_path = join(self.model_dir, file_name)
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
1009
|
+
cached_path = join(self._model_cache_dir, file_name)
|
|
1010
|
+
if exists(cached_path):
|
|
1011
|
+
self.model_files[file] = cached_path
|
|
1012
|
+
logger.debug(f"Model: '{file_name}' was found in checkpoint cache")
|
|
1013
|
+
model_download_main_pbar.update(1)
|
|
1014
|
+
continue
|
|
1015
|
+
if exists(file_path):
|
|
1016
|
+
self.model_files[file] = file_path
|
|
1017
|
+
logger.debug(f"Model: '{file_name}' was found in model dir")
|
|
1018
|
+
model_download_main_pbar.update(1)
|
|
1019
|
+
continue
|
|
1020
|
+
|
|
1021
|
+
with self.progress_bar_secondary(
|
|
1022
|
+
message=f"Downloading '{file_name}' ",
|
|
1023
|
+
total=file_size,
|
|
1024
|
+
unit="bytes",
|
|
1025
|
+
unit_scale=True,
|
|
1026
|
+
) as model_download_secondary_pbar:
|
|
1027
|
+
self.progress_bar_secondary.show()
|
|
1028
|
+
sly_fs.download(
|
|
1029
|
+
url=file_url,
|
|
1030
|
+
save_path=file_path,
|
|
1031
|
+
progress=model_download_secondary_pbar.update,
|
|
1032
|
+
)
|
|
1033
|
+
self.model_files[file] = file_path
|
|
952
1034
|
else:
|
|
953
1035
|
self.model_files[file] = file_url
|
|
954
1036
|
model_download_main_pbar.update(1)
|
|
@@ -1070,7 +1152,8 @@ class TrainApp:
|
|
|
1070
1152
|
elif isinstance(checkpoints, str):
|
|
1071
1153
|
checkpoints = [
|
|
1072
1154
|
sly_fs.get_file_name_with_ext(checkpoint)
|
|
1073
|
-
for checkpoint in
|
|
1155
|
+
for checkpoint in listdir(checkpoints)
|
|
1156
|
+
if sly_fs.get_file_ext(checkpoint) in [".pt", ".pth"]
|
|
1074
1157
|
]
|
|
1075
1158
|
if best_checkpoint not in checkpoints:
|
|
1076
1159
|
reason = (
|
|
@@ -1177,6 +1260,7 @@ class TrainApp:
|
|
|
1177
1260
|
config_name = sly_fs.get_file_name_with_ext(experiment_info["model_files"]["config"])
|
|
1178
1261
|
output_config_path = join(self.output_dir, config_name)
|
|
1179
1262
|
shutil.move(experiment_info["model_files"]["config"], output_config_path)
|
|
1263
|
+
experiment_info["model_files"]["config"] = output_config_path
|
|
1180
1264
|
if self.is_model_benchmark_enabled:
|
|
1181
1265
|
self._benchmark_params["model_files"]["config"] = output_config_path
|
|
1182
1266
|
|
|
@@ -1185,8 +1269,10 @@ class TrainApp:
|
|
|
1185
1269
|
# If checkpoints returned as directory
|
|
1186
1270
|
if isinstance(checkpoints, str):
|
|
1187
1271
|
checkpoint_paths = []
|
|
1188
|
-
for checkpoint_path in
|
|
1189
|
-
|
|
1272
|
+
for checkpoint_path in listdir(checkpoints):
|
|
1273
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
1274
|
+
if checkpoint_ext in [".pt", ".pth"]:
|
|
1275
|
+
checkpoint_paths.append(join(checkpoints, checkpoint_path))
|
|
1190
1276
|
elif isinstance(checkpoints, list):
|
|
1191
1277
|
checkpoint_paths = checkpoints
|
|
1192
1278
|
else:
|
|
@@ -1194,6 +1280,7 @@ class TrainApp:
|
|
|
1194
1280
|
"Checkpoints should be a list of paths or a path to directory with checkpoints"
|
|
1195
1281
|
)
|
|
1196
1282
|
|
|
1283
|
+
new_checkpoint_paths = []
|
|
1197
1284
|
best_checkpoints_name = experiment_info["best_checkpoint"]
|
|
1198
1285
|
for checkpoint_path in checkpoint_paths:
|
|
1199
1286
|
new_checkpoint_path = join(
|
|
@@ -1201,14 +1288,18 @@ class TrainApp:
|
|
|
1201
1288
|
sly_fs.get_file_name_with_ext(checkpoint_path),
|
|
1202
1289
|
)
|
|
1203
1290
|
shutil.move(checkpoint_path, new_checkpoint_path)
|
|
1204
|
-
|
|
1205
|
-
|
|
1291
|
+
new_checkpoint_paths.append(new_checkpoint_path)
|
|
1292
|
+
if sly_fs.get_file_name_with_ext(checkpoint_path) == best_checkpoints_name:
|
|
1293
|
+
experiment_info["best_checkpoint"] = new_checkpoint_path
|
|
1294
|
+
if self.is_model_benchmark_enabled:
|
|
1206
1295
|
self._benchmark_params["model_files"]["checkpoint"] = new_checkpoint_path
|
|
1296
|
+
experiment_info["checkpoints"] = new_checkpoint_paths
|
|
1207
1297
|
|
|
1208
1298
|
# Prepare logs
|
|
1209
1299
|
if sly_fs.dir_exists(self.log_dir):
|
|
1210
1300
|
logs_dir = join(self.output_dir, "logs")
|
|
1211
1301
|
shutil.copytree(self.log_dir, logs_dir)
|
|
1302
|
+
return experiment_info
|
|
1212
1303
|
|
|
1213
1304
|
# Generate experiment_info.json and app_state.json
|
|
1214
1305
|
def _upload_file_to_team_files(self, local_path: str, remote_path: str, message: str) -> None:
|
|
@@ -1273,7 +1364,9 @@ class TrainApp:
|
|
|
1273
1364
|
self,
|
|
1274
1365
|
remote_dir: str,
|
|
1275
1366
|
experiment_info: Dict,
|
|
1367
|
+
eval_metrics: Dict = {},
|
|
1276
1368
|
evaluation_report_id: Optional[int] = None,
|
|
1369
|
+
export_weights: Dict = {},
|
|
1277
1370
|
) -> None:
|
|
1278
1371
|
"""
|
|
1279
1372
|
Generates and uploads the experiment_info.json file to the output directory.
|
|
@@ -1282,8 +1375,12 @@ class TrainApp:
|
|
|
1282
1375
|
:type remote_dir: str
|
|
1283
1376
|
:param experiment_info: Information about the experiment results.
|
|
1284
1377
|
:type experiment_info: dict
|
|
1378
|
+
:param eval_metrics: Evaluation metrics.
|
|
1379
|
+
:type eval_metrics: dict
|
|
1285
1380
|
:param evaluation_report_id: Evaluation report file ID.
|
|
1286
1381
|
:type evaluation_report_id: int
|
|
1382
|
+
:param export_weights: Export data.
|
|
1383
|
+
:type export_weights: dict
|
|
1287
1384
|
"""
|
|
1288
1385
|
logger.debug("Updating experiment info")
|
|
1289
1386
|
|
|
@@ -1296,7 +1393,8 @@ class TrainApp:
|
|
|
1296
1393
|
"task_id": self.task_id,
|
|
1297
1394
|
"model_files": experiment_info["model_files"],
|
|
1298
1395
|
"checkpoints": experiment_info["checkpoints"],
|
|
1299
|
-
"best_checkpoint": experiment_info["best_checkpoint"],
|
|
1396
|
+
"best_checkpoint": sly_fs.get_file_name_with_ext(experiment_info["best_checkpoint"]),
|
|
1397
|
+
"export": export_weights,
|
|
1300
1398
|
"app_state": self._app_state_file,
|
|
1301
1399
|
"model_meta": self._model_meta_file,
|
|
1302
1400
|
"train_val_split": self._train_val_split_file,
|
|
@@ -1304,7 +1402,7 @@ class TrainApp:
|
|
|
1304
1402
|
"artifacts_dir": remote_dir,
|
|
1305
1403
|
"datetime": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
1306
1404
|
"evaluation_report_id": evaluation_report_id,
|
|
1307
|
-
"
|
|
1405
|
+
"evaluation_metrics": eval_metrics,
|
|
1308
1406
|
}
|
|
1309
1407
|
|
|
1310
1408
|
remote_checkpoints_dir = join(remote_dir, self._remote_checkpoints_dir_name)
|
|
@@ -1497,7 +1595,9 @@ class TrainApp:
|
|
|
1497
1595
|
set_directory(remote_dir)
|
|
1498
1596
|
self.gui.training_process.artifacts_thumbnail.set(file_info)
|
|
1499
1597
|
self.gui.training_process.artifacts_thumbnail.show()
|
|
1500
|
-
self.gui.training_process.
|
|
1598
|
+
self.gui.training_process.validator_text.set(
|
|
1599
|
+
self.gui.training_process.success_message_text, "success"
|
|
1600
|
+
)
|
|
1501
1601
|
|
|
1502
1602
|
# Model Benchmark
|
|
1503
1603
|
def _get_eval_results_dir_name(self) -> str:
|
|
@@ -1506,8 +1606,10 @@ class TrainApp:
|
|
|
1506
1606
|
"""
|
|
1507
1607
|
task_info = self._api.task.get_info_by_id(self.task_id)
|
|
1508
1608
|
task_dir = f"{self.task_id}_{task_info['meta']['app']['name']}"
|
|
1509
|
-
eval_res_dir =
|
|
1510
|
-
|
|
1609
|
+
eval_res_dir = (
|
|
1610
|
+
f"/model-benchmark/{self.project_info.id}_{self.project_info.name}/{task_dir}/"
|
|
1611
|
+
)
|
|
1612
|
+
eval_res_dir = self._api.storage.get_free_dir_name(self._team_id, eval_res_dir)
|
|
1511
1613
|
return eval_res_dir
|
|
1512
1614
|
|
|
1513
1615
|
def _run_model_benchmark(
|
|
@@ -1528,16 +1630,16 @@ class TrainApp:
|
|
|
1528
1630
|
:type experiment_info: dict
|
|
1529
1631
|
:param splits_data: Information about the train and val splits.
|
|
1530
1632
|
:type splits_data: dict
|
|
1531
|
-
:return: Evaluation report and
|
|
1633
|
+
:return: Evaluation report, report ID and evaluation metrics.
|
|
1532
1634
|
:rtype: tuple
|
|
1533
1635
|
"""
|
|
1534
|
-
|
|
1636
|
+
report_file, report_id, eval_metrics = None, None, {}
|
|
1535
1637
|
if self._inference_class is None:
|
|
1536
1638
|
logger.warn(
|
|
1537
1639
|
"Inference class is not registered, model benchmark disabled. "
|
|
1538
1640
|
"Use 'register_inference_class' method to register inference class."
|
|
1539
1641
|
)
|
|
1540
|
-
return
|
|
1642
|
+
return report_file, report_id, eval_metrics
|
|
1541
1643
|
|
|
1542
1644
|
# Can't get task type from session. requires before session init
|
|
1543
1645
|
supported_task_types = [
|
|
@@ -1550,7 +1652,7 @@ class TrainApp:
|
|
|
1550
1652
|
f"Task type: '{task_type}' is not supported for Model Benchmark. "
|
|
1551
1653
|
f"Supported tasks: {', '.join(task_type)}"
|
|
1552
1654
|
)
|
|
1553
|
-
return
|
|
1655
|
+
return report_file, report_id, eval_metrics
|
|
1554
1656
|
|
|
1555
1657
|
logger.info("Running Model Benchmark evaluation")
|
|
1556
1658
|
try:
|
|
@@ -1569,13 +1671,14 @@ class TrainApp:
|
|
|
1569
1671
|
|
|
1570
1672
|
logger.info(f"Creating the report for the best model: {best_filename!r}")
|
|
1571
1673
|
self.gui.training_process.validator_text.set(
|
|
1572
|
-
f"Creating evaluation report for the best model: {best_filename!r}",
|
|
1674
|
+
f"Creating evaluation report for the best model: {best_filename!r}",
|
|
1675
|
+
"info",
|
|
1573
1676
|
)
|
|
1574
1677
|
self.progress_bar_main(message="Starting Model Benchmark evaluation", total=1)
|
|
1575
1678
|
self.progress_bar_main.show()
|
|
1576
1679
|
|
|
1577
1680
|
# 0. Serve trained model
|
|
1578
|
-
m = self._inference_class(
|
|
1681
|
+
m: Inference = self._inference_class(
|
|
1579
1682
|
model_dir=self.model_dir,
|
|
1580
1683
|
use_gui=False,
|
|
1581
1684
|
custom_inference_settings=self._inference_settings,
|
|
@@ -1677,6 +1780,7 @@ class TrainApp:
|
|
|
1677
1780
|
|
|
1678
1781
|
# 4. Evaluate
|
|
1679
1782
|
bm._evaluate(gt_project_path, dt_project_path)
|
|
1783
|
+
bm._dump_eval_inference_info(bm._eval_inference_info)
|
|
1680
1784
|
|
|
1681
1785
|
# 5. Upload evaluation results
|
|
1682
1786
|
eval_res_dir = self._get_eval_results_dir_name()
|
|
@@ -1691,8 +1795,9 @@ class TrainApp:
|
|
|
1691
1795
|
# 7. Prepare visualizations, report and upload
|
|
1692
1796
|
bm.visualize()
|
|
1693
1797
|
remote_dir = bm.upload_visualizations(eval_res_dir + "/visualizations/")
|
|
1694
|
-
|
|
1695
|
-
report_id =
|
|
1798
|
+
report_file = bm.upload_report_link(remote_dir)
|
|
1799
|
+
report_id = bm.report_id
|
|
1800
|
+
eval_metrics = bm.key_metrics
|
|
1696
1801
|
|
|
1697
1802
|
# 8. UI updates
|
|
1698
1803
|
benchmark_report_template = self._api.file.get_info_by_path(
|
|
@@ -1725,8 +1830,8 @@ class TrainApp:
|
|
|
1725
1830
|
if bm.diff_project_info:
|
|
1726
1831
|
self._api.project.remove(bm.diff_project_info.id)
|
|
1727
1832
|
except Exception as e2:
|
|
1728
|
-
return
|
|
1729
|
-
return
|
|
1833
|
+
return report_file, report_id, eval_metrics
|
|
1834
|
+
return report_file, report_id, eval_metrics
|
|
1730
1835
|
|
|
1731
1836
|
# ----------------------------------------- #
|
|
1732
1837
|
|
|
@@ -1772,6 +1877,7 @@ class TrainApp:
|
|
|
1772
1877
|
team_files_dir: str,
|
|
1773
1878
|
file_info: FileInfo,
|
|
1774
1879
|
model_benchmark_report: Optional[FileInfo] = None,
|
|
1880
|
+
model_benchmark_report_id: Optional[FileInfo] = None,
|
|
1775
1881
|
):
|
|
1776
1882
|
"""
|
|
1777
1883
|
Adds the output data to the workflow.
|
|
@@ -1814,24 +1920,25 @@ class TrainApp:
|
|
|
1814
1920
|
f"File with checkpoints not found in Team Files. Cannot set workflow output."
|
|
1815
1921
|
)
|
|
1816
1922
|
|
|
1817
|
-
if
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1923
|
+
if self.is_model_benchmark_enabled:
|
|
1924
|
+
if model_benchmark_report:
|
|
1925
|
+
mb_relation_settings = WorkflowSettings(
|
|
1926
|
+
title="Model Benchmark",
|
|
1927
|
+
icon="assignment",
|
|
1928
|
+
icon_color="#674EA7",
|
|
1929
|
+
icon_bg_color="#CCCCFF",
|
|
1930
|
+
url=f"/model-benchmark?id={model_benchmark_report_id}",
|
|
1931
|
+
url_title="Open Report",
|
|
1932
|
+
)
|
|
1826
1933
|
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1934
|
+
meta = WorkflowMeta(
|
|
1935
|
+
relation_settings=mb_relation_settings, node_settings=node_settings
|
|
1936
|
+
)
|
|
1937
|
+
self._api.app.workflow.add_output_file(model_benchmark_report, meta=meta)
|
|
1938
|
+
else:
|
|
1939
|
+
logger.debug(
|
|
1940
|
+
f"File with model benchmark report not found in Team Files. Cannot set workflow output."
|
|
1941
|
+
)
|
|
1835
1942
|
except Exception as e:
|
|
1836
1943
|
logger.debug(f"Failed to add output to the workflow: {repr(e)}")
|
|
1837
1944
|
# ----------------------------------------- #
|
|
@@ -1841,13 +1948,19 @@ class TrainApp:
|
|
|
1841
1948
|
"""
|
|
1842
1949
|
Initialize training logger. Set up Tensorboard and callbacks.
|
|
1843
1950
|
"""
|
|
1844
|
-
|
|
1845
|
-
if
|
|
1846
|
-
|
|
1847
|
-
self.
|
|
1951
|
+
selected_logger = self._app_options.get("train_logger", "")
|
|
1952
|
+
if selected_logger.lower() == "tensorboard":
|
|
1953
|
+
setup_train_logger("tensorboard_logger")
|
|
1954
|
+
train_logger.set_log_dir(self.log_dir)
|
|
1848
1955
|
self._init_tensorboard()
|
|
1956
|
+
else:
|
|
1957
|
+
setup_train_logger("default_logger")
|
|
1958
|
+
self._setup_logger_callbacks()
|
|
1849
1959
|
|
|
1850
1960
|
def _init_tensorboard(self):
|
|
1961
|
+
if self._tensorboard_process is not None:
|
|
1962
|
+
logger.debug("Tensorboard server is already running")
|
|
1963
|
+
return
|
|
1851
1964
|
self._register_routes()
|
|
1852
1965
|
args = [
|
|
1853
1966
|
"tensorboard",
|
|
@@ -1910,10 +2023,7 @@ class TrainApp:
|
|
|
1910
2023
|
"""
|
|
1911
2024
|
self.progress_bar_main.hide()
|
|
1912
2025
|
self.progress_bar_secondary.hide()
|
|
1913
|
-
|
|
1914
|
-
train_logger = self._app_options.get("train_logger", "")
|
|
1915
|
-
if train_logger == "tensorboard":
|
|
1916
|
-
tb_logger.close()
|
|
2026
|
+
train_logger.close()
|
|
1917
2027
|
|
|
1918
2028
|
def start_epoch_callback(total_steps: int):
|
|
1919
2029
|
"""
|
|
@@ -1937,13 +2047,13 @@ class TrainApp:
|
|
|
1937
2047
|
"""
|
|
1938
2048
|
step_pbar.update(1)
|
|
1939
2049
|
|
|
1940
|
-
|
|
1941
|
-
|
|
2050
|
+
train_logger.add_on_train_started_callback(start_training_callback)
|
|
2051
|
+
train_logger.add_on_train_finish_callback(finish_training_callback)
|
|
1942
2052
|
|
|
1943
|
-
|
|
1944
|
-
|
|
2053
|
+
train_logger.add_on_epoch_started_callback(start_epoch_callback)
|
|
2054
|
+
train_logger.add_on_epoch_finish_callback(finish_epoch_callback)
|
|
1945
2055
|
|
|
1946
|
-
|
|
2056
|
+
train_logger.add_on_step_finished_callback(step_callback)
|
|
1947
2057
|
|
|
1948
2058
|
# ----------------------------------------- #
|
|
1949
2059
|
def _wrapped_start_training(self):
|
|
@@ -1962,9 +2072,11 @@ class TrainApp:
|
|
|
1962
2072
|
message = "Error occurred during training initialization. Please check the logs for more details."
|
|
1963
2073
|
self._show_error(message, e)
|
|
1964
2074
|
self._restore_train_widgets_state_on_error()
|
|
2075
|
+
self._set_progress_status("reset")
|
|
2076
|
+
return
|
|
1965
2077
|
|
|
1966
2078
|
try:
|
|
1967
|
-
self.
|
|
2079
|
+
self._set_progress_status("preparing")
|
|
1968
2080
|
self._prepare()
|
|
1969
2081
|
except Exception as e:
|
|
1970
2082
|
message = (
|
|
@@ -1972,30 +2084,30 @@ class TrainApp:
|
|
|
1972
2084
|
)
|
|
1973
2085
|
self._show_error(message, e)
|
|
1974
2086
|
self._restore_train_widgets_state_on_error()
|
|
2087
|
+
self._set_progress_status("reset")
|
|
1975
2088
|
return
|
|
1976
2089
|
|
|
1977
2090
|
try:
|
|
1978
|
-
self.
|
|
2091
|
+
self._set_progress_status("training")
|
|
2092
|
+
if self._app_options.get("train_logger", None) is None:
|
|
2093
|
+
self._set_progress_status("training")
|
|
1979
2094
|
experiment_info = self._train_func()
|
|
1980
2095
|
except Exception as e:
|
|
1981
2096
|
message = "Error occurred during training. Please check the logs for more details."
|
|
1982
2097
|
self._show_error(message, e)
|
|
1983
2098
|
self._restore_train_widgets_state_on_error()
|
|
2099
|
+
self._set_progress_status("reset")
|
|
1984
2100
|
return
|
|
1985
2101
|
|
|
1986
2102
|
try:
|
|
1987
|
-
self.
|
|
1988
|
-
"Finalizing and uploading training artifacts...", "info"
|
|
1989
|
-
)
|
|
2103
|
+
self._set_progress_status("finalizing")
|
|
1990
2104
|
self._finalize(experiment_info)
|
|
1991
2105
|
self.gui.training_process.start_button.loading = False
|
|
1992
|
-
self.gui.training_process.validator_text.set(
|
|
1993
|
-
self.gui.training_process.success_message_text, "success"
|
|
1994
|
-
)
|
|
1995
2106
|
except Exception as e:
|
|
1996
2107
|
message = "Error occurred during finalizing and uploading training artifacts . Please check the logs for more details."
|
|
1997
2108
|
self._show_error(message, e)
|
|
1998
2109
|
self._restore_train_widgets_state_on_error()
|
|
2110
|
+
self._set_progress_status("reset")
|
|
1999
2111
|
return
|
|
2000
2112
|
|
|
2001
2113
|
def _show_error(self, message: str, e=None):
|
|
@@ -2017,7 +2129,7 @@ class TrainApp:
|
|
|
2017
2129
|
|
|
2018
2130
|
self.gui.training_logs.card.unlock()
|
|
2019
2131
|
self.gui.stepper.set_active_step(7)
|
|
2020
|
-
self.gui.training_process.validator_text.set("Training
|
|
2132
|
+
self.gui.training_process.validator_text.set("Training has been started...", "info")
|
|
2021
2133
|
self.gui.training_process.validator_text.show()
|
|
2022
2134
|
self.gui.training_process.start_button.loading = True
|
|
2023
2135
|
|
|
@@ -2039,3 +2151,77 @@ class TrainApp:
|
|
|
2039
2151
|
logger.error(f"Experiment name contains invalid characters: {invalid_chars}")
|
|
2040
2152
|
raise ValueError(f"Experiment name contains invalid characters: {invalid_chars}")
|
|
2041
2153
|
return True
|
|
2154
|
+
|
|
2155
|
+
def _set_progress_status(
|
|
2156
|
+
self, status: Literal["reset", "completed", "training", "finalizing", "preparing"]
|
|
2157
|
+
):
|
|
2158
|
+
message = ""
|
|
2159
|
+
if status == "reset":
|
|
2160
|
+
message = "Ready for training"
|
|
2161
|
+
elif status == "completed":
|
|
2162
|
+
message = "Training completed"
|
|
2163
|
+
elif status == "training":
|
|
2164
|
+
message = "Training is in progress..."
|
|
2165
|
+
elif status == "finalizing":
|
|
2166
|
+
message = "Finalizing and uploading training artifacts..."
|
|
2167
|
+
elif status == "preparing":
|
|
2168
|
+
message = "Preparing data for training..."
|
|
2169
|
+
|
|
2170
|
+
self.progress_bar_main.hide()
|
|
2171
|
+
self.progress_bar_secondary.hide()
|
|
2172
|
+
with self.progress_bar_main(message=message, total=1) as pbar:
|
|
2173
|
+
pbar.update(1)
|
|
2174
|
+
with self.progress_bar_secondary(message=message, total=1) as pbar:
|
|
2175
|
+
pbar.update(1)
|
|
2176
|
+
|
|
2177
|
+
def _export_weights(self, experiment_info: dict) -> List[str]:
|
|
2178
|
+
export_weights = {}
|
|
2179
|
+
if (
|
|
2180
|
+
self.gui.hyperparameters_selector.get_export_onnx_checkbox_value() is True
|
|
2181
|
+
and self._convert_onnx_func is not None
|
|
2182
|
+
):
|
|
2183
|
+
self.gui.training_process.validator_text.set(
|
|
2184
|
+
f"Converting to {RuntimeType.ONNXRUNTIME}", "info"
|
|
2185
|
+
)
|
|
2186
|
+
onnx_path = self._convert_onnx_func(experiment_info)
|
|
2187
|
+
export_weights[RuntimeType.ONNXRUNTIME] = onnx_path
|
|
2188
|
+
|
|
2189
|
+
if (
|
|
2190
|
+
self.gui.hyperparameters_selector.get_export_tensorrt_checkbox_value() is True
|
|
2191
|
+
and self._convert_tensorrt_func is not None
|
|
2192
|
+
):
|
|
2193
|
+
self.gui.training_process.validator_text.set(
|
|
2194
|
+
f"Converting to {RuntimeType.TENSORRT}", "info"
|
|
2195
|
+
)
|
|
2196
|
+
tensorrt_path = self._convert_tensorrt_func(experiment_info)
|
|
2197
|
+
export_weights[RuntimeType.TENSORRT] = tensorrt_path
|
|
2198
|
+
return export_weights
|
|
2199
|
+
|
|
2200
|
+
def _upload_export_weights(
|
|
2201
|
+
self, export_weights: Dict[str, str], remote_dir: str
|
|
2202
|
+
) -> Dict[str, str]:
|
|
2203
|
+
with self.progress_bar_main(
|
|
2204
|
+
message="Uploading export weights",
|
|
2205
|
+
total=len(export_weights),
|
|
2206
|
+
) as export_upload_main_pbar:
|
|
2207
|
+
self.progress_bar_main.show()
|
|
2208
|
+
for path in export_weights.values():
|
|
2209
|
+
file_name = sly_fs.get_file_name_with_ext(path)
|
|
2210
|
+
file_size = sly_fs.get_file_size(path)
|
|
2211
|
+
with self.progress_bar_secondary(
|
|
2212
|
+
message=f"Uploading '{file_name}' ",
|
|
2213
|
+
total=file_size,
|
|
2214
|
+
unit="bytes",
|
|
2215
|
+
unit_scale=True,
|
|
2216
|
+
) as export_upload_secondary_pbar:
|
|
2217
|
+
destination_path = join(remote_dir, self._export_dir_name, file_name)
|
|
2218
|
+
self._api.file.upload(
|
|
2219
|
+
self._team_id, path, destination_path, export_upload_secondary_pbar
|
|
2220
|
+
)
|
|
2221
|
+
export_upload_main_pbar.update(1)
|
|
2222
|
+
|
|
2223
|
+
remote_export_weights = {
|
|
2224
|
+
runtime: join(self._export_dir_name, sly_fs.get_file_name_with_ext(path))
|
|
2225
|
+
for runtime, path in export_weights.items()
|
|
2226
|
+
}
|
|
2227
|
+
return remote_export_weights
|