birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. birder/adversarial/__init__.py +13 -0
  2. birder/adversarial/base.py +101 -0
  3. birder/adversarial/deepfool.py +173 -0
  4. birder/adversarial/fgsm.py +51 -18
  5. birder/adversarial/pgd.py +79 -28
  6. birder/adversarial/simba.py +172 -0
  7. birder/common/training_cli.py +11 -3
  8. birder/common/training_utils.py +18 -1
  9. birder/inference/data_parallel.py +1 -2
  10. birder/introspection/__init__.py +10 -6
  11. birder/introspection/attention_rollout.py +122 -54
  12. birder/introspection/base.py +73 -29
  13. birder/introspection/gradcam.py +71 -100
  14. birder/introspection/guided_backprop.py +146 -72
  15. birder/introspection/transformer_attribution.py +182 -0
  16. birder/net/detection/deformable_detr.py +14 -12
  17. birder/net/detection/detr.py +7 -3
  18. birder/net/detection/rt_detr_v1.py +3 -3
  19. birder/net/detection/yolo_v3.py +6 -11
  20. birder/net/detection/yolo_v4.py +7 -18
  21. birder/net/detection/yolo_v4_tiny.py +3 -3
  22. birder/net/fastvit.py +1 -1
  23. birder/net/mim/mae_vit.py +7 -8
  24. birder/net/pit.py +1 -1
  25. birder/net/resnet_v1.py +94 -34
  26. birder/net/ssl/data2vec.py +1 -1
  27. birder/net/ssl/data2vec2.py +4 -2
  28. birder/results/gui.py +15 -2
  29. birder/scripts/predict_detection.py +33 -1
  30. birder/scripts/train.py +24 -17
  31. birder/scripts/train_barlow_twins.py +10 -7
  32. birder/scripts/train_byol.py +10 -7
  33. birder/scripts/train_capi.py +12 -9
  34. birder/scripts/train_data2vec.py +10 -7
  35. birder/scripts/train_data2vec2.py +10 -7
  36. birder/scripts/train_detection.py +42 -18
  37. birder/scripts/train_dino_v1.py +10 -7
  38. birder/scripts/train_dino_v2.py +10 -7
  39. birder/scripts/train_dino_v2_dist.py +17 -7
  40. birder/scripts/train_franca.py +10 -7
  41. birder/scripts/train_i_jepa.py +17 -13
  42. birder/scripts/train_ibot.py +10 -7
  43. birder/scripts/train_kd.py +24 -18
  44. birder/scripts/train_mim.py +11 -10
  45. birder/scripts/train_mmcr.py +10 -7
  46. birder/scripts/train_rotnet.py +10 -7
  47. birder/scripts/train_simclr.py +10 -7
  48. birder/scripts/train_vicreg.py +10 -7
  49. birder/tools/__main__.py +6 -2
  50. birder/tools/adversarial.py +147 -96
  51. birder/tools/auto_anchors.py +361 -0
  52. birder/tools/ensemble_model.py +1 -1
  53. birder/tools/introspection.py +58 -31
  54. birder/version.py +1 -1
  55. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
  56. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
  57. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
  58. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
  59. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
  60. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,361 @@
