birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -67,7 +67,7 @@ def train(args: argparse.Namespace) -> None:
67
67
  #
68
68
  # Initialize
69
69
  #
70
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
70
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
71
71
 
72
72
  if args.size is None:
73
73
  args.size = registry.get_default_size(args.network)
@@ -90,11 +90,11 @@ def train(args: argparse.Namespace) -> None:
90
90
  elif args.wds is True:
91
91
  wds_path: str | list[str]
92
92
  if args.wds_info is not None:
93
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
93
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
94
94
  if args.wds_size is not None:
95
95
  dataset_size = args.wds_size
96
96
  else:
97
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
97
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
98
98
 
99
99
  training_dataset = make_wds_dataset(
100
100
  wds_path,
@@ -124,7 +124,7 @@ def train(args: argparse.Namespace) -> None:
124
124
 
125
125
  # Data loaders and samplers
126
126
  virtual_epoch_mode = args.steps_per_epoch is not None
127
- (train_sampler, _) = training_utils.get_samplers(
127
+ train_sampler, _ = training_utils.get_samplers(
128
128
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
129
129
  )
130
130
 
@@ -187,7 +187,7 @@ def train(args: argparse.Namespace) -> None:
187
187
 
188
188
  network_name = get_mim_network_name("simclr", encoder=args.network, tag=args.tag)
189
189
 
190
- backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
190
+ backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
191
191
  net = SimCLR(
192
192
  backbone,
193
193
  config={
@@ -199,7 +199,7 @@ def train(args: argparse.Namespace) -> None:
199
199
 
200
200
  if args.resume_epoch is not None:
201
201
  begin_epoch = args.resume_epoch + 1
202
- (net, training_states) = fs_ops.load_simple_checkpoint(
202
+ net, training_states = fs_ops.load_simple_checkpoint(
203
203
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
204
204
  )
205
205
 
@@ -258,7 +258,7 @@ def train(args: argparse.Namespace) -> None:
258
258
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
259
259
 
260
260
  # Gradient scaler and AMP related tasks
261
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
261
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
262
262
 
263
263
  # Load states
264
264
  if args.load_states is True:
@@ -370,6 +370,9 @@ def train(args: argparse.Namespace) -> None:
370
370
  tic = time.time()
371
371
  net.train()
372
372
 
373
+ # Clear metrics
374
+ running_loss.clear()
375
+
373
376
  if args.distributed is True or virtual_epoch_mode is True:
374
377
  train_sampler.set_epoch(epoch)
375
378
 
@@ -70,7 +70,7 @@ def train(args: argparse.Namespace) -> None:
70
70
  #
71
71
  # Initialize
72
72
  #
73
- (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
73
+ device, device_id, disable_tqdm = training_utils.init_training(args, logger)
74
74
 
75
75
  if args.size is None:
76
76
  args.size = registry.get_default_size(args.network)
@@ -93,11 +93,11 @@ def train(args: argparse.Namespace) -> None:
93
93
  elif args.wds is True:
94
94
  wds_path: str | list[str]
95
95
  if args.wds_info is not None:
96
- (wds_path, dataset_size) = wds_args_from_info(args.wds_info, args.wds_split)
96
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
97
97
  if args.wds_size is not None:
98
98
  dataset_size = args.wds_size
99
99
  else:
100
- (wds_path, dataset_size) = prepare_wds_args(args.data_path[0], args.wds_size, device)
100
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
101
101
 
102
102
  training_dataset = make_wds_dataset(
103
103
  wds_path,
@@ -127,7 +127,7 @@ def train(args: argparse.Namespace) -> None:
127
127
 
128
128
  # Data loaders and samplers
129
129
  virtual_epoch_mode = args.steps_per_epoch is not None
130
- (train_sampler, _) = training_utils.get_samplers(
130
+ train_sampler, _ = training_utils.get_samplers(
131
131
  args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
132
132
  )
133
133
 
@@ -190,7 +190,7 @@ def train(args: argparse.Namespace) -> None:
190
190
 
191
191
  network_name = get_mim_network_name("vicreg", encoder=args.network, tag=args.tag)
192
192
 
193
- backbone = registry.net_factory(args.network, sample_shape[1], 0, config=args.model_config, size=args.size)
193
+ backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
194
194
  net = VICReg(
195
195
  backbone,
196
196
  config={
@@ -205,7 +205,7 @@ def train(args: argparse.Namespace) -> None:
205
205
 
206
206
  if args.resume_epoch is not None:
207
207
  begin_epoch = args.resume_epoch + 1
208
- (net, training_states) = fs_ops.load_simple_checkpoint(
208
+ net, training_states = fs_ops.load_simple_checkpoint(
209
209
  device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
210
210
  )
211
211
 
@@ -264,7 +264,7 @@ def train(args: argparse.Namespace) -> None:
264
264
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
265
265
 
266
266
  # Gradient scaler and AMP related tasks
267
- (scaler, amp_dtype) = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
267
+ scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
268
268
 
269
269
  # Load states
270
270
  if args.load_states is True:
@@ -376,6 +376,9 @@ def train(args: argparse.Namespace) -> None:
376
376
  tic = time.time()
377
377
  net.train()
378
378
 
379
+ # Clear metrics
380
+ running_loss.clear()
381
+
379
382
  if args.distributed is True or virtual_epoch_mode is True:
380
383
  train_sampler.set_epoch(epoch)
381
384
 
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
29
29
  def _load_model_and_transform(
30
30
  args: argparse.Namespace, device: torch.device
31
31
  ) -> tuple[torch.nn.Module, dict[str, int], RGBType, Callable[..., torch.Tensor], Callable[..., torch.Tensor]]:
32
- (net, model_info) = fs_ops.load_model(
32
+ net, model_info = fs_ops.load_model(
33
33
  device, args.network, tag=args.tag, epoch=args.epoch, inference=True, reparameterized=args.reparameterized
34
34
  )
35
35
 
@@ -105,8 +105,8 @@ def _display_results(
105
105
  success: Optional[bool],
106
106
  result: AttackResult,
107
107
  ) -> None:
108
- (orig_label, orig_prob) = original_pred
109
- (adv_label, adv_prob) = adv_pred
108
+ orig_label, orig_prob = original_pred
109
+ adv_label, adv_prob = adv_pred
110
110
 
111
111
  # Log results
112
112
  logger.info(f"Original: {orig_label} ({orig_prob * 100:.2f}%)")
@@ -139,7 +139,7 @@ def run_attack(args: argparse.Namespace) -> None:
139
139
 
140
140
  logger.info(f"Using device {device}")
141
141
 
142
- (net, class_to_idx, rgb_stats, transform, reverse_transform) = _load_model_and_transform(args, device)
142
+ net, class_to_idx, rgb_stats, transform, reverse_transform = _load_model_and_transform(args, device)
143
143
  label_names = [name for name, _idx in sorted(class_to_idx.items(), key=lambda item: item[1])]
144
144
  img = Image.open(args.image_path)
145
145
  input_tensor = transform(img).unsqueeze(dim=0).to(device)
@@ -92,7 +92,7 @@ def _load_coco_boxes(
92
92
  stats["missing_images"] += 1
93
93
  continue
94
94
 
95
- (img_w, img_h, file_name) = images[image_id]
95
+ img_w, img_h, file_name = images[image_id]
96
96
  if file_name in ignore_list:
97
97
  stats["ignored_images"] += 1
98
98
  continue
@@ -219,7 +219,7 @@ def _validate_args(
219
219
  output_format = args.format if args.format is not None else (preset["format"] if preset else None)
220
220
  if num_scales is None or num_anchors is None or output_format is None:
221
221
  raise cli.ValidationError(
222
- "Missing configuration. Provide --num-scales, --num-anchors, and --format or use a --preset"
222
+ "Missing configuration. Provide --num-scales, --num-anchors and --format or use a --preset"
223
223
  )
224
224
  if num_scales < 1:
225
225
  raise cli.ValidationError("--num-scales must be >= 1")
@@ -244,10 +244,10 @@ def _validate_args(
244
244
 
245
245
  # pylint: disable=too-many-locals
246
246
  def auto_anchors(args: argparse.Namespace) -> None:
247
- (size, num_scales, num_anchors, output_format, strides) = _validate_args(args)
247
+ size, num_scales, num_anchors, output_format, strides = _validate_args(args)
248
248
 
249
249
  ignore_list = _load_ignore_list(args.ignore_file)
250
- (boxes, stats) = _load_coco_boxes(
250
+ boxes, stats = _load_coco_boxes(
251
251
  args.coco_json_path, size, ignore_list, args.min_size, ignore_crowd=not args.include_crowd
252
252
  )
253
253
 
@@ -262,7 +262,7 @@ def auto_anchors(args: argparse.Namespace) -> None:
262
262
  f"missing_size={stats['missing_size']}, too_small={stats['too_small']}"
263
263
  )
264
264
 
265
- (anchors, _assignments) = _kmeans_anchors(boxes, num_anchors, args.seed, args.max_iter)
265
+ anchors, _assignments = _kmeans_anchors(boxes, num_anchors, args.seed, args.max_iter)
266
266
  areas = anchors.prod(dim=1)
267
267
  anchors = anchors[torch.argsort(areas)]
268
268
  anchors_per_scale = num_anchors // num_scales
birder/tools/avg_model.py CHANGED
@@ -44,7 +44,7 @@ def avg_models(
44
44
  num_classes = lib.get_num_labels_from_signature(signature)
45
45
  size = lib.get_size_from_signature(signature)
46
46
 
47
- net = registry.net_factory(network, input_channels, num_classes, size=size)
47
+ net = registry.net_factory(network, num_classes, input_channels, size=size)
48
48
  if reparameterized is True:
49
49
  net.reparameterize_model()
50
50
 
@@ -74,6 +74,7 @@ def onnx_export(
74
74
  net: torch.nn.Module,
75
75
  signature: SignatureType | DetectionSignatureType,
76
76
  class_to_idx: dict[str, int],
77
+ rgb_stats: RGBType,
77
78
  model_path: str | Path,
78
79
  dynamo: bool,
79
80
  trace: bool,
@@ -117,9 +118,19 @@ def onnx_export(
117
118
 
118
119
  signature["inputs"][0]["data_shape"][0] = 0
119
120
 
120
- logger.info("Saving class to index json...")
121
- with open(f"{model_path}_class_to_idx.json", "w", encoding="utf-8") as handle:
122
- json.dump(class_to_idx, handle, indent=2)
121
+ logger.info("Saving model data json...")
122
+ with open(f"{model_path}_data.json", "w", encoding="utf-8") as handle:
123
+ json.dump(
124
+ {
125
+ "birder_version": __version__,
126
+ "task": net.task,
127
+ "class_to_idx": class_to_idx,
128
+ "signature": signature,
129
+ "rgb_stats": rgb_stats,
130
+ },
131
+ handle,
132
+ indent=2,
133
+ )
123
134
 
124
135
  # Test exported model
125
136
  onnx_model = onnx.load(str(model_path))
@@ -238,7 +249,7 @@ def main(args: argparse.Namespace) -> None:
238
249
  signature: SignatureType | DetectionSignatureType
239
250
  backbone_custom_config = None
240
251
  if args.backbone is None:
241
- (net, (class_to_idx, signature, rgb_stats, custom_config)) = fs_ops.load_model(
252
+ net, (class_to_idx, signature, rgb_stats, custom_config) = fs_ops.load_model(
242
253
  device,
243
254
  args.network,
244
255
  config=args.model_config,
@@ -251,22 +262,20 @@ def main(args: argparse.Namespace) -> None:
251
262
  network_name = lib.get_network_name(args.network, tag=args.tag)
252
263
 
253
264
  else:
254
- (net, (class_to_idx, signature, rgb_stats, custom_config, backbone_custom_config)) = (
255
- fs_ops.load_detection_model(
256
- device,
257
- args.network,
258
- config=args.model_config,
259
- tag=args.tag,
260
- reparameterized=args.reparameterized,
261
- backbone=args.backbone,
262
- backbone_config=args.backbone_model_config,
263
- backbone_tag=args.backbone_tag,
264
- backbone_reparameterized=args.backbone_reparameterized,
265
- epoch=args.epoch,
266
- new_size=args.resize,
267
- inference=True,
268
- export_mode=True,
269
- )
265
+ net, (class_to_idx, signature, rgb_stats, custom_config, backbone_custom_config) = fs_ops.load_detection_model(
266
+ device,
267
+ args.network,
268
+ config=args.model_config,
269
+ tag=args.tag,
270
+ reparameterized=args.reparameterized,
271
+ backbone=args.backbone,
272
+ backbone_config=args.backbone_model_config,
273
+ backbone_tag=args.backbone_tag,
274
+ backbone_reparameterized=args.backbone_reparameterized,
275
+ epoch=args.epoch,
276
+ new_size=args.resize,
277
+ inference=True,
278
+ export_mode=True,
270
279
  )
271
280
  network_name = lib.get_detection_network_name(
272
281
  args.network, tag=args.tag, backbone=args.backbone, backbone_tag=args.backbone_tag
@@ -407,8 +416,7 @@ def main(args: argparse.Namespace) -> None:
407
416
  )
408
417
 
409
418
  elif args.onnx is True or args.onnx_dynamo is True:
410
- config_export(net, signature, rgb_stats, model_path)
411
- onnx_export(net, signature, class_to_idx, model_path, args.onnx_dynamo, args.trace)
419
+ onnx_export(net, signature, class_to_idx, rgb_stats, model_path, args.onnx_dynamo, args.trace)
412
420
 
413
421
  elif args.config is True:
414
422
  config_export(net, signature, rgb_stats, model_path)
@@ -239,7 +239,7 @@ def main(args: argparse.Namespace) -> None:
239
239
  logger.warning("Cannot compare confusion matrix, processing only the first file")
240
240
 
241
241
  results = next(iter(results_dict.values()))
242
- (cnf_matrix, label_names) = confusion_matrix_data(
242
+ cnf_matrix, label_names = confusion_matrix_data(
243
243
  results, args.cnf_score_threshold, args.cnf_iou_threshold, args.classes, args.cnf_errors_only
244
244
  )
245
245
  title = f"Confusion matrix (score >= {args.cnf_score_threshold:.2f}, IoU >= {args.cnf_iou_threshold:.2f})"
@@ -52,7 +52,7 @@ def main(args: argparse.Namespace) -> None:
52
52
  )
53
53
  raise SystemExit(1)
54
54
 
55
- (model_file, url) = get_pretrained_model_url(args.model_name, args.format)
55
+ model_file, url = get_pretrained_model_url(args.model_name, args.format)
56
56
  dst = settings.MODELS_DIR.joinpath(model_file)
57
57
  if dst.exists() is True and args.force is False:
58
58
  logger.warning(f"File {model_file} already exists... aborting")
@@ -58,7 +58,7 @@ def main(args: argparse.Namespace) -> None:
58
58
  signature_list = []
59
59
  rgb_stats_list = []
60
60
  for network in args.networks:
61
- (net, model_info) = fs_ops.load_model(device, network, inference=True, pts=args.pts, pt2=args.pt2)
61
+ net, model_info = fs_ops.load_model(device, network, inference=True, pts=args.pts, pt2=args.pt2)
62
62
  nets.append(net)
63
63
  class_to_idx_list.append(model_info.class_to_idx)
64
64
  signature_list.append(model_info.signature)
@@ -126,6 +126,14 @@ def set_parser(subparsers: Any) -> None:
126
126
  formatter_class=cli.ArgumentHelpFormatter,
127
127
  )
128
128
  subparser.add_argument("-n", "--network", type=str, required=True, help="the neural network to use")
129
+ subparser.add_argument(
130
+ "--model-config",
131
+ action=cli.FlexibleDictAction,
132
+ help=(
133
+ "override the model default configuration, accepts key-value pairs or JSON "
134
+ "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
135
+ ),
136
+ )
129
137
  subparser.add_argument("-e", "--epoch", type=int, metavar="N", help="model checkpoint to load")
130
138
  subparser.add_argument("-t", "--tag", type=str, help="model tag (from the training phase)")
131
139
  subparser.add_argument(
@@ -145,7 +153,7 @@ def set_parser(subparsers: Any) -> None:
145
153
  subparser.add_argument(
146
154
  "--target",
147
155
  type=str,
148
- help="target class, leave empty to use predicted class (gradcam, guided-backprop, and transformer-attribution)",
156
+ help="target class, leave empty to use predicted class (gradcam, guided-backprop and transformer-attribution)",
149
157
  )
150
158
  subparser.add_argument("--block-name", type=str, default="body", help="target block (gradcam only)")
151
159
  subparser.add_argument(
@@ -203,9 +211,10 @@ def main(args: argparse.Namespace) -> None:
203
211
 
204
212
  logger.info(f"Using device {device}")
205
213
 
206
- (net, model_info) = fs_ops.load_model(
214
+ net, model_info = fs_ops.load_model(
207
215
  device,
208
216
  args.network,
217
+ config=args.model_config,
209
218
  tag=args.tag,
210
219
  epoch=args.epoch,
211
220
  new_size=args.size,
@@ -25,8 +25,8 @@ def _create_annotation(
25
25
  annotation["image_id"] = image_id
26
26
 
27
27
  # Bounding box in (x, y, w, h) format
28
- (x0, y0) = points[0]
29
- (x1, y1) = points[1]
28
+ x0, y0 = points[0]
29
+ x1, y1 = points[1]
30
30
  x = min(x0, x1)
31
31
  y = min(y0, y1)
32
32
  w = abs(x0 - x1)
@@ -73,7 +73,7 @@ def main(args: argparse.Namespace) -> None:
73
73
  signature: SignatureType | DetectionSignatureType
74
74
  backbone_custom_config = None
75
75
  if args.backbone is None:
76
- (net, (class_to_idx, signature, rgb_stats, custom_config)) = fs_ops.load_model(
76
+ net, (class_to_idx, signature, rgb_stats, custom_config) = fs_ops.load_model(
77
77
  device,
78
78
  args.network,
79
79
  tag=args.tag,
@@ -86,19 +86,17 @@ def main(args: argparse.Namespace) -> None:
86
86
  )
87
87
 
88
88
  else:
89
- (net, (class_to_idx, signature, rgb_stats, custom_config, backbone_custom_config)) = (
90
- fs_ops.load_detection_model(
91
- device,
92
- args.network,
93
- tag=args.tag,
94
- backbone=args.backbone,
95
- backbone_tag=args.backbone_tag,
96
- epoch=args.epoch,
97
- inference=True,
98
- pts=args.pts,
99
- pt2=args.pt2,
100
- st=args.st,
101
- )
89
+ net, (class_to_idx, signature, rgb_stats, custom_config, backbone_custom_config) = fs_ops.load_detection_model(
90
+ device,
91
+ args.network,
92
+ tag=args.tag,
93
+ backbone=args.backbone,
94
+ backbone_tag=args.backbone_tag,
95
+ epoch=args.epoch,
96
+ inference=True,
97
+ pts=args.pts,
98
+ pt2=args.pt2,
99
+ st=args.st,
102
100
  )
103
101
 
104
102
  model_info = get_model_info(net)
birder/tools/pack.py CHANGED
@@ -114,7 +114,7 @@ def read_worker(q_in: Any, q_out: Any, error_event: Any, size: Optional[int], fi
114
114
  break
115
115
 
116
116
  try:
117
- (idx, path, target) = deq
117
+ idx, path, target = deq
118
118
  if size is None:
119
119
  suffix = Path(path).suffix[1:]
120
120
  if file_format != suffix:
@@ -172,7 +172,7 @@ def wds_write_worker(
172
172
  while more:
173
173
  deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
174
174
  if deq is not None:
175
- (idx, sample, suffix, target) = deq
175
+ idx, sample, suffix, target = deq
176
176
  buf[idx] = (sample, suffix, target)
177
177
 
178
178
  else:
@@ -180,7 +180,7 @@ def wds_write_worker(
180
180
 
181
181
  # Ensures ordered write
182
182
  while count in buf:
183
- (sample, suffix, target) = buf[count]
183
+ sample, suffix, target = buf[count]
184
184
  del buf[count]
185
185
 
186
186
  if args.no_cls is True:
@@ -238,7 +238,7 @@ def directory_write_worker(
238
238
  while more:
239
239
  deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
240
240
  if deq is not None:
241
- (idx, sample, suffix, target) = deq
241
+ idx, sample, suffix, target = deq
242
242
  buf[idx] = (sample, suffix, target)
243
243
 
244
244
  else:
@@ -246,7 +246,7 @@ def directory_write_worker(
246
246
 
247
247
  # Ensures ordered write
248
248
  while count in buf:
249
- (sample, suffix, target) = buf[count]
249
+ sample, suffix, target = buf[count]
250
250
  del buf[count]
251
251
  with open(
252
252
  pack_path.joinpath(idx_to_class[target]).joinpath(f"{count:06d}.{suffix}"), "wb"
@@ -274,7 +274,7 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
274
274
  if len(line.strip()) == 0 or line.strip().startswith("#") is True:
275
275
  continue
276
276
 
277
- (data_path, r) = line.split()
277
+ data_path, r = line.split()
278
278
  data_path = os.path.expanduser(data_path)
279
279
  repeats = int(r)
280
280
  for _ in range(repeats):
@@ -391,7 +391,7 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
391
391
  cleanup_processes()
392
392
  raise RuntimeError()
393
393
 
394
- (path, target) = dataset[sample_idx]
394
+ path, target = dataset[sample_idx]
395
395
 
396
396
  while True:
397
397
  try:
@@ -430,7 +430,7 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
430
430
  raise RuntimeError()
431
431
 
432
432
  if args.type == "wds":
433
- (wds_path, num_shards) = fs_ops.wds_braces_from_path(pack_path, prefix=f"{args.suffix}-{args.split}")
433
+ wds_path, num_shards = fs_ops.wds_braces_from_path(pack_path, prefix=f"{args.suffix}-{args.split}")
434
434
  logger.info(f"Packed {len(dataset):,} samples into {num_shards} shards at {wds_path}")
435
435
  elif args.type == "directory":
436
436
  logger.info(f"Packed {len(dataset):,} samples")
@@ -1,7 +1,9 @@
1
1
  import argparse
2
2
  import itertools
3
+ import json
3
4
  import logging
4
5
  import time
6
+ from pathlib import Path
5
7
  from typing import Any
6
8
 
7
9
  import torch
@@ -15,7 +17,11 @@ from birder.common import fs_ops
15
17
  from birder.common import lib
16
18
  from birder.common.lib import get_network_name
17
19
  from birder.conf import settings
20
+ from birder.data.transforms.classification import RGBType
18
21
  from birder.data.transforms.classification import inference_preset
22
+ from birder.net.base import SignatureType
23
+ from birder.net.detection.base import DetectionSignatureType
24
+ from birder.version import __version__
19
25
 
20
26
  try:
21
27
  from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e
@@ -28,8 +34,10 @@ except ImportError:
28
34
  _HAS_TORCHAO = False
29
35
 
30
36
  try:
37
+ from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
31
38
  from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer
32
39
  from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import get_symmetric_quantization_config
40
+ from executorch.exir import to_edge_transform_and_lower
33
41
 
34
42
  _HAS_EXECUTORCH = True
35
43
  except ImportError:
@@ -54,6 +62,33 @@ def _build_quantizer(backend: str) -> Any:
54
62
  raise ValueError(f"Unsupported backend: {backend}")
55
63
 
56
64
 
65
+ def _save_pte(
66
+ exported_net: torch.export.ExportedProgram,
67
+ dst: str | Path,
68
+ task: str,
69
+ class_to_idx: dict[str, int],
70
+ signature: SignatureType | DetectionSignatureType,
71
+ rgb_stats: RGBType,
72
+ ) -> None:
73
+ edge_program = to_edge_transform_and_lower(exported_net, partitioner=[XnnpackPartitioner()])
74
+ executorch_program = edge_program.to_executorch()
75
+ with open(dst, "wb") as f:
76
+ f.write(executorch_program.buffer)
77
+
78
+ with open(f"{dst}_data.json", "w", encoding="utf-8") as handle:
79
+ json.dump(
80
+ {
81
+ "birder_version": __version__,
82
+ "task": task,
83
+ "class_to_idx": class_to_idx,
84
+ "signature": signature,
85
+ "rgb_stats": rgb_stats,
86
+ },
87
+ handle,
88
+ indent=2,
89
+ )
90
+
91
+
57
92
  def set_parser(subparsers: Any) -> None:
58
93
  subparser = subparsers.add_parser(
59
94
  "quantize-model",
@@ -65,6 +100,7 @@ def set_parser(subparsers: Any) -> None:
65
100
  "python -m birder.tools quantize-model -n convnext_v2_tiny -t eu-common\n"
66
101
  "python -m birder.tools quantize-model --network densenet_121 -e 100 --num-calibration-batches 256\n"
67
102
  "python -m birder.tools quantize-model -n efficientnet_v2_s -e 200 --qbackend xnnpack --batch-size 1\n"
103
+ "python -m birder.tools quantize-model -n hgnet_v2_b4 --qbackend xnnpack --pte\n"
68
104
  ),
69
105
  formatter_class=cli.ArgumentHelpFormatter,
70
106
  )
@@ -81,6 +117,9 @@ def set_parser(subparsers: Any) -> None:
81
117
  subparser.add_argument(
82
118
  "--qbackend", type=str, choices=["x86", "xnnpack"], default="x86", help="quantization backend"
83
119
  )
120
+ subparser.add_argument(
121
+ "--pte", default=False, action="store_true", help="lower quantized model to ExecuTorch PTE format"
122
+ )
84
123
  subparser.add_argument("--batch-size", type=int, default=1, metavar="N", help="the batch size")
85
124
  subparser.add_argument(
86
125
  "--num-calibration-batches",
@@ -96,8 +135,13 @@ def set_parser(subparsers: Any) -> None:
96
135
 
97
136
  # pylint: disable=too-many-locals
98
137
  def main(args: argparse.Namespace) -> None:
138
+ if args.pte is True and args.qbackend != "xnnpack":
139
+ raise cli.ValidationError("--pte requires --qbackend xnnpack")
140
+
99
141
  network_name = get_network_name(args.network, tag=args.tag)
100
142
  model_path = fs_ops.model_path(network_name, epoch=args.epoch, quantized=True, pt2=True)
143
+ if args.pte is True:
144
+ model_path = model_path.with_suffix(".pte")
101
145
  if model_path.exists() is True and args.force is False:
102
146
  logger.warning("Quantized model already exists... aborting")
103
147
  raise SystemExit(1)
@@ -105,7 +149,7 @@ def main(args: argparse.Namespace) -> None:
105
149
  device = torch.device("cpu")
106
150
 
107
151
  # Load model
108
- (net, (class_to_idx, signature, rgb_stats, *_)) = fs_ops.load_model(
152
+ net, (class_to_idx, signature, rgb_stats, *_) = fs_ops.load_model(
109
153
  device, args.network, tag=args.tag, epoch=args.epoch, inference=True, reparameterized=args.reparameterized
110
154
  )
111
155
  net.eval()
@@ -154,9 +198,14 @@ def main(args: argparse.Namespace) -> None:
154
198
  exported_quantized_net = torch.export.export(quantized_net, example_inputs)
155
199
 
156
200
  toc = time.time()
157
- (minutes, seconds) = divmod(toc - tic, 60)
201
+ minutes, seconds = divmod(toc - tic, 60)
158
202
  logger.info(f"{int(minutes):0>2}m{seconds:04.1f}s to quantize model")
159
203
 
160
204
  model_path = fs_ops.model_path(network_name, epoch=args.epoch, quantized=True, pt2=True)
161
- logger.info(f"Saving quantized PT2 model {model_path}...")
162
- fs_ops.save_pt2(exported_quantized_net, model_path, task, class_to_idx, signature, rgb_stats)
205
+ if args.pte is True:
206
+ model_path = model_path.with_suffix(".pte")
207
+ logger.info(f"Lowering quantized model to PTE {model_path}...")
208
+ _save_pte(exported_quantized_net, model_path, task, class_to_idx, signature, rgb_stats)
209
+ else:
210
+ logger.info(f"Saving quantized PT2 model {model_path}...")
211
+ fs_ops.save_pt2(exported_quantized_net, model_path, task, class_to_idx, signature, rgb_stats)
birder/tools/results.py CHANGED
@@ -125,7 +125,7 @@ def print_most_confused_pairs(most_confused_df: pl.DataFrame) -> None:
125
125
 
126
126
  def convert_to_sparse(results_file: str, sparse_k: int) -> None:
127
127
  logger.info(f"Converting {results_file} to sparse format (k={sparse_k})...")
128
- (_, detected_sparse_k) = detect_file_format(results_file)
128
+ _, detected_sparse_k = detect_file_format(results_file)
129
129
 
130
130
  if detected_sparse_k is not None:
131
131
  logger.info(f"File is already in sparse format (with k={detected_sparse_k}). Skipping conversion.")
@@ -233,7 +233,7 @@ def main(args: argparse.Namespace) -> None:
233
233
  logger.warning("Cannot print mistakes in compare mode. processing only the first file")
234
234
 
235
235
  if args.imperfect_only is True:
236
- (result_name, results) = next(iter(results_dict.items()))
236
+ result_name, results = next(iter(results_dict.items()))
237
237
  mistake_prediction_indices = results.mistakes["prediction"].unique().to_numpy().tolist()
238
238
  mistake_label_indices = results.mistakes["label"].unique().to_numpy().tolist()
239
239
  imperfect_class_indices = np.unique(mistake_prediction_indices + mistake_label_indices).tolist()