geoai-py 0.8.3__py2.py3-none-any.whl → 0.9.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/train.py CHANGED
@@ -4,6 +4,7 @@ import os
4
4
  import platform
5
5
  import random
6
6
  import time
7
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
7
8
 
8
9
  import matplotlib.pyplot as plt
9
10
  import numpy as np
@@ -34,7 +35,9 @@ except ImportError:
34
35
  SMP_AVAILABLE = False
35
36
 
36
37
 
37
- def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=True):
38
+ def get_instance_segmentation_model(
39
+ num_classes: int = 2, num_channels: int = 3, pretrained: bool = True
40
+ ) -> torch.nn.Module:
38
41
  """
39
42
  Get Mask R-CNN model with custom input channels and output classes.
40
43
 
@@ -129,7 +132,13 @@ def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=Tr
129
132
  class ObjectDetectionDataset(Dataset):
130
133
  """Dataset for object detection from GeoTIFF images and labels."""
131
134
 
132
- def __init__(self, image_paths, label_paths, transforms=None, num_channels=None):
135
+ def __init__(
136
+ self,
137
+ image_paths: List[str],
138
+ label_paths: List[str],
139
+ transforms: Optional[Callable] = None,
140
+ num_channels: Optional[int] = None,
141
+ ) -> None:
133
142
  """
134
143
  Initialize dataset.
135
144
 
@@ -151,10 +160,10 @@ class ObjectDetectionDataset(Dataset):
151
160
  else:
152
161
  self.num_channels = num_channels
153
162
 
154
- def __len__(self):
163
+ def __len__(self) -> int:
155
164
  return len(self.image_paths)
156
165
 
157
- def __getitem__(self, idx):
166
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
158
167
  # Load image
159
168
  with rasterio.open(self.image_paths[idx]) as src:
160
169
  # Read as [C, H, W] format
@@ -270,7 +279,7 @@ class ObjectDetectionDataset(Dataset):
270
279
  class Compose:
271
280
  """Custom compose transform that works with image and target."""
272
281
 
273
- def __init__(self, transforms):
282
+ def __init__(self, transforms: List[Callable]) -> None:
274
283
  """
275
284
  Initialize compose transform.
276
285
 
@@ -279,7 +288,9 @@ class Compose:
279
288
  """
280
289
  self.transforms = transforms
281
290
 
