birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. birder/adversarial/__init__.py +13 -0
  2. birder/adversarial/base.py +101 -0
  3. birder/adversarial/deepfool.py +173 -0
  4. birder/adversarial/fgsm.py +51 -18
  5. birder/adversarial/pgd.py +79 -28
  6. birder/adversarial/simba.py +172 -0
  7. birder/common/training_cli.py +11 -3
  8. birder/common/training_utils.py +18 -1
  9. birder/inference/data_parallel.py +1 -2
  10. birder/introspection/__init__.py +10 -6
  11. birder/introspection/attention_rollout.py +122 -54
  12. birder/introspection/base.py +73 -29
  13. birder/introspection/gradcam.py +71 -100
  14. birder/introspection/guided_backprop.py +146 -72
  15. birder/introspection/transformer_attribution.py +182 -0
  16. birder/net/detection/deformable_detr.py +14 -12
  17. birder/net/detection/detr.py +7 -3
  18. birder/net/detection/rt_detr_v1.py +3 -3
  19. birder/net/detection/yolo_v3.py +6 -11
  20. birder/net/detection/yolo_v4.py +7 -18
  21. birder/net/detection/yolo_v4_tiny.py +3 -3
  22. birder/net/fastvit.py +1 -1
  23. birder/net/mim/mae_vit.py +7 -8
  24. birder/net/pit.py +1 -1
  25. birder/net/resnet_v1.py +94 -34
  26. birder/net/ssl/data2vec.py +1 -1
  27. birder/net/ssl/data2vec2.py +4 -2
  28. birder/results/gui.py +15 -2
  29. birder/scripts/predict_detection.py +33 -1
  30. birder/scripts/train.py +24 -17
  31. birder/scripts/train_barlow_twins.py +10 -7
  32. birder/scripts/train_byol.py +10 -7
  33. birder/scripts/train_capi.py +12 -9
  34. birder/scripts/train_data2vec.py +10 -7
  35. birder/scripts/train_data2vec2.py +10 -7
  36. birder/scripts/train_detection.py +42 -18
  37. birder/scripts/train_dino_v1.py +10 -7
  38. birder/scripts/train_dino_v2.py +10 -7
  39. birder/scripts/train_dino_v2_dist.py +17 -7
  40. birder/scripts/train_franca.py +10 -7
  41. birder/scripts/train_i_jepa.py +17 -13
  42. birder/scripts/train_ibot.py +10 -7
  43. birder/scripts/train_kd.py +24 -18
  44. birder/scripts/train_mim.py +11 -10
  45. birder/scripts/train_mmcr.py +10 -7
  46. birder/scripts/train_rotnet.py +10 -7
  47. birder/scripts/train_simclr.py +10 -7
  48. birder/scripts/train_vicreg.py +10 -7
  49. birder/tools/__main__.py +6 -2
  50. birder/tools/adversarial.py +147 -96
  51. birder/tools/auto_anchors.py +361 -0
  52. birder/tools/ensemble_model.py +1 -1
  53. birder/tools/introspection.py +58 -31
  54. birder/version.py +1 -1
  55. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
  56. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
  57. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
  58. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
  59. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
  60. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
@@ -147,7 +147,8 @@ def train(args: argparse.Namespace) -> None:
147
147
  logger.info(f"Training on {len(training_dataset):,} samples")
148
148
 
149
149
  batch_size: int = args.batch_size
150
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
150
+ grad_accum_steps: int = args.grad_accum_steps
151
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
151
152
 
152
153
  # Data loaders and samplers
153
154
  if args.distributed is True:
@@ -179,6 +180,7 @@ def train(args: argparse.Namespace) -> None:
179
180
  drop_last=args.drop_last,
180
181
  )
181
182
 
183
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
182
184
  last_batch_idx = len(training_loader) - 1
183
185
  begin_epoch = 1
184
186
  epochs = args.epochs + 1
@@ -249,20 +251,19 @@ def train(args: argparse.Namespace) -> None:
249
251
 
250
252
  # Learning rate scaling
251
253
  lr = training_utils.scale_lr(args)
252
- grad_accum_steps: int = args.grad_accum_steps
253
254
 
254
255
  if args.lr_scheduler_update == "epoch":
255
256
  step_update = False
