dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from copy import copy
|
|
6
|
+
from typing import Any
|
|
4
7
|
|
|
5
8
|
import torch
|
|
6
9
|
|
|
@@ -8,22 +11,21 @@ from ultralytics.data import ClassificationDataset, build_dataloader
|
|
|
8
11
|
from ultralytics.engine.trainer import BaseTrainer
|
|
9
12
|
from ultralytics.models import yolo
|
|
10
13
|
from ultralytics.nn.tasks import ClassificationModel
|
|
11
|
-
from ultralytics.utils import DEFAULT_CFG,
|
|
12
|
-
from ultralytics.utils.plotting import plot_images
|
|
13
|
-
from ultralytics.utils.torch_utils import is_parallel,
|
|
14
|
+
from ultralytics.utils import DEFAULT_CFG, RANK
|
|
15
|
+
from ultralytics.utils.plotting import plot_images
|
|
16
|
+
from ultralytics.utils.torch_utils import is_parallel, torch_distributed_zero_first
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
class ClassificationTrainer(BaseTrainer):
|
|
17
|
-
"""
|
|
18
|
-
A class extending the BaseTrainer class for training based on a classification model.
|
|
20
|
+
"""A trainer class extending BaseTrainer for training image classification models.
|
|
19
21
|
|
|
20
22
|
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
|
21
|
-
and torchvision models.
|
|
23
|
+
and torchvision models with comprehensive dataset handling and validation.
|
|
22
24
|
|
|
23
25
|
Attributes:
|
|
24
26
|
model (ClassificationModel): The classification model to be trained.
|
|
25
|
-
data (dict): Dictionary containing dataset information including class names and number of classes.
|
|
26
|
-
loss_names (
|
|
27
|
+
data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
|
|
28
|
+
loss_names (list[str]): Names of the loss functions used during training.
|
|
27
29
|
validator (ClassificationValidator): Validator instance for model evaluation.
|
|
28
30
|
|
|
29
31
|
Methods:
|
|
@@ -35,35 +37,25 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
35
37
|
preprocess_batch: Preprocess a batch of images and classes.
|
|
36
38
|
progress_string: Return a formatted string showing training progress.
|
|
37
39
|
get_validator: Return an instance of ClassificationValidator.
|
|
38
|
-
label_loss_items: Return a loss dict with
|
|
39
|
-
plot_metrics: Plot metrics from a CSV file.
|
|
40
|
+
label_loss_items: Return a loss dict with labeled training loss items.
|
|
40
41
|
final_eval: Evaluate trained model and save validation results.
|
|
41
42
|
plot_training_samples: Plot training samples with their annotations.
|
|
42
43
|
|
|
43
44
|
Examples:
|
|
45
|
+
Initialize and train a classification model
|
|
44
46
|
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
|
45
47
|
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
|
46
48
|
>>> trainer = ClassificationTrainer(overrides=args)
|
|
47
49
|
>>> trainer.train()
|
|
48
50
|
"""
|
|
49
51
|
|
|
50
|
-
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
51
|
-
"""
|
|
52
|
-
Initialize a ClassificationTrainer object.
|
|
53
|
-
|
|
54
|
-
This constructor sets up a trainer for image classification tasks, configuring the task type and default
|
|
55
|
-
image size if not specified.
|
|
52
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
53
|
+
"""Initialize a ClassificationTrainer object.
|
|
56
54
|
|
|
57
55
|
Args:
|
|
58
|
-
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
|
59
|
-
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
|
60
|
-
_callbacks (list, optional): List of callback functions to be executed during training.
|
|
61
|
-
|
|
62
|
-
Examples:
|
|
63
|
-
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
|
64
|
-
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
|
65
|
-
>>> trainer = ClassificationTrainer(overrides=args)
|
|
66
|
-
>>> trainer.train()
|
|
56
|
+
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
|
|
57
|
+
overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
|
|
58
|
+
_callbacks (list[Any], optional): List of callback functions to be executed during training.
|
|
67
59
|
"""
|
|
68
60
|
if overrides is None:
|
|
69
61
|
overrides = {}
|
|
@@ -76,14 +68,13 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
76
68
|
"""Set the YOLO model's class names from the loaded dataset."""
|
|
77
69
|
self.model.names = self.data["names"]
|
|
78
70
|
|
|
79
|
-
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
80
|
-
"""
|
|
81
|
-
Return a modified PyTorch model configured for training YOLO.
|
|
71
|
+
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
72
|
+
"""Return a modified PyTorch model configured for training YOLO classification.
|
|
82
73
|
|
|
83
74
|
Args:
|
|
84
|
-
cfg (Any): Model configuration.
|
|
85
|
-
weights (Any): Pre-trained model weights.
|
|
86
|
-
verbose (bool): Whether to display model information.
|
|
75
|
+
cfg (Any, optional): Model configuration.
|
|
76
|
+
weights (Any, optional): Pre-trained model weights.
|
|
77
|
+
verbose (bool, optional): Whether to display model information.
|
|
87
78
|
|
|
88
79
|
Returns:
|
|
89
80
|
(ClassificationModel): Configured PyTorch model for classification.
|
|
@@ -102,8 +93,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
102
93
|
return model
|
|
103
94
|
|
|
104
95
|
def setup_model(self):
|
|
105
|
-
"""
|
|
106
|
-
Load, create or download model for classification tasks.
|
|
96
|
+
"""Load, create or download model for classification tasks.
|
|
107
97
|
|
|
108
98
|
Returns:
|
|
109
99
|
(Any): Model checkpoint if applicable, otherwise None.
|
|
@@ -120,29 +110,27 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
120
110
|
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
|
121
111
|
return ckpt
|
|
122
112
|
|
|
123
|
-
def build_dataset(self, img_path, mode="train", batch=None):
|
|
124
|
-
"""
|
|
125
|
-
Create a ClassificationDataset instance given an image path and mode.
|
|
113
|
+
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
|
114
|
+
"""Create a ClassificationDataset instance given an image path and mode.
|
|
126
115
|
|
|
127
116
|
Args:
|
|
128
117
|
img_path (str): Path to the dataset images.
|
|
129
|
-
mode (str): Dataset mode ('train', 'val', or 'test').
|
|
130
|
-
batch (Any): Batch information (unused in this implementation).
|
|
118
|
+
mode (str, optional): Dataset mode ('train', 'val', or 'test').
|
|
119
|
+
batch (Any, optional): Batch information (unused in this implementation).
|
|
131
120
|
|
|
132
121
|
Returns:
|
|
133
122
|
(ClassificationDataset): Dataset for the specified mode.
|
|
134
123
|
"""
|
|
135
124
|
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
|
136
125
|
|
|
137
|
-
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
|
138
|
-
"""
|
|
139
|
-
Return PyTorch DataLoader with transforms to preprocess images.
|
|
126
|
+
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
|
127
|
+
"""Return PyTorch DataLoader with transforms to preprocess images.
|
|
140
128
|
|
|
141
129
|
Args:
|
|
142
130
|
dataset_path (str): Path to the dataset.
|
|
143
|
-
batch_size (int): Number of images per batch.
|
|
144
|
-
rank (int): Process rank for distributed training.
|
|
145
|
-
mode (str): 'train', 'val', or 'test' mode.
|
|
131
|
+
batch_size (int, optional): Number of images per batch.
|
|
132
|
+
rank (int, optional): Process rank for distributed training.
|
|
133
|
+
mode (str, optional): 'train', 'val', or 'test' mode.
|
|
146
134
|
|
|
147
135
|
Returns:
|
|
148
136
|
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
|
|
@@ -150,7 +138,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
150
138
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
|
151
139
|
dataset = self.build_dataset(dataset_path, mode)
|
|
152
140
|
|
|
153
|
-
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
|
|
141
|
+
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
|
|
154
142
|
# Attach inference transforms
|
|
155
143
|
if mode != "train":
|
|
156
144
|
if is_parallel(self.model):
|
|
@@ -159,14 +147,14 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
159
147
|
self.model.transforms = loader.dataset.torch_transforms
|
|
160
148
|
return loader
|
|
161
149
|
|
|
162
|
-
def preprocess_batch(self, batch):
|
|
163
|
-
"""
|
|
164
|
-
batch["img"] = batch["img"].to(self.device)
|
|
165
|
-
batch["cls"] = batch["cls"].to(self.device)
|
|
150
|
+
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
151
|
+
"""Preprocess a batch of images and classes."""
|
|
152
|
+
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
153
|
+
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
166
154
|
return batch
|
|
167
155
|
|
|
168
|
-
def progress_string(self):
|
|
169
|
-
"""
|
|
156
|
+
def progress_string(self) -> str:
|
|
157
|
+
"""Return a formatted string showing training progress."""
|
|
170
158
|
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
|
171
159
|
"Epoch",
|
|
172
160
|
"GPU_mem",
|
|
@@ -176,22 +164,22 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
176
164
|
)
|
|
177
165
|
|
|
178
166
|
def get_validator(self):
|
|
179
|
-
"""
|
|
167
|
+
"""Return an instance of ClassificationValidator for validation."""
|
|
180
168
|
self.loss_names = ["loss"]
|
|
181
169
|
return yolo.classify.ClassificationValidator(
|
|
182
170
|
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
183
171
|
)
|
|
184
172
|
|
|
185
|
-
def label_loss_items(self, loss_items=None, prefix="train"):
|
|
186
|
-
"""
|
|
187
|
-
Return a loss dict with labelled training loss items tensor.
|
|
173
|
+
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
|
|
174
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
188
175
|
|
|
189
176
|
Args:
|
|
190
177
|
loss_items (torch.Tensor, optional): Loss tensor items.
|
|
191
|
-
prefix (str): Prefix to prepend to loss names.
|
|
178
|
+
prefix (str, optional): Prefix to prepend to loss names.
|
|
192
179
|
|
|
193
180
|
Returns:
|
|
194
|
-
(
|
|
181
|
+
keys (list[str]): List of loss keys if loss_items is None.
|
|
182
|
+
loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
|
|
195
183
|
"""
|
|
196
184
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
|
197
185
|
if loss_items is None:
|
|
@@ -199,35 +187,16 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
199
187
|
loss_items = [round(float(loss_items), 5)]
|
|
200
188
|
return dict(zip(keys, loss_items))
|
|
201
189
|
|
|
202
|
-
def
|
|
203
|
-
"""Plot
|
|
204
|
-
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
|
205
|
-
|
|
206
|
-
def final_eval(self):
|
|
207
|
-
"""Evaluate trained model and save validation results."""
|
|
208
|
-
for f in self.last, self.best:
|
|
209
|
-
if f.exists():
|
|
210
|
-
strip_optimizer(f) # strip optimizers
|
|
211
|
-
if f is self.best:
|
|
212
|
-
LOGGER.info(f"\nValidating {f}...")
|
|
213
|
-
self.validator.args.data = self.args.data
|
|
214
|
-
self.validator.args.plots = self.args.plots
|
|
215
|
-
self.metrics = self.validator(model=f)
|
|
216
|
-
self.metrics.pop("fitness", None)
|
|
217
|
-
self.run_callbacks("on_fit_epoch_end")
|
|
218
|
-
|
|
219
|
-
def plot_training_samples(self, batch, ni):
|
|
220
|
-
"""
|
|
221
|
-
Plot training samples with their annotations.
|
|
190
|
+
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
|
|
191
|
+
"""Plot training samples with their annotations.
|
|
222
192
|
|
|
223
193
|
Args:
|
|
224
|
-
batch (
|
|
194
|
+
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
|
|
225
195
|
ni (int): Number of iterations.
|
|
226
196
|
"""
|
|
197
|
+
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
|
227
198
|
plot_images(
|
|
228
|
-
|
|
229
|
-
batch_idx=torch.arange(len(batch["img"])),
|
|
230
|
-
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
|
199
|
+
labels=batch,
|
|
231
200
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
232
201
|
on_plot=self.on_plot,
|
|
233
202
|
)
|
|
@@ -1,24 +1,29 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
3
8
|
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
4
10
|
|
|
5
11
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
|
6
12
|
from ultralytics.engine.validator import BaseValidator
|
|
7
|
-
from ultralytics.utils import LOGGER
|
|
13
|
+
from ultralytics.utils import LOGGER, RANK
|
|
8
14
|
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
|
9
15
|
from ultralytics.utils.plotting import plot_images
|
|
10
16
|
|
|
11
17
|
|
|
12
18
|
class ClassificationValidator(BaseValidator):
|
|
13
|
-
"""
|
|
14
|
-
A class extending the BaseValidator class for validation based on a classification model.
|
|
19
|
+
"""A class extending the BaseValidator class for validation based on a classification model.
|
|
15
20
|
|
|
16
|
-
This validator handles the validation process for classification models, including metrics calculation,
|
|
17
|
-
|
|
21
|
+
This validator handles the validation process for classification models, including metrics calculation, confusion
|
|
22
|
+
matrix generation, and visualization of results.
|
|
18
23
|
|
|
19
24
|
Attributes:
|
|
20
|
-
targets (
|
|
21
|
-
pred (
|
|
25
|
+
targets (list[torch.Tensor]): Ground truth class labels.
|
|
26
|
+
pred (list[torch.Tensor]): Model predictions.
|
|
22
27
|
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
|
|
23
28
|
names (dict): Mapping of class indices to class names.
|
|
24
29
|
nc (int): Number of classes.
|
|
@@ -48,17 +53,12 @@ class ClassificationValidator(BaseValidator):
|
|
|
48
53
|
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
|
49
54
|
"""
|
|
50
55
|
|
|
51
|
-
def __init__(self, dataloader=None, save_dir=None,
|
|
52
|
-
"""
|
|
53
|
-
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
|
54
|
-
|
|
55
|
-
This validator handles the validation process for classification models, including metrics calculation,
|
|
56
|
-
confusion matrix generation, and visualization of results.
|
|
56
|
+
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
57
|
+
"""Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
|
57
58
|
|
|
58
59
|
Args:
|
|
59
60
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
|
60
61
|
save_dir (str | Path, optional): Directory to save results.
|
|
61
|
-
pbar (bool, optional): Display a progress bar.
|
|
62
62
|
args (dict, optional): Arguments containing model and validation configuration.
|
|
63
63
|
_callbacks (list, optional): List of callback functions to be called during validation.
|
|
64
64
|
|
|
@@ -68,56 +68,48 @@ class ClassificationValidator(BaseValidator):
|
|
|
68
68
|
>>> validator = ClassificationValidator(args=args)
|
|
69
69
|
>>> validator()
|
|
70
70
|
"""
|
|
71
|
-
super().__init__(dataloader, save_dir,
|
|
71
|
+
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
72
72
|
self.targets = None
|
|
73
73
|
self.pred = None
|
|
74
74
|
self.args.task = "classify"
|
|
75
75
|
self.metrics = ClassifyMetrics()
|
|
76
76
|
|
|
77
|
-
def get_desc(self):
|
|
77
|
+
def get_desc(self) -> str:
|
|
78
78
|
"""Return a formatted string summarizing classification metrics."""
|
|
79
79
|
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
|
80
80
|
|
|
81
|
-
def init_metrics(self, model):
|
|
81
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
82
82
|
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
|
|
83
83
|
self.names = model.names
|
|
84
84
|
self.nc = len(model.names)
|
|
85
|
-
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
|
|
86
85
|
self.pred = []
|
|
87
86
|
self.targets = []
|
|
87
|
+
self.confusion_matrix = ConfusionMatrix(names=model.names)
|
|
88
88
|
|
|
89
|
-
def preprocess(self, batch):
|
|
89
|
+
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
90
90
|
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
|
91
|
-
batch["img"] = batch["img"].to(self.device, non_blocking=
|
|
91
|
+
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
92
92
|
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
|
93
|
-
batch["cls"] = batch["cls"].to(self.device)
|
|
93
|
+
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
94
94
|
return batch
|
|
95
95
|
|
|
96
|
-
def update_metrics(self, preds, batch):
|
|
97
|
-
"""
|
|
98
|
-
Update running metrics with model predictions and batch targets.
|
|
96
|
+
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
|
97
|
+
"""Update running metrics with model predictions and batch targets.
|
|
99
98
|
|
|
100
99
|
Args:
|
|
101
100
|
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
|
102
101
|
batch (dict): Batch data containing images and class labels.
|
|
103
102
|
|
|
104
|
-
|
|
105
|
-
|
|
103
|
+
Notes:
|
|
104
|
+
This method appends the top-N predictions (sorted by confidence in descending order) to the
|
|
105
|
+
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
|
|
106
106
|
"""
|
|
107
107
|
n5 = min(len(self.names), 5)
|
|
108
108
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
|
109
109
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
|
110
110
|
|
|
111
|
-
def finalize_metrics(self
|
|
112
|
-
"""
|
|
113
|
-
Finalize metrics including confusion matrix and processing speed.
|
|
114
|
-
|
|
115
|
-
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
|
116
|
-
optionally plots it, and updates the metrics object with speed information.
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
*args (Any): Variable length argument list.
|
|
120
|
-
**kwargs (Any): Arbitrary keyword arguments.
|
|
111
|
+
def finalize_metrics(self) -> None:
|
|
112
|
+
"""Finalize metrics including confusion matrix and processing speed.
|
|
121
113
|
|
|
122
114
|
Examples:
|
|
123
115
|
>>> validator = ClassificationValidator()
|
|
@@ -125,33 +117,47 @@ class ClassificationValidator(BaseValidator):
|
|
|
125
117
|
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
|
126
118
|
>>> validator.finalize_metrics()
|
|
127
119
|
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
|
120
|
+
|
|
121
|
+
Notes:
|
|
122
|
+
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
|
123
|
+
optionally plots it, and updates the metrics object with speed information.
|
|
128
124
|
"""
|
|
129
125
|
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
|
130
126
|
if self.args.plots:
|
|
131
127
|
for normalize in True, False:
|
|
132
|
-
self.confusion_matrix.plot(
|
|
133
|
-
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
|
134
|
-
)
|
|
128
|
+
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
|
135
129
|
self.metrics.speed = self.speed
|
|
136
|
-
self.metrics.confusion_matrix = self.confusion_matrix
|
|
137
130
|
self.metrics.save_dir = self.save_dir
|
|
131
|
+
self.metrics.confusion_matrix = self.confusion_matrix
|
|
138
132
|
|
|
139
|
-
def postprocess(self, preds):
|
|
133
|
+
def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
|
|
140
134
|
"""Extract the primary prediction from model output if it's in a list or tuple format."""
|
|
141
135
|
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
|
142
136
|
|
|
143
|
-
def get_stats(self):
|
|
137
|
+
def get_stats(self) -> dict[str, float]:
|
|
144
138
|
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
|
|
145
139
|
self.metrics.process(self.targets, self.pred)
|
|
146
140
|
return self.metrics.results_dict
|
|
147
141
|
|
|
148
|
-
def
|
|
142
|
+
def gather_stats(self) -> None:
|
|
143
|
+
"""Gather stats from all GPUs."""
|
|
144
|
+
if RANK == 0:
|
|
145
|
+
gathered_preds = [None] * dist.get_world_size()
|
|
146
|
+
gathered_targets = [None] * dist.get_world_size()
|
|
147
|
+
dist.gather_object(self.pred, gathered_preds, dst=0)
|
|
148
|
+
dist.gather_object(self.targets, gathered_targets, dst=0)
|
|
149
|
+
self.pred = [pred for rank in gathered_preds for pred in rank]
|
|
150
|
+
self.targets = [targets for rank in gathered_targets for targets in rank]
|
|
151
|
+
elif RANK > 0:
|
|
152
|
+
dist.gather_object(self.pred, None, dst=0)
|
|
153
|
+
dist.gather_object(self.targets, None, dst=0)
|
|
154
|
+
|
|
155
|
+
def build_dataset(self, img_path: str) -> ClassificationDataset:
|
|
149
156
|
"""Create a ClassificationDataset instance for validation."""
|
|
150
157
|
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
|
151
158
|
|
|
152
|
-
def get_dataloader(self, dataset_path, batch_size):
|
|
153
|
-
"""
|
|
154
|
-
Build and return a data loader for classification validation.
|
|
159
|
+
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
|
160
|
+
"""Build and return a data loader for classification validation.
|
|
155
161
|
|
|
156
162
|
Args:
|
|
157
163
|
dataset_path (str | Path): Path to the dataset directory.
|
|
@@ -163,17 +169,16 @@ class ClassificationValidator(BaseValidator):
|
|
|
163
169
|
dataset = self.build_dataset(dataset_path)
|
|
164
170
|
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
|
165
171
|
|
|
166
|
-
def print_results(self):
|
|
172
|
+
def print_results(self) -> None:
|
|
167
173
|
"""Print evaluation metrics for the classification model."""
|
|
168
174
|
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
|
|
169
175
|
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
|
170
176
|
|
|
171
|
-
def plot_val_samples(self, batch, ni):
|
|
172
|
-
"""
|
|
173
|
-
Plot validation image samples with their ground truth labels.
|
|
177
|
+
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
178
|
+
"""Plot validation image samples with their ground truth labels.
|
|
174
179
|
|
|
175
180
|
Args:
|
|
176
|
-
batch (dict): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
|
181
|
+
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
|
177
182
|
ni (int): Batch index used for naming the output file.
|
|
178
183
|
|
|
179
184
|
Examples:
|
|
@@ -181,21 +186,19 @@ class ClassificationValidator(BaseValidator):
|
|
|
181
186
|
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
|
|
182
187
|
>>> validator.plot_val_samples(batch, 0)
|
|
183
188
|
"""
|
|
189
|
+
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
|
184
190
|
plot_images(
|
|
185
|
-
|
|
186
|
-
batch_idx=torch.arange(len(batch["img"])),
|
|
187
|
-
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
|
191
|
+
labels=batch,
|
|
188
192
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
|
189
193
|
names=self.names,
|
|
190
194
|
on_plot=self.on_plot,
|
|
191
195
|
)
|
|
192
196
|
|
|
193
|
-
def plot_predictions(self, batch, preds, ni):
|
|
194
|
-
"""
|
|
195
|
-
Plot images with their predicted class labels and save the visualization.
|
|
197
|
+
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
|
198
|
+
"""Plot images with their predicted class labels and save the visualization.
|
|
196
199
|
|
|
197
200
|
Args:
|
|
198
|
-
batch (dict): Batch data containing images and other information.
|
|
201
|
+
batch (dict[str, Any]): Batch data containing images and other information.
|
|
199
202
|
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
|
|
200
203
|
ni (int): Batch index used for naming the output file.
|
|
201
204
|
|
|
@@ -205,10 +208,14 @@ class ClassificationValidator(BaseValidator):
|
|
|
205
208
|
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
|
|
206
209
|
>>> validator.plot_predictions(batch, preds, 0)
|
|
207
210
|
"""
|
|
208
|
-
|
|
209
|
-
batch["img"],
|
|
210
|
-
batch_idx=torch.arange(
|
|
211
|
+
batched_preds = dict(
|
|
212
|
+
img=batch["img"],
|
|
213
|
+
batch_idx=torch.arange(batch["img"].shape[0]),
|
|
211
214
|
cls=torch.argmax(preds, dim=1),
|
|
215
|
+
conf=torch.amax(preds, dim=1),
|
|
216
|
+
)
|
|
217
|
+
plot_images(
|
|
218
|
+
batched_preds,
|
|
212
219
|
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
213
220
|
names=self.names,
|
|
214
221
|
on_plot=self.on_plot,
|
|
@@ -2,12 +2,11 @@
|
|
|
2
2
|
|
|
3
3
|
from ultralytics.engine.predictor import BasePredictor
|
|
4
4
|
from ultralytics.engine.results import Results
|
|
5
|
-
from ultralytics.utils import ops
|
|
5
|
+
from ultralytics.utils import nms, ops
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class DetectionPredictor(BasePredictor):
|
|
9
|
-
"""
|
|
10
|
-
A class extending the BasePredictor class for prediction based on a detection model.
|
|
9
|
+
"""A class extending the BasePredictor class for prediction based on a detection model.
|
|
11
10
|
|
|
12
11
|
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
|
|
13
12
|
with bounding boxes and class predictions.
|
|
@@ -21,6 +20,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
21
20
|
postprocess: Process raw model predictions into detection results.
|
|
22
21
|
construct_results: Build Results objects from processed predictions.
|
|
23
22
|
construct_result: Create a single Result object from a prediction.
|
|
23
|
+
get_obj_feats: Extract object features from the feature maps.
|
|
24
24
|
|
|
25
25
|
Examples:
|
|
26
26
|
>>> from ultralytics.utils import ASSETS
|
|
@@ -31,8 +31,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
def postprocess(self, preds, img, orig_imgs, **kwargs):
|
|
34
|
-
"""
|
|
35
|
-
Post-process predictions and return a list of Results objects.
|
|
34
|
+
"""Post-process predictions and return a list of Results objects.
|
|
36
35
|
|
|
37
36
|
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
|
|
38
37
|
further analysis.
|
|
@@ -52,7 +51,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
52
51
|
>>> processed_results = predictor.postprocess(preds, img, orig_imgs)
|
|
53
52
|
"""
|
|
54
53
|
save_feats = getattr(self, "_feats", None) is not None
|
|
55
|
-
preds =
|
|
54
|
+
preds = nms.non_max_suppression(
|
|
56
55
|
preds,
|
|
57
56
|
self.args.conf,
|
|
58
57
|
self.args.iou,
|
|
@@ -84,23 +83,22 @@ class DetectionPredictor(BasePredictor):
|
|
|
84
83
|
"""Extract object features from the feature maps."""
|
|
85
84
|
import torch
|
|
86
85
|
|
|
87
|
-
s = min(
|
|
86
|
+
s = min(x.shape[1] for x in feat_maps) # find shortest vector length
|
|
88
87
|
obj_feats = torch.cat(
|
|
89
88
|
[x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
|
|
90
89
|
) # mean reduce all vectors to same length
|
|
91
|
-
return [feats[idx] if
|
|
90
|
+
return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
|
|
92
91
|
|
|
93
92
|
def construct_results(self, preds, img, orig_imgs):
|
|
94
|
-
"""
|
|
95
|
-
Construct a list of Results objects from model predictions.
|
|
93
|
+
"""Construct a list of Results objects from model predictions.
|
|
96
94
|
|
|
97
95
|
Args:
|
|
98
|
-
preds (
|
|
96
|
+
preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
|
|
99
97
|
img (torch.Tensor): Batch of preprocessed images used for inference.
|
|
100
|
-
orig_imgs (
|
|
98
|
+
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
|
|
101
99
|
|
|
102
100
|
Returns:
|
|
103
|
-
(
|
|
101
|
+
(list[Results]): List of Results objects containing detection information for each image.
|
|
104
102
|
"""
|
|
105
103
|
return [
|
|
106
104
|
self.construct_result(pred, img, orig_img, img_path)
|
|
@@ -108,8 +106,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
108
106
|
]
|
|
109
107
|
|
|
110
108
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
111
|
-
"""
|
|
112
|
-
Construct a single Results object from one image prediction.
|
|
109
|
+
"""Construct a single Results object from one image prediction.
|
|
113
110
|
|
|
114
111
|
Args:
|
|
115
112
|
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
|