dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
|
@@ -6,6 +6,8 @@ Usage:
|
|
|
6
6
|
$ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
9
11
|
import gc
|
|
10
12
|
import math
|
|
11
13
|
import os
|
|
@@ -61,8 +63,7 @@ from ultralytics.utils.torch_utils import (
|
|
|
61
63
|
|
|
62
64
|
|
|
63
65
|
class BaseTrainer:
|
|
64
|
-
"""
|
|
65
|
-
A base class for creating trainers.
|
|
66
|
+
"""A base class for creating trainers.
|
|
66
67
|
|
|
67
68
|
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
|
|
68
69
|
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
|
|
@@ -112,8 +113,7 @@ class BaseTrainer:
|
|
|
112
113
|
"""
|
|
113
114
|
|
|
114
115
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
115
|
-
"""
|
|
116
|
-
Initialize the BaseTrainer class.
|
|
116
|
+
"""Initialize the BaseTrainer class.
|
|
117
117
|
|
|
118
118
|
Args:
|
|
119
119
|
cfg (str, optional): Path to a configuration file.
|
|
@@ -138,7 +138,12 @@ class BaseTrainer:
|
|
|
138
138
|
if RANK in {-1, 0}:
|
|
139
139
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
|
140
140
|
self.args.save_dir = str(self.save_dir)
|
|
141
|
-
|
|
141
|
+
# Save run args, serializing augmentations as reprs for resume compatibility
|
|
142
|
+
args_dict = vars(self.args).copy()
|
|
143
|
+
if args_dict.get("augmentations") is not None:
|
|
144
|
+
# Serialize Albumentations transforms as their repr strings for checkpoint compatibility
|
|
145
|
+
args_dict["augmentations"] = [repr(t) for t in args_dict["augmentations"]]
|
|
146
|
+
YAML.save(self.save_dir / "args.yaml", args_dict) # save run args
|
|
142
147
|
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
|
143
148
|
self.save_period = self.args.save_period
|
|
144
149
|
|
|
@@ -318,18 +323,18 @@ class BaseTrainer:
|
|
|
318
323
|
self.train_loader = self.get_dataloader(
|
|
319
324
|
self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
|
|
320
325
|
)
|
|
326
|
+
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
|
327
|
+
self.test_loader = self.get_dataloader(
|
|
328
|
+
self.data.get("val") or self.data.get("test"),
|
|
329
|
+
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
|
|
330
|
+
rank=LOCAL_RANK,
|
|
331
|
+
mode="val",
|
|
332
|
+
)
|
|
333
|
+
self.validator = self.get_validator()
|
|
334
|
+
self.ema = ModelEMA(self.model)
|
|
321
335
|
if RANK in {-1, 0}:
|
|
322
|
-
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
|
323
|
-
self.test_loader = self.get_dataloader(
|
|
324
|
-
self.data.get("val") or self.data.get("test"),
|
|
325
|
-
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
|
|
326
|
-
rank=-1,
|
|
327
|
-
mode="val",
|
|
328
|
-
)
|
|
329
|
-
self.validator = self.get_validator()
|
|
330
336
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
|
331
337
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
|
332
|
-
self.ema = ModelEMA(self.model)
|
|
333
338
|
if self.args.plots:
|
|
334
339
|
self.plot_training_labels()
|
|
335
340
|
|
|
@@ -464,13 +469,13 @@ class BaseTrainer:
|
|
|
464
469
|
|
|
465
470
|
self.run_callbacks("on_train_epoch_end")
|
|
466
471
|
if RANK in {-1, 0}:
|
|
467
|
-
final_epoch = epoch + 1 >= self.epochs
|
|
468
472
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
|
469
473
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
+
# Validation
|
|
475
|
+
final_epoch = epoch + 1 >= self.epochs
|
|
476
|
+
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
|
477
|
+
self._clear_memory(threshold=0.5) # prevent VRAM spike
|
|
478
|
+
self.metrics, self.fitness = self.validate()
|
|
474
479
|
|
|
475
480
|
# NaN recovery
|
|
476
481
|
if self._handle_nan_recovery(epoch):
|
|
@@ -510,11 +515,11 @@ class BaseTrainer:
|
|
|
510
515
|
break # must break all DDP ranks
|
|
511
516
|
epoch += 1
|
|
512
517
|
|
|
518
|
+
seconds = time.time() - self.train_time_start
|
|
519
|
+
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
|
520
|
+
# Do final val with best.pt
|
|
521
|
+
self.final_eval()
|
|
513
522
|
if RANK in {-1, 0}:
|
|
514
|
-
# Do final val with best.pt
|
|
515
|
-
seconds = time.time() - self.train_time_start
|
|
516
|
-
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
|
517
|
-
self.final_eval()
|
|
518
523
|
if self.args.plots:
|
|
519
524
|
self.plot_metrics()
|
|
520
525
|
self.run_callbacks("on_train_end")
|
|
@@ -545,7 +550,7 @@ class BaseTrainer:
|
|
|
545
550
|
total = torch.cuda.get_device_properties(self.device).total_memory
|
|
546
551
|
return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
|
|
547
552
|
|
|
548
|
-
def _clear_memory(self, threshold: float = None):
|
|
553
|
+
def _clear_memory(self, threshold: float | None = None):
|
|
549
554
|
"""Clear accelerator memory by calling garbage collector and emptying cache."""
|
|
550
555
|
if threshold:
|
|
551
556
|
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
|
|
@@ -618,8 +623,7 @@ class BaseTrainer:
|
|
|
618
623
|
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
|
619
624
|
|
|
620
625
|
def get_dataset(self):
|
|
621
|
-
"""
|
|
622
|
-
Get train and validation datasets from data dictionary.
|
|
626
|
+
"""Get train and validation datasets from data dictionary.
|
|
623
627
|
|
|
624
628
|
Returns:
|
|
625
629
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|
|
@@ -627,7 +631,7 @@ class BaseTrainer:
|
|
|
627
631
|
try:
|
|
628
632
|
if self.args.task == "classify":
|
|
629
633
|
data = check_cls_dataset(self.args.data)
|
|
630
|
-
elif self.args.data.rsplit(".", 1)[-1] == "ndjson":
|
|
634
|
+
elif str(self.args.data).rsplit(".", 1)[-1] == "ndjson":
|
|
631
635
|
# Convert NDJSON to YOLO format
|
|
632
636
|
import asyncio
|
|
633
637
|
|
|
@@ -636,7 +640,7 @@ class BaseTrainer:
|
|
|
636
640
|
yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
|
|
637
641
|
self.args.data = str(yaml_path)
|
|
638
642
|
data = check_det_dataset(self.args.data)
|
|
639
|
-
elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
|
|
643
|
+
elif str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
|
|
640
644
|
"detect",
|
|
641
645
|
"segment",
|
|
642
646
|
"pose",
|
|
@@ -654,8 +658,7 @@ class BaseTrainer:
|
|
|
654
658
|
return data
|
|
655
659
|
|
|
656
660
|
def setup_model(self):
|
|
657
|
-
"""
|
|
658
|
-
Load, create, or download model for any task.
|
|
661
|
+
"""Load, create, or download model for any task.
|
|
659
662
|
|
|
660
663
|
Returns:
|
|
661
664
|
(dict): Optional checkpoint to resume training from.
|
|
@@ -688,14 +691,19 @@ class BaseTrainer:
|
|
|
688
691
|
return batch
|
|
689
692
|
|
|
690
693
|
def validate(self):
|
|
691
|
-
"""
|
|
692
|
-
Run validation on val set using self.validator.
|
|
694
|
+
"""Run validation on val set using self.validator.
|
|
693
695
|
|
|
694
696
|
Returns:
|
|
695
697
|
metrics (dict): Dictionary of validation metrics.
|
|
696
698
|
fitness (float): Fitness score for the validation.
|
|
697
699
|
"""
|
|
700
|
+
if self.ema and self.world_size > 1:
|
|
701
|
+
# Sync EMA buffers from rank 0 to all ranks
|
|
702
|
+
for buffer in self.ema.ema.buffers():
|
|
703
|
+
dist.broadcast(buffer, src=0)
|
|
698
704
|
metrics = self.validator(self)
|
|
705
|
+
if metrics is None:
|
|
706
|
+
return None, None
|
|
699
707
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
|
700
708
|
if not self.best_fitness or self.best_fitness < fitness:
|
|
701
709
|
self.best_fitness = fitness
|
|
@@ -706,11 +714,11 @@ class BaseTrainer:
|
|
|
706
714
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
|
707
715
|
|
|
708
716
|
def get_validator(self):
|
|
709
|
-
"""
|
|
717
|
+
"""Raise NotImplementedError (must be implemented by subclasses)."""
|
|
710
718
|
raise NotImplementedError("get_validator function not implemented in trainer")
|
|
711
719
|
|
|
712
720
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
|
713
|
-
"""
|
|
721
|
+
"""Raise NotImplementedError (must return a `torch.utils.data.DataLoader` in subclasses)."""
|
|
714
722
|
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
|
715
723
|
|
|
716
724
|
def build_dataset(self, img_path, mode="train", batch=None):
|
|
@@ -718,10 +726,9 @@ class BaseTrainer:
|
|
|
718
726
|
raise NotImplementedError("build_dataset function not implemented in trainer")
|
|
719
727
|
|
|
720
728
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
|
721
|
-
"""
|
|
722
|
-
Return a loss dict with labelled training loss items tensor.
|
|
729
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
723
730
|
|
|
724
|
-
|
|
731
|
+
Notes:
|
|
725
732
|
This is not needed for classification but necessary for segmentation & detection
|
|
726
733
|
"""
|
|
727
734
|
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
|
@@ -753,9 +760,9 @@ class BaseTrainer:
|
|
|
753
760
|
n = len(metrics) + 2 # number of cols
|
|
754
761
|
t = time.time() - self.train_time_start
|
|
755
762
|
self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
|
|
756
|
-
s = "" if self.csv.exists() else (
|
|
763
|
+
s = "" if self.csv.exists() else ("%s," * n % ("epoch", "time", *keys)).rstrip(",") + "\n"
|
|
757
764
|
with open(self.csv, "a", encoding="utf-8") as f:
|
|
758
|
-
f.write(s + ("%.6g," * n %
|
|
765
|
+
f.write(s + ("%.6g," * n % (self.epoch + 1, t, *vals)).rstrip(",") + "\n")
|
|
759
766
|
|
|
760
767
|
def plot_metrics(self):
|
|
761
768
|
"""Plot metrics from a CSV file."""
|
|
@@ -768,20 +775,20 @@ class BaseTrainer:
|
|
|
768
775
|
|
|
769
776
|
def final_eval(self):
|
|
770
777
|
"""Perform final evaluation and validation for object detection YOLO model."""
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
if
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
778
|
+
model = self.best if self.best.exists() else None
|
|
779
|
+
with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
|
|
780
|
+
if RANK in {-1, 0}:
|
|
781
|
+
ckpt = strip_optimizer(self.last) if self.last.exists() else {}
|
|
782
|
+
if model:
|
|
783
|
+
# update best.pt train_metrics from last.pt
|
|
784
|
+
strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
|
|
785
|
+
if model:
|
|
786
|
+
LOGGER.info(f"\nValidating {model}...")
|
|
787
|
+
self.validator.args.plots = self.args.plots
|
|
788
|
+
self.validator.args.compile = False # disable final val compile as too slow
|
|
789
|
+
self.metrics = self.validator(model=model)
|
|
790
|
+
self.metrics.pop("fitness", None)
|
|
791
|
+
self.run_callbacks("on_fit_epoch_end")
|
|
785
792
|
|
|
786
793
|
def check_resume(self, overrides):
|
|
787
794
|
"""Check if resume checkpoint exists and update arguments accordingly."""
|
|
@@ -804,10 +811,29 @@ class BaseTrainer:
|
|
|
804
811
|
"batch",
|
|
805
812
|
"device",
|
|
806
813
|
"close_mosaic",
|
|
814
|
+
"augmentations",
|
|
815
|
+
"save_period",
|
|
816
|
+
"workers",
|
|
817
|
+
"cache",
|
|
818
|
+
"patience",
|
|
819
|
+
"time",
|
|
820
|
+
"freeze",
|
|
821
|
+
"val",
|
|
822
|
+
"plots",
|
|
807
823
|
): # allow arg updates to reduce memory or update device on resume
|
|
808
824
|
if k in overrides:
|
|
809
825
|
setattr(self.args, k, overrides[k])
|
|
810
826
|
|
|
827
|
+
# Handle augmentations parameter for resume: check if user provided custom augmentations
|
|
828
|
+
if ckpt_args.get("augmentations") is not None:
|
|
829
|
+
# Augmentations were saved in checkpoint as reprs but can't be restored automatically
|
|
830
|
+
LOGGER.warning(
|
|
831
|
+
"Custom Albumentations transforms were used in the original training run but are not "
|
|
832
|
+
"being restored. To preserve custom augmentations when resuming, you need to pass the "
|
|
833
|
+
"'augmentations' parameter again to get expected results. Example: \n"
|
|
834
|
+
f"model.train(resume=True, augmentations={ckpt_args['augmentations']})"
|
|
835
|
+
)
|
|
836
|
+
|
|
811
837
|
except Exception as e:
|
|
812
838
|
raise FileNotFoundError(
|
|
813
839
|
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
|
@@ -887,18 +913,16 @@ class BaseTrainer:
|
|
|
887
913
|
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
|
|
888
914
|
|
|
889
915
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
|
890
|
-
"""
|
|
891
|
-
Construct an optimizer for the given model.
|
|
916
|
+
"""Construct an optimizer for the given model.
|
|
892
917
|
|
|
893
918
|
Args:
|
|
894
919
|
model (torch.nn.Module): The model for which to build an optimizer.
|
|
895
|
-
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
|
896
|
-
|
|
920
|
+
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
|
|
921
|
+
number of iterations.
|
|
897
922
|
lr (float, optional): The learning rate for the optimizer.
|
|
898
923
|
momentum (float, optional): The momentum factor for the optimizer.
|
|
899
924
|
decay (float, optional): The weight decay for the optimizer.
|
|
900
|
-
iterations (float, optional): The number of iterations, which determines the optimizer if
|
|
901
|
-
name is 'auto'.
|
|
925
|
+
iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
|
|
902
926
|
|
|
903
927
|
Returns:
|
|
904
928
|
(torch.optim.Optimizer): The constructed optimizer.
|
ultralytics/engine/tuner.py
CHANGED
|
@@ -8,7 +8,7 @@ that yield the best model performance. This is particularly crucial in deep lear
|
|
|
8
8
|
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
|
|
9
9
|
|
|
10
10
|
Examples:
|
|
11
|
-
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=
|
|
11
|
+
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=10 for 300 tuning iterations.
|
|
12
12
|
>>> from ultralytics import YOLO
|
|
13
13
|
>>> model = YOLO("yolo11n.pt")
|
|
14
14
|
>>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
|
|
@@ -34,12 +34,11 @@ from ultralytics.utils.plotting import plot_tune_results
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
class Tuner:
|
|
37
|
-
"""
|
|
38
|
-
A class for hyperparameter tuning of YOLO models.
|
|
37
|
+
"""A class for hyperparameter tuning of YOLO models.
|
|
39
38
|
|
|
40
39
|
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
|
|
41
|
-
search space and retraining the model to evaluate their performance. Supports both local CSV storage and
|
|
42
|
-
|
|
40
|
+
search space and retraining the model to evaluate their performance. Supports both local CSV storage and distributed
|
|
41
|
+
MongoDB Atlas coordination for multi-machine hyperparameter optimization.
|
|
43
42
|
|
|
44
43
|
Attributes:
|
|
45
44
|
space (dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
|
|
@@ -56,7 +55,7 @@ class Tuner:
|
|
|
56
55
|
__call__: Execute the hyperparameter evolution across multiple iterations.
|
|
57
56
|
|
|
58
57
|
Examples:
|
|
59
|
-
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=
|
|
58
|
+
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=10 for 300 tuning iterations.
|
|
60
59
|
>>> from ultralytics import YOLO
|
|
61
60
|
>>> model = YOLO("yolo11n.pt")
|
|
62
61
|
>>> model.tune(
|
|
@@ -83,8 +82,7 @@ class Tuner:
|
|
|
83
82
|
"""
|
|
84
83
|
|
|
85
84
|
def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
|
|
86
|
-
"""
|
|
87
|
-
Initialize the Tuner with configurations.
|
|
85
|
+
"""Initialize the Tuner with configurations.
|
|
88
86
|
|
|
89
87
|
Args:
|
|
90
88
|
args (dict): Configuration for hyperparameter evolution.
|
|
@@ -142,8 +140,7 @@ class Tuner:
|
|
|
142
140
|
)
|
|
143
141
|
|
|
144
142
|
def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
|
|
145
|
-
"""
|
|
146
|
-
Create MongoDB client with exponential backoff retry on connection failures.
|
|
143
|
+
"""Create MongoDB client with exponential backoff retry on connection failures.
|
|
147
144
|
|
|
148
145
|
Args:
|
|
149
146
|
uri (str): MongoDB connection string with credentials and cluster information.
|
|
@@ -183,12 +180,10 @@ class Tuner:
|
|
|
183
180
|
time.sleep(wait_time)
|
|
184
181
|
|
|
185
182
|
def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
|
|
186
|
-
"""
|
|
187
|
-
Initialize MongoDB connection for distributed tuning.
|
|
183
|
+
"""Initialize MongoDB connection for distributed tuning.
|
|
188
184
|
|
|
189
|
-
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines.
|
|
190
|
-
|
|
191
|
-
from all workers for evolution.
|
|
185
|
+
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines. Each worker
|
|
186
|
+
saves results to a shared collection and reads the latest best hyperparameters from all workers for evolution.
|
|
192
187
|
|
|
193
188
|
Args:
|
|
194
189
|
mongodb_uri (str): MongoDB connection string, e.g. 'mongodb+srv://username:password@cluster.mongodb.net/'.
|
|
@@ -206,8 +201,7 @@ class Tuner:
|
|
|
206
201
|
LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
|
|
207
202
|
|
|
208
203
|
def _get_mongodb_results(self, n: int = 5) -> list:
|
|
209
|
-
"""
|
|
210
|
-
Get top N results from MongoDB sorted by fitness.
|
|
204
|
+
"""Get top N results from MongoDB sorted by fitness.
|
|
211
205
|
|
|
212
206
|
Args:
|
|
213
207
|
n (int): Number of top results to retrieve.
|
|
@@ -221,8 +215,7 @@ class Tuner:
|
|
|
221
215
|
return []
|
|
222
216
|
|
|
223
217
|
def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
|
|
224
|
-
"""
|
|
225
|
-
Save results to MongoDB with proper type conversion.
|
|
218
|
+
"""Save results to MongoDB with proper type conversion.
|
|
226
219
|
|
|
227
220
|
Args:
|
|
228
221
|
fitness (float): Fitness score achieved with these hyperparameters.
|
|
@@ -233,7 +226,7 @@ class Tuner:
|
|
|
233
226
|
try:
|
|
234
227
|
self.collection.insert_one(
|
|
235
228
|
{
|
|
236
|
-
"fitness":
|
|
229
|
+
"fitness": fitness,
|
|
237
230
|
"hyperparameters": {k: (v.item() if hasattr(v, "item") else v) for k, v in hyperparameters.items()},
|
|
238
231
|
"metrics": metrics,
|
|
239
232
|
"timestamp": datetime.now(),
|
|
@@ -244,8 +237,7 @@ class Tuner:
|
|
|
244
237
|
LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
|
|
245
238
|
|
|
246
239
|
def _sync_mongodb_to_csv(self):
|
|
247
|
-
"""
|
|
248
|
-
Sync MongoDB results to CSV for plotting compatibility.
|
|
240
|
+
"""Sync MongoDB results to CSV for plotting compatibility.
|
|
249
241
|
|
|
250
242
|
Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
|
|
251
243
|
the existing plotting functions to work seamlessly with distributed MongoDB data.
|
|
@@ -257,19 +249,20 @@ class Tuner:
|
|
|
257
249
|
return
|
|
258
250
|
|
|
259
251
|
# Write to CSV
|
|
260
|
-
headers = ",".join(["fitness"
|
|
252
|
+
headers = ",".join(["fitness", *list(self.space.keys())]) + "\n"
|
|
261
253
|
with open(self.tune_csv, "w", encoding="utf-8") as f:
|
|
262
254
|
f.write(headers)
|
|
263
255
|
for result in all_results:
|
|
264
256
|
fitness = result["fitness"]
|
|
265
257
|
hyp_values = [result["hyperparameters"][k] for k in self.space.keys()]
|
|
266
|
-
log_row = [round(fitness, 5)
|
|
258
|
+
log_row = [round(fitness, 5), *hyp_values]
|
|
267
259
|
f.write(",".join(map(str, log_row)) + "\n")
|
|
268
260
|
|
|
269
261
|
except Exception as e:
|
|
270
262
|
LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
|
|
271
263
|
|
|
272
|
-
|
|
264
|
+
@staticmethod
|
|
265
|
+
def _crossover(x: np.ndarray, alpha: float = 0.2, k: int = 9) -> np.ndarray:
|
|
273
266
|
"""BLX-α crossover from up to top-k parents (x[:,0]=fitness, rest=genes)."""
|
|
274
267
|
k = min(k, len(x))
|
|
275
268
|
# fitness weights (shifted to >0); fallback to uniform if degenerate
|
|
@@ -288,11 +281,9 @@ class Tuner:
|
|
|
288
281
|
mutation: float = 0.5,
|
|
289
282
|
sigma: float = 0.2,
|
|
290
283
|
) -> dict[str, float]:
|
|
291
|
-
"""
|
|
292
|
-
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
|
284
|
+
"""Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
|
293
285
|
|
|
294
286
|
Args:
|
|
295
|
-
parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
|
|
296
287
|
n (int): Number of top parents to consider.
|
|
297
288
|
mutation (float): Probability of a parameter mutation in any given iteration.
|
|
298
289
|
sigma (float): Standard deviation for Gaussian random number generator.
|
|
@@ -304,8 +295,7 @@ class Tuner:
|
|
|
304
295
|
|
|
305
296
|
# Try MongoDB first if available
|
|
306
297
|
if self.mongodb:
|
|
307
|
-
results
|
|
308
|
-
if results:
|
|
298
|
+
if results := self._get_mongodb_results(n):
|
|
309
299
|
# MongoDB already sorted by fitness DESC, so results[0] is best
|
|
310
300
|
x = np.array([[r["fitness"]] + [r["hyperparameters"][k] for k in self.space.keys()] for r in results])
|
|
311
301
|
elif self.collection.name in self.collection.database.list_collection_names(): # Tuner started elsewhere
|
|
@@ -344,13 +334,12 @@ class Tuner:
|
|
|
344
334
|
|
|
345
335
|
# Update types
|
|
346
336
|
if "close_mosaic" in hyp:
|
|
347
|
-
hyp["close_mosaic"] =
|
|
337
|
+
hyp["close_mosaic"] = round(hyp["close_mosaic"])
|
|
348
338
|
|
|
349
339
|
return hyp
|
|
350
340
|
|
|
351
341
|
def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
|
|
352
|
-
"""
|
|
353
|
-
Execute the hyperparameter evolution process when the Tuner instance is called.
|
|
342
|
+
"""Execute the hyperparameter evolution process when the Tuner instance is called.
|
|
354
343
|
|
|
355
344
|
This method iterates through the specified number of iterations, performing the following steps:
|
|
356
345
|
1. Sync MongoDB results to CSV (if using distributed mode)
|
|
@@ -421,7 +410,7 @@ class Tuner:
|
|
|
421
410
|
else:
|
|
422
411
|
# Save to CSV only if no MongoDB
|
|
423
412
|
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
|
424
|
-
headers = "" if self.tune_csv.exists() else (",".join(["fitness"
|
|
413
|
+
headers = "" if self.tune_csv.exists() else (",".join(["fitness", *list(self.space.keys())]) + "\n")
|
|
425
414
|
with open(self.tune_csv, "a", encoding="utf-8") as f:
|
|
426
415
|
f.write(headers + ",".join(map(str, log_row)) + "\n")
|
|
427
416
|
|
ultralytics/engine/validator.py
CHANGED
|
@@ -29,26 +29,26 @@ from pathlib import Path
|
|
|
29
29
|
|
|
30
30
|
import numpy as np
|
|
31
31
|
import torch
|
|
32
|
+
import torch.distributed as dist
|
|
32
33
|
|
|
33
34
|
from ultralytics.cfg import get_cfg, get_save_dir
|
|
34
35
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
|
35
36
|
from ultralytics.nn.autobackend import AutoBackend
|
|
36
|
-
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
|
37
|
+
from ultralytics.utils import LOGGER, RANK, TQDM, callbacks, colorstr, emojis
|
|
37
38
|
from ultralytics.utils.checks import check_imgsz
|
|
38
39
|
from ultralytics.utils.ops import Profile
|
|
39
40
|
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
class BaseValidator:
|
|
43
|
-
"""
|
|
44
|
-
A base class for creating validators.
|
|
44
|
+
"""A base class for creating validators.
|
|
45
45
|
|
|
46
46
|
This class provides the foundation for validation processes, including model evaluation, metric computation, and
|
|
47
47
|
result visualization.
|
|
48
48
|
|
|
49
49
|
Attributes:
|
|
50
50
|
args (SimpleNamespace): Configuration for the validator.
|
|
51
|
-
dataloader (DataLoader):
|
|
51
|
+
dataloader (DataLoader): DataLoader to use for validation.
|
|
52
52
|
model (nn.Module): Model to validate.
|
|
53
53
|
data (dict): Data dictionary containing dataset information.
|
|
54
54
|
device (torch.device): Device to use for validation.
|
|
@@ -61,8 +61,8 @@ class BaseValidator:
|
|
|
61
61
|
nc (int): Number of classes.
|
|
62
62
|
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
|
63
63
|
jdict (list): List to store JSON validation results.
|
|
64
|
-
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
|
65
|
-
|
|
64
|
+
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch
|
|
65
|
+
processing times in milliseconds.
|
|
66
66
|
save_dir (Path): Directory to save results.
|
|
67
67
|
plots (dict): Dictionary to store plots for visualization.
|
|
68
68
|
callbacks (dict): Dictionary to store various callback functions.
|
|
@@ -92,11 +92,10 @@ class BaseValidator:
|
|
|
92
92
|
"""
|
|
93
93
|
|
|
94
94
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
|
95
|
-
"""
|
|
96
|
-
Initialize a BaseValidator instance.
|
|
95
|
+
"""Initialize a BaseValidator instance.
|
|
97
96
|
|
|
98
97
|
Args:
|
|
99
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
98
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
100
99
|
save_dir (Path, optional): Directory to save results.
|
|
101
100
|
args (SimpleNamespace, optional): Configuration for the validator.
|
|
102
101
|
_callbacks (dict, optional): Dictionary to store various callback functions.
|
|
@@ -130,8 +129,7 @@ class BaseValidator:
|
|
|
130
129
|
|
|
131
130
|
@smart_inference_mode()
|
|
132
131
|
def __call__(self, trainer=None, model=None):
|
|
133
|
-
"""
|
|
134
|
-
Execute validation process, running inference on dataloader and computing performance metrics.
|
|
132
|
+
"""Execute validation process, running inference on dataloader and computing performance metrics.
|
|
135
133
|
|
|
136
134
|
Args:
|
|
137
135
|
trainer (object, optional): Trainer object that contains the model to validate.
|
|
@@ -160,7 +158,7 @@ class BaseValidator:
|
|
|
160
158
|
callbacks.add_integration_callbacks(self)
|
|
161
159
|
model = AutoBackend(
|
|
162
160
|
model=model or self.args.model,
|
|
163
|
-
device=select_device(self.args.device),
|
|
161
|
+
device=select_device(self.args.device) if RANK == -1 else torch.device("cuda", RANK),
|
|
164
162
|
dnn=self.args.dnn,
|
|
165
163
|
data=self.args.data,
|
|
166
164
|
fp16=self.args.half,
|
|
@@ -223,21 +221,34 @@ class BaseValidator:
|
|
|
223
221
|
preds = self.postprocess(preds)
|
|
224
222
|
|
|
225
223
|
self.update_metrics(preds, batch)
|
|
226
|
-
if self.args.plots and batch_i < 3:
|
|
224
|
+
if self.args.plots and batch_i < 3 and RANK in {-1, 0}:
|
|
227
225
|
self.plot_val_samples(batch, batch_i)
|
|
228
226
|
self.plot_predictions(batch, preds, batch_i)
|
|
229
227
|
|
|
230
228
|
self.run_callbacks("on_val_batch_end")
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
self.
|
|
234
|
-
|
|
235
|
-
|
|
229
|
+
|
|
230
|
+
stats = {}
|
|
231
|
+
self.gather_stats()
|
|
232
|
+
if RANK in {-1, 0}:
|
|
233
|
+
stats = self.get_stats()
|
|
234
|
+
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
|
235
|
+
self.finalize_metrics()
|
|
236
|
+
self.print_results()
|
|
237
|
+
self.run_callbacks("on_val_end")
|
|
238
|
+
|
|
236
239
|
if self.training:
|
|
237
240
|
model.float()
|
|
238
|
-
|
|
241
|
+
# Reduce loss across all GPUs
|
|
242
|
+
loss = self.loss.clone().detach()
|
|
243
|
+
if trainer.world_size > 1:
|
|
244
|
+
dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
|
|
245
|
+
if RANK > 0:
|
|
246
|
+
return
|
|
247
|
+
results = {**stats, **trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val")}
|
|
239
248
|
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
|
240
249
|
else:
|
|
250
|
+
if RANK > 0:
|
|
251
|
+
return stats
|
|
241
252
|
LOGGER.info(
|
|
242
253
|
"Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
|
|
243
254
|
*tuple(self.speed.values())
|
|
@@ -255,8 +266,7 @@ class BaseValidator:
|
|
|
255
266
|
def match_predictions(
|
|
256
267
|
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
|
|
257
268
|
) -> torch.Tensor:
|
|
258
|
-
"""
|
|
259
|
-
Match predictions to ground truth objects using IoU.
|
|
269
|
+
"""Match predictions to ground truth objects using IoU.
|
|
260
270
|
|
|
261
271
|
Args:
|
|
262
272
|
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
|
|
@@ -336,6 +346,10 @@ class BaseValidator:
|
|
|
336
346
|
"""Return statistics about the model's performance."""
|
|
337
347
|
return {}
|
|
338
348
|
|
|
349
|
+
def gather_stats(self):
|
|
350
|
+
"""Gather statistics from all the GPUs during DDP training to GPU 0."""
|
|
351
|
+
pass
|
|
352
|
+
|
|
339
353
|
def print_results(self):
|
|
340
354
|
"""Print the results of the model's predictions."""
|
|
341
355
|
pass
|
|
@@ -350,7 +364,10 @@ class BaseValidator:
|
|
|
350
364
|
return []
|
|
351
365
|
|
|
352
366
|
def on_plot(self, name, data=None):
|
|
353
|
-
"""Register plots for visualization."""
|
|
367
|
+
"""Register plots for visualization, deduplicating by type."""
|
|
368
|
+
plot_type = data.get("type") if data else None
|
|
369
|
+
if plot_type and any((v.get("data") or {}).get("type") == plot_type for v in self.plots.values()):
|
|
370
|
+
return # Skip duplicate plot types
|
|
354
371
|
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
|
355
372
|
|
|
356
373
|
def plot_val_samples(self, batch, ni):
|