supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (190) hide show
  1. supervisely/__init__.py +136 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/annotation/label.py +80 -3
  5. supervisely/api/annotation_api.py +9 -9
  6. supervisely/api/api.py +67 -43
  7. supervisely/api/app_api.py +72 -5
  8. supervisely/api/dataset_api.py +108 -33
  9. supervisely/api/entity_annotation/figure_api.py +113 -49
  10. supervisely/api/image_api.py +82 -0
  11. supervisely/api/module_api.py +10 -0
  12. supervisely/api/nn/deploy_api.py +15 -9
  13. supervisely/api/nn/ecosystem_models_api.py +201 -0
  14. supervisely/api/nn/neural_network_api.py +12 -3
  15. supervisely/api/pointcloud/pointcloud_api.py +38 -0
  16. supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
  17. supervisely/api/project_api.py +213 -6
  18. supervisely/api/task_api.py +11 -1
  19. supervisely/api/video/video_annotation_api.py +4 -2
  20. supervisely/api/video/video_api.py +79 -1
  21. supervisely/api/video/video_figure_api.py +24 -11
  22. supervisely/api/volume/volume_api.py +38 -0
  23. supervisely/app/__init__.py +1 -1
  24. supervisely/app/content.py +14 -6
  25. supervisely/app/fastapi/__init__.py +1 -0
  26. supervisely/app/fastapi/custom_static_files.py +1 -1
  27. supervisely/app/fastapi/multi_user.py +88 -0
  28. supervisely/app/fastapi/subapp.py +175 -42
  29. supervisely/app/fastapi/templating.py +1 -1
  30. supervisely/app/fastapi/websocket.py +77 -9
  31. supervisely/app/singleton.py +21 -0
  32. supervisely/app/v1/app_service.py +18 -2
  33. supervisely/app/v1/constants.py +7 -1
  34. supervisely/app/widgets/__init__.py +11 -1
  35. supervisely/app/widgets/agent_selector/template.html +1 -0
  36. supervisely/app/widgets/card/card.py +20 -0
  37. supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
  38. supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
  39. supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
  40. supervisely/app/widgets/dialog/dialog.py +12 -0
  41. supervisely/app/widgets/dialog/template.html +2 -1
  42. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  43. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  44. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  45. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  46. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
  47. supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
  48. supervisely/app/widgets/fast_table/fast_table.py +713 -126
  49. supervisely/app/widgets/fast_table/script.js +492 -95
  50. supervisely/app/widgets/fast_table/style.css +54 -0
  51. supervisely/app/widgets/fast_table/template.html +45 -5
  52. supervisely/app/widgets/heatmap/__init__.py +0 -0
  53. supervisely/app/widgets/heatmap/heatmap.py +523 -0
  54. supervisely/app/widgets/heatmap/script.js +378 -0
  55. supervisely/app/widgets/heatmap/style.css +227 -0
  56. supervisely/app/widgets/heatmap/template.html +21 -0
  57. supervisely/app/widgets/input_tag/input_tag.py +102 -15
  58. supervisely/app/widgets/input_tag_list/__init__.py +0 -0
  59. supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
  60. supervisely/app/widgets/input_tag_list/template.html +70 -0
  61. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  62. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  63. supervisely/app/widgets/radio_tabs/template.html +1 -0
  64. supervisely/app/widgets/select/select.py +6 -4
  65. supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
  66. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
  67. supervisely/app/widgets/table/table.py +68 -13
  68. supervisely/app/widgets/tabs/tabs.py +22 -6
  69. supervisely/app/widgets/tabs/template.html +5 -1
  70. supervisely/app/widgets/transfer/style.css +3 -0
  71. supervisely/app/widgets/transfer/template.html +3 -1
  72. supervisely/app/widgets/transfer/transfer.py +48 -45
  73. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  74. supervisely/convert/image/csv/csv_converter.py +24 -15
  75. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
  76. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
  77. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
  78. supervisely/convert/video/video_converter.py +2 -2
  79. supervisely/geometry/polyline_3d.py +110 -0
  80. supervisely/io/env.py +161 -1
  81. supervisely/nn/artifacts/__init__.py +1 -1
  82. supervisely/nn/artifacts/artifacts.py +10 -2
  83. supervisely/nn/artifacts/detectron2.py +1 -0
  84. supervisely/nn/artifacts/hrda.py +1 -0
  85. supervisely/nn/artifacts/mmclassification.py +20 -0
  86. supervisely/nn/artifacts/mmdetection.py +5 -3
  87. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  88. supervisely/nn/artifacts/ritm.py +1 -0
  89. supervisely/nn/artifacts/rtdetr.py +1 -0
  90. supervisely/nn/artifacts/unet.py +1 -0
  91. supervisely/nn/artifacts/utils.py +3 -0
  92. supervisely/nn/artifacts/yolov5.py +2 -0
  93. supervisely/nn/artifacts/yolov8.py +1 -0
  94. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  95. supervisely/nn/experiments.py +9 -0
  96. supervisely/nn/inference/cache.py +37 -17
  97. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  98. supervisely/nn/inference/inference.py +953 -211
  99. supervisely/nn/inference/inference_request.py +15 -8
  100. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  101. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  102. supervisely/nn/inference/predict_app/__init__.py +0 -0
  103. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  104. supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
  105. supervisely/nn/inference/predict_app/gui/gui.py +915 -0
  106. supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
  107. supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
  108. supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
  109. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  110. supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
  111. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  112. supervisely/nn/inference/predict_app/gui/utils.py +399 -0
  113. supervisely/nn/inference/predict_app/predict_app.py +176 -0
  114. supervisely/nn/inference/session.py +47 -39
  115. supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
  116. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  117. supervisely/nn/inference/tracking/tracker_interface.py +4 -0
  118. supervisely/nn/inference/uploader.py +9 -5
  119. supervisely/nn/model/model_api.py +44 -22
  120. supervisely/nn/model/prediction.py +15 -1
  121. supervisely/nn/model/prediction_session.py +70 -14
  122. supervisely/nn/prediction_dto.py +7 -0
  123. supervisely/nn/tracker/__init__.py +6 -8
  124. supervisely/nn/tracker/base_tracker.py +54 -0
  125. supervisely/nn/tracker/botsort/__init__.py +1 -0
  126. supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
  127. supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
  128. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  129. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  130. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  131. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  132. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  133. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  134. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  135. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  136. supervisely/nn/tracker/botsort_tracker.py +273 -0
  137. supervisely/nn/tracker/calculate_metrics.py +264 -0
  138. supervisely/nn/tracker/utils.py +273 -0
  139. supervisely/nn/tracker/visualize.py +520 -0
  140. supervisely/nn/training/gui/gui.py +152 -49
  141. supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
  142. supervisely/nn/training/gui/model_selector.py +8 -6
  143. supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
  144. supervisely/nn/training/gui/training_artifacts.py +3 -1
  145. supervisely/nn/training/train_app.py +225 -46
  146. supervisely/project/pointcloud_episode_project.py +12 -8
  147. supervisely/project/pointcloud_project.py +12 -8
  148. supervisely/project/project.py +221 -75
  149. supervisely/template/experiment/experiment.html.jinja +105 -55
  150. supervisely/template/experiment/experiment_generator.py +258 -112
  151. supervisely/template/experiment/header.html.jinja +31 -13
  152. supervisely/template/experiment/sly-style.css +7 -2
  153. supervisely/versions.json +3 -1
  154. supervisely/video/sampling.py +42 -20
  155. supervisely/video/video.py +41 -12
  156. supervisely/video_annotation/video_figure.py +38 -4
  157. supervisely/volume/stl_converter.py +2 -0
  158. supervisely/worker_api/agent_rpc.py +24 -1
  159. supervisely/worker_api/rpc_servicer.py +31 -7
  160. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
  161. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
  162. supervisely_lib/__init__.py +6 -1
  163. supervisely/app/widgets/experiment_selector/style.css +0 -27
  164. supervisely/app/widgets/experiment_selector/template.html +0 -61
  165. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  166. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  167. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  168. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  169. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  170. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  171. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  172. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  173. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  174. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  175. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  176. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  177. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  178. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  179. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  180. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  181. supervisely/nn/tracker/tracker.py +0 -285
  182. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  183. supervisely/nn/tracking/__init__.py +0 -1
  184. supervisely/nn/tracking/boxmot.py +0 -114
  185. supervisely/nn/tracking/tracking.py +0 -24
  186. /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
  187. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
  188. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
  189. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
  190. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
