birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -75,7 +75,7 @@ def train(args: argparse.Namespace) -> None:
75
75
  #
76
76
  # Initialize
77
77
  #
78
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
78
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
79
79
 
80
80
  if args.size is None:
81
81
  # Prefer mim size over encoder default size
@@ -105,7 +105,7 @@ def train(args: argparse.Namespace) -> None:
105
105
 
106
106
  network_name = get_mim_network_name("data2vec2", encoder=args.network, tag=args.tag)
107
107
 
108
- backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
108
+ backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
109
109
  net = Data2Vec2(
110
110
  backbone,
111
111
  config={
@@ -121,7 +121,7 @@ def train(args: argparse.Namespace) -> None:
121
121
 
122
122
  if args.resume_epoch is not None:
123
123
  begin_epoch = args.resume_epoch + 1
124
- (net, training_states) = fs_ops.load_simple_checkpoint(
124
+ net, training_states = fs_ops.load_simple_checkpoint(
125
125
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
126
126
  )
127
127
 
@@ -169,11 +169,11 @@ def train(args: argparse.Namespace) -> None:
169
169
  elif args.wds is True:
170
170
  wds_path: str | list[str]
171
171
  if args.wds_info is not None:
172
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
172
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
173
173
  if args.wds_size is not None:
174
174
  dataset_size = args.wds_size
175
175
  else:
176
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
176
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
177
177
 
178
178
  training_dataset = make_wds_dataset(
179
179
  wds_path,
@@ -199,7 +199,7 @@ def train(args: argparse.Namespace) -> None:
199
199
 
200
200
  # Data loaders and samplers
201
201
  virtual_epoch_mode = args.steps_per_epoch is not None
202
- (train_sampler, _) = training_utils.get_samplers(
202
+ train_sampler, _ = training_utils.get_samplers(
203
203
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
204
204
  )
205
205
 
@@ -288,7 +288,7 @@ def train(args: argparse.Namespace) -> None:
288
288
  )
289
289
 
290
290
  # Gradient scaler and AMP related tasks
291
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
291
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
292
292
 
293
293
  # Load states
294
294
  if args.load_states is True:
@@ -400,6 +400,9 @@ def train(args: argparse.Namespace) -> None:
400
400
  tic = time.time()
401
401
  net.train()
402
402
 
403
+ # Clear metrics
404
+ running_loss.clear()
405
+
403
406
  if args.distributed is True or virtual_epoch_mode is True:
404
407
  train_sampler.set_epoch(epoch)
405
408
 
@@ -27,7 +27,7 @@ from birder.common import training_cli
27
27
  from birder.common import training_utils
28
28
  from birder.conf import settings
29
29
  from birder.data.collators.detection import BatchRandomResizeCollator
30
- from birder.data.collators.detection import training_collate_fn
30
+ from birder.data.collators.detection import DetectionCollator
31
31
  from birder.data.datasets.coco import CocoMosaicTraining
32
32
  from birder.data.datasets.coco import CocoTraining
33
33
  from birder.data.transforms.classification import get_rgb_stats
@@ -63,7 +63,7 @@ def train(args: argparse.Namespace) -> None:
63
63
  )
64
64
  model_dynamic_size = transform_dynamic_size or args.batch_multiscale is True
65
65
 
66
- (device, device_id, disable_tqdm) = training_utils.init_training(
66
+ device, device_id, disable_tqdm = training_utils.init_training(
67
67
  args, logger, cudnn_dynamic_size=transform_dynamic_size
68
68
  )
69
69
 
@@ -92,6 +92,7 @@ def train(args: argparse.Namespace) -> None:
92
92
  args.multiscale,
93
93
  args.max_size,
94
94
  args.multiscale_min_size,
95
+ args.multiscale_step,
95
96
  )
96
97
  mosaic_dataset = None
97
98
  if args.mosaic_prob > 0.0:
@@ -104,6 +105,7 @@ def train(args: argparse.Namespace) -> None:
104
105
  args.multiscale,
105
106
  args.max_size,
106
107
  args.multiscale_min_size,
108
+ args.multiscale_step,
107
109
  post_mosaic=True,
108
110
  )
109
111
  if args.dynamic_size is True or args.multiscale is True:
@@ -177,14 +179,22 @@ def train(args: argparse.Namespace) -> None:
177
179
 
178
180
  # Data loaders and samplers
179
181
  virtual_epoch_mode = args.steps_per_epoch is not None
180
- (train_sampler, validation_sampler) = training_utils.get_samplers(
182
+ train_sampler, validation_sampler = training_utils.get_samplers(
181
183
  args, training_dataset, validation_dataset, infinite=virtual_epoch_mode
182
184
  )
183
185
 
184
186
  if args.batch_multiscale is True:
185
- train_collate_fn: Any = BatchRandomResizeCollator(0, args.size, multiscale_min_size=args.multiscale_min_size)
187
+ train_collate_fn: Any = BatchRandomResizeCollator(
188
+ 0,
189
+ args.size,
190
+ size_divisible=args.multiscale_step,
191
+ multiscale_min_size=args.multiscale_min_size,
192
+ multiscale_step=args.multiscale_step,
193
+ )
186
194
  else:
187
- train_collate_fn = training_collate_fn
195
+ train_collate_fn = DetectionCollator(0, size_divisible=args.multiscale_step)
196
+
197
+ validation_collate_fn = DetectionCollator(0, size_divisible=args.multiscale_step)
188
198
 
189
199
  training_loader = DataLoader(
190
200
  training_dataset,
@@ -202,7 +212,7 @@ def train(args: argparse.Namespace) -> None:
202
212
  sampler=validation_sampler,
203
213
  num_workers=args.num_workers,
204
214
  prefetch_factor=args.prefetch_factor,
205
- collate_fn=training_collate_fn,
215
+ collate_fn=validation_collate_fn,
206
216
  pin_memory=True,
207
217
  drop_last=args.drop_last,
208
218
  )
@@ -243,7 +253,7 @@ def train(args: argparse.Namespace) -> None:
243
253
 
244
254
  if args.resume_epoch is not None:
245
255
  begin_epoch = args.resume_epoch + 1
246
- (net, class_to_idx_saved, training_states) = fs_ops.load_detection_checkpoint(
256
+ net, class_to_idx_saved, training_states = fs_ops.load_detection_checkpoint(
247
257
  device,
248
258
  args.network,
249
259
  config=args.model_config,
@@ -262,7 +272,7 @@ def train(args: argparse.Namespace) -> None:
262
272
 
263
273
  elif args.pretrained is True:
264
274
  fs_ops.download_model_by_weights(network_name, progress_bar=training_utils.is_local_primary(args))
265
- (net, class_to_idx_saved, training_states) = fs_ops.load_detection_checkpoint(
275
+ net, class_to_idx_saved, training_states = fs_ops.load_detection_checkpoint(
266
276
  device,
267
277
  args.network,
268
278
  config=args.model_config,
@@ -282,7 +292,7 @@ def train(args: argparse.Namespace) -> None:
282
292
  else:
283
293
  if args.backbone_epoch is not None:
284
294
  backbone: DetectorBackbone
285
- (backbone, class_to_idx_saved, _) = fs_ops.load_checkpoint(
295
+ backbone, class_to_idx_saved, _ = fs_ops.load_checkpoint(
286
296
  device,
287
297
  args.backbone,
288
298
  config=args.backbone_model_config,
@@ -297,7 +307,7 @@ def train(args: argparse.Namespace) -> None:
297
307
  lib.get_network_name(args.backbone, tag=args.backbone_tag),
298
308
  progress_bar=training_utils.is_local_primary(args),
299
309
  )
300
- (backbone, class_to_idx_saved, _) = fs_ops.load_checkpoint(
310
+ backbone, class_to_idx_saved, _ = fs_ops.load_checkpoint(
301
311
  device,
302
312
  args.backbone,
303
313
  config=args.backbone_model_config,
@@ -309,7 +319,7 @@ def train(args: argparse.Namespace) -> None:
309
319
 
310
320
  else:
311
321
  backbone = registry.net_factory(
312
- args.backbone, sample_shape[1], num_outputs, config=args.backbone_model_config, size=args.size
322
+ args.backbone, num_outputs, sample_shape[1], config=args.backbone_model_config, size=args.size
313
323
  )
314
324
 
315
325
  net = registry.detection_net_factory(
@@ -386,7 +396,7 @@ def train(args: argparse.Namespace) -> None:
386
396
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
387
397
 
388
398
  # Gradient scaler and AMP related tasks
389
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
399
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
390
400
 
391
401
  # Load states
392
402
  if args.load_states is True:
@@ -546,6 +556,11 @@ def train(args: argparse.Namespace) -> None:
546
556
  tic = time.time()
547
557
  net.train()
548
558
 
559
+ # Clear metrics
560
+ running_loss.clear()
561
+ for tracker in loss_trackers.values():
562
+ tracker.clear()
563
+
549
564
  validation_metrics.reset()
550
565
 
551
566
  if args.distributed is True or virtual_epoch_mode is True:
@@ -586,7 +601,7 @@ def train(args: argparse.Namespace) -> None:
586
601
 
587
602
  # Forward, backward and optimize
588
603
  with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
589
- (_detections, losses) = net(inputs, targets, masks, image_sizes)
604
+ _detections, losses = net(inputs, targets, masks, image_sizes)
590
605
  loss = sum(v for v in losses.values())
591
606
 
592
607
  if scaler is not None:
@@ -708,7 +723,7 @@ def train(args: argparse.Namespace) -> None:
708
723
  masks = masks.to(device, non_blocking=True)
709
724
 
710
725
  with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
711
- (detections, losses) = eval_model(inputs, masks=masks, image_sizes=image_sizes)
726
+ detections, losses = eval_model(inputs, masks=masks, image_sizes=image_sizes)
712
727
 
713
728
  for target in targets:
714
729
  # TorchMetrics can't handle "empty" images
@@ -101,7 +101,7 @@ def train(args: argparse.Namespace) -> None:
101
101
  #
102
102
  # Initialize
103
103
  #
104
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
104
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
105
105
 
106
106
  if args.size is None:
107
107
  args.size = registry.get_default_size(args.network)
@@ -129,11 +129,11 @@ def train(args: argparse.Namespace) -> None:
129
129
  elif args.wds is True:
130
130
  wds_path: str | list[str]
131
131
  if args.wds_info is not None:
132
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
132
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
133
133
  if args.wds_size is not None:
134
134
  dataset_size = args.wds_size
135
135
  else:
136
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
136
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
137
137
 
138
138
  training_dataset = make_wds_dataset(
139
139
  wds_path,
@@ -163,7 +163,7 @@ def train(args: argparse.Namespace) -> None:
163
163
 
164
164
  # Data loaders and samplers
165
165
  virtual_epoch_mode = args.steps_per_epoch is not None
166
- (train_sampler, _) = training_utils.get_samplers(
166
+ train_sampler, _ = training_utils.get_samplers(
167
167
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
168
168
  )
169
169
 
@@ -226,9 +226,9 @@ def train(args: argparse.Namespace) -> None:
226
226
 
227
227
  network_name = get_mim_network_name("dino_v1", encoder=args.network, tag=args.tag)
228
228
 
229
- student_backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
229
+ student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
230
230
  if args.backbone_epoch is not None:
231
- (student_backbone, _) = fs_ops.load_simple_checkpoint(
231
+ student_backbone, _ = fs_ops.load_simple_checkpoint(
232
232
  device, student_backbone, backbone_name, epoch=args.backbone_epoch, strict=not args.non_strict_weights
233
233
  )
234
234
 
@@ -239,7 +239,7 @@ def train(args: argparse.Namespace) -> None:
239
239
  teacher_model_config = {"drop_path_rate": 0.0}
240
240
 
241
241
  teacher_backbone = registry.net_factory(
242
- args.network, sample_shape[1], 0, config=teacher_model_config, size=args.size
242
+ args.network, 0, sample_shape[1], config=teacher_model_config, size=args.size
243
243
  )
244
244
  if args.freeze_body is True:
245
245
  student_backbone.freeze(freeze_classifier=False, unfreeze_features=True)
@@ -293,7 +293,7 @@ def train(args: argparse.Namespace) -> None:
293
293
 
294
294
  if args.resume_epoch is not None:
295
295
  begin_epoch = args.resume_epoch + 1
296
- (net, training_states) = fs_ops.load_simple_checkpoint(
296
+ net, training_states = fs_ops.load_simple_checkpoint(
297
297
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
298
298
  )
299
299
  student = net["student"]
@@ -368,7 +368,7 @@ def train(args: argparse.Namespace) -> None:
368
368
  wd_schedule = None
369
369
 
370
370
  # Gradient scaler and AMP related tasks
371
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
371
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
372
372
 
373
373
  # Load states
374
374
  if args.load_states is True:
@@ -488,6 +488,10 @@ def train(args: argparse.Namespace) -> None:
488
488
  tic = time.time()
489
489
  net.train()
490
490
 
491
+ # Clear metrics
492
+ running_loss.clear()
493
+ train_proto_agreement.clear()
494
+
491
495
  if args.distributed is True or virtual_epoch_mode is True:
492
496
  train_sampler.set_epoch(epoch)
493
497
 
@@ -178,7 +178,7 @@ def train(args: argparse.Namespace) -> None:
178
178
  #
179
179
  # Initialize
180
180
  #
181
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
181
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
182
182
 
183
183
  if args.size is None:
184
184
  args.size = registry.get_default_size(args.network)
@@ -207,7 +207,7 @@ def train(args: argparse.Namespace) -> None:
207
207
 
208
208
  network_name = get_mim_network_name("dino_v2", encoder=args.network, tag=args.tag)
209
209
 
210
- student_backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
210
+ student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
211
211
  if args.model_config is not None:
212
212
  teacher_model_config = args.model_config.copy()
213
213
  teacher_model_config.update({"drop_path_rate": 0.0})
@@ -215,7 +215,7 @@ def train(args: argparse.Namespace) -> None:
215
215
  teacher_model_config = {"drop_path_rate": 0.0}
216
216
 
217
217
  teacher_backbone = registry.net_factory(
218
- args.network, sample_shape[1], 0, config=teacher_model_config, size=args.size
218
+ args.network, 0, sample_shape[1], config=teacher_model_config, size=args.size
219
219
  )
220
220
  student_backbone.set_dynamic_size()
221
221
  if args.ibot_separate_head is False:
@@ -267,7 +267,7 @@ def train(args: argparse.Namespace) -> None:
267
267
 
268
268
  if args.resume_epoch is not None:
269
269
  begin_epoch = args.resume_epoch + 1
270
- (net, training_states) = fs_ops.load_simple_checkpoint(
270
+ net, training_states = fs_ops.load_simple_checkpoint(
271
271
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
272
272
  )
273
273
  student = net["student"]
@@ -336,11 +336,11 @@ def train(args: argparse.Namespace) -> None:
336
336
  elif args.wds is True:
337
337
  wds_path: str | list[str]
338
338
  if args.wds_info is not None:
339
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
339
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
340
340
  if args.wds_size is not None:
341
341
  dataset_size = args.wds_size
342
342
  else:
343
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
343
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
344
344
 
345
345
  training_dataset = make_wds_dataset(
346
346
  wds_path,
@@ -366,7 +366,7 @@ def train(args: argparse.Namespace) -> None:
366
366
 
367
367
  # Data loaders and samplers
368
368
  virtual_epoch_mode = args.steps_per_epoch is not None
369
- (train_sampler, _) = training_utils.get_samplers(
369
+ train_sampler, _ = training_utils.get_samplers(
370
370
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
371
371
  )
372
372
 
@@ -466,7 +466,7 @@ def train(args: argparse.Namespace) -> None:
466
466
  wd_schedule = None
467
467
 
468
468
  # Gradient scaler and AMP related tasks
469
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
469
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
470
470
 
471
471
  # Load states
472
472
  if args.load_states is True:
@@ -603,6 +603,19 @@ def train(args: argparse.Namespace) -> None:
603
603
  tic = time.time()
604
604
  net.train()
605
605
 
606
+ # Clear metrics
607
+ running_loss.clear()
608
+ running_loss_dino_local.clear()
609
+ running_loss_dino_global.clear()
610
+ running_loss_koleo.clear()
611
+ running_loss_ibot_patch.clear()
612
+ if track_extended_metrics is True:
613
+ train_proto_agreement.clear()
614
+ train_patch_agreement.clear()
615
+ running_target_entropy.clear()
616
+ running_dino_center_drift.clear()
617
+ running_ibot_center_drift.clear()
618
+
606
619
  if args.sinkhorn_queue_size is not None:
607
620
  queue_active = epoch > args.sinkhorn_queue_warmup_epochs
608
621
  dino_loss.set_queue_active(queue_active)
@@ -661,7 +674,7 @@ def train(args: argparse.Namespace) -> None:
661
674
  with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
662
675
  with torch.no_grad():
663
676
  # Teacher
664
- (teacher_embedding_after_head, teacher_masked_patch_tokens_after_head) = teacher(
677
+ teacher_embedding_after_head, teacher_masked_patch_tokens_after_head = teacher(
665
678
  global_crops, n_global_crops, upper_bound, mask_indices_list
666
679
  )
667
680
  teacher_patch_tokens_raw = teacher_masked_patch_tokens_after_head
@@ -671,7 +684,7 @@ def train(args: argparse.Namespace) -> None:
671
684
  prev_dino_center = dino_loss.center.clone()
672
685
  prev_ibot_center = ibot_patch_loss.center.clone()
673
686
 
674
- teacher_dino_softmax_centered_list = dino_loss.softmax_center_teacher(
687
+ teacher_dino_softmax_centered = dino_loss.softmax_center_teacher(
675
688
  teacher_embedding_after_head, teacher_temp=teacher_temp
676
689
  ).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
677
690
  dino_loss.update_center(teacher_embedding_after_head)
@@ -684,7 +697,7 @@ def train(args: argparse.Namespace) -> None:
684
697
  ibot_patch_loss.update_center(teacher_masked_patch_tokens_after_head[:, :n_masked_patches])
685
698
 
686
699
  else: # sinkhorn_knopp
687
- teacher_dino_softmax_centered_list = dino_loss.sinkhorn_knopp_teacher(
700
+ teacher_dino_softmax_centered = dino_loss.sinkhorn_knopp_teacher(
688
701
  teacher_embedding_after_head, teacher_temp=teacher_temp
689
702
  ).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
690
703
 
@@ -705,7 +718,7 @@ def train(args: argparse.Namespace) -> None:
705
718
  # Local DINO loss
706
719
  loss_dino_local_crops = dino_loss(
707
720
  student_local_embedding_after_head.chunk(n_local_crops),
708
- teacher_dino_softmax_centered_list,
721
+ teacher_dino_softmax_centered.unbind(0),
709
722
  ) / (n_global_crops_loss_terms + n_local_crops_loss_terms)
710
723
  loss = args.dino_loss_weight * loss_dino_local_crops
711
724
 
@@ -715,7 +728,7 @@ def train(args: argparse.Namespace) -> None:
715
728
  dino_loss(
716
729
  [student_global_embedding_after_head],
717
730
  [
718
- teacher_dino_softmax_centered_list.flatten(0, 1)
731
+ teacher_dino_softmax_centered.flatten(0, 1)
719
732
  ], # These were chunked and stacked in reverse so A is matched to B
720
733
  )
721
734
  * loss_scales
@@ -809,7 +822,7 @@ def train(args: argparse.Namespace) -> None:
809
822
  train_patch_agreement.update(training_utils.accuracy(pred_patch_teacher, pred_patch_student))
810
823
 
811
824
  with torch.no_grad():
812
- p = teacher_dino_softmax_centered_list.detach()
825
+ p = teacher_dino_softmax_centered.detach()
813
826
  p = p.reshape(-1, p.size(-1)) # (N, D)
814
827
 
815
828
  # Mean distribution over prototypes (marginal)
@@ -179,7 +179,7 @@ def train(args: argparse.Namespace) -> None:
179
179
  #
180
180
  # Initialize
181
181
  #
182
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
182
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
183
183
 
184
184
  if args.size is None:
185
185
  args.size = registry.get_default_size(args.network)
@@ -208,17 +208,17 @@ def train(args: argparse.Namespace) -> None:
208
208
 
209
209
  network_name = get_mim_network_name("dino_v2_dist", encoder=args.network, tag=args.tag)
210
210
 
211
- student_backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
211
+ student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
212
212
  student_backbone_ema = registry.net_factory(
213
- args.network, sample_shape[1], 0, config=args.model_config, size=args.size
213
+ args.network, 0, sample_shape[1], config=args.model_config, size=args.size
214
214
  )
215
215
  student_backbone_ema.load_state_dict(student_backbone.state_dict())
216
216
  student_backbone_ema.requires_grad_(False)
217
217
 
218
218
  teacher_backbone = registry.net_factory(
219
219
  args.teacher,
220
- sample_shape[1],
221
220
  0,
221
+ sample_shape[1],
222
222
  config=args.teacher_model_config,
223
223
  size=args.size,
224
224
  )
@@ -277,7 +277,7 @@ def train(args: argparse.Namespace) -> None:
277
277
 
278
278
  if args.resume_epoch is not None:
279
279
  begin_epoch = args.resume_epoch + 1
280
- (net, training_states) = fs_ops.load_simple_checkpoint(
280
+ net, training_states = fs_ops.load_simple_checkpoint(
281
281
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
282
282
  )
283
283
  student = net["student"]
@@ -358,11 +358,11 @@ def train(args: argparse.Namespace) -> None:
358
358
  elif args.wds is True:
359
359
  wds_path: str | list[str]
360
360
  if args.wds_info is not None:
361
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
361
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
362
362
  if args.wds_size is not None:
363
363
  dataset_size = args.wds_size
364
364
  else:
365
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
365
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
366
366
 
367
367
  training_dataset = make_wds_dataset(
368
368
  wds_path,
@@ -388,7 +388,7 @@ def train(args: argparse.Namespace) -> None:
388
388
 
389
389
  # Data loaders and samplers
390
390
  virtual_epoch_mode = args.steps_per_epoch is not None
391
- (train_sampler, _) = training_utils.get_samplers(
391
+ train_sampler, _ = training_utils.get_samplers(
392
392
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
393
393
  )
394
394
 
@@ -487,7 +487,7 @@ def train(args: argparse.Namespace) -> None:
487
487
  wd_schedule = None
488
488
 
489
489
  # Gradient scaler and AMP related tasks
490
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
490
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
491
491
 
492
492
  # Load states
493
493
  if args.load_states is True:
@@ -625,6 +625,19 @@ def train(args: argparse.Namespace) -> None:
625
625
  net.train()
626
626
  teacher.eval()
627
627
 
628
+ # Clear metrics
629
+ running_loss.clear()
630
+ running_loss_dino_local.clear()
631
+ running_loss_dino_global.clear()
632
+ running_loss_koleo.clear()
633
+ running_loss_ibot_patch.clear()
634
+ if track_extended_metrics is True:
635
+ train_proto_agreement.clear()
636
+ train_patch_agreement.clear()
637
+ running_target_entropy.clear()
638
+ running_dino_center_drift.clear()
639
+ running_ibot_center_drift.clear()
640
+
628
641
  if args.sinkhorn_queue_size is not None:
629
642
  queue_active = epoch > args.sinkhorn_queue_warmup_epochs
630
643
  dino_loss.set_queue_active(queue_active)
@@ -682,7 +695,7 @@ def train(args: argparse.Namespace) -> None:
682
695
  with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
683
696
  with torch.no_grad():
684
697
  # Teacher
685
- (teacher_embedding_after_head, teacher_masked_patch_tokens_after_head) = teacher(
698
+ teacher_embedding_after_head, teacher_masked_patch_tokens_after_head = teacher(
686
699
  global_crops, n_global_crops, upper_bound, mask_indices_list
687
700
  )
688
701
  teacher_patch_tokens_raw = teacher_masked_patch_tokens_after_head
@@ -692,7 +705,7 @@ def train(args: argparse.Namespace) -> None:
692
705
  prev_dino_center = dino_loss.center.clone()
693
706
  prev_ibot_center = ibot_patch_loss.center.clone()
694
707
 
695
- teacher_dino_softmax_centered_list = dino_loss.softmax_center_teacher(
708
+ teacher_dino_softmax_centered = dino_loss.softmax_center_teacher(
696
709
  teacher_embedding_after_head, teacher_temp=teacher_temp
697
710
  ).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
698
711
  dino_loss.update_center(teacher_embedding_after_head)
@@ -705,7 +718,7 @@ def train(args: argparse.Namespace) -> None:
705
718
  ibot_patch_loss.update_center(teacher_masked_patch_tokens_after_head[:, :n_masked_patches])
706
719
 
707
720
  else: # sinkhorn_knopp
708
- teacher_dino_softmax_centered_list = dino_loss.sinkhorn_knopp_teacher(
721
+ teacher_dino_softmax_centered = dino_loss.sinkhorn_knopp_teacher(
709
722
  teacher_embedding_after_head, teacher_temp=teacher_temp
710
723
  ).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
711
724
 
@@ -726,7 +739,7 @@ def train(args: argparse.Namespace) -> None:
726
739
  # Local DINO loss
727
740
  loss_dino_local_crops = dino_loss(
728
741
  student_local_embedding_after_head.chunk(n_local_crops),
729
- teacher_dino_softmax_centered_list,
742
+ teacher_dino_softmax_centered.unbind(0),
730
743
  ) / (n_global_crops_loss_terms + n_local_crops_loss_terms)
731
744
  loss = args.dino_loss_weight * loss_dino_local_crops
732
745
 
@@ -736,7 +749,7 @@ def train(args: argparse.Namespace) -> None:
736
749
  dino_loss(
737
750
  [student_global_embedding_after_head],
738
751
  [
739
- teacher_dino_softmax_centered_list.flatten(0, 1)
752
+ teacher_dino_softmax_centered.flatten(0, 1)
740
753
  ], # These were chunked and stacked in reverse so A is matched to B
741
754
  )
742
755
  * loss_scales
@@ -830,7 +843,7 @@ def train(args: argparse.Namespace) -> None:
830
843
  train_patch_agreement.update(training_utils.accuracy(pred_patch_teacher, pred_patch_student))
831
844
 
832
845
  with torch.no_grad():
833
- p = teacher_dino_softmax_centered_list.detach()
846
+ p = teacher_dino_softmax_centered.detach()
834
847
  p = p.reshape(-1, p.size(-1)) # (N, D)
835
848
 
836
849
  # Mean distribution over prototypes (marginal)
@@ -205,7 +205,7 @@ def train(args: argparse.Namespace) -> None:
205
205
  #
206
206
  # Initialize
207
207
  #
208
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
208
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
209
209
 
210
210
  if args.size is None:
211
211
  args.size = registry.get_default_size(args.network)
@@ -234,7 +234,7 @@ def train(args: argparse.Namespace) -> None:
234
234
 
235
235
  network_name = get_mim_network_name("franca", encoder=args.network, tag=args.tag)
236
236
 
237
- student_backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
237
+ student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
238
238
  if args.model_config is not None:
239
239
  teacher_model_config = args.model_config.copy()
240
240
  teacher_model_config.update({"drop_path_rate": 0.0})
@@ -242,7 +242,7 @@ def train(args: argparse.Namespace) -> None:
242
242
  teacher_model_config = {"drop_path_rate": 0.0}
243
243
 
244
244
  teacher_backbone = registry.net_factory(
245
- args.network, sample_shape[1], 0, config=teacher_model_config, size=args.size
245
+ args.network, 0, sample_shape[1], config=teacher_model_config, size=args.size
246
246
  )
247
247
  student_backbone.set_dynamic_size()
248
248
  if args.ibot_separate_head is False:
@@ -296,7 +296,7 @@ def train(args: argparse.Namespace) -> None:
296
296
 
297
297
  if args.resume_epoch is not None:
298
298
  begin_epoch = args.resume_epoch + 1
299
- (net, training_states) = fs_ops.load_simple_checkpoint(
299
+ net, training_states = fs_ops.load_simple_checkpoint(
300
300
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
301
301
  )
302
302
  student = net["student"]
@@ -363,11 +363,11 @@ def train(args: argparse.Namespace) -> None:
363
363
  elif args.wds is True:
364
364
  wds_path: str | list[str]
365
365
  if args.wds_info is not None:
366
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
366
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
367
367
  if args.wds_size is not None:
368
368
  dataset_size = args.wds_size
369
369
  else:
370
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
370
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
371
371
 
372
372
  training_dataset = make_wds_dataset(
373
373
  wds_path,
@@ -393,7 +393,7 @@ def train(args: argparse.Namespace) -> None:
393
393
 
394
394
  # Data loaders and samplers
395
395
  virtual_epoch_mode = args.steps_per_epoch is not None
396
- (train_sampler, _) = training_utils.get_samplers(
396
+ train_sampler, _ = training_utils.get_samplers(
397
397
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
398
398
  )
399
399
 
@@ -493,7 +493,7 @@ def train(args: argparse.Namespace) -> None:
493
493
  wd_schedule = None
494
494
 
495
495
  # Gradient scaler and AMP related tasks
496
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
496
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
497
497
 
498
498
  # Load states
499
499
  if args.load_states is True:
@@ -623,6 +623,13 @@ def train(args: argparse.Namespace) -> None:
623
623
  tic = time.time()
624
624
  net.train()
625
625
 
626
+ # Clear metrics
627
+ running_loss.clear()
628
+ running_loss_dino_local.clear()
629
+ running_loss_dino_global.clear()
630
+ running_loss_koleo.clear()
631
+ running_loss_ibot_patch.clear()
632
+
626
633
  if args.sinkhorn_queue_size is not None:
627
634
  queue_active = epoch > args.sinkhorn_queue_warmup_epochs
628
635
  dino_loss.set_queue_active(queue_active)
@@ -681,7 +688,7 @@ def train(args: argparse.Namespace) -> None:
681
688
  with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
682
689
  with torch.no_grad():
683
690
  # Teacher
684
- (teacher_embedding_after_head, teacher_masked_patch_tokens_after_head) = teacher(
691
+ teacher_embedding_after_head, teacher_masked_patch_tokens_after_head = teacher(
685
692
  global_crops, n_global_crops, upper_bound, mask_indices_list
686
693
  )
687
694