rgb-to-segmentation 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rgb-to-segmentation might be problematic. Click here for more details.

@@ -0,0 +1,4 @@
1
+ __pycache__/
2
+ local/
3
+ .vscode/
4
+ lightning_logs/
@@ -0,0 +1,142 @@
1
+ Metadata-Version: 2.4
2
+ Name: rgb-to-segmentation
3
+ Version: 0.0.2
4
+ Summary: Tools for processing and cleaning segmentation images using palette mapping and neural networks
5
+ Author: Alex Senden
6
+ Maintainer: Alex Senden
7
+ License: MIT
8
+ Keywords: image-processing,machine-learning,pytorch,segmentation
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.8
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Topic :: Scientific/Engineering :: Image Processing
18
+ Requires-Python: >=3.8
19
+ Requires-Dist: numpy>=1.21.0
20
+ Requires-Dist: pillow>=8.0.0
21
+ Requires-Dist: pytorch-lightning>=1.5.0
22
+ Requires-Dist: scipy>=1.7.0
23
+ Requires-Dist: torch>=1.9.0
24
+ Requires-Dist: torchvision>=0.10.0
25
+ Description-Content-Type: text/markdown
26
+
27
+ # RGB to Segmentation
28
+
29
+ A Python package for processing and cleaning segmentation images. This package provides tools to convert RGB images to segmentation masks using palette-based color mapping and neural network-based refinement.
30
+
31
+ ## Features
32
+
33
+ - **Palette-based Cleaning**: Clean noisy segmentation images by mapping pixels to the nearest colors in a predefined palette, with optional morphological operations to refine boundaries.
34
+ - **Neural Network Refinement**: Use a trained pixelwise classifier to refine segmentation masks using PyTorch Lightning.
35
+ - **Command-Line Interface**: Unified CLI for cleaning with method selection, plus separate training command.
36
+ - **Programmatic API**: Direct access to cleaning and training functions for integration into other workflows.
37
+
38
+ ## Installation
39
+
40
+ Install from PyPI:
41
+
42
+ ```bash
43
+ pip install rgb-to-segmentation
44
+ ```
45
+
46
+ Or install from source:
47
+
48
+ ```bash
49
+ git clone https://github.com/alexsenden/rgb-to-segmentation.git
50
+ cd rgb-to-segmentation
51
+ pip install .
52
+ ```
53
+
54
+ ## Usage
55
+
56
+ ### Cleaning Noisy Segmentation Images
57
+
58
+ Use the `segment-clean` command to clean segmentation images using various methods:
59
+
60
+ #### Palette-based cleaning:
61
+
62
+ ```bash
63
+ segment-clean --method palette --input_dir /path/to/input --output_dir /path/to/output --colour_map "0,0,0;255,0,0;0,255,0" --output_type rgb
64
+ ```
65
+
66
+ #### Neural network-based cleaning:
67
+
68
+ ```bash
69
+ segment-clean --method nn --input_dir /path/to/input --output_dir /path/to/output --model_path /path/to/model.ckpt --colour_map "0,0,0;255,0,0;0,255,0" --output_type index
70
+ ```
71
+
72
+ You can also provide colours via file with `--colour_map_file /path/to/colours.txt` (one `r,g,b` per line). The CLI parses colours and constructs the palette/colour map internally, mirroring the Python API which accepts parsed structures (NumPy array for palette, dictionary for colour map).
73
+
74
+ Options:
75
+
76
+ - `--method`: Cleaning method ('palette' or 'nn')
77
+ - `--input_dir`: Path to input directory containing images
78
+ - `--output_dir`: Directory where cleaned images will be written
79
+ - `--inplace`: Overwrite input images in place
80
+ - `--exts`: Comma-separated list of allowed image extensions
81
+ - `--name_filter`: Only process files whose name contains this substring
82
+ - `--output_type`: Output format ('rgb' or 'index')
83
+
84
+ For palette method:
85
+ - `--colour_map`: Semicolon-separated list of RGB triples
86
+ - `--colour_map_file`: Path to a file listing RGB triples
87
+ - `--morph_kernel_size`: Size of morphological kernel for boundary cleaning
88
+
89
+ For nn method:
90
+ - `--model_path`: Path to trained model file
91
+ - `--colour_map`: Semicolon-separated list of RGB triples
92
+ - `--colour_map_file`: Path to a file listing RGB triples
93
+
94
+ ### Training the Neural Network Model
95
+
96
+ Train a pixelwise classifier to refine segmentation masks:
97
+
98
+ ```bash
99
+ segment-train --image_dir /path/to/noisy_images --label_dir /path/to/labels --output_dir /path/to/model_output --colour_map "0,0,0;255,0,0;0,255,0"
100
+ ```
101
+
102
+ Options:
103
+ - `--image_dir`: Path to directory containing noisy images
104
+ - `--label_dir`: Path to directory containing target RGB labels
105
+ - `--output_dir`: Directory where model weights will be saved
106
+ - `--colour_map`: Semicolon-separated list of RGB triples
107
+ - `--colour_map_file`: Path to a file listing RGB triples
108
+ - `--model_type`: The type of model to train (default: pixelwise)
109
+
110
+ Note that one label image may have multiple corresponding noisy masks. Labels are matched to noisy masks whose filenames contain the label file basename (pre-extension name, i.e. `my_image.png` -> `my_image`).
111
+
112
+ ## API
113
+
114
+ You can also use the package programmatically:
115
+
116
+ ```python
117
+ import numpy as np
118
+ from rgb_to_segmentation import clean, nn, train, utils
119
+
120
+ # Palette cleaning
121
+ colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
122
+ palette = np.asarray(colours, dtype=np.uint8)
123
+ clean.clean_segmentation(input_dir="/path/to/input", output_dir="/path/to/output", palette=palette, output_type="index")
124
+
125
+ # NN inference
126
+ colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
127
+ colour_map = {i: rgb for i, rgb in enumerate(colours)}
128
+ nn.run_inference(input_dir="/path/to/input", output_dir="/path/to/output", model_path="/path/to/model.ckpt", colour_map=colour_map, output_type="rgb")
129
+
130
+ # Train model
131
+ colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
132
+ colour_map = {i: rgb for i, rgb in enumerate(colours)}
133
+ train.train_model(image_dir="/path/to/images", label_dir="/path/to/labels", output_dir="/path/to/output", colour_map=colour_map)
134
+ ```
135
+
136
+ ## Contributing
137
+
138
+ Contributions are welcome! Please feel free to submit a Pull Request.
139
+
140
+ ## License
141
+
142
+ This project is licensed under the MIT License - see the LICENSE file for details.
@@ -0,0 +1,116 @@
1
+ # RGB to Segmentation
2
+
3
+ A Python package for processing and cleaning segmentation images. This package provides tools to convert RGB images to segmentation masks using palette-based color mapping and neural network-based refinement.
4
+
5
+ ## Features
6
+
7
+ - **Palette-based Cleaning**: Clean noisy segmentation images by mapping pixels to the nearest colors in a predefined palette, with optional morphological operations to refine boundaries.
8
+ - **Neural Network Refinement**: Use a trained pixelwise classifier to refine segmentation masks using PyTorch Lightning.
9
+ - **Command-Line Interface**: Unified CLI for cleaning with method selection, plus separate training command.
10
+ - **Programmatic API**: Direct access to cleaning and training functions for integration into other workflows.
11
+
12
+ ## Installation
13
+
14
+ Install from PyPI:
15
+
16
+ ```bash
17
+ pip install rgb-to-segmentation
18
+ ```
19
+
20
+ Or install from source:
21
+
22
+ ```bash
23
+ git clone https://github.com/alexsenden/rgb-to-segmentation.git
24
+ cd rgb-to-segmentation
25
+ pip install .
26
+ ```
27
+
28
+ ## Usage
29
+
30
+ ### Cleaning Noisy Segmentation Images
31
+
32
+ Use the `segment-clean` command to clean segmentation images using various methods:
33
+
34
+ #### Palette-based cleaning:
35
+
36
+ ```bash
37
+ segment-clean --method palette --input_dir /path/to/input --output_dir /path/to/output --colour_map "0,0,0;255,0,0;0,255,0" --output_type rgb
38
+ ```
39
+
40
+ #### Neural network-based cleaning:
41
+
42
+ ```bash
43
+ segment-clean --method nn --input_dir /path/to/input --output_dir /path/to/output --model_path /path/to/model.ckpt --colour_map "0,0,0;255,0,0;0,255,0" --output_type index
44
+ ```
45
+
46
+ You can also provide colours via file with `--colour_map_file /path/to/colours.txt` (one `r,g,b` per line). The CLI parses colours and constructs the palette/colour map internally, mirroring the Python API which accepts parsed structures (NumPy array for palette, dictionary for colour map).
47
+
48
+ Options:
49
+
50
+ - `--method`: Cleaning method ('palette' or 'nn')
51
+ - `--input_dir`: Path to input directory containing images
52
+ - `--output_dir`: Directory where cleaned images will be written
53
+ - `--inplace`: Overwrite input images in place
54
+ - `--exts`: Comma-separated list of allowed image extensions
55
+ - `--name_filter`: Only process files whose name contains this substring
56
+ - `--output_type`: Output format ('rgb' or 'index')
57
+
58
+ For palette method:
59
+ - `--colour_map`: Semicolon-separated list of RGB triples
60
+ - `--colour_map_file`: Path to a file listing RGB triples
61
+ - `--morph_kernel_size`: Size of morphological kernel for boundary cleaning
62
+
63
+ For nn method:
64
+ - `--model_path`: Path to trained model file
65
+ - `--colour_map`: Semicolon-separated list of RGB triples
66
+ - `--colour_map_file`: Path to a file listing RGB triples
67
+
68
+ ### Training the Neural Network Model
69
+
70
+ Train a pixelwise classifier to refine segmentation masks:
71
+
72
+ ```bash
73
+ segment-train --image_dir /path/to/noisy_images --label_dir /path/to/labels --output_dir /path/to/model_output --colour_map "0,0,0;255,0,0;0,255,0"
74
+ ```
75
+
76
+ Options:
77
+ - `--image_dir`: Path to directory containing noisy images
78
+ - `--label_dir`: Path to directory containing target RGB labels
79
+ - `--output_dir`: Directory where model weights will be saved
80
+ - `--colour_map`: Semicolon-separated list of RGB triples
81
+ - `--colour_map_file`: Path to a file listing RGB triples
82
+ - `--model_type`: The type of model to train (default: pixelwise)
83
+
84
+ Note that one label image may have multiple corresponding noisy masks. Labels are matched to noisy masks whose filenames contain the label file basename (pre-extension name, i.e. `my_image.png` -> `my_image`).
85
+
86
+ ## API
87
+
88
+ You can also use the package programmatically:
89
+
90
+ ```python
91
+ import numpy as np
92
+ from rgb_to_segmentation import clean, nn, train, utils
93
+
94
+ # Palette cleaning
95
+ colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
96
+ palette = np.asarray(colours, dtype=np.uint8)
97
+ clean.clean_segmentation(input_dir="/path/to/input", output_dir="/path/to/output", palette=palette, output_type="index")
98
+
99
+ # NN inference
100
+ colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
101
+ colour_map = {i: rgb for i, rgb in enumerate(colours)}
102
+ nn.run_inference(input_dir="/path/to/input", output_dir="/path/to/output", model_path="/path/to/model.ckpt", colour_map=colour_map, output_type="rgb")
103
+
104
+ # Train model
105
+ colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
106
+ colour_map = {i: rgb for i, rgb in enumerate(colours)}
107
+ train.train_model(image_dir="/path/to/images", label_dir="/path/to/labels", output_dir="/path/to/output", colour_map=colour_map)
108
+ ```
109
+
110
+ ## Contributing
111
+
112
+ Contributions are welcome! Please feel free to submit a Pull Request.
113
+
114
+ ## License
115
+
116
+ This project is licensed under the MIT License - see the LICENSE file for details.
@@ -0,0 +1,48 @@
1
+ [build-system]
2
+ requires = ["hatchling", "hatch-vcs"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "rgb-to-segmentation"
7
+ description = "Tools for processing and cleaning segmentation images using palette mapping and neural networks"
8
+ readme = "README.md"
9
+ license = { text = "MIT" }
10
+ requires-python = ">=3.8"
11
+ authors = [{ name = "Alex Senden" }]
12
+ maintainers = [{ name = "Alex Senden" }]
13
+ keywords = ["segmentation", "image-processing", "pytorch", "machine-learning"]
14
+ classifiers = [
15
+ "Development Status :: 3 - Alpha",
16
+ "Intended Audience :: Developers",
17
+ "License :: OSI Approved :: MIT License",
18
+ "Programming Language :: Python :: 3",
19
+ "Programming Language :: Python :: 3.8",
20
+ "Programming Language :: Python :: 3.9",
21
+ "Programming Language :: Python :: 3.10",
22
+ "Programming Language :: Python :: 3.11",
23
+ "Topic :: Scientific/Engineering :: Image Processing",
24
+ ]
25
+ dependencies = [
26
+ "torch>=1.9.0",
27
+ "torchvision>=0.10.0",
28
+ "pytorch-lightning>=1.5.0",
29
+ "numpy>=1.21.0",
30
+ "scipy>=1.7.0",
31
+ "Pillow>=8.0.0",
32
+ ]
33
+
34
+ # Version is automatically provided by hatch-vcs
35
+ dynamic = ["version"]
36
+
37
+ [project.scripts]
38
+ segment-clean = "rgb_to_segmentation.cli:main_clean"
39
+ segment-train = "rgb_to_segmentation.cli:main_train"
40
+
41
+ [tool.hatch.version]
42
+ source = "vcs"
43
+
44
+ [tool.hatch.build.targets.sdist]
45
+ include = ["rgb_to_segmentation/**", "README.md"]
46
+
47
+ [tool.hatch.build.targets.wheel]
48
+ include = ["rgb_to_segmentation/**"]
@@ -0,0 +1,219 @@
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from scipy import ndimage
8
+
9
+
10
+ def nearest_palette_image(image_array: np.ndarray, palette: np.ndarray) -> np.ndarray:
11
+ """
12
+ Assign each pixel in `image_array` (H,W,3 uint8) to the nearest colour in `palette` (K,3 uint8).
13
+ Returns recoloured image array with same shape and dtype uint8.
14
+ """
15
+ if image_array.ndim != 3 or image_array.shape[2] != 3:
16
+ raise ValueError("image_array must have shape (H, W, 3)")
17
+
18
+ h, w, _ = image_array.shape
19
+ flat = image_array.reshape(-1, 3).astype(np.int64)
20
+ pal = palette.astype(np.int64)
21
+
22
+ # Compute squared distances between each pixel and each palette colour.
23
+ # distances shape: (N_pixels, K)
24
+ d = np.sum((flat[:, None, :] - pal[None, :, :]) ** 2, axis=2)
25
+
26
+ idx = np.argmin(d, axis=1)
27
+ new_flat = pal[idx]
28
+ new = new_flat.reshape(h, w, 3).astype(np.uint8)
29
+ return new
30
+
31
+
32
+ def get_palette_for_image(
33
+ image_array: np.ndarray, full_palette: np.ndarray
34
+ ) -> np.ndarray:
35
+ """
36
+ Identify which colours from the full palette are present in the image,
37
+ and return only those colours.
38
+ """
39
+ h, w, _ = image_array.shape
40
+ flat_img = image_array.reshape(-1, 3).astype(np.int16)
41
+ pal = full_palette.astype(np.int16)
42
+
43
+ # For each pixel, find the nearest palette colour
44
+ d = np.sum((flat_img[:, None, :] - pal[None, :, :]) ** 2, axis=2)
45
+ idx = np.argmin(d, axis=1)
46
+
47
+ # Get unique indices that are actually used
48
+ unique_idx = np.unique(idx)
49
+
50
+ # Return only the palette colours that are used
51
+ return full_palette[unique_idx]
52
+
53
+
54
+ def apply_morphological_clean(image_array: np.ndarray, kernel_size: int) -> np.ndarray:
55
+ """
56
+ Apply morphological closing (erosion followed by dilation) per class to clean up
57
+ class boundaries and remove noise.
58
+ """
59
+ if kernel_size <= 0:
60
+ return image_array
61
+
62
+ # Create morphological kernel
63
+ kernel = ndimage.generate_binary_structure(2, 2)
64
+
65
+ # Get unique colours that actually appear in the image
66
+ h, w, _ = image_array.shape
67
+ flat_img = image_array.reshape(-1, 3)
68
+ unique_colours = np.unique(flat_img, axis=0)
69
+
70
+ # Process each class separately to avoid blending
71
+ result = np.zeros_like(image_array)
72
+
73
+ for colour in unique_colours:
74
+ # Create binary mask for this class
75
+ mask = np.all(image_array == colour, axis=-1)
76
+
77
+ # Apply closing: erosion then dilation
78
+ for _ in range(kernel_size):
79
+ mask = ndimage.binary_erosion(mask, structure=kernel)
80
+ for _ in range(kernel_size):
81
+ mask = ndimage.binary_dilation(mask, structure=kernel)
82
+
83
+ # Assign pixels back
84
+ result[mask] = colour
85
+
86
+ # Fill any remaining pixels (from eroded areas) with nearest colour from result
87
+ unfilled = ~np.any(result != 0, axis=-1)
88
+ if np.any(unfilled):
89
+ # For unfilled pixels, use nearest palette colour again or copy from nearby
90
+ result[unfilled] = image_array[unfilled]
91
+
92
+ return result
93
+
94
+
95
+ def rgb_image_to_index(image_array: np.ndarray, palette: np.ndarray) -> np.ndarray:
96
+ """
97
+ Map each RGB pixel in `image_array` to the index of the matching colour in `palette`.
98
+ Assumes pixels take values from `palette`.
99
+ """
100
+ h, w, _ = image_array.shape
101
+ palette_list = [tuple(map(int, c)) for c in palette.tolist()]
102
+ lookup = {c: i for i, c in enumerate(palette_list)}
103
+ flat = image_array.reshape(-1, 3)
104
+ idx = np.array([lookup[tuple(map(int, px))] for px in flat], dtype=np.uint16)
105
+ return idx.reshape(h, w)
106
+
107
+
108
+ def process_file(
109
+ input_path: str,
110
+ output_path: str,
111
+ palette: np.ndarray,
112
+ kernel_size: int,
113
+ output_type: str = "rgb",
114
+ ):
115
+ try:
116
+ img = Image.open(input_path).convert("RGB")
117
+ except Exception as e:
118
+ print(f"Skipping {input_path}: cannot open image ({e})")
119
+ return
120
+
121
+ arr = np.array(img, dtype=np.uint8)
122
+
123
+ # Reduce palette to only colours present in this image
124
+ reduced_palette = get_palette_for_image(arr, palette)
125
+
126
+ cleaned = nearest_palette_image(arr, reduced_palette)
127
+
128
+ # Apply morphological transformations if kernel_size > 0
129
+ if kernel_size > 0:
130
+ cleaned = apply_morphological_clean(cleaned, kernel_size)
131
+
132
+ # Ensure output directory exists
133
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
134
+
135
+ if output_type == "rgb":
136
+ Image.fromarray(cleaned).save(output_path)
137
+ elif output_type == "index":
138
+ index_mask = rgb_image_to_index(cleaned, reduced_palette)
139
+ Image.fromarray(index_mask.astype(np.uint8), mode="L").save(output_path)
140
+ else:
141
+ raise ValueError("output_type must be 'rgb' or 'index'")
142
+
143
+
144
+ def process_directory(
145
+ input_dir: str,
146
+ output_dir: str,
147
+ palette: np.ndarray,
148
+ exts: List[str],
149
+ inplace: bool,
150
+ name_filter: str = "",
151
+ kernel_size: int = 0,
152
+ output_type: str = "rgb",
153
+ ):
154
+ exts = [e.lower().strip() for e in exts]
155
+ for root, dirs, files in os.walk(input_dir):
156
+ # Determine the corresponding output root
157
+ rel = os.path.relpath(root, input_dir)
158
+ out_root = os.path.join(output_dir, rel) if not inplace else root
159
+ os.makedirs(out_root, exist_ok=True)
160
+
161
+ for fname in files:
162
+ if not any(fname.lower().endswith(e) for e in exts):
163
+ continue
164
+ if name_filter and name_filter not in fname:
165
+ continue
166
+ in_path = os.path.join(root, fname)
167
+ out_path = os.path.join(out_root, fname)
168
+ process_file(in_path, out_path, palette, kernel_size, output_type)
169
+
170
+
171
+ def clean_segmentation(
172
+ input_dir: str,
173
+ output_dir: str = None,
174
+ inplace: bool = False,
175
+ palette: np.ndarray = None,
176
+ exts: str = ".png,.jpg,.jpeg,.tiff,.bmp,.gif",
177
+ name_filter: str = "",
178
+ morph_kernel_size: int = 3,
179
+ output_type: str = "rgb",
180
+ ):
181
+ """
182
+ Clean segmentation images using palette-based color mapping.
183
+
184
+ Args:
185
+ input_dir (str): Path to input directory containing segmentation images.
186
+ output_dir (str, optional): Directory where cleaned images will be written. Required if not inplace.
187
+ inplace (bool): Overwrite input images in place.
188
+ palette (np.ndarray): Array of RGB triples (K, 3) uint8.
189
+ exts (str): Comma-separated list of allowed image extensions.
190
+ name_filter (str): Only process files whose name contains this substring.
191
+ morph_kernel_size (int): Size of morphological kernel for boundary cleaning.
192
+ """
193
+ if not inplace and output_dir is None:
194
+ raise ValueError("Either output_dir must be provided or inplace must be True")
195
+
196
+ if palette is None:
197
+ raise ValueError("palette must be provided")
198
+
199
+ exts_list = [e if e.startswith(".") else "." + e for e in exts.split(",")]
200
+
201
+ out_dir = output_dir if not inplace else input_dir
202
+
203
+ if not inplace:
204
+ os.makedirs(out_dir, exist_ok=True)
205
+
206
+ print(
207
+ f"Processing: input={input_dir} -> output={out_dir}, colours={len(palette)}, morph_kernel={morph_kernel_size}, output_type={output_type}"
208
+ )
209
+ process_directory(
210
+ input_dir,
211
+ out_dir,
212
+ palette,
213
+ exts_list,
214
+ inplace,
215
+ name_filter,
216
+ morph_kernel_size,
217
+ output_type,
218
+ )
219
+ print("Done.")
@@ -0,0 +1,179 @@
1
+ import argparse
2
+
3
+ import numpy as np
4
+
5
+ from . import clean, nn, train, utils
6
+
7
+
8
+ def main_clean():
9
+ parser = argparse.ArgumentParser(
10
+ description="Clean segmentation images using various methods."
11
+ )
12
+
13
+ parser.add_argument(
14
+ "--method",
15
+ type=str,
16
+ required=True,
17
+ choices=["palette", "nn"],
18
+ help="Cleaning method to use: 'palette' for color palette mapping, 'nn' for neural network.",
19
+ )
20
+ parser.add_argument(
21
+ "--input_dir",
22
+ type=str,
23
+ required=True,
24
+ help="Path to input directory containing images.",
25
+ )
26
+ parser.add_argument(
27
+ "--output_dir",
28
+ type=str,
29
+ required=False,
30
+ help="Directory where cleaned images will be written. Required if not inplace.",
31
+ )
32
+ parser.add_argument(
33
+ "--inplace",
34
+ action="store_true",
35
+ help="Overwrite input images in place.",
36
+ )
37
+ parser.add_argument(
38
+ "--exts",
39
+ type=str,
40
+ default=".png,.jpg,.jpeg,.tiff,.bmp,.gif",
41
+ help="Comma-separated list of allowed image extensions.",
42
+ )
43
+ parser.add_argument(
44
+ "--name_filter",
45
+ type=str,
46
+ default="",
47
+ help="Only process files whose name contains this substring.",
48
+ )
49
+
50
+ parser.add_argument(
51
+ "--output_type",
52
+ type=str,
53
+ choices=["rgb", "index"],
54
+ default="rgb",
55
+ help="Output format: 'rgb' colour image or 'index' mask.",
56
+ )
57
+
58
+ group = parser.add_mutually_exclusive_group(required=True)
59
+ group.add_argument(
60
+ "--colour_map",
61
+ type=str,
62
+ help="Semicolon-separated list of RGB triples.",
63
+ )
64
+ group.add_argument(
65
+ "--colour_map_file",
66
+ type=str,
67
+ help="Path to a file listing RGB triples.",
68
+ )
69
+
70
+ # Palette-specific args
71
+ parser.add_argument(
72
+ "--morph_kernel_size",
73
+ type=int,
74
+ default=3,
75
+ help="Size of morphological kernel for palette method.",
76
+ )
77
+
78
+ # NN-specific args
79
+ parser.add_argument(
80
+ "--model_path",
81
+ type=str,
82
+ help="Path to trained model for nn method.",
83
+ )
84
+
85
+ args = parser.parse_args()
86
+
87
+ if args.colour_map_file:
88
+ colours = utils.parse_colours_from_file(args.colour_map_file)
89
+ else:
90
+ colours = utils.parse_colours_from_string(args.colour_map)
91
+
92
+ if args.method == "palette":
93
+ palette = np.asarray(colours, dtype=np.uint8)
94
+
95
+ clean.clean_segmentation(
96
+ input_dir=args.input_dir,
97
+ output_dir=args.output_dir,
98
+ inplace=args.inplace,
99
+ palette=palette,
100
+ exts=args.exts,
101
+ name_filter=args.name_filter,
102
+ morph_kernel_size=args.morph_kernel_size,
103
+ output_type=args.output_type,
104
+ )
105
+
106
+ elif args.method == "nn":
107
+ if not args.model_path:
108
+ parser.error("--model_path required for nn method")
109
+
110
+ colour_map = {i: rgb for i, rgb in enumerate(colours)}
111
+ nn.run_inference(
112
+ input_dir=args.input_dir,
113
+ output_dir=args.output_dir,
114
+ inplace=args.inplace,
115
+ model_path=args.model_path,
116
+ colour_map=colour_map,
117
+ exts=args.exts,
118
+ name_filter=args.name_filter,
119
+ output_type=args.output_type,
120
+ )
121
+
122
+
123
+ def main_train():
124
+ parser = argparse.ArgumentParser(
125
+ description="Train a neural network model for segmentation cleaning."
126
+ )
127
+
128
+ parser.add_argument(
129
+ "--image_dir",
130
+ type=str,
131
+ required=True,
132
+ help="Path to directory containing noisy images.",
133
+ )
134
+ parser.add_argument(
135
+ "--label_dir",
136
+ type=str,
137
+ required=True,
138
+ help="Path to directory containing target RGB labels.",
139
+ )
140
+ parser.add_argument(
141
+ "--output_dir",
142
+ type=str,
143
+ required=True,
144
+ help="Directory where model weights will be saved.",
145
+ )
146
+ parser.add_argument(
147
+ "--model_type",
148
+ type=str,
149
+ default="pixelwise",
150
+ help="The type of model to train.",
151
+ )
152
+
153
+ group = parser.add_mutually_exclusive_group(required=True)
154
+ group.add_argument(
155
+ "--colour_map",
156
+ type=str,
157
+ help="Semicolon-separated list of RGB triples.",
158
+ )
159
+ group.add_argument(
160
+ "--colour_map_file",
161
+ type=str,
162
+ help="Path to a file listing RGB triples.",
163
+ )
164
+
165
+ args = parser.parse_args()
166
+
167
+ if args.colour_map_file:
168
+ colours = utils.parse_colours_from_file(args.colour_map_file)
169
+ else:
170
+ colours = utils.parse_colours_from_string(args.colour_map)
171
+ colour_map = {i: rgb for i, rgb in enumerate(colours)}
172
+
173
+ train.train_model(
174
+ image_dir=args.image_dir,
175
+ label_dir=args.label_dir,
176
+ output_dir=args.output_dir,
177
+ colour_map=colour_map,
178
+ model_type=args.model_type,
179
+ )
@@ -0,0 +1,57 @@
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class PixelwiseClassifier(pl.LightningModule):
8
+ step = 0
9
+ val_step = 0
10
+
11
+ def __init__(self, input_dim, hidden_dim, output_dim):
12
+ super().__init__()
13
+ self.save_hyperparameters()
14
+
15
+ self.output_dim = output_dim
16
+ self.net = nn.Sequential(
17
+ nn.Linear(input_dim, hidden_dim),
18
+ nn.ReLU(),
19
+ nn.Linear(hidden_dim, output_dim),
20
+ )
21
+ self.softmax = nn.Softmax(dim=-1)
22
+ self.loss_fn = nn.CrossEntropyLoss()
23
+
24
+ def forward(self, x: torch.Tensor):
25
+ x = self.net(x)
26
+ return self.softmax(x)
27
+
28
+ def training_step(self, batch, batch_idx):
29
+ sample, target = batch
30
+ probs = self(sample)
31
+ target = F.one_hot(target.squeeze(-1), num_classes=self.output_dim).to(
32
+ dtype=sample.dtype
33
+ )
34
+ loss = self.loss_fn(probs, target)
35
+ self.log("train_loss", loss)
36
+ return loss
37
+
38
+ def validation_step(self, batch, batch_idx):
39
+ sample, target = batch
40
+ probs = self(sample)
41
+ target = F.one_hot(target.squeeze(-1), num_classes=self.output_dim).to(
42
+ dtype=sample.dtype
43
+ )
44
+ loss = self.loss_fn(probs, target)
45
+ self.log("val_loss", loss, prog_bar=True)
46
+ return loss
47
+
48
+ def configure_optimizers(self):
49
+ return torch.optim.Adam(self.parameters())
50
+
51
+ def image_to_batch(self, x: torch.Tensor):
52
+ if len(x.shape) == 3:
53
+ x = x.unsqueeze(0)
54
+ return x.permute(0, 2, 3, 1).reshape(-1, x.shape[1]).to(self.device)
55
+
56
+ def batch_to_image(self, batch: torch.Tensor, height: int, width: int):
57
+ return batch.reshape(1, height, width, 3).permute(0, 3, 1, 2)
@@ -0,0 +1,110 @@
1
+ import os
2
+
3
+ import torch
4
+
5
+ from PIL import Image
6
+ from torchvision.io import read_image
7
+
8
+ from .models.pixelwise_classifier import PixelwiseClassifier
9
+
10
+
11
+ def load_model(model_path: str):
12
+ print(model_path)
13
+ model = PixelwiseClassifier.load_from_checkpoint(checkpoint_path=model_path)
14
+ model.eval()
15
+
16
+ return model
17
+
18
+
19
+ def map_int_to_rgb(indexed_image: torch.Tensor, colour_map: dict):
20
+ h, w = indexed_image.shape
21
+ rgb_image = torch.zeros((3, h, w), dtype=torch.uint8)
22
+
23
+ for idx, rgb in colour_map.items():
24
+ mask = indexed_image == idx
25
+ for c in range(3):
26
+ rgb_image[c][mask] = rgb[c]
27
+
28
+ return rgb_image
29
+
30
+
31
+ def run_inference(
32
+ input_dir: str,
33
+ output_dir: str,
34
+ inplace: bool = False,
35
+ model_path: str = None,
36
+ colour_map: dict = None,
37
+ exts: str = ".png,.jpg,.jpeg,.tiff,.bmp,.gif",
38
+ name_filter: str = "",
39
+ output_type: str = "rgb",
40
+ ):
41
+ """
42
+ Run neural network inference on segmentation images.
43
+
44
+ Args:
45
+ input_dir (str): Path to input directory containing images.
46
+ output_dir (str, optional): Directory where output images will be written.
47
+ inplace (bool): Overwrite input images in place.
48
+ model_path (str): Path to the trained model file.
49
+ colour_map (dict): Mapping from class indices to RGB tuples.
50
+ exts (str): Comma-separated list of allowed image extensions.
51
+ name_filter (str): Only process files whose name contains this substring.
52
+ """
53
+ if not inplace and output_dir is None:
54
+ raise ValueError("Either output_dir must be provided or inplace must be True")
55
+
56
+ if model_path is None:
57
+ raise ValueError("model_path must be provided")
58
+
59
+ if colour_map is None:
60
+ raise ValueError("colour_map must be provided")
61
+
62
+ model = load_model(model_path)
63
+
64
+ exts_list = [e.lower().strip() for e in exts.split(",")]
65
+ out_dir = output_dir if not inplace else input_dir
66
+
67
+ if not inplace:
68
+ os.makedirs(out_dir, exist_ok=True)
69
+
70
+ print(f"Running inference: input={input_dir} -> output={out_dir}, output_type={output_type}")
71
+
72
+ for root, dirs, files in os.walk(input_dir):
73
+ rel = os.path.relpath(root, input_dir)
74
+ out_root = os.path.join(out_dir, rel) if not inplace else root
75
+
76
+ os.makedirs(out_root, exist_ok=True)
77
+
78
+ for fname in files:
79
+ if not any(fname.lower().endswith(e) for e in exts_list):
80
+ continue
81
+ if name_filter and name_filter not in fname:
82
+ continue
83
+
84
+ in_path = os.path.join(root, fname)
85
+ out_path = os.path.join(out_root, fname)
86
+
87
+ try:
88
+ img = read_image(in_path, mode="RGB").float() / 127.5 - 1.0 # Normalize
89
+ h, w = img.shape[1], img.shape[2]
90
+
91
+ batch = model.image_to_batch(img)
92
+
93
+ with torch.no_grad():
94
+ probs = model(batch)
95
+ predicted = torch.argmax(probs, dim=-1).reshape(h, w)
96
+
97
+ if output_type == "rgb":
98
+ rgb_image = map_int_to_rgb(predicted, colour_map)
99
+ pil_image = Image.fromarray(rgb_image.permute(1, 2, 0).numpy())
100
+ pil_image.save(out_path)
101
+ elif output_type == "index":
102
+ pil_image = Image.fromarray(predicted.cpu().numpy().astype("uint8"), mode="L")
103
+ pil_image.save(out_path)
104
+ else:
105
+ raise ValueError("output_type must be 'rgb' or 'index'")
106
+
107
+ except Exception as e:
108
+ print(f"Skipping {in_path}: {e}")
109
+
110
+ print("Inference done.")
@@ -0,0 +1,183 @@
1
+ import os
2
+
3
+ import torch
4
+
5
+ from pytorch_lightning import Trainer
6
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
7
+ from torch.utils.data import DataLoader, Dataset, random_split
8
+ from torchvision.io import read_image
9
+
10
+ from .nn import PixelwiseClassifier
11
+
12
+
13
+ VALIDATION_FRACTION = 0.2
14
+
15
+
16
+ def map_colour_to_int(sample, colour_map):
17
+ _, H, W = sample.shape
18
+ image_array = torch.zeros((1, H, W))
19
+
20
+ for idx, rgb in colour_map.items():
21
+ mask = torch.all(sample == torch.tensor(rgb).view(3, 1, 1), dim=0)
22
+ image_array[0][mask] = idx
23
+
24
+ return image_array
25
+
26
+
27
+ class SegMaskDataset(Dataset):
28
+ def __init__(self, paired_filenames, colour_map, model):
29
+ self.paired_filenames = paired_filenames
30
+ self.colour_map = colour_map
31
+ self.to_batch = model.image_to_batch
32
+
33
+ def __len__(self):
34
+ return len(self.paired_filenames)
35
+
36
+ def __getitem__(self, index):
37
+ sample_path = self.paired_filenames[index]["sample"]
38
+ target_path = self.paired_filenames[index]["target"]
39
+
40
+ sample = (
41
+ read_image(sample_path, mode="RGB") / 127.5
42
+ ) - 1.0 # Normalize to [-1, 1]
43
+ target = read_image(target_path, mode="RGB")
44
+ target = map_colour_to_int(target, self.colour_map)
45
+ sample = self.to_batch(sample)
46
+ target = self.to_batch(target).to(torch.long)
47
+
48
+ return (sample, target)
49
+
50
+
51
+ def get_png_basenames(directory: str) -> list[str]:
52
+ return [
53
+ os.path.splitext(filename)[0]
54
+ for filename in os.listdir(directory)
55
+ if filename.endswith(".png")
56
+ ]
57
+
58
+
59
+ def get_paired_filenames(
60
+ image_dir: str,
61
+ label_dir: str,
62
+ noisy_basenames: list[str],
63
+ label_basenames: list[str],
64
+ training_label_basenames: list[str],
65
+ ) -> list[dict]:
66
+ training_paired_filenames = []
67
+ val_paired_filenames = []
68
+
69
+ for noisy_image_name in noisy_basenames:
70
+ targets = [
71
+ target_file
72
+ for target_file in label_basenames
73
+ if target_file in noisy_image_name
74
+ ]
75
+
76
+ if len(targets) > 1:
77
+ raise Exception(
78
+ f"Multiple target files exist for noisy file {noisy_image_name}: {targets}."
79
+ )
80
+ elif len(targets) < 1:
81
+ print(
82
+ f"WARNING: No target found for noisy file {noisy_image_name}. Discarding."
83
+ )
84
+ else:
85
+ target_filename = targets[0]
86
+
87
+ if target_filename in training_label_basenames:
88
+ training_paired_filenames.append(
89
+ {
90
+ "sample": f"{image_dir}/{noisy_image_name}.png",
91
+ "target": f"{label_dir}/{target_filename}.png",
92
+ }
93
+ )
94
+ else:
95
+ val_paired_filenames.append(
96
+ {
97
+ "sample": f"{image_dir}/{noisy_image_name}.png",
98
+ "target": f"{label_dir}/{target_filename}.png",
99
+ }
100
+ )
101
+
102
+ return training_paired_filenames, val_paired_filenames
103
+
104
+
105
+ def get_dataloaders(
106
+ image_dir, label_dir, colour_map, model
107
+ ) -> tuple[DataLoader, DataLoader]:
108
+ noisy_basenames = get_png_basenames(image_dir)
109
+ label_basenames = get_png_basenames(label_dir)
110
+
111
+ training_label_basenames, val_label_basenames = random_split(
112
+ label_basenames, [1 - VALIDATION_FRACTION, VALIDATION_FRACTION]
113
+ )
114
+
115
+ training_paired_filenames, val_paired_filenames = get_paired_filenames(
116
+ image_dir,
117
+ label_dir,
118
+ noisy_basenames,
119
+ label_basenames,
120
+ training_label_basenames,
121
+ )
122
+
123
+ train_dataset = SegMaskDataset(training_paired_filenames, colour_map, model)
124
+ val_dataset = SegMaskDataset(val_paired_filenames, colour_map, model)
125
+
126
+ train_dataloader = DataLoader(
127
+ dataset=train_dataset,
128
+ batch_size=1,
129
+ shuffle=True,
130
+ num_workers=2,
131
+ )
132
+ val_dataloader = DataLoader(
133
+ dataset=val_dataset,
134
+ batch_size=1,
135
+ shuffle=True,
136
+ num_workers=2,
137
+ )
138
+
139
+ return train_dataloader, val_dataloader
140
+
141
+
142
+ def get_model(model_type, colour_map):
143
+ num_classes = len(colour_map.keys())
144
+
145
+ if model_type == "pixelwise":
146
+ return PixelwiseClassifier(input_dim=3, hidden_dim=32, output_dim=num_classes)
147
+
148
+ raise ValueError(f"Invalid model type: {model_type}")
149
+
150
+
151
+ def train_model(
152
+ image_dir: str,
153
+ label_dir: str,
154
+ output_dir: str,
155
+ colour_map: dict,
156
+ model_type: str = "pixelwise",
157
+ ):
158
+ """
159
+ Train a neural network model for segmentation cleaning.
160
+
161
+ Args:
162
+ image_dir (str): Path to directory containing noisy images.
163
+ label_dir (str): Path to directory containing target RGB labels.
164
+ output_dir (str): Directory where model weights will be saved.
165
+ colour_map (dict): Mapping from class indices to RGB tuples.
166
+ model_type (str): The type of model to train.
167
+ """
168
+ model = get_model(model_type, colour_map)
169
+ train_dataloader, val_dataloader = get_dataloaders(
170
+ image_dir, label_dir, colour_map, model
171
+ )
172
+
173
+ os.makedirs(output_dir, exist_ok=True)
174
+
175
+ early_stop = EarlyStopping(monitor="val_loss", mode="min", verbose=True)
176
+ checkpoint = ModelCheckpoint(
177
+ monitor="val_loss", mode="min", save_top_k=1, dirpath=output_dir
178
+ )
179
+ trainer = Trainer(max_epochs=100, callbacks=[early_stop, checkpoint])
180
+
181
+ trainer.fit(
182
+ model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
183
+ )
@@ -0,0 +1,40 @@
1
+ from typing import List, Tuple
2
+
3
+
4
+ def parse_colours_from_string(colours_str: str) -> List[Tuple[int, int, int]]:
5
+ """Parse a semicolon-separated string of RGB triples into a list of tuples."""
6
+
7
+ parts = [p.strip() for p in colours_str.split(";") if p.strip()]
8
+ colours = []
9
+
10
+ for p in parts:
11
+ rgb = tuple(int(x) for x in p.split(","))
12
+
13
+ if len(rgb) != 3:
14
+ raise ValueError(f"Invalid colour triple: {p}")
15
+ colours.append(rgb)
16
+
17
+ return colours
18
+
19
+
20
+ def parse_colours_from_file(path: str) -> List[Tuple[int, int, int]]:
21
+ """Parse a file with one RGB triple per line into a list of tuples."""
22
+
23
+ colours = []
24
+
25
+ with open(path, "r") as f:
26
+ for line in f:
27
+ line = line.strip()
28
+ if not line or line.startswith("#"):
29
+ continue
30
+
31
+ rgb = tuple(int(x) for x in line.split(","))
32
+ if len(rgb) != 3:
33
+ raise ValueError(f"Invalid colour triple in file {path}: {line}")
34
+
35
+ colours.append(rgb)
36
+
37
+ if not colours:
38
+ raise ValueError(f"No colours found in file: {path}")
39
+
40
+ return colours