ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -26,15 +26,15 @@ from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorst
26
26
  try:
27
27
  import os
28
28
 
29
- assert not TESTS_RUNNING or 'test_mlflow' in os.environ.get('PYTEST_CURRENT_TEST', '') # do not log pytest
30
- assert SETTINGS['mlflow'] is True # verify integration is enabled
29
+ assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
30
+ assert SETTINGS["mlflow"] is True # verify integration is enabled
31
31
  import mlflow
32
32
 
33
- assert hasattr(mlflow, '__version__') # verify package is not directory
33
+ assert hasattr(mlflow, "__version__") # verify package is not directory
34
34
  from pathlib import Path
35
35
 
36
- PREFIX = colorstr('MLflow: ')
37
- SANITIZE = lambda x: {k.replace('(', '').replace(')', ''): float(v) for k, v in x.items()}
36
+ PREFIX = colorstr("MLflow: ")
37
+ SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
38
38
 
39
39
  except (ImportError, AssertionError):
40
40
  mlflow = None
@@ -61,33 +61,33 @@ def on_pretrain_routine_end(trainer):
61
61
  """
62
62
  global mlflow
63
63
 
64
- uri = os.environ.get('MLFLOW_TRACKING_URI') or str(RUNS_DIR / 'mlflow')
65
- LOGGER.debug(f'{PREFIX} tracking uri: {uri}')
64
+ uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
65
+ LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
66
66
  mlflow.set_tracking_uri(uri)
67
67
 
68
68
  # Set experiment and run names
69
- experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
70
- run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
69
+ experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
70
+ run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
71
71
  mlflow.set_experiment(experiment_name)
72
72
 
73
73
  mlflow.autolog()
74
74
  try:
75
75
  active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
76
- LOGGER.info(f'{PREFIX}logging run_id({active_run.info.run_id}) to {uri}')
76
+ LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
77
77
  if Path(uri).is_dir():
78
78
  LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
79
79
  LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
80
80
  mlflow.log_params(dict(trainer.args))
81
81
  except Exception as e:
82
- LOGGER.warning(f'{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n'
83
- f'{PREFIX}WARNING ⚠️ Not tracking this run')
82
+ LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
84
83
 
85
84
 
86
85
  def on_train_epoch_end(trainer):
87
86
  """Log training metrics at the end of each train epoch to MLflow."""
88
87
  if mlflow:
89
- mlflow.log_metrics(metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix='train')),
90
- step=trainer.epoch)
88
+ mlflow.log_metrics(
89
+ metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")), step=trainer.epoch
90
+ )
91
91
  mlflow.log_metrics(metrics=SANITIZE(trainer.lr), step=trainer.epoch)
92
92
 
93
93
 
@@ -101,16 +101,23 @@ def on_train_end(trainer):
101
101
  """Log model artifacts at the end of the training."""
102
102
  if mlflow:
103
103
  mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
104
- for f in trainer.save_dir.glob('*'): # log all other files in save_dir
105
- if f.suffix in {'.png', '.jpg', '.csv', '.pt', '.yaml'}:
104
+ for f in trainer.save_dir.glob("*"): # log all other files in save_dir
105
+ if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
106
106
  mlflow.log_artifact(str(f))
107
107
 
108
108
  mlflow.end_run()
109
- LOGGER.info(f'{PREFIX}results logged to {mlflow.get_tracking_uri()}\n'
110
- f"{PREFIX}disable with 'yolo settings mlflow=False'")
111
-
112
-
113
- callbacks = {
114
- 'on_pretrain_routine_end': on_pretrain_routine_end,
115
- 'on_fit_epoch_end': on_fit_epoch_end,
116
- 'on_train_end': on_train_end} if mlflow else {}
109
+ LOGGER.info(
110
+ f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
111
+ f"{PREFIX}disable with 'yolo settings mlflow=False'"
112
+ )
113
+
114
+
115
+ callbacks = (
116
+ {
117
+ "on_pretrain_routine_end": on_pretrain_routine_end,
118
+ "on_fit_epoch_end": on_fit_epoch_end,
119
+ "on_train_end": on_train_end,
120
+ }
121
+ if mlflow
122
+ else {}
123
+ )
@@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
4
4
 
5
5
  try:
6
6
  assert not TESTS_RUNNING # do not log pytest
7
- assert SETTINGS['neptune'] is True # verify integration is enabled
7
+ assert SETTINGS["neptune"] is True # verify integration is enabled
8
8
  import neptune
9
9
  from neptune.types import File
10
10
 
11
- assert hasattr(neptune, '__version__')
11
+ assert hasattr(neptune, "__version__")
12
12
 
13
13
  run = None # NeptuneAI experiment logger instance
14
14
 
@@ -23,11 +23,11 @@ def _log_scalars(scalars, step=0):
23
23
  run[k].append(value=v, step=step)
24
24
 
25
25
 
26
- def _log_images(imgs_dict, group=''):
26
+ def _log_images(imgs_dict, group=""):
27
27
  """Log scalars to the NeptuneAI experiment logger."""
28
28
  if run:
29
29
  for k, v in imgs_dict.items():
30
- run[f'{group}/{k}'].upload(File(v))
30
+ run[f"{group}/{k}"].upload(File(v))
31
31
 
32
32
 
33
33
  def _log_plot(title, plot_path):
@@ -43,34 +43,35 @@ def _log_plot(title, plot_path):
43
43
 
44
44
  img = mpimg.imread(plot_path)
45
45
  fig = plt.figure()
46
- ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
46
+ ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
47
47
  ax.imshow(img)
48
- run[f'Plots/{title}'].upload(fig)
48
+ run[f"Plots/{title}"].upload(fig)
49
49
 
50
50
 
51
51
  def on_pretrain_routine_start(trainer):
52
52
  """Callback function called before the training routine starts."""
53
53
  try:
54
54
  global run
55
- run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
56
- run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
55
+ run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
56
+ run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
57
57
  except Exception as e:
58
- LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
58
+ LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
59
59
 
60
60
 
61
61
  def on_train_epoch_end(trainer):
62
62
  """Callback function called at end of each training epoch."""
63
- _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
63
+ _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
64
64
  _log_scalars(trainer.lr, trainer.epoch + 1)
65
65
  if trainer.epoch == 1:
66
- _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
66
+ _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
67
67
 
68
68
 
69
69
  def on_fit_epoch_end(trainer):
70
70
  """Callback function called at end of each fit (train+val) epoch."""
71
71
  if run and trainer.epoch == 0:
72
72
  from ultralytics.utils.torch_utils import model_info_for_loggers
73
- run['Configuration/Model'] = model_info_for_loggers(trainer)
73
+
74
+ run["Configuration/Model"] = model_info_for_loggers(trainer)
74
75
  _log_scalars(trainer.metrics, trainer.epoch + 1)
75
76
 
76
77
 
@@ -78,7 +79,7 @@ def on_val_end(validator):
78
79
  """Callback function called at end of each validation."""
79
80
  if run:
80
81
  # Log val_labels and val_pred
81
- _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
82
+ _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
82
83
 
83
84
 
84
85
  def on_train_end(trainer):
@@ -86,19 +87,28 @@ def on_train_end(trainer):
86
87
  if run:
87
88
  # Log final results, CM matrix + PR plots
88
89
  files = [
89
- 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
90
- *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
90
+ "results.png",
91
+ "confusion_matrix.png",
92
+ "confusion_matrix_normalized.png",
93
+ *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
94
+ ]
91
95
  files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
92
96
  for f in files:
93
97
  _log_plot(title=f.stem, plot_path=f)
94
98
  # Log the final model
95
- run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
96
- trainer.best)))
97
-
98
-
99
- callbacks = {
100
- 'on_pretrain_routine_start': on_pretrain_routine_start,
101
- 'on_train_epoch_end': on_train_epoch_end,
102
- 'on_fit_epoch_end': on_fit_epoch_end,
103
- 'on_val_end': on_val_end,
104
- 'on_train_end': on_train_end} if neptune else {}
99
+ run[f"weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}"].upload(
100
+ File(str(trainer.best))
101
+ )
102
+
103
+
104
+ callbacks = (
105
+ {
106
+ "on_pretrain_routine_start": on_pretrain_routine_start,
107
+ "on_train_epoch_end": on_train_epoch_end,
108
+ "on_fit_epoch_end": on_fit_epoch_end,
109
+ "on_val_end": on_val_end,
110
+ "on_train_end": on_train_end,
111
+ }
112
+ if neptune
113
+ else {}
114
+ )
@@ -3,7 +3,7 @@
3
3
  from ultralytics.utils import SETTINGS
4
4
 
5
5
  try:
6
- assert SETTINGS['raytune'] is True # verify integration is enabled
6
+ assert SETTINGS["raytune"] is True # verify integration is enabled
7
7
  import ray
8
8
  from ray import tune
9
9
  from ray.air import session
@@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer):
16
16
  """Sends training metrics to Ray Tune at end of each epoch."""
17
17
  if ray.tune.is_session_enabled():
18
18
  metrics = trainer.metrics
19
- metrics['epoch'] = trainer.epoch
19
+ metrics["epoch"] = trainer.epoch
20
20
  session.report(metrics)
21
21
 
22
22
 
23
- callbacks = {
24
- 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
23
+ callbacks = (
24
+ {
25
+ "on_fit_epoch_end": on_fit_epoch_end,
26
+ }
27
+ if tune
28
+ else {}
29
+ )
@@ -7,7 +7,7 @@ try:
7
7
  from torch.utils.tensorboard import SummaryWriter
8
8
 
9
9
  assert not TESTS_RUNNING # do not log pytest
10
- assert SETTINGS['tensorboard'] is True # verify integration is enabled
10
+ assert SETTINGS["tensorboard"] is True # verify integration is enabled
11
11
  WRITER = None # TensorBoard SummaryWriter instance
12
12
 
13
13
  except (ImportError, AssertionError, TypeError):
@@ -34,10 +34,10 @@ def _log_tensorboard_graph(trainer):
34
34
  p = next(trainer.model.parameters()) # for device, type
35
35
  im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
36
36
  with warnings.catch_warnings():
37
- warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
37
+ warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
38
38
  WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
39
39
  except Exception as e:
40
- LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
40
+ LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}")
41
41
 
42
42
 
43
43
  def on_pretrain_routine_start(trainer):
@@ -46,10 +46,10 @@ def on_pretrain_routine_start(trainer):
46
46
  try:
47
47
  global WRITER
48
48
  WRITER = SummaryWriter(str(trainer.save_dir))
49
- prefix = colorstr('TensorBoard: ')
49
+ prefix = colorstr("TensorBoard: ")
50
50
  LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
51
51
  except Exception as e:
52
- LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
52
+ LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
53
53
 
54
54
 
55
55
  def on_train_start(trainer):
@@ -60,7 +60,7 @@ def on_train_start(trainer):
60
60
 
61
61
  def on_train_epoch_end(trainer):
62
62
  """Logs scalar statistics at the end of a training epoch."""
63
- _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
63
+ _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
64
64
  _log_scalars(trainer.lr, trainer.epoch + 1)
65
65
 
66
66
 
@@ -69,8 +69,13 @@ def on_fit_epoch_end(trainer):
69
69
  _log_scalars(trainer.metrics, trainer.epoch + 1)
70
70
 
71
71
 
72
- callbacks = {
73
- 'on_pretrain_routine_start': on_pretrain_routine_start,
74
- 'on_train_start': on_train_start,
75
- 'on_fit_epoch_end': on_fit_epoch_end,
76
- 'on_train_epoch_end': on_train_epoch_end} if SummaryWriter else {}
72
+ callbacks = (
73
+ {
74
+ "on_pretrain_routine_start": on_pretrain_routine_start,
75
+ "on_train_start": on_train_start,
76
+ "on_fit_epoch_end": on_fit_epoch_end,
77
+ "on_train_epoch_end": on_train_epoch_end,
78
+ }
79
+ if SummaryWriter
80
+ else {}
81
+ )
@@ -5,10 +5,10 @@ from ultralytics.utils.torch_utils import model_info_for_loggers
5
5
 
6
6
  try:
7
7
  assert not TESTS_RUNNING # do not log pytest
8
- assert SETTINGS['wandb'] is True # verify integration is enabled
8
+ assert SETTINGS["wandb"] is True # verify integration is enabled
9
9
  import wandb as wb
10
10
 
11
- assert hasattr(wb, '__version__') # verify package is not directory
11
+ assert hasattr(wb, "__version__") # verify package is not directory
12
12
 
13
13
  import numpy as np
14
14
  import pandas as pd
@@ -19,7 +19,7 @@ except (ImportError, AssertionError):
19
19
  wb = None
20
20
 
21
21
 
22
- def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall', y_title='Precision'):
22
+ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
23
23
  """
