dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,268 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
import torch.nn.functional as F
|
6
|
+
from scipy.optimize import linear_sum_assignment
|
7
|
+
|
8
|
+
from ultralytics.utils.metrics import bbox_iou
|
9
|
+
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
10
|
+
|
11
|
+
|
12
|
+
class HungarianMatcher(nn.Module):
|
13
|
+
"""
|
14
|
+
A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
|
15
|
+
end-to-end fashion.
|
16
|
+
|
17
|
+
HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
|
18
|
+
function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
|
19
|
+
|
20
|
+
Attributes:
|
21
|
+
cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
22
|
+
use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
|
23
|
+
with_mask (bool): Indicates whether the model makes mask predictions.
|
24
|
+
num_sample_points (int): The number of sample points used in mask cost calculation.
|
25
|
+
alpha (float): The alpha factor in Focal Loss calculation.
|
26
|
+
gamma (float): The gamma factor in Focal Loss calculation.
|
27
|
+
|
28
|
+
Methods:
|
29
|
+
forward: Computes the assignment between predictions and ground truths for a batch.
|
30
|
+
_cost_mask: Computes the mask cost and dice cost if masks are predicted.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
34
|
+
"""
|
35
|
+
Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes.
|
36
|
+
|
37
|
+
The HungarianMatcher uses a cost function that considers classification scores, bounding box coordinates,
|
38
|
+
and optionally mask predictions to perform optimal bipartite matching between predictions and ground truths.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
cost_gain (dict, optional): Dictionary of cost coefficients for different components of the matching cost.
|
42
|
+
Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
43
|
+
use_fl (bool, optional): Whether to use Focal Loss for the classification cost calculation.
|
44
|
+
with_mask (bool, optional): Whether the model makes mask predictions.
|
45
|
+
num_sample_points (int, optional): Number of sample points used in mask cost calculation.
|
46
|
+
alpha (float, optional): Alpha factor in Focal Loss calculation.
|
47
|
+
gamma (float, optional): Gamma factor in Focal Loss calculation.
|
48
|
+
"""
|
49
|
+
super().__init__()
|
50
|
+
if cost_gain is None:
|
51
|
+
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
52
|
+
self.cost_gain = cost_gain
|
53
|
+
self.use_fl = use_fl
|
54
|
+
self.with_mask = with_mask
|
55
|
+
self.num_sample_points = num_sample_points
|
56
|
+
self.alpha = alpha
|
57
|
+
self.gamma = gamma
|
58
|
+
|
59
|
+
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
|
60
|
+
"""
|
61
|
+
Forward pass for HungarianMatcher. Computes costs based on prediction and ground truth and finds the optimal
|
62
|
+
matching between predictions and ground truth based on these costs.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
|
66
|
+
pred_scores (torch.Tensor): Predicted scores with shape (batch_size, num_queries, num_classes).
|
67
|
+
gt_cls (torch.Tensor): Ground truth classes with shape (num_gts, ).
|
68
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
|
69
|
+
gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
|
70
|
+
each image.
|
71
|
+
masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
|
72
|
+
gt_mask (List[torch.Tensor], optional): List of ground truth masks, each with shape (num_masks, Height, Width).
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
(List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
|
76
|
+
- index_i is the tensor of indices of the selected predictions (in order)
|
77
|
+
- index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
|
78
|
+
For each batch element, it holds:
|
79
|
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
80
|
+
"""
|
81
|
+
bs, nq, nc = pred_scores.shape
|
82
|
+
|
83
|
+
if sum(gt_groups) == 0:
|
84
|
+
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
85
|
+
|
86
|
+
# We flatten to compute the cost matrices in a batch
|
87
|
+
# (batch_size * num_queries, num_classes)
|
88
|
+
pred_scores = pred_scores.detach().view(-1, nc)
|
89
|
+
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
90
|
+
# (batch_size * num_queries, 4)
|
91
|
+
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
92
|
+
|
93
|
+
# Compute the classification cost
|
94
|
+
pred_scores = pred_scores[:, gt_cls]
|
95
|
+
if self.use_fl:
|
96
|
+
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
97
|
+
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
98
|
+
cost_class = pos_cost_class - neg_cost_class
|
99
|
+
else:
|
100
|
+
cost_class = -pred_scores
|
101
|
+
|
102
|
+
# Compute the L1 cost between boxes
|
103
|
+
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
104
|
+
|
105
|
+
# Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
|
106
|
+
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
107
|
+
|
108
|
+
# Final cost matrix
|
109
|
+
C = (
|
110
|
+
self.cost_gain["class"] * cost_class
|
111
|
+
+ self.cost_gain["bbox"] * cost_bbox
|
112
|
+
+ self.cost_gain["giou"] * cost_giou
|
113
|
+
)
|
114
|
+
# Compute the mask cost and dice cost
|
115
|
+
if self.with_mask:
|
116
|
+
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
117
|
+
|
118
|
+
# Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries)
|
119
|
+
C[C.isnan() | C.isinf()] = 0.0
|
120
|
+
|
121
|
+
C = C.view(bs, nq, -1).cpu()
|
122
|
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
123
|
+
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
|
124
|
+
return [
|
125
|
+
(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
126
|
+
for k, (i, j) in enumerate(indices)
|
127
|
+
]
|
128
|
+
|
129
|
+
# This function is for future RT-DETR Segment models
|
130
|
+
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
131
|
+
# assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
|
132
|
+
# # all masks share the same set of points for efficient matching
|
133
|
+
# sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
|
134
|
+
# sample_points = 2.0 * sample_points - 1.0
|
135
|
+
#
|
136
|
+
# out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
|
137
|
+
# out_mask = out_mask.flatten(0, 1)
|
138
|
+
#
|
139
|
+
# tgt_mask = torch.cat(gt_mask).unsqueeze(1)
|
140
|
+
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
|
141
|
+
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
|
142
|
+
#
|
143
|
+
# with torch.amp.autocast("cuda", enabled=False):
|
144
|
+
# # binary cross entropy cost
|
145
|
+
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
|
146
|
+
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
|
147
|
+
# cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
|
148
|
+
# cost_mask /= self.num_sample_points
|
149
|
+
#
|
150
|
+
# # dice cost
|
151
|
+
# out_mask = F.sigmoid(out_mask)
|
152
|
+
# numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
|
153
|
+
# denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
|
154
|
+
# cost_dice = 1 - (numerator + 1) / (denominator + 1)
|
155
|
+
#
|
156
|
+
# C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
|
157
|
+
# return C
|
158
|
+
|
159
|
+
|
160
|
+
def get_cdn_group(
|
161
|
+
batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
|
162
|
+
):
|
163
|
+
"""
|
164
|
+
Get contrastive denoising training group with positive and negative samples from ground truths.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape (num_gts, )), 'gt_bboxes'
|
168
|
+
(torch.Tensor with shape (num_gts, 4)), 'gt_groups' (List[int]) which is a list of batch size length
|
169
|
+
indicating the number of gts of each image.
|
170
|
+
num_classes (int): Number of classes.
|
171
|
+
num_queries (int): Number of queries.
|
172
|
+
class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
|
173
|
+
num_dn (int, optional): Number of denoising queries.
|
174
|
+
cls_noise_ratio (float, optional): Noise ratio for class labels.
|
175
|
+
box_noise_scale (float, optional): Noise scale for bounding box coordinates.
|
176
|
+
training (bool, optional): If it's in training mode.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
padding_cls (Optional[torch.Tensor]): The modified class embeddings for denoising.
|
180
|
+
padding_bbox (Optional[torch.Tensor]): The modified bounding boxes for denoising.
|
181
|
+
attn_mask (Optional[torch.Tensor]): The attention mask for denoising.
|
182
|
+
dn_meta (Optional[Dict]): Meta information for denoising.
|
183
|
+
"""
|
184
|
+
if (not training) or num_dn <= 0 or batch is None:
|
185
|
+
return None, None, None, None
|
186
|
+
gt_groups = batch["gt_groups"]
|
187
|
+
total_num = sum(gt_groups)
|
188
|
+
max_nums = max(gt_groups)
|
189
|
+
if max_nums == 0:
|
190
|
+
return None, None, None, None
|
191
|
+
|
192
|
+
num_group = num_dn // max_nums
|
193
|
+
num_group = 1 if num_group == 0 else num_group
|
194
|
+
# Pad gt to max_num of a batch
|
195
|
+
bs = len(gt_groups)
|
196
|
+
gt_cls = batch["cls"] # (bs*num, )
|
197
|
+
gt_bbox = batch["bboxes"] # bs*num, 4
|
198
|
+
b_idx = batch["batch_idx"]
|
199
|
+
|
200
|
+
# Each group has positive and negative queries.
|
201
|
+
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
202
|
+
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
203
|
+
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
204
|
+
|
205
|
+
# Positive and negative mask
|
206
|
+
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
207
|
+
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
208
|
+
|
209
|
+
if cls_noise_ratio > 0:
|
210
|
+
# Half of bbox prob
|
211
|
+
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
212
|
+
idx = torch.nonzero(mask).squeeze(-1)
|
213
|
+
# Randomly put a new one here
|
214
|
+
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
215
|
+
dn_cls[idx] = new_label
|
216
|
+
|
217
|
+
if box_noise_scale > 0:
|
218
|
+
known_bbox = xywh2xyxy(dn_bbox)
|
219
|
+
|
220
|
+
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
|
221
|
+
|
222
|
+
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
|
223
|
+
rand_part = torch.rand_like(dn_bbox)
|
224
|
+
rand_part[neg_idx] += 1.0
|
225
|
+
rand_part *= rand_sign
|
226
|
+
known_bbox += rand_part * diff
|
227
|
+
known_bbox.clip_(min=0.0, max=1.0)
|
228
|
+
dn_bbox = xyxy2xywh(known_bbox)
|
229
|
+
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
|
230
|
+
|
231
|
+
num_dn = int(max_nums * 2 * num_group) # total denoising queries
|
232
|
+
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
233
|
+
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
234
|
+
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
235
|
+
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|
236
|
+
|
237
|
+
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
|
238
|
+
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
|
239
|
+
|
240
|
+
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
|
241
|
+
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
|
242
|
+
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
|
243
|
+
|
244
|
+
tgt_size = num_dn + num_queries
|
245
|
+
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
|
246
|
+
# Match query cannot see the reconstruct
|
247
|
+
attn_mask[num_dn:, :num_dn] = True
|
248
|
+
# Reconstruct cannot see each other
|
249
|
+
for i in range(num_group):
|
250
|
+
if i == 0:
|
251
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
252
|
+
if i == num_group - 1:
|
253
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
|
254
|
+
else:
|
255
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
256
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
|
257
|
+
dn_meta = {
|
258
|
+
"dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
|
259
|
+
"dn_num_group": num_group,
|
260
|
+
"dn_num_split": [num_dn, num_queries],
|
261
|
+
}
|
262
|
+
|
263
|
+
return (
|
264
|
+
padding_cls.to(class_embed.device),
|
265
|
+
padding_bbox.to(class_embed.device),
|
266
|
+
attn_mask.to(class_embed.device),
|
267
|
+
dn_meta,
|
268
|
+
)
|
@@ -0,0 +1,7 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
|
4
|
+
|
5
|
+
from .model import YOLO, YOLOE, YOLOWorld
|
6
|
+
|
7
|
+
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"
|
@@ -0,0 +1,7 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
|
4
|
+
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
5
|
+
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
6
|
+
|
7
|
+
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
|
@@ -0,0 +1,88 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import cv2
|
4
|
+
import torch
|
5
|
+
from PIL import Image
|
6
|
+
|
7
|
+
from ultralytics.engine.predictor import BasePredictor
|
8
|
+
from ultralytics.engine.results import Results
|
9
|
+
from ultralytics.utils import DEFAULT_CFG, ops
|
10
|
+
|
11
|
+
|
12
|
+
class ClassificationPredictor(BasePredictor):
|
13
|
+
"""
|
14
|
+
A class extending the BasePredictor class for prediction based on a classification model.
|
15
|
+
|
16
|
+
This predictor handles the specific requirements of classification models, including preprocessing images
|
17
|
+
and postprocessing predictions to generate classification results.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
args (dict): Configuration arguments for the predictor.
|
21
|
+
_legacy_transform_name (str): Name of the legacy transform class for backward compatibility.
|
22
|
+
|
23
|
+
Methods:
|
24
|
+
preprocess: Convert input images to model-compatible format.
|
25
|
+
postprocess: Process model predictions into Results objects.
|
26
|
+
|
27
|
+
Notes:
|
28
|
+
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
29
|
+
|
30
|
+
Examples:
|
31
|
+
>>> from ultralytics.utils import ASSETS
|
32
|
+
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
|
33
|
+
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
|
34
|
+
>>> predictor = ClassificationPredictor(overrides=args)
|
35
|
+
>>> predictor.predict_cli()
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
39
|
+
"""
|
40
|
+
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
41
|
+
|
42
|
+
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
43
|
+
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
cfg (dict): Default configuration dictionary containing prediction settings. Defaults to DEFAULT_CFG.
|
47
|
+
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
48
|
+
_callbacks (list, optional): List of callback functions to be executed during prediction.
|
49
|
+
"""
|
50
|
+
super().__init__(cfg, overrides, _callbacks)
|
51
|
+
self.args.task = "classify"
|
52
|
+
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
|
53
|
+
|
54
|
+
def preprocess(self, img):
|
55
|
+
"""Convert input images to model-compatible tensor format with appropriate normalization."""
|
56
|
+
if not isinstance(img, torch.Tensor):
|
57
|
+
is_legacy_transform = any(
|
58
|
+
self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
|
59
|
+
)
|
60
|
+
if is_legacy_transform: # to handle legacy transforms
|
61
|
+
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
62
|
+
else:
|
63
|
+
img = torch.stack(
|
64
|
+
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
65
|
+
)
|
66
|
+
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
67
|
+
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
68
|
+
|
69
|
+
def postprocess(self, preds, img, orig_imgs):
|
70
|
+
"""
|
71
|
+
Process predictions to return Results objects with classification probabilities.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
preds (torch.Tensor): Raw predictions from the model.
|
75
|
+
img (torch.Tensor): Input images after preprocessing.
|
76
|
+
orig_imgs (List[np.ndarray] | torch.Tensor): Original images before preprocessing.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
(List[Results]): List of Results objects containing classification results for each image.
|
80
|
+
"""
|
81
|
+
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
82
|
+
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
83
|
+
|
84
|
+
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
85
|
+
return [
|
86
|
+
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
|
87
|
+
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
88
|
+
]
|
@@ -0,0 +1,233 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from copy import copy
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from ultralytics.data import ClassificationDataset, build_dataloader
|
8
|
+
from ultralytics.engine.trainer import BaseTrainer
|
9
|
+
from ultralytics.models import yolo
|
10
|
+
from ultralytics.nn.tasks import ClassificationModel
|
11
|
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
12
|
+
from ultralytics.utils.plotting import plot_images, plot_results
|
13
|
+
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
14
|
+
|
15
|
+
|
16
|
+
class ClassificationTrainer(BaseTrainer):
|
17
|
+
"""
|
18
|
+
A class extending the BaseTrainer class for training based on a classification model.
|
19
|
+
|
20
|
+
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
21
|
+
and torchvision models.
|
22
|
+
|
23
|
+
Attributes:
|
24
|
+
model (ClassificationModel): The classification model to be trained.
|
25
|
+
data (dict): Dictionary containing dataset information including class names and number of classes.
|
26
|
+
loss_names (List[str]): Names of the loss functions used during training.
|
27
|
+
validator (ClassificationValidator): Validator instance for model evaluation.
|
28
|
+
|
29
|
+
Methods:
|
30
|
+
set_model_attributes: Set the model's class names from the loaded dataset.
|
31
|
+
get_model: Return a modified PyTorch model configured for training.
|
32
|
+
setup_model: Load, create or download model for classification.
|
33
|
+
build_dataset: Create a ClassificationDataset instance.
|
34
|
+
get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
|
35
|
+
preprocess_batch: Preprocess a batch of images and classes.
|
36
|
+
progress_string: Return a formatted string showing training progress.
|
37
|
+
get_validator: Return an instance of ClassificationValidator.
|
38
|
+
label_loss_items: Return a loss dict with labelled training loss items.
|
39
|
+
plot_metrics: Plot metrics from a CSV file.
|
40
|
+
final_eval: Evaluate trained model and save validation results.
|
41
|
+
plot_training_samples: Plot training samples with their annotations.
|
42
|
+
|
43
|
+
Examples:
|
44
|
+
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
45
|
+
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
46
|
+
>>> trainer = ClassificationTrainer(overrides=args)
|
47
|
+
>>> trainer.train()
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
51
|
+
"""
|
52
|
+
Initialize a ClassificationTrainer object.
|
53
|
+
|
54
|
+
This constructor sets up a trainer for image classification tasks, configuring the task type and default
|
55
|
+
image size if not specified.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
59
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
60
|
+
_callbacks (list, optional): List of callback functions to be executed during training.
|
61
|
+
|
62
|
+
Examples:
|
63
|
+
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
64
|
+
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
65
|
+
>>> trainer = ClassificationTrainer(overrides=args)
|
66
|
+
>>> trainer.train()
|
67
|
+
"""
|
68
|
+
if overrides is None:
|
69
|
+
overrides = {}
|
70
|
+
overrides["task"] = "classify"
|
71
|
+
if overrides.get("imgsz") is None:
|
72
|
+
overrides["imgsz"] = 224
|
73
|
+
super().__init__(cfg, overrides, _callbacks)
|
74
|
+
|
75
|
+
def set_model_attributes(self):
|
76
|
+
"""Set the YOLO model's class names from the loaded dataset."""
|
77
|
+
self.model.names = self.data["names"]
|
78
|
+
|
79
|
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
80
|
+
"""
|
81
|
+
Return a modified PyTorch model configured for training YOLO.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
cfg (Any): Model configuration.
|
85
|
+
weights (Any): Pre-trained model weights.
|
86
|
+
verbose (bool): Whether to display model information.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
(ClassificationModel): Configured PyTorch model for classification.
|
90
|
+
"""
|
91
|
+
model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
92
|
+
if weights:
|
93
|
+
model.load(weights)
|
94
|
+
|
95
|
+
for m in model.modules():
|
96
|
+
if not self.args.pretrained and hasattr(m, "reset_parameters"):
|
97
|
+
m.reset_parameters()
|
98
|
+
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
99
|
+
m.p = self.args.dropout # set dropout
|
100
|
+
for p in model.parameters():
|
101
|
+
p.requires_grad = True # for training
|
102
|
+
return model
|
103
|
+
|
104
|
+
def setup_model(self):
|
105
|
+
"""
|
106
|
+
Load, create or download model for classification tasks.
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
(Any): Model checkpoint if applicable, otherwise None.
|
110
|
+
"""
|
111
|
+
import torchvision # scope for faster 'import ultralytics'
|
112
|
+
|
113
|
+
if str(self.model) in torchvision.models.__dict__:
|
114
|
+
self.model = torchvision.models.__dict__[self.model](
|
115
|
+
weights="IMAGENET1K_V1" if self.args.pretrained else None
|
116
|
+
)
|
117
|
+
ckpt = None
|
118
|
+
else:
|
119
|
+
ckpt = super().setup_model()
|
120
|
+
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
121
|
+
return ckpt
|
122
|
+
|
123
|
+
def build_dataset(self, img_path, mode="train", batch=None):
|
124
|
+
"""
|
125
|
+
Create a ClassificationDataset instance given an image path and mode.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
img_path (str): Path to the dataset images.
|
129
|
+
mode (str): Dataset mode ('train', 'val', or 'test').
|
130
|
+
batch (Any): Batch information (unused in this implementation).
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
(ClassificationDataset): Dataset for the specified mode.
|
134
|
+
"""
|
135
|
+
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
136
|
+
|
137
|
+
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
138
|
+
"""
|
139
|
+
Return PyTorch DataLoader with transforms to preprocess images.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
dataset_path (str): Path to the dataset.
|
143
|
+
batch_size (int): Number of images per batch.
|
144
|
+
rank (int): Process rank for distributed training.
|
145
|
+
mode (str): 'train', 'val', or 'test' mode.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
|
149
|
+
"""
|
150
|
+
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
151
|
+
dataset = self.build_dataset(dataset_path, mode)
|
152
|
+
|
153
|
+
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
|
154
|
+
# Attach inference transforms
|
155
|
+
if mode != "train":
|
156
|
+
if is_parallel(self.model):
|
157
|
+
self.model.module.transforms = loader.dataset.torch_transforms
|
158
|
+
else:
|
159
|
+
self.model.transforms = loader.dataset.torch_transforms
|
160
|
+
return loader
|
161
|
+
|
162
|
+
def preprocess_batch(self, batch):
|
163
|
+
"""Preprocesses a batch of images and classes."""
|
164
|
+
batch["img"] = batch["img"].to(self.device)
|
165
|
+
batch["cls"] = batch["cls"].to(self.device)
|
166
|
+
return batch
|
167
|
+
|
168
|
+
def progress_string(self):
|
169
|
+
"""Returns a formatted string showing training progress."""
|
170
|
+
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
171
|
+
"Epoch",
|
172
|
+
"GPU_mem",
|
173
|
+
*self.loss_names,
|
174
|
+
"Instances",
|
175
|
+
"Size",
|
176
|
+
)
|
177
|
+
|
178
|
+
def get_validator(self):
|
179
|
+
"""Returns an instance of ClassificationValidator for validation."""
|
180
|
+
self.loss_names = ["loss"]
|
181
|
+
return yolo.classify.ClassificationValidator(
|
182
|
+
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
183
|
+
)
|
184
|
+
|
185
|
+
def label_loss_items(self, loss_items=None, prefix="train"):
|
186
|
+
"""
|
187
|
+
Return a loss dict with labelled training loss items tensor.
|
188
|
+
|
189
|
+
Args:
|
190
|
+
loss_items (torch.Tensor, optional): Loss tensor items.
|
191
|
+
prefix (str): Prefix to prepend to loss names.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
(Dict[str, float] | List[str]): Dictionary of loss items or list of loss keys if loss_items is None.
|
195
|
+
"""
|
196
|
+
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
197
|
+
if loss_items is None:
|
198
|
+
return keys
|
199
|
+
loss_items = [round(float(loss_items), 5)]
|
200
|
+
return dict(zip(keys, loss_items))
|
201
|
+
|
202
|
+
def plot_metrics(self):
|
203
|
+
"""Plot metrics from a CSV file."""
|
204
|
+
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
205
|
+
|
206
|
+
def final_eval(self):
|
207
|
+
"""Evaluate trained model and save validation results."""
|
208
|
+
for f in self.last, self.best:
|
209
|
+
if f.exists():
|
210
|
+
strip_optimizer(f) # strip optimizers
|
211
|
+
if f is self.best:
|
212
|
+
LOGGER.info(f"\nValidating {f}...")
|
213
|
+
self.validator.args.data = self.args.data
|
214
|
+
self.validator.args.plots = self.args.plots
|
215
|
+
self.metrics = self.validator(model=f)
|
216
|
+
self.metrics.pop("fitness", None)
|
217
|
+
self.run_callbacks("on_fit_epoch_end")
|
218
|
+
|
219
|
+
def plot_training_samples(self, batch, ni):
|
220
|
+
"""
|
221
|
+
Plot training samples with their annotations.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
batch (Dict[str, torch.Tensor]): Batch containing images and class labels.
|
225
|
+
ni (int): Number of iterations.
|
226
|
+
"""
|
227
|
+
plot_images(
|
228
|
+
images=batch["img"],
|
229
|
+
batch_idx=torch.arange(len(batch["img"])),
|
230
|
+
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
231
|
+
fname=self.save_dir / f"train_batch{ni}.jpg",
|
232
|
+
on_plot=self.on_plot,
|
233
|
+
)
|