birder 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. birder/adversarial/deepfool.py +2 -0
  2. birder/adversarial/simba.py +2 -0
  3. birder/common/masking.py +13 -4
  4. birder/inference/classification.py +1 -1
  5. birder/introspection/__init__.py +2 -0
  6. birder/introspection/base.py +0 -7
  7. birder/introspection/feature_pca.py +101 -0
  8. birder/kernels/soft_nms/soft_nms.cpp +5 -2
  9. birder/model_registry/model_registry.py +3 -2
  10. birder/net/convnext_v1.py +20 -0
  11. birder/net/fastvit.py +0 -1
  12. birder/net/flexivit.py +5 -0
  13. birder/net/focalnet.py +0 -1
  14. birder/net/rope_flexivit.py +7 -0
  15. birder/net/rope_vit.py +49 -4
  16. birder/net/smt.py +0 -1
  17. birder/net/ssl/ibot.py +0 -1
  18. birder/net/vit.py +166 -2
  19. birder/scripts/train.py +7 -6
  20. birder/scripts/train_barlow_twins.py +4 -3
  21. birder/scripts/train_byol.py +4 -3
  22. birder/scripts/train_capi.py +6 -5
  23. birder/scripts/train_data2vec.py +4 -3
  24. birder/scripts/train_data2vec2.py +4 -3
  25. birder/scripts/train_detection.py +7 -5
  26. birder/scripts/train_dino_v1.py +5 -4
  27. birder/scripts/train_dino_v2.py +69 -20
  28. birder/scripts/train_dino_v2_dist.py +70 -21
  29. birder/scripts/train_franca.py +8 -7
  30. birder/scripts/train_i_jepa.py +4 -3
  31. birder/scripts/train_ibot.py +5 -4
  32. birder/scripts/train_kd.py +8 -8
  33. birder/scripts/train_mim.py +4 -3
  34. birder/scripts/train_mmcr.py +4 -3
  35. birder/scripts/train_rotnet.py +5 -4
  36. birder/scripts/train_simclr.py +4 -3
  37. birder/scripts/train_vicreg.py +4 -3
  38. birder/tools/avg_model.py +24 -8
  39. birder/tools/introspection.py +35 -9
  40. birder/tools/show_iterator.py +1 -1
  41. birder/version.py +1 -1
  42. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
  43. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/RECORD +47 -46
  44. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
  45. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
  46. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
  47. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
birder/net/vit.py CHANGED
@@ -91,6 +91,126 @@ class PatchEmbed(nn.Module):
91
91
  return x
92
92
 
93
93
 
