ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py CHANGED
@@ -7,16 +7,54 @@ from pathlib import Path
7
7
  import torch
8
8
  import torch.nn as nn
9
9
 
10
- from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost,
11
- C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
12
- DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3,
13
- RepConv, ResNetLayer, RTDETRDecoder, Segment)
10
+ from ultralytics.nn.modules import (
11
+ AIFI,
12
+ C1,
13
+ C2,
14
+ C3,
15
+ C3TR,
16
+ OBB,
17
+ SPP,
18
+ SPPF,
19
+ Bottleneck,
20
+ BottleneckCSP,
21
+ C2f,
22
+ C3Ghost,
23
+ C3x,
24
+ Classify,
25
+ Concat,
26
+ Conv,
27
+ Conv2,
28
+ ConvTranspose,
29
+ Detect,
30
+ DWConv,
31
+ DWConvTranspose2d,
32
+ Focus,
33
+ GhostBottleneck,
34
+ GhostConv,
35
+ HGBlock,
36
+ HGStem,
37
+ Pose,
38
+ RepC3,
39
+ RepConv,
40
+ ResNetLayer,
41
+ RTDETRDecoder,
42
+ Segment,
43
+ )
14
44
  from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
15
45
  from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
16
46
  from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
17
47
  from ultralytics.utils.plotting import feature_visualization
18
- from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
19
- make_divisible, model_info, scale_img, time_sync)
48
+ from ultralytics.utils.torch_utils import (
49
+ fuse_conv_and_bn,
50
+ fuse_deconv_and_bn,
51
+ initialize_weights,
52
+ intersect_dicts,
53
+ make_divisible,
54
+ model_info,
55
+ scale_img,
56
+ time_sync,
57
+ )
20
58
 
21
59
  try:
22
60
  import thop
@@ -90,8 +128,10 @@ class BaseModel(nn.Module):
90
128
 
91
129
  def _predict_augment(self, x):
92
130
  """Perform augmentations on input image x and return augmented inference."""
93
- LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. '
94
- f'Reverting to single-scale inference instead.')
131
+ LOGGER.warning(
132
+ f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
133
+ f"Reverting to single-scale inference instead."
134
+ )
95
135
  return self._predict_once(x)
96
136
 
97
137
  def _profile_one_layer(self, m, x, dt):
@@ -108,14 +148,14 @@ class BaseModel(nn.Module):
108
148
  None
109
149
  """
110
150
  c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
111
- flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
151
+ flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs
112
152
  t = time_sync()
113
153
  for _ in range(10):
114
154
  m(x.copy() if c else x)
115
155
  dt.append((time_sync() - t) * 100)
116
156
  if m == self.model[0]:
117
157
  LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
118
- LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}')
158
+ LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}")
119
159
  if c:
120
160
  LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
121
161
 
@@ -129,15 +169,15 @@ class BaseModel(nn.Module):
129
169
  """
130
170
  if not self.is_fused():
131
171
  for m in self.model.modules():
132
- if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
172
+ if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
133
173
  if isinstance(m, Conv2):
134
174
  m.fuse_convs()
135
175
  m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
136
- delattr(m, 'bn') # remove batchnorm
176
+ delattr(m, "bn") # remove batchnorm
137
177
  m.forward = m.forward_fuse # update forward
138
- if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
178
+ if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
139
179
  m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
140
- delattr(m, 'bn') # remove batchnorm
180
+ delattr(m, "bn") # remove batchnorm
141
181
  m.forward = m.forward_fuse # update forward
142
182
  if isinstance(m, RepConv):
143
183
  m.fuse_convs()
@@ -156,7 +196,7 @@ class BaseModel(nn.Module):
156
196
  Returns:
