birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +18 -0
  3. birder/common/training_utils.py +123 -10
  4. birder/data/collators/detection.py +10 -3
  5. birder/data/datasets/coco.py +8 -10
  6. birder/data/transforms/detection.py +30 -13
  7. birder/inference/detection.py +108 -4
  8. birder/inference/wbf.py +226 -0
  9. birder/net/__init__.py +8 -0
  10. birder/net/detection/efficientdet.py +65 -86
  11. birder/net/detection/rt_detr_v1.py +1 -0
  12. birder/net/detection/yolo_anchors.py +205 -0
  13. birder/net/detection/yolo_v2.py +25 -24
  14. birder/net/detection/yolo_v3.py +39 -40
  15. birder/net/detection/yolo_v4.py +28 -26
  16. birder/net/detection/yolo_v4_tiny.py +24 -20
  17. birder/net/fasternet.py +1 -1
  18. birder/net/gc_vit.py +671 -0
  19. birder/net/lit_v1.py +472 -0
  20. birder/net/lit_v1_tiny.py +342 -0
  21. birder/net/lit_v2.py +436 -0
  22. birder/net/mobilenet_v4_hybrid.py +1 -1
  23. birder/net/resnet_v1.py +1 -1
  24. birder/net/resnext.py +67 -25
  25. birder/net/se_resnet_v1.py +46 -0
  26. birder/net/se_resnext.py +3 -0
  27. birder/net/simple_vit.py +2 -2
  28. birder/net/vit.py +0 -15
  29. birder/net/vovnet_v2.py +31 -1
  30. birder/scripts/benchmark.py +90 -21
  31. birder/scripts/predict.py +1 -0
  32. birder/scripts/predict_detection.py +18 -11
  33. birder/scripts/train.py +10 -34
  34. birder/scripts/train_barlow_twins.py +10 -34
  35. birder/scripts/train_byol.py +10 -34
  36. birder/scripts/train_capi.py +10 -35
  37. birder/scripts/train_data2vec.py +9 -34
  38. birder/scripts/train_data2vec2.py +9 -34
  39. birder/scripts/train_detection.py +48 -40
  40. birder/scripts/train_dino_v1.py +10 -34
  41. birder/scripts/train_dino_v2.py +9 -34
  42. birder/scripts/train_dino_v2_dist.py +9 -34
  43. birder/scripts/train_franca.py +9 -34
  44. birder/scripts/train_i_jepa.py +9 -34
  45. birder/scripts/train_ibot.py +9 -34
  46. birder/scripts/train_kd.py +156 -64
  47. birder/scripts/train_mim.py +10 -34
  48. birder/scripts/train_mmcr.py +10 -34
  49. birder/scripts/train_rotnet.py +10 -34
  50. birder/scripts/train_simclr.py +10 -34
  51. birder/scripts/train_vicreg.py +10 -34
  52. birder/tools/auto_anchors.py +20 -1
  53. birder/tools/pack.py +172 -103
  54. birder/tools/show_det_iterator.py +10 -1
  55. birder/version.py +1 -1
  56. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
  57. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
  58. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
  59. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
  60. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
  61. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,205 @@
