supervisely 6.73.268__py3-none-any.whl → 6.73.270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/file_api.py +1 -1
- supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +13 -103
- supervisely/nn/inference/inference.py +411 -64
- supervisely/nn/training/gui/gui.py +13 -5
- supervisely/nn/training/gui/training_artifacts.py +121 -51
- supervisely/nn/training/train_app.py +79 -23
- {supervisely-6.73.268.dist-info → supervisely-6.73.270.dist-info}/METADATA +1 -1
- {supervisely-6.73.268.dist-info → supervisely-6.73.270.dist-info}/RECORD +12 -12
- {supervisely-6.73.268.dist-info → supervisely-6.73.270.dist-info}/LICENSE +0 -0
- {supervisely-6.73.268.dist-info → supervisely-6.73.270.dist-info}/WHEEL +0 -0
- {supervisely-6.73.268.dist-info → supervisely-6.73.270.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.268.dist-info → supervisely-6.73.270.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import asyncio
|
|
1
3
|
import inspect
|
|
2
4
|
import json
|
|
3
5
|
import os
|
|
@@ -11,12 +13,14 @@ from collections import OrderedDict, defaultdict
|
|
|
11
13
|
from concurrent.futures import ThreadPoolExecutor
|
|
12
14
|
from dataclasses import asdict
|
|
13
15
|
from functools import partial, wraps
|
|
16
|
+
from pathlib import Path
|
|
14
17
|
from queue import Queue
|
|
15
18
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
16
19
|
from urllib.request import urlopen
|
|
17
20
|
|
|
18
21
|
import numpy as np
|
|
19
22
|
import requests
|
|
23
|
+
import uvicorn
|
|
20
24
|
import yaml
|
|
21
25
|
from fastapi import Form, HTTPException, Request, Response, UploadFile, status
|
|
22
26
|
from fastapi.responses import JSONResponse
|
|
@@ -24,8 +28,7 @@ from requests.structures import CaseInsensitiveDict
|
|
|
24
28
|
|
|
25
29
|
import supervisely.app.development as sly_app_development
|
|
26
30
|
import supervisely.imaging.image as sly_image
|
|
27
|
-
import supervisely.io.env as
|
|
28
|
-
import supervisely.io.fs as fs
|
|
31
|
+
import supervisely.io.env as sly_env
|
|
29
32
|
import supervisely.io.fs as sly_fs
|
|
30
33
|
import supervisely.io.json as sly_json
|
|
31
34
|
import supervisely.nn.inference.gui as GUI
|
|
@@ -33,6 +36,7 @@ from supervisely import DatasetInfo, ProjectInfo, VideoAnnotation, batched
|
|
|
33
36
|
from supervisely._utils import (
|
|
34
37
|
add_callback,
|
|
35
38
|
get_filename_from_headers,
|
|
39
|
+
get_or_create_event_loop,
|
|
36
40
|
is_debug_with_sly_net,
|
|
37
41
|
is_production,
|
|
38
42
|
rand_str,
|
|
@@ -43,6 +47,7 @@ from supervisely.annotation.obj_class import ObjClass
|
|
|
43
47
|
from supervisely.annotation.tag_collection import TagCollection
|
|
44
48
|
from supervisely.annotation.tag_meta import TagMeta, TagValueType
|
|
45
49
|
from supervisely.api.api import Api
|
|
50
|
+
from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
|
|
46
51
|
from supervisely.api.image_api import ImageInfo
|
|
47
52
|
from supervisely.app.content import StateJson, get_data_dir
|
|
48
53
|
from supervisely.app.exceptions import DialogWindowError
|
|
@@ -106,9 +111,19 @@ class Inference:
|
|
|
106
111
|
multithread_inference: Optional[bool] = True,
|
|
107
112
|
use_serving_gui_template: Optional[bool] = False,
|
|
108
113
|
):
|
|
114
|
+
|
|
115
|
+
self._args, self._is_local_deploy = self._parse_local_deploy_args()
|
|
116
|
+
|
|
109
117
|
if model_dir is None:
|
|
110
|
-
|
|
111
|
-
|
|
118
|
+
if self._is_local_deploy is True:
|
|
119
|
+
try:
|
|
120
|
+
model_dir = get_data_dir()
|
|
121
|
+
except:
|
|
122
|
+
model_dir = Path("~/.cache/supervisely/app_data").expanduser()
|
|
123
|
+
else:
|
|
124
|
+
model_dir = os.path.join(get_data_dir(), "models")
|
|
125
|
+
sly_fs.mkdir(model_dir)
|
|
126
|
+
|
|
112
127
|
self.device: str = None
|
|
113
128
|
self.runtime: str = None
|
|
114
129
|
self.model_precision: str = None
|
|
@@ -128,7 +143,7 @@ class Inference:
|
|
|
128
143
|
self._autostart_delay_time = 5 * 60 # 5 min
|
|
129
144
|
self._tracker = None
|
|
130
145
|
self._hardware: str = None
|
|
131
|
-
self.pretrained_models = self.
|
|
146
|
+
self.pretrained_models = self._load_models_json_file(self.MODELS) if self.MODELS else None
|
|
132
147
|
if custom_inference_settings is None:
|
|
133
148
|
if self.INFERENCE_SETTINGS is not None:
|
|
134
149
|
custom_inference_settings = self.INFERENCE_SETTINGS
|
|
@@ -136,7 +151,7 @@ class Inference:
|
|
|
136
151
|
logger.debug("Custom inference settings are not provided.")
|
|
137
152
|
custom_inference_settings = {}
|
|
138
153
|
if isinstance(custom_inference_settings, str):
|
|
139
|
-
if
|
|
154
|
+
if sly_fs.file_exists(custom_inference_settings):
|
|
140
155
|
with open(custom_inference_settings, "r") as f:
|
|
141
156
|
custom_inference_settings = f.read()
|
|
142
157
|
else:
|
|
@@ -146,12 +161,24 @@ class Inference:
|
|
|
146
161
|
self._use_gui = use_gui
|
|
147
162
|
self._use_serving_gui_template = use_serving_gui_template
|
|
148
163
|
self._gui = None
|
|
164
|
+
self._uvicorn_server = None
|
|
149
165
|
|
|
150
166
|
self.load_on_device = LOAD_ON_DEVICE_DECORATOR(self.load_on_device)
|
|
151
167
|
self.load_on_device = add_callback(self.load_on_device, self._set_served_callback)
|
|
152
168
|
|
|
153
169
|
self.load_model = LOAD_MODEL_DECORATOR(self.load_model)
|
|
154
170
|
|
|
171
|
+
if self._is_local_deploy:
|
|
172
|
+
# self._args = self._parse_local_deploy_args()
|
|
173
|
+
self._use_gui = False
|
|
174
|
+
deploy_params, need_download = self._get_deploy_params_from_args()
|
|
175
|
+
if need_download:
|
|
176
|
+
local_model_files = self._download_model_files(
|
|
177
|
+
deploy_params["model_source"], deploy_params["model_files"], False
|
|
178
|
+
)
|
|
179
|
+
deploy_params["model_files"] = local_model_files
|
|
180
|
+
self._load_model_headless(**deploy_params)
|
|
181
|
+
|
|
155
182
|
if self._use_gui:
|
|
156
183
|
initialize_custom_gui_method = getattr(self, "initialize_custom_gui", None)
|
|
157
184
|
original_initialize_custom_gui_method = getattr(
|
|
@@ -224,10 +251,10 @@ class Inference:
|
|
|
224
251
|
self.get_info = self._check_serve_before_call(self.get_info)
|
|
225
252
|
|
|
226
253
|
self.cache = InferenceImageCache(
|
|
227
|
-
maxsize=
|
|
228
|
-
ttl=
|
|
254
|
+
maxsize=sly_env.smart_cache_size(),
|
|
255
|
+
ttl=sly_env.smart_cache_ttl(),
|
|
229
256
|
is_persistent=True,
|
|
230
|
-
base_folder=
|
|
257
|
+
base_folder=sly_env.smart_cache_container_dir(),
|
|
231
258
|
log_progress=True,
|
|
232
259
|
)
|
|
233
260
|
|
|
@@ -248,9 +275,18 @@ class Inference:
|
|
|
248
275
|
)
|
|
249
276
|
device = "cpu"
|
|
250
277
|
|
|
251
|
-
def
|
|
278
|
+
def _load_json_file(self, file_path: str) -> dict:
|
|
279
|
+
if isinstance(file_path, str):
|
|
280
|
+
if sly_fs.file_exists(file_path) and sly_fs.get_file_ext(file_path) == ".json":
|
|
281
|
+
return sly_json.load_json_file(file_path)
|
|
282
|
+
else:
|
|
283
|
+
raise ValueError("File not found or invalid file format.")
|
|
284
|
+
else:
|
|
285
|
+
raise ValueError("Invalid file. Please provide a valid '.json' file.")
|
|
286
|
+
|
|
287
|
+
def _load_models_json_file(self, models: str) -> List[Dict[str, Any]]:
|
|
252
288
|
"""
|
|
253
|
-
Loads
|
|
289
|
+
Loads dictionary from the provided file.
|
|
254
290
|
"""
|
|
255
291
|
if isinstance(models, str):
|
|
256
292
|
if sly_fs.file_exists(models) and sly_fs.get_file_ext(models) == ".json":
|
|
@@ -258,9 +294,7 @@ class Inference:
|
|
|
258
294
|
else:
|
|
259
295
|
raise ValueError("File not found or invalid file format.")
|
|
260
296
|
else:
|
|
261
|
-
raise ValueError(
|
|
262
|
-
"Invalid models file. Please provide a valid '.json' file with list of model configurations."
|
|
263
|
-
)
|
|
297
|
+
raise ValueError("Invalid file. Please provide a valid '.json' file.")
|
|
264
298
|
|
|
265
299
|
if not isinstance(models, list):
|
|
266
300
|
raise ValueError("models parameters must be a list of dicts")
|
|
@@ -394,13 +428,13 @@ class Inference:
|
|
|
394
428
|
else:
|
|
395
429
|
progress = None
|
|
396
430
|
|
|
397
|
-
if
|
|
431
|
+
if sly_fs.dir_exists(src_path) or sly_fs.file_exists(
|
|
398
432
|
src_path
|
|
399
433
|
): # only during debug, has no effect in production
|
|
400
434
|
dst_path = os.path.abspath(src_path)
|
|
401
435
|
logger.info(f"File {dst_path} found.")
|
|
402
436
|
elif src_path.startswith("/"): # folder from Team Files
|
|
403
|
-
team_id =
|
|
437
|
+
team_id = sly_env.team_id()
|
|
404
438
|
|
|
405
439
|
if src_path.endswith("/") and self.api.file.dir_exists(team_id, src_path):
|
|
406
440
|
|
|
@@ -436,7 +470,7 @@ class Inference:
|
|
|
436
470
|
def download_file(team_id, src_path, dst_path, progress_cb=None):
|
|
437
471
|
self.api.file.download(team_id, src_path, dst_path, progress_cb=progress_cb)
|
|
438
472
|
|
|
439
|
-
file_info = self.api.file.get_info_by_path(
|
|
473
|
+
file_info = self.api.file.get_info_by_path(sly_env.team_id(), src_path)
|
|
440
474
|
if progress is None:
|
|
441
475
|
download_file(team_id, src_path, dst_path)
|
|
442
476
|
else:
|
|
@@ -451,8 +485,8 @@ class Inference:
|
|
|
451
485
|
logger.info(f"📥 File {basename} has been successfully downloaded from Team Files")
|
|
452
486
|
logger.info(f"File {basename} path: {dst_path}")
|
|
453
487
|
else: # external url
|
|
454
|
-
if not
|
|
455
|
-
|
|
488
|
+
if not sly_fs.dir_exists(os.path.dirname(dst_path)):
|
|
489
|
+
sly_fs.mkdir(os.path.dirname(dst_path))
|
|
456
490
|
|
|
457
491
|
def download_external_file(url, save_path, progress=None):
|
|
458
492
|
def download_content(save_path, progress_cb=None):
|
|
@@ -531,13 +565,15 @@ class Inference:
|
|
|
531
565
|
def _checkpoints_cache_dir(self):
|
|
532
566
|
return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
|
|
533
567
|
|
|
534
|
-
def _download_model_files(
|
|
568
|
+
def _download_model_files(
|
|
569
|
+
self, model_source: str, model_files: List[str], log_progress: bool = True
|
|
570
|
+
) -> dict:
|
|
535
571
|
if model_source == ModelSource.PRETRAINED:
|
|
536
|
-
return self._download_pretrained_model(model_files)
|
|
572
|
+
return self._download_pretrained_model(model_files, log_progress)
|
|
537
573
|
elif model_source == ModelSource.CUSTOM:
|
|
538
|
-
return self._download_custom_model(model_files)
|
|
574
|
+
return self._download_custom_model(model_files, log_progress)
|
|
539
575
|
|
|
540
|
-
def _download_pretrained_model(self, model_files: dict):
|
|
576
|
+
def _download_pretrained_model(self, model_files: dict, log_progress: bool = True):
|
|
541
577
|
"""
|
|
542
578
|
Downloads the pretrained model data.
|
|
543
579
|
"""
|
|
@@ -564,54 +600,61 @@ class Inference:
|
|
|
564
600
|
logger.debug(f"Model: '{file_name}' was found in model dir")
|
|
565
601
|
continue
|
|
566
602
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
603
|
+
if log_progress:
|
|
604
|
+
with self.gui.download_progress(
|
|
605
|
+
message=f"Downloading: '{file_name}'",
|
|
606
|
+
total=file_size,
|
|
607
|
+
unit="bytes",
|
|
608
|
+
unit_scale=True,
|
|
609
|
+
) as download_pbar:
|
|
610
|
+
self.gui.download_progress.show()
|
|
611
|
+
sly_fs.download(
|
|
612
|
+
url=file_url, save_path=file_path, progress=download_pbar.update
|
|
613
|
+
)
|
|
614
|
+
else:
|
|
615
|
+
sly_fs.download(url=file_url, save_path=file_path)
|
|
579
616
|
local_model_files[file] = file_path
|
|
580
617
|
else:
|
|
581
618
|
local_model_files[file] = file_url
|
|
582
|
-
|
|
619
|
+
|
|
620
|
+
if log_progress:
|
|
621
|
+
self.gui.download_progress.hide()
|
|
583
622
|
return local_model_files
|
|
584
623
|
|
|
585
|
-
def _download_custom_model(self, model_files: dict):
|
|
624
|
+
def _download_custom_model(self, model_files: dict, log_progress: bool = True):
|
|
586
625
|
"""
|
|
587
626
|
Downloads the custom model data.
|
|
588
627
|
"""
|
|
589
|
-
|
|
590
|
-
team_id = env.team_id()
|
|
628
|
+
team_id = sly_env.team_id()
|
|
591
629
|
local_model_files = {}
|
|
592
|
-
|
|
593
630
|
for file in model_files:
|
|
594
631
|
file_url = model_files[file]
|
|
595
632
|
file_info = self.api.file.get_info_by_path(team_id, file_url)
|
|
596
633
|
if file_info is None:
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
634
|
+
if sly_fs.file_exists(file_url):
|
|
635
|
+
local_model_files[file] = file_url
|
|
636
|
+
continue
|
|
637
|
+
else:
|
|
638
|
+
raise FileNotFoundError(f"File '{file_url}' not found in Team Files")
|
|
600
639
|
file_size = file_info.sizeb
|
|
601
640
|
file_name = os.path.basename(file_url)
|
|
602
641
|
file_path = os.path.join(self.model_dir, file_name)
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
642
|
+
if log_progress:
|
|
643
|
+
with self.gui.download_progress(
|
|
644
|
+
message=f"Downloading: '{file_name}'",
|
|
645
|
+
total=file_size,
|
|
646
|
+
unit="bytes",
|
|
647
|
+
unit_scale=True,
|
|
648
|
+
) as download_pbar:
|
|
649
|
+
self.gui.download_progress.show()
|
|
650
|
+
self.api.file.download(
|
|
651
|
+
team_id, file_url, file_path, progress_cb=download_pbar.update
|
|
652
|
+
)
|
|
653
|
+
else:
|
|
654
|
+
self.api.file.download(team_id, file_url, file_path)
|
|
613
655
|
local_model_files[file] = file_path
|
|
614
|
-
|
|
656
|
+
if log_progress:
|
|
657
|
+
self.gui.download_progress.hide()
|
|
615
658
|
return local_model_files
|
|
616
659
|
|
|
617
660
|
def _load_model(self, deploy_params: dict):
|
|
@@ -647,6 +690,13 @@ class Inference:
|
|
|
647
690
|
if model_source == ModelSource.CUSTOM:
|
|
648
691
|
self._set_model_meta_custom_model(model_info)
|
|
649
692
|
self._set_checkpoint_info_custom_model(deploy_params)
|
|
693
|
+
|
|
694
|
+
try:
|
|
695
|
+
if is_production():
|
|
696
|
+
self._add_workflow_input(model_source, model_files, model_info)
|
|
697
|
+
except Exception as e:
|
|
698
|
+
logger.warning(f"Failed to add input to the workflow: {repr(e)}")
|
|
699
|
+
|
|
650
700
|
self._load_model(deploy_params)
|
|
651
701
|
if self._model_meta is None:
|
|
652
702
|
self._set_model_meta_from_classes()
|
|
@@ -679,7 +729,7 @@ class Inference:
|
|
|
679
729
|
model_info.get("artifacts_dir"), "checkpoints", checkpoint_name
|
|
680
730
|
)
|
|
681
731
|
checkpoint_file_info = self.api.file.get_info_by_path(
|
|
682
|
-
|
|
732
|
+
sly_env.team_id(), checkpoint_file_path
|
|
683
733
|
)
|
|
684
734
|
if checkpoint_file_info is None:
|
|
685
735
|
checkpoint_url = None
|
|
@@ -1253,18 +1303,18 @@ class Inference:
|
|
|
1253
1303
|
logger.debug("Inferring image_url...", extra={"state": state})
|
|
1254
1304
|
settings = self._get_inference_settings(state)
|
|
1255
1305
|
image_url = state["image_url"]
|
|
1256
|
-
ext =
|
|
1306
|
+
ext = sly_fs.get_file_ext(image_url)
|
|
1257
1307
|
if ext == "":
|
|
1258
1308
|
ext = ".jpg"
|
|
1259
1309
|
image_path = os.path.join(get_data_dir(), rand_str(15) + ext)
|
|
1260
|
-
|
|
1310
|
+
sly_fs.download(image_url, image_path)
|
|
1261
1311
|
logger.debug("Inference settings:", extra=settings)
|
|
1262
1312
|
logger.debug(f"Downloaded path: {image_path}")
|
|
1263
1313
|
anns, slides_data = self._inference_auto(
|
|
1264
1314
|
[image_path],
|
|
1265
1315
|
settings=settings,
|
|
1266
1316
|
)
|
|
1267
|
-
|
|
1317
|
+
sly_fs.silent_remove(image_path)
|
|
1268
1318
|
return self._format_output(anns, slides_data)[0]
|
|
1269
1319
|
|
|
1270
1320
|
def _inference_video_id(self, api: Api, state: dict, async_inference_request_uuid: str = None):
|
|
@@ -2208,28 +2258,46 @@ class Inference:
|
|
|
2208
2258
|
|
|
2209
2259
|
if is_debug_with_sly_net():
|
|
2210
2260
|
# advanced debug for Supervisely Team
|
|
2211
|
-
logger.
|
|
2261
|
+
logger.warning(
|
|
2212
2262
|
"Serving is running in advanced development mode with Supervisely VPN Network"
|
|
2213
2263
|
)
|
|
2214
|
-
team_id =
|
|
2264
|
+
team_id = sly_env.team_id()
|
|
2215
2265
|
# sly_app_development.supervisely_vpn_network(action="down") # for debug
|
|
2216
2266
|
sly_app_development.supervisely_vpn_network(action="up")
|
|
2217
2267
|
task = sly_app_development.create_debug_task(team_id, port="8000")
|
|
2218
2268
|
self._task_id = task["id"]
|
|
2219
2269
|
os.environ["TASK_ID"] = str(self._task_id)
|
|
2220
2270
|
else:
|
|
2221
|
-
|
|
2271
|
+
if not self._is_local_deploy:
|
|
2272
|
+
self._task_id = sly_env.task_id() if is_production() else None
|
|
2222
2273
|
|
|
2223
|
-
if isinstance(self.gui, GUI.
|
|
2274
|
+
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
2224
2275
|
self._app = Application(layout=self.get_ui())
|
|
2225
2276
|
elif isinstance(self.gui, GUI.ServingGUI):
|
|
2226
2277
|
self._app = Application(layout=self._app_layout)
|
|
2278
|
+
# elif isinstance(self.gui, GUI.InferenceGUI):
|
|
2279
|
+
# self._app = Application(layout=self.get_ui())
|
|
2227
2280
|
else:
|
|
2228
2281
|
self._app = Application(layout=self.get_ui())
|
|
2229
2282
|
|
|
2230
2283
|
server = self._app.get_server()
|
|
2231
2284
|
self._app.set_ready_check_function(self.is_model_deployed)
|
|
2232
2285
|
|
|
2286
|
+
if self._is_local_deploy:
|
|
2287
|
+
# Predict and shutdown
|
|
2288
|
+
if any(
|
|
2289
|
+
[
|
|
2290
|
+
self._args.predict_project,
|
|
2291
|
+
self._args.predict_dataset,
|
|
2292
|
+
self._args.predict_dir,
|
|
2293
|
+
self._args.predict_image,
|
|
2294
|
+
]
|
|
2295
|
+
):
|
|
2296
|
+
self._inference_by_local_deploy_args()
|
|
2297
|
+
# Gracefully shut down the server
|
|
2298
|
+
self._app.shutdown()
|
|
2299
|
+
# else: run server after endpoints
|
|
2300
|
+
|
|
2233
2301
|
@call_on_autostart()
|
|
2234
2302
|
def autostart_func():
|
|
2235
2303
|
gpu_count = get_gpu_count()
|
|
@@ -2725,6 +2793,285 @@ class Inference:
|
|
|
2725
2793
|
def _get_deploy_info():
|
|
2726
2794
|
return asdict(self._get_deploy_info())
|
|
2727
2795
|
|
|
2796
|
+
# Local deploy without predict args
|
|
2797
|
+
if self._is_local_deploy:
|
|
2798
|
+
self._run_server()
|
|
2799
|
+
|
|
2800
|
+
def _parse_local_deploy_args(self):
|
|
2801
|
+
parser = argparse.ArgumentParser(description="Run Inference Serving")
|
|
2802
|
+
|
|
2803
|
+
# Deploy args
|
|
2804
|
+
parser.add_argument(
|
|
2805
|
+
"--model",
|
|
2806
|
+
type=str,
|
|
2807
|
+
help="Name of the pretrained model or path to custom checkpoint file",
|
|
2808
|
+
)
|
|
2809
|
+
parser.add_argument(
|
|
2810
|
+
"--device",
|
|
2811
|
+
type=str,
|
|
2812
|
+
choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"],
|
|
2813
|
+
default="cuda:0",
|
|
2814
|
+
help="Device to use for inference (default: 'cuda:0')",
|
|
2815
|
+
)
|
|
2816
|
+
parser.add_argument(
|
|
2817
|
+
"--runtime",
|
|
2818
|
+
type=str,
|
|
2819
|
+
choices=[RuntimeType.PYTORCH, RuntimeType.ONNXRUNTIME, RuntimeType.TENSORRT],
|
|
2820
|
+
default=RuntimeType.PYTORCH,
|
|
2821
|
+
help="Runtime type for inference (default: PYTORCH)",
|
|
2822
|
+
)
|
|
2823
|
+
# -------------------------- #
|
|
2824
|
+
|
|
2825
|
+
# Predict args
|
|
2826
|
+
parser.add_argument("--predict-project", type=int, required=False, help="ID of the project")
|
|
2827
|
+
parser.add_argument(
|
|
2828
|
+
"--predict-dataset",
|
|
2829
|
+
type=lambda x: [int(i) for i in x.split(",")] if "," in x else int(x),
|
|
2830
|
+
required=False,
|
|
2831
|
+
help="ID of the dataset or a comma-separated list of dataset IDs",
|
|
2832
|
+
)
|
|
2833
|
+
parser.add_argument(
|
|
2834
|
+
"--predict-dir",
|
|
2835
|
+
type=str,
|
|
2836
|
+
required=False,
|
|
2837
|
+
help="Not implemented yet. Path to the local directory with images",
|
|
2838
|
+
)
|
|
2839
|
+
parser.add_argument(
|
|
2840
|
+
"--predict-image",
|
|
2841
|
+
type=str,
|
|
2842
|
+
required=False,
|
|
2843
|
+
help="Image ID on Supervisely instance or path to local image",
|
|
2844
|
+
)
|
|
2845
|
+
# -------------------------- #
|
|
2846
|
+
|
|
2847
|
+
# Output args
|
|
2848
|
+
parser.add_argument("--output", type=str, required=False, help="Not implemented yet")
|
|
2849
|
+
parser.add_argument("--output-dir", type=str, required=False, help="Not implemented yet")
|
|
2850
|
+
# -------------------------- #
|
|
2851
|
+
|
|
2852
|
+
# Parse arguments
|
|
2853
|
+
args, _ = parser.parse_known_args()
|
|
2854
|
+
if args.model is None:
|
|
2855
|
+
# raise ValueError("Argument '--model' is required for local deployment")
|
|
2856
|
+
return None, False
|
|
2857
|
+
if isinstance(args.predict_dataset, int):
|
|
2858
|
+
args.predict_dataset = [args.predict_dataset]
|
|
2859
|
+
if args.predict_image is not None:
|
|
2860
|
+
if args.predict_image.isdigit():
|
|
2861
|
+
args.predict_image = int(args.predict_image)
|
|
2862
|
+
return args, True
|
|
2863
|
+
|
|
2864
|
+
def _get_pretrained_model_params_from_args(self):
|
|
2865
|
+
model_files = None
|
|
2866
|
+
model_source = None
|
|
2867
|
+
model_info = None
|
|
2868
|
+
need_download = True
|
|
2869
|
+
|
|
2870
|
+
model = self._args.model
|
|
2871
|
+
for m in self.pretrained_models:
|
|
2872
|
+
meta = m.get("meta", None)
|
|
2873
|
+
if meta is None:
|
|
2874
|
+
continue
|
|
2875
|
+
model_name = meta.get("model_name", None)
|
|
2876
|
+
if model_name is None:
|
|
2877
|
+
continue
|
|
2878
|
+
m_files = meta.get("model_files", None)
|
|
2879
|
+
if m_files is None:
|
|
2880
|
+
continue
|
|
2881
|
+
checkpoint = m_files.get("checkpoint", None)
|
|
2882
|
+
if checkpoint is None:
|
|
2883
|
+
continue
|
|
2884
|
+
if model == m["meta"]["model_name"]:
|
|
2885
|
+
model_info = m
|
|
2886
|
+
model_source = ModelSource.PRETRAINED
|
|
2887
|
+
model_files = {"checkpoint": checkpoint}
|
|
2888
|
+
config = m_files.get("config", None)
|
|
2889
|
+
if config is not None:
|
|
2890
|
+
model_files["config"] = config
|
|
2891
|
+
break
|
|
2892
|
+
|
|
2893
|
+
return model_files, model_source, model_info, need_download
|
|
2894
|
+
|
|
2895
|
+
def _get_custom_model_params_from_args(self):
|
|
2896
|
+
def _load_experiment_info(artifacts_dir):
|
|
2897
|
+
experiment_path = os.path.join(artifacts_dir, "experiment_info.json")
|
|
2898
|
+
model_info = self._load_json_file(experiment_path)
|
|
2899
|
+
original_model_files = model_info.get("model_files")
|
|
2900
|
+
if not original_model_files:
|
|
2901
|
+
raise ValueError("Invalid 'experiment_info.json'. Missing 'model_files' key.")
|
|
2902
|
+
return model_info, original_model_files
|
|
2903
|
+
|
|
2904
|
+
def _prepare_local_model_files(artifacts_dir, checkpoint_path, original_model_files):
|
|
2905
|
+
return {k: os.path.join(artifacts_dir, v) for k, v in original_model_files.items()} | {
|
|
2906
|
+
"checkpoint": checkpoint_path
|
|
2907
|
+
}
|
|
2908
|
+
|
|
2909
|
+
def _download_remote_files(team_id, artifacts_dir, local_artifacts_dir):
|
|
2910
|
+
sly_fs.mkdir(local_artifacts_dir, True)
|
|
2911
|
+
file_infos = self.api.file.list(team_id, artifacts_dir, False, "fileinfo")
|
|
2912
|
+
remote_paths = [f.path for f in file_infos if not f.is_dir]
|
|
2913
|
+
local_paths = [
|
|
2914
|
+
os.path.join(local_artifacts_dir, f.name) for f in file_infos if not f.is_dir
|
|
2915
|
+
]
|
|
2916
|
+
|
|
2917
|
+
coro = self.api.file.download_bulk_async(team_id, remote_paths, local_paths)
|
|
2918
|
+
loop = get_or_create_event_loop()
|
|
2919
|
+
if loop.is_running():
|
|
2920
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
2921
|
+
future.result()
|
|
2922
|
+
else:
|
|
2923
|
+
loop.run_until_complete(coro)
|
|
2924
|
+
|
|
2925
|
+
model_source = ModelSource.CUSTOM
|
|
2926
|
+
need_download = False
|
|
2927
|
+
checkpoint_path = self._args.model
|
|
2928
|
+
|
|
2929
|
+
if not os.path.isfile(checkpoint_path):
|
|
2930
|
+
team_id = sly_env.team_id(raise_not_found=False)
|
|
2931
|
+
if not team_id:
|
|
2932
|
+
raise ValueError(
|
|
2933
|
+
"Team ID not found in env. Required for remote custom checkpoints."
|
|
2934
|
+
)
|
|
2935
|
+
file_info = self.api.file.get_info_by_path(team_id, checkpoint_path)
|
|
2936
|
+
if not file_info:
|
|
2937
|
+
raise ValueError(
|
|
2938
|
+
f"Couldn't find: '{checkpoint_path}' locally or remotely in Team ID."
|
|
2939
|
+
)
|
|
2940
|
+
need_download = True
|
|
2941
|
+
|
|
2942
|
+
artifacts_dir = os.path.dirname(os.path.dirname(checkpoint_path))
|
|
2943
|
+
if not need_download:
|
|
2944
|
+
model_info, original_model_files = _load_experiment_info(artifacts_dir)
|
|
2945
|
+
model_files = _prepare_local_model_files(
|
|
2946
|
+
artifacts_dir, checkpoint_path, original_model_files
|
|
2947
|
+
)
|
|
2948
|
+
else:
|
|
2949
|
+
local_artifacts_dir = os.path.join(
|
|
2950
|
+
self.model_dir, "local_deploy", os.path.basename(artifacts_dir)
|
|
2951
|
+
)
|
|
2952
|
+
_download_remote_files(team_id, artifacts_dir, local_artifacts_dir)
|
|
2953
|
+
|
|
2954
|
+
model_info, original_model_files = _load_experiment_info(local_artifacts_dir)
|
|
2955
|
+
model_files = _prepare_local_model_files(
|
|
2956
|
+
local_artifacts_dir, checkpoint_path, original_model_files
|
|
2957
|
+
)
|
|
2958
|
+
return model_files, model_source, model_info, need_download
|
|
2959
|
+
|
|
2960
|
+
def _get_deploy_params_from_args(self):
|
|
2961
|
+
# Ensure model directory exists
|
|
2962
|
+
device = self._args.device if self._args.device else "cuda:0"
|
|
2963
|
+
runtime = self._args.runtime if self._args.runtime else RuntimeType.PYTORCH
|
|
2964
|
+
|
|
2965
|
+
model_files, model_source, model_info, need_download = (
|
|
2966
|
+
self._get_pretrained_model_params_from_args()
|
|
2967
|
+
)
|
|
2968
|
+
if model_source is None:
|
|
2969
|
+
model_files, model_source, model_info, need_download = (
|
|
2970
|
+
self._get_custom_model_params_from_args()
|
|
2971
|
+
)
|
|
2972
|
+
|
|
2973
|
+
if model_source is None:
|
|
2974
|
+
raise ValueError("Couldn't create 'model_source' from args")
|
|
2975
|
+
if model_files is None:
|
|
2976
|
+
raise ValueError("Couldn't create 'model_files' from args")
|
|
2977
|
+
if model_info is None:
|
|
2978
|
+
raise ValueError("Couldn't create 'model_info' from args")
|
|
2979
|
+
|
|
2980
|
+
deploy_params = {
|
|
2981
|
+
"model_files": model_files,
|
|
2982
|
+
"model_source": model_source,
|
|
2983
|
+
"model_info": model_info,
|
|
2984
|
+
"device": device,
|
|
2985
|
+
"runtime": runtime,
|
|
2986
|
+
}
|
|
2987
|
+
|
|
2988
|
+
logger.info(f"Deploy parameters: {deploy_params}")
|
|
2989
|
+
return deploy_params, need_download
|
|
2990
|
+
|
|
2991
|
+
def _run_server(self):
|
|
2992
|
+
config = uvicorn.Config(app=self._app, host="0.0.0.0", port=8000, ws="websockets")
|
|
2993
|
+
self._uvicorn_server = uvicorn.Server(config)
|
|
2994
|
+
self._uvicorn_server.run()
|
|
2995
|
+
|
|
2996
|
+
def _inference_by_local_deploy_args(self):
|
|
2997
|
+
def predict_project_by_args(api: Api, project_id: int, dataset_ids: List[int] = None):
|
|
2998
|
+
source_project = api.project.get_info_by_id(project_id)
|
|
2999
|
+
workspace_id = source_project.workspace_id
|
|
3000
|
+
output_project = api.project.create(
|
|
3001
|
+
workspace_id, f"{source_project.name} predicted", change_name_if_conflict=True
|
|
3002
|
+
)
|
|
3003
|
+
results = self._inference_project_id(
|
|
3004
|
+
api=self.api,
|
|
3005
|
+
state={
|
|
3006
|
+
"projectId": project_id,
|
|
3007
|
+
"dataset_ids": dataset_ids,
|
|
3008
|
+
"output_project_id": output_project.id,
|
|
3009
|
+
},
|
|
3010
|
+
)
|
|
3011
|
+
|
|
3012
|
+
def predict_datasets_by_args(api: Api, dataset_ids: List[int]):
|
|
3013
|
+
dataset_infos = [api.dataset.get_info_by_id(dataset_id) for dataset_id in dataset_ids]
|
|
3014
|
+
project_ids = list(set([dataset_info.project_id for dataset_info in dataset_infos]))
|
|
3015
|
+
if len(project_ids) > 1:
|
|
3016
|
+
raise ValueError("All datasets should belong to the same project")
|
|
3017
|
+
predict_project_by_args(api, project_ids[0], dataset_ids)
|
|
3018
|
+
|
|
3019
|
+
def predict_image_by_args(api: Api, image: Union[str, int]):
|
|
3020
|
+
def predict_image_np(image_np):
|
|
3021
|
+
settings = self._get_inference_settings({})
|
|
3022
|
+
anns, _ = self._inference_auto([image_np], settings)
|
|
3023
|
+
if len(anns) == 0:
|
|
3024
|
+
return Annotation(img_size=image_np.shape[:2])
|
|
3025
|
+
ann = anns[0]
|
|
3026
|
+
return ann
|
|
3027
|
+
|
|
3028
|
+
if isinstance(image, int):
|
|
3029
|
+
image_np = api.image.download_np(image)
|
|
3030
|
+
ann = predict_image_np(image_np)
|
|
3031
|
+
api.annotation.upload_ann(image, ann)
|
|
3032
|
+
elif isinstance(image, str):
|
|
3033
|
+
if sly_fs.file_exists(self._args.predict):
|
|
3034
|
+
image_np = sly_image.read(self._args.predict)
|
|
3035
|
+
ann = predict_image_np(image_np)
|
|
3036
|
+
pred_ann_path = image + ".json"
|
|
3037
|
+
sly_json.dump_json_file(ann.to_json(), pred_ann_path)
|
|
3038
|
+
# Save image for debug
|
|
3039
|
+
# ann.draw_pretty(image_np)
|
|
3040
|
+
# pred_path = os.path.join(os.path.dirname(self._args.predict), "pred_" + os.path.basename(self._args.predict))
|
|
3041
|
+
# sly_image.write(pred_path, image_np)
|
|
3042
|
+
|
|
3043
|
+
if self._args.predict_project is not None:
|
|
3044
|
+
predict_project_by_args(self.api, self._args.predict_project)
|
|
3045
|
+
elif self._args.predict_dataset is not None:
|
|
3046
|
+
predict_datasets_by_args(self.api, self._args.predict_dataset)
|
|
3047
|
+
elif self._args.predict_dir is not None:
|
|
3048
|
+
raise NotImplementedError("Predict from directory is not implemented yet")
|
|
3049
|
+
elif self._args.predict_image is not None:
|
|
3050
|
+
predict_image_by_args(self.api, self._args.predict_image)
|
|
3051
|
+
|
|
3052
|
+
def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
|
|
3053
|
+
if model_source == ModelSource.PRETRAINED:
|
|
3054
|
+
checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
|
|
3055
|
+
checkpoint_name = model_info["meta"]["model_name"]
|
|
3056
|
+
else:
|
|
3057
|
+
checkpoint_name = sly_fs.get_file_name_with_ext(model_files["checkpoint"])
|
|
3058
|
+
checkpoint_url = os.path.join(
|
|
3059
|
+
model_info["artifacts_dir"], "checkpoints", checkpoint_name
|
|
3060
|
+
)
|
|
3061
|
+
|
|
3062
|
+
app_name = sly_env.app_name()
|
|
3063
|
+
meta = WorkflowMeta(node_settings=WorkflowSettings(title=f"Serve {app_name}"))
|
|
3064
|
+
|
|
3065
|
+
logger.debug(
|
|
3066
|
+
f"Workflow Input: Checkpoint URL - {checkpoint_url}, Checkpoint Name - {checkpoint_name}"
|
|
3067
|
+
)
|
|
3068
|
+
if checkpoint_url and self.api.file.exists(sly_env.team_id(), checkpoint_url):
|
|
3069
|
+
self.api.app.workflow.add_input_file(checkpoint_url, model_weight=True, meta=meta)
|
|
3070
|
+
else:
|
|
3071
|
+
logger.debug(
|
|
3072
|
+
f"Checkpoint {checkpoint_url} not found in Team Files. Cannot set workflow input"
|
|
3073
|
+
)
|
|
3074
|
+
|
|
2728
3075
|
|
|
2729
3076
|
def _get_log_extra_for_inference_request(inference_request_uuid, inference_request: dict):
|
|
2730
3077
|
log_extra = {
|
|
@@ -2998,7 +3345,7 @@ class TempImageWriter:
|
|
|
2998
3345
|
def __init__(self, format: str = "png"):
|
|
2999
3346
|
self.format = format
|
|
3000
3347
|
self.temp_dir = os.path.join(get_data_dir(), rand_str(10))
|
|
3001
|
-
|
|
3348
|
+
sly_fs.mkdir(self.temp_dir)
|
|
3002
3349
|
|
|
3003
3350
|
def write(self, image: np.ndarray):
|
|
3004
3351
|
image_path = os.path.join(self.temp_dir, f"{rand_str(10)}.{self.format}")
|
|
@@ -3006,7 +3353,7 @@ class TempImageWriter:
|
|
|
3006
3353
|
return image_path
|
|
3007
3354
|
|
|
3008
3355
|
def clean(self):
|
|
3009
|
-
|
|
3356
|
+
sly_fs.remove_dir(self.temp_dir)
|
|
3010
3357
|
|
|
3011
3358
|
|
|
3012
3359
|
def get_hardware_info(device: str) -> str:
|