supervisely 6.73.316__py3-none-any.whl → 6.73.318__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/nn/inference/cache.py +21 -0
 - supervisely/nn/inference/inference.py +341 -88
 - {supervisely-6.73.316.dist-info → supervisely-6.73.318.dist-info}/METADATA +1 -1
 - {supervisely-6.73.316.dist-info → supervisely-6.73.318.dist-info}/RECORD +8 -8
 - {supervisely-6.73.316.dist-info → supervisely-6.73.318.dist-info}/LICENSE +0 -0
 - {supervisely-6.73.316.dist-info → supervisely-6.73.318.dist-info}/WHEEL +0 -0
 - {supervisely-6.73.316.dist-info → supervisely-6.73.318.dist-info}/entry_points.txt +0 -0
 - {supervisely-6.73.316.dist-info → supervisely-6.73.318.dist-info}/top_level.txt +0 -0
 
| 
         @@ -65,6 +65,27 @@ class PersistentImageTTLCache(TTLCache): 
     | 
|
| 
       65 
65 
     | 
    
         
             
                def __init__(self, maxsize: int, ttl: int, filepath: Path):
         
     | 
| 
       66 
66 
     | 
    
         
             
                    super().__init__(maxsize, ttl)
         
     | 
| 
       67 
67 
     | 
    
         
             
                    self._base_dir = filepath
         
     | 
| 
      
 68 
     | 
    
         
            +
                
         
     | 
| 
      
 69 
     | 
    
         
            +
                def pop(self, *args, **kwargs):
         
     | 
| 
      
 70 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 71 
     | 
    
         
            +
                        super().pop(*args, **kwargs)
         
     | 
| 
      
 72 
     | 
    
         
            +
                    except Exception:
         
     | 
| 
      
 73 
     | 
    
         
            +
                        sly.logger.warn("Cache data corrupted. Cleaning the cache...", exc_info=True)
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
                        def _delitem(self, key):
         
     | 
| 
      
 76 
     | 
    
         
            +
                            try:
         
     | 
| 
      
 77 
     | 
    
         
            +
                                size = self._Cache__size.pop(key)
         
     | 
| 
      
 78 
     | 
    
         
            +
                            except:
         
     | 
| 
      
 79 
     | 
    
         
            +
                                size = 0
         
     | 
| 
      
 80 
     | 
    
         
            +
                            self._Cache__data.pop(key, None)
         
     | 
| 
      
 81 
     | 
    
         
            +
                            self._Cache__currsize -= size
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                        shutil.rmtree(self._base_dir, ignore_errors=True)
         
     | 
| 
      
 84 
     | 
    
         
            +
                        for key in self.keys():
         
     | 
| 
      
 85 
     | 
    
         
            +
                            try:
         
     | 
| 
      
 86 
     | 
    
         
            +
                                super().__delitem__(key, cache_delitem=_delitem)
         
     | 
| 
      
 87 
     | 
    
         
            +
                            except:
         
     | 
| 
      
 88 
     | 
    
         
            +
                                pass
         
     | 
| 
       68 
89 
     | 
    
         | 
| 
       69 
90 
     | 
    
         
             
                def __delitem__(self, key: Any) -> None:
         
     | 
| 
       70 
91 
     | 
    
         
             
                    self.__del_file(key)
         
     | 
| 
         @@ -66,6 +66,7 @@ from supervisely.decorators.inference import ( 
     | 
|
| 
       66 
66 
     | 
    
         
             
            )
         
     | 
| 
       67 
67 
     | 
    
         
             
            from supervisely.geometry.any_geometry import AnyGeometry
         
     | 
| 
       68 
68 
     | 
    
         
             
            from supervisely.imaging.color import get_predefined_colors
         
     | 
| 
      
 69 
     | 
    
         
            +
            from supervisely.io.fs import list_files
         
     | 
| 
       69 
70 
     | 
    
         
             
            from supervisely.nn.inference.cache import InferenceImageCache
         
     | 
| 
       70 
71 
     | 
    
         
             
            from supervisely.nn.prediction_dto import Prediction
         
     | 
| 
       71 
72 
     | 
    
         
             
            from supervisely.nn.utils import (
         
     | 
| 
         @@ -80,6 +81,7 @@ from supervisely.project.download import download_to_cache, read_from_cached_pro 
     | 
|
| 
       80 
81 
     | 
    
         
             
            from supervisely.project.project_meta import ProjectMeta
         
     | 
| 
       81 
82 
     | 
    
         
             
            from supervisely.sly_logger import logger
         
     | 
| 
       82 
83 
     | 
    
         
             
            from supervisely.task.progress import Progress
         
     | 
| 
      
 84 
     | 
    
         
            +
            from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS
         
     | 
| 
       83 
85 
     | 
    
         | 
| 
       84 
86 
     | 
    
         
             
            try:
         
     | 
| 
       85 
87 
     | 
    
         
             
                from typing import Literal
         
     | 
| 
         @@ -112,8 +114,8 @@ class Inference: 
     | 
|
| 
       112 
114 
     | 
    
         
             
                    use_serving_gui_template: Optional[bool] = False,
         
     | 
| 
       113 
115 
     | 
    
         
             
                ):
         
     | 
| 
       114 
116 
     | 
    
         | 
| 
      
 117 
     | 
    
         
            +
                    self.pretrained_models = self._load_models_json_file(self.MODELS) if self.MODELS else None
         
     | 
| 
       115 
118 
     | 
    
         
             
                    self._args, self._is_local_deploy = self._parse_local_deploy_args()
         
     | 
| 
       116 
     | 
    
         
            -
             
     | 
| 
       117 
119 
     | 
    
         
             
                    if model_dir is None:
         
     | 
| 
       118 
120 
     | 
    
         
             
                        if self._is_local_deploy is True:
         
     | 
| 
       119 
121 
     | 
    
         
             
                            try:
         
     | 
| 
         @@ -143,7 +145,6 @@ class Inference: 
     | 
|
| 
       143 
145 
     | 
    
         
             
                    self._autostart_delay_time = 5 * 60  # 5 min
         
     | 
| 
       144 
146 
     | 
    
         
             
                    self._tracker = None
         
     | 
| 
       145 
147 
     | 
    
         
             
                    self._hardware: str = None
         
     | 
| 
       146 
     | 
    
         
            -
                    self.pretrained_models = self._load_models_json_file(self.MODELS) if self.MODELS else None
         
     | 
| 
       147 
148 
     | 
    
         
             
                    if custom_inference_settings is None:
         
     | 
| 
       148 
149 
     | 
    
         
             
                        if self.INFERENCE_SETTINGS is not None:
         
     | 
| 
       149 
150 
     | 
    
         
             
                            custom_inference_settings = self.INFERENCE_SETTINGS
         
     | 
| 
         @@ -169,7 +170,6 @@ class Inference: 
     | 
|
| 
       169 
170 
     | 
    
         
             
                    self.load_model = LOAD_MODEL_DECORATOR(self.load_model)
         
     | 
| 
       170 
171 
     | 
    
         | 
| 
       171 
172 
     | 
    
         
             
                    if self._is_local_deploy:
         
     | 
| 
       172 
     | 
    
         
            -
                        # self._args = self._parse_local_deploy_args()
         
     | 
| 
       173 
173 
     | 
    
         
             
                        self._use_gui = False
         
     | 
| 
       174 
174 
     | 
    
         
             
                        deploy_params, need_download = self._get_deploy_params_from_args()
         
     | 
| 
       175 
175 
     | 
    
         
             
                        if need_download:
         
     | 
| 
         @@ -379,7 +379,8 @@ class Inference: 
     | 
|
| 
       379 
379 
     | 
    
         | 
