paraug 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- paraug/__init__.py +31 -0
- paraug/geometric.py +380 -0
- paraug/photometric.py +1112 -0
- paraug/pipeline.py +81 -0
- paraug/utils.py +102 -0
- paraug-0.1.1.dist-info/METADATA +201 -0
- paraug-0.1.1.dist-info/RECORD +10 -0
- paraug-0.1.1.dist-info/WHEEL +4 -0
- paraug-0.1.1.dist-info/licenses/LICENSE +201 -0
- paraug-0.1.1.dist-info/licenses/NOTICE +4 -0
paraug/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""paraug — Parity Augmentation library for image + mask.
|
|
2
|
+
|
|
3
|
+
Bit-exact CPU/GPU parity across 7 geometric + 24 photometric primitives.
|
|
4
|
+
Single entry point: build with `AugPipeline(config)`, then call with
|
|
5
|
+
`aug(img, mask=None, seed_base=..., epoch=..., step=...)` returning
|
|
6
|
+
`(img_out, mask_out)`. Per-item RNG is sampled on CPU regardless of tensor
|
|
7
|
+
device, so the same seed produces identical output on CPU and CUDA back-ends
|
|
8
|
+
within a tight tolerance (1e-6 elementwise, 2e-4 for grid_sample-class ops).
|
|
9
|
+
|
|
10
|
+
Quickstart:
|
|
11
|
+
|
|
12
|
+
from paraug import AugPipeline
|
|
13
|
+
aug = AugPipeline({"geometric": {"affine": {"p": 1.0, "rot_deg": 15}}})
|
|
14
|
+
img_out, mask_out = aug(img, mask=mask, seed_base=42)
|
|
15
|
+
"""
|
|
16
|
+
from .pipeline import AugPipeline
|
|
17
|
+
from .utils import set_deterministic, per_item_seed, cpu_generator
|
|
18
|
+
from .geometric import GEOMETRIC_PRIMITIVES
|
|
19
|
+
from .photometric import PHOTOMETRIC_PRIMITIVES
|
|
20
|
+
|
|
21
|
+
__version__ = "0.1.1"
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"AugPipeline",
|
|
25
|
+
"set_deterministic",
|
|
26
|
+
"per_item_seed",
|
|
27
|
+
"cpu_generator",
|
|
28
|
+
"GEOMETRIC_PRIMITIVES",
|
|
29
|
+
"PHOTOMETRIC_PRIMITIVES",
|
|
30
|
+
"__version__",
|
|
31
|
+
]
|
paraug/geometric.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
"""Geometric primitives — both image and mask are warped together.
|
|
2
|
+
|
|
3
|
+
Each primitive has signature:
|
|
4
|
+
fn(img, mask, spec, seed_base, epoch, step) -> (img_out, mask_out)
|
|
5
|
+
|
|
6
|
+
Image uses bilinear sampling; mask uses nearest. Padding is zero (background).
|
|
7
|
+
All primitives gate per-item: items with Bernoulli=False pass through bit-identical.
|
|
8
|
+
"""
|
|
9
|
+
import math
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
from .utils import per_item_seed, cpu_generator, sample_uniform, sample_bool
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_GRID_CACHE = {}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _coords(H: int, W: int, device, dtype):
|
|
20
|
+
"""Cached (yy, xx) pixel-coord meshgrid for sampling math."""
|
|
21
|
+
key = (H, W, str(device), str(dtype))
|
|
22
|
+
if key not in _GRID_CACHE:
|
|
23
|
+
yy, xx = torch.meshgrid(
|
|
24
|
+
torch.arange(H, device=device, dtype=dtype),
|
|
25
|
+
torch.arange(W, device=device, dtype=dtype),
|
|
26
|
+
indexing="ij",
|
|
27
|
+
)
|
|
28
|
+
_GRID_CACHE[key] = (yy, xx)
|
|
29
|
+
return _GRID_CACHE[key]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _apply_grid(img, mask, grid, gate_d):
|
|
33
|
+
"""Warp img + mask via grid; where gate is False, leave original.
|
|
34
|
+
|
|
35
|
+
grid: (B, H, W, 2) in normalised [-1, 1] coords.
|
|
36
|
+
gate_d: (B,) bool, on img.device.
|
|
37
|
+
"""
|
|
38
|
+
B = img.shape[0]
|
|
39
|
+
img_w = F.grid_sample(img, grid, mode="bilinear",
|
|
40
|
+
padding_mode="zeros", align_corners=True)
|
|
41
|
+
img_out = torch.where(gate_d.view(B, 1, 1, 1).expand_as(img), img_w, img)
|
|
42
|
+
if mask is not None:
|
|
43
|
+
mask_w = F.grid_sample(mask, grid, mode="nearest",
|
|
44
|
+
padding_mode="zeros", align_corners=True)
|
|
45
|
+
mask_out = torch.where(gate_d.view(B, 1, 1, 1).expand_as(mask), mask_w, mask)
|
|
46
|
+
else:
|
|
47
|
+
mask_out = None
|
|
48
|
+
return img_out, mask_out
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# ─── 1: tps ──────────────────────────────────────────────────────────
|
|
52
|
+
def _sample_tps_disp(B, H, W, p, max_disp, nc, seed_base, epoch, step, dtype):
|
|
53
|
+
"""Sample per-item TPS control points + upsample to full (H, W) displacement.
|
|
54
|
+
|
|
55
|
+
Returns CPU tensors (caller should .to(device) as needed):
|
|
56
|
+
gate (B,) bool
|
|
57
|
+
disp_y (B, H, W)
|
|
58
|
+
disp_x (B, H, W)
|
|
59
|
+
|
|
60
|
+
Factored out so any caller that wants to mirror tps()'s displacement
|
|
61
|
+
sampling (e.g. a downstream pipeline that pairs tps with extra GT
|
|
62
|
+
propagation) can pull the same disp field by using the same seed.
|
|
63
|
+
"""
|
|
64
|
+
gate = torch.zeros(B, dtype=torch.bool)
|
|
65
|
+
dy_ctrl = torch.zeros(B, 1, nc, nc, dtype=dtype)
|
|
66
|
+
dx_ctrl = torch.zeros(B, 1, nc, nc, dtype=dtype)
|
|
67
|
+
for i in range(B):
|
|
68
|
+
g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "tps"))
|
|
69
|
+
if sample_bool(p, g):
|
|
70
|
+
gate[i] = True
|
|
71
|
+
dy_ctrl[i, 0] = torch.empty(nc, nc).uniform_(-max_disp, max_disp, generator=g)
|
|
72
|
+
dx_ctrl[i, 0] = torch.empty(nc, nc).uniform_(-max_disp, max_disp, generator=g)
|
|
73
|
+
disp_y = F.interpolate(dy_ctrl, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
|
|
74
|
+
disp_x = F.interpolate(dx_ctrl, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
|
|
75
|
+
return gate, disp_y, disp_x
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def tps(img, mask, spec, seed_base, epoch, step):
|
|
79
|
+
"""Thin-plate-spline-ish warp via low-res control grid → bilinear upsample.
|
|
80
|
+
|
|
81
|
+
spec = {"p": prob, "max_disp": pixels, "n_ctrl": int}
|
|
82
|
+
"""
|
|
83
|
+
B, _, H, W = img.shape
|
|
84
|
+
p = float(spec.get("p", 1.0))
|
|
85
|
+
max_disp = float(spec.get("max_disp", 18.0))
|
|
86
|
+
nc = int(spec.get("n_ctrl", 5))
|
|
87
|
+
|
|
88
|
+
gate, disp_y, disp_x = _sample_tps_disp(
|
|
89
|
+
B, H, W, p, max_disp, nc, seed_base, epoch, step, img.dtype)
|
|
90
|
+
disp_y_d = disp_y.to(img.device); disp_x_d = disp_x.to(img.device)
|
|
91
|
+
gate_d = gate.to(img.device)
|
|
92
|
+
yy, xx = _coords(H, W, img.device, img.dtype)
|
|
93
|
+
src_x = xx[None] - disp_x_d
|
|
94
|
+
src_y = yy[None] - disp_y_d
|
|
95
|
+
grid = torch.stack([2.0 * src_x / (W - 1) - 1.0,
|
|
96
|
+
2.0 * src_y / (H - 1) - 1.0], dim=-1)
|
|
97
|
+
return _apply_grid(img, mask, grid, gate_d)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# ─── 2: affine ───────────────────────────────────────────────────────
|
|
101
|
+
def affine(img, mask, spec, seed_base, epoch, step):
|
|
102
|
+
"""Random rotation / scale / translation.
|
|
103
|
+
|
|
104
|
+
spec = {"p": prob, "rot_deg": float, "scale_range": (lo, hi),
|
|
105
|
+
"translate_frac": float (fraction of H/W)}
|
|
106
|
+
"""
|
|
107
|
+
B, C, H, W = img.shape
|
|
108
|
+
p = float(spec.get("p", 1.0))
|
|
109
|
+
rot_deg = float(spec.get("rot_deg", 30.0))
|
|
110
|
+
scale_lo, scale_hi = spec.get("scale_range", (0.85, 1.15))
|
|
111
|
+
trans_frac = float(spec.get("translate_frac", 0.05))
|
|
112
|
+
|
|
113
|
+
gate = torch.zeros(B, dtype=torch.bool)
|
|
114
|
+
# (B, 2, 3) affine matrix per item; identity for skipped items
|
|
115
|
+
theta = torch.zeros(B, 2, 3, dtype=img.dtype)
|
|
116
|
+
theta[:, 0, 0] = 1.0; theta[:, 1, 1] = 1.0
|
|
117
|
+
for i in range(B):
|
|
118
|
+
g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "affine"))
|
|
119
|
+
if sample_bool(p, g):
|
|
120
|
+
gate[i] = True
|
|
121
|
+
ang = math.radians(sample_uniform(-rot_deg, rot_deg, g))
|
|
122
|
+
s = sample_uniform(scale_lo, scale_hi, g)
|
|
123
|
+
tx = sample_uniform(-trans_frac, trans_frac, g)
|
|
124
|
+
ty = sample_uniform(-trans_frac, trans_frac, g)
|
|
125
|
+
ca, sa = math.cos(ang) / s, math.sin(ang) / s
|
|
126
|
+
theta[i] = torch.tensor([[ca, sa, tx], [-sa, ca, ty]])
|
|
127
|
+
theta_d = theta.to(img.device)
|
|
128
|
+
gate_d = gate.to(img.device)
|
|
129
|
+
grid = F.affine_grid(theta_d, [B, C, H, W], align_corners=True)
|
|
130
|
+
return _apply_grid(img, mask, grid, gate_d)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# ─── 3: perspective ──────────────────────────────────────────────────
|
|
134
|
+
def perspective(img, mask, spec, seed_base, epoch, step):
|
|
135
|
+
"""4-point perspective via homography H built from corner jitters.
|
|
136
|
+
|
|
137
|
+
spec = {"p": prob, "max_disp_frac": float (fraction of H/W)}
|
|
138
|
+
Each corner of the image is jittered independently by up to ±max_disp_frac.
|
|
139
|
+
"""
|
|
140
|
+
B, C, H, W = img.shape
|
|
141
|
+
p = float(spec.get("p", 1.0))
|
|
142
|
+
max_disp = float(spec.get("max_disp_frac", 0.05))
|
|
143
|
+
|
|
144
|
+
# Reference corners in normalised coords (-1, +1)
|
|
145
|
+
src = torch.tensor([[-1.0, -1.0], [1.0, -1.0],
|
|
146
|
+
[1.0, 1.0], [-1.0, 1.0]], dtype=img.dtype)
|
|
147
|
+
gate = torch.zeros(B, dtype=torch.bool)
|
|
148
|
+
dst_all = src.unsqueeze(0).repeat(B, 1, 1)
|
|
149
|
+
for i in range(B):
|
|
150
|
+
g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "perspective"))
|
|
151
|
+
if sample_bool(p, g):
|
|
152
|
+
gate[i] = True
|
|
153
|
+
dst = src.clone()
|
|
154
|
+
for c in range(4):
|
|
155
|
+
dst[c, 0] += sample_uniform(-max_disp, max_disp, g) * 2 # *2 because normalised
|
|
156
|
+
dst[c, 1] += sample_uniform(-max_disp, max_disp, g) * 2
|
|
157
|
+
dst_all[i] = dst
|
|
158
|
+
|
|
159
|
+
# Build homography per item (CPU, small linear solve), then bring grid to device
|
|
160
|
+
yy, xx = _coords(H, W, "cpu", img.dtype)
|
|
161
|
+
base_grid = torch.stack([2.0 * xx / (W - 1) - 1.0,
|
|
162
|
+
2.0 * yy / (H - 1) - 1.0], dim=-1) # (H, W, 2) in [-1, 1]
|
|
163
|
+
grids = torch.empty(B, H, W, 2, dtype=img.dtype)
|
|
164
|
+
for i in range(B):
|
|
165
|
+
Hmat = _solve_homography(src, dst_all[i])
|
|
166
|
+
# Apply homography to base_grid pixels (back-warp)
|
|
167
|
+
bg = base_grid.reshape(-1, 2) # (HW, 2)
|
|
168
|
+
ones = torch.ones(bg.shape[0], 1, dtype=img.dtype)
|
|
169
|
+
bgh = torch.cat([bg, ones], dim=1) # (HW, 3)
|
|
170
|
+
warped = bgh @ Hmat.T
|
|
171
|
+
warped = warped[:, :2] / warped[:, 2:].clamp(min=1e-8)
|
|
172
|
+
grids[i] = warped.reshape(H, W, 2)
|
|
173
|
+
grid = grids.to(img.device)
|
|
174
|
+
gate_d = gate.to(img.device)
|
|
175
|
+
return _apply_grid(img, mask, grid, gate_d)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _solve_homography(src, dst):
|
|
179
|
+
"""Compute 3×3 homography mapping src corners to dst corners. CPU, fp32.
|
|
180
|
+
src, dst: (4, 2) in normalised coords."""
|
|
181
|
+
A = torch.zeros(8, 8, dtype=src.dtype)
|
|
182
|
+
b = torch.zeros(8, dtype=src.dtype)
|
|
183
|
+
for k in range(4):
|
|
184
|
+
x, y = src[k]
|
|
185
|
+
u, v = dst[k]
|
|
186
|
+
A[2*k] = torch.tensor([x, y, 1, 0, 0, 0, -u*x, -u*y])
|
|
187
|
+
A[2*k + 1] = torch.tensor([0, 0, 0, x, y, 1, -v*x, -v*y])
|
|
188
|
+
b[2*k] = u
|
|
189
|
+
b[2*k + 1] = v
|
|
190
|
+
h = torch.linalg.solve(A, b)
|
|
191
|
+
H = torch.tensor([[h[0], h[1], h[2]],
|
|
192
|
+
[h[3], h[4], h[5]],
|
|
193
|
+
[h[6], h[7], 1.0]], dtype=src.dtype)
|
|
194
|
+
return H
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# ─── 4: random_crop_pad ──────────────────────────────────────────────
|
|
198
|
+
def random_crop_pad(img, mask, spec, seed_base, epoch, step):
|
|
199
|
+
"""Random crop, then pad back to original size with zero (or reflect).
|
|
200
|
+
|
|
201
|
+
spec = {"p": prob, "min_scale": float (lower bound on crop area frac)}
|
|
202
|
+
"""
|
|
203
|
+
B, C, H, W = img.shape
|
|
204
|
+
p = float(spec.get("p", 1.0))
|
|
205
|
+
min_scale = float(spec.get("min_scale", 0.7))
|
|
206
|
+
|
|
207
|
+
# We build the sampling grid: identity scaled + translated within bounds.
|
|
208
|
+
gate = torch.zeros(B, dtype=torch.bool)
|
|
209
|
+
theta = torch.zeros(B, 2, 3, dtype=img.dtype)
|
|
210
|
+
theta[:, 0, 0] = 1.0; theta[:, 1, 1] = 1.0
|
|
211
|
+
for i in range(B):
|
|
212
|
+
g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "random_crop_pad"))
|
|
213
|
+
if sample_bool(p, g):
|
|
214
|
+
gate[i] = True
|
|
215
|
+
s = sample_uniform(min_scale, 1.0, g)
|
|
216
|
+
# crop area = s × original; translation within (1-s) so crop stays inside
|
|
217
|
+
tx = sample_uniform(-(1 - s), (1 - s), g)
|
|
218
|
+
ty = sample_uniform(-(1 - s), (1 - s), g)
|
|
219
|
+
theta[i] = torch.tensor([[s, 0, tx], [0, s, ty]])
|
|
220
|
+
theta_d = theta.to(img.device); gate_d = gate.to(img.device)
|
|
221
|
+
grid = F.affine_grid(theta_d, [B, C, H, W], align_corners=True)
|
|
222
|
+
return _apply_grid(img, mask, grid, gate_d)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# ─── 5: elastic_transform ────────────────────────────────────────────
|
|
226
|
+
def elastic_transform(img, mask, spec, seed_base, epoch, step):
|
|
227
|
+
"""Elastic displacement-field warp.
|
|
228
|
+
|
|
229
|
+
spec = {"p": prob, "alpha": pixels (max disp)}
|
|
230
|
+
|
|
231
|
+
Implementation note (vs albumentations' ElasticTransform):
|
|
232
|
+
The displacement field is bilinear-upsampled from a low-resolution random
|
|
233
|
+
grid of shape (H/8, W/8). This differs from albumentations' dense
|
|
234
|
+
Gaussian-smoothed noise; ours produces smoother but more grid-aligned
|
|
235
|
+
distortion. Empirically indistinguishable for sim-to-real augmentation.
|
|
236
|
+
Smoothness is controlled by the fixed 1/8 downsample factor (no `sigma`
|
|
237
|
+
knob — by-design).
|
|
238
|
+
"""
|
|
239
|
+
B, C, H, W = img.shape
|
|
240
|
+
p = float(spec.get("p", 1.0))
|
|
241
|
+
alpha = float(spec.get("alpha", 8.0))
|
|
242
|
+
|
|
243
|
+
gate = torch.zeros(B, dtype=torch.bool)
|
|
244
|
+
# Per-item random field at full resolution, smoothed via separable Gaussian
|
|
245
|
+
disp_y = torch.zeros(B, H, W, dtype=img.dtype)
|
|
246
|
+
disp_x = torch.zeros(B, H, W, dtype=img.dtype)
|
|
247
|
+
# Build a low-res random field (H/8, W/8) and bicubic-upsample for speed +
|
|
248
|
+
# smoothness (equivalent to smoothing with σ proportional to scale factor).
|
|
249
|
+
hl = max(8, int(H / 8)); wl = max(8, int(W / 8))
|
|
250
|
+
for i in range(B):
|
|
251
|
+
g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "elastic_transform"))
|
|
252
|
+
if sample_bool(p, g):
|
|
253
|
+
gate[i] = True
|
|
254
|
+
dy_low = torch.empty(1, 1, hl, wl).uniform_(-1.0, 1.0, generator=g)
|
|
255
|
+
dx_low = torch.empty(1, 1, hl, wl).uniform_(-1.0, 1.0, generator=g)
|
|
256
|
+
# Upsample bilinear, scale by alpha
|
|
257
|
+
disp_y[i] = F.interpolate(dy_low, size=(H, W), mode="bilinear", align_corners=False)[0, 0] * alpha
|
|
258
|
+
disp_x[i] = F.interpolate(dx_low, size=(H, W), mode="bilinear", align_corners=False)[0, 0] * alpha
|
|
259
|
+
# Optional: gaussian smooth — keep simple with bilinear upsample only
|
|
260
|
+
disp_y_d = disp_y.to(img.device); disp_x_d = disp_x.to(img.device)
|
|
261
|
+
gate_d = gate.to(img.device)
|
|
262
|
+
yy, xx = _coords(H, W, img.device, img.dtype)
|
|
263
|
+
src_x = xx[None] - disp_x_d
|
|
264
|
+
src_y = yy[None] - disp_y_d
|
|
265
|
+
grid = torch.stack([2.0 * src_x / (W - 1) - 1.0,
|
|
266
|
+
2.0 * src_y / (H - 1) - 1.0], dim=-1)
|
|
267
|
+
return _apply_grid(img, mask, grid, gate_d)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
# ─── 6: optical_distortion ───────────────────────────────────────────
|
|
271
|
+
def optical_distortion(img, mask, spec, seed_base, epoch, step):
|
|
272
|
+
"""Barrel (k>0) / pincushion (k<0) lens distortion.
|
|
273
|
+
|
|
274
|
+
spec = {"p": prob, "k": float coefficient (max abs value)}
|
|
275
|
+
Maps (x, y) → (x', y') = (x·(1 + k·r²), y·(1 + k·r²)), r = normalised radius.
|
|
276
|
+
Per-item k sampled uniformly from [-k_max, k_max].
|
|
277
|
+
|
|
278
|
+
Implementation note:
|
|
279
|
+
This uses a first-order **forward** approximation `scale = 1 + k·r²` for
|
|
280
|
+
the back-warp lookup (computing src = nx · scale using output r). For
|
|
281
|
+
moderate k ∈ [-0.5, 0.5] the visual distortion is geometrically correct.
|
|
282
|
+
For larger |k| the strictly-correct inverse requires Newton iteration on
|
|
283
|
+
a cubic; not implemented since the small-k regime covers typical
|
|
284
|
+
photographic distortion.
|
|
285
|
+
"""
|
|
286
|
+
B, C, H, W = img.shape
|
|
287
|
+
p = float(spec.get("p", 1.0))
|
|
288
|
+
kmax = float(spec.get("k", 0.3))
|
|
289
|
+
|
|
290
|
+
gate = torch.zeros(B, dtype=torch.bool)
|
|
291
|
+
ks = torch.zeros(B, dtype=img.dtype)
|
|
292
|
+
for i in range(B):
|
|
293
|
+
g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "optical_distortion"))
|
|
294
|
+
if sample_bool(p, g):
|
|
295
|
+
gate[i] = True
|
|
296
|
+
ks[i] = sample_uniform(-kmax, kmax, g)
|
|
297
|
+
ks_d = ks.to(img.device); gate_d = gate.to(img.device)
|
|
298
|
+
|
|
299
|
+
yy, xx = _coords(H, W, img.device, img.dtype)
|
|
300
|
+
nx = 2.0 * xx / (W - 1) - 1.0
|
|
301
|
+
ny = 2.0 * yy / (H - 1) - 1.0
|
|
302
|
+
r2 = nx ** 2 + ny ** 2 # (H, W)
|
|
303
|
+
# Per-item radial scale: 1 + k_i * r²
|
|
304
|
+
scale = 1.0 + ks_d.view(B, 1, 1) * r2[None]
|
|
305
|
+
src_nx = nx[None] * scale # back-warp coords in normalised space
|
|
306
|
+
src_ny = ny[None] * scale
|
|
307
|
+
grid = torch.stack([src_nx, src_ny], dim=-1)
|
|
308
|
+
return _apply_grid(img, mask, grid, gate_d)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
# ─── 7: random_shadow (vectorized) ────────────────────────────────────
|
|
312
|
+
def random_shadow(img, mask, spec, seed_base, epoch, step):
|
|
313
|
+
"""Random triangular shadow region — multiplies image by a soft falloff
|
|
314
|
+
factor inside a randomly placed triangle. Mask is unchanged.
|
|
315
|
+
|
|
316
|
+
spec = {"p": prob, "strength": float in [0,1], "softness_px": float}
|
|
317
|
+
|
|
318
|
+
Vectorized version (Phase 4): the previous per-item Python loop over the
|
|
319
|
+
avg_pool blur (3 passes × kernel ≈ 61 per item × B items) was the worst
|
|
320
|
+
wall-clock primitive at ~4.2 s/iter on B=8 850×1100. Batching the box blur
|
|
321
|
+
over the B dimension lets the GPU run all items in parallel; the per-item
|
|
322
|
+
triangle is built with broadcasted barycentric math instead of a Python
|
|
323
|
+
loop. RNG sampling stays per-item to preserve bit-exact determinism.
|
|
324
|
+
"""
|
|
325
|
+
B, C, H, W = img.shape
|
|
326
|
+
p = float(spec.get("p", 1.0))
|
|
327
|
+
strength = float(spec.get("strength", 0.4))
|
|
328
|
+
softness = float(spec.get("softness_px", 30.0))
|
|
329
|
+
|
|
330
|
+
# Per-item RNG sampling (small loop, cheap)
|
|
331
|
+
gate = torch.zeros(B, dtype=torch.bool)
|
|
332
|
+
pts_all = torch.zeros(B, 3, 2, dtype=img.dtype)
|
|
333
|
+
for i in range(B):
|
|
334
|
+
g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "random_shadow"))
|
|
335
|
+
if sample_bool(p, g):
|
|
336
|
+
gate[i] = True
|
|
337
|
+
for k in range(3):
|
|
338
|
+
pts_all[i, k, 0] = sample_uniform(0, W - 1, g)
|
|
339
|
+
pts_all[i, k, 1] = sample_uniform(0, H - 1, g)
|
|
340
|
+
pts_d = pts_all.to(img.device)
|
|
341
|
+
gate_d = gate.to(img.device)
|
|
342
|
+
|
|
343
|
+
# Vectorized point-in-triangle barycentric, broadcast (B, H, W)
|
|
344
|
+
yy, xx = _coords(H, W, img.device, img.dtype)
|
|
345
|
+
p0x = pts_d[:, 0, 0].view(B, 1, 1); p0y = pts_d[:, 0, 1].view(B, 1, 1)
|
|
346
|
+
p1x = pts_d[:, 1, 0].view(B, 1, 1); p1y = pts_d[:, 1, 1].view(B, 1, 1)
|
|
347
|
+
p2x = pts_d[:, 2, 0].view(B, 1, 1); p2y = pts_d[:, 2, 1].view(B, 1, 1)
|
|
348
|
+
denom = (p1y - p2y) * (p0x - p2x) + (p2x - p1x) * (p0y - p2y) # (B, 1, 1)
|
|
349
|
+
# Avoid div by zero for degenerate (non-gated) items — their result gets
|
|
350
|
+
# masked out by the slot-gate further below.
|
|
351
|
+
denom_safe = torch.where(denom.abs() < 1e-8, torch.ones_like(denom), denom)
|
|
352
|
+
xx_b = xx.unsqueeze(0); yy_b = yy.unsqueeze(0)
|
|
353
|
+
a = ((p1y - p2y) * (xx_b - p2x) + (p2x - p1x) * (yy_b - p2y)) / denom_safe
|
|
354
|
+
b = ((p2y - p0y) * (xx_b - p2x) + (p0x - p2x) * (yy_b - p2y)) / denom_safe
|
|
355
|
+
c = 1.0 - a - b
|
|
356
|
+
inside = (a >= 0) & (b >= 0) & (c >= 0) & (denom.abs() >= 1e-8)
|
|
357
|
+
mask_in = inside.to(img.dtype) # (B, H, W)
|
|
358
|
+
|
|
359
|
+
# 3-pass iterated box blur on the whole batch in one shot (GPU friendly).
|
|
360
|
+
kbox = max(3, int(2 * softness) | 1)
|
|
361
|
+
padbox = kbox // 2
|
|
362
|
+
soft = mask_in.unsqueeze(1) # (B, 1, H, W)
|
|
363
|
+
for _ in range(3):
|
|
364
|
+
soft = F.avg_pool2d(soft, kernel_size=kbox, stride=1, padding=padbox)
|
|
365
|
+
shadow_factor = 1.0 - strength * soft.clamp(0, 1) # (B, 1, H, W)
|
|
366
|
+
|
|
367
|
+
img_dim = img * shadow_factor
|
|
368
|
+
img_out = torch.where(gate_d.view(B, 1, 1, 1).expand_as(img), img_dim, img)
|
|
369
|
+
return img_out, mask
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
GEOMETRIC_PRIMITIVES = {
|
|
373
|
+
"tps": tps,
|
|
374
|
+
"affine": affine,
|
|
375
|
+
"perspective": perspective,
|
|
376
|
+
"random_crop_pad": random_crop_pad,
|
|
377
|
+
"elastic_transform": elastic_transform,
|
|
378
|
+
"optical_distortion": optical_distortion,
|
|
379
|
+
"random_shadow": random_shadow,
|
|
380
|
+
}
|