ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +527 -67
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +44 -37
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +84 -56
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.28.dist-info/METADATA +0 -373
- ultralytics-8.1.28.dist-info/RECORD +0 -197
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import contextlib
|
4
|
+
import pickle
|
5
|
+
import re
|
6
|
+
import types
|
4
7
|
from copy import deepcopy
|
5
8
|
from pathlib import Path
|
6
9
|
|
@@ -11,18 +14,28 @@ from ultralytics.nn.modules import (
|
|
11
14
|
AIFI,
|
12
15
|
C1,
|
13
16
|
C2,
|
17
|
+
C2PSA,
|
14
18
|
C3,
|
15
19
|
C3TR,
|
20
|
+
ELAN1,
|
16
21
|
OBB,
|
22
|
+
PSA,
|
17
23
|
SPP,
|
24
|
+
SPPELAN,
|
18
25
|
SPPF,
|
26
|
+
AConv,
|
27
|
+
ADown,
|
19
28
|
Bottleneck,
|
20
29
|
BottleneckCSP,
|
21
30
|
C2f,
|
22
31
|
C2fAttn,
|
23
|
-
|
32
|
+
C2fCIB,
|
33
|
+
C2fPSA,
|
24
34
|
C3Ghost,
|
35
|
+
C3k2,
|
25
36
|
C3x,
|
37
|
+
CBFuse,
|
38
|
+
CBLinear,
|
26
39
|
Classify,
|
27
40
|
Concat,
|
28
41
|
Conv,
|
@@ -36,30 +49,38 @@ from ultralytics.nn.modules import (
|
|
36
49
|
GhostConv,
|
37
50
|
HGBlock,
|
38
51
|
HGStem,
|
52
|
+
ImagePoolingAttn,
|
53
|
+
Index,
|
39
54
|
Pose,
|
40
55
|
RepC3,
|
41
56
|
RepConv,
|
57
|
+
RepNCSPELAN4,
|
58
|
+
RepVGGDW,
|
42
59
|
ResNetLayer,
|
43
60
|
RTDETRDecoder,
|
61
|
+
SCDown,
|
44
62
|
Segment,
|
63
|
+
TorchVision,
|
45
64
|
WorldDetect,
|
46
|
-
|
47
|
-
ADown,
|
48
|
-
SPPELAN,
|
49
|
-
CBFuse,
|
50
|
-
CBLinear,
|
51
|
-
Silence,
|
65
|
+
v10Detect,
|
52
66
|
)
|
53
67
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
54
68
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
55
|
-
from ultralytics.utils.loss import
|
69
|
+
from ultralytics.utils.loss import (
|
70
|
+
E2EDetectLoss,
|
71
|
+
v8ClassificationLoss,
|
72
|
+
v8DetectionLoss,
|
73
|
+
v8OBBLoss,
|
74
|
+
v8PoseLoss,
|
75
|
+
v8SegmentationLoss,
|
76
|
+
)
|
77
|
+
from ultralytics.utils.ops import make_divisible
|
56
78
|
from ultralytics.utils.plotting import feature_visualization
|
57
79
|
from ultralytics.utils.torch_utils import (
|
58
80
|
fuse_conv_and_bn,
|
59
81
|
fuse_deconv_and_bn,
|
60
82
|
initialize_weights,
|
61
83
|
intersect_dicts,
|
62
|
-
make_divisible,
|
63
84
|
model_info,
|
64
85
|
scale_img,
|
65
86
|
time_sync,
|
@@ -76,13 +97,17 @@ class BaseModel(nn.Module):
|
|
76
97
|
|
77
98
|
def forward(self, x, *args, **kwargs):
|
78
99
|
"""
|
79
|
-
|
100
|
+
Perform forward pass of the model for either training or inference.
|
101
|
+
|
102
|
+
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
80
103
|
|
81
104
|
Args:
|
82
|
-
x (torch.Tensor | dict):
|
105
|
+
x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
|
106
|
+
*args (Any): Variable length argument list.
|
107
|
+
**kwargs (Any): Arbitrary keyword arguments.
|
83
108
|
|
84
109
|
Returns:
|
85
|
-
(torch.Tensor):
|
110
|
+
(torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
|
86
111
|
"""
|
87
112
|
if isinstance(x, dict): # for cases of training and validating while training.
|
88
113
|
return self.loss(x, *args, **kwargs)
|
@@ -138,8 +163,8 @@ class BaseModel(nn.Module):
|
|
138
163
|
def _predict_augment(self, x):
|
139
164
|
"""Perform augmentations on input image x and return augmented inference."""
|
140
165
|
LOGGER.warning(
|
141
|
-
f"WARNING ⚠️ {self.__class__.__name__} does not support
|
142
|
-
f"Reverting to single-scale
|
166
|
+
f"WARNING ⚠️ {self.__class__.__name__} does not support 'augment=True' prediction. "
|
167
|
+
f"Reverting to single-scale prediction."
|
143
168
|
)
|
144
169
|
return self._predict_once(x)
|
145
170
|
|
@@ -157,7 +182,7 @@ class BaseModel(nn.Module):
|
|
157
182
|
None
|
158
183
|
"""
|
159
184
|
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
160
|
-
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 #
|
185
|
+
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
161
186
|
t = time_sync()
|
162
187
|
for _ in range(10):
|
163
188
|
m(x.copy() if c else x)
|
@@ -191,6 +216,9 @@ class BaseModel(nn.Module):
|
|
191
216
|
if isinstance(m, RepConv):
|
192
217
|
m.fuse_convs()
|
193
218
|
m.forward = m.forward_fuse # update forward
|
219
|
+
if isinstance(m, RepVGGDW):
|
220
|
+
m.fuse()
|
221
|
+
m.forward = m.forward_fuse
|
194
222
|
self.info(verbose=verbose)
|
195
223
|
|
196
224
|
return self
|
@@ -260,7 +288,7 @@ class BaseModel(nn.Module):
|
|
260
288
|
batch (dict): Batch to compute loss on
|
261
289
|
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
262
290
|
"""
|
263
|
-
if
|
291
|
+
if getattr(self, "criterion", None) is None:
|
264
292
|
self.criterion = self.init_criterion()
|
265
293
|
|
266
294
|
preds = self.forward(batch["img"]) if preds is None else preds
|
@@ -278,6 +306,12 @@ class DetectionModel(BaseModel):
|
|
278
306
|
"""Initialize the YOLOv8 detection model with the given config and parameters."""
|
279
307
|
super().__init__()
|
280
308
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
309
|
+
if self.yaml["backbone"][0][2] == "Silence":
|
310
|
+
LOGGER.warning(
|
311
|
+
"WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. "
|
312
|
+
"Please delete local *.pt file and re-download the latest model checkpoint."
|
313
|
+
)
|
314
|
+
self.yaml["backbone"][0][2] = "nn.Identity"
|
281
315
|
|
282
316
|
# Define model
|
283
317
|
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
@@ -287,14 +321,21 @@ class DetectionModel(BaseModel):
|
|
287
321
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
288
322
|
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
289
323
|
self.inplace = self.yaml.get("inplace", True)
|
324
|
+
self.end2end = getattr(self.model[-1], "end2end", False)
|
290
325
|
|
291
326
|
# Build strides
|
292
327
|
m = self.model[-1] # Detect()
|
293
328
|
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
294
329
|
s = 256 # 2x min stride
|
295
330
|
m.inplace = self.inplace
|
296
|
-
|
297
|
-
|
331
|
+
|
332
|
+
def _forward(x):
|
333
|
+
"""Performs a forward pass through the model, handling different Detect subclass types accordingly."""
|
334
|
+
if self.end2end:
|
335
|
+
return self.forward(x)["one2many"]
|
336
|
+
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
|
337
|
+
|
338
|
+
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
|
298
339
|
self.stride = m.stride
|
299
340
|
m.bias_init() # only run once
|
300
341
|
else:
|
@@ -308,6 +349,9 @@ class DetectionModel(BaseModel):
|
|
308
349
|
|
309
350
|
def _predict_augment(self, x):
|
310
351
|
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
352
|
+
if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
|
353
|
+
LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.")
|
354
|
+
return self._predict_once(x)
|
311
355
|
img_size = x.shape[-2:] # height, width
|
312
356
|
s = [1, 0.83, 0.67] # scales
|
313
357
|
f = [None, 3, None] # flips (2-ud, 3-lr)
|
@@ -344,7 +388,7 @@ class DetectionModel(BaseModel):
|
|
344
388
|
|
345
389
|
def init_criterion(self):
|
346
390
|
"""Initialize the loss criterion for the DetectionModel."""
|
347
|
-
return v8DetectionLoss(self)
|
391
|
+
return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
|
348
392
|
|
349
393
|
|
350
394
|
class OBBModel(DetectionModel):
|
@@ -425,11 +469,11 @@ class ClassificationModel(BaseModel):
|
|
425
469
|
elif isinstance(m, nn.Sequential):
|
426
470
|
types = [type(x) for x in m]
|
427
471
|
if nn.Linear in types:
|
428
|
-
i = types.index(nn.Linear) # nn.Linear index
|
472
|
+
i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index
|
429
473
|
if m[i].out_features != nc:
|
430
474
|
m[i] = nn.Linear(m[i].in_features, nc)
|
431
475
|
elif nn.Conv2d in types:
|
432
|
-
i = types.index(nn.Conv2d) # nn.Conv2d index
|
476
|
+
i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index
|
433
477
|
if m[i].out_channels != nc:
|
434
478
|
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
435
479
|
|
@@ -560,30 +604,32 @@ class WorldModel(DetectionModel):
|
|
560
604
|
|
561
605
|
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
562
606
|
"""Initialize YOLOv8 world model with given config and parameters."""
|
563
|
-
self.txt_feats = torch.randn(1, nc or 80, 512) # placeholder
|
607
|
+
self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
|
608
|
+
self.clip_model = None # CLIP model placeholder
|
564
609
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
565
610
|
|
566
|
-
def set_classes(self, text):
|
567
|
-
"""
|
611
|
+
def set_classes(self, text, batch=80, cache_clip_model=True):
|
612
|
+
"""Set classes in advance so that model could do offline-inference without clip model."""
|
568
613
|
try:
|
569
614
|
import clip
|
570
615
|
except ImportError:
|
571
|
-
check_requirements("git+https://github.com/
|
616
|
+
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
572
617
|
import clip
|
573
618
|
|
574
|
-
|
619
|
+
if (
|
620
|
+
not getattr(self, "clip_model", None) and cache_clip_model
|
621
|
+
): # for backwards compatibility of models lacking clip_model attribute
|
622
|
+
self.clip_model = clip.load("ViT-B/32")[0]
|
623
|
+
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
|
575
624
|
device = next(model.parameters()).device
|
576
625
|
text_token = clip.tokenize(text).to(device)
|
577
|
-
txt_feats = model.encode_text(
|
626
|
+
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
627
|
+
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
578
628
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
579
|
-
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
629
|
+
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
580
630
|
self.model[-1].nc = len(text)
|
581
631
|
|
582
|
-
def
|
583
|
-
"""Initialize the loss criterion for the model."""
|
584
|
-
raise NotImplementedError
|
585
|
-
|
586
|
-
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
632
|
+
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
587
633
|
"""
|
588
634
|
Perform a forward pass through the model.
|
589
635
|
|
@@ -591,13 +637,14 @@ class WorldModel(DetectionModel):
|
|
591
637
|
x (torch.Tensor): The input tensor.
|
592
638
|
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
|
593
639
|
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
640
|
+
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
|
594
641
|
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
595
642
|
embed (list, optional): A list of feature vectors/embeddings to return.
|
596
643
|
|
597
644
|
Returns:
|
598
645
|
(torch.Tensor): Model's output tensor.
|
599
646
|
"""
|
600
|
-
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
|
647
|
+
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
601
648
|
if len(txt_feats) != len(x):
|
602
649
|
txt_feats = txt_feats.repeat(len(x), 1, 1)
|
603
650
|
ori_txt_feats = txt_feats.clone()
|
@@ -625,6 +672,21 @@ class WorldModel(DetectionModel):
|
|
625
672
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
626
673
|
return x
|
627
674
|
|
675
|
+
def loss(self, batch, preds=None):
|
676
|
+
"""
|
677
|
+
Compute loss.
|
678
|
+
|
679
|
+
Args:
|
680
|
+
batch (dict): Batch to compute loss on.
|
681
|
+
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
682
|
+
"""
|
683
|
+
if not hasattr(self, "criterion"):
|
684
|
+
self.criterion = self.init_criterion()
|
685
|
+
|
686
|
+
if preds is None:
|
687
|
+
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
|
688
|
+
return self.criterion(preds, batch)
|
689
|
+
|
628
690
|
|
629
691
|
class Ensemble(nn.ModuleList):
|
630
692
|
"""Ensemble of models."""
|
@@ -646,7 +708,7 @@ class Ensemble(nn.ModuleList):
|
|
646
708
|
|
647
709
|
|
648
710
|
@contextlib.contextmanager
|
649
|
-
def temporary_modules(modules=None):
|
711
|
+
def temporary_modules(modules=None, attributes=None):
|
650
712
|
"""
|
651
713
|
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
652
714
|
|
@@ -656,11 +718,13 @@ def temporary_modules(modules=None):
|
|
656
718
|
|
657
719
|
Args:
|
658
720
|
modules (dict, optional): A dictionary mapping old module paths to new module paths.
|
721
|
+
attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.
|
659
722
|
|
660
723
|
Example:
|
661
724
|
```python
|
662
|
-
with temporary_modules({
|
663
|
-
import old.module
|
725
|
+
with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
|
726
|
+
import old.module # this will now import new.module
|
727
|
+
from old.module import attribute # this will now import new.module.attribute
|
664
728
|
```
|
665
729
|
|
666
730
|
Note:
|
@@ -668,16 +732,23 @@ def temporary_modules(modules=None):
|
|
668
732
|
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
669
733
|
applications or libraries. Use this function with caution.
|
670
734
|
"""
|
671
|
-
if
|
735
|
+
if modules is None:
|
672
736
|
modules = {}
|
673
|
-
|
674
|
-
|
737
|
+
if attributes is None:
|
738
|
+
attributes = {}
|
675
739
|
import sys
|
740
|
+
from importlib import import_module
|
676
741
|
|
677
742
|
try:
|
743
|
+
# Set attributes in sys.modules under their old name
|
744
|
+
for old, new in attributes.items():
|
745
|
+
old_module, old_attr = old.rsplit(".", 1)
|
746
|
+
new_module, new_attr = new.rsplit(".", 1)
|
747
|
+
setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
|
748
|
+
|
678
749
|
# Set modules in sys.modules under their old name
|
679
750
|
for old, new in modules.items():
|
680
|
-
sys.modules[old] =
|
751
|
+
sys.modules[old] = import_module(new)
|
681
752
|
|
682
753
|
yield
|
683
754
|
finally:
|
@@ -687,17 +758,58 @@ def temporary_modules(modules=None):
|
|
687
758
|
del sys.modules[old]
|
688
759
|
|
689
760
|
|
690
|
-
|
761
|
+
class SafeClass:
|
762
|
+
"""A placeholder class to replace unknown classes during unpickling."""
|
763
|
+
|
764
|
+
def __init__(self, *args, **kwargs):
|
765
|
+
"""Initialize SafeClass instance, ignoring all arguments."""
|
766
|
+
pass
|
767
|
+
|
768
|
+
def __call__(self, *args, **kwargs):
|
769
|
+
"""Run SafeClass instance, ignoring all arguments."""
|
770
|
+
pass
|
771
|
+
|
772
|
+
|
773
|
+
class SafeUnpickler(pickle.Unpickler):
|
774
|
+
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
775
|
+
|
776
|
+
def find_class(self, module, name):
|
777
|
+
"""Attempt to find a class, returning SafeClass if not among safe modules."""
|
778
|
+
safe_modules = (
|
779
|
+
"torch",
|
780
|
+
"collections",
|
781
|
+
"collections.abc",
|
782
|
+
"builtins",
|
783
|
+
"math",
|
784
|
+
"numpy",
|
785
|
+
# Add other modules considered safe
|
786
|
+
)
|
787
|
+
if module in safe_modules:
|
788
|
+
return super().find_class(module, name)
|
789
|
+
else:
|
790
|
+
return SafeClass
|
791
|
+
|
792
|
+
|
793
|
+
def torch_safe_load(weight, safe_only=False):
|
691
794
|
"""
|
692
|
-
|
693
|
-
|
694
|
-
|
795
|
+
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
|
796
|
+
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
|
797
|
+
After installation, the function again attempts to load the model using torch.load().
|
695
798
|
|
696
799
|
Args:
|
697
800
|
weight (str): The file path of the PyTorch model.
|
801
|
+
safe_only (bool): If True, replace unknown classes with SafeClass during loading.
|
802
|
+
|
803
|
+
Example:
|
804
|
+
```python
|
805
|
+
from ultralytics.nn.tasks import torch_safe_load
|
806
|
+
|
807
|
+
ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
|
808
|
+
```
|
698
809
|
|
699
810
|
Returns:
|
700
|
-
(dict): The loaded
|
811
|
+
ckpt (dict): The loaded model checkpoint.
|
812
|
+
file (str): The loaded filename
|
701
813
|
"""
|
702
814
|
from ultralytics.utils.downloads import attempt_download_asset
|
703
815
|
|
@@ -705,13 +817,26 @@ def torch_safe_load(weight):
|
|
705
817
|
file = attempt_download_asset(weight) # search online if missing locally
|
706
818
|
try:
|
707
819
|
with temporary_modules(
|
708
|
-
{
|
820
|
+
modules={
|
709
821
|
"ultralytics.yolo.utils": "ultralytics.utils",
|
710
822
|
"ultralytics.yolo.v8": "ultralytics.models.yolo",
|
711
823
|
"ultralytics.yolo.data": "ultralytics.data",
|
712
|
-
}
|
713
|
-
|
714
|
-
|
824
|
+
},
|
825
|
+
attributes={
|
826
|
+
"ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
|
827
|
+
"ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
|
828
|
+
"ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
|
829
|
+
},
|
830
|
+
):
|
831
|
+
if safe_only:
|
832
|
+
# Load via custom pickle module
|
833
|
+
safe_pickle = types.ModuleType("safe_pickle")
|
834
|
+
safe_pickle.Unpickler = SafeUnpickler
|
835
|
+
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
|
836
|
+
with open(file, "rb") as f:
|
837
|
+
ckpt = torch.load(f, pickle_module=safe_pickle)
|
838
|
+
else:
|
839
|
+
ckpt = torch.load(file, map_location="cpu")
|
715
840
|
|
716
841
|
except ModuleNotFoundError as e: # e.name is missing module name
|
717
842
|
if e.name == "models":
|
@@ -721,14 +846,14 @@ def torch_safe_load(weight):
|
|
721
846
|
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
722
847
|
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
723
848
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
724
|
-
f"run a command with an official
|
849
|
+
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
|
725
850
|
)
|
726
851
|
) from e
|
727
852
|
LOGGER.warning(
|
728
|
-
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in
|
853
|
+
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements."
|
729
854
|
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
730
855
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
731
|
-
f"run a command with an official
|
856
|
+
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
|
732
857
|
)
|
733
858
|
check_requirements(e.name) # install missing module
|
734
859
|
ckpt = torch.load(file, map_location="cpu")
|
@@ -741,12 +866,11 @@ def torch_safe_load(weight):
|
|
741
866
|
)
|
742
867
|
ckpt = {"model": ckpt.model}
|
743
868
|
|
744
|
-
return ckpt, file
|
869
|
+
return ckpt, file
|
745
870
|
|
746
871
|
|
747
872
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
748
873
|
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
|
749
|
-
|
750
874
|
ensemble = Ensemble()
|
751
875
|
for w in weights if isinstance(weights, list) else [weights]:
|
752
876
|
ckpt, w = torch_safe_load(w) # load ckpt
|
@@ -814,6 +938,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
814
938
|
import ast
|
815
939
|
|
816
940
|
# Args
|
941
|
+
legacy = True # backward compatibility for v3/v5/v8/v9 models
|
817
942
|
max_channels = float("inf")
|
818
943
|
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
819
944
|
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
@@ -839,9 +964,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
839
964
|
if isinstance(a, str):
|
840
965
|
with contextlib.suppress(ValueError):
|
841
966
|
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
842
|
-
|
843
967
|
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
844
|
-
if m in
|
968
|
+
if m in {
|
845
969
|
Classify,
|
846
970
|
Conv,
|
847
971
|
ConvTranspose,
|
@@ -850,14 +974,19 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
850
974
|
GhostBottleneck,
|
851
975
|
SPP,
|
852
976
|
SPPF,
|
977
|
+
C2fPSA,
|
978
|
+
C2PSA,
|
853
979
|
DWConv,
|
854
980
|
Focus,
|
855
981
|
BottleneckCSP,
|
856
982
|
C1,
|
857
983
|
C2,
|
858
984
|
C2f,
|
985
|
+
C3k2,
|
859
986
|
RepNCSPELAN4,
|
987
|
+
ELAN1,
|
860
988
|
ADown,
|
989
|
+
AConv,
|
861
990
|
SPPELAN,
|
862
991
|
C2fAttn,
|
863
992
|
C3,
|
@@ -867,7 +996,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
867
996
|
DWConvTranspose2d,
|
868
997
|
C3x,
|
869
998
|
RepC3,
|
870
|
-
|
999
|
+
PSA,
|
1000
|
+
SCDown,
|
1001
|
+
C2fCIB,
|
1002
|
+
}:
|
871
1003
|
c1, c2 = ch[f], args[0]
|
872
1004
|
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
873
1005
|
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
@@ -878,12 +1010,31 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
878
1010
|
) # num heads
|
879
1011
|
|
880
1012
|
args = [c1, c2, *args[1:]]
|
881
|
-
if m in
|
1013
|
+
if m in {
|
1014
|
+
BottleneckCSP,
|
1015
|
+
C1,
|
1016
|
+
C2,
|
1017
|
+
C2f,
|
1018
|
+
C3k2,
|
1019
|
+
C2fAttn,
|
1020
|
+
C3,
|
1021
|
+
C3TR,
|
1022
|
+
C3Ghost,
|
1023
|
+
C3x,
|
1024
|
+
RepC3,
|
1025
|
+
C2fPSA,
|
1026
|
+
C2fCIB,
|
1027
|
+
C2PSA,
|
1028
|
+
}:
|
882
1029
|
args.insert(2, n) # number of repeats
|
883
1030
|
n = 1
|
1031
|
+
if m is C3k2: # for M/L/X sizes
|
1032
|
+
legacy = False
|
1033
|
+
if scale in "mlx":
|
1034
|
+
args[3] = True
|
884
1035
|
elif m is AIFI:
|
885
1036
|
args = [ch[f], *args]
|
886
|
-
elif m in
|
1037
|
+
elif m in {HGStem, HGBlock}:
|
887
1038
|
c1, cm, c2 = ch[f], args[0], args[1]
|
888
1039
|
args = [c1, cm, c2, *args[2:]]
|
889
1040
|
if m is HGBlock:
|
@@ -895,13 +1046,15 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
895
1046
|
args = [ch[f]]
|
896
1047
|
elif m is Concat:
|
897
1048
|
c2 = sum(ch[x] for x in f)
|
898
|
-
elif m in
|
1049
|
+
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
|
899
1050
|
args.append([ch[x] for x in f])
|
900
1051
|
if m is Segment:
|
901
1052
|
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
1053
|
+
if m in {Detect, Segment, Pose, OBB}:
|
1054
|
+
m.legacy = legacy
|
902
1055
|
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
903
1056
|
args.insert(1, [ch[x] for x in f])
|
904
|
-
elif m
|
1057
|
+
elif m in {CBLinear, TorchVision, Index}:
|
905
1058
|
c2 = args[0]
|
906
1059
|
c1 = ch[f]
|
907
1060
|
args = [c1, c2, *args[1:]]
|
@@ -912,10 +1065,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
912
1065
|
|
913
1066
|
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
914
1067
|
t = str(m)[8:-2].replace("__main__.", "") # module type
|
915
|
-
|
1068
|
+
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
916
1069
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
917
1070
|
if verbose:
|
918
|
-
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{
|
1071
|
+
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
|
919
1072
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
920
1073
|
layers.append(m_)
|
921
1074
|
if i == 0:
|
@@ -926,8 +1079,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
926
1079
|
|
927
1080
|
def yaml_model_load(path):
|
928
1081
|
"""Load a YOLOv8 model from a YAML file."""
|
929
|
-
import re
|
930
|
-
|
931
1082
|
path = Path(path)
|
932
1083
|
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
|
933
1084
|
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
|
@@ -954,11 +1105,10 @@ def guess_model_scale(model_path):
|
|
954
1105
|
Returns:
|
955
1106
|
(str): The size character of the model's scale, which can be n, s, m, l, or x.
|
956
1107
|
"""
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
return
|
961
|
-
return ""
|
1108
|
+
try:
|
1109
|
+
return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # noqa, returns n, s, m, l, or x
|
1110
|
+
except AttributeError:
|
1111
|
+
return ""
|
962
1112
|
|
963
1113
|
|
964
1114
|
def guess_model_task(model):
|
@@ -978,9 +1128,9 @@ def guess_model_task(model):
|
|
978
1128
|
def cfg2task(cfg):
|
979
1129
|
"""Guess from YAML dictionary."""
|
980
1130
|
m = cfg["head"][-1][-2].lower() # output module name
|
981
|
-
if m in
|
1131
|
+
if m in {"classify", "classifier", "cls", "fc"}:
|
982
1132
|
return "classify"
|
983
|
-
if
|
1133
|
+
if "detect" in m:
|
984
1134
|
return "detect"
|
985
1135
|
if m == "segment":
|
986
1136
|
return "segment"
|
@@ -993,7 +1143,6 @@ def guess_model_task(model):
|
|
993
1143
|
if isinstance(model, dict):
|
994
1144
|
with contextlib.suppress(Exception):
|
995
1145
|
return cfg2task(model)
|
996
|
-
|
997
1146
|
# Guess from PyTorch model
|
998
1147
|
if isinstance(model, nn.Module): # PyTorch model
|
999
1148
|
for x in "model.args", "model.model.args", "model.model.model.args":
|
@@ -1002,7 +1151,6 @@ def guess_model_task(model):
|
|
1002
1151
|
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
1003
1152
|
with contextlib.suppress(Exception):
|
1004
1153
|
return cfg2task(eval(x))
|
1005
|
-
|
1006
1154
|
for m in model.modules():
|
1007
1155
|
if isinstance(m, Segment):
|
1008
1156
|
return "segment"
|
@@ -1012,7 +1160,7 @@ def guess_model_task(model):
|
|
1012
1160
|
return "pose"
|
1013
1161
|
elif isinstance(m, OBB):
|
1014
1162
|
return "obb"
|
1015
|
-
elif isinstance(m, (Detect, WorldDetect)):
|
1163
|
+
elif isinstance(m, (Detect, WorldDetect, v10Detect)):
|
1016
1164
|
return "detect"
|
1017
1165
|
|
1018
1166
|
# Guess from model filename
|
@@ -1 +1,30 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from .ai_gym import AIGym
|
4
|
+
from .analytics import Analytics
|
5
|
+
from .distance_calculation import DistanceCalculation
|
6
|
+
from .heatmap import Heatmap
|
7
|
+
from .object_counter import ObjectCounter
|
8
|
+
from .parking_management import ParkingManagement, ParkingPtsSelection
|
9
|
+
from .queue_management import QueueManager
|
10
|
+
from .region_counter import RegionCounter
|
11
|
+
from .security_alarm import SecurityAlarm
|
12
|
+
from .speed_estimation import SpeedEstimator
|
13
|
+
from .streamlit_inference import Inference
|
14
|
+
from .trackzone import TrackZone
|
15
|
+
|
16
|
+
__all__ = (
|
17
|
+
"AIGym",
|
18
|
+
"DistanceCalculation",
|
19
|
+
"Heatmap",
|
20
|
+
"ObjectCounter",
|
21
|
+
"ParkingManagement",
|
22
|
+
"ParkingPtsSelection",
|
23
|
+
"QueueManager",
|
24
|
+
"SpeedEstimator",
|
25
|
+
"Analytics",
|
26
|
+
"Inference",
|
27
|
+
"RegionCounter",
|
28
|
+
"TrackZone",
|
29
|
+
"SecurityAlarm",
|
30
|
+
)
|