| 
       380 
380 
     | 
    
         
             
                    if isinstance(self.gui, GUI.ServingGUITemplate):
         
     | 
| 
       381 
381 
     | 
    
         
             
                        self._app_layout = Container(
         
     | 
| 
       382 
     | 
    
         
            -
                            [self._user_layout_card, self._api_request_model_layout, self.get_ui()], 
     | 
| 
      
 382 
     | 
    
         
            +
                            [self._user_layout_card, self._api_request_model_layout, self.get_ui()],
         
     | 
| 
      
 383 
     | 
    
         
            +
                            gap=5,
         
     | 
| 
       383 
384 
     | 
    
         
             
                        )
         
     | 
| 
       384 
385 
     | 
    
         
             
                        return
         
     | 
| 
       385 
386 
     | 
    
         | 
| 
         @@ -399,7 +400,8 @@ class Inference: 
     | 
|
| 
       399 
400 
     | 
    
         
             
                        )
         
     | 
| 
       400 
401 
     | 
    
         | 
| 
       401 
402 
     | 
    
         
             
                    self._app_layout = Container(
         
     | 
| 
       402 
     | 
    
         
            -
                        [self._user_layout_card, self._api_request_model_layout, self.get_ui()], 
     | 
| 
      
 403 
     | 
    
         
            +
                        [self._user_layout_card, self._api_request_model_layout, self.get_ui()],
         
     | 
| 
      
 404 
     | 
    
         
            +
                        gap=5,
         
     | 
| 
       403 
405 
     | 
    
         
             
                    )
         
     | 
| 
       404 
406 
     | 
    
         | 
| 
       405 
407 
     | 
    
         
             
                def support_custom_models(self) -> bool:
         
     | 
| 
         @@ -609,7 +611,9 @@ class Inference: 
     | 
|
| 
       609 
611 
     | 
    
         
             
                                    ) as download_pbar:
         
     | 
| 
       610 
612 
     | 
    
         
             
                                        self.gui.download_progress.show()
         
     | 
| 
       611 
613 
     | 
    
         
             
                                        sly_fs.download(
         
     | 
| 
       612 
     | 
    
         
            -
                                            url=file_url, 
     | 
| 
      
 614 
     | 
    
         
            +
                                            url=file_url,
         
     | 
| 
      
 615 
     | 
    
         
            +
                                            save_path=file_path,
         
     | 
| 
      
 616 
     | 
    
         
            +
                                            progress=download_pbar.update,
         
     | 
| 
       613 
617 
     | 
    
         
             
                                        )
         
     | 
| 
       614 
618 
     | 
    
         
             
                                else:
         
     | 
| 
       615 
619 
     | 
    
         
             
                                    sly_fs.download(url=file_url, save_path=file_path)
         
     | 
| 
         @@ -887,7 +891,14 @@ class Inference: 
     | 
|
| 
       887 
891 
     | 
    
         
             
                @property
         
     | 
| 
       888 
892 
     | 
    
         
             
                def api(self) -> Api:
         
     | 
| 
       889 
893 
     | 
    
         
             
                    if self._api is None:
         
     | 
| 
       890 
     | 
    
         
            -
                         
     | 
