supervisely 6.73.438__py3-none-any.whl → 6.73.513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +137 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/annotation.py +8 -2
- supervisely/annotation/json_geometries_map.py +14 -11
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +14 -11
- supervisely/api/api.py +59 -38
- supervisely/api/app_api.py +11 -2
- supervisely/api/dataset_api.py +74 -12
- supervisely/api/entities_collection_api.py +10 -0
- supervisely/api/entity_annotation/figure_api.py +52 -4
- supervisely/api/entity_annotation/object_api.py +3 -3
- supervisely/api/entity_annotation/tag_api.py +63 -12
- supervisely/api/guides_api.py +210 -0
- supervisely/api/image_api.py +72 -1
- supervisely/api/labeling_job_api.py +83 -1
- supervisely/api/labeling_queue_api.py +33 -7
- supervisely/api/module_api.py +9 -0
- supervisely/api/project_api.py +71 -26
- supervisely/api/storage_api.py +3 -1
- supervisely/api/task_api.py +13 -2
- supervisely/api/team_api.py +4 -3
- supervisely/api/video/video_annotation_api.py +119 -3
- supervisely/api/video/video_api.py +65 -14
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +23 -7
- supervisely/app/development/development.py +18 -2
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +105 -0
- supervisely/app/fastapi/subapp.py +88 -42
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +6 -0
- supervisely/app/widgets/activity_feed/__init__.py +0 -0
- supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
- supervisely/app/widgets/activity_feed/style.css +78 -0
- supervisely/app/widgets/activity_feed/template.html +22 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
- supervisely/app/widgets/classes_list_selector/template.html +60 -93
- supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
- supervisely/app/widgets/classes_table/classes_table.py +1 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +184 -60
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +564 -0
- supervisely/app/widgets/heatmap/script.js +533 -0
- supervisely/app/widgets/heatmap/style.css +233 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/modal/__init__.py +0 -0
- supervisely/app/widgets/modal/modal.py +198 -0
- supervisely/app/widgets/modal/template.html +10 -0
- supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -3
- supervisely/app/widgets/select_class/__init__.py +0 -0
- supervisely/app/widgets/select_class/select_class.py +363 -0
- supervisely/app/widgets/select_class/template.html +50 -0
- supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
- supervisely/app/widgets/select_tag/__init__.py +0 -0
- supervisely/app/widgets/select_tag/select_tag.py +352 -0
- supervisely/app/widgets/select_tag/template.html +64 -0
- supervisely/app/widgets/select_team/select_team.py +37 -4
- supervisely/app/widgets/select_team/template.html +4 -5
- supervisely/app/widgets/select_user/__init__.py +0 -0
- supervisely/app/widgets/select_user/select_user.py +270 -0
- supervisely/app/widgets/select_user/template.html +13 -0
- supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
- supervisely/app/widgets/select_workspace/template.html +9 -12
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/aug/aug.py +6 -2
- supervisely/convert/base_converter.py +1 -0
- supervisely/convert/converter.py +2 -2
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/image/image_converter.py +3 -1
- supervisely/convert/image/image_helper.py +48 -4
- supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
- supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
- supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
- supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
- supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
- supervisely/convert/pointcloud/las/las_converter.py +13 -1
- supervisely/convert/pointcloud/las/las_helper.py +110 -11
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
- supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
- supervisely/convert/video/__init__.py +1 -0
- supervisely/convert/video/multi_view/__init__.py +0 -0
- supervisely/convert/video/multi_view/multi_view.py +543 -0
- supervisely/convert/video/sly/sly_video_converter.py +359 -3
- supervisely/convert/video/video_converter.py +24 -4
- supervisely/convert/volume/dicom/dicom_converter.py +13 -5
- supervisely/convert/volume/dicom/dicom_helper.py +30 -18
- supervisely/geometry/constants.py +1 -0
- supervisely/geometry/geometry.py +4 -0
- supervisely/geometry/helpers.py +5 -1
- supervisely/geometry/oriented_bbox.py +676 -0
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/geometry/rectangle.py +2 -1
- supervisely/io/env.py +76 -1
- supervisely/io/fs.py +21 -0
- supervisely/nn/benchmark/base_evaluator.py +104 -11
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
- supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
- supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
- supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
- supervisely/nn/inference/cache.py +43 -18
- supervisely/nn/inference/gui/serving_gui_template.py +5 -2
- supervisely/nn/inference/inference.py +916 -222
- supervisely/nn/inference/inference_request.py +55 -10
- supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/inference/session.py +43 -35
- supervisely/nn/inference/tracking/bbox_tracking.py +118 -35
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +10 -1
- supervisely/nn/inference/uploader.py +139 -12
- supervisely/nn/live_training/__init__.py +7 -0
- supervisely/nn/live_training/api_server.py +111 -0
- supervisely/nn/live_training/artifacts_utils.py +243 -0
- supervisely/nn/live_training/checkpoint_utils.py +229 -0
- supervisely/nn/live_training/dynamic_sampler.py +44 -0
- supervisely/nn/live_training/helpers.py +14 -0
- supervisely/nn/live_training/incremental_dataset.py +146 -0
- supervisely/nn/live_training/live_training.py +497 -0
- supervisely/nn/live_training/loss_plateau_detector.py +111 -0
- supervisely/nn/live_training/request_queue.py +52 -0
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/model/prediction.py +2 -1
- supervisely/nn/model/prediction_session.py +26 -14
- supervisely/nn/prediction_dto.py +19 -1
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
- supervisely/nn/tracker/botsort_tracker.py +94 -65
- supervisely/nn/tracker/utils.py +4 -5
- supervisely/nn/tracker/visualize.py +93 -93
- supervisely/nn/training/gui/classes_selector.py +16 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
- supervisely/nn/training/train_app.py +46 -31
- supervisely/project/data_version.py +115 -51
- supervisely/project/download.py +1 -1
- supervisely/project/pointcloud_episode_project.py +37 -8
- supervisely/project/pointcloud_project.py +30 -2
- supervisely/project/project.py +14 -2
- supervisely/project/project_meta.py +27 -1
- supervisely/project/project_settings.py +32 -18
- supervisely/project/versioning/__init__.py +1 -0
- supervisely/project/versioning/common.py +20 -0
- supervisely/project/versioning/schema_fields.py +35 -0
- supervisely/project/versioning/video_schema.py +221 -0
- supervisely/project/versioning/volume_schema.py +87 -0
- supervisely/project/video_project.py +717 -15
- supervisely/project/volume_project.py +623 -5
- supervisely/template/experiment/experiment.html.jinja +4 -4
- supervisely/template/experiment/experiment_generator.py +14 -21
- supervisely/template/live_training/__init__.py +0 -0
- supervisely/template/live_training/header.html.jinja +96 -0
- supervisely/template/live_training/live_training.html.jinja +51 -0
- supervisely/template/live_training/live_training_generator.py +464 -0
- supervisely/template/live_training/sly-style.css +402 -0
- supervisely/template/live_training/template.html.jinja +18 -0
- supervisely/versions.json +28 -26
- supervisely/video/sampling.py +39 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/video_annotation/video_object.py +29 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/METADATA +58 -40
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/RECORD +203 -155
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
- supervisely_lib/__init__.py +6 -1
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,497 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
from .api_server import start_api_server
|
|
4
|
+
from .request_queue import RequestQueue, RequestType
|
|
5
|
+
from .incremental_dataset import IncrementalDataset
|
|
6
|
+
from .helpers import ClassMap
|
|
7
|
+
import supervisely as sly
|
|
8
|
+
from supervisely import logger
|
|
9
|
+
from supervisely.nn import TaskType
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
import signal
|
|
12
|
+
import sys
|
|
13
|
+
import time
|
|
14
|
+
from .checkpoint_utils import resolve_checkpoint, save_state_json
|
|
15
|
+
from .artifacts_utils import upload_artifacts
|
|
16
|
+
from .loss_plateau_detector import LossPlateauDetector
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
class Phase:
|
|
20
|
+
READY_TO_START = "ready_to_start"
|
|
21
|
+
WAITING_FOR_SAMPLES = "waiting_for_samples"
|
|
22
|
+
INITIAL_TRAINING = "initial_training"
|
|
23
|
+
TRAINING = "training"
|
|
24
|
+
|
|
25
|
+
class LiveTraining:
|
|
26
|
+
|
|
27
|
+
from torch import nn # pylint: disable=import-error
|
|
28
|
+
task_type: str = None # Should be set in subclass
|
|
29
|
+
framework_name: str = None # Should be set in subclass
|
|
30
|
+
|
|
31
|
+
_task2geometries = {
|
|
32
|
+
TaskType.OBJECT_DETECTION: [sly.Rectangle],
|
|
33
|
+
TaskType.INSTANCE_SEGMENTATION: [sly.Bitmap, sly.Polygon],
|
|
34
|
+
TaskType.SEMANTIC_SEGMENTATION: [sly.Bitmap, sly.Polygon],
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
initial_samples: int = 2,
|
|
40
|
+
filter_classes_by_task: bool = True,
|
|
41
|
+
):
|
|
42
|
+
from torch import nn # pylint: disable=import-error
|
|
43
|
+
self.initial_samples = initial_samples
|
|
44
|
+
self.filter_classes_by_task = filter_classes_by_task
|
|
45
|
+
if self.task_type is None and self.filter_classes_by_task:
|
|
46
|
+
raise ValueError("task_type must be set in subclass if filter_classes_by_task is set to True")
|
|
47
|
+
if self.framework_name is None:
|
|
48
|
+
raise ValueError("framework_name must be set in subclass")
|
|
49
|
+
|
|
50
|
+
self.project_id = sly.env.project_id()
|
|
51
|
+
self.team_id = sly.env.team_id()
|
|
52
|
+
self.task_id = sly.env.task_id(raise_not_found=False)
|
|
53
|
+
self.app = sly.Application()
|
|
54
|
+
self.api = sly.Api()
|
|
55
|
+
self.request_queue = RequestQueue()
|
|
56
|
+
|
|
57
|
+
if os.getenv("DEVELOP_AND_DEBUG") and not sly.is_production():
|
|
58
|
+
logger.info(f"🔧 Initializing Develop & Debug application for project {self.project_id}...")
|
|
59
|
+
sly.app.development.supervisely_vpn_network(action="up")
|
|
60
|
+
debug_task = sly.app.development.create_debug_task(self.team_id, port="8000", project_id=self.project_id)
|
|
61
|
+
self.task_id = debug_task['id']
|
|
62
|
+
self._api_thread = start_api_server(self.app, self.request_queue)
|
|
63
|
+
self.phase = Phase.READY_TO_START
|
|
64
|
+
self.iter = 0
|
|
65
|
+
self._loss = None
|
|
66
|
+
self._is_paused = False
|
|
67
|
+
self._should_pause_after_continue = False
|
|
68
|
+
self.initial_iters = 60 # TODO: remove later
|
|
69
|
+
self.project_meta = self._fetch_project_meta(self.project_id)
|
|
70
|
+
self.class_map = self._init_class_map(self.project_meta)
|
|
71
|
+
self.dataset: IncrementalDataset = None
|
|
72
|
+
self.model: nn.Module = None
|
|
73
|
+
self.loss_plateau_detector = self._init_loss_plateau_detector()
|
|
74
|
+
self.work_dir = 'app_data'
|
|
75
|
+
self.latest_checkpoint_path = f"{self.work_dir}/checkpoints/latest.pth"
|
|
76
|
+
|
|
77
|
+
self.checkpoint_mode = os.getenv("modal.state.checkpointMode", "scratch")
|
|
78
|
+
selected_task_id_env = os.getenv("modal.state.selectedExperimentTaskId")
|
|
79
|
+
self.selected_experiment_task_id = int(selected_task_id_env) if selected_task_id_env else None
|
|
80
|
+
|
|
81
|
+
self.training_start_time = None
|
|
82
|
+
self._upload_in_progress = False
|
|
83
|
+
|
|
84
|
+
# from . import live_training_instance
|
|
85
|
+
# live_training_instance = self # for access from other modules
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def ready_to_predict(self):
|
|
89
|
+
return self.iter > self.initial_iters
|
|
90
|
+
|
|
91
|
+
def status(self):
|
|
92
|
+
return {
|
|
93
|
+
'phase': self.phase,
|
|
94
|
+
'samples_count': len(self.dataset) if self.dataset is not None else 0,
|
|
95
|
+
'waiting_samples': self.initial_samples,
|
|
96
|
+
'task_type': self.task_type,
|
|
97
|
+
'iteration': self.iter,
|
|
98
|
+
'loss': self._loss,
|
|
99
|
+
'training_paused': self._is_paused,
|
|
100
|
+
'ready_to_predict': self.ready_to_predict,
|
|
101
|
+
'initial_iters': self.initial_iters,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
def run(self):
|
|
105
|
+
self.training_start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
106
|
+
self._add_shutdown_callback()
|
|
107
|
+
|
|
108
|
+
work_dir_path = Path(self.work_dir)
|
|
109
|
+
work_dir_path.mkdir(parents=True, exist_ok=True)
|
|
110
|
+
model_meta_path = work_dir_path / "model_meta.json"
|
|
111
|
+
sly.json.dump_json_file(self.project_meta.to_json(), str(model_meta_path))
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
self.phase = Phase.READY_TO_START
|
|
115
|
+
self._wait_for_start()
|
|
116
|
+
if self.checkpoint_mode == 'continue':
|
|
117
|
+
self._run_continue()
|
|
118
|
+
elif self.checkpoint_mode == 'finetune':
|
|
119
|
+
self._run_finetune()
|
|
120
|
+
else:
|
|
121
|
+
self._run_from_scratch()
|
|
122
|
+
except Exception as e:
|
|
123
|
+
if not sly.is_production():
|
|
124
|
+
raise e
|
|
125
|
+
else:
|
|
126
|
+
logger.error(f"Live training failed: {e}", exc_info=True)
|
|
127
|
+
final_checkpoint = self.latest_checkpoint_path
|
|
128
|
+
self.save_checkpoint(final_checkpoint)
|
|
129
|
+
save_state_json(self.state(), final_checkpoint)
|
|
130
|
+
self._upload_artifacts()
|
|
131
|
+
|
|
132
|
+
def _run_from_scratch(self):
|
|
133
|
+
self.phase = Phase.WAITING_FOR_SAMPLES
|
|
134
|
+
self._wait_for_initial_samples()
|
|
135
|
+
self.train(checkpoint_path=None)
|
|
136
|
+
|
|
137
|
+
def _run_continue(self):
|
|
138
|
+
checkpoint_path, state = self._load_checkpoint()
|
|
139
|
+
self.load_state(state)
|
|
140
|
+
image_ids = state.get('image_ids', [])
|
|
141
|
+
if image_ids:
|
|
142
|
+
self._restore_dataset(image_ids)
|
|
143
|
+
self.train(checkpoint_path=checkpoint_path)
|
|
144
|
+
|
|
145
|
+
def _run_finetune(self):
|
|
146
|
+
checkpoint_path, _ = self._load_checkpoint()
|
|
147
|
+
self.phase = Phase.WAITING_FOR_SAMPLES
|
|
148
|
+
self._wait_for_initial_samples()
|
|
149
|
+
self.train(checkpoint_path=checkpoint_path)
|
|
150
|
+
|
|
151
|
+
def _wait_for_start(self):
|
|
152
|
+
request = self.request_queue.get()
|
|
153
|
+
while request.type != RequestType.START:
|
|
154
|
+
if request.type == RequestType.STATUS:
|
|
155
|
+
status = self.status()
|
|
156
|
+
request.future.set_result(status)
|
|
157
|
+
else:
|
|
158
|
+
request.future.set_exception(Exception(f"Unexpected request {request.type} while waiting for START"))
|
|
159
|
+
request = self.request_queue.get()
|
|
160
|
+
# When START is received
|
|
161
|
+
status = self.status()
|
|
162
|
+
status['phase'] = Phase.WAITING_FOR_SAMPLES
|
|
163
|
+
request.future.set_result(status)
|
|
164
|
+
|
|
165
|
+
def _wait_until_samples_added(
|
|
166
|
+
self,
|
|
167
|
+
samples_needed: int,
|
|
168
|
+
max_wait_time: int = None,
|
|
169
|
+
):
|
|
170
|
+
sleep_interval = 0.5
|
|
171
|
+
elapsed_time = 0
|
|
172
|
+
samples_before = len(self.dataset)
|
|
173
|
+
|
|
174
|
+
while len(self.dataset) - samples_before < samples_needed:
|
|
175
|
+
if max_wait_time is not None and elapsed_time >= max_wait_time:
|
|
176
|
+
raise RuntimeError("Timeout waiting for samples")
|
|
177
|
+
|
|
178
|
+
if not self.request_queue.is_empty():
|
|
179
|
+
self._process_pending_requests()
|
|
180
|
+
|
|
181
|
+
time.sleep(sleep_interval)
|
|
182
|
+
elapsed_time += sleep_interval
|
|
183
|
+
|
|
184
|
+
def _wait_for_initial_samples(self):
|
|
185
|
+
if len(self.dataset) >= self.initial_samples:
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
self.phase = Phase.WAITING_FOR_SAMPLES
|
|
189
|
+
self._is_paused = True
|
|
190
|
+
|
|
191
|
+
samples_needed = self.initial_samples - len(self.dataset)
|
|
192
|
+
logger.info(f"Waiting for {samples_needed} initial samples")
|
|
193
|
+
self._wait_until_samples_added(
|
|
194
|
+
samples_needed=samples_needed,
|
|
195
|
+
max_wait_time=3600,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
self._is_paused = False
|
|
199
|
+
|
|
200
|
+
def _process_pending_requests(self):
|
|
201
|
+
requests = self.request_queue.get_all()
|
|
202
|
+
if not requests:
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
new_samples_added = False
|
|
206
|
+
|
|
207
|
+
for request in requests:
|
|
208
|
+
try:
|
|
209
|
+
if request.type == RequestType.PREDICT:
|
|
210
|
+
result = self._handle_predict(request.data)
|
|
211
|
+
request.future.set_result(result)
|
|
212
|
+
|
|
213
|
+
elif request.type == RequestType.ADD_SAMPLE:
|
|
214
|
+
result = self._handle_add_sample(request.data)
|
|
215
|
+
request.future.set_result(result)
|
|
216
|
+
new_samples_added = True
|
|
217
|
+
|
|
218
|
+
elif request.type == RequestType.STATUS:
|
|
219
|
+
result = self.status()
|
|
220
|
+
request.future.set_result(result)
|
|
221
|
+
|
|
222
|
+
except Exception as e:
|
|
223
|
+
logger.error(f"Error processing request {request.type}: {e}", exc_info=True)
|
|
224
|
+
request.future.set_exception(e)
|
|
225
|
+
|
|
226
|
+
def train(self, checkpoint_path: str = None):
|
|
227
|
+
"""
|
|
228
|
+
Main training loop. Implement framework-specific training logic here.
|
|
229
|
+
Prepare model config, set hyperparameters and run training.
|
|
230
|
+
Handle phases: initial training, training
|
|
231
|
+
"""
|
|
232
|
+
raise NotImplementedError
|
|
233
|
+
|
|
234
|
+
def predict(self, model: nn.Module, image_np, image_info) -> list:
|
|
235
|
+
"""
|
|
236
|
+
Run inference on a single image and return predictions as a list of sly figures in json format.
|
|
237
|
+
"""
|
|
238
|
+
raise NotImplementedError
|
|
239
|
+
|
|
240
|
+
def _handle_predict(self, data: dict):
|
|
241
|
+
image_np = data['image']
|
|
242
|
+
image_info = {'id': data['image_id']}
|
|
243
|
+
model = self.model
|
|
244
|
+
was_training = model.training
|
|
245
|
+
model.eval()
|
|
246
|
+
try:
|
|
247
|
+
objects = self.predict(self.model, image_np=image_np, image_info=image_info)
|
|
248
|
+
return {
|
|
249
|
+
'objects': objects,
|
|
250
|
+
'image_id': data['image_id'],
|
|
251
|
+
'status': self.status(),
|
|
252
|
+
}
|
|
253
|
+
finally:
|
|
254
|
+
# Restore training mode
|
|
255
|
+
if was_training:
|
|
256
|
+
model.train()
|
|
257
|
+
|
|
258
|
+
def add_sample(
|
|
259
|
+
self,
|
|
260
|
+
image_id: int,
|
|
261
|
+
image_np: np.ndarray,
|
|
262
|
+
annotation: sly.Annotation,
|
|
263
|
+
image_name: str
|
|
264
|
+
) -> dict:
|
|
265
|
+
return self.dataset.add_or_update(image_id, image_np, annotation, image_name)
|
|
266
|
+
|
|
267
|
+
def _handle_add_sample(self, data: dict):
|
|
268
|
+
ann_json = data['annotation']
|
|
269
|
+
ann_json = self._filter_annotation(ann_json)
|
|
270
|
+
sly_ann = sly.Annotation.from_json(ann_json, self.project_meta)
|
|
271
|
+
self.add_sample(
|
|
272
|
+
image_id=data['image_id'],
|
|
273
|
+
image_np=data['image'],
|
|
274
|
+
annotation=sly_ann,
|
|
275
|
+
image_name=data['image_name']
|
|
276
|
+
)
|
|
277
|
+
if (len(self.dataset) >= self.initial_samples) and self.phase==Phase.WAITING_FOR_SAMPLES:
|
|
278
|
+
self.phase = Phase.INITIAL_TRAINING
|
|
279
|
+
return {
|
|
280
|
+
'image_id': data['image_id'],
|
|
281
|
+
'status': self.status(),
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
def _fetch_project_meta(self, project_id: int) -> sly.ProjectMeta:
|
|
285
|
+
project_meta = self.api.project.get_meta(project_id)
|
|
286
|
+
project_meta = sly.ProjectMeta.from_json(project_meta)
|
|
287
|
+
return project_meta
|
|
288
|
+
|
|
289
|
+
def _init_class_map(self, project_meta: sly.ProjectMeta) -> ClassMap:
|
|
290
|
+
obj_classes = list(project_meta.obj_classes)
|
|
291
|
+
|
|
292
|
+
if self.task_type == TaskType.SEMANTIC_SEGMENTATION:
|
|
293
|
+
obj_classes.insert(0, sly.ObjClass(name='_background_', geometry_type=sly.Bitmap))
|
|
294
|
+
|
|
295
|
+
if self.filter_classes_by_task:
|
|
296
|
+
allowed_geometries = self._task2geometries[self.task_type]
|
|
297
|
+
obj_classes = [
|
|
298
|
+
obj_class for obj_class in obj_classes
|
|
299
|
+
if obj_class.geometry_type in allowed_geometries
|
|
300
|
+
]
|
|
301
|
+
|
|
302
|
+
return ClassMap(obj_classes)
|
|
303
|
+
|
|
304
|
+
def _filter_annotation(self, ann_json: dict) -> dict:
|
|
305
|
+
# Filter objects according to class_map
|
|
306
|
+
# Important: Must be filtered before sly.Annotation.from_json due to static project meta
|
|
307
|
+
filtered_objects = []
|
|
308
|
+
for obj in ann_json['objects']:
|
|
309
|
+
sly_id = obj['classId']
|
|
310
|
+
if sly_id in self.class_map.sly_ids:
|
|
311
|
+
filtered_objects.append(obj)
|
|
312
|
+
ann_json['objects'] = filtered_objects
|
|
313
|
+
return ann_json
|
|
314
|
+
|
|
315
|
+
def after_train_step(self, loss: float):
|
|
316
|
+
self.iter += 1
|
|
317
|
+
self._loss = loss
|
|
318
|
+
if self._should_pause_after_continue:
|
|
319
|
+
self._is_paused = True
|
|
320
|
+
logger.info("Training was paused. Waiting for 1 new sample before resuming...")
|
|
321
|
+
self._wait_until_samples_added(samples_needed=1, max_wait_time=None)
|
|
322
|
+
self._should_pause_after_continue = False
|
|
323
|
+
logger.info("New sample added. Resuming training...")
|
|
324
|
+
self._is_paused = False
|
|
325
|
+
if self.loss_plateau_detector is not None:
|
|
326
|
+
is_plateau = self.loss_plateau_detector.step(loss, self.iter)
|
|
327
|
+
if is_plateau:
|
|
328
|
+
self._is_paused = True
|
|
329
|
+
self._wait_until_samples_added(
|
|
330
|
+
samples_needed=1,
|
|
331
|
+
max_wait_time=None,
|
|
332
|
+
)
|
|
333
|
+
self._is_paused = False
|
|
334
|
+
self.loss_plateau_detector.reset()
|
|
335
|
+
self._process_pending_requests()
|
|
336
|
+
|
|
337
|
+
def register_model(self, model: nn.Module):
|
|
338
|
+
self.model = model
|
|
339
|
+
|
|
340
|
+
def register_dataset(self, dataset: IncrementalDataset):
|
|
341
|
+
assert hasattr(dataset, 'add_or_update'), "Dataset must implement add_or_update method. Consider inheriting from IncrementalDataset."
|
|
342
|
+
self.dataset = dataset
|
|
343
|
+
|
|
344
|
+
def _load_checkpoint(self) -> tuple:
|
|
345
|
+
"""Resolve and configure checkpoint based on checkpoint_mode."""
|
|
346
|
+
self._process_pending_requests()
|
|
347
|
+
checkpoint_path, class_map, state = resolve_checkpoint(
|
|
348
|
+
checkpoint_mode=self.checkpoint_mode,
|
|
349
|
+
selected_experiment_task_id=self.selected_experiment_task_id,
|
|
350
|
+
class_map=self.class_map,
|
|
351
|
+
project_meta=self.project_meta,
|
|
352
|
+
api=self.api,
|
|
353
|
+
team_id=self.team_id,
|
|
354
|
+
work_dir=self.work_dir
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
self.class_map = class_map
|
|
358
|
+
self._process_pending_requests()
|
|
359
|
+
return checkpoint_path, state
|
|
360
|
+
|
|
361
|
+
def state(self):
|
|
362
|
+
state = {
|
|
363
|
+
'phase': self.phase,
|
|
364
|
+
'iter': self.iter,
|
|
365
|
+
'loss': self._loss,
|
|
366
|
+
'clases': [cls.name for cls in self.class_map.obj_classes],
|
|
367
|
+
'image_ids': self.dataset.get_image_ids() if self.dataset else [],
|
|
368
|
+
'dataset_size': len(self.dataset) if self.dataset else 0,
|
|
369
|
+
'is_paused': self._is_paused
|
|
370
|
+
}
|
|
371
|
+
return state
|
|
372
|
+
|
|
373
|
+
def load_state(self, state: dict):
|
|
374
|
+
self.phase = state.get('phase', Phase.READY_TO_START)
|
|
375
|
+
self.iter = state.get('iter', 0)
|
|
376
|
+
self._loss = state.get('loss', None)
|
|
377
|
+
self.image_ids = state.get('image_ids', [])
|
|
378
|
+
if state.get('is_paused', False):
|
|
379
|
+
self._should_pause_after_continue = True
|
|
380
|
+
dataset_size = state.get('dataset_size', 0)
|
|
381
|
+
|
|
382
|
+
def _restore_dataset(self, image_ids: list):
|
|
383
|
+
if not image_ids:
|
|
384
|
+
return
|
|
385
|
+
|
|
386
|
+
logger.info(f"Restoring {len(image_ids)} images from Supervisely...")
|
|
387
|
+
|
|
388
|
+
restored_count = 0
|
|
389
|
+
for img_id in image_ids:
|
|
390
|
+
img_info = self.api.image.get_info_by_id(img_id)
|
|
391
|
+
|
|
392
|
+
if img_info is None:
|
|
393
|
+
logger.warning(f"Image {img_id} not found, skipping")
|
|
394
|
+
continue
|
|
395
|
+
|
|
396
|
+
image_np = self.api.image.download_np(img_id)
|
|
397
|
+
ann_json = self.api.annotation.download_json(img_id)
|
|
398
|
+
ann = sly.Annotation.from_json(ann_json, self.project_meta)
|
|
399
|
+
|
|
400
|
+
self.dataset.add_or_update(
|
|
401
|
+
image_id=img_id,
|
|
402
|
+
image_np=image_np,
|
|
403
|
+
annotation=ann,
|
|
404
|
+
image_name=img_info.name
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
restored_count += 1
|
|
408
|
+
|
|
409
|
+
if restored_count % 10 == 0:
|
|
410
|
+
logger.info(f"Restored {restored_count}/{len(image_ids)}")
|
|
411
|
+
|
|
412
|
+
logger.info(f"Restored {restored_count} images")
|
|
413
|
+
|
|
414
|
+
def prepare_artifacts(self) -> dict:
|
|
415
|
+
"""
|
|
416
|
+
Prepare all artifacts for upload (framework-specific).
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
Dict with:
|
|
420
|
+
- checkpoint_path: path to checkpoint file
|
|
421
|
+
- checkpoint_info: dict with {name, iteration, loss}
|
|
422
|
+
- config_path: path to config file
|
|
423
|
+
- logs_dir: path to logs directory or None
|
|
424
|
+
- model_name: model name
|
|
425
|
+
- model_config: model configuration dict
|
|
426
|
+
- loss_history: dict with loss history
|
|
427
|
+
"""
|
|
428
|
+
raise NotImplementedError(
|
|
429
|
+
f"{self.__class__.__name__} must implement prepare_artifacts()"
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
def _get_session_info(self) -> dict:
|
|
433
|
+
"""Collect training session context"""
|
|
434
|
+
return {
|
|
435
|
+
'team_id': self.team_id,
|
|
436
|
+
'task_id': self.task_id,
|
|
437
|
+
'project_id': self.project_id,
|
|
438
|
+
'framework_name': self.framework_name,
|
|
439
|
+
'task_type': self.task_type,
|
|
440
|
+
'class_map': self.class_map,
|
|
441
|
+
'start_time': self.training_start_time,
|
|
442
|
+
'train_size': len(self.dataset) if self.dataset else 0,
|
|
443
|
+
'initial_samples': self.initial_samples
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
def _upload_artifacts(self):
|
|
447
|
+
if self._upload_in_progress:
|
|
448
|
+
return
|
|
449
|
+
|
|
450
|
+
self._upload_in_progress = True
|
|
451
|
+
|
|
452
|
+
try:
|
|
453
|
+
session_info = self._get_session_info()
|
|
454
|
+
artifacts = self.prepare_artifacts()
|
|
455
|
+
|
|
456
|
+
report_url = upload_artifacts(
|
|
457
|
+
api=self.api,
|
|
458
|
+
session_info=session_info,
|
|
459
|
+
artifacts=artifacts
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
logger.info(f"Report: {report_url}")
|
|
463
|
+
|
|
464
|
+
except Exception as e:
|
|
465
|
+
logger.error(f"Upload failed: {e}", exc_info=True)
|
|
466
|
+
|
|
467
|
+
finally:
|
|
468
|
+
self._upload_in_progress = False
|
|
469
|
+
|
|
470
|
+
def save_checkpoint(self, checkpoint_path: str):
|
|
471
|
+
pass
|
|
472
|
+
|
|
473
|
+
def _init_loss_plateau_detector(self):
|
|
474
|
+
loss_plateau_detector = LossPlateauDetector()
|
|
475
|
+
loss_plateau_detector.register_save_checkpoint_callback(self.save_checkpoint)
|
|
476
|
+
return loss_plateau_detector
|
|
477
|
+
|
|
478
|
+
def _add_shutdown_callback(self):
|
|
479
|
+
"""Setup graceful shutdown: save experiment on SIGINT/SIGTERM"""
|
|
480
|
+
self._upload_in_progress = False
|
|
481
|
+
|
|
482
|
+
def signal_handler(signum, frame):
|
|
483
|
+
if self._upload_in_progress:
|
|
484
|
+
# Already uploading - force exit on second signal
|
|
485
|
+
signal.signal(signal.SIGINT, lambda s, f: sys.exit(1))
|
|
486
|
+
signal.signal(signal.SIGTERM, lambda s, f: sys.exit(1))
|
|
487
|
+
return
|
|
488
|
+
|
|
489
|
+
# Save checkpoint and state before upload
|
|
490
|
+
logger.info("Received shutdown signal, saving checkpoint...")
|
|
491
|
+
self.save_checkpoint(self.latest_checkpoint_path)
|
|
492
|
+
save_state_json(self.state(), self.latest_checkpoint_path)
|
|
493
|
+
self._upload_artifacts()
|
|
494
|
+
sys.exit(0)
|
|
495
|
+
|
|
496
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
497
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class LossPlateauDetector:
|
|
6
|
+
"""
|
|
7
|
+
Detect plateau in training loss using moving average comparison.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
window_size: Number of iterations for moving average
|
|
11
|
+
threshold: Relative change threshold (e.g., 0.005 = 0.5%)
|
|
12
|
+
patience: Number of consecutive plateau detections before action
|
|
13
|
+
check_interval: Check frequency (every N iterations)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
window_size: int = 20,
|
|
19
|
+
threshold: float = 0.005,
|
|
20
|
+
patience: int = 1,
|
|
21
|
+
check_interval: int = 1,
|
|
22
|
+
):
|
|
23
|
+
self.window_size = window_size
|
|
24
|
+
self.threshold = threshold
|
|
25
|
+
self.check_interval = check_interval
|
|
26
|
+
self.patience = patience
|
|
27
|
+
self._min_iterations = 2 * window_size
|
|
28
|
+
|
|
29
|
+
# State
|
|
30
|
+
self.loss_history = []
|
|
31
|
+
self.consecutive_plateau_count = 0
|
|
32
|
+
self._save_checkpoint_fn = None
|
|
33
|
+
|
|
34
|
+
def register_save_checkpoint_callback(self, fn: Callable[[], None]):
|
|
35
|
+
"""Register callback function to save checkpoint when plateau detected"""
|
|
36
|
+
self._save_checkpoint_fn = fn
|
|
37
|
+
|
|
38
|
+
def reset(self):
|
|
39
|
+
"""Reset detector state"""
|
|
40
|
+
self.loss_history = []
|
|
41
|
+
self.consecutive_plateau_count = 0
|
|
42
|
+
|
|
43
|
+
def step(self, loss: float, current_iter: int) -> bool:
|
|
44
|
+
"""
|
|
45
|
+
Process one training iteration.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
loss: Current loss value
|
|
49
|
+
current_iter: Current iteration number
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
True if plateau confirmed and checkpoint saved, False otherwise
|
|
53
|
+
"""
|
|
54
|
+
self.loss_history.append(loss)
|
|
55
|
+
|
|
56
|
+
# Check only at specified intervals
|
|
57
|
+
if (current_iter + 1) % self.check_interval != 0:
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
# Need enough data
|
|
61
|
+
if len(self.loss_history) < self._min_iterations:
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
# Check for plateau
|
|
65
|
+
is_plateau, info = self._check_plateau(current_iter)
|
|
66
|
+
|
|
67
|
+
if is_plateau:
|
|
68
|
+
self.consecutive_plateau_count += 1
|
|
69
|
+
print(
|
|
70
|
+
f'[Plateau Detection] Iteration {current_iter}: '
|
|
71
|
+
f'Signal {self.consecutive_plateau_count}/{self.patience} '
|
|
72
|
+
f'(change: {info["metric"]:.6f}, threshold: {self.threshold})'
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Trigger action when patience reached
|
|
76
|
+
if self.consecutive_plateau_count >= self.patience:
|
|
77
|
+
print(f'[Plateau Detection] Plateau confirmed, saving checkpoint...')
|
|
78
|
+
|
|
79
|
+
if self._save_checkpoint_fn is not None:
|
|
80
|
+
self._save_checkpoint_fn()
|
|
81
|
+
print(f'[Plateau Detection] Checkpoint saved')
|
|
82
|
+
else:
|
|
83
|
+
print(f'[Plateau Detection] No callback registered')
|
|
84
|
+
|
|
85
|
+
self.consecutive_plateau_count = 0
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
def _check_plateau(self, current_iter: int) -> tuple:
|
|
91
|
+
"""Check if current window shows plateau"""
|
|
92
|
+
# Current window average
|
|
93
|
+
current_window = self.loss_history[-self.window_size:]
|
|
94
|
+
current_avg = np.mean(current_window)
|
|
95
|
+
|
|
96
|
+
# Previous window average
|
|
97
|
+
previous_window = self.loss_history[-2*self.window_size:-self.window_size]
|
|
98
|
+
previous_avg = np.mean(previous_window)
|
|
99
|
+
|
|
100
|
+
change = previous_avg - current_avg
|
|
101
|
+
is_plateau = change < self.threshold
|
|
102
|
+
|
|
103
|
+
info = {
|
|
104
|
+
'iter': current_iter,
|
|
105
|
+
'metric': change,
|
|
106
|
+
'threshold': self.threshold,
|
|
107
|
+
'previous_avg': previous_avg,
|
|
108
|
+
'current_avg': current_avg,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
return is_plateau, info
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import queue
|
|
2
|
+
import asyncio
|
|
3
|
+
from typing import Any, Optional, List
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RequestType(Enum):
|
|
8
|
+
START = "start"
|
|
9
|
+
PREDICT = "predict"
|
|
10
|
+
ADD_SAMPLE = "add-sample"
|
|
11
|
+
STATUS = "status"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Request:
|
|
15
|
+
"""A simple representation of an API request."""
|
|
16
|
+
def __init__(self, request_type: RequestType, data: Optional[dict] = None, future: Optional[asyncio.Future] = None):
|
|
17
|
+
self.type = request_type
|
|
18
|
+
self.data = data
|
|
19
|
+
self.future = future
|
|
20
|
+
|
|
21
|
+
def to_tuple(self):
|
|
22
|
+
return (self.type, self.data, self.future)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class RequestQueue:
|
|
26
|
+
"""Thread-safe queue for API requests."""
|
|
27
|
+
|
|
28
|
+
def __init__(self):
|
|
29
|
+
self._queue = queue.Queue()
|
|
30
|
+
|
|
31
|
+
def put(self, request_type: RequestType, data: Optional[dict] = None) -> asyncio.Future:
|
|
32
|
+
"""Add request and return future for result."""
|
|
33
|
+
future = asyncio.Future()
|
|
34
|
+
self._queue.put(Request(request_type, data, future))
|
|
35
|
+
return future
|
|
36
|
+
|
|
37
|
+
def get_all(self) -> List[Request]:
|
|
38
|
+
"""Get all pending requests (non-blocking)."""
|
|
39
|
+
requests = []
|
|
40
|
+
while not self._queue.empty():
|
|
41
|
+
try:
|
|
42
|
+
requests.append(self._queue.get_nowait())
|
|
43
|
+
except queue.Empty:
|
|
44
|
+
break
|
|
45
|
+
return requests
|
|
46
|
+
|
|
47
|
+
def is_empty(self) -> bool:
|
|
48
|
+
return self._queue.empty()
|
|
49
|
+
|
|
50
|
+
def get(self, timeout: float = None) -> Request:
|
|
51
|
+
"""Get a single request from the queue."""
|
|
52
|
+
return self._queue.get(timeout=timeout)
|
|
@@ -72,6 +72,15 @@ class ModelAPI:
|
|
|
72
72
|
else:
|
|
73
73
|
return self._post("get_custom_inference_settings", {})["settings"]
|
|
74
74
|
|
|
75
|
+
def get_tracking_settings(self):
|
|
76
|
+
# @TODO: botsort hardcoded
|
|
77
|
+
# Add dropdown selector for tracking algorithms later
|
|
78
|
+
if self.task_id is not None:
|
|
79
|
+
return self.api.task.send_request(self.task_id, "get_tracking_settings", {})["botsort"]
|
|
80
|
+
else:
|
|
81
|
+
return self._post("get_tracking_settings", {})["botsort"]
|
|
82
|
+
|
|
83
|
+
|
|
75
84
|
def get_model_meta(self):
|
|
76
85
|
if self.task_id is not None:
|
|
77
86
|
return ProjectMeta.from_json(
|
|
@@ -59,6 +59,7 @@ class Prediction:
|
|
|
59
59
|
self.source = source
|
|
60
60
|
if isinstance(annotation_json, Annotation):
|
|
61
61
|
annotation_json = annotation_json.to_json()
|
|
62
|
+
|
|
62
63
|
self.annotation_json = annotation_json
|
|
63
64
|
self.model_meta = model_meta
|
|
64
65
|
if isinstance(self.model_meta, dict):
|
|
@@ -157,7 +158,7 @@ class Prediction:
|
|
|
157
158
|
|
|
158
159
|
@property
|
|
159
160
|
def annotation(self) -> Annotation:
|
|
160
|
-
if self._annotation is None:
|
|
161
|
+
if self._annotation is None and self.annotation_json is not None:
|
|
161
162
|
if self.model_meta is None:
|
|
162
163
|
raise ValueError("Model meta is not provided. Cannot create annotation.")
|
|
163
164
|
model_meta = get_meta_from_annotation(self.annotation_json, self.model_meta)
|