supervisely 6.73.442__py3-none-any.whl → 6.73.444__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -863,6 +863,50 @@ class Inference:
863
863
  self.gui.download_progress.hide()
864
864
  return local_model_files
865
865
 
866
+ def _fallback_download_custom_model_pt(self, deploy_params: dict):
867
+ """
868
+ Downloads the PyTorch checkpoint from Team Files if TensorRT is failed to load.
869
+ """
870
+ team_id = sly_env.team_id()
871
+
872
+ checkpoint_name = sly_fs.get_file_name(deploy_params["model_files"]["checkpoint"])
873
+ artifacts_dir = deploy_params["model_info"]["artifacts_dir"]
874
+ checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
875
+ checkpoint_ext = sly_fs.get_file_ext(deploy_params["model_info"]["checkpoints"][0])
876
+
877
+ pt_checkpoint_name = f"{checkpoint_name}{checkpoint_ext}"
878
+ remote_checkpoint_path = os.path.join(checkpoints_dir, pt_checkpoint_name)
879
+ local_checkpoint_path = os.path.join(self.model_dir, pt_checkpoint_name)
880
+
881
+ file_info = self.api.file.get_info_by_path(team_id, remote_checkpoint_path)
882
+ file_size = file_info.sizeb
883
+ if self.gui is not None:
884
+ with self.gui.download_progress(
885
+ message=f"Fallback. Downloading PyTorch checkpoint: '{pt_checkpoint_name}'",
886
+ total=file_size,
887
+ unit="bytes",
888
+ unit_scale=True,
889
+ ) as download_pbar:
890
+ self.gui.download_progress.show()
891
+ self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path, progress_cb=download_pbar.update)
892
+ self.gui.download_progress.hide()
893
+ else:
894
+ self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path)
895
+
896
+ return local_checkpoint_path
897
+
898
+ def _remove_exported_checkpoints(self, checkpoint_path: str):
899
+ """
900
+ Removes the exported checkpoints for provided PyTorch checkpoint path.
901
+ """
902
+ checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
903
+ onnx_path = checkpoint_path.replace(checkpoint_ext, ".onnx")
904
+ engine_path = checkpoint_path.replace(checkpoint_ext, ".engine")
905
+ if os.path.exists(onnx_path):
906
+ sly_fs.silent_remove(onnx_path)
907
+ if os.path.exists(engine_path):
908
+ sly_fs.silent_remove(engine_path)
909
+
866
910
  def _download_custom_model(self, model_files: dict, log_progress: bool = True):
