supervisely 6.73.420__py3-none-any.whl → 6.73.422__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. supervisely/api/api.py +10 -5
  2. supervisely/api/app_api.py +71 -4
  3. supervisely/api/module_api.py +4 -0
  4. supervisely/api/nn/deploy_api.py +15 -9
  5. supervisely/api/nn/ecosystem_models_api.py +201 -0
  6. supervisely/api/nn/neural_network_api.py +12 -3
  7. supervisely/api/project_api.py +35 -6
  8. supervisely/api/task_api.py +5 -1
  9. supervisely/app/widgets/__init__.py +8 -1
  10. supervisely/app/widgets/agent_selector/template.html +1 -0
  11. supervisely/app/widgets/deploy_model/__init__.py +0 -0
  12. supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
  13. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  14. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  15. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  16. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  17. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
  18. supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
  19. supervisely/app/widgets/fast_table/fast_table.py +402 -74
  20. supervisely/app/widgets/fast_table/script.js +364 -96
  21. supervisely/app/widgets/fast_table/style.css +24 -0
  22. supervisely/app/widgets/fast_table/template.html +43 -3
  23. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  24. supervisely/app/widgets/select/select.py +6 -4
  25. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
  26. supervisely/app/widgets/tabs/tabs.py +22 -6
  27. supervisely/app/widgets/tabs/template.html +5 -1
  28. supervisely/nn/artifacts/__init__.py +1 -1
  29. supervisely/nn/artifacts/artifacts.py +10 -2
  30. supervisely/nn/artifacts/detectron2.py +1 -0
  31. supervisely/nn/artifacts/hrda.py +1 -0
  32. supervisely/nn/artifacts/mmclassification.py +20 -0
  33. supervisely/nn/artifacts/mmdetection.py +5 -3
  34. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  35. supervisely/nn/artifacts/ritm.py +1 -0
  36. supervisely/nn/artifacts/rtdetr.py +1 -0
  37. supervisely/nn/artifacts/unet.py +1 -0
  38. supervisely/nn/artifacts/utils.py +3 -0
  39. supervisely/nn/artifacts/yolov5.py +2 -0
  40. supervisely/nn/artifacts/yolov8.py +1 -0
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  42. supervisely/nn/experiments.py +9 -0
  43. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  44. supervisely/nn/inference/inference.py +160 -94
  45. supervisely/nn/inference/predict_app/__init__.py +0 -0
  46. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  47. supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
  48. supervisely/nn/inference/predict_app/gui/gui.py +710 -0
  49. supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
  50. supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
  51. supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
  52. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  53. supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
  54. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  55. supervisely/nn/inference/predict_app/gui/utils.py +282 -0
  56. supervisely/nn/inference/predict_app/predict_app.py +184 -0
  57. supervisely/nn/inference/uploader.py +9 -5
  58. supervisely/nn/model/prediction.py +2 -0
  59. supervisely/nn/model/prediction_session.py +20 -3
  60. supervisely/nn/training/gui/gui.py +131 -44
  61. supervisely/nn/training/gui/model_selector.py +8 -6
  62. supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
  63. supervisely/nn/training/gui/training_artifacts.py +0 -5
  64. supervisely/nn/training/train_app.py +161 -44
  65. supervisely/template/experiment/experiment.html.jinja +74 -17
  66. supervisely/template/experiment/experiment_generator.py +258 -112
  67. supervisely/template/experiment/header.html.jinja +31 -13
  68. supervisely/template/experiment/sly-style.css +7 -2
  69. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/METADATA +3 -1
  70. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/RECORD +74 -56
  71. supervisely/app/widgets/experiment_selector/style.css +0 -27
  72. supervisely/app/widgets/experiment_selector/template.html +0 -61
  73. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/LICENSE +0 -0
  74. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/WHEEL +0 -0
  75. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/entry_points.txt +0 -0
  76. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ from typing import List
3
3
  from supervisely import Api, Project
4
4
  from supervisely.app.widgets import Button, Card, Container, Text, TrainValSplits
5
5
  from supervisely.api.module_api import ApiField
