birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,261 @@
1
+ """
2
+ FungiCLEF2023 benchmark using KNN for fungi species classification
3
+
4
+ Link: https://www.imageclef.org/FungiCLEF2023
5
+ """
6
+
7
+ import argparse
8
+ import logging
9
+ import time
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ import polars as pl
16
+ from rich.console import Console
17
+ from rich.table import Table
18
+
19
+ from birder.common import cli
20
+ from birder.common import lib
21
+ from birder.conf import settings
22
+ from birder.datahub.evaluation import FungiCLEF2023
23
+ from birder.eval._embeddings import load_embeddings
24
+ from birder.eval.methods.knn import evaluate_knn
25
+ from birder.eval.methods.simpleshot import sample_k_shot
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def _print_summary_table(results: list[dict[str, Any]], k_values: list[int]) -> None:
31
+ console = Console()
32
+
33
+ table = Table(show_header=True, header_style="bold dark_magenta")
34
+ table.add_column("FungiCLEF2023 (KNN)", style="dim")
35
+ for k in k_values:
36
+ table.add_column(f"k={k}", justify="right")
37
+
38
+ table.add_column("Runs", justify="right")
39
+
40
+ for result in results:
41
+ row = [Path(result["embeddings_file"]).name]
42
+ for k in k_values:
43
+ acc = result["accuracies"].get(k)
44
+ row.append(f"{acc:.4f}" if acc is not None else "-")
45
+
46
+ row.append(f"{result['num_runs']}")
47
+ table.add_row(*row)
48
+
49
+ console.print(table)
50
+
51
+
52
+ def _write_results_csv(results: list[dict[str, Any]], k_values: list[int], output_path: Path) -> None:
53
+ rows: list[dict[str, Any]] = []
54
+ for result in results:
55
+ row: dict[str, Any] = {
56
+ "embeddings_file": result["embeddings_file"],
57
+ "method": result["method"],
58
+ "num_runs": result["num_runs"],
59
+ }
60
+ for k in k_values:
61
+ row[f"k_{k}_acc"] = result["accuracies"].get(k)
62
+ row[f"k_{k}_std"] = result["accuracies_std"].get(k)
63
+
64
+ rows.append(row)
65
+
66
+ pl.DataFrame(rows).write_csv(output_path)
67
+ logger.info(f"Results saved to {output_path}")
68
+
69
+
70
+ def _load_fungiclef_metadata(dataset: FungiCLEF2023) -> pl.DataFrame:
71
+ """
72
+ Load metadata from FungiCLEF2023 CSV files
73
+
74
+ Returns DataFrame with columns: id (filename stem), label, split (train/val/test).
75
+ Filters out validation samples with unknown species (class_id == -1).
76
+ Test samples have label=-1 (no ground truth available) and are excluded from evaluation.
77
+ """
78
+
79
+ train_df = pl.read_csv(dataset.train_metadata_path)
80
+ train_df = train_df.with_columns(
81
+ pl.col("image_path").map_elements(lambda p: Path(p).stem, return_dtype=pl.Utf8).alias("id"),
82
+ pl.lit("train").alias("split"),
83
+ ).select(["id", "class_id", "split"])
84
+
85
+ val_df = pl.read_csv(dataset.val_metadata_path)
86
+ val_df = val_df.filter(pl.col("class_id") >= 0)
87
+ val_df = val_df.with_columns(
88
+ pl.col("filename").alias("id"),
89
+ pl.lit("val").alias("split"),
90
+ ).select(["id", "class_id", "split"])
91
+
92
+ # Include test IDs so they are properly excluded when embeddings contain all samples
93
+ test_df = pl.read_csv(dataset.test_metadata_path)
94
+ test_df = test_df.with_columns(
95
+ pl.col("filename").alias("id"),
96
+ pl.lit("test").alias("split"),
97
+ pl.lit(-1, dtype=pl.Int64).alias("class_id"),
98
+ ).select(["id", "class_id", "split"])
99
+
100
+ metadata_df = pl.concat([train_df, val_df, test_df])
101
+ metadata_df = metadata_df.rename({"class_id": "label"})
102
+
103
+ return metadata_df
104
+
105
+
106
+ def _load_embeddings_with_split(
107
+ embeddings_path: str, metadata_df: pl.DataFrame
108
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
109
+ logger.info(f"Loading embeddings from {embeddings_path}")
110
+ sample_ids, all_features = load_embeddings(embeddings_path)
111
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
112
+
113
+ joined = metadata_df.join(emb_df, on="id", how="inner")
114
+ if joined.height < metadata_df.height:
115
+ logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
116
+
117
+ all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
118
+ all_labels = joined.get_column("label").to_numpy().astype(np.int_)
119
+ splits = joined.get_column("split").to_list()
120
+
121
+ is_train = np.array([s == "train" for s in splits], dtype=bool)
122
+ is_val = np.array([s == "val" for s in splits], dtype=bool)
123
+ is_test = np.array([s == "test" for s in splits], dtype=bool)
124
+
125
+ num_classes = len(np.unique(all_labels[is_train]))
126
+ logger.info(
127
+ f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
128
+ )
129
+
130
+ x_train = all_features[is_train]
131
+ y_train = all_labels[is_train]
132
+ x_val = all_features[is_val]
133
+ y_val = all_labels[is_val]
134
+
135
+ logger.info(
136
+ f"Train: {len(y_train)} samples, Val: {len(y_val)} samples, "
137
+ f"Test: {int(is_test.sum())} samples (no labels, excluded)"
138
+ )
139
+
140
+ return (x_train, y_train, x_val, y_val)
141
+
142
+
143
+ def _evaluate_single_k(
144
+ x_train: npt.NDArray[np.float32],
145
+ y_train: npt.NDArray[np.int_],
146
+ x_val: npt.NDArray[np.float32],
147
+ y_val: npt.NDArray[np.int_],
148
+ k: int,
149
+ num_runs: int,
150
+ seed: int,
151
+ ) -> tuple[float, float]:
152
+ logger.info(f"Evaluating k={k} ({k}-shot sampling, KNN k={k})")
153
+
154
+ scores: list[float] = []
155
+ for run in range(num_runs):
156
+ run_seed = seed + run
157
+ rng = np.random.default_rng(run_seed)
158
+
159
+ # Sample k examples per class
160
+ x_train_k, y_train_k = sample_k_shot(x_train, y_train, k, rng)
161
+
162
+ # Evaluate using KNN with k neighbors
163
+ y_pred, y_true = evaluate_knn(x_train_k, y_train_k, x_val, y_val, k=k)
164
+
165
+ acc = float(np.mean(y_pred == y_true))
166
+ scores.append(acc)
167
+ logger.info(f"Run {run + 1}/{num_runs} - Accuracy: {acc:.4f}")
168
+
169
+ scores_arr = np.array(scores)
170
+ mean_acc = float(scores_arr.mean())
171
+ std_acc = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
172
+
173
+ logger.info(f"k={k} - Mean accuracy over {num_runs} runs: {mean_acc:.4f} +/- {std_acc:.4f} (std)")
174
+
175
+ return (mean_acc, std_acc)
176
+
177
+
178
+ def evaluate_fungiclef(args: argparse.Namespace) -> None:
179
+ tic = time.time()
180
+
181
+ logger.info(f"Loading FungiCLEF2023 dataset from {args.dataset_path}")
182
+ dataset = FungiCLEF2023(args.dataset_path)
183
+ metadata_df = _load_fungiclef_metadata(dataset)
184
+ logger.info(f"Loaded metadata for {metadata_df.height} images")
185
+
186
+ results: list[dict[str, Any]] = []
187
+ total = len(args.embeddings)
188
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
189
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
190
+ x_train, y_train, x_val, y_val = _load_embeddings_with_split(embeddings_path, metadata_df)
191
+
192
+ accuracies: dict[int, float] = {}
193
+ accuracies_std: dict[int, float] = {}
194
+ for k in args.k:
195
+ mean_acc, std_acc = _evaluate_single_k(x_train, y_train, x_val, y_val, k, args.runs, args.seed)
196
+ accuracies[k] = mean_acc
197
+ accuracies_std[k] = std_acc
198
+
199
+ results.append(
200
+ {
201
+ "embeddings_file": str(embeddings_path),
202
+ "method": "knn",
203
+ "num_runs": args.runs,
204
+ "accuracies": accuracies,
205
+ "accuracies_std": accuracies_std,
206
+ }
207
+ )
208
+
209
+ _print_summary_table(results, args.k)
210
+
211
+ if args.dry_run is False:
212
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
213
+ output_dir.mkdir(parents=True, exist_ok=True)
214
+ output_path = output_dir.joinpath("fungiclef.csv")
215
+ _write_results_csv(results, args.k, output_path)
216
+
217
+ toc = time.time()
218
+ logger.info(f"FungiCLEF2023 benchmark completed in {lib.format_duration(toc - tic)}")
219
+
220
+
221
+ def set_parser(subparsers: Any) -> None:
222
+ subparser = subparsers.add_parser(
223
+ "fungiclef",
224
+ allow_abbrev=False,
225
+ help="run FungiCLEF2023 benchmark - 1,604 species classification using KNN",
226
+ description="run FungiCLEF2023 benchmark - 1,604 species classification using KNN",
227
+ epilog=(
228
+ "Usage examples:\n"
229
+ "python -m birder.eval fungiclef --embeddings "
230
+ "results/fungiclef_embeddings.parquet --dataset-path ~/Datasets/FungiCLEF2023 --dry-run\n"
231
+ "python -m birder.eval fungiclef --embeddings results/fungiclef/*.parquet "
232
+ "--dataset-path ~/Datasets/FungiCLEF2023\n"
233
+ ),
234
+ formatter_class=cli.ArgumentHelpFormatter,
235
+ )
236
+ subparser.add_argument(
237
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
238
+ )
239
+ subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to FungiCLEF2023 dataset root")
240
+ subparser.add_argument(
241
+ "--k", type=int, nargs="+", default=[1, 3], help="k value for k-shot sampling and KNN neighbors"
242
+ )
243
+ subparser.add_argument("--runs", type=int, default=5, help="number of evaluation runs")
244
+ subparser.add_argument("--seed", type=int, default=0, help="base random seed")
245
+ subparser.add_argument(
246
+ "--dir", type=str, default="fungiclef", help="place all outputs in a sub-directory (relative to results)"
247
+ )
248
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
249
+ subparser.set_defaults(func=main)
250
+
251
+
252
+ def validate_args(args: argparse.Namespace) -> None:
253
+ if args.embeddings is None:
254
+ raise cli.ValidationError("--embeddings is required")
255
+ if args.dataset_path is None:
256
+ raise cli.ValidationError("--dataset-path is required")
257
+
258
+
259
+ def main(args: argparse.Namespace) -> None:
260
+ validate_args(args)
261
+ evaluate_fungiclef(args)
@@ -0,0 +1,202 @@
1
+ """
2
+ NABirds benchmark using KNN for bird species classification
3
+
4
+ Website: https://dl.allaboutbirds.org/nabirds
5
+ """
6
+
7
+ import argparse
8
+ import logging
9
+ import time
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ import polars as pl
16
+ from rich.console import Console
17
+ from rich.table import Table
18
+
19
+ from birder.common import cli
20
+ from birder.common import lib
21
+ from birder.conf import settings
22
+ from birder.datahub.evaluation import NABirds
23
+ from birder.eval._embeddings import load_embeddings
24
+ from birder.eval.methods.knn import evaluate_knn
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def _print_summary_table(results: list[dict[str, Any]], k_values: list[int]) -> None:
30
+ console = Console()
31
+
32
+ table = Table(show_header=True, header_style="bold dark_magenta")
33
+ table.add_column("NABirds (KNN)", style="dim")
34
+ for k in k_values:
35
+ table.add_column(f"k={k}", justify="right")
36
+
37
+ for result in results:
38
+ row = [Path(result["embeddings_file"]).name]
39
+ for k in k_values:
40
+ acc = result["accuracies"].get(k)
41
+ row.append(f"{acc:.4f}" if acc is not None else "-")
42
+
43
+ table.add_row(*row)
44
+
45
+ console.print(table)
46
+
47
+
48
+ def _write_results_csv(results: list[dict[str, Any]], k_values: list[int], output_path: Path) -> None:
49
+ rows: list[dict[str, Any]] = []
50
+ for result in results:
51
+ row: dict[str, Any] = {
52
+ "embeddings_file": result["embeddings_file"],
53
+ "method": result["method"],
54
+ }
55
+ for k in k_values:
56
+ row[f"k_{k}_acc"] = result["accuracies"].get(k)
57
+
58
+ rows.append(row)
59
+
60
+ pl.DataFrame(rows).write_csv(output_path)
61
+ logger.info(f"Results saved to {output_path}")
62
+
63
+
64
+ def _load_nabirds_metadata(dataset: NABirds) -> pl.DataFrame:
65
+ images_df = pl.read_csv(dataset.images_path, separator=" ", has_header=False, new_columns=["image_id", "filepath"])
66
+ images_df = images_df.with_columns(
67
+ pl.col("filepath").map_elements(lambda p: Path(p).stem, return_dtype=pl.Utf8).alias("id")
68
+ )
69
+ labels_df = pl.read_csv(dataset.labels_path, separator=" ", has_header=False, new_columns=["image_id", "class_id"])
70
+ classes_df = pl.read_csv(
71
+ dataset.classes_path, separator=" ", has_header=False, new_columns=["class_id", "class_name"]
72
+ )
73
+ split_df = pl.read_csv(
74
+ dataset.train_test_split_path, separator=" ", has_header=False, new_columns=["image_id", "is_train"]
75
+ )
76
+
77
+ metadata_df = (
78
+ images_df.join(labels_df, on="image_id")
79
+ .join(classes_df, on="class_id")
80
+ .join(split_df, on="image_id")
81
+ .select(["id", "class_id", "class_name", "is_train"])
82
+ )
83
+
84
+ # Create contiguous label indices (0 to num_classes-1)
85
+ unique_classes = metadata_df.get_column("class_id").unique().sort()
86
+ class_id_to_label = {cid: idx for idx, cid in enumerate(unique_classes.to_list())}
87
+ metadata_df = metadata_df.with_columns(pl.col("class_id").replace_strict(class_id_to_label).alias("label"))
88
+
89
+ return metadata_df
90
+
91
+
92
+ def _load_embeddings_with_split(
93
+ embeddings_path: str, metadata_df: pl.DataFrame
94
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_], npt.NDArray[np.float32], npt.NDArray[np.int_]]:
95
+ logger.info(f"Loading embeddings from {embeddings_path}")
96
+ sample_ids, all_features = load_embeddings(embeddings_path)
97
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
98
+
99
+ joined = metadata_df.join(emb_df, on="id", how="inner")
100
+ if joined.height < metadata_df.height:
101
+ logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
102
+
103
+ all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
104
+ all_labels = joined.get_column("label").to_numpy().astype(np.int_)
105
+ is_train = joined.get_column("is_train").to_numpy().astype(bool)
106
+
107
+ num_classes = all_labels.max() + 1
108
+ logger.info(
109
+ f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
110
+ )
111
+
112
+ x_train = all_features[is_train]
113
+ y_train = all_labels[is_train]
114
+ x_test = all_features[~is_train]
115
+ y_test = all_labels[~is_train]
116
+
117
+ logger.info(f"Train: {len(y_train)} samples, Test: {len(y_test)} samples")
118
+
119
+ return (x_train, y_train, x_test, y_test)
120
+
121
+
122
+ def evaluate_nabirds(args: argparse.Namespace) -> None:
123
+ tic = time.time()
124
+
125
+ logger.info(f"Loading NABirds dataset from {args.dataset_path}")
126
+ dataset = NABirds(args.dataset_path)
127
+ metadata_df = _load_nabirds_metadata(dataset)
128
+ logger.info(f"Loaded metadata for {metadata_df.height} images")
129
+
130
+ results: list[dict[str, Any]] = []
131
+ total = len(args.embeddings)
132
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
133
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
134
+ x_train, y_train, x_test, y_test = _load_embeddings_with_split(embeddings_path, metadata_df)
135
+
136
+ accuracies: dict[int, float] = {}
137
+ for k in args.k:
138
+ logger.info(f"Evaluating KNN with k={k}")
139
+ y_pred, y_true = evaluate_knn(x_train, y_train, x_test, y_test, k=k)
140
+ acc = float(np.mean(y_pred == y_true))
141
+ accuracies[k] = acc
142
+ logger.info(f"k={k} - Accuracy: {acc:.4f}")
143
+
144
+ results.append(
145
+ {
146
+ "embeddings_file": str(embeddings_path),
147
+ "method": "knn",
148
+ "accuracies": accuracies,
149
+ }
150
+ )
151
+
152
+ _print_summary_table(results, args.k)
153
+
154
+ if args.dry_run is False:
155
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
156
+ output_dir.mkdir(parents=True, exist_ok=True)
157
+ output_path = output_dir.joinpath("nabirds.csv")
158
+ _write_results_csv(results, args.k, output_path)
159
+
160
+ toc = time.time()
161
+ logger.info(f"NABirds benchmark completed in {lib.format_duration(toc - tic)}")
162
+
163
+
164
+ def set_parser(subparsers: Any) -> None:
165
+ subparser = subparsers.add_parser(
166
+ "nabirds",
167
+ allow_abbrev=False,
168
+ help="run NABirds benchmark - 555 class classification using KNN",
169
+ description="run NABirds benchmark - 555 class classification using KNN",
170
+ epilog=(
171
+ "Usage examples:\n"
172
+ "python -m birder.eval nabirds --embeddings "
173
+ "results/vit_b16_224px_crop1.0_48562_embeddings.parquet "
174
+ "--dataset-path ~/Datasets/nabirds --dry-run\n"
175
+ "python -m birder.eval nabirds --embeddings results/nabirds/*.parquet --dataset-path ~/Datasets/nabirds\n"
176
+ ),
177
+ formatter_class=cli.ArgumentHelpFormatter,
178
+ )
179
+ subparser.add_argument(
180
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
181
+ )
182
+ subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to NABirds dataset root")
183
+ subparser.add_argument(
184
+ "--k", type=int, nargs="+", default=[10, 20, 100], help="number of nearest neighbors for KNN"
185
+ )
186
+ subparser.add_argument(
187
+ "--dir", type=str, default="nabirds", help="place all outputs in a sub-directory (relative to results)"
188
+ )
189
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
190
+ subparser.set_defaults(func=main)
191
+
192
+
193
+ def validate_args(args: argparse.Namespace) -> None:
194
+ if args.embeddings is None:
195
+ raise cli.ValidationError("--embeddings is required")
196
+ if args.dataset_path is None:
197
+ raise cli.ValidationError("--dataset-path is required")
198
+
199
+
200
+ def main(args: argparse.Namespace) -> None:
201
+ validate_args(args)
202
+ evaluate_nabirds(args)