ultralytics-opencv-headless 8.3.242__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1026 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1574 -0
- ultralytics/engine/model.py +1124 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +974 -0
- ultralytics/engine/tuner.py +448 -0
- ultralytics/engine/validator.py +384 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +958 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +998 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +315 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +444 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1560 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1045 -0
- ultralytics/utils/tal.py +403 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +440 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +160 -0
- ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
- ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
- ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
- ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
- ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
- ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from scipy.optimize import linear_sum_assignment
|
|
11
|
+
|
|
12
|
+
from ultralytics.utils.metrics import bbox_iou
|
|
13
|
+
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HungarianMatcher(nn.Module):
|
|
17
|
+
"""A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
|
18
|
+
|
|
19
|
+
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
|
|
20
|
+
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
|
|
21
|
+
used in end-to-end object detection models like DETR.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
cost_gain (dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'
|
|
25
|
+
components.
|
|
26
|
+
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
|
27
|
+
with_mask (bool): Whether the model makes mask predictions.
|
|
28
|
+
num_sample_points (int): Number of sample points used in mask cost calculation.
|
|
29
|
+
alpha (float): Alpha factor in Focal Loss calculation.
|
|
30
|
+
gamma (float): Gamma factor in Focal Loss calculation.
|
|
31
|
+
|
|
32
|
+
Methods:
|
|
33
|
+
forward: Compute optimal assignment between predictions and ground truths for a batch.
|
|
34
|
+
_cost_mask: Compute mask cost and dice cost if masks are predicted.
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
Initialize a HungarianMatcher with custom cost gains
|
|
38
|
+
>>> matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
|
39
|
+
|
|
40
|
+
Perform matching between predictions and ground truth
|
|
41
|
+
>>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100
|
|
42
|
+
>>> pred_scores = torch.rand(2, 100, 80) # 80 classes
|
|
43
|
+
>>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes
|
|
44
|
+
>>> gt_classes = torch.randint(0, 80, (10,))
|
|
45
|
+
>>> gt_groups = [5, 5] # 5 GT boxes per image
|
|
46
|
+
>>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
cost_gain: dict[str, float] | None = None,
|
|
52
|
+
use_fl: bool = True,
|
|
53
|
+
with_mask: bool = False,
|
|
54
|
+
num_sample_points: int = 12544,
|
|
55
|
+
alpha: float = 0.25,
|
|
56
|
+
gamma: float = 2.0,
|
|
57
|
+
):
|
|
58
|
+
"""Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
|
|
62
|
+
components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
|
63
|
+
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
|
64
|
+
with_mask (bool): Whether the model makes mask predictions.
|
|
65
|
+
num_sample_points (int): Number of sample points used in mask cost calculation.
|
|
66
|
+
alpha (float): Alpha factor in Focal Loss calculation.
|
|
67
|
+
gamma (float): Gamma factor in Focal Loss calculation.
|
|
68
|
+
"""
|
|
69
|
+
super().__init__()
|
|
70
|
+
if cost_gain is None:
|
|
71
|
+
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
|
72
|
+
self.cost_gain = cost_gain
|
|
73
|
+
self.use_fl = use_fl
|
|
74
|
+
self.with_mask = with_mask
|
|
75
|
+
self.num_sample_points = num_sample_points
|
|
76
|
+
self.alpha = alpha
|
|
77
|
+
self.gamma = gamma
|
|
78
|
+
|
|
79
|
+
def forward(
|
|
80
|
+
self,
|
|
81
|
+
pred_bboxes: torch.Tensor,
|
|
82
|
+
pred_scores: torch.Tensor,
|
|
83
|
+
gt_bboxes: torch.Tensor,
|
|
84
|
+
gt_cls: torch.Tensor,
|
|
85
|
+
gt_groups: list[int],
|
|
86
|
+
masks: torch.Tensor | None = None,
|
|
87
|
+
gt_mask: list[torch.Tensor] | None = None,
|
|
88
|
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
89
|
+
"""Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
|
90
|
+
|
|
91
|
+
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
|
|
92
|
+
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
|
|
96
|
+
pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,
|
|
97
|
+
num_classes).
|
|
98
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
|
|
99
|
+
gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).
|
|
100
|
+
gt_groups (list[int]): Number of ground truth boxes for each image in the batch.
|
|
101
|
+
masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
|
|
102
|
+
gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i,
|
|
106
|
+
index_j), where index_i is the tensor of indices of the selected predictions (in order) and index_j is
|
|
107
|
+
the tensor of indices of the corresponding selected ground truth targets (in order).
|
|
108
|
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
|
|
109
|
+
"""
|
|
110
|
+
bs, nq, nc = pred_scores.shape
|
|
111
|
+
|
|
112
|
+
if sum(gt_groups) == 0:
|
|
113
|
+
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
|
114
|
+
|
|
115
|
+
# Flatten to compute cost matrices in batch format
|
|
116
|
+
pred_scores = pred_scores.detach().view(-1, nc)
|
|
117
|
+
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
|
118
|
+
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
|
119
|
+
|
|
120
|
+
# Compute classification cost
|
|
121
|
+
pred_scores = pred_scores[:, gt_cls]
|
|
122
|
+
if self.use_fl:
|
|
123
|
+
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
|
124
|
+
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
|
125
|
+
cost_class = pos_cost_class - neg_cost_class
|
|
126
|
+
else:
|
|
127
|
+
cost_class = -pred_scores
|
|
128
|
+
|
|
129
|
+
# Compute L1 cost between boxes
|
|
130
|
+
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
|
131
|
+
|
|
132
|
+
# Compute GIoU cost between boxes, (bs*num_queries, num_gt)
|
|
133
|
+
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
|
134
|
+
|
|
135
|
+
# Combine costs into final cost matrix
|
|
136
|
+
C = (
|
|
137
|
+
self.cost_gain["class"] * cost_class
|
|
138
|
+
+ self.cost_gain["bbox"] * cost_bbox
|
|
139
|
+
+ self.cost_gain["giou"] * cost_giou
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Add mask costs if available
|
|
143
|
+
if self.with_mask:
|
|
144
|
+
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
|
145
|
+
|
|
146
|
+
# Set invalid values (NaNs and infinities) to 0
|
|
147
|
+
C[C.isnan() | C.isinf()] = 0.0
|
|
148
|
+
|
|
149
|
+
C = C.view(bs, nq, -1).cpu()
|
|
150
|
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
|
151
|
+
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
|
|
152
|
+
return [
|
|
153
|
+
(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
|
154
|
+
for k, (i, j) in enumerate(indices)
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
# This function is for future RT-DETR Segment models
|
|
158
|
+
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
|
159
|
+
# assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
|
|
160
|
+
# # all masks share the same set of points for efficient matching
|
|
161
|
+
# sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
|
|
162
|
+
# sample_points = 2.0 * sample_points - 1.0
|
|
163
|
+
#
|
|
164
|
+
# out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
|
|
165
|
+
# out_mask = out_mask.flatten(0, 1)
|
|
166
|
+
#
|
|
167
|
+
# tgt_mask = torch.cat(gt_mask).unsqueeze(1)
|
|
168
|
+
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
|
|
169
|
+
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
|
|
170
|
+
#
|
|
171
|
+
# with torch.amp.autocast("cuda", enabled=False):
|
|
172
|
+
# # binary cross entropy cost
|
|
173
|
+
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
|
|
174
|
+
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
|
|
175
|
+
# cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
|
|
176
|
+
# cost_mask /= self.num_sample_points
|
|
177
|
+
#
|
|
178
|
+
# # dice cost
|
|
179
|
+
# out_mask = F.sigmoid(out_mask)
|
|
180
|
+
# numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
|
|
181
|
+
# denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
|
|
182
|
+
# cost_dice = 1 - (numerator + 1) / (denominator + 1)
|
|
183
|
+
#
|
|
184
|
+
# C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
|
|
185
|
+
# return C
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def get_cdn_group(
|
|
189
|
+
batch: dict[str, Any],
|
|
190
|
+
num_classes: int,
|
|
191
|
+
num_queries: int,
|
|
192
|
+
class_embed: torch.Tensor,
|
|
193
|
+
num_dn: int = 100,
|
|
194
|
+
cls_noise_ratio: float = 0.5,
|
|
195
|
+
box_noise_scale: float = 1.0,
|
|
196
|
+
training: bool = False,
|
|
197
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
|
|
198
|
+
"""Generate contrastive denoising training group with positive and negative samples from ground truths.
|
|
199
|
+
|
|
200
|
+
This function creates denoising queries for contrastive denoising training by adding noise to ground truth bounding
|
|
201
|
+
boxes and class labels. It generates both positive and negative samples to improve model robustness.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)), 'gt_bboxes'
|
|
205
|
+
(torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground truths
|
|
206
|
+
per image.
|
|
207
|
+
num_classes (int): Total number of object classes.
|
|
208
|
+
num_queries (int): Number of object queries.
|
|
209
|
+
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
|
|
210
|
+
num_dn (int): Number of denoising queries to generate.
|
|
211
|
+
cls_noise_ratio (float): Noise ratio for class labels.
|
|
212
|
+
box_noise_scale (float): Noise scale for bounding box coordinates.
|
|
213
|
+
training (bool): Whether model is in training mode.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).
|
|
217
|
+
padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).
|
|
218
|
+
attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).
|
|
219
|
+
dn_meta (dict[str, Any] | None): Meta information dictionary containing denoising parameters.
|
|
220
|
+
|
|
221
|
+
Examples:
|
|
222
|
+
Generate denoising group for training
|
|
223
|
+
>>> batch = {
|
|
224
|
+
... "cls": torch.tensor([0, 1, 2]),
|
|
225
|
+
... "bboxes": torch.rand(3, 4),
|
|
226
|
+
... "batch_idx": torch.tensor([0, 0, 1]),
|
|
227
|
+
... "gt_groups": [2, 1],
|
|
228
|
+
... }
|
|
229
|
+
>>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim
|
|
230
|
+
>>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)
|
|
231
|
+
"""
|
|
232
|
+
if (not training) or num_dn <= 0 or batch is None:
|
|
233
|
+
return None, None, None, None
|
|
234
|
+
gt_groups = batch["gt_groups"]
|
|
235
|
+
total_num = sum(gt_groups)
|
|
236
|
+
max_nums = max(gt_groups)
|
|
237
|
+
if max_nums == 0:
|
|
238
|
+
return None, None, None, None
|
|
239
|
+
|
|
240
|
+
num_group = num_dn // max_nums
|
|
241
|
+
num_group = 1 if num_group == 0 else num_group
|
|
242
|
+
# Pad gt to max_num of a batch
|
|
243
|
+
bs = len(gt_groups)
|
|
244
|
+
gt_cls = batch["cls"] # (bs*num, )
|
|
245
|
+
gt_bbox = batch["bboxes"] # bs*num, 4
|
|
246
|
+
b_idx = batch["batch_idx"]
|
|
247
|
+
|
|
248
|
+
# Each group has positive and negative queries
|
|
249
|
+
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
|
250
|
+
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
|
251
|
+
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
|
252
|
+
|
|
253
|
+
# Positive and negative mask
|
|
254
|
+
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
|
255
|
+
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
|
256
|
+
|
|
257
|
+
if cls_noise_ratio > 0:
|
|
258
|
+
# Apply class label noise to half of the samples
|
|
259
|
+
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
|
260
|
+
idx = torch.nonzero(mask).squeeze(-1)
|
|
261
|
+
# Randomly assign new class labels
|
|
262
|
+
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
|
263
|
+
dn_cls[idx] = new_label
|
|
264
|
+
|
|
265
|
+
if box_noise_scale > 0:
|
|
266
|
+
known_bbox = xywh2xyxy(dn_bbox)
|
|
267
|
+
|
|
268
|
+
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
|
|
269
|
+
|
|
270
|
+
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
|
|
271
|
+
rand_part = torch.rand_like(dn_bbox)
|
|
272
|
+
rand_part[neg_idx] += 1.0
|
|
273
|
+
rand_part *= rand_sign
|
|
274
|
+
known_bbox += rand_part * diff
|
|
275
|
+
known_bbox.clip_(min=0.0, max=1.0)
|
|
276
|
+
dn_bbox = xyxy2xywh(known_bbox)
|
|
277
|
+
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
|
|
278
|
+
|
|
279
|
+
num_dn = int(max_nums * 2 * num_group) # total denoising queries
|
|
280
|
+
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
|
281
|
+
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
|
282
|
+
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|
|
283
|
+
|
|
284
|
+
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
|
|
285
|
+
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
|
|
286
|
+
|
|
287
|
+
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
|
|
288
|
+
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
|
|
289
|
+
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
|
|
290
|
+
|
|
291
|
+
tgt_size = num_dn + num_queries
|
|
292
|
+
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
|
|
293
|
+
# Match query cannot see the reconstruct
|
|
294
|
+
attn_mask[num_dn:, :num_dn] = True
|
|
295
|
+
# Reconstruct cannot see each other
|
|
296
|
+
for i in range(num_group):
|
|
297
|
+
if i == 0:
|
|
298
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
|
299
|
+
if i == num_group - 1:
|
|
300
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
|
|
301
|
+
else:
|
|
302
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
|
303
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
|
|
304
|
+
dn_meta = {
|
|
305
|
+
"dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
|
|
306
|
+
"dn_num_group": num_group,
|
|
307
|
+
"dn_num_split": [num_dn, num_queries],
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
return (
|
|
311
|
+
padding_cls.to(class_embed.device),
|
|
312
|
+
padding_bbox.to(class_embed.device),
|
|
313
|
+
attn_mask.to(class_embed.device),
|
|
314
|
+
dn_meta,
|
|
315
|
+
)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
|
|
4
|
+
|
|
5
|
+
from .model import YOLO, YOLOE, YOLOWorld
|
|
6
|
+
|
|
7
|
+
__all__ = "YOLO", "YOLOE", "YOLOWorld", "classify", "detect", "obb", "pose", "segment", "world", "yoloe"
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
|
|
4
|
+
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
|
5
|
+
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
|
6
|
+
|
|
7
|
+
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
import cv2
|
|
4
|
+
import torch
|
|
5
|
+
from PIL import Image
|
|
6
|
+
|
|
7
|
+
from ultralytics.data.augment import classify_transforms
|
|
8
|
+
from ultralytics.engine.predictor import BasePredictor
|
|
9
|
+
from ultralytics.engine.results import Results
|
|
10
|
+
from ultralytics.utils import DEFAULT_CFG, ops
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ClassificationPredictor(BasePredictor):
|
|
14
|
+
"""A class extending the BasePredictor class for prediction based on a classification model.
|
|
15
|
+
|
|
16
|
+
This predictor handles the specific requirements of classification models, including preprocessing images and
|
|
17
|
+
postprocessing predictions to generate classification results.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
args (dict): Configuration arguments for the predictor.
|
|
21
|
+
|
|
22
|
+
Methods:
|
|
23
|
+
preprocess: Convert input images to model-compatible format.
|
|
24
|
+
postprocess: Process model predictions into Results objects.
|
|
25
|
+
|
|
26
|
+
Examples:
|
|
27
|
+
>>> from ultralytics.utils import ASSETS
|
|
28
|
+
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
|
|
29
|
+
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
|
|
30
|
+
>>> predictor = ClassificationPredictor(overrides=args)
|
|
31
|
+
>>> predictor.predict_cli()
|
|
32
|
+
|
|
33
|
+
Notes:
|
|
34
|
+
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
38
|
+
"""Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
|
39
|
+
|
|
40
|
+
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
|
41
|
+
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
cfg (dict): Default configuration dictionary containing prediction settings.
|
|
45
|
+
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
|
46
|
+
_callbacks (list, optional): List of callback functions to be executed during prediction.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
49
|
+
self.args.task = "classify"
|
|
50
|
+
|
|
51
|
+
def setup_source(self, source):
|
|
52
|
+
"""Set up source and inference mode and classify transforms."""
|
|
53
|
+
super().setup_source(source)
|
|
54
|
+
updated = (
|
|
55
|
+
self.model.model.transforms.transforms[0].size != max(self.imgsz)
|
|
56
|
+
if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
|
|
57
|
+
else False
|
|
58
|
+
)
|
|
59
|
+
self.transforms = (
|
|
60
|
+
classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def preprocess(self, img):
|
|
64
|
+
"""Convert input images to model-compatible tensor format with appropriate normalization."""
|
|
65
|
+
if not isinstance(img, torch.Tensor):
|
|
66
|
+
img = torch.stack(
|
|
67
|
+
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
|
68
|
+
)
|
|
69
|
+
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
|
70
|
+
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
|
|
71
|
+
|
|
72
|
+
def postprocess(self, preds, img, orig_imgs):
|
|
73
|
+
"""Process predictions to return Results objects with classification probabilities.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
preds (torch.Tensor): Raw predictions from the model.
|
|
77
|
+
img (torch.Tensor): Input images after preprocessing.
|
|
78
|
+
orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
(list[Results]): List of Results objects containing classification results for each image.
|
|
82
|
+
"""
|
|
83
|
+
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
|
|
84
|
+
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
|
|
85
|
+
|
|
86
|
+
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
|
87
|
+
return [
|
|
88
|
+
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
|
|
89
|
+
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
|
90
|
+
]
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from copy import copy
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ultralytics.data import ClassificationDataset, build_dataloader
|
|
11
|
+
from ultralytics.engine.trainer import BaseTrainer
|
|
12
|
+
from ultralytics.models import yolo
|
|
13
|
+
from ultralytics.nn.tasks import ClassificationModel
|
|
14
|
+
from ultralytics.utils import DEFAULT_CFG, RANK
|
|
15
|
+
from ultralytics.utils.plotting import plot_images
|
|
16
|
+
from ultralytics.utils.torch_utils import is_parallel, torch_distributed_zero_first
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ClassificationTrainer(BaseTrainer):
|
|
20
|
+
"""A trainer class extending BaseTrainer for training image classification models.
|
|
21
|
+
|
|
22
|
+
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
|
23
|
+
and torchvision models with comprehensive dataset handling and validation.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
model (ClassificationModel): The classification model to be trained.
|
|
27
|
+
data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
|
|
28
|
+
loss_names (list[str]): Names of the loss functions used during training.
|
|
29
|
+
validator (ClassificationValidator): Validator instance for model evaluation.
|
|
30
|
+
|
|
31
|
+
Methods:
|
|
32
|
+
set_model_attributes: Set the model's class names from the loaded dataset.
|
|
33
|
+
get_model: Return a modified PyTorch model configured for training.
|
|
34
|
+
setup_model: Load, create or download model for classification.
|
|
35
|
+
build_dataset: Create a ClassificationDataset instance.
|
|
36
|
+
get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
|
|
37
|
+
preprocess_batch: Preprocess a batch of images and classes.
|
|
38
|
+
progress_string: Return a formatted string showing training progress.
|
|
39
|
+
get_validator: Return an instance of ClassificationValidator.
|
|
40
|
+
label_loss_items: Return a loss dict with labeled training loss items.
|
|
41
|
+
final_eval: Evaluate trained model and save validation results.
|
|
42
|
+
plot_training_samples: Plot training samples with their annotations.
|
|
43
|
+
|
|
44
|
+
Examples:
|
|
45
|
+
Initialize and train a classification model
|
|
46
|
+
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
|
47
|
+
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
|
48
|
+
>>> trainer = ClassificationTrainer(overrides=args)
|
|
49
|
+
>>> trainer.train()
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
53
|
+
"""Initialize a ClassificationTrainer object.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
|
|
57
|
+
overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
|
|
58
|
+
_callbacks (list[Any], optional): List of callback functions to be executed during training.
|
|
59
|
+
"""
|
|
60
|
+
if overrides is None:
|
|
61
|
+
overrides = {}
|
|
62
|
+
overrides["task"] = "classify"
|
|
63
|
+
if overrides.get("imgsz") is None:
|
|
64
|
+
overrides["imgsz"] = 224
|
|
65
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
66
|
+
|
|
67
|
+
def set_model_attributes(self):
|
|
68
|
+
"""Set the YOLO model's class names from the loaded dataset."""
|
|
69
|
+
self.model.names = self.data["names"]
|
|
70
|
+
|
|
71
|
+
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
72
|
+
"""Return a modified PyTorch model configured for training YOLO classification.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
cfg (Any, optional): Model configuration.
|
|
76
|
+
weights (Any, optional): Pre-trained model weights.
|
|
77
|
+
verbose (bool, optional): Whether to display model information.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
(ClassificationModel): Configured PyTorch model for classification.
|
|
81
|
+
"""
|
|
82
|
+
model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
|
83
|
+
if weights:
|
|
84
|
+
model.load(weights)
|
|
85
|
+
|
|
86
|
+
for m in model.modules():
|
|
87
|
+
if not self.args.pretrained and hasattr(m, "reset_parameters"):
|
|
88
|
+
m.reset_parameters()
|
|
89
|
+
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
|
90
|
+
m.p = self.args.dropout # set dropout
|
|
91
|
+
for p in model.parameters():
|
|
92
|
+
p.requires_grad = True # for training
|
|
93
|
+
return model
|
|
94
|
+
|
|
95
|
+
def setup_model(self):
|
|
96
|
+
"""Load, create or download model for classification tasks.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
(Any): Model checkpoint if applicable, otherwise None.
|
|
100
|
+
"""
|
|
101
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
102
|
+
|
|
103
|
+
if str(self.model) in torchvision.models.__dict__:
|
|
104
|
+
self.model = torchvision.models.__dict__[self.model](
|
|
105
|
+
weights="IMAGENET1K_V1" if self.args.pretrained else None
|
|
106
|
+
)
|
|
107
|
+
ckpt = None
|
|
108
|
+
else:
|
|
109
|
+
ckpt = super().setup_model()
|
|
110
|
+
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
|
111
|
+
return ckpt
|
|
112
|
+
|
|
113
|
+
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
|
114
|
+
"""Create a ClassificationDataset instance given an image path and mode.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
img_path (str): Path to the dataset images.
|
|
118
|
+
mode (str, optional): Dataset mode ('train', 'val', or 'test').
|
|
119
|
+
batch (Any, optional): Batch information (unused in this implementation).
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
(ClassificationDataset): Dataset for the specified mode.
|
|
123
|
+
"""
|
|
124
|
+
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
|
125
|
+
|
|
126
|
+
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
|
127
|
+
"""Return PyTorch DataLoader with transforms to preprocess images.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
dataset_path (str): Path to the dataset.
|
|
131
|
+
batch_size (int, optional): Number of images per batch.
|
|
132
|
+
rank (int, optional): Process rank for distributed training.
|
|
133
|
+
mode (str, optional): 'train', 'val', or 'test' mode.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
|
|
137
|
+
"""
|
|
138
|
+
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
|
139
|
+
dataset = self.build_dataset(dataset_path, mode)
|
|
140
|
+
|
|
141
|
+
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
|
|
142
|
+
# Attach inference transforms
|
|
143
|
+
if mode != "train":
|
|
144
|
+
if is_parallel(self.model):
|
|
145
|
+
self.model.module.transforms = loader.dataset.torch_transforms
|
|
146
|
+
else:
|
|
147
|
+
self.model.transforms = loader.dataset.torch_transforms
|
|
148
|
+
return loader
|
|
149
|
+
|
|
150
|
+
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
151
|
+
"""Preprocess a batch of images and classes."""
|
|
152
|
+
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
153
|
+
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
154
|
+
return batch
|
|
155
|
+
|
|
156
|
+
def progress_string(self) -> str:
|
|
157
|
+
"""Return a formatted string showing training progress."""
|
|
158
|
+
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
|
159
|
+
"Epoch",
|
|
160
|
+
"GPU_mem",
|
|
161
|
+
*self.loss_names,
|
|
162
|
+
"Instances",
|
|
163
|
+
"Size",
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def get_validator(self):
|
|
167
|
+
"""Return an instance of ClassificationValidator for validation."""
|
|
168
|
+
self.loss_names = ["loss"]
|
|
169
|
+
return yolo.classify.ClassificationValidator(
|
|
170
|
+
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
|
|
174
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
loss_items (torch.Tensor, optional): Loss tensor items.
|
|
178
|
+
prefix (str, optional): Prefix to prepend to loss names.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
keys (list[str]): List of loss keys if loss_items is None.
|
|
182
|
+
loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
|
|
183
|
+
"""
|
|
184
|
+
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
|
185
|
+
if loss_items is None:
|
|
186
|
+
return keys
|
|
187
|
+
loss_items = [round(float(loss_items), 5)]
|
|
188
|
+
return dict(zip(keys, loss_items))
|
|
189
|
+
|
|
190
|
+
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
|
|
191
|
+
"""Plot training samples with their annotations.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
|
|
195
|
+
ni (int): Number of iterations.
|
|
196
|
+
"""
|
|
197
|
+
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
|
198
|
+
plot_images(
|
|
199
|
+
labels=batch,
|
|
200
|
+
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
201
|
+
on_plot=self.on_plot,
|
|
202
|
+
)
|