| 
      
 894 
     | 
    
         
            +
                        if (
         
     | 
| 
      
 895 
     | 
    
         
            +
                            self._is_local_deploy
         
     | 
| 
      
 896 
     | 
    
         
            +
                            and os.getenv("SERVER_ADDRESS") is None
         
     | 
| 
      
 897 
     | 
    
         
            +
                            and os.getenv("API_TOKEN") is None
         
     | 
| 
      
 898 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 899 
     | 
    
         
            +
                            return None
         
     | 
| 
      
 900 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 901 
     | 
    
         
            +
                            self._api = Api()
         
     | 
| 
       891 
902 
     | 
    
         
             
                    return self._api
         
     | 
| 
       892 
903 
     | 
    
         | 
| 
       893 
904 
     | 
    
         
             
                @property
         
     | 
| 
         @@ -2372,7 +2383,7 @@ class Inference: 
     | 
|
| 
       2372 
2383 
     | 
    
         
             
                    logger.debug("Scheduled task.", extra={"inference_request_uuid": inference_request_uuid})
         
     | 
| 
       2373 
2384 
     | 
    
         | 
| 
       2374 
2385 
     | 
    
         
             
                def serve(self):
         
     | 
| 
       2375 
     | 
    
         
            -
                    if not self._use_gui:
         
     | 
| 
      
 2386 
     | 
    
         
            +
                    if not self._use_gui and not self._is_local_deploy:
         
     | 
| 
       2376 
2387 
     | 
    
         
             
                        Progress("Deploying model ...", 1)
         
     | 
| 
       2377 
2388 
     | 
    
         | 
| 
       2378 
2389 
     | 
    
         
             
                    if is_debug_with_sly_net():
         
     | 
| 
         @@ -2390,6 +2401,21 @@ class Inference: 
     | 
|
| 
       2390 
2401 
     | 
    
         
             
                        if not self._is_local_deploy:
         
     | 
| 
       2391 
2402 
     | 
    
         
             
                            self._task_id = sly_env.task_id() if is_production() else None
         
     | 
| 
       2392 
2403 
     | 
    
         | 
| 
      
 2404 
     | 
    
         
            +
                    if self._is_local_deploy:
         
     | 
| 
      
 2405 
     | 
    
         
            +
                        # Predict and shutdown
         
     | 
| 
      
 2406 
     | 
    
         
            +
                        if self._args.mode == "predict" and any(
         
     | 
| 
      
 2407 
     | 
    
         
            +
                            [
         
     | 
| 
      
 2408 
     | 
    
         
            +
                                self._args.input,
         
     | 
| 
      
 2409 
     | 
    
         
            +
                                self._args.project_id,
         
     | 
| 
      
 2410 
     | 
    
         
            +
                                self._args.dataset_id,
         
     | 
| 
      
 2411 
     | 
    
         
            +
                                self._args.image_id,
         
     | 
| 
      
 2412 
     | 
    
         
            +
                            ]
         
     | 
| 
      
 2413 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 2414 
     | 
    
         
            +
             
     | 
| 
      
 2415 
     | 
    
         
            +
                            self._parse_inference_settings_from_args()
         
     | 
| 
      
 2416 
     | 
    
         
            +
                            self._inference_by_local_deploy_args()
         
     | 
| 
      
 2417 
     | 
    
         
            +
                            exit(0)
         
     | 
| 
      
 2418 
     | 
    
         
            +
             
     | 
| 
       2393 
2419 
     | 
    
         
             
                    if isinstance(self.gui, GUI.InferenceGUI):
         
     | 
| 
       2394 
2420 
     | 
    
         
             
                        self._app = Application(layout=self.get_ui())
         
     | 
| 
       2395 
2421 
     | 
    
         
             
                    elif isinstance(self.gui, GUI.ServingGUI):
         
     | 
| 
         @@ -2402,22 +2428,6 @@ class Inference: 
     | 
|
| 
       2402 
2428 
     | 
    
         
             
                    server = self._app.get_server()
         
     | 
| 
       2403 
2429 
     | 
    
         
             
                    self._app.set_ready_check_function(self.is_model_deployed)
         
     | 
| 
       2404 
2430 
     | 
    
         | 
| 
       2405 
     | 
    
         
            -
                    if self._is_local_deploy:
         
     | 
| 
       2406 
     | 
    
         
            -
                        # Predict and shutdown
         
     | 
| 
       2407 
     | 
    
         
            -
                        if any(
         
     | 
| 
       2408 
     | 
    
         
            -
                            [
         
     | 
| 
       2409 
     | 
    
         
            -
                                self._args.predict_project,
         
     | 
| 
       2410 
     | 
    
         
            -
                                self._args.predict_dataset,
         
     | 
| 
       2411 
     | 
    
         
            -
                                self._args.predict_dir,
         
     | 
| 
       2412 
     | 
    
         
            -
                                self._args.predict_image,
         
     | 
| 
       2413 
     | 
    
         
            -
                            ]
         
     | 
| 
       2414 
     | 
    
         
            -
                        ):
         
     | 
| 
       2415 
     | 
    
         
            -
                            self._inference_by_local_deploy_args()
         
     | 
| 
       2416 
     | 
    
         
            -
                            # Gracefully shut down the server
         
     | 
| 
       2417 
     | 
    
         
            -
                            self._app.shutdown()
         
     | 
| 
       2418 
     | 
    
         
            -
                            exit(0)
         
     | 
| 
       2419 
     | 
    
         
            -
                    # else: run server after endpoints
         
     | 
| 
       2420 
     | 
    
         
            -
             
     | 
| 
       2421 
2431 
     | 
    
         
             
                    @call_on_autostart()
         
     | 
| 
       2422 
2432 
     | 
    
         
             
                    def autostart_func():
         
     | 
| 
       2423 
2433 
     | 
    
         
             
                        gpu_count = get_gpu_count()
         
     | 
| 
         @@ -2924,6 +2934,10 @@ class Inference: 
     | 
|
| 
       2924 
2934 
     | 
    
         
             
                def _parse_local_deploy_args(self):
         
     | 
| 
       2925 
2935 
     | 
    
         
             
                    parser = argparse.ArgumentParser(description="Run Inference Serving")
         
     | 
| 
       2926 
2936 
     | 
    
         | 
| 
      
 2937 
     | 
    
         
            +
                    # Positional args
         
     | 
| 
      
 2938 
     | 
    
         
            +
                    parser.add_argument("mode", choices=["deploy", "predict"], help="Mode of operation")
         
     | 
| 
      
 2939 
     | 
    
         
            +
                    parser.add_argument("input", nargs="?", type=str, help="Local path to input data")
         
     | 
| 
      
 2940 
     | 
    
         
            +
             
     | 
| 
       2927 
2941 
     | 
    
         
             
                    # Deploy args
         
     | 
| 
       2928 
2942 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       2929 
2943 
     | 
    
         
             
                        "--model",
         
     | 
| 
         @@ -2940,51 +2954,129 @@ class Inference: 
     | 
|
| 
       2940 
2954 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       2941 
2955 
     | 
    
         
             
                        "--runtime",
         
     | 
| 
       2942 
2956 
     | 
    
         
             
                        type=str,
         
     | 
| 
       2943 
     | 
    
         
            -
                        choices=[ 
     | 
| 
      
 2957 
     | 
    
         
            +
                        choices=[
         
     | 
| 
      
 2958 
     | 
    
         
            +
                            RuntimeType.PYTORCH,
         
     | 
| 
      
 2959 
     | 
    
         
            +
                            RuntimeType.ONNXRUNTIME,
         
     | 
| 
      
 2960 
     | 
    
         
            +
                            RuntimeType.TENSORRT,
         
     | 
| 
      
 2961 
     | 
    
         
            +
                        ],
         
     | 
| 
       2944 
2962 
     | 
    
         
             
                        default=RuntimeType.PYTORCH,
         
     | 
| 
       2945 
2963 
     | 
    
         
             
                        help="Runtime type for inference (default: PYTORCH)",
         
     | 
| 
       2946 
2964 
     | 
    
         
             
                    )
         
     | 
| 
       2947 
2965 
     | 
    
         
             
                    # -------------------------- #
         
     | 
| 
       2948 
2966 
     | 
    
         | 
| 
       2949 
     | 
    
         
            -
                    #  
     | 
| 
       2950 
     | 
    
         
            -
                    parser.add_argument("--predict-project", type=int, required=False, help="ID of the project")
         
     | 
| 
      
 2967 
     | 
    
         
            +
                    # Remote predict
         
     | 
| 
       2951 
2968 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       2952 
     | 
    
         
            -
                        "-- 
     | 
| 
      
 2969 
     | 
    
         
            +
                        "--project_id",
         
     | 
| 
      
 2970 
     | 
    
         
            +
                        type=int,
         
     | 
| 
      
 2971 
     | 
    
         
            +
                        required=False,
         
     | 
| 
      
 2972 
     | 
    
         
            +
                        help="Project ID on Supervisely instance",
         
     | 
| 
      
 2973 
     | 
    
         
            +
                    )
         
     | 
| 
      
 2974 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 2975 
     | 
    
         
            +
                        "--dataset_id",
         
     | 
| 
       2953 
2976 
     | 
    
         
             
                        type=lambda x: [int(i) for i in x.split(",")] if "," in x else int(x),
         
     | 
| 
       2954 
2977 
     | 
    
         
             
                        required=False,
         
     | 
| 
       2955 
     | 
    
         
            -
                        help="ID of the dataset or a comma-separated list of dataset IDs",
         
     | 
| 
      
 2978 
     | 
    
         
            +
                        help="ID of the dataset or a comma-separated list of dataset IDs e.g. '505,506,507'",
         
     | 
| 
       2956 
2979 
     | 
    
         
             
                    )
         
     | 
| 
       2957 
2980 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       2958 
     | 
    
         
            -
                        "-- 
     | 
| 
       2959 
     | 
    
         
            -
                        type= 
     | 
| 
      
 2981 
     | 
    
         
            +
                        "--image_id",
         
     | 
| 
      
 2982 
     | 
    
         
            +
                        type=int,
         
     | 
| 
       2960 
2983 
     | 
    
         
             
                        required=False,
         
     | 
| 
       2961 
     | 
    
         
            -
                        help=" 
     | 
| 
      
 2984 
     | 
    
         
            +
                        help="Image ID on Supervisely instance",
         
     | 
| 
       2962 
2985 
     | 
    
         
             
                    )
         
     | 
| 
      
 2986 
     | 
    
         
            +
                    # -------------------------- #
         
     | 
| 
      
 2987 
     | 
    
         
            +
             
     | 
| 
      
 2988 
     | 
    
         
            +
                    # Output args
         
     | 
| 
       2963 
2989 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       2964 
     | 
    
         
            -
                        "-- 
     | 
| 
      
 2990 
     | 
    
         
            +
                        "--output",
         
     | 
| 
       2965 
2991 
     | 
    
         
             
                        type=str,
         
     | 
| 
       2966 
2992 
     | 
    
         
             
                        required=False,
         
     | 
| 
       2967 
     | 
    
         
            -
                        help=" 
     | 
| 
      
 2993 
     | 
    
         
            +
                        help="Path to local directory where predictions will be saved. Default: './predictions'",
         
     | 
| 
      
 2994 
     | 
    
         
            +
                    )
         
     | 
| 
      
 2995 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 2996 
     | 
    
         
            +
                        "--upload",
         
     | 
| 
      
 2997 
     | 
    
         
            +
                        required=False,
         
     | 
| 
      
 2998 
     | 
    
         
            +
                        action="store_true",
         
     | 
| 
      
 2999 
     | 
    
         
            +
                        help="Upload predictions to Supervisely instance. Works only with: '--project_id', '--dataset_id', '--image_id'. For project and dataset predictions a new project will be created. Default: False",
         
     | 
| 
       2968 
3000 
     | 
    
         
             
                    )
         
     | 
| 
       2969 
3001 
     | 
    
         
             
                    # -------------------------- #
         
     | 
| 
       2970 
3002 
     | 
    
         | 
