ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +37 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +111 -41
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +579 -244
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +191 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +226 -82
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +172 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +305 -112
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.63.dist-info/METADATA +370 -0
  235. ultralytics-8.3.63.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,18 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """
3
3
  Train a model on a dataset.
4
4
 
5
5
  Usage:
6
- $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
6
+ $ yolo mode=train model=yolov8n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
7
7
  """
8
8
 
9
+ import gc
9
10
  import math
10
11
  import os
11
12
  import subprocess
12
13
  import time
13
14
  import warnings
14
- from copy import deepcopy
15
+ from copy import copy, deepcopy
15
16
  from datetime import datetime, timedelta
16
17
  from pathlib import Path
17
18
 
@@ -25,6 +26,7 @@ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
25
26
  from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
26
27
  from ultralytics.utils import (
27
28
  DEFAULT_CFG,
29
+ LOCAL_RANK,
28
30
  LOGGER,
29
31
  RANK,
30
32
  TQDM,
@@ -40,20 +42,21 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
40
42
  from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
41
43
  from ultralytics.utils.files import get_latest_run
42
44
  from ultralytics.utils.torch_utils import (
45
+ TORCH_2_4,
43
46
  EarlyStopping,
44
47
  ModelEMA,
45
- de_parallel,
48
+ autocast,
49
+ convert_optimizer_state_dict_to_fp16,
46
50
  init_seeds,
47
51
  one_cycle,
48
52
  select_device,
49
53
  strip_optimizer,
54
+ torch_distributed_zero_first,
50
55
  )
51
56
 
52
57
 
53
58
  class BaseTrainer:
54
59
  """
55
- BaseTrainer.
56
-
57
60
  A base class for creating trainers.
58
61
 
59
62
  Attributes:
@@ -107,7 +110,7 @@ class BaseTrainer:
107
110
  self.save_dir = get_save_dir(self.args)
108
111
  self.args.name = self.save_dir.name # update name for loggers
109
112
  self.wdir = self.save_dir / "weights" # weights dir
110
- if RANK in (-1, 0):
113
+ if RANK in {-1, 0}:
111
114
  self.wdir.mkdir(parents=True, exist_ok=True) # make dir
112
115
  self.args.save_dir = str(self.save_dir)
113
116
  yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
@@ -115,33 +118,19 @@ class BaseTrainer:
115
118
  self.save_period = self.args.save_period
116
119
 
117
120
  self.batch_size = self.args.batch
118
- self.epochs = self.args.epochs
121
+ self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
119
122
  self.start_epoch = 0
120
123
  if RANK == -1:
121
124
  print_args(vars(self.args))
122
125
 
123
126
  # Device
124
- if self.device.type in ("cpu", "mps"):
127
+ if self.device.type in {"cpu", "mps"}:
125
128
  self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
126
129
 
127
130
  # Model and Dataset
128
131
  self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
129
- try:
130
- if self.args.task == "classify":
131
- self.data = check_cls_dataset(self.args.data)
132
- elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
133
- "detect",
134
- "segment",
135
- "pose",
136
- "obb",
137
- ):
138
- self.data = check_det_dataset(self.args.data)
139
- if "yaml_file" in self.data:
140
- self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
141
- except Exception as e:
142
- raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
143
-
144
- self.trainset, self.testset = self.get_dataset(self.data)
132
+ with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
133
+ self.trainset, self.testset = self.get_dataset()
145
134
  self.ema = None
146
135
 
147
136
  # Optimization utils init
@@ -157,9 +146,12 @@ class BaseTrainer:
157
146
  self.csv = self.save_dir / "results.csv"
158
147
  self.plot_idx = [0, 1, 2]
159
148
 
149
+ # HUB
150
+ self.hub_session = None
151
+
160
152
  # Callbacks
161
153
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
162
- if RANK in (-1, 0):
154
+ if RANK in {-1, 0}:
163
155
  callbacks.add_integration_callbacks(self)
164
156
 
165
157
  def add_callback(self, event: str, callback):
@@ -181,9 +173,11 @@ class BaseTrainer:
181
173
  world_size = len(self.args.device.split(","))
182
174
  elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
183
175
  world_size = len(self.args.device)
176
+ elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
177
+ world_size = 0
184
178
  elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
185
179
  world_size = 1 # default to device 0
186
- else: # i.e. device='cpu' or 'mps'
180
+ else: # i.e. device=None or device=''
187
181
  world_size = 0
188
182
 
189
183
  # Run subprocess if DDP training, else train normally
@@ -192,9 +186,9 @@ class BaseTrainer:
192
186
  if self.args.rect:
193
187
  LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
194
188
  self.args.rect = False
195
- if self.args.batch == -1:
189
+ if self.args.batch < 1.0:
196
190
  LOGGER.warning(
197
- "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
191
+ "WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting "
198
192
  "default 'batch=16'"
199
193
  )
200
194
  self.args.batch = 16
@@ -202,7 +196,7 @@ class BaseTrainer:
202
196
  # Command
203
197
  cmd, file = generate_ddp_command(world_size, self)
204
198
  try:
205
- LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
199
+ LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
206
200
  subprocess.run(cmd, check=True)
207
201
  except Exception as e:
208
202
  raise e
@@ -225,9 +219,9 @@ class BaseTrainer:
225
219
  torch.cuda.set_device(RANK)
226
220
  self.device = torch.device("cuda", RANK)
227
221
  # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
228
- os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
222
+ os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
229
223
  dist.init_process_group(
230
- "nccl" if dist.is_nccl_available() else "gloo",
224
+ backend="nccl" if dist.is_nccl_available() else "gloo",
231
225
  timeout=timedelta(seconds=10800), # 3 hours
232
226
  rank=RANK,
233
227
  world_size=world_size,
@@ -235,7 +229,6 @@ class BaseTrainer:
235
229
 
236
230
  def _setup_train(self, world_size):
237
231
  """Builds dataloaders and optimizer on correct rank process."""
238
-
239
232
  # Model
240
233
  self.run_callbacks("on_pretrain_routine_start")
241
234
  ckpt = self.setup_model()
@@ -266,16 +259,19 @@ class BaseTrainer:
266
259
 
267
260
  # Check AMP
268
261
  self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
269
- if self.amp and RANK in (-1, 0): # Single-GPU and DDP
262
+ if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
270
263
  callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
271
264
  self.amp = torch.tensor(check_amp(self.model), device=self.device)
272
265
  callbacks.default_callbacks = callbacks_backup # restore callbacks
273
266
  if RANK > -1 and world_size > 1: # DDP
274
267
  dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
275
268
  self.amp = bool(self.amp) # as boolean
276
- self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
269
+ self.scaler = (
270
+ torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
271
+ )
277
272
  if world_size > 1:
278
- self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
273
+ self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
274
+ self.set_model_attributes() # set again after DDP wrapper
279
275
 
280
276
  # Check imgsz
281
277
  gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
@@ -283,13 +279,13 @@ class BaseTrainer:
283
279
  self.stride = gs # for multiscale training
284
280
 
285
281
  # Batch size
286
- if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
287
- self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
282
+ if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
283
+ self.args.batch = self.batch_size = self.auto_batch()
288
284
 
289
285
  # Dataloaders
290
286
  batch_size = self.batch_size // max(world_size, 1)
291
- self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
292
- if RANK in (-1, 0):
287
+ self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
288
+ if RANK in {-1, 0}:
293
289
  # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
294
290
  self.test_loader = self.get_dataloader(
295
291
  self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
@@ -334,18 +330,23 @@ class BaseTrainer:
334
330
  self.train_time_start = time.time()
335
331
  self.run_callbacks("on_train_start")
336
332
  LOGGER.info(
337
- f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
338
- f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
333
+ f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
334
+ f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
339
335
  f"Logging results to {colorstr('bold', self.save_dir)}\n"
340
- f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
336
+ f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
341
337
  )
342
338
  if self.args.close_mosaic:
343
339
  base_idx = (self.epochs - self.args.close_mosaic) * nb
344
340
  self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
345
341
  epoch = self.start_epoch
342
+ self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
346
343
  while True:
347
344
  self.epoch = epoch
348
345
  self.run_callbacks("on_train_epoch_start")
346
+ with warnings.catch_warnings():
347
+ warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
348
+ self.scheduler.step()
349
+
349
350
  self.model.train()
350
351
  if RANK != -1:
351
352
  self.train_loader.sampler.set_epoch(epoch)
@@ -355,11 +356,10 @@ class BaseTrainer:
355
356
  self._close_dataloader_mosaic()
356
357
  self.train_loader.reset()
357
358
 
358
- if RANK in (-1, 0):
359
+ if RANK in {-1, 0}:
359
360
  LOGGER.info(self.progress_string())
360
361
  pbar = TQDM(enumerate(self.train_loader), total=nb)
361
362
  self.tloss = None
362
- self.optimizer.zero_grad()
363
363
  for i, batch in pbar:
364
364
  self.run_callbacks("on_train_batch_start")
365
365
  # Warmup
@@ -376,7 +376,7 @@ class BaseTrainer:
376
376
  x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
377
377
 
378
378
  # Forward
379
- with torch.cuda.amp.autocast(self.amp):
379
+ with autocast(self.amp):
380
380
  batch = self.preprocess_batch(batch)
381
381
  self.loss, self.loss_items = self.model(batch)
382
382
  if RANK != -1:
@@ -404,13 +404,17 @@ class BaseTrainer:
404
404
  break
405
405
 
406
406
  # Log
407
- mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
408
- loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
409
- losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
410
- if RANK in (-1, 0):
407
+ if RANK in {-1, 0}:
408
+ loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
411
409
  pbar.set_description(
412
- ("%11s" * 2 + "%11.4g" * (2 + loss_len))
413
- % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
410
+ ("%11s" * 2 + "%11.4g" * (2 + loss_length))
411
+ % (
412
+ f"{epoch + 1}/{self.epochs}",
413
+ f"{self._get_memory():.3g}G", # (GB) GPU memory util
414
+ *(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
415
+ batch["cls"].shape[0], # batch size, i.e. 8
416
+ batch["img"].shape[-1], # imgsz, i.e 640
417
+ )
414
418
  )
415
419
  self.run_callbacks("on_batch_end")
416
420
  if self.args.plots and ni in self.plot_idx:
@@ -420,8 +424,8 @@ class BaseTrainer:
420
424
 
421
425
  self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
422
426
  self.run_callbacks("on_train_epoch_end")
423
- if RANK in (-1, 0):
424
- final_epoch = epoch + 1 == self.epochs
427
+ if RANK in {-1, 0}:
428
+ final_epoch = epoch + 1 >= self.epochs
425
429
  self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
426
430
 
427
431
  # Validation
@@ -441,17 +445,14 @@ class BaseTrainer:
441
445
  t = time.time()
442
446
  self.epoch_time = t - self.epoch_time_start
443
447
  self.epoch_time_start = t
444
- with warnings.catch_warnings():
445
- warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
446
- if self.args.time:
447
- mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
448
- self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
449
- self._setup_scheduler()
450
- self.scheduler.last_epoch = self.epoch # do not move
451
- self.stop |= epoch >= self.epochs # stop if exceeded epochs
452
- self.scheduler.step()
448
+ if self.args.time:
449
+ mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
450
+ self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
451
+ self._setup_scheduler()
452
+ self.scheduler.last_epoch = self.epoch # do not move
453
+ self.stop |= epoch >= self.epochs # stop if exceeded epochs
453
454
  self.run_callbacks("on_fit_epoch_end")
454
- torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
455
+ self._clear_memory()
455
456
 
456
457
  # Early Stopping
457
458
  if RANK != -1: # if DDP training
@@ -462,55 +463,109 @@ class BaseTrainer:
462
463
  break # must break all DDP ranks
463
464
  epoch += 1
464
465
 
465
- if RANK in (-1, 0):
466
+ if RANK in {-1, 0}:
466
467
  # Do final val with best.pt
467
- LOGGER.info(
468
- f"\n{epoch - self.start_epoch + 1} epochs completed in "
469
- f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
470
- )
468
+ seconds = time.time() - self.train_time_start
469
+ LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
471
470
  self.final_eval()
472
471
  if self.args.plots:
473
472
  self.plot_metrics()
474
473
  self.run_callbacks("on_train_end")
475
- torch.cuda.empty_cache()
474
+ self._clear_memory()
476
475
  self.run_callbacks("teardown")
477
476
 
477
+ def auto_batch(self, max_num_obj=0):
478
+ """Get batch size by calculating memory occupation of model."""
479
+ return check_train_batch_size(
480
+ model=self.model,
481
+ imgsz=self.args.imgsz,
482
+ amp=self.amp,
483
+ batch=self.batch_size,
484
+ max_num_obj=max_num_obj,
485
+ ) # returns batch size
486
+
487
+ def _get_memory(self):
488
+ """Get accelerator memory utilization in GB."""
489
+ if self.device.type == "mps":
490
+ memory = torch.mps.driver_allocated_memory()
491
+ elif self.device.type == "cpu":
492
+ memory = 0
493
+ else:
494
+ memory = torch.cuda.memory_reserved()
495
+ return memory / 1e9
496
+
497
+ def _clear_memory(self):
498
+ """Clear accelerator memory on different platforms."""
499
+ gc.collect()
500
+ if self.device.type == "mps":
501
+ torch.mps.empty_cache()
502
+ elif self.device.type == "cpu":
503
+ return
504
+ else:
505
+ torch.cuda.empty_cache()
506
+
507
+ def read_results_csv(self):
508
+ """Read results.csv into a dict using pandas."""
509
+ import pandas as pd # scope for faster 'import ultralytics'
510
+
511
+ return pd.read_csv(self.csv).to_dict(orient="list")
512
+
478
513
  def save_model(self):
479
514
  """Save model training checkpoints with additional metadata."""
480
- import pandas as pd # scope for faster startup
481
-
482
- metrics = {**self.metrics, **{"fitness": self.fitness}}
483
- results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
484
- ckpt = {
485
- "epoch": self.epoch,
486
- "best_fitness": self.best_fitness,
487
- "model": deepcopy(de_parallel(self.model)).half(),
488
- "ema": deepcopy(self.ema.ema).half(),
489
- "updates": self.ema.updates,
490
- "optimizer": self.optimizer.state_dict(),
491
- "train_args": vars(self.args), # save as dict
492
- "train_metrics": metrics,
493
- "train_results": results,
494
- "date": datetime.now().isoformat(),
495
- "version": __version__,
496
- "license": "AGPL-3.0 (https://ultralytics.com/license)",
497
- "docs": "https://docs.ultralytics.com",
498
- }
499
-
500
- # Save last and best
501
- torch.save(ckpt, self.last)
515
+ import io
516
+
517
+ # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
518
+ buffer = io.BytesIO()
519
+ torch.save(
520
+ {
521
+ "epoch": self.epoch,
522
+ "best_fitness": self.best_fitness,
523
+ "model": None, # resume and final checkpoints derive from EMA
524
+ "ema": deepcopy(self.ema.ema).half(),
525
+ "updates": self.ema.updates,
526
+ "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
527
+ "train_args": vars(self.args), # save as dict
528
+ "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
529
+ "train_results": self.read_results_csv(),
530
+ "date": datetime.now().isoformat(),
531
+ "version": __version__,
532
+ "license": "AGPL-3.0 (https://ultralytics.com/license)",
533
+ "docs": "https://docs.ultralytics.com",
534
+ },
535
+ buffer,
536
+ )
537
+ serialized_ckpt = buffer.getvalue() # get the serialized content to save
538
+
539
+ # Save checkpoints
540
+ self.last.write_bytes(serialized_ckpt) # save last.pt
502
541
  if self.best_fitness == self.fitness:
503
- torch.save(ckpt, self.best)
504
- if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
505
- torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
542
+ self.best.write_bytes(serialized_ckpt) # save best.pt
543
+ if (self.save_period > 0) and (self.epoch % self.save_period == 0):
544
+ (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
545
+ # if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):
546
+ # (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint
506
547
 
507
- @staticmethod
508
- def get_dataset(data):
548
+ def get_dataset(self):
509
549
  """
510
550
  Get train, val path from data dict if it exists.
511
551
 
512
552
  Returns None if data format is not recognized.
513
553
  """
554
+ try:
555
+ if self.args.task == "classify":
556
+ data = check_cls_dataset(self.args.data)
557
+ elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
558
+ "detect",
559
+ "segment",
560
+ "pose",
561
+ "obb",
562
+ }:
563
+ data = check_det_dataset(self.args.data)
564
+ if "yaml_file" in data:
565
+ self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
566
+ except Exception as e:
567
+ raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
568
+ self.data = data
514
569
  return data["train"], data.get("val") or data.get("test")
515
570
 
516
571
  def setup_model(self):
@@ -518,13 +573,13 @@ class BaseTrainer:
518
573
  if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
519
574
  return
520
575
 
521
- model, weights = self.model, None
576
+ cfg, weights = self.model, None
522
577
  ckpt = None
523
- if str(model).endswith(".pt"):
524
- weights, ckpt = attempt_load_one_weight(model)
525
- cfg = ckpt["model"].yaml
526
- else:
527
- cfg = model
578
+ if str(self.model).endswith(".pt"):
579
+ weights, ckpt = attempt_load_one_weight(self.model)
580
+ cfg = weights.yaml
581
+ elif isinstance(self.args.pretrained, (str, Path)):
582
+ weights, _ = attempt_load_one_weight(self.args.pretrained)
528
583
  self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
529
584
  return ckpt
530
585
 
@@ -603,26 +658,31 @@ class BaseTrainer:
603
658
  def save_metrics(self, metrics):
604
659
  """Saves training metrics to a CSV file."""
605
660
  keys, vals = list(metrics.keys()), list(metrics.values())
606
- n = len(metrics) + 1 # number of cols
607
- s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
661
+ n = len(metrics) + 2 # number of cols
662
+ s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
663
+ t = time.time() - self.train_time_start
608
664
  with open(self.csv, "a") as f:
609
- f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
665
+ f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
610
666
 
611
667
  def plot_metrics(self):
612
668
  """Plot and display metrics visually."""
613
669
  pass
614
670
 
615
671
  def on_plot(self, name, data=None):
616
- """Registers plots (e.g. to be consumed in callbacks)"""
672
+ """Registers plots (e.g. to be consumed in callbacks)."""
617
673
  path = Path(name)
618
674
  self.plots[path] = {"data": data, "timestamp": time.time()}
619
675
 
620
676
  def final_eval(self):
621
677
  """Performs final evaluation and validation for object detection YOLO model."""
678
+ ckpt = {}
622
679
  for f in self.last, self.best:
623
680
  if f.exists():
624
- strip_optimizer(f) # strip optimizers
625
- if f is self.best:
681
+ if f is self.last:
682
+ ckpt = strip_optimizer(f)
683
+ elif f is self.best:
684
+ k = "train_results" # update best.pt train_metrics from last.pt
685
+ strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
626
686
  LOGGER.info(f"\nValidating {f}...")
627
687
  self.validator.args.plots = self.args.plots
628
688
  self.metrics = self.validator(model=f)
@@ -644,8 +704,13 @@ class BaseTrainer:
644
704
 
645
705
  resume = True
646
706
  self.args = get_cfg(ckpt_args)
647
- self.args.model = str(last) # reinstate model
648
- for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
707
+ self.args.model = self.args.resume = str(last) # reinstate model
708
+ for k in (
709
+ "imgsz",
710
+ "batch",
711
+ "device",
712
+ "close_mosaic",
713
+ ): # allow arg updates to reduce memory or update device on resume
649
714
  if k in overrides:
650
715
  setattr(self.args, k, overrides[k])
651
716
 
@@ -658,24 +723,21 @@ class BaseTrainer:
658
723
 
659
724
  def resume_training(self, ckpt):
660
725
  """Resume YOLO training from given epoch and best fitness."""
661
- if ckpt is None:
726
+ if ckpt is None or not self.resume:
662
727
  return
663
728
  best_fitness = 0.0
664
- start_epoch = ckpt["epoch"] + 1
665
- if ckpt["optimizer"] is not None:
729
+ start_epoch = ckpt.get("epoch", -1) + 1
730
+ if ckpt.get("optimizer", None) is not None:
666
731
  self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
667
732
  best_fitness = ckpt["best_fitness"]
668
733
  if self.ema and ckpt.get("ema"):
669
734
  self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
670
735
  self.ema.updates = ckpt["updates"]
671
- if self.resume:
672
- assert start_epoch > 0, (
673
- f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
674
- f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
675
- )
676
- LOGGER.info(
677
- f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
678
- )
736
+ assert start_epoch > 0, (
737
+ f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
738
+ f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
739
+ )
740
+ LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
679
741
  if self.epochs < start_epoch:
680
742
  LOGGER.info(
681
743
  f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
@@ -692,7 +754,7 @@ class BaseTrainer:
692
754
  self.train_loader.dataset.mosaic = False
693
755
  if hasattr(self.train_loader.dataset, "close_mosaic"):
694
756
  LOGGER.info("Closing dataloader mosaic")
695
- self.train_loader.dataset.close_mosaic(hyp=self.args)
757
+ self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
696
758
 
697
759
  def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
698
760
  """
@@ -712,7 +774,6 @@ class BaseTrainer:
712
774
  Returns:
713
775
  (torch.optim.Optimizer): The constructed optimizer.
714
776
  """
715
-
716
777
  g = [], [], [] # optimizer parameter groups
717
778
  bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
718
779
  if name == "auto":
@@ -736,7 +797,9 @@ class BaseTrainer:
736
797
  else: # weight (with decay)
737
798
  g[0].append(param)
738
799
 
739
- if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
800
+ optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
801
+ name = {x.lower(): x for x in optimizers}.get(name.lower())
802
+ if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
740
803
  optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
741
804
  elif name == "RMSProp":
742
805
  optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
@@ -744,15 +807,14 @@ class BaseTrainer:
744
807
  optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
745
808
  else:
746
809
  raise NotImplementedError(
747
- f"Optimizer '{name}' not found in list of available optimizers "
748
- f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
749
- "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
810
+ f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
811
+ "Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
750
812
  )
751
813
 
752
814
  optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
753
815
  optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
754
816
  LOGGER.info(
755
817
  f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
756
- f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
818
+ f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
757
819
  )
758
820
  return optimizer