ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
|
@@ -23,14 +23,31 @@ from torch import nn, optim
|
|
|
23
23
|
from ultralytics.cfg import get_cfg, get_save_dir
|
|
24
24
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
|
25
25
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
|
26
|
-
from ultralytics.utils import (
|
|
27
|
-
|
|
26
|
+
from ultralytics.utils import (
|
|
27
|
+
DEFAULT_CFG,
|
|
28
|
+
LOGGER,
|
|
29
|
+
RANK,
|
|
30
|
+
TQDM,
|
|
31
|
+
__version__,
|
|
32
|
+
callbacks,
|
|
33
|
+
clean_url,
|
|
34
|
+
colorstr,
|
|
35
|
+
emojis,
|
|
36
|
+
yaml_save,
|
|
37
|
+
)
|
|
28
38
|
from ultralytics.utils.autobatch import check_train_batch_size
|
|
29
39
|
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
|
30
40
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
|
31
41
|
from ultralytics.utils.files import get_latest_run
|
|
32
|
-
from ultralytics.utils.torch_utils import (
|
|
33
|
-
|
|
42
|
+
from ultralytics.utils.torch_utils import (
|
|
43
|
+
EarlyStopping,
|
|
44
|
+
ModelEMA,
|
|
45
|
+
de_parallel,
|
|
46
|
+
init_seeds,
|
|
47
|
+
one_cycle,
|
|
48
|
+
select_device,
|
|
49
|
+
strip_optimizer,
|
|
50
|
+
)
|
|
34
51
|
|
|
35
52
|
|
|
36
53
|
class BaseTrainer:
|
|
@@ -89,12 +106,12 @@ class BaseTrainer:
|
|
|
89
106
|
# Dirs
|
|
90
107
|
self.save_dir = get_save_dir(self.args)
|
|
91
108
|
self.args.name = self.save_dir.name # update name for loggers
|
|
92
|
-
self.wdir = self.save_dir /
|
|
109
|
+
self.wdir = self.save_dir / "weights" # weights dir
|
|
93
110
|
if RANK in (-1, 0):
|
|
94
111
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
|
95
112
|
self.args.save_dir = str(self.save_dir)
|
|
96
|
-
yaml_save(self.save_dir /
|
|
97
|
-
self.last, self.best = self.wdir /
|
|
113
|
+
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
|
114
|
+
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
|
98
115
|
self.save_period = self.args.save_period
|
|
99
116
|
|
|
100
117
|
self.batch_size = self.args.batch
|
|
@@ -104,18 +121,18 @@ class BaseTrainer:
|
|
|
104
121
|
print_args(vars(self.args))
|
|
105
122
|
|
|
106
123
|
# Device
|
|
107
|
-
if self.device.type in (
|
|
124
|
+
if self.device.type in ("cpu", "mps"):
|
|
108
125
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
|
109
126
|
|
|
110
127
|
# Model and Dataset
|
|
111
128
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
|
112
129
|
try:
|
|
113
|
-
if self.args.task ==
|
|
130
|
+
if self.args.task == "classify":
|
|
114
131
|
self.data = check_cls_dataset(self.args.data)
|
|
115
|
-
elif self.args.data.split(
|
|
132
|
+
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ("detect", "segment", "pose"):
|
|
116
133
|
self.data = check_det_dataset(self.args.data)
|
|
117
|
-
if
|
|
118
|
-
self.args.data = self.data[
|
|
134
|
+
if "yaml_file" in self.data:
|
|
135
|
+
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
|
119
136
|
except Exception as e:
|
|
120
137
|
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
|
121
138
|
|
|
@@ -131,8 +148,8 @@ class BaseTrainer:
|
|
|
131
148
|
self.fitness = None
|
|
132
149
|
self.loss = None
|
|
133
150
|
self.tloss = None
|
|
134
|
-
self.loss_names = [
|
|
135
|
-
self.csv = self.save_dir /
|
|
151
|
+
self.loss_names = ["Loss"]
|
|
152
|
+
self.csv = self.save_dir / "results.csv"
|
|
136
153
|
self.plot_idx = [0, 1, 2]
|
|
137
154
|
|
|
138
155
|
# Callbacks
|
|
@@ -156,7 +173,7 @@ class BaseTrainer:
|
|
|
156
173
|
def train(self):
|
|
157
174
|
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
|
158
175
|
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
|
159
|
-
world_size = len(self.args.device.split(
|
|
176
|
+
world_size = len(self.args.device.split(","))
|
|
160
177
|
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
|
161
178
|
world_size = len(self.args.device)
|
|
162
179
|
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
|
@@ -165,14 +182,16 @@ class BaseTrainer:
|
|
|
165
182
|
world_size = 0
|
|
166
183
|
|
|
167
184
|
# Run subprocess if DDP training, else train normally
|
|
168
|
-
if world_size > 1 and
|
|
185
|
+
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
|
169
186
|
# Argument checks
|
|
170
187
|
if self.args.rect:
|
|
171
188
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
|
172
189
|
self.args.rect = False
|
|
173
190
|
if self.args.batch == -1:
|
|
174
|
-
LOGGER.warning(
|
|
175
|
-
|
|
191
|
+
LOGGER.warning(
|
|
192
|
+
"WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
|
193
|
+
"default 'batch=16'"
|
|
194
|
+
)
|
|
176
195
|
self.args.batch = 16
|
|
177
196
|
|
|
178
197
|
# Command
|
|
@@ -199,37 +218,45 @@ class BaseTrainer:
|
|
|
199
218
|
def _setup_ddp(self, world_size):
|
|
200
219
|
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
|
201
220
|
torch.cuda.set_device(RANK)
|
|
202
|
-
self.device = torch.device(
|
|
221
|
+
self.device = torch.device("cuda", RANK)
|
|
203
222
|
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
|
204
|
-
os.environ[
|
|
223
|
+
os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
|
205
224
|
dist.init_process_group(
|
|
206
|
-
|
|
225
|
+
"nccl" if dist.is_nccl_available() else "gloo",
|
|
207
226
|
timeout=timedelta(seconds=10800), # 3 hours
|
|
208
227
|
rank=RANK,
|
|
209
|
-
world_size=world_size
|
|
228
|
+
world_size=world_size,
|
|
229
|
+
)
|
|
210
230
|
|
|
211
231
|
def _setup_train(self, world_size):
|
|
212
232
|
"""Builds dataloaders and optimizer on correct rank process."""
|
|
213
233
|
|
|
214
234
|
# Model
|
|
215
|
-
self.run_callbacks(
|
|
235
|
+
self.run_callbacks("on_pretrain_routine_start")
|
|
216
236
|
ckpt = self.setup_model()
|
|
217
237
|
self.model = self.model.to(self.device)
|
|
218
238
|
self.set_model_attributes()
|
|
219
239
|
|
|
220
240
|
# Freeze layers
|
|
221
|
-
freeze_list =
|
|
222
|
-
self.args.freeze
|
|
223
|
-
|
|
224
|
-
|
|
241
|
+
freeze_list = (
|
|
242
|
+
self.args.freeze
|
|
243
|
+
if isinstance(self.args.freeze, list)
|
|
244
|
+
else range(self.args.freeze)
|
|
245
|
+
if isinstance(self.args.freeze, int)
|
|
246
|
+
else []
|
|
247
|
+
)
|
|
248
|
+
always_freeze_names = [".dfl"] # always freeze these layers
|
|
249
|
+
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
|
|
225
250
|
for k, v in self.model.named_parameters():
|
|
226
251
|
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
|
227
252
|
if any(x in k for x in freeze_layer_names):
|
|
228
253
|
LOGGER.info(f"Freezing layer '{k}'")
|
|
229
254
|
v.requires_grad = False
|
|
230
255
|
elif not v.requires_grad:
|
|
231
|
-
LOGGER.info(
|
|
232
|
-
|
|
256
|
+
LOGGER.info(
|
|
257
|
+
f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
|
|
258
|
+
"See ultralytics.engine.trainer for customization of frozen layers."
|
|
259
|
+
)
|
|
233
260
|
v.requires_grad = True
|
|
234
261
|
|
|
235
262
|
# Check AMP
|
|
@@ -246,7 +273,7 @@ class BaseTrainer:
|
|
|
246
273
|
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
|
247
274
|
|
|
248
275
|
# Check imgsz
|
|
249
|
-
gs = max(int(self.model.stride.max() if hasattr(self.model,
|
|
276
|
+
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
|
|
250
277
|
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
|
251
278
|
self.stride = gs # for multi-scale training
|
|
252
279
|
|
|
@@ -256,15 +283,14 @@ class BaseTrainer:
|
|
|
256
283
|
|
|
257
284
|
# Dataloaders
|
|
258
285
|
batch_size = self.batch_size // max(world_size, 1)
|
|
259
|
-
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode=
|
|
286
|
+
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
|
260
287
|
if RANK in (-1, 0):
|
|
261
288
|
# NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects.
|
|
262
|
-
self.test_loader = self.get_dataloader(
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
mode='val')
|
|
289
|
+
self.test_loader = self.get_dataloader(
|
|
290
|
+
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
|
291
|
+
)
|
|
266
292
|
self.validator = self.get_validator()
|
|
267
|
-
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix=
|
|
293
|
+
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
|
268
294
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
|
269
295
|
self.ema = ModelEMA(self.model)
|
|
270
296
|
if self.args.plots:
|
|
@@ -274,18 +300,20 @@ class BaseTrainer:
|
|
|
274
300
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
|
275
301
|
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
|
276
302
|
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
|
277
|
-
self.optimizer = self.build_optimizer(
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
303
|
+
self.optimizer = self.build_optimizer(
|
|
304
|
+
model=self.model,
|
|
305
|
+
name=self.args.optimizer,
|
|
306
|
+
lr=self.args.lr0,
|
|
307
|
+
momentum=self.args.momentum,
|
|
308
|
+
decay=weight_decay,
|
|
309
|
+
iterations=iterations,
|
|
310
|
+
)
|
|
283
311
|
# Scheduler
|
|
284
312
|
self._setup_scheduler()
|
|
285
313
|
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
|
286
314
|
self.resume_training(ckpt)
|
|
287
315
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
|
288
|
-
self.run_callbacks(
|
|
316
|
+
self.run_callbacks("on_pretrain_routine_end")
|
|
289
317
|
|
|
290
318
|
def _do_train(self, world_size=1):
|
|
291
319
|
"""Train completed, evaluate and plot if specified by arguments."""
|
|
@@ -299,19 +327,23 @@ class BaseTrainer:
|
|
|
299
327
|
self.epoch_time = None
|
|
300
328
|
self.epoch_time_start = time.time()
|
|
301
329
|
self.train_time_start = time.time()
|
|
302
|
-
self.run_callbacks(
|
|
303
|
-
LOGGER.info(
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
330
|
+
self.run_callbacks("on_train_start")
|
|
331
|
+
LOGGER.info(
|
|
332
|
+
f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
|
333
|
+
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
|
334
|
+
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
|
335
|
+
f'Starting training for '
|
|
336
|
+
f'{self.args.time} hours...'
|
|
337
|
+
if self.args.time
|
|
338
|
+
else f"{self.epochs} epochs..."
|
|
339
|
+
)
|
|
308
340
|
if self.args.close_mosaic:
|
|
309
341
|
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
|
310
342
|
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
|
311
343
|
epoch = self.epochs # predefine for resume fully trained model edge cases
|
|
312
344
|
for epoch in range(self.start_epoch, self.epochs):
|
|
313
345
|
self.epoch = epoch
|
|
314
|
-
self.run_callbacks(
|
|
346
|
+
self.run_callbacks("on_train_epoch_start")
|
|
315
347
|
self.model.train()
|
|
316
348
|
if RANK != -1:
|
|
317
349
|
self.train_loader.sampler.set_epoch(epoch)
|
|
@@ -327,7 +359,7 @@ class BaseTrainer:
|
|
|
327
359
|
self.tloss = None
|
|
328
360
|
self.optimizer.zero_grad()
|
|
329
361
|
for i, batch in pbar:
|
|
330
|
-
self.run_callbacks(
|
|
362
|
+
self.run_callbacks("on_train_batch_start")
|
|
331
363
|
# Warmup
|
|
332
364
|
ni = i + nb * epoch
|
|
333
365
|
if ni <= nw:
|
|
@@ -335,10 +367,11 @@ class BaseTrainer:
|
|
|
335
367
|
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
|
336
368
|
for j, x in enumerate(self.optimizer.param_groups):
|
|
337
369
|
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
|
338
|
-
x[
|
|
339
|
-
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x[
|
|
340
|
-
|
|
341
|
-
|
|
370
|
+
x["lr"] = np.interp(
|
|
371
|
+
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
|
|
372
|
+
)
|
|
373
|
+
if "momentum" in x:
|
|
374
|
+
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
|
342
375
|
|
|
343
376
|
# Forward
|
|
344
377
|
with torch.cuda.amp.autocast(self.amp):
|
|
@@ -346,8 +379,9 @@ class BaseTrainer:
|
|
|
346
379
|
self.loss, self.loss_items = self.model(batch)
|
|
347
380
|
if RANK != -1:
|
|
348
381
|
self.loss *= world_size
|
|
349
|
-
self.tloss = (
|
|
350
|
-
else self.loss_items
|
|
382
|
+
self.tloss = (
|
|
383
|
+
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
|
384
|
+
)
|
|
351
385
|
|
|
352
386
|
# Backward
|
|
353
387
|
self.scaler.scale(self.loss).backward()
|
|
@@ -368,24 +402,25 @@ class BaseTrainer:
|
|
|
368
402
|
break
|
|
369
403
|
|
|
370
404
|
# Log
|
|
371
|
-
mem = f
|
|
405
|
+
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
|
372
406
|
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
|
373
407
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
|
374
408
|
if RANK in (-1, 0):
|
|
375
409
|
pbar.set_description(
|
|
376
|
-
(
|
|
377
|
-
(f
|
|
378
|
-
|
|
410
|
+
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
|
411
|
+
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
|
412
|
+
)
|
|
413
|
+
self.run_callbacks("on_batch_end")
|
|
379
414
|
if self.args.plots and ni in self.plot_idx:
|
|
380
415
|
self.plot_training_samples(batch, ni)
|
|
381
416
|
|
|
382
|
-
self.run_callbacks(
|
|
417
|
+
self.run_callbacks("on_train_batch_end")
|
|
383
418
|
|
|
384
|
-
self.lr = {f
|
|
385
|
-
self.run_callbacks(
|
|
419
|
+
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
|
420
|
+
self.run_callbacks("on_train_epoch_end")
|
|
386
421
|
if RANK in (-1, 0):
|
|
387
422
|
final_epoch = epoch + 1 == self.epochs
|
|
388
|
-
self.ema.update_attr(self.model, include=[
|
|
423
|
+
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
|
389
424
|
|
|
390
425
|
# Validation
|
|
391
426
|
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
|
@@ -398,14 +433,14 @@ class BaseTrainer:
|
|
|
398
433
|
# Save model
|
|
399
434
|
if self.args.save or final_epoch:
|
|
400
435
|
self.save_model()
|
|
401
|
-
self.run_callbacks(
|
|
436
|
+
self.run_callbacks("on_model_save")
|
|
402
437
|
|
|
403
438
|
# Scheduler
|
|
404
439
|
t = time.time()
|
|
405
440
|
self.epoch_time = t - self.epoch_time_start
|
|
406
441
|
self.epoch_time_start = t
|
|
407
442
|
with warnings.catch_warnings():
|
|
408
|
-
warnings.simplefilter(
|
|
443
|
+
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
|
409
444
|
if self.args.time:
|
|
410
445
|
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
|
411
446
|
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
|
@@ -413,7 +448,7 @@ class BaseTrainer:
|
|
|
413
448
|
self.scheduler.last_epoch = self.epoch # do not move
|
|
414
449
|
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
|
415
450
|
self.scheduler.step()
|
|
416
|
-
self.run_callbacks(
|
|
451
|
+
self.run_callbacks("on_fit_epoch_end")
|
|
417
452
|
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
|
418
453
|
|
|
419
454
|
# Early Stopping
|
|
@@ -426,39 +461,43 @@ class BaseTrainer:
|
|
|
426
461
|
|
|
427
462
|
if RANK in (-1, 0):
|
|
428
463
|
# Do final val with best.pt
|
|
429
|
-
LOGGER.info(
|
|
430
|
-
|
|
464
|
+
LOGGER.info(
|
|
465
|
+
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
|
466
|
+
f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
|
|
467
|
+
)
|
|
431
468
|
self.final_eval()
|
|
432
469
|
if self.args.plots:
|
|
433
470
|
self.plot_metrics()
|
|
434
|
-
self.run_callbacks(
|
|
471
|
+
self.run_callbacks("on_train_end")
|
|
435
472
|
torch.cuda.empty_cache()
|
|
436
|
-
self.run_callbacks(
|
|
473
|
+
self.run_callbacks("teardown")
|
|
437
474
|
|
|
438
475
|
def save_model(self):
|
|
439
476
|
"""Save model training checkpoints with additional metadata."""
|
|
440
477
|
import pandas as pd # scope for faster startup
|
|
441
|
-
|
|
442
|
-
|
|
478
|
+
|
|
479
|
+
metrics = {**self.metrics, **{"fitness": self.fitness}}
|
|
480
|
+
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
|
|
443
481
|
ckpt = {
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
482
|
+
"epoch": self.epoch,
|
|
483
|
+
"best_fitness": self.best_fitness,
|
|
484
|
+
"model": deepcopy(de_parallel(self.model)).half(),
|
|
485
|
+
"ema": deepcopy(self.ema.ema).half(),
|
|
486
|
+
"updates": self.ema.updates,
|
|
487
|
+
"optimizer": self.optimizer.state_dict(),
|
|
488
|
+
"train_args": vars(self.args), # save as dict
|
|
489
|
+
"train_metrics": metrics,
|
|
490
|
+
"train_results": results,
|
|
491
|
+
"date": datetime.now().isoformat(),
|
|
492
|
+
"version": __version__,
|
|
493
|
+
}
|
|
455
494
|
|
|
456
495
|
# Save last and best
|
|
457
496
|
torch.save(ckpt, self.last)
|
|
458
497
|
if self.best_fitness == self.fitness:
|
|
459
498
|
torch.save(ckpt, self.best)
|
|
460
499
|
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
|
461
|
-
torch.save(ckpt, self.wdir / f
|
|
500
|
+
torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
|
|
462
501
|
|
|
463
502
|
@staticmethod
|
|
464
503
|
def get_dataset(data):
|
|
@@ -467,7 +506,7 @@ class BaseTrainer:
|
|
|
467
506
|
|
|
468
507
|
Returns None if data format is not recognized.
|
|
469
508
|
"""
|
|
470
|
-
return data[
|
|
509
|
+
return data["train"], data.get("val") or data.get("test")
|
|
471
510
|
|
|
472
511
|
def setup_model(self):
|
|
473
512
|
"""Load/create/download model for any task."""
|
|
@@ -476,9 +515,9 @@ class BaseTrainer:
|
|
|
476
515
|
|
|
477
516
|
model, weights = self.model, None
|
|
478
517
|
ckpt = None
|
|
479
|
-
if str(model).endswith(
|
|
518
|
+
if str(model).endswith(".pt"):
|
|
480
519
|
weights, ckpt = attempt_load_one_weight(model)
|
|
481
|
-
cfg = ckpt[
|
|
520
|
+
cfg = ckpt["model"].yaml
|
|
482
521
|
else:
|
|
483
522
|
cfg = model
|
|
484
523
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
|
@@ -505,7 +544,7 @@ class BaseTrainer:
|
|
|
505
544
|
The returned dict is expected to contain "fitness" key.
|
|
506
545
|
"""
|
|
507
546
|
metrics = self.validator(self)
|
|
508
|
-
fitness = metrics.pop(
|
|
547
|
+
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
|
509
548
|
if not self.best_fitness or self.best_fitness < fitness:
|
|
510
549
|
self.best_fitness = fitness
|
|
511
550
|
return metrics, fitness
|
|
@@ -516,24 +555,24 @@ class BaseTrainer:
|
|
|
516
555
|
|
|
517
556
|
def get_validator(self):
|
|
518
557
|
"""Returns a NotImplementedError when the get_validator function is called."""
|
|
519
|
-
raise NotImplementedError(
|
|
558
|
+
raise NotImplementedError("get_validator function not implemented in trainer")
|
|
520
559
|
|
|
521
|
-
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode=
|
|
560
|
+
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
|
522
561
|
"""Returns dataloader derived from torch.data.Dataloader."""
|
|
523
|
-
raise NotImplementedError(
|
|
562
|
+
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
|
524
563
|
|
|
525
|
-
def build_dataset(self, img_path, mode=
|
|
564
|
+
def build_dataset(self, img_path, mode="train", batch=None):
|
|
526
565
|
"""Build dataset."""
|
|
527
|
-
raise NotImplementedError(
|
|
566
|
+
raise NotImplementedError("build_dataset function not implemented in trainer")
|
|
528
567
|
|
|
529
|
-
def label_loss_items(self, loss_items=None, prefix=
|
|
568
|
+
def label_loss_items(self, loss_items=None, prefix="train"):
|
|
530
569
|
"""Returns a loss dict with labelled training loss items tensor."""
|
|
531
570
|
# Not needed for classification but necessary for segmentation & detection
|
|
532
|
-
return {
|
|
571
|
+
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
|
533
572
|
|
|
534
573
|
def set_model_attributes(self):
|
|
535
574
|
"""To set or update model parameters before training."""
|
|
536
|
-
self.model.names = self.data[
|
|
575
|
+
self.model.names = self.data["names"]
|
|
537
576
|
|
|
538
577
|
def build_targets(self, preds, targets):
|
|
539
578
|
"""Builds target tensors for training YOLO model."""
|
|
@@ -541,7 +580,7 @@ class BaseTrainer:
|
|
|
541
580
|
|
|
542
581
|
def progress_string(self):
|
|
543
582
|
"""Returns a string describing training progress."""
|
|
544
|
-
return
|
|
583
|
+
return ""
|
|
545
584
|
|
|
546
585
|
# TODO: may need to put these following functions into callback
|
|
547
586
|
def plot_training_samples(self, batch, ni):
|
|
@@ -556,9 +595,9 @@ class BaseTrainer:
|
|
|
556
595
|
"""Saves training metrics to a CSV file."""
|
|
557
596
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
|
558
597
|
n = len(metrics) + 1 # number of cols
|
|
559
|
-
s =
|
|
560
|
-
with open(self.csv,
|
|
561
|
-
f.write(s + (
|
|
598
|
+
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
|
|
599
|
+
with open(self.csv, "a") as f:
|
|
600
|
+
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
|
|
562
601
|
|
|
563
602
|
def plot_metrics(self):
|
|
564
603
|
"""Plot and display metrics visually."""
|
|
@@ -567,7 +606,7 @@ class BaseTrainer:
|
|
|
567
606
|
def on_plot(self, name, data=None):
|
|
568
607
|
"""Registers plots (e.g. to be consumed in callbacks)"""
|
|
569
608
|
path = Path(name)
|
|
570
|
-
self.plots[path] = {
|
|
609
|
+
self.plots[path] = {"data": data, "timestamp": time.time()}
|
|
571
610
|
|
|
572
611
|
def final_eval(self):
|
|
573
612
|
"""Performs final evaluation and validation for object detection YOLO model."""
|
|
@@ -575,11 +614,11 @@ class BaseTrainer:
|
|
|
575
614
|
if f.exists():
|
|
576
615
|
strip_optimizer(f) # strip optimizers
|
|
577
616
|
if f is self.best:
|
|
578
|
-
LOGGER.info(f
|
|
617
|
+
LOGGER.info(f"\nValidating {f}...")
|
|
579
618
|
self.validator.args.plots = self.args.plots
|
|
580
619
|
self.metrics = self.validator(model=f)
|
|
581
|
-
self.metrics.pop(
|
|
582
|
-
self.run_callbacks(
|
|
620
|
+
self.metrics.pop("fitness", None)
|
|
621
|
+
self.run_callbacks("on_fit_epoch_end")
|
|
583
622
|
|
|
584
623
|
def check_resume(self, overrides):
|
|
585
624
|
"""Check if resume checkpoint exists and update arguments accordingly."""
|
|
@@ -591,19 +630,21 @@ class BaseTrainer:
|
|
|
591
630
|
|
|
592
631
|
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
|
593
632
|
ckpt_args = attempt_load_weights(last).args
|
|
594
|
-
if not Path(ckpt_args[
|
|
595
|
-
ckpt_args[
|
|
633
|
+
if not Path(ckpt_args["data"]).exists():
|
|
634
|
+
ckpt_args["data"] = self.args.data
|
|
596
635
|
|
|
597
636
|
resume = True
|
|
598
637
|
self.args = get_cfg(ckpt_args)
|
|
599
638
|
self.args.model = str(last) # reinstate model
|
|
600
|
-
for k in
|
|
639
|
+
for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
|
|
601
640
|
if k in overrides:
|
|
602
641
|
setattr(self.args, k, overrides[k])
|
|
603
642
|
|
|
604
643
|
except Exception as e:
|
|
605
|
-
raise FileNotFoundError(
|
|
606
|
-
|
|
644
|
+
raise FileNotFoundError(
|
|
645
|
+
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
|
646
|
+
"i.e. 'yolo train resume model=path/to/last.pt'"
|
|
647
|
+
) from e
|
|
607
648
|
self.resume = resume
|
|
608
649
|
|
|
609
650
|
def resume_training(self, ckpt):
|
|
@@ -611,23 +652,26 @@ class BaseTrainer:
|
|
|
611
652
|
if ckpt is None:
|
|
612
653
|
return
|
|
613
654
|
best_fitness = 0.0
|
|
614
|
-
start_epoch = ckpt[
|
|
615
|
-
if ckpt[
|
|
616
|
-
self.optimizer.load_state_dict(ckpt[
|
|
617
|
-
best_fitness = ckpt[
|
|
618
|
-
if self.ema and ckpt.get(
|
|
619
|
-
self.ema.ema.load_state_dict(ckpt[
|
|
620
|
-
self.ema.updates = ckpt[
|
|
655
|
+
start_epoch = ckpt["epoch"] + 1
|
|
656
|
+
if ckpt["optimizer"] is not None:
|
|
657
|
+
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
|
658
|
+
best_fitness = ckpt["best_fitness"]
|
|
659
|
+
if self.ema and ckpt.get("ema"):
|
|
660
|
+
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
|
661
|
+
self.ema.updates = ckpt["updates"]
|
|
621
662
|
if self.resume:
|
|
622
|
-
assert start_epoch > 0,
|
|
623
|
-
f
|
|
663
|
+
assert start_epoch > 0, (
|
|
664
|
+
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
|
624
665
|
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
|
666
|
+
)
|
|
625
667
|
LOGGER.info(
|
|
626
|
-
f
|
|
668
|
+
f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
|
|
669
|
+
)
|
|
627
670
|
if self.epochs < start_epoch:
|
|
628
671
|
LOGGER.info(
|
|
629
|
-
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
|
630
|
-
|
|
672
|
+
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
|
673
|
+
)
|
|
674
|
+
self.epochs += ckpt["epoch"] # finetune additional epochs
|
|
631
675
|
self.best_fitness = best_fitness
|
|
632
676
|
self.start_epoch = start_epoch
|
|
633
677
|
if start_epoch > (self.epochs - self.args.close_mosaic):
|
|
@@ -635,13 +679,13 @@ class BaseTrainer:
|
|
|
635
679
|
|
|
636
680
|
def _close_dataloader_mosaic(self):
|
|
637
681
|
"""Update dataloaders to stop using mosaic augmentation."""
|
|
638
|
-
if hasattr(self.train_loader.dataset,
|
|
682
|
+
if hasattr(self.train_loader.dataset, "mosaic"):
|
|
639
683
|
self.train_loader.dataset.mosaic = False
|
|
640
|
-
if hasattr(self.train_loader.dataset,
|
|
641
|
-
LOGGER.info(
|
|
684
|
+
if hasattr(self.train_loader.dataset, "close_mosaic"):
|
|
685
|
+
LOGGER.info("Closing dataloader mosaic")
|
|
642
686
|
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
|
643
687
|
|
|
644
|
-
def build_optimizer(self, model, name=
|
|
688
|
+
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
|
645
689
|
"""
|
|
646
690
|
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
|
|
647
691
|
weight decay, and number of iterations.
|
|
@@ -661,41 +705,45 @@ class BaseTrainer:
|
|
|
661
705
|
"""
|
|
662
706
|
|
|
663
707
|
g = [], [], [] # optimizer parameter groups
|
|
664
|
-
bn = tuple(v for k, v in nn.__dict__.items() if
|
|
665
|
-
if name ==
|
|
666
|
-
LOGGER.info(
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
708
|
+
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
|
709
|
+
if name == "auto":
|
|
710
|
+
LOGGER.info(
|
|
711
|
+
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
|
|
712
|
+
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
|
|
713
|
+
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
|
|
714
|
+
)
|
|
715
|
+
nc = getattr(model, "nc", 10) # number of classes
|
|
670
716
|
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
|
671
|
-
name, lr, momentum = (
|
|
717
|
+
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
|
|
672
718
|
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
|
673
719
|
|
|
674
720
|
for module_name, module in model.named_modules():
|
|
675
721
|
for param_name, param in module.named_parameters(recurse=False):
|
|
676
|
-
fullname = f
|
|
677
|
-
if
|
|
722
|
+
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
|
723
|
+
if "bias" in fullname: # bias (no decay)
|
|
678
724
|
g[2].append(param)
|
|
679
725
|
elif isinstance(module, bn): # weight (no decay)
|
|
680
726
|
g[1].append(param)
|
|
681
727
|
else: # weight (with decay)
|
|
682
728
|
g[0].append(param)
|
|
683
729
|
|
|
684
|
-
if name in (
|
|
730
|
+
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
|
685
731
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
|
686
|
-
elif name ==
|
|
732
|
+
elif name == "RMSProp":
|
|
687
733
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
|
688
|
-
elif name ==
|
|
734
|
+
elif name == "SGD":
|
|
689
735
|
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
|
690
736
|
else:
|
|
691
737
|
raise NotImplementedError(
|
|
692
738
|
f"Optimizer '{name}' not found in list of available optimizers "
|
|
693
|
-
f
|
|
694
|
-
|
|
739
|
+
f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
|
|
740
|
+
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
|
|
741
|
+
)
|
|
695
742
|
|
|
696
|
-
optimizer.add_param_group({
|
|
697
|
-
optimizer.add_param_group({
|
|
743
|
+
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
|
|
744
|
+
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
|
698
745
|
LOGGER.info(
|
|
699
746
|
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
|
700
|
-
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
|
747
|
+
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
|
748
|
+
)
|
|
701
749
|
return optimizer
|