supervisely 6.73.438__py3-none-any.whl → 6.73.513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. supervisely/__init__.py +137 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/annotation.py +8 -2
  4. supervisely/annotation/json_geometries_map.py +14 -11
  5. supervisely/annotation/label.py +80 -3
  6. supervisely/api/annotation_api.py +14 -11
  7. supervisely/api/api.py +59 -38
  8. supervisely/api/app_api.py +11 -2
  9. supervisely/api/dataset_api.py +74 -12
  10. supervisely/api/entities_collection_api.py +10 -0
  11. supervisely/api/entity_annotation/figure_api.py +52 -4
  12. supervisely/api/entity_annotation/object_api.py +3 -3
  13. supervisely/api/entity_annotation/tag_api.py +63 -12
  14. supervisely/api/guides_api.py +210 -0
  15. supervisely/api/image_api.py +72 -1
  16. supervisely/api/labeling_job_api.py +83 -1
  17. supervisely/api/labeling_queue_api.py +33 -7
  18. supervisely/api/module_api.py +9 -0
  19. supervisely/api/project_api.py +71 -26
  20. supervisely/api/storage_api.py +3 -1
  21. supervisely/api/task_api.py +13 -2
  22. supervisely/api/team_api.py +4 -3
  23. supervisely/api/video/video_annotation_api.py +119 -3
  24. supervisely/api/video/video_api.py +65 -14
  25. supervisely/api/video/video_figure_api.py +24 -11
  26. supervisely/app/__init__.py +1 -1
  27. supervisely/app/content.py +23 -7
  28. supervisely/app/development/development.py +18 -2
  29. supervisely/app/fastapi/__init__.py +1 -0
  30. supervisely/app/fastapi/custom_static_files.py +1 -1
  31. supervisely/app/fastapi/multi_user.py +105 -0
  32. supervisely/app/fastapi/subapp.py +88 -42
  33. supervisely/app/fastapi/websocket.py +77 -9
  34. supervisely/app/singleton.py +21 -0
  35. supervisely/app/v1/app_service.py +18 -2
  36. supervisely/app/v1/constants.py +7 -1
  37. supervisely/app/widgets/__init__.py +6 -0
  38. supervisely/app/widgets/activity_feed/__init__.py +0 -0
  39. supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
  40. supervisely/app/widgets/activity_feed/style.css +78 -0
  41. supervisely/app/widgets/activity_feed/template.html +22 -0
  42. supervisely/app/widgets/card/card.py +20 -0
  43. supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
  44. supervisely/app/widgets/classes_list_selector/template.html +60 -93
  45. supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
  46. supervisely/app/widgets/classes_table/classes_table.py +1 -0
  47. supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
  48. supervisely/app/widgets/dialog/dialog.py +12 -0
  49. supervisely/app/widgets/dialog/template.html +2 -1
  50. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
  51. supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
  52. supervisely/app/widgets/fast_table/fast_table.py +184 -60
  53. supervisely/app/widgets/fast_table/template.html +1 -1
  54. supervisely/app/widgets/heatmap/__init__.py +0 -0
  55. supervisely/app/widgets/heatmap/heatmap.py +564 -0
  56. supervisely/app/widgets/heatmap/script.js +533 -0
  57. supervisely/app/widgets/heatmap/style.css +233 -0
  58. supervisely/app/widgets/heatmap/template.html +21 -0
  59. supervisely/app/widgets/modal/__init__.py +0 -0
  60. supervisely/app/widgets/modal/modal.py +198 -0
  61. supervisely/app/widgets/modal/template.html +10 -0
  62. supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
  63. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  64. supervisely/app/widgets/radio_tabs/template.html +1 -0
  65. supervisely/app/widgets/select/select.py +6 -3
  66. supervisely/app/widgets/select_class/__init__.py +0 -0
  67. supervisely/app/widgets/select_class/select_class.py +363 -0
  68. supervisely/app/widgets/select_class/template.html +50 -0
  69. supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
  70. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
  71. supervisely/app/widgets/select_tag/__init__.py +0 -0
  72. supervisely/app/widgets/select_tag/select_tag.py +352 -0
  73. supervisely/app/widgets/select_tag/template.html +64 -0
  74. supervisely/app/widgets/select_team/select_team.py +37 -4
  75. supervisely/app/widgets/select_team/template.html +4 -5
  76. supervisely/app/widgets/select_user/__init__.py +0 -0
  77. supervisely/app/widgets/select_user/select_user.py +270 -0
  78. supervisely/app/widgets/select_user/template.html +13 -0
  79. supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
  80. supervisely/app/widgets/select_workspace/template.html +9 -12
  81. supervisely/app/widgets/table/table.py +68 -13
  82. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  83. supervisely/aug/aug.py +6 -2
  84. supervisely/convert/base_converter.py +1 -0
  85. supervisely/convert/converter.py +2 -2
  86. supervisely/convert/image/csv/csv_converter.py +24 -15
  87. supervisely/convert/image/image_converter.py +3 -1
  88. supervisely/convert/image/image_helper.py +48 -4
  89. supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
  90. supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
  91. supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
  92. supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
  93. supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
  94. supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
  95. supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
  96. supervisely/convert/pointcloud/las/las_converter.py +13 -1
  97. supervisely/convert/pointcloud/las/las_helper.py +110 -11
  98. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
  99. supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
  100. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
  101. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
  102. supervisely/convert/video/__init__.py +1 -0
  103. supervisely/convert/video/multi_view/__init__.py +0 -0
  104. supervisely/convert/video/multi_view/multi_view.py +543 -0
  105. supervisely/convert/video/sly/sly_video_converter.py +359 -3
  106. supervisely/convert/video/video_converter.py +24 -4
  107. supervisely/convert/volume/dicom/dicom_converter.py +13 -5
  108. supervisely/convert/volume/dicom/dicom_helper.py +30 -18
  109. supervisely/geometry/constants.py +1 -0
  110. supervisely/geometry/geometry.py +4 -0
  111. supervisely/geometry/helpers.py +5 -1
  112. supervisely/geometry/oriented_bbox.py +676 -0
  113. supervisely/geometry/polyline_3d.py +110 -0
  114. supervisely/geometry/rectangle.py +2 -1
  115. supervisely/io/env.py +76 -1
  116. supervisely/io/fs.py +21 -0
  117. supervisely/nn/benchmark/base_evaluator.py +104 -11
  118. supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
  119. supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
  120. supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
  121. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
  122. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
  123. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
  124. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
  125. supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
  126. supervisely/nn/inference/cache.py +43 -18
  127. supervisely/nn/inference/gui/serving_gui_template.py +5 -2
  128. supervisely/nn/inference/inference.py +916 -222
  129. supervisely/nn/inference/inference_request.py +55 -10
  130. supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
  131. supervisely/nn/inference/predict_app/gui/gui.py +676 -488
  132. supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
  133. supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
  134. supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
  135. supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
  136. supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
  137. supervisely/nn/inference/predict_app/gui/utils.py +236 -119
  138. supervisely/nn/inference/predict_app/predict_app.py +2 -2
  139. supervisely/nn/inference/session.py +43 -35
  140. supervisely/nn/inference/tracking/bbox_tracking.py +118 -35
  141. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  142. supervisely/nn/inference/tracking/tracker_interface.py +10 -1
  143. supervisely/nn/inference/uploader.py +139 -12
  144. supervisely/nn/live_training/__init__.py +7 -0
  145. supervisely/nn/live_training/api_server.py +111 -0
  146. supervisely/nn/live_training/artifacts_utils.py +243 -0
  147. supervisely/nn/live_training/checkpoint_utils.py +229 -0
  148. supervisely/nn/live_training/dynamic_sampler.py +44 -0
  149. supervisely/nn/live_training/helpers.py +14 -0
  150. supervisely/nn/live_training/incremental_dataset.py +146 -0
  151. supervisely/nn/live_training/live_training.py +497 -0
  152. supervisely/nn/live_training/loss_plateau_detector.py +111 -0
  153. supervisely/nn/live_training/request_queue.py +52 -0
  154. supervisely/nn/model/model_api.py +9 -0
  155. supervisely/nn/model/prediction.py +2 -1
  156. supervisely/nn/model/prediction_session.py +26 -14
  157. supervisely/nn/prediction_dto.py +19 -1
  158. supervisely/nn/tracker/base_tracker.py +11 -1
  159. supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
  160. supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
  161. supervisely/nn/tracker/botsort_tracker.py +94 -65
  162. supervisely/nn/tracker/utils.py +4 -5
  163. supervisely/nn/tracker/visualize.py +93 -93
  164. supervisely/nn/training/gui/classes_selector.py +16 -1
  165. supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
  166. supervisely/nn/training/train_app.py +46 -31
  167. supervisely/project/data_version.py +115 -51
  168. supervisely/project/download.py +1 -1
  169. supervisely/project/pointcloud_episode_project.py +37 -8
  170. supervisely/project/pointcloud_project.py +30 -2
  171. supervisely/project/project.py +14 -2
  172. supervisely/project/project_meta.py +27 -1
  173. supervisely/project/project_settings.py +32 -18
  174. supervisely/project/versioning/__init__.py +1 -0
  175. supervisely/project/versioning/common.py +20 -0
  176. supervisely/project/versioning/schema_fields.py +35 -0
  177. supervisely/project/versioning/video_schema.py +221 -0
  178. supervisely/project/versioning/volume_schema.py +87 -0
  179. supervisely/project/video_project.py +717 -15
  180. supervisely/project/volume_project.py +623 -5
  181. supervisely/template/experiment/experiment.html.jinja +4 -4
  182. supervisely/template/experiment/experiment_generator.py +14 -21
  183. supervisely/template/live_training/__init__.py +0 -0
  184. supervisely/template/live_training/header.html.jinja +96 -0
  185. supervisely/template/live_training/live_training.html.jinja +51 -0
  186. supervisely/template/live_training/live_training_generator.py +464 -0
  187. supervisely/template/live_training/sly-style.css +402 -0
  188. supervisely/template/live_training/template.html.jinja +18 -0
  189. supervisely/versions.json +28 -26
  190. supervisely/video/sampling.py +39 -20
  191. supervisely/video/video.py +41 -12
  192. supervisely/video_annotation/video_figure.py +38 -4
  193. supervisely/video_annotation/video_object.py +29 -4
  194. supervisely/volume/stl_converter.py +2 -0
  195. supervisely/worker_api/agent_rpc.py +24 -1
  196. supervisely/worker_api/rpc_servicer.py +31 -7
  197. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/METADATA +58 -40
  198. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/RECORD +203 -155
  199. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
  200. supervisely_lib/__init__.py +6 -1
  201. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
  202. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
  203. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,497 @@