256
- steps_per_epoch = 1
257
+ scheduler_steps_per_epoch = 1
257
258
  elif args.lr_scheduler_update == "step":
258
259
  step_update = True
259
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
260
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
260
261
  else:
261
262
  raise ValueError("Unsupported lr_scheduler_update")
262
263
 
263
264
  # Optimizer and learning rate scheduler
264
265
  optimizer = training_utils.get_optimizer(parameters, lr, args)
265
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
266
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
266
267
  if args.compile_opt is True:
267
268
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
268
269
 
@@ -288,11 +289,13 @@ def train(args: argparse.Namespace) -> None:
288
289
  optimizer.step()
289
290
  lrs = []
290
291
  for _ in range(begin_epoch, epochs):
291
- for _ in range(steps_per_epoch):
292
+ for _ in range(scheduler_steps_per_epoch):
292
293
  lrs.append(float(max(scheduler.get_last_lr())))
293
294
  scheduler.step()
294
295
 
295
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
296
+ plt.plot(
297
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
298
+ )
296
299
  plt.show()
297
300
  raise SystemExit(0)
298
301
 
@@ -150,7 +150,8 @@ def train(args: argparse.Namespace) -> None:
150
150
  logger.info(f"Training on {len(training_dataset):,} samples")
151
151
 
152
152
  batch_size: int = args.batch_size
153
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
153
+ grad_accum_steps: int = args.grad_accum_steps
154
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
154
155
 
155
156
  # Data loaders and samplers
156
157
  if args.distributed is True:
@@ -182,6 +183,7 @@ def train(args: argparse.Namespace) -> None:
182
183
  drop_last=args.drop_last,
183
184
  )
184
185
 
186
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
185
187
  last_batch_idx = len(training_loader) - 1
186
188
  begin_epoch = 1
187
189
  epochs = args.epochs + 1
@@ -255,20 +257,19 @@ def train(args: argparse.Namespace) -> None:
255
257
 
256
258
  # Learning rate scaling
257
259
  lr = training_utils.scale_lr(args)
258
- grad_accum_steps: int = args.grad_accum_steps
259
260
 
260
261
  if args.lr_scheduler_update == "epoch":
261
262
  step_update = False
262
- steps_per_epoch = 1
263
+ scheduler_steps_per_epoch = 1
263
264
  elif args.lr_scheduler_update == "step":
264
265
  step_update = True
265
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
266
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
266
267
  else:
267
268
  raise ValueError("Unsupported lr_scheduler_update")
268
269
 
269
270
  # Optimizer and learning rate scheduler
270
271
  optimizer = training_utils.get_optimizer(parameters, lr, args)
271
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
272
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
272
273
  if args.compile_opt is True:
273
274
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
274
275
 
@@ -294,11 +295,13 @@ def train(args: argparse.Namespace) -> None:
294
295
  optimizer.step()
295
296
  lrs = []
296
297
  for _ in range(begin_epoch, epochs):
297
- for _ in range(steps_per_epoch):
298
+ for _ in range(scheduler_steps_per_epoch):
298
299
  lrs.append(float(max(scheduler.get_last_lr())))
299
300
  scheduler.step()
300
301
 
301
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
302
+ plt.plot(
303
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
304
+ )
302
305
  plt.show()
303
306
  raise SystemExit(0)
304
307
 
birder/tools/__main__.py CHANGED
@@ -2,6 +2,7 @@ import argparse
2
2
 
3
3
  from birder.common import cli
4
4
  from birder.tools import adversarial
5
+ from birder.tools import auto_anchors
5
6
  from birder.tools import avg_model
6
7
  from birder.tools import convert_model
7
8
  from birder.tools import det_results
@@ -30,8 +31,10 @@ def main() -> None:
30
31
  description="Tool to run auxiliary commands",
31
32
  epilog=(
32
33
  "Usage examples:\n"
33
- "python -m birder.tools adversarial --method fgsm -n swin_transformer_v1_s -e 0 "
34
- "--image 'data/training/Mallard/000112.jpeg'\n"
34
+ "python -m birder.tools adversarial --method pgd -n swin_transformer_v1_s -e 0 --eps 0.02 --steps 10 "
35
+ "data/training/Mallard/000112.jpeg\n"
36
+ "python -m birder.tools auto-anchors --preset yolo_v4 --size 640 "
37
+ "--coco-json-path data/detection_data/training_annotations_coco.json\n"
35
38
  "python -m birder.tools avg-model --network resnet_v2_50 --epochs 95 95 100\n"
36
39
  "python -m birder.tools convert-model --network convnext_v2_base --epoch 0 --pt2\n"
37
40
  "python -m birder.tools det-results "
@@ -60,6 +63,7 @@ def main() -> None:
60
63
  )
