dgenerate-ultralytics-headless 8.3.155__py3-none-any.whl → 8.3.156__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.155.dist-info → dgenerate_ultralytics_headless-8.3.156.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.155.dist-info → dgenerate_ultralytics_headless-8.3.156.dist-info}/RECORD +13 -13
- ultralytics/__init__.py +1 -1
- ultralytics/data/build.py +3 -1
- ultralytics/engine/exporter.py +3 -28
- ultralytics/models/yolo/classify/train.py +2 -3
- ultralytics/models/yolo/detect/train.py +4 -1
- ultralytics/utils/export.py +7 -2
- ultralytics/utils/patches.py +48 -1
- {dgenerate_ultralytics_headless-8.3.155.dist-info → dgenerate_ultralytics_headless-8.3.156.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.155.dist-info → dgenerate_ultralytics_headless-8.3.156.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.155.dist-info → dgenerate_ultralytics_headless-8.3.156.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.155.dist-info → dgenerate_ultralytics_headless-8.3.156.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: dgenerate-ultralytics-headless
|
3
|
-
Version: 8.3.
|
3
|
+
Version: 8.3.156
|
4
4
|
Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
|
5
5
|
Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
|
6
6
|
Maintainer-email: Ultralytics <hello@ultralytics.com>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
dgenerate_ultralytics_headless-8.3.
|
1
|
+
dgenerate_ultralytics_headless-8.3.156.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
|
2
2
|
tests/__init__.py,sha256=b4KP5_q-2IO8Br8YHOSLYnn7IwZS81l_vfEF2YPa2lM,894
|
3
3
|
tests/conftest.py,sha256=JjgKSs36ZaGmmtqGmAapmFSoFF1YwyV3IZsOgqt2IVM,2593
|
4
4
|
tests/test_cli.py,sha256=Kpfxq_RlbKK1Z8xNScDUbre6GB7neZhXZAYGI1tiDS8,5660
|
@@ -8,7 +8,7 @@ tests/test_exports.py,sha256=HmMKOTCia9ZDC0VYc_EPmvBTM5LM5eeI1NF_pKjLpd8,9677
|
|
8
8
|
tests/test_integrations.py,sha256=cQfgueFhEZ8Xs-tF0uiIEhvn0DlhOH-Wqrx96LXp3D0,6303
|
9
9
|
tests/test_python.py,sha256=nOoaPDg-0j7ZPRz9-uGFny3uocxjUM1ze5wA3BpGxKQ,27865
|
10
10
|
tests/test_solutions.py,sha256=tuf6n_fsI8KvSdJrnc-cqP2qYdiYqCWuVrx0z9dOz3Q,13213
|
11
|
-
ultralytics/__init__.py,sha256=
|
11
|
+
ultralytics/__init__.py,sha256=J6_0KTPXPXrT2RzKhx1IG0zT6K_pszayz4L88pBLnzA,730
|
12
12
|
ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
|
13
13
|
ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
|
14
14
|
ultralytics/cfg/__init__.py,sha256=ds63URbbeRj5UxkCSyl62OrNw6HQy7xeit5-0wGDEKg,39699
|
@@ -108,7 +108,7 @@ ultralytics/data/__init__.py,sha256=nAXaL1puCc7z_NjzQNlJnhbVhT9Fla2u7Dsqo7q1dAc,
|
|
108
108
|
ultralytics/data/annotator.py,sha256=uAgd7K-yudxiwdNqHz0ubfFg5JsfNlae4cgxdvCMyuY,3030
|
109
109
|
ultralytics/data/augment.py,sha256=fvYug6B0qrSSS8IYpvdju9uENnEJWCf-GNG5WqIayng,128964
|
110
110
|
ultralytics/data/base.py,sha256=mRcuehK1thNuuzQGL6D1AaZkod71oHRdYTod_zdQZQg,19688
|
111
|
-
ultralytics/data/build.py,sha256=
|
111
|
+
ultralytics/data/build.py,sha256=13gPxCJIZRjgcNh7zbzanCgtyK6_oZM0ho9KQhHcM6c,11153
|
112
112
|
ultralytics/data/converter.py,sha256=oKW8ODtvFOKBx9Un8n87xUUm3b5GStU4ViIBH5UDylM,27200
|
113
113
|
ultralytics/data/dataset.py,sha256=bVi1yTfQKJGKItMDTYzIE6MIEPpWqzXnUqra5AXmV18,35443
|
114
114
|
ultralytics/data/loaders.py,sha256=kTGO1P-HntpQk078i1ASyXYckDx9Z7Pe7o1YbePcjC4,31657
|
@@ -120,7 +120,7 @@ ultralytics/data/scripts/get_coco.sh,sha256=UuJpJeo3qQpTHVINeOpmP0NYmg8PhEFE3A8J
|
|
120
120
|
ultralytics/data/scripts/get_coco128.sh,sha256=qmRQl_hOKrsdHrTrnyQuFIH01oDz3lfaz138OgGfLt8,650
|
121
121
|
ultralytics/data/scripts/get_imagenet.sh,sha256=hr42H16bM47iT27rgS7MpEo-GeOZAYUQXgr0B2cwn48,1705
|
122
122
|
ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
|
123
|
-
ultralytics/engine/exporter.py,sha256=
|
123
|
+
ultralytics/engine/exporter.py,sha256=6tBiT7xYg4U_C9OA-7vJQlowk08Jsb4ASZAFcJyTnCg,73077
|
124
124
|
ultralytics/engine/model.py,sha256=DwugtVxUbCGzpY2pStFMcEloim0ai6LrT6kTbwskSJ8,53302
|
125
125
|
ultralytics/engine/predictor.py,sha256=88zrgZP91ehwdeGl8BM_cQ_caeuwKIPDy3OzxcRBjTU,22474
|
126
126
|
ultralytics/engine/results.py,sha256=Mb8pBTOrBtQh0PQtGVbhRZ_C1VyqYFumjLggiKCRIJs,72295
|
@@ -168,11 +168,11 @@ ultralytics/models/yolo/__init__.py,sha256=or0j5xvcM0usMlsFTYhNAOcQUri7reD0cD9JR
|
|
168
168
|
ultralytics/models/yolo/model.py,sha256=C0wInQC6rFuFOGpdAen1s2e5LIFDmqevto8uPbpmB8c,18449
|
169
169
|
ultralytics/models/yolo/classify/__init__.py,sha256=9--HVaNOfI1K7rn_rRqclL8FUAnpfeBrRqEQIaQw2xM,383
|
170
170
|
ultralytics/models/yolo/classify/predict.py,sha256=_GiN6muuZOBrMS1KER85FE4ktcw_Onn1bZdGvpbsGCE,4618
|
171
|
-
ultralytics/models/yolo/classify/train.py,sha256=
|
171
|
+
ultralytics/models/yolo/classify/train.py,sha256=V-hevc6X7xemnpyru84OfTRA77eNnkVSMEz16_OUvo4,10244
|
172
172
|
ultralytics/models/yolo/classify/val.py,sha256=YakPxBVZCd85Kp4wFKx8KH6JJFiU7nkFS3r9_ZSwFRM,10036
|
173
173
|
ultralytics/models/yolo/detect/__init__.py,sha256=GIRsLYR-kT4JJx7lh4ZZAFGBZj0aebokuU0A7JbjDVA,257
|
174
174
|
ultralytics/models/yolo/detect/predict.py,sha256=ySUsdIf8dw00bzWhcxN1jZwLWKPRT2M7-N7TNL3o4zo,5387
|
175
|
-
ultralytics/models/yolo/detect/train.py,sha256=
|
175
|
+
ultralytics/models/yolo/detect/train.py,sha256=HlaCoHJ6Y2TpCXXWabMRZApAYqBvjuM_YQJUV5JYCvw,9907
|
176
176
|
ultralytics/models/yolo/detect/val.py,sha256=1w7sP4GQEIdSq_D26fTtqD4t8K_YlAu_GhCUM6uw4_0,19323
|
177
177
|
ultralytics/models/yolo/obb/__init__.py,sha256=tQmpG8wVHsajWkZdmD6cjGohJ4ki64iSXQT8JY_dydo,221
|
178
178
|
ultralytics/models/yolo/obb/predict.py,sha256=4r1eSld6TNJlk9JG56e-DX6oPL8uBBqiuztyBpxWlHE,2888
|
@@ -243,13 +243,13 @@ ultralytics/utils/checks.py,sha256=PPVmxfxoHuC4YR7i56uklCKXFAPnltzbHHCxUwERjUQ,3
|
|
243
243
|
ultralytics/utils/dist.py,sha256=A9lDGtGefTjSVvVS38w86GOdbtLzNBDZuDGK0MT4PRI,4170
|
244
244
|
ultralytics/utils/downloads.py,sha256=YB6rJkcRGQfklUjZqi9dOkTiZaDSqbkGyZEFcZLQkgc,22080
|
245
245
|
ultralytics/utils/errors.py,sha256=XT9Ru7ivoBgofK6PlnyigGoa7Fmf5nEhyHtnD-8TRXI,1584
|
246
|
-
ultralytics/utils/export.py,sha256=
|
246
|
+
ultralytics/utils/export.py,sha256=0gG_GZNRqHcORJbjQq_1MXEHc3UEfzPAdpOl2X5VoDc,10008
|
247
247
|
ultralytics/utils/files.py,sha256=ZCbLGleiF0f-PqYfaxMFAWop88w7U1hpreHXl8b2ko0,8238
|
248
248
|
ultralytics/utils/instance.py,sha256=vhqaZRGT_4K9Q3oQH5KNNK4ISOzxlf1_JjauwhuFhu0,18408
|
249
249
|
ultralytics/utils/loss.py,sha256=fbOWc3Iu0QOJiWbi-mXWA9-1otTYlehtmUsI7os7ydM,39799
|
250
250
|
ultralytics/utils/metrics.py,sha256=1XaTT3n3tfLms6LOCiEzg_QGHQJzjZmfjFoAYsCCc24,62646
|
251
251
|
ultralytics/utils/ops.py,sha256=Jkh80ujyi0XDQwNqCUYyomH8NQ145AH9doMUS8Vt8GE,34545
|
252
|
-
ultralytics/utils/patches.py,sha256=
|
252
|
+
ultralytics/utils/patches.py,sha256=P2uQy7S4RzSHBfwJEXJsjyuRUluaaUusiVU84lV3moQ,6577
|
253
253
|
ultralytics/utils/plotting.py,sha256=OzanAqs7Z02ddAd1LiXce0Jjjo8DSjAjbKViE6S5CKg,47176
|
254
254
|
ultralytics/utils/tal.py,sha256=aXawOnhn8ni65tJWIW-PYqWr_TRvltbHBjrTo7o6lDQ,20924
|
255
255
|
ultralytics/utils/torch_utils.py,sha256=iIAjf2g4hikzBeHvKN-EQK8QFlC_QtWWRuYQuBF2zIk,39184
|
@@ -266,8 +266,8 @@ ultralytics/utils/callbacks/neptune.py,sha256=j8pecmlcsM8FGzLKWoBw5xUsi5t8E5HuxY
|
|
266
266
|
ultralytics/utils/callbacks/raytune.py,sha256=S6Bq16oQDQ8BQgnZzA0zJHGN_BBr8iAM_WtGoLiEcwg,1283
|
267
267
|
ultralytics/utils/callbacks/tensorboard.py,sha256=MDPBW7aDes-66OE6YqKXXvqA_EocjzEMHWGM-8z9vUQ,5281
|
268
268
|
ultralytics/utils/callbacks/wb.py,sha256=Tm_-aRr2CN32MJkY9tylpMBJkb007-MSRNSQ7rDJ5QU,7521
|
269
|
-
dgenerate_ultralytics_headless-8.3.
|
270
|
-
dgenerate_ultralytics_headless-8.3.
|
271
|
-
dgenerate_ultralytics_headless-8.3.
|
272
|
-
dgenerate_ultralytics_headless-8.3.
|
273
|
-
dgenerate_ultralytics_headless-8.3.
|
269
|
+
dgenerate_ultralytics_headless-8.3.156.dist-info/METADATA,sha256=MPgCOrMsVi12cWvoXIS9DUYpa2RTJBpiYgNE6tA-m6w,38296
|
270
|
+
dgenerate_ultralytics_headless-8.3.156.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
271
|
+
dgenerate_ultralytics_headless-8.3.156.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
|
272
|
+
dgenerate_ultralytics_headless-8.3.156.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
|
273
|
+
dgenerate_ultralytics_headless-8.3.156.dist-info/RECORD,,
|
ultralytics/__init__.py
CHANGED
ultralytics/data/build.py
CHANGED
@@ -154,7 +154,7 @@ def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, s
|
|
154
154
|
)
|
155
155
|
|
156
156
|
|
157
|
-
def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, rank: int = -1):
|
157
|
+
def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, rank: int = -1, drop_last: bool = False):
|
158
158
|
"""
|
159
159
|
Create and return an InfiniteDataLoader or DataLoader for training or validation.
|
160
160
|
|
@@ -164,6 +164,7 @@ def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, ra
|
|
164
164
|
workers (int): Number of worker threads for loading data.
|
165
165
|
shuffle (bool, optional): Whether to shuffle the dataset.
|
166
166
|
rank (int, optional): Process rank in distributed training. -1 for single-GPU training.
|
167
|
+
drop_last (bool, optional): Whether to drop the last incomplete batch.
|
167
168
|
|
168
169
|
Returns:
|
169
170
|
(InfiniteDataLoader): A dataloader that can be used for training or validation.
|
@@ -189,6 +190,7 @@ def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, ra
|
|
189
190
|
collate_fn=getattr(dataset, "collate_fn", None),
|
190
191
|
worker_init_fn=seed_worker,
|
191
192
|
generator=generator,
|
193
|
+
drop_last=drop_last,
|
192
194
|
)
|
193
195
|
|
194
196
|
|
ultralytics/engine/exporter.py
CHANGED
@@ -62,7 +62,6 @@ import shutil
|
|
62
62
|
import subprocess
|
63
63
|
import time
|
64
64
|
import warnings
|
65
|
-
from contextlib import contextmanager
|
66
65
|
from copy import deepcopy
|
67
66
|
from datetime import datetime
|
68
67
|
from pathlib import Path
|
@@ -107,6 +106,7 @@ from ultralytics.utils.downloads import attempt_download_asset, get_github_asset
|
|
107
106
|
from ultralytics.utils.export import export_engine, export_onnx
|
108
107
|
from ultralytics.utils.files import file_size, spaces_in_path
|
109
108
|
from ultralytics.utils.ops import Profile, nms_rotated
|
109
|
+
from ultralytics.utils.patches import arange_patch
|
110
110
|
from ultralytics.utils.torch_utils import TORCH_1_13, get_cpu_info, get_latest_opset, select_device
|
111
111
|
|
112
112
|
|
@@ -199,27 +199,6 @@ def try_export(inner_func):
|
|
199
199
|
return outer_func
|
200
200
|
|
201
201
|
|
202
|
-
@contextmanager
|
203
|
-
def arange_patch(args):
|
204
|
-
"""
|
205
|
-
Workaround for ONNX torch.arange incompatibility with FP16.
|
206
|
-
|
207
|
-
https://github.com/pytorch/pytorch/issues/148041.
|
208
|
-
"""
|
209
|
-
if args.dynamic and args.half and args.format == "onnx":
|
210
|
-
func = torch.arange
|
211
|
-
|
212
|
-
def arange(*args, dtype=None, **kwargs):
|
213
|
-
"""Return a 1-D tensor of size with values from the interval and common difference."""
|
214
|
-
return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype
|
215
|
-
|
216
|
-
torch.arange = arange # patch
|
217
|
-
yield
|
218
|
-
torch.arange = func # unpatch
|
219
|
-
else:
|
220
|
-
yield
|
221
|
-
|
222
|
-
|
223
202
|
class Exporter:
|
224
203
|
"""
|
225
204
|
A class for exporting YOLO models to various formats.
|
@@ -345,8 +324,6 @@ class Exporter:
|
|
345
324
|
LOGGER.warning("half=True only compatible with GPU export, i.e. use device=0")
|
346
325
|
self.args.half = False
|
347
326
|
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
348
|
-
if self.args.int8 and engine:
|
349
|
-
self.args.dynamic = True # enforce dynamic to export TensorRT INT8
|
350
327
|
if self.args.optimize:
|
351
328
|
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
352
329
|
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
@@ -555,8 +532,6 @@ class Exporter:
|
|
555
532
|
"""Build and return a dataloader for calibration of INT8 models."""
|
556
533
|
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
557
534
|
data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
|
558
|
-
# TensorRT INT8 calibration should use 2x batch size
|
559
|
-
batch = self.args.batch * (2 if self.args.format == "engine" else 1)
|
560
535
|
dataset = YOLODataset(
|
561
536
|
data[self.args.split or "val"],
|
562
537
|
data=data,
|
@@ -564,7 +539,7 @@ class Exporter:
|
|
564
539
|
task=self.model.task,
|
565
540
|
imgsz=self.imgsz[0],
|
566
541
|
augment=False,
|
567
|
-
batch_size=batch,
|
542
|
+
batch_size=self.args.batch,
|
568
543
|
)
|
569
544
|
n = len(dataset)
|
570
545
|
if n < self.args.batch:
|
@@ -574,7 +549,7 @@ class Exporter:
|
|
574
549
|
)
|
575
550
|
elif n < 300:
|
576
551
|
LOGGER.warning(f"{prefix} >300 images recommended for INT8 calibration, found {n} images.")
|
577
|
-
return build_dataloader(dataset, batch=batch, workers=0) # required for batch loading
|
552
|
+
return build_dataloader(dataset, batch=self.args.batch, workers=0, drop_last=True) # required for batch loading
|
578
553
|
|
579
554
|
@try_export
|
580
555
|
def export_torchscript(self, prefix=colorstr("TorchScript:")):
|
@@ -228,10 +228,9 @@ class ClassificationTrainer(BaseTrainer):
|
|
228
228
|
batch (Dict[str, torch.Tensor]): Batch containing images and class labels.
|
229
229
|
ni (int): Number of iterations.
|
230
230
|
"""
|
231
|
+
batch["batch_idx"] = torch.arange(len(batch["img"])) # add batch index for plotting
|
231
232
|
plot_images(
|
232
|
-
|
233
|
-
batch_idx=torch.arange(len(batch["img"])),
|
234
|
-
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
233
|
+
labels=batch,
|
235
234
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
236
235
|
on_plot=self.on_plot,
|
237
236
|
)
|
@@ -13,6 +13,7 @@ from ultralytics.engine.trainer import BaseTrainer
|
|
13
13
|
from ultralytics.models import yolo
|
14
14
|
from ultralytics.nn.tasks import DetectionModel
|
15
15
|
from ultralytics.utils import LOGGER, RANK
|
16
|
+
from ultralytics.utils.patches import override_configs
|
16
17
|
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
|
17
18
|
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
|
18
19
|
|
@@ -210,6 +211,8 @@ class DetectionTrainer(BaseTrainer):
|
|
210
211
|
Returns:
|
211
212
|
(int): Optimal batch size.
|
212
213
|
"""
|
213
|
-
|
214
|
+
with override_configs(self.args, overrides={"cache": False}) as self.args:
|
215
|
+
train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
|
214
216
|
max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
|
217
|
+
del train_dataset # free memory
|
215
218
|
return super().auto_batch(max_num_obj)
|
ultralytics/utils/export.py
CHANGED
@@ -143,11 +143,12 @@ def export_engine(
|
|
143
143
|
for inp in inputs:
|
144
144
|
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
|
145
145
|
config.add_optimization_profile(profile)
|
146
|
+
if int8:
|
147
|
+
config.set_calibration_profile(profile)
|
146
148
|
|
147
149
|
LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
|
148
150
|
if int8:
|
149
151
|
config.set_flag(trt.BuilderFlag.INT8)
|
150
|
-
config.set_calibration_profile(profile)
|
151
152
|
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
|
152
153
|
|
153
154
|
class EngineCalibrator(trt.IInt8Calibrator):
|
@@ -181,7 +182,11 @@ def export_engine(
|
|
181
182
|
trt.IInt8Calibrator.__init__(self)
|
182
183
|
self.dataset = dataset
|
183
184
|
self.data_iter = iter(dataset)
|
184
|
-
self.algo =
|
185
|
+
self.algo = (
|
186
|
+
trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
|
187
|
+
if dla is not None
|
188
|
+
else trt.CalibrationAlgoType.MINMAX_CALIBRATION
|
189
|
+
)
|
185
190
|
self.batch = dataset.batch_size
|
186
191
|
self.cache = Path(cache)
|
187
192
|
|
ultralytics/utils/patches.py
CHANGED
@@ -2,8 +2,10 @@
|
|
2
2
|
"""Monkey patches to update/extend functionality of existing functions."""
|
3
3
|
|
4
4
|
import time
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from copy import copy
|
5
7
|
from pathlib import Path
|
6
|
-
from typing import List, Optional
|
8
|
+
from typing import Any, Dict, List, Optional
|
7
9
|
|
8
10
|
import cv2
|
9
11
|
import numpy as np
|
@@ -139,3 +141,48 @@ def torch_save(*args, **kwargs):
|
|
139
141
|
if i == 3:
|
140
142
|
raise e
|
141
143
|
time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s
|
144
|
+
|
145
|
+
|
146
|
+
@contextmanager
|
147
|
+
def arange_patch(args):
|
148
|
+
"""
|
149
|
+
Workaround for ONNX torch.arange incompatibility with FP16.
|
150
|
+
|
151
|
+
https://github.com/pytorch/pytorch/issues/148041.
|
152
|
+
"""
|
153
|
+
if args.dynamic and args.half and args.format == "onnx":
|
154
|
+
func = torch.arange
|
155
|
+
|
156
|
+
def arange(*args, dtype=None, **kwargs):
|
157
|
+
"""Return a 1-D tensor of size with values from the interval and common difference."""
|
158
|
+
return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype
|
159
|
+
|
160
|
+
torch.arange = arange # patch
|
161
|
+
yield
|
162
|
+
torch.arange = func # unpatch
|
163
|
+
else:
|
164
|
+
yield
|
165
|
+
|
166
|
+
|
167
|
+
@contextmanager
|
168
|
+
def override_configs(args, overrides: Optional[Dict[str, Any]] = None):
|
169
|
+
"""
|
170
|
+
Context manager to temporarily override configurations in args.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
args (IterableSimpleNamespace): Original configuration arguments.
|
174
|
+
overrides (Dict[str, Any]): Dictionary of overrides to apply.
|
175
|
+
|
176
|
+
Yields:
|
177
|
+
(IterableSimpleNamespace): Configuration arguments with overrides applied.
|
178
|
+
"""
|
179
|
+
if overrides:
|
180
|
+
original_args = copy(args)
|
181
|
+
for key, value in overrides.items():
|
182
|
+
setattr(args, key, value)
|
183
|
+
try:
|
184
|
+
yield args
|
185
|
+
finally:
|
186
|
+
args.__dict__.update(original_args.__dict__)
|
187
|
+
else:
|
188
|
+
yield args
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|