1
+ """
2
+ Fit YOLO-style anchor boxes using k-means based on COCO-format annotations.
3
+
4
+ Generated by gpt-5.2-codex xhigh.
5
+ """
6
+
7
+ import argparse
8
+ import json
9
+ import logging
10
+ import math
11
+ from pathlib import Path
12
+ from pprint import pformat
13
+ from typing import Any
14
+ from typing import Literal
15
+ from typing import Optional
16
+ from typing import TypedDict
17
+
18
+ import torch
19
+
20
+ from birder.common import cli
21
+ from birder.conf import settings
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class AnchorPreset(TypedDict):
27
+ num_scales: int
28
+ num_anchors: int
29
+ default_size: tuple[int, int]
30
+ format: Literal["pixels", "grid"]
31
+
32
+
33
+ MODEL_PRESETS: dict[str, AnchorPreset] = {
34
+ "yolo_v2": {"num_scales": 1, "num_anchors": 5, "default_size": (416, 416), "format": "grid"},
35
+ "yolo_v3": {"num_scales": 3, "num_anchors": 9, "default_size": (416, 416), "format": "pixels"},
36
+ "yolo_v4": {"num_scales": 3, "num_anchors": 9, "default_size": (608, 608), "format": "pixels"},
37
+ "yolo_v4_tiny": {"num_scales": 2, "num_anchors": 6, "default_size": (416, 416), "format": "pixels"},
38
+ }
39
+
40
+
41
+ def _load_ignore_list(ignore_file: Optional[str]) -> set[str]:
42
+ if ignore_file is None:
43
+ return set()
44
+
45
+ with open(ignore_file, "r", encoding="utf-8") as handle:
46
+ return {line.strip() for line in handle if line.strip()}
47
+
48
+
49
+ def _load_coco_boxes(
50
+ coco_json_path: str, target_size: tuple[int, int], ignore_list: set[str], min_size: float, ignore_crowd: bool
51
+ ) -> tuple[torch.Tensor, dict[str, int]]:
52
+ coco_path = Path(coco_json_path)
53
+ if coco_path.exists() is False:
54
+ raise ValueError(f"COCO json not found at {coco_path}")
55
+
56
+ with open(coco_path, "r", encoding="utf-8") as handle:
57
+ data = json.load(handle)
58
+
59
+ images = {}
60
+ for image in data.get("images", []):
61
+ image_id = image.get("id")
62
+ if image_id is None:
63
+ continue
64
+ images[image_id] = (image.get("width"), image.get("height"), image.get("file_name", ""))
65
+
66
+ stats = {
67
+ "total_annotations": 0,
68
+ "used_annotations": 0,
69
+ "crowd_annotations": 0,
70
+ "invalid_bbox": 0,
71
+ "ignored_images": 0,
72
+ "missing_images": 0,
73
+ "missing_size": 0,
74
+ "too_small": 0,
75
+ }
76
+ boxes: list[tuple[float, float]] = []
77
+ target_h = float(target_size[0])
78
+ target_w = float(target_size[1])
79
+ for annotation in data.get("annotations", []):
80
+ stats["total_annotations"] += 1
81
+ if ignore_crowd is True and annotation.get("iscrowd", 0) == 1:
82
+ stats["crowd_annotations"] += 1
83
+ continue
84
+
85
+ bbox = annotation.get("bbox")
86
+ if bbox is None or len(bbox) != 4:
87
+ stats["invalid_bbox"] += 1
88
+ continue
89
+
90
+ image_id = annotation.get("image_id")
91
+ if image_id not in images:
92
+ stats["missing_images"] += 1
93
+ continue
94
+
95
+ (img_w, img_h, file_name) = images[image_id]
96
+ if file_name in ignore_list:
97
+ stats["ignored_images"] += 1
98
+ continue
99
+
100
+ if img_w in {None, 0} or img_h in {None, 0}:
101
+ stats["missing_size"] += 1
102
+ continue
103
+
104
+ bbox_w = float(bbox[2])
105
+ bbox_h = float(bbox[3])
106
+ if bbox_w <= 0.0 or bbox_h <= 0.0:
107
+ stats["invalid_bbox"] += 1
108
+ continue
109
+
110
+ scaled_w = bbox_w / float(img_w) * target_w
111
+ scaled_h = bbox_h / float(img_h) * target_h
112
+ if scaled_w < min_size or scaled_h < min_size:
113
+ stats["too_small"] += 1
114
+ continue
115
+
116
+ boxes.append((scaled_w, scaled_h))
117
+
118
+ stats["used_annotations"] = len(boxes)
119
+ if len(boxes) == 0:
120
+ raise ValueError("No valid bounding boxes found for anchor fitting")
121
+
122
+ return (torch.tensor(boxes, dtype=torch.float32), stats)
123
+
124
+
125
+ def _wh_iou(boxes: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:
126
+ boxes = boxes[:, None, :]
127
+ anchors = anchors[None, :, :]
128
+ inter = torch.min(boxes, anchors).prod(dim=2)
129
+ union = boxes.prod(dim=2) + anchors.prod(dim=2) - inter
130
+ return inter / (union + 1e-9)
131
+
132
+
133
+ def _kmeans_plusplus_init(boxes: torch.Tensor, num_anchors: int, generator: torch.Generator) -> torch.Tensor:
134
+ n_local_trials = 2 + int(math.log(num_anchors))
135
+
136
+ anchors = []
137
+
138
+ first_idx = torch.randint(0, boxes.size(0), (1,), generator=generator).item()
139
+ anchors.append(boxes[first_idx])
140
+ for _ in range(num_anchors - 1):
141
+ anchors_tensor = torch.stack(anchors)
142
+ ious = _wh_iou(boxes, anchors_tensor)
143
+ max_ious = ious.max(dim=1).values
144
+
145
+ min_distances = 1.0 - max_ious
146
+ squared_distances = min_distances**2
147
+
148
+ probs = squared_distances / (squared_distances.sum() + 1e-9)
149
+ cumulative_probs = torch.cumsum(probs, dim=0)
150
+ r = torch.rand(n_local_trials, generator=generator)
151
+ candidate_indices = torch.searchsorted(cumulative_probs, r, right=True)
152
+
153
+ candidate_indices = torch.clamp(candidate_indices, 0, boxes.size(0) - 1)
154
+ candidate_boxes = boxes[candidate_indices]
155
+ candidate_ious = _wh_iou(boxes, candidate_boxes) # (n_boxes, n_trials)
156
+ candidate_distances = 1.0 - candidate_ious
157
+
158
+ min_distances_expanded = min_distances.unsqueeze(1)
159
+ candidate_potentials = torch.min(min_distances_expanded, candidate_distances).sum(dim=0)
160
+
161
+ best_trial_idx = torch.argmin(candidate_potentials).item()
162
+ best_idx = candidate_indices[best_trial_idx].item()
163
+ anchors.append(boxes[best_idx])
164
+
165
+ return torch.stack(anchors)
166
+
167
+
168
+ def _kmeans_anchors(
169
+ boxes: torch.Tensor, num_anchors: int, seed: Optional[int], max_iter: int
170
+ ) -> tuple[torch.Tensor, torch.Tensor]:
171
+ if boxes.size(0) < num_anchors:
172
+ raise ValueError(
173
+ f"Not enough boxes ({boxes.size(0)}) to fit {num_anchors} anchors, Reduce --num-anchors or add more data"
174
+ )
175
+
176
+ generator = torch.Generator()
177
+ if seed is not None:
178
+ generator.manual_seed(seed)
179
+
180
+ anchors = _kmeans_plusplus_init(boxes, num_anchors, generator)
181
+ assignments = torch.full((boxes.size(0),), -1, dtype=torch.int64)
182
+
183
+ for _ in range(max_iter):
184
+ ious = _wh_iou(boxes, anchors)
185
+ new_assignments = torch.argmax(ious, dim=1)
186
+ if torch.equal(assignments, new_assignments):
187
+ break
188
+
189
+ assignments = new_assignments
190
+ for idx in range(num_anchors):
191
+ mask = assignments == idx
192
+ if mask.any():
193
+ anchors[idx] = boxes[mask].median(dim=0).values
194
+ else:
195
+ rand_idx = torch.randint(0, boxes.size(0), (1,), generator=generator).item()
196
+ anchors[idx] = boxes[rand_idx]
197
+
198
+ return (anchors, assignments)
199
+
200
+
201
+ def _format_anchor_groups(anchor_groups: list[torch.Tensor], precision: int) -> list[list[tuple[float, float]]]:
202
+ formatted: list[list[tuple[float, float]]] = []
203
+ for group in anchor_groups:
204
+ formatted.append([(round(float(anchor[0]), precision), round(float(anchor[1]), precision)) for anchor in group])
205
+
206
+ return formatted
207
+
208
+
209
+ def _validate_args(
210
+ args: argparse.Namespace,
211
+ ) -> tuple[tuple[int, int], int, int, Literal["pixels", "grid"], list[float]]:
212
+ preset = MODEL_PRESETS.get(args.preset) if args.preset is not None else None
213
+ size = cli.parse_size(args.size) if args.size is not None else (preset["default_size"] if preset else None)
214
+ if size is None:
215
+ raise cli.ValidationError("Missing --size. Provide --size or use a --preset")
216
+
217
+ num_scales = args.num_scales if args.num_scales is not None else (preset["num_scales"] if preset else None)
218
+ num_anchors = args.num_anchors if args.num_anchors is not None else (preset["num_anchors"] if preset else None)
219
+ output_format = args.format if args.format is not None else (preset["format"] if preset else None)
220
+ if num_scales is None or num_anchors is None or output_format is None:
221
+ raise cli.ValidationError(
222
+ "Missing configuration. Provide --num-scales, --num-anchors, and --format or use a --preset"
223
+ )
224
+ if num_scales < 1:
225
+ raise cli.ValidationError("--num-scales must be >= 1")
226
+ if num_anchors < 1:
227
+ raise cli.ValidationError("--num-anchors must be >= 1")
228
+ if num_anchors % num_scales != 0:
229
+ raise cli.ValidationError("--num-anchors must be divisible by --num-scales")
230
+
231
+ strides: list[float] = []
232
+ if output_format == "grid":
233
+ if args.stride is None:
234
+ raise cli.ValidationError("--format grid requires --stride values per scale")
235
+
236
+ strides = [float(value) for value in args.stride]
237
+ if len(strides) != num_scales:
238
+ raise cli.ValidationError("--stride must provide one value per scale when --format grid is used")
239
+ if any(value <= 0 for value in strides):
240
+ raise cli.ValidationError("--stride values must be > 0")
241
+
242
+ return (size, num_scales, num_anchors, output_format, strides)
243
+
244
+
245
+ def auto_anchors(args: argparse.Namespace) -> None:
246
+ (size, num_scales, num_anchors, output_format, strides) = _validate_args(args)
247
+
248
+ ignore_list = _load_ignore_list(args.ignore_file)
249
+ (boxes, stats) = _load_coco_boxes(
250
+ args.coco_json_path, size, ignore_list, args.min_size, ignore_crowd=not args.include_crowd
251
+ )
252
+
253
+ if args.preset is not None:
254
+ logger.info(f"Using preset {args.preset}")
255
+
256
+ logger.info(f"Fitting anchors using size={size[0]}x{size[1]}")
257
+ logger.info(
258
+ f"Annotations: total={stats['total_annotations']}, used={stats['used_annotations']}, "
259
+ f"crowd={stats['crowd_annotations']}, invalid={stats['invalid_bbox']}, "
260
+ f"ignored={stats['ignored_images']}, missing={stats['missing_images']}, "
261
+ f"missing_size={stats['missing_size']}, too_small={stats['too_small']}"
262
+ )
263
+
264
+ (anchors, _assignments) = _kmeans_anchors(boxes, num_anchors, args.seed, args.max_iter)
265
+ areas = anchors.prod(dim=1)
266
+ anchors = anchors[torch.argsort(areas)]
267
+ anchors_per_scale = num_anchors // num_scales
268
+ anchor_groups = [anchors[i : i + anchors_per_scale] for i in range(0, num_anchors, anchors_per_scale)]
269
+
270
+ ious = _wh_iou(boxes, anchors)
271
+ best_iou = ious.max(dim=1).values
272
+ logger.info(f"Mean IoU: {best_iou.mean().item():.4f}")
273
+
274
+ formatted_groups = _format_anchor_groups(anchor_groups, args.precision)
275
+ if output_format == "pixels":
276
+ if num_scales == 1:
277
+ formatted_anchors: Any = formatted_groups[0]
278
+ else:
279
+ formatted_anchors = formatted_groups
280
+
281
+ print("Anchors (pixels):")
282
+ print(pformat(formatted_anchors))
283
+
284
+ if output_format == "grid":
285
+ grid_groups: list[torch.Tensor] = []
286
+ for group, stride in zip(anchor_groups, strides):
287
+ grid_group = group.clone()
288
+ grid_group[:, 0] = grid_group[:, 0] / stride
289
+ grid_group[:, 1] = grid_group[:, 1] / stride
290
+ grid_groups.append(grid_group)
291
+
292
+ formatted_grid = _format_anchor_groups(grid_groups, args.precision)
293
+ if num_scales == 1:
294
+ formatted_grid_output: Any = formatted_grid[0]
295
+ else:
296
+ formatted_grid_output = formatted_grid
297
+
298
+ print("Anchors (grid units):")
299
+ print(pformat(formatted_grid_output))
300
+
301
+
302
+ def set_parser(subparsers: Any) -> None:
303
+ subparser = subparsers.add_parser(
304
+ "auto-anchors",
305
+ allow_abbrev=False,
306
+ help="fit YOLO anchors with k-means on a COCO dataset",
307
+ description="fit YOLO anchors with k-means on a COCO dataset",
308
+ epilog=(
309
+ "Usage examples:\n"
310
+ "python -m birder.tools auto-anchors --preset yolo_v4 --size 640 "
311
+ "--coco-json-path data/detection_data/training_annotations_coco.json\n"
312
+ "python -m birder.tools auto-anchors --size 640 --num-anchors 9 --num-scales 3 --format pixels "
313
+ "--coco-json-path data/detection_data/training_annotations_coco.json\n"
314
+ "python -m birder.tools auto-anchors --preset yolo_v4_tiny --size 416 416 "
315
+ "--coco-json-path ~/Datasets/cocodataset/annotations/instances_train2017.json\n"
316
+ "python -m birder.tools auto-anchors --preset yolo_v2 --stride 32 "
317
+ "--coco-json-path data/detection_data/training_annotations_coco.json\n"
318
+ "python -m birder.tools auto-anchors --size 640 --num-anchors 9 --num-scales 3 "
319
+ "--format grid --stride 8 16 32 --coco-json-path data/detection_data/training_annotations_coco.json\n"
320
+ ),
321
+ formatter_class=cli.ArgumentHelpFormatter,
322
+ )
323
+ subparser.add_argument(
324
+ "--preset", type=str, choices=sorted(MODEL_PRESETS.keys()), help="YOLO preset for anchor formatting"
325
+ )
326
+ subparser.add_argument(
327
+ "--size",
328
+ type=int,
329
+ nargs="+",
330
+ metavar=("H", "W"),
331
+ help="target image size as [height, width], required without --preset",
332
+ )
333
+ subparser.add_argument("--num-anchors", type=int, help="number of anchors to fit, required without --preset")
334
+ subparser.add_argument("--num-scales", type=int, help="number of output scales, required without --preset")
335
+ subparser.add_argument(
336
+ "--format", type=str, choices=["pixels", "grid"], help="anchor output format, required without --preset"
337
+ )
338
+ subparser.add_argument(
339
+ "--stride", type=int, nargs="+", default=[32], help="strides per scale used to convert anchors to grid units"
340
+ )
341
+ subparser.add_argument(
342
+ "--min-size", type=float, default=1.0, help="minimum scaled box size to include in anchor fitting"
343
+ )
344
+ subparser.add_argument("--include-crowd", default=False, action="store_true", help="include crowd annotations")
345
+ subparser.add_argument("--seed", type=int, help="random seed for k-means initialization")
346
+ subparser.add_argument("--max-iter", type=int, default=1000, help="maximum k-means iterations")
347
+ subparser.add_argument("--precision", type=int, default=1, help="number of decimals to keep in anchor output")
348
+ subparser.add_argument(
349
+ "--ignore-file", type=str, metavar="FILE", help="file containing image names to skip (one per line)"
350
+ )
351
+ subparser.add_argument(
352
+ "--coco-json-path",
353
+ type=str,
354
+ default=f"{settings.TRAINING_DETECTION_ANNOTATIONS_PATH}_coco.json",
355
+ help="training COCO json path",
356
+ )
357
+ subparser.set_defaults(func=main)
358
+
359
+
360
+ def main(args: argparse.Namespace) -> None:
361
+ auto_anchors(args)
@@ -37,7 +37,7 @@ def set_parser(subparsers: Any) -> None:
37
37
  "python -m birder.tools ensemble-model --networks convnext_v2_4_0 focalnet_3_0 "
38
38
  "swin_transformer_v2_1_0 --pts\n"
39
39
  "python -m birder.tools ensemble-model --networks mobilevit_v2_1_5_intermediate_80 "
40
- "edgevit_2_intermediate_100 --pt2"
40
+ "edgevit_2_intermediate_100 --pt2\n"
41
41
  ),
