supervisely 6.73.269__py3-none-any.whl → 6.73.271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

@@ -1620,7 +1620,7 @@ class FileApi(ModuleApiBase):
1620
1620
  downloaded_file_hash = await get_file_hash_async(local_save_path)
1621
1621
  if hash_to_check != downloaded_file_hash:
1622
1622
  raise RuntimeError(
1623
- f"Downloaded hash of image with ID:{id} does not match the expected hash: {downloaded_file_hash} != {hash_to_check}"
1623
+ f"Downloaded hash of file path: '{remote_path}' does not match the expected hash: {downloaded_file_hash} != {hash_to_check}"
1624
1624
  )
1625
1625
  if progress_cb is not None and progress_cb_type == "number":
1626
1626
  progress_cb(1)
@@ -1,3 +1,5 @@
1
+ import argparse
2
+ import asyncio
1
3
  import inspect
2
4
  import json
3
5
  import os
@@ -11,12 +13,14 @@ from collections import OrderedDict, defaultdict
11
13
  from concurrent.futures import ThreadPoolExecutor
12
14
  from dataclasses import asdict
13
15
  from functools import partial, wraps
16
+ from pathlib import Path
14
17
  from queue import Queue
15
18
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
19
  from urllib.request import urlopen
17
20
 
18
21
  import numpy as np
19
22
  import requests
23
+ import uvicorn
20
24
  import yaml
21
25
  from fastapi import Form, HTTPException, Request, Response, UploadFile, status
22
26
  from fastapi.responses import JSONResponse
@@ -24,8 +28,7 @@ from requests.structures import CaseInsensitiveDict
24
28
 
25
29
  import supervisely.app.development as sly_app_development
26
30
  import supervisely.imaging.image as sly_image
27
- import supervisely.io.env as env
28
- import supervisely.io.fs as fs
31
+ import supervisely.io.env as sly_env
29
32
  import supervisely.io.fs as sly_fs
30
33
  import supervisely.io.json as sly_json
31
34
  import supervisely.nn.inference.gui as GUI
@@ -33,6 +36,7 @@ from supervisely import DatasetInfo, ProjectInfo, VideoAnnotation, batched
33
36
  from supervisely._utils import (
34
37
  add_callback,
35
38
  get_filename_from_headers,
39
+ get_or_create_event_loop,
36
40
  is_debug_with_sly_net,
37
41
  is_production,
38
42
  rand_str,
@@ -43,6 +47,7 @@ from supervisely.annotation.obj_class import ObjClass
43
47
  from supervisely.annotation.tag_collection import TagCollection
44
48
  from supervisely.annotation.tag_meta import TagMeta, TagValueType
45
49
  from supervisely.api.api import Api
50
+ from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
46
51
  from supervisely.api.image_api import ImageInfo
47
52
  from supervisely.app.content import StateJson, get_data_dir
48
53
  from supervisely.app.exceptions import DialogWindowError
@@ -106,9 +111,19 @@ class Inference:
106
111
  multithread_inference: Optional[bool] = True,
107
112
  use_serving_gui_template: Optional[bool] = False,
108
113
  ):
114
+
115
+ self._args, self._is_local_deploy = self._parse_local_deploy_args()
116
+
109
117
  if model_dir is None:
110
- model_dir = os.path.join(get_data_dir(), "models")
111
- fs.mkdir(model_dir)
118
+ if self._is_local_deploy is True:
119
+ try:
120
+ model_dir = get_data_dir()
121
+ except:
122
+ model_dir = Path("~/.cache/supervisely/app_data").expanduser()
123
+ else:
124
+ model_dir = os.path.join(get_data_dir(), "models")
125
+ sly_fs.mkdir(model_dir)
126
+
112
127
  self.device: str = None
113
128
  self.runtime: str = None
114
129
  self.model_precision: str = None
@@ -128,7 +143,7 @@ class Inference:
128
143
  self._autostart_delay_time = 5 * 60 # 5 min
129
144
  self._tracker = None
130
145
  self._hardware: str = None
131
- self.pretrained_models = self._load_models_json(self.MODELS) if self.MODELS else None
146
+ self.pretrained_models = self._load_models_json_file(self.MODELS) if self.MODELS else None
132
147
  if custom_inference_settings is None:
133
148
  if self.INFERENCE_SETTINGS is not None:
134
149
  custom_inference_settings = self.INFERENCE_SETTINGS
@@ -136,7 +151,7 @@ class Inference:
136
151
  logger.debug("Custom inference settings are not provided.")
137
152
  custom_inference_settings = {}
138
153
  if isinstance(custom_inference_settings, str):
139
- if fs.file_exists(custom_inference_settings):
154
+ if sly_fs.file_exists(custom_inference_settings):
140
155
  with open(custom_inference_settings, "r") as f:
141
156
  custom_inference_settings = f.read()
142
157
  else:
@@ -146,12 +161,24 @@ class Inference:
146
161
  self._use_gui = use_gui
147
162
  self._use_serving_gui_template = use_serving_gui_template
148
163
  self._gui = None
164
+ self._uvicorn_server = None
149
165
 
150
166
  self.load_on_device = LOAD_ON_DEVICE_DECORATOR(self.load_on_device)
151
167
  self.load_on_device = add_callback(self.load_on_device, self._set_served_callback)
152
168
 
153
169
  self.load_model = LOAD_MODEL_DECORATOR(self.load_model)
154
170
 
171
+ if self._is_local_deploy:
172
+ # self._args = self._parse_local_deploy_args()
173
+ self._use_gui = False
174
+ deploy_params, need_download = self._get_deploy_params_from_args()
175
+ if need_download:
176
+ local_model_files = self._download_model_files(
177
+ deploy_params["model_source"], deploy_params["model_files"], False
178
+ )
179
+ deploy_params["model_files"] = local_model_files
180
+ self._load_model_headless(**deploy_params)
181
+
155
182
  if self._use_gui:
156
183
  initialize_custom_gui_method = getattr(self, "initialize_custom_gui", None)
157
184
  original_initialize_custom_gui_method = getattr(
@@ -224,10 +251,10 @@ class Inference:
224
251
  self.get_info = self._check_serve_before_call(self.get_info)
225
252
 
226
253
  self.cache = InferenceImageCache(
227
- maxsize=env.smart_cache_size(),
228
- ttl=env.smart_cache_ttl(),
254
+ maxsize=sly_env.smart_cache_size(),
255
+ ttl=sly_env.smart_cache_ttl(),
229
256
  is_persistent=True,
230
- base_folder=env.smart_cache_container_dir(),
257
+ base_folder=sly_env.smart_cache_container_dir(),
231
258
  log_progress=True,
232
259
  )
233
260
 
@@ -248,9 +275,18 @@ class Inference:
248
275
  )
249
276
  device = "cpu"
250
277
 
251
- def _load_models_json(self, models: str) -> List[Dict[str, Any]]:
278
+ def _load_json_file(self, file_path: str) -> dict:
279
+ if isinstance(file_path, str):
280
+ if sly_fs.file_exists(file_path) and sly_fs.get_file_ext(file_path) == ".json":
281
+ return sly_json.load_json_file(file_path)
282
+ else:
283
+ raise ValueError("File not found or invalid file format.")
284
+ else:
285
+ raise ValueError("Invalid file. Please provide a valid '.json' file.")
286
+
287
+ def _load_models_json_file(self, models: str) -> List[Dict[str, Any]]:
252
288
  """
253
- Loads models from the provided file or list of model configurations.
289
+ Loads dictionary from the provided file.
254
290
  """
255
291
  if isinstance(models, str):
256
292
  if sly_fs.file_exists(models) and sly_fs.get_file_ext(models) == ".json":
@@ -258,9 +294,7 @@ class Inference:
258
294
  else:
259
295
  raise ValueError("File not found or invalid file format.")
260
296
  else:
261
- raise ValueError(
262
- "Invalid models file. Please provide a valid '.json' file with list of model configurations."
263
- )
297
+ raise ValueError("Invalid file. Please provide a valid '.json' file.")
264
298
 
265
299
  if not isinstance(models, list):
266
300
  raise ValueError("models parameters must be a list of dicts")