1
+ import os
2
+ import numpy as np
3
+ from .api_server import start_api_server
4
+ from .request_queue import RequestQueue, RequestType
5
+ from .incremental_dataset import IncrementalDataset
6
+ from .helpers import ClassMap
7
+ import supervisely as sly
8
+ from supervisely import logger
9
+ from supervisely.nn import TaskType
10
+ from datetime import datetime
11
+ import signal
12
+ import sys
13
+ import time
14
+ from .checkpoint_utils import resolve_checkpoint, save_state_json
15
+ from .artifacts_utils import upload_artifacts
16
+ from .loss_plateau_detector import LossPlateauDetector
17
+ from pathlib import Path
18
+
19
+ class Phase:
20
+ READY_TO_START = "ready_to_start"
21
+ WAITING_FOR_SAMPLES = "waiting_for_samples"
22
+ INITIAL_TRAINING = "initial_training"
23
+ TRAINING = "training"
24
+
25
+ class LiveTraining:
26
+
27
+ from torch import nn # pylint: disable=import-error
28
+ task_type: str = None # Should be set in subclass
29
+ framework_name: str = None # Should be set in subclass
30
+
31
+ _task2geometries = {
32
+ TaskType.OBJECT_DETECTION: [sly.Rectangle],
33
+ TaskType.INSTANCE_SEGMENTATION: [sly.Bitmap, sly.Polygon],
34
+ TaskType.SEMANTIC_SEGMENTATION: [sly.Bitmap, sly.Polygon],
35
+ }
36
+
37
+ def __init__(
38
+ self,
39
+ initial_samples: int = 2,
40
+ filter_classes_by_task: bool = True,
41
+ ):
42
+ from torch import nn # pylint: disable=import-error
43
+ self.initial_samples = initial_samples
44
+ self.filter_classes_by_task = filter_classes_by_task
45
+ if self.task_type is None and self.filter_classes_by_task:
46
+ raise ValueError("task_type must be set in subclass if filter_classes_by_task is set to True")
47
+ if self.framework_name is None:
48
+ raise ValueError("framework_name must be set in subclass")
49
+
50
+ self.project_id = sly.env.project_id()
51
+ self.team_id = sly.env.team_id()
52
+ self.task_id = sly.env.task_id(raise_not_found=False)
53
+ self.app = sly.Application()
54
+ self.api = sly.Api()
55
+ self.request_queue = RequestQueue()
56
+
57
+ if os.getenv("DEVELOP_AND_DEBUG") and not sly.is_production():
58
+ logger.info(f"🔧 Initializing Develop & Debug application for project {self.project_id}...")
59
+ sly.app.development.supervisely_vpn_network(action="up")
60
+ debug_task = sly.app.development.create_debug_task(self.team_id, port="8000", project_id=self.project_id)
61
+ self.task_id = debug_task['id']
62
+ self._api_thread = start_api_server(self.app, self.request_queue)
63
+ self.phase = Phase.READY_TO_START
64
+ self.iter = 0
65
+ self._loss = None
66
+ self._is_paused = False
67
+ self._should_pause_after_continue = False
68
+ self.initial_iters = 60 # TODO: remove later
69
+ self.project_meta = self._fetch_project_meta(self.project_id)
70
+ self.class_map = self._init_class_map(self.project_meta)
71
+ self.dataset: IncrementalDataset = None
72
+ self.model: nn.Module = None
73
+ self.loss_plateau_detector = self._init_loss_plateau_detector()
74
+ self.work_dir = 'app_data'
75
+ self.latest_checkpoint_path = f"{self.work_dir}/checkpoints/latest.pth"
76
+
77
+ self.checkpoint_mode = os.getenv("modal.state.checkpointMode", "scratch")
78
+ selected_task_id_env = os.getenv("modal.state.selectedExperimentTaskId")
79
+ self.selected_experiment_task_id = int(selected_task_id_env) if selected_task_id_env else None
80
+
81
+ self.training_start_time = None
82
+ self._upload_in_progress = False
83
+
84
+ # from . import live_training_instance
85
+ # live_training_instance = self # for access from other modules
86
+
87
+ @property
88
+ def ready_to_predict(self):
89
+ return self.iter > self.initial_iters
90
+
91
+ def status(self):
92
+ return {
93
+ 'phase': self.phase,
94
+ 'samples_count': len(self.dataset) if self.dataset is not None else 0,
95
+ 'waiting_samples': self.initial_samples,
96
+ 'task_type': self.task_type,
97
+ 'iteration': self.iter,
98
+ 'loss': self._loss,
99
+ 'training_paused': self._is_paused,
100
+ 'ready_to_predict': self.ready_to_predict,
101
+ 'initial_iters': self.initial_iters,
102
+ }
103
+
104
+ def run(self):
105
+ self.training_start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
106
+ self._add_shutdown_callback()
107
+
108
+ work_dir_path = Path(self.work_dir)
109
+ work_dir_path.mkdir(parents=True, exist_ok=True)
110
+ model_meta_path = work_dir_path / "model_meta.json"
111
+ sly.json.dump_json_file(self.project_meta.to_json(), str(model_meta_path))
112
+
113
+ try:
114
+ self.phase = Phase.READY_TO_START
115
+ self._wait_for_start()
116
+ if self.checkpoint_mode == 'continue':
117
+ self._run_continue()
118
+ elif self.checkpoint_mode == 'finetune':
119
+ self._run_finetune()
120
+ else:
121
+ self._run_from_scratch()
122
+ except Exception as e:
123
+ if not sly.is_production():
124
+ raise e
125
+ else:
126
+ logger.error(f"Live training failed: {e}", exc_info=True)
127
+ final_checkpoint = self.latest_checkpoint_path
128
+ self.save_checkpoint(final_checkpoint)
129
+ save_state_json(self.state(), final_checkpoint)
130
+ self._upload_artifacts()
131
+
132
+ def _run_from_scratch(self):
133
+ self.phase = Phase.WAITING_FOR_SAMPLES
134
+ self._wait_for_initial_samples()
135
+ self.train(checkpoint_path=None)
136
+
137
+ def _run_continue(self):
138
+ checkpoint_path, state = self._load_checkpoint()
139
+ self.load_state(state)
140
+ image_ids = state.get('image_ids', [])
141
+ if image_ids:
142
+ self._restore_dataset(image_ids)
143
+ self.train(checkpoint_path=checkpoint_path)
144
+
145
+ def _run_finetune(self):
146
+ checkpoint_path, _ = self._load_checkpoint()
147
+ self.phase = Phase.WAITING_FOR_SAMPLES
148
+ self._wait_for_initial_samples()
149
+ self.train(checkpoint_path=checkpoint_path)
150
+
151
+ def _wait_for_start(self):
152
+ request = self.request_queue.get()
153
+ while request.type != RequestType.START:
154
+ if request.type == RequestType.STATUS:
155
+ status = self.status()
156
+ request.future.set_result(status)
157
+ else:
158
+ request.future.set_exception(Exception(f"Unexpected request {request.type} while waiting for START"))
159
+ request = self.request_queue.get()
160
+ # When START is received
161
+ status = self.status()
162
+ status['phase'] = Phase.WAITING_FOR_SAMPLES
163
+ request.future.set_result(status)
164
+
165
+ def _wait_until_samples_added(
166
+ self,
167
+ samples_needed: int,
168
+ max_wait_time: int = None,
169
+ ):
170
+ sleep_interval = 0.5
171
+ elapsed_time = 0
172
+ samples_before = len(self.dataset)
173
+
174
+ while len(self.dataset) - samples_before < samples_needed:
175
+ if max_wait_time is not None and elapsed_time >= max_wait_time:
176
+ raise RuntimeError("Timeout waiting for samples")
177
+
178
+ if not self.request_queue.is_empty():
179
+ self._process_pending_requests()
180
+
181
+ time.sleep(sleep_interval)
182
+ elapsed_time += sleep_interval
183
+
184
+ def _wait_for_initial_samples(self):
185
+ if len(self.dataset) >= self.initial_samples:
186
+ return
187
+
188
+ self.phase = Phase.WAITING_FOR_SAMPLES
189
+ self._is_paused = True
190
+
191
+ samples_needed = self.initial_samples - len(self.dataset)
192
+ logger.info(f"Waiting for {samples_needed} initial samples")
193
+ self._wait_until_samples_added(
194
+ samples_needed=samples_needed,
195
+ max_wait_time=3600,
196
+ )
197
+
198
+ self._is_paused = False
199
+
200
+ def _process_pending_requests(self):
201
+ requests = self.request_queue.get_all()
202
+ if not requests:
203
+ return
204
+
205
+ new_samples_added = False
206
+
207
+ for request in requests:
208
+ try:
209
+ if request.type == RequestType.PREDICT:
210
+ result = self._handle_predict(request.data)
211
+ request.future.set_result(result)
212
+
213
+ elif request.type == RequestType.ADD_SAMPLE:
214
+ result = self._handle_add_sample(request.data)
215
+ request.future.set_result(result)
216
+ new_samples_added = True
217
+
218
+ elif request.type == RequestType.STATUS:
219
+ result = self.status()
220
+ request.future.set_result(result)
221
+
222
+ except Exception as e:
223
+ logger.error(f"Error processing request {request.type}: {e}", exc_info=True)
224
+ request.future.set_exception(e)
225
+
226
+ def train(self, checkpoint_path: str = None):
227
+ """
228
+ Main training loop. Implement framework-specific training logic here.
229
+ Prepare model config, set hyperparameters and run training.
230
+ Handle phases: initial training, training
231
+ """
232
+ raise NotImplementedError
233
+
234
+ def predict(self, model: nn.Module, image_np, image_info) -> list:
235
+ """
236
+ Run inference on a single image and return predictions as a list of sly figures in json format.
237
+ """
238
+ raise NotImplementedError
239
+
240
+ def _handle_predict(self, data: dict):
241
+ image_np = data['image']
242
+ image_info = {'id': data['image_id']}
243
+ model = self.model
244
+ was_training = model.training
245
+ model.eval()
246
+ try:
247
+ objects = self.predict(self.model, image_np=image_np, image_info=image_info)
248
+ return {
249
+ 'objects': objects,
250
+ 'image_id': data['image_id'],
251
+ 'status': self.status(),
252
+ }
253
+ finally:
254
+ # Restore training mode
255
+ if was_training:
256
+ model.train()
257
+
258
+ def add_sample(
259
+ self,
260
+ image_id: int,
261
+ image_np: np.ndarray,
262
+ annotation: sly.Annotation,
263
+ image_name: str
264
+ ) -> dict:
265
+ return self.dataset.add_or_update(image_id, image_np, annotation, image_name)
266
+
267
+ def _handle_add_sample(self, data: dict):
268
+ ann_json = data['annotation']
269
+ ann_json = self._filter_annotation(ann_json)
270
+ sly_ann = sly.Annotation.from_json(ann_json, self.project_meta)
271
+ self.add_sample(
272
+ image_id=data['image_id'],
273
+ image_np=data['image'],
274
+ annotation=sly_ann,
275
+ image_name=data['image_name']
276
+ )
277
+ if (len(self.dataset) >= self.initial_samples) and self.phase==Phase.WAITING_FOR_SAMPLES:
278
+ self.phase = Phase.INITIAL_TRAINING
279
+ return {
280
+ 'image_id': data['image_id'],
281
+ 'status': self.status(),
282
+ }
283
+
284
+ def _fetch_project_meta(self, project_id: int) -> sly.ProjectMeta:
285
+ project_meta = self.api.project.get_meta(project_id)
286
+ project_meta = sly.ProjectMeta.from_json(project_meta)
287
+ return project_meta
288
+
289
+ def _init_class_map(self, project_meta: sly.ProjectMeta) -> ClassMap:
290
+ obj_classes = list(project_meta.obj_classes)
291
+
292
+ if self.task_type == TaskType.SEMANTIC_SEGMENTATION:
293
+ obj_classes.insert(0, sly.ObjClass(name='_background_', geometry_type=sly.Bitmap))
294
+
295
+ if self.filter_classes_by_task:
296
+ allowed_geometries = self._task2geometries[self.task_type]
297
+ obj_classes = [
298
+ obj_class for obj_class in obj_classes
299
+ if obj_class.geometry_type in allowed_geometries
300
+ ]
301
+
302
+ return ClassMap(obj_classes)
303
+
304
+ def _filter_annotation(self, ann_json: dict) -> dict:
305
+ # Filter objects according to class_map
306
+ # Important: Must be filtered before sly.Annotation.from_json due to static project meta
307
+ filtered_objects = []
308
+ for obj in ann_json['objects']:
309
+ sly_id = obj['classId']
310
+ if sly_id in self.class_map.sly_ids:
311
+ filtered_objects.append(obj)
312
+ ann_json['objects'] = filtered_objects
313
+ return ann_json
314
+
315
+ def after_train_step(self, loss: float):
316
+ self.iter += 1
317
+ self._loss = loss
318
+ if self._should_pause_after_continue:
319
+ self._is_paused = True
320
+ logger.info("Training was paused. Waiting for 1 new sample before resuming...")
321
+ self._wait_until_samples_added(samples_needed=1, max_wait_time=None)
322
+ self._should_pause_after_continue = False
323
+ logger.info("New sample added. Resuming training...")
324
+ self._is_paused = False
325
+ if self.loss_plateau_detector is not None:
326
+ is_plateau = self.loss_plateau_detector.step(loss, self.iter)
327
+ if is_plateau:
328
+ self._is_paused = True
329
+ self._wait_until_samples_added(
330
+ samples_needed=1,
331
+ max_wait_time=None,
332
+ )
333
+ self._is_paused = False
334
+ self.loss_plateau_detector.reset()
335
+ self._process_pending_requests()
336
+
337
+ def register_model(self, model: nn.Module):
338
+ self.model = model
339
+
340
+ def register_dataset(self, dataset: IncrementalDataset):
341
+ assert hasattr(dataset, 'add_or_update'), "Dataset must implement add_or_update method. Consider inheriting from IncrementalDataset."
342
+ self.dataset = dataset
343
+
344
+ def _load_checkpoint(self) -> tuple:
345
+ """Resolve and configure checkpoint based on checkpoint_mode."""
346
+ self._process_pending_requests()
347
+ checkpoint_path, class_map, state = resolve_checkpoint(
348
+ checkpoint_mode=self.checkpoint_mode,
349
+ selected_experiment_task_id=self.selected_experiment_task_id,
350
+ class_map=self.class_map,
351
+ project_meta=self.project_meta,
352
+ api=self.api,
353
+ team_id=self.team_id,
354
+ work_dir=self.work_dir
355
+ )
356
+
357
+ self.class_map = class_map
358
+ self._process_pending_requests()
359
+ return checkpoint_path, state
360
+
361
+ def state(self):
362
+ state = {
363
+ 'phase': self.phase,
364
+ 'iter': self.iter,
365
+ 'loss': self._loss,
366
+ 'clases': [cls.name for cls in self.class_map.obj_classes],
367
+ 'image_ids': self.dataset.get_image_ids() if self.dataset else [],
368
+ 'dataset_size': len(self.dataset) if self.dataset else 0,
369
+ 'is_paused': self._is_paused
370
+ }
371
+ return state
372
+
373
+ def load_state(self, state: dict):
374
+ self.phase = state.get('phase', Phase.READY_TO_START)
375
+ self.iter = state.get('iter', 0)
376
+ self._loss = state.get('loss', None)
377
+ self.image_ids = state.get('image_ids', [])
378
+ if state.get('is_paused', False):
379
+ self._should_pause_after_continue = True
380
+ dataset_size = state.get('dataset_size', 0)
381
+
382
+ def _restore_dataset(self, image_ids: list):
383
+ if not image_ids:
384
+ return
385
+
386
+ logger.info(f"Restoring {len(image_ids)} images from Supervisely...")
387
+
388
+ restored_count = 0
389
+ for img_id in image_ids:
390
+ img_info = self.api.image.get_info_by_id(img_id)
391
+
392
+ if img_info is None:
393
+ logger.warning(f"Image {img_id} not found, skipping")
394
+ continue
395
+
396
+ image_np = self.api.image.download_np(img_id)
397
+ ann_json = self.api.annotation.download_json(img_id)
398
+ ann = sly.Annotation.from_json(ann_json, self.project_meta)
399
+
400
+ self.dataset.add_or_update(
401
+ image_id=img_id,
402
+ image_np=image_np,
403
+ annotation=ann,
404
+ image_name=img_info.name
405
+ )
406
+
407
+ restored_count += 1
408
+
409
+ if restored_count % 10 == 0:
410
+ logger.info(f"Restored {restored_count}/{len(image_ids)}")
411
+
412
+ logger.info(f"Restored {restored_count} images")
413
+
414
+ def prepare_artifacts(self) -> dict:
415
+ """
416
+ Prepare all artifacts for upload (framework-specific).
417
+
418
+ Returns:
419
+ Dict with:
420
+ - checkpoint_path: path to checkpoint file
421
+ - checkpoint_info: dict with {name, iteration, loss}
422
+ - config_path: path to config file
423
+ - logs_dir: path to logs directory or None
424
+ - model_name: model name
425
+ - model_config: model configuration dict
426
+ - loss_history: dict with loss history
427
+ """
428
+ raise NotImplementedError(
429
+ f"{self.__class__.__name__} must implement prepare_artifacts()"
430
+ )
431
+
432
+ def _get_session_info(self) -> dict:
433
+ """Collect training session context"""
434
+ return {
435
+ 'team_id': self.team_id,
436
+ 'task_id': self.task_id,
437
+ 'project_id': self.project_id,
438
+ 'framework_name': self.framework_name,
439
+ 'task_type': self.task_type,
440
+ 'class_map': self.class_map,
441
+ 'start_time': self.training_start_time,
442
+ 'train_size': len(self.dataset) if self.dataset else 0,
443
+ 'initial_samples': self.initial_samples
444
+ }
445
+
446
+ def _upload_artifacts(self):
447
+ if self._upload_in_progress:
448
+ return
449
+
450
+ self._upload_in_progress = True
451
+
452
+ try:
453
+ session_info = self._get_session_info()
454
+ artifacts = self.prepare_artifacts()
455
+
456
+ report_url = upload_artifacts(
457
+ api=self.api,
458
+ session_info=session_info,
459
+ artifacts=artifacts
460
+ )
461
+
462
+ logger.info(f"Report: {report_url}")
463
+
464
+ except Exception as e:
465
+ logger.error(f"Upload failed: {e}", exc_info=True)
466
+
467
+ finally:
468
+ self._upload_in_progress = False
469
+
470
+ def save_checkpoint(self, checkpoint_path: str):
471
+ pass
472
+
473
+ def _init_loss_plateau_detector(self):
474
+ loss_plateau_detector = LossPlateauDetector()
475
+ loss_plateau_detector.register_save_checkpoint_callback(self.save_checkpoint)
476
+ return loss_plateau_detector
477
+
478
+ def _add_shutdown_callback(self):
479
+ """Setup graceful shutdown: save experiment on SIGINT/SIGTERM"""
480
+ self._upload_in_progress = False
481
+
482
+ def signal_handler(signum, frame):
483
+ if self._upload_in_progress:
484
+ # Already uploading - force exit on second signal
485
+ signal.signal(signal.SIGINT, lambda s, f: sys.exit(1))
486
+ signal.signal(signal.SIGTERM, lambda s, f: sys.exit(1))
487
+ return
488
+
489
+ # Save checkpoint and state before upload
490
+ logger.info("Received shutdown signal, saving checkpoint...")
491
+ self.save_checkpoint(self.latest_checkpoint_path)
492
+ save_state_json(self.state(), self.latest_checkpoint_path)
493
+ self._upload_artifacts()
494
+ sys.exit(0)
495
+
496
+ signal.signal(signal.SIGINT, signal_handler)
497
+ signal.signal(signal.SIGTERM, signal_handler)
@@ -0,0 +1,111 @@
1
+ import numpy as np
2
+ from typing import Callable
3
+
4
+
5
+ class LossPlateauDetector:
6
+ """
7
+ Detect plateau in training loss using moving average comparison.
8
+
9
+ Args:
10
+ window_size: Number of iterations for moving average
11
+ threshold: Relative change threshold (e.g., 0.005 = 0.5%)
12
+ patience: Number of consecutive plateau detections before action
13
+ check_interval: Check frequency (every N iterations)
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ window_size: int = 20,
19
+ threshold: float = 0.005,
20
+ patience: int = 1,
21
+ check_interval: int = 1,
22
+ ):
23
+ self.window_size = window_size
24
+ self.threshold = threshold
25
+ self.check_interval = check_interval
26
+ self.patience = patience
27
+ self._min_iterations = 2 * window_size
28
+
29
+ # State
30
+ self.loss_history = []
31
+ self.consecutive_plateau_count = 0
32
+ self._save_checkpoint_fn = None
33
+
34
+ def register_save_checkpoint_callback(self, fn: Callable[[], None]):
35
+ """Register callback function to save checkpoint when plateau detected"""
36
+ self._save_checkpoint_fn = fn
37
+
38
+ def reset(self):
39
+ """Reset detector state"""
40
+ self.loss_history = []
41
+ self.consecutive_plateau_count = 0
42
+
43
+ def step(self, loss: float, current_iter: int) -> bool:
44
+ """
45
+ Process one training iteration.
46
+
47
+ Args:
48
+ loss: Current loss value
49
+ current_iter: Current iteration number
50
+
51
+ Returns:
52
+ True if plateau confirmed and checkpoint saved, False otherwise
53
+ """
54
+ self.loss_history.append(loss)
55
+
56
+ # Check only at specified intervals
57
+ if (current_iter + 1) % self.check_interval != 0:
58
+ return False
59
+
60
+ # Need enough data
61
+ if len(self.loss_history) < self._min_iterations:
62
+ return False
63
+
64
+ # Check for plateau
65
+ is_plateau, info = self._check_plateau(current_iter)
66
+
67
+ if is_plateau:
68
+ self.consecutive_plateau_count += 1
69
+ print(
70
+ f'[Plateau Detection] Iteration {current_iter}: '
71
+ f'Signal {self.consecutive_plateau_count}/{self.patience} '
72
+ f'(change: {info["metric"]:.6f}, threshold: {self.threshold})'
73
+ )
74
+
75
+ # Trigger action when patience reached
76
+ if self.consecutive_plateau_count >= self.patience:
77
+ print(f'[Plateau Detection] Plateau confirmed, saving checkpoint...')
78
+
79
+ if self._save_checkpoint_fn is not None:
80
+ self._save_checkpoint_fn()
81
+ print(f'[Plateau Detection] Checkpoint saved')
82
+ else:
83
+ print(f'[Plateau Detection] No callback registered')
84
+
85
+ self.consecutive_plateau_count = 0
86
+ return True
87
+
88
+ return False
89
+
90
+ def _check_plateau(self, current_iter: int) -> tuple:
91
+ """Check if current window shows plateau"""
92
+ # Current window average
93
+ current_window = self.loss_history[-self.window_size:]
94
+ current_avg = np.mean(current_window)
95
+
96
+ # Previous window average
97
+ previous_window = self.loss_history[-2*self.window_size:-self.window_size]
98
+ previous_avg = np.mean(previous_window)
99
+
100
+ change = previous_avg - current_avg
101
+ is_plateau = change < self.threshold
102
+
103
+ info = {
104
+ 'iter': current_iter,
105
+ 'metric': change,
106
+ 'threshold': self.threshold,
107
+ 'previous_avg': previous_avg,
108
+ 'current_avg': current_avg,
109
+ }
110
+
111
+ return is_plateau, info
@@ -0,0 +1,52 @@
1
+ import queue
2
+ import asyncio
3
+ from typing import Any, Optional, List
4
+ from enum import Enum
5
+
6
+
7
+ class RequestType(Enum):
8
+ START = "start"
9
+ PREDICT = "predict"
10
+ ADD_SAMPLE = "add-sample"
11
+ STATUS = "status"
12
+
13
+
14
+ class Request:
15
+ """A simple representation of an API request."""
16
+ def __init__(self, request_type: RequestType, data: Optional[dict] = None, future: Optional[asyncio.Future] = None):
17
+ self.type = request_type
18
+ self.data = data
19
+ self.future = future
20
+
21
+ def to_tuple(self):
22
+ return (self.type, self.data, self.future)
23
+
24
+
25
+ class RequestQueue:
26
+ """Thread-safe queue for API requests."""
27
+
28
+ def __init__(self):
29
+ self._queue = queue.Queue()
30
+
31
+ def put(self, request_type: RequestType, data: Optional[dict] = None) -> asyncio.Future:
32
+ """Add request and return future for result."""
33
+ future = asyncio.Future()
34
+ self._queue.put(Request(request_type, data, future))
35
+ return future
36
+
37
+ def get_all(self) -> List[Request]:
38
+ """Get all pending requests (non-blocking)."""
39
+ requests = []
40
+ while not self._queue.empty():
41
+ try:
42
+ requests.append(self._queue.get_nowait())
43
+ except queue.Empty:
44
+ break
45
+ return requests
46
+
47
+ def is_empty(self) -> bool:
48
+ return self._queue.empty()
49
+
50
+ def get(self, timeout: float = None) -> Request:
51
+ """Get a single request from the queue."""
52
+ return self._queue.get(timeout=timeout)
@@ -72,6 +72,15 @@ class ModelAPI:
72
72
  else:
73
73
  return self._post("get_custom_inference_settings", {})["settings"]
74
74
 
75
+ def get_tracking_settings(self):
76
+ # @TODO: botsort hardcoded
77
+ # Add dropdown selector for tracking algorithms later
78
+ if self.task_id is not None:
79
+ return self.api.task.send_request(self.task_id, "get_tracking_settings", {})["botsort"]
80
+ else:
81
+ return self._post("get_tracking_settings", {})["botsort"]
82
+
83
+
75
84
  def get_model_meta(self):
76
85
  if self.task_id is not None:
77
86
  return ProjectMeta.from_json(
@@ -59,6 +59,7 @@ class Prediction:
59
59
  self.source = source
60
60
  if isinstance(annotation_json, Annotation):
61
61
  annotation_json = annotation_json.to_json()
62
+
62
63
  self.annotation_json = annotation_json
63
64
  self.model_meta = model_meta
64
65
  if isinstance(self.model_meta, dict):
@@ -157,7 +158,7 @@ class Prediction:
157
158
 
158
159
  @property
159
160
  def annotation(self) -> Annotation:
160
- if self._annotation is None:
161
+ if self._annotation is None and self.annotation_json is not None:
161
162
  if self.model_meta is None:
162
163
  raise ValueError("Model meta is not provided. Cannot create annotation.")
163
164
  model_meta = get_meta_from_annotation(self.annotation_json, self.model_meta)