supervisely 6.73.243__py3-none-any.whl → 6.73.245__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (56) hide show
  1. supervisely/__init__.py +1 -1
  2. supervisely/_utils.py +18 -0
  3. supervisely/app/widgets/__init__.py +1 -0
  4. supervisely/app/widgets/card/card.py +3 -0
  5. supervisely/app/widgets/classes_table/classes_table.py +15 -1
  6. supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
  7. supervisely/app/widgets/custom_models_selector/template.html +1 -1
  8. supervisely/app/widgets/experiment_selector/__init__.py +0 -0
  9. supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
  10. supervisely/app/widgets/experiment_selector/style.css +27 -0
  11. supervisely/app/widgets/experiment_selector/template.html +82 -0
  12. supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
  13. supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
  14. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
  15. supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
  16. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  17. supervisely/nn/__init__.py +3 -1
  18. supervisely/nn/artifacts/artifacts.py +10 -0
  19. supervisely/nn/artifacts/detectron2.py +2 -0
  20. supervisely/nn/artifacts/hrda.py +3 -0
  21. supervisely/nn/artifacts/mmclassification.py +2 -0
  22. supervisely/nn/artifacts/mmdetection.py +6 -3
  23. supervisely/nn/artifacts/mmsegmentation.py +2 -0
  24. supervisely/nn/artifacts/ritm.py +3 -1
  25. supervisely/nn/artifacts/rtdetr.py +2 -0
  26. supervisely/nn/artifacts/unet.py +2 -0
  27. supervisely/nn/artifacts/yolov5.py +3 -0
  28. supervisely/nn/artifacts/yolov8.py +7 -1
  29. supervisely/nn/experiments.py +113 -0
  30. supervisely/nn/inference/gui/__init__.py +3 -1
  31. supervisely/nn/inference/gui/gui.py +31 -232
  32. supervisely/nn/inference/gui/serving_gui.py +223 -0
  33. supervisely/nn/inference/gui/serving_gui_template.py +240 -0
  34. supervisely/nn/inference/inference.py +225 -24
  35. supervisely/nn/training/__init__.py +0 -0
  36. supervisely/nn/training/gui/__init__.py +1 -0
  37. supervisely/nn/training/gui/classes_selector.py +100 -0
  38. supervisely/nn/training/gui/gui.py +539 -0
  39. supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
  40. supervisely/nn/training/gui/input_selector.py +70 -0
  41. supervisely/nn/training/gui/model_selector.py +95 -0
  42. supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
  43. supervisely/nn/training/gui/training_logs.py +93 -0
  44. supervisely/nn/training/gui/training_process.py +114 -0
  45. supervisely/nn/training/gui/utils.py +128 -0
  46. supervisely/nn/training/loggers/__init__.py +0 -0
  47. supervisely/nn/training/loggers/base_train_logger.py +58 -0
  48. supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
  49. supervisely/nn/training/train_app.py +2038 -0
  50. supervisely/nn/utils.py +5 -0
  51. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/METADATA +3 -1
  52. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/RECORD +56 -34
  53. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/LICENSE +0 -0
  54. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/WHEEL +0 -0
  55. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/entry_points.txt +0 -0
  56. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from dataclasses import asdict
13
13
  from functools import partial, wraps
14
14
  from queue import Queue
15
15
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
+ from urllib.request import urlopen
16
17
 
17
18
  import numpy as np
18
19
  import requests
@@ -25,10 +26,13 @@ import supervisely.app.development as sly_app_development
25
26
  import supervisely.imaging.image as sly_image
26
27
  import supervisely.io.env as env
27
28
  import supervisely.io.fs as fs
29
+ import supervisely.io.fs as sly_fs
30
+ import supervisely.io.json as sly_json
28
31
  import supervisely.nn.inference.gui as GUI
29
32
  from supervisely import DatasetInfo, ProjectInfo, VideoAnnotation, batched
