supervisely 6.73.393__py3-none-any.whl → 6.73.395__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ from typing import List
4
4
  import supervisely.convert.volume.sly.sly_volume_helper as sly_volume_helper
5
5
  from supervisely.volume_annotation.volume_annotation import VolumeAnnotation
6
6
  from supervisely import ProjectMeta, logger
7
+ from supervisely.project.volume_project import VolumeProject, VolumeDataset
7
8
  from supervisely.convert.base_converter import AvailableVolumeConverters
8
9
  from supervisely.convert.volume.volume_converter import VolumeConverter
9
10
  from supervisely.io.fs import JUNK_FILES, get_file_ext, get_file_name
@@ -40,7 +41,7 @@ class SLYVolumeConverter(VolumeConverter):
40
41
  ann = VolumeAnnotation.from_json(ann_json, meta) # , KeyIdMap())
41
42
  return True
42
43
  except Exception as e:
43
- logger.warn(f"Failed to validate annotation: {repr(e)}")
44
+ logger.warning(f"Failed to validate annotation: {repr(e)}")
44
45
  return False
45
46
 
46
47
  def validate_key_file(self, key_file_path: str) -> bool:
@@ -51,6 +52,9 @@ class SLYVolumeConverter(VolumeConverter):
51
52
  return False
52
53
 
53
54
  def validate_format(self) -> bool:
55
+ if self.read_sly_project(self._input_data):
56
+ return True
57
+
54
58
  detected_ann_cnt = 0
55
59
  vol_list, stl_dict, ann_dict, mask_dict = [], {}, {}, {}
56
60
  for root, _, files in os.walk(self._input_data):
@@ -70,7 +74,7 @@ class SLYVolumeConverter(VolumeConverter):
70
74
  ProjectMeta.from_json(meta_json)
71
75
  )
72
76
  except Exception as e:
73
- logger.warn(f"Failed to merge meta: {repr(e)}")
77
+ logger.warning(f"Failed to merge meta: {repr(e)}")
74
78
  continue
75
79
 
76
80
  elif file in JUNK_FILES: # add better check
@@ -139,5 +143,30 @@ class SLYVolumeConverter(VolumeConverter):
139
143
  ann_json = sly_volume_helper.rename_in_json(ann_json, renamed_classes, renamed_tags)
140
144
  return VolumeAnnotation.from_json(ann_json, meta) # , KeyIdMap())
141
145
  except Exception as e:
142
- logger.warn(f"Failed to read annotation: {repr(e)}")
146
+ logger.warning(f"Failed to read annotation: {repr(e)}")
143
147
  return item.create_empty_annotation()
148
+
149
+ def read_sly_project(self, input_data: str) -> bool:
150
+ try:
151
+ project_fs = VolumeProject.read_single(input_data)
152
+ self._meta = project_fs.meta
153
+ self._items = []
154
+
155
+ for dataset_fs in project_fs.datasets:
156
+ dataset_fs: VolumeDataset
157
+
158
+ for item_name in dataset_fs:
159
+ img_path, ann_path = dataset_fs.get_item_paths(item_name)
160
+ item = self.Item(
161
+ item_path=img_path,
162
+ ann_data=ann_path,
163
+ shape=None,
164
+ interpolation_dir=dataset_fs.get_interpolation_dir(item_name),
165
+ mask_dir=dataset_fs.get_mask_dir(item_name),
166
+ )
167
+ self._items.append(item)
168
+ return True
169
+
170
+ except Exception as e:
171
+ logger.info(f"Failed to read Supervisely project: {repr(e)}")
172
+ return False
@@ -93,6 +93,7 @@ from supervisely.project.project_meta import ProjectMeta
93
93
  from supervisely.sly_logger import logger
94
94
  from supervisely.task.progress import Progress
95
95
  from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
96
+ from supervisely.nn.model.model_api import ModelAPI
96
97
 
97
98
  try:
98
99
  from typing import Literal
@@ -150,12 +151,15 @@ class Inference:
150
151
  use_gui: Optional[bool] = False,
151
152
  multithread_inference: Optional[bool] = True,
152
153
  use_serving_gui_template: Optional[bool] = False,