@@ -56,7 +56,7 @@ from supervisely.nn.benchmark import (
56
56
  SemanticSegmentationEvaluator,
57
57
  )
58
58
  from supervisely.nn.inference import RuntimeType, SessionJSON
59
- from supervisely.nn.inference.inference import Inference
59
+ from supervisely.nn.inference.inference import Inference, torch_load_safe
60
60
  from supervisely.nn.task_type import TaskType
61
61
  from supervisely.nn.training.gui.gui import TrainGUI
62
62
  from supervisely.nn.training.gui.utils import generate_task_check_function_js
@@ -72,6 +72,7 @@ from supervisely.project.download import (
72
72
  is_cached,
73
73
  )
74
74
  from supervisely.template.experiment.experiment_generator import ExperimentGenerator
75
+ from supervisely.api.entities_collection_api import EntitiesCollectionInfo
75
76
 
76
77
 
77
78
  class TrainApp:
@@ -159,8 +160,14 @@ class TrainApp:
159
160
  self.sly_project = None
160
161
  # -------------------------- #
161
162
 
162
- self._train_split = None
163
- self._val_split = None
163
+ # Train Val Splits
164
+ self._train_split = []
165
+ self._train_split_item_ids = set()
166
+ self._train_collection_id = None
167
+
168
+ self._val_split = []
169
+ self._val_split_item_ids = set()
170
+ self._val_collection_id = None
164
171
  # -------------------------- #
