birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AwA2 benchmark using MLP probe for multi-label attribute prediction
|
|
3
|
+
|
|
4
|
+
Paper "Zero-Shot Learning -- A Comprehensive Evaluation of the Good, the Bad and the Ugly"
|
|
5
|
+
https://arxiv.org/abs/1707.00600
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import numpy.typing as npt
|
|
16
|
+
import polars as pl
|
|
17
|
+
import torch
|
|
18
|
+
from rich.console import Console
|
|
19
|
+
from rich.table import Table
|
|
20
|
+
from sklearn.metrics import f1_score
|
|
21
|
+
from torchvision.datasets import ImageFolder
|
|
22
|
+
|
|
23
|
+
from birder.common import cli
|
|
24
|
+
from birder.common import lib
|
|
25
|
+
from birder.conf import settings
|
|
26
|
+
from birder.datahub.evaluation import AwA2
|
|
27
|
+
from birder.eval._embeddings import load_embeddings
|
|
28
|
+
from birder.eval.methods.mlp import evaluate_mlp
|
|
29
|
+
from birder.eval.methods.mlp import train_mlp
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _print_summary_table(results: list[dict[str, Any]]) -> None:
|
|
35
|
+
console = Console()
|
|
36
|
+
|
|
37
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
38
|
+
table.add_column("AwA2 (MLP)", style="dim")
|
|
39
|
+
table.add_column("Macro F1", justify="right")
|
|
40
|
+
table.add_column("Std", justify="right")
|
|
41
|
+
table.add_column("Runs", justify="right")
|
|
42
|
+
|
|
43
|
+
for result in results:
|
|
44
|
+
table.add_row(
|
|
45
|
+
Path(result["embeddings_file"]).name,
|
|
46
|
+
f"{result['macro_f1']:.4f}",
|
|
47
|
+
f"{result['macro_f1_std']:.4f}",
|
|
48
|
+
f"{result['num_runs']}",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
console.print(table)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _write_results_csv(results: list[dict[str, Any]], attribute_names: list[str], output_path: Path) -> None:
|
|
55
|
+
rows: list[dict[str, Any]] = []
|
|
56
|
+
for result in results:
|
|
57
|
+
row: dict[str, Any] = {
|
|
58
|
+
"embeddings_file": result["embeddings_file"],
|
|
59
|
+
"method": result["method"],
|
|
60
|
+
"metric_mode": result["metric_mode"],
|
|
61
|
+
"macro_f1": result["macro_f1"],
|
|
62
|
+
"macro_f1_std": result["macro_f1_std"],
|
|
63
|
+
"num_runs": result["num_runs"],
|
|
64
|
+
}
|
|
65
|
+
for attr in attribute_names:
|
|
66
|
+
row[f"f1_{attr}"] = result["per_attribute_f1"].get(attr)
|
|
67
|
+
|
|
68
|
+
rows.append(row)
|
|
69
|
+
|
|
70
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
71
|
+
logger.info(f"Results saved to {output_path}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _load_awa2_metadata(dataset: AwA2) -> tuple[pl.DataFrame, npt.NDArray[np.float32], list[str], list[str], list[str]]:
|
|
75
|
+
"""
|
|
76
|
+
Load AwA2 metadata: image paths, class assignments, and attribute matrix.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
metadata_df
|
|
81
|
+
DataFrame with columns: id (image stem), class_name
|
|
82
|
+
attribute_matrix
|
|
83
|
+
Binary attribute matrix of shape (num_classes, num_attributes)
|
|
84
|
+
class_names
|
|
85
|
+
List of class names in order
|
|
86
|
+
train_classes
|
|
87
|
+
List of training class names
|
|
88
|
+
test_classes
|
|
89
|
+
List of test class names
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
# Load class names (1-indexed in file)
|
|
93
|
+
class_names: list[str] = []
|
|
94
|
+
with open(dataset.classes_path, encoding="utf-8") as f:
|
|
95
|
+
for line in f:
|
|
96
|
+
parts = line.strip().split("\t")
|
|
97
|
+
class_names.append(parts[1])
|
|
98
|
+
|
|
99
|
+
# Load attribute matrix (one row per class, 85 attributes)
|
|
100
|
+
attribute_matrix = np.loadtxt(dataset.predicate_matrix_binary_path, dtype=np.float32)
|
|
101
|
+
|
|
102
|
+
# Load train/test class split
|
|
103
|
+
with open(dataset.trainclasses_path, encoding="utf-8") as f:
|
|
104
|
+
train_classes = [line.strip() for line in f if line.strip()]
|
|
105
|
+
|
|
106
|
+
with open(dataset.testclasses_path, encoding="utf-8") as f:
|
|
107
|
+
test_classes = [line.strip() for line in f if line.strip()]
|
|
108
|
+
|
|
109
|
+
# Load image paths using ImageFolder
|
|
110
|
+
image_dataset = ImageFolder(str(dataset.images_dir))
|
|
111
|
+
rows: list[dict[str, Any]] = []
|
|
112
|
+
for path, class_idx in image_dataset.samples:
|
|
113
|
+
class_name = image_dataset.classes[class_idx]
|
|
114
|
+
rows.append({"id": Path(path).stem, "class_name": class_name})
|
|
115
|
+
|
|
116
|
+
metadata_df = pl.DataFrame(rows)
|
|
117
|
+
|
|
118
|
+
return (metadata_df, attribute_matrix, class_names, train_classes, test_classes)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _load_embeddings_with_labels(
|
|
122
|
+
embeddings_path: str,
|
|
123
|
+
metadata_df: pl.DataFrame,
|
|
124
|
+
attribute_matrix: npt.NDArray[np.float32],
|
|
125
|
+
class_names: list[str],
|
|
126
|
+
train_classes: list[str],
|
|
127
|
+
test_classes: list[str],
|
|
128
|
+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]:
|
|
129
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
130
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
131
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
132
|
+
|
|
133
|
+
# Join embeddings with metadata
|
|
134
|
+
joined = metadata_df.join(emb_df, on="id", how="inner")
|
|
135
|
+
if joined.height < metadata_df.height:
|
|
136
|
+
logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
|
|
137
|
+
|
|
138
|
+
# Create class name to index mapping
|
|
139
|
+
class_to_idx = {name: idx for idx, name in enumerate(class_names)}
|
|
140
|
+
|
|
141
|
+
# Split into train/test based on class membership
|
|
142
|
+
train_mask = joined.get_column("class_name").is_in(train_classes)
|
|
143
|
+
test_mask = joined.get_column("class_name").is_in(test_classes)
|
|
144
|
+
|
|
145
|
+
train_data = joined.filter(train_mask)
|
|
146
|
+
test_data = joined.filter(test_mask)
|
|
147
|
+
|
|
148
|
+
# Extract features
|
|
149
|
+
x_train = np.array(train_data.get_column("embedding").to_list(), dtype=np.float32)
|
|
150
|
+
x_test = np.array(test_data.get_column("embedding").to_list(), dtype=np.float32)
|
|
151
|
+
|
|
152
|
+
# Get labels from attribute matrix (class-level attributes)
|
|
153
|
+
train_class_indices = [class_to_idx[name] for name in train_data.get_column("class_name").to_list()]
|
|
154
|
+
test_class_indices = [class_to_idx[name] for name in test_data.get_column("class_name").to_list()]
|
|
155
|
+
|
|
156
|
+
y_train = attribute_matrix[train_class_indices]
|
|
157
|
+
y_test = attribute_matrix[test_class_indices]
|
|
158
|
+
|
|
159
|
+
logger.info(f"Train: {x_train.shape[0]} samples ({len(train_classes)} classes)")
|
|
160
|
+
logger.info(f"Test: {x_test.shape[0]} samples ({len(test_classes)} classes)")
|
|
161
|
+
logger.info(f"Features: {x_train.shape[1]} dims, Attributes: {attribute_matrix.shape[1]}")
|
|
162
|
+
|
|
163
|
+
return (x_train, y_train, x_test, y_test)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _compute_macro_f1(
|
|
167
|
+
y_true: npt.NDArray[np.int_], y_pred: npt.NDArray[np.int_], metric_mode: str
|
|
168
|
+
) -> tuple[float, int]:
|
|
169
|
+
if metric_mode == "all":
|
|
170
|
+
num_attrs = y_true.shape[1]
|
|
171
|
+
score = f1_score(y_true, y_pred, average="macro", zero_division=0.0)
|
|
172
|
+
return (float(score), num_attrs)
|
|
173
|
+
|
|
174
|
+
if metric_mode == "present-only":
|
|
175
|
+
present_attrs = np.where(y_true.sum(axis=0) > 0)[0]
|
|
176
|
+
if len(present_attrs) == 0:
|
|
177
|
+
logger.warning("No positive attributes in y_true, falling back to --metric-mode all")
|
|
178
|
+
score = f1_score(y_true, y_pred, average="macro", zero_division=0.0)
|
|
179
|
+
return (float(score), y_true.shape[1])
|
|
180
|
+
|
|
181
|
+
score = f1_score(y_true, y_pred, average="macro", labels=present_attrs, zero_division=0.0)
|
|
182
|
+
return (float(score), int(len(present_attrs)))
|
|
183
|
+
|
|
184
|
+
raise ValueError(f"Unsupported metric mode: {metric_mode}")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# pylint: disable=too-many-locals
|
|
188
|
+
def evaluate_awa2_single(
|
|
189
|
+
x_train: npt.NDArray[np.float32],
|
|
190
|
+
y_train: npt.NDArray[np.float32],
|
|
191
|
+
x_test: npt.NDArray[np.float32],
|
|
192
|
+
y_test: npt.NDArray[np.float32],
|
|
193
|
+
attribute_names: list[str],
|
|
194
|
+
args: argparse.Namespace,
|
|
195
|
+
embeddings_path: str,
|
|
196
|
+
device: torch.device,
|
|
197
|
+
) -> dict[str, Any]:
|
|
198
|
+
num_attributes = len(attribute_names)
|
|
199
|
+
|
|
200
|
+
scores: list[float] = []
|
|
201
|
+
per_attribute_f1_runs: list[dict[str, float]] = []
|
|
202
|
+
|
|
203
|
+
for run in range(args.runs):
|
|
204
|
+
run_seed = args.seed + run
|
|
205
|
+
logger.info(f"Run {run + 1}/{args.runs} (seed={run_seed})")
|
|
206
|
+
|
|
207
|
+
# Train MLP
|
|
208
|
+
model = train_mlp(
|
|
209
|
+
x_train,
|
|
210
|
+
y_train,
|
|
211
|
+
num_classes=num_attributes,
|
|
212
|
+
device=device,
|
|
213
|
+
epochs=args.epochs,
|
|
214
|
+
batch_size=args.batch_size,
|
|
215
|
+
lr=args.lr,
|
|
216
|
+
hidden_dim=args.hidden_dim,
|
|
217
|
+
dropout=args.dropout,
|
|
218
|
+
seed=run_seed,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Evaluate
|
|
222
|
+
y_pred, y_true, _ = evaluate_mlp(model, x_test, y_test, batch_size=args.batch_size, device=device)
|
|
223
|
+
macro_f1, num_attrs_scored = _compute_macro_f1(y_true, y_pred, args.metric_mode)
|
|
224
|
+
scores.append(macro_f1)
|
|
225
|
+
|
|
226
|
+
# Per-attribute F1
|
|
227
|
+
per_attribute_f1: dict[str, float] = {}
|
|
228
|
+
for i, attr in enumerate(attribute_names):
|
|
229
|
+
attr_f1 = f1_score(y_true[:, i], y_pred[:, i], average="binary", zero_division=0.0)
|
|
230
|
+
per_attribute_f1[attr] = float(attr_f1)
|
|
231
|
+
|
|
232
|
+
per_attribute_f1_runs.append(per_attribute_f1)
|
|
233
|
+
logger.info(
|
|
234
|
+
f"Run {run + 1}/{args.runs} - Macro F1 ({args.metric_mode}, {num_attrs_scored} attrs): {macro_f1:.4f}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Average results
|
|
238
|
+
scores_arr = np.array(scores)
|
|
239
|
+
mean_f1 = float(scores_arr.mean())
|
|
240
|
+
std_f1 = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
|
|
241
|
+
|
|
242
|
+
# Average per-attribute F1 across runs
|
|
243
|
+
avg_per_attribute_f1: dict[str, float] = {}
|
|
244
|
+
for attr in attribute_names:
|
|
245
|
+
attr_scores = [run_f1[attr] for run_f1 in per_attribute_f1_runs]
|
|
246
|
+
avg_per_attribute_f1[attr] = float(np.mean(attr_scores))
|
|
247
|
+
|
|
248
|
+
logger.info(f"Mean Macro F1 over {args.runs} runs: {mean_f1:.4f} +/- {std_f1:.4f} (std)")
|
|
249
|
+
|
|
250
|
+
return {
|
|
251
|
+
"method": "mlp",
|
|
252
|
+
"metric_mode": args.metric_mode,
|
|
253
|
+
"macro_f1": mean_f1,
|
|
254
|
+
"macro_f1_std": std_f1,
|
|
255
|
+
"num_runs": args.runs,
|
|
256
|
+
"per_attribute_f1": avg_per_attribute_f1,
|
|
257
|
+
"embeddings_file": str(embeddings_path),
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def evaluate_awa2(args: argparse.Namespace) -> None:
|
|
262
|
+
tic = time.time()
|
|
263
|
+
|
|
264
|
+
if args.gpu is True:
|
|
265
|
+
device = torch.device("cuda")
|
|
266
|
+
else:
|
|
267
|
+
device = torch.device("cpu")
|
|
268
|
+
|
|
269
|
+
if args.gpu_id is not None:
|
|
270
|
+
torch.cuda.set_device(args.gpu_id)
|
|
271
|
+
|
|
272
|
+
logger.info(f"Using device {device}")
|
|
273
|
+
logger.info(f"Loading AwA2 dataset from {args.dataset_path}")
|
|
274
|
+
logger.info(f"Metric mode: {args.metric_mode}")
|
|
275
|
+
dataset = AwA2(args.dataset_path)
|
|
276
|
+
attribute_names = dataset.attribute_names
|
|
277
|
+
|
|
278
|
+
metadata_df, attribute_matrix, class_names, train_classes, test_classes = _load_awa2_metadata(dataset)
|
|
279
|
+
logger.info(f"Loaded metadata for {metadata_df.height} images")
|
|
280
|
+
logger.info(f"Train classes: {len(train_classes)}, Test classes: {len(test_classes)}")
|
|
281
|
+
|
|
282
|
+
results: list[dict[str, Any]] = []
|
|
283
|
+
total = len(args.embeddings)
|
|
284
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
285
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
286
|
+
x_train, y_train, x_test, y_test = _load_embeddings_with_labels(
|
|
287
|
+
embeddings_path, metadata_df, attribute_matrix, class_names, train_classes, test_classes
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
result = evaluate_awa2_single(x_train, y_train, x_test, y_test, attribute_names, args, embeddings_path, device)
|
|
291
|
+
results.append(result)
|
|
292
|
+
|
|
293
|
+
_print_summary_table(results)
|
|
294
|
+
|
|
295
|
+
if args.dry_run is False:
|
|
296
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
297
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
298
|
+
output_path = output_dir.joinpath("awa2.csv")
|
|
299
|
+
_write_results_csv(results, attribute_names, output_path)
|
|
300
|
+
|
|
301
|
+
toc = time.time()
|
|
302
|
+
logger.info(f"AwA2 benchmark completed in {lib.format_duration(toc - tic)}")
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def set_parser(subparsers: Any) -> None:
|
|
306
|
+
subparser = subparsers.add_parser(
|
|
307
|
+
"awa2",
|
|
308
|
+
allow_abbrev=False,
|
|
309
|
+
help="run AwA2 benchmark - 85 attribute multi-label classification using MLP probe",
|
|
310
|
+
description="run AwA2 benchmark - 85 attribute multi-label classification using MLP probe",
|
|
311
|
+
epilog=(
|
|
312
|
+
"Usage examples:\n"
|
|
313
|
+
"python -m birder.eval awa2 --embeddings "
|
|
314
|
+
"results/awa2_embeddings.parquet "
|
|
315
|
+
"--dataset-path ~/Datasets/Animals_with_Attributes2 --dry-run\n"
|
|
316
|
+
"python -m birder.eval awa2 --embeddings results/awa2_*.parquet "
|
|
317
|
+
"--dataset-path ~/Datasets/Animals_with_Attributes2 --gpu\n"
|
|
318
|
+
),
|
|
319
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
320
|
+
)
|
|
321
|
+
subparser.add_argument(
|
|
322
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
|
|
323
|
+
)
|
|
324
|
+
subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to AwA2 dataset root")
|
|
325
|
+
subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
|
|
326
|
+
subparser.add_argument("--epochs", type=int, default=100, help="training epochs per run")
|
|
327
|
+
subparser.add_argument("--batch-size", type=int, default=128, help="batch size for training and inference")
|
|
328
|
+
subparser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
|
|
329
|
+
subparser.add_argument("--hidden-dim", type=int, default=512, help="MLP hidden layer dimension")
|
|
330
|
+
subparser.add_argument("--dropout", type=float, default=0.5, help="dropout probability")
|
|
331
|
+
subparser.add_argument("--seed", type=int, default=0, help="base random seed")
|
|
332
|
+
subparser.add_argument(
|
|
333
|
+
"--metric-mode",
|
|
334
|
+
type=str,
|
|
335
|
+
choices=["all", "present-only"],
|
|
336
|
+
default="present-only",
|
|
337
|
+
help="macro F1 mode: all attributes or only attributes present in test split",
|
|
338
|
+
)
|
|
339
|
+
subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
|
|
340
|
+
subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
|
|
341
|
+
subparser.add_argument(
|
|
342
|
+
"--dir", type=str, default="awa2", help="place all outputs in a sub-directory (relative to results)"
|
|
343
|
+
)
|
|
344
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
345
|
+
subparser.set_defaults(func=main)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
349
|
+
if args.embeddings is None:
|
|
350
|
+
raise cli.ValidationError("--embeddings is required")
|
|
351
|
+
if args.dataset_path is None:
|
|
352
|
+
raise cli.ValidationError("--dataset-path is required")
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def main(args: argparse.Namespace) -> None:
|
|
356
|
+
validate_args(args)
|
|
357
|
+
evaluate_awa2(args)
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""
|
|
2
|
+
BIOSCAN-5M benchmark using AMI clustering for unsupervised embedding evaluation
|
|
3
|
+
|
|
4
|
+
Paper "BIOSCAN-5M: A Multimodal Dataset for Insect Biodiversity",
|
|
5
|
+
https://arxiv.org/abs/2406.12723
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import polars as pl
|
|
16
|
+
from rich.console import Console
|
|
17
|
+
from rich.table import Table
|
|
18
|
+
|
|
19
|
+
from birder.common import cli
|
|
20
|
+
from birder.common import lib
|
|
21
|
+
from birder.conf import settings
|
|
22
|
+
from birder.data.datasets.directory import class_to_idx_from_paths
|
|
23
|
+
from birder.data.datasets.directory import make_image_dataset
|
|
24
|
+
from birder.eval._embeddings import load_embeddings
|
|
25
|
+
from birder.eval.methods.ami import evaluate_ami
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _print_summary_table(results: list[dict[str, Any]]) -> None:
|
|
31
|
+
console = Console()
|
|
32
|
+
|
|
33
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
34
|
+
table.add_column("BIOSCAN-5M (AMI)", style="dim")
|
|
35
|
+
table.add_column("AMI Score", justify="right")
|
|
36
|
+
table.add_column("Classes", justify="right")
|
|
37
|
+
table.add_column("Samples", justify="right")
|
|
38
|
+
|
|
39
|
+
for result in results:
|
|
40
|
+
table.add_row(
|
|
41
|
+
Path(result["embeddings_file"]).name,
|
|
42
|
+
f"{result['ami_score']:.4f}",
|
|
43
|
+
str(result["num_classes"]),
|
|
44
|
+
str(result["num_samples"]),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
console.print(table)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _write_results_csv(results: list[dict[str, Any]], output_path: Path) -> None:
|
|
51
|
+
rows: list[dict[str, Any]] = []
|
|
52
|
+
for result in results:
|
|
53
|
+
rows.append(
|
|
54
|
+
{
|
|
55
|
+
"embeddings_file": result["embeddings_file"],
|
|
56
|
+
"method": result["method"],
|
|
57
|
+
"ami_score": result["ami_score"],
|
|
58
|
+
"l2_normalize": result["l2_normalize"],
|
|
59
|
+
"num_classes": result["num_classes"],
|
|
60
|
+
"num_samples": result["num_samples"],
|
|
61
|
+
}
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
65
|
+
logger.info(f"Results saved to {output_path}")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _load_bioscan5m_metadata(data_path: str) -> pl.DataFrame:
|
|
69
|
+
"""
|
|
70
|
+
Load metadata from an ImageFolder-compatible directory
|
|
71
|
+
|
|
72
|
+
Returns DataFrame with columns: id (filename stem), label
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
class_to_idx = class_to_idx_from_paths([data_path])
|
|
76
|
+
image_dataset = make_image_dataset([data_path], class_to_idx)
|
|
77
|
+
|
|
78
|
+
rows: list[dict[str, Any]] = []
|
|
79
|
+
for i in range(len(image_dataset)):
|
|
80
|
+
path = image_dataset.paths[i].decode("utf-8")
|
|
81
|
+
label = image_dataset.labels[i].item()
|
|
82
|
+
rows.append({"id": Path(path).stem, "label": label})
|
|
83
|
+
|
|
84
|
+
return pl.DataFrame(rows)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _load_embeddings_with_labels(embeddings_path: str, metadata_df: pl.DataFrame) -> tuple[np.ndarray, np.ndarray, int]:
|
|
88
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
89
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
90
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
91
|
+
|
|
92
|
+
joined = metadata_df.join(emb_df, on="id", how="inner")
|
|
93
|
+
if joined.height < metadata_df.height:
|
|
94
|
+
logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
|
|
95
|
+
|
|
96
|
+
features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
97
|
+
labels = joined.get_column("label").to_numpy().astype(np.int_)
|
|
98
|
+
|
|
99
|
+
num_classes = len(metadata_df.get_column("label").unique())
|
|
100
|
+
logger.info(f"Loaded {features.shape[0]} samples with {features.shape[1]} dimensions, {num_classes} classes")
|
|
101
|
+
|
|
102
|
+
return (features, labels, num_classes)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def evaluate_bioscan5m(args: argparse.Namespace) -> None:
|
|
106
|
+
tic = time.time()
|
|
107
|
+
|
|
108
|
+
logger.info(f"Loading dataset from {args.data_path}")
|
|
109
|
+
metadata_df = _load_bioscan5m_metadata(args.data_path)
|
|
110
|
+
logger.info(f"Loaded metadata for {metadata_df.height} images")
|
|
111
|
+
|
|
112
|
+
results: list[dict[str, Any]] = []
|
|
113
|
+
total = len(args.embeddings)
|
|
114
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
115
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
116
|
+
features, labels, num_classes = _load_embeddings_with_labels(embeddings_path, metadata_df)
|
|
117
|
+
|
|
118
|
+
logger.info(
|
|
119
|
+
f"Evaluating AMI with umap_dim={args.umap_dim}, seed={args.seed}, "
|
|
120
|
+
f"l2_normalize={not args.no_l2_normalize}"
|
|
121
|
+
)
|
|
122
|
+
ami_score = evaluate_ami(
|
|
123
|
+
features,
|
|
124
|
+
labels,
|
|
125
|
+
n_clusters=num_classes,
|
|
126
|
+
umap_dim=args.umap_dim,
|
|
127
|
+
l2_normalize_features=not args.no_l2_normalize,
|
|
128
|
+
seed=args.seed,
|
|
129
|
+
)
|
|
130
|
+
logger.info(f"AMI Score: {ami_score:.4f}")
|
|
131
|
+
|
|
132
|
+
results.append(
|
|
133
|
+
{
|
|
134
|
+
"embeddings_file": str(embeddings_path),
|
|
135
|
+
"method": "ami",
|
|
136
|
+
"ami_score": ami_score,
|
|
137
|
+
"l2_normalize": not args.no_l2_normalize,
|
|
138
|
+
"num_classes": num_classes,
|
|
139
|
+
"num_samples": len(labels),
|
|
140
|
+
}
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
_print_summary_table(results)
|
|
144
|
+
|
|
145
|
+
if args.dry_run is False:
|
|
146
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
147
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
148
|
+
output_path = output_dir.joinpath("bioscan5m.csv")
|
|
149
|
+
_write_results_csv(results, output_path)
|
|
150
|
+
|
|
151
|
+
toc = time.time()
|
|
152
|
+
logger.info(f"BIOSCAN-5M benchmark completed in {lib.format_duration(toc - tic)}")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def set_parser(subparsers: Any) -> None:
|
|
156
|
+
subparser = subparsers.add_parser(
|
|
157
|
+
"bioscan5m",
|
|
158
|
+
allow_abbrev=False,
|
|
159
|
+
help="run BIOSCAN-5M benchmark - unsupervised embedding evaluation using AMI clustering",
|
|
160
|
+
description="run BIOSCAN-5M benchmark - unsupervised embedding evaluation using AMI clustering",
|
|
161
|
+
epilog=(
|
|
162
|
+
"Usage examples:\n"
|
|
163
|
+
"python -m birder.eval bioscan5m --embeddings "
|
|
164
|
+
"results/embeddings.parquet --data-path ~/Datasets/BIOSCAN-5M/species/testing_unseen --dry-run\n"
|
|
165
|
+
"python -m birder.eval bioscan5m --embeddings results/bioscan5m/*.parquet "
|
|
166
|
+
"--data-path ~/Datasets/BIOSCAN-5M/species/testing_unseen --seed 0\n"
|
|
167
|
+
),
|
|
168
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
169
|
+
)
|
|
170
|
+
subparser.add_argument(
|
|
171
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
|
|
172
|
+
)
|
|
173
|
+
subparser.add_argument("--data-path", type=str, metavar="PATH", help="path to ImageFolder-compatible directory")
|
|
174
|
+
subparser.add_argument("--umap-dim", type=int, default=50, help="target dimensionality for UMAP reduction")
|
|
175
|
+
subparser.add_argument(
|
|
176
|
+
"--no-l2-normalize",
|
|
177
|
+
default=False,
|
|
178
|
+
action="store_true",
|
|
179
|
+
help="disable L2 normalization of embeddings before UMAP",
|
|
180
|
+
)
|
|
181
|
+
subparser.add_argument("--seed", type=int, help="random seed for UMAP")
|
|
182
|
+
subparser.add_argument(
|
|
183
|
+
"--dir", type=str, default="bioscan5m", help="place all outputs in a sub-directory (relative to results)"
|
|
184
|
+
)
|
|
185
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
186
|
+
subparser.set_defaults(func=main)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
190
|
+
if args.embeddings is None:
|
|
191
|
+
raise cli.ValidationError("--embeddings is required")
|
|
192
|
+
if args.data_path is None:
|
|
193
|
+
raise cli.ValidationError("--data-path is required")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def main(args: argparse.Namespace) -> None:
|
|
197
|
+
validate_args(args)
|
|
198
|
+
evaluate_bioscan5m(args)
|