24
24
  Create and log a custom metric visualization to wandb.plot.pr_curve.
25
25
 
@@ -37,24 +37,25 @@ def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall
37
37
  Returns:
38
38
  (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
39
39
  """
40
- df = pd.DataFrame({'class': classes, 'y': y, 'x': x}).round(3)
41
- fields = {'x': 'x', 'y': 'y', 'class': 'class'}
42
- string_fields = {'title': title, 'x-axis-title': x_title, 'y-axis-title': y_title}
43
- return wb.plot_table('wandb/area-under-curve/v0',
44
- wb.Table(dataframe=df),
45
- fields=fields,
46
- string_fields=string_fields)
47
-
48
-
49
- def _plot_curve(x,
50
- y,
51
- names=None,
52
- id='precision-recall',
53
- title='Precision Recall Curve',
54
- x_title='Recall',
55
- y_title='Precision',
56
- num_x=100,
57
- only_mean=False):
40
+ df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
41
+ fields = {"x": "x", "y": "y", "class": "class"}
42
+ string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
43
+ return wb.plot_table(
44
+ "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
45
+ )
46
+
47
+
48
+ def _plot_curve(
49
+ x,
50
+ y,
51
+ names=None,
52
+ id="precision-recall",
53
+ title="Precision Recall Curve",
54
+ x_title="Recall",
55
+ y_title="Precision",
56
+ num_x=100,
57
+ only_mean=False,
58
+ ):
58
59
  """
59
60
  Log a metric curve visualization.
60
61
 
@@ -88,7 +89,7 @@ def _plot_curve(x,
88
89
  table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
89
90
  wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
90
91
  else:
91
- classes = ['mean'] * len(x_log)
92
+ classes = ["mean"] * len(x_log)
92
93
  for i, yi in enumerate(y):
93
94
  x_log.extend(x_new) # add new x
94
95
  y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
@@ -99,7 +100,7 @@ def _plot_curve(x,
99
100
  def _log_plots(plots, step):
100
101
  """Logs plots from the input dictionary if they haven't been logged already at the specified step."""
101
102
  for name, params in plots.items():
102
- timestamp = params['timestamp']
103
+ timestamp = params["timestamp"]
103
104
  if _processed_plots.get(name) != timestamp:
104
105
  wb.run.log({name.stem: wb.Image(str(name))}, step=step)
105
106
  _processed_plots[name] = timestamp
@@ -107,7 +108,7 @@ def _log_plots(plots, step):
107
108
 
108
109
  def on_pretrain_routine_start(trainer):
109
110
  """Initiate and start project if module is present."""
110
- wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
111
+ wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
111
112
 
112
113
 
113
114
  def on_fit_epoch_end(trainer):
@@ -121,7 +122,7 @@ def on_fit_epoch_end(trainer):
121
122
 
122
123
  def on_train_epoch_end(trainer):
123
124
  """Log metrics and save images at the end of each training epoch."""
124
- wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
125
+ wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
125
126
  wb.run.log(trainer.lr, step=trainer.epoch + 1)
126
127
  if trainer.epoch == 1:
127
128
  _log_plots(trainer.plots, step=trainer.epoch + 1)
@@ -131,17 +132,17 @@ def on_train_end(trainer):
131
132
  """Save the best model as an artifact at end of training."""
132
133
  _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
133
134
  _log_plots(trainer.plots, step=trainer.epoch + 1)
134
- art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
135
+ art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
135
136
  if trainer.best.exists():
136
137
  art.add_file(trainer.best)
137
- wb.run.log_artifact(art, aliases=['best'])
138
+ wb.run.log_artifact(art, aliases=["best"])
138
139
  for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
139
140
  x, y, x_title, y_title = curve_values
140
141
  _plot_curve(
141
142
  x,
142
143
  y,
143
144
  names=list(trainer.validator.metrics.names.values()),
144
- id=f'curves/{curve_name}',
145
+ id=f"curves/{curve_name}",
145
146
  title=curve_name,
146
147
  x_title=x_title,
147
148
  y_title=y_title,
@@ -149,8 +150,13 @@ def on_train_end(trainer):
149
150
  wb.run.finish() # required or run continues on dashboard
150
151
 
151
152
 
152
- callbacks = {
153
- 'on_pretrain_routine_start': on_pretrain_routine_start,
154
- 'on_train_epoch_end': on_train_epoch_end,
155
- 'on_fit_epoch_end': on_fit_epoch_end,
156
- 'on_train_end': on_train_end} if wb else {}
153
+ callbacks = (
154
+ {
155
+ "on_pretrain_routine_start": on_pretrain_routine_start,
156
+ "on_train_epoch_end": on_train_epoch_end,
157
+ "on_fit_epoch_end": on_fit_epoch_end,
158
+ "on_train_end": on_train_end,
159
+ }
160
+ if wb
161
+ else {}
162
+ )