154
+ model: Optional[str] = None,
155
+ device: Optional[str] = None,
156
+ runtime: Optional[str] = None,
153
157
  ):
154
158
 
155
159
  self.pretrained_models = self._load_models_json_file(self.MODELS) if self.MODELS else None
156
- self._args, self._is_local_deploy = self._parse_local_deploy_args()
160
+ self._args, self._is_cli_deploy = self._parse_cli_deploy_args()
157
161
  if model_dir is None:
158
- if self._is_local_deploy is True:
162
+ if self._is_cli_deploy is True:
159
163
  try:
160
164
  model_dir = get_data_dir()
161
165
  except:
@@ -165,8 +169,12 @@ class Inference:
165
169
  sly_fs.mkdir(model_dir)
166
170
 
167
171
  self.autorestart = None
172
+ _deploy_model = model
173
+ _deploy_device = device
174
+ _deploy_runtime = runtime
168
175
  self.device: str = None
169
176
  self.runtime: str = None
177
+ self._is_quick_deploy = False
170
178
  self.model_precision: str = None
171
179
  self.model_source: str = None
172
180
  self.checkpoint_info: CheckpointInfo = None
@@ -208,12 +216,14 @@ class Inference:
208
216
 
209
217
  self.load_model = LOAD_MODEL_DECORATOR(self.load_model)
210
218
 
211
- if self._is_local_deploy:
219
+ if self._is_cli_deploy:
212
220
  self._use_gui = False
213
221
  deploy_params, need_download = self._get_deploy_params_from_args()
214
222
  if need_download:
223
+ logger.debug("Downloading model files")
215
224
  local_model_files = self._download_model_files(deploy_params, False)
216
225
  deploy_params["model_files"] = local_model_files
226
+ logger.debug("Loading model...")
217
227
  self._load_model_headless(**deploy_params)
218
228
 
219
229
  if self._use_gui:
@@ -307,6 +317,97 @@ class Inference:
307
317
  )
308
318
 
309
319
  self.inference_requests_manager = InferenceRequestsManager(executor=self._executor)
320
+ if _deploy_model is not None and not self._model_served:
321
+ self._is_quick_deploy = True
322
+ self.serve()
323
+ self._deploy_headless(_deploy_model, _deploy_device, _deploy_runtime)
324
+
325
+ def __call__(
326
+ self,
327
+ input: Union[np.ndarray, str, os.PathLike, list] = None,
328
+ image_id: Union[List[int], int] = None,
329
+ video_id: Union[List[int], int] = None,
330
+ dataset_id: Union[List[int], int] = None,
331
+ project_id: Union[List[int], int] = None,
332
+ batch_size: int = None,
333
+ conf: float = None,
334
+ img_size: int = None,
335
+ classes: List[str] = None,
336
+ upload_mode: str = None,
337
+ recursive: bool = None,
338
+ **kwargs,
339
+ ):
340
+ return ModelAPI(api=self._api, url="http://0.0.0.0:8000").predict(
341
+ input,
342
+ image_id,
343
+ video_id,
344
+ dataset_id,
345
+ project_id,
346
+ batch_size,
347
+ conf,
348
+ img_size,
349
+ classes,
350
+ upload_mode,
351
+ recursive,
352
+ **kwargs,
353
+ )
354
+
355
+ def _deploy_headless(self, model: str, device: str, runtime: Optional[str] = None):
356
+ """Deploy model immediately from constructor arguments."""
357
+ # Clean model_dir before deploying
358
+ sly_fs.mkdir(self._model_dir, True)
359
+
360
+ def get_runtime(runtime: Optional[str]) -> str:
361
+ if runtime is None:
362
+ runtime = RuntimeType.PYTORCH
363
+ else:
364
+ if runtime.lower() in ["torch", RuntimeType.PYTORCH.lower()]:
365
+ runtime = RuntimeType.PYTORCH
366
+ elif runtime.lower() in ["onnx", RuntimeType.ONNXRUNTIME.lower()]:
367
+ runtime = RuntimeType.ONNXRUNTIME
368
+ elif runtime.lower() in ["trt", RuntimeType.TENSORRT.lower()]:
369
+ runtime = RuntimeType.TENSORRT
370
+ else:
371
+ raise ValueError(f"Invalid runtime: {runtime}. Please use one of the following values: {RuntimeType.PYTORCH}, {RuntimeType.ONNXRUNTIME}, {RuntimeType.TENSORRT}.")
372
+ return runtime
373
+
374
+ def get_pretrained_model(model: str) -> dict:
375
+ if self.pretrained_models is not None:
376
+ for m in self.pretrained_models:
377
+ m_name = _get_model_name(m)
378
+ if m_name and m_name.lower() == model.lower():
379
+ return m
380
+ return None
381
+
382
+ runtime = get_runtime(runtime)
383
+ logger.debug(f"Runtime: {runtime}")
384
+
385
+ # Pretrained models
386
+ selected_pretrained = get_pretrained_model(model)
387
+ if selected_pretrained is not None:
388
+ logger.debug("Pretrained model found")
389
+ model_files_remote = selected_pretrained["meta"]["model_files"]
390
+ model_files_local = self._download_pretrained_model(model_files_remote, headless=True)
391
+
392
+ deploy_params = {
393
+ "model_source": ModelSource.PRETRAINED,
394
+ "model_files": model_files_local,
395
+ "model_info": selected_pretrained,
396
+ "device": device,
397
+ "runtime": runtime,
398
+ }
399
+ logger.debug(f"Deploying pretrained model '{model}' ...")
400
+ logger.debug(f"Deploy parameters: {deploy_params}")
401
+ self._load_model_headless(**deploy_params)
402
+ return self
403
+
404
+ # Custom Models
405
+ checkpoint_path = model
406
+ checkpoint_name = sly_fs.get_file_name_with_ext(checkpoint_path)
407
+ deploy_params = self._get_deploy_parameters_from_custom_checkpoint(checkpoint_path, device, runtime)
408
+ logger.debug(f"Deploying custom model '{checkpoint_name}'...")
409
+ self._load_model_headless(**deploy_params)
410
+ return self
310
411
 
