birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/__init__.py +13 -0
- birder/adversarial/base.py +101 -0
- birder/adversarial/deepfool.py +173 -0
- birder/adversarial/fgsm.py +51 -18
- birder/adversarial/pgd.py +79 -28
- birder/adversarial/simba.py +172 -0
- birder/common/training_cli.py +11 -3
- birder/common/training_utils.py +18 -1
- birder/inference/data_parallel.py +1 -2
- birder/introspection/__init__.py +10 -6
- birder/introspection/attention_rollout.py +122 -54
- birder/introspection/base.py +73 -29
- birder/introspection/gradcam.py +71 -100
- birder/introspection/guided_backprop.py +146 -72
- birder/introspection/transformer_attribution.py +182 -0
- birder/net/detection/deformable_detr.py +14 -12
- birder/net/detection/detr.py +7 -3
- birder/net/detection/rt_detr_v1.py +3 -3
- birder/net/detection/yolo_v3.py +6 -11
- birder/net/detection/yolo_v4.py +7 -18
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/fastvit.py +1 -1
- birder/net/mim/mae_vit.py +7 -8
- birder/net/pit.py +1 -1
- birder/net/resnet_v1.py +94 -34
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +4 -2
- birder/results/gui.py +15 -2
- birder/scripts/predict_detection.py +33 -1
- birder/scripts/train.py +24 -17
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +12 -9
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +42 -18
- birder/scripts/train_dino_v1.py +10 -7
- birder/scripts/train_dino_v2.py +10 -7
- birder/scripts/train_dino_v2_dist.py +17 -7
- birder/scripts/train_franca.py +10 -7
- birder/scripts/train_i_jepa.py +17 -13
- birder/scripts/train_ibot.py +10 -7
- birder/scripts/train_kd.py +24 -18
- birder/scripts/train_mim.py +11 -10
- birder/scripts/train_mmcr.py +10 -7
- birder/scripts/train_rotnet.py +10 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/__main__.py +6 -2
- birder/tools/adversarial.py +147 -96
- birder/tools/auto_anchors.py +361 -0
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +58 -31
- birder/version.py +1 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
birder/scripts/train.py
CHANGED
|
@@ -160,8 +160,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
160
160
|
|
|
161
161
|
num_outputs = len(class_to_idx)
|
|
162
162
|
batch_size: int = args.batch_size
|
|
163
|
-
|
|
164
|
-
|
|
163
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
164
|
+
model_ema_steps: int = args.model_ema_steps
|
|
165
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
165
166
|
|
|
166
167
|
# Set data iterators
|
|
167
168
|
if args.mixup_alpha is not None or args.cutmix is True:
|
|
@@ -220,8 +221,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
220
221
|
pin_memory=True,
|
|
221
222
|
)
|
|
222
223
|
|
|
223
|
-
optimizer_steps_per_epoch = math.ceil(len(training_loader) /
|
|
224
|
-
assert args.model_ema is False or
|
|
224
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
225
|
+
assert args.model_ema is False or model_ema_steps <= optimizer_steps_per_epoch
|
|
225
226
|
|
|
226
227
|
last_batch_idx = len(training_loader) - 1
|
|
227
228
|
begin_epoch = 1
|
|
@@ -317,20 +318,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
317
318
|
|
|
318
319
|
# Learning rate scaling
|
|
319
320
|
lr = training_utils.scale_lr(args)
|
|
320
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
321
321
|
|
|
322
322
|
if args.lr_scheduler_update == "epoch":
|
|
323
323
|
step_update = False
|
|
324
|
-
|
|
324
|
+
scheduler_steps_per_epoch = 1
|
|
325
325
|
elif args.lr_scheduler_update == "step":
|
|
326
326
|
step_update = True
|
|
327
|
-
|
|
327
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
328
328
|
else:
|
|
329
329
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
330
330
|
|
|
331
331
|
# Optimizer and learning rate scheduler
|
|
332
332
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
333
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
333
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
334
334
|
if args.compile_opt is True:
|
|
335
335
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
336
336
|
|
|
@@ -356,11 +356,14 @@ def train(args: argparse.Namespace) -> None:
|
|
|
356
356
|
optimizer.step()
|
|
357
357
|
lrs = []
|
|
358
358
|
for _ in range(begin_epoch, epochs):
|
|
359
|
-
for _ in range(
|
|
359
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
360
360
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
361
361
|
scheduler.step()
|
|
362
362
|
|
|
363
|
-
plt.plot(
|
|
363
|
+
plt.plot(
|
|
364
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False),
|
|
365
|
+
lrs,
|
|
366
|
+
)
|
|
364
367
|
plt.show()
|
|
365
368
|
raise SystemExit(0)
|
|
366
369
|
|
|
@@ -368,15 +371,15 @@ def train(args: argparse.Namespace) -> None:
|
|
|
368
371
|
# Distributed (DDP) and Model EMA
|
|
369
372
|
#
|
|
370
373
|
if args.model_ema_warmup is not None:
|
|
371
|
-
|
|
374
|
+
ema_warmup_steps = args.model_ema_warmup * optimizer_steps_per_epoch
|
|
372
375
|
elif args.warmup_epochs is not None:
|
|
373
|
-
|
|
376
|
+
ema_warmup_steps = args.warmup_epochs * optimizer_steps_per_epoch
|
|
374
377
|
elif args.warmup_steps is not None:
|
|
375
|
-
|
|
378
|
+
ema_warmup_steps = args.warmup_steps
|
|
376
379
|
else:
|
|
377
|
-
|
|
380
|
+
ema_warmup_steps = 0
|
|
378
381
|
|
|
379
|
-
logger.debug(f"EMA warmup
|
|
382
|
+
logger.debug(f"EMA warmup steps = {ema_warmup_steps}")
|
|
380
383
|
net_without_ddp = net
|
|
381
384
|
if args.distributed is True:
|
|
382
385
|
net = torch.nn.parallel.DistributedDataParallel(
|
|
@@ -474,6 +477,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
474
477
|
#
|
|
475
478
|
# Training loop
|
|
476
479
|
#
|
|
480
|
+
optimizer_step = (begin_epoch - 1) * optimizer_steps_per_epoch
|
|
477
481
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
478
482
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
479
483
|
tic = time.time()
|
|
@@ -542,10 +546,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
542
546
|
if step_update is True:
|
|
543
547
|
scheduler.step()
|
|
544
548
|
|
|
549
|
+
if optimizer_update is True:
|
|
550
|
+
optimizer_step += 1
|
|
551
|
+
|
|
545
552
|
# Exponential moving average
|
|
546
|
-
if args.model_ema is True and
|
|
553
|
+
if args.model_ema is True and optimizer_update is True and optimizer_step % model_ema_steps == 0:
|
|
547
554
|
model_ema.update_parameters(net)
|
|
548
|
-
if
|
|
555
|
+
if ema_warmup_steps > 0 and optimizer_step <= ema_warmup_steps:
|
|
549
556
|
# Reset ema buffer to keep copying weights during warmup period
|
|
550
557
|
model_ema.n_averaged.fill_(0) # pylint: disable=no-member
|
|
551
558
|
|
|
@@ -149,7 +149,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
149
149
|
logger.info(f"Training on {len(training_dataset):,} samples")
|
|
150
150
|
|
|
151
151
|
batch_size: int = args.batch_size
|
|
152
|
-
|
|
152
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
153
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
153
154
|
|
|
154
155
|
# Data loaders and samplers
|
|
155
156
|
if args.distributed is True:
|
|
@@ -181,6 +182,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
181
182
|
drop_last=args.drop_last,
|
|
182
183
|
)
|
|
183
184
|
|
|
185
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
184
186
|
last_batch_idx = len(training_loader) - 1
|
|
185
187
|
begin_epoch = 1
|
|
186
188
|
epochs = args.epochs + 1
|
|
@@ -244,20 +246,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
244
246
|
|
|
245
247
|
# Learning rate scaling
|
|
246
248
|
lr = training_utils.scale_lr(args)
|
|
247
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
248
249
|
|
|
249
250
|
if args.lr_scheduler_update == "epoch":
|
|
250
251
|
step_update = False
|
|
251
|
-
|
|
252
|
+
scheduler_steps_per_epoch = 1
|
|
252
253
|
elif args.lr_scheduler_update == "step":
|
|
253
254
|
step_update = True
|
|
254
|
-
|
|
255
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
255
256
|
else:
|
|
256
257
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
257
258
|
|
|
258
259
|
# Optimizer and learning rate scheduler
|
|
259
260
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
260
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
261
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
261
262
|
if args.compile_opt is True:
|
|
262
263
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
263
264
|
|
|
@@ -283,11 +284,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
283
284
|
optimizer.step()
|
|
284
285
|
lrs = []
|
|
285
286
|
for _ in range(begin_epoch, epochs):
|
|
286
|
-
for _ in range(
|
|
287
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
287
288
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
288
289
|
scheduler.step()
|
|
289
290
|
|
|
290
|
-
plt.plot(
|
|
291
|
+
plt.plot(
|
|
292
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
293
|
+
)
|
|
291
294
|
plt.show()
|
|
292
295
|
raise SystemExit(0)
|
|
293
296
|
|
birder/scripts/train_byol.py
CHANGED
|
@@ -151,7 +151,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
151
151
|
logger.info(f"Training on {len(training_dataset):,} samples")
|
|
152
152
|
|
|
153
153
|
batch_size: int = args.batch_size
|
|
154
|
-
|
|
154
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
155
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
155
156
|
|
|
156
157
|
# Data loaders and samplers
|
|
157
158
|
if args.distributed is True:
|
|
@@ -183,6 +184,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
183
184
|
drop_last=args.drop_last,
|
|
184
185
|
)
|
|
185
186
|
|
|
187
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
186
188
|
last_batch_idx = len(training_loader) - 1
|
|
187
189
|
begin_epoch = 1
|
|
188
190
|
epochs = args.epochs + 1
|
|
@@ -256,20 +258,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
256
258
|
|
|
257
259
|
# Learning rate scaling
|
|
258
260
|
lr = training_utils.scale_lr(args)
|
|
259
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
260
261
|
|
|
261
262
|
if args.lr_scheduler_update == "epoch":
|
|
262
263
|
step_update = False
|
|
263
|
-
|
|
264
|
+
scheduler_steps_per_epoch = 1
|
|
264
265
|
elif args.lr_scheduler_update == "step":
|
|
265
266
|
step_update = True
|
|
266
|
-
|
|
267
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
267
268
|
else:
|
|
268
269
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
269
270
|
|
|
270
271
|
# Optimizer and learning rate scheduler
|
|
271
272
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
272
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
273
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
273
274
|
if args.compile_opt is True:
|
|
274
275
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
275
276
|
|
|
@@ -295,11 +296,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
295
296
|
optimizer.step()
|
|
296
297
|
lrs = []
|
|
297
298
|
for _ in range(begin_epoch, epochs):
|
|
298
|
-
for _ in range(
|
|
299
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
299
300
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
300
301
|
scheduler.step()
|
|
301
302
|
|
|
302
|
-
plt.plot(
|
|
303
|
+
plt.plot(
|
|
304
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
305
|
+
)
|
|
303
306
|
plt.show()
|
|
304
307
|
raise SystemExit(0)
|
|
305
308
|
|
birder/scripts/train_capi.py
CHANGED
|
@@ -115,7 +115,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
115
115
|
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
116
116
|
|
|
117
117
|
batch_size: int = args.batch_size
|
|
118
|
-
|
|
118
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
119
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
119
120
|
|
|
120
121
|
begin_epoch = 1
|
|
121
122
|
epochs = args.epochs + 1
|
|
@@ -278,6 +279,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
278
279
|
drop_last=args.drop_last,
|
|
279
280
|
)
|
|
280
281
|
|
|
282
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
281
283
|
last_batch_idx = len(training_loader) - 1
|
|
282
284
|
|
|
283
285
|
#
|
|
@@ -300,22 +302,21 @@ def train(args: argparse.Namespace) -> None:
|
|
|
300
302
|
# Learning rate scaling
|
|
301
303
|
lr = training_utils.scale_lr(args)
|
|
302
304
|
clustering_lr = lr / 2
|
|
303
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
304
305
|
|
|
305
306
|
if args.lr_scheduler_update == "epoch":
|
|
306
307
|
step_update = False
|
|
307
|
-
|
|
308
|
+
scheduler_steps_per_epoch = 1
|
|
308
309
|
elif args.lr_scheduler_update == "step":
|
|
309
310
|
step_update = True
|
|
310
|
-
|
|
311
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
311
312
|
else:
|
|
312
313
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
313
314
|
|
|
314
315
|
# Optimizer and learning rate scheduler
|
|
315
316
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
316
317
|
clustering_optimizer = torch.optim.AdamW(teacher.head.parameters(), lr=clustering_lr, betas=[0.9, 0.95])
|
|
317
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
318
|
-
clustering_scheduler = training_utils.get_scheduler(clustering_optimizer,
|
|
318
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
319
|
+
clustering_scheduler = training_utils.get_scheduler(clustering_optimizer, scheduler_steps_per_epoch, args)
|
|
319
320
|
if args.compile_opt is True:
|
|
320
321
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
321
322
|
clustering_optimizer.step = torch.compile(clustering_optimizer.step, fullgraph=False)
|
|
@@ -324,7 +325,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
324
325
|
if args.warmup_epochs is not None:
|
|
325
326
|
warmup_epochs = args.warmup_epochs
|
|
326
327
|
elif args.warmup_steps is not None:
|
|
327
|
-
warmup_epochs = args.warmup_steps /
|
|
328
|
+
warmup_epochs = args.warmup_steps / scheduler_steps_per_epoch
|
|
328
329
|
else:
|
|
329
330
|
warmup_epochs = 0.0
|
|
330
331
|
|
|
@@ -353,11 +354,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
353
354
|
optimizer.step()
|
|
354
355
|
lrs = []
|
|
355
356
|
for _ in range(begin_epoch, epochs):
|
|
356
|
-
for _ in range(
|
|
357
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
357
358
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
358
359
|
scheduler.step()
|
|
359
360
|
|
|
360
|
-
plt.plot(
|
|
361
|
+
plt.plot(
|
|
362
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
363
|
+
)
|
|
361
364
|
plt.show()
|
|
362
365
|
raise SystemExit(0)
|
|
363
366
|
|
birder/scripts/train_data2vec.py
CHANGED
|
@@ -106,7 +106,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
106
106
|
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
107
107
|
|
|
108
108
|
batch_size: int = args.batch_size
|
|
109
|
-
|
|
109
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
110
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
110
111
|
|
|
111
112
|
begin_epoch = 1
|
|
112
113
|
epochs = args.epochs + 1
|
|
@@ -245,6 +246,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
245
246
|
drop_last=args.drop_last,
|
|
246
247
|
)
|
|
247
248
|
|
|
249
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
248
250
|
last_batch_idx = len(training_loader) - 1
|
|
249
251
|
|
|
250
252
|
#
|
|
@@ -266,20 +268,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
266
268
|
|
|
267
269
|
# Learning rate scaling
|
|
268
270
|
lr = training_utils.scale_lr(args)
|
|
269
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
270
271
|
|
|
271
272
|
if args.lr_scheduler_update == "epoch":
|
|
272
273
|
step_update = False
|
|
273
|
-
|
|
274
|
+
scheduler_steps_per_epoch = 1
|
|
274
275
|
elif args.lr_scheduler_update == "step":
|
|
275
276
|
step_update = True
|
|
276
|
-
|
|
277
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
277
278
|
else:
|
|
278
279
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
279
280
|
|
|
280
281
|
# Optimizer and learning rate scheduler
|
|
281
282
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
282
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
283
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
283
284
|
if args.compile_opt is True:
|
|
284
285
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
285
286
|
|
|
@@ -310,11 +311,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
310
311
|
optimizer.step()
|
|
311
312
|
lrs = []
|
|
312
313
|
for _ in range(begin_epoch, epochs):
|
|
313
|
-
for _ in range(
|
|
314
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
314
315
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
315
316
|
scheduler.step()
|
|
316
317
|
|
|
317
|
-
plt.plot(
|
|
318
|
+
plt.plot(
|
|
319
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
320
|
+
)
|
|
318
321
|
plt.show()
|
|
319
322
|
raise SystemExit(0)
|
|
320
323
|
|
|
@@ -112,7 +112,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
112
112
|
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
113
113
|
|
|
114
114
|
batch_size: int = args.batch_size
|
|
115
|
-
|
|
115
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
116
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
116
117
|
|
|
117
118
|
begin_epoch = 1
|
|
118
119
|
epochs = args.epochs + 1
|
|
@@ -254,6 +255,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
254
255
|
drop_last=args.drop_last,
|
|
255
256
|
)
|
|
256
257
|
|
|
258
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
257
259
|
last_batch_idx = len(training_loader) - 1
|
|
258
260
|
|
|
259
261
|
#
|
|
@@ -275,20 +277,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
275
277
|
|
|
276
278
|
# Learning rate scaling
|
|
277
279
|
lr = training_utils.scale_lr(args)
|
|
278
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
279
280
|
|
|
280
281
|
if args.lr_scheduler_update == "epoch":
|
|
281
282
|
step_update = False
|
|
282
|
-
|
|
283
|
+
scheduler_steps_per_epoch = 1
|
|
283
284
|
elif args.lr_scheduler_update == "step":
|
|
284
285
|
step_update = True
|
|
285
|
-
|
|
286
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
286
287
|
else:
|
|
287
288
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
288
289
|
|
|
289
290
|
# Optimizer and learning rate scheduler
|
|
290
291
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
291
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
292
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
292
293
|
if args.compile_opt is True:
|
|
293
294
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
294
295
|
|
|
@@ -319,11 +320,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
319
320
|
optimizer.step()
|
|
320
321
|
lrs = []
|
|
321
322
|
for _ in range(begin_epoch, epochs):
|
|
322
|
-
for _ in range(
|
|
323
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
323
324
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
324
325
|
scheduler.step()
|
|
325
326
|
|
|
326
|
-
plt.plot(
|
|
327
|
+
plt.plot(
|
|
328
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
329
|
+
)
|
|
327
330
|
plt.show()
|
|
328
331
|
raise SystemExit(0)
|
|
329
332
|
|
|
@@ -192,8 +192,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
192
192
|
|
|
193
193
|
num_outputs = len(class_to_idx) # Does not include background class
|
|
194
194
|
batch_size: int = args.batch_size
|
|
195
|
-
|
|
196
|
-
|
|
195
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
196
|
+
model_ema_steps: int = args.model_ema_steps
|
|
197
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
197
198
|
|
|
198
199
|
# Data loaders and samplers
|
|
199
200
|
(train_sampler, validation_sampler) = training_utils.get_samplers(args, training_dataset, validation_dataset)
|
|
@@ -224,8 +225,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
224
225
|
drop_last=args.drop_last,
|
|
225
226
|
)
|
|
226
227
|
|
|
227
|
-
optimizer_steps_per_epoch = math.ceil(len(training_loader) /
|
|
228
|
-
assert args.model_ema is False or
|
|
228
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
229
|
+
assert args.model_ema is False or model_ema_steps <= optimizer_steps_per_epoch
|
|
229
230
|
|
|
230
231
|
last_batch_idx = len(training_loader) - 1
|
|
231
232
|
begin_epoch = 1
|
|
@@ -369,20 +370,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
369
370
|
|
|
370
371
|
# Learning rate scaling
|
|
371
372
|
lr = training_utils.scale_lr(args)
|
|
372
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
373
373
|
|
|
374
374
|
if args.lr_scheduler_update == "epoch":
|
|
375
375
|
step_update = False
|
|
376
|
-
|
|
376
|
+
scheduler_steps_per_epoch = 1
|
|
377
377
|
elif args.lr_scheduler_update == "step":
|
|
378
378
|
step_update = True
|
|
379
|
-
|
|
379
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
380
380
|
else:
|
|
381
381
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
382
382
|
|
|
383
383
|
# Optimizer and learning rate scheduler
|
|
384
384
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
385
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
385
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
386
386
|
if args.compile_opt is True:
|
|
387
387
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
388
388
|
|
|
@@ -408,11 +408,14 @@ def train(args: argparse.Namespace) -> None:
|
|
|
408
408
|
optimizer.step()
|
|
409
409
|
lrs = []
|
|
410
410
|
for _ in range(begin_epoch, epochs):
|
|
411
|
-
for _ in range(
|
|
411
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
412
412
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
413
413
|
scheduler.step()
|
|
414
414
|
|
|
415
|
-
plt.plot(
|
|
415
|
+
plt.plot(
|
|
416
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False),
|
|
417
|
+
lrs,
|
|
418
|
+
)
|
|
416
419
|
plt.show()
|
|
417
420
|
raise SystemExit(0)
|
|
418
421
|
|
|
@@ -420,15 +423,15 @@ def train(args: argparse.Namespace) -> None:
|
|
|
420
423
|
# Distributed (DDP) and Model EMA
|
|
421
424
|
#
|
|
422
425
|
if args.model_ema_warmup is not None:
|
|
423
|
-
|
|
426
|
+
ema_warmup_steps = args.model_ema_warmup * optimizer_steps_per_epoch
|
|
424
427
|
elif args.warmup_epochs is not None:
|
|
425
|
-
|
|
428
|
+
ema_warmup_steps = args.warmup_epochs * optimizer_steps_per_epoch
|
|
426
429
|
elif args.warmup_steps is not None:
|
|
427
|
-
|
|
430
|
+
ema_warmup_steps = args.warmup_steps
|
|
428
431
|
else:
|
|
429
|
-
|
|
432
|
+
ema_warmup_steps = 0
|
|
430
433
|
|
|
431
|
-
logger.debug(f"EMA warmup
|
|
434
|
+
logger.debug(f"EMA warmup steps = {ema_warmup_steps}")
|
|
432
435
|
net_without_ddp = net
|
|
433
436
|
if args.distributed is True:
|
|
434
437
|
net = torch.nn.parallel.DistributedDataParallel(
|
|
@@ -532,11 +535,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
532
535
|
#
|
|
533
536
|
# Training loop
|
|
534
537
|
#
|
|
538
|
+
optimizer_step = (begin_epoch - 1) * optimizer_steps_per_epoch
|
|
535
539
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
536
540
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
537
541
|
tic = time.time()
|
|
538
542
|
net.train()
|
|
539
543
|
running_loss = training_utils.SmoothedValue()
|
|
544
|
+
loss_trackers: dict[str, training_utils.SmoothedValue] = {}
|
|
540
545
|
validation_metrics.reset()
|
|
541
546
|
|
|
542
547
|
if args.distributed is True:
|
|
@@ -598,16 +603,28 @@ def train(args: argparse.Namespace) -> None:
|
|
|
598
603
|
if step_update is True:
|
|
599
604
|
scheduler.step()
|
|
600
605
|
|
|
606
|
+
if optimizer_update is True:
|
|
607
|
+
optimizer_step += 1
|
|
608
|
+
|
|
601
609
|
# Exponential moving average
|
|
602
|
-
if args.model_ema is True and
|
|
610
|
+
if args.model_ema is True and optimizer_update is True and optimizer_step % model_ema_steps == 0:
|
|
603
611
|
model_ema.update_parameters(net)
|
|
604
|
-
if
|
|
612
|
+
if ema_warmup_steps > 0 and optimizer_step <= ema_warmup_steps:
|
|
605
613
|
# Reset ema buffer to keep copying weights during warmup period
|
|
606
614
|
model_ema.n_averaged.fill_(0) # pylint: disable=no-member
|
|
607
615
|
|
|
608
616
|
# Statistics
|
|
609
617
|
running_loss.update(loss.detach())
|
|
610
618
|
|
|
619
|
+
# Dynamically create trackers on first batch
|
|
620
|
+
if len(loss_trackers) == 0:
|
|
621
|
+
for key in losses.keys():
|
|
622
|
+
loss_trackers[key] = training_utils.SmoothedValue()
|
|
623
|
+
|
|
624
|
+
# Update individual loss trackers
|
|
625
|
+
for key, value in losses.items():
|
|
626
|
+
loss_trackers[key].update(value.detach())
|
|
627
|
+
|
|
611
628
|
# Write statistics
|
|
612
629
|
if (i == last_batch_idx) or (i + 1) % args.log_interval == 0:
|
|
613
630
|
time_now = time.time()
|
|
@@ -624,6 +641,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
624
641
|
cur_lr = float(max(scheduler.get_last_lr()))
|
|
625
642
|
|
|
626
643
|
running_loss.synchronize_between_processes(device)
|
|
644
|
+
for tracker in loss_trackers.values():
|
|
645
|
+
tracker.synchronize_between_processes(device)
|
|
646
|
+
|
|
627
647
|
with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
|
|
628
648
|
log.info(
|
|
629
649
|
f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
|
|
@@ -636,9 +656,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
636
656
|
)
|
|
637
657
|
|
|
638
658
|
if training_utils.is_local_primary(args) is True:
|
|
659
|
+
loss_dict = {"training": running_loss.avg}
|
|
660
|
+
loss_dict.update({k: v.avg for k, v in loss_trackers.items()})
|
|
639
661
|
summary_writer.add_scalars(
|
|
640
662
|
"loss",
|
|
641
|
-
|
|
663
|
+
loss_dict,
|
|
642
664
|
((epoch - 1) * len(training_dataset)) + (i * batch_size * args.world_size),
|
|
643
665
|
)
|
|
644
666
|
|
|
@@ -649,6 +671,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
649
671
|
|
|
650
672
|
# Epoch training metrics
|
|
651
673
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_loss: {running_loss.global_avg:.4f}")
|
|
674
|
+
for key, tracker in loss_trackers.items():
|
|
675
|
+
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} {key}: {tracker.global_avg:.4f}")
|
|
652
676
|
|
|
653
677
|
# Validation
|
|
654
678
|
eval_model.eval()
|
birder/scripts/train_dino_v1.py
CHANGED
|
@@ -186,7 +186,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
186
186
|
logger.info(f"Training on {len(training_dataset):,} samples")
|
|
187
187
|
|
|
188
188
|
batch_size: int = args.batch_size
|
|
189
|
-
|
|
189
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
190
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
190
191
|
|
|
191
192
|
# Data loaders and samplers
|
|
192
193
|
if args.distributed is True:
|
|
@@ -218,6 +219,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
218
219
|
drop_last=args.drop_last,
|
|
219
220
|
)
|
|
220
221
|
|
|
222
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
221
223
|
last_batch_idx = len(training_loader) - 1
|
|
222
224
|
begin_epoch = 1
|
|
223
225
|
epochs = args.epochs + 1
|
|
@@ -352,20 +354,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
352
354
|
|
|
353
355
|
# Learning rate scaling
|
|
354
356
|
lr = training_utils.scale_lr(args)
|
|
355
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
356
357
|
|
|
357
358
|
if args.lr_scheduler_update == "epoch":
|
|
358
359
|
step_update = False
|
|
359
|
-
|
|
360
|
+
scheduler_steps_per_epoch = 1
|
|
360
361
|
elif args.lr_scheduler_update == "step":
|
|
361
362
|
step_update = True
|
|
362
|
-
|
|
363
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
363
364
|
else:
|
|
364
365
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
365
366
|
|
|
366
367
|
# Optimizer and learning rate scheduler
|
|
367
368
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
368
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
369
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
369
370
|
if args.compile_opt is True:
|
|
370
371
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
371
372
|
|
|
@@ -398,11 +399,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
398
399
|
optimizer.step()
|
|
399
400
|
lrs = []
|
|
400
401
|
for _ in range(begin_epoch, epochs):
|
|
401
|
-
for _ in range(
|
|
402
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
402
403
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
403
404
|
scheduler.step()
|
|
404
405
|
|
|
405
|
-
plt.plot(
|
|
406
|
+
plt.plot(
|
|
407
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
408
|
+
)
|
|
406
409
|
plt.show()
|
|
407
410
|
raise SystemExit(0)
|
|
408
411
|
|