dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
"""Provides utility to combine a vision backbone with a language backbone."""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from copy import copy
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
14
|
+
|
|
15
|
+
from .necks import Sam3DualViTDetNeck
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SAM3VLBackbone(nn.Module):
|
|
19
|
+
"""This backbone combines a vision backbone and a language backbone without fusion. As such it is more of a
|
|
20
|
+
convenience wrapper to handle the two backbones together.
|
|
21
|
+
|
|
22
|
+
It adds support for activation checkpointing and compilation.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
visual: Sam3DualViTDetNeck,
|
|
28
|
+
text,
|
|
29
|
+
compile_visual: bool = False,
|
|
30
|
+
act_ckpt_whole_vision_backbone: bool = False,
|
|
31
|
+
act_ckpt_whole_language_backbone: bool = False,
|
|
32
|
+
scalp=0,
|
|
33
|
+
):
|
|
34
|
+
"""Initialize the backbone combiner.
|
|
35
|
+
|
|
36
|
+
:param visual: The vision backbone to use
|
|
37
|
+
:param text: The text encoder to use
|
|
38
|
+
"""
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.vision_backbone: Sam3DualViTDetNeck = torch.compile(visual) if compile_visual else visual
|
|
41
|
+
self.language_backbone = text
|
|
42
|
+
self.scalp = scalp
|
|
43
|
+
# allow running activation checkpointing on the entire vision and language backbones
|
|
44
|
+
self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
|
|
45
|
+
self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
|
|
46
|
+
|
|
47
|
+
def forward(
|
|
48
|
+
self,
|
|
49
|
+
samples: torch.Tensor,
|
|
50
|
+
captions: list[str],
|
|
51
|
+
input_boxes: torch.Tensor = None,
|
|
52
|
+
additional_text: list[str] | None = None,
|
|
53
|
+
):
|
|
54
|
+
"""Forward pass of the backbone combiner.
|
|
55
|
+
|
|
56
|
+
:param samples: The input images
|
|
57
|
+
:param captions: The input captions
|
|
58
|
+
:param input_boxes: If the text contains place-holders for boxes, this
|
|
59
|
+
parameter contains the tensor containing their spatial features
|
|
60
|
+
:param additional_text: This can be used to encode some additional text
|
|
61
|
+
(different from the captions) in the same forward of the backbone
|
|
62
|
+
:return: Output dictionary with the following keys:
|
|
63
|
+
- vision_features: The output of the vision backbone
|
|
64
|
+
- language_features: The output of the language backbone
|
|
65
|
+
- language_mask: The attention mask of the language backbone
|
|
66
|
+
- vision_pos_enc: The positional encoding of the vision backbone
|
|
67
|
+
- (optional) additional_text_features: The output of the language
|
|
68
|
+
backbone for the additional text
|
|
69
|
+
- (optional) additional_text_mask: The attention mask of the
|
|
70
|
+
language backbone for the additional text
|
|
71
|
+
"""
|
|
72
|
+
output = self.forward_image(samples)
|
|
73
|
+
output.update(self.forward_text(captions, input_boxes, additional_text))
|
|
74
|
+
return output
|
|
75
|
+
|
|
76
|
+
def forward_image(self, samples: torch.Tensor):
|
|
77
|
+
"""Forward pass of the vision backbone and get both SAM3 and SAM2 features."""
|
|
78
|
+
# Forward through backbone
|
|
79
|
+
sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(samples)
|
|
80
|
+
if self.scalp > 0:
|
|
81
|
+
# Discard the lowest resolution features
|
|
82
|
+
sam3_features, sam3_pos = (
|
|
83
|
+
sam3_features[: -self.scalp],
|
|
84
|
+
sam3_pos[: -self.scalp],
|
|
85
|
+
)
|
|
86
|
+
if sam2_features is not None and sam2_pos is not None:
|
|
87
|
+
sam2_features, sam2_pos = (
|
|
88
|
+
sam2_features[: -self.scalp],
|
|
89
|
+
sam2_pos[: -self.scalp],
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
sam2_output = None
|
|
93
|
+
|
|
94
|
+
if sam2_features is not None and sam2_pos is not None:
|
|
95
|
+
sam2_src = sam2_features[-1]
|
|
96
|
+
sam2_output = {
|
|
97
|
+
"vision_features": sam2_src,
|
|
98
|
+
"vision_pos_enc": sam2_pos,
|
|
99
|
+
"backbone_fpn": sam2_features,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
sam3_src = sam3_features[-1]
|
|
103
|
+
return {
|
|
104
|
+
"vision_features": sam3_src,
|
|
105
|
+
"vision_pos_enc": sam3_pos,
|
|
106
|
+
"backbone_fpn": sam3_features,
|
|
107
|
+
"sam2_backbone_out": sam2_output,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
def forward_image_sam2(self, samples: torch.Tensor):
|
|
111
|
+
"""Forward pass of the vision backbone to get SAM2 features only."""
|
|
112
|
+
xs = self.vision_backbone.trunk(samples)
|
|
113
|
+
x = xs[-1] # simpleFPN
|
|
114
|
+
|
|
115
|
+
assert self.vision_backbone.sam2_convs is not None, "SAM2 neck is not available."
|
|
116
|
+
sam2_features, sam2_pos = self.vision_backbone.sam_forward_feature_levels(x, self.vision_backbone.sam2_convs)
|
|
117
|
+
|
|
118
|
+
if self.scalp > 0:
|
|
119
|
+
# Discard the lowest resolution features
|
|
120
|
+
sam2_features, sam2_pos = (
|
|
121
|
+
sam2_features[: -self.scalp],
|
|
122
|
+
sam2_pos[: -self.scalp],
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return {
|
|
126
|
+
"vision_features": sam2_features[-1],
|
|
127
|
+
"vision_pos_enc": sam2_pos,
|
|
128
|
+
"backbone_fpn": sam2_features,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
def forward_text(self, captions, input_boxes=None, additional_text=None):
|
|
132
|
+
"""Forward pass of the text encoder."""
|
|
133
|
+
output = {}
|
|
134
|
+
|
|
135
|
+
# Forward through text_encoder
|
|
136
|
+
text_to_encode = copy(captions)
|
|
137
|
+
if additional_text is not None:
|
|
138
|
+
# if there are additional_text, we piggy-back them into this forward.
|
|
139
|
+
# They'll be used later for output alignment
|
|
140
|
+
text_to_encode += additional_text
|
|
141
|
+
|
|
142
|
+
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]):
|
|
143
|
+
text_attention_mask, text_memory, text_embeds = self.language_backbone(text_to_encode, input_boxes)
|
|
144
|
+
|
|
145
|
+
if additional_text is not None:
|
|
146
|
+
output["additional_text_features"] = text_memory[:, -len(additional_text) :]
|
|
147
|
+
output["additional_text_mask"] = text_attention_mask[-len(additional_text) :]
|
|
148
|
+
|
|
149
|
+
text_memory = text_memory[:, : len(captions)]
|
|
150
|
+
text_attention_mask = text_attention_mask[: len(captions)]
|
|
151
|
+
text_embeds = text_embeds[:, : len(captions)]
|
|
152
|
+
output["language_features"] = text_memory
|
|
153
|
+
output["language_mask"] = text_attention_mask
|
|
154
|
+
output["language_embeds"] = text_embeds # Text embeddings before forward to the encoder
|
|
155
|
+
|
|
156
|
+
return output
|
|
157
|
+
|
|
158
|
+
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
|
|
159
|
+
"""Set the image size for the vision backbone."""
|
|
160
|
+
self.vision_backbone.set_imgsz(imgsz)
|
ultralytics/models/utils/loss.py
CHANGED
|
@@ -15,11 +15,10 @@ from .ops import HungarianMatcher
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class DETRLoss(nn.Module):
|
|
18
|
-
"""
|
|
19
|
-
DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
|
18
|
+
"""DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
|
20
19
|
|
|
21
|
-
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
|
|
22
|
-
|
|
20
|
+
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the DETR
|
|
21
|
+
object detection model.
|
|
23
22
|
|
|
24
23
|
Attributes:
|
|
25
24
|
nc (int): Number of classes.
|
|
@@ -47,8 +46,7 @@ class DETRLoss(nn.Module):
|
|
|
47
46
|
gamma: float = 1.5,
|
|
48
47
|
alpha: float = 0.25,
|
|
49
48
|
):
|
|
50
|
-
"""
|
|
51
|
-
Initialize DETR loss function with customizable components and gains.
|
|
49
|
+
"""Initialize DETR loss function with customizable components and gains.
|
|
52
50
|
|
|
53
51
|
Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
|
|
54
52
|
losses and various loss types.
|
|
@@ -82,8 +80,7 @@ class DETRLoss(nn.Module):
|
|
|
82
80
|
def _get_loss_class(
|
|
83
81
|
self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
|
|
84
82
|
) -> dict[str, torch.Tensor]:
|
|
85
|
-
"""
|
|
86
|
-
Compute classification loss based on predictions, target values, and ground truth scores.
|
|
83
|
+
"""Compute classification loss based on predictions, target values, and ground truth scores.
|
|
87
84
|
|
|
88
85
|
Args:
|
|
89
86
|
pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
|
|
@@ -124,8 +121,7 @@ class DETRLoss(nn.Module):
|
|
|
124
121
|
def _get_loss_bbox(
|
|
125
122
|
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
|
|
126
123
|
) -> dict[str, torch.Tensor]:
|
|
127
|
-
"""
|
|
128
|
-
Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
|
124
|
+
"""Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
|
129
125
|
|
|
130
126
|
Args:
|
|
131
127
|
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
|
|
@@ -199,8 +195,7 @@ class DETRLoss(nn.Module):
|
|
|
199
195
|
masks: torch.Tensor | None = None,
|
|
200
196
|
gt_mask: torch.Tensor | None = None,
|
|
201
197
|
) -> dict[str, torch.Tensor]:
|
|
202
|
-
"""
|
|
203
|
-
Get auxiliary losses for intermediate decoder layers.
|
|
198
|
+
"""Get auxiliary losses for intermediate decoder layers.
|
|
204
199
|
|
|
205
200
|
Args:
|
|
206
201
|
pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
|
|
@@ -261,8 +256,7 @@ class DETRLoss(nn.Module):
|
|
|
261
256
|
|
|
262
257
|
@staticmethod
|
|
263
258
|
def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
264
|
-
"""
|
|
265
|
-
Extract batch indices, source indices, and destination indices from match indices.
|
|
259
|
+
"""Extract batch indices, source indices, and destination indices from match indices.
|
|
266
260
|
|
|
267
261
|
Args:
|
|
268
262
|
match_indices (list[tuple]): List of tuples containing matched indices.
|
|
@@ -279,8 +273,7 @@ class DETRLoss(nn.Module):
|
|
|
279
273
|
def _get_assigned_bboxes(
|
|
280
274
|
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
|
|
281
275
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
282
|
-
"""
|
|
283
|
-
Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
|
276
|
+
"""Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
|
284
277
|
|
|
285
278
|
Args:
|
|
286
279
|
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
|
@@ -317,8 +310,7 @@ class DETRLoss(nn.Module):
|
|
|
317
310
|
postfix: str = "",
|
|
318
311
|
match_indices: list[tuple] | None = None,
|
|
319
312
|
) -> dict[str, torch.Tensor]:
|
|
320
|
-
"""
|
|
321
|
-
Calculate losses for a single prediction layer.
|
|
313
|
+
"""Calculate losses for a single prediction layer.
|
|
322
314
|
|
|
323
315
|
Args:
|
|
324
316
|
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
|
@@ -364,8 +356,7 @@ class DETRLoss(nn.Module):
|
|
|
364
356
|
postfix: str = "",
|
|
365
357
|
**kwargs: Any,
|
|
366
358
|
) -> dict[str, torch.Tensor]:
|
|
367
|
-
"""
|
|
368
|
-
Calculate loss for predicted bounding boxes and scores.
|
|
359
|
+
"""Calculate loss for predicted bounding boxes and scores.
|
|
369
360
|
|
|
370
361
|
Args:
|
|
371
362
|
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
|
|
@@ -400,8 +391,7 @@ class DETRLoss(nn.Module):
|
|
|
400
391
|
|
|
401
392
|
|
|
402
393
|
class RTDETRDetectionLoss(DETRLoss):
|
|
403
|
-
"""
|
|
404
|
-
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
|
394
|
+
"""Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
|
405
395
|
|
|
406
396
|
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
|
|
407
397
|
an additional denoising training loss when provided with denoising metadata.
|
|
@@ -415,8 +405,7 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
|
415
405
|
dn_scores: torch.Tensor | None = None,
|
|
416
406
|
dn_meta: dict[str, Any] | None = None,
|
|
417
407
|
) -> dict[str, torch.Tensor]:
|
|
418
|
-
"""
|
|
419
|
-
Forward pass to compute detection loss with optional denoising loss.
|
|
408
|
+
"""Forward pass to compute detection loss with optional denoising loss.
|
|
420
409
|
|
|
421
410
|
Args:
|
|
422
411
|
preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
|
|
@@ -452,8 +441,7 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
|
452
441
|
def get_dn_match_indices(
|
|
453
442
|
dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
|
|
454
443
|
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
455
|
-
"""
|
|
456
|
-
Get match indices for denoising.
|
|
444
|
+
"""Get match indices for denoising.
|
|
457
445
|
|
|
458
446
|
Args:
|
|
459
447
|
dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
|
ultralytics/models/utils/ops.py
CHANGED
|
@@ -14,8 +14,7 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class HungarianMatcher(nn.Module):
|
|
17
|
-
"""
|
|
18
|
-
A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
|
17
|
+
"""A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
|
19
18
|
|
|
20
19
|
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
|
|
21
20
|
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
|
|
@@ -56,8 +55,7 @@ class HungarianMatcher(nn.Module):
|
|
|
56
55
|
alpha: float = 0.25,
|
|
57
56
|
gamma: float = 2.0,
|
|
58
57
|
):
|
|
59
|
-
"""
|
|
60
|
-
Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
|
58
|
+
"""Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
|
61
59
|
|
|
62
60
|
Args:
|
|
63
61
|
cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
|
|
@@ -88,8 +86,7 @@ class HungarianMatcher(nn.Module):
|
|
|
88
86
|
masks: torch.Tensor | None = None,
|
|
89
87
|
gt_mask: list[torch.Tensor] | None = None,
|
|
90
88
|
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
91
|
-
"""
|
|
92
|
-
Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
|
89
|
+
"""Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
|
93
90
|
|
|
94
91
|
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
|
|
95
92
|
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
|
|
@@ -105,10 +102,10 @@ class HungarianMatcher(nn.Module):
|
|
|
105
102
|
gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
|
|
106
103
|
|
|
107
104
|
Returns:
|
|
108
|
-
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
105
|
+
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i,
|
|
106
|
+
index_j), where index_i is the tensor of indices of the selected predictions (in order) and index_j is
|
|
107
|
+
the tensor of indices of the corresponding selected ground truth targets (in order).
|
|
108
|
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
|
|
112
109
|
"""
|
|
113
110
|
bs, nq, nc = pred_scores.shape
|
|
114
111
|
|
|
@@ -198,16 +195,15 @@ def get_cdn_group(
|
|
|
198
195
|
box_noise_scale: float = 1.0,
|
|
199
196
|
training: bool = False,
|
|
200
197
|
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
|
|
201
|
-
"""
|
|
202
|
-
Generate contrastive denoising training group with positive and negative samples from ground truths.
|
|
198
|
+
"""Generate contrastive denoising training group with positive and negative samples from ground truths.
|
|
203
199
|
|
|
204
|
-
This function creates denoising queries for contrastive denoising training by adding noise to ground truth
|
|
205
|
-
|
|
200
|
+
This function creates denoising queries for contrastive denoising training by adding noise to ground truth bounding
|
|
201
|
+
boxes and class labels. It generates both positive and negative samples to improve model robustness.
|
|
206
202
|
|
|
207
203
|
Args:
|
|
208
|
-
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
|
|
209
|
-
|
|
210
|
-
|
|
204
|
+
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)), 'gt_bboxes'
|
|
205
|
+
(torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground truths
|
|
206
|
+
per image.
|
|
211
207
|
num_classes (int): Total number of object classes.
|
|
212
208
|
num_queries (int): Number of object queries.
|
|
213
209
|
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
|
|
@@ -4,4 +4,4 @@ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world,
|
|
|
4
4
|
|
|
5
5
|
from .model import YOLO, YOLOE, YOLOWorld
|
|
6
6
|
|
|
7
|
-
__all__ = "
|
|
7
|
+
__all__ = "YOLO", "YOLOE", "YOLOWorld", "classify", "detect", "obb", "pose", "segment", "world", "yoloe"
|
|
@@ -11,11 +11,10 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ClassificationPredictor(BasePredictor):
|
|
14
|
-
"""
|
|
15
|
-
A class extending the BasePredictor class for prediction based on a classification model.
|
|
14
|
+
"""A class extending the BasePredictor class for prediction based on a classification model.
|
|
16
15
|
|
|
17
|
-
This predictor handles the specific requirements of classification models, including preprocessing images
|
|
18
|
-
|
|
16
|
+
This predictor handles the specific requirements of classification models, including preprocessing images and
|
|
17
|
+
postprocessing predictions to generate classification results.
|
|
19
18
|
|
|
20
19
|
Attributes:
|
|
21
20
|
args (dict): Configuration arguments for the predictor.
|
|
@@ -24,20 +23,19 @@ class ClassificationPredictor(BasePredictor):
|
|
|
24
23
|
preprocess: Convert input images to model-compatible format.
|
|
25
24
|
postprocess: Process model predictions into Results objects.
|
|
26
25
|
|
|
27
|
-
Notes:
|
|
28
|
-
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
|
29
|
-
|
|
30
26
|
Examples:
|
|
31
27
|
>>> from ultralytics.utils import ASSETS
|
|
32
28
|
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
|
|
33
29
|
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
|
|
34
30
|
>>> predictor = ClassificationPredictor(overrides=args)
|
|
35
31
|
>>> predictor.predict_cli()
|
|
32
|
+
|
|
33
|
+
Notes:
|
|
34
|
+
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
|
36
35
|
"""
|
|
37
36
|
|
|
38
37
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
39
|
-
"""
|
|
40
|
-
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
|
38
|
+
"""Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
|
41
39
|
|
|
42
40
|
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
|
43
41
|
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
|
@@ -72,8 +70,7 @@ class ClassificationPredictor(BasePredictor):
|
|
|
72
70
|
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
|
|
73
71
|
|
|
74
72
|
def postprocess(self, preds, img, orig_imgs):
|
|
75
|
-
"""
|
|
76
|
-
Process predictions to return Results objects with classification probabilities.
|
|
73
|
+
"""Process predictions to return Results objects with classification probabilities.
|
|
77
74
|
|
|
78
75
|
Args:
|
|
79
76
|
preds (torch.Tensor): Raw predictions from the model.
|
|
@@ -84,7 +81,7 @@ class ClassificationPredictor(BasePredictor):
|
|
|
84
81
|
(list[Results]): List of Results objects containing classification results for each image.
|
|
85
82
|
"""
|
|
86
83
|
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
|
|
87
|
-
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
84
|
+
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
|
|
88
85
|
|
|
89
86
|
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
|
90
87
|
return [
|
|
@@ -11,14 +11,13 @@ from ultralytics.data import ClassificationDataset, build_dataloader
|
|
|
11
11
|
from ultralytics.engine.trainer import BaseTrainer
|
|
12
12
|
from ultralytics.models import yolo
|
|
13
13
|
from ultralytics.nn.tasks import ClassificationModel
|
|
14
|
-
from ultralytics.utils import DEFAULT_CFG,
|
|
15
|
-
from ultralytics.utils.plotting import plot_images
|
|
16
|
-
from ultralytics.utils.torch_utils import is_parallel,
|
|
14
|
+
from ultralytics.utils import DEFAULT_CFG, RANK
|
|
15
|
+
from ultralytics.utils.plotting import plot_images
|
|
16
|
+
from ultralytics.utils.torch_utils import is_parallel, torch_distributed_zero_first
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class ClassificationTrainer(BaseTrainer):
|
|
20
|
-
"""
|
|
21
|
-
A trainer class extending BaseTrainer for training image classification models.
|
|
20
|
+
"""A trainer class extending BaseTrainer for training image classification models.
|
|
22
21
|
|
|
23
22
|
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
|
24
23
|
and torchvision models with comprehensive dataset handling and validation.
|
|
@@ -38,8 +37,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
38
37
|
preprocess_batch: Preprocess a batch of images and classes.
|
|
39
38
|
progress_string: Return a formatted string showing training progress.
|
|
40
39
|
get_validator: Return an instance of ClassificationValidator.
|
|
41
|
-
label_loss_items: Return a loss dict with
|
|
42
|
-
plot_metrics: Plot metrics from a CSV file.
|
|
40
|
+
label_loss_items: Return a loss dict with labeled training loss items.
|
|
43
41
|
final_eval: Evaluate trained model and save validation results.
|
|
44
42
|
plot_training_samples: Plot training samples with their annotations.
|
|
45
43
|
|
|
@@ -52,8 +50,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
52
50
|
"""
|
|
53
51
|
|
|
54
52
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
55
|
-
"""
|
|
56
|
-
Initialize a ClassificationTrainer object.
|
|
53
|
+
"""Initialize a ClassificationTrainer object.
|
|
57
54
|
|
|
58
55
|
Args:
|
|
59
56
|
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
|
|
@@ -72,8 +69,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
72
69
|
self.model.names = self.data["names"]
|
|
73
70
|
|
|
74
71
|
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
75
|
-
"""
|
|
76
|
-
Return a modified PyTorch model configured for training YOLO classification.
|
|
72
|
+
"""Return a modified PyTorch model configured for training YOLO classification.
|
|
77
73
|
|
|
78
74
|
Args:
|
|
79
75
|
cfg (Any, optional): Model configuration.
|
|
@@ -97,8 +93,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
97
93
|
return model
|
|
98
94
|
|
|
99
95
|
def setup_model(self):
|
|
100
|
-
"""
|
|
101
|
-
Load, create or download model for classification tasks.
|
|
96
|
+
"""Load, create or download model for classification tasks.
|
|
102
97
|
|
|
103
98
|
Returns:
|
|
104
99
|
(Any): Model checkpoint if applicable, otherwise None.
|
|
@@ -116,8 +111,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
116
111
|
return ckpt
|
|
117
112
|
|
|
118
113
|
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
|
119
|
-
"""
|
|
120
|
-
Create a ClassificationDataset instance given an image path and mode.
|
|
114
|
+
"""Create a ClassificationDataset instance given an image path and mode.
|
|
121
115
|
|
|
122
116
|
Args:
|
|
123
117
|
img_path (str): Path to the dataset images.
|
|
@@ -130,8 +124,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
130
124
|
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
|
131
125
|
|
|
132
126
|
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
|
133
|
-
"""
|
|
134
|
-
Return PyTorch DataLoader with transforms to preprocess images.
|
|
127
|
+
"""Return PyTorch DataLoader with transforms to preprocess images.
|
|
135
128
|
|
|
136
129
|
Args:
|
|
137
130
|
dataset_path (str): Path to the dataset.
|
|
@@ -156,8 +149,8 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
156
149
|
|
|
157
150
|
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
158
151
|
"""Preprocess a batch of images and classes."""
|
|
159
|
-
batch["img"] = batch["img"].to(self.device, non_blocking=
|
|
160
|
-
batch["cls"] = batch["cls"].to(self.device, non_blocking=
|
|
152
|
+
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
153
|
+
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
161
154
|
return batch
|
|
162
155
|
|
|
163
156
|
def progress_string(self) -> str:
|
|
@@ -178,8 +171,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
178
171
|
)
|
|
179
172
|
|
|
180
173
|
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
|
|
181
|
-
"""
|
|
182
|
-
Return a loss dict with labelled training loss items tensor.
|
|
174
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
183
175
|
|
|
184
176
|
Args:
|
|
185
177
|
loss_items (torch.Tensor, optional): Loss tensor items.
|
|
@@ -195,32 +187,14 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
195
187
|
loss_items = [round(float(loss_items), 5)]
|
|
196
188
|
return dict(zip(keys, loss_items))
|
|
197
189
|
|
|
198
|
-
def plot_metrics(self):
|
|
199
|
-
"""Plot metrics from a CSV file."""
|
|
200
|
-
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
|
201
|
-
|
|
202
|
-
def final_eval(self):
|
|
203
|
-
"""Evaluate trained model and save validation results."""
|
|
204
|
-
for f in self.last, self.best:
|
|
205
|
-
if f.exists():
|
|
206
|
-
strip_optimizer(f) # strip optimizers
|
|
207
|
-
if f is self.best:
|
|
208
|
-
LOGGER.info(f"\nValidating {f}...")
|
|
209
|
-
self.validator.args.data = self.args.data
|
|
210
|
-
self.validator.args.plots = self.args.plots
|
|
211
|
-
self.metrics = self.validator(model=f)
|
|
212
|
-
self.metrics.pop("fitness", None)
|
|
213
|
-
self.run_callbacks("on_fit_epoch_end")
|
|
214
|
-
|
|
215
190
|
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
|
|
216
|
-
"""
|
|
217
|
-
Plot training samples with their annotations.
|
|
191
|
+
"""Plot training samples with their annotations.
|
|
218
192
|
|
|
219
193
|
Args:
|
|
220
194
|
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
|
|
221
195
|
ni (int): Number of iterations.
|
|
222
196
|
"""
|
|
223
|
-
batch["batch_idx"] = torch.arange(
|
|
197
|
+
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
|
224
198
|
plot_images(
|
|
225
199
|
labels=batch,
|
|
226
200
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|