supervisely 6.73.374__py3-none-any.whl → 6.73.375__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ <div>
2
+ {{{widget._content}}}
3
+ </div>
@@ -1,4 +1,5 @@
1
1
  import os
2
+ from collections import defaultdict
2
3
  from typing import Dict, List, Literal, Optional, Tuple, Union
3
4
 
4
5
  import supervisely as sly
@@ -17,6 +18,7 @@ from supervisely.app.widgets import (
17
18
  from supervisely.app.widgets.random_splits_table.random_splits_table import (
18
19
  RandomSplitsTable,
19
20
  )
21
+ from supervisely.app.widgets.select_collection.select_collection import SelectCollection
20
22
  from supervisely.app.widgets.select_dataset_tree.select_dataset_tree import (
21
23
  SelectDatasetTree,
22
24
  )
@@ -41,28 +43,27 @@ class TrainValSplits(Widget):
41
43
  tags_splits: Optional[bool] = True,
42
44
  datasets_splits: Optional[bool] = True,
43
45
  widget_id: Optional[int] = None,
46
+ collections_splits: Optional[bool] = False,
44
47
  ):
45
48
  self._project_id = project_id
46
49
  self._project_fs = project_fs
47
50
 
48
- if project_fs is not None and project_id is not None:
49
- raise ValueError(
50
- "You can not provide both project_id and project_fs parameters to TrainValSplits widget."
51
- )
52
- if project_fs is None and project_id is None:
53
- raise ValueError(
54
- "You should provide at least one of: project_id or project_fs parameters to TrainValSplits widget."
55
- )
56
-
57
51
  self._project_info = None
52
+ self._project_type = None
53
+ self._project_class = None
54
+ self._api = None
58
55
  if project_id is not None:
59
56
  self._api = Api()
60
57
  self._project_info = self._api.project.get_info_by_id(
61
58
  self._project_id, raise_error=True
62
59
  )
63
60
 
64
- self._project_type = project_fs.type if project_id is None else self._project_info.type
65
- self._project_class = get_project_class(self._project_type)
61
+ if project_fs is not None:
62
+ self._project_type = project_fs.type
63
+ elif self._project_info is not None:
64
+ self._project_type = self._project_info.type
65
+ if self._project_type is not None:
66
+ self._project_class = get_project_class(self._project_type)
66
67
 
67
68
  self._random_splits_table: RandomSplitsTable = None
68
69
  self._train_tag_select: SelectTagMeta = None
@@ -70,6 +71,8 @@ class TrainValSplits(Widget):
70
71
  self._untagged_select: SelectString = None
71
72
  self._train_ds_select: Union[SelectDatasetTree, SelectString] = None
72
73
  self._val_ds_select: Union[SelectDatasetTree, SelectString] = None
74
+ self._train_collections_select: SelectCollection = None
75
+ self._val_collections_select: SelectCollection = None
73
76
  self._split_methods = []
74
77
 
75
78
  contents = []
@@ -80,12 +83,18 @@ class TrainValSplits(Widget):
80
83
  contents.append(self._get_random_content())
81
84
  if tags_splits:
82
85
  self._split_methods.append("Based on item tags")
83
- tabs_descriptions.append(f"{self._project_type.capitalize()} should have assigned train or val tag")
86
+ tabs_descriptions.append(
87
+ f"{self._project_type.capitalize()} should have assigned train or val tag"
88
+ )
84
89
  contents.append(self._get_tags_content())
85
90
  if datasets_splits:
86
91
  self._split_methods.append("Based on datasets")
87
92
  tabs_descriptions.append("Select one or several datasets for every split")
88
93
  contents.append(self._get_datasets_content())
94
+ if collections_splits:
95
+ self._split_methods.append("Based on collections")
96
+ tabs_descriptions.append("Select one or several collections for every split")
97
+ contents.append(self._get_collections_content())
89
98
  if not self._split_methods:
90
99
  raise ValueError(
91
100
  "Any of split methods [random_splits, tags_splits, datasets_splits] must be specified in TrainValSplits."
@@ -216,6 +225,32 @@ class TrainValSplits(Widget):
216
225
  widgets=[notification_box, train_field, val_field], direction="vertical", gap=5
217
226
  )
218
227
 
