supervisely 6.73.254__py3-none-any.whl → 6.73.256__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (61) hide show
  1. supervisely/api/api.py +16 -8
  2. supervisely/api/file_api.py +16 -5
  3. supervisely/api/task_api.py +4 -2
  4. supervisely/app/widgets/field/field.py +10 -7
  5. supervisely/app/widgets/grid_gallery_v2/grid_gallery_v2.py +3 -1
  6. supervisely/io/network_exceptions.py +14 -2
  7. supervisely/nn/benchmark/base_benchmark.py +33 -35
  8. supervisely/nn/benchmark/base_evaluator.py +27 -1
  9. supervisely/nn/benchmark/base_visualizer.py +8 -11
  10. supervisely/nn/benchmark/comparison/base_visualizer.py +147 -0
  11. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +1 -1
  12. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +5 -7
  13. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +4 -6
  14. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/{explore_predicttions.py → explore_predictions.py} +17 -17
  15. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +3 -5
  16. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +7 -9
  17. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +11 -22
  18. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +3 -5
  19. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +22 -20
  20. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +12 -6
  21. supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +31 -76
  22. supervisely/nn/benchmark/comparison/model_comparison.py +112 -19
  23. supervisely/nn/benchmark/comparison/semantic_segmentation/__init__.py +0 -0
  24. supervisely/nn/benchmark/comparison/semantic_segmentation/text_templates.py +128 -0
  25. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/__init__.py +21 -0
  26. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/classwise_error_analysis.py +68 -0
  27. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/explore_predictions.py +141 -0
  28. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/frequently_confused.py +71 -0
  29. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/iou_eou.py +68 -0
  30. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/overview.py +223 -0
  31. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py +57 -0
  32. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/speedtest.py +314 -0
  33. supervisely/nn/benchmark/comparison/semantic_segmentation/visualizer.py +159 -0
  34. supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -1
  35. supervisely/nn/benchmark/object_detection/evaluator.py +1 -1
  36. supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +1 -3
  37. supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +3 -0
  38. supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +3 -0
  39. supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +1 -1
  40. supervisely/nn/benchmark/object_detection/visualizer.py +5 -10
  41. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +12 -2
  42. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +8 -9
  43. supervisely/nn/benchmark/semantic_segmentation/text_templates.py +2 -2
  44. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +31 -1
  45. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -3
  46. supervisely/nn/benchmark/semantic_segmentation/visualizer.py +7 -6
  47. supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +3 -21
  48. supervisely/nn/benchmark/visualization/renderer.py +25 -10
  49. supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +1 -0
  50. supervisely/nn/inference/inference.py +1 -0
  51. supervisely/nn/training/gui/gui.py +32 -10
  52. supervisely/nn/training/gui/training_artifacts.py +145 -0
  53. supervisely/nn/training/gui/training_process.py +3 -19
  54. supervisely/nn/training/train_app.py +179 -70
  55. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/METADATA +1 -1
  56. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/RECORD +60 -48
  57. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +0 -19
  58. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/LICENSE +0 -0
  59. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/WHEEL +0 -0
  60. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/entry_points.txt +0 -0
  61. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,11 @@
1
1
  from supervisely.nn.benchmark.semantic_segmentation.base_vis_metric import (
2
2
  SemanticSegmVisMetric,
3
3
  )
4
- from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
4
+ from supervisely.nn.benchmark.visualization.widgets import (
5
+ ChartWidget,
6
+ MarkdownWidget,
7
+ TableWidget,
8
+ )
5
9
 
6
10
 
7
11
  class KeyMetrics(SemanticSegmVisMetric):
@@ -14,6 +18,32 @@ class KeyMetrics(SemanticSegmVisMetric):
14
18
  text=self.vis_texts.markdown_key_metrics,
15
19
  )
16
20
 