157
197
  (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
158
198
  """
159
- bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
199
+ bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
160
200
  return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
161
201
 
162
202
  def info(self, detailed=False, verbose=True, imgsz=640):
@@ -196,12 +236,12 @@ class BaseModel(nn.Module):
196
236
  weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
197
237
  verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
198
238
  """
199
- model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
239
+ model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
200
240
  csd = model.float().state_dict() # checkpoint state_dict as FP32
201
241
  csd = intersect_dicts(csd, self.state_dict()) # intersect
202
242
  self.load_state_dict(csd, strict=False) # load
203
243
  if verbose:
204
- LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
244
+ LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
205
245
 
206
246
  def loss(self, batch, preds=None):
207
247
  """
@@ -211,33 +251,33 @@ class BaseModel(nn.Module):
211
251
  batch (dict): Batch to compute loss on
212
252
  preds (torch.Tensor | List[torch.Tensor]): Predictions.
213
253
  """
214
- if not hasattr(self, 'criterion'):
254
+ if not hasattr(self, "criterion"):
215
255
  self.criterion = self.init_criterion()
216
256
 
217
- preds = self.forward(batch['img']) if preds is None else preds
257
+ preds = self.forward(batch["img"]) if preds is None else preds
218
258
  return self.criterion(preds, batch)
219
259
 
220
260
  def init_criterion(self):
221
261
  """Initialize the loss criterion for the BaseModel."""
222
- raise NotImplementedError('compute_loss() needs to be implemented by task heads')
262
+ raise NotImplementedError("compute_loss() needs to be implemented by task heads")
223
263
 
224
264
 
225
265
  class DetectionModel(BaseModel):
226
266
  """YOLOv8 detection model."""
227
267
 
228
- def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
268
+ def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
229
269
  """Initialize the YOLOv8 detection model with the given config and parameters."""
230
270
  super().__init__()
231
271
  self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
232
272
 
233
273
  # Define model
234
- ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
235
- if nc and nc != self.yaml['nc']:
274
+ ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
275
+ if nc and nc != self.yaml["nc"]:
236
276
  LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
237
- self.yaml['nc'] = nc # override YAML value
277
+ self.yaml["nc"] = nc # override YAML value
238
278
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
239
- self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
240
- self.inplace = self.yaml.get('inplace', True)
279
+ self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
280
+ self.inplace = self.yaml.get("inplace", True)
241
281
 
242
282
  # Build strides
243
283
  m = self.model[-1] # Detect()
@@ -255,7 +295,7 @@ class DetectionModel(BaseModel):
255
295
  initialize_weights(self)
256
296
  if verbose:
257
297
  self.info()
258
- LOGGER.info('')
298
+ LOGGER.info("")
259
299
 
260
300
  def _predict_augment(self, x):
261
301
  """Perform augmentations on input image x and return augmented inference and train outputs."""
@@ -285,9 +325,9 @@ class DetectionModel(BaseModel):
285
325
  def _clip_augmented(self, y):
286
326
  """Clip YOLO augmented inference tails."""
287
327
  nl = self.model[-1].nl # number of detection layers (P3-P5)
288
- g = sum(4 ** x for x in range(nl)) # grid points
328
+ g = sum(4**x for x in range(nl)) # grid points
289
329
  e = 1 # exclude layer count
290
- i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
330
+ i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices
291
331
  y[0] = y[0][..., :-i] # large
292
332
  i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
293
333
  y[-1] = y[-1][..., i:] # small
@@ -301,18 +341,19 @@ class DetectionModel(BaseModel):
301
341
  class OBBModel(DetectionModel):
302
342
  """"YOLOv8 Oriented Bounding Box (OBB) model."""
303
343
 
304
- def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True):
344
+ def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True):
305
345
  """Initialize YOLOv8 OBB model with given config and parameters."""
306
346
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
307
347
 
308
348
  def init_criterion(self):
349
+ """Initialize the loss criterion for the model."""
309
350
  return v8OBBLoss(self)
310
351
 
311
352
 
312
353
  class SegmentationModel(DetectionModel):
313
354
  """YOLOv8 segmentation model."""
314
355
 
315
- def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
356
+ def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True):
316
357
  """Initialize YOLOv8 segmentation model with given config and parameters."""
