geoai-py 0.3.5__py2.py3-none-any.whl → 0.3.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/extract.py +96 -17
- geoai/geoai.py +1 -0
- geoai/segment.py +305 -0
- geoai/utils.py +859 -18
- {geoai_py-0.3.5.dist-info → geoai_py-0.3.6.dist-info}/METADATA +5 -1
- geoai_py-0.3.6.dist-info/RECORD +13 -0
- {geoai_py-0.3.5.dist-info → geoai_py-0.3.6.dist-info}/WHEEL +1 -1
- geoai/preprocess.py +0 -3021
- geoai_py-0.3.5.dist-info/RECORD +0 -13
- {geoai_py-0.3.5.dist-info → geoai_py-0.3.6.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.5.dist-info → geoai_py-0.3.6.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.5.dist-info → geoai_py-0.3.6.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
geoai/extract.py
CHANGED
|
@@ -1,21 +1,29 @@
|
|
|
1
|
+
"""This module provides a dataset class for object extraction from raster data"""
|
|
2
|
+
|
|
3
|
+
# Standard Library
|
|
1
4
|
import os
|
|
5
|
+
|
|
6
|
+
# Third-Party Libraries
|
|
7
|
+
import cv2
|
|
8
|
+
import geopandas as gpd
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
2
10
|
import numpy as np
|
|
11
|
+
import rasterio
|
|
12
|
+
import scipy.ndimage as ndimage
|
|
3
13
|
import torch
|
|
4
|
-
|
|
14
|
+
from huggingface_hub import hf_hub_download
|
|
15
|
+
from rasterio.windows import Window
|
|
5
16
|
from shapely.geometry import Polygon, box
|
|
6
|
-
import geopandas as gpd
|
|
7
17
|
from tqdm import tqdm
|
|
18
|
+
from torchvision.models.detection import (
|
|
19
|
+
maskrcnn_resnet50_fpn,
|
|
20
|
+
fasterrcnn_resnet50_fpn_v2,
|
|
21
|
+
)
|
|
8
22
|
|
|
9
|
-
|
|
10
|
-
from torchvision.models.detection import maskrcnn_resnet50_fpn
|
|
11
|
-
import torchvision.transforms as T
|
|
12
|
-
import rasterio
|
|
13
|
-
from rasterio.windows import Window
|
|
14
|
-
from rasterio.features import shapes
|
|
15
|
-
from huggingface_hub import hf_hub_download
|
|
16
|
-
import scipy.ndimage as ndimage
|
|
23
|
+
# Local Imports
|
|
17
24
|
from .utils import get_raster_stats
|
|
18
25
|
|
|
26
|
+
|
|
19
27
|
try:
|
|
20
28
|
from torchgeo.datasets import NonGeoDataset
|
|
21
29
|
except ImportError as e:
|
|
@@ -60,6 +68,7 @@ class CustomDataset(NonGeoDataset):
|
|
|
60
68
|
chip_size=(512, 512),
|
|
61
69
|
overlap=0.5,
|
|
62
70
|
transforms=None,
|
|
71
|
+
band_indexes=None,
|
|
63
72
|
verbose=False,
|
|
64
73
|
):
|
|
65
74
|
"""
|
|
@@ -70,6 +79,7 @@ class CustomDataset(NonGeoDataset):
|
|
|
70
79
|
chip_size: Size of image chips to extract (height, width). Default is (512, 512).
|
|
71
80
|
overlap: Amount of overlap between adjacent tiles (0.0-1.0). Default is 0.5 (50%).
|
|
72
81
|
transforms: Transforms to apply to the image. Default is None.
|
|
82
|
+
band_indexes: List of band indexes to use. Default is None (use all bands).
|
|
73
83
|
verbose: Whether to print detailed processing information. Default is False.
|
|
74
84
|
|
|
75
85
|
Raises:
|
|
@@ -82,6 +92,7 @@ class CustomDataset(NonGeoDataset):
|
|
|
82
92
|
self.chip_size = chip_size
|
|
83
93
|
self.overlap = overlap
|
|
84
94
|
self.transforms = transforms
|
|
95
|
+
self.band_indexes = band_indexes
|
|
85
96
|
self.verbose = verbose
|
|
86
97
|
self.warned_about_bands = False
|
|
87
98
|
|
|
@@ -191,7 +202,10 @@ class CustomDataset(NonGeoDataset):
|
|
|
191
202
|
if not self.warned_about_bands and self.verbose:
|
|
192
203
|
print(f"Image has {image.shape[0]} bands, using first 3 bands only")
|
|
193
204
|
self.warned_about_bands = True
|
|
194
|
-
|
|
205
|
+
if self.band_indexes is not None:
|
|
206
|
+
image = image[self.band_indexes]
|
|
207
|
+
else:
|
|
208
|
+
image = image[:3]
|
|
195
209
|
elif image.shape[0] < 3:
|
|
196
210
|
# If image has fewer than 3 bands, duplicate the last band to make 3
|
|
197
211
|
if not self.warned_about_bands and self.verbose:
|
|
@@ -594,7 +608,7 @@ class ObjectDetector:
|
|
|
594
608
|
|
|
595
609
|
Args:
|
|
596
610
|
mask_path: Path to the object masks GeoTIFF
|
|
597
|
-
output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
|
|
611
|
+
output_path: Path to save the output GeoJSON or Parquet file (default: mask_path with .geojson extension)
|
|
598
612
|
simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
|
|
599
613
|
mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
|
|
600
614
|
min_object_area: Minimum area in pixels to keep an object (default: self.min_object_area)
|
|
@@ -779,7 +793,10 @@ class ObjectDetector:
|
|
|
779
793
|
|
|
780
794
|
# Save to file
|
|
781
795
|
if output_path:
|
|
782
|
-
|
|
796
|
+
if output_path.endswith(".parquet"):
|
|
797
|
+
gdf.to_parquet(output_path)
|
|
798
|
+
else:
|
|
799
|
+
gdf.to_file(output_path)
|
|
783
800
|
print(f"Saved {len(gdf)} objects to {output_path}")
|
|
784
801
|
|
|
785
802
|
return gdf
|
|
@@ -792,6 +809,7 @@ class ObjectDetector:
|
|
|
792
809
|
batch_size=4,
|
|
793
810
|
filter_edges=True,
|
|
794
811
|
edge_buffer=20,
|
|
812
|
+
band_indexes=None,
|
|
795
813
|
**kwargs,
|
|
796
814
|
):
|
|
797
815
|
"""
|
|
@@ -799,10 +817,11 @@ class ObjectDetector:
|
|
|
799
817
|
|
|
800
818
|
Args:
|
|
801
819
|
raster_path: Path to input raster file
|
|
802
|
-
output_path: Path to output GeoJSON file (optional)
|
|
820
|
+
output_path: Path to output GeoJSON or Parquet file (optional)
|
|
803
821
|
batch_size: Batch size for processing
|
|
804
822
|
filter_edges: Whether to filter out objects at the edges of the image
|
|
805
823
|
edge_buffer: Size of edge buffer in pixels to filter out objects (if filter_edges=True)
|
|
824
|
+
band_indexes: List of band indexes to use (if None, use all bands)
|
|
806
825
|
**kwargs: Additional parameters:
|
|
807
826
|
confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
|
|
808
827
|
overlap: Overlap between adjacent tiles (0.0-1.0)
|
|
@@ -843,7 +862,10 @@ class ObjectDetector:
|
|
|
843
862
|
|
|
844
863
|
# Create dataset
|
|
845
864
|
dataset = CustomDataset(
|
|
846
|
-
raster_path=raster_path,
|
|
865
|
+
raster_path=raster_path,
|
|
866
|
+
chip_size=chip_size,
|
|
867
|
+
overlap=overlap,
|
|
868
|
+
band_indexes=band_indexes,
|
|
847
869
|
)
|
|
848
870
|
self.raster_stats = dataset.raster_stats
|
|
849
871
|
|
|
@@ -1021,7 +1043,10 @@ class ObjectDetector:
|
|
|
1021
1043
|
|
|
1022
1044
|
# Save to file if requested
|
|
1023
1045
|
if output_path:
|
|
1024
|
-
|
|
1046
|
+
if output_path.endswith(".parquet"):
|
|
1047
|
+
gdf.to_parquet(output_path)
|
|
1048
|
+
else:
|
|
1049
|
+
gdf.to_file(output_path, driver="GeoJSON")
|
|
1025
1050
|
print(f"Saved {len(gdf)} objects to {output_path}")
|
|
1026
1051
|
|
|
1027
1052
|
return gdf
|
|
@@ -1887,21 +1912,30 @@ class ObjectDetector:
|
|
|
1887
1912
|
output_path=None,
|
|
1888
1913
|
confidence_threshold=None,
|
|
1889
1914
|
mask_threshold=None,
|
|
1915
|
+
min_object_area=10,
|
|
1916
|
+
max_object_area=float("inf"),
|
|
1890
1917
|
overlap=0.25,
|
|
1891
1918
|
batch_size=4,
|
|
1919
|
+
band_indexes=None,
|
|
1892
1920
|
verbose=False,
|
|
1893
1921
|
**kwargs,
|
|
1894
1922
|
):
|
|
1895
1923
|
"""
|
|
1896
1924
|
Save masks with confidence values as a multi-band GeoTIFF.
|
|
1897
1925
|
|
|
1926
|
+
Objects with area smaller than min_object_area or larger than max_object_area
|
|
1927
|
+
will be filtered out.
|
|
1928
|
+
|
|
1898
1929
|
Args:
|
|
1899
1930
|
raster_path: Path to input raster
|
|
1900
1931
|
output_path: Path for output GeoTIFF
|
|
1901
1932
|
confidence_threshold: Minimum confidence score (0.0-1.0)
|
|
1902
1933
|
mask_threshold: Threshold for mask binarization (0.0-1.0)
|
|
1934
|
+
min_object_area: Minimum area (in pixels) for an object to be included
|
|
1935
|
+
max_object_area: Maximum area (in pixels) for an object to be included
|
|
1903
1936
|
overlap: Overlap between tiles (0.0-1.0)
|
|
1904
1937
|
batch_size: Batch size for processing
|
|
1938
|
+
band_indexes: List of band indexes to use (default: all bands)
|
|
1905
1939
|
verbose: Whether to print detailed processing information
|
|
1906
1940
|
|
|
1907
1941
|
Returns:
|
|
@@ -1926,6 +1960,7 @@ class ObjectDetector:
|
|
|
1926
1960
|
raster_path=raster_path,
|
|
1927
1961
|
chip_size=chip_size,
|
|
1928
1962
|
overlap=overlap,
|
|
1963
|
+
band_indexes=band_indexes,
|
|
1929
1964
|
verbose=verbose,
|
|
1930
1965
|
)
|
|
1931
1966
|
|
|
@@ -2012,6 +2047,21 @@ class ObjectDetector:
|
|
|
2012
2047
|
for mask_idx, mask in enumerate(masks):
|
|
2013
2048
|
# Convert to binary mask
|
|
2014
2049
|
binary_mask = (mask[0] > mask_threshold).astype(np.uint8) * 255
|
|
2050
|
+
|
|
2051
|
+
# Check object area - calculate number of pixels in the mask
|
|
2052
|
+
object_area = np.sum(binary_mask > 0)
|
|
2053
|
+
|
|
2054
|
+
# Skip objects that don't meet area criteria
|
|
2055
|
+
if (
|
|
2056
|
+
object_area < min_object_area
|
|
2057
|
+
or object_area > max_object_area
|
|
2058
|
+
):
|
|
2059
|
+
if verbose:
|
|
2060
|
+
print(
|
|
2061
|
+
f"Filtering out object with area {object_area} pixels"
|
|
2062
|
+
)
|
|
2063
|
+
continue
|
|
2064
|
+
|
|
2015
2065
|
conf_value = int(scores[mask_idx] * 255) # Scale to 0-255
|
|
2016
2066
|
|
|
2017
2067
|
# Update the mask and confidence arrays
|
|
@@ -2164,7 +2214,7 @@ class ObjectDetector:
|
|
|
2164
2214
|
# Save to file if requested
|
|
2165
2215
|
if output_path:
|
|
2166
2216
|
gdf.to_file(output_path, driver="GeoJSON")
|
|
2167
|
-
print(f"Saved {len(gdf)}
|
|
2217
|
+
print(f"Saved {len(gdf)} objects with confidence to {output_path}")
|
|
2168
2218
|
|
|
2169
2219
|
return gdf
|
|
2170
2220
|
else:
|
|
@@ -2277,3 +2327,32 @@ class ShipDetector(ObjectDetector):
|
|
|
2277
2327
|
super().__init__(
|
|
2278
2328
|
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
2279
2329
|
)
|
|
2330
|
+
|
|
2331
|
+
|
|
2332
|
+
class SolarPanelDetector(ObjectDetector):
|
|
2333
|
+
"""
|
|
2334
|
+
Solar panel detection using a pre-trained Mask R-CNN model.
|
|
2335
|
+
|
|
2336
|
+
This class extends the
|
|
2337
|
+
`ObjectDetector` class with additional methods for solar panel detection."
|
|
2338
|
+
"""
|
|
2339
|
+
|
|
2340
|
+
def __init__(
|
|
2341
|
+
self,
|
|
2342
|
+
model_path="solar_panel_detection.pth",
|
|
2343
|
+
repo_id=None,
|
|
2344
|
+
model=None,
|
|
2345
|
+
device=None,
|
|
2346
|
+
):
|
|
2347
|
+
"""
|
|
2348
|
+
Initialize the object extractor.
|
|
2349
|
+
|
|
2350
|
+
Args:
|
|
2351
|
+
model_path: Path to the .pth model file.
|
|
2352
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
2353
|
+
model: Custom model to use for inference.
|
|
2354
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
2355
|
+
"""
|
|
2356
|
+
super().__init__(
|
|
2357
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
2358
|
+
)
|
geoai/geoai.py
CHANGED
geoai/segment.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""This module provides functionality for segmenting high-resolution satellite imagery using vision-language models."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from PIL import Image
|
|
8
|
+
import rasterio
|
|
9
|
+
from rasterio.windows import Window
|
|
10
|
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CLIPSegmentation:
|
|
14
|
+
"""
|
|
15
|
+
A class for segmenting high-resolution satellite imagery using text prompts with CLIP-based models.
|
|
16
|
+
|
|
17
|
+
This segmenter utilizes the CLIP-Seg model to perform semantic segmentation based on text prompts.
|
|
18
|
+
It can process large GeoTIFF files by tiling them and handles proper georeferencing in the output.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
|
|
22
|
+
device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
|
|
23
|
+
tile_size (int): Size of tiles to process the image in chunks. Defaults to 352.
|
|
24
|
+
overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 16.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
processor (CLIPSegProcessor): The processor for the CLIP-Seg model.
|
|
28
|
+
model (CLIPSegForImageSegmentation): The CLIP-Seg model for segmentation.
|
|
29
|
+
device (str): The device being used ('cuda' or 'cpu').
|
|
30
|
+
tile_size (int): Size of tiles for processing.
|
|
31
|
+
overlap (int): Overlap between tiles.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model_name="CIDAS/clipseg-rd64-refined",
|
|
37
|
+
device=None,
|
|
38
|
+
tile_size=512,
|
|
39
|
+
overlap=32,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Initialize the ImageSegmenter with the specified model and settings.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
|
|
46
|
+
device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
|
|
47
|
+
tile_size (int): Size of tiles to process the image in chunks. Defaults to 512.
|
|
48
|
+
overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 32.
|
|
49
|
+
"""
|
|
50
|
+
self.tile_size = tile_size
|
|
51
|
+
self.overlap = overlap
|
|
52
|
+
|
|
53
|
+
# Set device
|
|
54
|
+
if device is None:
|
|
55
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
56
|
+
else:
|
|
57
|
+
self.device = device
|
|
58
|
+
|
|
59
|
+
# Load model and processor
|
|
60
|
+
self.processor = CLIPSegProcessor.from_pretrained(model_name)
|
|
61
|
+
self.model = CLIPSegForImageSegmentation.from_pretrained(model_name).to(
|
|
62
|
+
self.device
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
print(f"Model loaded on {self.device}")
|
|
66
|
+
|
|
67
|
+
def segment_image(
|
|
68
|
+
self, input_path, output_path, text_prompt, threshold=0.5, smoothing_sigma=1.0
|
|
69
|
+
):
|
|
70
|
+
"""
|
|
71
|
+
Segment a GeoTIFF image using the provided text prompt.
|
|
72
|
+
|
|
73
|
+
The function processes the image in tiles and saves the result as a GeoTIFF with two bands:
|
|
74
|
+
- Band 1: Binary segmentation mask (0 or 1)
|
|
75
|
+
- Band 2: Probability scores (0.0 to 1.0)
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
input_path (str): Path to the input GeoTIFF file.
|
|
79
|
+
output_path (str): Path where the output GeoTIFF will be saved.
|
|
80
|
+
text_prompt (str): Text description of what to segment (e.g., "water", "buildings").
|
|
81
|
+
threshold (float): Threshold for binary segmentation (0.0 to 1.0). Defaults to 0.5.
|
|
82
|
+
smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
str: Path to the saved output file.
|
|
86
|
+
"""
|
|
87
|
+
# Open the input GeoTIFF
|
|
88
|
+
with rasterio.open(input_path) as src:
|
|
89
|
+
# Get metadata
|
|
90
|
+
meta = src.meta
|
|
91
|
+
height = src.height
|
|
92
|
+
width = src.width
|
|
93
|
+
|
|
94
|
+
# Create output metadata
|
|
95
|
+
out_meta = meta.copy()
|
|
96
|
+
out_meta.update({"count": 2, "dtype": "float32", "nodata": None})
|
|
97
|
+
|
|
98
|
+
# Create arrays for results
|
|
99
|
+
segmentation = np.zeros((height, width), dtype=np.float32)
|
|
100
|
+
probabilities = np.zeros((height, width), dtype=np.float32)
|
|
101
|
+
|
|
102
|
+
# Calculate effective tile size (accounting for overlap)
|
|
103
|
+
effective_tile_size = self.tile_size - 2 * self.overlap
|
|
104
|
+
|
|
105
|
+
# Calculate number of tiles
|
|
106
|
+
n_tiles_x = max(1, int(np.ceil(width / effective_tile_size)))
|
|
107
|
+
n_tiles_y = max(1, int(np.ceil(height / effective_tile_size)))
|
|
108
|
+
total_tiles = n_tiles_x * n_tiles_y
|
|
109
|
+
|
|
110
|
+
# Process tiles with tqdm progress bar
|
|
111
|
+
with tqdm(total=total_tiles, desc="Processing tiles") as pbar:
|
|
112
|
+
# Iterate through tiles
|
|
113
|
+
for y in range(n_tiles_y):
|
|
114
|
+
for x in range(n_tiles_x):
|
|
115
|
+
# Calculate tile coordinates with overlap
|
|
116
|
+
x_start = max(0, x * effective_tile_size - self.overlap)
|
|
117
|
+
y_start = max(0, y * effective_tile_size - self.overlap)
|
|
118
|
+
x_end = min(width, (x + 1) * effective_tile_size + self.overlap)
|
|
119
|
+
y_end = min(
|
|
120
|
+
height, (y + 1) * effective_tile_size + self.overlap
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
tile_width = x_end - x_start
|
|
124
|
+
tile_height = y_end - y_start
|
|
125
|
+
|
|
126
|
+
# Read the tile
|
|
127
|
+
window = Window(x_start, y_start, tile_width, tile_height)
|
|
128
|
+
tile_data = src.read(window=window)
|
|
129
|
+
|
|
130
|
+
# Process the tile
|
|
131
|
+
try:
|
|
132
|
+
# Convert to RGB if necessary (handling different satellite bands)
|
|
133
|
+
if tile_data.shape[0] > 3:
|
|
134
|
+
# Use first three bands for RGB representation
|
|
135
|
+
rgb_tile = tile_data[:3].transpose(1, 2, 0)
|
|
136
|
+
# Normalize data to 0-255 range if needed
|
|
137
|
+
if rgb_tile.max() > 0:
|
|
138
|
+
rgb_tile = (
|
|
139
|
+
(rgb_tile - rgb_tile.min())
|
|
140
|
+
/ (rgb_tile.max() - rgb_tile.min())
|
|
141
|
+
* 255
|
|
142
|
+
).astype(np.uint8)
|
|
143
|
+
elif tile_data.shape[0] == 1:
|
|
144
|
+
# Create RGB from grayscale
|
|
145
|
+
rgb_tile = np.repeat(
|
|
146
|
+
tile_data[0][:, :, np.newaxis], 3, axis=2
|
|
147
|
+
)
|
|
148
|
+
# Normalize if needed
|
|
149
|
+
if rgb_tile.max() > 0:
|
|
150
|
+
rgb_tile = (
|
|
151
|
+
(rgb_tile - rgb_tile.min())
|
|
152
|
+
/ (rgb_tile.max() - rgb_tile.min())
|
|
153
|
+
* 255
|
|
154
|
+
).astype(np.uint8)
|
|
155
|
+
else:
|
|
156
|
+
# Already 3-channel, assume RGB
|
|
157
|
+
rgb_tile = tile_data.transpose(1, 2, 0)
|
|
158
|
+
# Normalize if needed
|
|
159
|
+
if rgb_tile.max() > 0:
|
|
160
|
+
rgb_tile = (
|
|
161
|
+
(rgb_tile - rgb_tile.min())
|
|
162
|
+
/ (rgb_tile.max() - rgb_tile.min())
|
|
163
|
+
* 255
|
|
164
|
+
).astype(np.uint8)
|
|
165
|
+
|
|
166
|
+
# Convert to PIL Image
|
|
167
|
+
pil_image = Image.fromarray(rgb_tile)
|
|
168
|
+
|
|
169
|
+
# Resize if needed to match model's requirements
|
|
170
|
+
if (
|
|
171
|
+
pil_image.width > self.tile_size
|
|
172
|
+
or pil_image.height > self.tile_size
|
|
173
|
+
):
|
|
174
|
+
# Keep aspect ratio
|
|
175
|
+
pil_image.thumbnail(
|
|
176
|
+
(self.tile_size, self.tile_size), Image.LANCZOS
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Process with CLIP-Seg
|
|
180
|
+
inputs = self.processor(
|
|
181
|
+
text=text_prompt, images=pil_image, return_tensors="pt"
|
|
182
|
+
).to(self.device)
|
|
183
|
+
|
|
184
|
+
# Forward pass
|
|
185
|
+
with torch.no_grad():
|
|
186
|
+
outputs = self.model(**inputs)
|
|
187
|
+
|
|
188
|
+
# Get logits and resize to original tile size
|
|
189
|
+
logits = outputs.logits[0]
|
|
190
|
+
|
|
191
|
+
# Convert logits to probabilities with sigmoid
|
|
192
|
+
probs = torch.sigmoid(logits).cpu().numpy()
|
|
193
|
+
|
|
194
|
+
# Resize back to original tile size if needed
|
|
195
|
+
if probs.shape != (tile_height, tile_width):
|
|
196
|
+
# Use bicubic interpolation for smoother results
|
|
197
|
+
probs_resized = np.array(
|
|
198
|
+
Image.fromarray(probs).resize(
|
|
199
|
+
(tile_width, tile_height), Image.BICUBIC
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
probs_resized = probs
|
|
204
|
+
|
|
205
|
+
# Apply gaussian blur to reduce blockiness
|
|
206
|
+
try:
|
|
207
|
+
from scipy.ndimage import gaussian_filter
|
|
208
|
+
|
|
209
|
+
probs_resized = gaussian_filter(
|
|
210
|
+
probs_resized, sigma=smoothing_sigma
|
|
211
|
+
)
|
|
212
|
+
except ImportError:
|
|
213
|
+
pass # Continue without smoothing if scipy is not available
|
|
214
|
+
|
|
215
|
+
# Store results in the full arrays
|
|
216
|
+
# Only store the non-overlapping part (except at edges)
|
|
217
|
+
valid_x_start = self.overlap if x > 0 else 0
|
|
218
|
+
valid_y_start = self.overlap if y > 0 else 0
|
|
219
|
+
valid_x_end = (
|
|
220
|
+
tile_width - self.overlap
|
|
221
|
+
if x < n_tiles_x - 1
|
|
222
|
+
else tile_width
|
|
223
|
+
)
|
|
224
|
+
valid_y_end = (
|
|
225
|
+
tile_height - self.overlap
|
|
226
|
+
if y < n_tiles_y - 1
|
|
227
|
+
else tile_height
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
dest_x_start = x_start + valid_x_start
|
|
231
|
+
dest_y_start = y_start + valid_y_start
|
|
232
|
+
dest_x_end = x_start + valid_x_end
|
|
233
|
+
dest_y_end = y_start + valid_y_end
|
|
234
|
+
|
|
235
|
+
# Store probabilities
|
|
236
|
+
probabilities[
|
|
237
|
+
dest_y_start:dest_y_end, dest_x_start:dest_x_end
|
|
238
|
+
] = probs_resized[
|
|
239
|
+
valid_y_start:valid_y_end, valid_x_start:valid_x_end
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
except Exception as e:
|
|
243
|
+
print(f"Error processing tile at ({x}, {y}): {str(e)}")
|
|
244
|
+
# Continue with next tile
|
|
245
|
+
|
|
246
|
+
# Update progress bar
|
|
247
|
+
pbar.update(1)
|
|
248
|
+
|
|
249
|
+
# Create binary segmentation from probabilities
|
|
250
|
+
segmentation = (probabilities >= threshold).astype(np.float32)
|
|
251
|
+
|
|
252
|
+
# Write the output GeoTIFF
|
|
253
|
+
with rasterio.open(output_path, "w", **out_meta) as dst:
|
|
254
|
+
dst.write(segmentation, 1)
|
|
255
|
+
dst.write(probabilities, 2)
|
|
256
|
+
|
|
257
|
+
# Add descriptions to bands
|
|
258
|
+
dst.set_band_description(1, "Binary Segmentation")
|
|
259
|
+
dst.set_band_description(2, "Probability Scores")
|
|
260
|
+
|
|
261
|
+
print(f"Segmentation saved to {output_path}")
|
|
262
|
+
return output_path
|
|
263
|
+
|
|
264
|
+
def segment_image_batch(
|
|
265
|
+
self,
|
|
266
|
+
input_paths,
|
|
267
|
+
output_dir,
|
|
268
|
+
text_prompt,
|
|
269
|
+
threshold=0.5,
|
|
270
|
+
smoothing_sigma=1.0,
|
|
271
|
+
suffix="_segmented",
|
|
272
|
+
):
|
|
273
|
+
"""
|
|
274
|
+
Segment multiple GeoTIFF images using the provided text prompt.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
input_paths (list): List of paths to input GeoTIFF files.
|
|
278
|
+
output_dir (str): Directory where output GeoTIFFs will be saved.
|
|
279
|
+
text_prompt (str): Text description of what to segment.
|
|
280
|
+
threshold (float): Threshold for binary segmentation. Defaults to 0.5.
|
|
281
|
+
smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
|
|
282
|
+
suffix (str): Suffix to add to output filenames. Defaults to "_segmented".
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
list: Paths to all saved output files.
|
|
286
|
+
"""
|
|
287
|
+
# Create output directory if it doesn't exist
|
|
288
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
289
|
+
|
|
290
|
+
output_paths = []
|
|
291
|
+
|
|
292
|
+
# Process each input file
|
|
293
|
+
for input_path in tqdm(input_paths, desc="Processing files"):
|
|
294
|
+
# Generate output path
|
|
295
|
+
filename = os.path.basename(input_path)
|
|
296
|
+
base_name, ext = os.path.splitext(filename)
|
|
297
|
+
output_path = os.path.join(output_dir, f"{base_name}{suffix}{ext}")
|
|
298
|
+
|
|
299
|
+
# Segment the image
|
|
300
|
+
result_path = self.segment_image(
|
|
301
|
+
input_path, output_path, text_prompt, threshold, smoothing_sigma
|
|
302
|
+
)
|
|
303
|
+
output_paths.append(result_path)
|
|
304
|
+
|
|
305
|
+
return output_paths
|