165
172
 
166
173
  # Input
@@ -232,19 +239,33 @@ class TrainApp:
232
239
  self.gui.training_process.start_button.loading = False
233
240
  raise e
234
241
 
235
- # # Get training status
236
- # @self._server.post("/train_status")
237
- # def _train_status(response: Response, request: Request):
238
- # """Returns the current training status."""
239
- # status = self.gui.training_process.validator_text.get_value()
240
- # if status == "Training is in progress...":
241
- # try:
242
- # total_epochs = self.progress_bar_main.total
243
- # current_epoch = self.progress_bar_main.current
244
- # status += f" (Epoch {current_epoch}/{total_epochs})"
245
- # except Exception:
246
- # pass
247
- # return {"status": status}
242
+ @self._server.post("/train_status")
243
+ def _train_status(response: Response, request: Request):
244
+ """Returns the current training status."""
245
+ status = self.gui.training_process.validator_text.get_value()
246
+ if status == "Training is in progress...":
247
+ try:
248
+ total_epochs = getattr(self.progress_bar_main, "total", None)
249
+ current_epoch = getattr(self.progress_bar_main, "current", None)
250
+ if total_epochs is not None and current_epoch is not None:
251
+ status += f" (Epoch {current_epoch}/{total_epochs})"
252
+ except Exception:
253
+ pass
254
+ return {"status": status}
255
+
256
+ # Read GUI State when launched from experiment modal
257
+ state = self.gui._extract_state_from_env()
258
+ logger.debug(f"State: {state}")
259
+ gui_state_raw = state.get("guiState")
260
+ if gui_state_raw is not None:
261
+ logger.info("Loading GUI from state")
262
+ logger.debug(f"GUI State: {gui_state_raw}")
263
+ try:
264
+ self.gui.load_from_app_state(gui_state_raw)
265
+ logger.info("Successfully loaded GUI from state")
266
+ except Exception as e:
267
+ raise e
268
+ # ----------------------------------------- #
248
269
 
249
270
  def _register_routes(self):
250
271
  """
@@ -284,6 +305,12 @@ class TrainApp:
284
305
 
285
306
  # Properties
286
307
  # General
308
+ @property
309
+ def auto_start(self) -> bool:
310
+ """
311
+ If True, the training will start automatically after the GUI is loaded and train server is started.
312
+ """
313
+ return self.gui._start_training
287
314
  # ----------------------------------------- #
288
315
 
289
316
  # Input Data
@@ -399,14 +426,14 @@ class TrainApp:
399
426
  :rtype: str
400
427
  """
401
428
  return self.gui.training_process.get_device()
402
-
429
+
403
430
  @property
404
431
  def base_checkpoint(self) -> str:
405
432
  """
406
433
  Returns the name of the base checkpoint.
407
434
  """
408
435
  return self.gui.model_selector.get_checkpoint_name()
