supervisely 6.73.420__py3-none-any.whl → 6.73.422__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. supervisely/api/api.py +10 -5
  2. supervisely/api/app_api.py +71 -4
  3. supervisely/api/module_api.py +4 -0
  4. supervisely/api/nn/deploy_api.py +15 -9
  5. supervisely/api/nn/ecosystem_models_api.py +201 -0
  6. supervisely/api/nn/neural_network_api.py +12 -3
  7. supervisely/api/project_api.py +35 -6
  8. supervisely/api/task_api.py +5 -1
  9. supervisely/app/widgets/__init__.py +8 -1
  10. supervisely/app/widgets/agent_selector/template.html +1 -0
  11. supervisely/app/widgets/deploy_model/__init__.py +0 -0
  12. supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
  13. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  14. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  15. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  16. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  17. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
  18. supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
  19. supervisely/app/widgets/fast_table/fast_table.py +402 -74
  20. supervisely/app/widgets/fast_table/script.js +364 -96
  21. supervisely/app/widgets/fast_table/style.css +24 -0
  22. supervisely/app/widgets/fast_table/template.html +43 -3
  23. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  24. supervisely/app/widgets/select/select.py +6 -4
  25. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
  26. supervisely/app/widgets/tabs/tabs.py +22 -6
  27. supervisely/app/widgets/tabs/template.html +5 -1
  28. supervisely/nn/artifacts/__init__.py +1 -1
  29. supervisely/nn/artifacts/artifacts.py +10 -2
  30. supervisely/nn/artifacts/detectron2.py +1 -0
  31. supervisely/nn/artifacts/hrda.py +1 -0
  32. supervisely/nn/artifacts/mmclassification.py +20 -0
  33. supervisely/nn/artifacts/mmdetection.py +5 -3
  34. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  35. supervisely/nn/artifacts/ritm.py +1 -0
  36. supervisely/nn/artifacts/rtdetr.py +1 -0
  37. supervisely/nn/artifacts/unet.py +1 -0
  38. supervisely/nn/artifacts/utils.py +3 -0
  39. supervisely/nn/artifacts/yolov5.py +2 -0
  40. supervisely/nn/artifacts/yolov8.py +1 -0
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  42. supervisely/nn/experiments.py +9 -0
  43. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  44. supervisely/nn/inference/inference.py +160 -94
  45. supervisely/nn/inference/predict_app/__init__.py +0 -0
  46. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  47. supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
  48. supervisely/nn/inference/predict_app/gui/gui.py +710 -0
  49. supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
  50. supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
  51. supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
  52. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  53. supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
  54. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  55. supervisely/nn/inference/predict_app/gui/utils.py +282 -0
  56. supervisely/nn/inference/predict_app/predict_app.py +184 -0
  57. supervisely/nn/inference/uploader.py +9 -5
  58. supervisely/nn/model/prediction.py +2 -0
  59. supervisely/nn/model/prediction_session.py +20 -3
  60. supervisely/nn/training/gui/gui.py +131 -44
  61. supervisely/nn/training/gui/model_selector.py +8 -6
  62. supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
  63. supervisely/nn/training/gui/training_artifacts.py +0 -5
  64. supervisely/nn/training/train_app.py +161 -44
  65. supervisely/template/experiment/experiment.html.jinja +74 -17
  66. supervisely/template/experiment/experiment_generator.py +258 -112
  67. supervisely/template/experiment/header.html.jinja +31 -13
  68. supervisely/template/experiment/sly-style.css +7 -2
  69. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/METADATA +3 -1
  70. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/RECORD +74 -56
  71. supervisely/app/widgets/experiment_selector/style.css +0 -27
  72. supervisely/app/widgets/experiment_selector/template.html +0 -61
  73. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/LICENSE +0 -0
  74. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/WHEEL +0 -0
  75. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/entry_points.txt +0 -0
  76. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/top_level.txt +0 -0
@@ -9,17 +9,17 @@ import os
9
9
  from os import environ, getenv
10
10
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11
11
 
12
- from supervisely import logger
13
- import supervisely.io.fs as sly_fs
14
12
  import supervisely.io.env as sly_env
13
+ import supervisely.io.fs as sly_fs
15
14
  import supervisely.io.json as sly_json
16
- from supervisely import Api, ProjectMeta
15
+ from supervisely import Api, ProjectMeta, logger
17
16
  from supervisely._utils import is_production
