dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -3,8 +3,9 @@
|
|
|
3
3
|
import json
|
|
4
4
|
from time import time
|
|
5
5
|
|
|
6
|
-
from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession
|
|
6
|
+
from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession
|
|
7
7
|
from ultralytics.utils import LOGGER, RANK, SETTINGS
|
|
8
|
+
from ultralytics.utils.events import events
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
def on_pretrain_routine_start(trainer):
|
|
@@ -73,22 +74,23 @@ def on_train_end(trainer):
|
|
|
73
74
|
|
|
74
75
|
def on_train_start(trainer):
|
|
75
76
|
"""Run events on train start."""
|
|
76
|
-
events(trainer.args)
|
|
77
|
+
events(trainer.args, trainer.device)
|
|
77
78
|
|
|
78
79
|
|
|
79
80
|
def on_val_start(validator):
|
|
80
81
|
"""Run events on validation start."""
|
|
81
|
-
|
|
82
|
+
if not validator.training:
|
|
83
|
+
events(validator.args, validator.device)
|
|
82
84
|
|
|
83
85
|
|
|
84
86
|
def on_predict_start(predictor):
|
|
85
87
|
"""Run events on predict start."""
|
|
86
|
-
events(predictor.args)
|
|
88
|
+
events(predictor.args, predictor.device)
|
|
87
89
|
|
|
88
90
|
|
|
89
91
|
def on_export_start(exporter):
|
|
90
92
|
"""Run events on export start."""
|
|
91
|
-
events(exporter.args)
|
|
93
|
+
events(exporter.args, exporter.device)
|
|
92
94
|
|
|
93
95
|
|
|
94
96
|
callbacks = (
|
|
@@ -105,4 +107,4 @@ callbacks = (
|
|
|
105
107
|
}
|
|
106
108
|
if SETTINGS["hub"] is True
|
|
107
109
|
else {}
|
|
108
|
-
)
|
|
110
|
+
)
|
|
@@ -45,24 +45,20 @@ def sanitize_dict(x: dict) -> dict:
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
def on_pretrain_routine_end(trainer):
|
|
48
|
-
"""
|
|
49
|
-
Log training parameters to MLflow at the end of the pretraining routine.
|
|
48
|
+
"""Log training parameters to MLflow at the end of the pretraining routine.
|
|
50
49
|
|
|
51
50
|
This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
|
|
52
|
-
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
|
|
53
|
-
|
|
51
|
+
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters from
|
|
52
|
+
the trainer.
|
|
54
53
|
|
|
55
54
|
Args:
|
|
56
55
|
trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
|
|
57
56
|
|
|
58
|
-
|
|
59
|
-
mlflow: The imported mlflow module to use for logging.
|
|
60
|
-
|
|
61
|
-
Environment Variables:
|
|
57
|
+
Notes:
|
|
62
58
|
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
|
|
63
59
|
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
|
|
64
60
|
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
|
|
65
|
-
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after
|
|
61
|
+
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.
|
|
66
62
|
"""
|
|
67
63
|
global mlflow
|
|
68
64
|
|
|
@@ -107,7 +103,7 @@ def on_fit_epoch_end(trainer):
|
|
|
107
103
|
|
|
108
104
|
|
|
109
105
|
def on_train_end(trainer):
|
|
110
|
-
"""Log model artifacts at the end of
|
|
106
|
+
"""Log model artifacts at the end of training."""
|
|
111
107
|
if not mlflow:
|
|
112
108
|
return
|
|
113
109
|
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
|
|
@@ -18,12 +18,11 @@ except (ImportError, AssertionError):
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def _log_scalars(scalars: dict, step: int = 0) -> None:
|
|
21
|
-
"""
|
|
22
|
-
Log scalars to the NeptuneAI experiment logger.
|
|
21
|
+
"""Log scalars to the NeptuneAI experiment logger.
|
|
23
22
|
|
|
24
23
|
Args:
|
|
25
24
|
scalars (dict): Dictionary of scalar values to log to NeptuneAI.
|
|
26
|
-
step (int): The current step or iteration number for logging.
|
|
25
|
+
step (int, optional): The current step or iteration number for logging.
|
|
27
26
|
|
|
28
27
|
Examples:
|
|
29
28
|
>>> metrics = {"mAP": 0.85, "loss": 0.32}
|
|
@@ -35,11 +34,10 @@ def _log_scalars(scalars: dict, step: int = 0) -> None:
|
|
|
35
34
|
|
|
36
35
|
|
|
37
36
|
def _log_images(imgs_dict: dict, group: str = "") -> None:
|
|
38
|
-
"""
|
|
39
|
-
Log images to the NeptuneAI experiment logger.
|
|
37
|
+
"""Log images to the NeptuneAI experiment logger.
|
|
40
38
|
|
|
41
|
-
This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized
|
|
42
|
-
|
|
39
|
+
This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized under the
|
|
40
|
+
specified group name.
|
|
43
41
|
|
|
44
42
|
Args:
|
|
45
43
|
imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data.
|
|
@@ -55,13 +53,7 @@ def _log_images(imgs_dict: dict, group: str = "") -> None:
|
|
|
55
53
|
|
|
56
54
|
|
|
57
55
|
def _log_plot(title: str, plot_path: str) -> None:
|
|
58
|
-
"""
|
|
59
|
-
Log plots to the NeptuneAI experiment logger.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
title (str): Title of the plot.
|
|
63
|
-
plot_path (str): Path to the saved image file.
|
|
64
|
-
"""
|
|
56
|
+
"""Log plots to the NeptuneAI experiment logger."""
|
|
65
57
|
import matplotlib.image as mpimg
|
|
66
58
|
import matplotlib.pyplot as plt
|
|
67
59
|
|
|
@@ -73,7 +65,7 @@ def _log_plot(title: str, plot_path: str) -> None:
|
|
|
73
65
|
|
|
74
66
|
|
|
75
67
|
def on_pretrain_routine_start(trainer) -> None:
|
|
76
|
-
"""
|
|
68
|
+
"""Initialize NeptuneAI run and log hyperparameters before training starts."""
|
|
77
69
|
try:
|
|
78
70
|
global run
|
|
79
71
|
run = neptune.init_run(
|
|
@@ -87,7 +79,7 @@ def on_pretrain_routine_start(trainer) -> None:
|
|
|
87
79
|
|
|
88
80
|
|
|
89
81
|
def on_train_epoch_end(trainer) -> None:
|
|
90
|
-
"""
|
|
82
|
+
"""Log training metrics and learning rate at the end of each training epoch."""
|
|
91
83
|
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
|
92
84
|
_log_scalars(trainer.lr, trainer.epoch + 1)
|
|
93
85
|
if trainer.epoch == 1:
|
|
@@ -95,7 +87,7 @@ def on_train_epoch_end(trainer) -> None:
|
|
|
95
87
|
|
|
96
88
|
|
|
97
89
|
def on_fit_epoch_end(trainer) -> None:
|
|
98
|
-
"""
|
|
90
|
+
"""Log model info and validation metrics at the end of each fit epoch."""
|
|
99
91
|
if run and trainer.epoch == 0:
|
|
100
92
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
101
93
|
|
|
@@ -104,14 +96,14 @@ def on_fit_epoch_end(trainer) -> None:
|
|
|
104
96
|
|
|
105
97
|
|
|
106
98
|
def on_val_end(validator) -> None:
|
|
107
|
-
"""
|
|
99
|
+
"""Log validation images at the end of validation."""
|
|
108
100
|
if run:
|
|
109
101
|
# Log val_labels and val_pred
|
|
110
102
|
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
|
|
111
103
|
|
|
112
104
|
|
|
113
105
|
def on_train_end(trainer) -> None:
|
|
114
|
-
"""
|
|
106
|
+
"""Log final results, plots, and model weights at the end of training."""
|
|
115
107
|
if run:
|
|
116
108
|
# Log final results, CM matrix + PR plots
|
|
117
109
|
files = [
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from ultralytics.utils import RANK, SETTINGS
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def on_pretrain_routine_start(trainer):
|
|
7
|
+
"""Initialize and start console logging immediately at the very beginning."""
|
|
8
|
+
if RANK in {-1, 0}:
|
|
9
|
+
from ultralytics.utils.logger import DEFAULT_LOG_PATH, ConsoleLogger, SystemLogger
|
|
10
|
+
|
|
11
|
+
trainer.system_logger = SystemLogger()
|
|
12
|
+
trainer.console_logger = ConsoleLogger(DEFAULT_LOG_PATH)
|
|
13
|
+
trainer.console_logger.start_capture()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def on_pretrain_routine_end(trainer):
|
|
17
|
+
"""Handle pre-training routine completion event."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def on_fit_epoch_end(trainer):
|
|
22
|
+
"""Handle end of training epoch event and collect system metrics."""
|
|
23
|
+
if RANK in {-1, 0} and hasattr(trainer, "system_logger"):
|
|
24
|
+
system_metrics = trainer.system_logger.get_metrics()
|
|
25
|
+
print(system_metrics) # for debug
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def on_model_save(trainer):
|
|
29
|
+
"""Handle model checkpoint save event."""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def on_train_end(trainer):
|
|
34
|
+
"""Stop console capture and finalize logs."""
|
|
35
|
+
if logger := getattr(trainer, "console_logger", None):
|
|
36
|
+
logger.stop_capture()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def on_train_start(trainer):
|
|
40
|
+
"""Handle training start event."""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def on_val_start(validator):
|
|
45
|
+
"""Handle validation start event."""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def on_predict_start(predictor):
|
|
50
|
+
"""Handle prediction start event."""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def on_export_start(exporter):
|
|
55
|
+
"""Handle model export start event."""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
callbacks = (
|
|
60
|
+
{
|
|
61
|
+
"on_pretrain_routine_start": on_pretrain_routine_start,
|
|
62
|
+
"on_pretrain_routine_end": on_pretrain_routine_end,
|
|
63
|
+
"on_fit_epoch_end": on_fit_epoch_end,
|
|
64
|
+
"on_model_save": on_model_save,
|
|
65
|
+
"on_train_end": on_train_end,
|
|
66
|
+
"on_train_start": on_train_start,
|
|
67
|
+
"on_val_start": on_val_start,
|
|
68
|
+
"on_predict_start": on_predict_start,
|
|
69
|
+
"on_export_start": on_export_start,
|
|
70
|
+
}
|
|
71
|
+
if SETTINGS.get("platform", False) is True # disabled for debugging
|
|
72
|
+
else {}
|
|
73
|
+
)
|
|
@@ -13,11 +13,10 @@ except (ImportError, AssertionError):
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def on_fit_epoch_end(trainer):
|
|
16
|
-
"""
|
|
17
|
-
Reports training metrics to Ray Tune at epoch end when a Ray session is active.
|
|
16
|
+
"""Report training metrics to Ray Tune at epoch end when a Ray session is active.
|
|
18
17
|
|
|
19
|
-
Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number,
|
|
20
|
-
|
|
18
|
+
Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number, enabling
|
|
19
|
+
hyperparameter tuning optimization. Only executes when within an active Ray Tune session.
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
22
|
trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs.
|
|
@@ -22,8 +22,7 @@ except (ImportError, AssertionError, TypeError, AttributeError):
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def _log_scalars(scalars: dict, step: int = 0) -> None:
|
|
25
|
-
"""
|
|
26
|
-
Log scalar values to TensorBoard.
|
|
25
|
+
"""Log scalar values to TensorBoard.
|
|
27
26
|
|
|
28
27
|
Args:
|
|
29
28
|
scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the
|
|
@@ -31,7 +30,7 @@ def _log_scalars(scalars: dict, step: int = 0) -> None:
|
|
|
31
30
|
step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
|
|
32
31
|
|
|
33
32
|
Examples:
|
|
34
|
-
|
|
33
|
+
Log training metrics
|
|
35
34
|
>>> metrics = {"loss": 0.5, "accuracy": 0.95}
|
|
36
35
|
>>> _log_scalars(metrics, step=100)
|
|
37
36
|
"""
|
|
@@ -41,17 +40,15 @@ def _log_scalars(scalars: dict, step: int = 0) -> None:
|
|
|
41
40
|
|
|
42
41
|
|
|
43
42
|
def _log_tensorboard_graph(trainer) -> None:
|
|
44
|
-
"""
|
|
45
|
-
Log model graph to TensorBoard.
|
|
43
|
+
"""Log model graph to TensorBoard.
|
|
46
44
|
|
|
47
45
|
This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input
|
|
48
46
|
tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex
|
|
49
47
|
approach for models like RTDETR that may require special handling.
|
|
50
48
|
|
|
51
49
|
Args:
|
|
52
|
-
trainer (BaseTrainer): The trainer object containing the model to visualize. Must
|
|
53
|
-
|
|
54
|
-
- args: Configuration arguments with 'imgsz' attribute
|
|
50
|
+
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize. Must
|
|
51
|
+
have attributes model and args with imgsz.
|
|
55
52
|
|
|
56
53
|
Notes:
|
|
57
54
|
This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
|
|
@@ -71,14 +68,14 @@ def _log_tensorboard_graph(trainer) -> None:
|
|
|
71
68
|
# Try simple method first (YOLO)
|
|
72
69
|
try:
|
|
73
70
|
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
|
|
74
|
-
WRITER.add_graph(torch.jit.trace(torch_utils.
|
|
71
|
+
WRITER.add_graph(torch.jit.trace(torch_utils.unwrap_model(trainer.model), im, strict=False), [])
|
|
75
72
|
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
|
76
73
|
return
|
|
77
74
|
|
|
78
75
|
except Exception:
|
|
79
76
|
# Fallback to TorchScript export steps (RTDETR)
|
|
80
77
|
try:
|
|
81
|
-
model = deepcopy(torch_utils.
|
|
78
|
+
model = deepcopy(torch_utils.unwrap_model(trainer.model))
|
|
82
79
|
model.eval()
|
|
83
80
|
model = model.fuse(verbose=False)
|
|
84
81
|
for m in model.modules():
|
|
@@ -110,13 +107,13 @@ def on_train_start(trainer) -> None:
|
|
|
110
107
|
|
|
111
108
|
|
|
112
109
|
def on_train_epoch_end(trainer) -> None:
|
|
113
|
-
"""
|
|
110
|
+
"""Log scalar statistics at the end of a training epoch."""
|
|
114
111
|
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
|
115
112
|
_log_scalars(trainer.lr, trainer.epoch + 1)
|
|
116
113
|
|
|
117
114
|
|
|
118
115
|
def on_fit_epoch_end(trainer) -> None:
|
|
119
|
-
"""
|
|
116
|
+
"""Log epoch metrics at end of training epoch."""
|
|
120
117
|
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
|
121
118
|
|
|
122
119
|
|
|
@@ -16,8 +16,7 @@ except (ImportError, AssertionError):
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
|
|
19
|
-
"""
|
|
20
|
-
Create and log a custom metric visualization to wandb.plot.pr_curve.
|
|
19
|
+
"""Create and log a custom metric visualization to wandb.plot.pr_curve.
|
|
21
20
|
|
|
22
21
|
This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
|
|
23
22
|
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
|
|
@@ -27,20 +26,26 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
|
|
|
27
26
|
x (list): Values for the x-axis; expected to have length N.
|
|
28
27
|
y (list): Corresponding values for the y-axis; also expected to have length N.
|
|
29
28
|
classes (list): Labels identifying the class of each point; length N.
|
|
30
|
-
title (str): Title for the plot
|
|
31
|
-
x_title (str): Label for the x-axis
|
|
32
|
-
y_title (str): Label for the y-axis
|
|
29
|
+
title (str, optional): Title for the plot.
|
|
30
|
+
x_title (str, optional): Label for the x-axis.
|
|
31
|
+
y_title (str, optional): Label for the y-axis.
|
|
33
32
|
|
|
34
33
|
Returns:
|
|
35
34
|
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
|
36
35
|
"""
|
|
37
|
-
import
|
|
36
|
+
import polars as pl # scope for faster 'import ultralytics'
|
|
37
|
+
import polars.selectors as cs
|
|
38
|
+
|
|
39
|
+
df = pl.DataFrame({"class": classes, "y": y, "x": x}).with_columns(cs.numeric().round(3))
|
|
40
|
+
data = df.select(["class", "y", "x"]).rows()
|
|
38
41
|
|
|
39
|
-
df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
|
40
42
|
fields = {"x": "x", "y": "y", "class": "class"}
|
|
41
43
|
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
|
42
44
|
return wb.plot_table(
|
|
43
|
-
"wandb/area-under-curve/v0",
|
|
45
|
+
"wandb/area-under-curve/v0",
|
|
46
|
+
wb.Table(data=data, columns=["class", "y", "x"]),
|
|
47
|
+
fields=fields,
|
|
48
|
+
string_fields=string_fields,
|
|
44
49
|
)
|
|
45
50
|
|
|
46
51
|
|
|
@@ -55,22 +60,21 @@ def _plot_curve(
|
|
|
55
60
|
num_x=100,
|
|
56
61
|
only_mean=False,
|
|
57
62
|
):
|
|
58
|
-
"""
|
|
59
|
-
Log a metric curve visualization.
|
|
63
|
+
"""Log a metric curve visualization.
|
|
60
64
|
|
|
61
|
-
This function generates a metric curve based on input data and logs the visualization to wandb.
|
|
62
|
-
|
|
65
|
+
This function generates a metric curve based on input data and logs the visualization to wandb. The curve can
|
|
66
|
+
represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
|
|
63
67
|
|
|
64
68
|
Args:
|
|
65
69
|
x (np.ndarray): Data points for the x-axis with length N.
|
|
66
70
|
y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
|
|
67
|
-
names (list): Names of the classes corresponding to the y-axis data; length C.
|
|
68
|
-
id (str): Unique identifier for the logged data in wandb.
|
|
69
|
-
title (str): Title for the visualization plot.
|
|
70
|
-
x_title (str): Label for the x-axis.
|
|
71
|
-
y_title (str): Label for the y-axis.
|
|
72
|
-
num_x (int): Number of interpolated data points for visualization.
|
|
73
|
-
only_mean (bool): Flag to indicate if only the mean curve should be plotted.
|
|
71
|
+
names (list, optional): Names of the classes corresponding to the y-axis data; length C.
|
|
72
|
+
id (str, optional): Unique identifier for the logged data in wandb.
|
|
73
|
+
title (str, optional): Title for the visualization plot.
|
|
74
|
+
x_title (str, optional): Label for the x-axis.
|
|
75
|
+
y_title (str, optional): Label for the y-axis.
|
|
76
|
+
num_x (int, optional): Number of interpolated data points for visualization.
|
|
77
|
+
only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.
|
|
74
78
|
|
|
75
79
|
Notes:
|
|
76
80
|
The function leverages the '_custom_table' function to generate the actual visualization.
|
|
@@ -99,21 +103,20 @@ def _plot_curve(
|
|
|
99
103
|
|
|
100
104
|
|
|
101
105
|
def _log_plots(plots, step):
|
|
102
|
-
"""
|
|
103
|
-
Log plots to WandB at a specific step if they haven't been logged already.
|
|
106
|
+
"""Log plots to WandB at a specific step if they haven't been logged already.
|
|
104
107
|
|
|
105
|
-
This function checks each plot in the input dictionary against previously processed plots and logs
|
|
106
|
-
|
|
108
|
+
This function checks each plot in the input dictionary against previously processed plots and logs new or updated
|
|
109
|
+
plots to WandB at the specified step.
|
|
107
110
|
|
|
108
111
|
Args:
|
|
109
|
-
plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries
|
|
110
|
-
|
|
112
|
+
plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries containing plot
|
|
113
|
+
metadata including timestamps.
|
|
111
114
|
step (int): The step/epoch at which to log the plots in the WandB run.
|
|
112
115
|
|
|
113
116
|
Notes:
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
+
The function uses a shallow copy of the plots dictionary to prevent modification during iteration.
|
|
118
|
+
Plots are identified by their stem name (filename without extension).
|
|
119
|
+
Each plot is logged as a WandB Image object.
|
|
117
120
|
"""
|
|
118
121
|
for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration
|
|
119
122
|
timestamp = params["timestamp"]
|
|
@@ -123,7 +126,7 @@ def _log_plots(plots, step):
|
|
|
123
126
|
|
|
124
127
|
|
|
125
128
|
def on_pretrain_routine_start(trainer):
|
|
126
|
-
"""
|
|
129
|
+
"""Initialize and start wandb project if module is present."""
|
|
127
130
|
if not wb.run:
|
|
128
131
|
wb.init(
|
|
129
132
|
project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
|
|
@@ -134,11 +137,11 @@ def on_pretrain_routine_start(trainer):
|
|
|
134
137
|
|
|
135
138
|
def on_fit_epoch_end(trainer):
|
|
136
139
|
"""Log training metrics and model information at the end of an epoch."""
|
|
137
|
-
wb.run.log(trainer.metrics, step=trainer.epoch + 1)
|
|
138
140
|
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
|
139
141
|
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
|
140
142
|
if trainer.epoch == 0:
|
|
141
143
|
wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
|
|
144
|
+
wb.run.log(trainer.metrics, step=trainer.epoch + 1, commit=True) # commit forces sync
|
|
142
145
|
|
|
143
146
|
|
|
144
147
|
def on_train_epoch_end(trainer):
|