birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/mim/simmim.py
CHANGED
|
@@ -80,7 +80,6 @@ class SimMIM(MIMBaseNet):
|
|
|
80
80
|
kernel_size=(1, 1),
|
|
81
81
|
stride=(1, 1),
|
|
82
82
|
padding=(0, 0),
|
|
83
|
-
bias=True,
|
|
84
83
|
)
|
|
85
84
|
|
|
86
85
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.encoder.stem_width))
|
|
@@ -112,7 +111,7 @@ class SimMIM(MIMBaseNet):
|
|
|
112
111
|
"""
|
|
113
112
|
|
|
114
113
|
if x.ndim == 4:
|
|
115
|
-
|
|
114
|
+
n, c, _, _ = x.shape
|
|
116
115
|
x = x.reshape(n, c, -1)
|
|
117
116
|
x = torch.einsum("ncl->nlc", x)
|
|
118
117
|
|
|
@@ -135,7 +134,7 @@ class SimMIM(MIMBaseNet):
|
|
|
135
134
|
mask: 0 is keep, 1 is remove
|
|
136
135
|
"""
|
|
137
136
|
|
|
138
|
-
|
|
137
|
+
N, C, _, _ = pred.shape
|
|
139
138
|
pred = pred.reshape(N, C, -1)
|
|
140
139
|
pred = torch.einsum("ncl->nlc", pred)
|
|
141
140
|
|
|
@@ -142,24 +142,24 @@ class MultiQueryAttention(nn.Module):
|
|
|
142
142
|
self.output = nn.Sequential(*output_layers)
|
|
143
143
|
|
|
144
144
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
145
|
-
|
|
145
|
+
B, C, H, W = x.size()
|
|
146
146
|
q = self.query(x)
|
|
147
147
|
q = q.reshape(B, self.num_heads, self.key_dim, -1)
|
|
148
148
|
q = q.transpose(-1, -2).contiguous()
|
|
149
149
|
|
|
150
150
|
k = self.key(x)
|
|
151
|
-
|
|
151
|
+
B, C, _, _ = k.size()
|
|
152
152
|
k = k.reshape(B, C, -1).transpose(1, 2)
|
|
153
153
|
k = k.unsqueeze(1).contiguous()
|
|
154
154
|
|
|
155
155
|
v = self.value(x)
|
|
156
|
-
|
|
156
|
+
B, C, _, _ = v.size()
|
|
157
157
|
v = v.reshape(B, C, -1).transpose(1, 2)
|
|
158
158
|
v = v.unsqueeze(1).contiguous()
|
|
159
159
|
|
|
160
160
|
# Calculate attention score
|
|
161
161
|
attn_score = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) # pylint: disable=not-callable
|
|
162
|
-
|
|
162
|
+
B, _, _, C = attn_score.size()
|
|
163
163
|
feat_dim = C * self.num_heads
|
|
164
164
|
attn_score = attn_score.transpose(1, 2)
|
|
165
165
|
attn_score = (
|
birder/net/mobileone.py
CHANGED
|
@@ -61,13 +61,7 @@ class MobileOneBlock(nn.Module):
|
|
|
61
61
|
|
|
62
62
|
if reparameterized is True:
|
|
63
63
|
self.reparam_conv = nn.Conv2d(
|
|
64
|
-
in_channels,
|
|
65
|
-
out_channels,
|
|
66
|
-
kernel_size=kernel_size,
|
|
67
|
-
stride=stride,
|
|
68
|
-
padding=padding,
|
|
69
|
-
groups=groups,
|
|
70
|
-
bias=True,
|
|
64
|
+
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups
|
|
71
65
|
)
|
|
72
66
|
else:
|
|
73
67
|
self.reparam_conv = None
|
|
@@ -144,7 +138,7 @@ class MobileOneBlock(nn.Module):
|
|
|
144
138
|
if self.reparameterized is True:
|
|
145
139
|
return
|
|
146
140
|
|
|
147
|
-
|
|
141
|
+
kernel, bias = self._get_kernel_bias()
|
|
148
142
|
self.reparam_conv = nn.Conv2d(
|
|
149
143
|
in_channels=self.in_channels,
|
|
150
144
|
out_channels=self.out_channels,
|
|
@@ -152,7 +146,6 @@ class MobileOneBlock(nn.Module):
|
|
|
152
146
|
stride=self.stride,
|
|
153
147
|
padding=self.padding,
|
|
154
148
|
groups=self.groups,
|
|
155
|
-
bias=True,
|
|
156
149
|
)
|
|
157
150
|
self.reparam_conv.weight.data = kernel
|
|
158
151
|
self.reparam_conv.bias.data = bias
|
|
@@ -178,7 +171,7 @@ class MobileOneBlock(nn.Module):
|
|
|
178
171
|
kernel_scale = 0
|
|
179
172
|
bias_scale = 0
|
|
180
173
|
if self.rbr_scale is not None:
|
|
181
|
-
|
|
174
|
+
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
|
|
182
175
|
pad = self.kernel_size // 2
|
|
183
176
|
kernel_scale = F.pad(kernel_scale, [pad, pad, pad, pad])
|
|
184
177
|
|
|
@@ -186,13 +179,13 @@ class MobileOneBlock(nn.Module):
|
|
|
186
179
|
kernel_identity = 0
|
|
187
180
|
bias_identity = 0
|
|
188
181
|
if self.rbr_skip is not None:
|
|
189
|
-
|
|
182
|
+
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
|
|
190
183
|
|
|
191
184
|
# Get weights and bias of conv branches
|
|
192
185
|
kernel_conv = 0
|
|
193
186
|
bias_conv = 0
|
|
194
187
|
for ix in range(self.num_conv_branches):
|
|
195
|
-
|
|
188
|
+
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
|
|
196
189
|
kernel_conv += _kernel
|
|
197
190
|
bias_conv += _bias
|
|
198
191
|
|
birder/net/mobilevit_v1.py
CHANGED
|
@@ -101,8 +101,8 @@ class MobileVitBlock(nn.Module):
|
|
|
101
101
|
x = self.conv_1x1(x)
|
|
102
102
|
|
|
103
103
|
# Unfold (feature map -> patches)
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
patch_h, patch_w = self.patch_size
|
|
105
|
+
B, C, H, W = x.shape
|
|
106
106
|
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
|
|
107
107
|
num_patch_h = new_h // patch_h # n_h, n_w
|
|
108
108
|
num_patch_w = new_w // patch_w
|
birder/net/mobilevit_v2.py
CHANGED
|
@@ -63,7 +63,7 @@ class LinearSelfAttention(nn.Module):
|
|
|
63
63
|
# Project x into query, key and value
|
|
64
64
|
# Query --> [B, 1, P, N]
|
|
65
65
|
# value, key --> [B, d, P, N]
|
|
66
|
-
|
|
66
|
+
query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
|
|
67
67
|
|
|
68
68
|
# apply softmax along N dimension
|
|
69
69
|
context_scores = F.softmax(query, dim=-1)
|
|
@@ -98,14 +98,10 @@ class LinearTransformerBlock(nn.Module):
|
|
|
98
98
|
|
|
99
99
|
self.norm2 = nn.GroupNorm(num_groups=1, num_channels=embed_dim)
|
|
100
100
|
self.mlp = nn.Sequential(
|
|
101
|
-
nn.Conv2d(
|
|
102
|
-
embed_dim, int(embed_dim * mlp_ratio), kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
|
|
103
|
-
),
|
|
101
|
+
nn.Conv2d(embed_dim, int(embed_dim * mlp_ratio), kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
|
|
104
102
|
nn.SiLU(),
|
|
105
103
|
nn.Dropout(drop),
|
|
106
|
-
nn.Conv2d(
|
|
107
|
-
int(embed_dim * mlp_ratio), embed_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
|
|
108
|
-
),
|
|
104
|
+
nn.Conv2d(int(embed_dim * mlp_ratio), embed_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
|
|
109
105
|
)
|
|
110
106
|
self.drop_path2 = StochasticDepth(drop_path, mode="row")
|
|
111
107
|
|
|
@@ -166,8 +162,8 @@ class MobileVitBlock(nn.Module):
|
|
|
166
162
|
self.patch_area = self.patch_size[0] * self.patch_size[1]
|
|
167
163
|
|
|
168
164
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
169
|
-
|
|
170
|
-
|
|
165
|
+
B, C, H, W = x.shape
|
|
166
|
+
patch_h, patch_w = self.patch_size
|
|
171
167
|
new_h = math.ceil(H / patch_h) * patch_h
|
|
172
168
|
new_w = math.ceil(W / patch_w) * patch_w
|
|
173
169
|
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
|
birder/net/mvit_v2.py
CHANGED
|
@@ -36,7 +36,7 @@ from birder.net.base import TokenRetentionResultType
|
|
|
36
36
|
def pre_pool(
|
|
37
37
|
x: torch.Tensor, hw_shape: tuple[int, int], has_cls_token: bool
|
|
38
38
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
39
|
-
|
|
39
|
+
H, W = hw_shape
|
|
40
40
|
if has_cls_token is True:
|
|
41
41
|
cls_tok = x[:, :, :1, :]
|
|
42
42
|
x = x[:, :, 1:, :]
|
|
@@ -68,8 +68,8 @@ def cal_rel_pos_spatial(
|
|
|
68
68
|
rel_pos_w: torch.Tensor,
|
|
69
69
|
) -> torch.Tensor:
|
|
70
70
|
sp_idx = 1 if has_cls_token is True else 0
|
|
71
|
-
|
|
72
|
-
|
|
71
|
+
q_h, q_w = q_shape
|
|
72
|
+
k_h, k_w = k_shape
|
|
73
73
|
|
|
74
74
|
# Scale up rel pos if shapes for q and k are different.
|
|
75
75
|
q_h_ratio = max(k_h / q_h, 1.0)
|
|
@@ -90,7 +90,7 @@ def cal_rel_pos_spatial(
|
|
|
90
90
|
rel_h = rel_pos_h[dist_h.long()]
|
|
91
91
|
rel_w = rel_pos_w[dist_w.long()]
|
|
92
92
|
|
|
93
|
-
|
|
93
|
+
B, n_head, _, dim = q.shape
|
|
94
94
|
|
|
95
95
|
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
|
|
96
96
|
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h)
|
|
@@ -108,7 +108,7 @@ class SequentialWithShape(nn.Sequential):
|
|
|
108
108
|
self, x: torch.Tensor, hw_shape: tuple[int, int]
|
|
109
109
|
) -> tuple[torch.Tensor, tuple[int, int]]:
|
|
110
110
|
for module in self:
|
|
111
|
-
|
|
111
|
+
x, hw_shape = module(x, hw_shape)
|
|
112
112
|
|
|
113
113
|
return (x, hw_shape)
|
|
114
114
|
|
|
@@ -129,7 +129,7 @@ class PatchEmbed(nn.Module):
|
|
|
129
129
|
|
|
130
130
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
|
|
131
131
|
x = self.proj(x)
|
|
132
|
-
|
|
132
|
+
H, W = x.shape[2:4]
|
|
133
133
|
|
|
134
134
|
x = x.flatten(2).transpose(1, 2)
|
|
135
135
|
|
|
@@ -227,31 +227,31 @@ class MultiScaleAttention(nn.Module):
|
|
|
227
227
|
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
|
|
228
228
|
|
|
229
229
|
def forward(self, x: torch.Tensor, hw_shape: tuple[int, int]) -> tuple[torch.Tensor, tuple[int, int]]:
|
|
230
|
-
|
|
230
|
+
B, N, _ = x.size()
|
|
231
231
|
|
|
232
232
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
233
|
-
|
|
233
|
+
q, k, v = qkv.unbind(dim=0)
|
|
234
234
|
|
|
235
235
|
if self.pool_q is not None:
|
|
236
|
-
|
|
236
|
+
q, q_tok = pre_pool(q, hw_shape, self.has_cls_token)
|
|
237
237
|
q = self.pool_q(q)
|
|
238
|
-
|
|
238
|
+
q, q_shape = post_pool(q, self.num_heads, q_tok)
|
|
239
239
|
q = self.norm_q(q)
|
|
240
240
|
else:
|
|
241
241
|
q_shape = hw_shape
|
|
242
242
|
|
|
243
243
|
if self.pool_k is not None:
|
|
244
|
-
|
|
244
|
+
k, k_tok = pre_pool(k, hw_shape, self.has_cls_token)
|
|
245
245
|
k = self.pool_k(k)
|
|
246
|
-
|
|
246
|
+
k, k_shape = post_pool(k, self.num_heads, k_tok)
|
|
247
247
|
k = self.norm_k(k)
|
|
248
248
|
else:
|
|
249
249
|
k_shape = hw_shape
|
|
250
250
|
|
|
251
251
|
if self.pool_v is not None:
|
|
252
|
-
|
|
252
|
+
v, v_tok = pre_pool(v, hw_shape, self.has_cls_token)
|
|
253
253
|
v = self.pool_v(v)
|
|
254
|
-
|
|
254
|
+
v, _ = post_pool(v, self.num_heads, v_tok)
|
|
255
255
|
v = self.norm_v(v)
|
|
256
256
|
|
|
257
257
|
attn = (q * self.scale) @ k.transpose(-2, -1)
|
|
@@ -337,8 +337,8 @@ class MultiScaleBlock(nn.Module):
|
|
|
337
337
|
else:
|
|
338
338
|
cls_tok = None
|
|
339
339
|
|
|
340
|
-
|
|
341
|
-
|
|
340
|
+
B, _, C = x.size()
|
|
341
|
+
H, W = hw_shape
|
|
342
342
|
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
|
343
343
|
x = self.pool_skip(x)
|
|
344
344
|
x = x.reshape(B, C, -1).transpose(1, 2)
|
|
@@ -349,7 +349,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
349
349
|
|
|
350
350
|
def forward(self, x: torch.Tensor, hw_shape: tuple[int, int]) -> tuple[torch.Tensor, tuple[int, int]]:
|
|
351
351
|
x_norm = self.norm1(x)
|
|
352
|
-
|
|
352
|
+
x_block, hw_shape_new = self.attn(x_norm, hw_shape)
|
|
353
353
|
|
|
354
354
|
if self.proj_attn is not None:
|
|
355
355
|
x = self.proj_attn(x_norm)
|
|
@@ -421,7 +421,7 @@ class MultiScaleVitStage(nn.Module):
|
|
|
421
421
|
|
|
422
422
|
def forward(self, x: torch.Tensor, hw_shape: tuple[int, int]) -> tuple[torch.Tensor, tuple[int, int]]:
|
|
423
423
|
for blk in self.blocks:
|
|
424
|
-
|
|
424
|
+
x, hw_shape = blk(x, hw_shape)
|
|
425
425
|
|
|
426
426
|
return (x, hw_shape)
|
|
427
427
|
|
|
@@ -523,14 +523,14 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
523
523
|
nn.init.zeros_(m.bias)
|
|
524
524
|
|
|
525
525
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
526
|
-
|
|
526
|
+
x, hw_shape = self.patch_embed(x)
|
|
527
527
|
if self.cls_token is not None:
|
|
528
528
|
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
|
|
529
529
|
x = torch.concat((cls_tokens, x), dim=1)
|
|
530
530
|
|
|
531
531
|
out = {}
|
|
532
532
|
for name, module in self.body.named_children():
|
|
533
|
-
|
|
533
|
+
x, hw_shape = module(x, hw_shape)
|
|
534
534
|
if name in self.return_stages:
|
|
535
535
|
x_inter = x
|
|
536
536
|
if self.cls_token is not None:
|
|
@@ -561,7 +561,7 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
561
561
|
) -> TokenRetentionResultType:
|
|
562
562
|
B = x.size(0)
|
|
563
563
|
|
|
564
|
-
|
|
564
|
+
x, hw_shape = self.patch_embed(x)
|
|
565
565
|
x = mask_tensor(
|
|
566
566
|
x.permute(0, 2, 1).reshape(B, -1, hw_shape[0], hw_shape[1]),
|
|
567
567
|
mask,
|
|
@@ -574,7 +574,7 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
574
574
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
575
575
|
x = torch.concat((cls_tokens, x), dim=1)
|
|
576
576
|
|
|
577
|
-
|
|
577
|
+
x, _ = self.body(x, hw_shape)
|
|
578
578
|
x = self.norm(x)
|
|
579
579
|
|
|
580
580
|
result: TokenRetentionResultType = {}
|
|
@@ -596,12 +596,12 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
596
596
|
return result
|
|
597
597
|
|
|
598
598
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
599
|
-
|
|
599
|
+
x, hw_shape = self.patch_embed(x)
|
|
600
600
|
if self.cls_token is not None:
|
|
601
601
|
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
|
|
602
602
|
x = torch.concat((cls_tokens, x), dim=1)
|
|
603
603
|
|
|
604
|
-
|
|
604
|
+
x, _ = self.body(x, hw_shape)
|
|
605
605
|
x = self.norm(x)
|
|
606
606
|
|
|
607
607
|
return x
|
birder/net/nextvit.py
CHANGED
|
@@ -165,7 +165,7 @@ class E_MHSA(nn.Module):
|
|
|
165
165
|
self.norm = nn.Identity()
|
|
166
166
|
|
|
167
167
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
168
|
-
|
|
168
|
+
B, N, C = x.size()
|
|
169
169
|
q = self.q(x)
|
|
170
170
|
q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
|
|
171
171
|
|
|
@@ -226,7 +226,7 @@ class NTB(nn.Module):
|
|
|
226
226
|
|
|
227
227
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
228
228
|
x = self.patch_embed(x)
|
|
229
|
-
|
|
229
|
+
B, C, H, W = x.size()
|
|
230
230
|
out = self.norm1(x)
|
|
231
231
|
|
|
232
232
|
out = out.reshape(B, C, H * W).permute(0, 2, 1)
|
birder/net/pit.py
CHANGED
|
@@ -29,12 +29,12 @@ class SequentialTuple(nn.Sequential):
|
|
|
29
29
|
self, x: tuple[torch.Tensor, torch.Tensor]
|
|
30
30
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
31
31
|
for module in self:
|
|
32
|
-
x = module(x)
|
|
32
|
+
x = module(*x)
|
|
33
33
|
|
|
34
34
|
return x
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
class
|
|
37
|
+
class PiTStage(nn.Module):
|
|
38
38
|
def __init__(
|
|
39
39
|
self,
|
|
40
40
|
base_dim: int,
|
|
@@ -59,13 +59,12 @@ class Transformer(nn.Module):
|
|
|
59
59
|
dpr=drop_path_prob,
|
|
60
60
|
)
|
|
61
61
|
|
|
62
|
-
def forward(self,
|
|
63
|
-
(x, cls_tokens) = xt
|
|
62
|
+
def forward(self, x: torch.Tensor, cls_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
64
63
|
token_length = cls_tokens.shape[1]
|
|
65
64
|
if self.pool is not None:
|
|
66
|
-
|
|
65
|
+
x, cls_tokens = self.pool(x, cls_tokens)
|
|
67
66
|
|
|
68
|
-
|
|
67
|
+
B, C, H, W = x.size()
|
|
69
68
|
x = x.flatten(2).transpose(1, 2)
|
|
70
69
|
x = torch.concat((cls_tokens, x), dim=1)
|
|
71
70
|
x = self.encoder(x)
|
|
@@ -142,7 +141,7 @@ class PiT(DetectorBackbone):
|
|
|
142
141
|
if i > 0:
|
|
143
142
|
pool = Pooling(prev_dim, embed_dim)
|
|
144
143
|
|
|
145
|
-
stages[f"stage{i+1}"] =
|
|
144
|
+
stages[f"stage{i+1}"] = PiTStage(
|
|
146
145
|
base_dims[i],
|
|
147
146
|
depth,
|
|
148
147
|
heads=heads[i],
|
|
@@ -158,7 +157,7 @@ class PiT(DetectorBackbone):
|
|
|
158
157
|
self.body = SequentialTuple(stages)
|
|
159
158
|
self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
|
|
160
159
|
|
|
161
|
-
self.return_stages =
|
|
160
|
+
self.return_stages = [f"stage{idx + 1}" for idx in range(len(depths))]
|
|
162
161
|
self.return_channels = return_channels
|
|
163
162
|
self.embedding_size = embed_dim
|
|
164
163
|
self.dist_classifier = self.create_classifier()
|
|
@@ -197,7 +196,7 @@ class PiT(DetectorBackbone):
|
|
|
197
196
|
|
|
198
197
|
out = {}
|
|
199
198
|
for name, module in self.body.named_children():
|
|
200
|
-
|
|
199
|
+
x, cls_tokens = module(x, cls_tokens)
|
|
201
200
|
if name in self.return_stages:
|
|
202
201
|
out[name] = x
|
|
203
202
|
|
|
@@ -218,12 +217,13 @@ class PiT(DetectorBackbone):
|
|
|
218
217
|
x = self.stem(x)
|
|
219
218
|
x = x + self.pos_embed
|
|
220
219
|
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
|
221
|
-
|
|
220
|
+
for stage in self.body.children():
|
|
221
|
+
x, cls_tokens = stage(x, cls_tokens)
|
|
222
222
|
|
|
223
223
|
return (x, cls_tokens)
|
|
224
224
|
|
|
225
225
|
def embedding(self, x: torch.Tensor) -> torch.Tensor:
|
|
226
|
-
|
|
226
|
+
_, cls_tokens = self.forward_features(x)
|
|
227
227
|
cls_tokens = self.norm(cls_tokens)
|
|
228
228
|
|
|
229
229
|
return cls_tokens
|
|
@@ -312,18 +312,3 @@ registry.register_model_config(
|
|
|
312
312
|
"drop_path_rate": 0.1,
|
|
313
313
|
},
|
|
314
314
|
)
|
|
315
|
-
|
|
316
|
-
registry.register_weights(
|
|
317
|
-
"pit_t_il-common",
|
|
318
|
-
{
|
|
319
|
-
"description": "PiT tiny model trained on the il-common dataset",
|
|
320
|
-
"resolution": (256, 256),
|
|
321
|
-
"formats": {
|
|
322
|
-
"pt": {
|
|
323
|
-
"file_size": 18.4,
|
|
324
|
-
"sha256": "5f6bd74b09c1ee541ee2ddae4844ce501b4b3218201ea6381fce0b8fc30257f2",
|
|
325
|
-
}
|
|
326
|
-
},
|
|
327
|
-
"net": {"network": "pit_t", "tag": "il-common"},
|
|
328
|
-
},
|
|
329
|
-
)
|
birder/net/pvt_v1.py
CHANGED
|
@@ -56,7 +56,7 @@ class Attention(nn.Module):
|
|
|
56
56
|
self.norm = None
|
|
57
57
|
|
|
58
58
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
59
|
-
|
|
59
|
+
B, N, C = x.shape
|
|
60
60
|
q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
|
61
61
|
|
|
62
62
|
if self.sr is not None:
|
|
@@ -65,7 +65,7 @@ class Attention(nn.Module):
|
|
|
65
65
|
x = self.norm(x)
|
|
66
66
|
|
|
67
67
|
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
68
|
-
|
|
68
|
+
k, v = kv.unbind(0)
|
|
69
69
|
|
|
70
70
|
x = F.scaled_dot_product_attention( # pylint: disable=not-callable
|
|
71
71
|
q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
|
|
@@ -177,7 +177,7 @@ class PyramidVisionTransformerStage(nn.Module):
|
|
|
177
177
|
|
|
178
178
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
179
179
|
x = self.downsample(x) # B, C, H, W -> B, H, W, C
|
|
180
|
-
|
|
180
|
+
B, H, W, C = x.size()
|
|
181
181
|
x = x.reshape(B, -1, C)
|
|
182
182
|
x = x + self.pos_embed
|
|
183
183
|
if self.cls_token is not None:
|
|
@@ -264,7 +264,7 @@ class PVT_v1(DetectorBackbone):
|
|
|
264
264
|
|
|
265
265
|
out = {}
|
|
266
266
|
for name, module in self.body.named_children():
|
|
267
|
-
|
|
267
|
+
B, _, H, W = x.size()
|
|
268
268
|
x = module(x)
|
|
269
269
|
if name in self.return_stages:
|
|
270
270
|
if name == "stage4":
|
birder/net/pvt_v2.py
CHANGED
|
@@ -29,13 +29,7 @@ class MLP(nn.Module):
|
|
|
29
29
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
30
30
|
self.relu = nn.ReLU() if extra_relu else nn.Identity()
|
|
31
31
|
self.dwconv = nn.Conv2d(
|
|
32
|
-
hidden_features,
|
|
33
|
-
hidden_features,
|
|
34
|
-
kernel_size=(3, 3),
|
|
35
|
-
stride=(1, 1),
|
|
36
|
-
padding=(1, 1),
|
|
37
|
-
groups=hidden_features,
|
|
38
|
-
bias=True,
|
|
32
|
+
hidden_features, hidden_features, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=hidden_features
|
|
39
33
|
)
|
|
40
34
|
self.act = nn.GELU()
|
|
41
35
|
self.fc2 = nn.Linear(hidden_features, in_features)
|
|
@@ -44,7 +38,7 @@ class MLP(nn.Module):
|
|
|
44
38
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
45
39
|
x = self.fc1(x)
|
|
46
40
|
x = self.relu(x)
|
|
47
|
-
|
|
41
|
+
B, _, C = x.shape
|
|
48
42
|
x = x.transpose(1, 2).view(B, C, H, W)
|
|
49
43
|
x = self.dwconv(x)
|
|
50
44
|
x = x.flatten(2).transpose(1, 2)
|
|
@@ -98,7 +92,7 @@ class Attention(nn.Module):
|
|
|
98
92
|
assert (self.pool is None and self.act is None) or (self.pool is not None and self.act is not None)
|
|
99
93
|
|
|
100
94
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
101
|
-
|
|
95
|
+
B, N, C = x.shape
|
|
102
96
|
q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
|
103
97
|
|
|
104
98
|
if self.pool is not None and self.act is not None:
|
|
@@ -114,7 +108,7 @@ class Attention(nn.Module):
|
|
|
114
108
|
x = self.norm(x)
|
|
115
109
|
|
|
116
110
|
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
117
|
-
|
|
111
|
+
k, v = kv.unbind(0)
|
|
118
112
|
|
|
119
113
|
x = F.scaled_dot_product_attention( # pylint: disable=not-callable
|
|
120
114
|
q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
|
|
@@ -238,7 +232,7 @@ class PyramidVisionTransformerStage(nn.Module):
|
|
|
238
232
|
|
|
239
233
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
240
234
|
x = self.downsample(x) # B, C, H, W -> B, H, W, C
|
|
241
|
-
|
|
235
|
+
B, H, W, C = x.shape
|
|
242
236
|
x = x.reshape(B, -1, C)
|
|
243
237
|
for blk in self.blocks:
|
|
244
238
|
x = blk(x, H, W)
|
birder/net/regionvit.py
CHANGED
|
@@ -30,8 +30,8 @@ def convert_to_flatten_layout(
|
|
|
30
30
|
cls_tokens: torch.Tensor, patch_tokens: torch.Tensor, ws: int
|
|
31
31
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], int, int, int, int, int, int]:
|
|
32
32
|
# Padding if added will be at the bottom right
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
B, C, H, W = patch_tokens.size()
|
|
34
|
+
_, _, h_ks, w_ks = cls_tokens.size()
|
|
35
35
|
need_mask = False
|
|
36
36
|
p_l = 0
|
|
37
37
|
p_r = 0
|
|
@@ -43,13 +43,13 @@ def convert_to_flatten_layout(
|
|
|
43
43
|
patch_tokens = F.pad(patch_tokens, (p_l, p_r, p_t, p_b))
|
|
44
44
|
need_mask = True
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
B, C, H, W = patch_tokens.size()
|
|
47
47
|
kernel_size = (H // h_ks, W // w_ks)
|
|
48
48
|
tmp = F.unfold(patch_tokens, kernel_size=kernel_size, dilation=(1, 1), padding=(0, 0), stride=kernel_size)
|
|
49
49
|
patch_tokens = tmp.transpose(1, 2).reshape(-1, C, kernel_size[0] * kernel_size[1]).transpose(-2, -1)
|
|
50
50
|
|
|
51
51
|
if need_mask is True:
|
|
52
|
-
|
|
52
|
+
bh_sk_s, ksks, C = patch_tokens.size()
|
|
53
53
|
h_s = H // ws
|
|
54
54
|
w_s = W // ws
|
|
55
55
|
mask = torch.ones(bh_sk_s // B, 1 + ksks, 1 + ksks, device=patch_tokens.device, dtype=torch.float)
|
|
@@ -116,7 +116,7 @@ class SequentialWithTwo(nn.Sequential):
|
|
|
116
116
|
self, cls_tokens: torch.Tensor, patch_tokens: torch.Tensor
|
|
117
117
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
118
118
|
for module in self:
|
|
119
|
-
|
|
119
|
+
cls_tokens, patch_tokens = module(cls_tokens, patch_tokens)
|
|
120
120
|
|
|
121
121
|
return (cls_tokens, patch_tokens)
|
|
122
122
|
|
|
@@ -178,9 +178,9 @@ class AttentionWithRelPos(nn.Module):
|
|
|
178
178
|
nn.init.trunc_normal_(self.rel_pos, std=0.02)
|
|
179
179
|
|
|
180
180
|
def forward(self, x: torch.Tensor, patch_attn: bool = False, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
181
|
-
|
|
181
|
+
B, N, C = x.size()
|
|
182
182
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
183
|
-
|
|
183
|
+
q, k, v = qkv.unbind(0)
|
|
184
184
|
|
|
185
185
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
186
186
|
|
|
@@ -242,7 +242,7 @@ class PatchEmbed(nn.Module):
|
|
|
242
242
|
raise ValueError("Unknown patch_conv_type")
|
|
243
243
|
|
|
244
244
|
def forward(self, x: torch.Tensor, extra_padding: bool = False) -> torch.Tensor:
|
|
245
|
-
|
|
245
|
+
_, _, H, W = x.size()
|
|
246
246
|
if extra_padding and (H % self.patch_size[0] != 0 or W % self.patch_size[1] != 0):
|
|
247
247
|
p_l = (self.patch_size[1] - W % self.patch_size[1]) // 2
|
|
248
248
|
p_r = (self.patch_size[1] - W % self.patch_size[1]) - p_l
|
|
@@ -384,12 +384,12 @@ class ConvAttStage(nn.Module):
|
|
|
384
384
|
self.ws = window_size
|
|
385
385
|
|
|
386
386
|
def forward(self, cls_tokens: torch.Tensor, patch_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
387
|
-
|
|
388
|
-
|
|
387
|
+
cls_tokens, patch_tokens = self.proj(cls_tokens, patch_tokens)
|
|
388
|
+
out, mask, p_r, p_b, B, C, H, W = convert_to_flatten_layout(cls_tokens, patch_tokens, self.ws[0])
|
|
389
389
|
for blk in self.blocks:
|
|
390
390
|
out = blk(out, mask, B)
|
|
391
391
|
|
|
392
|
-
|
|
392
|
+
cls_tokens, patch_tokens = convert_to_spatial_layout(out, B, C, H, W, self.ws, mask, p_r, p_b)
|
|
393
393
|
|
|
394
394
|
return (cls_tokens, patch_tokens)
|
|
395
395
|
|
|
@@ -480,7 +480,7 @@ class RegionViT(DetectorBackbone):
|
|
|
480
480
|
|
|
481
481
|
out = {}
|
|
482
482
|
for name, module in self.body.named_children():
|
|
483
|
-
|
|
483
|
+
cls_tokens, x = module(cls_tokens, x)
|
|
484
484
|
if name in self.return_stages:
|
|
485
485
|
out[name] = x
|
|
486
486
|
|
|
@@ -503,14 +503,14 @@ class RegionViT(DetectorBackbone):
|
|
|
503
503
|
o_x = x
|
|
504
504
|
x = self.patch_embed(x)
|
|
505
505
|
cls_tokens = self.cls_token(o_x, extra_padding=True)
|
|
506
|
-
|
|
506
|
+
cls_tokens, x = self.body(cls_tokens, x)
|
|
507
507
|
|
|
508
508
|
return (cls_tokens, x)
|
|
509
509
|
|
|
510
510
|
def embedding(self, x: torch.Tensor) -> torch.Tensor:
|
|
511
|
-
|
|
511
|
+
cls_tokens, _ = self.forward_features(x)
|
|
512
512
|
|
|
513
|
-
|
|
513
|
+
N, C, _, _ = cls_tokens.size()
|
|
514
514
|
cls_tokens = cls_tokens.reshape(N, C, -1).transpose(1, 2)
|
|
515
515
|
cls_tokens = self.norm(cls_tokens)
|
|
516
516
|
out = torch.mean(cls_tokens, dim=1)
|
birder/net/regnet.py
CHANGED
|
@@ -100,7 +100,7 @@ class BlockParams:
|
|
|
100
100
|
group_widths = [group_width] * num_stages
|
|
101
101
|
|
|
102
102
|
# Adjust the compatibility of stage widths and group widths
|
|
103
|
-
|
|
103
|
+
stage_widths, group_widths = cls._adjust_widths_groups_compatibility(
|
|
104
104
|
stage_widths, bottleneck_multipliers, group_widths
|
|
105
105
|
)
|
|
106
106
|
|
birder/net/repghost.py
CHANGED
|
@@ -79,7 +79,7 @@ class RepGhostModule(nn.Module):
|
|
|
79
79
|
if self.reparameterized is True:
|
|
80
80
|
return
|
|
81
81
|
|
|
82
|
-
|
|
82
|
+
kernel, bias = self._get_kernel_bias()
|
|
83
83
|
self.cheap_operation = nn.Conv2d(
|
|
84
84
|
in_channels=self.cheap_operation[0].in_channels,
|
|
85
85
|
out_channels=self.cheap_operation[0].out_channels,
|
|
@@ -87,7 +87,6 @@ class RepGhostModule(nn.Module):
|
|
|
87
87
|
padding=self.cheap_operation[0].padding,
|
|
88
88
|
dilation=self.cheap_operation[0].dilation,
|
|
89
89
|
groups=self.cheap_operation[0].groups,
|
|
90
|
-
bias=True,
|
|
91
90
|
)
|
|
92
91
|
|
|
93
92
|
self.cheap_operation.weight.data = kernel
|
|
@@ -98,9 +97,9 @@ class RepGhostModule(nn.Module):
|
|
|
98
97
|
self.reparameterized = True
|
|
99
98
|
|
|
100
99
|
def _get_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
101
|
-
|
|
100
|
+
kernel, bias = self._fuse_bn_tensor(self.cheap_operation[0], self.cheap_operation[1])
|
|
102
101
|
if self.fusion_bn is not None:
|
|
103
|
-
|
|
102
|
+
kernel1x1, bias_bn = self._fuse_bn_tensor(nn.Identity(), self.fusion_bn, kernel.shape[0])
|
|
104
103
|
kernel += F.pad(kernel1x1, [1, 1, 1, 1])
|
|
105
104
|
bias += bias_bn
|
|
106
105
|
|
|
@@ -299,7 +298,7 @@ class RepGhost(DetectorBackbone):
|
|
|
299
298
|
out_channels = 1280
|
|
300
299
|
self.features = nn.Sequential(
|
|
301
300
|
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
|
302
|
-
nn.Conv2d(prev_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
|
|
301
|
+
nn.Conv2d(prev_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
|
|
303
302
|
nn.ReLU(inplace=True),
|
|
304
303
|
nn.Flatten(1),
|
|
305
304
|
nn.Dropout(p=0.2),
|