61
64
  subparsers = parser.add_subparsers(dest="cmd", required=True)
62
65
  adversarial.set_parser(subparsers)
66
+ auto_anchors.set_parser(subparsers)
63
67
  avg_model.set_parser(subparsers)
64
68
  convert_model.set_parser(subparsers)
65
69
  det_results.set_parser(subparsers)
@@ -1,76 +1,134 @@
1
1
  import argparse
2
2
  import logging
3
+ from collections.abc import Callable
3
4
  from typing import Any
5
+ from typing import Optional
4
6
 
5
7
  import matplotlib.pyplot as plt
6
8
  import numpy as np
9
+ import numpy.typing as npt
7
10
  import torch
8
11
  from PIL import Image
9
12
 
13
+ from birder.adversarial.base import Attack
14
+ from birder.adversarial.base import AttackResult
15
+ from birder.adversarial.deepfool import DeepFool
10
16
  from birder.adversarial.fgsm import FGSM
11
17
  from birder.adversarial.pgd import PGD
18
+ from birder.adversarial.simba import SimBA
12
19
  from birder.common import cli
13
20
  from birder.common import fs_ops
14
21
  from birder.common import lib
22
+ from birder.data.transforms.classification import RGBType
15
23
  from birder.data.transforms.classification import inference_preset
16
24
  from birder.data.transforms.classification import reverse_preset
17
25
 
18
26
  logger = logging.getLogger(__name__)
19
27
 
20
28
 
21
- def show_pgd(args: argparse.Namespace) -> None:
22
- if args.gpu is True:
23
- device = torch.device("cuda")
24
- else:
25
- device = torch.device("cpu")
26
-
27
- if args.gpu_id is not None:
28
- torch.cuda.set_device(args.gpu_id)
29
+ def _load_model_and_transform(
30
+ args: argparse.Namespace, device: torch.device
31
+ ) -> tuple[torch.nn.Module, dict[str, int], RGBType, Callable[..., torch.Tensor], Callable[..., torch.Tensor]]:
32
+ (net, model_info) = fs_ops.load_model(
33
+ device, args.network, tag=args.tag, epoch=args.epoch, inference=True, reparameterized=args.reparameterized
34
+ )
29
35
 
30
- logger.info(f"Using device {device}")
36
+ class_to_idx = model_info.class_to_idx
37
+ rgb_stats = model_info.rgb_stats
31
38
 
32
- (net, (class_to_idx, signature, rgb_stats, *_)) = fs_ops.load_model(
33
- device,
34
- args.network,
35
- tag=args.tag,
36
- epoch=args.epoch,
37
- inference=True,
38
- reparameterized=args.reparameterized,
39
- )
40
- label_names = list(class_to_idx.keys())
41
- size = lib.get_size_from_signature(signature)
39
+ size = lib.get_size_from_signature(model_info.signature)
42
40
  transform = inference_preset(size, rgb_stats, 1.0)
43
41
  reverse_transform = reverse_preset(rgb_stats)
44
42
 
45
- img: Image.Image = Image.open(args.image_path)
46
- input_tensor = transform(img).unsqueeze(dim=0).to(device)
43
+ return (net, class_to_idx, rgb_stats, transform, reverse_transform)
47
44
 
48
- pgd = PGD(net, eps=args.eps, max_delta=0.012, steps=10, random_start=True)
49
- if args.target is not None:
50
- target = torch.tensor(class_to_idx[args.target]).unsqueeze(dim=0).to(device)
51
- else:
52
- target = None
53
45
 
54
- img = img.resize(size)
55
- pgd_response = pgd(input_tensor, target=target)
56
- perturbation = reverse_transform(pgd_response.adv_img).cpu().detach().numpy().squeeze()
57
- pgd_img = np.moveaxis(perturbation, 0, 2)
46
+ def _resolve_target(
47
+ target_name: Optional[str], class_to_idx: dict[str, int], device: torch.device
48
+ ) -> Optional[torch.Tensor]:
49
+ if target_name is None:
50
+ return None
51
+ if target_name not in class_to_idx:
52
+ raise ValueError(f"Unknown target class '{target_name}'")
58
53
 
