supervisely 6.73.390__py3-none-any.whl → 6.73.391__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/app/widgets/experiment_selector/experiment_selector.py +20 -3
- supervisely/app/widgets/experiment_selector/template.html +49 -70
- supervisely/app/widgets/report_thumbnail/report_thumbnail.py +19 -4
- supervisely/decorators/profile.py +20 -0
- supervisely/nn/benchmark/utils/detection/utlis.py +7 -0
- supervisely/nn/experiments.py +4 -0
- supervisely/nn/inference/gui/serving_gui_template.py +71 -11
- supervisely/nn/inference/inference.py +108 -6
- supervisely/nn/training/gui/classes_selector.py +246 -27
- supervisely/nn/training/gui/gui.py +318 -234
- supervisely/nn/training/gui/hyperparameters_selector.py +2 -2
- supervisely/nn/training/gui/model_selector.py +42 -1
- supervisely/nn/training/gui/tags_selector.py +1 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +8 -7
- supervisely/nn/training/gui/training_artifacts.py +10 -1
- supervisely/nn/training/gui/training_process.py +17 -1
- supervisely/nn/training/train_app.py +227 -72
- supervisely/template/__init__.py +2 -0
- supervisely/template/base_generator.py +90 -0
- supervisely/template/experiment/__init__.py +0 -0
- supervisely/template/experiment/experiment.html.jinja +537 -0
- supervisely/template/experiment/experiment_generator.py +996 -0
- supervisely/template/experiment/header.html.jinja +154 -0
- supervisely/template/experiment/sidebar.html.jinja +240 -0
- supervisely/template/experiment/sly-style.css +397 -0
- supervisely/template/experiment/template.html.jinja +18 -0
- supervisely/template/extensions.py +172 -0
- supervisely/template/template_renderer.py +253 -0
- {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/METADATA +3 -1
- {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/RECORD +34 -23
- {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/LICENSE +0 -0
- {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/WHEEL +0 -0
- {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/top_level.txt +0 -0
|
@@ -5,9 +5,12 @@ This module provides the `TrainGUI` class that handles the graphical user interf
|
|
|
5
5
|
training workflows in Supervisely.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
import os
|
|
9
|
+
from os import environ, getenv
|
|
9
10
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
10
11
|
|
|
12
|
+
from supervisely import logger
|
|
13
|
+
import supervisely.io.fs as sly_fs
|
|
11
14
|
import supervisely.io.env as sly_env
|
|
12
15
|
import supervisely.io.json as sly_json
|
|
13
16
|
from supervisely import Api, ProjectMeta
|
|
@@ -29,6 +32,7 @@ from supervisely.nn.training.gui.training_logs import TrainingLogs
|
|
|
29
32
|
from supervisely.nn.training.gui.training_process import TrainingProcess
|
|
30
33
|
from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
|
|
31
34
|
from supervisely.nn.utils import ModelSource, RuntimeType
|
|
35
|
+
from supervisely.nn.experiments import ExperimentInfo
|
|
32
36
|
|
|
33
37
|
|
|
34
38
|
class StepFlow:
|
|
@@ -249,7 +253,7 @@ class TrainGUI:
|
|
|
249
253
|
self.hyperparameters = hyperparameters
|
|
250
254
|
self.app_options = app_options
|
|
251
255
|
self.collapsable = self.app_options.get("collapsable", False)
|
|
252
|
-
self.
|
|
256
|
+
self.need_convert_shapes = False
|
|
253
257
|
|
|
254
258
|
self.team_id = sly_env.team_id(raise_not_found=False)
|
|
255
259
|
self.workspace_id = sly_env.workspace_id(raise_not_found=False)
|
|
@@ -289,32 +293,32 @@ class TrainGUI:
|
|
|
289
293
|
self.input_selector = InputSelector(self.project_info, self.app_options)
|
|
290
294
|
self.steps.append(self.input_selector.card)
|
|
291
295
|
|
|
292
|
-
# 2.
|
|
293
|
-
self.
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
)
|
|
298
|
-
self.steps.append(self.train_val_splits_selector.card)
|
|
296
|
+
# 2. Model selector
|
|
297
|
+
self.model_selector = ModelSelector(
|
|
298
|
+
self._api, self.framework_name, self.models, self.app_options
|
|
299
|
+
)
|
|
300
|
+
if self.show_model_selector:
|
|
301
|
+
self.steps.append(self.model_selector.card)
|
|
299
302
|
|
|
300
|
-
# 3.
|
|
303
|
+
# 3. Classes selector
|
|
301
304
|
self.classes_selector = None
|
|
302
305
|
if self.show_classes_selector:
|
|
303
|
-
self.classes_selector = ClassesSelector(self.project_id, [], self.app_options)
|
|
306
|
+
self.classes_selector = ClassesSelector(self.project_id, [], self.model_selector, self.app_options)
|
|
304
307
|
self.steps.append(self.classes_selector.card)
|
|
305
308
|
|
|
306
|
-
# 4.
|
|
309
|
+
# 4. Tags selector
|
|
307
310
|
self.tags_selector = None
|
|
308
311
|
if self.show_tags_selector:
|
|
309
312
|
self.tags_selector = TagsSelector(self.project_id, [], self.app_options)
|
|
310
313
|
self.steps.append(self.tags_selector.card)
|
|
311
314
|
|
|
312
|
-
# 5.
|
|
313
|
-
self.
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
315
|
+
# 5. Train/Val splits selector
|
|
316
|
+
self.train_val_splits_selector = None
|
|
317
|
+
if self.show_train_val_splits_selector:
|
|
318
|
+
self.train_val_splits_selector = TrainValSplitsSelector(
|
|
319
|
+
self._api, self.project_id, self.app_options
|
|
320
|
+
)
|
|
321
|
+
self.steps.append(self.train_val_splits_selector.card)
|
|
318
322
|
|
|
319
323
|
# 6. Training parameters (yaml)
|
|
320
324
|
self.hyperparameters_selector = HyperparametersSelector(
|
|
@@ -360,89 +364,36 @@ class TrainGUI:
|
|
|
360
364
|
self.training_process.set_experiment_name(experiment_name)
|
|
361
365
|
|
|
362
366
|
def need_convert_class_shapes() -> bool:
|
|
363
|
-
if self.hyperparameters_selector.run_model_benchmark_checkbox is not
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
else:
|
|
368
|
-
task_type = self.model_selector.get_selected_task_type()
|
|
369
|
-
|
|
370
|
-
def _need_convert(shape):
|
|
371
|
-
if task_type == TaskType.OBJECT_DETECTION:
|
|
372
|
-
return shape != Rectangle.geometry_name()
|
|
373
|
-
elif task_type in [
|
|
374
|
-
TaskType.INSTANCE_SEGMENTATION,
|
|
375
|
-
TaskType.SEMANTIC_SEGMENTATION,
|
|
376
|
-
]:
|
|
377
|
-
return shape == Polygon.geometry_name()
|
|
378
|
-
return False
|
|
379
|
-
|
|
380
|
-
if self.classes_selector is not None:
|
|
381
|
-
data = self.classes_selector.classes_table._table_data
|
|
382
|
-
selected_classes = set(
|
|
383
|
-
self.classes_selector.classes_table.get_selected_classes()
|
|
384
|
-
)
|
|
385
|
-
empty = set(
|
|
386
|
-
r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
|
|
387
|
-
)
|
|
388
|
-
need_convert = set(
|
|
389
|
-
r[0]["data"] for r in data if _need_convert(r[1]["data"])
|
|
390
|
-
)
|
|
391
|
-
else:
|
|
392
|
-
# Set project meta classes when classes selector is disabled
|
|
393
|
-
selected_classes = set(cls.name for cls in self.project_meta.obj_classes)
|
|
394
|
-
need_convert = set(
|
|
395
|
-
obj_class.name
|
|
396
|
-
for obj_class in self.project_meta.obj_classes
|
|
397
|
-
if _need_convert(obj_class.geometry_type)
|
|
398
|
-
)
|
|
399
|
-
empty = set()
|
|
400
|
-
|
|
401
|
-
if need_convert.intersection(selected_classes - empty):
|
|
402
|
-
self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
|
|
403
|
-
self.need_convert_shapes_for_bm = True
|
|
404
|
-
else:
|
|
405
|
-
self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
|
|
406
|
-
self.need_convert_shapes_for_bm = False
|
|
407
|
-
else:
|
|
408
|
-
self.need_convert_shapes_for_bm = False
|
|
367
|
+
if self.hyperparameters_selector.run_model_benchmark_checkbox is None or not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
|
|
368
|
+
self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
|
|
369
|
+
self.need_convert_shapes = False
|
|
370
|
+
return False
|
|
409
371
|
|
|
410
|
-
def validate_class_shape_for_model_task():
|
|
411
372
|
task_type = self.model_selector.get_selected_task_type()
|
|
412
|
-
if self.classes_selector is not None:
|
|
413
|
-
classes = self.classes_selector.get_selected_classes()
|
|
414
|
-
else:
|
|
415
|
-
classes = list(self.project_meta.obj_classes.keys())
|
|
416
373
|
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
TaskType.SEMANTIC_SEGMENTATION: {Polygon, Bitmap},
|
|
420
|
-
TaskType.POSE_ESTIMATION: {GraphNodes},
|
|
421
|
-
}
|
|
422
|
-
task_specific_texts = {
|
|
423
|
-
TaskType.INSTANCE_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
|
|
424
|
-
TaskType.SEMANTIC_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
|
|
425
|
-
TaskType.POSE_ESTIMATION: "Only keypoint (graph) shape is supported for pose estimation task",
|
|
426
|
-
}
|
|
374
|
+
if self.classes_selector is not None:
|
|
375
|
+
wrong_shapes = set(self.classes_selector.get_wrong_shape_classes(task_type))
|
|
427
376
|
|
|
428
|
-
|
|
429
|
-
|
|
377
|
+
# Exclude classes with no annotations to avoid unnecessary conversion
|
|
378
|
+
data = self.classes_selector.classes_table._table_data
|
|
379
|
+
empty_classes = {r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0}
|
|
380
|
+
need_conversion = bool(wrong_shapes - empty_classes)
|
|
381
|
+
else:
|
|
382
|
+
# Classes selector disabled – check entire project meta
|
|
383
|
+
if task_type == TaskType.OBJECT_DETECTION:
|
|
384
|
+
need_conversion = any(obj_cls.geometry_type != Rectangle for obj_cls in self.project_meta.obj_classes)
|
|
385
|
+
elif task_type in [TaskType.INSTANCE_SEGMENTATION, TaskType.SEMANTIC_SEGMENTATION]:
|
|
386
|
+
need_conversion = any(obj_cls.geometry_type == Polygon for obj_cls in self.project_meta.obj_classes)
|
|
387
|
+
else:
|
|
388
|
+
need_conversion = False
|
|
430
389
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
not in required_geometries[task_type]
|
|
436
|
-
]
|
|
437
|
-
|
|
438
|
-
if wrong_shape_classes:
|
|
439
|
-
specific_text = task_specific_texts[task_type]
|
|
440
|
-
message_text = f"Model task type is {task_type}. {specific_text}. Selected classes have wrong shapes for the model task: {', '.join(wrong_shape_classes)}"
|
|
441
|
-
self.model_selector.validator_text.set(
|
|
442
|
-
text=message_text,
|
|
443
|
-
status="warning",
|
|
444
|
-
)
|
|
390
|
+
if need_conversion:
|
|
391
|
+
self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
|
|
392
|
+
else:
|
|
393
|
+
self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
|
|
445
394
|
|
|
395
|
+
self.need_convert_shapes = need_conversion
|
|
396
|
+
return need_conversion
|
|
446
397
|
# ------------------------------------------------- #
|
|
447
398
|
|
|
448
399
|
self.step_flow = StepFlow(self.stepper, self.app_options)
|
|
@@ -460,17 +411,17 @@ class TrainGUI:
|
|
|
460
411
|
)
|
|
461
412
|
position += 1
|
|
462
413
|
|
|
463
|
-
# 2.
|
|
464
|
-
if self.
|
|
414
|
+
# 2. Model selector
|
|
415
|
+
if self.show_model_selector:
|
|
465
416
|
self.step_flow.register_step(
|
|
466
|
-
"
|
|
467
|
-
self.
|
|
468
|
-
self.
|
|
469
|
-
self.
|
|
470
|
-
self.
|
|
471
|
-
self.
|
|
472
|
-
position=position
|
|
473
|
-
)
|
|
417
|
+
"model_selector",
|
|
418
|
+
self.model_selector.card,
|
|
419
|
+
self.model_selector.button,
|
|
420
|
+
self.model_selector.widgets_to_disable,
|
|
421
|
+
self.model_selector.validator_text,
|
|
422
|
+
self.model_selector.validate_step,
|
|
423
|
+
position=position
|
|
424
|
+
).add_on_select_actions("model_selector", [set_experiment_name])
|
|
474
425
|
position += 1
|
|
475
426
|
|
|
476
427
|
# 3. Classes selector
|
|
@@ -483,7 +434,7 @@ class TrainGUI:
|
|
|
483
434
|
self.classes_selector.validator_text,
|
|
484
435
|
self.classes_selector.validate_step,
|
|
485
436
|
position=position,
|
|
486
|
-
)
|
|
437
|
+
).add_on_select_actions("classes_selector", [need_convert_class_shapes])
|
|
487
438
|
position += 1
|
|
488
439
|
|
|
489
440
|
# 4. Tags selector
|
|
@@ -499,23 +450,16 @@ class TrainGUI:
|
|
|
499
450
|
)
|
|
500
451
|
position += 1
|
|
501
452
|
|
|
502
|
-
# 5.
|
|
503
|
-
if self.
|
|
453
|
+
# 5. Train/Val splits selector
|
|
454
|
+
if self.show_train_val_splits_selector and self.train_val_splits_selector is not None:
|
|
504
455
|
self.step_flow.register_step(
|
|
505
|
-
"
|
|
506
|
-
self.
|
|
507
|
-
self.
|
|
508
|
-
self.
|
|
509
|
-
self.
|
|
510
|
-
self.
|
|
456
|
+
"train_val_splits",
|
|
457
|
+
self.train_val_splits_selector.card,
|
|
458
|
+
self.train_val_splits_selector.button,
|
|
459
|
+
self.train_val_splits_selector.widgets_to_disable,
|
|
460
|
+
self.train_val_splits_selector.validator_text,
|
|
461
|
+
self.train_val_splits_selector.validate_step,
|
|
511
462
|
position=position,
|
|
512
|
-
).add_on_select_actions(
|
|
513
|
-
"model_selector",
|
|
514
|
-
[
|
|
515
|
-
set_experiment_name,
|
|
516
|
-
need_convert_class_shapes,
|
|
517
|
-
validate_class_shape_for_model_task,
|
|
518
|
-
],
|
|
519
463
|
)
|
|
520
464
|
position += 1
|
|
521
465
|
|
|
@@ -570,34 +514,28 @@ class TrainGUI:
|
|
|
570
514
|
)
|
|
571
515
|
|
|
572
516
|
# Set dependencies between steps
|
|
573
|
-
|
|
574
|
-
self.show_train_val_splits_selector and self.train_val_splits_selector is not None
|
|
575
|
-
)
|
|
517
|
+
has_model_selector = self.show_model_selector and self.model_selector is not None
|
|
576
518
|
has_classes_selector = self.show_classes_selector and self.classes_selector is not None
|
|
577
519
|
has_tags_selector = self.show_tags_selector and self.tags_selector is not None
|
|
520
|
+
has_train_val_splits = self.show_train_val_splits_selector and self.train_val_splits_selector is not None
|
|
578
521
|
|
|
579
522
|
# Set step dependency chain
|
|
580
|
-
# 1. Input selector
|
|
581
523
|
prev_step = "input_selector"
|
|
582
|
-
if
|
|
583
|
-
self.step_flow.set_next_steps(prev_step, ["
|
|
584
|
-
prev_step = "
|
|
524
|
+
if has_model_selector:
|
|
525
|
+
self.step_flow.set_next_steps(prev_step, ["model_selector"])
|
|
526
|
+
prev_step = "model_selector"
|
|
585
527
|
if has_classes_selector:
|
|
586
528
|
self.step_flow.set_next_steps(prev_step, ["classes_selector"])
|
|
587
529
|
prev_step = "classes_selector"
|
|
588
530
|
if has_tags_selector:
|
|
589
531
|
self.step_flow.set_next_steps(prev_step, ["tags_selector"])
|
|
590
532
|
prev_step = "tags_selector"
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
# Model selector -> hyperparameters
|
|
595
|
-
self.step_flow.set_next_steps("model_selector", ["hyperparameters_selector"])
|
|
596
|
-
prev_step = "model_selector"
|
|
597
|
-
else:
|
|
598
|
-
self.step_flow.set_next_steps(prev_step, ["hyperparameters_selector"])
|
|
533
|
+
if has_train_val_splits:
|
|
534
|
+
self.step_flow.set_next_steps(prev_step, ["train_val_splits"])
|
|
535
|
+
prev_step = "train_val_splits"
|
|
599
536
|
|
|
600
537
|
# 6. Hyperparameters selector -> 7. Training process
|
|
538
|
+
self.step_flow.set_next_steps(prev_step, ["hyperparameters_selector"])
|
|
601
539
|
self.step_flow.set_next_steps("hyperparameters_selector", ["training_process"])
|
|
602
540
|
|
|
603
541
|
# 7. Training process -> 8. Training logs
|
|
@@ -633,13 +571,17 @@ class TrainGUI:
|
|
|
633
571
|
@self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
|
|
634
572
|
def show_mb_speedtest(is_checked: bool):
|
|
635
573
|
self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
|
|
636
|
-
need_convert_class_shapes()
|
|
637
|
-
|
|
638
574
|
# ------------------------------------------------- #
|
|
639
575
|
|
|
640
576
|
self.layout: Widget = self.stepper
|
|
641
577
|
|
|
642
|
-
#
|
|
578
|
+
# Run from experiment page
|
|
579
|
+
train_task_id = getenv("modal.state.trainTaskId", None)
|
|
580
|
+
train_mode = getenv("modal.state.trainMode", None)
|
|
581
|
+
if train_task_id is not None and train_mode is not None:
|
|
582
|
+
self._run_from_experiment(train_task_id, train_mode)
|
|
583
|
+
# ----------------------------------------- #
|
|
584
|
+
|
|
643
585
|
|
|
644
586
|
def set_next_step(self):
|
|
645
587
|
current_step = self.stepper.get_active_step()
|
|
@@ -831,12 +773,16 @@ class TrainGUI:
|
|
|
831
773
|
)
|
|
832
774
|
return app_state
|
|
833
775
|
|
|
834
|
-
def load_from_app_state(self, app_state: Union[str, dict]) -> None:
|
|
776
|
+
def load_from_app_state(self, app_state: Union[str, dict], click_cb: bool = True, validate_steps: bool = True) -> None:
|
|
835
777
|
"""
|
|
836
|
-
Load the GUI state from app state dictionary.
|
|
778
|
+
Load the GUI state from app state dictionary or path to the state file.
|
|
837
779
|
|
|
838
|
-
:param app_state: The state dictionary.
|
|
839
|
-
:type app_state: dict
|
|
780
|
+
:param app_state: The state dictionary or path to the state file.
|
|
781
|
+
:type app_state: Union[str, dict]
|
|
782
|
+
:param click_cb: Automatically click the callback functions to set the GUI state.
|
|
783
|
+
:type click_cb: bool
|
|
784
|
+
:param validate_steps: Validate the steps. If False, the steps will not be validated.
|
|
785
|
+
:type validate_steps: bool
|
|
840
786
|
|
|
841
787
|
app_state example:
|
|
842
788
|
|
|
@@ -847,21 +793,29 @@ class TrainGUI:
|
|
|
847
793
|
"percent": 90
|
|
848
794
|
},
|
|
849
795
|
"classes": ["apple"],
|
|
796
|
+
# Pretrained model
|
|
850
797
|
"model": {
|
|
851
798
|
"source": "Pretrained models",
|
|
852
799
|
"model_name": "rtdetr_r50vd_coco_objects365"
|
|
853
800
|
},
|
|
801
|
+
# Custom model
|
|
802
|
+
# "model": {
|
|
803
|
+
# "source": "Custom models",
|
|
804
|
+
# "task_id": 555,
|
|
805
|
+
# "checkpoint": "checkpoint_10.pth"
|
|
806
|
+
# },
|
|
854
807
|
"hyperparameters": hyperparameters, # yaml string
|
|
855
808
|
"options": {
|
|
809
|
+
"convert_class_shapes": True,
|
|
856
810
|
"model_benchmark": {
|
|
857
811
|
"enable": True,
|
|
858
812
|
"speed_test": True
|
|
859
813
|
},
|
|
860
814
|
"cache_project": True,
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
815
|
+
"export": {
|
|
816
|
+
"enable": True,
|
|
817
|
+
"ONNXRuntime": True,
|
|
818
|
+
"TensorRT": True
|
|
865
819
|
},
|
|
866
820
|
},
|
|
867
821
|
"experiment_name": "my_experiment",
|
|
@@ -869,27 +823,43 @@ class TrainGUI:
|
|
|
869
823
|
"""
|
|
870
824
|
if isinstance(app_state, str):
|
|
871
825
|
app_state = sly_json.load_json_file(app_state)
|
|
826
|
+
|
|
872
827
|
app_state = self.validate_app_state(app_state)
|
|
873
|
-
|
|
874
828
|
options = app_state.get("options", {})
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
tags_settings = app_state.get("tags", [])
|
|
879
|
-
model_settings = app_state["model"]
|
|
880
|
-
hyperparameters_settings = app_state["hyperparameters"]
|
|
881
|
-
experiment_name = app_state.get("experiment_name", None)
|
|
882
|
-
|
|
883
|
-
self._init_input(input_settings, options)
|
|
884
|
-
self._init_train_val_splits(train_val_splits_settings, options)
|
|
885
|
-
self._init_classes(classes_settings, options)
|
|
886
|
-
self._init_tags(tags_settings, options)
|
|
887
|
-
self._init_model(model_settings, options)
|
|
888
|
-
self._init_hyperparameters(hyperparameters_settings, options)
|
|
829
|
+
|
|
830
|
+
# Set experiment name
|
|
831
|
+
experiment_name = app_state.get("experiment_name")
|
|
889
832
|
if experiment_name is not None:
|
|
890
833
|
self.training_process.set_experiment_name(experiment_name)
|
|
891
834
|
|
|
892
|
-
|
|
835
|
+
# Run init-steps and stop on validation failure
|
|
836
|
+
def _run_step(init_fn, settings) -> bool:
|
|
837
|
+
if not init_fn(settings, options, click_cb, validate_steps):
|
|
838
|
+
return False
|
|
839
|
+
return True
|
|
840
|
+
|
|
841
|
+
# GUI init steps
|
|
842
|
+
_steps = [
|
|
843
|
+
(self._init_input, app_state.get("input"), "Input project"),
|
|
844
|
+
(self._init_model, app_state["model"], "Select Model"),
|
|
845
|
+
(self._init_classes, app_state.get("classes", []), "Classes Selector"),
|
|
846
|
+
(self._init_tags, app_state.get("tags", []), "Tags Selector"),
|
|
847
|
+
(self._init_train_val_splits, app_state.get("train_val_split", {}), "Train/Val Splits"),
|
|
848
|
+
(self._init_hyperparameters, app_state["hyperparameters"], "Hyperparameters"),
|
|
849
|
+
]
|
|
850
|
+
|
|
851
|
+
for idx, (init_fn, settings, step_name) in enumerate(_steps, start=1):
|
|
852
|
+
if not _run_step(init_fn, settings):
|
|
853
|
+
if validate_steps:
|
|
854
|
+
logger.warning(f"Step '{step_name}' {idx}/{len(_steps)} failed to validate")
|
|
855
|
+
return
|
|
856
|
+
if validate_steps:
|
|
857
|
+
logger.info(f"Step '{step_name}' {idx}/{len(_steps)} has been validated successfully")
|
|
858
|
+
if validate_steps:
|
|
859
|
+
logger.info(f"All steps have been validated successfully")
|
|
860
|
+
# ------------------------------------------------------------------ #
|
|
861
|
+
|
|
862
|
+
def _init_input(self, input_settings: Union[dict, None], options: dict, click_cb: bool = True, validate: bool = True) -> bool:
|
|
893
863
|
"""
|
|
894
864
|
Initialize the input selector with the given settings.
|
|
895
865
|
|
|
@@ -897,13 +867,123 @@ class TrainGUI:
|
|
|
897
867
|
:type input_settings: dict
|
|
898
868
|
:param options: The application options.
|
|
899
869
|
:type options: dict
|
|
870
|
+
:param click_cb: Click the callback function.
|
|
871
|
+
:type click_cb: bool
|
|
872
|
+
:param validate: Validate the step.
|
|
873
|
+
:type validate: bool
|
|
900
874
|
"""
|
|
901
875
|
# Set Input
|
|
902
876
|
self.input_selector.set_cache(options.get("cache_project", True))
|
|
903
|
-
|
|
877
|
+
is_valid = True
|
|
878
|
+
if validate:
|
|
879
|
+
is_valid = self.input_selector.validate_step()
|
|
880
|
+
if is_valid and click_cb:
|
|
881
|
+
self.input_selector_cb()
|
|
882
|
+
self.set_next_step()
|
|
883
|
+
return is_valid
|
|
904
884
|
# ----------------------------------------- #
|
|
905
885
|
|
|
906
|
-
def
|
|
886
|
+
def _init_model(self, model_settings: dict, options: dict = None, click_cb: bool = True, validate: bool = True) -> bool:
|
|
887
|
+
"""
|
|
888
|
+
Initialize the model selector with the given settings.
|
|
889
|
+
|
|
890
|
+
:param model_settings: The model settings.
|
|
891
|
+
:type model_settings: dict
|
|
892
|
+
:param options: The application options.
|
|
893
|
+
:type options: dict
|
|
894
|
+
:param click_cb: Click the callback function.
|
|
895
|
+
:type click_cb: bool
|
|
896
|
+
:param validate: Validate the step.
|
|
897
|
+
:type validate: bool
|
|
898
|
+
"""
|
|
899
|
+
|
|
900
|
+
# Pretrained
|
|
901
|
+
if model_settings["source"] == ModelSource.PRETRAINED:
|
|
902
|
+
self.model_selector.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
|
|
903
|
+
self.model_selector.pretrained_models_table.set_by_model_name(
|
|
904
|
+
model_settings["model_name"]
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
# Custom
|
|
908
|
+
elif model_settings["source"] == ModelSource.CUSTOM:
|
|
909
|
+
self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
|
|
910
|
+
self.model_selector.experiment_selector.set_by_task_id(model_settings["task_id"])
|
|
911
|
+
active_row = self.model_selector.experiment_selector.get_selected_row()
|
|
912
|
+
if model_settings["checkpoint"] not in active_row.checkpoints_names:
|
|
913
|
+
raise ValueError(
|
|
914
|
+
f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
active_row.set_selected_checkpoint_by_name(model_settings["checkpoint"])
|
|
918
|
+
|
|
919
|
+
is_valid = True
|
|
920
|
+
if validate:
|
|
921
|
+
is_valid = self.model_selector.validate_step()
|
|
922
|
+
if is_valid and click_cb:
|
|
923
|
+
self.model_selector_cb()
|
|
924
|
+
self.set_next_step()
|
|
925
|
+
return is_valid
|
|
926
|
+
# ----------------------------------------- #
|
|
927
|
+
|
|
928
|
+
def _init_classes(self, classes_settings: list, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
|
|
929
|
+
"""
|
|
930
|
+
Initialize the classes selector with the given settings.
|
|
931
|
+
|
|
932
|
+
:param classes_settings: The classes settings.
|
|
933
|
+
:type classes_settings: list
|
|
934
|
+
:param options: The application options.
|
|
935
|
+
:type options: dict
|
|
936
|
+
:param click_cb: Click the callback function.
|
|
937
|
+
:type click_cb: bool
|
|
938
|
+
:param validate: Validate the step.
|
|
939
|
+
:type validate: bool
|
|
940
|
+
"""
|
|
941
|
+
if self.classes_selector is None:
|
|
942
|
+
return True # Selector disabled by app options
|
|
943
|
+
|
|
944
|
+
convert_class_shapes = options.get("convert_class_shapes", True)
|
|
945
|
+
if convert_class_shapes:
|
|
946
|
+
self.classes_selector.convert_class_shapes_checkbox.check()
|
|
947
|
+
|
|
948
|
+
# Set Classes
|
|
949
|
+
self.classes_selector.set_classes(classes_settings)
|
|
950
|
+
is_valid = True
|
|
951
|
+
if validate:
|
|
952
|
+
is_valid = self.classes_selector.validate_step()
|
|
953
|
+
if is_valid and click_cb:
|
|
954
|
+
self.classes_selector_cb()
|
|
955
|
+
self.set_next_step()
|
|
956
|
+
return is_valid
|
|
957
|
+
# ----------------------------------------- #
|
|
958
|
+
|
|
959
|
+
def _init_tags(self, tags_settings: list, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
|
|
960
|
+
"""
|
|
961
|
+
Initialize the tags selector with the given settings.
|
|
962
|
+
|
|
963
|
+
:param tags_settings: The tags settings.
|
|
964
|
+
:type tags_settings: list
|
|
965
|
+
:param options: The application options.
|
|
966
|
+
:type options: dict
|
|
967
|
+
:param click_cb: Click the callback function.
|
|
968
|
+
:type click_cb: bool
|
|
969
|
+
:param validate: Validate the step.
|
|
970
|
+
:type validate: bool
|
|
971
|
+
"""
|
|
972
|
+
if self.tags_selector is None:
|
|
973
|
+
return True # Selector disabled by app options
|
|
974
|
+
|
|
975
|
+
# Set Tags
|
|
976
|
+
self.tags_selector.set_tags(tags_settings)
|
|
977
|
+
is_valid = True
|
|
978
|
+
if validate:
|
|
979
|
+
is_valid = self.tags_selector.validate_step()
|
|
980
|
+
if is_valid and click_cb:
|
|
981
|
+
self.tags_selector_cb()
|
|
982
|
+
self.set_next_step()
|
|
983
|
+
return is_valid
|
|
984
|
+
# ----------------------------------------- #
|
|
985
|
+
|
|
986
|
+
def _init_train_val_splits(self, train_val_splits_settings: dict, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
|
|
907
987
|
"""
|
|
908
988
|
Initialize the train/val splits selector with the given settings.
|
|
909
989
|
|
|
@@ -911,9 +991,13 @@ class TrainGUI:
|
|
|
911
991
|
:type train_val_splits_settings: dict
|
|
912
992
|
:param options: The application options.
|
|
913
993
|
:type options: dict
|
|
994
|
+
:param click_cb: Click the callback function.
|
|
995
|
+
:type click_cb: bool
|
|
996
|
+
:param validate: Validate the step.
|
|
997
|
+
:type validate: bool
|
|
914
998
|
"""
|
|
915
999
|
if self.train_val_splits_selector is None:
|
|
916
|
-
return
|
|
1000
|
+
return True # Selector disabled by app options
|
|
917
1001
|
|
|
918
1002
|
if train_val_splits_settings == {}:
|
|
919
1003
|
available_methods = self.app_options.get("train_val_splits_methods", [])
|
|
@@ -965,74 +1049,16 @@ class TrainGUI:
|
|
|
965
1049
|
self.train_val_splits_selector.train_val_splits.set_collections_splits(
|
|
966
1050
|
train_collections, val_collections
|
|
967
1051
|
)
|
|
968
|
-
self.train_val_splits_selector_cb()
|
|
969
1052
|
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
if self.classes_selector is None:
|
|
980
|
-
return # Selector disabled by app options
|
|
981
|
-
|
|
982
|
-
# Set Classes
|
|
983
|
-
self.classes_selector.set_classes(classes_settings)
|
|
984
|
-
self.classes_selector_cb()
|
|
985
|
-
# ----------------------------------------- #
|
|
986
|
-
|
|
987
|
-
def _init_tags(self, tags_settings: list, options: dict) -> None:
|
|
988
|
-
"""
|
|
989
|
-
Initialize the tags selector with the given settings.
|
|
990
|
-
|
|
991
|
-
:param tags_settings: The tags settings.
|
|
992
|
-
:type tags_settings: list
|
|
993
|
-
:param options: The application options.
|
|
994
|
-
:type options: dict
|
|
995
|
-
"""
|
|
996
|
-
if self.tags_selector is None:
|
|
997
|
-
return # Selector disabled by app options
|
|
998
|
-
|
|
999
|
-
# Set Tags
|
|
1000
|
-
self.tags_selector.set_tags(tags_settings)
|
|
1001
|
-
self.tags_selector_cb()
|
|
1002
|
-
# ----------------------------------------- #
|
|
1003
|
-
|
|
1004
|
-
def _init_model(self, model_settings: dict, options: dict) -> None:
|
|
1005
|
-
"""
|
|
1006
|
-
Initialize the model selector with the given settings.
|
|
1007
|
-
|
|
1008
|
-
:param model_settings: The model settings.
|
|
1009
|
-
:type model_settings: dict
|
|
1010
|
-
:param options: The application options.
|
|
1011
|
-
:type options: dict
|
|
1012
|
-
"""
|
|
1013
|
-
|
|
1014
|
-
# Pretrained
|
|
1015
|
-
if model_settings["source"] == ModelSource.PRETRAINED:
|
|
1016
|
-
self.model_selector.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
|
|
1017
|
-
self.model_selector.pretrained_models_table.set_by_model_name(
|
|
1018
|
-
model_settings["model_name"]
|
|
1019
|
-
)
|
|
1020
|
-
|
|
1021
|
-
# Custom
|
|
1022
|
-
elif model_settings["source"] == ModelSource.CUSTOM:
|
|
1023
|
-
self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
|
|
1024
|
-
self.model_selector.experiment_selector.set_by_task_id(model_settings["task_id"])
|
|
1025
|
-
active_row = self.model_selector.experiment_selector.get_selected_row()
|
|
1026
|
-
if model_settings["checkpoint"] not in active_row.checkpoints_names:
|
|
1027
|
-
raise ValueError(
|
|
1028
|
-
f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
|
|
1029
|
-
)
|
|
1030
|
-
|
|
1031
|
-
active_row.set_selected_checkpoint_by_name(model_settings["checkpoint"])
|
|
1032
|
-
self.model_selector_cb()
|
|
1033
|
-
# ----------------------------------------- #
|
|
1034
|
-
|
|
1035
|
-
def _init_hyperparameters(self, hyperparameters_settings: dict, options: dict) -> None:
|
|
1053
|
+
is_valid = True
|
|
1054
|
+
if validate:
|
|
1055
|
+
is_valid = self.train_val_splits_selector.validate_step()
|
|
1056
|
+
if is_valid and click_cb:
|
|
1057
|
+
self.train_val_splits_selector_cb()
|
|
1058
|
+
self.set_next_step()
|
|
1059
|
+
return is_valid
|
|
1060
|
+
|
|
1061
|
+
def _init_hyperparameters(self, hyperparameters_settings: dict, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
|
|
1036
1062
|
"""
|
|
1037
1063
|
Initialize the hyperparameters selector with the given settings.
|
|
1038
1064
|
|
|
@@ -1040,6 +1066,10 @@ class TrainGUI:
|
|
|
1040
1066
|
:type hyperparameters_settings: dict
|
|
1041
1067
|
:param options: The application options.
|
|
1042
1068
|
:type options: dict
|
|
1069
|
+
:param click_cb: Click the callback function.
|
|
1070
|
+
:type click_cb: bool
|
|
1071
|
+
:param validate: Validate the step.
|
|
1072
|
+
:type validate: bool
|
|
1043
1073
|
"""
|
|
1044
1074
|
self.hyperparameters_selector.set_hyperparameters(hyperparameters_settings)
|
|
1045
1075
|
|
|
@@ -1061,6 +1091,60 @@ class TrainGUI:
|
|
|
1061
1091
|
self.hyperparameters_selector.set_export_tensorrt_checkbox_value(
|
|
1062
1092
|
export_weights_settings.get(RuntimeType.TENSORRT, False)
|
|
1063
1093
|
)
|
|
1064
|
-
self.hyperparameters_selector_cb()
|
|
1065
1094
|
|
|
1095
|
+
is_valid = True
|
|
1096
|
+
if validate:
|
|
1097
|
+
is_valid = self.hyperparameters_selector.validate_step()
|
|
1098
|
+
if is_valid and click_cb:
|
|
1099
|
+
self.hyperparameters_selector_cb()
|
|
1100
|
+
self.set_next_step()
|
|
1101
|
+
return is_valid
|
|
1102
|
+
# ----------------------------------------- #
|
|
1103
|
+
|
|
1104
|
+
# Run from experiment page
|
|
1105
|
+
def _download_experiment_state(self, experiment_info: ExperimentInfo) -> dict:
|
|
1106
|
+
local_app_state_path = f"./app_state.json"
|
|
1107
|
+
remote_app_state_path = os.path.join(experiment_info.artifacts_dir, "app_state.json")
|
|
1108
|
+
self._api.file.download(self.team_id, remote_app_state_path, local_app_state_path)
|
|
1109
|
+
app_state = sly_json.load_json_file(local_app_state_path)
|
|
1110
|
+
sly_fs.silent_remove(local_app_state_path)
|
|
1111
|
+
return app_state
|
|
1112
|
+
|
|
1113
|
+
def _download_experiment_hparams(self, experiment_info: ExperimentInfo) -> dict:
|
|
1114
|
+
local_hparams_path = f"./{experiment_info.hyperparameters}"
|
|
1115
|
+
remote_hparams_path = os.path.join(experiment_info.artifacts_dir, experiment_info.hyperparameters)
|
|
1116
|
+
self._api.file.download(self.team_id, remote_hparams_path, local_hparams_path)
|
|
1117
|
+
with open(local_hparams_path, "r") as f:
|
|
1118
|
+
hparams = f.read()
|
|
1119
|
+
sly_fs.silent_remove(local_hparams_path)
|
|
1120
|
+
return hparams
|
|
1121
|
+
|
|
1122
|
+
def _run_from_experiment(self, train_task_id: int, train_mode: str):
|
|
1123
|
+
experiment_info = self._api.nn.get_experiment_info(train_task_id)
|
|
1124
|
+
experiment_state = experiment_info.app_state
|
|
1125
|
+
|
|
1126
|
+
if train_mode == "continue":
|
|
1127
|
+
model_settings = {
|
|
1128
|
+
"source": ModelSource.CUSTOM,
|
|
1129
|
+
"task_id": train_task_id,
|
|
1130
|
+
"checkpoint": experiment_info.best_checkpoint
|
|
1131
|
+
}
|
|
1132
|
+
|
|
1133
|
+
if experiment_state is not None:
|
|
1134
|
+
self.input_selector.validator_text.set(f"Training configuration is loaded from the experiment: {experiment_info.experiment_name}.", "success")
|
|
1135
|
+
self.input_selector.validator_text.show()
|
|
1136
|
+
experiment_state = self._download_experiment_state(experiment_info)
|
|
1137
|
+
if train_mode == "continue":
|
|
1138
|
+
experiment_state["model"] = model_settings
|
|
1139
|
+
self.load_from_app_state(experiment_state, click_cb=False, validate_steps=False)
|
|
1140
|
+
else:
|
|
1141
|
+
self.input_selector.validator_text.set(
|
|
1142
|
+
f"Couldn't load full training configuration from the experiment: {experiment_info.experiment_name}. Only model and hyperparameters are loaded.",
|
|
1143
|
+
"warning"
|
|
1144
|
+
)
|
|
1145
|
+
self.input_selector.validator_text.show()
|
|
1146
|
+
hparams = self._download_experiment_hparams(experiment_info)
|
|
1147
|
+
self.hyperparameters_selector.set_hyperparameters(hparams)
|
|
1148
|
+
if train_mode == "continue":
|
|
1149
|
+
self._init_model(model_settings, {}, click_cb=False, validate=False)
|
|
1066
1150
|
# ----------------------------------------- #
|