geoai-py 0.14.0__py2.py3-none-any.whl → 0.15.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/timm_segment.py ADDED
@@ -0,0 +1,1097 @@
1
+ """Module for training semantic segmentation models using timm encoders with PyTorch Lightning."""
2
+
3
+ import os
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from tqdm import tqdm
11
+
12
+ try:
13
+ import timm
14
+
15
+ TIMM_AVAILABLE = True
16
+ except ImportError:
17
+ TIMM_AVAILABLE = False
18
+
19
+ try:
20
+ import segmentation_models_pytorch as smp
21
+
22
+ SMP_AVAILABLE = True
23
+ except ImportError:
24
+ SMP_AVAILABLE = False
25
+
26
+ try:
27
+ import lightning.pytorch as pl
28
+ from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
29
+ from lightning.pytorch.loggers import CSVLogger
30
+
31
+ LIGHTNING_AVAILABLE = True
32
+ except ImportError:
33
+ LIGHTNING_AVAILABLE = False
34
+
35
+
36
+ class TimmSegmentationModel(pl.LightningModule):
37
+ """
38
+ PyTorch Lightning module for semantic segmentation using timm encoders with SMP decoders,
39
+ or pure timm models from Hugging Face Hub.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ encoder_name: str = "resnet50",
45
+ architecture: str = "unet",
46
+ num_classes: int = 2,
47
+ in_channels: int = 3,
48
+ encoder_weights: str = "imagenet",
49
+ learning_rate: float = 1e-3,
50
+ weight_decay: float = 1e-4,
51
+ freeze_encoder: bool = False,
52
+ loss_fn: Optional[nn.Module] = None,
53
+ class_weights: Optional[torch.Tensor] = None,
54
+ use_timm_model: bool = False,
55
+ timm_model_name: Optional[str] = None,
56
+ **decoder_kwargs: Any,
57
+ ):
58
+ """
59
+ Initialize TimmSegmentationModel.
60
+
61
+ Args:
62
+ encoder_name (str): Name of encoder (e.g., 'resnet50', 'efficientnet_b0').
63
+ architecture (str): Segmentation architecture ('unet', 'unetplusplus', 'deeplabv3',
64
+ 'deeplabv3plus', 'fpn', 'pspnet', 'linknet', 'manet', 'pan').
65
+ Ignored if use_timm_model=True.
66
+ num_classes (int): Number of output classes.
67
+ in_channels (int): Number of input channels.
68
+ encoder_weights (str): Pretrained weights for encoder ('imagenet', 'ssl', 'swsl', None).
69
+ learning_rate (float): Learning rate for optimizer.
70
+ weight_decay (float): Weight decay for optimizer.
71
+ freeze_encoder (bool): Freeze encoder weights during training.
72
+ loss_fn (nn.Module, optional): Custom loss function. Defaults to CrossEntropyLoss.
73
+ class_weights (torch.Tensor, optional): Class weights for loss function.
74
+ use_timm_model (bool): If True, load a complete segmentation model from timm/HF Hub
75
+ instead of using SMP architecture. Defaults to False.
76
+ timm_model_name (str, optional): Name or path of timm model from HF Hub
77
+ (e.g., 'hf-hub:timm/segformer_b0.ade_512x512' or 'nvidia/mit-b0').
78
+ Only used if use_timm_model=True.
79
+ **decoder_kwargs: Additional arguments for decoder (only used with SMP).
80
+ """
81
+ super().__init__()
82
+
83
+ if not TIMM_AVAILABLE:
84
+ raise ImportError("timm is required. Install it with: pip install timm")
85
+
86
+ self.save_hyperparameters()
87
+
88
+ # Check if using a pure timm model from HF Hub
89
+ if use_timm_model:
90
+ if timm_model_name is None:
91
+ timm_model_name = encoder_name
92
+
93
+ # Load model from timm (supports HF Hub with 'hf-hub:' prefix)
94
+ try:
95
+ self.model = timm.create_model(
96
+ timm_model_name,
97
+ pretrained=True if encoder_weights else False,
98
+ num_classes=num_classes,
99
+ in_chans=in_channels,
100
+ )
101
+ print(f"Loaded timm model: {timm_model_name}")
102
+ except Exception as e:
103
+ raise ValueError(
104
+ f"Failed to load timm model '{timm_model_name}'. "
105
+ f"Error: {str(e)}. "
106
+ f"For HF Hub models, use format 'hf-hub:username/model-name' or 'hf_hub:username/model-name'."
107
+ )
108
+ else:
109
+ # Use SMP architecture with timm encoder
110
+ if not SMP_AVAILABLE:
111
+ raise ImportError(
112
+ "segmentation-models-pytorch is required. "
113
+ "Install it with: pip install segmentation-models-pytorch"
114
+ )
115
+
116
+ # Create segmentation model with timm encoder using smp.create_model
117
+ try:
118
+ self.model = smp.create_model(
119
+ arch=architecture,
120
+ encoder_name=encoder_name,
121
+ encoder_weights=encoder_weights,
122
+ in_channels=in_channels,
123
+ classes=num_classes,
124
+ **decoder_kwargs,
125
+ )
126
+ except Exception as e:
127
+ # Provide helpful error message
128
+ available_archs = [
129
+ "unet",
130
+ "unetplusplus",
131
+ "manet",
132
+ "linknet",
133
+ "fpn",
134
+ "pspnet",
135
+ "deeplabv3",
136
+ "deeplabv3plus",
137
+ "pan",
138
+ "upernet",
139
+ ]
140
+ raise ValueError(
141
+ f"Failed to create model with architecture '{architecture}' and encoder '{encoder_name}'. "
142
+ f"Error: {str(e)}. "
143
+ f"Available architectures include: {', '.join(available_archs)}. "
144
+ f"Please check the segmentation-models-pytorch documentation for supported combinations."
145
+ )
146
+
147
+ if freeze_encoder:
148
+ self._freeze_encoder()
149
+
150
+ # Set up loss function
151
+ if loss_fn is not None:
152
+ self.loss_fn = loss_fn
153
+ elif class_weights is not None:
154
+ self.loss_fn = nn.CrossEntropyLoss(weight=class_weights)
155
+ else:
156
+ self.loss_fn = nn.CrossEntropyLoss()
157
+
158
+ self.learning_rate = learning_rate
159
+ self.weight_decay = weight_decay
160
+
161
+ def _freeze_encoder(self):
162
+ """Freeze encoder weights."""
163
+ if hasattr(self.model, "encoder"):
164
+ for param in self.model.encoder.parameters():
165
+ param.requires_grad = False
166
+ else:
167
+ # For pure timm models without separate encoder
168
+ if not self.hparams.use_timm_model:
169
+ raise ValueError("Model does not have an encoder attribute to freeze")
170
+
171
+ def forward(self, x):
172
+ return self.model(x)
173
+
174
+ def training_step(self, batch, batch_idx):
175
+ x, y = batch
176
+ logits = self(x)
177
+ loss = self.loss_fn(logits, y)
178
+
179
+ # Calculate IoU
180
+ pred = torch.argmax(logits, dim=1)
181
+ iou = self._compute_iou(pred, y)
182
+
183
+ self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
184
+ self.log("train_iou", iou, on_step=True, on_epoch=True, prog_bar=True)
185
+
186
+ return loss
187
+
188
+ def validation_step(self, batch, batch_idx):
189
+ x, y = batch
190
+ logits = self(x)
191
+ loss = self.loss_fn(logits, y)
192
+
193
+ # Calculate IoU
194
+ pred = torch.argmax(logits, dim=1)
195
+ iou = self._compute_iou(pred, y)
196
+
197
+ self.log("val_loss", loss, on_epoch=True, prog_bar=True)
198
+ self.log("val_iou", iou, on_epoch=True, prog_bar=True)
199
+
200
+ return loss
201
+
202
+ def test_step(self, batch, batch_idx):
203
+ x, y = batch
204
+ logits = self(x)
205
+ loss = self.loss_fn(logits, y)
206
+
207
+ # Calculate IoU
208
+ pred = torch.argmax(logits, dim=1)
209
+ iou = self._compute_iou(pred, y)
210
+
211
+ self.log("test_loss", loss, on_epoch=True)
212
+ self.log("test_iou", iou, on_epoch=True)
213
+
214
+ return loss
215
+
216
+ def _compute_iou(self, pred, target, smooth=1e-6):
217
+ """Compute mean IoU across all classes."""
218
+ num_classes = self.hparams.num_classes
219
+ ious = []
220
+
221
+ for cls in range(num_classes):
222
+ pred_cls = pred == cls
223
+ target_cls = target == cls
224
+
225
+ intersection = (pred_cls & target_cls).float().sum()
226
+ union = (pred_cls | target_cls).float().sum()
227
+
228
+ if union == 0:
229
+ continue
230
+
231
+ iou = (intersection + smooth) / (union + smooth)
232
+ ious.append(iou)
233
+
234
+ return (
235
+ torch.stack(ious).mean() if ious else torch.tensor(0.0, device=pred.device)
236
+ )
237
+
238
+ def configure_optimizers(self):
239
+ optimizer = torch.optim.AdamW(
240
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
241
+ )
242
+
243
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
244
+ optimizer, mode="min", factor=0.5, patience=5, verbose=True
245
+ )
246
+
247
+ return {
248
+ "optimizer": optimizer,
249
+ "lr_scheduler": {
250
+ "scheduler": scheduler,
251
+ "monitor": "val_loss",
252
+ },
253
+ }
254
+
255
+ def predict_step(self, batch, batch_idx):
256
+ x = batch[0] if isinstance(batch, (list, tuple)) else batch
257
+ logits = self(x)
258
+ probs = torch.softmax(logits, dim=1)
259
+ preds = torch.argmax(probs, dim=1)
260
+ return {"predictions": preds, "probabilities": probs}
261
+
262
+
263
+ class SegmentationDataset(Dataset):
264
+ """
265
+ Dataset for semantic segmentation with remote sensing imagery.
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ image_paths: List[str],
271
+ mask_paths: List[str],
272
+ transform: Optional[Callable] = None,
273
+ num_channels: Optional[int] = None,
274
+ ):
275
+ """
276
+ Initialize SegmentationDataset.
277
+
278
+ Args:
279
+ image_paths (List[str]): List of paths to image files.
280
+ mask_paths (List[str]): List of paths to mask files.
281
+ transform (callable, optional): Transform to apply to images and masks.
282
+ num_channels (int, optional): Number of channels to use. If None, uses all.
283
+ """
284
+ self.image_paths = image_paths
285
+ self.mask_paths = mask_paths
286
+ self.transform = transform
287
+ self.num_channels = num_channels
288
+
289
+ if len(image_paths) != len(mask_paths):
290
+ raise ValueError("Number of images must match number of masks")
291
+
292
+ def __len__(self):
293
+ return len(self.image_paths)
294
+
295
+ def __getitem__(self, idx):
296
+ import rasterio
297
+
298
+ # Load image
299
+ with rasterio.open(self.image_paths[idx]) as src:
300
+ image = src.read() # Shape: (C, H, W)
301
+
302
+ # Handle channel selection
303
+ if self.num_channels is not None and image.shape[0] != self.num_channels:
304
+ if image.shape[0] > self.num_channels:
305
+ image = image[: self.num_channels]
306
+ else:
307
+ # Pad with zeros if needed
308
+ padded = np.zeros(
309
+ (self.num_channels, image.shape[1], image.shape[2])
310
+ )
311
+ padded[: image.shape[0]] = image
312
+ image = padded
313
+
314
+ # Normalize to [0, 1]
315
+ if image.max() > 1.0:
316
+ image = image / 255.0
317
+
318
+ image = image.astype(np.float32)
319
+
320
+ # Load mask
321
+ with rasterio.open(self.mask_paths[idx]) as src:
322
+ mask = src.read(1) # Shape: (H, W)
323
+ mask = mask.astype(np.int64)
324
+
325
+ # Convert to tensors
326
+ image = torch.from_numpy(image)
327
+ mask = torch.from_numpy(mask)
328
+
329
+ # Apply transforms if provided
330
+ if self.transform is not None:
331
+ image, mask = self.transform(image, mask)
332
+
333
+ return image, mask
334
+
335
+
336
+ def train_timm_segmentation(
337
+ train_dataset: Dataset,
338
+ val_dataset: Optional[Dataset] = None,
339
+ test_dataset: Optional[Dataset] = None,
340
+ encoder_name: str = "resnet50",
341
+ architecture: str = "unet",
342
+ num_classes: int = 2,
343
+ in_channels: int = 3,
344
+ encoder_weights: str = "imagenet",
345
+ output_dir: str = "output",
346
+ batch_size: int = 8,
347
+ num_epochs: int = 50,
348
+ learning_rate: float = 1e-3,
349
+ weight_decay: float = 1e-4,
350
+ num_workers: int = 4,
351
+ freeze_encoder: bool = False,
352
+ class_weights: Optional[List[float]] = None,
353
+ accelerator: str = "auto",
354
+ devices: str = "auto",
355
+ monitor_metric: str = "val_loss",
356
+ mode: str = "min",
357
+ patience: int = 10,
358
+ save_top_k: int = 1,
359
+ checkpoint_path: Optional[str] = None,
360
+ use_timm_model: bool = False,
361
+ timm_model_name: Optional[str] = None,
362
+ **kwargs: Any,
363
+ ) -> TimmSegmentationModel:
364
+ """
365
+ Train a semantic segmentation model using timm encoder.
366
+
367
+ Args:
368
+ train_dataset (Dataset): Training dataset.
369
+ val_dataset (Dataset, optional): Validation dataset.
370
+ test_dataset (Dataset, optional): Test dataset.
371
+ encoder_name (str): Name of timm encoder.
372
+ architecture (str): Segmentation architecture.
373
+ num_classes (int): Number of output classes.
374
+ in_channels (int): Number of input channels.
375
+ encoder_weights (str): Pretrained weights for encoder.
376
+ output_dir (str): Directory to save outputs.
377
+ batch_size (int): Batch size for training.
378
+ num_epochs (int): Number of training epochs.
379
+ learning_rate (float): Learning rate.
380
+ weight_decay (float): Weight decay for optimizer.
381
+ num_workers (int): Number of data loading workers.
382
+ freeze_encoder (bool): Freeze encoder during training.
383
+ class_weights (List[float], optional): Class weights for loss.
384
+ accelerator (str): Accelerator type ('auto', 'gpu', 'cpu').
385
+ devices (str): Devices to use.
386
+ monitor_metric (str): Metric to monitor for checkpointing.
387
+ mode (str): 'min' or 'max' for monitor_metric.
388
+ patience (int): Early stopping patience.
389
+ save_top_k (int): Number of best models to save.
390
+ checkpoint_path (str, optional): Path to checkpoint to resume from.
391
+ use_timm_model (bool): Load complete segmentation model from timm/HF Hub.
392
+ timm_model_name (str, optional): Model name from HF Hub (e.g., 'hf-hub:nvidia/mit-b0').
393
+ **kwargs: Additional arguments for PyTorch Lightning Trainer.
394
+
395
+ Returns:
396
+ TimmSegmentationModel: Trained model.
397
+ """
398
+ if not LIGHTNING_AVAILABLE:
399
+ raise ImportError(
400
+ "PyTorch Lightning is required. Install it with: pip install lightning"
401
+ )
402
+
403
+ # Create output directory
404
+ os.makedirs(output_dir, exist_ok=True)
405
+ model_dir = os.path.join(output_dir, "models")
406
+ os.makedirs(model_dir, exist_ok=True)
407
+
408
+ # Convert class weights to tensor if provided
409
+ weight_tensor = None
410
+ if class_weights is not None:
411
+ weight_tensor = torch.tensor(class_weights, dtype=torch.float32)
412
+
413
+ # Create model
414
+ model = TimmSegmentationModel(
415
+ encoder_name=encoder_name,
416
+ architecture=architecture,
417
+ num_classes=num_classes,
418
+ in_channels=in_channels,
419
+ encoder_weights=encoder_weights,
420
+ learning_rate=learning_rate,
421
+ weight_decay=weight_decay,
422
+ freeze_encoder=freeze_encoder,
423
+ class_weights=weight_tensor,
424
+ use_timm_model=use_timm_model,
425
+ timm_model_name=timm_model_name,
426
+ )
427
+
428
+ # Create data loaders
429
+ train_loader = DataLoader(
430
+ train_dataset,
431
+ batch_size=batch_size,
432
+ shuffle=True,
433
+ num_workers=num_workers,
434
+ pin_memory=True,
435
+ )
436
+
437
+ val_loader = None
438
+ if val_dataset is not None:
439
+ val_loader = DataLoader(
440
+ val_dataset,
441
+ batch_size=batch_size,
442
+ shuffle=False,
443
+ num_workers=num_workers,
444
+ pin_memory=True,
445
+ )
446
+
447
+ # Set up callbacks
448
+ callbacks = []
449
+
450
+ # Model checkpoint
451
+ checkpoint_callback = ModelCheckpoint(
452
+ dirpath=model_dir,
453
+ filename=f"{encoder_name}_{architecture}_{{epoch:02d}}_{{val_loss:.4f}}",
454
+ monitor=monitor_metric,
455
+ mode=mode,
456
+ save_top_k=save_top_k,
457
+ save_last=True,
458
+ verbose=True,
459
+ )
460
+ callbacks.append(checkpoint_callback)
461
+
462
+ # Early stopping
463
+ early_stop_callback = EarlyStopping(
464
+ monitor=monitor_metric,
465
+ patience=patience,
466
+ mode=mode,
467
+ verbose=True,
468
+ )
469
+ callbacks.append(early_stop_callback)
470
+
471
+ # Set up logger
472
+ logger = CSVLogger(model_dir, name="lightning_logs")
473
+
474
+ # Create trainer
475
+ trainer = pl.Trainer(
476
+ max_epochs=num_epochs,
477
+ accelerator=accelerator,
478
+ devices=devices,
479
+ callbacks=callbacks,
480
+ logger=logger,
481
+ log_every_n_steps=10,
482
+ **kwargs,
483
+ )
484
+
485
+ # Train model
486
+ print(f"Training {encoder_name} {architecture} for {num_epochs} epochs...")
487
+ trainer.fit(
488
+ model,
489
+ train_dataloaders=train_loader,
490
+ val_dataloaders=val_loader,
491
+ ckpt_path=checkpoint_path,
492
+ )
493
+
494
+ # Test if test dataset provided
495
+ if test_dataset is not None:
496
+ test_loader = DataLoader(
497
+ test_dataset,
498
+ batch_size=batch_size,
499
+ shuffle=False,
500
+ num_workers=num_workers,
501
+ pin_memory=True,
502
+ )
503
+ print("\nTesting model on test set...")
504
+ trainer.test(model, dataloaders=test_loader)
505
+
506
+ print(f"\nBest model saved at: {checkpoint_callback.best_model_path}")
507
+
508
+ # Save training history in compatible format
509
+ metrics = trainer.logged_metrics
510
+ history = {
511
+ "train_loss": [],
512
+ "val_loss": [],
513
+ "val_iou": [],
514
+ "epochs": [],
515
+ }
516
+
517
+ # Extract metrics from logger
518
+ import pandas as pd
519
+ import glob
520
+
521
+ csv_files = glob.glob(
522
+ os.path.join(model_dir, "lightning_logs", "**", "metrics.csv"), recursive=True
523
+ )
524
+ if csv_files:
525
+ df = pd.read_csv(csv_files[0])
526
+
527
+ # Group by epoch to get epoch-level metrics
528
+ epoch_data = df.groupby("epoch").last().reset_index()
529
+
530
+ if "train_loss_epoch" in epoch_data.columns:
531
+ history["train_loss"] = epoch_data["train_loss_epoch"].dropna().tolist()
532
+ if "val_loss" in epoch_data.columns:
533
+ history["val_loss"] = epoch_data["val_loss"].dropna().tolist()
534
+ if "val_iou" in epoch_data.columns:
535
+ history["val_iou"] = epoch_data["val_iou"].dropna().tolist()
536
+ if "epoch" in epoch_data.columns:
537
+ history["epochs"] = epoch_data["epoch"].dropna().tolist()
538
+
539
+ # Save history
540
+ history_path = os.path.join(model_dir, "training_history.pth")
541
+ torch.save(history, history_path)
542
+ print(f"Training history saved to: {history_path}")
543
+
544
+ return model
545
+
546
+
547
+ def predict_segmentation(
548
+ model: Union[TimmSegmentationModel, nn.Module],
549
+ image_paths: List[str],
550
+ batch_size: int = 8,
551
+ num_workers: int = 4,
552
+ device: Optional[str] = None,
553
+ return_probabilities: bool = False,
554
+ ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
555
+ """
556
+ Make predictions on images using a trained segmentation model.
557
+
558
+ Args:
559
+ model: Trained model.
560
+ image_paths: List of paths to images.
561
+ batch_size: Batch size for inference.
562
+ num_workers: Number of data loading workers.
563
+ device: Device to use ('cuda', 'cpu', etc.). Auto-detected if None.
564
+ return_probabilities: If True, return both predictions and probabilities.
565
+
566
+ Returns:
567
+ predictions: Array of predicted segmentation masks.
568
+ probabilities (optional): Array of class probabilities if return_probabilities=True.
569
+ """
570
+ if device is None:
571
+ device = "cuda" if torch.cuda.is_available() else "cpu"
572
+
573
+ # Create dummy masks for dataset
574
+ dummy_masks = image_paths # Use image paths as placeholders
575
+ dataset = SegmentationDataset(image_paths, dummy_masks)
576
+
577
+ loader = DataLoader(
578
+ dataset,
579
+ batch_size=batch_size,
580
+ shuffle=False,
581
+ num_workers=num_workers,
582
+ pin_memory=True,
583
+ )
584
+
585
+ model.eval()
586
+ model = model.to(device)
587
+
588
+ all_preds = []
589
+ all_probs = []
590
+
591
+ with torch.no_grad():
592
+ for images, _ in tqdm(loader, desc="Making predictions"):
593
+ images = images.to(device)
594
+
595
+ if isinstance(model, TimmSegmentationModel):
596
+ logits = model(images)
597
+ else:
598
+ logits = model(images)
599
+
600
+ probs = torch.softmax(logits, dim=1)
601
+ preds = torch.argmax(probs, dim=1)
602
+
603
+ all_preds.append(preds.cpu().numpy())
604
+ if return_probabilities:
605
+ all_probs.append(probs.cpu().numpy())
606
+
607
+ predictions = np.concatenate(all_preds)
608
+
609
+ if return_probabilities:
610
+ probabilities = np.concatenate(all_probs)
611
+ return predictions, probabilities
612
+
613
+ return predictions
614
+
615
+
616
+ def train_timm_segmentation_model(
617
+ images_dir: str,
618
+ labels_dir: str,
619
+ output_dir: str,
620
+ input_format: str = "directory",
621
+ encoder_name: str = "resnet50",
622
+ architecture: str = "unet",
623
+ encoder_weights: str = "imagenet",
624
+ num_channels: int = 3,
625
+ num_classes: int = 2,
626
+ batch_size: int = 8,
627
+ num_epochs: int = 50,
628
+ learning_rate: float = 0.001,
629
+ weight_decay: float = 1e-4,
630
+ val_split: float = 0.2,
631
+ seed: int = 42,
632
+ num_workers: int = 4,
633
+ freeze_encoder: bool = False,
634
+ monitor_metric: str = "val_iou",
635
+ mode: str = "max",
636
+ patience: int = 10,
637
+ save_top_k: int = 1,
638
+ verbose: bool = True,
639
+ device: Optional[str] = None,
640
+ use_timm_model: bool = False,
641
+ timm_model_name: Optional[str] = None,
642
+ **kwargs: Any,
643
+ ) -> torch.nn.Module:
644
+ """
645
+ Train a semantic segmentation model using timm encoder (simplified interface).
646
+
647
+ This is a simplified function that takes image and label directories and handles
648
+ the dataset creation automatically, similar to train_segmentation_model.
649
+
650
+ Args:
651
+ images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
652
+ or root directory containing images/ subdirectory (for 'yolo' format),
653
+ or directory containing images (for 'coco' format).
654
+ labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
655
+ or path to COCO annotations JSON file (for 'coco' format),
656
+ or not used (for 'yolo' format - labels are in images_dir/labels/).
657
+ output_dir (str): Directory to save model checkpoints and results.
658
+ input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
659
+ - 'directory': Standard directory structure with separate images_dir and labels_dir
660
+ - 'coco': COCO JSON format (labels_dir should be path to instances.json)
661
+ - 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
662
+ encoder_name (str): Name of timm encoder (e.g., 'resnet50', 'efficientnet_b3').
663
+ architecture (str): Segmentation architecture ('unet', 'unetplusplus', 'deeplabv3',
664
+ 'deeplabv3plus', 'fpn', 'pspnet', 'linknet', 'manet', 'pan').
665
+ encoder_weights (str): Pretrained weights ('imagenet', 'ssl', 'swsl', None).
666
+ num_channels (int): Number of input channels.
667
+ num_classes (int): Number of output classes.
668
+ batch_size (int): Batch size for training.
669
+ num_epochs (int): Number of training epochs.
670
+ learning_rate (float): Learning rate.
671
+ weight_decay (float): Weight decay for optimizer.
672
+ val_split (float): Validation split ratio (0-1).
673
+ seed (int): Random seed for reproducibility.
674
+ num_workers (int): Number of data loading workers.
675
+ freeze_encoder (bool): Freeze encoder during training.
676
+ monitor_metric (str): Metric to monitor ('val_loss' or 'val_iou').
677
+ mode (str): 'min' for loss, 'max' for metrics.
678
+ patience (int): Early stopping patience.
679
+ save_top_k (int): Number of best models to save.
680
+ verbose (bool): Print training progress.
681
+ device (str, optional): Device to use. Auto-detected if None.
682
+ use_timm_model (bool): Load complete segmentation model from timm/HF Hub.
683
+ timm_model_name (str, optional): Model name from HF Hub (e.g., 'hf-hub:nvidia/mit-b0').
684
+ **kwargs: Additional arguments for training.
685
+
686
+ Returns:
687
+ torch.nn.Module: Trained model.
688
+ """
689
+ import glob
690
+ from sklearn.model_selection import train_test_split
691
+ from .train import parse_coco_annotations, parse_yolo_annotations
692
+
693
+ if not LIGHTNING_AVAILABLE:
694
+ raise ImportError(
695
+ "PyTorch Lightning is required. Install it with: pip install lightning"
696
+ )
697
+
698
+ # Set random seed
699
+ torch.manual_seed(seed)
700
+ np.random.seed(seed)
701
+
702
+ # Get image and label paths based on input format
703
+ if input_format.lower() == "coco":
704
+ # Parse COCO format annotations
705
+ if verbose:
706
+ print(f"Loading COCO format annotations from {labels_dir}")
707
+ # For COCO format, labels_dir is path to instances.json
708
+ # Labels are typically in a "labels" directory parallel to "annotations"
709
+ coco_root = os.path.dirname(os.path.dirname(labels_dir)) # Go up two levels
710
+ labels_directory = os.path.join(coco_root, "labels")
711
+ image_paths, label_paths = parse_coco_annotations(
712
+ labels_dir, images_dir, labels_directory
713
+ )
714
+ elif input_format.lower() == "yolo":
715
+ # Parse YOLO format annotations
716
+ if verbose:
717
+ print(f"Loading YOLO format data from {images_dir}")
718
+ image_paths, label_paths = parse_yolo_annotations(images_dir)
719
+ else:
720
+ # Default: directory format
721
+ image_paths = sorted(
722
+ glob.glob(os.path.join(images_dir, "*.tif"))
723
+ + glob.glob(os.path.join(images_dir, "*.tiff"))
724
+ )
725
+ label_paths = sorted(
726
+ glob.glob(os.path.join(labels_dir, "*.tif"))
727
+ + glob.glob(os.path.join(labels_dir, "*.tiff"))
728
+ )
729
+
730
+ if len(image_paths) == 0:
731
+ raise ValueError(f"No images found")
732
+ if len(label_paths) == 0:
733
+ raise ValueError(f"No labels found")
734
+ if len(image_paths) != len(label_paths):
735
+ raise ValueError(
736
+ f"Number of images ({len(image_paths)}) doesn't match "
737
+ f"number of labels ({len(label_paths)})"
738
+ )
739
+
740
+ if verbose:
741
+ print(f"Found {len(image_paths)} image-label pairs")
742
+
743
+ # Split into train and validation
744
+ train_images, val_images, train_labels, val_labels = train_test_split(
745
+ image_paths, label_paths, test_size=val_split, random_state=seed
746
+ )
747
+
748
+ if verbose:
749
+ print(f"Training samples: {len(train_images)}")
750
+ print(f"Validation samples: {len(val_images)}")
751
+
752
+ # Create datasets
753
+ train_dataset = SegmentationDataset(
754
+ image_paths=train_images,
755
+ mask_paths=train_labels,
756
+ num_channels=num_channels,
757
+ )
758
+
759
+ val_dataset = SegmentationDataset(
760
+ image_paths=val_images,
761
+ mask_paths=val_labels,
762
+ num_channels=num_channels,
763
+ )
764
+
765
+ # Train model
766
+ model = train_timm_segmentation(
767
+ train_dataset=train_dataset,
768
+ val_dataset=val_dataset,
769
+ test_dataset=None,
770
+ encoder_name=encoder_name,
771
+ architecture=architecture,
772
+ num_classes=num_classes,
773
+ in_channels=num_channels,
774
+ encoder_weights=encoder_weights,
775
+ output_dir=output_dir,
776
+ batch_size=batch_size,
777
+ num_epochs=num_epochs,
778
+ learning_rate=learning_rate,
779
+ weight_decay=weight_decay,
780
+ num_workers=num_workers,
781
+ freeze_encoder=freeze_encoder,
782
+ accelerator="auto" if device is None else device,
783
+ monitor_metric=monitor_metric,
784
+ mode=mode,
785
+ patience=patience,
786
+ save_top_k=save_top_k,
787
+ use_timm_model=use_timm_model,
788
+ timm_model_name=timm_model_name,
789
+ **kwargs,
790
+ )
791
+
792
+ if verbose:
793
+ print(f"\nTraining completed. Model saved to {output_dir}")
794
+
795
+ return model.model # Return the underlying model
796
+
797
+
798
+ def timm_semantic_segmentation(
799
+ input_path: str,
800
+ output_path: str,
801
+ model_path: str,
802
+ encoder_name: str = "resnet50",
803
+ architecture: str = "unet",
804
+ num_channels: int = 3,
805
+ num_classes: int = 2,
806
+ window_size: int = 512,
807
+ overlap: int = 256,
808
+ batch_size: int = 4,
809
+ device: Optional[str] = None,
810
+ quiet: bool = False,
811
+ use_timm_model: bool = False,
812
+ timm_model_name: Optional[str] = None,
813
+ **kwargs: Any,
814
+ ) -> None:
815
+ """
816
+ Perform semantic segmentation on a raster using a trained timm model.
817
+
818
+ This function performs inference on a GeoTIFF using a sliding window approach
819
+ and saves the result as a georeferenced raster.
820
+
821
+ Args:
822
+ input_path (str): Path to input GeoTIFF file.
823
+ output_path (str): Path to save output mask.
824
+ model_path (str): Path to trained model checkpoint (.ckpt or .pth).
825
+ encoder_name (str): Name of timm encoder used in training.
826
+ architecture (str): Segmentation architecture used in training.
827
+ num_channels (int): Number of input channels.
828
+ num_classes (int): Number of output classes.
829
+ window_size (int): Size of sliding window for inference.
830
+ overlap (int): Overlap between adjacent windows.
831
+ batch_size (int): Batch size for inference.
832
+ device (str, optional): Device to use. Auto-detected if None.
833
+ quiet (bool): If True, suppress progress messages.
834
+ use_timm_model (bool): If True, model was trained with timm model from HF Hub.
835
+ timm_model_name (str, optional): Model name from HF Hub used during training.
836
+ **kwargs: Additional arguments.
837
+ """
838
+ import rasterio
839
+ from rasterio.windows import Window
840
+
841
+ if device is None:
842
+ device = "cuda" if torch.cuda.is_available() else "cpu"
843
+
844
+ # Load model
845
+ if model_path.endswith(".ckpt"):
846
+ model = TimmSegmentationModel.load_from_checkpoint(
847
+ model_path,
848
+ encoder_name=encoder_name,
849
+ architecture=architecture,
850
+ num_classes=num_classes,
851
+ in_channels=num_channels,
852
+ use_timm_model=use_timm_model,
853
+ timm_model_name=timm_model_name,
854
+ )
855
+ model = model.model # Get underlying model
856
+ else:
857
+ # Load state dict
858
+ if use_timm_model:
859
+ # Load pure timm model
860
+ if timm_model_name is None:
861
+ timm_model_name = encoder_name
862
+
863
+ model = timm.create_model(
864
+ timm_model_name,
865
+ pretrained=False,
866
+ num_classes=num_classes,
867
+ in_chans=num_channels,
868
+ )
869
+ else:
870
+ # Load SMP model
871
+ import segmentation_models_pytorch as smp
872
+
873
+ try:
874
+ model = smp.create_model(
875
+ arch=architecture,
876
+ encoder_name=encoder_name,
877
+ encoder_weights=None,
878
+ in_channels=num_channels,
879
+ classes=num_classes,
880
+ )
881
+ except Exception as e:
882
+ raise ValueError(
883
+ f"Failed to create model with architecture '{architecture}' and encoder '{encoder_name}'. "
884
+ f"Error: {str(e)}"
885
+ )
886
+
887
+ model.load_state_dict(torch.load(model_path, map_location=device))
888
+
889
+ model.eval()
890
+ model = model.to(device)
891
+
892
+ # Read input raster
893
+ with rasterio.open(input_path) as src:
894
+ meta = src.meta.copy()
895
+ height, width = src.shape
896
+
897
+ # Calculate number of windows
898
+ stride = window_size - overlap
899
+ n_rows = int(np.ceil((height - overlap) / stride))
900
+ n_cols = int(np.ceil((width - overlap) / stride))
901
+
902
+ if not quiet:
903
+ print(f"Processing {n_rows} x {n_cols} = {n_rows * n_cols} windows")
904
+
905
+ # Initialize output array (use int32 to avoid overflow during accumulation)
906
+ output = np.zeros((height, width), dtype=np.int32)
907
+ count = np.zeros((height, width), dtype=np.int32)
908
+
909
+ # Process windows
910
+ with torch.no_grad():
911
+ for i in tqdm(range(n_rows), disable=quiet, desc="Processing rows"):
912
+ for j in range(n_cols):
913
+ # Calculate window bounds
914
+ row_start = i * stride
915
+ col_start = j * stride
916
+ row_end = min(row_start + window_size, height)
917
+ col_end = min(col_start + window_size, width)
918
+
919
+ # Read window
920
+ window = Window(
921
+ col_start, row_start, col_end - col_start, row_end - row_start
922
+ )
923
+ img = src.read(window=window)
924
+
925
+ # Handle channel selection
926
+ if img.shape[0] > num_channels:
927
+ img = img[:num_channels]
928
+ elif img.shape[0] < num_channels:
929
+ padded = np.zeros((num_channels, img.shape[1], img.shape[2]))
930
+ padded[: img.shape[0]] = img
931
+ img = padded
932
+
933
+ # Normalize
934
+ if img.max() > 1.0:
935
+ img = img / 255.0
936
+ img = img.astype(np.float32)
937
+
938
+ # Pad if necessary
939
+ h, w = img.shape[1], img.shape[2]
940
+ if h < window_size or w < window_size:
941
+ padded = np.zeros(
942
+ (num_channels, window_size, window_size), dtype=np.float32
943
+ )
944
+ padded[:, :h, :w] = img
945
+ img = padded
946
+
947
+ # Predict
948
+ img_tensor = torch.from_numpy(img).unsqueeze(0).to(device)
949
+ logits = model(img_tensor)
950
+ pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
951
+
952
+ # Crop to actual size
953
+ pred = pred[:h, :w]
954
+
955
+ # Add to output
956
+ output[row_start:row_end, col_start:col_end] += pred
957
+ count[row_start:row_end, col_start:col_end] += 1
958
+
959
+ # Average overlapping predictions
960
+ output = (output / np.maximum(count, 1)).astype(np.uint8)
961
+
962
+ # Save output
963
+ meta.update({"count": 1, "dtype": "uint8", "compress": "lzw"})
964
+
965
+ with rasterio.open(output_path, "w", **meta) as dst:
966
+ dst.write(output, 1)
967
+
968
+ if not quiet:
969
+ print(f"Segmentation saved to {output_path}")
970
+
971
+
972
+ def push_timm_model_to_hub(
973
+ model_path: str,
974
+ repo_id: str,
975
+ encoder_name: str = "resnet50",
976
+ architecture: str = "unet",
977
+ num_channels: int = 3,
978
+ num_classes: int = 2,
979
+ use_timm_model: bool = False,
980
+ timm_model_name: Optional[str] = None,
981
+ commit_message: Optional[str] = None,
982
+ private: bool = False,
983
+ token: Optional[str] = None,
984
+ **kwargs: Any,
985
+ ) -> str:
986
+ """
987
+ Push a trained timm segmentation model to Hugging Face Hub.
988
+
989
+ Args:
990
+ model_path (str): Path to trained model checkpoint (.ckpt or .pth).
991
+ repo_id (str): Repository ID on HF Hub (e.g., 'username/model-name').
992
+ encoder_name (str): Name of timm encoder used in training.
993
+ architecture (str): Segmentation architecture used in training.
994
+ num_channels (int): Number of input channels.
995
+ num_classes (int): Number of output classes.
996
+ use_timm_model (bool): If True, model was trained with pure timm model.
997
+ timm_model_name (str, optional): Model name from HF Hub used during training.
998
+ commit_message (str, optional): Commit message for the upload.
999
+ private (bool): Whether to make the repository private.
1000
+ token (str, optional): HuggingFace API token. If None, uses logged-in token.
1001
+ **kwargs: Additional arguments for push_to_hub.
1002
+
1003
+ Returns:
1004
+ str: URL of the uploaded model on HF Hub.
1005
+ """
1006
+ try:
1007
+ from huggingface_hub import HfApi, create_repo
1008
+ except ImportError:
1009
+ raise ImportError(
1010
+ "huggingface_hub is required to push models. "
1011
+ "Install it with: pip install huggingface-hub"
1012
+ )
1013
+
1014
+ # Load model
1015
+ if model_path.endswith(".ckpt"):
1016
+ lightning_model = TimmSegmentationModel.load_from_checkpoint(
1017
+ model_path,
1018
+ encoder_name=encoder_name,
1019
+ architecture=architecture,
1020
+ num_classes=num_classes,
1021
+ in_channels=num_channels,
1022
+ use_timm_model=use_timm_model,
1023
+ timm_model_name=timm_model_name,
1024
+ )
1025
+ model = lightning_model.model
1026
+ else:
1027
+ # Load state dict
1028
+ if use_timm_model:
1029
+ if timm_model_name is None:
1030
+ timm_model_name = encoder_name
1031
+
1032
+ model = timm.create_model(
1033
+ timm_model_name,
1034
+ pretrained=False,
1035
+ num_classes=num_classes,
1036
+ in_chans=num_channels,
1037
+ )
1038
+ else:
1039
+ import segmentation_models_pytorch as smp
1040
+
1041
+ model = smp.create_model(
1042
+ arch=architecture,
1043
+ encoder_name=encoder_name,
1044
+ encoder_weights=None,
1045
+ in_channels=num_channels,
1046
+ classes=num_classes,
1047
+ )
1048
+
1049
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
1050
+
1051
+ # Create repository if it doesn't exist
1052
+ api = HfApi(token=token)
1053
+ try:
1054
+ create_repo(repo_id, private=private, token=token, exist_ok=True)
1055
+ except Exception as e:
1056
+ print(f"Repository creation note: {e}")
1057
+
1058
+ # Save model configuration
1059
+ config = {
1060
+ "encoder_name": encoder_name,
1061
+ "architecture": architecture,
1062
+ "num_channels": num_channels,
1063
+ "num_classes": num_classes,
1064
+ "use_timm_model": use_timm_model,
1065
+ "timm_model_name": timm_model_name,
1066
+ "model_type": "timm_segmentation",
1067
+ }
1068
+
1069
+ # Save model state dict to temporary file
1070
+ import tempfile
1071
+ import json
1072
+
1073
+ with tempfile.TemporaryDirectory() as tmpdir:
1074
+ # Save model state dict
1075
+ model_save_path = os.path.join(tmpdir, "model.pth")
1076
+ torch.save(model.state_dict(), model_save_path)
1077
+
1078
+ # Save config
1079
+ config_path = os.path.join(tmpdir, "config.json")
1080
+ with open(config_path, "w") as f:
1081
+ json.dump(config, f, indent=2)
1082
+
1083
+ # Upload files
1084
+ if commit_message is None:
1085
+ commit_message = f"Upload {architecture} with {encoder_name} encoder"
1086
+
1087
+ api.upload_folder(
1088
+ folder_path=tmpdir,
1089
+ repo_id=repo_id,
1090
+ commit_message=commit_message,
1091
+ token=token,
1092
+ **kwargs,
1093
+ )
1094
+
1095
+ url = f"https://huggingface.co/{repo_id}"
1096
+ print(f"Model successfully pushed to: {url}")
1097
+ return url