94
+ class Attention(nn.Module):
95
+ def __init__(
96
+ self,
97
+ dim: int,
98
+ num_heads: int,
99
+ attn_drop: float,
100
+ proj_drop: float,
101
+ qkv_bias: bool = True,
102
+ qk_norm: bool = False,
103
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
104
+ norm_layer_eps: float = 1e-6,
105
+ ) -> None:
106
+ super().__init__()
107
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
108
+
109
+ self.num_heads = num_heads
110
+ self.head_dim = dim // num_heads
111
+ self.scale = self.head_dim**-0.5
112
+
113
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
114
+ if qk_norm is True:
115
+ self.q_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
116
+ self.k_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
117
+ else:
118
+ self.q_norm = nn.Identity()
119
+ self.k_norm = nn.Identity()
120
+
121
+ self.attn_drop = nn.Dropout(attn_drop)
122
+ self.proj = nn.Linear(dim, dim)
123
+ self.proj_drop = nn.Dropout(proj_drop)
124
+
125
+ # Make the same interface as nn.MultiheadAttention forward for TorchScript compatibility
126
+ def forward(
127
+ self,
128
+ x: torch.Tensor,
129
+ key: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
130
+ value: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
131
+ need_weights: bool = False,
132
+ attn_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
133
+ average_attn_weights: bool = False,
134
+ is_causal: bool = False,
135
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
136
+ """
137
+ Apply multi-head self-attention to the input sequence
138
+
139
+ This module implements scaled dot-product attention over x and returns the
140
+ projected output. The method signature intentionally matches
141
+ torch.nn.MultiheadAttention.forward for TorchScript compatibility.
142
+
143
+ Compatibility notes
144
+ -------------------
145
+ The following parameters are accepted for API compatibility but are ignored by this implementation:
146
+ - key: ignored (keys are computed from x)
147
+ - value: ignored (values are computed from x)
148
+ - attn_mask: ignored (no external attention mask is applied)
149
+
150
+ Parameters
151
+ ----------
152
+ x
153
+ Input tensor of shape (B, N, C) where B is batch size, N is sequence length,
154
+ and C is embedding dimension.
155
+ key
156
+ Unused. Present for nn.MultiheadAttention-compatible signature.
157
+ value
158
+ Unused. Present for nn.MultiheadAttention-compatible signature.
159
+ need_weights
160
+ If True, also return attention weights computed explicitly. If False, uses
161
+ torch.nn.functional.scaled_dot_product_attention and returns None for attention weights.
162
+ attn_mask
163
+ Unused. Present for nn.MultiheadAttention-compatible signature.
164
+ average_attn_weights
165
+ If True and need_weights is True, average attention weights across heads
166
+ to shape (B, N, N). If False, return per-head weights of shape (B, num_heads, N, N).
167
+ is_causal
168
+ If True, apply a causal (upper-triangular) mask so positions cannot attend to future tokens.
169
+
170
+ Returns
171
+ -------
172
+ A tuple containing two elements:
173
+ - output: Tensor of shape (B, N, C)
174
+ - attn_weights: If need_weights is True attention weights, otherwise, None.
175
+ """
176
+
177
+ (B, N, C) = x.size()
178
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
179
+ (q, k, v) = qkv.unbind(0)
180
+ q = self.q_norm(q)
181
+ k = self.k_norm(k)
182
+
183
+ attn_weights: Optional[torch.Tensor] = None
184
+ if need_weights is True:
185
+ # Compute attention manually to get weights
186
+ attn = (q @ k.transpose(-2, -1)) * self.scale
187
+ if is_causal is True:
188
+ causal_mask = torch.triu(
189
+ torch.full((N, N), float("-inf"), dtype=attn.dtype, device=attn.device),
190
+ diagonal=1,
191
+ )
192
+ attn = attn + causal_mask
193
+
194
+ attn = attn.softmax(dim=-1)
195
+ attn_weights = attn
196
+ attn = self.attn_drop(attn)
197
+ x = attn @ v
198
+
199
+ if average_attn_weights is True:
200
+ # Average across heads: (B, num_heads, N, N) -> (B, N, N)
201
+ attn_weights = attn_weights.mean(dim=1)
202
+ else:
203
+ x = F.scaled_dot_product_attention( # pylint: disable=not-callable
204
+ q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, is_causal=is_causal, scale=self.scale
205
+ )
206
+
207
+ x = x.transpose(1, 2).reshape(B, N, C)
208
+ x = self.proj(x)
209
+ x = self.proj_drop(x)
210
+
211
+ return (x, attn_weights)
212
+
213
+
94
214
  class EncoderBlock(nn.Module):
95
215
  def __init__(
96
216
  self,
@@ -105,17 +225,37 @@ class EncoderBlock(nn.Module):
105
225
  norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
106
226
  norm_layer_eps: float = 1e-6,
107
227
  mlp_layer: Callable[..., nn.Module] = FFN,
228
+ qkv_bias: bool = True,
229
+ qk_norm: bool = False,
108
230
  ) -> None:
109
231
  super().__init__()
110
232
  self.need_attn = False
111
233
  self.is_causal = False
234
+ self.use_custom_attn = qk_norm is True
112
235
 
113
236
  if mlp_dim is None:
114
237
  mlp_dim = hidden_dim * 4
115
238
 
116
239
  # Attention block
117
240
  self.ln1 = norm_layer(hidden_dim, eps=norm_layer_eps)
118
- self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
241
+
242
+ if self.use_custom_attn is False:
243
+ # Prefer PyTorch's built-in MultiheadAttention for the "standard" case
244
+ self.self_attention = nn.MultiheadAttention(
245
+ hidden_dim, num_heads, dropout=attention_dropout, bias=qkv_bias, batch_first=True
246
+ )
247
+ else:
248
+ self.self_attention = Attention(
249
+ hidden_dim,
250
+ num_heads=num_heads,
251
+ attn_drop=attention_dropout,
252
+ proj_drop=0.0,
253
+ qkv_bias=qkv_bias,
254
+ qk_norm=qk_norm,
255
+ norm_layer=norm_layer,
256
+ norm_layer_eps=norm_layer_eps,
257
+ )
258
+
119
259
  self.drop_path1 = StochasticDepth(drop_path, mode="row")
120
260
  if layer_scale_init_value is not None:
121
261
  self.layer_scale_1 = LayerScale(hidden_dim, layer_scale_init_value)
@@ -148,10 +288,11 @@ class EncoderBlock(nn.Module):
148
288
  branch1,
149
289
  branch1,
150
290
  need_weights=self.need_attn,
151
- attn_mask=attn_mask,
291
+ attn_mask=attn_mask, # Ignored on the custom attention
152
292
  average_attn_weights=False,
153
293
  is_causal=self.is_causal,
154
294
  )
