birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FishNet benchmark using MLP probe for multi-label trait prediction
|
|
3
|
+
|
|
4
|
+
Paper "FishNet: A Large-scale Dataset and Benchmark for Fish Recognition, Detection, and Functional Trait Prediction"
|
|
5
|
+
https://fishnet-2023.github.io/
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import numpy.typing as npt
|
|
16
|
+
import polars as pl
|
|
17
|
+
import torch
|
|
18
|
+
from rich.console import Console
|
|
19
|
+
from rich.table import Table
|
|
20
|
+
from sklearn.metrics import f1_score
|
|
21
|
+
|
|
22
|
+
from birder.common import cli
|
|
23
|
+
from birder.common import lib
|
|
24
|
+
from birder.conf import settings
|
|
25
|
+
from birder.datahub.evaluation import FishNet
|
|
26
|
+
from birder.eval._embeddings import load_embeddings
|
|
27
|
+
from birder.eval.methods.mlp import evaluate_mlp
|
|
28
|
+
from birder.eval.methods.mlp import train_mlp
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _print_summary_table(results: list[dict[str, Any]]) -> None:
|
|
34
|
+
console = Console()
|
|
35
|
+
|
|
36
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
37
|
+
table.add_column("FishNet (MLP)", style="dim")
|
|
38
|
+
table.add_column("Macro F1", justify="right")
|
|
39
|
+
table.add_column("Std", justify="right")
|
|
40
|
+
table.add_column("Exact Match", justify="right")
|
|
41
|
+
table.add_column("Std", justify="right")
|
|
42
|
+
table.add_column("Runs", justify="right")
|
|
43
|
+
for result in results:
|
|
44
|
+
row = [
|
|
45
|
+
Path(result["embeddings_file"]).name,
|
|
46
|
+
f"{result['macro_f1']:.4f}",
|
|
47
|
+
f"{result['macro_f1_std']:.4f}",
|
|
48
|
+
f"{result['exact_match_acc']:.4f}",
|
|
49
|
+
f"{result['exact_match_acc_std']:.4f}",
|
|
50
|
+
f"{result['num_runs']}",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
table.add_row(*row)
|
|
54
|
+
|
|
55
|
+
console.print(table)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _write_results_csv(results: list[dict[str, Any]], trait_names: list[str], output_path: Path) -> None:
|
|
59
|
+
rows: list[dict[str, Any]] = []
|
|
60
|
+
for result in results:
|
|
61
|
+
row: dict[str, Any] = {
|
|
62
|
+
"embeddings_file": result["embeddings_file"],
|
|
63
|
+
"method": result["method"],
|
|
64
|
+
"macro_f1": result["macro_f1"],
|
|
65
|
+
"macro_f1_std": result["macro_f1_std"],
|
|
66
|
+
"exact_match_acc": result["exact_match_acc"],
|
|
67
|
+
"exact_match_acc_std": result["exact_match_acc_std"],
|
|
68
|
+
"num_runs": result["num_runs"],
|
|
69
|
+
}
|
|
70
|
+
for trait in trait_names:
|
|
71
|
+
row[f"f1_{trait}"] = result["per_trait_f1"].get(trait)
|
|
72
|
+
|
|
73
|
+
rows.append(row)
|
|
74
|
+
|
|
75
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
76
|
+
logger.info(f"Results saved to {output_path}")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _load_fishnet_data(csv_path: Path, trait_columns: list[str]) -> pl.DataFrame:
|
|
80
|
+
"""
|
|
81
|
+
Load FishNet CSV and prepare metadata
|
|
82
|
+
|
|
83
|
+
Returns DataFrame with columns: id, trait labels (0/1)
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
df = pl.read_csv(csv_path)
|
|
87
|
+
df = df.with_columns(
|
|
88
|
+
pl.col("image")
|
|
89
|
+
.str.extract(r"([^/]+)$") # Get filename (last path segment)
|
|
90
|
+
.str.replace(r"\.[^.]+$", "") # Remove extension
|
|
91
|
+
.alias("id")
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Encode FeedingPath: benthic=0, pelagic=1
|
|
95
|
+
df = df.with_columns(
|
|
96
|
+
pl.when(pl.col("FeedingPath") == "pelagic")
|
|
97
|
+
.then(pl.lit(1))
|
|
98
|
+
.when(pl.col("FeedingPath") == "benthic")
|
|
99
|
+
.then(pl.lit(0))
|
|
100
|
+
.otherwise(pl.lit(None))
|
|
101
|
+
.alias("FeedingPath_encoded")
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Select relevant columns
|
|
105
|
+
other_traits = [t for t in trait_columns if t != "FeedingPath"]
|
|
106
|
+
select_cols = ["id", "FeedingPath_encoded"] + other_traits
|
|
107
|
+
df = df.select(select_cols)
|
|
108
|
+
|
|
109
|
+
# Rename FeedingPath_encoded back to FeedingPath
|
|
110
|
+
df = df.rename({"FeedingPath_encoded": "FeedingPath"})
|
|
111
|
+
|
|
112
|
+
# Filter rows with any null trait values
|
|
113
|
+
for trait in trait_columns:
|
|
114
|
+
df = df.filter(pl.col(trait).is_not_null())
|
|
115
|
+
|
|
116
|
+
return df
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _load_embeddings_with_labels(
|
|
120
|
+
embeddings_path: str, train_df: pl.DataFrame, test_df: pl.DataFrame, trait_columns: list[str]
|
|
121
|
+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]:
|
|
122
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
123
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
124
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
125
|
+
|
|
126
|
+
# Join with train data
|
|
127
|
+
train_joined = train_df.join(emb_df, on="id", how="inner")
|
|
128
|
+
if train_joined.height < train_df.height:
|
|
129
|
+
logger.warning(f"Train: dropped {train_df.height - train_joined.height} samples (missing embeddings)")
|
|
130
|
+
|
|
131
|
+
# Join with test data
|
|
132
|
+
test_joined = test_df.join(emb_df, on="id", how="inner")
|
|
133
|
+
if test_joined.height < test_df.height:
|
|
134
|
+
logger.warning(f"Test: dropped {test_df.height - test_joined.height} samples (missing embeddings)")
|
|
135
|
+
|
|
136
|
+
# Extract features and labels
|
|
137
|
+
x_train = np.array(train_joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
138
|
+
y_train = train_joined.select(trait_columns).to_numpy().astype(np.float32)
|
|
139
|
+
|
|
140
|
+
x_test = np.array(test_joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
141
|
+
y_test = test_joined.select(trait_columns).to_numpy().astype(np.float32)
|
|
142
|
+
|
|
143
|
+
logger.info(f"Train: {x_train.shape[0]} samples, Test: {x_test.shape[0]} samples")
|
|
144
|
+
logger.info(f"Features: {x_train.shape[1]} dims, Traits: {len(trait_columns)}")
|
|
145
|
+
|
|
146
|
+
return (x_train, y_train, x_test, y_test)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# pylint: disable=too-many-locals
|
|
150
|
+
def evaluate_fishnet_single(
|
|
151
|
+
x_train: npt.NDArray[np.float32],
|
|
152
|
+
y_train: npt.NDArray[np.float32],
|
|
153
|
+
x_test: npt.NDArray[np.float32],
|
|
154
|
+
y_test: npt.NDArray[np.float32],
|
|
155
|
+
trait_columns: list[str],
|
|
156
|
+
args: argparse.Namespace,
|
|
157
|
+
embeddings_path: str,
|
|
158
|
+
device: torch.device,
|
|
159
|
+
) -> dict[str, Any]:
|
|
160
|
+
num_classes = len(trait_columns)
|
|
161
|
+
|
|
162
|
+
scores: list[float] = []
|
|
163
|
+
exact_match_scores: list[float] = []
|
|
164
|
+
per_trait_f1_runs: list[dict[str, float]] = []
|
|
165
|
+
|
|
166
|
+
for run in range(args.runs):
|
|
167
|
+
run_seed = args.seed + run
|
|
168
|
+
logger.info(f"Run {run + 1}/{args.runs} (seed={run_seed})")
|
|
169
|
+
|
|
170
|
+
# Train MLP
|
|
171
|
+
model = train_mlp(
|
|
172
|
+
x_train,
|
|
173
|
+
y_train,
|
|
174
|
+
num_classes=num_classes,
|
|
175
|
+
device=device,
|
|
176
|
+
epochs=args.epochs,
|
|
177
|
+
batch_size=args.batch_size,
|
|
178
|
+
lr=args.lr,
|
|
179
|
+
hidden_dim=args.hidden_dim,
|
|
180
|
+
dropout=args.dropout,
|
|
181
|
+
seed=run_seed,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Evaluate
|
|
185
|
+
y_pred, y_true, macro_f1 = evaluate_mlp(model, x_test, y_test, batch_size=args.batch_size, device=device)
|
|
186
|
+
scores.append(macro_f1)
|
|
187
|
+
exact_match_acc = float(np.mean(np.all(y_pred == y_true, axis=1)))
|
|
188
|
+
exact_match_scores.append(exact_match_acc)
|
|
189
|
+
|
|
190
|
+
# Per-trait F1
|
|
191
|
+
per_trait_f1: dict[str, float] = {}
|
|
192
|
+
for i, trait in enumerate(trait_columns):
|
|
193
|
+
trait_f1 = f1_score(y_true[:, i], y_pred[:, i], average="binary", zero_division=0.0)
|
|
194
|
+
per_trait_f1[trait] = float(trait_f1)
|
|
195
|
+
|
|
196
|
+
per_trait_f1_runs.append(per_trait_f1)
|
|
197
|
+
logger.info(f"Run {run + 1}/{args.runs} - Macro F1: {macro_f1:.4f}, Exact Match: {exact_match_acc:.4f}")
|
|
198
|
+
|
|
199
|
+
# Average results
|
|
200
|
+
scores_arr = np.array(scores)
|
|
201
|
+
mean_f1 = float(scores_arr.mean())
|
|
202
|
+
std_f1 = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
|
|
203
|
+
exact_scores_arr = np.array(exact_match_scores)
|
|
204
|
+
mean_exact = float(exact_scores_arr.mean())
|
|
205
|
+
std_exact = float(exact_scores_arr.std(ddof=1)) if len(exact_match_scores) > 1 else 0.0
|
|
206
|
+
|
|
207
|
+
# Average per-trait F1 across runs
|
|
208
|
+
avg_per_trait_f1: dict[str, float] = {}
|
|
209
|
+
for trait in trait_columns:
|
|
210
|
+
trait_scores = [run_f1[trait] for run_f1 in per_trait_f1_runs]
|
|
211
|
+
avg_per_trait_f1[trait] = float(np.mean(trait_scores))
|
|
212
|
+
|
|
213
|
+
logger.info(f"Mean Macro F1 over {args.runs} runs: {mean_f1:.4f} +/- {std_f1:.4f} (std)")
|
|
214
|
+
logger.info(f"Mean Exact Match over {args.runs} runs: {mean_exact:.4f} +/- {std_exact:.4f} (std)")
|
|
215
|
+
for trait, f1 in avg_per_trait_f1.items():
|
|
216
|
+
logger.info(f" {trait}: {f1:.4f}")
|
|
217
|
+
|
|
218
|
+
return {
|
|
219
|
+
"method": "mlp",
|
|
220
|
+
"macro_f1": mean_f1,
|
|
221
|
+
"macro_f1_std": std_f1,
|
|
222
|
+
"exact_match_acc": mean_exact,
|
|
223
|
+
"exact_match_acc_std": std_exact,
|
|
224
|
+
"num_runs": args.runs,
|
|
225
|
+
"per_trait_f1": avg_per_trait_f1,
|
|
226
|
+
"embeddings_file": str(embeddings_path),
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def evaluate_fishnet(args: argparse.Namespace) -> None:
|
|
231
|
+
tic = time.time()
|
|
232
|
+
|
|
233
|
+
if args.gpu is True:
|
|
234
|
+
device = torch.device("cuda")
|
|
235
|
+
else:
|
|
236
|
+
device = torch.device("cpu")
|
|
237
|
+
|
|
238
|
+
if args.gpu_id is not None:
|
|
239
|
+
torch.cuda.set_device(args.gpu_id)
|
|
240
|
+
|
|
241
|
+
logger.info(f"Using device {device}")
|
|
242
|
+
logger.info(f"Loading FishNet dataset from {args.dataset_path}")
|
|
243
|
+
dataset = FishNet(args.dataset_path)
|
|
244
|
+
trait_columns = dataset.trait_columns
|
|
245
|
+
|
|
246
|
+
train_df = _load_fishnet_data(dataset.train_csv, trait_columns)
|
|
247
|
+
test_df = _load_fishnet_data(dataset.test_csv, trait_columns)
|
|
248
|
+
logger.info(f"Train samples: {train_df.height}, Test samples: {test_df.height}")
|
|
249
|
+
|
|
250
|
+
results: list[dict[str, Any]] = []
|
|
251
|
+
total = len(args.embeddings)
|
|
252
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
253
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
254
|
+
x_train, y_train, x_test, y_test = _load_embeddings_with_labels(
|
|
255
|
+
embeddings_path, train_df, test_df, trait_columns
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
result = evaluate_fishnet_single(x_train, y_train, x_test, y_test, trait_columns, args, embeddings_path, device)
|
|
259
|
+
results.append(result)
|
|
260
|
+
|
|
261
|
+
_print_summary_table(results)
|
|
262
|
+
|
|
263
|
+
if args.dry_run is False:
|
|
264
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
265
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
266
|
+
output_path = output_dir.joinpath("fishnet.csv")
|
|
267
|
+
_write_results_csv(results, trait_columns, output_path)
|
|
268
|
+
|
|
269
|
+
toc = time.time()
|
|
270
|
+
logger.info(f"FishNet benchmark completed in {lib.format_duration(toc - tic)}")
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def set_parser(subparsers: Any) -> None:
|
|
274
|
+
subparser = subparsers.add_parser(
|
|
275
|
+
"fishnet",
|
|
276
|
+
allow_abbrev=False,
|
|
277
|
+
help="run FishNet benchmark - 9 trait multi-label classification using MLP probe",
|
|
278
|
+
description="run FishNet benchmark - 9 trait multi-label classification using MLP probe",
|
|
279
|
+
epilog=(
|
|
280
|
+
"Usage examples:\n"
|
|
281
|
+
"python -m birder.eval fishnet --embeddings "
|
|
282
|
+
"results/vit_b16_224px_embeddings.parquet "
|
|
283
|
+
"--dataset-path ~/Datasets/fishnet --dry-run\n"
|
|
284
|
+
"python -m birder.eval fishnet --embeddings results/fishnet/*.parquet "
|
|
285
|
+
"--dataset-path ~/Datasets/fishnet --gpu --gpu-id 1\n"
|
|
286
|
+
),
|
|
287
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
288
|
+
)
|
|
289
|
+
subparser.add_argument(
|
|
290
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
|
|
291
|
+
)
|
|
292
|
+
subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to FishNet dataset root")
|
|
293
|
+
subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
|
|
294
|
+
subparser.add_argument("--epochs", type=int, default=100, help="training epochs per run")
|
|
295
|
+
subparser.add_argument("--batch-size", type=int, default=128, help="batch size for training and inference")
|
|
296
|
+
subparser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
|
|
297
|
+
subparser.add_argument("--hidden-dim", type=int, default=512, help="MLP hidden layer dimension")
|
|
298
|
+
subparser.add_argument("--dropout", type=float, default=0.5, help="dropout probability")
|
|
299
|
+
subparser.add_argument("--seed", type=int, default=0, help="base random seed")
|
|
300
|
+
subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
|
|
301
|
+
subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
|
|
302
|
+
subparser.add_argument(
|
|
303
|
+
"--dir", type=str, default="fishnet", help="place all outputs in a sub-directory (relative to results)"
|
|
304
|
+
)
|
|
305
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
306
|
+
subparser.set_defaults(func=main)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
310
|
+
if args.embeddings is None:
|
|
311
|
+
raise cli.ValidationError("--embeddings is required")
|
|
312
|
+
if args.dataset_path is None:
|
|
313
|
+
raise cli.ValidationError("--dataset-path is required")
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def main(args: argparse.Namespace) -> None:
|
|
317
|
+
validate_args(args)
|
|
318
|
+
evaluate_fishnet(args)
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Flowers102 benchmark using SimpleShot for flower species classification
|
|
3
|
+
|
|
4
|
+
Paper "Automated Flower Classification over a Large Number of Classes"
|
|
5
|
+
https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import numpy.typing as npt
|
|
16
|
+
import polars as pl
|
|
17
|
+
from rich.console import Console
|
|
18
|
+
from rich.table import Table
|
|
19
|
+
from torchvision.datasets import ImageFolder
|
|
20
|
+
|
|
21
|
+
from birder.common import cli
|
|
22
|
+
from birder.common import lib
|
|
23
|
+
from birder.conf import settings
|
|
24
|
+
from birder.eval._embeddings import load_embeddings
|
|
25
|
+
from birder.eval.methods.simpleshot import evaluate_simpleshot
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _print_summary_table(results: list[dict[str, Any]]) -> None:
|
|
31
|
+
console = Console()
|
|
32
|
+
|
|
33
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
34
|
+
table.add_column("Flowers102 (SimpleShot)", style="dim")
|
|
35
|
+
table.add_column("Val Acc", justify="right")
|
|
36
|
+
table.add_column("Test Acc", justify="right")
|
|
37
|
+
|
|
38
|
+
for result in results:
|
|
39
|
+
table.add_row(
|
|
40
|
+
Path(result["embeddings_file"]).name,
|
|
41
|
+
f"{result['val_accuracy']:.4f}",
|
|
42
|
+
f"{result['test_accuracy']:.4f}",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
console.print(table)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _write_results_csv(results: list[dict[str, Any]], output_path: Path) -> None:
|
|
49
|
+
rows: list[dict[str, Any]] = []
|
|
50
|
+
for result in results:
|
|
51
|
+
rows.append(
|
|
52
|
+
{
|
|
53
|
+
"embeddings_file": result["embeddings_file"],
|
|
54
|
+
"method": result["method"],
|
|
55
|
+
"val_accuracy": result["val_accuracy"],
|
|
56
|
+
"test_accuracy": result["test_accuracy"],
|
|
57
|
+
}
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
61
|
+
logger.info(f"Results saved to {output_path}")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _load_flowers102_metadata(dataset_path: Path) -> pl.DataFrame:
|
|
65
|
+
rows: list[dict[str, Any]] = []
|
|
66
|
+
for split in ["training", "validation", "testing"]:
|
|
67
|
+
split_dir = dataset_path.joinpath(split)
|
|
68
|
+
if not split_dir.exists():
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
dataset = ImageFolder(str(split_dir))
|
|
72
|
+
for path, label in dataset.samples:
|
|
73
|
+
rows.append({"id": Path(path).stem, "label": label, "split": split})
|
|
74
|
+
|
|
75
|
+
return pl.DataFrame(rows)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _load_embeddings_with_split(embeddings_path: str, metadata_df: pl.DataFrame) -> tuple[
|
|
79
|
+
npt.NDArray[np.float32],
|
|
80
|
+
npt.NDArray[np.int_],
|
|
81
|
+
npt.NDArray[np.float32],
|
|
82
|
+
npt.NDArray[np.int_],
|
|
83
|
+
npt.NDArray[np.float32],
|
|
84
|
+
npt.NDArray[np.int_],
|
|
85
|
+
]:
|
|
86
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
87
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
88
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
89
|
+
|
|
90
|
+
joined = metadata_df.join(emb_df, on="id", how="inner")
|
|
91
|
+
if joined.height < metadata_df.height:
|
|
92
|
+
logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
|
|
93
|
+
|
|
94
|
+
all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
95
|
+
all_labels = joined.get_column("label").to_numpy().astype(np.int_)
|
|
96
|
+
splits = joined.get_column("split").to_list()
|
|
97
|
+
|
|
98
|
+
is_train = np.array([s == "training" for s in splits], dtype=bool)
|
|
99
|
+
is_val = np.array([s == "validation" for s in splits], dtype=bool)
|
|
100
|
+
is_test = np.array([s == "testing" for s in splits], dtype=bool)
|
|
101
|
+
|
|
102
|
+
num_classes = all_labels.max() + 1
|
|
103
|
+
logger.info(
|
|
104
|
+
f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
x_train = all_features[is_train]
|
|
108
|
+
y_train = all_labels[is_train]
|
|
109
|
+
x_val = all_features[is_val]
|
|
110
|
+
y_val = all_labels[is_val]
|
|
111
|
+
x_test = all_features[is_test]
|
|
112
|
+
y_test = all_labels[is_test]
|
|
113
|
+
|
|
114
|
+
logger.info(f"Train: {len(y_train)}, Val: {len(y_val)}, Test: {len(y_test)} samples")
|
|
115
|
+
|
|
116
|
+
return (x_train, y_train, x_val, y_val, x_test, y_test)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def evaluate_flowers102_single(
|
|
120
|
+
x_train: npt.NDArray[np.float32],
|
|
121
|
+
y_train: npt.NDArray[np.int_],
|
|
122
|
+
x_val: npt.NDArray[np.float32],
|
|
123
|
+
y_val: npt.NDArray[np.int_],
|
|
124
|
+
x_test: npt.NDArray[np.float32],
|
|
125
|
+
y_test: npt.NDArray[np.int_],
|
|
126
|
+
embeddings_path: str,
|
|
127
|
+
) -> dict[str, Any]:
|
|
128
|
+
# Evaluate on validation set
|
|
129
|
+
y_pred_val, y_true_val = evaluate_simpleshot(x_train, y_train, x_val, y_val)
|
|
130
|
+
val_acc = float(np.mean(y_pred_val == y_true_val))
|
|
131
|
+
logger.info(f"Validation accuracy: {val_acc:.4f}")
|
|
132
|
+
|
|
133
|
+
# Evaluate on test set
|
|
134
|
+
y_pred_test, y_true_test = evaluate_simpleshot(x_train, y_train, x_test, y_test)
|
|
135
|
+
test_acc = float(np.mean(y_pred_test == y_true_test))
|
|
136
|
+
logger.info(f"Test accuracy: {test_acc:.4f}")
|
|
137
|
+
|
|
138
|
+
return {
|
|
139
|
+
"method": "simpleshot",
|
|
140
|
+
"val_accuracy": val_acc,
|
|
141
|
+
"test_accuracy": test_acc,
|
|
142
|
+
"embeddings_file": str(embeddings_path),
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def evaluate_flowers102(args: argparse.Namespace) -> None:
|
|
147
|
+
tic = time.time()
|
|
148
|
+
|
|
149
|
+
logger.info(f"Loading Flowers102 dataset from {args.dataset_path}")
|
|
150
|
+
dataset_path = Path(args.dataset_path)
|
|
151
|
+
metadata_df = _load_flowers102_metadata(dataset_path)
|
|
152
|
+
logger.info(f"Loaded metadata for {metadata_df.height} images")
|
|
153
|
+
|
|
154
|
+
results: list[dict[str, Any]] = []
|
|
155
|
+
total = len(args.embeddings)
|
|
156
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
157
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
158
|
+
x_train, y_train, x_val, y_val, x_test, y_test = _load_embeddings_with_split(embeddings_path, metadata_df)
|
|
159
|
+
|
|
160
|
+
result = evaluate_flowers102_single(x_train, y_train, x_val, y_val, x_test, y_test, embeddings_path)
|
|
161
|
+
results.append(result)
|
|
162
|
+
|
|
163
|
+
_print_summary_table(results)
|
|
164
|
+
|
|
165
|
+
if args.dry_run is False:
|
|
166
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
167
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
168
|
+
output_path = output_dir.joinpath("flowers102.csv")
|
|
169
|
+
_write_results_csv(results, output_path)
|
|
170
|
+
|
|
171
|
+
toc = time.time()
|
|
172
|
+
logger.info(f"Flowers102 benchmark completed in {lib.format_duration(toc - tic)}")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def set_parser(subparsers: Any) -> None:
|
|
176
|
+
subparser = subparsers.add_parser(
|
|
177
|
+
"flowers102",
|
|
178
|
+
allow_abbrev=False,
|
|
179
|
+
help="run Flowers102 benchmark - 102 class classification using SimpleShot",
|
|
180
|
+
description="run Flowers102 benchmark - 102 class classification using SimpleShot",
|
|
181
|
+
epilog=(
|
|
182
|
+
"Usage examples:\n"
|
|
183
|
+
"python -m birder.eval flowers102 --embeddings "
|
|
184
|
+
"results/flowers102_embeddings.parquet --dataset-path ~/Datasets/Flowers102 --dry-run\n"
|
|
185
|
+
"python -m birder.eval flowers102 --embeddings results/flowers102/*.parquet "
|
|
186
|
+
"--dataset-path ~/Datasets/Flowers102\n"
|
|
187
|
+
),
|
|
188
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
189
|
+
)
|
|
190
|
+
subparser.add_argument(
|
|
191
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
|
|
192
|
+
)
|
|
193
|
+
subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to Flowers102 dataset root")
|
|
194
|
+
subparser.add_argument(
|
|
195
|
+
"--dir", type=str, default="flowers102", help="place all outputs in a sub-directory (relative to results)"
|
|
196
|
+
)
|
|
197
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
198
|
+
subparser.set_defaults(func=main)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
202
|
+
if args.embeddings is None:
|
|
203
|
+
raise cli.ValidationError("--embeddings is required")
|
|
204
|
+
if args.dataset_path is None:
|
|
205
|
+
raise cli.ValidationError("--dataset-path is required")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def main(args: argparse.Namespace) -> None:
|
|
209
|
+
validate_args(args)
|
|
210
|
+
evaluate_flowers102(args)
|