dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/utils/export/imx.py
CHANGED
|
@@ -3,20 +3,51 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import subprocess
|
|
6
|
+
import sys
|
|
6
7
|
import types
|
|
7
8
|
from pathlib import Path
|
|
9
|
+
from shutil import which
|
|
8
10
|
|
|
11
|
+
import numpy as np
|
|
9
12
|
import torch
|
|
10
13
|
|
|
11
|
-
from ultralytics.nn.modules import Detect, Pose
|
|
12
|
-
from ultralytics.utils import LOGGER
|
|
14
|
+
from ultralytics.nn.modules import Detect, Pose, Segment
|
|
15
|
+
from ultralytics.utils import LOGGER, WINDOWS
|
|
16
|
+
from ultralytics.utils.patches import onnx_export_patch
|
|
13
17
|
from ultralytics.utils.tal import make_anchors
|
|
14
18
|
from ultralytics.utils.torch_utils import copy_attr
|
|
15
19
|
|
|
20
|
+
# Configuration for Model Compression Toolkit (MCT) quantization
|
|
21
|
+
MCT_CONFIG = {
|
|
22
|
+
"YOLO11": {
|
|
23
|
+
"detect": {
|
|
24
|
+
"layer_names": ["sub", "mul_2", "add_14", "cat_21"],
|
|
25
|
+
"weights_memory": 2585350.2439,
|
|
26
|
+
"n_layers": 238,
|
|
27
|
+
},
|
|
28
|
+
"pose": {
|
|
29
|
+
"layer_names": ["sub", "mul_2", "add_14", "cat_22", "cat_23", "mul_4", "add_15"],
|
|
30
|
+
"weights_memory": 2437771.67,
|
|
31
|
+
"n_layers": 257,
|
|
32
|
+
},
|
|
33
|
+
"classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": 112},
|
|
34
|
+
"segment": {"layer_names": ["sub", "mul_2", "add_14", "cat_22"], "weights_memory": 2466604.8, "n_layers": 265},
|
|
35
|
+
},
|
|
36
|
+
"YOLOv8": {
|
|
37
|
+
"detect": {"layer_names": ["sub", "mul", "add_6", "cat_17"], "weights_memory": 2550540.8, "n_layers": 168},
|
|
38
|
+
"pose": {
|
|
39
|
+
"layer_names": ["add_7", "mul_2", "cat_19", "mul", "sub", "add_6", "cat_18"],
|
|
40
|
+
"weights_memory": 2482451.85,
|
|
41
|
+
"n_layers": 187,
|
|
42
|
+
},
|
|
43
|
+
"classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": 73},
|
|
44
|
+
"segment": {"layer_names": ["sub", "mul", "add_6", "cat_18"], "weights_memory": 2580060.0, "n_layers": 195},
|
|
45
|
+
},
|
|
46
|
+
}
|
|
47
|
+
|
|
16
48
|
|
|
17
49
|
class FXModel(torch.nn.Module):
|
|
18
|
-
"""
|
|
19
|
-
A custom model class for torch.fx compatibility.
|
|
50
|
+
"""A custom model class for torch.fx compatibility.
|
|
20
51
|
|
|
21
52
|
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
|
|
22
53
|
manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
|
|
@@ -27,8 +58,7 @@ class FXModel(torch.nn.Module):
|
|
|
27
58
|
"""
|
|
28
59
|
|
|
29
60
|
def __init__(self, model, imgsz=(640, 640)):
|
|
30
|
-
"""
|
|
31
|
-
Initialize the FXModel.
|
|
61
|
+
"""Initialize the FXModel.
|
|
32
62
|
|
|
33
63
|
Args:
|
|
34
64
|
model (nn.Module): The original model to wrap for torch.fx compatibility.
|
|
@@ -41,8 +71,7 @@ class FXModel(torch.nn.Module):
|
|
|
41
71
|
self.imgsz = imgsz
|
|
42
72
|
|
|
43
73
|
def forward(self, x):
|
|
44
|
-
"""
|
|
45
|
-
Forward pass through the model.
|
|
74
|
+
"""Forward pass through the model.
|
|
46
75
|
|
|
47
76
|
This method performs the forward pass through the model, handling the dependencies between layers and saving
|
|
48
77
|
intermediate outputs.
|
|
@@ -68,6 +97,8 @@ class FXModel(torch.nn.Module):
|
|
|
68
97
|
)
|
|
69
98
|
if type(m) is Pose:
|
|
70
99
|
m.forward = types.MethodType(pose_forward, m) # bind method to Detect
|
|
100
|
+
if type(m) is Segment:
|
|
101
|
+
m.forward = types.MethodType(segment_forward, m) # bind method to Detect
|
|
71
102
|
x = m(x) # run
|
|
72
103
|
y.append(x) # save output
|
|
73
104
|
return x
|
|
@@ -87,11 +118,20 @@ def pose_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tenso
|
|
|
87
118
|
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
|
88
119
|
x = Detect.forward(self, x)
|
|
89
120
|
pred_kpt = self.kpts_decode(bs, kpt)
|
|
90
|
-
return
|
|
121
|
+
return *x, pred_kpt.permute(0, 2, 1)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def segment_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
125
|
+
"""Forward pass for imx segmentation."""
|
|
126
|
+
p = self.proto(x[0]) # mask protos
|
|
127
|
+
bs = p.shape[0] # batch size
|
|
128
|
+
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
|
129
|
+
x = Detect.forward(self, x)
|
|
130
|
+
return *x, mc.transpose(1, 2), p
|
|
91
131
|
|
|
92
132
|
|
|
93
133
|
class NMSWrapper(torch.nn.Module):
|
|
94
|
-
"""Wrap PyTorch Module with multiclass_nms layer from
|
|
134
|
+
"""Wrap PyTorch Module with multiclass_nms layer from edge-mdt-cl."""
|
|
95
135
|
|
|
96
136
|
def __init__(
|
|
97
137
|
self,
|
|
@@ -101,8 +141,7 @@ class NMSWrapper(torch.nn.Module):
|
|
|
101
141
|
max_detections: int = 300,
|
|
102
142
|
task: str = "detect",
|
|
103
143
|
):
|
|
104
|
-
"""
|
|
105
|
-
Initialize NMSWrapper with PyTorch Module and NMS parameters.
|
|
144
|
+
"""Initialize NMSWrapper with PyTorch Module and NMS parameters.
|
|
106
145
|
|
|
107
146
|
Args:
|
|
108
147
|
model (torch.nn.Module): Model instance.
|
|
@@ -120,7 +159,7 @@ class NMSWrapper(torch.nn.Module):
|
|
|
120
159
|
|
|
121
160
|
def forward(self, images):
|
|
122
161
|
"""Forward pass with model inference and NMS post-processing."""
|
|
123
|
-
from
|
|
162
|
+
from edgemdt_cl.pytorch.nms.nms_with_indices import multiclass_nms_with_indices
|
|
124
163
|
|
|
125
164
|
# model inference
|
|
126
165
|
outputs = self.model(images)
|
|
@@ -136,6 +175,10 @@ class NMSWrapper(torch.nn.Module):
|
|
|
136
175
|
kpts = outputs[2] # (bs, max_detections, kpts 17*3)
|
|
137
176
|
out_kpts = torch.gather(kpts, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, kpts.size(-1)))
|
|
138
177
|
return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_kpts
|
|
178
|
+
if self.task == "segment":
|
|
179
|
+
mc, proto = outputs[2], outputs[3]
|
|
180
|
+
out_mc = torch.gather(mc, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, mc.size(-1)))
|
|
181
|
+
return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_mc, proto
|
|
139
182
|
return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, nms_outputs.n_valid
|
|
140
183
|
|
|
141
184
|
|
|
@@ -150,12 +193,10 @@ def torch2imx(
|
|
|
150
193
|
dataset=None,
|
|
151
194
|
prefix: str = "",
|
|
152
195
|
):
|
|
153
|
-
"""
|
|
154
|
-
Export YOLO model to IMX format for deployment on Sony IMX500 devices.
|
|
196
|
+
"""Export YOLO model to IMX format for deployment on Sony IMX500 devices.
|
|
155
197
|
|
|
156
|
-
This function quantizes a YOLO model using Model Compression Toolkit (MCT) and exports it
|
|
157
|
-
|
|
158
|
-
models for detection and pose estimation tasks.
|
|
198
|
+
This function quantizes a YOLO model using Model Compression Toolkit (MCT) and exports it to IMX format compatible
|
|
199
|
+
with Sony IMX500 edge devices. It supports both YOLOv8n and YOLO11n models for detection and pose estimation tasks.
|
|
159
200
|
|
|
160
201
|
Args:
|
|
161
202
|
model (torch.nn.Module): The YOLO model to export. Must be YOLOv8n or YOLO11n.
|
|
@@ -164,8 +205,8 @@ def torch2imx(
|
|
|
164
205
|
iou (float): IoU threshold for NMS post-processing.
|
|
165
206
|
max_det (int): Maximum number of detections to return.
|
|
166
207
|
metadata (dict | None, optional): Metadata to embed in the ONNX model. Defaults to None.
|
|
167
|
-
gptq (bool, optional): Whether to use Gradient-Based Post Training Quantization.
|
|
168
|
-
|
|
208
|
+
gptq (bool, optional): Whether to use Gradient-Based Post Training Quantization. If False, uses standard Post
|
|
209
|
+
Training Quantization. Defaults to False.
|
|
169
210
|
dataset (optional): Representative dataset for quantization calibration. Defaults to None.
|
|
170
211
|
prefix (str, optional): Logging prefix string. Defaults to "".
|
|
171
212
|
|
|
@@ -175,13 +216,13 @@ def torch2imx(
|
|
|
175
216
|
Raises:
|
|
176
217
|
ValueError: If the model is not a supported YOLOv8n or YOLO11n variant.
|
|
177
218
|
|
|
178
|
-
|
|
219
|
+
Examples:
|
|
179
220
|
>>> from ultralytics import YOLO
|
|
180
221
|
>>> model = YOLO("yolo11n.pt")
|
|
181
222
|
>>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.45, max_det=300)
|
|
182
223
|
|
|
183
|
-
|
|
184
|
-
- Requires model_compression_toolkit, onnx, edgemdt_tpc, and
|
|
224
|
+
Notes:
|
|
225
|
+
- Requires model_compression_toolkit, onnx, edgemdt_tpc, and edge-mdt-cl packages
|
|
185
226
|
- Only supports YOLOv8n and YOLO11n models (detection and pose tasks)
|
|
186
227
|
- Output includes quantized ONNX model, IMX binary, and labels.txt file
|
|
187
228
|
"""
|
|
@@ -197,33 +238,17 @@ def torch2imx(
|
|
|
197
238
|
img = img / 255.0
|
|
198
239
|
yield [img]
|
|
199
240
|
|
|
241
|
+
# NOTE: need tpc_version to be "4.0" for IMX500 Pose estimation models
|
|
200
242
|
tpc = get_target_platform_capabilities(tpc_version="4.0", device_type="imx500")
|
|
201
243
|
|
|
202
244
|
bit_cfg = mct.core.BitWidthConfig()
|
|
203
|
-
if "C2PSA" in model.__str__()
|
|
204
|
-
if model.task == "detect":
|
|
205
|
-
layer_names = ["sub", "mul_2", "add_14", "cat_21"]
|
|
206
|
-
weights_memory = 2585350.2439
|
|
207
|
-
n_layers = 238 # 238 layers for fused YOLO11n
|
|
208
|
-
elif model.task == "pose":
|
|
209
|
-
layer_names = ["sub", "mul_2", "add_14", "cat_22", "cat_23", "mul_4", "add_15"]
|
|
210
|
-
weights_memory = 2437771.67
|
|
211
|
-
n_layers = 257 # 257 layers for fused YOLO11n-pose
|
|
212
|
-
else: # YOLOv8
|
|
213
|
-
if model.task == "detect":
|
|
214
|
-
layer_names = ["sub", "mul", "add_6", "cat_17"]
|
|
215
|
-
weights_memory = 2550540.8
|
|
216
|
-
n_layers = 168 # 168 layers for fused YOLOv8n
|
|
217
|
-
elif model.task == "pose":
|
|
218
|
-
layer_names = ["add_7", "mul_2", "cat_19", "mul", "sub", "add_6", "cat_18"]
|
|
219
|
-
weights_memory = 2482451.85
|
|
220
|
-
n_layers = 187 # 187 layers for fused YOLO11n-pose
|
|
245
|
+
mct_config = MCT_CONFIG["YOLO11" if "C2PSA" in model.__str__() else "YOLOv8"][model.task]
|
|
221
246
|
|
|
222
247
|
# Check if the model has the expected number of layers
|
|
223
|
-
if len(list(model.modules())) != n_layers:
|
|
248
|
+
if len(list(model.modules())) != mct_config["n_layers"]:
|
|
224
249
|
raise ValueError("IMX export only supported for YOLOv8n and YOLO11n models.")
|
|
225
250
|
|
|
226
|
-
for layer_name in layer_names:
|
|
251
|
+
for layer_name in mct_config["layer_names"]:
|
|
227
252
|
bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)
|
|
228
253
|
|
|
229
254
|
config = mct.core.CoreConfig(
|
|
@@ -232,7 +257,7 @@ def torch2imx(
|
|
|
232
257
|
bit_width_config=bit_cfg,
|
|
233
258
|
)
|
|
234
259
|
|
|
235
|
-
resource_utilization = mct.core.ResourceUtilization(weights_memory=weights_memory)
|
|
260
|
+
resource_utilization = mct.core.ResourceUtilization(weights_memory=mct_config["weights_memory"])
|
|
236
261
|
|
|
237
262
|
quant_model = (
|
|
238
263
|
mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
|
|
@@ -255,20 +280,23 @@ def torch2imx(
|
|
|
255
280
|
)[0]
|
|
256
281
|
)
|
|
257
282
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
283
|
+
if model.task != "classify":
|
|
284
|
+
quant_model = NMSWrapper(
|
|
285
|
+
model=quant_model,
|
|
286
|
+
score_threshold=conf or 0.001,
|
|
287
|
+
iou_threshold=iou,
|
|
288
|
+
max_detections=max_det,
|
|
289
|
+
task=model.task,
|
|
290
|
+
)
|
|
265
291
|
|
|
266
292
|
f = Path(str(file).replace(file.suffix, "_imx_model"))
|
|
267
293
|
f.mkdir(exist_ok=True)
|
|
268
294
|
onnx_model = f / Path(str(file.name).replace(file.suffix, "_imx.onnx")) # js dir
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
295
|
+
|
|
296
|
+
with onnx_export_patch():
|
|
297
|
+
mct.exporter.pytorch_export_model(
|
|
298
|
+
model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
|
|
299
|
+
)
|
|
272
300
|
|
|
273
301
|
model_onnx = onnx.load(onnx_model) # load onnx model
|
|
274
302
|
for k, v in metadata.items():
|
|
@@ -277,8 +305,16 @@ def torch2imx(
|
|
|
277
305
|
|
|
278
306
|
onnx.save(model_onnx, onnx_model)
|
|
279
307
|
|
|
308
|
+
# Find imxconv-pt binary - check venv bin directory first, then PATH
|
|
309
|
+
bin_dir = Path(sys.executable).parent
|
|
310
|
+
imxconv = bin_dir / ("imxconv-pt.exe" if WINDOWS else "imxconv-pt")
|
|
311
|
+
if not imxconv.exists():
|
|
312
|
+
imxconv = which("imxconv-pt") # fallback to PATH
|
|
313
|
+
if not imxconv:
|
|
314
|
+
raise FileNotFoundError("imxconv-pt not found. Install with: pip install imx500-converter[pt]")
|
|
315
|
+
|
|
280
316
|
subprocess.run(
|
|
281
|
-
[
|
|
317
|
+
[str(imxconv), "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
|
|
282
318
|
check=True,
|
|
283
319
|
)
|
|
284
320
|
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ultralytics.nn.modules import Detect, Pose
|
|
11
|
+
from ultralytics.utils import LOGGER
|
|
12
|
+
from ultralytics.utils.downloads import attempt_download_asset
|
|
13
|
+
from ultralytics.utils.files import spaces_in_path
|
|
14
|
+
from ultralytics.utils.tal import make_anchors
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def tf_wrapper(model: torch.nn.Module) -> torch.nn.Module:
|
|
18
|
+
"""A wrapper to add TensorFlow compatible inference methods to Detect and Pose layers."""
|
|
19
|
+
for m in model.modules():
|
|
20
|
+
if not isinstance(m, Detect):
|
|
21
|
+
continue
|
|
22
|
+
import types
|
|
23
|
+
|
|
24
|
+
m._inference = types.MethodType(_tf_inference, m)
|
|
25
|
+
if type(m) is Pose:
|
|
26
|
+
m.kpts_decode = types.MethodType(tf_kpts_decode, m)
|
|
27
|
+
return model
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _tf_inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
|
|
31
|
+
"""Decode boxes and cls scores for tf object detection."""
|
|
32
|
+
shape = x[0].shape # BCHW
|
|
33
|
+
x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
|
|
34
|
+
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
|
35
|
+
if self.dynamic or self.shape != shape:
|
|
36
|
+
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
|
37
|
+
self.shape = shape
|
|
38
|
+
grid_h, grid_w = shape[2], shape[3]
|
|
39
|
+
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
|
40
|
+
norm = self.strides / (self.stride[0] * grid_size)
|
|
41
|
+
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
|
42
|
+
return torch.cat((dbox, cls.sigmoid()), 1)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def tf_kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
"""Decode keypoints for tf pose estimation."""
|
|
47
|
+
ndim = self.kpt_shape[1]
|
|
48
|
+
# required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
|
|
49
|
+
# Precompute normalization factor to increase numerical stability
|
|
50
|
+
y = kpts.view(bs, *self.kpt_shape, -1)
|
|
51
|
+
grid_h, grid_w = self.shape[2], self.shape[3]
|
|
52
|
+
grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
|
|
53
|
+
norm = self.strides / (self.stride[0] * grid_size)
|
|
54
|
+
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
|
|
55
|
+
if ndim == 3:
|
|
56
|
+
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
|
57
|
+
return a.view(bs, self.nk, -1)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def onnx2saved_model(
|
|
61
|
+
onnx_file: str,
|
|
62
|
+
output_dir: Path,
|
|
63
|
+
int8: bool = False,
|
|
64
|
+
images: np.ndarray = None,
|
|
65
|
+
disable_group_convolution: bool = False,
|
|
66
|
+
prefix="",
|
|
67
|
+
):
|
|
68
|
+
"""Convert a ONNX model to TensorFlow SavedModel format via ONNX.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
onnx_file (str): ONNX file path.
|
|
72
|
+
output_dir (Path): Output directory path for the SavedModel.
|
|
73
|
+
int8 (bool, optional): Enable INT8 quantization. Defaults to False.
|
|
74
|
+
images (np.ndarray, optional): Calibration images for INT8 quantization in BHWC format.
|
|
75
|
+
disable_group_convolution (bool, optional): Disable group convolution optimization. Defaults to False.
|
|
76
|
+
prefix (str, optional): Logging prefix. Defaults to "".
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
(keras.Model): Converted Keras model.
|
|
80
|
+
|
|
81
|
+
Notes:
|
|
82
|
+
- Requires onnx2tf package. Downloads calibration data if INT8 quantization is enabled.
|
|
83
|
+
- Removes temporary files and renames quantized models after conversion.
|
|
84
|
+
"""
|
|
85
|
+
# Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
|
|
86
|
+
onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
|
|
87
|
+
if not onnx2tf_file.exists():
|
|
88
|
+
attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
|
|
89
|
+
np_data = None
|
|
90
|
+
if int8:
|
|
91
|
+
tmp_file = output_dir / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
|
|
92
|
+
if images is not None:
|
|
93
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
94
|
+
np.save(str(tmp_file), images) # BHWC
|
|
95
|
+
np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
|
|
96
|
+
|
|
97
|
+
# Patch onnx.helper for onnx_graphsurgeon compatibility with ONNX>=1.17
|
|
98
|
+
# The float32_to_bfloat16 function was removed in ONNX 1.17, but onnx_graphsurgeon still uses it
|
|
99
|
+
import onnx.helper
|
|
100
|
+
|
|
101
|
+
if not hasattr(onnx.helper, "float32_to_bfloat16"):
|
|
102
|
+
import struct
|
|
103
|
+
|
|
104
|
+
def float32_to_bfloat16(fval):
|
|
105
|
+
"""Convert float32 to bfloat16 (truncates lower 16 bits of mantissa)."""
|
|
106
|
+
ival = struct.unpack("=I", struct.pack("=f", fval))[0]
|
|
107
|
+
return ival >> 16
|
|
108
|
+
|
|
109
|
+
onnx.helper.float32_to_bfloat16 = float32_to_bfloat16
|
|
110
|
+
|
|
111
|
+
import onnx2tf # scoped for after ONNX export for reduced conflict during import
|
|
112
|
+
|
|
113
|
+
LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
|
|
114
|
+
keras_model = onnx2tf.convert(
|
|
115
|
+
input_onnx_file_path=onnx_file,
|
|
116
|
+
output_folder_path=str(output_dir),
|
|
117
|
+
not_use_onnxsim=True,
|
|
118
|
+
verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
|
|
119
|
+
output_integer_quantized_tflite=int8,
|
|
120
|
+
custom_input_op_name_np_data_path=np_data,
|
|
121
|
+
enable_batchmatmul_unfold=True and not int8, # fix lower no. of detected objects on GPU delegate
|
|
122
|
+
output_signaturedefs=True, # fix error with Attention block group convolution
|
|
123
|
+
disable_group_convolution=disable_group_convolution, # fix error with group convolution
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Remove/rename TFLite models
|
|
127
|
+
if int8:
|
|
128
|
+
tmp_file.unlink(missing_ok=True)
|
|
129
|
+
for file in output_dir.rglob("*_dynamic_range_quant.tflite"):
|
|
130
|
+
file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
|
|
131
|
+
for file in output_dir.rglob("*_integer_quant_with_int16_act.tflite"):
|
|
132
|
+
file.unlink() # delete extra fp16 activation TFLite files
|
|
133
|
+
return keras_model
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def keras2pb(keras_model, file: Path, prefix=""):
|
|
137
|
+
"""Convert a Keras model to TensorFlow GraphDef (.pb) format.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
keras_model (keras.Model): Keras model to convert to frozen graph format.
|
|
141
|
+
file (Path): Output file path (suffix will be changed to .pb).
|
|
142
|
+
prefix (str, optional): Logging prefix. Defaults to "".
|
|
143
|
+
|
|
144
|
+
Notes:
|
|
145
|
+
Creates a frozen graph by converting variables to constants for inference optimization.
|
|
146
|
+
"""
|
|
147
|
+
import tensorflow as tf
|
|
148
|
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
|
149
|
+
|
|
150
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
|
151
|
+
m = tf.function(lambda x: keras_model(x)) # full model
|
|
152
|
+
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
|
153
|
+
frozen_func = convert_variables_to_constants_v2(m)
|
|
154
|
+
frozen_func.graph.as_graph_def()
|
|
155
|
+
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(file.parent), name=file.name, as_text=False)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def tflite2edgetpu(tflite_file: str | Path, output_dir: str | Path, prefix: str = ""):
|
|
159
|
+
"""Convert a TensorFlow Lite model to Edge TPU format using the Edge TPU compiler.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
tflite_file (str | Path): Path to the input TensorFlow Lite (.tflite) model file.
|
|
163
|
+
output_dir (str | Path): Output directory path for the compiled Edge TPU model.
|
|
164
|
+
prefix (str, optional): Logging prefix. Defaults to "".
|
|
165
|
+
|
|
166
|
+
Notes:
|
|
167
|
+
Requires the Edge TPU compiler to be installed. The function compiles the TFLite model
|
|
168
|
+
for optimal performance on Google's Edge TPU hardware accelerator.
|
|
169
|
+
"""
|
|
170
|
+
import subprocess
|
|
171
|
+
|
|
172
|
+
cmd = (
|
|
173
|
+
"edgetpu_compiler "
|
|
174
|
+
f'--out_dir "{output_dir}" '
|
|
175
|
+
"--show_operations "
|
|
176
|
+
"--search_delegate "
|
|
177
|
+
"--delegate_search_step 30 "
|
|
178
|
+
"--timeout_sec 180 "
|
|
179
|
+
f'"{tflite_file}"'
|
|
180
|
+
)
|
|
181
|
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
182
|
+
subprocess.run(cmd, shell=True)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def pb2tfjs(pb_file: str, output_dir: str, half: bool = False, int8: bool = False, prefix: str = ""):
|
|
186
|
+
"""Convert a TensorFlow GraphDef (.pb) model to TensorFlow.js format.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
pb_file (str): Path to the input TensorFlow GraphDef (.pb) model file.
|
|
190
|
+
output_dir (str): Output directory path for the converted TensorFlow.js model.
|
|
191
|
+
half (bool, optional): Enable FP16 quantization. Defaults to False.
|
|
192
|
+
int8 (bool, optional): Enable INT8 quantization. Defaults to False.
|
|
193
|
+
prefix (str, optional): Logging prefix. Defaults to "".
|
|
194
|
+
|
|
195
|
+
Notes:
|
|
196
|
+
Requires tensorflowjs package. Uses tensorflowjs_converter command-line tool for conversion.
|
|
197
|
+
Handles spaces in file paths and warns if output directory contains spaces.
|
|
198
|
+
"""
|
|
199
|
+
import subprocess
|
|
200
|
+
|
|
201
|
+
import tensorflow as tf
|
|
202
|
+
import tensorflowjs as tfjs
|
|
203
|
+
|
|
204
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
|
|
205
|
+
|
|
206
|
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
|
207
|
+
with open(pb_file, "rb") as file:
|
|
208
|
+
gd.ParseFromString(file.read())
|
|
209
|
+
outputs = ",".join(gd_outputs(gd))
|
|
210
|
+
LOGGER.info(f"\n{prefix} output node names: {outputs}")
|
|
211
|
+
|
|
212
|
+
quantization = "--quantize_float16" if half else "--quantize_uint8" if int8 else ""
|
|
213
|
+
with spaces_in_path(pb_file) as fpb_, spaces_in_path(output_dir) as f_: # exporter cannot handle spaces in paths
|
|
214
|
+
cmd = (
|
|
215
|
+
"tensorflowjs_converter "
|
|
216
|
+
f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
|
|
217
|
+
)
|
|
218
|
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
219
|
+
subprocess.run(cmd, shell=True)
|
|
220
|
+
|
|
221
|
+
if " " in output_dir:
|
|
222
|
+
LOGGER.warning(f"{prefix} your model may not work correctly with spaces in path '{output_dir}'.")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def gd_outputs(gd):
|
|
226
|
+
"""Return TensorFlow GraphDef model output node names."""
|
|
227
|
+
name_list, input_list = [], []
|
|
228
|
+
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
|
229
|
+
name_list.append(node.name)
|
|
230
|
+
input_list.extend(node.input)
|
|
231
|
+
return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
|
ultralytics/utils/files.py
CHANGED
|
@@ -13,11 +13,10 @@ from pathlib import Path
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class WorkingDirectory(contextlib.ContextDecorator):
|
|
16
|
-
"""
|
|
17
|
-
A context manager and decorator for temporarily changing the working directory.
|
|
16
|
+
"""A context manager and decorator for temporarily changing the working directory.
|
|
18
17
|
|
|
19
|
-
This class allows for the temporary change of the working directory using a context manager or decorator.
|
|
20
|
-
|
|
18
|
+
This class allows for the temporary change of the working directory using a context manager or decorator. It ensures
|
|
19
|
+
that the original working directory is restored after the context or decorated function completes.
|
|
21
20
|
|
|
22
21
|
Attributes:
|
|
23
22
|
dir (Path | str): The new directory to switch to.
|
|
@@ -29,15 +28,15 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
|
|
29
28
|
|
|
30
29
|
Examples:
|
|
31
30
|
Using as a context manager:
|
|
32
|
-
>>> with WorkingDirectory(
|
|
33
|
-
|
|
34
|
-
|
|
31
|
+
>>> with WorkingDirectory("/path/to/new/dir"):
|
|
32
|
+
... # Perform operations in the new directory
|
|
33
|
+
... pass
|
|
35
34
|
|
|
36
35
|
Using as a decorator:
|
|
37
|
-
>>> @WorkingDirectory(
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
36
|
+
>>> @WorkingDirectory("/path/to/new/dir")
|
|
37
|
+
... def some_function():
|
|
38
|
+
... # Perform operations in the new directory
|
|
39
|
+
... pass
|
|
41
40
|
"""
|
|
42
41
|
|
|
43
42
|
def __init__(self, new_dir: str | Path):
|
|
@@ -49,15 +48,14 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
|
|
49
48
|
"""Change the current working directory to the specified directory upon entering the context."""
|
|
50
49
|
os.chdir(self.dir)
|
|
51
50
|
|
|
52
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
51
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
53
52
|
"""Restore the original working directory when exiting the context."""
|
|
54
53
|
os.chdir(self.cwd)
|
|
55
54
|
|
|
56
55
|
|
|
57
56
|
@contextmanager
|
|
58
57
|
def spaces_in_path(path: str | Path):
|
|
59
|
-
"""
|
|
60
|
-
Context manager to handle paths with spaces in their names.
|
|
58
|
+
"""Context manager to handle paths with spaces in their names.
|
|
61
59
|
|
|
62
60
|
If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes
|
|
63
61
|
the context code block, then copies the file/directory back to its original location.
|
|
@@ -66,13 +64,12 @@ def spaces_in_path(path: str | Path):
|
|
|
66
64
|
path (str | Path): The original path that may contain spaces.
|
|
67
65
|
|
|
68
66
|
Yields:
|
|
69
|
-
(Path | str): Temporary path with spaces replaced by underscores
|
|
70
|
-
original path.
|
|
67
|
+
(Path | str): Temporary path with any spaces replaced by underscores.
|
|
71
68
|
|
|
72
69
|
Examples:
|
|
73
|
-
>>> with spaces_in_path(
|
|
74
|
-
|
|
75
|
-
|
|
70
|
+
>>> with spaces_in_path("/path/with spaces") as new_path:
|
|
71
|
+
... # Your code here
|
|
72
|
+
... pass
|
|
76
73
|
"""
|
|
77
74
|
# If path has spaces, replace them with underscores
|
|
78
75
|
if " " in str(path):
|
|
@@ -107,12 +104,11 @@ def spaces_in_path(path: str | Path):
|
|
|
107
104
|
|
|
108
105
|
|
|
109
106
|
def increment_path(path: str | Path, exist_ok: bool = False, sep: str = "", mkdir: bool = False) -> Path:
|
|
110
|
-
"""
|
|
111
|
-
Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
|
107
|
+
"""Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
|
112
108
|
|
|
113
|
-
If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to
|
|
114
|
-
|
|
115
|
-
|
|
109
|
+
If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to the
|
|
110
|
+
end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the number
|
|
111
|
+
will be appended directly to the end of the path.
|
|
116
112
|
|
|
117
113
|
Args:
|
|
118
114
|
path (str | Path): Path to increment.
|
|
@@ -185,8 +181,7 @@ def get_latest_run(search_dir: str = ".") -> str:
|
|
|
185
181
|
|
|
186
182
|
|
|
187
183
|
def update_models(model_names: tuple = ("yolo11n.pt",), source_dir: Path = Path("."), update_names: bool = False):
|
|
188
|
-
"""
|
|
189
|
-
Update and re-save specified YOLO models in an 'updated_models' subdirectory.
|
|
184
|
+
"""Update and re-save specified YOLO models in an 'updated_models' subdirectory.
|
|
190
185
|
|
|
191
186
|
Args:
|
|
192
187
|
model_names (tuple, optional): Model filenames to update.
|
|
@@ -201,13 +196,14 @@ def update_models(model_names: tuple = ("yolo11n.pt",), source_dir: Path = Path(
|
|
|
201
196
|
"""
|
|
202
197
|
from ultralytics import YOLO
|
|
203
198
|
from ultralytics.nn.autobackend import default_class_names
|
|
199
|
+
from ultralytics.utils import LOGGER
|
|
204
200
|
|
|
205
201
|
target_dir = source_dir / "updated_models"
|
|
206
202
|
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
|
|
207
203
|
|
|
208
204
|
for model_name in model_names:
|
|
209
205
|
model_path = source_dir / model_name
|
|
210
|
-
|
|
206
|
+
LOGGER.info(f"Loading model from {model_path}")
|
|
211
207
|
|
|
212
208
|
# Load model
|
|
213
209
|
model = YOLO(model_path)
|
|
@@ -219,5 +215,5 @@ def update_models(model_names: tuple = ("yolo11n.pt",), source_dir: Path = Path(
|
|
|
219
215
|
save_path = target_dir / model_name
|
|
220
216
|
|
|
221
217
|
# Save model using model.save()
|
|
222
|
-
|
|
218
|
+
LOGGER.info(f"Re-saving {model_name} model to {save_path}")
|
|
223
219
|
model.save(save_path)
|