rgb-to-segmentation 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,77 @@
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class PixelClassifier(pl.LightningModule):
7
+ """
8
+ Base class for pixel-wise classifiers.
9
+ Provides common interface for different model architectures.
10
+ """
11
+
12
+ def __init__(self, output_dim):
13
+ super().__init__()
14
+ self.output_dim = output_dim
15
+ self.loss_fn = nn.CrossEntropyLoss()
16
+
17
+ def forward(self, x: torch.Tensor):
18
+ """Forward pass - to be implemented by subclasses."""
19
+ raise NotImplementedError
20
+
21
+ def image_to_batch(self, x: torch.Tensor):
22
+ """Convert image tensor to batch for processing.
23
+
24
+ Args:
25
+ x: Input tensor of shape (C, H, W) or (B, C, H, W)
26
+
27
+ Returns:
28
+ Batch tensor ready for model input
29
+ """
30
+ raise NotImplementedError
31
+
32
+ def _align_logits_and_target(
33
+ self, logits: torch.Tensor, target: torch.Tensor
34
+ ) -> tuple[torch.Tensor, torch.Tensor]:
35
+ """Reshape predictions/targets so CrossEntropyLoss receives the expected shapes."""
36
+
37
+ # Handle channel-first 4D logits from CNN decoders: (B, C, H, W)
38
+ if logits.dim() == 4:
39
+ # Targets may come in as (B, 1, H, W) – squeeze the class channel
40
+ if target.dim() == 4 and target.size(1) == 1:
41
+ target = target.squeeze(1)
42
+ return logits, target
43
+
44
+ # Everything else (e.g. flattened pixel classifiers): collapse batch/pixel dims
45
+ logits = logits.view(-1, logits.size(-1))
46
+ target = target.view(-1)
47
+ return logits, target
48
+
49
+ def training_step(self, batch, batch_idx):
50
+ """Training step."""
51
+ sample, target = batch
52
+ # Apply model-specific batching on GPU
53
+ sample = self.image_to_batch(sample)
54
+ target = self.image_to_batch(target).long()
55
+
56
+ logits = self(sample)
57
+ logits, target = self._align_logits_and_target(logits, target)
58
+ loss = self.loss_fn(logits, target)
59
+ self.log("train_loss", loss)
60
+ return loss
61
+
62
+ def validation_step(self, batch, batch_idx):
63
+ """Validation step."""
64
+ sample, target = batch
65
+ # Apply model-specific batching on GPU
66
+ sample = self.image_to_batch(sample)
67
+ target = self.image_to_batch(target).long()
68
+
69
+ logits = self(sample)
70
+ logits, target = self._align_logits_and_target(logits, target)
71
+ loss = self.loss_fn(logits, target)
72
+ self.log("val_loss", loss, prog_bar=True)
73
+ return loss
74
+
75
+ def configure_optimizers(self):
76
+ """Configure optimizer."""
77
+ return torch.optim.Adam(self.parameters())
@@ -0,0 +1,103 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_classifier import PixelClassifier
6
+
7
+
8
+ class CNNDecoder(PixelClassifier):
9
+ """
10
+ CNN-based pixel classifier for image segmentation.
11
+
12
+ Uses a convolutional encoder-decoder architecture to process images
13
+ and produce per-pixel classifications.
14
+ """
15
+
16
+ def __init__(self, input_channels=3, hidden_dim=64, output_dim=2):
17
+ """
18
+ Initialize CNN decoder.
19
+
20
+ Args:
21
+ input_channels: Number of input channels (e.g., 3 for RGB)
22
+ hidden_dim: Number of hidden channels in conv layers
23
+ output_dim: Number of output classes
24
+ """
25
+ super().__init__(output_dim=output_dim)
26
+ self.save_hyperparameters()
27
+
28
+ self.input_channels = input_channels
29
+ self.hidden_dim = hidden_dim
30
+
31
+ # Encoder: downsampling with convolutions
32
+ self.encoder = nn.Sequential(
33
+ nn.Conv2d(input_channels, hidden_dim, kernel_size=3, padding=1),
34
+ nn.ReLU(),
35
+ nn.Conv2d(hidden_dim, hidden_dim * 2, kernel_size=3, stride=2, padding=1),
36
+ nn.ReLU(),
37
+ nn.Conv2d(
38
+ hidden_dim * 2, hidden_dim * 4, kernel_size=3, stride=2, padding=1
39
+ ),
40
+ nn.ReLU(),
41
+ )
42
+
43
+ # Decoder: upsampling with transpose convolutions
44
+ self.decoder = nn.Sequential(
45
+ nn.ConvTranspose2d(
46
+ hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=2, padding=1
47
+ ),
48
+ nn.ReLU(),
49
+ nn.ConvTranspose2d(
50
+ hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1
51
+ ),
52
+ nn.ReLU(),
53
+ nn.Conv2d(hidden_dim, output_dim, kernel_size=3, padding=1),
54
+ )
55
+
56
+ def forward(self, x: torch.Tensor):
57
+ """
58
+ Forward pass.
59
+
60
+ Args:
61
+ x: Input tensor of shape (B, C, H, W)
62
+
63
+ Returns:
64
+ Logits of shape (B, num_classes, H, W)
65
+ """
66
+ encoded = self.encoder(x)
67
+ decoded = self.decoder(encoded)
68
+ return decoded
69
+
70
+ def image_to_batch(self, x: torch.Tensor):
71
+ """
72
+ Convert image tensor to batch format for CNN processing.
73
+
74
+ Args:
75
+ x: Input tensor of shape (C, H, W)
76
+
77
+ Returns:
78
+ Batch tensor of shape (1, C, H, W)
79
+ """
80
+ if len(x.shape) == 3:
81
+ x = x.unsqueeze(0)
82
+ return x.to(self.device)
83
+
84
+ def training_step(self, batch, batch_idx):
85
+ """Training step with softmax cross entropy loss."""
86
+ sample, target = batch
87
+ # sample shape: (B, C, H, W)
88
+ # target shape: (B, H, W) with class indices
89
+
90
+ logits = self(sample) # (B, num_classes, H, W)
91
+
92
+ # Compute loss
93
+ loss = self.loss_fn(logits, target)
94
+ self.log("train_loss", loss)
95
+ return loss
96
+
97
+ def validation_step(self, batch, batch_idx):
98
+ """Validation step with softmax cross entropy loss."""
99
+ sample, target = batch
100
+ logits = self(sample) # (B, num_classes, H, W)
101
+ loss = self.loss_fn(logits, target)
102
+ self.log("val_loss", loss, prog_bar=True)
103
+ return loss
@@ -0,0 +1,31 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_classifier import PixelClassifier
6
+
7
+
8
+ class PixelwiseClassifier(PixelClassifier):
9
+ step = 0
10
+ val_step = 0
11
+
12
+ def __init__(self, input_dim, hidden_dim, output_dim):
13
+ super().__init__(output_dim=output_dim)
14
+ self.save_hyperparameters()
15
+
16
+ self.net = nn.Sequential(
17
+ nn.Linear(input_dim, hidden_dim),
18
+ nn.ReLU(),
19
+ nn.Linear(hidden_dim, output_dim),
20
+ )
21
+
22
+ def forward(self, x: torch.Tensor):
23
+ return self.net(x)
24
+
25
+ def image_to_batch(self, x: torch.Tensor):
26
+ if len(x.shape) == 3:
27
+ x = x.unsqueeze(0)
28
+ return x.permute(0, 2, 3, 1).reshape(-1, x.shape[1]).to(self.device)
29
+
30
+ def batch_to_image(self, batch: torch.Tensor, height: int, width: int):
31
+ return batch.reshape(1, height, width, 3).permute(0, 3, 1, 2)
@@ -0,0 +1,164 @@
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from torchvision.io import read_image
8
+
9
+ from .models.pixelwise_classifier import PixelwiseClassifier
10
+ from .models.cnn_decoder import CNNDecoder
11
+
12
+
13
+ def clean_image_nn(
14
+ image_array: np.ndarray,
15
+ model: object,
16
+ colour_map: dict,
17
+ output_type: str = "rgb",
18
+ ) -> np.ndarray:
19
+ """
20
+ Clean a single image using the neural network model.
21
+
22
+ Args:
23
+ image_array: Numpy array (H, W, 3) uint8
24
+ model: Trained model with `image_to_batch` and `forward`
25
+ colour_map: Dict[int, (r,g,b)] used for RGB output mapping
26
+ output_type: 'rgb' for colour image, 'index' for class-index mask
27
+
28
+ Returns:
29
+ np.ndarray: (H,W,3) uint8 if output_type='rgb'; else (H,W) uint8
30
+ """
31
+ if output_type not in ("rgb", "index"):
32
+ raise ValueError("output_type must be 'rgb' or 'index'")
33
+
34
+ import torch
35
+
36
+ img_t = torch.from_numpy(image_array).permute(2, 0, 1).float() / 127.5 - 1.0
37
+ h, w = img_t.shape[1], img_t.shape[2]
38
+ batch = model.image_to_batch(img_t)
39
+
40
+ with torch.no_grad():
41
+ probs = model(batch)
42
+ predicted = torch.argmax(probs, dim=-1).reshape(h, w)
43
+
44
+ if output_type == "rgb":
45
+ rgb_image_t = map_int_to_rgb(predicted, colour_map)
46
+ return rgb_image_t.permute(1, 2, 0).numpy()
47
+ else:
48
+ return predicted.cpu().numpy().astype(np.uint8)
49
+
50
+
51
+ def load_model(model_path: str):
52
+ if os.path.isdir(model_path):
53
+ checkpoint_files = [f for f in os.listdir(model_path) if f.endswith(".ckpt")]
54
+
55
+ if len(checkpoint_files) == 0:
56
+ raise ValueError(f"No checkpoint files found in directory: {model_path}")
57
+
58
+ model_path = os.path.join(model_path, checkpoint_files[0])
59
+
60
+ if "pixel_decoder" in model_path:
61
+ model = PixelwiseClassifier.load_from_checkpoint(checkpoint_path=model_path)
62
+ elif "cnn_decoder" in model_path:
63
+ model = CNNDecoder.load_from_checkpoint(checkpoint_path=model_path)
64
+ else:
65
+ raise ValueError(
66
+ "Model path must contain 'pixel_decoder' or 'cnn_decoder' to identify model type."
67
+ )
68
+
69
+ model.eval()
70
+ return model
71
+
72
+
73
+ def map_int_to_rgb(indexed_image: torch.Tensor, colour_map: dict):
74
+ h, w = indexed_image.shape
75
+ rgb_image = torch.zeros((3, h, w), dtype=torch.uint8)
76
+
77
+ for idx, rgb in colour_map.items():
78
+ mask = indexed_image == idx
79
+ for c in range(3):
80
+ rgb_image[c][mask] = rgb[c]
81
+
82
+ return rgb_image
83
+
84
+
85
+ def run_inference(
86
+ input_dir: str,
87
+ output_dir: str,
88
+ inplace: bool = False,
89
+ model_path: str = None,
90
+ colour_map: dict = None,
91
+ exts: str = ".png,.jpg,.jpeg,.tiff,.bmp,.gif",
92
+ name_filter: str = "",
93
+ output_type: str = "rgb",
94
+ ):
95
+ """
96
+ Run neural network inference on segmentation images.
97
+
98
+ Args:
99
+ input_dir (str): Path to input directory containing images.
100
+ output_dir (str, optional): Directory where output images will be written.
101
+ inplace (bool): Overwrite input images in place.
102
+ model_path (str): Path to the trained model file.
103
+ colour_map (dict): Mapping from class indices to RGB tuples.
104
+ exts (str): Comma-separated list of allowed image extensions.
105
+ name_filter (str): Only process files whose name contains this substring.
106
+ """
107
+ if not inplace and output_dir is None:
108
+ raise ValueError("Either output_dir must be provided or inplace must be True")
109
+
110
+ if model_path is None:
111
+ raise ValueError("model_path must be provided")
112
+
113
+ if colour_map is None:
114
+ raise ValueError("colour_map must be provided")
115
+
116
+ model = load_model(model_path)
117
+
118
+ exts_list = [e.lower().strip() for e in exts.split(",")]
119
+ out_dir = output_dir if not inplace else input_dir
120
+
121
+ if not inplace:
122
+ os.makedirs(out_dir, exist_ok=True)
123
+
124
+ print(
125
+ f"Running inference: input={input_dir} -> output={out_dir}, output_type={output_type}"
126
+ )
127
+
128
+ for root, dirs, files in os.walk(input_dir):
129
+ rel = os.path.relpath(root, input_dir)
130
+ out_root = os.path.join(out_dir, rel) if not inplace else root
131
+
132
+ os.makedirs(out_root, exist_ok=True)
133
+
134
+ for fname in files:
135
+ if not any(fname.lower().endswith(e) for e in exts_list):
136
+ continue
137
+ if name_filter and name_filter not in fname:
138
+ continue
139
+
140
+ in_path = os.path.join(root, fname)
141
+ out_path = os.path.join(out_root, fname)
142
+
143
+ try:
144
+ # Load as numpy for the shared single-image cleaner
145
+ pil_img = Image.open(in_path).convert("RGB")
146
+ np_img = np.array(pil_img, dtype=np.uint8)
147
+
148
+ cleaned = clean_image_nn(
149
+ np_img, model=model, colour_map=colour_map, output_type=output_type
150
+ )
151
+
152
+ if output_type == "rgb":
153
+ pil_image = Image.fromarray(cleaned)
154
+ pil_image.save(out_path)
155
+ elif output_type == "index":
156
+ pil_image = Image.fromarray(cleaned.astype("uint8"), mode="L")
157
+ pil_image.save(out_path)
158
+ else:
159
+ raise ValueError("output_type must be 'rgb' or 'index'")
160
+
161
+ except Exception as e:
162
+ print(f"Skipping {in_path}: {e}")
163
+
164
+ print("Inference done.")
@@ -0,0 +1,212 @@
1
+ import os
2
+
3
+ import torch
4
+
5
+ from pytorch_lightning import Trainer
6
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
7
+ from torch.utils.data import DataLoader, Dataset, random_split
8
+ from torchvision.io import read_image
9
+
10
+ from .nn import PixelwiseClassifier
11
+
12
+
13
+ VALIDATION_FRACTION = 0.2
14
+
15
+
16
+ def map_colour_to_int(sample, colour_map):
17
+ """Vectorized color to class index mapping."""
18
+ _, H, W = sample.shape
19
+
20
+ # Create lookup tensor: shape (num_classes, 3)
21
+ num_classes = len(colour_map)
22
+ color_array = torch.zeros((num_classes, 3), dtype=sample.dtype)
23
+ for idx, rgb in colour_map.items():
24
+ color_array[idx] = torch.tensor(rgb, dtype=sample.dtype)
25
+
26
+ # Reshape for broadcasting: (3, H, W) -> (H, W, 3) -> (H*W, 3)
27
+ sample_flat = sample.permute(1, 2, 0).reshape(-1, 3)
28
+
29
+ # Compute distances to all colors: (H*W, num_classes)
30
+ distances = torch.cdist(sample_flat.float(), color_array.float(), p=2)
31
+
32
+ # Assign to nearest color
33
+ image_array = distances.argmin(dim=1).reshape(1, H, W)
34
+
35
+ return image_array
36
+
37
+
38
+ def collate_fn(batch):
39
+ """Custom collate to apply image_to_batch on GPU."""
40
+ samples, targets = zip(*batch)
41
+ # Stack into batch tensors
42
+ samples = torch.stack(samples, dim=0)
43
+ targets = torch.stack(targets, dim=0)
44
+ # Apply model-specific batching (will happen on GPU in training loop)
45
+ return samples, targets
46
+
47
+
48
+ class SegMaskDataset(Dataset):
49
+ def __init__(self, paired_filenames, colour_map, model):
50
+ self.paired_filenames = paired_filenames
51
+ self.colour_map = colour_map
52
+ self.to_batch = model.image_to_batch
53
+
54
+ def __len__(self):
55
+ return len(self.paired_filenames)
56
+
57
+ def __getitem__(self, index):
58
+ sample_path = self.paired_filenames[index]["sample"]
59
+ target_path = self.paired_filenames[index]["target"]
60
+
61
+ sample = (
62
+ read_image(sample_path, mode="RGB") / 127.5
63
+ ) - 1.0 # Normalize to [-1, 1]
64
+ target = read_image(target_path, mode="RGB")
65
+ target = map_colour_to_int(target, self.colour_map)
66
+ sample = self.to_batch(sample)
67
+ target = self.to_batch(target).to(torch.long)
68
+
69
+ return (sample, target)
70
+
71
+
72
+ def get_png_basenames(directory: str) -> list[str]:
73
+ return [
74
+ os.path.splitext(filename)[0]
75
+ for filename in os.listdir(directory)
76
+ if filename.endswith(".png")
77
+ ]
78
+
79
+
80
+ def get_paired_filenames(
81
+ image_dir: str,
82
+ label_dir: str,
83
+ noisy_basenames: list[str],
84
+ label_basenames: list[str],
85
+ training_label_basenames: list[str],
86
+ ) -> list[dict]:
87
+ training_paired_filenames = []
88
+ val_paired_filenames = []
89
+
90
+ for noisy_image_name in noisy_basenames:
91
+ targets = [
92
+ target_file
93
+ for target_file in label_basenames
94
+ if target_file in noisy_image_name
95
+ ]
96
+
97
+ if len(targets) > 1:
98
+ raise Exception(
99
+ f"Multiple target files exist for noisy file {noisy_image_name}: {targets}."
100
+ )
101
+ elif len(targets) < 1:
102
+ print(
103
+ f"WARNING: No target found for noisy file {noisy_image_name}. Discarding."
104
+ )
105
+ else:
106
+ target_filename = targets[0]
107
+
108
+ if target_filename in training_label_basenames:
109
+ training_paired_filenames.append(
110
+ {
111
+ "sample": f"{image_dir}/{noisy_image_name}.png",
112
+ "target": f"{label_dir}/{target_filename}.png",
113
+ }
114
+ )
115
+ else:
116
+ val_paired_filenames.append(
117
+ {
118
+ "sample": f"{image_dir}/{noisy_image_name}.png",
119
+ "target": f"{label_dir}/{target_filename}.png",
120
+ }
121
+ )
122
+
123
+ return training_paired_filenames, val_paired_filenames
124
+
125
+
126
+ def get_dataloaders(
127
+ image_dir, label_dir, colour_map, model
128
+ ) -> tuple[DataLoader, DataLoader]:
129
+ noisy_basenames = get_png_basenames(image_dir)
130
+ label_basenames = get_png_basenames(label_dir)
131
+
132
+ training_label_basenames, val_label_basenames = random_split(
133
+ label_basenames, [1 - VALIDATION_FRACTION, VALIDATION_FRACTION]
134
+ )
135
+
136
+ training_paired_filenames, val_paired_filenames = get_paired_filenames(
137
+ image_dir,
138
+ label_dir,
139
+ noisy_basenames,
140
+ label_basenames,
141
+ training_label_basenames,
142
+ )
143
+
144
+ train_dataset = SegMaskDataset(training_paired_filenames, colour_map, model)
145
+ val_dataset = SegMaskDataset(val_paired_filenames, colour_map, model)
146
+
147
+ train_dataloader = DataLoader(
148
+ dataset=train_dataset,
149
+ batch_size=1,
150
+ shuffle=True,
151
+ num_workers=4,
152
+ prefetch_factor=4,
153
+ collate_fn=collate_fn,
154
+ )
155
+ val_dataloader = DataLoader(
156
+ dataset=val_dataset,
157
+ batch_size=1,
158
+ shuffle=False,
159
+ num_workers=4,
160
+ prefetch_factor=4,
161
+ collate_fn=collate_fn,
162
+ )
163
+
164
+ return train_dataloader, val_dataloader
165
+
166
+
167
+ def get_model(model_type, colour_map):
168
+ num_classes = len(colour_map.keys())
169
+
170
+ if model_type == "pixel_decoder":
171
+ return PixelwiseClassifier(input_dim=3, hidden_dim=32, output_dim=num_classes)
172
+
173
+ raise ValueError(f"Invalid model type: {model_type}")
174
+
175
+
176
+ def train_model(
177
+ image_dir: str,
178
+ label_dir: str,
179
+ output_dir: str,
180
+ colour_map: dict,
181
+ model_type: str = "pixel_decoder",
182
+ ):
183
+ """
184
+ Train a neural network model for segmentation cleaning.
185
+
186
+ Args:
187
+ image_dir (str): Path to directory containing noisy images.
188
+ label_dir (str): Path to directory containing target RGB labels.
189
+ output_dir (str): Directory where model weights will be saved.
190
+ colour_map (dict): Mapping from class indices to RGB tuples.
191
+ model_type (str): The type of model to train.
192
+ """
193
+ model = get_model(model_type, colour_map)
194
+ train_dataloader, val_dataloader = get_dataloaders(
195
+ image_dir, label_dir, colour_map, model
196
+ )
197
+
198
+ os.makedirs(output_dir, exist_ok=True)
199
+
200
+ early_stop = EarlyStopping(monitor="val_loss", mode="min", verbose=True)
201
+ checkpoint = ModelCheckpoint(
202
+ monitor="val_loss",
203
+ mode="min",
204
+ save_top_k=1,
205
+ dirpath=output_dir,
206
+ filename=f"{model_type}.ckpt",
207
+ )
208
+ trainer = Trainer(max_epochs=100, callbacks=[early_stop, checkpoint])
209
+
210
+ trainer.fit(
211
+ model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
212
+ )
@@ -0,0 +1,40 @@
1
+ from typing import List, Tuple
2
+
3
+
4
+ def parse_colours_from_string(colours_str: str) -> List[Tuple[int, int, int]]:
5
+ """Parse a semicolon-separated string of RGB triples into a list of tuples."""
6
+
7
+ parts = [p.strip() for p in colours_str.split(";") if p.strip()]
8
+ colours = []
9
+
10
+ for p in parts:
11
+ rgb = tuple(int(x) for x in p.split(","))
12
+
13
+ if len(rgb) != 3:
14
+ raise ValueError(f"Invalid colour triple: {p}")
15
+ colours.append(rgb)
16
+
17
+ return colours
18
+
19
+
20
+ def parse_colours_from_file(path: str) -> List[Tuple[int, int, int]]:
21
+ """Parse a file with one RGB triple per line into a list of tuples."""
22
+
23
+ colours = []
24
+
25
+ with open(path, "r") as f:
26
+ for line in f:
27
+ line = line.strip()
28
+ if not line or line.startswith("#"):
29
+ continue
30
+
31
+ rgb = tuple(int(x) for x in line.split(","))
32
+ if len(rgb) != 3:
33
+ raise ValueError(f"Invalid colour triple in file {path}: {line}")
34
+
35
+ colours.append(rgb)
36
+
37
+ if not colours:
38
+ raise ValueError(f"No colours found in file: {path}")
39
+
40
+ return colours