birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
1
+ import argparse
2
+ import logging
3
+ from typing import Any
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+
8
+ import birder
9
+ from birder.common import cli
10
+ from birder.conf import settings
11
+ from birder.data.dataloader.webdataset import make_wds_loader
12
+ from birder.data.datasets.directory import make_image_dataset
13
+ from birder.data.datasets.webdataset import make_wds_dataset
14
+ from birder.data.datasets.webdataset import prepare_wds_args
15
+ from birder.data.datasets.webdataset import wds_args_from_info
16
+ from birder.inference.data_parallel import InferenceDataParallel
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # pylint: disable=too-many-branches
22
+ def evaluate(args: argparse.Namespace) -> None:
23
+ if args.gpu is True:
24
+ device = torch.device("cuda")
25
+ elif args.mps is True:
26
+ device = torch.device("mps")
27
+ else:
28
+ device = torch.device("cpu")
29
+
30
+ if args.parallel is True and torch.cuda.device_count() > 1:
31
+ logger.info(f"Using {torch.cuda.device_count()} {device} devices")
32
+ else:
33
+ if args.gpu_id is not None:
34
+ torch.cuda.set_device(args.gpu_id)
35
+
36
+ logger.info(f"Using device {device}")
37
+
38
+ if args.fast_matmul is True or args.amp is True:
39
+ torch.set_float32_matmul_precision("high")
40
+
41
+ model_dtype: torch.dtype = getattr(torch, args.model_dtype)
42
+ amp_dtype: torch.dtype = getattr(torch, args.amp_dtype)
43
+ model_list = birder.list_pretrained_models(args.filter)
44
+ for model_name in model_list:
45
+ net, (class_to_idx, signature, rgb_stats, *_) = birder.load_pretrained_model(
46
+ model_name, inference=True, device=device, dtype=model_dtype
47
+ )
48
+ if args.channels_last is True:
49
+ net = net.to(memory_format=torch.channels_last)
50
+ logger.debug("Using channels-last memory format")
51
+
52
+ if args.parallel is True and torch.cuda.device_count() > 1:
53
+ net = InferenceDataParallel(net, output_device="cpu", compile_replicas=args.compile)
54
+ elif args.compile is True:
55
+ net = torch.compile(net)
56
+
57
+ if args.size is None:
58
+ size = birder.get_size_from_signature(signature)
59
+ else:
60
+ size = args.size
61
+
62
+ transform = birder.classification_transform(size, rgb_stats, args.center_crop, args.simple_crop)
63
+
64
+ if args.wds is True:
65
+ wds_path: str | list[str]
66
+ if args.wds_info is not None:
67
+ wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
68
+ if args.wds_size is not None:
69
+ dataset_size = args.wds_size
70
+ else:
71
+ wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
72
+
73
+ num_samples = dataset_size
74
+ dataset = make_wds_dataset(
75
+ wds_path,
76
+ dataset_size=dataset_size,
77
+ shuffle=False,
78
+ samples_names=True,
79
+ transform=transform,
80
+ )
81
+ inference_loader = make_wds_loader(
82
+ dataset,
83
+ args.batch_size,
84
+ num_workers=args.num_workers,
85
+ prefetch_factor=None,
86
+ collate_fn=None,
87
+ world_size=1,
88
+ pin_memory=False,
89
+ exact=True,
90
+ )
91
+ else:
92
+ dataset = make_image_dataset(args.data_path, class_to_idx, transforms=transform)
93
+ num_samples = len(dataset)
94
+ inference_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
95
+
96
+ with torch.inference_mode():
97
+ results = birder.evaluate_classification(
98
+ device,
99
+ net,
100
+ inference_loader,
101
+ class_to_idx,
102
+ args.tta,
103
+ args.channels_last,
104
+ model_dtype,
105
+ args.amp,
106
+ amp_dtype,
107
+ num_samples=num_samples,
108
+ sparse=args.save_sparse_results,
109
+ )
110
+
111
+ logger.info(f"{model_name}: accuracy={results.accuracy:.4f}")
112
+ base_output_path = (
113
+ f"{args.dir}/{model_name}_{len(class_to_idx)}_{size[0]}px_crop{args.center_crop}_{num_samples}"
114
+ )
115
+ if args.save_sparse_results is True:
116
+ results_file_suffix = "_sparse.csv"
117
+ else:
118
+ results_file_suffix = ".csv"
119
+
120
+ results.save(f"{base_output_path}{results_file_suffix}")
121
+
122
+
123
+ def set_parser(subparsers: Any) -> None:
124
+ subparser = subparsers.add_parser(
125
+ "classification",
126
+ allow_abbrev=False,
127
+ help="evaluate pretrained classification models on a dataset",
128
+ description="evaluate pretrained classification models on a dataset",
129
+ epilog=(
130
+ "Usage examples:\n"
131
+ "python -m birder.eval classification --filter '*il-all*' --fast-matmul --gpu "
132
+ "data/validation_il-all_packed\n"
133
+ "python -m birder.eval classification --amp --compile --gpu --gpu-id 1 data/testing\n"
134
+ "python -m birder.eval classification --filter '*inat21*' --amp --compile --gpu "
135
+ "--parallel ~/Datasets/inat2021/val\n"
136
+ "python -m birder.eval classification --wds --wds-info data/validation_wds/info.json --gpu\n"
137
+ ),
138
+ formatter_class=cli.ArgumentHelpFormatter,
139
+ )
140
+ subparser.add_argument("--filter", type=str, help="models to evaluate (fnmatch type filter)")
141
+ subparser.add_argument("--compile", default=False, action="store_true", help="enable compilation")
142
+ subparser.add_argument(
143
+ "--channels-last", default=False, action="store_true", help="use channels-last memory format"
144
+ )
145
+ subparser.add_argument(
146
+ "--model-dtype",
147
+ type=str,
148
+ choices=["float32", "float16", "bfloat16"],
149
+ default="float32",
150
+ help="model dtype to use",
151
+ )
152
+ subparser.add_argument(
153
+ "--amp", default=False, action="store_true", help="use torch.amp.autocast for mixed precision inference"
154
+ )
155
+ subparser.add_argument(
156
+ "--amp-dtype",
157
+ type=str,
158
+ choices=["float16", "bfloat16"],
159
+ default="float16",
160
+ help="whether to use float16 or bfloat16 for mixed precision",
161
+ )
162
+ subparser.add_argument(
163
+ "--fast-matmul", default=False, action="store_true", help="use fast matrix multiplication (affects precision)"
164
+ )
165
+ subparser.add_argument("--tta", default=False, action="store_true", help="test time augmentation (oversampling)")
166
+ subparser.add_argument(
167
+ "--size", type=int, nargs="+", metavar=("H", "W"), help="image size for inference (defaults to model signature)"
168
+ )
169
+ subparser.add_argument("--batch-size", type=int, default=64, metavar="N", help="the batch size")
170
+ subparser.add_argument(
171
+ "-j", "--num-workers", type=int, default=8, metavar="N", help="number of preprocessing workers"
172
+ )
173
+ subparser.add_argument("--center-crop", type=float, default=1.0, help="center crop ratio to use during inference")
174
+ subparser.add_argument(
175
+ "--simple-crop",
176
+ default=False,
177
+ action="store_true",
178
+ help="use a simple crop that preserves aspect ratio but may trim parts of the image",
179
+ )
180
+ subparser.add_argument(
181
+ "--dir", type=str, default="evaluate", help="place all outputs in a sub-directory (relative to results)"
182
+ )
183
+ subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
184
+ subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
185
+ subparser.add_argument(
186
+ "--mps", default=False, action="store_true", help="use mps (Metal Performance Shaders) device"
187
+ )
188
+ subparser.add_argument("--parallel", default=False, action="store_true", help="use multiple gpus")
189
+ subparser.add_argument(
190
+ "--save-sparse-results",
191
+ default=False,
192
+ action="store_true",
193
+ help="save results object in memory-efficient sparse format (only top-k probabilities)",
194
+ )
195
+ subparser.add_argument("--wds", default=False, action="store_true", help="evaluate a webdataset directory")
196
+ subparser.add_argument("--wds-size", type=int, metavar="N", help="size of the wds dataset")
197
+ subparser.add_argument("--wds-info", type=str, metavar="FILE", help="wds info file path")
198
+ subparser.add_argument(
199
+ "--wds-split", type=str, default="validation", metavar="NAME", help="wds dataset split to load"
200
+ )
201
+ subparser.add_argument("data_path", nargs="*", help="data files path (directories and files)")
202
+ subparser.set_defaults(func=main)
203
+
204
+
205
+ def validate_args(args: argparse.Namespace) -> None:
206
+ if args.amp is True and args.model_dtype != "float32":
207
+ raise cli.ValidationError("--amp can only be used with --model-dtype float32")
208
+ if args.center_crop > 1 or args.center_crop <= 0.0:
209
+ raise cli.ValidationError(f"--center-crop must be in range of (0, 1.0], got {args.center_crop}")
210
+ if args.parallel is True and args.gpu is False:
211
+ raise cli.ValidationError("--parallel requires --gpu to be set")
212
+ if args.wds is False and len(args.data_path) == 0:
213
+ raise cli.ValidationError("Must provide at least one data source, --data-path or --wds")
214
+
215
+ if args.wds is True:
216
+ if args.wds_info is None and len(args.data_path) == 0:
217
+ raise cli.ValidationError("--wds requires a data path unless --wds-info is provided")
218
+ if len(args.data_path) > 1:
219
+ raise cli.ValidationError(f"--wds can have at most 1 --data-path, got {len(args.data_path)}")
220
+ if args.wds_info is None and len(args.data_path) == 1:
221
+ data_path = args.data_path[0]
222
+ if "://" in data_path and args.wds_size is None:
223
+ raise cli.ValidationError("--wds-size is required for remote --data-path")
224
+
225
+ args.size = cli.parse_size(args.size)
226
+
227
+
228
+ def main(args: argparse.Namespace) -> None:
229
+ validate_args(args)
230
+
231
+ if settings.RESULTS_DIR.joinpath(args.dir).exists() is False:
232
+ logger.info(f"Creating {settings.RESULTS_DIR.joinpath(args.dir)} directory...")
233
+ settings.RESULTS_DIR.joinpath(args.dir).mkdir(parents=True)
234
+
235
+ evaluate(args)
File without changes
@@ -0,0 +1,78 @@
1
+ """
2
+ AMI clustering for unsupervised embedding evaluation
3
+
4
+ Paper "An Empirical Study into Clustering of Unseen Datasets with Self-Supervised Encoders",
5
+ https://arxiv.org/abs/2406.02465
6
+ """
7
+
8
+ import logging
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ import numpy.typing as npt
13
+ from sklearn.cluster import AgglomerativeClustering
14
+ from sklearn.metrics import adjusted_mutual_info_score
15
+
16
+ from birder.eval._embeddings import l2_normalize
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ try:
21
+ import umap
22
+
23
+ _HAS_UMAP = True
24
+ except ImportError:
25
+ _HAS_UMAP = False
26
+
27
+
28
+ def evaluate_ami(
29
+ features: npt.NDArray[np.float32],
30
+ labels: npt.NDArray[np.int_],
31
+ n_clusters: int,
32
+ umap_dim: int = 50,
33
+ l2_normalize_features: bool = True,
34
+ seed: Optional[int] = None,
35
+ ) -> float:
36
+ """
37
+ Evaluate embedding quality using UMAP + Agglomerative Clustering + AMI
38
+
39
+ Parameters
40
+ ----------
41
+ features
42
+ Feature array of shape (n_samples, embedding_dim).
43
+ labels
44
+ True labels of shape (n_samples,).
45
+ n_clusters
46
+ Number of clusters (should match number of true classes).
47
+ umap_dim
48
+ Target dimensionality for UMAP reduction.
49
+ l2_normalize_features
50
+ If True, applies row-wise L2 normalization before UMAP.
51
+ seed
52
+ Random seed for UMAP reproducibility. When None, uses all available
53
+ cores (n_jobs=-1) but results are non-deterministic. When set, forces n_jobs=1.
54
+
55
+ Returns
56
+ -------
57
+ ami_score
58
+ Adjusted Mutual Information score between true labels and cluster assignments.
59
+ """
60
+
61
+ assert _HAS_UMAP, "'pip install umap-learn' to use AMI evaluation"
62
+
63
+ if seed is None:
64
+ logger.warning("No seed set, UMAP results will be non-deterministic (using n_jobs=-1)")
65
+ n_jobs = -1
66
+ else:
67
+ n_jobs = 1
68
+
69
+ if l2_normalize_features is True:
70
+ features = l2_normalize(features)
71
+
72
+ reducer = umap.UMAP(n_components=umap_dim, min_dist=0.0, n_jobs=n_jobs, random_state=seed)
73
+ features_reduced = reducer.fit_transform(features)
74
+
75
+ clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage="ward")
76
+ cluster_assignments = clustering.fit_predict(features_reduced)
77
+
78
+ return float(adjusted_mutual_info_score(labels, cluster_assignments))
@@ -0,0 +1,71 @@
1
+ """
2
+ K-Nearest Neighbors classifier for few-shot learning evaluation
3
+
4
+ Uses cosine similarity (dot product of L2-normalized features) with temperature-scaled softmax voting.
5
+ """
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+
10
+ from birder.eval._embeddings import l2_normalize
11
+
12
+
13
+ def evaluate_knn(
14
+ train_features: npt.NDArray[np.float32],
15
+ train_labels: npt.NDArray[np.int_],
16
+ test_features: npt.NDArray[np.float32],
17
+ test_labels: npt.NDArray[np.int_],
18
+ k: int,
19
+ temperature: float = 0.07,
20
+ ) -> tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]:
21
+ """
22
+ Evaluate using K-Nearest Neighbors with cosine similarity and soft voting
23
+
24
+ Parameters
25
+ ----------
26
+ train_features
27
+ Training features of shape (n_train, embedding_dim).
28
+ train_labels
29
+ Training labels of shape (n_train,).
30
+ test_features
31
+ Test features of shape (n_test, embedding_dim).
32
+ test_labels
33
+ Test labels of shape (n_test,).
34
+ k
35
+ Number of nearest neighbors.
36
+ temperature
37
+ Temperature for softmax scaling.
38
+
39
+ Returns
40
+ -------
41
+ y_pred
42
+ Predicted labels for test samples.
43
+ y_true
44
+ True labels for test samples (same as test_labels).
45
+ """
46
+
47
+ # Cosine similarity
48
+ train_norm = l2_normalize(train_features)
49
+ test_norm = l2_normalize(test_features)
50
+ similarities = test_norm @ train_norm.T
51
+
52
+ # Get top-k neighbors
53
+ top_k_indices = np.argsort(-similarities, axis=1)[:, :k]
54
+ top_k_sims = np.take_along_axis(similarities, top_k_indices, axis=1)
55
+ top_k_labels = train_labels[top_k_indices]
56
+
57
+ # Temperature-scaled softmax voting
58
+ top_k_sims_scaled = top_k_sims / temperature
59
+ top_k_sims_scaled = top_k_sims_scaled - top_k_sims_scaled.max(axis=1, keepdims=True)
60
+ weights = np.exp(top_k_sims_scaled)
61
+ weights = weights / weights.sum(axis=1, keepdims=True)
62
+
63
+ # Weighted voting
64
+ num_classes = train_labels.max() + 1
65
+ votes = np.zeros((len(test_features), num_classes), dtype=np.float32)
66
+ for i in range(k):
67
+ np.add.at(votes, (np.arange(len(test_features)), top_k_labels[:, i]), weights[:, i])
68
+
69
+ y_pred = votes.argmax(axis=1).astype(np.int_)
70
+
71
+ return (y_pred, test_labels)
@@ -0,0 +1,152 @@
1
+ """
2
+ Linear probing on frozen embeddings
3
+
4
+ Trains a single linear layer with cross-entropy loss for classification evaluation.
5
+ """
6
+
7
+ import logging
8
+
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+ import torch
12
+ from torch import nn
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.data import TensorDataset
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def train_linear_probe(
20
+ train_features: npt.NDArray[np.float32],
21
+ train_labels: npt.NDArray[np.int_],
22
+ num_classes: int,
23
+ device: torch.device,
24
+ epochs: int,
25
+ batch_size: int,
26
+ lr: float = 1e-4,
27
+ step_size: int = 20,
28
+ seed: int = 0,
29
+ ) -> nn.Linear:
30
+ """
31
+ Train a linear probe on frozen embeddings
32
+
33
+ Parameters
34
+ ----------
35
+ train_features
36
+ Training features of shape (n_train, embedding_dim).
37
+ train_labels
38
+ Training labels of shape (n_train,), integer class indices.
39
+ num_classes
40
+ Number of output classes.
41
+ device
42
+ Device to train on.
43
+ epochs
44
+ Number of training epochs.
45
+ batch_size
46
+ Batch size for training.
47
+ lr
48
+ Learning rate.
49
+ step_size
50
+ Number of epochs between learning rate decay steps.
51
+ seed
52
+ Random seed for reproducibility.
53
+
54
+ Returns
55
+ -------
56
+ Trained linear layer.
57
+ """
58
+
59
+ torch.manual_seed(seed)
60
+ np.random.seed(seed)
61
+
62
+ input_dim = train_features.shape[1]
63
+ model = nn.Linear(input_dim, num_classes).to(device)
64
+
65
+ x_train = torch.from_numpy(train_features).float()
66
+ y_train = torch.from_numpy(train_labels).long()
67
+
68
+ dataset = TensorDataset(x_train, y_train)
69
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
70
+
71
+ criterion = nn.CrossEntropyLoss()
72
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
73
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.5)
74
+
75
+ model.train()
76
+ for epoch in range(epochs):
77
+ total_loss = 0.0
78
+ for batch_x, batch_y in loader:
79
+ batch_x = batch_x.to(device)
80
+ batch_y = batch_y.to(device)
81
+
82
+ optimizer.zero_grad()
83
+ logits = model(batch_x)
84
+ loss = criterion(logits, batch_y)
85
+ loss.backward()
86
+ optimizer.step()
87
+
88
+ total_loss += loss.item() * batch_x.size(0)
89
+
90
+ scheduler.step()
91
+
92
+ if (epoch + 1) % 10 == 0:
93
+ avg_loss = total_loss / len(dataset)
94
+ logger.debug(f"Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.4f}")
95
+
96
+ return model
97
+
98
+
99
+ def evaluate_linear_probe(
100
+ model: nn.Linear,
101
+ test_features: npt.NDArray[np.float32],
102
+ test_labels: npt.NDArray[np.int_],
103
+ batch_size: int = 128,
104
+ device: torch.device = torch.device("cpu"),
105
+ ) -> tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]:
106
+ """
107
+ Evaluate linear probe on test set
108
+
109
+ Parameters
110
+ ----------
111
+ model
112
+ Trained linear layer.
113
+ test_features
114
+ Test features of shape (n_test, embedding_dim).
115
+ test_labels
116
+ Test labels of shape (n_test,), integer class indices.
117
+ batch_size
118
+ Batch size for inference.
119
+ device
120
+ Device to run on.
121
+
122
+ Returns
123
+ -------
124
+ y_pred
125
+ Predicted labels for test samples.
126
+ y_true
127
+ True labels for test samples (same as test_labels).
128
+ """
129
+
130
+ x_test = torch.from_numpy(test_features).float()
131
+ y_test = torch.from_numpy(test_labels).long()
132
+
133
+ dataset = TensorDataset(x_test, y_test)
134
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
135
+
136
+ model.to(device)
137
+ model.eval()
138
+ all_preds: list[torch.Tensor] = []
139
+ all_labels: list[torch.Tensor] = []
140
+
141
+ with torch.inference_mode():
142
+ for batch_x, batch_y in loader:
143
+ batch_x = batch_x.to(device)
144
+ logits = model(batch_x)
145
+ preds = logits.argmax(dim=1)
146
+ all_preds.append(preds.cpu())
147
+ all_labels.append(batch_y)
148
+
149
+ y_pred = torch.concat(all_preds, dim=0).numpy().astype(np.int_)
150
+ y_true = torch.concat(all_labels, dim=0).numpy().astype(np.int_)
151
+
152
+ return (y_pred, y_true)