30
33
  from supervisely._utils import (
31
34
  add_callback,
35
+ get_filename_from_headers,
32
36
  is_debug_with_sly_net,
33
37
  is_production,
34
38
  rand_str,
@@ -59,7 +63,13 @@ from supervisely.geometry.any_geometry import AnyGeometry
59
63
  from supervisely.imaging.color import get_predefined_colors
60
64
  from supervisely.nn.inference.cache import InferenceImageCache
61
65
  from supervisely.nn.prediction_dto import Prediction
62
- from supervisely.nn.utils import CheckpointInfo, DeployInfo, ModelPrecision, RuntimeType
66
+ from supervisely.nn.utils import (
67
+ CheckpointInfo,
68
+ DeployInfo,
69
+ ModelPrecision,
70
+ ModelSource,
71
+ RuntimeType,
72
+ )
63
73
  from supervisely.project import ProjectType
64
74
  from supervisely.project.download import download_to_cache, read_from_cached_project
65
75
  from supervisely.project.project_meta import ProjectMeta
@@ -74,6 +84,12 @@ except ImportError:
74
84
 
75
85
 
76
86
  class Inference:
87
+ FRAMEWORK_NAME: str = None
88
+ """Name of framework to register models in Supervisely"""
89
+ MODELS: str = None
90
+ """Path to file with list of models"""
91
+ APP_OPTIONS: str = None
92
+ """Path to file with app options"""
77
93
  DEFAULT_BATCH_SIZE = 16
78
94
 
79
95
  def __init__(
@@ -85,6 +101,7 @@ class Inference:
85
101
  sliding_window_mode: Optional[Literal["basic", "advanced", "none"]] = "basic",
86
102
  use_gui: Optional[bool] = False,
87
103
  multithread_inference: Optional[bool] = True,
104
+ use_serving_gui_template: Optional[bool] = False,
88
105
  ):
89
106
  if model_dir is None:
90
107
  model_dir = os.path.join(get_data_dir(), "models")
@@ -92,8 +109,10 @@ class Inference:
92
109
  self.device: str = None
93
110
  self.runtime: str = None
94
111
  self.model_precision: str = None
112
+ self.model_source: str = None
95
113
  self.checkpoint_info: CheckpointInfo = None
96
114
  self.max_batch_size: int = None # set it only if a model has a limit on the batch size
115
+ self.classes: List[str] = None
97
116
  self._model_dir = model_dir
98
117
  self._model_served = False
99
118
  self._deploy_params: dict = None
@@ -117,6 +136,7 @@ class Inference:
117
136
  self._custom_inference_settings = custom_inference_settings
118
137
 
119
138
  self._use_gui = use_gui
139
+ self._use_serving_gui_template = use_serving_gui_template
120
140
  self._gui = None
121
141
 
122
142
  self.load_on_device = LOAD_ON_DEVICE_DECORATOR(self.load_on_device)
@@ -124,20 +144,48 @@ class Inference:
124
144
 
125
145
  self.load_model = LOAD_MODEL_DECORATOR(self.load_model)
126
146
 
127
- if use_gui:
147
+ if self._use_gui:
128
148
  initialize_custom_gui_method = getattr(self, "initialize_custom_gui", None)
129
149
  original_initialize_custom_gui_method = getattr(
130
150
  Inference, "initialize_custom_gui", None
131
151
  )
132
- if initialize_custom_gui_method.__func__ is not original_initialize_custom_gui_method:
152
+ if self._use_serving_gui_template:
153
+ if self.FRAMEWORK_NAME is None:
154
+ raise ValueError("FRAMEWORK_NAME is not defined")
155
+ self._gui = GUI.ServingGUITemplate(
156
+ self.FRAMEWORK_NAME, self.MODELS, self.APP_OPTIONS
157
+ )
158
+ self._user_layout = self._gui.widgets
159
+ self._user_layout_card = self._gui.card
160
+ elif initialize_custom_gui_method.__func__ is not original_initialize_custom_gui_method:
133
161
  self._gui = GUI.ServingGUI()
134
162
  self._user_layout = self.initialize_custom_gui()
135
163
  else:
136
- self.initialize_gui()
164
+ initialize_custom_gui_method = getattr(self, "initialize_custom_gui", None)
165
+ original_initialize_custom_gui_method = getattr(
166
+ Inference, "initialize_custom_gui", None
167
+ )
168
+ if (
169
+ initialize_custom_gui_method.__func__
170
+ is not original_initialize_custom_gui_method
171
+ ):
172
+ self._gui = GUI.ServingGUI()
173
+ self._user_layout = self.initialize_custom_gui()
174
+ else:
175
+ self.initialize_gui()
137
176
 
138
- def on_serve_callback(gui: Union[GUI.InferenceGUI, GUI.ServingGUI]):
177
+ def on_serve_callback(
178
+ gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
179
+ ):
139
180
  Progress("Deploying model ...", 1)
140
- if isinstance(self.gui, GUI.ServingGUI):
181
+ if isinstance(self.gui, GUI.ServingGUITemplate):
182
+ deploy_params = self.get_params_from_gui()
183
+ model_files = self._download_model_files(
184
+ deploy_params["model_source"], deploy_params["model_files"]
185
+ )
186
+ deploy_params["model_files"] = model_files
187
+ self._load_model_headless(**deploy_params)
188
+ elif isinstance(self.gui, GUI.ServingGUI):
141
189
  deploy_params = self.get_params_from_gui()
142
190
  self._load_model(deploy_params)
143
191
  else: # GUI.InferenceGUI
@@ -146,9 +194,11 @@ class Inference:
146
194
  self.load_on_device(self._model_dir, device)
147
195
  gui.show_deployed_model_info(self)
148
196
 
149
- def on_change_model_callback(gui: Union[GUI.InferenceGUI, GUI.ServingGUI]):
197
+ def on_change_model_callback(
198
+ gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
199
+ ):
150
200
  self.shutdown_model()
151
- if isinstance(self.gui, GUI.ServingGUI):
201
+ if isinstance(self.gui, (GUI.ServingGUI, GUI.ServingGUITemplate)):
152
202
  self._api_request_model_layout.unlock()
153
203
  self._api_request_model_layout.hide()
154
204
  self.update_gui(self._model_served)
@@ -198,7 +248,7 @@ class Inference:
198
248
  raise NotImplementedError("Have to be implemented in child class after inheritance")
199
249
 
200
250
  def update_gui(self, is_model_deployed: bool = True) -> None:
201
- if isinstance(self.gui, GUI.ServingGUI):
251
+ if isinstance(self.gui, (GUI.ServingGUI, GUI.ServingGUITemplate)):
202
252
  if is_model_deployed:
203
253
  self._user_layout_card.lock()
204
254
  else:
@@ -211,6 +261,8 @@ class Inference:
211
261
  self._api_request_model_layout.show()
212
262
 
213
263
  def get_params_from_gui(self) -> dict:
264
+ if isinstance(self.gui, GUI.ServingGUITemplate):
265
+ return self.gui.get_params_from_gui()
214
266
  raise NotImplementedError("Have to be implemented in child class after inheritance")
215
267
 
216
268
  def initialize_gui(self) -> None:
@@ -237,6 +289,25 @@ class Inference:
237
289
  )
