supervisely 6.73.438__py3-none-any.whl → 6.73.513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. supervisely/__init__.py +137 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/annotation.py +8 -2
  4. supervisely/annotation/json_geometries_map.py +14 -11
  5. supervisely/annotation/label.py +80 -3
  6. supervisely/api/annotation_api.py +14 -11
  7. supervisely/api/api.py +59 -38
  8. supervisely/api/app_api.py +11 -2
  9. supervisely/api/dataset_api.py +74 -12
  10. supervisely/api/entities_collection_api.py +10 -0
  11. supervisely/api/entity_annotation/figure_api.py +52 -4
  12. supervisely/api/entity_annotation/object_api.py +3 -3
  13. supervisely/api/entity_annotation/tag_api.py +63 -12
  14. supervisely/api/guides_api.py +210 -0
  15. supervisely/api/image_api.py +72 -1
  16. supervisely/api/labeling_job_api.py +83 -1
  17. supervisely/api/labeling_queue_api.py +33 -7
  18. supervisely/api/module_api.py +9 -0
  19. supervisely/api/project_api.py +71 -26
  20. supervisely/api/storage_api.py +3 -1
  21. supervisely/api/task_api.py +13 -2
  22. supervisely/api/team_api.py +4 -3
  23. supervisely/api/video/video_annotation_api.py +119 -3
  24. supervisely/api/video/video_api.py +65 -14
  25. supervisely/api/video/video_figure_api.py +24 -11
  26. supervisely/app/__init__.py +1 -1
  27. supervisely/app/content.py +23 -7
  28. supervisely/app/development/development.py +18 -2
  29. supervisely/app/fastapi/__init__.py +1 -0
  30. supervisely/app/fastapi/custom_static_files.py +1 -1
  31. supervisely/app/fastapi/multi_user.py +105 -0
  32. supervisely/app/fastapi/subapp.py +88 -42
  33. supervisely/app/fastapi/websocket.py +77 -9
  34. supervisely/app/singleton.py +21 -0
  35. supervisely/app/v1/app_service.py +18 -2
  36. supervisely/app/v1/constants.py +7 -1
  37. supervisely/app/widgets/__init__.py +6 -0
  38. supervisely/app/widgets/activity_feed/__init__.py +0 -0
  39. supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
  40. supervisely/app/widgets/activity_feed/style.css +78 -0
  41. supervisely/app/widgets/activity_feed/template.html +22 -0
  42. supervisely/app/widgets/card/card.py +20 -0
  43. supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
  44. supervisely/app/widgets/classes_list_selector/template.html +60 -93
  45. supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
  46. supervisely/app/widgets/classes_table/classes_table.py +1 -0
  47. supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
  48. supervisely/app/widgets/dialog/dialog.py +12 -0
  49. supervisely/app/widgets/dialog/template.html +2 -1
  50. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
  51. supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
  52. supervisely/app/widgets/fast_table/fast_table.py +184 -60
  53. supervisely/app/widgets/fast_table/template.html +1 -1
  54. supervisely/app/widgets/heatmap/__init__.py +0 -0
  55. supervisely/app/widgets/heatmap/heatmap.py +564 -0
  56. supervisely/app/widgets/heatmap/script.js +533 -0
  57. supervisely/app/widgets/heatmap/style.css +233 -0
  58. supervisely/app/widgets/heatmap/template.html +21 -0
  59. supervisely/app/widgets/modal/__init__.py +0 -0
  60. supervisely/app/widgets/modal/modal.py +198 -0
  61. supervisely/app/widgets/modal/template.html +10 -0
  62. supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
  63. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  64. supervisely/app/widgets/radio_tabs/template.html +1 -0
  65. supervisely/app/widgets/select/select.py +6 -3
  66. supervisely/app/widgets/select_class/__init__.py +0 -0
  67. supervisely/app/widgets/select_class/select_class.py +363 -0
  68. supervisely/app/widgets/select_class/template.html +50 -0
  69. supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
  70. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
  71. supervisely/app/widgets/select_tag/__init__.py +0 -0
  72. supervisely/app/widgets/select_tag/select_tag.py +352 -0
  73. supervisely/app/widgets/select_tag/template.html +64 -0
  74. supervisely/app/widgets/select_team/select_team.py +37 -4
  75. supervisely/app/widgets/select_team/template.html +4 -5
  76. supervisely/app/widgets/select_user/__init__.py +0 -0
  77. supervisely/app/widgets/select_user/select_user.py +270 -0
  78. supervisely/app/widgets/select_user/template.html +13 -0
  79. supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
  80. supervisely/app/widgets/select_workspace/template.html +9 -12
  81. supervisely/app/widgets/table/table.py +68 -13
  82. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  83. supervisely/aug/aug.py +6 -2
  84. supervisely/convert/base_converter.py +1 -0
  85. supervisely/convert/converter.py +2 -2
  86. supervisely/convert/image/csv/csv_converter.py +24 -15
  87. supervisely/convert/image/image_converter.py +3 -1
  88. supervisely/convert/image/image_helper.py +48 -4
  89. supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
  90. supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
  91. supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
  92. supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
  93. supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
  94. supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
  95. supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
  96. supervisely/convert/pointcloud/las/las_converter.py +13 -1
  97. supervisely/convert/pointcloud/las/las_helper.py +110 -11
  98. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
  99. supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
  100. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
  101. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
  102. supervisely/convert/video/__init__.py +1 -0
  103. supervisely/convert/video/multi_view/__init__.py +0 -0
  104. supervisely/convert/video/multi_view/multi_view.py +543 -0
  105. supervisely/convert/video/sly/sly_video_converter.py +359 -3
  106. supervisely/convert/video/video_converter.py +24 -4
  107. supervisely/convert/volume/dicom/dicom_converter.py +13 -5
  108. supervisely/convert/volume/dicom/dicom_helper.py +30 -18
  109. supervisely/geometry/constants.py +1 -0
  110. supervisely/geometry/geometry.py +4 -0
  111. supervisely/geometry/helpers.py +5 -1
  112. supervisely/geometry/oriented_bbox.py +676 -0
  113. supervisely/geometry/polyline_3d.py +110 -0
  114. supervisely/geometry/rectangle.py +2 -1
  115. supervisely/io/env.py +76 -1
  116. supervisely/io/fs.py +21 -0
  117. supervisely/nn/benchmark/base_evaluator.py +104 -11
  118. supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
  119. supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
  120. supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
  121. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
  122. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
  123. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
  124. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
  125. supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
  126. supervisely/nn/inference/cache.py +43 -18
  127. supervisely/nn/inference/gui/serving_gui_template.py +5 -2
  128. supervisely/nn/inference/inference.py +916 -222
  129. supervisely/nn/inference/inference_request.py +55 -10
  130. supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
  131. supervisely/nn/inference/predict_app/gui/gui.py +676 -488
  132. supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
  133. supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
  134. supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
  135. supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
  136. supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
  137. supervisely/nn/inference/predict_app/gui/utils.py +236 -119
  138. supervisely/nn/inference/predict_app/predict_app.py +2 -2
  139. supervisely/nn/inference/session.py +43 -35
  140. supervisely/nn/inference/tracking/bbox_tracking.py +118 -35
  141. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  142. supervisely/nn/inference/tracking/tracker_interface.py +10 -1
  143. supervisely/nn/inference/uploader.py +139 -12
  144. supervisely/nn/live_training/__init__.py +7 -0
  145. supervisely/nn/live_training/api_server.py +111 -0
  146. supervisely/nn/live_training/artifacts_utils.py +243 -0
  147. supervisely/nn/live_training/checkpoint_utils.py +229 -0
  148. supervisely/nn/live_training/dynamic_sampler.py +44 -0
  149. supervisely/nn/live_training/helpers.py +14 -0
  150. supervisely/nn/live_training/incremental_dataset.py +146 -0
  151. supervisely/nn/live_training/live_training.py +497 -0
  152. supervisely/nn/live_training/loss_plateau_detector.py +111 -0
  153. supervisely/nn/live_training/request_queue.py +52 -0
  154. supervisely/nn/model/model_api.py +9 -0
  155. supervisely/nn/model/prediction.py +2 -1
  156. supervisely/nn/model/prediction_session.py +26 -14
  157. supervisely/nn/prediction_dto.py +19 -1
  158. supervisely/nn/tracker/base_tracker.py +11 -1
  159. supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
  160. supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
  161. supervisely/nn/tracker/botsort_tracker.py +94 -65
  162. supervisely/nn/tracker/utils.py +4 -5
  163. supervisely/nn/tracker/visualize.py +93 -93
  164. supervisely/nn/training/gui/classes_selector.py +16 -1
  165. supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
  166. supervisely/nn/training/train_app.py +46 -31
  167. supervisely/project/data_version.py +115 -51
  168. supervisely/project/download.py +1 -1
  169. supervisely/project/pointcloud_episode_project.py +37 -8
  170. supervisely/project/pointcloud_project.py +30 -2
  171. supervisely/project/project.py +14 -2
  172. supervisely/project/project_meta.py +27 -1
  173. supervisely/project/project_settings.py +32 -18
  174. supervisely/project/versioning/__init__.py +1 -0
  175. supervisely/project/versioning/common.py +20 -0
  176. supervisely/project/versioning/schema_fields.py +35 -0
  177. supervisely/project/versioning/video_schema.py +221 -0
  178. supervisely/project/versioning/volume_schema.py +87 -0
  179. supervisely/project/video_project.py +717 -15
  180. supervisely/project/volume_project.py +623 -5
  181. supervisely/template/experiment/experiment.html.jinja +4 -4
  182. supervisely/template/experiment/experiment_generator.py +14 -21
  183. supervisely/template/live_training/__init__.py +0 -0
  184. supervisely/template/live_training/header.html.jinja +96 -0
  185. supervisely/template/live_training/live_training.html.jinja +51 -0
  186. supervisely/template/live_training/live_training_generator.py +464 -0
  187. supervisely/template/live_training/sly-style.css +402 -0
  188. supervisely/template/live_training/template.html.jinja +18 -0
  189. supervisely/versions.json +28 -26
  190. supervisely/video/sampling.py +39 -20
  191. supervisely/video/video.py +41 -12
  192. supervisely/video_annotation/video_figure.py +38 -4
  193. supervisely/video_annotation/video_object.py +29 -4
  194. supervisely/volume/stl_converter.py +2 -0
  195. supervisely/worker_api/agent_rpc.py +24 -1
  196. supervisely/worker_api/rpc_servicer.py +31 -7
  197. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/METADATA +58 -40
  198. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/RECORD +203 -155
  199. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
  200. supervisely_lib/__init__.py +6 -1
  201. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
  202. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
  203. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,243 @@
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ from typing import Optional, Dict
5
+ from datetime import datetime
6
+ import supervisely as sly
7
+ from supervisely import logger
8
+ from supervisely.template.live_training.live_training_generator import LiveTrainingGenerator
9
+ import supervisely.io.json as sly_json
10
+ from supervisely.nn.live_training.helpers import ClassMap
11
+ import yaml
12
+
13
+
14
+ def upload_artifacts(
15
+ api: sly.Api,
16
+ session_info: dict,
17
+ artifacts: dict,
18
+ ) -> str:
19
+ """
20
+ Upload artifacts to Team Files and generate experiment report.
21
+
22
+ Args:
23
+ session_info: Training session context
24
+ - team_id: Team ID
25
+ - task_id: Task ID
26
+ - project_id: Project ID
27
+ - framework_name: Framework name
28
+ - task_type: Task type
29
+ - class_map: Model class map
30
+ - start_time: Training start time string
31
+ - train_size: Final dataset size
32
+ - initial_samples: Number of initial samples
33
+
34
+
35
+ artifacts: Framework-specific artifacts
36
+ - checkpoint_path: Path to checkpoint file
37
+ - checkpoint_info: Dict with {name, iteration, loss}
38
+ - config_path: Path to config file
39
+ - logs_dir: Path to TensorBoard logs or None
40
+ - model_config: Model configuration dict
41
+ - loss_history: Dict with loss history
42
+
43
+ Returns:
44
+ report_url: URL to experiment report
45
+ """
46
+ logger.info("Starting artifacts upload")
47
+
48
+ # Unpack session_info
49
+ team_id = session_info['team_id']
50
+ task_id = session_info['task_id']
51
+ project_id = session_info['project_id']
52
+ framework_name = session_info['framework_name']
53
+ task_type = session_info['task_type']
54
+ class_map: ClassMap = session_info['class_map']
55
+ model_meta = sly.ProjectMeta(obj_classes=class_map.obj_classes)
56
+ start_time = session_info['start_time']
57
+ train_size = session_info['train_size']
58
+ initial_samples = session_info.get('initial_samples', 0)
59
+
60
+ # Unpack artifacts
61
+ checkpoint_path = artifacts['checkpoint_path']
62
+ checkpoint_info = artifacts['checkpoint_info']
63
+ config_path = artifacts['config_path']
64
+ logs_dir = artifacts.get('logs_dir')
65
+ model_config = artifacts['model_config']
66
+
67
+ work_dir = Path(os.path.dirname(checkpoint_path)).parent
68
+ output_dir = work_dir / "upload_artifacts"
69
+ output_dir.mkdir(parents=True, exist_ok=True)
70
+
71
+ project_info = api.project.get_info_by_id(project_id)
72
+ project_name = project_info.name if project_info else "unknown"
73
+ model_name = f"Live training - {project_name}"
74
+ model_config['model_name'] = model_name
75
+ remote_dir = f"/experiments/live_training/{project_id}_{project_name}/{task_id}_{framework_name}/"
76
+ logger.info(f"Remote directory: {remote_dir}")
77
+
78
+ checkpoints_dir = output_dir / "checkpoints"
79
+ checkpoints_dir.mkdir(exist_ok=True)
80
+ checkpoint_dest = checkpoints_dir / checkpoint_info["name"]
81
+ shutil.copy2(checkpoint_path, checkpoint_dest)
82
+
83
+ state_json_src = checkpoint_path.replace('.pth', '_state.json')
84
+ if os.path.exists(state_json_src):
85
+ state_json_dest = str(checkpoint_dest).replace('.pth', '_state.json')
86
+ shutil.copy2(state_json_src, state_json_dest)
87
+
88
+ if config_path and os.path.exists(config_path):
89
+ config_dest = output_dir / os.path.basename(config_path)
90
+ shutil.copy2(config_path, config_dest)
91
+ model_files = {"config": config_dest.name}
92
+ else:
93
+ model_files = {}
94
+
95
+ sly_json.dump_json_file(
96
+ model_meta.to_json(),
97
+ str(output_dir / "model_meta.json")
98
+ )
99
+
100
+ hyperparams = {}
101
+ if config_path and os.path.exists(config_path):
102
+ try:
103
+ hyperparams = LiveTrainingGenerator.parse_hyperparameters(config_path)
104
+ except Exception as e:
105
+ logger.warning(f"Failed to parse hyperparameters: {e}")
106
+
107
+ with open(output_dir / "hyperparameters.yaml", 'w') as f:
108
+ yaml.dump(hyperparams, f, default_flow_style=False)
109
+
110
+ with open(output_dir / "open_app.lnk", 'w') as f:
111
+ f.write(f"/apps/sessions/{task_id}")
112
+
113
+ if logs_dir and os.path.exists(logs_dir):
114
+ logs_dest = output_dir / "logs"
115
+ if logs_dest.exists():
116
+ shutil.rmtree(logs_dest)
117
+ shutil.copytree(logs_dir, logs_dest)
118
+ logger.info(f"Logs copied from {logs_dir}")
119
+ has_logs = True
120
+ else:
121
+ logger.warning("No logs provided")
122
+ has_logs = False
123
+
124
+ logger.info("Uploading to Team Files")
125
+ api.file.upload_directory_fast(
126
+ team_id=team_id,
127
+ local_dir=str(output_dir),
128
+ remote_dir=remote_dir
129
+ )
130
+
131
+ experiment_info = {
132
+ "experiment_name": f"Live Training {task_type.capitalize()} - Task {task_id}",
133
+ "framework_name": framework_name,
134
+ "model_name": model_name,
135
+ "base_checkpoint": None,
136
+ "base_checkpoint_link": None,
137
+ "task_type": task_type,
138
+ "project_id": project_id,
139
+ "project_version": project_info.version if project_info else None,
140
+ "task_id": task_id,
141
+ "model_files": model_files,
142
+ "checkpoints": [f"checkpoints/{checkpoint_info['name']}"],
143
+ "best_checkpoint": checkpoint_info['name'],
144
+ "export": {},
145
+ "model_meta": "model_meta.json",
146
+ "hyperparameters": "hyperparameters.yaml",
147
+ "hyperparameters_id": None,
148
+ "artifacts_dir": remote_dir,
149
+ "datetime": start_time,
150
+ "experiment_report_id": None,
151
+ "evaluation_report_link": None,
152
+ "evaluation_metrics": {},
153
+ "primary_metric": None,
154
+ "logs": {"type": "tensorboard", "link": f"{remote_dir}logs/"} if has_logs else None,
155
+ "device": get_device_name(),
156
+ "training_duration": calculate_duration(start_time),
157
+ "train_collection_id": None,
158
+ "val_collection_id": None,
159
+ "project_preview": project_info.image_preview_url if project_info else None,
160
+ "train_size": train_size,
161
+ "initial_samples": initial_samples,
162
+ "val_size": 0,
163
+ }
164
+
165
+ checkpoints_info = [checkpoint_info]
166
+ loss_history = artifacts.get('loss_history', {})
167
+
168
+ session_info = {
169
+ "session_id": task_id,
170
+ "session_name": experiment_info["experiment_name"],
171
+ "project_id": project_id,
172
+ "start_time": experiment_info["datetime"],
173
+ "duration": experiment_info["training_duration"],
174
+ "artifacts_dir": remote_dir,
175
+ "logs_dir": f"{remote_dir}logs/" if has_logs else None,
176
+ "checkpoints": checkpoints_info,
177
+ "loss_history": loss_history,
178
+ "hyperparameters": hyperparams,
179
+ "status": "completed",
180
+ "device": experiment_info["device"],
181
+ "dataset_size": train_size,
182
+ "initial_samples": 0,
183
+ "samples_added": 0,
184
+ "final_size": train_size,
185
+ "train_size": train_size,
186
+ "initial_samples": initial_samples,
187
+ "val_size": 0,
188
+ }
189
+
190
+ report_dir = work_dir / "live_training_report"
191
+ report_dir.mkdir(exist_ok=True)
192
+
193
+ generator = LiveTrainingGenerator(
194
+ api=api,
195
+ session_info=session_info,
196
+ model_config=model_config,
197
+ model_meta=model_meta,
198
+ output_dir=str(report_dir),
199
+ team_id=team_id,
200
+ task_type=task_type,
201
+ )
202
+
203
+ generator.generate()
204
+ file_info = generator.upload_to_artifacts(os.path.join(remote_dir, "visualization"))
205
+
206
+ report_id = file_info if isinstance(file_info, int) else getattr(file_info, 'id', file_info)
207
+ report_url = f"{api.server_address}/nn/experiments/{report_id}"
208
+
209
+ logger.info(f"Report URL: {report_url}")
210
+
211
+ experiment_info["has_report"] = True
212
+ experiment_info["experiment_report_id"] = int(report_url.split('/')[-1])
213
+ response = api.task.set_output_experiment(task_id, experiment_info)
214
+
215
+ return report_url
216
+
217
+
218
+ def calculate_duration(start_time: str) -> str:
219
+ """Calculate training duration in 'Xh Ym' format."""
220
+ try:
221
+ start_dt = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S")
222
+ duration_sec = (datetime.now() - start_dt).total_seconds()
223
+ hours = int(duration_sec // 3600)
224
+ minutes = int((duration_sec % 3600) // 60)
225
+ return f"{hours}h {minutes}m"
226
+ except:
227
+ return "N/A"
228
+
229
+
230
+ def get_device_name() -> str:
231
+ """Get GPU device name or 'cpu'."""
232
+ import torch # pylint: disable=import-error
233
+ if not os.path.exists("/dev/nvidia0"):
234
+ return "cpu"
235
+
236
+ try:
237
+ if torch.cuda.is_available():
238
+ device_id = int(os.getenv("CUDA_VISIBLE_DEVICES", "0").split(",")[0])
239
+ return torch.cuda.get_device_name(device_id)
240
+ except Exception as e:
241
+ logger.warning(f"Failed to get GPU name: {e}")
242
+
243
+ return "cuda"
@@ -0,0 +1,229 @@
1
+ import os
2
+ import re
3
+ from typing import Tuple, List, Optional
4
+ import json
5
+ from supervisely import logger
6
+ from supervisely.nn.live_training.helpers import ClassMap
7
+ from supervisely.io.json import load_json_file
8
+ import supervisely as sly
9
+
10
+
11
+ def validate_classes_exact_match(saved_classes: List[str], current_classes: List[str]):
12
+ """
13
+ Validate that saved and current classes match exactly.
14
+ Raises ValueError if they don't match.
15
+ """
16
+ if set(saved_classes) != set(current_classes):
17
+ raise ValueError(
18
+ f"Class names in checkpoint do not match current class names. "
19
+ f"Saved: {saved_classes}, Current: {current_classes}"
20
+ )
21
+
22
+
23
+ def reorder_class_map(saved_classes: List[str], project_meta) -> ClassMap:
24
+ """
25
+ Create ClassMap with reordered classes matching checkpoint order.
26
+
27
+ Returns:
28
+ class_map: New ClassMap with correct order
29
+ """
30
+ class_mapping = {cls: idx for idx, cls in enumerate(saved_classes)}
31
+ logger.info(f"Class mapping: {class_mapping}")
32
+
33
+ # Create ClassMap from class names in checkpoint order
34
+ obj_classes = []
35
+ for name in saved_classes:
36
+ obj_class = project_meta.get_obj_class(name)
37
+ if obj_class is None:
38
+ raise ValueError(f"Class '{name}' not found in project metadata")
39
+ obj_classes.append(obj_class)
40
+
41
+ return ClassMap(obj_classes)
42
+
43
+
44
+ def remove_classification_head(checkpoint_path: str) -> str:
45
+ """
46
+ Remove classification head weights from checkpoint and save modified version.
47
+
48
+ Returns:
49
+ modified_path: Path to checkpoint without classification head
50
+ """
51
+ import torch # pylint: disable=import-error
52
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
53
+ state_dict = checkpoint.get('state_dict', {})
54
+
55
+ keys_to_remove = []
56
+ for key in state_dict.keys():
57
+ if 'decode_head' in key or 'auxiliary_head' in key:
58
+ keys_to_remove.append(key)
59
+
60
+ for key in keys_to_remove:
61
+ del state_dict[key]
62
+
63
+ logger.info(f"Removed {len(keys_to_remove)} classification head parameters")
64
+
65
+ modified_path = checkpoint_path.replace('.pth', '_headless.pth')
66
+ torch.save(checkpoint, modified_path)
67
+
68
+ return modified_path
69
+
70
+
71
+ def resolve_checkpoint(
72
+ checkpoint_mode: str,
73
+ selected_experiment_task_id: Optional[int],
74
+ class_map: ClassMap,
75
+ project_meta,
76
+ api,
77
+ team_id: int,
78
+ work_dir: str
79
+ ) -> Tuple[Optional[str], ClassMap, Optional[dict]]:
80
+ """
81
+ Main orchestrator function to resolve checkpoint loading based on mode.
82
+
83
+ Args:
84
+ checkpoint_mode: One of 'scratch', 'finetune', 'continue'
85
+ selected_experiment_task_id: Task ID to load checkpoint from (required for finetune/continue)
86
+ class_map: Current ClassMap
87
+ project_meta: Project metadata
88
+ api: Supervisely API instance
89
+ team_id: Team ID
90
+ framework_name: Framework name (unused, kept for compatibility)
91
+ work_dir: Working directory for downloaded files
92
+
93
+ Returns:
94
+ checkpoint_path: Path to checkpoint file (None for scratch mode)
95
+ class_map: Updated ClassMap (may be reordered for finetune/continue)
96
+ state: Training state dict (only for continue mode)
97
+ """
98
+ checkpoint_name = "latest.pth"
99
+ current_classes = [cls.name for cls in class_map.obj_classes]
100
+ logger.info(f"Checkpoint mode: {checkpoint_mode}")
101
+
102
+ if checkpoint_mode == "scratch":
103
+ logger.info("Starting from pretrained weights (scratch mode)")
104
+ return None, class_map, None
105
+
106
+ if selected_experiment_task_id is None:
107
+ raise ValueError(
108
+ f"selected_experiment_task_id must be provided when checkpoint_mode='{checkpoint_mode}'"
109
+ )
110
+
111
+ # Get experiment info
112
+ task_info = api.task.get_info_by_id(selected_experiment_task_id)
113
+ experiment_info = task_info["meta"]["output"]["experiment"]["data"]
114
+
115
+ artifacts_dir = experiment_info["artifacts_dir"]
116
+ model_meta_filename = experiment_info.get("model_meta", "model_meta.json")
117
+
118
+ # Setup local paths
119
+ local_dir = os.path.join(work_dir, 'downloaded_checkpoints')
120
+ os.makedirs(local_dir, exist_ok=True)
121
+
122
+ # Download checkpoint
123
+ remote_checkpoint = f"{artifacts_dir}checkpoints/{checkpoint_name}"
124
+ local_checkpoint = os.path.join(local_dir, checkpoint_name)
125
+
126
+ logger.info(f"Downloading checkpoint from {remote_checkpoint}")
127
+ api.file.download(team_id, remote_checkpoint, local_checkpoint)
128
+ logger.info(f"Checkpoint downloaded to {local_checkpoint}")
129
+
130
+ # Download model_meta.json
131
+ remote_model_meta = f"{artifacts_dir}{model_meta_filename}"
132
+ local_model_meta = os.path.join(local_dir, 'model_meta.json')
133
+
134
+ logger.info(f"Downloading model_meta from {remote_model_meta}")
135
+ api.file.download(team_id, remote_model_meta, local_model_meta)
136
+
137
+ # Load saved classes
138
+ model_meta_json = load_json_file(local_model_meta)
139
+ saved_project_meta = sly.ProjectMeta.from_json(model_meta_json)
140
+ saved_classes = [cls.name for cls in saved_project_meta.obj_classes]
141
+
142
+ logger.info(f"Saved classes: {saved_classes}")
143
+ logger.info(f"Current classes: {current_classes}")
144
+
145
+ # Finetune mode - flexible class handling
146
+ if checkpoint_mode == "finetune":
147
+ saved_set = set(saved_classes)
148
+ current_set = set(current_classes)
149
+
150
+ if saved_set == current_set:
151
+ if saved_classes != current_classes:
152
+ logger.info("Class order differs. Reordering classes")
153
+ class_map = reorder_class_map(saved_classes, project_meta)
154
+ else:
155
+ logger.info("Class names match exactly")
156
+ return local_checkpoint, class_map, None
157
+
158
+ elif len(saved_classes) == len(current_classes):
159
+ # logger.info("Class names differ but count matches. Removing classification head")
160
+ # modified_checkpoint = remove_classification_head(local_checkpoint)
161
+ logger.warning("Class names differ but count matches. Classification head will be kept as is")
162
+ return local_checkpoint, class_map, None
163
+
164
+ else:
165
+ logger.info("Classes differ completely. Starting from checkpoint")
166
+ return local_checkpoint, class_map, None
167
+
168
+ # Continue mode - strict matching required
169
+ elif checkpoint_mode == "continue":
170
+ logger.info(f"Continue mode: loading from task_id={selected_experiment_task_id}")
171
+ validate_classes_exact_match(saved_classes, current_classes)
172
+
173
+ # Download state JSON
174
+ state_filename = checkpoint_name.replace('.pth', '_state.json')
175
+ remote_state = f"{artifacts_dir}checkpoints/{state_filename}"
176
+ local_state = local_checkpoint.replace('.pth', '_state.json')
177
+
178
+ logger.info(f"Downloading state from {remote_state}")
179
+ api.file.download(team_id, remote_state, local_state)
180
+
181
+ # Use existing utility function
182
+ state = load_state_json(local_checkpoint)
183
+
184
+ logger.info(f"State loaded with {state.get('dataset_size', 0)} samples at iter {state.get('iter', 0)}")
185
+ return local_checkpoint, class_map, state
186
+
187
+ else:
188
+ raise ValueError(
189
+ f"Invalid checkpoint_mode='{checkpoint_mode}'. Valid: 'scratch', 'finetune', 'continue'"
190
+ )
191
+
192
+ def save_state_json(state: dict, checkpoint_path: str):
193
+ """
194
+ Save training state as JSON file next to checkpoint.
195
+
196
+ Args:
197
+ state: State dict from LiveTraining.state()
198
+ checkpoint_path: Path to .pth checkpoint file
199
+ """
200
+ state_path = checkpoint_path.replace('.pth', '_state.json')
201
+
202
+ with open(state_path, 'w') as f:
203
+ json.dump(state, f, indent=2)
204
+
205
+ logger.info(f"State saved to {state_path}")
206
+
207
+ def load_state_json(checkpoint_path: str) -> dict:
208
+ """
209
+ Load training state from JSON file next to checkpoint.
210
+
211
+ Args:
212
+ checkpoint_path: Path to .pth checkpoint file
213
+
214
+ Returns:
215
+ state: State dict for LiveTraining.load_state()
216
+ """
217
+ state_path = checkpoint_path.replace('.pth', '_state.json')
218
+
219
+ if not os.path.exists(state_path):
220
+ raise ValueError(
221
+ f"State file not found: {state_path}. "
222
+ f"This checkpoint may not support 'continue' mode."
223
+ )
224
+
225
+ with open(state_path, 'r') as f:
226
+ state = json.load(f)
227
+
228
+ logger.info(f"State loaded from {state_path}")
229
+ return state
@@ -0,0 +1,44 @@
1
+ import random
2
+ import time
3
+ from collections.abc import Sized
4
+
5
+
6
+ class DynamicSampler:
7
+ """
8
+ A sampler that dynamically adjusts to the size of a dataset that grows over time.
9
+ Implements torch.utils.data.Sampler interface.
10
+ """
11
+
12
+ def __init__(self, dataset: Sized, shuffle: bool = True, seed: int = 0):
13
+ self.dataset = dataset
14
+ self.shuffle = shuffle
15
+ self.seed = seed
16
+
17
+ def __iter__(self):
18
+ remaining_indices = []
19
+ last_known_len = 0
20
+
21
+ # Wait for first samples
22
+ while len(self.dataset) == 0:
23
+ time.sleep(0.1)
24
+
25
+ while True:
26
+ current_len = len(self.dataset)
27
+
28
+ if current_len > last_known_len:
29
+ new_indices = list(range(last_known_len, current_len))
30
+ if self.shuffle:
31
+ random.shuffle(new_indices)
32
+ remaining_indices.extend(new_indices)
33
+ last_known_len = current_len
34
+
35
+ if not remaining_indices:
36
+ # Reshuffle existing data
37
+ remaining_indices = list(range(current_len))
38
+ if self.shuffle:
39
+ random.shuffle(remaining_indices)
40
+
41
+ yield remaining_indices.pop()
42
+
43
+ def __len__(self):
44
+ return len(self.dataset)
@@ -0,0 +1,14 @@
1
+ from typing import List, Union
2
+ import supervisely as sly
3
+
4
+
5
+ class ClassMap:
6
+ def __init__(self, obj_classes: Union[sly.ObjClassCollection, List[sly.ObjClass]]):
7
+ self.obj_classes = obj_classes
8
+ self.class2idx = {obj_class.name: idx for idx, obj_class in enumerate(self.obj_classes)}
9
+ self.idx2class = {idx: obj_class.name for idx, obj_class in enumerate(self.obj_classes)}
10
+ self.classes = [obj_class.name for obj_class in self.obj_classes]
11
+ self.sly_ids = [obj_class.sly_id for obj_class in self.obj_classes]
12
+
13
+ def __len__(self):
14
+ return len(self.obj_classes)
@@ -0,0 +1,146 @@
1
+ from typing import Dict, Any
2
+ from pathlib import Path
3
+ from PIL import Image
4
+ import numpy as np
5
+ import supervisely as sly
6
+ import cv2
7
+
8
+
9
+ class IncrementalDataset:
10
+ """
11
+ 1. Save images on disk
12
+ 2. Store annotations in SLY/COCO format. Handle case for Segmentation task
13
+ 3. Implement indexing, adding, and updating samples
14
+ 4. __getitem__ to retrieve samples by index
15
+ """
16
+ def __init__(
17
+ self,
18
+ class2idx: dict,
19
+ data_dir: str,
20
+ save_masks_as_images: bool = False,
21
+ ):
22
+ self.class2idx = class2idx
23
+ self.data_dir = Path(data_dir)
24
+ self.save_masks_as_images = save_masks_as_images
25
+ self.images_dir = self.data_dir / "images"
26
+ self.images_dir.mkdir(parents=True, exist_ok=True)
27
+ if self.save_masks_as_images:
28
+ self.masks_dir = self.data_dir / "masks"
29
+ self.masks_dir.mkdir(parents=True, exist_ok=True)
30
+ self.samples: Dict[int, dict] = {}
31
+ self.samples_list = []
32
+
33
+ def add(
34
+ self,
35
+ image_id: int,
36
+ image_np: np.ndarray,
37
+ annotation: sly.Annotation,
38
+ image_name: str
39
+ ) -> dict:
40
+ if image_id in self.samples:
41
+ raise ValueError(f"Cannot add sample: Image ID {image_id} already exists in the dataset.")
42
+ image_name = f"{image_id} {image_name}"
43
+ w, h = image_np.shape[1], image_np.shape[0]
44
+ img_size = (w, h)
45
+ image_path = self._save_img(image_np, image_name)
46
+ mask_path = None
47
+ if self.save_masks_as_images:
48
+ mask_path = self._save_mask(annotation, image_name)
49
+ sample = self._format_sample(
50
+ image_id,
51
+ annotation,
52
+ img_size,
53
+ image_path,
54
+ mask_path
55
+ )
56
+ assert isinstance(sample, dict), "Sample must be a dict."
57
+ # add extra fields for internal use
58
+ sample['image_path'] = image_path
59
+ sample['size'] = img_size
60
+ if mask_path is not None:
61
+ sample['mask_path'] = mask_path
62
+ # add to dataset
63
+ self.samples[image_id] = sample
64
+ self.samples_list.append(sample)
65
+ return sample
66
+
67
+ def update(
68
+ self,
69
+ image_id: int,
70
+ annotation: sly.Annotation,
71
+ ) -> dict:
72
+ if image_id not in self.samples:
73
+ raise ValueError(f"Cannot update sample: Image ID {image_id} does not exist in the dataset.")
74
+ sample = self.samples[image_id]
75
+ new_sample = self._format_sample(
76
+ image_id,
77
+ annotation,
78
+ sample['size'],
79
+ sample['image_path'],
80
+ sample.get('mask_path')
81
+ )
82
+ sample.update(new_sample)
83
+ return sample
84
+
85
+ def add_or_update(
86
+ self,
87
+ image_id: int,
88
+ image_np: np.ndarray,
89
+ annotation: sly.Annotation,
90
+ image_name: str
91
+ ) -> dict:
92
+ if image_id not in self.samples:
93
+ return self.add(image_id, image_np, annotation, image_name)
94
+ else:
95
+ return self.update(image_id, annotation)
96
+
97
+ def _format_sample(
98
+ self,
99
+ image_id: int,
100
+ annotation: sly.Annotation,
101
+ image_size: tuple,
102
+ image_path: str,
103
+ mask_path: str = None
104
+ ) -> dict:
105
+ sample = {
106
+ 'image_id': image_id,
107
+ 'width': image_size[0],
108
+ 'height': image_size[1],
109
+ 'annotations': annotation.to_coco(annotation.image_id, self.class2idx)[0],
110
+ 'image_path': image_path,
111
+ 'mask_path': mask_path
112
+ }
113
+ return sample
114
+
115
+ def _save_img(self, image_np: np.ndarray, image_name: str) -> str:
116
+ image = Image.fromarray(image_np).convert('RGB')
117
+ image_path = str(self.images_dir / image_name)
118
+ image.save(image_path)
119
+ return image_path
120
+
121
+ def _save_mask(self, annotation: sly.Annotation, image_name: str) -> str:
122
+
123
+ mapping = {label.obj_class: label.obj_class for label in annotation.labels}
124
+ ann_nonoverlap = annotation.to_nonoverlapping_masks(mapping)
125
+ h, w = annotation.img_size
126
+ mask = np.zeros((h, w), dtype=np.uint8)
127
+
128
+ for label in ann_nonoverlap.labels:
129
+ class_name = label.obj_class.name
130
+ class_id = self.class2idx.get(class_name)
131
+ if class_id is not None:
132
+ label.geometry.draw(mask, color=class_id)
133
+
134
+ mask_name = Path(image_name).stem + '.png'
135
+ mask_path = str(self.masks_dir / mask_name)
136
+ cv2.imwrite(mask_path, mask)
137
+
138
+ return mask_path
139
+
140
+ def __len__(self) -> int:
141
+ return len(self.samples)
142
+
143
+ def get_image_ids(self) -> list:
144
+ """Get list of image IDs in dataset"""
145
+ return list(self.samples.keys())
146
+