adv-lib 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. adv_lib/__init__.py +1 -0
  2. adv_lib/attacks/__init__.py +13 -0
  3. adv_lib/attacks/augmented_lagrangian.py +243 -0
  4. adv_lib/attacks/auto_pgd.py +523 -0
  5. adv_lib/attacks/boundary_projection_tf.py +170 -0
  6. adv_lib/attacks/carlini_wagner/__init__.py +2 -0
  7. adv_lib/attacks/carlini_wagner/l2.py +151 -0
  8. adv_lib/attacks/carlini_wagner/linf.py +158 -0
  9. adv_lib/attacks/decoupled_direction_norm.py +113 -0
  10. adv_lib/attacks/fast_adaptive_boundary/__init__.py +1 -0
  11. adv_lib/attacks/fast_adaptive_boundary/fast_adaptive_boundary.py +215 -0
  12. adv_lib/attacks/fast_adaptive_boundary/projections.py +164 -0
  13. adv_lib/attacks/fast_minimum_norm.py +218 -0
  14. adv_lib/attacks/perceptual_color_attacks/__init__.py +1 -0
  15. adv_lib/attacks/perceptual_color_attacks/differential_color_functions.py +181 -0
  16. adv_lib/attacks/perceptual_color_attacks/perceptual_color_distance_al.py +128 -0
  17. adv_lib/attacks/primal_dual_gradient_descent.py +379 -0
  18. adv_lib/attacks/projected_gradient_descent.py +109 -0
  19. adv_lib/attacks/segmentation/__init__.py +4 -0
  20. adv_lib/attacks/segmentation/alma_prox.py +283 -0
  21. adv_lib/attacks/segmentation/asma.py +92 -0
  22. adv_lib/attacks/segmentation/dense_adversary.py +83 -0
  23. adv_lib/attacks/segmentation/primal_dual_gradient_descent.py +349 -0
  24. adv_lib/attacks/self_adaptive_norm_update.py +127 -0
  25. adv_lib/attacks/sigma_zero.py +119 -0
  26. adv_lib/attacks/stochastic_sparse_attacks.py +237 -0
  27. adv_lib/attacks/structured_adversarial_attack.py +289 -0
  28. adv_lib/attacks/trust_region.py +153 -0
  29. adv_lib/distances/__init__.py +0 -0
  30. adv_lib/distances/color_difference.py +212 -0
  31. adv_lib/distances/lp_norms.py +18 -0
  32. adv_lib/distances/lpips.py +99 -0
  33. adv_lib/distances/structural_similarity.py +147 -0
  34. adv_lib/utils/__init__.py +1 -0
  35. adv_lib/utils/attack_utils.py +226 -0
  36. adv_lib/utils/color_conversions.py +71 -0
  37. adv_lib/utils/image_selection.py +27 -0
  38. adv_lib/utils/lagrangian_penalties/__init__.py +1 -0
  39. adv_lib/utils/lagrangian_penalties/all_penalties.py +67 -0
  40. adv_lib/utils/lagrangian_penalties/penalty_functions.py +79 -0
  41. adv_lib/utils/lagrangian_penalties/scripts/plot_penalties.py +42 -0
  42. adv_lib/utils/lagrangian_penalties/scripts/plot_univariates.py +32 -0
  43. adv_lib/utils/lagrangian_penalties/univariate_functions.py +299 -0
  44. adv_lib/utils/losses.py +29 -0
  45. adv_lib/utils/projections.py +100 -0
  46. adv_lib/utils/utils.py +58 -0
  47. adv_lib/utils/visdom_logger.py +109 -0
  48. adv_lib-0.2.2.dist-info/LICENSE +29 -0
  49. adv_lib-0.2.2.dist-info/METADATA +170 -0
  50. adv_lib-0.2.2.dist-info/RECORD +52 -0
  51. adv_lib-0.2.2.dist-info/WHEEL +5 -0
  52. adv_lib-0.2.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,523 @@