409
-
436
+
410
437
  @property
411
438
  def base_checkpoint_link(self) -> str:
412
439
  """
@@ -596,13 +623,18 @@ class TrainApp:
596
623
  if self.gui.classes_selector is not None:
597
624
  if self.gui.classes_selector.is_convert_class_shapes_enabled():
598
625
  self._convert_project_to_model_task()
626
+
599
627
  # Step 4. Split Project
600
628
  self._split_project()
601
629
  # Step 5. Remove classes except selected
602
630
  if self.sly_project.type == ProjectType.IMAGES.value:
603
631
  self.sly_project.remove_classes_except(self.project_dir, self.classes, True)
604
632
  self._read_project()
605
- # Step 6. Download Model files
633
+
634
+ # Step 6. Create collections
635
+ self._create_collection_splits()
636
+
637
+ # Step 7. Download Model files
606
638
  self._download_model()
607
639
 
608
640
  def _finalize(self, experiment_info: dict) -> None:
@@ -702,11 +734,7 @@ class TrainApp:
702
734
  self._upload_demo_files(remote_dir)
703
735
 
704
736
  # Step 10. Generate training output
705
- need_generate_report = self._app_options.get("generate_report", True)
706
- if need_generate_report:
707
- output_file_info = self._generate_experiment_report(experiment_info, model_meta)
708
- else: # output artifacts directory
709
- output_file_info = session_link_file_info
737
+ output_file_info, experiment_info = self._generate_experiment_output(experiment_info, model_meta, session_link_file_info)
710
738
 
711
739
  # Step 11. Set output widgets
712
740
  self._set_text_status("reset")
@@ -851,7 +879,7 @@ class TrainApp:
851
879
  "TensorRT": True
852
880
  },
853
881
  },
854
- "experiment_name": "my_experiment",
882
+ "experiment_name": "My Experiment",
855
883
  }
856
884
  """
857
885
  self.gui.load_from_app_state(app_state)
@@ -1074,6 +1102,7 @@ class TrainApp:
1074
1102
  project_id=self.project_id,
1075
1103
  dest_dir=self.project_dir,
1076
1104
  dataset_ids=[ds_info.id for ds_info in dataset_infos],
1105
+ save_image_info=True,
1077
1106
  log_progress=True,
1078
1107
  progress_cb=pbar.update,
1079
1108
  )
@@ -1166,13 +1195,15 @@ class TrainApp:
1166
1195
  self._train_val_split_file = None
1167
1196
  self._train_split = []
1168
1197
  self._val_split = []
1198
+ self._train_split_item_ids = set()
1199
+ self._val_split_item_ids = set()
1169
1200
  return
1170
1201
 
1171
1202
  # Load splits
1172
1203
  self.gui.train_val_splits_selector.set_sly_project(self.sly_project)
1173
- self._train_split, self._val_split = (
1174
- self.gui.train_val_splits_selector.train_val_splits.get_splits()
1175
- )
1204
+ self._train_split, self._val_split = self.gui.train_val_splits_selector.train_val_splits.get_splits()
1205
+ self._train_split_ids, self._val_split_ids = [], []
1206
+ self._train_split_item_ids, self._val_split_item_ids = set(), set()
1176
1207
 
1177
1208
  # Prepare paths
1178
1209
  project_split_path = join(self.work_dir, "splits")
@@ -1181,11 +1212,13 @@ class TrainApp:
1181
1212
  "split_path": join(project_split_path, "train"),
1182
1213
  "img_dir": join(project_split_path, "train", "img"),
1183
1214
  "ann_dir": join(project_split_path, "train", "ann"),
1215
+ "img_info_dir": join(project_split_path, "train", "img_info"),
1184
1216
  },
1185
1217
  "val": {
1186
1218
  "split_path": join(project_split_path, "val"),
1187
1219
  "img_dir": join(project_split_path, "val", "img"),
1188
1220
  "ann_dir": join(project_split_path, "val", "ann"),
1221
+ "img_info_dir": join(project_split_path, "val", "img_info"),
1189
1222
  },
1190
1223
  }
1191
1224
 
@@ -1203,7 +1236,7 @@ class TrainApp:
1203
1236
  }
1204
1237
 
1205
1238
  # Utility function to move files
