ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -23,14 +23,31 @@ from torch import nn, optim
23
23
  from ultralytics.cfg import get_cfg, get_save_dir
24
24
  from ultralytics.data.utils import check_cls_dataset, check_det_dataset
25
25
  from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
26
- from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
27
- yaml_save)
26
+ from ultralytics.utils import (
27
+ DEFAULT_CFG,
28
+ LOGGER,
29
+ RANK,
30
+ TQDM,
31
+ __version__,
32
+ callbacks,
33
+ clean_url,
34
+ colorstr,
35
+ emojis,
36
+ yaml_save,
37
+ )
28
38
  from ultralytics.utils.autobatch import check_train_batch_size
29
39
  from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
30
40
  from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
31
41
  from ultralytics.utils.files import get_latest_run
32
- from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
33
- strip_optimizer)
42
+ from ultralytics.utils.torch_utils import (
43
+ EarlyStopping,
44
+ ModelEMA,
45
+ de_parallel,
46
+ init_seeds,
47
+ one_cycle,
48
+ select_device,
49
+ strip_optimizer,
50
+ )
34
51
 
35
52
 
36
53
  class BaseTrainer:
@@ -89,12 +106,12 @@ class BaseTrainer:
89
106
  # Dirs
90
107
  self.save_dir = get_save_dir(self.args)
91
108
  self.args.name = self.save_dir.name # update name for loggers
92
- self.wdir = self.save_dir / 'weights' # weights dir
109
+ self.wdir = self.save_dir / "weights" # weights dir
93
110
  if RANK in (-1, 0):
94
111
  self.wdir.mkdir(parents=True, exist_ok=True) # make dir
95
112
  self.args.save_dir = str(self.save_dir)
96
- yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
97
- self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
113
+ yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
114
+ self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
98
115
  self.save_period = self.args.save_period
99
116
 
100
117
  self.batch_size = self.args.batch
@@ -104,18 +121,18 @@ class BaseTrainer:
104
121
  print_args(vars(self.args))
105
122
 
106
123
  # Device
107
- if self.device.type in ('cpu', 'mps'):
124
+ if self.device.type in ("cpu", "mps"):
108
125
  self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
109
126
 
110
127
  # Model and Dataset
111
128
  self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
112
129
  try:
113
- if self.args.task == 'classify':
130
+ if self.args.task == "classify":
114
131
  self.data = check_cls_dataset(self.args.data)
115
- elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'):
132
+ elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ("detect", "segment", "pose"):
116
133
  self.data = check_det_dataset(self.args.data)
117
- if 'yaml_file' in self.data:
118
- self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
134
+ if "yaml_file" in self.data:
135
+ self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
119
136
  except Exception as e:
120
137
  raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
121
138
 
@@ -131,8 +148,8 @@ class BaseTrainer:
131
148
  self.fitness = None
132
149
  self.loss = None
133
150
  self.tloss = None
134
- self.loss_names = ['Loss']
135
- self.csv = self.save_dir / 'results.csv'
151
+ self.loss_names = ["Loss"]
152
+ self.csv = self.save_dir / "results.csv"
136
153
  self.plot_idx = [0, 1, 2]
137
154
 
138
155
  # Callbacks
@@ -156,7 +173,7 @@ class BaseTrainer:
156
173
  def train(self):
157
174
  """Allow device='', device=None on Multi-GPU systems to default to device=0."""
158
175
  if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
159
- world_size = len(self.args.device.split(','))
176
+ world_size = len(self.args.device.split(","))
160
177
  elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
161
178
  world_size = len(self.args.device)
162
179
  elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
@@ -165,14 +182,16 @@ class BaseTrainer:
165
182
  world_size = 0
166
183
 
167
184
  # Run subprocess if DDP training, else train normally
168
- if world_size > 1 and 'LOCAL_RANK' not in os.environ:
185
+ if world_size > 1 and "LOCAL_RANK" not in os.environ:
169
186
  # Argument checks
170
187
  if self.args.rect:
171
188
  LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
172
189
  self.args.rect = False
173
190
  if self.args.batch == -1:
174
- LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
175
- "default 'batch=16'")
191
+ LOGGER.warning(
192
+ "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
193
+ "default 'batch=16'"
194
+ )
176
195
  self.args.batch = 16
177
196
 
178
197
  # Command
@@ -199,37 +218,45 @@ class BaseTrainer:
199
218
  def _setup_ddp(self, world_size):
200
219
  """Initializes and sets the DistributedDataParallel parameters for training."""