311
412
  def get_batch_size(self):
312
413
  if self.max_batch_size is not None:
@@ -808,6 +909,101 @@ class Inference:
808
909
  if log_progress:
809
910
  self.gui.download_progress.hide()
810
911
  return local_model_files
912
+
913
+ def _get_deploy_parameters_from_custom_checkpoint(self, checkpoint_path: str, device: str, runtime: str) -> dict:
914
+ def _read_experiment_info(artifacts_dir: str) -> Optional[dict]:
915
+ exp_path = os.path.join(artifacts_dir, "experiment_info.json")
916
+ if sly_fs.file_exists(exp_path):
917
+ return self._load_json_file(exp_path)
918
+ return None
919
+
920
+ def _compose_model_files(artifacts_dir: str, checkpoint_path: str, files_map: dict) -> dict:
921
+ model_files = {k: os.path.join(artifacts_dir, v) for k, v in files_map.items()}
922
+ model_files["checkpoint"] = checkpoint_path
923
+ return model_files
924
+
925
+ is_local = sly_fs.file_exists(checkpoint_path)
926
+ if not is_local:
927
+ team_id = sly_env.team_id()
928
+ if self.api is None:
929
+ raise ValueError(
930
+ f"File: '{checkpoint_path}' not found in local storage. "
931
+ "Initialize API by providing 'API_TOKEN' and 'SERVER_ADDRESS' "
932
+ "environment variables to use remote checkpoint."
933
+ )
934
+ if not self.api.file.exists(team_id, checkpoint_path):
935
+ raise FileNotFoundError(
936
+ f"Checkpoint '{checkpoint_path}' not found locally and remotely. "
937
+ "Make sure you have provided correct checkpoint path."
938
+ )
939
+
940
+ artifacts_dir = os.path.dirname(os.path.dirname(checkpoint_path))
941
+ if not is_local:
942
+ logger.debug("Remote checkpoint found")
943
+ # --- REMOTE ---
944
+ # experiment_info.json
945
+ logger.debug("Getting experiment_info.json...")
946
+ remote_exp_info = os.path.join(artifacts_dir, "experiment_info.json")
947
+ local_exp_info = os.path.join(self.model_dir, "experiment_info.json")
948
+ self.download(remote_exp_info, local_exp_info)
949
+ experiment_info = self._load_json_file(local_exp_info)
950
+
951
+ # model_meta.json
952
+ logger.debug("Getting model_meta.json...")
953
+ meta_name = experiment_info.get("model_meta") or "model_meta.json"
954
+ remote_meta = os.path.join(artifacts_dir, meta_name)
955
+ local_meta = os.path.join(self.model_dir, meta_name)
956
+ self.download(remote_meta, local_meta)
957
+ model_meta = self._load_json_file(local_meta)
958
+ experiment_info["model_meta"] = model_meta
959
+
960
+ # model files
961
+ logger.debug("Getting model files...")
962
+ remote_files_rel = experiment_info.get("model_files", {})
963
+ remote_files_full = _compose_model_files(artifacts_dir, checkpoint_path, remote_files_rel)
964
+ local_model_files = self._download_custom_model(remote_files_full, False)
965
+ model_files = local_model_files
966
+ model_info = experiment_info
967
+ logger.debug("Deploy parameters extracted from remote checkpoint successfully")
968
+ else:
969
+ logger.debug("Local checkpoint found")
970
+ # --- LOCAL ---
971
+ try:
972
+ logger.debug("Reading state dict...")
973
+ import torch # pylint: disable=import-error
974
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
975
+ model_info = ckpt.get("model_info", {})
976
+ model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
977
+ model_files["checkpoint"] = checkpoint_path
978
+ meta_path = os.path.join(self.model_dir, "model_meta.json")
979
+ if sly_fs.file_exists(meta_path):
980
+ model_info["model_meta"] = self._load_json_file(meta_path)
981
+ logger.debug("Deploy parameters extracted from state dict successfully")
982
+ except:
983
+ logger.debug(f"Failed to read model metadata from checkpoint '{checkpoint_path}'. Trying to find local files...")
984
+ experiment_info = _read_experiment_info(artifacts_dir)
985
+ if experiment_info:
986
+ logger.debug("Reading experiment_info.json...")
987
+ model_files = _compose_model_files(artifacts_dir, checkpoint_path, experiment_info.get("model_files", {}))
988
+ meta_name = experiment_info.get("model_meta") or "model_meta.json"
989
+ meta_path = os.path.join(artifacts_dir, meta_name)
990
+ if not sly_fs.file_exists(meta_path):
991
+ raise FileNotFoundError(f"Model meta file not found: '{meta_path}'")
992
+ experiment_info["model_meta"] = self._load_json_file(meta_path)
993
+ model_info = experiment_info
994
+ logger.debug("Deploy parameters extracted from experiment_info.json successfully")
995
+ else:
996
+ raise FileNotFoundError(f"'experiment_info.json' not found in '{artifacts_dir}'")
997
+
998
+ deploy_params = {
999
+ "model_source": ModelSource.CUSTOM,
1000
+ "model_files": model_files,
1001
+ "model_info": model_info,
1002
+ "device": device,
1003
+ "runtime": runtime,
1004
+ }
1005
+ logger.debug(f"Deploy parameters: {deploy_params}")
1006
+ return deploy_params
811
1007
 
