supervisely 6.73.278__py3-none-any.whl → 6.73.280__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

@@ -11,6 +11,9 @@ import supervisely.io.env as sly_env
11
11
  from supervisely import Api, ProjectMeta
12
12
  from supervisely._utils import is_production
13
13
  from supervisely.app.widgets import Stepper, Widget
14
+ from supervisely.geometry.polygon import Polygon
15
+ from supervisely.geometry.rectangle import Rectangle
16
+ from supervisely.nn.task_type import TaskType
14
17
  from supervisely.nn.training.gui.classes_selector import ClassesSelector
15
18
  from supervisely.nn.training.gui.hyperparameters_selector import HyperparametersSelector
16
19
  from supervisely.nn.training.gui.input_selector import InputSelector
@@ -61,6 +64,7 @@ class TrainGUI:
61
64
  self.hyperparameters = hyperparameters
62
65
  self.app_options = app_options
63
66
  self.collapsable = app_options.get("collapsable", False)
67
+ self.need_convert_shapes_for_bm = False
64
68
 
65
69
  self.team_id = sly_env.team_id(raise_not_found=False)
66
70
  self.workspace_id = sly_env.workspace_id(raise_not_found=False)
@@ -137,6 +141,35 @@ class TrainGUI:
137
141
  return
138
142
  self.training_process.set_experiment_name(experiment_name)
139
143
 
144
+ def need_convert_class_shapes() -> bool:
145
+ if not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
146
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
147
+ self.need_convert_shapes_for_bm = False
148
+ else:
149
+ task_type = self.model_selector.get_selected_task_type()
150
+
151
+ def _need_convert(shape):
152
+ if task_type == TaskType.OBJECT_DETECTION:
153
+ return shape != Rectangle.geometry_name()
154
+ elif task_type in [
155
+ TaskType.INSTANCE_SEGMENTATION,
156
+ TaskType.SEMANTIC_SEGMENTATION,
157
+ ]:
158
+ return shape == Polygon.geometry_name()
159
+ return
160
+
161
+ data = self.classes_selector.classes_table._table_data
162
+ selected_classes = set(self.classes_selector.classes_table.get_selected_classes())
163
+ empty = set(r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0)
164
+ need_convert = set(r[0]["data"] for r in data if _need_convert(r[1]["data"]))
165
+
166
+ if need_convert.intersection(selected_classes - empty):
167
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
168
+ self.need_convert_shapes_for_bm = True
169
+ else:
170
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
171
+ self.need_convert_shapes_for_bm = False
172
+
140
173
  # ------------------------------------------------- #
141
174
 
142
175
  # Wrappers
@@ -168,7 +201,7 @@ class TrainGUI:
168
201
  callback=self.hyperparameters_selector_cb,
169
202
  validation_text=self.model_selector.validator_text,
170
203
  validation_func=self.model_selector.validate_step,
171
- on_select_click=[set_experiment_name],
204
+ on_select_click=[set_experiment_name, need_convert_class_shapes],
172
205
  collapse_card=(self.model_selector.card, self.collapsable),
173
206
  )
174
207
 
@@ -276,6 +309,7 @@ class TrainGUI:
276
309
  @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
277
310
  def show_mb_speedtest(is_checked: bool):
278
311
  self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
312
+ need_convert_class_shapes()
279
313
 
280
314
  # ------------------------------------------------- #
281
315
 
@@ -44,8 +44,18 @@ class HyperparametersSelector:
44
44
  self.model_benchmark_learn_more = Text(
45
45
  f"Learn more about Model Benchmark in the {docs_link}.", status="info"
46
46
  )
47
+ self.model_benchmark_auto_convert_warning = Text(
48
+ text="Project will be automatically converted according to CV task for Model Evaluation.",
49
+ status="warning",
50
+ )
51
+ self.model_benchmark_auto_convert_warning.hide()
52
+
47
53
  self.display_widgets.extend(
48
- [self.model_benchmark_field, self.model_benchmark_learn_more]
54
+ [
55
+ self.model_benchmark_field,
56
+ self.model_benchmark_learn_more,
57
+ self.model_benchmark_auto_convert_warning,
58
+ ]
49
59
  )
50
60
  # -------------------------------- #
51
61
 
@@ -12,9 +12,9 @@ from supervisely.app.widgets import (
12
12
  RadioTabs,
13
13
  Text,
14
14
  )
15
+ from supervisely.nn.artifacts.utils import FrameworkMapper
15
16
  from supervisely.nn.experiments import get_experiment_infos
