geoai-py 0.25.0__py2.py3-none-any.whl → 0.27.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/timm_regress.py ADDED
@@ -0,0 +1,1652 @@
1
+ """Module for training pixel-level regression models using timm encoders with PyTorch Lightning.
2
+
3
+ This module provides tools for remote sensing regression tasks like predicting NDVI,
4
+ biomass, temperature, or other continuous values at the pixel level.
5
+ """
6
+
7
+ import os
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from tqdm import tqdm
15
+
16
+ try:
17
+ import timm
18
+
19
+ TIMM_AVAILABLE = True
20
+ except ImportError:
21
+ TIMM_AVAILABLE = False
22
+
23
+ try:
24
+ import segmentation_models_pytorch as smp
25
+
26
+ SMP_AVAILABLE = True
27
+ except ImportError:
28
+ SMP_AVAILABLE = False
29
+
30
+ try:
31
+ import lightning.pytorch as pl
32
+ from lightning.pytorch.callbacks import (
33
+ ModelCheckpoint,
34
+ EarlyStopping,
35
+ TQDMProgressBar,
36
+ )
37
+ from lightning.pytorch.loggers import CSVLogger
38
+
39
+ LIGHTNING_AVAILABLE = True
40
+ except ImportError:
41
+ LIGHTNING_AVAILABLE = False
42
+
43
+
44
+ class _CompactProgressBar(TQDMProgressBar):
45
+ """Progress bar that shows key metrics in the postfix, updated in place."""
46
+
47
+ def get_metrics(self, trainer, pl_module):
48
+ # Don't let Lightning set the postfix — we control it
49
+ return {}
50
+
51
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
52
+ super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
53
+ if self.train_progress_bar is not None:
54
+ self.train_progress_bar.set_postfix_str("")
55
+
56
+ def on_validation_epoch_end(self, trainer, pl_module):
57
+ super().on_validation_epoch_end(trainer, pl_module)
58
+ metrics = trainer.callback_metrics
59
+ if metrics and self.train_progress_bar is not None:
60
+ keys = [
61
+ "train_loss_epoch",
62
+ "train_r2",
63
+ "val_loss",
64
+ "val_r2",
65
+ "val_rmse",
66
+ ]
67
+ parts = []
68
+ for k in keys:
69
+ v = metrics.get(k)
70
+ if v is not None:
71
+ val = v.item() if hasattr(v, "item") else v
72
+ if isinstance(val, float):
73
+ parts.append(f"{k}={val:.4g}")
74
+ if parts:
75
+ self.train_progress_bar.set_postfix_str(", ".join(parts))
76
+
77
+
78
+ def _prepare_normalization_stats(
79
+ mean: List[float], std: List[float], num_channels: int
80
+ ) -> Tuple[np.ndarray, np.ndarray]:
81
+ """Prepare mean/std arrays that match the number of channels."""
82
+ if len(mean) == num_channels and len(std) == num_channels:
83
+ mean_vals = mean
84
+ std_vals = std
85
+ elif len(mean) == 3 and len(std) == 3 and num_channels > 3:
86
+ mean_pad = [float(np.mean(mean))] * (num_channels - 3)
87
+ std_pad = [float(np.mean(std))] * (num_channels - 3)
88
+ mean_vals = list(mean) + mean_pad
89
+ std_vals = list(std) + std_pad
90
+ elif len(mean) == 1 and len(std) == 1:
91
+ mean_vals = list(mean) * num_channels
92
+ std_vals = list(std) * num_channels
93
+ elif len(mean) >= num_channels and len(std) >= num_channels:
94
+ mean_vals = list(mean)[:num_channels]
95
+ std_vals = list(std)[:num_channels]
96
+ else:
97
+ raise ValueError(
98
+ "Normalization stats length must match channels. "
99
+ f"Got mean={len(mean)}, std={len(std)}, channels={num_channels}."
100
+ )
101
+
102
+ mean_arr = np.array(mean_vals, dtype=np.float32)[:, None, None]
103
+ std_arr = np.array(std_vals, dtype=np.float32)[:, None, None]
104
+ return mean_arr, std_arr
105
+
106
+
107
+ def _infer_preprocessing_params(
108
+ encoder_name: str, encoder_weights: Optional[str]
109
+ ) -> Optional[Dict[str, Any]]:
110
+ if encoder_weights is None:
111
+ return None
112
+ if not SMP_AVAILABLE:
113
+ return None
114
+ try:
115
+ return smp.encoders.get_preprocessing_params(
116
+ encoder_name, pretrained=encoder_weights
117
+ )
118
+ except Exception:
119
+ return None
120
+
121
+
122
+ class PixelRegressionModel(pl.LightningModule):
123
+ """
124
+ PyTorch Lightning module for pixel-level regression using encoder-decoder architectures.
125
+
126
+ Uses segmentation-models-pytorch (SMP) with timm encoders but configured for
127
+ regression (single channel output with continuous values).
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ encoder_name: str = "resnet50",
133
+ architecture: str = "unet",
134
+ in_channels: int = 3,
135
+ encoder_weights: str = "imagenet",
136
+ learning_rate: float = 1e-4,
137
+ weight_decay: float = 1e-4,
138
+ freeze_encoder: bool = False,
139
+ loss_fn: Optional[nn.Module] = None,
140
+ loss_type: str = "mse",
141
+ **decoder_kwargs: Any,
142
+ ):
143
+ """
144
+ Initialize PixelRegressionModel.
145
+
146
+ Args:
147
+ encoder_name (str): Name of timm encoder (e.g., 'resnet50', 'efficientnet_b0').
148
+ architecture (str): Segmentation architecture ('unet', 'unetplusplus', 'deeplabv3',
149
+ 'deeplabv3plus', 'fpn', 'pspnet', 'linknet', 'manet', 'pan').
150
+ in_channels (int): Number of input channels (3 for RGB, 4 for RGBN, etc.).
151
+ encoder_weights (str): Pretrained weights for encoder ('imagenet', None).
152
+ learning_rate (float): Learning rate for optimizer.
153
+ weight_decay (float): Weight decay for optimizer.
154
+ freeze_encoder (bool): Freeze encoder weights during training.
155
+ loss_fn (nn.Module, optional): Custom loss function.
156
+ loss_type (str): Type of loss if loss_fn is None ('mse', 'l1', 'mae', 'huber').
157
+ **decoder_kwargs: Additional arguments for decoder.
158
+ """
159
+ super().__init__()
160
+
161
+ if not SMP_AVAILABLE:
162
+ raise ImportError(
163
+ "segmentation-models-pytorch is required. "
164
+ "Install it with: pip install segmentation-models-pytorch"
165
+ )
166
+
167
+ self.save_hyperparameters()
168
+
169
+ # Create segmentation model with 1 output class for regression
170
+ try:
171
+ self.model = smp.create_model(
172
+ arch=architecture,
173
+ encoder_name=encoder_name,
174
+ encoder_weights=encoder_weights,
175
+ in_channels=in_channels,
176
+ classes=1, # Single channel for regression
177
+ **decoder_kwargs,
178
+ )
179
+ except Exception as e:
180
+ available_archs = [
181
+ "unet",
182
+ "unetplusplus",
183
+ "manet",
184
+ "linknet",
185
+ "fpn",
186
+ "pspnet",
187
+ "deeplabv3",
188
+ "deeplabv3plus",
189
+ "pan",
190
+ "upernet",
191
+ ]
192
+ raise ValueError(
193
+ f"Failed to create model with architecture '{architecture}' and encoder '{encoder_name}'. "
194
+ f"Error: {str(e)}. "
195
+ f"Available architectures: {', '.join(available_archs)}."
196
+ )
197
+
198
+ if freeze_encoder:
199
+ self._freeze_encoder()
200
+
201
+ # Set up loss function
202
+ if loss_fn is not None:
203
+ self.loss_fn = loss_fn
204
+ else:
205
+ loss_type = loss_type.lower()
206
+ if loss_type == "mse":
207
+ self.loss_fn = nn.MSELoss()
208
+ elif loss_type in ["l1", "mae"]:
209
+ self.loss_fn = nn.L1Loss()
210
+ elif loss_type in ["huber", "smooth_l1"]:
211
+ self.loss_fn = nn.SmoothL1Loss()
212
+ else:
213
+ raise ValueError(f"Unknown loss_type: {loss_type}")
214
+
215
+ self.learning_rate = learning_rate
216
+ self.weight_decay = weight_decay
217
+
218
+ def _freeze_encoder(self):
219
+ """Freeze encoder weights."""
220
+ if hasattr(self.model, "encoder"):
221
+ for param in self.model.encoder.parameters():
222
+ param.requires_grad = False
223
+
224
+ def forward(self, x):
225
+ return self.model(x).squeeze(1) # Remove channel dim: (B, 1, H, W) -> (B, H, W)
226
+
227
+ def _compute_metrics(
228
+ self, preds: torch.Tensor, targets: torch.Tensor
229
+ ) -> Dict[str, torch.Tensor]:
230
+ """Compute pixel-wise regression metrics."""
231
+ # Flatten for metrics computation
232
+ preds_flat = preds.view(-1)
233
+ targets_flat = targets.view(-1)
234
+
235
+ # MSE and RMSE
236
+ mse = torch.mean((preds_flat - targets_flat) ** 2)
237
+ rmse = torch.sqrt(mse)
238
+
239
+ # MAE
240
+ mae = torch.mean(torch.abs(preds_flat - targets_flat))
241
+
242
+ # R² (coefficient of determination)
243
+ ss_res = torch.sum((targets_flat - preds_flat) ** 2)
244
+ ss_tot = torch.sum((targets_flat - targets_flat.mean()) ** 2)
245
+ r2 = 1 - ss_res / (ss_tot + 1e-8)
246
+
247
+ return {"mse": mse, "rmse": rmse, "mae": mae, "r2": r2}
248
+
249
+ def training_step(self, batch, batch_idx):
250
+ x, y = batch
251
+ preds = self(x)
252
+ loss = self.loss_fn(preds, y)
253
+
254
+ metrics = self._compute_metrics(preds, y)
255
+
256
+ pb = getattr(self, "_prog_bar_metrics", True)
257
+ self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=pb)
258
+ self.log("train_rmse", metrics["rmse"], on_step=False, on_epoch=True)
259
+ self.log("train_r2", metrics["r2"], on_step=False, on_epoch=True, prog_bar=pb)
260
+
261
+ return loss
262
+
263
+ def validation_step(self, batch, batch_idx):
264
+ x, y = batch
265
+ preds = self(x)
266
+ loss = self.loss_fn(preds, y)
267
+
268
+ metrics = self._compute_metrics(preds, y)
269
+
270
+ pb = getattr(self, "_prog_bar_metrics", True)
271
+ self.log("val_loss", loss, on_epoch=True, prog_bar=pb)
272
+ self.log("val_rmse", metrics["rmse"], on_epoch=True, prog_bar=pb)
273
+ self.log("val_mae", metrics["mae"], on_epoch=True)
274
+ self.log("val_r2", metrics["r2"], on_epoch=True, prog_bar=pb)
275
+
276
+ return loss
277
+
278
+ def test_step(self, batch, batch_idx):
279
+ x, y = batch
280
+ preds = self(x)
281
+ loss = self.loss_fn(preds, y)
282
+
283
+ metrics = self._compute_metrics(preds, y)
284
+
285
+ self.log("test_loss", loss, on_epoch=True)
286
+ self.log("test_rmse", metrics["rmse"], on_epoch=True)
287
+ self.log("test_mae", metrics["mae"], on_epoch=True)
288
+ self.log("test_r2", metrics["r2"], on_epoch=True)
289
+
290
+ return loss
291
+
292
+ def configure_optimizers(self):
293
+ optimizer = torch.optim.AdamW(
294
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
295
+ )
296
+
297
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
298
+ optimizer, mode="min", factor=0.5, patience=5
299
+ )
300
+
301
+ return {
302
+ "optimizer": optimizer,
303
+ "lr_scheduler": {
304
+ "scheduler": scheduler,
305
+ "monitor": "val_loss",
306
+ },
307
+ }
308
+
309
+ def predict_step(self, batch, batch_idx):
310
+ x = batch[0] if isinstance(batch, (list, tuple)) else batch
311
+ return self(x)
312
+
313
+
314
+ class PixelRegressionDataset(Dataset):
315
+ """
316
+ Dataset for pixel-level regression from paired image and target rasters.
317
+
318
+ Loads image patches and corresponding target patches for training
319
+ pixel-wise regression models.
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ image_paths: List[str],
325
+ target_paths: List[str],
326
+ input_bands: Optional[List[int]] = None,
327
+ target_band: int = 1,
328
+ transform: Optional[Callable] = None,
329
+ normalize_input: bool = True,
330
+ image_mean: Optional[List[float]] = None,
331
+ image_std: Optional[List[float]] = None,
332
+ target_nodata: Optional[float] = None,
333
+ ):
334
+ """
335
+ Initialize PixelRegressionDataset.
336
+
337
+ Args:
338
+ image_paths (List[str]): List of paths to input image patches.
339
+ target_paths (List[str]): List of paths to target raster patches.
340
+ input_bands (List[int], optional): Band indices to use (1-indexed). If None, uses all.
341
+ target_band (int): Band index for target raster (1-indexed).
342
+ transform (callable, optional): Transform to apply to images.
343
+ normalize_input (bool): Whether to normalize input to [0, 1].
344
+ image_mean (List[float], optional): Per-channel mean for normalization.
345
+ image_std (List[float], optional): Per-channel std for normalization.
346
+ target_nodata (float, optional): NoData value for targets.
347
+ """
348
+ self.image_paths = image_paths
349
+ self.target_paths = target_paths
350
+ self.input_bands = input_bands
351
+ self.target_band = target_band
352
+ self.transform = transform
353
+ self.normalize_input = normalize_input
354
+ self.image_mean = image_mean
355
+ self.image_std = image_std
356
+ self.target_nodata = target_nodata
357
+ self._mean_array = None
358
+ self._std_array = None
359
+
360
+ if len(image_paths) != len(target_paths):
361
+ raise ValueError("Number of images must match number of targets")
362
+
363
+ def __len__(self):
364
+ return len(self.image_paths)
365
+
366
+ def __getitem__(self, idx):
367
+ import rasterio
368
+
369
+ # Load input image
370
+ with rasterio.open(self.image_paths[idx]) as src:
371
+ if self.input_bands is not None:
372
+ image = src.read(self.input_bands)
373
+ else:
374
+ image = src.read()
375
+
376
+ # Load target
377
+ with rasterio.open(self.target_paths[idx]) as src:
378
+ target = src.read(self.target_band)
379
+
380
+ # Handle NaN
381
+ image = np.nan_to_num(image, nan=0.0)
382
+ target = np.nan_to_num(target, nan=0.0)
383
+
384
+ # Normalize input
385
+ image = image.astype(np.float32)
386
+ if self.normalize_input:
387
+ data_max = np.abs(image).max()
388
+ if data_max <= 1.5:
389
+ image = np.clip(image, 0, 1)
390
+ else:
391
+ image = np.clip(image, 0, 10000) / 10000.0
392
+
393
+ if self.image_mean is not None and self.image_std is not None:
394
+ mean, std = _prepare_normalization_stats(
395
+ self.image_mean, self.image_std, image.shape[0]
396
+ )
397
+ image = (image - mean) / std
398
+
399
+ target = target.astype(np.float32)
400
+
401
+ # Convert to tensor
402
+ image = torch.from_numpy(image)
403
+ target = torch.from_numpy(target)
404
+
405
+ if self.transform is not None:
406
+ image = self.transform(image)
407
+
408
+ return image, target
409
+
410
+
411
+ def create_regression_tiles(
412
+ input_raster: str,
413
+ target_raster: str,
414
+ output_dir: str,
415
+ tile_size: int = 256,
416
+ stride: Optional[int] = None,
417
+ input_bands: Optional[List[int]] = None,
418
+ target_band: int = 1,
419
+ min_valid_ratio: float = 0.8,
420
+ target_min: Optional[float] = None,
421
+ target_max: Optional[float] = None,
422
+ ) -> Tuple[List[str], List[str]]:
423
+ """
424
+ Create paired image and target tiles from input and target rasters.
425
+
426
+ Args:
427
+ input_raster (str): Path to input raster (e.g., Landsat imagery).
428
+ target_raster (str): Path to target raster (e.g., NDVI).
429
+ output_dir (str): Directory to save tiles.
430
+ tile_size (int): Size of each tile (tile_size x tile_size pixels).
431
+ stride (int, optional): Stride between tiles. Defaults to tile_size (no overlap).
432
+ input_bands (List[int], optional): Band indices to use (1-indexed).
433
+ target_band (int): Band index for target raster (1-indexed).
434
+ min_valid_ratio (float): Minimum ratio of valid pixels in tile.
435
+ target_min (float, optional): Minimum valid target value.
436
+ target_max (float, optional): Maximum valid target value.
437
+
438
+ Returns:
439
+ Tuple of (image_paths, target_paths): Lists of tile paths.
440
+ """
441
+ import rasterio
442
+
443
+ image_dir = os.path.join(output_dir, "images")
444
+ target_dir = os.path.join(output_dir, "targets")
445
+ os.makedirs(image_dir, exist_ok=True)
446
+ os.makedirs(target_dir, exist_ok=True)
447
+
448
+ if stride is None:
449
+ stride = tile_size
450
+
451
+ image_paths = []
452
+ target_paths = []
453
+
454
+ with rasterio.open(input_raster) as src_input:
455
+ with rasterio.open(target_raster) as src_target:
456
+ height = src_input.height
457
+ width = src_input.width
458
+
459
+ if input_bands is None:
460
+ input_bands = list(range(1, src_input.count + 1))
461
+
462
+ n_tiles_y = (height - tile_size) // stride + 1
463
+ n_tiles_x = (width - tile_size) // stride + 1
464
+
465
+ print(f"Input raster: {width}x{height}, {src_input.count} bands")
466
+ print(f"Target raster: {src_target.width}x{src_target.height}")
467
+ print(f"Tile size: {tile_size}x{tile_size}, stride: {stride}")
468
+ print(
469
+ f"Expected tiles: {n_tiles_y} x {n_tiles_x} = {n_tiles_y * n_tiles_x}"
470
+ )
471
+
472
+ tile_idx = 0
473
+ valid_tiles = 0
474
+ skipped_nodata = 0
475
+ skipped_range = 0
476
+
477
+ for row in tqdm(
478
+ range(0, height - tile_size + 1, stride), desc="Creating tiles"
479
+ ):
480
+ for col in range(0, width - tile_size + 1, stride):
481
+ window = rasterio.windows.Window(col, row, tile_size, tile_size)
482
+
483
+ # Read tiles
484
+ input_tile = src_input.read(input_bands, window=window)
485
+ target_tile = src_target.read(target_band, window=window)
486
+
487
+ # Check for valid pixels
488
+ valid_mask = ~np.isnan(input_tile).any(axis=0) & ~np.isnan(
489
+ target_tile
490
+ )
491
+ valid_ratio = valid_mask.sum() / (tile_size * tile_size)
492
+
493
+ if valid_ratio < min_valid_ratio:
494
+ tile_idx += 1
495
+ skipped_nodata += 1
496
+ continue
497
+
498
+ # Check target range - skip tiles with >5% out-of-range values
499
+ valid_target = target_tile[valid_mask]
500
+ out_of_range_ratio = 0.0
501
+ if target_min is not None or target_max is not None:
502
+ out_of_range = np.zeros_like(valid_target, dtype=bool)
503
+ if target_min is not None:
504
+ out_of_range |= valid_target < target_min
505
+ if target_max is not None:
506
+ out_of_range |= valid_target > target_max
507
+ out_of_range_ratio = out_of_range.sum() / len(valid_target)
508
+
509
+ # Skip if more than 5% of pixels are out of range
510
+ if out_of_range_ratio > 0.05:
511
+ tile_idx += 1
512
+ skipped_range += 1
513
+ continue
514
+
515
+ # Replace NaN with 0
516
+ input_tile = np.nan_to_num(input_tile, nan=0.0)
517
+ target_tile = np.nan_to_num(target_tile, nan=0.0)
518
+
519
+ # Clip target values to valid range (important!)
520
+ if target_min is not None or target_max is not None:
521
+ target_tile = np.clip(
522
+ target_tile,
523
+ target_min if target_min is not None else -np.inf,
524
+ target_max if target_max is not None else np.inf,
525
+ )
526
+
527
+ # Save tiles
528
+ image_path = os.path.join(image_dir, f"tile_{tile_idx:06d}.tif")
529
+ target_path = os.path.join(target_dir, f"tile_{tile_idx:06d}.tif")
530
+
531
+ # Save input tile
532
+ profile = src_input.profile.copy()
533
+ profile.update(
534
+ width=tile_size,
535
+ height=tile_size,
536
+ count=len(input_bands),
537
+ dtype=input_tile.dtype,
538
+ transform=rasterio.windows.transform(
539
+ window, src_input.transform
540
+ ),
541
+ tiled=False, # Disable tiling for small tiles
542
+ )
543
+ # Remove block size settings that cause warnings
544
+ profile.pop("blockxsize", None)
545
+ profile.pop("blockysize", None)
546
+ with rasterio.open(image_path, "w", **profile) as dst:
547
+ dst.write(input_tile)
548
+
549
+ # Save target tile
550
+ profile = src_target.profile.copy()
551
+ profile.update(
552
+ width=tile_size,
553
+ height=tile_size,
554
+ count=1,
555
+ dtype=target_tile.dtype,
556
+ transform=rasterio.windows.transform(
557
+ window, src_target.transform
558
+ ),
559
+ tiled=False,
560
+ )
561
+ profile.pop("blockxsize", None)
562
+ profile.pop("blockysize", None)
563
+ with rasterio.open(target_path, "w", **profile) as dst:
564
+ dst.write(target_tile[np.newaxis, :, :])
565
+
566
+ image_paths.append(image_path)
567
+ target_paths.append(target_path)
568
+ valid_tiles += 1
569
+ tile_idx += 1
570
+
571
+ print(f"\nCreated {valid_tiles} valid tiles out of {tile_idx} total")
572
+ print(f"Skipped due to nodata: {skipped_nodata}")
573
+ print(f"Skipped due to target range: {skipped_range}")
574
+
575
+ return image_paths, target_paths
576
+
577
+
578
+ def train_pixel_regressor(
579
+ train_image_paths: List[str],
580
+ train_target_paths: List[str],
581
+ val_image_paths: Optional[List[str]] = None,
582
+ val_target_paths: Optional[List[str]] = None,
583
+ encoder_name: str = "resnet50",
584
+ architecture: str = "unet",
585
+ in_channels: int = 3,
586
+ encoder_weights: str = "imagenet",
587
+ output_dir: str = "output",
588
+ batch_size: int = 8,
589
+ num_epochs: int = 50,
590
+ learning_rate: float = 1e-4,
591
+ weight_decay: float = 1e-4,
592
+ num_workers: int = 0,
593
+ freeze_encoder: bool = False,
594
+ loss_type: str = "mse",
595
+ normalize_input: bool = True,
596
+ accelerator: str = "auto",
597
+ devices: int = 1,
598
+ monitor_metric: str = "val_loss",
599
+ mode: str = "min",
600
+ patience: int = 10,
601
+ save_top_k: int = 1,
602
+ checkpoint_path: Optional[str] = None,
603
+ input_bands: Optional[List[int]] = None,
604
+ verbose: bool = True,
605
+ **kwargs: Any,
606
+ ) -> PixelRegressionModel:
607
+ """
608
+ Train a pixel-level regression model.
609
+
610
+ Args:
611
+ train_image_paths: List of training image paths.
612
+ train_target_paths: List of training target paths.
613
+ val_image_paths: List of validation image paths.
614
+ val_target_paths: List of validation target paths.
615
+ encoder_name: Name of timm encoder.
616
+ architecture: Segmentation architecture ('unet', 'unetplusplus', 'deeplabv3plus', etc.).
617
+ in_channels: Number of input channels.
618
+ encoder_weights: Pretrained weights for encoder.
619
+ output_dir: Directory to save outputs.
620
+ batch_size: Batch size for training.
621
+ num_epochs: Number of training epochs.
622
+ learning_rate: Learning rate.
623
+ weight_decay: Weight decay.
624
+ num_workers: Number of data loading workers.
625
+ freeze_encoder: Freeze encoder during training.
626
+ loss_type: Loss function type ('mse', 'l1', 'huber').
627
+ normalize_input: Normalize input tiles to expected range.
628
+ accelerator: Accelerator type.
629
+ devices: Number of devices.
630
+ monitor_metric: Metric to monitor for checkpointing.
631
+ mode: 'min' or 'max' for monitor_metric.
632
+ patience: Early stopping patience.
633
+ save_top_k: Number of best models to save.
634
+ checkpoint_path: Path to checkpoint to resume from.
635
+ input_bands: Band indices to use (1-indexed).
636
+ verbose: Whether to show detailed training logs, progress bars,
637
+ and callback messages. Defaults to True.
638
+ **kwargs: Additional arguments for Trainer.
639
+
640
+ Returns:
641
+ PixelRegressionModel: Trained model.
642
+ """
643
+ if not LIGHTNING_AVAILABLE:
644
+ raise ImportError(
645
+ "PyTorch Lightning is required. Install with: pip install lightning"
646
+ )
647
+
648
+ os.makedirs(output_dir, exist_ok=True)
649
+ model_dir = os.path.join(output_dir, "models")
650
+ os.makedirs(model_dir, exist_ok=True)
651
+
652
+ # Create model
653
+ model = PixelRegressionModel(
654
+ encoder_name=encoder_name,
655
+ architecture=architecture,
656
+ in_channels=in_channels,
657
+ encoder_weights=encoder_weights,
658
+ learning_rate=learning_rate,
659
+ weight_decay=weight_decay,
660
+ freeze_encoder=freeze_encoder,
661
+ loss_type=loss_type,
662
+ )
663
+
664
+ preprocessing = _infer_preprocessing_params(encoder_name, encoder_weights)
665
+ image_mean = None
666
+ image_std = None
667
+ if preprocessing is not None:
668
+ pp_mean = preprocessing.get("mean")
669
+ pp_std = preprocessing.get("std")
670
+ # Only apply encoder preprocessing when channel count matches
671
+ # (e.g. 3-band RGB with ImageNet weights). For multi-spectral
672
+ # inputs the ImageNet statistics are inappropriate; the
673
+ # normalize_input flag already scales values to [0, 1].
674
+ if pp_mean is not None and pp_std is not None and len(pp_mean) == in_channels:
675
+ image_mean = pp_mean
676
+ image_std = pp_std
677
+
678
+ # Create datasets
679
+ train_dataset = PixelRegressionDataset(
680
+ train_image_paths,
681
+ train_target_paths,
682
+ input_bands=input_bands,
683
+ normalize_input=normalize_input,
684
+ image_mean=image_mean,
685
+ image_std=image_std,
686
+ )
687
+
688
+ train_loader = DataLoader(
689
+ train_dataset,
690
+ batch_size=batch_size,
691
+ shuffle=True,
692
+ num_workers=num_workers,
693
+ pin_memory=True,
694
+ )
695
+
696
+ val_loader = None
697
+ if val_image_paths is not None and val_target_paths is not None:
698
+ val_dataset = PixelRegressionDataset(
699
+ val_image_paths,
700
+ val_target_paths,
701
+ input_bands=input_bands,
702
+ normalize_input=normalize_input,
703
+ image_mean=image_mean,
704
+ image_std=image_std,
705
+ )
706
+ val_loader = DataLoader(
707
+ val_dataset,
708
+ batch_size=batch_size,
709
+ shuffle=False,
710
+ num_workers=num_workers,
711
+ pin_memory=True,
712
+ )
713
+
714
+ # Callbacks
715
+ callbacks = []
716
+
717
+ checkpoint_callback = ModelCheckpoint(
718
+ dirpath=model_dir,
719
+ filename=f"{encoder_name}_{architecture}_{{epoch:02d}}_{{val_loss:.4f}}",
720
+ monitor=monitor_metric,
721
+ mode=mode,
722
+ save_top_k=save_top_k,
723
+ save_last=True,
724
+ verbose=verbose,
725
+ )
726
+ callbacks.append(checkpoint_callback)
727
+
728
+ early_stop_callback = EarlyStopping(
729
+ monitor=monitor_metric,
730
+ patience=patience,
731
+ mode=mode,
732
+ verbose=verbose,
733
+ )
734
+ callbacks.append(early_stop_callback)
735
+
736
+ if not verbose:
737
+ import logging
738
+
739
+ logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
740
+ callbacks.append(_CompactProgressBar())
741
+ model._prog_bar_metrics = False
742
+
743
+ logger = CSVLogger(model_dir, name="lightning_logs")
744
+
745
+ trainer = pl.Trainer(
746
+ max_epochs=num_epochs,
747
+ accelerator=accelerator,
748
+ devices=devices,
749
+ callbacks=callbacks,
750
+ logger=logger,
751
+ log_every_n_steps=10,
752
+ enable_model_summary=verbose,
753
+ **kwargs,
754
+ )
755
+
756
+ if verbose:
757
+ print(
758
+ f"Training {architecture} with {encoder_name} encoder"
759
+ f" for {num_epochs} epochs..."
760
+ )
761
+ print(f"Loss function: {loss_type.upper()}")
762
+
763
+ trainer.fit(
764
+ model,
765
+ train_dataloaders=train_loader,
766
+ val_dataloaders=val_loader,
767
+ ckpt_path=checkpoint_path,
768
+ )
769
+
770
+ best_model_path = checkpoint_callback.best_model_path
771
+ if best_model_path:
772
+ if verbose:
773
+ print(f"\nBest model saved at: {best_model_path}")
774
+ model = PixelRegressionModel.load_from_checkpoint(best_model_path)
775
+ model.best_model_path = best_model_path
776
+ else:
777
+ if verbose:
778
+ print("\nBest model path not found; returning last epoch model.")
779
+
780
+ return model
781
+
782
+
783
+ def predict_raster(
784
+ model: Union[PixelRegressionModel, nn.Module],
785
+ input_raster: str,
786
+ output_raster: str,
787
+ tile_size: int = 256,
788
+ overlap: int = 64,
789
+ input_bands: Optional[List[int]] = None,
790
+ batch_size: int = 4,
791
+ device: Optional[str] = None,
792
+ output_nodata: float = -9999.0,
793
+ clip_range: Optional[Tuple[float, float]] = None,
794
+ image_mean: Optional[List[float]] = None,
795
+ image_std: Optional[List[float]] = None,
796
+ use_model_preprocessing: bool = True,
797
+ ) -> str:
798
+ """
799
+ Run pixel-level inference on a raster.
800
+
801
+ Uses sliding window with overlap and blending for smooth predictions.
802
+ Output dimensions match input dimensions exactly.
803
+
804
+ Args:
805
+ model: Trained pixel regression model.
806
+ input_raster: Path to input raster.
807
+ output_raster: Path to save output raster.
808
+ tile_size: Size of tiles for inference.
809
+ overlap: Overlap between tiles for blending.
810
+ input_bands: Band indices to use (1-indexed).
811
+ batch_size: Batch size for inference.
812
+ device: Device to use.
813
+ output_nodata: NoData value for output.
814
+ clip_range: Optional tuple (min, max) to clip output values.
815
+ image_mean: Optional per-channel mean for normalization.
816
+ image_std: Optional per-channel std for normalization.
817
+ use_model_preprocessing: Use encoder preprocessing params if available.
818
+
819
+ Returns:
820
+ str: Path to output raster.
821
+ """
822
+ import rasterio
823
+
824
+ if device is None:
825
+ device = "cuda" if torch.cuda.is_available() else "cpu"
826
+
827
+ model.eval()
828
+ model = model.to(device)
829
+
830
+ stride = tile_size - overlap
831
+
832
+ with rasterio.open(input_raster) as src:
833
+ height = src.height
834
+ width = src.width
835
+
836
+ if input_bands is None:
837
+ input_bands = list(range(1, src.count + 1))
838
+
839
+ print(f"Input raster: {width}x{height}")
840
+ print(f"Tile size: {tile_size}, overlap: {overlap}, stride: {stride}")
841
+
842
+ # Initialize output arrays
843
+ output_sum = np.zeros((height, width), dtype=np.float64)
844
+ weight_sum = np.zeros((height, width), dtype=np.float64)
845
+
846
+ # Create weight mask for blending (higher weight in center)
847
+ weight_mask = _create_weight_mask(tile_size, overlap)
848
+
849
+ # Read full input for nodata mask
850
+ full_input = src.read(input_bands)
851
+ nodata_mask = np.any(np.isnan(full_input), axis=0)
852
+
853
+ if use_model_preprocessing and image_mean is None and image_std is None:
854
+ encoder_name = getattr(
855
+ getattr(model, "hparams", None), "encoder_name", None
856
+ )
857
+ encoder_weights = getattr(
858
+ getattr(model, "hparams", None), "encoder_weights", None
859
+ )
860
+ model_in_channels = getattr(
861
+ getattr(model, "hparams", None),
862
+ "in_channels",
863
+ len(input_bands),
864
+ )
865
+ if encoder_name and encoder_weights:
866
+ preprocessing = _infer_preprocessing_params(
867
+ encoder_name, encoder_weights
868
+ )
869
+ if preprocessing is not None:
870
+ pp_mean = preprocessing.get("mean")
871
+ pp_std = preprocessing.get("std")
872
+ if (
873
+ pp_mean is not None
874
+ and pp_std is not None
875
+ and len(pp_mean) == model_in_channels
876
+ ):
877
+ image_mean = pp_mean
878
+ image_std = pp_std
879
+
880
+ # Collect tiles
881
+ tiles = []
882
+ positions = []
883
+
884
+ for row in range(0, height, stride):
885
+ for col in range(0, width, stride):
886
+ # Calculate tile bounds
887
+ row_end = min(row + tile_size, height)
888
+ col_end = min(col + tile_size, width)
889
+ row_start = row_end - tile_size
890
+ col_start = col_end - tile_size
891
+
892
+ # Clamp to valid range
893
+ row_start = max(0, row_start)
894
+ col_start = max(0, col_start)
895
+
896
+ tiles.append((row_start, col_start, row_end, col_end))
897
+ positions.append((row_start, col_start))
898
+
899
+ print(f"Total tiles: {len(tiles)}")
900
+
901
+ # Process in batches
902
+ for batch_start in tqdm(
903
+ range(0, len(tiles), batch_size), desc="Running inference"
904
+ ):
905
+ batch_end = min(batch_start + batch_size, len(tiles))
906
+ batch_tiles = tiles[batch_start:batch_end]
907
+
908
+ # Load batch
909
+ batch_images = []
910
+ for row_start, col_start, row_end, col_end in batch_tiles:
911
+ window = rasterio.windows.Window(
912
+ col_start, row_start, col_end - col_start, row_end - row_start
913
+ )
914
+ tile = src.read(input_bands, window=window).astype(np.float32)
915
+
916
+ # Handle non-square tiles at edges
917
+ if tile.shape[1] != tile_size or tile.shape[2] != tile_size:
918
+ padded = np.zeros(
919
+ (len(input_bands), tile_size, tile_size), dtype=np.float32
920
+ )
921
+ padded[:, : tile.shape[1], : tile.shape[2]] = tile
922
+ tile = padded
923
+
924
+ # Normalize
925
+ tile = np.nan_to_num(tile, nan=0.0)
926
+ data_max = np.abs(tile).max()
927
+ if data_max <= 1.5:
928
+ tile = np.clip(tile, 0, 1)
929
+ else:
930
+ tile = np.clip(tile, 0, 10000) / 10000.0
931
+
932
+ if image_mean is not None and image_std is not None:
933
+ mean, std = _prepare_normalization_stats(
934
+ image_mean, image_std, tile.shape[0]
935
+ )
936
+ tile = (tile - mean) / std
937
+
938
+ batch_images.append(tile)
939
+
940
+ batch_tensor = torch.from_numpy(np.stack(batch_images)).to(device)
941
+
942
+ # Inference
943
+ with torch.no_grad():
944
+ preds = model(batch_tensor).cpu().numpy()
945
+
946
+ # Apply predictions with blending
947
+ for i, (row_start, col_start, row_end, col_end) in enumerate(batch_tiles):
948
+ pred = preds[i]
949
+ h = row_end - row_start
950
+ w = col_end - col_start
951
+
952
+ # Get the relevant portion of prediction and weight
953
+ pred_crop = pred[:h, :w]
954
+ weight_crop = weight_mask[:h, :w]
955
+
956
+ # Accumulate
957
+ output_sum[row_start:row_end, col_start:col_end] += (
958
+ pred_crop * weight_crop
959
+ )
960
+ weight_sum[row_start:row_end, col_start:col_end] += weight_crop
961
+
962
+ # Normalize by weights
963
+ valid_weights = weight_sum > 0
964
+ output_array = np.full((height, width), output_nodata, dtype=np.float32)
965
+ output_array[valid_weights] = (
966
+ output_sum[valid_weights] / weight_sum[valid_weights]
967
+ )
968
+
969
+ # Apply nodata mask
970
+ output_array[nodata_mask] = output_nodata
971
+
972
+ # Clip output to valid range if specified
973
+ if clip_range is not None:
974
+ valid_data_mask = ~nodata_mask & valid_weights
975
+ output_array[valid_data_mask] = np.clip(
976
+ output_array[valid_data_mask], clip_range[0], clip_range[1]
977
+ )
978
+
979
+ # Save output
980
+ profile = src.profile.copy()
981
+ profile.update(
982
+ count=1,
983
+ dtype=np.float32,
984
+ nodata=output_nodata,
985
+ )
986
+
987
+ output_dir = os.path.dirname(os.path.abspath(output_raster))
988
+ if output_dir:
989
+ os.makedirs(output_dir, exist_ok=True)
990
+
991
+ with rasterio.open(output_raster, "w", **profile) as dst:
992
+ dst.write(output_array, 1)
993
+
994
+ valid_data = output_array[~nodata_mask & valid_weights]
995
+ print(f"\nOutput saved to: {output_raster}")
996
+ print(f"Output dimensions: {width}x{height} (same as input)")
997
+ if len(valid_data) > 0:
998
+ print(f"Prediction range: [{valid_data.min():.4f}, {valid_data.max():.4f}]")
999
+
1000
+ return output_raster
1001
+
1002
+
1003
+ def _create_weight_mask(tile_size: int, overlap: int) -> np.ndarray:
1004
+ """Create a weight mask for blending overlapping tiles."""
1005
+ if overlap == 0:
1006
+ return np.ones((tile_size, tile_size), dtype=np.float32)
1007
+
1008
+ # Create 1D ramp
1009
+ ramp = np.ones(tile_size, dtype=np.float32)
1010
+ ramp[:overlap] = np.linspace(0, 1, overlap)
1011
+ ramp[-overlap:] = np.linspace(1, 0, overlap)
1012
+
1013
+ # Create 2D weight mask
1014
+ weight_mask = np.outer(ramp, ramp)
1015
+ return weight_mask
1016
+
1017
+
1018
+ # ============================================================================
1019
+ # Evaluation and Visualization Functions
1020
+ # ============================================================================
1021
+
1022
+
1023
+ def evaluate_regression(
1024
+ y_true: np.ndarray,
1025
+ y_pred: np.ndarray,
1026
+ mask: Optional[np.ndarray] = None,
1027
+ print_results: bool = True,
1028
+ ) -> Dict[str, float]:
1029
+ """
1030
+ Evaluate regression predictions with multiple metrics.
1031
+
1032
+ Args:
1033
+ y_true: Ground truth values.
1034
+ y_pred: Predicted values.
1035
+ mask: Optional mask of valid pixels.
1036
+ print_results: Whether to print results.
1037
+
1038
+ Returns:
1039
+ Dictionary of metrics: MSE, RMSE, MAE, R².
1040
+ """
1041
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
1042
+
1043
+ y_true = np.array(y_true).flatten()
1044
+ y_pred = np.array(y_pred).flatten()
1045
+
1046
+ if mask is not None:
1047
+ mask = np.array(mask).flatten()
1048
+ y_true = y_true[mask]
1049
+ y_pred = y_pred[mask]
1050
+
1051
+ mse = mean_squared_error(y_true, y_pred)
1052
+ rmse = np.sqrt(mse)
1053
+ mae = mean_absolute_error(y_true, y_pred)
1054
+ r2 = r2_score(y_true, y_pred)
1055
+
1056
+ metrics = {
1057
+ "mse": mse,
1058
+ "rmse": rmse,
1059
+ "mae": mae,
1060
+ "r2": r2,
1061
+ }
1062
+
1063
+ if print_results:
1064
+ print("=" * 50)
1065
+ print("Regression Evaluation Metrics")
1066
+ print("=" * 50)
1067
+ print(f"MSE: {mse:.6f}")
1068
+ print(f"RMSE: {rmse:.6f}")
1069
+ print(f"MAE: {mae:.6f}")
1070
+ print(f"R²: {r2:.4f}")
1071
+ print("=" * 50)
1072
+
1073
+ return metrics
1074
+
1075
+
1076
+ def plot_regression_comparison(
1077
+ true_raster: str,
1078
+ pred_raster: str,
1079
+ title: str = "Regression Results",
1080
+ cmap: str = "RdYlGn",
1081
+ vmin: Optional[float] = None,
1082
+ vmax: Optional[float] = None,
1083
+ valid_range: Optional[Tuple[float, float]] = None,
1084
+ figsize: Tuple[int, int] = (18, 6),
1085
+ save_path: Optional[str] = None,
1086
+ ):
1087
+ """
1088
+ Plot comparison of ground truth, prediction, and difference.
1089
+
1090
+ Args:
1091
+ true_raster: Path to ground truth raster.
1092
+ pred_raster: Path to prediction raster.
1093
+ title: Title for the plot.
1094
+ cmap: Colormap for visualization.
1095
+ vmin: Minimum value for colormap.
1096
+ vmax: Maximum value for colormap.
1097
+ valid_range: Tuple of (min, max) valid values for filtering outliers.
1098
+ figsize: Figure size.
1099
+ save_path: Path to save figure.
1100
+
1101
+ Returns:
1102
+ Tuple of (figure, metrics_dict).
1103
+ """
1104
+ import matplotlib.pyplot as plt
1105
+ import rasterio
1106
+
1107
+ with rasterio.open(true_raster) as src:
1108
+ true_data = src.read(1)
1109
+ true_nodata = src.nodata
1110
+
1111
+ with rasterio.open(pred_raster) as src:
1112
+ pred_data = src.read(1)
1113
+ pred_nodata = src.nodata
1114
+
1115
+ # Create valid mask
1116
+ valid_mask = np.ones_like(true_data, dtype=bool)
1117
+ if true_nodata is not None:
1118
+ valid_mask &= true_data != true_nodata
1119
+ if pred_nodata is not None:
1120
+ valid_mask &= pred_data != pred_nodata
1121
+ valid_mask &= ~np.isnan(true_data) & ~np.isnan(pred_data)
1122
+
1123
+ # Filter by valid range (important for NDVI which should be [-1, 1])
1124
+ if valid_range is not None:
1125
+ valid_mask &= (true_data >= valid_range[0]) & (true_data <= valid_range[1])
1126
+ valid_mask &= (pred_data >= valid_range[0]) & (pred_data <= valid_range[1])
1127
+
1128
+ # Calculate metrics
1129
+ metrics = evaluate_regression(
1130
+ true_data[valid_mask], pred_data[valid_mask], print_results=False
1131
+ )
1132
+
1133
+ # Auto-determine vmin/vmax if not specified
1134
+ if vmin is None:
1135
+ vmin = np.percentile(true_data[valid_mask], 2)
1136
+ if vmax is None:
1137
+ vmax = np.percentile(true_data[valid_mask], 98)
1138
+
1139
+ # Create masked arrays for display
1140
+ true_masked = np.ma.masked_where(~valid_mask, true_data)
1141
+ pred_masked = np.ma.masked_where(~valid_mask, pred_data)
1142
+ diff = pred_data - true_data
1143
+ diff_masked = np.ma.masked_where(~valid_mask, diff)
1144
+
1145
+ # Plot
1146
+ fig, axes = plt.subplots(1, 3, figsize=figsize)
1147
+
1148
+ im1 = axes[0].imshow(true_masked, cmap=cmap, vmin=vmin, vmax=vmax)
1149
+ axes[0].set_title("Ground Truth", fontsize=14)
1150
+ axes[0].axis("off")
1151
+ plt.colorbar(im1, ax=axes[0], shrink=0.8)
1152
+
1153
+ im2 = axes[1].imshow(pred_masked, cmap=cmap, vmin=vmin, vmax=vmax)
1154
+ axes[1].set_title(f"Prediction (R²={metrics['r2']:.4f})", fontsize=14)
1155
+ axes[1].axis("off")
1156
+ plt.colorbar(im2, ax=axes[1], shrink=0.8)
1157
+
1158
+ diff_range = max(
1159
+ abs(np.percentile(diff[valid_mask], 5)),
1160
+ abs(np.percentile(diff[valid_mask], 95)),
1161
+ )
1162
+ im3 = axes[2].imshow(diff_masked, cmap="RdBu_r", vmin=-diff_range, vmax=diff_range)
1163
+ axes[2].set_title(f"Difference (RMSE={metrics['rmse']:.4f})", fontsize=14)
1164
+ axes[2].axis("off")
1165
+ plt.colorbar(im3, ax=axes[2], shrink=0.8)
1166
+
1167
+ plt.suptitle(title, fontsize=16)
1168
+ plt.tight_layout()
1169
+
1170
+ if save_path:
1171
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
1172
+ print(f"Figure saved to: {save_path}")
1173
+
1174
+ plt.show()
1175
+
1176
+ return fig, metrics
1177
+
1178
+
1179
+ def plot_scatter(
1180
+ true_raster: str,
1181
+ pred_raster: str,
1182
+ sample_size: int = 10000,
1183
+ title: str = "Predicted vs Actual",
1184
+ valid_range: Optional[Tuple[float, float]] = None,
1185
+ fit_line: bool = True,
1186
+ figsize: Tuple[int, int] = (10, 8),
1187
+ save_path: Optional[str] = None,
1188
+ ):
1189
+ """
1190
+ Plot scatter plot of predicted vs actual values with optional trend line.
1191
+
1192
+ Args:
1193
+ true_raster: Path to ground truth raster.
1194
+ pred_raster: Path to prediction raster.
1195
+ sample_size: Number of points to plot (sampled if needed).
1196
+ title: Title for the plot.
1197
+ valid_range: Tuple of (min, max) valid values for filtering outliers.
1198
+ fit_line: Whether to show a linear regression trend line.
1199
+ figsize: Figure size.
1200
+ save_path: Path to save figure.
1201
+
1202
+ Returns:
1203
+ Tuple of (figure, metrics_dict).
1204
+ """
1205
+ import matplotlib.pyplot as plt
1206
+ import rasterio
1207
+ from sklearn.metrics import r2_score
1208
+
1209
+ with rasterio.open(true_raster) as src:
1210
+ true_data = src.read(1)
1211
+ true_nodata = src.nodata
1212
+
1213
+ with rasterio.open(pred_raster) as src:
1214
+ pred_data = src.read(1)
1215
+ pred_nodata = src.nodata
1216
+
1217
+ # Create valid mask
1218
+ valid_mask = np.ones_like(true_data, dtype=bool)
1219
+ if true_nodata is not None:
1220
+ valid_mask &= true_data != true_nodata
1221
+ if pred_nodata is not None:
1222
+ valid_mask &= pred_data != pred_nodata
1223
+ valid_mask &= ~np.isnan(true_data) & ~np.isnan(pred_data)
1224
+
1225
+ # Filter by valid range
1226
+ if valid_range is not None:
1227
+ valid_mask &= (true_data >= valid_range[0]) & (true_data <= valid_range[1])
1228
+ valid_mask &= (pred_data >= valid_range[0]) & (pred_data <= valid_range[1])
1229
+
1230
+ y_true = true_data[valid_mask]
1231
+ y_pred = pred_data[valid_mask]
1232
+
1233
+ # Sample if too many points
1234
+ if len(y_true) > sample_size:
1235
+ idx = np.random.choice(len(y_true), sample_size, replace=False)
1236
+ y_true_plot = y_true[idx]
1237
+ y_pred_plot = y_pred[idx]
1238
+ else:
1239
+ y_true_plot = y_true
1240
+ y_pred_plot = y_pred
1241
+
1242
+ # Calculate metrics on full data
1243
+ metrics = evaluate_regression(y_true, y_pred, print_results=False)
1244
+
1245
+ # Plot
1246
+ fig, ax = plt.subplots(figsize=figsize)
1247
+
1248
+ ax.scatter(y_true_plot, y_pred_plot, alpha=0.3, s=5, edgecolors="none")
1249
+
1250
+ # Add 1:1 line
1251
+ min_val = min(y_true.min(), y_pred.min())
1252
+ max_val = max(y_true.max(), y_pred.max())
1253
+ ax.plot([min_val, max_val], [min_val, max_val], "r--", lw=2, label="1:1 Line")
1254
+
1255
+ # Add linear regression trend line
1256
+ if fit_line:
1257
+ coeffs = np.polyfit(y_true, y_pred, 1)
1258
+ slope, intercept = coeffs
1259
+ fit_x = np.array([min_val, max_val])
1260
+ fit_y = slope * fit_x + intercept
1261
+ ax.plot(
1262
+ fit_x,
1263
+ fit_y,
1264
+ "b-",
1265
+ lw=2,
1266
+ label=f"Fit: y = {slope:.3f}x + {intercept:.3f}",
1267
+ )
1268
+ metrics["slope"] = float(slope)
1269
+ metrics["intercept"] = float(intercept)
1270
+
1271
+ ax.set_xlabel("Actual Values", fontsize=12)
1272
+ ax.set_ylabel("Predicted Values", fontsize=12)
1273
+ ax.set_title(
1274
+ f"{title}\nR² = {metrics['r2']:.4f}, RMSE = {metrics['rmse']:.4f}", fontsize=14
1275
+ )
1276
+ ax.legend()
1277
+ ax.grid(True, alpha=0.3)
1278
+
1279
+ plt.tight_layout()
1280
+
1281
+ if save_path:
1282
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
1283
+ print(f"Figure saved to: {save_path}")
1284
+
1285
+ plt.show()
1286
+
1287
+ return fig, metrics
1288
+
1289
+
1290
+ def plot_training_history(
1291
+ log_dir: str,
1292
+ metrics: Optional[List[str]] = None,
1293
+ figsize: Optional[Tuple[int, int]] = None,
1294
+ tail: Optional[int] = None,
1295
+ save_path: Optional[str] = None,
1296
+ ):
1297
+ """
1298
+ Plot training history curves from Lightning CSV logs.
1299
+
1300
+ Reads the ``metrics.csv`` file produced by :class:`CSVLogger` and plots
1301
+ the requested training and validation metrics over epochs.
1302
+
1303
+ Args:
1304
+ log_dir: Path to the model output directory (the same ``output_dir``
1305
+ passed to :func:`train_pixel_regressor`). The function searches
1306
+ for ``lightning_logs/version_*/metrics.csv`` inside a ``models``
1307
+ sub-directory (or directly under *log_dir*).
1308
+ metrics: List of metric names to plot. Each name is matched against
1309
+ the CSV columns; both the ``train_`` and ``val_`` variants are
1310
+ plotted when available. Defaults to ``["loss", "r2"]``.
1311
+ figsize: Figure size as ``(width, height)``. Defaults to
1312
+ ``(6 * n_metrics, 5)``.
1313
+ tail: If given, only plot the last *tail* epochs. Useful for
1314
+ skipping early warm-up instability. By default the function
1315
+ automatically skips early epochs when extreme outliers would
1316
+ compress the y-axis (more than 10× the stable range).
1317
+ save_path: If given, save the figure to this path.
1318
+
1319
+ Returns:
1320
+ Tuple of (figure, pandas.DataFrame of the loaded metrics).
1321
+ """
1322
+ import glob
1323
+
1324
+ import matplotlib.pyplot as plt
1325
+
1326
+ try:
1327
+ import pandas as pd
1328
+ except ImportError:
1329
+ raise ImportError("pandas is required for plot_training_history")
1330
+
1331
+ if metrics is None:
1332
+ metrics = ["loss", "r2"]
1333
+
1334
+ # Locate metrics.csv
1335
+ search_paths = [
1336
+ os.path.join(log_dir, "models", "lightning_logs", "version_*", "metrics.csv"),
1337
+ os.path.join(log_dir, "lightning_logs", "version_*", "metrics.csv"),
1338
+ os.path.join(log_dir, "version_*", "metrics.csv"),
1339
+ ]
1340
+
1341
+ csv_path = None
1342
+ for pattern in search_paths:
1343
+ matches = sorted(glob.glob(pattern))
1344
+ if matches:
1345
+ csv_path = matches[-1] # latest version
1346
+ break
1347
+
1348
+ if csv_path is None:
1349
+ raise FileNotFoundError(
1350
+ f"No metrics.csv found under '{log_dir}'. "
1351
+ "Looked for lightning_logs/version_*/metrics.csv"
1352
+ )
1353
+
1354
+ df = pd.read_csv(csv_path)
1355
+ _n_epochs = df["epoch"].nunique() if "epoch" in df.columns else len(df)
1356
+ print(f"Reading logs: {csv_path} ({_n_epochs} epochs)")
1357
+
1358
+ # Group rows by epoch – Lightning logs multiple rows per epoch (one per
1359
+ # step plus validation). Use ``last()`` with ``skipna`` so we keep the
1360
+ # last non-null value for every column within each epoch.
1361
+ if "epoch" in df.columns:
1362
+ try:
1363
+ df_epoch = df.groupby("epoch").last(skipna=True).reset_index()
1364
+ except TypeError:
1365
+ # older pandas without skipna
1366
+ df_epoch = df.groupby("epoch").last().reset_index()
1367
+ else:
1368
+ df_epoch = df
1369
+
1370
+ # Apply tail filter
1371
+ if tail is not None:
1372
+ df_epoch = df_epoch.tail(tail).reset_index(drop=True)
1373
+
1374
+ n_metrics = len(metrics)
1375
+ if figsize is None:
1376
+ figsize = (6 * n_metrics, 5)
1377
+
1378
+ fig, axes = plt.subplots(1, n_metrics, figsize=figsize)
1379
+ if n_metrics == 1:
1380
+ axes = [axes]
1381
+
1382
+ for ax, metric in zip(axes, metrics):
1383
+ train_col = (
1384
+ f"train_{metric}_epoch"
1385
+ if f"train_{metric}_epoch" in df_epoch.columns
1386
+ else f"train_{metric}"
1387
+ )
1388
+ val_col = f"val_{metric}"
1389
+
1390
+ has_train = train_col in df_epoch.columns
1391
+ has_val = val_col in df_epoch.columns
1392
+
1393
+ if not has_train and not has_val:
1394
+ ax.set_title(f"{metric} (no data)")
1395
+ continue
1396
+
1397
+ x = df_epoch["epoch"] if "epoch" in df_epoch.columns else df_epoch.index
1398
+
1399
+ if has_train:
1400
+ train_data = df_epoch[train_col].dropna()
1401
+ ax.plot(
1402
+ x[train_data.index],
1403
+ train_data.values,
1404
+ label=f"Train {metric}",
1405
+ linewidth=2,
1406
+ )
1407
+ if has_val:
1408
+ val_data = df_epoch[val_col].dropna()
1409
+ ax.plot(
1410
+ x[val_data.index],
1411
+ val_data.values,
1412
+ label=f"Val {metric}",
1413
+ linewidth=2,
1414
+ )
1415
+
1416
+ # Auto-zoom: if early outliers compress the view, clip y-axis to
1417
+ # the range of the stable second half of training.
1418
+ if tail is None:
1419
+ n_epochs = len(df_epoch)
1420
+ if n_epochs >= 10:
1421
+ half = n_epochs // 2
1422
+ second_half_vals = []
1423
+ all_vals = []
1424
+ if has_train:
1425
+ col_data = df_epoch[train_col].dropna()
1426
+ all_vals.extend(col_data.values)
1427
+ second_half_vals.extend(
1428
+ df_epoch[train_col].iloc[half:].dropna().values
1429
+ )
1430
+ if has_val:
1431
+ col_data = df_epoch[val_col].dropna()
1432
+ all_vals.extend(col_data.values)
1433
+ second_half_vals.extend(
1434
+ df_epoch[val_col].iloc[half:].dropna().values
1435
+ )
1436
+ if second_half_vals and all_vals:
1437
+ sh_arr = np.array(second_half_vals)
1438
+ all_arr = np.array(all_vals)
1439
+ sh_min, sh_max = sh_arr.min(), sh_arr.max()
1440
+ sh_range = sh_max - sh_min if sh_max != sh_min else 1.0
1441
+ full_range = all_arr.max() - all_arr.min()
1442
+ if full_range == 0:
1443
+ full_range = 1.0
1444
+ # If full range is >5× the stable range, zoom in
1445
+ if full_range > 5 * sh_range:
1446
+ margin = sh_range * 0.3
1447
+ ax.set_ylim(sh_min - margin, sh_max + margin)
1448
+
1449
+ label = metric.upper() if len(metric) <= 4 else metric.replace("_", " ").title()
1450
+ ax.set_xlabel("Epoch", fontsize=12)
1451
+ ax.set_ylabel(label, fontsize=12)
1452
+ ax.set_title(f"Training & Validation {label}", fontsize=14)
1453
+ ax.legend()
1454
+ ax.grid(True, alpha=0.3)
1455
+
1456
+ plt.tight_layout()
1457
+
1458
+ if save_path:
1459
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
1460
+ print(f"Figure saved to: {save_path}")
1461
+
1462
+ plt.show()
1463
+
1464
+ return fig, df_epoch
1465
+
1466
+
1467
+ def visualize_prediction(
1468
+ input_raster: str,
1469
+ pred_raster: str,
1470
+ rgb_bands: List[int] = [1, 2, 3],
1471
+ cmap: str = "RdYlGn",
1472
+ vmin: Optional[float] = None,
1473
+ vmax: Optional[float] = None,
1474
+ figsize: Tuple[int, int] = (14, 6),
1475
+ save_path: Optional[str] = None,
1476
+ ):
1477
+ """
1478
+ Visualize input RGB and prediction side by side.
1479
+
1480
+ Args:
1481
+ input_raster: Path to input raster.
1482
+ pred_raster: Path to prediction raster.
1483
+ rgb_bands: Band indices for RGB display (1-indexed).
1484
+ cmap: Colormap for prediction.
1485
+ vmin: Minimum value for colormap.
1486
+ vmax: Maximum value for colormap.
1487
+ figsize: Figure size.
1488
+ save_path: Path to save figure.
1489
+
1490
+ Returns:
1491
+ Figure object.
1492
+ """
1493
+ import matplotlib.pyplot as plt
1494
+ import rasterio
1495
+
1496
+ with rasterio.open(input_raster) as src:
1497
+ rgb = src.read(rgb_bands).astype(np.float64)
1498
+ # Per-band 2–98 percentile stretch for proper RGB display
1499
+ for i in range(rgb.shape[0]):
1500
+ band = rgb[i]
1501
+ valid = band[
1502
+ np.isfinite(band) & (band != src.nodata if src.nodata else True)
1503
+ ]
1504
+ if valid.size > 0:
1505
+ p2, p98 = np.percentile(valid, [2, 98])
1506
+ if p98 > p2:
1507
+ rgb[i] = (band - p2) / (p98 - p2)
1508
+ else:
1509
+ rgb[i] = band / p98 if p98 > 0 else band
1510
+ rgb = np.clip(rgb, 0, 1)
1511
+ rgb = np.transpose(rgb, (1, 2, 0))
1512
+
1513
+ with rasterio.open(pred_raster) as src:
1514
+ pred = src.read(1)
1515
+ pred_nodata = src.nodata
1516
+
1517
+ # Mask
1518
+ valid_mask = np.ones_like(pred, dtype=bool)
1519
+ if pred_nodata is not None:
1520
+ valid_mask &= pred != pred_nodata
1521
+ valid_mask &= ~np.isnan(pred)
1522
+ pred_masked = np.ma.masked_where(~valid_mask, pred)
1523
+
1524
+ if vmin is None:
1525
+ vmin = np.percentile(pred[valid_mask], 2)
1526
+ if vmax is None:
1527
+ vmax = np.percentile(pred[valid_mask], 98)
1528
+
1529
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1530
+
1531
+ axes[0].imshow(rgb)
1532
+ axes[0].set_title("Input RGB", fontsize=14)
1533
+ axes[0].axis("off")
1534
+
1535
+ im = axes[1].imshow(pred_masked, cmap=cmap, vmin=vmin, vmax=vmax)
1536
+ axes[1].set_title("Prediction", fontsize=14)
1537
+ axes[1].axis("off")
1538
+ plt.colorbar(im, ax=axes[1], shrink=0.8)
1539
+
1540
+ plt.tight_layout()
1541
+
1542
+ if save_path:
1543
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
1544
+
1545
+ plt.close(fig)
1546
+ return fig
1547
+
1548
+
1549
+ # ============================================================================
1550
+ # Backward compatibility aliases
1551
+ # ============================================================================
1552
+
1553
+ # Aliases for backward compatibility
1554
+ TimmRegressor = PixelRegressionModel
1555
+ RegressionDataset = PixelRegressionDataset
1556
+ create_regression_patches = create_regression_tiles
1557
+ train_timm_regressor = train_pixel_regressor
1558
+
1559
+
1560
+ def plot_regression_results(
1561
+ y_true: np.ndarray,
1562
+ y_pred: np.ndarray,
1563
+ title: str = "Regression Results",
1564
+ fit_line: bool = True,
1565
+ figsize: Tuple[int, int] = (12, 5),
1566
+ save_path: Optional[str] = None,
1567
+ ):
1568
+ """
1569
+ Plot regression results: scatter plot with trend line and residual plot.
1570
+
1571
+ Args:
1572
+ y_true: Ground truth values.
1573
+ y_pred: Predicted values.
1574
+ title: Title for the plots.
1575
+ fit_line: Whether to show a linear regression trend line.
1576
+ figsize: Figure size.
1577
+ save_path: Path to save the figure.
1578
+ """
1579
+ import matplotlib.pyplot as plt
1580
+ from sklearn.metrics import r2_score
1581
+
1582
+ y_true = np.array(y_true).flatten()
1583
+ y_pred = np.array(y_pred).flatten()
1584
+
1585
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1586
+
1587
+ # Scatter plot
1588
+ ax1 = axes[0]
1589
+ ax1.scatter(y_true, y_pred, alpha=0.5, edgecolors="none", s=20)
1590
+
1591
+ min_val = min(y_true.min(), y_pred.min())
1592
+ max_val = max(y_true.max(), y_pred.max())
1593
+ ax1.plot([min_val, max_val], [min_val, max_val], "r--", lw=2, label="1:1 Line")
1594
+
1595
+ # Add linear regression trend line
1596
+ if fit_line:
1597
+ coeffs = np.polyfit(y_true, y_pred, 1)
1598
+ slope, intercept = coeffs
1599
+ fit_x = np.array([min_val, max_val])
1600
+ fit_y = slope * fit_x + intercept
1601
+ ax1.plot(
1602
+ fit_x,
1603
+ fit_y,
1604
+ "b-",
1605
+ lw=2,
1606
+ label=f"Fit: y = {slope:.3f}x + {intercept:.3f}",
1607
+ )
1608
+
1609
+ r2 = r2_score(y_true, y_pred)
1610
+ ax1.set_xlabel("Actual Values", fontsize=12)
1611
+ ax1.set_ylabel("Predicted Values", fontsize=12)
1612
+ ax1.set_title(f"Predicted vs Actual (R² = {r2:.4f})", fontsize=14)
1613
+ ax1.legend()
1614
+ ax1.grid(True, alpha=0.3)
1615
+
1616
+ # Residual plot
1617
+ ax2 = axes[1]
1618
+ residuals = y_pred - y_true
1619
+ ax2.scatter(y_pred, residuals, alpha=0.5, edgecolors="none", s=20)
1620
+ ax2.axhline(y=0, color="r", linestyle="--", lw=2)
1621
+
1622
+ ax2.set_xlabel("Predicted Values", fontsize=12)
1623
+ ax2.set_ylabel("Residuals", fontsize=12)
1624
+ ax2.set_title("Residual Plot", fontsize=14)
1625
+ ax2.grid(True, alpha=0.3)
1626
+
1627
+ plt.suptitle(title, fontsize=16, y=1.02)
1628
+ plt.tight_layout()
1629
+
1630
+ if save_path:
1631
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
1632
+ print(f"Figure saved to: {save_path}")
1633
+
1634
+ plt.show()
1635
+
1636
+ return fig
1637
+
1638
+
1639
+ def predict_with_timm_regressor(*args, **kwargs):
1640
+ """Deprecated: Use predict_raster instead."""
1641
+ raise NotImplementedError(
1642
+ "predict_with_timm_regressor is deprecated. "
1643
+ "Use predict_raster for pixel-level predictions."
1644
+ )
1645
+
1646
+
1647
+ def get_timm_regression_model(*args, **kwargs):
1648
+ """Deprecated: Use PixelRegressionModel instead."""
1649
+ raise NotImplementedError(
1650
+ "get_timm_regression_model is deprecated. "
1651
+ "Use PixelRegressionModel for pixel-level regression."
1652
+ )