birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/__init__.py +13 -0
- birder/adversarial/base.py +101 -0
- birder/adversarial/deepfool.py +173 -0
- birder/adversarial/fgsm.py +51 -18
- birder/adversarial/pgd.py +79 -28
- birder/adversarial/simba.py +172 -0
- birder/common/training_cli.py +11 -3
- birder/common/training_utils.py +18 -1
- birder/inference/data_parallel.py +1 -2
- birder/introspection/__init__.py +10 -6
- birder/introspection/attention_rollout.py +122 -54
- birder/introspection/base.py +73 -29
- birder/introspection/gradcam.py +71 -100
- birder/introspection/guided_backprop.py +146 -72
- birder/introspection/transformer_attribution.py +182 -0
- birder/net/detection/deformable_detr.py +14 -12
- birder/net/detection/detr.py +7 -3
- birder/net/detection/rt_detr_v1.py +3 -3
- birder/net/detection/yolo_v3.py +6 -11
- birder/net/detection/yolo_v4.py +7 -18
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/fastvit.py +1 -1
- birder/net/mim/mae_vit.py +7 -8
- birder/net/pit.py +1 -1
- birder/net/resnet_v1.py +94 -34
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +4 -2
- birder/results/gui.py +15 -2
- birder/scripts/predict_detection.py +33 -1
- birder/scripts/train.py +24 -17
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +12 -9
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +42 -18
- birder/scripts/train_dino_v1.py +10 -7
- birder/scripts/train_dino_v2.py +10 -7
- birder/scripts/train_dino_v2_dist.py +17 -7
- birder/scripts/train_franca.py +10 -7
- birder/scripts/train_i_jepa.py +17 -13
- birder/scripts/train_ibot.py +10 -7
- birder/scripts/train_kd.py +24 -18
- birder/scripts/train_mim.py +11 -10
- birder/scripts/train_mmcr.py +10 -7
- birder/scripts/train_rotnet.py +10 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/__main__.py +6 -2
- birder/tools/adversarial.py +147 -96
- birder/tools/auto_anchors.py +361 -0
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +58 -31
- birder/version.py +1 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fit YOLO-style anchor boxes using k-means based on COCO-format annotations.
|
|
3
|
+
|
|
4
|
+
Generated by gpt-5.2-codex xhigh.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import math
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from pprint import pformat
|
|
13
|
+
from typing import Any
|
|
14
|
+
from typing import Literal
|
|
15
|
+
from typing import Optional
|
|
16
|
+
from typing import TypedDict
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from birder.common import cli
|
|
21
|
+
from birder.conf import settings
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AnchorPreset(TypedDict):
|
|
27
|
+
num_scales: int
|
|
28
|
+
num_anchors: int
|
|
29
|
+
default_size: tuple[int, int]
|
|
30
|
+
format: Literal["pixels", "grid"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
MODEL_PRESETS: dict[str, AnchorPreset] = {
|
|
34
|
+
"yolo_v2": {"num_scales": 1, "num_anchors": 5, "default_size": (416, 416), "format": "grid"},
|
|
35
|
+
"yolo_v3": {"num_scales": 3, "num_anchors": 9, "default_size": (416, 416), "format": "pixels"},
|
|
36
|
+
"yolo_v4": {"num_scales": 3, "num_anchors": 9, "default_size": (608, 608), "format": "pixels"},
|
|
37
|
+
"yolo_v4_tiny": {"num_scales": 2, "num_anchors": 6, "default_size": (416, 416), "format": "pixels"},
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _load_ignore_list(ignore_file: Optional[str]) -> set[str]:
|
|
42
|
+
if ignore_file is None:
|
|
43
|
+
return set()
|
|
44
|
+
|
|
45
|
+
with open(ignore_file, "r", encoding="utf-8") as handle:
|
|
46
|
+
return {line.strip() for line in handle if line.strip()}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _load_coco_boxes(
|
|
50
|
+
coco_json_path: str, target_size: tuple[int, int], ignore_list: set[str], min_size: float, ignore_crowd: bool
|
|
51
|
+
) -> tuple[torch.Tensor, dict[str, int]]:
|
|
52
|
+
coco_path = Path(coco_json_path)
|
|
53
|
+
if coco_path.exists() is False:
|
|
54
|
+
raise ValueError(f"COCO json not found at {coco_path}")
|
|
55
|
+
|
|
56
|
+
with open(coco_path, "r", encoding="utf-8") as handle:
|
|
57
|
+
data = json.load(handle)
|
|
58
|
+
|
|
59
|
+
images = {}
|
|
60
|
+
for image in data.get("images", []):
|
|
61
|
+
image_id = image.get("id")
|
|
62
|
+
if image_id is None:
|
|
63
|
+
continue
|
|
64
|
+
images[image_id] = (image.get("width"), image.get("height"), image.get("file_name", ""))
|
|
65
|
+
|
|
66
|
+
stats = {
|
|
67
|
+
"total_annotations": 0,
|
|
68
|
+
"used_annotations": 0,
|
|
69
|
+
"crowd_annotations": 0,
|
|
70
|
+
"invalid_bbox": 0,
|
|
71
|
+
"ignored_images": 0,
|
|
72
|
+
"missing_images": 0,
|
|
73
|
+
"missing_size": 0,
|
|
74
|
+
"too_small": 0,
|
|
75
|
+
}
|
|
76
|
+
boxes: list[tuple[float, float]] = []
|
|
77
|
+
target_h = float(target_size[0])
|
|
78
|
+
target_w = float(target_size[1])
|
|
79
|
+
for annotation in data.get("annotations", []):
|
|
80
|
+
stats["total_annotations"] += 1
|
|
81
|
+
if ignore_crowd is True and annotation.get("iscrowd", 0) == 1:
|
|
82
|
+
stats["crowd_annotations"] += 1
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
bbox = annotation.get("bbox")
|
|
86
|
+
if bbox is None or len(bbox) != 4:
|
|
87
|
+
stats["invalid_bbox"] += 1
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
image_id = annotation.get("image_id")
|
|
91
|
+
if image_id not in images:
|
|
92
|
+
stats["missing_images"] += 1
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
(img_w, img_h, file_name) = images[image_id]
|
|
96
|
+
if file_name in ignore_list:
|
|
97
|
+
stats["ignored_images"] += 1
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
if img_w in {None, 0} or img_h in {None, 0}:
|
|
101
|
+
stats["missing_size"] += 1
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
bbox_w = float(bbox[2])
|
|
105
|
+
bbox_h = float(bbox[3])
|
|
106
|
+
if bbox_w <= 0.0 or bbox_h <= 0.0:
|
|
107
|
+
stats["invalid_bbox"] += 1
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
scaled_w = bbox_w / float(img_w) * target_w
|
|
111
|
+
scaled_h = bbox_h / float(img_h) * target_h
|
|
112
|
+
if scaled_w < min_size or scaled_h < min_size:
|
|
113
|
+
stats["too_small"] += 1
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
boxes.append((scaled_w, scaled_h))
|
|
117
|
+
|
|
118
|
+
stats["used_annotations"] = len(boxes)
|
|
119
|
+
if len(boxes) == 0:
|
|
120
|
+
raise ValueError("No valid bounding boxes found for anchor fitting")
|
|
121
|
+
|
|
122
|
+
return (torch.tensor(boxes, dtype=torch.float32), stats)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _wh_iou(boxes: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:
|
|
126
|
+
boxes = boxes[:, None, :]
|
|
127
|
+
anchors = anchors[None, :, :]
|
|
128
|
+
inter = torch.min(boxes, anchors).prod(dim=2)
|
|
129
|
+
union = boxes.prod(dim=2) + anchors.prod(dim=2) - inter
|
|
130
|
+
return inter / (union + 1e-9)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _kmeans_plusplus_init(boxes: torch.Tensor, num_anchors: int, generator: torch.Generator) -> torch.Tensor:
|
|
134
|
+
n_local_trials = 2 + int(math.log(num_anchors))
|
|
135
|
+
|
|
136
|
+
anchors = []
|
|
137
|
+
|
|
138
|
+
first_idx = torch.randint(0, boxes.size(0), (1,), generator=generator).item()
|
|
139
|
+
anchors.append(boxes[first_idx])
|
|
140
|
+
for _ in range(num_anchors - 1):
|
|
141
|
+
anchors_tensor = torch.stack(anchors)
|
|
142
|
+
ious = _wh_iou(boxes, anchors_tensor)
|
|
143
|
+
max_ious = ious.max(dim=1).values
|
|
144
|
+
|
|
145
|
+
min_distances = 1.0 - max_ious
|
|
146
|
+
squared_distances = min_distances**2
|
|
147
|
+
|
|
148
|
+
probs = squared_distances / (squared_distances.sum() + 1e-9)
|
|
149
|
+
cumulative_probs = torch.cumsum(probs, dim=0)
|
|
150
|
+
r = torch.rand(n_local_trials, generator=generator)
|
|
151
|
+
candidate_indices = torch.searchsorted(cumulative_probs, r, right=True)
|
|
152
|
+
|
|
153
|
+
candidate_indices = torch.clamp(candidate_indices, 0, boxes.size(0) - 1)
|
|
154
|
+
candidate_boxes = boxes[candidate_indices]
|
|
155
|
+
candidate_ious = _wh_iou(boxes, candidate_boxes) # (n_boxes, n_trials)
|
|
156
|
+
candidate_distances = 1.0 - candidate_ious
|
|
157
|
+
|
|
158
|
+
min_distances_expanded = min_distances.unsqueeze(1)
|
|
159
|
+
candidate_potentials = torch.min(min_distances_expanded, candidate_distances).sum(dim=0)
|
|
160
|
+
|
|
161
|
+
best_trial_idx = torch.argmin(candidate_potentials).item()
|
|
162
|
+
best_idx = candidate_indices[best_trial_idx].item()
|
|
163
|
+
anchors.append(boxes[best_idx])
|
|
164
|
+
|
|
165
|
+
return torch.stack(anchors)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _kmeans_anchors(
|
|
169
|
+
boxes: torch.Tensor, num_anchors: int, seed: Optional[int], max_iter: int
|
|
170
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
171
|
+
if boxes.size(0) < num_anchors:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Not enough boxes ({boxes.size(0)}) to fit {num_anchors} anchors, Reduce --num-anchors or add more data"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
generator = torch.Generator()
|
|
177
|
+
if seed is not None:
|
|
178
|
+
generator.manual_seed(seed)
|
|
179
|
+
|
|
180
|
+
anchors = _kmeans_plusplus_init(boxes, num_anchors, generator)
|
|
181
|
+
assignments = torch.full((boxes.size(0),), -1, dtype=torch.int64)
|
|
182
|
+
|
|
183
|
+
for _ in range(max_iter):
|
|
184
|
+
ious = _wh_iou(boxes, anchors)
|
|
185
|
+
new_assignments = torch.argmax(ious, dim=1)
|
|
186
|
+
if torch.equal(assignments, new_assignments):
|
|
187
|
+
break
|
|
188
|
+
|
|
189
|
+
assignments = new_assignments
|
|
190
|
+
for idx in range(num_anchors):
|
|
191
|
+
mask = assignments == idx
|
|
192
|
+
if mask.any():
|
|
193
|
+
anchors[idx] = boxes[mask].median(dim=0).values
|
|
194
|
+
else:
|
|
195
|
+
rand_idx = torch.randint(0, boxes.size(0), (1,), generator=generator).item()
|
|
196
|
+
anchors[idx] = boxes[rand_idx]
|
|
197
|
+
|
|
198
|
+
return (anchors, assignments)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _format_anchor_groups(anchor_groups: list[torch.Tensor], precision: int) -> list[list[tuple[float, float]]]:
|
|
202
|
+
formatted: list[list[tuple[float, float]]] = []
|
|
203
|
+
for group in anchor_groups:
|
|
204
|
+
formatted.append([(round(float(anchor[0]), precision), round(float(anchor[1]), precision)) for anchor in group])
|
|
205
|
+
|
|
206
|
+
return formatted
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _validate_args(
|
|
210
|
+
args: argparse.Namespace,
|
|
211
|
+
) -> tuple[tuple[int, int], int, int, Literal["pixels", "grid"], list[float]]:
|
|
212
|
+
preset = MODEL_PRESETS.get(args.preset) if args.preset is not None else None
|
|
213
|
+
size = cli.parse_size(args.size) if args.size is not None else (preset["default_size"] if preset else None)
|
|
214
|
+
if size is None:
|
|
215
|
+
raise cli.ValidationError("Missing --size. Provide --size or use a --preset")
|
|
216
|
+
|
|
217
|
+
num_scales = args.num_scales if args.num_scales is not None else (preset["num_scales"] if preset else None)
|
|
218
|
+
num_anchors = args.num_anchors if args.num_anchors is not None else (preset["num_anchors"] if preset else None)
|
|
219
|
+
output_format = args.format if args.format is not None else (preset["format"] if preset else None)
|
|
220
|
+
if num_scales is None or num_anchors is None or output_format is None:
|
|
221
|
+
raise cli.ValidationError(
|
|
222
|
+
"Missing configuration. Provide --num-scales, --num-anchors, and --format or use a --preset"
|
|
223
|
+
)
|
|
224
|
+
if num_scales < 1:
|
|
225
|
+
raise cli.ValidationError("--num-scales must be >= 1")
|
|
226
|
+
if num_anchors < 1:
|
|
227
|
+
raise cli.ValidationError("--num-anchors must be >= 1")
|
|
228
|
+
if num_anchors % num_scales != 0:
|
|
229
|
+
raise cli.ValidationError("--num-anchors must be divisible by --num-scales")
|
|
230
|
+
|
|
231
|
+
strides: list[float] = []
|
|
232
|
+
if output_format == "grid":
|
|
233
|
+
if args.stride is None:
|
|
234
|
+
raise cli.ValidationError("--format grid requires --stride values per scale")
|
|
235
|
+
|
|
236
|
+
strides = [float(value) for value in args.stride]
|
|
237
|
+
if len(strides) != num_scales:
|
|
238
|
+
raise cli.ValidationError("--stride must provide one value per scale when --format grid is used")
|
|
239
|
+
if any(value <= 0 for value in strides):
|
|
240
|
+
raise cli.ValidationError("--stride values must be > 0")
|
|
241
|
+
|
|
242
|
+
return (size, num_scales, num_anchors, output_format, strides)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def auto_anchors(args: argparse.Namespace) -> None:
|
|
246
|
+
(size, num_scales, num_anchors, output_format, strides) = _validate_args(args)
|
|
247
|
+
|
|
248
|
+
ignore_list = _load_ignore_list(args.ignore_file)
|
|
249
|
+
(boxes, stats) = _load_coco_boxes(
|
|
250
|
+
args.coco_json_path, size, ignore_list, args.min_size, ignore_crowd=not args.include_crowd
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if args.preset is not None:
|
|
254
|
+
logger.info(f"Using preset {args.preset}")
|
|
255
|
+
|
|
256
|
+
logger.info(f"Fitting anchors using size={size[0]}x{size[1]}")
|
|
257
|
+
logger.info(
|
|
258
|
+
f"Annotations: total={stats['total_annotations']}, used={stats['used_annotations']}, "
|
|
259
|
+
f"crowd={stats['crowd_annotations']}, invalid={stats['invalid_bbox']}, "
|
|
260
|
+
f"ignored={stats['ignored_images']}, missing={stats['missing_images']}, "
|
|
261
|
+
f"missing_size={stats['missing_size']}, too_small={stats['too_small']}"
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
(anchors, _assignments) = _kmeans_anchors(boxes, num_anchors, args.seed, args.max_iter)
|
|
265
|
+
areas = anchors.prod(dim=1)
|
|
266
|
+
anchors = anchors[torch.argsort(areas)]
|
|
267
|
+
anchors_per_scale = num_anchors // num_scales
|
|
268
|
+
anchor_groups = [anchors[i : i + anchors_per_scale] for i in range(0, num_anchors, anchors_per_scale)]
|
|
269
|
+
|
|
270
|
+
ious = _wh_iou(boxes, anchors)
|
|
271
|
+
best_iou = ious.max(dim=1).values
|
|
272
|
+
logger.info(f"Mean IoU: {best_iou.mean().item():.4f}")
|
|
273
|
+
|
|
274
|
+
formatted_groups = _format_anchor_groups(anchor_groups, args.precision)
|
|
275
|
+
if output_format == "pixels":
|
|
276
|
+
if num_scales == 1:
|
|
277
|
+
formatted_anchors: Any = formatted_groups[0]
|
|
278
|
+
else:
|
|
279
|
+
formatted_anchors = formatted_groups
|
|
280
|
+
|
|
281
|
+
print("Anchors (pixels):")
|
|
282
|
+
print(pformat(formatted_anchors))
|
|
283
|
+
|
|
284
|
+
if output_format == "grid":
|
|
285
|
+
grid_groups: list[torch.Tensor] = []
|
|
286
|
+
for group, stride in zip(anchor_groups, strides):
|
|
287
|
+
grid_group = group.clone()
|
|
288
|
+
grid_group[:, 0] = grid_group[:, 0] / stride
|
|
289
|
+
grid_group[:, 1] = grid_group[:, 1] / stride
|
|
290
|
+
grid_groups.append(grid_group)
|
|
291
|
+
|
|
292
|
+
formatted_grid = _format_anchor_groups(grid_groups, args.precision)
|
|
293
|
+
if num_scales == 1:
|
|
294
|
+
formatted_grid_output: Any = formatted_grid[0]
|
|
295
|
+
else:
|
|
296
|
+
formatted_grid_output = formatted_grid
|
|
297
|
+
|
|
298
|
+
print("Anchors (grid units):")
|
|
299
|
+
print(pformat(formatted_grid_output))
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def set_parser(subparsers: Any) -> None:
|
|
303
|
+
subparser = subparsers.add_parser(
|
|
304
|
+
"auto-anchors",
|
|
305
|
+
allow_abbrev=False,
|
|
306
|
+
help="fit YOLO anchors with k-means on a COCO dataset",
|
|
307
|
+
description="fit YOLO anchors with k-means on a COCO dataset",
|
|
308
|
+
epilog=(
|
|
309
|
+
"Usage examples:\n"
|
|
310
|
+
"python -m birder.tools auto-anchors --preset yolo_v4 --size 640 "
|
|
311
|
+
"--coco-json-path data/detection_data/training_annotations_coco.json\n"
|
|
312
|
+
"python -m birder.tools auto-anchors --size 640 --num-anchors 9 --num-scales 3 --format pixels "
|
|
313
|
+
"--coco-json-path data/detection_data/training_annotations_coco.json\n"
|
|
314
|
+
"python -m birder.tools auto-anchors --preset yolo_v4_tiny --size 416 416 "
|
|
315
|
+
"--coco-json-path ~/Datasets/cocodataset/annotations/instances_train2017.json\n"
|
|
316
|
+
"python -m birder.tools auto-anchors --preset yolo_v2 --stride 32 "
|
|
317
|
+
"--coco-json-path data/detection_data/training_annotations_coco.json\n"
|
|
318
|
+
"python -m birder.tools auto-anchors --size 640 --num-anchors 9 --num-scales 3 "
|
|
319
|
+
"--format grid --stride 8 16 32 --coco-json-path data/detection_data/training_annotations_coco.json\n"
|
|
320
|
+
),
|
|
321
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
322
|
+
)
|
|
323
|
+
subparser.add_argument(
|
|
324
|
+
"--preset", type=str, choices=sorted(MODEL_PRESETS.keys()), help="YOLO preset for anchor formatting"
|
|
325
|
+
)
|
|
326
|
+
subparser.add_argument(
|
|
327
|
+
"--size",
|
|
328
|
+
type=int,
|
|
329
|
+
nargs="+",
|
|
330
|
+
metavar=("H", "W"),
|
|
331
|
+
help="target image size as [height, width], required without --preset",
|
|
332
|
+
)
|
|
333
|
+
subparser.add_argument("--num-anchors", type=int, help="number of anchors to fit, required without --preset")
|
|
334
|
+
subparser.add_argument("--num-scales", type=int, help="number of output scales, required without --preset")
|
|
335
|
+
subparser.add_argument(
|
|
336
|
+
"--format", type=str, choices=["pixels", "grid"], help="anchor output format, required without --preset"
|
|
337
|
+
)
|
|
338
|
+
subparser.add_argument(
|
|
339
|
+
"--stride", type=int, nargs="+", default=[32], help="strides per scale used to convert anchors to grid units"
|
|
340
|
+
)
|
|
341
|
+
subparser.add_argument(
|
|
342
|
+
"--min-size", type=float, default=1.0, help="minimum scaled box size to include in anchor fitting"
|
|
343
|
+
)
|
|
344
|
+
subparser.add_argument("--include-crowd", default=False, action="store_true", help="include crowd annotations")
|
|
345
|
+
subparser.add_argument("--seed", type=int, help="random seed for k-means initialization")
|
|
346
|
+
subparser.add_argument("--max-iter", type=int, default=1000, help="maximum k-means iterations")
|
|
347
|
+
subparser.add_argument("--precision", type=int, default=1, help="number of decimals to keep in anchor output")
|
|
348
|
+
subparser.add_argument(
|
|
349
|
+
"--ignore-file", type=str, metavar="FILE", help="file containing image names to skip (one per line)"
|
|
350
|
+
)
|
|
351
|
+
subparser.add_argument(
|
|
352
|
+
"--coco-json-path",
|
|
353
|
+
type=str,
|
|
354
|
+
default=f"{settings.TRAINING_DETECTION_ANNOTATIONS_PATH}_coco.json",
|
|
355
|
+
help="training COCO json path",
|
|
356
|
+
)
|
|
357
|
+
subparser.set_defaults(func=main)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def main(args: argparse.Namespace) -> None:
|
|
361
|
+
auto_anchors(args)
|
birder/tools/ensemble_model.py
CHANGED
|
@@ -37,7 +37,7 @@ def set_parser(subparsers: Any) -> None:
|
|
|
37
37
|
"python -m birder.tools ensemble-model --networks convnext_v2_4_0 focalnet_3_0 "
|
|
38
38
|
"swin_transformer_v2_1_0 --pts\n"
|
|
39
39
|
"python -m birder.tools ensemble-model --networks mobilevit_v2_1_5_intermediate_80 "
|
|
40
|
-
"edgevit_2_intermediate_100 --pt2"
|
|
40
|
+
"edgevit_2_intermediate_100 --pt2\n"
|
|
41
41
|
),
|
|
42
42
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
43
43
|
)
|
birder/tools/introspection.py
CHANGED
|
@@ -9,9 +9,10 @@ from birder.common import cli
|
|
|
9
9
|
from birder.common import fs_ops
|
|
10
10
|
from birder.common import lib
|
|
11
11
|
from birder.data.transforms.classification import inference_preset
|
|
12
|
-
from birder.introspection import
|
|
13
|
-
from birder.introspection import
|
|
14
|
-
from birder.introspection import
|
|
12
|
+
from birder.introspection import AttentionRollout
|
|
13
|
+
from birder.introspection import GradCAM
|
|
14
|
+
from birder.introspection import GuidedBackprop
|
|
15
|
+
from birder.introspection import TransformerAttribution
|
|
15
16
|
from birder.net.base import BaseNet
|
|
16
17
|
|
|
17
18
|
logger = logging.getLogger(__name__)
|
|
@@ -21,19 +22,18 @@ def _nhwc_reshape_transform(tensor: torch.Tensor) -> torch.Tensor:
|
|
|
21
22
|
return tensor.permute(0, 3, 1, 2).contiguous()
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
def
|
|
25
|
+
def _show_attn_rollout(
|
|
25
26
|
args: argparse.Namespace,
|
|
26
27
|
net: BaseNet,
|
|
27
|
-
_class_to_idx: dict[str, int],
|
|
28
28
|
transform: Callable[..., torch.Tensor],
|
|
29
29
|
device: torch.device,
|
|
30
30
|
) -> None:
|
|
31
|
-
ar =
|
|
32
|
-
result = ar
|
|
31
|
+
ar = AttentionRollout(net, device, transform, args.attn_layer_name, args.discard_ratio, args.head_fusion)
|
|
32
|
+
result = ar(args.image_path)
|
|
33
33
|
result.show()
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def
|
|
36
|
+
def _show_transformer_attribution(
|
|
37
37
|
args: argparse.Namespace,
|
|
38
38
|
net: BaseNet,
|
|
39
39
|
class_to_idx: dict[str, int],
|
|
@@ -45,12 +45,29 @@ def show_guided_backprop(
|
|
|
45
45
|
else:
|
|
46
46
|
target = None
|
|
47
47
|
|
|
48
|
-
|
|
49
|
-
result =
|
|
48
|
+
ta = TransformerAttribution(net, device, transform, args.attn_layer_name)
|
|
49
|
+
result = ta(args.image_path, target_class=target)
|
|
50
50
|
result.show()
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
def
|
|
53
|
+
def _show_guided_backprop(
|
|
54
|
+
args: argparse.Namespace,
|
|
55
|
+
net: BaseNet,
|
|
56
|
+
class_to_idx: dict[str, int],
|
|
57
|
+
transform: Callable[..., torch.Tensor],
|
|
58
|
+
device: torch.device,
|
|
59
|
+
) -> None:
|
|
60
|
+
if args.target is not None:
|
|
61
|
+
target = class_to_idx[args.target]
|
|
62
|
+
else:
|
|
63
|
+
target = None
|
|
64
|
+
|
|
65
|
+
guided_bp = GuidedBackprop(net, device, transform)
|
|
66
|
+
result = guided_bp(args.image_path, target_class=target)
|
|
67
|
+
result.show()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _show_grad_cam(
|
|
54
71
|
args: argparse.Namespace,
|
|
55
72
|
net: BaseNet,
|
|
56
73
|
class_to_idx: dict[str, int],
|
|
@@ -70,8 +87,8 @@ def show_grad_cam(
|
|
|
70
87
|
else:
|
|
71
88
|
target = None
|
|
72
89
|
|
|
73
|
-
grad_cam =
|
|
74
|
-
result = grad_cam
|
|
90
|
+
grad_cam = GradCAM(net, device, transform, target_layer, reshape_transform=reshape_transform)
|
|
91
|
+
result = grad_cam(args.image_path, target_class=target)
|
|
75
92
|
result.show()
|
|
76
93
|
|
|
77
94
|
|
|
@@ -83,25 +100,22 @@ def set_parser(subparsers: Any) -> None:
|
|
|
83
100
|
description="computer vision introspection and explainability",
|
|
84
101
|
epilog=(
|
|
85
102
|
"Usage examples:\n"
|
|
86
|
-
"python -m birder.tools introspection --
|
|
87
|
-
"
|
|
88
|
-
"python -m birder.tools introspection
|
|
103
|
+
"python -m birder.tools introspection --network efficientnet_v2_m -e 200 --method gradcam "
|
|
104
|
+
"'data/training/European goldfinch/000300.jpeg'\n"
|
|
105
|
+
"python -m birder.tools introspection -n resnest_50 --epoch 300 --method gradcam "
|
|
89
106
|
"data/index5.jpeg --target 'Grey heron'\n"
|
|
90
|
-
"python -m birder.tools introspection --method guided-backprop
|
|
91
|
-
"
|
|
92
|
-
"python -m birder.tools introspection
|
|
107
|
+
"python -m birder.tools introspection -n efficientnet_v2_s --method guided-backprop "
|
|
108
|
+
"'data/training/European goldfinch/000300.jpeg'\n"
|
|
109
|
+
"python -m birder.tools introspection -n swin_transformer_v1_b -e 85 --layer-num -4 --method gradcam "
|
|
93
110
|
"--channels-last data/training/Fieldfare/000002.jpeg\n"
|
|
94
|
-
"python -m birder.tools introspection
|
|
111
|
+
"python -m birder.tools introspection -n vit_reg4_b16 -t mim -e 100 --method attn-rollout "
|
|
95
112
|
" data/validation/Bluethroat/000013.jpeg\n"
|
|
113
|
+
"python -m birder.tools introspection -n deit3_t16 -t il-common --method transformer-attribution "
|
|
114
|
+
"--target 'Black-crowned night heron' data/detection_data/training/0002/000544.jpeg\n"
|
|
96
115
|
),
|
|
97
116
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
98
117
|
)
|
|
99
|
-
subparser.add_argument(
|
|
100
|
-
"--method", type=str, choices=["gradcam", "guided-backprop", "attn-rollout"], help="introspection method"
|
|
101
|
-
)
|
|
102
|
-
subparser.add_argument(
|
|
103
|
-
"-n", "--network", type=str, required=True, help="the neural network to use (i.e. resnet_v2)"
|
|
104
|
-
)
|
|
118
|
+
subparser.add_argument("-n", "--network", type=str, required=True, help="the neural network to use")
|
|
105
119
|
subparser.add_argument("-e", "--epoch", type=int, metavar="N", help="model checkpoint to load")
|
|
106
120
|
subparser.add_argument("-t", "--tag", type=str, help="model tag (from the training phase)")
|
|
107
121
|
subparser.add_argument(
|
|
@@ -109,11 +123,19 @@ def set_parser(subparsers: Any) -> None:
|
|
|
109
123
|
)
|
|
110
124
|
subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
|
|
111
125
|
subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
|
|
126
|
+
subparser.add_argument(
|
|
127
|
+
"--method",
|
|
128
|
+
type=str,
|
|
129
|
+
choices=["gradcam", "guided-backprop", "attn-rollout", "transformer-attribution"],
|
|
130
|
+
help="introspection method",
|
|
131
|
+
)
|
|
112
132
|
subparser.add_argument(
|
|
113
133
|
"--size", type=int, nargs="+", metavar=("H", "W"), help="image size for inference (defaults to model signature)"
|
|
114
134
|
)
|
|
115
135
|
subparser.add_argument(
|
|
116
|
-
"--target",
|
|
136
|
+
"--target",
|
|
137
|
+
type=str,
|
|
138
|
+
help="target class, leave empty to use predicted class (gradcam, guided-backprop, and transformer-attribution)",
|
|
117
139
|
)
|
|
118
140
|
subparser.add_argument("--block-name", type=str, default="body", help="target block (gradcam only)")
|
|
119
141
|
subparser.add_argument(
|
|
@@ -123,7 +145,10 @@ def set_parser(subparsers: Any) -> None:
|
|
|
123
145
|
"--channels-last", default=False, action="store_true", help="channels last model, like swin (gradcam only)"
|
|
124
146
|
)
|
|
125
147
|
subparser.add_argument(
|
|
126
|
-
"--attn-layer-name",
|
|
148
|
+
"--attn-layer-name",
|
|
149
|
+
type=str,
|
|
150
|
+
default="self_attention",
|
|
151
|
+
help="attention layer name (attn-rollout and transformer-attribution)",
|
|
127
152
|
)
|
|
128
153
|
subparser.add_argument(
|
|
129
154
|
"--head-fusion",
|
|
@@ -169,8 +194,10 @@ def main(args: argparse.Namespace) -> None:
|
|
|
169
194
|
transform = inference_preset(args.size, model_info.rgb_stats, 1.0)
|
|
170
195
|
|
|
171
196
|
if args.method == "gradcam":
|
|
172
|
-
|
|
197
|
+
_show_grad_cam(args, net, model_info.class_to_idx, transform, device)
|
|
173
198
|
elif args.method == "guided-backprop":
|
|
174
|
-
|
|
199
|
+
_show_guided_backprop(args, net, model_info.class_to_idx, transform, device)
|
|
175
200
|
elif args.method == "attn-rollout":
|
|
176
|
-
|
|
201
|
+
_show_attn_rollout(args, net, transform, device)
|
|
202
|
+
elif args.method == "transformer-attribution":
|
|
203
|
+
_show_transformer_attribution(args, net, model_info.class_to_idx, transform, device)
|
birder/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "v0.2.
|
|
1
|
+
__version__ = "v0.2.2"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -62,6 +62,7 @@ Requires-Dist: MonkeyType~=23.3.0; extra == "dev"
|
|
|
62
62
|
Requires-Dist: mypy~=1.19.1; extra == "dev"
|
|
63
63
|
Requires-Dist: parameterized~=0.9.0; extra == "dev"
|
|
64
64
|
Requires-Dist: pylint~=4.0.4; extra == "dev"
|
|
65
|
+
Requires-Dist: pytest; extra == "dev"
|
|
65
66
|
Requires-Dist: requests~=2.32.5; extra == "dev"
|
|
66
67
|
Requires-Dist: safetensors~=0.7.0; extra == "dev"
|
|
67
68
|
Requires-Dist: setuptools; extra == "dev"
|