birder 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/lib.py +2 -9
- birder/common/training_cli.py +24 -0
- birder/common/training_utils.py +338 -41
- birder/data/collators/detection.py +11 -3
- birder/data/dataloader/webdataset.py +12 -2
- birder/data/datasets/coco.py +8 -10
- birder/data/transforms/detection.py +30 -13
- birder/inference/detection.py +108 -4
- birder/inference/wbf.py +226 -0
- birder/kernels/load_kernel.py +16 -11
- birder/kernels/soft_nms/soft_nms.cpp +17 -18
- birder/net/__init__.py +8 -0
- birder/net/cait.py +4 -3
- birder/net/convnext_v1.py +5 -0
- birder/net/crossformer.py +33 -30
- birder/net/crossvit.py +4 -3
- birder/net/deit.py +3 -3
- birder/net/deit3.py +3 -3
- birder/net/detection/deformable_detr.py +2 -5
- birder/net/detection/detr.py +2 -5
- birder/net/detection/efficientdet.py +67 -93
- birder/net/detection/fcos.py +2 -7
- birder/net/detection/retinanet.py +2 -7
- birder/net/detection/rt_detr_v1.py +2 -0
- birder/net/detection/yolo_anchors.py +205 -0
- birder/net/detection/yolo_v2.py +25 -24
- birder/net/detection/yolo_v3.py +39 -40
- birder/net/detection/yolo_v4.py +28 -26
- birder/net/detection/yolo_v4_tiny.py +24 -20
- birder/net/efficientformer_v1.py +15 -9
- birder/net/efficientformer_v2.py +39 -29
- birder/net/efficientvit_msft.py +9 -7
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +1 -0
- birder/net/flexivit.py +5 -4
- birder/net/gc_vit.py +671 -0
- birder/net/hiera.py +12 -9
- birder/net/hornet.py +9 -7
- birder/net/iformer.py +8 -6
- birder/net/levit.py +42 -30
- birder/net/lit_v1.py +472 -0
- birder/net/lit_v1_tiny.py +357 -0
- birder/net/lit_v2.py +436 -0
- birder/net/maxvit.py +67 -55
- birder/net/mobilenet_v4_hybrid.py +1 -1
- birder/net/mobileone.py +1 -0
- birder/net/mvit_v2.py +13 -12
- birder/net/pit.py +4 -3
- birder/net/pvt_v1.py +4 -1
- birder/net/repghost.py +1 -0
- birder/net/repvgg.py +1 -0
- birder/net/repvit.py +1 -0
- birder/net/resnet_v1.py +1 -1
- birder/net/resnext.py +67 -25
- birder/net/rope_deit3.py +5 -3
- birder/net/rope_flexivit.py +7 -4
- birder/net/rope_vit.py +10 -5
- birder/net/se_resnet_v1.py +46 -0
- birder/net/se_resnext.py +3 -0
- birder/net/simple_vit.py +11 -8
- birder/net/swin_transformer_v1.py +71 -68
- birder/net/swin_transformer_v2.py +38 -31
- birder/net/tiny_vit.py +20 -10
- birder/net/transnext.py +38 -28
- birder/net/vit.py +5 -19
- birder/net/vit_parallel.py +5 -4
- birder/net/vit_sam.py +38 -37
- birder/net/vovnet_v1.py +15 -0
- birder/net/vovnet_v2.py +31 -1
- birder/ops/msda.py +108 -43
- birder/ops/swattention.py +124 -61
- birder/results/detection.py +4 -0
- birder/scripts/benchmark.py +110 -32
- birder/scripts/predict.py +8 -0
- birder/scripts/predict_detection.py +18 -11
- birder/scripts/train.py +48 -46
- birder/scripts/train_barlow_twins.py +44 -45
- birder/scripts/train_byol.py +44 -45
- birder/scripts/train_capi.py +50 -49
- birder/scripts/train_data2vec.py +45 -47
- birder/scripts/train_data2vec2.py +45 -47
- birder/scripts/train_detection.py +83 -50
- birder/scripts/train_dino_v1.py +60 -47
- birder/scripts/train_dino_v2.py +86 -52
- birder/scripts/train_dino_v2_dist.py +84 -50
- birder/scripts/train_franca.py +51 -52
- birder/scripts/train_i_jepa.py +45 -47
- birder/scripts/train_ibot.py +51 -53
- birder/scripts/train_kd.py +194 -76
- birder/scripts/train_mim.py +44 -45
- birder/scripts/train_mmcr.py +44 -45
- birder/scripts/train_rotnet.py +45 -46
- birder/scripts/train_simclr.py +44 -45
- birder/scripts/train_vicreg.py +44 -45
- birder/tools/auto_anchors.py +20 -1
- birder/tools/convert_model.py +18 -15
- birder/tools/det_results.py +114 -2
- birder/tools/pack.py +172 -103
- birder/tools/quantize_model.py +73 -67
- birder/tools/show_det_iterator.py +10 -1
- birder/version.py +1 -1
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/METADATA +4 -3
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/RECORD +107 -101
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -22,9 +22,19 @@ def make_wds_loader(
|
|
|
22
22
|
shuffle: bool = False,
|
|
23
23
|
*,
|
|
24
24
|
exact: bool = False,
|
|
25
|
+
infinite: bool = False,
|
|
25
26
|
) -> DataLoader:
|
|
27
|
+
assert exact is False or infinite is False
|
|
28
|
+
|
|
29
|
+
if infinite is True:
|
|
30
|
+
dataset_iterable = dataset.repeat()
|
|
31
|
+
elif exact is False:
|
|
32
|
+
dataset_iterable = dataset.repeat()
|
|
33
|
+
else:
|
|
34
|
+
dataset_iterable = dataset
|
|
35
|
+
|
|
26
36
|
dataloader = wds.WebLoader(
|
|
27
|
-
|
|
37
|
+
dataset_iterable,
|
|
28
38
|
batch_size=batch_size,
|
|
29
39
|
num_workers=num_workers,
|
|
30
40
|
prefetch_factor=prefetch_factor,
|
|
@@ -43,7 +53,7 @@ def make_wds_loader(
|
|
|
43
53
|
epoch_size = math.ceil(len(dataset) / (batch_size * world_size))
|
|
44
54
|
|
|
45
55
|
dataloader = dataloader.with_length(epoch_size, silent=True)
|
|
46
|
-
if exact is False:
|
|
56
|
+
if exact is False and infinite is False:
|
|
47
57
|
dataloader = dataloader.with_epoch(epoch_size)
|
|
48
58
|
|
|
49
59
|
return dataloader
|
birder/data/datasets/coco.py
CHANGED
|
@@ -98,10 +98,14 @@ class CocoTraining(CocoBase):
|
|
|
98
98
|
class CocoInference(CocoBase):
|
|
99
99
|
def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any, list[int]]:
|
|
100
100
|
coco_id = self.dataset.ids[index]
|
|
101
|
-
|
|
101
|
+
img_info = self.dataset.coco.loadImgs(coco_id)[0]
|
|
102
|
+
path = img_info["file_name"]
|
|
102
103
|
(sample, labels) = self.dataset[index]
|
|
103
104
|
|
|
104
|
-
|
|
105
|
+
# Get original image size (height, width) before transforms
|
|
106
|
+
orig_size = [img_info["height"], img_info["width"]]
|
|
107
|
+
|
|
108
|
+
return (path, sample, labels, orig_size)
|
|
105
109
|
|
|
106
110
|
|
|
107
111
|
class CocoMosaicTraining(CocoBase):
|
|
@@ -127,9 +131,7 @@ class CocoMosaicTraining(CocoBase):
|
|
|
127
131
|
self._mosaic_decay_epochs: Optional[int] = None
|
|
128
132
|
self._mosaic_decay_start: Optional[int] = None
|
|
129
133
|
|
|
130
|
-
def configure_mosaic_linear_decay(
|
|
131
|
-
self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1
|
|
132
|
-
) -> None:
|
|
134
|
+
def configure_mosaic_linear_decay(self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1) -> None:
|
|
133
135
|
if total_epochs <= 0:
|
|
134
136
|
raise ValueError("total_epochs must be positive")
|
|
135
137
|
if decay_fraction <= 0.0 or decay_fraction > 1.0:
|
|
@@ -141,11 +143,7 @@ class CocoMosaicTraining(CocoBase):
|
|
|
141
143
|
self._mosaic_decay_start = max(1, total_epochs - decay_epochs + 1)
|
|
142
144
|
|
|
143
145
|
def update_mosaic_prob(self, epoch: int) -> Optional[float]:
|
|
144
|
-
if
|
|
145
|
-
self._mosaic_base_prob is None
|
|
146
|
-
or self._mosaic_decay_epochs is None
|
|
147
|
-
or self._mosaic_decay_start is None
|
|
148
|
-
):
|
|
146
|
+
if self._mosaic_base_prob is None or self._mosaic_decay_epochs is None or self._mosaic_decay_start is None:
|
|
149
147
|
return None
|
|
150
148
|
|
|
151
149
|
if epoch >= self._mosaic_decay_start:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import random
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import Any
|
|
@@ -10,6 +11,24 @@ from torchvision.transforms import v2
|
|
|
10
11
|
|
|
11
12
|
from birder.data.transforms.classification import RGBType
|
|
12
13
|
|
|
14
|
+
MULTISCALE_STEP = 32
|
|
15
|
+
DEFAULT_MULTISCALE_MIN_SIZE = 480
|
|
16
|
+
DEFAULT_MULTISCALE_MAX_SIZE = 800
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def build_multiscale_sizes(
|
|
20
|
+
min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE
|
|
21
|
+
) -> tuple[int, ...]:
|
|
22
|
+
if min_size is None:
|
|
23
|
+
min_size = DEFAULT_MULTISCALE_MIN_SIZE
|
|
24
|
+
|
|
25
|
+
start = int(math.ceil(min_size / MULTISCALE_STEP) * MULTISCALE_STEP)
|
|
26
|
+
end = int(math.floor(max_size / MULTISCALE_STEP) * MULTISCALE_STEP)
|
|
27
|
+
if end < start:
|
|
28
|
+
return (start,)
|
|
29
|
+
|
|
30
|
+
return tuple(range(start, end + 1, MULTISCALE_STEP))
|
|
31
|
+
|
|
13
32
|
|
|
14
33
|
class ResizeWithRandomInterpolation(nn.Module):
|
|
15
34
|
def __init__(
|
|
@@ -39,6 +58,7 @@ def get_birder_augment(
|
|
|
39
58
|
dynamic_size: bool,
|
|
40
59
|
multiscale: bool,
|
|
41
60
|
max_size: Optional[int],
|
|
61
|
+
multiscale_min_size: Optional[int],
|
|
42
62
|
post_mosaic: bool = False,
|
|
43
63
|
) -> Callable[..., torch.Tensor]:
|
|
44
64
|
if dynamic_size is True:
|
|
@@ -78,9 +98,7 @@ def get_birder_augment(
|
|
|
78
98
|
# Resize
|
|
79
99
|
if multiscale is True:
|
|
80
100
|
transformations.append(
|
|
81
|
-
v2.RandomShortestSize(
|
|
82
|
-
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
83
|
-
),
|
|
101
|
+
v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
|
|
84
102
|
)
|
|
85
103
|
else:
|
|
86
104
|
transformations.append(
|
|
@@ -132,6 +150,7 @@ def get_birder_augment(
|
|
|
132
150
|
AugType = Literal["birder", "lsj", "multiscale", "ssd", "ssdlite", "yolo", "detr"]
|
|
133
151
|
|
|
134
152
|
|
|
153
|
+
# pylint: disable=too-many-return-statements
|
|
135
154
|
def training_preset(
|
|
136
155
|
size: tuple[int, int],
|
|
137
156
|
aug_type: AugType,
|
|
@@ -140,6 +159,7 @@ def training_preset(
|
|
|
140
159
|
dynamic_size: bool = False,
|
|
141
160
|
multiscale: bool = False,
|
|
142
161
|
max_size: Optional[int] = None,
|
|
162
|
+
multiscale_min_size: Optional[int] = None,
|
|
143
163
|
post_mosaic: bool = False,
|
|
144
164
|
) -> Callable[..., torch.Tensor]:
|
|
145
165
|
mean = rgv_values["mean"]
|
|
@@ -159,7 +179,9 @@ def training_preset(
|
|
|
159
179
|
return v2.Compose( # type:ignore
|
|
160
180
|
[
|
|
161
181
|
v2.ToImage(),
|
|
162
|
-
get_birder_augment(
|
|
182
|
+
get_birder_augment(
|
|
183
|
+
size, level, fill_value, dynamic_size, multiscale, max_size, multiscale_min_size, post_mosaic
|
|
184
|
+
),
|
|
163
185
|
v2.ToDtype(torch.float32, scale=True),
|
|
164
186
|
v2.Normalize(mean=mean, std=std),
|
|
165
187
|
v2.ToPureTensor(),
|
|
@@ -190,9 +212,7 @@ def training_preset(
|
|
|
190
212
|
return v2.Compose( # type: ignore
|
|
191
213
|
[
|
|
192
214
|
v2.ToImage(),
|
|
193
|
-
v2.RandomShortestSize(
|
|
194
|
-
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
195
|
-
),
|
|
215
|
+
v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
|
|
196
216
|
v2.RandomHorizontalFlip(0.5),
|
|
197
217
|
v2.SanitizeBoundingBoxes(),
|
|
198
218
|
v2.ToDtype(torch.float32, scale=True),
|
|
@@ -264,21 +284,18 @@ def training_preset(
|
|
|
264
284
|
)
|
|
265
285
|
|
|
266
286
|
if aug_type == "detr":
|
|
287
|
+
multiscale_sizes = build_multiscale_sizes(multiscale_min_size)
|
|
267
288
|
return v2.Compose( # type: ignore
|
|
268
289
|
[
|
|
269
290
|
v2.ToImage(),
|
|
270
291
|
v2.RandomChoice(
|
|
271
292
|
[
|
|
272
|
-
v2.RandomShortestSize(
|
|
273
|
-
(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
274
|
-
),
|
|
293
|
+
v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
|
|
275
294
|
v2.Compose(
|
|
276
295
|
[
|
|
277
296
|
v2.RandomShortestSize((400, 500, 600)),
|
|
278
297
|
v2.RandomIoUCrop() if post_mosaic is False else v2.Identity(), # RandomSizeCrop
|
|
279
|
-
v2.RandomShortestSize(
|
|
280
|
-
(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
281
|
-
),
|
|
298
|
+
v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
|
|
282
299
|
]
|
|
283
300
|
),
|
|
284
301
|
]
|
birder/inference/detection.py
CHANGED
|
@@ -5,17 +5,99 @@ from typing import Optional
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.amp
|
|
7
7
|
from PIL import Image
|
|
8
|
+
from torch.nn import functional as F
|
|
8
9
|
from torch.utils.data import DataLoader
|
|
9
10
|
from tqdm import tqdm
|
|
10
11
|
|
|
11
12
|
from birder.conf import settings
|
|
13
|
+
from birder.data.collators.detection import batch_images
|
|
12
14
|
from birder.data.transforms.detection import InferenceTransform
|
|
15
|
+
from birder.inference.wbf import fuse_detections_wbf
|
|
16
|
+
from birder.net.base import make_divisible
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _normalize_image_sizes(inputs: torch.Tensor, image_sizes: Optional[list[list[int]]]) -> list[list[int]]:
|
|
20
|
+
if image_sizes is not None:
|
|
21
|
+
return image_sizes
|
|
22
|
+
|
|
23
|
+
(_, _, height, width) = inputs.shape
|
|
24
|
+
return [[height, width] for _ in range(inputs.size(0))]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _hflip_inputs(inputs: torch.Tensor, image_sizes: list[list[int]]) -> torch.Tensor:
|
|
28
|
+
# Detection collator pads on the right/bottom, so flip only the valid region to keep padding aligned.
|
|
29
|
+
flipped = inputs.clone()
|
|
30
|
+
for idx, (height, width) in enumerate(image_sizes):
|
|
31
|
+
flipped[idx, :, :height, :width] = torch.flip(inputs[idx, :, :height, :width], dims=[2])
|
|
32
|
+
|
|
33
|
+
return flipped
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _resize_batch(
|
|
37
|
+
inputs: torch.Tensor, image_sizes: list[list[int]], scale: float, size_divisible: int
|
|
38
|
+
) -> tuple[torch.Tensor, torch.Tensor, list[list[int]]]:
|
|
39
|
+
resized_images: list[torch.Tensor] = []
|
|
40
|
+
for idx, (height, width) in enumerate(image_sizes):
|
|
41
|
+
target_h = make_divisible(height * scale, size_divisible)
|
|
42
|
+
target_w = make_divisible(width * scale, size_divisible)
|
|
43
|
+
image = inputs[idx, :, :height, :width]
|
|
44
|
+
resized = F.interpolate(image.unsqueeze(0), size=(target_h, target_w), mode="bilinear", align_corners=False)
|
|
45
|
+
resized_images.append(resized.squeeze(0))
|
|
46
|
+
|
|
47
|
+
return batch_images(resized_images, size_divisible)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _rescale_boxes(boxes: torch.Tensor, from_size: list[int], to_size: list[int]) -> torch.Tensor:
|
|
51
|
+
scale_w = to_size[1] / from_size[1]
|
|
52
|
+
scale_h = to_size[0] / from_size[0]
|
|
53
|
+
scale = boxes.new_tensor([scale_w, scale_h, scale_w, scale_h])
|
|
54
|
+
return boxes * scale
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _rescale_detections(
|
|
58
|
+
detections: list[dict[str, torch.Tensor]],
|
|
59
|
+
from_sizes: list[list[int]],
|
|
60
|
+
to_sizes: list[list[int]],
|
|
61
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
62
|
+
for idx, (detection, from_size, to_size) in enumerate(zip(detections, from_sizes, to_sizes)):
|
|
63
|
+
boxes = detection["boxes"]
|
|
64
|
+
if boxes.numel() == 0:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
detections[idx]["boxes"] = _rescale_boxes(boxes, from_size, to_size)
|
|
68
|
+
|
|
69
|
+
return detections
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _invert_hflip_boxes(boxes: torch.Tensor, image_size: list[int]) -> torch.Tensor:
|
|
73
|
+
width = boxes.new_tensor(image_size[1])
|
|
74
|
+
x1 = boxes[:, 0]
|
|
75
|
+
x2 = boxes[:, 2]
|
|
76
|
+
flipped = boxes.clone()
|
|
77
|
+
flipped[:, 0] = width - x2
|
|
78
|
+
flipped[:, 2] = width - x1
|
|
79
|
+
|
|
80
|
+
return flipped
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _invert_detections(
|
|
84
|
+
detections: list[dict[str, torch.Tensor]], image_sizes: list[list[int]]
|
|
85
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
86
|
+
for idx, (detection, image_size) in enumerate(zip(detections, image_sizes)):
|
|
87
|
+
boxes = detection["boxes"]
|
|
88
|
+
if boxes.numel() == 0:
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
detections[idx]["boxes"] = _invert_hflip_boxes(boxes, image_size)
|
|
92
|
+
|
|
93
|
+
return detections
|
|
13
94
|
|
|
14
95
|
|
|
15
96
|
def infer_image(
|
|
16
97
|
net: torch.nn.Module | torch.ScriptModule,
|
|
17
98
|
sample: Image.Image | str,
|
|
18
99
|
transform: Callable[..., torch.Tensor],
|
|
100
|
+
tta: bool = False,
|
|
19
101
|
device: Optional[torch.device] = None,
|
|
20
102
|
score_threshold: Optional[float] = None,
|
|
21
103
|
**kwargs: Any,
|
|
@@ -43,7 +125,7 @@ def infer_image(
|
|
|
43
125
|
device = torch.device("cpu")
|
|
44
126
|
|
|
45
127
|
input_tensor = transform(image).unsqueeze(dim=0).to(device)
|
|
46
|
-
detections = infer_batch(net, input_tensor, **kwargs)
|
|
128
|
+
detections = infer_batch(net, input_tensor, tta=tta, **kwargs)
|
|
47
129
|
if score_threshold is not None:
|
|
48
130
|
for i, detection in enumerate(detections):
|
|
49
131
|
idxs = torch.where(detection["scores"] > score_threshold)
|
|
@@ -63,16 +145,36 @@ def infer_batch(
|
|
|
63
145
|
inputs: torch.Tensor,
|
|
64
146
|
masks: Optional[torch.Tensor] = None,
|
|
65
147
|
image_sizes: Optional[list[list[int]]] = None,
|
|
148
|
+
tta: bool = False,
|
|
66
149
|
**kwargs: Any,
|
|
67
150
|
) -> list[dict[str, torch.Tensor]]:
|
|
68
|
-
|
|
69
|
-
|
|
151
|
+
if tta is False:
|
|
152
|
+
(detections, _) = net(inputs, masks=masks, image_sizes=image_sizes, **kwargs)
|
|
153
|
+
return detections # type: ignore[no-any-return]
|
|
154
|
+
|
|
155
|
+
normalized_sizes = _normalize_image_sizes(inputs, image_sizes)
|
|
156
|
+
detections_list: list[list[dict[str, torch.Tensor]]] = []
|
|
157
|
+
|
|
158
|
+
for scale in (0.8, 1.0, 1.2):
|
|
159
|
+
(scaled_inputs, scaled_masks, scaled_sizes) = _resize_batch(inputs, normalized_sizes, scale, size_divisible=32)
|
|
160
|
+
(detections, _) = net(scaled_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
|
|
161
|
+
detections = _rescale_detections(detections, scaled_sizes, normalized_sizes)
|
|
162
|
+
detections_list.append(detections)
|
|
163
|
+
|
|
164
|
+
flipped_inputs = _hflip_inputs(scaled_inputs, scaled_sizes)
|
|
165
|
+
(flipped_detections, _) = net(flipped_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
|
|
166
|
+
flipped_detections = _invert_detections(flipped_detections, scaled_sizes)
|
|
167
|
+
flipped_detections = _rescale_detections(flipped_detections, scaled_sizes, normalized_sizes)
|
|
168
|
+
detections_list.append(flipped_detections)
|
|
169
|
+
|
|
170
|
+
return fuse_detections_wbf(detections_list, iou_thr=0.55, conf_type="avg")
|
|
70
171
|
|
|
71
172
|
|
|
72
173
|
def infer_dataloader(
|
|
73
174
|
device: torch.device,
|
|
74
175
|
net: torch.nn.Module | torch.ScriptModule,
|
|
75
176
|
dataloader: DataLoader,
|
|
177
|
+
tta: bool = False,
|
|
76
178
|
model_dtype: torch.dtype = torch.float32,
|
|
77
179
|
amp: bool = False,
|
|
78
180
|
amp_dtype: Optional[torch.dtype] = None,
|
|
@@ -97,6 +199,8 @@ def infer_dataloader(
|
|
|
97
199
|
The model to use for inference.
|
|
98
200
|
dataloader
|
|
99
201
|
The DataLoader containing the dataset to perform inference on.
|
|
202
|
+
tta
|
|
203
|
+
Run inference with multi-scale and horizontal flip test time augmentation and fuse results with WBF.
|
|
100
204
|
model_dtype
|
|
101
205
|
The base dtype to use.
|
|
102
206
|
amp
|
|
@@ -142,7 +246,7 @@ def infer_dataloader(
|
|
|
142
246
|
masks = masks.to(device, non_blocking=True)
|
|
143
247
|
|
|
144
248
|
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
145
|
-
detections = infer_batch(net, inputs, masks, image_sizes)
|
|
249
|
+
detections = infer_batch(net, inputs, masks=masks, image_sizes=image_sizes, tta=tta)
|
|
146
250
|
|
|
147
251
|
detections = InferenceTransform.postprocess(detections, image_sizes, orig_sizes)
|
|
148
252
|
if targets[0] != settings.NO_LABEL:
|
birder/inference/wbf.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Weighted Boxes Fusion, adapted from
|
|
3
|
+
https://github.com/ZFTurbo/Weighted-Boxes-Fusion
|
|
4
|
+
|
|
5
|
+
Paper "Weighted boxes fusion: Ensembling boxes from different object detection models",
|
|
6
|
+
https://arxiv.org/abs/1910.13302
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# Reference license: MIT
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Literal
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torchvision.ops import box_iou
|
|
17
|
+
|
|
18
|
+
ConfType = Literal["avg", "max", "box_and_model_avg", "absent_model_aware_avg"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class BoxCluster:
|
|
23
|
+
box: torch.Tensor
|
|
24
|
+
score_weight_sum: torch.Tensor
|
|
25
|
+
weight_sum: torch.Tensor
|
|
26
|
+
max_score: torch.Tensor
|
|
27
|
+
boxes_count: int
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def from_entry(cls, box: torch.Tensor, score: torch.Tensor, weight: torch.Tensor) -> "BoxCluster":
|
|
31
|
+
score_weight = score * weight
|
|
32
|
+
return cls(
|
|
33
|
+
box=box.clone(),
|
|
34
|
+
score_weight_sum=score_weight,
|
|
35
|
+
weight_sum=weight,
|
|
36
|
+
max_score=score,
|
|
37
|
+
boxes_count=1,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def add(self, box: torch.Tensor, score: torch.Tensor, weight: torch.Tensor) -> None:
|
|
41
|
+
score_weight = score * weight
|
|
42
|
+
total_weight = self.score_weight_sum + score_weight
|
|
43
|
+
self.box = (self.box * self.score_weight_sum + box * score_weight) / total_weight
|
|
44
|
+
self.score_weight_sum = total_weight
|
|
45
|
+
self.weight_sum += weight
|
|
46
|
+
self.max_score = torch.maximum(self.max_score, score)
|
|
47
|
+
self.boxes_count += 1
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# pylint: disable=too-many-locals,too-many-branches
|
|
51
|
+
def weighted_boxes_fusion(
|
|
52
|
+
boxes_list: list[torch.Tensor],
|
|
53
|
+
scores_list: list[torch.Tensor],
|
|
54
|
+
labels_list: list[torch.Tensor],
|
|
55
|
+
weights: Optional[list[float]] = None,
|
|
56
|
+
iou_thr: float = 0.55,
|
|
57
|
+
skip_box_thr: float = 0.0,
|
|
58
|
+
conf_type: ConfType = "avg",
|
|
59
|
+
allows_overflow: bool = False,
|
|
60
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
61
|
+
if weights is None:
|
|
62
|
+
weights = [1.0] * len(boxes_list)
|
|
63
|
+
if len(weights) != len(boxes_list):
|
|
64
|
+
raise ValueError("weights must match number of box sets")
|
|
65
|
+
|
|
66
|
+
if len(boxes_list) > 0:
|
|
67
|
+
device = boxes_list[0].device
|
|
68
|
+
else:
|
|
69
|
+
device = torch.device("cpu")
|
|
70
|
+
|
|
71
|
+
boxes_all: list[torch.Tensor] = []
|
|
72
|
+
scores_all: list[torch.Tensor] = []
|
|
73
|
+
labels_all: list[torch.Tensor] = []
|
|
74
|
+
weights_all: list[torch.Tensor] = []
|
|
75
|
+
for boxes, scores, labels, weight in zip(boxes_list, scores_list, labels_list, weights):
|
|
76
|
+
if boxes.numel() == 0:
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
boxes_tensor = boxes.detach().to(dtype=torch.float32)
|
|
80
|
+
scores_tensor = scores.detach().to(dtype=torch.float32)
|
|
81
|
+
labels_tensor = labels.detach().to(dtype=torch.int64)
|
|
82
|
+
|
|
83
|
+
keep = scores_tensor >= skip_box_thr
|
|
84
|
+
if not keep.any():
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
boxes_tensor = boxes_tensor[keep]
|
|
88
|
+
scores_tensor = scores_tensor[keep]
|
|
89
|
+
labels_tensor = labels_tensor[keep]
|
|
90
|
+
weights_tensor = scores_tensor.new_full(scores_tensor.shape, weight)
|
|
91
|
+
|
|
92
|
+
boxes_all.append(boxes_tensor)
|
|
93
|
+
scores_all.append(scores_tensor)
|
|
94
|
+
labels_all.append(labels_tensor)
|
|
95
|
+
weights_all.append(weights_tensor)
|
|
96
|
+
|
|
97
|
+
if len(boxes_all) == 0:
|
|
98
|
+
empty_boxes = torch.zeros((0, 4), dtype=torch.float32, device=device)
|
|
99
|
+
empty_scores = torch.zeros((0,), dtype=torch.float32, device=device)
|
|
100
|
+
empty_labels = torch.zeros((0,), dtype=torch.int64, device=device)
|
|
101
|
+
return (empty_boxes, empty_scores, empty_labels)
|
|
102
|
+
|
|
103
|
+
boxes_tensor = torch.concat(boxes_all, dim=0)
|
|
104
|
+
scores_tensor = torch.concat(scores_all, dim=0)
|
|
105
|
+
labels_tensor = torch.concat(labels_all, dim=0)
|
|
106
|
+
weights_tensor = torch.concat(weights_all, dim=0)
|
|
107
|
+
labels_unique = torch.unique(labels_tensor)
|
|
108
|
+
|
|
109
|
+
total_weight = float(sum(weights))
|
|
110
|
+
num_models = len(weights)
|
|
111
|
+
fused_boxes: list[torch.Tensor] = []
|
|
112
|
+
fused_scores: list[torch.Tensor] = []
|
|
113
|
+
fused_labels: list[torch.Tensor] = []
|
|
114
|
+
|
|
115
|
+
for label in labels_unique:
|
|
116
|
+
label_mask = labels_tensor == label
|
|
117
|
+
label_boxes = boxes_tensor[label_mask]
|
|
118
|
+
label_scores = scores_tensor[label_mask]
|
|
119
|
+
label_weights = weights_tensor[label_mask]
|
|
120
|
+
order = torch.argsort(label_scores, descending=True)
|
|
121
|
+
clusters: list[BoxCluster] = []
|
|
122
|
+
for idx in order:
|
|
123
|
+
box = label_boxes[idx]
|
|
124
|
+
score = label_scores[idx]
|
|
125
|
+
weight = label_weights[idx]
|
|
126
|
+
if len(clusters) == 0:
|
|
127
|
+
clusters.append(BoxCluster.from_entry(box, score, weight))
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
cluster_boxes = torch.stack([cluster.box for cluster in clusters], dim=0)
|
|
131
|
+
ious = box_iou(box.unsqueeze(0), cluster_boxes).squeeze(0)
|
|
132
|
+
max_iou, max_idx = torch.max(ious, dim=0)
|
|
133
|
+
if max_iou > iou_thr:
|
|
134
|
+
clusters[int(max_idx)].add(box, score, weight)
|
|
135
|
+
else:
|
|
136
|
+
clusters.append(BoxCluster.from_entry(box, score, weight))
|
|
137
|
+
|
|
138
|
+
for cluster in clusters:
|
|
139
|
+
if conf_type == "avg":
|
|
140
|
+
score = cluster.score_weight_sum / cluster.weight_sum
|
|
141
|
+
elif conf_type == "max":
|
|
142
|
+
score = cluster.max_score
|
|
143
|
+
elif conf_type == "box_and_model_avg":
|
|
144
|
+
score = (cluster.score_weight_sum / cluster.weight_sum) * (cluster.boxes_count / num_models)
|
|
145
|
+
elif conf_type == "absent_model_aware_avg":
|
|
146
|
+
score = cluster.score_weight_sum / total_weight
|
|
147
|
+
else:
|
|
148
|
+
raise ValueError(f"Unsupported conf_type: {conf_type}")
|
|
149
|
+
|
|
150
|
+
if allows_overflow is False:
|
|
151
|
+
score = score.clamp(max=1.0)
|
|
152
|
+
|
|
153
|
+
fused_boxes.append(cluster.box)
|
|
154
|
+
fused_scores.append(score)
|
|
155
|
+
fused_labels.append(label)
|
|
156
|
+
|
|
157
|
+
fused_scores_tensor = torch.stack(fused_scores)
|
|
158
|
+
order = torch.argsort(fused_scores_tensor, descending=True)
|
|
159
|
+
fused_boxes_tensor = torch.stack(fused_boxes, dim=0)[order]
|
|
160
|
+
fused_scores_tensor = fused_scores_tensor[order]
|
|
161
|
+
fused_labels_tensor = torch.stack(fused_labels)[order]
|
|
162
|
+
|
|
163
|
+
return (fused_boxes_tensor, fused_scores_tensor, fused_labels_tensor)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def fuse_detections_wbf_single(
|
|
167
|
+
detections: list[dict[str, torch.Tensor]],
|
|
168
|
+
weights: Optional[list[float]] = None,
|
|
169
|
+
iou_thr: float = 0.55,
|
|
170
|
+
skip_box_thr: float = 0.0,
|
|
171
|
+
conf_type: ConfType = "avg",
|
|
172
|
+
allows_overflow: bool = False,
|
|
173
|
+
) -> dict[str, torch.Tensor]:
|
|
174
|
+
if len(detections) == 0:
|
|
175
|
+
return {
|
|
176
|
+
"boxes": torch.zeros((0, 4)),
|
|
177
|
+
"scores": torch.zeros((0,)),
|
|
178
|
+
"labels": torch.zeros((0,), dtype=torch.int64),
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
boxes_list = [detection["boxes"] for detection in detections]
|
|
182
|
+
scores_list = [detection["scores"] for detection in detections]
|
|
183
|
+
labels_list = [detection["labels"] for detection in detections]
|
|
184
|
+
|
|
185
|
+
(boxes, scores, labels) = weighted_boxes_fusion(
|
|
186
|
+
boxes_list,
|
|
187
|
+
scores_list,
|
|
188
|
+
labels_list,
|
|
189
|
+
weights=weights,
|
|
190
|
+
iou_thr=iou_thr,
|
|
191
|
+
skip_box_thr=skip_box_thr,
|
|
192
|
+
conf_type=conf_type,
|
|
193
|
+
allows_overflow=allows_overflow,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return {"boxes": boxes, "scores": scores, "labels": labels}
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def fuse_detections_wbf(
|
|
200
|
+
detections_list: list[list[dict[str, torch.Tensor]]],
|
|
201
|
+
weights: Optional[list[float]] = None,
|
|
202
|
+
iou_thr: float = 0.55,
|
|
203
|
+
skip_box_thr: float = 0.0,
|
|
204
|
+
conf_type: ConfType = "avg",
|
|
205
|
+
allows_overflow: bool = False,
|
|
206
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
207
|
+
if len(detections_list) == 0:
|
|
208
|
+
return []
|
|
209
|
+
|
|
210
|
+
# Outer list is the augmentations, inner is the batch
|
|
211
|
+
batch_size = len(detections_list[0])
|
|
212
|
+
fused: list[dict[str, torch.Tensor]] = []
|
|
213
|
+
for idx in range(batch_size):
|
|
214
|
+
per_image = [detections[idx] for detections in detections_list]
|
|
215
|
+
fused.append(
|
|
216
|
+
fuse_detections_wbf_single(
|
|
217
|
+
per_image,
|
|
218
|
+
weights=weights,
|
|
219
|
+
iou_thr=iou_thr,
|
|
220
|
+
skip_box_thr=skip_box_thr,
|
|
221
|
+
conf_type=conf_type,
|
|
222
|
+
allows_overflow=allows_overflow,
|
|
223
|
+
)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return fused
|
birder/kernels/load_kernel.py
CHANGED
|
@@ -14,11 +14,24 @@ logger = logging.getLogger(__name__)
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
_CACHED_KERNELS: dict[str, ModuleType] = {}
|
|
17
|
+
_CUSTOM_KERNELS_ENABLED = True
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def set_custom_kernels_enabled(enabled: bool) -> None:
|
|
21
|
+
global _CUSTOM_KERNELS_ENABLED # pylint: disable=global-statement
|
|
22
|
+
_CUSTOM_KERNELS_ENABLED = enabled
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_custom_kernels_enabled() -> bool:
|
|
26
|
+
if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
return _CUSTOM_KERNELS_ENABLED
|
|
17
30
|
|
|
18
31
|
|
|
19
32
|
def load_msda() -> Optional[ModuleType]:
|
|
20
33
|
name = "msda"
|
|
21
|
-
if torch.cuda.is_available() is False or
|
|
34
|
+
if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
|
|
22
35
|
return None
|
|
23
36
|
|
|
24
37
|
if name in _CACHED_KERNELS:
|
|
@@ -60,7 +73,7 @@ def load_msda() -> Optional[ModuleType]:
|
|
|
60
73
|
|
|
61
74
|
def load_swattention() -> Optional[ModuleType]:
|
|
62
75
|
name = "swattention"
|
|
63
|
-
if torch.cuda.is_available() is False or
|
|
76
|
+
if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
|
|
64
77
|
return None
|
|
65
78
|
|
|
66
79
|
if name in _CACHED_KERNELS:
|
|
@@ -103,7 +116,7 @@ def load_swattention() -> Optional[ModuleType]:
|
|
|
103
116
|
|
|
104
117
|
def load_soft_nms() -> Optional[ModuleType]:
|
|
105
118
|
name = "soft_nms"
|
|
106
|
-
if
|
|
119
|
+
if is_custom_kernels_enabled() is False:
|
|
107
120
|
return None
|
|
108
121
|
|
|
109
122
|
if name in _CACHED_KERNELS:
|
|
@@ -120,14 +133,6 @@ def load_soft_nms() -> Optional[ModuleType]:
|
|
|
120
133
|
soft_nms: Optional[ModuleType] = load(
|
|
121
134
|
"soft_nms",
|
|
122
135
|
src_files,
|
|
123
|
-
with_cuda=True,
|
|
124
|
-
extra_cflags=["-DWITH_CUDA=1"],
|
|
125
|
-
extra_cuda_cflags=[
|
|
126
|
-
"-DCUDA_HAS_FP16=1",
|
|
127
|
-
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
128
|
-
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
129
|
-
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
130
|
-
],
|
|
131
136
|
)
|
|
132
137
|
|
|
133
138
|
if soft_nms is not None:
|
|
@@ -61,24 +61,23 @@ void update_sorting_order(torch::Tensor& boxes, torch::Tensor& scores, torch::Te
|
|
|
61
61
|
std::tie(max_score, t_max_idx) = torch::max(scores.index({Slice(idx + 1, None)}), 0);
|
|
62
62
|
|
|
63
63
|
// max_idx is computed from sliced data, therefore need to convert it to "global" max idx
|
|
64
|
-
auto max_idx = t_max_idx
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
}
|
|
64
|
+
auto max_idx = t_max_idx + (idx + 1);
|
|
65
|
+
auto should_swap = scores.index({idx}) < max_score;
|
|
66
|
+
|
|
67
|
+
auto boxes_idx = boxes.index({idx}).clone();
|
|
68
|
+
auto boxes_max = boxes.index({max_idx}).clone();
|
|
69
|
+
boxes.index_put_({idx}, torch::where(should_swap, boxes_max, boxes_idx));
|
|
70
|
+
boxes.index_put_({max_idx}, torch::where(should_swap, boxes_idx, boxes_max));
|
|
71
|
+
|
|
72
|
+
auto scores_idx = scores.index({idx}).clone();
|
|
73
|
+
auto scores_max = scores.index({max_idx}).clone();
|
|
74
|
+
scores.index_put_({idx}, torch::where(should_swap, scores_max, scores_idx));
|
|
75
|
+
scores.index_put_({max_idx}, torch::where(should_swap, scores_idx, scores_max));
|
|
76
|
+
|
|
77
|
+
auto areas_idx = areas.index({idx}).clone();
|
|
78
|
+
auto areas_max = areas.index({max_idx}).clone();
|
|
79
|
+
areas.index_put_({idx}, torch::where(should_swap, areas_max, areas_idx));
|
|
80
|
+
areas.index_put_({max_idx}, torch::where(should_swap, areas_idx, areas_max));
|
|
82
81
|
}
|
|
83
82
|
|
|
84
83
|
std::tuple<torch::Tensor, torch::Tensor> soft_nms(
|