birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/vit.py
CHANGED
|
@@ -40,6 +40,7 @@ from birder.net.base import MaskedTokenRetentionMixin
|
|
|
40
40
|
from birder.net.base import PreTrainEncoder
|
|
41
41
|
from birder.net.base import TokenOmissionResultType
|
|
42
42
|
from birder.net.base import TokenRetentionResultType
|
|
43
|
+
from birder.net.base import normalize_out_indices
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
def adjust_position_embedding(
|
|
@@ -73,12 +74,10 @@ def adjust_position_embedding(
|
|
|
73
74
|
class PatchEmbed(nn.Module):
|
|
74
75
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
75
76
|
"""
|
|
76
|
-
|
|
77
|
+
This is equivalent (in output) to: x.flatten(2).transpose(1, 2)
|
|
77
78
|
"""
|
|
78
79
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
# (n, hidden_dim, h, w) -> (n, hidden_dim, (h * w))
|
|
80
|
+
n, hidden_dim, h, w = x.size()
|
|
82
81
|
x = x.reshape(n, hidden_dim, h * w)
|
|
83
82
|
|
|
84
83
|
# (n, hidden_dim, (h * w)) -> (n, (h * w), hidden_dim)
|
|
@@ -155,9 +154,9 @@ class Attention(nn.Module):
|
|
|
155
154
|
- attn_weights: If need_weights is True attention weights, otherwise, None.
|
|
156
155
|
"""
|
|
157
156
|
|
|
158
|
-
|
|
157
|
+
B, N, C = x.size()
|
|
159
158
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
160
|
-
|
|
159
|
+
q, k, v = qkv.unbind(0)
|
|
161
160
|
q = self.q_norm(q)
|
|
162
161
|
k = self.k_norm(k)
|
|
163
162
|
|
|
@@ -245,7 +244,7 @@ class EncoderBlock(nn.Module):
|
|
|
245
244
|
|
|
246
245
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
247
246
|
# torch._assert(x.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {x.size()}")
|
|
248
|
-
|
|
247
|
+
attn_out, _ = self.attn(
|
|
249
248
|
self.norm1(x),
|
|
250
249
|
need_weights=self.need_attn,
|
|
251
250
|
average_attn_weights=False,
|
|
@@ -317,13 +316,15 @@ class Encoder(nn.Module):
|
|
|
317
316
|
x = self.pre_block(x)
|
|
318
317
|
return self.block(x)
|
|
319
318
|
|
|
320
|
-
def forward_features(self, x: torch.Tensor) -> list[torch.Tensor]:
|
|
319
|
+
def forward_features(self, x: torch.Tensor, out_indices: Optional[list[int]] = None) -> list[torch.Tensor]:
|
|
321
320
|
x = self.pre_block(x)
|
|
322
321
|
|
|
322
|
+
out_indices_set = set(out_indices) if out_indices is not None else None
|
|
323
323
|
xs = []
|
|
324
|
-
for blk in self.block:
|
|
324
|
+
for idx, blk in enumerate(self.block):
|
|
325
325
|
x = blk(x)
|
|
326
|
-
|
|
326
|
+
if out_indices_set is None or idx in out_indices_set:
|
|
327
|
+
xs.append(x)
|
|
327
328
|
|
|
328
329
|
return xs
|
|
329
330
|
|
|
@@ -340,7 +341,7 @@ class Encoder(nn.Module):
|
|
|
340
341
|
class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTokenRetentionMixin):
|
|
341
342
|
block_group_regex = r"encoder\.block\.(\d+)"
|
|
342
343
|
|
|
343
|
-
# pylint: disable=too-many-locals,too-many-branches
|
|
344
|
+
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
|
344
345
|
def __init__(
|
|
345
346
|
self,
|
|
346
347
|
input_channels: int,
|
|
@@ -375,6 +376,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
375
376
|
norm_layer_eps: float = self.config.get("norm_layer_eps", 1e-6)
|
|
376
377
|
mlp_layer_type: str = self.config.get("mlp_layer_type", "FFN")
|
|
377
378
|
act_layer_type: Optional[str] = self.config.get("act_layer_type", None) # Default according to mlp type
|
|
379
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
378
380
|
drop_path_rate: float = self.config["drop_path_rate"]
|
|
379
381
|
|
|
380
382
|
if norm_layer_type == "LayerNorm":
|
|
@@ -405,6 +407,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
405
407
|
self.hidden_dim = hidden_dim
|
|
406
408
|
self.num_reg_tokens = num_reg_tokens
|
|
407
409
|
self.attn_pool_special_tokens = attn_pool_special_tokens
|
|
410
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
408
411
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
|
|
409
412
|
|
|
410
413
|
self.conv_proj = nn.Conv2d(
|
|
@@ -472,8 +475,9 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
472
475
|
|
|
473
476
|
self.attn_pool = MultiHeadAttentionPool(hidden_dim, attn_pool_num_heads, mlp_dim, qkv_bias=True)
|
|
474
477
|
|
|
475
|
-
self.
|
|
476
|
-
self.
|
|
478
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
479
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
480
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
477
481
|
self.embedding_size = hidden_dim
|
|
478
482
|
self.classifier = self.create_classifier()
|
|
479
483
|
|
|
@@ -537,8 +541,12 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
537
541
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
538
542
|
self.encoder.set_causal_attention(is_causal)
|
|
539
543
|
|
|
544
|
+
def transform_to_backbone(self) -> None:
|
|
545
|
+
super().transform_to_backbone()
|
|
546
|
+
self.norm = nn.Identity()
|
|
547
|
+
|
|
540
548
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
541
|
-
|
|
549
|
+
H, W = x.shape[-2:]
|
|
542
550
|
x = self.conv_proj(x)
|
|
543
551
|
x = self.patch_embed(x)
|
|
544
552
|
|
|
@@ -558,15 +566,20 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
558
566
|
if self.pos_embed_special_tokens is True:
|
|
559
567
|
x = x + self._get_pos_embed(H, W)
|
|
560
568
|
|
|
561
|
-
|
|
562
|
-
|
|
569
|
+
if self.out_indices is None:
|
|
570
|
+
xs = [self.encoder(x)]
|
|
571
|
+
else:
|
|
572
|
+
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
563
573
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
574
|
+
out: dict[str, torch.Tensor] = {}
|
|
575
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
576
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
577
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
578
|
+
B, C, _ = stage_x.size()
|
|
579
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
580
|
+
out[stage_name] = stage_x
|
|
568
581
|
|
|
569
|
-
return
|
|
582
|
+
return out
|
|
570
583
|
|
|
571
584
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
572
585
|
for param in self.conv_proj.parameters():
|
|
@@ -589,7 +602,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
589
602
|
return_all_features: bool = False,
|
|
590
603
|
return_keys: Literal["all", "tokens", "embedding"] = "tokens",
|
|
591
604
|
) -> TokenOmissionResultType:
|
|
592
|
-
|
|
605
|
+
H, W = x.shape[-2:]
|
|
593
606
|
|
|
594
607
|
# Reshape and permute the input tensor
|
|
595
608
|
x = self.conv_proj(x)
|
|
@@ -663,7 +676,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
663
676
|
mask_token: Optional[torch.Tensor] = None,
|
|
664
677
|
return_keys: Literal["all", "features", "embedding"] = "features",
|
|
665
678
|
) -> TokenRetentionResultType:
|
|
666
|
-
|
|
679
|
+
H, W = x.shape[-2:]
|
|
667
680
|
|
|
668
681
|
x = self.conv_proj(x)
|
|
669
682
|
x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
|
|
@@ -694,7 +707,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
694
707
|
if return_keys in ("all", "features"):
|
|
695
708
|
features = x[:, self.num_special_tokens :]
|
|
696
709
|
features = features.permute(0, 2, 1)
|
|
697
|
-
|
|
710
|
+
B, C, _ = features.size()
|
|
698
711
|
features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
699
712
|
result["features"] = features
|
|
700
713
|
|
|
@@ -714,7 +727,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
714
727
|
return result
|
|
715
728
|
|
|
716
729
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
717
|
-
|
|
730
|
+
H, W = x.shape[-2:]
|
|
718
731
|
|
|
719
732
|
# Reshape and permute the input tensor
|
|
720
733
|
x = self.conv_proj(x)
|
birder/net/vit_parallel.py
CHANGED
|
@@ -31,6 +31,7 @@ from birder.net.base import MaskedTokenRetentionMixin
|
|
|
31
31
|
from birder.net.base import PreTrainEncoder
|
|
32
32
|
from birder.net.base import TokenOmissionResultType
|
|
33
33
|
from birder.net.base import TokenRetentionResultType
|
|
34
|
+
from birder.net.base import normalize_out_indices
|
|
34
35
|
from birder.net.vit import PatchEmbed
|
|
35
36
|
from birder.net.vit import adjust_position_embedding
|
|
36
37
|
|
|
@@ -51,9 +52,9 @@ class Attention(nn.Module):
|
|
|
51
52
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
52
53
|
|
|
53
54
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
-
|
|
55
|
+
B, N, C = x.size()
|
|
55
56
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
56
|
-
|
|
57
|
+
q, k, v = qkv.unbind(0)
|
|
57
58
|
|
|
58
59
|
x = F.scaled_dot_product_attention( # pylint: disable=not-callable
|
|
59
60
|
q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, is_causal=self.is_causal, scale=self.scale
|
|
@@ -172,11 +173,13 @@ class Encoder(nn.Module):
|
|
|
172
173
|
|
|
173
174
|
return x
|
|
174
175
|
|
|
175
|
-
def forward_features(self, x: torch.Tensor) -> list[torch.Tensor]:
|
|
176
|
+
def forward_features(self, x: torch.Tensor, out_indices: Optional[list[int]] = None) -> list[torch.Tensor]:
|
|
176
177
|
xs = []
|
|
177
|
-
|
|
178
|
+
out_indices_set = set(out_indices) if out_indices is not None else None
|
|
179
|
+
for idx, blk in enumerate(self.block):
|
|
178
180
|
x = blk(x)
|
|
179
|
-
|
|
181
|
+
if out_indices_set is None or idx in out_indices_set:
|
|
182
|
+
xs.append(x)
|
|
180
183
|
|
|
181
184
|
return xs
|
|
182
185
|
|
|
@@ -213,6 +216,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
213
216
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
214
217
|
class_token: bool = self.config.get("class_token", True)
|
|
215
218
|
norm_layer_type: str = self.config.get("norm_layer_type", "LayerNorm")
|
|
219
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
216
220
|
drop_path_rate: float = self.config["drop_path_rate"]
|
|
217
221
|
|
|
218
222
|
if norm_layer_type == "LayerNorm":
|
|
@@ -230,6 +234,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
230
234
|
self.hidden_dim = hidden_dim
|
|
231
235
|
self.layer_scale_init_value = layer_scale_init_value
|
|
232
236
|
self.num_reg_tokens = num_reg_tokens
|
|
237
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
233
238
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
|
|
234
239
|
|
|
235
240
|
self.conv_proj = nn.Conv2d(
|
|
@@ -238,7 +243,6 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
238
243
|
kernel_size=(patch_size, patch_size),
|
|
239
244
|
stride=(patch_size, patch_size),
|
|
240
245
|
padding=(0, 0),
|
|
241
|
-
bias=True,
|
|
242
246
|
)
|
|
243
247
|
self.patch_embed = PatchEmbed()
|
|
244
248
|
|
|
@@ -278,8 +282,9 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
278
282
|
)
|
|
279
283
|
self.norm = norm_layer(hidden_dim, eps=1e-6)
|
|
280
284
|
|
|
281
|
-
self.
|
|
282
|
-
self.
|
|
285
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
286
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
287
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
283
288
|
self.embedding_size = hidden_dim
|
|
284
289
|
self.classifier = self.create_classifier()
|
|
285
290
|
|
|
@@ -338,8 +343,12 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
338
343
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
339
344
|
self.encoder.set_causal_attention(is_causal)
|
|
340
345
|
|
|
346
|
+
def transform_to_backbone(self) -> None:
|
|
347
|
+
super().transform_to_backbone()
|
|
348
|
+
self.norm = nn.Identity()
|
|
349
|
+
|
|
341
350
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
342
|
-
|
|
351
|
+
H, W = x.shape[-2:]
|
|
343
352
|
x = self.conv_proj(x)
|
|
344
353
|
x = self.patch_embed(x)
|
|
345
354
|
|
|
@@ -354,15 +363,21 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
354
363
|
x = torch.concat([batch_reg_tokens, x], dim=1)
|
|
355
364
|
|
|
356
365
|
x = x + self._get_pos_embed(H, W)
|
|
357
|
-
x = self.encoder(x)
|
|
358
|
-
x = self.norm(x)
|
|
359
366
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
367
|
+
if self.out_indices is None:
|
|
368
|
+
xs = [self.encoder(x)]
|
|
369
|
+
else:
|
|
370
|
+
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
371
|
+
|
|
372
|
+
out: dict[str, torch.Tensor] = {}
|
|
373
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
374
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
375
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
376
|
+
B, C, _ = stage_x.size()
|
|
377
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
378
|
+
out[stage_name] = stage_x
|
|
364
379
|
|
|
365
|
-
return
|
|
380
|
+
return out
|
|
366
381
|
|
|
367
382
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
368
383
|
for param in self.conv_proj.parameters():
|
|
@@ -384,7 +399,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
384
399
|
return_all_features: bool = False,
|
|
385
400
|
return_keys: Literal["all", "tokens", "embedding"] = "tokens",
|
|
386
401
|
) -> TokenOmissionResultType:
|
|
387
|
-
|
|
402
|
+
H, W = x.shape[-2:]
|
|
388
403
|
|
|
389
404
|
# Reshape and permute the input tensor
|
|
390
405
|
x = self.conv_proj(x)
|
|
@@ -441,7 +456,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
441
456
|
mask_token: Optional[torch.Tensor] = None,
|
|
442
457
|
return_keys: Literal["all", "features", "embedding"] = "features",
|
|
443
458
|
) -> TokenRetentionResultType:
|
|
444
|
-
|
|
459
|
+
H, W = x.shape[-2:]
|
|
445
460
|
|
|
446
461
|
x = self.conv_proj(x)
|
|
447
462
|
x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
|
|
@@ -467,7 +482,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
467
482
|
if return_keys in ("all", "features"):
|
|
468
483
|
features = x[:, self.num_special_tokens :]
|
|
469
484
|
features = features.permute(0, 2, 1)
|
|
470
|
-
|
|
485
|
+
B, C, _ = features.size()
|
|
471
486
|
features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
472
487
|
result["features"] = features
|
|
473
488
|
|
|
@@ -481,7 +496,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
481
496
|
return result
|
|
482
497
|
|
|
483
498
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
484
|
-
|
|
499
|
+
H, W = x.shape[-2:]
|
|
485
500
|
|
|
486
501
|
# Reshape and permute the input tensor
|
|
487
502
|
x = self.conv_proj(x)
|
birder/net/vit_sam.py
CHANGED
|
@@ -35,7 +35,7 @@ from birder.net.vit import EncoderBlock as MAEDecoderBlock
|
|
|
35
35
|
|
|
36
36
|
# pylint: disable=invalid-name
|
|
37
37
|
def window_partition(x: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]:
|
|
38
|
-
|
|
38
|
+
B, H, W, C = x.shape
|
|
39
39
|
|
|
40
40
|
pad_h = (window_size - H % window_size) % window_size
|
|
41
41
|
pad_w = (window_size - W % window_size) % window_size
|
|
@@ -55,8 +55,8 @@ def window_partition(x: torch.Tensor, window_size: int) -> tuple[torch.Tensor, t
|
|
|
55
55
|
def window_unpartition(
|
|
56
56
|
windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]
|
|
57
57
|
) -> torch.Tensor:
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
Hp, Wp = pad_hw
|
|
59
|
+
H, W = hw
|
|
60
60
|
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
|
61
61
|
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
|
62
62
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
|
@@ -91,12 +91,12 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
|
|
|
91
91
|
def get_decomposed_rel_pos_bias(
|
|
92
92
|
q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: tuple[int, int], k_size: tuple[int, int]
|
|
93
93
|
) -> torch.Tensor:
|
|
94
|
-
|
|
95
|
-
|
|
94
|
+
q_h, q_w = q_size
|
|
95
|
+
k_h, k_w = k_size
|
|
96
96
|
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
|
97
97
|
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
98
98
|
|
|
99
|
-
|
|
99
|
+
B, _, dim = q.shape
|
|
100
100
|
r_q = q.reshape(B, q_h, q_w, dim)
|
|
101
101
|
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
|
102
102
|
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
|
@@ -139,9 +139,9 @@ class Attention(nn.Module):
|
|
|
139
139
|
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
|
140
140
|
|
|
141
141
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
142
|
-
|
|
142
|
+
B, H, W, _ = x.shape
|
|
143
143
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
144
|
-
|
|
144
|
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
|
145
145
|
|
|
146
146
|
if self.use_rel_pos is True:
|
|
147
147
|
attn_bias = get_decomposed_rel_pos_bias(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
|
@@ -216,13 +216,13 @@ class EncoderBlock(nn.Module):
|
|
|
216
216
|
self.layer_scale_2 = nn.Identity()
|
|
217
217
|
|
|
218
218
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
219
|
-
|
|
219
|
+
_, H, W, _ = x.shape
|
|
220
220
|
shortcut = x
|
|
221
221
|
|
|
222
222
|
x = self.norm1(x)
|
|
223
223
|
pad_hw = (0, 0)
|
|
224
224
|
if self.window_size > 0:
|
|
225
|
-
|
|
225
|
+
x, pad_hw = window_partition(x, self.window_size)
|
|
226
226
|
|
|
227
227
|
x = self.attn(x)
|
|
228
228
|
if self.window_size > 0:
|
birder/net/vovnet_v2.py
CHANGED
|
@@ -27,7 +27,7 @@ class EffectiveSE(nn.Module):
|
|
|
27
27
|
|
|
28
28
|
def __init__(self, channels: int) -> None:
|
|
29
29
|
super().__init__()
|
|
30
|
-
self.fc = nn.Conv2d(channels, channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
|
|
30
|
+
self.fc = nn.Conv2d(channels, channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
31
31
|
|
|
32
32
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
33
33
|
x_se = x.mean(dim=(2, 3), keepdim=True)
|
birder/net/xcit.py
CHANGED
|
@@ -30,6 +30,7 @@ from birder.net.base import DetectorBackbone
|
|
|
30
30
|
from birder.net.base import MaskedTokenRetentionMixin
|
|
31
31
|
from birder.net.base import PreTrainEncoder
|
|
32
32
|
from birder.net.base import TokenRetentionResultType
|
|
33
|
+
from birder.net.base import normalize_out_indices
|
|
33
34
|
from birder.net.cait import ClassAttention
|
|
34
35
|
|
|
35
36
|
|
|
@@ -212,7 +213,7 @@ class LPI(nn.Module):
|
|
|
212
213
|
)
|
|
213
214
|
|
|
214
215
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
215
|
-
|
|
216
|
+
B, N, C = x.shape
|
|
216
217
|
x = x.permute(0, 2, 1).reshape(B, C, H, W)
|
|
217
218
|
x = self.conv_bn_act(x)
|
|
218
219
|
x = self.conv(x)
|
|
@@ -236,10 +237,10 @@ class XCA(nn.Module):
|
|
|
236
237
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
237
238
|
|
|
238
239
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
239
|
-
|
|
240
|
+
B, N, C = x.shape
|
|
240
241
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
|
241
242
|
qkv = qkv.permute(2, 0, 3, 1, 4)
|
|
242
|
-
|
|
243
|
+
q, k, v = qkv.unbind(0)
|
|
243
244
|
|
|
244
245
|
q = F.normalize(q, dim=-1) * self.temperature
|
|
245
246
|
k = F.normalize(k, dim=-1)
|
|
@@ -311,6 +312,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
311
312
|
else:
|
|
312
313
|
raise ValueError(f"depth={depth} is not supported")
|
|
313
314
|
|
|
315
|
+
out_indices = normalize_out_indices(out_indices, depth)
|
|
314
316
|
self.patch_embed = ConvPatchEmbed(patch_size, self.input_channels, dim=embed_dim)
|
|
315
317
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
316
318
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # Stochastic depth decay rule
|
|
@@ -381,7 +383,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
381
383
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
382
384
|
B = x.size(0)
|
|
383
385
|
|
|
384
|
-
|
|
386
|
+
x, H, W = self.patch_embed(x)
|
|
385
387
|
|
|
386
388
|
pos_encoding = self.pos_embed(B, H, W).reshape(B, -1, x.size(1)).permute(0, 2, 1)
|
|
387
389
|
x = x + pos_encoding
|
|
@@ -414,7 +416,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
414
416
|
) -> TokenRetentionResultType:
|
|
415
417
|
B = x.size(0)
|
|
416
418
|
|
|
417
|
-
|
|
419
|
+
x, H, W = self.patch_embed(x)
|
|
418
420
|
x = mask_tensor(
|
|
419
421
|
x.permute(0, 2, 1).reshape(B, -1, H, W),
|
|
420
422
|
mask,
|
|
@@ -435,7 +437,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
435
437
|
if return_keys in ("all", "features"):
|
|
436
438
|
features = x[:, 1:]
|
|
437
439
|
features = features.permute(0, 2, 1)
|
|
438
|
-
|
|
440
|
+
B, C, _ = features.size()
|
|
439
441
|
features = features.reshape(B, C, H, W)
|
|
440
442
|
result["features"] = features
|
|
441
443
|
|
|
@@ -447,7 +449,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
447
449
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
448
450
|
B = x.size(0)
|
|
449
451
|
|
|
450
|
-
|
|
452
|
+
x, H, W = self.patch_embed(x)
|
|
451
453
|
|
|
452
454
|
pos_encoding = self.pos_embed(B, H, W).reshape(B, -1, x.size(1)).permute(0, 2, 1)
|
|
453
455
|
x = x + pos_encoding
|
birder/ops/msda.py
CHANGED
|
@@ -91,8 +91,8 @@ def _ms_deform_attn_setup_context( # type: ignore[no-untyped-def] # pylint: dis
|
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
def _ms_deform_attn_backward(ctx, grad_output): # type: ignore[no-untyped-def]
|
|
94
|
-
|
|
95
|
-
|
|
94
|
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
|
|
95
|
+
grad_value, grad_sampling_loc, grad_attn_weight = ms_deform_attn_backward_op(
|
|
96
96
|
value,
|
|
97
97
|
value_spatial_shapes,
|
|
98
98
|
value_level_start_index,
|
|
@@ -160,8 +160,8 @@ def multi_scale_deformable_attention(
|
|
|
160
160
|
attention_weights: torch.Tensor,
|
|
161
161
|
im2col_step: int, # pylint: disable=unused-argument
|
|
162
162
|
) -> torch.Tensor:
|
|
163
|
-
|
|
164
|
-
|
|
163
|
+
batch_size, _, num_heads, hidden_dim = value.size()
|
|
164
|
+
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.size()
|
|
165
165
|
areas: list[int] = value_spatial_shapes.prod(dim=1).tolist()
|
|
166
166
|
value_list = value.split(areas, dim=1)
|
|
167
167
|
sampling_grids = 2 * sampling_locations - 1
|
birder/ops/swattention.py
CHANGED
|
@@ -38,7 +38,7 @@ def _swattention_qk_rpb_fake( # pylint: disable=unused-argument
|
|
|
38
38
|
def _swattention_qk_rpb_setup_context( # type: ignore[no-untyped-def] # pylint: disable=unused-argument
|
|
39
39
|
ctx, inputs, output
|
|
40
40
|
) -> None:
|
|
41
|
-
|
|
41
|
+
query, key, _rpb, height, width, kernel_size = inputs
|
|
42
42
|
ctx.save_for_backward(query, key)
|
|
43
43
|
ctx.height = height
|
|
44
44
|
ctx.width = width
|
|
@@ -46,8 +46,8 @@ def _swattention_qk_rpb_setup_context( # type: ignore[no-untyped-def] # pylint:
|
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
def _swattention_qk_rpb_backward(ctx, grad_output): # type: ignore[no-untyped-def]
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
query, key = ctx.saved_tensors
|
|
50
|
+
d_query, d_key, d_rpb = swattention_qk_rpb_backward_op(
|
|
51
51
|
grad_output.contiguous(), query, key, ctx.height, ctx.width, ctx.kernel_size
|
|
52
52
|
)
|
|
53
53
|
return (d_query, d_key, d_rpb, None, None, None)
|
|
@@ -107,8 +107,8 @@ def _swattention_av_setup_context( # type: ignore[no-untyped-def] # pylint: dis
|
|
|
107
107
|
|
|
108
108
|
|
|
109
109
|
def _swattention_av_backward(ctx, grad_output): # type: ignore[no-untyped-def]
|
|
110
|
-
|
|
111
|
-
|
|
110
|
+
attn_weight, value = ctx.saved_tensors
|
|
111
|
+
d_attn_weight, d_value = swattention_av_backward_op(
|
|
112
112
|
grad_output.contiguous(), attn_weight, value, ctx.height, ctx.width, ctx.kernel_size
|
|
113
113
|
)
|
|
114
114
|
return (d_attn_weight, d_value, None, None, None)
|
|
@@ -184,10 +184,10 @@ class SWAttention_QK_RPB(nn.Module):
|
|
|
184
184
|
)
|
|
185
185
|
|
|
186
186
|
# Custom kernel
|
|
187
|
-
|
|
187
|
+
B, N, _ = kv.size()
|
|
188
188
|
|
|
189
189
|
# Generate unfolded keys and values and l2-normalize them
|
|
190
|
-
|
|
190
|
+
k_local, v_local = kv.reshape(B, N, 2 * num_heads, head_dim).permute(0, 2, 1, 3).chunk(2, dim=1)
|
|
191
191
|
|
|
192
192
|
# Compute local similarity
|
|
193
193
|
attn_local = swattention_qk_rpb_op(
|
|
@@ -254,14 +254,14 @@ def swattention_qk_rpb(
|
|
|
254
254
|
H: int,
|
|
255
255
|
W: int,
|
|
256
256
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
257
|
-
|
|
257
|
+
B, N, _ = kv.size()
|
|
258
258
|
|
|
259
259
|
# Generate unfolded keys and values and l2-normalize them
|
|
260
|
-
|
|
260
|
+
k_local, v_local = kv.chunk(2, dim=-1)
|
|
261
261
|
k_local = F.normalize(k_local.reshape(B, N, num_heads, head_dim), dim=-1).reshape(B, N, -1)
|
|
262
262
|
kv_local = torch.concat([k_local, v_local], dim=-1).permute(0, 2, 1).reshape(B, -1, H, W)
|
|
263
263
|
|
|
264
|
-
|
|
264
|
+
k_local, v_local = (
|
|
265
265
|
F.unfold(kv_local, kernel_size=window_size, padding=window_size // 2, stride=1)
|
|
266
266
|
.reshape(B, 2 * num_heads, head_dim, local_len, N)
|
|
267
267
|
.permute(0, 1, 4, 2, 3)
|
birder/results/classification.py
CHANGED
|
@@ -30,7 +30,7 @@ def top_k_accuracy_score(y_true: npt.NDArray[Any], y_pred: npt.NDArray[np.float6
|
|
|
30
30
|
if len(y_true.shape) == 2:
|
|
31
31
|
y_true = np.argmax(y_true, axis=1)
|
|
32
32
|
|
|
33
|
-
|
|
33
|
+
num_samples, _num_labels = y_pred.shape
|
|
34
34
|
indices: list[int] = []
|
|
35
35
|
arg_sorted = np.argpartition(y_pred, -top_k, axis=1)[:, -top_k:]
|
|
36
36
|
for i in range(num_samples):
|
|
@@ -693,7 +693,7 @@ class SparseResults(Results):
|
|
|
693
693
|
For sparse files, this value is ignored.
|
|
694
694
|
"""
|
|
695
695
|
|
|
696
|
-
|
|
696
|
+
label_names, detected_sparse_k = detect_file_format(path)
|
|
697
697
|
|
|
698
698
|
if detected_sparse_k is not None:
|
|
699
699
|
schema_overrides = {
|
|
@@ -817,7 +817,7 @@ def load_results(path: str, lazy: bool = True) -> Results | SparseResults:
|
|
|
817
817
|
<class 'birder.results.classification.SparseResults'>
|
|
818
818
|
"""
|
|
819
819
|
|
|
820
|
-
|
|
820
|
+
_, sparse_k = detect_file_format(path)
|
|
821
821
|
|
|
822
822
|
# Load using appropriate class
|
|
823
823
|
if sparse_k is not None:
|
birder/results/gui.py
CHANGED
|
@@ -212,7 +212,7 @@ class ConfusionMatrix:
|
|
|
212
212
|
)
|
|
213
213
|
|
|
214
214
|
offset = 0.5
|
|
215
|
-
|
|
215
|
+
height, width = cnf_matrix.shape
|
|
216
216
|
ax.hlines(
|
|
217
217
|
y=np.arange(height + 1) - offset,
|
|
218
218
|
xmin=-offset,
|
|
@@ -261,7 +261,7 @@ class ROC:
|
|
|
261
261
|
roc_auc = {}
|
|
262
262
|
for i in results.unique_labels:
|
|
263
263
|
binary_labels = results.labels == i
|
|
264
|
-
|
|
264
|
+
fpr[i], tpr[i], _ = roc_curve(binary_labels, results.output[:, i])
|
|
265
265
|
if np.sum(binary_labels) == 0:
|
|
266
266
|
tpr[i] = np.zeros_like(fpr[i])
|
|
267
267
|
|
|
@@ -324,7 +324,7 @@ class PrecisionRecall:
|
|
|
324
324
|
labels = label_binarize(results.labels, classes=range(len(results.label_names)))
|
|
325
325
|
|
|
326
326
|
# A "micro-average" quantifying score on all classes jointly
|
|
327
|
-
|
|
327
|
+
precision, recall, _ = precision_recall_curve(labels.ravel(), results.output.ravel())
|
|
328
328
|
average_precision = average_precision_score(labels.ravel(), results.output.ravel(), average="micro")
|
|
329
329
|
|
|
330
330
|
line = ax.step(recall, precision, linestyle=":", where="post")
|
|
@@ -334,7 +334,7 @@ class PrecisionRecall:
|
|
|
334
334
|
# Per selected class
|
|
335
335
|
for cls in pr_classes:
|
|
336
336
|
i = results.label_names.index(cls)
|
|
337
|
-
|
|
337
|
+
precision, recall, _ = precision_recall_curve(labels[:, i], results.output[:, i])
|
|
338
338
|
average_precision = average_precision_score(labels[:, i], results.output[:, i])
|
|
339
339
|
line = ax.plot(recall, precision, lw=2)
|
|
340
340
|
legend_lines.append(line[0])
|
|
@@ -372,8 +372,8 @@ class ProbabilityHistogram:
|
|
|
372
372
|
cls_a_df = results_df.filter(pl.col("label_name") == cls_a)
|
|
373
373
|
cls_b_df = results_df.filter(pl.col("label_name") == cls_b)
|
|
374
374
|
|
|
375
|
-
|
|
376
|
-
|
|
375
|
+
cls_a_prob_a_counts, cls_a_prob_a_bins = hist(cls_a_df[str(self.results.label_names.index(cls_a))])
|
|
376
|
+
cls_a_prob_b_counts, cls_a_prob_b_bins = hist(cls_b_df[str(self.results.label_names.index(cls_a))])
|
|
377
377
|
plt.subplot(2, 1, 1)
|
|
378
378
|
plt.stairs(
|
|
379
379
|
cls_a_prob_a_counts,
|
|
@@ -391,8 +391,8 @@ class ProbabilityHistogram:
|
|
|
391
391
|
)
|
|
392
392
|
plt.legend(loc="upper center")
|
|
393
393
|
|
|
394
|
-
|
|
395
|
-
|
|
394
|
+
cls_b_prob_a_counts, cls_b_prob_a_bins = hist(cls_a_df[str(self.results.label_names.index(cls_b))])
|
|
395
|
+
cls_b_prob_b_counts, cls_b_prob_b_bins = hist(cls_b_df[str(self.results.label_names.index(cls_b))])
|
|
396
396
|
plt.subplot(2, 1, 2)
|
|
397
397
|
plt.stairs(
|
|
398
398
|
cls_b_prob_b_counts,
|