238
290
 
239
291
  def _initialize_app_layout(self):
292
+ self._api_request_model_info = Editor(
293
+ height_lines=12,
294
+ language_mode="json",
295
+ readonly=True,
296
+ restore_default_button=False,
297
+ auto_format=True,
298
+ )
299
+ self._api_request_model_layout = Card(
300
+ title="Model was deployed from API request with the following settings",
301
+ content=self._api_request_model_info,
302
+ )
303
+ self._api_request_model_layout.hide()
304
+
305
+ if isinstance(self.gui, GUI.ServingGUITemplate):
306
+ self._app_layout = Container(
307
+ [self._user_layout_card, self._api_request_model_layout, self.get_ui()], gap=5
308
+ )
309
+ return
310
+
240
311
  if hasattr(self, "_user_layout"):
241
312
  self._user_layout_card = Card(
242
313
  title="Select Model",
@@ -251,20 +322,9 @@ class Inference:
251
322
  content=self._gui,
252
323
  lock_message="Model is deployed. To change the model, stop the serving first.",
253
324
  )
254
- self._api_request_model_info = Editor(
255
- height_lines=12,
256
- language_mode="json",
257
- readonly=True,
258
- restore_default_button=False,
259
- auto_format=True,
260
- )
261
- self._api_request_model_layout = Card(
262
- title="Model was deployed from API request with the following settings",
263
- content=self._api_request_model_info,
264
- )
265
- self._api_request_model_layout.hide()
325
+
266
326
  self._app_layout = Container(
267
- [self._user_layout_card, self._api_request_model_layout, self.get_ui()]
327
+ [self._user_layout_card, self._api_request_model_layout, self.get_ui()], gap=5
268
328
  )
