birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,50 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import polars as pl
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def l2_normalize(x: npt.NDArray[np.float32], eps: float = 1e-12) -> npt.NDArray[np.float32]:
12
+ norms = np.linalg.norm(x, axis=1, keepdims=True)
13
+ return x / np.clip(norms, eps, None) # type: ignore[no-any-return]
14
+
15
+
16
+ def load_embeddings(path: Path | str) -> tuple[list[str], npt.NDArray[np.float32]]:
17
+ """
18
+ Load embeddings from parquet file
19
+
20
+ Auto-detects format:
21
+ - If 'embedding' column exists: use directly
22
+ - If numeric column names (0, 1, 2, ...): treat as logits, convert to array
23
+
24
+ Returns
25
+ -------
26
+ sample_ids
27
+ List of sample identifiers (stem of 'sample' column path).
28
+ features
29
+ Array of shape (n_samples, embedding_dim), dtype float32.
30
+ """
31
+
32
+ if isinstance(path, str):
33
+ path = Path(path)
34
+
35
+ df = pl.read_parquet(path)
36
+ df = df.with_columns(pl.col("sample").map_elements(lambda p: Path(p).stem, return_dtype=pl.Utf8).alias("id"))
37
+
38
+ if "embedding" in df.columns:
39
+ df = df.select(["id", "embedding"])
40
+ else:
41
+ # Logits format - numeric column names
42
+ embed_cols = sorted([c for c in df.columns if c.isdigit()], key=int)
43
+ df = df.with_columns(
44
+ pl.concat_list(pl.col(embed_cols)).cast(pl.Array(pl.Float32, len(embed_cols))).alias("embedding")
45
+ ).select(["id", "embedding"])
46
+
47
+ sample_ids = df.get_column("id").to_list()
48
+ features = df.get_column("embedding").to_numpy().astype(np.float32, copy=False)
49
+
50
+ return (sample_ids, features)
@@ -0,0 +1,315 @@
1
+ import argparse
2
+ import json
3
+ import logging
4
+ from typing import Any
5
+
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from birder.adversarial.deepfool import DeepFool
11
+ from birder.adversarial.fgsm import FGSM
12
+ from birder.adversarial.pgd import PGD
13
+ from birder.adversarial.simba import SimBA
14
+ from birder.common import cli
15
+ from birder.common import fs_ops
16
+ from birder.common import lib
17
+ from birder.conf import settings
18
+ from birder.data.dataloader.webdataset import make_wds_loader
19
+ from birder.data.datasets.directory import make_image_dataset
20
+ from birder.data.datasets.webdataset import make_wds_dataset
21
+ from birder.data.datasets.webdataset import prepare_wds_args
22
+ from birder.data.datasets.webdataset import wds_args_from_info
23
+ from birder.data.transforms.classification import RGBType
24
+ from birder.data.transforms.classification import inference_preset
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def _build_attack(
30
+ method: str,
31
+ net: torch.nn.Module,
32
+ rgb_stats: RGBType,
33
+ eps: float,
34
+ steps: int,
35
+ step_size: float | None,
36
+ deepfool_num_classes: int,
37
+ ) -> FGSM | PGD | DeepFool | SimBA:
38
+ if method == "fgsm":
39
+ return FGSM(net, eps=eps, rgb_stats=rgb_stats)
40
+
41
+ if method == "pgd":
42
+ return PGD(
43
+ net,
44
+ eps=eps,
45
+ steps=steps,
46
+ step_size=step_size,
47
+ random_start=False,
48
+ rgb_stats=rgb_stats,
49
+ )
50
+
51
+ if method == "deepfool":
52
+ return DeepFool(net, num_classes=deepfool_num_classes, overshoot=0.02, max_iter=steps, rgb_stats=rgb_stats)
53
+
54
+ if method == "simba":
55
+ return SimBA(
56
+ net,
57
+ step_size=step_size if step_size is not None else eps,
58
+ max_iter=steps,
59
+ rgb_stats=rgb_stats,
60
+ )
61
+
62
+ raise ValueError(f"Unsupported attack method '{method}'")
63
+
64
+
65
+ # pylint: disable=too-many-locals,too-many-branches
66
+ def evaluate_robust(args: argparse.Namespace) -> None:
67
+ if args.gpu is True:
68
+ device = torch.device("cuda")
69
+ elif args.mps is True:
70
+ device = torch.device("mps")
71
+ else:
72
+ device = torch.device("cpu")
73
+
74
+ if args.gpu_id is not None:
75
+ torch.cuda.set_device(args.gpu_id)
76
+
77
+ logger.info(f"Using device {device}")
78
+
79
+ if args.fast_matmul is True or args.amp is True:
80
+ torch.set_float32_matmul_precision("high")
81
+
82
+ if args.amp_dtype is None:
83
+ amp_dtype = torch.get_autocast_dtype(device.type)
84
+ logger.debug(f"AMP: {args.amp}, AMP dtype: {amp_dtype}")
85
+ else:
86
+ amp_dtype = getattr(torch, args.amp_dtype)
87
+
88
+ network_name = lib.get_network_name(args.network, tag=args.tag)
89
+ net, model_info = fs_ops.load_model(
90
+ device,
91
+ args.network,
92
+ tag=args.tag,
93
+ epoch=args.epoch,
94
+ new_size=args.size,
95
+ inference=True,
96
+ reparameterized=args.reparameterized,
97
+ )
98
+
99
+ class_to_idx = model_info.class_to_idx
100
+ rgb_stats = model_info.rgb_stats
101
+ if args.size is None:
102
+ size = lib.get_size_from_signature(model_info.signature)
103
+ else:
104
+ size = args.size
105
+
106
+ transform = inference_preset(size, rgb_stats, args.center_crop, args.simple_crop)
107
+
108
+ if args.wds is True:
109
+ wds_path: str | list[str]
110
+ if args.wds_info is not None:
111
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
112
+ if args.wds_size is not None:
113
+ dataset_size = args.wds_size
114
+ else:
115
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
116
+
117
+ num_samples = dataset_size
118
+ dataset = make_wds_dataset(
119
+ wds_path,
120
+ dataset_size=dataset_size,
121
+ shuffle=False,
122
+ samples_names=True,
123
+ transform=transform,
124
+ )
125
+ dataloader = make_wds_loader(
126
+ dataset,
127
+ args.batch_size,
128
+ num_workers=args.num_workers,
129
+ prefetch_factor=None,
130
+ collate_fn=None,
131
+ world_size=1,
132
+ pin_memory=False,
133
+ exact=True,
134
+ )
135
+ else:
136
+ dataset = make_image_dataset(args.data_path, class_to_idx, transforms=transform)
137
+ num_samples = len(dataset)
138
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
139
+
140
+ attack = _build_attack(args.method, net, rgb_stats, args.eps, args.steps, args.step_size, args.deepfool_num_classes)
141
+
142
+ clean_correct = 0
143
+ adv_correct = 0
144
+ total = 0
145
+ skipped_unlabeled = 0
146
+ with tqdm(total=num_samples, unit="images", leave=False) as progress:
147
+ for _, inputs, targets in dataloader:
148
+ inputs = inputs.to(device)
149
+ targets = targets.to(device)
150
+ batch_size = inputs.size(0)
151
+
152
+ valid_mask = targets != settings.NO_LABEL
153
+ num_valid = int(valid_mask.sum().item())
154
+ skipped_unlabeled += batch_size - num_valid
155
+ if num_valid == 0:
156
+ progress.update(batch_size)
157
+ continue
158
+
159
+ inputs = inputs[valid_mask]
160
+ targets = targets[valid_mask]
161
+
162
+ with torch.no_grad():
163
+ with torch.amp.autocast(device.type, enabled=args.amp, dtype=amp_dtype):
164
+ clean_logits = net(inputs)
165
+
166
+ clean_preds = clean_logits.argmax(dim=1)
167
+ clean_correct += (clean_preds == targets).sum().item()
168
+
169
+ result = attack(inputs, target=None)
170
+ adv_logits = result.adv_logits
171
+ adv_preds = adv_logits.argmax(dim=1)
172
+ adv_correct += (adv_preds == targets).sum().item()
173
+
174
+ total += num_valid
175
+ progress.update(batch_size)
176
+
177
+ if total == 0:
178
+ raise RuntimeError(f"No labeled samples found (all labels are {settings.NO_LABEL})")
179
+
180
+ if skipped_unlabeled > 0:
181
+ logger.warning(f"Skipped {skipped_unlabeled} unlabeled samples (label={settings.NO_LABEL})")
182
+
183
+ clean_accuracy = clean_correct / total
184
+ adv_accuracy = adv_correct / total
185
+ accuracy_drop = clean_accuracy - adv_accuracy
186
+
187
+ logger.info(
188
+ f"{network_name}: clean={clean_accuracy:.4f}, adv={adv_accuracy:.4f}, drop={accuracy_drop:.4f} "
189
+ f"(evaluated on {total} labeled samples)"
190
+ )
191
+
192
+ if args.dry_run is False:
193
+ output = {
194
+ "method": "robust",
195
+ "accuracy": adv_accuracy,
196
+ "clean_accuracy": clean_accuracy,
197
+ "accuracy_drop": accuracy_drop,
198
+ "attack_method": args.method,
199
+ "epsilon": args.eps,
200
+ "num_samples": total,
201
+ "num_skipped_unlabeled": skipped_unlabeled,
202
+ }
203
+
204
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
205
+ output_dir.mkdir(parents=True, exist_ok=True)
206
+ epoch_str = f"_e{args.epoch}" if args.epoch is not None else ""
207
+ output_path = output_dir.joinpath(f"{network_name}{epoch_str}_{args.method}_eps{args.eps}.json")
208
+
209
+ with open(output_path, "w", encoding="utf-8") as f:
210
+ json.dump(output, f, indent=2)
211
+
212
+ logger.info(f"Results saved to {output_path}")
213
+
214
+
215
+ def set_parser(subparsers: Any) -> None:
216
+ subparser = subparsers.add_parser(
217
+ "adversarial",
218
+ allow_abbrev=False,
219
+ help="evaluate adversarial robustness of a model on a dataset",
220
+ description="evaluate adversarial robustness of a model on a dataset",
221
+ epilog=(
222
+ "Usage examples:\n"
223
+ "python -m birder.eval adversarial -n resnet_v2_50 -e 100 --method pgd --eps 0.02 --gpu data/validation\n"
224
+ "python -m birder.eval adversarial -n vovnet_v2_39 -t il-common --method pgd --batch-size 4 "
225
+ "--gpu --gpu-id 1 --fast-matmul data/validation_il-common_packed\n"
226
+ ),
227
+ formatter_class=cli.ArgumentHelpFormatter,
228
+ )
229
+ subparser.add_argument("-n", "--network", type=str, help="neural network to evaluate")
230
+ subparser.add_argument("-t", "--tag", type=str, help="model tag")
231
+ subparser.add_argument("-e", "--epoch", type=int, metavar="N", help="model checkpoint epoch")
232
+ subparser.add_argument("--reparameterized", default=False, action="store_true", help="load reparameterized model")
233
+ subparser.add_argument(
234
+ "--method",
235
+ type=str,
236
+ choices=["fgsm", "pgd", "deepfool", "simba"],
237
+ help="adversarial attack method",
238
+ )
239
+ subparser.add_argument("--eps", type=float, default=0.007, help="perturbation budget in pixel space [0, 1]")
240
+ subparser.add_argument("--steps", type=int, default=10, help="number of iterations for iterative attacks")
241
+ subparser.add_argument("--step-size", type=float, help="step size in pixel space (defaults to eps/steps for PGD)")
242
+ subparser.add_argument(
243
+ "--deepfool-num-classes", type=int, default=10, help="number of top classes to consider for DeepFool"
244
+ )
245
+ subparser.add_argument(
246
+ "--size", type=int, nargs="+", metavar=("H", "W"), help="image size for inference (defaults to model signature)"
247
+ )
248
+ subparser.add_argument(
249
+ "--amp", default=False, action="store_true", help="use torch.amp.autocast for mixed precision inference"
250
+ )
251
+ subparser.add_argument(
252
+ "--amp-dtype",
253
+ type=str,
254
+ choices=["float16", "bfloat16"],
255
+ help="whether to use float16 or bfloat16 for mixed precision",
256
+ )
257
+ subparser.add_argument(
258
+ "--fast-matmul", default=False, action="store_true", help="use fast matrix multiplication (affects precision)"
259
+ )
260
+ subparser.add_argument("--batch-size", type=int, default=32, metavar="N", help="the batch size")
261
+ subparser.add_argument(
262
+ "-j", "--num-workers", type=int, default=8, metavar="N", help="number of preprocessing workers"
263
+ )
264
+ subparser.add_argument("--center-crop", type=float, default=1.0, help="center crop ratio to use during inference")
265
+ subparser.add_argument(
266
+ "--simple-crop",
267
+ default=False,
268
+ action="store_true",
269
+ help="use a simple crop that preserves aspect ratio but may trim parts of the image",
270
+ )
271
+ subparser.add_argument(
272
+ "--dir", type=str, default="robust", help="place all outputs in a sub-directory (relative to results)"
273
+ )
274
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
275
+ subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
276
+ subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
277
+ subparser.add_argument(
278
+ "--mps", default=False, action="store_true", help="use mps (Metal Performance Shaders) device"
279
+ )
280
+ subparser.add_argument("--wds", default=False, action="store_true", help="evaluate a webdataset directory")
281
+ subparser.add_argument("--wds-size", type=int, metavar="N", help="size of the wds dataset")
282
+ subparser.add_argument("--wds-info", type=str, metavar="FILE", help="wds info file path")
283
+ subparser.add_argument(
284
+ "--wds-split", type=str, default="validation", metavar="NAME", help="wds dataset split to load"
285
+ )
286
+ subparser.add_argument("data_path", nargs="*", help="data files path (directories and files)")
287
+ subparser.set_defaults(func=main)
288
+
289
+
290
+ def validate_args(args: argparse.Namespace) -> None:
291
+ args.size = cli.parse_size(args.size)
292
+ if args.network is None:
293
+ raise cli.ValidationError("--network is required")
294
+ if args.method is None:
295
+ raise cli.ValidationError("--method is required")
296
+ if args.center_crop > 1 or args.center_crop <= 0.0:
297
+ raise cli.ValidationError(f"--center-crop must be in range of (0, 1.0], got {args.center_crop}")
298
+
299
+ if args.wds is False and len(args.data_path) == 0:
300
+ raise cli.ValidationError("Must provide at least one data source, --data-path or --wds")
301
+
302
+ if args.wds is True:
303
+ if args.wds_info is None and len(args.data_path) == 0:
304
+ raise cli.ValidationError("--wds requires a data path unless --wds-info is provided")
305
+ if len(args.data_path) > 1:
306
+ raise cli.ValidationError(f"--wds can have at most 1 --data-path, got {len(args.data_path)}")
307
+ if args.wds_info is None and len(args.data_path) == 1:
308
+ data_path = args.data_path[0]
309
+ if "://" in data_path and args.wds_size is None:
310
+ raise cli.ValidationError("--wds-size is required for remote --data-path")
311
+
312
+
313
+ def main(args: argparse.Namespace) -> None:
314
+ validate_args(args)
315
+ evaluate_robust(args)
File without changes