ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -5,10 +5,11 @@ import sys
5
5
  from pathlib import Path
6
6
  from typing import Union
7
7
 
8
+ from hub_sdk.config import HUB_WEB_ROOT
9
+
8
10
  from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
9
- from ultralytics.hub.utils import HUB_WEB_ROOT
10
11
  from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
11
- from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
12
+ from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
12
13
 
13
14
 
14
15
  class Model(nn.Module):
@@ -52,7 +53,7 @@ class Model(nn.Module):
52
53
  list(ultralytics.engine.results.Results): The prediction results.
53
54
  """
54
55
 
55
- def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
56
+ def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None) -> None:
56
57
  """
57
58
  Initializes the YOLO model.
58
59
 
@@ -76,8 +77,8 @@ class Model(nn.Module):
76
77
 
77
78
  # Check if Ultralytics HUB model from https://hub.ultralytics.com
78
79
  if self.is_hub_model(model):
79
- from ultralytics.hub.session import HUBTrainingSession
80
- self.session = HUBTrainingSession(model)
80
+ # Fetch model from HUB
81
+ self.session = self._get_hub_session(model)
81
82
  model = self.session.model_file
82
83
 
83
84
  # Check if Triton Server model
@@ -88,29 +89,43 @@ class Model(nn.Module):
88
89
 
89
90
  # Load or create new YOLO model
90
91
  model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
91
- if Path(model).suffix in ('.yaml', '.yml'):
92
+ if Path(model).suffix in (".yaml", ".yml"):
92
93
  self._new(model, task)
93
94
  else:
94
95
  self._load(model, task)
95
96
 
97
+ self.model_name = model
98
+
96
99
  def __call__(self, source=None, stream=False, **kwargs):
97
100
  """Calls the predict() method with given arguments to perform object detection."""
98
101
  return self.predict(source, stream, **kwargs)
99
102
 
103
+ @staticmethod
104
+ def _get_hub_session(model: str):
105
+ """Creates a session for Hub Training."""
106
+ from ultralytics.hub.session import HUBTrainingSession
107
+
108
+ session = HUBTrainingSession(model)
109
+ return session if session.client.authenticated else None
110
+
100
111
  @staticmethod
101
112
  def is_triton_model(model):
102
113
  """Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
103
114
  from urllib.parse import urlsplit
115
+
104
116
  url = urlsplit(model)
105
- return url.netloc and url.path and url.scheme in {'http', 'grpc'}
117
+ return url.netloc and url.path and url.scheme in {"http", "grpc"}
106
118
 
107
119
  @staticmethod
108
120
  def is_hub_model(model):
109
121
  """Check if the provided model is a HUB model."""
110
- return any((
111
- model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
112
- [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
113
- len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
122
+ return any(
123
+ (
124
+ model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
125
+ [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
126
+ len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"),
127
+ )
128
+ ) # MODELID
114
129
 
115
130
  def _new(self, cfg: str, task=None, model=None, verbose=True):
116
131
  """
@@ -125,9 +140,9 @@ class Model(nn.Module):
125
140
  cfg_dict = yaml_model_load(cfg)
126
141
  self.cfg = cfg
127
142
  self.task = task or guess_model_task(cfg_dict)
