birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.data import DataLoader
|
|
7
|
+
|
|
8
|
+
import birder
|
|
9
|
+
from birder.common import cli
|
|
10
|
+
from birder.conf import settings
|
|
11
|
+
from birder.data.dataloader.webdataset import make_wds_loader
|
|
12
|
+
from birder.data.datasets.directory import make_image_dataset
|
|
13
|
+
from birder.data.datasets.webdataset import make_wds_dataset
|
|
14
|
+
from birder.data.datasets.webdataset import prepare_wds_args
|
|
15
|
+
from birder.data.datasets.webdataset import wds_args_from_info
|
|
16
|
+
from birder.inference.data_parallel import InferenceDataParallel
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# pylint: disable=too-many-branches
|
|
22
|
+
def evaluate(args: argparse.Namespace) -> None:
|
|
23
|
+
if args.gpu is True:
|
|
24
|
+
device = torch.device("cuda")
|
|
25
|
+
elif args.mps is True:
|
|
26
|
+
device = torch.device("mps")
|
|
27
|
+
else:
|
|
28
|
+
device = torch.device("cpu")
|
|
29
|
+
|
|
30
|
+
if args.parallel is True and torch.cuda.device_count() > 1:
|
|
31
|
+
logger.info(f"Using {torch.cuda.device_count()} {device} devices")
|
|
32
|
+
else:
|
|
33
|
+
if args.gpu_id is not None:
|
|
34
|
+
torch.cuda.set_device(args.gpu_id)
|
|
35
|
+
|
|
36
|
+
logger.info(f"Using device {device}")
|
|
37
|
+
|
|
38
|
+
if args.fast_matmul is True or args.amp is True:
|
|
39
|
+
torch.set_float32_matmul_precision("high")
|
|
40
|
+
|
|
41
|
+
model_dtype: torch.dtype = getattr(torch, args.model_dtype)
|
|
42
|
+
amp_dtype: torch.dtype = getattr(torch, args.amp_dtype)
|
|
43
|
+
model_list = birder.list_pretrained_models(args.filter)
|
|
44
|
+
for model_name in model_list:
|
|
45
|
+
net, (class_to_idx, signature, rgb_stats, *_) = birder.load_pretrained_model(
|
|
46
|
+
model_name, inference=True, device=device, dtype=model_dtype
|
|
47
|
+
)
|
|
48
|
+
if args.channels_last is True:
|
|
49
|
+
net = net.to(memory_format=torch.channels_last)
|
|
50
|
+
logger.debug("Using channels-last memory format")
|
|
51
|
+
|
|
52
|
+
if args.parallel is True and torch.cuda.device_count() > 1:
|
|
53
|
+
net = InferenceDataParallel(net, output_device="cpu", compile_replicas=args.compile)
|
|
54
|
+
elif args.compile is True:
|
|
55
|
+
net = torch.compile(net)
|
|
56
|
+
|
|
57
|
+
if args.size is None:
|
|
58
|
+
size = birder.get_size_from_signature(signature)
|
|
59
|
+
else:
|
|
60
|
+
size = args.size
|
|
61
|
+
|
|
62
|
+
transform = birder.classification_transform(size, rgb_stats, args.center_crop, args.simple_crop)
|
|
63
|
+
|
|
64
|
+
if args.wds is True:
|
|
65
|
+
wds_path: str | list[str]
|
|
66
|
+
if args.wds_info is not None:
|
|
67
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
68
|
+
if args.wds_size is not None:
|
|
69
|
+
dataset_size = args.wds_size
|
|
70
|
+
else:
|
|
71
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
72
|
+
|
|
73
|
+
num_samples = dataset_size
|
|
74
|
+
dataset = make_wds_dataset(
|
|
75
|
+
wds_path,
|
|
76
|
+
dataset_size=dataset_size,
|
|
77
|
+
shuffle=False,
|
|
78
|
+
samples_names=True,
|
|
79
|
+
transform=transform,
|
|
80
|
+
)
|
|
81
|
+
inference_loader = make_wds_loader(
|
|
82
|
+
dataset,
|
|
83
|
+
args.batch_size,
|
|
84
|
+
num_workers=args.num_workers,
|
|
85
|
+
prefetch_factor=None,
|
|
86
|
+
collate_fn=None,
|
|
87
|
+
world_size=1,
|
|
88
|
+
pin_memory=False,
|
|
89
|
+
exact=True,
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
dataset = make_image_dataset(args.data_path, class_to_idx, transforms=transform)
|
|
93
|
+
num_samples = len(dataset)
|
|
94
|
+
inference_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
|
95
|
+
|
|
96
|
+
with torch.inference_mode():
|
|
97
|
+
results = birder.evaluate_classification(
|
|
98
|
+
device,
|
|
99
|
+
net,
|
|
100
|
+
inference_loader,
|
|
101
|
+
class_to_idx,
|
|
102
|
+
args.tta,
|
|
103
|
+
args.channels_last,
|
|
104
|
+
model_dtype,
|
|
105
|
+
args.amp,
|
|
106
|
+
amp_dtype,
|
|
107
|
+
num_samples=num_samples,
|
|
108
|
+
sparse=args.save_sparse_results,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
logger.info(f"{model_name}: accuracy={results.accuracy:.4f}")
|
|
112
|
+
base_output_path = (
|
|
113
|
+
f"{args.dir}/{model_name}_{len(class_to_idx)}_{size[0]}px_crop{args.center_crop}_{num_samples}"
|
|
114
|
+
)
|
|
115
|
+
if args.save_sparse_results is True:
|
|
116
|
+
results_file_suffix = "_sparse.csv"
|
|
117
|
+
else:
|
|
118
|
+
results_file_suffix = ".csv"
|
|
119
|
+
|
|
120
|
+
results.save(f"{base_output_path}{results_file_suffix}")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def set_parser(subparsers: Any) -> None:
|
|
124
|
+
subparser = subparsers.add_parser(
|
|
125
|
+
"classification",
|
|
126
|
+
allow_abbrev=False,
|
|
127
|
+
help="evaluate pretrained classification models on a dataset",
|
|
128
|
+
description="evaluate pretrained classification models on a dataset",
|
|
129
|
+
epilog=(
|
|
130
|
+
"Usage examples:\n"
|
|
131
|
+
"python -m birder.eval classification --filter '*il-all*' --fast-matmul --gpu "
|
|
132
|
+
"data/validation_il-all_packed\n"
|
|
133
|
+
"python -m birder.eval classification --amp --compile --gpu --gpu-id 1 data/testing\n"
|
|
134
|
+
"python -m birder.eval classification --filter '*inat21*' --amp --compile --gpu "
|
|
135
|
+
"--parallel ~/Datasets/inat2021/val\n"
|
|
136
|
+
"python -m birder.eval classification --wds --wds-info data/validation_wds/info.json --gpu\n"
|
|
137
|
+
),
|
|
138
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
139
|
+
)
|
|
140
|
+
subparser.add_argument("--filter", type=str, help="models to evaluate (fnmatch type filter)")
|
|
141
|
+
subparser.add_argument("--compile", default=False, action="store_true", help="enable compilation")
|
|
142
|
+
subparser.add_argument(
|
|
143
|
+
"--channels-last", default=False, action="store_true", help="use channels-last memory format"
|
|
144
|
+
)
|
|
145
|
+
subparser.add_argument(
|
|
146
|
+
"--model-dtype",
|
|
147
|
+
type=str,
|
|
148
|
+
choices=["float32", "float16", "bfloat16"],
|
|
149
|
+
default="float32",
|
|
150
|
+
help="model dtype to use",
|
|
151
|
+
)
|
|
152
|
+
subparser.add_argument(
|
|
153
|
+
"--amp", default=False, action="store_true", help="use torch.amp.autocast for mixed precision inference"
|
|
154
|
+
)
|
|
155
|
+
subparser.add_argument(
|
|
156
|
+
"--amp-dtype",
|
|
157
|
+
type=str,
|
|
158
|
+
choices=["float16", "bfloat16"],
|
|
159
|
+
default="float16",
|
|
160
|
+
help="whether to use float16 or bfloat16 for mixed precision",
|
|
161
|
+
)
|
|
162
|
+
subparser.add_argument(
|
|
163
|
+
"--fast-matmul", default=False, action="store_true", help="use fast matrix multiplication (affects precision)"
|
|
164
|
+
)
|
|
165
|
+
subparser.add_argument("--tta", default=False, action="store_true", help="test time augmentation (oversampling)")
|
|
166
|
+
subparser.add_argument(
|
|
167
|
+
"--size", type=int, nargs="+", metavar=("H", "W"), help="image size for inference (defaults to model signature)"
|
|
168
|
+
)
|
|
169
|
+
subparser.add_argument("--batch-size", type=int, default=64, metavar="N", help="the batch size")
|
|
170
|
+
subparser.add_argument(
|
|
171
|
+
"-j", "--num-workers", type=int, default=8, metavar="N", help="number of preprocessing workers"
|
|
172
|
+
)
|
|
173
|
+
subparser.add_argument("--center-crop", type=float, default=1.0, help="center crop ratio to use during inference")
|
|
174
|
+
subparser.add_argument(
|
|
175
|
+
"--simple-crop",
|
|
176
|
+
default=False,
|
|
177
|
+
action="store_true",
|
|
178
|
+
help="use a simple crop that preserves aspect ratio but may trim parts of the image",
|
|
179
|
+
)
|
|
180
|
+
subparser.add_argument(
|
|
181
|
+
"--dir", type=str, default="evaluate", help="place all outputs in a sub-directory (relative to results)"
|
|
182
|
+
)
|
|
183
|
+
subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
|
|
184
|
+
subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
|
|
185
|
+
subparser.add_argument(
|
|
186
|
+
"--mps", default=False, action="store_true", help="use mps (Metal Performance Shaders) device"
|
|
187
|
+
)
|
|
188
|
+
subparser.add_argument("--parallel", default=False, action="store_true", help="use multiple gpus")
|
|
189
|
+
subparser.add_argument(
|
|
190
|
+
"--save-sparse-results",
|
|
191
|
+
default=False,
|
|
192
|
+
action="store_true",
|
|
193
|
+
help="save results object in memory-efficient sparse format (only top-k probabilities)",
|
|
194
|
+
)
|
|
195
|
+
subparser.add_argument("--wds", default=False, action="store_true", help="evaluate a webdataset directory")
|
|
196
|
+
subparser.add_argument("--wds-size", type=int, metavar="N", help="size of the wds dataset")
|
|
197
|
+
subparser.add_argument("--wds-info", type=str, metavar="FILE", help="wds info file path")
|
|
198
|
+
subparser.add_argument(
|
|
199
|
+
"--wds-split", type=str, default="validation", metavar="NAME", help="wds dataset split to load"
|
|
200
|
+
)
|
|
201
|
+
subparser.add_argument("data_path", nargs="*", help="data files path (directories and files)")
|
|
202
|
+
subparser.set_defaults(func=main)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
206
|
+
if args.amp is True and args.model_dtype != "float32":
|
|
207
|
+
raise cli.ValidationError("--amp can only be used with --model-dtype float32")
|
|
208
|
+
if args.center_crop > 1 or args.center_crop <= 0.0:
|
|
209
|
+
raise cli.ValidationError(f"--center-crop must be in range of (0, 1.0], got {args.center_crop}")
|
|
210
|
+
if args.parallel is True and args.gpu is False:
|
|
211
|
+
raise cli.ValidationError("--parallel requires --gpu to be set")
|
|
212
|
+
if args.wds is False and len(args.data_path) == 0:
|
|
213
|
+
raise cli.ValidationError("Must provide at least one data source, --data-path or --wds")
|
|
214
|
+
|
|
215
|
+
if args.wds is True:
|
|
216
|
+
if args.wds_info is None and len(args.data_path) == 0:
|
|
217
|
+
raise cli.ValidationError("--wds requires a data path unless --wds-info is provided")
|
|
218
|
+
if len(args.data_path) > 1:
|
|
219
|
+
raise cli.ValidationError(f"--wds can have at most 1 --data-path, got {len(args.data_path)}")
|
|
220
|
+
if args.wds_info is None and len(args.data_path) == 1:
|
|
221
|
+
data_path = args.data_path[0]
|
|
222
|
+
if "://" in data_path and args.wds_size is None:
|
|
223
|
+
raise cli.ValidationError("--wds-size is required for remote --data-path")
|
|
224
|
+
|
|
225
|
+
args.size = cli.parse_size(args.size)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def main(args: argparse.Namespace) -> None:
|
|
229
|
+
validate_args(args)
|
|
230
|
+
|
|
231
|
+
if settings.RESULTS_DIR.joinpath(args.dir).exists() is False:
|
|
232
|
+
logger.info(f"Creating {settings.RESULTS_DIR.joinpath(args.dir)} directory...")
|
|
233
|
+
settings.RESULTS_DIR.joinpath(args.dir).mkdir(parents=True)
|
|
234
|
+
|
|
235
|
+
evaluate(args)
|
|
File without changes
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AMI clustering for unsupervised embedding evaluation
|
|
3
|
+
|
|
4
|
+
Paper "An Empirical Study into Clustering of Unseen Datasets with Self-Supervised Encoders",
|
|
5
|
+
https://arxiv.org/abs/2406.02465
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import numpy.typing as npt
|
|
13
|
+
from sklearn.cluster import AgglomerativeClustering
|
|
14
|
+
from sklearn.metrics import adjusted_mutual_info_score
|
|
15
|
+
|
|
16
|
+
from birder.eval._embeddings import l2_normalize
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import umap
|
|
22
|
+
|
|
23
|
+
_HAS_UMAP = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
_HAS_UMAP = False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def evaluate_ami(
|
|
29
|
+
features: npt.NDArray[np.float32],
|
|
30
|
+
labels: npt.NDArray[np.int_],
|
|
31
|
+
n_clusters: int,
|
|
32
|
+
umap_dim: int = 50,
|
|
33
|
+
l2_normalize_features: bool = True,
|
|
34
|
+
seed: Optional[int] = None,
|
|
35
|
+
) -> float:
|
|
36
|
+
"""
|
|
37
|
+
Evaluate embedding quality using UMAP + Agglomerative Clustering + AMI
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
features
|
|
42
|
+
Feature array of shape (n_samples, embedding_dim).
|
|
43
|
+
labels
|
|
44
|
+
True labels of shape (n_samples,).
|
|
45
|
+
n_clusters
|
|
46
|
+
Number of clusters (should match number of true classes).
|
|
47
|
+
umap_dim
|
|
48
|
+
Target dimensionality for UMAP reduction.
|
|
49
|
+
l2_normalize_features
|
|
50
|
+
If True, applies row-wise L2 normalization before UMAP.
|
|
51
|
+
seed
|
|
52
|
+
Random seed for UMAP reproducibility. When None, uses all available
|
|
53
|
+
cores (n_jobs=-1) but results are non-deterministic. When set, forces n_jobs=1.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
ami_score
|
|
58
|
+
Adjusted Mutual Information score between true labels and cluster assignments.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
assert _HAS_UMAP, "'pip install umap-learn' to use AMI evaluation"
|
|
62
|
+
|
|
63
|
+
if seed is None:
|
|
64
|
+
logger.warning("No seed set, UMAP results will be non-deterministic (using n_jobs=-1)")
|
|
65
|
+
n_jobs = -1
|
|
66
|
+
else:
|
|
67
|
+
n_jobs = 1
|
|
68
|
+
|
|
69
|
+
if l2_normalize_features is True:
|
|
70
|
+
features = l2_normalize(features)
|
|
71
|
+
|
|
72
|
+
reducer = umap.UMAP(n_components=umap_dim, min_dist=0.0, n_jobs=n_jobs, random_state=seed)
|
|
73
|
+
features_reduced = reducer.fit_transform(features)
|
|
74
|
+
|
|
75
|
+
clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage="ward")
|
|
76
|
+
cluster_assignments = clustering.fit_predict(features_reduced)
|
|
77
|
+
|
|
78
|
+
return float(adjusted_mutual_info_score(labels, cluster_assignments))
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
K-Nearest Neighbors classifier for few-shot learning evaluation
|
|
3
|
+
|
|
4
|
+
Uses cosine similarity (dot product of L2-normalized features) with temperature-scaled softmax voting.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import numpy.typing as npt
|
|
9
|
+
|
|
10
|
+
from birder.eval._embeddings import l2_normalize
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def evaluate_knn(
|
|
14
|
+
train_features: npt.NDArray[np.float32],
|
|
15
|
+
train_labels: npt.NDArray[np.int_],
|
|
16
|
+
test_features: npt.NDArray[np.float32],
|
|
17
|
+
test_labels: npt.NDArray[np.int_],
|
|
18
|
+
k: int,
|
|
19
|
+
temperature: float = 0.07,
|
|
20
|
+
) -> tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]:
|
|
21
|
+
"""
|
|
22
|
+
Evaluate using K-Nearest Neighbors with cosine similarity and soft voting
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
train_features
|
|
27
|
+
Training features of shape (n_train, embedding_dim).
|
|
28
|
+
train_labels
|
|
29
|
+
Training labels of shape (n_train,).
|
|
30
|
+
test_features
|
|
31
|
+
Test features of shape (n_test, embedding_dim).
|
|
32
|
+
test_labels
|
|
33
|
+
Test labels of shape (n_test,).
|
|
34
|
+
k
|
|
35
|
+
Number of nearest neighbors.
|
|
36
|
+
temperature
|
|
37
|
+
Temperature for softmax scaling.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
y_pred
|
|
42
|
+
Predicted labels for test samples.
|
|
43
|
+
y_true
|
|
44
|
+
True labels for test samples (same as test_labels).
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# Cosine similarity
|
|
48
|
+
train_norm = l2_normalize(train_features)
|
|
49
|
+
test_norm = l2_normalize(test_features)
|
|
50
|
+
similarities = test_norm @ train_norm.T
|
|
51
|
+
|
|
52
|
+
# Get top-k neighbors
|
|
53
|
+
top_k_indices = np.argsort(-similarities, axis=1)[:, :k]
|
|
54
|
+
top_k_sims = np.take_along_axis(similarities, top_k_indices, axis=1)
|
|
55
|
+
top_k_labels = train_labels[top_k_indices]
|
|
56
|
+
|
|
57
|
+
# Temperature-scaled softmax voting
|
|
58
|
+
top_k_sims_scaled = top_k_sims / temperature
|
|
59
|
+
top_k_sims_scaled = top_k_sims_scaled - top_k_sims_scaled.max(axis=1, keepdims=True)
|
|
60
|
+
weights = np.exp(top_k_sims_scaled)
|
|
61
|
+
weights = weights / weights.sum(axis=1, keepdims=True)
|
|
62
|
+
|
|
63
|
+
# Weighted voting
|
|
64
|
+
num_classes = train_labels.max() + 1
|
|
65
|
+
votes = np.zeros((len(test_features), num_classes), dtype=np.float32)
|
|
66
|
+
for i in range(k):
|
|
67
|
+
np.add.at(votes, (np.arange(len(test_features)), top_k_labels[:, i]), weights[:, i])
|
|
68
|
+
|
|
69
|
+
y_pred = votes.argmax(axis=1).astype(np.int_)
|
|
70
|
+
|
|
71
|
+
return (y_pred, test_labels)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Linear probing on frozen embeddings
|
|
3
|
+
|
|
4
|
+
Trains a single linear layer with cross-entropy loss for classification evaluation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import numpy.typing as npt
|
|
11
|
+
import torch
|
|
12
|
+
from torch import nn
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
from torch.utils.data import TensorDataset
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def train_linear_probe(
|
|
20
|
+
train_features: npt.NDArray[np.float32],
|
|
21
|
+
train_labels: npt.NDArray[np.int_],
|
|
22
|
+
num_classes: int,
|
|
23
|
+
device: torch.device,
|
|
24
|
+
epochs: int,
|
|
25
|
+
batch_size: int,
|
|
26
|
+
lr: float = 1e-4,
|
|
27
|
+
step_size: int = 20,
|
|
28
|
+
seed: int = 0,
|
|
29
|
+
) -> nn.Linear:
|
|
30
|
+
"""
|
|
31
|
+
Train a linear probe on frozen embeddings
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
train_features
|
|
36
|
+
Training features of shape (n_train, embedding_dim).
|
|
37
|
+
train_labels
|
|
38
|
+
Training labels of shape (n_train,), integer class indices.
|
|
39
|
+
num_classes
|
|
40
|
+
Number of output classes.
|
|
41
|
+
device
|
|
42
|
+
Device to train on.
|
|
43
|
+
epochs
|
|
44
|
+
Number of training epochs.
|
|
45
|
+
batch_size
|
|
46
|
+
Batch size for training.
|
|
47
|
+
lr
|
|
48
|
+
Learning rate.
|
|
49
|
+
step_size
|
|
50
|
+
Number of epochs between learning rate decay steps.
|
|
51
|
+
seed
|
|
52
|
+
Random seed for reproducibility.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
Trained linear layer.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
torch.manual_seed(seed)
|
|
60
|
+
np.random.seed(seed)
|
|
61
|
+
|
|
62
|
+
input_dim = train_features.shape[1]
|
|
63
|
+
model = nn.Linear(input_dim, num_classes).to(device)
|
|
64
|
+
|
|
65
|
+
x_train = torch.from_numpy(train_features).float()
|
|
66
|
+
y_train = torch.from_numpy(train_labels).long()
|
|
67
|
+
|
|
68
|
+
dataset = TensorDataset(x_train, y_train)
|
|
69
|
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
70
|
+
|
|
71
|
+
criterion = nn.CrossEntropyLoss()
|
|
72
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
73
|
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.5)
|
|
74
|
+
|
|
75
|
+
model.train()
|
|
76
|
+
for epoch in range(epochs):
|
|
77
|
+
total_loss = 0.0
|
|
78
|
+
for batch_x, batch_y in loader:
|
|
79
|
+
batch_x = batch_x.to(device)
|
|
80
|
+
batch_y = batch_y.to(device)
|
|
81
|
+
|
|
82
|
+
optimizer.zero_grad()
|
|
83
|
+
logits = model(batch_x)
|
|
84
|
+
loss = criterion(logits, batch_y)
|
|
85
|
+
loss.backward()
|
|
86
|
+
optimizer.step()
|
|
87
|
+
|
|
88
|
+
total_loss += loss.item() * batch_x.size(0)
|
|
89
|
+
|
|
90
|
+
scheduler.step()
|
|
91
|
+
|
|
92
|
+
if (epoch + 1) % 10 == 0:
|
|
93
|
+
avg_loss = total_loss / len(dataset)
|
|
94
|
+
logger.debug(f"Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.4f}")
|
|
95
|
+
|
|
96
|
+
return model
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def evaluate_linear_probe(
|
|
100
|
+
model: nn.Linear,
|
|
101
|
+
test_features: npt.NDArray[np.float32],
|
|
102
|
+
test_labels: npt.NDArray[np.int_],
|
|
103
|
+
batch_size: int = 128,
|
|
104
|
+
device: torch.device = torch.device("cpu"),
|
|
105
|
+
) -> tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]:
|
|
106
|
+
"""
|
|
107
|
+
Evaluate linear probe on test set
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
model
|
|
112
|
+
Trained linear layer.
|
|
113
|
+
test_features
|
|
114
|
+
Test features of shape (n_test, embedding_dim).
|
|
115
|
+
test_labels
|
|
116
|
+
Test labels of shape (n_test,), integer class indices.
|
|
117
|
+
batch_size
|
|
118
|
+
Batch size for inference.
|
|
119
|
+
device
|
|
120
|
+
Device to run on.
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
y_pred
|
|
125
|
+
Predicted labels for test samples.
|
|
126
|
+
y_true
|
|
127
|
+
True labels for test samples (same as test_labels).
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
x_test = torch.from_numpy(test_features).float()
|
|
131
|
+
y_test = torch.from_numpy(test_labels).long()
|
|
132
|
+
|
|
133
|
+
dataset = TensorDataset(x_test, y_test)
|
|
134
|
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
135
|
+
|
|
136
|
+
model.to(device)
|
|
137
|
+
model.eval()
|
|
138
|
+
all_preds: list[torch.Tensor] = []
|
|
139
|
+
all_labels: list[torch.Tensor] = []
|
|
140
|
+
|
|
141
|
+
with torch.inference_mode():
|
|
142
|
+
for batch_x, batch_y in loader:
|
|
143
|
+
batch_x = batch_x.to(device)
|
|
144
|
+
logits = model(batch_x)
|
|
145
|
+
preds = logits.argmax(dim=1)
|
|
146
|
+
all_preds.append(preds.cpu())
|
|
147
|
+
all_labels.append(batch_y)
|
|
148
|
+
|
|
149
|
+
y_pred = torch.concat(all_preds, dim=0).numpy().astype(np.int_)
|
|
150
|
+
y_true = torch.concat(all_labels, dim=0).numpy().astype(np.int_)
|
|
151
|
+
|
|
152
|
+
return (y_pred, y_true)
|