812
1008
  def _extract_model_files_from_checkpoint(self, checkpoint_path: str) -> dict:
813
1009
  extracted_files: dict = {}
@@ -816,7 +1012,6 @@ class Inference:
816
1012
  return extracted_files
817
1013
 
818
1014
  import torch # pylint: disable=import-error
819
-
820
1015
  logger.debug(f"Reading checkpoint: {checkpoint_path}")
821
1016
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
822
1017
 
@@ -1008,11 +1203,12 @@ class Inference:
1008
1203
  model_files = deploy_params.get("model_files", {})
1009
1204
  if model_info:
1010
1205
  checkpoint_name = os.path.basename(model_files.get("checkpoint"))
1011
- checkpoint_file_path = os.path.join(
1012
- model_info.get("artifacts_dir"), "checkpoints", checkpoint_name
1013
- )
1206
+ artifacts_dir = model_info.get("artifacts_dir")
1207
+ if artifacts_dir is None:
1208
+ artifacts_dir = os.path.dirname(os.path.dirname(model_files.get("checkpoint")))
1209
+ checkpoint_file_path = os.path.join(artifacts_dir, "checkpoints", checkpoint_name)
1014
1210
  checkpoint_file_info = None
1015
- if not self._is_local_deploy:
1211
+ if not self._is_cli_deploy:
1016
1212
  checkpoint_file_info = self.api.file.get_info_by_path(
1017
1213
  sly_env.team_id(), checkpoint_file_path
1018
1214
  )