1206
- def move_files(split, paths, img_name_format, pbar):
1239
+ def move_files(split, split_name, paths, img_name_format, pbar):
1207
1240
  """
1208
1241
  Move files to the appropriate directories.
1209
1242
  """
@@ -1212,6 +1245,19 @@ class TrainApp:
1212
1245
  ann_name = f"{item_name}.json"
1213
1246
  shutil.copy(item.img_path, join(paths["img_dir"], item_name))
1214
1247
  shutil.copy(item.ann_path, join(paths["ann_dir"], ann_name))
1248
+
1249
+ # Move img_info
1250
+ img_info_name = f"{sly_fs.get_file_name_with_ext(item.img_path)}.json"
1251
+ img_info_path = join(dirname(dirname(item.img_path)), "img_info", img_info_name)
1252
+ # shutil.copy(img_info_path, join(paths["img_info_dir"], ann_name))
1253
+
1254
+ # Write split ids to img_info
1255
+ img_info = sly_json.load_json_file(img_info_path)
1256
+ if split_name == "train":
1257
+ self._train_split_item_ids.add(img_info["id"])
1258
+ else:
1259
+ self._val_split_item_ids.add(img_info["id"])
1260
+
1215
1261
  pbar.update(1)
1216
1262
 
1217
1263
  # Main split processing
@@ -1220,12 +1266,16 @@ class TrainApp:
1220
1266
  ) as main_pbar:
1221
1267
  self.progress_bar_main.show()
1222
1268
  for dataset in ["train", "val"]:
1223
- split = self._train_split if dataset == "train" else self._val_split
1269
+ split_name = dataset
1270
+ if split_name == "train":
1271
+ split = self._train_split
1272
+ else:
1273
+ split = self._val_split
1224
1274
  with self.progress_bar_secondary(
1225
1275
  message=f"Preparing '{dataset}'", total=len(split)
1226
1276
  ) as second_pbar:
1227
1277
  self.progress_bar_secondary.show()
1228
- move_files(split, paths[dataset], image_name_formats[dataset], second_pbar)
1278
+ move_files(split, split_name, paths[dataset], image_name_formats[dataset], second_pbar)
1229
1279
  main_pbar.update(1)
1230
1280
  self.progress_bar_secondary.hide()
1231
1281
  self.progress_bar_main.hide()
@@ -1245,10 +1295,7 @@ class TrainApp:
1245
1295
  with self.progress_bar_main(message="Processing splits", total=2) as pbar:
1246
1296
  self.progress_bar_main.show()
1247
1297
  for dataset in ["train", "val"]:
1248
- shutil.move(
1249
- paths[dataset]["split_path"],
1250
- train_ds_path if dataset == "train" else val_ds_path,
1251
- )
1298
+ shutil.move(paths[dataset]["split_path"], train_ds_path if dataset == "train" else val_ds_path)
1252
1299
  pbar.update(1)
1253
1300
  self.progress_bar_main.hide()
1254
1301
 
@@ -1551,13 +1598,18 @@ class TrainApp:
1551
1598
  project_id = self.project_id
1552
1599
 
1553
1600
  dataset_infos = [dataset for _, dataset in self._api.dataset.tree(project_id)]
1601
+ id_to_info = {ds.id: ds for ds in dataset_infos}
1554
1602
  ds_infos_dict = {}
1555
1603
  for dataset in dataset_infos:
1556
- if dataset.parent_id is not None:
1557
- parent_ds = self._api.dataset.get_info_by_id(dataset.parent_id)
1558
- dataset_name = f"{parent_ds.name}/{dataset.name}"
1559
- else:
1560
- dataset_name = dataset.name
1604
+ name_parts = [dataset.name]
1605
+ parent_id = dataset.parent_id
1606
+ while parent_id is not None:
1607
+ parent_ds = id_to_info.get(parent_id)
1608
+ if parent_ds is None:
1609
+ parent_ds = self._api.dataset.get_info_by_id(parent_id)
1610
+ name_parts.append(parent_ds.name)
1611
+ parent_id = parent_ds.parent_id
1612
+ dataset_name = "/".join(reversed(name_parts))
1561
1613
  ds_infos_dict[dataset_name] = dataset
1562
1614
 
1563
1615
  def get_image_infos_by_split(ds_infos_dict: dict, split: list):
@@ -1671,7 +1723,8 @@ class TrainApp:
1671
1723
  try:
