supervisely 6.73.420__py3-none-any.whl → 6.73.421__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/api.py +10 -5
- supervisely/api/app_api.py +71 -4
- supervisely/api/module_api.py +4 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/project_api.py +35 -6
- supervisely/api/task_api.py +5 -1
- supervisely/app/widgets/__init__.py +8 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/deploy_model/__init__.py +0 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
- supervisely/app/widgets/fast_table/fast_table.py +402 -74
- supervisely/app/widgets/fast_table/script.js +364 -96
- supervisely/app/widgets/fast_table/style.css +24 -0
- supervisely/app/widgets/fast_table/template.html +43 -3
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +160 -94
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
- supervisely/nn/inference/predict_app/gui/gui.py +710 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +282 -0
- supervisely/nn/inference/predict_app/predict_app.py +184 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/prediction.py +2 -0
- supervisely/nn/model/prediction_session.py +20 -3
- supervisely/nn/training/gui/gui.py +131 -44
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
- supervisely/nn/training/gui/training_artifacts.py +0 -5
- supervisely/nn/training/train_app.py +161 -44
- supervisely/template/experiment/experiment.html.jinja +74 -17
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/METADATA +3 -1
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/RECORD +74 -56
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/LICENSE +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/WHEEL +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/top_level.txt +0 -0
|
@@ -9,17 +9,17 @@ import os
|
|
|
9
9
|
from os import environ, getenv
|
|
10
10
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
11
11
|
|
|
12
|
-
from supervisely import logger
|
|
13
|
-
import supervisely.io.fs as sly_fs
|
|
14
12
|
import supervisely.io.env as sly_env
|
|
13
|
+
import supervisely.io.fs as sly_fs
|
|
15
14
|
import supervisely.io.json as sly_json
|
|
16
|
-
from supervisely import Api, ProjectMeta
|
|
15
|
+
from supervisely import Api, ProjectMeta, logger
|
|
17
16
|
from supervisely._utils import is_production
|
|
18
17
|
from supervisely.app.widgets import Button, Card, Stepper, Widget
|
|
19
18
|
from supervisely.geometry.bitmap import Bitmap
|
|
20
19
|
from supervisely.geometry.graph import GraphNodes
|
|
21
20
|
from supervisely.geometry.polygon import Polygon
|
|
22
21
|
from supervisely.geometry.rectangle import Rectangle
|
|
22
|
+
from supervisely.nn.experiments import ExperimentInfo
|
|
23
23
|
from supervisely.nn.task_type import TaskType
|
|
24
24
|
from supervisely.nn.training.gui.classes_selector import ClassesSelector
|
|
25
25
|
from supervisely.nn.training.gui.hyperparameters_selector import HyperparametersSelector
|
|
@@ -32,7 +32,6 @@ from supervisely.nn.training.gui.training_logs import TrainingLogs
|
|
|
32
32
|
from supervisely.nn.training.gui.training_process import TrainingProcess
|
|
33
33
|
from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
|
|
34
34
|
from supervisely.nn.utils import ModelSource, RuntimeType
|
|
35
|
-
from supervisely.nn.experiments import ExperimentInfo
|
|
36
35
|
|
|
37
36
|
|
|
38
37
|
class StepFlow:
|
|
@@ -303,7 +302,9 @@ class TrainGUI:
|
|
|
303
302
|
# 3. Classes selector
|
|
304
303
|
self.classes_selector = None
|
|
305
304
|
if self.show_classes_selector:
|
|
306
|
-
self.classes_selector = ClassesSelector(
|
|
305
|
+
self.classes_selector = ClassesSelector(
|
|
306
|
+
self.project_id, [], self.model_selector, self.app_options
|
|
307
|
+
)
|
|
307
308
|
self.steps.append(self.classes_selector.card)
|
|
308
309
|
|
|
309
310
|
# 4. Tags selector
|
|
@@ -355,16 +356,19 @@ class TrainGUI:
|
|
|
355
356
|
experiment_name = "Enter experiment name"
|
|
356
357
|
else:
|
|
357
358
|
if self.task_id == -1:
|
|
358
|
-
experiment_name = f"
|
|
359
|
+
experiment_name = f"debug {self.project_info.name} {model_name}"
|
|
359
360
|
else:
|
|
360
|
-
experiment_name = f"{self.task_id}
|
|
361
|
+
experiment_name = f"{self.task_id} {self.project_info.name} {model_name}"
|
|
361
362
|
|
|
362
363
|
if experiment_name == self.training_process.get_experiment_name():
|
|
363
364
|
return
|
|
364
365
|
self.training_process.set_experiment_name(experiment_name)
|
|
365
366
|
|
|
366
367
|
def need_convert_class_shapes() -> bool:
|
|
367
|
-
if
|
|
368
|
+
if (
|
|
369
|
+
self.hyperparameters_selector.run_model_benchmark_checkbox is None
|
|
370
|
+
or not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked()
|
|
371
|
+
):
|
|
368
372
|
self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
|
|
369
373
|
self.need_convert_shapes = False
|
|
370
374
|
return False
|
|
@@ -376,14 +380,22 @@ class TrainGUI:
|
|
|
376
380
|
|
|
377
381
|
# Exclude classes with no annotations to avoid unnecessary conversion
|
|
378
382
|
data = self.classes_selector.classes_table._table_data
|
|
379
|
-
empty_classes = {
|
|
383
|
+
empty_classes = {
|
|
384
|
+
r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
|
|
385
|
+
}
|
|
380
386
|
need_conversion = bool(wrong_shapes - empty_classes)
|
|
381
387
|
else:
|
|
382
388
|
# Classes selector disabled – check entire project meta
|
|
383
389
|
if task_type == TaskType.OBJECT_DETECTION:
|
|
384
|
-
need_conversion = any(
|
|
390
|
+
need_conversion = any(
|
|
391
|
+
obj_cls.geometry_type != Rectangle
|
|
392
|
+
for obj_cls in self.project_meta.obj_classes
|
|
393
|
+
)
|
|
385
394
|
elif task_type in [TaskType.INSTANCE_SEGMENTATION, TaskType.SEMANTIC_SEGMENTATION]:
|
|
386
|
-
need_conversion = any(
|
|
395
|
+
need_conversion = any(
|
|
396
|
+
obj_cls.geometry_type == Polygon
|
|
397
|
+
for obj_cls in self.project_meta.obj_classes
|
|
398
|
+
)
|
|
387
399
|
else:
|
|
388
400
|
need_conversion = False
|
|
389
401
|
|
|
@@ -394,6 +406,7 @@ class TrainGUI:
|
|
|
394
406
|
|
|
395
407
|
self.need_convert_shapes = need_conversion
|
|
396
408
|
return need_conversion
|
|
409
|
+
|
|
397
410
|
# ------------------------------------------------- #
|
|
398
411
|
|
|
399
412
|
self.step_flow = StepFlow(self.stepper, self.app_options)
|
|
@@ -420,7 +433,7 @@ class TrainGUI:
|
|
|
420
433
|
self.model_selector.widgets_to_disable,
|
|
421
434
|
self.model_selector.validator_text,
|
|
422
435
|
self.model_selector.validate_step,
|
|
423
|
-
position=position
|
|
436
|
+
position=position,
|
|
424
437
|
).add_on_select_actions("model_selector", [set_experiment_name])
|
|
425
438
|
position += 1
|
|
426
439
|
|
|
@@ -517,7 +530,9 @@ class TrainGUI:
|
|
|
517
530
|
has_model_selector = self.show_model_selector and self.model_selector is not None
|
|
518
531
|
has_classes_selector = self.show_classes_selector and self.classes_selector is not None
|
|
519
532
|
has_tags_selector = self.show_tags_selector and self.tags_selector is not None
|
|
520
|
-
has_train_val_splits =
|
|
533
|
+
has_train_val_splits = (
|
|
534
|
+
self.show_train_val_splits_selector and self.train_val_splits_selector is not None
|
|
535
|
+
)
|
|
521
536
|
|
|
522
537
|
# Set step dependency chain
|
|
523
538
|
prev_step = "input_selector"
|
|
@@ -571,11 +586,13 @@ class TrainGUI:
|
|
|
571
586
|
@self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
|
|
572
587
|
def show_mb_speedtest(is_checked: bool):
|
|
573
588
|
self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
|
|
589
|
+
|
|
574
590
|
# ------------------------------------------------- #
|
|
575
591
|
|
|
576
592
|
self.layout: Widget = self.stepper
|
|
577
593
|
|
|
578
594
|
# Run from experiment page
|
|
595
|
+
|
|
579
596
|
train_task_id = getenv("modal.state.trainTaskId", None)
|
|
580
597
|
if train_task_id is not None:
|
|
581
598
|
train_task_id = int(train_task_id)
|
|
@@ -584,7 +601,6 @@ class TrainGUI:
|
|
|
584
601
|
self._run_from_experiment(train_task_id, train_mode)
|
|
585
602
|
# ----------------------------------------- #
|
|
586
603
|
|
|
587
|
-
|
|
588
604
|
def set_next_step(self):
|
|
589
605
|
current_step = self.stepper.get_active_step()
|
|
590
606
|
self.stepper.set_active_step(current_step + 1)
|
|
@@ -605,6 +621,8 @@ class TrainGUI:
|
|
|
605
621
|
"""
|
|
606
622
|
if self.input_selector is not None:
|
|
607
623
|
self.input_selector.button.enable()
|
|
624
|
+
if self.model_selector is not None:
|
|
625
|
+
self.model_selector.button.enable()
|
|
608
626
|
if self.train_val_splits_selector is not None:
|
|
609
627
|
self.train_val_splits_selector.button.enable()
|
|
610
628
|
if self.classes_selector is not None:
|
|
@@ -622,6 +640,8 @@ class TrainGUI:
|
|
|
622
640
|
"""
|
|
623
641
|
if self.input_selector is not None:
|
|
624
642
|
self.input_selector.button.disable()
|
|
643
|
+
if self.model_selector is not None:
|
|
644
|
+
self.model_selector.button.disable()
|
|
625
645
|
if self.train_val_splits_selector is not None:
|
|
626
646
|
self.train_val_splits_selector.button.disable()
|
|
627
647
|
if self.classes_selector is not None:
|
|
@@ -775,7 +795,9 @@ class TrainGUI:
|
|
|
775
795
|
)
|
|
776
796
|
return app_state
|
|
777
797
|
|
|
778
|
-
def load_from_app_state(
|
|
798
|
+
def load_from_app_state(
|
|
799
|
+
self, app_state: Union[str, dict], click_cb: bool = True, validate_steps: bool = True
|
|
800
|
+
) -> None:
|
|
779
801
|
"""
|
|
780
802
|
Load the GUI state from app state dictionary or path to the state file.
|
|
781
803
|
|
|
@@ -820,15 +842,15 @@ class TrainGUI:
|
|
|
820
842
|
"TensorRT": True
|
|
821
843
|
},
|
|
822
844
|
},
|
|
823
|
-
"experiment_name": "
|
|
845
|
+
"experiment_name": "My Experiment",
|
|
824
846
|
}
|
|
825
847
|
"""
|
|
826
848
|
if isinstance(app_state, str):
|
|
827
849
|
app_state = sly_json.load_json_file(app_state)
|
|
828
|
-
|
|
850
|
+
|
|
829
851
|
app_state = self.validate_app_state(app_state)
|
|
830
852
|
options = app_state.get("options", {})
|
|
831
|
-
|
|
853
|
+
|
|
832
854
|
# Set experiment name
|
|
833
855
|
experiment_name = app_state.get("experiment_name")
|
|
834
856
|
if experiment_name is not None:
|
|
@@ -839,7 +861,7 @@ class TrainGUI:
|
|
|
839
861
|
if not init_fn(settings, options, click_cb, validate_steps):
|
|
840
862
|
return False
|
|
841
863
|
return True
|
|
842
|
-
|
|
864
|
+
|
|
843
865
|
# GUI init steps
|
|
844
866
|
_steps = [
|
|
845
867
|
(self._init_input, app_state.get("input"), "Input project"),
|
|
@@ -856,12 +878,20 @@ class TrainGUI:
|
|
|
856
878
|
logger.warning(f"Step '{step_name}' {idx}/{len(_steps)} failed to validate")
|
|
857
879
|
return
|
|
858
880
|
if validate_steps:
|
|
859
|
-
logger.info(
|
|
881
|
+
logger.info(
|
|
882
|
+
f"Step '{step_name}' {idx}/{len(_steps)} has been validated successfully"
|
|
883
|
+
)
|
|
860
884
|
if validate_steps:
|
|
861
885
|
logger.info(f"All steps have been validated successfully")
|
|
862
886
|
# ------------------------------------------------------------------ #
|
|
863
887
|
|
|
864
|
-
def _init_input(
|
|
888
|
+
def _init_input(
|
|
889
|
+
self,
|
|
890
|
+
input_settings: Union[dict, None],
|
|
891
|
+
options: dict,
|
|
892
|
+
click_cb: bool = True,
|
|
893
|
+
validate: bool = True,
|
|
894
|
+
) -> bool:
|
|
865
895
|
"""
|
|
866
896
|
Initialize the input selector with the given settings.
|
|
867
897
|
|
|
@@ -885,7 +915,13 @@ class TrainGUI:
|
|
|
885
915
|
return is_valid
|
|
886
916
|
# ----------------------------------------- #
|
|
887
917
|
|
|
888
|
-
def _init_model(
|
|
918
|
+
def _init_model(
|
|
919
|
+
self,
|
|
920
|
+
model_settings: dict,
|
|
921
|
+
options: dict = None,
|
|
922
|
+
click_cb: bool = True,
|
|
923
|
+
validate: bool = True,
|
|
924
|
+
) -> bool:
|
|
889
925
|
"""
|
|
890
926
|
Initialize the model selector with the given settings.
|
|
891
927
|
|
|
@@ -909,14 +945,18 @@ class TrainGUI:
|
|
|
909
945
|
# Custom
|
|
910
946
|
elif model_settings["source"] == ModelSource.CUSTOM:
|
|
911
947
|
self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
|
|
912
|
-
self.model_selector.experiment_selector.
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
948
|
+
self.model_selector.experiment_selector.set_selected_row_by_task_id(
|
|
949
|
+
model_settings["task_id"]
|
|
950
|
+
)
|
|
951
|
+
experiment_info = self.model_selector.experiment_selector.get_selected_experiment_info()
|
|
952
|
+
if model_settings["checkpoint"] not in experiment_info.checkpoints:
|
|
953
|
+
if f"checkpoints/{model_settings['checkpoint']}" not in experiment_info.checkpoints:
|
|
954
|
+
raise ValueError(
|
|
955
|
+
f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
|
|
956
|
+
)
|
|
957
|
+
self.model_selector.experiment_selector.set_selected_checkpoint_by_name(
|
|
958
|
+
model_settings["checkpoint"]
|
|
959
|
+
)
|
|
920
960
|
|
|
921
961
|
is_valid = True
|
|
922
962
|
if validate:
|
|
@@ -926,8 +966,10 @@ class TrainGUI:
|
|
|
926
966
|
self.set_next_step()
|
|
927
967
|
return is_valid
|
|
928
968
|
# ----------------------------------------- #
|
|
929
|
-
|
|
930
|
-
def _init_classes(
|
|
969
|
+
|
|
970
|
+
def _init_classes(
|
|
971
|
+
self, classes_settings: list, options: dict, click_cb: bool = True, validate: bool = True
|
|
972
|
+
) -> bool:
|
|
931
973
|
"""
|
|
932
974
|
Initialize the classes selector with the given settings.
|
|
933
975
|
|
|
@@ -941,7 +983,7 @@ class TrainGUI:
|
|
|
941
983
|
:type validate: bool
|
|
942
984
|
"""
|
|
943
985
|
if self.classes_selector is None:
|
|
944
|
-
return True
|
|
986
|
+
return True # Selector disabled by app options
|
|
945
987
|
|
|
946
988
|
convert_class_shapes = options.get("convert_class_shapes", True)
|
|
947
989
|
if convert_class_shapes:
|
|
@@ -958,7 +1000,9 @@ class TrainGUI:
|
|
|
958
1000
|
return is_valid
|
|
959
1001
|
# ----------------------------------------- #
|
|
960
1002
|
|
|
961
|
-
def _init_tags(
|
|
1003
|
+
def _init_tags(
|
|
1004
|
+
self, tags_settings: list, options: dict, click_cb: bool = True, validate: bool = True
|
|
1005
|
+
) -> bool:
|
|
962
1006
|
"""
|
|
963
1007
|
Initialize the tags selector with the given settings.
|
|
964
1008
|
|
|
@@ -972,7 +1016,7 @@ class TrainGUI:
|
|
|
972
1016
|
:type validate: bool
|
|
973
1017
|
"""
|
|
974
1018
|
if self.tags_selector is None:
|
|
975
|
-
return True
|
|
1019
|
+
return True # Selector disabled by app options
|
|
976
1020
|
|
|
977
1021
|
# Set Tags
|
|
978
1022
|
self.tags_selector.set_tags(tags_settings)
|
|
@@ -985,7 +1029,13 @@ class TrainGUI:
|
|
|
985
1029
|
return is_valid
|
|
986
1030
|
# ----------------------------------------- #
|
|
987
1031
|
|
|
988
|
-
def _init_train_val_splits(
|
|
1032
|
+
def _init_train_val_splits(
|
|
1033
|
+
self,
|
|
1034
|
+
train_val_splits_settings: dict,
|
|
1035
|
+
options: dict,
|
|
1036
|
+
click_cb: bool = True,
|
|
1037
|
+
validate: bool = True,
|
|
1038
|
+
) -> bool:
|
|
989
1039
|
"""
|
|
990
1040
|
Initialize the train/val splits selector with the given settings.
|
|
991
1041
|
|
|
@@ -999,7 +1049,7 @@ class TrainGUI:
|
|
|
999
1049
|
:type validate: bool
|
|
1000
1050
|
"""
|
|
1001
1051
|
if self.train_val_splits_selector is None:
|
|
1002
|
-
return True
|
|
1052
|
+
return True # Selector disabled by app options
|
|
1003
1053
|
|
|
1004
1054
|
if train_val_splits_settings == {}:
|
|
1005
1055
|
available_methods = self.app_options.get("train_val_splits_methods", [])
|
|
@@ -1059,8 +1109,14 @@ class TrainGUI:
|
|
|
1059
1109
|
self.train_val_splits_selector_cb()
|
|
1060
1110
|
self.set_next_step()
|
|
1061
1111
|
return is_valid
|
|
1062
|
-
|
|
1063
|
-
def _init_hyperparameters(
|
|
1112
|
+
|
|
1113
|
+
def _init_hyperparameters(
|
|
1114
|
+
self,
|
|
1115
|
+
hyperparameters_settings: dict,
|
|
1116
|
+
options: dict,
|
|
1117
|
+
click_cb: bool = True,
|
|
1118
|
+
validate: bool = True,
|
|
1119
|
+
) -> bool:
|
|
1064
1120
|
"""
|
|
1065
1121
|
Initialize the hyperparameters selector with the given settings.
|
|
1066
1122
|
|
|
@@ -1101,6 +1157,7 @@ class TrainGUI:
|
|
|
1101
1157
|
self.hyperparameters_selector_cb()
|
|
1102
1158
|
self.set_next_step()
|
|
1103
1159
|
return is_valid
|
|
1160
|
+
|
|
1104
1161
|
# ----------------------------------------- #
|
|
1105
1162
|
|
|
1106
1163
|
# Run from experiment page
|
|
@@ -1111,10 +1168,12 @@ class TrainGUI:
|
|
|
1111
1168
|
app_state = sly_json.load_json_file(local_app_state_path)
|
|
1112
1169
|
sly_fs.silent_remove(local_app_state_path)
|
|
1113
1170
|
return app_state
|
|
1114
|
-
|
|
1171
|
+
|
|
1115
1172
|
def _download_experiment_hparams(self, experiment_info: ExperimentInfo) -> dict:
|
|
1116
1173
|
local_hparams_path = f"./{experiment_info.hyperparameters}"
|
|
1117
|
-
remote_hparams_path = os.path.join(
|
|
1174
|
+
remote_hparams_path = os.path.join(
|
|
1175
|
+
experiment_info.artifacts_dir, experiment_info.hyperparameters
|
|
1176
|
+
)
|
|
1118
1177
|
self._api.file.download(self.team_id, remote_hparams_path, local_hparams_path)
|
|
1119
1178
|
with open(local_hparams_path, "r") as f:
|
|
1120
1179
|
hparams = f.read()
|
|
@@ -1129,11 +1188,14 @@ class TrainGUI:
|
|
|
1129
1188
|
model_settings = {
|
|
1130
1189
|
"source": ModelSource.CUSTOM,
|
|
1131
1190
|
"task_id": train_task_id,
|
|
1132
|
-
"checkpoint": experiment_info.best_checkpoint
|
|
1191
|
+
"checkpoint": experiment_info.best_checkpoint,
|
|
1133
1192
|
}
|
|
1134
1193
|
|
|
1135
1194
|
if experiment_state is not None:
|
|
1136
|
-
self.input_selector.validator_text.set(
|
|
1195
|
+
self.input_selector.validator_text.set(
|
|
1196
|
+
f"Training configuration is loaded from the experiment: {experiment_info.experiment_name}.",
|
|
1197
|
+
"success",
|
|
1198
|
+
)
|
|
1137
1199
|
self.input_selector.validator_text.show()
|
|
1138
1200
|
experiment_state = self._download_experiment_state(experiment_info)
|
|
1139
1201
|
if train_mode == "continue":
|
|
@@ -1142,7 +1204,7 @@ class TrainGUI:
|
|
|
1142
1204
|
else:
|
|
1143
1205
|
self.input_selector.validator_text.set(
|
|
1144
1206
|
f"Couldn't load full training configuration from the experiment: {experiment_info.experiment_name}. Only model and hyperparameters are loaded.",
|
|
1145
|
-
"warning"
|
|
1207
|
+
"warning",
|
|
1146
1208
|
)
|
|
1147
1209
|
self.input_selector.validator_text.show()
|
|
1148
1210
|
hparams = self._download_experiment_hparams(experiment_info)
|
|
@@ -1150,3 +1212,28 @@ class TrainGUI:
|
|
|
1150
1212
|
if train_mode == "continue":
|
|
1151
1213
|
self._init_model(model_settings, {}, click_cb=False, validate=False)
|
|
1152
1214
|
# ----------------------------------------- #
|
|
1215
|
+
|
|
1216
|
+
def _extract_state_from_env(self):
|
|
1217
|
+
import ast
|
|
1218
|
+
import os
|
|
1219
|
+
|
|
1220
|
+
base = "modal.state"
|
|
1221
|
+
state = {}
|
|
1222
|
+
for key, value in os.environ.items():
|
|
1223
|
+
state_part = state
|
|
1224
|
+
if key.startswith(base):
|
|
1225
|
+
key = key.replace(base + ".", "")
|
|
1226
|
+
parts = key.split(".")
|
|
1227
|
+
while len(parts) > 1:
|
|
1228
|
+
part = parts.pop(0)
|
|
1229
|
+
state_part.setdefault(part, {})
|
|
1230
|
+
state_part = state_part[part]
|
|
1231
|
+
part = parts.pop(0)
|
|
1232
|
+
if value and (value[0] == "[" or value.isdigit()):
|
|
1233
|
+
state_part[part] = ast.literal_eval(value)
|
|
1234
|
+
elif value in ["True", "true", "False", "false"]:
|
|
1235
|
+
state_part[part] = value in ["True", "true"]
|
|
1236
|
+
else:
|
|
1237
|
+
state_part[part] = value
|
|
1238
|
+
return state
|
|
1239
|
+
# ----------------------------------------- #
|
|
@@ -26,6 +26,7 @@ class ModelSelector:
|
|
|
26
26
|
|
|
27
27
|
def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
|
|
28
28
|
# Init widgets
|
|
29
|
+
self.api = api
|
|
29
30
|
self.pretrained_models_table = None
|
|
30
31
|
self.experiment_selector = None
|
|
31
32
|
self.model_source_tabs = None
|
|
@@ -50,7 +51,7 @@ class ModelSelector:
|
|
|
50
51
|
|
|
51
52
|
# GUI Components
|
|
52
53
|
self.pretrained_models_table = PretrainedModelsSelector(self.models)
|
|
53
|
-
experiment_infos = get_experiment_infos(api, self.team_id, framework)
|
|
54
|
+
experiment_infos = get_experiment_infos(self.api, self.team_id, framework)
|
|
54
55
|
if self.app_options.get("legacy_checkpoints", False):
|
|
55
56
|
try:
|
|
56
57
|
framework_cls = FrameworkMapper.get_framework_cls(framework, self.team_id)
|
|
@@ -59,7 +60,7 @@ class ModelSelector:
|
|
|
59
60
|
except:
|
|
60
61
|
logger.warning(f"Legacy checkpoints are not available for '{framework}'")
|
|
61
62
|
|
|
62
|
-
self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
|
|
63
|
+
self.experiment_selector = ExperimentSelector(self.api, self.team_id, experiment_infos)
|
|
63
64
|
|
|
64
65
|
tab_titles = []
|
|
65
66
|
tab_descriptions = []
|
|
@@ -85,6 +86,7 @@ class ModelSelector:
|
|
|
85
86
|
self.validator_text = Text("")
|
|
86
87
|
self.validator_text.hide()
|
|
87
88
|
self.button = Button("Select")
|
|
89
|
+
|
|
88
90
|
self.display_widgets.extend([self.model_source_tabs, self.validator_text, self.button])
|
|
89
91
|
# -------------------------------- #
|
|
90
92
|
|
|
@@ -118,14 +120,14 @@ class ModelSelector:
|
|
|
118
120
|
model_name = _get_model_name(selected_row)
|
|
119
121
|
else:
|
|
120
122
|
selected_row = self.experiment_selector.get_selected_experiment_info()
|
|
121
|
-
model_name = selected_row.
|
|
123
|
+
model_name = selected_row.model_name
|
|
122
124
|
return model_name
|
|
123
125
|
|
|
124
126
|
def get_model_info(self) -> dict:
|
|
125
127
|
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
126
128
|
return self.pretrained_models_table.get_selected_row()
|
|
127
129
|
else:
|
|
128
|
-
return self.experiment_selector.get_selected_experiment_info()
|
|
130
|
+
return self.experiment_selector.get_selected_experiment_info().to_json()
|
|
129
131
|
|
|
130
132
|
def get_checkpoint_name(self) -> str:
|
|
131
133
|
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
@@ -146,7 +148,7 @@ class ModelSelector:
|
|
|
146
148
|
else:
|
|
147
149
|
checkpoint_name = self.experiment_selector.get_selected_checkpoint_name()
|
|
148
150
|
return checkpoint_name
|
|
149
|
-
|
|
151
|
+
|
|
150
152
|
def get_checkpoint_link(self) -> str:
|
|
151
153
|
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
152
154
|
selected_row = self.pretrained_models_table.get_selected_row()
|
|
@@ -182,4 +184,4 @@ class ModelSelector:
|
|
|
182
184
|
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
183
185
|
return self.pretrained_models_table.get_selected_task_type()
|
|
184
186
|
else:
|
|
185
|
-
return self.experiment_selector.
|
|
187
|
+
return self.experiment_selector.get_selected_experiment_info().task_type
|