@@ -1115,7 +1311,7 @@ class Inference:
1115
1311
  def api(self) -> Api:
1116
1312
  if self._api is None:
1117
1313
  if (
1118
- self._is_local_deploy
1314
+ self._is_cli_deploy
1119
1315
  and os.getenv("SERVER_ADDRESS") is None
1120
1316
  and os.getenv("API_TOKEN") is None
1121
1317
  ):
@@ -2538,7 +2734,7 @@ class Inference:
2538
2734
  inference_request.done(len(results))
2539
2735
 
2540
2736
  def serve(self):
2541
- if not self._use_gui and not self._is_local_deploy:
2737
+ if not self._use_gui and not self._is_cli_deploy:
2542
2738
  Progress("Deploying model ...", 1)
2543
2739
 
2544
2740
  if is_debug_with_sly_net():
@@ -2553,30 +2749,31 @@ class Inference:
2553
2749
  self._task_id = task["id"]
2554
2750
  os.environ["TASK_ID"] = str(self._task_id)
2555
2751
  else:
2556
- if not self._is_local_deploy:
2752
+ if not self._is_cli_deploy:
2557
2753
  self._task_id = sly_env.task_id() if is_production() else None
2558
2754
 
2559
- if self._is_local_deploy:
2755
+ if self._is_cli_deploy:
2560
2756
  # Predict and shutdown
2561
- if self._args.mode == "predict" and any(
2757
+ if self._args.mode == "predict":
2758
+ if any(
2562
2759
  [
2563
2760
  self._args.input,
2564
2761
  self._args.project_id,
2565
2762
  self._args.dataset_id,
2566
2763
  self._args.image_id,
2567
2764
  ]
2568
- ):
2569
-
2570
- self._parse_inference_settings_from_args()
2571
- self._inference_by_local_deploy_args()
2572
- exit(0)
2765
+ ):
2766
+ self._parse_inference_settings_from_args()
2767
+ self._inference_by_cli_deploy_args()
2768
+ exit(0)
2769
+ else:
2770
+ logger.error("Predict mode requires one of the following arguments: --input, --project_id, --dataset_id, --image_id")
2771
+ exit(0)
2573
2772
 
2574
2773
  if isinstance(self.gui, GUI.InferenceGUI):
2575
2774
  self._app = Application(layout=self.get_ui())
2576
2775
  elif isinstance(self.gui, GUI.ServingGUI):
2577
2776
  self._app = Application(layout=self._app_layout)
2578
- # elif isinstance(self.gui, GUI.InferenceGUI):
2579
- # self._app = Application(layout=self.get_ui())
2580
2777
  else:
2581
2778
  self._app = Application(layout=self.get_ui())
2582
2779
 
@@ -3283,10 +3480,12 @@ class Inference:
3283
3480
  }
3284
3481
 
3285
3482
  # Local deploy without predict args
3286
- if self._is_local_deploy:
3483
+ if self._is_cli_deploy:
3287
3484
  self._run_server()
3485
+ elif self._is_quick_deploy:
3486
+ self._run_server_in_thread()
3288
3487
 
3289
- def _parse_local_deploy_args(self):
3488
+ def _parse_cli_deploy_args(self):
3290
3489
  parser = argparse.ArgumentParser(description="Run Inference Serving")
3291
3490
 
3292
3491
  # Positional args
@@ -3403,6 +3602,7 @@ class Inference:
3403
3602
  return args, True
3404
3603
 
3405
3604
  def _parse_inference_settings_from_args(self):
3605
+ logger.debug("Parsing inference settings from args")
3406
3606
  def parse_value(value: str):
