dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
ultralytics/nn/tasks.py
ADDED
@@ -0,0 +1,1627 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import pickle
|
5
|
+
import re
|
6
|
+
import types
|
7
|
+
from copy import deepcopy
|
8
|
+
from pathlib import Path
|
9
|
+
|
10
|
+
import torch
|
11
|
+
import torch.nn as nn
|
12
|
+
|
13
|
+
from ultralytics.nn.autobackend import check_class_names
|
14
|
+
from ultralytics.nn.modules import (
|
15
|
+
AIFI,
|
16
|
+
C1,
|
17
|
+
C2,
|
18
|
+
C2PSA,
|
19
|
+
C3,
|
20
|
+
C3TR,
|
21
|
+
ELAN1,
|
22
|
+
OBB,
|
23
|
+
PSA,
|
24
|
+
SPP,
|
25
|
+
SPPELAN,
|
26
|
+
SPPF,
|
27
|
+
A2C2f,
|
28
|
+
AConv,
|
29
|
+
ADown,
|
30
|
+
Bottleneck,
|
31
|
+
BottleneckCSP,
|
32
|
+
C2f,
|
33
|
+
C2fAttn,
|
34
|
+
C2fCIB,
|
35
|
+
C2fPSA,
|
36
|
+
C3Ghost,
|
37
|
+
C3k2,
|
38
|
+
C3x,
|
39
|
+
CBFuse,
|
40
|
+
CBLinear,
|
41
|
+
Classify,
|
42
|
+
Concat,
|
43
|
+
Conv,
|
44
|
+
Conv2,
|
45
|
+
ConvTranspose,
|
46
|
+
Detect,
|
47
|
+
DWConv,
|
48
|
+
DWConvTranspose2d,
|
49
|
+
Focus,
|
50
|
+
GhostBottleneck,
|
51
|
+
GhostConv,
|
52
|
+
HGBlock,
|
53
|
+
HGStem,
|
54
|
+
ImagePoolingAttn,
|
55
|
+
Index,
|
56
|
+
LRPCHead,
|
57
|
+
Pose,
|
58
|
+
RepC3,
|
59
|
+
RepConv,
|
60
|
+
RepNCSPELAN4,
|
61
|
+
RepVGGDW,
|
62
|
+
ResNetLayer,
|
63
|
+
RTDETRDecoder,
|
64
|
+
SCDown,
|
65
|
+
Segment,
|
66
|
+
TorchVision,
|
67
|
+
WorldDetect,
|
68
|
+
YOLOEDetect,
|
69
|
+
YOLOESegment,
|
70
|
+
v10Detect,
|
71
|
+
)
|
72
|
+
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, YAML, colorstr, emojis
|
73
|
+
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
74
|
+
from ultralytics.utils.loss import (
|
75
|
+
E2EDetectLoss,
|
76
|
+
v8ClassificationLoss,
|
77
|
+
v8DetectionLoss,
|
78
|
+
v8OBBLoss,
|
79
|
+
v8PoseLoss,
|
80
|
+
v8SegmentationLoss,
|
81
|
+
)
|
82
|
+
from ultralytics.utils.ops import make_divisible
|
83
|
+
from ultralytics.utils.plotting import feature_visualization
|
84
|
+
from ultralytics.utils.torch_utils import (
|
85
|
+
fuse_conv_and_bn,
|
86
|
+
fuse_deconv_and_bn,
|
87
|
+
initialize_weights,
|
88
|
+
intersect_dicts,
|
89
|
+
model_info,
|
90
|
+
scale_img,
|
91
|
+
smart_inference_mode,
|
92
|
+
time_sync,
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
class BaseModel(torch.nn.Module):
|
97
|
+
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
|
98
|
+
|
99
|
+
def forward(self, x, *args, **kwargs):
|
100
|
+
"""
|
101
|
+
Perform forward pass of the model for either training or inference.
|
102
|
+
|
103
|
+
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
|
107
|
+
*args (Any): Variable length argument list.
|
108
|
+
**kwargs (Any): Arbitrary keyword arguments.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
(torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
|
112
|
+
"""
|
113
|
+
if isinstance(x, dict): # for cases of training and validating while training.
|
114
|
+
return self.loss(x, *args, **kwargs)
|
115
|
+
return self.predict(x, *args, **kwargs)
|
116
|
+
|
117
|
+
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
118
|
+
"""
|
119
|
+
Perform a forward pass through the network.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
x (torch.Tensor): The input tensor to the model.
|
123
|
+
profile (bool): Print the computation time of each layer if True.
|
124
|
+
visualize (bool): Save the feature maps of the model if True.
|
125
|
+
augment (bool): Augment image during prediction.
|
126
|
+
embed (list, optional): A list of feature vectors/embeddings to return.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
(torch.Tensor): The last output of the model.
|
130
|
+
"""
|
131
|
+
if augment:
|
132
|
+
return self._predict_augment(x)
|
133
|
+
return self._predict_once(x, profile, visualize, embed)
|
134
|
+
|
135
|
+
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
136
|
+
"""
|
137
|
+
Perform a forward pass through the network.
|
138
|
+
|
139
|
+
Args:
|
140
|
+
x (torch.Tensor): The input tensor to the model.
|
141
|
+
profile (bool): Print the computation time of each layer if True.
|
142
|
+
visualize (bool): Save the feature maps of the model if True.
|
143
|
+
embed (list, optional): A list of feature vectors/embeddings to return.
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
(torch.Tensor): The last output of the model.
|
147
|
+
"""
|
148
|
+
y, dt, embeddings = [], [], [] # outputs
|
149
|
+
for m in self.model:
|
150
|
+
if m.f != -1: # if not from previous layer
|
151
|
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
152
|
+
if profile:
|
153
|
+
self._profile_one_layer(m, x, dt)
|
154
|
+
x = m(x) # run
|
155
|
+
y.append(x if m.i in self.save else None) # save output
|
156
|
+
if visualize:
|
157
|
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
158
|
+
if embed and m.i in embed:
|
159
|
+
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
160
|
+
if m.i == max(embed):
|
161
|
+
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
162
|
+
return x
|
163
|
+
|
164
|
+
def _predict_augment(self, x):
|
165
|
+
"""Perform augmentations on input image x and return augmented inference."""
|
166
|
+
LOGGER.warning(
|
167
|
+
f"{self.__class__.__name__} does not support 'augment=True' prediction. "
|
168
|
+
f"Reverting to single-scale prediction."
|
169
|
+
)
|
170
|
+
return self._predict_once(x)
|
171
|
+
|
172
|
+
def _profile_one_layer(self, m, x, dt):
|
173
|
+
"""
|
174
|
+
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
m (torch.nn.Module): The layer to be profiled.
|
178
|
+
x (torch.Tensor): The input data to the layer.
|
179
|
+
dt (list): A list to store the computation time of the layer.
|
180
|
+
"""
|
181
|
+
try:
|
182
|
+
import thop
|
183
|
+
except ImportError:
|
184
|
+
thop = None # conda support without 'ultralytics-thop' installed
|
185
|
+
|
186
|
+
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
187
|
+
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
188
|
+
t = time_sync()
|
189
|
+
for _ in range(10):
|
190
|
+
m(x.copy() if c else x)
|
191
|
+
dt.append((time_sync() - t) * 100)
|
192
|
+
if m == self.model[0]:
|
193
|
+
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
194
|
+
LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}")
|
195
|
+
if c:
|
196
|
+
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
197
|
+
|
198
|
+
def fuse(self, verbose=True):
|
199
|
+
"""
|
200
|
+
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
201
|
+
efficiency.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
(torch.nn.Module): The fused model is returned.
|
205
|
+
"""
|
206
|
+
if not self.is_fused():
|
207
|
+
for m in self.model.modules():
|
208
|
+
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
|
209
|
+
if isinstance(m, Conv2):
|
210
|
+
m.fuse_convs()
|
211
|
+
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
212
|
+
delattr(m, "bn") # remove batchnorm
|
213
|
+
m.forward = m.forward_fuse # update forward
|
214
|
+
if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
|
215
|
+
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
|
216
|
+
delattr(m, "bn") # remove batchnorm
|
217
|
+
m.forward = m.forward_fuse # update forward
|
218
|
+
if isinstance(m, RepConv):
|
219
|
+
m.fuse_convs()
|
220
|
+
m.forward = m.forward_fuse # update forward
|
221
|
+
if isinstance(m, RepVGGDW):
|
222
|
+
m.fuse()
|
223
|
+
m.forward = m.forward_fuse
|
224
|
+
if isinstance(m, v10Detect):
|
225
|
+
m.fuse() # remove one2many head
|
226
|
+
self.info(verbose=verbose)
|
227
|
+
|
228
|
+
return self
|
229
|
+
|
230
|
+
def is_fused(self, thresh=10):
|
231
|
+
"""
|
232
|
+
Check if the model has less than a certain threshold of BatchNorm layers.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
thresh (int, optional): The threshold number of BatchNorm layers.
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
239
|
+
"""
|
240
|
+
bn = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
241
|
+
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
242
|
+
|
243
|
+
def info(self, detailed=False, verbose=True, imgsz=640):
|
244
|
+
"""
|
245
|
+
Print model information.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
detailed (bool): If True, prints out detailed information about the model.
|
249
|
+
verbose (bool): If True, prints out the model information.
|
250
|
+
imgsz (int): The size of the image that the model will be trained on.
|
251
|
+
"""
|
252
|
+
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
253
|
+
|
254
|
+
def _apply(self, fn):
|
255
|
+
"""
|
256
|
+
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
fn (function): The function to apply to the model.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
(BaseModel): An updated BaseModel object.
|
263
|
+
"""
|
264
|
+
self = super()._apply(fn)
|
265
|
+
m = self.model[-1] # Detect()
|
266
|
+
if isinstance(
|
267
|
+
m, Detect
|
268
|
+
): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect, YOLOESegment
|
269
|
+
m.stride = fn(m.stride)
|
270
|
+
m.anchors = fn(m.anchors)
|
271
|
+
m.strides = fn(m.strides)
|
272
|
+
return self
|
273
|
+
|
274
|
+
def load(self, weights, verbose=True):
|
275
|
+
"""
|
276
|
+
Load weights into the model.
|
277
|
+
|
278
|
+
Args:
|
279
|
+
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
280
|
+
verbose (bool, optional): Whether to log the transfer progress.
|
281
|
+
"""
|
282
|
+
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
283
|
+
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
284
|
+
updated_csd = intersect_dicts(csd, self.state_dict()) # intersect
|
285
|
+
self.load_state_dict(updated_csd, strict=False) # load
|
286
|
+
len_updated_csd = len(updated_csd)
|
287
|
+
first_conv = "model.0.conv.weight" # hard-coded to yolo models for now
|
288
|
+
# mostly used to boost multi-channel training
|
289
|
+
state_dict = self.state_dict()
|
290
|
+
if first_conv not in updated_csd and first_conv in state_dict:
|
291
|
+
c1, c2, h, w = state_dict[first_conv].shape
|
292
|
+
cc1, cc2, ch, cw = csd[first_conv].shape
|
293
|
+
if ch == h and cw == w:
|
294
|
+
c1, c2 = min(c1, cc1), min(c2, cc2)
|
295
|
+
state_dict[first_conv][:c1, :c2] = csd[first_conv][:c1, :c2]
|
296
|
+
len_updated_csd += 1
|
297
|
+
if verbose:
|
298
|
+
LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
|
299
|
+
|
300
|
+
def loss(self, batch, preds=None):
|
301
|
+
"""
|
302
|
+
Compute loss.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
batch (dict): Batch to compute loss on.
|
306
|
+
preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
|
307
|
+
"""
|
308
|
+
if getattr(self, "criterion", None) is None:
|
309
|
+
self.criterion = self.init_criterion()
|
310
|
+
|
311
|
+
preds = self.forward(batch["img"]) if preds is None else preds
|
312
|
+
return self.criterion(preds, batch)
|
313
|
+
|
314
|
+
def init_criterion(self):
|
315
|
+
"""Initialize the loss criterion for the BaseModel."""
|
316
|
+
raise NotImplementedError("compute_loss() needs to be implemented by task heads")
|
317
|
+
|
318
|
+
|
319
|
+
class DetectionModel(BaseModel):
|
320
|
+
"""YOLO detection model."""
|
321
|
+
|
322
|
+
def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
|
323
|
+
"""
|
324
|
+
Initialize the YOLO detection model with the given config and parameters.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
328
|
+
ch (int): Number of input channels.
|
329
|
+
nc (int, optional): Number of classes.
|
330
|
+
verbose (bool): Whether to display model information.
|
331
|
+
"""
|
332
|
+
super().__init__()
|
333
|
+
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
334
|
+
if self.yaml["backbone"][0][2] == "Silence":
|
335
|
+
LOGGER.warning(
|
336
|
+
"YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. "
|
337
|
+
"Please delete local *.pt file and re-download the latest model checkpoint."
|
338
|
+
)
|
339
|
+
self.yaml["backbone"][0][2] = "nn.Identity"
|
340
|
+
|
341
|
+
# Define model
|
342
|
+
self.yaml["channels"] = ch # save channels
|
343
|
+
if nc and nc != self.yaml["nc"]:
|
344
|
+
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
345
|
+
self.yaml["nc"] = nc # override YAML value
|
346
|
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
347
|
+
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
348
|
+
self.inplace = self.yaml.get("inplace", True)
|
349
|
+
self.end2end = getattr(self.model[-1], "end2end", False)
|
350
|
+
|
351
|
+
# Build strides
|
352
|
+
m = self.model[-1] # Detect()
|
353
|
+
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, YOLOEDetect, YOLOESegment
|
354
|
+
s = 256 # 2x min stride
|
355
|
+
m.inplace = self.inplace
|
356
|
+
|
357
|
+
def _forward(x):
|
358
|
+
"""Perform a forward pass through the model, handling different Detect subclass types accordingly."""
|
359
|
+
if self.end2end:
|
360
|
+
return self.forward(x)["one2many"]
|
361
|
+
return self.forward(x)[0] if isinstance(m, (Segment, YOLOESegment, Pose, OBB)) else self.forward(x)
|
362
|
+
|
363
|
+
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
|
364
|
+
self.stride = m.stride
|
365
|
+
m.bias_init() # only run once
|
366
|
+
else:
|
367
|
+
self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
|
368
|
+
|
369
|
+
# Init weights, biases
|
370
|
+
initialize_weights(self)
|
371
|
+
if verbose:
|
372
|
+
self.info()
|
373
|
+
LOGGER.info("")
|
374
|
+
|
375
|
+
def _predict_augment(self, x):
|
376
|
+
"""
|
377
|
+
Perform augmentations on input image x and return augmented inference and train outputs.
|
378
|
+
|
379
|
+
Args:
|
380
|
+
x (torch.Tensor): Input image tensor.
|
381
|
+
|
382
|
+
Returns:
|
383
|
+
(torch.Tensor): Augmented inference output.
|
384
|
+
"""
|
385
|
+
if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
|
386
|
+
LOGGER.warning("Model does not support 'augment=True', reverting to single-scale prediction.")
|
387
|
+
return self._predict_once(x)
|
388
|
+
img_size = x.shape[-2:] # height, width
|
389
|
+
s = [1, 0.83, 0.67] # scales
|
390
|
+
f = [None, 3, None] # flips (2-ud, 3-lr)
|
391
|
+
y = [] # outputs
|
392
|
+
for si, fi in zip(s, f):
|
393
|
+
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
394
|
+
yi = super().predict(xi)[0] # forward
|
395
|
+
yi = self._descale_pred(yi, fi, si, img_size)
|
396
|
+
y.append(yi)
|
397
|
+
y = self._clip_augmented(y) # clip augmented tails
|
398
|
+
return torch.cat(y, -1), None # augmented inference, train
|
399
|
+
|
400
|
+
@staticmethod
|
401
|
+
def _descale_pred(p, flips, scale, img_size, dim=1):
|
402
|
+
"""
|
403
|
+
De-scale predictions following augmented inference (inverse operation).
|
404
|
+
|
405
|
+
Args:
|
406
|
+
p (torch.Tensor): Predictions tensor.
|
407
|
+
flips (int): Flip type (0=none, 2=ud, 3=lr).
|
408
|
+
scale (float): Scale factor.
|
409
|
+
img_size (tuple): Original image size (height, width).
|
410
|
+
dim (int): Dimension to split at.
|
411
|
+
|
412
|
+
Returns:
|
413
|
+
(torch.Tensor): De-scaled predictions.
|
414
|
+
"""
|
415
|
+
p[:, :4] /= scale # de-scale
|
416
|
+
x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
|
417
|
+
if flips == 2:
|
418
|
+
y = img_size[0] - y # de-flip ud
|
419
|
+
elif flips == 3:
|
420
|
+
x = img_size[1] - x # de-flip lr
|
421
|
+
return torch.cat((x, y, wh, cls), dim)
|
422
|
+
|
423
|
+
def _clip_augmented(self, y):
|
424
|
+
"""
|
425
|
+
Clip YOLO augmented inference tails.
|
426
|
+
|
427
|
+
Args:
|
428
|
+
y (List[torch.Tensor]): List of detection tensors.
|
429
|
+
|
430
|
+
Returns:
|
431
|
+
(List[torch.Tensor]): Clipped detection tensors.
|
432
|
+
"""
|
433
|
+
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
434
|
+
g = sum(4**x for x in range(nl)) # grid points
|
435
|
+
e = 1 # exclude layer count
|
436
|
+
i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices
|
437
|
+
y[0] = y[0][..., :-i] # large
|
438
|
+
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
439
|
+
y[-1] = y[-1][..., i:] # small
|
440
|
+
return y
|
441
|
+
|
442
|
+
def init_criterion(self):
|
443
|
+
"""Initialize the loss criterion for the DetectionModel."""
|
444
|
+
return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
|
445
|
+
|
446
|
+
|
447
|
+
class OBBModel(DetectionModel):
|
448
|
+
"""YOLO Oriented Bounding Box (OBB) model."""
|
449
|
+
|
450
|
+
def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
|
451
|
+
"""
|
452
|
+
Initialize YOLO OBB model with given config and parameters.
|
453
|
+
|
454
|
+
Args:
|
455
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
456
|
+
ch (int): Number of input channels.
|
457
|
+
nc (int, optional): Number of classes.
|
458
|
+
verbose (bool): Whether to display model information.
|
459
|
+
"""
|
460
|
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
461
|
+
|
462
|
+
def init_criterion(self):
|
463
|
+
"""Initialize the loss criterion for the model."""
|
464
|
+
return v8OBBLoss(self)
|
465
|
+
|
466
|
+
|
467
|
+
class SegmentationModel(DetectionModel):
|
468
|
+
"""YOLO segmentation model."""
|
469
|
+
|
470
|
+
def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
|
471
|
+
"""
|
472
|
+
Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
473
|
+
|
474
|
+
Args:
|
475
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
476
|
+
ch (int): Number of input channels.
|
477
|
+
nc (int, optional): Number of classes.
|
478
|
+
verbose (bool): Whether to display model information.
|
479
|
+
"""
|
480
|
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
481
|
+
|
482
|
+
def init_criterion(self):
|
483
|
+
"""Initialize the loss criterion for the SegmentationModel."""
|
484
|
+
return v8SegmentationLoss(self)
|
485
|
+
|
486
|
+
|
487
|
+
class PoseModel(DetectionModel):
|
488
|
+
"""YOLO pose model."""
|
489
|
+
|
490
|
+
def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
491
|
+
"""
|
492
|
+
Initialize Ultralytics YOLO Pose model.
|
493
|
+
|
494
|
+
Args:
|
495
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
496
|
+
ch (int): Number of input channels.
|
497
|
+
nc (int, optional): Number of classes.
|
498
|
+
data_kpt_shape (tuple): Shape of keypoints data.
|
499
|
+
verbose (bool): Whether to display model information.
|
500
|
+
"""
|
501
|
+
if not isinstance(cfg, dict):
|
502
|
+
cfg = yaml_model_load(cfg) # load model YAML
|
503
|
+
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
|
504
|
+
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
|
505
|
+
cfg["kpt_shape"] = data_kpt_shape
|
506
|
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
507
|
+
|
508
|
+
def init_criterion(self):
|
509
|
+
"""Initialize the loss criterion for the PoseModel."""
|
510
|
+
return v8PoseLoss(self)
|
511
|
+
|
512
|
+
|
513
|
+
class ClassificationModel(BaseModel):
|
514
|
+
"""YOLO classification model."""
|
515
|
+
|
516
|
+
def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
|
517
|
+
"""
|
518
|
+
Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
519
|
+
|
520
|
+
Args:
|
521
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
522
|
+
ch (int): Number of input channels.
|
523
|
+
nc (int, optional): Number of classes.
|
524
|
+
verbose (bool): Whether to display model information.
|
525
|
+
"""
|
526
|
+
super().__init__()
|
527
|
+
self._from_yaml(cfg, ch, nc, verbose)
|
528
|
+
|
529
|
+
def _from_yaml(self, cfg, ch, nc, verbose):
|
530
|
+
"""
|
531
|
+
Set Ultralytics YOLO model configurations and define the model architecture.
|
532
|
+
|
533
|
+
Args:
|
534
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
535
|
+
ch (int): Number of input channels.
|
536
|
+
nc (int, optional): Number of classes.
|
537
|
+
verbose (bool): Whether to display model information.
|
538
|
+
"""
|
539
|
+
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
540
|
+
|
541
|
+
# Define model
|
542
|
+
ch = self.yaml["channels"] = self.yaml.get("channels", ch) # input channels
|
543
|
+
if nc and nc != self.yaml["nc"]:
|
544
|
+
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
545
|
+
self.yaml["nc"] = nc # override YAML value
|
546
|
+
elif not nc and not self.yaml.get("nc", None):
|
547
|
+
raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
|
548
|
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
549
|
+
self.stride = torch.Tensor([1]) # no stride constraints
|
550
|
+
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
551
|
+
self.info()
|
552
|
+
|
553
|
+
@staticmethod
|
554
|
+
def reshape_outputs(model, nc):
|
555
|
+
"""
|
556
|
+
Update a TorchVision classification model to class count 'n' if required.
|
557
|
+
|
558
|
+
Args:
|
559
|
+
model (torch.nn.Module): Model to update.
|
560
|
+
nc (int): New number of classes.
|
561
|
+
"""
|
562
|
+
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
563
|
+
if isinstance(m, Classify): # YOLO Classify() head
|
564
|
+
if m.linear.out_features != nc:
|
565
|
+
m.linear = torch.nn.Linear(m.linear.in_features, nc)
|
566
|
+
elif isinstance(m, torch.nn.Linear): # ResNet, EfficientNet
|
567
|
+
if m.out_features != nc:
|
568
|
+
setattr(model, name, torch.nn.Linear(m.in_features, nc))
|
569
|
+
elif isinstance(m, torch.nn.Sequential):
|
570
|
+
types = [type(x) for x in m]
|
571
|
+
if torch.nn.Linear in types:
|
572
|
+
i = len(types) - 1 - types[::-1].index(torch.nn.Linear) # last torch.nn.Linear index
|
573
|
+
if m[i].out_features != nc:
|
574
|
+
m[i] = torch.nn.Linear(m[i].in_features, nc)
|
575
|
+
elif torch.nn.Conv2d in types:
|
576
|
+
i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d) # last torch.nn.Conv2d index
|
577
|
+
if m[i].out_channels != nc:
|
578
|
+
m[i] = torch.nn.Conv2d(
|
579
|
+
m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None
|
580
|
+
)
|
581
|
+
|
582
|
+
def init_criterion(self):
|
583
|
+
"""Initialize the loss criterion for the ClassificationModel."""
|
584
|
+
return v8ClassificationLoss()
|
585
|
+
|
586
|
+
|
587
|
+
class RTDETRDetectionModel(DetectionModel):
|
588
|
+
"""
|
589
|
+
RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
590
|
+
|
591
|
+
This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
|
592
|
+
the training and inference processes. RTDETR is an object detection and tracking model that extends from the
|
593
|
+
DetectionModel base class.
|
594
|
+
|
595
|
+
Methods:
|
596
|
+
init_criterion: Initializes the criterion used for loss calculation.
|
597
|
+
loss: Computes and returns the loss during training.
|
598
|
+
predict: Performs a forward pass through the network and returns the output.
|
599
|
+
"""
|
600
|
+
|
601
|
+
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
|
602
|
+
"""
|
603
|
+
Initialize the RTDETRDetectionModel.
|
604
|
+
|
605
|
+
Args:
|
606
|
+
cfg (str | dict): Configuration file name or path.
|
607
|
+
ch (int): Number of input channels.
|
608
|
+
nc (int, optional): Number of classes.
|
609
|
+
verbose (bool): Print additional information during initialization.
|
610
|
+
"""
|
611
|
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
612
|
+
|
613
|
+
def init_criterion(self):
|
614
|
+
"""Initialize the loss criterion for the RTDETRDetectionModel."""
|
615
|
+
from ultralytics.models.utils.loss import RTDETRDetectionLoss
|
616
|
+
|
617
|
+
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
618
|
+
|
619
|
+
def loss(self, batch, preds=None):
|
620
|
+
"""
|
621
|
+
Compute the loss for the given batch of data.
|
622
|
+
|
623
|
+
Args:
|
624
|
+
batch (dict): Dictionary containing image and label data.
|
625
|
+
preds (torch.Tensor, optional): Precomputed model predictions.
|
626
|
+
|
627
|
+
Returns:
|
628
|
+
(tuple): A tuple containing the total loss and main three losses in a tensor.
|
629
|
+
"""
|
630
|
+
if not hasattr(self, "criterion"):
|
631
|
+
self.criterion = self.init_criterion()
|
632
|
+
|
633
|
+
img = batch["img"]
|
634
|
+
# NOTE: preprocess gt_bbox and gt_labels to list.
|
635
|
+
bs = len(img)
|
636
|
+
batch_idx = batch["batch_idx"]
|
637
|
+
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
|
638
|
+
targets = {
|
639
|
+
"cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
|
640
|
+
"bboxes": batch["bboxes"].to(device=img.device),
|
641
|
+
"batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
|
642
|
+
"gt_groups": gt_groups,
|
643
|
+
}
|
644
|
+
|
645
|
+
preds = self.predict(img, batch=targets) if preds is None else preds
|
646
|
+
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
|
647
|
+
if dn_meta is None:
|
648
|
+
dn_bboxes, dn_scores = None, None
|
649
|
+
else:
|
650
|
+
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
|
651
|
+
dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)
|
652
|
+
|
653
|
+
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
|
654
|
+
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
|
655
|
+
|
656
|
+
loss = self.criterion(
|
657
|
+
(dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
|
658
|
+
)
|
659
|
+
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
|
660
|
+
return sum(loss.values()), torch.as_tensor(
|
661
|
+
[loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
|
662
|
+
)
|
663
|
+
|
664
|
+
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
665
|
+
"""
|
666
|
+
Perform a forward pass through the model.
|
667
|
+
|
668
|
+
Args:
|
669
|
+
x (torch.Tensor): The input tensor.
|
670
|
+
profile (bool): If True, profile the computation time for each layer.
|
671
|
+
visualize (bool): If True, save feature maps for visualization.
|
672
|
+
batch (dict, optional): Ground truth data for evaluation.
|
673
|
+
augment (bool): If True, perform data augmentation during inference.
|
674
|
+
embed (list, optional): A list of feature vectors/embeddings to return.
|
675
|
+
|
676
|
+
Returns:
|
677
|
+
(torch.Tensor): Model's output tensor.
|
678
|
+
"""
|
679
|
+
y, dt, embeddings = [], [], [] # outputs
|
680
|
+
for m in self.model[:-1]: # except the head part
|
681
|
+
if m.f != -1: # if not from previous layer
|
682
|
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
683
|
+
if profile:
|
684
|
+
self._profile_one_layer(m, x, dt)
|
685
|
+
x = m(x) # run
|
686
|
+
y.append(x if m.i in self.save else None) # save output
|
687
|
+
if visualize:
|
688
|
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
689
|
+
if embed and m.i in embed:
|
690
|
+
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
691
|
+
if m.i == max(embed):
|
692
|
+
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
693
|
+
head = self.model[-1]
|
694
|
+
x = head([y[j] for j in head.f], batch) # head inference
|
695
|
+
return x
|
696
|
+
|
697
|
+
|
698
|
+
class WorldModel(DetectionModel):
|
699
|
+
"""YOLOv8 World Model."""
|
700
|
+
|
701
|
+
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
702
|
+
"""
|
703
|
+
Initialize YOLOv8 world model with given config and parameters.
|
704
|
+
|
705
|
+
Args:
|
706
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
707
|
+
ch (int): Number of input channels.
|
708
|
+
nc (int, optional): Number of classes.
|
709
|
+
verbose (bool): Whether to display model information.
|
710
|
+
"""
|
711
|
+
self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
|
712
|
+
self.clip_model = None # CLIP model placeholder
|
713
|
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
714
|
+
|
715
|
+
def set_classes(self, text, batch=80, cache_clip_model=True):
|
716
|
+
"""
|
717
|
+
Set classes in advance so that model could do offline-inference without clip model.
|
718
|
+
|
719
|
+
Args:
|
720
|
+
text (List[str]): List of class names.
|
721
|
+
batch (int): Batch size for processing text tokens.
|
722
|
+
cache_clip_model (bool): Whether to cache the CLIP model.
|
723
|
+
"""
|
724
|
+
try:
|
725
|
+
import clip
|
726
|
+
except ImportError:
|
727
|
+
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
728
|
+
import clip
|
729
|
+
|
730
|
+
if (
|
731
|
+
not getattr(self, "clip_model", None) and cache_clip_model
|
732
|
+
): # for backwards compatibility of models lacking clip_model attribute
|
733
|
+
self.clip_model = clip.load("ViT-B/32")[0]
|
734
|
+
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
|
735
|
+
device = next(model.parameters()).device
|
736
|
+
text_token = clip.tokenize(text).to(device)
|
737
|
+
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
738
|
+
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
739
|
+
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
740
|
+
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
741
|
+
self.model[-1].nc = len(text)
|
742
|
+
|
743
|
+
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
744
|
+
"""
|
745
|
+
Perform a forward pass through the model.
|
746
|
+
|
747
|
+
Args:
|
748
|
+
x (torch.Tensor): The input tensor.
|
749
|
+
profile (bool): If True, profile the computation time for each layer.
|
750
|
+
visualize (bool): If True, save feature maps for visualization.
|
751
|
+
txt_feats (torch.Tensor, optional): The text features, use it if it's given.
|
752
|
+
augment (bool): If True, perform data augmentation during inference.
|
753
|
+
embed (list, optional): A list of feature vectors/embeddings to return.
|
754
|
+
|
755
|
+
Returns:
|
756
|
+
(torch.Tensor): Model's output tensor.
|
757
|
+
"""
|
758
|
+
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
759
|
+
if len(txt_feats) != len(x) or self.model[-1].export:
|
760
|
+
txt_feats = txt_feats.expand(x.shape[0], -1, -1)
|
761
|
+
ori_txt_feats = txt_feats.clone()
|
762
|
+
y, dt, embeddings = [], [], [] # outputs
|
763
|
+
for m in self.model: # except the head part
|
764
|
+
if m.f != -1: # if not from previous layer
|
765
|
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
766
|
+
if profile:
|
767
|
+
self._profile_one_layer(m, x, dt)
|
768
|
+
if isinstance(m, C2fAttn):
|
769
|
+
x = m(x, txt_feats)
|
770
|
+
elif isinstance(m, WorldDetect):
|
771
|
+
x = m(x, ori_txt_feats)
|
772
|
+
elif isinstance(m, ImagePoolingAttn):
|
773
|
+
txt_feats = m(x, txt_feats)
|
774
|
+
else:
|
775
|
+
x = m(x) # run
|
776
|
+
|
777
|
+
y.append(x if m.i in self.save else None) # save output
|
778
|
+
if visualize:
|
779
|
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
780
|
+
if embed and m.i in embed:
|
781
|
+
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
782
|
+
if m.i == max(embed):
|
783
|
+
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
784
|
+
return x
|
785
|
+
|
786
|
+
def loss(self, batch, preds=None):
|
787
|
+
"""
|
788
|
+
Compute loss.
|
789
|
+
|
790
|
+
Args:
|
791
|
+
batch (dict): Batch to compute loss on.
|
792
|
+
preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
|
793
|
+
"""
|
794
|
+
if not hasattr(self, "criterion"):
|
795
|
+
self.criterion = self.init_criterion()
|
796
|
+
|
797
|
+
if preds is None:
|
798
|
+
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
|
799
|
+
return self.criterion(preds, batch)
|
800
|
+
|
801
|
+
|
802
|
+
class YOLOEModel(DetectionModel):
|
803
|
+
"""YOLOE detection model."""
|
804
|
+
|
805
|
+
def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
|
806
|
+
"""
|
807
|
+
Initialize YOLOE model with given config and parameters.
|
808
|
+
|
809
|
+
Args:
|
810
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
811
|
+
ch (int): Number of input channels.
|
812
|
+
nc (int, optional): Number of classes.
|
813
|
+
verbose (bool): Whether to display model information.
|
814
|
+
"""
|
815
|
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
816
|
+
|
817
|
+
@smart_inference_mode()
|
818
|
+
def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
|
819
|
+
"""
|
820
|
+
Set classes in advance so that model could do offline-inference without clip model.
|
821
|
+
|
822
|
+
Args:
|
823
|
+
text (List[str]): List of class names.
|
824
|
+
batch (int): Batch size for processing text tokens.
|
825
|
+
cache_clip_model (bool): Whether to cache the CLIP model.
|
826
|
+
without_reprta (bool): Whether to return text embeddings cooperated with reprta module.
|
827
|
+
|
828
|
+
Returns:
|
829
|
+
(torch.Tensor): Text positional embeddings.
|
830
|
+
"""
|
831
|
+
from ultralytics.nn.text_model import build_text_model
|
832
|
+
|
833
|
+
device = next(self.model.parameters()).device
|
834
|
+
if not getattr(self, "clip_model", None) and cache_clip_model:
|
835
|
+
# For backwards compatibility of models lacking clip_model attribute
|
836
|
+
self.clip_model = build_text_model("mobileclip:blt", device=device)
|
837
|
+
|
838
|
+
model = self.clip_model if cache_clip_model else build_text_model("mobileclip:blt", device=device)
|
839
|
+
text_token = model.tokenize(text)
|
840
|
+
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
841
|
+
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
842
|
+
txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
843
|
+
if without_reprta:
|
844
|
+
return txt_feats
|
845
|
+
|
846
|
+
assert not self.training
|
847
|
+
head = self.model[-1]
|
848
|
+
assert isinstance(head, YOLOEDetect)
|
849
|
+
return head.get_tpe(txt_feats) # run axuiliary text head
|
850
|
+
|
851
|
+
@smart_inference_mode()
|
852
|
+
def get_visual_pe(self, img, visual):
|
853
|
+
"""
|
854
|
+
Get visual embeddings.
|
855
|
+
|
856
|
+
Args:
|
857
|
+
img (torch.Tensor): Input image tensor.
|
858
|
+
visual (torch.Tensor): Visual features.
|
859
|
+
|
860
|
+
Returns:
|
861
|
+
(torch.Tensor): Visual positional embeddings.
|
862
|
+
"""
|
863
|
+
return self(img, vpe=visual, return_vpe=True)
|
864
|
+
|
865
|
+
def set_vocab(self, vocab, names):
|
866
|
+
"""
|
867
|
+
Set vocabulary for the prompt-free model.
|
868
|
+
|
869
|
+
Args:
|
870
|
+
vocab (nn.ModuleList): List of vocabulary items.
|
871
|
+
names (List[str]): List of class names.
|
872
|
+
"""
|
873
|
+
assert not self.training
|
874
|
+
head = self.model[-1]
|
875
|
+
assert isinstance(head, YOLOEDetect)
|
876
|
+
|
877
|
+
# Cache anchors for head
|
878
|
+
device = next(self.parameters()).device
|
879
|
+
self(torch.empty(1, 3, self.args["imgsz"], self.args["imgsz"]).to(device)) # warmup
|
880
|
+
|
881
|
+
# re-parameterization for prompt-free model
|
882
|
+
self.model[-1].lrpc = nn.ModuleList(
|
883
|
+
LRPCHead(cls, pf[-1], loc[-1], enabled=i != 2)
|
884
|
+
for i, (cls, pf, loc) in enumerate(zip(vocab, head.cv3, head.cv2))
|
885
|
+
)
|
886
|
+
for loc_head, cls_head in zip(head.cv2, head.cv3):
|
887
|
+
assert isinstance(loc_head, nn.Sequential)
|
888
|
+
assert isinstance(cls_head, nn.Sequential)
|
889
|
+
del loc_head[-1]
|
890
|
+
del cls_head[-1]
|
891
|
+
self.model[-1].nc = len(names)
|
892
|
+
self.names = check_class_names(names)
|
893
|
+
|
894
|
+
def get_vocab(self, names):
|
895
|
+
"""
|
896
|
+
Get fused vocabulary layer from the model.
|
897
|
+
|
898
|
+
Args:
|
899
|
+
names (list): List of class names.
|
900
|
+
|
901
|
+
Returns:
|
902
|
+
(nn.ModuleList): List of vocabulary modules.
|
903
|
+
"""
|
904
|
+
assert not self.training
|
905
|
+
head = self.model[-1]
|
906
|
+
assert isinstance(head, YOLOEDetect)
|
907
|
+
assert not head.is_fused
|
908
|
+
|
909
|
+
tpe = self.get_text_pe(names)
|
910
|
+
self.set_classes(names, tpe)
|
911
|
+
device = next(self.model.parameters()).device
|
912
|
+
head.fuse(self.pe.to(device)) # fuse prompt embeddings to classify head
|
913
|
+
|
914
|
+
vocab = nn.ModuleList()
|
915
|
+
for cls_head in head.cv3:
|
916
|
+
assert isinstance(cls_head, nn.Sequential)
|
917
|
+
vocab.append(cls_head[-1])
|
918
|
+
return vocab
|
919
|
+
|
920
|
+
def set_classes(self, names, embeddings):
|
921
|
+
"""
|
922
|
+
Set classes in advance so that model could do offline-inference without clip model.
|
923
|
+
|
924
|
+
Args:
|
925
|
+
names (List[str]): List of class names.
|
926
|
+
embeddings (torch.Tensor): Embeddings tensor.
|
927
|
+
"""
|
928
|
+
assert not hasattr(self.model[-1], "lrpc"), (
|
929
|
+
"Prompt-free model does not support setting classes. Please try with Text/Visual prompt models."
|
930
|
+
)
|
931
|
+
assert embeddings.ndim == 3
|
932
|
+
self.pe = embeddings
|
933
|
+
self.model[-1].nc = len(names)
|
934
|
+
self.names = check_class_names(names)
|
935
|
+
|
936
|
+
def get_cls_pe(self, tpe, vpe):
|
937
|
+
"""
|
938
|
+
Get class positional embeddings.
|
939
|
+
|
940
|
+
Args:
|
941
|
+
tpe (torch.Tensor, optional): Text positional embeddings.
|
942
|
+
vpe (torch.Tensor, optional): Visual positional embeddings.
|
943
|
+
|
944
|
+
Returns:
|
945
|
+
(torch.Tensor): Class positional embeddings.
|
946
|
+
"""
|
947
|
+
all_pe = []
|
948
|
+
if tpe is not None:
|
949
|
+
assert tpe.ndim == 3
|
950
|
+
all_pe.append(tpe)
|
951
|
+
if vpe is not None:
|
952
|
+
assert vpe.ndim == 3
|
953
|
+
all_pe.append(vpe)
|
954
|
+
if not all_pe:
|
955
|
+
all_pe.append(getattr(self, "pe", torch.zeros(1, 80, 512)))
|
956
|
+
return torch.cat(all_pe, dim=1)
|
957
|
+
|
958
|
+
def predict(
|
959
|
+
self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
|
960
|
+
):
|
961
|
+
"""
|
962
|
+
Perform a forward pass through the model.
|
963
|
+
|
964
|
+
Args:
|
965
|
+
x (torch.Tensor): The input tensor.
|
966
|
+
profile (bool): If True, profile the computation time for each layer.
|
967
|
+
visualize (bool): If True, save feature maps for visualization.
|
968
|
+
tpe (torch.Tensor, optional): Text positional embeddings.
|
969
|
+
augment (bool): If True, perform data augmentation during inference.
|
970
|
+
embed (list, optional): A list of feature vectors/embeddings to return.
|
971
|
+
vpe (torch.Tensor, optional): Visual positional embeddings.
|
972
|
+
return_vpe (bool): If True, return visual positional embeddings.
|
973
|
+
|
974
|
+
Returns:
|
975
|
+
(torch.Tensor): Model's output tensor.
|
976
|
+
"""
|
977
|
+
y, dt, embeddings = [], [], [] # outputs
|
978
|
+
b = x.shape[0]
|
979
|
+
for m in self.model: # except the head part
|
980
|
+
if m.f != -1: # if not from previous layer
|
981
|
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
982
|
+
if profile:
|
983
|
+
self._profile_one_layer(m, x, dt)
|
984
|
+
if isinstance(m, YOLOEDetect):
|
985
|
+
vpe = m.get_vpe(x, vpe) if vpe is not None else None
|
986
|
+
if return_vpe:
|
987
|
+
assert vpe is not None
|
988
|
+
assert not self.training
|
989
|
+
return vpe
|
990
|
+
cls_pe = self.get_cls_pe(m.get_tpe(tpe), vpe).to(device=x[0].device, dtype=x[0].dtype)
|
991
|
+
if cls_pe.shape[0] != b or m.export:
|
992
|
+
cls_pe = cls_pe.expand(b, -1, -1)
|
993
|
+
x = m(x, cls_pe)
|
994
|
+
else:
|
995
|
+
x = m(x) # run
|
996
|
+
|
997
|
+
y.append(x if m.i in self.save else None) # save output
|
998
|
+
if visualize:
|
999
|
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
1000
|
+
if embed and m.i in embed:
|
1001
|
+
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
1002
|
+
if m.i == max(embed):
|
1003
|
+
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
1004
|
+
return x
|
1005
|
+
|
1006
|
+
def loss(self, batch, preds=None):
|
1007
|
+
"""
|
1008
|
+
Compute loss.
|
1009
|
+
|
1010
|
+
Args:
|
1011
|
+
batch (dict): Batch to compute loss on.
|
1012
|
+
preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
|
1013
|
+
"""
|
1014
|
+
if not hasattr(self, "criterion"):
|
1015
|
+
from ultralytics.utils.loss import TVPDetectLoss
|
1016
|
+
|
1017
|
+
visual_prompt = batch.get("visuals", None) is not None # TODO
|
1018
|
+
self.criterion = TVPDetectLoss(self) if visual_prompt else self.init_criterion()
|
1019
|
+
|
1020
|
+
if preds is None:
|
1021
|
+
preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
|
1022
|
+
return self.criterion(preds, batch)
|
1023
|
+
|
1024
|
+
|
1025
|
+
class YOLOESegModel(YOLOEModel, SegmentationModel):
|
1026
|
+
"""YOLOE segmentation model."""
|
1027
|
+
|
1028
|
+
def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
|
1029
|
+
"""
|
1030
|
+
Initialize YOLOE segmentation model with given config and parameters.
|
1031
|
+
|
1032
|
+
Args:
|
1033
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
1034
|
+
ch (int): Number of input channels.
|
1035
|
+
nc (int, optional): Number of classes.
|
1036
|
+
verbose (bool): Whether to display model information.
|
1037
|
+
"""
|
1038
|
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
1039
|
+
|
1040
|
+
def loss(self, batch, preds=None):
|
1041
|
+
"""
|
1042
|
+
Compute loss.
|
1043
|
+
|
1044
|
+
Args:
|
1045
|
+
batch (dict): Batch to compute loss on.
|
1046
|
+
preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
|
1047
|
+
"""
|
1048
|
+
if not hasattr(self, "criterion"):
|
1049
|
+
from ultralytics.utils.loss import TVPSegmentLoss
|
1050
|
+
|
1051
|
+
visual_prompt = batch.get("visuals", None) is not None # TODO
|
1052
|
+
self.criterion = TVPSegmentLoss(self) if visual_prompt else self.init_criterion()
|
1053
|
+
|
1054
|
+
if preds is None:
|
1055
|
+
preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
|
1056
|
+
return self.criterion(preds, batch)
|
1057
|
+
|
1058
|
+
|
1059
|
+
class Ensemble(torch.nn.ModuleList):
|
1060
|
+
"""Ensemble of models."""
|
1061
|
+
|
1062
|
+
def __init__(self):
|
1063
|
+
"""Initialize an ensemble of models."""
|
1064
|
+
super().__init__()
|
1065
|
+
|
1066
|
+
def forward(self, x, augment=False, profile=False, visualize=False):
|
1067
|
+
"""
|
1068
|
+
Generate the YOLO network's final layer.
|
1069
|
+
|
1070
|
+
Args:
|
1071
|
+
x (torch.Tensor): Input tensor.
|
1072
|
+
augment (bool): Whether to augment the input.
|
1073
|
+
profile (bool): Whether to profile the model.
|
1074
|
+
visualize (bool): Whether to visualize the features.
|
1075
|
+
|
1076
|
+
Returns:
|
1077
|
+
(tuple): Tuple containing the concatenated predictions and None.
|
1078
|
+
"""
|
1079
|
+
y = [module(x, augment, profile, visualize)[0] for module in self]
|
1080
|
+
# y = torch.stack(y).max(0)[0] # max ensemble
|
1081
|
+
# y = torch.stack(y).mean(0) # mean ensemble
|
1082
|
+
y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
|
1083
|
+
return y, None # inference, train output
|
1084
|
+
|
1085
|
+
|
1086
|
+
# Functions ------------------------------------------------------------------------------------------------------------
|
1087
|
+
|
1088
|
+
|
1089
|
+
@contextlib.contextmanager
|
1090
|
+
def temporary_modules(modules=None, attributes=None):
|
1091
|
+
"""
|
1092
|
+
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
1093
|
+
|
1094
|
+
This function can be used to change the module paths during runtime. It's useful when refactoring code,
|
1095
|
+
where you've moved a module from one location to another, but you still want to support the old import
|
1096
|
+
paths for backwards compatibility.
|
1097
|
+
|
1098
|
+
Args:
|
1099
|
+
modules (dict, optional): A dictionary mapping old module paths to new module paths.
|
1100
|
+
attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.
|
1101
|
+
|
1102
|
+
Examples:
|
1103
|
+
>>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
|
1104
|
+
>>> import old.module # this will now import new.module
|
1105
|
+
>>> from old.module import attribute # this will now import new.module.attribute
|
1106
|
+
|
1107
|
+
Note:
|
1108
|
+
The changes are only in effect inside the context manager and are undone once the context manager exits.
|
1109
|
+
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
1110
|
+
applications or libraries. Use this function with caution.
|
1111
|
+
"""
|
1112
|
+
if modules is None:
|
1113
|
+
modules = {}
|
1114
|
+
if attributes is None:
|
1115
|
+
attributes = {}
|
1116
|
+
import sys
|
1117
|
+
from importlib import import_module
|
1118
|
+
|
1119
|
+
try:
|
1120
|
+
# Set attributes in sys.modules under their old name
|
1121
|
+
for old, new in attributes.items():
|
1122
|
+
old_module, old_attr = old.rsplit(".", 1)
|
1123
|
+
new_module, new_attr = new.rsplit(".", 1)
|
1124
|
+
setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
|
1125
|
+
|
1126
|
+
# Set modules in sys.modules under their old name
|
1127
|
+
for old, new in modules.items():
|
1128
|
+
sys.modules[old] = import_module(new)
|
1129
|
+
|
1130
|
+
yield
|
1131
|
+
finally:
|
1132
|
+
# Remove the temporary module paths
|
1133
|
+
for old in modules:
|
1134
|
+
if old in sys.modules:
|
1135
|
+
del sys.modules[old]
|
1136
|
+
|
1137
|
+
|
1138
|
+
class SafeClass:
|
1139
|
+
"""A placeholder class to replace unknown classes during unpickling."""
|
1140
|
+
|
1141
|
+
def __init__(self, *args, **kwargs):
|
1142
|
+
"""Initialize SafeClass instance, ignoring all arguments."""
|
1143
|
+
pass
|
1144
|
+
|
1145
|
+
def __call__(self, *args, **kwargs):
|
1146
|
+
"""Run SafeClass instance, ignoring all arguments."""
|
1147
|
+
pass
|
1148
|
+
|
1149
|
+
|
1150
|
+
class SafeUnpickler(pickle.Unpickler):
|
1151
|
+
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
1152
|
+
|
1153
|
+
def find_class(self, module, name):
|
1154
|
+
"""
|
1155
|
+
Attempt to find a class, returning SafeClass if not among safe modules.
|
1156
|
+
|
1157
|
+
Args:
|
1158
|
+
module (str): Module name.
|
1159
|
+
name (str): Class name.
|
1160
|
+
|
1161
|
+
Returns:
|
1162
|
+
(type): Found class or SafeClass.
|
1163
|
+
"""
|
1164
|
+
safe_modules = (
|
1165
|
+
"torch",
|
1166
|
+
"collections",
|
1167
|
+
"collections.abc",
|
1168
|
+
"builtins",
|
1169
|
+
"math",
|
1170
|
+
"numpy",
|
1171
|
+
# Add other modules considered safe
|
1172
|
+
)
|
1173
|
+
if module in safe_modules:
|
1174
|
+
return super().find_class(module, name)
|
1175
|
+
else:
|
1176
|
+
return SafeClass
|
1177
|
+
|
1178
|
+
|
1179
|
+
def torch_safe_load(weight, safe_only=False):
|
1180
|
+
"""
|
1181
|
+
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
|
1182
|
+
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
|
1183
|
+
After installation, the function again attempts to load the model using torch.load().
|
1184
|
+
|
1185
|
+
Args:
|
1186
|
+
weight (str): The file path of the PyTorch model.
|
1187
|
+
safe_only (bool): If True, replace unknown classes with SafeClass during loading.
|
1188
|
+
|
1189
|
+
Returns:
|
1190
|
+
ckpt (dict): The loaded model checkpoint.
|
1191
|
+
file (str): The loaded filename.
|
1192
|
+
|
1193
|
+
Examples:
|
1194
|
+
>>> from ultralytics.nn.tasks import torch_safe_load
|
1195
|
+
>>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
|
1196
|
+
"""
|
1197
|
+
from ultralytics.utils.downloads import attempt_download_asset
|
1198
|
+
|
1199
|
+
check_suffix(file=weight, suffix=".pt")
|
1200
|
+
file = attempt_download_asset(weight) # search online if missing locally
|
1201
|
+
try:
|
1202
|
+
with temporary_modules(
|
1203
|
+
modules={
|
1204
|
+
"ultralytics.yolo.utils": "ultralytics.utils",
|
1205
|
+
"ultralytics.yolo.v8": "ultralytics.models.yolo",
|
1206
|
+
"ultralytics.yolo.data": "ultralytics.data",
|
1207
|
+
},
|
1208
|
+
attributes={
|
1209
|
+
"ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
|
1210
|
+
"ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
|
1211
|
+
"ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
|
1212
|
+
},
|
1213
|
+
):
|
1214
|
+
if safe_only:
|
1215
|
+
# Load via custom pickle module
|
1216
|
+
safe_pickle = types.ModuleType("safe_pickle")
|
1217
|
+
safe_pickle.Unpickler = SafeUnpickler
|
1218
|
+
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
|
1219
|
+
with open(file, "rb") as f:
|
1220
|
+
ckpt = torch.load(f, pickle_module=safe_pickle)
|
1221
|
+
else:
|
1222
|
+
ckpt = torch.load(file, map_location="cpu")
|
1223
|
+
|
1224
|
+
except ModuleNotFoundError as e: # e.name is missing module name
|
1225
|
+
if e.name == "models":
|
1226
|
+
raise TypeError(
|
1227
|
+
emojis(
|
1228
|
+
f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
|
1229
|
+
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
1230
|
+
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
1231
|
+
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
1232
|
+
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
|
1233
|
+
)
|
1234
|
+
) from e
|
1235
|
+
LOGGER.warning(
|
1236
|
+
f"{weight} appears to require '{e.name}', which is not in Ultralytics requirements."
|
1237
|
+
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
1238
|
+
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
1239
|
+
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
|
1240
|
+
)
|
1241
|
+
check_requirements(e.name) # install missing module
|
1242
|
+
ckpt = torch.load(file, map_location="cpu")
|
1243
|
+
|
1244
|
+
if not isinstance(ckpt, dict):
|
1245
|
+
# File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt")
|
1246
|
+
LOGGER.warning(
|
1247
|
+
f"The file '{weight}' appears to be improperly saved or formatted. "
|
1248
|
+
f"For optimal results, use model.save('filename.pt') to correctly save YOLO models."
|
1249
|
+
)
|
1250
|
+
ckpt = {"model": ckpt.model}
|
1251
|
+
|
1252
|
+
return ckpt, file
|
1253
|
+
|
1254
|
+
|
1255
|
+
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
1256
|
+
"""
|
1257
|
+
Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.
|
1258
|
+
|
1259
|
+
Args:
|
1260
|
+
weights (str | List[str]): Model weights path(s).
|
1261
|
+
device (torch.device, optional): Device to load model to.
|
1262
|
+
inplace (bool): Whether to do inplace operations.
|
1263
|
+
fuse (bool): Whether to fuse model.
|
1264
|
+
|
1265
|
+
Returns:
|
1266
|
+
(torch.nn.Module): Loaded model.
|
1267
|
+
"""
|
1268
|
+
ensemble = Ensemble()
|
1269
|
+
for w in weights if isinstance(weights, list) else [weights]:
|
1270
|
+
ckpt, w = torch_safe_load(w) # load ckpt
|
1271
|
+
args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
|
1272
|
+
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
1273
|
+
|
1274
|
+
# Model compatibility updates
|
1275
|
+
model.args = args # attach args to model
|
1276
|
+
model.pt_path = w # attach *.pt file path to model
|
1277
|
+
model.task = guess_model_task(model)
|
1278
|
+
if not hasattr(model, "stride"):
|
1279
|
+
model.stride = torch.tensor([32.0])
|
1280
|
+
|
1281
|
+
# Append
|
1282
|
+
ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
|
1283
|
+
|
1284
|
+
# Module updates
|
1285
|
+
for m in ensemble.modules():
|
1286
|
+
if hasattr(m, "inplace"):
|
1287
|
+
m.inplace = inplace
|
1288
|
+
elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
1289
|
+
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
1290
|
+
|
1291
|
+
# Return model
|
1292
|
+
if len(ensemble) == 1:
|
1293
|
+
return ensemble[-1]
|
1294
|
+
|
1295
|
+
# Return ensemble
|
1296
|
+
LOGGER.info(f"Ensemble created with {weights}\n")
|
1297
|
+
for k in "names", "nc", "yaml":
|
1298
|
+
setattr(ensemble, k, getattr(ensemble[0], k))
|
1299
|
+
ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride
|
1300
|
+
assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
|
1301
|
+
return ensemble
|
1302
|
+
|
1303
|
+
|
1304
|
+
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
1305
|
+
"""
|
1306
|
+
Load a single model weights.
|
1307
|
+
|
1308
|
+
Args:
|
1309
|
+
weight (str): Model weight path.
|
1310
|
+
device (torch.device, optional): Device to load model to.
|
1311
|
+
inplace (bool): Whether to do inplace operations.
|
1312
|
+
fuse (bool): Whether to fuse model.
|
1313
|
+
|
1314
|
+
Returns:
|
1315
|
+
(tuple): Tuple containing the model and checkpoint.
|
1316
|
+
"""
|
1317
|
+
ckpt, weight = torch_safe_load(weight) # load ckpt
|
1318
|
+
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
|
1319
|
+
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
1320
|
+
|
1321
|
+
# Model compatibility updates
|
1322
|
+
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
1323
|
+
model.pt_path = weight # attach *.pt file path to model
|
1324
|
+
model.task = guess_model_task(model)
|
1325
|
+
if not hasattr(model, "stride"):
|
1326
|
+
model.stride = torch.tensor([32.0])
|
1327
|
+
|
1328
|
+
model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode
|
1329
|
+
|
1330
|
+
# Module updates
|
1331
|
+
for m in model.modules():
|
1332
|
+
if hasattr(m, "inplace"):
|
1333
|
+
m.inplace = inplace
|
1334
|
+
elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
1335
|
+
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
1336
|
+
|
1337
|
+
# Return model and ckpt
|
1338
|
+
return model, ckpt
|
1339
|
+
|
1340
|
+
|
1341
|
+
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
1342
|
+
"""
|
1343
|
+
Parse a YOLO model.yaml dictionary into a PyTorch model.
|
1344
|
+
|
1345
|
+
Args:
|
1346
|
+
d (dict): Model dictionary.
|
1347
|
+
ch (int): Input channels.
|
1348
|
+
verbose (bool): Whether to print model details.
|
1349
|
+
|
1350
|
+
Returns:
|
1351
|
+
(tuple): Tuple containing the PyTorch model and sorted list of output layers.
|
1352
|
+
"""
|
1353
|
+
import ast
|
1354
|
+
|
1355
|
+
# Args
|
1356
|
+
legacy = True # backward compatibility for v3/v5/v8/v9 models
|
1357
|
+
max_channels = float("inf")
|
1358
|
+
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
1359
|
+
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
1360
|
+
if scales:
|
1361
|
+
scale = d.get("scale")
|
1362
|
+
if not scale:
|
1363
|
+
scale = tuple(scales.keys())[0]
|
1364
|
+
LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
|
1365
|
+
depth, width, max_channels = scales[scale]
|
1366
|
+
|
1367
|
+
if act:
|
1368
|
+
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU()
|
1369
|
+
if verbose:
|
1370
|
+
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
1371
|
+
|
1372
|
+
if verbose:
|
1373
|
+
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
1374
|
+
ch = [ch]
|
1375
|
+
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
1376
|
+
base_modules = frozenset(
|
1377
|
+
{
|
1378
|
+
Classify,
|
1379
|
+
Conv,
|
1380
|
+
ConvTranspose,
|
1381
|
+
GhostConv,
|
1382
|
+
Bottleneck,
|
1383
|
+
GhostBottleneck,
|
1384
|
+
SPP,
|
1385
|
+
SPPF,
|
1386
|
+
C2fPSA,
|
1387
|
+
C2PSA,
|
1388
|
+
DWConv,
|
1389
|
+
Focus,
|
1390
|
+
BottleneckCSP,
|
1391
|
+
C1,
|
1392
|
+
C2,
|
1393
|
+
C2f,
|
1394
|
+
C3k2,
|
1395
|
+
RepNCSPELAN4,
|
1396
|
+
ELAN1,
|
1397
|
+
ADown,
|
1398
|
+
AConv,
|
1399
|
+
SPPELAN,
|
1400
|
+
C2fAttn,
|
1401
|
+
C3,
|
1402
|
+
C3TR,
|
1403
|
+
C3Ghost,
|
1404
|
+
torch.nn.ConvTranspose2d,
|
1405
|
+
DWConvTranspose2d,
|
1406
|
+
C3x,
|
1407
|
+
RepC3,
|
1408
|
+
PSA,
|
1409
|
+
SCDown,
|
1410
|
+
C2fCIB,
|
1411
|
+
A2C2f,
|
1412
|
+
}
|
1413
|
+
)
|
1414
|
+
repeat_modules = frozenset( # modules with 'repeat' arguments
|
1415
|
+
{
|
1416
|
+
BottleneckCSP,
|
1417
|
+
C1,
|
1418
|
+
C2,
|
1419
|
+
C2f,
|
1420
|
+
C3k2,
|
1421
|
+
C2fAttn,
|
1422
|
+
C3,
|
1423
|
+
C3TR,
|
1424
|
+
C3Ghost,
|
1425
|
+
C3x,
|
1426
|
+
RepC3,
|
1427
|
+
C2fPSA,
|
1428
|
+
C2fCIB,
|
1429
|
+
C2PSA,
|
1430
|
+
A2C2f,
|
1431
|
+
}
|
1432
|
+
)
|
1433
|
+
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
1434
|
+
m = (
|
1435
|
+
getattr(torch.nn, m[3:])
|
1436
|
+
if "nn." in m
|
1437
|
+
else getattr(__import__("torchvision").ops, m[16:])
|
1438
|
+
if "torchvision.ops." in m
|
1439
|
+
else globals()[m]
|
1440
|
+
) # get module
|
1441
|
+
for j, a in enumerate(args):
|
1442
|
+
if isinstance(a, str):
|
1443
|
+
with contextlib.suppress(ValueError):
|
1444
|
+
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
1445
|
+
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
1446
|
+
if m in base_modules:
|
1447
|
+
c1, c2 = ch[f], args[0]
|
1448
|
+
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
1449
|
+
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
1450
|
+
if m is C2fAttn: # set 1) embed channels and 2) num heads
|
1451
|
+
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
|
1452
|
+
args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])
|
1453
|
+
|
1454
|
+
args = [c1, c2, *args[1:]]
|
1455
|
+
if m in repeat_modules:
|
1456
|
+
args.insert(2, n) # number of repeats
|
1457
|
+
n = 1
|
1458
|
+
if m is C3k2: # for M/L/X sizes
|
1459
|
+
legacy = False
|
1460
|
+
if scale in "mlx":
|
1461
|
+
args[3] = True
|
1462
|
+
if m is A2C2f:
|
1463
|
+
legacy = False
|
1464
|
+
if scale in "lx": # for L/X sizes
|
1465
|
+
args.extend((True, 1.2))
|
1466
|
+
if m is C2fCIB:
|
1467
|
+
legacy = False
|
1468
|
+
elif m is AIFI:
|
1469
|
+
args = [ch[f], *args]
|
1470
|
+
elif m in frozenset({HGStem, HGBlock}):
|
1471
|
+
c1, cm, c2 = ch[f], args[0], args[1]
|
1472
|
+
args = [c1, cm, c2, *args[2:]]
|
1473
|
+
if m is HGBlock:
|
1474
|
+
args.insert(4, n) # number of repeats
|
1475
|
+
n = 1
|
1476
|
+
elif m is ResNetLayer:
|
1477
|
+
c2 = args[1] if args[3] else args[1] * 4
|
1478
|
+
elif m is torch.nn.BatchNorm2d:
|
1479
|
+
args = [ch[f]]
|
1480
|
+
elif m is Concat:
|
1481
|
+
c2 = sum(ch[x] for x in f)
|
1482
|
+
elif m in frozenset(
|
1483
|
+
{Detect, WorldDetect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB, ImagePoolingAttn, v10Detect}
|
1484
|
+
):
|
1485
|
+
args.append([ch[x] for x in f])
|
1486
|
+
if m is Segment or m is YOLOESegment:
|
1487
|
+
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
1488
|
+
if m in {Detect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB}:
|
1489
|
+
m.legacy = legacy
|
1490
|
+
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
1491
|
+
args.insert(1, [ch[x] for x in f])
|
1492
|
+
elif m is CBLinear:
|
1493
|
+
c2 = args[0]
|
1494
|
+
c1 = ch[f]
|
1495
|
+
args = [c1, c2, *args[1:]]
|
1496
|
+
elif m is CBFuse:
|
1497
|
+
c2 = ch[f[-1]]
|
1498
|
+
elif m in frozenset({TorchVision, Index}):
|
1499
|
+
c2 = args[0]
|
1500
|
+
c1 = ch[f]
|
1501
|
+
args = [*args[1:]]
|
1502
|
+
else:
|
1503
|
+
c2 = ch[f]
|
1504
|
+
|
1505
|
+
m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
1506
|
+
t = str(m)[8:-2].replace("__main__.", "") # module type
|
1507
|
+
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
1508
|
+
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
1509
|
+
if verbose:
|
1510
|
+
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
|
1511
|
+
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
1512
|
+
layers.append(m_)
|
1513
|
+
if i == 0:
|
1514
|
+
ch = []
|
1515
|
+
ch.append(c2)
|
1516
|
+
return torch.nn.Sequential(*layers), sorted(save)
|
1517
|
+
|
1518
|
+
|
1519
|
+
def yaml_model_load(path):
|
1520
|
+
"""
|
1521
|
+
Load a YOLOv8 model from a YAML file.
|
1522
|
+
|
1523
|
+
Args:
|
1524
|
+
path (str | Path): Path to the YAML file.
|
1525
|
+
|
1526
|
+
Returns:
|
1527
|
+
(dict): Model dictionary.
|
1528
|
+
"""
|
1529
|
+
path = Path(path)
|
1530
|
+
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
|
1531
|
+
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
|
1532
|
+
LOGGER.warning(f"Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
|
1533
|
+
path = path.with_name(new_stem + path.suffix)
|
1534
|
+
|
1535
|
+
unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
|
1536
|
+
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
|
1537
|
+
d = YAML.load(yaml_file) # model dict
|
1538
|
+
d["scale"] = guess_model_scale(path)
|
1539
|
+
d["yaml_file"] = str(path)
|
1540
|
+
return d
|
1541
|
+
|
1542
|
+
|
1543
|
+
def guess_model_scale(model_path):
|
1544
|
+
"""
|
1545
|
+
Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
1546
|
+
|
1547
|
+
Args:
|
1548
|
+
model_path (str | Path): The path to the YOLO model's YAML file.
|
1549
|
+
|
1550
|
+
Returns:
|
1551
|
+
(str): The size character of the model's scale (n, s, m, l, or x).
|
1552
|
+
"""
|
1553
|
+
try:
|
1554
|
+
return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2) # noqa
|
1555
|
+
except AttributeError:
|
1556
|
+
return ""
|
1557
|
+
|
1558
|
+
|
1559
|
+
def guess_model_task(model):
|
1560
|
+
"""
|
1561
|
+
Guess the task of a PyTorch model from its architecture or configuration.
|
1562
|
+
|
1563
|
+
Args:
|
1564
|
+
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
|
1565
|
+
|
1566
|
+
Returns:
|
1567
|
+
(str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').
|
1568
|
+
"""
|
1569
|
+
|
1570
|
+
def cfg2task(cfg):
|
1571
|
+
"""Guess from YAML dictionary."""
|
1572
|
+
m = cfg["head"][-1][-2].lower() # output module name
|
1573
|
+
if m in {"classify", "classifier", "cls", "fc"}:
|
1574
|
+
return "classify"
|
1575
|
+
if "detect" in m:
|
1576
|
+
return "detect"
|
1577
|
+
if "segment" in m:
|
1578
|
+
return "segment"
|
1579
|
+
if m == "pose":
|
1580
|
+
return "pose"
|
1581
|
+
if m == "obb":
|
1582
|
+
return "obb"
|
1583
|
+
|
1584
|
+
# Guess from model cfg
|
1585
|
+
if isinstance(model, dict):
|
1586
|
+
with contextlib.suppress(Exception):
|
1587
|
+
return cfg2task(model)
|
1588
|
+
# Guess from PyTorch model
|
1589
|
+
if isinstance(model, torch.nn.Module): # PyTorch model
|
1590
|
+
for x in "model.args", "model.model.args", "model.model.model.args":
|
1591
|
+
with contextlib.suppress(Exception):
|
1592
|
+
return eval(x)["task"]
|
1593
|
+
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
1594
|
+
with contextlib.suppress(Exception):
|
1595
|
+
return cfg2task(eval(x))
|
1596
|
+
for m in model.modules():
|
1597
|
+
if isinstance(m, (Segment, YOLOESegment)):
|
1598
|
+
return "segment"
|
1599
|
+
elif isinstance(m, Classify):
|
1600
|
+
return "classify"
|
1601
|
+
elif isinstance(m, Pose):
|
1602
|
+
return "pose"
|
1603
|
+
elif isinstance(m, OBB):
|
1604
|
+
return "obb"
|
1605
|
+
elif isinstance(m, (Detect, WorldDetect, YOLOEDetect, v10Detect)):
|
1606
|
+
return "detect"
|
1607
|
+
|
1608
|
+
# Guess from model filename
|
1609
|
+
if isinstance(model, (str, Path)):
|
1610
|
+
model = Path(model)
|
1611
|
+
if "-seg" in model.stem or "segment" in model.parts:
|
1612
|
+
return "segment"
|
1613
|
+
elif "-cls" in model.stem or "classify" in model.parts:
|
1614
|
+
return "classify"
|
1615
|
+
elif "-pose" in model.stem or "pose" in model.parts:
|
1616
|
+
return "pose"
|
1617
|
+
elif "-obb" in model.stem or "obb" in model.parts:
|
1618
|
+
return "obb"
|
1619
|
+
elif "detect" in model.parts:
|
1620
|
+
return "detect"
|
1621
|
+
|
1622
|
+
# Unable to determine task from model
|
1623
|
+
LOGGER.warning(
|
1624
|
+
"Unable to automatically guess model task, assuming 'task=detect'. "
|
1625
|
+
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
|
1626
|
+
)
|
1627
|
+
return "detect" # assume detect
|