282
- def __call__(self, image, target):
291
+ def __call__(
292
+ self, image: torch.Tensor, target: Dict[str, torch.Tensor]
293
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
283
294
  for t in self.transforms:
284
295
  image, target = t(image, target)
285
296
  return image, target
@@ -288,7 +299,9 @@ class Compose:
288
299
  class ToTensor:
289
300
  """Convert numpy.ndarray to tensor."""
290
301
 
291
- def __call__(self, image, target):
302
+ def __call__(
303
+ self, image: torch.Tensor, target: Dict[str, torch.Tensor]
304
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
292
305
  """
293
306
  Apply transform to image and target.
294
307
 
@@ -305,7 +318,7 @@ class ToTensor:
305
318
  class RandomHorizontalFlip:
306
319
  """Random horizontal flip transform."""
307
320
 
308
- def __init__(self, prob=0.5):
321
+ def __init__(self, prob: float = 0.5) -> None:
309
322
  """
310
323
  Initialize random horizontal flip.
311
324
 
@@ -314,7 +327,9 @@ class RandomHorizontalFlip:
314
327
  """
315
328
  self.prob = prob
316
329
 
317
- def __call__(self, image, target):
330
+ def __call__(
331
+ self, image: torch.Tensor, target: Dict[str, torch.Tensor]
332
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
318
333
  if random.random() < self.prob:
319
334
  # Flip image
320
335
  image = torch.flip(image, dims=[2]) # Flip along width dimension
@@ -333,7 +348,7 @@ class RandomHorizontalFlip:
333
348
  return image, target
334
349
 
335
350
 
336
- def get_transform(train):
351
+ def get_transform(train: bool) -> torchvision.transforms.Compose:
337
352
  """
338
353
  Get transforms for data augmentation.
339
354
 
@@ -352,7 +367,9 @@ def get_transform(train):
352
367
  return Compose(transforms)
353
368
 
354
369
 
355
- def collate_fn(batch):
370
+ def collate_fn(
371
+ batch: List[Tuple[torch.Tensor, Dict[str, torch.Tensor]]],
372
+ ) -> Tuple[Tuple[torch.Tensor, ...], Tuple[Dict[str, torch.Tensor], ...]]:
356
373
  """
357
374
  Custom collate function for batching samples.
358
375
 
@@ -366,8 +383,14 @@ def collate_fn(batch):
366
383
 
367
384
 
368
385
  def train_one_epoch(
369
- model, optimizer, data_loader, device, epoch, print_freq=10, verbose=True
370
- ):
386
+ model: torch.nn.Module,
387
+ optimizer: torch.optim.Optimizer,
388
+ data_loader: DataLoader,
389
+ device: torch.device,
390
+ epoch: int,
391
+ print_freq: int = 10,
392
+ verbose: bool = True,
393
+ ) -> float:
371
394
  """
372
395
  Train the model for one epoch.
373
396
 
@@ -419,7 +442,9 @@ def train_one_epoch(
419
442
  return avg_loss
420
443
 
421
444
 
422
- def evaluate(model, data_loader, device):
445
+ def evaluate(
446
+ model: torch.nn.Module, data_loader: DataLoader, device: torch.device
447
+ ) -> Dict[str, float]:
423
448
  """
424
449
  Evaluate the model on the validation set.
425
450
 
@@ -495,7 +520,13 @@ def evaluate(model, data_loader, device):
495
520
  return {"loss": avg_loss, "IoU": avg_iou}
496
521
 
497
522
 
498
- def visualize_predictions(model, dataset, device, num_samples=5, output_dir=None):
523
+ def visualize_predictions(
524
+ model: torch.nn.Module,
525
+ dataset: Dataset,
526
+ device: torch.device,
527
+ num_samples: int = 5,
528
+ output_dir: Optional[str] = None,
529
+ ) -> None:
499
530
  """
500
531
  Visualize model predictions.
501
532
 
@@ -583,24 +614,25 @@ def visualize_predictions(model, dataset, device, num_samples=5, output_dir=None
583
614
 
584
615
 
585
616
  def train_MaskRCNN_model(
586
- images_dir,
587
- labels_dir,
588
- output_dir,
589
- num_channels=3,
590
- model=None,
591
- pretrained=True,
592
- pretrained_model_path=None,
593
- batch_size=4,
594
- num_epochs=10,
595
- learning_rate=0.005,
596
- seed=42,
597
- val_split=0.2,
598
- visualize=False,
599
- resume_training=False,
600
- print_freq=10,
601
- device=None,
602
- verbose=True,
603
- ):
617
+ images_dir: str,
618
+ labels_dir: str,
619
+ output_dir: str,
620
+ num_channels: int = 3,
621
+ model: Optional[torch.nn.Module] = None,
622
+ pretrained: bool = True,
623
+ pretrained_model_path: Optional[str] = None,
624
+ batch_size: int = 4,
625
+ num_epochs: int = 10,
626
+ learning_rate: float = 0.005,
627
+ seed: int = 42,
628
+ val_split: float = 0.2,
629
+ visualize: bool = False,
630
+ resume_training: bool = False,
631
+ print_freq: int = 10,
632
+ device: Optional[torch.device] = None,
633
+ num_workers: Optional[int] = None,
634
+ verbose: bool = True,
635
+ ) -> torch.nn.Module:
604
636
  """Train and evaluate Mask R-CNN model for instance segmentation.
605
637
 
606
638
  This function trains a Mask R-CNN model for instance segmentation using the
@@ -629,6 +661,7 @@ def train_MaskRCNN_model(
629
661
  will try to load optimizer and scheduler states as well. Defaults to False.
630
662
  print_freq (int): Frequency of printing training progress. Defaults to 10.
631
663
  device (torch.device): Device to train on. If None, uses CUDA if available.
664
+ num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
632
665
  verbose (bool): If True, prints detailed training progress. Defaults to True.
633
666
  Returns:
634
667
  None: Model weights are saved to output_dir.
@@ -712,7 +745,9 @@ def train_MaskRCNN_model(
712
745
  # Create data loaders
713
746
  # Use num_workers=0 on macOS and Windows to avoid multiprocessing issues
714
747
  # Windows often has issues with multiprocessing in Jupyter notebooks
715
- num_workers = 0 if platform.system() in ["Darwin", "Windows"] else 4
748
+ # Increase num_workers for better data loading performance
749
+ if num_workers is None:
750
+ num_workers = 0 if platform.system() in ["Darwin", "Windows"] else 8
716
751
 
717
752
  train_loader = DataLoader(
718
753
  train_dataset,
@@ -872,17 +907,17 @@ def train_MaskRCNN_model(
872
907
 
873
908
 
874
909
  def inference_on_geotiff(
875
- model,
876
- geotiff_path,
877
- output_path,
878
- window_size=512,
879
- overlap=256,
880
- confidence_threshold=0.5,
881
- batch_size=4,
882
- num_channels=3,
883
- device=None,
884
- **kwargs,
885
- ):
910
+ model: torch.nn.Module,
911
+ geotiff_path: str,
912
+ output_path: str,
913
+ window_size: int = 512,
914
+ overlap: int = 256,
915
+ confidence_threshold: float = 0.5,
916
+ batch_size: int = 4,
917
+ num_channels: int = 3,
918
+ device: Optional[torch.device] = None,
919
+ **kwargs: Any,
920
+ ) -> Tuple[np.ndarray, np.ndarray]:
886
921
  """
887
922
  Perform inference on a large GeoTIFF using a sliding window approach with improved blending.
888
923
 
@@ -1096,17 +1131,17 @@ def inference_on_geotiff(
1096
1131
 
1097
1132
 
1098
1133
  def instance_segmentation_inference_on_geotiff(
1099
- model,
1100
- geotiff_path,
1101
- output_path,
1102
- window_size=512,
1103
- overlap=256,
1104
- confidence_threshold=0.5,
1105
- batch_size=4,
1106
- num_channels=3,
1107
- device=None,
1108
- **kwargs,
1109
- ):
1134
+ model: torch.nn.Module,
1135
+ geotiff_path: str,
1136
+ output_path: str,
1137
+ window_size: int = 512,
1138
+ overlap: int = 256,
1139
+ confidence_threshold: float = 0.5,
1140
+ batch_size: int = 4,
1141
+ num_channels: int = 3,
1142
+ device: Optional[torch.device] = None,
1143
+ **kwargs: Any,
1144
+ ) -> Tuple[str, float]:
1110
1145
  """
1111
1146
  Perform instance segmentation inference on a large GeoTIFF using a sliding window approach.
1112
1147
 
@@ -1327,19 +1362,19 @@ def instance_segmentation_inference_on_geotiff(
1327
1362
 
1328
1363
 
1329
1364
  def object_detection(
1330
- input_path,
1331
- output_path,
1332
- model_path,
1333
- window_size=512,
1334
- overlap=256,
1335
- confidence_threshold=0.5,
1336
- batch_size=4,
1337
- num_channels=3,
1338
- model=None,
1339
- pretrained=True,
1340
- device=None,
1341
- **kwargs,
1342
- ):
1365
+ input_path: str,
1366
+ output_path: str,
1367
+ model_path: str,
1368
+ window_size: int = 512,
1369
+ overlap: int = 256,
1370
+ confidence_threshold: float = 0.5,
1371
+ batch_size: int = 4,
1372
+ num_channels: int = 3,
1373
+ model: Optional[torch.nn.Module] = None,
1374
+ pretrained: bool = True,
1375
+ device: Optional[torch.device] = None,
1376
+ **kwargs: Any,
1377
+ ) -> None:
1343
1378
  """
1344
1379
  Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
1345
1380
 
@@ -1374,7 +1409,16 @@ def object_detection(
1374
1409
  except Exception as e:
1375
1410
  raise FileNotFoundError(f"Model file not found: {model_path}")
1376
1411
 
1377
- model.load_state_dict(torch.load(model_path, map_location=device))
1412
+ # Load state dict and handle DataParallel module prefix
1413
+ state_dict = torch.load(model_path, map_location=device)
1414
+
1415
+ # Remove 'module.' prefix if present (from DataParallel training)
1416
+ if any(key.startswith("module.") for key in state_dict.keys()):
1417
+ state_dict = {
1418
+ key.replace("module.", ""): value for key, value in state_dict.items()
1419
+ }
1420
+
1421
+ model.load_state_dict(state_dict)
1378
1422
  model.to(device)
1379
1423
  model.eval()
1380
1424
 
@@ -1393,20 +1437,20 @@ def object_detection(
1393
1437
 
1394
1438
 
1395
1439
  def object_detection_batch(
1396
- input_paths,
1397
- output_dir,
1398
- model_path,
1399
- filenames=None,
1400
- window_size=512,
1401
- overlap=256,
1402
- confidence_threshold=0.5,
1403
- batch_size=4,
1404
- model=None,
1405
- num_channels=3,
1406
- pretrained=True,
1407
- device=None,
1408
- **kwargs,
1409
- ):
1440
+ input_paths: Union[str, List[str]],
1441
+ output_dir: str,
1442
+ model_path: str,
1443
+ filenames: Optional[List[str]] = None,
1444
+ window_size: int = 512,
1445
+ overlap: int = 256,
1446
+ confidence_threshold: float = 0.5,
1447
+ batch_size: int = 4,
1448
+ model: Optional[torch.nn.Module] = None,
1449
+ num_channels: int = 3,
1450
+ pretrained: bool = True,
1451
+ device: Optional[torch.device] = None,
1452
+ **kwargs: Any,
1453
+ ) -> None:
1410
1454
  """
1411
1455
  Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
1412
1456
 
@@ -1449,7 +1493,16 @@ def object_detection_batch(
1449
1493
  except Exception as e:
1450
1494
  raise FileNotFoundError(f"Model file not found: {model_path}")
1451
1495
 
1452
- model.load_state_dict(torch.load(model_path, map_location=device))
1496
+ # Load state dict and handle DataParallel module prefix
1497
+ state_dict = torch.load(model_path, map_location=device)
1498
+
1499
+ # Remove 'module.' prefix if present (from DataParallel training)
1500
+ if any(key.startswith("module.") for key in state_dict.keys()):
1501
+ state_dict = {
1502
+ key.replace("module.", ""): value for key, value in state_dict.items()
1503
+ }
1504
+
1505
+ model.load_state_dict(state_dict)
1453
1506
  model.to(device)
1454
1507
  model.eval()
1455
1508
 
@@ -1489,14 +1542,14 @@ class SemanticSegmentationDataset(Dataset):
1489
1542
 
1490
1543
  def __init__(
1491
1544
  self,
1492
- image_paths,
1493
- label_paths,
1494
- transforms=None,
1495
- num_channels=None,
1496
- target_size=None,
1497
- resize_mode="resize",
1498
- num_classes=2,
1499
- ):
1545
+ image_paths: List[str],
1546
+ label_paths: List[str],
1547
+ transforms: Optional[Callable] = None,
1548
+ num_channels: Optional[int] = None,
1549
+ target_size: Optional[Tuple[int, int]] = None,
1550
+ resize_mode: str = "resize",
1551
+ num_classes: int = 2,
1552
+ ) -> None:
1500
1553
  """
1501
1554
  Initialize dataset for semantic segmentation.
1502
1555
 
@@ -1526,11 +1579,11 @@ class SemanticSegmentationDataset(Dataset):
1526
1579
  else:
1527
1580
  self.num_channels = num_channels
1528
1581
 
1529
- def _is_geotiff(self, file_path):
1582
+ def _is_geotiff(self, file_path: str) -> bool:
1530
1583
  """Check if file is a GeoTIFF based on extension."""
1531
1584
  return file_path.lower().endswith((".tif", ".tiff"))
1532
1585
 
1533
- def _get_num_channels(self, image_path):
1586
+ def _get_num_channels(self, image_path: str) -> int:
1534
1587
  """Get number of channels from an image file."""
1535
1588
  if self._is_geotiff(image_path):
1536
1589
  with rasterio.open(image_path) as src:
@@ -1548,7 +1601,9 @@ class SemanticSegmentationDataset(Dataset):
1548
1601
  # Convert to RGB and return 3 channels
1549
1602
  return 3
1550
1603
 
1551
- def _resize_image_and_mask(self, image, mask):
1604
+ def _resize_image_and_mask(
1605
+ self, image: np.ndarray, mask: np.ndarray
1606
+ ) -> Tuple[np.ndarray, np.ndarray]:
1552
1607
  """Resize image and mask to target size."""
1553
1608
  if self.target_size is None:
1554
1609
  return image, mask
@@ -1586,7 +1641,9 @@ class SemanticSegmentationDataset(Dataset):
1586
1641
 
1587
1642
  return image, mask
1588
1643
 
1589
- def _pad_to_size(self, tensor, target_size):
1644
+ def _pad_to_size(
1645
+ self, tensor: torch.Tensor, target_size: Tuple[int, int]
1646
+ ) -> torch.Tensor:
1590
1647
  """Pad tensor to target size with zeros."""
1591
1648
  target_h, target_w = target_size
1592
1649
 
@@ -1618,10 +1675,10 @@ class SemanticSegmentationDataset(Dataset):
1618
1675
 
1619
1676
  return padded
1620
1677
 
1621
- def __len__(self):
1678
+ def __len__(self) -> int:
1622
1679
  return len(self.image_paths)
1623
1680
 
1624
- def __getitem__(self, idx):
1681
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
1625
1682
  # Load image
1626
1683
  image_path = self.image_paths[idx]
1627
1684
  if self._is_geotiff(image_path):
@@ -1706,10 +1763,12 @@ class SemanticSegmentationDataset(Dataset):
1706
1763
  class SemanticTransforms:
1707
1764
  """Custom transforms for semantic segmentation."""
1708
1765
 
1709
- def __init__(self, transforms):
1766
+ def __init__(self, transforms: List[Callable]) -> None:
1710
1767
  self.transforms = transforms
1711
1768
 
1712
- def __call__(self, image, mask):
1769
+ def __call__(
1770
+ self, image: torch.Tensor, mask: torch.Tensor
1771
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1713
1772
  for t in self.transforms:
1714
1773
  image, mask = t(image, mask)
1715
1774
  return image, mask
@@ -1718,17 +1777,21 @@ class SemanticTransforms:
1718
1777
  class SemanticToTensor:
1719
1778
  """Convert numpy.ndarray to tensor for semantic segmentation."""
1720
1779
 
1721
- def __call__(self, image, mask):
1780
+ def __call__(
1781
+ self, image: torch.Tensor, mask: torch.Tensor
1782
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1722
1783
  return image, mask
1723
1784
 
1724
1785
 
1725
1786
  class SemanticRandomHorizontalFlip:
1726
1787
  """Random horizontal flip transform for semantic segmentation."""
1727
1788
 
1728
- def __init__(self, prob=0.5):
1789
+ def __init__(self, prob: float = 0.5) -> None:
1729
1790
  self.prob = prob
1730
1791
 
1731
- def __call__(self, image, mask):
1792
+ def __call__(
1793
+ self, image: torch.Tensor, mask: torch.Tensor
1794
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1732
1795
  if random.random() < self.prob:
1733
1796
  # Flip image and mask along width dimension
1734
1797
  image = torch.flip(image, dims=[2])
@@ -1736,7 +1799,7 @@ class SemanticRandomHorizontalFlip:
1736
1799
  return image, mask
1737
1800
 
1738
1801
 
1739
- def get_semantic_transform(train):
1802
+ def get_semantic_transform(train: bool) -> Any:
1740
1803
  """
1741
1804
  Get transforms for semantic segmentation data augmentation.
1742
1805
 
@@ -1756,14 +1819,14 @@ def get_semantic_transform(train):
1756
1819
 
1757
1820
 
1758
1821
  def get_smp_model(
1759
- architecture="unet",
1760
- encoder_name="resnet34",
1761
- encoder_weights="imagenet",
1762
- in_channels=3,
1763
- classes=2,
1764
- activation=None,
1765
- **kwargs,
1766
- ):
1822
+ architecture: str = "unet",
1823
+ encoder_name: str = "resnet34",
1824
+ encoder_weights: Optional[str] = "imagenet",
1825
+ in_channels: int = 3,
1826
+ classes: int = 2,
1827
+ activation: Optional[str] = None,
1828
+ **kwargs: Any,
1829
+ ) -> torch.nn.Module:
1767
1830
  """
1768
1831
  Get a segmentation model from segmentation-models-pytorch using the generic create_model function.
1769
1832
 
@@ -1852,7 +1915,12 @@ def get_smp_model(
1852
1915
  )
1853
1916
 
1854
1917
 
1855
- def dice_coefficient(pred, target, smooth=1e-6, num_classes=None):
1918
+ def dice_coefficient(
1919
+ pred: torch.Tensor,
1920
+ target: torch.Tensor,
1921
+ smooth: float = 1e-6,
1922
+ num_classes: Optional[int] = None,
1923
+ ) -> float:
1856
1924
  """
1857
1925
  Calculate Dice coefficient for segmentation (binary or multi-class).
1858
1926
 
@@ -1894,7 +1962,12 @@ def dice_coefficient(pred, target, smooth=1e-6, num_classes=None):
1894
1962
  return sum(dice_scores) / len(dice_scores) if dice_scores else 0.0
1895
1963
 
1896
1964
 
1897
- def iou_coefficient(pred, target, smooth=1e-6, num_classes=None):
1965
+ def iou_coefficient(
1966
+ pred: torch.Tensor,
1967
+ target: torch.Tensor,
1968
+ smooth: float = 1e-6,
1969
+ num_classes: Optional[int] = None,
1970
+ ) -> float:
1898
1971
  """
1899
1972
  Calculate IoU coefficient for segmentation (binary or multi-class).
1900
1973
 
@@ -1937,8 +2010,15 @@ def iou_coefficient(pred, target, smooth=1e-6, num_classes=None):
1937
2010
 
1938
2011
 
1939
2012
  def train_semantic_one_epoch(
1940
- model, optimizer, data_loader, device, epoch, criterion, print_freq=10, verbose=True
1941
- ):
2013
+ model: torch.nn.Module,
2014
+ optimizer: torch.optim.Optimizer,
2015
+ data_loader: DataLoader,
2016
+ device: torch.device,
2017
+ epoch: int,
2018
+ criterion: Any,
2019
+ print_freq: int = 10,
2020
+ verbose: bool = True,
2021
+ ) -> float:
1942
2022
  """
1943
2023
  Train the semantic segmentation model for one epoch.
1944
2024
 
@@ -1992,7 +2072,13 @@ def train_semantic_one_epoch(
1992
2072
  return avg_loss
1993
2073
 
1994
2074
 
1995
- def evaluate_semantic(model, data_loader, device, criterion, num_classes=2):
2075
+ def evaluate_semantic(
2076
+ model: torch.nn.Module,
2077
+ data_loader: DataLoader,
2078
+ device: torch.device,
2079
+ criterion: Any,
2080
+ num_classes: int = 2,
2081
+ ) -> Dict[str, float]:
1996
2082
  """
1997
2083
  Evaluate the semantic segmentation model on the validation set.
1998
2084
 
@@ -2040,31 +2126,32 @@ def evaluate_semantic(model, data_loader, device, criterion, num_classes=2):
2040
2126
 
2041
2127
 
2042
2128
  def train_segmentation_model(
2043
- images_dir,
2044
- labels_dir,
2045
- output_dir,
2046
- architecture="unet",
2047
- encoder_name="resnet34",
2048
- encoder_weights="imagenet",
2049
- num_channels=3,
2050
- num_classes=2,
2051
- batch_size=8,
2052
- num_epochs=50,
2053
- learning_rate=0.001,
2054
- weight_decay=1e-4,
2055
- seed=42,
2056
- val_split=0.2,
2057
- print_freq=10,
2058
- verbose=True,
2059
- save_best_only=True,
2060
- plot_curves=False,
2061
- device=None,
2062
- checkpoint_path=None,
2063
- resume_training=False,
2064
- target_size=None,
2065
- resize_mode="resize",
2066
- **kwargs,
2067
- ):
2129
+ images_dir: str,
2130
+ labels_dir: str,
2131
+ output_dir: str,
2132
+ architecture: str = "unet",
2133
+ encoder_name: str = "resnet34",
2134
+ encoder_weights: Optional[str] = "imagenet",
2135
+ num_channels: int = 3,
2136
+ num_classes: int = 2,
2137
+ batch_size: int = 8,
2138
+ num_epochs: int = 50,
2139
+ learning_rate: float = 0.001,
2140
+ weight_decay: float = 1e-4,
2141
+ seed: int = 42,
2142
+ val_split: float = 0.2,
2143
+ print_freq: int = 10,
2144
+ verbose: bool = True,
2145
+ save_best_only: bool = True,
2146
+ plot_curves: bool = False,
2147
+ device: Optional[torch.device] = None,
2148
+ checkpoint_path: Optional[str] = None,
2149
+ resume_training: bool = False,
2150
+ target_size: Optional[Tuple[int, int]] = None,
2151
+ resize_mode: str = "resize",
2152
+ num_workers: Optional[int] = None,
2153
+ **kwargs: Any,
2154
+ ) -> torch.nn.Module:
2068
2155
  """
2069
2156
  Train a semantic segmentation model for object detection using segmentation-models-pytorch.
2070
2157
 
@@ -2106,6 +2193,7 @@ def train_segmentation_model(
2106
2193
  resize_mode (str): How to handle size standardization when target_size is specified.
2107
2194
  'resize' - Resize images to target_size (may change aspect ratio)
2108
2195
  'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
2196
+ num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
2109
2197
  **kwargs: Additional arguments passed to smp.create_model().
2110
2198
  Returns:
2111
2199
  None: Model weights are saved to output_dir.
@@ -2252,7 +2340,9 @@ def train_segmentation_model(
2252
2340
  # Create data loaders
2253
2341
  # Use num_workers=0 on macOS and Windows to avoid multiprocessing issues
2254
2342
  # Windows often has issues with multiprocessing in Jupyter notebooks
2255
- num_workers = 0 if platform.system() in ["Darwin", "Windows"] else 4
2343
+ # Increase num_workers for better data loading performance
2344
+ if num_workers is None:
2345
+ num_workers = 0 if platform.system() in ["Darwin", "Windows"] else 8
2256
2346
 
2257
2347
  try:
2258
2348
  train_loader = DataLoader(
@@ -2310,6 +2400,11 @@ def train_segmentation_model(
2310
2400
  )
2311
2401
  model.to(device)
2312
2402
 
2403
+ # Enable multi-GPU training if multiple GPUs are available
2404
+ if torch.cuda.device_count() > 1:
2405
+ print(f"Using {torch.cuda.device_count()} GPUs for training")
2406
+ model = torch.nn.DataParallel(model)
2407
+
2313
2408
  # Set up loss function (CrossEntropyLoss for multi-class, can also use DiceLoss)
2314
2409
  criterion = torch.nn.CrossEntropyLoss()
2315
2410
 
@@ -2521,18 +2616,18 @@ def train_segmentation_model(
2521
2616
 
2522
2617
 
2523
2618
  def semantic_inference_on_geotiff(
2524
- model,
2525
- geotiff_path,
2526
- output_path,
2527
- window_size=512,
2528
- overlap=256,
2529
- batch_size=4,
2530
- num_channels=3,
2531
- num_classes=2,
2532
- device=None,
2533
- quiet=False,
2534
- **kwargs,
2535
- ):
2619
+ model: torch.nn.Module,
2620
+ geotiff_path: str,
2621
+ output_path: str,
2622
+ window_size: int = 512,
2623
+ overlap: int = 256,
2624
+ batch_size: int = 4,
2625
+ num_channels: int = 3,
2626
+ num_classes: int = 2,
2627
+ device: Optional[torch.device] = None,
2628
+ quiet: bool = False,
2629
+ **kwargs: Any,
2630
+ ) -> Tuple[str, float]:
2536
2631
  """
2537
2632
  Perform semantic segmentation inference on a large GeoTIFF using a sliding window approach.
2538
2633
 
@@ -2748,19 +2843,19 @@ def semantic_inference_on_geotiff(
2748
2843
 
2749
2844
 
2750
2845
  def semantic_inference_on_image(
2751
- model,
2752
- image_path,
2753
- output_path,
2754
- window_size=512,
2755
- overlap=256,
2756
- batch_size=4,
2757
- num_channels=3,
2758
- num_classes=2,
2759
- device=None,
2760
- binary_output=True,
2761
- quiet=False,
2762
- **kwargs,
2763
- ):
2846
+ model: torch.nn.Module,
2847
+ image_path: str,
2848
+ output_path: str,
2849
+ window_size: int = 512,
2850
+ overlap: int = 256,
2851
+ batch_size: int = 4,
2852
+ num_channels: int = 3,
2853
+ num_classes: int = 2,
2854
+ device: Optional[torch.device] = None,
2855
+ binary_output: bool = True,
2856
+ quiet: bool = False,
2857
+ **kwargs: Any,
2858
+ ) -> Tuple[str, float]:
2764
2859
  """
2765
2860
  Perform semantic segmentation inference on a regular image (JPG, PNG, etc.) using a sliding window approach.
2766
2861
 
@@ -3025,20 +3120,20 @@ def semantic_inference_on_image(
3025
3120
 
3026
3121
 
3027
3122
  def semantic_segmentation(
3028
- input_path,
3029
- output_path,
3030
- model_path,
3031
- architecture="unet",
3032
- encoder_name="resnet34",
3033
- num_channels=3,
3034
- num_classes=2,
3035
- window_size=512,
3036
- overlap=256,
3037
- batch_size=4,
3038
- device=None,
3039
- quiet=False,
3040
- **kwargs,
3041
- ):
3123
+ input_path: str,
3124
+ output_path: str,
3125
+ model_path: str,
3126
+ architecture: str = "unet",
3127
+ encoder_name: str = "resnet34",
3128
+ num_channels: int = 3,
3129
+ num_classes: int = 2,
3130
+ window_size: int = 512,
3131
+ overlap: int = 256,
3132
+ batch_size: int = 4,
3133
+ device: Optional[torch.device] = None,
3134
+ quiet: bool = False,
3135
+ **kwargs: Any,
3136
+ ) -> None:
3042
3137
  """
3043
3138
  Perform semantic segmentation on an image file using a trained model.
3044
3139
 
@@ -3091,7 +3186,16 @@ def semantic_segmentation(
3091
3186
  except Exception as e:
3092
3187
  raise FileNotFoundError(f"Model file not found: {model_path}")
3093
3188
 
3094
- model.load_state_dict(torch.load(model_path, map_location=device))
3189
+ # Load state dict and handle DataParallel module prefix
3190
+ state_dict = torch.load(model_path, map_location=device)
3191
+
3192
+ # Remove 'module.' prefix if present (from DataParallel training)
3193
+ if any(key.startswith("module.") for key in state_dict.keys()):
3194
+ state_dict = {
3195
+ key.replace("module.", ""): value for key, value in state_dict.items()
3196
+ }
3197
+
3198
+ model.load_state_dict(state_dict)
3095
3199
  model.to(device)
3096
3200
  model.eval()
3097
3201
 
@@ -3131,21 +3235,21 @@ def semantic_segmentation(
3131
3235
 
3132
3236
 
3133
3237
  def semantic_segmentation_batch(
3134
- input_dir,
3135
- output_dir,
3136
- model_path,
3137
- architecture="unet",
3138
- encoder_name="resnet34",
3139
- num_channels=3,
3140
- num_classes=2,
3141
- window_size=512,
3142
- overlap=256,
3143
- batch_size=4,
3144
- device=None,
3145
- filenames=None,
3146
- quiet=False,
3147
- **kwargs,
3148
- ):
3238
+ input_dir: str,
3239
+ output_dir: str,
3240
+ model_path: str,
3241
+ architecture: str = "unet",
3242
+ encoder_name: str = "resnet34",
3243
+ num_channels: int = 3,
3244
+ num_classes: int = 2,
3245
+ window_size: int = 512,
3246
+ overlap: int = 256,
3247
+ batch_size: int = 4,
3248
+ device: Optional[torch.device] = None,
3249
+ filenames: Optional[List[str]] = None,
3250
+ quiet: bool = False,
3251
+ **kwargs: Any,
3252
+ ) -> None:
3149
3253
  """
3150
3254
  Perform semantic segmentation on a batch of images from an input directory.
3151
3255
 
@@ -3220,7 +3324,16 @@ def semantic_segmentation_batch(
3220
3324
  except Exception as e:
3221
3325
  raise FileNotFoundError(f"Model file not found: {model_path}")
3222
3326
 
3223
- model.load_state_dict(torch.load(model_path, map_location=device))
3327
+ # Load state dict and handle DataParallel module prefix
3328
+ state_dict = torch.load(model_path, map_location=device)
3329
+
3330
+ # Remove 'module.' prefix if present (from DataParallel training)
3331
+ if any(key.startswith("module.") for key in state_dict.keys()):
3332
+ state_dict = {
3333
+ key.replace("module.", ""): value for key, value in state_dict.items()
3334
+ }
3335
+
3336
+ model.load_state_dict(state_dict)
3224
3337
  model.to(device)
3225
3338
  model.eval()
3226
3339
 
@@ -3295,21 +3408,21 @@ def semantic_segmentation_batch(
3295
3408
 
3296
3409
 
3297
3410
  def train_instance_segmentation_model(
3298
- images_dir,
3299
- labels_dir,
3300
- output_dir,
3301
- num_classes=2,
3302
- num_channels=3,
3303
- batch_size=4,
3304
- num_epochs=10,
3305
- learning_rate=0.005,
3306
- seed=42,
3307
- val_split=0.2,
3308
- visualize=False,
3309
- device=None,
3310
- verbose=True,
3311
- **kwargs,
3312
- ):
3411
+ images_dir: str,
3412
+ labels_dir: str,
3413
+ output_dir: str,
3414
+ num_classes: int = 2,
3415
+ num_channels: int = 3,
3416
+ batch_size: int = 4,
3417
+ num_epochs: int = 10,
3418
+ learning_rate: float = 0.005,
3419
+ seed: int = 42,
3420
+ val_split: float = 0.2,
3421
+ visualize: bool = False,
3422
+ device: Optional[torch.device] = None,
3423
+ verbose: bool = True,
3424
+ **kwargs: Any,
3425
+ ) -> torch.nn.Module:
3313
3426
  """
3314
3427
  Train an instance segmentation model using Mask R-CNN.
3315
3428
 
@@ -3358,18 +3471,18 @@ def train_instance_segmentation_model(
3358
3471
 
3359
3472
 
3360
3473
  def instance_segmentation(
3361
- input_path,
3362
- output_path,
3363
- model_path,
3364
- window_size=512,
3365
- overlap=256,
3366
- confidence_threshold=0.5,
3367
- batch_size=4,
3368
- num_channels=3,
3369
- num_classes=2,
3370
- device=None,
3371
- **kwargs,
3372
- ):
3474
+ input_path: str,
3475
+ output_path: str,
3476
+ model_path: str,
3477
+ window_size: int = 512,
3478
+ overlap: int = 256,
3479
+ confidence_threshold: float = 0.5,
3480
+ batch_size: int = 4,
3481
+ num_channels: int = 3,
3482
+ num_classes: int = 2,
3483
+ device: Optional[torch.device] = None,
3484
+ **kwargs: Any,
3485
+ ) -> None:
3373
3486
  """
3374
3487
  Perform instance segmentation on a GeoTIFF using a pre-trained Mask R-CNN model.
3375
3488
 
@@ -3400,7 +3513,16 @@ def instance_segmentation(
3400
3513
  if device is None:
3401
3514
  device = get_device()
3402
3515
 
3403
- model.load_state_dict(torch.load(model_path, map_location=device))
3516
+ # Load state dict and handle DataParallel module prefix
3517
+ state_dict = torch.load(model_path, map_location=device)
3518
+
3519
+ # Remove 'module.' prefix if present (from DataParallel training)
3520
+ if any(key.startswith("module.") for key in state_dict.keys()):
3521
+ state_dict = {
3522
+ key.replace("module.", ""): value for key, value in state_dict.items()
3523
+ }
3524
+
3525
+ model.load_state_dict(state_dict)
3404
3526
  model.to(device)
3405
3527
 
3406
3528
  # Use the proper instance segmentation inference function
@@ -3419,18 +3541,18 @@ def instance_segmentation(
3419
3541
 
3420
3542
 
3421
3543
  def instance_segmentation_batch(
3422
- input_dir,
3423
- output_dir,
3424
- model_path,
3425
- window_size=512,
3426
- overlap=256,
3427
- confidence_threshold=0.5,
3428
- batch_size=4,
3429
- num_channels=3,
3430
- num_classes=2,
3431
- device=None,
3432
- **kwargs,
3433
- ):
3544
+ input_dir: str,
3545
+ output_dir: str,
3546
+ model_path: str,
3547
+ window_size: int = 512,
3548
+ overlap: int = 256,
3549
+ confidence_threshold: float = 0.5,
3550
+ batch_size: int = 4,
3551
+ num_channels: int = 3,
3552
+ num_classes: int = 2,
3553
+ device: Optional[torch.device] = None,
3554
+ **kwargs: Any,
3555
+ ) -> None:
3434
3556
  """
3435
3557
  Perform instance segmentation on multiple GeoTIFF files using a pre-trained Mask R-CNN model.
3436
3558
 
@@ -3461,7 +3583,16 @@ def instance_segmentation_batch(
3461
3583
  if device is None:
3462
3584
  device = get_device()
3463
3585
 
3464
- model.load_state_dict(torch.load(model_path, map_location=device))
3586
+ # Load state dict and handle DataParallel module prefix
3587
+ state_dict = torch.load(model_path, map_location=device)
3588
+
3589
+ # Remove 'module.' prefix if present (from DataParallel training)
3590
+ if any(key.startswith("module.") for key in state_dict.keys()):
3591
+ state_dict = {
3592
+ key.replace("module.", ""): value for key, value in state_dict.items()
3593
+ }
3594
+
3595
+ model.load_state_dict(state_dict)
3465
3596
  model.to(device)
3466
3597
 
3467
3598
  # Process all GeoTIFF files in the input directory