3407
3607
  if value.lower() in ("true", "false"):
3408
3608
  return value.lower() == "true"
@@ -3438,6 +3638,7 @@ class Inference:
3438
3638
  args.settings = settings_dict
3439
3639
  args.settings = self._read_settings(args.settings)
3440
3640
  self._validate_settings(args.settings)
3641
+ logger.debug("Inference settings were successfully parsed from args")
3441
3642
 
3442
3643
  def _get_pretrained_model_params_from_args(self):
3443
3644
  model_files = None
@@ -3473,8 +3674,12 @@ class Inference:
3473
3674
  def _get_custom_model_params_from_args(self):
3474
3675
  def _load_experiment_info(artifacts_dir):
3475
3676
  experiment_path = os.path.join(artifacts_dir, "experiment_info.json")
3677
+ if not os.path.exists(experiment_path):
3678
+ raise ValueError(f"Experiment info file not found in {artifacts_dir}")
3476
3679
  model_info = self._load_json_file(experiment_path)
3477
3680
  model_meta_path = os.path.join(artifacts_dir, "model_meta.json")
3681
+ if not os.path.exists(model_meta_path):
3682
+ raise ValueError(f"Model meta file not found in {artifacts_dir}")
3478
3683
  model_info["model_meta"] = self._load_json_file(model_meta_path)
3479
3684
  original_model_files = model_info.get("model_files")
3480
3685
  return model_info, original_model_files
@@ -3500,6 +3705,7 @@ class Inference:
3500
3705
  else:
3501
3706
  loop.run_until_complete(coro)
3502
3707
 
3708
+ logger.debug("Getting custom model params from args")
3503
3709
  model_source = ModelSource.CUSTOM
3504
3710
  need_download = False
3505
3711
  checkpoint_path = self._args.model
@@ -3510,6 +3716,8 @@ class Inference:
3510
3716
  raise ValueError(
3511
3717
  "Team ID not found in env. Required for remote custom checkpoints."
3512
3718
  )
3719
+ if self.api is None:
3720
+ raise ValueError("API is not initialized. Please provide .env file with 'API_TOKEN' and 'SERVER_ADDRESS' environment variables.")
3513
3721
  file_info = self.api.file.get_info_by_path(team_id, checkpoint_path)
3514
3722
  if not file_info:
3515
3723
  raise ValueError(
@@ -3517,26 +3725,46 @@ class Inference:
3517
3725
  )
3518
3726
  need_download = True
3519
3727
 
3728
+ if not need_download:
3729
+ try:
3730
+ # Read data from checkpoint
3731
+ logger.debug(f"Reading data from checkpoint: {checkpoint_path}")
3732
+ import torch # pylint: disable=import-error
3733
+ checkpoint = torch.load(checkpoint_path)
3734
+ model_info = checkpoint["model_info"]
3735
+ model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
3736
+ model_meta = os.path.join(self.model_dir, "model_meta.json")
3737
+ model_info["model_meta"] = self._load_json_file(model_meta)
3738
+ model_files["checkpoint"] = checkpoint_path
3739
+ need_download = False
3740
+ return model_files, model_source, model_info, need_download
3741
+ except Exception as e:
3742
+ logger.debug(f"Failed to read data from checkpoint: {repr(e)}")
3743
+
3520
3744
  artifacts_dir = os.path.dirname(os.path.dirname(checkpoint_path))
3521
3745
  if not need_download:
3746
+ logger.debug(f"Looking for data in artifacts: '{artifacts_dir}'")
3522
3747
  model_info, original_model_files = _load_experiment_info(artifacts_dir)
3523
3748
  model_files = _prepare_local_model_files(
3524
3749
  artifacts_dir, checkpoint_path, original_model_files
3525
3750
  )
3526
-
3751
+ logger.debug(f"Data was found in artifacts directory: '{artifacts_dir}'")
3527
3752
  else:
3753
+ logger.debug(f"Downloading data from remote directory: '{artifacts_dir}'")
3528
3754
  local_artifacts_dir = os.path.join(
3529
- self.model_dir, "local_deploy", os.path.basename(artifacts_dir)
3755
+ self.model_dir, "cli_deploy", os.path.basename(artifacts_dir)
3530
3756
  )