18
17
  from supervisely.app.widgets import Button, Card, Stepper, Widget
19
18
  from supervisely.geometry.bitmap import Bitmap
20
19
  from supervisely.geometry.graph import GraphNodes
21
20
  from supervisely.geometry.polygon import Polygon
22
21
  from supervisely.geometry.rectangle import Rectangle
22
+ from supervisely.nn.experiments import ExperimentInfo
23
23
  from supervisely.nn.task_type import TaskType
24
24
  from supervisely.nn.training.gui.classes_selector import ClassesSelector
25
25
  from supervisely.nn.training.gui.hyperparameters_selector import HyperparametersSelector
@@ -32,7 +32,6 @@ from supervisely.nn.training.gui.training_logs import TrainingLogs
32
32
  from supervisely.nn.training.gui.training_process import TrainingProcess
33
33
  from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
34
34
  from supervisely.nn.utils import ModelSource, RuntimeType
35
- from supervisely.nn.experiments import ExperimentInfo
36
35
 
37
36
 
38
37
  class StepFlow:
@@ -303,7 +302,9 @@ class TrainGUI:
303
302
  # 3. Classes selector
304
303
  self.classes_selector = None
305
304
  if self.show_classes_selector:
306
- self.classes_selector = ClassesSelector(self.project_id, [], self.model_selector, self.app_options)
305
+ self.classes_selector = ClassesSelector(
306
+ self.project_id, [], self.model_selector, self.app_options
307
+ )
307
308
  self.steps.append(self.classes_selector.card)
308
309
 
309
310
  # 4. Tags selector
@@ -355,16 +356,19 @@ class TrainGUI:
355
356
  experiment_name = "Enter experiment name"
356
357
  else:
357
358
  if self.task_id == -1:
358
- experiment_name = f"debug_{self.project_info.name}_{model_name}"
359
+ experiment_name = f"debug {self.project_info.name} {model_name}"
359
360
  else:
360
- experiment_name = f"{self.task_id}_{self.project_info.name}_{model_name}"
361
+ experiment_name = f"{self.task_id} {self.project_info.name} {model_name}"
361
362
 
362
363
  if experiment_name == self.training_process.get_experiment_name():
363
364
  return
364
365
  self.training_process.set_experiment_name(experiment_name)
365
366
 
366
367
  def need_convert_class_shapes() -> bool:
367
- if self.hyperparameters_selector.run_model_benchmark_checkbox is None or not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
368
+ if (
369
+ self.hyperparameters_selector.run_model_benchmark_checkbox is None
370
+ or not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked()
371
+ ):
368
372
  self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
369
373
  self.need_convert_shapes = False
370
374
  return False
@@ -376,14 +380,22 @@ class TrainGUI:
376
380
 
377
381
  # Exclude classes with no annotations to avoid unnecessary conversion
378
382
  data = self.classes_selector.classes_table._table_data
379
- empty_classes = {r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0}
383
+ empty_classes = {
384
+ r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
385
+ }
380
386
  need_conversion = bool(wrong_shapes - empty_classes)
381
387
  else:
382
388
  # Classes selector disabled – check entire project meta
383
389
  if task_type == TaskType.OBJECT_DETECTION:
384
- need_conversion = any(obj_cls.geometry_type != Rectangle for obj_cls in self.project_meta.obj_classes)
390
+ need_conversion = any(
391
+ obj_cls.geometry_type != Rectangle
392
+ for obj_cls in self.project_meta.obj_classes
393
+ )
385
394
  elif task_type in [TaskType.INSTANCE_SEGMENTATION, TaskType.SEMANTIC_SEGMENTATION]:
386
- need_conversion = any(obj_cls.geometry_type == Polygon for obj_cls in self.project_meta.obj_classes)
395
+ need_conversion = any(
396
+ obj_cls.geometry_type == Polygon
397
+ for obj_cls in self.project_meta.obj_classes
398
+ )
387
399
  else:
388
400
  need_conversion = False
389
401
 
@@ -394,6 +406,7 @@ class TrainGUI:
394
406
 
395
407
  self.need_convert_shapes = need_conversion
396
408
  return need_conversion
409
+
397
410
  # ------------------------------------------------- #
398
411
 
399
412
  self.step_flow = StepFlow(self.stepper, self.app_options)
