birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,259 @@
1
+ """
2
+ PlantDoc benchmark using SimpleShot for plant disease classification
3
+
4
+ Paper "PlantDoc: A Dataset for Visual Plant Disease Detection",
5
+ https://arxiv.org/abs/1911.10317
6
+ """
7
+
8
+ import argparse
9
+ import logging
10
+ import os
11
+ import time
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import numpy as np
16
+ import numpy.typing as npt
17
+ import polars as pl
18
+ from rich.console import Console
19
+ from rich.table import Table
20
+
21
+ from birder.common import cli
22
+ from birder.common import lib
23
+ from birder.conf import settings
24
+ from birder.data.datasets.directory import make_image_dataset
25
+ from birder.datahub.evaluation import PlantDoc
26
+ from birder.eval._embeddings import load_embeddings
27
+ from birder.eval.methods.simpleshot import evaluate_simpleshot
28
+ from birder.eval.methods.simpleshot import sample_k_shot
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def _print_summary_table(results: list[dict[str, Any]], k_shots: list[int]) -> None:
34
+ console = Console()
35
+
36
+ table = Table(show_header=True, header_style="bold dark_magenta")
37
+ table.add_column("PlantDoc (SimpleShot)", style="dim")
38
+ for k in k_shots:
39
+ table.add_column(f"{k}-shot", justify="right")
40
+
41
+ table.add_column("Runs", justify="right")
42
+
43
+ for result in results:
44
+ row = [Path(result["embeddings_file"]).name]
45
+ for k in k_shots:
46
+ acc = result["accuracies"].get(k)
47
+ row.append(f"{acc:.4f}" if acc is not None else "-")
48
+
49
+ row.append(f"{result['num_runs']}")
50
+ table.add_row(*row)
51
+
52
+ console.print(table)
53
+
54
+
55
+ def _write_results_csv(results: list[dict[str, Any]], k_shots: list[int], output_path: Path) -> None:
56
+ rows: list[dict[str, Any]] = []
57
+ for result in results:
58
+ row: dict[str, Any] = {
59
+ "embeddings_file": result["embeddings_file"],
60
+ "method": result["method"],
61
+ "num_runs": result["num_runs"],
62
+ }
63
+ for k in k_shots:
64
+ row[f"{k}_shot_acc"] = result["accuracies"].get(k)
65
+ row[f"{k}_shot_std"] = result["accuracies_std"].get(k)
66
+
67
+ rows.append(row)
68
+
69
+ pl.DataFrame(rows).write_csv(output_path)
70
+ logger.info(f"Results saved to {output_path}")
71
+
72
+
73
+ def _load_plantdoc_metadata(dataset: PlantDoc) -> pl.DataFrame:
74
+ """
75
+ Load metadata using make_image_dataset with a fixed class_to_idx
76
+
77
+ Returns DataFrame with columns: id (filename stem), label, split
78
+ """
79
+
80
+ # Build unified class_to_idx from the union of both splits
81
+ all_classes: set[str] = set()
82
+ for split_dir in [dataset.train_dir, dataset.test_dir]:
83
+ all_classes.update(entry.name for entry in os.scandir(str(split_dir)) if entry.is_dir())
84
+
85
+ class_to_idx = {cls_name: idx for idx, cls_name in enumerate(sorted(all_classes))}
86
+
87
+ rows: list[dict[str, Any]] = []
88
+ for split, split_dir in [("train", dataset.train_dir), ("test", dataset.test_dir)]:
89
+ image_dataset = make_image_dataset([str(split_dir)], class_to_idx)
90
+ for i in range(len(image_dataset)):
91
+ path = image_dataset.paths[i].decode("utf-8")
92
+ label = image_dataset.labels[i].item()
93
+ rows.append({"id": Path(path).stem, "label": label, "split": split})
94
+
95
+ return pl.DataFrame(rows)
96
+
97
+
98
+ def _load_embeddings_with_split(
99
+ embeddings_path: str, metadata_df: pl.DataFrame
100
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
101
+ logger.info(f"Loading embeddings from {embeddings_path}")
102
+ sample_ids, all_features = load_embeddings(embeddings_path)
103
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
104
+
105
+ joined = metadata_df.join(emb_df, on="id", how="inner")
106
+ if joined.height < metadata_df.height:
107
+ logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
108
+
109
+ all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
110
+ all_labels = joined.get_column("label").to_numpy().astype(np.int_)
111
+ splits = joined.get_column("split").to_list()
112
+
113
+ is_train = np.array([s == "train" for s in splits], dtype=bool)
114
+ is_test = np.array([s == "test" for s in splits], dtype=bool)
115
+
116
+ num_classes = all_labels.max() + 1
117
+ logger.info(
118
+ f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
119
+ )
120
+
121
+ x_train = all_features[is_train]
122
+ y_train = all_labels[is_train]
123
+ x_test = all_features[is_test]
124
+ y_test = all_labels[is_test]
125
+
126
+ logger.info(f"Train: {len(y_train)} samples, Test: {len(y_test)} samples")
127
+
128
+ return (x_train, y_train, x_test, y_test)
129
+
130
+
131
+ def evaluate_plantdoc_single(
132
+ x_train: npt.NDArray[np.float32],
133
+ y_train: npt.NDArray[np.int_],
134
+ x_test: npt.NDArray[np.float32],
135
+ y_test: npt.NDArray[np.int_],
136
+ k_shot: int,
137
+ num_runs: int,
138
+ seed: int,
139
+ embeddings_path: str,
140
+ ) -> dict[str, Any]:
141
+ logger.info(f"Evaluating {k_shot}-shot")
142
+
143
+ scores: list[float] = []
144
+ for run in range(num_runs):
145
+ run_seed = seed + run
146
+ rng = np.random.default_rng(run_seed)
147
+
148
+ # Sample k examples per class
149
+ x_train_k_shot, y_train_k_shot = sample_k_shot(x_train, y_train, k_shot, rng)
150
+
151
+ # Evaluate using SimpleShot
152
+ y_pred, y_true = evaluate_simpleshot(x_train_k_shot, y_train_k_shot, x_test, y_test)
153
+
154
+ acc = float(np.mean(y_pred == y_true))
155
+ scores.append(acc)
156
+ logger.info(f"Run {run + 1}/{num_runs} - Accuracy: {acc:.4f}")
157
+
158
+ scores_arr = np.array(scores)
159
+ mean_acc = float(scores_arr.mean())
160
+ std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
161
+
162
+ logger.info(f"Mean accuracy over {num_runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
163
+
164
+ return {
165
+ "method": "simpleshot",
166
+ "k_shot": k_shot,
167
+ "accuracy": mean_acc,
168
+ "accuracy_std": std_acc,
169
+ "num_runs": num_runs,
170
+ "embeddings_file": str(embeddings_path),
171
+ }
172
+
173
+
174
+ def evaluate_plantdoc(args: argparse.Namespace) -> None:
175
+ tic = time.time()
176
+
177
+ logger.info(f"Loading PlantDoc dataset from {args.dataset_path}")
178
+ dataset = PlantDoc(args.dataset_path)
179
+ metadata_df = _load_plantdoc_metadata(dataset)
180
+ logger.info(f"Loaded metadata for {metadata_df.height} images")
181
+
182
+ results: list[dict[str, Any]] = []
183
+ total = len(args.embeddings)
184
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
185
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
186
+ x_train, y_train, x_test, y_test = _load_embeddings_with_split(embeddings_path, metadata_df)
187
+
188
+ accuracies: dict[int, float] = {}
189
+ accuracies_std: dict[int, float] = {}
190
+ for k_shot in args.k_shot:
191
+ single_result = evaluate_plantdoc_single(
192
+ x_train, y_train, x_test, y_test, k_shot, args.runs, args.seed, embeddings_path
193
+ )
194
+ accuracies[k_shot] = single_result["accuracy"]
195
+ accuracies_std[k_shot] = single_result["accuracy_std"]
196
+
197
+ results.append(
198
+ {
199
+ "embeddings_file": str(embeddings_path),
200
+ "method": "simpleshot",
201
+ "num_runs": args.runs,
202
+ "accuracies": accuracies,
203
+ "accuracies_std": accuracies_std,
204
+ }
205
+ )
206
+
207
+ _print_summary_table(results, args.k_shot)
208
+
209
+ if args.dry_run is False:
210
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
211
+ output_dir.mkdir(parents=True, exist_ok=True)
212
+ output_path = output_dir.joinpath("plantdoc.csv")
213
+ _write_results_csv(results, args.k_shot, output_path)
214
+
215
+ toc = time.time()
216
+ logger.info(f"PlantDoc benchmark completed in {lib.format_duration(toc - tic)}")
217
+
218
+
219
+ def set_parser(subparsers: Any) -> None:
220
+ subparser = subparsers.add_parser(
221
+ "plantdoc",
222
+ allow_abbrev=False,
223
+ help="run PlantDoc benchmark - 27 class plant disease classification using SimpleShot",
224
+ description="run PlantDoc benchmark - 27 class plant disease classification using SimpleShot",
225
+ epilog=(
226
+ "Usage examples:\n"
227
+ "python -m birder.eval plantdoc --embeddings "
228
+ "results/plantdoc_embeddings.parquet --dataset-path ~/Datasets/PlantDoc --dry-run\n"
229
+ "python -m birder.eval plantdoc --embeddings results/plantdoc/*.parquet "
230
+ "--dataset-path ~/Datasets/PlantDoc\n"
231
+ ),
232
+ formatter_class=cli.ArgumentHelpFormatter,
233
+ )
234
+ subparser.add_argument(
235
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
236
+ )
237
+ subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to PlantDoc dataset root")
238
+ subparser.add_argument(
239
+ "--k-shot", type=int, nargs="+", default=[2, 5], help="number of examples per class for few-shot learning"
240
+ )
241
+ subparser.add_argument("--runs", type=int, default=5, help="number of evaluation runs")
242
+ subparser.add_argument("--seed", type=int, default=0, help="base random seed")
243
+ subparser.add_argument(
244
+ "--dir", type=str, default="plantdoc", help="place all outputs in a sub-directory (relative to results)"
245
+ )
246
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
247
+ subparser.set_defaults(func=main)
248
+
249
+
250
+ def validate_args(args: argparse.Namespace) -> None:
251
+ if args.embeddings is None:
252
+ raise cli.ValidationError("--embeddings is required")
253
+ if args.dataset_path is None:
254
+ raise cli.ValidationError("--dataset-path is required")
255
+
256
+
257
+ def main(args: argparse.Namespace) -> None:
258
+ validate_args(args)
259
+ evaluate_plantdoc(args)
@@ -0,0 +1,252 @@
1
+ """
2
+ PlantNet-300K benchmark using SimpleShot for plant species classification
3
+
4
+ Paper "Pl@ntNet-300K: a plant image dataset with high label ambiguity and a long-tailed distribution"
5
+ https://openreview.net/forum?id=eLYinD0TtIt
6
+ """
7
+
8
+ import argparse
9
+ import logging
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import numpy.typing as npt
16
+ import polars as pl
17
+ from rich.console import Console
18
+ from rich.table import Table
19
+ from torchvision.datasets import ImageFolder
20
+
21
+ from birder.common import cli
22
+ from birder.common import lib
23
+ from birder.conf import settings
24
+ from birder.datahub.evaluation import PlantNet
25
+ from birder.eval._embeddings import load_embeddings
26
+ from birder.eval.methods.simpleshot import evaluate_simpleshot
27
+ from birder.eval.methods.simpleshot import sample_k_shot
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def _print_summary_table(results: list[dict[str, Any]], k_shots: list[int]) -> None:
33
+ console = Console()
34
+
35
+ table = Table(show_header=True, header_style="bold dark_magenta")
36
+ table.add_column("PlantNet (SimpleShot)", style="dim")
37
+ for k in k_shots:
38
+ table.add_column(f"{k}-shot", justify="right")
39
+
40
+ table.add_column("Runs", justify="right")
41
+
42
+ for result in results:
43
+ row = [Path(result["embeddings_file"]).name]
44
+ for k in k_shots:
45
+ acc = result["accuracies"].get(k)
46
+ row.append(f"{acc:.4f}" if acc is not None else "-")
47
+
48
+ row.append(f"{result['num_runs']}")
49
+ table.add_row(*row)
50
+
51
+ console.print(table)
52
+
53
+
54
+ def _write_results_csv(results: list[dict[str, Any]], k_shots: list[int], output_path: Path) -> None:
55
+ rows: list[dict[str, Any]] = []
56
+ for result in results:
57
+ row: dict[str, Any] = {
58
+ "embeddings_file": result["embeddings_file"],
59
+ "method": result["method"],
60
+ "num_runs": result["num_runs"],
61
+ }
62
+ for k in k_shots:
63
+ row[f"{k}_shot_acc"] = result["accuracies"].get(k)
64
+ row[f"{k}_shot_std"] = result["accuracies_std"].get(k)
65
+
66
+ rows.append(row)
67
+
68
+ pl.DataFrame(rows).write_csv(output_path)
69
+ logger.info(f"Results saved to {output_path}")
70
+
71
+
72
+ def _load_plantnet_metadata(dataset: PlantNet) -> pl.DataFrame:
73
+ """
74
+ Load metadata from ImageFolder structure
75
+
76
+ Returns DataFrame with columns: id (filename stem), label, split
77
+ """
78
+
79
+ rows: list[dict[str, Any]] = []
80
+ for split, split_dir in [("train", dataset.train_dir), ("val", dataset.val_dir), ("test", dataset.test_dir)]:
81
+ if not split_dir.exists():
82
+ continue
83
+
84
+ image_dataset = ImageFolder(str(split_dir))
85
+ for path, label in image_dataset.samples:
86
+ rows.append({"id": Path(path).stem, "label": label, "split": split})
87
+
88
+ return pl.DataFrame(rows)
89
+
90
+
91
+ def _load_embeddings_with_split(
92
+ embeddings_path: str, metadata_df: pl.DataFrame
93
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
94
+ logger.info(f"Loading embeddings from {embeddings_path}")
95
+ sample_ids, all_features = load_embeddings(embeddings_path)
96
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
97
+
98
+ joined = metadata_df.join(emb_df, on="id", how="inner")
99
+ if joined.height < metadata_df.height:
100
+ logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
101
+
102
+ all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
103
+ all_labels = joined.get_column("label").to_numpy().astype(np.int_)
104
+ splits = joined.get_column("split").to_list()
105
+
106
+ is_train = np.array([s == "train" for s in splits], dtype=bool)
107
+ is_test = np.array([s == "test" for s in splits], dtype=bool)
108
+
109
+ num_classes = all_labels.max() + 1
110
+ logger.info(
111
+ f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
112
+ )
113
+
114
+ x_train = all_features[is_train]
115
+ y_train = all_labels[is_train]
116
+ x_test = all_features[is_test]
117
+ y_test = all_labels[is_test]
118
+
119
+ logger.info(f"Train: {len(y_train)} samples, Test: {len(y_test)} samples")
120
+
121
+ return (x_train, y_train, x_test, y_test)
122
+
123
+
124
+ def evaluate_plantnet_single(
125
+ x_train: npt.NDArray[np.float32],
126
+ y_train: npt.NDArray[np.int_],
127
+ x_test: npt.NDArray[np.float32],
128
+ y_test: npt.NDArray[np.int_],
129
+ k_shot: int,
130
+ num_runs: int,
131
+ seed: int,
132
+ embeddings_path: str,
133
+ ) -> dict[str, Any]:
134
+ logger.info(f"Evaluating {k_shot}-shot")
135
+
136
+ scores: list[float] = []
137
+ for run in range(num_runs):
138
+ run_seed = seed + run
139
+ rng = np.random.default_rng(run_seed)
140
+
141
+ # Sample k examples per class
142
+ x_train_k_shot, y_train_k_shot = sample_k_shot(x_train, y_train, k_shot, rng)
143
+
144
+ # Evaluate using SimpleShot
145
+ y_pred, y_true = evaluate_simpleshot(x_train_k_shot, y_train_k_shot, x_test, y_test)
146
+
147
+ acc = float(np.mean(y_pred == y_true))
148
+ scores.append(acc)
149
+ logger.info(f"Run {run + 1}/{num_runs} - Accuracy: {acc:.4f}")
150
+
151
+ scores_arr = np.array(scores)
152
+ mean_acc = float(scores_arr.mean())
153
+ std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
154
+
155
+ logger.info(f"Mean accuracy over {num_runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
156
+
157
+ return {
158
+ "method": "simpleshot",
159
+ "k_shot": k_shot,
160
+ "accuracy": mean_acc,
161
+ "accuracy_std": std_acc,
162
+ "num_runs": num_runs,
163
+ "embeddings_file": str(embeddings_path),
164
+ }
165
+
166
+
167
+ def evaluate_plantnet(args: argparse.Namespace) -> None:
168
+ tic = time.time()
169
+
170
+ logger.info(f"Loading PlantNet dataset from {args.dataset_path}")
171
+ dataset = PlantNet(args.dataset_path)
172
+ metadata_df = _load_plantnet_metadata(dataset)
173
+ logger.info(f"Loaded metadata for {metadata_df.height} images")
174
+
175
+ results: list[dict[str, Any]] = []
176
+ total = len(args.embeddings)
177
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
178
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
179
+ x_train, y_train, x_test, y_test = _load_embeddings_with_split(embeddings_path, metadata_df)
180
+
181
+ accuracies: dict[int, float] = {}
182
+ accuracies_std: dict[int, float] = {}
183
+ for k_shot in args.k_shot:
184
+ single_result = evaluate_plantnet_single(
185
+ x_train, y_train, x_test, y_test, k_shot, args.runs, args.seed, embeddings_path
186
+ )
187
+ accuracies[k_shot] = single_result["accuracy"]
188
+ accuracies_std[k_shot] = single_result["accuracy_std"]
189
+
190
+ results.append(
191
+ {
192
+ "embeddings_file": str(embeddings_path),
193
+ "method": "simpleshot",
194
+ "num_runs": args.runs,
195
+ "accuracies": accuracies,
196
+ "accuracies_std": accuracies_std,
197
+ }
198
+ )
199
+
200
+ _print_summary_table(results, args.k_shot)
201
+
202
+ if args.dry_run is False:
203
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
204
+ output_dir.mkdir(parents=True, exist_ok=True)
205
+ output_path = output_dir.joinpath("plantnet.csv")
206
+ _write_results_csv(results, args.k_shot, output_path)
207
+
208
+ toc = time.time()
209
+ logger.info(f"PlantNet benchmark completed in {lib.format_duration(toc - tic)}")
210
+
211
+
212
+ def set_parser(subparsers: Any) -> None:
213
+ subparser = subparsers.add_parser(
214
+ "plantnet",
215
+ allow_abbrev=False,
216
+ help="run PlantNet-300K benchmark - 1081 species classification using SimpleShot",
217
+ description="run PlantNet-300K benchmark - 1081 species classification using SimpleShot",
218
+ epilog=(
219
+ "Usage examples:\n"
220
+ "python -m birder.eval plantnet --embeddings "
221
+ "results/plantnet_embeddings.parquet --dataset-path ~/Datasets/plantnet_300K --dry-run\n"
222
+ "python -m birder.eval plantnet --embeddings results/plantnet/*.parquet "
223
+ "--dataset-path ~/Datasets/plantnet_300K\n"
224
+ ),
225
+ formatter_class=cli.ArgumentHelpFormatter,
226
+ )
227
+ subparser.add_argument(
228
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
229
+ )
230
+ subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to PlantNet dataset root")
231
+ subparser.add_argument(
232
+ "--k-shot", type=int, nargs="+", default=[2, 5], help="number of examples per class for few-shot learning"
233
+ )
234
+ subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
235
+ subparser.add_argument("--seed", type=int, default=0, help="base random seed")
236
+ subparser.add_argument(
237
+ "--dir", type=str, default="plantnet", help="place all outputs in a sub-directory (relative to results)"
238
+ )
239
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
240
+ subparser.set_defaults(func=main)
241
+
242
+
243
+ def validate_args(args: argparse.Namespace) -> None:
244
+ if args.embeddings is None:
245
+ raise cli.ValidationError("--embeddings is required")
246
+ if args.dataset_path is None:
247
+ raise cli.ValidationError("--dataset-path is required")
248
+
249
+
250
+ def main(args: argparse.Namespace) -> None:
251
+ validate_args(args)
252
+ evaluate_plantnet(args)