scribble-annotation-generator 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scribble_annotation_generator/__init__.py +0 -0
- scribble_annotation_generator/cli.py +195 -0
- scribble_annotation_generator/crop_field.py +366 -0
- scribble_annotation_generator/dataset.py +96 -0
- scribble_annotation_generator/debug.py +43 -0
- scribble_annotation_generator/nn.py +570 -0
- scribble_annotation_generator/utils.py +495 -0
- scribble_annotation_generator-0.0.1.dist-info/METADATA +108 -0
- scribble_annotation_generator-0.0.1.dist-info/RECORD +11 -0
- scribble_annotation_generator-0.0.1.dist-info/WHEEL +4 -0
- scribble_annotation_generator-0.0.1.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
# ============================================================
|
|
2
|
+
# Scribble Annotation Generation
|
|
3
|
+
# - Autoregressive Count Model
|
|
4
|
+
# - Set-Transformer Object Generator
|
|
5
|
+
# - Hungarian Matching Loss
|
|
6
|
+
# - PyTorch Lightning
|
|
7
|
+
# ============================================================
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
import os
|
|
11
|
+
from typing import Dict, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
import cv2
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
import torch.nn.functional as F
|
|
18
|
+
import pytorch_lightning as pl
|
|
19
|
+
|
|
20
|
+
from scipy.optimize import linear_sum_assignment
|
|
21
|
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|
22
|
+
|
|
23
|
+
from scribble_annotation_generator.dataset import ScribbleDataset
|
|
24
|
+
from scribble_annotation_generator.utils import (
|
|
25
|
+
generate_multiclass_scribble,
|
|
26
|
+
unpack_feature_vector,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# ============================================================
|
|
31
|
+
# Utility
|
|
32
|
+
# ============================================================
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def masked_mean(x, mask, dim=1, eps=1e-6):
|
|
36
|
+
mask = mask.float()
|
|
37
|
+
return (x * mask.unsqueeze(-1)).sum(dim) / (mask.sum(dim, keepdim=True) + eps)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# ============================================================
|
|
41
|
+
# Autoregressive Count Model
|
|
42
|
+
# p(n1, n2, ..., nC)
|
|
43
|
+
# ============================================================
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CountModel(pl.LightningModule):
|
|
47
|
+
"""
|
|
48
|
+
Models joint distribution over object counts per class.
|
|
49
|
+
Autoregressive over classes.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, num_classes: int, hidden_dim=128, max_count=20):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.num_classes = num_classes
|
|
55
|
+
self.max_count = max_count
|
|
56
|
+
|
|
57
|
+
self.embedding = nn.Embedding(max_count + 1, hidden_dim)
|
|
58
|
+
|
|
59
|
+
self.mlp = nn.Sequential(
|
|
60
|
+
nn.Linear(hidden_dim * num_classes, hidden_dim),
|
|
61
|
+
nn.ReLU(),
|
|
62
|
+
nn.Linear(hidden_dim, (max_count + 1) * num_classes),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def forward(self, counts):
|
|
66
|
+
"""
|
|
67
|
+
counts: (B, C) integer tensor
|
|
68
|
+
"""
|
|
69
|
+
B, C = counts.shape
|
|
70
|
+
embeds = self.embedding(counts) # (B, C, D)
|
|
71
|
+
flat = embeds.view(B, -1) # (B, C*D)
|
|
72
|
+
logits = self.mlp(flat) # (B, C*(K+1))
|
|
73
|
+
logits = logits.view(B, C, self.max_count + 1)
|
|
74
|
+
return logits
|
|
75
|
+
|
|
76
|
+
def training_step(self, batch, batch_idx):
|
|
77
|
+
counts = batch["counts"] # (B, C)
|
|
78
|
+
logits = self(counts)
|
|
79
|
+
loss = F.cross_entropy(
|
|
80
|
+
logits.view(-1, self.max_count + 1),
|
|
81
|
+
counts.view(-1),
|
|
82
|
+
)
|
|
83
|
+
self.log("count_loss", loss)
|
|
84
|
+
return loss
|
|
85
|
+
|
|
86
|
+
def sample(self, batch_size=1):
|
|
87
|
+
counts = torch.zeros(batch_size, self.num_classes, dtype=torch.long)
|
|
88
|
+
for c in range(self.num_classes):
|
|
89
|
+
logits = self(counts)[:, c]
|
|
90
|
+
probs = F.softmax(logits, dim=-1)
|
|
91
|
+
counts[:, c] = torch.multinomial(probs, 1).squeeze(-1)
|
|
92
|
+
return counts
|
|
93
|
+
|
|
94
|
+
def configure_optimizers(self):
|
|
95
|
+
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# ============================================================
|
|
99
|
+
# Set Transformer Blocks
|
|
100
|
+
# ============================================================
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class MAB(nn.Module):
|
|
104
|
+
def __init__(self, dim, num_heads):
|
|
105
|
+
super().__init__()
|
|
106
|
+
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
|
107
|
+
self.ff = nn.Sequential(
|
|
108
|
+
nn.Linear(dim, dim),
|
|
109
|
+
nn.ReLU(),
|
|
110
|
+
nn.Linear(dim, dim),
|
|
111
|
+
)
|
|
112
|
+
self.ln1 = nn.LayerNorm(dim)
|
|
113
|
+
self.ln2 = nn.LayerNorm(dim)
|
|
114
|
+
|
|
115
|
+
def forward(self, Q, K, mask=None):
|
|
116
|
+
attn, _ = self.attn(Q, K, K, key_padding_mask=mask)
|
|
117
|
+
Q = self.ln1(Q + attn)
|
|
118
|
+
Q = self.ln2(Q + self.ff(Q))
|
|
119
|
+
return Q
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class SAB(nn.Module):
|
|
123
|
+
def __init__(self, dim, num_heads):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.mab = MAB(dim, num_heads)
|
|
126
|
+
|
|
127
|
+
def forward(self, X, mask=None):
|
|
128
|
+
return self.mab(X, X, mask)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class PMA(nn.Module):
|
|
132
|
+
def __init__(self, dim, num_heads, num_seeds):
|
|
133
|
+
super().__init__()
|
|
134
|
+
self.seed = nn.Parameter(torch.randn(1, num_seeds, dim))
|
|
135
|
+
self.mab = MAB(dim, num_heads)
|
|
136
|
+
|
|
137
|
+
def forward(self, X, mask=None):
|
|
138
|
+
B = X.size(0)
|
|
139
|
+
seed = self.seed.expand(B, -1, -1)
|
|
140
|
+
return self.mab(seed, X, mask)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class ObjectEncoder(nn.Module):
|
|
144
|
+
def __init__(self, dim, num_heads, num_layers):
|
|
145
|
+
super().__init__()
|
|
146
|
+
self.layers = nn.ModuleList([SAB(dim, num_heads) for _ in range(num_layers)])
|
|
147
|
+
|
|
148
|
+
def forward(self, X, mask=None):
|
|
149
|
+
for layer in self.layers:
|
|
150
|
+
X = layer(X, mask)
|
|
151
|
+
return X
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class MDNCrossAttentionDecoder(nn.Module):
|
|
155
|
+
"""
|
|
156
|
+
Transformer decoder with Mixture Density Network (MDN) output.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
hidden_dim: int,
|
|
162
|
+
obj_dim: int,
|
|
163
|
+
num_components: int = 5,
|
|
164
|
+
latent_dim: int = 8,
|
|
165
|
+
num_heads: int = 4,
|
|
166
|
+
num_layers: int = 2,
|
|
167
|
+
):
|
|
168
|
+
super().__init__()
|
|
169
|
+
|
|
170
|
+
self.K = num_components
|
|
171
|
+
self.obj_dim = obj_dim
|
|
172
|
+
self.latent_dim = latent_dim
|
|
173
|
+
|
|
174
|
+
self.latent_proj = nn.Linear(latent_dim, hidden_dim)
|
|
175
|
+
|
|
176
|
+
decoder_layer = nn.TransformerDecoderLayer(
|
|
177
|
+
d_model=hidden_dim,
|
|
178
|
+
nhead=num_heads,
|
|
179
|
+
dim_feedforward=hidden_dim * 4,
|
|
180
|
+
batch_first=True,
|
|
181
|
+
)
|
|
182
|
+
self.decoder = nn.TransformerDecoder(
|
|
183
|
+
decoder_layer,
|
|
184
|
+
num_layers=num_layers,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Mixture heads
|
|
188
|
+
self.pi_head = nn.Linear(hidden_dim, self.K)
|
|
189
|
+
self.mu_head = nn.Linear(hidden_dim, self.K * obj_dim)
|
|
190
|
+
self.log_var_head = nn.Linear(hidden_dim, self.K * obj_dim)
|
|
191
|
+
|
|
192
|
+
def forward(
|
|
193
|
+
self,
|
|
194
|
+
memory, # (B, N, D)
|
|
195
|
+
memory_mask, # (B, N)
|
|
196
|
+
query_embed, # (B, D)
|
|
197
|
+
z=None, # (B, latent_dim)
|
|
198
|
+
):
|
|
199
|
+
B = memory.size(0)
|
|
200
|
+
|
|
201
|
+
if z is None:
|
|
202
|
+
z = torch.randn(B, self.latent_dim, device=memory.device)
|
|
203
|
+
|
|
204
|
+
query = query_embed + self.latent_proj(z)
|
|
205
|
+
query = query.unsqueeze(1) # (B, 1, D)
|
|
206
|
+
|
|
207
|
+
out = self.decoder(
|
|
208
|
+
tgt=query,
|
|
209
|
+
memory=memory,
|
|
210
|
+
memory_key_padding_mask=(memory_mask == 0),
|
|
211
|
+
)[
|
|
212
|
+
:, 0
|
|
213
|
+
] # (B, D)
|
|
214
|
+
|
|
215
|
+
pi_logits = self.pi_head(out) # (B, K)
|
|
216
|
+
mu = self.mu_head(out).view(B, self.K, self.obj_dim)
|
|
217
|
+
log_var = self.log_var_head(out).view(B, self.K, self.obj_dim)
|
|
218
|
+
|
|
219
|
+
log_var = log_var.clamp(-6.0, 3.0)
|
|
220
|
+
|
|
221
|
+
return pi_logits, mu, log_var
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# ============================================================
|
|
225
|
+
# Object Generator Model
|
|
226
|
+
# ============================================================
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class ObjectGenerator(pl.LightningModule):
|
|
230
|
+
"""
|
|
231
|
+
Masked-object autoregressive generator with Hungarian matching.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
def __init__(
|
|
235
|
+
self,
|
|
236
|
+
num_classes: int,
|
|
237
|
+
obj_dim: int,
|
|
238
|
+
hidden_dim=128,
|
|
239
|
+
num_heads=4,
|
|
240
|
+
num_encoder_layers=2,
|
|
241
|
+
num_decoder_layers=2,
|
|
242
|
+
):
|
|
243
|
+
super().__init__()
|
|
244
|
+
|
|
245
|
+
self.num_classes = num_classes
|
|
246
|
+
self.obj_dim = obj_dim
|
|
247
|
+
|
|
248
|
+
self.class_embed = nn.Embedding(num_classes, hidden_dim)
|
|
249
|
+
self.obj_embed = nn.Linear(obj_dim, hidden_dim)
|
|
250
|
+
|
|
251
|
+
self.encoder = ObjectEncoder(
|
|
252
|
+
dim=hidden_dim, num_heads=num_heads, num_layers=num_encoder_layers
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
self.decoder = MDNCrossAttentionDecoder(
|
|
256
|
+
hidden_dim=hidden_dim,
|
|
257
|
+
obj_dim=obj_dim,
|
|
258
|
+
latent_dim=8,
|
|
259
|
+
num_heads=num_heads,
|
|
260
|
+
num_layers=num_decoder_layers,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# --------------------------------------------------------
|
|
264
|
+
|
|
265
|
+
def forward(self, objs, classes, mask, query_class):
|
|
266
|
+
"""
|
|
267
|
+
objs: (B, N, obj_dim)
|
|
268
|
+
classes: (B, N)
|
|
269
|
+
mask: (B, N) 1=present, 0=masked
|
|
270
|
+
query_class: (B,)
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
obj_emb = self.obj_embed(objs)
|
|
274
|
+
class_emb = self.class_embed(classes)
|
|
275
|
+
x = obj_emb + class_emb
|
|
276
|
+
|
|
277
|
+
enc = self.encoder(x, mask=(mask == 0))
|
|
278
|
+
query_emb = self.class_embed(query_class)
|
|
279
|
+
|
|
280
|
+
pi_logits, mu, log_var = self.decoder(
|
|
281
|
+
memory=enc,
|
|
282
|
+
memory_mask=mask,
|
|
283
|
+
query_embed=query_emb,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return pi_logits, mu, log_var
|
|
287
|
+
|
|
288
|
+
# --------------------------------------------------------
|
|
289
|
+
|
|
290
|
+
def hungarian_mdn_loss(self, pi_logits, mu, log_var, targets):
|
|
291
|
+
"""
|
|
292
|
+
pi_logits: (B, K)
|
|
293
|
+
mu: (B, K, D)
|
|
294
|
+
log_var: (B, K, D)
|
|
295
|
+
targets: (B, T, D)
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
total_loss = 0.0
|
|
299
|
+
B = mu.size(0)
|
|
300
|
+
|
|
301
|
+
for i in range(B):
|
|
302
|
+
t = targets[i] # (T, D)
|
|
303
|
+
|
|
304
|
+
costs = []
|
|
305
|
+
for obj in t:
|
|
306
|
+
obj = obj.unsqueeze(0)
|
|
307
|
+
cost = self.mdn_nll(
|
|
308
|
+
pi_logits[i : i + 1],
|
|
309
|
+
mu[i : i + 1],
|
|
310
|
+
log_var[i : i + 1],
|
|
311
|
+
obj,
|
|
312
|
+
)
|
|
313
|
+
costs.append(cost)
|
|
314
|
+
|
|
315
|
+
cost = torch.stack(costs).detach().cpu().numpy()
|
|
316
|
+
row, col = linear_sum_assignment(cost.reshape(1, -1))
|
|
317
|
+
matched = t[col[0]].unsqueeze(0)
|
|
318
|
+
|
|
319
|
+
total_loss += self.mdn_nll(
|
|
320
|
+
pi_logits[i : i + 1],
|
|
321
|
+
mu[i : i + 1],
|
|
322
|
+
log_var[i : i + 1],
|
|
323
|
+
matched,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
return total_loss / B
|
|
327
|
+
|
|
328
|
+
# --------------------------------------------------------
|
|
329
|
+
|
|
330
|
+
def training_step(self, batch, batch_idx):
|
|
331
|
+
objs = batch["objects"] # (B, N, D)
|
|
332
|
+
classes = batch["classes"] # (B, N)
|
|
333
|
+
mask = batch["mask"] # (B, N)
|
|
334
|
+
target_objs = batch["targets"] # (B, K, D)
|
|
335
|
+
query_class = batch["query_cls"] # (B,)
|
|
336
|
+
|
|
337
|
+
pi_logits, mu, log_var = self(objs, classes, mask, query_class)
|
|
338
|
+
loss = self.hungarian_mdn_loss(pi_logits, mu, log_var, target_objs)
|
|
339
|
+
|
|
340
|
+
self.log("obj_loss", loss)
|
|
341
|
+
return loss
|
|
342
|
+
|
|
343
|
+
# --------------------------------------------------------
|
|
344
|
+
|
|
345
|
+
def validation_step(self, batch, batch_idx):
|
|
346
|
+
objs = batch["objects"] # (B, N, D)
|
|
347
|
+
classes = batch["classes"] # (B, N)
|
|
348
|
+
mask = batch["mask"] # (B, N)
|
|
349
|
+
target_objs = batch["targets"] # (B, K, D)
|
|
350
|
+
query_class = batch["query_cls"] # (B,)
|
|
351
|
+
|
|
352
|
+
pi_logits, mu, log_var = self(objs, classes, mask, query_class)
|
|
353
|
+
loss = self.hungarian_mdn_loss(pi_logits, mu, log_var, target_objs)
|
|
354
|
+
|
|
355
|
+
self.log("val_loss", loss, prog_bar=True)
|
|
356
|
+
return loss
|
|
357
|
+
|
|
358
|
+
# --------------------------------------------------------
|
|
359
|
+
|
|
360
|
+
def mdn_nll(self, pi_logits, mu, log_var, target):
|
|
361
|
+
"""
|
|
362
|
+
pi_logits: (B, K)
|
|
363
|
+
mu: (B, K, D)
|
|
364
|
+
log_var: (B, K, D)
|
|
365
|
+
target: (B, D)
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
B, K, D = mu.shape
|
|
369
|
+
|
|
370
|
+
target = target.unsqueeze(1) # (B, 1, D)
|
|
371
|
+
|
|
372
|
+
log_pi = F.log_softmax(pi_logits, dim=-1) # (B, K)
|
|
373
|
+
|
|
374
|
+
log_prob = -0.5 * (
|
|
375
|
+
log_var + (target - mu).pow(2) / log_var.exp() + math.log(2 * math.pi)
|
|
376
|
+
).sum(
|
|
377
|
+
dim=-1
|
|
378
|
+
) # (B, K)
|
|
379
|
+
|
|
380
|
+
log_mix = torch.logsumexp(log_pi + log_prob, dim=-1)
|
|
381
|
+
|
|
382
|
+
return -log_mix.mean()
|
|
383
|
+
|
|
384
|
+
# --------------------------------------------------------
|
|
385
|
+
|
|
386
|
+
@torch.no_grad()
|
|
387
|
+
def sample_from_mdn(self, pi_logits, mu, log_var, temperature=1.0):
|
|
388
|
+
"""
|
|
389
|
+
Returns one sampled object per batch element.
|
|
390
|
+
"""
|
|
391
|
+
|
|
392
|
+
pi = F.softmax(pi_logits / temperature, dim=-1) # (B, K)
|
|
393
|
+
comp = torch.multinomial(pi, 1).squeeze(-1) # (B,)
|
|
394
|
+
|
|
395
|
+
B = mu.size(0)
|
|
396
|
+
idx = torch.arange(B)
|
|
397
|
+
|
|
398
|
+
sel_mu = mu[idx, comp]
|
|
399
|
+
sel_std = log_var[idx, comp].exp().sqrt() * temperature
|
|
400
|
+
|
|
401
|
+
random_vector = torch.randn_like(sel_mu)
|
|
402
|
+
random_vector[:, 2] = random_vector[:, 0]
|
|
403
|
+
random_vector[:, 3] = random_vector[:, 1]
|
|
404
|
+
|
|
405
|
+
return sel_mu + random_vector * sel_std
|
|
406
|
+
|
|
407
|
+
# --------------------------------------------------------
|
|
408
|
+
|
|
409
|
+
def configure_optimizers(self):
|
|
410
|
+
return torch.optim.Adam(self.parameters(), lr=1e-4)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
# ============================================================
|
|
414
|
+
# Inference
|
|
415
|
+
# ============================================================
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def generate_scribble(
|
|
419
|
+
model: ObjectGenerator,
|
|
420
|
+
objects: torch.Tensor,
|
|
421
|
+
classes: torch.Tensor,
|
|
422
|
+
mask: torch.Tensor,
|
|
423
|
+
colour_map: Dict[Tuple[int, int, int], int],
|
|
424
|
+
i: int,
|
|
425
|
+
output_dir: str,
|
|
426
|
+
) -> np.ndarray:
|
|
427
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
428
|
+
index = mask.argmin().item()
|
|
429
|
+
|
|
430
|
+
original = generate_multiclass_scribble(
|
|
431
|
+
image_shape=(512, 512),
|
|
432
|
+
objects=[unpack_feature_vector(obj) for obj in objects.numpy()],
|
|
433
|
+
classes=classes.numpy(),
|
|
434
|
+
colour_map=colour_map,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Save the original scribble
|
|
438
|
+
output_path = os.path.join(output_dir, f"{i}_original.png")
|
|
439
|
+
|
|
440
|
+
# Convert RGB to BGR for saving with OpenCV
|
|
441
|
+
original_bgr = cv2.cvtColor(original, cv2.COLOR_RGB2BGR)
|
|
442
|
+
cv2.imwrite(str(output_path), original_bgr)
|
|
443
|
+
|
|
444
|
+
classes_removed = classes.clone()
|
|
445
|
+
classes_removed[index:] = 0
|
|
446
|
+
synthetic = generate_multiclass_scribble(
|
|
447
|
+
image_shape=(512, 512),
|
|
448
|
+
objects=[unpack_feature_vector(obj) for obj in objects.numpy()],
|
|
449
|
+
classes=classes_removed.numpy(),
|
|
450
|
+
colour_map=colour_map,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Save the synthetic scribble
|
|
454
|
+
output_path = os.path.join(output_dir, f"{i}_synthetic_removed.png")
|
|
455
|
+
|
|
456
|
+
# Convert RGB to BGR for saving with OpenCV
|
|
457
|
+
synthetic_bgr = cv2.cvtColor(synthetic, cv2.COLOR_RGB2BGR)
|
|
458
|
+
cv2.imwrite(str(output_path), synthetic_bgr)
|
|
459
|
+
|
|
460
|
+
while index < len(classes) and classes[index] != 0:
|
|
461
|
+
with torch.no_grad():
|
|
462
|
+
pi_logits, mu, log_var = model(
|
|
463
|
+
objects[None],
|
|
464
|
+
classes[None],
|
|
465
|
+
mask[None],
|
|
466
|
+
classes[index : index + 1],
|
|
467
|
+
)
|
|
468
|
+
sample = model.sample_from_mdn(pi_logits, mu, log_var, temperature=1.0)[0]
|
|
469
|
+
|
|
470
|
+
objects[index] = sample
|
|
471
|
+
mask[index] = 1
|
|
472
|
+
index += 1
|
|
473
|
+
|
|
474
|
+
print(objects)
|
|
475
|
+
|
|
476
|
+
synthetic = generate_multiclass_scribble(
|
|
477
|
+
image_shape=(512, 512),
|
|
478
|
+
objects=[unpack_feature_vector(obj) for obj in objects.numpy()],
|
|
479
|
+
classes=classes.numpy(),
|
|
480
|
+
colour_map=colour_map,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Save the synthetic scribble
|
|
484
|
+
output_path = os.path.join(output_dir, f"{i}_synthetic.png")
|
|
485
|
+
|
|
486
|
+
# Convert RGB to BGR for saving with OpenCV
|
|
487
|
+
synthetic_bgr = cv2.cvtColor(synthetic, cv2.COLOR_RGB2BGR)
|
|
488
|
+
cv2.imwrite(str(output_path), synthetic_bgr)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
# ============================================================
|
|
492
|
+
# Training Entry Point
|
|
493
|
+
# ============================================================
|
|
494
|
+
|
|
495
|
+
def train_and_infer(
|
|
496
|
+
train_dir: str,
|
|
497
|
+
val_dir: str,
|
|
498
|
+
colour_map: Dict[Tuple[int, int, int], int],
|
|
499
|
+
checkpoint_dir: str = "./local/nn-checkpoints",
|
|
500
|
+
inference_dir: str = "./local/nn-inference",
|
|
501
|
+
batch_size: int = 8,
|
|
502
|
+
num_workers: int = 4,
|
|
503
|
+
max_epochs: int = 50,
|
|
504
|
+
num_classes: Optional[int] = None,
|
|
505
|
+
):
|
|
506
|
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
507
|
+
os.makedirs(inference_dir, exist_ok=True)
|
|
508
|
+
|
|
509
|
+
resolved_num_classes = (
|
|
510
|
+
num_classes if num_classes is not None else len(set(colour_map.values()))
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
train_dataset = ScribbleDataset(
|
|
514
|
+
num_classes=resolved_num_classes, data_dir=train_dir, colour_map=colour_map
|
|
515
|
+
)
|
|
516
|
+
train_loader = torch.utils.data.DataLoader(
|
|
517
|
+
train_dataset,
|
|
518
|
+
batch_size=batch_size,
|
|
519
|
+
num_workers=num_workers,
|
|
520
|
+
shuffle=True,
|
|
521
|
+
prefetch_factor=2,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
val_dataset = ScribbleDataset(
|
|
525
|
+
num_classes=resolved_num_classes, data_dir=val_dir, colour_map=colour_map
|
|
526
|
+
)
|
|
527
|
+
val_loader = torch.utils.data.DataLoader(
|
|
528
|
+
val_dataset,
|
|
529
|
+
batch_size=1,
|
|
530
|
+
num_workers=num_workers,
|
|
531
|
+
shuffle=False,
|
|
532
|
+
prefetch_factor=2,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
obj_dim = train_dataset[0]["objects"].shape[1]
|
|
536
|
+
|
|
537
|
+
obj_model = ObjectGenerator(
|
|
538
|
+
num_classes=resolved_num_classes,
|
|
539
|
+
obj_dim=obj_dim,
|
|
540
|
+
hidden_dim=256,
|
|
541
|
+
num_encoder_layers=4,
|
|
542
|
+
num_decoder_layers=4,
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
early_stop = EarlyStopping(monitor="val_loss", mode="min", verbose=True)
|
|
546
|
+
checkpoint = ModelCheckpoint(
|
|
547
|
+
monitor="val_loss",
|
|
548
|
+
mode="min",
|
|
549
|
+
save_top_k=1,
|
|
550
|
+
dirpath=checkpoint_dir,
|
|
551
|
+
filename="best-checkpoint",
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
trainer = pl.Trainer(max_epochs=max_epochs, callbacks=[early_stop, checkpoint])
|
|
555
|
+
trainer.fit(obj_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
|
|
556
|
+
|
|
557
|
+
for i in range(len(val_dataset)):
|
|
558
|
+
datum = val_dataset[i]
|
|
559
|
+
objects = datum["objects"]
|
|
560
|
+
classes = datum["classes"]
|
|
561
|
+
mask = datum["mask"]
|
|
562
|
+
generate_scribble(
|
|
563
|
+
obj_model,
|
|
564
|
+
objects=objects,
|
|
565
|
+
classes=classes,
|
|
566
|
+
mask=mask,
|
|
567
|
+
colour_map=colour_map,
|
|
568
|
+
i=i,
|
|
569
|
+
output_dir=inference_dir,
|
|
570
|
+
)
|