ultralytics 8.0.66__py3-none-any.whl → 8.0.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (31) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/hub/__init__.py +6 -4
  3. ultralytics/hub/auth.py +6 -4
  4. ultralytics/hub/session.py +1 -1
  5. ultralytics/yolo/data/utils.py +2 -2
  6. ultralytics/yolo/engine/exporter.py +8 -3
  7. ultralytics/yolo/engine/model.py +9 -11
  8. ultralytics/yolo/engine/predictor.py +8 -3
  9. ultralytics/yolo/engine/trainer.py +4 -5
  10. ultralytics/yolo/engine/validator.py +8 -3
  11. ultralytics/yolo/utils/__init__.py +16 -5
  12. ultralytics/yolo/utils/callbacks/__init__.py +2 -2
  13. ultralytics/yolo/utils/callbacks/base.py +6 -0
  14. ultralytics/yolo/utils/callbacks/clearml.py +85 -18
  15. ultralytics/yolo/utils/callbacks/comet.py +290 -20
  16. ultralytics/yolo/utils/checks.py +5 -5
  17. ultralytics/yolo/utils/downloads.py +17 -9
  18. ultralytics/yolo/utils/files.py +0 -7
  19. ultralytics/yolo/v8/classify/train.py +2 -2
  20. ultralytics/yolo/v8/classify/val.py +2 -2
  21. ultralytics/yolo/v8/detect/val.py +2 -2
  22. ultralytics/yolo/v8/pose/train.py +2 -2
  23. ultralytics/yolo/v8/pose/val.py +2 -2
  24. ultralytics/yolo/v8/segment/train.py +2 -2
  25. ultralytics/yolo/v8/segment/val.py +2 -2
  26. {ultralytics-8.0.66.dist-info → ultralytics-8.0.68.dist-info}/METADATA +11 -11
  27. {ultralytics-8.0.66.dist-info → ultralytics-8.0.68.dist-info}/RECORD +31 -31
  28. {ultralytics-8.0.66.dist-info → ultralytics-8.0.68.dist-info}/LICENSE +0 -0
  29. {ultralytics-8.0.66.dist-info → ultralytics-8.0.68.dist-info}/WHEEL +0 -0
  30. {ultralytics-8.0.66.dist-info → ultralytics-8.0.68.dist-info}/entry_points.txt +0 -0
  31. {ultralytics-8.0.66.dist-info → ultralytics-8.0.68.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, GPL-3.0 license
2
2
 
3
- __version__ = '8.0.66'
3
+ __version__ = '8.0.68'
4
4
 
5
5
  from ultralytics.hub import start
6
6
  from ultralytics.yolo.engine.model import YOLO
@@ -3,7 +3,7 @@
3
3
  import requests
4
4
 
5
5
  from ultralytics.hub.utils import PREFIX, split_key
6
- from ultralytics.yolo.utils import LOGGER
6
+ from ultralytics.yolo.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
7
7
 
8
8
 
9
9
  def login(api_key=''):
@@ -15,7 +15,7 @@ def login(api_key=''):
15
15
 
16
16
  Example:
17
17
  from ultralytics import hub
18
- hub.login('your_api_key')
18
+ hub.login('API_KEY')
19
19
  """
20
20
  from ultralytics.hub.auth import Auth
21
21
  Auth(api_key)
@@ -23,13 +23,15 @@ def login(api_key=''):
23
23
 
24
24
  def logout():
25
25
  """
26
- Logout Ultralytics HUB
26
+ Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
27
27
 
28
28
  Example:
29
29
  from ultralytics import hub
30
30
  hub.logout()
31
31
  """
32
- LOGGER.warning('WARNING ⚠️ This method is not yet implemented.')
32
+ SETTINGS['api_key'] = ''
33
+ yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
34
+ LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
33
35
 
34
36
 
35
37
  def start(key=''):
ultralytics/hub/auth.py CHANGED
@@ -11,7 +11,7 @@ API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys'
11
11
  class Auth:
12
12
  id_token = api_key = model_key = False
13
13
 
14
- def __init__(self, api_key=''):
14
+ def __init__(self, api_key='', verbose=True):
15
15
  """
16
16
  Initialize the Auth class with an optional API key.
17
17
 
@@ -29,7 +29,8 @@ class Auth:
29
29
  # If the provided API key matches the API key in the SETTINGS
30
30
  if self.api_key == SETTINGS.get('api_key'):
31
31
  # Log that the user is already logged in
32
- LOGGER.info(f'{PREFIX}Authenticated ✅')
32
+ if verbose:
33
+ LOGGER.info(f'{PREFIX}Authenticated ✅')
33
34
  return
34
35
  else:
35
36
  # Attempt to authenticate with the provided API key
@@ -46,8 +47,9 @@ class Auth:
46
47
  if success:
47
48
  set_settings({'api_key': self.api_key})
48
49
  # Log that the new login was successful
49
- LOGGER.info(f'{PREFIX}New authentication successful ✅')
50
- else:
50
+ if verbose:
51
+ LOGGER.info(f'{PREFIX}New authentication successful ✅')
52
+ elif verbose:
51
53
  LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
52
54
 
53
55
  def request_api_key(self, max_attempts=3):
@@ -58,7 +58,7 @@ class HUBTrainingSession:
58
58
  raise ValueError(f'Invalid HUBTrainingSession input: {url}')
59
59
 
60
60
  # Authorize
61
- auth = Auth(key)
61
+ auth = Auth(key, verbose=False)
62
62
  self.agent_id = None # identifies which instance is communicating with server
63
63
  self.model_id = model_id
64
64
  self.model_url = f'https://hub.ultralytics.com/models/{model_id}'
@@ -17,7 +17,7 @@ from PIL import ExifTags, Image, ImageOps
17
17
  from tqdm import tqdm
18
18
 
19
19
  from ultralytics.nn.autobackend import check_class_names
20
- from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load
20
+ from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, clean_url, colorstr, emojis, yaml_load
21
21
  from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
22
22
  from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
23
23
  from ultralytics.yolo.utils.ops import segments2boxes
@@ -241,7 +241,7 @@ def check_det_dataset(dataset, autodownload=True):
241
241
  if val:
242
242
  val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
243
243
  if not all(x.exists() for x in val):
244
- name = str(dataset).split('?')[0] # dataset name with URL auth stripped
244
+ name = clean_url(dataset) # dataset name with URL auth stripped
245
245
  m = f"\nDataset '{name}' images not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
246
246
  if s and autodownload:
247
247
  LOGGER.warning(m)
@@ -53,7 +53,6 @@ import platform
53
53
  import subprocess
54
54
  import time
55
55
  import warnings
56
- from collections import defaultdict
57
56
  from copy import deepcopy
58
57
  from pathlib import Path
59
58
 
@@ -130,7 +129,7 @@ class Exporter:
130
129
  save_dir (Path): Directory to save results.
131
130
  """
132
131
 
133
- def __init__(self, cfg=DEFAULT_CFG, overrides=None):
132
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
134
133
  """
135
134
  Initializes the Exporter class.
136
135
 
@@ -139,7 +138,7 @@ class Exporter:
139
138
  overrides (dict, optional): Configuration overrides. Defaults to None.
140
139
  """
141
140
  self.args = get_cfg(cfg, overrides)
142
- self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
141
+ self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
143
142
  callbacks.add_integration_callbacks(self)
144
143
 
145
144
  @smart_inference_mode()
@@ -854,6 +853,12 @@ class Exporter:
854
853
  LOGGER.info(f'{prefix} pipeline success')
855
854
  return model
856
855
 
856
+ def add_callback(self, event: str, callback):
857
+ """
858
+ Appends the given callback.
859
+ """
860
+ self.callbacks[event].append(callback)
861
+
857
862
  def run_callbacks(self, event: str):
858
863
  for callback in self.callbacks.get(event, []):
859
864
  callback(self)
@@ -78,7 +78,7 @@ class YOLO:
78
78
  task (Any, optional): Task type for the YOLO model. Defaults to None.
79
79
 
80
80
  """
81
- self._reset_callbacks()
81
+ self.callbacks = callbacks.get_default_callbacks()
82
82
  self.predictor = None # reuse predictor
83
83
  self.model = None # model object
84
84
  self.trainer = None # trainer object
@@ -238,7 +238,7 @@ class YOLO:
238
238
  overrides['save'] = kwargs.get('save', False) # not save files by default
239
239
  if not self.predictor:
240
240
  self.task = overrides.get('task') or self.task
241
- self.predictor = TASK_MAP[self.task][3](overrides=overrides)
241
+ self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
242
242
  self.predictor.setup_model(model=self.model, verbose=is_cli)
243
243
  else: # only update args if predictor is already setup
244
244
  self.predictor.args = get_cfg(self.predictor.args, overrides)
@@ -277,7 +277,7 @@ class YOLO:
277
277
  args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
278
278
  args.imgsz = check_imgsz(args.imgsz, max_dim=1)
279
279
 
280
- validator = TASK_MAP[self.task][2](args=args)
280
+ validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
281
281
  validator(model=self.model)
282
282
  self.metrics = validator.metrics
283
283
 
@@ -316,7 +316,7 @@ class YOLO:
316
316
  args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
317
317
  if args.batch == DEFAULT_CFG.batch:
318
318
  args.batch = 1 # default to 1 if not modified
319
- return Exporter(overrides=args)(model=self.model)
319
+ return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
320
320
 
321
321
  def train(self, **kwargs):
322
322
  """
@@ -344,7 +344,7 @@ class YOLO:
344
344
  overrides['resume'] = self.ckpt_path
345
345
 
346
346
  self.task = overrides.get('task') or self.task
347
- self.trainer = TASK_MAP[self.task][1](overrides=overrides)
347
+ self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
348
348
  if not overrides.get('resume'): # manually set model only if not resuming
349
349
  self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
350
350
  self.model = self.trainer.model
@@ -387,19 +387,17 @@ class YOLO:
387
387
  """
388
388
  return self.model.transforms if hasattr(self.model, 'transforms') else None
389
389
 
390
- @staticmethod
391
- def add_callback(event: str, func):
390
+ def add_callback(self, event: str, func):
392
391
  """
393
392
  Add callback
394
393
  """
395
- callbacks.default_callbacks[event].append(func)
394
+ self.callbacks[event].append(func)
396
395
 
397
396
  @staticmethod
398
397
  def _reset_ckpt_args(args):
399
398
  include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
400
399
  return {k: v for k, v in args.items() if k in include}
401
400
 
402
- @staticmethod
403
- def _reset_callbacks():
401
+ def _reset_callbacks(self):
404
402
  for event in callbacks.default_callbacks.keys():
405
- callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
403
+ self.callbacks[event] = [callbacks.default_callbacks[event][0]]
@@ -28,7 +28,6 @@ Usage - formats:
28
28
  yolov8n_paddle_model # PaddlePaddle
29
29
  """
30
30
  import platform
31
- from collections import defaultdict
32
31
  from pathlib import Path
33
32
 
34
33
  import cv2
@@ -75,7 +74,7 @@ class BasePredictor:
75
74
  data_path (str): Path to data.
76
75
  """
77
76
 
78
- def __init__(self, cfg=DEFAULT_CFG, overrides=None):
77
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
79
78
  """
80
79
  Initializes the BasePredictor class.
81
80
 
@@ -104,7 +103,7 @@ class BasePredictor:
104
103
  self.data_path = None
105
104
  self.source_type = None
106
105
  self.batch = None
107
- self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
106
+ self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
108
107
  callbacks.add_integration_callbacks(self)
109
108
 
110
109
  def preprocess(self, img):
@@ -283,3 +282,9 @@ class BasePredictor:
283
282
  def run_callbacks(self, event: str):
284
283
  for callback in self.callbacks.get(event, []):
285
284
  callback(self)
285
+
286
+ def add_callback(self, event: str, func):
287
+ """
288
+ Add callback
289
+ """
290
+ self.callbacks[event].append(func)
@@ -8,7 +8,6 @@ Usage:
8
8
  import os
9
9
  import subprocess
10
10
  import time
11
- from collections import defaultdict
12
11
  from copy import deepcopy
13
12
  from datetime import datetime
14
13
  from pathlib import Path
@@ -26,7 +25,7 @@ from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
26
25
  from ultralytics.yolo.cfg import get_cfg
27
26
  from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
28
27
  from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__,
29
- callbacks, colorstr, emojis, yaml_save)
28
+ callbacks, clean_url, colorstr, emojis, yaml_save)
30
29
  from ultralytics.yolo.utils.autobatch import check_train_batch_size
31
30
  from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
32
31
  from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
@@ -72,7 +71,7 @@ class BaseTrainer:
72
71
  csv (Path): Path to results CSV file.
73
72
  """
74
73
 
75
- def __init__(self, cfg=DEFAULT_CFG, overrides=None):
74
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
76
75
  """
77
76
  Initializes the BaseTrainer class.
78
77
 
@@ -124,7 +123,7 @@ class BaseTrainer:
124
123
  if 'yaml_file' in self.data:
125
124
  self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
126
125
  except Exception as e:
127
- raise RuntimeError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
126
+ raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
128
127
 
129
128
  self.trainset, self.testset = self.get_dataset(self.data)
130
129
  self.ema = None
@@ -143,7 +142,7 @@ class BaseTrainer:
143
142
  self.plot_idx = [0, 1, 2]
144
143
 
145
144
  # Callbacks
146
- self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
145
+ self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
147
146
  if RANK in (-1, 0):
148
147
  callbacks.add_integration_callbacks(self)
149
148
 
@@ -19,7 +19,6 @@ Usage - formats:
19
19
  yolov8n_paddle_model # PaddlePaddle
20
20
  """
21
21
  import json
22
- from collections import defaultdict
23
22
  from pathlib import Path
24
23
 
25
24
  import torch
@@ -55,7 +54,7 @@ class BaseValidator:
55
54
  save_dir (Path): Directory to save results.
56
55
  """
57
56
 
58
- def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
57
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
59
58
  """
60
59
  Initializes a BaseValidator instance.
61
60
 
@@ -85,7 +84,7 @@ class BaseValidator:
85
84
  if self.args.conf is None:
86
85
  self.args.conf = 0.001 # default conf=0.001
87
86
 
88
- self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
87
+ self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
89
88
 
90
89
  @smart_inference_mode()
91
90
  def __call__(self, trainer=None, model=None):
@@ -195,6 +194,12 @@ class BaseValidator:
195
194
  LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
196
195
  return stats
197
196
 
197
+ def add_callback(self, event: str, callback):
198
+ """
199
+ Appends the given callback.
200
+ """
201
+ self.callbacks[event].append(callback)
202
+
198
203
  def run_callbacks(self, event: str):
199
204
  for callback in self.callbacks.get(event, []):
200
205
  callback(self)
@@ -10,6 +10,7 @@ import subprocess
10
10
  import sys
11
11
  import tempfile
12
12
  import threading
13
+ import urllib
13
14
  import uuid
14
15
  from pathlib import Path
15
16
  from types import SimpleNamespace
@@ -165,7 +166,7 @@ class IterableSimpleNamespace(SimpleNamespace):
165
166
  def set_logging(name=LOGGING_NAME, verbose=True):
166
167
  # sets up logging for the given name
167
168
  rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
168
- level = logging.INFO if verbose and rank in (-1, 0) else logging.ERROR
169
+ level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
169
170
  logging.config.dictConfig({
170
171
  'version': 1,
171
172
  'disable_existing_loggers': False,
@@ -649,10 +650,20 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
649
650
 
650
651
  def deprecation_warn(arg, new_arg, version=None):
651
652
  if not version:
652
- version = float(__version__[0:3]) + 0.2 # deprecate after 2nd major release
653
- LOGGER.warning(
654
- f'WARNING: `{arg}` is deprecated and will be removed in upcoming major release {version}. Use `{new_arg}` instead'
655
- )
653
+ version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
654
+ LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
655
+ f"Please use '{new_arg}' instead.")
656
+
657
+
658
+ def clean_url(url):
659
+ # Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt
660
+ url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
661
+ return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
662
+
663
+
664
+ def url2file(url):
665
+ # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
666
+ return Path(clean_url(url)).name
656
667
 
657
668
 
658
669
  # Run below code on yolo/utils init ------------------------------------------------------------------------------------
@@ -1,3 +1,3 @@
1
- from .base import add_integration_callbacks, default_callbacks
1
+ from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
2
2
 
3
- __all__ = 'add_integration_callbacks', 'default_callbacks'
3
+ __all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
@@ -2,6 +2,8 @@
2
2
  """
3
3
  Base callbacks
4
4
  """
5
+ from collections import defaultdict
6
+ from copy import deepcopy
5
7
 
6
8
 
7
9
  # Trainer callbacks ----------------------------------------------------------------------------------------------------
@@ -143,6 +145,10 @@ default_callbacks = {
143
145
  'on_export_end': [on_export_end]}
144
146
 
145
147
 
148
+ def get_default_callbacks():
149
+ return defaultdict(list, deepcopy(default_callbacks))
150
+
151
+
146
152
  def add_integration_callbacks(instance):
147
153
  from .clearml import callbacks as clearml_callbacks
148
154
  from .comet import callbacks as comet_callbacks
@@ -1,10 +1,17 @@
1
1
  # Ultralytics YOLO 🚀, GPL-3.0 license
2
+ import re
3
+
4
+ import matplotlib.image as mpimg
5
+ import matplotlib.pyplot as plt
6
+
2
7
  from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
3
8
  from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
4
9
 
5
10
  try:
6
11
  import clearml
7
12
  from clearml import Task
13
+ from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
14
+ from clearml.binding.matplotlib_bind import PatchedMatplotlib
8
15
 
9
16
  assert hasattr(clearml, '__version__') # verify package is not directory
10
17
  assert not TESTS_RUNNING # do not log pytest
@@ -12,21 +19,61 @@ except (ImportError, AssertionError):
12
19
  clearml = None
13
20
 
14
21
 
15
- def _log_images(imgs_dict, group='', step=0):
16
- task = Task.current_task()
17
- if task:
18
- for k, v in imgs_dict.items():
19
- task.get_logger().report_image(group, k, step, v)
22
+ def _log_debug_samples(files, title='Debug Samples'):
23
+ """
24
+ Log files (images) as debug samples in the ClearML task.
25
+
26
+ arguments:
27
+ files (List(PosixPath)) a list of file paths in PosixPath format
28
+ title (str) A title that groups together images with the same values
29
+ """
30
+ for f in files:
31
+ if f.exists():
32
+ it = re.search(r'_batch(\d+)', f.name)
33
+ iteration = int(it.groups()[0]) if it else 0
34
+ Task.current_task().get_logger().report_image(title=title,
35
+ series=f.name.replace(it.group(), ''),
36
+ local_path=str(f),
37
+ iteration=iteration)
38
+
39
+
40
+ def _log_plot(title, plot_path):
41
+ """
42
+ Log image as plot in the plot section of ClearML
43
+
44
+ arguments:
45
+ title (str) Title of the plot
46
+ plot_path (PosixPath or str) Path to the saved image file
47
+ """
48
+ img = mpimg.imread(plot_path)
49
+ fig = plt.figure()
50
+ ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
51
+ ax.imshow(img)
52
+
53
+ Task.current_task().get_logger().report_matplotlib_figure(title, '', figure=fig, report_interactive=False)
20
54
 
21
55
 
22
56
  def on_pretrain_routine_start(trainer):
57
+ # TODO: reuse existing task
23
58
  try:
24
- task = Task.init(project_name=trainer.args.project or 'YOLOv8',
25
- task_name=trainer.args.name,
26
- tags=['YOLOv8'],
27
- output_uri=True,
28
- reuse_last_task_id=False,
29
- auto_connect_frameworks={'pytorch': False})
59
+ if Task.current_task():
60
+ task = Task.current_task()
61
+
62
+ # Make sure the automatic pytorch and matplotlib bindings are disabled!
63
+ # We are logging these plots and model files manually in the integration
64
+ PatchPyTorchModelIO.update_current_task(None)
65
+ PatchedMatplotlib.update_current_task(None)
66
+ else:
67
+ task = Task.init(project_name=trainer.args.project or 'YOLOv8',
68
+ task_name=trainer.args.name,
69
+ tags=['YOLOv8'],
70
+ output_uri=True,
71
+ reuse_last_task_id=False,
72
+ auto_connect_frameworks={
73
+ 'pytorch': False,
74
+ 'matplotlib': False})
75
+ LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, '
76
+ 'please add clearml-init and connect your arguments before initializing YOLO.')
30
77
  task.connect(vars(trainer.args), name='General')
31
78
  except Exception as e:
32
79
  LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
@@ -34,27 +81,47 @@ def on_pretrain_routine_start(trainer):
34
81
 
35
82
  def on_train_epoch_end(trainer):
36
83
  if trainer.epoch == 1:
37
- _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic', trainer.epoch)
84
+ _log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
38
85
 
39
86
 
40
87
  def on_fit_epoch_end(trainer):
41
- task = Task.current_task()
42
- if task and trainer.epoch == 0:
88
+ # You should have access to the validation bboxes under jdict
89
+ Task.current_task().get_logger().report_scalar(title='Epoch Time',
90
+ series='Epoch Time',
91
+ value=trainer.epoch_time,
92
+ iteration=trainer.epoch)
93
+ if trainer.epoch == 0:
43
94
  model_info = {
44
95
  'model/parameters': get_num_params(trainer.model),
45
96
  'model/GFLOPs': round(get_flops(trainer.model), 3),
46
97
  'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
47
- task.connect(model_info, name='Model')
98
+ for k, v in model_info.items():
99
+ Task.current_task().get_logger().report_single_value(k, v)
100
+
101
+
102
+ def on_val_end(validator):
103
+ # Log val_labels and val_pred
104
+ _log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
48
105
 
49
106
 
50
107
  def on_train_end(trainer):
51
- task = Task.current_task()
52
- if task:
53
- task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
108
+ # Log final results, CM matrix + PR plots
109
+ files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
110
+ files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
111
+ for f in files:
112
+ _log_plot(title=f.stem, plot_path=f)
113
+ # Report final metrics
114
+ for k, v in trainer.validator.metrics.results_dict.items():
115
+ Task.current_task().get_logger().report_single_value(k, v)
116
+ # Log the final model
117
+ Task.current_task().update_output_model(model_path=str(trainer.best),
118
+ model_name=trainer.args.name,
119
+ auto_delete_file=False)
54
120
 
55
121
 
56
122
  callbacks = {
57
123
  'on_pretrain_routine_start': on_pretrain_routine_start,
58
124
  'on_train_epoch_end': on_train_epoch_end,
59
125
  'on_fit_epoch_end': on_fit_epoch_end,
126
+ 'on_val_end': on_val_end,
60
127
  'on_train_end': on_train_end} if clearml else {}