dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py
CHANGED
|
@@ -69,7 +69,7 @@ from ultralytics.nn.modules import (
|
|
|
69
69
|
YOLOESegment,
|
|
70
70
|
v10Detect,
|
|
71
71
|
)
|
|
72
|
-
from ultralytics.utils import DEFAULT_CFG_DICT,
|
|
72
|
+
from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, YAML, colorstr, emojis
|
|
73
73
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
|
74
74
|
from ultralytics.utils.loss import (
|
|
75
75
|
E2EDetectLoss,
|
|
@@ -80,6 +80,7 @@ from ultralytics.utils.loss import (
|
|
|
80
80
|
v8SegmentationLoss,
|
|
81
81
|
)
|
|
82
82
|
from ultralytics.utils.ops import make_divisible
|
|
83
|
+
from ultralytics.utils.patches import torch_load
|
|
83
84
|
from ultralytics.utils.plotting import feature_visualization
|
|
84
85
|
from ultralytics.utils.torch_utils import (
|
|
85
86
|
fuse_conv_and_bn,
|
|
@@ -94,11 +95,32 @@ from ultralytics.utils.torch_utils import (
|
|
|
94
95
|
|
|
95
96
|
|
|
96
97
|
class BaseModel(torch.nn.Module):
|
|
97
|
-
"""
|
|
98
|
+
"""Base class for all YOLO models in the Ultralytics family.
|
|
99
|
+
|
|
100
|
+
This class provides common functionality for YOLO models including forward pass handling, model fusion, information
|
|
101
|
+
display, and weight loading capabilities.
|
|
102
|
+
|
|
103
|
+
Attributes:
|
|
104
|
+
model (torch.nn.Module): The neural network model.
|
|
105
|
+
save (list): List of layer indices to save outputs from.
|
|
106
|
+
stride (torch.Tensor): Model stride values.
|
|
107
|
+
|
|
108
|
+
Methods:
|
|
109
|
+
forward: Perform forward pass for training or inference.
|
|
110
|
+
predict: Perform inference on input tensor.
|
|
111
|
+
fuse: Fuse Conv2d and BatchNorm2d layers for optimization.
|
|
112
|
+
info: Print model information.
|
|
113
|
+
load: Load weights into the model.
|
|
114
|
+
loss: Compute loss for training.
|
|
115
|
+
|
|
116
|
+
Examples:
|
|
117
|
+
Create a BaseModel instance
|
|
118
|
+
>>> model = BaseModel()
|
|
119
|
+
>>> model.info() # Display model information
|
|
120
|
+
"""
|
|
98
121
|
|
|
99
122
|
def forward(self, x, *args, **kwargs):
|
|
100
|
-
"""
|
|
101
|
-
Perform forward pass of the model for either training or inference.
|
|
123
|
+
"""Perform forward pass of the model for either training or inference.
|
|
102
124
|
|
|
103
125
|
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
|
104
126
|
|
|
@@ -115,8 +137,7 @@ class BaseModel(torch.nn.Module):
|
|
|
115
137
|
return self.predict(x, *args, **kwargs)
|
|
116
138
|
|
|
117
139
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
|
118
|
-
"""
|
|
119
|
-
Perform a forward pass through the network.
|
|
140
|
+
"""Perform a forward pass through the network.
|
|
120
141
|
|
|
121
142
|
Args:
|
|
122
143
|
x (torch.Tensor): The input tensor to the model.
|
|
@@ -133,8 +154,7 @@ class BaseModel(torch.nn.Module):
|
|
|
133
154
|
return self._predict_once(x, profile, visualize, embed)
|
|
134
155
|
|
|
135
156
|
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
|
136
|
-
"""
|
|
137
|
-
Perform a forward pass through the network.
|
|
157
|
+
"""Perform a forward pass through the network.
|
|
138
158
|
|
|
139
159
|
Args:
|
|
140
160
|
x (torch.Tensor): The input tensor to the model.
|
|
@@ -172,8 +192,7 @@ class BaseModel(torch.nn.Module):
|
|
|
172
192
|
return self._predict_once(x)
|
|
173
193
|
|
|
174
194
|
def _profile_one_layer(self, m, x, dt):
|
|
175
|
-
"""
|
|
176
|
-
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
|
195
|
+
"""Profile the computation time and FLOPs of a single layer of the model on a given input.
|
|
177
196
|
|
|
178
197
|
Args:
|
|
179
198
|
m (torch.nn.Module): The layer to be profiled.
|
|
@@ -198,8 +217,7 @@ class BaseModel(torch.nn.Module):
|
|
|
198
217
|
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
|
199
218
|
|
|
200
219
|
def fuse(self, verbose=True):
|
|
201
|
-
"""
|
|
202
|
-
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
|
220
|
+
"""Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
|
203
221
|
efficiency.
|
|
204
222
|
|
|
205
223
|
Returns:
|
|
@@ -230,8 +248,7 @@ class BaseModel(torch.nn.Module):
|
|
|
230
248
|
return self
|
|
231
249
|
|
|
232
250
|
def is_fused(self, thresh=10):
|
|
233
|
-
"""
|
|
234
|
-
Check if the model has less than a certain threshold of BatchNorm layers.
|
|
251
|
+
"""Check if the model has less than a certain threshold of BatchNorm layers.
|
|
235
252
|
|
|
236
253
|
Args:
|
|
237
254
|
thresh (int, optional): The threshold number of BatchNorm layers.
|
|
@@ -243,8 +260,7 @@ class BaseModel(torch.nn.Module):
|
|
|
243
260
|
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
|
244
261
|
|
|
245
262
|
def info(self, detailed=False, verbose=True, imgsz=640):
|
|
246
|
-
"""
|
|
247
|
-
Print model information.
|
|
263
|
+
"""Print model information.
|
|
248
264
|
|
|
249
265
|
Args:
|
|
250
266
|
detailed (bool): If True, prints out detailed information about the model.
|
|
@@ -254,8 +270,7 @@ class BaseModel(torch.nn.Module):
|
|
|
254
270
|
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
|
255
271
|
|
|
256
272
|
def _apply(self, fn):
|
|
257
|
-
"""
|
|
258
|
-
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
273
|
+
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
259
274
|
|
|
260
275
|
Args:
|
|
261
276
|
fn (function): The function to apply to the model.
|
|
@@ -274,8 +289,7 @@ class BaseModel(torch.nn.Module):
|
|
|
274
289
|
return self
|
|
275
290
|
|
|
276
291
|
def load(self, weights, verbose=True):
|
|
277
|
-
"""
|
|
278
|
-
Load weights into the model.
|
|
292
|
+
"""Load weights into the model.
|
|
279
293
|
|
|
280
294
|
Args:
|
|
281
295
|
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
|
@@ -300,17 +314,17 @@ class BaseModel(torch.nn.Module):
|
|
|
300
314
|
LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
|
|
301
315
|
|
|
302
316
|
def loss(self, batch, preds=None):
|
|
303
|
-
"""
|
|
304
|
-
Compute loss.
|
|
317
|
+
"""Compute loss.
|
|
305
318
|
|
|
306
319
|
Args:
|
|
307
320
|
batch (dict): Batch to compute loss on.
|
|
308
|
-
preds (torch.Tensor |
|
|
321
|
+
preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
|
|
309
322
|
"""
|
|
310
323
|
if getattr(self, "criterion", None) is None:
|
|
311
324
|
self.criterion = self.init_criterion()
|
|
312
325
|
|
|
313
|
-
|
|
326
|
+
if preds is None:
|
|
327
|
+
preds = self.forward(batch["img"])
|
|
314
328
|
return self.criterion(preds, batch)
|
|
315
329
|
|
|
316
330
|
def init_criterion(self):
|
|
@@ -319,11 +333,35 @@ class BaseModel(torch.nn.Module):
|
|
|
319
333
|
|
|
320
334
|
|
|
321
335
|
class DetectionModel(BaseModel):
|
|
322
|
-
"""YOLO detection model.
|
|
336
|
+
"""YOLO detection model.
|
|
337
|
+
|
|
338
|
+
This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
|
|
339
|
+
inference, and loss computation for object detection tasks.
|
|
340
|
+
|
|
341
|
+
Attributes:
|
|
342
|
+
yaml (dict): Model configuration dictionary.
|
|
343
|
+
model (torch.nn.Sequential): The neural network model.
|
|
344
|
+
save (list): List of layer indices to save outputs from.
|
|
345
|
+
names (dict): Class names dictionary.
|
|
346
|
+
inplace (bool): Whether to use inplace operations.
|
|
347
|
+
end2end (bool): Whether the model uses end-to-end detection.
|
|
348
|
+
stride (torch.Tensor): Model stride values.
|
|
349
|
+
|
|
350
|
+
Methods:
|
|
351
|
+
__init__: Initialize the YOLO detection model.
|
|
352
|
+
_predict_augment: Perform augmented inference.
|
|
353
|
+
_descale_pred: De-scale predictions following augmented inference.
|
|
354
|
+
_clip_augmented: Clip YOLO augmented inference tails.
|
|
355
|
+
init_criterion: Initialize the loss criterion.
|
|
356
|
+
|
|
357
|
+
Examples:
|
|
358
|
+
Initialize a detection model
|
|
359
|
+
>>> model = DetectionModel("yolo11n.yaml", ch=3, nc=80)
|
|
360
|
+
>>> results = model.predict(image_tensor)
|
|
361
|
+
"""
|
|
323
362
|
|
|
324
363
|
def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
|
|
325
|
-
"""
|
|
326
|
-
Initialize the YOLO detection model with the given config and parameters.
|
|
364
|
+
"""Initialize the YOLO detection model with the given config and parameters.
|
|
327
365
|
|
|
328
366
|
Args:
|
|
329
367
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -362,8 +400,11 @@ class DetectionModel(BaseModel):
|
|
|
362
400
|
return self.forward(x)["one2many"]
|
|
363
401
|
return self.forward(x)[0] if isinstance(m, (Segment, YOLOESegment, Pose, OBB)) else self.forward(x)
|
|
364
402
|
|
|
403
|
+
self.model.eval() # Avoid changing batch statistics until training begins
|
|
404
|
+
m.training = True # Setting it to True to properly return strides
|
|
365
405
|
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
|
|
366
406
|
self.stride = m.stride
|
|
407
|
+
self.model.train() # Set model back to training(default) mode
|
|
367
408
|
m.bias_init() # only run once
|
|
368
409
|
else:
|
|
369
410
|
self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
|
|
@@ -375,8 +416,7 @@ class DetectionModel(BaseModel):
|
|
|
375
416
|
LOGGER.info("")
|
|
376
417
|
|
|
377
418
|
def _predict_augment(self, x):
|
|
378
|
-
"""
|
|
379
|
-
Perform augmentations on input image x and return augmented inference and train outputs.
|
|
419
|
+
"""Perform augmentations on input image x and return augmented inference and train outputs.
|
|
380
420
|
|
|
381
421
|
Args:
|
|
382
422
|
x (torch.Tensor): Input image tensor.
|
|
@@ -401,8 +441,7 @@ class DetectionModel(BaseModel):
|
|
|
401
441
|
|
|
402
442
|
@staticmethod
|
|
403
443
|
def _descale_pred(p, flips, scale, img_size, dim=1):
|
|
404
|
-
"""
|
|
405
|
-
De-scale predictions following augmented inference (inverse operation).
|
|
444
|
+
"""De-scale predictions following augmented inference (inverse operation).
|
|
406
445
|
|
|
407
446
|
Args:
|
|
408
447
|
p (torch.Tensor): Predictions tensor.
|
|
@@ -423,14 +462,13 @@ class DetectionModel(BaseModel):
|
|
|
423
462
|
return torch.cat((x, y, wh, cls), dim)
|
|
424
463
|
|
|
425
464
|
def _clip_augmented(self, y):
|
|
426
|
-
"""
|
|
427
|
-
Clip YOLO augmented inference tails.
|
|
465
|
+
"""Clip YOLO augmented inference tails.
|
|
428
466
|
|
|
429
467
|
Args:
|
|
430
|
-
y (
|
|
468
|
+
y (list[torch.Tensor]): List of detection tensors.
|
|
431
469
|
|
|
432
470
|
Returns:
|
|
433
|
-
(
|
|
471
|
+
(list[torch.Tensor]): Clipped detection tensors.
|
|
434
472
|
"""
|
|
435
473
|
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
|
436
474
|
g = sum(4**x for x in range(nl)) # grid points
|
|
@@ -447,11 +485,23 @@ class DetectionModel(BaseModel):
|
|
|
447
485
|
|
|
448
486
|
|
|
449
487
|
class OBBModel(DetectionModel):
|
|
450
|
-
"""YOLO Oriented Bounding Box (OBB) model.
|
|
488
|
+
"""YOLO Oriented Bounding Box (OBB) model.
|
|
489
|
+
|
|
490
|
+
This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
|
|
491
|
+
computation for rotated object detection.
|
|
492
|
+
|
|
493
|
+
Methods:
|
|
494
|
+
__init__: Initialize YOLO OBB model.
|
|
495
|
+
init_criterion: Initialize the loss criterion for OBB detection.
|
|
496
|
+
|
|
497
|
+
Examples:
|
|
498
|
+
Initialize an OBB model
|
|
499
|
+
>>> model = OBBModel("yolo11n-obb.yaml", ch=3, nc=80)
|
|
500
|
+
>>> results = model.predict(image_tensor)
|
|
501
|
+
"""
|
|
451
502
|
|
|
452
503
|
def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
|
|
453
|
-
"""
|
|
454
|
-
Initialize YOLO OBB model with given config and parameters.
|
|
504
|
+
"""Initialize YOLO OBB model with given config and parameters.
|
|
455
505
|
|
|
456
506
|
Args:
|
|
457
507
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -467,11 +517,23 @@ class OBBModel(DetectionModel):
|
|
|
467
517
|
|
|
468
518
|
|
|
469
519
|
class SegmentationModel(DetectionModel):
|
|
470
|
-
"""YOLO segmentation model.
|
|
520
|
+
"""YOLO segmentation model.
|
|
521
|
+
|
|
522
|
+
This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
|
|
523
|
+
pixel-level object detection and segmentation.
|
|
524
|
+
|
|
525
|
+
Methods:
|
|
526
|
+
__init__: Initialize YOLO segmentation model.
|
|
527
|
+
init_criterion: Initialize the loss criterion for segmentation.
|
|
528
|
+
|
|
529
|
+
Examples:
|
|
530
|
+
Initialize a segmentation model
|
|
531
|
+
>>> model = SegmentationModel("yolo11n-seg.yaml", ch=3, nc=80)
|
|
532
|
+
>>> results = model.predict(image_tensor)
|
|
533
|
+
"""
|
|
471
534
|
|
|
472
535
|
def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
|
|
473
|
-
"""
|
|
474
|
-
Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
|
536
|
+
"""Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
|
475
537
|
|
|
476
538
|
Args:
|
|
477
539
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -487,11 +549,26 @@ class SegmentationModel(DetectionModel):
|
|
|
487
549
|
|
|
488
550
|
|
|
489
551
|
class PoseModel(DetectionModel):
|
|
490
|
-
"""YOLO pose model.
|
|
552
|
+
"""YOLO pose model.
|
|
553
|
+
|
|
554
|
+
This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
|
|
555
|
+
keypoint detection and pose estimation.
|
|
556
|
+
|
|
557
|
+
Attributes:
|
|
558
|
+
kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).
|
|
559
|
+
|
|
560
|
+
Methods:
|
|
561
|
+
__init__: Initialize YOLO pose model.
|
|
562
|
+
init_criterion: Initialize the loss criterion for pose estimation.
|
|
563
|
+
|
|
564
|
+
Examples:
|
|
565
|
+
Initialize a pose model
|
|
566
|
+
>>> model = PoseModel("yolo11n-pose.yaml", ch=3, nc=1, data_kpt_shape=(17, 3))
|
|
567
|
+
>>> results = model.predict(image_tensor)
|
|
568
|
+
"""
|
|
491
569
|
|
|
492
570
|
def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
|
493
|
-
"""
|
|
494
|
-
Initialize Ultralytics YOLO Pose model.
|
|
571
|
+
"""Initialize Ultralytics YOLO Pose model.
|
|
495
572
|
|
|
496
573
|
Args:
|
|
497
574
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -513,11 +590,31 @@ class PoseModel(DetectionModel):
|
|
|
513
590
|
|
|
514
591
|
|
|
515
592
|
class ClassificationModel(BaseModel):
|
|
516
|
-
"""YOLO classification model.
|
|
593
|
+
"""YOLO classification model.
|
|
594
|
+
|
|
595
|
+
This class implements the YOLO classification architecture for image classification tasks, providing model
|
|
596
|
+
initialization, configuration, and output reshaping capabilities.
|
|
597
|
+
|
|
598
|
+
Attributes:
|
|
599
|
+
yaml (dict): Model configuration dictionary.
|
|
600
|
+
model (torch.nn.Sequential): The neural network model.
|
|
601
|
+
stride (torch.Tensor): Model stride values.
|
|
602
|
+
names (dict): Class names dictionary.
|
|
603
|
+
|
|
604
|
+
Methods:
|
|
605
|
+
__init__: Initialize ClassificationModel.
|
|
606
|
+
_from_yaml: Set model configurations and define architecture.
|
|
607
|
+
reshape_outputs: Update model to specified class count.
|
|
608
|
+
init_criterion: Initialize the loss criterion.
|
|
609
|
+
|
|
610
|
+
Examples:
|
|
611
|
+
Initialize a classification model
|
|
612
|
+
>>> model = ClassificationModel("yolo11n-cls.yaml", ch=3, nc=1000)
|
|
613
|
+
>>> results = model.predict(image_tensor)
|
|
614
|
+
"""
|
|
517
615
|
|
|
518
616
|
def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
|
|
519
|
-
"""
|
|
520
|
-
Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
|
617
|
+
"""Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
|
521
618
|
|
|
522
619
|
Args:
|
|
523
620
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -529,8 +626,7 @@ class ClassificationModel(BaseModel):
|
|
|
529
626
|
self._from_yaml(cfg, ch, nc, verbose)
|
|
530
627
|
|
|
531
628
|
def _from_yaml(self, cfg, ch, nc, verbose):
|
|
532
|
-
"""
|
|
533
|
-
Set Ultralytics YOLO model configurations and define the model architecture.
|
|
629
|
+
"""Set Ultralytics YOLO model configurations and define the model architecture.
|
|
534
630
|
|
|
535
631
|
Args:
|
|
536
632
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -554,8 +650,7 @@ class ClassificationModel(BaseModel):
|
|
|
554
650
|
|
|
555
651
|
@staticmethod
|
|
556
652
|
def reshape_outputs(model, nc):
|
|
557
|
-
"""
|
|
558
|
-
Update a TorchVision classification model to class count 'n' if required.
|
|
653
|
+
"""Update a TorchVision classification model to class count 'n' if required.
|
|
559
654
|
|
|
560
655
|
Args:
|
|
561
656
|
model (torch.nn.Module): Model to update.
|
|
@@ -587,22 +682,30 @@ class ClassificationModel(BaseModel):
|
|
|
587
682
|
|
|
588
683
|
|
|
589
684
|
class RTDETRDetectionModel(DetectionModel):
|
|
590
|
-
"""
|
|
591
|
-
RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
|
685
|
+
"""RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
|
592
686
|
|
|
593
687
|
This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
|
|
594
688
|
the training and inference processes. RTDETR is an object detection and tracking model that extends from the
|
|
595
689
|
DetectionModel base class.
|
|
596
690
|
|
|
691
|
+
Attributes:
|
|
692
|
+
nc (int): Number of classes for detection.
|
|
693
|
+
criterion (RTDETRDetectionLoss): Loss function for training.
|
|
694
|
+
|
|
597
695
|
Methods:
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
696
|
+
__init__: Initialize the RTDETRDetectionModel.
|
|
697
|
+
init_criterion: Initialize the loss criterion.
|
|
698
|
+
loss: Compute loss for training.
|
|
699
|
+
predict: Perform forward pass through the model.
|
|
700
|
+
|
|
701
|
+
Examples:
|
|
702
|
+
Initialize an RTDETR model
|
|
703
|
+
>>> model = RTDETRDetectionModel("rtdetr-l.yaml", ch=3, nc=80)
|
|
704
|
+
>>> results = model.predict(image_tensor)
|
|
601
705
|
"""
|
|
602
706
|
|
|
603
707
|
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
|
|
604
|
-
"""
|
|
605
|
-
Initialize the RTDETRDetectionModel.
|
|
708
|
+
"""Initialize the RTDETRDetectionModel.
|
|
606
709
|
|
|
607
710
|
Args:
|
|
608
711
|
cfg (str | dict): Configuration file name or path.
|
|
@@ -612,6 +715,21 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
612
715
|
"""
|
|
613
716
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
614
717
|
|
|
718
|
+
def _apply(self, fn):
|
|
719
|
+
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
720
|
+
|
|
721
|
+
Args:
|
|
722
|
+
fn (function): The function to apply to the model.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
(RTDETRDetectionModel): An updated BaseModel object.
|
|
726
|
+
"""
|
|
727
|
+
self = super()._apply(fn)
|
|
728
|
+
m = self.model[-1]
|
|
729
|
+
m.anchors = fn(m.anchors)
|
|
730
|
+
m.valid_mask = fn(m.valid_mask)
|
|
731
|
+
return self
|
|
732
|
+
|
|
615
733
|
def init_criterion(self):
|
|
616
734
|
"""Initialize the loss criterion for the RTDETRDetectionModel."""
|
|
617
735
|
from ultralytics.models.utils.loss import RTDETRDetectionLoss
|
|
@@ -619,22 +737,22 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
619
737
|
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
|
620
738
|
|
|
621
739
|
def loss(self, batch, preds=None):
|
|
622
|
-
"""
|
|
623
|
-
Compute the loss for the given batch of data.
|
|
740
|
+
"""Compute the loss for the given batch of data.
|
|
624
741
|
|
|
625
742
|
Args:
|
|
626
743
|
batch (dict): Dictionary containing image and label data.
|
|
627
744
|
preds (torch.Tensor, optional): Precomputed model predictions.
|
|
628
745
|
|
|
629
746
|
Returns:
|
|
630
|
-
(
|
|
747
|
+
loss_sum (torch.Tensor): Total loss value.
|
|
748
|
+
loss_items (torch.Tensor): Main three losses in a tensor.
|
|
631
749
|
"""
|
|
632
750
|
if not hasattr(self, "criterion"):
|
|
633
751
|
self.criterion = self.init_criterion()
|
|
634
752
|
|
|
635
753
|
img = batch["img"]
|
|
636
754
|
# NOTE: preprocess gt_bbox and gt_labels to list.
|
|
637
|
-
bs =
|
|
755
|
+
bs = img.shape[0]
|
|
638
756
|
batch_idx = batch["batch_idx"]
|
|
639
757
|
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
|
|
640
758
|
targets = {
|
|
@@ -644,7 +762,8 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
644
762
|
"gt_groups": gt_groups,
|
|
645
763
|
}
|
|
646
764
|
|
|
647
|
-
|
|
765
|
+
if preds is None:
|
|
766
|
+
preds = self.predict(img, batch=targets)
|
|
648
767
|
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
|
|
649
768
|
if dn_meta is None:
|
|
650
769
|
dn_bboxes, dn_scores = None, None
|
|
@@ -664,8 +783,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
664
783
|
)
|
|
665
784
|
|
|
666
785
|
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
|
667
|
-
"""
|
|
668
|
-
Perform a forward pass through the model.
|
|
786
|
+
"""Perform a forward pass through the model.
|
|
669
787
|
|
|
670
788
|
Args:
|
|
671
789
|
x (torch.Tensor): The input tensor.
|
|
@@ -700,11 +818,31 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
700
818
|
|
|
701
819
|
|
|
702
820
|
class WorldModel(DetectionModel):
|
|
703
|
-
"""YOLOv8 World Model.
|
|
821
|
+
"""YOLOv8 World Model.
|
|
822
|
+
|
|
823
|
+
This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
|
|
824
|
+
specification and CLIP model integration for zero-shot detection capabilities.
|
|
825
|
+
|
|
826
|
+
Attributes:
|
|
827
|
+
txt_feats (torch.Tensor): Text feature embeddings for classes.
|
|
828
|
+
clip_model (torch.nn.Module): CLIP model for text encoding.
|
|
829
|
+
|
|
830
|
+
Methods:
|
|
831
|
+
__init__: Initialize YOLOv8 world model.
|
|
832
|
+
set_classes: Set classes for offline inference.
|
|
833
|
+
get_text_pe: Get text positional embeddings.
|
|
834
|
+
predict: Perform forward pass with text features.
|
|
835
|
+
loss: Compute loss with text features.
|
|
836
|
+
|
|
837
|
+
Examples:
|
|
838
|
+
Initialize a world model
|
|
839
|
+
>>> model = WorldModel("yolov8s-world.yaml", ch=3, nc=80)
|
|
840
|
+
>>> model.set_classes(["person", "car", "bicycle"])
|
|
841
|
+
>>> results = model.predict(image_tensor)
|
|
842
|
+
"""
|
|
704
843
|
|
|
705
844
|
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
|
706
|
-
"""
|
|
707
|
-
Initialize YOLOv8 world model with given config and parameters.
|
|
845
|
+
"""Initialize YOLOv8 world model with given config and parameters.
|
|
708
846
|
|
|
709
847
|
Args:
|
|
710
848
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -717,24 +855,21 @@ class WorldModel(DetectionModel):
|
|
|
717
855
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
718
856
|
|
|
719
857
|
def set_classes(self, text, batch=80, cache_clip_model=True):
|
|
720
|
-
"""
|
|
721
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
858
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
722
859
|
|
|
723
860
|
Args:
|
|
724
|
-
text (
|
|
861
|
+
text (list[str]): List of class names.
|
|
725
862
|
batch (int): Batch size for processing text tokens.
|
|
726
863
|
cache_clip_model (bool): Whether to cache the CLIP model.
|
|
727
864
|
"""
|
|
728
865
|
self.txt_feats = self.get_text_pe(text, batch=batch, cache_clip_model=cache_clip_model)
|
|
729
866
|
self.model[-1].nc = len(text)
|
|
730
867
|
|
|
731
|
-
@smart_inference_mode()
|
|
732
868
|
def get_text_pe(self, text, batch=80, cache_clip_model=True):
|
|
733
|
-
"""
|
|
734
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
869
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
735
870
|
|
|
736
871
|
Args:
|
|
737
|
-
text (
|
|
872
|
+
text (list[str]): List of class names.
|
|
738
873
|
batch (int): Batch size for processing text tokens.
|
|
739
874
|
cache_clip_model (bool): Whether to cache the CLIP model.
|
|
740
875
|
|
|
@@ -754,8 +889,7 @@ class WorldModel(DetectionModel):
|
|
|
754
889
|
return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
|
755
890
|
|
|
756
891
|
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
|
757
|
-
"""
|
|
758
|
-
Perform a forward pass through the model.
|
|
892
|
+
"""Perform a forward pass through the model.
|
|
759
893
|
|
|
760
894
|
Args:
|
|
761
895
|
x (torch.Tensor): The input tensor.
|
|
@@ -769,7 +903,7 @@ class WorldModel(DetectionModel):
|
|
|
769
903
|
(torch.Tensor): Model's output tensor.
|
|
770
904
|
"""
|
|
771
905
|
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
|
772
|
-
if
|
|
906
|
+
if txt_feats.shape[0] != x.shape[0] or self.model[-1].export:
|
|
773
907
|
txt_feats = txt_feats.expand(x.shape[0], -1, -1)
|
|
774
908
|
ori_txt_feats = txt_feats.clone()
|
|
775
909
|
y, dt, embeddings = [], [], [] # outputs
|
|
@@ -799,12 +933,11 @@ class WorldModel(DetectionModel):
|
|
|
799
933
|
return x
|
|
800
934
|
|
|
801
935
|
def loss(self, batch, preds=None):
|
|
802
|
-
"""
|
|
803
|
-
Compute loss.
|
|
936
|
+
"""Compute loss.
|
|
804
937
|
|
|
805
938
|
Args:
|
|
806
939
|
batch (dict): Batch to compute loss on.
|
|
807
|
-
preds (torch.Tensor |
|
|
940
|
+
preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
|
|
808
941
|
"""
|
|
809
942
|
if not hasattr(self, "criterion"):
|
|
810
943
|
self.criterion = self.init_criterion()
|
|
@@ -815,11 +948,34 @@ class WorldModel(DetectionModel):
|
|
|
815
948
|
|
|
816
949
|
|
|
817
950
|
class YOLOEModel(DetectionModel):
|
|
818
|
-
"""YOLOE detection model.
|
|
951
|
+
"""YOLOE detection model.
|
|
952
|
+
|
|
953
|
+
This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
|
|
954
|
+
both prompt-based and prompt-free inference modes.
|
|
955
|
+
|
|
956
|
+
Attributes:
|
|
957
|
+
pe (torch.Tensor): Prompt embeddings for classes.
|
|
958
|
+
clip_model (torch.nn.Module): CLIP model for text encoding.
|
|
959
|
+
|
|
960
|
+
Methods:
|
|
961
|
+
__init__: Initialize YOLOE model.
|
|
962
|
+
get_text_pe: Get text positional embeddings.
|
|
963
|
+
get_visual_pe: Get visual embeddings.
|
|
964
|
+
set_vocab: Set vocabulary for prompt-free model.
|
|
965
|
+
get_vocab: Get fused vocabulary layer.
|
|
966
|
+
set_classes: Set classes for offline inference.
|
|
967
|
+
get_cls_pe: Get class positional embeddings.
|
|
968
|
+
predict: Perform forward pass with prompts.
|
|
969
|
+
loss: Compute loss with prompts.
|
|
970
|
+
|
|
971
|
+
Examples:
|
|
972
|
+
Initialize a YOLOE model
|
|
973
|
+
>>> model = YOLOEModel("yoloe-v8s.yaml", ch=3, nc=80)
|
|
974
|
+
>>> results = model.predict(image_tensor, tpe=text_embeddings)
|
|
975
|
+
"""
|
|
819
976
|
|
|
820
977
|
def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
|
|
821
|
-
"""
|
|
822
|
-
Initialize YOLOE model with given config and parameters.
|
|
978
|
+
"""Initialize YOLOE model with given config and parameters.
|
|
823
979
|
|
|
824
980
|
Args:
|
|
825
981
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -831,11 +987,10 @@ class YOLOEModel(DetectionModel):
|
|
|
831
987
|
|
|
832
988
|
@smart_inference_mode()
|
|
833
989
|
def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
|
|
834
|
-
"""
|
|
835
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
990
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
836
991
|
|
|
837
992
|
Args:
|
|
838
|
-
text (
|
|
993
|
+
text (list[str]): List of class names.
|
|
839
994
|
batch (int): Batch size for processing text tokens.
|
|
840
995
|
cache_clip_model (bool): Whether to cache the CLIP model.
|
|
841
996
|
without_reprta (bool): Whether to return text embeddings cooperated with reprta module.
|
|
@@ -858,15 +1013,13 @@ class YOLOEModel(DetectionModel):
|
|
|
858
1013
|
if without_reprta:
|
|
859
1014
|
return txt_feats
|
|
860
1015
|
|
|
861
|
-
assert not self.training
|
|
862
1016
|
head = self.model[-1]
|
|
863
1017
|
assert isinstance(head, YOLOEDetect)
|
|
864
|
-
return head.get_tpe(txt_feats) # run
|
|
1018
|
+
return head.get_tpe(txt_feats) # run auxiliary text head
|
|
865
1019
|
|
|
866
1020
|
@smart_inference_mode()
|
|
867
1021
|
def get_visual_pe(self, img, visual):
|
|
868
|
-
"""
|
|
869
|
-
Get visual embeddings.
|
|
1022
|
+
"""Get visual embeddings.
|
|
870
1023
|
|
|
871
1024
|
Args:
|
|
872
1025
|
img (torch.Tensor): Input image tensor.
|
|
@@ -878,12 +1031,11 @@ class YOLOEModel(DetectionModel):
|
|
|
878
1031
|
return self(img, vpe=visual, return_vpe=True)
|
|
879
1032
|
|
|
880
1033
|
def set_vocab(self, vocab, names):
|
|
881
|
-
"""
|
|
882
|
-
Set vocabulary for the prompt-free model.
|
|
1034
|
+
"""Set vocabulary for the prompt-free model.
|
|
883
1035
|
|
|
884
1036
|
Args:
|
|
885
1037
|
vocab (nn.ModuleList): List of vocabulary items.
|
|
886
|
-
names (
|
|
1038
|
+
names (list[str]): List of class names.
|
|
887
1039
|
"""
|
|
888
1040
|
assert not self.training
|
|
889
1041
|
head = self.model[-1]
|
|
@@ -907,8 +1059,7 @@ class YOLOEModel(DetectionModel):
|
|
|
907
1059
|
self.names = check_class_names(names)
|
|
908
1060
|
|
|
909
1061
|
def get_vocab(self, names):
|
|
910
|
-
"""
|
|
911
|
-
Get fused vocabulary layer from the model.
|
|
1062
|
+
"""Get fused vocabulary layer from the model.
|
|
912
1063
|
|
|
913
1064
|
Args:
|
|
914
1065
|
names (list): List of class names.
|
|
@@ -933,11 +1084,10 @@ class YOLOEModel(DetectionModel):
|
|
|
933
1084
|
return vocab
|
|
934
1085
|
|
|
935
1086
|
def set_classes(self, names, embeddings):
|
|
936
|
-
"""
|
|
937
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
1087
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
938
1088
|
|
|
939
1089
|
Args:
|
|
940
|
-
names (
|
|
1090
|
+
names (list[str]): List of class names.
|
|
941
1091
|
embeddings (torch.Tensor): Embeddings tensor.
|
|
942
1092
|
"""
|
|
943
1093
|
assert not hasattr(self.model[-1], "lrpc"), (
|
|
@@ -949,8 +1099,7 @@ class YOLOEModel(DetectionModel):
|
|
|
949
1099
|
self.names = check_class_names(names)
|
|
950
1100
|
|
|
951
1101
|
def get_cls_pe(self, tpe, vpe):
|
|
952
|
-
"""
|
|
953
|
-
Get class positional embeddings.
|
|
1102
|
+
"""Get class positional embeddings.
|
|
954
1103
|
|
|
955
1104
|
Args:
|
|
956
1105
|
tpe (torch.Tensor, optional): Text positional embeddings.
|
|
@@ -973,8 +1122,7 @@ class YOLOEModel(DetectionModel):
|
|
|
973
1122
|
def predict(
|
|
974
1123
|
self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
|
|
975
1124
|
):
|
|
976
|
-
"""
|
|
977
|
-
Perform a forward pass through the model.
|
|
1125
|
+
"""Perform a forward pass through the model.
|
|
978
1126
|
|
|
979
1127
|
Args:
|
|
980
1128
|
x (torch.Tensor): The input tensor.
|
|
@@ -1021,12 +1169,11 @@ class YOLOEModel(DetectionModel):
|
|
|
1021
1169
|
return x
|
|
1022
1170
|
|
|
1023
1171
|
def loss(self, batch, preds=None):
|
|
1024
|
-
"""
|
|
1025
|
-
Compute loss.
|
|
1172
|
+
"""Compute loss.
|
|
1026
1173
|
|
|
1027
1174
|
Args:
|
|
1028
1175
|
batch (dict): Batch to compute loss on.
|
|
1029
|
-
preds (torch.Tensor |
|
|
1176
|
+
preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
|
|
1030
1177
|
"""
|
|
1031
1178
|
if not hasattr(self, "criterion"):
|
|
1032
1179
|
from ultralytics.utils.loss import TVPDetectLoss
|
|
@@ -1040,11 +1187,23 @@ class YOLOEModel(DetectionModel):
|
|
|
1040
1187
|
|
|
1041
1188
|
|
|
1042
1189
|
class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
1043
|
-
"""YOLOE segmentation model.
|
|
1190
|
+
"""YOLOE segmentation model.
|
|
1191
|
+
|
|
1192
|
+
This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
|
|
1193
|
+
specialized loss computation for pixel-level object detection and segmentation.
|
|
1194
|
+
|
|
1195
|
+
Methods:
|
|
1196
|
+
__init__: Initialize YOLOE segmentation model.
|
|
1197
|
+
loss: Compute loss with prompts for segmentation.
|
|
1198
|
+
|
|
1199
|
+
Examples:
|
|
1200
|
+
Initialize a YOLOE segmentation model
|
|
1201
|
+
>>> model = YOLOESegModel("yoloe-v8s-seg.yaml", ch=3, nc=80)
|
|
1202
|
+
>>> results = model.predict(image_tensor, tpe=text_embeddings)
|
|
1203
|
+
"""
|
|
1044
1204
|
|
|
1045
1205
|
def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
|
|
1046
|
-
"""
|
|
1047
|
-
Initialize YOLOE segmentation model with given config and parameters.
|
|
1206
|
+
"""Initialize YOLOE segmentation model with given config and parameters.
|
|
1048
1207
|
|
|
1049
1208
|
Args:
|
|
1050
1209
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -1055,12 +1214,11 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1055
1214
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
1056
1215
|
|
|
1057
1216
|
def loss(self, batch, preds=None):
|
|
1058
|
-
"""
|
|
1059
|
-
Compute loss.
|
|
1217
|
+
"""Compute loss.
|
|
1060
1218
|
|
|
1061
1219
|
Args:
|
|
1062
1220
|
batch (dict): Batch to compute loss on.
|
|
1063
|
-
preds (torch.Tensor |
|
|
1221
|
+
preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
|
|
1064
1222
|
"""
|
|
1065
1223
|
if not hasattr(self, "criterion"):
|
|
1066
1224
|
from ultralytics.utils.loss import TVPSegmentLoss
|
|
@@ -1074,15 +1232,29 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1074
1232
|
|
|
1075
1233
|
|
|
1076
1234
|
class Ensemble(torch.nn.ModuleList):
|
|
1077
|
-
"""Ensemble of models.
|
|
1235
|
+
"""Ensemble of models.
|
|
1236
|
+
|
|
1237
|
+
This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
|
|
1238
|
+
or other ensemble techniques.
|
|
1239
|
+
|
|
1240
|
+
Methods:
|
|
1241
|
+
__init__: Initialize an ensemble of models.
|
|
1242
|
+
forward: Generate predictions from all models in the ensemble.
|
|
1243
|
+
|
|
1244
|
+
Examples:
|
|
1245
|
+
Create an ensemble of models
|
|
1246
|
+
>>> ensemble = Ensemble()
|
|
1247
|
+
>>> ensemble.append(model1)
|
|
1248
|
+
>>> ensemble.append(model2)
|
|
1249
|
+
>>> results = ensemble(image_tensor)
|
|
1250
|
+
"""
|
|
1078
1251
|
|
|
1079
1252
|
def __init__(self):
|
|
1080
1253
|
"""Initialize an ensemble of models."""
|
|
1081
1254
|
super().__init__()
|
|
1082
1255
|
|
|
1083
1256
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
|
1084
|
-
"""
|
|
1085
|
-
Generate the YOLO network's final layer.
|
|
1257
|
+
"""Generate the YOLO network's final layer.
|
|
1086
1258
|
|
|
1087
1259
|
Args:
|
|
1088
1260
|
x (torch.Tensor): Input tensor.
|
|
@@ -1091,7 +1263,8 @@ class Ensemble(torch.nn.ModuleList):
|
|
|
1091
1263
|
visualize (bool): Whether to visualize the features.
|
|
1092
1264
|
|
|
1093
1265
|
Returns:
|
|
1094
|
-
(
|
|
1266
|
+
y (torch.Tensor): Concatenated predictions from all models.
|
|
1267
|
+
train_out (None): Always None for ensemble inference.
|
|
1095
1268
|
"""
|
|
1096
1269
|
y = [module(x, augment, profile, visualize)[0] for module in self]
|
|
1097
1270
|
# y = torch.stack(y).max(0)[0] # max ensemble
|
|
@@ -1105,12 +1278,11 @@ class Ensemble(torch.nn.ModuleList):
|
|
|
1105
1278
|
|
|
1106
1279
|
@contextlib.contextmanager
|
|
1107
1280
|
def temporary_modules(modules=None, attributes=None):
|
|
1108
|
-
"""
|
|
1109
|
-
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
|
1281
|
+
"""Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
|
1110
1282
|
|
|
1111
|
-
This function can be used to change the module paths during runtime. It's useful when refactoring code,
|
|
1112
|
-
|
|
1113
|
-
|
|
1283
|
+
This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
|
|
1284
|
+
moved a module from one location to another, but you still want to support the old import paths for backwards
|
|
1285
|
+
compatibility.
|
|
1114
1286
|
|
|
1115
1287
|
Args:
|
|
1116
1288
|
modules (dict, optional): A dictionary mapping old module paths to new module paths.
|
|
@@ -1121,7 +1293,7 @@ def temporary_modules(modules=None, attributes=None):
|
|
|
1121
1293
|
>>> import old.module # this will now import new.module
|
|
1122
1294
|
>>> from old.module import attribute # this will now import new.module.attribute
|
|
1123
1295
|
|
|
1124
|
-
|
|
1296
|
+
Notes:
|
|
1125
1297
|
The changes are only in effect inside the context manager and are undone once the context manager exits.
|
|
1126
1298
|
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
|
1127
1299
|
applications or libraries. Use this function with caution.
|
|
@@ -1168,8 +1340,7 @@ class SafeUnpickler(pickle.Unpickler):
|
|
|
1168
1340
|
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
|
1169
1341
|
|
|
1170
1342
|
def find_class(self, module, name):
|
|
1171
|
-
"""
|
|
1172
|
-
Attempt to find a class, returning SafeClass if not among safe modules.
|
|
1343
|
+
"""Attempt to find a class, returning SafeClass if not among safe modules.
|
|
1173
1344
|
|
|
1174
1345
|
Args:
|
|
1175
1346
|
module (str): Module name.
|
|
@@ -1194,10 +1365,9 @@ class SafeUnpickler(pickle.Unpickler):
|
|
|
1194
1365
|
|
|
1195
1366
|
|
|
1196
1367
|
def torch_safe_load(weight, safe_only=False):
|
|
1197
|
-
"""
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
After installation, the function again attempts to load the model using torch.load().
|
|
1368
|
+
"""Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
|
|
1369
|
+
the error, logs a warning message, and attempts to install the missing module via the check_requirements()
|
|
1370
|
+
function. After installation, the function again attempts to load the model using torch.load().
|
|
1201
1371
|
|
|
1202
1372
|
Args:
|
|
1203
1373
|
weight (str): The file path of the PyTorch model.
|
|
@@ -1234,9 +1404,9 @@ def torch_safe_load(weight, safe_only=False):
|
|
|
1234
1404
|
safe_pickle.Unpickler = SafeUnpickler
|
|
1235
1405
|
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
|
|
1236
1406
|
with open(file, "rb") as f:
|
|
1237
|
-
ckpt =
|
|
1407
|
+
ckpt = torch_load(f, pickle_module=safe_pickle)
|
|
1238
1408
|
else:
|
|
1239
|
-
ckpt =
|
|
1409
|
+
ckpt = torch_load(file, map_location="cpu")
|
|
1240
1410
|
|
|
1241
1411
|
except ModuleNotFoundError as e: # e.name is missing module name
|
|
1242
1412
|
if e.name == "models":
|
|
@@ -1249,6 +1419,12 @@ def torch_safe_load(weight, safe_only=False):
|
|
|
1249
1419
|
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
|
|
1250
1420
|
)
|
|
1251
1421
|
) from e
|
|
1422
|
+
elif e.name == "numpy._core":
|
|
1423
|
+
raise ModuleNotFoundError(
|
|
1424
|
+
emojis(
|
|
1425
|
+
f"ERROR ❌️ {weight} requires numpy>=1.26.1, however numpy=={__import__('numpy').__version__} is installed."
|
|
1426
|
+
)
|
|
1427
|
+
) from e
|
|
1252
1428
|
LOGGER.warning(
|
|
1253
1429
|
f"{weight} appears to require '{e.name}', which is not in Ultralytics requirements."
|
|
1254
1430
|
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
|
@@ -1256,7 +1432,7 @@ def torch_safe_load(weight, safe_only=False):
|
|
|
1256
1432
|
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
|
|
1257
1433
|
)
|
|
1258
1434
|
check_requirements(e.name) # install missing module
|
|
1259
|
-
ckpt =
|
|
1435
|
+
ckpt = torch_load(file, map_location="cpu")
|
|
1260
1436
|
|
|
1261
1437
|
if not isinstance(ckpt, dict):
|
|
1262
1438
|
# File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt")
|
|
@@ -1269,80 +1445,31 @@ def torch_safe_load(weight, safe_only=False):
|
|
|
1269
1445
|
return ckpt, file
|
|
1270
1446
|
|
|
1271
1447
|
|
|
1272
|
-
def
|
|
1273
|
-
"""
|
|
1274
|
-
Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.
|
|
1275
|
-
|
|
1276
|
-
Args:
|
|
1277
|
-
weights (str | List[str]): Model weights path(s).
|
|
1278
|
-
device (torch.device, optional): Device to load model to.
|
|
1279
|
-
inplace (bool): Whether to do inplace operations.
|
|
1280
|
-
fuse (bool): Whether to fuse model.
|
|
1281
|
-
|
|
1282
|
-
Returns:
|
|
1283
|
-
(torch.nn.Module): Loaded model.
|
|
1284
|
-
"""
|
|
1285
|
-
ensemble = Ensemble()
|
|
1286
|
-
for w in weights if isinstance(weights, list) else [weights]:
|
|
1287
|
-
ckpt, w = torch_safe_load(w) # load ckpt
|
|
1288
|
-
args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
|
|
1289
|
-
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
|
1290
|
-
|
|
1291
|
-
# Model compatibility updates
|
|
1292
|
-
model.args = args # attach args to model
|
|
1293
|
-
model.pt_path = w # attach *.pt file path to model
|
|
1294
|
-
model.task = guess_model_task(model)
|
|
1295
|
-
if not hasattr(model, "stride"):
|
|
1296
|
-
model.stride = torch.tensor([32.0])
|
|
1297
|
-
|
|
1298
|
-
# Append
|
|
1299
|
-
ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
|
|
1300
|
-
|
|
1301
|
-
# Module updates
|
|
1302
|
-
for m in ensemble.modules():
|
|
1303
|
-
if hasattr(m, "inplace"):
|
|
1304
|
-
m.inplace = inplace
|
|
1305
|
-
elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
|
1306
|
-
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
|
1307
|
-
|
|
1308
|
-
# Return model
|
|
1309
|
-
if len(ensemble) == 1:
|
|
1310
|
-
return ensemble[-1]
|
|
1311
|
-
|
|
1312
|
-
# Return ensemble
|
|
1313
|
-
LOGGER.info(f"Ensemble created with {weights}\n")
|
|
1314
|
-
for k in "names", "nc", "yaml":
|
|
1315
|
-
setattr(ensemble, k, getattr(ensemble[0], k))
|
|
1316
|
-
ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride
|
|
1317
|
-
assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
|
|
1318
|
-
return ensemble
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|
1322
|
-
"""
|
|
1323
|
-
Load a single model weights.
|
|
1448
|
+
def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
|
1449
|
+
"""Load a single model weights.
|
|
1324
1450
|
|
|
1325
1451
|
Args:
|
|
1326
|
-
weight (str): Model weight path.
|
|
1452
|
+
weight (str | Path): Model weight path.
|
|
1327
1453
|
device (torch.device, optional): Device to load model to.
|
|
1328
1454
|
inplace (bool): Whether to do inplace operations.
|
|
1329
1455
|
fuse (bool): Whether to fuse model.
|
|
1330
1456
|
|
|
1331
1457
|
Returns:
|
|
1332
|
-
(
|
|
1458
|
+
model (torch.nn.Module): Loaded model.
|
|
1459
|
+
ckpt (dict): Model checkpoint dictionary.
|
|
1333
1460
|
"""
|
|
1334
1461
|
ckpt, weight = torch_safe_load(weight) # load ckpt
|
|
1335
1462
|
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
|
|
1336
|
-
model = (ckpt.get("ema") or ckpt["model"]).
|
|
1463
|
+
model = (ckpt.get("ema") or ckpt["model"]).float() # FP32 model
|
|
1337
1464
|
|
|
1338
1465
|
# Model compatibility updates
|
|
1339
|
-
model.args =
|
|
1466
|
+
model.args = args # attach args to model
|
|
1340
1467
|
model.pt_path = weight # attach *.pt file path to model
|
|
1341
|
-
model.task = guess_model_task(model)
|
|
1468
|
+
model.task = getattr(model, "task", guess_model_task(model))
|
|
1342
1469
|
if not hasattr(model, "stride"):
|
|
1343
1470
|
model.stride = torch.tensor([32.0])
|
|
1344
1471
|
|
|
1345
|
-
model = model.fuse()
|
|
1472
|
+
model = (model.fuse() if fuse and hasattr(model, "fuse") else model).eval().to(device) # model in eval mode
|
|
1346
1473
|
|
|
1347
1474
|
# Module updates
|
|
1348
1475
|
for m in model.modules():
|
|
@@ -1355,9 +1482,8 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|
|
1355
1482
|
return model, ckpt
|
|
1356
1483
|
|
|
1357
1484
|
|
|
1358
|
-
def parse_model(d, ch, verbose=True):
|
|
1359
|
-
"""
|
|
1360
|
-
Parse a YOLO model.yaml dictionary into a PyTorch model.
|
|
1485
|
+
def parse_model(d, ch, verbose=True):
|
|
1486
|
+
"""Parse a YOLO model.yaml dictionary into a PyTorch model.
|
|
1361
1487
|
|
|
1362
1488
|
Args:
|
|
1363
1489
|
d (dict): Model dictionary.
|
|
@@ -1365,7 +1491,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
|
1365
1491
|
verbose (bool): Whether to print model details.
|
|
1366
1492
|
|
|
1367
1493
|
Returns:
|
|
1368
|
-
(
|
|
1494
|
+
model (torch.nn.Sequential): PyTorch model.
|
|
1495
|
+
save (list): Sorted list of output layers.
|
|
1369
1496
|
"""
|
|
1370
1497
|
import ast
|
|
1371
1498
|
|
|
@@ -1374,10 +1501,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
|
1374
1501
|
max_channels = float("inf")
|
|
1375
1502
|
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
|
1376
1503
|
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
|
1504
|
+
scale = d.get("scale")
|
|
1377
1505
|
if scales:
|
|
1378
|
-
scale = d.get("scale")
|
|
1379
1506
|
if not scale:
|
|
1380
|
-
scale =
|
|
1507
|
+
scale = next(iter(scales.keys()))
|
|
1381
1508
|
LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
|
|
1382
1509
|
depth, width, max_channels = scales[scale]
|
|
1383
1510
|
|
|
@@ -1524,7 +1651,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
|
1524
1651
|
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
|
1525
1652
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
|
1526
1653
|
if verbose:
|
|
1527
|
-
LOGGER.info(f"{i:>3}{
|
|
1654
|
+
LOGGER.info(f"{i:>3}{f!s:>20}{n_:>3}{m_.np:10.0f} {t:<45}{args!s:<30}") # print
|
|
1528
1655
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
|
1529
1656
|
layers.append(m_)
|
|
1530
1657
|
if i == 0:
|
|
@@ -1534,8 +1661,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
|
1534
1661
|
|
|
1535
1662
|
|
|
1536
1663
|
def yaml_model_load(path):
|
|
1537
|
-
"""
|
|
1538
|
-
Load a YOLOv8 model from a YAML file.
|
|
1664
|
+
"""Load a YOLOv8 model from a YAML file.
|
|
1539
1665
|
|
|
1540
1666
|
Args:
|
|
1541
1667
|
path (str | Path): Path to the YAML file.
|
|
@@ -1558,8 +1684,7 @@ def yaml_model_load(path):
|
|
|
1558
1684
|
|
|
1559
1685
|
|
|
1560
1686
|
def guess_model_scale(model_path):
|
|
1561
|
-
"""
|
|
1562
|
-
Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
|
1687
|
+
"""Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
|
1563
1688
|
|
|
1564
1689
|
Args:
|
|
1565
1690
|
model_path (str | Path): The path to the YOLO model's YAML file.
|
|
@@ -1568,14 +1693,13 @@ def guess_model_scale(model_path):
|
|
|
1568
1693
|
(str): The size character of the model's scale (n, s, m, l, or x).
|
|
1569
1694
|
"""
|
|
1570
1695
|
try:
|
|
1571
|
-
return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
|
|
1696
|
+
return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
|
|
1572
1697
|
except AttributeError:
|
|
1573
1698
|
return ""
|
|
1574
1699
|
|
|
1575
1700
|
|
|
1576
1701
|
def guess_model_task(model):
|
|
1577
|
-
"""
|
|
1578
|
-
Guess the task of a PyTorch model from its architecture or configuration.
|
|
1702
|
+
"""Guess the task of a PyTorch model from its architecture or configuration.
|
|
1579
1703
|
|
|
1580
1704
|
Args:
|
|
1581
1705
|
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
|