59
- # Get predictions and probabilities
60
- prob = pgd_response.out.cpu().detach().numpy().squeeze()
61
- adv_prob = pgd_response.adv_out.cpu().detach().numpy().squeeze()
62
- idx = np.argmax(prob)
63
- adv_idx = np.argmax(adv_prob)
54
+ return torch.tensor([class_to_idx[target_name]], device=device, dtype=torch.long)
64
55
 
56
+
57
+ def _build_attack(args: argparse.Namespace, net: torch.nn.Module, rgb_stats: RGBType) -> Attack:
58
+ if args.method == "fgsm":
59
+ return FGSM(net, eps=args.eps, rgb_stats=rgb_stats)
60
+ if args.method == "pgd":
61
+ return PGD(
62
+ net,
63
+ eps=args.eps,
64
+ steps=args.steps,
65
+ step_size=args.step_size,
66
+ random_start=args.random_start,
67
+ rgb_stats=rgb_stats,
68
+ )
69
+ if args.method == "deepfool":
70
+ return DeepFool(
71
+ net,
72
+ num_classes=args.deepfool_num_classes,
73
+ overshoot=args.deepfool_overshoot,
74
+ max_iter=args.deepfool_max_iter,
75
+ rgb_stats=rgb_stats,
76
+ )
77
+ if args.method == "simba":
78
+ return SimBA(
79
+ net,
80
+ step_size=args.step_size if args.step_size is not None else args.eps,
81
+ max_iter=args.steps,
82
+ rgb_stats=rgb_stats,
83
+ )
84
+
85
+ raise ValueError(f"Unsupported attack method '{args.method}'")
86
+
87
+
88
+ def _tensor_to_image(tensor: torch.Tensor, reverse_transform: Callable[..., torch.Tensor]) -> npt.NDArray[np.uint8]:
89
+ img_tensor = reverse_transform(tensor).cpu()
90
+ img = img_tensor.numpy()
91
+ return np.moveaxis(img, 0, 2)
92
+
93
+
94
+ def _get_prediction(logits: torch.Tensor, label_names: list[str]) -> tuple[str, float]:
95
+ probs = torch.softmax(logits, dim=1).cpu().numpy().squeeze()
96
+ idx = int(np.argmax(probs))
97
+ return (label_names[idx], float(probs[idx]))
98
+
99
+
100
+ def _display_results(
101
+ original_img: npt.NDArray[np.uint8],
102
+ adv_img: npt.NDArray[np.uint8],
103
+ original_pred: tuple[str, float],
104
+ adv_pred: tuple[str, float],
105
+ success: Optional[bool],
106
+ result: AttackResult,
107
+ ) -> None:
108
+ (orig_label, orig_prob) = original_pred
109
+ (adv_label, adv_prob) = adv_pred
110
+
111
+ # Log results
112
+ logger.info(f"Original: {orig_label} ({orig_prob * 100:.2f}%)")
113
+ logger.info(f"Adversarial: {adv_label} ({adv_prob * 100:.2f}%)")
114
+ if success is not None:
115
+ logger.info(f"Attack success: {success}")
116
+ if result.num_queries is not None:
117
+ logger.info(f"Model queries: {result.num_queries}")
118
+
119
+ # Display images
65
120
  _, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
66
- ax1.imshow(img)
67
- ax1.set_title(f"{label_names[idx]} {100 * prob[idx]:.2f}%")
68
- ax2.imshow(pgd_img)
69
- ax2.set_title(f"{label_names[adv_idx]} {100 * adv_prob[adv_idx]:.2f}%")
121
+ ax1.imshow(original_img)
122
+ ax1.set_title(f"{orig_label} {100 * orig_prob:.2f}%")
123
+ ax1.axis("off")
124
+ ax2.imshow(adv_img)
125
+ ax2.set_title(f"{adv_label} {100 * adv_prob:.2f}%")
126
+ ax2.axis("off")
127
+ plt.tight_layout()
70
128
  plt.show()
71
129
 
