birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,262 @@
1
+ """
2
+ NeWT benchmark, adapted from
3
+ https://github.com/samuelstevens/biobench/blob/main/src/biobench/newt/__init__.py
4
+
5
+ Paper "Benchmarking Representation Learning for Natural World Image Collections",
6
+ https://arxiv.org/abs/2103.16483
7
+ """
8
+
9
+ # Reference license: MIT
10
+
11
+ import argparse
12
+ import logging
13
+ import time
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+ import numpy.typing as npt
19
+ import polars as pl
20
+ from rich.console import Console
21
+ from rich.table import Table
22
+
23
+ from birder.common import cli
24
+ from birder.common import lib
25
+ from birder.conf import settings
26
+ from birder.datahub.evaluation import NeWT
27
+ from birder.eval._embeddings import l2_normalize
28
+ from birder.eval._embeddings import load_embeddings
29
+ from birder.eval.methods.simpleshot import normalize_features
30
+ from birder.eval.methods.svm import evaluate_svm
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def _print_summary_table(results: list[dict[str, Any]]) -> None:
36
+ console = Console()
37
+
38
+ cluster_names = sorted({cluster for result in results for cluster in result["per_cluster_accuracy"].keys()})
39
+
40
+ table = Table(show_header=True, header_style="bold dark_magenta")
41
+ table.add_column("NeWT (SVM)", style="dim")
42
+ table.add_column("Accuracy", justify="right")
43
+ table.add_column("Std", justify="right")
44
+ table.add_column("Runs", justify="right")
45
+ for cluster in cluster_names:
46
+ table.add_column(cluster.replace("_", " ").title(), justify="right")
47
+
48
+ for result in results:
49
+ row = [
50
+ Path(result["embeddings_file"]).name,
51
+ f"{result['accuracy']:.4f}",
52
+ f"{result['accuracy_std']:.4f}",
53
+ f"{result['num_runs']}",
54
+ ]
55
+ for cluster in cluster_names:
56
+ acc = result["per_cluster_accuracy"].get(cluster)
57
+ row.append(f"{acc:.4f}" if acc is not None else "-")
58
+
59
+ table.add_row(*row)
60
+
61
+ console.print(table)
62
+
63
+
64
+ def _write_results_csv(results: list[dict[str, Any]], output_path: Path) -> None:
65
+ cluster_names = sorted({cluster for result in results for cluster in result["per_cluster_accuracy"].keys()})
66
+ rows: list[dict[str, Any]] = []
67
+ for result in results:
68
+ row: dict[str, Any] = {
69
+ "embeddings_file": result["embeddings_file"],
70
+ "method": result["method"],
71
+ "accuracy": result["accuracy"],
72
+ "accuracy_std": result["accuracy_std"],
73
+ "num_runs": result["num_runs"],
74
+ }
75
+ for cluster in cluster_names:
76
+ row[f"cluster_{cluster}"] = result["per_cluster_accuracy"].get(cluster)
77
+
78
+ rows.append(row)
79
+
80
+ pl.DataFrame(rows).write_csv(output_path)
81
+ logger.info(f"Results saved to {output_path}")
82
+
83
+
84
+ # pylint: disable=too-many-locals
85
+ def evaluate_newt_single(embeddings_path: str, labels_df: pl.DataFrame, args: argparse.Namespace) -> dict[str, Any]:
86
+ logger.info(f"Loading embeddings from {embeddings_path}")
87
+ sample_ids, all_features = load_embeddings(embeddings_path)
88
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
89
+
90
+ joined = labels_df.join(emb_df, on="id", how="inner").sort("index")
91
+ if joined.height < labels_df.height:
92
+ logger.warning(f"Join dropped {labels_df.height - joined.height} samples (missing embeddings)")
93
+
94
+ all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
95
+ logger.info(f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions")
96
+
97
+ # Global L2 normalization
98
+ all_features = l2_normalize(all_features)
99
+ joined = joined.with_columns(pl.Series("embedding", all_features.tolist()))
100
+
101
+ tasks = joined.get_column("task").unique().to_list()
102
+ logger.info(f"Found {len(tasks)} tasks")
103
+
104
+ scores: list[float] = []
105
+ cluster_scores: dict[str, list[float]] = {}
106
+ for run in range(args.runs):
107
+ run_seed = args.seed + run
108
+
109
+ y_preds_all: list[npt.NDArray[np.int_]] = []
110
+ y_trues_all: list[npt.NDArray[np.int_]] = []
111
+ cluster_preds: dict[str, list[npt.NDArray[np.int_]]] = {}
112
+ cluster_trues: dict[str, list[npt.NDArray[np.int_]]] = {}
113
+ for task_name in tasks:
114
+ tdf = joined.filter(pl.col("task") == task_name)
115
+
116
+ features = np.array(tdf.get_column("embedding").to_list(), dtype=np.float32)
117
+
118
+ labels = tdf.get_column("label").to_numpy()
119
+ is_train = (tdf.get_column("split") == "train").to_numpy()
120
+ cluster = tdf.item(0, "task_cluster")
121
+
122
+ x_train = features[is_train]
123
+ y_train = labels[is_train]
124
+ x_test = features[~is_train]
125
+ y_test = labels[~is_train]
126
+
127
+ if x_train.size == 0 or x_test.size == 0:
128
+ logger.warning(f"Skipping task {task_name}: empty train or test split")
129
+ continue
130
+
131
+ # Per-task centering and L2 normalization
132
+ x_train, x_test = normalize_features(x_train, x_test)
133
+
134
+ # Train and evaluate SVM
135
+ y_pred, y_true = evaluate_svm(
136
+ x_train,
137
+ y_train,
138
+ x_test,
139
+ y_test,
140
+ n_iter=args.n_iter,
141
+ n_jobs=args.n_jobs,
142
+ seed=run_seed,
143
+ )
144
+
145
+ y_preds_all.append(y_pred)
146
+ y_trues_all.append(y_true)
147
+
148
+ # Track per-cluster predictions
149
+ if cluster not in cluster_preds:
150
+ cluster_preds[cluster] = []
151
+ cluster_trues[cluster] = []
152
+ cluster_preds[cluster].append(y_pred)
153
+ cluster_trues[cluster].append(y_true)
154
+
155
+ # Micro-averaged accuracy
156
+ y_preds = np.concatenate(y_preds_all)
157
+ y_trues = np.concatenate(y_trues_all)
158
+ acc = float(np.mean(y_preds == y_trues))
159
+ scores.append(acc)
160
+ logger.info(f"Run {run + 1}/{args.runs} - Accuracy: {acc:.4f}")
161
+
162
+ # Compute per-cluster accuracy for this run
163
+ for cluster, preds_list in cluster_preds.items():
164
+ preds = np.concatenate(preds_list)
165
+ trues = np.concatenate(cluster_trues[cluster])
166
+ cluster_acc = float(np.mean(preds == trues))
167
+ if cluster not in cluster_scores:
168
+ cluster_scores[cluster] = []
169
+ cluster_scores[cluster].append(cluster_acc)
170
+
171
+ # Average per-cluster accuracy across runs
172
+ per_cluster_accuracy: dict[str, float] = {}
173
+ for cluster, accs in cluster_scores.items():
174
+ per_cluster_accuracy[cluster.lower()] = float(np.mean(accs))
175
+
176
+ scores_arr = np.array(scores)
177
+ mean_acc = float(scores_arr.mean())
178
+ std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
179
+
180
+ logger.info(f"Mean accuracy over {args.runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
181
+ for cluster, acc in sorted(per_cluster_accuracy.items()):
182
+ logger.info(f" {cluster}: {acc:.4f}")
183
+
184
+ return {
185
+ "method": "svm",
186
+ "accuracy": mean_acc,
187
+ "accuracy_std": std_acc,
188
+ "num_runs": args.runs,
189
+ "per_cluster_accuracy": per_cluster_accuracy,
190
+ "embeddings_file": str(embeddings_path),
191
+ }
192
+
193
+
194
+ def evaluate_newt(args: argparse.Namespace) -> None:
195
+ tic = time.time()
196
+
197
+ logger.info(f"Loading NeWT dataset from {args.dataset_path}")
198
+ dataset = NeWT(args.dataset_path)
199
+ labels_path = dataset.labels_path
200
+ logger.info(f"Loading labels from {labels_path}")
201
+ labels_df = pl.read_csv(labels_path).with_row_index(name="index").with_columns(pl.col("id").cast(pl.Utf8))
202
+
203
+ results: list[dict[str, Any]] = []
204
+ total = len(args.embeddings)
205
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
206
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
207
+ result = evaluate_newt_single(embeddings_path, labels_df, args)
208
+ results.append(result)
209
+
210
+ _print_summary_table(results)
211
+
212
+ if args.dry_run is False:
213
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
214
+ output_dir.mkdir(parents=True, exist_ok=True)
215
+ output_path = output_dir.joinpath("newt.csv")
216
+ _write_results_csv(results, output_path)
217
+
218
+ toc = time.time()
219
+ logger.info(f"NeWT benchmark completed in {lib.format_duration(toc - tic)}")
220
+
221
+
222
+ def set_parser(subparsers: Any) -> None:
223
+ subparser = subparsers.add_parser(
224
+ "newt",
225
+ allow_abbrev=False,
226
+ help="run NeWT benchmark - 164 binary classification tasks evaluated with SVM",
227
+ description="run NeWT benchmark - 164 binary classification tasks evaluated with SVM",
228
+ epilog=(
229
+ "Usage examples:\n"
230
+ "python -m birder.eval newt --embeddings "
231
+ "results/hieradet_d_small_dino-v2_0_224px_crop1.0_36032_output.parquet "
232
+ "--dataset-path ~/Datasets/NeWT --dry-run\n"
233
+ "python -m birder.eval newt --embeddings results/newt/*.parquet "
234
+ "--dataset-path ~/Datasets/NeWT\n"
235
+ ),
236
+ formatter_class=cli.ArgumentHelpFormatter,
237
+ )
238
+ subparser.add_argument(
239
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings/logits parquet files"
240
+ )
241
+ subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to NeWT dataset root")
242
+ subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
243
+ subparser.add_argument("--n-iter", type=int, default=100, help="SVM hyperparameter search iterations")
244
+ subparser.add_argument("--n-jobs", type=int, default=8, help="parallel jobs for RandomizedSearchCV")
245
+ subparser.add_argument("--seed", type=int, default=0, help="base random seed")
246
+ subparser.add_argument(
247
+ "--dir", type=str, default="newt", help="place all outputs in a sub-directory (relative to results)"
248
+ )
249
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
250
+ subparser.set_defaults(func=main)
251
+
252
+
253
+ def validate_args(args: argparse.Namespace) -> None:
254
+ if args.embeddings is None:
255
+ raise cli.ValidationError("--embeddings is required")
256
+ if args.dataset_path is None:
257
+ raise cli.ValidationError("--dataset-path is required")
258
+
259
+
260
+ def main(args: argparse.Namespace) -> None:
261
+ validate_args(args)
262
+ evaluate_newt(args)
@@ -0,0 +1,255 @@
1
+ """
2
+ Plankton benchmark using linear probing for phytoplankton classification
3
+
4
+ Dataset "SYKE-plankton_IFCB_2022", https://b2share.eudat.eu/records/xvnrp-7ga56
5
+ """
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import numpy.typing as npt
16
+ import polars as pl
17
+ import torch
18
+ from rich.console import Console
19
+ from rich.table import Table
20
+
21
+ from birder.common import cli
22
+ from birder.common import lib
23
+ from birder.conf import settings
24
+ from birder.data.datasets.directory import make_image_dataset
25
+ from birder.datahub.evaluation import Plankton
26
+ from birder.eval._embeddings import load_embeddings
27
+ from birder.eval.methods.linear import evaluate_linear_probe
28
+ from birder.eval.methods.linear import train_linear_probe
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def _print_summary_table(results: list[dict[str, Any]]) -> None:
34
+ console = Console()
35
+
36
+ table = Table(show_header=True, header_style="bold dark_magenta")
37
+ table.add_column("Plankton (Linear)", style="dim")
38
+ table.add_column("Accuracy", justify="right")
39
+ table.add_column("Std", justify="right")
40
+ table.add_column("Runs", justify="right")
41
+
42
+ for result in results:
43
+ table.add_row(
44
+ Path(result["embeddings_file"]).name,
45
+ f"{result['accuracy']:.4f}",
46
+ f"{result['accuracy_std']:.4f}",
47
+ f"{result['num_runs']}",
48
+ )
49
+
50
+ console.print(table)
51
+
52
+
53
+ def _write_results_csv(results: list[dict[str, Any]], output_path: Path) -> None:
54
+ rows: list[dict[str, Any]] = []
55
+ for result in results:
56
+ rows.append(
57
+ {
58
+ "embeddings_file": result["embeddings_file"],
59
+ "method": result["method"],
60
+ "accuracy": result["accuracy"],
61
+ "accuracy_std": result["accuracy_std"],
62
+ "num_runs": result["num_runs"],
63
+ }
64
+ )
65
+
66
+ pl.DataFrame(rows).write_csv(output_path)
67
+ logger.info(f"Results saved to {output_path}")
68
+
69
+
70
+ def _load_plankton_metadata(dataset: Plankton) -> pl.DataFrame:
71
+ """
72
+ Load metadata using make_image_dataset with a fixed class_to_idx
73
+
74
+ Returns DataFrame with columns: id (filename stem), label, split
75
+ """
76
+
77
+ # Build class_to_idx from train classes only
78
+ class_to_idx = {
79
+ entry.name: idx
80
+ for idx, entry in enumerate(sorted(os.scandir(str(dataset.train_dir)), key=lambda e: e.name))
81
+ if entry.is_dir()
82
+ }
83
+
84
+ rows: list[dict[str, Any]] = []
85
+ for split, split_dir in [("train", dataset.train_dir), ("val", dataset.val_dir)]:
86
+ image_dataset = make_image_dataset([str(split_dir)], class_to_idx)
87
+ for i in range(len(image_dataset)):
88
+ path = image_dataset.paths[i].decode("utf-8")
89
+ label = image_dataset.labels[i].item()
90
+ rows.append({"id": Path(path).stem, "label": label, "split": split})
91
+
92
+ return pl.DataFrame(rows)
93
+
94
+
95
+ def _load_embeddings_with_split(
96
+ embeddings_path: str, metadata_df: pl.DataFrame
97
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
98
+ logger.info(f"Loading embeddings from {embeddings_path}")
99
+ sample_ids, all_features = load_embeddings(embeddings_path)
100
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
101
+
102
+ joined = metadata_df.join(emb_df, on="id", how="inner")
103
+ if joined.height < metadata_df.height:
104
+ logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
105
+
106
+ all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
107
+ all_labels = joined.get_column("label").to_numpy().astype(np.int_)
108
+ splits = joined.get_column("split").to_list()
109
+
110
+ is_train = np.array([s == "train" for s in splits], dtype=bool)
111
+ is_val = np.array([s == "val" for s in splits], dtype=bool)
112
+
113
+ num_classes = all_labels.max() + 1
114
+ logger.info(
115
+ f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
116
+ )
117
+
118
+ x_train = all_features[is_train]
119
+ y_train = all_labels[is_train]
120
+ x_val = all_features[is_val]
121
+ y_val = all_labels[is_val]
122
+
123
+ logger.info(f"Train: {len(y_train)} samples, Val: {len(y_val)} samples")
124
+
125
+ return (x_train, y_train, x_val, y_val)
126
+
127
+
128
+ def evaluate_plankton_single(
129
+ x_train: npt.NDArray[np.float32],
130
+ y_train: npt.NDArray[np.int_],
131
+ x_val: npt.NDArray[np.float32],
132
+ y_val: npt.NDArray[np.int_],
133
+ args: argparse.Namespace,
134
+ embeddings_path: str,
135
+ device: torch.device,
136
+ ) -> dict[str, Any]:
137
+ num_classes = int(y_train.max() + 1)
138
+
139
+ scores: list[float] = []
140
+ for run in range(args.runs):
141
+ run_seed = args.seed + run
142
+ logger.info(f"Run {run + 1}/{args.runs} (seed={run_seed})")
143
+
144
+ model = train_linear_probe(
145
+ x_train,
146
+ y_train,
147
+ num_classes,
148
+ device=device,
149
+ epochs=args.epochs,
150
+ batch_size=args.batch_size,
151
+ lr=args.lr,
152
+ seed=run_seed,
153
+ )
154
+
155
+ y_pred, y_true = evaluate_linear_probe(model, x_val, y_val, batch_size=args.batch_size, device=device)
156
+ acc = float(np.mean(y_pred == y_true))
157
+ scores.append(acc)
158
+ logger.info(f"Run {run + 1}/{args.runs} - Accuracy: {acc:.4f}")
159
+
160
+ scores_arr = np.array(scores)
161
+ mean_acc = float(scores_arr.mean())
162
+ std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
163
+
164
+ logger.info(f"Mean accuracy over {args.runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
165
+
166
+ return {
167
+ "method": "linear",
168
+ "accuracy": mean_acc,
169
+ "accuracy_std": std_acc,
170
+ "num_runs": args.runs,
171
+ "embeddings_file": str(embeddings_path),
172
+ }
173
+
174
+
175
+ def evaluate_plankton(args: argparse.Namespace) -> None:
176
+ tic = time.time()
177
+
178
+ if args.gpu is True:
179
+ device = torch.device("cuda")
180
+ else:
181
+ device = torch.device("cpu")
182
+
183
+ if args.gpu_id is not None:
184
+ torch.cuda.set_device(args.gpu_id)
185
+
186
+ logger.info(f"Using device {device}")
187
+ logger.info(f"Loading Plankton dataset from {args.dataset_path}")
188
+ dataset = Plankton(args.dataset_path)
189
+ metadata_df = _load_plankton_metadata(dataset)
190
+ logger.info(f"Loaded metadata for {metadata_df.height} images")
191
+
192
+ results: list[dict[str, Any]] = []
193
+ total = len(args.embeddings)
194
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
195
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
196
+ x_train, y_train, x_val, y_val = _load_embeddings_with_split(embeddings_path, metadata_df)
197
+
198
+ result = evaluate_plankton_single(x_train, y_train, x_val, y_val, args, embeddings_path, device)
199
+ results.append(result)
200
+
201
+ _print_summary_table(results)
202
+
203
+ if args.dry_run is False:
204
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
205
+ output_dir.mkdir(parents=True, exist_ok=True)
206
+ output_path = output_dir.joinpath("plankton.csv")
207
+ _write_results_csv(results, output_path)
208
+
209
+ toc = time.time()
210
+ logger.info(f"Plankton benchmark completed in {lib.format_duration(toc - tic)}")
211
+
212
+
213
+ def set_parser(subparsers: Any) -> None:
214
+ subparser = subparsers.add_parser(
215
+ "plankton",
216
+ allow_abbrev=False,
217
+ help="run Plankton benchmark - 50 class phytoplankton classification using linear probing",
218
+ description="run Plankton benchmark - 50 class phytoplankton classification using linear probing",
219
+ epilog=(
220
+ "Usage examples:\n"
221
+ "python -m birder.eval plankton --embeddings "
222
+ "results/plankton_embeddings.parquet --dataset-path ~/Datasets/plankton --dry-run\n"
223
+ "python -m birder.eval plankton --embeddings results/plankton/*.parquet "
224
+ "--dataset-path ~/Datasets/plankton --gpu\n"
225
+ ),
226
+ formatter_class=cli.ArgumentHelpFormatter,
227
+ )
228
+ subparser.add_argument(
229
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
230
+ )
231
+ subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to Plankton dataset root")
232
+ subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
233
+ subparser.add_argument("--epochs", type=int, default=75, help="training epochs per run")
234
+ subparser.add_argument("--batch-size", type=int, default=128, help="batch size for training and inference")
235
+ subparser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
236
+ subparser.add_argument("--seed", type=int, default=0, help="base random seed")
237
+ subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
238
+ subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
239
+ subparser.add_argument(
240
+ "--dir", type=str, default="plankton", help="place all outputs in a sub-directory (relative to results)"
241
+ )
242
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
243
+ subparser.set_defaults(func=main)
244
+
245
+
246
+ def validate_args(args: argparse.Namespace) -> None:
247
+ if args.embeddings is None:
248
+ raise cli.ValidationError("--embeddings is required")
249
+ if args.dataset_path is None:
250
+ raise cli.ValidationError("--dataset-path is required")
251
+
252
+
253
+ def main(args: argparse.Namespace) -> None:
254
+ validate_args(args)
255
+ evaluate_plankton(args)