6
+ from supervisely.api.entities_collection_api import EntitiesCollectionInfo
6
7
 
7
8
  class TrainValSplitsSelector:
8
9
  title = "Train / Val Splits"
@@ -18,6 +19,13 @@ class TrainValSplitsSelector:
18
19
  self.card = None
19
20
  # -------------------------------- #
20
21
 
22
+ # Automated Splits
23
+ self._all_train_collections = []
24
+ self._all_val_collections = []
25
+ self._latest_train_collection = None
26
+ self._latest_val_collection = None
27
+ # -------------------------------- #
28
+
21
29
  self.display_widgets = []
22
30
  self.app_options = app_options
23
31
  self.api = api
@@ -32,75 +40,9 @@ class TrainValSplitsSelector:
32
40
  ds_split = "Based on datasets" in split_methods
33
41
  coll_split = "Based on collections" in split_methods
34
42
 
35
- self.train_val_splits = TrainValSplits(
36
- project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split
37
- )
38
-
39
- # check for collections with "train" and "val" prefixes
40
- all_collections = self.api.entities_collection.get_list(self.project_id)
41
- train_collections = []
42
- val_collections = []
43
- collections_found = False
44
- for collection in all_collections:
45
- if collection.name.lower().startswith("train"):
46
- train_collections.append(collection.id)
47
- elif collection.name.lower().startswith("val"):
48
- val_collections.append(collection.id)
49
-
50
- if len(train_collections) > 0 and len(val_collections) > 0:
51
- self.train_val_splits.set_collections_splits(train_collections, val_collections)
52
- self.validator_text = Text(
53
- "Train and val collections are detected", status="info"
54
- )
55
- self.validator_text.show()
56
- collections_found = True
57
- else:
58
- self.validator_text = Text("")
59
- self.validator_text.hide()
60
-
61
-
62
- def _extend_with_nested(root_ds):
63
- nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
64
- nested_ids = [ds.id for ds in nested]
65
- return [root_ds.id] + nested_ids
66
-
67
- if not collections_found:
68
- train_val_dataset_ids = {"train": set(), "val": set()}
69
- for _, dataset in self.api.dataset.tree(self.project_id):
70
- ds_name = dataset.name.lower()
71
-
72
- if ds_name in {"train", "training"}:
73
- for _id in _extend_with_nested(dataset):
74
- train_val_dataset_ids["train"].add(_id)
75
-
76
- elif ds_name in {"val", "validation", "test", "testing"}:
77
- for _id in _extend_with_nested(dataset):
78
- train_val_dataset_ids["val"].add(_id)
79
-
80
- train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
81
- train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
82
-
83
- train_count = len(train_val_dataset_ids["train"])
84
- val_count = len(train_val_dataset_ids["val"])
85
-
86
- if train_count > 0 and val_count > 0:
87
- self.train_val_splits.set_datasets_splits(
88
- train_val_dataset_ids["train"], train_val_dataset_ids["val"]
89
- )
90
-
91
- if train_count > 0 and val_count > 0:
92
- if train_count == val_count == 1:
93
- self.validator_text = Text("train and val datasets are detected", status="info")
94
- else:
95
- self.validator_text = Text(
96
- "Multiple train and val datasets are detected. Check manually if selection is correct",
97
- status="info",
98
- )
99
- self.validator_text.show()
100
- else:
101
- self.validator_text = Text("")
102
- self.validator_text.hide()
43
+ self.train_val_splits = TrainValSplits(project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split)
103
44
 
45
+ self._detect_splits(coll_split, ds_split)
104
46
  self.button = Button("Select")
105
47
  self.display_widgets.extend([self.train_val_splits, self.validator_text, self.button])
106
48
  # -------------------------------- #
@@ -115,6 +57,22 @@ class TrainValSplitsSelector:
115
57
  )
116
58
  self.card.lock()
117
59
 