128
- self.model = (model or self._smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model
129
- self.overrides['model'] = self.cfg
130
- self.overrides['task'] = self.task
143
+ self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model
144
+ self.overrides["model"] = self.cfg
145
+ self.overrides["task"] = self.task
131
146
 
132
147
  # Below added to allow export from YAMLs
133
148
  self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
@@ -142,9 +157,9 @@ class Model(nn.Module):
142
157
  task (str | None): model task
143
158
  """
144
159
  suffix = Path(weights).suffix
145
- if suffix == '.pt':
160
+ if suffix == ".pt":
146
161
  self.model, self.ckpt = attempt_load_one_weight(weights)
147
- self.task = self.model.args['task']
162
+ self.task = self.model.args["task"]
148
163
  self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
149
164
  self.ckpt_path = self.model.pt_path
150
165
  else:
@@ -152,12 +167,12 @@ class Model(nn.Module):
152
167
  self.model, self.ckpt = weights, None
153
168
  self.task = task or guess_model_task(weights)
154
169
  self.ckpt_path = weights
155
- self.overrides['model'] = weights
156
- self.overrides['task'] = self.task
170
+ self.overrides["model"] = weights
171
+ self.overrides["task"] = self.task
157
172
 
158
173
  def _check_is_pytorch_model(self):
159
174
  """Raises TypeError is model is not a PyTorch model."""
160
- pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
175
+ pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
161
176
  pt_module = isinstance(self.model, nn.Module)
162
177
  if not (pt_module or pt_str):
163
178
  raise TypeError(
@@ -165,19 +180,20 @@ class Model(nn.Module):
165
180
  f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
166
181
  f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
167
182
  f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
168
- f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'")
183
+ f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
184
+ )
169
185
 
170
186
  def reset_weights(self):
171
187
  """Resets the model modules parameters to randomly initialized values, losing all training information."""
172
188
  self._check_is_pytorch_model()
173
189
  for m in self.model.modules():
174
- if hasattr(m, 'reset_parameters'):
190
+ if hasattr(m, "reset_parameters"):
175
191
  m.reset_parameters()
176
192
  for p in self.model.parameters():
177
193
  p.requires_grad = True
178
194
  return self
179
195
 
180
- def load(self, weights='yolov8n.pt'):
196
+ def load(self, weights="yolov8n.pt"):
181
197
  """Transfers parameters with matching names and shapes from 'weights' to model."""
182
198
  self._check_is_pytorch_model()
183
199
  if isinstance(weights, (str, Path)):
@@ -215,8 +231,8 @@ class Model(nn.Module):
215
231
  Returns:
216
232
  (List[torch.Tensor]): A list of image embeddings.
217
233
  """
218
- if not kwargs.get('embed'):
219
- kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
234
+ if not kwargs.get("embed"):
235
+ kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
220
236
  return self.predict(source, stream, **kwargs)
221
237
 
222
238
  def predict(self, source=None, stream=False, predictor=None, **kwargs):
@@ -238,21 +254,22 @@ class Model(nn.Module):
238
254
  source = ASSETS
239
255
  LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
240
256
 
241
- is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
242
- x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
257
+ is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
258
+ x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
259
+ )
243
260
 
244
- custom = {'conf': 0.25, 'save': is_cli} # method defaults
245
- args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right
246
- prompts = args.pop('prompts', None) # for SAM-type models
261
+ custom = {"conf": 0.25, "save": is_cli} # method defaults
262
+ args = {**self.overrides, **custom, **kwargs, "mode": "predict"} # highest priority args on the right
263
+ prompts = args.pop("prompts", None) # for SAM-type models
247
264
 
248
265
  if not self.predictor:
249
- self.predictor = predictor or self._smart_load('predictor')(overrides=args, _callbacks=self.callbacks)
266
+ self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
250
267
  self.predictor.setup_model(model=self.model, verbose=is_cli)
251
268
  else: # only update args if predictor is already setup
252
269
  self.predictor.args = get_cfg(self.predictor.args, args)
253
- if 'project' in args or 'name' in args:
270
+ if "project" in args or "name" in args:
254
271
  self.predictor.save_dir = get_save_dir(self.predictor.args)
255
- if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models
272
+ if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
256
273
  self.predictor.set_prompts(prompts)
257
274
  return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
258
275
 
@@ -269,11 +286,12 @@ class Model(nn.Module):
269
286
  Returns:
270
287
  (List[ultralytics.engine.results.Results]): The tracking results.
271
288
  """
272
- if not hasattr(self.predictor, 'trackers'):
289
+ if not hasattr(self.predictor, "trackers"):
273
290
  from ultralytics.trackers import register_tracker
291
+
274
292
  register_tracker(self, persist)
275
- kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input
276
- kwargs['mode'] = 'track'
293
+ kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input
294
+ kwargs["mode"] = "track"
277
295
  return self.predict(source=source, stream=stream, **kwargs)
278
296
 
279
297
  def val(self, validator=None, **kwargs):
@@ -284,10 +302,10 @@ class Model(nn.Module):
284
302
  validator (BaseValidator): Customized validator.
285
303
  **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
286
304
  """
