birder 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +24 -0
  3. birder/common/training_utils.py +338 -41
  4. birder/data/collators/detection.py +11 -3
  5. birder/data/dataloader/webdataset.py +12 -2
  6. birder/data/datasets/coco.py +8 -10
  7. birder/data/transforms/detection.py +30 -13
  8. birder/inference/detection.py +108 -4
  9. birder/inference/wbf.py +226 -0
  10. birder/kernels/load_kernel.py +16 -11
  11. birder/kernels/soft_nms/soft_nms.cpp +17 -18
  12. birder/net/__init__.py +8 -0
  13. birder/net/cait.py +4 -3
  14. birder/net/convnext_v1.py +5 -0
  15. birder/net/crossformer.py +33 -30
  16. birder/net/crossvit.py +4 -3
  17. birder/net/deit.py +3 -3
  18. birder/net/deit3.py +3 -3
  19. birder/net/detection/deformable_detr.py +2 -5
  20. birder/net/detection/detr.py +2 -5
  21. birder/net/detection/efficientdet.py +67 -93
  22. birder/net/detection/fcos.py +2 -7
  23. birder/net/detection/retinanet.py +2 -7
  24. birder/net/detection/rt_detr_v1.py +2 -0
  25. birder/net/detection/yolo_anchors.py +205 -0
  26. birder/net/detection/yolo_v2.py +25 -24
  27. birder/net/detection/yolo_v3.py +39 -40
  28. birder/net/detection/yolo_v4.py +28 -26
  29. birder/net/detection/yolo_v4_tiny.py +24 -20
  30. birder/net/efficientformer_v1.py +15 -9
  31. birder/net/efficientformer_v2.py +39 -29
  32. birder/net/efficientvit_msft.py +9 -7
  33. birder/net/fasternet.py +1 -1
  34. birder/net/fastvit.py +1 -0
  35. birder/net/flexivit.py +5 -4
  36. birder/net/gc_vit.py +671 -0
  37. birder/net/hiera.py +12 -9
  38. birder/net/hornet.py +9 -7
  39. birder/net/iformer.py +8 -6
  40. birder/net/levit.py +42 -30
  41. birder/net/lit_v1.py +472 -0
  42. birder/net/lit_v1_tiny.py +357 -0
  43. birder/net/lit_v2.py +436 -0
  44. birder/net/maxvit.py +67 -55
  45. birder/net/mobilenet_v4_hybrid.py +1 -1
  46. birder/net/mobileone.py +1 -0
  47. birder/net/mvit_v2.py +13 -12
  48. birder/net/pit.py +4 -3
  49. birder/net/pvt_v1.py +4 -1
  50. birder/net/repghost.py +1 -0
  51. birder/net/repvgg.py +1 -0
  52. birder/net/repvit.py +1 -0
  53. birder/net/resnet_v1.py +1 -1
  54. birder/net/resnext.py +67 -25
  55. birder/net/rope_deit3.py +5 -3
  56. birder/net/rope_flexivit.py +7 -4
  57. birder/net/rope_vit.py +10 -5
  58. birder/net/se_resnet_v1.py +46 -0
  59. birder/net/se_resnext.py +3 -0
  60. birder/net/simple_vit.py +11 -8
  61. birder/net/swin_transformer_v1.py +71 -68
  62. birder/net/swin_transformer_v2.py +38 -31
  63. birder/net/tiny_vit.py +20 -10
  64. birder/net/transnext.py +38 -28
  65. birder/net/vit.py +5 -19
  66. birder/net/vit_parallel.py +5 -4
  67. birder/net/vit_sam.py +38 -37
  68. birder/net/vovnet_v1.py +15 -0
  69. birder/net/vovnet_v2.py +31 -1
  70. birder/ops/msda.py +108 -43
  71. birder/ops/swattention.py +124 -61
  72. birder/results/detection.py +4 -0
  73. birder/scripts/benchmark.py +110 -32
  74. birder/scripts/predict.py +8 -0
  75. birder/scripts/predict_detection.py +18 -11
  76. birder/scripts/train.py +48 -46
  77. birder/scripts/train_barlow_twins.py +44 -45
  78. birder/scripts/train_byol.py +44 -45
  79. birder/scripts/train_capi.py +50 -49
  80. birder/scripts/train_data2vec.py +45 -47
  81. birder/scripts/train_data2vec2.py +45 -47
  82. birder/scripts/train_detection.py +83 -50
  83. birder/scripts/train_dino_v1.py +60 -47
  84. birder/scripts/train_dino_v2.py +86 -52
  85. birder/scripts/train_dino_v2_dist.py +84 -50
  86. birder/scripts/train_franca.py +51 -52
  87. birder/scripts/train_i_jepa.py +45 -47
  88. birder/scripts/train_ibot.py +51 -53
  89. birder/scripts/train_kd.py +194 -76
  90. birder/scripts/train_mim.py +44 -45
  91. birder/scripts/train_mmcr.py +44 -45
  92. birder/scripts/train_rotnet.py +45 -46
  93. birder/scripts/train_simclr.py +44 -45
  94. birder/scripts/train_vicreg.py +44 -45
  95. birder/tools/auto_anchors.py +20 -1
  96. birder/tools/convert_model.py +18 -15
  97. birder/tools/det_results.py +114 -2
  98. birder/tools/pack.py +172 -103
  99. birder/tools/quantize_model.py +73 -67
  100. birder/tools/show_det_iterator.py +10 -1
  101. birder/version.py +1 -1
  102. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/METADATA +4 -3
  103. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/RECORD +107 -101
  104. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
  105. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
  106. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
  107. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
