rgb-to-segmentation 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rgb_to_segmentation/__init__.py +10 -0
- rgb_to_segmentation/api.py +123 -0
- rgb_to_segmentation/clean.py +315 -0
- rgb_to_segmentation/cli.py +180 -0
- rgb_to_segmentation/models/__init__.py +5 -0
- rgb_to_segmentation/models/base_classifier.py +77 -0
- rgb_to_segmentation/models/cnn_decoder.py +103 -0
- rgb_to_segmentation/models/pixelwise_classifier.py +31 -0
- rgb_to_segmentation/nn.py +164 -0
- rgb_to_segmentation/train.py +212 -0
- rgb_to_segmentation/utils.py +40 -0
- rgb_to_segmentation-0.1.6.dist-info/METADATA +213 -0
- rgb_to_segmentation-0.1.6.dist-info/RECORD +15 -0
- rgb_to_segmentation-0.1.6.dist-info/WHEEL +4 -0
- rgb_to_segmentation-0.1.6.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import pytorch_lightning as pl
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PixelClassifier(pl.LightningModule):
|
|
7
|
+
"""
|
|
8
|
+
Base class for pixel-wise classifiers.
|
|
9
|
+
Provides common interface for different model architectures.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, output_dim):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.output_dim = output_dim
|
|
15
|
+
self.loss_fn = nn.CrossEntropyLoss()
|
|
16
|
+
|
|
17
|
+
def forward(self, x: torch.Tensor):
|
|
18
|
+
"""Forward pass - to be implemented by subclasses."""
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
def image_to_batch(self, x: torch.Tensor):
|
|
22
|
+
"""Convert image tensor to batch for processing.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
x: Input tensor of shape (C, H, W) or (B, C, H, W)
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Batch tensor ready for model input
|
|
29
|
+
"""
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
def _align_logits_and_target(
|
|
33
|
+
self, logits: torch.Tensor, target: torch.Tensor
|
|
34
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
35
|
+
"""Reshape predictions/targets so CrossEntropyLoss receives the expected shapes."""
|
|
36
|
+
|
|
37
|
+
# Handle channel-first 4D logits from CNN decoders: (B, C, H, W)
|
|
38
|
+
if logits.dim() == 4:
|
|
39
|
+
# Targets may come in as (B, 1, H, W) – squeeze the class channel
|
|
40
|
+
if target.dim() == 4 and target.size(1) == 1:
|
|
41
|
+
target = target.squeeze(1)
|
|
42
|
+
return logits, target
|
|
43
|
+
|
|
44
|
+
# Everything else (e.g. flattened pixel classifiers): collapse batch/pixel dims
|
|
45
|
+
logits = logits.view(-1, logits.size(-1))
|
|
46
|
+
target = target.view(-1)
|
|
47
|
+
return logits, target
|
|
48
|
+
|
|
49
|
+
def training_step(self, batch, batch_idx):
|
|
50
|
+
"""Training step."""
|
|
51
|
+
sample, target = batch
|
|
52
|
+
# Apply model-specific batching on GPU
|
|
53
|
+
sample = self.image_to_batch(sample)
|
|
54
|
+
target = self.image_to_batch(target).long()
|
|
55
|
+
|
|
56
|
+
logits = self(sample)
|
|
57
|
+
logits, target = self._align_logits_and_target(logits, target)
|
|
58
|
+
loss = self.loss_fn(logits, target)
|
|
59
|
+
self.log("train_loss", loss)
|
|
60
|
+
return loss
|
|
61
|
+
|
|
62
|
+
def validation_step(self, batch, batch_idx):
|
|
63
|
+
"""Validation step."""
|
|
64
|
+
sample, target = batch
|
|
65
|
+
# Apply model-specific batching on GPU
|
|
66
|
+
sample = self.image_to_batch(sample)
|
|
67
|
+
target = self.image_to_batch(target).long()
|
|
68
|
+
|
|
69
|
+
logits = self(sample)
|
|
70
|
+
logits, target = self._align_logits_and_target(logits, target)
|
|
71
|
+
loss = self.loss_fn(logits, target)
|
|
72
|
+
self.log("val_loss", loss, prog_bar=True)
|
|
73
|
+
return loss
|
|
74
|
+
|
|
75
|
+
def configure_optimizers(self):
|
|
76
|
+
"""Configure optimizer."""
|
|
77
|
+
return torch.optim.Adam(self.parameters())
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from .base_classifier import PixelClassifier
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CNNDecoder(PixelClassifier):
|
|
9
|
+
"""
|
|
10
|
+
CNN-based pixel classifier for image segmentation.
|
|
11
|
+
|
|
12
|
+
Uses a convolutional encoder-decoder architecture to process images
|
|
13
|
+
and produce per-pixel classifications.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, input_channels=3, hidden_dim=64, output_dim=2):
|
|
17
|
+
"""
|
|
18
|
+
Initialize CNN decoder.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
input_channels: Number of input channels (e.g., 3 for RGB)
|
|
22
|
+
hidden_dim: Number of hidden channels in conv layers
|
|
23
|
+
output_dim: Number of output classes
|
|
24
|
+
"""
|
|
25
|
+
super().__init__(output_dim=output_dim)
|
|
26
|
+
self.save_hyperparameters()
|
|
27
|
+
|
|
28
|
+
self.input_channels = input_channels
|
|
29
|
+
self.hidden_dim = hidden_dim
|
|
30
|
+
|
|
31
|
+
# Encoder: downsampling with convolutions
|
|
32
|
+
self.encoder = nn.Sequential(
|
|
33
|
+
nn.Conv2d(input_channels, hidden_dim, kernel_size=3, padding=1),
|
|
34
|
+
nn.ReLU(),
|
|
35
|
+
nn.Conv2d(hidden_dim, hidden_dim * 2, kernel_size=3, stride=2, padding=1),
|
|
36
|
+
nn.ReLU(),
|
|
37
|
+
nn.Conv2d(
|
|
38
|
+
hidden_dim * 2, hidden_dim * 4, kernel_size=3, stride=2, padding=1
|
|
39
|
+
),
|
|
40
|
+
nn.ReLU(),
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Decoder: upsampling with transpose convolutions
|
|
44
|
+
self.decoder = nn.Sequential(
|
|
45
|
+
nn.ConvTranspose2d(
|
|
46
|
+
hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=2, padding=1
|
|
47
|
+
),
|
|
48
|
+
nn.ReLU(),
|
|
49
|
+
nn.ConvTranspose2d(
|
|
50
|
+
hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1
|
|
51
|
+
),
|
|
52
|
+
nn.ReLU(),
|
|
53
|
+
nn.Conv2d(hidden_dim, output_dim, kernel_size=3, padding=1),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def forward(self, x: torch.Tensor):
|
|
57
|
+
"""
|
|
58
|
+
Forward pass.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
x: Input tensor of shape (B, C, H, W)
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Logits of shape (B, num_classes, H, W)
|
|
65
|
+
"""
|
|
66
|
+
encoded = self.encoder(x)
|
|
67
|
+
decoded = self.decoder(encoded)
|
|
68
|
+
return decoded
|
|
69
|
+
|
|
70
|
+
def image_to_batch(self, x: torch.Tensor):
|
|
71
|
+
"""
|
|
72
|
+
Convert image tensor to batch format for CNN processing.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
x: Input tensor of shape (C, H, W)
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Batch tensor of shape (1, C, H, W)
|
|
79
|
+
"""
|
|
80
|
+
if len(x.shape) == 3:
|
|
81
|
+
x = x.unsqueeze(0)
|
|
82
|
+
return x.to(self.device)
|
|
83
|
+
|
|
84
|
+
def training_step(self, batch, batch_idx):
|
|
85
|
+
"""Training step with softmax cross entropy loss."""
|
|
86
|
+
sample, target = batch
|
|
87
|
+
# sample shape: (B, C, H, W)
|
|
88
|
+
# target shape: (B, H, W) with class indices
|
|
89
|
+
|
|
90
|
+
logits = self(sample) # (B, num_classes, H, W)
|
|
91
|
+
|
|
92
|
+
# Compute loss
|
|
93
|
+
loss = self.loss_fn(logits, target)
|
|
94
|
+
self.log("train_loss", loss)
|
|
95
|
+
return loss
|
|
96
|
+
|
|
97
|
+
def validation_step(self, batch, batch_idx):
|
|
98
|
+
"""Validation step with softmax cross entropy loss."""
|
|
99
|
+
sample, target = batch
|
|
100
|
+
logits = self(sample) # (B, num_classes, H, W)
|
|
101
|
+
loss = self.loss_fn(logits, target)
|
|
102
|
+
self.log("val_loss", loss, prog_bar=True)
|
|
103
|
+
return loss
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from .base_classifier import PixelClassifier
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PixelwiseClassifier(PixelClassifier):
|
|
9
|
+
step = 0
|
|
10
|
+
val_step = 0
|
|
11
|
+
|
|
12
|
+
def __init__(self, input_dim, hidden_dim, output_dim):
|
|
13
|
+
super().__init__(output_dim=output_dim)
|
|
14
|
+
self.save_hyperparameters()
|
|
15
|
+
|
|
16
|
+
self.net = nn.Sequential(
|
|
17
|
+
nn.Linear(input_dim, hidden_dim),
|
|
18
|
+
nn.ReLU(),
|
|
19
|
+
nn.Linear(hidden_dim, output_dim),
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def forward(self, x: torch.Tensor):
|
|
23
|
+
return self.net(x)
|
|
24
|
+
|
|
25
|
+
def image_to_batch(self, x: torch.Tensor):
|
|
26
|
+
if len(x.shape) == 3:
|
|
27
|
+
x = x.unsqueeze(0)
|
|
28
|
+
return x.permute(0, 2, 3, 1).reshape(-1, x.shape[1]).to(self.device)
|
|
29
|
+
|
|
30
|
+
def batch_to_image(self, batch: torch.Tensor, height: int, width: int):
|
|
31
|
+
return batch.reshape(1, height, width, 3).permute(0, 3, 1, 2)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from torchvision.io import read_image
|
|
8
|
+
|
|
9
|
+
from .models.pixelwise_classifier import PixelwiseClassifier
|
|
10
|
+
from .models.cnn_decoder import CNNDecoder
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def clean_image_nn(
|
|
14
|
+
image_array: np.ndarray,
|
|
15
|
+
model: object,
|
|
16
|
+
colour_map: dict,
|
|
17
|
+
output_type: str = "rgb",
|
|
18
|
+
) -> np.ndarray:
|
|
19
|
+
"""
|
|
20
|
+
Clean a single image using the neural network model.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
image_array: Numpy array (H, W, 3) uint8
|
|
24
|
+
model: Trained model with `image_to_batch` and `forward`
|
|
25
|
+
colour_map: Dict[int, (r,g,b)] used for RGB output mapping
|
|
26
|
+
output_type: 'rgb' for colour image, 'index' for class-index mask
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
np.ndarray: (H,W,3) uint8 if output_type='rgb'; else (H,W) uint8
|
|
30
|
+
"""
|
|
31
|
+
if output_type not in ("rgb", "index"):
|
|
32
|
+
raise ValueError("output_type must be 'rgb' or 'index'")
|
|
33
|
+
|
|
34
|
+
import torch
|
|
35
|
+
|
|
36
|
+
img_t = torch.from_numpy(image_array).permute(2, 0, 1).float() / 127.5 - 1.0
|
|
37
|
+
h, w = img_t.shape[1], img_t.shape[2]
|
|
38
|
+
batch = model.image_to_batch(img_t)
|
|
39
|
+
|
|
40
|
+
with torch.no_grad():
|
|
41
|
+
probs = model(batch)
|
|
42
|
+
predicted = torch.argmax(probs, dim=-1).reshape(h, w)
|
|
43
|
+
|
|
44
|
+
if output_type == "rgb":
|
|
45
|
+
rgb_image_t = map_int_to_rgb(predicted, colour_map)
|
|
46
|
+
return rgb_image_t.permute(1, 2, 0).numpy()
|
|
47
|
+
else:
|
|
48
|
+
return predicted.cpu().numpy().astype(np.uint8)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_model(model_path: str):
|
|
52
|
+
if os.path.isdir(model_path):
|
|
53
|
+
checkpoint_files = [f for f in os.listdir(model_path) if f.endswith(".ckpt")]
|
|
54
|
+
|
|
55
|
+
if len(checkpoint_files) == 0:
|
|
56
|
+
raise ValueError(f"No checkpoint files found in directory: {model_path}")
|
|
57
|
+
|
|
58
|
+
model_path = os.path.join(model_path, checkpoint_files[0])
|
|
59
|
+
|
|
60
|
+
if "pixel_decoder" in model_path:
|
|
61
|
+
model = PixelwiseClassifier.load_from_checkpoint(checkpoint_path=model_path)
|
|
62
|
+
elif "cnn_decoder" in model_path:
|
|
63
|
+
model = CNNDecoder.load_from_checkpoint(checkpoint_path=model_path)
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"Model path must contain 'pixel_decoder' or 'cnn_decoder' to identify model type."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
model.eval()
|
|
70
|
+
return model
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def map_int_to_rgb(indexed_image: torch.Tensor, colour_map: dict):
|
|
74
|
+
h, w = indexed_image.shape
|
|
75
|
+
rgb_image = torch.zeros((3, h, w), dtype=torch.uint8)
|
|
76
|
+
|
|
77
|
+
for idx, rgb in colour_map.items():
|
|
78
|
+
mask = indexed_image == idx
|
|
79
|
+
for c in range(3):
|
|
80
|
+
rgb_image[c][mask] = rgb[c]
|
|
81
|
+
|
|
82
|
+
return rgb_image
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def run_inference(
|
|
86
|
+
input_dir: str,
|
|
87
|
+
output_dir: str,
|
|
88
|
+
inplace: bool = False,
|
|
89
|
+
model_path: str = None,
|
|
90
|
+
colour_map: dict = None,
|
|
91
|
+
exts: str = ".png,.jpg,.jpeg,.tiff,.bmp,.gif",
|
|
92
|
+
name_filter: str = "",
|
|
93
|
+
output_type: str = "rgb",
|
|
94
|
+
):
|
|
95
|
+
"""
|
|
96
|
+
Run neural network inference on segmentation images.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
input_dir (str): Path to input directory containing images.
|
|
100
|
+
output_dir (str, optional): Directory where output images will be written.
|
|
101
|
+
inplace (bool): Overwrite input images in place.
|
|
102
|
+
model_path (str): Path to the trained model file.
|
|
103
|
+
colour_map (dict): Mapping from class indices to RGB tuples.
|
|
104
|
+
exts (str): Comma-separated list of allowed image extensions.
|
|
105
|
+
name_filter (str): Only process files whose name contains this substring.
|
|
106
|
+
"""
|
|
107
|
+
if not inplace and output_dir is None:
|
|
108
|
+
raise ValueError("Either output_dir must be provided or inplace must be True")
|
|
109
|
+
|
|
110
|
+
if model_path is None:
|
|
111
|
+
raise ValueError("model_path must be provided")
|
|
112
|
+
|
|
113
|
+
if colour_map is None:
|
|
114
|
+
raise ValueError("colour_map must be provided")
|
|
115
|
+
|
|
116
|
+
model = load_model(model_path)
|
|
117
|
+
|
|
118
|
+
exts_list = [e.lower().strip() for e in exts.split(",")]
|
|
119
|
+
out_dir = output_dir if not inplace else input_dir
|
|
120
|
+
|
|
121
|
+
if not inplace:
|
|
122
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
123
|
+
|
|
124
|
+
print(
|
|
125
|
+
f"Running inference: input={input_dir} -> output={out_dir}, output_type={output_type}"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
for root, dirs, files in os.walk(input_dir):
|
|
129
|
+
rel = os.path.relpath(root, input_dir)
|
|
130
|
+
out_root = os.path.join(out_dir, rel) if not inplace else root
|
|
131
|
+
|
|
132
|
+
os.makedirs(out_root, exist_ok=True)
|
|
133
|
+
|
|
134
|
+
for fname in files:
|
|
135
|
+
if not any(fname.lower().endswith(e) for e in exts_list):
|
|
136
|
+
continue
|
|
137
|
+
if name_filter and name_filter not in fname:
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
in_path = os.path.join(root, fname)
|
|
141
|
+
out_path = os.path.join(out_root, fname)
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
# Load as numpy for the shared single-image cleaner
|
|
145
|
+
pil_img = Image.open(in_path).convert("RGB")
|
|
146
|
+
np_img = np.array(pil_img, dtype=np.uint8)
|
|
147
|
+
|
|
148
|
+
cleaned = clean_image_nn(
|
|
149
|
+
np_img, model=model, colour_map=colour_map, output_type=output_type
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if output_type == "rgb":
|
|
153
|
+
pil_image = Image.fromarray(cleaned)
|
|
154
|
+
pil_image.save(out_path)
|
|
155
|
+
elif output_type == "index":
|
|
156
|
+
pil_image = Image.fromarray(cleaned.astype("uint8"), mode="L")
|
|
157
|
+
pil_image.save(out_path)
|
|
158
|
+
else:
|
|
159
|
+
raise ValueError("output_type must be 'rgb' or 'index'")
|
|
160
|
+
|
|
161
|
+
except Exception as e:
|
|
162
|
+
print(f"Skipping {in_path}: {e}")
|
|
163
|
+
|
|
164
|
+
print("Inference done.")
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from pytorch_lightning import Trainer
|
|
6
|
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|
7
|
+
from torch.utils.data import DataLoader, Dataset, random_split
|
|
8
|
+
from torchvision.io import read_image
|
|
9
|
+
|
|
10
|
+
from .nn import PixelwiseClassifier
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
VALIDATION_FRACTION = 0.2
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def map_colour_to_int(sample, colour_map):
|
|
17
|
+
"""Vectorized color to class index mapping."""
|
|
18
|
+
_, H, W = sample.shape
|
|
19
|
+
|
|
20
|
+
# Create lookup tensor: shape (num_classes, 3)
|
|
21
|
+
num_classes = len(colour_map)
|
|
22
|
+
color_array = torch.zeros((num_classes, 3), dtype=sample.dtype)
|
|
23
|
+
for idx, rgb in colour_map.items():
|
|
24
|
+
color_array[idx] = torch.tensor(rgb, dtype=sample.dtype)
|
|
25
|
+
|
|
26
|
+
# Reshape for broadcasting: (3, H, W) -> (H, W, 3) -> (H*W, 3)
|
|
27
|
+
sample_flat = sample.permute(1, 2, 0).reshape(-1, 3)
|
|
28
|
+
|
|
29
|
+
# Compute distances to all colors: (H*W, num_classes)
|
|
30
|
+
distances = torch.cdist(sample_flat.float(), color_array.float(), p=2)
|
|
31
|
+
|
|
32
|
+
# Assign to nearest color
|
|
33
|
+
image_array = distances.argmin(dim=1).reshape(1, H, W)
|
|
34
|
+
|
|
35
|
+
return image_array
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def collate_fn(batch):
|
|
39
|
+
"""Custom collate to apply image_to_batch on GPU."""
|
|
40
|
+
samples, targets = zip(*batch)
|
|
41
|
+
# Stack into batch tensors
|
|
42
|
+
samples = torch.stack(samples, dim=0)
|
|
43
|
+
targets = torch.stack(targets, dim=0)
|
|
44
|
+
# Apply model-specific batching (will happen on GPU in training loop)
|
|
45
|
+
return samples, targets
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SegMaskDataset(Dataset):
|
|
49
|
+
def __init__(self, paired_filenames, colour_map, model):
|
|
50
|
+
self.paired_filenames = paired_filenames
|
|
51
|
+
self.colour_map = colour_map
|
|
52
|
+
self.to_batch = model.image_to_batch
|
|
53
|
+
|
|
54
|
+
def __len__(self):
|
|
55
|
+
return len(self.paired_filenames)
|
|
56
|
+
|
|
57
|
+
def __getitem__(self, index):
|
|
58
|
+
sample_path = self.paired_filenames[index]["sample"]
|
|
59
|
+
target_path = self.paired_filenames[index]["target"]
|
|
60
|
+
|
|
61
|
+
sample = (
|
|
62
|
+
read_image(sample_path, mode="RGB") / 127.5
|
|
63
|
+
) - 1.0 # Normalize to [-1, 1]
|
|
64
|
+
target = read_image(target_path, mode="RGB")
|
|
65
|
+
target = map_colour_to_int(target, self.colour_map)
|
|
66
|
+
sample = self.to_batch(sample)
|
|
67
|
+
target = self.to_batch(target).to(torch.long)
|
|
68
|
+
|
|
69
|
+
return (sample, target)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_png_basenames(directory: str) -> list[str]:
|
|
73
|
+
return [
|
|
74
|
+
os.path.splitext(filename)[0]
|
|
75
|
+
for filename in os.listdir(directory)
|
|
76
|
+
if filename.endswith(".png")
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_paired_filenames(
|
|
81
|
+
image_dir: str,
|
|
82
|
+
label_dir: str,
|
|
83
|
+
noisy_basenames: list[str],
|
|
84
|
+
label_basenames: list[str],
|
|
85
|
+
training_label_basenames: list[str],
|
|
86
|
+
) -> list[dict]:
|
|
87
|
+
training_paired_filenames = []
|
|
88
|
+
val_paired_filenames = []
|
|
89
|
+
|
|
90
|
+
for noisy_image_name in noisy_basenames:
|
|
91
|
+
targets = [
|
|
92
|
+
target_file
|
|
93
|
+
for target_file in label_basenames
|
|
94
|
+
if target_file in noisy_image_name
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
if len(targets) > 1:
|
|
98
|
+
raise Exception(
|
|
99
|
+
f"Multiple target files exist for noisy file {noisy_image_name}: {targets}."
|
|
100
|
+
)
|
|
101
|
+
elif len(targets) < 1:
|
|
102
|
+
print(
|
|
103
|
+
f"WARNING: No target found for noisy file {noisy_image_name}. Discarding."
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
target_filename = targets[0]
|
|
107
|
+
|
|
108
|
+
if target_filename in training_label_basenames:
|
|
109
|
+
training_paired_filenames.append(
|
|
110
|
+
{
|
|
111
|
+
"sample": f"{image_dir}/{noisy_image_name}.png",
|
|
112
|
+
"target": f"{label_dir}/{target_filename}.png",
|
|
113
|
+
}
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
val_paired_filenames.append(
|
|
117
|
+
{
|
|
118
|
+
"sample": f"{image_dir}/{noisy_image_name}.png",
|
|
119
|
+
"target": f"{label_dir}/{target_filename}.png",
|
|
120
|
+
}
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return training_paired_filenames, val_paired_filenames
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_dataloaders(
|
|
127
|
+
image_dir, label_dir, colour_map, model
|
|
128
|
+
) -> tuple[DataLoader, DataLoader]:
|
|
129
|
+
noisy_basenames = get_png_basenames(image_dir)
|
|
130
|
+
label_basenames = get_png_basenames(label_dir)
|
|
131
|
+
|
|
132
|
+
training_label_basenames, val_label_basenames = random_split(
|
|
133
|
+
label_basenames, [1 - VALIDATION_FRACTION, VALIDATION_FRACTION]
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
training_paired_filenames, val_paired_filenames = get_paired_filenames(
|
|
137
|
+
image_dir,
|
|
138
|
+
label_dir,
|
|
139
|
+
noisy_basenames,
|
|
140
|
+
label_basenames,
|
|
141
|
+
training_label_basenames,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
train_dataset = SegMaskDataset(training_paired_filenames, colour_map, model)
|
|
145
|
+
val_dataset = SegMaskDataset(val_paired_filenames, colour_map, model)
|
|
146
|
+
|
|
147
|
+
train_dataloader = DataLoader(
|
|
148
|
+
dataset=train_dataset,
|
|
149
|
+
batch_size=1,
|
|
150
|
+
shuffle=True,
|
|
151
|
+
num_workers=4,
|
|
152
|
+
prefetch_factor=4,
|
|
153
|
+
collate_fn=collate_fn,
|
|
154
|
+
)
|
|
155
|
+
val_dataloader = DataLoader(
|
|
156
|
+
dataset=val_dataset,
|
|
157
|
+
batch_size=1,
|
|
158
|
+
shuffle=False,
|
|
159
|
+
num_workers=4,
|
|
160
|
+
prefetch_factor=4,
|
|
161
|
+
collate_fn=collate_fn,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return train_dataloader, val_dataloader
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def get_model(model_type, colour_map):
|
|
168
|
+
num_classes = len(colour_map.keys())
|
|
169
|
+
|
|
170
|
+
if model_type == "pixel_decoder":
|
|
171
|
+
return PixelwiseClassifier(input_dim=3, hidden_dim=32, output_dim=num_classes)
|
|
172
|
+
|
|
173
|
+
raise ValueError(f"Invalid model type: {model_type}")
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def train_model(
|
|
177
|
+
image_dir: str,
|
|
178
|
+
label_dir: str,
|
|
179
|
+
output_dir: str,
|
|
180
|
+
colour_map: dict,
|
|
181
|
+
model_type: str = "pixel_decoder",
|
|
182
|
+
):
|
|
183
|
+
"""
|
|
184
|
+
Train a neural network model for segmentation cleaning.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
image_dir (str): Path to directory containing noisy images.
|
|
188
|
+
label_dir (str): Path to directory containing target RGB labels.
|
|
189
|
+
output_dir (str): Directory where model weights will be saved.
|
|
190
|
+
colour_map (dict): Mapping from class indices to RGB tuples.
|
|
191
|
+
model_type (str): The type of model to train.
|
|
192
|
+
"""
|
|
193
|
+
model = get_model(model_type, colour_map)
|
|
194
|
+
train_dataloader, val_dataloader = get_dataloaders(
|
|
195
|
+
image_dir, label_dir, colour_map, model
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
199
|
+
|
|
200
|
+
early_stop = EarlyStopping(monitor="val_loss", mode="min", verbose=True)
|
|
201
|
+
checkpoint = ModelCheckpoint(
|
|
202
|
+
monitor="val_loss",
|
|
203
|
+
mode="min",
|
|
204
|
+
save_top_k=1,
|
|
205
|
+
dirpath=output_dir,
|
|
206
|
+
filename=f"{model_type}.ckpt",
|
|
207
|
+
)
|
|
208
|
+
trainer = Trainer(max_epochs=100, callbacks=[early_stop, checkpoint])
|
|
209
|
+
|
|
210
|
+
trainer.fit(
|
|
211
|
+
model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
|
|
212
|
+
)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def parse_colours_from_string(colours_str: str) -> List[Tuple[int, int, int]]:
|
|
5
|
+
"""Parse a semicolon-separated string of RGB triples into a list of tuples."""
|
|
6
|
+
|
|
7
|
+
parts = [p.strip() for p in colours_str.split(";") if p.strip()]
|
|
8
|
+
colours = []
|
|
9
|
+
|
|
10
|
+
for p in parts:
|
|
11
|
+
rgb = tuple(int(x) for x in p.split(","))
|
|
12
|
+
|
|
13
|
+
if len(rgb) != 3:
|
|
14
|
+
raise ValueError(f"Invalid colour triple: {p}")
|
|
15
|
+
colours.append(rgb)
|
|
16
|
+
|
|
17
|
+
return colours
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def parse_colours_from_file(path: str) -> List[Tuple[int, int, int]]:
|
|
21
|
+
"""Parse a file with one RGB triple per line into a list of tuples."""
|
|
22
|
+
|
|
23
|
+
colours = []
|
|
24
|
+
|
|
25
|
+
with open(path, "r") as f:
|
|
26
|
+
for line in f:
|
|
27
|
+
line = line.strip()
|
|
28
|
+
if not line or line.startswith("#"):
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
rgb = tuple(int(x) for x in line.split(","))
|
|
32
|
+
if len(rgb) != 3:
|
|
33
|
+
raise ValueError(f"Invalid colour triple in file {path}: {line}")
|
|
34
|
+
|
|
35
|
+
colours.append(rgb)
|
|
36
|
+
|
|
37
|
+
if not colours:
|
|
38
|
+
raise ValueError(f"No colours found in file: {path}")
|
|
39
|
+
|
|
40
|
+
return colours
|