295
+
155
296
  branch1 = self.layer_scale_1(branch1)
156
297
  branch1 = self.drop_path1(branch1) + x
157
298
 
@@ -181,6 +322,8 @@ class Encoder(nn.Module):
181
322
  attention_dropout: float,
182
323
  dpr: list[float],
183
324
  pre_norm: bool = False,
325
+ qkv_bias: bool = True,
326
+ qk_norm: bool = False,
184
327
  activation_layer: Callable[..., nn.Module] = nn.GELU,
185
328
  layer_scale_init_value: Optional[float] = None,
186
329
  norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
@@ -211,6 +354,8 @@ class Encoder(nn.Module):
211
354
  norm_layer=norm_layer,
212
355
  norm_layer_eps=norm_layer_eps,
213
356
  mlp_layer=mlp_layer,
357
+ qkv_bias=qkv_bias,
358
+ qk_norm=qk_norm,
214
359
  )
215
360
  )
216
361
 
@@ -267,6 +412,8 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
267
412
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
268
413
  pre_norm: bool = self.config.get("pre_norm", False)
269
414
  post_norm: bool = self.config.get("post_norm", True)
415
+ qkv_bias: bool = self.config.get("qkv_bias", True)
416
+ qk_norm: bool = self.config.get("qk_norm", False)
270
417
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
271
418
  class_token: bool = self.config.get("class_token", True)
272
419
  attn_pool_head: bool = self.config.get("attn_pool_head", False)
@@ -351,6 +498,8 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
351
498
  attention_dropout,
352
499
  dpr,
353
500
  pre_norm=pre_norm,
501
+ qkv_bias=qkv_bias,
502
+ qk_norm=qk_norm,
354
503
  activation_layer=act_layer,
355
504
  layer_scale_init_value=layer_scale_init_value,
356
505
  norm_layer=norm_layer,
@@ -389,6 +538,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
389
538
  drop_path=0,
390
539
  activation_layer=act_layer,
391
540
  norm_layer=norm_layer,
541
+ norm_layer_eps=norm_layer_eps,
392
542
  mlp_layer=mlp_layer,
393
543
  )
394
544
 
@@ -846,6 +996,20 @@ registry.register_model_config(
846
996
  "drop_path_rate": 0.1,
847
997
  },
848
998
  )