287
- custom = {'rect': True} # method defaults
288
- args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
305
+ custom = {"rect": True} # method defaults
306
+ args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
289
307
 
290
- validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks)
308
+ validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
291
309
  validator(model=self.model)
292
310
  self.metrics = validator.metrics
293
311
  return validator.metrics
@@ -302,16 +320,17 @@ class Model(nn.Module):
302
320
  self._check_is_pytorch_model()
303
321
  from ultralytics.utils.benchmarks import benchmark
304
322
 
305
- custom = {'verbose': False} # method defaults
306
- args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'}
323
+ custom = {"verbose": False} # method defaults
324
+ args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
307
325
  return benchmark(
308
326
  model=self,
309
- data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
310
- imgsz=args['imgsz'],
311
- half=args['half'],
312
- int8=args['int8'],
313
- device=args['device'],
314
- verbose=kwargs.get('verbose'))
327
+ data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets
328
+ imgsz=args["imgsz"],
329
+ half=args["half"],
330
+ int8=args["int8"],
331
+ device=args["device"],
332
+ verbose=kwargs.get("verbose"),
333
+ )
315
334
 
316
335
  def export(self, **kwargs):
317
336
  """
@@ -323,8 +342,8 @@ class Model(nn.Module):
323
342
  self._check_is_pytorch_model()
324
343
  from .exporter import Exporter
325
344
 
326
- custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False} # method defaults
327
- args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right
345
+ custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults
346
+ args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
328
347
  return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
329
348
 
330
349
  def train(self, trainer=None, **kwargs):
@@ -336,22 +355,37 @@ class Model(nn.Module):
336
355
  **kwargs (Any): Any number of arguments representing the training configuration.
337
356
  """
338
357
  self._check_is_pytorch_model()
339
- if self.session: # Ultralytics HUB session
358
+ if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
340
359
  if any(kwargs):
341
- LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
342
- kwargs = self.session.train_args
360
+ LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
361
+ kwargs = self.session.train_args # overwrite kwargs
362
+
343
363
  checks.check_pip_update_available()
344
364
 
345
- overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
346
- custom = {'data': DEFAULT_CFG_DICT['data'] or TASK2DATA[self.task]} # method defaults
347
- args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
348
- if args.get('resume'):
349
- args['resume'] = self.ckpt_path
365
+ overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
366
+ custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
367
+ args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
368
+ if args.get("resume"):
369
+ args["resume"] = self.ckpt_path
350
370
 