60
+ @property
61
+ def all_train_collections(self) -> List[EntitiesCollectionInfo]:
62
+ return self._all_train_collections
63
+
64
+ @property
65
+ def all_val_collections(self) -> List[EntitiesCollectionInfo]:
66
+ return self._all_val_collections
67
+
68
+ @property
69
+ def latest_train_collection(self) -> EntitiesCollectionInfo:
70
+ return self._latest_train_collection
71
+
72
+ @property
73
+ def latest_val_collection(self) -> EntitiesCollectionInfo:
74
+ return self._latest_val_collection
75
+
118
76
  @property
119
77
  def widgets_to_disable(self) -> list:
120
78
  return [self.train_val_splits]
@@ -313,10 +271,11 @@ class TrainValSplitsSelector:
313
271
  return False
314
272
  if len(empty_train_collections) > 0 or len(empty_val_collections) > 0:
315
273
  empty_collections_text = "Selected collections are empty. "
274
+ # @TODO: Use collection names instead of ids
316
275
  if len(empty_train_collections) > 0:
317
- empty_collections_text += f"train: {', '.join(empty_train_collections)}. "
276
+ empty_collections_text += f"train: {', '.join([str(collection_id) for collection_id in empty_train_collections])}. "
318
277
  if len(empty_val_collections) > 0:
319
- empty_collections_text += f"val: {', '.join(empty_val_collections)}. "
278
+ empty_collections_text += f"val: {', '.join([str(collection_id) for collection_id in empty_val_collections])}. "
320
279
  empty_collections_text += f"{ensure_text}"