@@ -22,9 +22,19 @@ def make_wds_loader(
22
22
  shuffle: bool = False,
23
23
  *,
24
24
  exact: bool = False,
25
+ infinite: bool = False,
25
26
  ) -> DataLoader:
27
+ assert exact is False or infinite is False
28
+
29
+ if infinite is True:
30
+ dataset_iterable = dataset.repeat()
31
+ elif exact is False:
32
+ dataset_iterable = dataset.repeat()
33
+ else:
34
+ dataset_iterable = dataset
35
+
26
36
  dataloader = wds.WebLoader(
27
- dataset.repeat() if exact is False else dataset,
37
+ dataset_iterable,
28
38
  batch_size=batch_size,
29
39
  num_workers=num_workers,
30
40
  prefetch_factor=prefetch_factor,
@@ -43,7 +53,7 @@ def make_wds_loader(
43
53
  epoch_size = math.ceil(len(dataset) / (batch_size * world_size))
44
54
 
45
55
  dataloader = dataloader.with_length(epoch_size, silent=True)
46
- if exact is False:
56
+ if exact is False and infinite is False:
47
57
  dataloader = dataloader.with_epoch(epoch_size)
48
58
 
49
59
  return dataloader
@@ -98,10 +98,14 @@ class CocoTraining(CocoBase):
98
98
  class CocoInference(CocoBase):
99
99
  def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any, list[int]]:
100
100
  coco_id = self.dataset.ids[index]
101
- path = self.dataset.coco.loadImgs(coco_id)[0]["file_name"]
101
+ img_info = self.dataset.coco.loadImgs(coco_id)[0]
102
+ path = img_info["file_name"]
102
103
  (sample, labels) = self.dataset[index]
103
104
 
104
- return (path, sample, labels, F.get_size(sample))
105
+ # Get original image size (height, width) before transforms
106
+ orig_size = [img_info["height"], img_info["width"]]
107
+
108
+ return (path, sample, labels, orig_size)
105
109
 
106
110
 
107
111
  class CocoMosaicTraining(CocoBase):
@@ -127,9 +131,7 @@ class CocoMosaicTraining(CocoBase):
127
131
  self._mosaic_decay_epochs: Optional[int] = None
128
132
  self._mosaic_decay_start: Optional[int] = None
129
133
 
130
- def configure_mosaic_linear_decay(
131
- self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1
132
- ) -> None:
134
+ def configure_mosaic_linear_decay(self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1) -> None:
133
135
  if total_epochs <= 0:
134
136
  raise ValueError("total_epochs must be positive")
135
137
  if decay_fraction <= 0.0 or decay_fraction > 1.0:
@@ -141,11 +143,7 @@ class CocoMosaicTraining(CocoBase):
141
143
  self._mosaic_decay_start = max(1, total_epochs - decay_epochs + 1)
142
144
 
143
145
  def update_mosaic_prob(self, epoch: int) -> Optional[float]:
144
- if (
145
- self._mosaic_base_prob is None
146
- or self._mosaic_decay_epochs is None
147
- or self._mosaic_decay_start is None
148
- ):
146
+ if self._mosaic_base_prob is None or self._mosaic_decay_epochs is None or self._mosaic_decay_start is None:
149
147
  return None
150
148
 
151
149
  if epoch >= self._mosaic_decay_start:
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import random
2
3
  from collections.abc import Callable
3
4
  from typing import Any
@@ -10,6 +11,24 @@ from torchvision.transforms import v2
10
11
 
11
12
  from birder.data.transforms.classification import RGBType
12
13
 
14
+ MULTISCALE_STEP = 32
15
+ DEFAULT_MULTISCALE_MIN_SIZE = 480
16
+ DEFAULT_MULTISCALE_MAX_SIZE = 800
17
+
18
+
19
+ def build_multiscale_sizes(
20
+ min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE
21
+ ) -> tuple[int, ...]:
22
+ if min_size is None:
23
+ min_size = DEFAULT_MULTISCALE_MIN_SIZE
24
+
25
+ start = int(math.ceil(min_size / MULTISCALE_STEP) * MULTISCALE_STEP)
26
+ end = int(math.floor(max_size / MULTISCALE_STEP) * MULTISCALE_STEP)
27
+ if end < start:
28
+ return (start,)
29
+
30
+ return tuple(range(start, end + 1, MULTISCALE_STEP))
31
+
13
32
 
14
33
  class ResizeWithRandomInterpolation(nn.Module):
15
34
  def __init__(
@@ -39,6 +58,7 @@ def get_birder_augment(
39
58
  dynamic_size: bool,
40
59
  multiscale: bool,
41
60
  max_size: Optional[int],
61
+ multiscale_min_size: Optional[int],
42
62
  post_mosaic: bool = False,
43
63
  ) -> Callable[..., torch.Tensor]:
44
64
  if dynamic_size is True:
@@ -78,9 +98,7 @@ def get_birder_augment(
78
98
  # Resize
79
99
  if multiscale is True:
80
100
  transformations.append(
81
- v2.RandomShortestSize(
82
- min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
83
- ),
101
+ v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
84
102
  )
85
103
  else:
86
104
  transformations.append(
@@ -132,6 +150,7 @@ def get_birder_augment(
132
150
  AugType = Literal["birder", "lsj", "multiscale", "ssd", "ssdlite", "yolo", "detr"]
133
151
 
134
152
 
153
+ # pylint: disable=too-many-return-statements
135
154
  def training_preset(
136
155
  size: tuple[int, int],
137
156
  aug_type: AugType,
@@ -140,6 +159,7 @@ def training_preset(
140
159
  dynamic_size: bool = False,
141
160
  multiscale: bool = False,
142
161
  max_size: Optional[int] = None,
162
+ multiscale_min_size: Optional[int] = None,
143
163
  post_mosaic: bool = False,
144
164
  ) -> Callable[..., torch.Tensor]:
145
165
  mean = rgv_values["mean"]
@@ -159,7 +179,9 @@ def training_preset(
159
179
  return v2.Compose( # type:ignore
160
180
  [
161
181
  v2.ToImage(),
162
- get_birder_augment(size, level, fill_value, dynamic_size, multiscale, max_size, post_mosaic),
182
+ get_birder_augment(
183
+ size, level, fill_value, dynamic_size, multiscale, max_size, multiscale_min_size, post_mosaic
184
+ ),
163
185
  v2.ToDtype(torch.float32, scale=True),
164
186
  v2.Normalize(mean=mean, std=std),
165
187
  v2.ToPureTensor(),
@@ -190,9 +212,7 @@ def training_preset(
190
212
  return v2.Compose( # type: ignore
191
213
  [
192
214
  v2.ToImage(),
193
- v2.RandomShortestSize(
194
- min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
195
- ),
215
+ v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
196
216
  v2.RandomHorizontalFlip(0.5),
197
217
  v2.SanitizeBoundingBoxes(),
198
218
  v2.ToDtype(torch.float32, scale=True),
@@ -264,21 +284,18 @@ def training_preset(
264
284
  )
265
285
 
266
286
  if aug_type == "detr":
287
+ multiscale_sizes = build_multiscale_sizes(multiscale_min_size)
267
288
  return v2.Compose( # type: ignore
268
289
  [
269
290
  v2.ToImage(),
270
291
  v2.RandomChoice(
271
292
  [
272
- v2.RandomShortestSize(
273
- (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
274
- ),
293
+ v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
275
294
  v2.Compose(
276
295
  [
277
296
  v2.RandomShortestSize((400, 500, 600)),
278
297
  v2.RandomIoUCrop() if post_mosaic is False else v2.Identity(), # RandomSizeCrop
279
- v2.RandomShortestSize(
280
- (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
281
- ),
298
+ v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
282
299
  ]
283
300
  ),
284
301
  ]
@@ -5,17 +5,99 @@ from typing import Optional
5
5
  import torch
6
6
  import torch.amp
7
7
  from PIL import Image
8
+ from torch.nn import functional as F
8
9
  from torch.utils.data import DataLoader
9
10
  from tqdm import tqdm
10
11
 
11
12
  from birder.conf import settings
13
+ from birder.data.collators.detection import batch_images
12
14
  from birder.data.transforms.detection import InferenceTransform
15
+ from birder.inference.wbf import fuse_detections_wbf
16
+ from birder.net.base import make_divisible
17
+
18
+
19
+ def _normalize_image_sizes(inputs: torch.Tensor, image_sizes: Optional[list[list[int]]]) -> list[list[int]]:
20
+ if image_sizes is not None:
21
+ return image_sizes
22
+
23
+ (_, _, height, width) = inputs.shape
24
+ return [[height, width] for _ in range(inputs.size(0))]
25
+
26
+
27
+ def _hflip_inputs(inputs: torch.Tensor, image_sizes: list[list[int]]) -> torch.Tensor:
28
+ # Detection collator pads on the right/bottom, so flip only the valid region to keep padding aligned.
29
+ flipped = inputs.clone()
30
+ for idx, (height, width) in enumerate(image_sizes):
31
+ flipped[idx, :, :height, :width] = torch.flip(inputs[idx, :, :height, :width], dims=[2])
32
+
33
+ return flipped
34
+
35
+
36
+ def _resize_batch(
37
+ inputs: torch.Tensor, image_sizes: list[list[int]], scale: float, size_divisible: int
38
+ ) -> tuple[torch.Tensor, torch.Tensor, list[list[int]]]:
39
+ resized_images: list[torch.Tensor] = []
40
+ for idx, (height, width) in enumerate(image_sizes):
41
+ target_h = make_divisible(height * scale, size_divisible)
42
+ target_w = make_divisible(width * scale, size_divisible)
43
+ image = inputs[idx, :, :height, :width]
44
+ resized = F.interpolate(image.unsqueeze(0), size=(target_h, target_w), mode="bilinear", align_corners=False)
45
+ resized_images.append(resized.squeeze(0))
46
+
47
+ return batch_images(resized_images, size_divisible)
48
+
49
+
50
+ def _rescale_boxes(boxes: torch.Tensor, from_size: list[int], to_size: list[int]) -> torch.Tensor:
51
+ scale_w = to_size[1] / from_size[1]
52
+ scale_h = to_size[0] / from_size[0]
53
+ scale = boxes.new_tensor([scale_w, scale_h, scale_w, scale_h])
54
+ return boxes * scale
55
+
56
+
57
+ def _rescale_detections(
58
+ detections: list[dict[str, torch.Tensor]],
59
+ from_sizes: list[list[int]],
60
+ to_sizes: list[list[int]],
61
+ ) -> list[dict[str, torch.Tensor]]:
62
+ for idx, (detection, from_size, to_size) in enumerate(zip(detections, from_sizes, to_sizes)):
63
+ boxes = detection["boxes"]
64
+ if boxes.numel() == 0:
65
+ continue
66
+
67
+ detections[idx]["boxes"] = _rescale_boxes(boxes, from_size, to_size)
68
+
69
+ return detections
70
+
71
+
72
+ def _invert_hflip_boxes(boxes: torch.Tensor, image_size: list[int]) -> torch.Tensor:
73
+ width = boxes.new_tensor(image_size[1])
74
+ x1 = boxes[:, 0]
75
+ x2 = boxes[:, 2]
76
+ flipped = boxes.clone()
77
+ flipped[:, 0] = width - x2
78
+ flipped[:, 2] = width - x1
79
+
80
+ return flipped
81
+
82
+
83
+ def _invert_detections(
84
+ detections: list[dict[str, torch.Tensor]], image_sizes: list[list[int]]
85
+ ) -> list[dict[str, torch.Tensor]]:
86
+ for idx, (detection, image_size) in enumerate(zip(detections, image_sizes)):
87
+ boxes = detection["boxes"]
88
+ if boxes.numel() == 0:
89
+ continue
90
+
91
+ detections[idx]["boxes"] = _invert_hflip_boxes(boxes, image_size)
92
+
93
+ return detections
13
94
 
14
95
 
15
96
  def infer_image(
16
97
  net: torch.nn.Module | torch.ScriptModule,
17
98
  sample: Image.Image | str,
18
99
  transform: Callable[..., torch.Tensor],
100
+ tta: bool = False,
19
101
  device: Optional[torch.device] = None,
20
102
  score_threshold: Optional[float] = None,
21
103
  **kwargs: Any,
@@ -43,7 +125,7 @@ def infer_image(
43
125
  device = torch.device("cpu")
44
126
 
45
127
  input_tensor = transform(image).unsqueeze(dim=0).to(device)
46
- detections = infer_batch(net, input_tensor, **kwargs)
128
+ detections = infer_batch(net, input_tensor, tta=tta, **kwargs)
47
129
  if score_threshold is not None:
48
130
  for i, detection in enumerate(detections):
49
131
  idxs = torch.where(detection["scores"] > score_threshold)
@@ -63,16 +145,36 @@ def infer_batch(
63
145
  inputs: torch.Tensor,
64
146
  masks: Optional[torch.Tensor] = None,
65
147
  image_sizes: Optional[list[list[int]]] = None,
148
+ tta: bool = False,
66
149
  **kwargs: Any,
67
150
  ) -> list[dict[str, torch.Tensor]]:
68
- (detections, _) = net(inputs, masks=masks, image_sizes=image_sizes, **kwargs)
69
- return detections # type: ignore[no-any-return]
151
+ if tta is False:
152
+ (detections, _) = net(inputs, masks=masks, image_sizes=image_sizes, **kwargs)
153
+ return detections # type: ignore[no-any-return]
154
+
155
+ normalized_sizes = _normalize_image_sizes(inputs, image_sizes)
156
+ detections_list: list[list[dict[str, torch.Tensor]]] = []
157
+
158
+ for scale in (0.8, 1.0, 1.2):
159
+ (scaled_inputs, scaled_masks, scaled_sizes) = _resize_batch(inputs, normalized_sizes, scale, size_divisible=32)
160
+ (detections, _) = net(scaled_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
161
+ detections = _rescale_detections(detections, scaled_sizes, normalized_sizes)
162
+ detections_list.append(detections)
163
+
164
+ flipped_inputs = _hflip_inputs(scaled_inputs, scaled_sizes)
165
+ (flipped_detections, _) = net(flipped_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
166
+ flipped_detections = _invert_detections(flipped_detections, scaled_sizes)
167
+ flipped_detections = _rescale_detections(flipped_detections, scaled_sizes, normalized_sizes)
168
+ detections_list.append(flipped_detections)
169
+
170
+ return fuse_detections_wbf(detections_list, iou_thr=0.55, conf_type="avg")
70
171
 
71
172
 
72
173
  def infer_dataloader(
73
174
  device: torch.device,
74
175
  net: torch.nn.Module | torch.ScriptModule,
75
176
  dataloader: DataLoader,
177
+ tta: bool = False,
76
178
  model_dtype: torch.dtype = torch.float32,
77
179
  amp: bool = False,
78
180
  amp_dtype: Optional[torch.dtype] = None,
@@ -97,6 +199,8 @@ def infer_dataloader(
97
199
  The model to use for inference.
98
200
  dataloader
99
201
  The DataLoader containing the dataset to perform inference on.
202
+ tta
203
+ Run inference with multi-scale and horizontal flip test time augmentation and fuse results with WBF.
100
204
  model_dtype
101
205
  The base dtype to use.
102
206
  amp
@@ -142,7 +246,7 @@ def infer_dataloader(
142
246
  masks = masks.to(device, non_blocking=True)
143
247
 
144
248
  with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
145
- detections = infer_batch(net, inputs, masks, image_sizes)
249
+ detections = infer_batch(net, inputs, masks=masks, image_sizes=image_sizes, tta=tta)
146
250
 
147
251
  detections = InferenceTransform.postprocess(detections, image_sizes, orig_sizes)
148
252
  if targets[0] != settings.NO_LABEL:
@@ -0,0 +1,226 @@
1
+ """
2
+ Weighted Boxes Fusion, adapted from
3
+ https://github.com/ZFTurbo/Weighted-Boxes-Fusion
4
+
5
+ Paper "Weighted boxes fusion: Ensembling boxes from different object detection models",
6
+ https://arxiv.org/abs/1910.13302
7
+ """
8
+
9
+ # Reference license: MIT
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Literal
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from torchvision.ops import box_iou
17
+
18
+ ConfType = Literal["avg", "max", "box_and_model_avg", "absent_model_aware_avg"]
19
+
20
+
21
+ @dataclass
22
+ class BoxCluster:
23
+ box: torch.Tensor
24
+ score_weight_sum: torch.Tensor
25
+ weight_sum: torch.Tensor
26
+ max_score: torch.Tensor
27
+ boxes_count: int
28
+
29
+ @classmethod
30
+ def from_entry(cls, box: torch.Tensor, score: torch.Tensor, weight: torch.Tensor) -> "BoxCluster":
31
+ score_weight = score * weight
32
+ return cls(
33
+ box=box.clone(),
34
+ score_weight_sum=score_weight,
35
+ weight_sum=weight,
36
+ max_score=score,
37
+ boxes_count=1,
38
+ )
39
+
40
+ def add(self, box: torch.Tensor, score: torch.Tensor, weight: torch.Tensor) -> None:
41
+ score_weight = score * weight
42
+ total_weight = self.score_weight_sum + score_weight
43
+ self.box = (self.box * self.score_weight_sum + box * score_weight) / total_weight
44
+ self.score_weight_sum = total_weight
45
+ self.weight_sum += weight
46
+ self.max_score = torch.maximum(self.max_score, score)
47
+ self.boxes_count += 1
48
+
49
+
50
+ # pylint: disable=too-many-locals,too-many-branches
51
+ def weighted_boxes_fusion(
52
+ boxes_list: list[torch.Tensor],
53
+ scores_list: list[torch.Tensor],
54
+ labels_list: list[torch.Tensor],
55
+ weights: Optional[list[float]] = None,
56
+ iou_thr: float = 0.55,
57
+ skip_box_thr: float = 0.0,
58
+ conf_type: ConfType = "avg",
59
+ allows_overflow: bool = False,
60
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
61
+ if weights is None:
62
+ weights = [1.0] * len(boxes_list)
63
+ if len(weights) != len(boxes_list):
64
+ raise ValueError("weights must match number of box sets")
65
+
66
+ if len(boxes_list) > 0:
67
+ device = boxes_list[0].device
68
+ else:
69
+ device = torch.device("cpu")
70
+
71
+ boxes_all: list[torch.Tensor] = []
72
+ scores_all: list[torch.Tensor] = []
73
+ labels_all: list[torch.Tensor] = []
74
+ weights_all: list[torch.Tensor] = []
75
+ for boxes, scores, labels, weight in zip(boxes_list, scores_list, labels_list, weights):
76
+ if boxes.numel() == 0:
77
+ continue
78
+
79
+ boxes_tensor = boxes.detach().to(dtype=torch.float32)
80
+ scores_tensor = scores.detach().to(dtype=torch.float32)
81
+ labels_tensor = labels.detach().to(dtype=torch.int64)
82
+
83
+ keep = scores_tensor >= skip_box_thr
84
+ if not keep.any():
85
+ continue
86
+
87
+ boxes_tensor = boxes_tensor[keep]
88
+ scores_tensor = scores_tensor[keep]
89
+ labels_tensor = labels_tensor[keep]
90
+ weights_tensor = scores_tensor.new_full(scores_tensor.shape, weight)
91
+
92
+ boxes_all.append(boxes_tensor)
93
+ scores_all.append(scores_tensor)
94
+ labels_all.append(labels_tensor)
95
+ weights_all.append(weights_tensor)
96
+
97
+ if len(boxes_all) == 0:
98
+ empty_boxes = torch.zeros((0, 4), dtype=torch.float32, device=device)
99
+ empty_scores = torch.zeros((0,), dtype=torch.float32, device=device)
100
+ empty_labels = torch.zeros((0,), dtype=torch.int64, device=device)
101
+ return (empty_boxes, empty_scores, empty_labels)
102
+
103
+ boxes_tensor = torch.concat(boxes_all, dim=0)
104
+ scores_tensor = torch.concat(scores_all, dim=0)
105
+ labels_tensor = torch.concat(labels_all, dim=0)
106
+ weights_tensor = torch.concat(weights_all, dim=0)
107
+ labels_unique = torch.unique(labels_tensor)
108
+
109
+ total_weight = float(sum(weights))
110
+ num_models = len(weights)
111
+ fused_boxes: list[torch.Tensor] = []
112
+ fused_scores: list[torch.Tensor] = []
113
+ fused_labels: list[torch.Tensor] = []
114
+
115
+ for label in labels_unique:
116
+ label_mask = labels_tensor == label
117
+ label_boxes = boxes_tensor[label_mask]
118
+ label_scores = scores_tensor[label_mask]
119
+ label_weights = weights_tensor[label_mask]
120
+ order = torch.argsort(label_scores, descending=True)
121
+ clusters: list[BoxCluster] = []
122
+ for idx in order:
123
+ box = label_boxes[idx]
124
+ score = label_scores[idx]
125
+ weight = label_weights[idx]
126
+ if len(clusters) == 0:
127
+ clusters.append(BoxCluster.from_entry(box, score, weight))
128
+ continue
129
+
130
+ cluster_boxes = torch.stack([cluster.box for cluster in clusters], dim=0)
131
+ ious = box_iou(box.unsqueeze(0), cluster_boxes).squeeze(0)
132
+ max_iou, max_idx = torch.max(ious, dim=0)
133
+ if max_iou > iou_thr:
134
+ clusters[int(max_idx)].add(box, score, weight)
135
+ else:
136
+ clusters.append(BoxCluster.from_entry(box, score, weight))
137
+
138
+ for cluster in clusters:
139
+ if conf_type == "avg":
140
+ score = cluster.score_weight_sum / cluster.weight_sum
141
+ elif conf_type == "max":
142
+ score = cluster.max_score
143
+ elif conf_type == "box_and_model_avg":
144
+ score = (cluster.score_weight_sum / cluster.weight_sum) * (cluster.boxes_count / num_models)
145
+ elif conf_type == "absent_model_aware_avg":
146
+ score = cluster.score_weight_sum / total_weight
147
+ else:
148
+ raise ValueError(f"Unsupported conf_type: {conf_type}")
149
+
150
+ if allows_overflow is False:
151
+ score = score.clamp(max=1.0)
152
+
153
+ fused_boxes.append(cluster.box)
154
+ fused_scores.append(score)
155
+ fused_labels.append(label)
156
+
157
+ fused_scores_tensor = torch.stack(fused_scores)
158
+ order = torch.argsort(fused_scores_tensor, descending=True)
159
+ fused_boxes_tensor = torch.stack(fused_boxes, dim=0)[order]
160
+ fused_scores_tensor = fused_scores_tensor[order]
161
+ fused_labels_tensor = torch.stack(fused_labels)[order]
162
+
163
+ return (fused_boxes_tensor, fused_scores_tensor, fused_labels_tensor)
164
+
165
+
166
+ def fuse_detections_wbf_single(
167
+ detections: list[dict[str, torch.Tensor]],
168
+ weights: Optional[list[float]] = None,
169
+ iou_thr: float = 0.55,
170
+ skip_box_thr: float = 0.0,
171
+ conf_type: ConfType = "avg",
172
+ allows_overflow: bool = False,
173
+ ) -> dict[str, torch.Tensor]:
174
+ if len(detections) == 0:
175
+ return {
176
+ "boxes": torch.zeros((0, 4)),
177
+ "scores": torch.zeros((0,)),
178
+ "labels": torch.zeros((0,), dtype=torch.int64),
179
+ }
180
+
181
+ boxes_list = [detection["boxes"] for detection in detections]
182
+ scores_list = [detection["scores"] for detection in detections]
183
+ labels_list = [detection["labels"] for detection in detections]
184
+
185
+ (boxes, scores, labels) = weighted_boxes_fusion(
186
+ boxes_list,
187
+ scores_list,
188
+ labels_list,
189
+ weights=weights,
190
+ iou_thr=iou_thr,
191
+ skip_box_thr=skip_box_thr,
192
+ conf_type=conf_type,
193
+ allows_overflow=allows_overflow,
194
+ )
195
+
196
+ return {"boxes": boxes, "scores": scores, "labels": labels}
197
+
198
+
199
+ def fuse_detections_wbf(
200
+ detections_list: list[list[dict[str, torch.Tensor]]],
201
+ weights: Optional[list[float]] = None,
202
+ iou_thr: float = 0.55,
203
+ skip_box_thr: float = 0.0,
204
+ conf_type: ConfType = "avg",
205
+ allows_overflow: bool = False,
206
+ ) -> list[dict[str, torch.Tensor]]:
207
+ if len(detections_list) == 0:
208
+ return []
209
+
210
+ # Outer list is the augmentations, inner is the batch
211
+ batch_size = len(detections_list[0])
212
+ fused: list[dict[str, torch.Tensor]] = []
213
+ for idx in range(batch_size):
214
+ per_image = [detections[idx] for detections in detections_list]
215
+ fused.append(
216
+ fuse_detections_wbf_single(
217
+ per_image,
218
+ weights=weights,
219
+ iou_thr=iou_thr,
220
+ skip_box_thr=skip_box_thr,
221
+ conf_type=conf_type,
222
+ allows_overflow=allows_overflow,
223
+ )
224
+ )
225
+
226
+ return fused
@@ -14,11 +14,24 @@ logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  _CACHED_KERNELS: dict[str, ModuleType] = {}
17
+ _CUSTOM_KERNELS_ENABLED = True
18
+
19
+
20
+ def set_custom_kernels_enabled(enabled: bool) -> None:
21
+ global _CUSTOM_KERNELS_ENABLED # pylint: disable=global-statement
22
+ _CUSTOM_KERNELS_ENABLED = enabled
23
+
24
+
25
+ def is_custom_kernels_enabled() -> bool:
26
+ if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
27
+ return False
28
+
29
+ return _CUSTOM_KERNELS_ENABLED
17
30
 
18
31
 
19
32
  def load_msda() -> Optional[ModuleType]:
20
33
  name = "msda"
21
- if torch.cuda.is_available() is False or os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
34
+ if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
22
35
  return None
23
36
 
24
37
  if name in _CACHED_KERNELS:
@@ -60,7 +73,7 @@ def load_msda() -> Optional[ModuleType]:
60
73
 
61
74
  def load_swattention() -> Optional[ModuleType]:
62
75
  name = "swattention"
63
- if torch.cuda.is_available() is False or os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
76
+ if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
64
77
  return None
65
78
 
66
79
  if name in _CACHED_KERNELS:
@@ -103,7 +116,7 @@ def load_swattention() -> Optional[ModuleType]:
103
116
 
104
117
  def load_soft_nms() -> Optional[ModuleType]:
105
118
  name = "soft_nms"
106
- if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
119
+ if is_custom_kernels_enabled() is False:
107
120
  return None
108
121
 
109
122
  if name in _CACHED_KERNELS:
@@ -120,14 +133,6 @@ def load_soft_nms() -> Optional[ModuleType]:
120
133
  soft_nms: Optional[ModuleType] = load(
121
134
  "soft_nms",
122
135
  src_files,
123
- with_cuda=True,
124
- extra_cflags=["-DWITH_CUDA=1"],
125
- extra_cuda_cflags=[
126
- "-DCUDA_HAS_FP16=1",
127
- "-D__CUDA_NO_HALF_OPERATORS__",
128
- "-D__CUDA_NO_HALF_CONVERSIONS__",
129
- "-D__CUDA_NO_HALF2_OPERATORS__",
130
- ],
131
136
  )
132
137
 
133
138
  if soft_nms is not None:
@@ -61,24 +61,23 @@ void update_sorting_order(torch::Tensor& boxes, torch::Tensor& scores, torch::Te
61
61
  std::tie(max_score, t_max_idx) = torch::max(scores.index({Slice(idx + 1, None)}), 0);
62
62
 
63
63
  // max_idx is computed from sliced data, therefore need to convert it to "global" max idx
64
- auto max_idx = t_max_idx.item<int>() + idx + 1;
65
-
66
- if (scores.index({idx}).item<float>() < max_score.item<float>()) {
67
- auto boxes_idx = boxes.index({idx}).clone();
68
- auto boxes_max = boxes.index({max_idx}).clone();
69
- boxes.index({idx}) = boxes_max;
70
- boxes.index({max_idx}) = boxes_idx;
71
-
72
- auto scores_idx = scores.index({idx}).clone();
73
- auto scores_max = scores.index({max_idx}).clone();
74
- scores.index({idx}) = scores_max;
75
- scores.index({max_idx}) = scores_idx;
76
-
77
- auto areas_idx = areas.index({idx}).clone();
78
- auto areas_max = areas.index({max_idx}).clone();
79
- areas.index({idx}) = areas_max;
80
- areas.index({max_idx}) = areas_idx;
81
- }
64
+ auto max_idx = t_max_idx + (idx + 1);
65
+ auto should_swap = scores.index({idx}) < max_score;
66
+
67
+ auto boxes_idx = boxes.index({idx}).clone();
68
+ auto boxes_max = boxes.index({max_idx}).clone();
69
+ boxes.index_put_({idx}, torch::where(should_swap, boxes_max, boxes_idx));
70
+ boxes.index_put_({max_idx}, torch::where(should_swap, boxes_idx, boxes_max));
71
+
72
+ auto scores_idx = scores.index({idx}).clone();
73
+ auto scores_max = scores.index({max_idx}).clone();
74
+ scores.index_put_({idx}, torch::where(should_swap, scores_max, scores_idx));
75
+ scores.index_put_({max_idx}, torch::where(should_swap, scores_idx, scores_max));
76
+
77
+ auto areas_idx = areas.index({idx}).clone();
78
+ auto areas_max = areas.index({max_idx}).clone();
79
+ areas.index_put_({idx}, torch::where(should_swap, areas_max, areas_idx));
80
+ areas.index_put_({max_idx}, torch::where(should_swap, areas_idx, areas_max));
82
81
  }
83
82
 
84
83
  std::tuple<torch::Tensor, torch::Tensor> soft_nms(