201
220
  torch.cuda.set_device(RANK)
202
- self.device = torch.device('cuda', RANK)
221
+ self.device = torch.device("cuda", RANK)
203
222
  # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
204
- os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
223
+ os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
205
224
  dist.init_process_group(
206
- 'nccl' if dist.is_nccl_available() else 'gloo',
225
+ "nccl" if dist.is_nccl_available() else "gloo",
207
226
  timeout=timedelta(seconds=10800), # 3 hours
208
227
  rank=RANK,
209
- world_size=world_size)
228
+ world_size=world_size,
229
+ )
210
230
 
211
231
  def _setup_train(self, world_size):
212
232
  """Builds dataloaders and optimizer on correct rank process."""
213
233
 
214
234
  # Model
215
- self.run_callbacks('on_pretrain_routine_start')
235
+ self.run_callbacks("on_pretrain_routine_start")
216
236
  ckpt = self.setup_model()
217
237
  self.model = self.model.to(self.device)
218
238
  self.set_model_attributes()
219
239
 
220
240
  # Freeze layers
221
- freeze_list = self.args.freeze if isinstance(
222
- self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
223
- always_freeze_names = ['.dfl'] # always freeze these layers
224
- freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
241
+ freeze_list = (
242
+ self.args.freeze
243
+ if isinstance(self.args.freeze, list)
244
+ else range(self.args.freeze)
245
+ if isinstance(self.args.freeze, int)
246
+ else []
247
+ )
248
+ always_freeze_names = [".dfl"] # always freeze these layers
249
+ freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
225
250
  for k, v in self.model.named_parameters():
226
251
  # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
227
252
  if any(x in k for x in freeze_layer_names):
228
253
  LOGGER.info(f"Freezing layer '{k}'")
229
254
  v.requires_grad = False
230
255
  elif not v.requires_grad:
231
- LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
232
- 'See ultralytics.engine.trainer for customization of frozen layers.')
256
+ LOGGER.info(
257
+ f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
258
+ "See ultralytics.engine.trainer for customization of frozen layers."
259
+ )
233
260
  v.requires_grad = True
234
261
 
235
262
  # Check AMP
@@ -246,7 +273,7 @@ class BaseTrainer:
246
273
  self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
247
274
 
248
275
  # Check imgsz
249
- gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
276
+ gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
250
277
  self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
251
278
  self.stride = gs # for multi-scale training
252
279
 
@@ -256,15 +283,14 @@ class BaseTrainer:
256
283
 
257
284
  # Dataloaders
258
285
  batch_size = self.batch_size // max(world_size, 1)
259
- self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
286
+ self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
260
287
  if RANK in (-1, 0):
261
288
  # NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects.
262
- self.test_loader = self.get_dataloader(self.testset,
263
- batch_size=batch_size if self.args.task == 'obb' else batch_size * 2,
264
- rank=-1,
265
- mode='val')
289
+ self.test_loader = self.get_dataloader(
290
+ self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
291
+ )
266
292
  self.validator = self.get_validator()
267
- metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
293
+ metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
268
294
  self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
269
295
  self.ema = ModelEMA(self.model)
270
296
  if self.args.plots:
@@ -274,18 +300,20 @@ class BaseTrainer:
274
300
  self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
275
301
  weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
276
302
  iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
277
- self.optimizer = self.build_optimizer(model=self.model,
278
- name=self.args.optimizer,
279
- lr=self.args.lr0,
280
- momentum=self.args.momentum,
281
- decay=weight_decay,
282
- iterations=iterations)
303
+ self.optimizer = self.build_optimizer(
304
+ model=self.model,
305
+ name=self.args.optimizer,
306
+ lr=self.args.lr0,
307
+ momentum=self.args.momentum,
308
+ decay=weight_decay,
309
+ iterations=iterations,
310
+ )
283
311
  # Scheduler
284
312
  self._setup_scheduler()
285
313
  self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
286
314
  self.resume_training(ckpt)
287
315
  self.scheduler.last_epoch = self.start_epoch - 1 # do not move
288
- self.run_callbacks('on_pretrain_routine_end')
316
+ self.run_callbacks("on_pretrain_routine_end")
289
317
 
290
318
  def _do_train(self, world_size=1):
291
319
  """Train completed, evaluate and plot if specified by arguments."""
@@ -299,19 +327,23 @@ class BaseTrainer:
299
327
  self.epoch_time = None