999
+ registry.register_model_config(
1000
+ "vit_b16_qkn_ls",
1001
+ ViT,
1002
+ config={
1003
+ "patch_size": 16,
1004
+ "num_layers": 12,
1005
+ "num_heads": 12,
1006
+ "hidden_dim": 768,
1007
+ "mlp_dim": 3072,
1008
+ "layer_scale_init_value": 1e-5,
1009
+ "qk_norm": True,
1010
+ "drop_path_rate": 0.1,
1011
+ },
1012
+ )
849
1013
  registry.register_model_config(
850
1014
  "vit_b16_pn_quick_gelu",
851
1015
  ViT,
birder/scripts/train.py CHANGED
@@ -474,14 +474,15 @@ def train(args: argparse.Namespace) -> None:
474
474
  if virtual_epoch_mode is True:
475
475
  train_iter = iter(training_loader)
476
476
 
477
+ running_loss = training_utils.SmoothedValue(window_size=64)
478
+ running_val_loss = training_utils.SmoothedValue()
479
+ train_accuracy = training_utils.SmoothedValue(window_size=64)
480
+ val_accuracy = training_utils.SmoothedValue()
481
+
477
482
  logger.info(f"Starting training with learning rate of {last_lr}")
478
483
  for epoch in range(begin_epoch, args.stop_epoch):
479
484
  tic = time.time()
480
485
  net.train()
481
- running_loss = training_utils.SmoothedValue(window_size=64)
482
- running_val_loss = training_utils.SmoothedValue()
483
- train_accuracy = training_utils.SmoothedValue(window_size=64)
484
- val_accuracy = training_utils.SmoothedValue()
485
486
 
486
487
  if args.distributed is True or virtual_epoch_mode is True:
487
488
  train_sampler.set_epoch(epoch)
@@ -566,7 +567,7 @@ def train(args: argparse.Namespace) -> None:
566
567
  train_accuracy.update(training_utils.accuracy(targets, outputs.detach()))
567
568
 
568
569
  # Write statistics
569
- if i % args.log_interval == 0 or i == last_batch_idx:
570
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
570
571
  time_now = time.time()
571
572
  time_cost = time_now - start_time
572
573
  iters_processed_in_interval = i - last_idx
@@ -806,6 +807,7 @@ def get_args_parser() -> argparse.ArgumentParser:
806
807
  formatter_class=cli.ArgumentHelpFormatter,
807
808
  )
808
809
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
810
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
809
811
  parser.add_argument(
810
812
  "--model-config",
811
813
  action=cli.FlexibleDictAction,
@@ -814,7 +816,6 @@ def get_args_parser() -> argparse.ArgumentParser:
814
816
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
815
817
  ),
816
818
  )
817
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
818
819
  parser.add_argument("--reset-head", default=False, action="store_true", help="reset the classification head")
819
820
  parser.add_argument(
820
821
  "--freeze-body",
@@ -358,11 +358,12 @@ def train(args: argparse.Namespace) -> None:
358
358
  if virtual_epoch_mode is True:
359
359
  train_iter = iter(training_loader)
360
360
 
361
+ running_loss = training_utils.SmoothedValue()
362
+
361
363
  logger.info(f"Starting training with learning rate of {last_lr}")
362
364
  for epoch in range(begin_epoch, args.stop_epoch):
363
365
  tic = time.time()
364
366
  net.train()
365
- running_loss = training_utils.SmoothedValue()
366
367
 
367
368
  if args.distributed is True or virtual_epoch_mode is True:
368
369
  train_sampler.set_epoch(epoch)
@@ -426,7 +427,7 @@ def train(args: argparse.Namespace) -> None:
426
427
  running_loss.update(loss.detach())
427
428
 
428
429
  # Write statistics
429
- if i % args.log_interval == 0 or i == last_batch_idx:
430
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
430
431
  time_now = time.time()
431
432
  time_cost = time_now - start_time
432
433
  iters_processed_in_interval = i - last_idx
@@ -566,6 +567,7 @@ def get_args_parser() -> argparse.ArgumentParser:
566
567
  formatter_class=cli.ArgumentHelpFormatter,
567
568
  )
568
569
  parser.add_argument("-n", "--network", type=str, help="the neural network to train")
570
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
569
571
  parser.add_argument(
570
572
  "--model-config",
571
573
  action=cli.FlexibleDictAction,
@@ -583,7 +585,6 @@ def get_args_parser() -> argparse.ArgumentParser:
583
585
  help="projector mlp dimensions",
584
586
  )
585
587
  parser.add_argument("--off-lambda", type=float, default=0.0051, help="weight on off-diagonal terms")