228
+ def _get_collections_content(self):
229
+ notification_box = NotificationBox(
230
+ title="Notice: How to make equal splits",
231
+ description="Choose the same collection(s) for train/validation to make splits equal. Can be used for debug and for tiny projects",
232
+ box_type="info",
233
+ )
234
+
235
+ self._train_collections_select = SelectCollection(multiselect=True, compact=True)
236
+ self._val_collections_select = SelectCollection(multiselect=True, compact=True)
237
+ if self._project_id is not None:
238
+ self._train_collections_select.set_project_id(self._project_id)
239
+ self._val_collections_select.set_project_id(self._project_id)
240
+ train_field = Field(
241
+ self._train_collections_select,
242
+ title="Train collection(s)",
243
+ description="all images in selected collection(s) are considered as training set",
244
+ )
245
+ val_field = Field(
246
+ self._val_collections_select,
247
+ title="Validation collection(s)",
248
+ description="all images in selected collection(s) are considered as validation set",
249
+ )
250
+ return Container(
251
+ widgets=[notification_box, train_field, val_field], direction="vertical", gap=5
252
+ )
253
+
219
254
  def get_json_data(self):
220
255
  return {}
221
256
 
@@ -223,6 +258,8 @@ class TrainValSplits(Widget):
223
258
  return {}
224
259
 
225
260
  def get_splits(self) -> Tuple[List[ItemInfo], List[ItemInfo]]:
261
+ if self._project_id is None and self._project_fs is None:
262
+ raise ValueError("Both project_id and project_fs are None.")
226
263
  split_method = self._content.get_active_tab()
227
264
  tmp_project_dir = None
228
265
  train_set, val_set = [], []
@@ -276,18 +313,35 @@ class TrainValSplits(Widget):
276
313
  train_set, val_set = self._project_class.get_train_val_splits_by_dataset(
277
314
  project_dir, train_ds_names, val_ds_names
278
315
  )
316
+ elif split_method == "Based on collections":
317
+ if self._project_id is None:
318
+ raise ValueError(
319
+ "You can not use collections_splits parameter without project_id parameter."
320
+ )
321
+ train_collections = self._train_collections_select.get_selected_ids()
322
+ val_collections = self._val_collections_select.get_selected_ids()
323
+
324
+ train_set, val_set = self._project_class.get_train_val_splits_by_collections(
325
+ project_dir,
326
+ train_collections,
327
+ val_collections,
328
+ self._project_id,
329
+ self._api,
330
+ )
279
331
 
280
332
  if tmp_project_dir is not None:
281
333
  remove_dir(tmp_project_dir)
282
334
  return train_set, val_set
283
335
 
284
- def set_split_method(self, split_method: Literal["random", "tags", "datasets"]):
336
+ def set_split_method(self, split_method: Literal["random", "tags", "datasets", "collections"]):
285
337
  if split_method == "random":
286
338
  split_method = "Random"
287
339
  elif split_method == "tags":
288
340
  split_method = "Based on item tags"
289
341
  elif split_method == "datasets":
290
342
  split_method = "Based on datasets"
343
+ elif split_method == "collections":
344
+ split_method = "Based on collections"
291
345
  self._content.set_active_tab(split_method)
292
346
  StateJson().send_changes()
293
347
  DataJson().send_changes()
@@ -337,6 +391,42 @@ class TrainValSplits(Widget):
337
391
  def get_val_dataset_ids(self) -> List[int]:
338
392
  return self._val_ds_select.get_selected_ids()
339
393
 
394
+ def set_project_id_for_collections(self, project_id: int):
395
+ if not isinstance(project_id, int):
396
+ raise ValueError("Project ID must be an integer.")
397
+ self._project_id = project_id
398
+ self._project_type = None
399
+ if self._api is None:
400
+ self._api = Api()
401
+ self._project_info = self._api.project.get_info_by_id(self._project_id, raise_error=True)
402
+ self._project_type = self._project_info.type
403
+ self._project_class = get_project_class(self._project_type)
404
+ if not self._train_collections_select or not self._val_collections_select:
405
+ raise ValueError("Collections select widgets are not initialized.")
406
+ self._train_collections_select.set_project_id(project_id)
407
+ self._val_collections_select.set_project_id(project_id)
408
+
409
+ def get_train_collections_ids(self) -> List[int]:
410
+ return self._train_collections_select.get_selected_ids() or []
411
+
412
+ def get_val_collections_ids(self) -> List[int]:
413
+ return self._val_collections_select.get_selected_ids() or []
414
+
415
+ def set_collections_splits(self, train_collections: List[int], val_collections: List[int]):
416
+ self._content.set_active_tab("Based on collections")
417
+ self.set_collections_splits_by_ids("train", train_collections)
418
+ self.set_collections_splits_by_ids("val", val_collections)
419
+
420
+ def set_collections_splits_by_ids(
421
+ self, split: Literal["train", "val"], collection_ids: List[int]
422
+ ):
423
+ if split == "train":
424
+ self._train_collections_select.set_collections(collection_ids)
425
+ elif split == "val":
426
+ self._val_collections_select.set_collections(collection_ids)
427
+ else:
428
+ raise ValueError("Split value must be 'train' or 'val'")
429
+
340
430
  def get_untagged_action(self) -> str:
