birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +13 -13
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +6 -6
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +4 -4
- birder/layers/attention_pool.py +2 -2
- birder/layers/layer_scale.py +1 -1
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +4 -10
- birder/net/_rope_vit_configs.py +435 -0
- birder/net/_vit_configs.py +466 -0
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +18 -17
- birder/net/cait.py +7 -7
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +27 -27
- birder/net/conv2former.py +3 -3
- birder/net/convmixer.py +1 -1
- birder/net/convnext_v1.py +3 -11
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +6 -6
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +11 -11
- birder/net/deit.py +68 -29
- birder/net/deit3.py +69 -204
- birder/net/densenet.py +9 -8
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +31 -30
- birder/net/detection/detr.py +14 -11
- birder/net/detection/efficientdet.py +10 -29
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/dpn.py +1 -2
- birder/net/edgenext.py +5 -4
- birder/net/edgevit.py +13 -14
- birder/net/efficientformer_v1.py +3 -2
- birder/net/efficientformer_v2.py +18 -31
- birder/net/efficientnet_v2.py +3 -0
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +7 -7
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +3 -3
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +50 -58
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +13 -13
- birder/net/hgnet_v1.py +6 -6
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +5 -15
- birder/net/inception_resnet_v1.py +3 -3
- birder/net/inception_resnet_v2.py +7 -4
- birder/net/inception_v3.py +3 -0
- birder/net/inception_v4.py +3 -0
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +11 -23
- birder/net/metaformer.py +5 -5
- birder/net/mim/crossmae.py +6 -6
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +4 -6
- birder/net/mim/simmim.py +3 -4
- birder/net/mobilenet_v1.py +0 -9
- birder/net/mobilenet_v2.py +38 -44
- birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +7 -34
- birder/net/mobilevit_v2.py +6 -54
- birder/net/moganet.py +8 -5
- birder/net/mvit_v2.py +30 -30
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +4 -0
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +9 -9
- birder/net/pvt_v2.py +10 -16
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +5 -35
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resmlp.py +2 -2
- birder/net/resnest.py +4 -1
- birder/net/resnet_v1.py +125 -1
- birder/net/resnet_v2.py +75 -1
- birder/net/resnext.py +35 -1
- birder/net/rope_deit3.py +62 -151
- birder/net/rope_flexivit.py +46 -33
- birder/net/rope_vit.py +44 -758
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +69 -21
- birder/net/smt.py +8 -8
- birder/net/squeezenet.py +5 -12
- birder/net/squeezenext.py +0 -24
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +4 -4
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +13 -3
- birder/net/ssl/franca.py +28 -4
- birder/net/ssl/i_jepa.py +5 -5
- birder/net/ssl/ibot.py +1 -1
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +13 -3
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +5 -8
- birder/net/tiny_vit.py +6 -19
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/van.py +2 -2
- birder/net/vgg.py +1 -10
- birder/net/vit.py +72 -987
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +23 -48
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +16 -13
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +12 -3
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +15 -15
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder/net/mobilenet_v3_small.py +0 -43
- birder/net/se_resnet_v1.py +0 -105
- birder/net/se_resnet_v2.py +0 -59
- birder/net/se_resnext.py +0 -30
- birder-0.3.3.dist-info/RECORD +0 -299
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/adversarial/base.py
CHANGED
|
@@ -56,7 +56,7 @@ def pixel_eps_to_normalized(
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def clamp_normalized(inputs: torch.Tensor, rgb_stats: RGBType) -> torch.Tensor:
|
|
59
|
-
|
|
59
|
+
min_val, max_val = normalized_bounds(rgb_stats, device=inputs.device, dtype=inputs.dtype)
|
|
60
60
|
return torch.clamp(inputs, min=min_val, max=max_val)
|
|
61
61
|
|
|
62
62
|
|
birder/adversarial/simba.py
CHANGED
|
@@ -87,7 +87,7 @@ class SimBA:
|
|
|
87
87
|
if self._is_successful(current_logits, label, target_label):
|
|
88
88
|
return adv_inputs.detach(), num_queries
|
|
89
89
|
|
|
90
|
-
|
|
90
|
+
_, channels, height, width = adv_inputs.shape
|
|
91
91
|
num_dims = channels * height * width
|
|
92
92
|
step = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=adv_inputs.device, dtype=adv_inputs.dtype)
|
|
93
93
|
step_vals = step.view(-1) # Per-channel steps
|
|
@@ -98,11 +98,11 @@ class SimBA:
|
|
|
98
98
|
|
|
99
99
|
# Coordinate-wise search in random order
|
|
100
100
|
for flat_idx in perm[:num_steps]:
|
|
101
|
-
|
|
102
|
-
|
|
101
|
+
c, rem = divmod(int(flat_idx.item()), stride)
|
|
102
|
+
h, w = divmod(rem, width)
|
|
103
103
|
step_val = step_vals[c]
|
|
104
104
|
|
|
105
|
-
|
|
105
|
+
candidate_inputs, candidate_logits, candidate_objective = self._best_candidate(
|
|
106
106
|
adv_inputs, c, h, w, step_val, label, target_label
|
|
107
107
|
)
|
|
108
108
|
num_queries += 2
|
birder/common/cli.py
CHANGED
|
@@ -49,7 +49,7 @@ class FlexibleDictAction(argparse.Action):
|
|
|
49
49
|
new_dict = {}
|
|
50
50
|
for pair in pairs:
|
|
51
51
|
# Split each pair into key and value
|
|
52
|
-
|
|
52
|
+
key, value = pair.split("=", 1)
|
|
53
53
|
key = key.strip()
|
|
54
54
|
|
|
55
55
|
# Try to safely evaluate the value (handles ints and strings mostly)
|
birder/common/fs_ops.py
CHANGED
|
@@ -158,7 +158,7 @@ def model_path(
|
|
|
158
158
|
file_name = f"{file_name}_quantized"
|
|
159
159
|
|
|
160
160
|
if states is True:
|
|
161
|
-
file_name = f"{file_name}_states"
|
|
161
|
+
file_name = f"{file_name}_states.pt"
|
|
162
162
|
elif lite is True:
|
|
163
163
|
file_name = f"{file_name}.ptl"
|
|
164
164
|
elif pt2 is True:
|
|
@@ -254,7 +254,7 @@ def clean_checkpoints(network_name: str, keep_last: int) -> None:
|
|
|
254
254
|
models_glob = str(model_path(network_name, epoch=epoch))
|
|
255
255
|
states_glob = str(model_path(network_name, epoch=epoch, states=True))
|
|
256
256
|
model_pattern = re.compile(r".*_([1-9][0-9]*)\.pt$")
|
|
257
|
-
states_pattern = re.compile(r".*_([1-9][0-9]*)_states$")
|
|
257
|
+
states_pattern = re.compile(r".*_([1-9][0-9]*)_states\.pt$")
|
|
258
258
|
|
|
259
259
|
model_paths = list(settings.BASE_DIR.glob(models_glob))
|
|
260
260
|
for p in sorted(model_paths, key=lambda p: p.stat().st_mtime)[:-keep_last]:
|
|
@@ -384,7 +384,7 @@ def load_checkpoint(
|
|
|
384
384
|
)
|
|
385
385
|
|
|
386
386
|
# Initialize network and restore checkpoint state
|
|
387
|
-
net = registry.net_factory(network,
|
|
387
|
+
net = registry.net_factory(network, num_classes, input_channels, config=config, size=size)
|
|
388
388
|
|
|
389
389
|
# When a checkpoint was trained with EMA:
|
|
390
390
|
# The primary weights in the checkpoint file are the EMA weights
|
|
@@ -437,7 +437,7 @@ def load_mim_checkpoint(
|
|
|
437
437
|
size = lib.get_size_from_signature(signature)
|
|
438
438
|
|
|
439
439
|
# Initialize network and restore checkpoint state
|
|
440
|
-
net_encoder = registry.net_factory(encoder,
|
|
440
|
+
net_encoder = registry.net_factory(encoder, num_classes, input_channels, config=encoder_config, size=size)
|
|
441
441
|
net = registry.mim_net_factory(
|
|
442
442
|
network, net_encoder, config=config, size=size, mask_ratio=mask_ratio, min_mask_size=min_mask_size
|
|
443
443
|
)
|
|
@@ -488,7 +488,7 @@ def load_detection_checkpoint(
|
|
|
488
488
|
size = lib.get_size_from_signature(signature)
|
|
489
489
|
|
|
490
490
|
# Initialize network and restore checkpoint state
|
|
491
|
-
net_backbone = registry.net_factory(backbone,
|
|
491
|
+
net_backbone = registry.net_factory(backbone, num_classes, input_channels, config=backbone_config, size=size)
|
|
492
492
|
net = registry.detection_net_factory(network, num_classes, net_backbone, config=config, size=size)
|
|
493
493
|
|
|
494
494
|
# When a checkpoint was trained with EMA:
|
|
@@ -584,7 +584,7 @@ def load_model(
|
|
|
584
584
|
merged_config = None # type: ignore[assignment]
|
|
585
585
|
|
|
586
586
|
model_state: dict[str, Any] = safetensors.torch.load_file(path, device=device.type)
|
|
587
|
-
net = registry.net_factory(network,
|
|
587
|
+
net = registry.net_factory(network, num_classes, input_channels, config=merged_config, size=size)
|
|
588
588
|
if reparameterized is True:
|
|
589
589
|
net.reparameterize_model()
|
|
590
590
|
|
|
@@ -611,7 +611,7 @@ def load_model(
|
|
|
611
611
|
if len(merged_config) == 0:
|
|
612
612
|
merged_config = None
|
|
613
613
|
|
|
614
|
-
net = registry.net_factory(network,
|
|
614
|
+
net = registry.net_factory(network, num_classes, input_channels, config=merged_config, size=size)
|
|
615
615
|
if reparameterized is True:
|
|
616
616
|
net.reparameterize_model()
|
|
617
617
|
|
|
@@ -733,7 +733,7 @@ def load_detection_model(
|
|
|
733
733
|
|
|
734
734
|
model_state: dict[str, Any] = safetensors.torch.load_file(path, device=device.type)
|
|
735
735
|
net_backbone = registry.net_factory(
|
|
736
|
-
backbone,
|
|
736
|
+
backbone, num_classes, input_channels, config=backbone_merged_config, size=size
|
|
737
737
|
)
|
|
738
738
|
if backbone_reparameterized is True:
|
|
739
739
|
net_backbone.reparameterize_model()
|
|
@@ -776,7 +776,7 @@ def load_detection_model(
|
|
|
776
776
|
merged_config = None
|
|
777
777
|
|
|
778
778
|
net_backbone = registry.net_factory(
|
|
779
|
-
backbone,
|
|
779
|
+
backbone, num_classes, input_channels, config=backbone_merged_config, size=size
|
|
780
780
|
)
|
|
781
781
|
if backbone_reparameterized is True:
|
|
782
782
|
net_backbone.reparameterize_model()
|
|
@@ -959,7 +959,7 @@ def load_model_with_cfg(
|
|
|
959
959
|
encoder_name = cfg["encoder"]
|
|
960
960
|
|
|
961
961
|
encoder_config = cfg.get("encoder_config", None)
|
|
962
|
-
encoder = registry.net_factory(encoder_name,
|
|
962
|
+
encoder = registry.net_factory(encoder_name, 0, input_channels, config=encoder_config, size=size)
|
|
963
963
|
net = registry.mim_net_factory(name, encoder, config=model_config, size=size)
|
|
964
964
|
|
|
965
965
|
elif cfg["task"] == Task.OBJECT_DETECTION:
|
|
@@ -969,14 +969,14 @@ def load_model_with_cfg(
|
|
|
969
969
|
backbone_name = cfg["backbone"]
|
|
970
970
|
|
|
971
971
|
backbone_config = cfg.get("backbone_config", None)
|
|
972
|
-
backbone = registry.net_factory(backbone_name,
|
|
972
|
+
backbone = registry.net_factory(backbone_name, num_classes, input_channels, config=backbone_config, size=size)
|
|
973
973
|
if cfg.get("backbone_reparameterized", False) is True:
|
|
974
974
|
backbone.reparameterize_model()
|
|
975
975
|
|
|
976
976
|
net = registry.detection_net_factory(name, num_classes, backbone, config=model_config, size=size)
|
|
977
977
|
|
|
978
978
|
elif cfg["task"] == Task.IMAGE_CLASSIFICATION:
|
|
979
|
-
net = registry.net_factory(name,
|
|
979
|
+
net = registry.net_factory(name, num_classes, input_channels, config=model_config, size=size)
|
|
980
980
|
|
|
981
981
|
else:
|
|
982
982
|
raise ValueError(f"Configuration not supported: {cfg['task']}")
|
|
@@ -1019,7 +1019,7 @@ def download_model_by_weights(
|
|
|
1019
1019
|
f"Requested format '{file_format}' not available for {weights}, available formats are: {available_formats}"
|
|
1020
1020
|
)
|
|
1021
1021
|
|
|
1022
|
-
|
|
1022
|
+
model_file, url = get_pretrained_model_url(weights, file_format)
|
|
1023
1023
|
if dst is None:
|
|
1024
1024
|
dst = settings.MODELS_DIR.joinpath(model_file)
|
|
1025
1025
|
|
birder/common/lib.py
CHANGED
|
@@ -157,6 +157,6 @@ def get_pretrained_model_url(weights: str, file_format: str) -> tuple[str, str]:
|
|
|
157
157
|
|
|
158
158
|
def format_duration(seconds: float) -> str:
|
|
159
159
|
s = int(seconds)
|
|
160
|
-
|
|
161
|
-
|
|
160
|
+
mm, ss = divmod(s, 60)
|
|
161
|
+
hh, mm = divmod(mm, 60)
|
|
162
162
|
return f"{hh:d}:{mm:02d}:{ss:02d}"
|
birder/common/masking.py
CHANGED
|
@@ -16,7 +16,7 @@ def _mask_token_omission(
|
|
|
16
16
|
Parameters
|
|
17
17
|
----------
|
|
18
18
|
x
|
|
19
|
-
Tensor of shape (N, L, D), where N is the batch size, L is the sequence length
|
|
19
|
+
Tensor of shape (N, L, D), where N is the batch size, L is the sequence length and D is the feature dimension.
|
|
20
20
|
mask_ratio
|
|
21
21
|
The ratio of the sequence length to be masked. This value should be between 0 and 1.
|
|
22
22
|
kept_mask_ratio
|
|
@@ -48,7 +48,7 @@ def _mask_token_omission(
|
|
|
48
48
|
# Masking: length -> length * mask_ratio
|
|
49
49
|
# Perform per-sample random masking by per-sample shuffling.
|
|
50
50
|
# Per-sample shuffling is done by argsort random noise.
|
|
51
|
-
|
|
51
|
+
N, L, D = x.size() # batch, length, dim
|
|
52
52
|
len_keep = int(L * (1 - mask_ratio))
|
|
53
53
|
len_masked = int(L * (mask_ratio - kept_mask_ratio))
|
|
54
54
|
|
|
@@ -82,7 +82,7 @@ def mask_tensor(
|
|
|
82
82
|
if channels_last is False:
|
|
83
83
|
x = x.permute(0, 2, 3, 1)
|
|
84
84
|
|
|
85
|
-
|
|
85
|
+
B, H, W, _ = x.size()
|
|
86
86
|
|
|
87
87
|
shaped_mask = mask.reshape(B, H // patch_factor, W // patch_factor)
|
|
88
88
|
shaped_mask = shaped_mask.repeat_interleave(patch_factor, dim=1).repeat_interleave(patch_factor, dim=2)
|
birder/common/training_cli.py
CHANGED
|
@@ -13,6 +13,7 @@ from birder.conf import settings
|
|
|
13
13
|
from birder.data.datasets.coco import MosaicType
|
|
14
14
|
from birder.data.transforms.classification import AugType
|
|
15
15
|
from birder.data.transforms.classification import RGBMode
|
|
16
|
+
from birder.data.transforms.detection import MULTISCALE_STEP
|
|
16
17
|
from birder.data.transforms.detection import AugType as DetAugType
|
|
17
18
|
|
|
18
19
|
logger = logging.getLogger(__name__)
|
|
@@ -199,10 +200,16 @@ def add_detection_input_args(parser: argparse.ArgumentParser) -> None:
|
|
|
199
200
|
action="store_true",
|
|
200
201
|
help="enable random square resize once per batch (capped by max(--size))",
|
|
201
202
|
)
|
|
203
|
+
group.add_argument(
|
|
204
|
+
"--multiscale-step",
|
|
205
|
+
type=int,
|
|
206
|
+
default=MULTISCALE_STEP,
|
|
207
|
+
help="step size for multiscale size lists and collator padding divisibility (size_divisible)",
|
|
208
|
+
)
|
|
202
209
|
group.add_argument(
|
|
203
210
|
"--multiscale-min-size",
|
|
204
211
|
type=int,
|
|
205
|
-
help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of
|
|
212
|
+
help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of --multiscale-step)",
|
|
206
213
|
)
|
|
207
214
|
|
|
208
215
|
|
|
@@ -515,7 +522,10 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
|
|
515
522
|
|
|
516
523
|
|
|
517
524
|
def add_logging_and_debug_args(
|
|
518
|
-
parser: argparse.ArgumentParser,
|
|
525
|
+
parser: argparse.ArgumentParser,
|
|
526
|
+
default_log_interval: int = 50,
|
|
527
|
+
fake_data: bool = True,
|
|
528
|
+
classification: bool = False,
|
|
519
529
|
) -> None:
|
|
520
530
|
group = parser.add_argument_group("Logging and debugging parameters")
|
|
521
531
|
group.add_argument(
|
|
@@ -525,6 +535,11 @@ def add_logging_and_debug_args(
|
|
|
525
535
|
metavar="NAME",
|
|
526
536
|
help="experiment name for logging (creates dedicated directory for the run)",
|
|
527
537
|
)
|
|
538
|
+
if classification is True:
|
|
539
|
+
group.add_argument(
|
|
540
|
+
"--top-k", type=int, metavar="K", help="additional top-k accuracy value to track (top-1 is always tracked)"
|
|
541
|
+
)
|
|
542
|
+
|
|
528
543
|
group.add_argument(
|
|
529
544
|
"--log-interval",
|
|
530
545
|
type=int,
|
|
@@ -746,3 +761,10 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
746
761
|
# Precision_args, shared by all scripts
|
|
747
762
|
if args.amp is True and args.model_dtype != "float32":
|
|
748
763
|
raise ValidationError("--amp can only be used with --model-dtype float32")
|
|
764
|
+
|
|
765
|
+
if hasattr(args, "top_k") is True and args.top_k is not None:
|
|
766
|
+
if args.top_k == 1:
|
|
767
|
+
raise ValidationError("Top-1 accuracy is tracked by default, please remove 1 from --top-k argument")
|
|
768
|
+
|
|
769
|
+
if args.top_k <= 0:
|
|
770
|
+
raise ValidationError("--top-k value must be a positive integer")
|
birder/common/training_utils.py
CHANGED
|
@@ -11,6 +11,7 @@ from collections import deque
|
|
|
11
11
|
from collections.abc import Callable
|
|
12
12
|
from collections.abc import Generator
|
|
13
13
|
from collections.abc import Iterator
|
|
14
|
+
from collections.abc import Sequence
|
|
14
15
|
from datetime import datetime
|
|
15
16
|
from pathlib import Path
|
|
16
17
|
from typing import Any
|
|
@@ -361,7 +362,7 @@ def optimizer_parameter_groups(
|
|
|
361
362
|
Return parameter groups for optimizers with per-parameter group weight decay.
|
|
362
363
|
|
|
363
364
|
This function creates parameter groups with customizable weight decay, layer-wise
|
|
364
|
-
learning rate scaling
|
|
365
|
+
learning rate scaling and special handling for different parameter types. It supports
|
|
365
366
|
advanced optimization techniques like layer decay and custom weight decay rules.
|
|
366
367
|
|
|
367
368
|
Referenced from https://github.com/pytorch/vision/blob/main/references/classification/utils.py and from
|
|
@@ -450,7 +451,7 @@ def optimizer_parameter_groups(
|
|
|
450
451
|
visited_modules = []
|
|
451
452
|
while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
|
|
452
453
|
skip_module = False
|
|
453
|
-
|
|
454
|
+
module, prefix = module_stack_with_prefix.pop()
|
|
454
455
|
if id(module) in visited_modules:
|
|
455
456
|
skip_module = True
|
|
456
457
|
|
|
@@ -884,6 +885,11 @@ class SmoothedValue:
|
|
|
884
885
|
self.total: torch.Tensor | float = 0.0
|
|
885
886
|
self.count: int = 0
|
|
886
887
|
|
|
888
|
+
def clear(self) -> None:
|
|
889
|
+
self.deque.clear()
|
|
890
|
+
self.total = 0.0
|
|
891
|
+
self.count = 0
|
|
892
|
+
|
|
887
893
|
def update(self, value: torch.Tensor | float, n: int = 1) -> None:
|
|
888
894
|
self.deque.append(value)
|
|
889
895
|
self.count += n
|
|
@@ -927,14 +933,32 @@ class SmoothedValue:
|
|
|
927
933
|
return to_tensor(v, torch.device("cpu")).item() # type: ignore[no-any-return]
|
|
928
934
|
|
|
929
935
|
|
|
930
|
-
|
|
936
|
+
@torch.no_grad() # type: ignore[untyped-decorator]
|
|
937
|
+
def accuracy(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
|
|
931
938
|
if y_pred.dim() > 1 and y_pred.size(1) > 1:
|
|
932
939
|
y_pred = y_pred.argmax(dim=1)
|
|
933
940
|
|
|
934
941
|
y_true = y_true.flatten()
|
|
935
942
|
y_pred = y_pred.flatten()
|
|
936
943
|
|
|
937
|
-
return (y_true == y_pred).
|
|
944
|
+
return (y_true == y_pred).sum() / y_true.numel()
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
@torch.no_grad() # type: ignore[untyped-decorator]
|
|
948
|
+
def topk_accuracy(y_true: torch.Tensor, y_pred: torch.Tensor, topk: Sequence[int]) -> list[torch.Tensor]:
|
|
949
|
+
maxk = min(max(topk), y_pred.size(1))
|
|
950
|
+
batch_size = y_true.size(0)
|
|
951
|
+
|
|
952
|
+
_, pred = y_pred.topk(maxk, dim=1, largest=True, sorted=True)
|
|
953
|
+
correct = pred.eq(y_true.unsqueeze(1))
|
|
954
|
+
|
|
955
|
+
res: list[torch.Tensor] = []
|
|
956
|
+
for k in topk:
|
|
957
|
+
k = min(k, maxk)
|
|
958
|
+
correct_k = correct[:, :k].any(dim=1).sum(dtype=torch.float32)
|
|
959
|
+
res.append((correct_k / batch_size))
|
|
960
|
+
|
|
961
|
+
return res
|
|
938
962
|
|
|
939
963
|
|
|
940
964
|
###############################################################################
|
|
@@ -70,13 +70,21 @@ class BatchRandomResizeCollator(DetectionCollator):
|
|
|
70
70
|
size: tuple[int, int],
|
|
71
71
|
size_divisible: int = 32,
|
|
72
72
|
multiscale_min_size: Optional[int] = None,
|
|
73
|
+
multiscale_step: Optional[int] = None,
|
|
73
74
|
) -> None:
|
|
74
75
|
super().__init__(input_offset, size_divisible=size_divisible)
|
|
75
76
|
if size is None:
|
|
76
77
|
raise ValueError("size must be provided for batch multiscale")
|
|
77
78
|
|
|
78
79
|
max_side = max(size)
|
|
79
|
-
|
|
80
|
+
if multiscale_step is None:
|
|
81
|
+
multiscale_step = size_divisible
|
|
82
|
+
|
|
83
|
+
sizes = []
|
|
84
|
+
for side in build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step):
|
|
85
|
+
if side <= max_side:
|
|
86
|
+
sizes.append(side)
|
|
87
|
+
|
|
80
88
|
if len(sizes) == 0:
|
|
81
89
|
sizes = [max_side]
|
|
82
90
|
|
|
@@ -17,17 +17,20 @@ DEFAULT_MULTISCALE_MAX_SIZE = 800
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def build_multiscale_sizes(
|
|
20
|
-
min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE
|
|
20
|
+
min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE, multiscale_step: int = MULTISCALE_STEP
|
|
21
21
|
) -> tuple[int, ...]:
|
|
22
|
+
if multiscale_step <= 0:
|
|
23
|
+
raise ValueError("multiscale_step must be positive")
|
|
24
|
+
|
|
22
25
|
if min_size is None:
|
|
23
26
|
min_size = DEFAULT_MULTISCALE_MIN_SIZE
|
|
24
27
|
|
|
25
|
-
start = int(math.ceil(min_size /
|
|
26
|
-
end = int(math.floor(max_size /
|
|
28
|
+
start = int(math.ceil(min_size / multiscale_step) * multiscale_step)
|
|
29
|
+
end = int(math.floor(max_size / multiscale_step) * multiscale_step)
|
|
27
30
|
if end < start:
|
|
28
31
|
return (start,)
|
|
29
32
|
|
|
30
|
-
return tuple(range(start, end + 1,
|
|
33
|
+
return tuple(range(start, end + 1, multiscale_step))
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
class ResizeWithRandomInterpolation(nn.Module):
|
|
@@ -59,6 +62,7 @@ def get_birder_augment(
|
|
|
59
62
|
multiscale: bool,
|
|
60
63
|
max_size: Optional[int],
|
|
61
64
|
multiscale_min_size: Optional[int],
|
|
65
|
+
multiscale_step: int = MULTISCALE_STEP,
|
|
62
66
|
post_mosaic: bool = False,
|
|
63
67
|
) -> Callable[..., torch.Tensor]:
|
|
64
68
|
if dynamic_size is True:
|
|
@@ -98,7 +102,10 @@ def get_birder_augment(
|
|
|
98
102
|
# Resize
|
|
99
103
|
if multiscale is True:
|
|
100
104
|
transformations.append(
|
|
101
|
-
v2.RandomShortestSize(
|
|
105
|
+
v2.RandomShortestSize(
|
|
106
|
+
min_size=build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step),
|
|
107
|
+
max_size=max_size or 1333,
|
|
108
|
+
),
|
|
102
109
|
)
|
|
103
110
|
else:
|
|
104
111
|
transformations.append(
|
|
@@ -160,6 +167,7 @@ def training_preset(
|
|
|
160
167
|
multiscale: bool = False,
|
|
161
168
|
max_size: Optional[int] = None,
|
|
162
169
|
multiscale_min_size: Optional[int] = None,
|
|
170
|
+
multiscale_step: int = MULTISCALE_STEP,
|
|
163
171
|
post_mosaic: bool = False,
|
|
164
172
|
) -> Callable[..., torch.Tensor]:
|
|
165
173
|
mean = rgv_values["mean"]
|
|
@@ -180,7 +188,15 @@ def training_preset(
|
|
|
180
188
|
[
|
|
181
189
|
v2.ToImage(),
|
|
182
190
|
get_birder_augment(
|
|
183
|
-
size,
|
|
191
|
+
size,
|
|
192
|
+
level,
|
|
193
|
+
fill_value,
|
|
194
|
+
dynamic_size,
|
|
195
|
+
multiscale,
|
|
196
|
+
max_size,
|
|
197
|
+
multiscale_min_size,
|
|
198
|
+
multiscale_step,
|
|
199
|
+
post_mosaic,
|
|
184
200
|
),
|
|
185
201
|
v2.ToDtype(torch.float32, scale=True),
|
|
186
202
|
v2.Normalize(mean=mean, std=std),
|
|
@@ -212,7 +228,10 @@ def training_preset(
|
|
|
212
228
|
return v2.Compose( # type: ignore
|
|
213
229
|
[
|
|
214
230
|
v2.ToImage(),
|
|
215
|
-
v2.RandomShortestSize(
|
|
231
|
+
v2.RandomShortestSize(
|
|
232
|
+
min_size=build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step),
|
|
233
|
+
max_size=max_size or 1333,
|
|
234
|
+
),
|
|
216
235
|
v2.RandomHorizontalFlip(0.5),
|
|
217
236
|
v2.SanitizeBoundingBoxes(),
|
|
218
237
|
v2.ToDtype(torch.float32, scale=True),
|
|
@@ -284,7 +303,7 @@ def training_preset(
|
|
|
284
303
|
)
|
|
285
304
|
|
|
286
305
|
if aug_type == "detr":
|
|
287
|
-
multiscale_sizes = build_multiscale_sizes(multiscale_min_size)
|
|
306
|
+
multiscale_sizes = build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step)
|
|
288
307
|
return v2.Compose( # type: ignore
|
|
289
308
|
[
|
|
290
309
|
v2.ToImage(),
|
birder/data/transforms/mosaic.py
CHANGED
|
@@ -19,7 +19,7 @@ def mosaic_random_center(
|
|
|
19
19
|
Create a mosaic augmentation by combining 4 images into a single image.
|
|
20
20
|
|
|
21
21
|
This augmentation places 4 images on a canvas, meeting at a randomly selected
|
|
22
|
-
center point. Each image is scaled to fit, cropped as needed
|
|
22
|
+
center point. Each image is scaled to fit, cropped as needed and their bounding
|
|
23
23
|
boxes are transformed accordingly.
|
|
24
24
|
|
|
25
25
|
Parameters
|
birder/datahub/classification.py
CHANGED
|
@@ -63,7 +63,7 @@ class TestDataset(ImageFolder):
|
|
|
63
63
|
super().__init__(self._root.joinpath(split), transform, target_transform, loader, is_valid_file)
|
|
64
64
|
|
|
65
65
|
def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any]:
|
|
66
|
-
|
|
66
|
+
path, target = self.samples[index]
|
|
67
67
|
sample = self.loader(path)
|
|
68
68
|
if self.transform is not None:
|
|
69
69
|
sample = self.transform(sample)
|
|
@@ -122,7 +122,7 @@ class Flowers102(ImageFolder):
|
|
|
122
122
|
super().__init__(self._root.joinpath(split), transform, target_transform, loader, is_valid_file)
|
|
123
123
|
|
|
124
124
|
def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any]:
|
|
125
|
-
|
|
125
|
+
path, target = self.samples[index]
|
|
126
126
|
sample = self.loader(path)
|
|
127
127
|
if self.transform is not None:
|
|
128
128
|
sample = self.transform(sample)
|
|
@@ -182,7 +182,7 @@ class CUB_200_2011(ImageFolder):
|
|
|
182
182
|
super().__init__(self._root.joinpath(split), transform, target_transform, loader, is_valid_file)
|
|
183
183
|
|
|
184
184
|
def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any]:
|
|
185
|
-
|
|
185
|
+
path, target = self.samples[index]
|
|
186
186
|
sample = self.loader(path)
|
|
187
187
|
if self.transform is not None:
|
|
188
188
|
sample = self.transform(sample)
|
|
@@ -75,7 +75,7 @@ def infer_batch(
|
|
|
75
75
|
embedding = embedding_tensor.cpu().float().numpy()
|
|
76
76
|
|
|
77
77
|
elif tta is True:
|
|
78
|
-
|
|
78
|
+
_, _, H, W = inputs.size()
|
|
79
79
|
crop_h = int(H * 0.8)
|
|
80
80
|
crop_w = int(W * 0.8)
|
|
81
81
|
tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
|
|
@@ -137,7 +137,7 @@ def infer_dataloader_iter(
|
|
|
137
137
|
inputs = inputs.to(device, dtype=model_dtype)
|
|
138
138
|
|
|
139
139
|
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
140
|
-
|
|
140
|
+
out, embedding = infer_batch(
|
|
141
141
|
net, inputs, return_embedding=return_embedding, tta=tta, return_logits=return_logits, **kwargs
|
|
142
142
|
)
|
|
143
143
|
|
|
@@ -394,7 +394,7 @@ def evaluate(
|
|
|
394
394
|
num_samples: Optional[int] = None,
|
|
395
395
|
sparse: bool = False,
|
|
396
396
|
) -> Results | SparseResults:
|
|
397
|
-
|
|
397
|
+
sample_paths, outs, labels, _ = infer_dataloader(
|
|
398
398
|
device, net, dataloader, tta=tta, model_dtype=model_dtype, amp=amp, amp_dtype=amp_dtype, num_samples=num_samples
|
|
399
399
|
)
|
|
400
400
|
if sparse is True:
|
|
@@ -253,7 +253,7 @@ class InferenceDataParallel(nn.Module):
|
|
|
253
253
|
|
|
254
254
|
This allows custom methods (e.g., model.embedding()) to be called
|
|
255
255
|
on the InferenceDataParallel instance, which then scatters inputs,
|
|
256
|
-
calls the method on each replica
|
|
256
|
+
calls the method on each replica and gathers the results.
|
|
257
257
|
|
|
258
258
|
Parameters
|
|
259
259
|
----------
|
birder/inference/detection.py
CHANGED
|
@@ -20,7 +20,7 @@ def _normalize_image_sizes(inputs: torch.Tensor, image_sizes: Optional[list[list
|
|
|
20
20
|
if image_sizes is not None:
|
|
21
21
|
return image_sizes
|
|
22
22
|
|
|
23
|
-
|
|
23
|
+
_, _, height, width = inputs.shape
|
|
24
24
|
return [[height, width] for _ in range(inputs.size(0))]
|
|
25
25
|
|
|
26
26
|
|
|
@@ -149,20 +149,20 @@ def infer_batch(
|
|
|
149
149
|
**kwargs: Any,
|
|
150
150
|
) -> list[dict[str, torch.Tensor]]:
|
|
151
151
|
if tta is False:
|
|
152
|
-
|
|
152
|
+
detections, _ = net(inputs, masks=masks, image_sizes=image_sizes, **kwargs)
|
|
153
153
|
return detections # type: ignore[no-any-return]
|
|
154
154
|
|
|
155
155
|
normalized_sizes = _normalize_image_sizes(inputs, image_sizes)
|
|
156
156
|
detections_list: list[list[dict[str, torch.Tensor]]] = []
|
|
157
157
|
|
|
158
158
|
for scale in (0.8, 1.0, 1.2):
|
|
159
|
-
|
|
160
|
-
|
|
159
|
+
scaled_inputs, scaled_masks, scaled_sizes = _resize_batch(inputs, normalized_sizes, scale, size_divisible=32)
|
|
160
|
+
detections, _ = net(scaled_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
|
|
161
161
|
detections = _rescale_detections(detections, scaled_sizes, normalized_sizes)
|
|
162
162
|
detections_list.append(detections)
|
|
163
163
|
|
|
164
164
|
flipped_inputs = _hflip_inputs(scaled_inputs, scaled_sizes)
|
|
165
|
-
|
|
165
|
+
flipped_detections, _ = net(flipped_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
|
|
166
166
|
flipped_detections = _invert_detections(flipped_detections, scaled_sizes)
|
|
167
167
|
flipped_detections = _rescale_detections(flipped_detections, scaled_sizes, normalized_sizes)
|
|
168
168
|
detections_list.append(flipped_detections)
|
birder/inference/wbf.py
CHANGED
|
@@ -182,7 +182,7 @@ def fuse_detections_wbf_single(
|
|
|
182
182
|
scores_list = [detection["scores"] for detection in detections]
|
|
183
183
|
labels_list = [detection["labels"] for detection in detections]
|
|
184
184
|
|
|
185
|
-
|
|
185
|
+
boxes, scores, labels = weighted_boxes_fusion(
|
|
186
186
|
boxes_list,
|
|
187
187
|
scores_list,
|
|
188
188
|
labels_list,
|
|
@@ -70,7 +70,7 @@ def compute_rollout(
|
|
|
70
70
|
num_to_discard = int(num_allowed * discard_ratio)
|
|
71
71
|
if num_to_discard > 0:
|
|
72
72
|
# Drop the smallest allowed values
|
|
73
|
-
|
|
73
|
+
_, low_idx = torch.topk(allowed_values, num_to_discard, largest=False)
|
|
74
74
|
allowed_values[low_idx] = 0
|
|
75
75
|
attn[allow] = allowed_values
|
|
76
76
|
attention_heads_fused[0] = attn
|
|
@@ -97,7 +97,7 @@ def compute_rollout(
|
|
|
97
97
|
|
|
98
98
|
# Normalize and reshape to 2D map using actual patch grid dimensions
|
|
99
99
|
mask = mask / (mask.max() + 1e-8)
|
|
100
|
-
|
|
100
|
+
grid_h, grid_w = patch_grid_shape
|
|
101
101
|
mask = mask.reshape(grid_h, grid_w)
|
|
102
102
|
|
|
103
103
|
return mask
|
|
@@ -141,7 +141,7 @@ class AttentionRollout:
|
|
|
141
141
|
net: nn.Module,
|
|
142
142
|
device: torch.device,
|
|
143
143
|
transform: Callable[..., torch.Tensor],
|
|
144
|
-
attention_layer_name: str = "
|
|
144
|
+
attention_layer_name: str = "attn",
|
|
145
145
|
discard_ratio: float = 0.9,
|
|
146
146
|
head_fusion: Literal["mean", "max", "min"] = "max",
|
|
147
147
|
) -> None:
|
|
@@ -156,11 +156,11 @@ class AttentionRollout:
|
|
|
156
156
|
self.attention_gatherer = AttentionGatherer(net, attention_layer_name)
|
|
157
157
|
|
|
158
158
|
def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
|
|
159
|
-
|
|
159
|
+
input_tensor, rgb_img = preprocess_image(image, self.transform, self.device)
|
|
160
160
|
|
|
161
|
-
|
|
161
|
+
attentions, logits = self.attention_gatherer(input_tensor)
|
|
162
162
|
|
|
163
|
-
|
|
163
|
+
_, _, H, W = input_tensor.shape
|
|
164
164
|
patch_grid_shape = (H // self.net.stem_stride, W // self.net.stem_stride)
|
|
165
165
|
|
|
166
166
|
attention_map = compute_rollout(
|
|
@@ -17,7 +17,7 @@ class FeaturePCA:
|
|
|
17
17
|
Visualizes feature maps using Principal Component Analysis
|
|
18
18
|
|
|
19
19
|
This method extracts feature maps from a specified stage of a DetectorBackbone model,
|
|
20
|
-
applies PCA to reduce the channel dimension to 3 components
|
|
20
|
+
applies PCA to reduce the channel dimension to 3 components and visualizes them as an RGB image where:
|
|
21
21
|
- R channel = 1st principal component (most important)
|
|
22
22
|
- G channel = 2nd principal component
|
|
23
23
|
- B channel = 3rd principal component
|
|
@@ -40,7 +40,7 @@ class FeaturePCA:
|
|
|
40
40
|
self.stage = stage
|
|
41
41
|
|
|
42
42
|
def __call__(self, image: str | Path | Image.Image) -> InterpretabilityResult:
|
|
43
|
-
|
|
43
|
+
input_tensor, rgb_img = preprocess_image(image, self.transform, self.device)
|
|
44
44
|
|
|
45
45
|
with torch.inference_mode():
|
|
46
46
|
features_dict = self.net.detection_features(input_tensor)
|
|
@@ -54,11 +54,11 @@ class FeaturePCA:
|
|
|
54
54
|
|
|
55
55
|
# Handle channels_last format (B, H, W, C) vs channels_first (B, C, H, W)
|
|
56
56
|
if self.channels_last is True:
|
|
57
|
-
|
|
57
|
+
B, H, W, C = features_np.shape
|
|
58
58
|
# Already in (B, H, W, C), just reshape to (B*H*W, C)
|
|
59
59
|
features_reshaped = features_np.reshape(-1, C)
|
|
60
60
|
else:
|
|
61
|
-
|
|
61
|
+
B, C, H, W = features_np.shape
|
|
62
62
|
# Reshape to (spatial_points, channels) for PCA
|
|
63
63
|
features_reshaped = features_np.reshape(B, C, -1)
|
|
64
64
|
features_reshaped = features_reshaped.transpose(0, 2, 1) # (B, H*W, C)
|
birder/introspection/gradcam.py
CHANGED
|
@@ -98,7 +98,7 @@ class GradCAM:
|
|
|
98
98
|
self.activation_capture = ActivationCapture(net, target_layer, reshape_transform)
|
|
99
99
|
|
|
100
100
|
def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
|
|
101
|
-
|
|
101
|
+
input_tensor, rgb_img = preprocess_image(image, self.transform, self.device)
|
|
102
102
|
input_tensor.requires_grad_(True)
|
|
103
103
|
|
|
104
104
|
# Forward pass
|