317
358
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
318
359
 
@@ -324,13 +365,13 @@ class SegmentationModel(DetectionModel):
324
365
  class PoseModel(DetectionModel):
325
366
  """YOLOv8 pose model."""
326
367
 
327
- def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
368
+ def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
328
369
  """Initialize YOLOv8 Pose model."""
329
370
  if not isinstance(cfg, dict):
330
371
  cfg = yaml_model_load(cfg) # load model YAML
331
- if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
372
+ if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
332
373
  LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
333
- cfg['kpt_shape'] = data_kpt_shape
374
+ cfg["kpt_shape"] = data_kpt_shape
334
375
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
335
376
 
336
377
  def init_criterion(self):
@@ -341,7 +382,7 @@ class PoseModel(DetectionModel):
341
382
  class ClassificationModel(BaseModel):
342
383
  """YOLOv8 classification model."""
343
384
 
344
- def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True):
385
+ def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True):
345
386
  """Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
346
387
  super().__init__()
347
388
  self._from_yaml(cfg, ch, nc, verbose)
@@ -351,21 +392,21 @@ class ClassificationModel(BaseModel):
351
392
  self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
352
393
 
353
394
  # Define model
354
- ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
355
- if nc and nc != self.yaml['nc']:
395
+ ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
396
+ if nc and nc != self.yaml["nc"]:
356
397
  LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
357
- self.yaml['nc'] = nc # override YAML value
358
- elif not nc and not self.yaml.get('nc', None):
359
- raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
398
+ self.yaml["nc"] = nc # override YAML value
399
+ elif not nc and not self.yaml.get("nc", None):
400
+ raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
360
401
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
361
402
  self.stride = torch.Tensor([1]) # no stride constraints
362
- self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
403
+ self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
363
404
  self.info()
364
405
 
365
406
  @staticmethod
366
407
  def reshape_outputs(model, nc):
367
408
  """Update a TorchVision classification model to class count 'n' if required."""
368
- name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
409
+ name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
369
410
  if isinstance(m, Classify): # YOLO Classify() head
370
411
  if m.linear.out_features != nc:
371
412
  m.linear = nn.Linear(m.linear.in_features, nc)
@@ -408,7 +449,7 @@ class RTDETRDetectionModel(DetectionModel):
408
449
  predict: Performs a forward pass through the network and returns the output.
409
450
  """
410
451
 
411
- def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
452
+ def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
412
453
  """
413
454
  Initialize the RTDETRDetectionModel.
414
455
 
@@ -437,39 +478,39 @@ class RTDETRDetectionModel(DetectionModel):
437
478
  Returns:
438
479
  (tuple): A tuple containing the total loss and main three losses in a tensor.
439
480
  """
440
- if not hasattr(self, 'criterion'):
481
+ if not hasattr(self, "criterion"):
441
482
  self.criterion = self.init_criterion()
442
483
 
443
- img = batch['img']
484
+ img = batch["img"]
444
485
  # NOTE: preprocess gt_bbox and gt_labels to list.
445
486
  bs = len(img)
446
- batch_idx = batch['batch_idx']
487
+ batch_idx = batch["batch_idx"]
447
488
  gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
448
489
  targets = {
449
- 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
450
- 'bboxes': batch['bboxes'].to(device=img.device),
451
- 'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
452
- 'gt_groups': gt_groups}
490
+ "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
491
+ "bboxes": batch["bboxes"].to(device=img.device),
492
+ "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
493
+ "gt_groups": gt_groups,
494
+ }
453
495
 
454
496
  preds = self.predict(img, batch=targets) if preds is None else preds
455
497
  dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
456
498
  if dn_meta is None:
457
499
  dn_bboxes, dn_scores = None, None
458
500
  else:
459
- dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
460
- dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
501
+ dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
502
+ dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)
461
503
 
462
504
  dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
463
505
  dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
464
506
 
