supervisely 6.73.442__py3-none-any.whl → 6.73.444__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/nn/inference/inference.py +79 -2
- supervisely/nn/tracker/visualize.py +6 -3
- supervisely/nn/training/train_app.py +18 -7
- {supervisely-6.73.442.dist-info → supervisely-6.73.444.dist-info}/METADATA +1 -1
- {supervisely-6.73.442.dist-info → supervisely-6.73.444.dist-info}/RECORD +9 -9
- {supervisely-6.73.442.dist-info → supervisely-6.73.444.dist-info}/LICENSE +0 -0
- {supervisely-6.73.442.dist-info → supervisely-6.73.444.dist-info}/WHEEL +0 -0
- {supervisely-6.73.442.dist-info → supervisely-6.73.444.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.442.dist-info → supervisely-6.73.444.dist-info}/top_level.txt +0 -0
|
@@ -863,6 +863,50 @@ class Inference:
|
|
|
863
863
|
self.gui.download_progress.hide()
|
|
864
864
|
return local_model_files
|
|
865
865
|
|
|
866
|
+
def _fallback_download_custom_model_pt(self, deploy_params: dict):
|
|
867
|
+
"""
|
|
868
|
+
Downloads the PyTorch checkpoint from Team Files if TensorRT is failed to load.
|
|
869
|
+
"""
|
|
870
|
+
team_id = sly_env.team_id()
|
|
871
|
+
|
|
872
|
+
checkpoint_name = sly_fs.get_file_name(deploy_params["model_files"]["checkpoint"])
|
|
873
|
+
artifacts_dir = deploy_params["model_info"]["artifacts_dir"]
|
|
874
|
+
checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
|
|
875
|
+
checkpoint_ext = sly_fs.get_file_ext(deploy_params["model_info"]["checkpoints"][0])
|
|
876
|
+
|
|
877
|
+
pt_checkpoint_name = f"{checkpoint_name}{checkpoint_ext}"
|
|
878
|
+
remote_checkpoint_path = os.path.join(checkpoints_dir, pt_checkpoint_name)
|
|
879
|
+
local_checkpoint_path = os.path.join(self.model_dir, pt_checkpoint_name)
|
|
880
|
+
|
|
881
|
+
file_info = self.api.file.get_info_by_path(team_id, remote_checkpoint_path)
|
|
882
|
+
file_size = file_info.sizeb
|
|
883
|
+
if self.gui is not None:
|
|
884
|
+
with self.gui.download_progress(
|
|
885
|
+
message=f"Fallback. Downloading PyTorch checkpoint: '{pt_checkpoint_name}'",
|
|
886
|
+
total=file_size,
|
|
887
|
+
unit="bytes",
|
|
888
|
+
unit_scale=True,
|
|
889
|
+
) as download_pbar:
|
|
890
|
+
self.gui.download_progress.show()
|
|
891
|
+
self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path, progress_cb=download_pbar.update)
|
|
892
|
+
self.gui.download_progress.hide()
|
|
893
|
+
else:
|
|
894
|
+
self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path)
|
|
895
|
+
|
|
896
|
+
return local_checkpoint_path
|
|
897
|
+
|
|
898
|
+
def _remove_exported_checkpoints(self, checkpoint_path: str):
|
|
899
|
+
"""
|
|
900
|
+
Removes the exported checkpoints for provided PyTorch checkpoint path.
|
|
901
|
+
"""
|
|
902
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
903
|
+
onnx_path = checkpoint_path.replace(checkpoint_ext, ".onnx")
|
|
904
|
+
engine_path = checkpoint_path.replace(checkpoint_ext, ".engine")
|
|
905
|
+
if os.path.exists(onnx_path):
|
|
906
|
+
sly_fs.silent_remove(onnx_path)
|
|
907
|
+
if os.path.exists(engine_path):
|
|
908
|
+
sly_fs.silent_remove(engine_path)
|
|
909
|
+
|
|
866
910
|
def _download_custom_model(self, model_files: dict, log_progress: bool = True):
|
|
867
911
|
"""
|
|
868
912
|
Downloads the custom model data.
|
|
@@ -1060,7 +1104,35 @@ class Inference:
|
|
|
1060
1104
|
self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
|
|
1061
1105
|
self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
|
|
1062
1106
|
self._hardware = get_hardware_info(self.device)
|
|
1063
|
-
|
|
1107
|
+
|
|
1108
|
+
checkpoint_path = deploy_params["model_files"]["checkpoint"]
|
|
1109
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
1110
|
+
if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
|
|
1111
|
+
try:
|
|
1112
|
+
self.load_model(**deploy_params)
|
|
1113
|
+
except Exception as e:
|
|
1114
|
+
logger.warning(f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}")
|
|
1115
|
+
checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
|
|
1116
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1117
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1118
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1119
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1120
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1121
|
+
self.load_model(**deploy_params)
|
|
1122
|
+
if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
|
|
1123
|
+
if self.runtime == RuntimeType.ONNXRUNTIME:
|
|
1124
|
+
logger.info("Exporting PyTorch model to ONNX...")
|
|
1125
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1126
|
+
checkpoint_path = self.export_onnx(deploy_params)
|
|
1127
|
+
elif self.runtime == RuntimeType.TENSORRT:
|
|
1128
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1129
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1130
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1131
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1132
|
+
self.load_model(**deploy_params)
|
|
1133
|
+
else:
|
|
1134
|
+
self.load_model(**deploy_params)
|
|
1135
|
+
|
|
1064
1136
|
self._model_served = True
|
|
1065
1137
|
self._deploy_params = deploy_params
|
|
1066
1138
|
if self._task_id is not None and is_production():
|
|
@@ -1181,6 +1253,7 @@ class Inference:
|
|
|
1181
1253
|
if self._model_meta is None:
|
|
1182
1254
|
self._set_model_meta_from_classes()
|
|
1183
1255
|
|
|
1256
|
+
|
|
1184
1257
|
def _set_model_meta_custom_model(self, model_info: dict):
|
|
1185
1258
|
model_meta = model_info.get("model_meta")
|
|
1186
1259
|
if model_meta is None:
|
|
@@ -2526,7 +2599,6 @@ class Inference:
|
|
|
2526
2599
|
timer.daemon = True
|
|
2527
2600
|
timer.start()
|
|
2528
2601
|
self._freeze_timer = timer
|
|
2529
|
-
logger.debug("Model will be frozen in %s seconds due to inactivity.", self._inactivity_timeout)
|
|
2530
2602
|
|
|
2531
2603
|
def _set_served_callback(self):
|
|
2532
2604
|
self._model_served = True
|
|
@@ -4214,6 +4286,11 @@ class Inference:
|
|
|
4214
4286
|
return
|
|
4215
4287
|
self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
|
|
4216
4288
|
|
|
4289
|
+
def export_onnx(self, deploy_params: dict):
|
|
4290
|
+
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4291
|
+
|
|
4292
|
+
def export_tensorrt(self, deploy_params: dict):
|
|
4293
|
+
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4217
4294
|
|
|
4218
4295
|
def _exclude_duplicated_predictions(
|
|
4219
4296
|
api: Api,
|
|
@@ -161,8 +161,8 @@ class TrackingVisualizer:
|
|
|
161
161
|
process = (
|
|
162
162
|
ffmpeg
|
|
163
163
|
.input(str(video_path))
|
|
164
|
-
.output('pipe:', format='rawvideo', pix_fmt='bgr24')
|
|
165
|
-
.run_async(pipe_stdout=True, pipe_stderr=
|
|
164
|
+
.output('pipe:', format='rawvideo', pix_fmt='bgr24', loglevel='quiet')
|
|
165
|
+
.run_async(pipe_stdout=True, pipe_stderr=False)
|
|
166
166
|
)
|
|
167
167
|
|
|
168
168
|
try:
|
|
@@ -177,7 +177,10 @@ class TrackingVisualizer:
|
|
|
177
177
|
frame = np.frombuffer(frame_data, np.uint8).reshape([height, width, 3])
|
|
178
178
|
yield frame_idx, frame
|
|
179
179
|
frame_idx += 1
|
|
180
|
-
|
|
180
|
+
|
|
181
|
+
except ffmpeg.Error as e:
|
|
182
|
+
logger.error(f"ffmpeg error: {e.stderr.decode() if e.stderr else str(e)}", exc_info=True)
|
|
183
|
+
|
|
181
184
|
finally:
|
|
182
185
|
process.stdout.close()
|
|
183
186
|
if process.stderr:
|
|
@@ -3177,22 +3177,33 @@ class TrainApp:
|
|
|
3177
3177
|
experiment_name = self.gui.training_process.get_experiment_name()
|
|
3178
3178
|
|
|
3179
3179
|
train_collection_idx = 1
|
|
3180
|
-
val_collection_idx = 1
|
|
3180
|
+
val_collection_idx = 1
|
|
3181
|
+
|
|
3182
|
+
def _extract_index_from_col_name(name: str, expected_prefix: str) -> Optional[int]:
|
|
3183
|
+
parts = name.split("_")
|
|
3184
|
+
if len(parts) == 2 and parts[0] == expected_prefix and parts[1].isdigit():
|
|
3185
|
+
return int(parts[1])
|
|
3186
|
+
return None
|
|
3181
3187
|
|
|
3182
3188
|
# Get train collection with max idx
|
|
3183
3189
|
if len(all_train_collections) > 0:
|
|
3184
|
-
|
|
3185
|
-
|
|
3190
|
+
train_indices = [_extract_index_from_col_name(collection.name, "train") for collection in all_train_collections]
|
|
3191
|
+
train_indices = [idx for idx in train_indices if idx is not None]
|
|
3192
|
+
if len(train_indices) > 0:
|
|
3193
|
+
train_collection_idx = max(train_indices) + 1
|
|
3194
|
+
|
|
3186
3195
|
# Get val collection with max idx
|
|
3187
3196
|
if len(all_val_collections) > 0:
|
|
3188
|
-
|
|
3189
|
-
|
|
3197
|
+
val_indices = [_extract_index_from_col_name(collection.name, "val") for collection in all_val_collections]
|
|
3198
|
+
val_indices = [idx for idx in val_indices if idx is not None]
|
|
3199
|
+
if len(val_indices) > 0:
|
|
3200
|
+
val_collection_idx = max(val_indices) + 1
|
|
3190
3201
|
# -------------------------------- #
|
|
3191
3202
|
|
|
3192
3203
|
# Create Train Collection
|
|
3193
3204
|
train_img_ids = list(self._train_split_item_ids)
|
|
3194
3205
|
train_collection_description = f"Collection with train {item_type} for experiment: {experiment_name}"
|
|
3195
|
-
train_collection = self._api.entities_collection.create(self.project_id, f"train_{train_collection_idx}", train_collection_description)
|
|
3206
|
+
train_collection = self._api.entities_collection.create(self.project_id, f"train_{train_collection_idx:03d}", train_collection_description)
|
|
3196
3207
|
train_collection_id = getattr(train_collection, "id", None)
|
|
3197
3208
|
if train_collection_id is None:
|
|
3198
3209
|
raise AttributeError("Train EntitiesCollectionInfo object does not have 'id' attribute")
|
|
@@ -3202,7 +3213,7 @@ class TrainApp:
|
|
|
3202
3213
|
# Create Val Collection
|
|
3203
3214
|
val_img_ids = list(self._val_split_item_ids)
|
|
3204
3215
|
val_collection_description = f"Collection with val {item_type} for experiment: {experiment_name}"
|
|
3205
|
-
val_collection = self._api.entities_collection.create(self.project_id, f"val_{val_collection_idx}", val_collection_description)
|
|
3216
|
+
val_collection = self._api.entities_collection.create(self.project_id, f"val_{val_collection_idx:03d}", val_collection_description)
|
|
3206
3217
|
val_collection_id = getattr(val_collection, "id", None)
|
|
3207
3218
|
if val_collection_id is None:
|
|
3208
3219
|
raise AttributeError("Val EntitiesCollectionInfo object does not have 'id' attribute")
|
|
@@ -904,7 +904,7 @@ supervisely/nn/benchmark/visualization/widgets/table/__init__.py,sha256=47DEQpj8
|
|
|
904
904
|
supervisely/nn/benchmark/visualization/widgets/table/table.py,sha256=atmDnF1Af6qLQBUjLhK18RMDKAYlxnsuVHMSEa5a-e8,4319
|
|
905
905
|
supervisely/nn/inference/__init__.py,sha256=QFukX2ip-U7263aEPCF_UCFwj6EujbMnsgrXp5Bbt8I,1623
|
|
906
906
|
supervisely/nn/inference/cache.py,sha256=rfmb1teJ9lNDfisUSh6bwDCVkPZocn8GMvDgLQktnbo,35023
|
|
907
|
-
supervisely/nn/inference/inference.py,sha256=
|
|
907
|
+
supervisely/nn/inference/inference.py,sha256=aFFFuNhRV8m4Ch4eB-zvw1gnhz7X9WJZz-bQ3g6wUyM,207124
|
|
908
908
|
supervisely/nn/inference/inference_request.py,sha256=y6yw0vbaRRcEBS27nq3y0sL6Gmq2qLA_Bm0GrnJGegE,14267
|
|
909
909
|
supervisely/nn/inference/session.py,sha256=XUqJ_CqHk3ZJYkWxdnErN_6afCpIBU76nq6Ek7DiOQI,35792
|
|
910
910
|
supervisely/nn/inference/uploader.py,sha256=Dn5MfMRq7tclEWpP0B9fJjTiQPBpwumfXxC8-lOYgnM,5659
|
|
@@ -999,7 +999,7 @@ supervisely/nn/tracker/base_tracker.py,sha256=2d23JlHizOqVye324YT20EE8RP52uwoQUk
|
|
|
999
999
|
supervisely/nn/tracker/botsort_tracker.py,sha256=F2OaoeK1EAlBKAY95Fd9ZooZIlOZBh4YThhzmKNyP6w,10224
|
|
1000
1000
|
supervisely/nn/tracker/calculate_metrics.py,sha256=JjXI4VYWYSZ5j2Ed81FNYozkS3v2UAM73ztjLrHGg58,10434
|
|
1001
1001
|
supervisely/nn/tracker/utils.py,sha256=4WdtFHSEx5Buq6aND7cD21FdH0sGg92QNIMTpC9EGD4,10126
|
|
1002
|
-
supervisely/nn/tracker/visualize.py,sha256=
|
|
1002
|
+
supervisely/nn/tracker/visualize.py,sha256=s0QdqEE0rcKsXow3b9lSoM9FMtms-p26HqsfNiZ2W_s,20286
|
|
1003
1003
|
supervisely/nn/tracker/botsort/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
1004
1004
|
supervisely/nn/tracker/botsort/botsort_config.yaml,sha256=q_7Gp1-15lGYOLv7JvxVJ69mm6hbCLbUAl_ZBOYNGpw,535
|
|
1005
1005
|
supervisely/nn/tracker/botsort/osnet_reid/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -1012,7 +1012,7 @@ supervisely/nn/tracker/botsort/tracker/kalman_filter.py,sha256=waTArMcbmpHAzb57a
|
|
|
1012
1012
|
supervisely/nn/tracker/botsort/tracker/matching.py,sha256=bgnheHwWD3XZSI3OJVfdrU5bYJ44rxPHzzSElfg6LZM,6600
|
|
1013
1013
|
supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py,sha256=dFjWmubyJLrUP4i-CJaOhPEkQD-WD144deW7Ua5a7Rc,17775
|
|
1014
1014
|
supervisely/nn/training/__init__.py,sha256=gY4PCykJ-42MWKsqb9kl-skemKa8yB6t_fb5kzqR66U,111
|
|
1015
|
-
supervisely/nn/training/train_app.py,sha256=
|
|
1015
|
+
supervisely/nn/training/train_app.py,sha256=l4t4zzE-nPwV-tXzbw4ENQcGMQ1PCeqmyxfkLg_eFDI,133159
|
|
1016
1016
|
supervisely/nn/training/gui/__init__.py,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
|
|
1017
1017
|
supervisely/nn/training/gui/classes_selector.py,sha256=tqmVwUfC2u5K53mZmvDvNOhu9Mw5mddjpB2kxRXXUO8,12453
|
|
1018
1018
|
supervisely/nn/training/gui/gui.py,sha256=_CtpzlwP6WLFgOTBDB_4RPcaqrQPK92DwSCDvO-dIKM,51749
|
|
@@ -1127,9 +1127,9 @@ supervisely/worker_proto/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ
|
|
|
1127
1127
|
supervisely/worker_proto/worker_api_pb2.py,sha256=VQfi5JRBHs2pFCK1snec3JECgGnua3Xjqw_-b3aFxuM,59142
|
|
1128
1128
|
supervisely/worker_proto/worker_api_pb2_grpc.py,sha256=3BwQXOaP9qpdi0Dt9EKG--Lm8KGN0C5AgmUfRv77_Jk,28940
|
|
1129
1129
|
supervisely_lib/__init__.py,sha256=yRwzEQmVwSd6lUQoAUdBngKEOlnoQ6hA9ZcoZGJRNC4,331
|
|
1130
|
-
supervisely-6.73.
|
|
1131
|
-
supervisely-6.73.
|
|
1132
|
-
supervisely-6.73.
|
|
1133
|
-
supervisely-6.73.
|
|
1134
|
-
supervisely-6.73.
|
|
1135
|
-
supervisely-6.73.
|
|
1130
|
+
supervisely-6.73.444.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
1131
|
+
supervisely-6.73.444.dist-info/METADATA,sha256=wnmK1_pmFJflYjyd6PQCw6Sa4k5F7i_ISTiEq4qa3-0,35480
|
|
1132
|
+
supervisely-6.73.444.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
1133
|
+
supervisely-6.73.444.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
|
|
1134
|
+
supervisely-6.73.444.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
|
|
1135
|
+
supervisely-6.73.444.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|