300
328
  self.epoch_time_start = time.time()
301
329
  self.train_time_start = time.time()
302
- self.run_callbacks('on_train_start')
303
- LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
304
- f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
305
- f"Logging results to {colorstr('bold', self.save_dir)}\n"
306
- f'Starting training for '
307
- f'{self.args.time} hours...' if self.args.time else f'{self.epochs} epochs...')
330
+ self.run_callbacks("on_train_start")
331
+ LOGGER.info(
332
+ f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
333
+ f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
334
+ f"Logging results to {colorstr('bold', self.save_dir)}\n"
335
+ f'Starting training for '
336
+ f'{self.args.time} hours...'
337
+ if self.args.time
338
+ else f"{self.epochs} epochs..."
339
+ )
308
340
  if self.args.close_mosaic:
309
341
  base_idx = (self.epochs - self.args.close_mosaic) * nb
310
342
  self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
311
343
  epoch = self.epochs # predefine for resume fully trained model edge cases
312
344
  for epoch in range(self.start_epoch, self.epochs):
313
345
  self.epoch = epoch
314
- self.run_callbacks('on_train_epoch_start')
346
+ self.run_callbacks("on_train_epoch_start")
315
347
  self.model.train()
316
348
  if RANK != -1:
317
349
  self.train_loader.sampler.set_epoch(epoch)
@@ -327,7 +359,7 @@ class BaseTrainer:
327
359
  self.tloss = None
328
360
  self.optimizer.zero_grad()
329
361
  for i, batch in pbar:
330
- self.run_callbacks('on_train_batch_start')
362
+ self.run_callbacks("on_train_batch_start")
331
363
  # Warmup
332
364
  ni = i + nb * epoch
333
365
  if ni <= nw:
@@ -335,10 +367,11 @@ class BaseTrainer:
335
367
  self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
336
368
  for j, x in enumerate(self.optimizer.param_groups):
337
369
  # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
338
- x['lr'] = np.interp(
339
- ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
340
- if 'momentum' in x:
341
- x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
370
+ x["lr"] = np.interp(
371
+ ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
372
+ )
373
+ if "momentum" in x:
374
+ x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
342
375
 
343
376
  # Forward
344
377
  with torch.cuda.amp.autocast(self.amp):
@@ -346,8 +379,9 @@ class BaseTrainer:
346
379
  self.loss, self.loss_items = self.model(batch)
347
380
  if RANK != -1:
348
381
  self.loss *= world_size
349
- self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
350
- else self.loss_items
382
+ self.tloss = (
383
+ (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
384
+ )
351
385
 
352
386
  # Backward
353
387
  self.scaler.scale(self.loss).backward()
@@ -368,24 +402,25 @@ class BaseTrainer:
368
402
  break
369
403
 
370
404
  # Log
371
- mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
405
+ mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
372
406
  loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
373
407
  losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
374
408
  if RANK in (-1, 0):
375
409
  pbar.set_description(
376
- ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
377
- (f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
378
- self.run_callbacks('on_batch_end')
410
+ ("%11s" * 2 + "%11.4g" * (2 + loss_len))
411
+ % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
412
+ )
413
+ self.run_callbacks("on_batch_end")
379
414
  if self.args.plots and ni in self.plot_idx:
380
415
  self.plot_training_samples(batch, ni)
381
416
 
382
- self.run_callbacks('on_train_batch_end')
417
+ self.run_callbacks("on_train_batch_end")
383
418
 
384
- self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
385
- self.run_callbacks('on_train_epoch_end')
419
+ self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
420
+ self.run_callbacks("on_train_epoch_end")
386
421
  if RANK in (-1, 0):
387
422
  final_epoch = epoch + 1 == self.epochs
388
- self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
423
+ self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
389
424
 
390
425
  # Validation
391
426
  if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
@@ -398,14 +433,14 @@ class BaseTrainer:
398
433
  # Save model
399
434
  if self.args.save or final_epoch:
400
435
  self.save_model()
401
- self.run_callbacks('on_model_save')
436
+ self.run_callbacks("on_model_save")
402
437
 
403
438
  # Scheduler
404
439
  t = time.time()
405
440
  self.epoch_time = t - self.epoch_time_start
406
441
  self.epoch_time_start = t
407
442
  with warnings.catch_warnings():
408
- warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
443
+ warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
409
444
  if self.args.time:
410
445
  mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
411
446
  self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
@@ -413,7 +448,7 @@ class BaseTrainer:
413
448
  self.scheduler.last_epoch = self.epoch # do not move
414
449
  self.stop |= epoch >= self.epochs # stop if exceeded epochs
415
450
  self.scheduler.step()
416
- self.run_callbacks('on_fit_epoch_end')
451
+ self.run_callbacks("on_fit_epoch_end")
417
452
  torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
418
453
 
419
454
  # Early Stopping
@@ -426,39 +461,43 @@ class BaseTrainer:
426
461
 
427
462
  if RANK in (-1, 0):
428
463
  # Do final val with best.pt
429
- LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
430
- f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
464
+ LOGGER.info(
465
+ f"\n{epoch - self.start_epoch + 1} epochs completed in "
466
+ f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
467
+ )
431
468
  self.final_eval()
432
469
  if self.args.plots:
433
470
  self.plot_metrics()
434
- self.run_callbacks('on_train_end')
471
+ self.run_callbacks("on_train_end")
435
472
  torch.cuda.empty_cache()
436
- self.run_callbacks('teardown')
473
+ self.run_callbacks("teardown")
437
474
 
438
475
  def save_model(self):
439
476
  """Save model training checkpoints with additional metadata."""
440
477
  import pandas as pd # scope for faster startup
441
- metrics = {**self.metrics, **{'fitness': self.fitness}}
442
- results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient='list').items()}
478
+
479
+ metrics = {**self.metrics, **{"fitness": self.fitness}}
480
+ results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
443
481
  ckpt = {
444
- 'epoch': self.epoch,
445
- 'best_fitness': self.best_fitness,
446
- 'model': deepcopy(de_parallel(self.model)).half(),
447
- 'ema': deepcopy(self.ema.ema).half(),
448
- 'updates': self.ema.updates,
449
- 'optimizer': self.optimizer.state_dict(),
450
- 'train_args': vars(self.args), # save as dict
451
- 'train_metrics': metrics,
452
- 'train_results': results,
453
- 'date': datetime.now().isoformat(),
454
- 'version': __version__}
482
+ "epoch": self.epoch,
483
+ "best_fitness": self.best_fitness,
484
+ "model": deepcopy(de_parallel(self.model)).half(),
485
+ "ema": deepcopy(self.ema.ema).half(),
486
+ "updates": self.ema.updates,
487
+ "optimizer": self.optimizer.state_dict(),
488
+ "train_args": vars(self.args), # save as dict
489
+ "train_metrics": metrics,
490
+ "train_results": results,
491
+ "date": datetime.now().isoformat(),
492
+ "version": __version__,
493
+ }
455
494
 