465
- loss = self.criterion((dec_bboxes, dec_scores),
466
- targets,
467
- dn_bboxes=dn_bboxes,
468
- dn_scores=dn_scores,
469
- dn_meta=dn_meta)
507
+ loss = self.criterion(
508
+ (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
509
+ )
470
510
  # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
471
- return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
472
- device=img.device)
511
+ return sum(loss.values()), torch.as_tensor(
512
+ [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
513
+ )
473
514
 
474
515
  def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
475
516
  """
@@ -552,6 +593,7 @@ def temporary_modules(modules=None):
552
593
 
553
594
  import importlib
554
595
  import sys
596
+
555
597
  try:
556
598
  # Set modules in sys.modules under their old name
557
599
  for old, new in modules.items():
@@ -579,30 +621,38 @@ def torch_safe_load(weight):
579
621
  """
580
622
  from ultralytics.utils.downloads import attempt_download_asset
581
623
 
582
- check_suffix(file=weight, suffix='.pt')
624
+ check_suffix(file=weight, suffix=".pt")
583
625
  file = attempt_download_asset(weight) # search online if missing locally
584
626
  try:
585
- with temporary_modules({
586
- 'ultralytics.yolo.utils': 'ultralytics.utils',
587
- 'ultralytics.yolo.v8': 'ultralytics.models.yolo',
588
- 'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
589
- return torch.load(file, map_location='cpu'), file # load
627
+ with temporary_modules(
628
+ {
629
+ "ultralytics.yolo.utils": "ultralytics.utils",
630
+ "ultralytics.yolo.v8": "ultralytics.models.yolo",
631
+ "ultralytics.yolo.data": "ultralytics.data",
632
+ }
633
+ ): # for legacy 8.0 Classify and Pose models
634
+ return torch.load(file, map_location="cpu"), file # load
590
635
 
591
636
  except ModuleNotFoundError as e: # e.name is missing module name
592
- if e.name == 'models':
637
+ if e.name == "models":
593
638
  raise TypeError(
594
- emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
595
- f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
596
- f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
597
- f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
598
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
599
- LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
600
- f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
601
- f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
602
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
639
+ emojis(
640
+ f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
641
+ f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
642
+ f"YOLOv8 at https://github.com/ultralytics/ultralytics."
643
+ f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
644
+ f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
645
+ )
646
+ ) from e
647
+ LOGGER.warning(
648
+ f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
649
+ f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
650
+ f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
651
+ f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
652
+ )
603
653
  check_requirements(e.name) # install missing module
604
654
 
605
- return torch.load(file, map_location='cpu'), file # load
655
+ return torch.load(file, map_location="cpu"), file # load
606
656
 
607
657
 
608
658
  def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
@@ -611,25 +661,25 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
611
661
  ensemble = Ensemble()
612
662
  for w in weights if isinstance(weights, list) else [weights]:
613
663
  ckpt, w = torch_safe_load(w) # load ckpt
614
- args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
615
- model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
664
+ args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
665
+ model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
616
666
 
617
667
  # Model compatibility updates
618
668
  model.args = args # attach args to model
619
669
  model.pt_path = w # attach *.pt file path to model
620
670
  model.task = guess_model_task(model)
621
- if not hasattr(model, 'stride'):
622
- model.stride = torch.tensor([32.])
671
+ if not hasattr(model, "stride"):
672
+ model.stride = torch.tensor([32.0])
623
673
 
624
674
  # Append
625
- ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
675
+ ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
626
676
 
627
677
  # Module updates
628
678
  for m in ensemble.modules():
629
679
  t = type(m)
630
680
  if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
631
681
  m.inplace = inplace
632
- elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
682
+ elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
633
683
  m.recompute_scale_factor = None # torch 1.11.0 compatibility
634
684
 
635
685
  # Return model
@@ -637,35 +687,35 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
637
687
  return ensemble[-1]
638
688
 
639
689
  # Return ensemble
640
- LOGGER.info(f'Ensemble created with {weights}\n')
641
- for k in 'names', 'nc', 'yaml':
690
+ LOGGER.info(f"Ensemble created with {weights}\n")
691
+ for k in "names", "nc", "yaml":
642
692
  setattr(ensemble, k, getattr(ensemble[0], k))
643
693
  ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
644
- assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
694
+ assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
645
695
  return ensemble
646
696
 
647
697
 
648
698
  def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
649
699
  """Loads a single model weights."""
650
700
  ckpt, weight = torch_safe_load(weight) # load ckpt
651
- args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
652
- model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
701
+ args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
702
+ model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
653
703
 
654
704
  # Model compatibility updates
655
705
  model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
656
706
  model.pt_path = weight # attach *.pt file path to model
657
707
  model.task = guess_model_task(model)
658
- if not hasattr(model, 'stride'):
659
- model.stride = torch.tensor([32.])
708
+ if not hasattr(model, "stride"):
709
+ model.stride = torch.tensor([32.0])
660
710
 
661
- model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
711
+ model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode
662
712
 
663
713
  # Module updates
664
714
  for m in model.modules():
665
715
  t = type(m)
666
716
  if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
667
717
  m.inplace = inplace
668
- elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
718
+ elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
669
719
  m.recompute_scale_factor = None # torch 1.11.0 compatibility
670
720
 
671
721
  # Return model and ckpt
@@ -677,11 +727,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
677
727
  import ast
678
728
 
679
729
  # Args
680
- max_channels = float('inf')
681
- nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
682
- depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
730
+ max_channels = float("inf")
731
+ nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
732
+ depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
683
733
  if scales:
684
- scale = d.get('scale')
734
+ scale = d.get("scale")
685
735
  if not scale:
686
736
  scale = tuple(scales.keys())[0]
687
737
  LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
@@ -696,16 +746,37 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
696
746
  LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
697
747
  ch = [ch]
698
748
  layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
699
- for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
700
- m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
749
+ for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
750
+ m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
701
751
  for j, a in enumerate(args):
702
752
  if isinstance(a, str):
703
753
  with contextlib.suppress(ValueError):
704
754
  args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
705
755
 
706
756
  n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
707
- if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
708
- BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
757
+ if m in (
758
+ Classify,
759
+ Conv,
760
+ ConvTranspose,
761
+ GhostConv,
762
+ Bottleneck,
763
+ GhostBottleneck,
764
+ SPP,
765
+ SPPF,
766
+ DWConv,
767
+ Focus,
768
+ BottleneckCSP,
769
+ C1,
770
+ C2,
771
+ C2f,
772
+ C3,
773
+ C3TR,
774
+ C3Ghost,
775
+ nn.ConvTranspose2d,
776
+ DWConvTranspose2d,
777
+ C3x,
778
+ RepC3,
779
+ ):
709
780
  c1, c2 = ch[f], args[0]
710
781
  if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
711
782
  c2 = make_divisible(min(c2, max_channels) * width, 8)
@@ -738,11 +809,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
738
809
  c2 = ch[f]
739
810
 
740
811
  m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
741
- t = str(m)[8:-2].replace('__main__.', '') # module type
812
+ t = str(m)[8:-2].replace("__main__.", "") # module type
742
813
  m.np = sum(x.numel() for x in m_.parameters()) # number params
743
814
  m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
744
815
  if verbose:
745
- LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
816
+ LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
746
817
  save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
747
818
  layers.append(m_)
748
819
  if i == 0:
@@ -756,16 +827,16 @@ def yaml_model_load(path):
756
827
  import re
757
828
 
758
829
  path = Path(path)
759
- if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
760
- new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
761
- LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
830
+ if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
831
+ new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
832
+ LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
762
833
  path = path.with_name(new_stem + path.suffix)
763
834
 
764
- unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
835
+ unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
765
836
  yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
766
837
  d = yaml_load(yaml_file) # model dict
767
- d['scale'] = guess_model_scale(path)
768
- d['yaml_file'] = str(path)
838
+ d["scale"] = guess_model_scale(path)
839
+ d["yaml_file"] = str(path)
769
840
  return d
770
841
 
771
842
 
@@ -783,8 +854,9 @@ def guess_model_scale(model_path):
783
854
  """
784
855
  with contextlib.suppress(AttributeError):
785
856
  import re
786
- return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
787
- return ''
857
+
858
+ return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
859
+ return ""
788
860
 
789
861
 
790
862
  def guess_model_task(model):
@@ -803,17 +875,17 @@ def guess_model_task(model):
803
875
 
804
876
  def cfg2task(cfg):
805
877
  """Guess from YAML dictionary."""
806
- m = cfg['head'][-1][-2].lower() # output module name
807
- if m in ('classify', 'classifier', 'cls', 'fc'):
808
- return 'classify'
809
- if m == 'detect':
810
- return 'detect'
811
- if m == 'segment':
812
- return 'segment'
813
- if m == 'pose':
814
- return 'pose'
815
- if m == 'obb':
816
- return 'obb'
878
+ m = cfg["head"][-1][-2].lower() # output module name
879
+ if m in ("classify", "classifier", "cls", "fc"):
880
+ return "classify"
881
+ if m == "detect":
882
+ return "detect"
883
+ if m == "segment":
884
+ return "segment"
885
+ if m == "pose":
886
+ return "pose"
887
+ if m == "obb":
888
+ return "obb"
817
889
 
818
890
  # Guess from model cfg
819
891
  if isinstance(model, dict):
@@ -822,40 +894,42 @@ def guess_model_task(model):
822
894
 
823
895
  # Guess from PyTorch model
824
896
  if isinstance(model, nn.Module): # PyTorch model
825
- for x in 'model.args', 'model.model.args', 'model.model.model.args':
897
+ for x in "model.args", "model.model.args", "model.model.model.args":
826
898
  with contextlib.suppress(Exception):
827
- return eval(x)['task']
828
- for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
899
+ return eval(x)["task"]
900
+ for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
829
901
  with contextlib.suppress(Exception):
830
902
  return cfg2task(eval(x))
831
903
 
832
904
  for m in model.modules():
833
905
  if isinstance(m, Detect):
834
- return 'detect'
906
+ return "detect"
835
907
  elif isinstance(m, Segment):
836
- return 'segment'
908
+ return "segment"
837
909
  elif isinstance(m, Classify):
838
- return 'classify'
910
+ return "classify"
839
911
  elif isinstance(m, Pose):
840
- return 'pose'
912
+ return "pose"
841
913
  elif isinstance(m, OBB):
842
- return 'obb'
914
+ return "obb"
843
915
 
844
916
  # Guess from model filename
845
917
  if isinstance(model, (str, Path)):
846
918
  model = Path(model)
847
- if '-seg' in model.stem or 'segment' in model.parts:
848
- return 'segment'
849
- elif '-cls' in model.stem or 'classify' in model.parts:
850
- return 'classify'
851
- elif '-pose' in model.stem or 'pose' in model.parts:
852
- return 'pose'
853
- elif '-obb' in model.stem or 'obb' in model.parts:
854
- return 'obb'
855
- elif 'detect' in model.parts:
856
- return 'detect'
919
+ if "-seg" in model.stem or "segment" in model.parts:
920
+ return "segment"
921
+ elif "-cls" in model.stem or "classify" in model.parts:
922
+ return "classify"
923
+ elif "-pose" in model.stem or "pose" in model.parts:
924
+ return "pose"
925
+ elif "-obb" in model.stem or "obb" in model.parts:
926
+ return "obb"
927
+ elif "detect" in model.parts:
928
+ return "detect"
857
929
 
858
930
  # Unable to determine task from model
859
- LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
860
- "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.")
861
- return 'detect' # assume detect
931
+ LOGGER.warning(
932
+ "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
933
+ "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
934
+ )
935
+ return "detect" # assume detect