586
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
587
588
  training_cli.add_optimization_args(parser)
588
589
  training_cli.add_lr_wd_args(parser)
589
590
  training_cli.add_lr_scheduler_args(parser)
@@ -370,11 +370,12 @@ def train(args: argparse.Namespace) -> None:
370
370
  if virtual_epoch_mode is True:
371
371
  train_iter = iter(training_loader)
372
372
 
373
+ running_loss = training_utils.SmoothedValue()
374
+
373
375
  logger.info(f"Starting training with learning rate of {last_lr}")
374
376
  for epoch in range(begin_epoch, args.stop_epoch):
375
377
  tic = time.time()
376
378
  net.train()
377
- running_loss = training_utils.SmoothedValue()
378
379
 
379
380
  if args.distributed is True or virtual_epoch_mode is True:
380
381
  train_sampler.set_epoch(epoch)
@@ -449,7 +450,7 @@ def train(args: argparse.Namespace) -> None:
449
450
  running_loss.update(loss.detach())
450
451
 
451
452
  # Write statistics
452
- if i % args.log_interval == 0 or i == last_batch_idx:
453
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
453
454
  time_now = time.time()
454
455
  time_cost = time_now - start_time
455
456
  iters_processed_in_interval = i - last_idx
@@ -590,6 +591,7 @@ def get_args_parser() -> argparse.ArgumentParser:
590
591
  formatter_class=cli.ArgumentHelpFormatter,
591
592
  )
592
593
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
594
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
593
595
  parser.add_argument(
594
596
  "--model-config",
595
597
  action=cli.FlexibleDictAction,
@@ -611,7 +613,6 @@ def get_args_parser() -> argparse.ArgumentParser:
611
613
  default=0.99,
612
614
  help="base EMA parameter for teacher update, set a higher value with small batches",
613
615
  )
614
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
615
616
  training_cli.add_optimization_args(parser)
616
617
  training_cli.add_lr_wd_args(parser)
617
618
  training_cli.add_lr_scheduler_args(parser)
@@ -444,13 +444,14 @@ def train(args: argparse.Namespace) -> None:
444
444
  if virtual_epoch_mode is True:
445
445
  train_iter = iter(training_loader)
446
446
 
447
+ running_loss = training_utils.SmoothedValue()
448
+ running_clustering_loss = training_utils.SmoothedValue()
449
+ running_target_entropy = training_utils.SmoothedValue()
450
+
447
451
  logger.info(f"Starting training with learning rate of {last_lr}")
448
452
  for epoch in range(begin_epoch, args.stop_epoch):
449
453
  tic = time.time()
450
454
  net.train()
451
- running_loss = training_utils.SmoothedValue()
452
- running_clustering_loss = training_utils.SmoothedValue()
453
- running_target_entropy = training_utils.SmoothedValue()
454
455
 
455
456
  if args.sinkhorn_queue_size is not None:
456
457
  queue_active = epoch > args.sinkhorn_queue_warmup_epochs
@@ -564,7 +565,7 @@ def train(args: argparse.Namespace) -> None:
564
565
  running_target_entropy.update(target_entropy.detach())
565
566
 
566
567
  # Write statistics
567
- if i % args.log_interval == 0 or i == last_batch_idx:
568
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
568
569
  time_now = time.time()
569
570
  time_cost = time_now - start_time
570
571
  iters_processed_in_interval = i - last_idx
@@ -737,6 +738,7 @@ def get_args_parser() -> argparse.ArgumentParser:
737
738
  formatter_class=cli.ArgumentHelpFormatter,
738
739
  )
739
740
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
741
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
740
742
  parser.add_argument(
741
743
  "--model-config",
742
744
  action=cli.FlexibleDictAction,
@@ -768,7 +770,6 @@ def get_args_parser() -> argparse.ArgumentParser:
768
770
  default=0,
769
771
  help="number of initial epochs to disable Sinkhorn queueing",
770
772
  )
771
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
772
773
  training_cli.add_optimization_args(parser)
773
774
  training_cli.add_lr_wd_args(parser)
774
775
  training_cli.add_lr_scheduler_args(parser)
@@ -384,11 +384,12 @@ def train(args: argparse.Namespace) -> None:
384
384
  if virtual_epoch_mode is True:
385
385
  train_iter = iter(training_loader)
386
386
 
387
+ running_loss = training_utils.SmoothedValue()
388
+
387
389
  logger.info(f"Starting training with learning rate of {last_lr}")
388
390
  for epoch in range(begin_epoch, args.stop_epoch):
389
391
  tic = time.time()
390
392
  net.train()
391
- running_loss = training_utils.SmoothedValue()
392
393
 
393
394
  if args.distributed is True or virtual_epoch_mode is True:
394
395
  train_sampler.set_epoch(epoch)
@@ -463,7 +464,7 @@ def train(args: argparse.Namespace) -> None:
463
464
  running_loss.update(loss.detach())
464
465
 
465
466
  # Write statistics
466
- if i % args.log_interval == 0 or i == last_batch_idx:
467
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
467
468
  time_now = time.time()
468
469
  time_cost = time_now - start_time
469
470
  iters_processed_in_interval = i - last_idx
@@ -603,6 +604,7 @@ def get_args_parser() -> argparse.ArgumentParser:
603
604
  formatter_class=cli.ArgumentHelpFormatter,
604
605
  )
605
606
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
607
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
606
608
  parser.add_argument(
607
609
  "--model-config",
608
610
  action=cli.FlexibleDictAction,
@@ -617,7 +619,6 @@ def get_args_parser() -> argparse.ArgumentParser:
617
619
  default=0.999,
618
620
  help="base EMA parameter for teacher update, set a higher value with small batches",
619
621
  )
620
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
621
622
  training_cli.add_optimization_args(parser)
622
623
  training_cli.add_lr_wd_args(parser)
623
624
  training_cli.add_lr_scheduler_args(parser)
@@ -393,11 +393,12 @@ def train(args: argparse.Namespace) -> None:
393
393
  if virtual_epoch_mode is True:
394
394
  train_iter = iter(training_loader)
395
395
 
396
+ running_loss = training_utils.SmoothedValue()
397
+
396
398
  logger.info(f"Starting training with learning rate of {last_lr}")
397
399
  for epoch in range(begin_epoch, args.stop_epoch):
398
400
  tic = time.time()
399
401
  net.train()
400
- running_loss = training_utils.SmoothedValue()
401
402
 
402
403
  if args.distributed is True or virtual_epoch_mode is True:
403
404
  train_sampler.set_epoch(epoch)
@@ -473,7 +474,7 @@ def train(args: argparse.Namespace) -> None:
473
474
  running_loss.update(loss.detach())
474
475
 
475
476
  # Write statistics
476
- if i % args.log_interval == 0 or i == last_batch_idx:
477
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
477
478
  time_now = time.time()
478
479
  time_cost = time_now - start_time
479
480
  iters_processed_in_interval = i - last_idx
@@ -615,6 +616,7 @@ def get_args_parser() -> argparse.ArgumentParser:
615
616
  formatter_class=cli.ArgumentHelpFormatter,
616
617
  )
617
618
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
619
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
618
620
  parser.add_argument(
619
621
  "--model-config",
620
622
  action=cli.FlexibleDictAction,
@@ -635,7 +637,6 @@ def get_args_parser() -> argparse.ArgumentParser:
635
637
  default=0.9998,
636
638
  help="base EMA parameter for teacher update, set a higher value with small batches",
637
639
  )
638
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
639
640
  training_cli.add_optimization_args(parser)
640
641
  training_cli.add_lr_wd_args(parser)
641
642
  training_cli.add_lr_scheduler_args(parser)
@@ -538,12 +538,14 @@ def train(args: argparse.Namespace) -> None:
538
538
  if virtual_epoch_mode is True:
539
539
  train_iter = iter(training_loader)
540
540
 
541
+ running_loss = training_utils.SmoothedValue()
542
+ loss_trackers: dict[str, training_utils.SmoothedValue] = {}
543
+
541
544
  logger.info(f"Starting training with learning rate of {last_lr}")