21
+ @property
22
+ def table(self) -> TableWidget:
23
+ columns = ["metrics", "values"]
24
+ content = []
25
+
26
+ metrics = self.eval_result.mp.key_metrics().copy()
27
+ metrics["mPixel accuracy"] = round(metrics["mPixel accuracy"] * 100, 2)
28
+
29
+ for metric, value in metrics.items():
30
+ row = [metric, round(value, 2)]
31
+ dct = {"row": row, "id": metric, "items": row}
32
+ content.append(dct)
33
+
34
+ columns_options = [{"disableSort": True}, {"disableSort": True}]
35
+ data = {"columns": columns, "columnsOptions": columns_options, "content": content}
36
+
37
+ table = TableWidget(
38
+ name="table_key_metrics",
39
+ data=data,
40
+ fix_columns=1,
41
+ width="60%",
42
+ show_header_controls=False,
43
+ main_column=columns[0],
44
+ )
45
+ return table
46
+
17
47
  @property
18
48
  def chart(self) -> ChartWidget:
19
49
  return ChartWidget("base_metrics_chart", self.get_figure())
@@ -26,9 +26,7 @@ class Overview(SemanticSegmVisMetric):
26
26
  link_text = link_text.replace("_", "\_")
27
27
 
28
28
  model_name = self.eval_result.inference_info.get("model_name") or "Custom"
29
- checkpoint_name = self.eval_result.inference_info.get("deploy_params", {}).get(
30
- "checkpoint_name", ""
31
- )
29
+ checkpoint_name = self.eval_result.checkpoint_name
32
30
 
33
31
  # Note about validation dataset
34
32
  classes_str, note_about_images, starter_app_info = self._get_overview_info()
@@ -93,6 +93,7 @@ class SemanticSegmentationVisualizer(BaseVisualizer):
93
93
  # key metrics
94
94
  key_metrics = KeyMetrics(self.vis_texts, self.eval_result)
95
95
  self.key_metrics_md = key_metrics.md
96
+ self.key_metrics_table = key_metrics.table
96
97
  self.key_metrics_chart = key_metrics.chart
97
98
 
98
99
  # explore predictions
@@ -143,15 +144,14 @@ class SemanticSegmentationVisualizer(BaseVisualizer):
143
144
  self.acknowledgement_md = acknowledgement.md
144
145
 
145
146
  # SpeedTest
146
- self.speedtest_present = False
147
- self.speedtest_multiple_batch_sizes = False
148
147
  speedtest = Speedtest(self.vis_texts, self.eval_result)
149
- if not speedtest.is_empty():
150
- self.speedtest_present = True
148
+ self.speedtest_present = not speedtest.is_empty()
149
+ self.speedtest_multiple_batch_sizes = False
150
+ if self.speedtest_present:
151
151
  self.speedtest_md_intro = speedtest.intro_md
152
152
  self.speedtest_intro_table = speedtest.intro_table
153
- if speedtest.multiple_batche_sizes():
154
- self.speedtest_multiple_batch_sizes = True
153
+ self.speedtest_multiple_batch_sizes = speedtest.multiple_batche_sizes()
154
+ if self.speedtest_multiple_batch_sizes:
155
155
  self.speedtest_batch_inference_md = speedtest.batch_size_md
156
156
  self.speedtest_chart = speedtest.chart
157
157
 
@@ -166,6 +166,7 @@ class SemanticSegmentationVisualizer(BaseVisualizer):
166
166
  (0, self.header),
167
167
  (1, self.overview_md),
168
168
  (1, self.key_metrics_md),
169
+ (0, self.key_metrics_table),
169
170
  (0, self.key_metrics_chart),
170
171
  (1, self.explore_predictions_md),
171
172
  (0, self.explore_predictions_gallery),
@@ -63,29 +63,11 @@ class Evaluator:
63
63
  :param boundary_implementation: Choose "exact" for the euclidean pixel distance.
64
64
  The Boundary IoU paper uses the L1 distance ("fast").
