supervisely 6.73.442__py3-none-any.whl → 6.73.443__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/nn/inference/inference.py +79 -2
- supervisely/nn/tracker/visualize.py +6 -3
- {supervisely-6.73.442.dist-info → supervisely-6.73.443.dist-info}/METADATA +1 -1
- {supervisely-6.73.442.dist-info → supervisely-6.73.443.dist-info}/RECORD +8 -8
- {supervisely-6.73.442.dist-info → supervisely-6.73.443.dist-info}/LICENSE +0 -0
- {supervisely-6.73.442.dist-info → supervisely-6.73.443.dist-info}/WHEEL +0 -0
- {supervisely-6.73.442.dist-info → supervisely-6.73.443.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.442.dist-info → supervisely-6.73.443.dist-info}/top_level.txt +0 -0
|
@@ -863,6 +863,50 @@ class Inference:
|
|
|
863
863
|
self.gui.download_progress.hide()
|
|
864
864
|
return local_model_files
|
|
865
865
|
|
|
866
|
+
def _fallback_download_custom_model_pt(self, deploy_params: dict):
|
|
867
|
+
"""
|
|
868
|
+
Downloads the PyTorch checkpoint from Team Files if TensorRT is failed to load.
|
|
869
|
+
"""
|
|
870
|
+
team_id = sly_env.team_id()
|
|
871
|
+
|
|
872
|
+
checkpoint_name = sly_fs.get_file_name(deploy_params["model_files"]["checkpoint"])
|
|
873
|
+
artifacts_dir = deploy_params["model_info"]["artifacts_dir"]
|
|
874
|
+
checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
|
|
875
|
+
checkpoint_ext = sly_fs.get_file_ext(deploy_params["model_info"]["checkpoints"][0])
|
|
876
|
+
|
|
877
|
+
pt_checkpoint_name = f"{checkpoint_name}{checkpoint_ext}"
|
|
878
|
+
remote_checkpoint_path = os.path.join(checkpoints_dir, pt_checkpoint_name)
|
|
879
|
+
local_checkpoint_path = os.path.join(self.model_dir, pt_checkpoint_name)
|
|
880
|
+
|
|
881
|
+
file_info = self.api.file.get_info_by_path(team_id, remote_checkpoint_path)
|
|
882
|
+
file_size = file_info.sizeb
|
|
883
|
+
if self.gui is not None:
|
|
884
|
+
with self.gui.download_progress(
|
|
885
|
+
message=f"Fallback. Downloading PyTorch checkpoint: '{pt_checkpoint_name}'",
|
|
886
|
+
total=file_size,
|
|
887
|
+
unit="bytes",
|
|
888
|
+
unit_scale=True,
|
|
889
|
+
) as download_pbar:
|
|
890
|
+
self.gui.download_progress.show()
|
|
891
|
+
self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path, progress_cb=download_pbar.update)
|
|
892
|
+
self.gui.download_progress.hide()
|
|
893
|
+
else:
|
|
894
|
+
self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path)
|
|
895
|
+
|
|
896
|
+
return local_checkpoint_path
|
|
897
|
+
|
|
898
|
+
def _remove_exported_checkpoints(self, checkpoint_path: str):
|
|
899
|
+
"""
|
|
900
|
+
Removes the exported checkpoints for provided PyTorch checkpoint path.
|
|
901
|
+
"""
|
|
902
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
903
|
+
onnx_path = checkpoint_path.replace(checkpoint_ext, ".onnx")
|
|
904
|
+
engine_path = checkpoint_path.replace(checkpoint_ext, ".engine")
|
|
905
|
+
if os.path.exists(onnx_path):
|
|
906
|
+
sly_fs.silent_remove(onnx_path)
|
|
907
|
+
if os.path.exists(engine_path):
|
|
908
|
+
sly_fs.silent_remove(engine_path)
|
|
909
|
+
|
|
866
910
|
def _download_custom_model(self, model_files: dict, log_progress: bool = True):
|
|
867
911
|
"""
|
|
868
912
|
Downloads the custom model data.
|
|
@@ -1060,7 +1104,35 @@ class Inference:
|
|
|
1060
1104
|
self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
|
|
1061
1105
|
self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
|
|
1062
1106
|
self._hardware = get_hardware_info(self.device)
|
|
1063
|
-
|
|
1107
|
+
|
|
1108
|
+
checkpoint_path = deploy_params["model_files"]["checkpoint"]
|
|
1109
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
1110
|
+
if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
|
|
1111
|
+
try:
|
|
1112
|
+
self.load_model(**deploy_params)
|
|
1113
|
+
except Exception as e:
|
|
1114
|
+
logger.warning(f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}")
|
|
1115
|
+
checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
|
|
1116
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1117
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1118
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1119
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1120
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1121
|
+
self.load_model(**deploy_params)
|
|
1122
|
+
if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
|
|
1123
|
+
if self.runtime == RuntimeType.ONNXRUNTIME:
|
|
1124
|
+
logger.info("Exporting PyTorch model to ONNX...")
|
|
1125
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1126
|
+
checkpoint_path = self.export_onnx(deploy_params)
|
|
1127
|
+
elif self.runtime == RuntimeType.TENSORRT:
|
|
1128
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1129
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1130
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1131
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1132
|
+
self.load_model(**deploy_params)
|
|
1133
|
+
else:
|
|
1134
|
+
self.load_model(**deploy_params)
|
|
1135
|
+
|
|
1064
1136
|
self._model_served = True
|
|
1065
1137
|
self._deploy_params = deploy_params
|
|
1066
1138
|
if self._task_id is not None and is_production():
|
|
@@ -1181,6 +1253,7 @@ class Inference:
|
|
|
1181
1253
|
if self._model_meta is None:
|
|
1182
1254
|
self._set_model_meta_from_classes()
|
|
1183
1255
|
|
|
1256
|
+
|
|
1184
1257
|
def _set_model_meta_custom_model(self, model_info: dict):
|
|
1185
1258
|
model_meta = model_info.get("model_meta")
|
|
1186
1259
|
if model_meta is None:
|
|
@@ -2526,7 +2599,6 @@ class Inference:
|
|
|
2526
2599
|
timer.daemon = True
|
|
2527
2600
|
timer.start()
|
|
2528
2601
|
self._freeze_timer = timer
|
|
2529
|
-
logger.debug("Model will be frozen in %s seconds due to inactivity.", self._inactivity_timeout)
|
|
2530
2602
|
|
|
2531
2603
|
def _set_served_callback(self):
|
|
2532
2604
|
self._model_served = True
|
|
@@ -4214,6 +4286,11 @@ class Inference:
|
|
|
4214
4286
|
return
|
|
4215
4287
|
self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
|
|
4216
4288
|
|
|
4289
|
+
def export_onnx(self, deploy_params: dict):
|
|
4290
|
+
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4291
|
+
|
|
4292
|
+
def export_tensorrt(self, deploy_params: dict):
|
|
4293
|
+
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4217
4294
|
|
|
4218
4295
|
def _exclude_duplicated_predictions(
|
|
4219
4296
|
api: Api,
|
|
@@ -161,8 +161,8 @@ class TrackingVisualizer:
|
|
|
161
161
|
process = (
|
|
162
162
|
ffmpeg
|
|
163
163
|
.input(str(video_path))
|
|
164
|
-
.output('pipe:', format='rawvideo', pix_fmt='bgr24')
|
|
165
|
-
.run_async(pipe_stdout=True, pipe_stderr=
|
|
164
|
+
.output('pipe:', format='rawvideo', pix_fmt='bgr24', loglevel='quiet')
|
|
165
|
+
.run_async(pipe_stdout=True, pipe_stderr=False)
|
|
166
166
|
)
|
|
167
167
|
|
|
168
168
|
try:
|
|
@@ -177,7 +177,10 @@ class TrackingVisualizer:
|
|
|
177
177
|
frame = np.frombuffer(frame_data, np.uint8).reshape([height, width, 3])
|
|
178
178
|
yield frame_idx, frame
|
|
179
179
|
frame_idx += 1
|
|
180
|
-
|
|
180
|
+
|
|
181
|
+
except ffmpeg.Error as e:
|
|
182
|
+
logger.error(f"ffmpeg error: {e.stderr.decode() if e.stderr else str(e)}", exc_info=True)
|
|
183
|
+
|
|
181
184
|
finally:
|
|
182
185
|
process.stdout.close()
|
|
183
186
|
if process.stderr:
|
|
@@ -904,7 +904,7 @@ supervisely/nn/benchmark/visualization/widgets/table/__init__.py,sha256=47DEQpj8
|
|
|
904
904
|
supervisely/nn/benchmark/visualization/widgets/table/table.py,sha256=atmDnF1Af6qLQBUjLhK18RMDKAYlxnsuVHMSEa5a-e8,4319
|
|
905
905
|
supervisely/nn/inference/__init__.py,sha256=QFukX2ip-U7263aEPCF_UCFwj6EujbMnsgrXp5Bbt8I,1623
|
|
906
906
|
supervisely/nn/inference/cache.py,sha256=rfmb1teJ9lNDfisUSh6bwDCVkPZocn8GMvDgLQktnbo,35023
|
|
907
|
-
supervisely/nn/inference/inference.py,sha256=
|
|
907
|
+
supervisely/nn/inference/inference.py,sha256=aFFFuNhRV8m4Ch4eB-zvw1gnhz7X9WJZz-bQ3g6wUyM,207124
|
|
908
908
|
supervisely/nn/inference/inference_request.py,sha256=y6yw0vbaRRcEBS27nq3y0sL6Gmq2qLA_Bm0GrnJGegE,14267
|
|
909
909
|
supervisely/nn/inference/session.py,sha256=XUqJ_CqHk3ZJYkWxdnErN_6afCpIBU76nq6Ek7DiOQI,35792
|
|
910
910
|
supervisely/nn/inference/uploader.py,sha256=Dn5MfMRq7tclEWpP0B9fJjTiQPBpwumfXxC8-lOYgnM,5659
|
|
@@ -999,7 +999,7 @@ supervisely/nn/tracker/base_tracker.py,sha256=2d23JlHizOqVye324YT20EE8RP52uwoQUk
|
|
|
999
999
|
supervisely/nn/tracker/botsort_tracker.py,sha256=F2OaoeK1EAlBKAY95Fd9ZooZIlOZBh4YThhzmKNyP6w,10224
|
|
1000
1000
|
supervisely/nn/tracker/calculate_metrics.py,sha256=JjXI4VYWYSZ5j2Ed81FNYozkS3v2UAM73ztjLrHGg58,10434
|
|
1001
1001
|
supervisely/nn/tracker/utils.py,sha256=4WdtFHSEx5Buq6aND7cD21FdH0sGg92QNIMTpC9EGD4,10126
|
|
1002
|
-
supervisely/nn/tracker/visualize.py,sha256=
|
|
1002
|
+
supervisely/nn/tracker/visualize.py,sha256=s0QdqEE0rcKsXow3b9lSoM9FMtms-p26HqsfNiZ2W_s,20286
|
|
1003
1003
|
supervisely/nn/tracker/botsort/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
1004
1004
|
supervisely/nn/tracker/botsort/botsort_config.yaml,sha256=q_7Gp1-15lGYOLv7JvxVJ69mm6hbCLbUAl_ZBOYNGpw,535
|
|
1005
1005
|
supervisely/nn/tracker/botsort/osnet_reid/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -1127,9 +1127,9 @@ supervisely/worker_proto/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ
|
|
|
1127
1127
|
supervisely/worker_proto/worker_api_pb2.py,sha256=VQfi5JRBHs2pFCK1snec3JECgGnua3Xjqw_-b3aFxuM,59142
|
|
1128
1128
|
supervisely/worker_proto/worker_api_pb2_grpc.py,sha256=3BwQXOaP9qpdi0Dt9EKG--Lm8KGN0C5AgmUfRv77_Jk,28940
|
|
1129
1129
|
supervisely_lib/__init__.py,sha256=yRwzEQmVwSd6lUQoAUdBngKEOlnoQ6hA9ZcoZGJRNC4,331
|
|
1130
|
-
supervisely-6.73.
|
|
1131
|
-
supervisely-6.73.
|
|
1132
|
-
supervisely-6.73.
|
|
1133
|
-
supervisely-6.73.
|
|
1134
|
-
supervisely-6.73.
|
|
1135
|
-
supervisely-6.73.
|
|
1130
|
+
supervisely-6.73.443.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
1131
|
+
supervisely-6.73.443.dist-info/METADATA,sha256=XiNJG5fGmMFlBABN2mptuO55-VvjCbOxLMflB-PHDOE,35480
|
|
1132
|
+
supervisely-6.73.443.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
1133
|
+
supervisely-6.73.443.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
|
|
1134
|
+
supervisely-6.73.443.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
|
|
1135
|
+
supervisely-6.73.443.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|