@@ -420,7 +433,7 @@ class TrainGUI:
420
433
  self.model_selector.widgets_to_disable,
421
434
  self.model_selector.validator_text,
422
435
  self.model_selector.validate_step,
423
- position=position
436
+ position=position,
424
437
  ).add_on_select_actions("model_selector", [set_experiment_name])
425
438
  position += 1
426
439
 
@@ -517,7 +530,9 @@ class TrainGUI:
517
530
  has_model_selector = self.show_model_selector and self.model_selector is not None
518
531
  has_classes_selector = self.show_classes_selector and self.classes_selector is not None
519
532
  has_tags_selector = self.show_tags_selector and self.tags_selector is not None
520
- has_train_val_splits = self.show_train_val_splits_selector and self.train_val_splits_selector is not None
533
+ has_train_val_splits = (
534
+ self.show_train_val_splits_selector and self.train_val_splits_selector is not None
535
+ )
521
536
 
522
537
  # Set step dependency chain
523
538
  prev_step = "input_selector"
@@ -571,11 +586,13 @@ class TrainGUI:
571
586
  @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
572
587
  def show_mb_speedtest(is_checked: bool):
573
588
  self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
589
+
574
590
  # ------------------------------------------------- #
575
591
 
576
592
  self.layout: Widget = self.stepper
577
593
 
578
594
  # Run from experiment page
595
+
579
596
  train_task_id = getenv("modal.state.trainTaskId", None)
580
597
  if train_task_id is not None:
581
598
  train_task_id = int(train_task_id)
@@ -584,7 +601,6 @@ class TrainGUI:
584
601
  self._run_from_experiment(train_task_id, train_mode)
585
602
  # ----------------------------------------- #
586
603
 
587
-
588
604
  def set_next_step(self):
589
605
  current_step = self.stepper.get_active_step()
590
606
  self.stepper.set_active_step(current_step + 1)
@@ -605,6 +621,8 @@ class TrainGUI:
605
621
  """
606
622
  if self.input_selector is not None:
607
623
  self.input_selector.button.enable()
624
+ if self.model_selector is not None:
625
+ self.model_selector.button.enable()
608
626
  if self.train_val_splits_selector is not None:
609
627
  self.train_val_splits_selector.button.enable()
610
628
  if self.classes_selector is not None:
@@ -622,6 +640,8 @@ class TrainGUI:
622
640
  """
623
641
  if self.input_selector is not None:
624
642
  self.input_selector.button.disable()
643
+ if self.model_selector is not None:
644
+ self.model_selector.button.disable()
625
645
  if self.train_val_splits_selector is not None:
626
646
  self.train_val_splits_selector.button.disable()
627
647
  if self.classes_selector is not None:
@@ -775,7 +795,9 @@ class TrainGUI:
775
795
  )
776
796
  return app_state
777
797
 
778
- def load_from_app_state(self, app_state: Union[str, dict], click_cb: bool = True, validate_steps: bool = True) -> None:
798
+ def load_from_app_state(
799
+ self, app_state: Union[str, dict], click_cb: bool = True, validate_steps: bool = True
800
+ ) -> None:
779
801
  """
780
802
  Load the GUI state from app state dictionary or path to the state file.
781
803
 
@@ -820,15 +842,15 @@ class TrainGUI:
820
842
  "TensorRT": True
821
843
  },
822
844
  },
823
- "experiment_name": "my_experiment",
845
+ "experiment_name": "My Experiment",
824
846
  }
825
847
  """
826
848
  if isinstance(app_state, str):
827
849
  app_state = sly_json.load_json_file(app_state)
828
-
850
+
829
851
  app_state = self.validate_app_state(app_state)
830
852
  options = app_state.get("options", {})
831
-
853
+
832
854
  # Set experiment name
833
855
  experiment_name = app_state.get("experiment_name")
834
856
  if experiment_name is not None:
@@ -839,7 +861,7 @@ class TrainGUI:
839
861
  if not init_fn(settings, options, click_cb, validate_steps):
840
862
  return False
841
863
  return True
842
-
864
+
843
865
  # GUI init steps
