paraug 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
paraug/__init__.py ADDED
@@ -0,0 +1,31 @@
1
+ """paraug — Parity Augmentation library for image + mask.
2
+
3
+ Bit-exact CPU/GPU parity across 7 geometric + 24 photometric primitives.
4
+ Single entry point: build with `AugPipeline(config)`, then call with
5
+ `aug(img, mask=None, seed_base=..., epoch=..., step=...)` returning
6
+ `(img_out, mask_out)`. Per-item RNG is sampled on CPU regardless of tensor
7
+ device, so the same seed produces identical output on CPU and CUDA back-ends
8
+ within a tight tolerance (1e-6 elementwise, 2e-4 for grid_sample-class ops).
9
+
10
+ Quickstart:
11
+
12
+ from paraug import AugPipeline
13
+ aug = AugPipeline({"geometric": {"affine": {"p": 1.0, "rot_deg": 15}}})
14
+ img_out, mask_out = aug(img, mask=mask, seed_base=42)
15
+ """
16
+ from .pipeline import AugPipeline
17
+ from .utils import set_deterministic, per_item_seed, cpu_generator
18
+ from .geometric import GEOMETRIC_PRIMITIVES
19
+ from .photometric import PHOTOMETRIC_PRIMITIVES
20
+
21
+ __version__ = "0.1.1"
22
+
23
+ __all__ = [
24
+ "AugPipeline",
25
+ "set_deterministic",
26
+ "per_item_seed",
27
+ "cpu_generator",
28
+ "GEOMETRIC_PRIMITIVES",
29
+ "PHOTOMETRIC_PRIMITIVES",
30
+ "__version__",
31
+ ]
paraug/geometric.py ADDED
@@ -0,0 +1,380 @@
1
+ """Geometric primitives — both image and mask are warped together.
2
+
3
+ Each primitive has signature:
4
+ fn(img, mask, spec, seed_base, epoch, step) -> (img_out, mask_out)
5
+
6
+ Image uses bilinear sampling; mask uses nearest. Padding is zero (background).
7
+ All primitives gate per-item: items with Bernoulli=False pass through bit-identical.
8
+ """
9
+ import math
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from .utils import per_item_seed, cpu_generator, sample_uniform, sample_bool
14
+
15
+
16
+ _GRID_CACHE = {}
17
+
18
+
19
+ def _coords(H: int, W: int, device, dtype):
20
+ """Cached (yy, xx) pixel-coord meshgrid for sampling math."""
21
+ key = (H, W, str(device), str(dtype))
22
+ if key not in _GRID_CACHE:
23
+ yy, xx = torch.meshgrid(
24
+ torch.arange(H, device=device, dtype=dtype),
25
+ torch.arange(W, device=device, dtype=dtype),
26
+ indexing="ij",
27
+ )
28
+ _GRID_CACHE[key] = (yy, xx)
29
+ return _GRID_CACHE[key]
30
+
31
+
32
+ def _apply_grid(img, mask, grid, gate_d):
33
+ """Warp img + mask via grid; where gate is False, leave original.
34
+
35
+ grid: (B, H, W, 2) in normalised [-1, 1] coords.
36
+ gate_d: (B,) bool, on img.device.
37
+ """
38
+ B = img.shape[0]
39
+ img_w = F.grid_sample(img, grid, mode="bilinear",
40
+ padding_mode="zeros", align_corners=True)
41
+ img_out = torch.where(gate_d.view(B, 1, 1, 1).expand_as(img), img_w, img)
42
+ if mask is not None:
43
+ mask_w = F.grid_sample(mask, grid, mode="nearest",
44
+ padding_mode="zeros", align_corners=True)
45
+ mask_out = torch.where(gate_d.view(B, 1, 1, 1).expand_as(mask), mask_w, mask)
46
+ else:
47
+ mask_out = None
48
+ return img_out, mask_out
49
+
50
+
51
+ # ─── 1: tps ──────────────────────────────────────────────────────────
52
+ def _sample_tps_disp(B, H, W, p, max_disp, nc, seed_base, epoch, step, dtype):
53
+ """Sample per-item TPS control points + upsample to full (H, W) displacement.
54
+
55
+ Returns CPU tensors (caller should .to(device) as needed):
56
+ gate (B,) bool
57
+ disp_y (B, H, W)
58
+ disp_x (B, H, W)
59
+
60
+ Factored out so any caller that wants to mirror tps()'s displacement
61
+ sampling (e.g. a downstream pipeline that pairs tps with extra GT
62
+ propagation) can pull the same disp field by using the same seed.
63
+ """
64
+ gate = torch.zeros(B, dtype=torch.bool)
65
+ dy_ctrl = torch.zeros(B, 1, nc, nc, dtype=dtype)
66
+ dx_ctrl = torch.zeros(B, 1, nc, nc, dtype=dtype)
67
+ for i in range(B):
68
+ g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "tps"))
69
+ if sample_bool(p, g):
70
+ gate[i] = True
71
+ dy_ctrl[i, 0] = torch.empty(nc, nc).uniform_(-max_disp, max_disp, generator=g)
72
+ dx_ctrl[i, 0] = torch.empty(nc, nc).uniform_(-max_disp, max_disp, generator=g)
73
+ disp_y = F.interpolate(dy_ctrl, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
74
+ disp_x = F.interpolate(dx_ctrl, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
75
+ return gate, disp_y, disp_x
76
+
77
+
78
+ def tps(img, mask, spec, seed_base, epoch, step):
79
+ """Thin-plate-spline-ish warp via low-res control grid → bilinear upsample.
80
+
81
+ spec = {"p": prob, "max_disp": pixels, "n_ctrl": int}
82
+ """
83
+ B, _, H, W = img.shape
84
+ p = float(spec.get("p", 1.0))
85
+ max_disp = float(spec.get("max_disp", 18.0))
86
+ nc = int(spec.get("n_ctrl", 5))
87
+
88
+ gate, disp_y, disp_x = _sample_tps_disp(
89
+ B, H, W, p, max_disp, nc, seed_base, epoch, step, img.dtype)
90
+ disp_y_d = disp_y.to(img.device); disp_x_d = disp_x.to(img.device)
91
+ gate_d = gate.to(img.device)
92
+ yy, xx = _coords(H, W, img.device, img.dtype)
93
+ src_x = xx[None] - disp_x_d
94
+ src_y = yy[None] - disp_y_d
95
+ grid = torch.stack([2.0 * src_x / (W - 1) - 1.0,
96
+ 2.0 * src_y / (H - 1) - 1.0], dim=-1)
97
+ return _apply_grid(img, mask, grid, gate_d)
98
+
99
+
100
+ # ─── 2: affine ───────────────────────────────────────────────────────
101
+ def affine(img, mask, spec, seed_base, epoch, step):
102
+ """Random rotation / scale / translation.
103
+
104
+ spec = {"p": prob, "rot_deg": float, "scale_range": (lo, hi),
105
+ "translate_frac": float (fraction of H/W)}
106
+ """
107
+ B, C, H, W = img.shape
108
+ p = float(spec.get("p", 1.0))
109
+ rot_deg = float(spec.get("rot_deg", 30.0))
110
+ scale_lo, scale_hi = spec.get("scale_range", (0.85, 1.15))
111
+ trans_frac = float(spec.get("translate_frac", 0.05))
112
+
113
+ gate = torch.zeros(B, dtype=torch.bool)
114
+ # (B, 2, 3) affine matrix per item; identity for skipped items
115
+ theta = torch.zeros(B, 2, 3, dtype=img.dtype)
116
+ theta[:, 0, 0] = 1.0; theta[:, 1, 1] = 1.0
117
+ for i in range(B):
118
+ g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "affine"))
119
+ if sample_bool(p, g):
120
+ gate[i] = True
121
+ ang = math.radians(sample_uniform(-rot_deg, rot_deg, g))
122
+ s = sample_uniform(scale_lo, scale_hi, g)
123
+ tx = sample_uniform(-trans_frac, trans_frac, g)
124
+ ty = sample_uniform(-trans_frac, trans_frac, g)
125
+ ca, sa = math.cos(ang) / s, math.sin(ang) / s
126
+ theta[i] = torch.tensor([[ca, sa, tx], [-sa, ca, ty]])
127
+ theta_d = theta.to(img.device)
128
+ gate_d = gate.to(img.device)
129
+ grid = F.affine_grid(theta_d, [B, C, H, W], align_corners=True)
130
+ return _apply_grid(img, mask, grid, gate_d)
131
+
132
+
133
+ # ─── 3: perspective ──────────────────────────────────────────────────
134
+ def perspective(img, mask, spec, seed_base, epoch, step):
135
+ """4-point perspective via homography H built from corner jitters.
136
+
137
+ spec = {"p": prob, "max_disp_frac": float (fraction of H/W)}
138
+ Each corner of the image is jittered independently by up to ±max_disp_frac.
139
+ """
140
+ B, C, H, W = img.shape
141
+ p = float(spec.get("p", 1.0))
142
+ max_disp = float(spec.get("max_disp_frac", 0.05))
143
+
144
+ # Reference corners in normalised coords (-1, +1)
145
+ src = torch.tensor([[-1.0, -1.0], [1.0, -1.0],
146
+ [1.0, 1.0], [-1.0, 1.0]], dtype=img.dtype)
147
+ gate = torch.zeros(B, dtype=torch.bool)
148
+ dst_all = src.unsqueeze(0).repeat(B, 1, 1)
149
+ for i in range(B):
150
+ g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "perspective"))
151
+ if sample_bool(p, g):
152
+ gate[i] = True
153
+ dst = src.clone()
154
+ for c in range(4):
155
+ dst[c, 0] += sample_uniform(-max_disp, max_disp, g) * 2 # *2 because normalised
156
+ dst[c, 1] += sample_uniform(-max_disp, max_disp, g) * 2
157
+ dst_all[i] = dst
158
+
159
+ # Build homography per item (CPU, small linear solve), then bring grid to device
160
+ yy, xx = _coords(H, W, "cpu", img.dtype)
161
+ base_grid = torch.stack([2.0 * xx / (W - 1) - 1.0,
162
+ 2.0 * yy / (H - 1) - 1.0], dim=-1) # (H, W, 2) in [-1, 1]
163
+ grids = torch.empty(B, H, W, 2, dtype=img.dtype)
164
+ for i in range(B):
165
+ Hmat = _solve_homography(src, dst_all[i])
166
+ # Apply homography to base_grid pixels (back-warp)
167
+ bg = base_grid.reshape(-1, 2) # (HW, 2)
168
+ ones = torch.ones(bg.shape[0], 1, dtype=img.dtype)
169
+ bgh = torch.cat([bg, ones], dim=1) # (HW, 3)
170
+ warped = bgh @ Hmat.T
171
+ warped = warped[:, :2] / warped[:, 2:].clamp(min=1e-8)
172
+ grids[i] = warped.reshape(H, W, 2)
173
+ grid = grids.to(img.device)
174
+ gate_d = gate.to(img.device)
175
+ return _apply_grid(img, mask, grid, gate_d)
176
+
177
+
178
+ def _solve_homography(src, dst):
179
+ """Compute 3×3 homography mapping src corners to dst corners. CPU, fp32.
180
+ src, dst: (4, 2) in normalised coords."""
181
+ A = torch.zeros(8, 8, dtype=src.dtype)
182
+ b = torch.zeros(8, dtype=src.dtype)
183
+ for k in range(4):
184
+ x, y = src[k]
185
+ u, v = dst[k]
186
+ A[2*k] = torch.tensor([x, y, 1, 0, 0, 0, -u*x, -u*y])
187
+ A[2*k + 1] = torch.tensor([0, 0, 0, x, y, 1, -v*x, -v*y])
188
+ b[2*k] = u
189
+ b[2*k + 1] = v
190
+ h = torch.linalg.solve(A, b)
191
+ H = torch.tensor([[h[0], h[1], h[2]],
192
+ [h[3], h[4], h[5]],
193
+ [h[6], h[7], 1.0]], dtype=src.dtype)
194
+ return H
195
+
196
+
197
+ # ─── 4: random_crop_pad ──────────────────────────────────────────────
198
+ def random_crop_pad(img, mask, spec, seed_base, epoch, step):
199
+ """Random crop, then pad back to original size with zero (or reflect).
200
+
201
+ spec = {"p": prob, "min_scale": float (lower bound on crop area frac)}
202
+ """
203
+ B, C, H, W = img.shape
204
+ p = float(spec.get("p", 1.0))
205
+ min_scale = float(spec.get("min_scale", 0.7))
206
+
207
+ # We build the sampling grid: identity scaled + translated within bounds.
208
+ gate = torch.zeros(B, dtype=torch.bool)
209
+ theta = torch.zeros(B, 2, 3, dtype=img.dtype)
210
+ theta[:, 0, 0] = 1.0; theta[:, 1, 1] = 1.0
211
+ for i in range(B):
212
+ g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "random_crop_pad"))
213
+ if sample_bool(p, g):
214
+ gate[i] = True
215
+ s = sample_uniform(min_scale, 1.0, g)
216
+ # crop area = s × original; translation within (1-s) so crop stays inside
217
+ tx = sample_uniform(-(1 - s), (1 - s), g)
218
+ ty = sample_uniform(-(1 - s), (1 - s), g)
219
+ theta[i] = torch.tensor([[s, 0, tx], [0, s, ty]])
220
+ theta_d = theta.to(img.device); gate_d = gate.to(img.device)
221
+ grid = F.affine_grid(theta_d, [B, C, H, W], align_corners=True)
222
+ return _apply_grid(img, mask, grid, gate_d)
223
+
224
+
225
+ # ─── 5: elastic_transform ────────────────────────────────────────────
226
+ def elastic_transform(img, mask, spec, seed_base, epoch, step):
227
+ """Elastic displacement-field warp.
228
+
229
+ spec = {"p": prob, "alpha": pixels (max disp)}
230
+
231
+ Implementation note (vs albumentations' ElasticTransform):
232
+ The displacement field is bilinear-upsampled from a low-resolution random
233
+ grid of shape (H/8, W/8). This differs from albumentations' dense
234
+ Gaussian-smoothed noise; ours produces smoother but more grid-aligned
235
+ distortion. Empirically indistinguishable for sim-to-real augmentation.
236
+ Smoothness is controlled by the fixed 1/8 downsample factor (no `sigma`
237
+ knob — by-design).
238
+ """
239
+ B, C, H, W = img.shape
240
+ p = float(spec.get("p", 1.0))
241
+ alpha = float(spec.get("alpha", 8.0))
242
+
243
+ gate = torch.zeros(B, dtype=torch.bool)
244
+ # Per-item random field at full resolution, smoothed via separable Gaussian
245
+ disp_y = torch.zeros(B, H, W, dtype=img.dtype)
246
+ disp_x = torch.zeros(B, H, W, dtype=img.dtype)
247
+ # Build a low-res random field (H/8, W/8) and bicubic-upsample for speed +
248
+ # smoothness (equivalent to smoothing with σ proportional to scale factor).
249
+ hl = max(8, int(H / 8)); wl = max(8, int(W / 8))
250
+ for i in range(B):
251
+ g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "elastic_transform"))
252
+ if sample_bool(p, g):
253
+ gate[i] = True
254
+ dy_low = torch.empty(1, 1, hl, wl).uniform_(-1.0, 1.0, generator=g)
255
+ dx_low = torch.empty(1, 1, hl, wl).uniform_(-1.0, 1.0, generator=g)
256
+ # Upsample bilinear, scale by alpha
257
+ disp_y[i] = F.interpolate(dy_low, size=(H, W), mode="bilinear", align_corners=False)[0, 0] * alpha
258
+ disp_x[i] = F.interpolate(dx_low, size=(H, W), mode="bilinear", align_corners=False)[0, 0] * alpha
259
+ # Optional: gaussian smooth — keep simple with bilinear upsample only
260
+ disp_y_d = disp_y.to(img.device); disp_x_d = disp_x.to(img.device)
261
+ gate_d = gate.to(img.device)
262
+ yy, xx = _coords(H, W, img.device, img.dtype)
263
+ src_x = xx[None] - disp_x_d
264
+ src_y = yy[None] - disp_y_d
265
+ grid = torch.stack([2.0 * src_x / (W - 1) - 1.0,
266
+ 2.0 * src_y / (H - 1) - 1.0], dim=-1)
267
+ return _apply_grid(img, mask, grid, gate_d)
268
+
269
+
270
+ # ─── 6: optical_distortion ───────────────────────────────────────────
271
+ def optical_distortion(img, mask, spec, seed_base, epoch, step):
272
+ """Barrel (k>0) / pincushion (k<0) lens distortion.
273
+
274
+ spec = {"p": prob, "k": float coefficient (max abs value)}
275
+ Maps (x, y) → (x', y') = (x·(1 + k·r²), y·(1 + k·r²)), r = normalised radius.
276
+ Per-item k sampled uniformly from [-k_max, k_max].
277
+
278
+ Implementation note:
279
+ This uses a first-order **forward** approximation `scale = 1 + k·r²` for
280
+ the back-warp lookup (computing src = nx · scale using output r). For
281
+ moderate k ∈ [-0.5, 0.5] the visual distortion is geometrically correct.
282
+ For larger |k| the strictly-correct inverse requires Newton iteration on
283
+ a cubic; not implemented since the small-k regime covers typical
284
+ photographic distortion.
285
+ """
286
+ B, C, H, W = img.shape
287
+ p = float(spec.get("p", 1.0))
288
+ kmax = float(spec.get("k", 0.3))
289
+
290
+ gate = torch.zeros(B, dtype=torch.bool)
291
+ ks = torch.zeros(B, dtype=img.dtype)
292
+ for i in range(B):
293
+ g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "optical_distortion"))
294
+ if sample_bool(p, g):
295
+ gate[i] = True
296
+ ks[i] = sample_uniform(-kmax, kmax, g)
297
+ ks_d = ks.to(img.device); gate_d = gate.to(img.device)
298
+
299
+ yy, xx = _coords(H, W, img.device, img.dtype)
300
+ nx = 2.0 * xx / (W - 1) - 1.0
301
+ ny = 2.0 * yy / (H - 1) - 1.0
302
+ r2 = nx ** 2 + ny ** 2 # (H, W)
303
+ # Per-item radial scale: 1 + k_i * r²
304
+ scale = 1.0 + ks_d.view(B, 1, 1) * r2[None]
305
+ src_nx = nx[None] * scale # back-warp coords in normalised space
306
+ src_ny = ny[None] * scale
307
+ grid = torch.stack([src_nx, src_ny], dim=-1)
308
+ return _apply_grid(img, mask, grid, gate_d)
309
+
310
+
311
+ # ─── 7: random_shadow (vectorized) ────────────────────────────────────
312
+ def random_shadow(img, mask, spec, seed_base, epoch, step):
313
+ """Random triangular shadow region — multiplies image by a soft falloff
314
+ factor inside a randomly placed triangle. Mask is unchanged.
315
+
316
+ spec = {"p": prob, "strength": float in [0,1], "softness_px": float}
317
+
318
+ Vectorized version (Phase 4): the previous per-item Python loop over the
319
+ avg_pool blur (3 passes × kernel ≈ 61 per item × B items) was the worst
320
+ wall-clock primitive at ~4.2 s/iter on B=8 850×1100. Batching the box blur
321
+ over the B dimension lets the GPU run all items in parallel; the per-item
322
+ triangle is built with broadcasted barycentric math instead of a Python
323
+ loop. RNG sampling stays per-item to preserve bit-exact determinism.
324
+ """
325
+ B, C, H, W = img.shape
326
+ p = float(spec.get("p", 1.0))
327
+ strength = float(spec.get("strength", 0.4))
328
+ softness = float(spec.get("softness_px", 30.0))
329
+
330
+ # Per-item RNG sampling (small loop, cheap)
331
+ gate = torch.zeros(B, dtype=torch.bool)
332
+ pts_all = torch.zeros(B, 3, 2, dtype=img.dtype)
333
+ for i in range(B):
334
+ g = cpu_generator(per_item_seed(seed_base, epoch, step, i, "random_shadow"))
335
+ if sample_bool(p, g):
336
+ gate[i] = True
337
+ for k in range(3):
338
+ pts_all[i, k, 0] = sample_uniform(0, W - 1, g)
339
+ pts_all[i, k, 1] = sample_uniform(0, H - 1, g)
340
+ pts_d = pts_all.to(img.device)
341
+ gate_d = gate.to(img.device)
342
+
343
+ # Vectorized point-in-triangle barycentric, broadcast (B, H, W)
344
+ yy, xx = _coords(H, W, img.device, img.dtype)
345
+ p0x = pts_d[:, 0, 0].view(B, 1, 1); p0y = pts_d[:, 0, 1].view(B, 1, 1)
346
+ p1x = pts_d[:, 1, 0].view(B, 1, 1); p1y = pts_d[:, 1, 1].view(B, 1, 1)
347
+ p2x = pts_d[:, 2, 0].view(B, 1, 1); p2y = pts_d[:, 2, 1].view(B, 1, 1)
348
+ denom = (p1y - p2y) * (p0x - p2x) + (p2x - p1x) * (p0y - p2y) # (B, 1, 1)
349
+ # Avoid div by zero for degenerate (non-gated) items — their result gets
350
+ # masked out by the slot-gate further below.
351
+ denom_safe = torch.where(denom.abs() < 1e-8, torch.ones_like(denom), denom)
352
+ xx_b = xx.unsqueeze(0); yy_b = yy.unsqueeze(0)
353
+ a = ((p1y - p2y) * (xx_b - p2x) + (p2x - p1x) * (yy_b - p2y)) / denom_safe
354
+ b = ((p2y - p0y) * (xx_b - p2x) + (p0x - p2x) * (yy_b - p2y)) / denom_safe
355
+ c = 1.0 - a - b
356
+ inside = (a >= 0) & (b >= 0) & (c >= 0) & (denom.abs() >= 1e-8)
357
+ mask_in = inside.to(img.dtype) # (B, H, W)
358
+
359
+ # 3-pass iterated box blur on the whole batch in one shot (GPU friendly).
360
+ kbox = max(3, int(2 * softness) | 1)
361
+ padbox = kbox // 2
362
+ soft = mask_in.unsqueeze(1) # (B, 1, H, W)
363
+ for _ in range(3):
364
+ soft = F.avg_pool2d(soft, kernel_size=kbox, stride=1, padding=padbox)
365
+ shadow_factor = 1.0 - strength * soft.clamp(0, 1) # (B, 1, H, W)
366
+
367
+ img_dim = img * shadow_factor
368
+ img_out = torch.where(gate_d.view(B, 1, 1, 1).expand_as(img), img_dim, img)
369
+ return img_out, mask
370
+
371
+
372
+ GEOMETRIC_PRIMITIVES = {
373
+ "tps": tps,
374
+ "affine": affine,
375
+ "perspective": perspective,
376
+ "random_crop_pad": random_crop_pad,
377
+ "elastic_transform": elastic_transform,
378
+ "optical_distortion": optical_distortion,
379
+ "random_shadow": random_shadow,
380
+ }