birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from birder.common import cli
13
13
  from birder.conf import settings
14
14
  from birder.model_registry import Task
15
15
  from birder.model_registry import registry
16
+ from birder.net.base import DetectorBackbone
16
17
 
17
18
  logger = logging.getLogger(__name__)
18
19
 
@@ -27,6 +28,23 @@ def prepare_model(net: torch.nn.Module) -> None:
27
28
  param.requires_grad_(False)
28
29
 
29
30
 
31
+ def init_plain_model(
32
+ model_name: str, sample_shape: tuple[int, ...], device: torch.device, args: argparse.Namespace
33
+ ) -> torch.nn.Module:
34
+ size = (sample_shape[2], sample_shape[3])
35
+ input_channels = sample_shape[1]
36
+ if args.backbone is not None:
37
+ backbone = registry.net_factory(args.backbone, args.num_classes, input_channels, size=size)
38
+ net = registry.detection_net_factory(model_name, args.num_classes, backbone, size=size)
39
+ else:
40
+ net = registry.net_factory(model_name, args.num_classes, input_channels, size=size)
41
+
42
+ net.to(device)
43
+ prepare_model(net)
44
+
45
+ return net
46
+
47
+
30
48
  def throughput_benchmark(
31
49
  net: torch.nn.Module, device: torch.device, sample_shape: tuple[int, ...], model_name: str, args: argparse.Namespace
32
50
  ) -> tuple[float, int]:
@@ -110,14 +128,10 @@ def memory_benchmark(
110
128
  )
111
129
 
112
130
  if args.plain is True:
113
- size = (sample_shape[2], sample_shape[3])
114
- input_channels = sample_shape[1]
115
- net = registry.net_factory(model_name, input_channels, 0, size=size)
116
- net.to(device)
117
- prepare_model(net)
131
+ net = init_plain_model(model_name, sample_shape, device, args)
118
132
 
119
133
  else:
120
- (net, _) = birder.load_pretrained_model(model_name, inference=True, device=device)
134
+ net, _ = birder.load_pretrained_model(model_name, inference=True, device=device)
121
135
  if args.size is not None:
122
136
  size = (sample_shape[2], sample_shape[3])
123
137
  net.adjust_size(size)
@@ -182,7 +196,8 @@ def benchmark(args: argparse.Namespace) -> None:
182
196
  if args.plain is True:
183
197
  model_list = args.models or []
184
198
  if len(model_list) == 0:
185
- model_list = registry.list_models(include_filter=args.filter, task=Task.IMAGE_CLASSIFICATION)
199
+ task = Task.OBJECT_DETECTION if args.backbone is not None else Task.IMAGE_CLASSIFICATION
200
+ model_list = registry.list_models(include_filter=args.filter, task=task)
186
201
 
187
202
  else:
188
203
  model_list = birder.list_pretrained_models(args.filter)
@@ -234,11 +249,9 @@ def benchmark(args: argparse.Namespace) -> None:
234
249
  else:
235
250
  # Initialize model
236
251
  if args.plain is True:
237
- net = registry.net_factory(model_name, input_channels, 0, size=size)
238
- net.to(device)
239
- prepare_model(net)
252
+ net = init_plain_model(model_name, sample_shape, device, args)
240
253
  else:
241
- (net, _) = birder.load_pretrained_model(model_name, inference=True, device=device)
254
+ net, _ = birder.load_pretrained_model(model_name, inference=True, device=device)
242
255
  if args.size is not None:
243
256
  net.adjust_size(size)
244
257
 
@@ -247,7 +260,7 @@ def benchmark(args: argparse.Namespace) -> None:
247
260
  net = torch.compile(net)
248
261
 
249
262
  peak_memory = None
250
- (t_elapsed, batch_size) = throughput_benchmark(net, device, sample_shape, model_name, args)
263
+ t_elapsed, batch_size = throughput_benchmark(net, device, sample_shape, model_name, args)
251
264
  if t_elapsed < 0.0:
252
265
  continue
253
266
 
@@ -305,12 +318,18 @@ def get_args_parser() -> argparse.ArgumentParser:
305
318
  "--compile --suffix il-common --append\n"
306
319
  "python -m birder.scripts.benchmark --plain --models rdnet_t convnext_v1_tiny --bench-iter 50 --repeats 1 "
307
320
  "--gpu --size 416 --dry-run\n"
321
+ "python -m birder.scripts.benchmark --plain --models retinanet --backbone resnet_v1_50 --num-classes 91 "
322
+ "--size 640 --gpu --dry-run\n"
308
323
  ),
309
324
  formatter_class=cli.ArgumentHelpFormatter,
310
325
  )
311
326
  parser.add_argument("--filter", type=str, help="models to benchmark (fnmatch type filter)")
312
327
  parser.add_argument("--models", nargs="+", help="plain network names to benchmark")
313
328
  parser.add_argument("--plain", default=False, action="store_true", help="benchmark plain networks without weights")
329
+ parser.add_argument("--backbone", type=str, help="backbone name for plain detection benchmarks")
330
+ parser.add_argument(
331
+ "--num-classes", type=int, default=0, metavar="N", help="number of classes for plain benchmarks"
332
+ )
314
333
  parser.add_argument("--compile", default=False, action="store_true", help="enable compilation")
315
334
  parser.add_argument(
316
335
  "--amp", default=False, action="store_true", help="use torch.amp.autocast for mixed precision inference"
@@ -353,6 +372,12 @@ def validate_args(args: argparse.Namespace) -> None:
353
372
  raise cli.ValidationError("--memory cannot be used with --compile")
354
373
  if args.plain is False and args.models is not None:
355
374
  raise cli.ValidationError("--models can only be used with --plain")
375
+ if args.backbone is not None and args.plain is False:
376
+ raise cli.ValidationError("--backbone can only be used with --plain")
377
+ if args.backbone is not None and registry.exists(args.backbone, net_type=DetectorBackbone) is False:
378
+ raise cli.ValidationError(
379
+ f"--backbone {args.backbone} not supported, see list-models tool for available options"
380
+ )
356
381
 
357
382
 
358
383
  def args_from_dict(**kwargs: Any) -> argparse.Namespace:
@@ -37,7 +37,7 @@ def evaluate(args: argparse.Namespace) -> None:
37
37
  amp_dtype: torch.dtype = getattr(torch, args.amp_dtype)
38
38
  model_list = birder.list_pretrained_models(args.filter)
39
39
  for model_name in model_list:
40
- (net, (class_to_idx, signature, rgb_stats, *_)) = birder.load_pretrained_model(
40
+ net, (class_to_idx, signature, rgb_stats, *_) = birder.load_pretrained_model(
41
41
  model_name, inference=True, device=device, dtype=model_dtype
42
42
  )
43
43
  if args.parallel is True and torch.cuda.device_count() > 1:
birder/scripts/predict.py CHANGED
@@ -204,7 +204,7 @@ def predict(args: argparse.Namespace) -> None:
204
204
  raise RuntimeError("'pip install torchao' to load quantization operators") from exc
205
205
 
206
206
  network_name = lib.get_network_name(args.network, tag=args.tag)
207
- (net, (class_to_idx, signature, rgb_stats, *_)) = fs_ops.load_model(
207
+ net, (class_to_idx, signature, rgb_stats, *_) = fs_ops.load_model(
208
208
  device,
209
209
  args.network,
210
210
  config=args.model_config,
@@ -261,11 +261,11 @@ def predict(args: argparse.Namespace) -> None:
261
261
  if args.wds is True:
262
262
  wds_path: str | list[str]
263
263
  if args.wds_info is not None:
264
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
264
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
265
265
  if args.wds_size is not None:
266
266
  dataset_size = args.wds_size
267
267
  else:
268
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
268
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
269
269
 
270
270
  num_samples = dataset_size
271
271
  dataset = make_wds_dataset(
@@ -60,7 +60,7 @@ def predict(args: argparse.Namespace) -> None:
60
60
  network_name = lib.get_detection_network_name(
61
61
  args.network, tag=args.tag, backbone=args.backbone, backbone_tag=args.backbone_tag
62
62
  )
63
- (net, (class_to_idx, signature, rgb_stats, *_)) = fs_ops.load_detection_model(
63
+ net, (class_to_idx, signature, rgb_stats, *_) = fs_ops.load_detection_model(
64
64
  device,
65
65
  args.network,
66
66
  config=args.model_config,
@@ -197,7 +197,7 @@ def predict(args: argparse.Namespace) -> None:
197
197
  # Inference
198
198
  tic = time.time()
199
199
  with torch.inference_mode():
200
- (sample_paths, detections, targets) = infer_dataloader(
200
+ sample_paths, detections, targets = infer_dataloader(
201
201
  device,
202
202
  net,
203
203
  inference_loader,
birder/scripts/train.py CHANGED
@@ -7,6 +7,7 @@ import time
7
7
  from collections.abc import Iterator
8
8
  from pathlib import Path
9
9
  from typing import Any
10
+ from typing import Optional
10
11
 
11
12
  import matplotlib.pyplot as plt
12
13
  import numpy as np
@@ -52,7 +53,7 @@ def train(args: argparse.Namespace) -> None:
52
53
  #
53
54
  # Initialize
54
55
  #
55
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
56
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
56
57
 
57
58
  if args.size is None:
58
59
  args.size = registry.get_default_size(args.network)
@@ -77,15 +78,15 @@ def train(args: argparse.Namespace) -> None:
77
78
  training_wds_path: str | list[str]
78
79
  val_wds_path: str | list[str]
79
80
  if args.wds_info is not None:
80
- (training_wds_path, training_size) = wds_args_from_info(args.wds_info, args.wds_training_split)
81
- (val_wds_path, val_size) = wds_args_from_info(args.wds_info, args.wds_val_split)
81
+ training_wds_path, training_size = wds_args_from_info(args.wds_info, args.wds_training_split)
82
+ val_wds_path, val_size = wds_args_from_info(args.wds_info, args.wds_val_split)
82
83
  if args.wds_train_size is not None:
83
84
  training_size = args.wds_train_size
84
85
  if args.wds_val_size is not None:
85
86
  val_size = args.wds_val_size
86
87
  else:
87
- (training_wds_path, training_size) = prepare_wds_args(args.data_path, args.wds_train_size, device)
88
- (val_wds_path, val_size) = prepare_wds_args(args.val_path, args.wds_val_size, device)
88
+ training_wds_path, training_size = prepare_wds_args(args.data_path, args.wds_train_size, device)
89
+ val_wds_path, val_size = prepare_wds_args(args.val_path, args.wds_val_size, device)
89
90
 
90
91
  training_dataset = make_wds_dataset(
91
92
  training_wds_path,
@@ -149,7 +150,7 @@ def train(args: argparse.Namespace) -> None:
149
150
 
150
151
  # Data loaders and samplers
151
152
  virtual_epoch_mode = args.steps_per_epoch is not None
152
- (train_sampler, validation_sampler) = training_utils.get_samplers(
153
+ train_sampler, validation_sampler = training_utils.get_samplers(
153
154
  args, training_dataset, validation_dataset, infinite=virtual_epoch_mode
154
155
  )
155
156
 
@@ -231,7 +232,7 @@ def train(args: argparse.Namespace) -> None:
231
232
 
232
233
  if args.resume_epoch is not None:
233
234
  begin_epoch = args.resume_epoch + 1
234
- (net, class_to_idx_saved, training_states) = fs_ops.load_checkpoint(
235
+ net, class_to_idx_saved, training_states = fs_ops.load_checkpoint(
235
236
  device,
236
237
  args.network,
237
238
  config=args.model_config,
@@ -247,7 +248,7 @@ def train(args: argparse.Namespace) -> None:
247
248
 
248
249
  elif args.pretrained is True:
249
250
  fs_ops.download_model_by_weights(network_name, progress_bar=training_utils.is_local_primary(args))
250
- (net, class_to_idx_saved, training_states) = fs_ops.load_checkpoint(
251
+ net, class_to_idx_saved, training_states = fs_ops.load_checkpoint(
251
252
  device,
252
253
  args.network,
253
254
  config=args.model_config,
@@ -262,7 +263,7 @@ def train(args: argparse.Namespace) -> None:
262
263
  assert class_to_idx == class_to_idx_saved
263
264
 
264
265
  else:
265
- net = registry.net_factory(args.network, sample_shape[1], num_outputs, config=args.model_config, size=args.size)
266
+ net = registry.net_factory(args.network, num_outputs, sample_shape[1], config=args.model_config, size=args.size)
266
267
  training_states = fs_ops.TrainingStates.empty()
267
268
 
268
269
  net.to(device, dtype=model_dtype)
@@ -328,7 +329,7 @@ def train(args: argparse.Namespace) -> None:
328
329
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
329
330
 
330
331
  # Gradient scaler and AMP related tasks
331
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
332
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
332
333
 
333
334
  # Load states
334
335
  if args.load_states is True:
@@ -474,16 +475,32 @@ def train(args: argparse.Namespace) -> None:
474
475
  if virtual_epoch_mode is True:
475
476
  train_iter = iter(training_loader)
476
477
 
478
+ top_k = args.top_k
477
479
  running_loss = training_utils.SmoothedValue(window_size=64)
478
480
  running_val_loss = training_utils.SmoothedValue()
479
481
  train_accuracy = training_utils.SmoothedValue(window_size=64)
480
482
  val_accuracy = training_utils.SmoothedValue()
483
+ train_topk: Optional[training_utils.SmoothedValue] = None
484
+ val_topk: Optional[training_utils.SmoothedValue] = None
485
+ if top_k is not None:
486
+ train_topk = training_utils.SmoothedValue(window_size=64)
487
+ val_topk = training_utils.SmoothedValue()
481
488
 
482
489
  logger.info(f"Starting training with learning rate of {last_lr}")
483
490
  for epoch in range(begin_epoch, args.stop_epoch):
484
491
  tic = time.time()
485
492
  net.train()
486
493
 
494
+ # Clear metrics
495
+ running_loss.clear()
496
+ running_val_loss.clear()
497
+ train_accuracy.clear()
498
+ val_accuracy.clear()
499
+ if train_topk is not None:
500
+ train_topk.clear()
501
+ if val_topk is not None:
502
+ val_topk.clear()
503
+
487
504
  if args.distributed is True or virtual_epoch_mode is True:
488
505
  train_sampler.set_epoch(epoch)
489
506
 
@@ -565,6 +582,9 @@ def train(args: argparse.Namespace) -> None:
565
582
  targets = targets.argmax(dim=1)
566
583
 
567
584
  train_accuracy.update(training_utils.accuracy(targets, outputs.detach()))
585
+ if train_topk is not None:
586
+ topk_val = training_utils.topk_accuracy(targets, outputs.detach(), topk=(top_k,))[0]
587
+ train_topk.update(topk_val)
568
588
 
569
589
  # Write statistics
570
590
  if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
@@ -583,6 +603,9 @@ def train(args: argparse.Namespace) -> None:
583
603
 
584
604
  running_loss.synchronize_between_processes(device)
585
605
  train_accuracy.synchronize_between_processes(device)
606
+ if train_topk is not None:
607
+ train_topk.synchronize_between_processes(device)
608
+
586
609
  with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
587
610
  log.info(
588
611
  f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
@@ -597,8 +620,17 @@ def train(args: argparse.Namespace) -> None:
597
620
  f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
598
621
  f"Accuracy: {train_accuracy.avg:.4f}"
599
622
  )
623
+ if train_topk is not None:
624
+ log.info(
625
+ f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
626
+ f"Accuracy@{top_k}: {train_topk.avg:.4f}"
627
+ )
600
628
 
601
629
  if training_utils.is_local_primary(args) is True:
630
+ performance = {"training_accuracy": train_accuracy.avg}
631
+ if train_topk is not None:
632
+ performance[f"training_accuracy@{top_k}"] = train_topk.avg
633
+
602
634
  summary_writer.add_scalars(
603
635
  "loss",
604
636
  {"training": running_loss.avg},
@@ -606,7 +638,7 @@ def train(args: argparse.Namespace) -> None:
606
638
  )
607
639
  summary_writer.add_scalars(
608
640
  "performance",
609
- {"training_accuracy": train_accuracy.avg},
641
+ performance,
610
642
  ((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
611
643
  )
612
644
 
@@ -618,6 +650,8 @@ def train(args: argparse.Namespace) -> None:
618
650
  # Epoch training metrics
619
651
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_loss: {running_loss.global_avg:.4f}")
620
652
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_accuracy: {train_accuracy.global_avg:.4f}")
653
+ if train_topk is not None:
654
+ logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_accuracy@{top_k}: {train_topk.global_avg:.4f}")
621
655
 
622
656
  # Validation
623
657
  eval_model.eval()
@@ -649,6 +683,9 @@ def train(args: argparse.Namespace) -> None:
649
683
  # Statistics
650
684
  running_val_loss.update(val_loss.detach())
651
685
  val_accuracy.update(training_utils.accuracy(targets, outputs), n=outputs.size(0))
686
+ if val_topk is not None:
687
+ topk_val = training_utils.topk_accuracy(targets, outputs, topk=(top_k,))[0]
688
+ val_topk.update(topk_val, n=outputs.size(0))
652
689
 
653
690
  # Update progress bar
654
691
  progress.update(n=batch_size * args.world_size)
@@ -666,19 +703,30 @@ def train(args: argparse.Namespace) -> None:
666
703
 
667
704
  running_val_loss.synchronize_between_processes(device)
668
705
  val_accuracy.synchronize_between_processes(device)
706
+ if val_topk is not None:
707
+ val_topk.synchronize_between_processes(device)
708
+
669
709
  epoch_val_loss = running_val_loss.global_avg
670
710
  epoch_val_accuracy = val_accuracy.global_avg
711
+ if val_topk is not None:
712
+ epoch_val_topk = val_topk.global_avg
713
+ else:
714
+ epoch_val_topk = None
671
715
 
672
716
  # Write statistics
673
717
  if training_utils.is_local_primary(args) is True:
674
718
  summary_writer.add_scalars("loss", {"validation": epoch_val_loss}, epoch * epoch_samples)
675
- summary_writer.add_scalars(
676
- "performance", {"validation_accuracy": epoch_val_accuracy}, epoch * epoch_samples
677
- )
719
+ performance = {"validation_accuracy": epoch_val_accuracy}
720
+ if epoch_val_topk is not None:
721
+ performance[f"validation_accuracy@{top_k}"] = epoch_val_topk
722
+
723
+ summary_writer.add_scalars("performance", performance, epoch * epoch_samples)
678
724
 
679
725
  # Epoch validation metrics
680
726
  logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_loss: {epoch_val_loss:.4f}")
681
727
  logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_accuracy: {epoch_val_accuracy:.4f}")
728
+ if epoch_val_topk is not None:
729
+ logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_accuracy@{top_k}: {epoch_val_topk:.4f}")
682
730
 
683
731
  # Learning rate scheduler update
684
732
  if step_update is False:
@@ -849,7 +897,7 @@ def get_args_parser() -> argparse.ArgumentParser:
849
897
  training_cli.add_compile_args(parser)
850
898
  training_cli.add_checkpoint_args(parser, default_save_frequency=5, pretrained=True)
851
899
  training_cli.add_distributed_args(parser)
852
- training_cli.add_logging_and_debug_args(parser)
900
+ training_cli.add_logging_and_debug_args(parser, classification=True)
853
901
  training_cli.add_training_data_args(parser)
854
902
 
855
903
  return parser
@@ -69,7 +69,7 @@ def train(args: argparse.Namespace) -> None:
69
69
  #
70
70
  # Initialize
71
71
  #
72
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
72
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
73
73
 
74
74
  if args.size is None:
75
75
  args.size = registry.get_default_size(args.network)
@@ -92,11 +92,11 @@ def train(args: argparse.Namespace) -> None:
92
92
  elif args.wds is True:
93
93
  wds_path: str | list[str]
94
94
  if args.wds_info is not None:
95
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
95
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
96
96
  if args.wds_size is not None:
97
97
  dataset_size = args.wds_size
98
98
  else:
99
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
99
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
100
100
 
101
101
  training_dataset = make_wds_dataset(
102
102
  wds_path,
@@ -126,7 +126,7 @@ def train(args: argparse.Namespace) -> None:
126
126
 
127
127
  # Data loaders and samplers
128
128
  virtual_epoch_mode = args.steps_per_epoch is not None
129
- (train_sampler, _) = training_utils.get_samplers(
129
+ train_sampler, _ = training_utils.get_samplers(
130
130
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
131
131
  )
132
132
 
@@ -189,12 +189,12 @@ def train(args: argparse.Namespace) -> None:
189
189
 
190
190
  network_name = get_mim_network_name("barlow_twins", encoder=args.network, tag=args.tag)
191
191
 
192
- backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
192
+ backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
193
193
  net = BarlowTwins(backbone, config={"projector_sizes": args.projector_dims, "off_lambda": args.off_lambda})
194
194
 
195
195
  if args.resume_epoch is not None:
196
196
  begin_epoch = args.resume_epoch + 1
197
- (net, training_states) = fs_ops.load_simple_checkpoint(
197
+ net, training_states = fs_ops.load_simple_checkpoint(
198
198
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
199
199
  )
200
200
 
@@ -253,7 +253,7 @@ def train(args: argparse.Namespace) -> None:
253
253
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
254
254
 
255
255
  # Gradient scaler and AMP related tasks
256
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
256
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
257
257
 
258
258
  # Load states
259
259
  if args.load_states is True:
@@ -365,6 +365,9 @@ def train(args: argparse.Namespace) -> None:
365
365
  tic = time.time()
366
366
  net.train()
367
367
 
368
+ # Clear metrics
369
+ running_loss.clear()
370
+
368
371
  if args.distributed is True or virtual_epoch_mode is True:
369
372
  train_sampler.set_epoch(epoch)
370
373
 
@@ -70,7 +70,7 @@ def train(args: argparse.Namespace) -> None:
70
70
  #
71
71
  # Initialize
72
72
  #
73
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
73
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
74
74
 
75
75
  if args.size is None:
76
76
  # Prefer mim size over encoder default size
@@ -94,11 +94,11 @@ def train(args: argparse.Namespace) -> None:
94
94
  elif args.wds is True:
95
95
  wds_path: str | list[str]
96
96
  if args.wds_info is not None:
97
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
97
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
98
98
  if args.wds_size is not None:
99
99
  dataset_size = args.wds_size
100
100
  else:
101
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
101
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
102
102
 
103
103
  training_dataset = make_wds_dataset(
104
104
  wds_path,
@@ -128,7 +128,7 @@ def train(args: argparse.Namespace) -> None:
128
128
 
129
129
  # Data loaders and samplers
130
130
  virtual_epoch_mode = args.steps_per_epoch is not None
131
- (train_sampler, _) = training_utils.get_samplers(
131
+ train_sampler, _ = training_utils.get_samplers(
132
132
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
133
133
  )
134
134
 
@@ -191,7 +191,7 @@ def train(args: argparse.Namespace) -> None:
191
191
 
192
192
  network_name = get_mim_network_name("byol", encoder=args.network, tag=args.tag)
193
193
 
194
- backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
194
+ backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
195
195
  net = BYOL(
196
196
  backbone,
197
197
  config={
@@ -202,7 +202,7 @@ def train(args: argparse.Namespace) -> None:
202
202
 
203
203
  if args.resume_epoch is not None:
204
204
  begin_epoch = args.resume_epoch + 1
205
- (net, training_states) = fs_ops.load_simple_checkpoint(
205
+ net, training_states = fs_ops.load_simple_checkpoint(
206
206
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
207
207
  )
208
208
 
@@ -265,7 +265,7 @@ def train(args: argparse.Namespace) -> None:
265
265
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
266
266
 
267
267
  # Gradient scaler and AMP related tasks
268
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
268
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
269
269
 
270
270
  # Load states
271
271
  if args.load_states is True:
@@ -377,6 +377,9 @@ def train(args: argparse.Namespace) -> None:
377
377
  tic = time.time()
378
378
  net.train()
379
379
 
380
+ # Clear metrics
381
+ running_loss.clear()
382
+
380
383
  if args.distributed is True or virtual_epoch_mode is True:
381
384
  train_sampler.set_epoch(epoch)
382
385
 
@@ -79,7 +79,7 @@ def train(args: argparse.Namespace) -> None:
79
79
  #
80
80
  # Initialize
81
81
  #
82
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
82
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
83
83
 
84
84
  if args.size is None:
85
85
  args.size = registry.get_default_size(args.network)
@@ -108,8 +108,8 @@ def train(args: argparse.Namespace) -> None:
108
108
 
109
109
  network_name = get_mim_network_name("capi", encoder=args.network, tag=args.tag)
110
110
 
111
- student_backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
112
- teacher_backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
111
+ student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
112
+ teacher_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
113
113
 
114
114
  teacher_backbone.load_state_dict(student_backbone.state_dict())
115
115
 
@@ -144,7 +144,7 @@ def train(args: argparse.Namespace) -> None:
144
144
 
145
145
  if args.resume_epoch is not None:
146
146
  begin_epoch = args.resume_epoch + 1
147
- (net, training_states) = fs_ops.load_simple_checkpoint(
147
+ net, training_states = fs_ops.load_simple_checkpoint(
148
148
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
149
149
  )
150
150
  student = net["student"]
@@ -194,11 +194,11 @@ def train(args: argparse.Namespace) -> None:
194
194
  elif args.wds is True:
195
195
  wds_path: str | list[str]
196
196
  if args.wds_info is not None:
197
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
197
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
198
198
  if args.wds_size is not None:
199
199
  dataset_size = args.wds_size
200
200
  else:
201
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
201
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
202
202
 
203
203
  training_dataset = make_wds_dataset(
204
204
  wds_path,
@@ -224,7 +224,7 @@ def train(args: argparse.Namespace) -> None:
224
224
 
225
225
  # Data loaders and samplers
226
226
  virtual_epoch_mode = args.steps_per_epoch is not None
227
- (train_sampler, _) = training_utils.get_samplers(
227
+ train_sampler, _ = training_utils.get_samplers(
228
228
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
229
229
  )
230
230
 
@@ -326,8 +326,8 @@ def train(args: argparse.Namespace) -> None:
326
326
  student_temp = 0.12
327
327
 
328
328
  # Gradient scaler and AMP related tasks
329
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
330
- (clustering_scaler, _) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
329
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
330
+ clustering_scaler, _ = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
331
331
 
332
332
  # Load states
333
333
  if args.load_states is True:
@@ -453,6 +453,11 @@ def train(args: argparse.Namespace) -> None:
453
453
  tic = time.time()
454
454
  net.train()
455
455
 
456
+ # Clear metrics
457
+ running_loss.clear()
458
+ running_clustering_loss.clear()
459
+ running_target_entropy.clear()
460
+
456
461
  if args.sinkhorn_queue_size is not None:
457
462
  queue_active = epoch > args.sinkhorn_queue_warmup_epochs
458
463
  teacher_without_ddp.head.set_queue_active(queue_active)
@@ -499,7 +504,7 @@ def train(args: argparse.Namespace) -> None:
499
504
 
500
505
  # Forward, backward and optimize
501
506
  with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
502
- (selected_assignments, clustering_loss) = teacher(images, None, predict_indices)
507
+ selected_assignments, clustering_loss = teacher(images, None, predict_indices)
503
508
 
504
509
  if clustering_scaler is not None:
505
510
  clustering_scaler.scale(clustering_loss).backward()
@@ -69,7 +69,7 @@ def train(args: argparse.Namespace) -> None:
69
69
  #
70
70
  # Initialize
71
71
  #
72
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
72
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
73
73
 
74
74
  if args.size is None:
75
75
  # Prefer mim size over encoder default size
@@ -99,7 +99,7 @@ def train(args: argparse.Namespace) -> None:
99
99
 
100
100
  network_name = get_mim_network_name("data2vec", encoder=args.network, tag=args.tag)
101
101
 
102
- backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
102
+ backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
103
103
  net = Data2Vec(
104
104
  backbone,
105
105
  config={
@@ -112,7 +112,7 @@ def train(args: argparse.Namespace) -> None:
112
112
 
113
113
  if args.resume_epoch is not None:
114
114
  begin_epoch = args.resume_epoch + 1
115
- (net, training_states) = fs_ops.load_simple_checkpoint(
115
+ net, training_states = fs_ops.load_simple_checkpoint(
116
116
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
117
117
  )
118
118
 
@@ -160,11 +160,11 @@ def train(args: argparse.Namespace) -> None:
160
160
  elif args.wds is True:
161
161
  wds_path: str | list[str]
162
162
  if args.wds_info is not None:
163
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
163
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
164
164
  if args.wds_size is not None:
165
165
  dataset_size = args.wds_size
166
166
  else:
167
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
167
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
168
168
 
169
169
  training_dataset = make_wds_dataset(
170
170
  wds_path,
@@ -190,7 +190,7 @@ def train(args: argparse.Namespace) -> None:
190
190
 
191
191
  # Data loaders and samplers
192
192
  virtual_epoch_mode = args.steps_per_epoch is not None
193
- (train_sampler, _) = training_utils.get_samplers(
193
+ train_sampler, _ = training_utils.get_samplers(
194
194
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
195
195
  )
196
196
 
@@ -279,7 +279,7 @@ def train(args: argparse.Namespace) -> None:
279
279
  )
280
280
 
281
281
  # Gradient scaler and AMP related tasks
282
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
282
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
283
283
 
284
284
  # Load states
285
285
  if args.load_states is True:
@@ -391,6 +391,9 @@ def train(args: argparse.Namespace) -> None:
391
391
  tic = time.time()
392
392
  net.train()
393
393
 
394
+ # Clear metrics
395
+ running_loss.clear()
396
+
394
397
  if args.distributed is True or virtual_epoch_mode is True:
395
398
  train_sampler.set_epoch(epoch)
396
399