birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PlantDoc benchmark using SimpleShot for plant disease classification
|
|
3
|
+
|
|
4
|
+
Paper "PlantDoc: A Dataset for Visual Plant Disease Detection",
|
|
5
|
+
https://arxiv.org/abs/1911.10317
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import time
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import numpy.typing as npt
|
|
17
|
+
import polars as pl
|
|
18
|
+
from rich.console import Console
|
|
19
|
+
from rich.table import Table
|
|
20
|
+
|
|
21
|
+
from birder.common import cli
|
|
22
|
+
from birder.common import lib
|
|
23
|
+
from birder.conf import settings
|
|
24
|
+
from birder.data.datasets.directory import make_image_dataset
|
|
25
|
+
from birder.datahub.evaluation import PlantDoc
|
|
26
|
+
from birder.eval._embeddings import load_embeddings
|
|
27
|
+
from birder.eval.methods.simpleshot import evaluate_simpleshot
|
|
28
|
+
from birder.eval.methods.simpleshot import sample_k_shot
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _print_summary_table(results: list[dict[str, Any]], k_shots: list[int]) -> None:
|
|
34
|
+
console = Console()
|
|
35
|
+
|
|
36
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
37
|
+
table.add_column("PlantDoc (SimpleShot)", style="dim")
|
|
38
|
+
for k in k_shots:
|
|
39
|
+
table.add_column(f"{k}-shot", justify="right")
|
|
40
|
+
|
|
41
|
+
table.add_column("Runs", justify="right")
|
|
42
|
+
|
|
43
|
+
for result in results:
|
|
44
|
+
row = [Path(result["embeddings_file"]).name]
|
|
45
|
+
for k in k_shots:
|
|
46
|
+
acc = result["accuracies"].get(k)
|
|
47
|
+
row.append(f"{acc:.4f}" if acc is not None else "-")
|
|
48
|
+
|
|
49
|
+
row.append(f"{result['num_runs']}")
|
|
50
|
+
table.add_row(*row)
|
|
51
|
+
|
|
52
|
+
console.print(table)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _write_results_csv(results: list[dict[str, Any]], k_shots: list[int], output_path: Path) -> None:
|
|
56
|
+
rows: list[dict[str, Any]] = []
|
|
57
|
+
for result in results:
|
|
58
|
+
row: dict[str, Any] = {
|
|
59
|
+
"embeddings_file": result["embeddings_file"],
|
|
60
|
+
"method": result["method"],
|
|
61
|
+
"num_runs": result["num_runs"],
|
|
62
|
+
}
|
|
63
|
+
for k in k_shots:
|
|
64
|
+
row[f"{k}_shot_acc"] = result["accuracies"].get(k)
|
|
65
|
+
row[f"{k}_shot_std"] = result["accuracies_std"].get(k)
|
|
66
|
+
|
|
67
|
+
rows.append(row)
|
|
68
|
+
|
|
69
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
70
|
+
logger.info(f"Results saved to {output_path}")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _load_plantdoc_metadata(dataset: PlantDoc) -> pl.DataFrame:
|
|
74
|
+
"""
|
|
75
|
+
Load metadata using make_image_dataset with a fixed class_to_idx
|
|
76
|
+
|
|
77
|
+
Returns DataFrame with columns: id (filename stem), label, split
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
# Build unified class_to_idx from the union of both splits
|
|
81
|
+
all_classes: set[str] = set()
|
|
82
|
+
for split_dir in [dataset.train_dir, dataset.test_dir]:
|
|
83
|
+
all_classes.update(entry.name for entry in os.scandir(str(split_dir)) if entry.is_dir())
|
|
84
|
+
|
|
85
|
+
class_to_idx = {cls_name: idx for idx, cls_name in enumerate(sorted(all_classes))}
|
|
86
|
+
|
|
87
|
+
rows: list[dict[str, Any]] = []
|
|
88
|
+
for split, split_dir in [("train", dataset.train_dir), ("test", dataset.test_dir)]:
|
|
89
|
+
image_dataset = make_image_dataset([str(split_dir)], class_to_idx)
|
|
90
|
+
for i in range(len(image_dataset)):
|
|
91
|
+
path = image_dataset.paths[i].decode("utf-8")
|
|
92
|
+
label = image_dataset.labels[i].item()
|
|
93
|
+
rows.append({"id": Path(path).stem, "label": label, "split": split})
|
|
94
|
+
|
|
95
|
+
return pl.DataFrame(rows)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _load_embeddings_with_split(
|
|
99
|
+
embeddings_path: str, metadata_df: pl.DataFrame
|
|
100
|
+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
|
|
101
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
102
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
103
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
104
|
+
|
|
105
|
+
joined = metadata_df.join(emb_df, on="id", how="inner")
|
|
106
|
+
if joined.height < metadata_df.height:
|
|
107
|
+
logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
|
|
108
|
+
|
|
109
|
+
all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
110
|
+
all_labels = joined.get_column("label").to_numpy().astype(np.int_)
|
|
111
|
+
splits = joined.get_column("split").to_list()
|
|
112
|
+
|
|
113
|
+
is_train = np.array([s == "train" for s in splits], dtype=bool)
|
|
114
|
+
is_test = np.array([s == "test" for s in splits], dtype=bool)
|
|
115
|
+
|
|
116
|
+
num_classes = all_labels.max() + 1
|
|
117
|
+
logger.info(
|
|
118
|
+
f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
x_train = all_features[is_train]
|
|
122
|
+
y_train = all_labels[is_train]
|
|
123
|
+
x_test = all_features[is_test]
|
|
124
|
+
y_test = all_labels[is_test]
|
|
125
|
+
|
|
126
|
+
logger.info(f"Train: {len(y_train)} samples, Test: {len(y_test)} samples")
|
|
127
|
+
|
|
128
|
+
return (x_train, y_train, x_test, y_test)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def evaluate_plantdoc_single(
|
|
132
|
+
x_train: npt.NDArray[np.float32],
|
|
133
|
+
y_train: npt.NDArray[np.int_],
|
|
134
|
+
x_test: npt.NDArray[np.float32],
|
|
135
|
+
y_test: npt.NDArray[np.int_],
|
|
136
|
+
k_shot: int,
|
|
137
|
+
num_runs: int,
|
|
138
|
+
seed: int,
|
|
139
|
+
embeddings_path: str,
|
|
140
|
+
) -> dict[str, Any]:
|
|
141
|
+
logger.info(f"Evaluating {k_shot}-shot")
|
|
142
|
+
|
|
143
|
+
scores: list[float] = []
|
|
144
|
+
for run in range(num_runs):
|
|
145
|
+
run_seed = seed + run
|
|
146
|
+
rng = np.random.default_rng(run_seed)
|
|
147
|
+
|
|
148
|
+
# Sample k examples per class
|
|
149
|
+
x_train_k_shot, y_train_k_shot = sample_k_shot(x_train, y_train, k_shot, rng)
|
|
150
|
+
|
|
151
|
+
# Evaluate using SimpleShot
|
|
152
|
+
y_pred, y_true = evaluate_simpleshot(x_train_k_shot, y_train_k_shot, x_test, y_test)
|
|
153
|
+
|
|
154
|
+
acc = float(np.mean(y_pred == y_true))
|
|
155
|
+
scores.append(acc)
|
|
156
|
+
logger.info(f"Run {run + 1}/{num_runs} - Accuracy: {acc:.4f}")
|
|
157
|
+
|
|
158
|
+
scores_arr = np.array(scores)
|
|
159
|
+
mean_acc = float(scores_arr.mean())
|
|
160
|
+
std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
|
|
161
|
+
|
|
162
|
+
logger.info(f"Mean accuracy over {num_runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
|
|
163
|
+
|
|
164
|
+
return {
|
|
165
|
+
"method": "simpleshot",
|
|
166
|
+
"k_shot": k_shot,
|
|
167
|
+
"accuracy": mean_acc,
|
|
168
|
+
"accuracy_std": std_acc,
|
|
169
|
+
"num_runs": num_runs,
|
|
170
|
+
"embeddings_file": str(embeddings_path),
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def evaluate_plantdoc(args: argparse.Namespace) -> None:
|
|
175
|
+
tic = time.time()
|
|
176
|
+
|
|
177
|
+
logger.info(f"Loading PlantDoc dataset from {args.dataset_path}")
|
|
178
|
+
dataset = PlantDoc(args.dataset_path)
|
|
179
|
+
metadata_df = _load_plantdoc_metadata(dataset)
|
|
180
|
+
logger.info(f"Loaded metadata for {metadata_df.height} images")
|
|
181
|
+
|
|
182
|
+
results: list[dict[str, Any]] = []
|
|
183
|
+
total = len(args.embeddings)
|
|
184
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
185
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
186
|
+
x_train, y_train, x_test, y_test = _load_embeddings_with_split(embeddings_path, metadata_df)
|
|
187
|
+
|
|
188
|
+
accuracies: dict[int, float] = {}
|
|
189
|
+
accuracies_std: dict[int, float] = {}
|
|
190
|
+
for k_shot in args.k_shot:
|
|
191
|
+
single_result = evaluate_plantdoc_single(
|
|
192
|
+
x_train, y_train, x_test, y_test, k_shot, args.runs, args.seed, embeddings_path
|
|
193
|
+
)
|
|
194
|
+
accuracies[k_shot] = single_result["accuracy"]
|
|
195
|
+
accuracies_std[k_shot] = single_result["accuracy_std"]
|
|
196
|
+
|
|
197
|
+
results.append(
|
|
198
|
+
{
|
|
199
|
+
"embeddings_file": str(embeddings_path),
|
|
200
|
+
"method": "simpleshot",
|
|
201
|
+
"num_runs": args.runs,
|
|
202
|
+
"accuracies": accuracies,
|
|
203
|
+
"accuracies_std": accuracies_std,
|
|
204
|
+
}
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
_print_summary_table(results, args.k_shot)
|
|
208
|
+
|
|
209
|
+
if args.dry_run is False:
|
|
210
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
211
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
212
|
+
output_path = output_dir.joinpath("plantdoc.csv")
|
|
213
|
+
_write_results_csv(results, args.k_shot, output_path)
|
|
214
|
+
|
|
215
|
+
toc = time.time()
|
|
216
|
+
logger.info(f"PlantDoc benchmark completed in {lib.format_duration(toc - tic)}")
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def set_parser(subparsers: Any) -> None:
|
|
220
|
+
subparser = subparsers.add_parser(
|
|
221
|
+
"plantdoc",
|
|
222
|
+
allow_abbrev=False,
|
|
223
|
+
help="run PlantDoc benchmark - 27 class plant disease classification using SimpleShot",
|
|
224
|
+
description="run PlantDoc benchmark - 27 class plant disease classification using SimpleShot",
|
|
225
|
+
epilog=(
|
|
226
|
+
"Usage examples:\n"
|
|
227
|
+
"python -m birder.eval plantdoc --embeddings "
|
|
228
|
+
"results/plantdoc_embeddings.parquet --dataset-path ~/Datasets/PlantDoc --dry-run\n"
|
|
229
|
+
"python -m birder.eval plantdoc --embeddings results/plantdoc/*.parquet "
|
|
230
|
+
"--dataset-path ~/Datasets/PlantDoc\n"
|
|
231
|
+
),
|
|
232
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
233
|
+
)
|
|
234
|
+
subparser.add_argument(
|
|
235
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
|
|
236
|
+
)
|
|
237
|
+
subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to PlantDoc dataset root")
|
|
238
|
+
subparser.add_argument(
|
|
239
|
+
"--k-shot", type=int, nargs="+", default=[2, 5], help="number of examples per class for few-shot learning"
|
|
240
|
+
)
|
|
241
|
+
subparser.add_argument("--runs", type=int, default=5, help="number of evaluation runs")
|
|
242
|
+
subparser.add_argument("--seed", type=int, default=0, help="base random seed")
|
|
243
|
+
subparser.add_argument(
|
|
244
|
+
"--dir", type=str, default="plantdoc", help="place all outputs in a sub-directory (relative to results)"
|
|
245
|
+
)
|
|
246
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
247
|
+
subparser.set_defaults(func=main)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
251
|
+
if args.embeddings is None:
|
|
252
|
+
raise cli.ValidationError("--embeddings is required")
|
|
253
|
+
if args.dataset_path is None:
|
|
254
|
+
raise cli.ValidationError("--dataset-path is required")
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def main(args: argparse.Namespace) -> None:
|
|
258
|
+
validate_args(args)
|
|
259
|
+
evaluate_plantdoc(args)
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PlantNet-300K benchmark using SimpleShot for plant species classification
|
|
3
|
+
|
|
4
|
+
Paper "Pl@ntNet-300K: a plant image dataset with high label ambiguity and a long-tailed distribution"
|
|
5
|
+
https://openreview.net/forum?id=eLYinD0TtIt
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import numpy.typing as npt
|
|
16
|
+
import polars as pl
|
|
17
|
+
from rich.console import Console
|
|
18
|
+
from rich.table import Table
|
|
19
|
+
from torchvision.datasets import ImageFolder
|
|
20
|
+
|
|
21
|
+
from birder.common import cli
|
|
22
|
+
from birder.common import lib
|
|
23
|
+
from birder.conf import settings
|
|
24
|
+
from birder.datahub.evaluation import PlantNet
|
|
25
|
+
from birder.eval._embeddings import load_embeddings
|
|
26
|
+
from birder.eval.methods.simpleshot import evaluate_simpleshot
|
|
27
|
+
from birder.eval.methods.simpleshot import sample_k_shot
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _print_summary_table(results: list[dict[str, Any]], k_shots: list[int]) -> None:
|
|
33
|
+
console = Console()
|
|
34
|
+
|
|
35
|
+
table = Table(show_header=True, header_style="bold dark_magenta")
|
|
36
|
+
table.add_column("PlantNet (SimpleShot)", style="dim")
|
|
37
|
+
for k in k_shots:
|
|
38
|
+
table.add_column(f"{k}-shot", justify="right")
|
|
39
|
+
|
|
40
|
+
table.add_column("Runs", justify="right")
|
|
41
|
+
|
|
42
|
+
for result in results:
|
|
43
|
+
row = [Path(result["embeddings_file"]).name]
|
|
44
|
+
for k in k_shots:
|
|
45
|
+
acc = result["accuracies"].get(k)
|
|
46
|
+
row.append(f"{acc:.4f}" if acc is not None else "-")
|
|
47
|
+
|
|
48
|
+
row.append(f"{result['num_runs']}")
|
|
49
|
+
table.add_row(*row)
|
|
50
|
+
|
|
51
|
+
console.print(table)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _write_results_csv(results: list[dict[str, Any]], k_shots: list[int], output_path: Path) -> None:
|
|
55
|
+
rows: list[dict[str, Any]] = []
|
|
56
|
+
for result in results:
|
|
57
|
+
row: dict[str, Any] = {
|
|
58
|
+
"embeddings_file": result["embeddings_file"],
|
|
59
|
+
"method": result["method"],
|
|
60
|
+
"num_runs": result["num_runs"],
|
|
61
|
+
}
|
|
62
|
+
for k in k_shots:
|
|
63
|
+
row[f"{k}_shot_acc"] = result["accuracies"].get(k)
|
|
64
|
+
row[f"{k}_shot_std"] = result["accuracies_std"].get(k)
|
|
65
|
+
|
|
66
|
+
rows.append(row)
|
|
67
|
+
|
|
68
|
+
pl.DataFrame(rows).write_csv(output_path)
|
|
69
|
+
logger.info(f"Results saved to {output_path}")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _load_plantnet_metadata(dataset: PlantNet) -> pl.DataFrame:
|
|
73
|
+
"""
|
|
74
|
+
Load metadata from ImageFolder structure
|
|
75
|
+
|
|
76
|
+
Returns DataFrame with columns: id (filename stem), label, split
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
rows: list[dict[str, Any]] = []
|
|
80
|
+
for split, split_dir in [("train", dataset.train_dir), ("val", dataset.val_dir), ("test", dataset.test_dir)]:
|
|
81
|
+
if not split_dir.exists():
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
image_dataset = ImageFolder(str(split_dir))
|
|
85
|
+
for path, label in image_dataset.samples:
|
|
86
|
+
rows.append({"id": Path(path).stem, "label": label, "split": split})
|
|
87
|
+
|
|
88
|
+
return pl.DataFrame(rows)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _load_embeddings_with_split(
|
|
92
|
+
embeddings_path: str, metadata_df: pl.DataFrame
|
|
93
|
+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
|
|
94
|
+
logger.info(f"Loading embeddings from {embeddings_path}")
|
|
95
|
+
sample_ids, all_features = load_embeddings(embeddings_path)
|
|
96
|
+
emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
|
|
97
|
+
|
|
98
|
+
joined = metadata_df.join(emb_df, on="id", how="inner")
|
|
99
|
+
if joined.height < metadata_df.height:
|
|
100
|
+
logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
|
|
101
|
+
|
|
102
|
+
all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
|
|
103
|
+
all_labels = joined.get_column("label").to_numpy().astype(np.int_)
|
|
104
|
+
splits = joined.get_column("split").to_list()
|
|
105
|
+
|
|
106
|
+
is_train = np.array([s == "train" for s in splits], dtype=bool)
|
|
107
|
+
is_test = np.array([s == "test" for s in splits], dtype=bool)
|
|
108
|
+
|
|
109
|
+
num_classes = all_labels.max() + 1
|
|
110
|
+
logger.info(
|
|
111
|
+
f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
x_train = all_features[is_train]
|
|
115
|
+
y_train = all_labels[is_train]
|
|
116
|
+
x_test = all_features[is_test]
|
|
117
|
+
y_test = all_labels[is_test]
|
|
118
|
+
|
|
119
|
+
logger.info(f"Train: {len(y_train)} samples, Test: {len(y_test)} samples")
|
|
120
|
+
|
|
121
|
+
return (x_train, y_train, x_test, y_test)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def evaluate_plantnet_single(
|
|
125
|
+
x_train: npt.NDArray[np.float32],
|
|
126
|
+
y_train: npt.NDArray[np.int_],
|
|
127
|
+
x_test: npt.NDArray[np.float32],
|
|
128
|
+
y_test: npt.NDArray[np.int_],
|
|
129
|
+
k_shot: int,
|
|
130
|
+
num_runs: int,
|
|
131
|
+
seed: int,
|
|
132
|
+
embeddings_path: str,
|
|
133
|
+
) -> dict[str, Any]:
|
|
134
|
+
logger.info(f"Evaluating {k_shot}-shot")
|
|
135
|
+
|
|
136
|
+
scores: list[float] = []
|
|
137
|
+
for run in range(num_runs):
|
|
138
|
+
run_seed = seed + run
|
|
139
|
+
rng = np.random.default_rng(run_seed)
|
|
140
|
+
|
|
141
|
+
# Sample k examples per class
|
|
142
|
+
x_train_k_shot, y_train_k_shot = sample_k_shot(x_train, y_train, k_shot, rng)
|
|
143
|
+
|
|
144
|
+
# Evaluate using SimpleShot
|
|
145
|
+
y_pred, y_true = evaluate_simpleshot(x_train_k_shot, y_train_k_shot, x_test, y_test)
|
|
146
|
+
|
|
147
|
+
acc = float(np.mean(y_pred == y_true))
|
|
148
|
+
scores.append(acc)
|
|
149
|
+
logger.info(f"Run {run + 1}/{num_runs} - Accuracy: {acc:.4f}")
|
|
150
|
+
|
|
151
|
+
scores_arr = np.array(scores)
|
|
152
|
+
mean_acc = float(scores_arr.mean())
|
|
153
|
+
std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
|
|
154
|
+
|
|
155
|
+
logger.info(f"Mean accuracy over {num_runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
|
|
156
|
+
|
|
157
|
+
return {
|
|
158
|
+
"method": "simpleshot",
|
|
159
|
+
"k_shot": k_shot,
|
|
160
|
+
"accuracy": mean_acc,
|
|
161
|
+
"accuracy_std": std_acc,
|
|
162
|
+
"num_runs": num_runs,
|
|
163
|
+
"embeddings_file": str(embeddings_path),
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def evaluate_plantnet(args: argparse.Namespace) -> None:
|
|
168
|
+
tic = time.time()
|
|
169
|
+
|
|
170
|
+
logger.info(f"Loading PlantNet dataset from {args.dataset_path}")
|
|
171
|
+
dataset = PlantNet(args.dataset_path)
|
|
172
|
+
metadata_df = _load_plantnet_metadata(dataset)
|
|
173
|
+
logger.info(f"Loaded metadata for {metadata_df.height} images")
|
|
174
|
+
|
|
175
|
+
results: list[dict[str, Any]] = []
|
|
176
|
+
total = len(args.embeddings)
|
|
177
|
+
for idx, embeddings_path in enumerate(args.embeddings, start=1):
|
|
178
|
+
logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
|
|
179
|
+
x_train, y_train, x_test, y_test = _load_embeddings_with_split(embeddings_path, metadata_df)
|
|
180
|
+
|
|
181
|
+
accuracies: dict[int, float] = {}
|
|
182
|
+
accuracies_std: dict[int, float] = {}
|
|
183
|
+
for k_shot in args.k_shot:
|
|
184
|
+
single_result = evaluate_plantnet_single(
|
|
185
|
+
x_train, y_train, x_test, y_test, k_shot, args.runs, args.seed, embeddings_path
|
|
186
|
+
)
|
|
187
|
+
accuracies[k_shot] = single_result["accuracy"]
|
|
188
|
+
accuracies_std[k_shot] = single_result["accuracy_std"]
|
|
189
|
+
|
|
190
|
+
results.append(
|
|
191
|
+
{
|
|
192
|
+
"embeddings_file": str(embeddings_path),
|
|
193
|
+
"method": "simpleshot",
|
|
194
|
+
"num_runs": args.runs,
|
|
195
|
+
"accuracies": accuracies,
|
|
196
|
+
"accuracies_std": accuracies_std,
|
|
197
|
+
}
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
_print_summary_table(results, args.k_shot)
|
|
201
|
+
|
|
202
|
+
if args.dry_run is False:
|
|
203
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
204
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
205
|
+
output_path = output_dir.joinpath("plantnet.csv")
|
|
206
|
+
_write_results_csv(results, args.k_shot, output_path)
|
|
207
|
+
|
|
208
|
+
toc = time.time()
|
|
209
|
+
logger.info(f"PlantNet benchmark completed in {lib.format_duration(toc - tic)}")
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def set_parser(subparsers: Any) -> None:
|
|
213
|
+
subparser = subparsers.add_parser(
|
|
214
|
+
"plantnet",
|
|
215
|
+
allow_abbrev=False,
|
|
216
|
+
help="run PlantNet-300K benchmark - 1081 species classification using SimpleShot",
|
|
217
|
+
description="run PlantNet-300K benchmark - 1081 species classification using SimpleShot",
|
|
218
|
+
epilog=(
|
|
219
|
+
"Usage examples:\n"
|
|
220
|
+
"python -m birder.eval plantnet --embeddings "
|
|
221
|
+
"results/plantnet_embeddings.parquet --dataset-path ~/Datasets/plantnet_300K --dry-run\n"
|
|
222
|
+
"python -m birder.eval plantnet --embeddings results/plantnet/*.parquet "
|
|
223
|
+
"--dataset-path ~/Datasets/plantnet_300K\n"
|
|
224
|
+
),
|
|
225
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
226
|
+
)
|
|
227
|
+
subparser.add_argument(
|
|
228
|
+
"--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
|
|
229
|
+
)
|
|
230
|
+
subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to PlantNet dataset root")
|
|
231
|
+
subparser.add_argument(
|
|
232
|
+
"--k-shot", type=int, nargs="+", default=[2, 5], help="number of examples per class for few-shot learning"
|
|
233
|
+
)
|
|
234
|
+
subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
|
|
235
|
+
subparser.add_argument("--seed", type=int, default=0, help="base random seed")
|
|
236
|
+
subparser.add_argument(
|
|
237
|
+
"--dir", type=str, default="plantnet", help="place all outputs in a sub-directory (relative to results)"
|
|
238
|
+
)
|
|
239
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
240
|
+
subparser.set_defaults(func=main)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
244
|
+
if args.embeddings is None:
|
|
245
|
+
raise cli.ValidationError("--embeddings is required")
|
|
246
|
+
if args.dataset_path is None:
|
|
247
|
+
raise cli.ValidationError("--dataset-path is required")
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def main(args: argparse.Namespace) -> None:
|
|
251
|
+
validate_args(args)
|
|
252
|
+
evaluate_plantnet(args)
|