456
495
  # Save last and best
457
496
  torch.save(ckpt, self.last)
458
497
  if self.best_fitness == self.fitness:
459
498
  torch.save(ckpt, self.best)
460
499
  if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
461
- torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
500
+ torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
462
501
 
463
502
  @staticmethod
464
503
  def get_dataset(data):
@@ -467,7 +506,7 @@ class BaseTrainer:
467
506
 
468
507
  Returns None if data format is not recognized.
469
508
  """
470
- return data['train'], data.get('val') or data.get('test')
509
+ return data["train"], data.get("val") or data.get("test")
471
510
 
472
511
  def setup_model(self):
473
512
  """Load/create/download model for any task."""
@@ -476,9 +515,9 @@ class BaseTrainer:
476
515
 
477
516
  model, weights = self.model, None
478
517
  ckpt = None
479
- if str(model).endswith('.pt'):
518
+ if str(model).endswith(".pt"):
480
519
  weights, ckpt = attempt_load_one_weight(model)
481
- cfg = ckpt['model'].yaml
520
+ cfg = ckpt["model"].yaml
482
521
  else:
483
522
  cfg = model
484
523
  self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
@@ -505,7 +544,7 @@ class BaseTrainer:
505
544
  The returned dict is expected to contain "fitness" key.
506
545
  """
507
546
  metrics = self.validator(self)
508
- fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
547
+ fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
509
548
  if not self.best_fitness or self.best_fitness < fitness:
510
549
  self.best_fitness = fitness
511
550
  return metrics, fitness
@@ -516,24 +555,24 @@ class BaseTrainer:
516
555
 
517
556
  def get_validator(self):
518
557
  """Returns a NotImplementedError when the get_validator function is called."""
519
- raise NotImplementedError('get_validator function not implemented in trainer')
558
+ raise NotImplementedError("get_validator function not implemented in trainer")
520
559
 