65
65
  """
66
- global torch, np, GPU
66
+ global torch, np, GPU, numpy
67
67
  import torch # pylint: disable=import-error
68
68
 
69
- if torch.cuda.is_available():
70
- GPU = True
71
- logger.info("Using GPU for evaluation.")
72
- try:
73
- # gpu-compatible numpy analogue
74
- import cupy as np # pylint: disable=import-error
75
-
76
- global numpy
77
- import numpy as numpy
78
- except:
79
- logger.warning(
80
- "Failed to import cupy. Use cupy official documentation to install this "
81
- "module: https://docs.cupy.dev/en/stable/install.html"
82
- )
83
- else:
84
- GPU = False
85
- import numpy as np
86
-
87
- global numpy
88
- numpy = np
69
+ numpy = np
70
+ GPU = False
89
71
 
90
72
  self.progress = progress or tqdm_sly
91
73
  self.class_names = class_names
@@ -20,6 +20,7 @@ class Renderer:
20
20
  layout: BaseWidget,
21
21
  base_dir: str = "./output",
22
22
  template: str = None,
23
+ report_name: str = "Model Evaluation Report.lnk",
23
24
  ) -> None:
24
25
  if template is None:
25
26
  template = (
@@ -28,6 +29,9 @@ class Renderer:
28
29
  self.main_template = template
29
30
  self.layout = layout
30
31
  self.base_dir = base_dir
32
+ self.report_name = report_name
33
+ self._report = None
34
+ self._lnk = None
31
35
 
32
36
  if Path(base_dir).exists():
33
37
  if not dir_empty(base_dir):
@@ -81,20 +85,31 @@ class Renderer:
81
85
  change_name_if_conflict=True,
82
86
  progress_size_cb=pbar,
83
87
  )
84
- src = self.save_report_link(api, team_id, remote_dir)
85
- api.file.upload(team_id=team_id, src=src, dst=remote_dir.rstrip("/") + "/open.lnk")
88
+ src = self._save_report_link(api, team_id, remote_dir)
89
+ dst = Path(remote_dir).joinpath(self.report_name)
90
+ self._lnk = api.file.upload(team_id=team_id, src=src, dst=str(dst))
86
91
  return remote_dir
87
92
 
88
- def save_report_link(self, api: Api, team_id: int, remote_dir: str):
89
- report_link = self.get_report_link(api, team_id, remote_dir)
90
- pth = Path(self.base_dir).joinpath("open.lnk")
93
+ def _save_report_link(self, api: Api, team_id: int, remote_dir: str):
94
+ report_link = self._get_report_path(api, team_id, remote_dir)
95
+ pth = Path(self.base_dir).joinpath(self.report_name)
91
96
  with open(pth, "w") as f:
92
97
  f.write(report_link)
93
98
  return str(pth)
94
99
 
95
- def get_report_link(self, api: Api, team_id: int, remote_dir: str):
96
- template_path = remote_dir.rstrip("/") + "/" + "template.vue"
97
- vue_template_info = api.file.get_info_by_path(team_id, template_path)
100
+ def _get_report_link(self, api: Api, team_id: int, remote_dir: str):
101
+ path = self._get_report_path(api, team_id, remote_dir)
102
+ return f"{api.server_address}{path}"
98
103
 
99
- report_link = "/model-benchmark?id=" + str(vue_template_info.id)
100
- return report_link
104
+ def _get_report_path(self, api: Api, team_id: int, remote_dir: str):
105
+ template_path = Path(remote_dir).joinpath("template.vue")
106
+ self._report = api.file.get_info_by_path(team_id, str(template_path))
107
+ return "/model-benchmark?id=" + str(self._report.id)
108
+
109
+ @property
110
+ def report(self):
111
+ return self._report
112
+
113
+ @property
114
+ def lnk(self):
115
+ return self._lnk
@@ -79,6 +79,7 @@ class GalleryWidget(BaseWidget):
79
79
  column_index=idx % self.columns_number,
80
80
  project_meta=project_metas[idx % self.columns_number],
81
81
  ignore_tags_filtering=skip_tags_filtering[idx % self.columns_number],
82
+ call_update=idx == len(image_infos) - 1,
82
83
  )
83
84
 
84
85
  def _get_init_data(self):
@@ -133,6 +133,7 @@ class Inference:
133
133
  if self.INFERENCE_SETTINGS is not None:
134
134
  custom_inference_settings = self.INFERENCE_SETTINGS
135
135
  else:
136
+ logger.debug("Custom inference settings are not provided.")
136
137
  custom_inference_settings = {}
137
138
  if isinstance(custom_inference_settings, str):
138
139
  if fs.file_exists(custom_inference_settings):
@@ -14,6 +14,7 @@ from supervisely.nn.training.gui.hyperparameters_selector import Hyperparameters
14
14
  from supervisely.nn.training.gui.input_selector import InputSelector
15
15
  from supervisely.nn.training.gui.model_selector import ModelSelector
16
16
  from supervisely.nn.training.gui.train_val_splits_selector import TrainValSplitsSelector
17
+ from supervisely.nn.training.gui.training_artifacts import TrainingArtifacts
17
18
  from supervisely.nn.training.gui.training_logs import TrainingLogs
18
19
  from supervisely.nn.training.gui.training_process import TrainingProcess
19
20
  from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
@@ -50,7 +51,9 @@ class TrainGUI:
50
51
  if is_production():
51
52
  self.task_id = sly_env.task_id()
52
53
  else:
53
- self.task_id = "debug-session"
54
+ self.task_id = sly_env.task_id(raise_not_found=False)
55
+ if self.task_id is None:
56
+ self.task_id = "debug-session"
54
57
 
55
58
  self.framework_name = framework_name
56
59
  self.models = models
@@ -86,17 +89,22 @@ class TrainGUI:
86
89
  # 7. Training logs
87
90
  self.training_logs = TrainingLogs(self.app_options)
88
91
 
92
+ # 8. Training Artifacts
93
+ self.training_artifacts = TrainingArtifacts(self.app_options)
94
+
89
95
  # Stepper layout
96
+ self.steps = [
97
+ self.input_selector.card,
98
+ self.train_val_splits_selector.card,
99
+ self.classes_selector.card,
100
+ self.model_selector.card,
101
+ self.hyperparameters_selector.card,
102
+ self.training_process.card,
103
+ self.training_logs.card,
104
+ self.training_artifacts.card,
105
+ ]
90
106
  self.stepper = Stepper(
91
- widgets=[
92
- self.input_selector.card,
93
- self.train_val_splits_selector.card,
94
- self.classes_selector.card,
95
- self.model_selector.card,
96
- self.hyperparameters_selector.card,
97
- self.training_process.card,
98
- self.training_logs.card,
99
- ],
107
+ widgets=self.steps,
100
108
  )
101
109
  # ------------------------------------------------- #
102
110
 
@@ -265,6 +273,20 @@ class TrainGUI:
265
273
 
266
274
  self.layout: Widget = self.stepper
267
275
 
276
+ def set_next_step(self):
277
+ current_step = self.stepper.get_active_step()
278
+ self.stepper.set_active_step(current_step + 1)
279
+
280
+ def set_previous_step(self):
281
+ current_step = self.stepper.get_active_step()
282
+ self.stepper.set_active_step(current_step - 1)
283
+
284
+ def set_first_step(self):
285
+ self.stepper.set_active_step(1)
286
+
287
+ def set_last_step(self):
288
+ self.stepper.set_active_step(len(self.steps))
289
+
268
290
  def enable_select_buttons(self):
269
291
  """
