dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py
CHANGED
|
@@ -20,6 +20,7 @@ from ultralytics.nn.modules import (
|
|
|
20
20
|
C3TR,
|
|
21
21
|
ELAN1,
|
|
22
22
|
OBB,
|
|
23
|
+
OBB26,
|
|
23
24
|
PSA,
|
|
24
25
|
SPP,
|
|
25
26
|
SPPELAN,
|
|
@@ -55,6 +56,7 @@ from ultralytics.nn.modules import (
|
|
|
55
56
|
Index,
|
|
56
57
|
LRPCHead,
|
|
57
58
|
Pose,
|
|
59
|
+
Pose26,
|
|
58
60
|
RepC3,
|
|
59
61
|
RepConv,
|
|
60
62
|
RepNCSPELAN4,
|
|
@@ -63,16 +65,19 @@ from ultralytics.nn.modules import (
|
|
|
63
65
|
RTDETRDecoder,
|
|
64
66
|
SCDown,
|
|
65
67
|
Segment,
|
|
68
|
+
Segment26,
|
|
66
69
|
TorchVision,
|
|
67
70
|
WorldDetect,
|
|
68
71
|
YOLOEDetect,
|
|
69
72
|
YOLOESegment,
|
|
73
|
+
YOLOESegment26,
|
|
70
74
|
v10Detect,
|
|
71
75
|
)
|
|
72
76
|
from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, YAML, colorstr, emojis
|
|
73
77
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
|
74
78
|
from ultralytics.utils.loss import (
|
|
75
|
-
|
|
79
|
+
E2ELoss,
|
|
80
|
+
PoseLoss26,
|
|
76
81
|
v8ClassificationLoss,
|
|
77
82
|
v8DetectionLoss,
|
|
78
83
|
v8OBBLoss,
|
|
@@ -95,11 +100,10 @@ from ultralytics.utils.torch_utils import (
|
|
|
95
100
|
|
|
96
101
|
|
|
97
102
|
class BaseModel(torch.nn.Module):
|
|
98
|
-
"""
|
|
99
|
-
Base class for all YOLO models in the Ultralytics family.
|
|
103
|
+
"""Base class for all YOLO models in the Ultralytics family.
|
|
100
104
|
|
|
101
|
-
This class provides common functionality for YOLO models including forward pass handling, model fusion,
|
|
102
|
-
|
|
105
|
+
This class provides common functionality for YOLO models including forward pass handling, model fusion, information
|
|
106
|
+
display, and weight loading capabilities.
|
|
103
107
|
|
|
104
108
|
Attributes:
|
|
105
109
|
model (torch.nn.Module): The neural network model.
|
|
@@ -121,8 +125,7 @@ class BaseModel(torch.nn.Module):
|
|
|
121
125
|
"""
|
|
122
126
|
|
|
123
127
|
def forward(self, x, *args, **kwargs):
|
|
124
|
-
"""
|
|
125
|
-
Perform forward pass of the model for either training or inference.
|
|
128
|
+
"""Perform forward pass of the model for either training or inference.
|
|
126
129
|
|
|
127
130
|
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
|
128
131
|
|
|
@@ -139,8 +142,7 @@ class BaseModel(torch.nn.Module):
|
|
|
139
142
|
return self.predict(x, *args, **kwargs)
|
|
140
143
|
|
|
141
144
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
|
142
|
-
"""
|
|
143
|
-
Perform a forward pass through the network.
|
|
145
|
+
"""Perform a forward pass through the network.
|
|
144
146
|
|
|
145
147
|
Args:
|
|
146
148
|
x (torch.Tensor): The input tensor to the model.
|
|
@@ -157,8 +159,7 @@ class BaseModel(torch.nn.Module):
|
|
|
157
159
|
return self._predict_once(x, profile, visualize, embed)
|
|
158
160
|
|
|
159
161
|
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
|
160
|
-
"""
|
|
161
|
-
Perform a forward pass through the network.
|
|
162
|
+
"""Perform a forward pass through the network.
|
|
162
163
|
|
|
163
164
|
Args:
|
|
164
165
|
x (torch.Tensor): The input tensor to the model.
|
|
@@ -196,8 +197,7 @@ class BaseModel(torch.nn.Module):
|
|
|
196
197
|
return self._predict_once(x)
|
|
197
198
|
|
|
198
199
|
def _profile_one_layer(self, m, x, dt):
|
|
199
|
-
"""
|
|
200
|
-
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
|
200
|
+
"""Profile the computation time and FLOPs of a single layer of the model on a given input.
|
|
201
201
|
|
|
202
202
|
Args:
|
|
203
203
|
m (torch.nn.Module): The layer to be profiled.
|
|
@@ -222,8 +222,7 @@ class BaseModel(torch.nn.Module):
|
|
|
222
222
|
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
|
223
223
|
|
|
224
224
|
def fuse(self, verbose=True):
|
|
225
|
-
"""
|
|
226
|
-
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
|
225
|
+
"""Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
|
227
226
|
efficiency.
|
|
228
227
|
|
|
229
228
|
Returns:
|
|
@@ -247,15 +246,14 @@ class BaseModel(torch.nn.Module):
|
|
|
247
246
|
if isinstance(m, RepVGGDW):
|
|
248
247
|
m.fuse()
|
|
249
248
|
m.forward = m.forward_fuse
|
|
250
|
-
if isinstance(m,
|
|
249
|
+
if isinstance(m, Detect) and getattr(m, "end2end", False):
|
|
251
250
|
m.fuse() # remove one2many head
|
|
252
251
|
self.info(verbose=verbose)
|
|
253
252
|
|
|
254
253
|
return self
|
|
255
254
|
|
|
256
255
|
def is_fused(self, thresh=10):
|
|
257
|
-
"""
|
|
258
|
-
Check if the model has less than a certain threshold of BatchNorm layers.
|
|
256
|
+
"""Check if the model has less than a certain threshold of BatchNorm layers.
|
|
259
257
|
|
|
260
258
|
Args:
|
|
261
259
|
thresh (int, optional): The threshold number of BatchNorm layers.
|
|
@@ -267,8 +265,7 @@ class BaseModel(torch.nn.Module):
|
|
|
267
265
|
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
|
268
266
|
|
|
269
267
|
def info(self, detailed=False, verbose=True, imgsz=640):
|
|
270
|
-
"""
|
|
271
|
-
Print model information.
|
|
268
|
+
"""Print model information.
|
|
272
269
|
|
|
273
270
|
Args:
|
|
274
271
|
detailed (bool): If True, prints out detailed information about the model.
|
|
@@ -278,8 +275,7 @@ class BaseModel(torch.nn.Module):
|
|
|
278
275
|
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
|
279
276
|
|
|
280
277
|
def _apply(self, fn):
|
|
281
|
-
"""
|
|
282
|
-
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
278
|
+
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
283
279
|
|
|
284
280
|
Args:
|
|
285
281
|
fn (function): The function to apply to the model.
|
|
@@ -298,8 +294,7 @@ class BaseModel(torch.nn.Module):
|
|
|
298
294
|
return self
|
|
299
295
|
|
|
300
296
|
def load(self, weights, verbose=True):
|
|
301
|
-
"""
|
|
302
|
-
Load weights into the model.
|
|
297
|
+
"""Load weights into the model.
|
|
303
298
|
|
|
304
299
|
Args:
|
|
305
300
|
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
|
@@ -324,8 +319,7 @@ class BaseModel(torch.nn.Module):
|
|
|
324
319
|
LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
|
|
325
320
|
|
|
326
321
|
def loss(self, batch, preds=None):
|
|
327
|
-
"""
|
|
328
|
-
Compute loss.
|
|
322
|
+
"""Compute loss.
|
|
329
323
|
|
|
330
324
|
Args:
|
|
331
325
|
batch (dict): Batch to compute loss on.
|
|
@@ -344,11 +338,10 @@ class BaseModel(torch.nn.Module):
|
|
|
344
338
|
|
|
345
339
|
|
|
346
340
|
class DetectionModel(BaseModel):
|
|
347
|
-
"""
|
|
348
|
-
YOLO detection model.
|
|
341
|
+
"""YOLO detection model.
|
|
349
342
|
|
|
350
|
-
This class implements the YOLO detection architecture, handling model initialization, forward pass,
|
|
351
|
-
|
|
343
|
+
This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
|
|
344
|
+
inference, and loss computation for object detection tasks.
|
|
352
345
|
|
|
353
346
|
Attributes:
|
|
354
347
|
yaml (dict): Model configuration dictionary.
|
|
@@ -368,13 +361,12 @@ class DetectionModel(BaseModel):
|
|
|
368
361
|
|
|
369
362
|
Examples:
|
|
370
363
|
Initialize a detection model
|
|
371
|
-
>>> model = DetectionModel("
|
|
364
|
+
>>> model = DetectionModel("yolo26n.yaml", ch=3, nc=80)
|
|
372
365
|
>>> results = model.predict(image_tensor)
|
|
373
366
|
"""
|
|
374
367
|
|
|
375
|
-
def __init__(self, cfg="
|
|
376
|
-
"""
|
|
377
|
-
Initialize the YOLO detection model with the given config and parameters.
|
|
368
|
+
def __init__(self, cfg="yolo26n.yaml", ch=3, nc=None, verbose=True):
|
|
369
|
+
"""Initialize the YOLO detection model with the given config and parameters.
|
|
378
370
|
|
|
379
371
|
Args:
|
|
380
372
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -399,7 +391,6 @@ class DetectionModel(BaseModel):
|
|
|
399
391
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
|
400
392
|
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
|
401
393
|
self.inplace = self.yaml.get("inplace", True)
|
|
402
|
-
self.end2end = getattr(self.model[-1], "end2end", False)
|
|
403
394
|
|
|
404
395
|
# Build strides
|
|
405
396
|
m = self.model[-1] # Detect()
|
|
@@ -409,9 +400,10 @@ class DetectionModel(BaseModel):
|
|
|
409
400
|
|
|
410
401
|
def _forward(x):
|
|
411
402
|
"""Perform a forward pass through the model, handling different Detect subclass types accordingly."""
|
|
403
|
+
output = self.forward(x)
|
|
412
404
|
if self.end2end:
|
|
413
|
-
|
|
414
|
-
return
|
|
405
|
+
output = output["one2many"]
|
|
406
|
+
return output["feats"]
|
|
415
407
|
|
|
416
408
|
self.model.eval() # Avoid changing batch statistics until training begins
|
|
417
409
|
m.training = True # Setting it to True to properly return strides
|
|
@@ -420,7 +412,7 @@ class DetectionModel(BaseModel):
|
|
|
420
412
|
self.model.train() # Set model back to training(default) mode
|
|
421
413
|
m.bias_init() # only run once
|
|
422
414
|
else:
|
|
423
|
-
self.stride = torch.Tensor([32]) # default stride
|
|
415
|
+
self.stride = torch.Tensor([32]) # default stride, e.g., RTDETR
|
|
424
416
|
|
|
425
417
|
# Init weights, biases
|
|
426
418
|
initialize_weights(self)
|
|
@@ -428,9 +420,13 @@ class DetectionModel(BaseModel):
|
|
|
428
420
|
self.info()
|
|
429
421
|
LOGGER.info("")
|
|
430
422
|
|
|
423
|
+
@property
|
|
424
|
+
def end2end(self):
|
|
425
|
+
"""Return whether the model uses end-to-end NMS-free detection."""
|
|
426
|
+
return getattr(self.model[-1], "end2end", False)
|
|
427
|
+
|
|
431
428
|
def _predict_augment(self, x):
|
|
432
|
-
"""
|
|
433
|
-
Perform augmentations on input image x and return augmented inference and train outputs.
|
|
429
|
+
"""Perform augmentations on input image x and return augmented inference and train outputs.
|
|
434
430
|
|
|
435
431
|
Args:
|
|
436
432
|
x (torch.Tensor): Input image tensor.
|
|
@@ -455,8 +451,7 @@ class DetectionModel(BaseModel):
|
|
|
455
451
|
|
|
456
452
|
@staticmethod
|
|
457
453
|
def _descale_pred(p, flips, scale, img_size, dim=1):
|
|
458
|
-
"""
|
|
459
|
-
De-scale predictions following augmented inference (inverse operation).
|
|
454
|
+
"""De-scale predictions following augmented inference (inverse operation).
|
|
460
455
|
|
|
461
456
|
Args:
|
|
462
457
|
p (torch.Tensor): Predictions tensor.
|
|
@@ -477,8 +472,7 @@ class DetectionModel(BaseModel):
|
|
|
477
472
|
return torch.cat((x, y, wh, cls), dim)
|
|
478
473
|
|
|
479
474
|
def _clip_augmented(self, y):
|
|
480
|
-
"""
|
|
481
|
-
Clip YOLO augmented inference tails.
|
|
475
|
+
"""Clip YOLO augmented inference tails.
|
|
482
476
|
|
|
483
477
|
Args:
|
|
484
478
|
y (list[torch.Tensor]): List of detection tensors.
|
|
@@ -497,15 +491,14 @@ class DetectionModel(BaseModel):
|
|
|
497
491
|
|
|
498
492
|
def init_criterion(self):
|
|
499
493
|
"""Initialize the loss criterion for the DetectionModel."""
|
|
500
|
-
return
|
|
494
|
+
return E2ELoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
|
|
501
495
|
|
|
502
496
|
|
|
503
497
|
class OBBModel(DetectionModel):
|
|
504
|
-
"""
|
|
505
|
-
YOLO Oriented Bounding Box (OBB) model.
|
|
498
|
+
"""YOLO Oriented Bounding Box (OBB) model.
|
|
506
499
|
|
|
507
|
-
This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized
|
|
508
|
-
|
|
500
|
+
This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
|
|
501
|
+
computation for rotated object detection.
|
|
509
502
|
|
|
510
503
|
Methods:
|
|
511
504
|
__init__: Initialize YOLO OBB model.
|
|
@@ -513,13 +506,12 @@ class OBBModel(DetectionModel):
|
|
|
513
506
|
|
|
514
507
|
Examples:
|
|
515
508
|
Initialize an OBB model
|
|
516
|
-
>>> model = OBBModel("
|
|
509
|
+
>>> model = OBBModel("yolo26n-obb.yaml", ch=3, nc=80)
|
|
517
510
|
>>> results = model.predict(image_tensor)
|
|
518
511
|
"""
|
|
519
512
|
|
|
520
|
-
def __init__(self, cfg="
|
|
521
|
-
"""
|
|
522
|
-
Initialize YOLO OBB model with given config and parameters.
|
|
513
|
+
def __init__(self, cfg="yolo26n-obb.yaml", ch=3, nc=None, verbose=True):
|
|
514
|
+
"""Initialize YOLO OBB model with given config and parameters.
|
|
523
515
|
|
|
524
516
|
Args:
|
|
525
517
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -531,15 +523,14 @@ class OBBModel(DetectionModel):
|
|
|
531
523
|
|
|
532
524
|
def init_criterion(self):
|
|
533
525
|
"""Initialize the loss criterion for the model."""
|
|
534
|
-
return v8OBBLoss(self)
|
|
526
|
+
return E2ELoss(self, v8OBBLoss) if getattr(self, "end2end", False) else v8OBBLoss(self)
|
|
535
527
|
|
|
536
528
|
|
|
537
529
|
class SegmentationModel(DetectionModel):
|
|
538
|
-
"""
|
|
539
|
-
YOLO segmentation model.
|
|
530
|
+
"""YOLO segmentation model.
|
|
540
531
|
|
|
541
|
-
This class extends DetectionModel to handle instance segmentation tasks, providing specialized
|
|
542
|
-
|
|
532
|
+
This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
|
|
533
|
+
pixel-level object detection and segmentation.
|
|
543
534
|
|
|
544
535
|
Methods:
|
|
545
536
|
__init__: Initialize YOLO segmentation model.
|
|
@@ -547,13 +538,12 @@ class SegmentationModel(DetectionModel):
|
|
|
547
538
|
|
|
548
539
|
Examples:
|
|
549
540
|
Initialize a segmentation model
|
|
550
|
-
>>> model = SegmentationModel("
|
|
541
|
+
>>> model = SegmentationModel("yolo26n-seg.yaml", ch=3, nc=80)
|
|
551
542
|
>>> results = model.predict(image_tensor)
|
|
552
543
|
"""
|
|
553
544
|
|
|
554
|
-
def __init__(self, cfg="
|
|
555
|
-
"""
|
|
556
|
-
Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
|
545
|
+
def __init__(self, cfg="yolo26n-seg.yaml", ch=3, nc=None, verbose=True):
|
|
546
|
+
"""Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
|
557
547
|
|
|
558
548
|
Args:
|
|
559
549
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -565,15 +555,14 @@ class SegmentationModel(DetectionModel):
|
|
|
565
555
|
|
|
566
556
|
def init_criterion(self):
|
|
567
557
|
"""Initialize the loss criterion for the SegmentationModel."""
|
|
568
|
-
return v8SegmentationLoss(self)
|
|
558
|
+
return E2ELoss(self, v8SegmentationLoss) if getattr(self, "end2end", False) else v8SegmentationLoss(self)
|
|
569
559
|
|
|
570
560
|
|
|
571
561
|
class PoseModel(DetectionModel):
|
|
572
|
-
"""
|
|
573
|
-
YOLO pose model.
|
|
562
|
+
"""YOLO pose model.
|
|
574
563
|
|
|
575
|
-
This class extends DetectionModel to handle human pose estimation tasks, providing specialized
|
|
576
|
-
|
|
564
|
+
This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
|
|
565
|
+
keypoint detection and pose estimation.
|
|
577
566
|
|
|
578
567
|
Attributes:
|
|
579
568
|
kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).
|
|
@@ -584,13 +573,12 @@ class PoseModel(DetectionModel):
|
|
|
584
573
|
|
|
585
574
|
Examples:
|
|
586
575
|
Initialize a pose model
|
|
587
|
-
>>> model = PoseModel("
|
|
576
|
+
>>> model = PoseModel("yolo26n-pose.yaml", ch=3, nc=1, data_kpt_shape=(17, 3))
|
|
588
577
|
>>> results = model.predict(image_tensor)
|
|
589
578
|
"""
|
|
590
579
|
|
|
591
|
-
def __init__(self, cfg="
|
|
592
|
-
"""
|
|
593
|
-
Initialize Ultralytics YOLO Pose model.
|
|
580
|
+
def __init__(self, cfg="yolo26n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
|
581
|
+
"""Initialize Ultralytics YOLO Pose model.
|
|
594
582
|
|
|
595
583
|
Args:
|
|
596
584
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -608,15 +596,14 @@ class PoseModel(DetectionModel):
|
|
|
608
596
|
|
|
609
597
|
def init_criterion(self):
|
|
610
598
|
"""Initialize the loss criterion for the PoseModel."""
|
|
611
|
-
return v8PoseLoss(self)
|
|
599
|
+
return E2ELoss(self, PoseLoss26) if getattr(self, "end2end", False) else v8PoseLoss(self)
|
|
612
600
|
|
|
613
601
|
|
|
614
602
|
class ClassificationModel(BaseModel):
|
|
615
|
-
"""
|
|
616
|
-
YOLO classification model.
|
|
603
|
+
"""YOLO classification model.
|
|
617
604
|
|
|
618
|
-
This class implements the YOLO classification architecture for image classification tasks,
|
|
619
|
-
|
|
605
|
+
This class implements the YOLO classification architecture for image classification tasks, providing model
|
|
606
|
+
initialization, configuration, and output reshaping capabilities.
|
|
620
607
|
|
|
621
608
|
Attributes:
|
|
622
609
|
yaml (dict): Model configuration dictionary.
|
|
@@ -632,13 +619,12 @@ class ClassificationModel(BaseModel):
|
|
|
632
619
|
|
|
633
620
|
Examples:
|
|
634
621
|
Initialize a classification model
|
|
635
|
-
>>> model = ClassificationModel("
|
|
622
|
+
>>> model = ClassificationModel("yolo26n-cls.yaml", ch=3, nc=1000)
|
|
636
623
|
>>> results = model.predict(image_tensor)
|
|
637
624
|
"""
|
|
638
625
|
|
|
639
|
-
def __init__(self, cfg="
|
|
640
|
-
"""
|
|
641
|
-
Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
|
626
|
+
def __init__(self, cfg="yolo26n-cls.yaml", ch=3, nc=None, verbose=True):
|
|
627
|
+
"""Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
|
642
628
|
|
|
643
629
|
Args:
|
|
644
630
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -650,8 +636,7 @@ class ClassificationModel(BaseModel):
|
|
|
650
636
|
self._from_yaml(cfg, ch, nc, verbose)
|
|
651
637
|
|
|
652
638
|
def _from_yaml(self, cfg, ch, nc, verbose):
|
|
653
|
-
"""
|
|
654
|
-
Set Ultralytics YOLO model configurations and define the model architecture.
|
|
639
|
+
"""Set Ultralytics YOLO model configurations and define the model architecture.
|
|
655
640
|
|
|
656
641
|
Args:
|
|
657
642
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -675,8 +660,7 @@ class ClassificationModel(BaseModel):
|
|
|
675
660
|
|
|
676
661
|
@staticmethod
|
|
677
662
|
def reshape_outputs(model, nc):
|
|
678
|
-
"""
|
|
679
|
-
Update a TorchVision classification model to class count 'n' if required.
|
|
663
|
+
"""Update a TorchVision classification model to class count 'n' if required.
|
|
680
664
|
|
|
681
665
|
Args:
|
|
682
666
|
model (torch.nn.Module): Model to update.
|
|
@@ -708,8 +692,7 @@ class ClassificationModel(BaseModel):
|
|
|
708
692
|
|
|
709
693
|
|
|
710
694
|
class RTDETRDetectionModel(DetectionModel):
|
|
711
|
-
"""
|
|
712
|
-
RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
|
695
|
+
"""RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
|
713
696
|
|
|
714
697
|
This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
|
|
715
698
|
the training and inference processes. RTDETR is an object detection and tracking model that extends from the
|
|
@@ -732,8 +715,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
732
715
|
"""
|
|
733
716
|
|
|
734
717
|
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
|
|
735
|
-
"""
|
|
736
|
-
Initialize the RTDETRDetectionModel.
|
|
718
|
+
"""Initialize the RTDETRDetectionModel.
|
|
737
719
|
|
|
738
720
|
Args:
|
|
739
721
|
cfg (str | dict): Configuration file name or path.
|
|
@@ -744,8 +726,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
744
726
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
745
727
|
|
|
746
728
|
def _apply(self, fn):
|
|
747
|
-
"""
|
|
748
|
-
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
729
|
+
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
749
730
|
|
|
750
731
|
Args:
|
|
751
732
|
fn (function): The function to apply to the model.
|
|
@@ -766,8 +747,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
766
747
|
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
|
767
748
|
|
|
768
749
|
def loss(self, batch, preds=None):
|
|
769
|
-
"""
|
|
770
|
-
Compute the loss for the given batch of data.
|
|
750
|
+
"""Compute the loss for the given batch of data.
|
|
771
751
|
|
|
772
752
|
Args:
|
|
773
753
|
batch (dict): Dictionary containing image and label data.
|
|
@@ -813,8 +793,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
813
793
|
)
|
|
814
794
|
|
|
815
795
|
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
|
816
|
-
"""
|
|
817
|
-
Perform a forward pass through the model.
|
|
796
|
+
"""Perform a forward pass through the model.
|
|
818
797
|
|
|
819
798
|
Args:
|
|
820
799
|
x (torch.Tensor): The input tensor.
|
|
@@ -849,11 +828,10 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
849
828
|
|
|
850
829
|
|
|
851
830
|
class WorldModel(DetectionModel):
|
|
852
|
-
"""
|
|
853
|
-
YOLOv8 World Model.
|
|
831
|
+
"""YOLOv8 World Model.
|
|
854
832
|
|
|
855
|
-
This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based
|
|
856
|
-
|
|
833
|
+
This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
|
|
834
|
+
specification and CLIP model integration for zero-shot detection capabilities.
|
|
857
835
|
|
|
858
836
|
Attributes:
|
|
859
837
|
txt_feats (torch.Tensor): Text feature embeddings for classes.
|
|
@@ -874,8 +852,7 @@ class WorldModel(DetectionModel):
|
|
|
874
852
|
"""
|
|
875
853
|
|
|
876
854
|
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
|
877
|
-
"""
|
|
878
|
-
Initialize YOLOv8 world model with given config and parameters.
|
|
855
|
+
"""Initialize YOLOv8 world model with given config and parameters.
|
|
879
856
|
|
|
880
857
|
Args:
|
|
881
858
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -888,8 +865,7 @@ class WorldModel(DetectionModel):
|
|
|
888
865
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
889
866
|
|
|
890
867
|
def set_classes(self, text, batch=80, cache_clip_model=True):
|
|
891
|
-
"""
|
|
892
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
868
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
893
869
|
|
|
894
870
|
Args:
|
|
895
871
|
text (list[str]): List of class names.
|
|
@@ -900,8 +876,7 @@ class WorldModel(DetectionModel):
|
|
|
900
876
|
self.model[-1].nc = len(text)
|
|
901
877
|
|
|
902
878
|
def get_text_pe(self, text, batch=80, cache_clip_model=True):
|
|
903
|
-
"""
|
|
904
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
879
|
+
"""Get text positional embeddings for offline inference without CLIP model.
|
|
905
880
|
|
|
906
881
|
Args:
|
|
907
882
|
text (list[str]): List of class names.
|
|
@@ -924,8 +899,7 @@ class WorldModel(DetectionModel):
|
|
|
924
899
|
return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
|
925
900
|
|
|
926
901
|
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
|
927
|
-
"""
|
|
928
|
-
Perform a forward pass through the model.
|
|
902
|
+
"""Perform a forward pass through the model.
|
|
929
903
|
|
|
930
904
|
Args:
|
|
931
905
|
x (torch.Tensor): The input tensor.
|
|
@@ -969,8 +943,7 @@ class WorldModel(DetectionModel):
|
|
|
969
943
|
return x
|
|
970
944
|
|
|
971
945
|
def loss(self, batch, preds=None):
|
|
972
|
-
"""
|
|
973
|
-
Compute loss.
|
|
946
|
+
"""Compute loss.
|
|
974
947
|
|
|
975
948
|
Args:
|
|
976
949
|
batch (dict): Batch to compute loss on.
|
|
@@ -985,11 +958,10 @@ class WorldModel(DetectionModel):
|
|
|
985
958
|
|
|
986
959
|
|
|
987
960
|
class YOLOEModel(DetectionModel):
|
|
988
|
-
"""
|
|
989
|
-
YOLOE detection model.
|
|
961
|
+
"""YOLOE detection model.
|
|
990
962
|
|
|
991
|
-
This class implements the YOLOE architecture for efficient object detection with text and visual prompts,
|
|
992
|
-
|
|
963
|
+
This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
|
|
964
|
+
both prompt-based and prompt-free inference modes.
|
|
993
965
|
|
|
994
966
|
Attributes:
|
|
995
967
|
pe (torch.Tensor): Prompt embeddings for classes.
|
|
@@ -1013,8 +985,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1013
985
|
"""
|
|
1014
986
|
|
|
1015
987
|
def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
|
|
1016
|
-
"""
|
|
1017
|
-
Initialize YOLOE model with given config and parameters.
|
|
988
|
+
"""Initialize YOLOE model with given config and parameters.
|
|
1018
989
|
|
|
1019
990
|
Args:
|
|
1020
991
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -1023,17 +994,17 @@ class YOLOEModel(DetectionModel):
|
|
|
1023
994
|
verbose (bool): Whether to display model information.
|
|
1024
995
|
"""
|
|
1025
996
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
997
|
+
self.text_model = self.yaml.get("text_model", "mobileclip:blt")
|
|
1026
998
|
|
|
1027
999
|
@smart_inference_mode()
|
|
1028
1000
|
def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
|
|
1029
|
-
"""
|
|
1030
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
1001
|
+
"""Get text positional embeddings for offline inference without CLIP model.
|
|
1031
1002
|
|
|
1032
1003
|
Args:
|
|
1033
1004
|
text (list[str]): List of class names.
|
|
1034
1005
|
batch (int): Batch size for processing text tokens.
|
|
1035
1006
|
cache_clip_model (bool): Whether to cache the CLIP model.
|
|
1036
|
-
without_reprta (bool): Whether to return text embeddings
|
|
1007
|
+
without_reprta (bool): Whether to return text embeddings without reprta module processing.
|
|
1037
1008
|
|
|
1038
1009
|
Returns:
|
|
1039
1010
|
(torch.Tensor): Text positional embeddings.
|
|
@@ -1043,9 +1014,13 @@ class YOLOEModel(DetectionModel):
|
|
|
1043
1014
|
device = next(self.model.parameters()).device
|
|
1044
1015
|
if not getattr(self, "clip_model", None) and cache_clip_model:
|
|
1045
1016
|
# For backwards compatibility of models lacking clip_model attribute
|
|
1046
|
-
self.clip_model = build_text_model("mobileclip:blt", device=device)
|
|
1017
|
+
self.clip_model = build_text_model(getattr(self, "text_model", "mobileclip:blt"), device=device)
|
|
1047
1018
|
|
|
1048
|
-
model =
|
|
1019
|
+
model = (
|
|
1020
|
+
self.clip_model
|
|
1021
|
+
if cache_clip_model
|
|
1022
|
+
else build_text_model(getattr(self, "text_model", "mobileclip:blt"), device=device)
|
|
1023
|
+
)
|
|
1049
1024
|
text_token = model.tokenize(text)
|
|
1050
1025
|
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
|
1051
1026
|
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
|
@@ -1059,8 +1034,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1059
1034
|
|
|
1060
1035
|
@smart_inference_mode()
|
|
1061
1036
|
def get_visual_pe(self, img, visual):
|
|
1062
|
-
"""
|
|
1063
|
-
Get visual embeddings.
|
|
1037
|
+
"""Get visual embeddings.
|
|
1064
1038
|
|
|
1065
1039
|
Args:
|
|
1066
1040
|
img (torch.Tensor): Input image tensor.
|
|
@@ -1072,8 +1046,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1072
1046
|
return self(img, vpe=visual, return_vpe=True)
|
|
1073
1047
|
|
|
1074
1048
|
def set_vocab(self, vocab, names):
|
|
1075
|
-
"""
|
|
1076
|
-
Set vocabulary for the prompt-free model.
|
|
1049
|
+
"""Set vocabulary for the prompt-free model.
|
|
1077
1050
|
|
|
1078
1051
|
Args:
|
|
1079
1052
|
vocab (nn.ModuleList): List of vocabulary items.
|
|
@@ -1087,10 +1060,12 @@ class YOLOEModel(DetectionModel):
|
|
|
1087
1060
|
device = next(self.parameters()).device
|
|
1088
1061
|
self(torch.empty(1, 3, self.args["imgsz"], self.args["imgsz"]).to(device)) # warmup
|
|
1089
1062
|
|
|
1063
|
+
cv3 = getattr(head, "one2one_cv3", head.cv3)
|
|
1064
|
+
cv2 = getattr(head, "one2one_cv2", head.cv2)
|
|
1065
|
+
|
|
1090
1066
|
# re-parameterization for prompt-free model
|
|
1091
1067
|
self.model[-1].lrpc = nn.ModuleList(
|
|
1092
|
-
LRPCHead(cls, pf[-1], loc[-1], enabled=i != 2)
|
|
1093
|
-
for i, (cls, pf, loc) in enumerate(zip(vocab, head.cv3, head.cv2))
|
|
1068
|
+
LRPCHead(cls, pf[-1], loc[-1], enabled=i != 2) for i, (cls, pf, loc) in enumerate(zip(vocab, cv3, cv2))
|
|
1094
1069
|
)
|
|
1095
1070
|
for loc_head, cls_head in zip(head.cv2, head.cv3):
|
|
1096
1071
|
assert isinstance(loc_head, nn.Sequential)
|
|
@@ -1101,8 +1076,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1101
1076
|
self.names = check_class_names(names)
|
|
1102
1077
|
|
|
1103
1078
|
def get_vocab(self, names):
|
|
1104
|
-
"""
|
|
1105
|
-
Get fused vocabulary layer from the model.
|
|
1079
|
+
"""Get fused vocabulary layer from the model.
|
|
1106
1080
|
|
|
1107
1081
|
Args:
|
|
1108
1082
|
names (list): List of class names.
|
|
@@ -1120,15 +1094,15 @@ class YOLOEModel(DetectionModel):
|
|
|
1120
1094
|
device = next(self.model.parameters()).device
|
|
1121
1095
|
head.fuse(self.pe.to(device)) # fuse prompt embeddings to classify head
|
|
1122
1096
|
|
|
1097
|
+
cv3 = getattr(head, "one2one_cv3", head.cv3)
|
|
1123
1098
|
vocab = nn.ModuleList()
|
|
1124
|
-
for cls_head in
|
|
1099
|
+
for cls_head in cv3:
|
|
1125
1100
|
assert isinstance(cls_head, nn.Sequential)
|
|
1126
1101
|
vocab.append(cls_head[-1])
|
|
1127
1102
|
return vocab
|
|
1128
1103
|
|
|
1129
1104
|
def set_classes(self, names, embeddings):
|
|
1130
|
-
"""
|
|
1131
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
1105
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
1132
1106
|
|
|
1133
1107
|
Args:
|
|
1134
1108
|
names (list[str]): List of class names.
|
|
@@ -1143,8 +1117,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1143
1117
|
self.names = check_class_names(names)
|
|
1144
1118
|
|
|
1145
1119
|
def get_cls_pe(self, tpe, vpe):
|
|
1146
|
-
"""
|
|
1147
|
-
Get class positional embeddings.
|
|
1120
|
+
"""Get class positional embeddings.
|
|
1148
1121
|
|
|
1149
1122
|
Args:
|
|
1150
1123
|
tpe (torch.Tensor, optional): Text positional embeddings.
|
|
@@ -1167,8 +1140,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1167
1140
|
def predict(
|
|
1168
1141
|
self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
|
|
1169
1142
|
):
|
|
1170
|
-
"""
|
|
1171
|
-
Perform a forward pass through the model.
|
|
1143
|
+
"""Perform a forward pass through the model.
|
|
1172
1144
|
|
|
1173
1145
|
Args:
|
|
1174
1146
|
x (torch.Tensor): The input tensor.
|
|
@@ -1201,9 +1173,8 @@ class YOLOEModel(DetectionModel):
|
|
|
1201
1173
|
cls_pe = self.get_cls_pe(m.get_tpe(tpe), vpe).to(device=x[0].device, dtype=x[0].dtype)
|
|
1202
1174
|
if cls_pe.shape[0] != b or m.export:
|
|
1203
1175
|
cls_pe = cls_pe.expand(b, -1, -1)
|
|
1204
|
-
x
|
|
1205
|
-
|
|
1206
|
-
x = m(x) # run
|
|
1176
|
+
x.append(cls_pe) # adding cls embedding
|
|
1177
|
+
x = m(x) # run
|
|
1207
1178
|
|
|
1208
1179
|
y.append(x if m.i in self.save else None) # save output
|
|
1209
1180
|
if visualize:
|
|
@@ -1215,8 +1186,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1215
1186
|
return x
|
|
1216
1187
|
|
|
1217
1188
|
def loss(self, batch, preds=None):
|
|
1218
|
-
"""
|
|
1219
|
-
Compute loss.
|
|
1189
|
+
"""Compute loss.
|
|
1220
1190
|
|
|
1221
1191
|
Args:
|
|
1222
1192
|
batch (dict): Batch to compute loss on.
|
|
@@ -1226,19 +1196,25 @@ class YOLOEModel(DetectionModel):
|
|
|
1226
1196
|
from ultralytics.utils.loss import TVPDetectLoss
|
|
1227
1197
|
|
|
1228
1198
|
visual_prompt = batch.get("visuals", None) is not None # TODO
|
|
1229
|
-
self.criterion =
|
|
1230
|
-
|
|
1199
|
+
self.criterion = (
|
|
1200
|
+
(E2ELoss(self, TVPDetectLoss) if getattr(self, "end2end", False) else TVPDetectLoss(self))
|
|
1201
|
+
if visual_prompt
|
|
1202
|
+
else self.init_criterion()
|
|
1203
|
+
)
|
|
1231
1204
|
if preds is None:
|
|
1232
|
-
preds = self.forward(
|
|
1205
|
+
preds = self.forward(
|
|
1206
|
+
batch["img"],
|
|
1207
|
+
tpe=None if "visuals" in batch else batch.get("txt_feats", None),
|
|
1208
|
+
vpe=batch.get("visuals", None),
|
|
1209
|
+
)
|
|
1233
1210
|
return self.criterion(preds, batch)
|
|
1234
1211
|
|
|
1235
1212
|
|
|
1236
1213
|
class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
1237
|
-
"""
|
|
1238
|
-
YOLOE segmentation model.
|
|
1214
|
+
"""YOLOE segmentation model.
|
|
1239
1215
|
|
|
1240
|
-
This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts,
|
|
1241
|
-
|
|
1216
|
+
This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
|
|
1217
|
+
specialized loss computation for pixel-level object detection and segmentation.
|
|
1242
1218
|
|
|
1243
1219
|
Methods:
|
|
1244
1220
|
__init__: Initialize YOLOE segmentation model.
|
|
@@ -1251,8 +1227,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1251
1227
|
"""
|
|
1252
1228
|
|
|
1253
1229
|
def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
|
|
1254
|
-
"""
|
|
1255
|
-
Initialize YOLOE segmentation model with given config and parameters.
|
|
1230
|
+
"""Initialize YOLOE segmentation model with given config and parameters.
|
|
1256
1231
|
|
|
1257
1232
|
Args:
|
|
1258
1233
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -1263,8 +1238,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1263
1238
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
1264
1239
|
|
|
1265
1240
|
def loss(self, batch, preds=None):
|
|
1266
|
-
"""
|
|
1267
|
-
Compute loss.
|
|
1241
|
+
"""Compute loss.
|
|
1268
1242
|
|
|
1269
1243
|
Args:
|
|
1270
1244
|
batch (dict): Batch to compute loss on.
|
|
@@ -1274,7 +1248,11 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1274
1248
|
from ultralytics.utils.loss import TVPSegmentLoss
|
|
1275
1249
|
|
|
1276
1250
|
visual_prompt = batch.get("visuals", None) is not None # TODO
|
|
1277
|
-
self.criterion =
|
|
1251
|
+
self.criterion = (
|
|
1252
|
+
(E2ELoss(self, TVPSegmentLoss) if getattr(self, "end2end", False) else TVPSegmentLoss(self))
|
|
1253
|
+
if visual_prompt
|
|
1254
|
+
else self.init_criterion()
|
|
1255
|
+
)
|
|
1278
1256
|
|
|
1279
1257
|
if preds is None:
|
|
1280
1258
|
preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
|
|
@@ -1282,11 +1260,10 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1282
1260
|
|
|
1283
1261
|
|
|
1284
1262
|
class Ensemble(torch.nn.ModuleList):
|
|
1285
|
-
"""
|
|
1286
|
-
Ensemble of models.
|
|
1263
|
+
"""Ensemble of models.
|
|
1287
1264
|
|
|
1288
|
-
This class allows combining multiple YOLO models into an ensemble for improved performance through
|
|
1289
|
-
|
|
1265
|
+
This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
|
|
1266
|
+
or other ensemble techniques.
|
|
1290
1267
|
|
|
1291
1268
|
Methods:
|
|
1292
1269
|
__init__: Initialize an ensemble of models.
|
|
@@ -1305,8 +1282,7 @@ class Ensemble(torch.nn.ModuleList):
|
|
|
1305
1282
|
super().__init__()
|
|
1306
1283
|
|
|
1307
1284
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
|
1308
|
-
"""
|
|
1309
|
-
Generate the YOLO network's final layer.
|
|
1285
|
+
"""Generate the YOLO network's final layer.
|
|
1310
1286
|
|
|
1311
1287
|
Args:
|
|
1312
1288
|
x (torch.Tensor): Input tensor.
|
|
@@ -1321,7 +1297,7 @@ class Ensemble(torch.nn.ModuleList):
|
|
|
1321
1297
|
y = [module(x, augment, profile, visualize)[0] for module in self]
|
|
1322
1298
|
# y = torch.stack(y).max(0)[0] # max ensemble
|
|
1323
1299
|
# y = torch.stack(y).mean(0) # mean ensemble
|
|
1324
|
-
y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
|
|
1300
|
+
y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C*num_models)
|
|
1325
1301
|
return y, None # inference, train output
|
|
1326
1302
|
|
|
1327
1303
|
|
|
@@ -1330,12 +1306,11 @@ class Ensemble(torch.nn.ModuleList):
|
|
|
1330
1306
|
|
|
1331
1307
|
@contextlib.contextmanager
|
|
1332
1308
|
def temporary_modules(modules=None, attributes=None):
|
|
1333
|
-
"""
|
|
1334
|
-
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
|
1309
|
+
"""Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
|
1335
1310
|
|
|
1336
|
-
This function can be used to change the module paths during runtime. It's useful when refactoring code,
|
|
1337
|
-
|
|
1338
|
-
|
|
1311
|
+
This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
|
|
1312
|
+
moved a module from one location to another, but you still want to support the old import paths for backwards
|
|
1313
|
+
compatibility.
|
|
1339
1314
|
|
|
1340
1315
|
Args:
|
|
1341
1316
|
modules (dict, optional): A dictionary mapping old module paths to new module paths.
|
|
@@ -1346,7 +1321,7 @@ def temporary_modules(modules=None, attributes=None):
|
|
|
1346
1321
|
>>> import old.module # this will now import new.module
|
|
1347
1322
|
>>> from old.module import attribute # this will now import new.module.attribute
|
|
1348
1323
|
|
|
1349
|
-
|
|
1324
|
+
Notes:
|
|
1350
1325
|
The changes are only in effect inside the context manager and are undone once the context manager exits.
|
|
1351
1326
|
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
|
1352
1327
|
applications or libraries. Use this function with caution.
|
|
@@ -1393,8 +1368,7 @@ class SafeUnpickler(pickle.Unpickler):
|
|
|
1393
1368
|
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
|
1394
1369
|
|
|
1395
1370
|
def find_class(self, module, name):
|
|
1396
|
-
"""
|
|
1397
|
-
Attempt to find a class, returning SafeClass if not among safe modules.
|
|
1371
|
+
"""Attempt to find a class, returning SafeClass if not among safe modules.
|
|
1398
1372
|
|
|
1399
1373
|
Args:
|
|
1400
1374
|
module (str): Module name.
|
|
@@ -1419,10 +1393,9 @@ class SafeUnpickler(pickle.Unpickler):
|
|
|
1419
1393
|
|
|
1420
1394
|
|
|
1421
1395
|
def torch_safe_load(weight, safe_only=False):
|
|
1422
|
-
"""
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
After installation, the function again attempts to load the model using torch.load().
|
|
1396
|
+
"""Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
|
|
1397
|
+
the error, logs a warning message, and attempts to install the missing module via the check_requirements()
|
|
1398
|
+
function. After installation, the function again attempts to load the model using torch.load().
|
|
1426
1399
|
|
|
1427
1400
|
Args:
|
|
1428
1401
|
weight (str): The file path of the PyTorch model.
|
|
@@ -1471,7 +1444,7 @@ def torch_safe_load(weight, safe_only=False):
|
|
|
1471
1444
|
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
|
1472
1445
|
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
|
1473
1446
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
|
1474
|
-
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=
|
|
1447
|
+
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo26n.pt'"
|
|
1475
1448
|
)
|
|
1476
1449
|
) from e
|
|
1477
1450
|
elif e.name == "numpy._core":
|
|
@@ -1484,7 +1457,7 @@ def torch_safe_load(weight, safe_only=False):
|
|
|
1484
1457
|
f"{weight} appears to require '{e.name}', which is not in Ultralytics requirements."
|
|
1485
1458
|
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
|
1486
1459
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
|
1487
|
-
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=
|
|
1460
|
+
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo26n.pt'"
|
|
1488
1461
|
)
|
|
1489
1462
|
check_requirements(e.name) # install missing module
|
|
1490
1463
|
ckpt = torch_load(file, map_location="cpu")
|
|
@@ -1501,8 +1474,7 @@ def torch_safe_load(weight, safe_only=False):
|
|
|
1501
1474
|
|
|
1502
1475
|
|
|
1503
1476
|
def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
|
1504
|
-
"""
|
|
1505
|
-
Load a single model weights.
|
|
1477
|
+
"""Load a single model weights.
|
|
1506
1478
|
|
|
1507
1479
|
Args:
|
|
1508
1480
|
weight (str | Path): Model weight path.
|
|
@@ -1539,8 +1511,7 @@ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
|
|
1539
1511
|
|
|
1540
1512
|
|
|
1541
1513
|
def parse_model(d, ch, verbose=True):
|
|
1542
|
-
"""
|
|
1543
|
-
Parse a YOLO model.yaml dictionary into a PyTorch model.
|
|
1514
|
+
"""Parse a YOLO model.yaml dictionary into a PyTorch model.
|
|
1544
1515
|
|
|
1545
1516
|
Args:
|
|
1546
1517
|
d (dict): Model dictionary.
|
|
@@ -1556,12 +1527,13 @@ def parse_model(d, ch, verbose=True):
|
|
|
1556
1527
|
# Args
|
|
1557
1528
|
legacy = True # backward compatibility for v3/v5/v8/v9 models
|
|
1558
1529
|
max_channels = float("inf")
|
|
1559
|
-
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
|
1530
|
+
nc, act, scales, end2end = (d.get(x) for x in ("nc", "activation", "scales", "end2end"))
|
|
1531
|
+
reg_max = d.get("reg_max", 16)
|
|
1560
1532
|
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
|
1561
1533
|
scale = d.get("scale")
|
|
1562
1534
|
if scales:
|
|
1563
1535
|
if not scale:
|
|
1564
|
-
scale =
|
|
1536
|
+
scale = next(iter(scales.keys()))
|
|
1565
1537
|
LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
|
|
1566
1538
|
depth, width, max_channels = scales[scale]
|
|
1567
1539
|
|
|
@@ -1646,7 +1618,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1646
1618
|
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
|
1647
1619
|
if m in base_modules:
|
|
1648
1620
|
c1, c2 = ch[f], args[0]
|
|
1649
|
-
if c2 != nc: # if c2
|
|
1621
|
+
if c2 != nc: # if c2 != nc (e.g., Classify() output)
|
|
1650
1622
|
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
|
1651
1623
|
if m is C2fAttn: # set 1) embed channels and 2) num heads
|
|
1652
1624
|
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
|
|
@@ -1681,13 +1653,29 @@ def parse_model(d, ch, verbose=True):
|
|
|
1681
1653
|
elif m is Concat:
|
|
1682
1654
|
c2 = sum(ch[x] for x in f)
|
|
1683
1655
|
elif m in frozenset(
|
|
1684
|
-
{
|
|
1656
|
+
{
|
|
1657
|
+
Detect,
|
|
1658
|
+
WorldDetect,
|
|
1659
|
+
YOLOEDetect,
|
|
1660
|
+
Segment,
|
|
1661
|
+
Segment26,
|
|
1662
|
+
YOLOESegment,
|
|
1663
|
+
YOLOESegment26,
|
|
1664
|
+
Pose,
|
|
1665
|
+
Pose26,
|
|
1666
|
+
OBB,
|
|
1667
|
+
OBB26,
|
|
1668
|
+
}
|
|
1685
1669
|
):
|
|
1686
|
-
args.
|
|
1687
|
-
if m is Segment or m is YOLOESegment:
|
|
1670
|
+
args.extend([reg_max, end2end, [ch[x] for x in f]])
|
|
1671
|
+
if m is Segment or m is YOLOESegment or m is Segment26 or m is YOLOESegment26:
|
|
1688
1672
|
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
|
1689
|
-
if m in {Detect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB}:
|
|
1673
|
+
if m in {Detect, YOLOEDetect, Segment, Segment26, YOLOESegment, YOLOESegment26, Pose, Pose26, OBB, OBB26}:
|
|
1690
1674
|
m.legacy = legacy
|
|
1675
|
+
elif m is v10Detect:
|
|
1676
|
+
args.append([ch[x] for x in f])
|
|
1677
|
+
elif m is ImagePoolingAttn:
|
|
1678
|
+
args.insert(1, [ch[x] for x in f]) # channels as second arg
|
|
1691
1679
|
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
|
1692
1680
|
args.insert(1, [ch[x] for x in f])
|
|
1693
1681
|
elif m is CBLinear:
|
|
@@ -1708,7 +1696,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1708
1696
|
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
|
1709
1697
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
|
1710
1698
|
if verbose:
|
|
1711
|
-
LOGGER.info(f"{i:>3}{
|
|
1699
|
+
LOGGER.info(f"{i:>3}{f!s:>20}{n_:>3}{m_.np:10.0f} {t:<45}{args!s:<30}") # print
|
|
1712
1700
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
|
1713
1701
|
layers.append(m_)
|
|
1714
1702
|
if i == 0:
|
|
@@ -1718,8 +1706,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1718
1706
|
|
|
1719
1707
|
|
|
1720
1708
|
def yaml_model_load(path):
|
|
1721
|
-
"""
|
|
1722
|
-
Load a YOLOv8 model from a YAML file.
|
|
1709
|
+
"""Load a YOLOv8 model from a YAML file.
|
|
1723
1710
|
|
|
1724
1711
|
Args:
|
|
1725
1712
|
path (str | Path): Path to the YAML file.
|
|
@@ -1742,8 +1729,7 @@ def yaml_model_load(path):
|
|
|
1742
1729
|
|
|
1743
1730
|
|
|
1744
1731
|
def guess_model_scale(model_path):
|
|
1745
|
-
"""
|
|
1746
|
-
Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
|
1732
|
+
"""Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
|
1747
1733
|
|
|
1748
1734
|
Args:
|
|
1749
1735
|
model_path (str | Path): The path to the YOLO model's YAML file.
|
|
@@ -1752,14 +1738,13 @@ def guess_model_scale(model_path):
|
|
|
1752
1738
|
(str): The size character of the model's scale (n, s, m, l, or x).
|
|
1753
1739
|
"""
|
|
1754
1740
|
try:
|
|
1755
|
-
return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
|
|
1741
|
+
return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
|
|
1756
1742
|
except AttributeError:
|
|
1757
1743
|
return ""
|
|
1758
1744
|
|
|
1759
1745
|
|
|
1760
1746
|
def guess_model_task(model):
|
|
1761
|
-
"""
|
|
1762
|
-
Guess the task of a PyTorch model from its architecture or configuration.
|
|
1747
|
+
"""Guess the task of a PyTorch model from its architecture or configuration.
|
|
1763
1748
|
|
|
1764
1749
|
Args:
|
|
1765
1750
|
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
|
|
@@ -1777,9 +1762,9 @@ def guess_model_task(model):
|
|
|
1777
1762
|
return "detect"
|
|
1778
1763
|
if "segment" in m:
|
|
1779
1764
|
return "segment"
|
|
1780
|
-
if
|
|
1765
|
+
if "pose" in m:
|
|
1781
1766
|
return "pose"
|
|
1782
|
-
if
|
|
1767
|
+
if "obb" in m:
|
|
1783
1768
|
return "obb"
|
|
1784
1769
|
|
|
1785
1770
|
# Guess from model cfg
|
|
@@ -1790,10 +1775,10 @@ def guess_model_task(model):
|
|
|
1790
1775
|
if isinstance(model, torch.nn.Module): # PyTorch model
|
|
1791
1776
|
for x in "model.args", "model.model.args", "model.model.model.args":
|
|
1792
1777
|
with contextlib.suppress(Exception):
|
|
1793
|
-
return eval(x)["task"]
|
|
1778
|
+
return eval(x)["task"] # nosec B307: safe eval of known attribute paths
|
|
1794
1779
|
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
|
1795
1780
|
with contextlib.suppress(Exception):
|
|
1796
|
-
return cfg2task(eval(x))
|
|
1781
|
+
return cfg2task(eval(x)) # nosec B307: safe eval of known attribute paths
|
|
1797
1782
|
for m in model.modules():
|
|
1798
1783
|
if isinstance(m, (Segment, YOLOESegment)):
|
|
1799
1784
|
return "segment"
|