521
- def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
560
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
522
561
  """Returns dataloader derived from torch.data.Dataloader."""
523
- raise NotImplementedError('get_dataloader function not implemented in trainer')
562
+ raise NotImplementedError("get_dataloader function not implemented in trainer")
524
563
 
525
- def build_dataset(self, img_path, mode='train', batch=None):
564
+ def build_dataset(self, img_path, mode="train", batch=None):
526
565
  """Build dataset."""
527
- raise NotImplementedError('build_dataset function not implemented in trainer')
566
+ raise NotImplementedError("build_dataset function not implemented in trainer")
528
567
 
529
- def label_loss_items(self, loss_items=None, prefix='train'):
568
+ def label_loss_items(self, loss_items=None, prefix="train"):
530
569
  """Returns a loss dict with labelled training loss items tensor."""
531
570
  # Not needed for classification but necessary for segmentation & detection
532
- return {'loss': loss_items} if loss_items is not None else ['loss']
571
+ return {"loss": loss_items} if loss_items is not None else ["loss"]
533
572
 
534
573
  def set_model_attributes(self):
535
574
  """To set or update model parameters before training."""
536
- self.model.names = self.data['names']
575
+ self.model.names = self.data["names"]
537
576
 
538
577
  def build_targets(self, preds, targets):
539
578
  """Builds target tensors for training YOLO model."""
@@ -541,7 +580,7 @@ class BaseTrainer:
541
580
 
542
581
  def progress_string(self):
543
582
  """Returns a string describing training progress."""
544
- return ''
583
+ return ""
545
584
 
546
585
  # TODO: may need to put these following functions into callback
547
586
  def plot_training_samples(self, batch, ni):
@@ -556,9 +595,9 @@ class BaseTrainer:
556
595
  """Saves training metrics to a CSV file."""
557
596
  keys, vals = list(metrics.keys()), list(metrics.values())
558
597
  n = len(metrics) + 1 # number of cols
559
- s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
560
- with open(self.csv, 'a') as f:
561
- f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
598
+ s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
599
+ with open(self.csv, "a") as f:
600
+ f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
562
601
 
563
602
  def plot_metrics(self):
564
603
  """Plot and display metrics visually."""
@@ -567,7 +606,7 @@ class BaseTrainer:
567
606
  def on_plot(self, name, data=None):
568
607
  """Registers plots (e.g. to be consumed in callbacks)"""
569
608
  path = Path(name)
570
- self.plots[path] = {'data': data, 'timestamp': time.time()}
609
+ self.plots[path] = {"data": data, "timestamp": time.time()}
571
610
 
572
611
  def final_eval(self):
573
612
  """Performs final evaluation and validation for object detection YOLO model."""
@@ -575,11 +614,11 @@ class BaseTrainer:
575
614
  if f.exists():
576
615
  strip_optimizer(f) # strip optimizers
577
616
  if f is self.best:
578
- LOGGER.info(f'\nValidating {f}...')
617
+ LOGGER.info(f"\nValidating {f}...")
579
618
  self.validator.args.plots = self.args.plots
580
619
  self.metrics = self.validator(model=f)
581
- self.metrics.pop('fitness', None)
582
- self.run_callbacks('on_fit_epoch_end')
620
+ self.metrics.pop("fitness", None)
621
+ self.run_callbacks("on_fit_epoch_end")
583
622
 
584
623
  def check_resume(self, overrides):
585
624
  """Check if resume checkpoint exists and update arguments accordingly."""
@@ -591,19 +630,21 @@ class BaseTrainer:
591
630
 
592
631
  # Check that resume data YAML exists, otherwise strip to force re-download of dataset
593
632
  ckpt_args = attempt_load_weights(last).args
594
- if not Path(ckpt_args['data']).exists():
595
- ckpt_args['data'] = self.args.data
633
+ if not Path(ckpt_args["data"]).exists():
634
+ ckpt_args["data"] = self.args.data
596
635
 
597
636
  resume = True
598
637
  self.args = get_cfg(ckpt_args)
599
638
  self.args.model = str(last) # reinstate model
600
- for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
639
+ for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
601
640
  if k in overrides:
602
641
  setattr(self.args, k, overrides[k])
603
642
 
604
643
  except Exception as e:
605
- raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
606
- "i.e. 'yolo train resume model=path/to/last.pt'") from e
644
+ raise FileNotFoundError(
645
+ "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
646
+ "i.e. 'yolo train resume model=path/to/last.pt'"
647
+ ) from e
607
648
  self.resume = resume