351
- self.trainer = (trainer or self._smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
352
- if not args.get('resume'): # manually set model only if not resuming
371
+ self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
372
+ if not args.get("resume"): # manually set model only if not resuming
353
373
  self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
354
374
  self.model = self.trainer.model
375
+
376
+ if SETTINGS["hub"] is True and not self.session:
377
+ # Create a model in HUB
378
+ try:
379
+ self.session = self._get_hub_session(self.model_name)
380
+ if self.session:
381
+ self.session.create_model(args)
382
+ # Check model was created
383
+ if not getattr(self.session.model, "id", None):
384
+ self.session = None
385
+ except PermissionError:
386
+ # Ignore permission error
387
+ pass
388
+
355
389
  self.trainer.hub_session = self.session # attach optional HUB session
356
390
  self.trainer.train()
357
391
  # Update model and cfg after training
@@ -359,7 +393,7 @@ class Model(nn.Module):
359
393
  ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
360
394
  self.model, _ = attempt_load_one_weight(ckpt)
361
395
  self.overrides = self.model.args
362
- self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
396
+ self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
363
397
  return self.metrics
364
398
 
365
399
  def tune(self, use_ray=False, iterations=10, *args, **kwargs):
@@ -372,12 +406,13 @@ class Model(nn.Module):
372
406
  self._check_is_pytorch_model()
373
407
  if use_ray:
374
408
  from ultralytics.utils.tuner import run_ray_tune
409
+
375
410
  return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
376
411
  else:
377
412
  from .tuner import Tuner
378
413
 
379
414
  custom = {} # method defaults
380
- args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
415
+ args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
381
416
  return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
382
417
 
383
418
  def _apply(self, fn):
@@ -385,13 +420,13 @@ class Model(nn.Module):
385
420
  self._check_is_pytorch_model()
386
421
  self = super()._apply(fn) # noqa
387
422
  self.predictor = None # reset predictor as device may have changed
388
- self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
423
+ self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
389
424
  return self
390
425
 
391
426
  @property
392
427
  def names(self):
393
428
  """Returns class names of the loaded model."""
394
- return self.model.names if hasattr(self.model, 'names') else None
429
+ return self.model.names if hasattr(self.model, "names") else None
395
430
 
396
431
  @property
397
432
  def device(self):
@@ -401,7 +436,7 @@ class Model(nn.Module):
401
436
  @property
402
437
  def transforms(self):
403
438
  """Returns transform of the loaded model."""
404
- return self.model.transforms if hasattr(self.model, 'transforms') else None
439
+ return self.model.transforms if hasattr(self.model, "transforms") else None
405
440
 
406
441
  def add_callback(self, event: str, func):
407
442
  """Add a callback."""
@@ -419,7 +454,7 @@ class Model(nn.Module):
419
454
  @staticmethod
420
455
  def _reset_ckpt_args(args):
421
456
  """Reset arguments when loading a PyTorch model."""
422
- include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
457
+ include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
423
458
  return {k: v for k, v in args.items() if k in include}
424
459
 
425
460
  # def __getattr__(self, attr):
@@ -435,7 +470,8 @@ class Model(nn.Module):
435
470
  name = self.__class__.__name__
436
471
  mode = inspect.stack()[1][3] # get the function name.
437
472
  raise NotImplementedError(
438
- emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e
473
+ emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")
474
+ ) from e
439
475
 
440
476
  @property
441
477
  def task_map(self):
@@ -445,4 +481,4 @@ class Model(nn.Module):
445
481
  Returns:
446
482
  task_map (dict): The map of model task to mode classes.
447
483
  """
448
- raise NotImplementedError('Please provide task map for your model!')
484
+ raise NotImplementedError("Please provide task map for your model!")
@@ -132,8 +132,11 @@ class BasePredictor:
132
132
 
133
133
  def inference(self, im, *args, **kwargs):
134
134
  """Runs inference on a given image using the specified model and arguments."""
135
- visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
136
- mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
135
+ visualize = (
136
+ increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
137
+ if self.args.visualize and (not self.source_type.tensor)
138
+ else False
139
+ )
137
140
  return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
138
141
 
139
142
  def pre_transform(self, im):
@@ -153,35 +156,38 @@ class BasePredictor:
153
156
  def write_results(self, idx, results, batch):
154
157
  """Write inference results to a file or directory."""
155
158
  p, im, _ = batch
156
- log_string = ''
159
+ log_string = ""
157
160
  if len(im.shape) == 3:
158
161
  im = im[None] # expand for batch dim
159
162
  if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
160
- log_string += f'{idx}: '
163
+ log_string += f"{idx}: "
161
164
  frame = self.dataset.count
162
165
  else:
163
- frame = getattr(self.dataset, 'frame', 0)
166
+ frame = getattr(self.dataset, "frame", 0)
164
167
  self.data_path = p
165
- self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
166
- log_string += '%gx%g ' % im.shape[2:] # print string
168
+ self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}")
169
+ log_string += "%gx%g " % im.shape[2:] # print string
167
170
  result = results[idx]
168
171
  log_string += result.verbose()
169
172
 
170
173
  if self.args.save or self.args.show: # Add bbox to image
171
174
  plot_args = {
172
- 'line_width': self.args.line_width,
173
- 'boxes': self.args.show_boxes,
174
- 'conf': self.args.show_conf,
175
- 'labels': self.args.show_labels}
175
+ "line_width": self.args.line_width,
176
+ "boxes": self.args.show_boxes,
177
+ "conf": self.args.show_conf,
178
+ "labels": self.args.show_labels,
179
+ }
176
180
  if not self.args.retina_masks:
177
- plot_args['im_gpu'] = im[idx]
181
+ plot_args["im_gpu"] = im[idx]
178
182
  self.plotted_img = result.plot(**plot_args)
179
183
  # Write
180
184
  if self.args.save_txt:
181
- result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
185
+ result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
182
186
  if self.args.save_crop:
183
- result.save_crop(save_dir=self.save_dir / 'crops',
184
- file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
187
+ result.save_crop(
188
+ save_dir=self.save_dir / "crops",
189
+ file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"),
190
+ )
185
191
 
186
192
  return log_string
187
193
 
@@ -210,17 +216,24 @@ class BasePredictor:
210
216
  def setup_source(self, source):
211
217
  """Sets up source and inference mode."""
212
218
  self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
213
- self.transforms = getattr(
214
- self.model.model, 'transforms', classify_transforms(
215
- self.imgsz[0], crop_fraction=self.args.crop_fraction)) if self.args.task == 'classify' else None
216
- self.dataset = load_inference_source(source=source,
217
- imgsz=self.imgsz,
218
- vid_stride=self.args.vid_stride,
219
- buffer=self.args.stream_buffer)
219
+ self.transforms = (
220
+ getattr(
221
+ self.model.model,
222
+ "transforms",
223
+ classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
224
+ )
225
+ if self.args.task == "classify"
226
+ else None
227
+ )
228
+ self.dataset = load_inference_source(
229
+ source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
230
+ )
220
231
  self.source_type = self.dataset.source_type
221
- if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
222
- len(self.dataset) > 1000 or # images
223
- any(getattr(self.dataset, 'video_flag', [False]))): # videos
232
+ if not getattr(self, "stream", True) and (
233
+ self.dataset.mode == "stream" # streams
234
+ or len(self.dataset) > 1000 # images
235
+ or any(getattr(self.dataset, "video_flag", [False]))
236
+ ): # videos
224
237
  LOGGER.warning(STREAM_WARNING)
225
238
  self.vid_path = [None] * self.dataset.bs
226
239
  self.vid_writer = [None] * self.dataset.bs
@@ -230,7 +243,7 @@ class BasePredictor:
230
243
  def stream_inference(self, source=None, model=None, *args, **kwargs):
231
244
  """Streams real-time inference on camera feed and saves results to file."""
232
245
  if self.args.verbose:
233
- LOGGER.info('')
246
+ LOGGER.info("")
234
247
 
235
248
  # Setup model
236
249
  if not self.model:
@@ -242,7 +255,7 @@ class BasePredictor:
242
255
 
243
256
  # Check if save_dir/ label file exists
244
257
  if self.args.save or self.args.save_txt:
245
- (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
258
+ (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
246
259
 
247
260
  # Warmup model
248
261
  if not self.done_warmup:
@@ -250,10 +263,10 @@ class BasePredictor:
250
263
  self.done_warmup = True
251
264
 
252
265
  self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
253
- self.run_callbacks('on_predict_start')
266
+ self.run_callbacks("on_predict_start")
254
267
 
255
268
  for batch in self.dataset:
256
- self.run_callbacks('on_predict_batch_start')
269
+ self.run_callbacks("on_predict_batch_start")
257
270
  self.batch = batch
258
271
  path, im0s, vid_cap, s = batch
259
272
 
@@ -272,15 +285,16 @@ class BasePredictor:
272
285
  with profilers[2]:
273
286
  self.results = self.postprocess(preds, im, im0s)
274
287
 
275
- self.run_callbacks('on_predict_postprocess_end')
288
+ self.run_callbacks("on_predict_postprocess_end")
276
289
  # Visualize, save, write results
277
290
  n = len(im0s)
278
291
  for i in range(n):
279
292
  self.seen += 1
280
293
  self.results[i].speed = {
281
- 'preprocess': profilers[0].dt * 1E3 / n,
282
- 'inference': profilers[1].dt * 1E3 / n,
283
- 'postprocess': profilers[2].dt * 1E3 / n}
294
+ "preprocess": profilers[0].dt * 1e3 / n,
295
+ "inference": profilers[1].dt * 1e3 / n,
296
+ "postprocess": profilers[2].dt * 1e3 / n,
297
+ }
284
298
  p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
285
299
  p = Path(p)
286
300
 
@@ -293,12 +307,12 @@ class BasePredictor:
293
307
  if self.args.save and self.plotted_img is not None:
294
308
  self.save_preds(vid_cap, i, str(self.save_dir / p.name))
295
309
 
296
- self.run_callbacks('on_predict_batch_end')
310
+ self.run_callbacks("on_predict_batch_end")
297
311
  yield from self.results
298
312
 
299
313
  # Print time (inference-only)
300
314
  if self.args.verbose:
301
- LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
315
+ LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms")
302
316
 
303
317
  # Release assets
304
318
  if isinstance(self.vid_writer[-1], cv2.VideoWriter):
@@ -306,25 +320,29 @@ class BasePredictor:
306
320
 
307
321
  # Print results
308
322
  if self.args.verbose and self.seen:
309
- t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
310
- LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
311
- f'{(1, 3, *im.shape[2:])}' % t)
323
+ t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
324
+ LOGGER.info(
325
+ f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
326
+ f"{(1, 3, *im.shape[2:])}" % t
327
+ )
312
328
  if self.args.save or self.args.save_txt or self.args.save_crop:
313
- nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
314
- s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
329
+ nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
330
+ s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
315
331
  LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
316
332
 
317
- self.run_callbacks('on_predict_end')
333
+ self.run_callbacks("on_predict_end")
318
334
 
319
335
  def setup_model(self, model, verbose=True):
320
336
  """Initialize YOLO model with given parameters and set it to evaluation mode."""
321
- self.model = AutoBackend(model or self.args.model,
322
- device=select_device(self.args.device, verbose=verbose),
323
- dnn=self.args.dnn,
324
- data=self.args.data,
325
- fp16=self.args.half,
326
- fuse=True,
327
- verbose=verbose)
337
+ self.model = AutoBackend(
338
+ model or self.args.model,
339
+ device=select_device(self.args.device, verbose=verbose),
340
+ dnn=self.args.dnn,
341
+ data=self.args.data,
342
+ fp16=self.args.half,
343
+ fuse=True,
344
+ verbose=verbose,
345
+ )
328
346
 
329
347
  self.device = self.model.device # update device
330
348
  self.args.half = self.model.fp16 # update half
@@ -333,18 +351,18 @@ class BasePredictor:
333
351
  def show(self, p):
334
352
  """Display an image in a window using OpenCV imshow()."""
335
353
  im0 = self.plotted_img
336
- if platform.system() == 'Linux' and p not in self.windows:
354
+ if platform.system() == "Linux" and p not in self.windows:
337
355
  self.windows.append(p)
338
356
  cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
339
357
  cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
340
358
  cv2.imshow(str(p), im0)
341
- cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
359
+ cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond
342
360
 
343
361
  def save_preds(self, vid_cap, idx, save_path):
344
362
  """Save video predictions as mp4 at specified path."""
345
363
  im0 = self.plotted_img
346
364
  # Save imgs
347
- if self.dataset.mode == 'image':
365
+ if self.dataset.mode == "image":
348
366
  cv2.imwrite(save_path, im0)
349
367
  else: # 'video' or 'stream'
350
368
  frames_path = f'{save_path.split(".", 1)[0]}_frames/'
@@ -361,15 +379,16 @@ class BasePredictor:
361
379
  h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
362
380
  else: # stream
363
381
  fps, w, h = 30, im0.shape[1], im0.shape[0]
364
- suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG')
365
- self.vid_writer[idx] = cv2.VideoWriter(str(Path(save_path).with_suffix(suffix)),
366
- cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
382
+ suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
383
+ self.vid_writer[idx] = cv2.VideoWriter(
384
+ str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)
385
+ )
367
386
  # Write video
368
387
  self.vid_writer[idx].write(im0)
369
388
 
370
389
  # Write frame
371
390
  if self.args.save_frames:
372
- cv2.imwrite(f'{frames_path}{self.vid_frame[idx]}.jpg', im0)
391
+ cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0)
373
392
  self.vid_frame[idx] += 1
374
393
 
375
394
  def run_callbacks(self, event: str):