adv-lib 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. adv_lib/__init__.py +1 -0
  2. adv_lib/attacks/__init__.py +13 -0
  3. adv_lib/attacks/augmented_lagrangian.py +243 -0
  4. adv_lib/attacks/auto_pgd.py +523 -0
  5. adv_lib/attacks/boundary_projection_tf.py +170 -0
  6. adv_lib/attacks/carlini_wagner/__init__.py +2 -0
  7. adv_lib/attacks/carlini_wagner/l2.py +151 -0
  8. adv_lib/attacks/carlini_wagner/linf.py +158 -0
  9. adv_lib/attacks/decoupled_direction_norm.py +113 -0
  10. adv_lib/attacks/fast_adaptive_boundary/__init__.py +1 -0
  11. adv_lib/attacks/fast_adaptive_boundary/fast_adaptive_boundary.py +215 -0
  12. adv_lib/attacks/fast_adaptive_boundary/projections.py +164 -0
  13. adv_lib/attacks/fast_minimum_norm.py +218 -0
  14. adv_lib/attacks/perceptual_color_attacks/__init__.py +1 -0
  15. adv_lib/attacks/perceptual_color_attacks/differential_color_functions.py +181 -0
  16. adv_lib/attacks/perceptual_color_attacks/perceptual_color_distance_al.py +128 -0
  17. adv_lib/attacks/primal_dual_gradient_descent.py +379 -0
  18. adv_lib/attacks/projected_gradient_descent.py +109 -0
  19. adv_lib/attacks/segmentation/__init__.py +4 -0
  20. adv_lib/attacks/segmentation/alma_prox.py +283 -0
  21. adv_lib/attacks/segmentation/asma.py +92 -0
  22. adv_lib/attacks/segmentation/dense_adversary.py +83 -0
  23. adv_lib/attacks/segmentation/primal_dual_gradient_descent.py +349 -0
  24. adv_lib/attacks/self_adaptive_norm_update.py +127 -0
  25. adv_lib/attacks/sigma_zero.py +119 -0
  26. adv_lib/attacks/stochastic_sparse_attacks.py +237 -0
  27. adv_lib/attacks/structured_adversarial_attack.py +289 -0
  28. adv_lib/attacks/trust_region.py +153 -0
  29. adv_lib/distances/__init__.py +0 -0
  30. adv_lib/distances/color_difference.py +212 -0
  31. adv_lib/distances/lp_norms.py +18 -0
  32. adv_lib/distances/lpips.py +99 -0
  33. adv_lib/distances/structural_similarity.py +147 -0
  34. adv_lib/utils/__init__.py +1 -0
  35. adv_lib/utils/attack_utils.py +226 -0
  36. adv_lib/utils/color_conversions.py +71 -0
  37. adv_lib/utils/image_selection.py +27 -0
  38. adv_lib/utils/lagrangian_penalties/__init__.py +1 -0
  39. adv_lib/utils/lagrangian_penalties/all_penalties.py +67 -0
  40. adv_lib/utils/lagrangian_penalties/penalty_functions.py +79 -0
  41. adv_lib/utils/lagrangian_penalties/scripts/plot_penalties.py +42 -0
  42. adv_lib/utils/lagrangian_penalties/scripts/plot_univariates.py +32 -0
  43. adv_lib/utils/lagrangian_penalties/univariate_functions.py +299 -0
  44. adv_lib/utils/losses.py +29 -0
  45. adv_lib/utils/projections.py +100 -0
  46. adv_lib/utils/utils.py +58 -0
  47. adv_lib/utils/visdom_logger.py +109 -0
  48. adv_lib-0.2.2.dist-info/LICENSE +29 -0
  49. adv_lib-0.2.2.dist-info/METADATA +170 -0
  50. adv_lib-0.2.2.dist-info/RECORD +52 -0
  51. adv_lib-0.2.2.dist-info/WHEEL +5 -0
  52. adv_lib-0.2.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,151 @@