1672
1724
  # pylint: disable=import-error
1673
1725
  import torch
1674
- state_dict = torch.load(new_checkpoint_path)
1726
+
1727
+ state_dict = torch_load_safe(new_checkpoint_path)
1675
1728
  state_dict["model_info"] = {
1676
1729
  "task_id": self.task_id,
1677
1730
  "model_name": experiment_info["model_name"],
@@ -1683,9 +1736,7 @@ class TrainApp:
1683
1736
  state_dict["model_files"] = ckpt_files
1684
1737
  torch.save(state_dict, new_checkpoint_path)
1685
1738
  except Exception as e:
1686
- logger.warning(
1687
- f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}"
1688
- )
1739
+ logger.warning(f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}")
1689
1740
  continue
1690
1741
 
1691
1742
  new_checkpoint_paths.append(new_checkpoint_path)
@@ -1821,7 +1872,6 @@ class TrainApp:
1821
1872
  :type export_weights: dict
1822
1873
  """
1823
1874
  logger.debug("Updating experiment info")
1824
-
1825
1875
  experiment_info = {
1826
1876
  "experiment_name": self.gui.training_process.get_experiment_name(),
1827
1877
  "framework_name": self.framework_name,
@@ -1830,6 +1880,7 @@ class TrainApp:
1830
1880
  "base_checkpoint_link": self.base_checkpoint_link,
1831
1881
  "task_type": experiment_info["task_type"],
1832
1882
  "project_id": self.project_info.id,
1883
+ "project_version": self.project_info.version,
1833
1884
  "task_id": self.task_id,
1834
1885
  "model_files": experiment_info["model_files"],
1835
1886
  "checkpoints": experiment_info["checkpoints"],
@@ -1847,6 +1898,8 @@ class TrainApp:
1847
1898
  "logs": {"type": "tensorboard", "link": f"{remote_dir}logs/"},
1848
1899
  "device": self.gui.training_process.get_device_name(),
1849
1900
  "training_duration": self._training_duration,
1901
+ "train_collection_id": self._train_collection_id,
1902
+ "val_collection_id": self._val_collection_id,
1850
1903
  }
1851
1904
 
1852
1905
  if self._has_splits_selector:
@@ -1986,6 +2039,37 @@ class TrainApp:
1986
2039
 
1987
2040
  self.progress_bar_main.hide()
1988
2041
 
2042
+ def _generate_experiment_output(self, experiment_info: dict, model_meta: ProjectMeta, session_link_file_info: FileInfo) -> tuple:
2043
+ """
2044
+ Generates and uploads the experiment page to the output directory, if report generation is successful.
2045
+ Otherwise, artifacts directory link will be used for output.
2046
+
2047
+ :param experiment_info: Information about the experiment results.
2048
+ :type experiment_info: dict
2049
+ :param model_meta: Model meta with object classes.
2050
+ :type model_meta: ProjectMeta
2051
+ :param session_link_file_info: Artifacts directory link, used if report is not generated.
2052
+ :type session_link_file_info: FileInfo
2053
+ :return: Output file info and experiment info.
2054
+ :rtype: tuple
2055
+ """
2056
+ need_generate_report = self._app_options.get("generate_report", False)
2057
+ if need_generate_report: # link to experiment page
2058
+ try:
2059
+ output_file_info = self._generate_experiment_report(experiment_info, model_meta)
2060
+ experiment_info["has_report"] = True
2061
+ experiment_info["experiment_report_id"] = output_file_info.id
2062
+ except Exception as e:
2063
+ logger.error(f"Error generating experiment report: {e}")
2064
+ output_file_info = session_link_file_info
2065
+ experiment_info["has_report"] = False
2066
+ experiment_info["experiment_report_id"] = None
2067
+ else: # link to artifacts directory
2068
+ output_file_info = session_link_file_info
2069
+ experiment_info["has_report"] = False
2070
+ experiment_info["experiment_report_id"] = None
2071
+ return output_file_info, experiment_info
2072
+
1989
2073
  def _get_train_val_splits_for_app_state(self) -> Dict:
1990
2074
  """
1991
2075
  Gets the train and val splits information for app_state.json.
@@ -2719,6 +2803,12 @@ class TrainApp:
2719
2803
  train_logger.add_on_step_finished_callback(step_callback)
2720
2804
 
