supervisely 6.73.359__py3-none-any.whl → 6.73.360__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,7 @@ from supervisely.project.download import is_cached
12
12
 
13
13
  class InputSelector:
14
14
  title = "Input project"
15
- description = "Selected project from which images and annotations will be downloaded"
15
+ description = "Selected project from which items and annotations will be downloaded"
16
16
  lock_message = None
17
17
 
18
18
  def __init__(self, project_info: ProjectInfo, app_options: dict = {}):
@@ -36,6 +36,13 @@ class ModelSelector:
36
36
  self.display_widgets = []
37
37
  self.app_options = app_options
38
38
 
39
+ model_selector_opts = self.app_options.get("model_selector", {})
40
+ if not isinstance(model_selector_opts, dict):
41
+ model_selector_opts = {}
42
+
43
+ self.show_pretrained = True
44
+ self.show_custom = model_selector_opts.get("show_custom", True)
45
+
39
46
  self.team_id = sly_env.team_id()
40
47
  self.models = models
41
48
 
@@ -51,15 +58,28 @@ class ModelSelector:
51
58
  logger.warning(f"Legacy checkpoints are not available for '{framework}'")
52
59
 
53
60
  self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
61
+
62
+ tab_titles = []
63
+ tab_descriptions = []
64
+ tab_contents = []
65
+ if self.show_pretrained:
66
+ tab_titles.append(ModelSource.PRETRAINED)
67
+ tab_descriptions.append("Publicly available models")
68
+ tab_contents.append(self.pretrained_models_table)
69
+ if self.show_custom:
70
+ tab_titles.append(ModelSource.CUSTOM)
71
+ tab_descriptions.append("Models trained by you in Supervisely")
72
+ tab_contents.append(self.experiment_selector)
73
+
54
74
  self.model_source_tabs = RadioTabs(
55
- titles=[ModelSource.PRETRAINED, ModelSource.CUSTOM],
56
- descriptions=[
57
- "Publicly available models",
58
- "Models trained by you in Supervisely",
59
- ],
60
- contents=[self.pretrained_models_table, self.experiment_selector],
75
+ titles=tab_titles,
76
+ descriptions=tab_descriptions,
77
+ contents=tab_contents,
61
78
  )
62
79
 
80
+ if len(tab_titles) > 0:
81
+ self.model_source_tabs.set_active_tab(tab_titles[0])
82
+
63
83
  self.validator_text = Text("")
64
84
  self.validator_text.hide()
65
85
  self.button = Button("Select")
@@ -0,0 +1,105 @@
1
+ from supervisely._utils import abs_url, is_debug_with_sly_net, is_development
2
+ from supervisely.app.widgets import Button, Card, Container, TagsTable, Text
3
+
4
+
5
+ class TagsSelector:
6
+ title = "Tags Selector"
7
+ description = "Select tags that will be used for training"
8
+ lock_message = "Select training and validation splits to unlock"
9
+
10
+ def __init__(self, project_id: int, tags: list, app_options: dict = {}):
11
+ # Init widgets
12
+ self.qa_stats_text = None
13
+ self.tags_table = None
14
+ self.validator_text = None
15
+ self.button = None
16
+ self.container = None
17
+ self.card = None
18
+ # -------------------------------- #
19
+
20
+ self.display_widgets = []
21
+ self.app_options = app_options
22
+
23
+ # GUI Components
24
+ if is_development() or is_debug_with_sly_net():
25
+ qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
26
+ else:
27
+ qa_stats_link = f"/projects/{project_id}/stats/datasets"
28
+ self.qa_stats_text = Text(
29
+ text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
30
+ )
31
+
32
+ self.tags_table = TagsTable(project_id=project_id)
33
+ if len(tags) > 0:
34
+ self.tags_table.select_tags(tags)
35
+ else:
36
+ self.tags_table.select_all()
37
+
38
+ self.validator_text = Text("")
39
+ self.validator_text.hide()
40
+ self.button = Button("Select")
41
+ self.display_widgets.extend(
42
+ [self.qa_stats_text, self.tags_table, self.validator_text, self.button]
43
+ )
44
+ # -------------------------------- #
45
+
46
+ self.container = Container(self.display_widgets)
47
+ self.card = Card(
48
+ title=self.title,
49
+ description=self.description,
50
+ content=self.container,
51
+ lock_message=self.lock_message,
52
+ collapsable=self.app_options.get("collapsable", False),
53
+ )
54
+ self.card.lock()
55
+
56
+ @property
57
+ def widgets_to_disable(self) -> list:
58
+ return [self.tags_table]
59
+
60
+ def get_selected_tags(self) -> list:
61
+ return self.tags_table.get_selected_tags()
62
+
63
+ def set_tags(self, tags) -> None:
64
+ self.tags_table.select_tags(tags)
65
+
66
+ def select_all_tags(self) -> None:
67
+ self.tags_table.select_all()
68
+
69
+ def validate_step(self) -> bool:
70
+ self.validator_text.hide()
71
+
72
+ project_tags = self.tags_table.project_meta.tag_metas
73
+ if len(project_tags) == 0:
74
+ self.validator_text.set(text="Project has no tags", status="error")
75
+ self.validator_text.show()
76
+ return False
77
+
78
+ selected_tags = self.tags_table.get_selected_tags()
79
+ table_data = self.tags_table._table_data
80
+ empty_tags = [
81
+ row[0]["data"]
82
+ for row in table_data
83
+ if row[0]["data"] in selected_tags and row[2]["data"] == 0 and row[3]["data"] == 0
84
+ ]
85
+
86
+ n_tags = len(selected_tags)
87
+ if n_tags == 0:
88
+ message = "Please select at least one tag"
89
+ status = "error"
90
+ else:
91
+ tag_text = "tag" if n_tags == 1 else "tags"
92
+ message = f"Selected {n_tags} {tag_text}"
93
+ status = "success"
94
+ if empty_tags:
95
+ intersections = set(selected_tags).intersection(empty_tags)
96
+ if intersections:
97
+ tag_text = "tag" if len(intersections) == 1 else "tags"
98
+ message += (
99
+ f". Selected {tag_text} have no annotations: {', '.join(intersections)}"
100
+ )
101
+ status = "warning"
102
+
103
+ self.validator_text.set(text=message, status=status)
104
+ self.validator_text.show()
105
+ return n_tags > 0
@@ -24,26 +24,51 @@ class TrainValSplitsSelector:
24
24
  self.project_id = project_id