1
+ # Adapted from https://github.com/carlini/nn_robust_attacks
2
+
3
+ from typing import Tuple, Optional
4
+
5
+ import torch
6
+ from torch import nn, optim, Tensor
7
+ from torch.autograd import grad
8
+
9
+ from adv_lib.utils.losses import difference_of_logits
10
+ from adv_lib.utils.visdom_logger import VisdomLogger
11
+
12
+
13
+ def carlini_wagner_l2(model: nn.Module,
14
+ inputs: Tensor,
15
+ labels: Tensor,
16
+ targeted: bool = False,
17
+ confidence: float = 0,
18
+ learning_rate: float = 0.01,
19
+ initial_const: float = 0.001,
20
+ binary_search_steps: int = 9,
21
+ max_iterations: int = 10000,
22
+ abort_early: bool = True,
23
+ callback: Optional[VisdomLogger] = None) -> Tensor:
24
+ """
25
+ Carlini and Wagner L2 attack from https://arxiv.org/abs/1608.04644.
26
+
27
+ Parameters
28
+ ----------
29
+ model : nn.Module
30
+ Model to attack.
31
+ inputs : Tensor
32
+ Inputs to attack. Should be in [0, 1].
33
+ labels : Tensor
34
+ Labels corresponding to the inputs if untargeted, else target labels.
35
+ targeted : bool
36
+ Whether to perform a targeted attack or not.
37
+ confidence : float
38
+ Confidence of adversarial examples: higher produces examples that are farther away, but more strongly classified
39
+ as adversarial.
40
+ learning_rate: float
41
+ The learning rate for the attack algorithm. Smaller values produce better results but are slower to converge.
42
+ initial_const : float
43
+ The initial tradeoff-constant to use to tune the relative importance of distance and confidence. If
44
+ binary_search_steps is large, the initial constant is not important.
45
+ binary_search_steps : int
46
+ The number of times we perform binary search to find the optimal tradeoff-constant between distance and
47
+ confidence.
48
+ max_iterations : int
49
+ The maximum number of iterations. Larger values are more accurate; setting too small will require a large
50
+ learning rate and will produce poor results.
51
+ abort_early : bool
52
+ If true, allows early aborts if gradient descent gets stuck.
53
+ callback : Optional
54
+
55
+ Returns
56
+ -------
57
+ adv_inputs : Tensor
58
+ Modified inputs to be adversarial to the model.
59
+
60
+ """
61
+ device = inputs.device
62
+ batch_size = len(inputs)
63
+ batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
64
+ t_inputs = (inputs * 2).sub_(1).mul_(1 - 1e-6).atanh_()
65
+ multiplier = -1 if targeted else 1
66
+
67
+ # set the lower and upper bounds accordingly
68
+ c = torch.full((batch_size,), initial_const, device=device)
69
+ lower_bound = torch.zeros_like(c)
70
+ upper_bound = torch.full_like(c, 1e10)
71
+
72
+ o_best_l2 = torch.full_like(c, float('inf'))
73
+ o_best_adv = inputs.clone()
74
+ o_adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
75
+
76
+ i_total = 0
77
+ for outer_step in range(binary_search_steps):
78
+
79
+ # setup the modifier variable and the optimizer
80
+ modifier = torch.zeros_like(inputs, requires_grad=True)
81
+ optimizer = optim.Adam([modifier], lr=learning_rate)
82
+ best_l2 = torch.full_like(c, float('inf'))
83
+ adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
84
+
85
+ # The last iteration (if we run many steps) repeat the search once.
86
+ if (binary_search_steps >= 10) and outer_step == (binary_search_steps - 1):
87
+ c = upper_bound
88
+
89
+ prev = float('inf')
90
+ for i in range(max_iterations):
91
+
92
+ adv_inputs = (torch.tanh(t_inputs + modifier) + 1) / 2
93
+ l2_squared = (adv_inputs - inputs).flatten(1).square().sum(1)
94
+ l2 = l2_squared.detach().sqrt()
95
+ logits = model(adv_inputs)
96
+
97
+ if outer_step == 0 and i == 0:
98
+ # setup the target variable, we need it to be in one-hot form for the loss function
99
+ labels_onehot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), 1)
100
+ labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf'))
101
+
102
+ # adjust the best result found so far
103
+ predicted_classes = (logits - labels_onehot * confidence).argmax(1) if targeted else \
104
+ (logits + labels_onehot * confidence).argmax(1)
105
+
106
+ is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels)
107
+ is_smaller = l2 < best_l2
108
+ o_is_smaller = l2 < o_best_l2
109
+ is_both = is_adv & is_smaller
110
+ o_is_both = is_adv & o_is_smaller
111
+
112
+ best_l2 = torch.where(is_both, l2, best_l2)
113
+ adv_found.logical_or_(is_both)
114
+ o_best_l2 = torch.where(o_is_both, l2, o_best_l2)
115
+ o_adv_found.logical_or_(is_both)
116
+ o_best_adv = torch.where(batch_view(o_is_both), adv_inputs.detach(), o_best_adv)
117
+
118
+ logit_dists = multiplier * difference_of_logits(logits, labels, labels_infhot=labels_infhot)
119
+ loss = l2_squared + c * (logit_dists + confidence).clamp_(min=0)
120
+
121
+ # check if we should abort search if we're getting nowhere.
122
+ if abort_early and i % (max_iterations // 10) == 0:
123
+ if (loss > prev * 0.9999).all():
124
+ break
125
+ prev = loss.detach()
126
+
127
+ optimizer.zero_grad(set_to_none=True)
128
+ modifier.grad = grad(loss.sum(), modifier, only_inputs=True)[0]
129
+ optimizer.step()
130
+
131
+ if callback:
132
+ i_total += 1
133
+ callback.accumulate_line('logit_dist', i_total, logit_dists.mean())
134
+ callback.accumulate_line('l2_norm', i_total, l2.mean())
135
+ if i_total % (max_iterations // 20) == 0:
136
+ callback.update_lines()
137
+
138
+ if callback:
139
+ best_l2 = o_best_l2.masked_select(o_adv_found).mean()
140
+ callback.line(['success', 'best_l2', 'c'], outer_step, [o_adv_found.float().mean(), best_l2, c.mean()])
141
+
142
+ # adjust the constant as needed
143
+ upper_bound[adv_found] = torch.min(upper_bound[adv_found], c[adv_found])
144
+ adv_not_found = ~adv_found
145
+ lower_bound[adv_not_found] = torch.max(lower_bound[adv_not_found], c[adv_not_found])
146
+ is_smaller = upper_bound < 1e9
147
+ c[is_smaller] = (lower_bound[is_smaller] + upper_bound[is_smaller]) / 2
148
+ c[(~is_smaller) & adv_not_found] *= 10
149
+
150
+ # return the best solution found
151
+ return o_best_adv
@@ -0,0 +1,158 @@
1
+ # Adapted from https://github.com/carlini/nn_robust_attacks
2
+
3
+ from typing import Tuple, Optional
4
+
5
+ import torch
6
+ from torch import nn, optim, Tensor
7
+
8
+ from adv_lib.utils.losses import difference_of_logits
9
+ from adv_lib.utils.visdom_logger import VisdomLogger
10
+
11
+
12
+ def carlini_wagner_linf(model: nn.Module,
13
+ inputs: Tensor,
14
+ labels: Tensor,
15
+ targeted: bool = False,
16
+ learning_rate: float = 0.01,
17
+ max_iterations: int = 1000,
18
+ initial_const: float = 1e-5,
19
+ largest_const: float = 2e+1,
20
+ const_factor: float = 2,
21
+ reduce_const: bool = False,
22
+ decrease_factor: float = 0.9,
23
+ abort_early: bool = True,
24
+ callback: Optional[VisdomLogger] = None) -> Tensor:
25
+ """
26
+ Carlini and Wagner Linf attack from https://arxiv.org/abs/1608.04644.
27
+
28
+ Parameters
29
+ ----------
30
+ model : nn.Module
31
+ Model to attack.
32
+ inputs : Tensor
33
+ Inputs to attack. Should be in [0, 1].
34
+ labels : Tensor
35
+ Labels corresponding to the inputs if untargeted, else target labels.
36
+ targeted : bool
37
+ Whether to perform a targeted attack or not.
38
+ learning_rate: float
39
+ The learning rate for the attack algorithm. Smaller values produce better results but are slower to converge.
40
+ max_iterations : int
41
+ The maximum number of iterations. Larger values are more accurate; setting too small will require a large
42
+ learning rate and will produce poor results.
43
+ initial_const : float
44
+ The initial tradeoff-constant to use to tune the relative importance of distance and classification objective.
45
+ largest_const : float
46
+ The maximum tradeoff-constant to use to tune the relative importance of distance and classification objective.
47
+ const_factor : float
48
+ The multiplicative factor by which the constant is increased if the search failed.
49
+ reduce_const : float
50
+ If true, after each successful attack, make the constant smaller.
51
+ decrease_factor : float
52
+ Rate at which τ is decreased. Larger produces better quality results.
53
+ abort_early : bool
54
+ If true, allows early aborts if gradient descent gets stuck.
55
+ image_constraints : Tuple[float, float]
56
+ Minimum and maximum pixel values.
57
+ callback : Optional
58
+
59
+ Returns
60
+ -------
61
+ adv_inputs : Tensor
62
+ Modified inputs to be adversarial to the model.
63
+
64
+ """
65
+ device = inputs.device
66
+ batch_size = len(inputs)
67
+ t_inputs = (inputs * 2).sub_(1).mul_(1 - 1e-6).atanh_()
68
+ multiplier = -1 if targeted else 1
69
+
70
+ # set modifier and the parameters used in the optimization
71
+ modifier = torch.zeros_like(inputs)
72
+ c = torch.full((batch_size,), initial_const, device=device, dtype=torch.float)
73
+ τ = torch.ones(batch_size, device=device)
74
+
75
+ o_adv_found = torch.zeros_like(c, dtype=torch.bool)
76
+ o_best_linf = torch.ones_like(c)
77
+ o_best_adv = inputs.clone()
78
+
79
+ outer_loops = 0
80
+ total_iters = 0
81
+ while (to_optimize := (τ > 1 / 255) & (c < largest_const)).any():
82
+
83
+ inputs_, t_inputs_, labels_ = inputs[to_optimize], t_inputs[to_optimize], labels[to_optimize]
84
+ batch_view = lambda tensor: tensor.view(len(inputs_), *[1] * (inputs_.ndim - 1))
85
+
86
+ if callback:
87
+ callback.line(['const', 'τ'], outer_loops, [c[to_optimize].mean(), τ[to_optimize].mean()])
88
+ callback.line(['success', 'best_linf'], outer_loops, [o_adv_found.float().mean(), best_linf.mean()])
89
+
90
+ # setup the optimizer
91
+ modifier_ = modifier[to_optimize].requires_grad_(True)
92
+ optimizer = optim.Adam([modifier_], lr=learning_rate)
93
+ c_, τ_ = c[to_optimize], τ[to_optimize]
94
+
95
+ adv_found = torch.zeros(len(modifier_), device=device, dtype=torch.bool)
96
+ best_linf = o_best_linf[to_optimize]
97
+ best_adv = inputs_.clone()
98
+
99
+ if callback:
100
+ callback.line(['const', 'τ'], outer_loops, [c_.mean(), τ_.mean()])
101
+ callback.line(['success', 'best_linf'], outer_loops, [o_adv_found.float().mean(), o_best_linf.mean()])
102
+
103
+ for i in range(max_iterations):
104
+
105
+ adv_inputs = (torch.tanh(t_inputs_ + modifier_) + 1) / 2
106
+ linf = (adv_inputs.detach() - inputs_).flatten(1).norm(p=float('inf'), dim=1)
107
+ logits = model(adv_inputs)
108
+
109
+ if i == 0:
110
+ labels_infhot = torch.zeros_like(logits).scatter_(1, labels[to_optimize].unsqueeze(1), float('inf'))
111
+
112
+ # adjust the best result found so far
113
+ predicted_classes = logits.argmax(1)
114
+
115
+ is_adv = (predicted_classes == labels_) if targeted else (predicted_classes != labels_)
116
+ is_smaller = linf < best_linf
117
+ is_both = is_adv & is_smaller
118
+ adv_found.logical_or_(is_both)
119
+ best_linf = torch.where(is_both, linf, best_linf)
120
+ best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv)
121
+
122
+ logit_dists = multiplier * difference_of_logits(logits, labels_, labels_infhot=labels_infhot)
123
+ linf_loss = (adv_inputs - inputs_).abs_().sub_(batch_view(τ_)).clamp_(min=0).flatten(1).sum(1)
124
+ loss = linf_loss + c_ * logit_dists.clamp_(min=0)
125
+
126
+ # check if we should abort search
127
+ if abort_early and (loss < 0.0001 * c_).all():
128
+ break
129
+
130
+ optimizer.zero_grad()
131
+ loss.sum().backward()
132
+ optimizer.step()
133
+
134
+ if callback:
135
+ callback.accumulate_line('logit_dist', total_iters, logit_dists.mean())
136
+ callback.accumulate_line('linf_norm', total_iters, linf.mean())
137
+
138
+ if (i + 1) % (max_iterations // 10) == 0 or (i + 1) == max_iterations:
139
+ callback.update_lines()
140
+
141
+ total_iters += 1
142
+
143
+ o_adv_found[to_optimize] = adv_found | o_adv_found[to_optimize]
144
+ o_best_linf[to_optimize] = torch.where(adv_found, best_linf, o_best_linf[to_optimize])
145
+ o_best_adv[to_optimize] = torch.where(batch_view(adv_found), best_adv, o_best_adv[to_optimize])
146
+ modifier[to_optimize] = modifier_.detach()
147
+
148
+ smaller_τ_ = adv_found & (best_linf < τ_)
149
+ τ_ = torch.where(smaller_τ_, best_linf, τ_)
150
+ τ[to_optimize] = torch.where(adv_found, decrease_factor * τ_, τ_)
151
+ c[to_optimize] = torch.where(~adv_found, const_factor * c_, c_)
152
+ if reduce_const:
153
+ c[to_optimize] = torch.where(adv_found, c[to_optimize] / 2, c[to_optimize])
154
+
155
+ outer_loops += 1
156
+
157
+ # return the best solution found
158
+ return o_best_adv
@@ -0,0 +1,113 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+ from torch.autograd import grad
7
+ from torch.nn import functional as F
8
+
9
+ from adv_lib.utils.visdom_logger import VisdomLogger
10
+
11
+
12
+ def ddn(model: nn.Module,
13
+ inputs: Tensor,
14
+ labels: Tensor,
15
+ targeted: bool = False,
16
+ steps: int = 100,
17
+ γ: float = 0.05,
18
+ init_norm: float = 1.,
19
+ levels: Optional[int] = 256,
20
+ callback: Optional[VisdomLogger] = None) -> Tensor:
21
+ """
22
+ Decoupled Direction and Norm attack from https://arxiv.org/abs/1811.09600.
23
+
24
+ Parameters
25
+ ----------
26
+ model : nn.Module
27
+ Model to attack.
28
+ inputs : Tensor
29
+ Inputs to attack. Should be in [0, 1].
30
+ labels : Tensor
31
+ Labels corresponding to the inputs if untargeted, else target labels.
32
+ targeted : bool
33
+ Whether to perform a targeted attack or not.
34
+ steps : int
35
+ Number of optimization steps.
36
+ γ : float
37
+ Factor by which the norm will be modified. new_norm = norm * (1 + or - γ).
38
+ init_norm : float
39
+ Initial value for the norm of the attack.
40
+ levels : int
41
+ If not None, the returned adversarials will have quantized values to the specified number of levels.
42
+ callback : Optional
43
+
44
+ Returns
45
+ -------
46
+ adv_inputs : Tensor
47
+ Modified inputs to be adversarial to the model.
48
+
49
+ """
50
+ if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.')
51
+ device = inputs.device
52
+ batch_size = len(inputs)
53
+ batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
54
+
55
+ # Init variables
56
+ multiplier = -1 if targeted else 1
57
+ δ = torch.zeros_like(inputs, requires_grad=True)
58
+ ε = torch.full((batch_size,), init_norm, device=device, dtype=torch.float)
59
+ worst_norm = torch.max(inputs, 1 - inputs).flatten(1).norm(p=2, dim=1)
60
+
61
+ # Init trackers
62
+ best_l2 = worst_norm.clone()
63
+ best_adv = inputs.clone()
64
+ adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
65
+
66
+ for i in range(steps):
67
+ l2 = δ.data.flatten(1).norm(p=2, dim=1)
68
+ adv_inputs = inputs + δ
69
+ logits = model(adv_inputs)
70
+ pred_labels = logits.argmax(1)
71
+ ce_loss = F.cross_entropy(logits, labels, reduction='none')
72
+ loss = multiplier * ce_loss
73
+
74
+ is_adv = (pred_labels == labels) if targeted else (pred_labels != labels)
75
+ is_smaller = l2 < best_l2
76
+ is_both = is_adv & is_smaller
77
+ adv_found.logical_or_(is_adv)
78
+ best_l2 = torch.where(is_both, l2, best_l2)
79
+ best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv)
80
+
81
+ δ_grad = grad(loss.sum(), δ, only_inputs=True)[0]
82
+ # renorming gradient
83
+ grad_norms = δ_grad.flatten(1).norm(p=2, dim=1)
84
+ δ_grad.div_(batch_view(grad_norms))
85
+ # avoid nan or inf if gradient is 0
86
+ if (zero_grad := (grad_norms < 1e-12)).any():
87
+ δ_grad[zero_grad] = torch.randn_like(δ_grad[zero_grad])
88
+
89
+ α = 0.01 + (1 - 0.01) * (1 + math.cos(math.pi * i / steps)) / 2
90
+
91
+ if callback is not None:
92
+ cosine = F.cosine_similarity(δ_grad.flatten(1), δ.data.flatten(1), dim=1).mean()
93
+ callback.accumulate_line('ce', i, ce_loss.mean())
94
+ callback_best = best_l2.masked_select(adv_found).mean()
95
+ callback.accumulate_line(['ε', 'l2', 'best_l2'], i, [ε.mean(), l2.mean(), callback_best])
96
+ callback.accumulate_line(['cosine', 'α', 'success'], i,
97
+ [cosine, torch.tensor(α, device=device), adv_found.float().mean()])
98
+
99
+ if (i + 1) % (steps // 20) == 0 or (i + 1) == steps:
100
+ callback.update_lines()
101
+
102
+ δ.data.add_(δ_grad, alpha=α)
103
+
104
+ ε = torch.where(is_adv, (1 - γ) * ε, (1 + γ) * ε)
105
+ ε = torch.minimum(ε, worst_norm)
106
+
107
+ δ.data.mul_(batch_view(ε / δ.data.flatten(1).norm(p=2, dim=1)))
108
+ δ.data.add_(inputs).clamp_(min=0, max=1)
109
+ if levels is not None:
110
+ δ.data.mul_(levels - 1).round_().div_(levels - 1)
111
+ δ.data.sub_(inputs)
112
+
113
+ return best_adv
@@ -0,0 +1 @@
1
+ from .fast_adaptive_boundary import fab
@@ -0,0 +1,215 @@
1
+ # Adapted from https://github.com/fra31/auto-attack
2
+
3
+ import warnings
4
+ from functools import partial
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+ from torch.autograd import grad
10
+
11
+ from .projections import projection_l1, projection_l2, projection_linf
12
+
13
+
14
+ def fab(model: nn.Module,
15
+ inputs: Tensor,
16
+ labels: Tensor,
17
+ norm: float,
18
+ n_iter: int = 100,
19
+ ε: Optional[float] = None,
20
+ α_max: float = 0.1,
21
+ β: float = 0.9,
22
+ η: float = 1.05,
23
+ restarts: Optional[int] = None,
24
+ targeted_restarts: bool = False,
25
+ targeted: bool = False) -> Tensor:
26
+ """
27
+ Fast Adaptive Boundary (FAB) attack from https://arxiv.org/abs/1907.02044
28
+
29
+ Parameters
30
+ ----------
31
+ model : nn.Module
32
+ Model to attack.
33
+ inputs : Tensor
34
+ Inputs to attack. Should be in [0, 1].
35
+ labels : Tensor
36
+ Labels corresponding to the inputs if untargeted, else target labels.
37
+ norm : float
38
+ Norm to minimize in {1, 2 ,float('inf')}.
39
+ n_iter : int
40
+ Number of optimization steps. This does not correspond to the number of forward / backward propagations for this
41
+ attack. For a more comprehensive discussion on complexity, see section 4 of https://arxiv.org/abs/1907.02044 and
42
+ for a comparison of complexities, see https://arxiv.org/abs/2011.11857.
43
+ TL;DR: FAB performs 2 forwards and K - 1 (i.e. number of classes - 1) backwards per step in default mode. If
44
+ `targeted_restarts` is `True`, performs `2 * restarts` forwards and `restarts` or (K - 1) backwards per step.
45
+ ε : float
46
+ Maximum norm of the random initialization for restarts.
47
+ α_max : float
48
+ Maximum weight for the biased gradient step. α = 0 corresponds to taking the projection of the `adv_inputs` on
49
+ the decision hyperplane, while α = 1 corresponds to taking the projection of the `inputs` on the decision
50
+ hyperplane.
51
+ β : float
52
+ Weight for the biased backward step, i.e. a linear interpolation between `inputs` and `adv_inputs` at step i.
53
+ β = 0 corresponds to taking the original `inputs` and β = 1 corresponds to taking the `adv_inputs`.
54
+ η : float
55
+ Extrapolation for the optimization step. η = 1 corresponds to projecting the `adv_inputs` on the decision
56
+ hyperplane. η > 1 corresponds to overshooting to increase the probability of crossing the decision hyperplane.
57
+ restarts : int
58
+ Number of random restarts in default mode; starts from the inputs in the first run and then add random noise for
59
+ the consecutive restarts. Number of classes to attack if `targeted_restarts` is `True`.
60
+ targeted_restarts : bool
61
+ If `True`, performs targeted attack towards the most likely classes for the unperturbed `inputs`. If `restarts`
62
+ is not given, this will attack each class (except the original class). If `restarts` is given, the `restarts`
63
+ most likely classes will be attacked. If `restarts` is larger than K - 1, this will re-attack the most likely
64
+ classes with random noise.
65
+ targeted : bool
66
+ Placeholder argument for library. FAB is only for untargeted attacks, so setting this to True will raise a
67
+ warning and return the inputs.
68
+
69
+ Returns
70
+ -------
71
+ adv_inputs : Tensor
72
+ Modified inputs to be adversarial to the model.
73
+
74
+ """
75
+ if targeted:
76
+ warnings.warn('FAB attack is untargeted only. Returning inputs.')
77
+ return inputs
78
+
79
+ best_adv = inputs.clone()
80
+ best_norm = torch.full_like(labels, float('inf'), dtype=torch.float)
81
+
82
+ fab_attack = partial(_fab, model=model, norm=norm, n_iter=n_iter, ε=ε, α_max=α_max, β=β, η=η)
83
+
84
+ if targeted_restarts:
85
+ logits = model(inputs)
86
+ n_target_classes = logits.size(1) - 1
87
+ labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf'))
88
+ k = min(restarts or n_target_classes, n_target_classes)
89
+ topk_labels = (logits - labels_infhot).topk(k=k, dim=1).indices
90
+
91
+ n_restarts = restarts or (n_target_classes if targeted_restarts else 1)
92
+ for i in range(n_restarts):
93
+
94
+ if targeted_restarts:
95
+ target_labels = topk_labels[:, i % n_target_classes]
96
+ adv_inputs_run, adv_found_run, norm_run = fab_attack(
97
+ inputs=inputs, labels=labels, random_start=i >= n_target_classes, targets=target_labels, u=best_norm)
98
+ else:
99
+ adv_inputs_run, adv_found_run, norm_run = fab_attack(inputs=inputs, labels=labels, random_start=i != 0,
100
+ u=best_norm)
101
+
102
+ is_better_adv = adv_found_run & (norm_run < best_norm)
103
+ best_norm[is_better_adv] = norm_run[is_better_adv]
104
+ best_adv[is_better_adv] = adv_inputs_run[is_better_adv]
105
+
106
+ return best_adv
107
+
108
+
109
+ def get_best_diff_logits_grads(model: nn.Module,
110
+ inputs: Tensor,
111
+ labels: Tensor,
112
+ other_labels: Tensor,
113
+ q: float) -> Tuple[Tensor, Tensor]:
114
+ batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
115
+ min_ratio = torch.full_like(labels, float('inf'), dtype=torch.float)
116
+ best_logit_diff, best_grad_diff = torch.zeros_like(labels, dtype=torch.float), torch.zeros_like(inputs)
117
+
118
+ inputs.requires_grad_(True)
119
+ logits = model(inputs)
120
+ class_logits = logits.gather(1, labels.unsqueeze(1)).squeeze(1)
121
+
122
+ n_other_labels = other_labels.size(1)
123
+ for i, o_labels in enumerate(other_labels.transpose(0, 1)):
124
+ other_logits = logits.gather(1, o_labels.unsqueeze(1)).squeeze(1)
125
+ logits_diff = other_logits - class_logits
126
+ grad_diff = grad(logits_diff.sum(), inputs, only_inputs=True, retain_graph=i + 1 != n_other_labels)[0]
127
+ ratio = logits_diff.abs().div_(grad_diff.flatten(1).norm(p=q, dim=1).clamp_(min=1e-12))
128
+
129
+ smaller_ratio = ratio < min_ratio
130
+ min_ratio = torch.min(ratio, min_ratio)
131
+ best_logit_diff = torch.where(smaller_ratio, logits_diff.detach(), best_logit_diff)
132
+ best_grad_diff = torch.where(batch_view(smaller_ratio), grad_diff.detach(), best_grad_diff)
133
+
134
+ inputs.detach_()
135
+ return best_logit_diff, best_grad_diff
136
+
137
+
138
+ def _fab(model: nn.Module,
139
+ inputs: Tensor,
140
+ labels: Tensor,
141
+ norm: float,
142
+ n_iter: int = 100,
143
+ ε: Optional[float] = None,
144
+ α_max: float = 0.1,
145
+ β: float = 0.9,
146
+ η: float = 1.05,
147
+ random_start: bool = False,
148
+ u: Optional[Tensor] = None,
149
+ targets: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
150
+ _projection_dual_default_ε = {
151
+ 1: (projection_l1, float('inf'), 5),
152
+ 2: (projection_l2, 2, 1),
153
+ float('inf'): (projection_linf, 1, 0.3)
154
+ }
155
+
156
+ device = inputs.device
157
+ batch_size = len(inputs)
158
+ batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
159
+ projection, dual_norm, default_ε = _projection_dual_default_ε[norm]
160
+ ε = default_ε if ε is None else ε
161
+
162
+ logits = model(inputs)
163
+ if targets is not None:
164
+ other_labels = targets.unsqueeze(1)
165
+ else:
166
+ # generate all other labels
167
+ n_classes = logits.size(1)
168
+ other_labels = torch.zeros(len(labels), n_classes - 1, dtype=torch.long, device=device)
169
+ all_classes = set(range(n_classes))
170
+ for i in range(len(labels)):
171
+ diff_labels = list(all_classes.difference({labels[i].item()}))
172
+ other_labels[i] = torch.tensor(diff_labels, device=device)
173
+
174
+ get_df_dg = partial(get_best_diff_logits_grads, model=model, labels=labels, other_labels=other_labels, q=dual_norm)
175
+
176
+ adv_inputs = inputs.clone()
177
+ adv_found = logits.argmax(dim=1) != labels
178
+ best_norm = torch.full((batch_size,), float('inf'), device=device, dtype=torch.float) if u is None else u
179
+ best_norm[adv_found] = 0
180
+ best_adv = inputs.clone()
181
+
182
+ if random_start:
183
+ if norm == float('inf'):
184
+ t = torch.rand_like(inputs).mul_(2).sub_(1)
185
+ elif norm in [1, 2]:
186
+ t = torch.randn_like(inputs)
187
+
188
+ adv_inputs.add_(t.mul_(batch_view(best_norm.clamp(max=ε) / t.flatten(1).norm(p=norm, dim=1).mul_(2))))
189
+ adv_inputs.clamp_(min=0.0, max=1.0)
190
+
191
+ for i in range(n_iter):
192
+ df, dg = get_df_dg(inputs=adv_inputs)
193
+ b = (dg * adv_inputs).flatten(1).sum(dim=1).sub_(df)
194
+ w = dg.flatten(1)
195
+
196
+ d3 = projection(torch.cat((adv_inputs.flatten(1), inputs.flatten(1)), 0), w.repeat(2, 1), b.repeat(2))
197
+ d1, d2 = map(lambda t: t.view_as(adv_inputs), torch.chunk(d3, 2, dim=0))
198
+
199
+ a0 = batch_view(d3.flatten(1).norm(p=norm, dim=1).clamp_(min=1e-8))
200
+ a1, a2 = torch.chunk(a0, 2, dim=0)
201
+
202
+ α = a1.div_(a2.add_(a1)).clamp_(min=0, max=α_max)
203
+ adv_inputs.add_(d1, alpha=η).mul_(1 - α).add_(inputs.add(d2, alpha=η).mul_(α)).clamp_(min=0, max=1)
204
+
205
+ is_adv = model(adv_inputs).argmax(1) != labels
206
+ adv_found.logical_or_(is_adv)
207
+ adv_norm = (adv_inputs - inputs).flatten(1).norm(p=norm, dim=1)
208
+ is_smaller = adv_norm < best_norm
209
+ is_both = is_adv & is_smaller
210
+ best_norm = torch.where(is_both, adv_norm, best_norm)
211
+ best_adv = torch.where(batch_view(is_both), adv_inputs, best_adv)
212
+
213
+ adv_inputs = torch.where(batch_view(is_adv), inputs + (adv_inputs - inputs) * β, adv_inputs)
214
+
215
+ return best_adv, adv_found, best_norm