geoai-py 0.5.6__py2.py3-none-any.whl → 0.7.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/train.py CHANGED
@@ -24,6 +24,16 @@ from tqdm import tqdm
24
24
  from .utils import download_model_from_hf
25
25
 
26
26
 
27
+ # Additional imports for semantic segmentation
28
+ try:
29
+ import segmentation_models_pytorch as smp
30
+ from torch.nn import functional as F
31
+
32
+ SMP_AVAILABLE = True
33
+ except ImportError:
34
+ SMP_AVAILABLE = False
35
+
36
+
27
37
  def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=True):
28
38
  """
29
39
  Get Mask R-CNN model with custom input channels and output classes.
@@ -577,6 +587,7 @@ def train_MaskRCNN_model(
577
587
  labels_dir,
578
588
  output_dir,
579
589
  num_channels=3,
590
+ model=None,
580
591
  pretrained=True,
581
592
  pretrained_model_path=None,
582
593
  batch_size=4,
@@ -601,6 +612,7 @@ def train_MaskRCNN_model(
601
612
  output_dir (str): Directory to save model checkpoints and results.
602
613
  num_channels (int, optional): Number of input channels. If None, auto-detected.
603
614
  Defaults to 3.
615
+ model (torch.nn.Module, optional): Predefined model. If None, a new model is created.
604
616
  pretrained (bool): Whether to use pretrained backbone. This is ignored if
605
617
  pretrained_model_path is provided. Defaults to True.
606
618
  pretrained_model_path (str, optional): Path to a .pth file to load as a
@@ -708,9 +720,10 @@ def train_MaskRCNN_model(
708
720
  )
709
721
 
710
722
  # Initialize model (2 classes: background and building)
711
- model = get_instance_segmentation_model(
712
- num_classes=2, num_channels=num_channels, pretrained=pretrained
713
- )
723
+ if model is None:
724
+ model = get_instance_segmentation_model(
725
+ num_classes=2, num_channels=num_channels, pretrained=pretrained
726
+ )
714
727
  model.to(device)
715
728
 
716
729
  # Set up optimizer
@@ -1088,6 +1101,7 @@ def object_detection(
1088
1101
  confidence_threshold=0.5,
1089
1102
  batch_size=4,
1090
1103
  num_channels=3,
1104
+ model=None,
1091
1105
  pretrained=True,
1092
1106
  device=None,
1093
1107
  **kwargs,
@@ -1104,6 +1118,7 @@ def object_detection(
1104
1118
  confidence_threshold (float): Confidence threshold for predictions (0-1).
1105
1119
  batch_size (int): Batch size for inference.
1106
1120
  num_channels (int): Number of channels in the input image and model.
1121
+ model (torch.nn.Module, optional): Predefined model. If None, a new model is created.
1107
1122
  pretrained (bool): Whether to use pretrained backbone for model loading.
1108
1123
  device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
1109
1124
  **kwargs: Additional arguments passed to inference_on_geotiff.
@@ -1116,9 +1131,10 @@ def object_detection(
1116
1131
  device = (
1117
1132
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1118
1133
  )
1119
- model = get_instance_segmentation_model(
1120
- num_classes=2, num_channels=num_channels, pretrained=pretrained
1121
- )
1134
+ if model is None:
1135
+ model = get_instance_segmentation_model(
1136
+ num_classes=2, num_channels=num_channels, pretrained=pretrained
1137
+ )
1122
1138
 
1123
1139
  if not os.path.exists(model_path):
1124
1140
  try:
@@ -1153,6 +1169,7 @@ def object_detection_batch(
1153
1169
  overlap=256,
1154
1170
  confidence_threshold=0.5,
1155
1171
  batch_size=4,
1172
+ model=None,
1156
1173
  num_channels=3,
1157
1174
  pretrained=True,
1158
1175
  device=None,
@@ -1174,6 +1191,7 @@ def object_detection_batch(
1174
1191
  confidence_threshold (float): Confidence threshold for predictions (0-1).
1175
1192
  batch_size (int): Batch size for inference.
1176
1193
  num_channels (int): Number of channels in the input image and model.
1194
+ model (torch.nn.Module, optional): Predefined model. If None, a new model is created.
1177
1195
  pretrained (bool): Whether to use pretrained backbone for model loading.
1178
1196
  device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
1179
1197
  **kwargs: Additional arguments passed to inference_on_geotiff.
@@ -1186,9 +1204,10 @@ def object_detection_batch(
1186
1204
  device = (
1187
1205
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1188
1206
  )
1189
- model = get_instance_segmentation_model(
1190
- num_classes=2, num_channels=num_channels, pretrained=pretrained
1191
- )
1207
+ if model is None:
1208
+ model = get_instance_segmentation_model(
1209
+ num_classes=2, num_channels=num_channels, pretrained=pretrained
1210
+ )
1192
1211
 
1193
1212
  if not os.path.exists(output_dir):
1194
1213
  os.makedirs(output_dir, exist_ok=True)
@@ -1232,3 +1251,1016 @@ def object_detection_batch(
1232
1251
  device=device,
1233
1252
  **kwargs,
1234
1253
  )
1254
+
1255
+
1256
+ class SemanticSegmentationDataset(Dataset):
1257
+ """Dataset for semantic segmentation from GeoTIFF images and labels."""
1258
+
1259
+ def __init__(self, image_paths, label_paths, transforms=None, num_channels=None):
1260
+ """
1261
+ Initialize dataset for semantic segmentation.
1262
+
1263
+ Args:
1264
+ image_paths (list): List of paths to image GeoTIFF files.
1265
+ label_paths (list): List of paths to label GeoTIFF files.
1266
+ transforms (callable, optional): Transformations to apply to images and masks.
1267
+ num_channels (int, optional): Number of channels to use from images. If None,
1268
+ auto-detected from the first image.
1269
+ """
1270
+ self.image_paths = image_paths
1271
+ self.label_paths = label_paths
1272
+ self.transforms = transforms
1273
+
1274
+ # Auto-detect the number of channels if not specified
1275
+ if num_channels is None:
1276
+ with rasterio.open(self.image_paths[0]) as src:
1277
+ self.num_channels = src.count
1278
+ else:
1279
+ self.num_channels = num_channels
1280
+
1281
+ def __len__(self):
1282
+ return len(self.image_paths)
1283
+
1284
+ def __getitem__(self, idx):
1285
+ # Load image
1286
+ with rasterio.open(self.image_paths[idx]) as src:
1287
+ # Read as [C, H, W] format
1288
+ image = src.read().astype(np.float32)
1289
+
1290
+ # Normalize image to [0, 1] range
1291
+ image = image / 255.0
1292
+
1293
+ # Handle different number of channels
1294
+ if image.shape[0] > self.num_channels:
1295
+ image = image[: self.num_channels] # Keep only specified bands
1296
+ elif image.shape[0] < self.num_channels:
1297
+ # Pad with zeros if less than specified bands
1298
+ padded = np.zeros(
1299
+ (self.num_channels, image.shape[1], image.shape[2]),
1300
+ dtype=np.float32,
1301
+ )
1302
+ padded[: image.shape[0]] = image
1303
+ image = padded
1304
+
1305
+ # Convert to CHW tensor
1306
+ image = torch.as_tensor(image, dtype=torch.float32)
1307
+
1308
+ # Load label mask
1309
+ with rasterio.open(self.label_paths[idx]) as src:
1310
+ label_mask = src.read(1).astype(np.int64)
1311
+ # Keep original class values for multi-class segmentation
1312
+ # No conversion to binary - preserve all class labels
1313
+
1314
+ # Convert to tensor
1315
+ mask = torch.as_tensor(label_mask, dtype=torch.long)
1316
+
1317
+ # Apply transforms if specified
1318
+ if self.transforms is not None:
1319
+ image, mask = self.transforms(image, mask)
1320
+
1321
+ return image, mask
1322
+
1323
+
1324
+ class SemanticTransforms:
1325
+ """Custom transforms for semantic segmentation."""
1326
+
1327
+ def __init__(self, transforms):
1328
+ self.transforms = transforms
1329
+
1330
+ def __call__(self, image, mask):
1331
+ for t in self.transforms:
1332
+ image, mask = t(image, mask)
1333
+ return image, mask
1334
+
1335
+
1336
+ class SemanticToTensor:
1337
+ """Convert numpy.ndarray to tensor for semantic segmentation."""
1338
+
1339
+ def __call__(self, image, mask):
1340
+ return image, mask
1341
+
1342
+
1343
+ class SemanticRandomHorizontalFlip:
1344
+ """Random horizontal flip transform for semantic segmentation."""
1345
+
1346
+ def __init__(self, prob=0.5):
1347
+ self.prob = prob
1348
+
1349
+ def __call__(self, image, mask):
1350
+ if random.random() < self.prob:
1351
+ # Flip image and mask along width dimension
1352
+ image = torch.flip(image, dims=[2])
1353
+ mask = torch.flip(mask, dims=[1])
1354
+ return image, mask
1355
+
1356
+
1357
+ def get_semantic_transform(train):
1358
+ """
1359
+ Get transforms for semantic segmentation data augmentation.
1360
+
1361
+ Args:
1362
+ train (bool): Whether to include training-specific transforms.
1363
+
1364
+ Returns:
1365
+ SemanticTransforms: Composed transforms.
1366
+ """
1367
+ transforms = []
1368
+ transforms.append(SemanticToTensor())
1369
+
1370
+ if train:
1371
+ transforms.append(SemanticRandomHorizontalFlip(0.5))
1372
+
1373
+ return SemanticTransforms(transforms)
1374
+
1375
+
1376
+ def get_smp_model(
1377
+ architecture="unet",
1378
+ encoder_name="resnet34",
1379
+ encoder_weights="imagenet",
1380
+ in_channels=3,
1381
+ classes=2,
1382
+ activation=None,
1383
+ **kwargs,
1384
+ ):
1385
+ """
1386
+ Get a segmentation model from segmentation-models-pytorch using the generic create_model function.
1387
+
1388
+ Args:
1389
+ architecture (str): Model architecture (e.g., 'unet', 'deeplabv3', 'deeplabv3plus', 'fpn',
1390
+ 'pspnet', 'linknet', 'manet', 'pan', 'upernet', etc.). Case insensitive.
1391
+ encoder_name (str): Encoder backbone name (e.g., 'resnet34', 'efficientnet-b0', 'mit_b0', etc.).
1392
+ encoder_weights (str): Encoder weights ('imagenet' or None).
1393
+ in_channels (int): Number of input channels.
1394
+ classes (int): Number of output classes.
1395
+ activation (str): Activation function for output layer.
1396
+ **kwargs: Additional arguments passed to smp.create_model().
1397
+
1398
+ Returns:
1399
+ torch.nn.Module: Segmentation model.
1400
+
1401
+ Note:
1402
+ This function uses smp.create_model() which supports all architectures available in
1403
+ segmentation-models-pytorch, making it future-proof for new model additions.
1404
+ """
1405
+ if not SMP_AVAILABLE:
1406
+ raise ImportError(
1407
+ "segmentation-models-pytorch is not installed. "
1408
+ "Please install it with: pip install segmentation-models-pytorch"
1409
+ )
1410
+
1411
+ try:
1412
+ # Use the generic create_model function - supports all SMP architectures
1413
+ model = smp.create_model(
1414
+ arch=architecture, # Case insensitive
1415
+ encoder_name=encoder_name,
1416
+ encoder_weights=encoder_weights,
1417
+ in_channels=in_channels,
1418
+ classes=classes,
1419
+ **kwargs,
1420
+ )
1421
+
1422
+ # Apply activation if specified (note: activation is handled differently in create_model)
1423
+ if activation is not None:
1424
+ import warnings
1425
+
1426
+ warnings.warn(
1427
+ "The 'activation' parameter is deprecated when using smp.create_model(). "
1428
+ "Apply activation manually after model creation if needed.",
1429
+ DeprecationWarning,
1430
+ stacklevel=2,
1431
+ )
1432
+
1433
+ return model
1434
+
1435
+ except Exception as e:
1436
+ # Provide helpful error message
1437
+ available_archs = []
1438
+ try:
1439
+ # Try to get available architectures from smp
1440
+ if hasattr(smp, "get_available_models"):
1441
+ available_archs = smp.get_available_models()
1442
+ else:
1443
+ available_archs = [
1444
+ "unet",
1445
+ "unetplusplus",
1446
+ "manet",
1447
+ "linknet",
1448
+ "fpn",
1449
+ "pspnet",
1450
+ "deeplabv3",
1451
+ "deeplabv3plus",
1452
+ "pan",
1453
+ "upernet",
1454
+ ]
1455
+ except:
1456
+ available_archs = [
1457
+ "unet",
1458
+ "fpn",
1459
+ "deeplabv3plus",
1460
+ "pspnet",
1461
+ "linknet",
1462
+ "manet",
1463
+ ]
1464
+
1465
+ raise ValueError(
1466
+ f"Failed to create model with architecture '{architecture}' and encoder '{encoder_name}'. "
1467
+ f"Error: {str(e)}. "
1468
+ f"Available architectures include: {', '.join(available_archs)}. "
1469
+ f"Please check the segmentation-models-pytorch documentation for supported combinations."
1470
+ )
1471
+
1472
+
1473
+ def dice_coefficient(pred, target, smooth=1e-6, num_classes=None):
1474
+ """
1475
+ Calculate Dice coefficient for segmentation (binary or multi-class).
1476
+
1477
+ Args:
1478
+ pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
1479
+ target (torch.Tensor): Ground truth mask with shape [H, W].
1480
+ smooth (float): Smoothing factor to avoid division by zero.
1481
+ num_classes (int, optional): Number of classes. If None, auto-detected.
1482
+
1483
+ Returns:
1484
+ float: Mean Dice coefficient across all classes.
1485
+ """
1486
+ # Convert predictions to class predictions
1487
+ if pred.dim() == 3: # [C, H, W] format
1488
+ pred = torch.softmax(pred, dim=0)
1489
+ pred_classes = torch.argmax(pred, dim=0)
1490
+ elif pred.dim() == 2: # [H, W] format
1491
+ pred_classes = pred
1492
+ else:
1493
+ raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
1494
+
1495
+ # Auto-detect number of classes if not provided
1496
+ if num_classes is None:
1497
+ num_classes = max(pred_classes.max().item(), target.max().item()) + 1
1498
+
1499
+ # Calculate Dice for each class and average
1500
+ dice_scores = []
1501
+ for class_id in range(num_classes):
1502
+ pred_class = (pred_classes == class_id).float()
1503
+ target_class = (target == class_id).float()
1504
+
1505
+ intersection = (pred_class * target_class).sum()
1506
+ union = pred_class.sum() + target_class.sum()
1507
+
1508
+ if union > 0:
1509
+ dice = (2.0 * intersection + smooth) / (union + smooth)
1510
+ dice_scores.append(dice.item())
1511
+
1512
+ return sum(dice_scores) / len(dice_scores) if dice_scores else 0.0
1513
+
1514
+
1515
+ def iou_coefficient(pred, target, smooth=1e-6, num_classes=None):
1516
+ """
1517
+ Calculate IoU coefficient for segmentation (binary or multi-class).
1518
+
1519
+ Args:
1520
+ pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
1521
+ target (torch.Tensor): Ground truth mask with shape [H, W].
1522
+ smooth (float): Smoothing factor to avoid division by zero.
1523
+ num_classes (int, optional): Number of classes. If None, auto-detected.
1524
+
1525
+ Returns:
1526
+ float: Mean IoU coefficient across all classes.
1527
+ """
1528
+ # Convert predictions to class predictions
1529
+ if pred.dim() == 3: # [C, H, W] format
1530
+ pred = torch.softmax(pred, dim=0)
1531
+ pred_classes = torch.argmax(pred, dim=0)
1532
+ elif pred.dim() == 2: # [H, W] format
1533
+ pred_classes = pred
1534
+ else:
1535
+ raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
1536
+
1537
+ # Auto-detect number of classes if not provided
1538
+ if num_classes is None:
1539
+ num_classes = max(pred_classes.max().item(), target.max().item()) + 1
1540
+
1541
+ # Calculate IoU for each class and average
1542
+ iou_scores = []
1543
+ for class_id in range(num_classes):
1544
+ pred_class = (pred_classes == class_id).float()
1545
+ target_class = (target == class_id).float()
1546
+
1547
+ intersection = (pred_class * target_class).sum()
1548
+ union = pred_class.sum() + target_class.sum() - intersection
1549
+
1550
+ if union > 0:
1551
+ iou = (intersection + smooth) / (union + smooth)
1552
+ iou_scores.append(iou.item())
1553
+
1554
+ return sum(iou_scores) / len(iou_scores) if iou_scores else 0.0
1555
+
1556
+
1557
+ def train_semantic_one_epoch(
1558
+ model, optimizer, data_loader, device, epoch, criterion, print_freq=10, verbose=True
1559
+ ):
1560
+ """
1561
+ Train the semantic segmentation model for one epoch.
1562
+
1563
+ Args:
1564
+ model (torch.nn.Module): The model to train.
1565
+ optimizer (torch.optim.Optimizer): The optimizer to use.
1566
+ data_loader (torch.utils.data.DataLoader): DataLoader for training data.
1567
+ device (torch.device): Device to train on.
1568
+ epoch (int): Current epoch number.
1569
+ criterion: Loss function.
1570
+ print_freq (int): How often to print progress.
1571
+ verbose (bool): Whether to print detailed progress.
1572
+
1573
+ Returns:
1574
+ float: Average loss for the epoch.
1575
+ """
1576
+ model.train()
1577
+ total_loss = 0
1578
+ num_batches = len(data_loader)
1579
+
1580
+ start_time = time.time()
1581
+
1582
+ for i, (images, targets) in enumerate(data_loader):
1583
+ # Move images and targets to device
1584
+ images = images.to(device)
1585
+ targets = targets.to(device)
1586
+
1587
+ # Forward pass
1588
+ outputs = model(images)
1589
+ loss = criterion(outputs, targets)
1590
+
1591
+ # Backward pass
1592
+ optimizer.zero_grad()
1593
+ loss.backward()
1594
+ optimizer.step()
1595
+
1596
+ # Track loss
1597
+ total_loss += loss.item()
1598
+
1599
+ # Print progress
1600
+ if i % print_freq == 0:
1601
+ elapsed_time = time.time() - start_time
1602
+ if verbose:
1603
+ print(
1604
+ f"Epoch: {epoch}, Batch: {i}/{num_batches}, Loss: {loss.item():.4f}, Time: {elapsed_time:.2f}s"
1605
+ )
1606
+ start_time = time.time()
1607
+
1608
+ # Calculate average loss
1609
+ avg_loss = total_loss / num_batches
1610
+ return avg_loss
1611
+
1612
+
1613
+ def evaluate_semantic(model, data_loader, device, criterion, num_classes=2):
1614
+ """
1615
+ Evaluate the semantic segmentation model on the validation set.
1616
+
1617
+ Args:
1618
+ model (torch.nn.Module): The model to evaluate.
1619
+ data_loader (torch.utils.data.DataLoader): DataLoader for validation data.
1620
+ device (torch.device): Device to evaluate on.
1621
+ criterion: Loss function.
1622
+ num_classes (int): Number of classes for evaluation metrics.
1623
+
1624
+ Returns:
1625
+ dict: Evaluation metrics including loss, IoU, and Dice.
1626
+ """
1627
+ model.eval()
1628
+
1629
+ total_loss = 0
1630
+ dice_scores = []
1631
+ iou_scores = []
1632
+ num_batches = len(data_loader)
1633
+
1634
+ with torch.no_grad():
1635
+ for images, targets in data_loader:
1636
+ # Move to device
1637
+ images = images.to(device)
1638
+ targets = targets.to(device)
1639
+
1640
+ # Forward pass
1641
+ outputs = model(images)
1642
+ loss = criterion(outputs, targets)
1643
+ total_loss += loss.item()
1644
+
1645
+ # Calculate metrics for each sample in the batch
1646
+ for pred, target in zip(outputs, targets):
1647
+ dice = dice_coefficient(pred, target, num_classes=num_classes)
1648
+ iou = iou_coefficient(pred, target, num_classes=num_classes)
1649
+ dice_scores.append(dice)
1650
+ iou_scores.append(iou)
1651
+
1652
+ # Calculate metrics
1653
+ avg_loss = total_loss / num_batches
1654
+ avg_dice = sum(dice_scores) / len(dice_scores) if dice_scores else 0
1655
+ avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
1656
+
1657
+ return {"loss": avg_loss, "Dice": avg_dice, "IoU": avg_iou}
1658
+
1659
+
1660
+ def train_segmentation_model(
1661
+ images_dir,
1662
+ labels_dir,
1663
+ output_dir,
1664
+ architecture="unet",
1665
+ encoder_name="resnet34",
1666
+ encoder_weights="imagenet",
1667
+ num_channels=3,
1668
+ num_classes=2,
1669
+ batch_size=8,
1670
+ num_epochs=50,
1671
+ learning_rate=0.001,
1672
+ weight_decay=1e-4,
1673
+ seed=42,
1674
+ val_split=0.2,
1675
+ print_freq=10,
1676
+ verbose=True,
1677
+ save_best_only=True,
1678
+ plot_curves=False,
1679
+ **kwargs,
1680
+ ):
1681
+ """
1682
+ Train a semantic segmentation model for object detection using segmentation-models-pytorch.
1683
+
1684
+ This function trains a semantic segmentation model for object detection (e.g., building detection)
1685
+ using models from the segmentation-models-pytorch library. Unlike instance segmentation (Mask R-CNN),
1686
+ this approach treats the task as pixel-level binary classification.
1687
+
1688
+ Args:
1689
+ images_dir (str): Directory containing image GeoTIFF files.
1690
+ labels_dir (str): Directory containing label GeoTIFF files.
1691
+ output_dir (str): Directory to save model checkpoints and results.
1692
+ architecture (str): Model architecture ('unet', 'deeplabv3', 'deeplabv3plus', 'fpn',
1693
+ 'pspnet', 'linknet', 'manet'). Defaults to 'unet'.
1694
+ encoder_name (str): Encoder backbone name (e.g., 'resnet34', 'resnet50', 'efficientnet-b0').
1695
+ Defaults to 'resnet34'.
1696
+ encoder_weights (str): Encoder pretrained weights ('imagenet' or None). Defaults to 'imagenet'.
1697
+ num_channels (int): Number of input channels. Defaults to 3.
1698
+ num_classes (int): Number of output classes (typically 2 for binary segmentation). Defaults to 2.
1699
+ batch_size (int): Batch size for training. Defaults to 8.
1700
+ num_epochs (int): Number of training epochs. Defaults to 50.
1701
+ learning_rate (float): Initial learning rate. Defaults to 0.001.
1702
+ weight_decay (float): Weight decay for optimizer. Defaults to 1e-4.
1703
+ seed (int): Random seed for reproducibility. Defaults to 42.
1704
+ val_split (float): Fraction of data to use for validation (0-1). Defaults to 0.2.
1705
+ print_freq (int): Frequency of printing training progress. Defaults to 10.
1706
+ verbose (bool): If True, prints detailed training progress. Defaults to True.
1707
+ save_best_only (bool): If True, only saves the best model. Otherwise saves all checkpoints.
1708
+ Defaults to True.
1709
+ plot_curves (bool): If True, plots training curves. Defaults to False.
1710
+ **kwargs: Additional arguments passed to smp.create_model().
1711
+ Returns:
1712
+ None: Model weights are saved to output_dir.
1713
+
1714
+ Raises:
1715
+ ImportError: If segmentation-models-pytorch is not installed.
1716
+ FileNotFoundError: If input directories don't exist or contain no matching files.
1717
+ """
1718
+ import datetime
1719
+
1720
+ if not SMP_AVAILABLE:
1721
+ raise ImportError(
1722
+ "segmentation-models-pytorch is not installed. "
1723
+ "Please install it with: pip install segmentation-models-pytorch"
1724
+ )
1725
+
1726
+ # Set random seeds for reproducibility
1727
+ torch.manual_seed(seed)
1728
+ np.random.seed(seed)
1729
+ random.seed(seed)
1730
+ torch.backends.cudnn.deterministic = True
1731
+ torch.backends.cudnn.benchmark = False
1732
+
1733
+ # Create output directory
1734
+ os.makedirs(output_dir, exist_ok=True)
1735
+
1736
+ # Get device
1737
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1738
+ print(f"Using device: {device}")
1739
+
1740
+ # Get all image and label files
1741
+ image_files = sorted(
1742
+ [
1743
+ os.path.join(images_dir, f)
1744
+ for f in os.listdir(images_dir)
1745
+ if f.endswith(".tif")
1746
+ ]
1747
+ )
1748
+ label_files = sorted(
1749
+ [
1750
+ os.path.join(labels_dir, f)
1751
+ for f in os.listdir(labels_dir)
1752
+ if f.endswith(".tif")
1753
+ ]
1754
+ )
1755
+
1756
+ print(f"Found {len(image_files)} image files and {len(label_files)} label files")
1757
+
1758
+ # Ensure matching files
1759
+ if len(image_files) != len(label_files):
1760
+ print("Warning: Number of image files and label files don't match!")
1761
+ # Find matching files by basename
1762
+ basenames = [os.path.basename(f) for f in image_files]
1763
+ label_files = [
1764
+ os.path.join(labels_dir, os.path.basename(f))
1765
+ for f in image_files
1766
+ if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
1767
+ ]
1768
+ image_files = [
1769
+ f
1770
+ for f, b in zip(image_files, basenames)
1771
+ if os.path.exists(os.path.join(labels_dir, b))
1772
+ ]
1773
+ print(f"Using {len(image_files)} matching files")
1774
+
1775
+ if len(image_files) == 0:
1776
+ raise FileNotFoundError("No matching image and label files found")
1777
+
1778
+ # Split data into train and validation sets
1779
+ train_imgs, val_imgs, train_labels, val_labels = train_test_split(
1780
+ image_files, label_files, test_size=val_split, random_state=seed
1781
+ )
1782
+
1783
+ print(f"Training on {len(train_imgs)} images, validating on {len(val_imgs)} images")
1784
+
1785
+ # Create datasets
1786
+ train_dataset = SemanticSegmentationDataset(
1787
+ train_imgs,
1788
+ train_labels,
1789
+ transforms=get_semantic_transform(train=True),
1790
+ num_channels=num_channels,
1791
+ )
1792
+ val_dataset = SemanticSegmentationDataset(
1793
+ val_imgs,
1794
+ val_labels,
1795
+ transforms=get_semantic_transform(train=False),
1796
+ num_channels=num_channels,
1797
+ )
1798
+
1799
+ # Create data loaders
1800
+ train_loader = DataLoader(
1801
+ train_dataset,
1802
+ batch_size=batch_size,
1803
+ shuffle=True,
1804
+ num_workers=4,
1805
+ pin_memory=True,
1806
+ )
1807
+
1808
+ val_loader = DataLoader(
1809
+ val_dataset,
1810
+ batch_size=batch_size,
1811
+ shuffle=False,
1812
+ num_workers=4,
1813
+ pin_memory=True,
1814
+ )
1815
+
1816
+ # Initialize model
1817
+ model = get_smp_model(
1818
+ architecture=architecture,
1819
+ encoder_name=encoder_name,
1820
+ encoder_weights=encoder_weights,
1821
+ in_channels=num_channels,
1822
+ classes=num_classes,
1823
+ activation=None, # We'll apply softmax later
1824
+ **kwargs,
1825
+ )
1826
+ model.to(device)
1827
+
1828
+ # Set up loss function (CrossEntropyLoss for multi-class, can also use DiceLoss)
1829
+ criterion = torch.nn.CrossEntropyLoss()
1830
+
1831
+ # Set up optimizer
1832
+ optimizer = torch.optim.Adam(
1833
+ model.parameters(), lr=learning_rate, weight_decay=weight_decay
1834
+ )
1835
+
1836
+ # Set up learning rate scheduler
1837
+ lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
1838
+ optimizer, mode="min", factor=0.5, patience=5, verbose=True
1839
+ )
1840
+
1841
+ # Initialize tracking variables
1842
+ best_iou = 0
1843
+ train_losses = []
1844
+ val_losses = []
1845
+ val_ious = []
1846
+ val_dices = []
1847
+
1848
+ print(f"Starting training with {architecture} + {encoder_name}")
1849
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
1850
+
1851
+ # Training loop
1852
+ for epoch in range(num_epochs):
1853
+ # Train one epoch
1854
+ train_loss = train_semantic_one_epoch(
1855
+ model,
1856
+ optimizer,
1857
+ train_loader,
1858
+ device,
1859
+ epoch,
1860
+ criterion,
1861
+ print_freq,
1862
+ verbose,
1863
+ )
1864
+ train_losses.append(train_loss)
1865
+
1866
+ # Evaluate on validation set
1867
+ eval_metrics = evaluate_semantic(
1868
+ model, val_loader, device, criterion, num_classes=num_classes
1869
+ )
1870
+ val_losses.append(eval_metrics["loss"])
1871
+ val_ious.append(eval_metrics["IoU"])
1872
+ val_dices.append(eval_metrics["Dice"])
1873
+
1874
+ # Update learning rate
1875
+ lr_scheduler.step(eval_metrics["loss"])
1876
+
1877
+ # Print metrics
1878
+ print(
1879
+ f"Epoch {epoch+1}/{num_epochs}: "
1880
+ f"Train Loss: {train_loss:.4f}, "
1881
+ f"Val Loss: {eval_metrics['loss']:.4f}, "
1882
+ f"Val IoU: {eval_metrics['IoU']:.4f}, "
1883
+ f"Val Dice: {eval_metrics['Dice']:.4f}"
1884
+ )
1885
+
1886
+ # Save best model
1887
+ if eval_metrics["IoU"] > best_iou:
1888
+ best_iou = eval_metrics["IoU"]
1889
+ print(f"Saving best model with IoU: {best_iou:.4f}")
1890
+ torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
1891
+
1892
+ # Save checkpoint every 10 epochs (if not save_best_only)
1893
+ if not save_best_only and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
1894
+ torch.save(
1895
+ {
1896
+ "epoch": epoch,
1897
+ "model_state_dict": model.state_dict(),
1898
+ "optimizer_state_dict": optimizer.state_dict(),
1899
+ "scheduler_state_dict": lr_scheduler.state_dict(),
1900
+ "best_iou": best_iou,
1901
+ "architecture": architecture,
1902
+ "encoder_name": encoder_name,
1903
+ "num_channels": num_channels,
1904
+ "num_classes": num_classes,
1905
+ },
1906
+ os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
1907
+ )
1908
+
1909
+ # Save final model
1910
+ torch.save(model.state_dict(), os.path.join(output_dir, "final_model.pth"))
1911
+
1912
+ # Save training history
1913
+ history = {
1914
+ "train_losses": train_losses,
1915
+ "val_losses": val_losses,
1916
+ "val_ious": val_ious,
1917
+ "val_dices": val_dices,
1918
+ }
1919
+ torch.save(history, os.path.join(output_dir, "training_history.pth"))
1920
+
1921
+ # Save training summary
1922
+ with open(os.path.join(output_dir, "training_summary.txt"), "w") as f:
1923
+ f.write(
1924
+ f"Training completed on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
1925
+ )
1926
+ f.write(f"Architecture: {architecture}\n")
1927
+ f.write(f"Encoder: {encoder_name}\n")
1928
+ f.write(f"Total epochs: {num_epochs}\n")
1929
+ f.write(f"Best validation IoU: {best_iou:.4f}\n")
1930
+ f.write(f"Final validation IoU: {val_ious[-1]:.4f}\n")
1931
+ f.write(f"Final validation Dice: {val_dices[-1]:.4f}\n")
1932
+ f.write(f"Final validation loss: {val_losses[-1]:.4f}\n")
1933
+
1934
+ print(f"Training complete! Best IoU: {best_iou:.4f}")
1935
+ print(f"Models saved to {output_dir}")
1936
+
1937
+ # Plot training curves
1938
+ if plot_curves:
1939
+ try:
1940
+ plt.figure(figsize=(15, 5))
1941
+
1942
+ plt.subplot(1, 3, 1)
1943
+ plt.plot(train_losses, label="Train Loss")
1944
+ plt.plot(val_losses, label="Val Loss")
1945
+ plt.title("Loss")
1946
+ plt.xlabel("Epoch")
1947
+ plt.ylabel("Loss")
1948
+ plt.legend()
1949
+ plt.grid(True)
1950
+
1951
+ plt.subplot(1, 3, 2)
1952
+ plt.plot(val_ious, label="Val IoU")
1953
+ plt.title("IoU Score")
1954
+ plt.xlabel("Epoch")
1955
+ plt.ylabel("IoU")
1956
+ plt.legend()
1957
+ plt.grid(True)
1958
+
1959
+ plt.subplot(1, 3, 3)
1960
+ plt.plot(val_dices, label="Val Dice")
1961
+ plt.title("Dice Score")
1962
+ plt.xlabel("Epoch")
1963
+ plt.ylabel("Dice")
1964
+ plt.legend()
1965
+ plt.grid(True)
1966
+
1967
+ plt.tight_layout()
1968
+ plt.savefig(
1969
+ os.path.join(output_dir, "training_curves.png"),
1970
+ dpi=150,
1971
+ bbox_inches="tight",
1972
+ )
1973
+ print(
1974
+ f"Training curves saved to {os.path.join(output_dir, 'training_curves.png')}"
1975
+ )
1976
+ plt.close()
1977
+ except Exception as e:
1978
+ print(f"Could not save training curves: {e}")
1979
+
1980
+
1981
+ def semantic_inference_on_geotiff(
1982
+ model,
1983
+ geotiff_path,
1984
+ output_path,
1985
+ window_size=512,
1986
+ overlap=256,
1987
+ batch_size=4,
1988
+ num_channels=3,
1989
+ num_classes=2,
1990
+ device=None,
1991
+ **kwargs,
1992
+ ):
1993
+ """
1994
+ Perform semantic segmentation inference on a large GeoTIFF using a sliding window approach.
1995
+
1996
+ Args:
1997
+ model (torch.nn.Module): Trained semantic segmentation model.
1998
+ geotiff_path (str): Path to input GeoTIFF file.
1999
+ output_path (str): Path to save output mask GeoTIFF.
2000
+ window_size (int): Size of sliding window for inference.
2001
+ overlap (int): Overlap between adjacent windows.
2002
+ batch_size (int): Batch size for inference.
2003
+ num_channels (int): Number of channels to use from the input image.
2004
+ num_classes (int): Number of classes in the model output.
2005
+ device (torch.device, optional): Device to run inference on.
2006
+ **kwargs: Additional arguments.
2007
+
2008
+ Returns:
2009
+ tuple: Tuple containing output path and inference time in seconds.
2010
+ """
2011
+ if device is None:
2012
+ device = (
2013
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
2014
+ )
2015
+
2016
+ # Put model in evaluation mode
2017
+ model.to(device)
2018
+ model.eval()
2019
+
2020
+ # Open the GeoTIFF
2021
+ with rasterio.open(geotiff_path) as src:
2022
+ # Read metadata
2023
+ meta = src.meta
2024
+ height = src.height
2025
+ width = src.width
2026
+
2027
+ # Update metadata for output raster
2028
+ out_meta = meta.copy()
2029
+ out_meta.update({"count": 1, "dtype": "uint8"})
2030
+
2031
+ # Initialize accumulator arrays for multi-class probability blending
2032
+ # We'll accumulate probabilities for each class and then take argmax
2033
+ prob_accumulator = np.zeros((num_classes, height, width), dtype=np.float32)
2034
+ count_accumulator = np.zeros((height, width), dtype=np.float32)
2035
+
2036
+ # Calculate steps
2037
+ steps_y = math.ceil((height - overlap) / (window_size - overlap))
2038
+ steps_x = math.ceil((width - overlap) / (window_size - overlap))
2039
+ last_y = height - window_size
2040
+ last_x = width - window_size
2041
+
2042
+ total_windows = steps_y * steps_x
2043
+ print(f"Processing {total_windows} windows...")
2044
+
2045
+ pbar = tqdm(total=total_windows)
2046
+ batch_inputs = []
2047
+ batch_positions = []
2048
+ batch_count = 0
2049
+
2050
+ start_time = time.time()
2051
+
2052
+ for i in range(steps_y + 1):
2053
+ y = min(i * (window_size - overlap), last_y)
2054
+ y = max(0, y)
2055
+
2056
+ if y > last_y and i > 0:
2057
+ continue
2058
+
2059
+ for j in range(steps_x + 1):
2060
+ x = min(j * (window_size - overlap), last_x)
2061
+ x = max(0, x)
2062
+
2063
+ if x > last_x and j > 0:
2064
+ continue
2065
+
2066
+ # Read window
2067
+ window = src.read(window=Window(x, y, window_size, window_size))
2068
+
2069
+ if window.shape[1] == 0 or window.shape[2] == 0:
2070
+ continue
2071
+
2072
+ current_height = window.shape[1]
2073
+ current_width = window.shape[2]
2074
+
2075
+ # Normalize and prepare input
2076
+ image = window.astype(np.float32) / 255.0
2077
+
2078
+ # Handle different number of bands
2079
+ if image.shape[0] > num_channels:
2080
+ image = image[:num_channels]
2081
+ elif image.shape[0] < num_channels:
2082
+ padded = np.zeros(
2083
+ (num_channels, current_height, current_width), dtype=np.float32
2084
+ )
2085
+ padded[: image.shape[0]] = image
2086
+ image = padded
2087
+
2088
+ # Convert to tensor
2089
+ image_tensor = torch.tensor(image, device=device)
2090
+
2091
+ # Add to batch
2092
+ batch_inputs.append(image_tensor)
2093
+ batch_positions.append((y, x, current_height, current_width))
2094
+ batch_count += 1
2095
+
2096
+ # Process batch
2097
+ if batch_count == batch_size or (i == steps_y and j == steps_x):
2098
+ with torch.no_grad():
2099
+ batch_tensor = torch.stack(batch_inputs)
2100
+ outputs = model(batch_tensor)
2101
+
2102
+ # Apply softmax to get class probabilities
2103
+ probs = torch.softmax(outputs, dim=1)
2104
+
2105
+ # Process each output in the batch
2106
+ for idx, prob in enumerate(probs):
2107
+ y_pos, x_pos, h, w = batch_positions[idx]
2108
+
2109
+ # Create weight matrix for blending
2110
+ y_grid, x_grid = np.mgrid[0:h, 0:w]
2111
+ dist_from_left = x_grid
2112
+ dist_from_right = w - x_grid - 1
2113
+ dist_from_top = y_grid
2114
+ dist_from_bottom = h - y_grid - 1
2115
+
2116
+ edge_distance = np.minimum.reduce(
2117
+ [
2118
+ dist_from_left,
2119
+ dist_from_right,
2120
+ dist_from_top,
2121
+ dist_from_bottom,
2122
+ ]
2123
+ )
2124
+ edge_distance = np.minimum(edge_distance, overlap / 2)
2125
+
2126
+ # Avoid zero weights - use minimum weight of 0.1
2127
+ weight = np.maximum(edge_distance / (overlap / 2), 0.1)
2128
+
2129
+ # For non-overlapping windows, use uniform weight
2130
+ if overlap == 0:
2131
+ weight = np.ones_like(weight)
2132
+
2133
+ # Convert probabilities to numpy [C, H, W]
2134
+ prob_np = prob.cpu().numpy()
2135
+
2136
+ # Accumulate weighted probabilities for each class
2137
+ y_slice = slice(y_pos, y_pos + h)
2138
+ x_slice = slice(x_pos, x_pos + w)
2139
+
2140
+ # Add weighted probabilities for each class
2141
+ for class_idx in range(num_classes):
2142
+ prob_accumulator[class_idx, y_slice, x_slice] += (
2143
+ prob_np[class_idx] * weight
2144
+ )
2145
+
2146
+ # Update weight accumulator
2147
+ count_accumulator[y_slice, x_slice] += weight
2148
+
2149
+ # Reset batch
2150
+ batch_inputs = []
2151
+ batch_positions = []
2152
+ batch_count = 0
2153
+ pbar.update(len(probs))
2154
+
2155
+ pbar.close()
2156
+
2157
+ # Calculate final mask by taking argmax of accumulated probabilities
2158
+ mask = np.zeros((height, width), dtype=np.uint8)
2159
+ valid_pixels = count_accumulator > 0
2160
+
2161
+ if np.any(valid_pixels):
2162
+ # Normalize accumulated probabilities by weights
2163
+ normalized_probs = np.zeros_like(prob_accumulator)
2164
+ for class_idx in range(num_classes):
2165
+ normalized_probs[class_idx, valid_pixels] = (
2166
+ prob_accumulator[class_idx, valid_pixels]
2167
+ / count_accumulator[valid_pixels]
2168
+ )
2169
+
2170
+ # Take argmax to get final class predictions
2171
+ mask[valid_pixels] = np.argmax(
2172
+ normalized_probs[:, valid_pixels], axis=0
2173
+ ).astype(np.uint8)
2174
+
2175
+ # Check class distribution in predictions (summary only)
2176
+ unique_classes, class_counts = np.unique(
2177
+ mask[valid_pixels], return_counts=True
2178
+ )
2179
+ bg_ratio = np.sum(mask == 0) / mask.size
2180
+ print(
2181
+ f"Predicted classes: {len(unique_classes)} classes, Background: {bg_ratio:.1%}"
2182
+ )
2183
+
2184
+ inference_time = time.time() - start_time
2185
+ print(f"Inference completed in {inference_time:.2f} seconds")
2186
+
2187
+ # Save output
2188
+ with rasterio.open(output_path, "w", **out_meta) as dst:
2189
+ dst.write(mask, 1)
2190
+
2191
+ print(f"Saved prediction to {output_path}")
2192
+
2193
+ return output_path, inference_time
2194
+
2195
+
2196
+ def semantic_segmentation(
2197
+ input_path,
2198
+ output_path,
2199
+ model_path,
2200
+ architecture="unet",
2201
+ encoder_name="resnet34",
2202
+ num_channels=3,
2203
+ num_classes=2,
2204
+ window_size=512,
2205
+ overlap=256,
2206
+ batch_size=4,
2207
+ device=None,
2208
+ **kwargs,
2209
+ ):
2210
+ """
2211
+ Perform semantic segmentation on a GeoTIFF using a trained model.
2212
+
2213
+ Args:
2214
+ input_path (str): Path to input GeoTIFF file.
2215
+ output_path (str): Path to save output mask GeoTIFF.
2216
+ model_path (str): Path to trained model weights.
2217
+ architecture (str): Model architecture used for training.
2218
+ encoder_name (str): Encoder backbone name used for training.
2219
+ num_channels (int): Number of channels in the input image and model.
2220
+ num_classes (int): Number of classes in the model.
2221
+ window_size (int): Size of sliding window for inference.
2222
+ overlap (int): Overlap between adjacent windows.
2223
+ batch_size (int): Batch size for inference.
2224
+ device (torch.device, optional): Device to run inference on.
2225
+ **kwargs: Additional arguments.
2226
+
2227
+ Returns:
2228
+ None: Output mask is saved to output_path.
2229
+ """
2230
+ if device is None:
2231
+ device = (
2232
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
2233
+ )
2234
+
2235
+ # Load model
2236
+ model = get_smp_model(
2237
+ architecture=architecture,
2238
+ encoder_name=encoder_name,
2239
+ encoder_weights=None, # We're loading trained weights
2240
+ in_channels=num_channels,
2241
+ classes=num_classes,
2242
+ activation=None,
2243
+ )
2244
+
2245
+ if not os.path.exists(model_path):
2246
+ try:
2247
+ model_path = download_model_from_hf(model_path)
2248
+ except Exception as e:
2249
+ raise FileNotFoundError(f"Model file not found: {model_path}")
2250
+
2251
+ model.load_state_dict(torch.load(model_path, map_location=device))
2252
+ model.to(device)
2253
+ model.eval()
2254
+
2255
+ semantic_inference_on_geotiff(
2256
+ model=model,
2257
+ geotiff_path=input_path,
2258
+ output_path=output_path,
2259
+ window_size=window_size,
2260
+ overlap=overlap,
2261
+ batch_size=batch_size,
2262
+ num_channels=num_channels,
2263
+ num_classes=num_classes,
2264
+ device=device,
2265
+ **kwargs,
2266
+ )