321
280
  self.validator_text.set(
322
281
  text=empty_collections_text,
@@ -372,3 +331,96 @@ class TrainValSplitsSelector:
372
331
 
373
332
  def set_val_collection_ids(self, collection_ids: List[int]) -> None:
374
333
  self.train_val_splits._val_collections_select.set_selected_ids(collection_ids)
334
+
335
+ def _detect_splits(self, collections_split: bool, datasets_split: bool) -> bool:
336
+ """Detect splits based on the selected method"""
337
+ splits_found = False
338
+ if collections_split:
339
+ splits_found = self._detect_collections()
340
+ if not splits_found and datasets_split:
341
+ splits_found = self._detect_datasets()
342
+ return splits_found
343
+
344
+ def _detect_collections(self) -> bool:
345
+ """Find collections with train and val prefixes and set them to train_val_splits"""
346
+ def _get_latest_collection(collections: List[EntitiesCollectionInfo]) -> EntitiesCollectionInfo:
347
+ curr_collection = None
348
+ curr_idx = 0
349
+ for collection in collections:
350
+ collection_idx = int(collection.name.rsplit('_', 1)[-1])
351
+ if collection_idx > curr_idx:
352
+ curr_idx = collection_idx
353
+ curr_collection = collection
354
+ return curr_collection
355
+
356
+ all_collections = self.api.entities_collection.get_list(self.project_id)
357
+ train_collections = []
358
+ val_collections = []
359
+ collections_found = False
360
+ for collection in all_collections:
361
+ if collection.name.lower().startswith("train_"):
362
+ train_collections.append(collection)
363
+ elif collection.name.lower().startswith("val_"):
364
+ val_collections.append(collection)
365
+
366
+ train_collection = _get_latest_collection(train_collections)
367
+ val_collection = _get_latest_collection(val_collections)
368
+ if train_collection is not None and val_collection is not None:
369
+ self.train_val_splits.set_collections_splits([train_collection.id], [val_collection.id])
370
+ self.validator_text = Text("Train and val collections are detected", status="info")
371
+ self.validator_text.show()
372
+ collections_found = True
373
+ self._all_train_collections = train_collections
374
+ self._all_val_collections = val_collections
375
+ self._latest_train_collection = train_collection
376
+ self._latest_val_collection = val_collection
377
+ else:
378
+ self.validator_text = Text("")
379
+ self.validator_text.hide()
380
+ collections_found = False
381
+ return collections_found
382
+
383
+ def _detect_datasets(self) -> bool:
384
+ """Find datasets with train and val prefixes and set them to train_val_splits"""
385
+ def _extend_with_nested(root_ds):
386
+ nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
387
+ nested_ids = [ds.id for ds in nested]
388
+ return [root_ds.id] + nested_ids
389
+
390
+ datasets_found = False
391
+ train_val_dataset_ids = {"train": set(), "val": set()}
392
+ for _, dataset in self.api.dataset.tree(self.project_id):
393
+ ds_name = dataset.name.lower()
394
+
395
+ if ds_name in {"train", "training"}:
396
+ for _id in _extend_with_nested(dataset):
397
+ train_val_dataset_ids["train"].add(_id)
398
+
399
+ elif ds_name in {"val", "validation", "test", "testing"}:
400
+ for _id in _extend_with_nested(dataset):
401
+ train_val_dataset_ids["val"].add(_id)
402
+
403
+ train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
404
+ train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
405
+
406
+ train_count = len(train_val_dataset_ids["train"])
407
+ val_count = len(train_val_dataset_ids["val"])
408
+
409
+ if train_count > 0 and val_count > 0:
410
+ self.train_val_splits.set_datasets_splits(train_val_dataset_ids["train"], train_val_dataset_ids["val"])
411
+ datasets_found = True
412
+
413
+ if train_count > 0 and val_count > 0:
414
+ if train_count == val_count == 1:
415
+ message = "train and val datasets are detected"
416
+ else:
417
+ message = "Multiple train and val datasets are detected. Check manually if selection is correct"
418
+
419
+ self.validator_text = Text(message, status="info")
420
+ self.validator_text.show()
421
+ datasets_found = True
422
+ else:
423
+ self.validator_text = Text("")
424
+ self.validator_text.hide()
425
+ datasets_found = False
426
+ return datasets_found
@@ -65,11 +65,6 @@ class TrainingArtifacts:
65
65
 
66
66
  # Outputs
67
67
  need_generate_report = self.app_options.get("generate_report", False)
68
-
69
- # @TODO: temporary code to generate report for dev only
70
- is_dev = "dev.internal" in api.server_address
71
- if not is_dev:
72
- need_generate_report = False
73
68
  # ------------------------------------------------------------ #
74
69
 
75
70
  if need_generate_report:
@@ -56,7 +56,7 @@ from supervisely.nn.benchmark import (
56
56
  SemanticSegmentationEvaluator,
57
57
  )
58
58
  from supervisely.nn.inference import RuntimeType, SessionJSON
59
- from supervisely.nn.inference.inference import Inference
59
+ from supervisely.nn.inference.inference import Inference, torch_load_safe
60
60
  from supervisely.nn.task_type import TaskType
61
61
  from supervisely.nn.training.gui.gui import TrainGUI
62
62
  from supervisely.nn.training.gui.utils import generate_task_check_function_js
@@ -72,6 +72,7 @@ from supervisely.project.download import (
72
72
  is_cached,
73
73
  )
74
74
  from supervisely.template.experiment.experiment_generator import ExperimentGenerator
75
+ from supervisely.api.entities_collection_api import EntitiesCollectionInfo
75
76
 
76
77
 
77
78
  class TrainApp:
@@ -159,8 +160,14 @@ class TrainApp:
159
160
  self.sly_project = None
160
161
  # -------------------------- #
161
162
 
162
- self._train_split = None
163
- self._val_split = None
163
+ # Train Val Splits
164
+ self._train_split = []
165
+ self._train_split_item_ids = set()
166
+ self._train_collection_id = None
167
+
168
+ self._val_split = []
169
+ self._val_split_item_ids = set()
170
+ self._val_collection_id = None
164
171
  # -------------------------- #
165
172
 
166
173
  # Input
@@ -232,19 +239,33 @@ class TrainApp:
232
239
  self.gui.training_process.start_button.loading = False
233
240
  raise e
234
241
 
235
- # # Get training status
236
- # @self._server.post("/train_status")
237
- # def _train_status(response: Response, request: Request):
238
- # """Returns the current training status."""
239
- # status = self.gui.training_process.validator_text.get_value()
240
- # if status == "Training is in progress...":
241
- # try:
242
- # total_epochs = self.progress_bar_main.total
243
- # current_epoch = self.progress_bar_main.current
244
- # status += f" (Epoch {current_epoch}/{total_epochs})"
245
- # except Exception:
246
- # pass
247
- # return {"status": status}
242
+ @self._server.post("/train_status")
243
+ def _train_status(response: Response, request: Request):
244
+ """Returns the current training status."""
245
+ status = self.gui.training_process.validator_text.get_value()
246
+ if status == "Training is in progress...":
247
+ try:
248
+ total_epochs = getattr(self.progress_bar_main, "total", None)
249
+ current_epoch = getattr(self.progress_bar_main, "current", None)
250
+ if total_epochs is not None and current_epoch is not None:
251
+ status += f" (Epoch {current_epoch}/{total_epochs})"
252
+ except Exception:
253
+ pass
254
+ return {"status": status}
255
+
256
+ # Read GUI State when launched from experiment modal
257
+ state = self.gui._extract_state_from_env()
258
+ logger.debug(f"State: {state}")
259
+ gui_state_raw = state.get("guiState")
260
+ if gui_state_raw is not None:
261
+ logger.info("Loading GUI from state")
262
+ logger.debug(f"GUI State: {gui_state_raw}")
263
+ try:
264
+ self.gui.load_from_app_state(gui_state_raw)
265
+ logger.info("Successfully loaded GUI from state")
266
+ except Exception as e:
267
+ raise e
268
+ # ----------------------------------------- #
248
269
 
249
270
  def _register_routes(self):
250
271
  """
@@ -399,14 +420,14 @@ class TrainApp:
399
420
  :rtype: str
400
421
  """
401
422
  return self.gui.training_process.get_device()
402
-
423
+
403
424
  @property
404
425
  def base_checkpoint(self) -> str:
405
426
  """
406
427
  Returns the name of the base checkpoint.
407
428
  """
408
429
  return self.gui.model_selector.get_checkpoint_name()
409
-
430
+
410
431
  @property
411
432
  def base_checkpoint_link(self) -> str:
412
433
  """
@@ -596,13 +617,18 @@ class TrainApp:
596
617
  if self.gui.classes_selector is not None:
597
618
  if self.gui.classes_selector.is_convert_class_shapes_enabled():
598
619
  self._convert_project_to_model_task()
620
+
599
621
  # Step 4. Split Project
600
622
  self._split_project()
601
623
  # Step 5. Remove classes except selected
602
624
  if self.sly_project.type == ProjectType.IMAGES.value:
603
625
  self.sly_project.remove_classes_except(self.project_dir, self.classes, True)
604
626
  self._read_project()
605
- # Step 6. Download Model files
627
+
628
+ # Step 6. Create collections
629
+ self._create_collection_splits()
630
+
631
+ # Step 7. Download Model files
606
632
  self._download_model()
607
633
 
608
634
  def _finalize(self, experiment_info: dict) -> None:
@@ -847,7 +873,7 @@ class TrainApp:
847
873
  "TensorRT": True
848
874
  },
849
875
  },
850
- "experiment_name": "my_experiment",
876
+ "experiment_name": "My Experiment",
851
877
  }
852
878
  """
853
879
  self.gui.load_from_app_state(app_state)
@@ -1162,13 +1188,15 @@ class TrainApp:
1162
1188
  self._train_val_split_file = None
1163
1189
  self._train_split = []
1164
1190
  self._val_split = []
1191
+ self._train_split_item_ids = set()
1192
+ self._val_split_item_ids = set()
1165
1193
  return
1166
1194
 
1167
1195
  # Load splits
1168
1196
  self.gui.train_val_splits_selector.set_sly_project(self.sly_project)
1169
- self._train_split, self._val_split = (
1170
- self.gui.train_val_splits_selector.train_val_splits.get_splits()
1171
- )
1197
+ self._train_split, self._val_split = self.gui.train_val_splits_selector.train_val_splits.get_splits()
1198
+ self._train_split_ids, self._val_split_ids = [], []
1199
+ self._train_split_item_ids, self._val_split_item_ids = set(), set()
1172
1200
 
1173
1201
  # Prepare paths
1174
1202
  project_split_path = join(self.work_dir, "splits")
@@ -1177,11 +1205,13 @@ class TrainApp:
1177
1205
  "split_path": join(project_split_path, "train"),
1178
1206
  "img_dir": join(project_split_path, "train", "img"),
1179
1207
  "ann_dir": join(project_split_path, "train", "ann"),
1208
+ "img_info_dir": join(project_split_path, "train", "img_info"),
1180
1209
  },