608
649
 
609
650
  def resume_training(self, ckpt):
@@ -611,23 +652,26 @@ class BaseTrainer:
611
652
  if ckpt is None:
612
653
  return
613
654
  best_fitness = 0.0
614
- start_epoch = ckpt['epoch'] + 1
615
- if ckpt['optimizer'] is not None:
616
- self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
617
- best_fitness = ckpt['best_fitness']
618
- if self.ema and ckpt.get('ema'):
619
- self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
620
- self.ema.updates = ckpt['updates']
655
+ start_epoch = ckpt["epoch"] + 1
656
+ if ckpt["optimizer"] is not None:
657
+ self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
658
+ best_fitness = ckpt["best_fitness"]
659
+ if self.ema and ckpt.get("ema"):
660
+ self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
661
+ self.ema.updates = ckpt["updates"]
621
662
  if self.resume:
622
- assert start_epoch > 0, \
623
- f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
663
+ assert start_epoch > 0, (
664
+ f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
624
665
  f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
666
+ )
625
667
  LOGGER.info(
626
- f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
668
+ f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
669
+ )
627
670
  if self.epochs < start_epoch:
628
671
  LOGGER.info(
629
- f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
630
- self.epochs += ckpt['epoch'] # finetune additional epochs
672
+ f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
673
+ )
674
+ self.epochs += ckpt["epoch"] # finetune additional epochs
631
675
  self.best_fitness = best_fitness
632
676
  self.start_epoch = start_epoch
633
677
  if start_epoch > (self.epochs - self.args.close_mosaic):
@@ -635,13 +679,13 @@ class BaseTrainer:
635
679
 
636
680
  def _close_dataloader_mosaic(self):
637
681
  """Update dataloaders to stop using mosaic augmentation."""
638
- if hasattr(self.train_loader.dataset, 'mosaic'):
682
+ if hasattr(self.train_loader.dataset, "mosaic"):
639
683
  self.train_loader.dataset.mosaic = False
640
- if hasattr(self.train_loader.dataset, 'close_mosaic'):
641
- LOGGER.info('Closing dataloader mosaic')
684
+ if hasattr(self.train_loader.dataset, "close_mosaic"):
685
+ LOGGER.info("Closing dataloader mosaic")
642
686
  self.train_loader.dataset.close_mosaic(hyp=self.args)
643
687
 
644
- def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
688
+ def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
645
689
  """
646
690
  Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
647
691
  weight decay, and number of iterations.
@@ -661,41 +705,45 @@ class BaseTrainer:
661
705
  """
662
706
 
663
707
  g = [], [], [] # optimizer parameter groups
664
- bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
665
- if name == 'auto':
666
- LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
667
- f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
668
- f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
669
- nc = getattr(model, 'nc', 10) # number of classes
708
+ bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
709
+ if name == "auto":
710
+ LOGGER.info(
711
+ f"{colorstr('optimizer:')} 'optimizer=auto' found, "
712
+ f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
713
+ f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
714
+ )
715
+ nc = getattr(model, "nc", 10) # number of classes
670
716
  lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
671
- name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
717
+ name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
672
718
  self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
673
719
 
674
720
  for module_name, module in model.named_modules():
675
721
  for param_name, param in module.named_parameters(recurse=False):
676
- fullname = f'{module_name}.{param_name}' if module_name else param_name
677
- if 'bias' in fullname: # bias (no decay)
722
+ fullname = f"{module_name}.{param_name}" if module_name else param_name
723
+ if "bias" in fullname: # bias (no decay)
678
724
  g[2].append(param)
679
725
  elif isinstance(module, bn): # weight (no decay)
680
726
  g[1].append(param)
681
727
  else: # weight (with decay)
682
728
  g[0].append(param)
683
729
 
684
- if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
730
+ if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
685
731
  optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
686
- elif name == 'RMSProp':
732
+ elif name == "RMSProp":
687
733
  optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
688
- elif name == 'SGD':
734
+ elif name == "SGD":
689
735
  optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
690
736
  else:
691
737
  raise NotImplementedError(
692
738
  f"Optimizer '{name}' not found in list of available optimizers "
693
- f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
694
- 'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
739
+ f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
740
+ "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
741
+ )
695
742
 
696
- optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
697
- optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
743
+ optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
744
+ optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
698
745
  LOGGER.info(
699
746
  f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
700
- f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
747
+ f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
748
+ )
701
749
  return optimizer