844
866
  _steps = [
845
867
  (self._init_input, app_state.get("input"), "Input project"),
@@ -856,12 +878,20 @@ class TrainGUI:
856
878
  logger.warning(f"Step '{step_name}' {idx}/{len(_steps)} failed to validate")
857
879
  return
858
880
  if validate_steps:
859
- logger.info(f"Step '{step_name}' {idx}/{len(_steps)} has been validated successfully")
881
+ logger.info(
882
+ f"Step '{step_name}' {idx}/{len(_steps)} has been validated successfully"
883
+ )
860
884
  if validate_steps:
861
885
  logger.info(f"All steps have been validated successfully")
862
886
  # ------------------------------------------------------------------ #
863
887
 
864
- def _init_input(self, input_settings: Union[dict, None], options: dict, click_cb: bool = True, validate: bool = True) -> bool:
888
+ def _init_input(
889
+ self,
890
+ input_settings: Union[dict, None],
891
+ options: dict,
892
+ click_cb: bool = True,
893
+ validate: bool = True,
894
+ ) -> bool:
865
895
  """
866
896
  Initialize the input selector with the given settings.
867
897
 
@@ -885,7 +915,13 @@ class TrainGUI:
885
915
  return is_valid
886
916
  # ----------------------------------------- #
887
917
 
888
- def _init_model(self, model_settings: dict, options: dict = None, click_cb: bool = True, validate: bool = True) -> bool:
918
+ def _init_model(
919
+ self,
920
+ model_settings: dict,
921
+ options: dict = None,
922
+ click_cb: bool = True,
923
+ validate: bool = True,
924
+ ) -> bool:
889
925
  """
890
926
  Initialize the model selector with the given settings.
891
927
 
@@ -909,14 +945,18 @@ class TrainGUI:
909
945
  # Custom
910
946
  elif model_settings["source"] == ModelSource.CUSTOM:
911
947
  self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
912
- self.model_selector.experiment_selector.set_by_task_id(model_settings["task_id"])
913
- active_row = self.model_selector.experiment_selector.get_selected_row()
914
- if model_settings["checkpoint"] not in active_row.checkpoints_names:
915
- raise ValueError(
916
- f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
917
- )
918
-
919
- active_row.set_selected_checkpoint_by_name(model_settings["checkpoint"])
948
+ self.model_selector.experiment_selector.set_selected_row_by_task_id(
949
+ model_settings["task_id"]
950
+ )
951
+ experiment_info = self.model_selector.experiment_selector.get_selected_experiment_info()
952
+ if model_settings["checkpoint"] not in experiment_info.checkpoints:
953
+ if f"checkpoints/{model_settings['checkpoint']}" not in experiment_info.checkpoints:
954
+ raise ValueError(
955
+ f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
956
+ )
957
+ self.model_selector.experiment_selector.set_selected_checkpoint_by_name(
958
+ model_settings["checkpoint"]
959
+ )
920
960
 
921
961
  is_valid = True
922
962
  if validate:
@@ -926,8 +966,10 @@ class TrainGUI:
926
966
  self.set_next_step()
927
967
  return is_valid
928
968
  # ----------------------------------------- #
929
-
930
- def _init_classes(self, classes_settings: list, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
969
+
970
+ def _init_classes(
971
+ self, classes_settings: list, options: dict, click_cb: bool = True, validate: bool = True
972
+ ) -> bool:
931
973
  """
932
974
  Initialize the classes selector with the given settings.
933
975
 
@@ -941,7 +983,7 @@ class TrainGUI:
941
983
  :type validate: bool
942
984
  """
943
985
  if self.classes_selector is None:
944
- return True # Selector disabled by app options
986
+ return True # Selector disabled by app options
945
987
 
946
988
  convert_class_shapes = options.get("convert_class_shapes", True)
947
989
  if convert_class_shapes:
@@ -958,7 +1000,9 @@ class TrainGUI:
958
1000
  return is_valid
959
1001
  # ----------------------------------------- #
960
1002
 
961
- def _init_tags(self, tags_settings: list, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
1003
+ def _init_tags(
1004
+ self, tags_settings: list, options: dict, click_cb: bool = True, validate: bool = True
1005
+ ) -> bool:
962
1006
  """
963
1007
  Initialize the tags selector with the given settings.
964
1008
 
@@ -972,7 +1016,7 @@ class TrainGUI:
972
1016
  :type validate: bool
973
1017
  """
974
1018
  if self.tags_selector is None:
975
- return True # Selector disabled by app options
1019
+ return True # Selector disabled by app options
976
1020
 
977
1021
  # Set Tags
978
1022
  self.tags_selector.set_tags(tags_settings)
@@ -985,7 +1029,13 @@ class TrainGUI:
985
1029
  return is_valid
986
1030
  # ----------------------------------------- #
987
1031
 
988
- def _init_train_val_splits(self, train_val_splits_settings: dict, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
1032
+ def _init_train_val_splits(
1033
+ self,
1034
+ train_val_splits_settings: dict,
1035
+ options: dict,
1036
+ click_cb: bool = True,
1037
+ validate: bool = True,
1038
+ ) -> bool:
989
1039
  """
990
1040
  Initialize the train/val splits selector with the given settings.
991
1041
 
@@ -999,7 +1049,7 @@ class TrainGUI:
999
1049
  :type validate: bool
1000
1050
  """
1001
1051
  if self.train_val_splits_selector is None:
1002
- return True # Selector disabled by app options
1052
+ return True # Selector disabled by app options
1003
1053
 
1004
1054
  if train_val_splits_settings == {}:
1005
1055
  available_methods = self.app_options.get("train_val_splits_methods", [])
@@ -1059,8 +1109,14 @@ class TrainGUI:
1059
1109
  self.train_val_splits_selector_cb()
1060
1110
  self.set_next_step()
1061
1111
  return is_valid
1062
-
1063
- def _init_hyperparameters(self, hyperparameters_settings: dict, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
1112
+
1113
+ def _init_hyperparameters(
1114
+ self,
1115
+ hyperparameters_settings: dict,
1116
+ options: dict,
1117
+ click_cb: bool = True,
1118
+ validate: bool = True,
1119
+ ) -> bool:
1064
1120
  """
1065
1121
  Initialize the hyperparameters selector with the given settings.
1066
1122
 
@@ -1101,6 +1157,7 @@ class TrainGUI:
1101
1157
  self.hyperparameters_selector_cb()
1102
1158
  self.set_next_step()
1103
1159
  return is_valid
1160
+
1104
1161
  # ----------------------------------------- #
1105
1162
 
1106
1163
  # Run from experiment page
@@ -1111,10 +1168,12 @@ class TrainGUI:
1111
1168
  app_state = sly_json.load_json_file(local_app_state_path)
1112
1169
  sly_fs.silent_remove(local_app_state_path)
1113
1170
  return app_state
1114
-
1171
+
1115
1172
  def _download_experiment_hparams(self, experiment_info: ExperimentInfo) -> dict:
1116
1173
  local_hparams_path = f"./{experiment_info.hyperparameters}"
1117
- remote_hparams_path = os.path.join(experiment_info.artifacts_dir, experiment_info.hyperparameters)
1174
+ remote_hparams_path = os.path.join(
1175
+ experiment_info.artifacts_dir, experiment_info.hyperparameters
1176
+ )
1118
1177
  self._api.file.download(self.team_id, remote_hparams_path, local_hparams_path)
1119
1178
  with open(local_hparams_path, "r") as f:
1120
1179
  hparams = f.read()
@@ -1129,11 +1188,14 @@ class TrainGUI:
1129
1188
  model_settings = {
1130
1189
  "source": ModelSource.CUSTOM,
1131
1190
  "task_id": train_task_id,
1132
- "checkpoint": experiment_info.best_checkpoint
1191
+ "checkpoint": experiment_info.best_checkpoint,
1133
1192
  }
1134
1193
 
1135
1194
  if experiment_state is not None:
1136
- self.input_selector.validator_text.set(f"Training configuration is loaded from the experiment: {experiment_info.experiment_name}.", "success")
1195
+ self.input_selector.validator_text.set(
1196
+ f"Training configuration is loaded from the experiment: {experiment_info.experiment_name}.",
1197
+ "success",
1198
+ )
1137
1199
  self.input_selector.validator_text.show()
1138
1200
  experiment_state = self._download_experiment_state(experiment_info)
1139
1201
  if train_mode == "continue":
@@ -1142,7 +1204,7 @@ class TrainGUI:
1142
1204
  else:
1143
1205
  self.input_selector.validator_text.set(
1144
1206
  f"Couldn't load full training configuration from the experiment: {experiment_info.experiment_name}. Only model and hyperparameters are loaded.",
1145
- "warning"
1207
+ "warning",
1146
1208
  )
1147
1209
  self.input_selector.validator_text.show()
1148
1210
  hparams = self._download_experiment_hparams(experiment_info)
@@ -1150,3 +1212,28 @@ class TrainGUI:
1150
1212
  if train_mode == "continue":
1151
1213
  self._init_model(model_settings, {}, click_cb=False, validate=False)
1152
1214
  # ----------------------------------------- #
1215
+
1216
+ def _extract_state_from_env(self):
1217
+ import ast
1218
+ import os
1219
+
1220
+ base = "modal.state"
1221
+ state = {}
1222
+ for key, value in os.environ.items():
1223
+ state_part = state
1224
+ if key.startswith(base):
1225
+ key = key.replace(base + ".", "")
1226
+ parts = key.split(".")
1227
+ while len(parts) > 1:
1228
+ part = parts.pop(0)
1229
+ state_part.setdefault(part, {})
1230
+ state_part = state_part[part]
1231
+ part = parts.pop(0)
1232
+ if value and (value[0] == "[" or value.isdigit()):
1233
+ state_part[part] = ast.literal_eval(value)
1234
+ elif value in ["True", "true", "False", "false"]:
1235
+ state_part[part] = value in ["True", "true"]
1236
+ else:
1237
+ state_part[part] = value
1238
+ return state
1239
+ # ----------------------------------------- #
@@ -26,6 +26,7 @@ class ModelSelector:
26
26
 
27
27
  def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
28
28
  # Init widgets
29
+ self.api = api
29
30
  self.pretrained_models_table = None
30
31
  self.experiment_selector = None
31
32
  self.model_source_tabs = None
@@ -50,7 +51,7 @@ class ModelSelector:
50
51
 
51
52
  # GUI Components
52
53
  self.pretrained_models_table = PretrainedModelsSelector(self.models)
53
- experiment_infos = get_experiment_infos(api, self.team_id, framework)
54
+ experiment_infos = get_experiment_infos(self.api, self.team_id, framework)
54
55
  if self.app_options.get("legacy_checkpoints", False):
55
56
  try:
56
57
  framework_cls = FrameworkMapper.get_framework_cls(framework, self.team_id)
@@ -59,7 +60,7 @@ class ModelSelector:
59
60
  except:
60
61
  logger.warning(f"Legacy checkpoints are not available for '{framework}'")
61
62
 
62
- self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
63
+ self.experiment_selector = ExperimentSelector(self.api, self.team_id, experiment_infos)
63
64
 
64
65
  tab_titles = []
65
66
  tab_descriptions = []
@@ -85,6 +86,7 @@ class ModelSelector:
85
86
  self.validator_text = Text("")
86
87
  self.validator_text.hide()
87
88
  self.button = Button("Select")
89
+
88
90
  self.display_widgets.extend([self.model_source_tabs, self.validator_text, self.button])
89
91
  # -------------------------------- #
90
92
 
@@ -118,14 +120,14 @@ class ModelSelector:
118
120
  model_name = _get_model_name(selected_row)
119
121
  else:
120
122
  selected_row = self.experiment_selector.get_selected_experiment_info()
121
- model_name = selected_row.get("model_name", None)
123
+ model_name = selected_row.model_name
122
124
  return model_name
123
125
 
124
126
  def get_model_info(self) -> dict:
125
127
  if self.get_model_source() == ModelSource.PRETRAINED:
126
128
  return self.pretrained_models_table.get_selected_row()
127
129
  else:
128
- return self.experiment_selector.get_selected_experiment_info()
130
+ return self.experiment_selector.get_selected_experiment_info().to_json()
129
131
 
130
132
  def get_checkpoint_name(self) -> str:
131
133
  if self.get_model_source() == ModelSource.PRETRAINED:
@@ -146,7 +148,7 @@ class ModelSelector:
146
148
  else:
147
149
  checkpoint_name = self.experiment_selector.get_selected_checkpoint_name()
148
150
  return checkpoint_name
149
-
151
+
150
152
  def get_checkpoint_link(self) -> str:
151
153
  if self.get_model_source() == ModelSource.PRETRAINED:
152
154
  selected_row = self.pretrained_models_table.get_selected_row()
@@ -182,4 +184,4 @@ class ModelSelector:
182
184
  if self.get_model_source() == ModelSource.PRETRAINED:
183
185
  return self.pretrained_models_table.get_selected_task_type()
184
186
  else:
185
- return self.experiment_selector.get_selected_task_type()
187
+ return self.experiment_selector.get_selected_experiment_info().task_type