1181
1210
  "val": {
1182
1211
  "split_path": join(project_split_path, "val"),
1183
1212
  "img_dir": join(project_split_path, "val", "img"),
1184
1213
  "ann_dir": join(project_split_path, "val", "ann"),
1214
+ "img_info_dir": join(project_split_path, "val", "img_info"),
1185
1215
  },
1186
1216
  }
1187
1217
 
@@ -1199,7 +1229,7 @@ class TrainApp:
1199
1229
  }
1200
1230
 
1201
1231
  # Utility function to move files
1202
- def move_files(split, paths, img_name_format, pbar):
1232
+ def move_files(split, split_name, paths, img_name_format, pbar):
1203
1233
  """
1204
1234
  Move files to the appropriate directories.
1205
1235
  """
@@ -1208,6 +1238,19 @@ class TrainApp:
1208
1238
  ann_name = f"{item_name}.json"
1209
1239
  shutil.copy(item.img_path, join(paths["img_dir"], item_name))
1210
1240
  shutil.copy(item.ann_path, join(paths["ann_dir"], ann_name))
1241
+
1242
+ # Move img_info
1243
+ img_info_name = f"{sly_fs.get_file_name_with_ext(item.img_path)}.json"
1244
+ img_info_path = join(dirname(dirname(item.img_path)), "img_info", img_info_name)
1245
+ # shutil.copy(img_info_path, join(paths["img_info_dir"], ann_name))
1246
+
1247
+ # Write split ids to img_info
1248
+ img_info = sly_json.load_json_file(img_info_path)
1249
+ if split_name == "train":
1250
+ self._train_split_item_ids.add(img_info["id"])
1251
+ else:
1252
+ self._val_split_item_ids.add(img_info["id"])
1253
+
1211
1254
  pbar.update(1)