542
545
  for epoch in range(begin_epoch, args.stop_epoch):
543
546
  tic = time.time()
544
547
  net.train()
545
- running_loss = training_utils.SmoothedValue()
546
- loss_trackers: dict[str, training_utils.SmoothedValue] = {}
548
+
547
549
  validation_metrics.reset()
548
550
 
549
551
  if args.distributed is True or virtual_epoch_mode is True:
@@ -634,7 +636,7 @@ def train(args: argparse.Namespace) -> None:
634
636
  loss_trackers[key].update(value.detach())
635
637
 
636
638
  # Write statistics
637
- if i % args.log_interval == 0 or i == last_batch_idx:
639
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
638
640
  time_now = time.time()
639
641
  time_cost = time_now - start_time
640
642
  iters_processed_in_interval = i - last_idx
@@ -889,6 +891,7 @@ def get_args_parser() -> argparse.ArgumentParser:
889
891
  formatter_class=cli.ArgumentHelpFormatter,
890
892
  )
891
893
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
894
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
892
895
  parser.add_argument(
893
896
  "--model-config",
894
897
  action=cli.FlexibleDictAction,
@@ -897,8 +900,8 @@ def get_args_parser() -> argparse.ArgumentParser:
897
900
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
898
901
  ),
899
902
  )
900
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
901
903
  parser.add_argument("--backbone", type=str, help="the neural network to used as backbone")
904
+ parser.add_argument("--backbone-tag", type=str, help="backbone training log tag (loading only)")
902
905
  parser.add_argument(
903
906
  "--backbone-model-config",
904
907
  action=cli.FlexibleDictAction,
@@ -907,7 +910,6 @@ def get_args_parser() -> argparse.ArgumentParser:
907
910
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
908
911
  ),
909
912
  )
910
- parser.add_argument("--backbone-tag", type=str, help="backbone training log tag (loading only)")
911
913
  parser.add_argument("--backbone-epoch", type=int, help="load backbone weights from selected epoch")
912
914
  parser.add_argument(
913
915
  "--backbone-pretrained",
@@ -480,12 +480,13 @@ def train(args: argparse.Namespace) -> None:
480
480
  if virtual_epoch_mode is True:
481
481
  train_iter = iter(training_loader)
482
482
 
483
+ running_loss = training_utils.SmoothedValue()
484
+ train_proto_agreement = training_utils.SmoothedValue()
485
+
483
486
  logger.info(f"Starting training with learning rate of {last_lr}")
484
487
  for epoch in range(begin_epoch, args.stop_epoch):
485
488
  tic = time.time()
486
489
  net.train()
487
- running_loss = training_utils.SmoothedValue()
488
- train_proto_agreement = training_utils.SmoothedValue()
489
490
 
490
491
  if args.distributed is True or virtual_epoch_mode is True:
491
492
  train_sampler.set_epoch(epoch)
@@ -581,7 +582,7 @@ def train(args: argparse.Namespace) -> None:
581
582
  train_proto_agreement.update(training_utils.accuracy(pred_teacher, pred_student))
582
583
 
583
584
  # Write statistics
584
- if i % args.log_interval == 0 or i == last_batch_idx:
585
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
585
586
  time_now = time.time()
586
587
  time_cost = time_now - start_time
587
588
  iters_processed_in_interval = i - last_idx
@@ -733,6 +734,7 @@ def get_args_parser() -> argparse.ArgumentParser:
733
734
  formatter_class=cli.ArgumentHelpFormatter,
734
735
  )
735
736
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
737
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
736
738
  parser.add_argument(
737
739
  "--model-config",
738
740
  action=cli.FlexibleDictAction,
@@ -788,7 +790,6 @@ def get_args_parser() -> argparse.ArgumentParser:
788
790
  parser.add_argument(
789
791
  "--local-crop-size", type=int, nargs="+", default=[96, 96], metavar=("H", "W"), help="local view size"
790
792
  )
791
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
792
793
  parser.add_argument(
793
794
  "--backbone-epoch",
794
795
  type=int,