ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +37 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +111 -41
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +579 -244
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +191 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +226 -82
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +172 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +305 -112
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.63.dist-info/METADATA +370 -0
- ultralytics-8.3.63.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import contextlib
|
4
|
+
import pickle
|
5
|
+
import re
|
6
|
+
import types
|
4
7
|
from copy import deepcopy
|
5
8
|
from pathlib import Path
|
6
9
|
|
10
|
+
import thop
|
7
11
|
import torch
|
8
12
|
import torch.nn as nn
|
9
13
|
|
@@ -11,18 +15,28 @@ from ultralytics.nn.modules import (
|
|
11
15
|
AIFI,
|
12
16
|
C1,
|
13
17
|
C2,
|
18
|
+
C2PSA,
|
14
19
|
C3,
|
15
20
|
C3TR,
|
21
|
+
ELAN1,
|
16
22
|
OBB,
|
23
|
+
PSA,
|
17
24
|
SPP,
|
25
|
+
SPPELAN,
|
18
26
|
SPPF,
|
27
|
+
AConv,
|
28
|
+
ADown,
|
19
29
|
Bottleneck,
|
20
30
|
BottleneckCSP,
|
21
31
|
C2f,
|
22
32
|
C2fAttn,
|
23
|
-
|
33
|
+
C2fCIB,
|
34
|
+
C2fPSA,
|
24
35
|
C3Ghost,
|
36
|
+
C3k2,
|
25
37
|
C3x,
|
38
|
+
CBFuse,
|
39
|
+
CBLinear,
|
26
40
|
Classify,
|
27
41
|
Concat,
|
28
42
|
Conv,
|
@@ -36,53 +50,60 @@ from ultralytics.nn.modules import (
|
|
36
50
|
GhostConv,
|
37
51
|
HGBlock,
|
38
52
|
HGStem,
|
53
|
+
ImagePoolingAttn,
|
54
|
+
Index,
|
39
55
|
Pose,
|
40
56
|
RepC3,
|
41
57
|
RepConv,
|
58
|
+
RepNCSPELAN4,
|
59
|
+
RepVGGDW,
|
42
60
|
ResNetLayer,
|
43
61
|
RTDETRDecoder,
|
62
|
+
SCDown,
|
44
63
|
Segment,
|
64
|
+
TorchVision,
|
45
65
|
WorldDetect,
|
46
|
-
|
47
|
-
ADown,
|
48
|
-
SPPELAN,
|
49
|
-
CBFuse,
|
50
|
-
CBLinear,
|
51
|
-
Silence,
|
66
|
+
v10Detect,
|
52
67
|
)
|
53
68
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
54
69
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
55
|
-
from ultralytics.utils.loss import
|
70
|
+
from ultralytics.utils.loss import (
|
71
|
+
E2EDetectLoss,
|
72
|
+
v8ClassificationLoss,
|
73
|
+
v8DetectionLoss,
|
74
|
+
v8OBBLoss,
|
75
|
+
v8PoseLoss,
|
76
|
+
v8SegmentationLoss,
|
77
|
+
)
|
78
|
+
from ultralytics.utils.ops import make_divisible
|
56
79
|
from ultralytics.utils.plotting import feature_visualization
|
57
80
|
from ultralytics.utils.torch_utils import (
|
58
81
|
fuse_conv_and_bn,
|
59
82
|
fuse_deconv_and_bn,
|
60
83
|
initialize_weights,
|
61
84
|
intersect_dicts,
|
62
|
-
make_divisible,
|
63
85
|
model_info,
|
64
86
|
scale_img,
|
65
87
|
time_sync,
|
66
88
|
)
|
67
89
|
|
68
|
-
try:
|
69
|
-
import thop
|
70
|
-
except ImportError:
|
71
|
-
thop = None
|
72
|
-
|
73
90
|
|
74
91
|
class BaseModel(nn.Module):
|
75
92
|
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
|
76
93
|
|
77
94
|
def forward(self, x, *args, **kwargs):
|
78
95
|
"""
|
79
|
-
|
96
|
+
Perform forward pass of the model for either training or inference.
|
97
|
+
|
98
|
+
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
80
99
|
|
81
100
|
Args:
|
82
|
-
x (torch.Tensor | dict):
|
101
|
+
x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
|
102
|
+
*args (Any): Variable length argument list.
|
103
|
+
**kwargs (Any): Arbitrary keyword arguments.
|
83
104
|
|
84
105
|
Returns:
|
85
|
-
(torch.Tensor):
|
106
|
+
(torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
|
86
107
|
"""
|
87
108
|
if isinstance(x, dict): # for cases of training and validating while training.
|
88
109
|
return self.loss(x, *args, **kwargs)
|
@@ -138,8 +159,8 @@ class BaseModel(nn.Module):
|
|
138
159
|
def _predict_augment(self, x):
|
139
160
|
"""Perform augmentations on input image x and return augmented inference."""
|
140
161
|
LOGGER.warning(
|
141
|
-
f"WARNING ⚠️ {self.__class__.__name__} does not support
|
142
|
-
f"Reverting to single-scale
|
162
|
+
f"WARNING ⚠️ {self.__class__.__name__} does not support 'augment=True' prediction. "
|
163
|
+
f"Reverting to single-scale prediction."
|
143
164
|
)
|
144
165
|
return self._predict_once(x)
|
145
166
|
|
@@ -157,7 +178,7 @@ class BaseModel(nn.Module):
|
|
157
178
|
None
|
158
179
|
"""
|
159
180
|
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
160
|
-
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 #
|
181
|
+
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
161
182
|
t = time_sync()
|
162
183
|
for _ in range(10):
|
163
184
|
m(x.copy() if c else x)
|
@@ -191,6 +212,9 @@ class BaseModel(nn.Module):
|
|
191
212
|
if isinstance(m, RepConv):
|
192
213
|
m.fuse_convs()
|
193
214
|
m.forward = m.forward_fuse # update forward
|
215
|
+
if isinstance(m, RepVGGDW):
|
216
|
+
m.fuse()
|
217
|
+
m.forward = m.forward_fuse
|
194
218
|
self.info(verbose=verbose)
|
195
219
|
|
196
220
|
return self
|
@@ -260,7 +284,7 @@ class BaseModel(nn.Module):
|
|
260
284
|
batch (dict): Batch to compute loss on
|
261
285
|
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
262
286
|
"""
|
263
|
-
if
|
287
|
+
if getattr(self, "criterion", None) is None:
|
264
288
|
self.criterion = self.init_criterion()
|
265
289
|
|
266
290
|
preds = self.forward(batch["img"]) if preds is None else preds
|
@@ -278,6 +302,12 @@ class DetectionModel(BaseModel):
|
|
278
302
|
"""Initialize the YOLOv8 detection model with the given config and parameters."""
|
279
303
|
super().__init__()
|
280
304
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
305
|
+
if self.yaml["backbone"][0][2] == "Silence":
|
306
|
+
LOGGER.warning(
|
307
|
+
"WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. "
|
308
|
+
"Please delete local *.pt file and re-download the latest model checkpoint."
|
309
|
+
)
|
310
|
+
self.yaml["backbone"][0][2] = "nn.Identity"
|
281
311
|
|
282
312
|
# Define model
|
283
313
|
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
@@ -287,14 +317,21 @@ class DetectionModel(BaseModel):
|
|
287
317
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
288
318
|
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
289
319
|
self.inplace = self.yaml.get("inplace", True)
|
320
|
+
self.end2end = getattr(self.model[-1], "end2end", False)
|
290
321
|
|
291
322
|
# Build strides
|
292
323
|
m = self.model[-1] # Detect()
|
293
324
|
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
294
325
|
s = 256 # 2x min stride
|
295
326
|
m.inplace = self.inplace
|
296
|
-
|
297
|
-
|
327
|
+
|
328
|
+
def _forward(x):
|
329
|
+
"""Performs a forward pass through the model, handling different Detect subclass types accordingly."""
|
330
|
+
if self.end2end:
|
331
|
+
return self.forward(x)["one2many"]
|
332
|
+
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
|
333
|
+
|
334
|
+
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
|
298
335
|
self.stride = m.stride
|
299
336
|
m.bias_init() # only run once
|
300
337
|
else:
|
@@ -308,6 +345,9 @@ class DetectionModel(BaseModel):
|
|
308
345
|
|
309
346
|
def _predict_augment(self, x):
|
310
347
|
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
348
|
+
if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
|
349
|
+
LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.")
|
350
|
+
return self._predict_once(x)
|
311
351
|
img_size = x.shape[-2:] # height, width
|
312
352
|
s = [1, 0.83, 0.67] # scales
|
313
353
|
f = [None, 3, None] # flips (2-ud, 3-lr)
|
@@ -344,7 +384,7 @@ class DetectionModel(BaseModel):
|
|
344
384
|
|
345
385
|
def init_criterion(self):
|
346
386
|
"""Initialize the loss criterion for the DetectionModel."""
|
347
|
-
return v8DetectionLoss(self)
|
387
|
+
return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
|
348
388
|
|
349
389
|
|
350
390
|
class OBBModel(DetectionModel):
|
@@ -425,11 +465,11 @@ class ClassificationModel(BaseModel):
|
|
425
465
|
elif isinstance(m, nn.Sequential):
|
426
466
|
types = [type(x) for x in m]
|
427
467
|
if nn.Linear in types:
|
428
|
-
i = types.index(nn.Linear) # nn.Linear index
|
468
|
+
i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index
|
429
469
|
if m[i].out_features != nc:
|
430
470
|
m[i] = nn.Linear(m[i].in_features, nc)
|
431
471
|
elif nn.Conv2d in types:
|
432
|
-
i = types.index(nn.Conv2d) # nn.Conv2d index
|
472
|
+
i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index
|
433
473
|
if m[i].out_channels != nc:
|
434
474
|
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
435
475
|
|
@@ -560,30 +600,32 @@ class WorldModel(DetectionModel):
|
|
560
600
|
|
561
601
|
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
562
602
|
"""Initialize YOLOv8 world model with given config and parameters."""
|
563
|
-
self.txt_feats = torch.randn(1, nc or 80, 512) # placeholder
|
603
|
+
self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
|
604
|
+
self.clip_model = None # CLIP model placeholder
|
564
605
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
565
606
|
|
566
|
-
def set_classes(self, text):
|
567
|
-
"""
|
607
|
+
def set_classes(self, text, batch=80, cache_clip_model=True):
|
608
|
+
"""Set classes in advance so that model could do offline-inference without clip model."""
|
568
609
|
try:
|
569
610
|
import clip
|
570
611
|
except ImportError:
|
571
|
-
check_requirements("git+https://github.com/
|
612
|
+
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
572
613
|
import clip
|
573
614
|
|
574
|
-
|
615
|
+
if (
|
616
|
+
not getattr(self, "clip_model", None) and cache_clip_model
|
617
|
+
): # for backwards compatibility of models lacking clip_model attribute
|
618
|
+
self.clip_model = clip.load("ViT-B/32")[0]
|
619
|
+
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
|
575
620
|
device = next(model.parameters()).device
|
576
621
|
text_token = clip.tokenize(text).to(device)
|
577
|
-
txt_feats = model.encode_text(
|
622
|
+
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
623
|
+
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
578
624
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
579
|
-
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
625
|
+
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
580
626
|
self.model[-1].nc = len(text)
|
581
627
|
|
582
|
-
def
|
583
|
-
"""Initialize the loss criterion for the model."""
|
584
|
-
raise NotImplementedError
|
585
|
-
|
586
|
-
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
628
|
+
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
587
629
|
"""
|
588
630
|
Perform a forward pass through the model.
|
589
631
|
|
@@ -591,13 +633,14 @@ class WorldModel(DetectionModel):
|
|
591
633
|
x (torch.Tensor): The input tensor.
|
592
634
|
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
|
593
635
|
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
636
|
+
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
|
594
637
|
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
595
638
|
embed (list, optional): A list of feature vectors/embeddings to return.
|
596
639
|
|
597
640
|
Returns:
|
598
641
|
(torch.Tensor): Model's output tensor.
|
599
642
|
"""
|
600
|
-
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
|
643
|
+
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
601
644
|
if len(txt_feats) != len(x):
|
602
645
|
txt_feats = txt_feats.repeat(len(x), 1, 1)
|
603
646
|
ori_txt_feats = txt_feats.clone()
|
@@ -625,6 +668,21 @@ class WorldModel(DetectionModel):
|
|
625
668
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
626
669
|
return x
|
627
670
|
|
671
|
+
def loss(self, batch, preds=None):
|
672
|
+
"""
|
673
|
+
Compute loss.
|
674
|
+
|
675
|
+
Args:
|
676
|
+
batch (dict): Batch to compute loss on.
|
677
|
+
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
678
|
+
"""
|
679
|
+
if not hasattr(self, "criterion"):
|
680
|
+
self.criterion = self.init_criterion()
|
681
|
+
|
682
|
+
if preds is None:
|
683
|
+
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
|
684
|
+
return self.criterion(preds, batch)
|
685
|
+
|
628
686
|
|
629
687
|
class Ensemble(nn.ModuleList):
|
630
688
|
"""Ensemble of models."""
|
@@ -646,7 +704,7 @@ class Ensemble(nn.ModuleList):
|
|
646
704
|
|
647
705
|
|
648
706
|
@contextlib.contextmanager
|
649
|
-
def temporary_modules(modules=None):
|
707
|
+
def temporary_modules(modules=None, attributes=None):
|
650
708
|
"""
|
651
709
|
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
652
710
|
|
@@ -656,11 +714,13 @@ def temporary_modules(modules=None):
|
|
656
714
|
|
657
715
|
Args:
|
658
716
|
modules (dict, optional): A dictionary mapping old module paths to new module paths.
|
717
|
+
attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.
|
659
718
|
|
660
719
|
Example:
|
661
720
|
```python
|
662
|
-
with temporary_modules({
|
663
|
-
import old.module
|
721
|
+
with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
|
722
|
+
import old.module # this will now import new.module
|
723
|
+
from old.module import attribute # this will now import new.module.attribute
|
664
724
|
```
|
665
725
|
|
666
726
|
Note:
|
@@ -668,16 +728,23 @@ def temporary_modules(modules=None):
|
|
668
728
|
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
669
729
|
applications or libraries. Use this function with caution.
|
670
730
|
"""
|
671
|
-
if
|
731
|
+
if modules is None:
|
672
732
|
modules = {}
|
673
|
-
|
674
|
-
|
733
|
+
if attributes is None:
|
734
|
+
attributes = {}
|
675
735
|
import sys
|
736
|
+
from importlib import import_module
|
676
737
|
|
677
738
|
try:
|
739
|
+
# Set attributes in sys.modules under their old name
|
740
|
+
for old, new in attributes.items():
|
741
|
+
old_module, old_attr = old.rsplit(".", 1)
|
742
|
+
new_module, new_attr = new.rsplit(".", 1)
|
743
|
+
setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
|
744
|
+
|
678
745
|
# Set modules in sys.modules under their old name
|
679
746
|
for old, new in modules.items():
|
680
|
-
sys.modules[old] =
|
747
|
+
sys.modules[old] = import_module(new)
|
681
748
|
|
682
749
|
yield
|
683
750
|
finally:
|
@@ -687,17 +754,58 @@ def temporary_modules(modules=None):
|
|
687
754
|
del sys.modules[old]
|
688
755
|
|
689
756
|
|
690
|
-
|
757
|
+
class SafeClass:
|
758
|
+
"""A placeholder class to replace unknown classes during unpickling."""
|
759
|
+
|
760
|
+
def __init__(self, *args, **kwargs):
|
761
|
+
"""Initialize SafeClass instance, ignoring all arguments."""
|
762
|
+
pass
|
763
|
+
|
764
|
+
def __call__(self, *args, **kwargs):
|
765
|
+
"""Run SafeClass instance, ignoring all arguments."""
|
766
|
+
pass
|
767
|
+
|
768
|
+
|
769
|
+
class SafeUnpickler(pickle.Unpickler):
|
770
|
+
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
771
|
+
|
772
|
+
def find_class(self, module, name):
|
773
|
+
"""Attempt to find a class, returning SafeClass if not among safe modules."""
|
774
|
+
safe_modules = (
|
775
|
+
"torch",
|
776
|
+
"collections",
|
777
|
+
"collections.abc",
|
778
|
+
"builtins",
|
779
|
+
"math",
|
780
|
+
"numpy",
|
781
|
+
# Add other modules considered safe
|
782
|
+
)
|
783
|
+
if module in safe_modules:
|
784
|
+
return super().find_class(module, name)
|
785
|
+
else:
|
786
|
+
return SafeClass
|
787
|
+
|
788
|
+
|
789
|
+
def torch_safe_load(weight, safe_only=False):
|
691
790
|
"""
|
692
|
-
|
693
|
-
|
694
|
-
|
791
|
+
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
|
792
|
+
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
|
793
|
+
After installation, the function again attempts to load the model using torch.load().
|
695
794
|
|
696
795
|
Args:
|
697
796
|
weight (str): The file path of the PyTorch model.
|
797
|
+
safe_only (bool): If True, replace unknown classes with SafeClass during loading.
|
798
|
+
|
799
|
+
Example:
|
800
|
+
```python
|
801
|
+
from ultralytics.nn.tasks import torch_safe_load
|
802
|
+
|
803
|
+
ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
|
804
|
+
```
|
698
805
|
|
699
806
|
Returns:
|
700
|
-
(dict): The loaded
|
807
|
+
ckpt (dict): The loaded model checkpoint.
|
808
|
+
file (str): The loaded filename
|
701
809
|
"""
|
702
810
|
from ultralytics.utils.downloads import attempt_download_asset
|
703
811
|
|
@@ -705,13 +813,26 @@ def torch_safe_load(weight):
|
|
705
813
|
file = attempt_download_asset(weight) # search online if missing locally
|
706
814
|
try:
|
707
815
|
with temporary_modules(
|
708
|
-
{
|
816
|
+
modules={
|
709
817
|
"ultralytics.yolo.utils": "ultralytics.utils",
|
710
818
|
"ultralytics.yolo.v8": "ultralytics.models.yolo",
|
711
819
|
"ultralytics.yolo.data": "ultralytics.data",
|
712
|
-
}
|
713
|
-
|
714
|
-
|
820
|
+
},
|
821
|
+
attributes={
|
822
|
+
"ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
|
823
|
+
"ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
|
824
|
+
"ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
|
825
|
+
},
|
826
|
+
):
|
827
|
+
if safe_only:
|
828
|
+
# Load via custom pickle module
|
829
|
+
safe_pickle = types.ModuleType("safe_pickle")
|
830
|
+
safe_pickle.Unpickler = SafeUnpickler
|
831
|
+
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
|
832
|
+
with open(file, "rb") as f:
|
833
|
+
ckpt = torch.load(f, pickle_module=safe_pickle)
|
834
|
+
else:
|
835
|
+
ckpt = torch.load(file, map_location="cpu")
|
715
836
|
|
716
837
|
except ModuleNotFoundError as e: # e.name is missing module name
|
717
838
|
if e.name == "models":
|
@@ -721,14 +842,14 @@ def torch_safe_load(weight):
|
|
721
842
|
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
722
843
|
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
723
844
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
724
|
-
f"run a command with an official
|
845
|
+
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
|
725
846
|
)
|
726
847
|
) from e
|
727
848
|
LOGGER.warning(
|
728
|
-
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in
|
849
|
+
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements."
|
729
850
|
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
730
851
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
731
|
-
f"run a command with an official
|
852
|
+
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
|
732
853
|
)
|
733
854
|
check_requirements(e.name) # install missing module
|
734
855
|
ckpt = torch.load(file, map_location="cpu")
|
@@ -741,12 +862,11 @@ def torch_safe_load(weight):
|
|
741
862
|
)
|
742
863
|
ckpt = {"model": ckpt.model}
|
743
864
|
|
744
|
-
return ckpt, file
|
865
|
+
return ckpt, file
|
745
866
|
|
746
867
|
|
747
868
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
748
869
|
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
|
749
|
-
|
750
870
|
ensemble = Ensemble()
|
751
871
|
for w in weights if isinstance(weights, list) else [weights]:
|
752
872
|
ckpt, w = torch_safe_load(w) # load ckpt
|
@@ -814,6 +934,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
814
934
|
import ast
|
815
935
|
|
816
936
|
# Args
|
937
|
+
legacy = True # backward compatibility for v3/v5/v8/v9 models
|
817
938
|
max_channels = float("inf")
|
818
939
|
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
819
940
|
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
@@ -839,9 +960,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
839
960
|
if isinstance(a, str):
|
840
961
|
with contextlib.suppress(ValueError):
|
841
962
|
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
842
|
-
|
843
963
|
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
844
|
-
if m in
|
964
|
+
if m in {
|
845
965
|
Classify,
|
846
966
|
Conv,
|
847
967
|
ConvTranspose,
|
@@ -850,14 +970,19 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
850
970
|
GhostBottleneck,
|
851
971
|
SPP,
|
852
972
|
SPPF,
|
973
|
+
C2fPSA,
|
974
|
+
C2PSA,
|
853
975
|
DWConv,
|
854
976
|
Focus,
|
855
977
|
BottleneckCSP,
|
856
978
|
C1,
|
857
979
|
C2,
|
858
980
|
C2f,
|
981
|
+
C3k2,
|
859
982
|
RepNCSPELAN4,
|
983
|
+
ELAN1,
|
860
984
|
ADown,
|
985
|
+
AConv,
|
861
986
|
SPPELAN,
|
862
987
|
C2fAttn,
|
863
988
|
C3,
|
@@ -867,7 +992,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
867
992
|
DWConvTranspose2d,
|
868
993
|
C3x,
|
869
994
|
RepC3,
|
870
|
-
|
995
|
+
PSA,
|
996
|
+
SCDown,
|
997
|
+
C2fCIB,
|
998
|
+
}:
|
871
999
|
c1, c2 = ch[f], args[0]
|
872
1000
|
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
873
1001
|
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
@@ -878,12 +1006,31 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
878
1006
|
) # num heads
|
879
1007
|
|
880
1008
|
args = [c1, c2, *args[1:]]
|
881
|
-
if m in
|
1009
|
+
if m in {
|
1010
|
+
BottleneckCSP,
|
1011
|
+
C1,
|
1012
|
+
C2,
|
1013
|
+
C2f,
|
1014
|
+
C3k2,
|
1015
|
+
C2fAttn,
|
1016
|
+
C3,
|
1017
|
+
C3TR,
|
1018
|
+
C3Ghost,
|
1019
|
+
C3x,
|
1020
|
+
RepC3,
|
1021
|
+
C2fPSA,
|
1022
|
+
C2fCIB,
|
1023
|
+
C2PSA,
|
1024
|
+
}:
|
882
1025
|
args.insert(2, n) # number of repeats
|
883
1026
|
n = 1
|
1027
|
+
if m is C3k2: # for M/L/X sizes
|
1028
|
+
legacy = False
|
1029
|
+
if scale in "mlx":
|
1030
|
+
args[3] = True
|
884
1031
|
elif m is AIFI:
|
885
1032
|
args = [ch[f], *args]
|
886
|
-
elif m in
|
1033
|
+
elif m in {HGStem, HGBlock}:
|
887
1034
|
c1, cm, c2 = ch[f], args[0], args[1]
|
888
1035
|
args = [c1, cm, c2, *args[2:]]
|
889
1036
|
if m is HGBlock:
|
@@ -895,13 +1042,15 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
895
1042
|
args = [ch[f]]
|
896
1043
|
elif m is Concat:
|
897
1044
|
c2 = sum(ch[x] for x in f)
|
898
|
-
elif m in
|
1045
|
+
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
|
899
1046
|
args.append([ch[x] for x in f])
|
900
1047
|
if m is Segment:
|
901
1048
|
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
1049
|
+
if m in {Detect, Segment, Pose, OBB}:
|
1050
|
+
m.legacy = legacy
|
902
1051
|
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
903
1052
|
args.insert(1, [ch[x] for x in f])
|
904
|
-
elif m
|
1053
|
+
elif m in {CBLinear, TorchVision, Index}:
|
905
1054
|
c2 = args[0]
|
906
1055
|
c1 = ch[f]
|
907
1056
|
args = [c1, c2, *args[1:]]
|
@@ -912,10 +1061,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
912
1061
|
|
913
1062
|
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
914
1063
|
t = str(m)[8:-2].replace("__main__.", "") # module type
|
915
|
-
|
1064
|
+
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
916
1065
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
917
1066
|
if verbose:
|
918
|
-
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{
|
1067
|
+
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
|
919
1068
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
920
1069
|
layers.append(m_)
|
921
1070
|
if i == 0:
|
@@ -926,8 +1075,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
926
1075
|
|
927
1076
|
def yaml_model_load(path):
|
928
1077
|
"""Load a YOLOv8 model from a YAML file."""
|
929
|
-
import re
|
930
|
-
|
931
1078
|
path = Path(path)
|
932
1079
|
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
|
933
1080
|
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
|
@@ -954,11 +1101,10 @@ def guess_model_scale(model_path):
|
|
954
1101
|
Returns:
|
955
1102
|
(str): The size character of the model's scale, which can be n, s, m, l, or x.
|
956
1103
|
"""
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
return
|
961
|
-
return ""
|
1104
|
+
try:
|
1105
|
+
return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # noqa, returns n, s, m, l, or x
|
1106
|
+
except AttributeError:
|
1107
|
+
return ""
|
962
1108
|
|
963
1109
|
|
964
1110
|
def guess_model_task(model):
|
@@ -978,9 +1124,9 @@ def guess_model_task(model):
|
|
978
1124
|
def cfg2task(cfg):
|
979
1125
|
"""Guess from YAML dictionary."""
|
980
1126
|
m = cfg["head"][-1][-2].lower() # output module name
|
981
|
-
if m in
|
1127
|
+
if m in {"classify", "classifier", "cls", "fc"}:
|
982
1128
|
return "classify"
|
983
|
-
if
|
1129
|
+
if "detect" in m:
|
984
1130
|
return "detect"
|
985
1131
|
if m == "segment":
|
986
1132
|
return "segment"
|
@@ -993,7 +1139,6 @@ def guess_model_task(model):
|
|
993
1139
|
if isinstance(model, dict):
|
994
1140
|
with contextlib.suppress(Exception):
|
995
1141
|
return cfg2task(model)
|
996
|
-
|
997
1142
|
# Guess from PyTorch model
|
998
1143
|
if isinstance(model, nn.Module): # PyTorch model
|
999
1144
|
for x in "model.args", "model.model.args", "model.model.model.args":
|
@@ -1002,7 +1147,6 @@ def guess_model_task(model):
|
|
1002
1147
|
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
1003
1148
|
with contextlib.suppress(Exception):
|
1004
1149
|
return cfg2task(eval(x))
|
1005
|
-
|
1006
1150
|
for m in model.modules():
|
1007
1151
|
if isinstance(m, Segment):
|
1008
1152
|
return "segment"
|
@@ -1012,7 +1156,7 @@ def guess_model_task(model):
|
|
1012
1156
|
return "pose"
|
1013
1157
|
elif isinstance(m, OBB):
|
1014
1158
|
return "obb"
|
1015
|
-
elif isinstance(m, (Detect, WorldDetect)):
|
1159
|
+
elif isinstance(m, (Detect, WorldDetect, v10Detect)):
|
1016
1160
|
return "detect"
|
1017
1161
|
|
1018
1162
|
# Guess from model filename
|