geoai-py 0.26.0__py2.py3-none-any.whl → 0.28.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +41 -1
- geoai/auto.py +4 -1
- geoai/change_detection.py +1 -1
- geoai/detectron2.py +4 -1
- geoai/extract.py +10 -7
- geoai/hf.py +3 -3
- geoai/moondream.py +2 -2
- geoai/onnx.py +1155 -0
- geoai/prithvi.py +92 -7
- geoai/sam.py +2 -1
- geoai/segment.py +10 -1
- geoai/timm_regress.py +1652 -0
- geoai/train.py +1 -1
- geoai/utils.py +550 -1
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/METADATA +9 -7
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/RECORD +20 -18
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/WHEEL +1 -1
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/top_level.txt +0 -0
geoai/timm_regress.py
ADDED
|
@@ -0,0 +1,1652 @@
|
|
|
1
|
+
"""Module for training pixel-level regression models using timm encoders with PyTorch Lightning.
|
|
2
|
+
|
|
3
|
+
This module provides tools for remote sensing regression tasks like predicting NDVI,
|
|
4
|
+
biomass, temperature, or other continuous values at the pixel level.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from torch.utils.data import Dataset, DataLoader
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import timm
|
|
18
|
+
|
|
19
|
+
TIMM_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
TIMM_AVAILABLE = False
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
import segmentation_models_pytorch as smp
|
|
25
|
+
|
|
26
|
+
SMP_AVAILABLE = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
SMP_AVAILABLE = False
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import lightning.pytorch as pl
|
|
32
|
+
from lightning.pytorch.callbacks import (
|
|
33
|
+
ModelCheckpoint,
|
|
34
|
+
EarlyStopping,
|
|
35
|
+
TQDMProgressBar,
|
|
36
|
+
)
|
|
37
|
+
from lightning.pytorch.loggers import CSVLogger
|
|
38
|
+
|
|
39
|
+
LIGHTNING_AVAILABLE = True
|
|
40
|
+
except ImportError:
|
|
41
|
+
LIGHTNING_AVAILABLE = False
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class _CompactProgressBar(TQDMProgressBar):
|
|
45
|
+
"""Progress bar that shows key metrics in the postfix, updated in place."""
|
|
46
|
+
|
|
47
|
+
def get_metrics(self, trainer, pl_module):
|
|
48
|
+
# Don't let Lightning set the postfix — we control it
|
|
49
|
+
return {}
|
|
50
|
+
|
|
51
|
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
52
|
+
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
|
|
53
|
+
if self.train_progress_bar is not None:
|
|
54
|
+
self.train_progress_bar.set_postfix_str("")
|
|
55
|
+
|
|
56
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
57
|
+
super().on_validation_epoch_end(trainer, pl_module)
|
|
58
|
+
metrics = trainer.callback_metrics
|
|
59
|
+
if metrics and self.train_progress_bar is not None:
|
|
60
|
+
keys = [
|
|
61
|
+
"train_loss_epoch",
|
|
62
|
+
"train_r2",
|
|
63
|
+
"val_loss",
|
|
64
|
+
"val_r2",
|
|
65
|
+
"val_rmse",
|
|
66
|
+
]
|
|
67
|
+
parts = []
|
|
68
|
+
for k in keys:
|
|
69
|
+
v = metrics.get(k)
|
|
70
|
+
if v is not None:
|
|
71
|
+
val = v.item() if hasattr(v, "item") else v
|
|
72
|
+
if isinstance(val, float):
|
|
73
|
+
parts.append(f"{k}={val:.4g}")
|
|
74
|
+
if parts:
|
|
75
|
+
self.train_progress_bar.set_postfix_str(", ".join(parts))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _prepare_normalization_stats(
|
|
79
|
+
mean: List[float], std: List[float], num_channels: int
|
|
80
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
81
|
+
"""Prepare mean/std arrays that match the number of channels."""
|
|
82
|
+
if len(mean) == num_channels and len(std) == num_channels:
|
|
83
|
+
mean_vals = mean
|
|
84
|
+
std_vals = std
|
|
85
|
+
elif len(mean) == 3 and len(std) == 3 and num_channels > 3:
|
|
86
|
+
mean_pad = [float(np.mean(mean))] * (num_channels - 3)
|
|
87
|
+
std_pad = [float(np.mean(std))] * (num_channels - 3)
|
|
88
|
+
mean_vals = list(mean) + mean_pad
|
|
89
|
+
std_vals = list(std) + std_pad
|
|
90
|
+
elif len(mean) == 1 and len(std) == 1:
|
|
91
|
+
mean_vals = list(mean) * num_channels
|
|
92
|
+
std_vals = list(std) * num_channels
|
|
93
|
+
elif len(mean) >= num_channels and len(std) >= num_channels:
|
|
94
|
+
mean_vals = list(mean)[:num_channels]
|
|
95
|
+
std_vals = list(std)[:num_channels]
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
"Normalization stats length must match channels. "
|
|
99
|
+
f"Got mean={len(mean)}, std={len(std)}, channels={num_channels}."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
mean_arr = np.array(mean_vals, dtype=np.float32)[:, None, None]
|
|
103
|
+
std_arr = np.array(std_vals, dtype=np.float32)[:, None, None]
|
|
104
|
+
return mean_arr, std_arr
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _infer_preprocessing_params(
|
|
108
|
+
encoder_name: str, encoder_weights: Optional[str]
|
|
109
|
+
) -> Optional[Dict[str, Any]]:
|
|
110
|
+
if encoder_weights is None:
|
|
111
|
+
return None
|
|
112
|
+
if not SMP_AVAILABLE:
|
|
113
|
+
return None
|
|
114
|
+
try:
|
|
115
|
+
return smp.encoders.get_preprocessing_params(
|
|
116
|
+
encoder_name, pretrained=encoder_weights
|
|
117
|
+
)
|
|
118
|
+
except Exception:
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class PixelRegressionModel(pl.LightningModule):
|
|
123
|
+
"""
|
|
124
|
+
PyTorch Lightning module for pixel-level regression using encoder-decoder architectures.
|
|
125
|
+
|
|
126
|
+
Uses segmentation-models-pytorch (SMP) with timm encoders but configured for
|
|
127
|
+
regression (single channel output with continuous values).
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
encoder_name: str = "resnet50",
|
|
133
|
+
architecture: str = "unet",
|
|
134
|
+
in_channels: int = 3,
|
|
135
|
+
encoder_weights: str = "imagenet",
|
|
136
|
+
learning_rate: float = 1e-4,
|
|
137
|
+
weight_decay: float = 1e-4,
|
|
138
|
+
freeze_encoder: bool = False,
|
|
139
|
+
loss_fn: Optional[nn.Module] = None,
|
|
140
|
+
loss_type: str = "mse",
|
|
141
|
+
**decoder_kwargs: Any,
|
|
142
|
+
):
|
|
143
|
+
"""
|
|
144
|
+
Initialize PixelRegressionModel.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
encoder_name (str): Name of timm encoder (e.g., 'resnet50', 'efficientnet_b0').
|
|
148
|
+
architecture (str): Segmentation architecture ('unet', 'unetplusplus', 'deeplabv3',
|
|
149
|
+
'deeplabv3plus', 'fpn', 'pspnet', 'linknet', 'manet', 'pan').
|
|
150
|
+
in_channels (int): Number of input channels (3 for RGB, 4 for RGBN, etc.).
|
|
151
|
+
encoder_weights (str): Pretrained weights for encoder ('imagenet', None).
|
|
152
|
+
learning_rate (float): Learning rate for optimizer.
|
|
153
|
+
weight_decay (float): Weight decay for optimizer.
|
|
154
|
+
freeze_encoder (bool): Freeze encoder weights during training.
|
|
155
|
+
loss_fn (nn.Module, optional): Custom loss function.
|
|
156
|
+
loss_type (str): Type of loss if loss_fn is None ('mse', 'l1', 'mae', 'huber').
|
|
157
|
+
**decoder_kwargs: Additional arguments for decoder.
|
|
158
|
+
"""
|
|
159
|
+
super().__init__()
|
|
160
|
+
|
|
161
|
+
if not SMP_AVAILABLE:
|
|
162
|
+
raise ImportError(
|
|
163
|
+
"segmentation-models-pytorch is required. "
|
|
164
|
+
"Install it with: pip install segmentation-models-pytorch"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
self.save_hyperparameters()
|
|
168
|
+
|
|
169
|
+
# Create segmentation model with 1 output class for regression
|
|
170
|
+
try:
|
|
171
|
+
self.model = smp.create_model(
|
|
172
|
+
arch=architecture,
|
|
173
|
+
encoder_name=encoder_name,
|
|
174
|
+
encoder_weights=encoder_weights,
|
|
175
|
+
in_channels=in_channels,
|
|
176
|
+
classes=1, # Single channel for regression
|
|
177
|
+
**decoder_kwargs,
|
|
178
|
+
)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
available_archs = [
|
|
181
|
+
"unet",
|
|
182
|
+
"unetplusplus",
|
|
183
|
+
"manet",
|
|
184
|
+
"linknet",
|
|
185
|
+
"fpn",
|
|
186
|
+
"pspnet",
|
|
187
|
+
"deeplabv3",
|
|
188
|
+
"deeplabv3plus",
|
|
189
|
+
"pan",
|
|
190
|
+
"upernet",
|
|
191
|
+
]
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Failed to create model with architecture '{architecture}' and encoder '{encoder_name}'. "
|
|
194
|
+
f"Error: {str(e)}. "
|
|
195
|
+
f"Available architectures: {', '.join(available_archs)}."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if freeze_encoder:
|
|
199
|
+
self._freeze_encoder()
|
|
200
|
+
|
|
201
|
+
# Set up loss function
|
|
202
|
+
if loss_fn is not None:
|
|
203
|
+
self.loss_fn = loss_fn
|
|
204
|
+
else:
|
|
205
|
+
loss_type = loss_type.lower()
|
|
206
|
+
if loss_type == "mse":
|
|
207
|
+
self.loss_fn = nn.MSELoss()
|
|
208
|
+
elif loss_type in ["l1", "mae"]:
|
|
209
|
+
self.loss_fn = nn.L1Loss()
|
|
210
|
+
elif loss_type in ["huber", "smooth_l1"]:
|
|
211
|
+
self.loss_fn = nn.SmoothL1Loss()
|
|
212
|
+
else:
|
|
213
|
+
raise ValueError(f"Unknown loss_type: {loss_type}")
|
|
214
|
+
|
|
215
|
+
self.learning_rate = learning_rate
|
|
216
|
+
self.weight_decay = weight_decay
|
|
217
|
+
|
|
218
|
+
def _freeze_encoder(self):
|
|
219
|
+
"""Freeze encoder weights."""
|
|
220
|
+
if hasattr(self.model, "encoder"):
|
|
221
|
+
for param in self.model.encoder.parameters():
|
|
222
|
+
param.requires_grad = False
|
|
223
|
+
|
|
224
|
+
def forward(self, x):
|
|
225
|
+
return self.model(x).squeeze(1) # Remove channel dim: (B, 1, H, W) -> (B, H, W)
|
|
226
|
+
|
|
227
|
+
def _compute_metrics(
|
|
228
|
+
self, preds: torch.Tensor, targets: torch.Tensor
|
|
229
|
+
) -> Dict[str, torch.Tensor]:
|
|
230
|
+
"""Compute pixel-wise regression metrics."""
|
|
231
|
+
# Flatten for metrics computation
|
|
232
|
+
preds_flat = preds.view(-1)
|
|
233
|
+
targets_flat = targets.view(-1)
|
|
234
|
+
|
|
235
|
+
# MSE and RMSE
|
|
236
|
+
mse = torch.mean((preds_flat - targets_flat) ** 2)
|
|
237
|
+
rmse = torch.sqrt(mse)
|
|
238
|
+
|
|
239
|
+
# MAE
|
|
240
|
+
mae = torch.mean(torch.abs(preds_flat - targets_flat))
|
|
241
|
+
|
|
242
|
+
# R² (coefficient of determination)
|
|
243
|
+
ss_res = torch.sum((targets_flat - preds_flat) ** 2)
|
|
244
|
+
ss_tot = torch.sum((targets_flat - targets_flat.mean()) ** 2)
|
|
245
|
+
r2 = 1 - ss_res / (ss_tot + 1e-8)
|
|
246
|
+
|
|
247
|
+
return {"mse": mse, "rmse": rmse, "mae": mae, "r2": r2}
|
|
248
|
+
|
|
249
|
+
def training_step(self, batch, batch_idx):
|
|
250
|
+
x, y = batch
|
|
251
|
+
preds = self(x)
|
|
252
|
+
loss = self.loss_fn(preds, y)
|
|
253
|
+
|
|
254
|
+
metrics = self._compute_metrics(preds, y)
|
|
255
|
+
|
|
256
|
+
pb = getattr(self, "_prog_bar_metrics", True)
|
|
257
|
+
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=pb)
|
|
258
|
+
self.log("train_rmse", metrics["rmse"], on_step=False, on_epoch=True)
|
|
259
|
+
self.log("train_r2", metrics["r2"], on_step=False, on_epoch=True, prog_bar=pb)
|
|
260
|
+
|
|
261
|
+
return loss
|
|
262
|
+
|
|
263
|
+
def validation_step(self, batch, batch_idx):
|
|
264
|
+
x, y = batch
|
|
265
|
+
preds = self(x)
|
|
266
|
+
loss = self.loss_fn(preds, y)
|
|
267
|
+
|
|
268
|
+
metrics = self._compute_metrics(preds, y)
|
|
269
|
+
|
|
270
|
+
pb = getattr(self, "_prog_bar_metrics", True)
|
|
271
|
+
self.log("val_loss", loss, on_epoch=True, prog_bar=pb)
|
|
272
|
+
self.log("val_rmse", metrics["rmse"], on_epoch=True, prog_bar=pb)
|
|
273
|
+
self.log("val_mae", metrics["mae"], on_epoch=True)
|
|
274
|
+
self.log("val_r2", metrics["r2"], on_epoch=True, prog_bar=pb)
|
|
275
|
+
|
|
276
|
+
return loss
|
|
277
|
+
|
|
278
|
+
def test_step(self, batch, batch_idx):
|
|
279
|
+
x, y = batch
|
|
280
|
+
preds = self(x)
|
|
281
|
+
loss = self.loss_fn(preds, y)
|
|
282
|
+
|
|
283
|
+
metrics = self._compute_metrics(preds, y)
|
|
284
|
+
|
|
285
|
+
self.log("test_loss", loss, on_epoch=True)
|
|
286
|
+
self.log("test_rmse", metrics["rmse"], on_epoch=True)
|
|
287
|
+
self.log("test_mae", metrics["mae"], on_epoch=True)
|
|
288
|
+
self.log("test_r2", metrics["r2"], on_epoch=True)
|
|
289
|
+
|
|
290
|
+
return loss
|
|
291
|
+
|
|
292
|
+
def configure_optimizers(self):
|
|
293
|
+
optimizer = torch.optim.AdamW(
|
|
294
|
+
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
298
|
+
optimizer, mode="min", factor=0.5, patience=5
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
return {
|
|
302
|
+
"optimizer": optimizer,
|
|
303
|
+
"lr_scheduler": {
|
|
304
|
+
"scheduler": scheduler,
|
|
305
|
+
"monitor": "val_loss",
|
|
306
|
+
},
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
def predict_step(self, batch, batch_idx):
|
|
310
|
+
x = batch[0] if isinstance(batch, (list, tuple)) else batch
|
|
311
|
+
return self(x)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class PixelRegressionDataset(Dataset):
|
|
315
|
+
"""
|
|
316
|
+
Dataset for pixel-level regression from paired image and target rasters.
|
|
317
|
+
|
|
318
|
+
Loads image patches and corresponding target patches for training
|
|
319
|
+
pixel-wise regression models.
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
image_paths: List[str],
|
|
325
|
+
target_paths: List[str],
|
|
326
|
+
input_bands: Optional[List[int]] = None,
|
|
327
|
+
target_band: int = 1,
|
|
328
|
+
transform: Optional[Callable] = None,
|
|
329
|
+
normalize_input: bool = True,
|
|
330
|
+
image_mean: Optional[List[float]] = None,
|
|
331
|
+
image_std: Optional[List[float]] = None,
|
|
332
|
+
target_nodata: Optional[float] = None,
|
|
333
|
+
):
|
|
334
|
+
"""
|
|
335
|
+
Initialize PixelRegressionDataset.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
image_paths (List[str]): List of paths to input image patches.
|
|
339
|
+
target_paths (List[str]): List of paths to target raster patches.
|
|
340
|
+
input_bands (List[int], optional): Band indices to use (1-indexed). If None, uses all.
|
|
341
|
+
target_band (int): Band index for target raster (1-indexed).
|
|
342
|
+
transform (callable, optional): Transform to apply to images.
|
|
343
|
+
normalize_input (bool): Whether to normalize input to [0, 1].
|
|
344
|
+
image_mean (List[float], optional): Per-channel mean for normalization.
|
|
345
|
+
image_std (List[float], optional): Per-channel std for normalization.
|
|
346
|
+
target_nodata (float, optional): NoData value for targets.
|
|
347
|
+
"""
|
|
348
|
+
self.image_paths = image_paths
|
|
349
|
+
self.target_paths = target_paths
|
|
350
|
+
self.input_bands = input_bands
|
|
351
|
+
self.target_band = target_band
|
|
352
|
+
self.transform = transform
|
|
353
|
+
self.normalize_input = normalize_input
|
|
354
|
+
self.image_mean = image_mean
|
|
355
|
+
self.image_std = image_std
|
|
356
|
+
self.target_nodata = target_nodata
|
|
357
|
+
self._mean_array = None
|
|
358
|
+
self._std_array = None
|
|
359
|
+
|
|
360
|
+
if len(image_paths) != len(target_paths):
|
|
361
|
+
raise ValueError("Number of images must match number of targets")
|
|
362
|
+
|
|
363
|
+
def __len__(self):
|
|
364
|
+
return len(self.image_paths)
|
|
365
|
+
|
|
366
|
+
def __getitem__(self, idx):
|
|
367
|
+
import rasterio
|
|
368
|
+
|
|
369
|
+
# Load input image
|
|
370
|
+
with rasterio.open(self.image_paths[idx]) as src:
|
|
371
|
+
if self.input_bands is not None:
|
|
372
|
+
image = src.read(self.input_bands)
|
|
373
|
+
else:
|
|
374
|
+
image = src.read()
|
|
375
|
+
|
|
376
|
+
# Load target
|
|
377
|
+
with rasterio.open(self.target_paths[idx]) as src:
|
|
378
|
+
target = src.read(self.target_band)
|
|
379
|
+
|
|
380
|
+
# Handle NaN
|
|
381
|
+
image = np.nan_to_num(image, nan=0.0)
|
|
382
|
+
target = np.nan_to_num(target, nan=0.0)
|
|
383
|
+
|
|
384
|
+
# Normalize input
|
|
385
|
+
image = image.astype(np.float32)
|
|
386
|
+
if self.normalize_input:
|
|
387
|
+
data_max = np.abs(image).max()
|
|
388
|
+
if data_max <= 1.5:
|
|
389
|
+
image = np.clip(image, 0, 1)
|
|
390
|
+
else:
|
|
391
|
+
image = np.clip(image, 0, 10000) / 10000.0
|
|
392
|
+
|
|
393
|
+
if self.image_mean is not None and self.image_std is not None:
|
|
394
|
+
mean, std = _prepare_normalization_stats(
|
|
395
|
+
self.image_mean, self.image_std, image.shape[0]
|
|
396
|
+
)
|
|
397
|
+
image = (image - mean) / std
|
|
398
|
+
|
|
399
|
+
target = target.astype(np.float32)
|
|
400
|
+
|
|
401
|
+
# Convert to tensor
|
|
402
|
+
image = torch.from_numpy(image)
|
|
403
|
+
target = torch.from_numpy(target)
|
|
404
|
+
|
|
405
|
+
if self.transform is not None:
|
|
406
|
+
image = self.transform(image)
|
|
407
|
+
|
|
408
|
+
return image, target
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def create_regression_tiles(
|
|
412
|
+
input_raster: str,
|
|
413
|
+
target_raster: str,
|
|
414
|
+
output_dir: str,
|
|
415
|
+
tile_size: int = 256,
|
|
416
|
+
stride: Optional[int] = None,
|
|
417
|
+
input_bands: Optional[List[int]] = None,
|
|
418
|
+
target_band: int = 1,
|
|
419
|
+
min_valid_ratio: float = 0.8,
|
|
420
|
+
target_min: Optional[float] = None,
|
|
421
|
+
target_max: Optional[float] = None,
|
|
422
|
+
) -> Tuple[List[str], List[str]]:
|
|
423
|
+
"""
|
|
424
|
+
Create paired image and target tiles from input and target rasters.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
input_raster (str): Path to input raster (e.g., Landsat imagery).
|
|
428
|
+
target_raster (str): Path to target raster (e.g., NDVI).
|
|
429
|
+
output_dir (str): Directory to save tiles.
|
|
430
|
+
tile_size (int): Size of each tile (tile_size x tile_size pixels).
|
|
431
|
+
stride (int, optional): Stride between tiles. Defaults to tile_size (no overlap).
|
|
432
|
+
input_bands (List[int], optional): Band indices to use (1-indexed).
|
|
433
|
+
target_band (int): Band index for target raster (1-indexed).
|
|
434
|
+
min_valid_ratio (float): Minimum ratio of valid pixels in tile.
|
|
435
|
+
target_min (float, optional): Minimum valid target value.
|
|
436
|
+
target_max (float, optional): Maximum valid target value.
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
Tuple of (image_paths, target_paths): Lists of tile paths.
|
|
440
|
+
"""
|
|
441
|
+
import rasterio
|
|
442
|
+
|
|
443
|
+
image_dir = os.path.join(output_dir, "images")
|
|
444
|
+
target_dir = os.path.join(output_dir, "targets")
|
|
445
|
+
os.makedirs(image_dir, exist_ok=True)
|
|
446
|
+
os.makedirs(target_dir, exist_ok=True)
|
|
447
|
+
|
|
448
|
+
if stride is None:
|
|
449
|
+
stride = tile_size
|
|
450
|
+
|
|
451
|
+
image_paths = []
|
|
452
|
+
target_paths = []
|
|
453
|
+
|
|
454
|
+
with rasterio.open(input_raster) as src_input:
|
|
455
|
+
with rasterio.open(target_raster) as src_target:
|
|
456
|
+
height = src_input.height
|
|
457
|
+
width = src_input.width
|
|
458
|
+
|
|
459
|
+
if input_bands is None:
|
|
460
|
+
input_bands = list(range(1, src_input.count + 1))
|
|
461
|
+
|
|
462
|
+
n_tiles_y = (height - tile_size) // stride + 1
|
|
463
|
+
n_tiles_x = (width - tile_size) // stride + 1
|
|
464
|
+
|
|
465
|
+
print(f"Input raster: {width}x{height}, {src_input.count} bands")
|
|
466
|
+
print(f"Target raster: {src_target.width}x{src_target.height}")
|
|
467
|
+
print(f"Tile size: {tile_size}x{tile_size}, stride: {stride}")
|
|
468
|
+
print(
|
|
469
|
+
f"Expected tiles: {n_tiles_y} x {n_tiles_x} = {n_tiles_y * n_tiles_x}"
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
tile_idx = 0
|
|
473
|
+
valid_tiles = 0
|
|
474
|
+
skipped_nodata = 0
|
|
475
|
+
skipped_range = 0
|
|
476
|
+
|
|
477
|
+
for row in tqdm(
|
|
478
|
+
range(0, height - tile_size + 1, stride), desc="Creating tiles"
|
|
479
|
+
):
|
|
480
|
+
for col in range(0, width - tile_size + 1, stride):
|
|
481
|
+
window = rasterio.windows.Window(col, row, tile_size, tile_size)
|
|
482
|
+
|
|
483
|
+
# Read tiles
|
|
484
|
+
input_tile = src_input.read(input_bands, window=window)
|
|
485
|
+
target_tile = src_target.read(target_band, window=window)
|
|
486
|
+
|
|
487
|
+
# Check for valid pixels
|
|
488
|
+
valid_mask = ~np.isnan(input_tile).any(axis=0) & ~np.isnan(
|
|
489
|
+
target_tile
|
|
490
|
+
)
|
|
491
|
+
valid_ratio = valid_mask.sum() / (tile_size * tile_size)
|
|
492
|
+
|
|
493
|
+
if valid_ratio < min_valid_ratio:
|
|
494
|
+
tile_idx += 1
|
|
495
|
+
skipped_nodata += 1
|
|
496
|
+
continue
|
|
497
|
+
|
|
498
|
+
# Check target range - skip tiles with >5% out-of-range values
|
|
499
|
+
valid_target = target_tile[valid_mask]
|
|
500
|
+
out_of_range_ratio = 0.0
|
|
501
|
+
if target_min is not None or target_max is not None:
|
|
502
|
+
out_of_range = np.zeros_like(valid_target, dtype=bool)
|
|
503
|
+
if target_min is not None:
|
|
504
|
+
out_of_range |= valid_target < target_min
|
|
505
|
+
if target_max is not None:
|
|
506
|
+
out_of_range |= valid_target > target_max
|
|
507
|
+
out_of_range_ratio = out_of_range.sum() / len(valid_target)
|
|
508
|
+
|
|
509
|
+
# Skip if more than 5% of pixels are out of range
|
|
510
|
+
if out_of_range_ratio > 0.05:
|
|
511
|
+
tile_idx += 1
|
|
512
|
+
skipped_range += 1
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
# Replace NaN with 0
|
|
516
|
+
input_tile = np.nan_to_num(input_tile, nan=0.0)
|
|
517
|
+
target_tile = np.nan_to_num(target_tile, nan=0.0)
|
|
518
|
+
|
|
519
|
+
# Clip target values to valid range (important!)
|
|
520
|
+
if target_min is not None or target_max is not None:
|
|
521
|
+
target_tile = np.clip(
|
|
522
|
+
target_tile,
|
|
523
|
+
target_min if target_min is not None else -np.inf,
|
|
524
|
+
target_max if target_max is not None else np.inf,
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
# Save tiles
|
|
528
|
+
image_path = os.path.join(image_dir, f"tile_{tile_idx:06d}.tif")
|
|
529
|
+
target_path = os.path.join(target_dir, f"tile_{tile_idx:06d}.tif")
|
|
530
|
+
|
|
531
|
+
# Save input tile
|
|
532
|
+
profile = src_input.profile.copy()
|
|
533
|
+
profile.update(
|
|
534
|
+
width=tile_size,
|
|
535
|
+
height=tile_size,
|
|
536
|
+
count=len(input_bands),
|
|
537
|
+
dtype=input_tile.dtype,
|
|
538
|
+
transform=rasterio.windows.transform(
|
|
539
|
+
window, src_input.transform
|
|
540
|
+
),
|
|
541
|
+
tiled=False, # Disable tiling for small tiles
|
|
542
|
+
)
|
|
543
|
+
# Remove block size settings that cause warnings
|
|
544
|
+
profile.pop("blockxsize", None)
|
|
545
|
+
profile.pop("blockysize", None)
|
|
546
|
+
with rasterio.open(image_path, "w", **profile) as dst:
|
|
547
|
+
dst.write(input_tile)
|
|
548
|
+
|
|
549
|
+
# Save target tile
|
|
550
|
+
profile = src_target.profile.copy()
|
|
551
|
+
profile.update(
|
|
552
|
+
width=tile_size,
|
|
553
|
+
height=tile_size,
|
|
554
|
+
count=1,
|
|
555
|
+
dtype=target_tile.dtype,
|
|
556
|
+
transform=rasterio.windows.transform(
|
|
557
|
+
window, src_target.transform
|
|
558
|
+
),
|
|
559
|
+
tiled=False,
|
|
560
|
+
)
|
|
561
|
+
profile.pop("blockxsize", None)
|
|
562
|
+
profile.pop("blockysize", None)
|
|
563
|
+
with rasterio.open(target_path, "w", **profile) as dst:
|
|
564
|
+
dst.write(target_tile[np.newaxis, :, :])
|
|
565
|
+
|
|
566
|
+
image_paths.append(image_path)
|
|
567
|
+
target_paths.append(target_path)
|
|
568
|
+
valid_tiles += 1
|
|
569
|
+
tile_idx += 1
|
|
570
|
+
|
|
571
|
+
print(f"\nCreated {valid_tiles} valid tiles out of {tile_idx} total")
|
|
572
|
+
print(f"Skipped due to nodata: {skipped_nodata}")
|
|
573
|
+
print(f"Skipped due to target range: {skipped_range}")
|
|
574
|
+
|
|
575
|
+
return image_paths, target_paths
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def train_pixel_regressor(
|
|
579
|
+
train_image_paths: List[str],
|
|
580
|
+
train_target_paths: List[str],
|
|
581
|
+
val_image_paths: Optional[List[str]] = None,
|
|
582
|
+
val_target_paths: Optional[List[str]] = None,
|
|
583
|
+
encoder_name: str = "resnet50",
|
|
584
|
+
architecture: str = "unet",
|
|
585
|
+
in_channels: int = 3,
|
|
586
|
+
encoder_weights: str = "imagenet",
|
|
587
|
+
output_dir: str = "output",
|
|
588
|
+
batch_size: int = 8,
|
|
589
|
+
num_epochs: int = 50,
|
|
590
|
+
learning_rate: float = 1e-4,
|
|
591
|
+
weight_decay: float = 1e-4,
|
|
592
|
+
num_workers: int = 0,
|
|
593
|
+
freeze_encoder: bool = False,
|
|
594
|
+
loss_type: str = "mse",
|
|
595
|
+
normalize_input: bool = True,
|
|
596
|
+
accelerator: str = "auto",
|
|
597
|
+
devices: int = 1,
|
|
598
|
+
monitor_metric: str = "val_loss",
|
|
599
|
+
mode: str = "min",
|
|
600
|
+
patience: int = 10,
|
|
601
|
+
save_top_k: int = 1,
|
|
602
|
+
checkpoint_path: Optional[str] = None,
|
|
603
|
+
input_bands: Optional[List[int]] = None,
|
|
604
|
+
verbose: bool = True,
|
|
605
|
+
**kwargs: Any,
|
|
606
|
+
) -> PixelRegressionModel:
|
|
607
|
+
"""
|
|
608
|
+
Train a pixel-level regression model.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
train_image_paths: List of training image paths.
|
|
612
|
+
train_target_paths: List of training target paths.
|
|
613
|
+
val_image_paths: List of validation image paths.
|
|
614
|
+
val_target_paths: List of validation target paths.
|
|
615
|
+
encoder_name: Name of timm encoder.
|
|
616
|
+
architecture: Segmentation architecture ('unet', 'unetplusplus', 'deeplabv3plus', etc.).
|
|
617
|
+
in_channels: Number of input channels.
|
|
618
|
+
encoder_weights: Pretrained weights for encoder.
|
|
619
|
+
output_dir: Directory to save outputs.
|
|
620
|
+
batch_size: Batch size for training.
|
|
621
|
+
num_epochs: Number of training epochs.
|
|
622
|
+
learning_rate: Learning rate.
|
|
623
|
+
weight_decay: Weight decay.
|
|
624
|
+
num_workers: Number of data loading workers.
|
|
625
|
+
freeze_encoder: Freeze encoder during training.
|
|
626
|
+
loss_type: Loss function type ('mse', 'l1', 'huber').
|
|
627
|
+
normalize_input: Normalize input tiles to expected range.
|
|
628
|
+
accelerator: Accelerator type.
|
|
629
|
+
devices: Number of devices.
|
|
630
|
+
monitor_metric: Metric to monitor for checkpointing.
|
|
631
|
+
mode: 'min' or 'max' for monitor_metric.
|
|
632
|
+
patience: Early stopping patience.
|
|
633
|
+
save_top_k: Number of best models to save.
|
|
634
|
+
checkpoint_path: Path to checkpoint to resume from.
|
|
635
|
+
input_bands: Band indices to use (1-indexed).
|
|
636
|
+
verbose: Whether to show detailed training logs, progress bars,
|
|
637
|
+
and callback messages. Defaults to True.
|
|
638
|
+
**kwargs: Additional arguments for Trainer.
|
|
639
|
+
|
|
640
|
+
Returns:
|
|
641
|
+
PixelRegressionModel: Trained model.
|
|
642
|
+
"""
|
|
643
|
+
if not LIGHTNING_AVAILABLE:
|
|
644
|
+
raise ImportError(
|
|
645
|
+
"PyTorch Lightning is required. Install with: pip install lightning"
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
649
|
+
model_dir = os.path.join(output_dir, "models")
|
|
650
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
651
|
+
|
|
652
|
+
# Create model
|
|
653
|
+
model = PixelRegressionModel(
|
|
654
|
+
encoder_name=encoder_name,
|
|
655
|
+
architecture=architecture,
|
|
656
|
+
in_channels=in_channels,
|
|
657
|
+
encoder_weights=encoder_weights,
|
|
658
|
+
learning_rate=learning_rate,
|
|
659
|
+
weight_decay=weight_decay,
|
|
660
|
+
freeze_encoder=freeze_encoder,
|
|
661
|
+
loss_type=loss_type,
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
preprocessing = _infer_preprocessing_params(encoder_name, encoder_weights)
|
|
665
|
+
image_mean = None
|
|
666
|
+
image_std = None
|
|
667
|
+
if preprocessing is not None:
|
|
668
|
+
pp_mean = preprocessing.get("mean")
|
|
669
|
+
pp_std = preprocessing.get("std")
|
|
670
|
+
# Only apply encoder preprocessing when channel count matches
|
|
671
|
+
# (e.g. 3-band RGB with ImageNet weights). For multi-spectral
|
|
672
|
+
# inputs the ImageNet statistics are inappropriate; the
|
|
673
|
+
# normalize_input flag already scales values to [0, 1].
|
|
674
|
+
if pp_mean is not None and pp_std is not None and len(pp_mean) == in_channels:
|
|
675
|
+
image_mean = pp_mean
|
|
676
|
+
image_std = pp_std
|
|
677
|
+
|
|
678
|
+
# Create datasets
|
|
679
|
+
train_dataset = PixelRegressionDataset(
|
|
680
|
+
train_image_paths,
|
|
681
|
+
train_target_paths,
|
|
682
|
+
input_bands=input_bands,
|
|
683
|
+
normalize_input=normalize_input,
|
|
684
|
+
image_mean=image_mean,
|
|
685
|
+
image_std=image_std,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
train_loader = DataLoader(
|
|
689
|
+
train_dataset,
|
|
690
|
+
batch_size=batch_size,
|
|
691
|
+
shuffle=True,
|
|
692
|
+
num_workers=num_workers,
|
|
693
|
+
pin_memory=True,
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
val_loader = None
|
|
697
|
+
if val_image_paths is not None and val_target_paths is not None:
|
|
698
|
+
val_dataset = PixelRegressionDataset(
|
|
699
|
+
val_image_paths,
|
|
700
|
+
val_target_paths,
|
|
701
|
+
input_bands=input_bands,
|
|
702
|
+
normalize_input=normalize_input,
|
|
703
|
+
image_mean=image_mean,
|
|
704
|
+
image_std=image_std,
|
|
705
|
+
)
|
|
706
|
+
val_loader = DataLoader(
|
|
707
|
+
val_dataset,
|
|
708
|
+
batch_size=batch_size,
|
|
709
|
+
shuffle=False,
|
|
710
|
+
num_workers=num_workers,
|
|
711
|
+
pin_memory=True,
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# Callbacks
|
|
715
|
+
callbacks = []
|
|
716
|
+
|
|
717
|
+
checkpoint_callback = ModelCheckpoint(
|
|
718
|
+
dirpath=model_dir,
|
|
719
|
+
filename=f"{encoder_name}_{architecture}_{{epoch:02d}}_{{val_loss:.4f}}",
|
|
720
|
+
monitor=monitor_metric,
|
|
721
|
+
mode=mode,
|
|
722
|
+
save_top_k=save_top_k,
|
|
723
|
+
save_last=True,
|
|
724
|
+
verbose=verbose,
|
|
725
|
+
)
|
|
726
|
+
callbacks.append(checkpoint_callback)
|
|
727
|
+
|
|
728
|
+
early_stop_callback = EarlyStopping(
|
|
729
|
+
monitor=monitor_metric,
|
|
730
|
+
patience=patience,
|
|
731
|
+
mode=mode,
|
|
732
|
+
verbose=verbose,
|
|
733
|
+
)
|
|
734
|
+
callbacks.append(early_stop_callback)
|
|
735
|
+
|
|
736
|
+
if not verbose:
|
|
737
|
+
import logging
|
|
738
|
+
|
|
739
|
+
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
|
|
740
|
+
callbacks.append(_CompactProgressBar())
|
|
741
|
+
model._prog_bar_metrics = False
|
|
742
|
+
|
|
743
|
+
logger = CSVLogger(model_dir, name="lightning_logs")
|
|
744
|
+
|
|
745
|
+
trainer = pl.Trainer(
|
|
746
|
+
max_epochs=num_epochs,
|
|
747
|
+
accelerator=accelerator,
|
|
748
|
+
devices=devices,
|
|
749
|
+
callbacks=callbacks,
|
|
750
|
+
logger=logger,
|
|
751
|
+
log_every_n_steps=10,
|
|
752
|
+
enable_model_summary=verbose,
|
|
753
|
+
**kwargs,
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
if verbose:
|
|
757
|
+
print(
|
|
758
|
+
f"Training {architecture} with {encoder_name} encoder"
|
|
759
|
+
f" for {num_epochs} epochs..."
|
|
760
|
+
)
|
|
761
|
+
print(f"Loss function: {loss_type.upper()}")
|
|
762
|
+
|
|
763
|
+
trainer.fit(
|
|
764
|
+
model,
|
|
765
|
+
train_dataloaders=train_loader,
|
|
766
|
+
val_dataloaders=val_loader,
|
|
767
|
+
ckpt_path=checkpoint_path,
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
best_model_path = checkpoint_callback.best_model_path
|
|
771
|
+
if best_model_path:
|
|
772
|
+
if verbose:
|
|
773
|
+
print(f"\nBest model saved at: {best_model_path}")
|
|
774
|
+
model = PixelRegressionModel.load_from_checkpoint(best_model_path)
|
|
775
|
+
model.best_model_path = best_model_path
|
|
776
|
+
else:
|
|
777
|
+
if verbose:
|
|
778
|
+
print("\nBest model path not found; returning last epoch model.")
|
|
779
|
+
|
|
780
|
+
return model
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
def predict_raster(
|
|
784
|
+
model: Union[PixelRegressionModel, nn.Module],
|
|
785
|
+
input_raster: str,
|
|
786
|
+
output_raster: str,
|
|
787
|
+
tile_size: int = 256,
|
|
788
|
+
overlap: int = 64,
|
|
789
|
+
input_bands: Optional[List[int]] = None,
|
|
790
|
+
batch_size: int = 4,
|
|
791
|
+
device: Optional[str] = None,
|
|
792
|
+
output_nodata: float = -9999.0,
|
|
793
|
+
clip_range: Optional[Tuple[float, float]] = None,
|
|
794
|
+
image_mean: Optional[List[float]] = None,
|
|
795
|
+
image_std: Optional[List[float]] = None,
|
|
796
|
+
use_model_preprocessing: bool = True,
|
|
797
|
+
) -> str:
|
|
798
|
+
"""
|
|
799
|
+
Run pixel-level inference on a raster.
|
|
800
|
+
|
|
801
|
+
Uses sliding window with overlap and blending for smooth predictions.
|
|
802
|
+
Output dimensions match input dimensions exactly.
|
|
803
|
+
|
|
804
|
+
Args:
|
|
805
|
+
model: Trained pixel regression model.
|
|
806
|
+
input_raster: Path to input raster.
|
|
807
|
+
output_raster: Path to save output raster.
|
|
808
|
+
tile_size: Size of tiles for inference.
|
|
809
|
+
overlap: Overlap between tiles for blending.
|
|
810
|
+
input_bands: Band indices to use (1-indexed).
|
|
811
|
+
batch_size: Batch size for inference.
|
|
812
|
+
device: Device to use.
|
|
813
|
+
output_nodata: NoData value for output.
|
|
814
|
+
clip_range: Optional tuple (min, max) to clip output values.
|
|
815
|
+
image_mean: Optional per-channel mean for normalization.
|
|
816
|
+
image_std: Optional per-channel std for normalization.
|
|
817
|
+
use_model_preprocessing: Use encoder preprocessing params if available.
|
|
818
|
+
|
|
819
|
+
Returns:
|
|
820
|
+
str: Path to output raster.
|
|
821
|
+
"""
|
|
822
|
+
import rasterio
|
|
823
|
+
|
|
824
|
+
if device is None:
|
|
825
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
826
|
+
|
|
827
|
+
model.eval()
|
|
828
|
+
model = model.to(device)
|
|
829
|
+
|
|
830
|
+
stride = tile_size - overlap
|
|
831
|
+
|
|
832
|
+
with rasterio.open(input_raster) as src:
|
|
833
|
+
height = src.height
|
|
834
|
+
width = src.width
|
|
835
|
+
|
|
836
|
+
if input_bands is None:
|
|
837
|
+
input_bands = list(range(1, src.count + 1))
|
|
838
|
+
|
|
839
|
+
print(f"Input raster: {width}x{height}")
|
|
840
|
+
print(f"Tile size: {tile_size}, overlap: {overlap}, stride: {stride}")
|
|
841
|
+
|
|
842
|
+
# Initialize output arrays
|
|
843
|
+
output_sum = np.zeros((height, width), dtype=np.float64)
|
|
844
|
+
weight_sum = np.zeros((height, width), dtype=np.float64)
|
|
845
|
+
|
|
846
|
+
# Create weight mask for blending (higher weight in center)
|
|
847
|
+
weight_mask = _create_weight_mask(tile_size, overlap)
|
|
848
|
+
|
|
849
|
+
# Read full input for nodata mask
|
|
850
|
+
full_input = src.read(input_bands)
|
|
851
|
+
nodata_mask = np.any(np.isnan(full_input), axis=0)
|
|
852
|
+
|
|
853
|
+
if use_model_preprocessing and image_mean is None and image_std is None:
|
|
854
|
+
encoder_name = getattr(
|
|
855
|
+
getattr(model, "hparams", None), "encoder_name", None
|
|
856
|
+
)
|
|
857
|
+
encoder_weights = getattr(
|
|
858
|
+
getattr(model, "hparams", None), "encoder_weights", None
|
|
859
|
+
)
|
|
860
|
+
model_in_channels = getattr(
|
|
861
|
+
getattr(model, "hparams", None),
|
|
862
|
+
"in_channels",
|
|
863
|
+
len(input_bands),
|
|
864
|
+
)
|
|
865
|
+
if encoder_name and encoder_weights:
|
|
866
|
+
preprocessing = _infer_preprocessing_params(
|
|
867
|
+
encoder_name, encoder_weights
|
|
868
|
+
)
|
|
869
|
+
if preprocessing is not None:
|
|
870
|
+
pp_mean = preprocessing.get("mean")
|
|
871
|
+
pp_std = preprocessing.get("std")
|
|
872
|
+
if (
|
|
873
|
+
pp_mean is not None
|
|
874
|
+
and pp_std is not None
|
|
875
|
+
and len(pp_mean) == model_in_channels
|
|
876
|
+
):
|
|
877
|
+
image_mean = pp_mean
|
|
878
|
+
image_std = pp_std
|
|
879
|
+
|
|
880
|
+
# Collect tiles
|
|
881
|
+
tiles = []
|
|
882
|
+
positions = []
|
|
883
|
+
|
|
884
|
+
for row in range(0, height, stride):
|
|
885
|
+
for col in range(0, width, stride):
|
|
886
|
+
# Calculate tile bounds
|
|
887
|
+
row_end = min(row + tile_size, height)
|
|
888
|
+
col_end = min(col + tile_size, width)
|
|
889
|
+
row_start = row_end - tile_size
|
|
890
|
+
col_start = col_end - tile_size
|
|
891
|
+
|
|
892
|
+
# Clamp to valid range
|
|
893
|
+
row_start = max(0, row_start)
|
|
894
|
+
col_start = max(0, col_start)
|
|
895
|
+
|
|
896
|
+
tiles.append((row_start, col_start, row_end, col_end))
|
|
897
|
+
positions.append((row_start, col_start))
|
|
898
|
+
|
|
899
|
+
print(f"Total tiles: {len(tiles)}")
|
|
900
|
+
|
|
901
|
+
# Process in batches
|
|
902
|
+
for batch_start in tqdm(
|
|
903
|
+
range(0, len(tiles), batch_size), desc="Running inference"
|
|
904
|
+
):
|
|
905
|
+
batch_end = min(batch_start + batch_size, len(tiles))
|
|
906
|
+
batch_tiles = tiles[batch_start:batch_end]
|
|
907
|
+
|
|
908
|
+
# Load batch
|
|
909
|
+
batch_images = []
|
|
910
|
+
for row_start, col_start, row_end, col_end in batch_tiles:
|
|
911
|
+
window = rasterio.windows.Window(
|
|
912
|
+
col_start, row_start, col_end - col_start, row_end - row_start
|
|
913
|
+
)
|
|
914
|
+
tile = src.read(input_bands, window=window).astype(np.float32)
|
|
915
|
+
|
|
916
|
+
# Handle non-square tiles at edges
|
|
917
|
+
if tile.shape[1] != tile_size or tile.shape[2] != tile_size:
|
|
918
|
+
padded = np.zeros(
|
|
919
|
+
(len(input_bands), tile_size, tile_size), dtype=np.float32
|
|
920
|
+
)
|
|
921
|
+
padded[:, : tile.shape[1], : tile.shape[2]] = tile
|
|
922
|
+
tile = padded
|
|
923
|
+
|
|
924
|
+
# Normalize
|
|
925
|
+
tile = np.nan_to_num(tile, nan=0.0)
|
|
926
|
+
data_max = np.abs(tile).max()
|
|
927
|
+
if data_max <= 1.5:
|
|
928
|
+
tile = np.clip(tile, 0, 1)
|
|
929
|
+
else:
|
|
930
|
+
tile = np.clip(tile, 0, 10000) / 10000.0
|
|
931
|
+
|
|
932
|
+
if image_mean is not None and image_std is not None:
|
|
933
|
+
mean, std = _prepare_normalization_stats(
|
|
934
|
+
image_mean, image_std, tile.shape[0]
|
|
935
|
+
)
|
|
936
|
+
tile = (tile - mean) / std
|
|
937
|
+
|
|
938
|
+
batch_images.append(tile)
|
|
939
|
+
|
|
940
|
+
batch_tensor = torch.from_numpy(np.stack(batch_images)).to(device)
|
|
941
|
+
|
|
942
|
+
# Inference
|
|
943
|
+
with torch.no_grad():
|
|
944
|
+
preds = model(batch_tensor).cpu().numpy()
|
|
945
|
+
|
|
946
|
+
# Apply predictions with blending
|
|
947
|
+
for i, (row_start, col_start, row_end, col_end) in enumerate(batch_tiles):
|
|
948
|
+
pred = preds[i]
|
|
949
|
+
h = row_end - row_start
|
|
950
|
+
w = col_end - col_start
|
|
951
|
+
|
|
952
|
+
# Get the relevant portion of prediction and weight
|
|
953
|
+
pred_crop = pred[:h, :w]
|
|
954
|
+
weight_crop = weight_mask[:h, :w]
|
|
955
|
+
|
|
956
|
+
# Accumulate
|
|
957
|
+
output_sum[row_start:row_end, col_start:col_end] += (
|
|
958
|
+
pred_crop * weight_crop
|
|
959
|
+
)
|
|
960
|
+
weight_sum[row_start:row_end, col_start:col_end] += weight_crop
|
|
961
|
+
|
|
962
|
+
# Normalize by weights
|
|
963
|
+
valid_weights = weight_sum > 0
|
|
964
|
+
output_array = np.full((height, width), output_nodata, dtype=np.float32)
|
|
965
|
+
output_array[valid_weights] = (
|
|
966
|
+
output_sum[valid_weights] / weight_sum[valid_weights]
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
# Apply nodata mask
|
|
970
|
+
output_array[nodata_mask] = output_nodata
|
|
971
|
+
|
|
972
|
+
# Clip output to valid range if specified
|
|
973
|
+
if clip_range is not None:
|
|
974
|
+
valid_data_mask = ~nodata_mask & valid_weights
|
|
975
|
+
output_array[valid_data_mask] = np.clip(
|
|
976
|
+
output_array[valid_data_mask], clip_range[0], clip_range[1]
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# Save output
|
|
980
|
+
profile = src.profile.copy()
|
|
981
|
+
profile.update(
|
|
982
|
+
count=1,
|
|
983
|
+
dtype=np.float32,
|
|
984
|
+
nodata=output_nodata,
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
output_dir = os.path.dirname(os.path.abspath(output_raster))
|
|
988
|
+
if output_dir:
|
|
989
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
990
|
+
|
|
991
|
+
with rasterio.open(output_raster, "w", **profile) as dst:
|
|
992
|
+
dst.write(output_array, 1)
|
|
993
|
+
|
|
994
|
+
valid_data = output_array[~nodata_mask & valid_weights]
|
|
995
|
+
print(f"\nOutput saved to: {output_raster}")
|
|
996
|
+
print(f"Output dimensions: {width}x{height} (same as input)")
|
|
997
|
+
if len(valid_data) > 0:
|
|
998
|
+
print(f"Prediction range: [{valid_data.min():.4f}, {valid_data.max():.4f}]")
|
|
999
|
+
|
|
1000
|
+
return output_raster
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
def _create_weight_mask(tile_size: int, overlap: int) -> np.ndarray:
|
|
1004
|
+
"""Create a weight mask for blending overlapping tiles."""
|
|
1005
|
+
if overlap == 0:
|
|
1006
|
+
return np.ones((tile_size, tile_size), dtype=np.float32)
|
|
1007
|
+
|
|
1008
|
+
# Create 1D ramp
|
|
1009
|
+
ramp = np.ones(tile_size, dtype=np.float32)
|
|
1010
|
+
ramp[:overlap] = np.linspace(0, 1, overlap)
|
|
1011
|
+
ramp[-overlap:] = np.linspace(1, 0, overlap)
|
|
1012
|
+
|
|
1013
|
+
# Create 2D weight mask
|
|
1014
|
+
weight_mask = np.outer(ramp, ramp)
|
|
1015
|
+
return weight_mask
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
# ============================================================================
|
|
1019
|
+
# Evaluation and Visualization Functions
|
|
1020
|
+
# ============================================================================
|
|
1021
|
+
|
|
1022
|
+
|
|
1023
|
+
def evaluate_regression(
|
|
1024
|
+
y_true: np.ndarray,
|
|
1025
|
+
y_pred: np.ndarray,
|
|
1026
|
+
mask: Optional[np.ndarray] = None,
|
|
1027
|
+
print_results: bool = True,
|
|
1028
|
+
) -> Dict[str, float]:
|
|
1029
|
+
"""
|
|
1030
|
+
Evaluate regression predictions with multiple metrics.
|
|
1031
|
+
|
|
1032
|
+
Args:
|
|
1033
|
+
y_true: Ground truth values.
|
|
1034
|
+
y_pred: Predicted values.
|
|
1035
|
+
mask: Optional mask of valid pixels.
|
|
1036
|
+
print_results: Whether to print results.
|
|
1037
|
+
|
|
1038
|
+
Returns:
|
|
1039
|
+
Dictionary of metrics: MSE, RMSE, MAE, R².
|
|
1040
|
+
"""
|
|
1041
|
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
|
1042
|
+
|
|
1043
|
+
y_true = np.array(y_true).flatten()
|
|
1044
|
+
y_pred = np.array(y_pred).flatten()
|
|
1045
|
+
|
|
1046
|
+
if mask is not None:
|
|
1047
|
+
mask = np.array(mask).flatten()
|
|
1048
|
+
y_true = y_true[mask]
|
|
1049
|
+
y_pred = y_pred[mask]
|
|
1050
|
+
|
|
1051
|
+
mse = mean_squared_error(y_true, y_pred)
|
|
1052
|
+
rmse = np.sqrt(mse)
|
|
1053
|
+
mae = mean_absolute_error(y_true, y_pred)
|
|
1054
|
+
r2 = r2_score(y_true, y_pred)
|
|
1055
|
+
|
|
1056
|
+
metrics = {
|
|
1057
|
+
"mse": mse,
|
|
1058
|
+
"rmse": rmse,
|
|
1059
|
+
"mae": mae,
|
|
1060
|
+
"r2": r2,
|
|
1061
|
+
}
|
|
1062
|
+
|
|
1063
|
+
if print_results:
|
|
1064
|
+
print("=" * 50)
|
|
1065
|
+
print("Regression Evaluation Metrics")
|
|
1066
|
+
print("=" * 50)
|
|
1067
|
+
print(f"MSE: {mse:.6f}")
|
|
1068
|
+
print(f"RMSE: {rmse:.6f}")
|
|
1069
|
+
print(f"MAE: {mae:.6f}")
|
|
1070
|
+
print(f"R²: {r2:.4f}")
|
|
1071
|
+
print("=" * 50)
|
|
1072
|
+
|
|
1073
|
+
return metrics
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
def plot_regression_comparison(
|
|
1077
|
+
true_raster: str,
|
|
1078
|
+
pred_raster: str,
|
|
1079
|
+
title: str = "Regression Results",
|
|
1080
|
+
cmap: str = "RdYlGn",
|
|
1081
|
+
vmin: Optional[float] = None,
|
|
1082
|
+
vmax: Optional[float] = None,
|
|
1083
|
+
valid_range: Optional[Tuple[float, float]] = None,
|
|
1084
|
+
figsize: Tuple[int, int] = (18, 6),
|
|
1085
|
+
save_path: Optional[str] = None,
|
|
1086
|
+
):
|
|
1087
|
+
"""
|
|
1088
|
+
Plot comparison of ground truth, prediction, and difference.
|
|
1089
|
+
|
|
1090
|
+
Args:
|
|
1091
|
+
true_raster: Path to ground truth raster.
|
|
1092
|
+
pred_raster: Path to prediction raster.
|
|
1093
|
+
title: Title for the plot.
|
|
1094
|
+
cmap: Colormap for visualization.
|
|
1095
|
+
vmin: Minimum value for colormap.
|
|
1096
|
+
vmax: Maximum value for colormap.
|
|
1097
|
+
valid_range: Tuple of (min, max) valid values for filtering outliers.
|
|
1098
|
+
figsize: Figure size.
|
|
1099
|
+
save_path: Path to save figure.
|
|
1100
|
+
|
|
1101
|
+
Returns:
|
|
1102
|
+
Tuple of (figure, metrics_dict).
|
|
1103
|
+
"""
|
|
1104
|
+
import matplotlib.pyplot as plt
|
|
1105
|
+
import rasterio
|
|
1106
|
+
|
|
1107
|
+
with rasterio.open(true_raster) as src:
|
|
1108
|
+
true_data = src.read(1)
|
|
1109
|
+
true_nodata = src.nodata
|
|
1110
|
+
|
|
1111
|
+
with rasterio.open(pred_raster) as src:
|
|
1112
|
+
pred_data = src.read(1)
|
|
1113
|
+
pred_nodata = src.nodata
|
|
1114
|
+
|
|
1115
|
+
# Create valid mask
|
|
1116
|
+
valid_mask = np.ones_like(true_data, dtype=bool)
|
|
1117
|
+
if true_nodata is not None:
|
|
1118
|
+
valid_mask &= true_data != true_nodata
|
|
1119
|
+
if pred_nodata is not None:
|
|
1120
|
+
valid_mask &= pred_data != pred_nodata
|
|
1121
|
+
valid_mask &= ~np.isnan(true_data) & ~np.isnan(pred_data)
|
|
1122
|
+
|
|
1123
|
+
# Filter by valid range (important for NDVI which should be [-1, 1])
|
|
1124
|
+
if valid_range is not None:
|
|
1125
|
+
valid_mask &= (true_data >= valid_range[0]) & (true_data <= valid_range[1])
|
|
1126
|
+
valid_mask &= (pred_data >= valid_range[0]) & (pred_data <= valid_range[1])
|
|
1127
|
+
|
|
1128
|
+
# Calculate metrics
|
|
1129
|
+
metrics = evaluate_regression(
|
|
1130
|
+
true_data[valid_mask], pred_data[valid_mask], print_results=False
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
# Auto-determine vmin/vmax if not specified
|
|
1134
|
+
if vmin is None:
|
|
1135
|
+
vmin = np.percentile(true_data[valid_mask], 2)
|
|
1136
|
+
if vmax is None:
|
|
1137
|
+
vmax = np.percentile(true_data[valid_mask], 98)
|
|
1138
|
+
|
|
1139
|
+
# Create masked arrays for display
|
|
1140
|
+
true_masked = np.ma.masked_where(~valid_mask, true_data)
|
|
1141
|
+
pred_masked = np.ma.masked_where(~valid_mask, pred_data)
|
|
1142
|
+
diff = pred_data - true_data
|
|
1143
|
+
diff_masked = np.ma.masked_where(~valid_mask, diff)
|
|
1144
|
+
|
|
1145
|
+
# Plot
|
|
1146
|
+
fig, axes = plt.subplots(1, 3, figsize=figsize)
|
|
1147
|
+
|
|
1148
|
+
im1 = axes[0].imshow(true_masked, cmap=cmap, vmin=vmin, vmax=vmax)
|
|
1149
|
+
axes[0].set_title("Ground Truth", fontsize=14)
|
|
1150
|
+
axes[0].axis("off")
|
|
1151
|
+
plt.colorbar(im1, ax=axes[0], shrink=0.8)
|
|
1152
|
+
|
|
1153
|
+
im2 = axes[1].imshow(pred_masked, cmap=cmap, vmin=vmin, vmax=vmax)
|
|
1154
|
+
axes[1].set_title(f"Prediction (R²={metrics['r2']:.4f})", fontsize=14)
|
|
1155
|
+
axes[1].axis("off")
|
|
1156
|
+
plt.colorbar(im2, ax=axes[1], shrink=0.8)
|
|
1157
|
+
|
|
1158
|
+
diff_range = max(
|
|
1159
|
+
abs(np.percentile(diff[valid_mask], 5)),
|
|
1160
|
+
abs(np.percentile(diff[valid_mask], 95)),
|
|
1161
|
+
)
|
|
1162
|
+
im3 = axes[2].imshow(diff_masked, cmap="RdBu_r", vmin=-diff_range, vmax=diff_range)
|
|
1163
|
+
axes[2].set_title(f"Difference (RMSE={metrics['rmse']:.4f})", fontsize=14)
|
|
1164
|
+
axes[2].axis("off")
|
|
1165
|
+
plt.colorbar(im3, ax=axes[2], shrink=0.8)
|
|
1166
|
+
|
|
1167
|
+
plt.suptitle(title, fontsize=16)
|
|
1168
|
+
plt.tight_layout()
|
|
1169
|
+
|
|
1170
|
+
if save_path:
|
|
1171
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
1172
|
+
print(f"Figure saved to: {save_path}")
|
|
1173
|
+
|
|
1174
|
+
plt.show()
|
|
1175
|
+
|
|
1176
|
+
return fig, metrics
|
|
1177
|
+
|
|
1178
|
+
|
|
1179
|
+
def plot_scatter(
|
|
1180
|
+
true_raster: str,
|
|
1181
|
+
pred_raster: str,
|
|
1182
|
+
sample_size: int = 10000,
|
|
1183
|
+
title: str = "Predicted vs Actual",
|
|
1184
|
+
valid_range: Optional[Tuple[float, float]] = None,
|
|
1185
|
+
fit_line: bool = True,
|
|
1186
|
+
figsize: Tuple[int, int] = (10, 8),
|
|
1187
|
+
save_path: Optional[str] = None,
|
|
1188
|
+
):
|
|
1189
|
+
"""
|
|
1190
|
+
Plot scatter plot of predicted vs actual values with optional trend line.
|
|
1191
|
+
|
|
1192
|
+
Args:
|
|
1193
|
+
true_raster: Path to ground truth raster.
|
|
1194
|
+
pred_raster: Path to prediction raster.
|
|
1195
|
+
sample_size: Number of points to plot (sampled if needed).
|
|
1196
|
+
title: Title for the plot.
|
|
1197
|
+
valid_range: Tuple of (min, max) valid values for filtering outliers.
|
|
1198
|
+
fit_line: Whether to show a linear regression trend line.
|
|
1199
|
+
figsize: Figure size.
|
|
1200
|
+
save_path: Path to save figure.
|
|
1201
|
+
|
|
1202
|
+
Returns:
|
|
1203
|
+
Tuple of (figure, metrics_dict).
|
|
1204
|
+
"""
|
|
1205
|
+
import matplotlib.pyplot as plt
|
|
1206
|
+
import rasterio
|
|
1207
|
+
from sklearn.metrics import r2_score
|
|
1208
|
+
|
|
1209
|
+
with rasterio.open(true_raster) as src:
|
|
1210
|
+
true_data = src.read(1)
|
|
1211
|
+
true_nodata = src.nodata
|
|
1212
|
+
|
|
1213
|
+
with rasterio.open(pred_raster) as src:
|
|
1214
|
+
pred_data = src.read(1)
|
|
1215
|
+
pred_nodata = src.nodata
|
|
1216
|
+
|
|
1217
|
+
# Create valid mask
|
|
1218
|
+
valid_mask = np.ones_like(true_data, dtype=bool)
|
|
1219
|
+
if true_nodata is not None:
|
|
1220
|
+
valid_mask &= true_data != true_nodata
|
|
1221
|
+
if pred_nodata is not None:
|
|
1222
|
+
valid_mask &= pred_data != pred_nodata
|
|
1223
|
+
valid_mask &= ~np.isnan(true_data) & ~np.isnan(pred_data)
|
|
1224
|
+
|
|
1225
|
+
# Filter by valid range
|
|
1226
|
+
if valid_range is not None:
|
|
1227
|
+
valid_mask &= (true_data >= valid_range[0]) & (true_data <= valid_range[1])
|
|
1228
|
+
valid_mask &= (pred_data >= valid_range[0]) & (pred_data <= valid_range[1])
|
|
1229
|
+
|
|
1230
|
+
y_true = true_data[valid_mask]
|
|
1231
|
+
y_pred = pred_data[valid_mask]
|
|
1232
|
+
|
|
1233
|
+
# Sample if too many points
|
|
1234
|
+
if len(y_true) > sample_size:
|
|
1235
|
+
idx = np.random.choice(len(y_true), sample_size, replace=False)
|
|
1236
|
+
y_true_plot = y_true[idx]
|
|
1237
|
+
y_pred_plot = y_pred[idx]
|
|
1238
|
+
else:
|
|
1239
|
+
y_true_plot = y_true
|
|
1240
|
+
y_pred_plot = y_pred
|
|
1241
|
+
|
|
1242
|
+
# Calculate metrics on full data
|
|
1243
|
+
metrics = evaluate_regression(y_true, y_pred, print_results=False)
|
|
1244
|
+
|
|
1245
|
+
# Plot
|
|
1246
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1247
|
+
|
|
1248
|
+
ax.scatter(y_true_plot, y_pred_plot, alpha=0.3, s=5, edgecolors="none")
|
|
1249
|
+
|
|
1250
|
+
# Add 1:1 line
|
|
1251
|
+
min_val = min(y_true.min(), y_pred.min())
|
|
1252
|
+
max_val = max(y_true.max(), y_pred.max())
|
|
1253
|
+
ax.plot([min_val, max_val], [min_val, max_val], "r--", lw=2, label="1:1 Line")
|
|
1254
|
+
|
|
1255
|
+
# Add linear regression trend line
|
|
1256
|
+
if fit_line:
|
|
1257
|
+
coeffs = np.polyfit(y_true, y_pred, 1)
|
|
1258
|
+
slope, intercept = coeffs
|
|
1259
|
+
fit_x = np.array([min_val, max_val])
|
|
1260
|
+
fit_y = slope * fit_x + intercept
|
|
1261
|
+
ax.plot(
|
|
1262
|
+
fit_x,
|
|
1263
|
+
fit_y,
|
|
1264
|
+
"b-",
|
|
1265
|
+
lw=2,
|
|
1266
|
+
label=f"Fit: y = {slope:.3f}x + {intercept:.3f}",
|
|
1267
|
+
)
|
|
1268
|
+
metrics["slope"] = float(slope)
|
|
1269
|
+
metrics["intercept"] = float(intercept)
|
|
1270
|
+
|
|
1271
|
+
ax.set_xlabel("Actual Values", fontsize=12)
|
|
1272
|
+
ax.set_ylabel("Predicted Values", fontsize=12)
|
|
1273
|
+
ax.set_title(
|
|
1274
|
+
f"{title}\nR² = {metrics['r2']:.4f}, RMSE = {metrics['rmse']:.4f}", fontsize=14
|
|
1275
|
+
)
|
|
1276
|
+
ax.legend()
|
|
1277
|
+
ax.grid(True, alpha=0.3)
|
|
1278
|
+
|
|
1279
|
+
plt.tight_layout()
|
|
1280
|
+
|
|
1281
|
+
if save_path:
|
|
1282
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
1283
|
+
print(f"Figure saved to: {save_path}")
|
|
1284
|
+
|
|
1285
|
+
plt.show()
|
|
1286
|
+
|
|
1287
|
+
return fig, metrics
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
def plot_training_history(
|
|
1291
|
+
log_dir: str,
|
|
1292
|
+
metrics: Optional[List[str]] = None,
|
|
1293
|
+
figsize: Optional[Tuple[int, int]] = None,
|
|
1294
|
+
tail: Optional[int] = None,
|
|
1295
|
+
save_path: Optional[str] = None,
|
|
1296
|
+
):
|
|
1297
|
+
"""
|
|
1298
|
+
Plot training history curves from Lightning CSV logs.
|
|
1299
|
+
|
|
1300
|
+
Reads the ``metrics.csv`` file produced by :class:`CSVLogger` and plots
|
|
1301
|
+
the requested training and validation metrics over epochs.
|
|
1302
|
+
|
|
1303
|
+
Args:
|
|
1304
|
+
log_dir: Path to the model output directory (the same ``output_dir``
|
|
1305
|
+
passed to :func:`train_pixel_regressor`). The function searches
|
|
1306
|
+
for ``lightning_logs/version_*/metrics.csv`` inside a ``models``
|
|
1307
|
+
sub-directory (or directly under *log_dir*).
|
|
1308
|
+
metrics: List of metric names to plot. Each name is matched against
|
|
1309
|
+
the CSV columns; both the ``train_`` and ``val_`` variants are
|
|
1310
|
+
plotted when available. Defaults to ``["loss", "r2"]``.
|
|
1311
|
+
figsize: Figure size as ``(width, height)``. Defaults to
|
|
1312
|
+
``(6 * n_metrics, 5)``.
|
|
1313
|
+
tail: If given, only plot the last *tail* epochs. Useful for
|
|
1314
|
+
skipping early warm-up instability. By default the function
|
|
1315
|
+
automatically skips early epochs when extreme outliers would
|
|
1316
|
+
compress the y-axis (more than 10× the stable range).
|
|
1317
|
+
save_path: If given, save the figure to this path.
|
|
1318
|
+
|
|
1319
|
+
Returns:
|
|
1320
|
+
Tuple of (figure, pandas.DataFrame of the loaded metrics).
|
|
1321
|
+
"""
|
|
1322
|
+
import glob
|
|
1323
|
+
|
|
1324
|
+
import matplotlib.pyplot as plt
|
|
1325
|
+
|
|
1326
|
+
try:
|
|
1327
|
+
import pandas as pd
|
|
1328
|
+
except ImportError:
|
|
1329
|
+
raise ImportError("pandas is required for plot_training_history")
|
|
1330
|
+
|
|
1331
|
+
if metrics is None:
|
|
1332
|
+
metrics = ["loss", "r2"]
|
|
1333
|
+
|
|
1334
|
+
# Locate metrics.csv
|
|
1335
|
+
search_paths = [
|
|
1336
|
+
os.path.join(log_dir, "models", "lightning_logs", "version_*", "metrics.csv"),
|
|
1337
|
+
os.path.join(log_dir, "lightning_logs", "version_*", "metrics.csv"),
|
|
1338
|
+
os.path.join(log_dir, "version_*", "metrics.csv"),
|
|
1339
|
+
]
|
|
1340
|
+
|
|
1341
|
+
csv_path = None
|
|
1342
|
+
for pattern in search_paths:
|
|
1343
|
+
matches = sorted(glob.glob(pattern))
|
|
1344
|
+
if matches:
|
|
1345
|
+
csv_path = matches[-1] # latest version
|
|
1346
|
+
break
|
|
1347
|
+
|
|
1348
|
+
if csv_path is None:
|
|
1349
|
+
raise FileNotFoundError(
|
|
1350
|
+
f"No metrics.csv found under '{log_dir}'. "
|
|
1351
|
+
"Looked for lightning_logs/version_*/metrics.csv"
|
|
1352
|
+
)
|
|
1353
|
+
|
|
1354
|
+
df = pd.read_csv(csv_path)
|
|
1355
|
+
_n_epochs = df["epoch"].nunique() if "epoch" in df.columns else len(df)
|
|
1356
|
+
print(f"Reading logs: {csv_path} ({_n_epochs} epochs)")
|
|
1357
|
+
|
|
1358
|
+
# Group rows by epoch – Lightning logs multiple rows per epoch (one per
|
|
1359
|
+
# step plus validation). Use ``last()`` with ``skipna`` so we keep the
|
|
1360
|
+
# last non-null value for every column within each epoch.
|
|
1361
|
+
if "epoch" in df.columns:
|
|
1362
|
+
try:
|
|
1363
|
+
df_epoch = df.groupby("epoch").last(skipna=True).reset_index()
|
|
1364
|
+
except TypeError:
|
|
1365
|
+
# older pandas without skipna
|
|
1366
|
+
df_epoch = df.groupby("epoch").last().reset_index()
|
|
1367
|
+
else:
|
|
1368
|
+
df_epoch = df
|
|
1369
|
+
|
|
1370
|
+
# Apply tail filter
|
|
1371
|
+
if tail is not None:
|
|
1372
|
+
df_epoch = df_epoch.tail(tail).reset_index(drop=True)
|
|
1373
|
+
|
|
1374
|
+
n_metrics = len(metrics)
|
|
1375
|
+
if figsize is None:
|
|
1376
|
+
figsize = (6 * n_metrics, 5)
|
|
1377
|
+
|
|
1378
|
+
fig, axes = plt.subplots(1, n_metrics, figsize=figsize)
|
|
1379
|
+
if n_metrics == 1:
|
|
1380
|
+
axes = [axes]
|
|
1381
|
+
|
|
1382
|
+
for ax, metric in zip(axes, metrics):
|
|
1383
|
+
train_col = (
|
|
1384
|
+
f"train_{metric}_epoch"
|
|
1385
|
+
if f"train_{metric}_epoch" in df_epoch.columns
|
|
1386
|
+
else f"train_{metric}"
|
|
1387
|
+
)
|
|
1388
|
+
val_col = f"val_{metric}"
|
|
1389
|
+
|
|
1390
|
+
has_train = train_col in df_epoch.columns
|
|
1391
|
+
has_val = val_col in df_epoch.columns
|
|
1392
|
+
|
|
1393
|
+
if not has_train and not has_val:
|
|
1394
|
+
ax.set_title(f"{metric} (no data)")
|
|
1395
|
+
continue
|
|
1396
|
+
|
|
1397
|
+
x = df_epoch["epoch"] if "epoch" in df_epoch.columns else df_epoch.index
|
|
1398
|
+
|
|
1399
|
+
if has_train:
|
|
1400
|
+
train_data = df_epoch[train_col].dropna()
|
|
1401
|
+
ax.plot(
|
|
1402
|
+
x[train_data.index],
|
|
1403
|
+
train_data.values,
|
|
1404
|
+
label=f"Train {metric}",
|
|
1405
|
+
linewidth=2,
|
|
1406
|
+
)
|
|
1407
|
+
if has_val:
|
|
1408
|
+
val_data = df_epoch[val_col].dropna()
|
|
1409
|
+
ax.plot(
|
|
1410
|
+
x[val_data.index],
|
|
1411
|
+
val_data.values,
|
|
1412
|
+
label=f"Val {metric}",
|
|
1413
|
+
linewidth=2,
|
|
1414
|
+
)
|
|
1415
|
+
|
|
1416
|
+
# Auto-zoom: if early outliers compress the view, clip y-axis to
|
|
1417
|
+
# the range of the stable second half of training.
|
|
1418
|
+
if tail is None:
|
|
1419
|
+
n_epochs = len(df_epoch)
|
|
1420
|
+
if n_epochs >= 10:
|
|
1421
|
+
half = n_epochs // 2
|
|
1422
|
+
second_half_vals = []
|
|
1423
|
+
all_vals = []
|
|
1424
|
+
if has_train:
|
|
1425
|
+
col_data = df_epoch[train_col].dropna()
|
|
1426
|
+
all_vals.extend(col_data.values)
|
|
1427
|
+
second_half_vals.extend(
|
|
1428
|
+
df_epoch[train_col].iloc[half:].dropna().values
|
|
1429
|
+
)
|
|
1430
|
+
if has_val:
|
|
1431
|
+
col_data = df_epoch[val_col].dropna()
|
|
1432
|
+
all_vals.extend(col_data.values)
|
|
1433
|
+
second_half_vals.extend(
|
|
1434
|
+
df_epoch[val_col].iloc[half:].dropna().values
|
|
1435
|
+
)
|
|
1436
|
+
if second_half_vals and all_vals:
|
|
1437
|
+
sh_arr = np.array(second_half_vals)
|
|
1438
|
+
all_arr = np.array(all_vals)
|
|
1439
|
+
sh_min, sh_max = sh_arr.min(), sh_arr.max()
|
|
1440
|
+
sh_range = sh_max - sh_min if sh_max != sh_min else 1.0
|
|
1441
|
+
full_range = all_arr.max() - all_arr.min()
|
|
1442
|
+
if full_range == 0:
|
|
1443
|
+
full_range = 1.0
|
|
1444
|
+
# If full range is >5× the stable range, zoom in
|
|
1445
|
+
if full_range > 5 * sh_range:
|
|
1446
|
+
margin = sh_range * 0.3
|
|
1447
|
+
ax.set_ylim(sh_min - margin, sh_max + margin)
|
|
1448
|
+
|
|
1449
|
+
label = metric.upper() if len(metric) <= 4 else metric.replace("_", " ").title()
|
|
1450
|
+
ax.set_xlabel("Epoch", fontsize=12)
|
|
1451
|
+
ax.set_ylabel(label, fontsize=12)
|
|
1452
|
+
ax.set_title(f"Training & Validation {label}", fontsize=14)
|
|
1453
|
+
ax.legend()
|
|
1454
|
+
ax.grid(True, alpha=0.3)
|
|
1455
|
+
|
|
1456
|
+
plt.tight_layout()
|
|
1457
|
+
|
|
1458
|
+
if save_path:
|
|
1459
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
1460
|
+
print(f"Figure saved to: {save_path}")
|
|
1461
|
+
|
|
1462
|
+
plt.show()
|
|
1463
|
+
|
|
1464
|
+
return fig, df_epoch
|
|
1465
|
+
|
|
1466
|
+
|
|
1467
|
+
def visualize_prediction(
|
|
1468
|
+
input_raster: str,
|
|
1469
|
+
pred_raster: str,
|
|
1470
|
+
rgb_bands: List[int] = [1, 2, 3],
|
|
1471
|
+
cmap: str = "RdYlGn",
|
|
1472
|
+
vmin: Optional[float] = None,
|
|
1473
|
+
vmax: Optional[float] = None,
|
|
1474
|
+
figsize: Tuple[int, int] = (14, 6),
|
|
1475
|
+
save_path: Optional[str] = None,
|
|
1476
|
+
):
|
|
1477
|
+
"""
|
|
1478
|
+
Visualize input RGB and prediction side by side.
|
|
1479
|
+
|
|
1480
|
+
Args:
|
|
1481
|
+
input_raster: Path to input raster.
|
|
1482
|
+
pred_raster: Path to prediction raster.
|
|
1483
|
+
rgb_bands: Band indices for RGB display (1-indexed).
|
|
1484
|
+
cmap: Colormap for prediction.
|
|
1485
|
+
vmin: Minimum value for colormap.
|
|
1486
|
+
vmax: Maximum value for colormap.
|
|
1487
|
+
figsize: Figure size.
|
|
1488
|
+
save_path: Path to save figure.
|
|
1489
|
+
|
|
1490
|
+
Returns:
|
|
1491
|
+
Figure object.
|
|
1492
|
+
"""
|
|
1493
|
+
import matplotlib.pyplot as plt
|
|
1494
|
+
import rasterio
|
|
1495
|
+
|
|
1496
|
+
with rasterio.open(input_raster) as src:
|
|
1497
|
+
rgb = src.read(rgb_bands).astype(np.float64)
|
|
1498
|
+
# Per-band 2–98 percentile stretch for proper RGB display
|
|
1499
|
+
for i in range(rgb.shape[0]):
|
|
1500
|
+
band = rgb[i]
|
|
1501
|
+
valid = band[
|
|
1502
|
+
np.isfinite(band) & (band != src.nodata if src.nodata else True)
|
|
1503
|
+
]
|
|
1504
|
+
if valid.size > 0:
|
|
1505
|
+
p2, p98 = np.percentile(valid, [2, 98])
|
|
1506
|
+
if p98 > p2:
|
|
1507
|
+
rgb[i] = (band - p2) / (p98 - p2)
|
|
1508
|
+
else:
|
|
1509
|
+
rgb[i] = band / p98 if p98 > 0 else band
|
|
1510
|
+
rgb = np.clip(rgb, 0, 1)
|
|
1511
|
+
rgb = np.transpose(rgb, (1, 2, 0))
|
|
1512
|
+
|
|
1513
|
+
with rasterio.open(pred_raster) as src:
|
|
1514
|
+
pred = src.read(1)
|
|
1515
|
+
pred_nodata = src.nodata
|
|
1516
|
+
|
|
1517
|
+
# Mask
|
|
1518
|
+
valid_mask = np.ones_like(pred, dtype=bool)
|
|
1519
|
+
if pred_nodata is not None:
|
|
1520
|
+
valid_mask &= pred != pred_nodata
|
|
1521
|
+
valid_mask &= ~np.isnan(pred)
|
|
1522
|
+
pred_masked = np.ma.masked_where(~valid_mask, pred)
|
|
1523
|
+
|
|
1524
|
+
if vmin is None:
|
|
1525
|
+
vmin = np.percentile(pred[valid_mask], 2)
|
|
1526
|
+
if vmax is None:
|
|
1527
|
+
vmax = np.percentile(pred[valid_mask], 98)
|
|
1528
|
+
|
|
1529
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1530
|
+
|
|
1531
|
+
axes[0].imshow(rgb)
|
|
1532
|
+
axes[0].set_title("Input RGB", fontsize=14)
|
|
1533
|
+
axes[0].axis("off")
|
|
1534
|
+
|
|
1535
|
+
im = axes[1].imshow(pred_masked, cmap=cmap, vmin=vmin, vmax=vmax)
|
|
1536
|
+
axes[1].set_title("Prediction", fontsize=14)
|
|
1537
|
+
axes[1].axis("off")
|
|
1538
|
+
plt.colorbar(im, ax=axes[1], shrink=0.8)
|
|
1539
|
+
|
|
1540
|
+
plt.tight_layout()
|
|
1541
|
+
|
|
1542
|
+
if save_path:
|
|
1543
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
1544
|
+
|
|
1545
|
+
plt.close(fig)
|
|
1546
|
+
return fig
|
|
1547
|
+
|
|
1548
|
+
|
|
1549
|
+
# ============================================================================
|
|
1550
|
+
# Backward compatibility aliases
|
|
1551
|
+
# ============================================================================
|
|
1552
|
+
|
|
1553
|
+
# Aliases for backward compatibility
|
|
1554
|
+
TimmRegressor = PixelRegressionModel
|
|
1555
|
+
RegressionDataset = PixelRegressionDataset
|
|
1556
|
+
create_regression_patches = create_regression_tiles
|
|
1557
|
+
train_timm_regressor = train_pixel_regressor
|
|
1558
|
+
|
|
1559
|
+
|
|
1560
|
+
def plot_regression_results(
|
|
1561
|
+
y_true: np.ndarray,
|
|
1562
|
+
y_pred: np.ndarray,
|
|
1563
|
+
title: str = "Regression Results",
|
|
1564
|
+
fit_line: bool = True,
|
|
1565
|
+
figsize: Tuple[int, int] = (12, 5),
|
|
1566
|
+
save_path: Optional[str] = None,
|
|
1567
|
+
):
|
|
1568
|
+
"""
|
|
1569
|
+
Plot regression results: scatter plot with trend line and residual plot.
|
|
1570
|
+
|
|
1571
|
+
Args:
|
|
1572
|
+
y_true: Ground truth values.
|
|
1573
|
+
y_pred: Predicted values.
|
|
1574
|
+
title: Title for the plots.
|
|
1575
|
+
fit_line: Whether to show a linear regression trend line.
|
|
1576
|
+
figsize: Figure size.
|
|
1577
|
+
save_path: Path to save the figure.
|
|
1578
|
+
"""
|
|
1579
|
+
import matplotlib.pyplot as plt
|
|
1580
|
+
from sklearn.metrics import r2_score
|
|
1581
|
+
|
|
1582
|
+
y_true = np.array(y_true).flatten()
|
|
1583
|
+
y_pred = np.array(y_pred).flatten()
|
|
1584
|
+
|
|
1585
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1586
|
+
|
|
1587
|
+
# Scatter plot
|
|
1588
|
+
ax1 = axes[0]
|
|
1589
|
+
ax1.scatter(y_true, y_pred, alpha=0.5, edgecolors="none", s=20)
|
|
1590
|
+
|
|
1591
|
+
min_val = min(y_true.min(), y_pred.min())
|
|
1592
|
+
max_val = max(y_true.max(), y_pred.max())
|
|
1593
|
+
ax1.plot([min_val, max_val], [min_val, max_val], "r--", lw=2, label="1:1 Line")
|
|
1594
|
+
|
|
1595
|
+
# Add linear regression trend line
|
|
1596
|
+
if fit_line:
|
|
1597
|
+
coeffs = np.polyfit(y_true, y_pred, 1)
|
|
1598
|
+
slope, intercept = coeffs
|
|
1599
|
+
fit_x = np.array([min_val, max_val])
|
|
1600
|
+
fit_y = slope * fit_x + intercept
|
|
1601
|
+
ax1.plot(
|
|
1602
|
+
fit_x,
|
|
1603
|
+
fit_y,
|
|
1604
|
+
"b-",
|
|
1605
|
+
lw=2,
|
|
1606
|
+
label=f"Fit: y = {slope:.3f}x + {intercept:.3f}",
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1609
|
+
r2 = r2_score(y_true, y_pred)
|
|
1610
|
+
ax1.set_xlabel("Actual Values", fontsize=12)
|
|
1611
|
+
ax1.set_ylabel("Predicted Values", fontsize=12)
|
|
1612
|
+
ax1.set_title(f"Predicted vs Actual (R² = {r2:.4f})", fontsize=14)
|
|
1613
|
+
ax1.legend()
|
|
1614
|
+
ax1.grid(True, alpha=0.3)
|
|
1615
|
+
|
|
1616
|
+
# Residual plot
|
|
1617
|
+
ax2 = axes[1]
|
|
1618
|
+
residuals = y_pred - y_true
|
|
1619
|
+
ax2.scatter(y_pred, residuals, alpha=0.5, edgecolors="none", s=20)
|
|
1620
|
+
ax2.axhline(y=0, color="r", linestyle="--", lw=2)
|
|
1621
|
+
|
|
1622
|
+
ax2.set_xlabel("Predicted Values", fontsize=12)
|
|
1623
|
+
ax2.set_ylabel("Residuals", fontsize=12)
|
|
1624
|
+
ax2.set_title("Residual Plot", fontsize=14)
|
|
1625
|
+
ax2.grid(True, alpha=0.3)
|
|
1626
|
+
|
|
1627
|
+
plt.suptitle(title, fontsize=16, y=1.02)
|
|
1628
|
+
plt.tight_layout()
|
|
1629
|
+
|
|
1630
|
+
if save_path:
|
|
1631
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
1632
|
+
print(f"Figure saved to: {save_path}")
|
|
1633
|
+
|
|
1634
|
+
plt.show()
|
|
1635
|
+
|
|
1636
|
+
return fig
|
|
1637
|
+
|
|
1638
|
+
|
|
1639
|
+
def predict_with_timm_regressor(*args, **kwargs):
|
|
1640
|
+
"""Deprecated: Use predict_raster instead."""
|
|
1641
|
+
raise NotImplementedError(
|
|
1642
|
+
"predict_with_timm_regressor is deprecated. "
|
|
1643
|
+
"Use predict_raster for pixel-level predictions."
|
|
1644
|
+
)
|
|
1645
|
+
|
|
1646
|
+
|
|
1647
|
+
def get_timm_regression_model(*args, **kwargs):
|
|
1648
|
+
"""Deprecated: Use PixelRegressionModel instead."""
|
|
1649
|
+
raise NotImplementedError(
|
|
1650
|
+
"get_timm_regression_model is deprecated. "
|
|
1651
|
+
"Use PixelRegressionModel for pixel-level regression."
|
|
1652
|
+
)
|