birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NeWT benchmark, adapted from
|
|
3
|
+
https://github.com/samuelstevens/biobench/blob/main/src/biobench/newt/__init__.py
|
|
4
|
+
|
|
5
|
+
Paper "Benchmarking Representation Learning for Natural World Image Collections",
|
|
6
|
+
https://arxiv.org/abs/2103.16483
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# Reference license: MIT
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import logging
|
|
13
|
+
import time
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import numpy.typing as npt
|
|
19
|
+
import polars as pl
|
|
20
|
+
from rich.console import Console
|
|
21
|
+
from rich.table import Table
|
|
22
|
+
|
|
23
|
+
from birder.common import cli
|
|
24
|
+
from birder.common import lib
|
|
25
|
+
from birder.conf import settings
|
|
26
|
+
from birder.datahub.evaluation import NeWT
|
|
27
|
+
from birder.eval._embeddings import l2_normalize
|
|
28
|
+
from birder.eval._embeddings import load_embeddings
|
|
29
|
+
from birder.eval.methods.simpleshot import normalize_features
|
|
30
|
+
from birder.eval.methods.svm import evaluate_svm
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _print_summary_table(results: list[dict[str, Any]]) -> None:
|
|
36
|
+
console = Console()
|
|
37
|
+
|
|
38
|
+
cluster_names = sorted({cluster for result in results for cluster in result["per_cluster_accuracy"].keys()})
|
|
39
|
+
|
|
40
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
41
|
+
table.add_column("NeWT (SVM)", style="dim")
|
|
42
|
+
table.add_column("Accuracy", justify="right")
|
|
43
|
+
table.add_column("Std", justify="right")
|
|
44
|
+
table.add_column("Runs", justify="right")
|
|
45
|
+
for cluster in cluster_names:
|
|
46
|
+
table.add_column(cluster.replace("_", " ").title(), justify="right")
|
|
47
|
+
|
|
48
|
+
for result in results:
|
|
49
|
+
row = [
|
|
50
|
+
Path(result["embeddings_file"]).name,
|
|
51
|
+
f"{result['accuracy']:.4f}",
|
|
52
|
+
f"{result['accuracy_std']:.4f}",
|
|
53
|
+
f"{result['num_runs']}",
|
|
54
|
+
]
|
|
55
|
+
for cluster in cluster_names:
|
|
56
|
+
acc = result["per_cluster_accuracy"].get(cluster)
|
|
57
|
+
row.append(f"{acc:.4f}" if acc is not None else "-")
|
|
58
|
+
|
|
59
|
+
table.add_row(*row)
|
|
60
|
+
|
|
61
|
+
console.print(table)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _write_results_csv(results: list[dict[str, Any]], output_path: Path) -> None:
|
|
65
|
+
cluster_names = sorted({cluster for result in results for cluster in result["per_cluster_accuracy"].keys()})
|
|
66
|
+
rows: list[dict[str, Any]] = []
|
|
67
|
+
for result in results:
|
|
68
|
+
row: dict[str, Any] = {
|
|
69
|
+
"embeddings_file": result["embeddings_file"],
|
|
70
|
+
"method": result["method"],
|
|
71
|
+
"accuracy": result["accuracy"],
|
|
72
|
+
"accuracy_std": result["accuracy_std"],
|
|
73
|
+
"num_runs": result["num_runs"],
|
|
74
|
+
}
|
|
75
|
+
for cluster in cluster_names:
|
|
76
|
+
row[f"cluster_{cluster}"] = result["per_cluster_accuracy"].get(cluster)
|
|
77
|
+
|
|
78
|
+
rows.append(row)
|
|
79
|
+
|
|
80
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
81
|
+
logger.info(f"Results saved to {output_path}")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# pylint: disable=too-many-locals
|
|
85
|
+
def evaluate_newt_single(embeddings_path: str, labels_df: pl.DataFrame, args: argparse.Namespace) -> dict[str, Any]:
|
|
86
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
87
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
88
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
89
|
+
|
|
90
|
+
joined = labels_df.join(emb_df, on="id", how="inner").sort("index")
|
|
91
|
+
if joined.height < labels_df.height:
|
|
92
|
+
logger.warning(f"Join dropped {labels_df.height - joined.height} samples (missing embeddings)")
|
|
93
|
+
|
|
94
|
+
all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
95
|
+
logger.info(f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions")
|
|
96
|
+
|
|
97
|
+
# Global L2 normalization
|
|
98
|
+
all_features = l2_normalize(all_features)
|
|
99
|
+
joined = joined.with_columns(pl.Series("embedding", all_features.tolist()))
|
|
100
|
+
|
|
101
|
+
tasks = joined.get_column("task").unique().to_list()
|
|
102
|
+
logger.info(f"Found {len(tasks)} tasks")
|
|
103
|
+
|
|
104
|
+
scores: list[float] = []
|
|
105
|
+
cluster_scores: dict[str, list[float]] = {}
|
|
106
|
+
for run in range(args.runs):
|
|
107
|
+
run_seed = args.seed + run
|
|
108
|
+
|
|
109
|
+
y_preds_all: list[npt.NDArray[np.int_]] = []
|
|
110
|
+
y_trues_all: list[npt.NDArray[np.int_]] = []
|
|
111
|
+
cluster_preds: dict[str, list[npt.NDArray[np.int_]]] = {}
|
|
112
|
+
cluster_trues: dict[str, list[npt.NDArray[np.int_]]] = {}
|
|
113
|
+
for task_name in tasks:
|
|
114
|
+
tdf = joined.filter(pl.col("task") == task_name)
|
|
115
|
+
|
|
116
|
+
features = np.array(tdf.get_column("embedding").to_list(), dtype=np.float32)
|
|
117
|
+
|
|
118
|
+
labels = tdf.get_column("label").to_numpy()
|
|
119
|
+
is_train = (tdf.get_column("split") == "train").to_numpy()
|
|
120
|
+
cluster = tdf.item(0, "task_cluster")
|
|
121
|
+
|
|
122
|
+
x_train = features[is_train]
|
|
123
|
+
y_train = labels[is_train]
|
|
124
|
+
x_test = features[~is_train]
|
|
125
|
+
y_test = labels[~is_train]
|
|
126
|
+
|
|
127
|
+
if x_train.size == 0 or x_test.size == 0:
|
|
128
|
+
logger.warning(f"Skipping task {task_name}: empty train or test split")
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
# Per-task centering and L2 normalization
|
|
132
|
+
x_train, x_test = normalize_features(x_train, x_test)
|
|
133
|
+
|
|
134
|
+
# Train and evaluate SVM
|
|
135
|
+
y_pred, y_true = evaluate_svm(
|
|
136
|
+
x_train,
|
|
137
|
+
y_train,
|
|
138
|
+
x_test,
|
|
139
|
+
y_test,
|
|
140
|
+
n_iter=args.n_iter,
|
|
141
|
+
n_jobs=args.n_jobs,
|
|
142
|
+
seed=run_seed,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
y_preds_all.append(y_pred)
|
|
146
|
+
y_trues_all.append(y_true)
|
|
147
|
+
|
|
148
|
+
# Track per-cluster predictions
|
|
149
|
+
if cluster not in cluster_preds:
|
|
150
|
+
cluster_preds[cluster] = []
|
|
151
|
+
cluster_trues[cluster] = []
|
|
152
|
+
cluster_preds[cluster].append(y_pred)
|
|
153
|
+
cluster_trues[cluster].append(y_true)
|
|
154
|
+
|
|
155
|
+
# Micro-averaged accuracy
|
|
156
|
+
y_preds = np.concatenate(y_preds_all)
|
|
157
|
+
y_trues = np.concatenate(y_trues_all)
|
|
158
|
+
acc = float(np.mean(y_preds == y_trues))
|
|
159
|
+
scores.append(acc)
|
|
160
|
+
logger.info(f"Run {run + 1}/{args.runs} - Accuracy: {acc:.4f}")
|
|
161
|
+
|
|
162
|
+
# Compute per-cluster accuracy for this run
|
|
163
|
+
for cluster, preds_list in cluster_preds.items():
|
|
164
|
+
preds = np.concatenate(preds_list)
|
|
165
|
+
trues = np.concatenate(cluster_trues[cluster])
|
|
166
|
+
cluster_acc = float(np.mean(preds == trues))
|
|
167
|
+
if cluster not in cluster_scores:
|
|
168
|
+
cluster_scores[cluster] = []
|
|
169
|
+
cluster_scores[cluster].append(cluster_acc)
|
|
170
|
+
|
|
171
|
+
# Average per-cluster accuracy across runs
|
|
172
|
+
per_cluster_accuracy: dict[str, float] = {}
|
|
173
|
+
for cluster, accs in cluster_scores.items():
|
|
174
|
+
per_cluster_accuracy[cluster.lower()] = float(np.mean(accs))
|
|
175
|
+
|
|
176
|
+
scores_arr = np.array(scores)
|
|
177
|
+
mean_acc = float(scores_arr.mean())
|
|
178
|
+
std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
|
|
179
|
+
|
|
180
|
+
logger.info(f"Mean accuracy over {args.runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
|
|
181
|
+
for cluster, acc in sorted(per_cluster_accuracy.items()):
|
|
182
|
+
logger.info(f" {cluster}: {acc:.4f}")
|
|
183
|
+
|
|
184
|
+
return {
|
|
185
|
+
"method": "svm",
|
|
186
|
+
"accuracy": mean_acc,
|
|
187
|
+
"accuracy_std": std_acc,
|
|
188
|
+
"num_runs": args.runs,
|
|
189
|
+
"per_cluster_accuracy": per_cluster_accuracy,
|
|
190
|
+
"embeddings_file": str(embeddings_path),
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def evaluate_newt(args: argparse.Namespace) -> None:
|
|
195
|
+
tic = time.time()
|
|
196
|
+
|
|
197
|
+
logger.info(f"Loading NeWT dataset from {args.dataset_path}")
|
|
198
|
+
dataset = NeWT(args.dataset_path)
|
|
199
|
+
labels_path = dataset.labels_path
|
|
200
|
+
logger.info(f"Loading labels from {labels_path}")
|
|
201
|
+
labels_df = pl.read_csv(labels_path).with_row_index(name="index").with_columns(pl.col("id").cast(pl.Utf8))
|
|
202
|
+
|
|
203
|
+
results: list[dict[str, Any]] = []
|
|
204
|
+
total = len(args.embeddings)
|
|
205
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
206
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
207
|
+
result = evaluate_newt_single(embeddings_path, labels_df, args)
|
|
208
|
+
results.append(result)
|
|
209
|
+
|
|
210
|
+
_print_summary_table(results)
|
|
211
|
+
|
|
212
|
+
if args.dry_run is False:
|
|
213
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
214
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
215
|
+
output_path = output_dir.joinpath("newt.csv")
|
|
216
|
+
_write_results_csv(results, output_path)
|
|
217
|
+
|
|
218
|
+
toc = time.time()
|
|
219
|
+
logger.info(f"NeWT benchmark completed in {lib.format_duration(toc - tic)}")
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def set_parser(subparsers: Any) -> None:
|
|
223
|
+
subparser = subparsers.add_parser(
|
|
224
|
+
"newt",
|
|
225
|
+
allow_abbrev=False,
|
|
226
|
+
help="run NeWT benchmark - 164 binary classification tasks evaluated with SVM",
|
|
227
|
+
description="run NeWT benchmark - 164 binary classification tasks evaluated with SVM",
|
|
228
|
+
epilog=(
|
|
229
|
+
"Usage examples:\n"
|
|
230
|
+
"python -m birder.eval newt --embeddings "
|
|
231
|
+
"results/hieradet_d_small_dino-v2_0_224px_crop1.0_36032_output.parquet "
|
|
232
|
+
"--dataset-path ~/Datasets/NeWT --dry-run\n"
|
|
233
|
+
"python -m birder.eval newt --embeddings results/newt/*.parquet "
|
|
234
|
+
"--dataset-path ~/Datasets/NeWT\n"
|
|
235
|
+
),
|
|
236
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
237
|
+
)
|
|
238
|
+
subparser.add_argument(
|
|
239
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings/logits parquet files"
|
|
240
|
+
)
|
|
241
|
+
subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to NeWT dataset root")
|
|
242
|
+
subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
|
|
243
|
+
subparser.add_argument("--n-iter", type=int, default=100, help="SVM hyperparameter search iterations")
|
|
244
|
+
subparser.add_argument("--n-jobs", type=int, default=8, help="parallel jobs for RandomizedSearchCV")
|
|
245
|
+
subparser.add_argument("--seed", type=int, default=0, help="base random seed")
|
|
246
|
+
subparser.add_argument(
|
|
247
|
+
"--dir", type=str, default="newt", help="place all outputs in a sub-directory (relative to results)"
|
|
248
|
+
)
|
|
249
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
250
|
+
subparser.set_defaults(func=main)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
254
|
+
if args.embeddings is None:
|
|
255
|
+
raise cli.ValidationError("--embeddings is required")
|
|
256
|
+
if args.dataset_path is None:
|
|
257
|
+
raise cli.ValidationError("--dataset-path is required")
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def main(args: argparse.Namespace) -> None:
|
|
261
|
+
validate_args(args)
|
|
262
|
+
evaluate_newt(args)
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Plankton benchmark using linear probing for phytoplankton classification
|
|
3
|
+
|
|
4
|
+
Dataset "SYKE-plankton_IFCB_2022", https://b2share.eudat.eu/records/xvnrp-7ga56
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import numpy.typing as npt
|
|
16
|
+
import polars as pl
|
|
17
|
+
import torch
|
|
18
|
+
from rich.console import Console
|
|
19
|
+
from rich.table import Table
|
|
20
|
+
|
|
21
|
+
from birder.common import cli
|
|
22
|
+
from birder.common import lib
|
|
23
|
+
from birder.conf import settings
|
|
24
|
+
from birder.data.datasets.directory import make_image_dataset
|
|
25
|
+
from birder.datahub.evaluation import Plankton
|
|
26
|
+
from birder.eval._embeddings import load_embeddings
|
|
27
|
+
from birder.eval.methods.linear import evaluate_linear_probe
|
|
28
|
+
from birder.eval.methods.linear import train_linear_probe
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _print_summary_table(results: list[dict[str, Any]]) -> None:
|
|
34
|
+
console = Console()
|
|
35
|
+
|
|
36
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
37
|
+
table.add_column("Plankton (Linear)", style="dim")
|
|
38
|
+
table.add_column("Accuracy", justify="right")
|
|
39
|
+
table.add_column("Std", justify="right")
|
|
40
|
+
table.add_column("Runs", justify="right")
|
|
41
|
+
|
|
42
|
+
for result in results:
|
|
43
|
+
table.add_row(
|
|
44
|
+
Path(result["embeddings_file"]).name,
|
|
45
|
+
f"{result['accuracy']:.4f}",
|
|
46
|
+
f"{result['accuracy_std']:.4f}",
|
|
47
|
+
f"{result['num_runs']}",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
console.print(table)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _write_results_csv(results: list[dict[str, Any]], output_path: Path) -> None:
|
|
54
|
+
rows: list[dict[str, Any]] = []
|
|
55
|
+
for result in results:
|
|
56
|
+
rows.append(
|
|
57
|
+
{
|
|
58
|
+
"embeddings_file": result["embeddings_file"],
|
|
59
|
+
"method": result["method"],
|
|
60
|
+
"accuracy": result["accuracy"],
|
|
61
|
+
"accuracy_std": result["accuracy_std"],
|
|
62
|
+
"num_runs": result["num_runs"],
|
|
63
|
+
}
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
67
|
+
logger.info(f"Results saved to {output_path}")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _load_plankton_metadata(dataset: Plankton) -> pl.DataFrame:
|
|
71
|
+
"""
|
|
72
|
+
Load metadata using make_image_dataset with a fixed class_to_idx
|
|
73
|
+
|
|
74
|
+
Returns DataFrame with columns: id (filename stem), label, split
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
# Build class_to_idx from train classes only
|
|
78
|
+
class_to_idx = {
|
|
79
|
+
entry.name: idx
|
|
80
|
+
for idx, entry in enumerate(sorted(os.scandir(str(dataset.train_dir)), key=lambda e: e.name))
|
|
81
|
+
if entry.is_dir()
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
rows: list[dict[str, Any]] = []
|
|
85
|
+
for split, split_dir in [("train", dataset.train_dir), ("val", dataset.val_dir)]:
|
|
86
|
+
image_dataset = make_image_dataset([str(split_dir)], class_to_idx)
|
|
87
|
+
for i in range(len(image_dataset)):
|
|
88
|
+
path = image_dataset.paths[i].decode("utf-8")
|
|
89
|
+
label = image_dataset.labels[i].item()
|
|
90
|
+
rows.append({"id": Path(path).stem, "label": label, "split": split})
|
|
91
|
+
|
|
92
|
+
return pl.DataFrame(rows)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _load_embeddings_with_split(
|
|
96
|
+
embeddings_path: str, metadata_df: pl.DataFrame
|
|
97
|
+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
|
|
98
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
99
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
100
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
101
|
+
|
|
102
|
+
joined = metadata_df.join(emb_df, on="id", how="inner")
|
|
103
|
+
if joined.height < metadata_df.height:
|
|
104
|
+
logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
|
|
105
|
+
|
|
106
|
+
all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
107
|
+
all_labels = joined.get_column("label").to_numpy().astype(np.int_)
|
|
108
|
+
splits = joined.get_column("split").to_list()
|
|
109
|
+
|
|
110
|
+
is_train = np.array([s == "train" for s in splits], dtype=bool)
|
|
111
|
+
is_val = np.array([s == "val" for s in splits], dtype=bool)
|
|
112
|
+
|
|
113
|
+
num_classes = all_labels.max() + 1
|
|
114
|
+
logger.info(
|
|
115
|
+
f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
x_train = all_features[is_train]
|
|
119
|
+
y_train = all_labels[is_train]
|
|
120
|
+
x_val = all_features[is_val]
|
|
121
|
+
y_val = all_labels[is_val]
|
|
122
|
+
|
|
123
|
+
logger.info(f"Train: {len(y_train)} samples, Val: {len(y_val)} samples")
|
|
124
|
+
|
|
125
|
+
return (x_train, y_train, x_val, y_val)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def evaluate_plankton_single(
|
|
129
|
+
x_train: npt.NDArray[np.float32],
|
|
130
|
+
y_train: npt.NDArray[np.int_],
|
|
131
|
+
x_val: npt.NDArray[np.float32],
|
|
132
|
+
y_val: npt.NDArray[np.int_],
|
|
133
|
+
args: argparse.Namespace,
|
|
134
|
+
embeddings_path: str,
|
|
135
|
+
device: torch.device,
|
|
136
|
+
) -> dict[str, Any]:
|
|
137
|
+
num_classes = int(y_train.max() + 1)
|
|
138
|
+
|
|
139
|
+
scores: list[float] = []
|
|
140
|
+
for run in range(args.runs):
|
|
141
|
+
run_seed = args.seed + run
|
|
142
|
+
logger.info(f"Run {run + 1}/{args.runs} (seed={run_seed})")
|
|
143
|
+
|
|
144
|
+
model = train_linear_probe(
|
|
145
|
+
x_train,
|
|
146
|
+
y_train,
|
|
147
|
+
num_classes,
|
|
148
|
+
device=device,
|
|
149
|
+
epochs=args.epochs,
|
|
150
|
+
batch_size=args.batch_size,
|
|
151
|
+
lr=args.lr,
|
|
152
|
+
seed=run_seed,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
y_pred, y_true = evaluate_linear_probe(model, x_val, y_val, batch_size=args.batch_size, device=device)
|
|
156
|
+
acc = float(np.mean(y_pred == y_true))
|
|
157
|
+
scores.append(acc)
|
|
158
|
+
logger.info(f"Run {run + 1}/{args.runs} - Accuracy: {acc:.4f}")
|
|
159
|
+
|
|
160
|
+
scores_arr = np.array(scores)
|
|
161
|
+
mean_acc = float(scores_arr.mean())
|
|
162
|
+
std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
|
|
163
|
+
|
|
164
|
+
logger.info(f"Mean accuracy over {args.runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
|
|
165
|
+
|
|
166
|
+
return {
|
|
167
|
+
"method": "linear",
|
|
168
|
+
"accuracy": mean_acc,
|
|
169
|
+
"accuracy_std": std_acc,
|
|
170
|
+
"num_runs": args.runs,
|
|
171
|
+
"embeddings_file": str(embeddings_path),
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def evaluate_plankton(args: argparse.Namespace) -> None:
|
|
176
|
+
tic = time.time()
|
|
177
|
+
|
|
178
|
+
if args.gpu is True:
|
|
179
|
+
device = torch.device("cuda")
|
|
180
|
+
else:
|
|
181
|
+
device = torch.device("cpu")
|
|
182
|
+
|
|
183
|
+
if args.gpu_id is not None:
|
|
184
|
+
torch.cuda.set_device(args.gpu_id)
|
|
185
|
+
|
|
186
|
+
logger.info(f"Using device {device}")
|
|
187
|
+
logger.info(f"Loading Plankton dataset from {args.dataset_path}")
|
|
188
|
+
dataset = Plankton(args.dataset_path)
|
|
189
|
+
metadata_df = _load_plankton_metadata(dataset)
|
|
190
|
+
logger.info(f"Loaded metadata for {metadata_df.height} images")
|
|
191
|
+
|
|
192
|
+
results: list[dict[str, Any]] = []
|
|
193
|
+
total = len(args.embeddings)
|
|
194
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
195
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
196
|
+
x_train, y_train, x_val, y_val = _load_embeddings_with_split(embeddings_path, metadata_df)
|
|
197
|
+
|
|
198
|
+
result = evaluate_plankton_single(x_train, y_train, x_val, y_val, args, embeddings_path, device)
|
|
199
|
+
results.append(result)
|
|
200
|
+
|
|
201
|
+
_print_summary_table(results)
|
|
202
|
+
|
|
203
|
+
if args.dry_run is False:
|
|
204
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
205
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
206
|
+
output_path = output_dir.joinpath("plankton.csv")
|
|
207
|
+
_write_results_csv(results, output_path)
|
|
208
|
+
|
|
209
|
+
toc = time.time()
|
|
210
|
+
logger.info(f"Plankton benchmark completed in {lib.format_duration(toc - tic)}")
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def set_parser(subparsers: Any) -> None:
|
|
214
|
+
subparser = subparsers.add_parser(
|
|
215
|
+
"plankton",
|
|
216
|
+
allow_abbrev=False,
|
|
217
|
+
help="run Plankton benchmark - 50 class phytoplankton classification using linear probing",
|
|
218
|
+
description="run Plankton benchmark - 50 class phytoplankton classification using linear probing",
|
|
219
|
+
epilog=(
|
|
220
|
+
"Usage examples:\n"
|
|
221
|
+
"python -m birder.eval plankton --embeddings "
|
|
222
|
+
"results/plankton_embeddings.parquet --dataset-path ~/Datasets/plankton --dry-run\n"
|
|
223
|
+
"python -m birder.eval plankton --embeddings results/plankton/*.parquet "
|
|
224
|
+
"--dataset-path ~/Datasets/plankton --gpu\n"
|
|
225
|
+
),
|
|
226
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
227
|
+
)
|
|
228
|
+
subparser.add_argument(
|
|
229
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
|
|
230
|
+
)
|
|
231
|
+
subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to Plankton dataset root")
|
|
232
|
+
subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
|
|
233
|
+
subparser.add_argument("--epochs", type=int, default=75, help="training epochs per run")
|
|
234
|
+
subparser.add_argument("--batch-size", type=int, default=128, help="batch size for training and inference")
|
|
235
|
+
subparser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
|
|
236
|
+
subparser.add_argument("--seed", type=int, default=0, help="base random seed")
|
|
237
|
+
subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
|
|
238
|
+
subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
|
|
239
|
+
subparser.add_argument(
|
|
240
|
+
"--dir", type=str, default="plankton", help="place all outputs in a sub-directory (relative to results)"
|
|
241
|
+
)
|
|
242
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
243
|
+
subparser.set_defaults(func=main)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
247
|
+
if args.embeddings is None:
|
|
248
|
+
raise cli.ValidationError("--embeddings is required")
|
|
249
|
+
if args.dataset_path is None:
|
|
250
|
+
raise cli.ValidationError("--dataset-path is required")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def main(args: argparse.Namespace) -> None:
|
|
254
|
+
validate_args(args)
|
|
255
|
+
evaluate_plankton(args)
|