birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FungiCLEF2023 benchmark using KNN for fungi species classification
|
|
3
|
+
|
|
4
|
+
Link: https://www.imageclef.org/FungiCLEF2023
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import logging
|
|
9
|
+
import time
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
15
|
+
import polars as pl
|
|
16
|
+
from rich.console import Console
|
|
17
|
+
from rich.table import Table
|
|
18
|
+
|
|
19
|
+
from birder.common import cli
|
|
20
|
+
from birder.common import lib
|
|
21
|
+
from birder.conf import settings
|
|
22
|
+
from birder.datahub.evaluation import FungiCLEF2023
|
|
23
|
+
from birder.eval._embeddings import load_embeddings
|
|
24
|
+
from birder.eval.methods.knn import evaluate_knn
|
|
25
|
+
from birder.eval.methods.simpleshot import sample_k_shot
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _print_summary_table(results: list[dict[str, Any]], k_values: list[int]) -> None:
|
|
31
|
+
console = Console()
|
|
32
|
+
|
|
33
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
34
|
+
table.add_column("FungiCLEF2023 (KNN)", style="dim")
|
|
35
|
+
for k in k_values:
|
|
36
|
+
table.add_column(f"k={k}", justify="right")
|
|
37
|
+
|
|
38
|
+
table.add_column("Runs", justify="right")
|
|
39
|
+
|
|
40
|
+
for result in results:
|
|
41
|
+
row = [Path(result["embeddings_file"]).name]
|
|
42
|
+
for k in k_values:
|
|
43
|
+
acc = result["accuracies"].get(k)
|
|
44
|
+
row.append(f"{acc:.4f}" if acc is not None else "-")
|
|
45
|
+
|
|
46
|
+
row.append(f"{result['num_runs']}")
|
|
47
|
+
table.add_row(*row)
|
|
48
|
+
|
|
49
|
+
console.print(table)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _write_results_csv(results: list[dict[str, Any]], k_values: list[int], output_path: Path) -> None:
|
|
53
|
+
rows: list[dict[str, Any]] = []
|
|
54
|
+
for result in results:
|
|
55
|
+
row: dict[str, Any] = {
|
|
56
|
+
"embeddings_file": result["embeddings_file"],
|
|
57
|
+
"method": result["method"],
|
|
58
|
+
"num_runs": result["num_runs"],
|
|
59
|
+
}
|
|
60
|
+
for k in k_values:
|
|
61
|
+
row[f"k_{k}_acc"] = result["accuracies"].get(k)
|
|
62
|
+
row[f"k_{k}_std"] = result["accuracies_std"].get(k)
|
|
63
|
+
|
|
64
|
+
rows.append(row)
|
|
65
|
+
|
|
66
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
67
|
+
logger.info(f"Results saved to {output_path}")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _load_fungiclef_metadata(dataset: FungiCLEF2023) -> pl.DataFrame:
|
|
71
|
+
"""
|
|
72
|
+
Load metadata from FungiCLEF2023 CSV files
|
|
73
|
+
|
|
74
|
+
Returns DataFrame with columns: id (filename stem), label, split (train/val/test).
|
|
75
|
+
Filters out validation samples with unknown species (class_id == -1).
|
|
76
|
+
Test samples have label=-1 (no ground truth available) and are excluded from evaluation.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
train_df = pl.read_csv(dataset.train_metadata_path)
|
|
80
|
+
train_df = train_df.with_columns(
|
|
81
|
+
pl.col("image_path").map_elements(lambda p: Path(p).stem, return_dtype=pl.Utf8).alias("id"),
|
|
82
|
+
pl.lit("train").alias("split"),
|
|
83
|
+
).select(["id", "class_id", "split"])
|
|
84
|
+
|
|
85
|
+
val_df = pl.read_csv(dataset.val_metadata_path)
|
|
86
|
+
val_df = val_df.filter(pl.col("class_id") >= 0)
|
|
87
|
+
val_df = val_df.with_columns(
|
|
88
|
+
pl.col("filename").alias("id"),
|
|
89
|
+
pl.lit("val").alias("split"),
|
|
90
|
+
).select(["id", "class_id", "split"])
|
|
91
|
+
|
|
92
|
+
# Include test IDs so they are properly excluded when embeddings contain all samples
|
|
93
|
+
test_df = pl.read_csv(dataset.test_metadata_path)
|
|
94
|
+
test_df = test_df.with_columns(
|
|
95
|
+
pl.col("filename").alias("id"),
|
|
96
|
+
pl.lit("test").alias("split"),
|
|
97
|
+
pl.lit(-1, dtype=pl.Int64).alias("class_id"),
|
|
98
|
+
).select(["id", "class_id", "split"])
|
|
99
|
+
|
|
100
|
+
metadata_df = pl.concat([train_df, val_df, test_df])
|
|
101
|
+
metadata_df = metadata_df.rename({"class_id": "label"})
|
|
102
|
+
|
|
103
|
+
return metadata_df
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _load_embeddings_with_split(
|
|
107
|
+
embeddings_path: str, metadata_df: pl.DataFrame
|
|
108
|
+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
|
|
109
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
110
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
111
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
112
|
+
|
|
113
|
+
joined = metadata_df.join(emb_df, on="id", how="inner")
|
|
114
|
+
if joined.height < metadata_df.height:
|
|
115
|
+
logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
|
|
116
|
+
|
|
117
|
+
all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
118
|
+
all_labels = joined.get_column("label").to_numpy().astype(np.int_)
|
|
119
|
+
splits = joined.get_column("split").to_list()
|
|
120
|
+
|
|
121
|
+
is_train = np.array([s == "train" for s in splits], dtype=bool)
|
|
122
|
+
is_val = np.array([s == "val" for s in splits], dtype=bool)
|
|
123
|
+
is_test = np.array([s == "test" for s in splits], dtype=bool)
|
|
124
|
+
|
|
125
|
+
num_classes = len(np.unique(all_labels[is_train]))
|
|
126
|
+
logger.info(
|
|
127
|
+
f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
x_train = all_features[is_train]
|
|
131
|
+
y_train = all_labels[is_train]
|
|
132
|
+
x_val = all_features[is_val]
|
|
133
|
+
y_val = all_labels[is_val]
|
|
134
|
+
|
|
135
|
+
logger.info(
|
|
136
|
+
f"Train: {len(y_train)} samples, Val: {len(y_val)} samples, "
|
|
137
|
+
f"Test: {int(is_test.sum())} samples (no labels, excluded)"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return (x_train, y_train, x_val, y_val)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _evaluate_single_k(
|
|
144
|
+
x_train: npt.NDArray[np.float32],
|
|
145
|
+
y_train: npt.NDArray[np.int_],
|
|
146
|
+
x_val: npt.NDArray[np.float32],
|
|
147
|
+
y_val: npt.NDArray[np.int_],
|
|
148
|
+
k: int,
|
|
149
|
+
num_runs: int,
|
|
150
|
+
seed: int,
|
|
151
|
+
) -> tuple[float, float]:
|
|
152
|
+
logger.info(f"Evaluating k={k} ({k}-shot sampling, KNN k={k})")
|
|
153
|
+
|
|
154
|
+
scores: list[float] = []
|
|
155
|
+
for run in range(num_runs):
|
|
156
|
+
run_seed = seed + run
|
|
157
|
+
rng = np.random.default_rng(run_seed)
|
|
158
|
+
|
|
159
|
+
# Sample k examples per class
|
|
160
|
+
x_train_k, y_train_k = sample_k_shot(x_train, y_train, k, rng)
|
|
161
|
+
|
|
162
|
+
# Evaluate using KNN with k neighbors
|
|
163
|
+
y_pred, y_true = evaluate_knn(x_train_k, y_train_k, x_val, y_val, k=k)
|
|
164
|
+
|
|
165
|
+
acc = float(np.mean(y_pred == y_true))
|
|
166
|
+
scores.append(acc)
|
|
167
|
+
logger.info(f"Run {run + 1}/{num_runs} - Accuracy: {acc:.4f}")
|
|
168
|
+
|
|
169
|
+
scores_arr = np.array(scores)
|
|
170
|
+
mean_acc = float(scores_arr.mean())
|
|
171
|
+
std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
|
|
172
|
+
|
|
173
|
+
logger.info(f"k={k} - Mean accuracy over {num_runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
|
|
174
|
+
|
|
175
|
+
return (mean_acc, std_acc)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def evaluate_fungiclef(args: argparse.Namespace) -> None:
|
|
179
|
+
tic = time.time()
|
|
180
|
+
|
|
181
|
+
logger.info(f"Loading FungiCLEF2023 dataset from {args.dataset_path}")
|
|
182
|
+
dataset = FungiCLEF2023(args.dataset_path)
|
|
183
|
+
metadata_df = _load_fungiclef_metadata(dataset)
|
|
184
|
+
logger.info(f"Loaded metadata for {metadata_df.height} images")
|
|
185
|
+
|
|
186
|
+
results: list[dict[str, Any]] = []
|
|
187
|
+
total = len(args.embeddings)
|
|
188
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
189
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
190
|
+
x_train, y_train, x_val, y_val = _load_embeddings_with_split(embeddings_path, metadata_df)
|
|
191
|
+
|
|
192
|
+
accuracies: dict[int, float] = {}
|
|
193
|
+
accuracies_std: dict[int, float] = {}
|
|
194
|
+
for k in args.k:
|
|
195
|
+
mean_acc, std_acc = _evaluate_single_k(x_train, y_train, x_val, y_val, k, args.runs, args.seed)
|
|
196
|
+
accuracies[k] = mean_acc
|
|
197
|
+
accuracies_std[k] = std_acc
|
|
198
|
+
|
|
199
|
+
results.append(
|
|
200
|
+
{
|
|
201
|
+
"embeddings_file": str(embeddings_path),
|
|
202
|
+
"method": "knn",
|
|
203
|
+
"num_runs": args.runs,
|
|
204
|
+
"accuracies": accuracies,
|
|
205
|
+
"accuracies_std": accuracies_std,
|
|
206
|
+
}
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
_print_summary_table(results, args.k)
|
|
210
|
+
|
|
211
|
+
if args.dry_run is False:
|
|
212
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
213
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
214
|
+
output_path = output_dir.joinpath("fungiclef.csv")
|
|
215
|
+
_write_results_csv(results, args.k, output_path)
|
|
216
|
+
|
|
217
|
+
toc = time.time()
|
|
218
|
+
logger.info(f"FungiCLEF2023 benchmark completed in {lib.format_duration(toc - tic)}")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def set_parser(subparsers: Any) -> None:
|
|
222
|
+
subparser = subparsers.add_parser(
|
|
223
|
+
"fungiclef",
|
|
224
|
+
allow_abbrev=False,
|
|
225
|
+
help="run FungiCLEF2023 benchmark - 1,604 species classification using KNN",
|
|
226
|
+
description="run FungiCLEF2023 benchmark - 1,604 species classification using KNN",
|
|
227
|
+
epilog=(
|
|
228
|
+
"Usage examples:\n"
|
|
229
|
+
"python -m birder.eval fungiclef --embeddings "
|
|
230
|
+
"results/fungiclef_embeddings.parquet --dataset-path ~/Datasets/FungiCLEF2023 --dry-run\n"
|
|
231
|
+
"python -m birder.eval fungiclef --embeddings results/fungiclef/*.parquet "
|
|
232
|
+
"--dataset-path ~/Datasets/FungiCLEF2023\n"
|
|
233
|
+
),
|
|
234
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
235
|
+
)
|
|
236
|
+
subparser.add_argument(
|
|
237
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
|
|
238
|
+
)
|
|
239
|
+
subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to FungiCLEF2023 dataset root")
|
|
240
|
+
subparser.add_argument(
|
|
241
|
+
"--k", type=int, nargs="+", default=[1, 3], help="k value for k-shot sampling and KNN neighbors"
|
|
242
|
+
)
|
|
243
|
+
subparser.add_argument("--runs", type=int, default=5, help="number of evaluation runs")
|
|
244
|
+
subparser.add_argument("--seed", type=int, default=0, help="base random seed")
|
|
245
|
+
subparser.add_argument(
|
|
246
|
+
"--dir", type=str, default="fungiclef", help="place all outputs in a sub-directory (relative to results)"
|
|
247
|
+
)
|
|
248
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
249
|
+
subparser.set_defaults(func=main)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
253
|
+
if args.embeddings is None:
|
|
254
|
+
raise cli.ValidationError("--embeddings is required")
|
|
255
|
+
if args.dataset_path is None:
|
|
256
|
+
raise cli.ValidationError("--dataset-path is required")
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def main(args: argparse.Namespace) -> None:
|
|
260
|
+
validate_args(args)
|
|
261
|
+
evaluate_fungiclef(args)
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NABirds benchmark using KNN for bird species classification
|
|
3
|
+
|
|
4
|
+
Website: https://dl.allaboutbirds.org/nabirds
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import logging
|
|
9
|
+
import time
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
15
|
+
import polars as pl
|
|
16
|
+
from rich.console import Console
|
|
17
|
+
from rich.table import Table
|
|
18
|
+
|
|
19
|
+
from birder.common import cli
|
|
20
|
+
from birder.common import lib
|
|
21
|
+
from birder.conf import settings
|
|
22
|
+
from birder.datahub.evaluation import NABirds
|
|
23
|
+
from birder.eval._embeddings import load_embeddings
|
|
24
|
+
from birder.eval.methods.knn import evaluate_knn
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _print_summary_table(results: list[dict[str, Any]], k_values: list[int]) -> None:
|
|
30
|
+
console = Console()
|
|
31
|
+
|
|
32
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
33
|
+
table.add_column("NABirds (KNN)", style="dim")
|
|
34
|
+
for k in k_values:
|
|
35
|
+
table.add_column(f"k={k}", justify="right")
|
|
36
|
+
|
|
37
|
+
for result in results:
|
|
38
|
+
row = [Path(result["embeddings_file"]).name]
|
|
39
|
+
for k in k_values:
|
|
40
|
+
acc = result["accuracies"].get(k)
|
|
41
|
+
row.append(f"{acc:.4f}" if acc is not None else "-")
|
|
42
|
+
|
|
43
|
+
table.add_row(*row)
|
|
44
|
+
|
|
45
|
+
console.print(table)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _write_results_csv(results: list[dict[str, Any]], k_values: list[int], output_path: Path) -> None:
|
|
49
|
+
rows: list[dict[str, Any]] = []
|
|
50
|
+
for result in results:
|
|
51
|
+
row: dict[str, Any] = {
|
|
52
|
+
"embeddings_file": result["embeddings_file"],
|
|
53
|
+
"method": result["method"],
|
|
54
|
+
}
|
|
55
|
+
for k in k_values:
|
|
56
|
+
row[f"k_{k}_acc"] = result["accuracies"].get(k)
|
|
57
|
+
|
|
58
|
+
rows.append(row)
|
|
59
|
+
|
|
60
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
61
|
+
logger.info(f"Results saved to {output_path}")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _load_nabirds_metadata(dataset: NABirds) -> pl.DataFrame:
|
|
65
|
+
images_df = pl.read_csv(dataset.images_path, separator=" ", has_header=False, new_columns=["image_id", "filepath"])
|
|
66
|
+
images_df = images_df.with_columns(
|
|
67
|
+
pl.col("filepath").map_elements(lambda p: Path(p).stem, return_dtype=pl.Utf8).alias("id")
|
|
68
|
+
)
|
|
69
|
+
labels_df = pl.read_csv(dataset.labels_path, separator=" ", has_header=False, new_columns=["image_id", "class_id"])
|
|
70
|
+
classes_df = pl.read_csv(
|
|
71
|
+
dataset.classes_path, separator=" ", has_header=False, new_columns=["class_id", "class_name"]
|
|
72
|
+
)
|
|
73
|
+
split_df = pl.read_csv(
|
|
74
|
+
dataset.train_test_split_path, separator=" ", has_header=False, new_columns=["image_id", "is_train"]
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
metadata_df = (
|
|
78
|
+
images_df.join(labels_df, on="image_id")
|
|
79
|
+
.join(classes_df, on="class_id")
|
|
80
|
+
.join(split_df, on="image_id")
|
|
81
|
+
.select(["id", "class_id", "class_name", "is_train"])
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Create contiguous label indices (0 to num_classes-1)
|
|
85
|
+
unique_classes = metadata_df.get_column("class_id").unique().sort()
|
|
86
|
+
class_id_to_label = {cid: idx for idx, cid in enumerate(unique_classes.to_list())}
|
|
87
|
+
metadata_df = metadata_df.with_columns(pl.col("class_id").replace_strict(class_id_to_label).alias("label"))
|
|
88
|
+
|
|
89
|
+
return metadata_df
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _load_embeddings_with_split(
|
|
93
|
+
embeddings_path: str, metadata_df: pl.DataFrame
|
|
94
|
+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
|
|
95
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
96
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
97
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
98
|
+
|
|
99
|
+
joined = metadata_df.join(emb_df, on="id", how="inner")
|
|
100
|
+
if joined.height < metadata_df.height:
|
|
101
|
+
logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
|
|
102
|
+
|
|
103
|
+
all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
104
|
+
all_labels = joined.get_column("label").to_numpy().astype(np.int_)
|
|
105
|
+
is_train = joined.get_column("is_train").to_numpy().astype(bool)
|
|
106
|
+
|
|
107
|
+
num_classes = all_labels.max() + 1
|
|
108
|
+
logger.info(
|
|
109
|
+
f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
x_train = all_features[is_train]
|
|
113
|
+
y_train = all_labels[is_train]
|
|
114
|
+
x_test = all_features[~is_train]
|
|
115
|
+
y_test = all_labels[~is_train]
|
|
116
|
+
|
|
117
|
+
logger.info(f"Train: {len(y_train)} samples, Test: {len(y_test)} samples")
|
|
118
|
+
|
|
119
|
+
return (x_train, y_train, x_test, y_test)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def evaluate_nabirds(args: argparse.Namespace) -> None:
|
|
123
|
+
tic = time.time()
|
|
124
|
+
|
|
125
|
+
logger.info(f"Loading NABirds dataset from {args.dataset_path}")
|
|
126
|
+
dataset = NABirds(args.dataset_path)
|
|
127
|
+
metadata_df = _load_nabirds_metadata(dataset)
|
|
128
|
+
logger.info(f"Loaded metadata for {metadata_df.height} images")
|
|
129
|
+
|
|
130
|
+
results: list[dict[str, Any]] = []
|
|
131
|
+
total = len(args.embeddings)
|
|
132
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
133
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
134
|
+
x_train, y_train, x_test, y_test = _load_embeddings_with_split(embeddings_path, metadata_df)
|
|
135
|
+
|
|
136
|
+
accuracies: dict[int, float] = {}
|
|
137
|
+
for k in args.k:
|
|
138
|
+
logger.info(f"Evaluating KNN with k={k}")
|
|
139
|
+
y_pred, y_true = evaluate_knn(x_train, y_train, x_test, y_test, k=k)
|
|
140
|
+
acc = float(np.mean(y_pred == y_true))
|
|
141
|
+
accuracies[k] = acc
|
|
142
|
+
logger.info(f"k={k} - Accuracy: {acc:.4f}")
|
|
143
|
+
|
|
144
|
+
results.append(
|
|
145
|
+
{
|
|
146
|
+
"embeddings_file": str(embeddings_path),
|
|
147
|
+
"method": "knn",
|
|
148
|
+
"accuracies": accuracies,
|
|
149
|
+
}
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
_print_summary_table(results, args.k)
|
|
153
|
+
|
|
154
|
+
if args.dry_run is False:
|
|
155
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
156
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
157
|
+
output_path = output_dir.joinpath("nabirds.csv")
|
|
158
|
+
_write_results_csv(results, args.k, output_path)
|
|
159
|
+
|
|
160
|
+
toc = time.time()
|
|
161
|
+
logger.info(f"NABirds benchmark completed in {lib.format_duration(toc - tic)}")
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def set_parser(subparsers: Any) -> None:
|
|
165
|
+
subparser = subparsers.add_parser(
|
|
166
|
+
"nabirds",
|
|
167
|
+
allow_abbrev=False,
|
|
168
|
+
help="run NABirds benchmark - 555 class classification using KNN",
|
|
169
|
+
description="run NABirds benchmark - 555 class classification using KNN",
|
|
170
|
+
epilog=(
|
|
171
|
+
"Usage examples:\n"
|
|
172
|
+
"python -m birder.eval nabirds --embeddings "
|
|
173
|
+
"results/vit_b16_224px_crop1.0_48562_embeddings.parquet "
|
|
174
|
+
"--dataset-path ~/Datasets/nabirds --dry-run\n"
|
|
175
|
+
"python -m birder.eval nabirds --embeddings results/nabirds/*.parquet --dataset-path ~/Datasets/nabirds\n"
|
|
176
|
+
),
|
|
177
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
178
|
+
)
|
|
179
|
+
subparser.add_argument(
|
|
180
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
|
|
181
|
+
)
|
|
182
|
+
subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to NABirds dataset root")
|
|
183
|
+
subparser.add_argument(
|
|
184
|
+
"--k", type=int, nargs="+", default=[10, 20, 100], help="number of nearest neighbors for KNN"
|
|
185
|
+
)
|
|
186
|
+
subparser.add_argument(
|
|
187
|
+
"--dir", type=str, default="nabirds", help="place all outputs in a sub-directory (relative to results)"
|
|
188
|
+
)
|
|
189
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
190
|
+
subparser.set_defaults(func=main)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
194
|
+
if args.embeddings is None:
|
|
195
|
+
raise cli.ValidationError("--embeddings is required")
|
|
196
|
+
if args.dataset_path is None:
|
|
197
|
+
raise cli.ValidationError("--dataset-path is required")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def main(args: argparse.Namespace) -> None:
|
|
201
|
+
validate_args(args)
|
|
202
|
+
evaluate_nabirds(args)
|