ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
|
@@ -26,15 +26,15 @@ from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorst
|
|
|
26
26
|
try:
|
|
27
27
|
import os
|
|
28
28
|
|
|
29
|
-
assert not TESTS_RUNNING or
|
|
30
|
-
assert SETTINGS[
|
|
29
|
+
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
|
|
30
|
+
assert SETTINGS["mlflow"] is True # verify integration is enabled
|
|
31
31
|
import mlflow
|
|
32
32
|
|
|
33
|
-
assert hasattr(mlflow,
|
|
33
|
+
assert hasattr(mlflow, "__version__") # verify package is not directory
|
|
34
34
|
from pathlib import Path
|
|
35
35
|
|
|
36
|
-
PREFIX = colorstr(
|
|
37
|
-
SANITIZE = lambda x: {k.replace(
|
|
36
|
+
PREFIX = colorstr("MLflow: ")
|
|
37
|
+
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
|
|
38
38
|
|
|
39
39
|
except (ImportError, AssertionError):
|
|
40
40
|
mlflow = None
|
|
@@ -61,33 +61,33 @@ def on_pretrain_routine_end(trainer):
|
|
|
61
61
|
"""
|
|
62
62
|
global mlflow
|
|
63
63
|
|
|
64
|
-
uri = os.environ.get(
|
|
65
|
-
LOGGER.debug(f
|
|
64
|
+
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
|
|
65
|
+
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
|
|
66
66
|
mlflow.set_tracking_uri(uri)
|
|
67
67
|
|
|
68
68
|
# Set experiment and run names
|
|
69
|
-
experiment_name = os.environ.get(
|
|
70
|
-
run_name = os.environ.get(
|
|
69
|
+
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
|
|
70
|
+
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
|
|
71
71
|
mlflow.set_experiment(experiment_name)
|
|
72
72
|
|
|
73
73
|
mlflow.autolog()
|
|
74
74
|
try:
|
|
75
75
|
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
|
|
76
|
-
LOGGER.info(f
|
|
76
|
+
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
|
|
77
77
|
if Path(uri).is_dir():
|
|
78
78
|
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
|
|
79
79
|
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
|
|
80
80
|
mlflow.log_params(dict(trainer.args))
|
|
81
81
|
except Exception as e:
|
|
82
|
-
LOGGER.warning(f
|
|
83
|
-
f'{PREFIX}WARNING ⚠️ Not tracking this run')
|
|
82
|
+
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
|
|
84
83
|
|
|
85
84
|
|
|
86
85
|
def on_train_epoch_end(trainer):
|
|
87
86
|
"""Log training metrics at the end of each train epoch to MLflow."""
|
|
88
87
|
if mlflow:
|
|
89
|
-
mlflow.log_metrics(
|
|
90
|
-
|
|
88
|
+
mlflow.log_metrics(
|
|
89
|
+
metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")), step=trainer.epoch
|
|
90
|
+
)
|
|
91
91
|
mlflow.log_metrics(metrics=SANITIZE(trainer.lr), step=trainer.epoch)
|
|
92
92
|
|
|
93
93
|
|
|
@@ -101,16 +101,23 @@ def on_train_end(trainer):
|
|
|
101
101
|
"""Log model artifacts at the end of the training."""
|
|
102
102
|
if mlflow:
|
|
103
103
|
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
|
|
104
|
-
for f in trainer.save_dir.glob(
|
|
105
|
-
if f.suffix in {
|
|
104
|
+
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
|
105
|
+
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
|
106
106
|
mlflow.log_artifact(str(f))
|
|
107
107
|
|
|
108
108
|
mlflow.end_run()
|
|
109
|
-
LOGGER.info(
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
109
|
+
LOGGER.info(
|
|
110
|
+
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
|
|
111
|
+
f"{PREFIX}disable with 'yolo settings mlflow=False'"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
callbacks = (
|
|
116
|
+
{
|
|
117
|
+
"on_pretrain_routine_end": on_pretrain_routine_end,
|
|
118
|
+
"on_fit_epoch_end": on_fit_epoch_end,
|
|
119
|
+
"on_train_end": on_train_end,
|
|
120
|
+
}
|
|
121
|
+
if mlflow
|
|
122
|
+
else {}
|
|
123
|
+
)
|
|
@@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
|
|
4
4
|
|
|
5
5
|
try:
|
|
6
6
|
assert not TESTS_RUNNING # do not log pytest
|
|
7
|
-
assert SETTINGS[
|
|
7
|
+
assert SETTINGS["neptune"] is True # verify integration is enabled
|
|
8
8
|
import neptune
|
|
9
9
|
from neptune.types import File
|
|
10
10
|
|
|
11
|
-
assert hasattr(neptune,
|
|
11
|
+
assert hasattr(neptune, "__version__")
|
|
12
12
|
|
|
13
13
|
run = None # NeptuneAI experiment logger instance
|
|
14
14
|
|
|
@@ -23,11 +23,11 @@ def _log_scalars(scalars, step=0):
|
|
|
23
23
|
run[k].append(value=v, step=step)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def _log_images(imgs_dict, group=
|
|
26
|
+
def _log_images(imgs_dict, group=""):
|
|
27
27
|
"""Log scalars to the NeptuneAI experiment logger."""
|
|
28
28
|
if run:
|
|
29
29
|
for k, v in imgs_dict.items():
|
|
30
|
-
run[f
|
|
30
|
+
run[f"{group}/{k}"].upload(File(v))
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def _log_plot(title, plot_path):
|
|
@@ -43,34 +43,35 @@ def _log_plot(title, plot_path):
|
|
|
43
43
|
|
|
44
44
|
img = mpimg.imread(plot_path)
|
|
45
45
|
fig = plt.figure()
|
|
46
|
-
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect=
|
|
46
|
+
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
|
|
47
47
|
ax.imshow(img)
|
|
48
|
-
run[f
|
|
48
|
+
run[f"Plots/{title}"].upload(fig)
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
def on_pretrain_routine_start(trainer):
|
|
52
52
|
"""Callback function called before the training routine starts."""
|
|
53
53
|
try:
|
|
54
54
|
global run
|
|
55
|
-
run = neptune.init_run(project=trainer.args.project or
|
|
56
|
-
run[
|
|
55
|
+
run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
|
|
56
|
+
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
|
|
57
57
|
except Exception as e:
|
|
58
|
-
LOGGER.warning(f
|
|
58
|
+
LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
def on_train_epoch_end(trainer):
|
|
62
62
|
"""Callback function called at end of each training epoch."""
|
|
63
|
-
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix=
|
|
63
|
+
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
|
64
64
|
_log_scalars(trainer.lr, trainer.epoch + 1)
|
|
65
65
|
if trainer.epoch == 1:
|
|
66
|
-
_log_images({f.stem: str(f) for f in trainer.save_dir.glob(
|
|
66
|
+
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
def on_fit_epoch_end(trainer):
|
|
70
70
|
"""Callback function called at end of each fit (train+val) epoch."""
|
|
71
71
|
if run and trainer.epoch == 0:
|
|
72
72
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
73
|
-
|
|
73
|
+
|
|
74
|
+
run["Configuration/Model"] = model_info_for_loggers(trainer)
|
|
74
75
|
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
|
75
76
|
|
|
76
77
|
|
|
@@ -78,7 +79,7 @@ def on_val_end(validator):
|
|
|
78
79
|
"""Callback function called at end of each validation."""
|
|
79
80
|
if run:
|
|
80
81
|
# Log val_labels and val_pred
|
|
81
|
-
_log_images({f.stem: str(f) for f in validator.save_dir.glob(
|
|
82
|
+
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
|
|
82
83
|
|
|
83
84
|
|
|
84
85
|
def on_train_end(trainer):
|
|
@@ -86,19 +87,28 @@ def on_train_end(trainer):
|
|
|
86
87
|
if run:
|
|
87
88
|
# Log final results, CM matrix + PR plots
|
|
88
89
|
files = [
|
|
89
|
-
|
|
90
|
-
|
|
90
|
+
"results.png",
|
|
91
|
+
"confusion_matrix.png",
|
|
92
|
+
"confusion_matrix_normalized.png",
|
|
93
|
+
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
|
94
|
+
]
|
|
91
95
|
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
|
92
96
|
for f in files:
|
|
93
97
|
_log_plot(title=f.stem, plot_path=f)
|
|
94
98
|
# Log the final model
|
|
95
|
-
run[f
|
|
96
|
-
trainer.best))
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
99
|
+
run[f"weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}"].upload(
|
|
100
|
+
File(str(trainer.best))
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
callbacks = (
|
|
105
|
+
{
|
|
106
|
+
"on_pretrain_routine_start": on_pretrain_routine_start,
|
|
107
|
+
"on_train_epoch_end": on_train_epoch_end,
|
|
108
|
+
"on_fit_epoch_end": on_fit_epoch_end,
|
|
109
|
+
"on_val_end": on_val_end,
|
|
110
|
+
"on_train_end": on_train_end,
|
|
111
|
+
}
|
|
112
|
+
if neptune
|
|
113
|
+
else {}
|
|
114
|
+
)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from ultralytics.utils import SETTINGS
|
|
4
4
|
|
|
5
5
|
try:
|
|
6
|
-
assert SETTINGS[
|
|
6
|
+
assert SETTINGS["raytune"] is True # verify integration is enabled
|
|
7
7
|
import ray
|
|
8
8
|
from ray import tune
|
|
9
9
|
from ray.air import session
|
|
@@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer):
|
|
|
16
16
|
"""Sends training metrics to Ray Tune at end of each epoch."""
|
|
17
17
|
if ray.tune.is_session_enabled():
|
|
18
18
|
metrics = trainer.metrics
|
|
19
|
-
metrics[
|
|
19
|
+
metrics["epoch"] = trainer.epoch
|
|
20
20
|
session.report(metrics)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
callbacks =
|
|
24
|
-
|
|
23
|
+
callbacks = (
|
|
24
|
+
{
|
|
25
|
+
"on_fit_epoch_end": on_fit_epoch_end,
|
|
26
|
+
}
|
|
27
|
+
if tune
|
|
28
|
+
else {}
|
|
29
|
+
)
|
|
@@ -7,7 +7,7 @@ try:
|
|
|
7
7
|
from torch.utils.tensorboard import SummaryWriter
|
|
8
8
|
|
|
9
9
|
assert not TESTS_RUNNING # do not log pytest
|
|
10
|
-
assert SETTINGS[
|
|
10
|
+
assert SETTINGS["tensorboard"] is True # verify integration is enabled
|
|
11
11
|
WRITER = None # TensorBoard SummaryWriter instance
|
|
12
12
|
|
|
13
13
|
except (ImportError, AssertionError, TypeError):
|
|
@@ -34,10 +34,10 @@ def _log_tensorboard_graph(trainer):
|
|
|
34
34
|
p = next(trainer.model.parameters()) # for device, type
|
|
35
35
|
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
|
36
36
|
with warnings.catch_warnings():
|
|
37
|
-
warnings.simplefilter(
|
|
37
|
+
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
|
|
38
38
|
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
|
39
39
|
except Exception as e:
|
|
40
|
-
LOGGER.warning(f
|
|
40
|
+
LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}")
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
def on_pretrain_routine_start(trainer):
|
|
@@ -46,10 +46,10 @@ def on_pretrain_routine_start(trainer):
|
|
|
46
46
|
try:
|
|
47
47
|
global WRITER
|
|
48
48
|
WRITER = SummaryWriter(str(trainer.save_dir))
|
|
49
|
-
prefix = colorstr(
|
|
49
|
+
prefix = colorstr("TensorBoard: ")
|
|
50
50
|
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
|
51
51
|
except Exception as e:
|
|
52
|
-
LOGGER.warning(f
|
|
52
|
+
LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
def on_train_start(trainer):
|
|
@@ -60,7 +60,7 @@ def on_train_start(trainer):
|
|
|
60
60
|
|
|
61
61
|
def on_train_epoch_end(trainer):
|
|
62
62
|
"""Logs scalar statistics at the end of a training epoch."""
|
|
63
|
-
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix=
|
|
63
|
+
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
|
64
64
|
_log_scalars(trainer.lr, trainer.epoch + 1)
|
|
65
65
|
|
|
66
66
|
|
|
@@ -69,8 +69,13 @@ def on_fit_epoch_end(trainer):
|
|
|
69
69
|
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
|
70
70
|
|
|
71
71
|
|
|
72
|
-
callbacks =
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
72
|
+
callbacks = (
|
|
73
|
+
{
|
|
74
|
+
"on_pretrain_routine_start": on_pretrain_routine_start,
|
|
75
|
+
"on_train_start": on_train_start,
|
|
76
|
+
"on_fit_epoch_end": on_fit_epoch_end,
|
|
77
|
+
"on_train_epoch_end": on_train_epoch_end,
|
|
78
|
+
}
|
|
79
|
+
if SummaryWriter
|
|
80
|
+
else {}
|
|
81
|
+
)
|
|
@@ -5,10 +5,10 @@ from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
|
5
5
|
|
|
6
6
|
try:
|
|
7
7
|
assert not TESTS_RUNNING # do not log pytest
|
|
8
|
-
assert SETTINGS[
|
|
8
|
+
assert SETTINGS["wandb"] is True # verify integration is enabled
|
|
9
9
|
import wandb as wb
|
|
10
10
|
|
|
11
|
-
assert hasattr(wb,
|
|
11
|
+
assert hasattr(wb, "__version__") # verify package is not directory
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
import pandas as pd
|
|
@@ -19,7 +19,7 @@ except (ImportError, AssertionError):
|
|
|
19
19
|
wb = None
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def _custom_table(x, y, classes, title=
|
|
22
|
+
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
|
|
23
23
|
"""
|
|
24
24
|
Create and log a custom metric visualization to wandb.plot.pr_curve.
|
|
25
25
|
|
|
@@ -37,24 +37,25 @@ def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall
|
|
|
37
37
|
Returns:
|
|
38
38
|
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
|
39
39
|
"""
|
|
40
|
-
df = pd.DataFrame({
|
|
41
|
-
fields = {
|
|
42
|
-
string_fields = {
|
|
43
|
-
return wb.plot_table(
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
40
|
+
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
|
41
|
+
fields = {"x": "x", "y": "y", "class": "class"}
|
|
42
|
+
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
|
43
|
+
return wb.plot_table(
|
|
44
|
+
"wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _plot_curve(
|
|
49
|
+
x,
|
|
50
|
+
y,
|
|
51
|
+
names=None,
|
|
52
|
+
id="precision-recall",
|
|
53
|
+
title="Precision Recall Curve",
|
|
54
|
+
x_title="Recall",
|
|
55
|
+
y_title="Precision",
|
|
56
|
+
num_x=100,
|
|
57
|
+
only_mean=False,
|
|
58
|
+
):
|
|
58
59
|
"""
|
|
59
60
|
Log a metric curve visualization.
|
|
60
61
|
|
|
@@ -88,7 +89,7 @@ def _plot_curve(x,
|
|
|
88
89
|
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
|
|
89
90
|
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
|
|
90
91
|
else:
|
|
91
|
-
classes = [
|
|
92
|
+
classes = ["mean"] * len(x_log)
|
|
92
93
|
for i, yi in enumerate(y):
|
|
93
94
|
x_log.extend(x_new) # add new x
|
|
94
95
|
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
|
|
@@ -99,7 +100,7 @@ def _plot_curve(x,
|
|
|
99
100
|
def _log_plots(plots, step):
|
|
100
101
|
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
|
|
101
102
|
for name, params in plots.items():
|
|
102
|
-
timestamp = params[
|
|
103
|
+
timestamp = params["timestamp"]
|
|
103
104
|
if _processed_plots.get(name) != timestamp:
|
|
104
105
|
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
|
|
105
106
|
_processed_plots[name] = timestamp
|
|
@@ -107,7 +108,7 @@ def _log_plots(plots, step):
|
|
|
107
108
|
|
|
108
109
|
def on_pretrain_routine_start(trainer):
|
|
109
110
|
"""Initiate and start project if module is present."""
|
|
110
|
-
wb.run or wb.init(project=trainer.args.project or
|
|
111
|
+
wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
|
|
111
112
|
|
|
112
113
|
|
|
113
114
|
def on_fit_epoch_end(trainer):
|
|
@@ -121,7 +122,7 @@ def on_fit_epoch_end(trainer):
|
|
|
121
122
|
|
|
122
123
|
def on_train_epoch_end(trainer):
|
|
123
124
|
"""Log metrics and save images at the end of each training epoch."""
|
|
124
|
-
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix=
|
|
125
|
+
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
|
|
125
126
|
wb.run.log(trainer.lr, step=trainer.epoch + 1)
|
|
126
127
|
if trainer.epoch == 1:
|
|
127
128
|
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
|
@@ -131,17 +132,17 @@ def on_train_end(trainer):
|
|
|
131
132
|
"""Save the best model as an artifact at end of training."""
|
|
132
133
|
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
|
133
134
|
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
|
134
|
-
art = wb.Artifact(type=
|
|
135
|
+
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
|
|
135
136
|
if trainer.best.exists():
|
|
136
137
|
art.add_file(trainer.best)
|
|
137
|
-
wb.run.log_artifact(art, aliases=[
|
|
138
|
+
wb.run.log_artifact(art, aliases=["best"])
|
|
138
139
|
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
|
|
139
140
|
x, y, x_title, y_title = curve_values
|
|
140
141
|
_plot_curve(
|
|
141
142
|
x,
|
|
142
143
|
y,
|
|
143
144
|
names=list(trainer.validator.metrics.names.values()),
|
|
144
|
-
id=f
|
|
145
|
+
id=f"curves/{curve_name}",
|
|
145
146
|
title=curve_name,
|
|
146
147
|
x_title=x_title,
|
|
147
148
|
y_title=y_title,
|
|
@@ -149,8 +150,13 @@ def on_train_end(trainer):
|
|
|
149
150
|
wb.run.finish() # required or run continues on dashboard
|
|
150
151
|
|
|
151
152
|
|
|
152
|
-
callbacks =
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
153
|
+
callbacks = (
|
|
154
|
+
{
|
|
155
|
+
"on_pretrain_routine_start": on_pretrain_routine_start,
|
|
156
|
+
"on_train_epoch_end": on_train_epoch_end,
|
|
157
|
+
"on_fit_epoch_end": on_fit_epoch_end,
|
|
158
|
+
"on_train_end": on_train_end,
|
|
159
|
+
}
|
|
160
|
+
if wb
|
|
161
|
+
else {}
|
|
162
|
+
)
|