3531
3757
  _download_remote_files(team_id, artifacts_dir, local_artifacts_dir)
3532
-
3533
3758
  model_info, original_model_files = _load_experiment_info(local_artifacts_dir)
3534
3759
  model_files = _prepare_local_model_files(
3535
3760
  local_artifacts_dir, checkpoint_path, original_model_files
3536
3761
  )
3762
+ logger.debug(f"Data was downloaded from remote directory: '{artifacts_dir}'")
3763
+ logger.debug("Custom model params were successfully parsed from args")
3537
3764
  return model_files, model_source, model_info, need_download
3538
3765
 
3539
3766
  def _get_deploy_params_from_args(self):
3767
+ logger.debug("Getting deploy params from args")
3540
3768
  # Ensure model directory exists
3541
3769
  device = self._args.device if self._args.device else "cuda:0"
3542
3770
  runtime = self._args.runtime if self._args.runtime else RuntimeType.PYTORCH
@@ -3563,7 +3791,6 @@ class Inference:
3563
3791
  "device": device,
3564
3792
  "runtime": runtime,
3565
3793
  }
3566
-
3567
3794
  logger.debug(f"Deploy parameters: {deploy_params}")
3568
3795
  return deploy_params, need_download
3569
3796
 
@@ -3572,6 +3799,18 @@ class Inference:
3572
3799
  self._uvicorn_server = uvicorn.Server(config)
3573
3800
  self._uvicorn_server.run()
3574
3801
 
3802
+ def _run_server_in_thread(self):
3803
+ """Run Uvicorn server in a separate thread so that this method doesn't block the caller."""
3804
+ import threading
3805
+
3806
+ def _serve():
3807
+ config = uvicorn.Config(app=self._app, host="0.0.0.0", port=8000, ws="websockets")
3808
+ self._uvicorn_server = uvicorn.Server(config)
3809
+ self._uvicorn_server.run()
3810
+
3811
+ self._server_thread = threading.Thread(target=_serve, daemon=True)
3812
+ self._server_thread.start()
3813
+
3575
3814
  def _read_settings(self, settings: Union[str, Dict[str, Any]]):
3576
3815
  if isinstance(settings, dict):
3577
3816
  return settings
@@ -3598,7 +3837,8 @@ class Inference:
3598
3837
  f"Inference settings doesn't have key: '{key}'. Available keys are: '{acceptable_keys}'"
3599
3838
  )
3600
3839
 
3601
- def _inference_by_local_deploy_args(self):
3840
+ def _inference_by_cli_deploy_args(self):
3841
+ logger.debug("Starting inference by CLI deploy args")
3602
3842
  missing_env_message = "Set 'SERVER_ADDRESS' and 'API_TOKEN' environment variables to predict data on Supervisely platform."
3603
3843
 
3604
3844
  def predict_project_id_by_args(
@@ -3612,14 +3852,13 @@ class Inference:
3612
3852
  ):
3613
3853
  if self.api is None:
3614
3854
  raise ValueError(missing_env_message)
3855
+ if draw:
3856
+ raise ValueError("Draw visualization is not supported for project inference")
3615
3857
 
3616
3858
  if dataset_ids:
3617
- logger.info(f"Predicting datasets: '{dataset_ids}'")
3859
+ logger.info(f"Predicting Dataset(s) by ID(s): '{', '.join(str(dataset_id) for dataset_id in dataset_ids)}'")
3618
3860
  else:
3619
- logger.info(f"Predicting project: '{project_id}'")
3620
-
3621
- if draw:
3622
- raise ValueError("Draw visualization is not supported for project inference")
3861
+ logger.info(f"Predicting Project by ID: {project_id}")
3623
3862
 
3624
3863
  state = {
3625
3864
  "projectId": project_id,
@@ -3659,6 +3898,7 @@ class Inference:
3659
3898
  draw: bool = False,
3660
3899
  upload: bool = False,
3661
3900
  ):
3901
+ logger.info(f"Predicting Dataset(s) by ID(s): {', '.join(str(dataset_id) for dataset_id in dataset_ids)}")
3662
3902
  if draw:
