rgb-to-segmentation 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rgb-to-segmentation might be problematic. Click here for more details.
- rgb_to_segmentation-0.0.2/.gitignore +4 -0
- rgb_to_segmentation-0.0.2/PKG-INFO +142 -0
- rgb_to_segmentation-0.0.2/README.md +116 -0
- rgb_to_segmentation-0.0.2/pyproject.toml +48 -0
- rgb_to_segmentation-0.0.2/rgb_to_segmentation/__init__.py +0 -0
- rgb_to_segmentation-0.0.2/rgb_to_segmentation/clean.py +219 -0
- rgb_to_segmentation-0.0.2/rgb_to_segmentation/cli.py +179 -0
- rgb_to_segmentation-0.0.2/rgb_to_segmentation/models/__init__.py +0 -0
- rgb_to_segmentation-0.0.2/rgb_to_segmentation/models/pixelwise_classifier.py +57 -0
- rgb_to_segmentation-0.0.2/rgb_to_segmentation/nn.py +110 -0
- rgb_to_segmentation-0.0.2/rgb_to_segmentation/train.py +183 -0
- rgb_to_segmentation-0.0.2/rgb_to_segmentation/utils.py +40 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rgb-to-segmentation
|
|
3
|
+
Version: 0.0.2
|
|
4
|
+
Summary: Tools for processing and cleaning segmentation images using palette mapping and neural networks
|
|
5
|
+
Author: Alex Senden
|
|
6
|
+
Maintainer: Alex Senden
|
|
7
|
+
License: MIT
|
|
8
|
+
Keywords: image-processing,machine-learning,pytorch,segmentation
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Image Processing
|
|
18
|
+
Requires-Python: >=3.8
|
|
19
|
+
Requires-Dist: numpy>=1.21.0
|
|
20
|
+
Requires-Dist: pillow>=8.0.0
|
|
21
|
+
Requires-Dist: pytorch-lightning>=1.5.0
|
|
22
|
+
Requires-Dist: scipy>=1.7.0
|
|
23
|
+
Requires-Dist: torch>=1.9.0
|
|
24
|
+
Requires-Dist: torchvision>=0.10.0
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
|
|
27
|
+
# RGB to Segmentation
|
|
28
|
+
|
|
29
|
+
A Python package for processing and cleaning segmentation images. This package provides tools to convert RGB images to segmentation masks using palette-based color mapping and neural network-based refinement.
|
|
30
|
+
|
|
31
|
+
## Features
|
|
32
|
+
|
|
33
|
+
- **Palette-based Cleaning**: Clean noisy segmentation images by mapping pixels to the nearest colors in a predefined palette, with optional morphological operations to refine boundaries.
|
|
34
|
+
- **Neural Network Refinement**: Use a trained pixelwise classifier to refine segmentation masks using PyTorch Lightning.
|
|
35
|
+
- **Command-Line Interface**: Unified CLI for cleaning with method selection, plus separate training command.
|
|
36
|
+
- **Programmatic API**: Direct access to cleaning and training functions for integration into other workflows.
|
|
37
|
+
|
|
38
|
+
## Installation
|
|
39
|
+
|
|
40
|
+
Install from PyPI:
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
pip install rgb-to-segmentation
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
Or install from source:
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
git clone https://github.com/alexsenden/rgb-to-segmentation.git
|
|
50
|
+
cd rgb-to-segmentation
|
|
51
|
+
pip install .
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Usage
|
|
55
|
+
|
|
56
|
+
### Cleaning Noisy Segmentation Images
|
|
57
|
+
|
|
58
|
+
Use the `segment-clean` command to clean segmentation images using various methods:
|
|
59
|
+
|
|
60
|
+
#### Palette-based cleaning:
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
segment-clean --method palette --input_dir /path/to/input --output_dir /path/to/output --colour_map "0,0,0;255,0,0;0,255,0" --output_type rgb
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
#### Neural network-based cleaning:
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
segment-clean --method nn --input_dir /path/to/input --output_dir /path/to/output --model_path /path/to/model.ckpt --colour_map "0,0,0;255,0,0;0,255,0" --output_type index
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
You can also provide colours via file with `--colour_map_file /path/to/colours.txt` (one `r,g,b` per line). The CLI parses colours and constructs the palette/colour map internally, mirroring the Python API which accepts parsed structures (NumPy array for palette, dictionary for colour map).
|
|
73
|
+
|
|
74
|
+
Options:
|
|
75
|
+
|
|
76
|
+
- `--method`: Cleaning method ('palette' or 'nn')
|
|
77
|
+
- `--input_dir`: Path to input directory containing images
|
|
78
|
+
- `--output_dir`: Directory where cleaned images will be written
|
|
79
|
+
- `--inplace`: Overwrite input images in place
|
|
80
|
+
- `--exts`: Comma-separated list of allowed image extensions
|
|
81
|
+
- `--name_filter`: Only process files whose name contains this substring
|
|
82
|
+
- `--output_type`: Output format ('rgb' or 'index')
|
|
83
|
+
|
|
84
|
+
For palette method:
|
|
85
|
+
- `--colour_map`: Semicolon-separated list of RGB triples
|
|
86
|
+
- `--colour_map_file`: Path to a file listing RGB triples
|
|
87
|
+
- `--morph_kernel_size`: Size of morphological kernel for boundary cleaning
|
|
88
|
+
|
|
89
|
+
For nn method:
|
|
90
|
+
- `--model_path`: Path to trained model file
|
|
91
|
+
- `--colour_map`: Semicolon-separated list of RGB triples
|
|
92
|
+
- `--colour_map_file`: Path to a file listing RGB triples
|
|
93
|
+
|
|
94
|
+
### Training the Neural Network Model
|
|
95
|
+
|
|
96
|
+
Train a pixelwise classifier to refine segmentation masks:
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
segment-train --image_dir /path/to/noisy_images --label_dir /path/to/labels --output_dir /path/to/model_output --colour_map "0,0,0;255,0,0;0,255,0"
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
Options:
|
|
103
|
+
- `--image_dir`: Path to directory containing noisy images
|
|
104
|
+
- `--label_dir`: Path to directory containing target RGB labels
|
|
105
|
+
- `--output_dir`: Directory where model weights will be saved
|
|
106
|
+
- `--colour_map`: Semicolon-separated list of RGB triples
|
|
107
|
+
- `--colour_map_file`: Path to a file listing RGB triples
|
|
108
|
+
- `--model_type`: The type of model to train (default: pixelwise)
|
|
109
|
+
|
|
110
|
+
Note that one label image may have multiple corresponding noisy masks. Labels are matched to noisy masks whose filenames contain the label file basename (pre-extension name, i.e. `my_image.png` -> `my_image`).
|
|
111
|
+
|
|
112
|
+
## API
|
|
113
|
+
|
|
114
|
+
You can also use the package programmatically:
|
|
115
|
+
|
|
116
|
+
```python
|
|
117
|
+
import numpy as np
|
|
118
|
+
from rgb_to_segmentation import clean, nn, train, utils
|
|
119
|
+
|
|
120
|
+
# Palette cleaning
|
|
121
|
+
colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
|
|
122
|
+
palette = np.asarray(colours, dtype=np.uint8)
|
|
123
|
+
clean.clean_segmentation(input_dir="/path/to/input", output_dir="/path/to/output", palette=palette, output_type="index")
|
|
124
|
+
|
|
125
|
+
# NN inference
|
|
126
|
+
colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
|
|
127
|
+
colour_map = {i: rgb for i, rgb in enumerate(colours)}
|
|
128
|
+
nn.run_inference(input_dir="/path/to/input", output_dir="/path/to/output", model_path="/path/to/model.ckpt", colour_map=colour_map, output_type="rgb")
|
|
129
|
+
|
|
130
|
+
# Train model
|
|
131
|
+
colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
|
|
132
|
+
colour_map = {i: rgb for i, rgb in enumerate(colours)}
|
|
133
|
+
train.train_model(image_dir="/path/to/images", label_dir="/path/to/labels", output_dir="/path/to/output", colour_map=colour_map)
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
## Contributing
|
|
137
|
+
|
|
138
|
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
|
139
|
+
|
|
140
|
+
## License
|
|
141
|
+
|
|
142
|
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# RGB to Segmentation
|
|
2
|
+
|
|
3
|
+
A Python package for processing and cleaning segmentation images. This package provides tools to convert RGB images to segmentation masks using palette-based color mapping and neural network-based refinement.
|
|
4
|
+
|
|
5
|
+
## Features
|
|
6
|
+
|
|
7
|
+
- **Palette-based Cleaning**: Clean noisy segmentation images by mapping pixels to the nearest colors in a predefined palette, with optional morphological operations to refine boundaries.
|
|
8
|
+
- **Neural Network Refinement**: Use a trained pixelwise classifier to refine segmentation masks using PyTorch Lightning.
|
|
9
|
+
- **Command-Line Interface**: Unified CLI for cleaning with method selection, plus separate training command.
|
|
10
|
+
- **Programmatic API**: Direct access to cleaning and training functions for integration into other workflows.
|
|
11
|
+
|
|
12
|
+
## Installation
|
|
13
|
+
|
|
14
|
+
Install from PyPI:
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
pip install rgb-to-segmentation
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
Or install from source:
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
git clone https://github.com/alexsenden/rgb-to-segmentation.git
|
|
24
|
+
cd rgb-to-segmentation
|
|
25
|
+
pip install .
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Usage
|
|
29
|
+
|
|
30
|
+
### Cleaning Noisy Segmentation Images
|
|
31
|
+
|
|
32
|
+
Use the `segment-clean` command to clean segmentation images using various methods:
|
|
33
|
+
|
|
34
|
+
#### Palette-based cleaning:
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
segment-clean --method palette --input_dir /path/to/input --output_dir /path/to/output --colour_map "0,0,0;255,0,0;0,255,0" --output_type rgb
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
#### Neural network-based cleaning:
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
segment-clean --method nn --input_dir /path/to/input --output_dir /path/to/output --model_path /path/to/model.ckpt --colour_map "0,0,0;255,0,0;0,255,0" --output_type index
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
You can also provide colours via file with `--colour_map_file /path/to/colours.txt` (one `r,g,b` per line). The CLI parses colours and constructs the palette/colour map internally, mirroring the Python API which accepts parsed structures (NumPy array for palette, dictionary for colour map).
|
|
47
|
+
|
|
48
|
+
Options:
|
|
49
|
+
|
|
50
|
+
- `--method`: Cleaning method ('palette' or 'nn')
|
|
51
|
+
- `--input_dir`: Path to input directory containing images
|
|
52
|
+
- `--output_dir`: Directory where cleaned images will be written
|
|
53
|
+
- `--inplace`: Overwrite input images in place
|
|
54
|
+
- `--exts`: Comma-separated list of allowed image extensions
|
|
55
|
+
- `--name_filter`: Only process files whose name contains this substring
|
|
56
|
+
- `--output_type`: Output format ('rgb' or 'index')
|
|
57
|
+
|
|
58
|
+
For palette method:
|
|
59
|
+
- `--colour_map`: Semicolon-separated list of RGB triples
|
|
60
|
+
- `--colour_map_file`: Path to a file listing RGB triples
|
|
61
|
+
- `--morph_kernel_size`: Size of morphological kernel for boundary cleaning
|
|
62
|
+
|
|
63
|
+
For nn method:
|
|
64
|
+
- `--model_path`: Path to trained model file
|
|
65
|
+
- `--colour_map`: Semicolon-separated list of RGB triples
|
|
66
|
+
- `--colour_map_file`: Path to a file listing RGB triples
|
|
67
|
+
|
|
68
|
+
### Training the Neural Network Model
|
|
69
|
+
|
|
70
|
+
Train a pixelwise classifier to refine segmentation masks:
|
|
71
|
+
|
|
72
|
+
```bash
|
|
73
|
+
segment-train --image_dir /path/to/noisy_images --label_dir /path/to/labels --output_dir /path/to/model_output --colour_map "0,0,0;255,0,0;0,255,0"
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
Options:
|
|
77
|
+
- `--image_dir`: Path to directory containing noisy images
|
|
78
|
+
- `--label_dir`: Path to directory containing target RGB labels
|
|
79
|
+
- `--output_dir`: Directory where model weights will be saved
|
|
80
|
+
- `--colour_map`: Semicolon-separated list of RGB triples
|
|
81
|
+
- `--colour_map_file`: Path to a file listing RGB triples
|
|
82
|
+
- `--model_type`: The type of model to train (default: pixelwise)
|
|
83
|
+
|
|
84
|
+
Note that one label image may have multiple corresponding noisy masks. Labels are matched to noisy masks whose filenames contain the label file basename (pre-extension name, i.e. `my_image.png` -> `my_image`).
|
|
85
|
+
|
|
86
|
+
## API
|
|
87
|
+
|
|
88
|
+
You can also use the package programmatically:
|
|
89
|
+
|
|
90
|
+
```python
|
|
91
|
+
import numpy as np
|
|
92
|
+
from rgb_to_segmentation import clean, nn, train, utils
|
|
93
|
+
|
|
94
|
+
# Palette cleaning
|
|
95
|
+
colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
|
|
96
|
+
palette = np.asarray(colours, dtype=np.uint8)
|
|
97
|
+
clean.clean_segmentation(input_dir="/path/to/input", output_dir="/path/to/output", palette=palette, output_type="index")
|
|
98
|
+
|
|
99
|
+
# NN inference
|
|
100
|
+
colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
|
|
101
|
+
colour_map = {i: rgb for i, rgb in enumerate(colours)}
|
|
102
|
+
nn.run_inference(input_dir="/path/to/input", output_dir="/path/to/output", model_path="/path/to/model.ckpt", colour_map=colour_map, output_type="rgb")
|
|
103
|
+
|
|
104
|
+
# Train model
|
|
105
|
+
colours = utils.parse_colours_from_string("0,0,0;255,0,0;0,255,0")
|
|
106
|
+
colour_map = {i: rgb for i, rgb in enumerate(colours)}
|
|
107
|
+
train.train_model(image_dir="/path/to/images", label_dir="/path/to/labels", output_dir="/path/to/output", colour_map=colour_map)
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
## Contributing
|
|
111
|
+
|
|
112
|
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
|
113
|
+
|
|
114
|
+
## License
|
|
115
|
+
|
|
116
|
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling", "hatch-vcs"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "rgb-to-segmentation"
|
|
7
|
+
description = "Tools for processing and cleaning segmentation images using palette mapping and neural networks"
|
|
8
|
+
readme = "README.md"
|
|
9
|
+
license = { text = "MIT" }
|
|
10
|
+
requires-python = ">=3.8"
|
|
11
|
+
authors = [{ name = "Alex Senden" }]
|
|
12
|
+
maintainers = [{ name = "Alex Senden" }]
|
|
13
|
+
keywords = ["segmentation", "image-processing", "pytorch", "machine-learning"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 3 - Alpha",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"License :: OSI Approved :: MIT License",
|
|
18
|
+
"Programming Language :: Python :: 3",
|
|
19
|
+
"Programming Language :: Python :: 3.8",
|
|
20
|
+
"Programming Language :: Python :: 3.9",
|
|
21
|
+
"Programming Language :: Python :: 3.10",
|
|
22
|
+
"Programming Language :: Python :: 3.11",
|
|
23
|
+
"Topic :: Scientific/Engineering :: Image Processing",
|
|
24
|
+
]
|
|
25
|
+
dependencies = [
|
|
26
|
+
"torch>=1.9.0",
|
|
27
|
+
"torchvision>=0.10.0",
|
|
28
|
+
"pytorch-lightning>=1.5.0",
|
|
29
|
+
"numpy>=1.21.0",
|
|
30
|
+
"scipy>=1.7.0",
|
|
31
|
+
"Pillow>=8.0.0",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
# Version is automatically provided by hatch-vcs
|
|
35
|
+
dynamic = ["version"]
|
|
36
|
+
|
|
37
|
+
[project.scripts]
|
|
38
|
+
segment-clean = "rgb_to_segmentation.cli:main_clean"
|
|
39
|
+
segment-train = "rgb_to_segmentation.cli:main_train"
|
|
40
|
+
|
|
41
|
+
[tool.hatch.version]
|
|
42
|
+
source = "vcs"
|
|
43
|
+
|
|
44
|
+
[tool.hatch.build.targets.sdist]
|
|
45
|
+
include = ["rgb_to_segmentation/**", "README.md"]
|
|
46
|
+
|
|
47
|
+
[tool.hatch.build.targets.wheel]
|
|
48
|
+
include = ["rgb_to_segmentation/**"]
|
|
File without changes
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from scipy import ndimage
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def nearest_palette_image(image_array: np.ndarray, palette: np.ndarray) -> np.ndarray:
|
|
11
|
+
"""
|
|
12
|
+
Assign each pixel in `image_array` (H,W,3 uint8) to the nearest colour in `palette` (K,3 uint8).
|
|
13
|
+
Returns recoloured image array with same shape and dtype uint8.
|
|
14
|
+
"""
|
|
15
|
+
if image_array.ndim != 3 or image_array.shape[2] != 3:
|
|
16
|
+
raise ValueError("image_array must have shape (H, W, 3)")
|
|
17
|
+
|
|
18
|
+
h, w, _ = image_array.shape
|
|
19
|
+
flat = image_array.reshape(-1, 3).astype(np.int64)
|
|
20
|
+
pal = palette.astype(np.int64)
|
|
21
|
+
|
|
22
|
+
# Compute squared distances between each pixel and each palette colour.
|
|
23
|
+
# distances shape: (N_pixels, K)
|
|
24
|
+
d = np.sum((flat[:, None, :] - pal[None, :, :]) ** 2, axis=2)
|
|
25
|
+
|
|
26
|
+
idx = np.argmin(d, axis=1)
|
|
27
|
+
new_flat = pal[idx]
|
|
28
|
+
new = new_flat.reshape(h, w, 3).astype(np.uint8)
|
|
29
|
+
return new
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_palette_for_image(
|
|
33
|
+
image_array: np.ndarray, full_palette: np.ndarray
|
|
34
|
+
) -> np.ndarray:
|
|
35
|
+
"""
|
|
36
|
+
Identify which colours from the full palette are present in the image,
|
|
37
|
+
and return only those colours.
|
|
38
|
+
"""
|
|
39
|
+
h, w, _ = image_array.shape
|
|
40
|
+
flat_img = image_array.reshape(-1, 3).astype(np.int16)
|
|
41
|
+
pal = full_palette.astype(np.int16)
|
|
42
|
+
|
|
43
|
+
# For each pixel, find the nearest palette colour
|
|
44
|
+
d = np.sum((flat_img[:, None, :] - pal[None, :, :]) ** 2, axis=2)
|
|
45
|
+
idx = np.argmin(d, axis=1)
|
|
46
|
+
|
|
47
|
+
# Get unique indices that are actually used
|
|
48
|
+
unique_idx = np.unique(idx)
|
|
49
|
+
|
|
50
|
+
# Return only the palette colours that are used
|
|
51
|
+
return full_palette[unique_idx]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def apply_morphological_clean(image_array: np.ndarray, kernel_size: int) -> np.ndarray:
|
|
55
|
+
"""
|
|
56
|
+
Apply morphological closing (erosion followed by dilation) per class to clean up
|
|
57
|
+
class boundaries and remove noise.
|
|
58
|
+
"""
|
|
59
|
+
if kernel_size <= 0:
|
|
60
|
+
return image_array
|
|
61
|
+
|
|
62
|
+
# Create morphological kernel
|
|
63
|
+
kernel = ndimage.generate_binary_structure(2, 2)
|
|
64
|
+
|
|
65
|
+
# Get unique colours that actually appear in the image
|
|
66
|
+
h, w, _ = image_array.shape
|
|
67
|
+
flat_img = image_array.reshape(-1, 3)
|
|
68
|
+
unique_colours = np.unique(flat_img, axis=0)
|
|
69
|
+
|
|
70
|
+
# Process each class separately to avoid blending
|
|
71
|
+
result = np.zeros_like(image_array)
|
|
72
|
+
|
|
73
|
+
for colour in unique_colours:
|
|
74
|
+
# Create binary mask for this class
|
|
75
|
+
mask = np.all(image_array == colour, axis=-1)
|
|
76
|
+
|
|
77
|
+
# Apply closing: erosion then dilation
|
|
78
|
+
for _ in range(kernel_size):
|
|
79
|
+
mask = ndimage.binary_erosion(mask, structure=kernel)
|
|
80
|
+
for _ in range(kernel_size):
|
|
81
|
+
mask = ndimage.binary_dilation(mask, structure=kernel)
|
|
82
|
+
|
|
83
|
+
# Assign pixels back
|
|
84
|
+
result[mask] = colour
|
|
85
|
+
|
|
86
|
+
# Fill any remaining pixels (from eroded areas) with nearest colour from result
|
|
87
|
+
unfilled = ~np.any(result != 0, axis=-1)
|
|
88
|
+
if np.any(unfilled):
|
|
89
|
+
# For unfilled pixels, use nearest palette colour again or copy from nearby
|
|
90
|
+
result[unfilled] = image_array[unfilled]
|
|
91
|
+
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def rgb_image_to_index(image_array: np.ndarray, palette: np.ndarray) -> np.ndarray:
|
|
96
|
+
"""
|
|
97
|
+
Map each RGB pixel in `image_array` to the index of the matching colour in `palette`.
|
|
98
|
+
Assumes pixels take values from `palette`.
|
|
99
|
+
"""
|
|
100
|
+
h, w, _ = image_array.shape
|
|
101
|
+
palette_list = [tuple(map(int, c)) for c in palette.tolist()]
|
|
102
|
+
lookup = {c: i for i, c in enumerate(palette_list)}
|
|
103
|
+
flat = image_array.reshape(-1, 3)
|
|
104
|
+
idx = np.array([lookup[tuple(map(int, px))] for px in flat], dtype=np.uint16)
|
|
105
|
+
return idx.reshape(h, w)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def process_file(
|
|
109
|
+
input_path: str,
|
|
110
|
+
output_path: str,
|
|
111
|
+
palette: np.ndarray,
|
|
112
|
+
kernel_size: int,
|
|
113
|
+
output_type: str = "rgb",
|
|
114
|
+
):
|
|
115
|
+
try:
|
|
116
|
+
img = Image.open(input_path).convert("RGB")
|
|
117
|
+
except Exception as e:
|
|
118
|
+
print(f"Skipping {input_path}: cannot open image ({e})")
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
arr = np.array(img, dtype=np.uint8)
|
|
122
|
+
|
|
123
|
+
# Reduce palette to only colours present in this image
|
|
124
|
+
reduced_palette = get_palette_for_image(arr, palette)
|
|
125
|
+
|
|
126
|
+
cleaned = nearest_palette_image(arr, reduced_palette)
|
|
127
|
+
|
|
128
|
+
# Apply morphological transformations if kernel_size > 0
|
|
129
|
+
if kernel_size > 0:
|
|
130
|
+
cleaned = apply_morphological_clean(cleaned, kernel_size)
|
|
131
|
+
|
|
132
|
+
# Ensure output directory exists
|
|
133
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
134
|
+
|
|
135
|
+
if output_type == "rgb":
|
|
136
|
+
Image.fromarray(cleaned).save(output_path)
|
|
137
|
+
elif output_type == "index":
|
|
138
|
+
index_mask = rgb_image_to_index(cleaned, reduced_palette)
|
|
139
|
+
Image.fromarray(index_mask.astype(np.uint8), mode="L").save(output_path)
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError("output_type must be 'rgb' or 'index'")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def process_directory(
|
|
145
|
+
input_dir: str,
|
|
146
|
+
output_dir: str,
|
|
147
|
+
palette: np.ndarray,
|
|
148
|
+
exts: List[str],
|
|
149
|
+
inplace: bool,
|
|
150
|
+
name_filter: str = "",
|
|
151
|
+
kernel_size: int = 0,
|
|
152
|
+
output_type: str = "rgb",
|
|
153
|
+
):
|
|
154
|
+
exts = [e.lower().strip() for e in exts]
|
|
155
|
+
for root, dirs, files in os.walk(input_dir):
|
|
156
|
+
# Determine the corresponding output root
|
|
157
|
+
rel = os.path.relpath(root, input_dir)
|
|
158
|
+
out_root = os.path.join(output_dir, rel) if not inplace else root
|
|
159
|
+
os.makedirs(out_root, exist_ok=True)
|
|
160
|
+
|
|
161
|
+
for fname in files:
|
|
162
|
+
if not any(fname.lower().endswith(e) for e in exts):
|
|
163
|
+
continue
|
|
164
|
+
if name_filter and name_filter not in fname:
|
|
165
|
+
continue
|
|
166
|
+
in_path = os.path.join(root, fname)
|
|
167
|
+
out_path = os.path.join(out_root, fname)
|
|
168
|
+
process_file(in_path, out_path, palette, kernel_size, output_type)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def clean_segmentation(
|
|
172
|
+
input_dir: str,
|
|
173
|
+
output_dir: str = None,
|
|
174
|
+
inplace: bool = False,
|
|
175
|
+
palette: np.ndarray = None,
|
|
176
|
+
exts: str = ".png,.jpg,.jpeg,.tiff,.bmp,.gif",
|
|
177
|
+
name_filter: str = "",
|
|
178
|
+
morph_kernel_size: int = 3,
|
|
179
|
+
output_type: str = "rgb",
|
|
180
|
+
):
|
|
181
|
+
"""
|
|
182
|
+
Clean segmentation images using palette-based color mapping.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
input_dir (str): Path to input directory containing segmentation images.
|
|
186
|
+
output_dir (str, optional): Directory where cleaned images will be written. Required if not inplace.
|
|
187
|
+
inplace (bool): Overwrite input images in place.
|
|
188
|
+
palette (np.ndarray): Array of RGB triples (K, 3) uint8.
|
|
189
|
+
exts (str): Comma-separated list of allowed image extensions.
|
|
190
|
+
name_filter (str): Only process files whose name contains this substring.
|
|
191
|
+
morph_kernel_size (int): Size of morphological kernel for boundary cleaning.
|
|
192
|
+
"""
|
|
193
|
+
if not inplace and output_dir is None:
|
|
194
|
+
raise ValueError("Either output_dir must be provided or inplace must be True")
|
|
195
|
+
|
|
196
|
+
if palette is None:
|
|
197
|
+
raise ValueError("palette must be provided")
|
|
198
|
+
|
|
199
|
+
exts_list = [e if e.startswith(".") else "." + e for e in exts.split(",")]
|
|
200
|
+
|
|
201
|
+
out_dir = output_dir if not inplace else input_dir
|
|
202
|
+
|
|
203
|
+
if not inplace:
|
|
204
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
205
|
+
|
|
206
|
+
print(
|
|
207
|
+
f"Processing: input={input_dir} -> output={out_dir}, colours={len(palette)}, morph_kernel={morph_kernel_size}, output_type={output_type}"
|
|
208
|
+
)
|
|
209
|
+
process_directory(
|
|
210
|
+
input_dir,
|
|
211
|
+
out_dir,
|
|
212
|
+
palette,
|
|
213
|
+
exts_list,
|
|
214
|
+
inplace,
|
|
215
|
+
name_filter,
|
|
216
|
+
morph_kernel_size,
|
|
217
|
+
output_type,
|
|
218
|
+
)
|
|
219
|
+
print("Done.")
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from . import clean, nn, train, utils
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def main_clean():
|
|
9
|
+
parser = argparse.ArgumentParser(
|
|
10
|
+
description="Clean segmentation images using various methods."
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
parser.add_argument(
|
|
14
|
+
"--method",
|
|
15
|
+
type=str,
|
|
16
|
+
required=True,
|
|
17
|
+
choices=["palette", "nn"],
|
|
18
|
+
help="Cleaning method to use: 'palette' for color palette mapping, 'nn' for neural network.",
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--input_dir",
|
|
22
|
+
type=str,
|
|
23
|
+
required=True,
|
|
24
|
+
help="Path to input directory containing images.",
|
|
25
|
+
)
|
|
26
|
+
parser.add_argument(
|
|
27
|
+
"--output_dir",
|
|
28
|
+
type=str,
|
|
29
|
+
required=False,
|
|
30
|
+
help="Directory where cleaned images will be written. Required if not inplace.",
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"--inplace",
|
|
34
|
+
action="store_true",
|
|
35
|
+
help="Overwrite input images in place.",
|
|
36
|
+
)
|
|
37
|
+
parser.add_argument(
|
|
38
|
+
"--exts",
|
|
39
|
+
type=str,
|
|
40
|
+
default=".png,.jpg,.jpeg,.tiff,.bmp,.gif",
|
|
41
|
+
help="Comma-separated list of allowed image extensions.",
|
|
42
|
+
)
|
|
43
|
+
parser.add_argument(
|
|
44
|
+
"--name_filter",
|
|
45
|
+
type=str,
|
|
46
|
+
default="",
|
|
47
|
+
help="Only process files whose name contains this substring.",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
"--output_type",
|
|
52
|
+
type=str,
|
|
53
|
+
choices=["rgb", "index"],
|
|
54
|
+
default="rgb",
|
|
55
|
+
help="Output format: 'rgb' colour image or 'index' mask.",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
59
|
+
group.add_argument(
|
|
60
|
+
"--colour_map",
|
|
61
|
+
type=str,
|
|
62
|
+
help="Semicolon-separated list of RGB triples.",
|
|
63
|
+
)
|
|
64
|
+
group.add_argument(
|
|
65
|
+
"--colour_map_file",
|
|
66
|
+
type=str,
|
|
67
|
+
help="Path to a file listing RGB triples.",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Palette-specific args
|
|
71
|
+
parser.add_argument(
|
|
72
|
+
"--morph_kernel_size",
|
|
73
|
+
type=int,
|
|
74
|
+
default=3,
|
|
75
|
+
help="Size of morphological kernel for palette method.",
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# NN-specific args
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"--model_path",
|
|
81
|
+
type=str,
|
|
82
|
+
help="Path to trained model for nn method.",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
args = parser.parse_args()
|
|
86
|
+
|
|
87
|
+
if args.colour_map_file:
|
|
88
|
+
colours = utils.parse_colours_from_file(args.colour_map_file)
|
|
89
|
+
else:
|
|
90
|
+
colours = utils.parse_colours_from_string(args.colour_map)
|
|
91
|
+
|
|
92
|
+
if args.method == "palette":
|
|
93
|
+
palette = np.asarray(colours, dtype=np.uint8)
|
|
94
|
+
|
|
95
|
+
clean.clean_segmentation(
|
|
96
|
+
input_dir=args.input_dir,
|
|
97
|
+
output_dir=args.output_dir,
|
|
98
|
+
inplace=args.inplace,
|
|
99
|
+
palette=palette,
|
|
100
|
+
exts=args.exts,
|
|
101
|
+
name_filter=args.name_filter,
|
|
102
|
+
morph_kernel_size=args.morph_kernel_size,
|
|
103
|
+
output_type=args.output_type,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
elif args.method == "nn":
|
|
107
|
+
if not args.model_path:
|
|
108
|
+
parser.error("--model_path required for nn method")
|
|
109
|
+
|
|
110
|
+
colour_map = {i: rgb for i, rgb in enumerate(colours)}
|
|
111
|
+
nn.run_inference(
|
|
112
|
+
input_dir=args.input_dir,
|
|
113
|
+
output_dir=args.output_dir,
|
|
114
|
+
inplace=args.inplace,
|
|
115
|
+
model_path=args.model_path,
|
|
116
|
+
colour_map=colour_map,
|
|
117
|
+
exts=args.exts,
|
|
118
|
+
name_filter=args.name_filter,
|
|
119
|
+
output_type=args.output_type,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def main_train():
|
|
124
|
+
parser = argparse.ArgumentParser(
|
|
125
|
+
description="Train a neural network model for segmentation cleaning."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
parser.add_argument(
|
|
129
|
+
"--image_dir",
|
|
130
|
+
type=str,
|
|
131
|
+
required=True,
|
|
132
|
+
help="Path to directory containing noisy images.",
|
|
133
|
+
)
|
|
134
|
+
parser.add_argument(
|
|
135
|
+
"--label_dir",
|
|
136
|
+
type=str,
|
|
137
|
+
required=True,
|
|
138
|
+
help="Path to directory containing target RGB labels.",
|
|
139
|
+
)
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--output_dir",
|
|
142
|
+
type=str,
|
|
143
|
+
required=True,
|
|
144
|
+
help="Directory where model weights will be saved.",
|
|
145
|
+
)
|
|
146
|
+
parser.add_argument(
|
|
147
|
+
"--model_type",
|
|
148
|
+
type=str,
|
|
149
|
+
default="pixelwise",
|
|
150
|
+
help="The type of model to train.",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
154
|
+
group.add_argument(
|
|
155
|
+
"--colour_map",
|
|
156
|
+
type=str,
|
|
157
|
+
help="Semicolon-separated list of RGB triples.",
|
|
158
|
+
)
|
|
159
|
+
group.add_argument(
|
|
160
|
+
"--colour_map_file",
|
|
161
|
+
type=str,
|
|
162
|
+
help="Path to a file listing RGB triples.",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
args = parser.parse_args()
|
|
166
|
+
|
|
167
|
+
if args.colour_map_file:
|
|
168
|
+
colours = utils.parse_colours_from_file(args.colour_map_file)
|
|
169
|
+
else:
|
|
170
|
+
colours = utils.parse_colours_from_string(args.colour_map)
|
|
171
|
+
colour_map = {i: rgb for i, rgb in enumerate(colours)}
|
|
172
|
+
|
|
173
|
+
train.train_model(
|
|
174
|
+
image_dir=args.image_dir,
|
|
175
|
+
label_dir=args.label_dir,
|
|
176
|
+
output_dir=args.output_dir,
|
|
177
|
+
colour_map=colour_map,
|
|
178
|
+
model_type=args.model_type,
|
|
179
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import pytorch_lightning as pl
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PixelwiseClassifier(pl.LightningModule):
|
|
8
|
+
step = 0
|
|
9
|
+
val_step = 0
|
|
10
|
+
|
|
11
|
+
def __init__(self, input_dim, hidden_dim, output_dim):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.save_hyperparameters()
|
|
14
|
+
|
|
15
|
+
self.output_dim = output_dim
|
|
16
|
+
self.net = nn.Sequential(
|
|
17
|
+
nn.Linear(input_dim, hidden_dim),
|
|
18
|
+
nn.ReLU(),
|
|
19
|
+
nn.Linear(hidden_dim, output_dim),
|
|
20
|
+
)
|
|
21
|
+
self.softmax = nn.Softmax(dim=-1)
|
|
22
|
+
self.loss_fn = nn.CrossEntropyLoss()
|
|
23
|
+
|
|
24
|
+
def forward(self, x: torch.Tensor):
|
|
25
|
+
x = self.net(x)
|
|
26
|
+
return self.softmax(x)
|
|
27
|
+
|
|
28
|
+
def training_step(self, batch, batch_idx):
|
|
29
|
+
sample, target = batch
|
|
30
|
+
probs = self(sample)
|
|
31
|
+
target = F.one_hot(target.squeeze(-1), num_classes=self.output_dim).to(
|
|
32
|
+
dtype=sample.dtype
|
|
33
|
+
)
|
|
34
|
+
loss = self.loss_fn(probs, target)
|
|
35
|
+
self.log("train_loss", loss)
|
|
36
|
+
return loss
|
|
37
|
+
|
|
38
|
+
def validation_step(self, batch, batch_idx):
|
|
39
|
+
sample, target = batch
|
|
40
|
+
probs = self(sample)
|
|
41
|
+
target = F.one_hot(target.squeeze(-1), num_classes=self.output_dim).to(
|
|
42
|
+
dtype=sample.dtype
|
|
43
|
+
)
|
|
44
|
+
loss = self.loss_fn(probs, target)
|
|
45
|
+
self.log("val_loss", loss, prog_bar=True)
|
|
46
|
+
return loss
|
|
47
|
+
|
|
48
|
+
def configure_optimizers(self):
|
|
49
|
+
return torch.optim.Adam(self.parameters())
|
|
50
|
+
|
|
51
|
+
def image_to_batch(self, x: torch.Tensor):
|
|
52
|
+
if len(x.shape) == 3:
|
|
53
|
+
x = x.unsqueeze(0)
|
|
54
|
+
return x.permute(0, 2, 3, 1).reshape(-1, x.shape[1]).to(self.device)
|
|
55
|
+
|
|
56
|
+
def batch_to_image(self, batch: torch.Tensor, height: int, width: int):
|
|
57
|
+
return batch.reshape(1, height, width, 3).permute(0, 3, 1, 2)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from torchvision.io import read_image
|
|
7
|
+
|
|
8
|
+
from .models.pixelwise_classifier import PixelwiseClassifier
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_model(model_path: str):
|
|
12
|
+
print(model_path)
|
|
13
|
+
model = PixelwiseClassifier.load_from_checkpoint(checkpoint_path=model_path)
|
|
14
|
+
model.eval()
|
|
15
|
+
|
|
16
|
+
return model
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def map_int_to_rgb(indexed_image: torch.Tensor, colour_map: dict):
|
|
20
|
+
h, w = indexed_image.shape
|
|
21
|
+
rgb_image = torch.zeros((3, h, w), dtype=torch.uint8)
|
|
22
|
+
|
|
23
|
+
for idx, rgb in colour_map.items():
|
|
24
|
+
mask = indexed_image == idx
|
|
25
|
+
for c in range(3):
|
|
26
|
+
rgb_image[c][mask] = rgb[c]
|
|
27
|
+
|
|
28
|
+
return rgb_image
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def run_inference(
|
|
32
|
+
input_dir: str,
|
|
33
|
+
output_dir: str,
|
|
34
|
+
inplace: bool = False,
|
|
35
|
+
model_path: str = None,
|
|
36
|
+
colour_map: dict = None,
|
|
37
|
+
exts: str = ".png,.jpg,.jpeg,.tiff,.bmp,.gif",
|
|
38
|
+
name_filter: str = "",
|
|
39
|
+
output_type: str = "rgb",
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Run neural network inference on segmentation images.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
input_dir (str): Path to input directory containing images.
|
|
46
|
+
output_dir (str, optional): Directory where output images will be written.
|
|
47
|
+
inplace (bool): Overwrite input images in place.
|
|
48
|
+
model_path (str): Path to the trained model file.
|
|
49
|
+
colour_map (dict): Mapping from class indices to RGB tuples.
|
|
50
|
+
exts (str): Comma-separated list of allowed image extensions.
|
|
51
|
+
name_filter (str): Only process files whose name contains this substring.
|
|
52
|
+
"""
|
|
53
|
+
if not inplace and output_dir is None:
|
|
54
|
+
raise ValueError("Either output_dir must be provided or inplace must be True")
|
|
55
|
+
|
|
56
|
+
if model_path is None:
|
|
57
|
+
raise ValueError("model_path must be provided")
|
|
58
|
+
|
|
59
|
+
if colour_map is None:
|
|
60
|
+
raise ValueError("colour_map must be provided")
|
|
61
|
+
|
|
62
|
+
model = load_model(model_path)
|
|
63
|
+
|
|
64
|
+
exts_list = [e.lower().strip() for e in exts.split(",")]
|
|
65
|
+
out_dir = output_dir if not inplace else input_dir
|
|
66
|
+
|
|
67
|
+
if not inplace:
|
|
68
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
69
|
+
|
|
70
|
+
print(f"Running inference: input={input_dir} -> output={out_dir}, output_type={output_type}")
|
|
71
|
+
|
|
72
|
+
for root, dirs, files in os.walk(input_dir):
|
|
73
|
+
rel = os.path.relpath(root, input_dir)
|
|
74
|
+
out_root = os.path.join(out_dir, rel) if not inplace else root
|
|
75
|
+
|
|
76
|
+
os.makedirs(out_root, exist_ok=True)
|
|
77
|
+
|
|
78
|
+
for fname in files:
|
|
79
|
+
if not any(fname.lower().endswith(e) for e in exts_list):
|
|
80
|
+
continue
|
|
81
|
+
if name_filter and name_filter not in fname:
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
in_path = os.path.join(root, fname)
|
|
85
|
+
out_path = os.path.join(out_root, fname)
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
img = read_image(in_path, mode="RGB").float() / 127.5 - 1.0 # Normalize
|
|
89
|
+
h, w = img.shape[1], img.shape[2]
|
|
90
|
+
|
|
91
|
+
batch = model.image_to_batch(img)
|
|
92
|
+
|
|
93
|
+
with torch.no_grad():
|
|
94
|
+
probs = model(batch)
|
|
95
|
+
predicted = torch.argmax(probs, dim=-1).reshape(h, w)
|
|
96
|
+
|
|
97
|
+
if output_type == "rgb":
|
|
98
|
+
rgb_image = map_int_to_rgb(predicted, colour_map)
|
|
99
|
+
pil_image = Image.fromarray(rgb_image.permute(1, 2, 0).numpy())
|
|
100
|
+
pil_image.save(out_path)
|
|
101
|
+
elif output_type == "index":
|
|
102
|
+
pil_image = Image.fromarray(predicted.cpu().numpy().astype("uint8"), mode="L")
|
|
103
|
+
pil_image.save(out_path)
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError("output_type must be 'rgb' or 'index'")
|
|
106
|
+
|
|
107
|
+
except Exception as e:
|
|
108
|
+
print(f"Skipping {in_path}: {e}")
|
|
109
|
+
|
|
110
|
+
print("Inference done.")
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from pytorch_lightning import Trainer
|
|
6
|
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|
7
|
+
from torch.utils.data import DataLoader, Dataset, random_split
|
|
8
|
+
from torchvision.io import read_image
|
|
9
|
+
|
|
10
|
+
from .nn import PixelwiseClassifier
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
VALIDATION_FRACTION = 0.2
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def map_colour_to_int(sample, colour_map):
|
|
17
|
+
_, H, W = sample.shape
|
|
18
|
+
image_array = torch.zeros((1, H, W))
|
|
19
|
+
|
|
20
|
+
for idx, rgb in colour_map.items():
|
|
21
|
+
mask = torch.all(sample == torch.tensor(rgb).view(3, 1, 1), dim=0)
|
|
22
|
+
image_array[0][mask] = idx
|
|
23
|
+
|
|
24
|
+
return image_array
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SegMaskDataset(Dataset):
|
|
28
|
+
def __init__(self, paired_filenames, colour_map, model):
|
|
29
|
+
self.paired_filenames = paired_filenames
|
|
30
|
+
self.colour_map = colour_map
|
|
31
|
+
self.to_batch = model.image_to_batch
|
|
32
|
+
|
|
33
|
+
def __len__(self):
|
|
34
|
+
return len(self.paired_filenames)
|
|
35
|
+
|
|
36
|
+
def __getitem__(self, index):
|
|
37
|
+
sample_path = self.paired_filenames[index]["sample"]
|
|
38
|
+
target_path = self.paired_filenames[index]["target"]
|
|
39
|
+
|
|
40
|
+
sample = (
|
|
41
|
+
read_image(sample_path, mode="RGB") / 127.5
|
|
42
|
+
) - 1.0 # Normalize to [-1, 1]
|
|
43
|
+
target = read_image(target_path, mode="RGB")
|
|
44
|
+
target = map_colour_to_int(target, self.colour_map)
|
|
45
|
+
sample = self.to_batch(sample)
|
|
46
|
+
target = self.to_batch(target).to(torch.long)
|
|
47
|
+
|
|
48
|
+
return (sample, target)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_png_basenames(directory: str) -> list[str]:
|
|
52
|
+
return [
|
|
53
|
+
os.path.splitext(filename)[0]
|
|
54
|
+
for filename in os.listdir(directory)
|
|
55
|
+
if filename.endswith(".png")
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_paired_filenames(
|
|
60
|
+
image_dir: str,
|
|
61
|
+
label_dir: str,
|
|
62
|
+
noisy_basenames: list[str],
|
|
63
|
+
label_basenames: list[str],
|
|
64
|
+
training_label_basenames: list[str],
|
|
65
|
+
) -> list[dict]:
|
|
66
|
+
training_paired_filenames = []
|
|
67
|
+
val_paired_filenames = []
|
|
68
|
+
|
|
69
|
+
for noisy_image_name in noisy_basenames:
|
|
70
|
+
targets = [
|
|
71
|
+
target_file
|
|
72
|
+
for target_file in label_basenames
|
|
73
|
+
if target_file in noisy_image_name
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
if len(targets) > 1:
|
|
77
|
+
raise Exception(
|
|
78
|
+
f"Multiple target files exist for noisy file {noisy_image_name}: {targets}."
|
|
79
|
+
)
|
|
80
|
+
elif len(targets) < 1:
|
|
81
|
+
print(
|
|
82
|
+
f"WARNING: No target found for noisy file {noisy_image_name}. Discarding."
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
target_filename = targets[0]
|
|
86
|
+
|
|
87
|
+
if target_filename in training_label_basenames:
|
|
88
|
+
training_paired_filenames.append(
|
|
89
|
+
{
|
|
90
|
+
"sample": f"{image_dir}/{noisy_image_name}.png",
|
|
91
|
+
"target": f"{label_dir}/{target_filename}.png",
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
val_paired_filenames.append(
|
|
96
|
+
{
|
|
97
|
+
"sample": f"{image_dir}/{noisy_image_name}.png",
|
|
98
|
+
"target": f"{label_dir}/{target_filename}.png",
|
|
99
|
+
}
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return training_paired_filenames, val_paired_filenames
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_dataloaders(
|
|
106
|
+
image_dir, label_dir, colour_map, model
|
|
107
|
+
) -> tuple[DataLoader, DataLoader]:
|
|
108
|
+
noisy_basenames = get_png_basenames(image_dir)
|
|
109
|
+
label_basenames = get_png_basenames(label_dir)
|
|
110
|
+
|
|
111
|
+
training_label_basenames, val_label_basenames = random_split(
|
|
112
|
+
label_basenames, [1 - VALIDATION_FRACTION, VALIDATION_FRACTION]
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
training_paired_filenames, val_paired_filenames = get_paired_filenames(
|
|
116
|
+
image_dir,
|
|
117
|
+
label_dir,
|
|
118
|
+
noisy_basenames,
|
|
119
|
+
label_basenames,
|
|
120
|
+
training_label_basenames,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
train_dataset = SegMaskDataset(training_paired_filenames, colour_map, model)
|
|
124
|
+
val_dataset = SegMaskDataset(val_paired_filenames, colour_map, model)
|
|
125
|
+
|
|
126
|
+
train_dataloader = DataLoader(
|
|
127
|
+
dataset=train_dataset,
|
|
128
|
+
batch_size=1,
|
|
129
|
+
shuffle=True,
|
|
130
|
+
num_workers=2,
|
|
131
|
+
)
|
|
132
|
+
val_dataloader = DataLoader(
|
|
133
|
+
dataset=val_dataset,
|
|
134
|
+
batch_size=1,
|
|
135
|
+
shuffle=True,
|
|
136
|
+
num_workers=2,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return train_dataloader, val_dataloader
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_model(model_type, colour_map):
|
|
143
|
+
num_classes = len(colour_map.keys())
|
|
144
|
+
|
|
145
|
+
if model_type == "pixelwise":
|
|
146
|
+
return PixelwiseClassifier(input_dim=3, hidden_dim=32, output_dim=num_classes)
|
|
147
|
+
|
|
148
|
+
raise ValueError(f"Invalid model type: {model_type}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def train_model(
|
|
152
|
+
image_dir: str,
|
|
153
|
+
label_dir: str,
|
|
154
|
+
output_dir: str,
|
|
155
|
+
colour_map: dict,
|
|
156
|
+
model_type: str = "pixelwise",
|
|
157
|
+
):
|
|
158
|
+
"""
|
|
159
|
+
Train a neural network model for segmentation cleaning.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
image_dir (str): Path to directory containing noisy images.
|
|
163
|
+
label_dir (str): Path to directory containing target RGB labels.
|
|
164
|
+
output_dir (str): Directory where model weights will be saved.
|
|
165
|
+
colour_map (dict): Mapping from class indices to RGB tuples.
|
|
166
|
+
model_type (str): The type of model to train.
|
|
167
|
+
"""
|
|
168
|
+
model = get_model(model_type, colour_map)
|
|
169
|
+
train_dataloader, val_dataloader = get_dataloaders(
|
|
170
|
+
image_dir, label_dir, colour_map, model
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
174
|
+
|
|
175
|
+
early_stop = EarlyStopping(monitor="val_loss", mode="min", verbose=True)
|
|
176
|
+
checkpoint = ModelCheckpoint(
|
|
177
|
+
monitor="val_loss", mode="min", save_top_k=1, dirpath=output_dir
|
|
178
|
+
)
|
|
179
|
+
trainer = Trainer(max_epochs=100, callbacks=[early_stop, checkpoint])
|
|
180
|
+
|
|
181
|
+
trainer.fit(
|
|
182
|
+
model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
|
|
183
|
+
)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def parse_colours_from_string(colours_str: str) -> List[Tuple[int, int, int]]:
|
|
5
|
+
"""Parse a semicolon-separated string of RGB triples into a list of tuples."""
|
|
6
|
+
|
|
7
|
+
parts = [p.strip() for p in colours_str.split(";") if p.strip()]
|
|
8
|
+
colours = []
|
|
9
|
+
|
|
10
|
+
for p in parts:
|
|
11
|
+
rgb = tuple(int(x) for x in p.split(","))
|
|
12
|
+
|
|
13
|
+
if len(rgb) != 3:
|
|
14
|
+
raise ValueError(f"Invalid colour triple: {p}")
|
|
15
|
+
colours.append(rgb)
|
|
16
|
+
|
|
17
|
+
return colours
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def parse_colours_from_file(path: str) -> List[Tuple[int, int, int]]:
|
|
21
|
+
"""Parse a file with one RGB triple per line into a list of tuples."""
|
|
22
|
+
|
|
23
|
+
colours = []
|
|
24
|
+
|
|
25
|
+
with open(path, "r") as f:
|
|
26
|
+
for line in f:
|
|
27
|
+
line = line.strip()
|
|
28
|
+
if not line or line.startswith("#"):
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
rgb = tuple(int(x) for x in line.split(","))
|
|
32
|
+
if len(rgb) != 3:
|
|
33
|
+
raise ValueError(f"Invalid colour triple in file {path}: {line}")
|
|
34
|
+
|
|
35
|
+
colours.append(rgb)
|
|
36
|
+
|
|
37
|
+
if not colours:
|
|
38
|
+
raise ValueError(f"No colours found in file: {path}")
|
|
39
|
+
|
|
40
|
+
return colours
|