@@ -394,13 +428,13 @@ class Inference:
394
428
  else:
395
429
  progress = None
396
430
 
397
- if fs.dir_exists(src_path) or fs.file_exists(
431
+ if sly_fs.dir_exists(src_path) or sly_fs.file_exists(
398
432
  src_path
399
433
  ): # only during debug, has no effect in production
400
434
  dst_path = os.path.abspath(src_path)
401
435
  logger.info(f"File {dst_path} found.")
402
436
  elif src_path.startswith("/"): # folder from Team Files
403
- team_id = env.team_id()
437
+ team_id = sly_env.team_id()
404
438
 
405
439
  if src_path.endswith("/") and self.api.file.dir_exists(team_id, src_path):
406
440
 
@@ -436,7 +470,7 @@ class Inference:
436
470
  def download_file(team_id, src_path, dst_path, progress_cb=None):
437
471
  self.api.file.download(team_id, src_path, dst_path, progress_cb=progress_cb)
438
472
 
439
- file_info = self.api.file.get_info_by_path(env.team_id(), src_path)
473
+ file_info = self.api.file.get_info_by_path(sly_env.team_id(), src_path)
440
474
  if progress is None:
441
475
  download_file(team_id, src_path, dst_path)
442
476
  else:
@@ -451,8 +485,8 @@ class Inference:
451
485
  logger.info(f"📥 File {basename} has been successfully downloaded from Team Files")
452
486
  logger.info(f"File {basename} path: {dst_path}")
453
487
  else: # external url
454
- if not fs.dir_exists(os.path.dirname(dst_path)):
455
- fs.mkdir(os.path.dirname(dst_path))
488
+ if not sly_fs.dir_exists(os.path.dirname(dst_path)):
489
+ sly_fs.mkdir(os.path.dirname(dst_path))
456
490
 
457
491
  def download_external_file(url, save_path, progress=None):
458
492
  def download_content(save_path, progress_cb=None):
@@ -531,13 +565,15 @@ class Inference:
531
565
  def _checkpoints_cache_dir(self):
532
566
  return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
533
567
 
534
- def _download_model_files(self, model_source: str, model_files: List[str]) -> dict:
568
+ def _download_model_files(
569
+ self, model_source: str, model_files: List[str], log_progress: bool = True
570
+ ) -> dict:
535
571
  if model_source == ModelSource.PRETRAINED:
536
- return self._download_pretrained_model(model_files)
572
+ return self._download_pretrained_model(model_files, log_progress)
537
573
  elif model_source == ModelSource.CUSTOM:
538
- return self._download_custom_model(model_files)
574
+ return self._download_custom_model(model_files, log_progress)
539
575
 
540
- def _download_pretrained_model(self, model_files: dict):
576
+ def _download_pretrained_model(self, model_files: dict, log_progress: bool = True):
541
577
  """
542
578
  Downloads the pretrained model data.
543
579
  """
@@ -564,54 +600,61 @@ class Inference:
564
600
  logger.debug(f"Model: '{file_name}' was found in model dir")
565
601
  continue
566
602
 
567
- with self.gui.download_progress(
568
- message=f"Downloading: '{file_name}'",
569
- total=file_size,
570
- unit="bytes",
571
- unit_scale=True,
572
- ) as download_pbar:
573
- self.gui.download_progress.show()
574
- sly_fs.download(
575
- url=file_url,
576
- save_path=file_path,
577
- progress=download_pbar.update,
578
- )
603
+ if log_progress:
604
+ with self.gui.download_progress(
605
+ message=f"Downloading: '{file_name}'",
606
+ total=file_size,
607
+ unit="bytes",
608
+ unit_scale=True,
609
+ ) as download_pbar:
610
+ self.gui.download_progress.show()
611
+ sly_fs.download(
612
+ url=file_url, save_path=file_path, progress=download_pbar.update
613
+ )
614
+ else:
615
+ sly_fs.download(url=file_url, save_path=file_path)
579
616
  local_model_files[file] = file_path
580
617
  else:
581
618
  local_model_files[file] = file_url
582
- self.gui.download_progress.hide()
619
+
620
+ if log_progress:
621
+ self.gui.download_progress.hide()
583
622
  return local_model_files
584
623
 
585
- def _download_custom_model(self, model_files: dict):
624
+ def _download_custom_model(self, model_files: dict, log_progress: bool = True):
586
625
  """
587
626
  Downloads the custom model data.
588
627
  """
589
-
590
- team_id = env.team_id()
628
+ team_id = sly_env.team_id()
591
629
  local_model_files = {}
592
-
593
630
  for file in model_files:
594
631
  file_url = model_files[file]
595
632
  file_info = self.api.file.get_info_by_path(team_id, file_url)
596
633
  if file_info is None:
597
- raise FileNotFoundError(
598
- f"File '{file_url}' not found in Team Files. Make sure the file exists."
599
- )
634
+ if sly_fs.file_exists(file_url):
635
+ local_model_files[file] = file_url
636
+ continue
637
+ else:
638
+ raise FileNotFoundError(f"File '{file_url}' not found in Team Files")
600
639
  file_size = file_info.sizeb
601
640
  file_name = os.path.basename(file_url)
602
641
  file_path = os.path.join(self.model_dir, file_name)
603
- with self.gui.download_progress(
604
- message=f"Downloading: '{file_name}'",
605
- total=file_size,
606
- unit="bytes",
607
- unit_scale=True,
608
- ) as download_pbar:
609
- self.gui.download_progress.show()
610
- self.api.file.download(
611
- team_id, file_url, file_path, progress_cb=download_pbar.update
612
- )
642
+ if log_progress:
643
+ with self.gui.download_progress(
644
+ message=f"Downloading: '{file_name}'",
645
+ total=file_size,
646
+ unit="bytes",
647
+ unit_scale=True,
648
+ ) as download_pbar:
649
+ self.gui.download_progress.show()
650
+ self.api.file.download(
651
+ team_id, file_url, file_path, progress_cb=download_pbar.update
652
+ )
653
+ else:
654
+ self.api.file.download(team_id, file_url, file_path)
613
655
  local_model_files[file] = file_path
614
- self.gui.download_progress.hide()
656
+ if log_progress:
657
+ self.gui.download_progress.hide()
615
658
  return local_model_files
616
659
 
617
660
  def _load_model(self, deploy_params: dict):
@@ -647,6 +690,13 @@ class Inference:
647
690
  if model_source == ModelSource.CUSTOM:
648
691
  self._set_model_meta_custom_model(model_info)
649
692
  self._set_checkpoint_info_custom_model(deploy_params)
693
+
694
+ try:
695
+ if is_production():
696
+ self._add_workflow_input(model_source, model_files, model_info)
697
+ except Exception as e:
698
+ logger.warning(f"Failed to add input to the workflow: {repr(e)}")
699
+
650
700
  self._load_model(deploy_params)
651
701
  if self._model_meta is None:
652
702
  self._set_model_meta_from_classes()
@@ -679,7 +729,7 @@ class Inference:
679
729
  model_info.get("artifacts_dir"), "checkpoints", checkpoint_name
680
730
  )
681
731
  checkpoint_file_info = self.api.file.get_info_by_path(
682
- env.team_id(), checkpoint_file_path
732
+ sly_env.team_id(), checkpoint_file_path
683
733
  )
684
734
  if checkpoint_file_info is None:
685
735
  checkpoint_url = None
@@ -1253,18 +1303,18 @@ class Inference:
1253
1303
  logger.debug("Inferring image_url...", extra={"state": state})
1254
1304
  settings = self._get_inference_settings(state)
1255
1305
  image_url = state["image_url"]
1256
- ext = fs.get_file_ext(image_url)
1306
+ ext = sly_fs.get_file_ext(image_url)
1257
1307
  if ext == "":
1258
1308
  ext = ".jpg"
1259
1309
  image_path = os.path.join(get_data_dir(), rand_str(15) + ext)
1260
- fs.download(image_url, image_path)
1310
+ sly_fs.download(image_url, image_path)
1261
1311
  logger.debug("Inference settings:", extra=settings)
1262
1312
  logger.debug(f"Downloaded path: {image_path}")