| 
       2971 
     | 
    
         
            -
                    #  
     | 
| 
       2972 
     | 
    
         
            -
                    parser.add_argument( 
     | 
| 
       2973 
     | 
    
         
            -
             
     | 
| 
      
 3003 
     | 
    
         
            +
                    # Other args
         
     | 
| 
      
 3004 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 3005 
     | 
    
         
            +
                        "--settings",
         
     | 
| 
      
 3006 
     | 
    
         
            +
                        type=str,
         
     | 
| 
      
 3007 
     | 
    
         
            +
                        required=False,
         
     | 
| 
      
 3008 
     | 
    
         
            +
                        nargs="*",
         
     | 
| 
      
 3009 
     | 
    
         
            +
                        help="Path to the settings JSON/YAML file or key=value pairs",
         
     | 
| 
      
 3010 
     | 
    
         
            +
                    )
         
     | 
| 
      
 3011 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 3012 
     | 
    
         
            +
                        "--draw",
         
     | 
| 
      
 3013 
     | 
    
         
            +
                        required=False,
         
     | 
| 
      
 3014 
     | 
    
         
            +
                        action="store_true",
         
     | 
| 
      
 3015 
     | 
    
         
            +
                        help="Generate new images with visualized predictions. Default: False",
         
     | 
| 
      
 3016 
     | 
    
         
            +
                    )
         
     | 
| 
       2974 
3017 
     | 
    
         
             
                    # -------------------------- #
         
     | 
| 
       2975 
3018 
     | 
    
         | 
| 
       2976 
3019 
     | 
    
         
             
                    # Parse arguments
         
     | 
| 
       2977 
3020 
     | 
    
         
             
                    args, _ = parser.parse_known_args()
         
     | 
| 
       2978 
3021 
     | 
    
         
             
                    if args.model is None:
         
     | 
| 
       2979 
     | 
    
         
            -
                         
     | 
| 
       2980 
     | 
    
         
            -
             
     | 
| 
       2981 
     | 
    
         
            -
             
     | 
| 
       2982 
     | 
    
         
            -
                         
     | 
| 
       2983 
     | 
    
         
            -
             
     | 
| 
       2984 
     | 
    
         
            -
                        if  
     | 
| 
       2985 
     | 
    
         
            -
                             
     | 
| 
      
 3022 
     | 
    
         
            +
                        if len(self.pretrained_models) == 0:
         
     | 
| 
      
 3023 
     | 
    
         
            +
                            raise ValueError("No pretrained models found.")
         
     | 
| 
      
 3024 
     | 
    
         
            +
             
     | 
| 
      
 3025 
     | 
    
         
            +
                        model = self.pretrained_models[0]
         
     | 
| 
      
 3026 
     | 
    
         
            +
                        model_name = model.get("meta", {}).get("model_name", None)
         
     | 
| 
      
 3027 
     | 
    
         
            +
                        if model_name is None:
         
     | 
| 
      
 3028 
     | 
    
         
            +
                            raise ValueError("No model name found in the first pretrained model.")
         
     | 
| 
      
 3029 
     | 
    
         
            +
             
     | 
| 
      
 3030 
     | 
    
         
            +
                        args.model = model_name
         
     | 
| 
      
 3031 
     | 
    
         
            +
                        logger.info(
         
     | 
| 
      
 3032 
     | 
    
         
            +
                            f"Argument '--model' is not provided. Model: '{model_name}' will be deployed."
         
     | 
| 
      
 3033 
     | 
    
         
            +
                        )
         
     | 
| 
      
 3034 
     | 
    
         
            +
                    if args.mode not in ["deploy", "predict"]:
         
     | 
| 
      
 3035 
     | 
    
         
            +
                        raise ValueError("Invalid operation. Only 'deploy' or 'predict' is supported.")
         
     | 
| 
      
 3036 
     | 
    
         
            +
                    if args.output is None:
         
     | 
| 
      
 3037 
     | 
    
         
            +
                        args.output = "./predictions"
         
     | 
| 
      
 3038 
     | 
    
         
            +
                    if isinstance(args.dataset_id, int):
         
     | 
| 
      
 3039 
     | 
    
         
            +
                        args.dataset_id = [args.dataset_id]
         
     | 
| 
      
 3040 
     | 
    
         
            +
             
     | 
| 
       2986 
3041 
     | 
    
         
             
                    return args, True
         
     | 
| 
       2987 
3042 
     | 
    
         | 
| 
      
 3043 
     | 
    
         
            +
                def _parse_inference_settings_from_args(self):
         
     | 
| 
      
 3044 
     | 
    
         
            +
                    def parse_value(value: str):
         
     | 
| 
      
 3045 
     | 
    
         
            +
                        if value.lower() in ("true", "false"):
         
     | 
| 
      
 3046 
     | 
    
         
            +
                            return value.lower() == "true"
         
     | 
| 
      
 3047 
     | 
    
         
            +
                        if value.lower() == ("none", "null"):
         
     | 
| 
      
 3048 
     | 
    
         
            +
                            return None
         
     | 
| 
      
 3049 
     | 
    
         
            +
                        if value.isdigit():
         
     | 
| 
      
 3050 
     | 
    
         
            +
                            return int(value)
         
     | 
| 
      
 3051 
     | 
    
         
            +
                        if "." in value:
         
     | 
| 
      
 3052 
     | 
    
         
            +
                            parts = value.split(".")
         
     | 
| 
      
 3053 
     | 
    
         
            +
                            if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit():
         
     | 
| 
      
 3054 
     | 
    
         
            +
                                return float(value)
         
     | 
| 
      
 3055 
     | 
    
         
            +
                        return value
         
     | 
| 
      
 3056 
     | 
    
         
            +
             
     | 
| 
      
 3057 
     | 
    
         
            +
                    args = self._args
         
     | 
| 
      
 3058 
     | 
    
         
            +
                    # Parse settings argument
         
     | 
| 
      
 3059 
     | 
    
         
            +
                    settings_dict = {}
         
     | 
| 
      
 3060 
     | 
    
         
            +
                    if args.settings:
         
     | 
| 
      
 3061 
     | 
    
         
            +
                        is_settings_file = args.settings[0].endswith((".json", ".yaml", ".yml"))
         
     | 
| 
      
 3062 
     | 
    
         
            +
                        if len(args.settings) == 1 and is_settings_file:
         
     | 
| 
      
 3063 
     | 
    
         
            +
                            args.settings = args.settings[0]
         
     | 
| 
      
 3064 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 3065 
     | 
    
         
            +
                            for setting in args.settings:
         
     | 
| 
      
 3066 
     | 
    
         
            +
                                if "=" in setting:
         
     | 
| 
      
 3067 
     | 
    
         
            +
                                    key, value = setting.split("=", 1)
         
     | 
| 
      
 3068 
     | 
    
         
            +
                                    settings_dict[key] = parse_value(value)
         
     | 
| 
      
 3069 
     | 
    
         
            +
                                elif ":" in setting:
         
     | 
| 
      
 3070 
     | 
    
         
            +
                                    key, value = setting.split(":", 1)
         
     | 
| 
      
 3071 
     | 
    
         
            +
                                    settings_dict[key] = parse_value(value)
         
     | 
| 
      
 3072 
     | 
    
         
            +
                                else:
         
     | 