1212
1255
 
1213
1256
  # Main split processing
@@ -1216,12 +1259,16 @@ class TrainApp:
1216
1259
  ) as main_pbar:
1217
1260
  self.progress_bar_main.show()
1218
1261
  for dataset in ["train", "val"]:
1219
- split = self._train_split if dataset == "train" else self._val_split
1262
+ split_name = dataset
1263
+ if split_name == "train":
1264
+ split = self._train_split
1265
+ else:
1266
+ split = self._val_split
1220
1267
  with self.progress_bar_secondary(
1221
1268
  message=f"Preparing '{dataset}'", total=len(split)
1222
1269
  ) as second_pbar:
1223
1270
  self.progress_bar_secondary.show()
1224
- move_files(split, paths[dataset], image_name_formats[dataset], second_pbar)
1271
+ move_files(split, split_name, paths[dataset], image_name_formats[dataset], second_pbar)
1225
1272
  main_pbar.update(1)
1226
1273
  self.progress_bar_secondary.hide()
1227
1274
  self.progress_bar_main.hide()
@@ -1241,10 +1288,7 @@ class TrainApp:
1241
1288
  with self.progress_bar_main(message="Processing splits", total=2) as pbar:
1242
1289
  self.progress_bar_main.show()
1243
1290
  for dataset in ["train", "val"]:
1244
- shutil.move(
1245
- paths[dataset]["split_path"],
1246
- train_ds_path if dataset == "train" else val_ds_path,
1247
- )
1291
+ shutil.move(paths[dataset]["split_path"], train_ds_path if dataset == "train" else val_ds_path)
1248
1292
  pbar.update(1)
1249
1293
  self.progress_bar_main.hide()
1250
1294
 
@@ -1667,7 +1711,8 @@ class TrainApp:
1667
1711
  try:
1668
1712
  # pylint: disable=import-error
1669
1713
  import torch