72
130
 
73
- def show_fgsm(args: argparse.Namespace) -> None:
131
+ def run_attack(args: argparse.Namespace) -> None:
74
132
  if args.gpu is True:
75
133
  device = torch.device("cuda")
76
134
  else:
@@ -81,83 +139,76 @@ def show_fgsm(args: argparse.Namespace) -> None:
81
139
 
82
140
  logger.info(f"Using device {device}")
83
141
 
84
- (net, (class_to_idx, signature, rgb_stats, *_)) = fs_ops.load_model(
85
- device,
86
- args.network,
87
- tag=args.tag,
88
- epoch=args.epoch,
89
- inference=False,
90
- reparameterized=args.reparameterized,
91
- )
92
- label_names = list(class_to_idx.keys())
93
- size = lib.get_size_from_signature(signature)
94
- transform = inference_preset(size, rgb_stats, 1.0)
95
-
96
- img: Image.Image = Image.open(args.image_path)
142
+ (net, class_to_idx, rgb_stats, transform, reverse_transform) = _load_model_and_transform(args, device)
143
+ label_names = [name for name, _idx in sorted(class_to_idx.items(), key=lambda item: item[1])]
144
+ img = Image.open(args.image_path)
97
145
  input_tensor = transform(img).unsqueeze(dim=0).to(device)
98
146
 
99
- fgsm = FGSM(net, eps=args.eps)
100
- if args.target is not None:
101
- target = torch.tensor(class_to_idx[args.target]).unsqueeze(dim=0).to(device)
102
- else:
103
- target = None
147
+ target = _resolve_target(args.target, class_to_idx, device)
148
+ attack = _build_attack(args, net, rgb_stats)
149
+ result = attack(input_tensor, target=target)
104
150
 
105
- img = img.resize(size)
106
- fgsm_response = fgsm(input_tensor, target=target)
107
- perturbation = fgsm_response.perturbation.cpu().detach().numpy().squeeze()
108
- fgsm_img = (np.array(img).astype(np.float32) / 255.0) + np.moveaxis(perturbation, 0, 2)
109
- fgsm_img = np.clip(fgsm_img, 0, 1)
151
+ original_img = _tensor_to_image(input_tensor.squeeze(0).cpu(), reverse_transform)
152
+ adv_img = _tensor_to_image(result.adv_inputs.squeeze(0).cpu(), reverse_transform)
153
+ original_logits = result.logits
154
+ if original_logits is None:
155
+ with torch.no_grad():
156
+ original_logits = net(input_tensor)
110
157
 
111
- # Get predictions and probabilities
112
- prob = fgsm_response.out.cpu().detach().numpy().squeeze()
113
- adv_prob = fgsm_response.adv_out.cpu().detach().numpy().squeeze()
114
- idx = np.argmax(prob)
115
- adv_idx = np.argmax(adv_prob)
158
+ original_pred = _get_prediction(original_logits, label_names)
159
+ adv_pred = _get_prediction(result.adv_logits, label_names)
160
+ success = bool(result.success.item()) if result.success is not None else None
116
161
 
117
- _, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
118
- ax1.imshow(img)
119
- ax1.set_title(f"{label_names[idx]} {100 * prob[idx]:.2f}%")
120
- ax2.imshow(fgsm_img)
121
- ax2.set_title(f"{label_names[adv_idx]} {100 * adv_prob[adv_idx]:.2f}%")
122
- plt.show()
162
+ _display_results(original_img, adv_img, original_pred, adv_pred, success, result)
123
163
 
124
164
 
125
165
  def set_parser(subparsers: Any) -> None:
