birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ """
2
+ AwA2 benchmark using MLP probe for multi-label attribute prediction
3
+
4
+ Paper "Zero-Shot Learning -- A Comprehensive Evaluation of the Good, the Bad and the Ugly"
5
+ https://arxiv.org/abs/1707.00600
6
+ """
7
+
8
+ import argparse
9
+ import logging
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import numpy.typing as npt
16
+ import polars as pl
17
+ import torch
18
+ from rich.console import Console
19
+ from rich.table import Table
20
+ from sklearn.metrics import f1_score
21
+ from torchvision.datasets import ImageFolder
22
+
23
+ from birder.common import cli
24
+ from birder.common import lib
25
+ from birder.conf import settings
26
+ from birder.datahub.evaluation import AwA2
27
+ from birder.eval._embeddings import load_embeddings
28
+ from birder.eval.methods.mlp import evaluate_mlp
29
+ from birder.eval.methods.mlp import train_mlp
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def _print_summary_table(results: list[dict[str, Any]]) -> None:
35
+ console = Console()
36
+
37
+ table = Table(show_header=True, header_style="bold dark_magenta")
38
+ table.add_column("AwA2 (MLP)", style="dim")
39
+ table.add_column("Macro F1", justify="right")
40
+ table.add_column("Std", justify="right")
41
+ table.add_column("Runs", justify="right")
42
+
43
+ for result in results:
44
+ table.add_row(
45
+ Path(result["embeddings_file"]).name,
46
+ f"{result['macro_f1']:.4f}",
47
+ f"{result['macro_f1_std']:.4f}",
48
+ f"{result['num_runs']}",
49
+ )
50
+
51
+ console.print(table)
52
+
53
+
54
+ def _write_results_csv(results: list[dict[str, Any]], attribute_names: list[str], output_path: Path) -> None:
55
+ rows: list[dict[str, Any]] = []
56
+ for result in results:
57
+ row: dict[str, Any] = {
58
+ "embeddings_file": result["embeddings_file"],
59
+ "method": result["method"],
60
+ "metric_mode": result["metric_mode"],
61
+ "macro_f1": result["macro_f1"],
62
+ "macro_f1_std": result["macro_f1_std"],
63
+ "num_runs": result["num_runs"],
64
+ }
65
+ for attr in attribute_names:
66
+ row[f"f1_{attr}"] = result["per_attribute_f1"].get(attr)
67
+
68
+ rows.append(row)
69
+
70
+ pl.DataFrame(rows).write_csv(output_path)
71
+ logger.info(f"Results saved to {output_path}")
72
+
73
+
74
+ def _load_awa2_metadata(dataset: AwA2) -> tuple[pl.DataFrame, npt.NDArray[np.float32], list[str], list[str], list[str]]:
75
+ """
76
+ Load AwA2 metadata: image paths, class assignments, and attribute matrix.
77
+
78
+ Returns
79
+ -------
80
+ metadata_df
81
+ DataFrame with columns: id (image stem), class_name
82
+ attribute_matrix
83
+ Binary attribute matrix of shape (num_classes, num_attributes)
84
+ class_names
85
+ List of class names in order
86
+ train_classes
87
+ List of training class names
88
+ test_classes
89
+ List of test class names
90
+ """
91
+
92
+ # Load class names (1-indexed in file)
93
+ class_names: list[str] = []
94
+ with open(dataset.classes_path, encoding="utf-8") as f:
95
+ for line in f:
96
+ parts = line.strip().split("\t")
97
+ class_names.append(parts[1])
98
+
99
+ # Load attribute matrix (one row per class, 85 attributes)
100
+ attribute_matrix = np.loadtxt(dataset.predicate_matrix_binary_path, dtype=np.float32)
101
+
102
+ # Load train/test class split
103
+ with open(dataset.trainclasses_path, encoding="utf-8") as f:
104
+ train_classes = [line.strip() for line in f if line.strip()]
105
+
106
+ with open(dataset.testclasses_path, encoding="utf-8") as f:
107
+ test_classes = [line.strip() for line in f if line.strip()]
108
+
109
+ # Load image paths using ImageFolder
110
+ image_dataset = ImageFolder(str(dataset.images_dir))
111
+ rows: list[dict[str, Any]] = []
112
+ for path, class_idx in image_dataset.samples:
113
+ class_name = image_dataset.classes[class_idx]
114
+ rows.append({"id": Path(path).stem, "class_name": class_name})
115
+
116
+ metadata_df = pl.DataFrame(rows)
117
+
118
+ return (metadata_df, attribute_matrix, class_names, train_classes, test_classes)
119
+
120
+
121
+ def _load_embeddings_with_labels(
122
+ embeddings_path: str,
123
+ metadata_df: pl.DataFrame,
124
+ attribute_matrix: npt.NDArray[np.float32],
125
+ class_names: list[str],
126
+ train_classes: list[str],
127
+ test_classes: list[str],
128
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]:
129
+ logger.info(f"Loading embeddings from {embeddings_path}")
130
+ sample_ids, all_features = load_embeddings(embeddings_path)
131
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
132
+
133
+ # Join embeddings with metadata
134
+ joined = metadata_df.join(emb_df, on="id", how="inner")
135
+ if joined.height < metadata_df.height:
136
+ logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
137
+
138
+ # Create class name to index mapping
139
+ class_to_idx = {name: idx for idx, name in enumerate(class_names)}
140
+
141
+ # Split into train/test based on class membership
142
+ train_mask = joined.get_column("class_name").is_in(train_classes)
143
+ test_mask = joined.get_column("class_name").is_in(test_classes)
144
+
145
+ train_data = joined.filter(train_mask)
146
+ test_data = joined.filter(test_mask)
147
+
148
+ # Extract features
149
+ x_train = np.array(train_data.get_column("embedding").to_list(), dtype=np.float32)
150
+ x_test = np.array(test_data.get_column("embedding").to_list(), dtype=np.float32)
151
+
152
+ # Get labels from attribute matrix (class-level attributes)
153
+ train_class_indices = [class_to_idx[name] for name in train_data.get_column("class_name").to_list()]
154
+ test_class_indices = [class_to_idx[name] for name in test_data.get_column("class_name").to_list()]
155
+
156
+ y_train = attribute_matrix[train_class_indices]
157
+ y_test = attribute_matrix[test_class_indices]
158
+
159
+ logger.info(f"Train: {x_train.shape[0]} samples ({len(train_classes)} classes)")
160
+ logger.info(f"Test: {x_test.shape[0]} samples ({len(test_classes)} classes)")
161
+ logger.info(f"Features: {x_train.shape[1]} dims, Attributes: {attribute_matrix.shape[1]}")
162
+
163
+ return (x_train, y_train, x_test, y_test)
164
+
165
+
166
+ def _compute_macro_f1(
167
+ y_true: npt.NDArray[np.int_], y_pred: npt.NDArray[np.int_], metric_mode: str
168
+ ) -> tuple[float, int]:
169
+ if metric_mode == "all":
170
+ num_attrs = y_true.shape[1]
171
+ score = f1_score(y_true, y_pred, average="macro", zero_division=0.0)
172
+ return (float(score), num_attrs)
173
+
174
+ if metric_mode == "present-only":
175
+ present_attrs = np.where(y_true.sum(axis=0) > 0)[0]
176
+ if len(present_attrs) == 0:
177
+ logger.warning("No positive attributes in y_true, falling back to --metric-mode all")
178
+ score = f1_score(y_true, y_pred, average="macro", zero_division=0.0)
179
+ return (float(score), y_true.shape[1])
180
+
181
+ score = f1_score(y_true, y_pred, average="macro", labels=present_attrs, zero_division=0.0)
182
+ return (float(score), int(len(present_attrs)))
183
+
184
+ raise ValueError(f"Unsupported metric mode: {metric_mode}")
185
+
186
+
187
+ # pylint: disable=too-many-locals
188
+ def evaluate_awa2_single(
189
+ x_train: npt.NDArray[np.float32],
190
+ y_train: npt.NDArray[np.float32],
191
+ x_test: npt.NDArray[np.float32],
192
+ y_test: npt.NDArray[np.float32],
193
+ attribute_names: list[str],
194
+ args: argparse.Namespace,
195
+ embeddings_path: str,
196
+ device: torch.device,
197
+ ) -> dict[str, Any]:
198
+ num_attributes = len(attribute_names)
199
+
200
+ scores: list[float] = []
201
+ per_attribute_f1_runs: list[dict[str, float]] = []
202
+
203
+ for run in range(args.runs):
204
+ run_seed = args.seed + run
205
+ logger.info(f"Run {run + 1}/{args.runs} (seed={run_seed})")
206
+
207
+ # Train MLP
208
+ model = train_mlp(
209
+ x_train,
210
+ y_train,
211
+ num_classes=num_attributes,
212
+ device=device,
213
+ epochs=args.epochs,
214
+ batch_size=args.batch_size,
215
+ lr=args.lr,
216
+ hidden_dim=args.hidden_dim,
217
+ dropout=args.dropout,
218
+ seed=run_seed,
219
+ )
220
+
221
+ # Evaluate
222
+ y_pred, y_true, _ = evaluate_mlp(model, x_test, y_test, batch_size=args.batch_size, device=device)
223
+ macro_f1, num_attrs_scored = _compute_macro_f1(y_true, y_pred, args.metric_mode)
224
+ scores.append(macro_f1)
225
+
226
+ # Per-attribute F1
227
+ per_attribute_f1: dict[str, float] = {}
228
+ for i, attr in enumerate(attribute_names):
229
+ attr_f1 = f1_score(y_true[:, i], y_pred[:, i], average="binary", zero_division=0.0)
230
+ per_attribute_f1[attr] = float(attr_f1)
231
+
232
+ per_attribute_f1_runs.append(per_attribute_f1)
233
+ logger.info(
234
+ f"Run {run + 1}/{args.runs} - Macro F1 ({args.metric_mode}, {num_attrs_scored} attrs): {macro_f1:.4f}"
235
+ )
236
+
237
+ # Average results
238
+ scores_arr = np.array(scores)
239
+ mean_f1 = float(scores_arr.mean())
240
+ std_f1 = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
241
+
242
+ # Average per-attribute F1 across runs
243
+ avg_per_attribute_f1: dict[str, float] = {}
244
+ for attr in attribute_names:
245
+ attr_scores = [run_f1[attr] for run_f1 in per_attribute_f1_runs]
246
+ avg_per_attribute_f1[attr] = float(np.mean(attr_scores))
247
+
248
+ logger.info(f"Mean Macro F1 over {args.runs} runs: {mean_f1:.4f} +/- {std_f1:.4f} (std)")
249
+
250
+ return {
251
+ "method": "mlp",
252
+ "metric_mode": args.metric_mode,
253
+ "macro_f1": mean_f1,
254
+ "macro_f1_std": std_f1,
255
+ "num_runs": args.runs,
256
+ "per_attribute_f1": avg_per_attribute_f1,
257
+ "embeddings_file": str(embeddings_path),
258
+ }
259
+
260
+
261
+ def evaluate_awa2(args: argparse.Namespace) -> None:
262
+ tic = time.time()
263
+
264
+ if args.gpu is True:
265
+ device = torch.device("cuda")
266
+ else:
267
+ device = torch.device("cpu")
268
+
269
+ if args.gpu_id is not None:
270
+ torch.cuda.set_device(args.gpu_id)
271
+
272
+ logger.info(f"Using device {device}")
273
+ logger.info(f"Loading AwA2 dataset from {args.dataset_path}")
274
+ logger.info(f"Metric mode: {args.metric_mode}")
275
+ dataset = AwA2(args.dataset_path)
276
+ attribute_names = dataset.attribute_names
277
+
278
+ metadata_df, attribute_matrix, class_names, train_classes, test_classes = _load_awa2_metadata(dataset)
279
+ logger.info(f"Loaded metadata for {metadata_df.height} images")
280
+ logger.info(f"Train classes: {len(train_classes)}, Test classes: {len(test_classes)}")
281
+
282
+ results: list[dict[str, Any]] = []
283
+ total = len(args.embeddings)
284
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
285
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
286
+ x_train, y_train, x_test, y_test = _load_embeddings_with_labels(
287
+ embeddings_path, metadata_df, attribute_matrix, class_names, train_classes, test_classes
288
+ )
289
+
290
+ result = evaluate_awa2_single(x_train, y_train, x_test, y_test, attribute_names, args, embeddings_path, device)
291
+ results.append(result)
292
+
293
+ _print_summary_table(results)
294
+
295
+ if args.dry_run is False:
296
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
297
+ output_dir.mkdir(parents=True, exist_ok=True)
298
+ output_path = output_dir.joinpath("awa2.csv")
299
+ _write_results_csv(results, attribute_names, output_path)
300
+
301
+ toc = time.time()
302
+ logger.info(f"AwA2 benchmark completed in {lib.format_duration(toc - tic)}")
303
+
304
+
305
+ def set_parser(subparsers: Any) -> None:
306
+ subparser = subparsers.add_parser(
307
+ "awa2",
308
+ allow_abbrev=False,
309
+ help="run AwA2 benchmark - 85 attribute multi-label classification using MLP probe",
310
+ description="run AwA2 benchmark - 85 attribute multi-label classification using MLP probe",
311
+ epilog=(
312
+ "Usage examples:\n"
313
+ "python -m birder.eval awa2 --embeddings "
314
+ "results/awa2_embeddings.parquet "
315
+ "--dataset-path ~/Datasets/Animals_with_Attributes2 --dry-run\n"
316
+ "python -m birder.eval awa2 --embeddings results/awa2_*.parquet "
317
+ "--dataset-path ~/Datasets/Animals_with_Attributes2 --gpu\n"
318
+ ),
319
+ formatter_class=cli.ArgumentHelpFormatter,
320
+ )
321
+ subparser.add_argument(
322
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
323
+ )
324
+ subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to AwA2 dataset root")
325
+ subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
326
+ subparser.add_argument("--epochs", type=int, default=100, help="training epochs per run")
327
+ subparser.add_argument("--batch-size", type=int, default=128, help="batch size for training and inference")
328
+ subparser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
329
+ subparser.add_argument("--hidden-dim", type=int, default=512, help="MLP hidden layer dimension")
330
+ subparser.add_argument("--dropout", type=float, default=0.5, help="dropout probability")
331
+ subparser.add_argument("--seed", type=int, default=0, help="base random seed")
332
+ subparser.add_argument(
333
+ "--metric-mode",
334
+ type=str,
335
+ choices=["all", "present-only"],
336
+ default="present-only",
337
+ help="macro F1 mode: all attributes or only attributes present in test split",
338
+ )
339
+ subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
340
+ subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
341
+ subparser.add_argument(
342
+ "--dir", type=str, default="awa2", help="place all outputs in a sub-directory (relative to results)"
343
+ )
344
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
345
+ subparser.set_defaults(func=main)
346
+
347
+
348
+ def validate_args(args: argparse.Namespace) -> None:
349
+ if args.embeddings is None:
350
+ raise cli.ValidationError("--embeddings is required")
351
+ if args.dataset_path is None:
352
+ raise cli.ValidationError("--dataset-path is required")
353
+
354
+
355
+ def main(args: argparse.Namespace) -> None:
356
+ validate_args(args)
357
+ evaluate_awa2(args)
@@ -0,0 +1,198 @@
1
+ """
2
+ BIOSCAN-5M benchmark using AMI clustering for unsupervised embedding evaluation
3
+
4
+ Paper "BIOSCAN-5M: A Multimodal Dataset for Insect Biodiversity",
5
+ https://arxiv.org/abs/2406.12723
6
+ """
7
+
8
+ import argparse
9
+ import logging
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import polars as pl
16
+ from rich.console import Console
17
+ from rich.table import Table
18
+
19
+ from birder.common import cli
20
+ from birder.common import lib
21
+ from birder.conf import settings
22
+ from birder.data.datasets.directory import class_to_idx_from_paths
23
+ from birder.data.datasets.directory import make_image_dataset
24
+ from birder.eval._embeddings import load_embeddings
25
+ from birder.eval.methods.ami import evaluate_ami
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def _print_summary_table(results: list[dict[str, Any]]) -> None:
31
+ console = Console()
32
+
33
+ table = Table(show_header=True, header_style="bold dark_magenta")
34
+ table.add_column("BIOSCAN-5M (AMI)", style="dim")
35
+ table.add_column("AMI Score", justify="right")
36
+ table.add_column("Classes", justify="right")
37
+ table.add_column("Samples", justify="right")
38
+
39
+ for result in results:
40
+ table.add_row(
41
+ Path(result["embeddings_file"]).name,
42
+ f"{result['ami_score']:.4f}",
43
+ str(result["num_classes"]),
44
+ str(result["num_samples"]),
45
+ )
46
+
47
+ console.print(table)
48
+
49
+
50
+ def _write_results_csv(results: list[dict[str, Any]], output_path: Path) -> None:
51
+ rows: list[dict[str, Any]] = []
52
+ for result in results:
53
+ rows.append(
54
+ {
55
+ "embeddings_file": result["embeddings_file"],
56
+ "method": result["method"],
57
+ "ami_score": result["ami_score"],
58
+ "l2_normalize": result["l2_normalize"],
59
+ "num_classes": result["num_classes"],
60
+ "num_samples": result["num_samples"],
61
+ }
62
+ )
63
+
64
+ pl.DataFrame(rows).write_csv(output_path)
65
+ logger.info(f"Results saved to {output_path}")
66
+
67
+
68
+ def _load_bioscan5m_metadata(data_path: str) -> pl.DataFrame:
69
+ """
70
+ Load metadata from an ImageFolder-compatible directory
71
+
72
+ Returns DataFrame with columns: id (filename stem), label
73
+ """
74
+
75
+ class_to_idx = class_to_idx_from_paths([data_path])
76
+ image_dataset = make_image_dataset([data_path], class_to_idx)
77
+
78
+ rows: list[dict[str, Any]] = []
79
+ for i in range(len(image_dataset)):
80
+ path = image_dataset.paths[i].decode("utf-8")
81
+ label = image_dataset.labels[i].item()
82
+ rows.append({"id": Path(path).stem, "label": label})
83
+
84
+ return pl.DataFrame(rows)
85
+
86
+
87
+ def _load_embeddings_with_labels(embeddings_path: str, metadata_df: pl.DataFrame) -> tuple[np.ndarray, np.ndarray, int]:
88
+ logger.info(f"Loading embeddings from {embeddings_path}")
89
+ sample_ids, all_features = load_embeddings(embeddings_path)
90
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
91
+
92
+ joined = metadata_df.join(emb_df, on="id", how="inner")
93
+ if joined.height < metadata_df.height:
94
+ logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
95
+
96
+ features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
97
+ labels = joined.get_column("label").to_numpy().astype(np.int_)
98
+
99
+ num_classes = len(metadata_df.get_column("label").unique())
100
+ logger.info(f"Loaded {features.shape[0]} samples with {features.shape[1]} dimensions, {num_classes} classes")
101
+
102
+ return (features, labels, num_classes)
103
+
104
+
105
+ def evaluate_bioscan5m(args: argparse.Namespace) -> None:
106
+ tic = time.time()
107
+
108
+ logger.info(f"Loading dataset from {args.data_path}")
109
+ metadata_df = _load_bioscan5m_metadata(args.data_path)
110
+ logger.info(f"Loaded metadata for {metadata_df.height} images")
111
+
112
+ results: list[dict[str, Any]] = []
113
+ total = len(args.embeddings)
114
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
115
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
116
+ features, labels, num_classes = _load_embeddings_with_labels(embeddings_path, metadata_df)
117
+
118
+ logger.info(
119
+ f"Evaluating AMI with umap_dim={args.umap_dim}, seed={args.seed}, "
120
+ f"l2_normalize={not args.no_l2_normalize}"
121
+ )
122
+ ami_score = evaluate_ami(
123
+ features,
124
+ labels,
125
+ n_clusters=num_classes,
126
+ umap_dim=args.umap_dim,
127
+ l2_normalize_features=not args.no_l2_normalize,
128
+ seed=args.seed,
129
+ )
130
+ logger.info(f"AMI Score: {ami_score:.4f}")
131
+
132
+ results.append(
133
+ {
134
+ "embeddings_file": str(embeddings_path),
135
+ "method": "ami",
136
+ "ami_score": ami_score,
137
+ "l2_normalize": not args.no_l2_normalize,
138
+ "num_classes": num_classes,
139
+ "num_samples": len(labels),
140
+ }
141
+ )
142
+
143
+ _print_summary_table(results)
144
+
145
+ if args.dry_run is False:
146
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
147
+ output_dir.mkdir(parents=True, exist_ok=True)
148
+ output_path = output_dir.joinpath("bioscan5m.csv")
149
+ _write_results_csv(results, output_path)
150
+
151
+ toc = time.time()
152
+ logger.info(f"BIOSCAN-5M benchmark completed in {lib.format_duration(toc - tic)}")
153
+
154
+
155
+ def set_parser(subparsers: Any) -> None:
156
+ subparser = subparsers.add_parser(
157
+ "bioscan5m",
158
+ allow_abbrev=False,
159
+ help="run BIOSCAN-5M benchmark - unsupervised embedding evaluation using AMI clustering",
160
+ description="run BIOSCAN-5M benchmark - unsupervised embedding evaluation using AMI clustering",
161
+ epilog=(
162
+ "Usage examples:\n"
163
+ "python -m birder.eval bioscan5m --embeddings "
164
+ "results/embeddings.parquet --data-path ~/Datasets/BIOSCAN-5M/species/testing_unseen --dry-run\n"
165
+ "python -m birder.eval bioscan5m --embeddings results/bioscan5m/*.parquet "
166
+ "--data-path ~/Datasets/BIOSCAN-5M/species/testing_unseen --seed 0\n"
167
+ ),
168
+ formatter_class=cli.ArgumentHelpFormatter,
169
+ )
170
+ subparser.add_argument(
171
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
172
+ )
173
+ subparser.add_argument("--data-path", type=str, metavar="PATH", help="path to ImageFolder-compatible directory")
174
+ subparser.add_argument("--umap-dim", type=int, default=50, help="target dimensionality for UMAP reduction")
175
+ subparser.add_argument(
176
+ "--no-l2-normalize",
177
+ default=False,
178
+ action="store_true",
179
+ help="disable L2 normalization of embeddings before UMAP",
180
+ )
181
+ subparser.add_argument("--seed", type=int, help="random seed for UMAP")
182
+ subparser.add_argument(
183
+ "--dir", type=str, default="bioscan5m", help="place all outputs in a sub-directory (relative to results)"
184
+ )
185
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
186
+ subparser.set_defaults(func=main)
187
+
188
+
189
+ def validate_args(args: argparse.Namespace) -> None:
190
+ if args.embeddings is None:
191
+ raise cli.ValidationError("--embeddings is required")
192
+ if args.data_path is None:
193
+ raise cli.ValidationError("--data-path is required")
194
+
195
+
196
+ def main(args: argparse.Namespace) -> None:
197
+ validate_args(args)
198
+ evaluate_bioscan5m(args)