dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -175,13 +175,12 @@ default_callbacks = {
|
|
|
175
175
|
|
|
176
176
|
|
|
177
177
|
def get_default_callbacks():
|
|
178
|
-
"""
|
|
179
|
-
Get the default callbacks for Ultralytics training, validation, prediction, and export processes.
|
|
178
|
+
"""Get the default callbacks for Ultralytics training, validation, prediction, and export processes.
|
|
180
179
|
|
|
181
180
|
Returns:
|
|
182
|
-
(dict): Dictionary of default callbacks for various training events. Each key
|
|
183
|
-
|
|
184
|
-
|
|
181
|
+
(dict): Dictionary of default callbacks for various training events. Each key represents an event during the
|
|
182
|
+
training process, and the corresponding value is a list of callback functions executed when that
|
|
183
|
+
event occurs.
|
|
185
184
|
|
|
186
185
|
Examples:
|
|
187
186
|
>>> callbacks = get_default_callbacks()
|
|
@@ -192,27 +191,27 @@ def get_default_callbacks():
|
|
|
192
191
|
|
|
193
192
|
|
|
194
193
|
def add_integration_callbacks(instance):
|
|
195
|
-
"""
|
|
196
|
-
Add integration callbacks to the instance's callbacks dictionary.
|
|
194
|
+
"""Add integration callbacks to the instance's callbacks dictionary.
|
|
197
195
|
|
|
198
196
|
This function loads and adds various integration callbacks to the provided instance. The specific callbacks added
|
|
199
197
|
depend on the type of instance provided. All instances receive HUB callbacks, while Trainer instances also receive
|
|
200
|
-
additional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard,
|
|
201
|
-
|
|
198
|
+
additional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard, and
|
|
199
|
+
Weights & Biases.
|
|
202
200
|
|
|
203
201
|
Args:
|
|
204
|
-
instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added.
|
|
205
|
-
|
|
202
|
+
instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added. The
|
|
203
|
+
type of instance determines which callbacks are loaded.
|
|
206
204
|
|
|
207
205
|
Examples:
|
|
208
206
|
>>> from ultralytics.engine.trainer import BaseTrainer
|
|
209
207
|
>>> trainer = BaseTrainer()
|
|
210
208
|
>>> add_integration_callbacks(trainer)
|
|
211
209
|
"""
|
|
212
|
-
# Load HUB callbacks
|
|
213
210
|
from .hub import callbacks as hub_cb
|
|
211
|
+
from .platform import callbacks as platform_cb
|
|
214
212
|
|
|
215
|
-
|
|
213
|
+
# Load Ultralytics callbacks
|
|
214
|
+
callbacks_list = [hub_cb, platform_cb]
|
|
216
215
|
|
|
217
216
|
# Load training callbacks
|
|
218
217
|
if "Trainer" in instance.__class__.__name__:
|
|
@@ -15,11 +15,10 @@ except (ImportError, AssertionError):
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
def _log_debug_samples(files, title: str = "Debug Samples") -> None:
|
|
18
|
-
"""
|
|
19
|
-
Log files (images) as debug samples in the ClearML task.
|
|
18
|
+
"""Log files (images) as debug samples in the ClearML task.
|
|
20
19
|
|
|
21
20
|
Args:
|
|
22
|
-
files (
|
|
21
|
+
files (list[Path]): A list of file paths in PosixPath format.
|
|
23
22
|
title (str): A title that groups together images with the same values.
|
|
24
23
|
"""
|
|
25
24
|
import re
|
|
@@ -35,8 +34,7 @@ def _log_debug_samples(files, title: str = "Debug Samples") -> None:
|
|
|
35
34
|
|
|
36
35
|
|
|
37
36
|
def _log_plot(title: str, plot_path: str) -> None:
|
|
38
|
-
"""
|
|
39
|
-
Log an image as a plot in the plot section of ClearML.
|
|
37
|
+
"""Log an image as a plot in the plot section of ClearML.
|
|
40
38
|
|
|
41
39
|
Args:
|
|
42
40
|
title (str): The title of the plot.
|
|
@@ -56,7 +54,7 @@ def _log_plot(title: str, plot_path: str) -> None:
|
|
|
56
54
|
|
|
57
55
|
|
|
58
56
|
def on_pretrain_routine_start(trainer) -> None:
|
|
59
|
-
"""
|
|
57
|
+
"""Initialize and connect ClearML task at the start of pretraining routine."""
|
|
60
58
|
try:
|
|
61
59
|
if task := Task.current_task():
|
|
62
60
|
# WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
|
|
@@ -85,9 +83,9 @@ def on_pretrain_routine_start(trainer) -> None:
|
|
|
85
83
|
|
|
86
84
|
|
|
87
85
|
def on_train_epoch_end(trainer) -> None:
|
|
88
|
-
"""
|
|
86
|
+
"""Log debug samples for the first epoch and report current training progress."""
|
|
89
87
|
if task := Task.current_task():
|
|
90
|
-
# Log debug samples
|
|
88
|
+
# Log debug samples for first epoch only
|
|
91
89
|
if trainer.epoch == 1:
|
|
92
90
|
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
|
|
93
91
|
# Report the current training progress
|
|
@@ -98,14 +96,15 @@ def on_train_epoch_end(trainer) -> None:
|
|
|
98
96
|
|
|
99
97
|
|
|
100
98
|
def on_fit_epoch_end(trainer) -> None:
|
|
101
|
-
"""
|
|
99
|
+
"""Report model information and metrics to logger at the end of an epoch."""
|
|
102
100
|
if task := Task.current_task():
|
|
103
101
|
# Report epoch time and validation metrics
|
|
104
102
|
task.get_logger().report_scalar(
|
|
105
103
|
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
|
|
106
104
|
)
|
|
107
105
|
for k, v in trainer.metrics.items():
|
|
108
|
-
|
|
106
|
+
title = k.split("/")[0]
|
|
107
|
+
task.get_logger().report_scalar(title, k, v, iteration=trainer.epoch)
|
|
109
108
|
if trainer.epoch == 0:
|
|
110
109
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
111
110
|
|
|
@@ -114,23 +113,23 @@ def on_fit_epoch_end(trainer) -> None:
|
|
|
114
113
|
|
|
115
114
|
|
|
116
115
|
def on_val_end(validator) -> None:
|
|
117
|
-
"""
|
|
116
|
+
"""Log validation results including labels and predictions."""
|
|
118
117
|
if Task.current_task():
|
|
119
|
-
# Log
|
|
118
|
+
# Log validation labels and predictions
|
|
120
119
|
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
|
|
121
120
|
|
|
122
121
|
|
|
123
122
|
def on_train_end(trainer) -> None:
|
|
124
|
-
"""
|
|
123
|
+
"""Log final model and training results on training completion."""
|
|
125
124
|
if task := Task.current_task():
|
|
126
|
-
# Log final results,
|
|
125
|
+
# Log final results, confusion matrix and PR plots
|
|
127
126
|
files = [
|
|
128
127
|
"results.png",
|
|
129
128
|
"confusion_matrix.png",
|
|
130
129
|
"confusion_matrix_normalized.png",
|
|
131
130
|
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
|
132
131
|
]
|
|
133
|
-
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
|
132
|
+
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter existing files
|
|
134
133
|
for f in files:
|
|
135
134
|
_log_plot(title=f.stem, plot_path=f)
|
|
136
135
|
# Report final metrics
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from collections.abc import Callable
|
|
4
6
|
from types import SimpleNamespace
|
|
5
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
6
8
|
|
|
7
9
|
import cv2
|
|
8
10
|
import numpy as np
|
|
@@ -26,9 +28,12 @@ try:
|
|
|
26
28
|
# Names of plots created by Ultralytics that are logged to Comet
|
|
27
29
|
CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized"
|
|
28
30
|
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve"
|
|
29
|
-
LABEL_PLOT_NAMES = "labels"
|
|
31
|
+
LABEL_PLOT_NAMES = ["labels"]
|
|
30
32
|
SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask"
|
|
31
33
|
POSE_METRICS_PLOT_PREFIX = "Box", "Pose"
|
|
34
|
+
DETECTION_METRICS_PLOT_PREFIX = ["Box"]
|
|
35
|
+
RESULTS_TABLE_NAME = "results.csv"
|
|
36
|
+
ARGS_YAML_NAME = "args.yaml"
|
|
32
37
|
|
|
33
38
|
_comet_image_prediction_count = 0
|
|
34
39
|
|
|
@@ -37,7 +42,7 @@ except (ImportError, AssertionError):
|
|
|
37
42
|
|
|
38
43
|
|
|
39
44
|
def _get_comet_mode() -> str:
|
|
40
|
-
"""
|
|
45
|
+
"""Return the Comet mode from environment variables, defaulting to 'online'."""
|
|
41
46
|
comet_mode = os.getenv("COMET_MODE")
|
|
42
47
|
if comet_mode is not None:
|
|
43
48
|
LOGGER.warning(
|
|
@@ -52,7 +57,7 @@ def _get_comet_mode() -> str:
|
|
|
52
57
|
|
|
53
58
|
|
|
54
59
|
def _get_comet_model_name() -> str:
|
|
55
|
-
"""
|
|
60
|
+
"""Return the Comet model name from environment variable or default to 'Ultralytics'."""
|
|
56
61
|
return os.getenv("COMET_MODEL_NAME", "Ultralytics")
|
|
57
62
|
|
|
58
63
|
|
|
@@ -62,31 +67,33 @@ def _get_eval_batch_logging_interval() -> int:
|
|
|
62
67
|
|
|
63
68
|
|
|
64
69
|
def _get_max_image_predictions_to_log() -> int:
|
|
65
|
-
"""Get the maximum number of image predictions to log from
|
|
70
|
+
"""Get the maximum number of image predictions to log from environment variables."""
|
|
66
71
|
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
|
|
67
72
|
|
|
68
73
|
|
|
69
74
|
def _scale_confidence_score(score: float) -> float:
|
|
70
|
-
"""
|
|
75
|
+
"""Scale the confidence score by a factor specified in environment variable."""
|
|
71
76
|
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
|
|
72
77
|
return score * scale
|
|
73
78
|
|
|
74
79
|
|
|
75
80
|
def _should_log_confusion_matrix() -> bool:
|
|
76
|
-
"""
|
|
81
|
+
"""Determine if the confusion matrix should be logged based on environment variable settings."""
|
|
77
82
|
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
|
|
78
83
|
|
|
79
84
|
|
|
80
85
|
def _should_log_image_predictions() -> bool:
|
|
81
|
-
"""
|
|
86
|
+
"""Determine whether to log image predictions based on environment variable."""
|
|
82
87
|
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
|
|
83
88
|
|
|
84
89
|
|
|
85
90
|
def _resume_or_create_experiment(args: SimpleNamespace) -> None:
|
|
86
|
-
"""
|
|
87
|
-
Resumes CometML experiment or creates a new experiment based on args.
|
|
91
|
+
"""Resume CometML experiment or create a new experiment based on args.
|
|
88
92
|
|
|
89
93
|
Ensures that the experiment object is only created in a single process during distributed training.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
args (SimpleNamespace): Training arguments containing project configuration and other parameters.
|
|
90
97
|
"""
|
|
91
98
|
if RANK not in {-1, 0}:
|
|
92
99
|
return
|
|
@@ -116,7 +123,14 @@ def _resume_or_create_experiment(args: SimpleNamespace) -> None:
|
|
|
116
123
|
|
|
117
124
|
|
|
118
125
|
def _fetch_trainer_metadata(trainer) -> dict:
|
|
119
|
-
"""
|
|
126
|
+
"""Return metadata for YOLO training including epoch and asset saving status.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
trainer (ultralytics.engine.trainer.BaseTrainer): The YOLO trainer object containing training state and config.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
(dict): Dictionary containing current epoch, step, save assets flag, and final epoch flag.
|
|
133
|
+
"""
|
|
120
134
|
curr_epoch = trainer.epoch + 1
|
|
121
135
|
|
|
122
136
|
train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
|
|
@@ -133,11 +147,20 @@ def _fetch_trainer_metadata(trainer) -> dict:
|
|
|
133
147
|
|
|
134
148
|
def _scale_bounding_box_to_original_image_shape(
|
|
135
149
|
box, resized_image_shape, original_image_shape, ratio_pad
|
|
136
|
-
) ->
|
|
137
|
-
"""
|
|
138
|
-
|
|
150
|
+
) -> list[float]:
|
|
151
|
+
"""Scale bounding box from resized image coordinates to original image coordinates.
|
|
152
|
+
|
|
153
|
+
YOLO resizes images during training and the label values are normalized based on this resized shape. This function
|
|
154
|
+
rescales the bounding box labels to the original image shape.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
box (torch.Tensor): Bounding box in normalized xywh format.
|
|
158
|
+
resized_image_shape (tuple): Shape of the resized image (height, width).
|
|
159
|
+
original_image_shape (tuple): Shape of the original image (height, width).
|
|
160
|
+
ratio_pad (tuple): Ratio and padding information for scaling.
|
|
139
161
|
|
|
140
|
-
|
|
162
|
+
Returns:
|
|
163
|
+
(list[float]): Scaled bounding box coordinates in xywh format with top-left corner adjustment.
|
|
141
164
|
"""
|
|
142
165
|
resized_image_height, resized_image_width = resized_image_shape
|
|
143
166
|
|
|
@@ -154,9 +177,8 @@ def _scale_bounding_box_to_original_image_shape(
|
|
|
154
177
|
return box
|
|
155
178
|
|
|
156
179
|
|
|
157
|
-
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) ->
|
|
158
|
-
"""
|
|
159
|
-
Format ground truth annotations for object detection.
|
|
180
|
+
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) -> dict | None:
|
|
181
|
+
"""Format ground truth annotations for object detection.
|
|
160
182
|
|
|
161
183
|
This function processes ground truth annotations from a batch of images for object detection tasks. It extracts
|
|
162
184
|
bounding boxes, class labels, and other metadata for a specific image in the batch, and formats them for
|
|
@@ -172,7 +194,7 @@ def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, c
|
|
|
172
194
|
- 'ori_shape': Original image shapes
|
|
173
195
|
- 'resized_shape': Resized image shapes
|
|
174
196
|
- 'ratio_pad': Ratio and padding information
|
|
175
|
-
class_name_map (dict
|
|
197
|
+
class_name_map (dict, optional): Mapping from class indices to class names.
|
|
176
198
|
|
|
177
199
|
Returns:
|
|
178
200
|
(dict | None): Formatted ground truth annotations with the following structure:
|
|
@@ -209,8 +231,18 @@ def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, c
|
|
|
209
231
|
return {"name": "ground_truth", "data": data}
|
|
210
232
|
|
|
211
233
|
|
|
212
|
-
def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) ->
|
|
213
|
-
"""Format YOLO predictions for object detection visualization.
|
|
234
|
+
def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) -> dict | None:
|
|
235
|
+
"""Format YOLO predictions for object detection visualization.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
image_path (Path): Path to the image file.
|
|
239
|
+
metadata (dict): Prediction metadata containing bounding boxes and class information.
|
|
240
|
+
class_label_map (dict, optional): Mapping from class indices to class names.
|
|
241
|
+
class_map (dict, optional): Additional class mapping for label conversion.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
(dict | None): Formatted prediction annotations or None if no predictions exist.
|
|
245
|
+
"""
|
|
214
246
|
stem = image_path.stem
|
|
215
247
|
image_id = int(stem) if stem.isnumeric() else stem
|
|
216
248
|
|
|
@@ -224,7 +256,7 @@ def _format_prediction_annotations(image_path, metadata, class_label_map=None, c
|
|
|
224
256
|
class_label_map = {class_map[k]: v for k, v in class_label_map.items()}
|
|
225
257
|
try:
|
|
226
258
|
# import pycotools utilities to decompress annotations for various tasks, e.g. segmentation
|
|
227
|
-
from
|
|
259
|
+
from faster_coco_eval.core.mask import decode
|
|
228
260
|
except ImportError:
|
|
229
261
|
decode = None
|
|
230
262
|
|
|
@@ -251,16 +283,15 @@ def _format_prediction_annotations(image_path, metadata, class_label_map=None, c
|
|
|
251
283
|
return {"name": "prediction", "data": data}
|
|
252
284
|
|
|
253
285
|
|
|
254
|
-
def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) ->
|
|
255
|
-
"""
|
|
256
|
-
Extracts segmentation annotation from compressed segmentations as list of polygons.
|
|
286
|
+
def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) -> list[list[Any]] | None:
|
|
287
|
+
"""Extract segmentation annotation from compressed segmentations as list of polygons.
|
|
257
288
|
|
|
258
289
|
Args:
|
|
259
|
-
segmentation_raw: Raw segmentation data in compressed format.
|
|
260
|
-
decode: Function to decode the compressed segmentation data.
|
|
290
|
+
segmentation_raw (str): Raw segmentation data in compressed format.
|
|
291
|
+
decode (Callable): Function to decode the compressed segmentation data.
|
|
261
292
|
|
|
262
293
|
Returns:
|
|
263
|
-
(
|
|
294
|
+
(list[list[Any]] | None): List of polygon points or None if extraction fails.
|
|
264
295
|
"""
|
|
265
296
|
try:
|
|
266
297
|
mask = decode(segmentation_raw)
|
|
@@ -272,10 +303,20 @@ def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) ->
|
|
|
272
303
|
return None
|
|
273
304
|
|
|
274
305
|
|
|
275
|
-
def _fetch_annotations(
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
306
|
+
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map) -> list | None:
|
|
307
|
+
"""Join the ground truth and prediction annotations if they exist.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
img_idx (int): Index of the image in the batch.
|
|
311
|
+
image_path (Path): Path to the image file.
|
|
312
|
+
batch (dict): Batch data containing ground truth annotations.
|
|
313
|
+
prediction_metadata_map (dict): Map of prediction metadata by image ID.
|
|
314
|
+
class_label_map (dict): Mapping from class indices to class names.
|
|
315
|
+
class_map (dict): Additional class mapping for label conversion.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
(list | None): List of annotation dictionaries or None if no annotations exist.
|
|
319
|
+
"""
|
|
279
320
|
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
|
|
280
321
|
img_idx, image_path, batch, class_label_map
|
|
281
322
|
)
|
|
@@ -290,7 +331,7 @@ def _fetch_annotations(
|
|
|
290
331
|
|
|
291
332
|
|
|
292
333
|
def _create_prediction_metadata_map(model_predictions) -> dict:
|
|
293
|
-
"""Create metadata map for model predictions by
|
|
334
|
+
"""Create metadata map for model predictions by grouping them based on image ID."""
|
|
294
335
|
pred_metadata_map = {}
|
|
295
336
|
for prediction in model_predictions:
|
|
296
337
|
pred_metadata_map.setdefault(prediction["image_id"], [])
|
|
@@ -302,28 +343,24 @@ def _create_prediction_metadata_map(model_predictions) -> dict:
|
|
|
302
343
|
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None:
|
|
303
344
|
"""Log the confusion matrix to Comet experiment."""
|
|
304
345
|
conf_mat = trainer.validator.confusion_matrix.matrix
|
|
305
|
-
names = list(trainer.data["names"].values())
|
|
346
|
+
names = [*list(trainer.data["names"].values()), "background"]
|
|
306
347
|
experiment.log_confusion_matrix(
|
|
307
348
|
matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
|
|
308
349
|
)
|
|
309
350
|
|
|
310
351
|
|
|
311
|
-
def _log_images(experiment, image_paths, curr_step, annotations=None) -> None:
|
|
312
|
-
"""
|
|
313
|
-
Log images to the experiment with optional annotations.
|
|
352
|
+
def _log_images(experiment, image_paths, curr_step: int | None, annotations=None) -> None:
|
|
353
|
+
"""Log images to the experiment with optional annotations.
|
|
314
354
|
|
|
315
|
-
This function logs images to a Comet ML experiment, optionally including annotation data for visualization
|
|
316
|
-
|
|
355
|
+
This function logs images to a Comet ML experiment, optionally including annotation data for visualization such as
|
|
356
|
+
bounding boxes or segmentation masks.
|
|
317
357
|
|
|
318
358
|
Args:
|
|
319
|
-
experiment (comet_ml.
|
|
320
|
-
image_paths (
|
|
359
|
+
experiment (comet_ml.CometExperiment): The Comet ML experiment to log images to.
|
|
360
|
+
image_paths (list[Path]): List of paths to images that will be logged.
|
|
321
361
|
curr_step (int): Current training step/iteration for tracking in the experiment timeline.
|
|
322
|
-
annotations (
|
|
362
|
+
annotations (list[list[dict]], optional): Nested list of annotation dictionaries for each image. Each annotation
|
|
323
363
|
contains visualization data like bounding boxes, labels, and confidence scores.
|
|
324
|
-
|
|
325
|
-
Returns:
|
|
326
|
-
None
|
|
327
364
|
"""
|
|
328
365
|
if annotations:
|
|
329
366
|
for image_path, annotation in zip(image_paths, annotations):
|
|
@@ -335,15 +372,14 @@ def _log_images(experiment, image_paths, curr_step, annotations=None) -> None:
|
|
|
335
372
|
|
|
336
373
|
|
|
337
374
|
def _log_image_predictions(experiment, validator, curr_step) -> None:
|
|
338
|
-
"""
|
|
339
|
-
Log predicted boxes for a single image during training.
|
|
375
|
+
"""Log predicted boxes for a single image during training.
|
|
340
376
|
|
|
341
|
-
This function logs image predictions to a Comet ML experiment during model validation. It processes
|
|
342
|
-
|
|
377
|
+
This function logs image predictions to a Comet ML experiment during model validation. It processes validation data
|
|
378
|
+
and formats both ground truth and prediction annotations for visualization in the Comet
|
|
343
379
|
dashboard. The function respects configured limits on the number of images to log.
|
|
344
380
|
|
|
345
381
|
Args:
|
|
346
|
-
experiment (comet_ml.
|
|
382
|
+
experiment (comet_ml.CometExperiment): The Comet ML experiment to log to.
|
|
347
383
|
validator (BaseValidator): The validator instance containing validation data and predictions.
|
|
348
384
|
curr_step (int): The current training step for logging timeline.
|
|
349
385
|
|
|
@@ -398,15 +434,14 @@ def _log_image_predictions(experiment, validator, curr_step) -> None:
|
|
|
398
434
|
|
|
399
435
|
|
|
400
436
|
def _log_plots(experiment, trainer) -> None:
|
|
401
|
-
"""
|
|
402
|
-
Log evaluation plots and label plots for the experiment.
|
|
437
|
+
"""Log evaluation plots and label plots for the experiment.
|
|
403
438
|
|
|
404
439
|
This function logs various evaluation plots and confusion matrices to the experiment tracking system. It handles
|
|
405
|
-
different types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots
|
|
406
|
-
|
|
440
|
+
different types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots for
|
|
441
|
+
each type.
|
|
407
442
|
|
|
408
443
|
Args:
|
|
409
|
-
experiment (comet_ml.
|
|
444
|
+
experiment (comet_ml.CometExperiment): The Comet ML experiment to log plots to.
|
|
410
445
|
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing validation metrics and save
|
|
411
446
|
directory information.
|
|
412
447
|
|
|
@@ -415,7 +450,7 @@ def _log_plots(experiment, trainer) -> None:
|
|
|
415
450
|
>>> _log_plots(experiment, trainer)
|
|
416
451
|
"""
|
|
417
452
|
plot_filenames = None
|
|
418
|
-
if isinstance(trainer.validator.metrics, SegmentMetrics)
|
|
453
|
+
if isinstance(trainer.validator.metrics, SegmentMetrics):
|
|
419
454
|
plot_filenames = [
|
|
420
455
|
trainer.save_dir / f"{prefix}{plots}.png"
|
|
421
456
|
for plots in EVALUATION_PLOT_NAMES
|
|
@@ -428,7 +463,11 @@ def _log_plots(experiment, trainer) -> None:
|
|
|
428
463
|
for prefix in POSE_METRICS_PLOT_PREFIX
|
|
429
464
|
]
|
|
430
465
|
elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)):
|
|
431
|
-
plot_filenames = [
|
|
466
|
+
plot_filenames = [
|
|
467
|
+
trainer.save_dir / f"{prefix}{plots}.png"
|
|
468
|
+
for plots in EVALUATION_PLOT_NAMES
|
|
469
|
+
for prefix in DETECTION_METRICS_PLOT_PREFIX
|
|
470
|
+
]
|
|
432
471
|
|
|
433
472
|
if plot_filenames is not None:
|
|
434
473
|
_log_images(experiment, plot_filenames, None)
|
|
@@ -448,13 +487,38 @@ def _log_model(experiment, trainer) -> None:
|
|
|
448
487
|
|
|
449
488
|
|
|
450
489
|
def _log_image_batches(experiment, trainer, curr_step: int) -> None:
|
|
451
|
-
"""Log samples of
|
|
490
|
+
"""Log samples of image batches for train, validation, and test."""
|
|
452
491
|
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
|
|
453
492
|
_log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step)
|
|
454
493
|
|
|
455
494
|
|
|
495
|
+
def _log_asset(experiment, asset_path) -> None:
|
|
496
|
+
"""Logs a specific asset file to the given experiment.
|
|
497
|
+
|
|
498
|
+
This function facilitates logging an asset, such as a file, to the provided
|
|
499
|
+
experiment. It enables integration with experiment tracking platforms.
|
|
500
|
+
|
|
501
|
+
Args:
|
|
502
|
+
experiment (comet_ml.CometExperiment): The experiment instance to which the asset will be logged.
|
|
503
|
+
asset_path (Path): The file path of the asset to log.
|
|
504
|
+
"""
|
|
505
|
+
experiment.log_asset(asset_path)
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def _log_table(experiment, table_path) -> None:
|
|
509
|
+
"""Logs a table to the provided experiment.
|
|
510
|
+
|
|
511
|
+
This function is used to log a table file to the given experiment. The table is identified by its file path.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
experiment (comet_ml.CometExperiment): The experiment object where the table file will be logged.
|
|
515
|
+
table_path (Path): The file path of the table to be logged.
|
|
516
|
+
"""
|
|
517
|
+
experiment.log_table(str(table_path))
|
|
518
|
+
|
|
519
|
+
|
|
456
520
|
def on_pretrain_routine_start(trainer) -> None:
|
|
457
|
-
"""
|
|
521
|
+
"""Create or resume a CometML experiment at the start of a YOLO pre-training routine."""
|
|
458
522
|
_resume_or_create_experiment(trainer.args)
|
|
459
523
|
|
|
460
524
|
|
|
@@ -472,16 +536,15 @@ def on_train_epoch_end(trainer) -> None:
|
|
|
472
536
|
|
|
473
537
|
|
|
474
538
|
def on_fit_epoch_end(trainer) -> None:
|
|
475
|
-
"""
|
|
476
|
-
Log model assets at the end of each epoch during training.
|
|
539
|
+
"""Log model assets at the end of each epoch during training.
|
|
477
540
|
|
|
478
|
-
This function is called at the end of each training epoch to log metrics, learning rates, and model information
|
|
479
|
-
|
|
480
|
-
|
|
541
|
+
This function is called at the end of each training epoch to log metrics, learning rates, and model information to a
|
|
542
|
+
Comet ML experiment. It also logs model assets, confusion matrices, and image predictions based on configuration
|
|
543
|
+
settings.
|
|
481
544
|
|
|
482
545
|
The function retrieves the current Comet ML experiment and logs various training metrics. If it's the first epoch,
|
|
483
|
-
it also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled),
|
|
484
|
-
|
|
546
|
+
it also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled), and
|
|
547
|
+
image predictions (if enabled).
|
|
485
548
|
|
|
486
549
|
Args:
|
|
487
550
|
trainer (BaseTrainer): The YOLO trainer object containing training state, metrics, and configuration.
|
|
@@ -534,6 +597,16 @@ def on_train_end(trainer) -> None:
|
|
|
534
597
|
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
|
|
535
598
|
_log_image_predictions(experiment, trainer.validator, curr_step)
|
|
536
599
|
_log_image_batches(experiment, trainer, curr_step)
|
|
600
|
+
# log results table
|
|
601
|
+
table_path = trainer.save_dir / RESULTS_TABLE_NAME
|
|
602
|
+
if table_path.exists():
|
|
603
|
+
_log_table(experiment, table_path)
|
|
604
|
+
|
|
605
|
+
# log arguments YAML
|
|
606
|
+
args_path = trainer.save_dir / ARGS_YAML_NAME
|
|
607
|
+
if args_path.exists():
|
|
608
|
+
_log_asset(experiment, args_path)
|
|
609
|
+
|
|
537
610
|
experiment.end()
|
|
538
611
|
|
|
539
612
|
global _comet_image_prediction_count
|
|
@@ -27,8 +27,7 @@ except (ImportError, AssertionError, TypeError):
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def _log_images(path: Path, prefix: str = "") -> None:
|
|
30
|
-
"""
|
|
31
|
-
Log images at specified path with an optional prefix using DVCLive.
|
|
30
|
+
"""Log images at specified path with an optional prefix using DVCLive.
|
|
32
31
|
|
|
33
32
|
This function logs images found at the given path to DVCLive, organizing them by batch to enable slider
|
|
34
33
|
functionality in the UI. It processes image filenames to extract batch information and restructures the path
|
|
@@ -36,7 +35,7 @@ def _log_images(path: Path, prefix: str = "") -> None:
|
|
|
36
35
|
|
|
37
36
|
Args:
|
|
38
37
|
path (Path): Path to the image file to be logged.
|
|
39
|
-
prefix (str): Optional prefix to add to the image name when logging.
|
|
38
|
+
prefix (str, optional): Optional prefix to add to the image name when logging.
|
|
40
39
|
|
|
41
40
|
Examples:
|
|
42
41
|
>>> from pathlib import Path
|
|
@@ -55,8 +54,7 @@ def _log_images(path: Path, prefix: str = "") -> None:
|
|
|
55
54
|
|
|
56
55
|
|
|
57
56
|
def _log_plots(plots: dict, prefix: str = "") -> None:
|
|
58
|
-
"""
|
|
59
|
-
Log plot images for training progress if they have not been previously processed.
|
|
57
|
+
"""Log plot images for training progress if they have not been previously processed.
|
|
60
58
|
|
|
61
59
|
Args:
|
|
62
60
|
plots (dict): Dictionary containing plot information with timestamps.
|
|
@@ -70,18 +68,14 @@ def _log_plots(plots: dict, prefix: str = "") -> None:
|
|
|
70
68
|
|
|
71
69
|
|
|
72
70
|
def _log_confusion_matrix(validator) -> None:
|
|
73
|
-
"""
|
|
74
|
-
Log confusion matrix for a validator using DVCLive.
|
|
71
|
+
"""Log confusion matrix for a validator using DVCLive.
|
|
75
72
|
|
|
76
|
-
This function processes the confusion matrix from a validator object and logs it to DVCLive by converting
|
|
77
|
-
|
|
73
|
+
This function processes the confusion matrix from a validator object and logs it to DVCLive by converting the matrix
|
|
74
|
+
into lists of target and prediction labels.
|
|
78
75
|
|
|
79
76
|
Args:
|
|
80
|
-
validator (BaseValidator): The validator object containing the confusion matrix and class names.
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
Returns:
|
|
84
|
-
None
|
|
77
|
+
validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have
|
|
78
|
+
attributes confusion_matrix.matrix, confusion_matrix.task, and names.
|
|
85
79
|
"""
|
|
86
80
|
targets = []
|
|
87
81
|
preds = []
|
|
@@ -99,7 +93,7 @@ def _log_confusion_matrix(validator) -> None:
|
|
|
99
93
|
|
|
100
94
|
|
|
101
95
|
def on_pretrain_routine_start(trainer) -> None:
|
|
102
|
-
"""
|
|
96
|
+
"""Initialize DVCLive logger for training metadata during pre-training routine."""
|
|
103
97
|
try:
|
|
104
98
|
global live
|
|
105
99
|
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
|
|
@@ -109,28 +103,27 @@ def on_pretrain_routine_start(trainer) -> None:
|
|
|
109
103
|
|
|
110
104
|
|
|
111
105
|
def on_pretrain_routine_end(trainer) -> None:
|
|
112
|
-
"""
|
|
106
|
+
"""Log plots related to the training process at the end of the pretraining routine."""
|
|
113
107
|
_log_plots(trainer.plots, "train")
|
|
114
108
|
|
|
115
109
|
|
|
116
110
|
def on_train_start(trainer) -> None:
|
|
117
|
-
"""
|
|
111
|
+
"""Log the training parameters if DVCLive logging is active."""
|
|
118
112
|
if live:
|
|
119
113
|
live.log_params(trainer.args)
|
|
120
114
|
|
|
121
115
|
|
|
122
116
|
def on_train_epoch_start(trainer) -> None:
|
|
123
|
-
"""
|
|
117
|
+
"""Set the global variable _training_epoch value to True at the start of training each epoch."""
|
|
124
118
|
global _training_epoch
|
|
125
119
|
_training_epoch = True
|
|
126
120
|
|
|
127
121
|
|
|
128
122
|
def on_fit_epoch_end(trainer) -> None:
|
|
129
|
-
"""
|
|
130
|
-
Log training metrics, model info, and advance to next step at the end of each fit epoch.
|
|
123
|
+
"""Log training metrics, model info, and advance to next step at the end of each fit epoch.
|
|
131
124
|
|
|
132
|
-
This function is called at the end of each fit epoch during training. It logs various metrics including
|
|
133
|
-
|
|
125
|
+
This function is called at the end of each fit epoch during training. It logs various metrics including training
|
|
126
|
+
loss items, validation metrics, and learning rates. On the first epoch, it also logs model
|
|
134
127
|
information. Additionally, it logs training and validation plots and advances the DVCLive step counter.
|
|
135
128
|
|
|
136
129
|
Args:
|
|
@@ -160,12 +153,11 @@ def on_fit_epoch_end(trainer) -> None:
|
|
|
160
153
|
|
|
161
154
|
|
|
162
155
|
def on_train_end(trainer) -> None:
|
|
163
|
-
"""
|
|
164
|
-
Log best metrics, plots, and confusion matrix at the end of training.
|
|
156
|
+
"""Log best metrics, plots, and confusion matrix at the end of training.
|
|
165
157
|
|
|
166
|
-
This function is called at the conclusion of the training process to log final metrics, visualizations, and
|
|
167
|
-
|
|
168
|
-
|
|
158
|
+
This function is called at the conclusion of the training process to log final metrics, visualizations, and model
|
|
159
|
+
artifacts if DVCLive logging is active. It captures the best model performance metrics, training plots, validation
|
|
160
|
+
plots, and confusion matrix for later analysis.
|
|
169
161
|
|
|
170
162
|
Args:
|
|
171
163
|
trainer (BaseTrainer): The trainer object containing training state, metrics, and validation results.
|