270
292
  Makes all select buttons in the GUI available for interaction.
@@ -0,0 +1,145 @@
1
+ from typing import Any, Dict
2
+
3
+ from supervisely import Api
4
+ from supervisely.app.widgets import (
5
+ Card,
6
+ Container,
7
+ Empty,
8
+ Field,
9
+ Flexbox,
10
+ FolderThumbnail,
11
+ ReportThumbnail,
12
+ Text,
13
+ )
14
+
15
+ PYTORCH_ICON = "https://img.icons8.com/?size=100&id=jH4BpkMnRrU5&format=png&color=000000"
16
+ ONNX_ICON = "https://artwork.lfaidata.foundation/projects/onnx/icon/color/onnx-icon-color.png"
17
+ TRT_ICON = "https://img.icons8.com/?size=100&id=yqf95864UzeQ&format=png&color=000000"
18
+
19
+
20
+ class TrainingArtifacts:
21
+ title = "Training Artifacts"
22
+ description = "All outputs of the training process will appear here"
23
+ lock_message = "Artifacts will be available after training is completed"
24
+
25
+ def __init__(self, app_options: Dict[str, Any]):
26
+ self.display_widgets = []
27
+ self.success_message_text = (
28
+ "Training completed. Training artifacts were uploaded to Team Files. "
29
+ "You can find and open tensorboard logs in the artifacts folder via the "
30
+ "<a href='https://ecosystem.supervisely.com/apps/tensorboard-logs-viewer' target='_blank'>Tensorboard</a> app."
31
+ )
32
+ self.app_options = app_options
33
+
34
+ # GUI Components
35
+ self.validator_text = Text("")
36
+ self.validator_text.hide()
37
+ self.display_widgets.extend([self.validator_text])
38
+
39
+ # Outputs
40
+ self.artifacts_thumbnail = FolderThumbnail()
41
+ self.artifacts_thumbnail.hide()
42
+
43
+ self.artifacts_field = Field(
44
+ title="Artifacts",
45
+ description="Contains all outputs of the training process",
46
+ content=self.artifacts_thumbnail,
47
+ )
48
+ self.artifacts_field.hide()
49
+ self.display_widgets.extend([self.artifacts_field])
50
+
51
+ # Optional Model Benchmark
52
+ if app_options.get("model_benchmark", False):
53
+ self.model_benchmark_report_thumbnail = ReportThumbnail()
54
+ self.model_benchmark_report_thumbnail.hide()
55
+
56
+ self.model_benchmark_fail_text = Text(
57
+ text="Model evaluation did not finish successfully. Please check the app logs for details.",
58
+ status="error",
59
+ )
60
+ self.model_benchmark_fail_text.hide()
61
+
62
+ self.model_benchmark_widgets = Container(
63
+ [self.model_benchmark_report_thumbnail, self.model_benchmark_fail_text]
64
+ )
65
+
66
+ self.model_benchmark_report_field = Field(
67
+ title="Model Benchmark",
68
+ description="Evaluation report of the trained model",
69
+ content=self.model_benchmark_widgets,
70
+ )
71
+ self.model_benchmark_report_field.hide()
72
+ self.display_widgets.extend([self.model_benchmark_report_field])
73
+ # -------------------------------- #
74
+
75
+ # PyTorch, ONNX, TensorRT demo
76
+ self.inference_demo_field = []
77
+ model_demo = self.app_options.get("demo", None)
78
+ if model_demo is not None:
79
+ pytorch_demo_link = model_demo.get("pytorch", None)
80
+ if pytorch_demo_link is not None:
81
+ pytorch_icon = Field.Icon(image_url=PYTORCH_ICON, bg_color_rgb=[255, 255, 255])
82
+ self.pytorch_instruction = Field(
83
+ title="PyTorch",
84
+ description="Open file",
85
+ description_url=pytorch_demo_link,
86
+ icon=pytorch_icon,
87
+ content=Empty(),
88
+ )
89
+ self.pytorch_instruction.hide()
90
+ self.inference_demo_field.extend([self.pytorch_instruction])
91
+
92
+ onnx_demo_link = model_demo.get("onnx", None)
93
+ if onnx_demo_link is not None:
94
+ if self.app_options.get("export_onnx_supported", False):
95
+ onnx_icon = Field.Icon(image_url=ONNX_ICON, bg_color_rgb=[255, 255, 255])
96
+ self.onnx_instruction = Field(
97
+ title="ONNX",
98
+ description="Open file",
99
+ description_url=onnx_demo_link,
100
+ icon=onnx_icon,
101
+ content=Empty(),
102
+ )
103
+ self.onnx_instruction.hide()
104
+ self.inference_demo_field.extend([self.onnx_instruction])
105
+
106
+ trt_demo_link = model_demo.get("tensorrt", None)
107
+ if trt_demo_link is not None:
108
+ if self.app_options.get("export_tensorrt_supported", False):
109
+ trt_icon = Field.Icon(image_url=TRT_ICON, bg_color_rgb=[255, 255, 255])
110
+ self.trt_instruction = Field(
111
+ title="TensorRT",
112
+ description="Open file",
113
+ description_url=trt_demo_link,
114
+ icon=trt_icon,
115
+ content=Empty(),
116
+ )
117
+ self.trt_instruction.hide()
118
+ self.inference_demo_field.extend([self.trt_instruction])
119
+
120
+ demo_overview_link = model_demo.get("overview", None)
121
+ self.inference_demo_field = Field(
122
+ title="How to run inference",
123
+ description="Instructions on how to use your checkpoints outside of Supervisely Platform",
124
+ content=Flexbox(self.inference_demo_field),
125
+ title_url=demo_overview_link,
126
+ )
127
+ self.inference_demo_field.hide()
128
+ self.display_widgets.extend([self.inference_demo_field])
129
+ # -------------------------------- #
130
+
131
+ self.container = Container(self.display_widgets)
132
+ self.card = Card(
133
+ title=self.title,
134
+ description=self.description,
135
+ content=self.container,
136
+ lock_message=self.lock_message,
137
+ )
138
+ self.card.lock()
139
+
140
+ @property
141
+ def widgets_to_disable(self) -> list:
142
+ return []
143
+
144
+ def validate_step(self) -> bool:
145
+ return True
@@ -23,23 +23,18 @@ class TrainingProcess:
23
23
 