2721
2805
  # ----------------------------------------- #
2806
+ def start_in_thread(self):
2807
+ def auto_train():
2808
+ import threading
2809
+ threading.Thread(target=self._wrapped_start_training, daemon=True).start()
2810
+ self._server.add_event_handler("startup", auto_train)
2811
+
2722
2812
  def _wrapped_start_training(self):
2723
2813
  """
2724
2814
  Wrapper function to wrap the training process.
@@ -3059,4 +3149,93 @@ class TrainApp:
3059
3149
  # 4. Match splits with original project
3060
3150
  gt_split_data = self._postprocess_splits(gt_project_info.id)
3061
3151
  return gt_project_info.id, gt_split_data
3062
- return gt_project_info.id, gt_split_data
3152
+
3153
+ def _create_collection_splits(self):
3154
+ def _check_match(current_selected_collection_ids: List[int], all_split_collections: List[EntitiesCollectionInfo]):
3155
+ if len(current_selected_collection_ids) > 0:
3156
+ if len(current_selected_collection_ids) == 1:
3157
+ current_selected_collection_id = current_selected_collection_ids[0]
3158
+ for collection in all_split_collections:
3159
+ if collection.id == current_selected_collection_id:
3160
+ return True
3161
+ return False
3162
+
3163
+ # Case 1: Use existing collections for training. No need to create new collections
3164
+ split_method = self.gui.train_val_splits_selector.get_split_method()
3165
+ all_train_collections = self.gui.train_val_splits_selector.all_train_collections
3166
+ all_val_collections = self.gui.train_val_splits_selector.all_val_collections
3167
+ if split_method == "Based on collections":
3168
+ current_selected_train_collection_ids = self.gui.train_val_splits_selector.train_val_splits.get_train_collections_ids()
3169
+ train_match = _check_match(current_selected_train_collection_ids, all_train_collections)
3170
+ if train_match:
3171
+ current_selected_val_collection_ids = self.gui.train_val_splits_selector.train_val_splits.get_val_collections_ids()
3172
+ val_match = _check_match(current_selected_val_collection_ids, all_val_collections)
3173
+ if val_match:
3174
+ self._train_collection_id = current_selected_train_collection_ids[0]
3175
+ self._val_collection_id = current_selected_val_collection_ids[0]
3176
+ self._update_project_custom_data(self._train_collection_id, self._val_collection_id)
3177
+ return
3178
+ # ------------------------------------------------------------ #
3179
+
3180
+ # Case 2: Create new collections for selected train val splits. Need to create new collections
3181
+ item_type = self.project_info.type
3182
+ experiment_name = self.gui.training_process.get_experiment_name()
3183
+
3184
+ train_collection_idx = 1
3185
+ val_collection_idx = 1
3186
+
3187
+ def _extract_index_from_col_name(name: str, expected_prefix: str) -> Optional[int]:
3188
+ parts = name.split("_")
3189
+ if len(parts) == 2 and parts[0] == expected_prefix and parts[1].isdigit():
3190
+ return int(parts[1])
3191
+ return None
3192
+
3193
+ # Get train collection with max idx
3194
+ if len(all_train_collections) > 0:
3195
+ train_indices = [_extract_index_from_col_name(collection.name, "train") for collection in all_train_collections]
3196
+ train_indices = [idx for idx in train_indices if idx is not None]
3197
+ if len(train_indices) > 0:
3198
+ train_collection_idx = max(train_indices) + 1
3199
+
3200
+ # Get val collection with max idx
3201
+ if len(all_val_collections) > 0:
3202
+ val_indices = [_extract_index_from_col_name(collection.name, "val") for collection in all_val_collections]
3203
+ val_indices = [idx for idx in val_indices if idx is not None]
3204
+ if len(val_indices) > 0:
3205
+ val_collection_idx = max(val_indices) + 1
3206
+ # -------------------------------- #
3207
+
3208
+ # Create Train Collection
3209
+ train_img_ids = list(self._train_split_item_ids)
3210
+ train_collection_description = f"Collection with train {item_type} for experiment: {experiment_name}"
3211
+ train_collection = self._api.entities_collection.create(self.project_id, f"train_{train_collection_idx:03d}", train_collection_description)
3212
+ train_collection_id = getattr(train_collection, "id", None)
3213
+ if train_collection_id is None:
3214
+ raise AttributeError("Train EntitiesCollectionInfo object does not have 'id' attribute")
3215
+ self._api.entities_collection.add_items(train_collection_id, train_img_ids)
3216
+ self._train_collection_id = train_collection_id
3217
+
3218
+ # Create Val Collection
3219
+ val_img_ids = list(self._val_split_item_ids)
3220
+ val_collection_description = f"Collection with val {item_type} for experiment: {experiment_name}"
3221
+ val_collection = self._api.entities_collection.create(self.project_id, f"val_{val_collection_idx:03d}", val_collection_description)
3222
+ val_collection_id = getattr(val_collection, "id", None)
3223
+ if val_collection_id is None:
3224
+ raise AttributeError("Val EntitiesCollectionInfo object does not have 'id' attribute")
3225
+ self._api.entities_collection.add_items(val_collection_id, val_img_ids)
3226
+ self._val_collection_id = val_collection_id
3227
+
3228
+ # Update Project Custom Data
3229
+ self._update_project_custom_data(train_collection_id, val_collection_id)
3230
+
3231
+ def _update_project_custom_data(self, train_collection_id: int, val_collection_id: int):
3232
+ train_info = {
3233
+ "task_id": self.task_id,
3234
+ "framework_name": self.framework_name,
3235
+ "splits": {"train_collection": train_collection_id, "val_collection": val_collection_id}
3236
+ }
3237
+ custom_data = self._api.project.get_info_by_id(self.project_id).custom_data
3238
+ train_info_list = custom_data.get("train_info", [])
3239
+ train_info_list.append(train_info)
3240
+ custom_data.update({"train_info": train_info_list})
3241
+ self._api.project.update_custom_data(self.project_id, custom_data)
@@ -740,15 +740,19 @@ def download_pointcloud_episode_project(
740
740
  if progress_cb is not None:
741
741
  log_progress = False
742
742
 
743
- datasets_infos = []
743
+ filter_fn = lambda x: True
744
744
  if dataset_ids is not None:
745
- for ds_id in dataset_ids:
746
- datasets_infos.append(api.dataset.get_info_by_id(ds_id))
747
- else:
748
- datasets_infos = api.dataset.get_list(project_id)
749
-
750
- for dataset in datasets_infos:
751
- dataset_fs: PointcloudEpisodeDataset = project_fs.create_dataset(dataset.name)
745
+ filter_fn = lambda ds: ds.id in dataset_ids
746
+
747
+ for parents, dataset in api.dataset.tree(project_id):
748
+ if not filter_fn(dataset):
749
+ continue
750
+ dataset_path = None
751
+ if parents:
752
+ dataset_path = "/datasets/".join(parents + [dataset.name])
753
+ dataset_fs: PointcloudEpisodeDataset = project_fs.create_dataset(
754
+ dataset.name, ds_path=dataset_path
755
+ )
752
756
  pointclouds = api.pointcloud_episode.get_list(dataset.id)
753
757
 
754
758
  # Download annotation to project_path/dataset_path/annotation.json
@@ -994,15 +994,19 @@ def download_pointcloud_project(
994
994
  if progress_cb is not None:
995
995
  log_progress = False
996
996
 
997
- datasets_infos = []
997
+ filter_fn = lambda ds: True
998
998
  if dataset_ids is not None:
999
- for ds_id in dataset_ids:
1000
- datasets_infos.append(api.dataset.get_info_by_id(ds_id))
1001
- else:
1002
- datasets_infos = api.dataset.get_list(project_id)
1003
-
1004
- for dataset in datasets_infos:
1005
- dataset_fs: PointcloudDataset = project_fs.create_dataset(dataset.name)
999
+ filter_fn = lambda ds: ds.id in dataset_ids
1000
+
1001
+ for parents, dataset in api.dataset.tree(project_id):
1002
+ if not filter_fn(dataset):
1003
+ continue
1004
+ dataset_path = None
1005
+ if parents:
1006
+ dataset_path = "/datasets/".join(parents + [dataset.name])
1007
+ dataset_fs: PointcloudDataset = project_fs.create_dataset(
1008
+ ds_name=dataset.name, ds_path=dataset_path
1009
+ )
1006
1010
  pointclouds = api.pointcloud.get_list(dataset.id)
1007
1011
 
1008
1012
  ds_progress = progress_cb