3663
3903
  raise ValueError("Draw visualization is not supported for dataset inference")
3664
3904
  if self.api is None:
@@ -3682,7 +3922,7 @@ class Inference:
3682
3922
  if self.api is None:
3683
3923
  raise ValueError(missing_env_message)
3684
3924
 
3685
- logger.info(f"Predicting image: '{image_id}'")
3925
+ logger.info(f"Predicting Image by ID: {image_id}")
3686
3926
 
3687
3927
  def predict_image_np(image_np):
3688
3928
  anns, _ = self._inference_auto([image_np], settings)
@@ -3719,8 +3959,7 @@ class Inference:
3719
3959
  output_dir: str = "./predictions",
3720
3960
  draw: bool = False,
3721
3961
  ):
3722
- logger.info(f"Predicting '{input_path}'")
3723
-
3962
+ logger.info(f"Predicting Local Data: {input_path}")
3724
3963
  def postprocess_image(image_path: str, ann: Annotation, pred_dir: str = None):
3725
3964
  image_name = sly_fs.get_file_name_with_ext(image_path)
3726
3965
  if pred_dir is not None:
@@ -1667,26 +1667,25 @@ class TrainApp:
1667
1667
  checkpoint_name = sly_fs.get_file_name_with_ext(checkpoint_path)
1668
1668
  new_checkpoint_path = join(self._output_checkpoints_dir, checkpoint_name)
1669
1669
  shutil.move(checkpoint_path, new_checkpoint_path)
1670
- if len(ckpt_files) > 0:
1671
- try:
1672
- # pylint: disable=import-error
1673
- import torch
1674
-
1675
- state_dict = torch.load(new_checkpoint_path)
1676
- state_dict["model_info"] = {
1677
- "model_name": experiment_info["model_name"],
1678
- "framework": self.framework_name,
1679
- "checkpoint": checkpoint_name,
1680
- "experiment": self.gui.training_process.get_experiment_name(),
1681
- }
1682
- state_dict["model_meta"] = model_meta.to_json()
1683
- state_dict["model_files"] = ckpt_files
1684
- torch.save(state_dict, new_checkpoint_path)
1685
- except Exception as e:
1686
- logger.warning(
1687
- f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}"
1688
- )
1689
- continue
1670
+ try:
1671
+ # pylint: disable=import-error
1672
+ import torch
1673
+ state_dict = torch.load(new_checkpoint_path)
1674
+ state_dict["model_info"] = {
1675
+ "task_id": self.task_id,
1676
+ "model_name": experiment_info["model_name"],
1677
+ "framework": self.framework_name,
1678
+ "checkpoint": checkpoint_name,
1679
+ "experiment": self.gui.training_process.get_experiment_name(),
1680
+ }
1681
+ state_dict["model_meta"] = model_meta.to_json()
1682
+ state_dict["model_files"] = ckpt_files
1683
+ torch.save(state_dict, new_checkpoint_path)
1684
+ except Exception as e:
1685
+ logger.warning(
1686
+ f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}"
1687
+ )
1688
+ continue
1690
1689
 
1691
1690
  new_checkpoint_paths.append(new_checkpoint_path)
1692
1691
  if sly_fs.get_file_name_with_ext(checkpoint_path) == best_checkpoints_name:
@@ -478,6 +478,12 @@ def download_volume_project(
478
478
  figure_path = dataset_fs.get_interpolation_path(volume_name, sf)
479
479
  mesh_paths.append(figure_path)
480
480
 
481
+ figs = api.volume.figure.download(dataset.id, [volume_id], skip_geometry=True)[volume_id]
482
+ figs_ids_map = {fig.id: fig for fig in figs}
483
+ for ann_fig in ann.figures + ann.spatial_figures:
484
+ fig = figs_ids_map.get(ann_fig.geometry.sly_id)
485
+ ann_fig.custom_data.update(fig.custom_data)
486
+
481
487
  api.volume.figure.download_stl_meshes(mesh_ids, mesh_paths)
482
488
  api.volume.figure.download_sf_geometries(mask_ids, mask_paths)
483
489