25
25
 
26
26
  # GUI Components
27
- self.train_val_splits = TrainValSplits(project_id)
28
- train_val_dataset_ids = {"train": [], "val": []}
29
- for _, dataset in api.dataset.tree(project_id):
30
- if dataset.name.lower() == "train" or dataset.name.lower() == "training":
31
- if dataset.items_count > 0:
32
- train_val_dataset_ids["train"].append(dataset.id)
33
- elif dataset.name.lower() == "val" or dataset.name.lower() == "validation":
34
- if dataset.items_count > 0:
35
- train_val_dataset_ids["val"].append(dataset.id)
36
-
37
- # Check nested dataset names
27
+ split_methods = self.app_options.get("train_val_split_methods", [])
28
+ if len(split_methods) == 0:
29
+ split_methods = ["Random", "Based on tags", "Based on datasets"]
30
+ random_split = "Random" in split_methods
31
+ tag_split = "Based on tags" in split_methods
32
+ ds_split = "Based on datasets" in split_methods
33
+
34
+ self.train_val_splits = TrainValSplits(project_id, None, random_split, tag_split, ds_split)
35
+
36
+ def _extend_with_nested(root_ds):
37
+ nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
38
+ nested_ids = [ds.id for ds in nested]
39
+ return [root_ds.id] + nested_ids
40
+
41
+ train_val_dataset_ids = {"train": set(), "val": set()}
42
+ for _, dataset in self.api.dataset.tree(self.project_id):
43
+ ds_name = dataset.name.lower()
44
+
45
+ if ds_name in {"train", "training"}:
46
+ for _id in _extend_with_nested(dataset):
47
+ train_val_dataset_ids["train"].add(_id)
48
+
49
+ elif ds_name in {"val", "validation", "test", "testing"}:
50
+ for _id in _extend_with_nested(dataset):
51
+ train_val_dataset_ids["val"].add(_id)
52
+
53
+ train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
54
+ train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
55
+
38
56
  train_count = len(train_val_dataset_ids["train"])
39
57
  val_count = len(train_val_dataset_ids["val"])
58
+
40
59
  if train_count > 0 and val_count > 0:
41
60
  self.train_val_splits.set_datasets_splits(
42
61
  train_val_dataset_ids["train"], train_val_dataset_ids["val"]
43
62
  )
44
63
 
45
64
  if train_count > 0 and val_count > 0:
46
- self.validator_text = Text("Train and val datasets are detected", status="info")
65
+ if train_count == val_count == 1:
66
+ self.validator_text = Text("train and val datasets are detected", status="info")
67
+ else:
68
+ self.validator_text = Text(
69
+ "Multiple train and val datasets are detected. Check manually if selection is correct",
70
+ status="info",
71
+ )
47
72
  self.validator_text.show()
48
73
  else:
49
74
  self.validator_text = Text("")