42
42
  formatter_class=cli.ArgumentHelpFormatter,
43
43
  )
@@ -9,9 +9,10 @@ from birder.common import cli
9
9
  from birder.common import fs_ops
10
10
  from birder.common import lib
11
11
  from birder.data.transforms.classification import inference_preset
12
- from birder.introspection import AttentionRolloutInterpreter
13
- from birder.introspection import GradCamInterpreter
14
- from birder.introspection import GuidedBackpropInterpreter
12
+ from birder.introspection import AttentionRollout
13
+ from birder.introspection import GradCAM
14
+ from birder.introspection import GuidedBackprop
15
+ from birder.introspection import TransformerAttribution
15
16
  from birder.net.base import BaseNet
16
17
 
17
18
  logger = logging.getLogger(__name__)
@@ -21,19 +22,18 @@ def _nhwc_reshape_transform(tensor: torch.Tensor) -> torch.Tensor:
21
22
  return tensor.permute(0, 3, 1, 2).contiguous()
22
23
 
23
24
 
24
- def show_attn_rollout(
25
+ def _show_attn_rollout(
25
26
  args: argparse.Namespace,
26
27
  net: BaseNet,
27
- _class_to_idx: dict[str, int],
28
28
  transform: Callable[..., torch.Tensor],
29
29
  device: torch.device,
30
30
  ) -> None:
31
- ar = AttentionRolloutInterpreter(net, device, transform, args.attn_layer_name, args.discard_ratio, args.head_fusion)
32
- result = ar.interpret(args.image_path)
31
+ ar = AttentionRollout(net, device, transform, args.attn_layer_name, args.discard_ratio, args.head_fusion)
32
+ result = ar(args.image_path)
33
33
  result.show()
34
34
 
35
35
 
36
- def show_guided_backprop(
36
+ def _show_transformer_attribution(
37
37
  args: argparse.Namespace,
38
38
  net: BaseNet,
39
39
  class_to_idx: dict[str, int],
@@ -45,12 +45,29 @@ def show_guided_backprop(
45
45
  else:
46
46
  target = None
47
47
 
48
- guided_bp = GuidedBackpropInterpreter(net, device, transform)
49
- result = guided_bp.interpret(args.image_path, target_class=target)
48
+ ta = TransformerAttribution(net, device, transform, args.attn_layer_name)
49
+ result = ta(args.image_path, target_class=target)
50
50
  result.show()
51
51
 
52
52
 
53
- def show_grad_cam(
53
+ def _show_guided_backprop(
54
+ args: argparse.Namespace,
55
+ net: BaseNet,
56
+ class_to_idx: dict[str, int],
57
+ transform: Callable[..., torch.Tensor],
58
+ device: torch.device,
59
+ ) -> None:
60
+ if args.target is not None:
61
+ target = class_to_idx[args.target]
62
+ else:
63
+ target = None
64
+
65
+ guided_bp = GuidedBackprop(net, device, transform)
66
+ result = guided_bp(args.image_path, target_class=target)
67
+ result.show()
68
+
69
+
70
+ def _show_grad_cam(
54
71
  args: argparse.Namespace,
55
72
  net: BaseNet,
56
73
  class_to_idx: dict[str, int],
@@ -70,8 +87,8 @@ def show_grad_cam(
70
87
  else:
71
88
  target = None
72
89
 
73
- grad_cam = GradCamInterpreter(net, device, transform, target_layer, reshape_transform=reshape_transform)
74
- result = grad_cam.interpret(args.image_path, target_class=target)
90
+ grad_cam = GradCAM(net, device, transform, target_layer, reshape_transform=reshape_transform)
91
+ result = grad_cam(args.image_path, target_class=target)
75
92
  result.show()
76
93
 
77
94
 
@@ -83,25 +100,22 @@ def set_parser(subparsers: Any) -> None:
83
100
  description="computer vision introspection and explainability",
84
101
  epilog=(
85
102
  "Usage examples:\n"
86
- "python -m birder.tools introspection --method gradcam --network efficientnet_v2_m "
87
- "--epoch 200 'data/training/European goldfinch/000300.jpeg'\n"
88
- "python -m birder.tools introspection --method gradcam -n resnest_50 --epoch 300 "
103
+ "python -m birder.tools introspection --network efficientnet_v2_m -e 200 --method gradcam "
104
+ "'data/training/European goldfinch/000300.jpeg'\n"
105
+ "python -m birder.tools introspection -n resnest_50 --epoch 300 --method gradcam "
89
106
  "data/index5.jpeg --target 'Grey heron'\n"
90
- "python -m birder.tools introspection --method guided-backprop -n efficientnet_v2_s "
91
- "-e 0 'data/training/European goldfinch/000300.jpeg'\n"
92
- "python -m birder.tools introspection --method gradcam -n swin_transformer_v1_b -e 85 --layer-num -4 "
107
+ "python -m birder.tools introspection -n efficientnet_v2_s --method guided-backprop "
108
+ "'data/training/European goldfinch/000300.jpeg'\n"
109
+ "python -m birder.tools introspection -n swin_transformer_v1_b -e 85 --layer-num -4 --method gradcam "
93
110
  "--channels-last data/training/Fieldfare/000002.jpeg\n"
94
- "python -m birder.tools introspection --method attn-rollout -n vit_reg4_b16 -t mim -e 100 "
111
+ "python -m birder.tools introspection -n vit_reg4_b16 -t mim -e 100 --method attn-rollout "
95
112
  " data/validation/Bluethroat/000013.jpeg\n"
113
+ "python -m birder.tools introspection -n deit3_t16 -t il-common --method transformer-attribution "
114
+ "--target 'Black-crowned night heron' data/detection_data/training/0002/000544.jpeg\n"
96
115
  ),
97
116
  formatter_class=cli.ArgumentHelpFormatter,
98
117
  )
99
- subparser.add_argument(
100
- "--method", type=str, choices=["gradcam", "guided-backprop", "attn-rollout"], help="introspection method"
101
- )
102
- subparser.add_argument(
103
- "-n", "--network", type=str, required=True, help="the neural network to use (i.e. resnet_v2)"
104
- )
118
+ subparser.add_argument("-n", "--network", type=str, required=True, help="the neural network to use")
105
119
  subparser.add_argument("-e", "--epoch", type=int, metavar="N", help="model checkpoint to load")
106
120
  subparser.add_argument("-t", "--tag", type=str, help="model tag (from the training phase)")
107
121
  subparser.add_argument(
@@ -109,11 +123,19 @@ def set_parser(subparsers: Any) -> None:
109
123
  )
110
124
  subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
111
125
  subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
126
+ subparser.add_argument(
127
+ "--method",
128
+ type=str,
129
+ choices=["gradcam", "guided-backprop", "attn-rollout", "transformer-attribution"],
130
+ help="introspection method",
131
+ )
112
132
  subparser.add_argument(
113
133
  "--size", type=int, nargs="+", metavar=("H", "W"), help="image size for inference (defaults to model signature)"
114
134
  )
115
135
  subparser.add_argument(
116
- "--target", type=str, help="target class, leave empty to use predicted class (gradcam and guided-backprop only)"
136
+ "--target",
137
+ type=str,
138
+ help="target class, leave empty to use predicted class (gradcam, guided-backprop, and transformer-attribution)",
117
139
  )
118
140
  subparser.add_argument("--block-name", type=str, default="body", help="target block (gradcam only)")
119
141
  subparser.add_argument(
@@ -123,7 +145,10 @@ def set_parser(subparsers: Any) -> None:
123
145
  "--channels-last", default=False, action="store_true", help="channels last model, like swin (gradcam only)"
124
146
  )
125
147
  subparser.add_argument(
126
- "--attn-layer-name", type=str, default="self_attention", help="attention layer name (attn-rollout only)"
148
+ "--attn-layer-name",
149
+ type=str,
150
+ default="self_attention",
151
+ help="attention layer name (attn-rollout and transformer-attribution)",
127
152
  )
128
153
  subparser.add_argument(
129
154
  "--head-fusion",
@@ -169,8 +194,10 @@ def main(args: argparse.Namespace) -> None:
169
194
  transform = inference_preset(args.size, model_info.rgb_stats, 1.0)
170
195
 
171
196
  if args.method == "gradcam":
172
- show_grad_cam(args, net, model_info.class_to_idx, transform, device)
197
+ _show_grad_cam(args, net, model_info.class_to_idx, transform, device)
173
198
  elif args.method == "guided-backprop":
174
- show_guided_backprop(args, net, model_info.class_to_idx, transform, device)
199
+ _show_guided_backprop(args, net, model_info.class_to_idx, transform, device)
175
200
  elif args.method == "attn-rollout":
176
- show_attn_rollout(args, net, model_info.class_to_idx, transform, device)
201
+ _show_attn_rollout(args, net, transform, device)
202
+ elif args.method == "transformer-attribution":
203
+ _show_transformer_attribution(args, net, model_info.class_to_idx, transform, device)
birder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "v0.2.1"
1
+ __version__ = "v0.2.2"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -62,6 +62,7 @@ Requires-Dist: MonkeyType~=23.3.0; extra == "dev"
62
62
  Requires-Dist: mypy~=1.19.1; extra == "dev"
63
63
  Requires-Dist: parameterized~=0.9.0; extra == "dev"
64
64
  Requires-Dist: pylint~=4.0.4; extra == "dev"
65
+ Requires-Dist: pytest; extra == "dev"
65
66
  Requires-Dist: requests~=2.32.5; extra == "dev"
66
67
  Requires-Dist: safetensors~=0.7.0; extra == "dev"
67
68
  Requires-Dist: setuptools; extra == "dev"