geoai-py 0.3.5__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/train.py ADDED
@@ -0,0 +1,1039 @@
1
+ import math
2
+ import os
3
+ import random
4
+ import time
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import rasterio
9
+ import torch
10
+ import torch.utils.data
11
+ import torchvision
12
+
13
+ # import torchvision.transforms as transforms
14
+ from rasterio.windows import Window
15
+ from skimage import measure
16
+ from sklearn.model_selection import train_test_split
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
19
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
20
+ from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
21
+ from tqdm import tqdm
22
+
23
+
24
+ def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=True):
25
+ """
26
+ Get Mask R-CNN model with custom input channels and output classes.
27
+
28
+ Args:
29
+ num_classes (int): Number of output classes (including background).
30
+ num_channels (int): Number of input channels (3 for RGB, 4 for RGBN).
31
+ pretrained (bool): Whether to use pretrained backbone.
32
+
33
+ Returns:
34
+ torch.nn.Module: Mask R-CNN model with specified input channels and output classes.
35
+
36
+ Raises:
37
+ ValueError: If num_channels is less than 3.
38
+ """
39
+ # Validate num_channels
40
+ if num_channels < 3:
41
+ raise ValueError("num_channels must be at least 3")
42
+
43
+ # Load pre-trained model
44
+ model = maskrcnn_resnet50_fpn(
45
+ pretrained=pretrained,
46
+ progress=True,
47
+ weights=(
48
+ torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT
49
+ if pretrained
50
+ else None
51
+ ),
52
+ )
53
+
54
+ # Modify transform if num_channels is different from 3
55
+ if num_channels != 3:
56
+ # Get the transform
57
+ transform = model.transform
58
+
59
+ # Default values are [0.485, 0.456, 0.406] and [0.229, 0.224, 0.225]
60
+ # Calculate means and stds for additional channels
61
+ rgb_mean = [0.485, 0.456, 0.406]
62
+ rgb_std = [0.229, 0.224, 0.225]
63
+
64
+ # Extend them to num_channels (use the mean value for additional channels)
65
+ mean_of_means = sum(rgb_mean) / len(rgb_mean)
66
+ mean_of_stds = sum(rgb_std) / len(rgb_std)
67
+
68
+ # Create new lists with appropriate length
69
+ transform.image_mean = rgb_mean + [mean_of_means] * (num_channels - 3)
70
+ transform.image_std = rgb_std + [mean_of_stds] * (num_channels - 3)
71
+
72
+ # Get number of input features for the classifier
73
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
74
+
75
+ # Replace the pre-trained head with a new one
76
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
77
+
78
+ # Get number of input features for mask classifier
79
+ in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
80
+ hidden_layer = 256
81
+
82
+ # Replace mask predictor with a new one
83
+ model.roi_heads.mask_predictor = MaskRCNNPredictor(
84
+ in_features_mask, hidden_layer, num_classes
85
+ )
86
+
87
+ # Modify the first layer if num_channels is different from 3
88
+ if num_channels != 3:
89
+ original_layer = model.backbone.body.conv1
90
+ model.backbone.body.conv1 = torch.nn.Conv2d(
91
+ num_channels,
92
+ original_layer.out_channels,
93
+ kernel_size=original_layer.kernel_size,
94
+ stride=original_layer.stride,
95
+ padding=original_layer.padding,
96
+ bias=original_layer.bias is not None,
97
+ )
98
+
99
+ # Copy weights from the original 3 channels to the new layer
100
+ with torch.no_grad():
101
+ # Copy the weights for the first 3 channels
102
+ model.backbone.body.conv1.weight[:, :3, :, :] = original_layer.weight
103
+
104
+ # Initialize additional channels with the mean of the first 3 channels
105
+ mean_weight = original_layer.weight.mean(dim=1, keepdim=True)
106
+ for i in range(3, num_channels):
107
+ model.backbone.body.conv1.weight[:, i : i + 1, :, :] = mean_weight
108
+
109
+ # Copy bias if it exists
110
+ if original_layer.bias is not None:
111
+ model.backbone.body.conv1.bias = original_layer.bias
112
+
113
+ return model
114
+
115
+
116
+ class ObjectDetectionDataset(Dataset):
117
+ """Dataset for object detection from GeoTIFF images and labels."""
118
+
119
+ def __init__(self, image_paths, label_paths, transforms=None, num_channels=None):
120
+ """
121
+ Initialize dataset.
122
+
123
+ Args:
124
+ image_paths (list): List of paths to image GeoTIFF files.
125
+ label_paths (list): List of paths to label GeoTIFF files.
126
+ transforms (callable, optional): Transformations to apply to images and masks.
127
+ num_channels (int, optional): Number of channels to use from images. If None,
128
+ auto-detected from the first image.
129
+ """
130
+ self.image_paths = image_paths
131
+ self.label_paths = label_paths
132
+ self.transforms = transforms
133
+
134
+ # Auto-detect the number of channels if not specified
135
+ if num_channels is None:
136
+ with rasterio.open(self.image_paths[0]) as src:
137
+ self.num_channels = src.count
138
+ else:
139
+ self.num_channels = num_channels
140
+
141
+ def __len__(self):
142
+ return len(self.image_paths)
143
+
144
+ def __getitem__(self, idx):
145
+ # Load image
146
+ with rasterio.open(self.image_paths[idx]) as src:
147
+ # Read as [C, H, W] format
148
+ image = src.read().astype(np.float32)
149
+
150
+ # Normalize image to [0, 1] range
151
+ image = image / 255.0
152
+
153
+ # Handle different number of channels
154
+ if image.shape[0] > self.num_channels:
155
+ image = image[
156
+ : self.num_channels
157
+ ] # Keep only first 4 bands if more exist
158
+ elif image.shape[0] < self.num_channels:
159
+ # Pad with zeros if less than 4 bands
160
+ padded = np.zeros(
161
+ (self.num_channels, image.shape[1], image.shape[2]),
162
+ dtype=np.float32,
163
+ )
164
+ padded[: image.shape[0]] = image
165
+ image = padded
166
+
167
+ # Convert to CHW tensor
168
+ image = torch.as_tensor(image, dtype=torch.float32)
169
+
170
+ # Load label mask
171
+ with rasterio.open(self.label_paths[idx]) as src:
172
+ label_mask = src.read(1)
173
+ binary_mask = (label_mask > 0).astype(np.uint8)
174
+
175
+ # Find all building instances using connected components
176
+ labeled_mask, num_instances = measure.label(
177
+ binary_mask, return_num=True, connectivity=2
178
+ )
179
+
180
+ # Create list to hold masks for each building instance
181
+ masks = []
182
+ boxes = []
183
+ labels = []
184
+
185
+ for i in range(1, num_instances + 1):
186
+ # Create mask for this instance
187
+ instance_mask = (labeled_mask == i).astype(np.uint8)
188
+
189
+ # Calculate area and filter out tiny instances (noise)
190
+ area = instance_mask.sum()
191
+ if area < 10: # Minimum area threshold
192
+ continue
193
+
194
+ # Find bounding box coordinates
195
+ pos = np.where(instance_mask)
196
+ if len(pos[0]) == 0: # Skip if mask is empty
197
+ continue
198
+
199
+ xmin = np.min(pos[1])
200
+ xmax = np.max(pos[1])
201
+ ymin = np.min(pos[0])
202
+ ymax = np.max(pos[0])
203
+
204
+ # Skip invalid boxes
205
+ if xmax <= xmin or ymax <= ymin:
206
+ continue
207
+
208
+ # Add small padding to ensure the mask is within the box
209
+ xmin = max(0, xmin - 1)
210
+ ymin = max(0, ymin - 1)
211
+ xmax = min(binary_mask.shape[1] - 1, xmax + 1)
212
+ ymax = min(binary_mask.shape[0] - 1, ymax + 1)
213
+
214
+ boxes.append([xmin, ymin, xmax, ymax])
215
+ masks.append(instance_mask)
216
+ labels.append(1) # 1 for building class
217
+
218
+ # Handle case with no valid instances
219
+ if len(boxes) == 0:
220
+ # Create a dummy target with minimal required fields
221
+ target = {
222
+ "boxes": torch.zeros((0, 4), dtype=torch.float32),
223
+ "labels": torch.zeros((0), dtype=torch.int64),
224
+ "masks": torch.zeros(
225
+ (0, binary_mask.shape[0], binary_mask.shape[1]), dtype=torch.uint8
226
+ ),
227
+ "image_id": torch.tensor([idx]),
228
+ "area": torch.zeros((0), dtype=torch.float32),
229
+ "iscrowd": torch.zeros((0), dtype=torch.int64),
230
+ }
231
+ else:
232
+ # Convert to tensors
233
+ boxes = torch.as_tensor(boxes, dtype=torch.float32)
234
+ labels = torch.as_tensor(labels, dtype=torch.int64)
235
+ masks = torch.as_tensor(np.array(masks), dtype=torch.uint8)
236
+
237
+ # Calculate area of boxes
238
+ area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
239
+
240
+ # Prepare target dictionary
241
+ target = {
242
+ "boxes": boxes,
243
+ "labels": labels,
244
+ "masks": masks,
245
+ "image_id": torch.tensor([idx]),
246
+ "area": area,
247
+ "iscrowd": torch.zeros_like(labels), # Assume no crowd instances
248
+ }
249
+
250
+ # Apply transforms if specified
251
+ if self.transforms is not None:
252
+ image, target = self.transforms(image, target)
253
+
254
+ return image, target
255
+
256
+
257
+ class Compose:
258
+ """Custom compose transform that works with image and target."""
259
+
260
+ def __init__(self, transforms):
261
+ """
262
+ Initialize compose transform.
263
+
264
+ Args:
265
+ transforms (list): List of transforms to apply.
266
+ """
267
+ self.transforms = transforms
268
+
269
+ def __call__(self, image, target):
270
+ for t in self.transforms:
271
+ image, target = t(image, target)
272
+ return image, target
273
+
274
+
275
+ class ToTensor:
276
+ """Convert numpy.ndarray to tensor."""
277
+
278
+ def __call__(self, image, target):
279
+ """
280
+ Apply transform to image and target.
281
+
282
+ Args:
283
+ image (torch.Tensor): Input image.
284
+ target (dict): Target annotations.
285
+
286
+ Returns:
287
+ tuple: Transformed image and target.
288
+ """
289
+ return image, target
290
+
291
+
292
+ class RandomHorizontalFlip:
293
+ """Random horizontal flip transform."""
294
+
295
+ def __init__(self, prob=0.5):
296
+ """
297
+ Initialize random horizontal flip.
298
+
299
+ Args:
300
+ prob (float): Probability of applying the flip.
301
+ """
302
+ self.prob = prob
303
+
304
+ def __call__(self, image, target):
305
+ if random.random() < self.prob:
306
+ # Flip image
307
+ image = torch.flip(image, dims=[2]) # Flip along width dimension
308
+
309
+ # Flip masks
310
+ if "masks" in target and len(target["masks"]) > 0:
311
+ target["masks"] = torch.flip(target["masks"], dims=[2])
312
+
313
+ # Update boxes
314
+ if "boxes" in target and len(target["boxes"]) > 0:
315
+ boxes = target["boxes"]
316
+ width = image.shape[2]
317
+ boxes[:, 0], boxes[:, 2] = width - boxes[:, 2], width - boxes[:, 0]
318
+ target["boxes"] = boxes
319
+
320
+ return image, target
321
+
322
+
323
+ def get_transform(train):
324
+ """
325
+ Get transforms for data augmentation.
326
+
327
+ Args:
328
+ train (bool): Whether to include training-specific transforms.
329
+
330
+ Returns:
331
+ Compose: Composed transforms.
332
+ """
333
+ transforms = []
334
+ transforms.append(ToTensor())
335
+
336
+ if train:
337
+ transforms.append(RandomHorizontalFlip(0.5))
338
+
339
+ return Compose(transforms)
340
+
341
+
342
+ def collate_fn(batch):
343
+ """
344
+ Custom collate function for batching samples.
345
+
346
+ Args:
347
+ batch (list): List of (image, target) tuples.
348
+
349
+ Returns:
350
+ tuple: Tuple of images and targets.
351
+ """
352
+ return tuple(zip(*batch))
353
+
354
+
355
+ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
356
+ """
357
+ Train the model for one epoch.
358
+
359
+ Args:
360
+ model (torch.nn.Module): The model to train.
361
+ optimizer (torch.optim.Optimizer): The optimizer to use.
362
+ data_loader (torch.utils.data.DataLoader): DataLoader for training data.
363
+ device (torch.device): Device to train on.
364
+ epoch (int): Current epoch number.
365
+ print_freq (int): How often to print progress.
366
+
367
+ Returns:
368
+ float: Average loss for the epoch.
369
+ """
370
+ model.train()
371
+ total_loss = 0
372
+
373
+ start_time = time.time()
374
+
375
+ for i, (images, targets) in enumerate(data_loader):
376
+ # Move images and targets to device
377
+ images = list(image.to(device) for image in images)
378
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
379
+
380
+ # Forward pass
381
+ loss_dict = model(images, targets)
382
+ losses = sum(loss for loss in loss_dict.values())
383
+
384
+ # Backward pass
385
+ optimizer.zero_grad()
386
+ losses.backward()
387
+ optimizer.step()
388
+
389
+ # Track loss
390
+ total_loss += losses.item()
391
+
392
+ # Print progress
393
+ if i % print_freq == 0:
394
+ elapsed_time = time.time() - start_time
395
+ print(
396
+ f"Epoch: {epoch}, Batch: {i}/{len(data_loader)}, Loss: {losses.item():.4f}, Time: {elapsed_time:.2f}s"
397
+ )
398
+ start_time = time.time()
399
+
400
+ # Calculate average loss
401
+ avg_loss = total_loss / len(data_loader)
402
+ return avg_loss
403
+
404
+
405
+ def evaluate(model, data_loader, device):
406
+ """
407
+ Evaluate the model on the validation set.
408
+
409
+ Args:
410
+ model (torch.nn.Module): The model to evaluate.
411
+ data_loader (torch.utils.data.DataLoader): DataLoader for validation data.
412
+ device (torch.device): Device to evaluate on.
413
+
414
+ Returns:
415
+ dict: Evaluation metrics including loss and IoU.
416
+ """
417
+ model.eval()
418
+
419
+ # Initialize metrics
420
+ total_loss = 0
421
+ iou_scores = []
422
+
423
+ with torch.no_grad():
424
+ for images, targets in data_loader:
425
+ # Move to device
426
+ images = list(image.to(device) for image in images)
427
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
428
+
429
+ # During evaluation, Mask R-CNN directly returns predictions, not losses
430
+ # So we'll only get loss when we provide targets explicitly
431
+ if len(targets) > 0:
432
+ try:
433
+ # Try to get loss dict (this works in some implementations)
434
+ loss_dict = model(images, targets)
435
+ if isinstance(loss_dict, dict):
436
+ losses = sum(loss for loss in loss_dict.values())
437
+ total_loss += losses.item()
438
+ except Exception as e:
439
+ print(f"Warning: Could not compute loss during evaluation: {e}")
440
+ # If we can't compute loss, we'll just focus on IoU
441
+ pass
442
+
443
+ # Get predictions
444
+ outputs = model(images)
445
+
446
+ # Calculate IoU for each image
447
+ for i, output in enumerate(outputs):
448
+ if len(output["masks"]) == 0 or len(targets[i]["masks"]) == 0:
449
+ continue
450
+
451
+ # Convert predicted masks to binary (threshold at 0.5)
452
+ pred_masks = (output["masks"].squeeze(1) > 0.5).float()
453
+
454
+ # Combine all instance masks into a single binary mask
455
+ pred_combined = (
456
+ torch.max(pred_masks, dim=0)[0]
457
+ if pred_masks.shape[0] > 0
458
+ else torch.zeros_like(targets[i]["masks"][0])
459
+ )
460
+ target_combined = (
461
+ torch.max(targets[i]["masks"], dim=0)[0]
462
+ if targets[i]["masks"].shape[0] > 0
463
+ else torch.zeros_like(pred_combined)
464
+ )
465
+
466
+ # Calculate IoU
467
+ intersection = (pred_combined * target_combined).sum().item()
468
+ union = ((pred_combined + target_combined) > 0).sum().item()
469
+
470
+ if union > 0:
471
+ iou = intersection / union
472
+ iou_scores.append(iou)
473
+
474
+ # Calculate metrics
475
+ avg_loss = total_loss / len(data_loader) if total_loss > 0 else float("inf")
476
+ avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
477
+
478
+ return {"loss": avg_loss, "IoU": avg_iou}
479
+
480
+
481
+ def visualize_predictions(model, dataset, device, num_samples=5, output_dir=None):
482
+ """
483
+ Visualize model predictions.
484
+
485
+ Args:
486
+ model (torch.nn.Module): Trained model.
487
+ dataset (torch.utils.data.Dataset): Dataset to visualize.
488
+ device (torch.device): Device to run inference on.
489
+ num_samples (int): Number of samples to visualize.
490
+ output_dir (str, optional): Directory to save visualizations. If None,
491
+ visualizations are displayed but not saved.
492
+ """
493
+ model.eval()
494
+
495
+ # Create output directory if needed
496
+ if output_dir:
497
+ os.makedirs(output_dir, exist_ok=True)
498
+
499
+ # Select random samples
500
+ indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
501
+
502
+ for idx in indices:
503
+ # Get image and target
504
+ image, target = dataset[idx]
505
+
506
+ # Convert to device and add batch dimension
507
+ image = image.to(device)
508
+ image_batch = [image]
509
+
510
+ # Get prediction
511
+ with torch.no_grad():
512
+ output = model(image_batch)[0]
513
+
514
+ # Convert image from CHW to HWC for display (first 3 bands as RGB)
515
+ rgb_image = image[:3].cpu().numpy()
516
+ rgb_image = np.transpose(rgb_image, (1, 2, 0))
517
+ rgb_image = np.clip(rgb_image, 0, 1) # Ensure values are in [0,1]
518
+
519
+ # Create binary ground truth mask (combine all instances)
520
+ gt_masks = target["masks"].cpu().numpy()
521
+ gt_combined = (
522
+ np.max(gt_masks, axis=0)
523
+ if len(gt_masks) > 0
524
+ else np.zeros((image.shape[1], image.shape[2]), dtype=np.uint8)
525
+ )
526
+
527
+ # Create binary prediction mask (combine all instances with score > 0.5)
528
+ pred_masks = output["masks"].cpu().numpy()
529
+ pred_scores = output["scores"].cpu().numpy()
530
+ high_conf_indices = pred_scores > 0.5
531
+
532
+ pred_combined = np.zeros((image.shape[1], image.shape[2]), dtype=np.float32)
533
+ if np.any(high_conf_indices):
534
+ for mask in pred_masks[high_conf_indices]:
535
+ # Apply threshold to each predicted mask
536
+ binary_mask = (mask[0] > 0.5).astype(np.float32)
537
+ # Combine with existing masks
538
+ pred_combined = np.maximum(pred_combined, binary_mask)
539
+
540
+ # Create figure
541
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
542
+
543
+ # Show RGB image
544
+ axs[0].imshow(rgb_image)
545
+ axs[0].set_title("RGB Image")
546
+ axs[0].axis("off")
547
+
548
+ # Show prediction
549
+ axs[1].imshow(pred_combined, cmap="viridis")
550
+ axs[1].set_title(f"Predicted Buildings: {np.sum(high_conf_indices)} instances")
551
+ axs[1].axis("off")
552
+
553
+ # Show ground truth
554
+ axs[2].imshow(gt_combined, cmap="viridis")
555
+ axs[2].set_title(f"Ground Truth: {len(gt_masks)} instances")
556
+ axs[2].axis("off")
557
+
558
+ plt.tight_layout()
559
+
560
+ # Save or show
561
+ if output_dir:
562
+ plt.savefig(os.path.join(output_dir, f"prediction_{idx}.png"))
563
+ plt.close()
564
+ else:
565
+ plt.show()
566
+
567
+
568
+ def train_MaskRCNN_model(
569
+ images_dir,
570
+ labels_dir,
571
+ output_dir,
572
+ num_channels=3,
573
+ pretrained=True,
574
+ batch_size=4,
575
+ num_epochs=10,
576
+ learning_rate=0.005,
577
+ seed=42,
578
+ val_split=0.2,
579
+ visualize=False,
580
+ ):
581
+ """
582
+ Train and evaluate Mask R-CNN model for instance segmentation.
583
+
584
+ Args:
585
+ images_dir (str): Directory containing image GeoTIFF files.
586
+ labels_dir (str): Directory containing label GeoTIFF files.
587
+ output_dir (str): Directory to save model checkpoints and results.
588
+ num_channels (int, optional): Number of input channels. If None, auto-detected.
589
+ pretrained (bool): Whether to use pretrained backbone.
590
+ batch_size (int): Batch size for training.
591
+ num_epochs (int): Number of training epochs.
592
+ learning_rate (float): Initial learning rate.
593
+ seed (int): Random seed for reproducibility.
594
+ val_split (float): Fraction of data to use for validation (0-1).
595
+ visualize (bool): Whether to generate visualizations of model predictions.
596
+
597
+ Returns:
598
+ None: Model weights are saved to output_dir.
599
+ """
600
+
601
+ # Set random seeds for reproducibility
602
+ torch.manual_seed(seed)
603
+ np.random.seed(seed)
604
+ random.seed(seed)
605
+ torch.backends.cudnn.deterministic = True
606
+ torch.backends.cudnn.benchmark = False
607
+
608
+ # Create output directory
609
+ os.makedirs(output_dir, exist_ok=True)
610
+
611
+ # Get device
612
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
613
+ print(f"Using device: {device}")
614
+
615
+ # Get all image and label files
616
+ image_files = sorted(
617
+ [
618
+ os.path.join(images_dir, f)
619
+ for f in os.listdir(images_dir)
620
+ if f.endswith(".tif")
621
+ ]
622
+ )
623
+ label_files = sorted(
624
+ [
625
+ os.path.join(labels_dir, f)
626
+ for f in os.listdir(labels_dir)
627
+ if f.endswith(".tif")
628
+ ]
629
+ )
630
+
631
+ print(f"Found {len(image_files)} image files and {len(label_files)} label files")
632
+
633
+ # Ensure matching files
634
+ if len(image_files) != len(label_files):
635
+ print("Warning: Number of image files and label files don't match!")
636
+ # Find matching files by basename
637
+ basenames = [os.path.basename(f) for f in image_files]
638
+ label_files = [
639
+ os.path.join(labels_dir, os.path.basename(f))
640
+ for f in image_files
641
+ if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
642
+ ]
643
+ image_files = [
644
+ f
645
+ for f, b in zip(image_files, basenames)
646
+ if os.path.exists(os.path.join(labels_dir, b))
647
+ ]
648
+ print(f"Using {len(image_files)} matching files")
649
+
650
+ # Split data into train and validation sets
651
+ train_imgs, val_imgs, train_labels, val_labels = train_test_split(
652
+ image_files, label_files, test_size=val_split, random_state=seed
653
+ )
654
+
655
+ print(f"Training on {len(train_imgs)} images, validating on {len(val_imgs)} images")
656
+
657
+ # Create datasets
658
+ train_dataset = ObjectDetectionDataset(
659
+ train_imgs, train_labels, transforms=get_transform(train=True)
660
+ )
661
+ val_dataset = ObjectDetectionDataset(
662
+ val_imgs, val_labels, transforms=get_transform(train=False)
663
+ )
664
+
665
+ # Create data loaders
666
+ train_loader = DataLoader(
667
+ train_dataset,
668
+ batch_size=batch_size,
669
+ shuffle=True,
670
+ collate_fn=collate_fn,
671
+ num_workers=4,
672
+ )
673
+
674
+ val_loader = DataLoader(
675
+ val_dataset,
676
+ batch_size=batch_size,
677
+ shuffle=False,
678
+ collate_fn=collate_fn,
679
+ num_workers=4,
680
+ )
681
+
682
+ # Initialize model (2 classes: background and building)
683
+ model = get_instance_segmentation_model(
684
+ num_classes=2, num_channels=num_channels, pretrained=pretrained
685
+ )
686
+ model.to(device)
687
+
688
+ # Set up optimizer
689
+ params = [p for p in model.parameters() if p.requires_grad]
690
+ optimizer = torch.optim.SGD(
691
+ params, lr=learning_rate, momentum=0.9, weight_decay=0.0005
692
+ )
693
+
694
+ # Set up learning rate scheduler
695
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)
696
+
697
+ # Training loop
698
+ best_iou = 0
699
+ for epoch in range(num_epochs):
700
+ # Train one epoch
701
+ train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch)
702
+
703
+ # Update learning rate
704
+ lr_scheduler.step()
705
+
706
+ # Evaluate
707
+ eval_metrics = evaluate(model, val_loader, device)
708
+
709
+ # Print metrics
710
+ print(
711
+ f"Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Val Loss: {eval_metrics['loss']:.4f}, Val IoU: {eval_metrics['IoU']:.4f}"
712
+ )
713
+
714
+ # Save best model
715
+ if eval_metrics["IoU"] > best_iou:
716
+ best_iou = eval_metrics["IoU"]
717
+ print(f"Saving best model with IoU: {best_iou:.4f}")
718
+ torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
719
+
720
+ # Save checkpoint every 10 epochs
721
+ if (epoch + 1) % 10 == 0:
722
+ torch.save(
723
+ {
724
+ "epoch": epoch,
725
+ "model_state_dict": model.state_dict(),
726
+ "optimizer_state_dict": optimizer.state_dict(),
727
+ "scheduler_state_dict": lr_scheduler.state_dict(),
728
+ "best_iou": best_iou,
729
+ },
730
+ os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
731
+ )
732
+
733
+ # Save final model
734
+ torch.save(model.state_dict(), os.path.join(output_dir, "final_model.pth"))
735
+
736
+ # Load best model for evaluation and visualization
737
+ model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
738
+
739
+ # Final evaluation
740
+ final_metrics = evaluate(model, val_loader, device)
741
+ print(
742
+ f"Final Evaluation - Loss: {final_metrics['loss']:.4f}, IoU: {final_metrics['IoU']:.4f}"
743
+ )
744
+
745
+ # Visualize results
746
+ if visualize:
747
+ print("Generating visualizations...")
748
+ visualize_predictions(
749
+ model,
750
+ val_dataset,
751
+ device,
752
+ num_samples=5,
753
+ output_dir=os.path.join(output_dir, "visualizations"),
754
+ )
755
+ print(f"Training complete!. Trained model saved to {output_dir}")
756
+
757
+
758
+ def inference_on_geotiff(
759
+ model,
760
+ geotiff_path,
761
+ output_path,
762
+ window_size=512,
763
+ overlap=256,
764
+ confidence_threshold=0.5,
765
+ batch_size=4,
766
+ num_channels=3,
767
+ device=None,
768
+ **kwargs,
769
+ ):
770
+ """
771
+ Perform inference on a large GeoTIFF using a sliding window approach with improved blending.
772
+
773
+ Args:
774
+ model (torch.nn.Module): Trained model for inference.
775
+ geotiff_path (str): Path to input GeoTIFF file.
776
+ output_path (str): Path to save output mask GeoTIFF.
777
+ window_size (int): Size of sliding window for inference.
778
+ overlap (int): Overlap between adjacent windows.
779
+ confidence_threshold (float): Confidence threshold for predictions (0-1).
780
+ batch_size (int): Batch size for inference.
781
+ num_channels (int): Number of channels to use from the input image.
782
+ device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
783
+ **kwargs: Additional arguments.
784
+
785
+ Returns:
786
+ tuple: Tuple containing output path and inference time in seconds.
787
+ """
788
+ if device is None:
789
+ device = (
790
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
791
+ )
792
+
793
+ # Put model in evaluation mode
794
+ model.to(device)
795
+ model.eval()
796
+
797
+ # Open the GeoTIFF
798
+ with rasterio.open(geotiff_path) as src:
799
+ # Read metadata
800
+ meta = src.meta
801
+ height = src.height
802
+ width = src.width
803
+
804
+ # Update metadata for output raster
805
+ out_meta = meta.copy()
806
+ out_meta.update(
807
+ {"count": 1, "dtype": "uint8"} # Single band for mask # Binary mask
808
+ )
809
+
810
+ # We'll use two arrays:
811
+ # 1. For accumulating predictions
812
+ pred_accumulator = np.zeros((height, width), dtype=np.float32)
813
+ # 2. For tracking how many predictions contribute to each pixel
814
+ count_accumulator = np.zeros((height, width), dtype=np.float32)
815
+
816
+ # Calculate the number of windows needed to cover the entire image
817
+ steps_y = math.ceil((height - overlap) / (window_size - overlap))
818
+ steps_x = math.ceil((width - overlap) / (window_size - overlap))
819
+
820
+ # Ensure we cover the entire image
821
+ last_y = height - window_size
822
+ last_x = width - window_size
823
+
824
+ total_windows = steps_y * steps_x
825
+ print(
826
+ f"Processing {total_windows} windows with size {window_size}x{window_size} and overlap {overlap}..."
827
+ )
828
+
829
+ # Create progress bar
830
+ pbar = tqdm(total=total_windows)
831
+
832
+ # Process in batches
833
+ batch_inputs = []
834
+ batch_positions = []
835
+ batch_count = 0
836
+
837
+ start_time = time.time()
838
+
839
+ # Slide window over the image - make sure we cover the entire image
840
+ for i in range(steps_y + 1): # +1 to ensure we reach the edge
841
+ y = min(i * (window_size - overlap), last_y)
842
+ y = max(0, y) # Prevent negative indices
843
+
844
+ if y > last_y and i > 0: # Skip if we've already covered the entire height
845
+ continue
846
+
847
+ for j in range(steps_x + 1): # +1 to ensure we reach the edge
848
+ x = min(j * (window_size - overlap), last_x)
849
+ x = max(0, x) # Prevent negative indices
850
+
851
+ if (
852
+ x > last_x and j > 0
853
+ ): # Skip if we've already covered the entire width
854
+ continue
855
+
856
+ # Read window
857
+ window = src.read(window=Window(x, y, window_size, window_size))
858
+
859
+ # Check if window is valid
860
+ if window.shape[1] != window_size or window.shape[2] != window_size:
861
+ # This can happen at image edges - adjust window size
862
+ current_height = window.shape[1]
863
+ current_width = window.shape[2]
864
+ if current_height == 0 or current_width == 0:
865
+ continue # Skip empty windows
866
+ else:
867
+ current_height = window_size
868
+ current_width = window_size
869
+
870
+ # Normalize and prepare input
871
+ image = window.astype(np.float32) / 255.0
872
+
873
+ # Handle different number of bands
874
+ if image.shape[0] > num_channels:
875
+ image = image[:num_channels]
876
+ elif image.shape[0] < num_channels:
877
+ padded = np.zeros(
878
+ (num_channels, current_height, current_width), dtype=np.float32
879
+ )
880
+ padded[: image.shape[0]] = image
881
+ image = padded
882
+
883
+ # Convert to tensor
884
+ image_tensor = torch.tensor(image, device=device)
885
+
886
+ # Add to batch
887
+ batch_inputs.append(image_tensor)
888
+ batch_positions.append((y, x, current_height, current_width))
889
+ batch_count += 1
890
+
891
+ # Process batch when it reaches the batch size or at the end
892
+ if batch_count == batch_size or (i == steps_y and j == steps_x):
893
+ # Forward pass
894
+ with torch.no_grad():
895
+ outputs = model(batch_inputs)
896
+
897
+ # Process each output in the batch
898
+ for idx, output in enumerate(outputs):
899
+ y_pos, x_pos, h, w = batch_positions[idx]
900
+
901
+ # Create weight matrix that gives higher weight to center pixels
902
+ # This helps with smooth blending at boundaries
903
+ y_grid, x_grid = np.mgrid[0:h, 0:w]
904
+
905
+ # Calculate distance from each edge
906
+ dist_from_left = x_grid
907
+ dist_from_right = w - x_grid - 1
908
+ dist_from_top = y_grid
909
+ dist_from_bottom = h - y_grid - 1
910
+
911
+ # Combine distances (minimum distance to any edge)
912
+ edge_distance = np.minimum.reduce(
913
+ [
914
+ dist_from_left,
915
+ dist_from_right,
916
+ dist_from_top,
917
+ dist_from_bottom,
918
+ ]
919
+ )
920
+
921
+ # Convert to weight (higher weight for center pixels)
922
+ # Normalize to [0, 1]
923
+ edge_distance = np.minimum(edge_distance, overlap / 2)
924
+ weight = edge_distance / (overlap / 2)
925
+
926
+ # Get masks for predictions above threshold
927
+ if len(output["scores"]) > 0:
928
+ # Get all instances that meet confidence threshold
929
+ keep = output["scores"] > confidence_threshold
930
+ masks = output["masks"][keep].squeeze(1)
931
+
932
+ # Combine all instances into one mask
933
+ if len(masks) > 0:
934
+ combined_mask = torch.max(masks, dim=0)[0] > 0.5
935
+ combined_mask = (
936
+ combined_mask.cpu().numpy().astype(np.float32)
937
+ )
938
+
939
+ # Apply weight to prediction
940
+ weighted_pred = combined_mask * weight
941
+
942
+ # Add to accumulators
943
+ pred_accumulator[
944
+ y_pos : y_pos + h, x_pos : x_pos + w
945
+ ] += weighted_pred
946
+ count_accumulator[
947
+ y_pos : y_pos + h, x_pos : x_pos + w
948
+ ] += weight
949
+
950
+ # Reset batch
951
+ batch_inputs = []
952
+ batch_positions = []
953
+ batch_count = 0
954
+
955
+ # Update progress bar
956
+ pbar.update(len(outputs))
957
+
958
+ # Close progress bar
959
+ pbar.close()
960
+
961
+ # Calculate final mask by dividing accumulated predictions by counts
962
+ # Handle division by zero
963
+ mask = np.zeros((height, width), dtype=np.uint8)
964
+ valid_pixels = count_accumulator > 0
965
+ if np.any(valid_pixels):
966
+ # Average predictions where we have data
967
+ mask[valid_pixels] = (
968
+ pred_accumulator[valid_pixels] / count_accumulator[valid_pixels] > 0.5
969
+ ).astype(np.uint8)
970
+
971
+ # Record time
972
+ inference_time = time.time() - start_time
973
+ print(f"Inference completed in {inference_time:.2f} seconds")
974
+
975
+ # Save output
976
+ with rasterio.open(output_path, "w", **out_meta) as dst:
977
+ dst.write(mask, 1)
978
+
979
+ print(f"Saved prediction to {output_path}")
980
+
981
+ return output_path, inference_time
982
+
983
+
984
+ def object_detection(
985
+ input_path,
986
+ output_path,
987
+ model_path,
988
+ window_size=512,
989
+ overlap=256,
990
+ confidence_threshold=0.5,
991
+ batch_size=4,
992
+ num_channels=3,
993
+ pretrained=True,
994
+ device=None,
995
+ **kwargs,
996
+ ):
997
+ """
998
+ Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
999
+
1000
+ Args:
1001
+ input_path (str): Path to input GeoTIFF file.
1002
+ output_path (str): Path to save output mask GeoTIFF.
1003
+ model_path (str): Path to trained model weights.
1004
+ window_size (int): Size of sliding window for inference.
1005
+ overlap (int): Overlap between adjacent windows.
1006
+ confidence_threshold (float): Confidence threshold for predictions (0-1).
1007
+ batch_size (int): Batch size for inference.
1008
+ num_channels (int): Number of channels in the input image and model.
1009
+ pretrained (bool): Whether to use pretrained backbone for model loading.
1010
+ device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
1011
+ **kwargs: Additional arguments passed to inference_on_geotiff.
1012
+
1013
+ Returns:
1014
+ None: Output mask is saved to output_path.
1015
+ """
1016
+ # Load your trained model
1017
+ if device is None:
1018
+ device = (
1019
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1020
+ )
1021
+ model = get_instance_segmentation_model(
1022
+ num_classes=2, num_channels=num_channels, pretrained=pretrained
1023
+ )
1024
+ model.load_state_dict(torch.load(model_path, map_location=device))
1025
+ model.to(device)
1026
+ model.eval()
1027
+
1028
+ inference_on_geotiff(
1029
+ model=model,
1030
+ geotiff_path=input_path,
1031
+ output_path=output_path,
1032
+ window_size=window_size, # Adjust based on your model and memory
1033
+ overlap=overlap, # Overlap to avoid edge artifacts
1034
+ confidence_threshold=confidence_threshold,
1035
+ batch_size=batch_size, # Adjust based on your GPU memory
1036
+ num_channels=num_channels,
1037
+ device=device,
1038
+ **kwargs,
1039
+ )