341
431
  return self._untagged_select.get_value()
342
432
 
@@ -355,6 +445,10 @@ class TrainValSplits(Widget):
355
445
  if self._val_ds_select is not None:
356
446
  self._val_ds_select.disable()
357
447
  self._disabled = True
448
+ if self._train_collections_select is not None:
449
+ self._train_collections_select.disable()
450
+ if self._val_collections_select is not None:
451
+ self._val_collections_select.disable()
358
452
  DataJson()[self.widget_id]["disabled"] = self._disabled
359
453
  DataJson().send_changes()
360
454
 
@@ -373,5 +467,9 @@ class TrainValSplits(Widget):
373
467
  if self._val_ds_select is not None:
374
468
  self._val_ds_select.enable()
375
469
  self._disabled = False
470
+ if self._train_collections_select is not None:
471
+ self._train_collections_select.enable()
472
+ if self._val_collections_select is not None:
473
+ self._val_collections_select.enable()
376
474
  DataJson()[self.widget_id]["disabled"] = self._disabled
377
475
  DataJson().send_changes()
@@ -815,6 +815,20 @@ class TrainGUI:
815
815
  raise ValueError("split must be 'train' or 'val'")
816
816
  if not isinstance(percent, int) or not 0 < percent < 100:
817
817
  raise ValueError("percent must be an integer in range 1 to 99")
818
+ elif train_val_splits_settings.get("method") == "collections":
819
+ train_collections = train_val_splits_settings.get("train_collections", [])
820
+ val_collections = train_val_splits_settings.get("val_collections", [])
821
+ collection_ids = set()
822
+ for collection in self._api.entities_collection.get_list(self.project_id):
823
+ collection_ids.add(collection.id)
824
+ missing_collections_ids = set(train_collections + val_collections) - collection_ids
825
+ if missing_collections_ids:
826
+ missing_collections_text = ", ".join(
827
+ [str(collection_id) for collection_id in missing_collections_ids]
828
+ )
829
+ raise ValueError(
830
+ f"Collections with ids: {missing_collections_text} not found in the project"
831
+ )
818
832
  return app_state
819
833
 
820
834
  def load_from_app_state(self, app_state: Union[str, dict]) -> None:
@@ -849,7 +863,8 @@ class TrainGUI:
849
863
  "ONNXRuntime": True,
850
864
  "TensorRT": True
851
865
  },
852
- }
866
+ },
867
+ "experiment_name": "my_experiment",
853
868
  }
854
869
  """
855
870
  if isinstance(app_state, str):
@@ -863,6 +878,7 @@ class TrainGUI:
863
878
  tags_settings = app_state.get("tags", [])
864
879
  model_settings = app_state["model"]
865
880
  hyperparameters_settings = app_state["hyperparameters"]
881
+ experiment_name = app_state.get("experiment_name", None)
866
882
 
867
883
  self._init_input(input_settings, options)
868
884
  self._init_train_val_splits(train_val_splits_settings, options)
@@ -870,6 +886,8 @@ class TrainGUI:
870
886
  self._init_tags(tags_settings, options)
871
887
  self._init_model(model_settings, options)
872
888
  self._init_hyperparameters(hyperparameters_settings, options)
889
+ if experiment_name is not None:
890
+ self.training_process.set_experiment_name(experiment_name)
873
891
 
874
892
  def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
875
893
  """
@@ -938,6 +956,15 @@ class TrainGUI:
938
956
  self.train_val_splits_selector.train_val_splits.set_datasets_splits(
939
957
  train_datasets, val_datasets
940
958
  )