| 
      
 3073 
     | 
    
         
            +
                                    raise ValueError(
         
     | 
| 
      
 3074 
     | 
    
         
            +
                                        f"Invalid setting: '{setting}'. Please use key value pairs separated by '=', e.g. conf=0.4'"
         
     | 
| 
      
 3075 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 3076 
     | 
    
         
            +
                            args.settings = settings_dict
         
     | 
| 
      
 3077 
     | 
    
         
            +
                    args.settings = self._read_settings(args.settings)
         
     | 
| 
      
 3078 
     | 
    
         
            +
                    self._validate_settings(args.settings)
         
     | 
| 
      
 3079 
     | 
    
         
            +
             
     | 
| 
       2988 
3080 
     | 
    
         
             
                def _get_pretrained_model_params_from_args(self):
         
     | 
| 
       2989 
3081 
     | 
    
         
             
                    model_files = None
         
     | 
| 
       2990 
3082 
     | 
    
         
             
                    model_source = None
         
     | 
| 
         @@ -3119,64 +3211,225 @@ class Inference: 
     | 
|
| 
       3119 
3211 
     | 
    
         
             
                    self._uvicorn_server = uvicorn.Server(config)
         
     | 
| 
       3120 
3212 
     | 
    
         
             
                    self._uvicorn_server.run()
         
     | 
| 
       3121 
3213 
     | 
    
         | 
| 
      
 3214 
     | 
    
         
            +
                def _read_settings(self, settings: Union[str, Dict[str, Any]]):
         
     | 
| 
      
 3215 
     | 
    
         
            +
                    if isinstance(settings, dict):
         
     | 
| 
      
 3216 
     | 
    
         
            +
                        return settings
         
     | 
| 
      
 3217 
     | 
    
         
            +
             
     | 
| 
      
 3218 
     | 
    
         
            +
                    settings_path = settings
         
     | 
| 
      
 3219 
     | 
    
         
            +
                    if settings_path is None:
         
     | 
| 
      
 3220 
     | 
    
         
            +
                        return {}
         
     | 
| 
      
 3221 
     | 
    
         
            +
                    if settings_path.endswith(".json"):
         
     | 
| 
      
 3222 
     | 
    
         
            +
                        return sly_json.load_json_file(settings_path)
         
     | 
| 
      
 3223 
     | 
    
         
            +
                    elif settings_path.endswith(".yaml") or settings_path.endswith(".yml"):
         
     | 
| 
      
 3224 
     | 
    
         
            +
                        with open(settings_path, "r") as f:
         
     | 
| 
      
 3225 
     | 
    
         
            +
                            return yaml.safe_load(f)
         
     | 
| 
      
 3226 
     | 
    
         
            +
                    raise ValueError("Settings file should be in JSON or YAML format")
         
     | 
| 
      
 3227 
     | 
    
         
            +
             
     | 
| 
      
 3228 
     | 
    
         
            +
                def _validate_settings(self, settings: dict):
         
     | 
| 
      
 3229 
     | 
    
         
            +
                    default_settings = self.custom_inference_settings_dict
         
     | 
| 
      
 3230 
     | 
    
         
            +
                    if settings == {}:
         
     | 
| 
      
 3231 
     | 
    
         
            +
                        self._args.settings = default_settings
         
     | 
| 
      
 3232 
     | 
    
         
            +
                        return
         
     | 
| 
      
 3233 
     | 
    
         
            +
                    for key, value in settings.items():
         
     | 
| 
      
 3234 
     | 
    
         
            +
                        if key not in default_settings and key != "classes":
         
     | 
| 
      
 3235 
     | 
    
         
            +
                            acceptable_keys = ", ".join(default_settings.keys()) + ", 'classes'"
         
     | 
| 
      
 3236 
     | 
    
         
            +
                            raise ValueError(
         
     | 
| 
      
 3237 
     | 
    
         
            +
                                f"Inference settings doesn't have key: '{key}'. Available keys are: '{acceptable_keys}'"
         
     | 
| 
      
 3238 
     | 
    
         
            +
                            )
         
     | 
| 
      
 3239 
     | 
    
         
            +
             
     | 
| 
       3122 
3240 
     | 
    
         
             
                def _inference_by_local_deploy_args(self):
         
     | 
| 
       3123 
     | 
    
         
            -
                     
     | 
| 
       3124 
     | 
    
         
            -
             
     | 
| 
       3125 
     | 
    
         
            -
             
     | 
| 
       3126 
     | 
    
         
            -
                         
     | 
| 
       3127 
     | 
    
         
            -
             
     | 
| 
       3128 
     | 
    
         
            -
                         
     | 
| 
       3129 
     | 
    
         
            -
                         
     | 
| 
       3130 
     | 
    
         
            -
             
     | 
| 
       3131 
     | 
    
         
            -
             
     | 
| 
       3132 
     | 
    
         
            -
             
     | 
| 
       3133 
     | 
    
         
            -
             
     | 
| 
       3134 
     | 
    
         
            -
             
     | 
| 
       3135 
     | 
    
         
            -
                             
     | 
| 
       3136 
     | 
    
         
            -
                        )
         
     | 
| 
      
 3241 
     | 
    
         
            +
                    missing_env_message = "Set 'SERVER_ADDRESS' and 'API_TOKEN' environment variables to predict data on Supervisely platform."
         
     | 
| 
      
 3242 
     | 
    
         
            +
             
     | 