1670
- state_dict = torch.load(new_checkpoint_path)
1714
+
1715
+ state_dict = torch_load_safe(new_checkpoint_path)
1671
1716
  state_dict["model_info"] = {
1672
1717
  "task_id": self.task_id,
1673
1718
  "model_name": experiment_info["model_name"],
@@ -1679,9 +1724,7 @@ class TrainApp:
1679
1724
  state_dict["model_files"] = ckpt_files
1680
1725
  torch.save(state_dict, new_checkpoint_path)
1681
1726
  except Exception as e:
1682
- logger.warning(
1683
- f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}"
1684
- )
1727
+ logger.warning(f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}")
1685
1728
  continue
1686
1729
 
1687
1730
  new_checkpoint_paths.append(new_checkpoint_path)
@@ -1817,7 +1860,6 @@ class TrainApp:
1817
1860
  :type export_weights: dict
1818
1861
  """
1819
1862
  logger.debug("Updating experiment info")
1820
-
1821
1863
  experiment_info = {
1822
1864
  "experiment_name": self.gui.training_process.get_experiment_name(),
1823
1865
  "framework_name": self.framework_name,
@@ -1826,6 +1868,7 @@ class TrainApp:
1826
1868
  "base_checkpoint_link": self.base_checkpoint_link,
1827
1869
  "task_type": experiment_info["task_type"],
1828
1870
  "project_id": self.project_info.id,
1871
+ "project_version": self.project_info.version,
1829
1872
  "task_id": self.task_id,
1830
1873
  "model_files": experiment_info["model_files"],
1831
1874
  "checkpoints": experiment_info["checkpoints"],
@@ -1843,6 +1886,8 @@ class TrainApp:
1843
1886
  "logs": {"type": "tensorboard", "link": f"{remote_dir}logs/"},
1844
1887
  "device": self.gui.training_process.get_device_name(),
1845
1888
  "training_duration": self._training_duration,
1889
+ "train_collection_id": self._train_collection_id,
1890
+ "val_collection_id": self._val_collection_id,
1846
1891
  }
1847
1892
 
1848
1893
  if self._has_splits_selector:
@@ -1997,13 +2042,7 @@ class TrainApp:
1997
2042
  :rtype: tuple
1998
2043
  """
1999
2044
  need_generate_report = self._app_options.get("generate_report", False)
2000
- # @TODO: temporary code to generate report for dev only
2001
- is_dev = "dev.internal" in self._api.server_address
2002
- if not is_dev:
2003
- need_generate_report = False
2004
- # ------------------------------------------------------------ #
2005
-
2006
- if need_generate_report: # link to experiment page
2045
+ if need_generate_report: # link to experiment page
2007
2046
  try:
2008
2047
  output_file_info = self._generate_experiment_report(experiment_info, model_meta)
2009
2048
  experiment_info["has_report"] = True
@@ -2011,7 +2050,7 @@ class TrainApp:
2011
2050
  logger.error(f"Error generating experiment report: {e}")
2012
2051
  output_file_info = session_link_file_info
2013
2052
  experiment_info["has_report"] = False
2014
- else: # link to artifacts directory
2053
+ else: # link to artifacts directory
2015
2054
  output_file_info = session_link_file_info
2016
2055
  experiment_info["has_report"] = False
2017
2056
  return output_file_info, experiment_info
@@ -3089,4 +3128,82 @@ class TrainApp:
3089
3128
  # 4. Match splits with original project
3090
3129
  gt_split_data = self._postprocess_splits(gt_project_info.id)
3091
3130
  return gt_project_info.id, gt_split_data