959
+ elif split_method == "collections":
960
+ train_collections = train_val_splits_settings["train_collections"]
961
+ val_collections = train_val_splits_settings["val_collections"]
962
+ self.train_val_splits_selector.train_val_splits.set_project_id_for_collections(
963
+ self.project_id
964
+ )
965
+ self.train_val_splits_selector.train_val_splits.set_collections_splits(
966
+ train_collections, val_collections
967
+ )
941
968
  self.train_val_splits_selector_cb()
942
969
 
943
970
  def _init_classes(self, classes_settings: list, options: dict) -> None:
@@ -26,53 +26,80 @@ class TrainValSplitsSelector:
26
26
  # GUI Components
27
27
  split_methods = self.app_options.get("train_val_split_methods", [])
28
28
  if len(split_methods) == 0:
29
- split_methods = ["Random", "Based on tags", "Based on datasets"]
29
+ split_methods = ["Random", "Based on tags", "Based on datasets", "Based on collections"]
30
30
  random_split = "Random" in split_methods
31
31
  tag_split = "Based on tags" in split_methods
32
32
  ds_split = "Based on datasets" in split_methods
33
+ coll_split = "Based on collections" in split_methods
33
34
 
34
- self.train_val_splits = TrainValSplits(project_id, None, random_split, tag_split, ds_split)
35
+ self.train_val_splits = TrainValSplits(
36
+ project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split
37
+ )
38
+
39
+ # check for collections with "train" and "val" prefixes
40
+ all_collections = self.api.entities_collection.get_list(self.project_id)
41
+ train_collections = []
42
+ val_collections = []
43
+ collections_found = False
44
+ for collection in all_collections:
45
+ if collection.name.lower().startswith("train"):
46
+ train_collections.append(collection.id)
47
+ elif collection.name.lower().startswith("val"):
48
+ val_collections.append(collection.id)
49
+
50
+ if len(train_collections) > 0 and len(val_collections) > 0:
51
+ self.train_val_splits.set_collections_splits(train_collections, val_collections)
52
+ self.validator_text = Text(
53
+ "Train and val collections are detected", status="info"
54
+ )
55
+ self.validator_text.show()
56
+ collections_found = True
57
+ else:
58
+ self.validator_text = Text("")
59
+ self.validator_text.hide()
60
+
35
61
 
36
62
  def _extend_with_nested(root_ds):
37
63
  nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
38
64
  nested_ids = [ds.id for ds in nested]
39
65
  return [root_ds.id] + nested_ids
40
66
 
41
- train_val_dataset_ids = {"train": set(), "val": set()}
42
- for _, dataset in self.api.dataset.tree(self.project_id):
43
- ds_name = dataset.name.lower()
67
+ if not collections_found:
68
+ train_val_dataset_ids = {"train": set(), "val": set()}
69
+ for _, dataset in self.api.dataset.tree(self.project_id):
70
+ ds_name = dataset.name.lower()
44
71
 
45
- if ds_name in {"train", "training"}:
46
- for _id in _extend_with_nested(dataset):
47
- train_val_dataset_ids["train"].add(_id)
72
+ if ds_name in {"train", "training"}:
73
+ for _id in _extend_with_nested(dataset):
74
+ train_val_dataset_ids["train"].add(_id)
48
75
 
49
- elif ds_name in {"val", "validation", "test", "testing"}:
50
- for _id in _extend_with_nested(dataset):
51
- train_val_dataset_ids["val"].add(_id)
76
+ elif ds_name in {"val", "validation", "test", "testing"}:
77
+ for _id in _extend_with_nested(dataset):
78
+ train_val_dataset_ids["val"].add(_id)
52
79
 
53
- train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
54
- train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
80
+ train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
81
+ train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
55
82
 
56
- train_count = len(train_val_dataset_ids["train"])
57
- val_count = len(train_val_dataset_ids["val"])
83
+ train_count = len(train_val_dataset_ids["train"])
84
+ val_count = len(train_val_dataset_ids["val"])
58
85
 
59
- if train_count > 0 and val_count > 0:
60
- self.train_val_splits.set_datasets_splits(
61
- train_val_dataset_ids["train"], train_val_dataset_ids["val"]
62
- )
86
+ if train_count > 0 and val_count > 0:
87
+ self.train_val_splits.set_datasets_splits(
88
+ train_val_dataset_ids["train"], train_val_dataset_ids["val"]
89
+ )
63
90
 