1
+ """
2
+ Shared YOLO anchor definitions and helpers.
3
+ """
4
+
5
+ import json
6
+ from collections.abc import Sequence
7
+ from typing import Any
8
+ from typing import Literal
9
+ from typing import NotRequired
10
+ from typing import TypedDict
11
+ from typing import overload
12
+
13
+ AnchorPair = tuple[float, float]
14
+ AnchorGroup = list[AnchorPair]
15
+ AnchorGroups = list[AnchorGroup]
16
+ AnchorLike = AnchorGroups | AnchorGroup
17
+
18
+ # Default anchors from yolo.cfg (COCO dataset), in grid units
19
+ YOLO_V2_ANCHORS: AnchorGroup = [
20
+ (0.57273, 0.677385),
21
+ (1.87446, 2.06253),
22
+ (3.33843, 5.47434),
23
+ (7.88282, 3.52778),
24
+ (9.77052, 9.16828),
25
+ ]
26
+
27
+ # Default anchors from YOLO v3 paper (sorted by area, small to large)
28
+ # These values are in absolute pixels (width, height) computed using K-Means
29
+ # on the COCO dataset with a reference input size of 416x416.
30
+ YOLO_V3_ANCHORS: AnchorGroups = [
31
+ [(10.0, 13.0), (16.0, 30.0), (33.0, 23.0)], # Small objects (stride 8)
32
+ [(30.0, 61.0), (62.0, 45.0), (59.0, 119.0)], # Medium objects (stride 16)
33
+ [(116.0, 90.0), (156.0, 198.0), (373.0, 326.0)], # Large objects (stride 32)
34
+ ]
35
+
36
+ # Default anchors from YOLO v4 (COCO), in pixels
37
+ YOLO_V4_ANCHORS: AnchorGroups = [
38
+ [(12.0, 16.0), (19.0, 36.0), (40.0, 28.0)], # Small
39
+ [(36.0, 75.0), (76.0, 55.0), (72.0, 146.0)], # Medium
40
+ [(142.0, 110.0), (192.0, 243.0), (459.0, 401.0)], # Large
41
+ ]
42
+
43
+ # Default anchors from YOLO v4 Tiny (COCO), in pixels
44
+ YOLO_V4_TINY_ANCHORS: AnchorGroups = [
45
+ [(10.0, 14.0), (23.0, 27.0), (37.0, 58.0)], # Medium
46
+ [(81.0, 82.0), (135.0, 169.0), (344.0, 319.0)], # Large
47
+ ]
48
+
49
+
50
+ class AnchorPreset(TypedDict):
51
+ anchors: AnchorLike
52
+ format: Literal["grid", "pixels"]
53
+ size: tuple[int, int]
54
+ strides: NotRequired[Sequence[int]]
55
+
56
+
57
+ ANCHOR_PRESETS: dict[str, AnchorPreset] = {
58
+ "yolo_v2": {"anchors": YOLO_V2_ANCHORS, "format": "grid", "size": (416, 416), "strides": (32,)},
59
+ "yolo_v3": {"anchors": YOLO_V3_ANCHORS, "format": "pixels", "size": (416, 416)},
60
+ "yolo_v4": {"anchors": YOLO_V4_ANCHORS, "format": "pixels", "size": (608, 608)},
61
+ "yolo_v4_tiny": {"anchors": YOLO_V4_TINY_ANCHORS, "format": "pixels", "size": (416, 416)},
62
+ }
63
+
64
+
65
+ @overload
66
+ def scale_anchors(anchors: AnchorGroup, from_size: tuple[int, int], to_size: tuple[int, int]) -> AnchorGroup: ...
67
+
68
+
69
+ @overload
70
+ def scale_anchors(anchors: AnchorGroups, from_size: tuple[int, int], to_size: tuple[int, int]) -> AnchorGroups: ...
71
+
72
+
73
+ def scale_anchors(anchors: AnchorLike, from_size: tuple[int, int], to_size: tuple[int, int]) -> AnchorLike:
74
+ (anchor_groups, single) = _normalize_anchor_groups(anchors)
75
+
76
+ if from_size == to_size:
77
+ # Avoid aliasing default anchors in case they are mutated later
78
+ scaled: AnchorGroups = [list(group) for group in anchor_groups]
79
+ if single is True:
80
+ return scaled[0]
81
+
82
+ return scaled
83
+
84
+ scale_h = to_size[0] / from_size[0]
85
+ scale_w = to_size[1] / from_size[1]
86
+ scaled = [[(w * scale_w, h * scale_h) for (w, h) in group] for group in anchor_groups]
87
+
88
+ if single is True:
89
+ return scaled[0]
90
+
91
+ return scaled
92
+
93
+
94
+ @overload
95
+ def pixels_to_grid(anchors: AnchorGroup, strides: Sequence[int]) -> AnchorGroup: ...
96
+
97
+
98
+ @overload
99
+ def pixels_to_grid(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroups: ...
100
+
101
+
102
+ def pixels_to_grid(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
103
+ (anchor_groups, single) = _normalize_anchor_groups(anchors)
104
+ if len(anchor_groups) != len(strides):
105
+ raise ValueError("strides must provide one value per anchor scale")
106
+
107
+ converted: AnchorGroups = []
108
+ for group, stride in zip(anchor_groups, strides):
109
+ converted.append([(w / stride, h / stride) for (w, h) in group])
110
+
111
+ if single is True:
112
+ return converted[0]
113
+
114
+ return converted
115
+
116
+
117
+ @overload
118
+ def grid_to_pixels(anchors: AnchorGroup, strides: Sequence[int]) -> AnchorGroup: ...
119
+
120
+
121
+ @overload
122
+ def grid_to_pixels(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroups: ...
123
+
124
+
125
+ def grid_to_pixels(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
126
+ (anchor_groups, single) = _normalize_anchor_groups(anchors)
127
+ if len(anchor_groups) != len(strides):
128
+ raise ValueError("strides must provide one value per anchor scale")
129
+
130
+ converted: AnchorGroups = []
131
+ for group, stride in zip(anchor_groups, strides):
132
+ converted.append([(w * stride, h * stride) for (w, h) in group])
133
+
134
+ if single is True:
135
+ return converted[0]
136
+
137
+ return converted
138
+
139
+
140
+ def _normalize_anchor_groups(anchors: AnchorLike) -> tuple[AnchorGroups, bool]:
141
+ if len(anchors) > 0 and _is_anchor_pair(anchors[0]) is True:
142
+ return ([anchors], True) # type: ignore[list-item]
143
+
144
+ return (anchors, False) # type: ignore[return-value]
145
+
146
+
147
+ def _is_anchor_pair(value: Any) -> bool:
148
+ if not isinstance(value, Sequence) or len(value) != 2:
149
+ return False
150
+
151
+ return all(isinstance(item, (float, int)) for item in value)
152
+
153
+
154
+ def _resolve_anchors(
155
+ preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
156
+ ) -> AnchorLike:
157
+ if preset.endswith(".json") is True:
158
+ with open(preset, "r", encoding="utf-8") as handle:
159
+ preset_spec = json.load(handle)
160
+ else:
161
+ if preset not in ANCHOR_PRESETS:
162
+ raise ValueError(f"Unknown anchor preset: {preset}")
163
+
164
+ preset_spec = ANCHOR_PRESETS[preset]
165
+
166
+ anchors = preset_spec["anchors"]
167
+ preset_size = tuple(preset_spec["size"])
168
+ preset_format = preset_spec["format"]
169
+ if preset_format == "grid":
170
+ if "strides" not in preset_spec:
171
+ raise ValueError("Preset is missing strides required for grid anchors")
172
+
173
+ preset_strides = preset_spec["strides"]
174
+ anchors = grid_to_pixels(anchors, preset_strides)
175
+
176
+ anchors = scale_anchors(anchors, preset_size, model_size)
177
+ if anchor_format == "pixels":
178
+ return anchors
179
+
180
+ if anchor_format == "grid":
181
+ return pixels_to_grid(anchors, model_strides)
182
+
183
+ raise ValueError(f"Unsupported anchor format: {anchor_format}")
184
+
185
+
186
+ def resolve_anchor_group(
187
+ preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
188
+ ) -> AnchorGroup:
189
+ anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
190
+ (anchor_groups, single) = _normalize_anchor_groups(anchors)
191
+ if single is False:
192
+ raise ValueError("Expected a single anchor group for this model")
193
+
194
+ return anchor_groups[0]
195
+
196
+
197
+ def resolve_anchor_groups(
198
+ preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
199
+ ) -> AnchorGroups:
200
+ anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
201
+ (anchor_groups, single) = _normalize_anchor_groups(anchors)
202
+ if single is True:
203
+ raise ValueError("Expected multiple anchor groups for this model")
204
+
205
+ return anchor_groups
@@ -17,18 +17,11 @@ from torch import nn
17
17
  from torchvision.ops import Conv2dNormActivation
18
18
  from torchvision.ops import boxes as box_ops
19
19
 
20
+ from birder.model_registry import registry
20
21
  from birder.net.base import DetectorBackbone
21
22
  from birder.net.detection.base import DetectionBaseNet
22
23
  from birder.net.detection.base import ImageList
23
-
24
- # Default anchors from yolo.cfg (COCO dataset)
25
- DEFAULT_ANCHORS = [
26
- (0.57273, 0.677385),
27
- (1.87446, 2.06253),
28
- (3.33843, 5.47434),
29
- (7.88282, 3.52778),
30
- (9.77052, 9.16828),
31
- ]
24
+ from birder.net.detection.yolo_anchors import resolve_anchor_group
32
25
 
33
26
 
34
27
  def decode_predictions(
@@ -102,8 +95,8 @@ def decode_predictions(
102
95
  class YOLOAnchorGenerator(nn.Module):
103
96
  def __init__(self, anchors: list[tuple[float, float]]) -> None:
104
97
  super().__init__()
105
- self.anchors = anchors
106
- self.num_anchors = len(anchors)
98
+ self.anchors = nn.Buffer(torch.tensor(anchors, dtype=torch.float32))
99
+ self.num_anchors: int = self.anchors.size(0)
107
100
 
108
101
  def num_anchors_per_location(self) -> int:
109
102
  return self.num_anchors
@@ -134,8 +127,7 @@ class YOLOAnchorGenerator(nn.Module):
134
127
  grid = torch.stack([grid_x, grid_y], dim=-1)
135
128
 
136
129
  # Scale anchors to feature map stride (anchors are in grid units)
137
- anchors_tensor = torch.tensor(self.anchors, device=device, dtype=dtype)
138
- anchors_tensor = anchors_tensor * torch.tensor([stride_w, stride_h], device=device, dtype=dtype)
130
+ anchors_tensor = self.anchors * torch.tensor([stride_w, stride_h], device=device, dtype=dtype)
139
131
 
140
132
  # Store strides as tensor
141
133
  stride = torch.tensor([stride_h, stride_w], device=device, dtype=dtype)
@@ -222,7 +214,6 @@ class YOLOHead(nn.Module):
222
214
  # pylint: disable=invalid-name
223
215
  class YOLO_v2(DetectionBaseNet):
224
216
  default_size = (416, 416)
225
- auto_register = True
226
217
 
227
218
  def __init__(
228
219
  self,
@@ -234,7 +225,7 @@ class YOLO_v2(DetectionBaseNet):
234
225
  export_mode: bool = False,
235
226
  ) -> None:
236
227
  super().__init__(num_classes, backbone, config=config, size=size, export_mode=export_mode)
237
- assert self.config is None, "config not supported"
228
+ assert self.config is not None, "must set config"
238
229
 
239
230
  self.num_classes = self.num_classes - 1
240
231
 
@@ -242,13 +233,19 @@ class YOLO_v2(DetectionBaseNet):
242
233
  nms_thresh = 0.45
243
234
  detections_per_img = 300
244
235
  mid_channels = 1024
245
- self.ignore_thresh = 0.5
246
- self.noobj_coeff = 0.5
247
- self.coord_coeff = 5.0
248
- self.obj_coeff = 1.0
249
- self.cls_coeff = 1.0
236
+ ignore_thresh = 0.5
237
+ noobj_coeff = 0.5
238
+ coord_coeff = 5.0
239
+ obj_coeff = 1.0
240
+ cls_coeff = 1.0
241
+ anchor_spec = self.config["anchors"]
242
+
243
+ self.ignore_thresh = ignore_thresh
244
+ self.noobj_coeff = noobj_coeff
245
+ self.coord_coeff = coord_coeff
246
+ self.obj_coeff = obj_coeff
247
+ self.cls_coeff = cls_coeff
250
248
 
251
- self.anchors = DEFAULT_ANCHORS
252
249
  self.score_thresh = score_thresh
253
250
  self.nms_thresh = nms_thresh
254
251
  self.detections_per_img = detections_per_img
@@ -258,7 +255,8 @@ class YOLO_v2(DetectionBaseNet):
258
255
 
259
256
  self.neck = YOLONeck(self.backbone.return_channels, mid_channels)
260
257
 
261
- self.anchor_generator = YOLOAnchorGenerator(self.anchors)
258
+ anchors = resolve_anchor_group(anchor_spec, anchor_format="grid", model_size=self.size, model_strides=(32,))
259
+ self.anchor_generator = YOLOAnchorGenerator(anchors)
262
260
  num_anchors = self.anchor_generator.num_anchors_per_location()
263
261
  self.head = YOLOHead(self.neck.out_channels, num_anchors, self.num_classes)
264
262
 
@@ -319,7 +317,7 @@ class YOLO_v2(DetectionBaseNet):
319
317
  device = predictions.device
320
318
  dtype = predictions.dtype
321
319
  (batch_size, _, H, W) = predictions.size()
322
- num_anchors = len(self.anchors)
320
+ num_anchors = self.anchor_generator.num_anchors
323
321
 
324
322
  stride_h = stride[0]
325
323
  stride_w = stride[1]
@@ -423,7 +421,7 @@ class YOLO_v2(DetectionBaseNet):
423
421
 
424
422
  device = predictions.device
425
423
  (N, _, H, W) = predictions.size()
426
- num_anchors = len(self.anchors)
424
+ num_anchors = self.anchor_generator.num_anchors
427
425
 
428
426
  predictions = predictions.view(N, num_anchors, 5 + self.num_classes, H, W)
429
427
  predictions = predictions.permute(0, 1, 3, 4, 2).contiguous()
@@ -552,3 +550,6 @@ class YOLO_v2(DetectionBaseNet):
552
550
  detections = self.postprocess_detections(decoded_predictions, images.image_sizes)
553
551
 
554
552
  return (detections, losses)
553
+
554
+
555
+ registry.register_model_config("yolo_v2", YOLO_v2, config={"anchors": "yolo_v2"})
@@ -17,32 +17,11 @@ from torch import nn
17
17
  from torchvision.ops import Conv2dNormActivation
18
18
  from torchvision.ops import boxes as box_ops
19
19
 
20
+ from birder.model_registry import registry
20
21
  from birder.net.base import DetectorBackbone
21
22
  from birder.net.detection.base import DetectionBaseNet
22
23
  from birder.net.detection.base import ImageList
23
-
24
- # Default anchors from YOLO v3 paper (sorted by area, small to large)
25
- # These values are in absolute pixels (width, height) computed using K-Means
26
- # on the COCO dataset with a reference input size of 416x416.
27
- DEFAULT_ANCHORS = [
28
- [(10.0, 13.0), (16.0, 30.0), (33.0, 23.0)], # Small objects (stride 8)
29
- [(30.0, 61.0), (62.0, 45.0), (59.0, 119.0)], # Medium objects (stride 16)
30
- [(116.0, 90.0), (156.0, 198.0), (373.0, 326.0)], # Large objects (stride 32)
31
- ]
32
-
33
-
34
- def scale_anchors(
35
- anchors: list[list[tuple[float, float]]],
36
- from_size: tuple[int, int],
37
- to_size: tuple[int, int],
38
- ) -> list[list[tuple[float, float]]]:
39
- if from_size == to_size:
40
- # Avoid aliasing default anchors in case they are mutated later
41
- return [list(scale) for scale in anchors]
42
-
43
- scale_h = to_size[0] / from_size[0]
44
- scale_w = to_size[1] / from_size[1]
45
- return [[(w * scale_w, h * scale_h) for (w, h) in scale] for scale in anchors]
24
+ from birder.net.detection.yolo_anchors import resolve_anchor_groups
46
25
 
47
26
 
48
27
  def decode_predictions(
@@ -116,11 +95,20 @@ def decode_predictions(
116
95
  class YOLOAnchorGenerator(nn.Module):
117
96
  def __init__(self, anchors: list[list[tuple[float, float]]]) -> None:
118
97
  super().__init__()
119
- self.anchors = anchors
120
- self.num_scales = len(anchors)
98
+ self.anchors = nn.Buffer(torch.tensor(anchors, dtype=torch.float32))
99
+ self.num_scales = self.anchors.size(0)
121
100
 
122
101
  def num_anchors_per_location(self) -> list[int]:
123
- return [len(a) for a in self.anchors]
102
+ return [a.size(0) for a in self.anchors]
103
+
104
+ def scale_anchors(self, from_size: tuple[int, int], to_size: tuple[int, int]) -> None:
105
+ if from_size == to_size:
106
+ return
107
+
108
+ scale_h = to_size[0] / from_size[0]
109
+ scale_w = to_size[1] / from_size[1]
110
+ self.anchors[..., 0].mul_(scale_w)
111
+ self.anchors[..., 1].mul_(scale_h)
124
112
 
125
113
  def forward(
126
114
  self, image_list: ImageList, feature_maps: list[torch.Tensor]
@@ -152,7 +140,7 @@ class YOLOAnchorGenerator(nn.Module):
152
140
  grid = torch.stack([grid_x, grid_y], dim=-1)
153
141
 
154
142
  # Select anchors for this scale
155
- anchors_for_scale = torch.tensor(self.anchors[idx], device=device, dtype=dtype)
143
+ anchors_for_scale = self.anchors[idx]
156
144
 
157
145
  # Store strides as tensor
158
146
  strides = torch.tensor([stride_h, stride_w], device=device, dtype=dtype)
@@ -321,7 +309,6 @@ class YOLONeck(nn.Module):
321
309
  # pylint: disable=invalid-name
322
310
  class YOLO_v3(DetectionBaseNet):
323
311
  default_size = (416, 416)
324
- auto_register = True
325
312
 
326
313
  def __init__(
327
314
  self,
@@ -333,20 +320,26 @@ class YOLO_v3(DetectionBaseNet):
333
320
  export_mode: bool = False,
334
321
  ) -> None:
335
322
  super().__init__(num_classes, backbone, config=config, size=size, export_mode=export_mode)
336
- assert self.config is None, "config not supported"
323
+ assert self.config is not None, "must set config"
337
324
 
338
325
  self.num_classes = self.num_classes - 1
339
326
 
340
327
  score_thresh = 0.05
341
328
  nms_thresh = 0.45
342
329
  detections_per_img = 300
343
- self.ignore_thresh = 0.5
344
- self.noobj_coeff = 0.2
345
- self.coord_coeff = 5.0
346
- self.obj_coeff = 1.0
347
- self.cls_coeff = 1.0
330
+ ignore_thresh = 0.5
331
+ noobj_coeff = 0.2
332
+ coord_coeff = 5.0
333
+ obj_coeff = 1.0
334
+ cls_coeff = 1.0
335
+ anchor_spec = self.config["anchors"]
336
+
337
+ self.ignore_thresh = ignore_thresh
338
+ self.noobj_coeff = noobj_coeff
339
+ self.coord_coeff = coord_coeff
340
+ self.obj_coeff = obj_coeff
341
+ self.cls_coeff = cls_coeff
348
342
 
349
- self.anchors = scale_anchors(DEFAULT_ANCHORS, self.default_size, self.size)
350
343
  self.score_thresh = score_thresh
351
344
  self.nms_thresh = nms_thresh
352
345
  self.detections_per_img = detections_per_img
@@ -356,7 +349,10 @@ class YOLO_v3(DetectionBaseNet):
356
349
 
357
350
  self.neck = YOLONeck(self.backbone.return_channels)
358
351
 
359
- self.anchor_generator = YOLOAnchorGenerator(self.anchors)
352
+ anchors = resolve_anchor_groups(
353
+ anchor_spec, anchor_format="pixels", model_size=self.size, model_strides=(8, 16, 32)
354
+ )
355
+ self.anchor_generator = YOLOAnchorGenerator(anchors)
360
356
  num_anchors = self.anchor_generator.num_anchors_per_location()
361
357
  self.head = YOLOHead(self.neck.out_channels, num_anchors, self.num_classes)
362
358
 
@@ -376,8 +372,7 @@ class YOLO_v3(DetectionBaseNet):
376
372
  super().adjust_size(new_size)
377
373
 
378
374
  if adjust_anchors is True:
379
- self.anchors = scale_anchors(self.anchors, old_size, new_size)
380
- self.anchor_generator.anchors = self.anchors
375
+ self.anchor_generator.scale_anchors(old_size, new_size)
381
376
 
382
377
  def freeze(self, freeze_classifier: bool = True) -> None:
383
378
  for param in self.parameters():
@@ -435,7 +430,7 @@ class YOLO_v3(DetectionBaseNet):
435
430
 
436
431
  # Build flat list of all anchors with their scale indices
437
432
  all_anchors = torch.concat(anchors, dim=0)
438
- anchors_per_scale = [len(self.anchors[i]) for i in range(num_scales)]
433
+ anchors_per_scale = self.anchor_generator.num_anchors_per_location()
439
434
  cumsum_anchors = torch.tensor([0] + anchors_per_scale, device=device).cumsum(0)
440
435
 
441
436
  # Get grid sizes and strides for each scale
@@ -586,6 +581,7 @@ class YOLO_v3(DetectionBaseNet):
586
581
  (target_tensors, obj_masks, noobj_masks) = self._build_targets(predictions, targets, anchors, strides)
587
582
 
588
583
  device = predictions[0].device
584
+ anchors_per_scale = self.anchor_generator.num_anchors_per_location()
589
585
  coord_loss = torch.tensor(0.0, device=device)
590
586
  obj_loss = torch.tensor(0.0, device=device)
591
587
  noobj_loss = torch.tensor(0.0, device=device)
@@ -594,7 +590,7 @@ class YOLO_v3(DetectionBaseNet):
594
590
  num_obj = 0
595
591
  for scale_idx, pred in enumerate(predictions):
596
592
  (N, _, H, W) = pred.size()
597
- num_anchors_scale = len(self.anchors[scale_idx])
593
+ num_anchors_scale = anchors_per_scale[scale_idx]
598
594
 
599
595
  pred = pred.view(N, num_anchors_scale, 5 + self.num_classes, H, W)
600
596
  pred = pred.permute(0, 1, 3, 4, 2).contiguous()
@@ -730,3 +726,6 @@ class YOLO_v3(DetectionBaseNet):
730
726
  detections = self.postprocess_detections(decoded_predictions, images.image_sizes)
731
727
 
732
728
  return (detections, losses)
729
+
730
+
731
+ registry.register_model_config("yolo_v3", YOLO_v3, config={"anchors": "yolo_v3"})
@@ -17,18 +17,12 @@ from torch import nn
17
17
  from torchvision.ops import Conv2dNormActivation
18
18
  from torchvision.ops import boxes as box_ops
19
19
 
20
+ from birder.model_registry import registry
20
21
  from birder.net.base import DetectorBackbone
21
22
  from birder.net.detection.base import DetectionBaseNet
23
+ from birder.net.detection.yolo_anchors import resolve_anchor_groups
22
24
  from birder.net.detection.yolo_v3 import YOLOAnchorGenerator
23
25
  from birder.net.detection.yolo_v3 import YOLOHead
24
- from birder.net.detection.yolo_v3 import scale_anchors
25
-
26
- # Default anchors from YOLO v4 (COCO)
27
- DEFAULT_ANCHORS = [
28
- [(12.0, 16.0), (19.0, 36.0), (40.0, 28.0)], # Small
29
- [(36.0, 75.0), (76.0, 55.0), (72.0, 146.0)], # Medium
30
- [(142.0, 110.0), (192.0, 243.0), (459.0, 401.0)], # Large
31
- ]
32
26
 
33
27
  # Scale factors per detection scale to eliminate grid sensitivity
34
28
  DEFAULT_SCALE_XY = [1.2, 1.1, 1.05] # [small, medium, large]
@@ -59,7 +53,6 @@ def decode_predictions(
59
53
  Number of classes.
60
54
  scale_xy
61
55
  Scale factor for grid sensitivity elimination.
62
- YOLOv4 uses 1.05-1.2 depending on scale level.
63
56
 
64
57
  Returns
65
58
  -------
@@ -378,7 +371,6 @@ class YOLONeck(nn.Module):
378
371
  # pylint: disable=invalid-name
379
372
  class YOLO_v4(DetectionBaseNet):
380
373
  default_size = (608, 608)
381
- auto_register = True
382
374
 
383
375
  def __init__(
384
376
  self,
@@ -390,22 +382,26 @@ class YOLO_v4(DetectionBaseNet):
390
382
  export_mode: bool = False,
391
383
  ) -> None:
392
384
  super().__init__(num_classes, backbone, config=config, size=size, export_mode=export_mode)
393
- assert self.config is None, "config not supported"
385
+ assert self.config is not None, "must set config"
394
386
 
395
387
  self.num_classes = self.num_classes - 1
396
388
 
397
389
  score_thresh = 0.05
398
390
  nms_thresh = 0.45
399
391
  detections_per_img = 300
400
- self.ignore_thresh = 0.7
401
-
402
- # Loss coefficients
403
- self.noobj_coeff = 0.25
404
- self.coord_coeff = 3.0
405
- self.obj_coeff = 1.0
406
- self.cls_coeff = 1.0
407
-
408
- self.anchors = scale_anchors(DEFAULT_ANCHORS, self.default_size, self.size)
392
+ ignore_thresh = 0.7
393
+ noobj_coeff = 0.25
394
+ coord_coeff = 3.0
395
+ obj_coeff = 1.0
396
+ cls_coeff = 1.0
397
+ label_smoothing = 0.1
398
+ anchor_spec = self.config["anchors"]
399
+
400
+ self.ignore_thresh = ignore_thresh
401
+ self.noobj_coeff = noobj_coeff
402
+ self.coord_coeff = coord_coeff
403
+ self.obj_coeff = obj_coeff
404
+ self.cls_coeff = cls_coeff
409
405
  self.scale_xy = DEFAULT_SCALE_XY
410
406
  self.score_thresh = score_thresh
411
407
  self.nms_thresh = nms_thresh
@@ -414,13 +410,16 @@ class YOLO_v4(DetectionBaseNet):
414
410
  self.backbone.return_channels = self.backbone.return_channels[-3:]
415
411
  self.backbone.return_stages = self.backbone.return_stages[-3:]
416
412
 
417
- self.label_smoothing = 0.1
413
+ self.label_smoothing = label_smoothing
418
414
  self.smooth_positive = 1.0 - self.label_smoothing
419
415
  self.smooth_negative = self.label_smoothing / self.num_classes
420
416
 
421
417
  self.neck = YOLONeck(self.backbone.return_channels)
422
418
 
423
- self.anchor_generator = YOLOAnchorGenerator(self.anchors)
419
+ anchors = resolve_anchor_groups(
420
+ anchor_spec, anchor_format="pixels", model_size=self.size, model_strides=(8, 16, 32)
421
+ )
422
+ self.anchor_generator = YOLOAnchorGenerator(anchors)
424
423
  num_anchors = self.anchor_generator.num_anchors_per_location()
425
424
  self.head = YOLOHead(self.neck.out_channels, num_anchors, self.num_classes)
426
425
 
@@ -441,8 +440,7 @@ class YOLO_v4(DetectionBaseNet):
441
440
  super().adjust_size(new_size)
442
441
 
443
442
  if adjust_anchors is True:
444
- self.anchors = scale_anchors(self.anchors, old_size, new_size)
445
- self.anchor_generator = YOLOAnchorGenerator(self.anchors)
443
+ self.anchor_generator.scale_anchors(old_size, new_size)
446
444
 
447
445
  def freeze(self, freeze_classifier: bool = True) -> None:
448
446
  for param in self.parameters():
@@ -500,7 +498,7 @@ class YOLO_v4(DetectionBaseNet):
500
498
 
501
499
  # Build flat list of all anchors with their scale indices
502
500
  all_anchors = torch.concat(anchors, dim=0)
503
- anchors_per_scale = [len(self.anchors[i]) for i in range(num_scales)]
501
+ anchors_per_scale = self.anchor_generator.num_anchors_per_location()
504
502
  cumsum_anchors = torch.tensor([0] + anchors_per_scale, device=device).cumsum(0)
505
503
 
506
504
  # Get grid sizes and strides for each scale
@@ -651,6 +649,7 @@ class YOLO_v4(DetectionBaseNet):
651
649
  (target_tensors, obj_masks, noobj_masks) = self._build_targets(predictions, targets, anchors, strides)
652
650
 
653
651
  device = predictions[0].device
652
+ anchors_per_scale = self.anchor_generator.num_anchors_per_location()
654
653
  coord_loss = torch.tensor(0.0, device=device)
655
654
  obj_loss = torch.tensor(0.0, device=device)
656
655
  noobj_loss = torch.tensor(0.0, device=device)
@@ -659,7 +658,7 @@ class YOLO_v4(DetectionBaseNet):
659
658
  num_obj = 0
660
659
  for scale_idx, pred in enumerate(predictions):
661
660
  (N, _, H, W) = pred.size()
662
- num_anchors_scale = len(self.anchors[scale_idx])
661
+ num_anchors_scale = anchors_per_scale[scale_idx]
663
662
  stride_h = strides[scale_idx][0]
664
663
  stride_w = strides[scale_idx][1]
665
664
  scale_xy = self.scale_xy[scale_idx]
@@ -829,3 +828,6 @@ class YOLO_v4(DetectionBaseNet):
829
828
  detections = self.postprocess_detections(decoded_predictions, images.image_sizes)
830
829
 
831
830
  return (detections, losses)
831
+
832
+
833
+ registry.register_model_config("yolo_v4", YOLO_v4, config={"anchors": "yolo_v4"})