scribble-annotation-generator 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,570 @@
1
+ # ============================================================
2
+ # Scribble Annotation Generation
3
+ # - Autoregressive Count Model
4
+ # - Set-Transformer Object Generator
5
+ # - Hungarian Matching Loss
6
+ # - PyTorch Lightning
7
+ # ============================================================
8
+
9
+ import math
10
+ import os
11
+ from typing import Dict, Optional, Tuple
12
+
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import pytorch_lightning as pl
19
+
20
+ from scipy.optimize import linear_sum_assignment
21
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
22
+
23
+ from scribble_annotation_generator.dataset import ScribbleDataset
24
+ from scribble_annotation_generator.utils import (
25
+ generate_multiclass_scribble,
26
+ unpack_feature_vector,
27
+ )
28
+
29
+
30
+ # ============================================================
31
+ # Utility
32
+ # ============================================================
33
+
34
+
35
+ def masked_mean(x, mask, dim=1, eps=1e-6):
36
+ mask = mask.float()
37
+ return (x * mask.unsqueeze(-1)).sum(dim) / (mask.sum(dim, keepdim=True) + eps)
38
+
39
+
40
+ # ============================================================
41
+ # Autoregressive Count Model
42
+ # p(n1, n2, ..., nC)
43
+ # ============================================================
44
+
45
+
46
+ class CountModel(pl.LightningModule):
47
+ """
48
+ Models joint distribution over object counts per class.
49
+ Autoregressive over classes.
50
+ """
51
+
52
+ def __init__(self, num_classes: int, hidden_dim=128, max_count=20):
53
+ super().__init__()
54
+ self.num_classes = num_classes
55
+ self.max_count = max_count
56
+
57
+ self.embedding = nn.Embedding(max_count + 1, hidden_dim)
58
+
59
+ self.mlp = nn.Sequential(
60
+ nn.Linear(hidden_dim * num_classes, hidden_dim),
61
+ nn.ReLU(),
62
+ nn.Linear(hidden_dim, (max_count + 1) * num_classes),
63
+ )
64
+
65
+ def forward(self, counts):
66
+ """
67
+ counts: (B, C) integer tensor
68
+ """
69
+ B, C = counts.shape
70
+ embeds = self.embedding(counts) # (B, C, D)
71
+ flat = embeds.view(B, -1) # (B, C*D)
72
+ logits = self.mlp(flat) # (B, C*(K+1))
73
+ logits = logits.view(B, C, self.max_count + 1)
74
+ return logits
75
+
76
+ def training_step(self, batch, batch_idx):
77
+ counts = batch["counts"] # (B, C)
78
+ logits = self(counts)
79
+ loss = F.cross_entropy(
80
+ logits.view(-1, self.max_count + 1),
81
+ counts.view(-1),
82
+ )
83
+ self.log("count_loss", loss)
84
+ return loss
85
+
86
+ def sample(self, batch_size=1):
87
+ counts = torch.zeros(batch_size, self.num_classes, dtype=torch.long)
88
+ for c in range(self.num_classes):
89
+ logits = self(counts)[:, c]
90
+ probs = F.softmax(logits, dim=-1)
91
+ counts[:, c] = torch.multinomial(probs, 1).squeeze(-1)
92
+ return counts
93
+
94
+ def configure_optimizers(self):
95
+ return torch.optim.Adam(self.parameters(), lr=1e-3)
96
+
97
+
98
+ # ============================================================
99
+ # Set Transformer Blocks
100
+ # ============================================================
101
+
102
+
103
+ class MAB(nn.Module):
104
+ def __init__(self, dim, num_heads):
105
+ super().__init__()
106
+ self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
107
+ self.ff = nn.Sequential(
108
+ nn.Linear(dim, dim),
109
+ nn.ReLU(),
110
+ nn.Linear(dim, dim),
111
+ )
112
+ self.ln1 = nn.LayerNorm(dim)
113
+ self.ln2 = nn.LayerNorm(dim)
114
+
115
+ def forward(self, Q, K, mask=None):
116
+ attn, _ = self.attn(Q, K, K, key_padding_mask=mask)
117
+ Q = self.ln1(Q + attn)
118
+ Q = self.ln2(Q + self.ff(Q))
119
+ return Q
120
+
121
+
122
+ class SAB(nn.Module):
123
+ def __init__(self, dim, num_heads):
124
+ super().__init__()
125
+ self.mab = MAB(dim, num_heads)
126
+
127
+ def forward(self, X, mask=None):
128
+ return self.mab(X, X, mask)
129
+
130
+
131
+ class PMA(nn.Module):
132
+ def __init__(self, dim, num_heads, num_seeds):
133
+ super().__init__()
134
+ self.seed = nn.Parameter(torch.randn(1, num_seeds, dim))
135
+ self.mab = MAB(dim, num_heads)
136
+
137
+ def forward(self, X, mask=None):
138
+ B = X.size(0)
139
+ seed = self.seed.expand(B, -1, -1)
140
+ return self.mab(seed, X, mask)
141
+
142
+
143
+ class ObjectEncoder(nn.Module):
144
+ def __init__(self, dim, num_heads, num_layers):
145
+ super().__init__()
146
+ self.layers = nn.ModuleList([SAB(dim, num_heads) for _ in range(num_layers)])
147
+
148
+ def forward(self, X, mask=None):
149
+ for layer in self.layers:
150
+ X = layer(X, mask)
151
+ return X
152
+
153
+
154
+ class MDNCrossAttentionDecoder(nn.Module):
155
+ """
156
+ Transformer decoder with Mixture Density Network (MDN) output.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ hidden_dim: int,
162
+ obj_dim: int,
163
+ num_components: int = 5,
164
+ latent_dim: int = 8,
165
+ num_heads: int = 4,
166
+ num_layers: int = 2,
167
+ ):
168
+ super().__init__()
169
+
170
+ self.K = num_components
171
+ self.obj_dim = obj_dim
172
+ self.latent_dim = latent_dim
173
+
174
+ self.latent_proj = nn.Linear(latent_dim, hidden_dim)
175
+
176
+ decoder_layer = nn.TransformerDecoderLayer(
177
+ d_model=hidden_dim,
178
+ nhead=num_heads,
179
+ dim_feedforward=hidden_dim * 4,
180
+ batch_first=True,
181
+ )
182
+ self.decoder = nn.TransformerDecoder(
183
+ decoder_layer,
184
+ num_layers=num_layers,
185
+ )
186
+
187
+ # Mixture heads
188
+ self.pi_head = nn.Linear(hidden_dim, self.K)
189
+ self.mu_head = nn.Linear(hidden_dim, self.K * obj_dim)
190
+ self.log_var_head = nn.Linear(hidden_dim, self.K * obj_dim)
191
+
192
+ def forward(
193
+ self,
194
+ memory, # (B, N, D)
195
+ memory_mask, # (B, N)
196
+ query_embed, # (B, D)
197
+ z=None, # (B, latent_dim)
198
+ ):
199
+ B = memory.size(0)
200
+
201
+ if z is None:
202
+ z = torch.randn(B, self.latent_dim, device=memory.device)
203
+
204
+ query = query_embed + self.latent_proj(z)
205
+ query = query.unsqueeze(1) # (B, 1, D)
206
+
207
+ out = self.decoder(
208
+ tgt=query,
209
+ memory=memory,
210
+ memory_key_padding_mask=(memory_mask == 0),
211
+ )[
212
+ :, 0
213
+ ] # (B, D)
214
+
215
+ pi_logits = self.pi_head(out) # (B, K)
216
+ mu = self.mu_head(out).view(B, self.K, self.obj_dim)
217
+ log_var = self.log_var_head(out).view(B, self.K, self.obj_dim)
218
+
219
+ log_var = log_var.clamp(-6.0, 3.0)
220
+
221
+ return pi_logits, mu, log_var
222
+
223
+
224
+ # ============================================================
225
+ # Object Generator Model
226
+ # ============================================================
227
+
228
+
229
+ class ObjectGenerator(pl.LightningModule):
230
+ """
231
+ Masked-object autoregressive generator with Hungarian matching.
232
+ """
233
+
234
+ def __init__(
235
+ self,
236
+ num_classes: int,
237
+ obj_dim: int,
238
+ hidden_dim=128,
239
+ num_heads=4,
240
+ num_encoder_layers=2,
241
+ num_decoder_layers=2,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.num_classes = num_classes
246
+ self.obj_dim = obj_dim
247
+
248
+ self.class_embed = nn.Embedding(num_classes, hidden_dim)
249
+ self.obj_embed = nn.Linear(obj_dim, hidden_dim)
250
+
251
+ self.encoder = ObjectEncoder(
252
+ dim=hidden_dim, num_heads=num_heads, num_layers=num_encoder_layers
253
+ )
254
+
255
+ self.decoder = MDNCrossAttentionDecoder(
256
+ hidden_dim=hidden_dim,
257
+ obj_dim=obj_dim,
258
+ latent_dim=8,
259
+ num_heads=num_heads,
260
+ num_layers=num_decoder_layers,
261
+ )
262
+
263
+ # --------------------------------------------------------
264
+
265
+ def forward(self, objs, classes, mask, query_class):
266
+ """
267
+ objs: (B, N, obj_dim)
268
+ classes: (B, N)
269
+ mask: (B, N) 1=present, 0=masked
270
+ query_class: (B,)
271
+ """
272
+
273
+ obj_emb = self.obj_embed(objs)
274
+ class_emb = self.class_embed(classes)
275
+ x = obj_emb + class_emb
276
+
277
+ enc = self.encoder(x, mask=(mask == 0))
278
+ query_emb = self.class_embed(query_class)
279
+
280
+ pi_logits, mu, log_var = self.decoder(
281
+ memory=enc,
282
+ memory_mask=mask,
283
+ query_embed=query_emb,
284
+ )
285
+
286
+ return pi_logits, mu, log_var
287
+
288
+ # --------------------------------------------------------
289
+
290
+ def hungarian_mdn_loss(self, pi_logits, mu, log_var, targets):
291
+ """
292
+ pi_logits: (B, K)
293
+ mu: (B, K, D)
294
+ log_var: (B, K, D)
295
+ targets: (B, T, D)
296
+ """
297
+
298
+ total_loss = 0.0
299
+ B = mu.size(0)
300
+
301
+ for i in range(B):
302
+ t = targets[i] # (T, D)
303
+
304
+ costs = []
305
+ for obj in t:
306
+ obj = obj.unsqueeze(0)
307
+ cost = self.mdn_nll(
308
+ pi_logits[i : i + 1],
309
+ mu[i : i + 1],
310
+ log_var[i : i + 1],
311
+ obj,
312
+ )
313
+ costs.append(cost)
314
+
315
+ cost = torch.stack(costs).detach().cpu().numpy()
316
+ row, col = linear_sum_assignment(cost.reshape(1, -1))
317
+ matched = t[col[0]].unsqueeze(0)
318
+
319
+ total_loss += self.mdn_nll(
320
+ pi_logits[i : i + 1],
321
+ mu[i : i + 1],
322
+ log_var[i : i + 1],
323
+ matched,
324
+ )
325
+
326
+ return total_loss / B
327
+
328
+ # --------------------------------------------------------
329
+
330
+ def training_step(self, batch, batch_idx):
331
+ objs = batch["objects"] # (B, N, D)
332
+ classes = batch["classes"] # (B, N)
333
+ mask = batch["mask"] # (B, N)
334
+ target_objs = batch["targets"] # (B, K, D)
335
+ query_class = batch["query_cls"] # (B,)
336
+
337
+ pi_logits, mu, log_var = self(objs, classes, mask, query_class)
338
+ loss = self.hungarian_mdn_loss(pi_logits, mu, log_var, target_objs)
339
+
340
+ self.log("obj_loss", loss)
341
+ return loss
342
+
343
+ # --------------------------------------------------------
344
+
345
+ def validation_step(self, batch, batch_idx):
346
+ objs = batch["objects"] # (B, N, D)
347
+ classes = batch["classes"] # (B, N)
348
+ mask = batch["mask"] # (B, N)
349
+ target_objs = batch["targets"] # (B, K, D)
350
+ query_class = batch["query_cls"] # (B,)
351
+
352
+ pi_logits, mu, log_var = self(objs, classes, mask, query_class)
353
+ loss = self.hungarian_mdn_loss(pi_logits, mu, log_var, target_objs)
354
+
355
+ self.log("val_loss", loss, prog_bar=True)
356
+ return loss
357
+
358
+ # --------------------------------------------------------
359
+
360
+ def mdn_nll(self, pi_logits, mu, log_var, target):
361
+ """
362
+ pi_logits: (B, K)
363
+ mu: (B, K, D)
364
+ log_var: (B, K, D)
365
+ target: (B, D)
366
+ """
367
+
368
+ B, K, D = mu.shape
369
+
370
+ target = target.unsqueeze(1) # (B, 1, D)
371
+
372
+ log_pi = F.log_softmax(pi_logits, dim=-1) # (B, K)
373
+
374
+ log_prob = -0.5 * (
375
+ log_var + (target - mu).pow(2) / log_var.exp() + math.log(2 * math.pi)
376
+ ).sum(
377
+ dim=-1
378
+ ) # (B, K)
379
+
380
+ log_mix = torch.logsumexp(log_pi + log_prob, dim=-1)
381
+
382
+ return -log_mix.mean()
383
+
384
+ # --------------------------------------------------------
385
+
386
+ @torch.no_grad()
387
+ def sample_from_mdn(self, pi_logits, mu, log_var, temperature=1.0):
388
+ """
389
+ Returns one sampled object per batch element.
390
+ """
391
+
392
+ pi = F.softmax(pi_logits / temperature, dim=-1) # (B, K)
393
+ comp = torch.multinomial(pi, 1).squeeze(-1) # (B,)
394
+
395
+ B = mu.size(0)
396
+ idx = torch.arange(B)
397
+
398
+ sel_mu = mu[idx, comp]
399
+ sel_std = log_var[idx, comp].exp().sqrt() * temperature
400
+
401
+ random_vector = torch.randn_like(sel_mu)
402
+ random_vector[:, 2] = random_vector[:, 0]
403
+ random_vector[:, 3] = random_vector[:, 1]
404
+
405
+ return sel_mu + random_vector * sel_std
406
+
407
+ # --------------------------------------------------------
408
+
409
+ def configure_optimizers(self):
410
+ return torch.optim.Adam(self.parameters(), lr=1e-4)
411
+
412
+
413
+ # ============================================================
414
+ # Inference
415
+ # ============================================================
416
+
417
+
418
+ def generate_scribble(
419
+ model: ObjectGenerator,
420
+ objects: torch.Tensor,
421
+ classes: torch.Tensor,
422
+ mask: torch.Tensor,
423
+ colour_map: Dict[Tuple[int, int, int], int],
424
+ i: int,
425
+ output_dir: str,
426
+ ) -> np.ndarray:
427
+ os.makedirs(output_dir, exist_ok=True)
428
+ index = mask.argmin().item()
429
+
430
+ original = generate_multiclass_scribble(
431
+ image_shape=(512, 512),
432
+ objects=[unpack_feature_vector(obj) for obj in objects.numpy()],
433
+ classes=classes.numpy(),
434
+ colour_map=colour_map,
435
+ )
436
+
437
+ # Save the original scribble
438
+ output_path = os.path.join(output_dir, f"{i}_original.png")
439
+
440
+ # Convert RGB to BGR for saving with OpenCV
441
+ original_bgr = cv2.cvtColor(original, cv2.COLOR_RGB2BGR)
442
+ cv2.imwrite(str(output_path), original_bgr)
443
+
444
+ classes_removed = classes.clone()
445
+ classes_removed[index:] = 0
446
+ synthetic = generate_multiclass_scribble(
447
+ image_shape=(512, 512),
448
+ objects=[unpack_feature_vector(obj) for obj in objects.numpy()],
449
+ classes=classes_removed.numpy(),
450
+ colour_map=colour_map,
451
+ )
452
+
453
+ # Save the synthetic scribble
454
+ output_path = os.path.join(output_dir, f"{i}_synthetic_removed.png")
455
+
456
+ # Convert RGB to BGR for saving with OpenCV
457
+ synthetic_bgr = cv2.cvtColor(synthetic, cv2.COLOR_RGB2BGR)
458
+ cv2.imwrite(str(output_path), synthetic_bgr)
459
+
460
+ while index < len(classes) and classes[index] != 0:
461
+ with torch.no_grad():
462
+ pi_logits, mu, log_var = model(
463
+ objects[None],
464
+ classes[None],
465
+ mask[None],
466
+ classes[index : index + 1],
467
+ )
468
+ sample = model.sample_from_mdn(pi_logits, mu, log_var, temperature=1.0)[0]
469
+
470
+ objects[index] = sample
471
+ mask[index] = 1
472
+ index += 1
473
+
474
+ print(objects)
475
+
476
+ synthetic = generate_multiclass_scribble(
477
+ image_shape=(512, 512),
478
+ objects=[unpack_feature_vector(obj) for obj in objects.numpy()],
479
+ classes=classes.numpy(),
480
+ colour_map=colour_map,
481
+ )
482
+
483
+ # Save the synthetic scribble
484
+ output_path = os.path.join(output_dir, f"{i}_synthetic.png")
485
+
486
+ # Convert RGB to BGR for saving with OpenCV
487
+ synthetic_bgr = cv2.cvtColor(synthetic, cv2.COLOR_RGB2BGR)
488
+ cv2.imwrite(str(output_path), synthetic_bgr)
489
+
490
+
491
+ # ============================================================
492
+ # Training Entry Point
493
+ # ============================================================
494
+
495
+ def train_and_infer(
496
+ train_dir: str,
497
+ val_dir: str,
498
+ colour_map: Dict[Tuple[int, int, int], int],
499
+ checkpoint_dir: str = "./local/nn-checkpoints",
500
+ inference_dir: str = "./local/nn-inference",
501
+ batch_size: int = 8,
502
+ num_workers: int = 4,
503
+ max_epochs: int = 50,
504
+ num_classes: Optional[int] = None,
505
+ ):
506
+ os.makedirs(checkpoint_dir, exist_ok=True)
507
+ os.makedirs(inference_dir, exist_ok=True)
508
+
509
+ resolved_num_classes = (
510
+ num_classes if num_classes is not None else len(set(colour_map.values()))
511
+ )
512
+
513
+ train_dataset = ScribbleDataset(
514
+ num_classes=resolved_num_classes, data_dir=train_dir, colour_map=colour_map
515
+ )
516
+ train_loader = torch.utils.data.DataLoader(
517
+ train_dataset,
518
+ batch_size=batch_size,
519
+ num_workers=num_workers,
520
+ shuffle=True,
521
+ prefetch_factor=2,
522
+ )
523
+
524
+ val_dataset = ScribbleDataset(
525
+ num_classes=resolved_num_classes, data_dir=val_dir, colour_map=colour_map
526
+ )
527
+ val_loader = torch.utils.data.DataLoader(
528
+ val_dataset,
529
+ batch_size=1,
530
+ num_workers=num_workers,
531
+ shuffle=False,
532
+ prefetch_factor=2,
533
+ )
534
+
535
+ obj_dim = train_dataset[0]["objects"].shape[1]
536
+
537
+ obj_model = ObjectGenerator(
538
+ num_classes=resolved_num_classes,
539
+ obj_dim=obj_dim,
540
+ hidden_dim=256,
541
+ num_encoder_layers=4,
542
+ num_decoder_layers=4,
543
+ )
544
+
545
+ early_stop = EarlyStopping(monitor="val_loss", mode="min", verbose=True)
546
+ checkpoint = ModelCheckpoint(
547
+ monitor="val_loss",
548
+ mode="min",
549
+ save_top_k=1,
550
+ dirpath=checkpoint_dir,
551
+ filename="best-checkpoint",
552
+ )
553
+
554
+ trainer = pl.Trainer(max_epochs=max_epochs, callbacks=[early_stop, checkpoint])
555
+ trainer.fit(obj_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
556
+
557
+ for i in range(len(val_dataset)):
558
+ datum = val_dataset[i]
559
+ objects = datum["objects"]
560
+ classes = datum["classes"]
561
+ mask = datum["mask"]
562
+ generate_scribble(
563
+ obj_model,
564
+ objects=objects,
565
+ classes=classes,
566
+ mask=mask,
567
+ colour_map=colour_map,
568
+ i=i,
569
+ output_dir=inference_dir,
570
+ )