24
24
  def __init__(self, app_options: Dict[str, Any]):
25
25
  self.display_widgets = []
26
- self.success_message_text = (
27
- "Training completed. Training artifacts were uploaded to Team Files. "
28
- "You can find and open tensorboard logs in the artifacts folder via the "
29
- "<a href='https://ecosystem.supervisely.com/apps/tensorboard-logs-viewer' target='_blank'>Tensorboard</a> app."
30
- )
31
26
  self.app_options = app_options
32
27
 
33
28
  # GUI Components
34
29
  # Optional Select CUDA device
35
30
  if self.app_options.get("device_selector", False):
36
31
  self.select_device = SelectCudaDevice()
37
- self.select_cuda_device_field = Field(
32
+ self.select_device_field = Field(
38
33
  title="Select CUDA device",
39
34
  description="The device on which the model will be trained",
40
35
  content=self.select_device,
41
36
  )
42
- self.display_widgets.extend([self.select_cuda_device_field])
37
+ self.display_widgets.extend([self.select_device_field])
43
38
  # -------------------------------- #
44
39
 
45
40
  self.experiment_name_input = Input("Enter experiment name")
@@ -63,26 +58,15 @@ class TrainingProcess:
63
58
  self.validator_text = Text("")
64
59
  self.validator_text.hide()
65
60
 
66
- self.artifacts_thumbnail = FolderThumbnail()
67
- self.artifacts_thumbnail.hide()
68
-
69
61
  self.display_widgets.extend(
70
62
  [
71
63
  self.experiment_name_field,
72
64
  button_container,
73
65
  self.validator_text,
74
- self.artifacts_thumbnail,
75
66
  ]
76
67
  )
77
68
  # -------------------------------- #
78
69
 
79
- # Optional Model Benchmark
80
- if app_options.get("model_benchmark", False):
81
- self.model_benchmark_report_thumbnail = ReportThumbnail()
82
- self.model_benchmark_report_thumbnail.hide()
83
- self.display_widgets.extend([self.model_benchmark_report_thumbnail])
84
- # -------------------------------- #
85
-
86
70
  self.container = Container(self.display_widgets)
87
71
  self.card = Card(
88
72
  title=self.title,
@@ -96,7 +80,7 @@ class TrainingProcess:
96
80
  def widgets_to_disable(self) -> list:
97
81
  widgets = [self.experiment_name_input]
98
82
  if self.app_options.get("device_selector", False):
99
- widgets.append(self.experiment_name_input)
83
+ widgets.extend([self.select_device, self.select_device_field])
100
84
  return widgets
101
85
 
102
86
  def validate_step(self) -> bool: