birder 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/deepfool.py +2 -0
- birder/adversarial/simba.py +2 -0
- birder/common/fs_ops.py +2 -2
- birder/common/masking.py +13 -4
- birder/common/training_cli.py +6 -1
- birder/common/training_utils.py +4 -2
- birder/inference/classification.py +1 -1
- birder/introspection/__init__.py +2 -0
- birder/introspection/base.py +0 -7
- birder/introspection/feature_pca.py +101 -0
- birder/kernels/soft_nms/soft_nms.cpp +5 -2
- birder/model_registry/model_registry.py +3 -2
- birder/net/base.py +3 -3
- birder/net/biformer.py +2 -2
- birder/net/cas_vit.py +6 -6
- birder/net/coat.py +8 -8
- birder/net/conv2former.py +2 -2
- birder/net/convnext_v1.py +22 -2
- birder/net/convnext_v2.py +2 -2
- birder/net/crossformer.py +2 -2
- birder/net/cspnet.py +2 -2
- birder/net/cswin_transformer.py +2 -2
- birder/net/darknet.py +2 -2
- birder/net/davit.py +2 -2
- birder/net/deit.py +3 -3
- birder/net/deit3.py +3 -3
- birder/net/densenet.py +2 -2
- birder/net/detection/deformable_detr.py +2 -2
- birder/net/detection/detr.py +2 -2
- birder/net/detection/efficientdet.py +2 -2
- birder/net/detection/faster_rcnn.py +2 -2
- birder/net/detection/fcos.py +2 -2
- birder/net/detection/retinanet.py +2 -2
- birder/net/detection/rt_detr_v1.py +4 -4
- birder/net/detection/ssd.py +2 -2
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +2 -2
- birder/net/detection/yolo_v3.py +2 -2
- birder/net/detection/yolo_v4.py +2 -2
- birder/net/edgenext.py +2 -2
- birder/net/edgevit.py +1 -1
- birder/net/efficientformer_v1.py +4 -4
- birder/net/efficientformer_v2.py +6 -6
- birder/net/efficientnet_lite.py +2 -2
- birder/net/efficientnet_v1.py +2 -2
- birder/net/efficientnet_v2.py +2 -2
- birder/net/efficientvim.py +3 -3
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +2 -2
- birder/net/fasternet.py +2 -2
- birder/net/fastvit.py +2 -3
- birder/net/flexivit.py +11 -6
- birder/net/focalnet.py +2 -3
- birder/net/gc_vit.py +17 -2
- birder/net/ghostnet_v1.py +2 -2
- birder/net/ghostnet_v2.py +2 -2
- birder/net/groupmixformer.py +2 -2
- birder/net/hgnet_v1.py +2 -2
- birder/net/hgnet_v2.py +2 -2
- birder/net/hiera.py +2 -2
- birder/net/hieradet.py +2 -2
- birder/net/hornet.py +2 -2
- birder/net/iformer.py +2 -2
- birder/net/inception_next.py +2 -2
- birder/net/inception_resnet_v1.py +2 -2
- birder/net/inception_resnet_v2.py +2 -2
- birder/net/inception_v3.py +2 -2
- birder/net/inception_v4.py +2 -2
- birder/net/levit.py +4 -4
- birder/net/lit_v1.py +2 -2
- birder/net/lit_v1_tiny.py +2 -2
- birder/net/lit_v2.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/metaformer.py +2 -2
- birder/net/mnasnet.py +2 -2
- birder/net/mobilenet_v1.py +2 -2
- birder/net/mobilenet_v2.py +2 -2
- birder/net/mobilenet_v3_large.py +2 -2
- birder/net/mobilenet_v4.py +2 -2
- birder/net/mobilenet_v4_hybrid.py +2 -2
- birder/net/mobileone.py +2 -2
- birder/net/mobilevit_v2.py +2 -2
- birder/net/moganet.py +2 -2
- birder/net/mvit_v2.py +2 -2
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +2 -2
- birder/net/pit.py +6 -6
- birder/net/pvt_v1.py +2 -2
- birder/net/pvt_v2.py +2 -2
- birder/net/rdnet.py +2 -2
- birder/net/regionvit.py +6 -6
- birder/net/regnet.py +2 -2
- birder/net/regnet_z.py +2 -2
- birder/net/repghost.py +2 -2
- birder/net/repvgg.py +2 -2
- birder/net/repvit.py +6 -6
- birder/net/resnest.py +2 -2
- birder/net/resnet_v1.py +2 -2
- birder/net/resnet_v2.py +2 -2
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +3 -3
- birder/net/rope_flexivit.py +13 -6
- birder/net/rope_vit.py +69 -10
- birder/net/shufflenet_v1.py +2 -2
- birder/net/shufflenet_v2.py +2 -2
- birder/net/smt.py +1 -2
- birder/net/squeezenext.py +2 -2
- birder/net/ssl/byol.py +3 -2
- birder/net/ssl/capi.py +156 -11
- birder/net/ssl/data2vec.py +3 -1
- birder/net/ssl/data2vec2.py +3 -1
- birder/net/ssl/dino_v1.py +1 -1
- birder/net/ssl/dino_v2.py +140 -18
- birder/net/ssl/franca.py +145 -13
- birder/net/ssl/ibot.py +1 -2
- birder/net/ssl/mmcr.py +3 -1
- birder/net/starnet.py +2 -2
- birder/net/swiftformer.py +6 -6
- birder/net/swin_transformer_v1.py +2 -2
- birder/net/swin_transformer_v2.py +2 -2
- birder/net/tiny_vit.py +2 -2
- birder/net/transnext.py +1 -1
- birder/net/uniformer.py +1 -1
- birder/net/van.py +1 -1
- birder/net/vgg.py +1 -1
- birder/net/vgg_reduced.py +1 -1
- birder/net/vit.py +172 -8
- birder/net/vit_parallel.py +5 -5
- birder/net/vit_sam.py +3 -3
- birder/net/vovnet_v1.py +2 -2
- birder/net/vovnet_v2.py +2 -2
- birder/net/wide_resnet.py +2 -2
- birder/net/xception.py +2 -2
- birder/net/xcit.py +2 -2
- birder/results/detection.py +104 -0
- birder/results/gui.py +10 -8
- birder/scripts/benchmark.py +1 -1
- birder/scripts/train.py +13 -18
- birder/scripts/train_barlow_twins.py +10 -14
- birder/scripts/train_byol.py +11 -15
- birder/scripts/train_capi.py +38 -17
- birder/scripts/train_data2vec.py +11 -15
- birder/scripts/train_data2vec2.py +13 -17
- birder/scripts/train_detection.py +11 -14
- birder/scripts/train_dino_v1.py +20 -22
- birder/scripts/train_dino_v2.py +126 -63
- birder/scripts/train_dino_v2_dist.py +127 -64
- birder/scripts/train_franca.py +49 -34
- birder/scripts/train_i_jepa.py +11 -14
- birder/scripts/train_ibot.py +16 -18
- birder/scripts/train_kd.py +14 -20
- birder/scripts/train_mim.py +10 -13
- birder/scripts/train_mmcr.py +11 -15
- birder/scripts/train_rotnet.py +12 -16
- birder/scripts/train_simclr.py +10 -14
- birder/scripts/train_vicreg.py +10 -14
- birder/tools/avg_model.py +24 -8
- birder/tools/det_results.py +91 -0
- birder/tools/introspection.py +35 -9
- birder/tools/results.py +11 -7
- birder/tools/show_iterator.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
- birder-0.3.2.dist-info/RECORD +299 -0
- birder-0.3.0.dist-info/RECORD +0 -298
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
birder/adversarial/deepfool.py
CHANGED
birder/adversarial/simba.py
CHANGED
birder/common/fs_ops.py
CHANGED
|
@@ -627,7 +627,7 @@ def load_model(
|
|
|
627
627
|
net.to(dtype)
|
|
628
628
|
if inference is True:
|
|
629
629
|
for param in net.parameters():
|
|
630
|
-
param.
|
|
630
|
+
param.requires_grad_(False)
|
|
631
631
|
|
|
632
632
|
if pt2 is False: # NOTE: Remove when GraphModule add support for 'eval'
|
|
633
633
|
net.eval()
|
|
@@ -799,7 +799,7 @@ def load_detection_model(
|
|
|
799
799
|
net.to(dtype)
|
|
800
800
|
if inference is True:
|
|
801
801
|
for param in net.parameters():
|
|
802
|
-
param.
|
|
802
|
+
param.requires_grad_(False)
|
|
803
803
|
|
|
804
804
|
net.eval()
|
|
805
805
|
|
birder/common/masking.py
CHANGED
|
@@ -84,8 +84,8 @@ def mask_tensor(
|
|
|
84
84
|
|
|
85
85
|
(B, H, W, _) = x.size()
|
|
86
86
|
|
|
87
|
-
shaped_mask = mask.reshape(
|
|
88
|
-
shaped_mask = shaped_mask.repeat_interleave(patch_factor,
|
|
87
|
+
shaped_mask = mask.reshape(B, H // patch_factor, W // patch_factor)
|
|
88
|
+
shaped_mask = shaped_mask.repeat_interleave(patch_factor, dim=1).repeat_interleave(patch_factor, dim=2)
|
|
89
89
|
shaped_mask = shaped_mask.unsqueeze(3).type_as(x)
|
|
90
90
|
|
|
91
91
|
if mask_token is not None:
|
|
@@ -228,14 +228,23 @@ class Masking:
|
|
|
228
228
|
|
|
229
229
|
|
|
230
230
|
class UniformMasking(Masking):
|
|
231
|
-
def __init__(
|
|
231
|
+
def __init__(
|
|
232
|
+
self,
|
|
233
|
+
input_size: tuple[int, int],
|
|
234
|
+
mask_ratio: float,
|
|
235
|
+
min_mask_size: int = 1,
|
|
236
|
+
device: Optional[torch.device] = None,
|
|
237
|
+
) -> None:
|
|
232
238
|
self.h = input_size[0]
|
|
233
239
|
self.w = input_size[1]
|
|
234
240
|
self.mask_ratio = mask_ratio
|
|
241
|
+
self.min_mask_size = min_mask_size
|
|
235
242
|
self.device = device
|
|
236
243
|
|
|
237
244
|
def __call__(self, batch_size: int) -> torch.Tensor:
|
|
238
|
-
return uniform_mask(
|
|
245
|
+
return uniform_mask(
|
|
246
|
+
batch_size, self.h, self.w, self.mask_ratio, min_mask_size=self.min_mask_size, device=self.device
|
|
247
|
+
)[0]
|
|
239
248
|
|
|
240
249
|
|
|
241
250
|
class BlockMasking(Masking):
|
birder/common/training_cli.py
CHANGED
|
@@ -39,6 +39,7 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
|
|
|
39
39
|
group = parser.add_argument_group("Optimization parameters")
|
|
40
40
|
group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
|
|
41
41
|
group.add_argument("--opt", type=str, choices=list(get_args(OptimizerType)), default="sgd", help="optimizer to use")
|
|
42
|
+
group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
|
|
42
43
|
group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
|
|
43
44
|
group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
|
|
44
45
|
group.add_argument("--opt-eps", type=float, help="optimizer epsilon (None to use the optimizer default)")
|
|
@@ -249,6 +250,7 @@ def add_data_aug_args(
|
|
|
249
250
|
default_level: int = 4,
|
|
250
251
|
default_min_scale: Optional[float] = None,
|
|
251
252
|
default_re_prob: Optional[float] = None,
|
|
253
|
+
smoothing_alpha: bool = False,
|
|
252
254
|
mixup_cutmix: bool = False,
|
|
253
255
|
) -> None:
|
|
254
256
|
group = parser.add_argument_group("Data augmentation parameters")
|
|
@@ -285,6 +287,8 @@ def add_data_aug_args(
|
|
|
285
287
|
group.add_argument(
|
|
286
288
|
"--simple-crop", default=False, action="store_true", help="use simple random crop (SRC) instead of RRC"
|
|
287
289
|
)
|
|
290
|
+
if smoothing_alpha is True:
|
|
291
|
+
group.add_argument("--smoothing-alpha", type=float, default=0.0, help="label smoothing alpha")
|
|
288
292
|
if mixup_cutmix is True:
|
|
289
293
|
group.add_argument("--mixup-alpha", type=float, help="mixup alpha")
|
|
290
294
|
group.add_argument("--cutmix", default=False, action="store_true", help="enable cutmix")
|
|
@@ -565,9 +569,9 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
|
|
|
565
569
|
group.add_argument("--wds", default=False, action="store_true", help="use webdataset for training")
|
|
566
570
|
group.add_argument("--wds-info", type=str, metavar="FILE", help="wds info file path")
|
|
567
571
|
group.add_argument("--wds-cache-dir", type=str, metavar="DIR", help="webdataset cache directory")
|
|
568
|
-
group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
|
|
569
572
|
if unsupervised is False:
|
|
570
573
|
group.add_argument("--wds-class-file", type=str, metavar="FILE", help="class list file")
|
|
574
|
+
group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
|
|
571
575
|
group.add_argument("--wds-val-size", type=int, metavar="N", help="size of the wds validation set")
|
|
572
576
|
group.add_argument(
|
|
573
577
|
"--wds-training-split", type=str, default="training", metavar="NAME", help="wds dataset train split"
|
|
@@ -576,6 +580,7 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
|
|
|
576
580
|
"--wds-val-split", type=str, default="validation", metavar="NAME", help="wds dataset validation split"
|
|
577
581
|
)
|
|
578
582
|
else:
|
|
583
|
+
group.add_argument("--wds-size", type=int, metavar="N", help="size of the wds")
|
|
579
584
|
group.add_argument(
|
|
580
585
|
"--wds-split", type=str, default="training", metavar="NAME", help="wds dataset split to load"
|
|
581
586
|
)
|
birder/common/training_utils.py
CHANGED
|
@@ -593,12 +593,14 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
|
|
|
593
593
|
kwargs["betas"] = args.opt_betas
|
|
594
594
|
if getattr(args, "opt_alpha", None) is not None:
|
|
595
595
|
kwargs["alpha"] = args.opt_alpha
|
|
596
|
+
if getattr(args, "opt_fused", False) is True:
|
|
597
|
+
kwargs["fused"] = True
|
|
596
598
|
|
|
597
599
|
# For optimizer compilation
|
|
598
600
|
# lr = torch.tensor(l_rate) - Causes weird LR scheduling bugs
|
|
599
601
|
lr = l_rate
|
|
600
|
-
if getattr(args, "compile_opt", False) is
|
|
601
|
-
if opt not in ("lamb", "lambw", "lars"):
|
|
602
|
+
if getattr(args, "compile_opt", False) is True:
|
|
603
|
+
if opt not in ("sgd", "lamb", "lambw", "lars"):
|
|
602
604
|
logger.debug("Setting optimizer capturable to True")
|
|
603
605
|
kwargs["capturable"] = True
|
|
604
606
|
|
|
@@ -85,7 +85,7 @@ def infer_batch(
|
|
|
85
85
|
logits = net(t(tta_input), **kwargs)
|
|
86
86
|
outs.append(logits if return_logits is True else F.softmax(logits, dim=1))
|
|
87
87
|
|
|
88
|
-
out = torch.stack(outs).mean(
|
|
88
|
+
out = torch.stack(outs).mean(dim=0)
|
|
89
89
|
|
|
90
90
|
else:
|
|
91
91
|
logits = net(inputs, **kwargs)
|
birder/introspection/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from birder.introspection.attention_rollout import AttentionRollout
|
|
2
2
|
from birder.introspection.base import InterpretabilityResult
|
|
3
|
+
from birder.introspection.feature_pca import FeaturePCA
|
|
3
4
|
from birder.introspection.gradcam import GradCAM
|
|
4
5
|
from birder.introspection.guided_backprop import GuidedBackprop
|
|
5
6
|
from birder.introspection.transformer_attribution import TransformerAttribution
|
|
@@ -7,6 +8,7 @@ from birder.introspection.transformer_attribution import TransformerAttribution
|
|
|
7
8
|
__all__ = [
|
|
8
9
|
"InterpretabilityResult",
|
|
9
10
|
"AttentionRollout",
|
|
11
|
+
"FeaturePCA",
|
|
10
12
|
"GradCAM",
|
|
11
13
|
"GuidedBackprop",
|
|
12
14
|
"TransformerAttribution",
|
birder/introspection/base.py
CHANGED
|
@@ -2,7 +2,6 @@ from collections.abc import Callable
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Optional
|
|
5
|
-
from typing import Protocol
|
|
6
5
|
|
|
7
6
|
import matplotlib
|
|
8
7
|
import matplotlib.pyplot as plt
|
|
@@ -27,12 +26,6 @@ class InterpretabilityResult:
|
|
|
27
26
|
plt.show()
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
class Interpreter(Protocol):
|
|
31
|
-
def __call__(
|
|
32
|
-
self, image: str | Path | Image.Image, target_class: Optional[int] = None
|
|
33
|
-
) -> InterpretabilityResult: ...
|
|
34
|
-
|
|
35
|
-
|
|
36
29
|
def load_image(image: str | Path | Image.Image) -> Image.Image:
|
|
37
30
|
if isinstance(image, (str, Path)):
|
|
38
31
|
return Image.open(image)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from sklearn.decomposition import PCA
|
|
9
|
+
|
|
10
|
+
from birder.introspection.base import InterpretabilityResult
|
|
11
|
+
from birder.introspection.base import preprocess_image
|
|
12
|
+
from birder.net.base import DetectorBackbone
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FeaturePCA:
|
|
16
|
+
"""
|
|
17
|
+
Visualizes feature maps using Principal Component Analysis
|
|
18
|
+
|
|
19
|
+
This method extracts feature maps from a specified stage of a DetectorBackbone model,
|
|
20
|
+
applies PCA to reduce the channel dimension to 3 components, and visualizes them as an RGB image where:
|
|
21
|
+
- R channel = 1st principal component (most important)
|
|
22
|
+
- G channel = 2nd principal component
|
|
23
|
+
- B channel = 3rd principal component
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
net: DetectorBackbone,
|
|
29
|
+
device: torch.device,
|
|
30
|
+
transform: Callable[..., torch.Tensor],
|
|
31
|
+
normalize: bool = False,
|
|
32
|
+
channels_last: bool = False,
|
|
33
|
+
stage: Optional[str] = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
self.net = net.eval()
|
|
36
|
+
self.device = device
|
|
37
|
+
self.transform = transform
|
|
38
|
+
self.normalize = normalize
|
|
39
|
+
self.channels_last = channels_last
|
|
40
|
+
self.stage = stage
|
|
41
|
+
|
|
42
|
+
def __call__(self, image: str | Path | Image.Image) -> InterpretabilityResult:
|
|
43
|
+
(input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
|
|
44
|
+
|
|
45
|
+
with torch.inference_mode():
|
|
46
|
+
features_dict = self.net.detection_features(input_tensor)
|
|
47
|
+
|
|
48
|
+
if self.stage is not None:
|
|
49
|
+
features = features_dict[self.stage]
|
|
50
|
+
else:
|
|
51
|
+
features = list(features_dict.values())[-1] # Use the last stage by default
|
|
52
|
+
|
|
53
|
+
features_np = features.cpu().numpy()
|
|
54
|
+
|
|
55
|
+
# Handle channels_last format (B, H, W, C) vs channels_first (B, C, H, W)
|
|
56
|
+
if self.channels_last is True:
|
|
57
|
+
(B, H, W, C) = features_np.shape
|
|
58
|
+
# Already in (B, H, W, C), just reshape to (B*H*W, C)
|
|
59
|
+
features_reshaped = features_np.reshape(-1, C)
|
|
60
|
+
else:
|
|
61
|
+
(B, C, H, W) = features_np.shape
|
|
62
|
+
# Reshape to (spatial_points, channels) for PCA
|
|
63
|
+
features_reshaped = features_np.reshape(B, C, -1)
|
|
64
|
+
features_reshaped = features_reshaped.transpose(0, 2, 1) # (B, H*W, C)
|
|
65
|
+
features_reshaped = features_reshaped.reshape(-1, C) # (B*H*W, C)
|
|
66
|
+
|
|
67
|
+
x = features_reshaped
|
|
68
|
+
if self.normalize is True:
|
|
69
|
+
x = x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-6)
|
|
70
|
+
|
|
71
|
+
pca = PCA(n_components=3)
|
|
72
|
+
pca_features = pca.fit_transform(x)
|
|
73
|
+
pca_features = pca_features.reshape(B, H, W, 3)
|
|
74
|
+
|
|
75
|
+
# Extract all 3 components (B=1)
|
|
76
|
+
pca_rgb = pca_features[0] # (H, W, 3)
|
|
77
|
+
|
|
78
|
+
# Normalize each channel independently to [0, 1]
|
|
79
|
+
for i in range(3):
|
|
80
|
+
channel = pca_rgb[:, :, i]
|
|
81
|
+
channel = channel - channel.min()
|
|
82
|
+
channel = channel / (channel.max() + 1e-7)
|
|
83
|
+
pca_rgb[:, :, i] = channel
|
|
84
|
+
|
|
85
|
+
target_size = (input_tensor.size(-1), input_tensor.size(-2)) # PIL expects (width, height)
|
|
86
|
+
pca_rgb_resized = (
|
|
87
|
+
np.array(
|
|
88
|
+
Image.fromarray((pca_rgb * 255).astype(np.uint8)).resize(target_size, Image.Resampling.BILINEAR)
|
|
89
|
+
).astype(np.float32)
|
|
90
|
+
/ 255.0
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
visualization = (pca_rgb_resized * 255).astype(np.uint8)
|
|
94
|
+
|
|
95
|
+
return InterpretabilityResult(
|
|
96
|
+
original_image=rgb_img,
|
|
97
|
+
visualization=visualization,
|
|
98
|
+
raw_output=pca_rgb.astype(np.float32),
|
|
99
|
+
logits=None,
|
|
100
|
+
predicted_class=None,
|
|
101
|
+
)
|
|
@@ -4,6 +4,9 @@
|
|
|
4
4
|
* Taken from:
|
|
5
5
|
* https://github.com/MrParosk/soft_nms
|
|
6
6
|
* Licensed under the MIT License
|
|
7
|
+
*
|
|
8
|
+
* Modified by:
|
|
9
|
+
* Ofer Hasson — 2026-01-10
|
|
7
10
|
**************************************************************************************************
|
|
8
11
|
*/
|
|
9
12
|
|
|
@@ -40,8 +43,8 @@ torch::Tensor calculate_iou(const torch::Tensor& boxes, const torch::Tensor& are
|
|
|
40
43
|
auto xx2 = torch::minimum(boxes.index({idx, 2}), boxes.index({Slice(idx + 1, None), 2}));
|
|
41
44
|
auto yy2 = torch::minimum(boxes.index({idx, 3}), boxes.index({Slice(idx + 1, None), 3}));
|
|
42
45
|
|
|
43
|
-
auto w =
|
|
44
|
-
auto h =
|
|
46
|
+
auto w = (xx2 - xx1).clamp_min(0);
|
|
47
|
+
auto h = (yy2 - yy1).clamp_min(0);
|
|
45
48
|
|
|
46
49
|
auto intersection = w * h;
|
|
47
50
|
auto union_ = areas.index({idx}) + areas.index({Slice(idx + 1, None)}) - intersection;
|
|
@@ -87,14 +87,15 @@ class ModelRegistry:
|
|
|
87
87
|
no further registration is needed.
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
|
+
alias_key = alias.lower()
|
|
90
91
|
if net_type.auto_register is False:
|
|
91
92
|
# Register the model manually, as the base class doesn't take care of that for us
|
|
92
|
-
|
|
93
|
+
self.register_model(alias_key, type(alias, (net_type,), {"config": config}))
|
|
93
94
|
|
|
94
95
|
if alias in self.aliases:
|
|
95
96
|
warnings.warn(f"Alias {alias} is already registered", UserWarning)
|
|
96
97
|
|
|
97
|
-
self.aliases[
|
|
98
|
+
self.aliases[alias_key] = type(alias, (net_type,), {"config": config})
|
|
98
99
|
|
|
99
100
|
def register_weights(self, name: str, weights_info: manifest.ModelMetadataType) -> None:
|
|
100
101
|
if name in self._pretrained_nets:
|
birder/net/base.py
CHANGED
|
@@ -173,14 +173,14 @@ class BaseNet(nn.Module):
|
|
|
173
173
|
|
|
174
174
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
175
175
|
for param in self.parameters():
|
|
176
|
-
param.
|
|
176
|
+
param.requires_grad_(False)
|
|
177
177
|
|
|
178
178
|
if freeze_classifier is False:
|
|
179
179
|
for param in self.classifier.parameters():
|
|
180
|
-
param.
|
|
180
|
+
param.requires_grad_(True)
|
|
181
181
|
if unfreeze_features is True and hasattr(self, "features") is True:
|
|
182
182
|
for param in self.features.parameters():
|
|
183
|
-
param.
|
|
183
|
+
param.requires_grad_(True)
|
|
184
184
|
|
|
185
185
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
186
186
|
"""
|
birder/net/biformer.py
CHANGED
|
@@ -468,14 +468,14 @@ class BiFormer(DetectorBackbone):
|
|
|
468
468
|
|
|
469
469
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
470
470
|
for param in self.stem.parameters():
|
|
471
|
-
param.
|
|
471
|
+
param.requires_grad_(False)
|
|
472
472
|
|
|
473
473
|
for idx, module in enumerate(self.body.children()):
|
|
474
474
|
if idx >= up_to_stage:
|
|
475
475
|
break
|
|
476
476
|
|
|
477
477
|
for param in module.parameters():
|
|
478
|
-
param.
|
|
478
|
+
param.requires_grad_(False)
|
|
479
479
|
|
|
480
480
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
481
481
|
x = self.stem(x)
|
birder/net/cas_vit.py
CHANGED
|
@@ -269,18 +269,18 @@ class CAS_ViT(DetectorBackbone):
|
|
|
269
269
|
|
|
270
270
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
271
271
|
for param in self.parameters():
|
|
272
|
-
param.
|
|
272
|
+
param.requires_grad_(False)
|
|
273
273
|
|
|
274
274
|
if freeze_classifier is False:
|
|
275
275
|
for param in self.classifier.parameters():
|
|
276
|
-
param.
|
|
276
|
+
param.requires_grad_(True)
|
|
277
277
|
|
|
278
278
|
for param in self.dist_classifier.parameters():
|
|
279
|
-
param.
|
|
279
|
+
param.requires_grad_(True)
|
|
280
280
|
|
|
281
281
|
if unfreeze_features is True:
|
|
282
282
|
for param in self.features.parameters():
|
|
283
|
-
param.
|
|
283
|
+
param.requires_grad_(True)
|
|
284
284
|
|
|
285
285
|
def transform_to_backbone(self) -> None:
|
|
286
286
|
self.features = nn.Identity()
|
|
@@ -300,14 +300,14 @@ class CAS_ViT(DetectorBackbone):
|
|
|
300
300
|
|
|
301
301
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
302
302
|
for param in self.stem.parameters():
|
|
303
|
-
param.
|
|
303
|
+
param.requires_grad_(False)
|
|
304
304
|
|
|
305
305
|
for idx, module in enumerate(self.body.children()):
|
|
306
306
|
if idx >= up_to_stage:
|
|
307
307
|
break
|
|
308
308
|
|
|
309
309
|
for param in module.parameters():
|
|
310
|
-
param.
|
|
310
|
+
param.requires_grad_(False)
|
|
311
311
|
|
|
312
312
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
313
313
|
x = self.stem(x)
|
birder/net/coat.py
CHANGED
|
@@ -563,24 +563,24 @@ class CoaT(DetectorBackbone):
|
|
|
563
563
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
564
564
|
if up_to_stage >= 1:
|
|
565
565
|
for param in self.patch_embed1.parameters():
|
|
566
|
-
param.
|
|
566
|
+
param.requires_grad_(False)
|
|
567
567
|
for param in self.serial_blocks1.parameters():
|
|
568
|
-
param.
|
|
568
|
+
param.requires_grad_(False)
|
|
569
569
|
if up_to_stage >= 2:
|
|
570
570
|
for param in self.patch_embed2.parameters():
|
|
571
|
-
param.
|
|
571
|
+
param.requires_grad_(False)
|
|
572
572
|
for param in self.serial_blocks2.parameters():
|
|
573
|
-
param.
|
|
573
|
+
param.requires_grad_(False)
|
|
574
574
|
if up_to_stage >= 3:
|
|
575
575
|
for param in self.patch_embed3.parameters():
|
|
576
|
-
param.
|
|
576
|
+
param.requires_grad_(False)
|
|
577
577
|
for param in self.serial_blocks3.parameters():
|
|
578
|
-
param.
|
|
578
|
+
param.requires_grad_(False)
|
|
579
579
|
if up_to_stage >= 4:
|
|
580
580
|
for param in self.patch_embed4.parameters():
|
|
581
|
-
param.
|
|
581
|
+
param.requires_grad_(False)
|
|
582
582
|
for param in self.serial_blocks4.parameters():
|
|
583
|
-
param.
|
|
583
|
+
param.requires_grad_(False)
|
|
584
584
|
|
|
585
585
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
586
586
|
features = self._features(x)
|
birder/net/conv2former.py
CHANGED
|
@@ -218,14 +218,14 @@ class Conv2Former(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
218
218
|
|
|
219
219
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
220
220
|
for param in self.stem.parameters():
|
|
221
|
-
param.
|
|
221
|
+
param.requires_grad_(False)
|
|
222
222
|
|
|
223
223
|
for idx, module in enumerate(self.body.children()):
|
|
224
224
|
if idx >= up_to_stage:
|
|
225
225
|
break
|
|
226
226
|
|
|
227
227
|
for param in module.parameters():
|
|
228
|
-
param.
|
|
228
|
+
param.requires_grad_(False)
|
|
229
229
|
|
|
230
230
|
def masked_encoding_retention(
|
|
231
231
|
self,
|
birder/net/convnext_v1.py
CHANGED
|
@@ -158,14 +158,14 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
158
158
|
|
|
159
159
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
160
160
|
for param in self.stem.parameters():
|
|
161
|
-
param.
|
|
161
|
+
param.requires_grad_(False)
|
|
162
162
|
|
|
163
163
|
for idx, module in enumerate(self.body.children()):
|
|
164
164
|
if idx >= up_to_stage:
|
|
165
165
|
break
|
|
166
166
|
|
|
167
167
|
for param in module.parameters():
|
|
168
|
-
param.
|
|
168
|
+
param.requires_grad_(False)
|
|
169
169
|
|
|
170
170
|
def masked_encoding_retention(
|
|
171
171
|
self,
|
|
@@ -195,6 +195,21 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
195
195
|
return self.features(x)
|
|
196
196
|
|
|
197
197
|
|
|
198
|
+
registry.register_model_config(
|
|
199
|
+
"convnext_v1_atto", # Not in the original v1, taken from v2
|
|
200
|
+
ConvNeXt_v1,
|
|
201
|
+
config={"in_channels": [40, 80, 160, 320], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
|
|
202
|
+
)
|
|
203
|
+
registry.register_model_config(
|
|
204
|
+
"convnext_v1_femto", # Not in the original v1, taken from v2
|
|
205
|
+
ConvNeXt_v1,
|
|
206
|
+
config={"in_channels": [48, 96, 192, 384], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
|
|
207
|
+
)
|
|
208
|
+
registry.register_model_config(
|
|
209
|
+
"convnext_v1_pico", # Not in the original v1, taken from v2
|
|
210
|
+
ConvNeXt_v1,
|
|
211
|
+
config={"in_channels": [64, 128, 256, 512], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
|
|
212
|
+
)
|
|
198
213
|
registry.register_model_config(
|
|
199
214
|
"convnext_v1_nano", # Not in the original v1, taken from v2
|
|
200
215
|
ConvNeXt_v1,
|
|
@@ -220,6 +235,11 @@ registry.register_model_config(
|
|
|
220
235
|
ConvNeXt_v1,
|
|
221
236
|
config={"in_channels": [192, 384, 768, 1536], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
|
|
222
237
|
)
|
|
238
|
+
registry.register_model_config(
|
|
239
|
+
"convnext_v1_huge", # Not in the original v1, taken from v2
|
|
240
|
+
ConvNeXt_v1,
|
|
241
|
+
config={"in_channels": [352, 704, 1408, 2816], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
|
|
242
|
+
)
|
|
223
243
|
|
|
224
244
|
registry.register_weights(
|
|
225
245
|
"convnext_v1_tiny_eu-common256px",
|
birder/net/convnext_v2.py
CHANGED
|
@@ -180,14 +180,14 @@ class ConvNeXt_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
180
180
|
|
|
181
181
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
182
182
|
for param in self.stem.parameters():
|
|
183
|
-
param.
|
|
183
|
+
param.requires_grad_(False)
|
|
184
184
|
|
|
185
185
|
for idx, module in enumerate(self.body.children()):
|
|
186
186
|
if idx >= up_to_stage:
|
|
187
187
|
break
|
|
188
188
|
|
|
189
189
|
for param in module.parameters():
|
|
190
|
-
param.
|
|
190
|
+
param.requires_grad_(False)
|
|
191
191
|
|
|
192
192
|
def masked_encoding_retention(
|
|
193
193
|
self,
|
birder/net/crossformer.py
CHANGED
|
@@ -404,14 +404,14 @@ class CrossFormer(DetectorBackbone):
|
|
|
404
404
|
|
|
405
405
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
406
406
|
for param in self.patch_embed.parameters():
|
|
407
|
-
param.
|
|
407
|
+
param.requires_grad_(False)
|
|
408
408
|
|
|
409
409
|
for idx, module in enumerate(self.body.children()):
|
|
410
410
|
if idx >= up_to_stage:
|
|
411
411
|
break
|
|
412
412
|
|
|
413
413
|
for param in module.parameters():
|
|
414
|
-
param.
|
|
414
|
+
param.requires_grad_(False)
|
|
415
415
|
|
|
416
416
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
417
417
|
x = self.patch_embed(x)
|
birder/net/cspnet.py
CHANGED
|
@@ -342,14 +342,14 @@ class CSPNet(DetectorBackbone):
|
|
|
342
342
|
|
|
343
343
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
344
344
|
for param in self.stem.parameters():
|
|
345
|
-
param.
|
|
345
|
+
param.requires_grad_(False)
|
|
346
346
|
|
|
347
347
|
for idx, module in enumerate(self.body.children()):
|
|
348
348
|
if idx >= up_to_stage:
|
|
349
349
|
break
|
|
350
350
|
|
|
351
351
|
for param in module.parameters():
|
|
352
|
-
param.
|
|
352
|
+
param.requires_grad_(False)
|
|
353
353
|
|
|
354
354
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
355
355
|
x = self.stem(x)
|
birder/net/cswin_transformer.py
CHANGED
|
@@ -359,14 +359,14 @@ class CSWin_Transformer(DetectorBackbone):
|
|
|
359
359
|
|
|
360
360
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
361
361
|
for param in self.stem.parameters():
|
|
362
|
-
param.
|
|
362
|
+
param.requires_grad_(False)
|
|
363
363
|
|
|
364
364
|
for idx, module in enumerate(self.body.children()):
|
|
365
365
|
if idx >= up_to_stage:
|
|
366
366
|
break
|
|
367
367
|
|
|
368
368
|
for param in module.parameters():
|
|
369
|
-
param.
|
|
369
|
+
param.requires_grad_(False)
|
|
370
370
|
|
|
371
371
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
372
372
|
x = self.stem(x)
|
birder/net/darknet.py
CHANGED
|
@@ -115,14 +115,14 @@ class Darknet(DetectorBackbone):
|
|
|
115
115
|
|
|
116
116
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
117
117
|
for param in self.stem.parameters():
|
|
118
|
-
param.
|
|
118
|
+
param.requires_grad_(False)
|
|
119
119
|
|
|
120
120
|
for idx, module in enumerate(self.body.children()):
|
|
121
121
|
if idx >= up_to_stage:
|
|
122
122
|
break
|
|
123
123
|
|
|
124
124
|
for param in module.parameters():
|
|
125
|
-
param.
|
|
125
|
+
param.requires_grad_(False)
|
|
126
126
|
|
|
127
127
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
128
128
|
x = self.stem(x)
|
birder/net/davit.py
CHANGED
|
@@ -391,14 +391,14 @@ class DaViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
391
391
|
|
|
392
392
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
393
393
|
for param in self.stem.parameters():
|
|
394
|
-
param.
|
|
394
|
+
param.requires_grad_(False)
|
|
395
395
|
|
|
396
396
|
for idx, module in enumerate(self.body.children()):
|
|
397
397
|
if idx >= up_to_stage:
|
|
398
398
|
break
|
|
399
399
|
|
|
400
400
|
for param in module.parameters():
|
|
401
|
-
param.
|
|
401
|
+
param.requires_grad_(False)
|
|
402
402
|
|
|
403
403
|
def masked_encoding_retention(
|
|
404
404
|
self,
|
birder/net/deit.py
CHANGED
|
@@ -117,14 +117,14 @@ class DeiT(BaseNet):
|
|
|
117
117
|
|
|
118
118
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
119
119
|
for param in self.parameters():
|
|
120
|
-
param.
|
|
120
|
+
param.requires_grad_(False)
|
|
121
121
|
|
|
122
122
|
if freeze_classifier is False:
|
|
123
123
|
for param in self.classifier.parameters():
|
|
124
|
-
param.
|
|
124
|
+
param.requires_grad_(True)
|
|
125
125
|
|
|
126
126
|
for param in self.dist_classifier.parameters():
|
|
127
|
-
param.
|
|
127
|
+
param.requires_grad_(True)
|
|
128
128
|
|
|
129
129
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
130
130
|
self.encoder.set_causal_attention(is_causal)
|
birder/net/deit3.py
CHANGED
|
@@ -182,16 +182,16 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
182
182
|
|
|
183
183
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
184
184
|
for param in self.conv_proj.parameters():
|
|
185
|
-
param.
|
|
185
|
+
param.requires_grad_(False)
|
|
186
186
|
|
|
187
|
-
self.pos_embedding.
|
|
187
|
+
self.pos_embedding.requires_grad_(False)
|
|
188
188
|
|
|
189
189
|
for idx, module in enumerate(self.encoder.children()):
|
|
190
190
|
if idx >= up_to_stage:
|
|
191
191
|
break
|
|
192
192
|
|
|
193
193
|
for param in module.parameters():
|
|
194
|
-
param.
|
|
194
|
+
param.requires_grad_(False)
|
|
195
195
|
|
|
196
196
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
197
197
|
self.encoder.set_causal_attention(is_causal)
|