16
17
  from supervisely.nn.utils import ModelSource
17
- from supervisely.nn.artifacts.utils import FrameworkMapper, FrameworkName
18
18
 
19
19
 
20
20
  class ModelSelector:
@@ -37,7 +37,7 @@ class ModelSelector:
37
37
  legacy_experiment_infos = framework_cls.get_list_experiment_info()
38
38
  experiment_infos = experiment_infos + legacy_experiment_infos
39
39
  except:
40
- logger.warn(f"Legacy checkpoints are not available for '{framework}'")
40
+ logger.warning(f"Legacy checkpoints are not available for '{framework}'")
41
41
 
42
42
  self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
43
43
  self.model_source_tabs = RadioTabs(
@@ -106,3 +106,9 @@ class ModelSelector:
106
106
  self.validator_text.set(text="Model is selected", status="success")
107
107
  self.validator_text.show()
108
108
  return True
109
+
110
+ def get_selected_task_type(self) -> str:
111
+ if self.get_model_source() == ModelSource.PRETRAINED:
112
+ return self.pretrained_models_table.get_selected_task_type()
113
+ else:
114
+ return self.experiment_selector.get_selected_task_type()
@@ -25,6 +25,7 @@ import supervisely.io.json as sly_json
25
25
  from supervisely import (
26
26
  Api,
27
27
  Application,
28
+ Dataset,
28
29
  DatasetInfo,
29
30
  OpenMode,
30
31
  Project,
@@ -32,6 +33,7 @@ from supervisely import (
32
33
  ProjectMeta,
33
34
  WorkflowMeta,
34
35
  WorkflowSettings,
36
+ batched,
35
37
  download_project,
36
38
  is_development,
37
39
  is_production,
@@ -340,22 +342,6 @@ class TrainApp:
340
342
  """
341
343
  return self.gui.model_selector.get_model_info()
342
344
 
343
- @property
344
- def model_meta(self) -> ProjectMeta:
345
- """
346
- Returns the model metadata.
347
-
348
- :return: Model metadata.
349
- :rtype: dict
350
- """
351
- project_meta_json = self.project_meta.to_json()
352
- model_meta = {
353
- "classes": [
354
- item for item in project_meta_json["classes"] if item["title"] in self.classes
355
- ]
356
- }
357
- return ProjectMeta.from_json(model_meta)
358
-
359
345
  @property
360
346
  def device(self) -> str:
361
347
  """
@@ -496,7 +482,6 @@ class TrainApp:
496
482
  downloading project and model data.
497
483
  """
498
484
  logger.info("Preparing for training")
499
- self.gui.disable_select_buttons()
500
485
 
501
486
  # Step 1. Workflow Input
502
487
  if is_production():
@@ -505,8 +490,7 @@ class TrainApp:
505
490
  self._download_project()
506
491
  # Step 3. Split Project
507
492
  self._split_project()
508
- # Step 4. Convert Supervisely to X format
509
- # Step 5. Download Model files
493
+ # Step 4. Download Model files
510
494
  self._download_model()
511
495
 
512
496
  def _finalize(self, experiment_info: dict) -> None:
@@ -518,31 +502,41 @@ class TrainApp:
518
502
  :type experiment_info: dict
519
503
  """
520
504
  logger.info("Finalizing training")
505
+ # Step 1. Validate experiment TaskType
506
+ experiment_info = self._validate_experiment_task_type(experiment_info)
521
507
 
522
- # Step 1. Validate experiment_info
508
+ # Step 2. Validate experiment_info
523
509
  success, reason = self._validate_experiment_info(experiment_info)
524
510
  if not success:
525
511
  raise ValueError(f"{reason}. Failed to upload artifacts")
526
512
 
527
- # Step 2. Preprocess artifacts
513
+ # Step 3. Preprocess artifacts
528
514
  experiment_info = self._preprocess_artifacts(experiment_info)
529
515
 
530
- # Step3. Postprocess splits
531
- splits_data = self._postprocess_splits()
516
+ # Step 4. Postprocess splits
517
+ train_splits_data = self._postprocess_splits()
532
518
 
533
- # Step 3. Upload artifacts
519
+ # Step 5. Upload artifacts
534
520
  self._set_text_status("uploading")
535
521
  remote_dir, file_info = self._upload_artifacts()
536
522
 
537
- # Step 4. Run Model Benchmark
538
- mb_eval_lnk_file_info, mb_eval_report, mb_eval_report_id, eval_metrics = (
539
- None,
540
- None,
541
- None,
542
- {},
543
- )
523
+ # Step 6. Create model meta according to model CV task type
524
+ model_meta = self.create_model_meta(experiment_info["task_type"])
525
+
526
+ # Step 7. [Optional] Run Model Benchmark
527
+ mb_eval_lnk_file_info, mb_eval_report = None, None
528
+ mb_eval_report_id, eval_metrics = None, {}
544
529
  if self.is_model_benchmark_enabled:
545
530
  try:
531
+ # Convert GT project
532
+ gt_project_id, bm_splits_data = None, train_splits_data
533
+ if self._app_options.get("auto_convert_classes", True):
534
+ if self.gui.need_convert_shapes_for_bm:
535
+ self._set_text_status("convert_gt_project")
536
+ gt_project_id, bm_splits_data = self._convert_and_split_gt_project(
537
+ experiment_info["task_type"]
538
+ )
539
+
546
540
  self._set_text_status("benchmark")
547
541
  (
548
542
  mb_eval_lnk_file_info,
@@ -550,12 +544,17 @@ class TrainApp:
550
544
  mb_eval_report_id,
551
545
  eval_metrics,
552
546
  ) = self._run_model_benchmark(
553
- self.output_dir, remote_dir, experiment_info, splits_data
547
+ self.output_dir,
548
+ remote_dir,
549
+ experiment_info,
550
+ bm_splits_data,
551
+ model_meta,
552
+ gt_project_id,
554
553
  )
555
554
  except Exception as e:
556
555
  logger.error(f"Model benchmark failed: {e}")
557
556
 
558
- # Step 5. [Optional] Convert weights
557
+ # Step 8. [Optional] Convert weights
559
558
  export_weights = {}
560
559
  if self.gui.hyperparameters_selector.is_export_required():
561
560
  try:
@@ -564,23 +563,23 @@ class TrainApp:
564
563
  except Exception as e:
565
564
  logger.error(f"Export weights failed: {e}")
566
565
 
567
- # Step 6. Generate and upload additional files
566
+ # Step 9. Generate and upload additional files
568
567
  self._set_text_status("metadata")
569
568
  self._generate_experiment_info(
570
569
  remote_dir, experiment_info, eval_metrics, mb_eval_report_id, export_weights
571
570
  )
572
571
  self._generate_app_state(remote_dir, experiment_info)
573
572
  self._generate_hyperparameters(remote_dir, experiment_info)
574
- self._generate_train_val_splits(remote_dir, splits_data)
575
- self._generate_model_meta(remote_dir, experiment_info)
573
+ self._generate_train_val_splits(remote_dir, train_splits_data)
574
+ self._generate_model_meta(remote_dir, model_meta)
576
575
  self._upload_demo_files(remote_dir)
577
576
 
578
- # Step 7. Set output widgets
577
+ # Step 10. Set output widgets
579
578
  self._set_text_status("reset")
580
579
  self._set_training_output(remote_dir, file_info, mb_eval_report)
581
580
  self._set_ws_progress_status("completed")
582
581
 
583
- # Step 8. Workflow output
582
+ # Step 11. Workflow output
584
583
  if is_production():
585
584
  self._workflow_output(remote_dir, file_info, mb_eval_lnk_file_info, mb_eval_report_id)
586
585
 
@@ -1120,6 +1119,24 @@ class TrainApp:
1120
1119
  # ----------------------------------------- #
1121
1120
 
1122
1121
  # Postprocess
1122
+ def _validate_experiment_task_type(self, experiment_info: dict) -> dict:
1123
+ """
1124
+ Checks if the task_type key if returned from the user's training function.
1125
+ If not, it will be set to the task type of the model selected in the model selector.
1126
+
1127
+ :param experiment_info: Information about the experiment results.
1128
+ :type experiment_info: dict
1129
+ :return: Experiment info with task_type key.
1130
+ :rtype: dict
1131
+ """
1132
+ task_type = experiment_info.get("task_type", None)
1133
+ if task_type is None:
1134
+ logger.debug(
1135
+ "Task type not found in experiment_info. Task type from model config will be used."
1136
+ )
1137
+ task_type = self.gui.model_selector.get_selected_task_type()
1138
+ experiment_info["task_type"] = task_type
1139
+ return experiment_info
1123
1140
 
1124
1141
  def _validate_experiment_info(self, experiment_info: dict) -> tuple:
1125
1142
  """
@@ -1200,9 +1217,14 @@ class TrainApp:
1200
1217
  logger.debug("Validation successful")
1201
1218
  return True, None
1202
1219
 
1203
- def _postprocess_splits(self) -> dict:
1220
+ def _postprocess_splits(self, project_id: Optional[int] = None) -> dict:
1204
1221
  """
1205
1222
  Processes the train and val splits to generate the necessary data for the experiment_info.json file.
1223
+
1224
+ :param project_id: ID of the ground truth project for model benchmark. Provide only when cv task convertion is required.
1225
+ :type project_id: Optional[int]
1226
+ :return: Splits data.
1227
+ :rtype: dict
1206
1228
  """
1207
1229
  val_dataset_ids = None
1208
1230
  val_images_ids = None
@@ -1212,10 +1234,30 @@ class TrainApp:
1212
1234
  split_method = self.gui.train_val_splits_selector.get_split_method()
1213
1235
  train_set, val_set = self._train_split, self._val_split
1214
1236
  if split_method == "Based on datasets":
1215
- val_dataset_ids = self.gui.train_val_splits_selector.get_val_dataset_ids()
1216
- train_dataset_ids = self.gui.train_val_splits_selector.get_train_dataset_ids
1237
+ if project_id is None:
1238
+ val_dataset_ids = self.gui.train_val_splits_selector.get_val_dataset_ids()
1239
+ train_dataset_ids = self.gui.train_val_splits_selector.get_train_dataset_ids()
1240
+ else:
1241
+ src_datasets_map = {
1242
+ dataset.id: dataset
1243
+ for _, dataset in self._api.dataset.tree(self.project_info.id)
1244
+ }
1245
+ val_dataset_ids = self.gui.train_val_splits_selector.get_val_dataset_ids()
1246
+ train_dataset_ids = self.gui.train_val_splits_selector.get_train_dataset_ids()
1247
+
1248
+ train_dataset_names = [src_datasets_map[ds_id].name for ds_id in train_dataset_ids]
1249
+ val_dataset_names = [src_datasets_map[ds_id].name for ds_id in val_dataset_ids]
1250
+
1251
+ gt_datasets_map = {
1252
+ dataset.name: dataset.id for _, dataset in self._api.dataset.tree(project_id)
1253
+ }
1254
+ train_dataset_ids = [gt_datasets_map[ds_name] for ds_name in train_dataset_names]
1255
+ val_dataset_ids = [gt_datasets_map[ds_name] for ds_name in val_dataset_names]
1217
1256
  else:
1218
- dataset_infos = [dataset for _, dataset in self._api.dataset.tree(self.project_id)]
1257
+ if project_id is None:
1258
+ project_id = self.project_id
1259
+
1260
+ dataset_infos = [dataset for _, dataset in self._api.dataset.tree(project_id)]
1219
1261
  ds_infos_dict = {}
1220
1262
  for dataset in dataset_infos:
1221
1263
  if dataset.parent_id is not None:
@@ -1232,18 +1274,19 @@ class TrainApp:
1232
1274
  image_infos = []
1233
1275
  for dataset_name, image_names in image_names_per_dataset.items():
1234
1276
  ds_info = ds_infos_dict[dataset_name]
1235
- image_infos.extend(
1236
- self._api.image.get_list(
1237
- ds_info.id,
1238
- filters=[
1239
- {
1240
- "field": "name",
1241
- "operator": "in",
1242
- "value": image_names,
1243
- }
1244
- ],
1277
+ for names_batch in batched(image_names, 200):
1278
+ image_infos.extend(
1279
+ self._api.image.get_list(
1280
+ ds_info.id,
1281
+ filters=[
1282
+ {
1283
+ "field": "name",
1284
+ "operator": "in",
1285
+ "value": names_batch,
1286
+ }
1287
+ ],
1288
+ )
1245
1289
  )
1246
- )
1247
1290
  return image_infos
1248
1291
 
1249
1292
  val_image_infos = get_image_infos_by_split(ds_infos_dict, val_set)
@@ -1373,7 +1416,7 @@ class TrainApp:
1373
1416
  f"Uploading '{self._train_val_split_file}' to Team Files",
1374
1417
  )
1375
1418
 
1376
- def _generate_model_meta(self, remote_dir: str, experiment_info: dict) -> None:
1419
+ def _generate_model_meta(self, remote_dir: str, model_meta: ProjectMeta) -> None:
1377
1420
  """
1378
1421
  Generates and uploads the model_meta.json file to the output directory.
1379
1422
 
@@ -1382,17 +1425,31 @@ class TrainApp:
1382
1425
  :param experiment_info: Information about the experiment results.
1383
1426
  :type experiment_info: dict
1384
1427
  """
1385
- # @TODO: Handle tags for classification tasks
1386
1428
  local_path = join(self.output_dir, self._model_meta_file)
1387
1429
  remote_path = join(remote_dir, self._model_meta_file)
1388
1430
 
1389
- sly_json.dump_json_file(self.model_meta.to_json(), local_path)
1431
+ sly_json.dump_json_file(model_meta.to_json(), local_path)
1390
1432
  self._upload_file_to_team_files(
1391
1433
  local_path,
1392
1434
  remote_path,
1393
1435
  f"Uploading '{self._model_meta_file}' to Team Files",
1394
1436
  )
1395
1437
 
1438
+ def create_model_meta(self, task_type: str):
1439
+ """
1440
+ Convert project meta according to task type.
1441
+ """
1442
+ names_to_delete = [
1443
+ c.name for c in self.project_meta.obj_classes if c.name not in self.classes
1444
+ ]
1445
+ model_meta = self.project_meta.delete_obj_classes(names_to_delete)
1446
+
1447
+ if task_type == TaskType.OBJECT_DETECTION:
1448
+ model_meta, _ = model_meta.to_detection_task(True)
1449
+ elif task_type in [TaskType.INSTANCE_SEGMENTATION, TaskType.SEMANTIC_SEGMENTATION]:
1450
+ model_meta, _ = model_meta.to_segmentation_task() # @TODO: check background class
1451
+ return model_meta
1452
+
1396
1453
  def _generate_experiment_info(
1397
1454
  self,
1398
1455
  remote_dir: str,
@@ -1740,6 +1797,8 @@ class TrainApp:
1740
1797
  remote_artifacts_dir: str,
1741
1798
  experiment_info: dict,
1742
1799
  splits_data: dict,
1800
+ model_meta: ProjectInfo,
1801
+ gt_project_id: int = None,
1743
1802
  ) -> tuple:
1744
1803
  """
1745
1804
  Runs the Model Benchmark evaluation process. Model benchmark runs only in production mode.
@@ -1752,6 +1811,10 @@ class TrainApp:
1752
1811
  :type experiment_info: dict
1753
1812
  :param splits_data: Information about the train and val splits.
1754
1813
  :type splits_data: dict
1814
+ :param model_meta: Model meta with object classes.
1815
+ :type model_meta: ProjectInfo
1816
+ :param gt_project_id: Ground truth project ID with converted shapes.
1817
+ :type gt_project_id: int
1755
1818
  :return: Evaluation report, report ID and evaluation metrics.
1756
1819
  :rtype: tuple
1757
1820
  """
@@ -1767,6 +1830,7 @@ class TrainApp:
1767
1830
  supported_task_types = [
1768
1831
  TaskType.OBJECT_DETECTION,
1769
1832
  TaskType.INSTANCE_SEGMENTATION,
1833
+ TaskType.SEMANTIC_SEGMENTATION,
1770
1834
  ]
1771
1835
  task_type = experiment_info["task_type"]
1772
1836
  if task_type not in supported_task_types:
@@ -1807,7 +1871,7 @@ class TrainApp:
1807
1871
  "artifacts_dir": remote_artifacts_dir,
1808
1872
  "model_name": experiment_info["model_name"],
1809
1873
  "framework_name": self.framework_name,
1810
- "model_meta": self.model_meta.to_json(),
1874
+ "model_meta": model_meta.to_json(),
1811
1875
  }
1812
1876
 
1813
1877
  logger.info(f"Deploy parameters: {self._benchmark_params}")
@@ -1827,12 +1891,15 @@ class TrainApp:
1827
1891
  train_images_ids = splits_data["train"]["images_ids"]
1828
1892
 
1829
1893
  bm = None
1894
+ if gt_project_id is None:
1895
+ gt_project_id = self.project_info.id
1896
+
1830
1897
  if task_type == TaskType.OBJECT_DETECTION:
1831
1898
  eval_params = ObjectDetectionEvaluator.load_yaml_evaluation_params()
1832
1899
  eval_params = yaml.safe_load(eval_params)
1833
1900
  bm = ObjectDetectionBenchmark(
1834
1901
  self._api,
1835
- self.project_info.id,
1902
+ gt_project_id,
1836
1903
  output_dir=benchmark_dir,
1837
1904
  gt_dataset_ids=benchmark_dataset_ids,
1838
1905
  gt_images_ids=benchmark_images_ids,
@@ -1846,7 +1913,7 @@ class TrainApp:
1846
1913
  eval_params = yaml.safe_load(eval_params)
1847
1914
  bm = InstanceSegmentationBenchmark(
1848
1915
  self._api,
1849
- self.project_info.id,
1916
+ gt_project_id,
1850
1917
  output_dir=benchmark_dir,
1851
1918
  gt_dataset_ids=benchmark_dataset_ids,
1852
1919
  gt_images_ids=benchmark_images_ids,
@@ -1860,7 +1927,7 @@ class TrainApp:
1860
1927
  eval_params = yaml.safe_load(eval_params)
1861
1928
  bm = SemanticSegmentationBenchmark(
1862
1929
  self._api,
1863
- self.project_info.id,
1930
+ gt_project_id,
1864
1931
  output_dir=benchmark_dir,
1865
1932
  gt_dataset_ids=benchmark_dataset_ids,
1866
1933
  gt_images_ids=benchmark_images_ids,
@@ -2168,6 +2235,7 @@ class TrainApp:
2168
2235
  Wrapper function to wrap the training process.
2169
2236
  """
2170
2237
  experiment_info = None
2238
+ check_logs_text = "Please check the logs for more details."
2171
2239
 
2172
2240
  try:
2173
2241
  self._set_train_widgets_state_on_start()
@@ -2176,7 +2244,7 @@ class TrainApp:
2176
2244
  self._prepare_working_dir()
2177
2245
  self._init_logger()
2178
2246
  except Exception as e:
2179
- message = "Error occurred during training initialization. Please check the logs for more details."
2247
+ message = f"Error occurred during training initialization. {check_logs_text}"
2180
2248
  self._show_error(message, e)
2181
2249
  self._restore_train_widgets_state_on_error()
2182
2250
  self._set_ws_progress_status("reset")
@@ -2187,9 +2255,7 @@ class TrainApp:
2187
2255
  self._set_ws_progress_status("preparing")
2188
2256
  self._prepare()
2189
2257
  except Exception as e:
2190
- message = (
2191
- "Error occurred during data preparation. Please check the logs for more details."
2192
- )
2258
+ message = f"Error occurred during data preparation. {check_logs_text}"
2193
2259
  self._show_error(message, e)
2194
2260
  self._restore_train_widgets_state_on_error()
2195
2261
  self._set_ws_progress_status("reset")
@@ -2200,8 +2266,18 @@ class TrainApp:
2200
2266
  if self._app_options.get("train_logger", None) is None:
2201
2267
  self._set_ws_progress_status("training")
2202
2268
  experiment_info = self._train_func()
2269
+ except ZeroDivisionError as e:
2270
+ message = (
2271
+ "'ZeroDivisionError' occurred during training. "
2272
+ "The error was caused by an insufficient dataset size relative to the specified batch size in hyperparameters. "
2273
+ "Please check input data and hyperparameters."
2274
+ )
2275
+ self._show_error(message, e)
2276
+ self._restore_train_widgets_state_on_error()
2277
+ self._set_ws_progress_status("reset")
2278
+ return
2203
2279
  except Exception as e:
2204
- message = "Error occurred during training. Please check the logs for more details."
2280
+ message = f"Error occurred during training. {check_logs_text}"
2205
2281
  self._show_error(message, e)
2206
2282
  self._restore_train_widgets_state_on_error()
2207
2283
  self._set_ws_progress_status("reset")
@@ -2213,7 +2289,7 @@ class TrainApp:
2213
2289
  self._finalize(experiment_info)
2214
2290
  self.gui.training_process.start_button.loading = False
2215
2291
  except Exception as e:
2216
- message = "Error occurred during finalizing and uploading training artifacts . Please check the logs for more details."
2292
+ message = f"Error occurred during finalizing and uploading training artifacts. {check_logs_text}"
2217
2293
  self._show_error(message, e)
2218
2294
  self._restore_train_widgets_state_on_error()
2219
2295
  self._set_ws_progress_status("reset")
@@ -2230,6 +2306,7 @@ class TrainApp:
2230
2306
  self._restore_train_widgets_state_on_error()
2231
2307
 
2232
2308
  def _set_train_widgets_state_on_start(self):
2309
+ self.gui.disable_select_buttons()
2233
2310
  self.gui.training_artifacts.validator_text.hide()
2234
2311
  self._validate_experiment_name()
2235
2312
  self.gui.training_process.experiment_name_input.disable()
@@ -2255,6 +2332,7 @@ class TrainApp:
2255
2332
  if self._app_options.get("device_selector", False):
2256
2333
  self.gui.training_process.select_device._select.enable()
2257
2334
  self.gui.training_process.select_device.enable()
2335
+ self.gui.enable_select_buttons()
2258
2336
 
2259
2337
  def _validate_experiment_name(self) -> bool:
2260
2338
  experiment_name = self.gui.training_process.get_experiment_name()
@@ -2280,6 +2358,7 @@ class TrainApp:
2280
2358
  "metadata",
2281
2359
  "export_onnx",
2282
2360
  "export_trt",
2361
+ "convert_gt_project",
2283
2362
  ],
2284
2363
  ):
2285
2364
 
@@ -2313,6 +2392,8 @@ class TrainApp:
2313
2392
  self.gui.training_process.validator_text.set("Validating experiment...", "info")
2314
2393
  elif status == "metadata":
2315
2394
  self.gui.training_process.validator_text.set("Generating training metadata...", "info")
2395
+ elif status == "convert_gt_project":
2396
+ self.gui.training_process.validator_text.set("Converting GT project...", "info")
2316
2397
 
2317
2398
  def _set_ws_progress_status(
2318
2399
  self,
@@ -2391,3 +2472,55 @@ class TrainApp:
2391
2472
  for runtime, path in export_weights.items()
2392
2473
  }
2393
2474
  return remote_export_weights
2475
+
2476
+ def _convert_and_split_gt_project(self, task_type: str):
2477
+ # 1. Convert GT project to cv task
2478
+ Project.download(
2479
+ self._api, self.project_info.id, "tmp_project", save_images=False, save_image_info=True
2480
+ )
2481
+ project = Project("tmp_project", OpenMode.READ)
2482
+
2483
+ pr_prefix = ""
2484
+ if task_type == TaskType.OBJECT_DETECTION:
2485
+ Project.to_detection_task(project.directory, inplace=True)
2486
+ pr_prefix = "[detection]: "
2487
+ # @TODO: dont convert segmentation?
2488
+ elif (
2489
+ task_type == TaskType.INSTANCE_SEGMENTATION
2490
+ or task_type == TaskType.SEMANTIC_SEGMENTATION
2491
+ ):
2492
+ Project.to_segmentation_task(project.directory, inplace=True)
2493
+ pr_prefix = "[segmentation]: "
2494
+
2495
+ gt_project_info = self._api.project.create(
2496
+ self.workspace_id,
2497
+ f"{pr_prefix}{self.project_info.name}",
2498
+ description=(
2499
+ f"Converted ground truth project for trainig session: '{self.task_id}'. "
2500
+ f"Original project id: '{self.project_info.id}. "
2501
+ "Removing this project will affect model benchmark evaluation report."
2502
+ ),
2503
+ change_name_if_conflict=True,
2504
+ )
2505
+
2506
+ # 3. Upload converted gt project
2507
+ project = Project("tmp_project", OpenMode.READ)
2508
+ self._api.project.update_meta(gt_project_info.id, project.meta)
2509
+ for dataset in project.datasets:
2510
+ dataset: Dataset
2511
+ ds_info = self._api.dataset.create(
2512
+ gt_project_info.id, dataset.name, change_name_if_conflict=True
2513
+ )
2514
+ for batch_names in batched(dataset.get_items_names(), 100):
2515
+ img_infos = [dataset.get_item_info(name) for name in batch_names]
2516
+ img_ids = [img_info.id for img_info in img_infos]
2517
+ anns = [dataset.get_ann(name, project.meta) for name in batch_names]
2518
+
2519
+ img_infos = self._api.image.copy_batch(ds_info.id, img_ids)
2520
+ img_ids = [img_info.id for img_info in img_infos]
2521
+ self._api.annotation.upload_anns(img_ids, anns)
2522
+ sly_fs.remove_dir(project.directory)
2523
+
2524
+ # 4. Match splits with original project
2525
+ gt_split_data = self._postprocess_splits(gt_project_info.id)
2526
+ return gt_project_info.id, gt_split_data
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: supervisely
3
- Version: 6.73.278
3
+ Version: 6.73.280
4
4
  Summary: Supervisely Python SDK.
5
5
  Home-page: https://github.com/supervisely/supervisely
6
6
  Author: Supervisely
@@ -968,13 +968,13 @@ supervisely/nn/tracker/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
968
968
  supervisely/nn/tracker/utils/gmc.py,sha256=3JX8979H3NA-YHNaRQyj9Z-xb9qtyMittPEjGw8y2Jo,11557
969
969
  supervisely/nn/tracker/utils/kalman_filter.py,sha256=eSFmCjM0mikHCAFvj-KCVzw-0Jxpoc3Cfc2NWEjJC1Q,17268
970
970
  supervisely/nn/training/__init__.py,sha256=gY4PCykJ-42MWKsqb9kl-skemKa8yB6t_fb5kzqR66U,111
971
- supervisely/nn/training/train_app.py,sha256=ropUF_M9RfijQ3XheqEtYl0Soix-69CgZeOnYiCIuI4,95088
971
+ supervisely/nn/training/train_app.py,sha256=PZ4zWMYRvOFj97vy2rOofCBYqnpkDtmouzFTjs9UyN4,101747
972
972
  supervisely/nn/training/gui/__init__.py,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
973
973
  supervisely/nn/training/gui/classes_selector.py,sha256=8UgzA4aogOAr1s42smwEcDbgaBj_i0JLhjwlZ9bFdIA,3772
974
- supervisely/nn/training/gui/gui.py,sha256=nj4EVppoV9ZjLN0rVO0GKxmI56d6Qpp0qwnJJ6srT6w,23712
975
- supervisely/nn/training/gui/hyperparameters_selector.py,sha256=2qryuBss0bLcZJV8PNJ6_hKZM5Dbj2FIxTb3EULHQrE,6670
974
+ supervisely/nn/training/gui/gui.py,sha256=CnT_QhihrxdSHKybpI0pXhPLwCaXEana_qdn0DhXByg,25558
975
+ supervisely/nn/training/gui/hyperparameters_selector.py,sha256=UAXZYyhuUOY7d2ZKAx4R5Kz-KQaiFZ7AnY8BDoj3_30,7071
976
976
  supervisely/nn/training/gui/input_selector.py,sha256=Jp9PnVVADv1fhndPuZdMlKuzWTOBQZogrOks5dwATlc,2179
977
- supervisely/nn/training/gui/model_selector.py,sha256=QTFHMf-8-rREYPk64QKoRvE4zKPC8V6tcP4H4N6nyt0,4082
977
+ supervisely/nn/training/gui/model_selector.py,sha256=n2Xn6as60bNPtSlImJtyrVEo0gjKnvHLT3yq_m39TXk,4334
978
978
  supervisely/nn/training/gui/train_val_splits_selector.py,sha256=MLryFD2Tj_RobkFzZOeQXzXpch0eGiVFisq3FGA3dFg,8549
979
979
  supervisely/nn/training/gui/training_artifacts.py,sha256=UpKI68S0h_nT_CEEKxBi1oeRsYVnocxRZZD4kUEnQ80,9584
980
980
  supervisely/nn/training/gui/training_logs.py,sha256=1CBqnL0l5kiZVaegJ-NLgOVI1T4EDB_rLAtumuw18Jo,3222
@@ -1070,9 +1070,9 @@ supervisely/worker_proto/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ
1070
1070
  supervisely/worker_proto/worker_api_pb2.py,sha256=VQfi5JRBHs2pFCK1snec3JECgGnua3Xjqw_-b3aFxuM,59142
1071
1071
  supervisely/worker_proto/worker_api_pb2_grpc.py,sha256=3BwQXOaP9qpdi0Dt9EKG--Lm8KGN0C5AgmUfRv77_Jk,28940
1072
1072
  supervisely_lib/__init__.py,sha256=7-3QnN8Zf0wj8NCr2oJmqoQWMKKPKTECvjH9pd2S5vY,159
1073
- supervisely-6.73.278.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1074
- supervisely-6.73.278.dist-info/METADATA,sha256=hb6AM2qZI9n04Q_wTnLDxK9HzaYVWV1LK2Cp8hBcy7o,33573
1075
- supervisely-6.73.278.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
1076
- supervisely-6.73.278.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
1077
- supervisely-6.73.278.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
1078
- supervisely-6.73.278.dist-info/RECORD,,
1073
+ supervisely-6.73.280.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1074
+ supervisely-6.73.280.dist-info/METADATA,sha256=xY-ujb2oWVk6XMCZER18NgoGnfZYef8Lt6UAzVtDvkI,33573
1075
+ supervisely-6.73.280.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
1076
+ supervisely-6.73.280.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
1077
+ supervisely-6.73.280.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
1078
+ supervisely-6.73.280.dist-info/RECORD,,