@@ -71,8 +96,9 @@ class TrainValSplitsSelector:
71
96
  split_method = self.train_val_splits.get_split_method()
72
97
  warning_text = "Using the same data for training and validation leads to overfitting, poor generalization and biased model selection."
73
98
  ensure_text = "Ensure this is intentional."
99
+ is_valid = False
74
100
 
75
- if split_method == "Random":
101
+ def validate_random_split():
76
102
  train_ratio = self.train_val_splits.get_train_split_percent()
77
103
  val_ratio = self.train_val_splits.get_val_split_percent()
78
104
 
@@ -109,8 +135,9 @@ class TrainValSplitsSelector:
109
135
  text="Train and validation splits are selected.",
110
136
  status="success",
111
137
  )
138
+ return True
112
139
 
113
- elif split_method == "Based on tags":
140
+ def validate_based_on_tags():
114
141
  train_tag = self.train_val_splits.get_train_tag()
115
142
  val_tag = self.train_val_splits.get_val_tag()
116
143
 
@@ -130,29 +157,42 @@ class TrainValSplitsSelector:
130
157
  else:
131
158
  tags_count[tag_name] = tag_total
132
159
 
133
- # @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
134
160
  if tags_count[train_tag] == 0:
135
161
  self.validator_text.set(
136
162
  text=f"Train tag '{train_tag}' is not present in any images. {ensure_text}",
137
163
  status="error",
138
164
  )
165
+ return False
139
166
  elif tags_count[val_tag] == 0:
140
167
  self.validator_text.set(
141
168
  text=f"Val tag '{val_tag}' is not present in any images. {ensure_text}",
142
169
  status="error",
143
170
  )
144
-
171
+ return False
145
172
  elif train_tag == val_tag:
146
173
  self.validator_text.set(
147
174
  text=f"Train and val tags are the same. {ensure_text} {warning_text}",
148
175
  status="warning",
149
176
  )
177
+ return True
150
178
  else:
151
179
  self.validator_text.set("Train and val tags are selected", status="success")
180
+ return True
152
181
 
153
- elif split_method == "Based on datasets":
182
+ def validate_based_on_datasets():
154
183
  train_dataset_id = self.get_train_dataset_ids()
155
184
  val_dataset_id = self.get_val_dataset_ids()
185
+ if train_dataset_id is None and val_dataset_id is None:
186
+ self.validator_text.set("No datasets are selected", status="error")
187
+ return False
188
+
189
+ if train_dataset_id is None:
190
+ self.validator_text.set("No train dataset is selected", status="error")
191
+ return False
192
+
193
+ if val_dataset_id is None:
194
+ self.validator_text.set("No val dataset is selected", status="error")
195
+ return False
156
196
 
157
197
  # Check if datasets are not empty
158
198
  stats = self.api.project.get_stats(self.project_id)
@@ -169,6 +209,14 @@ class TrainValSplitsSelector:
169
209
  empty_dataset_names.append(datasets_count[dataset_id]["name"])
170
210
 
171
211
  if len(empty_dataset_names) > 0:
212
+ if len(empty_dataset_names) == len(train_dataset_id + val_dataset_id):
213
+ empty_ds_text = f"All selected datasets are empty. {ensure_text}"
214
+ self.validator_text.set(
215
+ text=empty_ds_text,
216
+ status="error",
217
+ )
218
+ return False
219
+
172
220
  if len(empty_dataset_names) == 1:
173
221
  empty_ds_text = f"Selected dataset: {', '.join(empty_dataset_names)} is empty. {ensure_text}"
174
222
  else:
@@ -178,16 +226,30 @@ class TrainValSplitsSelector:
178
226
  text=empty_ds_text,
179
227
  status="error",
180
228
  )
229
+ return True
181
230
 
182
231
  elif train_dataset_id == val_dataset_id:
183
232
  self.validator_text.set(
184
233
  text=f"Same datasets are selected for both train and val splits. {ensure_text} {warning_text}",
185
234
  status="warning",
186
235
  )
236
+ return True
187
237
  else:
188
238
  self.validator_text.set("Train and val datasets are selected", status="success")
239
+ return True
240
+
241
+ if split_method == "Random":
242
+ is_valid = validate_random_split()
243
+
244
+ elif split_method == "Based on tags":
245
+ is_valid = validate_based_on_tags()
246
+
247
+ elif split_method == "Based on datasets":
248
+ is_valid = validate_based_on_datasets()
249
+
250
+ # @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
189
251
  self.validator_text.show()
190
- return True
252
+ return is_valid
191
253
 
192
254
  def set_sly_project(self, project: Project) -> None:
193
255
  self.train_val_splits._project_fs = project