269
329
 
270
330
  def support_custom_models(self) -> bool:
@@ -427,7 +487,74 @@ class Inference:
427
487
  def load_model_meta(self, model_tab: str, local_weights_path: str):
428
488
  raise NotImplementedError("Have to be implemented in child class after inheritance")
429
489
 
490
+ def _download_model_files(self, model_source: str, model_files: List[str]) -> dict:
491
+ if model_source == ModelSource.PRETRAINED:
492
+ return self._download_pretrained_model(model_files)
493
+ elif model_source == ModelSource.CUSTOM:
494
+ return self._download_custom_model(model_files)
495
+
496
+ def _download_pretrained_model(self, model_files: dict):
497
+ """
498
+ Downloads the pretrained model data.
499
+ """
500
+ local_model_files = {}
501
+
502
+ for file in model_files:
503
+ file_url = model_files[file]
504
+ file_path = os.path.join(self.model_dir, file)
505
+ if file_url.startswith("http"):
506
+ with urlopen(file_url) as f:
507
+ file_size = f.length
508
+ file_name = get_filename_from_headers(file_url)
509
+ file_path = os.path.join(self.model_dir, file_name)
510
+ if file_name is None:
511
+ file_name = file
512
+ with self.gui.download_progress(
513
+ message=f"Downloading: '{file_name}'",
514
+ total=file_size,
515
+ unit="bytes",
516
+ unit_scale=True,
517
+ ) as download_pbar:
518
+ self.gui.download_progress.show()
519
+ sly_fs.download(
520
+ url=file_url, save_path=file_path, progress=download_pbar.update
521
+ )
522
+ local_model_files[file] = file_path
523
+ else:
524
+ local_model_files[file] = file_url
525
+ self.gui.download_progress.hide()
526
+ return local_model_files
527
+
528
+ def _download_custom_model(self, model_files: dict):
529
+ """
530
+ Downloads the custom model data.
531
+ """
532
+
533
+ team_id = env.team_id()
534
+ local_model_files = {}
535
+
536
+ for file in model_files:
537
+ file_url = model_files[file]
538
+ file_info = self.api.file.get_info_by_path(team_id, file_url)
539
+ file_size = file_info.sizeb
540
+ file_name = os.path.basename(file_url)
541
+ file_path = os.path.join(self.model_dir, file_name)
542
+ with self.gui.download_progress(
543
+ message=f"Downloading: '{file_name}'",
544
+ total=file_size,
545
+ unit="bytes",
546
+ unit_scale=True,
547
+ ) as download_pbar:
548
+ self.gui.download_progress.show()
549
+ self.api.file.download(
550
+ team_id, file_url, file_path, progress_cb=download_pbar.update
551
+ )
552
+ local_model_files[file] = file_path
553
+ self.gui.download_progress.hide()
554
+ return local_model_files
555
+
430
556
  def _load_model(self, deploy_params: dict):
557
+ self.model_source = deploy_params.get("model_source")
431
558
  self.device = deploy_params.get("device")
432
559
  self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
433
560
  self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
@@ -439,6 +566,64 @@ class Inference:
439
566
  self.update_gui(self._model_served)
440
567
  self.gui.show_deployed_model_info(self)
441
568
 
