birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def l2_normalize(x: npt.NDArray[np.float32], eps: float = 1e-12) -> npt.NDArray[np.float32]:
|
|
12
|
+
norms = np.linalg.norm(x, axis=1, keepdims=True)
|
|
13
|
+
return x / np.clip(norms, eps, None) # type: ignore[no-any-return]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def load_embeddings(path: Path | str) -> tuple[list[str], npt.NDArray[np.float32]]:
|
|
17
|
+
"""
|
|
18
|
+
Load embeddings from parquet file
|
|
19
|
+
|
|
20
|
+
Auto-detects format:
|
|
21
|
+
- If 'embedding' column exists: use directly
|
|
22
|
+
- If numeric column names (0, 1, 2, ...): treat as logits, convert to array
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
sample_ids
|
|
27
|
+
List of sample identifiers (stem of 'sample' column path).
|
|
28
|
+
features
|
|
29
|
+
Array of shape (n_samples, embedding_dim), dtype float32.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
if isinstance(path, str):
|
|
33
|
+
path = Path(path)
|
|
34
|
+
|
|
35
|
+
df = pl.read_parquet(path)
|
|
36
|
+
df = df.with_columns(pl.col("sample").map_elements(lambda p: Path(p).stem, return_dtype=pl.Utf8).alias("id"))
|
|
37
|
+
|
|
38
|
+
if "embedding" in df.columns:
|
|
39
|
+
df = df.select(["id", "embedding"])
|
|
40
|
+
else:
|
|
41
|
+
# Logits format - numeric column names
|
|
42
|
+
embed_cols = sorted([c for c in df.columns if c.isdigit()], key=int)
|
|
43
|
+
df = df.with_columns(
|
|
44
|
+
pl.concat_list(pl.col(embed_cols)).cast(pl.Array(pl.Float32, len(embed_cols))).alias("embedding")
|
|
45
|
+
).select(["id", "embedding"])
|
|
46
|
+
|
|
47
|
+
sample_ids = df.get_column("id").to_list()
|
|
48
|
+
features = df.get_column("embedding").to_numpy().astype(np.float32, copy=False)
|
|
49
|
+
|
|
50
|
+
return (sample_ids, features)
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from birder.adversarial.deepfool import DeepFool
|
|
11
|
+
from birder.adversarial.fgsm import FGSM
|
|
12
|
+
from birder.adversarial.pgd import PGD
|
|
13
|
+
from birder.adversarial.simba import SimBA
|
|
14
|
+
from birder.common import cli
|
|
15
|
+
from birder.common import fs_ops
|
|
16
|
+
from birder.common import lib
|
|
17
|
+
from birder.conf import settings
|
|
18
|
+
from birder.data.dataloader.webdataset import make_wds_loader
|
|
19
|
+
from birder.data.datasets.directory import make_image_dataset
|
|
20
|
+
from birder.data.datasets.webdataset import make_wds_dataset
|
|
21
|
+
from birder.data.datasets.webdataset import prepare_wds_args
|
|
22
|
+
from birder.data.datasets.webdataset import wds_args_from_info
|
|
23
|
+
from birder.data.transforms.classification import RGBType
|
|
24
|
+
from birder.data.transforms.classification import inference_preset
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _build_attack(
|
|
30
|
+
method: str,
|
|
31
|
+
net: torch.nn.Module,
|
|
32
|
+
rgb_stats: RGBType,
|
|
33
|
+
eps: float,
|
|
34
|
+
steps: int,
|
|
35
|
+
step_size: float | None,
|
|
36
|
+
deepfool_num_classes: int,
|
|
37
|
+
) -> FGSM | PGD | DeepFool | SimBA:
|
|
38
|
+
if method == "fgsm":
|
|
39
|
+
return FGSM(net, eps=eps, rgb_stats=rgb_stats)
|
|
40
|
+
|
|
41
|
+
if method == "pgd":
|
|
42
|
+
return PGD(
|
|
43
|
+
net,
|
|
44
|
+
eps=eps,
|
|
45
|
+
steps=steps,
|
|
46
|
+
step_size=step_size,
|
|
47
|
+
random_start=False,
|
|
48
|
+
rgb_stats=rgb_stats,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if method == "deepfool":
|
|
52
|
+
return DeepFool(net, num_classes=deepfool_num_classes, overshoot=0.02, max_iter=steps, rgb_stats=rgb_stats)
|
|
53
|
+
|
|
54
|
+
if method == "simba":
|
|
55
|
+
return SimBA(
|
|
56
|
+
net,
|
|
57
|
+
step_size=step_size if step_size is not None else eps,
|
|
58
|
+
max_iter=steps,
|
|
59
|
+
rgb_stats=rgb_stats,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
raise ValueError(f"Unsupported attack method '{method}'")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# pylint: disable=too-many-locals,too-many-branches
|
|
66
|
+
def evaluate_robust(args: argparse.Namespace) -> None:
|
|
67
|
+
if args.gpu is True:
|
|
68
|
+
device = torch.device("cuda")
|
|
69
|
+
elif args.mps is True:
|
|
70
|
+
device = torch.device("mps")
|
|
71
|
+
else:
|
|
72
|
+
device = torch.device("cpu")
|
|
73
|
+
|
|
74
|
+
if args.gpu_id is not None:
|
|
75
|
+
torch.cuda.set_device(args.gpu_id)
|
|
76
|
+
|
|
77
|
+
logger.info(f"Using device {device}")
|
|
78
|
+
|
|
79
|
+
if args.fast_matmul is True or args.amp is True:
|
|
80
|
+
torch.set_float32_matmul_precision("high")
|
|
81
|
+
|
|
82
|
+
if args.amp_dtype is None:
|
|
83
|
+
amp_dtype = torch.get_autocast_dtype(device.type)
|
|
84
|
+
logger.debug(f"AMP: {args.amp}, AMP dtype: {amp_dtype}")
|
|
85
|
+
else:
|
|
86
|
+
amp_dtype = getattr(torch, args.amp_dtype)
|
|
87
|
+
|
|
88
|
+
network_name = lib.get_network_name(args.network, tag=args.tag)
|
|
89
|
+
net, model_info = fs_ops.load_model(
|
|
90
|
+
device,
|
|
91
|
+
args.network,
|
|
92
|
+
tag=args.tag,
|
|
93
|
+
epoch=args.epoch,
|
|
94
|
+
new_size=args.size,
|
|
95
|
+
inference=True,
|
|
96
|
+
reparameterized=args.reparameterized,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
class_to_idx = model_info.class_to_idx
|
|
100
|
+
rgb_stats = model_info.rgb_stats
|
|
101
|
+
if args.size is None:
|
|
102
|
+
size = lib.get_size_from_signature(model_info.signature)
|
|
103
|
+
else:
|
|
104
|
+
size = args.size
|
|
105
|
+
|
|
106
|
+
transform = inference_preset(size, rgb_stats, args.center_crop, args.simple_crop)
|
|
107
|
+
|
|
108
|
+
if args.wds is True:
|
|
109
|
+
wds_path: str | list[str]
|
|
110
|
+
if args.wds_info is not None:
|
|
111
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
112
|
+
if args.wds_size is not None:
|
|
113
|
+
dataset_size = args.wds_size
|
|
114
|
+
else:
|
|
115
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
116
|
+
|
|
117
|
+
num_samples = dataset_size
|
|
118
|
+
dataset = make_wds_dataset(
|
|
119
|
+
wds_path,
|
|
120
|
+
dataset_size=dataset_size,
|
|
121
|
+
shuffle=False,
|
|
122
|
+
samples_names=True,
|
|
123
|
+
transform=transform,
|
|
124
|
+
)
|
|
125
|
+
dataloader = make_wds_loader(
|
|
126
|
+
dataset,
|
|
127
|
+
args.batch_size,
|
|
128
|
+
num_workers=args.num_workers,
|
|
129
|
+
prefetch_factor=None,
|
|
130
|
+
collate_fn=None,
|
|
131
|
+
world_size=1,
|
|
132
|
+
pin_memory=False,
|
|
133
|
+
exact=True,
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
dataset = make_image_dataset(args.data_path, class_to_idx, transforms=transform)
|
|
137
|
+
num_samples = len(dataset)
|
|
138
|
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
|
139
|
+
|
|
140
|
+
attack = _build_attack(args.method, net, rgb_stats, args.eps, args.steps, args.step_size, args.deepfool_num_classes)
|
|
141
|
+
|
|
142
|
+
clean_correct = 0
|
|
143
|
+
adv_correct = 0
|
|
144
|
+
total = 0
|
|
145
|
+
skipped_unlabeled = 0
|
|
146
|
+
with tqdm(total=num_samples, unit="images", leave=False) as progress:
|
|
147
|
+
for _, inputs, targets in dataloader:
|
|
148
|
+
inputs = inputs.to(device)
|
|
149
|
+
targets = targets.to(device)
|
|
150
|
+
batch_size = inputs.size(0)
|
|
151
|
+
|
|
152
|
+
valid_mask = targets != settings.NO_LABEL
|
|
153
|
+
num_valid = int(valid_mask.sum().item())
|
|
154
|
+
skipped_unlabeled += batch_size - num_valid
|
|
155
|
+
if num_valid == 0:
|
|
156
|
+
progress.update(batch_size)
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
inputs = inputs[valid_mask]
|
|
160
|
+
targets = targets[valid_mask]
|
|
161
|
+
|
|
162
|
+
with torch.no_grad():
|
|
163
|
+
with torch.amp.autocast(device.type, enabled=args.amp, dtype=amp_dtype):
|
|
164
|
+
clean_logits = net(inputs)
|
|
165
|
+
|
|
166
|
+
clean_preds = clean_logits.argmax(dim=1)
|
|
167
|
+
clean_correct += (clean_preds == targets).sum().item()
|
|
168
|
+
|
|
169
|
+
result = attack(inputs, target=None)
|
|
170
|
+
adv_logits = result.adv_logits
|
|
171
|
+
adv_preds = adv_logits.argmax(dim=1)
|
|
172
|
+
adv_correct += (adv_preds == targets).sum().item()
|
|
173
|
+
|
|
174
|
+
total += num_valid
|
|
175
|
+
progress.update(batch_size)
|
|
176
|
+
|
|
177
|
+
if total == 0:
|
|
178
|
+
raise RuntimeError(f"No labeled samples found (all labels are {settings.NO_LABEL})")
|
|
179
|
+
|
|
180
|
+
if skipped_unlabeled > 0:
|
|
181
|
+
logger.warning(f"Skipped {skipped_unlabeled} unlabeled samples (label={settings.NO_LABEL})")
|
|
182
|
+
|
|
183
|
+
clean_accuracy = clean_correct / total
|
|
184
|
+
adv_accuracy = adv_correct / total
|
|
185
|
+
accuracy_drop = clean_accuracy - adv_accuracy
|
|
186
|
+
|
|
187
|
+
logger.info(
|
|
188
|
+
f"{network_name}: clean={clean_accuracy:.4f}, adv={adv_accuracy:.4f}, drop={accuracy_drop:.4f} "
|
|
189
|
+
f"(evaluated on {total} labeled samples)"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if args.dry_run is False:
|
|
193
|
+
output = {
|
|
194
|
+
"method": "robust",
|
|
195
|
+
"accuracy": adv_accuracy,
|
|
196
|
+
"clean_accuracy": clean_accuracy,
|
|
197
|
+
"accuracy_drop": accuracy_drop,
|
|
198
|
+
"attack_method": args.method,
|
|
199
|
+
"epsilon": args.eps,
|
|
200
|
+
"num_samples": total,
|
|
201
|
+
"num_skipped_unlabeled": skipped_unlabeled,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
output_dir = settings.RESULTS_DIR.joinpath(args.dir)
|
|
205
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
206
|
+
epoch_str = f"_e{args.epoch}" if args.epoch is not None else ""
|
|
207
|
+
output_path = output_dir.joinpath(f"{network_name}{epoch_str}_{args.method}_eps{args.eps}.json")
|
|
208
|
+
|
|
209
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
210
|
+
json.dump(output, f, indent=2)
|
|
211
|
+
|
|
212
|
+
logger.info(f"Results saved to {output_path}")
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def set_parser(subparsers: Any) -> None:
|
|
216
|
+
subparser = subparsers.add_parser(
|
|
217
|
+
"adversarial",
|
|
218
|
+
allow_abbrev=False,
|
|
219
|
+
help="evaluate adversarial robustness of a model on a dataset",
|
|
220
|
+
description="evaluate adversarial robustness of a model on a dataset",
|
|
221
|
+
epilog=(
|
|
222
|
+
"Usage examples:\n"
|
|
223
|
+
"python -m birder.eval adversarial -n resnet_v2_50 -e 100 --method pgd --eps 0.02 --gpu data/validation\n"
|
|
224
|
+
"python -m birder.eval adversarial -n vovnet_v2_39 -t il-common --method pgd --batch-size 4 "
|
|
225
|
+
"--gpu --gpu-id 1 --fast-matmul data/validation_il-common_packed\n"
|
|
226
|
+
),
|
|
227
|
+
formatter_class=cli.ArgumentHelpFormatter,
|
|
228
|
+
)
|
|
229
|
+
subparser.add_argument("-n", "--network", type=str, help="neural network to evaluate")
|
|
230
|
+
subparser.add_argument("-t", "--tag", type=str, help="model tag")
|
|
231
|
+
subparser.add_argument("-e", "--epoch", type=int, metavar="N", help="model checkpoint epoch")
|
|
232
|
+
subparser.add_argument("--reparameterized", default=False, action="store_true", help="load reparameterized model")
|
|
233
|
+
subparser.add_argument(
|
|
234
|
+
"--method",
|
|
235
|
+
type=str,
|
|
236
|
+
choices=["fgsm", "pgd", "deepfool", "simba"],
|
|
237
|
+
help="adversarial attack method",
|
|
238
|
+
)
|
|
239
|
+
subparser.add_argument("--eps", type=float, default=0.007, help="perturbation budget in pixel space [0, 1]")
|
|
240
|
+
subparser.add_argument("--steps", type=int, default=10, help="number of iterations for iterative attacks")
|
|
241
|
+
subparser.add_argument("--step-size", type=float, help="step size in pixel space (defaults to eps/steps for PGD)")
|
|
242
|
+
subparser.add_argument(
|
|
243
|
+
"--deepfool-num-classes", type=int, default=10, help="number of top classes to consider for DeepFool"
|
|
244
|
+
)
|
|
245
|
+
subparser.add_argument(
|
|
246
|
+
"--size", type=int, nargs="+", metavar=("H", "W"), help="image size for inference (defaults to model signature)"
|
|
247
|
+
)
|
|
248
|
+
subparser.add_argument(
|
|
249
|
+
"--amp", default=False, action="store_true", help="use torch.amp.autocast for mixed precision inference"
|
|
250
|
+
)
|
|
251
|
+
subparser.add_argument(
|
|
252
|
+
"--amp-dtype",
|
|
253
|
+
type=str,
|
|
254
|
+
choices=["float16", "bfloat16"],
|
|
255
|
+
help="whether to use float16 or bfloat16 for mixed precision",
|
|
256
|
+
)
|
|
257
|
+
subparser.add_argument(
|
|
258
|
+
"--fast-matmul", default=False, action="store_true", help="use fast matrix multiplication (affects precision)"
|
|
259
|
+
)
|
|
260
|
+
subparser.add_argument("--batch-size", type=int, default=32, metavar="N", help="the batch size")
|
|
261
|
+
subparser.add_argument(
|
|
262
|
+
"-j", "--num-workers", type=int, default=8, metavar="N", help="number of preprocessing workers"
|
|
263
|
+
)
|
|
264
|
+
subparser.add_argument("--center-crop", type=float, default=1.0, help="center crop ratio to use during inference")
|
|
265
|
+
subparser.add_argument(
|
|
266
|
+
"--simple-crop",
|
|
267
|
+
default=False,
|
|
268
|
+
action="store_true",
|
|
269
|
+
help="use a simple crop that preserves aspect ratio but may trim parts of the image",
|
|
270
|
+
)
|
|
271
|
+
subparser.add_argument(
|
|
272
|
+
"--dir", type=str, default="robust", help="place all outputs in a sub-directory (relative to results)"
|
|
273
|
+
)
|
|
274
|
+
subparser.add_argument("--dry-run", default=False, action="store_true", help="skip saving results to file")
|
|
275
|
+
subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
|
|
276
|
+
subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
|
|
277
|
+
subparser.add_argument(
|
|
278
|
+
"--mps", default=False, action="store_true", help="use mps (Metal Performance Shaders) device"
|
|
279
|
+
)
|
|
280
|
+
subparser.add_argument("--wds", default=False, action="store_true", help="evaluate a webdataset directory")
|
|
281
|
+
subparser.add_argument("--wds-size", type=int, metavar="N", help="size of the wds dataset")
|
|
282
|
+
subparser.add_argument("--wds-info", type=str, metavar="FILE", help="wds info file path")
|
|
283
|
+
subparser.add_argument(
|
|
284
|
+
"--wds-split", type=str, default="validation", metavar="NAME", help="wds dataset split to load"
|
|
285
|
+
)
|
|
286
|
+
subparser.add_argument("data_path", nargs="*", help="data files path (directories and files)")
|
|
287
|
+
subparser.set_defaults(func=main)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def validate_args(args: argparse.Namespace) -> None:
|
|
291
|
+
args.size = cli.parse_size(args.size)
|
|
292
|
+
if args.network is None:
|
|
293
|
+
raise cli.ValidationError("--network is required")
|
|
294
|
+
if args.method is None:
|
|
295
|
+
raise cli.ValidationError("--method is required")
|
|
296
|
+
if args.center_crop > 1 or args.center_crop <= 0.0:
|
|
297
|
+
raise cli.ValidationError(f"--center-crop must be in range of (0, 1.0], got {args.center_crop}")
|
|
298
|
+
|
|
299
|
+
if args.wds is False and len(args.data_path) == 0:
|
|
300
|
+
raise cli.ValidationError("Must provide at least one data source, --data-path or --wds")
|
|
301
|
+
|
|
302
|
+
if args.wds is True:
|
|
303
|
+
if args.wds_info is None and len(args.data_path) == 0:
|
|
304
|
+
raise cli.ValidationError("--wds requires a data path unless --wds-info is provided")
|
|
305
|
+
if len(args.data_path) > 1:
|
|
306
|
+
raise cli.ValidationError(f"--wds can have at most 1 --data-path, got {len(args.data_path)}")
|
|
307
|
+
if args.wds_info is None and len(args.data_path) == 1:
|
|
308
|
+
data_path = args.data_path[0]
|
|
309
|
+
if "://" in data_path and args.wds_size is None:
|
|
310
|
+
raise cli.ValidationError("--wds-size is required for remote --data-path")
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def main(args: argparse.Namespace) -> None:
|
|
314
|
+
validate_args(args)
|
|
315
|
+
evaluate_robust(args)
|
|
File without changes
|