3092
- return gt_project_info.id, gt_split_data
3131
+
3132
+ def _create_collection_splits(self):
3133
+ def _check_match(current_selected_collection_ids: List[int], all_split_collections: List[EntitiesCollectionInfo]):
3134
+ if len(current_selected_collection_ids) > 0:
3135
+ if len(current_selected_collection_ids) == 1:
3136
+ current_selected_collection_id = current_selected_collection_ids[0]
3137
+ for collection in all_split_collections:
3138
+ if collection.id == current_selected_collection_id:
3139
+ return True
3140
+ return False
3141
+
3142
+ # Case 1: Use existing collections for training. No need to create new collections
3143
+ split_method = self.gui.train_val_splits_selector.get_split_method()
3144
+ all_train_collections = self.gui.train_val_splits_selector.all_train_collections
3145
+ all_val_collections = self.gui.train_val_splits_selector.all_val_collections
3146
+ if split_method == "Based on collections":
3147
+ current_selected_train_collection_ids = self.gui.train_val_splits_selector.train_val_splits.get_train_collections_ids()
3148
+ train_match = _check_match(current_selected_train_collection_ids, all_train_collections)
3149
+ if train_match:
3150
+ current_selected_val_collection_ids = self.gui.train_val_splits_selector.train_val_splits.get_val_collections_ids()
3151
+ val_match = _check_match(current_selected_val_collection_ids, all_val_collections)
3152
+ if val_match:
3153
+ self._train_collection_id = current_selected_train_collection_ids[0]
3154
+ self._val_collection_id = current_selected_val_collection_ids[0]
3155
+ self._update_project_custom_data(self._train_collection_id, self._val_collection_id)
3156
+ return
3157
+ # ------------------------------------------------------------ #
3158
+
3159
+ # Case 2: Create new collections for selected train val splits. Need to create new collections
3160
+ item_type = self.project_info.type
3161
+ experiment_name = self.gui.training_process.get_experiment_name()
3162
+
3163
+ train_collection_idx = 1
3164
+ val_collection_idx = 1
3165
+
3166
+ # Get train collection with max idx
3167
+ if len(all_train_collections) > 0:
3168
+ train_collection_idx = max([int(collection.name.split("_")[1]) for collection in all_train_collections])
3169
+ train_collection_idx += 1
3170
+ # Get val collection with max idx
3171
+ if len(all_val_collections) > 0:
3172
+ val_collection_idx = max([int(collection.name.split("_")[1]) for collection in all_val_collections])
3173
+ val_collection_idx += 1
3174
+ # -------------------------------- #
3175
+
3176
+ # Create Train Collection
3177
+ train_img_ids = list(self._train_split_item_ids)
3178
+ train_collection_description = f"Collection with train {item_type} for experiment: {experiment_name}"
3179
+ train_collection = self._api.entities_collection.create(self.project_id, f"train_{train_collection_idx}", train_collection_description)
3180
+ train_collection_id = getattr(train_collection, "id", None)
3181
+ if train_collection_id is None:
3182
+ raise AttributeError("Train EntitiesCollectionInfo object does not have 'id' attribute")
3183
+ self._api.entities_collection.add_items(train_collection_id, train_img_ids)
3184
+ self._train_collection_id = train_collection_id
3185
+
3186
+ # Create Val Collection
3187
+ val_img_ids = list(self._val_split_item_ids)
3188
+ val_collection_description = f"Collection with val {item_type} for experiment: {experiment_name}"
3189
+ val_collection = self._api.entities_collection.create(self.project_id, f"val_{val_collection_idx}", val_collection_description)
3190
+ val_collection_id = getattr(val_collection, "id", None)
3191
+ if val_collection_id is None:
3192
+ raise AttributeError("Val EntitiesCollectionInfo object does not have 'id' attribute")
3193
+ self._api.entities_collection.add_items(val_collection_id, val_img_ids)
3194
+ self._val_collection_id = val_collection_id
3195
+
3196
+ # Update Project Custom Data
3197
+ self._update_project_custom_data(train_collection_id, val_collection_id)
3198
+
3199
+ def _update_project_custom_data(self, train_collection_id: int, val_collection_id: int):
3200
+ train_info = {
3201
+ "task_id": self.task_id,
3202
+ "framework_name": self.framework_name,
3203
+ "splits": {"train_collection": train_collection_id, "val_collection": val_collection_id}
3204
+ }
3205
+ custom_data = self._api.project.get_info_by_id(self.project_id).custom_data
3206
+ train_info_list = custom_data.get("train_info", [])
3207
+ train_info_list.append(train_info)
3208
+ custom_data.update({"train_info": train_info_list})
3209
+ self._api.project.update_custom_data(self.project_id, custom_data)