64
- if train_count > 0 and val_count > 0:
65
- if train_count == val_count == 1:
66
- self.validator_text = Text("train and val datasets are detected", status="info")
91
+ if train_count > 0 and val_count > 0:
92
+ if train_count == val_count == 1:
93
+ self.validator_text = Text("train and val datasets are detected", status="info")
94
+ else:
95
+ self.validator_text = Text(
96
+ "Multiple train and val datasets are detected. Check manually if selection is correct",
97
+ status="info",
98
+ )
99
+ self.validator_text.show()
67
100
  else:
68
- self.validator_text = Text(
69
- "Multiple train and val datasets are detected. Check manually if selection is correct",
70
- status="info",
71
- )
72
- self.validator_text.show()
73
- else:
74
- self.validator_text = Text("")
75
- self.validator_text.hide()
101
+ self.validator_text = Text("")
102
+ self.validator_text.hide()
76
103
 
77
104
  self.button = Button("Select")
78
105
  self.display_widgets.extend([self.train_val_splits, self.validator_text, self.button])
@@ -238,6 +265,68 @@ class TrainValSplitsSelector:
238
265
  self.validator_text.set("Train and val datasets are selected", status="success")
239
266
  return True
240
267
 
268
+ def validate_based_on_collections():
269
+ train_collection_id = self.train_val_splits.get_train_collections_ids()
270
+ val_collection_id = self.train_val_splits.get_val_collections_ids()
271
+ if train_collection_id is None and val_collection_id is None:
272
+ self.validator_text.set("No collections are selected", status="error")
273
+ return False
274
+ if len(train_collection_id) == 0 or len(val_collection_id) == 0:
275
+ self.validator_text.set("Collections are not selected", status="error")
276
+ return False
277
+ if set(train_collection_id) == set(val_collection_id):
278
+ self.validator_text.set(
279
+ text=f"Same collections are selected for both train and val splits. {ensure_text} {warning_text}",
280
+ status="warning",
281
+ )
282
+ return True
283
+ from supervisely.api.entities_collection_api import CollectionTypeFilter
284
+
285
+ train_items = set()
286
+ empty_train_collections = []
287
+ for collection_id in train_collection_id:
288
+ items = self.api.entities_collection.get_items(
289
+ collection_id=collection_id,
290
+ project_id=self.project_id,
291
+ collection_type=CollectionTypeFilter.DEFAULT,
292
+ )
293
+ train_items.update([item.id for item in items])
294
+ if len(items) == 0:
295
+ empty_train_collections.append(collection_id)
296
+ val_items = set()
297
+ empty_val_collections = []
298
+ for collection_id in val_collection_id:
299
+ items = self.api.entities_collection.get_items(
300
+ collection_id=collection_id,
301
+ project_id=self.project_id,
302
+ collection_type=CollectionTypeFilter.DEFAULT,
303
+ )
304
+ val_items.update([item.id for item in items])
305
+ if len(items) == 0:
306
+ empty_val_collections.append(collection_id)
307
+ if len(train_items) == 0 and len(val_items) == 0:
308
+ self.validator_text.set(
309
+ text="All selected collections are empty. ",
310
+ status="error",
311
+ )
312
+ return False
313
+ if len(empty_train_collections) > 0 or len(empty_val_collections) > 0:
314
+ empty_collections_text = "Selected collections are empty. "
315
+ if len(empty_train_collections) > 0:
316
+ empty_collections_text += f"train: {', '.join(empty_train_collections)}. "
317
+ if len(empty_val_collections) > 0:
318
+ empty_collections_text += f"val: {', '.join(empty_val_collections)}. "
319
+ empty_collections_text += f"{ensure_text}"
320
+ self.validator_text.set(
321
+ text=empty_collections_text,
322
+ status="error",
323
+ )
324
+ return True
325
+
326
+ else:
327
+ self.validator_text.set("Train and val collections are selected", status="success")
328
+ return True
329
+
241
330
  if split_method == "Random":
242
331
  is_valid = validate_random_split()
243
332
 
@@ -246,6 +335,8 @@ class TrainValSplitsSelector:
246
335
 
