birder 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. birder/adversarial/__init__.py +13 -0
  2. birder/adversarial/base.py +101 -0
  3. birder/adversarial/deepfool.py +173 -0
  4. birder/adversarial/fgsm.py +51 -18
  5. birder/adversarial/pgd.py +79 -28
  6. birder/adversarial/simba.py +172 -0
  7. birder/common/lib.py +2 -9
  8. birder/common/training_cli.py +29 -3
  9. birder/common/training_utils.py +141 -11
  10. birder/data/collators/detection.py +10 -3
  11. birder/data/datasets/coco.py +8 -10
  12. birder/data/transforms/detection.py +30 -13
  13. birder/inference/data_parallel.py +1 -2
  14. birder/inference/detection.py +108 -4
  15. birder/inference/wbf.py +226 -0
  16. birder/introspection/__init__.py +10 -6
  17. birder/introspection/attention_rollout.py +122 -54
  18. birder/introspection/base.py +73 -29
  19. birder/introspection/gradcam.py +71 -100
  20. birder/introspection/guided_backprop.py +146 -72
  21. birder/introspection/transformer_attribution.py +182 -0
  22. birder/net/__init__.py +8 -0
  23. birder/net/detection/deformable_detr.py +14 -12
  24. birder/net/detection/detr.py +7 -3
  25. birder/net/detection/efficientdet.py +65 -86
  26. birder/net/detection/rt_detr_v1.py +4 -3
  27. birder/net/detection/yolo_anchors.py +205 -0
  28. birder/net/detection/yolo_v2.py +25 -24
  29. birder/net/detection/yolo_v3.py +42 -48
  30. birder/net/detection/yolo_v4.py +31 -40
  31. birder/net/detection/yolo_v4_tiny.py +24 -20
  32. birder/net/fasternet.py +1 -1
  33. birder/net/fastvit.py +1 -1
  34. birder/net/gc_vit.py +671 -0
  35. birder/net/lit_v1.py +472 -0
  36. birder/net/lit_v1_tiny.py +342 -0
  37. birder/net/lit_v2.py +436 -0
  38. birder/net/mim/mae_vit.py +7 -8
  39. birder/net/mobilenet_v4_hybrid.py +1 -1
  40. birder/net/pit.py +1 -1
  41. birder/net/resnet_v1.py +95 -35
  42. birder/net/resnext.py +67 -25
  43. birder/net/se_resnet_v1.py +46 -0
  44. birder/net/se_resnext.py +3 -0
  45. birder/net/simple_vit.py +2 -2
  46. birder/net/ssl/data2vec.py +1 -1
  47. birder/net/ssl/data2vec2.py +4 -2
  48. birder/net/vit.py +0 -15
  49. birder/net/vovnet_v2.py +31 -1
  50. birder/results/gui.py +15 -2
  51. birder/scripts/benchmark.py +90 -21
  52. birder/scripts/predict.py +1 -0
  53. birder/scripts/predict_detection.py +48 -9
  54. birder/scripts/train.py +33 -50
  55. birder/scripts/train_barlow_twins.py +19 -40
  56. birder/scripts/train_byol.py +19 -40
  57. birder/scripts/train_capi.py +21 -43
  58. birder/scripts/train_data2vec.py +18 -40
  59. birder/scripts/train_data2vec2.py +18 -40
  60. birder/scripts/train_detection.py +89 -57
  61. birder/scripts/train_dino_v1.py +19 -40
  62. birder/scripts/train_dino_v2.py +18 -40
  63. birder/scripts/train_dino_v2_dist.py +25 -40
  64. birder/scripts/train_franca.py +18 -40
  65. birder/scripts/train_i_jepa.py +25 -46
  66. birder/scripts/train_ibot.py +18 -40
  67. birder/scripts/train_kd.py +179 -81
  68. birder/scripts/train_mim.py +20 -43
  69. birder/scripts/train_mmcr.py +19 -40
  70. birder/scripts/train_rotnet.py +19 -40
  71. birder/scripts/train_simclr.py +19 -40
  72. birder/scripts/train_vicreg.py +19 -40
  73. birder/tools/__main__.py +6 -2
  74. birder/tools/adversarial.py +147 -96
  75. birder/tools/auto_anchors.py +380 -0
  76. birder/tools/ensemble_model.py +1 -1
  77. birder/tools/introspection.py +58 -31
  78. birder/tools/pack.py +172 -103
  79. birder/tools/show_det_iterator.py +10 -1
  80. birder/version.py +1 -1
  81. {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/METADATA +4 -3
  82. {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/RECORD +86 -75
  83. {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
  84. {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
  85. {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
  86. {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
1
+ from birder.adversarial.base import AttackResult
2
+ from birder.adversarial.deepfool import DeepFool
3
+ from birder.adversarial.fgsm import FGSM
4
+ from birder.adversarial.pgd import PGD
5
+ from birder.adversarial.simba import SimBA
6
+
7
+ __all__ = [
8
+ "AttackResult",
9
+ "DeepFool",
10
+ "FGSM",
11
+ "PGD",
12
+ "SimBA",
13
+ ]
@@ -0,0 +1,101 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+ from typing import Protocol
4
+
5
+ import torch
6
+
7
+ from birder.data.transforms.classification import RGBType
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class AttackResult:
12
+ adv_inputs: torch.Tensor
13
+ adv_logits: torch.Tensor
14
+ perturbation: torch.Tensor
15
+ logits: Optional[torch.Tensor] = None
16
+ success: Optional[torch.Tensor] = None
17
+ num_queries: Optional[int] = None
18
+
19
+
20
+ class Attack(Protocol):
21
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult: ...
22
+
23
+
24
+ def _to_channel_tensor(
25
+ values: tuple[float, float, float], device: Optional[torch.device], dtype: Optional[torch.dtype]
26
+ ) -> torch.Tensor:
27
+ return torch.tensor(values, device=device, dtype=dtype).view(1, -1, 1, 1)
28
+
29
+
30
+ def normalized_bounds(
31
+ rgb_stats: RGBType, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
32
+ ) -> tuple[torch.Tensor, torch.Tensor]:
33
+ mean = _to_channel_tensor(rgb_stats["mean"], device=device, dtype=dtype)
34
+ std = _to_channel_tensor(rgb_stats["std"], device=device, dtype=dtype)
35
+ min_val = (0.0 - mean) / std
36
+ max_val = (1.0 - mean) / std
37
+
38
+ return (min_val, max_val)
39
+
40
+
41
+ def pixel_eps_to_normalized(
42
+ eps: float | torch.Tensor,
43
+ rgb_stats: RGBType,
44
+ device: Optional[torch.device] = None,
45
+ dtype: Optional[torch.dtype] = None,
46
+ ) -> torch.Tensor:
47
+ eps_tensor = torch.as_tensor(eps, device=device, dtype=dtype)
48
+ std = _to_channel_tensor(rgb_stats["std"], device=eps_tensor.device, dtype=eps_tensor.dtype)
49
+
50
+ if eps_tensor.numel() == 1:
51
+ eps_tensor = eps_tensor.reshape(1, 1, 1, 1)
52
+ else:
53
+ eps_tensor = eps_tensor.reshape(1, -1, 1, 1)
54
+
55
+ return eps_tensor / std
56
+
57
+
58
+ def clamp_normalized(inputs: torch.Tensor, rgb_stats: RGBType) -> torch.Tensor:
59
+ (min_val, max_val) = normalized_bounds(rgb_stats, device=inputs.device, dtype=inputs.dtype)
60
+ return torch.clamp(inputs, min=min_val, max=max_val)
61
+
62
+
63
+ def predict_labels(logits: torch.Tensor) -> torch.Tensor:
64
+ return torch.argmax(logits, dim=1)
65
+
66
+
67
+ def validate_target(
68
+ target: Optional[torch.Tensor], batch_size: int, num_classes: int, device: torch.device
69
+ ) -> Optional[torch.Tensor]:
70
+ if target is None:
71
+ return None
72
+
73
+ target = target.to(device=device, dtype=torch.long)
74
+ if target.ndim == 0:
75
+ target = target.view(1)
76
+
77
+ if target.shape[0] != batch_size:
78
+ raise ValueError(f"Target shape {target.shape[0]} must match batch size {batch_size}")
79
+
80
+ if torch.any(target < 0) or torch.any(target >= num_classes):
81
+ raise ValueError(f"Target values must be in range [0, {num_classes})")
82
+
83
+ return target
84
+
85
+
86
+ def attack_success(
87
+ logits: torch.Tensor,
88
+ adv_logits: torch.Tensor,
89
+ targeted: bool,
90
+ target: Optional[torch.Tensor] = None,
91
+ labels: Optional[torch.Tensor] = None,
92
+ ) -> torch.Tensor:
93
+ adv_pred = predict_labels(adv_logits)
94
+ if targeted is True:
95
+ if target is None:
96
+ raise ValueError("Target labels required for targeted attacks")
97
+
98
+ return adv_pred.eq(target)
99
+
100
+ base_labels = labels if labels is not None else predict_labels(logits)
101
+ return adv_pred.ne(base_labels)
@@ -0,0 +1,173 @@
1
+ """
2
+ DeepFool
3
+
4
+ Paper "DeepFool: a simple and accurate method to fool deep neural networks", https://arxiv.org/abs/1511.04599
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from birder.adversarial.base import AttackResult
13
+ from birder.adversarial.base import attack_success
14
+ from birder.adversarial.base import clamp_normalized
15
+ from birder.adversarial.base import predict_labels
16
+ from birder.adversarial.base import validate_target
17
+ from birder.data.transforms.classification import RGBType
18
+
19
+ GRAD_EPS = 1e-12
20
+
21
+
22
+ class DeepFool:
23
+ def __init__(
24
+ self, net: nn.Module, num_classes: int = 10, overshoot: float = 0.02, max_iter: int = 50, *, rgb_stats: RGBType
25
+ ) -> None:
26
+ if num_classes < 2:
27
+ raise ValueError("num_classes must be at least 2")
28
+ if max_iter <= 0:
29
+ raise ValueError("max_iter must be positive")
30
+ if overshoot < 0:
31
+ raise ValueError("overshoot must be non-negative")
32
+
33
+ self.net = net.eval()
34
+ self.num_classes = num_classes
35
+ self.overshoot = overshoot
36
+ self.max_iter = max_iter
37
+ self.rgb_stats = rgb_stats
38
+
39
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
40
+ inputs = input_tensor.detach()
41
+ with torch.no_grad():
42
+ logits = self.net(inputs)
43
+
44
+ target_labels = (
45
+ validate_target(target, inputs.shape[0], logits.shape[1], inputs.device) if target is not None else None
46
+ )
47
+ targeted = target_labels is not None
48
+
49
+ adv_inputs_list = []
50
+ for idx in range(inputs.size(0)):
51
+ target_label = target_labels[idx : idx + 1] if target_labels is not None else None
52
+ adv_input = self._attack_single(inputs[idx : idx + 1], logits[idx : idx + 1], target_label)
53
+ adv_inputs_list.append(adv_input)
54
+
55
+ adv_inputs = torch.concat(adv_inputs_list, dim=0)
56
+ with torch.no_grad():
57
+ adv_logits = self.net(adv_inputs)
58
+
59
+ success = attack_success(
60
+ logits,
61
+ adv_logits,
62
+ targeted,
63
+ target=target_labels if targeted else None,
64
+ )
65
+
66
+ return AttackResult(
67
+ adv_inputs=adv_inputs,
68
+ adv_logits=adv_logits,
69
+ perturbation=adv_inputs - inputs,
70
+ logits=logits.detach(),
71
+ success=success,
72
+ )
73
+
74
+ def _attack_single(
75
+ self, inputs: torch.Tensor, logits: torch.Tensor, target_label: Optional[torch.Tensor]
76
+ ) -> torch.Tensor:
77
+ adv_inputs = inputs.clone()
78
+ original_label = int(predict_labels(logits).item())
79
+ targeted = target_label is not None
80
+ for _ in range(self.max_iter):
81
+ adv_inputs.requires_grad_(True)
82
+ outputs = self.net(adv_inputs)
83
+ current_label = int(predict_labels(outputs).item())
84
+
85
+ if targeted is True:
86
+ assert target_label is not None
87
+ target_value = int(target_label.item())
88
+ if current_label == target_value:
89
+ break
90
+
91
+ perturbation = self._targeted_perturbation(adv_inputs, outputs, current_label, target_value)
92
+
93
+ else:
94
+ if current_label != original_label:
95
+ break
96
+
97
+ perturbation = self._untargeted_perturbation(adv_inputs, outputs, current_label)
98
+
99
+ if perturbation is None:
100
+ break
101
+
102
+ # Overshoot helps ensure boundary crossing
103
+ adv_inputs = adv_inputs.detach() + (1.0 + self.overshoot) * perturbation
104
+ adv_inputs = clamp_normalized(adv_inputs, self.rgb_stats)
105
+
106
+ return adv_inputs.detach()
107
+
108
+ def _targeted_perturbation(
109
+ self, adv_inputs: torch.Tensor, outputs: torch.Tensor, current_label: int, target_label: int
110
+ ) -> Optional[torch.Tensor]:
111
+ self.net.zero_grad(set_to_none=True)
112
+ grad_current = torch.autograd.grad(outputs[0, current_label], adv_inputs, retain_graph=True)[0]
113
+ grad_target = torch.autograd.grad(outputs[0, target_label], adv_inputs, retain_graph=False)[0]
114
+
115
+ # Direction toward the target boundary
116
+ w = grad_target - grad_current
117
+ w_norm = torch.norm(w.view(-1))
118
+ if w_norm.item() < GRAD_EPS:
119
+ return None
120
+
121
+ # Distance to the decision boundary
122
+ f = outputs[0, target_label] - outputs[0, current_label]
123
+ perturbation = (f.abs() / (w_norm**2 + GRAD_EPS)) * w
124
+
125
+ return perturbation
126
+
127
+ def _untargeted_perturbation(
128
+ self, adv_inputs: torch.Tensor, outputs: torch.Tensor, current_label: int
129
+ ) -> Optional[torch.Tensor]:
130
+ # Search the top-k competing classes
131
+ top_k = min(self.num_classes, outputs.shape[1])
132
+ top_indices = torch.topk(outputs, k=top_k, dim=1).indices[0]
133
+ candidate_labels = [int(idx) for idx in top_indices if int(idx) != current_label]
134
+
135
+ if len(candidate_labels) == 0:
136
+ return None
137
+
138
+ self.net.zero_grad(set_to_none=True)
139
+ grad_current = torch.autograd.grad(outputs[0, current_label], adv_inputs, retain_graph=True)[0]
140
+
141
+ # Track the closest decision boundary
142
+ best_dist = None
143
+ best_w = None
144
+ best_f = None
145
+ for idx, label in enumerate(candidate_labels):
146
+ # Keep the graph until the last class
147
+ retain_graph = idx != len(candidate_labels) - 1
148
+ grad_other = torch.autograd.grad(outputs[0, label], adv_inputs, retain_graph=retain_graph)[0]
149
+
150
+ w_k = grad_other - grad_current
151
+ w_norm = torch.norm(w_k.view(-1))
152
+ if w_norm.item() < GRAD_EPS:
153
+ continue
154
+
155
+ f_k = outputs[0, label] - outputs[0, current_label]
156
+ dist = f_k.abs() / (w_norm + GRAD_EPS)
157
+
158
+ if best_dist is None or dist < best_dist:
159
+ best_dist = dist
160
+ best_w = w_k
161
+ best_f = f_k
162
+
163
+ if best_w is None or best_f is None:
164
+ return None
165
+
166
+ # Minimal perturbation toward the closest boundary
167
+ best_w_norm = torch.norm(best_w.view(-1))
168
+ if best_w_norm.item() < GRAD_EPS:
169
+ return None
170
+
171
+ perturbation = (best_f.abs() / (best_w_norm**2 + GRAD_EPS)) * best_w
172
+
173
+ return perturbation
@@ -1,34 +1,67 @@
1
- from typing import NamedTuple
1
+ """
2
+ Fast Gradient Sign Method (FGSM)
3
+
4
+ Paper "Explaining and Harnessing Adversarial Examples", https://arxiv.org/abs/1412.6572
5
+ """
6
+
2
7
  from typing import Optional
3
8
 
4
9
  import torch
5
10
  import torch.nn.functional as F
6
11
  from torch import nn
7
12
 
8
- FGSMResponse = NamedTuple(
9
- "FGSMResponse", [("out", torch.Tensor), ("perturbation", torch.Tensor), ("adv_out", torch.Tensor)]
10
- )
13
+ from birder.adversarial.base import AttackResult
14
+ from birder.adversarial.base import attack_success
15
+ from birder.adversarial.base import clamp_normalized
16
+ from birder.adversarial.base import pixel_eps_to_normalized
17
+ from birder.adversarial.base import predict_labels
18
+ from birder.adversarial.base import validate_target
19
+ from birder.data.transforms.classification import RGBType
11
20
 
12
21
 
13
22
  class FGSM:
14
- def __init__(self, net: nn.Module, eps: float) -> None:
23
+ def __init__(self, net: nn.Module, eps: float, *, rgb_stats: RGBType) -> None:
15
24
  self.net = net.eval()
16
25
  self.eps = eps
26
+ self.rgb_stats = rgb_stats
27
+
28
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
29
+ inputs = input_tensor.detach().clone()
30
+ inputs.requires_grad_(True)
31
+
32
+ logits = self.net(inputs)
33
+ targeted = target is not None
34
+ if targeted is True:
35
+ target = validate_target(target, inputs.shape[0], logits.shape[1], inputs.device)
36
+ else:
37
+ target = predict_labels(logits)
17
38
 
18
- def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> FGSMResponse:
19
- input_tensor.requires_grad = True
20
- out = self.net(input_tensor)
21
- if target is None:
22
- target = torch.argmax(out, dim=1)
39
+ loss = F.cross_entropy(logits, target)
40
+ (grad,) = torch.autograd.grad(loss, inputs, retain_graph=False, create_graph=False)
41
+ eps_norm = pixel_eps_to_normalized(self.eps, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
23
42
 
24
- loss = F.nll_loss(out, target)
25
- self.net.zero_grad()
26
- loss.backward()
43
+ # Targeted steps descend toward target, untargeted ascend away from original
44
+ if targeted is True:
45
+ direction = -1.0
46
+ else:
47
+ direction = 1.0
27
48
 
28
- input_grad = input_tensor.grad.data
29
- sign_data_grad = input_grad.sign()
30
- perturbed_image = input_tensor + self.eps * sign_data_grad
49
+ perturbation = direction * eps_norm * grad.sign()
50
+ adv_inputs = clamp_normalized(inputs + perturbation, self.rgb_stats)
51
+ with torch.no_grad():
52
+ adv_logits = self.net(adv_inputs)
31
53
 
32
- adv_out = self.net(perturbed_image)
54
+ success = attack_success(
55
+ logits.detach(),
56
+ adv_logits,
57
+ targeted,
58
+ target=target if targeted else None,
59
+ )
33
60
 
34
- return FGSMResponse(F.softmax(out, dim=1), self.eps * sign_data_grad, F.softmax(adv_out, dim=1))
61
+ return AttackResult(
62
+ adv_inputs=adv_inputs,
63
+ adv_logits=adv_logits,
64
+ perturbation=adv_inputs - inputs,
65
+ logits=logits.detach(),
66
+ success=success,
67
+ )
birder/adversarial/pgd.py CHANGED
@@ -1,54 +1,105 @@
1
1
  """
2
- Projected Gradient Descent, adapted from
3
- https://github.com/Harry24k/adversarial-attacks-pytorch/blob/master/torchattacks/attacks/pgd.py
2
+ Projected Gradient Descent (PGD)
4
3
 
5
- Paper "Towards Deep Learning Models Resistant to Adversarial Attacks",
6
- https://arxiv.org/abs/1706.06083
4
+ Paper "Towards Deep Learning Models Resistant to Adversarial Attacks", https://arxiv.org/abs/1706.06083
7
5
  """
8
6
 
9
7
  # Reference license: MIT
10
8
 
11
- from typing import NamedTuple
12
9
  from typing import Optional
13
10
 
14
11
  import torch
15
12
  import torch.nn.functional as F
16
13
  from torch import nn
17
14
 
18
- PGDResponse = NamedTuple("PGDResponse", [("out", torch.Tensor), ("adv_img", torch.Tensor), ("adv_out", torch.Tensor)])
15
+ from birder.adversarial.base import AttackResult
16
+ from birder.adversarial.base import attack_success
17
+ from birder.adversarial.base import clamp_normalized
18
+ from birder.adversarial.base import pixel_eps_to_normalized
19
+ from birder.adversarial.base import predict_labels
20
+ from birder.adversarial.base import validate_target
21
+ from birder.data.transforms.classification import RGBType
19
22
 
20
23
 
21
24
  class PGD:
22
- def __init__(self, net: nn.Module, eps: float, max_delta: float, steps: int, random_start: bool) -> None:
23
- self.net = net
24
- self.max_delta = max_delta
25
+ def __init__(
26
+ self,
27
+ net: nn.Module,
28
+ eps: float,
29
+ steps: int = 10,
30
+ step_size: Optional[float] = None,
31
+ random_start: bool = False,
32
+ *,
33
+ rgb_stats: RGBType,
34
+ ) -> None:
35
+ if steps <= 0:
36
+ raise ValueError("steps must be a positive integer")
37
+
38
+ self.net = net.eval()
25
39
  self.eps = eps
26
40
  self.steps = steps
41
+ if step_size is not None:
42
+ self.step_size = step_size
43
+ else:
44
+ self.step_size = eps / steps
45
+
27
46
  self.random_start = random_start
47
+ self.rgb_stats = rgb_stats
48
+
49
+ if self.step_size <= 0:
50
+ raise ValueError("step_size must be positive")
51
+
52
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
53
+ inputs = input_tensor.detach()
54
+ with torch.no_grad():
55
+ logits = self.net(inputs)
28
56
 
29
- def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> PGDResponse:
30
- adv_image = input_tensor.clone().detach()
31
- out = self.net(input_tensor)
32
- if target is None:
33
- target = torch.argmax(out, dim=1)
57
+ targeted = target is not None
58
+ if targeted:
59
+ target = validate_target(target, inputs.shape[0], logits.shape[1], inputs.device)
60
+ else:
61
+ target = predict_labels(logits)
34
62
 
63
+ eps_norm = pixel_eps_to_normalized(self.eps, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
64
+ step_norm = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
65
+
66
+ # Targeted steps descend toward target, untargeted ascend away from original
67
+ if targeted is True:
68
+ direction = -1.0
69
+ else:
70
+ direction = 1.0
71
+
72
+ adv_inputs = inputs.clone()
35
73
  if self.random_start is True:
36
- # Starting at a uniformly random point
37
- adv_image = adv_image + torch.empty_like(adv_image).uniform_(-self.max_delta, self.max_delta)
38
- adv_image = torch.clamp(adv_image, min=-4, max=4).detach()
74
+ # Random start inside the epsilon ball
75
+ adv_inputs = adv_inputs + torch.empty_like(adv_inputs).uniform_(-1.0, 1.0) * eps_norm
76
+ adv_inputs = clamp_normalized(adv_inputs, self.rgb_stats)
39
77
 
40
78
  for _ in range(self.steps):
41
- adv_image.requires_grad = True
42
- outputs = self.net(adv_image)
43
- loss = F.nll_loss(outputs, target)
44
- self.net.zero_grad()
45
- loss.backward()
79
+ adv_inputs.requires_grad_(True)
80
+ adv_logits = self.net(adv_inputs)
81
+ loss = F.cross_entropy(adv_logits, target)
82
+ (grad,) = torch.autograd.grad(loss, adv_inputs, retain_graph=False, create_graph=False)
83
+ adv_inputs = adv_inputs.detach() + direction * step_norm * grad.sign()
84
+
85
+ # Project back into the epsilon ball around the original input.
86
+ delta = torch.clamp(adv_inputs - inputs, min=-eps_norm, max=eps_norm)
87
+ adv_inputs = clamp_normalized(inputs + delta, self.rgb_stats)
46
88
 
47
- grad = adv_image.grad.data
48
- adv_image = adv_image.detach() + self.eps * grad.sign()
49
- delta = torch.clamp(adv_image - input_tensor, min=-self.max_delta, max=self.max_delta)
50
- adv_image = torch.clamp(input_tensor + delta, min=-4, max=4).detach()
89
+ with torch.no_grad():
90
+ adv_logits = self.net(adv_inputs)
51
91
 
52
- adv_out = self.net(adv_image)
92
+ success = attack_success(
93
+ logits.detach(),
94
+ adv_logits,
95
+ targeted,
96
+ target=target if targeted else None,
97
+ )
53
98
 
54
- return PGDResponse(F.softmax(out, dim=1), adv_image, F.softmax(adv_out, dim=1))
99
+ return AttackResult(
100
+ adv_inputs=adv_inputs,
101
+ adv_logits=adv_logits,
102
+ perturbation=adv_inputs - inputs,
103
+ logits=logits.detach(),
104
+ success=success,
105
+ )
@@ -0,0 +1,172 @@
1
+ """
2
+ SimBA (Simple Black-box Attack)
3
+
4
+ Paper "Simple Black-box Adversarial Attacks", https://arxiv.org/abs/1905.07121
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ from birder.adversarial.base import AttackResult
14
+ from birder.adversarial.base import attack_success
15
+ from birder.adversarial.base import clamp_normalized
16
+ from birder.adversarial.base import pixel_eps_to_normalized
17
+ from birder.adversarial.base import predict_labels
18
+ from birder.adversarial.base import validate_target
19
+ from birder.data.transforms.classification import RGBType
20
+
21
+
22
+ class SimBA:
23
+ def __init__(self, net: nn.Module, step_size: float, max_iter: int = 1000, *, rgb_stats: RGBType) -> None:
24
+ if step_size <= 0:
25
+ raise ValueError("step_size must be positive")
26
+ if max_iter <= 0:
27
+ raise ValueError("max_iter must be positive")
28
+
29
+ self.net = net.eval()
30
+ self.step_size = step_size
31
+ self.max_iter = max_iter
32
+ self.rgb_stats = rgb_stats
33
+
34
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
35
+ inputs = input_tensor.detach()
36
+ with torch.no_grad():
37
+ logits = self.net(inputs)
38
+
39
+ labels = predict_labels(logits)
40
+ target_labels = (
41
+ validate_target(target, inputs.shape[0], logits.shape[1], inputs.device) if target is not None else None
42
+ )
43
+ targeted = target_labels is not None
44
+
45
+ adv_inputs_list = []
46
+ total_queries = 0
47
+ for idx in range(inputs.size(0)):
48
+ label = labels[idx : idx + 1]
49
+ target_label = target_labels[idx : idx + 1] if target_labels is not None else None
50
+ adv_input, num_queries = self._attack_single(inputs[idx : idx + 1], label, target_label)
51
+ adv_inputs_list.append(adv_input)
52
+ total_queries += num_queries
53
+
54
+ adv_inputs = torch.concat(adv_inputs_list, dim=0)
55
+ with torch.no_grad():
56
+ adv_logits = self.net(adv_inputs)
57
+
58
+ success = attack_success(
59
+ logits,
60
+ adv_logits,
61
+ targeted,
62
+ target=target_labels if targeted else None,
63
+ )
64
+
65
+ return AttackResult(
66
+ adv_inputs=adv_inputs,
67
+ adv_logits=adv_logits,
68
+ perturbation=adv_inputs - inputs,
69
+ logits=logits.detach(),
70
+ success=success,
71
+ num_queries=total_queries,
72
+ )
73
+
74
+ # pylint: disable=too-many-locals
75
+ def _attack_single(
76
+ self, inputs: torch.Tensor, label: torch.Tensor, target_label: Optional[torch.Tensor]
77
+ ) -> tuple[torch.Tensor, int]:
78
+ adv_inputs = inputs.clone()
79
+ num_queries = 1 # Baseline forward pass
80
+
81
+ with torch.no_grad():
82
+ current_logits = self.net(adv_inputs)
83
+ current_objective = self._compute_objective(current_logits, label, target_label)
84
+
85
+ if self._is_successful(current_logits, label, target_label):
86
+ return adv_inputs.detach(), num_queries
87
+
88
+ (_, channels, height, width) = adv_inputs.shape
89
+ num_dims = channels * height * width
90
+ step = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=adv_inputs.device, dtype=adv_inputs.dtype)
91
+ step_vals = step.view(-1) # Per-channel steps
92
+ stride = height * width
93
+
94
+ perm = torch.randperm(num_dims, device=adv_inputs.device)
95
+ num_steps = min(self.max_iter, num_dims)
96
+
97
+ # Coordinate-wise search in random order
98
+ for flat_idx in perm[:num_steps]:
99
+ (c, rem) = divmod(int(flat_idx.item()), stride)
100
+ (h, w) = divmod(rem, width)
101
+ step_val = step_vals[c]
102
+
103
+ (candidate_inputs, candidate_logits, candidate_objective) = self._best_candidate(
104
+ adv_inputs, c, h, w, step_val, label, target_label
105
+ )
106
+ num_queries += 2
107
+
108
+ if candidate_objective < current_objective:
109
+ adv_inputs = candidate_inputs
110
+ current_logits = candidate_logits
111
+ current_objective = candidate_objective
112
+
113
+ if self._is_successful(current_logits, label, target_label) is True:
114
+ break
115
+
116
+ return adv_inputs.detach(), num_queries
117
+
118
+ def _perturb_pixel(
119
+ self, inputs: torch.Tensor, channel: int, row: int, col: int, step: torch.Tensor
120
+ ) -> torch.Tensor:
121
+ adv_inputs = inputs.clone()
122
+ adv_inputs[0, channel, row, col] = adv_inputs[0, channel, row, col] + step
123
+ return clamp_normalized(adv_inputs, self.rgb_stats)
124
+
125
+ def _evaluate_candidate(
126
+ self, inputs: torch.Tensor, label: torch.Tensor, target_label: Optional[torch.Tensor]
127
+ ) -> tuple[torch.Tensor, float]:
128
+ with torch.no_grad():
129
+ logits = self.net(inputs)
130
+
131
+ return logits, self._compute_objective(logits, label, target_label)
132
+
133
+ def _best_candidate(
134
+ self,
135
+ inputs: torch.Tensor,
136
+ channel: int,
137
+ row: int,
138
+ col: int,
139
+ step: torch.Tensor,
140
+ label: torch.Tensor,
141
+ target_label: Optional[torch.Tensor],
142
+ ) -> tuple[torch.Tensor, torch.Tensor, float]:
143
+ adv_plus = self._perturb_pixel(inputs, channel, row, col, step)
144
+ logits_plus, objective_plus = self._evaluate_candidate(adv_plus, label, target_label)
145
+
146
+ adv_minus = self._perturb_pixel(inputs, channel, row, col, -step)
147
+ logits_minus, objective_minus = self._evaluate_candidate(adv_minus, label, target_label)
148
+
149
+ if objective_plus <= objective_minus:
150
+ return adv_plus, logits_plus, objective_plus
151
+
152
+ return adv_minus, logits_minus, objective_minus
153
+
154
+ @staticmethod
155
+ def _compute_objective(
156
+ logits: torch.Tensor, original_label: torch.Tensor, target_label: Optional[torch.Tensor]
157
+ ) -> float:
158
+ # Lower objective is better in both modes
159
+ if target_label is not None:
160
+ return float(F.cross_entropy(logits, target_label).item())
161
+
162
+ return -float(F.cross_entropy(logits, original_label).item())
163
+
164
+ @staticmethod
165
+ def _is_successful(
166
+ logits: torch.Tensor, original_label: torch.Tensor, target_label: Optional[torch.Tensor]
167
+ ) -> bool:
168
+ pred = predict_labels(logits)
169
+ if target_label is not None:
170
+ return bool(pred.eq(target_label).item())
171
+
172
+ return bool(pred.ne(original_label).item())