birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,318 @@
1
+ """
2
+ FishNet benchmark using MLP probe for multi-label trait prediction
3
+
4
+ Paper "FishNet: A Large-scale Dataset and Benchmark for Fish Recognition, Detection, and Functional Trait Prediction"
5
+ https://fishnet-2023.github.io/
6
+ """
7
+
8
+ import argparse
9
+ import logging
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import numpy.typing as npt
16
+ import polars as pl
17
+ import torch
18
+ from rich.console import Console
19
+ from rich.table import Table
20
+ from sklearn.metrics import f1_score
21
+
22
+ from birder.common import cli
23
+ from birder.common import lib
24
+ from birder.conf import settings
25
+ from birder.datahub.evaluation import FishNet
26
+ from birder.eval._embeddings import load_embeddings
27
+ from birder.eval.methods.mlp import evaluate_mlp
28
+ from birder.eval.methods.mlp import train_mlp
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def _print_summary_table(results: list[dict[str, Any]]) -> None:
34
+ console = Console()
35
+
36
+ table = Table(show_header=True, header_style="bold dark_magenta")
37
+ table.add_column("FishNet (MLP)", style="dim")
38
+ table.add_column("Macro F1", justify="right")
39
+ table.add_column("Std", justify="right")
40
+ table.add_column("Exact Match", justify="right")
41
+ table.add_column("Std", justify="right")
42
+ table.add_column("Runs", justify="right")
43
+ for result in results:
44
+ row = [
45
+ Path(result["embeddings_file"]).name,
46
+ f"{result['macro_f1']:.4f}",
47
+ f"{result['macro_f1_std']:.4f}",
48
+ f"{result['exact_match_acc']:.4f}",
49
+ f"{result['exact_match_acc_std']:.4f}",
50
+ f"{result['num_runs']}",
51
+ ]
52
+
53
+ table.add_row(*row)
54
+
55
+ console.print(table)
56
+
57
+
58
+ def _write_results_csv(results: list[dict[str, Any]], trait_names: list[str], output_path: Path) -> None:
59
+ rows: list[dict[str, Any]] = []
60
+ for result in results:
61
+ row: dict[str, Any] = {
62
+ "embeddings_file": result["embeddings_file"],
63
+ "method": result["method"],
64
+ "macro_f1": result["macro_f1"],
65
+ "macro_f1_std": result["macro_f1_std"],
66
+ "exact_match_acc": result["exact_match_acc"],
67
+ "exact_match_acc_std": result["exact_match_acc_std"],
68
+ "num_runs": result["num_runs"],
69
+ }
70
+ for trait in trait_names:
71
+ row[f"f1_{trait}"] = result["per_trait_f1"].get(trait)
72
+
73
+ rows.append(row)
74
+
75
+ pl.DataFrame(rows).write_csv(output_path)
76
+ logger.info(f"Results saved to {output_path}")
77
+
78
+
79
+ def _load_fishnet_data(csv_path: Path, trait_columns: list[str]) -> pl.DataFrame:
80
+ """
81
+ Load FishNet CSV and prepare metadata
82
+
83
+ Returns DataFrame with columns: id, trait labels (0/1)
84
+ """
85
+
86
+ df = pl.read_csv(csv_path)
87
+ df = df.with_columns(
88
+ pl.col("image")
89
+ .str.extract(r"([^/]+)$") # Get filename (last path segment)
90
+ .str.replace(r"\.[^.]+$", "") # Remove extension
91
+ .alias("id")
92
+ )
93
+
94
+ # Encode FeedingPath: benthic=0, pelagic=1
95
+ df = df.with_columns(
96
+ pl.when(pl.col("FeedingPath") == "pelagic")
97
+ .then(pl.lit(1))
98
+ .when(pl.col("FeedingPath") == "benthic")
99
+ .then(pl.lit(0))
100
+ .otherwise(pl.lit(None))
101
+ .alias("FeedingPath_encoded")
102
+ )
103
+
104
+ # Select relevant columns
105
+ other_traits = [t for t in trait_columns if t != "FeedingPath"]
106
+ select_cols = ["id", "FeedingPath_encoded"] + other_traits
107
+ df = df.select(select_cols)
108
+
109
+ # Rename FeedingPath_encoded back to FeedingPath
110
+ df = df.rename({"FeedingPath_encoded": "FeedingPath"})
111
+
112
+ # Filter rows with any null trait values
113
+ for trait in trait_columns:
114
+ df = df.filter(pl.col(trait).is_not_null())
115
+
116
+ return df
117
+
118
+
119
+ def _load_embeddings_with_labels(
120
+ embeddings_path: str, train_df: pl.DataFrame, test_df: pl.DataFrame, trait_columns: list[str]
121
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]:
122
+ logger.info(f"Loading embeddings from {embeddings_path}")
123
+ sample_ids, all_features = load_embeddings(embeddings_path)
124
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
125
+
126
+ # Join with train data
127
+ train_joined = train_df.join(emb_df, on="id", how="inner")
128
+ if train_joined.height < train_df.height:
129
+ logger.warning(f"Train: dropped {train_df.height - train_joined.height} samples (missing embeddings)")
130
+
131
+ # Join with test data
132
+ test_joined = test_df.join(emb_df, on="id", how="inner")
133
+ if test_joined.height < test_df.height:
134
+ logger.warning(f"Test: dropped {test_df.height - test_joined.height} samples (missing embeddings)")
135
+
136
+ # Extract features and labels
137
+ x_train = np.array(train_joined.get_column("embedding").to_list(), dtype=np.float32)
138
+ y_train = train_joined.select(trait_columns).to_numpy().astype(np.float32)
139
+
140
+ x_test = np.array(test_joined.get_column("embedding").to_list(), dtype=np.float32)
141
+ y_test = test_joined.select(trait_columns).to_numpy().astype(np.float32)
142
+
143
+ logger.info(f"Train: {x_train.shape[0]} samples, Test: {x_test.shape[0]} samples")
144
+ logger.info(f"Features: {x_train.shape[1]} dims, Traits: {len(trait_columns)}")
145
+
146
+ return (x_train, y_train, x_test, y_test)
147
+
148
+
149
+ # pylint: disable=too-many-locals
150
+ def evaluate_fishnet_single(
151
+ x_train: npt.NDArray[np.float32],
152
+ y_train: npt.NDArray[np.float32],
153
+ x_test: npt.NDArray[np.float32],
154
+ y_test: npt.NDArray[np.float32],
155
+ trait_columns: list[str],
156
+ args: argparse.Namespace,
157
+ embeddings_path: str,
158
+ device: torch.device,
159
+ ) -> dict[str, Any]:
160
+ num_classes = len(trait_columns)
161
+
162
+ scores: list[float] = []
163
+ exact_match_scores: list[float] = []
164
+ per_trait_f1_runs: list[dict[str, float]] = []
165
+
166
+ for run in range(args.runs):
167
+ run_seed = args.seed + run
168
+ logger.info(f"Run {run + 1}/{args.runs} (seed={run_seed})")
169
+
170
+ # Train MLP
171
+ model = train_mlp(
172
+ x_train,
173
+ y_train,
174
+ num_classes=num_classes,
175
+ device=device,
176
+ epochs=args.epochs,
177
+ batch_size=args.batch_size,
178
+ lr=args.lr,
179
+ hidden_dim=args.hidden_dim,
180
+ dropout=args.dropout,
181
+ seed=run_seed,
182
+ )
183
+
184
+ # Evaluate
185
+ y_pred, y_true, macro_f1 = evaluate_mlp(model, x_test, y_test, batch_size=args.batch_size, device=device)
186
+ scores.append(macro_f1)
187
+ exact_match_acc = float(np.mean(np.all(y_pred == y_true, axis=1)))
188
+ exact_match_scores.append(exact_match_acc)
189
+
190
+ # Per-trait F1
191
+ per_trait_f1: dict[str, float] = {}
192
+ for i, trait in enumerate(trait_columns):
193
+ trait_f1 = f1_score(y_true[:, i], y_pred[:, i], average="binary", zero_division=0.0)
194
+ per_trait_f1[trait] = float(trait_f1)
195
+
196
+ per_trait_f1_runs.append(per_trait_f1)
197
+ logger.info(f"Run {run + 1}/{args.runs} - Macro F1: {macro_f1:.4f}, Exact Match: {exact_match_acc:.4f}")
198
+
199
+ # Average results
200
+ scores_arr = np.array(scores)
201
+ mean_f1 = float(scores_arr.mean())
202
+ std_f1 = float(scores_arr.std(ddof=1)) if len(scores) > 1 else 0.0
203
+ exact_scores_arr = np.array(exact_match_scores)
204
+ mean_exact = float(exact_scores_arr.mean())
205
+ std_exact = float(exact_scores_arr.std(ddof=1)) if len(exact_match_scores) > 1 else 0.0
206
+
207
+ # Average per-trait F1 across runs
208
+ avg_per_trait_f1: dict[str, float] = {}
209
+ for trait in trait_columns:
210
+ trait_scores = [run_f1[trait] for run_f1 in per_trait_f1_runs]
211
+ avg_per_trait_f1[trait] = float(np.mean(trait_scores))
212
+
213
+ logger.info(f"Mean Macro F1 over {args.runs} runs: {mean_f1:.4f} +/- {std_f1:.4f} (std)")
214
+ logger.info(f"Mean Exact Match over {args.runs} runs: {mean_exact:.4f} +/- {std_exact:.4f} (std)")
215
+ for trait, f1 in avg_per_trait_f1.items():
216
+ logger.info(f" {trait}: {f1:.4f}")
217
+
218
+ return {
219
+ "method": "mlp",
220
+ "macro_f1": mean_f1,
221
+ "macro_f1_std": std_f1,
222
+ "exact_match_acc": mean_exact,
223
+ "exact_match_acc_std": std_exact,
224
+ "num_runs": args.runs,
225
+ "per_trait_f1": avg_per_trait_f1,
226
+ "embeddings_file": str(embeddings_path),
227
+ }
228
+
229
+
230
+ def evaluate_fishnet(args: argparse.Namespace) -> None:
231
+ tic = time.time()
232
+
233
+ if args.gpu is True:
234
+ device = torch.device("cuda")
235
+ else:
236
+ device = torch.device("cpu")
237
+
238
+ if args.gpu_id is not None:
239
+ torch.cuda.set_device(args.gpu_id)
240
+
241
+ logger.info(f"Using device {device}")
242
+ logger.info(f"Loading FishNet dataset from {args.dataset_path}")
243
+ dataset = FishNet(args.dataset_path)
244
+ trait_columns = dataset.trait_columns
245
+
246
+ train_df = _load_fishnet_data(dataset.train_csv, trait_columns)
247
+ test_df = _load_fishnet_data(dataset.test_csv, trait_columns)
248
+ logger.info(f"Train samples: {train_df.height}, Test samples: {test_df.height}")
249
+
250
+ results: list[dict[str, Any]] = []
251
+ total = len(args.embeddings)
252
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
253
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
254
+ x_train, y_train, x_test, y_test = _load_embeddings_with_labels(
255
+ embeddings_path, train_df, test_df, trait_columns
256
+ )
257
+
258
+ result = evaluate_fishnet_single(x_train, y_train, x_test, y_test, trait_columns, args, embeddings_path, device)
259
+ results.append(result)
260
+
261
+ _print_summary_table(results)
262
+
263
+ if args.dry_run is False:
264
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
265
+ output_dir.mkdir(parents=True, exist_ok=True)
266
+ output_path = output_dir.joinpath("fishnet.csv")
267
+ _write_results_csv(results, trait_columns, output_path)
268
+
269
+ toc = time.time()
270
+ logger.info(f"FishNet benchmark completed in {lib.format_duration(toc - tic)}")
271
+
272
+
273
+ def set_parser(subparsers: Any) -> None:
274
+ subparser = subparsers.add_parser(
275
+ "fishnet",
276
+ allow_abbrev=False,
277
+ help="run FishNet benchmark - 9 trait multi-label classification using MLP probe",
278
+ description="run FishNet benchmark - 9 trait multi-label classification using MLP probe",
279
+ epilog=(
280
+ "Usage examples:\n"
281
+ "python -m birder.eval fishnet --embeddings "
282
+ "results/vit_b16_224px_embeddings.parquet "
283
+ "--dataset-path ~/Datasets/fishnet --dry-run\n"
284
+ "python -m birder.eval fishnet --embeddings results/fishnet/*.parquet "
285
+ "--dataset-path ~/Datasets/fishnet --gpu --gpu-id 1\n"
286
+ ),
287
+ formatter_class=cli.ArgumentHelpFormatter,
288
+ )
289
+ subparser.add_argument(
290
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
291
+ )
292
+ subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to FishNet dataset root")
293
+ subparser.add_argument("--runs", type=int, default=3, help="number of evaluation runs")
294
+ subparser.add_argument("--epochs", type=int, default=100, help="training epochs per run")
295
+ subparser.add_argument("--batch-size", type=int, default=128, help="batch size for training and inference")
296
+ subparser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
297
+ subparser.add_argument("--hidden-dim", type=int, default=512, help="MLP hidden layer dimension")
298
+ subparser.add_argument("--dropout", type=float, default=0.5, help="dropout probability")
299
+ subparser.add_argument("--seed", type=int, default=0, help="base random seed")
300
+ subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
301
+ subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
302
+ subparser.add_argument(
303
+ "--dir", type=str, default="fishnet", help="place all outputs in a sub-directory (relative to results)"
304
+ )
305
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
306
+ subparser.set_defaults(func=main)
307
+
308
+
309
+ def validate_args(args: argparse.Namespace) -> None:
310
+ if args.embeddings is None:
311
+ raise cli.ValidationError("--embeddings is required")
312
+ if args.dataset_path is None:
313
+ raise cli.ValidationError("--dataset-path is required")
314
+
315
+
316
+ def main(args: argparse.Namespace) -> None:
317
+ validate_args(args)
318
+ evaluate_fishnet(args)
@@ -0,0 +1,210 @@
1
+ """
2
+ Flowers102 benchmark using SimpleShot for flower species classification
3
+
4
+ Paper "Automated Flower Classification over a Large Number of Classes"
5
+ https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
6
+ """
7
+
8
+ import argparse
9
+ import logging
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import numpy.typing as npt
16
+ import polars as pl
17
+ from rich.console import Console
18
+ from rich.table import Table
19
+ from torchvision.datasets import ImageFolder
20
+
21
+ from birder.common import cli
22
+ from birder.common import lib
23
+ from birder.conf import settings
24
+ from birder.eval._embeddings import load_embeddings
25
+ from birder.eval.methods.simpleshot import evaluate_simpleshot
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def _print_summary_table(results: list[dict[str, Any]]) -> None:
31
+ console = Console()
32
+
33
+ table = Table(show_header=True, header_style="bold dark_magenta")
34
+ table.add_column("Flowers102 (SimpleShot)", style="dim")
35
+ table.add_column("Val Acc", justify="right")
36
+ table.add_column("Test Acc", justify="right")
37
+
38
+ for result in results:
39
+ table.add_row(
40
+ Path(result["embeddings_file"]).name,
41
+ f"{result['val_accuracy']:.4f}",
42
+ f"{result['test_accuracy']:.4f}",
43
+ )
44
+
45
+ console.print(table)
46
+
47
+
48
+ def _write_results_csv(results: list[dict[str, Any]], output_path: Path) -> None:
49
+ rows: list[dict[str, Any]] = []
50
+ for result in results:
51
+ rows.append(
52
+ {
53
+ "embeddings_file": result["embeddings_file"],
54
+ "method": result["method"],
55
+ "val_accuracy": result["val_accuracy"],
56
+ "test_accuracy": result["test_accuracy"],
57
+ }
58
+ )
59
+
60
+ pl.DataFrame(rows).write_csv(output_path)
61
+ logger.info(f"Results saved to {output_path}")
62
+
63
+
64
+ def _load_flowers102_metadata(dataset_path: Path) -> pl.DataFrame:
65
+ rows: list[dict[str, Any]] = []
66
+ for split in ["training", "validation", "testing"]:
67
+ split_dir = dataset_path.joinpath(split)
68
+ if not split_dir.exists():
69
+ continue
70
+
71
+ dataset = ImageFolder(str(split_dir))
72
+ for path, label in dataset.samples:
73
+ rows.append({"id": Path(path).stem, "label": label, "split": split})
74
+
75
+ return pl.DataFrame(rows)
76
+
77
+
78
+ def _load_embeddings_with_split(embeddings_path: str, metadata_df: pl.DataFrame) -> tuple[
79
+ npt.NDArray[np.float32],
80
+ npt.NDArray[np.int_],
81
+ npt.NDArray[np.float32],
82
+ npt.NDArray[np.int_],
83
+ npt.NDArray[np.float32],
84
+ npt.NDArray[np.int_],
85
+ ]:
86
+ logger.info(f"Loading embeddings from {embeddings_path}")
87
+ sample_ids, all_features = load_embeddings(embeddings_path)
88
+ emb_df = pl.DataFrame({"id": sample_ids, "embedding": all_features.tolist()})
89
+
90
+ joined = metadata_df.join(emb_df, on="id", how="inner")
91
+ if joined.height < metadata_df.height:
92
+ logger.warning(f"Join dropped {metadata_df.height - joined.height} samples (missing embeddings)")
93
+
94
+ all_features = np.array(joined.get_column("embedding").to_list(), dtype=np.float32)
95
+ all_labels = joined.get_column("label").to_numpy().astype(np.int_)
96
+ splits = joined.get_column("split").to_list()
97
+
98
+ is_train = np.array([s == "training" for s in splits], dtype=bool)
99
+ is_val = np.array([s == "validation" for s in splits], dtype=bool)
100
+ is_test = np.array([s == "testing" for s in splits], dtype=bool)
101
+
102
+ num_classes = all_labels.max() + 1
103
+ logger.info(
104
+ f"Loaded {all_features.shape[0]} samples with {all_features.shape[1]} dimensions, {num_classes} classes"
105
+ )
106
+
107
+ x_train = all_features[is_train]
108
+ y_train = all_labels[is_train]
109
+ x_val = all_features[is_val]
110
+ y_val = all_labels[is_val]
111
+ x_test = all_features[is_test]
112
+ y_test = all_labels[is_test]
113
+
114
+ logger.info(f"Train: {len(y_train)}, Val: {len(y_val)}, Test: {len(y_test)} samples")
115
+
116
+ return (x_train, y_train, x_val, y_val, x_test, y_test)
117
+
118
+
119
+ def evaluate_flowers102_single(
120
+ x_train: npt.NDArray[np.float32],
121
+ y_train: npt.NDArray[np.int_],
122
+ x_val: npt.NDArray[np.float32],
123
+ y_val: npt.NDArray[np.int_],
124
+ x_test: npt.NDArray[np.float32],
125
+ y_test: npt.NDArray[np.int_],
126
+ embeddings_path: str,
127
+ ) -> dict[str, Any]:
128
+ # Evaluate on validation set
129
+ y_pred_val, y_true_val = evaluate_simpleshot(x_train, y_train, x_val, y_val)
130
+ val_acc = float(np.mean(y_pred_val == y_true_val))
131
+ logger.info(f"Validation accuracy: {val_acc:.4f}")
132
+
133
+ # Evaluate on test set
134
+ y_pred_test, y_true_test = evaluate_simpleshot(x_train, y_train, x_test, y_test)
135
+ test_acc = float(np.mean(y_pred_test == y_true_test))
136
+ logger.info(f"Test accuracy: {test_acc:.4f}")
137
+
138
+ return {
139
+ "method": "simpleshot",
140
+ "val_accuracy": val_acc,
141
+ "test_accuracy": test_acc,
142
+ "embeddings_file": str(embeddings_path),
143
+ }
144
+
145
+
146
+ def evaluate_flowers102(args: argparse.Namespace) -> None:
147
+ tic = time.time()
148
+
149
+ logger.info(f"Loading Flowers102 dataset from {args.dataset_path}")
150
+ dataset_path = Path(args.dataset_path)
151
+ metadata_df = _load_flowers102_metadata(dataset_path)
152
+ logger.info(f"Loaded metadata for {metadata_df.height} images")
153
+
154
+ results: list[dict[str, Any]] = []
155
+ total = len(args.embeddings)
156
+ for idx, embeddings_path in enumerate(args.embeddings, start=1):
157
+ logger.info(f"Processing embeddings {idx}/{total}: {embeddings_path}")
158
+ x_train, y_train, x_val, y_val, x_test, y_test = _load_embeddings_with_split(embeddings_path, metadata_df)
159
+
160
+ result = evaluate_flowers102_single(x_train, y_train, x_val, y_val, x_test, y_test, embeddings_path)
161
+ results.append(result)
162
+
163
+ _print_summary_table(results)
164
+
165
+ if args.dry_run is False:
166
+ output_dir = settings.RESULTS_DIR.joinpath(args.dir)
167
+ output_dir.mkdir(parents=True, exist_ok=True)
168
+ output_path = output_dir.joinpath("flowers102.csv")
169
+ _write_results_csv(results, output_path)
170
+
171
+ toc = time.time()
172
+ logger.info(f"Flowers102 benchmark completed in {lib.format_duration(toc - tic)}")
173
+
174
+
175
+ def set_parser(subparsers: Any) -> None:
176
+ subparser = subparsers.add_parser(
177
+ "flowers102",
178
+ allow_abbrev=False,
179
+ help="run Flowers102 benchmark - 102 class classification using SimpleShot",
180
+ description="run Flowers102 benchmark - 102 class classification using SimpleShot",
181
+ epilog=(
182
+ "Usage examples:\n"
183
+ "python -m birder.eval flowers102 --embeddings "
184
+ "results/flowers102_embeddings.parquet --dataset-path ~/Datasets/Flowers102 --dry-run\n"
185
+ "python -m birder.eval flowers102 --embeddings results/flowers102/*.parquet "
186
+ "--dataset-path ~/Datasets/Flowers102\n"
187
+ ),
188
+ formatter_class=cli.ArgumentHelpFormatter,
189
+ )
190
+ subparser.add_argument(
191
+ "--embeddings", type=str, nargs="+", metavar="FILE", help="paths to embeddings parquet files"
192
+ )
193
+ subparser.add_argument("--dataset-path", type=str, metavar="PATH", help="path to Flowers102 dataset root")
194
+ subparser.add_argument(
195
+ "--dir", type=str, default="flowers102", help="place all outputs in a sub-directory (relative to results)"
196
+ )
197
+ subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
198
+ subparser.set_defaults(func=main)
199
+
200
+
201
+ def validate_args(args: argparse.Namespace) -> None:
202
+ if args.embeddings is None:
203
+ raise cli.ValidationError("--embeddings is required")
204
+ if args.dataset_path is None:
205
+ raise cli.ValidationError("--dataset-path is required")
206
+
207
+
208
+ def main(args: argparse.Namespace) -> None:
209
+ validate_args(args)
210
+ evaluate_flowers102(args)