569
+ def _load_model_headless(
570
+ self,
571
+ model_files: dict,
572
+ model_source: str,
573
+ model_info: dict,
574
+ device: str,
575
+ runtime: str,
576
+ **kwargs,
577
+ ):
578
+ deploy_params = {
579
+ "model_files": model_files,
580
+ "model_source": model_source,
581
+ "model_info": model_info,
582
+ "device": device,
583
+ "runtime": runtime,
584
+ **kwargs,
585
+ }
586
+ if model_source == ModelSource.CUSTOM:
587
+ self._set_model_meta_custom_model(model_info)
588
+ self._set_checkpoint_info_custom_model(deploy_params)
589
+ self._load_model(deploy_params)
590
+ if self._model_meta is None:
591
+ self._set_model_meta_from_classes()
592
+
593
+ def _set_model_meta_custom_model(self, model_info: dict):
594
+ model_meta = model_info.get("model_meta")
595
+ if model_meta is None:
596
+ return
597
+ if isinstance(model_meta, dict):
598
+ self._model_meta = ProjectMeta.from_json(model_meta)
599
+ elif isinstance(model_meta, str):
600
+ remote_artifacts_dir = model_info["artifacts_dir"]
601
+ model_meta_url = os.path.join(remote_artifacts_dir, model_meta)
602
+ model_meta_path = self.download(model_meta_url)
603
+ model_meta = sly_json.load_json_file(model_meta_path)
604
+ self._model_meta = ProjectMeta.from_json(model_meta)
605
+ else:
606
+ raise ValueError(
607
+ "model_meta should be a dict or a name of '.json' file in experiment artifacts folder in Team Files"
608
+ )
609
+ self._get_confidence_tag_meta()
610
+ self.classes = [obj_class.name for obj_class in self._model_meta.obj_classes]
611
+
612
+ def _set_checkpoint_info_custom_model(self, deploy_params: dict):
613
+ model_info = deploy_params.get("model_info", {})
614
+ model_files = deploy_params.get("model_files", {})
615
+ if model_info:
616
+ checkpoint_name = os.path.basename(model_files.get("checkpoint"))
617
+ self.checkpoint_info = CheckpointInfo(
618
+ checkpoint_name=checkpoint_name,
619
+ model_name=model_info.get("model_name"),
620
+ architecture=model_info.get("framework_name"),
621
+ custom_checkpoint_path=os.path.join(
622
+ model_info.get("artifacts_dir"), checkpoint_name
623
+ ),
624
+ model_source=ModelSource.CUSTOM,
625
+ )
626
+
442
627
  def shutdown_model(self):
443
628
  self._model_served = False
444
629
  self.device = None
@@ -453,7 +638,7 @@ class Inference:
453
638
  pass
454
639
 
455
640
  def get_classes(self) -> List[str]:
456
- raise NotImplementedError("Have to be implemented in child class after inheritance")
641
+ return self.classes
457
642
 
458
643
  def get_info(self) -> Dict[str, Any]:
459
644
  num_classes = None
@@ -550,6 +735,14 @@ class Inference:
550
735
  self._model_meta = ProjectMeta(classes)
551
736
  self._get_confidence_tag_meta()
552
737
 
738
+ def _set_model_meta_from_classes(self):
739
+ classes = self.get_classes()
740
+ if not classes:
741
+ raise ValueError("Can't create model meta. Please, set the `self.classes` attribute.")
742
+ shape = self._get_obj_class_shape()
743
+ self._model_meta = ProjectMeta([ObjClass(name, shape) for name in classes])
744
+ self._get_confidence_tag_meta()
745
+
553
746
  @property
554
747
  def task_id(self) -> int:
555
748
  return self._task_id
@@ -2420,7 +2613,15 @@ class Inference:
2420
2613
  self.shutdown_model()
2421
2614
  state = request.state.state
2422
2615
  deploy_params = state["deploy_params"]
