supervisely 6.73.393__py3-none-any.whl → 6.73.395__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/entity_annotation/entity_annotation_api.py +3 -1
- supervisely/api/entity_annotation/figure_api.py +25 -16
- supervisely/api/module_api.py +2 -0
- supervisely/api/volume/volume_annotation_api.py +4 -2
- supervisely/api/volume/volume_figure_api.py +36 -7
- supervisely/convert/base_converter.py +2 -2
- supervisely/convert/volume/nii/nii_planes_volume_converter.py +51 -13
- supervisely/convert/volume/nii/nii_volume_converter.py +1 -1
- supervisely/convert/volume/nii/nii_volume_helper.py +96 -36
- supervisely/convert/volume/sly/sly_volume_converter.py +32 -3
- supervisely/nn/inference/inference.py +274 -35
- supervisely/nn/training/train_app.py +19 -20
- supervisely/project/volume_project.py +6 -0
- supervisely/template/experiment/experiment.html.jinja +4 -4
- supervisely/template/experiment/experiment_generator.py +1 -1
- supervisely/volume_annotation/volume_figure.py +45 -1
- supervisely/volume_annotation/volume_object.py +23 -6
- {supervisely-6.73.393.dist-info → supervisely-6.73.395.dist-info}/METADATA +1 -1
- {supervisely-6.73.393.dist-info → supervisely-6.73.395.dist-info}/RECORD +23 -23
- {supervisely-6.73.393.dist-info → supervisely-6.73.395.dist-info}/LICENSE +0 -0
- {supervisely-6.73.393.dist-info → supervisely-6.73.395.dist-info}/WHEEL +0 -0
- {supervisely-6.73.393.dist-info → supervisely-6.73.395.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.393.dist-info → supervisely-6.73.395.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,7 @@ from typing import List
|
|
|
4
4
|
import supervisely.convert.volume.sly.sly_volume_helper as sly_volume_helper
|
|
5
5
|
from supervisely.volume_annotation.volume_annotation import VolumeAnnotation
|
|
6
6
|
from supervisely import ProjectMeta, logger
|
|
7
|
+
from supervisely.project.volume_project import VolumeProject, VolumeDataset
|
|
7
8
|
from supervisely.convert.base_converter import AvailableVolumeConverters
|
|
8
9
|
from supervisely.convert.volume.volume_converter import VolumeConverter
|
|
9
10
|
from supervisely.io.fs import JUNK_FILES, get_file_ext, get_file_name
|
|
@@ -40,7 +41,7 @@ class SLYVolumeConverter(VolumeConverter):
|
|
|
40
41
|
ann = VolumeAnnotation.from_json(ann_json, meta) # , KeyIdMap())
|
|
41
42
|
return True
|
|
42
43
|
except Exception as e:
|
|
43
|
-
logger.
|
|
44
|
+
logger.warning(f"Failed to validate annotation: {repr(e)}")
|
|
44
45
|
return False
|
|
45
46
|
|
|
46
47
|
def validate_key_file(self, key_file_path: str) -> bool:
|
|
@@ -51,6 +52,9 @@ class SLYVolumeConverter(VolumeConverter):
|
|
|
51
52
|
return False
|
|
52
53
|
|
|
53
54
|
def validate_format(self) -> bool:
|
|
55
|
+
if self.read_sly_project(self._input_data):
|
|
56
|
+
return True
|
|
57
|
+
|
|
54
58
|
detected_ann_cnt = 0
|
|
55
59
|
vol_list, stl_dict, ann_dict, mask_dict = [], {}, {}, {}
|
|
56
60
|
for root, _, files in os.walk(self._input_data):
|
|
@@ -70,7 +74,7 @@ class SLYVolumeConverter(VolumeConverter):
|
|
|
70
74
|
ProjectMeta.from_json(meta_json)
|
|
71
75
|
)
|
|
72
76
|
except Exception as e:
|
|
73
|
-
logger.
|
|
77
|
+
logger.warning(f"Failed to merge meta: {repr(e)}")
|
|
74
78
|
continue
|
|
75
79
|
|
|
76
80
|
elif file in JUNK_FILES: # add better check
|
|
@@ -139,5 +143,30 @@ class SLYVolumeConverter(VolumeConverter):
|
|
|
139
143
|
ann_json = sly_volume_helper.rename_in_json(ann_json, renamed_classes, renamed_tags)
|
|
140
144
|
return VolumeAnnotation.from_json(ann_json, meta) # , KeyIdMap())
|
|
141
145
|
except Exception as e:
|
|
142
|
-
logger.
|
|
146
|
+
logger.warning(f"Failed to read annotation: {repr(e)}")
|
|
143
147
|
return item.create_empty_annotation()
|
|
148
|
+
|
|
149
|
+
def read_sly_project(self, input_data: str) -> bool:
|
|
150
|
+
try:
|
|
151
|
+
project_fs = VolumeProject.read_single(input_data)
|
|
152
|
+
self._meta = project_fs.meta
|
|
153
|
+
self._items = []
|
|
154
|
+
|
|
155
|
+
for dataset_fs in project_fs.datasets:
|
|
156
|
+
dataset_fs: VolumeDataset
|
|
157
|
+
|
|
158
|
+
for item_name in dataset_fs:
|
|
159
|
+
img_path, ann_path = dataset_fs.get_item_paths(item_name)
|
|
160
|
+
item = self.Item(
|
|
161
|
+
item_path=img_path,
|
|
162
|
+
ann_data=ann_path,
|
|
163
|
+
shape=None,
|
|
164
|
+
interpolation_dir=dataset_fs.get_interpolation_dir(item_name),
|
|
165
|
+
mask_dir=dataset_fs.get_mask_dir(item_name),
|
|
166
|
+
)
|
|
167
|
+
self._items.append(item)
|
|
168
|
+
return True
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.info(f"Failed to read Supervisely project: {repr(e)}")
|
|
172
|
+
return False
|
|
@@ -93,6 +93,7 @@ from supervisely.project.project_meta import ProjectMeta
|
|
|
93
93
|
from supervisely.sly_logger import logger
|
|
94
94
|
from supervisely.task.progress import Progress
|
|
95
95
|
from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
|
|
96
|
+
from supervisely.nn.model.model_api import ModelAPI
|
|
96
97
|
|
|
97
98
|
try:
|
|
98
99
|
from typing import Literal
|
|
@@ -150,12 +151,15 @@ class Inference:
|
|
|
150
151
|
use_gui: Optional[bool] = False,
|
|
151
152
|
multithread_inference: Optional[bool] = True,
|
|
152
153
|
use_serving_gui_template: Optional[bool] = False,
|
|
154
|
+
model: Optional[str] = None,
|
|
155
|
+
device: Optional[str] = None,
|
|
156
|
+
runtime: Optional[str] = None,
|
|
153
157
|
):
|
|
154
158
|
|
|
155
159
|
self.pretrained_models = self._load_models_json_file(self.MODELS) if self.MODELS else None
|
|
156
|
-
self._args, self.
|
|
160
|
+
self._args, self._is_cli_deploy = self._parse_cli_deploy_args()
|
|
157
161
|
if model_dir is None:
|
|
158
|
-
if self.
|
|
162
|
+
if self._is_cli_deploy is True:
|
|
159
163
|
try:
|
|
160
164
|
model_dir = get_data_dir()
|
|
161
165
|
except:
|
|
@@ -165,8 +169,12 @@ class Inference:
|
|
|
165
169
|
sly_fs.mkdir(model_dir)
|
|
166
170
|
|
|
167
171
|
self.autorestart = None
|
|
172
|
+
_deploy_model = model
|
|
173
|
+
_deploy_device = device
|
|
174
|
+
_deploy_runtime = runtime
|
|
168
175
|
self.device: str = None
|
|
169
176
|
self.runtime: str = None
|
|
177
|
+
self._is_quick_deploy = False
|
|
170
178
|
self.model_precision: str = None
|
|
171
179
|
self.model_source: str = None
|
|
172
180
|
self.checkpoint_info: CheckpointInfo = None
|
|
@@ -208,12 +216,14 @@ class Inference:
|
|
|
208
216
|
|
|
209
217
|
self.load_model = LOAD_MODEL_DECORATOR(self.load_model)
|
|
210
218
|
|
|
211
|
-
if self.
|
|
219
|
+
if self._is_cli_deploy:
|
|
212
220
|
self._use_gui = False
|
|
213
221
|
deploy_params, need_download = self._get_deploy_params_from_args()
|
|
214
222
|
if need_download:
|
|
223
|
+
logger.debug("Downloading model files")
|
|
215
224
|
local_model_files = self._download_model_files(deploy_params, False)
|
|
216
225
|
deploy_params["model_files"] = local_model_files
|
|
226
|
+
logger.debug("Loading model...")
|
|
217
227
|
self._load_model_headless(**deploy_params)
|
|
218
228
|
|
|
219
229
|
if self._use_gui:
|
|
@@ -307,6 +317,97 @@ class Inference:
|
|
|
307
317
|
)
|
|
308
318
|
|
|
309
319
|
self.inference_requests_manager = InferenceRequestsManager(executor=self._executor)
|
|
320
|
+
if _deploy_model is not None and not self._model_served:
|
|
321
|
+
self._is_quick_deploy = True
|
|
322
|
+
self.serve()
|
|
323
|
+
self._deploy_headless(_deploy_model, _deploy_device, _deploy_runtime)
|
|
324
|
+
|
|
325
|
+
def __call__(
|
|
326
|
+
self,
|
|
327
|
+
input: Union[np.ndarray, str, os.PathLike, list] = None,
|
|
328
|
+
image_id: Union[List[int], int] = None,
|
|
329
|
+
video_id: Union[List[int], int] = None,
|
|
330
|
+
dataset_id: Union[List[int], int] = None,
|
|
331
|
+
project_id: Union[List[int], int] = None,
|
|
332
|
+
batch_size: int = None,
|
|
333
|
+
conf: float = None,
|
|
334
|
+
img_size: int = None,
|
|
335
|
+
classes: List[str] = None,
|
|
336
|
+
upload_mode: str = None,
|
|
337
|
+
recursive: bool = None,
|
|
338
|
+
**kwargs,
|
|
339
|
+
):
|
|
340
|
+
return ModelAPI(api=self._api, url="http://0.0.0.0:8000").predict(
|
|
341
|
+
input,
|
|
342
|
+
image_id,
|
|
343
|
+
video_id,
|
|
344
|
+
dataset_id,
|
|
345
|
+
project_id,
|
|
346
|
+
batch_size,
|
|
347
|
+
conf,
|
|
348
|
+
img_size,
|
|
349
|
+
classes,
|
|
350
|
+
upload_mode,
|
|
351
|
+
recursive,
|
|
352
|
+
**kwargs,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def _deploy_headless(self, model: str, device: str, runtime: Optional[str] = None):
|
|
356
|
+
"""Deploy model immediately from constructor arguments."""
|
|
357
|
+
# Clean model_dir before deploying
|
|
358
|
+
sly_fs.mkdir(self._model_dir, True)
|
|
359
|
+
|
|
360
|
+
def get_runtime(runtime: Optional[str]) -> str:
|
|
361
|
+
if runtime is None:
|
|
362
|
+
runtime = RuntimeType.PYTORCH
|
|
363
|
+
else:
|
|
364
|
+
if runtime.lower() in ["torch", RuntimeType.PYTORCH.lower()]:
|
|
365
|
+
runtime = RuntimeType.PYTORCH
|
|
366
|
+
elif runtime.lower() in ["onnx", RuntimeType.ONNXRUNTIME.lower()]:
|
|
367
|
+
runtime = RuntimeType.ONNXRUNTIME
|
|
368
|
+
elif runtime.lower() in ["trt", RuntimeType.TENSORRT.lower()]:
|
|
369
|
+
runtime = RuntimeType.TENSORRT
|
|
370
|
+
else:
|
|
371
|
+
raise ValueError(f"Invalid runtime: {runtime}. Please use one of the following values: {RuntimeType.PYTORCH}, {RuntimeType.ONNXRUNTIME}, {RuntimeType.TENSORRT}.")
|
|
372
|
+
return runtime
|
|
373
|
+
|
|
374
|
+
def get_pretrained_model(model: str) -> dict:
|
|
375
|
+
if self.pretrained_models is not None:
|
|
376
|
+
for m in self.pretrained_models:
|
|
377
|
+
m_name = _get_model_name(m)
|
|
378
|
+
if m_name and m_name.lower() == model.lower():
|
|
379
|
+
return m
|
|
380
|
+
return None
|
|
381
|
+
|
|
382
|
+
runtime = get_runtime(runtime)
|
|
383
|
+
logger.debug(f"Runtime: {runtime}")
|
|
384
|
+
|
|
385
|
+
# Pretrained models
|
|
386
|
+
selected_pretrained = get_pretrained_model(model)
|
|
387
|
+
if selected_pretrained is not None:
|
|
388
|
+
logger.debug("Pretrained model found")
|
|
389
|
+
model_files_remote = selected_pretrained["meta"]["model_files"]
|
|
390
|
+
model_files_local = self._download_pretrained_model(model_files_remote, headless=True)
|
|
391
|
+
|
|
392
|
+
deploy_params = {
|
|
393
|
+
"model_source": ModelSource.PRETRAINED,
|
|
394
|
+
"model_files": model_files_local,
|
|
395
|
+
"model_info": selected_pretrained,
|
|
396
|
+
"device": device,
|
|
397
|
+
"runtime": runtime,
|
|
398
|
+
}
|
|
399
|
+
logger.debug(f"Deploying pretrained model '{model}' ...")
|
|
400
|
+
logger.debug(f"Deploy parameters: {deploy_params}")
|
|
401
|
+
self._load_model_headless(**deploy_params)
|
|
402
|
+
return self
|
|
403
|
+
|
|
404
|
+
# Custom Models
|
|
405
|
+
checkpoint_path = model
|
|
406
|
+
checkpoint_name = sly_fs.get_file_name_with_ext(checkpoint_path)
|
|
407
|
+
deploy_params = self._get_deploy_parameters_from_custom_checkpoint(checkpoint_path, device, runtime)
|
|
408
|
+
logger.debug(f"Deploying custom model '{checkpoint_name}'...")
|
|
409
|
+
self._load_model_headless(**deploy_params)
|
|
410
|
+
return self
|
|
310
411
|
|
|
311
412
|
def get_batch_size(self):
|
|
312
413
|
if self.max_batch_size is not None:
|
|
@@ -808,6 +909,101 @@ class Inference:
|
|
|
808
909
|
if log_progress:
|
|
809
910
|
self.gui.download_progress.hide()
|
|
810
911
|
return local_model_files
|
|
912
|
+
|
|
913
|
+
def _get_deploy_parameters_from_custom_checkpoint(self, checkpoint_path: str, device: str, runtime: str) -> dict:
|
|
914
|
+
def _read_experiment_info(artifacts_dir: str) -> Optional[dict]:
|
|
915
|
+
exp_path = os.path.join(artifacts_dir, "experiment_info.json")
|
|
916
|
+
if sly_fs.file_exists(exp_path):
|
|
917
|
+
return self._load_json_file(exp_path)
|
|
918
|
+
return None
|
|
919
|
+
|
|
920
|
+
def _compose_model_files(artifacts_dir: str, checkpoint_path: str, files_map: dict) -> dict:
|
|
921
|
+
model_files = {k: os.path.join(artifacts_dir, v) for k, v in files_map.items()}
|
|
922
|
+
model_files["checkpoint"] = checkpoint_path
|
|
923
|
+
return model_files
|
|
924
|
+
|
|
925
|
+
is_local = sly_fs.file_exists(checkpoint_path)
|
|
926
|
+
if not is_local:
|
|
927
|
+
team_id = sly_env.team_id()
|
|
928
|
+
if self.api is None:
|
|
929
|
+
raise ValueError(
|
|
930
|
+
f"File: '{checkpoint_path}' not found in local storage. "
|
|
931
|
+
"Initialize API by providing 'API_TOKEN' and 'SERVER_ADDRESS' "
|
|
932
|
+
"environment variables to use remote checkpoint."
|
|
933
|
+
)
|
|
934
|
+
if not self.api.file.exists(team_id, checkpoint_path):
|
|
935
|
+
raise FileNotFoundError(
|
|
936
|
+
f"Checkpoint '{checkpoint_path}' not found locally and remotely. "
|
|
937
|
+
"Make sure you have provided correct checkpoint path."
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
artifacts_dir = os.path.dirname(os.path.dirname(checkpoint_path))
|
|
941
|
+
if not is_local:
|
|
942
|
+
logger.debug("Remote checkpoint found")
|
|
943
|
+
# --- REMOTE ---
|
|
944
|
+
# experiment_info.json
|
|
945
|
+
logger.debug("Getting experiment_info.json...")
|
|
946
|
+
remote_exp_info = os.path.join(artifacts_dir, "experiment_info.json")
|
|
947
|
+
local_exp_info = os.path.join(self.model_dir, "experiment_info.json")
|
|
948
|
+
self.download(remote_exp_info, local_exp_info)
|
|
949
|
+
experiment_info = self._load_json_file(local_exp_info)
|
|
950
|
+
|
|
951
|
+
# model_meta.json
|
|
952
|
+
logger.debug("Getting model_meta.json...")
|
|
953
|
+
meta_name = experiment_info.get("model_meta") or "model_meta.json"
|
|
954
|
+
remote_meta = os.path.join(artifacts_dir, meta_name)
|
|
955
|
+
local_meta = os.path.join(self.model_dir, meta_name)
|
|
956
|
+
self.download(remote_meta, local_meta)
|
|
957
|
+
model_meta = self._load_json_file(local_meta)
|
|
958
|
+
experiment_info["model_meta"] = model_meta
|
|
959
|
+
|
|
960
|
+
# model files
|
|
961
|
+
logger.debug("Getting model files...")
|
|
962
|
+
remote_files_rel = experiment_info.get("model_files", {})
|
|
963
|
+
remote_files_full = _compose_model_files(artifacts_dir, checkpoint_path, remote_files_rel)
|
|
964
|
+
local_model_files = self._download_custom_model(remote_files_full, False)
|
|
965
|
+
model_files = local_model_files
|
|
966
|
+
model_info = experiment_info
|
|
967
|
+
logger.debug("Deploy parameters extracted from remote checkpoint successfully")
|
|
968
|
+
else:
|
|
969
|
+
logger.debug("Local checkpoint found")
|
|
970
|
+
# --- LOCAL ---
|
|
971
|
+
try:
|
|
972
|
+
logger.debug("Reading state dict...")
|
|
973
|
+
import torch # pylint: disable=import-error
|
|
974
|
+
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
|
975
|
+
model_info = ckpt.get("model_info", {})
|
|
976
|
+
model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
|
|
977
|
+
model_files["checkpoint"] = checkpoint_path
|
|
978
|
+
meta_path = os.path.join(self.model_dir, "model_meta.json")
|
|
979
|
+
if sly_fs.file_exists(meta_path):
|
|
980
|
+
model_info["model_meta"] = self._load_json_file(meta_path)
|
|
981
|
+
logger.debug("Deploy parameters extracted from state dict successfully")
|
|
982
|
+
except:
|
|
983
|
+
logger.debug(f"Failed to read model metadata from checkpoint '{checkpoint_path}'. Trying to find local files...")
|
|
984
|
+
experiment_info = _read_experiment_info(artifacts_dir)
|
|
985
|
+
if experiment_info:
|
|
986
|
+
logger.debug("Reading experiment_info.json...")
|
|
987
|
+
model_files = _compose_model_files(artifacts_dir, checkpoint_path, experiment_info.get("model_files", {}))
|
|
988
|
+
meta_name = experiment_info.get("model_meta") or "model_meta.json"
|
|
989
|
+
meta_path = os.path.join(artifacts_dir, meta_name)
|
|
990
|
+
if not sly_fs.file_exists(meta_path):
|
|
991
|
+
raise FileNotFoundError(f"Model meta file not found: '{meta_path}'")
|
|
992
|
+
experiment_info["model_meta"] = self._load_json_file(meta_path)
|
|
993
|
+
model_info = experiment_info
|
|
994
|
+
logger.debug("Deploy parameters extracted from experiment_info.json successfully")
|
|
995
|
+
else:
|
|
996
|
+
raise FileNotFoundError(f"'experiment_info.json' not found in '{artifacts_dir}'")
|
|
997
|
+
|
|
998
|
+
deploy_params = {
|
|
999
|
+
"model_source": ModelSource.CUSTOM,
|
|
1000
|
+
"model_files": model_files,
|
|
1001
|
+
"model_info": model_info,
|
|
1002
|
+
"device": device,
|
|
1003
|
+
"runtime": runtime,
|
|
1004
|
+
}
|
|
1005
|
+
logger.debug(f"Deploy parameters: {deploy_params}")
|
|
1006
|
+
return deploy_params
|
|
811
1007
|
|
|
812
1008
|
def _extract_model_files_from_checkpoint(self, checkpoint_path: str) -> dict:
|
|
813
1009
|
extracted_files: dict = {}
|
|
@@ -816,7 +1012,6 @@ class Inference:
|
|
|
816
1012
|
return extracted_files
|
|
817
1013
|
|
|
818
1014
|
import torch # pylint: disable=import-error
|
|
819
|
-
|
|
820
1015
|
logger.debug(f"Reading checkpoint: {checkpoint_path}")
|
|
821
1016
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
822
1017
|
|
|
@@ -1008,11 +1203,12 @@ class Inference:
|
|
|
1008
1203
|
model_files = deploy_params.get("model_files", {})
|
|
1009
1204
|
if model_info:
|
|
1010
1205
|
checkpoint_name = os.path.basename(model_files.get("checkpoint"))
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1206
|
+
artifacts_dir = model_info.get("artifacts_dir")
|
|
1207
|
+
if artifacts_dir is None:
|
|
1208
|
+
artifacts_dir = os.path.dirname(os.path.dirname(model_files.get("checkpoint")))
|
|
1209
|
+
checkpoint_file_path = os.path.join(artifacts_dir, "checkpoints", checkpoint_name)
|
|
1014
1210
|
checkpoint_file_info = None
|
|
1015
|
-
if not self.
|
|
1211
|
+
if not self._is_cli_deploy:
|
|
1016
1212
|
checkpoint_file_info = self.api.file.get_info_by_path(
|
|
1017
1213
|
sly_env.team_id(), checkpoint_file_path
|
|
1018
1214
|
)
|
|
@@ -1115,7 +1311,7 @@ class Inference:
|
|
|
1115
1311
|
def api(self) -> Api:
|
|
1116
1312
|
if self._api is None:
|
|
1117
1313
|
if (
|
|
1118
|
-
self.
|
|
1314
|
+
self._is_cli_deploy
|
|
1119
1315
|
and os.getenv("SERVER_ADDRESS") is None
|
|
1120
1316
|
and os.getenv("API_TOKEN") is None
|
|
1121
1317
|
):
|
|
@@ -2538,7 +2734,7 @@ class Inference:
|
|
|
2538
2734
|
inference_request.done(len(results))
|
|
2539
2735
|
|
|
2540
2736
|
def serve(self):
|
|
2541
|
-
if not self._use_gui and not self.
|
|
2737
|
+
if not self._use_gui and not self._is_cli_deploy:
|
|
2542
2738
|
Progress("Deploying model ...", 1)
|
|
2543
2739
|
|
|
2544
2740
|
if is_debug_with_sly_net():
|
|
@@ -2553,30 +2749,31 @@ class Inference:
|
|
|
2553
2749
|
self._task_id = task["id"]
|
|
2554
2750
|
os.environ["TASK_ID"] = str(self._task_id)
|
|
2555
2751
|
else:
|
|
2556
|
-
if not self.
|
|
2752
|
+
if not self._is_cli_deploy:
|
|
2557
2753
|
self._task_id = sly_env.task_id() if is_production() else None
|
|
2558
2754
|
|
|
2559
|
-
if self.
|
|
2755
|
+
if self._is_cli_deploy:
|
|
2560
2756
|
# Predict and shutdown
|
|
2561
|
-
if self._args.mode == "predict"
|
|
2757
|
+
if self._args.mode == "predict":
|
|
2758
|
+
if any(
|
|
2562
2759
|
[
|
|
2563
2760
|
self._args.input,
|
|
2564
2761
|
self._args.project_id,
|
|
2565
2762
|
self._args.dataset_id,
|
|
2566
2763
|
self._args.image_id,
|
|
2567
2764
|
]
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
|
|
2571
|
-
|
|
2572
|
-
|
|
2765
|
+
):
|
|
2766
|
+
self._parse_inference_settings_from_args()
|
|
2767
|
+
self._inference_by_cli_deploy_args()
|
|
2768
|
+
exit(0)
|
|
2769
|
+
else:
|
|
2770
|
+
logger.error("Predict mode requires one of the following arguments: --input, --project_id, --dataset_id, --image_id")
|
|
2771
|
+
exit(0)
|
|
2573
2772
|
|
|
2574
2773
|
if isinstance(self.gui, GUI.InferenceGUI):
|
|
2575
2774
|
self._app = Application(layout=self.get_ui())
|
|
2576
2775
|
elif isinstance(self.gui, GUI.ServingGUI):
|
|
2577
2776
|
self._app = Application(layout=self._app_layout)
|
|
2578
|
-
# elif isinstance(self.gui, GUI.InferenceGUI):
|
|
2579
|
-
# self._app = Application(layout=self.get_ui())
|
|
2580
2777
|
else:
|
|
2581
2778
|
self._app = Application(layout=self.get_ui())
|
|
2582
2779
|
|
|
@@ -3283,10 +3480,12 @@ class Inference:
|
|
|
3283
3480
|
}
|
|
3284
3481
|
|
|
3285
3482
|
# Local deploy without predict args
|
|
3286
|
-
if self.
|
|
3483
|
+
if self._is_cli_deploy:
|
|
3287
3484
|
self._run_server()
|
|
3485
|
+
elif self._is_quick_deploy:
|
|
3486
|
+
self._run_server_in_thread()
|
|
3288
3487
|
|
|
3289
|
-
def
|
|
3488
|
+
def _parse_cli_deploy_args(self):
|
|
3290
3489
|
parser = argparse.ArgumentParser(description="Run Inference Serving")
|
|
3291
3490
|
|
|
3292
3491
|
# Positional args
|
|
@@ -3403,6 +3602,7 @@ class Inference:
|
|
|
3403
3602
|
return args, True
|
|
3404
3603
|
|
|
3405
3604
|
def _parse_inference_settings_from_args(self):
|
|
3605
|
+
logger.debug("Parsing inference settings from args")
|
|
3406
3606
|
def parse_value(value: str):
|
|
3407
3607
|
if value.lower() in ("true", "false"):
|
|
3408
3608
|
return value.lower() == "true"
|
|
@@ -3438,6 +3638,7 @@ class Inference:
|
|
|
3438
3638
|
args.settings = settings_dict
|
|
3439
3639
|
args.settings = self._read_settings(args.settings)
|
|
3440
3640
|
self._validate_settings(args.settings)
|
|
3641
|
+
logger.debug("Inference settings were successfully parsed from args")
|
|
3441
3642
|
|
|
3442
3643
|
def _get_pretrained_model_params_from_args(self):
|
|
3443
3644
|
model_files = None
|
|
@@ -3473,8 +3674,12 @@ class Inference:
|
|
|
3473
3674
|
def _get_custom_model_params_from_args(self):
|
|
3474
3675
|
def _load_experiment_info(artifacts_dir):
|
|
3475
3676
|
experiment_path = os.path.join(artifacts_dir, "experiment_info.json")
|
|
3677
|
+
if not os.path.exists(experiment_path):
|
|
3678
|
+
raise ValueError(f"Experiment info file not found in {artifacts_dir}")
|
|
3476
3679
|
model_info = self._load_json_file(experiment_path)
|
|
3477
3680
|
model_meta_path = os.path.join(artifacts_dir, "model_meta.json")
|
|
3681
|
+
if not os.path.exists(model_meta_path):
|
|
3682
|
+
raise ValueError(f"Model meta file not found in {artifacts_dir}")
|
|
3478
3683
|
model_info["model_meta"] = self._load_json_file(model_meta_path)
|
|
3479
3684
|
original_model_files = model_info.get("model_files")
|
|
3480
3685
|
return model_info, original_model_files
|
|
@@ -3500,6 +3705,7 @@ class Inference:
|
|
|
3500
3705
|
else:
|
|
3501
3706
|
loop.run_until_complete(coro)
|
|
3502
3707
|
|
|
3708
|
+
logger.debug("Getting custom model params from args")
|
|
3503
3709
|
model_source = ModelSource.CUSTOM
|
|
3504
3710
|
need_download = False
|
|
3505
3711
|
checkpoint_path = self._args.model
|
|
@@ -3510,6 +3716,8 @@ class Inference:
|
|
|
3510
3716
|
raise ValueError(
|
|
3511
3717
|
"Team ID not found in env. Required for remote custom checkpoints."
|
|
3512
3718
|
)
|
|
3719
|
+
if self.api is None:
|
|
3720
|
+
raise ValueError("API is not initialized. Please provide .env file with 'API_TOKEN' and 'SERVER_ADDRESS' environment variables.")
|
|
3513
3721
|
file_info = self.api.file.get_info_by_path(team_id, checkpoint_path)
|
|
3514
3722
|
if not file_info:
|
|
3515
3723
|
raise ValueError(
|
|
@@ -3517,26 +3725,46 @@ class Inference:
|
|
|
3517
3725
|
)
|
|
3518
3726
|
need_download = True
|
|
3519
3727
|
|
|
3728
|
+
if not need_download:
|
|
3729
|
+
try:
|
|
3730
|
+
# Read data from checkpoint
|
|
3731
|
+
logger.debug(f"Reading data from checkpoint: {checkpoint_path}")
|
|
3732
|
+
import torch # pylint: disable=import-error
|
|
3733
|
+
checkpoint = torch.load(checkpoint_path)
|
|
3734
|
+
model_info = checkpoint["model_info"]
|
|
3735
|
+
model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
|
|
3736
|
+
model_meta = os.path.join(self.model_dir, "model_meta.json")
|
|
3737
|
+
model_info["model_meta"] = self._load_json_file(model_meta)
|
|
3738
|
+
model_files["checkpoint"] = checkpoint_path
|
|
3739
|
+
need_download = False
|
|
3740
|
+
return model_files, model_source, model_info, need_download
|
|
3741
|
+
except Exception as e:
|
|
3742
|
+
logger.debug(f"Failed to read data from checkpoint: {repr(e)}")
|
|
3743
|
+
|
|
3520
3744
|
artifacts_dir = os.path.dirname(os.path.dirname(checkpoint_path))
|
|
3521
3745
|
if not need_download:
|
|
3746
|
+
logger.debug(f"Looking for data in artifacts: '{artifacts_dir}'")
|
|
3522
3747
|
model_info, original_model_files = _load_experiment_info(artifacts_dir)
|
|
3523
3748
|
model_files = _prepare_local_model_files(
|
|
3524
3749
|
artifacts_dir, checkpoint_path, original_model_files
|
|
3525
3750
|
)
|
|
3526
|
-
|
|
3751
|
+
logger.debug(f"Data was found in artifacts directory: '{artifacts_dir}'")
|
|
3527
3752
|
else:
|
|
3753
|
+
logger.debug(f"Downloading data from remote directory: '{artifacts_dir}'")
|
|
3528
3754
|
local_artifacts_dir = os.path.join(
|
|
3529
|
-
self.model_dir, "
|
|
3755
|
+
self.model_dir, "cli_deploy", os.path.basename(artifacts_dir)
|
|
3530
3756
|
)
|
|
3531
3757
|
_download_remote_files(team_id, artifacts_dir, local_artifacts_dir)
|
|
3532
|
-
|
|
3533
3758
|
model_info, original_model_files = _load_experiment_info(local_artifacts_dir)
|
|
3534
3759
|
model_files = _prepare_local_model_files(
|
|
3535
3760
|
local_artifacts_dir, checkpoint_path, original_model_files
|
|
3536
3761
|
)
|
|
3762
|
+
logger.debug(f"Data was downloaded from remote directory: '{artifacts_dir}'")
|
|
3763
|
+
logger.debug("Custom model params were successfully parsed from args")
|
|
3537
3764
|
return model_files, model_source, model_info, need_download
|
|
3538
3765
|
|
|
3539
3766
|
def _get_deploy_params_from_args(self):
|
|
3767
|
+
logger.debug("Getting deploy params from args")
|
|
3540
3768
|
# Ensure model directory exists
|
|
3541
3769
|
device = self._args.device if self._args.device else "cuda:0"
|
|
3542
3770
|
runtime = self._args.runtime if self._args.runtime else RuntimeType.PYTORCH
|
|
@@ -3563,7 +3791,6 @@ class Inference:
|
|
|
3563
3791
|
"device": device,
|
|
3564
3792
|
"runtime": runtime,
|
|
3565
3793
|
}
|
|
3566
|
-
|
|
3567
3794
|
logger.debug(f"Deploy parameters: {deploy_params}")
|
|
3568
3795
|
return deploy_params, need_download
|
|
3569
3796
|
|
|
@@ -3572,6 +3799,18 @@ class Inference:
|
|
|
3572
3799
|
self._uvicorn_server = uvicorn.Server(config)
|
|
3573
3800
|
self._uvicorn_server.run()
|
|
3574
3801
|
|
|
3802
|
+
def _run_server_in_thread(self):
|
|
3803
|
+
"""Run Uvicorn server in a separate thread so that this method doesn't block the caller."""
|
|
3804
|
+
import threading
|
|
3805
|
+
|
|
3806
|
+
def _serve():
|
|
3807
|
+
config = uvicorn.Config(app=self._app, host="0.0.0.0", port=8000, ws="websockets")
|
|
3808
|
+
self._uvicorn_server = uvicorn.Server(config)
|
|
3809
|
+
self._uvicorn_server.run()
|
|
3810
|
+
|
|
3811
|
+
self._server_thread = threading.Thread(target=_serve, daemon=True)
|
|
3812
|
+
self._server_thread.start()
|
|
3813
|
+
|
|
3575
3814
|
def _read_settings(self, settings: Union[str, Dict[str, Any]]):
|
|
3576
3815
|
if isinstance(settings, dict):
|
|
3577
3816
|
return settings
|
|
@@ -3598,7 +3837,8 @@ class Inference:
|
|
|
3598
3837
|
f"Inference settings doesn't have key: '{key}'. Available keys are: '{acceptable_keys}'"
|
|
3599
3838
|
)
|
|
3600
3839
|
|
|
3601
|
-
def
|
|
3840
|
+
def _inference_by_cli_deploy_args(self):
|
|
3841
|
+
logger.debug("Starting inference by CLI deploy args")
|
|
3602
3842
|
missing_env_message = "Set 'SERVER_ADDRESS' and 'API_TOKEN' environment variables to predict data on Supervisely platform."
|
|
3603
3843
|
|
|
3604
3844
|
def predict_project_id_by_args(
|
|
@@ -3612,14 +3852,13 @@ class Inference:
|
|
|
3612
3852
|
):
|
|
3613
3853
|
if self.api is None:
|
|
3614
3854
|
raise ValueError(missing_env_message)
|
|
3855
|
+
if draw:
|
|
3856
|
+
raise ValueError("Draw visualization is not supported for project inference")
|
|
3615
3857
|
|
|
3616
3858
|
if dataset_ids:
|
|
3617
|
-
logger.info(f"Predicting
|
|
3859
|
+
logger.info(f"Predicting Dataset(s) by ID(s): '{', '.join(str(dataset_id) for dataset_id in dataset_ids)}'")
|
|
3618
3860
|
else:
|
|
3619
|
-
logger.info(f"Predicting
|
|
3620
|
-
|
|
3621
|
-
if draw:
|
|
3622
|
-
raise ValueError("Draw visualization is not supported for project inference")
|
|
3861
|
+
logger.info(f"Predicting Project by ID: {project_id}")
|
|
3623
3862
|
|
|
3624
3863
|
state = {
|
|
3625
3864
|
"projectId": project_id,
|
|
@@ -3659,6 +3898,7 @@ class Inference:
|
|
|
3659
3898
|
draw: bool = False,
|
|
3660
3899
|
upload: bool = False,
|
|
3661
3900
|
):
|
|
3901
|
+
logger.info(f"Predicting Dataset(s) by ID(s): {', '.join(str(dataset_id) for dataset_id in dataset_ids)}")
|
|
3662
3902
|
if draw:
|
|
3663
3903
|
raise ValueError("Draw visualization is not supported for dataset inference")
|
|
3664
3904
|
if self.api is None:
|
|
@@ -3682,7 +3922,7 @@ class Inference:
|
|
|
3682
3922
|
if self.api is None:
|
|
3683
3923
|
raise ValueError(missing_env_message)
|
|
3684
3924
|
|
|
3685
|
-
logger.info(f"Predicting
|
|
3925
|
+
logger.info(f"Predicting Image by ID: {image_id}")
|
|
3686
3926
|
|
|
3687
3927
|
def predict_image_np(image_np):
|
|
3688
3928
|
anns, _ = self._inference_auto([image_np], settings)
|
|
@@ -3719,8 +3959,7 @@ class Inference:
|
|
|
3719
3959
|
output_dir: str = "./predictions",
|
|
3720
3960
|
draw: bool = False,
|
|
3721
3961
|
):
|
|
3722
|
-
logger.info(f"Predicting
|
|
3723
|
-
|
|
3962
|
+
logger.info(f"Predicting Local Data: {input_path}")
|
|
3724
3963
|
def postprocess_image(image_path: str, ann: Annotation, pred_dir: str = None):
|
|
3725
3964
|
image_name = sly_fs.get_file_name_with_ext(image_path)
|
|
3726
3965
|
if pred_dir is not None:
|
|
@@ -1667,26 +1667,25 @@ class TrainApp:
|
|
|
1667
1667
|
checkpoint_name = sly_fs.get_file_name_with_ext(checkpoint_path)
|
|
1668
1668
|
new_checkpoint_path = join(self._output_checkpoints_dir, checkpoint_name)
|
|
1669
1669
|
shutil.move(checkpoint_path, new_checkpoint_path)
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
1688
|
-
|
|
1689
|
-
continue
|
|
1670
|
+
try:
|
|
1671
|
+
# pylint: disable=import-error
|
|
1672
|
+
import torch
|
|
1673
|
+
state_dict = torch.load(new_checkpoint_path)
|
|
1674
|
+
state_dict["model_info"] = {
|
|
1675
|
+
"task_id": self.task_id,
|
|
1676
|
+
"model_name": experiment_info["model_name"],
|
|
1677
|
+
"framework": self.framework_name,
|
|
1678
|
+
"checkpoint": checkpoint_name,
|
|
1679
|
+
"experiment": self.gui.training_process.get_experiment_name(),
|
|
1680
|
+
}
|
|
1681
|
+
state_dict["model_meta"] = model_meta.to_json()
|
|
1682
|
+
state_dict["model_files"] = ckpt_files
|
|
1683
|
+
torch.save(state_dict, new_checkpoint_path)
|
|
1684
|
+
except Exception as e:
|
|
1685
|
+
logger.warning(
|
|
1686
|
+
f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}"
|
|
1687
|
+
)
|
|
1688
|
+
continue
|
|
1690
1689
|
|
|
1691
1690
|
new_checkpoint_paths.append(new_checkpoint_path)
|
|
1692
1691
|
if sly_fs.get_file_name_with_ext(checkpoint_path) == best_checkpoints_name:
|
|
@@ -478,6 +478,12 @@ def download_volume_project(
|
|
|
478
478
|
figure_path = dataset_fs.get_interpolation_path(volume_name, sf)
|
|
479
479
|
mesh_paths.append(figure_path)
|
|
480
480
|
|
|
481
|
+
figs = api.volume.figure.download(dataset.id, [volume_id], skip_geometry=True)[volume_id]
|
|
482
|
+
figs_ids_map = {fig.id: fig for fig in figs}
|
|
483
|
+
for ann_fig in ann.figures + ann.spatial_figures:
|
|
484
|
+
fig = figs_ids_map.get(ann_fig.geometry.sly_id)
|
|
485
|
+
ann_fig.custom_data.update(fig.custom_data)
|
|
486
|
+
|
|
481
487
|
api.volume.figure.download_stl_meshes(mesh_ids, mesh_paths)
|
|
482
488
|
api.volume.figure.download_sf_geometries(mask_ids, mask_paths)
|
|
483
489
|
|