867
911
  """
868
912
  Downloads the custom model data.
@@ -1060,7 +1104,35 @@ class Inference:
1060
1104
  self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
1061
1105
  self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
1062
1106
  self._hardware = get_hardware_info(self.device)
1063
- self.load_model(**deploy_params)
1107
+
1108
+ checkpoint_path = deploy_params["model_files"]["checkpoint"]
1109
+ checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
1110
+ if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
1111
+ try:
1112
+ self.load_model(**deploy_params)
1113
+ except Exception as e:
1114
+ logger.warning(f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}")
1115
+ checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
1116
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1117
+ logger.info("Exporting PyTorch model to TensorRT...")
1118
+ self._remove_exported_checkpoints(checkpoint_path)
1119
+ checkpoint_path = self.export_tensorrt(deploy_params)
1120
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1121
+ self.load_model(**deploy_params)
1122
+ if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
1123
+ if self.runtime == RuntimeType.ONNXRUNTIME:
1124
+ logger.info("Exporting PyTorch model to ONNX...")
1125
+ self._remove_exported_checkpoints(checkpoint_path)
1126
+ checkpoint_path = self.export_onnx(deploy_params)
1127
+ elif self.runtime == RuntimeType.TENSORRT:
1128
+ logger.info("Exporting PyTorch model to TensorRT...")
1129
+ self._remove_exported_checkpoints(checkpoint_path)
1130
+ checkpoint_path = self.export_tensorrt(deploy_params)
1131
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1132
+ self.load_model(**deploy_params)
1133
+ else:
1134
+ self.load_model(**deploy_params)
1135
+
1064
1136
  self._model_served = True
1065
1137
  self._deploy_params = deploy_params
1066
1138
  if self._task_id is not None and is_production():
@@ -1181,6 +1253,7 @@ class Inference:
1181
1253
  if self._model_meta is None:
1182
1254
  self._set_model_meta_from_classes()
1183
1255
 
1256
+
1184
1257
  def _set_model_meta_custom_model(self, model_info: dict):
1185
1258
  model_meta = model_info.get("model_meta")
1186
1259
  if model_meta is None:
@@ -2526,7 +2599,6 @@ class Inference:
2526
2599
  timer.daemon = True
2527
2600
  timer.start()
2528
2601
  self._freeze_timer = timer
2529
- logger.debug("Model will be frozen in %s seconds due to inactivity.", self._inactivity_timeout)
2530
2602
 
2531
2603
  def _set_served_callback(self):
2532
2604
  self._model_served = True
@@ -4214,6 +4286,11 @@ class Inference:
4214
4286
  return
4215
4287
  self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
4216
4288
 
4289
+ def export_onnx(self, deploy_params: dict):
4290
+ raise NotImplementedError("Have to be implemented in child class after inheritance")
4291
+
4292
+ def export_tensorrt(self, deploy_params: dict):
4293
+ raise NotImplementedError("Have to be implemented in child class after inheritance")
4217
4294
 
4218
4295
  def _exclude_duplicated_predictions(
4219
4296
  api: Api,
@@ -161,8 +161,8 @@ class TrackingVisualizer:
161
161
  process = (
162
162
  ffmpeg
163
163
  .input(str(video_path))
164
- .output('pipe:', format='rawvideo', pix_fmt='bgr24')
165
- .run_async(pipe_stdout=True, pipe_stderr=True)
164
+ .output('pipe:', format='rawvideo', pix_fmt='bgr24', loglevel='quiet')
165
+ .run_async(pipe_stdout=True, pipe_stderr=False)
166
166
  )
167
167
 
168
168
  try:
@@ -177,7 +177,10 @@ class TrackingVisualizer:
177
177
  frame = np.frombuffer(frame_data, np.uint8).reshape([height, width, 3])
178
178
  yield frame_idx, frame
179
179
  frame_idx += 1
180
-
180
+
181
+ except ffmpeg.Error as e:
182
+ logger.error(f"ffmpeg error: {e.stderr.decode() if e.stderr else str(e)}", exc_info=True)
183
+
181
184
  finally:
182
185
  process.stdout.close()
183
186
  if process.stderr:
@@ -3177,22 +3177,33 @@ class TrainApp:
3177
3177
  experiment_name = self.gui.training_process.get_experiment_name()
3178
3178
 
3179
3179
  train_collection_idx = 1
3180
- val_collection_idx = 1
3180
+ val_collection_idx = 1
3181
+
3182
+ def _extract_index_from_col_name(name: str, expected_prefix: str) -> Optional[int]:
3183
+ parts = name.split("_")
3184
+ if len(parts) == 2 and parts[0] == expected_prefix and parts[1].isdigit():
3185
+ return int(parts[1])
3186
+ return None
3181
3187
 
3182
3188
  # Get train collection with max idx
3183
3189
  if len(all_train_collections) > 0:
3184
- train_collection_idx = max([int(collection.name.split("_")[1]) for collection in all_train_collections])
3185
- train_collection_idx += 1
3190
+ train_indices = [_extract_index_from_col_name(collection.name, "train") for collection in all_train_collections]
3191
+ train_indices = [idx for idx in train_indices if idx is not None]
3192
+ if len(train_indices) > 0:
3193
+ train_collection_idx = max(train_indices) + 1
3194
+
3186
3195
  # Get val collection with max idx
3187
3196
  if len(all_val_collections) > 0:
3188
- val_collection_idx = max([int(collection.name.split("_")[1]) for collection in all_val_collections])
3189
- val_collection_idx += 1
3197
+ val_indices = [_extract_index_from_col_name(collection.name, "val") for collection in all_val_collections]
3198
+ val_indices = [idx for idx in val_indices if idx is not None]
3199
+ if len(val_indices) > 0:
3200
+ val_collection_idx = max(val_indices) + 1
3190
3201
  # -------------------------------- #
3191
3202
 
3192
3203
  # Create Train Collection
3193
3204
  train_img_ids = list(self._train_split_item_ids)
3194
3205
  train_collection_description = f"Collection with train {item_type} for experiment: {experiment_name}"
3195
- train_collection = self._api.entities_collection.create(self.project_id, f"train_{train_collection_idx}", train_collection_description)
3206
+ train_collection = self._api.entities_collection.create(self.project_id, f"train_{train_collection_idx:03d}", train_collection_description)
3196
3207
  train_collection_id = getattr(train_collection, "id", None)
3197
3208
  if train_collection_id is None:
3198
3209
  raise AttributeError("Train EntitiesCollectionInfo object does not have 'id' attribute")
@@ -3202,7 +3213,7 @@ class TrainApp:
3202
3213
  # Create Val Collection
3203
3214
  val_img_ids = list(self._val_split_item_ids)
3204
3215
  val_collection_description = f"Collection with val {item_type} for experiment: {experiment_name}"
3205
- val_collection = self._api.entities_collection.create(self.project_id, f"val_{val_collection_idx}", val_collection_description)
3216
+ val_collection = self._api.entities_collection.create(self.project_id, f"val_{val_collection_idx:03d}", val_collection_description)
3206
3217
  val_collection_id = getattr(val_collection, "id", None)
3207
3218
  if val_collection_id is None:
3208
3219
  raise AttributeError("Val EntitiesCollectionInfo object does not have 'id' attribute")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: supervisely
3
- Version: 6.73.442
3
+ Version: 6.73.444
4
4
  Summary: Supervisely Python SDK.
5
5
  Home-page: https://github.com/supervisely/supervisely
6
6
  Author: Supervisely
@@ -904,7 +904,7 @@ supervisely/nn/benchmark/visualization/widgets/table/__init__.py,sha256=47DEQpj8
904
904
  supervisely/nn/benchmark/visualization/widgets/table/table.py,sha256=atmDnF1Af6qLQBUjLhK18RMDKAYlxnsuVHMSEa5a-e8,4319
905
905
  supervisely/nn/inference/__init__.py,sha256=QFukX2ip-U7263aEPCF_UCFwj6EujbMnsgrXp5Bbt8I,1623
906
906
  supervisely/nn/inference/cache.py,sha256=rfmb1teJ9lNDfisUSh6bwDCVkPZocn8GMvDgLQktnbo,35023
907
- supervisely/nn/inference/inference.py,sha256=sr70lji4u7V5MsZBiUBBUuc8_dL_FsNLboAPexAq0HU,203145
907
+ supervisely/nn/inference/inference.py,sha256=aFFFuNhRV8m4Ch4eB-zvw1gnhz7X9WJZz-bQ3g6wUyM,207124
908
908
  supervisely/nn/inference/inference_request.py,sha256=y6yw0vbaRRcEBS27nq3y0sL6Gmq2qLA_Bm0GrnJGegE,14267
909
909
  supervisely/nn/inference/session.py,sha256=XUqJ_CqHk3ZJYkWxdnErN_6afCpIBU76nq6Ek7DiOQI,35792
910
910
  supervisely/nn/inference/uploader.py,sha256=Dn5MfMRq7tclEWpP0B9fJjTiQPBpwumfXxC8-lOYgnM,5659
@@ -999,7 +999,7 @@ supervisely/nn/tracker/base_tracker.py,sha256=2d23JlHizOqVye324YT20EE8RP52uwoQUk
999
999
  supervisely/nn/tracker/botsort_tracker.py,sha256=F2OaoeK1EAlBKAY95Fd9ZooZIlOZBh4YThhzmKNyP6w,10224
1000
1000
  supervisely/nn/tracker/calculate_metrics.py,sha256=JjXI4VYWYSZ5j2Ed81FNYozkS3v2UAM73ztjLrHGg58,10434
1001
1001
  supervisely/nn/tracker/utils.py,sha256=4WdtFHSEx5Buq6aND7cD21FdH0sGg92QNIMTpC9EGD4,10126
1002
- supervisely/nn/tracker/visualize.py,sha256=DY2cnRm4w6_e47xVuF9fwSnOP27ZWTRfv4MFPCyxsD4,20146
1002
+ supervisely/nn/tracker/visualize.py,sha256=s0QdqEE0rcKsXow3b9lSoM9FMtms-p26HqsfNiZ2W_s,20286
1003
1003
  supervisely/nn/tracker/botsort/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
1004
1004
  supervisely/nn/tracker/botsort/botsort_config.yaml,sha256=q_7Gp1-15lGYOLv7JvxVJ69mm6hbCLbUAl_ZBOYNGpw,535
1005
1005
  supervisely/nn/tracker/botsort/osnet_reid/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -1012,7 +1012,7 @@ supervisely/nn/tracker/botsort/tracker/kalman_filter.py,sha256=waTArMcbmpHAzb57a
1012
1012
  supervisely/nn/tracker/botsort/tracker/matching.py,sha256=bgnheHwWD3XZSI3OJVfdrU5bYJ44rxPHzzSElfg6LZM,6600
1013
1013
  supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py,sha256=dFjWmubyJLrUP4i-CJaOhPEkQD-WD144deW7Ua5a7Rc,17775
1014
1014
  supervisely/nn/training/__init__.py,sha256=gY4PCykJ-42MWKsqb9kl-skemKa8yB6t_fb5kzqR66U,111
1015
- supervisely/nn/training/train_app.py,sha256=yHr8s8xQsk1zs3wxOxSlJ_qMecO4g112SXvFUL6n99M,132593
1015
+ supervisely/nn/training/train_app.py,sha256=l4t4zzE-nPwV-tXzbw4ENQcGMQ1PCeqmyxfkLg_eFDI,133159
1016
1016
  supervisely/nn/training/gui/__init__.py,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
1017
1017
  supervisely/nn/training/gui/classes_selector.py,sha256=tqmVwUfC2u5K53mZmvDvNOhu9Mw5mddjpB2kxRXXUO8,12453
1018
1018
  supervisely/nn/training/gui/gui.py,sha256=_CtpzlwP6WLFgOTBDB_4RPcaqrQPK92DwSCDvO-dIKM,51749
@@ -1127,9 +1127,9 @@ supervisely/worker_proto/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ
1127
1127
  supervisely/worker_proto/worker_api_pb2.py,sha256=VQfi5JRBHs2pFCK1snec3JECgGnua3Xjqw_-b3aFxuM,59142
1128
1128
  supervisely/worker_proto/worker_api_pb2_grpc.py,sha256=3BwQXOaP9qpdi0Dt9EKG--Lm8KGN0C5AgmUfRv77_Jk,28940
1129
1129
  supervisely_lib/__init__.py,sha256=yRwzEQmVwSd6lUQoAUdBngKEOlnoQ6hA9ZcoZGJRNC4,331
1130
- supervisely-6.73.442.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1131
- supervisely-6.73.442.dist-info/METADATA,sha256=QnoEurAbF56kquhG3igJuoDl4p4SO571wv2ngekr6Fk,35480
1132
- supervisely-6.73.442.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
1133
- supervisely-6.73.442.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
1134
- supervisely-6.73.442.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
1135
- supervisely-6.73.442.dist-info/RECORD,,
1130
+ supervisely-6.73.444.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1131
+ supervisely-6.73.444.dist-info/METADATA,sha256=wnmK1_pmFJflYjyd6PQCw6Sa4k5F7i_ISTiEq4qa3-0,35480
1132
+ supervisely-6.73.444.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
1133
+ supervisely-6.73.444.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
1134
+ supervisely-6.73.444.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
1135
+ supervisely-6.73.444.dist-info/RECORD,,