247
336
  elif split_method == "Based on datasets":
248
337
  is_valid = validate_based_on_datasets()
338
+ elif split_method == "Based on collections":
339
+ is_valid = validate_based_on_collections()
249
340
 
250
341
  # @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
251
342
  self.validator_text.show()
@@ -268,3 +359,15 @@ class TrainValSplitsSelector:
268
359
 
269
360
  def set_val_dataset_ids(self, dataset_ids: List[int]) -> None:
270
361
  self.train_val_splits._val_ds_select.set_selected_ids(dataset_ids)
362
+
363
+ def get_train_collection_ids(self) -> List[int]:
364
+ return self.train_val_splits._train_collections_select.get_selected_ids()
365
+
366
+ def set_train_collection_ids(self, collection_ids: List[int]) -> None:
367
+ self.train_val_splits._train_collections_select.set_selected_ids(collection_ids)
368
+
369
+ def get_val_collection_ids(self) -> List[int]:
370
+ return self.train_val_splits._val_collections_select.get_selected_ids()
371
+
372
+ def set_val_collection_ids(self, collection_ids: List[int]) -> None:
373
+ self.train_val_splits._val_collections_select.set_selected_ids(collection_ids)
@@ -208,16 +208,41 @@ class TrainApp:
208
208
  def _train_from_api(response: Response, request: Request):
209
209
  try:
210
210
  state = request.state.state
211
+ wait = state.get("wait", True)
211
212
  app_state = state["app_state"]
212
213
  self.gui.load_from_app_state(app_state)
213
214
 
214
- self._wrapped_start_training()
215
+ if wait:
216
+ self._wrapped_start_training()
217
+ else:
218
+ import threading
219
+
220
+ training_thread = threading.Thread(
221
+ target=self._wrapped_start_training,
222
+ daemon=True,
223
+ )
224
+ training_thread.start()
225
+ return {"result": "model training started"}
215
226
 
216
227
  return {"result": "model was successfully trained"}
217
228
  except Exception as e:
218
229
  self.gui.training_process.start_button.loading = False
219
230
  raise e
220
231
 
232
+ # # Get training status
233
+ # @self._server.post("/train_status")
234
+ # def _train_status(response: Response, request: Request):
235
+ # """Returns the current training status."""
236
+ # status = self.gui.training_process.validator_text.get_value()
237
+ # if status == "Training is in progress...":
238
+ # try:
239
+ # total_epochs = self.progress_bar_main.total
240
+ # current_epoch = self.progress_bar_main.current
241
+ # status += f" (Epoch {current_epoch}/{total_epochs})"
242
+ # except Exception:
243
+ # pass
244
+ # return {"status": status}
245
+
221
246
  def _register_routes(self):
