rgb-to-segmentation 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,10 @@
1
+ from .api import clean_image
2
+ from . import clean, nn, train, utils
3
+
4
+ __all__ = [
5
+ "clean_image",
6
+ "clean",
7
+ "nn",
8
+ "train",
9
+ "utils",
10
+ ]
@@ -0,0 +1,123 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from typing import Dict, Tuple, Optional, Union
5
+
6
+ from .clean import clean_image_palette, clean_image_strict_palette
7
+ from .nn import clean_image_nn
8
+
9
+
10
+ ImageArray = Union[np.ndarray, torch.Tensor]
11
+
12
+
13
+ def _to_numpy(image_array: ImageArray):
14
+ if isinstance(image_array, np.ndarray):
15
+ return image_array, False, None, None
16
+ if isinstance(image_array, torch.Tensor):
17
+ return (
18
+ image_array.detach().cpu().numpy(),
19
+ True,
20
+ image_array.dtype,
21
+ image_array.device,
22
+ )
23
+ raise TypeError("image_array must be a numpy.ndarray or torch.Tensor")
24
+
25
+
26
+ def _to_original_type(
27
+ np_array: np.ndarray,
28
+ is_torch: bool,
29
+ dtype: Optional[torch.dtype],
30
+ device: Optional[torch.device],
31
+ ) -> ImageArray:
32
+ if not is_torch:
33
+ return np_array
34
+
35
+ tensor = torch.from_numpy(np_array)
36
+
37
+ if dtype is not None and tensor.dtype != dtype:
38
+ tensor = tensor.to(dtype=dtype)
39
+ if device is not None and tensor.device != device:
40
+ tensor = tensor.to(device=device)
41
+
42
+ return tensor
43
+
44
+
45
+ def clean_image(
46
+ image_array: ImageArray,
47
+ method: str,
48
+ colour_map: Dict[int, Tuple[int, int, int]],
49
+ *,
50
+ model: Optional[object] = None,
51
+ morph_kernel_size: int = 0,
52
+ output_type: str = "rgb",
53
+ ) -> ImageArray:
54
+ """
55
+ Clean a single image (numpy array or torch tensor) using the specified method.
56
+
57
+ Args:
58
+ image_array: Array/Tensor of shape (H, W, 3), dtype uint8.
59
+ method: "palette", "strict_palette", "pixel_decoder", or "cnn_decoder" to choose cleaning approach.
60
+ model: Required when method="pixel_decoder" or "cnn_decoder". A trained model with forward(batch) returning class probabilities.
61
+ colour_map: Required for all methods. Dict mapping class index -> (r,g,b).
62
+ morph_kernel_size: Optional morphological clean kernel size (palette method only).
63
+ output_type: "rgb" to return colour image, "index" to return integer mask.
64
+
65
+ Returns:
66
+ Cleaned image with the same container type as the input (np.ndarray or torch.Tensor):
67
+ (H,W,3) uint8 when output_type="rgb", otherwise (H,W) uint8.
68
+ """
69
+ np_image, is_torch, orig_dtype, orig_device = _to_numpy(image_array)
70
+
71
+ if np_image.ndim != 3 or np_image.shape[2] != 3:
72
+ raise ValueError("image_array must have shape (H, W, 3)")
73
+
74
+ if output_type not in ("rgb", "index"):
75
+ raise ValueError("output_type must be 'rgb' or 'index'")
76
+
77
+ if method == "palette":
78
+ # Build palette ndarray from colour_map in index order and delegate to core function
79
+ keys = sorted(colour_map.keys())
80
+ palette = np.asarray([colour_map[k] for k in keys], dtype=np.uint8)
81
+
82
+ cleaned = clean_image_palette(
83
+ np_image,
84
+ palette=palette,
85
+ morph_kernel_size=morph_kernel_size,
86
+ output_type=output_type,
87
+ )
88
+
89
+ elif method == "strict_palette":
90
+ cleaned = clean_image_strict_palette(
91
+ np_image,
92
+ colour_map=colour_map,
93
+ output_type=output_type,
94
+ )
95
+
96
+ elif method == "pixel_decoder":
97
+ if model is None:
98
+ raise ValueError("model must be provided for method='pixel_decoder'")
99
+
100
+ cleaned = clean_image_nn(
101
+ np_image,
102
+ model=model,
103
+ colour_map=colour_map,
104
+ output_type=output_type,
105
+ )
106
+
107
+ elif method == "cnn_decoder":
108
+ if model is None:
109
+ raise ValueError("model must be provided for method='cnn_decoder'")
110
+
111
+ cleaned = clean_image_nn(
112
+ np_image,
113
+ model=model,
114
+ colour_map=colour_map,
115
+ output_type=output_type,
116
+ )
117
+
118
+ else:
119
+ raise ValueError(
120
+ "method must be 'palette', 'strict_palette', 'pixel_decoder', or 'cnn_decoder'"
121
+ )
122
+
123
+ return _to_original_type(cleaned, is_torch, orig_dtype, orig_device)
@@ -0,0 +1,315 @@
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from scipy import ndimage
8
+ import torch
9
+
10
+
11
+ def clean_image_palette(
12
+ image_array: np.ndarray,
13
+ palette: np.ndarray,
14
+ morph_kernel_size: int = 0,
15
+ output_type: str = "rgb",
16
+ ) -> np.ndarray:
17
+ """
18
+ Clean a single RGB image using palette-based nearest-colour mapping.
19
+
20
+ Args:
21
+ image_array: (H, W, 3) uint8
22
+ palette: (K, 3) uint8
23
+ morph_kernel_size: kernel size for morphological cleaning (0 disables)
24
+ output_type: 'rgb' to return colour image; 'index' to return integer mask
25
+
26
+ Returns:
27
+ np.ndarray: (H,W,3) uint8 if output_type='rgb'; else (H,W) uint8
28
+ """
29
+ if output_type not in ("rgb", "index"):
30
+ raise ValueError("output_type must be 'rgb' or 'index'")
31
+
32
+ # Reduce palette to colours present in the image for efficiency
33
+ reduced_palette = get_palette_for_image(image_array, palette)
34
+
35
+ # Map to nearest colours
36
+ cleaned_rgb = nearest_palette_image(image_array, reduced_palette)
37
+
38
+ # Optional morphological clean
39
+ if morph_kernel_size > 0:
40
+ cleaned_rgb = apply_morphological_clean(cleaned_rgb, morph_kernel_size)
41
+
42
+ if output_type == "rgb":
43
+ return cleaned_rgb
44
+ else:
45
+ return rgb_image_to_index(cleaned_rgb, reduced_palette).astype(np.uint8)
46
+
47
+
48
+ def clean_image_strict_palette(
49
+ image_array: np.ndarray,
50
+ colour_map: dict,
51
+ output_type: str = "rgb",
52
+ ) -> np.ndarray:
53
+ """
54
+ Strictly map RGB values to indices based on colour_map.
55
+ Raises an error if any RGB value is not found in the colour_map.
56
+
57
+ Args:
58
+ image_array: (H, W, 3) uint8 numpy array
59
+ colour_map: Dict mapping class index -> (r,g,b)
60
+ output_type: 'rgb' to return colour image; 'index' to return integer mask
61
+
62
+ Returns:
63
+ np.ndarray: (H,W,3) uint8 if output_type='rgb'; else (H,W) uint8
64
+ """
65
+ if output_type not in ("rgb", "index"):
66
+ raise ValueError("output_type must be 'rgb' or 'index'")
67
+
68
+ h, w, _ = image_array.shape
69
+ flat_img = image_array.reshape(-1, 3)
70
+
71
+ # Build reverse lookup: RGB tuple -> class index
72
+ rgb_to_idx = {tuple(map(int, rgb)): idx for idx, rgb in colour_map.items()}
73
+
74
+ # Find all unique RGB values in the image
75
+ unique_colours = np.unique(flat_img, axis=0)
76
+
77
+ # Check if all colours are in the colour_map
78
+ unmapped_colours = []
79
+ for colour in unique_colours:
80
+ colour_tuple = tuple(map(int, colour))
81
+ if colour_tuple not in rgb_to_idx:
82
+ unmapped_colours.append(colour_tuple)
83
+
84
+ if unmapped_colours:
85
+ # Format error message with unmapped colours
86
+ colour_strs = [f"RGB{c}" for c in unmapped_colours[:10]] # Show first 10
87
+ if len(unmapped_colours) > 10:
88
+ colour_strs.append(f"... and {len(unmapped_colours) - 10} more")
89
+ raise ValueError(
90
+ f"Image contains {len(unmapped_colours)} RGB value(s) not in colour_map: "
91
+ f"{', '.join(colour_strs)}. All pixel values must exactly match a colour in the map."
92
+ )
93
+
94
+ # Map each pixel to its class index
95
+ flat_indices = np.array(
96
+ [rgb_to_idx[tuple(map(int, px))] for px in flat_img], dtype=np.uint16
97
+ )
98
+ index_image = flat_indices.reshape(h, w).astype(np.uint8)
99
+
100
+ if output_type == "index":
101
+ return index_image
102
+ else:
103
+ # Convert back to RGB using colour_map
104
+ rgb_output = np.zeros((h, w, 3), dtype=np.uint8)
105
+ for idx, rgb in colour_map.items():
106
+ mask = index_image == idx
107
+ rgb_output[mask] = rgb
108
+ return rgb_output
109
+
110
+
111
+ def nearest_palette_image(image_array: np.ndarray, palette: np.ndarray) -> np.ndarray:
112
+ """
113
+ Assign each pixel in `image_array` (H,W,3 uint8) to the nearest colour in `palette` (K,3 uint8).
114
+ Returns recoloured image array with same shape and dtype uint8.
115
+ """
116
+ if image_array.ndim != 3 or image_array.shape[2] != 3:
117
+ raise ValueError("image_array must have shape (H, W, 3)")
118
+
119
+ h, w, _ = image_array.shape
120
+ flat = image_array.reshape(-1, 3).astype(np.int64)
121
+ pal = palette.astype(np.int64)
122
+
123
+ # Compute squared distances between each pixel and each palette colour.
124
+ # distances shape: (N_pixels, K)
125
+ d = np.sum((flat[:, None, :] - pal[None, :, :]) ** 2, axis=2)
126
+
127
+ idx = np.argmin(d, axis=1)
128
+ new_flat = pal[idx]
129
+ new = new_flat.reshape(h, w, 3).astype(np.uint8)
130
+ return new
131
+
132
+
133
+ def get_palette_for_image(
134
+ image_array: np.ndarray, full_palette: np.ndarray
135
+ ) -> np.ndarray:
136
+ """
137
+ Identify which colours from the full palette are present in the image,
138
+ and return only those colours.
139
+ """
140
+ h, w, _ = image_array.shape
141
+ flat_img = image_array.reshape(-1, 3).astype(np.int16)
142
+ pal = full_palette.astype(np.int16)
143
+
144
+ # For each pixel, find the nearest palette colour
145
+ d = np.sum((flat_img[:, None, :] - pal[None, :, :]) ** 2, axis=2)
146
+ idx = np.argmin(d, axis=1)
147
+
148
+ # Get unique indices that are actually used
149
+ unique_idx = np.unique(idx)
150
+
151
+ # Return only the palette colours that are used
152
+ return full_palette[unique_idx]
153
+
154
+
155
+ def apply_morphological_clean(image_array: np.ndarray, kernel_size: int) -> np.ndarray:
156
+ """
157
+ Apply morphological closing (erosion followed by dilation) per class to clean up
158
+ class boundaries and remove noise.
159
+ """
160
+ if kernel_size <= 0:
161
+ return image_array
162
+
163
+ # Create morphological kernel
164
+ kernel = ndimage.generate_binary_structure(2, 2)
165
+
166
+ # Get unique colours that actually appear in the image
167
+ h, w, _ = image_array.shape
168
+ flat_img = image_array.reshape(-1, 3)
169
+ unique_colours = np.unique(flat_img, axis=0)
170
+
171
+ # Process each class separately to avoid blending
172
+ result = np.zeros_like(image_array)
173
+
174
+ for colour in unique_colours:
175
+ # Create binary mask for this class
176
+ mask = np.all(image_array == colour, axis=-1)
177
+
178
+ # Apply closing: erosion then dilation
179
+ for _ in range(kernel_size):
180
+ mask = ndimage.binary_erosion(mask, structure=kernel)
181
+ for _ in range(kernel_size):
182
+ mask = ndimage.binary_dilation(mask, structure=kernel)
183
+
184
+ # Assign pixels back
185
+ result[mask] = colour
186
+
187
+ # Fill any remaining pixels (from eroded areas) with nearest colour from result
188
+ unfilled = ~np.any(result != 0, axis=-1)
189
+ if np.any(unfilled):
190
+ # For unfilled pixels, use nearest palette colour again or copy from nearby
191
+ result[unfilled] = image_array[unfilled]
192
+
193
+ return result
194
+
195
+
196
+ def rgb_image_to_index(image_array: np.ndarray, palette: np.ndarray) -> np.ndarray:
197
+ """
198
+ Map each RGB pixel in `image_array` to the index of the matching colour in `palette`.
199
+ Assumes pixels take values from `palette`.
200
+ """
201
+ h, w, _ = image_array.shape
202
+ palette_list = [tuple(map(int, c)) for c in palette.tolist()]
203
+ lookup = {c: i for i, c in enumerate(palette_list)}
204
+ flat = image_array.reshape(-1, 3)
205
+ idx = np.array([lookup[tuple(map(int, px))] for px in flat], dtype=np.uint16)
206
+ return idx.reshape(h, w)
207
+
208
+
209
+ def process_file(
210
+ input_path: str,
211
+ output_path: str,
212
+ palette: np.ndarray,
213
+ kernel_size: int,
214
+ output_type: str = "rgb",
215
+ ):
216
+ try:
217
+ img = Image.open(input_path).convert("RGB")
218
+ except Exception as e:
219
+ print(f"Skipping {input_path}: cannot open image ({e})")
220
+ return
221
+
222
+ arr = np.array(img, dtype=np.uint8)
223
+
224
+ # Clean image using core function
225
+ cleaned = clean_image_palette(
226
+ arr, palette=palette, morph_kernel_size=kernel_size, output_type=output_type
227
+ )
228
+
229
+ # Ensure output directory exists
230
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
231
+
232
+ if output_type == "rgb":
233
+ Image.fromarray(cleaned).save(output_path)
234
+ elif output_type == "index":
235
+ Image.fromarray(cleaned.astype(np.uint8), mode="L").save(output_path)
236
+ else:
237
+ raise ValueError("output_type must be 'rgb' or 'index'")
238
+
239
+
240
+ def process_directory(
241
+ input_dir: str,
242
+ output_dir: str,
243
+ palette: np.ndarray,
244
+ exts: List[str],
245
+ inplace: bool,
246
+ name_filter: str = "",
247
+ kernel_size: int = 0,
248
+ output_type: str = "rgb",
249
+ ):
250
+ exts = [e.lower().strip() for e in exts]
251
+ for root, dirs, files in os.walk(input_dir):
252
+ # Determine the corresponding output root
253
+ rel = os.path.relpath(root, input_dir)
254
+ out_root = os.path.join(output_dir, rel) if not inplace else root
255
+ os.makedirs(out_root, exist_ok=True)
256
+
257
+ for fname in files:
258
+ if not any(fname.lower().endswith(e) for e in exts):
259
+ continue
260
+ if name_filter and name_filter not in fname:
261
+ continue
262
+ in_path = os.path.join(root, fname)
263
+ out_path = os.path.join(out_root, fname)
264
+ process_file(in_path, out_path, palette, kernel_size, output_type)
265
+
266
+
267
+ def clean_segmentation(
268
+ input_dir: str,
269
+ output_dir: str = None,
270
+ inplace: bool = False,
271
+ palette: np.ndarray = None,
272
+ exts: str = ".png,.jpg,.jpeg,.tiff,.bmp,.gif",
273
+ name_filter: str = "",
274
+ morph_kernel_size: int = 3,
275
+ output_type: str = "rgb",
276
+ ):
277
+ """
278
+ Clean segmentation images using palette-based color mapping.
279
+
280
+ Args:
281
+ input_dir (str): Path to input directory containing segmentation images.
282
+ output_dir (str, optional): Directory where cleaned images will be written. Required if not inplace.
283
+ inplace (bool): Overwrite input images in place.
284
+ palette (np.ndarray): Array of RGB triples (K, 3) uint8.
285
+ exts (str): Comma-separated list of allowed image extensions.
286
+ name_filter (str): Only process files whose name contains this substring.
287
+ morph_kernel_size (int): Size of morphological kernel for boundary cleaning.
288
+ """
289
+ if not inplace and output_dir is None:
290
+ raise ValueError("Either output_dir must be provided or inplace must be True")
291
+
292
+ if palette is None:
293
+ raise ValueError("palette must be provided")
294
+
295
+ exts_list = [e if e.startswith(".") else "." + e for e in exts.split(",")]
296
+
297
+ out_dir = output_dir if not inplace else input_dir
298
+
299
+ if not inplace:
300
+ os.makedirs(out_dir, exist_ok=True)
301
+
302
+ print(
303
+ f"Processing: input={input_dir} -> output={out_dir}, colours={len(palette)}, morph_kernel={morph_kernel_size}, output_type={output_type}"
304
+ )
305
+ process_directory(
306
+ input_dir,
307
+ out_dir,
308
+ palette,
309
+ exts_list,
310
+ inplace,
311
+ name_filter,
312
+ morph_kernel_size,
313
+ output_type,
314
+ )
315
+ print("Done.")
@@ -0,0 +1,180 @@
1
+ import argparse
2
+
3
+ import numpy as np
4
+
5
+ from . import clean, nn, train, utils
6
+
7
+
8
+ def main_clean():
9
+ parser = argparse.ArgumentParser(
10
+ description="Clean segmentation images using various methods."
11
+ )
12
+
13
+ parser.add_argument(
14
+ "--method",
15
+ type=str,
16
+ required=True,
17
+ choices=["palette", "nn"],
18
+ help="Cleaning method to use: 'palette' for color palette mapping, 'nn' for neural network (pixel_decoder or CNN).",
19
+ )
20
+ parser.add_argument(
21
+ "--input_dir",
22
+ type=str,
23
+ required=True,
24
+ help="Path to input directory containing images.",
25
+ )
26
+ parser.add_argument(
27
+ "--output_dir",
28
+ type=str,
29
+ required=False,
30
+ help="Directory where cleaned images will be written. Required if not inplace.",
31
+ )
32
+ parser.add_argument(
33
+ "--inplace",
34
+ action="store_true",
35
+ help="Overwrite input images in place.",
36
+ )
37
+ parser.add_argument(
38
+ "--exts",
39
+ type=str,
40
+ default=".png,.jpg,.jpeg,.tiff,.bmp,.gif",
41
+ help="Comma-separated list of allowed image extensions.",
42
+ )
43
+ parser.add_argument(
44
+ "--name_filter",
45
+ type=str,
46
+ default="",
47
+ help="Only process files whose name contains this substring.",
48
+ )
49
+
50
+ parser.add_argument(
51
+ "--output_type",
52
+ type=str,
53
+ choices=["rgb", "index"],
54
+ default="rgb",
55
+ help="Output format: 'rgb' colour image or 'index' mask.",
56
+ )
57
+
58
+ group = parser.add_mutually_exclusive_group(required=True)
59
+ group.add_argument(
60
+ "--colour_map",
61
+ type=str,
62
+ help="Semicolon-separated list of RGB triples.",
63
+ )
64
+ group.add_argument(
65
+ "--colour_map_file",
66
+ type=str,
67
+ help="Path to a file listing RGB triples.",
68
+ )
69
+
70
+ # Palette-specific args
71
+ parser.add_argument(
72
+ "--morph_kernel_size",
73
+ type=int,
74
+ default=3,
75
+ help="Size of morphological kernel for palette method.",
76
+ )
77
+
78
+ # NN-specific args
79
+ parser.add_argument(
80
+ "--model_path",
81
+ type=str,
82
+ help="Path to trained model for nn method.",
83
+ )
84
+
85
+ args = parser.parse_args()
86
+
87
+ if args.colour_map_file:
88
+ colours = utils.parse_colours_from_file(args.colour_map_file)
89
+ else:
90
+ colours = utils.parse_colours_from_string(args.colour_map)
91
+
92
+ if args.method == "palette":
93
+ palette = np.asarray(colours, dtype=np.uint8)
94
+
95
+ clean.clean_segmentation(
96
+ input_dir=args.input_dir,
97
+ output_dir=args.output_dir,
98
+ inplace=args.inplace,
99
+ palette=palette,
100
+ exts=args.exts,
101
+ name_filter=args.name_filter,
102
+ morph_kernel_size=args.morph_kernel_size,
103
+ output_type=args.output_type,
104
+ )
105
+
106
+ elif args.method == "nn":
107
+ if not args.model_path:
108
+ parser.error("--model_path required for nn method")
109
+
110
+ colour_map = {i: rgb for i, rgb in enumerate(colours)}
111
+ nn.run_inference(
112
+ input_dir=args.input_dir,
113
+ output_dir=args.output_dir,
114
+ inplace=args.inplace,
115
+ model_path=args.model_path,
116
+ colour_map=colour_map,
117
+ exts=args.exts,
118
+ name_filter=args.name_filter,
119
+ output_type=args.output_type,
120
+ )
121
+
122
+
123
+ def main_train():
124
+ parser = argparse.ArgumentParser(
125
+ description="Train a neural network model for segmentation cleaning."
126
+ )
127
+
128
+ parser.add_argument(
129
+ "--image_dir",
130
+ type=str,
131
+ required=True,
132
+ help="Path to directory containing noisy images.",
133
+ )
134
+ parser.add_argument(
135
+ "--label_dir",
136
+ type=str,
137
+ required=True,
138
+ help="Path to directory containing target RGB labels.",
139
+ )
140
+ parser.add_argument(
141
+ "--output_dir",
142
+ type=str,
143
+ required=True,
144
+ help="Directory where model weights will be saved.",
145
+ )
146
+ parser.add_argument(
147
+ "--model_type",
148
+ type=str,
149
+ choices=["pixel_decoder", "cnn_decoder"],
150
+ default="pixel_decoder",
151
+ help="The type of model to train: 'pixel_decoder' for MLP or 'cnn_decoder' for CNN-based decoder.",
152
+ )
153
+
154
+ group = parser.add_mutually_exclusive_group(required=True)
155
+ group.add_argument(
156
+ "--colour_map",
157
+ type=str,
158
+ help="Semicolon-separated list of RGB triples.",
159
+ )
160
+ group.add_argument(
161
+ "--colour_map_file",
162
+ type=str,
163
+ help="Path to a file listing RGB triples.",
164
+ )
165
+
166
+ args = parser.parse_args()
167
+
168
+ if args.colour_map_file:
169
+ colours = utils.parse_colours_from_file(args.colour_map_file)
170
+ else:
171
+ colours = utils.parse_colours_from_string(args.colour_map)
172
+ colour_map = {i: rgb for i, rgb in enumerate(colours)}
173
+
174
+ train.train_model(
175
+ image_dir=args.image_dir,
176
+ label_dir=args.label_dir,
177
+ output_dir=args.output_dir,
178
+ colour_map=colour_map,
179
+ model_type=args.model_type,
180
+ )
@@ -0,0 +1,5 @@
1
+ from .base_classifier import PixelClassifier
2
+ from .pixelwise_classifier import PixelwiseClassifier
3
+ from .cnn_decoder import CNNDecoder
4
+
5
+ __all__ = ["PixelClassifier", "PixelwiseClassifier", "CNNDecoder"]