ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,18 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """
3
3
  Train a model on a dataset.
4
4
 
5
5
  Usage:
6
- $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
6
+ $ yolo mode=train model=yolov8n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
7
7
  """
8
8
 
9
+ import gc
9
10
  import math
10
11
  import os
11
12
  import subprocess
12
13
  import time
13
14
  import warnings
14
- from copy import deepcopy
15
+ from copy import copy, deepcopy
15
16
  from datetime import datetime, timedelta
16
17
  from pathlib import Path
17
18
 
@@ -25,6 +26,7 @@ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
25
26
  from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
26
27
  from ultralytics.utils import (
27
28
  DEFAULT_CFG,
29
+ LOCAL_RANK,
28
30
  LOGGER,
29
31
  RANK,
30
32
  TQDM,
@@ -40,20 +42,21 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
40
42
  from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
41
43
  from ultralytics.utils.files import get_latest_run
42
44
  from ultralytics.utils.torch_utils import (
45
+ TORCH_2_4,
43
46
  EarlyStopping,
44
47
  ModelEMA,
45
- de_parallel,
48
+ autocast,
49
+ convert_optimizer_state_dict_to_fp16,
46
50
  init_seeds,
47
51
  one_cycle,
48
52
  select_device,
49
53
  strip_optimizer,
54
+ torch_distributed_zero_first,
50
55
  )
51
56
 
52
57
 
53
58
  class BaseTrainer:
54
59
  """
55
- BaseTrainer.
56
-
57
60
  A base class for creating trainers.
58
61
 
59
62
  Attributes:
@@ -107,7 +110,7 @@ class BaseTrainer:
107
110
  self.save_dir = get_save_dir(self.args)
108
111
  self.args.name = self.save_dir.name # update name for loggers
109
112
  self.wdir = self.save_dir / "weights" # weights dir
110
- if RANK in (-1, 0):
113
+ if RANK in {-1, 0}:
111
114
  self.wdir.mkdir(parents=True, exist_ok=True) # make dir
112
115
  self.args.save_dir = str(self.save_dir)
113
116
  yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
@@ -115,33 +118,19 @@ class BaseTrainer:
115
118
  self.save_period = self.args.save_period
116
119
 
117
120
  self.batch_size = self.args.batch
118
- self.epochs = self.args.epochs
121
+ self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
119
122
  self.start_epoch = 0
120
123
  if RANK == -1:
121
124
  print_args(vars(self.args))
122
125
 
123
126
  # Device
124
- if self.device.type in ("cpu", "mps"):
127
+ if self.device.type in {"cpu", "mps"}:
125
128
  self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
126
129
 
127
130
  # Model and Dataset
128
131
  self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
129
- try:
130
- if self.args.task == "classify":
131
- self.data = check_cls_dataset(self.args.data)
132
- elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
133
- "detect",
134
- "segment",
135
- "pose",
136
- "obb",
137
- ):
138
- self.data = check_det_dataset(self.args.data)
139
- if "yaml_file" in self.data:
140
- self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
141
- except Exception as e:
142
- raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
143
-
144
- self.trainset, self.testset = self.get_dataset(self.data)
132
+ with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
133
+ self.trainset, self.testset = self.get_dataset()
145
134
  self.ema = None
146
135
 
147
136
  # Optimization utils init
@@ -157,9 +146,12 @@ class BaseTrainer:
157
146
  self.csv = self.save_dir / "results.csv"
158
147
  self.plot_idx = [0, 1, 2]
159
148
 
149
+ # HUB
150
+ self.hub_session = None
151
+
160
152
  # Callbacks
161
153
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
162
- if RANK in (-1, 0):
154
+ if RANK in {-1, 0}:
163
155
  callbacks.add_integration_callbacks(self)
164
156
 
165
157
  def add_callback(self, event: str, callback):
@@ -181,9 +173,11 @@ class BaseTrainer:
181
173
  world_size = len(self.args.device.split(","))
182
174
  elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
183
175
  world_size = len(self.args.device)
176
+ elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
177
+ world_size = 0
184
178
  elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
185
179
  world_size = 1 # default to device 0
186
- else: # i.e. device='cpu' or 'mps'
180
+ else: # i.e. device=None or device=''
187
181
  world_size = 0
188
182
 
189
183
  # Run subprocess if DDP training, else train normally
@@ -192,9 +186,9 @@ class BaseTrainer:
192
186
  if self.args.rect:
193
187
  LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
194
188
  self.args.rect = False
195
- if self.args.batch == -1:
189
+ if self.args.batch < 1.0:
196
190
  LOGGER.warning(
197
- "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
191
+ "WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting "
198
192
  "default 'batch=16'"
199
193
  )
200
194
  self.args.batch = 16
@@ -202,7 +196,7 @@ class BaseTrainer:
202
196
  # Command
203
197
  cmd, file = generate_ddp_command(world_size, self)
204
198
  try:
205
- LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
199
+ LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
206
200
  subprocess.run(cmd, check=True)
207
201
  except Exception as e:
208
202
  raise e
@@ -225,9 +219,9 @@ class BaseTrainer:
225
219
  torch.cuda.set_device(RANK)
226
220
  self.device = torch.device("cuda", RANK)
227
221
  # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
228
- os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
222
+ os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
229
223
  dist.init_process_group(
230
- "nccl" if dist.is_nccl_available() else "gloo",
224
+ backend="nccl" if dist.is_nccl_available() else "gloo",
231
225
  timeout=timedelta(seconds=10800), # 3 hours
232
226
  rank=RANK,
233
227
  world_size=world_size,
@@ -235,7 +229,6 @@ class BaseTrainer:
235
229
 
236
230
  def _setup_train(self, world_size):
237
231
  """Builds dataloaders and optimizer on correct rank process."""
238
-
239
232
  # Model
240
233
  self.run_callbacks("on_pretrain_routine_start")
241
234
  ckpt = self.setup_model()
@@ -266,16 +259,18 @@ class BaseTrainer:
266
259
 
267
260
  # Check AMP
268
261
  self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
269
- if self.amp and RANK in (-1, 0): # Single-GPU and DDP
262
+ if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
270
263
  callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
271
264
  self.amp = torch.tensor(check_amp(self.model), device=self.device)
272
265
  callbacks.default_callbacks = callbacks_backup # restore callbacks
273
266
  if RANK > -1 and world_size > 1: # DDP
274
267
  dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
275
268
  self.amp = bool(self.amp) # as boolean
276
- self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
269
+ self.scaler = (
270
+ torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
271
+ )
277
272
  if world_size > 1:
278
- self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
273
+ self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
279
274
 
280
275
  # Check imgsz
281
276
  gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
@@ -283,13 +278,13 @@ class BaseTrainer:
283
278
  self.stride = gs # for multiscale training
284
279
 
285
280
  # Batch size
286
- if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
287
- self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
281
+ if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
282
+ self.args.batch = self.batch_size = self.auto_batch()
288
283
 
289
284
  # Dataloaders
290
285
  batch_size = self.batch_size // max(world_size, 1)
291
- self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
292
- if RANK in (-1, 0):
286
+ self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
287
+ if RANK in {-1, 0}:
293
288
  # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
294
289
  self.test_loader = self.get_dataloader(
295
290
  self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
@@ -334,18 +329,23 @@ class BaseTrainer:
334
329
  self.train_time_start = time.time()
335
330
  self.run_callbacks("on_train_start")
336
331
  LOGGER.info(
337
- f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
338
- f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
332
+ f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
333
+ f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
339
334
  f"Logging results to {colorstr('bold', self.save_dir)}\n"
340
- f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
335
+ f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
341
336
  )
342
337
  if self.args.close_mosaic:
343
338
  base_idx = (self.epochs - self.args.close_mosaic) * nb
344
339
  self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
345
340
  epoch = self.start_epoch
341
+ self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
346
342
  while True:
347
343
  self.epoch = epoch
348
344
  self.run_callbacks("on_train_epoch_start")
345
+ with warnings.catch_warnings():
346
+ warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
347
+ self.scheduler.step()
348
+
349
349
  self.model.train()
350
350
  if RANK != -1:
351
351
  self.train_loader.sampler.set_epoch(epoch)
@@ -355,11 +355,10 @@ class BaseTrainer:
355
355
  self._close_dataloader_mosaic()
356
356
  self.train_loader.reset()
357
357
 
358
- if RANK in (-1, 0):
358
+ if RANK in {-1, 0}:
359
359
  LOGGER.info(self.progress_string())
360
360
  pbar = TQDM(enumerate(self.train_loader), total=nb)
361
361
  self.tloss = None
362
- self.optimizer.zero_grad()
363
362
  for i, batch in pbar:
364
363
  self.run_callbacks("on_train_batch_start")
365
364
  # Warmup
@@ -376,7 +375,7 @@ class BaseTrainer:
376
375
  x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
377
376
 
378
377
  # Forward
379
- with torch.cuda.amp.autocast(self.amp):
378
+ with autocast(self.amp):
380
379
  batch = self.preprocess_batch(batch)
381
380
  self.loss, self.loss_items = self.model(batch)
382
381
  if RANK != -1:
@@ -404,13 +403,17 @@ class BaseTrainer:
404
403
  break
405
404
 
406
405
  # Log
407
- mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
408
- loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
409
- losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
410
- if RANK in (-1, 0):
406
+ if RANK in {-1, 0}:
407
+ loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
411
408
  pbar.set_description(
412
- ("%11s" * 2 + "%11.4g" * (2 + loss_len))
413
- % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
409
+ ("%11s" * 2 + "%11.4g" * (2 + loss_length))
410
+ % (
411
+ f"{epoch + 1}/{self.epochs}",
412
+ f"{self._get_memory():.3g}G", # (GB) GPU memory util
413
+ *(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
414
+ batch["cls"].shape[0], # batch size, i.e. 8
415
+ batch["img"].shape[-1], # imgsz, i.e 640
416
+ )
414
417
  )
415
418
  self.run_callbacks("on_batch_end")
416
419
  if self.args.plots and ni in self.plot_idx:
@@ -420,8 +423,8 @@ class BaseTrainer:
420
423
 
421
424
  self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
422
425
  self.run_callbacks("on_train_epoch_end")
423
- if RANK in (-1, 0):
424
- final_epoch = epoch + 1 == self.epochs
426
+ if RANK in {-1, 0}:
427
+ final_epoch = epoch + 1 >= self.epochs
425
428
  self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
426
429
 
427
430
  # Validation
@@ -441,17 +444,14 @@ class BaseTrainer:
441
444
  t = time.time()
442
445
  self.epoch_time = t - self.epoch_time_start
443
446
  self.epoch_time_start = t
444
- with warnings.catch_warnings():
445
- warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
446
- if self.args.time:
447
- mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
448
- self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
449
- self._setup_scheduler()
450
- self.scheduler.last_epoch = self.epoch # do not move
451
- self.stop |= epoch >= self.epochs # stop if exceeded epochs
452
- self.scheduler.step()
447
+ if self.args.time:
448
+ mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
449
+ self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
450
+ self._setup_scheduler()
451
+ self.scheduler.last_epoch = self.epoch # do not move
452
+ self.stop |= epoch >= self.epochs # stop if exceeded epochs
453
453
  self.run_callbacks("on_fit_epoch_end")
454
- torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
454
+ self._clear_memory()
455
455
 
456
456
  # Early Stopping
457
457
  if RANK != -1: # if DDP training
@@ -462,55 +462,109 @@ class BaseTrainer:
462
462
  break # must break all DDP ranks
463
463
  epoch += 1
464
464
 
465
- if RANK in (-1, 0):
465
+ if RANK in {-1, 0}:
466
466
  # Do final val with best.pt
467
- LOGGER.info(
468
- f"\n{epoch - self.start_epoch + 1} epochs completed in "
469
- f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
470
- )
467
+ seconds = time.time() - self.train_time_start
468
+ LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
471
469
  self.final_eval()
472
470
  if self.args.plots:
473
471
  self.plot_metrics()
474
472
  self.run_callbacks("on_train_end")
475
- torch.cuda.empty_cache()
473
+ self._clear_memory()
476
474
  self.run_callbacks("teardown")
477
475
 
476
+ def auto_batch(self, max_num_obj=0):
477
+ """Get batch size by calculating memory occupation of model."""
478
+ return check_train_batch_size(
479
+ model=self.model,
480
+ imgsz=self.args.imgsz,
481
+ amp=self.amp,
482
+ batch=self.batch_size,
483
+ max_num_obj=max_num_obj,
484
+ ) # returns batch size
485
+
486
+ def _get_memory(self):
487
+ """Get accelerator memory utilization in GB."""
488
+ if self.device.type == "mps":
489
+ memory = torch.mps.driver_allocated_memory()
490
+ elif self.device.type == "cpu":
491
+ memory = 0
492
+ else:
493
+ memory = torch.cuda.memory_reserved()
494
+ return memory / 1e9
495
+
496
+ def _clear_memory(self):
497
+ """Clear accelerator memory on different platforms."""
498
+ gc.collect()
499
+ if self.device.type == "mps":
500
+ torch.mps.empty_cache()
501
+ elif self.device.type == "cpu":
502
+ return
503
+ else:
504
+ torch.cuda.empty_cache()
505
+
506
+ def read_results_csv(self):
507
+ """Read results.csv into a dict using pandas."""
508
+ import pandas as pd # scope for faster 'import ultralytics'
509
+
510
+ return pd.read_csv(self.csv).to_dict(orient="list")
511
+
478
512
  def save_model(self):
479
513
  """Save model training checkpoints with additional metadata."""
480
- import pandas as pd # scope for faster startup
481
-
482
- metrics = {**self.metrics, **{"fitness": self.fitness}}
483
- results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
484
- ckpt = {
485
- "epoch": self.epoch,
486
- "best_fitness": self.best_fitness,
487
- "model": deepcopy(de_parallel(self.model)).half(),
488
- "ema": deepcopy(self.ema.ema).half(),
489
- "updates": self.ema.updates,
490
- "optimizer": self.optimizer.state_dict(),
491
- "train_args": vars(self.args), # save as dict
492
- "train_metrics": metrics,
493
- "train_results": results,
494
- "date": datetime.now().isoformat(),
495
- "version": __version__,
496
- "license": "AGPL-3.0 (https://ultralytics.com/license)",
497
- "docs": "https://docs.ultralytics.com",
498
- }
499
-
500
- # Save last and best
501
- torch.save(ckpt, self.last)
514
+ import io
515
+
516
+ # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
517
+ buffer = io.BytesIO()
518
+ torch.save(
519
+ {
520
+ "epoch": self.epoch,
521
+ "best_fitness": self.best_fitness,
522
+ "model": None, # resume and final checkpoints derive from EMA
523
+ "ema": deepcopy(self.ema.ema).half(),
524
+ "updates": self.ema.updates,
525
+ "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
526
+ "train_args": vars(self.args), # save as dict
527
+ "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
528
+ "train_results": self.read_results_csv(),
529
+ "date": datetime.now().isoformat(),
530
+ "version": __version__,
531
+ "license": "AGPL-3.0 (https://ultralytics.com/license)",
532
+ "docs": "https://docs.ultralytics.com",
533
+ },
534
+ buffer,
535
+ )
536
+ serialized_ckpt = buffer.getvalue() # get the serialized content to save
537
+
538
+ # Save checkpoints
539
+ self.last.write_bytes(serialized_ckpt) # save last.pt
502
540
  if self.best_fitness == self.fitness:
503
- torch.save(ckpt, self.best)
504
- if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
505
- torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
541
+ self.best.write_bytes(serialized_ckpt) # save best.pt
542
+ if (self.save_period > 0) and (self.epoch % self.save_period == 0):
543
+ (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
544
+ # if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):
545
+ # (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint
506
546
 
507
- @staticmethod
508
- def get_dataset(data):
547
+ def get_dataset(self):
509
548
  """
510
549
  Get train, val path from data dict if it exists.
511
550
 
512
551
  Returns None if data format is not recognized.
513
552
  """
553
+ try:
554
+ if self.args.task == "classify":
555
+ data = check_cls_dataset(self.args.data)
556
+ elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
557
+ "detect",
558
+ "segment",
559
+ "pose",
560
+ "obb",
561
+ }:
562
+ data = check_det_dataset(self.args.data)
563
+ if "yaml_file" in data:
564
+ self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
565
+ except Exception as e:
566
+ raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
567
+ self.data = data
514
568
  return data["train"], data.get("val") or data.get("test")
515
569
 
516
570
  def setup_model(self):
@@ -518,13 +572,13 @@ class BaseTrainer:
518
572
  if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
519
573
  return
520
574
 
521
- model, weights = self.model, None
575
+ cfg, weights = self.model, None
522
576
  ckpt = None
523
- if str(model).endswith(".pt"):
524
- weights, ckpt = attempt_load_one_weight(model)
525
- cfg = ckpt["model"].yaml
526
- else:
527
- cfg = model
577
+ if str(self.model).endswith(".pt"):
578
+ weights, ckpt = attempt_load_one_weight(self.model)
579
+ cfg = weights.yaml
580
+ elif isinstance(self.args.pretrained, (str, Path)):
581
+ weights, _ = attempt_load_one_weight(self.args.pretrained)
528
582
  self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
529
583
  return ckpt
530
584
 
@@ -603,26 +657,31 @@ class BaseTrainer:
603
657
  def save_metrics(self, metrics):
604
658
  """Saves training metrics to a CSV file."""
605
659
  keys, vals = list(metrics.keys()), list(metrics.values())
606
- n = len(metrics) + 1 # number of cols
607
- s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
660
+ n = len(metrics) + 2 # number of cols
661
+ s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
662
+ t = time.time() - self.train_time_start
608
663
  with open(self.csv, "a") as f:
609
- f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
664
+ f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
610
665
 
611
666
  def plot_metrics(self):
612
667
  """Plot and display metrics visually."""
613
668
  pass
614
669
 
615
670
  def on_plot(self, name, data=None):
616
- """Registers plots (e.g. to be consumed in callbacks)"""
671
+ """Registers plots (e.g. to be consumed in callbacks)."""
617
672
  path = Path(name)
618
673
  self.plots[path] = {"data": data, "timestamp": time.time()}
619
674
 
620
675
  def final_eval(self):
621
676
  """Performs final evaluation and validation for object detection YOLO model."""
677
+ ckpt = {}
622
678
  for f in self.last, self.best:
623
679
  if f.exists():
624
- strip_optimizer(f) # strip optimizers
625
- if f is self.best:
680
+ if f is self.last:
681
+ ckpt = strip_optimizer(f)
682
+ elif f is self.best:
683
+ k = "train_results" # update best.pt train_metrics from last.pt
684
+ strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
626
685
  LOGGER.info(f"\nValidating {f}...")
627
686
  self.validator.args.plots = self.args.plots
628
687
  self.metrics = self.validator(model=f)
@@ -644,8 +703,13 @@ class BaseTrainer:
644
703
 
645
704
  resume = True
646
705
  self.args = get_cfg(ckpt_args)
647
- self.args.model = str(last) # reinstate model
648
- for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
706
+ self.args.model = self.args.resume = str(last) # reinstate model
707
+ for k in (
708
+ "imgsz",
709
+ "batch",
710
+ "device",
711
+ "close_mosaic",
712
+ ): # allow arg updates to reduce memory or update device on resume
649
713
  if k in overrides:
650
714
  setattr(self.args, k, overrides[k])
651
715
 
@@ -658,24 +722,21 @@ class BaseTrainer:
658
722
 
659
723
  def resume_training(self, ckpt):
660
724
  """Resume YOLO training from given epoch and best fitness."""
661
- if ckpt is None:
725
+ if ckpt is None or not self.resume:
662
726
  return
663
727
  best_fitness = 0.0
664
- start_epoch = ckpt["epoch"] + 1
665
- if ckpt["optimizer"] is not None:
728
+ start_epoch = ckpt.get("epoch", -1) + 1
729
+ if ckpt.get("optimizer", None) is not None:
666
730
  self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
667
731
  best_fitness = ckpt["best_fitness"]
668
732
  if self.ema and ckpt.get("ema"):
669
733
  self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
670
734
  self.ema.updates = ckpt["updates"]
671
- if self.resume:
672
- assert start_epoch > 0, (
673
- f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
674
- f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
675
- )
676
- LOGGER.info(
677
- f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
678
- )
735
+ assert start_epoch > 0, (
736
+ f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
737
+ f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
738
+ )
739
+ LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
679
740
  if self.epochs < start_epoch:
680
741
  LOGGER.info(
681
742
  f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
@@ -692,7 +753,7 @@ class BaseTrainer:
692
753
  self.train_loader.dataset.mosaic = False
693
754
  if hasattr(self.train_loader.dataset, "close_mosaic"):
694
755
  LOGGER.info("Closing dataloader mosaic")
695
- self.train_loader.dataset.close_mosaic(hyp=self.args)
756
+ self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
696
757
 
697
758
  def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
698
759
  """
@@ -712,7 +773,6 @@ class BaseTrainer:
712
773
  Returns:
713
774
  (torch.optim.Optimizer): The constructed optimizer.
714
775
  """
715
-
716
776
  g = [], [], [] # optimizer parameter groups
717
777
  bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
718
778
  if name == "auto":
@@ -736,7 +796,9 @@ class BaseTrainer:
736
796
  else: # weight (with decay)
737
797
  g[0].append(param)
738
798
 
739
- if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
799
+ optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
800
+ name = {x.lower(): x for x in optimizers}.get(name.lower())
801
+ if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
740
802
  optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
741
803
  elif name == "RMSProp":
742
804
  optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
@@ -744,15 +806,14 @@ class BaseTrainer:
744
806
  optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
745
807
  else:
746
808
  raise NotImplementedError(
747
- f"Optimizer '{name}' not found in list of available optimizers "
748
- f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
749
- "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
809
+ f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
810
+ "Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
750
811
  )
751
812
 
752
813
  optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
753
814
  optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
754
815
  LOGGER.info(
755
816
  f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
756
- f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
817
+ f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
757
818
  )
758
819
  return optimizer