2423
- self._load_model(deploy_params)
2616
+ if isinstance(self.gui, GUI.ServingGUITemplate):
2617
+ model_files = self._download_model_files(
2618
+ deploy_params["model_source"], deploy_params["model_files"]
2619
+ )
2620
+ deploy_params["model_files"] = model_files
2621
+ self._load_model_headless(**deploy_params)
2622
+ elif isinstance(self.gui, GUI.ServingGUI):
2623
+ self._load_model(deploy_params)
2624
+
2424
2625
  self.set_params_to_gui(deploy_params)
2425
2626
  # update to set correct device
2426
2627
  device = deploy_params.get("device", "cpu")
File without changes
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,100 @@
1
+ from supervisely._utils import abs_url, is_debug_with_sly_net, is_development
2
+ from supervisely.app.widgets import Button, Card, ClassesTable, Container, Text
3
+
4
+
5
+ class ClassesSelector:
6
+ title = "Classes Selector"
7
+ description = (
8
+ "Select classes that will be used for training. "
9
+ "Supported shapes are Bitmap, Polygon, Rectangle."
10
+ )
11
+ lock_message = "Select training and validation splits to unlock"
12
+
13
+ def __init__(self, project_id: int, classes: list, app_options: dict = {}):
14
+ self.classes_table = ClassesTable(project_id=project_id) # use dataset_ids
15
+ if len(classes) > 0:
16
+ self.classes_table.select_classes(classes) # from app options
17
+ else:
18
+ self.classes_table.select_all()
19
+
20
+ if is_development() or is_debug_with_sly_net():
21
+ qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
22
+ else:
23
+ qa_stats_link = f"/projects/{project_id}/stats/datasets"
24
+
25
+ qa_stats_text = Text(
26
+ text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
27
+ )
28
+
29
+ self.validator_text = Text("")
30
+ self.validator_text.hide()
31
+ self.button = Button("Select")
32
+ container = Container(
33
+ [
34
+ qa_stats_text,
35
+ self.classes_table,
36
+ self.validator_text,
37
+ self.button,
38
+ ]
39
+ )
40
+ self.card = Card(
41
+ title=self.title,
42
+ description=self.description,
43
+ content=container,
44
+ lock_message=self.lock_message,
45
+ collapsable=app_options.get("collapsable", False),
46
+ )
47
+ self.card.lock()
48
+
49
+ @property
50
+ def widgets_to_disable(self) -> list:
51
+ return [self.classes_table]
52
+
53
+ def get_selected_classes(self) -> list:
54
+ return self.classes_table.get_selected_classes()
55
+
56
+ def set_classes(self, classes) -> None:
57
+ self.classes_table.select_classes(classes)
58
+
59
+ def select_all_classes(self) -> None:
60
+ self.classes_table.select_all()
61
+
62
+ def validate_step(self) -> bool:
63
+ self.validator_text.hide()
64
+
65
+ if len(self.classes_table.project_meta.obj_classes) == 0:
66
+ self.validator_text.set(text="Project has no classes", status="error")
67
+ self.validator_text.show()
68
+ return False
69
+
70
+ selected_classes = self.classes_table.get_selected_classes()
71
+ table_data = self.classes_table._table_data
72
+
73
+ empty_classes = [
74
+ row[0]["data"]
75
+ for row in table_data
76
+ if row[0]["data"] in selected_classes and row[2]["data"] == 0 and row[3]["data"] == 0
77
+ ]
78
+
79
+ n_classes = len(selected_classes)
80
+ if n_classes == 0:
81
+ self.validator_text.set(text="Please select at least one class", status="error")
82
+ else:
83
+ warning_text = ""
84
+ status = "success"
85
+ if empty_classes:
86
+ intersections = set(selected_classes).intersection(empty_classes)
87
+ if intersections:
88
+ warning_text = (
89
+ f". Selected class has no annotations: {', '.join(intersections)}"
90
+ if len(intersections) == 1
91
+ else f". Selected classes have no annotations: {', '.join(intersections)}"
92
+ )
93
+ status = "warning"
94
+
95
+ class_text = "class" if n_classes == 1 else "classes"
96
+ self.validator_text.set(
97
+ text=f"Selected {n_classes} {class_text}{warning_text}", status=status
98
+ )
99
+ self.validator_text.show()
100
+ return n_classes > 0