1
+ # Adapted from https://github.com/fra31/auto-attack
2
+ import math
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Tuple, Optional, Union
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+ from torch.nn import functional as F
10
+
11
+ from adv_lib.utils.losses import difference_of_logits_ratio
12
+
13
+
14
+ def apgd(model: nn.Module,
15
+ inputs: Tensor,
16
+ labels: Tensor,
17
+ eps: Union[float, Tensor],
18
+ norm: float,
19
+ targeted: bool = False,
20
+ n_iter: int = 100,
21
+ n_restarts: int = 1,
22
+ loss_function: str = 'dlr',
23
+ eot_iter: int = 1,
24
+ rho: float = 0.75,
25
+ use_large_reps: bool = False,
26
+ use_rs: bool = True,
27
+ best_loss: bool = False) -> Tensor:
28
+ """
29
+ Auto-PGD (APGD) attack from https://arxiv.org/abs/2003.01690 with L1 variant from https://arxiv.org/abs/2103.01208.
30
+
31
+ Parameters
32
+ ----------
33
+ model : nn.Module
34
+ Model to attack.
35
+ inputs : Tensor
36
+ Inputs to attack. Should be in [0, 1].
37
+ labels : Tensor
38
+ Labels corresponding to the inputs if untargeted, else target labels.
39
+ eps : float or Tensor
40
+ Maximum norm for the adversarial perturbation. Can be a float used for all samples or a Tensor containing the
41
+ distance for each corresponding sample.
42
+ norm : float
43
+ Norm corresponding to eps in {1, 2, float('inf')}.
44
+ targeted : bool
45
+ Whether to perform a targeted attack or not.
46
+ n_iter : int
47
+ Number of optimization steps.
48
+ n_restarts : int
49
+ Number of random restarts for the attack.
50
+ loss_function : str
51
+ Loss to optimize in ['ce', 'dlr'].
52
+ eot_iter : int
53
+ Number of iterations for expectation over transformation.
54
+ rho : float
55
+ Parameters for decreasing the step size.
56
+ use_large_reps : bool
57
+ Split iterations in three phases starting with larger eps (see section 3.2 of https://arxiv.org/abs/2103.01208).
58
+ use_rs : bool
59
+ Use a random start when using large reps.
60
+ best_loss : bool
61
+ If True, search for the strongest adversarial perturbation within the distance budget instead of stopping as
62
+ soon as it finds one.
63
+
64
+ Returns
65
+ -------
66
+ adv_inputs : Tensor
67
+ Modified inputs to be adversarial to the model.
68
+
69
+ """
70
+ assert norm in [1, 2, float('inf')]
71
+ device = inputs.device
72
+ batch_size = len(inputs)
73
+
74
+ adv_inputs = inputs.clone()
75
+ adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
76
+ if isinstance(eps, numbers.Real):
77
+ eps = torch.full_like(adv_found, eps, dtype=torch.float)
78
+
79
+ if use_large_reps:
80
+ epss = [3 * eps, 2 * eps, eps]
81
+ iters = [0.3 * n_iter, 0.3 * n_iter, 0.4 * n_iter]
82
+ iters = [math.ceil(i) for i in iters]
83
+ iters[-1] = n_iter - sum(iters[:-1])
84
+
85
+ apgd_attack = partial(_apgd, model=model, norm=norm, targeted=targeted, loss_function=loss_function,
86
+ eot_iter=eot_iter, rho=rho)
87
+
88
+ if best_loss:
89
+ loss = torch.full_like(adv_found, -float('inf'), dtype=torch.float)
90
+
91
+ for _ in range(n_restarts):
92
+ adv_inputs_run, adv_found_run, loss_run, _ = apgd_attack(inputs=inputs, labels=labels, eps=eps)
93
+
94
+ better_loss = loss_run > loss
95
+ adv_inputs[better_loss] = adv_inputs_run[better_loss]
96
+ loss[better_loss] = loss_run[better_loss]
97
+
98
+ else:
99
+ for _ in range(n_restarts):
100
+ if adv_found.all():
101
+ break
102
+ to_attack = ~adv_found
103
+
104
+ inputs_to_attack = inputs[to_attack]
105
+ labels_to_attack = labels[to_attack]
106
+
107
+ if use_large_reps:
108
+ assert norm == 1
109
+ if use_rs:
110
+ x_init = inputs_to_attack + torch.randn_like(inputs_to_attack)
111
+ x_init += l1_projection(inputs_to_attack, x_init - inputs_to_attack, epss[0][to_attack])
112
+ else:
113
+ x_init = None
114
+
115
+ for eps_, iter in zip(epss, iters):
116
+ eps_to_attack = eps_[to_attack]
117
+ if x_init is not None:
118
+ x_init += l1_projection(inputs_to_attack, x_init - inputs_to_attack, eps_to_attack)
119
+
120
+ x_init, adv_found_run, _, adv_inputs_run = apgd_attack(
121
+ inputs=inputs_to_attack, labels=labels_to_attack, eps=eps_to_attack, x_init=x_init, n_iter=iter)
122
+
123
+ else:
124
+ _, adv_found_run, _, adv_inputs_run = apgd_attack(inputs=inputs_to_attack, labels=labels_to_attack,
125
+ eps=eps[to_attack], n_iter=n_iter)
126
+ adv_inputs[to_attack] = adv_inputs_run
127
+ adv_found[to_attack] = adv_found_run
128
+
129
+ return adv_inputs
130
+
131
+
132
+ def apgd_targeted(model: nn.Module,
133
+ inputs: Tensor,
134
+ labels: Tensor,
135
+ eps: Union[float, Tensor],
136
+ norm: float,
137
+ targeted: bool = False,
138
+ n_iter: int = 100,
139
+ n_restarts: int = 1,
140
+ loss_function: str = 'dlr',
141
+ eot_iter: int = 1,
142
+ rho: float = 0.75,
143
+ use_large_reps: bool = False,
144
+ use_rs: bool = True,
145
+ num_targets: Optional[int] = None) -> Tensor:
146
+ """
147
+ Targeted variant of the Auto-PGD (APGD) attack from https://arxiv.org/abs/2003.01690 with L1 variant from
148
+ https://arxiv.org/abs/2103.01208. This attack is not a targeted one: it tries to find an adversarial perturbation by
149
+ attacking each class, starting with the most likely one (different from the original class).
150
+
151
+ Parameters
152
+ ----------
153
+ model : nn.Module
154
+ Model to attack.
155
+ inputs : Tensor
156
+ Inputs to attack. Should be in [0, 1].
157
+ labels : Tensor
158
+ Labels corresponding to the inputs if untargeted, else target labels.
159
+ eps : float or Tensor
160
+ Maximum norm for the adversarial perturbation. Can be a float used for all samples or a Tensor containing the
161
+ distance for each corresponding sample.
162
+ norm : float
163
+ Norm corresponding to eps in {1, 2, float('inf')}.
164
+ targeted : bool
165
+ Required argument for the library. Will raise an assertion error if True (will be ignored if the -O flag is
166
+ used).
167
+ n_iter : int
168
+ Number of optimization steps.
169
+ n_restarts : int
170
+ Number of random restarts for the attack for each class attacked.
171
+ loss_function : str
172
+ Loss to optimize in ['ce', 'dlr'].
173
+ eot_iter : int
174
+ Number of iterations for expectation over transformation.
175
+ rho : float
176
+ Parameters for decreasing the step size.
177
+ use_large_reps : bool
178
+ Split iterations in three phases starting with larger eps (see section 3.2 of https://arxiv.org/abs/2103.01208).
179
+ use_rs : bool
180
+ Use a random start when using large reps.
181
+ num_targets : int or None
182
+ Number of classes to attack. If None, it will attack every class (except the original class).
183
+
184
+ Returns
185
+ -------
186
+ adv_inputs : Tensor
187
+ Modified inputs to be adversarial to the model.
188
+
189
+ """
190
+ assert targeted == False
191
+ device = inputs.device
192
+ batch_size = len(inputs)
193
+
194
+ adv_inputs = inputs.clone()
195
+ adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
196
+ if isinstance(eps, numbers.Real):
197
+ eps = torch.full_like(adv_found, eps, dtype=torch.float)
198
+
199
+ if use_large_reps:
200
+ epss = [3 * eps, 2 * eps, eps]
201
+ iters = [0.3 * n_iter, 0.3 * n_iter, 0.4 * n_iter]
202
+ iters = [math.ceil(i) for i in iters]
203
+ iters[-1] = n_iter - sum(iters[:-1])
204
+
205
+ apgd_attack = partial(_apgd, model=model, norm=norm, targeted=True, loss_function=loss_function,
206
+ eot_iter=eot_iter, rho=rho)
207
+
208
+ # determine the number of classes based on the size of the model's output
209
+ most_likely_classes = model(inputs).argsort(dim=1, descending=True)[:, 1:]
210
+ num_classes_to_attack = most_likely_classes.size(1) if num_targets is None else num_targets
211
+
212
+ for i in range(num_classes_to_attack):
213
+ targets = most_likely_classes[:, i]
214
+
215
+ for counter in range(n_restarts):
216
+ if adv_found.all():
217
+ break
218
+ to_attack = ~adv_found
219
+
220
+ inputs_to_attack = inputs[to_attack]
221
+ targets_to_attack = targets[to_attack]
222
+
223
+ if use_large_reps:
224
+ assert norm == 1
225
+ if use_rs:
226
+ x_init = inputs_to_attack + torch.randn_like(inputs_to_attack)
227
+ x_init += l1_projection(inputs_to_attack, x_init - inputs_to_attack, epss[0][to_attack])
228
+ else:
229
+ x_init = None
230
+
231
+ for eps_, iter in zip(epss, iters):
232
+ eps_to_attack = eps_[to_attack]
233
+ if x_init is not None:
234
+ x_init += l1_projection(inputs_to_attack, x_init - inputs_to_attack, eps_to_attack)
235
+
236
+ x_init, adv_found_run, _, adv_inputs_run = apgd_attack(
237
+ inputs=inputs_to_attack, labels=targets_to_attack, eps=eps_to_attack, x_init=x_init,
238
+ n_iter=iter)
239
+
240
+ else:
241
+ _, adv_found_run, _, adv_inputs_run = apgd_attack(inputs=inputs_to_attack, labels=targets_to_attack,
242
+ eps=eps[to_attack], n_iter=n_iter)
243
+
244
+ adv_inputs[to_attack] = adv_inputs_run
245
+ adv_found[to_attack] = adv_found_run
246
+
247
+ return adv_inputs
248
+
249
+
250
+ def minimal_apgd(model: nn.Module,
251
+ inputs: Tensor,
252
+ labels: Tensor,
253
+ norm: float,
254
+ max_eps: float,
255
+ targeted: bool = False,
256
+ binary_search_steps: int = 20,
257
+ targeted_version: bool = False,
258
+ n_iter: int = 100,
259
+ n_restarts: int = 1,
260
+ loss_function: str = 'dlr',
261
+ eot_iter: int = 1,
262
+ rho: float = 0.75,
263
+ use_large_reps: bool = False,
264
+ use_rs: bool = True,
265
+ num_targets: Optional[int] = None) -> Tensor:
266
+ device = inputs.device
267
+ batch_size = len(inputs)
268
+
269
+ adv_inputs = inputs.clone()
270
+ best_eps = torch.full((batch_size,), 2 * max_eps, dtype=torch.float, device=device)
271
+ eps_low = torch.zeros_like(best_eps)
272
+
273
+ if targeted_version:
274
+ attack = partial(apgd_targeted, model=model, norm=norm, n_iter=n_iter, n_restarts=n_restarts,
275
+ loss_function=loss_function, eot_iter=eot_iter, rho=rho, use_large_reps=use_large_reps,
276
+ use_rs=use_rs, num_targets=num_targets)
277
+ else:
278
+ attack = partial(apgd, model=model, norm=norm, targeted=targeted, n_iter=n_iter, n_restarts=n_restarts,
279
+ loss_function=loss_function, eot_iter=eot_iter, rho=rho, use_large_reps=use_large_reps,
280
+ use_rs=use_rs)
281
+
282
+ for _ in range(binary_search_steps):
283
+ eps = (eps_low + best_eps) / 2
284
+
285
+ adv_inputs_run = attack(inputs=inputs, labels=labels, eps=eps)
286
+ adv_found_run = model(adv_inputs_run).argmax(1) != labels
287
+
288
+ better_adv = adv_found_run & (eps < best_eps)
289
+ adv_inputs[better_adv] = adv_inputs_run[better_adv]
290
+
291
+ eps_low = torch.where(better_adv, eps_low, eps)
292
+ best_eps = torch.where(better_adv, eps, best_eps)
293
+
294
+ return adv_inputs
295
+
296
+
297
+ def l1_projection(x: Tensor, y: Tensor, eps: Tensor) -> Tensor:
298
+ device = x.device
299
+ shape = x.shape
300
+ x, y = x.flatten(1), y.flatten(1)
301
+ u = torch.min(1 - x - y, x + y).clamp_(max=0)
302
+ l = y.abs().neg_()
303
+ d = u.clone()
304
+
305
+ bs, indbs = torch.sort(torch.cat((u, l), dim=1).neg_(), dim=1)
306
+ bs2 = F.pad(bs[:, 1:], (0, 1))
307
+
308
+ inu = (indbs < u.shape[1]).float().mul_(2).sub_(1).cumsum_(dim=1)
309
+
310
+ s1 = u.sum(dim=1)
311
+ c = l.sum(dim=1).add_(eps)
312
+ c5 = c < s1
313
+ s = (bs2 - bs).mul_(inu).cumsum_(dim=1).sub_(s1.unsqueeze(-1))
314
+
315
+ if c5.any():
316
+ lb = torch.zeros(c5.sum(), device=device)
317
+ ub = torch.full_like(lb, bs.shape[1] - 1)
318
+
319
+ nitermax = math.ceil(math.log2(bs.shape[1]))
320
+ counter = 0
321
+
322
+ while counter < nitermax:
323
+ counter4 = lb.lerp(ub, weight=0.5).floor()
324
+ counter2 = counter4.long()
325
+
326
+ c8 = s[c5, counter2] + c[c5] < 0
327
+ lb[c8] = counter4[c8]
328
+ ub[~c8] = counter4[~c8]
329
+
330
+ counter += 1
331
+
332
+ lb2 = lb.long()
333
+ alpha = bs2[c5, lb2].addcdiv(s[c5, lb2] + c[c5], inu[c5, lb2 + 1], value=-1)
334
+ d[c5] = -torch.min(torch.max(-u[c5], alpha.unsqueeze(-1)), -l[c5])
335
+
336
+ return d.mul_(y.sign()).view(shape)
337
+
338
+
339
+ def check_oscillation(loss_steps: Tensor, j: int, k: int, k3: float = 0.75) -> Tensor:
340
+ t = torch.zeros_like(loss_steps[0])
341
+ for counter5 in range(k):
342
+ t.add_(loss_steps[j - counter5] > loss_steps[j - counter5 - 1])
343
+ return t <= k * k3
344
+
345
+
346
+ def _apgd(model: nn.Module,
347
+ inputs: Tensor,
348
+ labels: Tensor,
349
+ eps: Tensor,
350
+ norm: float,
351
+ x_init: Optional[Tensor] = None,
352
+ targeted: bool = False,
353
+ n_iter: int = 100,
354
+ loss_function: str = 'dlr',
355
+ eot_iter: int = 1,
356
+ rho: float = 0.75) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
357
+ _loss_functions = {
358
+ 'ce': (nn.CrossEntropyLoss(reduction='none'), -1 if targeted else 1),
359
+ 'dlr': (partial(difference_of_logits_ratio, targeted=targeted), 1 if targeted else -1),
360
+ }
361
+
362
+ device = inputs.device
363
+ batch_size = len(inputs)
364
+ batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
365
+ criterion_indiv, multiplier = _loss_functions[loss_function.lower()]
366
+
367
+ lower, upper = (inputs - batch_view(eps)).clamp_(min=0, max=1), (inputs + batch_view(eps)).clamp_(min=0, max=1)
368
+
369
+ n_iter_2, n_iter_min, size_decr = max(int(0.22 * n_iter), 1), max(int(0.06 * n_iter), 1), max(int(0.03 * n_iter), 1)
370
+
371
+ if x_init is not None:
372
+ x_adv = x_init.clone()
373
+ elif norm == float('inf'):
374
+ t = 2 * torch.rand_like(inputs) - 1
375
+ x_adv = inputs + t * batch_view(eps / t.flatten(1).norm(p=float('inf'), dim=1))
376
+ elif norm == 2:
377
+ t = torch.randn_like(inputs)
378
+ x_adv = inputs + t * batch_view(eps / t.flatten(1).norm(p=2, dim=1))
379
+ elif norm == 1:
380
+ t = torch.randn_like(inputs)
381
+ delta = l1_projection(inputs, t, eps)
382
+ x_adv = inputs + t + delta
383
+
384
+ x_adv.clamp_(min=0, max=1)
385
+ x_best = x_adv.clone()
386
+ x_best_adv = inputs.clone()
387
+ loss_steps = torch.zeros(n_iter, batch_size, device=device)
388
+ loss_best_steps = torch.zeros(n_iter + 1, batch_size, device=device)
389
+ adv_found_steps = torch.zeros_like(loss_best_steps)
390
+
391
+ x_adv.requires_grad_()
392
+ grad = torch.zeros_like(inputs)
393
+ for _ in range(eot_iter):
394
+ logits = model(x_adv)
395
+ loss_indiv = multiplier * criterion_indiv(logits, labels)
396
+ grad.add_(torch.autograd.grad(loss_indiv.sum(), x_adv, only_inputs=True)[0])
397
+
398
+ grad.div_(eot_iter)
399
+ grad_best = grad.clone()
400
+ x_adv.detach_()
401
+
402
+ adv_found = (logits.argmax(1) == labels) if targeted else (logits.argmax(1) != labels)
403
+ adv_found_steps[0] = adv_found
404
+ loss_best = loss_indiv.detach().clone()
405
+
406
+ alpha = 2 if norm in [2, float('inf')] else 1 if norm == 1 else 2e-2
407
+ step_size = alpha * eps
408
+ x_adv_old = x_adv.clone()
409
+ k = n_iter_2
410
+ counter3 = 0
411
+
412
+ if norm == 1:
413
+ k = max(int(0.04 * n_iter), 1)
414
+ n_fts = inputs[0].numel()
415
+ if x_init is None:
416
+ topk = torch.ones(len(inputs), device=device).mul_(0.2)
417
+ sp_old = torch.full_like(topk, n_fts, dtype=torch.float)
418
+ else:
419
+ sp_old = (x_adv - inputs).flatten(1).norm(p=0, dim=1)
420
+ topk = sp_old / n_fts / 1.5
421
+ adasp_redstep = 1.5
422
+ adasp_minstep = 10.
423
+
424
+ loss_best_last_check = loss_best.clone()
425
+ reduced_last_check = torch.zeros_like(loss_best, dtype=torch.bool)
426
+
427
+ for i in range(n_iter):
428
+ ### gradient step
429
+ grad2 = x_adv - x_adv_old
430
+ x_adv_old = x_adv
431
+
432
+ a = 0.75 if i else 1.0
433
+
434
+ if norm == float('inf'):
435
+ x_adv_1 = x_adv.addcmul(batch_view(step_size), torch.sign(grad))
436
+ torch.minimum(torch.maximum(x_adv_1, lower, out=x_adv_1), upper, out=x_adv_1)
437
+
438
+ # momentum
439
+ x_adv_1.lerp_(x_adv, weight=1 - a).add_(grad2, alpha=1 - a)
440
+ torch.minimum(torch.maximum(x_adv_1, lower, out=x_adv_1), upper, out=x_adv_1)
441
+
442
+ elif norm == 2:
443
+ delta = x_adv.addcmul(grad, batch_view(step_size / grad.flatten(1).norm(p=2, dim=1).add_(1e-12)))
444
+ delta.sub_(inputs)
445
+ delta_norm = delta.flatten(1).norm(p=2, dim=1).add_(1e-12)
446
+ x_adv_1 = delta.mul_(batch_view(torch.min(delta_norm, eps).div_(delta_norm))).add_(inputs).clamp_(min=0,
447
+ max=1)
448
+
449
+ # momentum
450
+ delta = x_adv_1.lerp_(x_adv, weight=1 - a).add_(grad2, alpha=1 - a)
451
+ delta.sub_(inputs)
452
+ delta_norm = delta.flatten(1).norm(p=2, dim=1).add_(1e-12)
453
+ x_adv_1 = delta.mul_(batch_view(torch.min(delta_norm, eps).div_(delta_norm))).add_(inputs).clamp_(min=0,
454
+ max=1)
455
+
456
+ elif norm == 1:
457
+ grad_abs = grad.abs()
458
+ grad_topk = grad_abs.flatten(1).sort(dim=1).values
459
+ topk_curr = (1 - topk).mul_(n_fts).clamp_(min=0, max=n_fts - 1).long()
460
+ grad_topk = grad_topk.gather(1, topk_curr.unsqueeze(1))
461
+ grad.mul_(grad_abs >= batch_view(grad_topk))
462
+ grad_sign = grad.sign()
463
+
464
+ x_adv.addcmul_(grad_sign, batch_view(step_size / grad_sign.flatten(1).norm(p=1, dim=1).add_(1e-10)))
465
+ x_adv.sub_(inputs)
466
+ delta_p = l1_projection(inputs, x_adv, eps)
467
+ x_adv_1 = x_adv.add_(inputs).add_(delta_p)
468
+
469
+ x_adv = x_adv_1
470
+
471
+ ### get gradient
472
+ x_adv.requires_grad_(True)
473
+ grad.zero_()
474
+ for _ in range(eot_iter):
475
+ logits = model(x_adv)
476
+ loss_indiv = multiplier * criterion_indiv(logits, labels)
477
+ grad.add_(torch.autograd.grad(loss_indiv.sum(), x_adv, only_inputs=True)[0])
478
+
479
+ grad.div_(eot_iter)
480
+ x_adv.detach_(), loss_indiv.detach_()
481
+
482
+ is_adv = (logits.argmax(1) == labels) if targeted else (logits.argmax(1) != labels)
483
+ adv_found.logical_or_(is_adv)
484
+ adv_found_steps[i + 1] = adv_found
485
+ x_best_adv[is_adv] = x_adv[is_adv]
486
+
487
+ ### check step size
488
+ loss_steps[i] = loss_indiv
489
+ ind = loss_indiv > loss_best
490
+ x_best[ind] = x_adv[ind]
491
+ grad_best[ind] = grad[ind]
492
+ loss_best[ind] = loss_indiv[ind]
493
+ loss_best_steps[i + 1] = loss_best
494
+
495
+ counter3 += 1
496
+
497
+ if counter3 == k:
498
+ if norm in [2, float('inf')]:
499
+ fl_reduce_no_impr = (~reduced_last_check) & (loss_best_last_check >= loss_best)
500
+ reduced_last_check = check_oscillation(loss_steps, i, k, k3=rho) | fl_reduce_no_impr
501
+ loss_best_last_check = loss_best
502
+
503
+ if reduced_last_check.any():
504
+ step_size[reduced_last_check] /= 2.0
505
+ x_adv[reduced_last_check] = x_best[reduced_last_check]
506
+ grad[reduced_last_check] = grad_best[reduced_last_check]
507
+
508
+ k = max(k - size_decr, n_iter_min)
509
+
510
+ elif norm == 1:
511
+ sp_curr = (x_best - inputs).flatten(1).norm(p=0, dim=1)
512
+ fl_redtopk = (sp_curr / sp_old) < 0.95
513
+ topk = sp_curr / n_fts / 1.5
514
+ step_size = torch.where(fl_redtopk, alpha * eps, step_size / adasp_redstep)
515
+ step_size = torch.min(torch.max(step_size, alpha * eps / adasp_minstep), alpha * eps)
516
+ sp_old = sp_curr
517
+
518
+ x_adv[fl_redtopk] = x_best[fl_redtopk]
519
+ grad[fl_redtopk] = grad_best[fl_redtopk]
520
+
521
+ counter3 = 0
522
+
523
+ return x_best, adv_found, loss_best, x_best_adv
@@ -0,0 +1,170 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from torch.autograd import grad
6
+ from torch.nn import functional as F
7
+
8
+ from adv_lib.utils.visdom_logger import VisdomLogger
9
+
10
+
11
+ def bp(model: nn.Module,
12
+ inputs: Tensor,
13
+ labels: Tensor,
14
+ targeted: bool = False,
15
+ num_steps: int = 100,
16
+ γ: float = 0.7,
17
+ α: float = 2,
18
+ levels: Optional[int] = 256,
19
+ callback: Optional[VisdomLogger] = None) -> Tensor:
20
+ """
21
+ Boundary Projection (BP) attack from https://arxiv.org/abs/1912.02153.
22
+
23
+ Parameters
24
+ ----------
25
+ model : nn.Module
26
+ Model to attack.
27
+ inputs : Tensor
28
+ Inputs to attack. Should be in [0, 1].
29
+ labels : Tensor
30
+ Labels corresponding to the inputs if untargeted, else target labels.
31
+ targeted : bool
32
+ Whether to perform a targeted attack or not.
33
+ num_steps : int
34
+ Number of optimization steps.
35
+ γ : float
36
+ Factor by which the norm will be modified. new_norm = norm * (1 + or - γ).
37
+ levels : int
38
+ If not None, the returned adversarials will have quantized values to the specified number of levels.
39
+ callback : Optional
40
+
41
+ Returns
42
+ -------
43
+ adv_inputs : Tensor
44
+ Modified inputs to be adversarial to the model.
45
+
46
+ """
47
+ if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.')
48
+ device = inputs.device
49
+ batch_size = len(inputs)
50
+ batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
51
+
52
+ # Init variables
53
+ multiplier = 1 if targeted else -1
54
+ δ = torch.zeros_like(inputs, requires_grad=True)
55
+
56
+ # Init trackers
57
+ best_l2 = torch.full((batch_size,), float('inf'), device=device)
58
+ best_adv = inputs.clone()
59
+ adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
60
+
61
+ for i in range(num_steps):
62
+ adv_inputs = inputs + δ
63
+ logits = model(adv_inputs)
64
+
65
+ if i == 0:
66
+ num_classes = logits.shape[1]
67
+ one_hot_labels = F.one_hot(labels, num_classes=num_classes)
68
+
69
+ # "softmax_cross_entropy_better" loss
70
+ tmp = one_hot_labels * logits
71
+ logits_1 = logits - tmp
72
+ j_best = logits_1.amax(dim=1)
73
+ logits_2 = logits_1 - j_best.unsqueeze(1) + one_hot_labels * j_best.unsqueeze(1)
74
+ tmp_s = tmp.amax(dim=1)
75
+ up = tmp_s - j_best
76
+ down = logits_2.exp().add(1).sum(dim=1).log()
77
+ loss = up - down
78
+
79
+ g = grad(multiplier * loss.sum(), δ, only_inputs=True)[0]
80
+ d = inputs - adv_inputs.detach() # N x C x H x W
81
+
82
+ snd = d.flatten(1).norm(p=2, dim=1).clamp_(min=1e-6) # N
83
+ nd = d / batch_view(snd) # N x C x H x W
84
+
85
+ sng = g.flatten(1).norm(p=2, dim=1).clamp_(min=1e-6) # N
86
+ ng = g / batch_view(sng) # \hat{g} - shape: N x C x H x W
87
+
88
+ cos_ψ = (nd * ng).flatten(1).sum(dim=1) # r - shape: N
89
+ sin_ψ = (1 - cos_ψ ** 2).sqrt() # N
90
+
91
+ pred_labels = logits.argmax(1)
92
+ is_adv = (pred_labels == labels) if targeted else (pred_labels != labels)
93
+ is_smaller = snd < best_l2
94
+ is_both = is_adv & is_smaller
95
+ adv_found.logical_or_(is_adv)
96
+ best_l2 = torch.where(is_both, snd, best_l2)
97
+ best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv)
98
+
99
+ if callback is not None:
100
+ callback.accumulate_line('loss', i, loss.mean(), title='BP - Loss')
101
+ callback_best = best_l2.masked_select(adv_found).mean()
102
+ callback.accumulate_line(['l2', 'best_l2'], i, [snd.mean(), callback_best])
103
+ callback.accumulate_line(['success'], i, [adv_found.float().mean()], title='BP - ASR')
104
+
105
+ if (i + 1) % (num_steps // 20) == 0 or (i + 1) == num_steps:
106
+ callback.update_lines()
107
+
108
+ # step-size decay
109
+ ε = γ + i * (1 - γ) / (num_steps + 1)
110
+
111
+ # search
112
+ p_search = α * ε * ng
113
+
114
+ # refine step
115
+ # out
116
+ λ = (d * ng).flatten(1).sum(dim=1)
117
+ g_ort = d - batch_view(λ) * ng
118
+ ε_out = snd * ε
119
+
120
+ # estimate β out LS
121
+ dis = snd * ε # N
122
+ ngo = g_ort.flatten(1).norm(p=2, dim=1).clamp_(min=1e-6) # N
123
+ g_ort_ = g_ort / batch_view(ngo) # N x C x H x W
124
+ tmp = (d * g_ort_).flatten(1).sum(dim=1) # N
125
+ an_tmp = batch_view(tmp.square()) - d.square() + batch_view(dis) # N x C x H x W
126
+ min_β = batch_view(tmp) - an_tmp.sqrt() # N x C x H x W
127
+ max_β = batch_view(tmp) # N x C x H x W
128
+
129
+ p_min = min_β * g_ort_ # N x C x H x W
130
+ p_max = max_β * g_ort_ # N x C x H x W
131
+ DMin = (d - p_min).mul_(levels - 1).round_().div_(levels - 1).flatten(1).norm(p=2, dim=1)
132
+ DMax = (d - p_max).mul_(levels - 1).round_().div_(levels - 1).flatten(1).norm(p=2, dim=1)
133
+ for j in range(7):
134
+ β = (min_β + max_β) / 2
135
+ p = β * g_ort_
136
+ D = (d - p).mul_(levels - 1).round_().div_(levels - 1).flatten(1).norm(p=2, dim=1)
137
+ flag = D < dis
138
+ DMin = torch.where(flag, D, DMin)
139
+ DMax = torch.where(flag, DMax, D)
140
+ min_β = torch.where(batch_view(flag), β, min_β)
141
+ max_β = torch.where(batch_view(flag), max_β, β)
142
+ dMin = (DMin - dis).abs_()
143
+ dMax = (DMax - dis).abs_()
144
+ flag = dMax < dMin
145
+ β = torch.where(batch_view(flag), max_β, min_β)
146
+
147
+ # out_p
148
+ μ = (batch_view(sin_ψ * snd) / β - 1).clamp_(min=0)
149
+ p_out = g_ort / (1 + μ)
150
+
151
+ # estimate β in simple
152
+ dis = snd / ε
153
+ ngo, g_ort_ = sng, ng
154
+ p_o = (d * g_ort_).flatten(1).sum(dim=1) # N
155
+ bac = dis.square() - snd.square() + p_o.square() # N
156
+ β = p_o + bac.sqrt()
157
+ β_min = torch.full_like(dis, 0.1)
158
+ β.clamp_(min=β_min)
159
+
160
+ p_in = ng * batch_view(β)
161
+
162
+ delta = torch.where(batch_view(adv_found), p_in, p_search)
163
+ delta = torch.where(batch_view(is_adv), p_out, delta)
164
+
165
+ adv_x = (adv_inputs + delta).clamp_(min=0, max=1)
166
+ adv_x.mul_(levels - 1).round_().div_(levels - 1)
167
+
168
+ δ.data = adv_x - inputs
169
+
170
+ return best_adv
@@ -0,0 +1,2 @@
1
+ from .l2 import carlini_wagner_l2
2
+ from .linf import carlini_wagner_linf