geoai-py 0.5.6__py2.py3-none-any.whl → 0.7.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/classify.py +23 -24
- geoai/extract.py +1 -1
- geoai/geoai.py +10 -4
- geoai/sam.py +832 -0
- geoai/train.py +1041 -9
- geoai/utils.py +342 -0
- {geoai_py-0.5.6.dist-info → geoai_py-0.7.0.dist-info}/METADATA +2 -1
- geoai_py-0.7.0.dist-info/RECORD +17 -0
- {geoai_py-0.5.6.dist-info → geoai_py-0.7.0.dist-info}/WHEEL +1 -1
- geoai_py-0.5.6.dist-info/RECORD +0 -16
- {geoai_py-0.5.6.dist-info → geoai_py-0.7.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.5.6.dist-info → geoai_py-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.5.6.dist-info → geoai_py-0.7.0.dist-info}/top_level.txt +0 -0
geoai/train.py
CHANGED
|
@@ -24,6 +24,16 @@ from tqdm import tqdm
|
|
|
24
24
|
from .utils import download_model_from_hf
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
# Additional imports for semantic segmentation
|
|
28
|
+
try:
|
|
29
|
+
import segmentation_models_pytorch as smp
|
|
30
|
+
from torch.nn import functional as F
|
|
31
|
+
|
|
32
|
+
SMP_AVAILABLE = True
|
|
33
|
+
except ImportError:
|
|
34
|
+
SMP_AVAILABLE = False
|
|
35
|
+
|
|
36
|
+
|
|
27
37
|
def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=True):
|
|
28
38
|
"""
|
|
29
39
|
Get Mask R-CNN model with custom input channels and output classes.
|
|
@@ -577,6 +587,7 @@ def train_MaskRCNN_model(
|
|
|
577
587
|
labels_dir,
|
|
578
588
|
output_dir,
|
|
579
589
|
num_channels=3,
|
|
590
|
+
model=None,
|
|
580
591
|
pretrained=True,
|
|
581
592
|
pretrained_model_path=None,
|
|
582
593
|
batch_size=4,
|
|
@@ -601,6 +612,7 @@ def train_MaskRCNN_model(
|
|
|
601
612
|
output_dir (str): Directory to save model checkpoints and results.
|
|
602
613
|
num_channels (int, optional): Number of input channels. If None, auto-detected.
|
|
603
614
|
Defaults to 3.
|
|
615
|
+
model (torch.nn.Module, optional): Predefined model. If None, a new model is created.
|
|
604
616
|
pretrained (bool): Whether to use pretrained backbone. This is ignored if
|
|
605
617
|
pretrained_model_path is provided. Defaults to True.
|
|
606
618
|
pretrained_model_path (str, optional): Path to a .pth file to load as a
|
|
@@ -708,9 +720,10 @@ def train_MaskRCNN_model(
|
|
|
708
720
|
)
|
|
709
721
|
|
|
710
722
|
# Initialize model (2 classes: background and building)
|
|
711
|
-
model
|
|
712
|
-
|
|
713
|
-
|
|
723
|
+
if model is None:
|
|
724
|
+
model = get_instance_segmentation_model(
|
|
725
|
+
num_classes=2, num_channels=num_channels, pretrained=pretrained
|
|
726
|
+
)
|
|
714
727
|
model.to(device)
|
|
715
728
|
|
|
716
729
|
# Set up optimizer
|
|
@@ -1088,6 +1101,7 @@ def object_detection(
|
|
|
1088
1101
|
confidence_threshold=0.5,
|
|
1089
1102
|
batch_size=4,
|
|
1090
1103
|
num_channels=3,
|
|
1104
|
+
model=None,
|
|
1091
1105
|
pretrained=True,
|
|
1092
1106
|
device=None,
|
|
1093
1107
|
**kwargs,
|
|
@@ -1104,6 +1118,7 @@ def object_detection(
|
|
|
1104
1118
|
confidence_threshold (float): Confidence threshold for predictions (0-1).
|
|
1105
1119
|
batch_size (int): Batch size for inference.
|
|
1106
1120
|
num_channels (int): Number of channels in the input image and model.
|
|
1121
|
+
model (torch.nn.Module, optional): Predefined model. If None, a new model is created.
|
|
1107
1122
|
pretrained (bool): Whether to use pretrained backbone for model loading.
|
|
1108
1123
|
device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
|
|
1109
1124
|
**kwargs: Additional arguments passed to inference_on_geotiff.
|
|
@@ -1116,9 +1131,10 @@ def object_detection(
|
|
|
1116
1131
|
device = (
|
|
1117
1132
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
1118
1133
|
)
|
|
1119
|
-
model
|
|
1120
|
-
|
|
1121
|
-
|
|
1134
|
+
if model is None:
|
|
1135
|
+
model = get_instance_segmentation_model(
|
|
1136
|
+
num_classes=2, num_channels=num_channels, pretrained=pretrained
|
|
1137
|
+
)
|
|
1122
1138
|
|
|
1123
1139
|
if not os.path.exists(model_path):
|
|
1124
1140
|
try:
|
|
@@ -1153,6 +1169,7 @@ def object_detection_batch(
|
|
|
1153
1169
|
overlap=256,
|
|
1154
1170
|
confidence_threshold=0.5,
|
|
1155
1171
|
batch_size=4,
|
|
1172
|
+
model=None,
|
|
1156
1173
|
num_channels=3,
|
|
1157
1174
|
pretrained=True,
|
|
1158
1175
|
device=None,
|
|
@@ -1174,6 +1191,7 @@ def object_detection_batch(
|
|
|
1174
1191
|
confidence_threshold (float): Confidence threshold for predictions (0-1).
|
|
1175
1192
|
batch_size (int): Batch size for inference.
|
|
1176
1193
|
num_channels (int): Number of channels in the input image and model.
|
|
1194
|
+
model (torch.nn.Module, optional): Predefined model. If None, a new model is created.
|
|
1177
1195
|
pretrained (bool): Whether to use pretrained backbone for model loading.
|
|
1178
1196
|
device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
|
|
1179
1197
|
**kwargs: Additional arguments passed to inference_on_geotiff.
|
|
@@ -1186,9 +1204,10 @@ def object_detection_batch(
|
|
|
1186
1204
|
device = (
|
|
1187
1205
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
1188
1206
|
)
|
|
1189
|
-
model
|
|
1190
|
-
|
|
1191
|
-
|
|
1207
|
+
if model is None:
|
|
1208
|
+
model = get_instance_segmentation_model(
|
|
1209
|
+
num_classes=2, num_channels=num_channels, pretrained=pretrained
|
|
1210
|
+
)
|
|
1192
1211
|
|
|
1193
1212
|
if not os.path.exists(output_dir):
|
|
1194
1213
|
os.makedirs(output_dir, exist_ok=True)
|
|
@@ -1232,3 +1251,1016 @@ def object_detection_batch(
|
|
|
1232
1251
|
device=device,
|
|
1233
1252
|
**kwargs,
|
|
1234
1253
|
)
|
|
1254
|
+
|
|
1255
|
+
|
|
1256
|
+
class SemanticSegmentationDataset(Dataset):
|
|
1257
|
+
"""Dataset for semantic segmentation from GeoTIFF images and labels."""
|
|
1258
|
+
|
|
1259
|
+
def __init__(self, image_paths, label_paths, transforms=None, num_channels=None):
|
|
1260
|
+
"""
|
|
1261
|
+
Initialize dataset for semantic segmentation.
|
|
1262
|
+
|
|
1263
|
+
Args:
|
|
1264
|
+
image_paths (list): List of paths to image GeoTIFF files.
|
|
1265
|
+
label_paths (list): List of paths to label GeoTIFF files.
|
|
1266
|
+
transforms (callable, optional): Transformations to apply to images and masks.
|
|
1267
|
+
num_channels (int, optional): Number of channels to use from images. If None,
|
|
1268
|
+
auto-detected from the first image.
|
|
1269
|
+
"""
|
|
1270
|
+
self.image_paths = image_paths
|
|
1271
|
+
self.label_paths = label_paths
|
|
1272
|
+
self.transforms = transforms
|
|
1273
|
+
|
|
1274
|
+
# Auto-detect the number of channels if not specified
|
|
1275
|
+
if num_channels is None:
|
|
1276
|
+
with rasterio.open(self.image_paths[0]) as src:
|
|
1277
|
+
self.num_channels = src.count
|
|
1278
|
+
else:
|
|
1279
|
+
self.num_channels = num_channels
|
|
1280
|
+
|
|
1281
|
+
def __len__(self):
|
|
1282
|
+
return len(self.image_paths)
|
|
1283
|
+
|
|
1284
|
+
def __getitem__(self, idx):
|
|
1285
|
+
# Load image
|
|
1286
|
+
with rasterio.open(self.image_paths[idx]) as src:
|
|
1287
|
+
# Read as [C, H, W] format
|
|
1288
|
+
image = src.read().astype(np.float32)
|
|
1289
|
+
|
|
1290
|
+
# Normalize image to [0, 1] range
|
|
1291
|
+
image = image / 255.0
|
|
1292
|
+
|
|
1293
|
+
# Handle different number of channels
|
|
1294
|
+
if image.shape[0] > self.num_channels:
|
|
1295
|
+
image = image[: self.num_channels] # Keep only specified bands
|
|
1296
|
+
elif image.shape[0] < self.num_channels:
|
|
1297
|
+
# Pad with zeros if less than specified bands
|
|
1298
|
+
padded = np.zeros(
|
|
1299
|
+
(self.num_channels, image.shape[1], image.shape[2]),
|
|
1300
|
+
dtype=np.float32,
|
|
1301
|
+
)
|
|
1302
|
+
padded[: image.shape[0]] = image
|
|
1303
|
+
image = padded
|
|
1304
|
+
|
|
1305
|
+
# Convert to CHW tensor
|
|
1306
|
+
image = torch.as_tensor(image, dtype=torch.float32)
|
|
1307
|
+
|
|
1308
|
+
# Load label mask
|
|
1309
|
+
with rasterio.open(self.label_paths[idx]) as src:
|
|
1310
|
+
label_mask = src.read(1).astype(np.int64)
|
|
1311
|
+
# Keep original class values for multi-class segmentation
|
|
1312
|
+
# No conversion to binary - preserve all class labels
|
|
1313
|
+
|
|
1314
|
+
# Convert to tensor
|
|
1315
|
+
mask = torch.as_tensor(label_mask, dtype=torch.long)
|
|
1316
|
+
|
|
1317
|
+
# Apply transforms if specified
|
|
1318
|
+
if self.transforms is not None:
|
|
1319
|
+
image, mask = self.transforms(image, mask)
|
|
1320
|
+
|
|
1321
|
+
return image, mask
|
|
1322
|
+
|
|
1323
|
+
|
|
1324
|
+
class SemanticTransforms:
|
|
1325
|
+
"""Custom transforms for semantic segmentation."""
|
|
1326
|
+
|
|
1327
|
+
def __init__(self, transforms):
|
|
1328
|
+
self.transforms = transforms
|
|
1329
|
+
|
|
1330
|
+
def __call__(self, image, mask):
|
|
1331
|
+
for t in self.transforms:
|
|
1332
|
+
image, mask = t(image, mask)
|
|
1333
|
+
return image, mask
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
class SemanticToTensor:
|
|
1337
|
+
"""Convert numpy.ndarray to tensor for semantic segmentation."""
|
|
1338
|
+
|
|
1339
|
+
def __call__(self, image, mask):
|
|
1340
|
+
return image, mask
|
|
1341
|
+
|
|
1342
|
+
|
|
1343
|
+
class SemanticRandomHorizontalFlip:
|
|
1344
|
+
"""Random horizontal flip transform for semantic segmentation."""
|
|
1345
|
+
|
|
1346
|
+
def __init__(self, prob=0.5):
|
|
1347
|
+
self.prob = prob
|
|
1348
|
+
|
|
1349
|
+
def __call__(self, image, mask):
|
|
1350
|
+
if random.random() < self.prob:
|
|
1351
|
+
# Flip image and mask along width dimension
|
|
1352
|
+
image = torch.flip(image, dims=[2])
|
|
1353
|
+
mask = torch.flip(mask, dims=[1])
|
|
1354
|
+
return image, mask
|
|
1355
|
+
|
|
1356
|
+
|
|
1357
|
+
def get_semantic_transform(train):
|
|
1358
|
+
"""
|
|
1359
|
+
Get transforms for semantic segmentation data augmentation.
|
|
1360
|
+
|
|
1361
|
+
Args:
|
|
1362
|
+
train (bool): Whether to include training-specific transforms.
|
|
1363
|
+
|
|
1364
|
+
Returns:
|
|
1365
|
+
SemanticTransforms: Composed transforms.
|
|
1366
|
+
"""
|
|
1367
|
+
transforms = []
|
|
1368
|
+
transforms.append(SemanticToTensor())
|
|
1369
|
+
|
|
1370
|
+
if train:
|
|
1371
|
+
transforms.append(SemanticRandomHorizontalFlip(0.5))
|
|
1372
|
+
|
|
1373
|
+
return SemanticTransforms(transforms)
|
|
1374
|
+
|
|
1375
|
+
|
|
1376
|
+
def get_smp_model(
|
|
1377
|
+
architecture="unet",
|
|
1378
|
+
encoder_name="resnet34",
|
|
1379
|
+
encoder_weights="imagenet",
|
|
1380
|
+
in_channels=3,
|
|
1381
|
+
classes=2,
|
|
1382
|
+
activation=None,
|
|
1383
|
+
**kwargs,
|
|
1384
|
+
):
|
|
1385
|
+
"""
|
|
1386
|
+
Get a segmentation model from segmentation-models-pytorch using the generic create_model function.
|
|
1387
|
+
|
|
1388
|
+
Args:
|
|
1389
|
+
architecture (str): Model architecture (e.g., 'unet', 'deeplabv3', 'deeplabv3plus', 'fpn',
|
|
1390
|
+
'pspnet', 'linknet', 'manet', 'pan', 'upernet', etc.). Case insensitive.
|
|
1391
|
+
encoder_name (str): Encoder backbone name (e.g., 'resnet34', 'efficientnet-b0', 'mit_b0', etc.).
|
|
1392
|
+
encoder_weights (str): Encoder weights ('imagenet' or None).
|
|
1393
|
+
in_channels (int): Number of input channels.
|
|
1394
|
+
classes (int): Number of output classes.
|
|
1395
|
+
activation (str): Activation function for output layer.
|
|
1396
|
+
**kwargs: Additional arguments passed to smp.create_model().
|
|
1397
|
+
|
|
1398
|
+
Returns:
|
|
1399
|
+
torch.nn.Module: Segmentation model.
|
|
1400
|
+
|
|
1401
|
+
Note:
|
|
1402
|
+
This function uses smp.create_model() which supports all architectures available in
|
|
1403
|
+
segmentation-models-pytorch, making it future-proof for new model additions.
|
|
1404
|
+
"""
|
|
1405
|
+
if not SMP_AVAILABLE:
|
|
1406
|
+
raise ImportError(
|
|
1407
|
+
"segmentation-models-pytorch is not installed. "
|
|
1408
|
+
"Please install it with: pip install segmentation-models-pytorch"
|
|
1409
|
+
)
|
|
1410
|
+
|
|
1411
|
+
try:
|
|
1412
|
+
# Use the generic create_model function - supports all SMP architectures
|
|
1413
|
+
model = smp.create_model(
|
|
1414
|
+
arch=architecture, # Case insensitive
|
|
1415
|
+
encoder_name=encoder_name,
|
|
1416
|
+
encoder_weights=encoder_weights,
|
|
1417
|
+
in_channels=in_channels,
|
|
1418
|
+
classes=classes,
|
|
1419
|
+
**kwargs,
|
|
1420
|
+
)
|
|
1421
|
+
|
|
1422
|
+
# Apply activation if specified (note: activation is handled differently in create_model)
|
|
1423
|
+
if activation is not None:
|
|
1424
|
+
import warnings
|
|
1425
|
+
|
|
1426
|
+
warnings.warn(
|
|
1427
|
+
"The 'activation' parameter is deprecated when using smp.create_model(). "
|
|
1428
|
+
"Apply activation manually after model creation if needed.",
|
|
1429
|
+
DeprecationWarning,
|
|
1430
|
+
stacklevel=2,
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
return model
|
|
1434
|
+
|
|
1435
|
+
except Exception as e:
|
|
1436
|
+
# Provide helpful error message
|
|
1437
|
+
available_archs = []
|
|
1438
|
+
try:
|
|
1439
|
+
# Try to get available architectures from smp
|
|
1440
|
+
if hasattr(smp, "get_available_models"):
|
|
1441
|
+
available_archs = smp.get_available_models()
|
|
1442
|
+
else:
|
|
1443
|
+
available_archs = [
|
|
1444
|
+
"unet",
|
|
1445
|
+
"unetplusplus",
|
|
1446
|
+
"manet",
|
|
1447
|
+
"linknet",
|
|
1448
|
+
"fpn",
|
|
1449
|
+
"pspnet",
|
|
1450
|
+
"deeplabv3",
|
|
1451
|
+
"deeplabv3plus",
|
|
1452
|
+
"pan",
|
|
1453
|
+
"upernet",
|
|
1454
|
+
]
|
|
1455
|
+
except:
|
|
1456
|
+
available_archs = [
|
|
1457
|
+
"unet",
|
|
1458
|
+
"fpn",
|
|
1459
|
+
"deeplabv3plus",
|
|
1460
|
+
"pspnet",
|
|
1461
|
+
"linknet",
|
|
1462
|
+
"manet",
|
|
1463
|
+
]
|
|
1464
|
+
|
|
1465
|
+
raise ValueError(
|
|
1466
|
+
f"Failed to create model with architecture '{architecture}' and encoder '{encoder_name}'. "
|
|
1467
|
+
f"Error: {str(e)}. "
|
|
1468
|
+
f"Available architectures include: {', '.join(available_archs)}. "
|
|
1469
|
+
f"Please check the segmentation-models-pytorch documentation for supported combinations."
|
|
1470
|
+
)
|
|
1471
|
+
|
|
1472
|
+
|
|
1473
|
+
def dice_coefficient(pred, target, smooth=1e-6, num_classes=None):
|
|
1474
|
+
"""
|
|
1475
|
+
Calculate Dice coefficient for segmentation (binary or multi-class).
|
|
1476
|
+
|
|
1477
|
+
Args:
|
|
1478
|
+
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
|
1479
|
+
target (torch.Tensor): Ground truth mask with shape [H, W].
|
|
1480
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
1481
|
+
num_classes (int, optional): Number of classes. If None, auto-detected.
|
|
1482
|
+
|
|
1483
|
+
Returns:
|
|
1484
|
+
float: Mean Dice coefficient across all classes.
|
|
1485
|
+
"""
|
|
1486
|
+
# Convert predictions to class predictions
|
|
1487
|
+
if pred.dim() == 3: # [C, H, W] format
|
|
1488
|
+
pred = torch.softmax(pred, dim=0)
|
|
1489
|
+
pred_classes = torch.argmax(pred, dim=0)
|
|
1490
|
+
elif pred.dim() == 2: # [H, W] format
|
|
1491
|
+
pred_classes = pred
|
|
1492
|
+
else:
|
|
1493
|
+
raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
|
|
1494
|
+
|
|
1495
|
+
# Auto-detect number of classes if not provided
|
|
1496
|
+
if num_classes is None:
|
|
1497
|
+
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
|
1498
|
+
|
|
1499
|
+
# Calculate Dice for each class and average
|
|
1500
|
+
dice_scores = []
|
|
1501
|
+
for class_id in range(num_classes):
|
|
1502
|
+
pred_class = (pred_classes == class_id).float()
|
|
1503
|
+
target_class = (target == class_id).float()
|
|
1504
|
+
|
|
1505
|
+
intersection = (pred_class * target_class).sum()
|
|
1506
|
+
union = pred_class.sum() + target_class.sum()
|
|
1507
|
+
|
|
1508
|
+
if union > 0:
|
|
1509
|
+
dice = (2.0 * intersection + smooth) / (union + smooth)
|
|
1510
|
+
dice_scores.append(dice.item())
|
|
1511
|
+
|
|
1512
|
+
return sum(dice_scores) / len(dice_scores) if dice_scores else 0.0
|
|
1513
|
+
|
|
1514
|
+
|
|
1515
|
+
def iou_coefficient(pred, target, smooth=1e-6, num_classes=None):
|
|
1516
|
+
"""
|
|
1517
|
+
Calculate IoU coefficient for segmentation (binary or multi-class).
|
|
1518
|
+
|
|
1519
|
+
Args:
|
|
1520
|
+
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
|
1521
|
+
target (torch.Tensor): Ground truth mask with shape [H, W].
|
|
1522
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
1523
|
+
num_classes (int, optional): Number of classes. If None, auto-detected.
|
|
1524
|
+
|
|
1525
|
+
Returns:
|
|
1526
|
+
float: Mean IoU coefficient across all classes.
|
|
1527
|
+
"""
|
|
1528
|
+
# Convert predictions to class predictions
|
|
1529
|
+
if pred.dim() == 3: # [C, H, W] format
|
|
1530
|
+
pred = torch.softmax(pred, dim=0)
|
|
1531
|
+
pred_classes = torch.argmax(pred, dim=0)
|
|
1532
|
+
elif pred.dim() == 2: # [H, W] format
|
|
1533
|
+
pred_classes = pred
|
|
1534
|
+
else:
|
|
1535
|
+
raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
|
|
1536
|
+
|
|
1537
|
+
# Auto-detect number of classes if not provided
|
|
1538
|
+
if num_classes is None:
|
|
1539
|
+
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
|
1540
|
+
|
|
1541
|
+
# Calculate IoU for each class and average
|
|
1542
|
+
iou_scores = []
|
|
1543
|
+
for class_id in range(num_classes):
|
|
1544
|
+
pred_class = (pred_classes == class_id).float()
|
|
1545
|
+
target_class = (target == class_id).float()
|
|
1546
|
+
|
|
1547
|
+
intersection = (pred_class * target_class).sum()
|
|
1548
|
+
union = pred_class.sum() + target_class.sum() - intersection
|
|
1549
|
+
|
|
1550
|
+
if union > 0:
|
|
1551
|
+
iou = (intersection + smooth) / (union + smooth)
|
|
1552
|
+
iou_scores.append(iou.item())
|
|
1553
|
+
|
|
1554
|
+
return sum(iou_scores) / len(iou_scores) if iou_scores else 0.0
|
|
1555
|
+
|
|
1556
|
+
|
|
1557
|
+
def train_semantic_one_epoch(
|
|
1558
|
+
model, optimizer, data_loader, device, epoch, criterion, print_freq=10, verbose=True
|
|
1559
|
+
):
|
|
1560
|
+
"""
|
|
1561
|
+
Train the semantic segmentation model for one epoch.
|
|
1562
|
+
|
|
1563
|
+
Args:
|
|
1564
|
+
model (torch.nn.Module): The model to train.
|
|
1565
|
+
optimizer (torch.optim.Optimizer): The optimizer to use.
|
|
1566
|
+
data_loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
1567
|
+
device (torch.device): Device to train on.
|
|
1568
|
+
epoch (int): Current epoch number.
|
|
1569
|
+
criterion: Loss function.
|
|
1570
|
+
print_freq (int): How often to print progress.
|
|
1571
|
+
verbose (bool): Whether to print detailed progress.
|
|
1572
|
+
|
|
1573
|
+
Returns:
|
|
1574
|
+
float: Average loss for the epoch.
|
|
1575
|
+
"""
|
|
1576
|
+
model.train()
|
|
1577
|
+
total_loss = 0
|
|
1578
|
+
num_batches = len(data_loader)
|
|
1579
|
+
|
|
1580
|
+
start_time = time.time()
|
|
1581
|
+
|
|
1582
|
+
for i, (images, targets) in enumerate(data_loader):
|
|
1583
|
+
# Move images and targets to device
|
|
1584
|
+
images = images.to(device)
|
|
1585
|
+
targets = targets.to(device)
|
|
1586
|
+
|
|
1587
|
+
# Forward pass
|
|
1588
|
+
outputs = model(images)
|
|
1589
|
+
loss = criterion(outputs, targets)
|
|
1590
|
+
|
|
1591
|
+
# Backward pass
|
|
1592
|
+
optimizer.zero_grad()
|
|
1593
|
+
loss.backward()
|
|
1594
|
+
optimizer.step()
|
|
1595
|
+
|
|
1596
|
+
# Track loss
|
|
1597
|
+
total_loss += loss.item()
|
|
1598
|
+
|
|
1599
|
+
# Print progress
|
|
1600
|
+
if i % print_freq == 0:
|
|
1601
|
+
elapsed_time = time.time() - start_time
|
|
1602
|
+
if verbose:
|
|
1603
|
+
print(
|
|
1604
|
+
f"Epoch: {epoch}, Batch: {i}/{num_batches}, Loss: {loss.item():.4f}, Time: {elapsed_time:.2f}s"
|
|
1605
|
+
)
|
|
1606
|
+
start_time = time.time()
|
|
1607
|
+
|
|
1608
|
+
# Calculate average loss
|
|
1609
|
+
avg_loss = total_loss / num_batches
|
|
1610
|
+
return avg_loss
|
|
1611
|
+
|
|
1612
|
+
|
|
1613
|
+
def evaluate_semantic(model, data_loader, device, criterion, num_classes=2):
|
|
1614
|
+
"""
|
|
1615
|
+
Evaluate the semantic segmentation model on the validation set.
|
|
1616
|
+
|
|
1617
|
+
Args:
|
|
1618
|
+
model (torch.nn.Module): The model to evaluate.
|
|
1619
|
+
data_loader (torch.utils.data.DataLoader): DataLoader for validation data.
|
|
1620
|
+
device (torch.device): Device to evaluate on.
|
|
1621
|
+
criterion: Loss function.
|
|
1622
|
+
num_classes (int): Number of classes for evaluation metrics.
|
|
1623
|
+
|
|
1624
|
+
Returns:
|
|
1625
|
+
dict: Evaluation metrics including loss, IoU, and Dice.
|
|
1626
|
+
"""
|
|
1627
|
+
model.eval()
|
|
1628
|
+
|
|
1629
|
+
total_loss = 0
|
|
1630
|
+
dice_scores = []
|
|
1631
|
+
iou_scores = []
|
|
1632
|
+
num_batches = len(data_loader)
|
|
1633
|
+
|
|
1634
|
+
with torch.no_grad():
|
|
1635
|
+
for images, targets in data_loader:
|
|
1636
|
+
# Move to device
|
|
1637
|
+
images = images.to(device)
|
|
1638
|
+
targets = targets.to(device)
|
|
1639
|
+
|
|
1640
|
+
# Forward pass
|
|
1641
|
+
outputs = model(images)
|
|
1642
|
+
loss = criterion(outputs, targets)
|
|
1643
|
+
total_loss += loss.item()
|
|
1644
|
+
|
|
1645
|
+
# Calculate metrics for each sample in the batch
|
|
1646
|
+
for pred, target in zip(outputs, targets):
|
|
1647
|
+
dice = dice_coefficient(pred, target, num_classes=num_classes)
|
|
1648
|
+
iou = iou_coefficient(pred, target, num_classes=num_classes)
|
|
1649
|
+
dice_scores.append(dice)
|
|
1650
|
+
iou_scores.append(iou)
|
|
1651
|
+
|
|
1652
|
+
# Calculate metrics
|
|
1653
|
+
avg_loss = total_loss / num_batches
|
|
1654
|
+
avg_dice = sum(dice_scores) / len(dice_scores) if dice_scores else 0
|
|
1655
|
+
avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
|
|
1656
|
+
|
|
1657
|
+
return {"loss": avg_loss, "Dice": avg_dice, "IoU": avg_iou}
|
|
1658
|
+
|
|
1659
|
+
|
|
1660
|
+
def train_segmentation_model(
|
|
1661
|
+
images_dir,
|
|
1662
|
+
labels_dir,
|
|
1663
|
+
output_dir,
|
|
1664
|
+
architecture="unet",
|
|
1665
|
+
encoder_name="resnet34",
|
|
1666
|
+
encoder_weights="imagenet",
|
|
1667
|
+
num_channels=3,
|
|
1668
|
+
num_classes=2,
|
|
1669
|
+
batch_size=8,
|
|
1670
|
+
num_epochs=50,
|
|
1671
|
+
learning_rate=0.001,
|
|
1672
|
+
weight_decay=1e-4,
|
|
1673
|
+
seed=42,
|
|
1674
|
+
val_split=0.2,
|
|
1675
|
+
print_freq=10,
|
|
1676
|
+
verbose=True,
|
|
1677
|
+
save_best_only=True,
|
|
1678
|
+
plot_curves=False,
|
|
1679
|
+
**kwargs,
|
|
1680
|
+
):
|
|
1681
|
+
"""
|
|
1682
|
+
Train a semantic segmentation model for object detection using segmentation-models-pytorch.
|
|
1683
|
+
|
|
1684
|
+
This function trains a semantic segmentation model for object detection (e.g., building detection)
|
|
1685
|
+
using models from the segmentation-models-pytorch library. Unlike instance segmentation (Mask R-CNN),
|
|
1686
|
+
this approach treats the task as pixel-level binary classification.
|
|
1687
|
+
|
|
1688
|
+
Args:
|
|
1689
|
+
images_dir (str): Directory containing image GeoTIFF files.
|
|
1690
|
+
labels_dir (str): Directory containing label GeoTIFF files.
|
|
1691
|
+
output_dir (str): Directory to save model checkpoints and results.
|
|
1692
|
+
architecture (str): Model architecture ('unet', 'deeplabv3', 'deeplabv3plus', 'fpn',
|
|
1693
|
+
'pspnet', 'linknet', 'manet'). Defaults to 'unet'.
|
|
1694
|
+
encoder_name (str): Encoder backbone name (e.g., 'resnet34', 'resnet50', 'efficientnet-b0').
|
|
1695
|
+
Defaults to 'resnet34'.
|
|
1696
|
+
encoder_weights (str): Encoder pretrained weights ('imagenet' or None). Defaults to 'imagenet'.
|
|
1697
|
+
num_channels (int): Number of input channels. Defaults to 3.
|
|
1698
|
+
num_classes (int): Number of output classes (typically 2 for binary segmentation). Defaults to 2.
|
|
1699
|
+
batch_size (int): Batch size for training. Defaults to 8.
|
|
1700
|
+
num_epochs (int): Number of training epochs. Defaults to 50.
|
|
1701
|
+
learning_rate (float): Initial learning rate. Defaults to 0.001.
|
|
1702
|
+
weight_decay (float): Weight decay for optimizer. Defaults to 1e-4.
|
|
1703
|
+
seed (int): Random seed for reproducibility. Defaults to 42.
|
|
1704
|
+
val_split (float): Fraction of data to use for validation (0-1). Defaults to 0.2.
|
|
1705
|
+
print_freq (int): Frequency of printing training progress. Defaults to 10.
|
|
1706
|
+
verbose (bool): If True, prints detailed training progress. Defaults to True.
|
|
1707
|
+
save_best_only (bool): If True, only saves the best model. Otherwise saves all checkpoints.
|
|
1708
|
+
Defaults to True.
|
|
1709
|
+
plot_curves (bool): If True, plots training curves. Defaults to False.
|
|
1710
|
+
**kwargs: Additional arguments passed to smp.create_model().
|
|
1711
|
+
Returns:
|
|
1712
|
+
None: Model weights are saved to output_dir.
|
|
1713
|
+
|
|
1714
|
+
Raises:
|
|
1715
|
+
ImportError: If segmentation-models-pytorch is not installed.
|
|
1716
|
+
FileNotFoundError: If input directories don't exist or contain no matching files.
|
|
1717
|
+
"""
|
|
1718
|
+
import datetime
|
|
1719
|
+
|
|
1720
|
+
if not SMP_AVAILABLE:
|
|
1721
|
+
raise ImportError(
|
|
1722
|
+
"segmentation-models-pytorch is not installed. "
|
|
1723
|
+
"Please install it with: pip install segmentation-models-pytorch"
|
|
1724
|
+
)
|
|
1725
|
+
|
|
1726
|
+
# Set random seeds for reproducibility
|
|
1727
|
+
torch.manual_seed(seed)
|
|
1728
|
+
np.random.seed(seed)
|
|
1729
|
+
random.seed(seed)
|
|
1730
|
+
torch.backends.cudnn.deterministic = True
|
|
1731
|
+
torch.backends.cudnn.benchmark = False
|
|
1732
|
+
|
|
1733
|
+
# Create output directory
|
|
1734
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1735
|
+
|
|
1736
|
+
# Get device
|
|
1737
|
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
1738
|
+
print(f"Using device: {device}")
|
|
1739
|
+
|
|
1740
|
+
# Get all image and label files
|
|
1741
|
+
image_files = sorted(
|
|
1742
|
+
[
|
|
1743
|
+
os.path.join(images_dir, f)
|
|
1744
|
+
for f in os.listdir(images_dir)
|
|
1745
|
+
if f.endswith(".tif")
|
|
1746
|
+
]
|
|
1747
|
+
)
|
|
1748
|
+
label_files = sorted(
|
|
1749
|
+
[
|
|
1750
|
+
os.path.join(labels_dir, f)
|
|
1751
|
+
for f in os.listdir(labels_dir)
|
|
1752
|
+
if f.endswith(".tif")
|
|
1753
|
+
]
|
|
1754
|
+
)
|
|
1755
|
+
|
|
1756
|
+
print(f"Found {len(image_files)} image files and {len(label_files)} label files")
|
|
1757
|
+
|
|
1758
|
+
# Ensure matching files
|
|
1759
|
+
if len(image_files) != len(label_files):
|
|
1760
|
+
print("Warning: Number of image files and label files don't match!")
|
|
1761
|
+
# Find matching files by basename
|
|
1762
|
+
basenames = [os.path.basename(f) for f in image_files]
|
|
1763
|
+
label_files = [
|
|
1764
|
+
os.path.join(labels_dir, os.path.basename(f))
|
|
1765
|
+
for f in image_files
|
|
1766
|
+
if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
|
|
1767
|
+
]
|
|
1768
|
+
image_files = [
|
|
1769
|
+
f
|
|
1770
|
+
for f, b in zip(image_files, basenames)
|
|
1771
|
+
if os.path.exists(os.path.join(labels_dir, b))
|
|
1772
|
+
]
|
|
1773
|
+
print(f"Using {len(image_files)} matching files")
|
|
1774
|
+
|
|
1775
|
+
if len(image_files) == 0:
|
|
1776
|
+
raise FileNotFoundError("No matching image and label files found")
|
|
1777
|
+
|
|
1778
|
+
# Split data into train and validation sets
|
|
1779
|
+
train_imgs, val_imgs, train_labels, val_labels = train_test_split(
|
|
1780
|
+
image_files, label_files, test_size=val_split, random_state=seed
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
print(f"Training on {len(train_imgs)} images, validating on {len(val_imgs)} images")
|
|
1784
|
+
|
|
1785
|
+
# Create datasets
|
|
1786
|
+
train_dataset = SemanticSegmentationDataset(
|
|
1787
|
+
train_imgs,
|
|
1788
|
+
train_labels,
|
|
1789
|
+
transforms=get_semantic_transform(train=True),
|
|
1790
|
+
num_channels=num_channels,
|
|
1791
|
+
)
|
|
1792
|
+
val_dataset = SemanticSegmentationDataset(
|
|
1793
|
+
val_imgs,
|
|
1794
|
+
val_labels,
|
|
1795
|
+
transforms=get_semantic_transform(train=False),
|
|
1796
|
+
num_channels=num_channels,
|
|
1797
|
+
)
|
|
1798
|
+
|
|
1799
|
+
# Create data loaders
|
|
1800
|
+
train_loader = DataLoader(
|
|
1801
|
+
train_dataset,
|
|
1802
|
+
batch_size=batch_size,
|
|
1803
|
+
shuffle=True,
|
|
1804
|
+
num_workers=4,
|
|
1805
|
+
pin_memory=True,
|
|
1806
|
+
)
|
|
1807
|
+
|
|
1808
|
+
val_loader = DataLoader(
|
|
1809
|
+
val_dataset,
|
|
1810
|
+
batch_size=batch_size,
|
|
1811
|
+
shuffle=False,
|
|
1812
|
+
num_workers=4,
|
|
1813
|
+
pin_memory=True,
|
|
1814
|
+
)
|
|
1815
|
+
|
|
1816
|
+
# Initialize model
|
|
1817
|
+
model = get_smp_model(
|
|
1818
|
+
architecture=architecture,
|
|
1819
|
+
encoder_name=encoder_name,
|
|
1820
|
+
encoder_weights=encoder_weights,
|
|
1821
|
+
in_channels=num_channels,
|
|
1822
|
+
classes=num_classes,
|
|
1823
|
+
activation=None, # We'll apply softmax later
|
|
1824
|
+
**kwargs,
|
|
1825
|
+
)
|
|
1826
|
+
model.to(device)
|
|
1827
|
+
|
|
1828
|
+
# Set up loss function (CrossEntropyLoss for multi-class, can also use DiceLoss)
|
|
1829
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
1830
|
+
|
|
1831
|
+
# Set up optimizer
|
|
1832
|
+
optimizer = torch.optim.Adam(
|
|
1833
|
+
model.parameters(), lr=learning_rate, weight_decay=weight_decay
|
|
1834
|
+
)
|
|
1835
|
+
|
|
1836
|
+
# Set up learning rate scheduler
|
|
1837
|
+
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
1838
|
+
optimizer, mode="min", factor=0.5, patience=5, verbose=True
|
|
1839
|
+
)
|
|
1840
|
+
|
|
1841
|
+
# Initialize tracking variables
|
|
1842
|
+
best_iou = 0
|
|
1843
|
+
train_losses = []
|
|
1844
|
+
val_losses = []
|
|
1845
|
+
val_ious = []
|
|
1846
|
+
val_dices = []
|
|
1847
|
+
|
|
1848
|
+
print(f"Starting training with {architecture} + {encoder_name}")
|
|
1849
|
+
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
1850
|
+
|
|
1851
|
+
# Training loop
|
|
1852
|
+
for epoch in range(num_epochs):
|
|
1853
|
+
# Train one epoch
|
|
1854
|
+
train_loss = train_semantic_one_epoch(
|
|
1855
|
+
model,
|
|
1856
|
+
optimizer,
|
|
1857
|
+
train_loader,
|
|
1858
|
+
device,
|
|
1859
|
+
epoch,
|
|
1860
|
+
criterion,
|
|
1861
|
+
print_freq,
|
|
1862
|
+
verbose,
|
|
1863
|
+
)
|
|
1864
|
+
train_losses.append(train_loss)
|
|
1865
|
+
|
|
1866
|
+
# Evaluate on validation set
|
|
1867
|
+
eval_metrics = evaluate_semantic(
|
|
1868
|
+
model, val_loader, device, criterion, num_classes=num_classes
|
|
1869
|
+
)
|
|
1870
|
+
val_losses.append(eval_metrics["loss"])
|
|
1871
|
+
val_ious.append(eval_metrics["IoU"])
|
|
1872
|
+
val_dices.append(eval_metrics["Dice"])
|
|
1873
|
+
|
|
1874
|
+
# Update learning rate
|
|
1875
|
+
lr_scheduler.step(eval_metrics["loss"])
|
|
1876
|
+
|
|
1877
|
+
# Print metrics
|
|
1878
|
+
print(
|
|
1879
|
+
f"Epoch {epoch+1}/{num_epochs}: "
|
|
1880
|
+
f"Train Loss: {train_loss:.4f}, "
|
|
1881
|
+
f"Val Loss: {eval_metrics['loss']:.4f}, "
|
|
1882
|
+
f"Val IoU: {eval_metrics['IoU']:.4f}, "
|
|
1883
|
+
f"Val Dice: {eval_metrics['Dice']:.4f}"
|
|
1884
|
+
)
|
|
1885
|
+
|
|
1886
|
+
# Save best model
|
|
1887
|
+
if eval_metrics["IoU"] > best_iou:
|
|
1888
|
+
best_iou = eval_metrics["IoU"]
|
|
1889
|
+
print(f"Saving best model with IoU: {best_iou:.4f}")
|
|
1890
|
+
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
|
1891
|
+
|
|
1892
|
+
# Save checkpoint every 10 epochs (if not save_best_only)
|
|
1893
|
+
if not save_best_only and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
|
|
1894
|
+
torch.save(
|
|
1895
|
+
{
|
|
1896
|
+
"epoch": epoch,
|
|
1897
|
+
"model_state_dict": model.state_dict(),
|
|
1898
|
+
"optimizer_state_dict": optimizer.state_dict(),
|
|
1899
|
+
"scheduler_state_dict": lr_scheduler.state_dict(),
|
|
1900
|
+
"best_iou": best_iou,
|
|
1901
|
+
"architecture": architecture,
|
|
1902
|
+
"encoder_name": encoder_name,
|
|
1903
|
+
"num_channels": num_channels,
|
|
1904
|
+
"num_classes": num_classes,
|
|
1905
|
+
},
|
|
1906
|
+
os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
|
|
1907
|
+
)
|
|
1908
|
+
|
|
1909
|
+
# Save final model
|
|
1910
|
+
torch.save(model.state_dict(), os.path.join(output_dir, "final_model.pth"))
|
|
1911
|
+
|
|
1912
|
+
# Save training history
|
|
1913
|
+
history = {
|
|
1914
|
+
"train_losses": train_losses,
|
|
1915
|
+
"val_losses": val_losses,
|
|
1916
|
+
"val_ious": val_ious,
|
|
1917
|
+
"val_dices": val_dices,
|
|
1918
|
+
}
|
|
1919
|
+
torch.save(history, os.path.join(output_dir, "training_history.pth"))
|
|
1920
|
+
|
|
1921
|
+
# Save training summary
|
|
1922
|
+
with open(os.path.join(output_dir, "training_summary.txt"), "w") as f:
|
|
1923
|
+
f.write(
|
|
1924
|
+
f"Training completed on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
|
|
1925
|
+
)
|
|
1926
|
+
f.write(f"Architecture: {architecture}\n")
|
|
1927
|
+
f.write(f"Encoder: {encoder_name}\n")
|
|
1928
|
+
f.write(f"Total epochs: {num_epochs}\n")
|
|
1929
|
+
f.write(f"Best validation IoU: {best_iou:.4f}\n")
|
|
1930
|
+
f.write(f"Final validation IoU: {val_ious[-1]:.4f}\n")
|
|
1931
|
+
f.write(f"Final validation Dice: {val_dices[-1]:.4f}\n")
|
|
1932
|
+
f.write(f"Final validation loss: {val_losses[-1]:.4f}\n")
|
|
1933
|
+
|
|
1934
|
+
print(f"Training complete! Best IoU: {best_iou:.4f}")
|
|
1935
|
+
print(f"Models saved to {output_dir}")
|
|
1936
|
+
|
|
1937
|
+
# Plot training curves
|
|
1938
|
+
if plot_curves:
|
|
1939
|
+
try:
|
|
1940
|
+
plt.figure(figsize=(15, 5))
|
|
1941
|
+
|
|
1942
|
+
plt.subplot(1, 3, 1)
|
|
1943
|
+
plt.plot(train_losses, label="Train Loss")
|
|
1944
|
+
plt.plot(val_losses, label="Val Loss")
|
|
1945
|
+
plt.title("Loss")
|
|
1946
|
+
plt.xlabel("Epoch")
|
|
1947
|
+
plt.ylabel("Loss")
|
|
1948
|
+
plt.legend()
|
|
1949
|
+
plt.grid(True)
|
|
1950
|
+
|
|
1951
|
+
plt.subplot(1, 3, 2)
|
|
1952
|
+
plt.plot(val_ious, label="Val IoU")
|
|
1953
|
+
plt.title("IoU Score")
|
|
1954
|
+
plt.xlabel("Epoch")
|
|
1955
|
+
plt.ylabel("IoU")
|
|
1956
|
+
plt.legend()
|
|
1957
|
+
plt.grid(True)
|
|
1958
|
+
|
|
1959
|
+
plt.subplot(1, 3, 3)
|
|
1960
|
+
plt.plot(val_dices, label="Val Dice")
|
|
1961
|
+
plt.title("Dice Score")
|
|
1962
|
+
plt.xlabel("Epoch")
|
|
1963
|
+
plt.ylabel("Dice")
|
|
1964
|
+
plt.legend()
|
|
1965
|
+
plt.grid(True)
|
|
1966
|
+
|
|
1967
|
+
plt.tight_layout()
|
|
1968
|
+
plt.savefig(
|
|
1969
|
+
os.path.join(output_dir, "training_curves.png"),
|
|
1970
|
+
dpi=150,
|
|
1971
|
+
bbox_inches="tight",
|
|
1972
|
+
)
|
|
1973
|
+
print(
|
|
1974
|
+
f"Training curves saved to {os.path.join(output_dir, 'training_curves.png')}"
|
|
1975
|
+
)
|
|
1976
|
+
plt.close()
|
|
1977
|
+
except Exception as e:
|
|
1978
|
+
print(f"Could not save training curves: {e}")
|
|
1979
|
+
|
|
1980
|
+
|
|
1981
|
+
def semantic_inference_on_geotiff(
|
|
1982
|
+
model,
|
|
1983
|
+
geotiff_path,
|
|
1984
|
+
output_path,
|
|
1985
|
+
window_size=512,
|
|
1986
|
+
overlap=256,
|
|
1987
|
+
batch_size=4,
|
|
1988
|
+
num_channels=3,
|
|
1989
|
+
num_classes=2,
|
|
1990
|
+
device=None,
|
|
1991
|
+
**kwargs,
|
|
1992
|
+
):
|
|
1993
|
+
"""
|
|
1994
|
+
Perform semantic segmentation inference on a large GeoTIFF using a sliding window approach.
|
|
1995
|
+
|
|
1996
|
+
Args:
|
|
1997
|
+
model (torch.nn.Module): Trained semantic segmentation model.
|
|
1998
|
+
geotiff_path (str): Path to input GeoTIFF file.
|
|
1999
|
+
output_path (str): Path to save output mask GeoTIFF.
|
|
2000
|
+
window_size (int): Size of sliding window for inference.
|
|
2001
|
+
overlap (int): Overlap between adjacent windows.
|
|
2002
|
+
batch_size (int): Batch size for inference.
|
|
2003
|
+
num_channels (int): Number of channels to use from the input image.
|
|
2004
|
+
num_classes (int): Number of classes in the model output.
|
|
2005
|
+
device (torch.device, optional): Device to run inference on.
|
|
2006
|
+
**kwargs: Additional arguments.
|
|
2007
|
+
|
|
2008
|
+
Returns:
|
|
2009
|
+
tuple: Tuple containing output path and inference time in seconds.
|
|
2010
|
+
"""
|
|
2011
|
+
if device is None:
|
|
2012
|
+
device = (
|
|
2013
|
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
2014
|
+
)
|
|
2015
|
+
|
|
2016
|
+
# Put model in evaluation mode
|
|
2017
|
+
model.to(device)
|
|
2018
|
+
model.eval()
|
|
2019
|
+
|
|
2020
|
+
# Open the GeoTIFF
|
|
2021
|
+
with rasterio.open(geotiff_path) as src:
|
|
2022
|
+
# Read metadata
|
|
2023
|
+
meta = src.meta
|
|
2024
|
+
height = src.height
|
|
2025
|
+
width = src.width
|
|
2026
|
+
|
|
2027
|
+
# Update metadata for output raster
|
|
2028
|
+
out_meta = meta.copy()
|
|
2029
|
+
out_meta.update({"count": 1, "dtype": "uint8"})
|
|
2030
|
+
|
|
2031
|
+
# Initialize accumulator arrays for multi-class probability blending
|
|
2032
|
+
# We'll accumulate probabilities for each class and then take argmax
|
|
2033
|
+
prob_accumulator = np.zeros((num_classes, height, width), dtype=np.float32)
|
|
2034
|
+
count_accumulator = np.zeros((height, width), dtype=np.float32)
|
|
2035
|
+
|
|
2036
|
+
# Calculate steps
|
|
2037
|
+
steps_y = math.ceil((height - overlap) / (window_size - overlap))
|
|
2038
|
+
steps_x = math.ceil((width - overlap) / (window_size - overlap))
|
|
2039
|
+
last_y = height - window_size
|
|
2040
|
+
last_x = width - window_size
|
|
2041
|
+
|
|
2042
|
+
total_windows = steps_y * steps_x
|
|
2043
|
+
print(f"Processing {total_windows} windows...")
|
|
2044
|
+
|
|
2045
|
+
pbar = tqdm(total=total_windows)
|
|
2046
|
+
batch_inputs = []
|
|
2047
|
+
batch_positions = []
|
|
2048
|
+
batch_count = 0
|
|
2049
|
+
|
|
2050
|
+
start_time = time.time()
|
|
2051
|
+
|
|
2052
|
+
for i in range(steps_y + 1):
|
|
2053
|
+
y = min(i * (window_size - overlap), last_y)
|
|
2054
|
+
y = max(0, y)
|
|
2055
|
+
|
|
2056
|
+
if y > last_y and i > 0:
|
|
2057
|
+
continue
|
|
2058
|
+
|
|
2059
|
+
for j in range(steps_x + 1):
|
|
2060
|
+
x = min(j * (window_size - overlap), last_x)
|
|
2061
|
+
x = max(0, x)
|
|
2062
|
+
|
|
2063
|
+
if x > last_x and j > 0:
|
|
2064
|
+
continue
|
|
2065
|
+
|
|
2066
|
+
# Read window
|
|
2067
|
+
window = src.read(window=Window(x, y, window_size, window_size))
|
|
2068
|
+
|
|
2069
|
+
if window.shape[1] == 0 or window.shape[2] == 0:
|
|
2070
|
+
continue
|
|
2071
|
+
|
|
2072
|
+
current_height = window.shape[1]
|
|
2073
|
+
current_width = window.shape[2]
|
|
2074
|
+
|
|
2075
|
+
# Normalize and prepare input
|
|
2076
|
+
image = window.astype(np.float32) / 255.0
|
|
2077
|
+
|
|
2078
|
+
# Handle different number of bands
|
|
2079
|
+
if image.shape[0] > num_channels:
|
|
2080
|
+
image = image[:num_channels]
|
|
2081
|
+
elif image.shape[0] < num_channels:
|
|
2082
|
+
padded = np.zeros(
|
|
2083
|
+
(num_channels, current_height, current_width), dtype=np.float32
|
|
2084
|
+
)
|
|
2085
|
+
padded[: image.shape[0]] = image
|
|
2086
|
+
image = padded
|
|
2087
|
+
|
|
2088
|
+
# Convert to tensor
|
|
2089
|
+
image_tensor = torch.tensor(image, device=device)
|
|
2090
|
+
|
|
2091
|
+
# Add to batch
|
|
2092
|
+
batch_inputs.append(image_tensor)
|
|
2093
|
+
batch_positions.append((y, x, current_height, current_width))
|
|
2094
|
+
batch_count += 1
|
|
2095
|
+
|
|
2096
|
+
# Process batch
|
|
2097
|
+
if batch_count == batch_size or (i == steps_y and j == steps_x):
|
|
2098
|
+
with torch.no_grad():
|
|
2099
|
+
batch_tensor = torch.stack(batch_inputs)
|
|
2100
|
+
outputs = model(batch_tensor)
|
|
2101
|
+
|
|
2102
|
+
# Apply softmax to get class probabilities
|
|
2103
|
+
probs = torch.softmax(outputs, dim=1)
|
|
2104
|
+
|
|
2105
|
+
# Process each output in the batch
|
|
2106
|
+
for idx, prob in enumerate(probs):
|
|
2107
|
+
y_pos, x_pos, h, w = batch_positions[idx]
|
|
2108
|
+
|
|
2109
|
+
# Create weight matrix for blending
|
|
2110
|
+
y_grid, x_grid = np.mgrid[0:h, 0:w]
|
|
2111
|
+
dist_from_left = x_grid
|
|
2112
|
+
dist_from_right = w - x_grid - 1
|
|
2113
|
+
dist_from_top = y_grid
|
|
2114
|
+
dist_from_bottom = h - y_grid - 1
|
|
2115
|
+
|
|
2116
|
+
edge_distance = np.minimum.reduce(
|
|
2117
|
+
[
|
|
2118
|
+
dist_from_left,
|
|
2119
|
+
dist_from_right,
|
|
2120
|
+
dist_from_top,
|
|
2121
|
+
dist_from_bottom,
|
|
2122
|
+
]
|
|
2123
|
+
)
|
|
2124
|
+
edge_distance = np.minimum(edge_distance, overlap / 2)
|
|
2125
|
+
|
|
2126
|
+
# Avoid zero weights - use minimum weight of 0.1
|
|
2127
|
+
weight = np.maximum(edge_distance / (overlap / 2), 0.1)
|
|
2128
|
+
|
|
2129
|
+
# For non-overlapping windows, use uniform weight
|
|
2130
|
+
if overlap == 0:
|
|
2131
|
+
weight = np.ones_like(weight)
|
|
2132
|
+
|
|
2133
|
+
# Convert probabilities to numpy [C, H, W]
|
|
2134
|
+
prob_np = prob.cpu().numpy()
|
|
2135
|
+
|
|
2136
|
+
# Accumulate weighted probabilities for each class
|
|
2137
|
+
y_slice = slice(y_pos, y_pos + h)
|
|
2138
|
+
x_slice = slice(x_pos, x_pos + w)
|
|
2139
|
+
|
|
2140
|
+
# Add weighted probabilities for each class
|
|
2141
|
+
for class_idx in range(num_classes):
|
|
2142
|
+
prob_accumulator[class_idx, y_slice, x_slice] += (
|
|
2143
|
+
prob_np[class_idx] * weight
|
|
2144
|
+
)
|
|
2145
|
+
|
|
2146
|
+
# Update weight accumulator
|
|
2147
|
+
count_accumulator[y_slice, x_slice] += weight
|
|
2148
|
+
|
|
2149
|
+
# Reset batch
|
|
2150
|
+
batch_inputs = []
|
|
2151
|
+
batch_positions = []
|
|
2152
|
+
batch_count = 0
|
|
2153
|
+
pbar.update(len(probs))
|
|
2154
|
+
|
|
2155
|
+
pbar.close()
|
|
2156
|
+
|
|
2157
|
+
# Calculate final mask by taking argmax of accumulated probabilities
|
|
2158
|
+
mask = np.zeros((height, width), dtype=np.uint8)
|
|
2159
|
+
valid_pixels = count_accumulator > 0
|
|
2160
|
+
|
|
2161
|
+
if np.any(valid_pixels):
|
|
2162
|
+
# Normalize accumulated probabilities by weights
|
|
2163
|
+
normalized_probs = np.zeros_like(prob_accumulator)
|
|
2164
|
+
for class_idx in range(num_classes):
|
|
2165
|
+
normalized_probs[class_idx, valid_pixels] = (
|
|
2166
|
+
prob_accumulator[class_idx, valid_pixels]
|
|
2167
|
+
/ count_accumulator[valid_pixels]
|
|
2168
|
+
)
|
|
2169
|
+
|
|
2170
|
+
# Take argmax to get final class predictions
|
|
2171
|
+
mask[valid_pixels] = np.argmax(
|
|
2172
|
+
normalized_probs[:, valid_pixels], axis=0
|
|
2173
|
+
).astype(np.uint8)
|
|
2174
|
+
|
|
2175
|
+
# Check class distribution in predictions (summary only)
|
|
2176
|
+
unique_classes, class_counts = np.unique(
|
|
2177
|
+
mask[valid_pixels], return_counts=True
|
|
2178
|
+
)
|
|
2179
|
+
bg_ratio = np.sum(mask == 0) / mask.size
|
|
2180
|
+
print(
|
|
2181
|
+
f"Predicted classes: {len(unique_classes)} classes, Background: {bg_ratio:.1%}"
|
|
2182
|
+
)
|
|
2183
|
+
|
|
2184
|
+
inference_time = time.time() - start_time
|
|
2185
|
+
print(f"Inference completed in {inference_time:.2f} seconds")
|
|
2186
|
+
|
|
2187
|
+
# Save output
|
|
2188
|
+
with rasterio.open(output_path, "w", **out_meta) as dst:
|
|
2189
|
+
dst.write(mask, 1)
|
|
2190
|
+
|
|
2191
|
+
print(f"Saved prediction to {output_path}")
|
|
2192
|
+
|
|
2193
|
+
return output_path, inference_time
|
|
2194
|
+
|
|
2195
|
+
|
|
2196
|
+
def semantic_segmentation(
|
|
2197
|
+
input_path,
|
|
2198
|
+
output_path,
|
|
2199
|
+
model_path,
|
|
2200
|
+
architecture="unet",
|
|
2201
|
+
encoder_name="resnet34",
|
|
2202
|
+
num_channels=3,
|
|
2203
|
+
num_classes=2,
|
|
2204
|
+
window_size=512,
|
|
2205
|
+
overlap=256,
|
|
2206
|
+
batch_size=4,
|
|
2207
|
+
device=None,
|
|
2208
|
+
**kwargs,
|
|
2209
|
+
):
|
|
2210
|
+
"""
|
|
2211
|
+
Perform semantic segmentation on a GeoTIFF using a trained model.
|
|
2212
|
+
|
|
2213
|
+
Args:
|
|
2214
|
+
input_path (str): Path to input GeoTIFF file.
|
|
2215
|
+
output_path (str): Path to save output mask GeoTIFF.
|
|
2216
|
+
model_path (str): Path to trained model weights.
|
|
2217
|
+
architecture (str): Model architecture used for training.
|
|
2218
|
+
encoder_name (str): Encoder backbone name used for training.
|
|
2219
|
+
num_channels (int): Number of channels in the input image and model.
|
|
2220
|
+
num_classes (int): Number of classes in the model.
|
|
2221
|
+
window_size (int): Size of sliding window for inference.
|
|
2222
|
+
overlap (int): Overlap between adjacent windows.
|
|
2223
|
+
batch_size (int): Batch size for inference.
|
|
2224
|
+
device (torch.device, optional): Device to run inference on.
|
|
2225
|
+
**kwargs: Additional arguments.
|
|
2226
|
+
|
|
2227
|
+
Returns:
|
|
2228
|
+
None: Output mask is saved to output_path.
|
|
2229
|
+
"""
|
|
2230
|
+
if device is None:
|
|
2231
|
+
device = (
|
|
2232
|
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
2233
|
+
)
|
|
2234
|
+
|
|
2235
|
+
# Load model
|
|
2236
|
+
model = get_smp_model(
|
|
2237
|
+
architecture=architecture,
|
|
2238
|
+
encoder_name=encoder_name,
|
|
2239
|
+
encoder_weights=None, # We're loading trained weights
|
|
2240
|
+
in_channels=num_channels,
|
|
2241
|
+
classes=num_classes,
|
|
2242
|
+
activation=None,
|
|
2243
|
+
)
|
|
2244
|
+
|
|
2245
|
+
if not os.path.exists(model_path):
|
|
2246
|
+
try:
|
|
2247
|
+
model_path = download_model_from_hf(model_path)
|
|
2248
|
+
except Exception as e:
|
|
2249
|
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
2250
|
+
|
|
2251
|
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
2252
|
+
model.to(device)
|
|
2253
|
+
model.eval()
|
|
2254
|
+
|
|
2255
|
+
semantic_inference_on_geotiff(
|
|
2256
|
+
model=model,
|
|
2257
|
+
geotiff_path=input_path,
|
|
2258
|
+
output_path=output_path,
|
|
2259
|
+
window_size=window_size,
|
|
2260
|
+
overlap=overlap,
|
|
2261
|
+
batch_size=batch_size,
|
|
2262
|
+
num_channels=num_channels,
|
|
2263
|
+
num_classes=num_classes,
|
|
2264
|
+
device=device,
|
|
2265
|
+
**kwargs,
|
|
2266
|
+
)
|