dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py
CHANGED
|
@@ -95,11 +95,10 @@ from ultralytics.utils.torch_utils import (
|
|
|
95
95
|
|
|
96
96
|
|
|
97
97
|
class BaseModel(torch.nn.Module):
|
|
98
|
-
"""
|
|
99
|
-
Base class for all YOLO models in the Ultralytics family.
|
|
98
|
+
"""Base class for all YOLO models in the Ultralytics family.
|
|
100
99
|
|
|
101
|
-
This class provides common functionality for YOLO models including forward pass handling, model fusion,
|
|
102
|
-
|
|
100
|
+
This class provides common functionality for YOLO models including forward pass handling, model fusion, information
|
|
101
|
+
display, and weight loading capabilities.
|
|
103
102
|
|
|
104
103
|
Attributes:
|
|
105
104
|
model (torch.nn.Module): The neural network model.
|
|
@@ -121,8 +120,7 @@ class BaseModel(torch.nn.Module):
|
|
|
121
120
|
"""
|
|
122
121
|
|
|
123
122
|
def forward(self, x, *args, **kwargs):
|
|
124
|
-
"""
|
|
125
|
-
Perform forward pass of the model for either training or inference.
|
|
123
|
+
"""Perform forward pass of the model for either training or inference.
|
|
126
124
|
|
|
127
125
|
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
|
128
126
|
|
|
@@ -139,8 +137,7 @@ class BaseModel(torch.nn.Module):
|
|
|
139
137
|
return self.predict(x, *args, **kwargs)
|
|
140
138
|
|
|
141
139
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
|
142
|
-
"""
|
|
143
|
-
Perform a forward pass through the network.
|
|
140
|
+
"""Perform a forward pass through the network.
|
|
144
141
|
|
|
145
142
|
Args:
|
|
146
143
|
x (torch.Tensor): The input tensor to the model.
|
|
@@ -157,8 +154,7 @@ class BaseModel(torch.nn.Module):
|
|
|
157
154
|
return self._predict_once(x, profile, visualize, embed)
|
|
158
155
|
|
|
159
156
|
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
|
160
|
-
"""
|
|
161
|
-
Perform a forward pass through the network.
|
|
157
|
+
"""Perform a forward pass through the network.
|
|
162
158
|
|
|
163
159
|
Args:
|
|
164
160
|
x (torch.Tensor): The input tensor to the model.
|
|
@@ -196,8 +192,7 @@ class BaseModel(torch.nn.Module):
|
|
|
196
192
|
return self._predict_once(x)
|
|
197
193
|
|
|
198
194
|
def _profile_one_layer(self, m, x, dt):
|
|
199
|
-
"""
|
|
200
|
-
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
|
195
|
+
"""Profile the computation time and FLOPs of a single layer of the model on a given input.
|
|
201
196
|
|
|
202
197
|
Args:
|
|
203
198
|
m (torch.nn.Module): The layer to be profiled.
|
|
@@ -222,8 +217,7 @@ class BaseModel(torch.nn.Module):
|
|
|
222
217
|
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
|
223
218
|
|
|
224
219
|
def fuse(self, verbose=True):
|
|
225
|
-
"""
|
|
226
|
-
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
|
220
|
+
"""Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
|
227
221
|
efficiency.
|
|
228
222
|
|
|
229
223
|
Returns:
|
|
@@ -254,8 +248,7 @@ class BaseModel(torch.nn.Module):
|
|
|
254
248
|
return self
|
|
255
249
|
|
|
256
250
|
def is_fused(self, thresh=10):
|
|
257
|
-
"""
|
|
258
|
-
Check if the model has less than a certain threshold of BatchNorm layers.
|
|
251
|
+
"""Check if the model has less than a certain threshold of BatchNorm layers.
|
|
259
252
|
|
|
260
253
|
Args:
|
|
261
254
|
thresh (int, optional): The threshold number of BatchNorm layers.
|
|
@@ -267,8 +260,7 @@ class BaseModel(torch.nn.Module):
|
|
|
267
260
|
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
|
268
261
|
|
|
269
262
|
def info(self, detailed=False, verbose=True, imgsz=640):
|
|
270
|
-
"""
|
|
271
|
-
Print model information.
|
|
263
|
+
"""Print model information.
|
|
272
264
|
|
|
273
265
|
Args:
|
|
274
266
|
detailed (bool): If True, prints out detailed information about the model.
|
|
@@ -278,8 +270,7 @@ class BaseModel(torch.nn.Module):
|
|
|
278
270
|
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
|
279
271
|
|
|
280
272
|
def _apply(self, fn):
|
|
281
|
-
"""
|
|
282
|
-
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
273
|
+
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
283
274
|
|
|
284
275
|
Args:
|
|
285
276
|
fn (function): The function to apply to the model.
|
|
@@ -298,8 +289,7 @@ class BaseModel(torch.nn.Module):
|
|
|
298
289
|
return self
|
|
299
290
|
|
|
300
291
|
def load(self, weights, verbose=True):
|
|
301
|
-
"""
|
|
302
|
-
Load weights into the model.
|
|
292
|
+
"""Load weights into the model.
|
|
303
293
|
|
|
304
294
|
Args:
|
|
305
295
|
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
|
@@ -324,8 +314,7 @@ class BaseModel(torch.nn.Module):
|
|
|
324
314
|
LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
|
|
325
315
|
|
|
326
316
|
def loss(self, batch, preds=None):
|
|
327
|
-
"""
|
|
328
|
-
Compute loss.
|
|
317
|
+
"""Compute loss.
|
|
329
318
|
|
|
330
319
|
Args:
|
|
331
320
|
batch (dict): Batch to compute loss on.
|
|
@@ -344,11 +333,10 @@ class BaseModel(torch.nn.Module):
|
|
|
344
333
|
|
|
345
334
|
|
|
346
335
|
class DetectionModel(BaseModel):
|
|
347
|
-
"""
|
|
348
|
-
YOLO detection model.
|
|
336
|
+
"""YOLO detection model.
|
|
349
337
|
|
|
350
|
-
This class implements the YOLO detection architecture, handling model initialization, forward pass,
|
|
351
|
-
|
|
338
|
+
This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
|
|
339
|
+
inference, and loss computation for object detection tasks.
|
|
352
340
|
|
|
353
341
|
Attributes:
|
|
354
342
|
yaml (dict): Model configuration dictionary.
|
|
@@ -373,8 +361,7 @@ class DetectionModel(BaseModel):
|
|
|
373
361
|
"""
|
|
374
362
|
|
|
375
363
|
def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
|
|
376
|
-
"""
|
|
377
|
-
Initialize the YOLO detection model with the given config and parameters.
|
|
364
|
+
"""Initialize the YOLO detection model with the given config and parameters.
|
|
378
365
|
|
|
379
366
|
Args:
|
|
380
367
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -420,7 +407,7 @@ class DetectionModel(BaseModel):
|
|
|
420
407
|
self.model.train() # Set model back to training(default) mode
|
|
421
408
|
m.bias_init() # only run once
|
|
422
409
|
else:
|
|
423
|
-
self.stride = torch.Tensor([32]) # default stride
|
|
410
|
+
self.stride = torch.Tensor([32]) # default stride, e.g., RTDETR
|
|
424
411
|
|
|
425
412
|
# Init weights, biases
|
|
426
413
|
initialize_weights(self)
|
|
@@ -429,8 +416,7 @@ class DetectionModel(BaseModel):
|
|
|
429
416
|
LOGGER.info("")
|
|
430
417
|
|
|
431
418
|
def _predict_augment(self, x):
|
|
432
|
-
"""
|
|
433
|
-
Perform augmentations on input image x and return augmented inference and train outputs.
|
|
419
|
+
"""Perform augmentations on input image x and return augmented inference and train outputs.
|
|
434
420
|
|
|
435
421
|
Args:
|
|
436
422
|
x (torch.Tensor): Input image tensor.
|
|
@@ -455,8 +441,7 @@ class DetectionModel(BaseModel):
|
|
|
455
441
|
|
|
456
442
|
@staticmethod
|
|
457
443
|
def _descale_pred(p, flips, scale, img_size, dim=1):
|
|
458
|
-
"""
|
|
459
|
-
De-scale predictions following augmented inference (inverse operation).
|
|
444
|
+
"""De-scale predictions following augmented inference (inverse operation).
|
|
460
445
|
|
|
461
446
|
Args:
|
|
462
447
|
p (torch.Tensor): Predictions tensor.
|
|
@@ -477,8 +462,7 @@ class DetectionModel(BaseModel):
|
|
|
477
462
|
return torch.cat((x, y, wh, cls), dim)
|
|
478
463
|
|
|
479
464
|
def _clip_augmented(self, y):
|
|
480
|
-
"""
|
|
481
|
-
Clip YOLO augmented inference tails.
|
|
465
|
+
"""Clip YOLO augmented inference tails.
|
|
482
466
|
|
|
483
467
|
Args:
|
|
484
468
|
y (list[torch.Tensor]): List of detection tensors.
|
|
@@ -501,11 +485,10 @@ class DetectionModel(BaseModel):
|
|
|
501
485
|
|
|
502
486
|
|
|
503
487
|
class OBBModel(DetectionModel):
|
|
504
|
-
"""
|
|
505
|
-
YOLO Oriented Bounding Box (OBB) model.
|
|
488
|
+
"""YOLO Oriented Bounding Box (OBB) model.
|
|
506
489
|
|
|
507
|
-
This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized
|
|
508
|
-
|
|
490
|
+
This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
|
|
491
|
+
computation for rotated object detection.
|
|
509
492
|
|
|
510
493
|
Methods:
|
|
511
494
|
__init__: Initialize YOLO OBB model.
|
|
@@ -518,8 +501,7 @@ class OBBModel(DetectionModel):
|
|
|
518
501
|
"""
|
|
519
502
|
|
|
520
503
|
def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
|
|
521
|
-
"""
|
|
522
|
-
Initialize YOLO OBB model with given config and parameters.
|
|
504
|
+
"""Initialize YOLO OBB model with given config and parameters.
|
|
523
505
|
|
|
524
506
|
Args:
|
|
525
507
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -535,11 +517,10 @@ class OBBModel(DetectionModel):
|
|
|
535
517
|
|
|
536
518
|
|
|
537
519
|
class SegmentationModel(DetectionModel):
|
|
538
|
-
"""
|
|
539
|
-
YOLO segmentation model.
|
|
520
|
+
"""YOLO segmentation model.
|
|
540
521
|
|
|
541
|
-
This class extends DetectionModel to handle instance segmentation tasks, providing specialized
|
|
542
|
-
|
|
522
|
+
This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
|
|
523
|
+
pixel-level object detection and segmentation.
|
|
543
524
|
|
|
544
525
|
Methods:
|
|
545
526
|
__init__: Initialize YOLO segmentation model.
|
|
@@ -552,8 +533,7 @@ class SegmentationModel(DetectionModel):
|
|
|
552
533
|
"""
|
|
553
534
|
|
|
554
535
|
def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
|
|
555
|
-
"""
|
|
556
|
-
Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
|
536
|
+
"""Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
|
557
537
|
|
|
558
538
|
Args:
|
|
559
539
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -569,11 +549,10 @@ class SegmentationModel(DetectionModel):
|
|
|
569
549
|
|
|
570
550
|
|
|
571
551
|
class PoseModel(DetectionModel):
|
|
572
|
-
"""
|
|
573
|
-
YOLO pose model.
|
|
552
|
+
"""YOLO pose model.
|
|
574
553
|
|
|
575
|
-
This class extends DetectionModel to handle human pose estimation tasks, providing specialized
|
|
576
|
-
|
|
554
|
+
This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
|
|
555
|
+
keypoint detection and pose estimation.
|
|
577
556
|
|
|
578
557
|
Attributes:
|
|
579
558
|
kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).
|
|
@@ -589,8 +568,7 @@ class PoseModel(DetectionModel):
|
|
|
589
568
|
"""
|
|
590
569
|
|
|
591
570
|
def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
|
592
|
-
"""
|
|
593
|
-
Initialize Ultralytics YOLO Pose model.
|
|
571
|
+
"""Initialize Ultralytics YOLO Pose model.
|
|
594
572
|
|
|
595
573
|
Args:
|
|
596
574
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -612,11 +590,10 @@ class PoseModel(DetectionModel):
|
|
|
612
590
|
|
|
613
591
|
|
|
614
592
|
class ClassificationModel(BaseModel):
|
|
615
|
-
"""
|
|
616
|
-
YOLO classification model.
|
|
593
|
+
"""YOLO classification model.
|
|
617
594
|
|
|
618
|
-
This class implements the YOLO classification architecture for image classification tasks,
|
|
619
|
-
|
|
595
|
+
This class implements the YOLO classification architecture for image classification tasks, providing model
|
|
596
|
+
initialization, configuration, and output reshaping capabilities.
|
|
620
597
|
|
|
621
598
|
Attributes:
|
|
622
599
|
yaml (dict): Model configuration dictionary.
|
|
@@ -637,8 +614,7 @@ class ClassificationModel(BaseModel):
|
|
|
637
614
|
"""
|
|
638
615
|
|
|
639
616
|
def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
|
|
640
|
-
"""
|
|
641
|
-
Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
|
617
|
+
"""Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
|
642
618
|
|
|
643
619
|
Args:
|
|
644
620
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -650,8 +626,7 @@ class ClassificationModel(BaseModel):
|
|
|
650
626
|
self._from_yaml(cfg, ch, nc, verbose)
|
|
651
627
|
|
|
652
628
|
def _from_yaml(self, cfg, ch, nc, verbose):
|
|
653
|
-
"""
|
|
654
|
-
Set Ultralytics YOLO model configurations and define the model architecture.
|
|
629
|
+
"""Set Ultralytics YOLO model configurations and define the model architecture.
|
|
655
630
|
|
|
656
631
|
Args:
|
|
657
632
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -675,8 +650,7 @@ class ClassificationModel(BaseModel):
|
|
|
675
650
|
|
|
676
651
|
@staticmethod
|
|
677
652
|
def reshape_outputs(model, nc):
|
|
678
|
-
"""
|
|
679
|
-
Update a TorchVision classification model to class count 'n' if required.
|
|
653
|
+
"""Update a TorchVision classification model to class count 'n' if required.
|
|
680
654
|
|
|
681
655
|
Args:
|
|
682
656
|
model (torch.nn.Module): Model to update.
|
|
@@ -708,8 +682,7 @@ class ClassificationModel(BaseModel):
|
|
|
708
682
|
|
|
709
683
|
|
|
710
684
|
class RTDETRDetectionModel(DetectionModel):
|
|
711
|
-
"""
|
|
712
|
-
RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
|
685
|
+
"""RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
|
713
686
|
|
|
714
687
|
This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
|
|
715
688
|
the training and inference processes. RTDETR is an object detection and tracking model that extends from the
|
|
@@ -732,8 +705,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
732
705
|
"""
|
|
733
706
|
|
|
734
707
|
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
|
|
735
|
-
"""
|
|
736
|
-
Initialize the RTDETRDetectionModel.
|
|
708
|
+
"""Initialize the RTDETRDetectionModel.
|
|
737
709
|
|
|
738
710
|
Args:
|
|
739
711
|
cfg (str | dict): Configuration file name or path.
|
|
@@ -744,8 +716,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
744
716
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
745
717
|
|
|
746
718
|
def _apply(self, fn):
|
|
747
|
-
"""
|
|
748
|
-
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
719
|
+
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
|
749
720
|
|
|
750
721
|
Args:
|
|
751
722
|
fn (function): The function to apply to the model.
|
|
@@ -766,8 +737,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
766
737
|
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
|
767
738
|
|
|
768
739
|
def loss(self, batch, preds=None):
|
|
769
|
-
"""
|
|
770
|
-
Compute the loss for the given batch of data.
|
|
740
|
+
"""Compute the loss for the given batch of data.
|
|
771
741
|
|
|
772
742
|
Args:
|
|
773
743
|
batch (dict): Dictionary containing image and label data.
|
|
@@ -813,8 +783,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
813
783
|
)
|
|
814
784
|
|
|
815
785
|
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
|
816
|
-
"""
|
|
817
|
-
Perform a forward pass through the model.
|
|
786
|
+
"""Perform a forward pass through the model.
|
|
818
787
|
|
|
819
788
|
Args:
|
|
820
789
|
x (torch.Tensor): The input tensor.
|
|
@@ -849,11 +818,10 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
849
818
|
|
|
850
819
|
|
|
851
820
|
class WorldModel(DetectionModel):
|
|
852
|
-
"""
|
|
853
|
-
YOLOv8 World Model.
|
|
821
|
+
"""YOLOv8 World Model.
|
|
854
822
|
|
|
855
|
-
This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based
|
|
856
|
-
|
|
823
|
+
This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
|
|
824
|
+
specification and CLIP model integration for zero-shot detection capabilities.
|
|
857
825
|
|
|
858
826
|
Attributes:
|
|
859
827
|
txt_feats (torch.Tensor): Text feature embeddings for classes.
|
|
@@ -874,8 +842,7 @@ class WorldModel(DetectionModel):
|
|
|
874
842
|
"""
|
|
875
843
|
|
|
876
844
|
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
|
877
|
-
"""
|
|
878
|
-
Initialize YOLOv8 world model with given config and parameters.
|
|
845
|
+
"""Initialize YOLOv8 world model with given config and parameters.
|
|
879
846
|
|
|
880
847
|
Args:
|
|
881
848
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -888,8 +855,7 @@ class WorldModel(DetectionModel):
|
|
|
888
855
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
889
856
|
|
|
890
857
|
def set_classes(self, text, batch=80, cache_clip_model=True):
|
|
891
|
-
"""
|
|
892
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
858
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
893
859
|
|
|
894
860
|
Args:
|
|
895
861
|
text (list[str]): List of class names.
|
|
@@ -900,8 +866,7 @@ class WorldModel(DetectionModel):
|
|
|
900
866
|
self.model[-1].nc = len(text)
|
|
901
867
|
|
|
902
868
|
def get_text_pe(self, text, batch=80, cache_clip_model=True):
|
|
903
|
-
"""
|
|
904
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
869
|
+
"""Get text positional embeddings for offline inference without CLIP model.
|
|
905
870
|
|
|
906
871
|
Args:
|
|
907
872
|
text (list[str]): List of class names.
|
|
@@ -924,8 +889,7 @@ class WorldModel(DetectionModel):
|
|
|
924
889
|
return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
|
925
890
|
|
|
926
891
|
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
|
927
|
-
"""
|
|
928
|
-
Perform a forward pass through the model.
|
|
892
|
+
"""Perform a forward pass through the model.
|
|
929
893
|
|
|
930
894
|
Args:
|
|
931
895
|
x (torch.Tensor): The input tensor.
|
|
@@ -969,8 +933,7 @@ class WorldModel(DetectionModel):
|
|
|
969
933
|
return x
|
|
970
934
|
|
|
971
935
|
def loss(self, batch, preds=None):
|
|
972
|
-
"""
|
|
973
|
-
Compute loss.
|
|
936
|
+
"""Compute loss.
|
|
974
937
|
|
|
975
938
|
Args:
|
|
976
939
|
batch (dict): Batch to compute loss on.
|
|
@@ -985,11 +948,10 @@ class WorldModel(DetectionModel):
|
|
|
985
948
|
|
|
986
949
|
|
|
987
950
|
class YOLOEModel(DetectionModel):
|
|
988
|
-
"""
|
|
989
|
-
YOLOE detection model.
|
|
951
|
+
"""YOLOE detection model.
|
|
990
952
|
|
|
991
|
-
This class implements the YOLOE architecture for efficient object detection with text and visual prompts,
|
|
992
|
-
|
|
953
|
+
This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
|
|
954
|
+
both prompt-based and prompt-free inference modes.
|
|
993
955
|
|
|
994
956
|
Attributes:
|
|
995
957
|
pe (torch.Tensor): Prompt embeddings for classes.
|
|
@@ -1013,8 +975,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1013
975
|
"""
|
|
1014
976
|
|
|
1015
977
|
def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
|
|
1016
|
-
"""
|
|
1017
|
-
Initialize YOLOE model with given config and parameters.
|
|
978
|
+
"""Initialize YOLOE model with given config and parameters.
|
|
1018
979
|
|
|
1019
980
|
Args:
|
|
1020
981
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -1026,14 +987,13 @@ class YOLOEModel(DetectionModel):
|
|
|
1026
987
|
|
|
1027
988
|
@smart_inference_mode()
|
|
1028
989
|
def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
|
|
1029
|
-
"""
|
|
1030
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
990
|
+
"""Get text positional embeddings for offline inference without CLIP model.
|
|
1031
991
|
|
|
1032
992
|
Args:
|
|
1033
993
|
text (list[str]): List of class names.
|
|
1034
994
|
batch (int): Batch size for processing text tokens.
|
|
1035
995
|
cache_clip_model (bool): Whether to cache the CLIP model.
|
|
1036
|
-
without_reprta (bool): Whether to return text embeddings
|
|
996
|
+
without_reprta (bool): Whether to return text embeddings without reprta module processing.
|
|
1037
997
|
|
|
1038
998
|
Returns:
|
|
1039
999
|
(torch.Tensor): Text positional embeddings.
|
|
@@ -1059,8 +1019,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1059
1019
|
|
|
1060
1020
|
@smart_inference_mode()
|
|
1061
1021
|
def get_visual_pe(self, img, visual):
|
|
1062
|
-
"""
|
|
1063
|
-
Get visual embeddings.
|
|
1022
|
+
"""Get visual embeddings.
|
|
1064
1023
|
|
|
1065
1024
|
Args:
|
|
1066
1025
|
img (torch.Tensor): Input image tensor.
|
|
@@ -1072,8 +1031,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1072
1031
|
return self(img, vpe=visual, return_vpe=True)
|
|
1073
1032
|
|
|
1074
1033
|
def set_vocab(self, vocab, names):
|
|
1075
|
-
"""
|
|
1076
|
-
Set vocabulary for the prompt-free model.
|
|
1034
|
+
"""Set vocabulary for the prompt-free model.
|
|
1077
1035
|
|
|
1078
1036
|
Args:
|
|
1079
1037
|
vocab (nn.ModuleList): List of vocabulary items.
|
|
@@ -1101,8 +1059,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1101
1059
|
self.names = check_class_names(names)
|
|
1102
1060
|
|
|
1103
1061
|
def get_vocab(self, names):
|
|
1104
|
-
"""
|
|
1105
|
-
Get fused vocabulary layer from the model.
|
|
1062
|
+
"""Get fused vocabulary layer from the model.
|
|
1106
1063
|
|
|
1107
1064
|
Args:
|
|
1108
1065
|
names (list): List of class names.
|
|
@@ -1127,8 +1084,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1127
1084
|
return vocab
|
|
1128
1085
|
|
|
1129
1086
|
def set_classes(self, names, embeddings):
|
|
1130
|
-
"""
|
|
1131
|
-
Set classes in advance so that model could do offline-inference without clip model.
|
|
1087
|
+
"""Set classes in advance so that model could do offline-inference without clip model.
|
|
1132
1088
|
|
|
1133
1089
|
Args:
|
|
1134
1090
|
names (list[str]): List of class names.
|
|
@@ -1143,8 +1099,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1143
1099
|
self.names = check_class_names(names)
|
|
1144
1100
|
|
|
1145
1101
|
def get_cls_pe(self, tpe, vpe):
|
|
1146
|
-
"""
|
|
1147
|
-
Get class positional embeddings.
|
|
1102
|
+
"""Get class positional embeddings.
|
|
1148
1103
|
|
|
1149
1104
|
Args:
|
|
1150
1105
|
tpe (torch.Tensor, optional): Text positional embeddings.
|
|
@@ -1167,8 +1122,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1167
1122
|
def predict(
|
|
1168
1123
|
self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
|
|
1169
1124
|
):
|
|
1170
|
-
"""
|
|
1171
|
-
Perform a forward pass through the model.
|
|
1125
|
+
"""Perform a forward pass through the model.
|
|
1172
1126
|
|
|
1173
1127
|
Args:
|
|
1174
1128
|
x (torch.Tensor): The input tensor.
|
|
@@ -1215,8 +1169,7 @@ class YOLOEModel(DetectionModel):
|
|
|
1215
1169
|
return x
|
|
1216
1170
|
|
|
1217
1171
|
def loss(self, batch, preds=None):
|
|
1218
|
-
"""
|
|
1219
|
-
Compute loss.
|
|
1172
|
+
"""Compute loss.
|
|
1220
1173
|
|
|
1221
1174
|
Args:
|
|
1222
1175
|
batch (dict): Batch to compute loss on.
|
|
@@ -1234,11 +1187,10 @@ class YOLOEModel(DetectionModel):
|
|
|
1234
1187
|
|
|
1235
1188
|
|
|
1236
1189
|
class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
1237
|
-
"""
|
|
1238
|
-
YOLOE segmentation model.
|
|
1190
|
+
"""YOLOE segmentation model.
|
|
1239
1191
|
|
|
1240
|
-
This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts,
|
|
1241
|
-
|
|
1192
|
+
This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
|
|
1193
|
+
specialized loss computation for pixel-level object detection and segmentation.
|
|
1242
1194
|
|
|
1243
1195
|
Methods:
|
|
1244
1196
|
__init__: Initialize YOLOE segmentation model.
|
|
@@ -1251,8 +1203,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1251
1203
|
"""
|
|
1252
1204
|
|
|
1253
1205
|
def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
|
|
1254
|
-
"""
|
|
1255
|
-
Initialize YOLOE segmentation model with given config and parameters.
|
|
1206
|
+
"""Initialize YOLOE segmentation model with given config and parameters.
|
|
1256
1207
|
|
|
1257
1208
|
Args:
|
|
1258
1209
|
cfg (str | dict): Model configuration file path or dictionary.
|
|
@@ -1263,8 +1214,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1263
1214
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
1264
1215
|
|
|
1265
1216
|
def loss(self, batch, preds=None):
|
|
1266
|
-
"""
|
|
1267
|
-
Compute loss.
|
|
1217
|
+
"""Compute loss.
|
|
1268
1218
|
|
|
1269
1219
|
Args:
|
|
1270
1220
|
batch (dict): Batch to compute loss on.
|
|
@@ -1282,11 +1232,10 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|
|
1282
1232
|
|
|
1283
1233
|
|
|
1284
1234
|
class Ensemble(torch.nn.ModuleList):
|
|
1285
|
-
"""
|
|
1286
|
-
Ensemble of models.
|
|
1235
|
+
"""Ensemble of models.
|
|
1287
1236
|
|
|
1288
|
-
This class allows combining multiple YOLO models into an ensemble for improved performance through
|
|
1289
|
-
|
|
1237
|
+
This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
|
|
1238
|
+
or other ensemble techniques.
|
|
1290
1239
|
|
|
1291
1240
|
Methods:
|
|
1292
1241
|
__init__: Initialize an ensemble of models.
|
|
@@ -1305,8 +1254,7 @@ class Ensemble(torch.nn.ModuleList):
|
|
|
1305
1254
|
super().__init__()
|
|
1306
1255
|
|
|
1307
1256
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
|
1308
|
-
"""
|
|
1309
|
-
Generate the YOLO network's final layer.
|
|
1257
|
+
"""Generate the YOLO network's final layer.
|
|
1310
1258
|
|
|
1311
1259
|
Args:
|
|
1312
1260
|
x (torch.Tensor): Input tensor.
|
|
@@ -1330,12 +1278,11 @@ class Ensemble(torch.nn.ModuleList):
|
|
|
1330
1278
|
|
|
1331
1279
|
@contextlib.contextmanager
|
|
1332
1280
|
def temporary_modules(modules=None, attributes=None):
|
|
1333
|
-
"""
|
|
1334
|
-
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
|
1281
|
+
"""Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
|
1335
1282
|
|
|
1336
|
-
This function can be used to change the module paths during runtime. It's useful when refactoring code,
|
|
1337
|
-
|
|
1338
|
-
|
|
1283
|
+
This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
|
|
1284
|
+
moved a module from one location to another, but you still want to support the old import paths for backwards
|
|
1285
|
+
compatibility.
|
|
1339
1286
|
|
|
1340
1287
|
Args:
|
|
1341
1288
|
modules (dict, optional): A dictionary mapping old module paths to new module paths.
|
|
@@ -1346,7 +1293,7 @@ def temporary_modules(modules=None, attributes=None):
|
|
|
1346
1293
|
>>> import old.module # this will now import new.module
|
|
1347
1294
|
>>> from old.module import attribute # this will now import new.module.attribute
|
|
1348
1295
|
|
|
1349
|
-
|
|
1296
|
+
Notes:
|
|
1350
1297
|
The changes are only in effect inside the context manager and are undone once the context manager exits.
|
|
1351
1298
|
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
|
1352
1299
|
applications or libraries. Use this function with caution.
|
|
@@ -1393,8 +1340,7 @@ class SafeUnpickler(pickle.Unpickler):
|
|
|
1393
1340
|
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
|
1394
1341
|
|
|
1395
1342
|
def find_class(self, module, name):
|
|
1396
|
-
"""
|
|
1397
|
-
Attempt to find a class, returning SafeClass if not among safe modules.
|
|
1343
|
+
"""Attempt to find a class, returning SafeClass if not among safe modules.
|
|
1398
1344
|
|
|
1399
1345
|
Args:
|
|
1400
1346
|
module (str): Module name.
|
|
@@ -1419,10 +1365,9 @@ class SafeUnpickler(pickle.Unpickler):
|
|
|
1419
1365
|
|
|
1420
1366
|
|
|
1421
1367
|
def torch_safe_load(weight, safe_only=False):
|
|
1422
|
-
"""
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
After installation, the function again attempts to load the model using torch.load().
|
|
1368
|
+
"""Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
|
|
1369
|
+
the error, logs a warning message, and attempts to install the missing module via the check_requirements()
|
|
1370
|
+
function. After installation, the function again attempts to load the model using torch.load().
|
|
1426
1371
|
|
|
1427
1372
|
Args:
|
|
1428
1373
|
weight (str): The file path of the PyTorch model.
|
|
@@ -1501,8 +1446,7 @@ def torch_safe_load(weight, safe_only=False):
|
|
|
1501
1446
|
|
|
1502
1447
|
|
|
1503
1448
|
def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
|
1504
|
-
"""
|
|
1505
|
-
Load a single model weights.
|
|
1449
|
+
"""Load a single model weights.
|
|
1506
1450
|
|
|
1507
1451
|
Args:
|
|
1508
1452
|
weight (str | Path): Model weight path.
|
|
@@ -1539,8 +1483,7 @@ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
|
|
1539
1483
|
|
|
1540
1484
|
|
|
1541
1485
|
def parse_model(d, ch, verbose=True):
|
|
1542
|
-
"""
|
|
1543
|
-
Parse a YOLO model.yaml dictionary into a PyTorch model.
|
|
1486
|
+
"""Parse a YOLO model.yaml dictionary into a PyTorch model.
|
|
1544
1487
|
|
|
1545
1488
|
Args:
|
|
1546
1489
|
d (dict): Model dictionary.
|
|
@@ -1561,7 +1504,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1561
1504
|
scale = d.get("scale")
|
|
1562
1505
|
if scales:
|
|
1563
1506
|
if not scale:
|
|
1564
|
-
scale =
|
|
1507
|
+
scale = next(iter(scales.keys()))
|
|
1565
1508
|
LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
|
|
1566
1509
|
depth, width, max_channels = scales[scale]
|
|
1567
1510
|
|
|
@@ -1646,7 +1589,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1646
1589
|
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
|
1647
1590
|
if m in base_modules:
|
|
1648
1591
|
c1, c2 = ch[f], args[0]
|
|
1649
|
-
if c2 != nc: # if c2
|
|
1592
|
+
if c2 != nc: # if c2 != nc (e.g., Classify() output)
|
|
1650
1593
|
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
|
1651
1594
|
if m is C2fAttn: # set 1) embed channels and 2) num heads
|
|
1652
1595
|
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
|
|
@@ -1708,7 +1651,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1708
1651
|
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
|
1709
1652
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
|
1710
1653
|
if verbose:
|
|
1711
|
-
LOGGER.info(f"{i:>3}{
|
|
1654
|
+
LOGGER.info(f"{i:>3}{f!s:>20}{n_:>3}{m_.np:10.0f} {t:<45}{args!s:<30}") # print
|
|
1712
1655
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
|
1713
1656
|
layers.append(m_)
|
|
1714
1657
|
if i == 0:
|
|
@@ -1718,8 +1661,7 @@ def parse_model(d, ch, verbose=True):
|
|
|
1718
1661
|
|
|
1719
1662
|
|
|
1720
1663
|
def yaml_model_load(path):
|
|
1721
|
-
"""
|
|
1722
|
-
Load a YOLOv8 model from a YAML file.
|
|
1664
|
+
"""Load a YOLOv8 model from a YAML file.
|
|
1723
1665
|
|
|
1724
1666
|
Args:
|
|
1725
1667
|
path (str | Path): Path to the YAML file.
|
|
@@ -1742,8 +1684,7 @@ def yaml_model_load(path):
|
|
|
1742
1684
|
|
|
1743
1685
|
|
|
1744
1686
|
def guess_model_scale(model_path):
|
|
1745
|
-
"""
|
|
1746
|
-
Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
|
1687
|
+
"""Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
|
1747
1688
|
|
|
1748
1689
|
Args:
|
|
1749
1690
|
model_path (str | Path): The path to the YOLO model's YAML file.
|
|
@@ -1752,14 +1693,13 @@ def guess_model_scale(model_path):
|
|
|
1752
1693
|
(str): The size character of the model's scale (n, s, m, l, or x).
|
|
1753
1694
|
"""
|
|
1754
1695
|
try:
|
|
1755
|
-
return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
|
|
1696
|
+
return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
|
|
1756
1697
|
except AttributeError:
|
|
1757
1698
|
return ""
|
|
1758
1699
|
|
|
1759
1700
|
|
|
1760
1701
|
def guess_model_task(model):
|
|
1761
|
-
"""
|
|
1762
|
-
Guess the task of a PyTorch model from its architecture or configuration.
|
|
1702
|
+
"""Guess the task of a PyTorch model from its architecture or configuration.
|
|
1763
1703
|
|
|
1764
1704
|
Args:
|
|
1765
1705
|
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
|
|
@@ -1790,10 +1730,10 @@ def guess_model_task(model):
|
|
|
1790
1730
|
if isinstance(model, torch.nn.Module): # PyTorch model
|
|
1791
1731
|
for x in "model.args", "model.model.args", "model.model.model.args":
|
|
1792
1732
|
with contextlib.suppress(Exception):
|
|
1793
|
-
return eval(x)["task"]
|
|
1733
|
+
return eval(x)["task"] # nosec B307: safe eval of known attribute paths
|
|
1794
1734
|
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
|
1795
1735
|
with contextlib.suppress(Exception):
|
|
1796
|
-
return cfg2task(eval(x))
|
|
1736
|
+
return cfg2task(eval(x)) # nosec B307: safe eval of known attribute paths
|
|
1797
1737
|
for m in model.modules():
|
|
1798
1738
|
if isinstance(m, (Segment, YOLOESegment)):
|
|
1799
1739
|
return "segment"
|