222
247
  """
223
248
  Registers API routes for TensorBoard and training endpoints.
@@ -1870,6 +1895,13 @@ class TrainApp:
1870
1895
  "val_datasets": self.gui.train_val_splits_selector.train_val_splits.get_val_dataset_ids(),
1871
1896
  }
1872
1897
  )
1898
+ elif split_method == "Based on collections":
1899
+ train_val_splits.update(
1900
+ {
1901
+ "train_collections": self.gui.train_val_splits_selector.get_train_collection_ids(),
1902
+ "val_collections": self.gui.train_val_splits_selector.get_val_collection_ids(),
1903
+ }
1904
+ )
1873
1905
  return train_val_splits
1874
1906
 
1875
1907
  def _get_model_config_for_app_state(self, experiment_info: Dict = None) -> Dict:
@@ -2110,7 +2142,7 @@ class TrainApp:
2110
2142
  ]
2111
2143
  task_type = experiment_info["task_type"]
2112
2144
  if task_type not in supported_task_types:
2113
- logger.warn(
2145
+ logger.warning(
2114
2146
  f"Task type: '{task_type}' is not supported for Model Benchmark. "
2115
2147
  f"Supported tasks: {', '.join(task_type)}"
2116
2148
  )
@@ -495,6 +495,22 @@ class PointcloudEpisodeProject(PointcloudProject):
495
495
  _add_items_to_list(project, val_datasets, val_items)
496
496
  return train_items, val_items
497
497
 
498
+ @staticmethod
499
+ def get_train_val_splits_by_collections(
500
+ project_dir: str,
501
+ train_collections: List[int],
502
+ val_collections: List[int],
503
+ project_id: int,
504
+ api: Api,
505
+ ) -> None:
506
+ """
507
+ Not available for PointcloudEpisodeProject class.
508
+ :raises: :class:`NotImplementedError` in all cases.
509
+ """
510
+ raise NotImplementedError(
511
+ f"Static method 'get_train_val_splits_by_collections()' is not supported for PointcloudEpisodeProject class now."
512
+ )
513
+
498
514
  @staticmethod
499
515
  def download(
500
516
  api: Api,
@@ -725,6 +725,22 @@ class PointcloudProject(VideoProject):
725
725
  _add_items_to_list(project, val_datasets, val_items)
726
726
  return train_items, val_items
727
727
 
728
+ @staticmethod
729
+ def get_train_val_splits_by_collections(
730
+ project_dir: str,
731
+ train_collections: List[int],
732
+ val_collections: List[int],
733
+ project_id: int,
734
+ api: Api,
735
+ ) -> None:
736
+ """
737
+ Not available for PointcloudProject class.
738
+ :raises: :class:`NotImplementedError` in all cases.
739
+ """
740
+ raise NotImplementedError(
741
+ f"Static method 'get_train_val_splits_by_collections()' is not supported for PointcloudProject class now."
742
+ )
743
+
728
744
  @staticmethod
729
745
  def download(
730
746
  api: Api,
@@ -3277,6 +3277,63 @@ class Project:
3277
3277
  _add_items_to_list(project, val_datasets, val_items)
3278
3278
  return train_items, val_items
3279
3279
 
3280
+ @staticmethod
3281
+ def get_train_val_splits_by_collections(
3282
+ project_dir: str,
3283
+ train_collections: List[int],
3284
+ val_collections: List[int],
3285
+ project_id: int,
3286
+ api: Api,
3287
+ ) -> Tuple[List[ItemInfo], List[ItemInfo]]:
3288
+ """
3289
+ Get train and val items information from project by given train and val collections IDs.
3290
+
3291
+ :param project_dir: Path to project directory.
3292
+ :type project_dir: :class:`str`
3293
+ :param train_collections: List of train collections IDs.
3294
+ :type train_collections: :class:`list` [ :class:`int` ]
3295
+ :param val_collections: List of val collections IDs.
3296
+ :type val_collections: :class:`list` [ :class:`int` ]
3297
+ :param project_id: Project ID.
3298
+ :type project_id: :class:`int`
3299
+ :param api: Supervisely API address and token.
3300
+ :type api: :class:`Api<supervisely.api.api.Api>`
3301
+ :raises: :class:`KeyError` if collection ID not found in project
3302
+ :return: Tuple with lists of train items information and val items information
3303
+ :rtype: :class:`list` [ :class:`ItemInfo<ItemInfo>` ], :class:`list` [ :class:`ItemInfo<ItemInfo>` ]
3304
+ """
3305
+ from supervisely.api.entities_collection_api import CollectionTypeFilter
3306
+
3307
+ project = Project(project_dir, OpenMode.READ)
3308
+
3309
+ ds_id_to_name = {}
3310
+ for parents, ds_info in api.dataset.tree(project_id):
3311
+ full_name = "/".join(parents + [ds_info.name])
3312
+ ds_id_to_name[ds_info.id] = full_name
3313
+
3314
+ train_items = []
3315
+ val_items = []
3316
+
3317
+ for collection_ids, items_dict in [
3318
+ (train_collections, train_items),
3319
+ (val_collections, val_items),
3320
+ ]:
3321
+ for collection_id in collection_ids:
3322
+ collection_items = api.entities_collection.get_items(
3323
+ collection_id=collection_id,
3324
+ project_id=project_id,
3325
+ collection_type=CollectionTypeFilter.DEFAULT,
3326
+ )
3327
+ for item in collection_items:
3328
+ ds_name = ds_id_to_name.get(item.dataset_id)
3329
+ ds = project.datasets.get(ds_name)
3330
+ img_path, ann_path = ds.get_item_paths(item.name)
3331
+ info = ItemInfo(ds_name, item.name, img_path, ann_path)
3332
+ items_dict.append(info)
3333
+
3334
+ return train_items, val_items
3335
+
3336
+
3280
3337
  @staticmethod
3281
3338
  def download(
3282
3339
  api: Api,