rgb-to-segmentation 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rgb_to_segmentation/__init__.py +10 -0
- rgb_to_segmentation/api.py +123 -0
- rgb_to_segmentation/clean.py +315 -0
- rgb_to_segmentation/cli.py +180 -0
- rgb_to_segmentation/models/__init__.py +5 -0
- rgb_to_segmentation/models/base_classifier.py +77 -0
- rgb_to_segmentation/models/cnn_decoder.py +103 -0
- rgb_to_segmentation/models/pixelwise_classifier.py +31 -0
- rgb_to_segmentation/nn.py +164 -0
- rgb_to_segmentation/train.py +212 -0
- rgb_to_segmentation/utils.py +40 -0
- rgb_to_segmentation-0.1.6.dist-info/METADATA +213 -0
- rgb_to_segmentation-0.1.6.dist-info/RECORD +15 -0
- rgb_to_segmentation-0.1.6.dist-info/WHEEL +4 -0
- rgb_to_segmentation-0.1.6.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from typing import Dict, Tuple, Optional, Union
|
|
5
|
+
|
|
6
|
+
from .clean import clean_image_palette, clean_image_strict_palette
|
|
7
|
+
from .nn import clean_image_nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
ImageArray = Union[np.ndarray, torch.Tensor]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _to_numpy(image_array: ImageArray):
|
|
14
|
+
if isinstance(image_array, np.ndarray):
|
|
15
|
+
return image_array, False, None, None
|
|
16
|
+
if isinstance(image_array, torch.Tensor):
|
|
17
|
+
return (
|
|
18
|
+
image_array.detach().cpu().numpy(),
|
|
19
|
+
True,
|
|
20
|
+
image_array.dtype,
|
|
21
|
+
image_array.device,
|
|
22
|
+
)
|
|
23
|
+
raise TypeError("image_array must be a numpy.ndarray or torch.Tensor")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _to_original_type(
|
|
27
|
+
np_array: np.ndarray,
|
|
28
|
+
is_torch: bool,
|
|
29
|
+
dtype: Optional[torch.dtype],
|
|
30
|
+
device: Optional[torch.device],
|
|
31
|
+
) -> ImageArray:
|
|
32
|
+
if not is_torch:
|
|
33
|
+
return np_array
|
|
34
|
+
|
|
35
|
+
tensor = torch.from_numpy(np_array)
|
|
36
|
+
|
|
37
|
+
if dtype is not None and tensor.dtype != dtype:
|
|
38
|
+
tensor = tensor.to(dtype=dtype)
|
|
39
|
+
if device is not None and tensor.device != device:
|
|
40
|
+
tensor = tensor.to(device=device)
|
|
41
|
+
|
|
42
|
+
return tensor
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def clean_image(
|
|
46
|
+
image_array: ImageArray,
|
|
47
|
+
method: str,
|
|
48
|
+
colour_map: Dict[int, Tuple[int, int, int]],
|
|
49
|
+
*,
|
|
50
|
+
model: Optional[object] = None,
|
|
51
|
+
morph_kernel_size: int = 0,
|
|
52
|
+
output_type: str = "rgb",
|
|
53
|
+
) -> ImageArray:
|
|
54
|
+
"""
|
|
55
|
+
Clean a single image (numpy array or torch tensor) using the specified method.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
image_array: Array/Tensor of shape (H, W, 3), dtype uint8.
|
|
59
|
+
method: "palette", "strict_palette", "pixel_decoder", or "cnn_decoder" to choose cleaning approach.
|
|
60
|
+
model: Required when method="pixel_decoder" or "cnn_decoder". A trained model with forward(batch) returning class probabilities.
|
|
61
|
+
colour_map: Required for all methods. Dict mapping class index -> (r,g,b).
|
|
62
|
+
morph_kernel_size: Optional morphological clean kernel size (palette method only).
|
|
63
|
+
output_type: "rgb" to return colour image, "index" to return integer mask.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Cleaned image with the same container type as the input (np.ndarray or torch.Tensor):
|
|
67
|
+
(H,W,3) uint8 when output_type="rgb", otherwise (H,W) uint8.
|
|
68
|
+
"""
|
|
69
|
+
np_image, is_torch, orig_dtype, orig_device = _to_numpy(image_array)
|
|
70
|
+
|
|
71
|
+
if np_image.ndim != 3 or np_image.shape[2] != 3:
|
|
72
|
+
raise ValueError("image_array must have shape (H, W, 3)")
|
|
73
|
+
|
|
74
|
+
if output_type not in ("rgb", "index"):
|
|
75
|
+
raise ValueError("output_type must be 'rgb' or 'index'")
|
|
76
|
+
|
|
77
|
+
if method == "palette":
|
|
78
|
+
# Build palette ndarray from colour_map in index order and delegate to core function
|
|
79
|
+
keys = sorted(colour_map.keys())
|
|
80
|
+
palette = np.asarray([colour_map[k] for k in keys], dtype=np.uint8)
|
|
81
|
+
|
|
82
|
+
cleaned = clean_image_palette(
|
|
83
|
+
np_image,
|
|
84
|
+
palette=palette,
|
|
85
|
+
morph_kernel_size=morph_kernel_size,
|
|
86
|
+
output_type=output_type,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
elif method == "strict_palette":
|
|
90
|
+
cleaned = clean_image_strict_palette(
|
|
91
|
+
np_image,
|
|
92
|
+
colour_map=colour_map,
|
|
93
|
+
output_type=output_type,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
elif method == "pixel_decoder":
|
|
97
|
+
if model is None:
|
|
98
|
+
raise ValueError("model must be provided for method='pixel_decoder'")
|
|
99
|
+
|
|
100
|
+
cleaned = clean_image_nn(
|
|
101
|
+
np_image,
|
|
102
|
+
model=model,
|
|
103
|
+
colour_map=colour_map,
|
|
104
|
+
output_type=output_type,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
elif method == "cnn_decoder":
|
|
108
|
+
if model is None:
|
|
109
|
+
raise ValueError("model must be provided for method='cnn_decoder'")
|
|
110
|
+
|
|
111
|
+
cleaned = clean_image_nn(
|
|
112
|
+
np_image,
|
|
113
|
+
model=model,
|
|
114
|
+
colour_map=colour_map,
|
|
115
|
+
output_type=output_type,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"method must be 'palette', 'strict_palette', 'pixel_decoder', or 'cnn_decoder'"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return _to_original_type(cleaned, is_torch, orig_dtype, orig_device)
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from scipy import ndimage
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def clean_image_palette(
|
|
12
|
+
image_array: np.ndarray,
|
|
13
|
+
palette: np.ndarray,
|
|
14
|
+
morph_kernel_size: int = 0,
|
|
15
|
+
output_type: str = "rgb",
|
|
16
|
+
) -> np.ndarray:
|
|
17
|
+
"""
|
|
18
|
+
Clean a single RGB image using palette-based nearest-colour mapping.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
image_array: (H, W, 3) uint8
|
|
22
|
+
palette: (K, 3) uint8
|
|
23
|
+
morph_kernel_size: kernel size for morphological cleaning (0 disables)
|
|
24
|
+
output_type: 'rgb' to return colour image; 'index' to return integer mask
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
np.ndarray: (H,W,3) uint8 if output_type='rgb'; else (H,W) uint8
|
|
28
|
+
"""
|
|
29
|
+
if output_type not in ("rgb", "index"):
|
|
30
|
+
raise ValueError("output_type must be 'rgb' or 'index'")
|
|
31
|
+
|
|
32
|
+
# Reduce palette to colours present in the image for efficiency
|
|
33
|
+
reduced_palette = get_palette_for_image(image_array, palette)
|
|
34
|
+
|
|
35
|
+
# Map to nearest colours
|
|
36
|
+
cleaned_rgb = nearest_palette_image(image_array, reduced_palette)
|
|
37
|
+
|
|
38
|
+
# Optional morphological clean
|
|
39
|
+
if morph_kernel_size > 0:
|
|
40
|
+
cleaned_rgb = apply_morphological_clean(cleaned_rgb, morph_kernel_size)
|
|
41
|
+
|
|
42
|
+
if output_type == "rgb":
|
|
43
|
+
return cleaned_rgb
|
|
44
|
+
else:
|
|
45
|
+
return rgb_image_to_index(cleaned_rgb, reduced_palette).astype(np.uint8)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def clean_image_strict_palette(
|
|
49
|
+
image_array: np.ndarray,
|
|
50
|
+
colour_map: dict,
|
|
51
|
+
output_type: str = "rgb",
|
|
52
|
+
) -> np.ndarray:
|
|
53
|
+
"""
|
|
54
|
+
Strictly map RGB values to indices based on colour_map.
|
|
55
|
+
Raises an error if any RGB value is not found in the colour_map.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
image_array: (H, W, 3) uint8 numpy array
|
|
59
|
+
colour_map: Dict mapping class index -> (r,g,b)
|
|
60
|
+
output_type: 'rgb' to return colour image; 'index' to return integer mask
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
np.ndarray: (H,W,3) uint8 if output_type='rgb'; else (H,W) uint8
|
|
64
|
+
"""
|
|
65
|
+
if output_type not in ("rgb", "index"):
|
|
66
|
+
raise ValueError("output_type must be 'rgb' or 'index'")
|
|
67
|
+
|
|
68
|
+
h, w, _ = image_array.shape
|
|
69
|
+
flat_img = image_array.reshape(-1, 3)
|
|
70
|
+
|
|
71
|
+
# Build reverse lookup: RGB tuple -> class index
|
|
72
|
+
rgb_to_idx = {tuple(map(int, rgb)): idx for idx, rgb in colour_map.items()}
|
|
73
|
+
|
|
74
|
+
# Find all unique RGB values in the image
|
|
75
|
+
unique_colours = np.unique(flat_img, axis=0)
|
|
76
|
+
|
|
77
|
+
# Check if all colours are in the colour_map
|
|
78
|
+
unmapped_colours = []
|
|
79
|
+
for colour in unique_colours:
|
|
80
|
+
colour_tuple = tuple(map(int, colour))
|
|
81
|
+
if colour_tuple not in rgb_to_idx:
|
|
82
|
+
unmapped_colours.append(colour_tuple)
|
|
83
|
+
|
|
84
|
+
if unmapped_colours:
|
|
85
|
+
# Format error message with unmapped colours
|
|
86
|
+
colour_strs = [f"RGB{c}" for c in unmapped_colours[:10]] # Show first 10
|
|
87
|
+
if len(unmapped_colours) > 10:
|
|
88
|
+
colour_strs.append(f"... and {len(unmapped_colours) - 10} more")
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"Image contains {len(unmapped_colours)} RGB value(s) not in colour_map: "
|
|
91
|
+
f"{', '.join(colour_strs)}. All pixel values must exactly match a colour in the map."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Map each pixel to its class index
|
|
95
|
+
flat_indices = np.array(
|
|
96
|
+
[rgb_to_idx[tuple(map(int, px))] for px in flat_img], dtype=np.uint16
|
|
97
|
+
)
|
|
98
|
+
index_image = flat_indices.reshape(h, w).astype(np.uint8)
|
|
99
|
+
|
|
100
|
+
if output_type == "index":
|
|
101
|
+
return index_image
|
|
102
|
+
else:
|
|
103
|
+
# Convert back to RGB using colour_map
|
|
104
|
+
rgb_output = np.zeros((h, w, 3), dtype=np.uint8)
|
|
105
|
+
for idx, rgb in colour_map.items():
|
|
106
|
+
mask = index_image == idx
|
|
107
|
+
rgb_output[mask] = rgb
|
|
108
|
+
return rgb_output
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def nearest_palette_image(image_array: np.ndarray, palette: np.ndarray) -> np.ndarray:
|
|
112
|
+
"""
|
|
113
|
+
Assign each pixel in `image_array` (H,W,3 uint8) to the nearest colour in `palette` (K,3 uint8).
|
|
114
|
+
Returns recoloured image array with same shape and dtype uint8.
|
|
115
|
+
"""
|
|
116
|
+
if image_array.ndim != 3 or image_array.shape[2] != 3:
|
|
117
|
+
raise ValueError("image_array must have shape (H, W, 3)")
|
|
118
|
+
|
|
119
|
+
h, w, _ = image_array.shape
|
|
120
|
+
flat = image_array.reshape(-1, 3).astype(np.int64)
|
|
121
|
+
pal = palette.astype(np.int64)
|
|
122
|
+
|
|
123
|
+
# Compute squared distances between each pixel and each palette colour.
|
|
124
|
+
# distances shape: (N_pixels, K)
|
|
125
|
+
d = np.sum((flat[:, None, :] - pal[None, :, :]) ** 2, axis=2)
|
|
126
|
+
|
|
127
|
+
idx = np.argmin(d, axis=1)
|
|
128
|
+
new_flat = pal[idx]
|
|
129
|
+
new = new_flat.reshape(h, w, 3).astype(np.uint8)
|
|
130
|
+
return new
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def get_palette_for_image(
|
|
134
|
+
image_array: np.ndarray, full_palette: np.ndarray
|
|
135
|
+
) -> np.ndarray:
|
|
136
|
+
"""
|
|
137
|
+
Identify which colours from the full palette are present in the image,
|
|
138
|
+
and return only those colours.
|
|
139
|
+
"""
|
|
140
|
+
h, w, _ = image_array.shape
|
|
141
|
+
flat_img = image_array.reshape(-1, 3).astype(np.int16)
|
|
142
|
+
pal = full_palette.astype(np.int16)
|
|
143
|
+
|
|
144
|
+
# For each pixel, find the nearest palette colour
|
|
145
|
+
d = np.sum((flat_img[:, None, :] - pal[None, :, :]) ** 2, axis=2)
|
|
146
|
+
idx = np.argmin(d, axis=1)
|
|
147
|
+
|
|
148
|
+
# Get unique indices that are actually used
|
|
149
|
+
unique_idx = np.unique(idx)
|
|
150
|
+
|
|
151
|
+
# Return only the palette colours that are used
|
|
152
|
+
return full_palette[unique_idx]
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def apply_morphological_clean(image_array: np.ndarray, kernel_size: int) -> np.ndarray:
|
|
156
|
+
"""
|
|
157
|
+
Apply morphological closing (erosion followed by dilation) per class to clean up
|
|
158
|
+
class boundaries and remove noise.
|
|
159
|
+
"""
|
|
160
|
+
if kernel_size <= 0:
|
|
161
|
+
return image_array
|
|
162
|
+
|
|
163
|
+
# Create morphological kernel
|
|
164
|
+
kernel = ndimage.generate_binary_structure(2, 2)
|
|
165
|
+
|
|
166
|
+
# Get unique colours that actually appear in the image
|
|
167
|
+
h, w, _ = image_array.shape
|
|
168
|
+
flat_img = image_array.reshape(-1, 3)
|
|
169
|
+
unique_colours = np.unique(flat_img, axis=0)
|
|
170
|
+
|
|
171
|
+
# Process each class separately to avoid blending
|
|
172
|
+
result = np.zeros_like(image_array)
|
|
173
|
+
|
|
174
|
+
for colour in unique_colours:
|
|
175
|
+
# Create binary mask for this class
|
|
176
|
+
mask = np.all(image_array == colour, axis=-1)
|
|
177
|
+
|
|
178
|
+
# Apply closing: erosion then dilation
|
|
179
|
+
for _ in range(kernel_size):
|
|
180
|
+
mask = ndimage.binary_erosion(mask, structure=kernel)
|
|
181
|
+
for _ in range(kernel_size):
|
|
182
|
+
mask = ndimage.binary_dilation(mask, structure=kernel)
|
|
183
|
+
|
|
184
|
+
# Assign pixels back
|
|
185
|
+
result[mask] = colour
|
|
186
|
+
|
|
187
|
+
# Fill any remaining pixels (from eroded areas) with nearest colour from result
|
|
188
|
+
unfilled = ~np.any(result != 0, axis=-1)
|
|
189
|
+
if np.any(unfilled):
|
|
190
|
+
# For unfilled pixels, use nearest palette colour again or copy from nearby
|
|
191
|
+
result[unfilled] = image_array[unfilled]
|
|
192
|
+
|
|
193
|
+
return result
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def rgb_image_to_index(image_array: np.ndarray, palette: np.ndarray) -> np.ndarray:
|
|
197
|
+
"""
|
|
198
|
+
Map each RGB pixel in `image_array` to the index of the matching colour in `palette`.
|
|
199
|
+
Assumes pixels take values from `palette`.
|
|
200
|
+
"""
|
|
201
|
+
h, w, _ = image_array.shape
|
|
202
|
+
palette_list = [tuple(map(int, c)) for c in palette.tolist()]
|
|
203
|
+
lookup = {c: i for i, c in enumerate(palette_list)}
|
|
204
|
+
flat = image_array.reshape(-1, 3)
|
|
205
|
+
idx = np.array([lookup[tuple(map(int, px))] for px in flat], dtype=np.uint16)
|
|
206
|
+
return idx.reshape(h, w)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def process_file(
|
|
210
|
+
input_path: str,
|
|
211
|
+
output_path: str,
|
|
212
|
+
palette: np.ndarray,
|
|
213
|
+
kernel_size: int,
|
|
214
|
+
output_type: str = "rgb",
|
|
215
|
+
):
|
|
216
|
+
try:
|
|
217
|
+
img = Image.open(input_path).convert("RGB")
|
|
218
|
+
except Exception as e:
|
|
219
|
+
print(f"Skipping {input_path}: cannot open image ({e})")
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
arr = np.array(img, dtype=np.uint8)
|
|
223
|
+
|
|
224
|
+
# Clean image using core function
|
|
225
|
+
cleaned = clean_image_palette(
|
|
226
|
+
arr, palette=palette, morph_kernel_size=kernel_size, output_type=output_type
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Ensure output directory exists
|
|
230
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
231
|
+
|
|
232
|
+
if output_type == "rgb":
|
|
233
|
+
Image.fromarray(cleaned).save(output_path)
|
|
234
|
+
elif output_type == "index":
|
|
235
|
+
Image.fromarray(cleaned.astype(np.uint8), mode="L").save(output_path)
|
|
236
|
+
else:
|
|
237
|
+
raise ValueError("output_type must be 'rgb' or 'index'")
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def process_directory(
|
|
241
|
+
input_dir: str,
|
|
242
|
+
output_dir: str,
|
|
243
|
+
palette: np.ndarray,
|
|
244
|
+
exts: List[str],
|
|
245
|
+
inplace: bool,
|
|
246
|
+
name_filter: str = "",
|
|
247
|
+
kernel_size: int = 0,
|
|
248
|
+
output_type: str = "rgb",
|
|
249
|
+
):
|
|
250
|
+
exts = [e.lower().strip() for e in exts]
|
|
251
|
+
for root, dirs, files in os.walk(input_dir):
|
|
252
|
+
# Determine the corresponding output root
|
|
253
|
+
rel = os.path.relpath(root, input_dir)
|
|
254
|
+
out_root = os.path.join(output_dir, rel) if not inplace else root
|
|
255
|
+
os.makedirs(out_root, exist_ok=True)
|
|
256
|
+
|
|
257
|
+
for fname in files:
|
|
258
|
+
if not any(fname.lower().endswith(e) for e in exts):
|
|
259
|
+
continue
|
|
260
|
+
if name_filter and name_filter not in fname:
|
|
261
|
+
continue
|
|
262
|
+
in_path = os.path.join(root, fname)
|
|
263
|
+
out_path = os.path.join(out_root, fname)
|
|
264
|
+
process_file(in_path, out_path, palette, kernel_size, output_type)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def clean_segmentation(
|
|
268
|
+
input_dir: str,
|
|
269
|
+
output_dir: str = None,
|
|
270
|
+
inplace: bool = False,
|
|
271
|
+
palette: np.ndarray = None,
|
|
272
|
+
exts: str = ".png,.jpg,.jpeg,.tiff,.bmp,.gif",
|
|
273
|
+
name_filter: str = "",
|
|
274
|
+
morph_kernel_size: int = 3,
|
|
275
|
+
output_type: str = "rgb",
|
|
276
|
+
):
|
|
277
|
+
"""
|
|
278
|
+
Clean segmentation images using palette-based color mapping.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
input_dir (str): Path to input directory containing segmentation images.
|
|
282
|
+
output_dir (str, optional): Directory where cleaned images will be written. Required if not inplace.
|
|
283
|
+
inplace (bool): Overwrite input images in place.
|
|
284
|
+
palette (np.ndarray): Array of RGB triples (K, 3) uint8.
|
|
285
|
+
exts (str): Comma-separated list of allowed image extensions.
|
|
286
|
+
name_filter (str): Only process files whose name contains this substring.
|
|
287
|
+
morph_kernel_size (int): Size of morphological kernel for boundary cleaning.
|
|
288
|
+
"""
|
|
289
|
+
if not inplace and output_dir is None:
|
|
290
|
+
raise ValueError("Either output_dir must be provided or inplace must be True")
|
|
291
|
+
|
|
292
|
+
if palette is None:
|
|
293
|
+
raise ValueError("palette must be provided")
|
|
294
|
+
|
|
295
|
+
exts_list = [e if e.startswith(".") else "." + e for e in exts.split(",")]
|
|
296
|
+
|
|
297
|
+
out_dir = output_dir if not inplace else input_dir
|
|
298
|
+
|
|
299
|
+
if not inplace:
|
|
300
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
301
|
+
|
|
302
|
+
print(
|
|
303
|
+
f"Processing: input={input_dir} -> output={out_dir}, colours={len(palette)}, morph_kernel={morph_kernel_size}, output_type={output_type}"
|
|
304
|
+
)
|
|
305
|
+
process_directory(
|
|
306
|
+
input_dir,
|
|
307
|
+
out_dir,
|
|
308
|
+
palette,
|
|
309
|
+
exts_list,
|
|
310
|
+
inplace,
|
|
311
|
+
name_filter,
|
|
312
|
+
morph_kernel_size,
|
|
313
|
+
output_type,
|
|
314
|
+
)
|
|
315
|
+
print("Done.")
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from . import clean, nn, train, utils
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def main_clean():
|
|
9
|
+
parser = argparse.ArgumentParser(
|
|
10
|
+
description="Clean segmentation images using various methods."
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
parser.add_argument(
|
|
14
|
+
"--method",
|
|
15
|
+
type=str,
|
|
16
|
+
required=True,
|
|
17
|
+
choices=["palette", "nn"],
|
|
18
|
+
help="Cleaning method to use: 'palette' for color palette mapping, 'nn' for neural network (pixel_decoder or CNN).",
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--input_dir",
|
|
22
|
+
type=str,
|
|
23
|
+
required=True,
|
|
24
|
+
help="Path to input directory containing images.",
|
|
25
|
+
)
|
|
26
|
+
parser.add_argument(
|
|
27
|
+
"--output_dir",
|
|
28
|
+
type=str,
|
|
29
|
+
required=False,
|
|
30
|
+
help="Directory where cleaned images will be written. Required if not inplace.",
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"--inplace",
|
|
34
|
+
action="store_true",
|
|
35
|
+
help="Overwrite input images in place.",
|
|
36
|
+
)
|
|
37
|
+
parser.add_argument(
|
|
38
|
+
"--exts",
|
|
39
|
+
type=str,
|
|
40
|
+
default=".png,.jpg,.jpeg,.tiff,.bmp,.gif",
|
|
41
|
+
help="Comma-separated list of allowed image extensions.",
|
|
42
|
+
)
|
|
43
|
+
parser.add_argument(
|
|
44
|
+
"--name_filter",
|
|
45
|
+
type=str,
|
|
46
|
+
default="",
|
|
47
|
+
help="Only process files whose name contains this substring.",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
"--output_type",
|
|
52
|
+
type=str,
|
|
53
|
+
choices=["rgb", "index"],
|
|
54
|
+
default="rgb",
|
|
55
|
+
help="Output format: 'rgb' colour image or 'index' mask.",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
59
|
+
group.add_argument(
|
|
60
|
+
"--colour_map",
|
|
61
|
+
type=str,
|
|
62
|
+
help="Semicolon-separated list of RGB triples.",
|
|
63
|
+
)
|
|
64
|
+
group.add_argument(
|
|
65
|
+
"--colour_map_file",
|
|
66
|
+
type=str,
|
|
67
|
+
help="Path to a file listing RGB triples.",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Palette-specific args
|
|
71
|
+
parser.add_argument(
|
|
72
|
+
"--morph_kernel_size",
|
|
73
|
+
type=int,
|
|
74
|
+
default=3,
|
|
75
|
+
help="Size of morphological kernel for palette method.",
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# NN-specific args
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"--model_path",
|
|
81
|
+
type=str,
|
|
82
|
+
help="Path to trained model for nn method.",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
args = parser.parse_args()
|
|
86
|
+
|
|
87
|
+
if args.colour_map_file:
|
|
88
|
+
colours = utils.parse_colours_from_file(args.colour_map_file)
|
|
89
|
+
else:
|
|
90
|
+
colours = utils.parse_colours_from_string(args.colour_map)
|
|
91
|
+
|
|
92
|
+
if args.method == "palette":
|
|
93
|
+
palette = np.asarray(colours, dtype=np.uint8)
|
|
94
|
+
|
|
95
|
+
clean.clean_segmentation(
|
|
96
|
+
input_dir=args.input_dir,
|
|
97
|
+
output_dir=args.output_dir,
|
|
98
|
+
inplace=args.inplace,
|
|
99
|
+
palette=palette,
|
|
100
|
+
exts=args.exts,
|
|
101
|
+
name_filter=args.name_filter,
|
|
102
|
+
morph_kernel_size=args.morph_kernel_size,
|
|
103
|
+
output_type=args.output_type,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
elif args.method == "nn":
|
|
107
|
+
if not args.model_path:
|
|
108
|
+
parser.error("--model_path required for nn method")
|
|
109
|
+
|
|
110
|
+
colour_map = {i: rgb for i, rgb in enumerate(colours)}
|
|
111
|
+
nn.run_inference(
|
|
112
|
+
input_dir=args.input_dir,
|
|
113
|
+
output_dir=args.output_dir,
|
|
114
|
+
inplace=args.inplace,
|
|
115
|
+
model_path=args.model_path,
|
|
116
|
+
colour_map=colour_map,
|
|
117
|
+
exts=args.exts,
|
|
118
|
+
name_filter=args.name_filter,
|
|
119
|
+
output_type=args.output_type,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def main_train():
|
|
124
|
+
parser = argparse.ArgumentParser(
|
|
125
|
+
description="Train a neural network model for segmentation cleaning."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
parser.add_argument(
|
|
129
|
+
"--image_dir",
|
|
130
|
+
type=str,
|
|
131
|
+
required=True,
|
|
132
|
+
help="Path to directory containing noisy images.",
|
|
133
|
+
)
|
|
134
|
+
parser.add_argument(
|
|
135
|
+
"--label_dir",
|
|
136
|
+
type=str,
|
|
137
|
+
required=True,
|
|
138
|
+
help="Path to directory containing target RGB labels.",
|
|
139
|
+
)
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--output_dir",
|
|
142
|
+
type=str,
|
|
143
|
+
required=True,
|
|
144
|
+
help="Directory where model weights will be saved.",
|
|
145
|
+
)
|
|
146
|
+
parser.add_argument(
|
|
147
|
+
"--model_type",
|
|
148
|
+
type=str,
|
|
149
|
+
choices=["pixel_decoder", "cnn_decoder"],
|
|
150
|
+
default="pixel_decoder",
|
|
151
|
+
help="The type of model to train: 'pixel_decoder' for MLP or 'cnn_decoder' for CNN-based decoder.",
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
155
|
+
group.add_argument(
|
|
156
|
+
"--colour_map",
|
|
157
|
+
type=str,
|
|
158
|
+
help="Semicolon-separated list of RGB triples.",
|
|
159
|
+
)
|
|
160
|
+
group.add_argument(
|
|
161
|
+
"--colour_map_file",
|
|
162
|
+
type=str,
|
|
163
|
+
help="Path to a file listing RGB triples.",
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
args = parser.parse_args()
|
|
167
|
+
|
|
168
|
+
if args.colour_map_file:
|
|
169
|
+
colours = utils.parse_colours_from_file(args.colour_map_file)
|
|
170
|
+
else:
|
|
171
|
+
colours = utils.parse_colours_from_string(args.colour_map)
|
|
172
|
+
colour_map = {i: rgb for i, rgb in enumerate(colours)}
|
|
173
|
+
|
|
174
|
+
train.train_model(
|
|
175
|
+
image_dir=args.image_dir,
|
|
176
|
+
label_dir=args.label_dir,
|
|
177
|
+
output_dir=args.output_dir,
|
|
178
|
+
colour_map=colour_map,
|
|
179
|
+
model_type=args.model_type,
|
|
180
|
+
)
|