1263
1313
  anns, slides_data = self._inference_auto(
1264
1314
  [image_path],
1265
1315
  settings=settings,
1266
1316
  )
1267
- fs.silent_remove(image_path)
1317
+ sly_fs.silent_remove(image_path)
1268
1318
  return self._format_output(anns, slides_data)[0]
1269
1319
 
1270
1320
  def _inference_video_id(self, api: Api, state: dict, async_inference_request_uuid: str = None):
@@ -2208,28 +2258,46 @@ class Inference:
2208
2258
 
2209
2259
  if is_debug_with_sly_net():
2210
2260
  # advanced debug for Supervisely Team
2211
- logger.warn(
2261
+ logger.warning(
2212
2262
  "Serving is running in advanced development mode with Supervisely VPN Network"
2213
2263
  )
2214
- team_id = env.team_id()
2264
+ team_id = sly_env.team_id()
2215
2265
  # sly_app_development.supervisely_vpn_network(action="down") # for debug
2216
2266
  sly_app_development.supervisely_vpn_network(action="up")
2217
2267
  task = sly_app_development.create_debug_task(team_id, port="8000")
2218
2268
  self._task_id = task["id"]
2219
2269
  os.environ["TASK_ID"] = str(self._task_id)
2220
2270
  else:
2221
- self._task_id = env.task_id() if is_production() else None
2271
+ if not self._is_local_deploy:
2272
+ self._task_id = sly_env.task_id() if is_production() else None
2222
2273
 
2223
2274
  if isinstance(self.gui, GUI.InferenceGUI):
2224
2275
  self._app = Application(layout=self.get_ui())
2225
2276
  elif isinstance(self.gui, GUI.ServingGUI):
2226
2277
  self._app = Application(layout=self._app_layout)
2278
+ # elif isinstance(self.gui, GUI.InferenceGUI):
2279
+ # self._app = Application(layout=self.get_ui())
2227
2280
  else:
2228
2281
  self._app = Application(layout=self.get_ui())
2229
2282
 
2230
2283
  server = self._app.get_server()
2231
2284
  self._app.set_ready_check_function(self.is_model_deployed)
2232
2285
 
2286
+ if self._is_local_deploy:
2287
+ # Predict and shutdown
2288
+ if any(
2289
+ [
2290
+ self._args.predict_project,
2291
+ self._args.predict_dataset,
2292
+ self._args.predict_dir,
2293
+ self._args.predict_image,
2294
+ ]
2295
+ ):
2296
+ self._inference_by_local_deploy_args()
2297
+ # Gracefully shut down the server
2298
+ self._app.shutdown()
2299
+ # else: run server after endpoints
2300
+
2233
2301
  @call_on_autostart()
2234
2302
  def autostart_func():
2235
2303
  gpu_count = get_gpu_count()
@@ -2725,6 +2793,285 @@ class Inference:
2725
2793
  def _get_deploy_info():
2726
2794
  return asdict(self._get_deploy_info())
2727
2795
 
2796
+ # Local deploy without predict args
2797
+ if self._is_local_deploy:
2798
+ self._run_server()
2799
+
2800
+ def _parse_local_deploy_args(self):
2801
+ parser = argparse.ArgumentParser(description="Run Inference Serving")
2802
+
2803
+ # Deploy args
2804
+ parser.add_argument(
2805
+ "--model",
2806
+ type=str,
2807
+ help="Name of the pretrained model or path to custom checkpoint file",
2808
+ )
2809
+ parser.add_argument(
2810
+ "--device",
2811
+ type=str,
2812
+ choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"],
2813
+ default="cuda:0",
2814
+ help="Device to use for inference (default: 'cuda:0')",
2815
+ )
2816
+ parser.add_argument(
2817
+ "--runtime",
2818
+ type=str,
2819
+ choices=[RuntimeType.PYTORCH, RuntimeType.ONNXRUNTIME, RuntimeType.TENSORRT],
2820
+ default=RuntimeType.PYTORCH,
2821
+ help="Runtime type for inference (default: PYTORCH)",
2822
+ )
2823
+ # -------------------------- #
2824
+
2825
+ # Predict args
2826
+ parser.add_argument("--predict-project", type=int, required=False, help="ID of the project")
2827
+ parser.add_argument(
2828
+ "--predict-dataset",
2829
+ type=lambda x: [int(i) for i in x.split(",")] if "," in x else int(x),
2830
+ required=False,
2831
+ help="ID of the dataset or a comma-separated list of dataset IDs",
2832
+ )
2833
+ parser.add_argument(
2834
+ "--predict-dir",
2835
+ type=str,
2836
+ required=False,
2837
+ help="Not implemented yet. Path to the local directory with images",
2838
+ )
2839
+ parser.add_argument(
2840
+ "--predict-image",
2841
+ type=str,
2842
+ required=False,
2843
+ help="Image ID on Supervisely instance or path to local image",
2844
+ )
2845
+ # -------------------------- #
2846
+
2847
+ # Output args
2848
+ parser.add_argument("--output", type=str, required=False, help="Not implemented yet")
2849
+ parser.add_argument("--output-dir", type=str, required=False, help="Not implemented yet")
2850
+ # -------------------------- #
2851
+
2852
+ # Parse arguments
2853
+ args, _ = parser.parse_known_args()
2854
+ if args.model is None:
2855
+ # raise ValueError("Argument '--model' is required for local deployment")
2856
+ return None, False
2857
+ if isinstance(args.predict_dataset, int):
2858
+ args.predict_dataset = [args.predict_dataset]
2859
+ if args.predict_image is not None:
2860
+ if args.predict_image.isdigit():
2861
+ args.predict_image = int(args.predict_image)
2862
+ return args, True
2863
+
2864
+ def _get_pretrained_model_params_from_args(self):
2865
+ model_files = None
2866
+ model_source = None
2867
+ model_info = None
2868
+ need_download = True
2869
+
2870
+ model = self._args.model
2871
+ for m in self.pretrained_models:
2872
+ meta = m.get("meta", None)
2873
+ if meta is None:
2874
+ continue
2875
+ model_name = meta.get("model_name", None)
2876
+ if model_name is None:
2877
+ continue
2878
+ m_files = meta.get("model_files", None)
2879
+ if m_files is None:
2880
+ continue
2881
+ checkpoint = m_files.get("checkpoint", None)
2882
+ if checkpoint is None:
2883
+ continue
2884
+ if model == m["meta"]["model_name"]:
2885
+ model_info = m
2886
+ model_source = ModelSource.PRETRAINED
2887
+ model_files = {"checkpoint": checkpoint}
2888
+ config = m_files.get("config", None)
2889
+ if config is not None:
2890
+ model_files["config"] = config
2891
+ break
2892
+
2893
+ return model_files, model_source, model_info, need_download
2894
+
2895
+ def _get_custom_model_params_from_args(self):
2896
+ def _load_experiment_info(artifacts_dir):
2897
+ experiment_path = os.path.join(artifacts_dir, "experiment_info.json")
2898
+ model_info = self._load_json_file(experiment_path)
2899
+ original_model_files = model_info.get("model_files")
2900
+ if not original_model_files:
2901
+ raise ValueError("Invalid 'experiment_info.json'. Missing 'model_files' key.")
2902
+ return model_info, original_model_files
2903
+
2904
+ def _prepare_local_model_files(artifacts_dir, checkpoint_path, original_model_files):
2905
+ return {k: os.path.join(artifacts_dir, v) for k, v in original_model_files.items()} | {
2906
+ "checkpoint": checkpoint_path
2907
+ }
2908
+
2909
+ def _download_remote_files(team_id, artifacts_dir, local_artifacts_dir):
2910
+ sly_fs.mkdir(local_artifacts_dir, True)
2911
+ file_infos = self.api.file.list(team_id, artifacts_dir, False, "fileinfo")
2912
+ remote_paths = [f.path for f in file_infos if not f.is_dir]
2913
+ local_paths = [
2914
+ os.path.join(local_artifacts_dir, f.name) for f in file_infos if not f.is_dir
2915
+ ]
2916
+
2917
+ coro = self.api.file.download_bulk_async(team_id, remote_paths, local_paths)
2918
+ loop = get_or_create_event_loop()
2919
+ if loop.is_running():
2920
+ future = asyncio.run_coroutine_threadsafe(coro, loop)
2921
+ future.result()
2922
+ else:
2923
+ loop.run_until_complete(coro)
2924
+
2925
+ model_source = ModelSource.CUSTOM
2926
+ need_download = False
2927
+ checkpoint_path = self._args.model
2928
+
2929
+ if not os.path.isfile(checkpoint_path):
2930
+ team_id = sly_env.team_id(raise_not_found=False)
2931
+ if not team_id:
2932
+ raise ValueError(
2933
+ "Team ID not found in env. Required for remote custom checkpoints."
2934
+ )
2935
+ file_info = self.api.file.get_info_by_path(team_id, checkpoint_path)
2936
+ if not file_info:
2937
+ raise ValueError(
2938
+ f"Couldn't find: '{checkpoint_path}' locally or remotely in Team ID."
2939
+ )
2940
+ need_download = True
2941
+
2942
+ artifacts_dir = os.path.dirname(os.path.dirname(checkpoint_path))
2943
+ if not need_download:
2944
+ model_info, original_model_files = _load_experiment_info(artifacts_dir)
2945
+ model_files = _prepare_local_model_files(
2946
+ artifacts_dir, checkpoint_path, original_model_files
2947
+ )
2948
+ else:
2949
+ local_artifacts_dir = os.path.join(
2950
+ self.model_dir, "local_deploy", os.path.basename(artifacts_dir)
2951
+ )
2952
+ _download_remote_files(team_id, artifacts_dir, local_artifacts_dir)
2953
+
2954
+ model_info, original_model_files = _load_experiment_info(local_artifacts_dir)
2955
+ model_files = _prepare_local_model_files(
2956
+ local_artifacts_dir, checkpoint_path, original_model_files
2957
+ )
2958
+ return model_files, model_source, model_info, need_download
2959
+
2960
+ def _get_deploy_params_from_args(self):
2961
+ # Ensure model directory exists
2962
+ device = self._args.device if self._args.device else "cuda:0"
2963
+ runtime = self._args.runtime if self._args.runtime else RuntimeType.PYTORCH
2964
+
2965
+ model_files, model_source, model_info, need_download = (
2966
+ self._get_pretrained_model_params_from_args()
2967
+ )
2968
+ if model_source is None:
2969
+ model_files, model_source, model_info, need_download = (
2970
+ self._get_custom_model_params_from_args()
2971
+ )
2972
+
2973
+ if model_source is None:
2974
+ raise ValueError("Couldn't create 'model_source' from args")
2975
+ if model_files is None:
2976
+ raise ValueError("Couldn't create 'model_files' from args")
2977
+ if model_info is None:
2978
+ raise ValueError("Couldn't create 'model_info' from args")
2979
+
2980
+ deploy_params = {
2981
+ "model_files": model_files,
2982
+ "model_source": model_source,
2983
+ "model_info": model_info,
2984
+ "device": device,
2985
+ "runtime": runtime,
2986
+ }
2987
+
2988
+ logger.info(f"Deploy parameters: {deploy_params}")
2989
+ return deploy_params, need_download
2990
+
2991
+ def _run_server(self):
2992
+ config = uvicorn.Config(app=self._app, host="0.0.0.0", port=8000, ws="websockets")
2993
+ self._uvicorn_server = uvicorn.Server(config)
2994
+ self._uvicorn_server.run()
2995
+
2996
+ def _inference_by_local_deploy_args(self):
2997
+ def predict_project_by_args(api: Api, project_id: int, dataset_ids: List[int] = None):
2998
+ source_project = api.project.get_info_by_id(project_id)
2999
+ workspace_id = source_project.workspace_id
3000
+ output_project = api.project.create(
3001
+ workspace_id, f"{source_project.name} predicted", change_name_if_conflict=True
3002
+ )
3003
+ results = self._inference_project_id(
3004
+ api=self.api,
3005
+ state={
3006
+ "projectId": project_id,
3007
+ "dataset_ids": dataset_ids,
3008
+ "output_project_id": output_project.id,
3009
+ },
3010
+ )
3011
+
3012
+ def predict_datasets_by_args(api: Api, dataset_ids: List[int]):
3013
+ dataset_infos = [api.dataset.get_info_by_id(dataset_id) for dataset_id in dataset_ids]
3014
+ project_ids = list(set([dataset_info.project_id for dataset_info in dataset_infos]))
3015
+ if len(project_ids) > 1:
3016
+ raise ValueError("All datasets should belong to the same project")
3017
+ predict_project_by_args(api, project_ids[0], dataset_ids)
3018
+
3019
+ def predict_image_by_args(api: Api, image: Union[str, int]):
3020
+ def predict_image_np(image_np):
3021
+ settings = self._get_inference_settings({})
3022
+ anns, _ = self._inference_auto([image_np], settings)
3023
+ if len(anns) == 0:
3024
+ return Annotation(img_size=image_np.shape[:2])
3025
+ ann = anns[0]
3026
+ return ann
3027
+
3028
+ if isinstance(image, int):
3029
+ image_np = api.image.download_np(image)
3030
+ ann = predict_image_np(image_np)
3031
+ api.annotation.upload_ann(image, ann)
3032
+ elif isinstance(image, str):
3033
+ if sly_fs.file_exists(self._args.predict):
3034
+ image_np = sly_image.read(self._args.predict)
3035
+ ann = predict_image_np(image_np)
3036
+ pred_ann_path = image + ".json"
3037
+ sly_json.dump_json_file(ann.to_json(), pred_ann_path)
3038
+ # Save image for debug
3039
+ # ann.draw_pretty(image_np)
3040
+ # pred_path = os.path.join(os.path.dirname(self._args.predict), "pred_" + os.path.basename(self._args.predict))
3041
+ # sly_image.write(pred_path, image_np)
3042
+
3043
+ if self._args.predict_project is not None:
3044
+ predict_project_by_args(self.api, self._args.predict_project)
3045
+ elif self._args.predict_dataset is not None:
3046
+ predict_datasets_by_args(self.api, self._args.predict_dataset)
3047
+ elif self._args.predict_dir is not None:
3048
+ raise NotImplementedError("Predict from directory is not implemented yet")
3049
+ elif self._args.predict_image is not None:
3050
+ predict_image_by_args(self.api, self._args.predict_image)
3051
+
3052
+ def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
3053
+ if model_source == ModelSource.PRETRAINED:
3054
+ checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
3055
+ checkpoint_name = model_info["meta"]["model_name"]
3056
+ else:
3057
+ checkpoint_name = sly_fs.get_file_name_with_ext(model_files["checkpoint"])
3058
+ checkpoint_url = os.path.join(
3059
+ model_info["artifacts_dir"], "checkpoints", checkpoint_name
3060
+ )
3061
+
3062
+ app_name = sly_env.app_name()
3063
+ meta = WorkflowMeta(node_settings=WorkflowSettings(title=f"Serve {app_name}"))
3064
+
3065
+ logger.debug(
3066
+ f"Workflow Input: Checkpoint URL - {checkpoint_url}, Checkpoint Name - {checkpoint_name}"
3067
+ )
3068
+ if checkpoint_url and self.api.file.exists(sly_env.team_id(), checkpoint_url):
3069
+ self.api.app.workflow.add_input_file(checkpoint_url, model_weight=True, meta=meta)
3070
+ else:
3071
+ logger.debug(
3072
+ f"Checkpoint {checkpoint_url} not found in Team Files. Cannot set workflow input"
3073
+ )
3074
+
2728
3075
 
2729
3076
  def _get_log_extra_for_inference_request(inference_request_uuid, inference_request: dict):
2730
3077
  log_extra = {
@@ -2998,7 +3345,7 @@ class TempImageWriter:
2998
3345
  def __init__(self, format: str = "png"):
2999
3346
  self.format = format
3000
3347
  self.temp_dir = os.path.join(get_data_dir(), rand_str(10))
3001
- fs.mkdir(self.temp_dir)
3348
+ sly_fs.mkdir(self.temp_dir)
3002
3349
 
3003
3350
  def write(self, image: np.ndarray):
3004
3351
  image_path = os.path.join(self.temp_dir, f"{rand_str(10)}.{self.format}")
@@ -3006,7 +3353,7 @@ class TempImageWriter:
3006
3353
  return image_path
3007
3354
 
3008
3355
  def clean(self):
3009
- fs.remove_dir(self.temp_dir)
3356
+ sly_fs.remove_dir(self.temp_dir)
3010
3357
 
3011
3358
 
3012
3359
  def get_hardware_info(device: str) -> str:
@@ -5,6 +5,8 @@ This module provides the `TrainGUI` class that handles the graphical user interf
5
5
  training workflows in Supervisely.
6
6
  """
7
7
 
8
+ from os import environ
9
+
8
10
  import supervisely.io.env as sly_env
9
11
  from supervisely import Api, ProjectMeta
10
12
  from supervisely._utils import is_production
@@ -47,7 +49,6 @@ class TrainGUI:
47
49
  app_options: dict = None,
48
50
  ):
49
51
  self._api = Api.from_env()
50
-
51
52
  if is_production():
52
53
  self.task_id = sly_env.task_id()
53
54
  else:
@@ -61,12 +62,19 @@ class TrainGUI:
61
62
  self.app_options = app_options
62
63
  self.collapsable = app_options.get("collapsable", False)
63
64
 
64
- self.team_id = sly_env.team_id()
65
- self.workspace_id = sly_env.workspace_id()
66
- self.project_id = sly_env.project_id() # from app options?
65
+ self.team_id = sly_env.team_id(raise_not_found=False)
66
+ self.workspace_id = sly_env.workspace_id(raise_not_found=False)
67
+ self.project_id = sly_env.project_id()
67
68
  self.project_info = self._api.project.get_info_by_id(self.project_id)
68
69
  self.project_meta = ProjectMeta.from_json(self._api.project.get_meta(self.project_id))
69
70
 
71
+ if self.workspace_id is None:
72
+ self.workspace_id = self.project_info.workspace_id
73
+ environ["WORKSPACE_ID"] = str(self.workspace_id)
74
+ if self.team_id is None:
75
+ self.team_id = self.project_info.team_id
76
+ environ["TEAM_ID"] = str(self.team_id)
77
+
70
78
  # 1. Project selection + Train/val split
71
79
  self.input_selector = InputSelector(self.project_info, self.app_options)
72
80
  # 2. Select train val splits
@@ -90,7 +98,7 @@ class TrainGUI:
90
98
  self.training_logs = TrainingLogs(self.app_options)
91
99
 
92
100
  # 8. Training Artifacts
93
- self.training_artifacts = TrainingArtifacts(self.app_options)
101
+ self.training_artifacts = TrainingArtifacts(self._api, self.app_options)
94
102
 
95
103
  # Stepper layout
96
104
  self.steps = [
@@ -1,6 +1,10 @@
1
+ import os
1
2
  from typing import Any, Dict
2
3
 
4
+ import supervisely.io.env as sly_env
3
5
  from supervisely import Api
6
+ from supervisely._utils import is_production
7
+ from supervisely.api.api import ApiField
4
8
  from supervisely.app.widgets import (
5
9
  Card,
6
10
  Container,
@@ -11,18 +15,24 @@ from supervisely.app.widgets import (
11
15
  ReportThumbnail,
12
16
  Text,
13
17
  )
18
+ from supervisely.io.fs import file_exists
14
19
 
15
20
  PYTORCH_ICON = "https://img.icons8.com/?size=100&id=jH4BpkMnRrU5&format=png&color=000000"
16
21
  ONNX_ICON = "https://artwork.lfaidata.foundation/projects/onnx/icon/color/onnx-icon-color.png"
17
22
  TRT_ICON = "https://img.icons8.com/?size=100&id=yqf95864UzeQ&format=png&color=000000"
18
23
 
24
+ OVERVIEW_FILE_NAME = "README.md"
25
+ PYTORCH_FILE_NAME = "demo_pytorch.py"
26
+ ONNX_FILE_NAME = "demo_onnx.py"
27
+ TRT_FILE_NAME = "demo_tensorrt.py"
28
+
19
29
 
20
30
  class TrainingArtifacts:
21
31
  title = "Training Artifacts"
22
32
  description = "All outputs of the training process will appear here"
23
33
  lock_message = "Artifacts will be available after training is completed"
24
34
 
25
- def __init__(self, app_options: Dict[str, Any]):
35
+ def __init__(self, api: Api, app_options: Dict[str, Any]):
26
36
  self.display_widgets = []
27
37
  self.success_message_text = (
28
38
  "Training completed. Training artifacts were uploaded to Team Files. "
@@ -73,59 +83,107 @@ class TrainingArtifacts:
73
83
  # -------------------------------- #
74
84
 
75
85
  # PyTorch, ONNX, TensorRT demo
76
- self.inference_demo_field = []
86
+ self.inference_demo_widgets = []
87
+
77
88
  model_demo = self.app_options.get("demo", None)
78
89
  if model_demo is not None:
79
- pytorch_demo_link = model_demo.get("pytorch", None)
80
- if pytorch_demo_link is not None:
81
- pytorch_icon = Field.Icon(image_url=PYTORCH_ICON, bg_color_rgb=[255, 255, 255])
82
- self.pytorch_instruction = Field(
83
- title="PyTorch",
84
- description="Open file",
85
- description_url=pytorch_demo_link,
86
- icon=pytorch_icon,
87
- content=Empty(),
88
- )
89
- self.pytorch_instruction.hide()
90
- self.inference_demo_field.extend([self.pytorch_instruction])
91
-
92
- onnx_demo_link = model_demo.get("onnx", None)
93
- if onnx_demo_link is not None:
94
- if self.app_options.get("export_onnx_supported", False):
95
- onnx_icon = Field.Icon(image_url=ONNX_ICON, bg_color_rgb=[255, 255, 255])
96
- self.onnx_instruction = Field(
97
- title="ONNX",
98
- description="Open file",
99
- description_url=onnx_demo_link,
100
- icon=onnx_icon,
101
- content=Empty(),
90
+ model_demo_path = model_demo.get("path", None)
91
+ if model_demo_path is not None:
92
+ model_demo_gh_link = None
93
+ if is_production():
94
+ task_id = sly_env.task_id()
95
+ task_info = api.task.get_info_by_id(task_id)
96
+ app_id = task_info["meta"]["app"]["id"]
97
+ app_info = api.app.get_info_by_id(app_id)
98
+ model_demo_gh_link = app_info.repo
99
+ else:
100
+ app_name = sly_env.app_name()
101
+ team_id = sly_env.team_id()
102
+ apps = api.app.get_list(
103
+ team_id,
104
+ filter=[{"field": "name", "operator": "=", "value": app_name}],
105
+ only_running=True,
102
106
  )
103
- self.onnx_instruction.hide()
104
- self.inference_demo_field.extend([self.onnx_instruction])
105
-
106
- trt_demo_link = model_demo.get("tensorrt", None)
107
- if trt_demo_link is not None:
108
- if self.app_options.get("export_tensorrt_supported", False):
109
- trt_icon = Field.Icon(image_url=TRT_ICON, bg_color_rgb=[255, 255, 255])
110
- self.trt_instruction = Field(
111
- title="TensorRT",
112
- description="Open file",
113
- description_url=trt_demo_link,
114
- icon=trt_icon,
115
- content=Empty(),
116
- )
117
- self.trt_instruction.hide()
118
- self.inference_demo_field.extend([self.trt_instruction])
119
-
120
- demo_overview_link = model_demo.get("overview", None)
121
- self.inference_demo_field = Field(
122
- title="How to run inference",
123
- description="Instructions on how to use your checkpoints outside of Supervisely Platform",
124
- content=Flexbox(self.inference_demo_field),
125
- title_url=demo_overview_link,
126
- )
127
- self.inference_demo_field.hide()
128
- self.display_widgets.extend([self.inference_demo_field])
107
+ if len(apps) == 1:
108
+ app_info = apps[0]
109
+ model_demo_gh_link = app_info.repo
110
+
111
+ if model_demo_gh_link is not None:
112
+ gh_branch = "blob/main"
113
+ link_to_demo = f"{model_demo_gh_link}/{gh_branch}/{model_demo_path}"
114
+
115
+ if model_demo_gh_link is not None and model_demo_path is not None:
116
+ # PyTorch
117
+ local_pytorch_demo = os.path.join(
118
+ os.getcwd(), model_demo_path, PYTORCH_FILE_NAME
119
+ )
120
+ if file_exists(local_pytorch_demo):
121
+ pytorch_demo_link = f"{link_to_demo}/{PYTORCH_FILE_NAME}"
122
+ pytorch_icon = Field.Icon(
123
+ image_url=PYTORCH_ICON, bg_color_rgb=[255, 255, 255]
124
+ )
125
+ self.pytorch_instruction = Field(
126
+ title="PyTorch",
127
+ description="Open file",
128
+ description_url=pytorch_demo_link,
129
+ icon=pytorch_icon,
130
+ content=Empty(),
131
+ )
132
+ self.pytorch_instruction.hide()
133
+ self.inference_demo_widgets.extend([self.pytorch_instruction])
134
+
135
+ # ONNX
136
+ local_onnx_demo = os.path.join(os.getcwd(), model_demo_path, ONNX_FILE_NAME)
137
+ if file_exists(local_onnx_demo):
138
+ if self.app_options.get("export_onnx_supported", False):
139
+ onnx_demo_link = f"{link_to_demo}/{ONNX_FILE_NAME}"
140
+ onnx_icon = Field.Icon(
141
+ image_url=ONNX_ICON, bg_color_rgb=[255, 255, 255]
142
+ )
143
+ self.onnx_instruction = Field(
144
+ title="ONNX",
145
+ description="Open file",
146
+ description_url=onnx_demo_link,
147
+ icon=onnx_icon,
148
+ content=Empty(),
149
+ )
150
+ self.onnx_instruction.hide()
151
+ self.inference_demo_widgets.extend([self.onnx_instruction])
152
+
153
+ # TensorRT
154
+ local_trt_demo = os.path.join(os.getcwd(), model_demo_path, TRT_FILE_NAME)
155
+ if file_exists(local_trt_demo):
156
+ if self.app_options.get("export_tensorrt_supported", False):
157
+ trt_demo_link = f"{link_to_demo}/{TRT_FILE_NAME}"
158
+ trt_icon = Field.Icon(
159
+ image_url=TRT_ICON, bg_color_rgb=[255, 255, 255]
160
+ )
161
+ self.trt_instruction = Field(
162
+ title="TensorRT",
163
+ description="Open file",
164
+ description_url=trt_demo_link,
165
+ icon=trt_icon,
166
+ content=Empty(),
167
+ )
168
+ self.trt_instruction.hide()
169
+ self.inference_demo_widgets.extend([self.trt_instruction])
170
+
171
+ local_demo_overview = os.path.join(
172
+ os.getcwd(), model_demo_path, OVERVIEW_FILE_NAME
173
+ )
174
+ if file_exists(local_demo_overview):
175
+ demo_overview_link = os.path.join(link_to_demo, OVERVIEW_FILE_NAME)
176
+ else:
177
+ demo_overview_link = None
178
+
179
+ self.inference_demo_field = Field(
180
+ title="How to run inference",
181
+ description="Instructions on how to use your checkpoints outside of Supervisely Platform",
182
+ content=Flexbox(self.inference_demo_widgets),
183
+ title_url=demo_overview_link,
184
+ )
185
+ self.inference_demo_field.hide()
186
+ self.display_widgets.extend([self.inference_demo_field])
129
187
  # -------------------------------- #
130
188
 
131
189
  self.container = Container(self.display_widgets)
@@ -143,3 +201,15 @@ class TrainingArtifacts:
143
201
 
144
202
  def validate_step(self) -> bool:
145
203
  return True
204
+
205
+ def overview_demo_exists(self, demo_path: str):
206
+ return file_exists(os.path.join(os.getcwd(), demo_path, OVERVIEW_FILE_NAME))
207
+
208
+ def pytorch_demo_exists(self, demo_path: str):
209
+ return file_exists(os.path.join(os.getcwd(), demo_path, PYTORCH_FILE_NAME))
210
+
211
+ def onnx_demo_exists(self, demo_path: str):
212
+ return file_exists(os.path.join(os.getcwd(), demo_path, ONNX_FILE_NAME))
213
+
214
+ def trt_demo_exists(self, demo_path: str):
215
+ return file_exists(os.path.join(os.getcwd(), demo_path, TRT_FILE_NAME))
@@ -8,7 +8,7 @@ training workflows in a Supervisely application.
8
8
  import shutil
9
9
  import subprocess
10
10
  from datetime import datetime
11
- from os import listdir
11
+ from os import getcwd, listdir
12
12
  from os.path import basename, exists, expanduser, isdir, isfile, join
13
13
  from typing import Any, Dict, List, Literal, Optional, Union
14
14
  from urllib.request import urlopen
@@ -124,8 +124,6 @@ class TrainApp:
124
124
  logger.info("TrainApp is running in debug mode")
125
125
 
126
126
  self.framework_name = framework_name
127
- self._team_id = sly_env.team_id()
128
- self._workspace_id = sly_env.workspace_id()
129
127
  self._tensorboard_process = None
130
128
 
131
129
  self._models = self._load_models(models)
@@ -249,6 +247,26 @@ class TrainApp:
249
247
  # ----------------------------------------- #
250
248
 
251
249
  # Input Data
250
+ @property
251
+ def team_id(self) -> int:
252
+ """
253
+ Returns the ID of the team.
254
+
255
+ :return: Team ID.
256
+ :rtype: int
257
+ """
258
+ return self.gui.team_id
259
+
260
+ @property
261
+ def workspace_id(self) -> int:
262
+ """
263
+ Returns the ID of the workspace.
264
+
265
+ :return: Workspace ID.
266
+ :rtype: int
267
+ """
268
+ return self.gui.workspace_id
269
+
252
270
  @property
253
271
  def project_id(self) -> int:
254
272
  """
@@ -555,6 +573,7 @@ class TrainApp:
555
573
  self._generate_hyperparameters(remote_dir, experiment_info)
556
574
  self._generate_train_val_splits(remote_dir, splits_data)
557
575
  self._generate_model_meta(remote_dir, experiment_info)
576
+ self._upload_demo_files(remote_dir)
558
577
 
559
578
  # Step 7. Set output widgets
560
579
  self._set_text_status("reset")
@@ -1074,7 +1093,7 @@ class TrainApp:
1074
1093
  ) as model_download_main_pbar:
1075
1094
  self.progress_bar_main.show()
1076
1095
  for name, remote_path in remote_paths.items():
1077
- file_info = self._api.file.get_info_by_path(self._team_id, remote_path)
1096
+ file_info = self._api.file.get_info_by_path(self.team_id, remote_path)
1078
1097
  file_name = basename(remote_path)
1079
1098
  local_path = join(self.model_dir, file_name)
1080
1099
  file_size = file_info.sizeb
@@ -1087,7 +1106,7 @@ class TrainApp:
1087
1106
  ) as model_download_secondary_pbar:
1088
1107
  self.progress_bar_secondary.show()
1089
1108
  self._api.file.download(
1090
- self._team_id,
1109
+ self.team_id,
1091
1110
  remote_path,
1092
1111
  local_path,
1093
1112
  progress_cb=model_download_secondary_pbar.update,
@@ -1325,7 +1344,7 @@ class TrainApp:
1325
1344
  ) as upload_artifacts_pbar:
1326
1345
  self.progress_bar_main.show()
1327
1346
  self._api.file.upload(
1328
- self._team_id,
1347
+ self.team_id,
1329
1348
  local_path,
1330
1349
  remote_path,
1331
1350
  progress_cb=upload_artifacts_pbar,
@@ -1421,7 +1440,7 @@ class TrainApp:
1421
1440
 
1422
1441
  remote_checkpoints_dir = join(remote_dir, self._remote_checkpoints_dir_name)
1423
1442
  checkpoint_files = self._api.file.list(
1424
- self._team_id, remote_checkpoints_dir, return_type="fileinfo"
1443
+ self.team_id, remote_checkpoints_dir, return_type="fileinfo"
1425
1444
  )
1426
1445
  experiment_info["checkpoints"] = [
1427
1446
  f"checkpoints/{checkpoint.name}" for checkpoint in checkpoint_files
@@ -1482,6 +1501,38 @@ class TrainApp:
1482
1501
  local_path, remote_path, f"Uploading '{self._app_state_file}' to Team Files"
1483
1502
  )
1484
1503
 
1504
+ def _upload_demo_files(self, remote_dir: str) -> None:
1505
+ demo = self._app_options.get("demo")
1506
+ if demo is None:
1507
+ return
1508
+ demo_path = demo.get("path")
1509
+ if demo_path is None:
1510
+ return
1511
+
1512
+ local_demo_dir = join(getcwd(), demo_path)
1513
+ if not sly_fs.dir_exists(local_demo_dir):
1514
+ logger.info(f"Demo directory '{local_demo_dir}' does not exist")
1515
+ return
1516
+
1517
+ logger.debug(f"Uploading demo files to Supervisely")
1518
+ remote_demo_dir = join(remote_dir, "demo")
1519
+ local_files = sly_fs.list_files_recursively(local_demo_dir)
1520
+ total_size = sum([sly_fs.get_file_size(file_path) for file_path in local_files])
1521
+ with self.progress_bar_main(
1522
+ message="Uploading demo files to Team Files",
1523
+ total=total_size,
1524
+ unit="bytes",
1525
+ unit_scale=True,
1526
+ ) as upload_artifacts_pbar:
1527
+ self.progress_bar_main.show()
1528
+ remote_dir = self._api.file.upload_directory(
1529
+ self.team_id,
1530
+ local_demo_dir,
1531
+ remote_demo_dir,
1532
+ progress_size_cb=upload_artifacts_pbar,
1533
+ )
1534
+ self.progress_bar_main.hide()
1535
+
1485
1536
  def _get_train_val_splits_for_app_state(self) -> Dict:
1486
1537
  """
1487
1538
  Gets the train and val splits information for app_state.json.
@@ -1557,13 +1608,13 @@ class TrainApp:
1557
1608
 
1558
1609
  # Clean debug directory if exists
1559
1610
  if task_id == "debug-session":
1560
- if self._api.file.dir_exists(self._team_id, f"{remote_artifacts_dir}/", True):
1611
+ if self._api.file.dir_exists(self.team_id, f"{remote_artifacts_dir}/", True):
1561
1612
  with self.progress_bar_main(
1562
1613
  message=f"[Debug] Cleaning train artifacts: '{remote_artifacts_dir}/'",
1563
1614
  total=1,
1564
1615
  ) as upload_artifacts_pbar:
1565
1616
  self.progress_bar_main.show()
1566
- self._api.file.remove_dir(self._team_id, f"{remote_artifacts_dir}", True)
1617
+ self._api.file.remove_dir(self.team_id, f"{remote_artifacts_dir}", True)
1567
1618
  upload_artifacts_pbar.update(1)
1568
1619
  self.progress_bar_main.hide()
1569
1620
 
@@ -1586,14 +1637,14 @@ class TrainApp:
1586
1637
  ) as upload_artifacts_pbar:
1587
1638
  self.progress_bar_main.show()
1588
1639
  remote_dir = self._api.file.upload_directory(
1589
- self._team_id,
1640
+ self.team_id,
1590
1641
  self.output_dir,
1591
1642
  remote_artifacts_dir,
1592
1643
  progress_size_cb=upload_artifacts_pbar,
1593
1644
  )
1594
1645
  self.progress_bar_main.hide()
1595
1646
 
1596
- file_info = self._api.file.get_info_by_path(self._team_id, join(remote_dir, "open_app.lnk"))
1647
+ file_info = self._api.file.get_info_by_path(self.team_id, join(remote_dir, "open_app.lnk"))
1597
1648
  return remote_dir, file_info
1598
1649
 
1599
1650
  def _set_training_output(
@@ -1629,33 +1680,36 @@ class TrainApp:
1629
1680
 
1630
1681
  # Set instruction to GUI
1631
1682
  demo_options = self._app_options.get("demo", {})
1632
- if demo_options:
1683
+ demo_path = demo_options.get("path", None)
1684
+ if demo_path is not None:
1633
1685
  # Show PyTorch demo if available
1634
- pytorch_demo = demo_options.get("pytorch")
1635
- if pytorch_demo:
1686
+ if self.gui.training_artifacts.pytorch_demo_exists(demo_path):
1636
1687
  self.gui.training_artifacts.pytorch_instruction.show()
1637
1688
 
1638
1689
  # Show ONNX demo if supported and available
1639
- onnx_demo = demo_options.get("onnx")
1640
1690
  if (
1641
1691
  self._app_options.get("export_onnx_supported", False)
1642
1692
  and self.gui.hyperparameters_selector.get_export_onnx_checkbox_value()
1643
- and onnx_demo
1693
+ and self.gui.training_artifacts.onnx_demo_exists(demo_path)
1644
1694
  ):
1645
1695
  self.gui.training_artifacts.onnx_instruction.show()
1646
1696
 
1647
1697
  # Show TensorRT demo if supported and available
1648
- tensorrt_demo = demo_options.get("tensorrt")
1649
1698
  if (
1650
1699
  self._app_options.get("export_tensorrt_supported", False)
1651
1700
  and self.gui.hyperparameters_selector.get_export_tensorrt_checkbox_value()
1652
- and tensorrt_demo
1701
+ and self.gui.training_artifacts.trt_demo_exists(demo_path)
1653
1702
  ):
1654
1703
  self.gui.training_artifacts.trt_instruction.show()
1655
1704
 
1656
1705
  # Show the inference demo widget if overview or any demo is available
1657
- demo_overview = self._app_options.get("overview", {})
1658
- if demo_overview or any([pytorch_demo, onnx_demo, tensorrt_demo]):
1706
+ if self.gui.training_artifacts.overview_demo_exists(demo_path) or any(
1707
+ [
1708
+ self.gui.training_artifacts.pytorch_demo_exists(demo_path),
1709
+ self.gui.training_artifacts.onnx_demo_exists(demo_path),
1710
+ self.gui.training_artifacts.trt_demo_exists(demo_path),
1711
+ ]
1712
+ ):
1659
1713
  self.gui.training_artifacts.inference_demo_field.show()
1660
1714
  # ---------------------------- #
1661
1715
 
@@ -1676,7 +1730,7 @@ class TrainApp:
1676
1730
  eval_res_dir = (
1677
1731
  f"/model-benchmark/{self.project_info.id}_{self.project_info.name}/{task_dir}/"
1678
1732
  )
1679
- eval_res_dir = self._api.storage.get_free_dir_name(self._team_id, eval_res_dir)
1733
+ eval_res_dir = self._api.storage.get_free_dir_name(self.team_id, eval_res_dir)
1680
1734
  return eval_res_dir
1681
1735
 
1682
1736
  def _run_model_benchmark(
@@ -1742,6 +1796,8 @@ class TrainApp:
1742
1796
  use_gui=False,
1743
1797
  custom_inference_settings=self._inference_settings,
1744
1798
  )
1799
+ if hasattr(m, "in_train"):
1800
+ m.in_train = True
1745
1801
 
1746
1802
  logger.info(f"Using device: {self.device}")
1747
1803
 
@@ -1910,7 +1966,7 @@ class TrainApp:
1910
1966
 
1911
1967
  if self.model_source == ModelSource.CUSTOM:
1912
1968
  file_info = self._api.file.get_info_by_path(
1913
- self._team_id,
1969
+ self.team_id,
1914
1970
  self.gui.model_selector.experiment_selector.get_selected_checkpoint_path(),
1915
1971
  )
1916
1972
  if file_info is not None:
@@ -2319,7 +2375,7 @@ class TrainApp:
2319
2375
  self.progress_bar_secondary.show()
2320
2376
  destination_path = join(remote_dir, self._export_dir_name, file_name)
2321
2377
  self._api.file.upload(
2322
- self._team_id,
2378
+ self.team_id,
2323
2379
  path,
2324
2380
  destination_path,
2325
2381
  export_upload_secondary_pbar,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: supervisely
3
- Version: 6.73.269
3
+ Version: 6.73.271
4
4
  Summary: Supervisely Python SDK.
5
5
  Home-page: https://github.com/supervisely/supervisely
6
6
  Author: Supervisely
@@ -25,7 +25,7 @@ supervisely/api/annotation_api.py,sha256=kB9l0NhQEkunGDC9fWjNzf5DdhqRF1tv-RRnIbk
25
25
  supervisely/api/api.py,sha256=0dgPx_eizoCEFzfT8YH9uh1kq-OJwjrV5fBGD7uZ7E4,65840
26
26
  supervisely/api/app_api.py,sha256=RsbVej8WxWVn9cNo5s3Fqd1symsCdsfOaKVBKEUapRY,71927
27
27
  supervisely/api/dataset_api.py,sha256=GH7prDRJKyJlTv_7_Y-RkTwJN7ED4EkXNqqmi3iIdI4,41352
28
- supervisely/api/file_api.py,sha256=7yWt8lRQ4UfLmnMZ9T18UXzu8jihrtHtcqi6GZJG-0w,83414
28
+ supervisely/api/file_api.py,sha256=v2FsD3oljwNPqcDgEJRe8Bu5k0PYKzVhqmRb5QFaHAQ,83422
29
29
  supervisely/api/github_api.py,sha256=NIexNjEer9H5rf5sw2LEZd7C1WR-tK4t6IZzsgeAAwQ,623
30
30
  supervisely/api/image_annotation_tool_api.py,sha256=YcUo78jRDBJYvIjrd-Y6FJAasLta54nnxhyaGyanovA,5237
31
31
  supervisely/api/image_api.py,sha256=lLt8z_OE7cXwb94_UKxWiSKxe28a4meMrVM7dhHIWZY,176956
@@ -864,7 +864,7 @@ supervisely/nn/benchmark/visualization/widgets/table/__init__.py,sha256=47DEQpj8
864
864
  supervisely/nn/benchmark/visualization/widgets/table/table.py,sha256=atmDnF1Af6qLQBUjLhK18RMDKAYlxnsuVHMSEa5a-e8,4319
865
865
  supervisely/nn/inference/__init__.py,sha256=mtEci4Puu-fRXDnGn8RP47o97rv3VTE0hjbYO34Zwqg,1622
866
866
  supervisely/nn/inference/cache.py,sha256=_pPSpkl8Wkqkiidn0vu6kWE19cngd80av--jncHxMEQ,30510
867
- supervisely/nn/inference/inference.py,sha256=8MrOen2oyYIKiVqy0WbBTwABJZss9MLQ70EwX0e_-es,128895
867
+ supervisely/nn/inference/inference.py,sha256=IDkMFsURnB-vBM-2Kf-M7D5IQw6VUoIwVoh_IIdzA9Q,143958
868
868
  supervisely/nn/inference/session.py,sha256=jmkkxbe2kH-lEgUU6Afh62jP68dxfhF5v6OGDfLU62E,35757
869
869
  supervisely/nn/inference/video_inference.py,sha256=8Bshjr6rDyLay5Za8IB8Dr6FURMO2R_v7aELasO8pR4,5746
870
870
  supervisely/nn/inference/gui/__init__.py,sha256=wCxd-lF5Zhcwsis-wScDA8n1Gk_1O00PKgDviUZ3F1U,221
@@ -960,15 +960,15 @@ supervisely/nn/tracker/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
960
960
  supervisely/nn/tracker/utils/gmc.py,sha256=3JX8979H3NA-YHNaRQyj9Z-xb9qtyMittPEjGw8y2Jo,11557
961
961
  supervisely/nn/tracker/utils/kalman_filter.py,sha256=eSFmCjM0mikHCAFvj-KCVzw-0Jxpoc3Cfc2NWEjJC1Q,17268
962
962
  supervisely/nn/training/__init__.py,sha256=gY4PCykJ-42MWKsqb9kl-skemKa8yB6t_fb5kzqR66U,111
963
- supervisely/nn/training/train_app.py,sha256=Q3evZS3DChaWMEB7kZrabMSsHv621-XEQuyGdfq91nY,93101
963
+ supervisely/nn/training/train_app.py,sha256=wXWQb6Xa49nNHNjCe1YuSur-vHqwjk_J0yoQYCKjHWw,94989
964
964
  supervisely/nn/training/gui/__init__.py,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
965
965
  supervisely/nn/training/gui/classes_selector.py,sha256=8UgzA4aogOAr1s42smwEcDbgaBj_i0JLhjwlZ9bFdIA,3772
966
- supervisely/nn/training/gui/gui.py,sha256=ERMyRqZABLBXcLxvvsF1TtL8VKK8Ak4MwoN9wrL4Dzw,23357
966
+ supervisely/nn/training/gui/gui.py,sha256=nj4EVppoV9ZjLN0rVO0GKxmI56d6Qpp0qwnJJ6srT6w,23712
967
967
  supervisely/nn/training/gui/hyperparameters_selector.py,sha256=2qryuBss0bLcZJV8PNJ6_hKZM5Dbj2FIxTb3EULHQrE,6670
968
968
  supervisely/nn/training/gui/input_selector.py,sha256=Jp9PnVVADv1fhndPuZdMlKuzWTOBQZogrOks5dwATlc,2179
969
969
  supervisely/nn/training/gui/model_selector.py,sha256=QTFHMf-8-rREYPk64QKoRvE4zKPC8V6tcP4H4N6nyt0,4082
970
970
  supervisely/nn/training/gui/train_val_splits_selector.py,sha256=MLryFD2Tj_RobkFzZOeQXzXpch0eGiVFisq3FGA3dFg,8549
971
- supervisely/nn/training/gui/training_artifacts.py,sha256=JoeNn1cXSRrkatjxhYNwL_-yDsBT2aqYugjICMn4KUk,5887
971
+ supervisely/nn/training/gui/training_artifacts.py,sha256=-sJQu5kBaJJp8JhSZzHpQdV4lJFjbd2YOaovbQPyVLM,9583
972
972
  supervisely/nn/training/gui/training_logs.py,sha256=1CBqnL0l5kiZVaegJ-NLgOVI1T4EDB_rLAtumuw18Jo,3222
973
973
  supervisely/nn/training/gui/training_process.py,sha256=wqlwt1cHG-HoVEOotDiBjp9YTTIbeMr1bHY2zVRaNH8,3071
974
974
  supervisely/nn/training/gui/utils.py,sha256=Bi7-BRsAqN7fUkhd7rXVEAqsxhBdIZ2MrrJtrNqVf8I,3905
@@ -1062,9 +1062,9 @@ supervisely/worker_proto/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ
1062
1062
  supervisely/worker_proto/worker_api_pb2.py,sha256=VQfi5JRBHs2pFCK1snec3JECgGnua3Xjqw_-b3aFxuM,59142
1063
1063
  supervisely/worker_proto/worker_api_pb2_grpc.py,sha256=3BwQXOaP9qpdi0Dt9EKG--Lm8KGN0C5AgmUfRv77_Jk,28940
1064
1064
  supervisely_lib/__init__.py,sha256=7-3QnN8Zf0wj8NCr2oJmqoQWMKKPKTECvjH9pd2S5vY,159
1065
- supervisely-6.73.269.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1066
- supervisely-6.73.269.dist-info/METADATA,sha256=NZRYDDissMNHAzHNf8M4N6lVnQCMVCMf5DsJTLKcFb8,33573
1067
- supervisely-6.73.269.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
1068
- supervisely-6.73.269.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
1069
- supervisely-6.73.269.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
1070
- supervisely-6.73.269.dist-info/RECORD,,
1065
+ supervisely-6.73.271.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1066
+ supervisely-6.73.271.dist-info/METADATA,sha256=sRSopR2p5d0tsyK9jC8MYDBbhngm6uX8Z2eQY76yQjE,33573
1067
+ supervisely-6.73.271.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
1068
+ supervisely-6.73.271.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
1069
+ supervisely-6.73.271.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
1070
+ supervisely-6.73.271.dist-info/RECORD,,