birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -74,7 +74,7 @@ class TrainCollator:
74
74
  def __call__(self, batch: Any) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
75
75
  B = len(batch)
76
76
  collated_batch = torch.utils.data.default_collate(batch)
77
- (enc_masks, pred_masks) = self.mask_generator(B)
77
+ enc_masks, pred_masks = self.mask_generator(B)
78
78
 
79
79
  return (collated_batch, enc_masks, pred_masks)
80
80
 
@@ -84,7 +84,7 @@ def train(args: argparse.Namespace) -> None:
84
84
  #
85
85
  # Initialize
86
86
  #
87
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
87
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
88
88
 
89
89
  if args.size is None:
90
90
  args.size = registry.get_default_size(args.network)
@@ -119,9 +119,9 @@ def train(args: argparse.Namespace) -> None:
119
119
  else:
120
120
  model_config = {"drop_path_rate": 0.0}
121
121
 
122
- backbone = registry.net_factory(args.network, sample_shape[1], 0, config=model_config, size=args.size)
122
+ backbone = registry.net_factory(args.network, 0, sample_shape[1], config=model_config, size=args.size)
123
123
  num_special_tokens = backbone.num_special_tokens
124
- target_backbone = registry.net_factory(args.network, sample_shape[1], 0, config=model_config, size=args.size)
124
+ target_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=model_config, size=args.size)
125
125
  encoder = I_JEPA(backbone)
126
126
  target_encoder = I_JEPA(target_backbone)
127
127
  target_encoder.load_state_dict(encoder.state_dict())
@@ -148,7 +148,7 @@ def train(args: argparse.Namespace) -> None:
148
148
 
149
149
  if args.resume_epoch is not None:
150
150
  begin_epoch = args.resume_epoch + 1
151
- (net, training_states) = fs_ops.load_simple_checkpoint(
151
+ net, training_states = fs_ops.load_simple_checkpoint(
152
152
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
153
153
  )
154
154
  encoder = net["encoder"]
@@ -198,11 +198,11 @@ def train(args: argparse.Namespace) -> None:
198
198
  elif args.wds is True:
199
199
  wds_path: str | list[str]
200
200
  if args.wds_info is not None:
201
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
201
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
202
202
  if args.wds_size is not None:
203
203
  dataset_size = args.wds_size
204
204
  else:
205
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
205
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
206
206
 
207
207
  training_dataset = make_wds_dataset(
208
208
  wds_path,
@@ -228,7 +228,7 @@ def train(args: argparse.Namespace) -> None:
228
228
 
229
229
  # Data loaders and samplers
230
230
  virtual_epoch_mode = args.steps_per_epoch is not None
231
- (train_sampler, _) = training_utils.get_samplers(
231
+ train_sampler, _ = training_utils.get_samplers(
232
232
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
233
233
  )
234
234
 
@@ -320,7 +320,7 @@ def train(args: argparse.Namespace) -> None:
320
320
  wd_schedule = None
321
321
 
322
322
  # Gradient scaler and AMP related tasks
323
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
323
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
324
324
 
325
325
  # Load states
326
326
  if args.load_states is True:
@@ -440,6 +440,9 @@ def train(args: argparse.Namespace) -> None:
440
440
  tic = time.time()
441
441
  net.train()
442
442
 
443
+ # Clear metrics
444
+ running_loss.clear()
445
+
443
446
  if args.distributed is True or virtual_epoch_mode is True:
444
447
  train_sampler.set_epoch(epoch)
445
448
 
@@ -107,7 +107,7 @@ def train(args: argparse.Namespace) -> None:
107
107
  #
108
108
  # Initialize
109
109
  #
110
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
110
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
111
111
 
112
112
  if args.size is None:
113
113
  args.size = registry.get_default_size(args.network)
@@ -136,7 +136,7 @@ def train(args: argparse.Namespace) -> None:
136
136
 
137
137
  network_name = get_mim_network_name("ibot", encoder=args.network, tag=args.tag)
138
138
 
139
- student_backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
139
+ student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
140
140
  if args.model_config is not None:
141
141
  teacher_model_config = args.model_config.copy()
142
142
  teacher_model_config.update({"drop_path_rate": 0.0})
@@ -144,7 +144,7 @@ def train(args: argparse.Namespace) -> None:
144
144
  teacher_model_config = {"drop_path_rate": 0.0}
145
145
 
146
146
  teacher_backbone = registry.net_factory(
147
- args.network, sample_shape[1], 0, config=teacher_model_config, size=args.size
147
+ args.network, 0, sample_shape[1], config=teacher_model_config, size=args.size
148
148
  )
149
149
  student_backbone.set_dynamic_size()
150
150
  student = iBOT(
@@ -204,7 +204,7 @@ def train(args: argparse.Namespace) -> None:
204
204
 
205
205
  if args.resume_epoch is not None:
206
206
  begin_epoch = args.resume_epoch + 1
207
- (net, training_states) = fs_ops.load_simple_checkpoint(
207
+ net, training_states = fs_ops.load_simple_checkpoint(
208
208
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
209
209
  )
210
210
  student = net["student"]
@@ -266,11 +266,11 @@ def train(args: argparse.Namespace) -> None:
266
266
  elif args.wds is True:
267
267
  wds_path: str | list[str]
268
268
  if args.wds_info is not None:
269
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
269
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
270
270
  if args.wds_size is not None:
271
271
  dataset_size = args.wds_size
272
272
  else:
273
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
273
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
274
274
 
275
275
  training_dataset = make_wds_dataset(
276
276
  wds_path,
@@ -296,7 +296,7 @@ def train(args: argparse.Namespace) -> None:
296
296
 
297
297
  # Data loaders and samplers
298
298
  virtual_epoch_mode = args.steps_per_epoch is not None
299
- (train_sampler, _) = training_utils.get_samplers(
299
+ train_sampler, _ = training_utils.get_samplers(
300
300
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
301
301
  )
302
302
 
@@ -387,7 +387,7 @@ def train(args: argparse.Namespace) -> None:
387
387
  wd_schedule = None
388
388
 
389
389
  # Gradient scaler and AMP related tasks
390
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
390
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
391
391
 
392
392
  # Load states
393
393
  if args.load_states is True:
@@ -507,6 +507,10 @@ def train(args: argparse.Namespace) -> None:
507
507
  tic = time.time()
508
508
  net.train()
509
509
 
510
+ # Clear metrics
511
+ running_loss.clear()
512
+ train_proto_agreement.clear()
513
+
510
514
  if args.distributed is True or virtual_epoch_mode is True:
511
515
  train_sampler.set_epoch(epoch)
512
516
 
@@ -553,12 +557,12 @@ def train(args: argparse.Namespace) -> None:
553
557
  with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
554
558
  # Global views
555
559
  with torch.no_grad():
556
- (teacher_embedding, teacher_features) = teacher(torch.concat(images[:2], dim=0), None)
560
+ teacher_embedding, teacher_features = teacher(torch.concat(images[:2], dim=0), None)
557
561
 
558
- (student_embedding, student_features) = student(torch.concat(images[:2], dim=0), masks)
562
+ student_embedding, student_features = student(torch.concat(images[:2], dim=0), masks)
559
563
 
560
564
  # Local views
561
- (student_local_embedding, _) = student(torch.concat(images[2:], dim=0), None, return_keys="embedding")
565
+ student_local_embedding, _ = student(torch.concat(images[2:], dim=0), None, return_keys="embedding")
562
566
 
563
567
  loss = ibot_loss(
564
568
  student_embedding,
@@ -76,13 +76,13 @@ def train(args: argparse.Namespace) -> None:
76
76
  #
77
77
  # Initialize
78
78
  #
79
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
79
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
80
80
 
81
81
  if args.type != "soft":
82
82
  args.temperature = 1.0
83
83
 
84
84
  # Using the teacher rgb values for the student
85
- (teacher, (class_to_idx, signature, rgb_stats, *_)) = fs_ops.load_model(
85
+ teacher, (class_to_idx, signature, rgb_stats, *_) = fs_ops.load_model(
86
86
  device,
87
87
  args.teacher,
88
88
  config=args.teacher_model_config,
@@ -113,15 +113,15 @@ def train(args: argparse.Namespace) -> None:
113
113
  training_wds_path: str | list[str]
114
114
  val_wds_path: str | list[str]
115
115
  if args.wds_info is not None:
116
- (training_wds_path, training_size) = wds_args_from_info(args.wds_info, args.wds_training_split)
117
- (val_wds_path, val_size) = wds_args_from_info(args.wds_info, args.wds_val_split)
116
+ training_wds_path, training_size = wds_args_from_info(args.wds_info, args.wds_training_split)
117
+ val_wds_path, val_size = wds_args_from_info(args.wds_info, args.wds_val_split)
118
118
  if args.wds_train_size is not None:
119
119
  training_size = args.wds_train_size
120
120
  if args.wds_val_size is not None:
121
121
  val_size = args.wds_val_size
122
122
  else:
123
- (training_wds_path, training_size) = prepare_wds_args(args.data_path, args.wds_train_size, device)
124
- (val_wds_path, val_size) = prepare_wds_args(args.val_path, args.wds_val_size, device)
123
+ training_wds_path, training_size = prepare_wds_args(args.data_path, args.wds_train_size, device)
124
+ val_wds_path, val_size = prepare_wds_args(args.val_path, args.wds_val_size, device)
125
125
 
126
126
  training_dataset = make_wds_dataset(
127
127
  training_wds_path,
@@ -187,7 +187,7 @@ def train(args: argparse.Namespace) -> None:
187
187
 
188
188
  # Data loaders and samplers
189
189
  virtual_epoch_mode = args.steps_per_epoch is not None
190
- (train_sampler, validation_sampler) = training_utils.get_samplers(
190
+ train_sampler, validation_sampler = training_utils.get_samplers(
191
191
  args, training_dataset, validation_dataset, infinite=virtual_epoch_mode
192
192
  )
193
193
 
@@ -269,7 +269,7 @@ def train(args: argparse.Namespace) -> None:
269
269
 
270
270
  if args.resume_epoch is not None:
271
271
  begin_epoch = args.resume_epoch + 1
272
- (student, class_to_idx_saved, training_states) = fs_ops.load_checkpoint(
272
+ student, class_to_idx_saved, training_states = fs_ops.load_checkpoint(
273
273
  device,
274
274
  args.student,
275
275
  config=args.student_model_config,
@@ -283,8 +283,8 @@ def train(args: argparse.Namespace) -> None:
283
283
  else:
284
284
  student = registry.net_factory(
285
285
  args.student,
286
- sample_shape[1],
287
286
  num_outputs,
287
+ sample_shape[1],
288
288
  config=args.student_model_config,
289
289
  size=args.size,
290
290
  )
@@ -383,7 +383,7 @@ def train(args: argparse.Namespace) -> None:
383
383
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
384
384
 
385
385
  # Gradient scaler and AMP related tasks
386
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
386
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
387
387
 
388
388
  # Load states
389
389
  if args.load_states is True:
@@ -567,10 +567,16 @@ def train(args: argparse.Namespace) -> None:
567
567
  if virtual_epoch_mode is True:
568
568
  train_iter = iter(training_loader)
569
569
 
570
+ top_k = args.top_k
570
571
  running_loss = training_utils.SmoothedValue(window_size=64)
571
572
  running_val_loss = training_utils.SmoothedValue()
572
573
  train_accuracy = training_utils.SmoothedValue(window_size=64)
573
574
  val_accuracy = training_utils.SmoothedValue()
575
+ train_topk: Optional[training_utils.SmoothedValue] = None
576
+ val_topk: Optional[training_utils.SmoothedValue] = None
577
+ if top_k is not None:
578
+ train_topk = training_utils.SmoothedValue(window_size=64)
579
+ val_topk = training_utils.SmoothedValue()
574
580
 
575
581
  logger.info(f"Starting training with learning rate of {last_lr}")
576
582
  for epoch in range(begin_epoch, args.stop_epoch):
@@ -579,6 +585,16 @@ def train(args: argparse.Namespace) -> None:
579
585
  if embedding_projection is not None:
580
586
  embedding_projection.train()
581
587
 
588
+ # Clear metrics
589
+ running_loss.clear()
590
+ running_val_loss.clear()
591
+ train_accuracy.clear()
592
+ val_accuracy.clear()
593
+ if train_topk is not None:
594
+ train_topk.clear()
595
+ if val_topk is not None:
596
+ val_topk.clear()
597
+
582
598
  if args.distributed is True or virtual_epoch_mode is True:
583
599
  train_sampler.set_epoch(epoch)
584
600
 
@@ -616,7 +632,7 @@ def train(args: argparse.Namespace) -> None:
616
632
  teacher_embedding = teacher.embedding(inputs)
617
633
  teacher_embedding = F.normalize(teacher_embedding, dim=-1)
618
634
 
619
- (outputs, student_embedding) = train_student(inputs)
635
+ outputs, student_embedding = train_student(inputs)
620
636
  student_embedding = embedding_projection(student_embedding) # type: ignore[misc]
621
637
  student_embedding = F.normalize(student_embedding, dim=-1)
622
638
  dist_loss = distillation_criterion(student_embedding, teacher_embedding)
@@ -637,7 +653,7 @@ def train(args: argparse.Namespace) -> None:
637
653
  outputs = train_student(inputs)
638
654
  dist_loss = distillation_criterion(outputs, teacher_targets)
639
655
  elif distillation_type == "deit":
640
- (outputs, dist_output) = torch.unbind(train_student(inputs), dim=1)
656
+ outputs, dist_output = torch.unbind(train_student(inputs), dim=1)
641
657
  dist_loss = distillation_criterion(dist_output, teacher_targets)
642
658
  else:
643
659
  raise RuntimeError
@@ -693,6 +709,9 @@ def train(args: argparse.Namespace) -> None:
693
709
  targets = targets.argmax(dim=1)
694
710
 
695
711
  train_accuracy.update(training_utils.accuracy(targets, outputs.detach()))
712
+ if train_topk is not None:
713
+ topk_val = training_utils.topk_accuracy(targets, outputs.detach(), topk=(top_k,))[0]
714
+ train_topk.update(topk_val)
696
715
 
697
716
  # Write statistics
698
717
  if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
@@ -711,6 +730,9 @@ def train(args: argparse.Namespace) -> None:
711
730
 
712
731
  running_loss.synchronize_between_processes(device)
713
732
  train_accuracy.synchronize_between_processes(device)
733
+ if train_topk is not None:
734
+ train_topk.synchronize_between_processes(device)
735
+
714
736
  with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
715
737
  log.info(
716
738
  f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
@@ -725,8 +747,17 @@ def train(args: argparse.Namespace) -> None:
725
747
  f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
726
748
  f"Accuracy: {train_accuracy.avg:.4f}"
727
749
  )
750
+ if train_topk is not None:
751
+ log.info(
752
+ f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
753
+ f"Accuracy@{top_k}: {train_topk.avg:.4f}"
754
+ )
728
755
 
729
756
  if training_utils.is_local_primary(args) is True:
757
+ performance = {"training_accuracy": train_accuracy.avg}
758
+ if train_topk is not None:
759
+ performance[f"training_accuracy@{top_k}"] = train_topk.avg
760
+
730
761
  summary_writer.add_scalars(
731
762
  "loss",
732
763
  {"training": running_loss.avg},
@@ -734,7 +765,7 @@ def train(args: argparse.Namespace) -> None:
734
765
  )
735
766
  summary_writer.add_scalars(
736
767
  "performance",
737
- {"training_accuracy": train_accuracy.avg},
768
+ performance,
738
769
  ((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
739
770
  )
740
771
 
@@ -746,6 +777,8 @@ def train(args: argparse.Namespace) -> None:
746
777
  # Epoch training metrics
747
778
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_loss: {running_loss.global_avg:.4f}")
748
779
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_accuracy: {train_accuracy.global_avg:.4f}")
780
+ if train_topk is not None:
781
+ logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_accuracy@{top_k}: {train_topk.global_avg:.4f}")
749
782
 
750
783
  # Validation
751
784
  eval_model.eval()
@@ -772,6 +805,9 @@ def train(args: argparse.Namespace) -> None:
772
805
  # Statistics
773
806
  running_val_loss.update(val_loss.detach())
774
807
  val_accuracy.update(training_utils.accuracy(targets, outputs), n=outputs.size(0))
808
+ if val_topk is not None:
809
+ topk_val = training_utils.topk_accuracy(targets, outputs, topk=(top_k,))[0]
810
+ val_topk.update(topk_val, n=outputs.size(0))
775
811
 
776
812
  # Update progress bar
777
813
  progress.update(n=batch_size * args.world_size)
@@ -789,19 +825,30 @@ def train(args: argparse.Namespace) -> None:
789
825
 
790
826
  running_val_loss.synchronize_between_processes(device)
791
827
  val_accuracy.synchronize_between_processes(device)
828
+ if val_topk is not None:
829
+ val_topk.synchronize_between_processes(device)
830
+
792
831
  epoch_val_loss = running_val_loss.global_avg
793
832
  epoch_val_accuracy = val_accuracy.global_avg
833
+ if val_topk is not None:
834
+ epoch_val_topk = val_topk.global_avg
835
+ else:
836
+ epoch_val_topk = None
794
837
 
795
838
  # Write statistics
796
839
  if training_utils.is_local_primary(args) is True:
797
840
  summary_writer.add_scalars("loss", {"validation": epoch_val_loss}, epoch * epoch_samples)
798
- summary_writer.add_scalars(
799
- "performance", {"validation_accuracy": epoch_val_accuracy}, epoch * epoch_samples
800
- )
841
+ performance = {"validation_accuracy": epoch_val_accuracy}
842
+ if epoch_val_topk is not None:
843
+ performance[f"validation_accuracy@{top_k}"] = epoch_val_topk
844
+
845
+ summary_writer.add_scalars("performance", performance, epoch * epoch_samples)
801
846
 
802
847
  # Epoch validation metrics
803
848
  logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_loss (target only): {epoch_val_loss:.4f}")
804
849
  logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_accuracy: {epoch_val_accuracy:.4f}")
850
+ if epoch_val_topk is not None:
851
+ logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_accuracy@{top_k}: {epoch_val_topk:.4f}")
805
852
 
806
853
  # Learning rate scheduler update
807
854
  if step_update is False:
@@ -989,7 +1036,7 @@ def get_args_parser() -> argparse.ArgumentParser:
989
1036
  training_cli.add_compile_args(parser, teacher=True)
990
1037
  training_cli.add_checkpoint_args(parser, default_save_frequency=5)
991
1038
  training_cli.add_distributed_args(parser)
992
- training_cli.add_logging_and_debug_args(parser)
1039
+ training_cli.add_logging_and_debug_args(parser, classification=True)
993
1040
  training_cli.add_training_data_args(parser)
994
1041
 
995
1042
  return parser
@@ -49,7 +49,7 @@ def train(args: argparse.Namespace) -> None:
49
49
  #
50
50
  # Initialize
51
51
  #
52
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
52
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
53
53
 
54
54
  if args.size is None:
55
55
  # Prefer mim size over encoder default size
@@ -73,11 +73,11 @@ def train(args: argparse.Namespace) -> None:
73
73
  elif args.wds is True:
74
74
  wds_path: str | list[str]
75
75
  if args.wds_info is not None:
76
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
76
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
77
77
  if args.wds_size is not None:
78
78
  dataset_size = args.wds_size
79
79
  else:
80
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
80
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
81
81
 
82
82
  training_dataset = make_wds_dataset(
83
83
  wds_path,
@@ -107,7 +107,7 @@ def train(args: argparse.Namespace) -> None:
107
107
 
108
108
  # Data loaders and samplers
109
109
  virtual_epoch_mode = args.steps_per_epoch is not None
110
- (train_sampler, _) = training_utils.get_samplers(
110
+ train_sampler, _ = training_utils.get_samplers(
111
111
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
112
112
  )
113
113
 
@@ -172,7 +172,7 @@ def train(args: argparse.Namespace) -> None:
172
172
 
173
173
  if args.resume_epoch is not None:
174
174
  begin_epoch = args.resume_epoch + 1
175
- (net, training_states) = fs_ops.load_mim_checkpoint(
175
+ net, training_states = fs_ops.load_mim_checkpoint(
176
176
  device,
177
177
  args.network,
178
178
  config=args.model_config,
@@ -187,7 +187,7 @@ def train(args: argparse.Namespace) -> None:
187
187
 
188
188
  elif args.pretrained is True:
189
189
  fs_ops.download_model_by_weights(network_name, progress_bar=training_utils.is_local_primary(args))
190
- (net, training_states) = fs_ops.load_mim_checkpoint(
190
+ net, training_states = fs_ops.load_mim_checkpoint(
191
191
  device,
192
192
  args.network,
193
193
  config=args.model_config,
@@ -202,7 +202,7 @@ def train(args: argparse.Namespace) -> None:
202
202
 
203
203
  else:
204
204
  encoder = registry.net_factory(
205
- args.encoder, sample_shape[1], 0, config=args.encoder_model_config, size=args.size
205
+ args.encoder, 0, sample_shape[1], config=args.encoder_model_config, size=args.size
206
206
  )
207
207
  net = registry.mim_net_factory(
208
208
  args.network,
@@ -263,7 +263,7 @@ def train(args: argparse.Namespace) -> None:
263
263
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
264
264
 
265
265
  # Gradient scaler and AMP related tasks
266
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
266
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
267
267
 
268
268
  # Load states
269
269
  if args.load_states is True:
@@ -375,6 +375,9 @@ def train(args: argparse.Namespace) -> None:
375
375
  tic = time.time()
376
376
  net.train()
377
377
 
378
+ # Clear metrics
379
+ running_loss.clear()
380
+
378
381
  if args.distributed is True or virtual_epoch_mode is True:
379
382
  train_sampler.set_epoch(epoch)
380
383
 
@@ -74,7 +74,7 @@ def train(args: argparse.Namespace) -> None:
74
74
  #
75
75
  # Initialize
76
76
  #
77
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
77
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
78
78
 
79
79
  if args.size is None:
80
80
  args.size = registry.get_default_size(args.network)
@@ -97,11 +97,11 @@ def train(args: argparse.Namespace) -> None:
97
97
  elif args.wds is True:
98
98
  wds_path: str | list[str]
99
99
  if args.wds_info is not None:
100
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
100
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
101
101
  if args.wds_size is not None:
102
102
  dataset_size = args.wds_size
103
103
  else:
104
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
104
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
105
105
 
106
106
  training_dataset = make_wds_dataset(
107
107
  wds_path,
@@ -131,7 +131,7 @@ def train(args: argparse.Namespace) -> None:
131
131
 
132
132
  # Data loaders and samplers
133
133
  virtual_epoch_mode = args.steps_per_epoch is not None
134
- (train_sampler, _) = training_utils.get_samplers(
134
+ train_sampler, _ = training_utils.get_samplers(
135
135
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
136
136
  )
137
137
 
@@ -194,12 +194,12 @@ def train(args: argparse.Namespace) -> None:
194
194
 
195
195
  network_name = get_mim_network_name("mmcr", encoder=args.network, tag=args.tag)
196
196
 
197
- backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
197
+ backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
198
198
  net = MMCR(backbone, config={"projector_dims": args.projector_dims})
199
199
 
200
200
  if args.resume_epoch is not None:
201
201
  begin_epoch = args.resume_epoch + 1
202
- (net, training_states) = fs_ops.load_simple_checkpoint(
202
+ net, training_states = fs_ops.load_simple_checkpoint(
203
203
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
204
204
  )
205
205
 
@@ -265,7 +265,7 @@ def train(args: argparse.Namespace) -> None:
265
265
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
266
266
 
267
267
  # Gradient scaler and AMP related tasks
268
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
268
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
269
269
 
270
270
  # Load states
271
271
  if args.load_states is True:
@@ -377,6 +377,9 @@ def train(args: argparse.Namespace) -> None:
377
377
  tic = time.time()
378
378
  net.train()
379
379
 
380
+ # Clear metrics
381
+ running_loss.clear()
382
+
380
383
  if args.distributed is True or virtual_epoch_mode is True:
381
384
  train_sampler.set_epoch(epoch)
382
385
 
@@ -407,7 +410,7 @@ def train(args: argparse.Namespace) -> None:
407
410
 
408
411
  # Forward, backward and optimize
409
412
  with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
410
- (z, z_m) = net(images)
413
+ z, z_m = net(images)
411
414
  loss = mmcr_loss(z, z_m)
412
415
 
413
416
  if scaler is not None:
@@ -83,7 +83,7 @@ def train(args: argparse.Namespace) -> None:
83
83
  #
84
84
  # Initialize
85
85
  #
86
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
86
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
87
87
 
88
88
  if args.size is None:
89
89
  args.size = registry.get_default_size(args.network)
@@ -111,11 +111,11 @@ def train(args: argparse.Namespace) -> None:
111
111
  elif args.wds is True:
112
112
  wds_path: str | list[str]
113
113
  if args.wds_info is not None:
114
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
114
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
115
115
  if args.wds_size is not None:
116
116
  dataset_size = args.wds_size
117
117
  else:
118
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
118
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
119
119
 
120
120
  training_dataset = make_wds_dataset(
121
121
  wds_path,
@@ -145,7 +145,7 @@ def train(args: argparse.Namespace) -> None:
145
145
 
146
146
  # Data loaders and samplers
147
147
  virtual_epoch_mode = args.steps_per_epoch is not None
148
- (train_sampler, _) = training_utils.get_samplers(
148
+ train_sampler, _ = training_utils.get_samplers(
149
149
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
150
150
  )
151
151
 
@@ -207,12 +207,12 @@ def train(args: argparse.Namespace) -> None:
207
207
  network_name = f"{network_name}-{args.tag}"
208
208
 
209
209
  net = registry.net_factory(
210
- args.network, sample_shape[1], len(class_to_idx), config=args.model_config, size=args.size
210
+ args.network, len(class_to_idx), sample_shape[1], config=args.model_config, size=args.size
211
211
  )
212
212
 
213
213
  if args.resume_epoch is not None:
214
214
  begin_epoch = args.resume_epoch + 1
215
- (net, training_states) = fs_ops.load_simple_checkpoint(
215
+ net, training_states = fs_ops.load_simple_checkpoint(
216
216
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
217
217
  )
218
218
 
@@ -277,7 +277,7 @@ def train(args: argparse.Namespace) -> None:
277
277
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
278
278
 
279
279
  # Gradient scaler and AMP related tasks
280
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
280
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
281
281
 
282
282
  # Load states
283
283
  if args.load_states is True:
@@ -389,6 +389,10 @@ def train(args: argparse.Namespace) -> None:
389
389
  tic = time.time()
390
390
  net.train()
391
391
 
392
+ # Clear metrics
393
+ running_loss.clear()
394
+ train_accuracy.clear()
395
+
392
396
  if args.distributed is True or virtual_epoch_mode is True:
393
397
  train_sampler.set_epoch(epoch)
394
398