126
166
  subparser = subparsers.add_parser(
127
167
  "adversarial",
128
168
  allow_abbrev=False,
129
- help="deep learning adversarial attacks",
130
- description="deep learning adversarial attacks",
169
+ help="generate and visualize adversarial examples",
170
+ description="generate and visualize adversarial examples",
131
171
  epilog=(
132
172
  "Usage examples:\n"
133
- "python -m birder.tools adversarial --method fgsm --network efficientnet_v2_s "
134
- "--epoch 0 --target Bluethroat 'data/training/Mallard/000117.jpeg'\n"
135
- "python -m birder.tools adversarial --method fgsm --network efficientnet_v2_m "
136
- "--epoch 0 --eps 0.02 --target Mallard 'data/validation/White-tailed eagle/000006.jpeg'\n"
137
- "python tool.py adversarial --method pgd --network caformer_s18 -e 0 "
138
- "data/validation/Arabian babbler/000001.jpeg\n"
173
+ "python -m birder.tools adversarial -n resnet_v2_50 -e 0 --method fgsm --eps 0.02 "
174
+ "data/validation/Mallard/000112.jpeg\n"
175
+ "python -m birder.tools adversarial -n efficientnet_v2_m -e 0 --method pgd --eps 0.02 --steps 10 "
176
+ "data/validation/Mallard/000002.jpeg\n"
177
+ "python -m birder.tools adversarial -n convnext_v2_tiny -e 0 --method deepfool "
178
+ "data/validation/Bluethroat/000013.jpeg\n"
179
+ "python -m birder.tools adversarial -n convnext_v2_tiny -e 0 --method simba --steps 1000 --step-size 0.1 "
180
+ "data/validation/Bluethroat/000043.jpeg\n"
139
181
  ),
140
182
  formatter_class=cli.ArgumentHelpFormatter,
141
183
  )
142
- subparser.add_argument("--method", type=str, choices=["fgsm", "pgd"], help="introspection method")
184
+ subparser.add_argument("-n", "--network", type=str, required=True, help="neural network to attack")
185
+ subparser.add_argument("-t", "--tag", type=str, help="model tag")
186
+ subparser.add_argument("-e", "--epoch", type=int, required=True, help="model checkpoint epoch")
187
+ subparser.add_argument("--reparameterized", default=False, action="store_true", help="load reparameterized model")
188
+ subparser.add_argument("--gpu", default=False, action="store_true", help="use GPU")
189
+ subparser.add_argument("--gpu-id", type=int, metavar="ID", help="GPU device ID")
143
190
  subparser.add_argument(
144
- "-n", "--network", type=str, required=True, help="the neural network to use (i.e. resnet_v2)"
191
+ "--method",
192
+ type=str,
193
+ required=True,
194
+ choices=["fgsm", "pgd", "deepfool", "simba"],
195
+ help="adversarial attack method",
145
196
  )
146
- subparser.add_argument("-e", "--epoch", type=int, metavar="N", help="model checkpoint to load")
147
- subparser.add_argument("-t", "--tag", type=str, help="model tag (from the training phase)")
197
+ subparser.add_argument("--eps", type=float, default=0.007, help="perturbation budget in pixel space [0, 1]")
198
+ subparser.add_argument("--target", type=str, help="target class name for targeted attack (omit for untargeted)")
199
+ subparser.add_argument("--steps", type=int, default=10, help="number of iterations for iterative attacks")
200
+ subparser.add_argument("--step-size", type=float, help="step size in pixel space (defaults to eps/steps for PGD)")
148
201
  subparser.add_argument(
149
- "-r", "--reparameterized", default=False, action="store_true", help="load reparameterized model"
202
+ "--random-start", default=False, action="store_true", help="use random initialization for PGD"
150
203
  )
151
- subparser.add_argument("--gpu", default=False, action="store_true", help="use gpu")
152
- subparser.add_argument("--gpu-id", type=int, metavar="ID", help="gpu id to use")
153
- subparser.add_argument("--eps", type=float, default=0.007, help="fgsm epsilon")
154
- subparser.add_argument("--target", type=str, help="target class, leave empty to use predicted class")
155
- subparser.add_argument("image_path", type=str, help="input image path")
204
+ subparser.add_argument(
205
+ "--deepfool-num-classes", type=int, default=10, help="number of top classes to consider for DeepFool"
206
+ )
207
+ subparser.add_argument("--deepfool-overshoot", type=float, default=0.02, help="overshoot parameter for DeepFool")
208
+ subparser.add_argument("--deepfool-max-iter", type=int, default=50, help="max iterations for DeepFool")
209
+ subparser.add_argument("image_path", type=str, help="path to input image")
156
210
  subparser.set_defaults(func=main)
157
211
 
158
212
 
159
213
  def main(args: argparse.Namespace) -> None:
160
- if args.method == "fgsm":
161
- show_fgsm(args)
162
- elif args.method == "pgd":
163
- show_pgd(args)
214
+ run_attack(args)