dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py
CHANGED
|
@@ -95,11 +95,10 @@ from ultralytics.utils.torch_utils import (
|
|
|
95
95
|
|
|
96
96
|
|
|
97
97
|
class BaseModel(torch.nn.Module):
|
|
98
|
-
"""
|
|
99
|
-
Base class for all YOLO models in the Ultralytics family.
|
|
98
|
+
"""Base class for all YOLO models in the Ultralytics family.
|
|
100
99
|
|
|
101
|
-
This class provides common functionality for YOLO models including forward pass handling, model fusion,
|
|
102
|
-
|
|
100
|
+
This class provides common functionality for YOLO models including forward pass handling, model fusion, information
|
|
101
|
+
display, and weight loading capabilities.
|
|
103
102
|
|
|
104
103
|
Attributes:
|
|
105
104
|
model (torch.nn.Module): The neural network model.
|
|
@@ -121,8 +120,7 @@ class BaseModel(torch.nn.Module):
|
|
|
121
120
|
"""
|
|
122
121
|
|
|
123
122
|
def forward(self, x, *args, **kwargs):
|
|
124
|
-
"""
|
|
125
|
-
Perform forward pass of the model for either training or inference.
|
|
123
|
+
"""Perform forward pass of the model for either training or inference.
|
|
126
124
|
|
|
127
125
|
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
|
128
126
|
|
|
@@ -139,8 +137,7 @@ class BaseModel(torch.nn.Module):
|
|
|
139
137
|
return self.predict(x, *args, **kwargs)
|
|
140
138
|
|
|
141
139
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
|
142
|
-
"""
|
|
143
|
-
Perform a forward pass through the network.
|
|
140
|
+
"""Perform a forward pass through the network.
|
|
144
141
|
|
|
145
142
|
Args:
|
|
146
143
|
x (torch.Tensor): The input tensor to the model.
|
|
@@ -157,8 +154,7 @@ class BaseModel(torch.nn.Module):
|
|
|
157
154
|
return self._predict_once(x, profile, visualize, embed)
|
|
158
155
|
|
|
159
156
|
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
|
160
|
-
"""
|
|
161
|
-
Perform a forward pass through the network.
|
|
157
|
+
"""Perform a forward pass through the network.
|
|
162
158
|
|
|
163
159
|
Args:
|
|
164
160
|
x (torch.Tensor): The input tensor to the model.
|
|
@@ -196,8 +192,7 @@ class BaseModel(torch.nn.Module):
|
|
|
196
192
|
return self._predict_once(x)
|
|
197
193
|
|
|
198
194
|
def _profile_one_layer(self, m, x, dt):
|
|
199
|
-
"""
|
|
200
|
-
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
|
195
|
+
"""Profile the computation time and FLOPs of a single layer of the model on a given input.
|
|
201
196
|
|
|
202
197
|
Args:
|
|
203
198
|
m (torch.nn.Module): The layer to be profiled.
|
|
@@ -222,8 +217,7 @@ class BaseModel(torch.nn.Module):
|
|
|
222
217
|
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
|
223
218
|
|
|
224
219
|
def fuse(self, verbose=True):
|
|
225
|
-
"""
|
|
226
|
-
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
|
220
|
+
"""Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
|
227
221
|
efficiency.
|
|
228
222
|
|
|
229
223
|
Returns:
|
|
@@ -254,8 +248,7 @@ class BaseModel(torch.nn.Module):
|
|
|
254
248
|
return self
|
|
255
249
|
|
|
256
250
|
def is_fused(self, thresh=10):
|
|
257
|
-
"""
|
|
258
|
-
Check if the model has less than a certain threshold of BatchNorm layers.
|
|
251
|
+
"""Check if the model has less than a certain threshold of BatchNorm layers.
|
|
259
252
|
|
|
260
253
|
Args:
|
|
261
254
|
thresh (int, optional): The threshold number of BatchNorm layers.
|
|
@@ -267,8 +260,7 @@ class BaseModel(torch.nn.Module):
|
|
|
267
260
|
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
|
268
261
|
|
|
269
262
|
def info(self, detailed=False, verbose=True, imgsz=640):
|
|
270
|
-
"""
|
|
271
|
-
Print model information.
|
|
263
|
+
"""Print model information.
|
|
272
264
|
|
|
273
265
|
Args:
|
|
274
266
|
detailed (bool): If True, prints out detailed information about the model.
|
|
@@ -278,8 +270,7 @@ class BaseModel(torch.nn.Module):
|
|
|
278
270
|
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
|
279
271
|
|
|
280
272
|
def _apply(self, fn):
|
|
281
|
-
"""
|
|
282
|
-
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
273
|
+
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
283
274
|
|
|
284
275
|
Args:
|
|
285
276
|
fn (function): The function to apply to the model.
|
|
@@ -298,8 +289,7 @@ class BaseModel(torch.nn.Module):
|
|
|
298
289
|
return self
|
|
299
290
|
|
|
300
291
|
def load(self, weights, verbose=True):
|
|
301
|
-
"""
|
|
302
|
-
Load weights into the model.
|
|
292
|
+
"""Load weights into the model.
|
|
303
293
|
|
|
304
294
|
Args:
|
|
305
295
|
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
|
@@ -324,8 +314,7 @@ class BaseModel(torch.nn.Module):
|
|
|
324
314
|
LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
|
|
325
315
|
|
|
326
316
|
def loss(self, batch, preds=None):
|
|
327
|
-
"""
|
|
328
|
-
Compute loss.
|
|
317
|
+
"""Compute loss.
|
|
329
318
|
|
|
330
319
|
Args:
|
|
331
320
|
batch (dict): Batch to compute loss on.
|
|
@@ -344,11 +333,10 @@ class BaseModel(torch.nn.Module):
|
|
|
344
333
|
|
|
345
334
|
|
|
346
335
|
class DetectionModel(BaseModel):
|
|
347
|
-
"""
|
|
348
|
-
YOLO detection model.
|
|
336
|
+
"""YOLO detection model.
|
|
349
337
|
|
|
350
|
-
This class implements the YOLO detection architecture, handling model initialization, forward pass,
|
|
351
|
-
|
|
338
|
+
This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
|
|
339
|
+
inference, and loss computation for object detection tasks.
|
|
352
340
|
|
|
353
341
|
Attributes:
|
|
354
342
|
yaml (dict): Model configuration dictionary.
|
|
@@ -373,8 +361,7 @@ class DetectionModel(BaseModel):
|
|
|
373
361
|
"""
|
|
374
362
|
|
|
375
363
|
def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
|
|
376
|
-
"""
|
|
377
|
-
Initialize the YOLO detection model with the given config and parameters.
|
|
364
|
+
"""Initialize the YOLO detection model with the given config and parameters.
|
|
378
365
|
|
|
379
366
|
Args:
|
|
380
367
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -420,7 +407,7 @@ class DetectionModel(BaseModel):
|
|
|
420
407
|
self.model.train() # Set model back to training(default) mode
|
|
421
408
|
m.bias_init() # only run once
|
|
422
409
|
else:
|
|
423
|
-
self.stride = torch.Tensor([32]) # default stride
|
|
410
|
+
self.stride = torch.Tensor([32]) # default stride, e.g., RTDETR
|
|
424
411
|
|
|
425
412
|
# Init weights, biases
|
|
426
413
|
initialize_weights(self)
|
|
@@ -429,8 +416,7 @@ class DetectionModel(BaseModel):
|
|
|
429
416
|
LOGGER.info("")
|
|
430
417
|
|
|
431
418
|
def _predict_augment(self, x):
|
|
432
|
-
"""
|
|
433
|
-
Perform augmentations on input image x and return augmented inference and train outputs.
|
|
419
|
+
"""Perform augmentations on input image x and return augmented inference and train outputs.
|
|
434
420
|
|
|
435
421
|
Args:
|
|
436
422
|
x (torch.Tensor): Input image tensor.
|
|
@@ -455,8 +441,7 @@ class DetectionModel(BaseModel):
|
|
|
455
441
|
|
|
456
442
|
@staticmethod
|
|
457
443
|
def _descale_pred(p, flips, scale, img_size, dim=1):
|
|
458
|
-
"""
|
|
459
|
-
De-scale predictions following augmented inference (inverse operation).
|
|
444
|
+
"""De-scale predictions following augmented inference (inverse operation).
|
|
460
445
|
|
|
461
446
|
Args:
|
|
462
447
|
p (torch.Tensor): Predictions tensor.
|
|
@@ -477,8 +462,7 @@ class DetectionModel(BaseModel):
|
|
|
477
462
|
return torch.cat((x, y, wh, cls), dim)
|
|
478
463
|
|
|
479
464
|
def _clip_augmented(self, y):
|
|
480
|
-
"""
|
|
481
|
-
Clip YOLO augmented inference tails.
|
|
465
|
+
"""Clip YOLO augmented inference tails.
|
|
482
466
|
|
|
483
467
|
Args:
|
|
484
468
|
y (list[torch.Tensor]): List of detection tensors.
|
|
@@ -501,11 +485,10 @@ class DetectionModel(BaseModel):
|
|
|
501
485
|
|
|
502
486
|
|
|
503
487
|
class OBBModel(DetectionModel):
|
|
504
|
-
"""
|
|
505
|
-
YOLO Oriented Bounding Box (OBB) model.
|
|
488
|
+
"""YOLO Oriented Bounding Box (OBB) model.
|
|
506
489
|
|
|
507
|
-
This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized
|
|
508
|
-
|
|
490
|
+
This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
|
|
491
|
+
computation for rotated object detection.
|
|
509
492
|
|
|
510
493
|
Methods:
|
|
511
494
|
__init__: Initialize YOLO OBB model.
|
|
@@ -518,8 +501,7 @@ class OBBModel(DetectionModel):
|
|
|
518
501
|
"""
|
|
519
502
|
|
|
520
503
|
def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
|
|
521
|
-
"""
|
|
522
|
-
Initialize YOLO OBB model with given config and parameters.
|
|
504
|
+
"""Initialize YOLO OBB model with given config and parameters.
|
|
523
505
|
|
|
524
506
|
Args:
|
|
525
507
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -535,11 +517,10 @@ class OBBModel(DetectionModel):
|
|
|
535
517
|
|
|
536
518
|
|
|
537
519
|
class SegmentationModel(DetectionModel):
|
|
538
|
-
"""
|
|
539
|
-
YOLO segmentation model.
|
|
520
|
+
"""YOLO segmentation model.
|
|
540
521
|
|
|
541
|
-
This class extends DetectionModel to handle instance segmentation tasks, providing specialized
|
|
542
|
-
|
|
522
|
+
This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
|
|
523
|
+
pixel-level object detection and segmentation.
|
|
543
524
|
|
|
544
525
|
Methods:
|
|
545
526
|
__init__: Initialize YOLO segmentation model.
|
|
@@ -552,8 +533,7 @@ class SegmentationModel(DetectionModel):
|
|
|
552
533
|
"""
|
|
553
534
|
|
|
554
535
|
def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
|
|
555
|
-
"""
|
|
556
|
-
Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
|
536
|
+
"""Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
|
557
537
|
|
|
558
538
|
Args:
|
|
559
539
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -569,11 +549,10 @@ class SegmentationModel(DetectionModel):
|
|
|
569
549
|
|
|
570
550
|
|
|
571
551
|
class PoseModel(DetectionModel):
|
|
572
|
-
"""
|
|
573
|
-
YOLO pose model.
|
|
552
|
+
"""YOLO pose model.
|
|
574
553
|
|
|
575
|
-
This class extends DetectionModel to handle human pose estimation tasks, providing specialized
|
|
576
|
-
|
|
554
|
+
This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
|
|
555
|
+
keypoint detection and pose estimation.
|
|
577
556
|
|
|
578
557
|
Attributes:
|
|
579
558
|
kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).
|
|
@@ -589,8 +568,7 @@ class PoseModel(DetectionModel):
|
|
|
589
568
|
"""
|
|
590
569
|
|
|
591
570
|
def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
|
592
|
-
"""
|
|
593
|
-
Initialize Ultralytics YOLO Pose model.
|
|
571
|
+
"""Initialize Ultralytics YOLO Pose model.
|
|
594
572
|
|
|
595
573
|
Args:
|
|
596
574
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -612,11 +590,10 @@ class PoseModel(DetectionModel):
|
|
|
612
590
|
|
|
613
591
|
|
|
614
592
|
class ClassificationModel(BaseModel):
|
|
615
|
-
"""
|
|
616
|
-
YOLO classification model.
|
|
593
|
+
"""YOLO classification model.
|
|
617
594
|
|
|
618
|
-
This class implements the YOLO classification architecture for image classification tasks,
|
|
619
|
-
|
|
595
|
+
This class implements the YOLO classification architecture for image classification tasks, providing model
|
|
596
|
+
initialization, configuration, and output reshaping capabilities.
|
|
620
597
|
|
|
621
598
|
Attributes:
|
|
622
599
|
yaml (dict): Model configuration dictionary.
|
|
@@ -637,8 +614,7 @@ class ClassificationModel(BaseModel):
|
|
|
637
614
|
"""
|
|
638
615
|
|
|
639
616
|
def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
|
|
640
|
-
"""
|
|
641
|
-
Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
|
617
|
+
"""Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
|
642
618
|
|
|
643
619
|
Args:
|
|
644
620
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -650,8 +626,7 @@ class ClassificationModel(BaseModel):
|
|
|
650
626
|
self._from_yaml(cfg, ch, nc, verbose)
|
|
651
627
|
|
|
652
628
|
def _from_yaml(self, cfg, ch, nc, verbose):
|
|
653
|
-
"""
|
|
654
|
-
Set Ultralytics YOLO model configurations and define the model architecture.
|
|
629
|
+
"""Set Ultralytics YOLO model configurations and define the model architecture.
|
|
655
630
|
|
|
656
631
|
Args:
|
|
657
632
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -675,8 +650,7 @@ class ClassificationModel(BaseModel):
|
|
|
675
650
|
|
|
676
651
|
@staticmethod
|
|
677
652
|
def reshape_outputs(model, nc):
|
|
678
|
-
"""
|
|
679
|
-
Update a TorchVision classification model to class count 'n' if required.
|
|
653
|
+
"""Update a TorchVision classification model to class count 'n' if required.
|
|
680
654
|
|
|
681
655
|
Args:
|
|
682
656
|
model (torch.nn.Module): Model to update.
|
|
@@ -708,8 +682,7 @@ class ClassificationModel(BaseModel):
|
|
|
708
682
|
|
|
709
683
|
|
|
710
684
|
class RTDETRDetectionModel(DetectionModel):
|
|
711
|
-
"""
|
|
712
|
-
RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
|
685
|
+
"""RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
|
713
686
|
|
|
714
687
|
This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
|
|
715
688
|
the training and inference processes. RTDETR is an object detection and tracking model that extends from the
|
|
@@ -732,8 +705,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
732
705
|
"""
|
|
733
706
|
|
|
734
707
|
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
|
|
735
|
-
"""
|
|
736
|
-
Initialize the RTDETRDetectionModel.
|
|
708
|
+
"""Initialize the RTDETRDetectionModel.
|
|
737
709
|
|
|
738
710
|
Args:
|
|
739
711
|
cfg (str | dict): Configuration file name or path.
|
|
@@ -743,6 +715,21 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
743
715
|
"""
|
|
744
716
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
745
717
|
|
|
718
|
+
def _apply(self, fn):
|
|
719
|
+
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
720
|
+
|
|
721
|
+
Args:
|
|
722
|
+
fn (function): The function to apply to the model.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
(RTDETRDetectionModel): An updated BaseModel object.
|
|
726
|
+
"""
|
|
727
|
+
self = super()._apply(fn)
|
|
728
|
+
m = self.model[-1]
|
|
729
|
+
m.anchors = fn(m.anchors)
|
|
730
|
+
m.valid_mask = fn(m.valid_mask)
|
|
731
|
+
return self
|
|
732
|
+
|
|
746
733
|
def init_criterion(self):
|
|
747
734
|
"""Initialize the loss criterion for the RTDETRDetectionModel."""
|
|
748
735
|
from ultralytics.models.utils.loss import RTDETRDetectionLoss
|
|
@@ -750,8 +737,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
750
737
|
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
|
751
738
|
|
|
752
739
|
def loss(self, batch, preds=None):
|
|
753
|
-
"""
|
|
754
|
-
Compute the loss for the given batch of data.
|
|
740
|
+
"""Compute the loss for the given batch of data.
|
|
755
741
|
|
|
756
742
|
Args:
|
|
757
743
|
batch (dict): Dictionary containing image and label data.
|
|
@@ -766,7 +752,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
766
752
|
|
|
767
753
|
img = batch["img"]
|
|
768
754
|
# NOTE: preprocess gt_bbox and gt_labels to list.
|
|
769
|
-
bs =
|
|
755
|
+
bs = img.shape[0]
|
|
770
756
|
batch_idx = batch["batch_idx"]
|
|
771
757
|
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
|
|
772
758
|
targets = {
|
|
@@ -797,8 +783,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
797
783
|
)
|
|
798
784
|
|
|
799
785
|
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
|
800
|
-
"""
|
|
801
|
-
Perform a forward pass through the model.
|
|
786
|
+
"""Perform a forward pass through the model.
|
|
802
787
|
|
|
803
788
|
Args:
|
|
804
789
|
x (torch.Tensor): The input tensor.
|
|
@@ -833,11 +818,10 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
833
818
|
|
|
834
819
|
|
|
835
820
|
class WorldModel(DetectionModel):
|
|
836
|
-
"""
|
|
837
|
-
YOLOv8 World Model.
|
|
821
|
+
"""YOLOv8 World Model.
|
|
838
822
|
|
|
839
|
-
This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based
|
|
840
|
-
|
|
823
|
+
This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
|
|
824
|
+
specification and CLIP model integration for zero-shot detection capabilities.
|
|
841
825
|
|
|
842
826
|
Attributes:
|
|
843
827
|
txt_feats (torch.Tensor): Text feature embeddings for classes.
|
|
@@ -858,8 +842,7 @@ class WorldModel(DetectionModel):
|
|
|
858
842
|
"""
|
|
859
843
|
|
|
860
844
|
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
|
861
|
-
"""
|
|
862
|
-
Initialize YOLOv8 world model with given config and parameters.
|
|
845
|
+
"""Initialize YOLOv8 world model with given config and parameters.
|
|
863
846
|
|
|
864
847
|
Args:
|
|
865
848
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -872,8 +855,7 @@ class WorldModel(DetectionModel):
|
|
|
872
855
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
873
856
|
|
|
874
857
|
def set_classes(self, text, batch=80, cache_clip_model=True):
|
|
875
|
-
"""
|
|
876
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
858
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
877
859
|
|
|
878
860
|
Args:
|
|
879
861
|
text (list[str]): List of class names.
|
|
@@ -884,8 +866,7 @@ class WorldModel(DetectionModel):
|
|
|
884
866
|
self.model[-1].nc = len(text)
|
|
885
867
|
|
|
886
868
|
def get_text_pe(self, text, batch=80, cache_clip_model=True):
|
|
887
|
-
"""
|
|
888
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
869
|
+
"""Get text positional embeddings for offline inference without CLIP model.
|
|
889
870
|
|
|
890
871
|
Args:
|
|
891
872
|
text (list[str]): List of class names.
|
|
@@ -908,8 +889,7 @@ class WorldModel(DetectionModel):
|
|
|
908
889
|
return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
|
909
890
|
|
|
910
891
|
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
|
911
|
-
"""
|
|
912
|
-
Perform a forward pass through the model.
|
|
892
|
+
"""Perform a forward pass through the model.
|
|
913
893
|
|
|
914
894
|
Args:
|
|
915
895
|
x (torch.Tensor): The input tensor.
|
|
@@ -923,7 +903,7 @@ class WorldModel(DetectionModel):
|
|
|
923
903
|
(torch.Tensor): Model's output tensor.
|
|
924
904
|
"""
|
|
925
905
|
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
|
926
|
-
if
|
|
906
|
+
if txt_feats.shape[0] != x.shape[0] or self.model[-1].export:
|
|
927
907
|
txt_feats = txt_feats.expand(x.shape[0], -1, -1)
|
|
928
908
|
ori_txt_feats = txt_feats.clone()
|
|
929
909
|
y, dt, embeddings = [], [], [] # outputs
|
|
@@ -953,8 +933,7 @@ class WorldModel(DetectionModel):
|
|
|
953
933
|
return x
|
|
954
934
|
|
|
955
935
|
def loss(self, batch, preds=None):
|
|
956
|
-
"""
|
|
957
|
-
Compute loss.
|
|
936
|
+
"""Compute loss.
|
|
958
937
|
|
|
959
938
|
Args:
|
|
960
939
|
batch (dict): Batch to compute loss on.
|
|
@@ -969,11 +948,10 @@ class WorldModel(DetectionModel):
|
|
|
969
948
|
|
|
970
949
|
|
|
971
950
|
class YOLOEModel(DetectionModel):
|
|
972
|
-
"""
|
|
973
|
-
YOLOE detection model.
|
|
951
|
+
"""YOLOE detection model.
|
|
974
952
|
|
|
975
|
-
This class implements the YOLOE architecture for efficient object detection with text and visual prompts,
|
|
976
|
-
|
|
953
|
+
This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
|
|
954
|
+
both prompt-based and prompt-free inference modes.
|
|
977
955
|
|
|
978
956
|
Attributes:
|
|
979
957
|
pe (torch.Tensor): Prompt embeddings for classes.
|
|
@@ -997,8 +975,7 @@ class YOLOEModel(DetectionModel):
|
|
|
997
975
|
"""
|
|
998
976
|
|
|
999
977
|
def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
|
|
1000
|
-
"""
|
|
1001
|
-
Initialize YOLOE model with given config and parameters.
|
|
978
|
+
"""Initialize YOLOE model with given config and parameters.
|
|
1002
979
|
|
|
1003
980
|
Args:
|
|
1004
981
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -1010,14 +987,13 @@ class YOLOEModel(DetectionModel):
|
|
|
1010
987
|
|
|
1011
988
|
@smart_inference_mode()
|
|
1012
989
|
def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
|
|
1013
|
-
"""
|
|
1014
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
990
|
+
"""Get text positional embeddings for offline inference without CLIP model.
|
|
1015
991
|
|
|
1016
992
|
Args:
|
|
1017
993
|
text (list[str]): List of class names.
|
|
1018
994
|
batch (int): Batch size for processing text tokens.
|
|
1019
995
|
cache_clip_model (bool): Whether to cache the CLIP model.
|
|
1020
|
-
without_reprta (bool): Whether to return text embeddings
|
|
996
|
+
without_reprta (bool): Whether to return text embeddings without reprta module processing.
|
|
1021
997
|
|
|
1022
998
|
Returns:
|
|
1023
999
|
(torch.Tensor): Text positional embeddings.
|
|
@@ -1037,15 +1013,13 @@ class YOLOEModel(DetectionModel):
|
|
|
1037
1013
|
if without_reprta:
|
|
1038
1014
|
return txt_feats
|
|
1039
1015
|
|
|
1040
|
-
assert not self.training
|
|
1041
1016
|
head = self.model[-1]
|
|
1042
1017
|
assert isinstance(head, YOLOEDetect)
|
|
1043
1018
|
return head.get_tpe(txt_feats) # run auxiliary text head
|
|
1044
1019
|
|
|
1045
1020
|
@smart_inference_mode()
|
|
1046
1021
|
def get_visual_pe(self, img, visual):
|
|
1047
|
-
"""
|
|
1048
|
-
Get visual embeddings.
|
|
1022
|
+
"""Get visual embeddings.
|
|
1049
1023
|
|
|
1050
1024
|
Args:
|
|
1051
1025
|
img (torch.Tensor): Input image tensor.
|
|
@@ -1057,8 +1031,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1057
1031
|
return self(img, vpe=visual, return_vpe=True)
|
|
1058
1032
|
|
|
1059
1033
|
def set_vocab(self, vocab, names):
|
|
1060
|
-
"""
|
|
1061
|
-
Set vocabulary for the prompt-free model.
|
|
1034
|
+
"""Set vocabulary for the prompt-free model.
|
|
1062
1035
|
|
|
1063
1036
|
Args:
|
|
1064
1037
|
vocab (nn.ModuleList): List of vocabulary items.
|
|
@@ -1086,8 +1059,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1086
1059
|
self.names = check_class_names(names)
|
|
1087
1060
|
|
|
1088
1061
|
def get_vocab(self, names):
|
|
1089
|
-
"""
|
|
1090
|
-
Get fused vocabulary layer from the model.
|
|
1062
|
+
"""Get fused vocabulary layer from the model.
|
|
1091
1063
|
|
|
1092
1064
|
Args:
|
|
1093
1065
|
names (list): List of class names.
|
|
@@ -1112,8 +1084,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1112
1084
|
return vocab
|
|
1113
1085
|
|
|
1114
1086
|
def set_classes(self, names, embeddings):
|
|
1115
|
-
"""
|
|
1116
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
1087
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
1117
1088
|
|
|
1118
1089
|
Args:
|
|
1119
1090
|
names (list[str]): List of class names.
|
|
@@ -1128,8 +1099,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1128
1099
|
self.names = check_class_names(names)
|
|
1129
1100
|
|
|
1130
1101
|
def get_cls_pe(self, tpe, vpe):
|
|
1131
|
-
"""
|
|
1132
|
-
Get class positional embeddings.
|
|
1102
|
+
"""Get class positional embeddings.
|
|
1133
1103
|
|
|
1134
1104
|
Args:
|
|
1135
1105
|
tpe (torch.Tensor, optional): Text positional embeddings.
|
|
@@ -1152,8 +1122,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1152
1122
|
def predict(
|
|
1153
1123
|
self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
|
|
1154
1124
|
):
|
|
1155
|
-
"""
|
|
1156
|
-
Perform a forward pass through the model.
|
|
1125
|
+
"""Perform a forward pass through the model.
|
|
1157
1126
|
|
|
1158
1127
|
Args:
|
|
1159
1128
|
x (torch.Tensor): The input tensor.
|
|
@@ -1200,8 +1169,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1200
1169
|
return x
|
|
1201
1170
|
|
|
1202
1171
|
def loss(self, batch, preds=None):
|
|
1203
|
-
"""
|
|
1204
|
-
Compute loss.
|
|
1172
|
+
"""Compute loss.
|
|
1205
1173
|
|
|
1206
1174
|
Args:
|
|
1207
1175
|
batch (dict): Batch to compute loss on.
|
|
@@ -1219,11 +1187,10 @@ class YOLOEModel(DetectionModel):
|
|
|
1219
1187
|
|
|
1220
1188
|
|
|
1221
1189
|
class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
1222
|
-
"""
|
|
1223
|
-
YOLOE segmentation model.
|
|
1190
|
+
"""YOLOE segmentation model.
|
|
1224
1191
|
|
|
1225
|
-
This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts,
|
|
1226
|
-
|
|
1192
|
+
This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
|
|
1193
|
+
specialized loss computation for pixel-level object detection and segmentation.
|
|
1227
1194
|
|
|
1228
1195
|
Methods:
|
|
1229
1196
|
__init__: Initialize YOLOE segmentation model.
|
|
@@ -1236,8 +1203,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1236
1203
|
"""
|
|
1237
1204
|
|
|
1238
1205
|
def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
|
|
1239
|
-
"""
|
|
1240
|
-
Initialize YOLOE segmentation model with given config and parameters.
|
|
1206
|
+
"""Initialize YOLOE segmentation model with given config and parameters.
|
|
1241
1207
|
|
|
1242
1208
|
Args:
|
|
1243
1209
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -1248,8 +1214,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1248
1214
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
1249
1215
|
|
|
1250
1216
|
def loss(self, batch, preds=None):
|
|
1251
|
-
"""
|
|
1252
|
-
Compute loss.
|
|
1217
|
+
"""Compute loss.
|
|
1253
1218
|
|
|
1254
1219
|
Args:
|
|
1255
1220
|
batch (dict): Batch to compute loss on.
|
|
@@ -1267,11 +1232,10 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1267
1232
|
|
|
1268
1233
|
|
|
1269
1234
|
class Ensemble(torch.nn.ModuleList):
|
|
1270
|
-
"""
|
|
1271
|
-
Ensemble of models.
|
|
1235
|
+
"""Ensemble of models.
|
|
1272
1236
|
|
|
1273
|
-
This class allows combining multiple YOLO models into an ensemble for improved performance through
|
|
1274
|
-
|
|
1237
|
+
This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
|
|
1238
|
+
or other ensemble techniques.
|
|
1275
1239
|
|
|
1276
1240
|
Methods:
|
|
1277
1241
|
__init__: Initialize an ensemble of models.
|
|
@@ -1290,8 +1254,7 @@ class Ensemble(torch.nn.ModuleList):
|
|
|
1290
1254
|
super().__init__()
|
|
1291
1255
|
|
|
1292
1256
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
|
1293
|
-
"""
|
|
1294
|
-
Generate the YOLO network's final layer.
|
|
1257
|
+
"""Generate the YOLO network's final layer.
|
|
1295
1258
|
|
|
1296
1259
|
Args:
|
|
1297
1260
|
x (torch.Tensor): Input tensor.
|
|
@@ -1315,12 +1278,11 @@ class Ensemble(torch.nn.ModuleList):
|
|
|
1315
1278
|
|
|
1316
1279
|
@contextlib.contextmanager
|
|
1317
1280
|
def temporary_modules(modules=None, attributes=None):
|
|
1318
|
-
"""
|
|
1319
|
-
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
|
1281
|
+
"""Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
|
1320
1282
|
|
|
1321
|
-
This function can be used to change the module paths during runtime. It's useful when refactoring code,
|
|
1322
|
-
|
|
1323
|
-
|
|
1283
|
+
This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
|
|
1284
|
+
moved a module from one location to another, but you still want to support the old import paths for backwards
|
|
1285
|
+
compatibility.
|
|
1324
1286
|
|
|
1325
1287
|
Args:
|
|
1326
1288
|
modules (dict, optional): A dictionary mapping old module paths to new module paths.
|
|
@@ -1331,7 +1293,7 @@ def temporary_modules(modules=None, attributes=None):
|
|
|
1331
1293
|
>>> import old.module # this will now import new.module
|
|
1332
1294
|
>>> from old.module import attribute # this will now import new.module.attribute
|
|
1333
1295
|
|
|
1334
|
-
|
|
1296
|
+
Notes:
|
|
1335
1297
|
The changes are only in effect inside the context manager and are undone once the context manager exits.
|
|
1336
1298
|
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
|
1337
1299
|
applications or libraries. Use this function with caution.
|
|
@@ -1378,8 +1340,7 @@ class SafeUnpickler(pickle.Unpickler):
|
|
|
1378
1340
|
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
|
1379
1341
|
|
|
1380
1342
|
def find_class(self, module, name):
|
|
1381
|
-
"""
|
|
1382
|
-
Attempt to find a class, returning SafeClass if not among safe modules.
|
|
1343
|
+
"""Attempt to find a class, returning SafeClass if not among safe modules.
|
|
1383
1344
|
|
|
1384
1345
|
Args:
|
|
1385
1346
|
module (str): Module name.
|
|
@@ -1404,10 +1365,9 @@ class SafeUnpickler(pickle.Unpickler):
|
|
|
1404
1365
|
|
|
1405
1366
|
|
|
1406
1367
|
def torch_safe_load(weight, safe_only=False):
|
|
1407
|
-
"""
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
After installation, the function again attempts to load the model using torch.load().
|
|
1368
|
+
"""Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
|
|
1369
|
+
the error, logs a warning message, and attempts to install the missing module via the check_requirements()
|
|
1370
|
+
function. After installation, the function again attempts to load the model using torch.load().
|
|
1411
1371
|
|
|
1412
1372
|
Args:
|
|
1413
1373
|
weight (str): The file path of the PyTorch model.
|
|
@@ -1486,8 +1446,7 @@ def torch_safe_load(weight, safe_only=False):
|
|
|
1486
1446
|
|
|
1487
1447
|
|
|
1488
1448
|
def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
|
1489
|
-
"""
|
|
1490
|
-
Load a single model weights.
|
|
1449
|
+
"""Load a single model weights.
|
|
1491
1450
|
|
|
1492
1451
|
Args:
|
|
1493
1452
|
weight (str | Path): Model weight path.
|
|
@@ -1524,8 +1483,7 @@ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
|
|
1524
1483
|
|
|
1525
1484
|
|
|
1526
1485
|
def parse_model(d, ch, verbose=True):
|
|
1527
|
-
"""
|
|
1528
|
-
Parse a YOLO model.yaml dictionary into a PyTorch model.
|
|
1486
|
+
"""Parse a YOLO model.yaml dictionary into a PyTorch model.
|
|
1529
1487
|
|
|
1530
1488
|
Args:
|
|
1531
1489
|
d (dict): Model dictionary.
|
|
@@ -1543,10 +1501,10 @@ def parse_model(d, ch, verbose=True):
|
|
|
1543
1501
|
max_channels = float("inf")
|
|
1544
1502
|
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
|
1545
1503
|
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
|
1504
|
+
scale = d.get("scale")
|
|
1546
1505
|
if scales:
|
|
1547
|
-
scale = d.get("scale")
|
|
1548
1506
|
if not scale:
|
|
1549
|
-
scale =
|
|
1507
|
+
scale = next(iter(scales.keys()))
|
|
1550
1508
|
LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
|
|
1551
1509
|
depth, width, max_channels = scales[scale]
|
|
1552
1510
|
|
|
@@ -1631,7 +1589,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1631
1589
|
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
|
1632
1590
|
if m in base_modules:
|
|
1633
1591
|
c1, c2 = ch[f], args[0]
|
|
1634
|
-
if c2 != nc: # if c2
|
|
1592
|
+
if c2 != nc: # if c2 != nc (e.g., Classify() output)
|
|
1635
1593
|
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
|
1636
1594
|
if m is C2fAttn: # set 1) embed channels and 2) num heads
|
|
1637
1595
|
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
|
|
@@ -1693,7 +1651,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1693
1651
|
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
|
1694
1652
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
|
1695
1653
|
if verbose:
|
|
1696
|
-
LOGGER.info(f"{i:>3}{
|
|
1654
|
+
LOGGER.info(f"{i:>3}{f!s:>20}{n_:>3}{m_.np:10.0f} {t:<45}{args!s:<30}") # print
|
|
1697
1655
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
|
1698
1656
|
layers.append(m_)
|
|
1699
1657
|
if i == 0:
|
|
@@ -1703,8 +1661,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1703
1661
|
|
|
1704
1662
|
|
|
1705
1663
|
def yaml_model_load(path):
|
|
1706
|
-
"""
|
|
1707
|
-
Load a YOLOv8 model from a YAML file.
|
|
1664
|
+
"""Load a YOLOv8 model from a YAML file.
|
|
1708
1665
|
|
|
1709
1666
|
Args:
|
|
1710
1667
|
path (str | Path): Path to the YAML file.
|
|
@@ -1727,8 +1684,7 @@ def yaml_model_load(path):
|
|
|
1727
1684
|
|
|
1728
1685
|
|
|
1729
1686
|
def guess_model_scale(model_path):
|
|
1730
|
-
"""
|
|
1731
|
-
Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
|
1687
|
+
"""Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
|
1732
1688
|
|
|
1733
1689
|
Args:
|
|
1734
1690
|
model_path (str | Path): The path to the YOLO model's YAML file.
|
|
@@ -1737,14 +1693,13 @@ def guess_model_scale(model_path):
|
|
|
1737
1693
|
(str): The size character of the model's scale (n, s, m, l, or x).
|
|
1738
1694
|
"""
|
|
1739
1695
|
try:
|
|
1740
|
-
return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
|
|
1696
|
+
return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
|
|
1741
1697
|
except AttributeError:
|
|
1742
1698
|
return ""
|
|
1743
1699
|
|
|
1744
1700
|
|
|
1745
1701
|
def guess_model_task(model):
|
|
1746
|
-
"""
|
|
1747
|
-
Guess the task of a PyTorch model from its architecture or configuration.
|
|
1702
|
+
"""Guess the task of a PyTorch model from its architecture or configuration.
|
|
1748
1703
|
|
|
1749
1704
|
Args:
|
|
1750
1705
|
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
|
|
@@ -1775,10 +1730,10 @@ def guess_model_task(model):
|
|
|
1775
1730
|
if isinstance(model, torch.nn.Module): # PyTorch model
|
|
1776
1731
|
for x in "model.args", "model.model.args", "model.model.model.args":
|
|
1777
1732
|
with contextlib.suppress(Exception):
|
|
1778
|
-
return eval(x)["task"]
|
|
1733
|
+
return eval(x)["task"] # nosec B307: safe eval of known attribute paths
|
|
1779
1734
|
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
|
1780
1735
|
with contextlib.suppress(Exception):
|
|
1781
|
-
return cfg2task(eval(x))
|
|
1736
|
+
return cfg2task(eval(x)) # nosec B307: safe eval of known attribute paths
|
|
1782
1737
|
for m in model.modules():
|
|
1783
1738
|
if isinstance(m, (Segment, YOLOESegment)):
|
|
1784
1739
|
return "segment"
|