adv-lib 0.2.2__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_lib-0.2.2 → adv_lib-0.2.6}/PKG-INFO +5 -2
- {adv_lib-0.2.2 → adv_lib-0.2.6}/README.md +2 -0
- adv_lib-0.2.6/adv_lib/__init__.py +1 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/__init__.py +2 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/carlini_wagner/l2.py +27 -12
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/carlini_wagner/linf.py +21 -12
- adv_lib-0.2.6/adv_lib/attacks/deepfool.py +127 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/fast_adaptive_boundary/fast_adaptive_boundary.py +2 -7
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/fast_minimum_norm.py +4 -11
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/segmentation/alma_prox.py +21 -16
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/stochastic_sparse_attacks.py +6 -2
- adv_lib-0.2.6/adv_lib/attacks/superdeepfool.py +105 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/distances/structural_similarity.py +53 -48
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/attack_utils.py +21 -21
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/projections.py +2 -2
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib.egg-info/PKG-INFO +5 -2
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib.egg-info/SOURCES.txt +3 -5
- {adv_lib-0.2.2 → adv_lib-0.2.6}/pyproject.toml +3 -2
- adv_lib-0.2.2/adv_lib/__init__.py +0 -1
- adv_lib-0.2.2/adv_lib/attacks/boundary_projection_tf.py +0 -170
- adv_lib-0.2.2/adv_lib/attacks/self_adaptive_norm_update.py +0 -127
- adv_lib-0.2.2/adv_lib/utils/lagrangian_penalties/scripts/plot_penalties.py +0 -42
- adv_lib-0.2.2/adv_lib/utils/lagrangian_penalties/scripts/plot_univariates.py +0 -32
- {adv_lib-0.2.2 → adv_lib-0.2.6}/LICENSE +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/augmented_lagrangian.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/auto_pgd.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/carlini_wagner/__init__.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/decoupled_direction_norm.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/fast_adaptive_boundary/__init__.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/fast_adaptive_boundary/projections.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/perceptual_color_attacks/__init__.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/perceptual_color_attacks/differential_color_functions.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/perceptual_color_attacks/perceptual_color_distance_al.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/primal_dual_gradient_descent.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/projected_gradient_descent.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/segmentation/__init__.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/segmentation/asma.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/segmentation/dense_adversary.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/segmentation/primal_dual_gradient_descent.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/sigma_zero.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/structured_adversarial_attack.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/trust_region.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/distances/__init__.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/distances/color_difference.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/distances/lp_norms.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/distances/lpips.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/__init__.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/color_conversions.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/image_selection.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/lagrangian_penalties/__init__.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/lagrangian_penalties/all_penalties.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/lagrangian_penalties/penalty_functions.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/lagrangian_penalties/univariate_functions.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/losses.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/utils.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/utils/visdom_logger.py +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib.egg-info/dependency_links.txt +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib.egg-info/requires.txt +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib.egg-info/top_level.txt +0 -0
- {adv_lib-0.2.2 → adv_lib-0.2.6}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: adv-lib
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.6
|
|
4
4
|
Summary: Library of various adversarial attacks resources in PyTorch
|
|
5
5
|
Author-email: Jerome Rony <jerome.rony@gmail.com>
|
|
6
6
|
License: BSD 3-Clause License
|
|
@@ -49,6 +49,7 @@ Requires-Dist: visdom>=0.1.8
|
|
|
49
49
|
Provides-Extra: test
|
|
50
50
|
Requires-Dist: scikit-image; extra == "test"
|
|
51
51
|
Requires-Dist: pytest; extra == "test"
|
|
52
|
+
Dynamic: license-file
|
|
52
53
|
|
|
53
54
|
|
|
54
55
|
[](https://zenodo.org/badge/latestdoi/315504148)
|
|
@@ -111,6 +112,7 @@ Currently the following classification attacks are implemented in the `adv_lib.a
|
|
|
111
112
|
|
|
112
113
|
| Name | Knowledge | Type | Distance(s) | ArXiv Link |
|
|
113
114
|
|-----------------------------------------------------------------------------------------|-----------|---------|-----------------------------------------------------------|------------------------------------------------------------------------------------------------------|
|
|
115
|
+
| DeepFool (DF) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1511.04599](https://arxiv.org/abs/1511.04599) |
|
|
114
116
|
| Carlini and Wagner (C&W) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1608.04644](https://arxiv.org/abs/1608.04644) |
|
|
115
117
|
| Projected Gradient Descent (PGD) | White-box | Budget | $\ell_\infty$ | [1706.06083](https://arxiv.org/abs/1706.06083) |
|
|
116
118
|
| Structured Adversarial Attack (StrAttack) | White-box | Minimal | $\ell_2$ + group-sparsity | [1808.01664](https://arxiv.org/abs/1808.01664) |
|
|
@@ -123,6 +125,7 @@ Currently the following classification attacks are implemented in the `adv_lib.a
|
|
|
123
125
|
| Folded Gaussian Attack (FGA)<br /> Voting Folded Gaussian Attack (VFGA) | White-box | Minimal | $\ell_0$ | [2011.12423](https://arxiv.org/abs/2011.12423) |
|
|
124
126
|
| Fast Minimum-Norm (FMN) | White-box | Minimal | $\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2102.12827](https://arxiv.org/abs/2102.12827) |
|
|
125
127
|
| Primal-Dual Gradient Descent (PDGD)<br /> Primal-Dual Proximal Gradient Descent (PDPGD) | White-box | Minimal | $\ell_2$<br />$\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2106.01538](https://arxiv.org/abs/2106.01538) |
|
|
128
|
+
| SuperDeepFool (SDF) | White-box | Minimal | $\ell_2$ | [2303.12481](https://arxiv.org/abs/2303.12481) |
|
|
126
129
|
| σ-zero | White-box | Minimal | $\ell_0$ | [2402.01879](https://arxiv.org/abs/2402.01879) |
|
|
127
130
|
|
|
128
131
|
**Bold** means that this repository contains the official implementation.
|
|
@@ -59,6 +59,7 @@ Currently the following classification attacks are implemented in the `adv_lib.a
|
|
|
59
59
|
|
|
60
60
|
| Name | Knowledge | Type | Distance(s) | ArXiv Link |
|
|
61
61
|
|-----------------------------------------------------------------------------------------|-----------|---------|-----------------------------------------------------------|------------------------------------------------------------------------------------------------------|
|
|
62
|
+
| DeepFool (DF) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1511.04599](https://arxiv.org/abs/1511.04599) |
|
|
62
63
|
| Carlini and Wagner (C&W) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1608.04644](https://arxiv.org/abs/1608.04644) |
|
|
63
64
|
| Projected Gradient Descent (PGD) | White-box | Budget | $\ell_\infty$ | [1706.06083](https://arxiv.org/abs/1706.06083) |
|
|
64
65
|
| Structured Adversarial Attack (StrAttack) | White-box | Minimal | $\ell_2$ + group-sparsity | [1808.01664](https://arxiv.org/abs/1808.01664) |
|
|
@@ -71,6 +72,7 @@ Currently the following classification attacks are implemented in the `adv_lib.a
|
|
|
71
72
|
| Folded Gaussian Attack (FGA)<br /> Voting Folded Gaussian Attack (VFGA) | White-box | Minimal | $\ell_0$ | [2011.12423](https://arxiv.org/abs/2011.12423) |
|
|
72
73
|
| Fast Minimum-Norm (FMN) | White-box | Minimal | $\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2102.12827](https://arxiv.org/abs/2102.12827) |
|
|
73
74
|
| Primal-Dual Gradient Descent (PDGD)<br /> Primal-Dual Proximal Gradient Descent (PDPGD) | White-box | Minimal | $\ell_2$<br />$\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2106.01538](https://arxiv.org/abs/2106.01538) |
|
|
75
|
+
| SuperDeepFool (SDF) | White-box | Minimal | $\ell_2$ | [2303.12481](https://arxiv.org/abs/2303.12481) |
|
|
74
76
|
| σ-zero | White-box | Minimal | $\ell_0$ | [2402.01879](https://arxiv.org/abs/2402.01879) |
|
|
75
77
|
|
|
76
78
|
**Bold** means that this repository contains the official implementation.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.6"
|
|
@@ -2,6 +2,7 @@ from .augmented_lagrangian import alma
|
|
|
2
2
|
from .auto_pgd import apgd, apgd_targeted
|
|
3
3
|
from .carlini_wagner import carlini_wagner_l2, carlini_wagner_linf
|
|
4
4
|
from .decoupled_direction_norm import ddn
|
|
5
|
+
from .deepfool import df
|
|
5
6
|
from .fast_adaptive_boundary import fab
|
|
6
7
|
from .fast_minimum_norm import fmn
|
|
7
8
|
from .perceptual_color_attacks import perc_al
|
|
@@ -10,4 +11,5 @@ from .projected_gradient_descent import pgd_linf
|
|
|
10
11
|
from .sigma_zero import sigma_zero
|
|
11
12
|
from .stochastic_sparse_attacks import fga, vfga
|
|
12
13
|
from .structured_adversarial_attack import str_attack
|
|
14
|
+
from .superdeepfool import sdf
|
|
13
15
|
from .trust_region import tr
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
# Adapted from https://github.com/carlini/nn_robust_attacks
|
|
2
|
-
|
|
3
|
-
from typing import
|
|
2
|
+
import math
|
|
3
|
+
from typing import Optional
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from torch import nn,
|
|
6
|
+
from torch import nn, Tensor
|
|
7
7
|
from torch.autograd import grad
|
|
8
8
|
|
|
9
9
|
from adv_lib.utils.losses import difference_of_logits
|
|
@@ -20,6 +20,8 @@ def carlini_wagner_l2(model: nn.Module,
|
|
|
20
20
|
binary_search_steps: int = 9,
|
|
21
21
|
max_iterations: int = 10000,
|
|
22
22
|
abort_early: bool = True,
|
|
23
|
+
β_1: float = 0.9,
|
|
24
|
+
β_2: float = 0.999,
|
|
23
25
|
callback: Optional[VisdomLogger] = None) -> Tensor:
|
|
24
26
|
"""
|
|
25
27
|
Carlini and Wagner L2 attack from https://arxiv.org/abs/1608.04644.
|
|
@@ -50,6 +52,8 @@ def carlini_wagner_l2(model: nn.Module,
|
|
|
50
52
|
learning rate and will produce poor results.
|
|
51
53
|
abort_early : bool
|
|
52
54
|
If true, allows early aborts if gradient descent gets stuck.
|
|
55
|
+
β_1, β_2: float
|
|
56
|
+
Adam exponential averages smoothing parameters.
|
|
53
57
|
callback : Optional
|
|
54
58
|
|
|
55
59
|
Returns
|
|
@@ -62,25 +66,32 @@ def carlini_wagner_l2(model: nn.Module,
|
|
|
62
66
|
batch_size = len(inputs)
|
|
63
67
|
batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
|
|
64
68
|
t_inputs = (inputs * 2).sub_(1).mul_(1 - 1e-6).atanh_()
|
|
65
|
-
|
|
69
|
+
if not targeted: # gradient descent if untargeted, else ascent
|
|
70
|
+
learning_rate *= -1
|
|
66
71
|
|
|
67
72
|
# set the lower and upper bounds accordingly
|
|
68
73
|
c = torch.full((batch_size,), initial_const, device=device)
|
|
69
74
|
lower_bound = torch.zeros_like(c)
|
|
70
75
|
upper_bound = torch.full_like(c, 1e10)
|
|
71
76
|
|
|
77
|
+
# Adam variables
|
|
78
|
+
modifier = torch.zeros_like(inputs, requires_grad=True)
|
|
79
|
+
exp_avg = torch.zeros_like(inputs)
|
|
80
|
+
exp_avg_sq = torch.zeros_like(inputs)
|
|
81
|
+
|
|
72
82
|
o_best_l2 = torch.full_like(c, float('inf'))
|
|
73
83
|
o_best_adv = inputs.clone()
|
|
74
|
-
o_adv_found = torch.
|
|
84
|
+
o_adv_found = torch.zeros_like(c, dtype=torch.bool)
|
|
75
85
|
|
|
76
86
|
i_total = 0
|
|
77
87
|
for outer_step in range(binary_search_steps):
|
|
78
88
|
|
|
79
89
|
# setup the modifier variable and the optimizer
|
|
80
|
-
|
|
81
|
-
|
|
90
|
+
nn.init.zeros_(modifier)
|
|
91
|
+
nn.init.zeros_(exp_avg)
|
|
92
|
+
nn.init.zeros_(exp_avg_sq)
|
|
82
93
|
best_l2 = torch.full_like(c, float('inf'))
|
|
83
|
-
adv_found = torch.
|
|
94
|
+
adv_found = torch.zeros_like(o_adv_found)
|
|
84
95
|
|
|
85
96
|
# The last iteration (if we run many steps) repeat the search once.
|
|
86
97
|
if (binary_search_steps >= 10) and outer_step == (binary_search_steps - 1):
|
|
@@ -115,7 +126,7 @@ def carlini_wagner_l2(model: nn.Module,
|
|
|
115
126
|
o_adv_found.logical_or_(is_both)
|
|
116
127
|
o_best_adv = torch.where(batch_view(o_is_both), adv_inputs.detach(), o_best_adv)
|
|
117
128
|
|
|
118
|
-
logit_dists =
|
|
129
|
+
logit_dists = difference_of_logits(logits, labels, labels_infhot=labels_infhot)
|
|
119
130
|
loss = l2_squared + c * (logit_dists + confidence).clamp_(min=0)
|
|
120
131
|
|
|
121
132
|
# check if we should abort search if we're getting nowhere.
|
|
@@ -124,9 +135,13 @@ def carlini_wagner_l2(model: nn.Module,
|
|
|
124
135
|
break
|
|
125
136
|
prev = loss.detach()
|
|
126
137
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
138
|
+
g = grad(loss.sum(), modifier, only_inputs=True)[0]
|
|
139
|
+
exp_avg.mul_(β_1).add_(g, alpha=1 - β_1)
|
|
140
|
+
exp_avg_sq.mul_(β_2).addcmul_(g, g, value=1 - β_2)
|
|
141
|
+
bias_correction1 = 1 - β_1 ** (i + 1)
|
|
142
|
+
bias_correction2 = 1 - β_2 ** (i + 1)
|
|
143
|
+
denom = exp_avg_sq.sqrt().div_(math.sqrt(bias_correction2)).add_(1e-8)
|
|
144
|
+
modifier.data.addcdiv_(exp_avg, denom, value=learning_rate / bias_correction1)
|
|
130
145
|
|
|
131
146
|
if callback:
|
|
132
147
|
i_total += 1
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
# Adapted from https://github.com/carlini/nn_robust_attacks
|
|
2
|
-
|
|
3
|
-
from typing import
|
|
2
|
+
import math
|
|
3
|
+
from typing import Optional
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from torch import nn,
|
|
6
|
+
from torch import nn, Tensor
|
|
7
|
+
from torch.autograd import grad
|
|
7
8
|
|
|
8
9
|
from adv_lib.utils.losses import difference_of_logits
|
|
9
10
|
from adv_lib.utils.visdom_logger import VisdomLogger
|
|
@@ -21,6 +22,8 @@ def carlini_wagner_linf(model: nn.Module,
|
|
|
21
22
|
reduce_const: bool = False,
|
|
22
23
|
decrease_factor: float = 0.9,
|
|
23
24
|
abort_early: bool = True,
|
|
25
|
+
β_1: float = 0.9,
|
|
26
|
+
β_2: float = 0.999,
|
|
24
27
|
callback: Optional[VisdomLogger] = None) -> Tensor:
|
|
25
28
|
"""
|
|
26
29
|
Carlini and Wagner Linf attack from https://arxiv.org/abs/1608.04644.
|
|
@@ -52,8 +55,8 @@ def carlini_wagner_linf(model: nn.Module,
|
|
|
52
55
|
Rate at which τ is decreased. Larger produces better quality results.
|
|
53
56
|
abort_early : bool
|
|
54
57
|
If true, allows early aborts if gradient descent gets stuck.
|
|
55
|
-
|
|
56
|
-
|
|
58
|
+
β_1, β_2: float
|
|
59
|
+
Adam exponential averages smoothing parameters.
|
|
57
60
|
callback : Optional
|
|
58
61
|
|
|
59
62
|
Returns
|
|
@@ -65,12 +68,13 @@ def carlini_wagner_linf(model: nn.Module,
|
|
|
65
68
|
device = inputs.device
|
|
66
69
|
batch_size = len(inputs)
|
|
67
70
|
t_inputs = (inputs * 2).sub_(1).mul_(1 - 1e-6).atanh_()
|
|
68
|
-
|
|
71
|
+
if not targeted:
|
|
72
|
+
learning_rate *= -1
|
|
69
73
|
|
|
70
74
|
# set modifier and the parameters used in the optimization
|
|
71
75
|
modifier = torch.zeros_like(inputs)
|
|
72
76
|
c = torch.full((batch_size,), initial_const, device=device, dtype=torch.float)
|
|
73
|
-
τ = torch.
|
|
77
|
+
τ = torch.ones_like(c)
|
|
74
78
|
|
|
75
79
|
o_adv_found = torch.zeros_like(c, dtype=torch.bool)
|
|
76
80
|
o_best_linf = torch.ones_like(c)
|
|
@@ -89,7 +93,8 @@ def carlini_wagner_linf(model: nn.Module,
|
|
|
89
93
|
|
|
90
94
|
# setup the optimizer
|
|
91
95
|
modifier_ = modifier[to_optimize].requires_grad_(True)
|
|
92
|
-
|
|
96
|
+
exp_avg = torch.zeros_like(modifier_)
|
|
97
|
+
exp_avg_sq = torch.zeros_like(modifier_)
|
|
93
98
|
c_, τ_ = c[to_optimize], τ[to_optimize]
|
|
94
99
|
|
|
95
100
|
adv_found = torch.zeros(len(modifier_), device=device, dtype=torch.bool)
|
|
@@ -119,7 +124,7 @@ def carlini_wagner_linf(model: nn.Module,
|
|
|
119
124
|
best_linf = torch.where(is_both, linf, best_linf)
|
|
120
125
|
best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv)
|
|
121
126
|
|
|
122
|
-
logit_dists =
|
|
127
|
+
logit_dists = difference_of_logits(logits, labels_, labels_infhot=labels_infhot)
|
|
123
128
|
linf_loss = (adv_inputs - inputs_).abs_().sub_(batch_view(τ_)).clamp_(min=0).flatten(1).sum(1)
|
|
124
129
|
loss = linf_loss + c_ * logit_dists.clamp_(min=0)
|
|
125
130
|
|
|
@@ -127,9 +132,13 @@ def carlini_wagner_linf(model: nn.Module,
|
|
|
127
132
|
if abort_early and (loss < 0.0001 * c_).all():
|
|
128
133
|
break
|
|
129
134
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
135
|
+
g = grad(loss.sum(), modifier_, only_inputs=True)[0]
|
|
136
|
+
exp_avg.mul_(β_1).add_(g, alpha=1 - β_1)
|
|
137
|
+
exp_avg_sq.mul_(β_2).addcmul_(g, g, value=1 - β_2)
|
|
138
|
+
bias_correction1 = 1 - β_1 ** (i + 1)
|
|
139
|
+
bias_correction2 = 1 - β_2 ** (i + 1)
|
|
140
|
+
denom = exp_avg_sq.sqrt().div_(math.sqrt(bias_correction2)).add_(1e-8)
|
|
141
|
+
modifier_.data.addcdiv_(exp_avg, denom, value=learning_rate / bias_correction1)
|
|
133
142
|
|
|
134
143
|
if callback:
|
|
135
144
|
callback.accumulate_line('logit_dist', total_iters, logit_dists.mean())
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
from torch.autograd import grad
|
|
6
|
+
|
|
7
|
+
from adv_lib.utils.attack_utils import get_all_targets
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def df(model: nn.Module,
|
|
11
|
+
inputs: Tensor,
|
|
12
|
+
labels: Tensor,
|
|
13
|
+
targeted: bool = False,
|
|
14
|
+
steps: int = 100,
|
|
15
|
+
overshoot: float = 0.02,
|
|
16
|
+
norm: float = 2,
|
|
17
|
+
return_unsuccessful: bool = False,
|
|
18
|
+
return_targets: bool = False) -> Tensor:
|
|
19
|
+
"""
|
|
20
|
+
DeepFool attack from https://arxiv.org/abs/1511.04599. Properly implement parallel sample-wise early-stopping.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
model : nn.Module
|
|
25
|
+
Model to attack.
|
|
26
|
+
inputs : Tensor
|
|
27
|
+
Inputs to attack. Should be in [0, 1].
|
|
28
|
+
labels : Tensor
|
|
29
|
+
Labels corresponding to the inputs if untargeted, else target labels.
|
|
30
|
+
targeted : bool
|
|
31
|
+
Whether to perform a targeted attack or not.
|
|
32
|
+
steps : int
|
|
33
|
+
Maixmum number of attack steps.
|
|
34
|
+
overshoot : float
|
|
35
|
+
Ratio by which to overshoot the boundary estimated from linear model.
|
|
36
|
+
norm : float
|
|
37
|
+
Norm to minimize in {2, float('inf')}.
|
|
38
|
+
return_unsuccessful : bool
|
|
39
|
+
Whether to return unsuccessful adversarial inputs ; used by SuperDeepFool.
|
|
40
|
+
return_unsuccessful : bool
|
|
41
|
+
Whether to return last target labels ; used by SuperDeepFool.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
adv_inputs : Tensor
|
|
46
|
+
Modified inputs to be adversarial to the model.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
if targeted:
|
|
50
|
+
warnings.warn('DeepFool attack is untargeted only. Returning inputs.')
|
|
51
|
+
return inputs
|
|
52
|
+
|
|
53
|
+
if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.')
|
|
54
|
+
device = inputs.device
|
|
55
|
+
batch_size = len(inputs)
|
|
56
|
+
batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
|
|
57
|
+
|
|
58
|
+
# Setup variables
|
|
59
|
+
adv_inputs = inputs.clone()
|
|
60
|
+
adv_inputs.requires_grad_(True)
|
|
61
|
+
|
|
62
|
+
adv_out = inputs.clone()
|
|
63
|
+
adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
64
|
+
if return_targets:
|
|
65
|
+
targets = labels.clone()
|
|
66
|
+
|
|
67
|
+
arange = torch.arange(batch_size, device=device)
|
|
68
|
+
for i in range(steps):
|
|
69
|
+
|
|
70
|
+
logits = model(adv_inputs)
|
|
71
|
+
|
|
72
|
+
if i == 0:
|
|
73
|
+
other_labels = get_all_targets(labels=labels, num_classes=logits.shape[1])
|
|
74
|
+
|
|
75
|
+
pred_labels = logits.argmax(dim=1)
|
|
76
|
+
is_adv = (pred_labels == labels) if targeted else (pred_labels != labels)
|
|
77
|
+
|
|
78
|
+
if is_adv.any():
|
|
79
|
+
adv_not_found = ~adv_found
|
|
80
|
+
adv_out[adv_not_found] = torch.where(batch_view(is_adv), adv_inputs.detach(), adv_out[adv_not_found])
|
|
81
|
+
adv_found.masked_scatter_(adv_not_found, is_adv)
|
|
82
|
+
if is_adv.all():
|
|
83
|
+
break
|
|
84
|
+
|
|
85
|
+
not_adv = ~is_adv
|
|
86
|
+
logits, labels, other_labels = logits[not_adv], labels[not_adv], other_labels[not_adv]
|
|
87
|
+
arange = torch.arange(not_adv.sum(), device=device)
|
|
88
|
+
|
|
89
|
+
f_prime = logits.gather(dim=1, index=other_labels) - logits.gather(dim=1, index=labels.unsqueeze(1))
|
|
90
|
+
w_prime = []
|
|
91
|
+
for j, f_prime_k in enumerate(f_prime.unbind(dim=1)):
|
|
92
|
+
w_prime_k = grad(f_prime_k.sum(), inputs=adv_inputs, retain_graph=(j + 1) < f_prime.shape[1],
|
|
93
|
+
only_inputs=True)[0]
|
|
94
|
+
w_prime.append(w_prime_k)
|
|
95
|
+
w_prime = torch.stack(w_prime, dim=1) # batch_size × num_classes × ...
|
|
96
|
+
|
|
97
|
+
if is_adv.any():
|
|
98
|
+
not_adv = ~is_adv
|
|
99
|
+
adv_inputs, w_prime = adv_inputs[not_adv], w_prime[not_adv]
|
|
100
|
+
|
|
101
|
+
if norm == 2:
|
|
102
|
+
w_prime_norms = w_prime.flatten(2).norm(p=2, dim=2).clamp_(min=1e-6)
|
|
103
|
+
elif norm == float('inf'):
|
|
104
|
+
w_prime_norms = w_prime.flatten(2).norm(p=1, dim=2).clamp_(min=1e-6)
|
|
105
|
+
|
|
106
|
+
distance = f_prime.detach().abs_().div_(w_prime_norms).add_(1e-4)
|
|
107
|
+
l_hat = distance.argmin(dim=1)
|
|
108
|
+
|
|
109
|
+
if return_targets:
|
|
110
|
+
targets[~adv_found] = torch.where(l_hat >= labels, l_hat + 1, l_hat)
|
|
111
|
+
|
|
112
|
+
if norm == 2:
|
|
113
|
+
# 1e-4 added in original implementation
|
|
114
|
+
scale = distance[arange, l_hat] / w_prime_norms[arange, l_hat]
|
|
115
|
+
adv_inputs.data.addcmul_(batch_view(scale), w_prime[arange, l_hat], value=1 + overshoot)
|
|
116
|
+
elif norm == float('inf'):
|
|
117
|
+
adv_inputs.data.addcmul_(batch_view(distance[arange, l_hat]), w_prime[arange, l_hat].sign(),
|
|
118
|
+
value=1 + overshoot)
|
|
119
|
+
adv_inputs.data.clamp_(min=0, max=1)
|
|
120
|
+
|
|
121
|
+
if return_unsuccessful and not adv_found.all():
|
|
122
|
+
adv_out[~adv_found] = adv_inputs.detach()
|
|
123
|
+
|
|
124
|
+
if return_targets:
|
|
125
|
+
return adv_out, targets
|
|
126
|
+
|
|
127
|
+
return adv_out
|
{adv_lib-0.2.2 → adv_lib-0.2.6}/adv_lib/attacks/fast_adaptive_boundary/fast_adaptive_boundary.py
RENAMED
|
@@ -8,6 +8,7 @@ import torch
|
|
|
8
8
|
from torch import Tensor, nn
|
|
9
9
|
from torch.autograd import grad
|
|
10
10
|
|
|
11
|
+
from adv_lib.utils.attack_utils import get_all_targets
|
|
11
12
|
from .projections import projection_l1, projection_l2, projection_linf
|
|
12
13
|
|
|
13
14
|
|
|
@@ -163,13 +164,7 @@ def _fab(model: nn.Module,
|
|
|
163
164
|
if targets is not None:
|
|
164
165
|
other_labels = targets.unsqueeze(1)
|
|
165
166
|
else:
|
|
166
|
-
|
|
167
|
-
n_classes = logits.size(1)
|
|
168
|
-
other_labels = torch.zeros(len(labels), n_classes - 1, dtype=torch.long, device=device)
|
|
169
|
-
all_classes = set(range(n_classes))
|
|
170
|
-
for i in range(len(labels)):
|
|
171
|
-
diff_labels = list(all_classes.difference({labels[i].item()}))
|
|
172
|
-
other_labels[i] = torch.tensor(diff_labels, device=device)
|
|
167
|
+
other_labels = get_all_targets(labels=labels, num_classes=logits.shape[1])
|
|
173
168
|
|
|
174
169
|
get_df_dg = partial(get_best_diff_logits_grads, model=model, labels=labels, other_labels=other_labels, q=dual_norm)
|
|
175
170
|
|
|
@@ -59,8 +59,7 @@ def l1_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor:
|
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
def l2_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor:
|
|
62
|
-
|
|
63
|
-
return x0.flatten(1).mul(1 - ε).addcmul_(ε, x1.flatten(1)).view_as(x0)
|
|
62
|
+
return torch.lerp(x0.flatten(1), x1.flatten(1), weight=ε.unsqueeze(1)).view_as(x0)
|
|
64
63
|
|
|
65
64
|
|
|
66
65
|
def linf_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor:
|
|
@@ -194,24 +193,18 @@ def fmn(model: nn.Module,
|
|
|
194
193
|
torch.maximum(ε + 1, (ε * (1 + γ)).floor_()))
|
|
195
194
|
ε.clamp_(min=0)
|
|
196
195
|
else:
|
|
197
|
-
distance_to_boundary = loss.detach().
|
|
196
|
+
distance_to_boundary = loss.detach().abs_().div_(δ_grad.flatten(1).norm(p=dual, dim=1).clamp_(min=1e-12))
|
|
198
197
|
ε = torch.where(is_adv,
|
|
199
198
|
torch.minimum(ε * (1 - γ), best_norm),
|
|
200
199
|
torch.where(adv_found, ε * (1 + γ), δ_norm + distance_to_boundary))
|
|
201
200
|
|
|
202
201
|
# clip ε
|
|
203
202
|
ε = torch.minimum(ε, worst_norm)
|
|
204
|
-
|
|
205
|
-
# normalize gradient
|
|
203
|
+
# gradient ascent step with normalized gradient
|
|
206
204
|
grad_l2_norms = δ_grad.flatten(1).norm(p=2, dim=1).clamp_(min=1e-12)
|
|
207
|
-
δ_grad
|
|
208
|
-
|
|
209
|
-
# gradient ascent step
|
|
210
|
-
δ.data.add_(δ_grad, alpha=α)
|
|
211
|
-
|
|
205
|
+
δ.data.addcdiv_(δ_grad, batch_view(grad_l2_norms), value=α)
|
|
212
206
|
# project in place
|
|
213
207
|
projection(δ=δ.data, ε=ε)
|
|
214
|
-
|
|
215
208
|
# clamp
|
|
216
209
|
δ.data.add_(inputs).clamp_(min=0, max=1).sub_(inputs)
|
|
217
210
|
|
|
@@ -10,40 +10,45 @@ from adv_lib.utils.losses import difference_of_logits_ratio
|
|
|
10
10
|
from adv_lib.utils.visdom_logger import VisdomLogger
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
@torch.compile
|
|
13
14
|
def prox_linf_indicator(δ: Tensor, λ: Tensor, lower: Tensor, upper: Tensor, H: Optional[Tensor] = None,
|
|
14
15
|
ε: float = 1e-6, section: float = 1 / 3) -> Tensor:
|
|
15
16
|
"""Proximity operator of λ||·||_∞ + \iota_Λ in the diagonal metric H. The lower and upper tensors correspond to
|
|
16
17
|
the bounds of Λ. The problem is solved using a ternary search with section 1/3 up to an absolute error of ε on the
|
|
17
18
|
prox. Using a section of 1 - 1/φ (with φ the golden ratio) yields the Golden-section search, which is a bit faster,
|
|
18
19
|
but less numerically stable."""
|
|
19
|
-
δ
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
δ_shape = δ.shape
|
|
21
|
+
δ, λ = δ.flatten(1), 2 * λ
|
|
22
|
+
H = H.flatten(1) if H is not None else None
|
|
23
|
+
δ_proj = δ.clamp(min=lower.flatten(1), max=upper.flatten(1))
|
|
24
|
+
right = δ_proj.norm(p=float('inf'), dim=1)
|
|
23
25
|
left = torch.zeros_like(right)
|
|
24
|
-
steps = (ε / right.max()).log_().
|
|
26
|
+
steps = (ε / right.max()).log_().div_(math.log(1 - section)).ceil_().long()
|
|
25
27
|
prox, left_third, right_third, f_left, f_right, cond = (None,) * 6
|
|
26
28
|
for _ in range(steps):
|
|
27
29
|
left_third = torch.lerp(left, right, weight=section, out=left_third)
|
|
28
30
|
right_third = torch.lerp(left, right, weight=1 - section, out=right_third)
|
|
29
31
|
|
|
30
|
-
prox = torch.clamp(δ_proj, min=-left_third, max=left_third, out=prox)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
f_left.
|
|
32
|
+
prox = torch.clamp(δ_proj, min=-left_third.unsqueeze(1), max=left_third.unsqueeze(1), out=prox)
|
|
33
|
+
prox.sub_(δ).square_()
|
|
34
|
+
if H is not None:
|
|
35
|
+
prox.mul_(H)
|
|
36
|
+
f_left = torch.sum(prox, dim=1, out=f_left)
|
|
37
|
+
f_left.addcmul_(left_third, λ)
|
|
35
38
|
|
|
36
|
-
prox = torch.clamp(δ_proj, min=-right_third, max=right_third, out=prox)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
f_right.
|
|
39
|
+
prox = torch.clamp(δ_proj, min=-right_third.unsqueeze(1), max=right_third.unsqueeze(1), out=prox)
|
|
40
|
+
prox.sub_(δ).square_()
|
|
41
|
+
if H is not None:
|
|
42
|
+
prox.mul_(H)
|
|
43
|
+
f_right = torch.sum(prox, dim=1, out=f_right)
|
|
44
|
+
f_right.addcmul_(right_third, λ)
|
|
41
45
|
|
|
42
46
|
cond = torch.ge(f_left, f_right, out=cond)
|
|
43
47
|
left = torch.where(cond, left_third, left, out=left)
|
|
44
48
|
right = torch.where(cond, right, right_third, out=right)
|
|
45
49
|
left.lerp_(right, weight=0.5)
|
|
46
|
-
|
|
50
|
+
prox = δ_proj.clamp_(min=-left.unsqueeze(1), max=left.unsqueeze(1)).view(δ_shape)
|
|
51
|
+
return prox
|
|
47
52
|
|
|
48
53
|
|
|
49
54
|
class P(Function):
|
|
@@ -87,7 +87,9 @@ def fga(model: nn.Module,
|
|
|
87
87
|
|
|
88
88
|
# add perturbation to inputs
|
|
89
89
|
perturbed_inputs = inputs_.flatten(1).unsqueeze(1).repeat(1, n_samples, 1)
|
|
90
|
-
perturbed_inputs.scatter_add_(
|
|
90
|
+
perturbed_inputs.scatter_add_(
|
|
91
|
+
2, i_0.repeat_interleave(n_samples, dim=1, output_size=n_samples).unsqueeze(2), S.unsqueeze(2)
|
|
92
|
+
)
|
|
91
93
|
perturbed_inputs.clamp_(min=0, max=1)
|
|
92
94
|
|
|
93
95
|
# get probabilities for perturbed inputs
|
|
@@ -199,7 +201,9 @@ def vfga(model: nn.Module,
|
|
|
199
201
|
|
|
200
202
|
# add perturbation to inputs
|
|
201
203
|
perturbed_inputs = inputs_.flatten(1).unsqueeze(1).repeat(1, 2 * n_samples, 1)
|
|
202
|
-
i_plus_minus = torch.cat([i_plus, i_minus], dim=1).repeat_interleave(
|
|
204
|
+
i_plus_minus = torch.cat([i_plus, i_minus], dim=1).repeat_interleave(
|
|
205
|
+
n_samples, dim=1, output_size=2 * n_samples
|
|
206
|
+
)
|
|
203
207
|
S_plus_minus = torch.cat([S_plus, S_minus], dim=1)
|
|
204
208
|
perturbed_inputs.scatter_add_(2, i_plus_minus.unsqueeze(2), S_plus_minus.unsqueeze(2))
|
|
205
209
|
perturbed_inputs.clamp_(min=0, max=1)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
from torch.autograd import grad
|
|
6
|
+
|
|
7
|
+
from .deepfool import df
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def sdf(model: nn.Module,
|
|
11
|
+
inputs: Tensor,
|
|
12
|
+
labels: Tensor,
|
|
13
|
+
targeted: bool = False,
|
|
14
|
+
steps: int = 100,
|
|
15
|
+
df_steps: int = 100,
|
|
16
|
+
overshoot: float = 0.02,
|
|
17
|
+
search_iter: int = 10) -> Tensor:
|
|
18
|
+
"""
|
|
19
|
+
SuperDeepFool attack from https://arxiv.org/abs/2303.12481.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
model : nn.Module
|
|
24
|
+
Model to attack.
|
|
25
|
+
inputs : Tensor
|
|
26
|
+
Inputs to attack. Should be in [0, 1].
|
|
27
|
+
labels : Tensor
|
|
28
|
+
Labels corresponding to the inputs if untargeted, else target labels.
|
|
29
|
+
targeted : bool
|
|
30
|
+
Whether to perform a targeted attack or not.
|
|
31
|
+
steps : int
|
|
32
|
+
Number of steps.
|
|
33
|
+
df_steps : int
|
|
34
|
+
Maximum number of steps for DeepFool attack at each iteration of SuperDeepFool.
|
|
35
|
+
overshoot : float
|
|
36
|
+
overshoot parameter in DeepFool.
|
|
37
|
+
search_iter : int
|
|
38
|
+
Number of binary search steps at the end of the attack.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
adv_inputs : Tensor
|
|
43
|
+
Modified inputs to be adversarial to the model.
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
if targeted:
|
|
47
|
+
warnings.warn('DeepFool attack is untargeted only. Returning inputs.')
|
|
48
|
+
return inputs
|
|
49
|
+
|
|
50
|
+
if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.')
|
|
51
|
+
device = inputs.device
|
|
52
|
+
batch_size = len(inputs)
|
|
53
|
+
batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
|
|
54
|
+
|
|
55
|
+
# Setup variables
|
|
56
|
+
adv_inputs = inputs_ = inputs
|
|
57
|
+
labels_ = labels
|
|
58
|
+
adv_out = inputs.clone()
|
|
59
|
+
adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
60
|
+
|
|
61
|
+
for i in range(steps):
|
|
62
|
+
logits = model(adv_inputs)
|
|
63
|
+
pred_labels = logits.argmax(dim=1)
|
|
64
|
+
|
|
65
|
+
is_adv = pred_labels != labels_
|
|
66
|
+
if is_adv.any():
|
|
67
|
+
adv_not_found = ~adv_found
|
|
68
|
+
adv_out[adv_not_found] = torch.where(batch_view(is_adv), adv_inputs, adv_out[adv_not_found])
|
|
69
|
+
adv_found.masked_scatter_(adv_not_found, is_adv)
|
|
70
|
+
if is_adv.all():
|
|
71
|
+
break
|
|
72
|
+
|
|
73
|
+
not_adv = ~is_adv
|
|
74
|
+
inputs_, adv_inputs, labels_ = inputs_[not_adv], adv_inputs[not_adv], labels_[not_adv]
|
|
75
|
+
|
|
76
|
+
# start by doing deepfool -> need to return adv_inputs even for unsuccessful attacks
|
|
77
|
+
df_adv_inputs, df_targets = df(model=model, inputs=adv_inputs, labels=labels_, steps=df_steps, norm=2,
|
|
78
|
+
overshoot=overshoot, return_unsuccessful=True, return_targets=True)
|
|
79
|
+
|
|
80
|
+
r_df = df_adv_inputs - inputs_
|
|
81
|
+
df_adv_inputs.requires_grad_(True)
|
|
82
|
+
logits = model(df_adv_inputs)
|
|
83
|
+
pred_labels = logits.argmax(dim=1)
|
|
84
|
+
pred_labels = torch.where(pred_labels != labels_, pred_labels, df_targets)
|
|
85
|
+
|
|
86
|
+
logit_diff = logits.gather(1, pred_labels.unsqueeze(1)) - logits.gather(1, labels_.unsqueeze(1))
|
|
87
|
+
w = grad(logit_diff.sum(), inputs=df_adv_inputs, only_inputs=True)[0]
|
|
88
|
+
w.div_(batch_view(w.flatten(1).norm(p=2, dim=1).clamp_(min=1e-6))) # w / ||w||_2
|
|
89
|
+
scale = torch.linalg.vecdot(r_df.flatten(1), w.flatten(1), dim=1) # (\tilde{x} - x_0)^T w / ||w||_2
|
|
90
|
+
|
|
91
|
+
adv_inputs = adv_inputs.addcmul(batch_view(scale), w)
|
|
92
|
+
adv_inputs.clamp_(min=0, max=1) # added compared to original implementation to produce valid adv
|
|
93
|
+
|
|
94
|
+
if search_iter: # binary search to bring perturbation as close to the decision boundary as possible
|
|
95
|
+
low, high = torch.zeros(batch_size, device=device), torch.ones(batch_size, device=device)
|
|
96
|
+
for i in range(search_iter):
|
|
97
|
+
mid = (low + high) / 2
|
|
98
|
+
logits = torch.lerp(inputs, adv_out, weight=batch_view(mid))
|
|
99
|
+
pred_labels = model(logits).argmax(dim=1)
|
|
100
|
+
is_adv = pred_labels != labels
|
|
101
|
+
high = torch.where(is_adv, mid, high)
|
|
102
|
+
low = torch.where(is_adv, low, mid)
|
|
103
|
+
adv_out = torch.lerp(inputs, adv_out, weight=batch_view(high))
|
|
104
|
+
|
|
105
|
+
return adv_out
|