geoai-py 0.3.6__py2.py3-none-any.whl → 0.4.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +76 -14
- geoai/download.py +9 -8
- geoai/extract.py +65 -24
- geoai/geoai.py +3 -1
- geoai/hf.py +447 -0
- geoai/segment.py +4 -3
- geoai/segmentation.py +8 -7
- geoai/train.py +1039 -0
- geoai/utils.py +32 -28
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/METADATA +3 -8
- geoai_py-0.4.1.dist-info/RECORD +15 -0
- geoai_py-0.3.6.dist-info/RECORD +0 -13
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/WHEEL +0 -0
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/top_level.txt +0 -0
geoai/train.py
ADDED
|
@@ -0,0 +1,1039 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import rasterio
|
|
9
|
+
import torch
|
|
10
|
+
import torch.utils.data
|
|
11
|
+
import torchvision
|
|
12
|
+
|
|
13
|
+
# import torchvision.transforms as transforms
|
|
14
|
+
from rasterio.windows import Window
|
|
15
|
+
from skimage import measure
|
|
16
|
+
from sklearn.model_selection import train_test_split
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset
|
|
18
|
+
from torchvision.models.detection import maskrcnn_resnet50_fpn
|
|
19
|
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
20
|
+
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=True):
|
|
25
|
+
"""
|
|
26
|
+
Get Mask R-CNN model with custom input channels and output classes.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
num_classes (int): Number of output classes (including background).
|
|
30
|
+
num_channels (int): Number of input channels (3 for RGB, 4 for RGBN).
|
|
31
|
+
pretrained (bool): Whether to use pretrained backbone.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
torch.nn.Module: Mask R-CNN model with specified input channels and output classes.
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
ValueError: If num_channels is less than 3.
|
|
38
|
+
"""
|
|
39
|
+
# Validate num_channels
|
|
40
|
+
if num_channels < 3:
|
|
41
|
+
raise ValueError("num_channels must be at least 3")
|
|
42
|
+
|
|
43
|
+
# Load pre-trained model
|
|
44
|
+
model = maskrcnn_resnet50_fpn(
|
|
45
|
+
pretrained=pretrained,
|
|
46
|
+
progress=True,
|
|
47
|
+
weights=(
|
|
48
|
+
torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT
|
|
49
|
+
if pretrained
|
|
50
|
+
else None
|
|
51
|
+
),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Modify transform if num_channels is different from 3
|
|
55
|
+
if num_channels != 3:
|
|
56
|
+
# Get the transform
|
|
57
|
+
transform = model.transform
|
|
58
|
+
|
|
59
|
+
# Default values are [0.485, 0.456, 0.406] and [0.229, 0.224, 0.225]
|
|
60
|
+
# Calculate means and stds for additional channels
|
|
61
|
+
rgb_mean = [0.485, 0.456, 0.406]
|
|
62
|
+
rgb_std = [0.229, 0.224, 0.225]
|
|
63
|
+
|
|
64
|
+
# Extend them to num_channels (use the mean value for additional channels)
|
|
65
|
+
mean_of_means = sum(rgb_mean) / len(rgb_mean)
|
|
66
|
+
mean_of_stds = sum(rgb_std) / len(rgb_std)
|
|
67
|
+
|
|
68
|
+
# Create new lists with appropriate length
|
|
69
|
+
transform.image_mean = rgb_mean + [mean_of_means] * (num_channels - 3)
|
|
70
|
+
transform.image_std = rgb_std + [mean_of_stds] * (num_channels - 3)
|
|
71
|
+
|
|
72
|
+
# Get number of input features for the classifier
|
|
73
|
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
|
74
|
+
|
|
75
|
+
# Replace the pre-trained head with a new one
|
|
76
|
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
|
77
|
+
|
|
78
|
+
# Get number of input features for mask classifier
|
|
79
|
+
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
|
80
|
+
hidden_layer = 256
|
|
81
|
+
|
|
82
|
+
# Replace mask predictor with a new one
|
|
83
|
+
model.roi_heads.mask_predictor = MaskRCNNPredictor(
|
|
84
|
+
in_features_mask, hidden_layer, num_classes
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Modify the first layer if num_channels is different from 3
|
|
88
|
+
if num_channels != 3:
|
|
89
|
+
original_layer = model.backbone.body.conv1
|
|
90
|
+
model.backbone.body.conv1 = torch.nn.Conv2d(
|
|
91
|
+
num_channels,
|
|
92
|
+
original_layer.out_channels,
|
|
93
|
+
kernel_size=original_layer.kernel_size,
|
|
94
|
+
stride=original_layer.stride,
|
|
95
|
+
padding=original_layer.padding,
|
|
96
|
+
bias=original_layer.bias is not None,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Copy weights from the original 3 channels to the new layer
|
|
100
|
+
with torch.no_grad():
|
|
101
|
+
# Copy the weights for the first 3 channels
|
|
102
|
+
model.backbone.body.conv1.weight[:, :3, :, :] = original_layer.weight
|
|
103
|
+
|
|
104
|
+
# Initialize additional channels with the mean of the first 3 channels
|
|
105
|
+
mean_weight = original_layer.weight.mean(dim=1, keepdim=True)
|
|
106
|
+
for i in range(3, num_channels):
|
|
107
|
+
model.backbone.body.conv1.weight[:, i : i + 1, :, :] = mean_weight
|
|
108
|
+
|
|
109
|
+
# Copy bias if it exists
|
|
110
|
+
if original_layer.bias is not None:
|
|
111
|
+
model.backbone.body.conv1.bias = original_layer.bias
|
|
112
|
+
|
|
113
|
+
return model
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ObjectDetectionDataset(Dataset):
|
|
117
|
+
"""Dataset for object detection from GeoTIFF images and labels."""
|
|
118
|
+
|
|
119
|
+
def __init__(self, image_paths, label_paths, transforms=None, num_channels=None):
|
|
120
|
+
"""
|
|
121
|
+
Initialize dataset.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
image_paths (list): List of paths to image GeoTIFF files.
|
|
125
|
+
label_paths (list): List of paths to label GeoTIFF files.
|
|
126
|
+
transforms (callable, optional): Transformations to apply to images and masks.
|
|
127
|
+
num_channels (int, optional): Number of channels to use from images. If None,
|
|
128
|
+
auto-detected from the first image.
|
|
129
|
+
"""
|
|
130
|
+
self.image_paths = image_paths
|
|
131
|
+
self.label_paths = label_paths
|
|
132
|
+
self.transforms = transforms
|
|
133
|
+
|
|
134
|
+
# Auto-detect the number of channels if not specified
|
|
135
|
+
if num_channels is None:
|
|
136
|
+
with rasterio.open(self.image_paths[0]) as src:
|
|
137
|
+
self.num_channels = src.count
|
|
138
|
+
else:
|
|
139
|
+
self.num_channels = num_channels
|
|
140
|
+
|
|
141
|
+
def __len__(self):
|
|
142
|
+
return len(self.image_paths)
|
|
143
|
+
|
|
144
|
+
def __getitem__(self, idx):
|
|
145
|
+
# Load image
|
|
146
|
+
with rasterio.open(self.image_paths[idx]) as src:
|
|
147
|
+
# Read as [C, H, W] format
|
|
148
|
+
image = src.read().astype(np.float32)
|
|
149
|
+
|
|
150
|
+
# Normalize image to [0, 1] range
|
|
151
|
+
image = image / 255.0
|
|
152
|
+
|
|
153
|
+
# Handle different number of channels
|
|
154
|
+
if image.shape[0] > self.num_channels:
|
|
155
|
+
image = image[
|
|
156
|
+
: self.num_channels
|
|
157
|
+
] # Keep only first 4 bands if more exist
|
|
158
|
+
elif image.shape[0] < self.num_channels:
|
|
159
|
+
# Pad with zeros if less than 4 bands
|
|
160
|
+
padded = np.zeros(
|
|
161
|
+
(self.num_channels, image.shape[1], image.shape[2]),
|
|
162
|
+
dtype=np.float32,
|
|
163
|
+
)
|
|
164
|
+
padded[: image.shape[0]] = image
|
|
165
|
+
image = padded
|
|
166
|
+
|
|
167
|
+
# Convert to CHW tensor
|
|
168
|
+
image = torch.as_tensor(image, dtype=torch.float32)
|
|
169
|
+
|
|
170
|
+
# Load label mask
|
|
171
|
+
with rasterio.open(self.label_paths[idx]) as src:
|
|
172
|
+
label_mask = src.read(1)
|
|
173
|
+
binary_mask = (label_mask > 0).astype(np.uint8)
|
|
174
|
+
|
|
175
|
+
# Find all building instances using connected components
|
|
176
|
+
labeled_mask, num_instances = measure.label(
|
|
177
|
+
binary_mask, return_num=True, connectivity=2
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Create list to hold masks for each building instance
|
|
181
|
+
masks = []
|
|
182
|
+
boxes = []
|
|
183
|
+
labels = []
|
|
184
|
+
|
|
185
|
+
for i in range(1, num_instances + 1):
|
|
186
|
+
# Create mask for this instance
|
|
187
|
+
instance_mask = (labeled_mask == i).astype(np.uint8)
|
|
188
|
+
|
|
189
|
+
# Calculate area and filter out tiny instances (noise)
|
|
190
|
+
area = instance_mask.sum()
|
|
191
|
+
if area < 10: # Minimum area threshold
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
# Find bounding box coordinates
|
|
195
|
+
pos = np.where(instance_mask)
|
|
196
|
+
if len(pos[0]) == 0: # Skip if mask is empty
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
xmin = np.min(pos[1])
|
|
200
|
+
xmax = np.max(pos[1])
|
|
201
|
+
ymin = np.min(pos[0])
|
|
202
|
+
ymax = np.max(pos[0])
|
|
203
|
+
|
|
204
|
+
# Skip invalid boxes
|
|
205
|
+
if xmax <= xmin or ymax <= ymin:
|
|
206
|
+
continue
|
|
207
|
+
|
|
208
|
+
# Add small padding to ensure the mask is within the box
|
|
209
|
+
xmin = max(0, xmin - 1)
|
|
210
|
+
ymin = max(0, ymin - 1)
|
|
211
|
+
xmax = min(binary_mask.shape[1] - 1, xmax + 1)
|
|
212
|
+
ymax = min(binary_mask.shape[0] - 1, ymax + 1)
|
|
213
|
+
|
|
214
|
+
boxes.append([xmin, ymin, xmax, ymax])
|
|
215
|
+
masks.append(instance_mask)
|
|
216
|
+
labels.append(1) # 1 for building class
|
|
217
|
+
|
|
218
|
+
# Handle case with no valid instances
|
|
219
|
+
if len(boxes) == 0:
|
|
220
|
+
# Create a dummy target with minimal required fields
|
|
221
|
+
target = {
|
|
222
|
+
"boxes": torch.zeros((0, 4), dtype=torch.float32),
|
|
223
|
+
"labels": torch.zeros((0), dtype=torch.int64),
|
|
224
|
+
"masks": torch.zeros(
|
|
225
|
+
(0, binary_mask.shape[0], binary_mask.shape[1]), dtype=torch.uint8
|
|
226
|
+
),
|
|
227
|
+
"image_id": torch.tensor([idx]),
|
|
228
|
+
"area": torch.zeros((0), dtype=torch.float32),
|
|
229
|
+
"iscrowd": torch.zeros((0), dtype=torch.int64),
|
|
230
|
+
}
|
|
231
|
+
else:
|
|
232
|
+
# Convert to tensors
|
|
233
|
+
boxes = torch.as_tensor(boxes, dtype=torch.float32)
|
|
234
|
+
labels = torch.as_tensor(labels, dtype=torch.int64)
|
|
235
|
+
masks = torch.as_tensor(np.array(masks), dtype=torch.uint8)
|
|
236
|
+
|
|
237
|
+
# Calculate area of boxes
|
|
238
|
+
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
|
239
|
+
|
|
240
|
+
# Prepare target dictionary
|
|
241
|
+
target = {
|
|
242
|
+
"boxes": boxes,
|
|
243
|
+
"labels": labels,
|
|
244
|
+
"masks": masks,
|
|
245
|
+
"image_id": torch.tensor([idx]),
|
|
246
|
+
"area": area,
|
|
247
|
+
"iscrowd": torch.zeros_like(labels), # Assume no crowd instances
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
# Apply transforms if specified
|
|
251
|
+
if self.transforms is not None:
|
|
252
|
+
image, target = self.transforms(image, target)
|
|
253
|
+
|
|
254
|
+
return image, target
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class Compose:
|
|
258
|
+
"""Custom compose transform that works with image and target."""
|
|
259
|
+
|
|
260
|
+
def __init__(self, transforms):
|
|
261
|
+
"""
|
|
262
|
+
Initialize compose transform.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
transforms (list): List of transforms to apply.
|
|
266
|
+
"""
|
|
267
|
+
self.transforms = transforms
|
|
268
|
+
|
|
269
|
+
def __call__(self, image, target):
|
|
270
|
+
for t in self.transforms:
|
|
271
|
+
image, target = t(image, target)
|
|
272
|
+
return image, target
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class ToTensor:
|
|
276
|
+
"""Convert numpy.ndarray to tensor."""
|
|
277
|
+
|
|
278
|
+
def __call__(self, image, target):
|
|
279
|
+
"""
|
|
280
|
+
Apply transform to image and target.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
image (torch.Tensor): Input image.
|
|
284
|
+
target (dict): Target annotations.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
tuple: Transformed image and target.
|
|
288
|
+
"""
|
|
289
|
+
return image, target
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class RandomHorizontalFlip:
|
|
293
|
+
"""Random horizontal flip transform."""
|
|
294
|
+
|
|
295
|
+
def __init__(self, prob=0.5):
|
|
296
|
+
"""
|
|
297
|
+
Initialize random horizontal flip.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
prob (float): Probability of applying the flip.
|
|
301
|
+
"""
|
|
302
|
+
self.prob = prob
|
|
303
|
+
|
|
304
|
+
def __call__(self, image, target):
|
|
305
|
+
if random.random() < self.prob:
|
|
306
|
+
# Flip image
|
|
307
|
+
image = torch.flip(image, dims=[2]) # Flip along width dimension
|
|
308
|
+
|
|
309
|
+
# Flip masks
|
|
310
|
+
if "masks" in target and len(target["masks"]) > 0:
|
|
311
|
+
target["masks"] = torch.flip(target["masks"], dims=[2])
|
|
312
|
+
|
|
313
|
+
# Update boxes
|
|
314
|
+
if "boxes" in target and len(target["boxes"]) > 0:
|
|
315
|
+
boxes = target["boxes"]
|
|
316
|
+
width = image.shape[2]
|
|
317
|
+
boxes[:, 0], boxes[:, 2] = width - boxes[:, 2], width - boxes[:, 0]
|
|
318
|
+
target["boxes"] = boxes
|
|
319
|
+
|
|
320
|
+
return image, target
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def get_transform(train):
|
|
324
|
+
"""
|
|
325
|
+
Get transforms for data augmentation.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
train (bool): Whether to include training-specific transforms.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
Compose: Composed transforms.
|
|
332
|
+
"""
|
|
333
|
+
transforms = []
|
|
334
|
+
transforms.append(ToTensor())
|
|
335
|
+
|
|
336
|
+
if train:
|
|
337
|
+
transforms.append(RandomHorizontalFlip(0.5))
|
|
338
|
+
|
|
339
|
+
return Compose(transforms)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def collate_fn(batch):
|
|
343
|
+
"""
|
|
344
|
+
Custom collate function for batching samples.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
batch (list): List of (image, target) tuples.
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
tuple: Tuple of images and targets.
|
|
351
|
+
"""
|
|
352
|
+
return tuple(zip(*batch))
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
|
|
356
|
+
"""
|
|
357
|
+
Train the model for one epoch.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
model (torch.nn.Module): The model to train.
|
|
361
|
+
optimizer (torch.optim.Optimizer): The optimizer to use.
|
|
362
|
+
data_loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
363
|
+
device (torch.device): Device to train on.
|
|
364
|
+
epoch (int): Current epoch number.
|
|
365
|
+
print_freq (int): How often to print progress.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
float: Average loss for the epoch.
|
|
369
|
+
"""
|
|
370
|
+
model.train()
|
|
371
|
+
total_loss = 0
|
|
372
|
+
|
|
373
|
+
start_time = time.time()
|
|
374
|
+
|
|
375
|
+
for i, (images, targets) in enumerate(data_loader):
|
|
376
|
+
# Move images and targets to device
|
|
377
|
+
images = list(image.to(device) for image in images)
|
|
378
|
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
379
|
+
|
|
380
|
+
# Forward pass
|
|
381
|
+
loss_dict = model(images, targets)
|
|
382
|
+
losses = sum(loss for loss in loss_dict.values())
|
|
383
|
+
|
|
384
|
+
# Backward pass
|
|
385
|
+
optimizer.zero_grad()
|
|
386
|
+
losses.backward()
|
|
387
|
+
optimizer.step()
|
|
388
|
+
|
|
389
|
+
# Track loss
|
|
390
|
+
total_loss += losses.item()
|
|
391
|
+
|
|
392
|
+
# Print progress
|
|
393
|
+
if i % print_freq == 0:
|
|
394
|
+
elapsed_time = time.time() - start_time
|
|
395
|
+
print(
|
|
396
|
+
f"Epoch: {epoch}, Batch: {i}/{len(data_loader)}, Loss: {losses.item():.4f}, Time: {elapsed_time:.2f}s"
|
|
397
|
+
)
|
|
398
|
+
start_time = time.time()
|
|
399
|
+
|
|
400
|
+
# Calculate average loss
|
|
401
|
+
avg_loss = total_loss / len(data_loader)
|
|
402
|
+
return avg_loss
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def evaluate(model, data_loader, device):
|
|
406
|
+
"""
|
|
407
|
+
Evaluate the model on the validation set.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
model (torch.nn.Module): The model to evaluate.
|
|
411
|
+
data_loader (torch.utils.data.DataLoader): DataLoader for validation data.
|
|
412
|
+
device (torch.device): Device to evaluate on.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
dict: Evaluation metrics including loss and IoU.
|
|
416
|
+
"""
|
|
417
|
+
model.eval()
|
|
418
|
+
|
|
419
|
+
# Initialize metrics
|
|
420
|
+
total_loss = 0
|
|
421
|
+
iou_scores = []
|
|
422
|
+
|
|
423
|
+
with torch.no_grad():
|
|
424
|
+
for images, targets in data_loader:
|
|
425
|
+
# Move to device
|
|
426
|
+
images = list(image.to(device) for image in images)
|
|
427
|
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
428
|
+
|
|
429
|
+
# During evaluation, Mask R-CNN directly returns predictions, not losses
|
|
430
|
+
# So we'll only get loss when we provide targets explicitly
|
|
431
|
+
if len(targets) > 0:
|
|
432
|
+
try:
|
|
433
|
+
# Try to get loss dict (this works in some implementations)
|
|
434
|
+
loss_dict = model(images, targets)
|
|
435
|
+
if isinstance(loss_dict, dict):
|
|
436
|
+
losses = sum(loss for loss in loss_dict.values())
|
|
437
|
+
total_loss += losses.item()
|
|
438
|
+
except Exception as e:
|
|
439
|
+
print(f"Warning: Could not compute loss during evaluation: {e}")
|
|
440
|
+
# If we can't compute loss, we'll just focus on IoU
|
|
441
|
+
pass
|
|
442
|
+
|
|
443
|
+
# Get predictions
|
|
444
|
+
outputs = model(images)
|
|
445
|
+
|
|
446
|
+
# Calculate IoU for each image
|
|
447
|
+
for i, output in enumerate(outputs):
|
|
448
|
+
if len(output["masks"]) == 0 or len(targets[i]["masks"]) == 0:
|
|
449
|
+
continue
|
|
450
|
+
|
|
451
|
+
# Convert predicted masks to binary (threshold at 0.5)
|
|
452
|
+
pred_masks = (output["masks"].squeeze(1) > 0.5).float()
|
|
453
|
+
|
|
454
|
+
# Combine all instance masks into a single binary mask
|
|
455
|
+
pred_combined = (
|
|
456
|
+
torch.max(pred_masks, dim=0)[0]
|
|
457
|
+
if pred_masks.shape[0] > 0
|
|
458
|
+
else torch.zeros_like(targets[i]["masks"][0])
|
|
459
|
+
)
|
|
460
|
+
target_combined = (
|
|
461
|
+
torch.max(targets[i]["masks"], dim=0)[0]
|
|
462
|
+
if targets[i]["masks"].shape[0] > 0
|
|
463
|
+
else torch.zeros_like(pred_combined)
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Calculate IoU
|
|
467
|
+
intersection = (pred_combined * target_combined).sum().item()
|
|
468
|
+
union = ((pred_combined + target_combined) > 0).sum().item()
|
|
469
|
+
|
|
470
|
+
if union > 0:
|
|
471
|
+
iou = intersection / union
|
|
472
|
+
iou_scores.append(iou)
|
|
473
|
+
|
|
474
|
+
# Calculate metrics
|
|
475
|
+
avg_loss = total_loss / len(data_loader) if total_loss > 0 else float("inf")
|
|
476
|
+
avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
|
|
477
|
+
|
|
478
|
+
return {"loss": avg_loss, "IoU": avg_iou}
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def visualize_predictions(model, dataset, device, num_samples=5, output_dir=None):
|
|
482
|
+
"""
|
|
483
|
+
Visualize model predictions.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
model (torch.nn.Module): Trained model.
|
|
487
|
+
dataset (torch.utils.data.Dataset): Dataset to visualize.
|
|
488
|
+
device (torch.device): Device to run inference on.
|
|
489
|
+
num_samples (int): Number of samples to visualize.
|
|
490
|
+
output_dir (str, optional): Directory to save visualizations. If None,
|
|
491
|
+
visualizations are displayed but not saved.
|
|
492
|
+
"""
|
|
493
|
+
model.eval()
|
|
494
|
+
|
|
495
|
+
# Create output directory if needed
|
|
496
|
+
if output_dir:
|
|
497
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
498
|
+
|
|
499
|
+
# Select random samples
|
|
500
|
+
indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
|
|
501
|
+
|
|
502
|
+
for idx in indices:
|
|
503
|
+
# Get image and target
|
|
504
|
+
image, target = dataset[idx]
|
|
505
|
+
|
|
506
|
+
# Convert to device and add batch dimension
|
|
507
|
+
image = image.to(device)
|
|
508
|
+
image_batch = [image]
|
|
509
|
+
|
|
510
|
+
# Get prediction
|
|
511
|
+
with torch.no_grad():
|
|
512
|
+
output = model(image_batch)[0]
|
|
513
|
+
|
|
514
|
+
# Convert image from CHW to HWC for display (first 3 bands as RGB)
|
|
515
|
+
rgb_image = image[:3].cpu().numpy()
|
|
516
|
+
rgb_image = np.transpose(rgb_image, (1, 2, 0))
|
|
517
|
+
rgb_image = np.clip(rgb_image, 0, 1) # Ensure values are in [0,1]
|
|
518
|
+
|
|
519
|
+
# Create binary ground truth mask (combine all instances)
|
|
520
|
+
gt_masks = target["masks"].cpu().numpy()
|
|
521
|
+
gt_combined = (
|
|
522
|
+
np.max(gt_masks, axis=0)
|
|
523
|
+
if len(gt_masks) > 0
|
|
524
|
+
else np.zeros((image.shape[1], image.shape[2]), dtype=np.uint8)
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
# Create binary prediction mask (combine all instances with score > 0.5)
|
|
528
|
+
pred_masks = output["masks"].cpu().numpy()
|
|
529
|
+
pred_scores = output["scores"].cpu().numpy()
|
|
530
|
+
high_conf_indices = pred_scores > 0.5
|
|
531
|
+
|
|
532
|
+
pred_combined = np.zeros((image.shape[1], image.shape[2]), dtype=np.float32)
|
|
533
|
+
if np.any(high_conf_indices):
|
|
534
|
+
for mask in pred_masks[high_conf_indices]:
|
|
535
|
+
# Apply threshold to each predicted mask
|
|
536
|
+
binary_mask = (mask[0] > 0.5).astype(np.float32)
|
|
537
|
+
# Combine with existing masks
|
|
538
|
+
pred_combined = np.maximum(pred_combined, binary_mask)
|
|
539
|
+
|
|
540
|
+
# Create figure
|
|
541
|
+
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
|
|
542
|
+
|
|
543
|
+
# Show RGB image
|
|
544
|
+
axs[0].imshow(rgb_image)
|
|
545
|
+
axs[0].set_title("RGB Image")
|
|
546
|
+
axs[0].axis("off")
|
|
547
|
+
|
|
548
|
+
# Show prediction
|
|
549
|
+
axs[1].imshow(pred_combined, cmap="viridis")
|
|
550
|
+
axs[1].set_title(f"Predicted Buildings: {np.sum(high_conf_indices)} instances")
|
|
551
|
+
axs[1].axis("off")
|
|
552
|
+
|
|
553
|
+
# Show ground truth
|
|
554
|
+
axs[2].imshow(gt_combined, cmap="viridis")
|
|
555
|
+
axs[2].set_title(f"Ground Truth: {len(gt_masks)} instances")
|
|
556
|
+
axs[2].axis("off")
|
|
557
|
+
|
|
558
|
+
plt.tight_layout()
|
|
559
|
+
|
|
560
|
+
# Save or show
|
|
561
|
+
if output_dir:
|
|
562
|
+
plt.savefig(os.path.join(output_dir, f"prediction_{idx}.png"))
|
|
563
|
+
plt.close()
|
|
564
|
+
else:
|
|
565
|
+
plt.show()
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def train_MaskRCNN_model(
|
|
569
|
+
images_dir,
|
|
570
|
+
labels_dir,
|
|
571
|
+
output_dir,
|
|
572
|
+
num_channels=3,
|
|
573
|
+
pretrained=True,
|
|
574
|
+
batch_size=4,
|
|
575
|
+
num_epochs=10,
|
|
576
|
+
learning_rate=0.005,
|
|
577
|
+
seed=42,
|
|
578
|
+
val_split=0.2,
|
|
579
|
+
visualize=False,
|
|
580
|
+
):
|
|
581
|
+
"""
|
|
582
|
+
Train and evaluate Mask R-CNN model for instance segmentation.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
images_dir (str): Directory containing image GeoTIFF files.
|
|
586
|
+
labels_dir (str): Directory containing label GeoTIFF files.
|
|
587
|
+
output_dir (str): Directory to save model checkpoints and results.
|
|
588
|
+
num_channels (int, optional): Number of input channels. If None, auto-detected.
|
|
589
|
+
pretrained (bool): Whether to use pretrained backbone.
|
|
590
|
+
batch_size (int): Batch size for training.
|
|
591
|
+
num_epochs (int): Number of training epochs.
|
|
592
|
+
learning_rate (float): Initial learning rate.
|
|
593
|
+
seed (int): Random seed for reproducibility.
|
|
594
|
+
val_split (float): Fraction of data to use for validation (0-1).
|
|
595
|
+
visualize (bool): Whether to generate visualizations of model predictions.
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
None: Model weights are saved to output_dir.
|
|
599
|
+
"""
|
|
600
|
+
|
|
601
|
+
# Set random seeds for reproducibility
|
|
602
|
+
torch.manual_seed(seed)
|
|
603
|
+
np.random.seed(seed)
|
|
604
|
+
random.seed(seed)
|
|
605
|
+
torch.backends.cudnn.deterministic = True
|
|
606
|
+
torch.backends.cudnn.benchmark = False
|
|
607
|
+
|
|
608
|
+
# Create output directory
|
|
609
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
610
|
+
|
|
611
|
+
# Get device
|
|
612
|
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
613
|
+
print(f"Using device: {device}")
|
|
614
|
+
|
|
615
|
+
# Get all image and label files
|
|
616
|
+
image_files = sorted(
|
|
617
|
+
[
|
|
618
|
+
os.path.join(images_dir, f)
|
|
619
|
+
for f in os.listdir(images_dir)
|
|
620
|
+
if f.endswith(".tif")
|
|
621
|
+
]
|
|
622
|
+
)
|
|
623
|
+
label_files = sorted(
|
|
624
|
+
[
|
|
625
|
+
os.path.join(labels_dir, f)
|
|
626
|
+
for f in os.listdir(labels_dir)
|
|
627
|
+
if f.endswith(".tif")
|
|
628
|
+
]
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
print(f"Found {len(image_files)} image files and {len(label_files)} label files")
|
|
632
|
+
|
|
633
|
+
# Ensure matching files
|
|
634
|
+
if len(image_files) != len(label_files):
|
|
635
|
+
print("Warning: Number of image files and label files don't match!")
|
|
636
|
+
# Find matching files by basename
|
|
637
|
+
basenames = [os.path.basename(f) for f in image_files]
|
|
638
|
+
label_files = [
|
|
639
|
+
os.path.join(labels_dir, os.path.basename(f))
|
|
640
|
+
for f in image_files
|
|
641
|
+
if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
|
|
642
|
+
]
|
|
643
|
+
image_files = [
|
|
644
|
+
f
|
|
645
|
+
for f, b in zip(image_files, basenames)
|
|
646
|
+
if os.path.exists(os.path.join(labels_dir, b))
|
|
647
|
+
]
|
|
648
|
+
print(f"Using {len(image_files)} matching files")
|
|
649
|
+
|
|
650
|
+
# Split data into train and validation sets
|
|
651
|
+
train_imgs, val_imgs, train_labels, val_labels = train_test_split(
|
|
652
|
+
image_files, label_files, test_size=val_split, random_state=seed
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
print(f"Training on {len(train_imgs)} images, validating on {len(val_imgs)} images")
|
|
656
|
+
|
|
657
|
+
# Create datasets
|
|
658
|
+
train_dataset = ObjectDetectionDataset(
|
|
659
|
+
train_imgs, train_labels, transforms=get_transform(train=True)
|
|
660
|
+
)
|
|
661
|
+
val_dataset = ObjectDetectionDataset(
|
|
662
|
+
val_imgs, val_labels, transforms=get_transform(train=False)
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
# Create data loaders
|
|
666
|
+
train_loader = DataLoader(
|
|
667
|
+
train_dataset,
|
|
668
|
+
batch_size=batch_size,
|
|
669
|
+
shuffle=True,
|
|
670
|
+
collate_fn=collate_fn,
|
|
671
|
+
num_workers=4,
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
val_loader = DataLoader(
|
|
675
|
+
val_dataset,
|
|
676
|
+
batch_size=batch_size,
|
|
677
|
+
shuffle=False,
|
|
678
|
+
collate_fn=collate_fn,
|
|
679
|
+
num_workers=4,
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# Initialize model (2 classes: background and building)
|
|
683
|
+
model = get_instance_segmentation_model(
|
|
684
|
+
num_classes=2, num_channels=num_channels, pretrained=pretrained
|
|
685
|
+
)
|
|
686
|
+
model.to(device)
|
|
687
|
+
|
|
688
|
+
# Set up optimizer
|
|
689
|
+
params = [p for p in model.parameters() if p.requires_grad]
|
|
690
|
+
optimizer = torch.optim.SGD(
|
|
691
|
+
params, lr=learning_rate, momentum=0.9, weight_decay=0.0005
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
# Set up learning rate scheduler
|
|
695
|
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)
|
|
696
|
+
|
|
697
|
+
# Training loop
|
|
698
|
+
best_iou = 0
|
|
699
|
+
for epoch in range(num_epochs):
|
|
700
|
+
# Train one epoch
|
|
701
|
+
train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch)
|
|
702
|
+
|
|
703
|
+
# Update learning rate
|
|
704
|
+
lr_scheduler.step()
|
|
705
|
+
|
|
706
|
+
# Evaluate
|
|
707
|
+
eval_metrics = evaluate(model, val_loader, device)
|
|
708
|
+
|
|
709
|
+
# Print metrics
|
|
710
|
+
print(
|
|
711
|
+
f"Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Val Loss: {eval_metrics['loss']:.4f}, Val IoU: {eval_metrics['IoU']:.4f}"
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# Save best model
|
|
715
|
+
if eval_metrics["IoU"] > best_iou:
|
|
716
|
+
best_iou = eval_metrics["IoU"]
|
|
717
|
+
print(f"Saving best model with IoU: {best_iou:.4f}")
|
|
718
|
+
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
|
719
|
+
|
|
720
|
+
# Save checkpoint every 10 epochs
|
|
721
|
+
if (epoch + 1) % 10 == 0:
|
|
722
|
+
torch.save(
|
|
723
|
+
{
|
|
724
|
+
"epoch": epoch,
|
|
725
|
+
"model_state_dict": model.state_dict(),
|
|
726
|
+
"optimizer_state_dict": optimizer.state_dict(),
|
|
727
|
+
"scheduler_state_dict": lr_scheduler.state_dict(),
|
|
728
|
+
"best_iou": best_iou,
|
|
729
|
+
},
|
|
730
|
+
os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
# Save final model
|
|
734
|
+
torch.save(model.state_dict(), os.path.join(output_dir, "final_model.pth"))
|
|
735
|
+
|
|
736
|
+
# Load best model for evaluation and visualization
|
|
737
|
+
model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
|
|
738
|
+
|
|
739
|
+
# Final evaluation
|
|
740
|
+
final_metrics = evaluate(model, val_loader, device)
|
|
741
|
+
print(
|
|
742
|
+
f"Final Evaluation - Loss: {final_metrics['loss']:.4f}, IoU: {final_metrics['IoU']:.4f}"
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
# Visualize results
|
|
746
|
+
if visualize:
|
|
747
|
+
print("Generating visualizations...")
|
|
748
|
+
visualize_predictions(
|
|
749
|
+
model,
|
|
750
|
+
val_dataset,
|
|
751
|
+
device,
|
|
752
|
+
num_samples=5,
|
|
753
|
+
output_dir=os.path.join(output_dir, "visualizations"),
|
|
754
|
+
)
|
|
755
|
+
print(f"Training complete!. Trained model saved to {output_dir}")
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def inference_on_geotiff(
|
|
759
|
+
model,
|
|
760
|
+
geotiff_path,
|
|
761
|
+
output_path,
|
|
762
|
+
window_size=512,
|
|
763
|
+
overlap=256,
|
|
764
|
+
confidence_threshold=0.5,
|
|
765
|
+
batch_size=4,
|
|
766
|
+
num_channels=3,
|
|
767
|
+
device=None,
|
|
768
|
+
**kwargs,
|
|
769
|
+
):
|
|
770
|
+
"""
|
|
771
|
+
Perform inference on a large GeoTIFF using a sliding window approach with improved blending.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
model (torch.nn.Module): Trained model for inference.
|
|
775
|
+
geotiff_path (str): Path to input GeoTIFF file.
|
|
776
|
+
output_path (str): Path to save output mask GeoTIFF.
|
|
777
|
+
window_size (int): Size of sliding window for inference.
|
|
778
|
+
overlap (int): Overlap between adjacent windows.
|
|
779
|
+
confidence_threshold (float): Confidence threshold for predictions (0-1).
|
|
780
|
+
batch_size (int): Batch size for inference.
|
|
781
|
+
num_channels (int): Number of channels to use from the input image.
|
|
782
|
+
device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
|
|
783
|
+
**kwargs: Additional arguments.
|
|
784
|
+
|
|
785
|
+
Returns:
|
|
786
|
+
tuple: Tuple containing output path and inference time in seconds.
|
|
787
|
+
"""
|
|
788
|
+
if device is None:
|
|
789
|
+
device = (
|
|
790
|
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
# Put model in evaluation mode
|
|
794
|
+
model.to(device)
|
|
795
|
+
model.eval()
|
|
796
|
+
|
|
797
|
+
# Open the GeoTIFF
|
|
798
|
+
with rasterio.open(geotiff_path) as src:
|
|
799
|
+
# Read metadata
|
|
800
|
+
meta = src.meta
|
|
801
|
+
height = src.height
|
|
802
|
+
width = src.width
|
|
803
|
+
|
|
804
|
+
# Update metadata for output raster
|
|
805
|
+
out_meta = meta.copy()
|
|
806
|
+
out_meta.update(
|
|
807
|
+
{"count": 1, "dtype": "uint8"} # Single band for mask # Binary mask
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
# We'll use two arrays:
|
|
811
|
+
# 1. For accumulating predictions
|
|
812
|
+
pred_accumulator = np.zeros((height, width), dtype=np.float32)
|
|
813
|
+
# 2. For tracking how many predictions contribute to each pixel
|
|
814
|
+
count_accumulator = np.zeros((height, width), dtype=np.float32)
|
|
815
|
+
|
|
816
|
+
# Calculate the number of windows needed to cover the entire image
|
|
817
|
+
steps_y = math.ceil((height - overlap) / (window_size - overlap))
|
|
818
|
+
steps_x = math.ceil((width - overlap) / (window_size - overlap))
|
|
819
|
+
|
|
820
|
+
# Ensure we cover the entire image
|
|
821
|
+
last_y = height - window_size
|
|
822
|
+
last_x = width - window_size
|
|
823
|
+
|
|
824
|
+
total_windows = steps_y * steps_x
|
|
825
|
+
print(
|
|
826
|
+
f"Processing {total_windows} windows with size {window_size}x{window_size} and overlap {overlap}..."
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
# Create progress bar
|
|
830
|
+
pbar = tqdm(total=total_windows)
|
|
831
|
+
|
|
832
|
+
# Process in batches
|
|
833
|
+
batch_inputs = []
|
|
834
|
+
batch_positions = []
|
|
835
|
+
batch_count = 0
|
|
836
|
+
|
|
837
|
+
start_time = time.time()
|
|
838
|
+
|
|
839
|
+
# Slide window over the image - make sure we cover the entire image
|
|
840
|
+
for i in range(steps_y + 1): # +1 to ensure we reach the edge
|
|
841
|
+
y = min(i * (window_size - overlap), last_y)
|
|
842
|
+
y = max(0, y) # Prevent negative indices
|
|
843
|
+
|
|
844
|
+
if y > last_y and i > 0: # Skip if we've already covered the entire height
|
|
845
|
+
continue
|
|
846
|
+
|
|
847
|
+
for j in range(steps_x + 1): # +1 to ensure we reach the edge
|
|
848
|
+
x = min(j * (window_size - overlap), last_x)
|
|
849
|
+
x = max(0, x) # Prevent negative indices
|
|
850
|
+
|
|
851
|
+
if (
|
|
852
|
+
x > last_x and j > 0
|
|
853
|
+
): # Skip if we've already covered the entire width
|
|
854
|
+
continue
|
|
855
|
+
|
|
856
|
+
# Read window
|
|
857
|
+
window = src.read(window=Window(x, y, window_size, window_size))
|
|
858
|
+
|
|
859
|
+
# Check if window is valid
|
|
860
|
+
if window.shape[1] != window_size or window.shape[2] != window_size:
|
|
861
|
+
# This can happen at image edges - adjust window size
|
|
862
|
+
current_height = window.shape[1]
|
|
863
|
+
current_width = window.shape[2]
|
|
864
|
+
if current_height == 0 or current_width == 0:
|
|
865
|
+
continue # Skip empty windows
|
|
866
|
+
else:
|
|
867
|
+
current_height = window_size
|
|
868
|
+
current_width = window_size
|
|
869
|
+
|
|
870
|
+
# Normalize and prepare input
|
|
871
|
+
image = window.astype(np.float32) / 255.0
|
|
872
|
+
|
|
873
|
+
# Handle different number of bands
|
|
874
|
+
if image.shape[0] > num_channels:
|
|
875
|
+
image = image[:num_channels]
|
|
876
|
+
elif image.shape[0] < num_channels:
|
|
877
|
+
padded = np.zeros(
|
|
878
|
+
(num_channels, current_height, current_width), dtype=np.float32
|
|
879
|
+
)
|
|
880
|
+
padded[: image.shape[0]] = image
|
|
881
|
+
image = padded
|
|
882
|
+
|
|
883
|
+
# Convert to tensor
|
|
884
|
+
image_tensor = torch.tensor(image, device=device)
|
|
885
|
+
|
|
886
|
+
# Add to batch
|
|
887
|
+
batch_inputs.append(image_tensor)
|
|
888
|
+
batch_positions.append((y, x, current_height, current_width))
|
|
889
|
+
batch_count += 1
|
|
890
|
+
|
|
891
|
+
# Process batch when it reaches the batch size or at the end
|
|
892
|
+
if batch_count == batch_size or (i == steps_y and j == steps_x):
|
|
893
|
+
# Forward pass
|
|
894
|
+
with torch.no_grad():
|
|
895
|
+
outputs = model(batch_inputs)
|
|
896
|
+
|
|
897
|
+
# Process each output in the batch
|
|
898
|
+
for idx, output in enumerate(outputs):
|
|
899
|
+
y_pos, x_pos, h, w = batch_positions[idx]
|
|
900
|
+
|
|
901
|
+
# Create weight matrix that gives higher weight to center pixels
|
|
902
|
+
# This helps with smooth blending at boundaries
|
|
903
|
+
y_grid, x_grid = np.mgrid[0:h, 0:w]
|
|
904
|
+
|
|
905
|
+
# Calculate distance from each edge
|
|
906
|
+
dist_from_left = x_grid
|
|
907
|
+
dist_from_right = w - x_grid - 1
|
|
908
|
+
dist_from_top = y_grid
|
|
909
|
+
dist_from_bottom = h - y_grid - 1
|
|
910
|
+
|
|
911
|
+
# Combine distances (minimum distance to any edge)
|
|
912
|
+
edge_distance = np.minimum.reduce(
|
|
913
|
+
[
|
|
914
|
+
dist_from_left,
|
|
915
|
+
dist_from_right,
|
|
916
|
+
dist_from_top,
|
|
917
|
+
dist_from_bottom,
|
|
918
|
+
]
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
# Convert to weight (higher weight for center pixels)
|
|
922
|
+
# Normalize to [0, 1]
|
|
923
|
+
edge_distance = np.minimum(edge_distance, overlap / 2)
|
|
924
|
+
weight = edge_distance / (overlap / 2)
|
|
925
|
+
|
|
926
|
+
# Get masks for predictions above threshold
|
|
927
|
+
if len(output["scores"]) > 0:
|
|
928
|
+
# Get all instances that meet confidence threshold
|
|
929
|
+
keep = output["scores"] > confidence_threshold
|
|
930
|
+
masks = output["masks"][keep].squeeze(1)
|
|
931
|
+
|
|
932
|
+
# Combine all instances into one mask
|
|
933
|
+
if len(masks) > 0:
|
|
934
|
+
combined_mask = torch.max(masks, dim=0)[0] > 0.5
|
|
935
|
+
combined_mask = (
|
|
936
|
+
combined_mask.cpu().numpy().astype(np.float32)
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# Apply weight to prediction
|
|
940
|
+
weighted_pred = combined_mask * weight
|
|
941
|
+
|
|
942
|
+
# Add to accumulators
|
|
943
|
+
pred_accumulator[
|
|
944
|
+
y_pos : y_pos + h, x_pos : x_pos + w
|
|
945
|
+
] += weighted_pred
|
|
946
|
+
count_accumulator[
|
|
947
|
+
y_pos : y_pos + h, x_pos : x_pos + w
|
|
948
|
+
] += weight
|
|
949
|
+
|
|
950
|
+
# Reset batch
|
|
951
|
+
batch_inputs = []
|
|
952
|
+
batch_positions = []
|
|
953
|
+
batch_count = 0
|
|
954
|
+
|
|
955
|
+
# Update progress bar
|
|
956
|
+
pbar.update(len(outputs))
|
|
957
|
+
|
|
958
|
+
# Close progress bar
|
|
959
|
+
pbar.close()
|
|
960
|
+
|
|
961
|
+
# Calculate final mask by dividing accumulated predictions by counts
|
|
962
|
+
# Handle division by zero
|
|
963
|
+
mask = np.zeros((height, width), dtype=np.uint8)
|
|
964
|
+
valid_pixels = count_accumulator > 0
|
|
965
|
+
if np.any(valid_pixels):
|
|
966
|
+
# Average predictions where we have data
|
|
967
|
+
mask[valid_pixels] = (
|
|
968
|
+
pred_accumulator[valid_pixels] / count_accumulator[valid_pixels] > 0.5
|
|
969
|
+
).astype(np.uint8)
|
|
970
|
+
|
|
971
|
+
# Record time
|
|
972
|
+
inference_time = time.time() - start_time
|
|
973
|
+
print(f"Inference completed in {inference_time:.2f} seconds")
|
|
974
|
+
|
|
975
|
+
# Save output
|
|
976
|
+
with rasterio.open(output_path, "w", **out_meta) as dst:
|
|
977
|
+
dst.write(mask, 1)
|
|
978
|
+
|
|
979
|
+
print(f"Saved prediction to {output_path}")
|
|
980
|
+
|
|
981
|
+
return output_path, inference_time
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
def object_detection(
|
|
985
|
+
input_path,
|
|
986
|
+
output_path,
|
|
987
|
+
model_path,
|
|
988
|
+
window_size=512,
|
|
989
|
+
overlap=256,
|
|
990
|
+
confidence_threshold=0.5,
|
|
991
|
+
batch_size=4,
|
|
992
|
+
num_channels=3,
|
|
993
|
+
pretrained=True,
|
|
994
|
+
device=None,
|
|
995
|
+
**kwargs,
|
|
996
|
+
):
|
|
997
|
+
"""
|
|
998
|
+
Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
|
|
999
|
+
|
|
1000
|
+
Args:
|
|
1001
|
+
input_path (str): Path to input GeoTIFF file.
|
|
1002
|
+
output_path (str): Path to save output mask GeoTIFF.
|
|
1003
|
+
model_path (str): Path to trained model weights.
|
|
1004
|
+
window_size (int): Size of sliding window for inference.
|
|
1005
|
+
overlap (int): Overlap between adjacent windows.
|
|
1006
|
+
confidence_threshold (float): Confidence threshold for predictions (0-1).
|
|
1007
|
+
batch_size (int): Batch size for inference.
|
|
1008
|
+
num_channels (int): Number of channels in the input image and model.
|
|
1009
|
+
pretrained (bool): Whether to use pretrained backbone for model loading.
|
|
1010
|
+
device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
|
|
1011
|
+
**kwargs: Additional arguments passed to inference_on_geotiff.
|
|
1012
|
+
|
|
1013
|
+
Returns:
|
|
1014
|
+
None: Output mask is saved to output_path.
|
|
1015
|
+
"""
|
|
1016
|
+
# Load your trained model
|
|
1017
|
+
if device is None:
|
|
1018
|
+
device = (
|
|
1019
|
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
1020
|
+
)
|
|
1021
|
+
model = get_instance_segmentation_model(
|
|
1022
|
+
num_classes=2, num_channels=num_channels, pretrained=pretrained
|
|
1023
|
+
)
|
|
1024
|
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
1025
|
+
model.to(device)
|
|
1026
|
+
model.eval()
|
|
1027
|
+
|
|
1028
|
+
inference_on_geotiff(
|
|
1029
|
+
model=model,
|
|
1030
|
+
geotiff_path=input_path,
|
|
1031
|
+
output_path=output_path,
|
|
1032
|
+
window_size=window_size, # Adjust based on your model and memory
|
|
1033
|
+
overlap=overlap, # Overlap to avoid edge artifacts
|
|
1034
|
+
confidence_threshold=confidence_threshold,
|
|
1035
|
+
batch_size=batch_size, # Adjust based on your GPU memory
|
|
1036
|
+
num_channels=num_channels,
|
|
1037
|
+
device=device,
|
|
1038
|
+
**kwargs,
|
|
1039
|
+
)
|