geoai-py 0.3.2__py2.py3-none-any.whl → 0.3.4__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/extract.py +420 -45
- geoai/geoai.py +0 -1
- geoai/preprocess.py +20 -6
- geoai/utils.py +3830 -30
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.4.dist-info}/METADATA +11 -4
- geoai_py-0.3.4.dist-info/RECORD +13 -0
- geoai_py-0.3.2.dist-info/RECORD +0 -13
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.4.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.4.dist-info}/WHEEL +0 -0
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.4.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.4.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
geoai/extract.py
CHANGED
|
@@ -13,7 +13,8 @@ import rasterio
|
|
|
13
13
|
from rasterio.windows import Window
|
|
14
14
|
from rasterio.features import shapes
|
|
15
15
|
from huggingface_hub import hf_hub_download
|
|
16
|
-
|
|
16
|
+
import scipy.ndimage as ndimage
|
|
17
|
+
from .utils import get_raster_stats
|
|
17
18
|
|
|
18
19
|
try:
|
|
19
20
|
from torchgeo.datasets import NonGeoDataset
|
|
@@ -25,34 +26,74 @@ except ImportError as e:
|
|
|
25
26
|
|
|
26
27
|
class CustomDataset(NonGeoDataset):
|
|
27
28
|
"""
|
|
28
|
-
A TorchGeo dataset for object extraction.
|
|
29
|
-
|
|
29
|
+
A TorchGeo dataset for object extraction with overlapping tiles support.
|
|
30
|
+
|
|
31
|
+
This dataset class creates overlapping image tiles for object detection,
|
|
32
|
+
ensuring complete coverage of the input raster including right and bottom edges.
|
|
33
|
+
It inherits from NonGeoDataset to avoid spatial indexing issues.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
raster_path: Path to the input raster file.
|
|
37
|
+
chip_size: Size of image chips to extract (height, width).
|
|
38
|
+
overlap: Amount of overlap between adjacent tiles (0.0-1.0).
|
|
39
|
+
transforms: Transforms to apply to the image.
|
|
40
|
+
verbose: Whether to print detailed processing information.
|
|
41
|
+
stride_x: Horizontal stride between tiles based on overlap.
|
|
42
|
+
stride_y: Vertical stride between tiles based on overlap.
|
|
43
|
+
row_starts: Starting Y positions for each row of tiles.
|
|
44
|
+
col_starts: Starting X positions for each column of tiles.
|
|
45
|
+
crs: Coordinate reference system of the raster.
|
|
46
|
+
transform: Affine transform of the raster.
|
|
47
|
+
height: Height of the raster in pixels.
|
|
48
|
+
width: Width of the raster in pixels.
|
|
49
|
+
count: Number of bands in the raster.
|
|
50
|
+
bounds: Geographic bounds of the raster (west, south, east, north).
|
|
51
|
+
roi: Shapely box representing the region of interest.
|
|
52
|
+
rows: Number of rows of tiles.
|
|
53
|
+
cols: Number of columns of tiles.
|
|
54
|
+
raster_stats: Statistics of the raster.
|
|
30
55
|
"""
|
|
31
56
|
|
|
32
57
|
def __init__(
|
|
33
|
-
self,
|
|
58
|
+
self,
|
|
59
|
+
raster_path,
|
|
60
|
+
chip_size=(512, 512),
|
|
61
|
+
overlap=0.5,
|
|
62
|
+
transforms=None,
|
|
63
|
+
verbose=False,
|
|
34
64
|
):
|
|
35
65
|
"""
|
|
36
|
-
Initialize the dataset.
|
|
66
|
+
Initialize the dataset with overlapping tiles.
|
|
37
67
|
|
|
38
68
|
Args:
|
|
39
|
-
raster_path: Path to the input raster file
|
|
40
|
-
chip_size: Size of image chips to extract (height, width)
|
|
41
|
-
|
|
42
|
-
|
|
69
|
+
raster_path: Path to the input raster file.
|
|
70
|
+
chip_size: Size of image chips to extract (height, width). Default is (512, 512).
|
|
71
|
+
overlap: Amount of overlap between adjacent tiles (0.0-1.0). Default is 0.5 (50%).
|
|
72
|
+
transforms: Transforms to apply to the image. Default is None.
|
|
73
|
+
verbose: Whether to print detailed processing information. Default is False.
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
ValueError: If overlap is too high resulting in non-positive stride.
|
|
43
77
|
"""
|
|
44
78
|
super().__init__()
|
|
45
79
|
|
|
46
80
|
# Initialize parameters
|
|
47
81
|
self.raster_path = raster_path
|
|
48
82
|
self.chip_size = chip_size
|
|
83
|
+
self.overlap = overlap
|
|
49
84
|
self.transforms = transforms
|
|
50
85
|
self.verbose = verbose
|
|
51
|
-
|
|
52
|
-
# For tracking warnings about multi-band images
|
|
53
86
|
self.warned_about_bands = False
|
|
54
87
|
|
|
55
|
-
#
|
|
88
|
+
# Calculate stride based on overlap
|
|
89
|
+
self.stride_x = int(chip_size[1] * (1 - overlap))
|
|
90
|
+
self.stride_y = int(chip_size[0] * (1 - overlap))
|
|
91
|
+
|
|
92
|
+
if self.stride_x <= 0 or self.stride_y <= 0:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"Overlap {overlap} is too high, resulting in non-positive stride"
|
|
95
|
+
)
|
|
96
|
+
|
|
56
97
|
with rasterio.open(self.raster_path) as src:
|
|
57
98
|
self.crs = src.crs
|
|
58
99
|
self.transform = src.transform
|
|
@@ -63,43 +104,78 @@ class CustomDataset(NonGeoDataset):
|
|
|
63
104
|
# Define the bounds of the dataset
|
|
64
105
|
west, south, east, north = src.bounds
|
|
65
106
|
self.bounds = (west, south, east, north)
|
|
66
|
-
|
|
67
|
-
# Define the ROI for the dataset
|
|
68
107
|
self.roi = box(*self.bounds)
|
|
69
108
|
|
|
70
|
-
# Calculate
|
|
71
|
-
|
|
72
|
-
self.
|
|
73
|
-
|
|
109
|
+
# Calculate starting positions for each tile
|
|
110
|
+
self.row_starts = []
|
|
111
|
+
self.col_starts = []
|
|
112
|
+
|
|
113
|
+
# Normal row starts using stride
|
|
114
|
+
for r in range((self.height - 1) // self.stride_y):
|
|
115
|
+
self.row_starts.append(r * self.stride_y)
|
|
116
|
+
|
|
117
|
+
# Add a special last row that ensures we reach the bottom edge
|
|
118
|
+
if self.height > self.chip_size[0]:
|
|
119
|
+
self.row_starts.append(max(0, self.height - self.chip_size[0]))
|
|
120
|
+
else:
|
|
121
|
+
# If the image is smaller than chip size, just start at 0
|
|
122
|
+
if not self.row_starts:
|
|
123
|
+
self.row_starts.append(0)
|
|
124
|
+
|
|
125
|
+
# Normal column starts using stride
|
|
126
|
+
for c in range((self.width - 1) // self.stride_x):
|
|
127
|
+
self.col_starts.append(c * self.stride_x)
|
|
128
|
+
|
|
129
|
+
# Add a special last column that ensures we reach the right edge
|
|
130
|
+
if self.width > self.chip_size[1]:
|
|
131
|
+
self.col_starts.append(max(0, self.width - self.chip_size[1]))
|
|
132
|
+
else:
|
|
133
|
+
# If the image is smaller than chip size, just start at 0
|
|
134
|
+
if not self.col_starts:
|
|
135
|
+
self.col_starts.append(0)
|
|
136
|
+
|
|
137
|
+
# Update rows and cols based on actual starting positions
|
|
138
|
+
self.rows = len(self.row_starts)
|
|
139
|
+
self.cols = len(self.col_starts)
|
|
74
140
|
|
|
75
141
|
print(
|
|
76
142
|
f"Dataset initialized with {self.rows} rows and {self.cols} columns of chips"
|
|
77
143
|
)
|
|
78
144
|
print(f"Image dimensions: {self.width} x {self.height} pixels")
|
|
79
145
|
print(f"Chip size: {self.chip_size[1]} x {self.chip_size[0]} pixels")
|
|
146
|
+
print(
|
|
147
|
+
f"Overlap: {overlap*100}% (stride_x={self.stride_x}, stride_y={self.stride_y})"
|
|
148
|
+
)
|
|
80
149
|
if src.crs:
|
|
81
150
|
print(f"CRS: {src.crs}")
|
|
82
151
|
|
|
83
|
-
#
|
|
152
|
+
# Get raster stats
|
|
84
153
|
self.raster_stats = get_raster_stats(raster_path, divide_by=255)
|
|
85
154
|
|
|
86
155
|
def __getitem__(self, idx):
|
|
87
156
|
"""
|
|
88
157
|
Get an image chip from the dataset by index.
|
|
89
158
|
|
|
159
|
+
Retrieves an image tile with the specified overlap pattern, ensuring
|
|
160
|
+
proper coverage of the entire raster including edges.
|
|
161
|
+
|
|
90
162
|
Args:
|
|
91
|
-
idx: Index of the chip
|
|
163
|
+
idx: Index of the chip to retrieve.
|
|
92
164
|
|
|
93
165
|
Returns:
|
|
94
|
-
|
|
166
|
+
dict: Dictionary containing:
|
|
167
|
+
- image: Image tensor.
|
|
168
|
+
- bbox: Geographic bounding box for the window.
|
|
169
|
+
- coords: Pixel coordinates as tensor [i, j].
|
|
170
|
+
- window_size: Window size as tensor [width, height].
|
|
95
171
|
"""
|
|
96
172
|
# Convert flat index to grid position
|
|
97
173
|
row = idx // self.cols
|
|
98
174
|
col = idx % self.cols
|
|
99
175
|
|
|
100
|
-
#
|
|
101
|
-
|
|
102
|
-
|
|
176
|
+
# Get pre-calculated starting positions
|
|
177
|
+
j = self.row_starts[row]
|
|
178
|
+
i = self.col_starts[col]
|
|
103
179
|
|
|
104
180
|
# Read window from raster
|
|
105
181
|
with rasterio.open(self.raster_path) as src:
|
|
@@ -166,7 +242,12 @@ class CustomDataset(NonGeoDataset):
|
|
|
166
242
|
}
|
|
167
243
|
|
|
168
244
|
def __len__(self):
|
|
169
|
-
"""
|
|
245
|
+
"""
|
|
246
|
+
Return the number of samples in the dataset.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
int: Total number of tiles in the dataset.
|
|
250
|
+
"""
|
|
170
251
|
return self.rows * self.cols
|
|
171
252
|
|
|
172
253
|
|
|
@@ -196,7 +277,8 @@ class ObjectDetector:
|
|
|
196
277
|
self.overlap = 0.25 # Default overlap between tiles
|
|
197
278
|
self.confidence_threshold = 0.5 # Default confidence threshold
|
|
198
279
|
self.nms_iou_threshold = 0.5 # IoU threshold for non-maximum suppression
|
|
199
|
-
self.
|
|
280
|
+
self.min_object_area = 100 # Minimum area in pixels to keep an object
|
|
281
|
+
self.max_object_area = None # Maximum area in pixels to keep an object
|
|
200
282
|
self.mask_threshold = 0.5 # Threshold for mask binarization
|
|
201
283
|
self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
|
|
202
284
|
|
|
@@ -326,7 +408,8 @@ class ObjectDetector:
|
|
|
326
408
|
**kwargs: Optional parameters:
|
|
327
409
|
simplify_tolerance: Tolerance for polygon simplification
|
|
328
410
|
mask_threshold: Threshold for mask binarization
|
|
329
|
-
|
|
411
|
+
min_object_area: Minimum area in pixels to keep an object
|
|
412
|
+
max_object_area: Maximum area in pixels to keep an object
|
|
330
413
|
|
|
331
414
|
Returns:
|
|
332
415
|
List of polygons as lists of (x, y) coordinates
|
|
@@ -335,7 +418,8 @@ class ObjectDetector:
|
|
|
335
418
|
# Get parameters from kwargs or use instance defaults
|
|
336
419
|
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
337
420
|
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
338
|
-
|
|
421
|
+
min_object_area = kwargs.get("min_object_area", self.min_object_area)
|
|
422
|
+
max_object_area = kwargs.get("max_object_area", self.max_object_area)
|
|
339
423
|
|
|
340
424
|
# Ensure binary mask
|
|
341
425
|
mask = (mask > mask_threshold).astype(np.uint8)
|
|
@@ -351,7 +435,14 @@ class ObjectDetector:
|
|
|
351
435
|
polygons = []
|
|
352
436
|
for contour in contours:
|
|
353
437
|
# Filter out too small contours
|
|
354
|
-
if contour.shape[0] < 3 or cv2.contourArea(contour) <
|
|
438
|
+
if contour.shape[0] < 3 or cv2.contourArea(contour) < min_object_area:
|
|
439
|
+
continue
|
|
440
|
+
|
|
441
|
+
# Filter out too large contours
|
|
442
|
+
if (
|
|
443
|
+
max_object_area is not None
|
|
444
|
+
and cv2.contourArea(contour) > max_object_area
|
|
445
|
+
):
|
|
355
446
|
continue
|
|
356
447
|
|
|
357
448
|
# Simplify contour if it has many points
|
|
@@ -491,7 +582,8 @@ class ObjectDetector:
|
|
|
491
582
|
output_path=None,
|
|
492
583
|
simplify_tolerance=None,
|
|
493
584
|
mask_threshold=None,
|
|
494
|
-
|
|
585
|
+
min_object_area=None,
|
|
586
|
+
max_object_area=None,
|
|
495
587
|
nms_iou_threshold=None,
|
|
496
588
|
regularize=True,
|
|
497
589
|
angle_threshold=15,
|
|
@@ -505,7 +597,8 @@ class ObjectDetector:
|
|
|
505
597
|
output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
|
|
506
598
|
simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
|
|
507
599
|
mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
|
|
508
|
-
|
|
600
|
+
min_object_area: Minimum area in pixels to keep an object (default: self.min_object_area)
|
|
601
|
+
max_object_area: Minimum area in pixels to keep an object (default: self.max_object_area)
|
|
509
602
|
nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
|
|
510
603
|
regularize: Whether to regularize objects to right angles (default: True)
|
|
511
604
|
angle_threshold: Maximum deviation from 90 degrees for regularization (default: 15)
|
|
@@ -523,10 +616,11 @@ class ObjectDetector:
|
|
|
523
616
|
mask_threshold = (
|
|
524
617
|
mask_threshold if mask_threshold is not None else self.mask_threshold
|
|
525
618
|
)
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
619
|
+
min_object_area = (
|
|
620
|
+
min_object_area if min_object_area is not None else self.min_object_area
|
|
621
|
+
)
|
|
622
|
+
max_object_area = (
|
|
623
|
+
max_object_area if max_object_area is not None else self.max_object_area
|
|
530
624
|
)
|
|
531
625
|
nms_iou_threshold = (
|
|
532
626
|
nms_iou_threshold
|
|
@@ -540,7 +634,8 @@ class ObjectDetector:
|
|
|
540
634
|
|
|
541
635
|
print(f"Converting mask to GeoJSON with parameters:")
|
|
542
636
|
print(f"- Mask threshold: {mask_threshold}")
|
|
543
|
-
print(f"- Min object area: {
|
|
637
|
+
print(f"- Min object area: {min_object_area}")
|
|
638
|
+
print(f"- Max object area: {max_object_area}")
|
|
544
639
|
print(f"- Simplify tolerance: {simplify_tolerance}")
|
|
545
640
|
print(f"- NMS IoU threshold: {nms_iou_threshold}")
|
|
546
641
|
print(f"- Regularize objects: {regularize}")
|
|
@@ -586,7 +681,11 @@ class ObjectDetector:
|
|
|
586
681
|
area = stats[i, cv2.CC_STAT_AREA]
|
|
587
682
|
|
|
588
683
|
# Skip if too small
|
|
589
|
-
if area <
|
|
684
|
+
if area < min_object_area:
|
|
685
|
+
continue
|
|
686
|
+
|
|
687
|
+
# Skip if too large
|
|
688
|
+
if max_object_area is not None and area > max_object_area:
|
|
590
689
|
continue
|
|
591
690
|
|
|
592
691
|
# Create a mask for this object
|
|
@@ -710,7 +809,7 @@ class ObjectDetector:
|
|
|
710
809
|
chip_size: Size of image chips for processing (height, width)
|
|
711
810
|
nms_iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0)
|
|
712
811
|
mask_threshold: Threshold for mask binarization (0.0-1.0)
|
|
713
|
-
|
|
812
|
+
min_object_area: Minimum area in pixels to keep an object
|
|
714
813
|
simplify_tolerance: Tolerance for polygon simplification
|
|
715
814
|
|
|
716
815
|
Returns:
|
|
@@ -724,7 +823,8 @@ class ObjectDetector:
|
|
|
724
823
|
chip_size = kwargs.get("chip_size", self.chip_size)
|
|
725
824
|
nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
|
|
726
825
|
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
727
|
-
|
|
826
|
+
min_object_area = kwargs.get("min_object_area", self.min_object_area)
|
|
827
|
+
max_object_area = kwargs.get("max_object_area", self.max_object_area)
|
|
728
828
|
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
729
829
|
|
|
730
830
|
# Print parameters being used
|
|
@@ -734,14 +834,17 @@ class ObjectDetector:
|
|
|
734
834
|
print(f"- Chip size: {chip_size}")
|
|
735
835
|
print(f"- NMS IoU threshold: {nms_iou_threshold}")
|
|
736
836
|
print(f"- Mask threshold: {mask_threshold}")
|
|
737
|
-
print(f"- Min object area: {
|
|
837
|
+
print(f"- Min object area: {min_object_area}")
|
|
838
|
+
print(f"- Max object area: {max_object_area}")
|
|
738
839
|
print(f"- Simplify tolerance: {simplify_tolerance}")
|
|
739
840
|
print(f"- Filter edge objects: {filter_edges}")
|
|
740
841
|
if filter_edges:
|
|
741
842
|
print(f"- Edge buffer size: {edge_buffer} pixels")
|
|
742
843
|
|
|
743
844
|
# Create dataset
|
|
744
|
-
dataset = CustomDataset(
|
|
845
|
+
dataset = CustomDataset(
|
|
846
|
+
raster_path=raster_path, chip_size=chip_size, overlap=overlap
|
|
847
|
+
)
|
|
745
848
|
self.raster_stats = dataset.raster_stats
|
|
746
849
|
|
|
747
850
|
# Custom collate function to handle Shapely objects
|
|
@@ -865,7 +968,8 @@ class ObjectDetector:
|
|
|
865
968
|
binary_mask,
|
|
866
969
|
simplify_tolerance=simplify_tolerance,
|
|
867
970
|
mask_threshold=mask_threshold,
|
|
868
|
-
|
|
971
|
+
min_object_area=min_object_area,
|
|
972
|
+
max_object_area=max_object_area,
|
|
869
973
|
)
|
|
870
974
|
|
|
871
975
|
# Skip if no valid polygons
|
|
@@ -948,6 +1052,7 @@ class ObjectDetector:
|
|
|
948
1052
|
)
|
|
949
1053
|
chip_size = kwargs.get("chip_size", self.chip_size)
|
|
950
1054
|
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
1055
|
+
overlap = kwargs.get("overlap", self.overlap)
|
|
951
1056
|
|
|
952
1057
|
# Set default output path if not provided
|
|
953
1058
|
if output_path is None:
|
|
@@ -961,7 +1066,10 @@ class ObjectDetector:
|
|
|
961
1066
|
|
|
962
1067
|
# Create dataset
|
|
963
1068
|
dataset = CustomDataset(
|
|
964
|
-
raster_path=raster_path,
|
|
1069
|
+
raster_path=raster_path,
|
|
1070
|
+
chip_size=chip_size,
|
|
1071
|
+
overlap=overlap,
|
|
1072
|
+
verbose=verbose,
|
|
965
1073
|
)
|
|
966
1074
|
|
|
967
1075
|
# Store a flag to avoid repetitive messages
|
|
@@ -1836,8 +1944,7 @@ class CarDetector(ObjectDetector):
|
|
|
1836
1944
|
"""
|
|
1837
1945
|
Car detection using a pre-trained Mask R-CNN model.
|
|
1838
1946
|
|
|
1839
|
-
This class extends the
|
|
1840
|
-
`ObjectDetector` class with additional methods for car detection."
|
|
1947
|
+
This class extends the `ObjectDetector` class with additional methods for car detection.
|
|
1841
1948
|
"""
|
|
1842
1949
|
|
|
1843
1950
|
def __init__(
|
|
@@ -1856,6 +1963,274 @@ class CarDetector(ObjectDetector):
|
|
|
1856
1963
|
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
1857
1964
|
)
|
|
1858
1965
|
|
|
1966
|
+
def generate_masks(
|
|
1967
|
+
self,
|
|
1968
|
+
raster_path,
|
|
1969
|
+
output_path=None,
|
|
1970
|
+
confidence_threshold=None,
|
|
1971
|
+
mask_threshold=None,
|
|
1972
|
+
overlap=0.25,
|
|
1973
|
+
batch_size=4,
|
|
1974
|
+
verbose=False,
|
|
1975
|
+
):
|
|
1976
|
+
"""
|
|
1977
|
+
Save masks with confidence values as a multi-band GeoTIFF.
|
|
1978
|
+
|
|
1979
|
+
Args:
|
|
1980
|
+
raster_path: Path to input raster
|
|
1981
|
+
output_path: Path for output GeoTIFF
|
|
1982
|
+
confidence_threshold: Minimum confidence score (0.0-1.0)
|
|
1983
|
+
mask_threshold: Threshold for mask binarization (0.0-1.0)
|
|
1984
|
+
overlap: Overlap between tiles (0.0-1.0)
|
|
1985
|
+
batch_size: Batch size for processing
|
|
1986
|
+
verbose: Whether to print detailed processing information
|
|
1987
|
+
|
|
1988
|
+
Returns:
|
|
1989
|
+
Path to the saved GeoTIFF
|
|
1990
|
+
"""
|
|
1991
|
+
# Use provided thresholds or fall back to instance defaults
|
|
1992
|
+
if confidence_threshold is None:
|
|
1993
|
+
confidence_threshold = self.confidence_threshold
|
|
1994
|
+
if mask_threshold is None:
|
|
1995
|
+
mask_threshold = self.mask_threshold
|
|
1996
|
+
|
|
1997
|
+
# Default output path
|
|
1998
|
+
if output_path is None:
|
|
1999
|
+
output_path = os.path.splitext(raster_path)[0] + "_masks_conf.tif"
|
|
2000
|
+
|
|
2001
|
+
# Process the raster to get individual masks with confidence
|
|
2002
|
+
with rasterio.open(raster_path) as src:
|
|
2003
|
+
# Create dataset with the specified overlap
|
|
2004
|
+
dataset = CustomDataset(
|
|
2005
|
+
raster_path=raster_path,
|
|
2006
|
+
chip_size=self.chip_size,
|
|
2007
|
+
overlap=overlap,
|
|
2008
|
+
verbose=verbose,
|
|
2009
|
+
)
|
|
2010
|
+
|
|
2011
|
+
# Create output profile
|
|
2012
|
+
output_profile = src.profile.copy()
|
|
2013
|
+
output_profile.update(
|
|
2014
|
+
dtype=rasterio.uint8,
|
|
2015
|
+
count=2, # Two bands: mask and confidence
|
|
2016
|
+
compress="lzw",
|
|
2017
|
+
nodata=0,
|
|
2018
|
+
)
|
|
2019
|
+
|
|
2020
|
+
# Initialize mask and confidence arrays
|
|
2021
|
+
mask_array = np.zeros((src.height, src.width), dtype=np.uint8)
|
|
2022
|
+
conf_array = np.zeros((src.height, src.width), dtype=np.uint8)
|
|
2023
|
+
|
|
2024
|
+
# Define custom collate function to handle Shapely objects
|
|
2025
|
+
def custom_collate(batch):
|
|
2026
|
+
"""
|
|
2027
|
+
Custom collate function that handles Shapely geometries
|
|
2028
|
+
by keeping them as Python objects rather than trying to collate them.
|
|
2029
|
+
"""
|
|
2030
|
+
elem = batch[0]
|
|
2031
|
+
if isinstance(elem, dict):
|
|
2032
|
+
result = {}
|
|
2033
|
+
for key in elem:
|
|
2034
|
+
if key == "bbox":
|
|
2035
|
+
# Don't collate shapely objects, keep as list
|
|
2036
|
+
result[key] = [d[key] for d in batch]
|
|
2037
|
+
else:
|
|
2038
|
+
# For tensors and other collatable types
|
|
2039
|
+
try:
|
|
2040
|
+
result[key] = (
|
|
2041
|
+
torch.utils.data._utils.collate.default_collate(
|
|
2042
|
+
[d[key] for d in batch]
|
|
2043
|
+
)
|
|
2044
|
+
)
|
|
2045
|
+
except TypeError:
|
|
2046
|
+
# Fall back to list for non-collatable types
|
|
2047
|
+
result[key] = [d[key] for d in batch]
|
|
2048
|
+
return result
|
|
2049
|
+
else:
|
|
2050
|
+
# Default collate for non-dict types
|
|
2051
|
+
return torch.utils.data._utils.collate.default_collate(batch)
|
|
2052
|
+
|
|
2053
|
+
# Create dataloader with custom collate function
|
|
2054
|
+
dataloader = torch.utils.data.DataLoader(
|
|
2055
|
+
dataset,
|
|
2056
|
+
batch_size=batch_size,
|
|
2057
|
+
shuffle=False,
|
|
2058
|
+
num_workers=0,
|
|
2059
|
+
collate_fn=custom_collate,
|
|
2060
|
+
)
|
|
2061
|
+
|
|
2062
|
+
# Process batches
|
|
2063
|
+
print(f"Processing raster with {len(dataloader)} batches")
|
|
2064
|
+
for batch in tqdm(dataloader):
|
|
2065
|
+
# Move images to device
|
|
2066
|
+
images = batch["image"].to(self.device)
|
|
2067
|
+
coords = batch["coords"] # Tensor of shape [batch_size, 2]
|
|
2068
|
+
|
|
2069
|
+
# Run inference
|
|
2070
|
+
with torch.no_grad():
|
|
2071
|
+
predictions = self.model(images)
|
|
2072
|
+
|
|
2073
|
+
# Process predictions
|
|
2074
|
+
for idx, prediction in enumerate(predictions):
|
|
2075
|
+
masks = prediction["masks"].cpu().numpy()
|
|
2076
|
+
scores = prediction["scores"].cpu().numpy()
|
|
2077
|
+
|
|
2078
|
+
# Filter by confidence threshold
|
|
2079
|
+
valid_indices = scores >= confidence_threshold
|
|
2080
|
+
masks = masks[valid_indices]
|
|
2081
|
+
scores = scores[valid_indices]
|
|
2082
|
+
|
|
2083
|
+
# Skip if no valid predictions
|
|
2084
|
+
if len(masks) == 0:
|
|
2085
|
+
continue
|
|
2086
|
+
|
|
2087
|
+
# Get window coordinates
|
|
2088
|
+
i, j = coords[idx].cpu().numpy()
|
|
2089
|
+
|
|
2090
|
+
# Process each mask
|
|
2091
|
+
for mask_idx, mask in enumerate(masks):
|
|
2092
|
+
# Convert to binary mask
|
|
2093
|
+
binary_mask = (mask[0] > mask_threshold).astype(np.uint8) * 255
|
|
2094
|
+
conf_value = int(scores[mask_idx] * 255) # Scale to 0-255
|
|
2095
|
+
|
|
2096
|
+
# Update the mask and confidence arrays
|
|
2097
|
+
h, w = binary_mask.shape
|
|
2098
|
+
valid_h = min(h, src.height - j)
|
|
2099
|
+
valid_w = min(w, src.width - i)
|
|
2100
|
+
|
|
2101
|
+
if valid_h > 0 and valid_w > 0:
|
|
2102
|
+
# Use maximum for overlapping regions in the mask
|
|
2103
|
+
mask_array[j : j + valid_h, i : i + valid_w] = np.maximum(
|
|
2104
|
+
mask_array[j : j + valid_h, i : i + valid_w],
|
|
2105
|
+
binary_mask[:valid_h, :valid_w],
|
|
2106
|
+
)
|
|
2107
|
+
|
|
2108
|
+
# For confidence, only update where mask is positive
|
|
2109
|
+
# and confidence is higher than existing
|
|
2110
|
+
mask_region = binary_mask[:valid_h, :valid_w] > 0
|
|
2111
|
+
if np.any(mask_region):
|
|
2112
|
+
# Only update where mask is positive and new confidence is higher
|
|
2113
|
+
current_conf = conf_array[
|
|
2114
|
+
j : j + valid_h, i : i + valid_w
|
|
2115
|
+
]
|
|
2116
|
+
|
|
2117
|
+
# Where to update confidence (mask positive & higher confidence)
|
|
2118
|
+
update_mask = np.logical_and(
|
|
2119
|
+
mask_region,
|
|
2120
|
+
np.logical_or(
|
|
2121
|
+
current_conf == 0, current_conf < conf_value
|
|
2122
|
+
),
|
|
2123
|
+
)
|
|
2124
|
+
|
|
2125
|
+
if np.any(update_mask):
|
|
2126
|
+
conf_array[j : j + valid_h, i : i + valid_w][
|
|
2127
|
+
update_mask
|
|
2128
|
+
] = conf_value
|
|
2129
|
+
|
|
2130
|
+
# Write to GeoTIFF
|
|
2131
|
+
with rasterio.open(output_path, "w", **output_profile) as dst:
|
|
2132
|
+
dst.write(mask_array, 1)
|
|
2133
|
+
dst.write(conf_array, 2)
|
|
2134
|
+
|
|
2135
|
+
print(f"Masks with confidence values saved to {output_path}")
|
|
2136
|
+
return output_path
|
|
2137
|
+
|
|
2138
|
+
def vectorize_masks(self, masks_path, output_path=None, **kwargs):
|
|
2139
|
+
"""
|
|
2140
|
+
Convert masks with confidence to vector polygons.
|
|
2141
|
+
|
|
2142
|
+
Args:
|
|
2143
|
+
masks_path: Path to masks GeoTIFF with confidence band
|
|
2144
|
+
output_path: Path for output GeoJSON
|
|
2145
|
+
**kwargs: Additional parameters
|
|
2146
|
+
|
|
2147
|
+
Returns:
|
|
2148
|
+
GeoDataFrame with car detections and confidence values
|
|
2149
|
+
"""
|
|
2150
|
+
|
|
2151
|
+
print(f"Processing masks from: {masks_path}")
|
|
2152
|
+
|
|
2153
|
+
with rasterio.open(masks_path) as src:
|
|
2154
|
+
# Read mask and confidence bands
|
|
2155
|
+
mask_data = src.read(1)
|
|
2156
|
+
conf_data = src.read(2)
|
|
2157
|
+
transform = src.transform
|
|
2158
|
+
crs = src.crs
|
|
2159
|
+
|
|
2160
|
+
# Convert to binary mask
|
|
2161
|
+
binary_mask = mask_data > 0
|
|
2162
|
+
|
|
2163
|
+
# Find connected components
|
|
2164
|
+
labeled_mask, num_features = ndimage.label(binary_mask)
|
|
2165
|
+
print(f"Found {num_features} connected components")
|
|
2166
|
+
|
|
2167
|
+
# Process each component
|
|
2168
|
+
car_polygons = []
|
|
2169
|
+
car_confidences = []
|
|
2170
|
+
|
|
2171
|
+
# Add progress bar
|
|
2172
|
+
for label in tqdm(range(1, num_features + 1), desc="Processing components"):
|
|
2173
|
+
# Create mask for this component
|
|
2174
|
+
component_mask = (labeled_mask == label).astype(np.uint8)
|
|
2175
|
+
|
|
2176
|
+
# Get confidence value (mean of non-zero values in this region)
|
|
2177
|
+
conf_region = conf_data[component_mask > 0]
|
|
2178
|
+
if len(conf_region) > 0:
|
|
2179
|
+
confidence = (
|
|
2180
|
+
np.mean(conf_region) / 255.0
|
|
2181
|
+
) # Convert back to 0-1 range
|
|
2182
|
+
else:
|
|
2183
|
+
confidence = 0.0
|
|
2184
|
+
|
|
2185
|
+
# Find contours
|
|
2186
|
+
contours, _ = cv2.findContours(
|
|
2187
|
+
component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
|
2188
|
+
)
|
|
2189
|
+
|
|
2190
|
+
for contour in contours:
|
|
2191
|
+
# Filter by size
|
|
2192
|
+
area = cv2.contourArea(contour)
|
|
2193
|
+
if area < kwargs.get("min_object_area", 100):
|
|
2194
|
+
continue
|
|
2195
|
+
|
|
2196
|
+
# Get minimum area rectangle
|
|
2197
|
+
rect = cv2.minAreaRect(contour)
|
|
2198
|
+
box_points = cv2.boxPoints(rect)
|
|
2199
|
+
|
|
2200
|
+
# Convert to geographic coordinates
|
|
2201
|
+
geo_points = []
|
|
2202
|
+
for x, y in box_points:
|
|
2203
|
+
gx, gy = transform * (x, y)
|
|
2204
|
+
geo_points.append((gx, gy))
|
|
2205
|
+
|
|
2206
|
+
# Create polygon
|
|
2207
|
+
poly = Polygon(geo_points)
|
|
2208
|
+
|
|
2209
|
+
# Add to lists
|
|
2210
|
+
car_polygons.append(poly)
|
|
2211
|
+
car_confidences.append(confidence)
|
|
2212
|
+
|
|
2213
|
+
# Create GeoDataFrame
|
|
2214
|
+
if car_polygons:
|
|
2215
|
+
gdf = gpd.GeoDataFrame(
|
|
2216
|
+
{
|
|
2217
|
+
"geometry": car_polygons,
|
|
2218
|
+
"confidence": car_confidences,
|
|
2219
|
+
"class": [1] * len(car_polygons),
|
|
2220
|
+
},
|
|
2221
|
+
crs=crs,
|
|
2222
|
+
)
|
|
2223
|
+
|
|
2224
|
+
# Save to file if requested
|
|
2225
|
+
if output_path:
|
|
2226
|
+
gdf.to_file(output_path, driver="GeoJSON")
|
|
2227
|
+
print(f"Saved {len(gdf)} cars with confidence to {output_path}")
|
|
2228
|
+
|
|
2229
|
+
return gdf
|
|
2230
|
+
else:
|
|
2231
|
+
print("No valid car polygons found")
|
|
2232
|
+
return None
|
|
2233
|
+
|
|
1859
2234
|
|
|
1860
2235
|
class ShipDetector(ObjectDetector):
|
|
1861
2236
|
"""
|