birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/__init__.py +13 -0
- birder/adversarial/base.py +101 -0
- birder/adversarial/deepfool.py +173 -0
- birder/adversarial/fgsm.py +51 -18
- birder/adversarial/pgd.py +79 -28
- birder/adversarial/simba.py +172 -0
- birder/common/training_cli.py +11 -3
- birder/common/training_utils.py +18 -1
- birder/inference/data_parallel.py +1 -2
- birder/introspection/__init__.py +10 -6
- birder/introspection/attention_rollout.py +122 -54
- birder/introspection/base.py +73 -29
- birder/introspection/gradcam.py +71 -100
- birder/introspection/guided_backprop.py +146 -72
- birder/introspection/transformer_attribution.py +182 -0
- birder/net/detection/deformable_detr.py +14 -12
- birder/net/detection/detr.py +7 -3
- birder/net/detection/rt_detr_v1.py +3 -3
- birder/net/detection/yolo_v3.py +6 -11
- birder/net/detection/yolo_v4.py +7 -18
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/fastvit.py +1 -1
- birder/net/mim/mae_vit.py +7 -8
- birder/net/pit.py +1 -1
- birder/net/resnet_v1.py +94 -34
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +4 -2
- birder/results/gui.py +15 -2
- birder/scripts/predict_detection.py +33 -1
- birder/scripts/train.py +24 -17
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +12 -9
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +42 -18
- birder/scripts/train_dino_v1.py +10 -7
- birder/scripts/train_dino_v2.py +10 -7
- birder/scripts/train_dino_v2_dist.py +17 -7
- birder/scripts/train_franca.py +10 -7
- birder/scripts/train_i_jepa.py +17 -13
- birder/scripts/train_ibot.py +10 -7
- birder/scripts/train_kd.py +24 -18
- birder/scripts/train_mim.py +11 -10
- birder/scripts/train_mmcr.py +10 -7
- birder/scripts/train_rotnet.py +10 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/__main__.py +6 -2
- birder/tools/adversarial.py +147 -96
- birder/tools/auto_anchors.py +361 -0
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +58 -31
- birder/version.py +1 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
birder/scripts/train_simclr.py
CHANGED
|
@@ -147,7 +147,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
147
147
|
logger.info(f"Training on {len(training_dataset):,} samples")
|
|
148
148
|
|
|
149
149
|
batch_size: int = args.batch_size
|
|
150
|
-
|
|
150
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
151
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
151
152
|
|
|
152
153
|
# Data loaders and samplers
|
|
153
154
|
if args.distributed is True:
|
|
@@ -179,6 +180,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
179
180
|
drop_last=args.drop_last,
|
|
180
181
|
)
|
|
181
182
|
|
|
183
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
182
184
|
last_batch_idx = len(training_loader) - 1
|
|
183
185
|
begin_epoch = 1
|
|
184
186
|
epochs = args.epochs + 1
|
|
@@ -249,20 +251,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
249
251
|
|
|
250
252
|
# Learning rate scaling
|
|
251
253
|
lr = training_utils.scale_lr(args)
|
|
252
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
253
254
|
|
|
254
255
|
if args.lr_scheduler_update == "epoch":
|
|
255
256
|
step_update = False
|
|
256
|
-
|
|
257
|
+
scheduler_steps_per_epoch = 1
|
|
257
258
|
elif args.lr_scheduler_update == "step":
|
|
258
259
|
step_update = True
|
|
259
|
-
|
|
260
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
260
261
|
else:
|
|
261
262
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
262
263
|
|
|
263
264
|
# Optimizer and learning rate scheduler
|
|
264
265
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
265
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
266
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
266
267
|
if args.compile_opt is True:
|
|
267
268
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
268
269
|
|
|
@@ -288,11 +289,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
288
289
|
optimizer.step()
|
|
289
290
|
lrs = []
|
|
290
291
|
for _ in range(begin_epoch, epochs):
|
|
291
|
-
for _ in range(
|
|
292
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
292
293
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
293
294
|
scheduler.step()
|
|
294
295
|
|
|
295
|
-
plt.plot(
|
|
296
|
+
plt.plot(
|
|
297
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
298
|
+
)
|
|
296
299
|
plt.show()
|
|
297
300
|
raise SystemExit(0)
|
|
298
301
|
|
birder/scripts/train_vicreg.py
CHANGED
|
@@ -150,7 +150,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
150
150
|
logger.info(f"Training on {len(training_dataset):,} samples")
|
|
151
151
|
|
|
152
152
|
batch_size: int = args.batch_size
|
|
153
|
-
|
|
153
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
154
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
154
155
|
|
|
155
156
|
# Data loaders and samplers
|
|
156
157
|
if args.distributed is True:
|
|
@@ -182,6 +183,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
182
183
|
drop_last=args.drop_last,
|
|
183
184
|
)
|
|
184
185
|
|
|
186
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
185
187
|
last_batch_idx = len(training_loader) - 1
|
|
186
188
|
begin_epoch = 1
|
|
187
189
|
epochs = args.epochs + 1
|
|
@@ -255,20 +257,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
255
257
|
|
|
256
258
|
# Learning rate scaling
|
|
257
259
|
lr = training_utils.scale_lr(args)
|
|
258
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
259
260
|
|
|
260
261
|
if args.lr_scheduler_update == "epoch":
|
|
261
262
|
step_update = False
|
|
262
|
-
|
|
263
|
+
scheduler_steps_per_epoch = 1
|
|
263
264
|
elif args.lr_scheduler_update == "step":
|
|
264
265
|
step_update = True
|
|
265
|
-
|
|
266
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
266
267
|
else:
|
|
267
268
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
268
269
|
|
|
269
270
|
# Optimizer and learning rate scheduler
|
|
270
271
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
271
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
272
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
272
273
|
if args.compile_opt is True:
|
|
273
274
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
274
275
|
|
|
@@ -294,11 +295,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
294
295
|
optimizer.step()
|
|
295
296
|
lrs = []
|
|
296
297
|
for _ in range(begin_epoch, epochs):
|
|
297
|
-
for _ in range(
|
|
298
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
298
299
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
299
300
|
scheduler.step()
|
|
300
301
|
|
|
301
|
-
plt.plot(
|
|
302
|
+
plt.plot(
|
|
303
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
304
|
+
)
|
|
302
305
|
plt.show()
|
|
303
306
|
raise SystemExit(0)
|
|
304
307
|
|
birder/tools/__main__.py
CHANGED
|
@@ -2,6 +2,7 @@ import argparse
|
|
|
2
2
|
|
|
3
3
|
from birder.common import cli
|
|
4
4
|
from birder.tools import adversarial
|
|
5
|
+
from birder.tools import auto_anchors
|
|
5
6
|
from birder.tools import avg_model
|
|
6
7
|
from birder.tools import convert_model
|
|
7
8
|
from birder.tools import det_results
|
|
@@ -30,8 +31,10 @@ def main() -> None:
|
|
|
30
31
|
description="Tool to run auxiliary commands",
|
|
31
32
|
epilog=(
|
|
32
33
|
"Usage examples:\n"
|
|
33
|
-
"python -m birder.tools adversarial --method
|
|
34
|
-
"
|
|
34
|
+
"python -m birder.tools adversarial --method pgd -n swin_transformer_v1_s -e 0 --eps 0.02 --steps 10 "
|
|
35
|
+
"data/training/Mallard/000112.jpeg\n"
|
|
36
|
+
"python -m birder.tools auto-anchors --preset yolo_v4 --size 640 "
|
|
37
|
+
"--coco-json-path data/detection_data/training_annotations_coco.json\n"
|
|
35
38
|
"python -m birder.tools avg-model --network resnet_v2_50 --epochs 95 95 100\n"
|
|
36
39
|
"python -m birder.tools convert-model --network convnext_v2_base --epoch 0 --pt2\n"
|
|
37
40
|
"python -m birder.tools det-results "
|
|
@@ -60,6 +63,7 @@ def main() -> None:
|
|
|
60
63
|
)
|
|
61
64
|
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
|
62
65
|
adversarial.set_parser(subparsers)
|
|
66
|
+
auto_anchors.set_parser(subparsers)
|
|
63
67
|
avg_model.set_parser(subparsers)
|
|
64
68
|
convert_model.set_parser(subparsers)
|
|
65
69
|
det_results.set_parser(subparsers)
|
birder/tools/adversarial.py
CHANGED
|
@@ -1,76 +1,134 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import logging
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from typing import Any
|
|
5
|
+
from typing import Optional
|
|
4
6
|
|
|
5
7
|
import matplotlib.pyplot as plt
|
|
6
8
|
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
7
10
|
import torch
|
|
8
11
|
from PIL import Image
|
|
9
12
|
|
|
13
|
+
from birder.adversarial.base import Attack
|
|
14
|
+
from birder.adversarial.base import AttackResult
|
|
15
|
+
from birder.adversarial.deepfool import DeepFool
|
|
10
16
|
from birder.adversarial.fgsm import FGSM
|
|
11
17
|
from birder.adversarial.pgd import PGD
|
|
18
|
+
from birder.adversarial.simba import SimBA
|
|
12
19
|
from birder.common import cli
|
|
13
20
|
from birder.common import fs_ops
|
|
14
21
|
from birder.common import lib
|
|
22
|
+
from birder.data.transforms.classification import RGBType
|
|
15
23
|
from birder.data.transforms.classification import inference_preset
|
|
16
24
|
from birder.data.transforms.classification import reverse_preset
|
|
17
25
|
|
|
18
26
|
logger = logging.getLogger(__name__)
|
|
19
27
|
|
|
20
28
|
|
|
21
|
-
def
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
device =
|
|
26
|
-
|
|
27
|
-
if args.gpu_id is not None:
|
|
28
|
-
torch.cuda.set_device(args.gpu_id)
|
|
29
|
+
def _load_model_and_transform(
|
|
30
|
+
args: argparse.Namespace, device: torch.device
|
|
31
|
+
) -> tuple[torch.nn.Module, dict[str, int], RGBType, Callable[..., torch.Tensor], Callable[..., torch.Tensor]]:
|
|
32
|
+
(net, model_info) = fs_ops.load_model(
|
|
33
|
+
device, args.network, tag=args.tag, epoch=args.epoch, inference=True, reparameterized=args.reparameterized
|
|
34
|
+
)
|
|
29
35
|
|
|
30
|
-
|
|
36
|
+
class_to_idx = model_info.class_to_idx
|
|
37
|
+
rgb_stats = model_info.rgb_stats
|
|
31
38
|
|
|
32
|
-
|
|
33
|
-
device,
|
|
34
|
-
args.network,
|
|
35
|
-
tag=args.tag,
|
|
36
|
-
epoch=args.epoch,
|
|
37
|
-
inference=True,
|
|
38
|
-
reparameterized=args.reparameterized,
|
|
39
|
-
)
|
|
40
|
-
label_names = list(class_to_idx.keys())
|
|
41
|
-
size = lib.get_size_from_signature(signature)
|
|
39
|
+
size = lib.get_size_from_signature(model_info.signature)
|
|
42
40
|
transform = inference_preset(size, rgb_stats, 1.0)
|
|
43
41
|
reverse_transform = reverse_preset(rgb_stats)
|
|
44
42
|
|
|
45
|
-
|
|
46
|
-
input_tensor = transform(img).unsqueeze(dim=0).to(device)
|
|
43
|
+
return (net, class_to_idx, rgb_stats, transform, reverse_transform)
|
|
47
44
|
|
|
48
|
-
pgd = PGD(net, eps=args.eps, max_delta=0.012, steps=10, random_start=True)
|
|
49
|
-
if args.target is not None:
|
|
50
|
-
target = torch.tensor(class_to_idx[args.target]).unsqueeze(dim=0).to(device)
|
|
51
|
-
else:
|
|
52
|
-
target = None
|
|
53
45
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
46
|
+
def _resolve_target(
|
|
47
|
+
target_name: Optional[str], class_to_idx: dict[str, int], device: torch.device
|
|
48
|
+
) -> Optional[torch.Tensor]:
|
|
49
|
+
if target_name is None:
|
|
50
|
+
return None
|
|
51
|
+
if target_name not in class_to_idx:
|
|
52
|
+
raise ValueError(f"Unknown target class '{target_name}'")
|
|
58
53
|
|
|
59
|
-
|
|
60
|
-
prob = pgd_response.out.cpu().detach().numpy().squeeze()
|
|
61
|
-
adv_prob = pgd_response.adv_out.cpu().detach().numpy().squeeze()
|
|
62
|
-
idx = np.argmax(prob)
|
|
63
|
-
adv_idx = np.argmax(adv_prob)
|
|
54
|
+
return torch.tensor([class_to_idx[target_name]], device=device, dtype=torch.long)
|
|
64
55
|
|
|
56
|
+
|
|
57
|
+
def _build_attack(args: argparse.Namespace, net: torch.nn.Module, rgb_stats: RGBType) -> Attack:
|
|
58
|
+
if args.method == "fgsm":
|
|
59
|
+
return FGSM(net, eps=args.eps, rgb_stats=rgb_stats)
|
|
60
|
+
if args.method == "pgd":
|
|
61
|
+
return PGD(
|
|
62
|
+
net,
|
|
63
|
+
eps=args.eps,
|
|
64
|
+
steps=args.steps,
|
|
65
|
+
step_size=args.step_size,
|
|
66
|
+
random_start=args.random_start,
|
|
67
|
+
rgb_stats=rgb_stats,
|
|
68
|
+
)
|
|
69
|
+
if args.method == "deepfool":
|
|
70
|
+
return DeepFool(
|
|
71
|
+
net,
|
|
72
|
+
num_classes=args.deepfool_num_classes,
|
|
73
|
+
overshoot=args.deepfool_overshoot,
|
|
74
|
+
max_iter=args.deepfool_max_iter,
|
|
75
|
+
rgb_stats=rgb_stats,
|
|
76
|
+
)
|
|
77
|
+
if args.method == "simba":
|
|
78
|
+
return SimBA(
|
|
79
|
+
net,
|
|
80
|
+
step_size=args.step_size if args.step_size is not None else args.eps,
|
|
81
|
+
max_iter=args.steps,
|
|
82
|
+
rgb_stats=rgb_stats,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
raise ValueError(f"Unsupported attack method '{args.method}'")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _tensor_to_image(tensor: torch.Tensor, reverse_transform: Callable[..., torch.Tensor]) -> npt.NDArray[np.uint8]:
|
|
89
|
+
img_tensor = reverse_transform(tensor).cpu()
|
|
90
|
+
img = img_tensor.numpy()
|
|
91
|
+
return np.moveaxis(img, 0, 2)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _get_prediction(logits: torch.Tensor, label_names: list[str]) -> tuple[str, float]:
|
|
95
|
+
probs = torch.softmax(logits, dim=1).cpu().numpy().squeeze()
|
|
96
|
+
idx = int(np.argmax(probs))
|
|
97
|
+
return (label_names[idx], float(probs[idx]))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _display_results(
|
|
101
|
+
original_img: npt.NDArray[np.uint8],
|
|
102
|
+
adv_img: npt.NDArray[np.uint8],
|
|
103
|
+
original_pred: tuple[str, float],
|
|
104
|
+
adv_pred: tuple[str, float],
|
|
105
|
+
success: Optional[bool],
|
|
106
|
+
result: AttackResult,
|
|
107
|
+
) -> None:
|
|
108
|
+
(orig_label, orig_prob) = original_pred
|
|
109
|
+
(adv_label, adv_prob) = adv_pred
|
|
110
|
+
|
|
111
|
+
# Log results
|
|
112
|
+
logger.info(f"Original: {orig_label} ({orig_prob * 100:.2f}%)")
|
|
113
|
+
logger.info(f"Adversarial: {adv_label} ({adv_prob * 100:.2f}%)")
|
|
114
|
+
if success is not None:
|
|
115
|
+
logger.info(f"Attack success: {success}")
|
|
116
|
+
if result.num_queries is not None:
|
|
117
|
+
logger.info(f"Model queries: {result.num_queries}")
|
|
118
|
+
|
|
119
|
+
# Display images
|
|
65
120
|
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
|
|
66
|
-
ax1.imshow(
|
|
67
|
-
ax1.set_title(f"{
|
|
68
|
-
|
|
69
|
-
ax2.
|
|
121
|
+
ax1.imshow(original_img)
|
|
122
|
+
ax1.set_title(f"{orig_label} {100 * orig_prob:.2f}%")
|
|
123
|
+
ax1.axis("off")
|
|
124
|
+
ax2.imshow(adv_img)
|
|
125
|
+
ax2.set_title(f"{adv_label} {100 * adv_prob:.2f}%")
|
|
126
|
+
ax2.axis("off")
|
|
127
|
+
plt.tight_layout()
|
|
70
128
|
plt.show()
|
|
71
129
|
|
|
72
130
|
|
|
73
|
-
def
|
|
131
|
+
def run_attack(args: argparse.Namespace) -> None:
|
|
74
132
|
if args.gpu is True:
|
|
75
133
|
device = torch.device("cuda")
|
|
76
134
|
else:
|
|
@@ -81,83 +139,76 @@ def show_fgsm(args: argparse.Namespace) -> None:
|
|
|
81
139
|
|
|
82
140
|
logger.info(f"Using device {device}")
|
|
83
141
|
|
|
84
|
-
(net,
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
tag=args.tag,
|
|
88
|
-
epoch=args.epoch,
|
|
89
|
-
inference=False,
|
|
90
|
-
reparameterized=args.reparameterized,
|
|
91
|
-
)
|
|
92
|
-
label_names = list(class_to_idx.keys())
|
|
93
|
-
size = lib.get_size_from_signature(signature)
|
|
94
|
-
transform = inference_preset(size, rgb_stats, 1.0)
|
|
95
|
-
|
|
96
|
-
img: Image.Image = Image.open(args.image_path)
|
|
142
|
+
(net, class_to_idx, rgb_stats, transform, reverse_transform) = _load_model_and_transform(args, device)
|
|
143
|
+
label_names = [name for name, _idx in sorted(class_to_idx.items(), key=lambda item: item[1])]
|
|
144
|
+
img = Image.open(args.image_path)
|
|
97
145
|
input_tensor = transform(img).unsqueeze(dim=0).to(device)
|
|
98
146
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
else:
|
|
103
|
-
target = None
|
|
147
|
+
target = _resolve_target(args.target, class_to_idx, device)
|
|
148
|
+
attack = _build_attack(args, net, rgb_stats)
|
|
149
|
+
result = attack(input_tensor, target=target)
|
|
104
150
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
151
|
+
original_img = _tensor_to_image(input_tensor.squeeze(0).cpu(), reverse_transform)
|
|
152
|
+
adv_img = _tensor_to_image(result.adv_inputs.squeeze(0).cpu(), reverse_transform)
|
|
153
|
+
original_logits = result.logits
|
|
154
|
+
if original_logits is None:
|
|
155
|
+
with torch.no_grad():
|
|
156
|
+
original_logits = net(input_tensor)
|
|
110
157
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
idx = np.argmax(prob)
|
|
115
|
-
adv_idx = np.argmax(adv_prob)
|
|
158
|
+
original_pred = _get_prediction(original_logits, label_names)
|
|
159
|
+
adv_pred = _get_prediction(result.adv_logits, label_names)
|
|
160
|
+
success = bool(result.success.item()) if result.success is not None else None
|
|
116
161
|
|
|
117
|
-
|
|
118
|
-
ax1.imshow(img)
|
|
119
|
-
ax1.set_title(f"{label_names[idx]} {100 * prob[idx]:.2f}%")
|
|
120
|
-
ax2.imshow(fgsm_img)
|
|
121
|
-
ax2.set_title(f"{label_names[adv_idx]} {100 * adv_prob[adv_idx]:.2f}%")
|
|
122
|
-
plt.show()
|
|
162
|
+
_display_results(original_img, adv_img, original_pred, adv_pred, success, result)
|
|
123
163
|
|
|
124
164
|
|
|
125
165
|
def set_parser(subparsers: Any) -> None:
|
|
126
166
|
subparser = subparsers.add_parser(
|
|
127
167
|
"adversarial",
|
|
128
168
|
allow_abbrev=False,
|
|
129
|
-
help="
|
|
130
|
-
description="
|
|
169
|
+
help="generate and visualize adversarial examples",
|
|
170
|
+
description="generate and visualize adversarial examples",
|
|
131
171
|
epilog=(
|
|
132
172
|
"Usage examples:\n"
|
|
133
|
-
"python -m birder.tools adversarial --method fgsm --
|
|
134
|
-
"
|
|
135
|
-
"python -m birder.tools adversarial --method
|
|
136
|
-
"
|
|
137
|
-
"python
|
|
138
|
-
"data/validation/
|
|
173
|
+
"python -m birder.tools adversarial -n resnet_v2_50 -e 0 --method fgsm --eps 0.02 "
|
|
174
|
+
"data/validation/Mallard/000112.jpeg\n"
|
|
175
|
+
"python -m birder.tools adversarial -n efficientnet_v2_m -e 0 --method pgd --eps 0.02 --steps 10 "
|
|
176
|
+
"data/validation/Mallard/000002.jpeg\n"
|
|
177
|
+
"python -m birder.tools adversarial -n convnext_v2_tiny -e 0 --method deepfool "
|
|
178
|
+
"data/validation/Bluethroat/000013.jpeg\n"
|
|
179
|
+
"python -m birder.tools adversarial -n convnext_v2_tiny -e 0 --method simba --steps 1000 --step-size 0.1 "
|
|
180
|
+
"data/validation/Bluethroat/000043.jpeg\n"
|
|
139
181
|
),
|
|
140
182
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
141
183
|
)
|
|
142
|
-
subparser.add_argument("--
|
|
184
|
+
subparser.add_argument("-n", "--network", type=str, required=True, help="neural network to attack")
|
|
185
|
+
subparser.add_argument("-t", "--tag", type=str, help="model tag")
|
|
186
|
+
subparser.add_argument("-e", "--epoch", type=int, required=True, help="model checkpoint epoch")
|
|
187
|
+
subparser.add_argument("--reparameterized", default=False, action="store_true", help="load reparameterized model")
|
|
188
|
+
subparser.add_argument("--gpu", default=False, action="store_true", help="use GPU")
|
|
189
|
+
subparser.add_argument("--gpu-id", type=int, metavar="ID", help="GPU device ID")
|
|
143
190
|
subparser.add_argument(
|
|
144
|
-
"
|
|
191
|
+
"--method",
|
|
192
|
+
type=str,
|
|
193
|
+
required=True,
|
|
194
|
+
choices=["fgsm", "pgd", "deepfool", "simba"],
|
|
195
|
+
help="adversarial attack method",
|
|
145
196
|
)
|
|
146
|
-
subparser.add_argument("
|
|
147
|
-
subparser.add_argument("
|
|
197
|
+
subparser.add_argument("--eps", type=float, default=0.007, help="perturbation budget in pixel space [0, 1]")
|
|
198
|
+
subparser.add_argument("--target", type=str, help="target class name for targeted attack (omit for untargeted)")
|
|
199
|
+
subparser.add_argument("--steps", type=int, default=10, help="number of iterations for iterative attacks")
|
|
200
|
+
subparser.add_argument("--step-size", type=float, help="step size in pixel space (defaults to eps/steps for PGD)")
|
|
148
201
|
subparser.add_argument(
|
|
149
|
-
"-
|
|
202
|
+
"--random-start", default=False, action="store_true", help="use random initialization for PGD"
|
|
150
203
|
)
|
|
151
|
-
subparser.add_argument(
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
subparser.add_argument("--
|
|
155
|
-
subparser.add_argument("
|
|
204
|
+
subparser.add_argument(
|
|
205
|
+
"--deepfool-num-classes", type=int, default=10, help="number of top classes to consider for DeepFool"
|
|
206
|
+
)
|
|
207
|
+
subparser.add_argument("--deepfool-overshoot", type=float, default=0.02, help="overshoot parameter for DeepFool")
|
|
208
|
+
subparser.add_argument("--deepfool-max-iter", type=int, default=50, help="max iterations for DeepFool")
|
|
209
|
+
subparser.add_argument("image_path", type=str, help="path to input image")
|
|
156
210
|
subparser.set_defaults(func=main)
|
|
157
211
|
|
|
158
212
|
|
|
159
213
|
def main(args: argparse.Namespace) -> None:
|
|
160
|
-
|
|
161
|
-
show_fgsm(args)
|
|
162
|
-
elif args.method == "pgd":
|
|
163
|
-
show_pgd(args)
|
|
214
|
+
run_attack(args)
|