geoai-py 0.8.3__py2.py3-none-any.whl → 0.9.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/change_detection.py +1568 -0
- geoai/classify.py +58 -57
- geoai/detectron2.py +466 -0
- geoai/download.py +74 -68
- geoai/extract.py +186 -141
- geoai/geoai.py +13 -11
- geoai/hf.py +14 -12
- geoai/segment.py +44 -39
- geoai/segmentation.py +10 -9
- geoai/train.py +372 -241
- geoai/utils.py +198 -123
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/METADATA +5 -1
- geoai_py-0.9.1.dist-info/RECORD +19 -0
- geoai_py-0.8.3.dist-info/RECORD +0 -17
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/WHEEL +0 -0
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/top_level.txt +0 -0
geoai/train.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4
4
|
import platform
|
5
5
|
import random
|
6
6
|
import time
|
7
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
7
8
|
|
8
9
|
import matplotlib.pyplot as plt
|
9
10
|
import numpy as np
|
@@ -34,7 +35,9 @@ except ImportError:
|
|
34
35
|
SMP_AVAILABLE = False
|
35
36
|
|
36
37
|
|
37
|
-
def get_instance_segmentation_model(
|
38
|
+
def get_instance_segmentation_model(
|
39
|
+
num_classes: int = 2, num_channels: int = 3, pretrained: bool = True
|
40
|
+
) -> torch.nn.Module:
|
38
41
|
"""
|
39
42
|
Get Mask R-CNN model with custom input channels and output classes.
|
40
43
|
|
@@ -129,7 +132,13 @@ def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=Tr
|
|
129
132
|
class ObjectDetectionDataset(Dataset):
|
130
133
|
"""Dataset for object detection from GeoTIFF images and labels."""
|
131
134
|
|
132
|
-
def __init__(
|
135
|
+
def __init__(
|
136
|
+
self,
|
137
|
+
image_paths: List[str],
|
138
|
+
label_paths: List[str],
|
139
|
+
transforms: Optional[Callable] = None,
|
140
|
+
num_channels: Optional[int] = None,
|
141
|
+
) -> None:
|
133
142
|
"""
|
134
143
|
Initialize dataset.
|
135
144
|
|
@@ -151,10 +160,10 @@ class ObjectDetectionDataset(Dataset):
|
|
151
160
|
else:
|
152
161
|
self.num_channels = num_channels
|
153
162
|
|
154
|
-
def __len__(self):
|
163
|
+
def __len__(self) -> int:
|
155
164
|
return len(self.image_paths)
|
156
165
|
|
157
|
-
def __getitem__(self, idx):
|
166
|
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
158
167
|
# Load image
|
159
168
|
with rasterio.open(self.image_paths[idx]) as src:
|
160
169
|
# Read as [C, H, W] format
|
@@ -270,7 +279,7 @@ class ObjectDetectionDataset(Dataset):
|
|
270
279
|
class Compose:
|
271
280
|
"""Custom compose transform that works with image and target."""
|
272
281
|
|
273
|
-
def __init__(self, transforms):
|
282
|
+
def __init__(self, transforms: List[Callable]) -> None:
|
274
283
|
"""
|
275
284
|
Initialize compose transform.
|
276
285
|
|
@@ -279,7 +288,9 @@ class Compose:
|
|
279
288
|
"""
|
280
289
|
self.transforms = transforms
|
281
290
|
|
282
|
-
def __call__(
|
291
|
+
def __call__(
|
292
|
+
self, image: torch.Tensor, target: Dict[str, torch.Tensor]
|
293
|
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
283
294
|
for t in self.transforms:
|
284
295
|
image, target = t(image, target)
|
285
296
|
return image, target
|
@@ -288,7 +299,9 @@ class Compose:
|
|
288
299
|
class ToTensor:
|
289
300
|
"""Convert numpy.ndarray to tensor."""
|
290
301
|
|
291
|
-
def __call__(
|
302
|
+
def __call__(
|
303
|
+
self, image: torch.Tensor, target: Dict[str, torch.Tensor]
|
304
|
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
292
305
|
"""
|
293
306
|
Apply transform to image and target.
|
294
307
|
|
@@ -305,7 +318,7 @@ class ToTensor:
|
|
305
318
|
class RandomHorizontalFlip:
|
306
319
|
"""Random horizontal flip transform."""
|
307
320
|
|
308
|
-
def __init__(self, prob=0.5):
|
321
|
+
def __init__(self, prob: float = 0.5) -> None:
|
309
322
|
"""
|
310
323
|
Initialize random horizontal flip.
|
311
324
|
|
@@ -314,7 +327,9 @@ class RandomHorizontalFlip:
|
|
314
327
|
"""
|
315
328
|
self.prob = prob
|
316
329
|
|
317
|
-
def __call__(
|
330
|
+
def __call__(
|
331
|
+
self, image: torch.Tensor, target: Dict[str, torch.Tensor]
|
332
|
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
318
333
|
if random.random() < self.prob:
|
319
334
|
# Flip image
|
320
335
|
image = torch.flip(image, dims=[2]) # Flip along width dimension
|
@@ -333,7 +348,7 @@ class RandomHorizontalFlip:
|
|
333
348
|
return image, target
|
334
349
|
|
335
350
|
|
336
|
-
def get_transform(train):
|
351
|
+
def get_transform(train: bool) -> torchvision.transforms.Compose:
|
337
352
|
"""
|
338
353
|
Get transforms for data augmentation.
|
339
354
|
|
@@ -352,7 +367,9 @@ def get_transform(train):
|
|
352
367
|
return Compose(transforms)
|
353
368
|
|
354
369
|
|
355
|
-
def collate_fn(
|
370
|
+
def collate_fn(
|
371
|
+
batch: List[Tuple[torch.Tensor, Dict[str, torch.Tensor]]],
|
372
|
+
) -> Tuple[Tuple[torch.Tensor, ...], Tuple[Dict[str, torch.Tensor], ...]]:
|
356
373
|
"""
|
357
374
|
Custom collate function for batching samples.
|
358
375
|
|
@@ -366,8 +383,14 @@ def collate_fn(batch):
|
|
366
383
|
|
367
384
|
|
368
385
|
def train_one_epoch(
|
369
|
-
model
|
370
|
-
|
386
|
+
model: torch.nn.Module,
|
387
|
+
optimizer: torch.optim.Optimizer,
|
388
|
+
data_loader: DataLoader,
|
389
|
+
device: torch.device,
|
390
|
+
epoch: int,
|
391
|
+
print_freq: int = 10,
|
392
|
+
verbose: bool = True,
|
393
|
+
) -> float:
|
371
394
|
"""
|
372
395
|
Train the model for one epoch.
|
373
396
|
|
@@ -419,7 +442,9 @@ def train_one_epoch(
|
|
419
442
|
return avg_loss
|
420
443
|
|
421
444
|
|
422
|
-
def evaluate(
|
445
|
+
def evaluate(
|
446
|
+
model: torch.nn.Module, data_loader: DataLoader, device: torch.device
|
447
|
+
) -> Dict[str, float]:
|
423
448
|
"""
|
424
449
|
Evaluate the model on the validation set.
|
425
450
|
|
@@ -495,7 +520,13 @@ def evaluate(model, data_loader, device):
|
|
495
520
|
return {"loss": avg_loss, "IoU": avg_iou}
|
496
521
|
|
497
522
|
|
498
|
-
def visualize_predictions(
|
523
|
+
def visualize_predictions(
|
524
|
+
model: torch.nn.Module,
|
525
|
+
dataset: Dataset,
|
526
|
+
device: torch.device,
|
527
|
+
num_samples: int = 5,
|
528
|
+
output_dir: Optional[str] = None,
|
529
|
+
) -> None:
|
499
530
|
"""
|
500
531
|
Visualize model predictions.
|
501
532
|
|
@@ -583,24 +614,25 @@ def visualize_predictions(model, dataset, device, num_samples=5, output_dir=None
|
|
583
614
|
|
584
615
|
|
585
616
|
def train_MaskRCNN_model(
|
586
|
-
images_dir,
|
587
|
-
labels_dir,
|
588
|
-
output_dir,
|
589
|
-
num_channels=3,
|
590
|
-
model=None,
|
591
|
-
pretrained=True,
|
592
|
-
pretrained_model_path=None,
|
593
|
-
batch_size=4,
|
594
|
-
num_epochs=10,
|
595
|
-
learning_rate=0.005,
|
596
|
-
seed=42,
|
597
|
-
val_split=0.2,
|
598
|
-
visualize=False,
|
599
|
-
resume_training=False,
|
600
|
-
print_freq=10,
|
601
|
-
device=None,
|
602
|
-
|
603
|
-
|
617
|
+
images_dir: str,
|
618
|
+
labels_dir: str,
|
619
|
+
output_dir: str,
|
620
|
+
num_channels: int = 3,
|
621
|
+
model: Optional[torch.nn.Module] = None,
|
622
|
+
pretrained: bool = True,
|
623
|
+
pretrained_model_path: Optional[str] = None,
|
624
|
+
batch_size: int = 4,
|
625
|
+
num_epochs: int = 10,
|
626
|
+
learning_rate: float = 0.005,
|
627
|
+
seed: int = 42,
|
628
|
+
val_split: float = 0.2,
|
629
|
+
visualize: bool = False,
|
630
|
+
resume_training: bool = False,
|
631
|
+
print_freq: int = 10,
|
632
|
+
device: Optional[torch.device] = None,
|
633
|
+
num_workers: Optional[int] = None,
|
634
|
+
verbose: bool = True,
|
635
|
+
) -> torch.nn.Module:
|
604
636
|
"""Train and evaluate Mask R-CNN model for instance segmentation.
|
605
637
|
|
606
638
|
This function trains a Mask R-CNN model for instance segmentation using the
|
@@ -629,6 +661,7 @@ def train_MaskRCNN_model(
|
|
629
661
|
will try to load optimizer and scheduler states as well. Defaults to False.
|
630
662
|
print_freq (int): Frequency of printing training progress. Defaults to 10.
|
631
663
|
device (torch.device): Device to train on. If None, uses CUDA if available.
|
664
|
+
num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
|
632
665
|
verbose (bool): If True, prints detailed training progress. Defaults to True.
|
633
666
|
Returns:
|
634
667
|
None: Model weights are saved to output_dir.
|
@@ -712,7 +745,9 @@ def train_MaskRCNN_model(
|
|
712
745
|
# Create data loaders
|
713
746
|
# Use num_workers=0 on macOS and Windows to avoid multiprocessing issues
|
714
747
|
# Windows often has issues with multiprocessing in Jupyter notebooks
|
715
|
-
|
748
|
+
# Increase num_workers for better data loading performance
|
749
|
+
if num_workers is None:
|
750
|
+
num_workers = 0 if platform.system() in ["Darwin", "Windows"] else 8
|
716
751
|
|
717
752
|
train_loader = DataLoader(
|
718
753
|
train_dataset,
|
@@ -872,17 +907,17 @@ def train_MaskRCNN_model(
|
|
872
907
|
|
873
908
|
|
874
909
|
def inference_on_geotiff(
|
875
|
-
model,
|
876
|
-
geotiff_path,
|
877
|
-
output_path,
|
878
|
-
window_size=512,
|
879
|
-
overlap=256,
|
880
|
-
confidence_threshold=0.5,
|
881
|
-
batch_size=4,
|
882
|
-
num_channels=3,
|
883
|
-
device=None,
|
884
|
-
**kwargs,
|
885
|
-
):
|
910
|
+
model: torch.nn.Module,
|
911
|
+
geotiff_path: str,
|
912
|
+
output_path: str,
|
913
|
+
window_size: int = 512,
|
914
|
+
overlap: int = 256,
|
915
|
+
confidence_threshold: float = 0.5,
|
916
|
+
batch_size: int = 4,
|
917
|
+
num_channels: int = 3,
|
918
|
+
device: Optional[torch.device] = None,
|
919
|
+
**kwargs: Any,
|
920
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
886
921
|
"""
|
887
922
|
Perform inference on a large GeoTIFF using a sliding window approach with improved blending.
|
888
923
|
|
@@ -1096,17 +1131,17 @@ def inference_on_geotiff(
|
|
1096
1131
|
|
1097
1132
|
|
1098
1133
|
def instance_segmentation_inference_on_geotiff(
|
1099
|
-
model,
|
1100
|
-
geotiff_path,
|
1101
|
-
output_path,
|
1102
|
-
window_size=512,
|
1103
|
-
overlap=256,
|
1104
|
-
confidence_threshold=0.5,
|
1105
|
-
batch_size=4,
|
1106
|
-
num_channels=3,
|
1107
|
-
device=None,
|
1108
|
-
**kwargs,
|
1109
|
-
):
|
1134
|
+
model: torch.nn.Module,
|
1135
|
+
geotiff_path: str,
|
1136
|
+
output_path: str,
|
1137
|
+
window_size: int = 512,
|
1138
|
+
overlap: int = 256,
|
1139
|
+
confidence_threshold: float = 0.5,
|
1140
|
+
batch_size: int = 4,
|
1141
|
+
num_channels: int = 3,
|
1142
|
+
device: Optional[torch.device] = None,
|
1143
|
+
**kwargs: Any,
|
1144
|
+
) -> Tuple[str, float]:
|
1110
1145
|
"""
|
1111
1146
|
Perform instance segmentation inference on a large GeoTIFF using a sliding window approach.
|
1112
1147
|
|
@@ -1327,19 +1362,19 @@ def instance_segmentation_inference_on_geotiff(
|
|
1327
1362
|
|
1328
1363
|
|
1329
1364
|
def object_detection(
|
1330
|
-
input_path,
|
1331
|
-
output_path,
|
1332
|
-
model_path,
|
1333
|
-
window_size=512,
|
1334
|
-
overlap=256,
|
1335
|
-
confidence_threshold=0.5,
|
1336
|
-
batch_size=4,
|
1337
|
-
num_channels=3,
|
1338
|
-
model=None,
|
1339
|
-
pretrained=True,
|
1340
|
-
device=None,
|
1341
|
-
**kwargs,
|
1342
|
-
):
|
1365
|
+
input_path: str,
|
1366
|
+
output_path: str,
|
1367
|
+
model_path: str,
|
1368
|
+
window_size: int = 512,
|
1369
|
+
overlap: int = 256,
|
1370
|
+
confidence_threshold: float = 0.5,
|
1371
|
+
batch_size: int = 4,
|
1372
|
+
num_channels: int = 3,
|
1373
|
+
model: Optional[torch.nn.Module] = None,
|
1374
|
+
pretrained: bool = True,
|
1375
|
+
device: Optional[torch.device] = None,
|
1376
|
+
**kwargs: Any,
|
1377
|
+
) -> None:
|
1343
1378
|
"""
|
1344
1379
|
Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
|
1345
1380
|
|
@@ -1374,7 +1409,16 @@ def object_detection(
|
|
1374
1409
|
except Exception as e:
|
1375
1410
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
1376
1411
|
|
1377
|
-
|
1412
|
+
# Load state dict and handle DataParallel module prefix
|
1413
|
+
state_dict = torch.load(model_path, map_location=device)
|
1414
|
+
|
1415
|
+
# Remove 'module.' prefix if present (from DataParallel training)
|
1416
|
+
if any(key.startswith("module.") for key in state_dict.keys()):
|
1417
|
+
state_dict = {
|
1418
|
+
key.replace("module.", ""): value for key, value in state_dict.items()
|
1419
|
+
}
|
1420
|
+
|
1421
|
+
model.load_state_dict(state_dict)
|
1378
1422
|
model.to(device)
|
1379
1423
|
model.eval()
|
1380
1424
|
|
@@ -1393,20 +1437,20 @@ def object_detection(
|
|
1393
1437
|
|
1394
1438
|
|
1395
1439
|
def object_detection_batch(
|
1396
|
-
input_paths,
|
1397
|
-
output_dir,
|
1398
|
-
model_path,
|
1399
|
-
filenames=None,
|
1400
|
-
window_size=512,
|
1401
|
-
overlap=256,
|
1402
|
-
confidence_threshold=0.5,
|
1403
|
-
batch_size=4,
|
1404
|
-
model=None,
|
1405
|
-
num_channels=3,
|
1406
|
-
pretrained=True,
|
1407
|
-
device=None,
|
1408
|
-
**kwargs,
|
1409
|
-
):
|
1440
|
+
input_paths: Union[str, List[str]],
|
1441
|
+
output_dir: str,
|
1442
|
+
model_path: str,
|
1443
|
+
filenames: Optional[List[str]] = None,
|
1444
|
+
window_size: int = 512,
|
1445
|
+
overlap: int = 256,
|
1446
|
+
confidence_threshold: float = 0.5,
|
1447
|
+
batch_size: int = 4,
|
1448
|
+
model: Optional[torch.nn.Module] = None,
|
1449
|
+
num_channels: int = 3,
|
1450
|
+
pretrained: bool = True,
|
1451
|
+
device: Optional[torch.device] = None,
|
1452
|
+
**kwargs: Any,
|
1453
|
+
) -> None:
|
1410
1454
|
"""
|
1411
1455
|
Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
|
1412
1456
|
|
@@ -1449,7 +1493,16 @@ def object_detection_batch(
|
|
1449
1493
|
except Exception as e:
|
1450
1494
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
1451
1495
|
|
1452
|
-
|
1496
|
+
# Load state dict and handle DataParallel module prefix
|
1497
|
+
state_dict = torch.load(model_path, map_location=device)
|
1498
|
+
|
1499
|
+
# Remove 'module.' prefix if present (from DataParallel training)
|
1500
|
+
if any(key.startswith("module.") for key in state_dict.keys()):
|
1501
|
+
state_dict = {
|
1502
|
+
key.replace("module.", ""): value for key, value in state_dict.items()
|
1503
|
+
}
|
1504
|
+
|
1505
|
+
model.load_state_dict(state_dict)
|
1453
1506
|
model.to(device)
|
1454
1507
|
model.eval()
|
1455
1508
|
|
@@ -1489,14 +1542,14 @@ class SemanticSegmentationDataset(Dataset):
|
|
1489
1542
|
|
1490
1543
|
def __init__(
|
1491
1544
|
self,
|
1492
|
-
image_paths,
|
1493
|
-
label_paths,
|
1494
|
-
transforms=None,
|
1495
|
-
num_channels=None,
|
1496
|
-
target_size=None,
|
1497
|
-
resize_mode="resize",
|
1498
|
-
num_classes=2,
|
1499
|
-
):
|
1545
|
+
image_paths: List[str],
|
1546
|
+
label_paths: List[str],
|
1547
|
+
transforms: Optional[Callable] = None,
|
1548
|
+
num_channels: Optional[int] = None,
|
1549
|
+
target_size: Optional[Tuple[int, int]] = None,
|
1550
|
+
resize_mode: str = "resize",
|
1551
|
+
num_classes: int = 2,
|
1552
|
+
) -> None:
|
1500
1553
|
"""
|
1501
1554
|
Initialize dataset for semantic segmentation.
|
1502
1555
|
|
@@ -1526,11 +1579,11 @@ class SemanticSegmentationDataset(Dataset):
|
|
1526
1579
|
else:
|
1527
1580
|
self.num_channels = num_channels
|
1528
1581
|
|
1529
|
-
def _is_geotiff(self, file_path):
|
1582
|
+
def _is_geotiff(self, file_path: str) -> bool:
|
1530
1583
|
"""Check if file is a GeoTIFF based on extension."""
|
1531
1584
|
return file_path.lower().endswith((".tif", ".tiff"))
|
1532
1585
|
|
1533
|
-
def _get_num_channels(self, image_path):
|
1586
|
+
def _get_num_channels(self, image_path: str) -> int:
|
1534
1587
|
"""Get number of channels from an image file."""
|
1535
1588
|
if self._is_geotiff(image_path):
|
1536
1589
|
with rasterio.open(image_path) as src:
|
@@ -1548,7 +1601,9 @@ class SemanticSegmentationDataset(Dataset):
|
|
1548
1601
|
# Convert to RGB and return 3 channels
|
1549
1602
|
return 3
|
1550
1603
|
|
1551
|
-
def _resize_image_and_mask(
|
1604
|
+
def _resize_image_and_mask(
|
1605
|
+
self, image: np.ndarray, mask: np.ndarray
|
1606
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
1552
1607
|
"""Resize image and mask to target size."""
|
1553
1608
|
if self.target_size is None:
|
1554
1609
|
return image, mask
|
@@ -1586,7 +1641,9 @@ class SemanticSegmentationDataset(Dataset):
|
|
1586
1641
|
|
1587
1642
|
return image, mask
|
1588
1643
|
|
1589
|
-
def _pad_to_size(
|
1644
|
+
def _pad_to_size(
|
1645
|
+
self, tensor: torch.Tensor, target_size: Tuple[int, int]
|
1646
|
+
) -> torch.Tensor:
|
1590
1647
|
"""Pad tensor to target size with zeros."""
|
1591
1648
|
target_h, target_w = target_size
|
1592
1649
|
|
@@ -1618,10 +1675,10 @@ class SemanticSegmentationDataset(Dataset):
|
|
1618
1675
|
|
1619
1676
|
return padded
|
1620
1677
|
|
1621
|
-
def __len__(self):
|
1678
|
+
def __len__(self) -> int:
|
1622
1679
|
return len(self.image_paths)
|
1623
1680
|
|
1624
|
-
def __getitem__(self, idx):
|
1681
|
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
1625
1682
|
# Load image
|
1626
1683
|
image_path = self.image_paths[idx]
|
1627
1684
|
if self._is_geotiff(image_path):
|
@@ -1706,10 +1763,12 @@ class SemanticSegmentationDataset(Dataset):
|
|
1706
1763
|
class SemanticTransforms:
|
1707
1764
|
"""Custom transforms for semantic segmentation."""
|
1708
1765
|
|
1709
|
-
def __init__(self, transforms):
|
1766
|
+
def __init__(self, transforms: List[Callable]) -> None:
|
1710
1767
|
self.transforms = transforms
|
1711
1768
|
|
1712
|
-
def __call__(
|
1769
|
+
def __call__(
|
1770
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
1771
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1713
1772
|
for t in self.transforms:
|
1714
1773
|
image, mask = t(image, mask)
|
1715
1774
|
return image, mask
|
@@ -1718,17 +1777,21 @@ class SemanticTransforms:
|
|
1718
1777
|
class SemanticToTensor:
|
1719
1778
|
"""Convert numpy.ndarray to tensor for semantic segmentation."""
|
1720
1779
|
|
1721
|
-
def __call__(
|
1780
|
+
def __call__(
|
1781
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
1782
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1722
1783
|
return image, mask
|
1723
1784
|
|
1724
1785
|
|
1725
1786
|
class SemanticRandomHorizontalFlip:
|
1726
1787
|
"""Random horizontal flip transform for semantic segmentation."""
|
1727
1788
|
|
1728
|
-
def __init__(self, prob=0.5):
|
1789
|
+
def __init__(self, prob: float = 0.5) -> None:
|
1729
1790
|
self.prob = prob
|
1730
1791
|
|
1731
|
-
def __call__(
|
1792
|
+
def __call__(
|
1793
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
1794
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1732
1795
|
if random.random() < self.prob:
|
1733
1796
|
# Flip image and mask along width dimension
|
1734
1797
|
image = torch.flip(image, dims=[2])
|
@@ -1736,7 +1799,7 @@ class SemanticRandomHorizontalFlip:
|
|
1736
1799
|
return image, mask
|
1737
1800
|
|
1738
1801
|
|
1739
|
-
def get_semantic_transform(train):
|
1802
|
+
def get_semantic_transform(train: bool) -> Any:
|
1740
1803
|
"""
|
1741
1804
|
Get transforms for semantic segmentation data augmentation.
|
1742
1805
|
|
@@ -1756,14 +1819,14 @@ def get_semantic_transform(train):
|
|
1756
1819
|
|
1757
1820
|
|
1758
1821
|
def get_smp_model(
|
1759
|
-
architecture="unet",
|
1760
|
-
encoder_name="resnet34",
|
1761
|
-
encoder_weights="imagenet",
|
1762
|
-
in_channels=3,
|
1763
|
-
classes=2,
|
1764
|
-
activation=None,
|
1765
|
-
**kwargs,
|
1766
|
-
):
|
1822
|
+
architecture: str = "unet",
|
1823
|
+
encoder_name: str = "resnet34",
|
1824
|
+
encoder_weights: Optional[str] = "imagenet",
|
1825
|
+
in_channels: int = 3,
|
1826
|
+
classes: int = 2,
|
1827
|
+
activation: Optional[str] = None,
|
1828
|
+
**kwargs: Any,
|
1829
|
+
) -> torch.nn.Module:
|
1767
1830
|
"""
|
1768
1831
|
Get a segmentation model from segmentation-models-pytorch using the generic create_model function.
|
1769
1832
|
|
@@ -1852,7 +1915,12 @@ def get_smp_model(
|
|
1852
1915
|
)
|
1853
1916
|
|
1854
1917
|
|
1855
|
-
def dice_coefficient(
|
1918
|
+
def dice_coefficient(
|
1919
|
+
pred: torch.Tensor,
|
1920
|
+
target: torch.Tensor,
|
1921
|
+
smooth: float = 1e-6,
|
1922
|
+
num_classes: Optional[int] = None,
|
1923
|
+
) -> float:
|
1856
1924
|
"""
|
1857
1925
|
Calculate Dice coefficient for segmentation (binary or multi-class).
|
1858
1926
|
|
@@ -1894,7 +1962,12 @@ def dice_coefficient(pred, target, smooth=1e-6, num_classes=None):
|
|
1894
1962
|
return sum(dice_scores) / len(dice_scores) if dice_scores else 0.0
|
1895
1963
|
|
1896
1964
|
|
1897
|
-
def iou_coefficient(
|
1965
|
+
def iou_coefficient(
|
1966
|
+
pred: torch.Tensor,
|
1967
|
+
target: torch.Tensor,
|
1968
|
+
smooth: float = 1e-6,
|
1969
|
+
num_classes: Optional[int] = None,
|
1970
|
+
) -> float:
|
1898
1971
|
"""
|
1899
1972
|
Calculate IoU coefficient for segmentation (binary or multi-class).
|
1900
1973
|
|
@@ -1937,8 +2010,15 @@ def iou_coefficient(pred, target, smooth=1e-6, num_classes=None):
|
|
1937
2010
|
|
1938
2011
|
|
1939
2012
|
def train_semantic_one_epoch(
|
1940
|
-
model
|
1941
|
-
|
2013
|
+
model: torch.nn.Module,
|
2014
|
+
optimizer: torch.optim.Optimizer,
|
2015
|
+
data_loader: DataLoader,
|
2016
|
+
device: torch.device,
|
2017
|
+
epoch: int,
|
2018
|
+
criterion: Any,
|
2019
|
+
print_freq: int = 10,
|
2020
|
+
verbose: bool = True,
|
2021
|
+
) -> float:
|
1942
2022
|
"""
|
1943
2023
|
Train the semantic segmentation model for one epoch.
|
1944
2024
|
|
@@ -1992,7 +2072,13 @@ def train_semantic_one_epoch(
|
|
1992
2072
|
return avg_loss
|
1993
2073
|
|
1994
2074
|
|
1995
|
-
def evaluate_semantic(
|
2075
|
+
def evaluate_semantic(
|
2076
|
+
model: torch.nn.Module,
|
2077
|
+
data_loader: DataLoader,
|
2078
|
+
device: torch.device,
|
2079
|
+
criterion: Any,
|
2080
|
+
num_classes: int = 2,
|
2081
|
+
) -> Dict[str, float]:
|
1996
2082
|
"""
|
1997
2083
|
Evaluate the semantic segmentation model on the validation set.
|
1998
2084
|
|
@@ -2040,31 +2126,32 @@ def evaluate_semantic(model, data_loader, device, criterion, num_classes=2):
|
|
2040
2126
|
|
2041
2127
|
|
2042
2128
|
def train_segmentation_model(
|
2043
|
-
images_dir,
|
2044
|
-
labels_dir,
|
2045
|
-
output_dir,
|
2046
|
-
architecture="unet",
|
2047
|
-
encoder_name="resnet34",
|
2048
|
-
encoder_weights="imagenet",
|
2049
|
-
num_channels=3,
|
2050
|
-
num_classes=2,
|
2051
|
-
batch_size=8,
|
2052
|
-
num_epochs=50,
|
2053
|
-
learning_rate=0.001,
|
2054
|
-
weight_decay=1e-4,
|
2055
|
-
seed=42,
|
2056
|
-
val_split=0.2,
|
2057
|
-
print_freq=10,
|
2058
|
-
verbose=True,
|
2059
|
-
save_best_only=True,
|
2060
|
-
plot_curves=False,
|
2061
|
-
device=None,
|
2062
|
-
checkpoint_path=None,
|
2063
|
-
resume_training=False,
|
2064
|
-
target_size=None,
|
2065
|
-
resize_mode="resize",
|
2066
|
-
|
2067
|
-
|
2129
|
+
images_dir: str,
|
2130
|
+
labels_dir: str,
|
2131
|
+
output_dir: str,
|
2132
|
+
architecture: str = "unet",
|
2133
|
+
encoder_name: str = "resnet34",
|
2134
|
+
encoder_weights: Optional[str] = "imagenet",
|
2135
|
+
num_channels: int = 3,
|
2136
|
+
num_classes: int = 2,
|
2137
|
+
batch_size: int = 8,
|
2138
|
+
num_epochs: int = 50,
|
2139
|
+
learning_rate: float = 0.001,
|
2140
|
+
weight_decay: float = 1e-4,
|
2141
|
+
seed: int = 42,
|
2142
|
+
val_split: float = 0.2,
|
2143
|
+
print_freq: int = 10,
|
2144
|
+
verbose: bool = True,
|
2145
|
+
save_best_only: bool = True,
|
2146
|
+
plot_curves: bool = False,
|
2147
|
+
device: Optional[torch.device] = None,
|
2148
|
+
checkpoint_path: Optional[str] = None,
|
2149
|
+
resume_training: bool = False,
|
2150
|
+
target_size: Optional[Tuple[int, int]] = None,
|
2151
|
+
resize_mode: str = "resize",
|
2152
|
+
num_workers: Optional[int] = None,
|
2153
|
+
**kwargs: Any,
|
2154
|
+
) -> torch.nn.Module:
|
2068
2155
|
"""
|
2069
2156
|
Train a semantic segmentation model for object detection using segmentation-models-pytorch.
|
2070
2157
|
|
@@ -2106,6 +2193,7 @@ def train_segmentation_model(
|
|
2106
2193
|
resize_mode (str): How to handle size standardization when target_size is specified.
|
2107
2194
|
'resize' - Resize images to target_size (may change aspect ratio)
|
2108
2195
|
'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
|
2196
|
+
num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
|
2109
2197
|
**kwargs: Additional arguments passed to smp.create_model().
|
2110
2198
|
Returns:
|
2111
2199
|
None: Model weights are saved to output_dir.
|
@@ -2252,7 +2340,9 @@ def train_segmentation_model(
|
|
2252
2340
|
# Create data loaders
|
2253
2341
|
# Use num_workers=0 on macOS and Windows to avoid multiprocessing issues
|
2254
2342
|
# Windows often has issues with multiprocessing in Jupyter notebooks
|
2255
|
-
|
2343
|
+
# Increase num_workers for better data loading performance
|
2344
|
+
if num_workers is None:
|
2345
|
+
num_workers = 0 if platform.system() in ["Darwin", "Windows"] else 8
|
2256
2346
|
|
2257
2347
|
try:
|
2258
2348
|
train_loader = DataLoader(
|
@@ -2310,6 +2400,11 @@ def train_segmentation_model(
|
|
2310
2400
|
)
|
2311
2401
|
model.to(device)
|
2312
2402
|
|
2403
|
+
# Enable multi-GPU training if multiple GPUs are available
|
2404
|
+
if torch.cuda.device_count() > 1:
|
2405
|
+
print(f"Using {torch.cuda.device_count()} GPUs for training")
|
2406
|
+
model = torch.nn.DataParallel(model)
|
2407
|
+
|
2313
2408
|
# Set up loss function (CrossEntropyLoss for multi-class, can also use DiceLoss)
|
2314
2409
|
criterion = torch.nn.CrossEntropyLoss()
|
2315
2410
|
|
@@ -2521,18 +2616,18 @@ def train_segmentation_model(
|
|
2521
2616
|
|
2522
2617
|
|
2523
2618
|
def semantic_inference_on_geotiff(
|
2524
|
-
model,
|
2525
|
-
geotiff_path,
|
2526
|
-
output_path,
|
2527
|
-
window_size=512,
|
2528
|
-
overlap=256,
|
2529
|
-
batch_size=4,
|
2530
|
-
num_channels=3,
|
2531
|
-
num_classes=2,
|
2532
|
-
device=None,
|
2533
|
-
quiet=False,
|
2534
|
-
**kwargs,
|
2535
|
-
):
|
2619
|
+
model: torch.nn.Module,
|
2620
|
+
geotiff_path: str,
|
2621
|
+
output_path: str,
|
2622
|
+
window_size: int = 512,
|
2623
|
+
overlap: int = 256,
|
2624
|
+
batch_size: int = 4,
|
2625
|
+
num_channels: int = 3,
|
2626
|
+
num_classes: int = 2,
|
2627
|
+
device: Optional[torch.device] = None,
|
2628
|
+
quiet: bool = False,
|
2629
|
+
**kwargs: Any,
|
2630
|
+
) -> Tuple[str, float]:
|
2536
2631
|
"""
|
2537
2632
|
Perform semantic segmentation inference on a large GeoTIFF using a sliding window approach.
|
2538
2633
|
|
@@ -2748,19 +2843,19 @@ def semantic_inference_on_geotiff(
|
|
2748
2843
|
|
2749
2844
|
|
2750
2845
|
def semantic_inference_on_image(
|
2751
|
-
model,
|
2752
|
-
image_path,
|
2753
|
-
output_path,
|
2754
|
-
window_size=512,
|
2755
|
-
overlap=256,
|
2756
|
-
batch_size=4,
|
2757
|
-
num_channels=3,
|
2758
|
-
num_classes=2,
|
2759
|
-
device=None,
|
2760
|
-
binary_output=True,
|
2761
|
-
quiet=False,
|
2762
|
-
**kwargs,
|
2763
|
-
):
|
2846
|
+
model: torch.nn.Module,
|
2847
|
+
image_path: str,
|
2848
|
+
output_path: str,
|
2849
|
+
window_size: int = 512,
|
2850
|
+
overlap: int = 256,
|
2851
|
+
batch_size: int = 4,
|
2852
|
+
num_channels: int = 3,
|
2853
|
+
num_classes: int = 2,
|
2854
|
+
device: Optional[torch.device] = None,
|
2855
|
+
binary_output: bool = True,
|
2856
|
+
quiet: bool = False,
|
2857
|
+
**kwargs: Any,
|
2858
|
+
) -> Tuple[str, float]:
|
2764
2859
|
"""
|
2765
2860
|
Perform semantic segmentation inference on a regular image (JPG, PNG, etc.) using a sliding window approach.
|
2766
2861
|
|
@@ -3025,20 +3120,20 @@ def semantic_inference_on_image(
|
|
3025
3120
|
|
3026
3121
|
|
3027
3122
|
def semantic_segmentation(
|
3028
|
-
input_path,
|
3029
|
-
output_path,
|
3030
|
-
model_path,
|
3031
|
-
architecture="unet",
|
3032
|
-
encoder_name="resnet34",
|
3033
|
-
num_channels=3,
|
3034
|
-
num_classes=2,
|
3035
|
-
window_size=512,
|
3036
|
-
overlap=256,
|
3037
|
-
batch_size=4,
|
3038
|
-
device=None,
|
3039
|
-
quiet=False,
|
3040
|
-
**kwargs,
|
3041
|
-
):
|
3123
|
+
input_path: str,
|
3124
|
+
output_path: str,
|
3125
|
+
model_path: str,
|
3126
|
+
architecture: str = "unet",
|
3127
|
+
encoder_name: str = "resnet34",
|
3128
|
+
num_channels: int = 3,
|
3129
|
+
num_classes: int = 2,
|
3130
|
+
window_size: int = 512,
|
3131
|
+
overlap: int = 256,
|
3132
|
+
batch_size: int = 4,
|
3133
|
+
device: Optional[torch.device] = None,
|
3134
|
+
quiet: bool = False,
|
3135
|
+
**kwargs: Any,
|
3136
|
+
) -> None:
|
3042
3137
|
"""
|
3043
3138
|
Perform semantic segmentation on an image file using a trained model.
|
3044
3139
|
|
@@ -3091,7 +3186,16 @@ def semantic_segmentation(
|
|
3091
3186
|
except Exception as e:
|
3092
3187
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
3093
3188
|
|
3094
|
-
|
3189
|
+
# Load state dict and handle DataParallel module prefix
|
3190
|
+
state_dict = torch.load(model_path, map_location=device)
|
3191
|
+
|
3192
|
+
# Remove 'module.' prefix if present (from DataParallel training)
|
3193
|
+
if any(key.startswith("module.") for key in state_dict.keys()):
|
3194
|
+
state_dict = {
|
3195
|
+
key.replace("module.", ""): value for key, value in state_dict.items()
|
3196
|
+
}
|
3197
|
+
|
3198
|
+
model.load_state_dict(state_dict)
|
3095
3199
|
model.to(device)
|
3096
3200
|
model.eval()
|
3097
3201
|
|
@@ -3131,21 +3235,21 @@ def semantic_segmentation(
|
|
3131
3235
|
|
3132
3236
|
|
3133
3237
|
def semantic_segmentation_batch(
|
3134
|
-
input_dir,
|
3135
|
-
output_dir,
|
3136
|
-
model_path,
|
3137
|
-
architecture="unet",
|
3138
|
-
encoder_name="resnet34",
|
3139
|
-
num_channels=3,
|
3140
|
-
num_classes=2,
|
3141
|
-
window_size=512,
|
3142
|
-
overlap=256,
|
3143
|
-
batch_size=4,
|
3144
|
-
device=None,
|
3145
|
-
filenames=None,
|
3146
|
-
quiet=False,
|
3147
|
-
**kwargs,
|
3148
|
-
):
|
3238
|
+
input_dir: str,
|
3239
|
+
output_dir: str,
|
3240
|
+
model_path: str,
|
3241
|
+
architecture: str = "unet",
|
3242
|
+
encoder_name: str = "resnet34",
|
3243
|
+
num_channels: int = 3,
|
3244
|
+
num_classes: int = 2,
|
3245
|
+
window_size: int = 512,
|
3246
|
+
overlap: int = 256,
|
3247
|
+
batch_size: int = 4,
|
3248
|
+
device: Optional[torch.device] = None,
|
3249
|
+
filenames: Optional[List[str]] = None,
|
3250
|
+
quiet: bool = False,
|
3251
|
+
**kwargs: Any,
|
3252
|
+
) -> None:
|
3149
3253
|
"""
|
3150
3254
|
Perform semantic segmentation on a batch of images from an input directory.
|
3151
3255
|
|
@@ -3220,7 +3324,16 @@ def semantic_segmentation_batch(
|
|
3220
3324
|
except Exception as e:
|
3221
3325
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
3222
3326
|
|
3223
|
-
|
3327
|
+
# Load state dict and handle DataParallel module prefix
|
3328
|
+
state_dict = torch.load(model_path, map_location=device)
|
3329
|
+
|
3330
|
+
# Remove 'module.' prefix if present (from DataParallel training)
|
3331
|
+
if any(key.startswith("module.") for key in state_dict.keys()):
|
3332
|
+
state_dict = {
|
3333
|
+
key.replace("module.", ""): value for key, value in state_dict.items()
|
3334
|
+
}
|
3335
|
+
|
3336
|
+
model.load_state_dict(state_dict)
|
3224
3337
|
model.to(device)
|
3225
3338
|
model.eval()
|
3226
3339
|
|
@@ -3295,21 +3408,21 @@ def semantic_segmentation_batch(
|
|
3295
3408
|
|
3296
3409
|
|
3297
3410
|
def train_instance_segmentation_model(
|
3298
|
-
images_dir,
|
3299
|
-
labels_dir,
|
3300
|
-
output_dir,
|
3301
|
-
num_classes=2,
|
3302
|
-
num_channels=3,
|
3303
|
-
batch_size=4,
|
3304
|
-
num_epochs=10,
|
3305
|
-
learning_rate=0.005,
|
3306
|
-
seed=42,
|
3307
|
-
val_split=0.2,
|
3308
|
-
visualize=False,
|
3309
|
-
device=None,
|
3310
|
-
verbose=True,
|
3311
|
-
**kwargs,
|
3312
|
-
):
|
3411
|
+
images_dir: str,
|
3412
|
+
labels_dir: str,
|
3413
|
+
output_dir: str,
|
3414
|
+
num_classes: int = 2,
|
3415
|
+
num_channels: int = 3,
|
3416
|
+
batch_size: int = 4,
|
3417
|
+
num_epochs: int = 10,
|
3418
|
+
learning_rate: float = 0.005,
|
3419
|
+
seed: int = 42,
|
3420
|
+
val_split: float = 0.2,
|
3421
|
+
visualize: bool = False,
|
3422
|
+
device: Optional[torch.device] = None,
|
3423
|
+
verbose: bool = True,
|
3424
|
+
**kwargs: Any,
|
3425
|
+
) -> torch.nn.Module:
|
3313
3426
|
"""
|
3314
3427
|
Train an instance segmentation model using Mask R-CNN.
|
3315
3428
|
|
@@ -3358,18 +3471,18 @@ def train_instance_segmentation_model(
|
|
3358
3471
|
|
3359
3472
|
|
3360
3473
|
def instance_segmentation(
|
3361
|
-
input_path,
|
3362
|
-
output_path,
|
3363
|
-
model_path,
|
3364
|
-
window_size=512,
|
3365
|
-
overlap=256,
|
3366
|
-
confidence_threshold=0.5,
|
3367
|
-
batch_size=4,
|
3368
|
-
num_channels=3,
|
3369
|
-
num_classes=2,
|
3370
|
-
device=None,
|
3371
|
-
**kwargs,
|
3372
|
-
):
|
3474
|
+
input_path: str,
|
3475
|
+
output_path: str,
|
3476
|
+
model_path: str,
|
3477
|
+
window_size: int = 512,
|
3478
|
+
overlap: int = 256,
|
3479
|
+
confidence_threshold: float = 0.5,
|
3480
|
+
batch_size: int = 4,
|
3481
|
+
num_channels: int = 3,
|
3482
|
+
num_classes: int = 2,
|
3483
|
+
device: Optional[torch.device] = None,
|
3484
|
+
**kwargs: Any,
|
3485
|
+
) -> None:
|
3373
3486
|
"""
|
3374
3487
|
Perform instance segmentation on a GeoTIFF using a pre-trained Mask R-CNN model.
|
3375
3488
|
|
@@ -3400,7 +3513,16 @@ def instance_segmentation(
|
|
3400
3513
|
if device is None:
|
3401
3514
|
device = get_device()
|
3402
3515
|
|
3403
|
-
|
3516
|
+
# Load state dict and handle DataParallel module prefix
|
3517
|
+
state_dict = torch.load(model_path, map_location=device)
|
3518
|
+
|
3519
|
+
# Remove 'module.' prefix if present (from DataParallel training)
|
3520
|
+
if any(key.startswith("module.") for key in state_dict.keys()):
|
3521
|
+
state_dict = {
|
3522
|
+
key.replace("module.", ""): value for key, value in state_dict.items()
|
3523
|
+
}
|
3524
|
+
|
3525
|
+
model.load_state_dict(state_dict)
|
3404
3526
|
model.to(device)
|
3405
3527
|
|
3406
3528
|
# Use the proper instance segmentation inference function
|
@@ -3419,18 +3541,18 @@ def instance_segmentation(
|
|
3419
3541
|
|
3420
3542
|
|
3421
3543
|
def instance_segmentation_batch(
|
3422
|
-
input_dir,
|
3423
|
-
output_dir,
|
3424
|
-
model_path,
|
3425
|
-
window_size=512,
|
3426
|
-
overlap=256,
|
3427
|
-
confidence_threshold=0.5,
|
3428
|
-
batch_size=4,
|
3429
|
-
num_channels=3,
|
3430
|
-
num_classes=2,
|
3431
|
-
device=None,
|
3432
|
-
**kwargs,
|
3433
|
-
):
|
3544
|
+
input_dir: str,
|
3545
|
+
output_dir: str,
|
3546
|
+
model_path: str,
|
3547
|
+
window_size: int = 512,
|
3548
|
+
overlap: int = 256,
|
3549
|
+
confidence_threshold: float = 0.5,
|
3550
|
+
batch_size: int = 4,
|
3551
|
+
num_channels: int = 3,
|
3552
|
+
num_classes: int = 2,
|
3553
|
+
device: Optional[torch.device] = None,
|
3554
|
+
**kwargs: Any,
|
3555
|
+
) -> None:
|
3434
3556
|
"""
|
3435
3557
|
Perform instance segmentation on multiple GeoTIFF files using a pre-trained Mask R-CNN model.
|
3436
3558
|
|
@@ -3461,7 +3583,16 @@ def instance_segmentation_batch(
|
|
3461
3583
|
if device is None:
|
3462
3584
|
device = get_device()
|
3463
3585
|
|
3464
|
-
|
3586
|
+
# Load state dict and handle DataParallel module prefix
|
3587
|
+
state_dict = torch.load(model_path, map_location=device)
|
3588
|
+
|
3589
|
+
# Remove 'module.' prefix if present (from DataParallel training)
|
3590
|
+
if any(key.startswith("module.") for key in state_dict.keys()):
|
3591
|
+
state_dict = {
|
3592
|
+
key.replace("module.", ""): value for key, value in state_dict.items()
|
3593
|
+
}
|
3594
|
+
|
3595
|
+
model.load_state_dict(state_dict)
|
3465
3596
|
model.to(device)
|
3466
3597
|
|
3467
3598
|
# Process all GeoTIFF files in the input directory
|