| 
      
 3243 
     | 
    
         
            +
                    def predict_project_id_by_args(
         
     | 
| 
      
 3244 
     | 
    
         
            +
                        api: Api,
         
     | 
| 
      
 3245 
     | 
    
         
            +
                        project_id: int,
         
     | 
| 
      
 3246 
     | 
    
         
            +
                        dataset_ids: List[int] = None,
         
     | 
| 
      
 3247 
     | 
    
         
            +
                        output_dir: str = "./predictions",
         
     | 
| 
      
 3248 
     | 
    
         
            +
                        settings: str = None,
         
     | 
| 
      
 3249 
     | 
    
         
            +
                        draw: bool = False,
         
     | 
| 
      
 3250 
     | 
    
         
            +
                        upload: bool = False,
         
     | 
| 
      
 3251 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 3252 
     | 
    
         
            +
                        if self.api is None:
         
     | 
| 
      
 3253 
     | 
    
         
            +
                            raise ValueError(missing_env_message)
         
     | 
| 
       3137 
3254 
     | 
    
         | 
| 
       3138 
     | 
    
         
            -
             
     | 
| 
      
 3255 
     | 
    
         
            +
                        if dataset_ids:
         
     | 
| 
      
 3256 
     | 
    
         
            +
                            logger.info(f"Predicting datasets: '{dataset_ids}'")
         
     | 
| 
      
 3257 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 3258 
     | 
    
         
            +
                            logger.info(f"Predicting project: '{project_id}'")
         
     | 
| 
      
 3259 
     | 
    
         
            +
             
     | 
| 
      
 3260 
     | 
    
         
            +
                        if draw:
         
     | 
| 
      
 3261 
     | 
    
         
            +
                            raise ValueError("Draw visualization is not supported for project inference")
         
     | 
| 
      
 3262 
     | 
    
         
            +
             
     | 
| 
      
 3263 
     | 
    
         
            +
                        state = {"projectId": project_id, "dataset_ids": dataset_ids, "settings": settings}
         
     | 
| 
      
 3264 
     | 
    
         
            +
                        if upload:
         
     | 
| 
      
 3265 
     | 
    
         
            +
                            source_project = api.project.get_info_by_id(project_id)
         
     | 
| 
      
 3266 
     | 
    
         
            +
                            workspace_id = source_project.workspace_id
         
     | 
| 
      
 3267 
     | 
    
         
            +
                            output_project = api.project.create(
         
     | 
| 
      
 3268 
     | 
    
         
            +
                                workspace_id,
         
     | 
| 
      
 3269 
     | 
    
         
            +
                                f"{source_project.name} predicted",
         
     | 
| 
      
 3270 
     | 
    
         
            +
                                change_name_if_conflict=True,
         
     | 
| 
      
 3271 
     | 
    
         
            +
                            )
         
     | 
| 
      
 3272 
     | 
    
         
            +
                            state["output_project_id"] = output_project.id
         
     | 
| 
      
 3273 
     | 
    
         
            +
                        results = self._inference_project_id(api=self.api, state=state)
         
     | 
| 
      
 3274 
     | 
    
         
            +
             
     | 
| 
      
 3275 
     | 
    
         
            +
                        dataset_infos = api.dataset.get_list(project_id)
         
     | 
| 
      
 3276 
     | 
    
         
            +
                        datasets_map = {dataset_info.id: dataset_info.name for dataset_info in dataset_infos}
         
     | 
| 
      
 3277 
     | 
    
         
            +
             
     | 
| 
      
 3278 
     | 
    
         
            +
                        if not upload:
         
     | 
| 
      
 3279 
     | 
    
         
            +
                            for prediction in results:
         
     | 
| 
      
 3280 
     | 
    
         
            +
                                dataset_name = datasets_map[prediction["dataset_id"]]
         
     | 
| 
      
 3281 
     | 
    
         
            +
                                image_name = prediction["image_name"]
         
     | 
| 
      
 3282 
     | 
    
         
            +
                                pred_dir = os.path.join(output_dir, dataset_name)
         
     | 
| 
      
 3283 
     | 
    
         
            +
                                pred_path = os.path.join(pred_dir, f"{image_name}.json")
         
     | 
| 
      
 3284 
     | 
    
         
            +
                                ann_json = prediction["annotation"]
         
     | 
| 
      
 3285 
     | 
    
         
            +
                                if not sly_fs.dir_exists(pred_dir):
         
     | 
| 
      
 3286 
     | 
    
         
            +
                                    sly_fs.mkdir(pred_dir)
         
     | 
| 
      
 3287 
     | 
    
         
            +
                                sly_json.dump_json_file(ann_json, pred_path)
         
     | 
| 
      
 3288 
     | 
    
         
            +
             
     | 
| 
      
 3289 
     | 
    
         
            +
                    def predict_dataset_id_by_args(
         
     | 
| 
      
 3290 
     | 
    
         
            +
                        api: Api,
         
     | 
| 
      
 3291 
     | 
    
         
            +
                        dataset_ids: List[int],
         
     | 
| 
      
 3292 
     | 
    
         
            +
                        output_dir: str = "./predictions",
         
     | 
| 
      
 3293 
     | 
    
         
            +
                        settings: str = None,
         
     | 
| 
      
 3294 
     | 
    
         
            +
                        draw: bool = False,
         
     | 
| 
      
 3295 
     | 
    
         
            +
                        upload: bool = False,
         
     | 
| 
      
 3296 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 3297 
     | 
    
         
            +
                        if draw:
         
     | 
| 
      
 3298 
     | 
    
         
            +
                            raise ValueError("Draw visualization is not supported for dataset inference")
         
     | 
| 
      
 3299 
     | 
    
         
            +
                        if self.api is None:
         
     | 
| 
      
 3300 
     | 
    
         
            +
                            raise ValueError(missing_env_message)
         
     | 
| 
       3139 
3301 
     | 
    
         
             
                        dataset_infos = [api.dataset.get_info_by_id(dataset_id) for dataset_id in dataset_ids]
         
     | 
| 
       3140 
3302 
     | 
    
         
             
                        project_ids = list(set([dataset_info.project_id for dataset_info in dataset_infos]))
         
     | 
| 
       3141 
3303 
     | 
    
         
             
                        if len(project_ids) > 1:
         
     | 
| 
       3142 
3304 
     | 
    
         
             
                            raise ValueError("All datasets should belong to the same project")
         
     | 
| 
       3143 
     | 
    
         
            -
                         
     | 
| 
      
 3305 
     | 
    
         
            +
                        predict_project_id_by_args(
         
     | 
| 
      
 3306 
     | 
    
         
            +
                            api, project_ids[0], dataset_ids, output_dir, settings, draw, upload
         
     | 
| 
      
 3307 
     | 
    
         
            +
                        )
         
     | 
| 
      
 3308 
     | 
    
         
            +
             
     | 
| 
      
 3309 
     | 
    
         
            +
                    def predict_image_id_by_args(
         
     | 
| 
      
 3310 
     | 
    
         
            +
                        api: Api,
         
     | 
| 
      
 3311 
     | 
    
         
            +
                        image_id: int,
         
     | 
| 
      
 3312 
     | 
    
         
            +
                        output_dir: str = "./predictions",
         
     | 
| 
      
 3313 
     | 
    
         
            +
                        settings: str = None,
         
     | 
| 
      
 3314 
     | 
    
         
            +
                        draw: bool = False,
         
     | 
| 
      
 3315 
     | 
    
         
            +
                        upload: bool = False,
         
     | 
| 
      
 3316 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 3317 
     | 
    
         
            +
                        if self.api is None:
         
     | 
| 
      
 3318 
     | 
    
         
            +
                            raise ValueError(missing_env_message)
         
     | 
| 
      
 3319 
     | 
    
         
            +
             
     | 
| 
      
 3320 
     | 
    
         
            +
                        logger.info(f"Predicting image: '{image_id}'")
         
     | 
| 
       3144 
3321 
     | 
    
         | 
| 
       3145 
     | 
    
         
            -
                    def predict_image_by_args(api: Api, image: Union[str, int]):
         
     | 
| 
       3146 
3322 
     | 
    
         
             
                        def predict_image_np(image_np):
         
     | 
| 
       3147 
     | 
    
         
            -
                            settings = self._get_inference_settings({})
         
     | 
| 
       3148 
3323 
     | 
    
         
             
                            anns, _ = self._inference_auto([image_np], settings)
         
     | 
| 
       3149 
3324 
     | 
    
         
             
                            if len(anns) == 0:
         
     | 
| 
       3150 
3325 
     | 
    
         
             
                                return Annotation(img_size=image_np.shape[:2])
         
     | 
| 
       3151 
3326 
     | 
    
         
             
                            ann = anns[0]
         
     | 
| 
       3152 
3327 
     | 
    
         
             
                            return ann
         
     | 
| 
       3153 
3328 
     | 
    
         | 
| 
       3154 
     | 
    
         
            -
                         
     | 
| 
       3155 
     | 
    
         
            -
             
     | 
| 
       3156 
     | 
    
         
            -
             
     | 
| 
       3157 
     | 
    
         
            -
             
     | 
| 
       3158 
     | 
    
         
            -
                         
     | 
| 
       3159 
     | 
    
         
            -
                             
     | 
| 
       3160 
     | 
    
         
            -
             
     | 
| 
       3161 
     | 
    
         
            -
             
     | 
| 
       3162 
     | 
    
         
            -
             
     | 
| 
       3163 
     | 
    
         
            -
             
     | 
| 
       3164 
     | 
    
         
            -
             
     | 
| 
       3165 
     | 
    
         
            -
                                 
     | 
| 
       3166 
     | 
    
         
            -
             
     | 
| 
       3167 
     | 
    
         
            -
             
     | 
| 
       3168 
     | 
    
         
            -
             
     | 
| 
       3169 
     | 
    
         
            -
             
     | 
| 
       3170 
     | 
    
         
            -
                                 
     | 
| 
       3171 
     | 
    
         
            -
             
     | 
| 
       3172 
     | 
    
         
            -
             
     | 
| 
       3173 
     | 
    
         
            -
                         
     | 
| 
       3174 
     | 
    
         
            -
             
     | 
| 
       3175 
     | 
    
         
            -
             
     | 
| 
       3176 
     | 
    
         
            -
                     
     | 
| 
       3177 
     | 
    
         
            -
                         
     | 
| 
       3178 
     | 
    
         
            -
             
     | 
| 
       3179 
     | 
    
         
            -
                         
     | 
| 
      
 3329 
     | 
    
         
            +
                        image_np = api.image.download_np(image_id)
         
     | 
| 
      
 3330 
     | 
    
         
            +
                        ann = predict_image_np(image_np)
         
     | 
| 
      
 3331 
     | 
    
         
            +
             
     | 
| 
      
 3332 
     | 
    
         
            +
                        image_info = None
         
     | 
| 
      
 3333 
     | 
    
         
            +
                        if not upload:
         
     | 
| 
      
 3334 
     | 
    
         
            +
                            ann_json = ann.to_json()
         
     | 
| 
      
 3335 
     | 
    
         
            +
                            image_info = api.image.get_info_by_id(image_id)
         
     | 
| 
      
 3336 
     | 
    
         
            +
                            dataset_info = api.dataset.get_info_by_id(image_info.dataset_id)
         
     | 
| 
      
 3337 
     | 
    
         
            +
                            pred_dir = os.path.join(output_dir, dataset_info.name)
         
     | 
| 
      
 3338 
     | 
    
         
            +
                            pred_path = os.path.join(pred_dir, f"{image_info.name}.json")
         
     | 
| 
      
 3339 
     | 
    
         
            +
                            if not sly_fs.dir_exists(pred_dir):
         
     | 
| 
      
 3340 
     | 
    
         
            +
                                sly_fs.mkdir(pred_dir)
         
     | 
| 
      
 3341 
     | 
    
         
            +
                            sly_json.dump_json_file(ann_json, pred_path)
         
     | 
| 
      
 3342 
     | 
    
         
            +
             
     | 
| 
      
 3343 
     | 
    
         
            +
                        if draw:
         
     | 
| 
      
 3344 
     | 
    
         
            +
                            if image_info is None:
         
     | 
| 
      
 3345 
     | 
    
         
            +
                                image_info = api.image.get_info_by_id(image_id)
         
     | 
| 
      
 3346 
     | 
    
         
            +
                            vis_path = os.path.join(output_dir, dataset_info.name, f"{image_info.name}.png")
         
     | 
| 
      
 3347 
     | 
    
         
            +
                            ann.draw_pretty(image_np, output_path=vis_path)
         
     | 
| 
      
 3348 
     | 
    
         
            +
                        if upload:
         
     | 
| 
      
 3349 
     | 
    
         
            +
                            api.annotation.upload_ann(image_id, ann)
         
     | 
| 
      
 3350 
     | 
    
         
            +
             
     | 
| 
      
 3351 
     | 
    
         
            +
                    def predict_local_data_by_args(
         
     | 
| 
      
 3352 
     | 
    
         
            +
                        input_path: str,
         
     | 
| 
      
 3353 
     | 
    
         
            +
                        settings: str = None,
         
     | 
| 
      
 3354 
     | 
    
         
            +
                        output_dir: str = "./predictions",
         
     | 
| 
      
 3355 
     | 
    
         
            +
                        draw: bool = False,
         
     | 
| 
      
 3356 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 3357 
     | 
    
         
            +
                        logger.info(f"Predicting '{input_path}'")
         
     | 
| 
      
 3358 
     | 
    
         
            +
             
     | 
| 
      
 3359 
     | 
    
         
            +
                        def postprocess_image(image_path: str, ann: Annotation, pred_dir: str = None):
         
     | 
| 
      
 3360 
     | 
    
         
            +
                            image_name = sly_fs.get_file_name_with_ext(image_path)
         
     | 
| 
      
 3361 
     | 
    
         
            +
                            if pred_dir is not None:
         
     | 
| 
      
 3362 
     | 
    
         
            +
                                pred_dir = os.path.join(output_dir, pred_dir)
         
     | 
| 
      
 3363 
     | 
    
         
            +
                                pred_ann_path = os.path.join(pred_dir, f"{image_name}.json")
         
     | 
| 
      
 3364 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 3365 
     | 
    
         
            +
                                pred_dir = output_dir
         
     | 
| 
      
 3366 
     | 
    
         
            +
                                pred_ann_path = os.path.join(pred_dir, f"{image_name}.json")
         
     | 
| 
      
 3367 
     | 
    
         
            +
             
     | 
| 
      
 3368 
     | 
    
         
            +
                            if not os.path.exists(pred_dir):
         
     | 
| 
      
 3369 
     | 
    
         
            +
                                sly_fs.mkdir(pred_dir)
         
     | 
| 
      
 3370 
     | 
    
         
            +
                            sly_json.dump_json_file(ann.to_json(), pred_ann_path)
         
     | 
| 
      
 3371 
     | 
    
         
            +
                            if draw:
         
     | 
| 
      
 3372 
     | 
    
         
            +
                                image = sly_image.read(image_path)
         
     | 
| 
      
 3373 
     | 
    
         
            +
                                ann.draw_pretty(image, output_path=os.path.join(pred_dir, image_name))
         
     | 
| 
      
 3374 
     | 
    
         
            +
             
     | 
| 
      
 3375 
     | 
    
         
            +
                        # 1. Input Directory
         
     | 
| 
      
 3376 
     | 
    
         
            +
                        if os.path.isdir(input_path):
         
     | 
| 
      
 3377 
     | 
    
         
            +
                            pred_dir = os.path.basename(input_path)
         
     | 
| 
      
 3378 
     | 
    
         
            +
                            images = list_files(input_path, valid_extensions=sly_image.SUPPORTED_IMG_EXTS)
         
     | 
| 
      
 3379 
     | 
    
         
            +
                            anns, _ = self._inference_auto(images, settings)
         
     | 
| 
      
 3380 
     | 
    
         
            +
                            for image_path, ann in zip(images, anns):
         
     | 
| 
      
 3381 
     | 
    
         
            +
                                postprocess_image(image_path, ann, pred_dir)
         
     | 
| 
      
 3382 
     | 
    
         
            +
                        # 2. Input File
         
     | 
| 
      
 3383 
     | 
    
         
            +
                        elif os.path.isfile(input_path):
         
     | 
| 
      
 3384 
     | 
    
         
            +
                            if input_path.endswith(tuple(sly_image.SUPPORTED_IMG_EXTS)):
         
     | 
| 
      
 3385 
     | 
    
         
            +
                                image_np = sly_image.read(input_path)
         
     | 
| 
      
 3386 
     | 
    
         
            +
                                anns, _ = self._inference_auto([image_np], settings)
         
     | 
| 
      
 3387 
     | 
    
         
            +
                                ann = anns[0]
         
     | 
| 
      
 3388 
     | 
    
         
            +
                                postprocess_image(input_path, ann)
         
     | 
| 
      
 3389 
     | 
    
         
            +
                            elif input_path.endswith(tuple(ALLOWED_VIDEO_EXTENSIONS)):
         
     | 
| 
      
 3390 
     | 
    
         
            +
                                raise NotImplementedError("Video inference is not implemented yet")
         
     | 
| 
      
 3391 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 3392 
     | 
    
         
            +
                                raise ValueError(
         
     | 
| 
      
 3393 
     | 
    
         
            +
                                    f"Unsupported input format: '{input_path}'. Expect image or directory with images"
         
     | 
| 
      
 3394 
     | 
    
         
            +
                                )
         
     | 
| 
      
 3395 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 3396 
     | 
    
         
            +
                            raise ValueError(f"Please provide a valid input path: '{input_path}'")
         
     | 
| 
      
 3397 
     | 
    
         
            +
             
     | 
| 
      
 3398 
     | 
    
         
            +
                    if self._args.project_id is not None:
         
     | 
| 
      
 3399 
     | 
    
         
            +
                        predict_project_id_by_args(
         
     | 
| 
      
 3400 
     | 
    
         
            +
                            self.api,
         
     | 
| 
      
 3401 
     | 
    
         
            +
                            self._args.project_id,
         
     | 
| 
      
 3402 
     | 
    
         
            +
                            None,
         
     | 
| 
      
 3403 
     | 
    
         
            +
                            self._args.output,
         
     | 
| 
      
 3404 
     | 
    
         
            +
                            self._args.settings,
         
     | 
| 
      
 3405 
     | 
    
         
            +
                            self._args.draw,
         
     | 
| 
      
 3406 
     | 
    
         
            +
                            self._args.upload,
         
     | 
| 
      
 3407 
     | 
    
         
            +
                        )
         
     | 
| 
      
 3408 
     | 
    
         
            +
                    elif self._args.dataset_id is not None:
         
     | 
| 
      
 3409 
     | 
    
         
            +
                        predict_dataset_id_by_args(
         
     | 
| 
      
 3410 
     | 
    
         
            +
                            self.api,
         
     | 
| 
      
 3411 
     | 
    
         
            +
                            self._args.dataset_id,
         
     | 
| 
      
 3412 
     | 
    
         
            +
                            self._args.output,
         
     | 
| 
      
 3413 
     | 
    
         
            +
                            self._args.settings,
         
     | 
| 
      
 3414 
     | 
    
         
            +
                            self._args.draw,
         
     | 
| 
      
 3415 
     | 
    
         
            +
                            self._args.upload,
         
     | 
| 
      
 3416 
     | 
    
         
            +
                        )
         
     | 
| 
      
 3417 
     | 
    
         
            +
                    elif self._args.image_id is not None:
         
     | 
| 
      
 3418 
     | 
    
         
            +
                        predict_image_id_by_args(
         
     | 
| 
      
 3419 
     | 
    
         
            +
                            self.api,
         
     | 
| 
      
 3420 
     | 
    
         
            +
                            self._args.image_id,
         
     | 
| 
      
 3421 
     | 
    
         
            +
                            self._args.output,
         
     | 
| 
      
 3422 
     | 
    
         
            +
                            self._args.settings,
         
     | 
| 
      
 3423 
     | 
    
         
            +
                            self._args.draw,
         
     | 
| 
      
 3424 
     | 
    
         
            +
                            self._args.upload,
         
     | 
| 
      
 3425 
     | 
    
         
            +
                        )
         
     | 
| 
      
 3426 
     | 
    
         
            +
                    elif self._args.input is not None:
         
     | 
| 
      
 3427 
     | 
    
         
            +
                        predict_local_data_by_args(
         
     | 
| 
      
 3428 
     | 
    
         
            +
                            self._args.input,
         
     | 
| 
      
 3429 
     | 
    
         
            +
                            self._args.settings,
         
     | 
| 
      
 3430 
     | 
    
         
            +
                            self._args.output,
         
     | 
| 
      
 3431 
     | 
    
         
            +
                            self._args.draw,
         
     | 
| 
      
 3432 
     | 
    
         
            +
                        )
         
     | 
| 
       3180 
3433 
     | 
    
         | 
| 
       3181 
3434 
     | 
    
         
             
                def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
         
     | 
| 
       3182 
3435 
     | 
    
         
             
                    if model_source == ModelSource.PRETRAINED:
         
     | 
| 
         @@ -875,8 +875,8 @@ supervisely/nn/benchmark/visualization/widgets/sidebar/sidebar.py,sha256=tKPURRS 
     | 
|
| 
       875 
875 
     | 
    
         
             
            supervisely/nn/benchmark/visualization/widgets/table/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       876 
876 
     | 
    
         
             
            supervisely/nn/benchmark/visualization/widgets/table/table.py,sha256=atmDnF1Af6qLQBUjLhK18RMDKAYlxnsuVHMSEa5a-e8,4319
         
     | 
| 
       877 
877 
     | 
    
         
             
            supervisely/nn/inference/__init__.py,sha256=QFukX2ip-U7263aEPCF_UCFwj6EujbMnsgrXp5Bbt8I,1623
         
     | 
| 
       878 
     | 
    
         
            -
            supervisely/nn/inference/cache.py,sha256= 
     | 
| 
       879 
     | 
    
         
            -
            supervisely/nn/inference/inference.py,sha256= 
     | 
| 
      
 878 
     | 
    
         
            +
            supervisely/nn/inference/cache.py,sha256=q4F7ZRzZghNWSVFClXEIHNMNW4PK6xddYckCFUgyhCo,32027
         
     | 
| 
      
 879 
     | 
    
         
            +
            supervisely/nn/inference/inference.py,sha256=Fq2aMKIxLZwiFXW8a2mfbvW3KXd8O31fccKoKYDvqvQ,158506
         
     | 
| 
       880 
880 
     | 
    
         
             
            supervisely/nn/inference/session.py,sha256=jmkkxbe2kH-lEgUU6Afh62jP68dxfhF5v6OGDfLU62E,35757
         
     | 
| 
       881 
881 
     | 
    
         
             
            supervisely/nn/inference/video_inference.py,sha256=8Bshjr6rDyLay5Za8IB8Dr6FURMO2R_v7aELasO8pR4,5746
         
     | 
| 
       882 
882 
     | 
    
         
             
            supervisely/nn/inference/gui/__init__.py,sha256=wCxd-lF5Zhcwsis-wScDA8n1Gk_1O00PKgDviUZ3F1U,221
         
     | 
| 
         @@ -1075,9 +1075,9 @@ supervisely/worker_proto/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ 
     | 
|
| 
       1075 
1075 
     | 
    
         
             
            supervisely/worker_proto/worker_api_pb2.py,sha256=VQfi5JRBHs2pFCK1snec3JECgGnua3Xjqw_-b3aFxuM,59142
         
     | 
| 
       1076 
1076 
     | 
    
         
             
            supervisely/worker_proto/worker_api_pb2_grpc.py,sha256=3BwQXOaP9qpdi0Dt9EKG--Lm8KGN0C5AgmUfRv77_Jk,28940
         
     | 
| 
       1077 
1077 
     | 
    
         
             
            supervisely_lib/__init__.py,sha256=7-3QnN8Zf0wj8NCr2oJmqoQWMKKPKTECvjH9pd2S5vY,159
         
     | 
| 
       1078 
     | 
    
         
            -
            supervisely-6.73. 
     | 
| 
       1079 
     | 
    
         
            -
            supervisely-6.73. 
     | 
| 
       1080 
     | 
    
         
            -
            supervisely-6.73. 
     | 
| 
       1081 
     | 
    
         
            -
            supervisely-6.73. 
     | 
| 
       1082 
     | 
    
         
            -
            supervisely-6.73. 
     | 
| 
       1083 
     | 
    
         
            -
            supervisely-6.73. 
     | 
| 
      
 1078 
     | 
    
         
            +
            supervisely-6.73.318.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
         
     | 
| 
      
 1079 
     | 
    
         
            +
            supervisely-6.73.318.dist-info/METADATA,sha256=BG22_yvudczrZ-0Nq_eCZNa6hM2CocTY3YpADXH47uE,33596
         
     | 
| 
      
 1080 
     | 
    
         
            +
            supervisely-6.73.318.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
         
     | 
| 
      
 1081 
     | 
    
         
            +
            supervisely-6.73.318.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
         
     | 
| 
      
 1082 
     | 
    
         
            +
            supervisely-6.73.318.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
         
     | 
| 
      
 1083 
     | 
    
         
            +
            supervisely-6.73.318.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |