supervisely 6.73.268__py3-none-any.whl → 6.73.270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

@@ -1,3 +1,5 @@
1
+ import argparse
2
+ import asyncio
1
3
  import inspect
2
4
  import json
3
5
  import os
@@ -11,12 +13,14 @@ from collections import OrderedDict, defaultdict
11
13
  from concurrent.futures import ThreadPoolExecutor
12
14
  from dataclasses import asdict
13
15
  from functools import partial, wraps
16
+ from pathlib import Path
14
17
  from queue import Queue
15
18
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
19
  from urllib.request import urlopen
17
20
 
18
21
  import numpy as np
19
22
  import requests
23
+ import uvicorn
20
24
  import yaml
21
25
  from fastapi import Form, HTTPException, Request, Response, UploadFile, status
22
26
  from fastapi.responses import JSONResponse
@@ -24,8 +28,7 @@ from requests.structures import CaseInsensitiveDict
24
28
 
25
29
  import supervisely.app.development as sly_app_development
26
30
  import supervisely.imaging.image as sly_image
27
- import supervisely.io.env as env
28
- import supervisely.io.fs as fs
31
+ import supervisely.io.env as sly_env
29
32
  import supervisely.io.fs as sly_fs
30
33
  import supervisely.io.json as sly_json
31
34
  import supervisely.nn.inference.gui as GUI
@@ -33,6 +36,7 @@ from supervisely import DatasetInfo, ProjectInfo, VideoAnnotation, batched
33
36
  from supervisely._utils import (
34
37
  add_callback,
35
38
  get_filename_from_headers,
39
+ get_or_create_event_loop,
36
40
  is_debug_with_sly_net,
37
41
  is_production,
38
42
  rand_str,
@@ -43,6 +47,7 @@ from supervisely.annotation.obj_class import ObjClass
43
47
  from supervisely.annotation.tag_collection import TagCollection
44
48
  from supervisely.annotation.tag_meta import TagMeta, TagValueType
45
49
  from supervisely.api.api import Api
50
+ from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
46
51
  from supervisely.api.image_api import ImageInfo
47
52
  from supervisely.app.content import StateJson, get_data_dir
48
53
  from supervisely.app.exceptions import DialogWindowError
@@ -106,9 +111,19 @@ class Inference:
106
111
  multithread_inference: Optional[bool] = True,
107
112
  use_serving_gui_template: Optional[bool] = False,
108
113
  ):
114
+
115
+ self._args, self._is_local_deploy = self._parse_local_deploy_args()
116
+
109
117
  if model_dir is None:
110
- model_dir = os.path.join(get_data_dir(), "models")
111
- fs.mkdir(model_dir)
118
+ if self._is_local_deploy is True:
119
+ try:
120
+ model_dir = get_data_dir()
121
+ except:
122
+ model_dir = Path("~/.cache/supervisely/app_data").expanduser()
123
+ else:
124
+ model_dir = os.path.join(get_data_dir(), "models")
125
+ sly_fs.mkdir(model_dir)
126
+
112
127
  self.device: str = None
113
128
  self.runtime: str = None
114
129
  self.model_precision: str = None
@@ -128,7 +143,7 @@ class Inference:
128
143
  self._autostart_delay_time = 5 * 60 # 5 min
129
144
  self._tracker = None
130
145
  self._hardware: str = None
131
- self.pretrained_models = self._load_models_json(self.MODELS) if self.MODELS else None
146
+ self.pretrained_models = self._load_models_json_file(self.MODELS) if self.MODELS else None
132
147
  if custom_inference_settings is None:
133
148
  if self.INFERENCE_SETTINGS is not None:
134
149
  custom_inference_settings = self.INFERENCE_SETTINGS
@@ -136,7 +151,7 @@ class Inference:
136
151
  logger.debug("Custom inference settings are not provided.")
137
152
  custom_inference_settings = {}
138
153
  if isinstance(custom_inference_settings, str):
139
- if fs.file_exists(custom_inference_settings):
154
+ if sly_fs.file_exists(custom_inference_settings):
140
155
  with open(custom_inference_settings, "r") as f:
141
156
  custom_inference_settings = f.read()
142
157
  else:
@@ -146,12 +161,24 @@ class Inference:
146
161
  self._use_gui = use_gui
147
162
  self._use_serving_gui_template = use_serving_gui_template
148
163
  self._gui = None
164
+ self._uvicorn_server = None
149
165
 
150
166
  self.load_on_device = LOAD_ON_DEVICE_DECORATOR(self.load_on_device)
151
167
  self.load_on_device = add_callback(self.load_on_device, self._set_served_callback)
152
168
 
153
169
  self.load_model = LOAD_MODEL_DECORATOR(self.load_model)
154
170
 
171
+ if self._is_local_deploy:
172
+ # self._args = self._parse_local_deploy_args()
173
+ self._use_gui = False
174
+ deploy_params, need_download = self._get_deploy_params_from_args()
175
+ if need_download:
176
+ local_model_files = self._download_model_files(
177
+ deploy_params["model_source"], deploy_params["model_files"], False
178
+ )
179
+ deploy_params["model_files"] = local_model_files
180
+ self._load_model_headless(**deploy_params)
181
+
155
182
  if self._use_gui:
156
183
  initialize_custom_gui_method = getattr(self, "initialize_custom_gui", None)
157
184
  original_initialize_custom_gui_method = getattr(
@@ -224,10 +251,10 @@ class Inference:
224
251
  self.get_info = self._check_serve_before_call(self.get_info)
225
252
 
226
253
  self.cache = InferenceImageCache(
227
- maxsize=env.smart_cache_size(),
228
- ttl=env.smart_cache_ttl(),
254
+ maxsize=sly_env.smart_cache_size(),
255
+ ttl=sly_env.smart_cache_ttl(),
229
256
  is_persistent=True,
230
- base_folder=env.smart_cache_container_dir(),
257
+ base_folder=sly_env.smart_cache_container_dir(),
231
258
  log_progress=True,
232
259
  )
233
260
 
@@ -248,9 +275,18 @@ class Inference:
248
275
  )
249
276
  device = "cpu"
250
277
 
251
- def _load_models_json(self, models: str) -> List[Dict[str, Any]]:
278
+ def _load_json_file(self, file_path: str) -> dict:
279
+ if isinstance(file_path, str):
280
+ if sly_fs.file_exists(file_path) and sly_fs.get_file_ext(file_path) == ".json":
281
+ return sly_json.load_json_file(file_path)
282
+ else:
283
+ raise ValueError("File not found or invalid file format.")
284
+ else:
285
+ raise ValueError("Invalid file. Please provide a valid '.json' file.")
286
+
287
+ def _load_models_json_file(self, models: str) -> List[Dict[str, Any]]:
252
288
  """
253
- Loads models from the provided file or list of model configurations.
289
+ Loads dictionary from the provided file.
254
290
  """
255
291
  if isinstance(models, str):
256
292
  if sly_fs.file_exists(models) and sly_fs.get_file_ext(models) == ".json":
@@ -258,9 +294,7 @@ class Inference:
258
294
  else:
259
295
  raise ValueError("File not found or invalid file format.")
260
296
  else:
261
- raise ValueError(
262
- "Invalid models file. Please provide a valid '.json' file with list of model configurations."
263
- )
297
+ raise ValueError("Invalid file. Please provide a valid '.json' file.")
264
298
 
265
299
  if not isinstance(models, list):
266
300
  raise ValueError("models parameters must be a list of dicts")
@@ -394,13 +428,13 @@ class Inference:
394
428
  else:
395
429
  progress = None
396
430
 
397
- if fs.dir_exists(src_path) or fs.file_exists(
431
+ if sly_fs.dir_exists(src_path) or sly_fs.file_exists(
398
432
  src_path
399
433
  ): # only during debug, has no effect in production
400
434
  dst_path = os.path.abspath(src_path)
401
435
  logger.info(f"File {dst_path} found.")
402
436
  elif src_path.startswith("/"): # folder from Team Files
403
- team_id = env.team_id()
437
+ team_id = sly_env.team_id()
404
438
 
405
439
  if src_path.endswith("/") and self.api.file.dir_exists(team_id, src_path):
406
440
 
@@ -436,7 +470,7 @@ class Inference:
436
470
  def download_file(team_id, src_path, dst_path, progress_cb=None):
437
471
  self.api.file.download(team_id, src_path, dst_path, progress_cb=progress_cb)
438
472
 
439
- file_info = self.api.file.get_info_by_path(env.team_id(), src_path)
473
+ file_info = self.api.file.get_info_by_path(sly_env.team_id(), src_path)
440
474
  if progress is None:
441
475
  download_file(team_id, src_path, dst_path)
442
476
  else:
@@ -451,8 +485,8 @@ class Inference:
451
485
  logger.info(f"📥 File {basename} has been successfully downloaded from Team Files")
452
486
  logger.info(f"File {basename} path: {dst_path}")
453
487
  else: # external url
454
- if not fs.dir_exists(os.path.dirname(dst_path)):
455
- fs.mkdir(os.path.dirname(dst_path))
488
+ if not sly_fs.dir_exists(os.path.dirname(dst_path)):
489
+ sly_fs.mkdir(os.path.dirname(dst_path))
456
490
 
457
491
  def download_external_file(url, save_path, progress=None):
458
492
  def download_content(save_path, progress_cb=None):
@@ -531,13 +565,15 @@ class Inference:
531
565
  def _checkpoints_cache_dir(self):
532
566
  return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
533
567
 
534
- def _download_model_files(self, model_source: str, model_files: List[str]) -> dict:
568
+ def _download_model_files(
569
+ self, model_source: str, model_files: List[str], log_progress: bool = True
570
+ ) -> dict:
535
571
  if model_source == ModelSource.PRETRAINED:
536
- return self._download_pretrained_model(model_files)
572
+ return self._download_pretrained_model(model_files, log_progress)
537
573
  elif model_source == ModelSource.CUSTOM:
538
- return self._download_custom_model(model_files)
574
+ return self._download_custom_model(model_files, log_progress)
539
575
 
540
- def _download_pretrained_model(self, model_files: dict):
576
+ def _download_pretrained_model(self, model_files: dict, log_progress: bool = True):
541
577
  """
542
578
  Downloads the pretrained model data.
543
579
  """
@@ -564,54 +600,61 @@ class Inference:
564
600
  logger.debug(f"Model: '{file_name}' was found in model dir")
565
601
  continue
566
602
 
567
- with self.gui.download_progress(
568
- message=f"Downloading: '{file_name}'",
569
- total=file_size,
570
- unit="bytes",
571
- unit_scale=True,
572
- ) as download_pbar:
573
- self.gui.download_progress.show()
574
- sly_fs.download(
575
- url=file_url,
576
- save_path=file_path,
577
- progress=download_pbar.update,
578
- )
603
+ if log_progress:
604
+ with self.gui.download_progress(
605
+ message=f"Downloading: '{file_name}'",
606
+ total=file_size,
607
+ unit="bytes",
608
+ unit_scale=True,
609
+ ) as download_pbar:
610
+ self.gui.download_progress.show()
611
+ sly_fs.download(
612
+ url=file_url, save_path=file_path, progress=download_pbar.update
613
+ )
614
+ else:
615
+ sly_fs.download(url=file_url, save_path=file_path)
579
616
  local_model_files[file] = file_path
580
617
  else:
581
618
  local_model_files[file] = file_url
582
- self.gui.download_progress.hide()
619
+
620
+ if log_progress:
621
+ self.gui.download_progress.hide()
583
622
  return local_model_files
584
623
 
585
- def _download_custom_model(self, model_files: dict):
624
+ def _download_custom_model(self, model_files: dict, log_progress: bool = True):
586
625
  """
587
626
  Downloads the custom model data.
588
627
  """
589
-
590
- team_id = env.team_id()
628
+ team_id = sly_env.team_id()
591
629
  local_model_files = {}
592
-
593
630
  for file in model_files:
594
631
  file_url = model_files[file]
595
632
  file_info = self.api.file.get_info_by_path(team_id, file_url)
596
633
  if file_info is None:
597
- raise FileNotFoundError(
598
- f"File '{file_url}' not found in Team Files. Make sure the file exists."
599
- )
634
+ if sly_fs.file_exists(file_url):
635
+ local_model_files[file] = file_url
636
+ continue
637
+ else:
638
+ raise FileNotFoundError(f"File '{file_url}' not found in Team Files")
600
639
  file_size = file_info.sizeb
601
640
  file_name = os.path.basename(file_url)
602
641
  file_path = os.path.join(self.model_dir, file_name)
603
- with self.gui.download_progress(
604
- message=f"Downloading: '{file_name}'",
605
- total=file_size,
606
- unit="bytes",
607
- unit_scale=True,
608
- ) as download_pbar:
609
- self.gui.download_progress.show()
610
- self.api.file.download(
611
- team_id, file_url, file_path, progress_cb=download_pbar.update
612
- )
642
+ if log_progress:
643
+ with self.gui.download_progress(
644
+ message=f"Downloading: '{file_name}'",
645
+ total=file_size,
646
+ unit="bytes",
647
+ unit_scale=True,
648
+ ) as download_pbar:
649
+ self.gui.download_progress.show()
650
+ self.api.file.download(
651
+ team_id, file_url, file_path, progress_cb=download_pbar.update
652
+ )
653
+ else:
654
+ self.api.file.download(team_id, file_url, file_path)
613
655
  local_model_files[file] = file_path
614
- self.gui.download_progress.hide()
656
+ if log_progress:
657
+ self.gui.download_progress.hide()
615
658
  return local_model_files
616
659
 
617
660
  def _load_model(self, deploy_params: dict):
@@ -647,6 +690,13 @@ class Inference:
647
690
  if model_source == ModelSource.CUSTOM:
648
691
  self._set_model_meta_custom_model(model_info)
649
692
  self._set_checkpoint_info_custom_model(deploy_params)
693
+
694
+ try:
695
+ if is_production():
696
+ self._add_workflow_input(model_source, model_files, model_info)
697
+ except Exception as e:
698
+ logger.warning(f"Failed to add input to the workflow: {repr(e)}")
699
+
650
700
  self._load_model(deploy_params)
651
701
  if self._model_meta is None:
652
702
  self._set_model_meta_from_classes()
@@ -679,7 +729,7 @@ class Inference:
679
729
  model_info.get("artifacts_dir"), "checkpoints", checkpoint_name
680
730
  )
681
731
  checkpoint_file_info = self.api.file.get_info_by_path(
682
- env.team_id(), checkpoint_file_path
732
+ sly_env.team_id(), checkpoint_file_path
683
733
  )
684
734
  if checkpoint_file_info is None:
685
735
  checkpoint_url = None
@@ -1253,18 +1303,18 @@ class Inference:
1253
1303
  logger.debug("Inferring image_url...", extra={"state": state})
1254
1304
  settings = self._get_inference_settings(state)
1255
1305
  image_url = state["image_url"]
1256
- ext = fs.get_file_ext(image_url)
1306
+ ext = sly_fs.get_file_ext(image_url)
1257
1307
  if ext == "":
1258
1308
  ext = ".jpg"
1259
1309
  image_path = os.path.join(get_data_dir(), rand_str(15) + ext)
1260
- fs.download(image_url, image_path)
1310
+ sly_fs.download(image_url, image_path)
1261
1311
  logger.debug("Inference settings:", extra=settings)
1262
1312
  logger.debug(f"Downloaded path: {image_path}")
1263
1313
  anns, slides_data = self._inference_auto(
1264
1314
  [image_path],
1265
1315
  settings=settings,
1266
1316
  )
1267
- fs.silent_remove(image_path)
1317
+ sly_fs.silent_remove(image_path)
1268
1318
  return self._format_output(anns, slides_data)[0]
1269
1319
 
1270
1320
  def _inference_video_id(self, api: Api, state: dict, async_inference_request_uuid: str = None):
@@ -2208,28 +2258,46 @@ class Inference:
2208
2258
 
2209
2259
  if is_debug_with_sly_net():
2210
2260
  # advanced debug for Supervisely Team
2211
- logger.warn(
2261
+ logger.warning(
2212
2262
  "Serving is running in advanced development mode with Supervisely VPN Network"
2213
2263
  )
2214
- team_id = env.team_id()
2264
+ team_id = sly_env.team_id()
2215
2265
  # sly_app_development.supervisely_vpn_network(action="down") # for debug
2216
2266
  sly_app_development.supervisely_vpn_network(action="up")
2217
2267
  task = sly_app_development.create_debug_task(team_id, port="8000")
2218
2268
  self._task_id = task["id"]
2219
2269
  os.environ["TASK_ID"] = str(self._task_id)
2220
2270
  else:
2221
- self._task_id = env.task_id() if is_production() else None
2271
+ if not self._is_local_deploy:
2272
+ self._task_id = sly_env.task_id() if is_production() else None
2222
2273
 
2223
- if isinstance(self.gui, GUI.InferenceGUI):
2274
+ if isinstance(self.gui, GUI.ServingGUITemplate):
2224
2275
  self._app = Application(layout=self.get_ui())
2225
2276
  elif isinstance(self.gui, GUI.ServingGUI):
2226
2277
  self._app = Application(layout=self._app_layout)
2278
+ # elif isinstance(self.gui, GUI.InferenceGUI):
2279
+ # self._app = Application(layout=self.get_ui())
2227
2280
  else:
2228
2281
  self._app = Application(layout=self.get_ui())
2229
2282
 
2230
2283
  server = self._app.get_server()
2231
2284
  self._app.set_ready_check_function(self.is_model_deployed)
2232
2285
 
2286
+ if self._is_local_deploy:
2287
+ # Predict and shutdown
2288
+ if any(
2289
+ [
2290
+ self._args.predict_project,
2291
+ self._args.predict_dataset,
2292
+ self._args.predict_dir,
2293
+ self._args.predict_image,
2294
+ ]
2295
+ ):
2296
+ self._inference_by_local_deploy_args()
2297
+ # Gracefully shut down the server
2298
+ self._app.shutdown()
2299
+ # else: run server after endpoints
2300
+
2233
2301
  @call_on_autostart()
2234
2302
  def autostart_func():
2235
2303
  gpu_count = get_gpu_count()
@@ -2725,6 +2793,285 @@ class Inference:
2725
2793
  def _get_deploy_info():
2726
2794
  return asdict(self._get_deploy_info())
2727
2795
 
2796
+ # Local deploy without predict args
2797
+ if self._is_local_deploy:
2798
+ self._run_server()
2799
+
2800
+ def _parse_local_deploy_args(self):
2801
+ parser = argparse.ArgumentParser(description="Run Inference Serving")
2802
+
2803
+ # Deploy args
2804
+ parser.add_argument(
2805
+ "--model",
2806
+ type=str,
2807
+ help="Name of the pretrained model or path to custom checkpoint file",
2808
+ )
2809
+ parser.add_argument(
2810
+ "--device",
2811
+ type=str,
2812
+ choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"],
2813
+ default="cuda:0",
2814
+ help="Device to use for inference (default: 'cuda:0')",
2815
+ )
2816
+ parser.add_argument(
2817
+ "--runtime",
2818
+ type=str,
2819
+ choices=[RuntimeType.PYTORCH, RuntimeType.ONNXRUNTIME, RuntimeType.TENSORRT],
2820
+ default=RuntimeType.PYTORCH,
2821
+ help="Runtime type for inference (default: PYTORCH)",
2822
+ )
2823
+ # -------------------------- #
2824
+
2825
+ # Predict args
2826
+ parser.add_argument("--predict-project", type=int, required=False, help="ID of the project")
2827
+ parser.add_argument(
2828
+ "--predict-dataset",
2829
+ type=lambda x: [int(i) for i in x.split(",")] if "," in x else int(x),
2830
+ required=False,
2831
+ help="ID of the dataset or a comma-separated list of dataset IDs",
2832
+ )
2833
+ parser.add_argument(
2834
+ "--predict-dir",
2835
+ type=str,
2836
+ required=False,
2837
+ help="Not implemented yet. Path to the local directory with images",
2838
+ )
2839
+ parser.add_argument(
2840
+ "--predict-image",
2841
+ type=str,
2842
+ required=False,
2843
+ help="Image ID on Supervisely instance or path to local image",
2844
+ )
2845
+ # -------------------------- #
2846
+
2847
+ # Output args
2848
+ parser.add_argument("--output", type=str, required=False, help="Not implemented yet")
2849
+ parser.add_argument("--output-dir", type=str, required=False, help="Not implemented yet")
2850
+ # -------------------------- #
2851
+
2852
+ # Parse arguments
2853
+ args, _ = parser.parse_known_args()
2854
+ if args.model is None:
2855
+ # raise ValueError("Argument '--model' is required for local deployment")
2856
+ return None, False
2857
+ if isinstance(args.predict_dataset, int):
2858
+ args.predict_dataset = [args.predict_dataset]
2859
+ if args.predict_image is not None:
2860
+ if args.predict_image.isdigit():
2861
+ args.predict_image = int(args.predict_image)
2862
+ return args, True
2863
+
2864
+ def _get_pretrained_model_params_from_args(self):
2865
+ model_files = None
2866
+ model_source = None
2867
+ model_info = None
2868
+ need_download = True
2869
+
2870
+ model = self._args.model
2871
+ for m in self.pretrained_models:
2872
+ meta = m.get("meta", None)
2873
+ if meta is None:
2874
+ continue
2875
+ model_name = meta.get("model_name", None)
2876
+ if model_name is None:
2877
+ continue
2878
+ m_files = meta.get("model_files", None)
2879
+ if m_files is None:
2880
+ continue
2881
+ checkpoint = m_files.get("checkpoint", None)
2882
+ if checkpoint is None:
2883
+ continue
2884
+ if model == m["meta"]["model_name"]:
2885
+ model_info = m
2886
+ model_source = ModelSource.PRETRAINED
2887
+ model_files = {"checkpoint": checkpoint}
2888
+ config = m_files.get("config", None)
2889
+ if config is not None:
2890
+ model_files["config"] = config
2891
+ break
2892
+
2893
+ return model_files, model_source, model_info, need_download
2894
+
2895
+ def _get_custom_model_params_from_args(self):
2896
+ def _load_experiment_info(artifacts_dir):
2897
+ experiment_path = os.path.join(artifacts_dir, "experiment_info.json")
2898
+ model_info = self._load_json_file(experiment_path)
2899
+ original_model_files = model_info.get("model_files")
2900
+ if not original_model_files:
2901
+ raise ValueError("Invalid 'experiment_info.json'. Missing 'model_files' key.")
2902
+ return model_info, original_model_files
2903
+
2904
+ def _prepare_local_model_files(artifacts_dir, checkpoint_path, original_model_files):
2905
+ return {k: os.path.join(artifacts_dir, v) for k, v in original_model_files.items()} | {
2906
+ "checkpoint": checkpoint_path
2907
+ }
2908
+
2909
+ def _download_remote_files(team_id, artifacts_dir, local_artifacts_dir):
2910
+ sly_fs.mkdir(local_artifacts_dir, True)
2911
+ file_infos = self.api.file.list(team_id, artifacts_dir, False, "fileinfo")
2912
+ remote_paths = [f.path for f in file_infos if not f.is_dir]
2913
+ local_paths = [
2914
+ os.path.join(local_artifacts_dir, f.name) for f in file_infos if not f.is_dir
2915
+ ]
2916
+
2917
+ coro = self.api.file.download_bulk_async(team_id, remote_paths, local_paths)
2918
+ loop = get_or_create_event_loop()
2919
+ if loop.is_running():
2920
+ future = asyncio.run_coroutine_threadsafe(coro, loop)
2921
+ future.result()
2922
+ else:
2923
+ loop.run_until_complete(coro)
2924
+
2925
+ model_source = ModelSource.CUSTOM
2926
+ need_download = False
2927
+ checkpoint_path = self._args.model
2928
+
2929
+ if not os.path.isfile(checkpoint_path):
2930
+ team_id = sly_env.team_id(raise_not_found=False)
2931
+ if not team_id:
2932
+ raise ValueError(
2933
+ "Team ID not found in env. Required for remote custom checkpoints."
2934
+ )
2935
+ file_info = self.api.file.get_info_by_path(team_id, checkpoint_path)
2936
+ if not file_info:
2937
+ raise ValueError(
2938
+ f"Couldn't find: '{checkpoint_path}' locally or remotely in Team ID."
2939
+ )
2940
+ need_download = True
2941
+
2942
+ artifacts_dir = os.path.dirname(os.path.dirname(checkpoint_path))
2943
+ if not need_download:
2944
+ model_info, original_model_files = _load_experiment_info(artifacts_dir)
2945
+ model_files = _prepare_local_model_files(
2946
+ artifacts_dir, checkpoint_path, original_model_files
2947
+ )
2948
+ else:
2949
+ local_artifacts_dir = os.path.join(
2950
+ self.model_dir, "local_deploy", os.path.basename(artifacts_dir)
2951
+ )
2952
+ _download_remote_files(team_id, artifacts_dir, local_artifacts_dir)
2953
+
2954
+ model_info, original_model_files = _load_experiment_info(local_artifacts_dir)
2955
+ model_files = _prepare_local_model_files(
2956
+ local_artifacts_dir, checkpoint_path, original_model_files
2957
+ )
2958
+ return model_files, model_source, model_info, need_download
2959
+
2960
+ def _get_deploy_params_from_args(self):
2961
+ # Ensure model directory exists
2962
+ device = self._args.device if self._args.device else "cuda:0"
2963
+ runtime = self._args.runtime if self._args.runtime else RuntimeType.PYTORCH
2964
+
2965
+ model_files, model_source, model_info, need_download = (
2966
+ self._get_pretrained_model_params_from_args()
2967
+ )
2968
+ if model_source is None:
2969
+ model_files, model_source, model_info, need_download = (
2970
+ self._get_custom_model_params_from_args()
2971
+ )
2972
+
2973
+ if model_source is None:
2974
+ raise ValueError("Couldn't create 'model_source' from args")
2975
+ if model_files is None:
2976
+ raise ValueError("Couldn't create 'model_files' from args")
2977
+ if model_info is None:
2978
+ raise ValueError("Couldn't create 'model_info' from args")
2979
+
2980
+ deploy_params = {
2981
+ "model_files": model_files,
2982
+ "model_source": model_source,
2983
+ "model_info": model_info,
2984
+ "device": device,
2985
+ "runtime": runtime,
2986
+ }
2987
+
2988
+ logger.info(f"Deploy parameters: {deploy_params}")
2989
+ return deploy_params, need_download
2990
+
2991
+ def _run_server(self):
2992
+ config = uvicorn.Config(app=self._app, host="0.0.0.0", port=8000, ws="websockets")
2993
+ self._uvicorn_server = uvicorn.Server(config)
2994
+ self._uvicorn_server.run()
2995
+
2996
+ def _inference_by_local_deploy_args(self):
2997
+ def predict_project_by_args(api: Api, project_id: int, dataset_ids: List[int] = None):
2998
+ source_project = api.project.get_info_by_id(project_id)
2999
+ workspace_id = source_project.workspace_id
3000
+ output_project = api.project.create(
3001
+ workspace_id, f"{source_project.name} predicted", change_name_if_conflict=True
3002
+ )
3003
+ results = self._inference_project_id(
3004
+ api=self.api,
3005
+ state={
3006
+ "projectId": project_id,
3007
+ "dataset_ids": dataset_ids,
3008
+ "output_project_id": output_project.id,
3009
+ },
3010
+ )
3011
+
3012
+ def predict_datasets_by_args(api: Api, dataset_ids: List[int]):
3013
+ dataset_infos = [api.dataset.get_info_by_id(dataset_id) for dataset_id in dataset_ids]
3014
+ project_ids = list(set([dataset_info.project_id for dataset_info in dataset_infos]))
3015
+ if len(project_ids) > 1:
3016
+ raise ValueError("All datasets should belong to the same project")
3017
+ predict_project_by_args(api, project_ids[0], dataset_ids)
3018
+
3019
+ def predict_image_by_args(api: Api, image: Union[str, int]):
3020
+ def predict_image_np(image_np):
3021
+ settings = self._get_inference_settings({})
3022
+ anns, _ = self._inference_auto([image_np], settings)
3023
+ if len(anns) == 0:
3024
+ return Annotation(img_size=image_np.shape[:2])
3025
+ ann = anns[0]
3026
+ return ann
3027
+
3028
+ if isinstance(image, int):
3029
+ image_np = api.image.download_np(image)
3030
+ ann = predict_image_np(image_np)
3031
+ api.annotation.upload_ann(image, ann)
3032
+ elif isinstance(image, str):
3033
+ if sly_fs.file_exists(self._args.predict):
3034
+ image_np = sly_image.read(self._args.predict)
3035
+ ann = predict_image_np(image_np)
3036
+ pred_ann_path = image + ".json"
3037
+ sly_json.dump_json_file(ann.to_json(), pred_ann_path)
3038
+ # Save image for debug
3039
+ # ann.draw_pretty(image_np)
3040
+ # pred_path = os.path.join(os.path.dirname(self._args.predict), "pred_" + os.path.basename(self._args.predict))
3041
+ # sly_image.write(pred_path, image_np)
3042
+
3043
+ if self._args.predict_project is not None:
3044
+ predict_project_by_args(self.api, self._args.predict_project)
3045
+ elif self._args.predict_dataset is not None:
3046
+ predict_datasets_by_args(self.api, self._args.predict_dataset)
3047
+ elif self._args.predict_dir is not None:
3048
+ raise NotImplementedError("Predict from directory is not implemented yet")
3049
+ elif self._args.predict_image is not None:
3050
+ predict_image_by_args(self.api, self._args.predict_image)
3051
+
3052
+ def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
3053
+ if model_source == ModelSource.PRETRAINED:
3054
+ checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
3055
+ checkpoint_name = model_info["meta"]["model_name"]
3056
+ else:
3057
+ checkpoint_name = sly_fs.get_file_name_with_ext(model_files["checkpoint"])
3058
+ checkpoint_url = os.path.join(
3059
+ model_info["artifacts_dir"], "checkpoints", checkpoint_name
3060
+ )
3061
+
3062
+ app_name = sly_env.app_name()
3063
+ meta = WorkflowMeta(node_settings=WorkflowSettings(title=f"Serve {app_name}"))
3064
+
3065
+ logger.debug(
3066
+ f"Workflow Input: Checkpoint URL - {checkpoint_url}, Checkpoint Name - {checkpoint_name}"
3067
+ )
3068
+ if checkpoint_url and self.api.file.exists(sly_env.team_id(), checkpoint_url):
3069
+ self.api.app.workflow.add_input_file(checkpoint_url, model_weight=True, meta=meta)
3070
+ else:
3071
+ logger.debug(
3072
+ f"Checkpoint {checkpoint_url} not found in Team Files. Cannot set workflow input"
3073
+ )
3074
+
2728
3075
 
2729
3076
  def _get_log_extra_for_inference_request(inference_request_uuid, inference_request: dict):
2730
3077
  log_extra = {
@@ -2998,7 +3345,7 @@ class TempImageWriter:
2998
3345
  def __init__(self, format: str = "png"):
2999
3346
  self.format = format
3000
3347
  self.temp_dir = os.path.join(get_data_dir(), rand_str(10))
3001
- fs.mkdir(self.temp_dir)
3348
+ sly_fs.mkdir(self.temp_dir)
3002
3349
 
3003
3350
  def write(self, image: np.ndarray):
3004
3351
  image_path = os.path.join(self.temp_dir, f"{rand_str(10)}.{self.format}")
@@ -3006,7 +3353,7 @@ class TempImageWriter:
3006
3353
  return image_path
3007
3354
 
3008
3355
  def clean(self):
3009
- fs.remove_dir(self.temp_dir)
3356
+ sly_fs.remove_dir(self.temp_dir)
3010
3357
 
3011
3358
 
3012
3359
  def get_hardware_info(device: str) -> str: