geoai-py 0.18.1__py2.py3-none-any.whl → 0.19.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +23 -1
- geoai/agents/__init__.py +1 -0
- geoai/agents/geo_agents.py +74 -29
- geoai/geoai.py +2 -0
- geoai/landcover_train.py +685 -0
- geoai/landcover_utils.py +383 -0
- geoai/map_widgets.py +556 -0
- geoai/moondream.py +990 -0
- geoai/tools/__init__.py +11 -0
- geoai/tools/sr.py +194 -0
- geoai/train.py +22 -0
- geoai/utils.py +304 -1654
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/METADATA +3 -1
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/RECORD +18 -14
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/top_level.txt +0 -0
geoai/landcover_train.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Landcover Classification Training Module
|
|
3
|
+
|
|
4
|
+
This module extends the base geoai training functionality with specialized
|
|
5
|
+
components for discrete landcover classification, including:
|
|
6
|
+
- Enhanced loss functions with boundary weighting
|
|
7
|
+
- Per-class frequency weighting for imbalanced datasets
|
|
8
|
+
- Configurable ignore_index handling
|
|
9
|
+
- Additional validation metrics
|
|
10
|
+
|
|
11
|
+
Key Features:
|
|
12
|
+
- Maintains full compatibility with base geoai workflow
|
|
13
|
+
- Adds optional advanced loss computation modes
|
|
14
|
+
- Provides flexible ignore_index configuration
|
|
15
|
+
- Optimized for multi-class landcover segmentation
|
|
16
|
+
|
|
17
|
+
Author: ValHab Project
|
|
18
|
+
Date: November 2025
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
import torch.nn.functional as F
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class FocalLoss(nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
Focal Loss for addressing class imbalance in segmentation.
|
|
31
|
+
|
|
32
|
+
Reference: Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017).
|
|
33
|
+
Focal loss for dense object detection. ICCV.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
alpha: Weighting factor in range (0,1) to balance positive/negative examples
|
|
37
|
+
gamma: Exponent of the modulating factor (1 - p_t)^gamma
|
|
38
|
+
ignore_index: Specifies a target value that is ignored
|
|
39
|
+
reduction: Specifies the reduction to apply to the output
|
|
40
|
+
weight: Manual rescaling weight given to each class
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self, alpha=1.0, gamma=2.0, ignore_index=-100, reduction="mean", weight=None
|
|
45
|
+
):
|
|
46
|
+
super(FocalLoss, self).__init__()
|
|
47
|
+
self.alpha = alpha
|
|
48
|
+
self.gamma = gamma
|
|
49
|
+
self.ignore_index = ignore_index
|
|
50
|
+
self.reduction = reduction
|
|
51
|
+
self.weight = weight
|
|
52
|
+
|
|
53
|
+
def forward(self, inputs, targets):
|
|
54
|
+
"""
|
|
55
|
+
Forward pass of focal loss.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
inputs: Predictions (N, C, H, W) where C = number of classes
|
|
59
|
+
targets: Ground truth (N, H, W) with class indices
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Loss value
|
|
63
|
+
"""
|
|
64
|
+
# Get class probabilities
|
|
65
|
+
ce_loss = F.cross_entropy(
|
|
66
|
+
inputs,
|
|
67
|
+
targets,
|
|
68
|
+
weight=self.weight,
|
|
69
|
+
ignore_index=self.ignore_index,
|
|
70
|
+
reduction="none",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Get probability of true class
|
|
74
|
+
p_t = torch.exp(-ce_loss)
|
|
75
|
+
|
|
76
|
+
# Calculate focal loss
|
|
77
|
+
focal_loss = self.alpha * (1 - p_t) ** self.gamma * ce_loss
|
|
78
|
+
|
|
79
|
+
# Apply reduction
|
|
80
|
+
if self.reduction == "mean":
|
|
81
|
+
return focal_loss.mean()
|
|
82
|
+
elif self.reduction == "sum":
|
|
83
|
+
return focal_loss.sum()
|
|
84
|
+
else:
|
|
85
|
+
return focal_loss
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class LandcoverCrossEntropyLoss(nn.Module):
|
|
89
|
+
"""
|
|
90
|
+
Enhanced CrossEntropyLoss with optional ignore_index and class weights.
|
|
91
|
+
|
|
92
|
+
This extends the standard CrossEntropyLoss with more flexible ignore_index
|
|
93
|
+
handling, specifically designed for landcover classification tasks.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
weight: Manual rescaling weight given to each class
|
|
97
|
+
ignore_index: Specifies a target value that is ignored (default: None)
|
|
98
|
+
- None: No values ignored (standard behavior)
|
|
99
|
+
- int: Specific class index to ignore (e.g., 0 for background)
|
|
100
|
+
reduction: Specifies the reduction to apply ('mean', 'sum', 'none')
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
weight: Optional[torch.Tensor] = None,
|
|
106
|
+
ignore_index: Optional[int] = None,
|
|
107
|
+
reduction: str = "mean",
|
|
108
|
+
):
|
|
109
|
+
super().__init__()
|
|
110
|
+
self.weight = weight
|
|
111
|
+
self.ignore_index = ignore_index if ignore_index is not None else -100
|
|
112
|
+
self.reduction = reduction
|
|
113
|
+
|
|
114
|
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
115
|
+
"""
|
|
116
|
+
Compute cross entropy loss.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
input: Predictions (N, C, H, W) where C = number of classes
|
|
120
|
+
target: Ground truth (N, H, W) with class indices
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Loss value
|
|
124
|
+
"""
|
|
125
|
+
return F.cross_entropy(
|
|
126
|
+
input,
|
|
127
|
+
target,
|
|
128
|
+
weight=self.weight,
|
|
129
|
+
ignore_index=self.ignore_index,
|
|
130
|
+
reduction=self.reduction,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def landcover_iou(
|
|
135
|
+
pred: torch.Tensor,
|
|
136
|
+
target: torch.Tensor,
|
|
137
|
+
num_classes: int,
|
|
138
|
+
ignore_index: Optional[int] = None,
|
|
139
|
+
smooth: float = 1e-6,
|
|
140
|
+
mode: str = "mean",
|
|
141
|
+
boundary_weight_map: Optional[torch.Tensor] = None,
|
|
142
|
+
) -> Union[float, Tuple[float, List[float], List[int]]]:
|
|
143
|
+
"""
|
|
144
|
+
Calculate IoU for landcover classification with multiple weighting options.
|
|
145
|
+
|
|
146
|
+
Supports three IoU calculation modes:
|
|
147
|
+
1. "mean": Simple mean IoU across all classes
|
|
148
|
+
2. "perclass_frequency": Weight by per-class pixel frequency
|
|
149
|
+
3. "boundary_weighted": Weight by distance to class boundaries
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
pred: Predicted classes (N, H, W) or logits (N, C, H, W)
|
|
153
|
+
target: Ground truth (N, H, W)
|
|
154
|
+
num_classes: Number of classes
|
|
155
|
+
ignore_index: Class index to ignore (default: None)
|
|
156
|
+
smooth: Smoothing factor to avoid division by zero
|
|
157
|
+
mode: IoU calculation mode ("mean", "perclass_frequency", "boundary_weighted")
|
|
158
|
+
boundary_weight_map: Optional boundary weights (N, H, W)
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
If mode == "mean": float (mean IoU)
|
|
162
|
+
If mode == "perclass_frequency": tuple (weighted IoU, per-class IoUs, class counts)
|
|
163
|
+
If mode == "boundary_weighted": float (boundary-weighted IoU)
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
# Convert logits to class predictions if needed
|
|
167
|
+
if pred.dim() == 4:
|
|
168
|
+
pred = torch.argmax(pred, dim=1)
|
|
169
|
+
|
|
170
|
+
# Ensure correct shape
|
|
171
|
+
assert (
|
|
172
|
+
pred.shape == target.shape
|
|
173
|
+
), f"Shape mismatch: pred {pred.shape}, target {target.shape}"
|
|
174
|
+
|
|
175
|
+
# Create mask for valid pixels
|
|
176
|
+
if ignore_index is not None:
|
|
177
|
+
valid_mask = target != ignore_index
|
|
178
|
+
else:
|
|
179
|
+
valid_mask = torch.ones_like(target, dtype=torch.bool)
|
|
180
|
+
|
|
181
|
+
# Simple mean IoU
|
|
182
|
+
if mode == "mean":
|
|
183
|
+
ious = []
|
|
184
|
+
for cls in range(num_classes):
|
|
185
|
+
if ignore_index is not None and cls == ignore_index:
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
pred_cls = (pred == cls) & valid_mask
|
|
189
|
+
target_cls = (target == cls) & valid_mask
|
|
190
|
+
|
|
191
|
+
intersection = (pred_cls & target_cls).sum().float()
|
|
192
|
+
union = (pred_cls | target_cls).sum().float()
|
|
193
|
+
|
|
194
|
+
if union > 0:
|
|
195
|
+
iou = (intersection + smooth) / (union + smooth)
|
|
196
|
+
ious.append(iou.item())
|
|
197
|
+
|
|
198
|
+
return sum(ious) / len(ious) if ious else 0.0
|
|
199
|
+
|
|
200
|
+
# Per-class frequency weighted IoU
|
|
201
|
+
elif mode == "perclass_frequency":
|
|
202
|
+
ious = []
|
|
203
|
+
class_counts = []
|
|
204
|
+
|
|
205
|
+
# Filter out ignore_index from target
|
|
206
|
+
if ignore_index is not None:
|
|
207
|
+
target_filtered = target[valid_mask]
|
|
208
|
+
pred_filtered = pred[valid_mask]
|
|
209
|
+
else:
|
|
210
|
+
target_filtered = target.view(-1)
|
|
211
|
+
pred_filtered = pred.view(-1)
|
|
212
|
+
|
|
213
|
+
total_valid_pixels = target_filtered.numel()
|
|
214
|
+
|
|
215
|
+
for cls in range(num_classes):
|
|
216
|
+
if ignore_index is not None and cls == ignore_index:
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
pred_cls = pred_filtered == cls
|
|
220
|
+
target_cls = target_filtered == cls
|
|
221
|
+
|
|
222
|
+
intersection = (pred_cls & target_cls).sum().float()
|
|
223
|
+
union = (pred_cls | target_cls).sum().float()
|
|
224
|
+
|
|
225
|
+
class_pixel_count = target_cls.sum().item()
|
|
226
|
+
|
|
227
|
+
if union > 0:
|
|
228
|
+
iou = (intersection + smooth) / (union + smooth)
|
|
229
|
+
ious.append(iou.item())
|
|
230
|
+
class_counts.append(class_pixel_count)
|
|
231
|
+
else:
|
|
232
|
+
ious.append(0.0)
|
|
233
|
+
class_counts.append(0)
|
|
234
|
+
|
|
235
|
+
# Calculate frequency-weighted IoU
|
|
236
|
+
if sum(class_counts) > 0:
|
|
237
|
+
weights = [count / total_valid_pixels for count in class_counts]
|
|
238
|
+
weighted_iou = sum(iou * weight for iou, weight in zip(ious, weights))
|
|
239
|
+
else:
|
|
240
|
+
weighted_iou = 0.0
|
|
241
|
+
|
|
242
|
+
return weighted_iou, ious, class_counts
|
|
243
|
+
|
|
244
|
+
# Boundary-weighted IoU
|
|
245
|
+
elif mode == "boundary_weighted":
|
|
246
|
+
if boundary_weight_map is None:
|
|
247
|
+
raise ValueError("boundary_weight_map required for boundary_weighted mode")
|
|
248
|
+
|
|
249
|
+
ious = []
|
|
250
|
+
weights = []
|
|
251
|
+
|
|
252
|
+
for cls in range(num_classes):
|
|
253
|
+
if ignore_index is not None and cls == ignore_index:
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
pred_cls = (pred == cls) & valid_mask
|
|
257
|
+
target_cls = (target == cls) & valid_mask
|
|
258
|
+
|
|
259
|
+
# Weight by boundary map
|
|
260
|
+
weighted_intersection = (
|
|
261
|
+
pred_cls & target_cls
|
|
262
|
+
).float() * boundary_weight_map
|
|
263
|
+
weighted_union = (pred_cls | target_cls).float() * boundary_weight_map
|
|
264
|
+
|
|
265
|
+
intersection_sum = weighted_intersection.sum()
|
|
266
|
+
union_sum = weighted_union.sum()
|
|
267
|
+
|
|
268
|
+
if union_sum > 0:
|
|
269
|
+
iou = (intersection_sum + smooth) / (union_sum + smooth)
|
|
270
|
+
weight = union_sum.item()
|
|
271
|
+
ious.append(iou.item())
|
|
272
|
+
weights.append(weight)
|
|
273
|
+
|
|
274
|
+
if sum(weights) > 0:
|
|
275
|
+
weighted_iou = sum(iou * w for iou, w in zip(ious, weights)) / sum(weights)
|
|
276
|
+
else:
|
|
277
|
+
weighted_iou = 0.0
|
|
278
|
+
|
|
279
|
+
return weighted_iou
|
|
280
|
+
|
|
281
|
+
else:
|
|
282
|
+
raise ValueError(
|
|
283
|
+
f"Unknown mode: {mode}. Use 'mean', 'perclass_frequency', or 'boundary_weighted'"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_landcover_loss_function(
|
|
288
|
+
loss_name: str = "crossentropy",
|
|
289
|
+
num_classes: int = 2,
|
|
290
|
+
ignore_index: Optional[int] = None,
|
|
291
|
+
class_weights: Optional[torch.Tensor] = None,
|
|
292
|
+
use_class_weights: bool = False,
|
|
293
|
+
focal_alpha: float = 1.0,
|
|
294
|
+
focal_gamma: float = 2.0,
|
|
295
|
+
device: Optional[torch.device] = None,
|
|
296
|
+
) -> nn.Module:
|
|
297
|
+
"""
|
|
298
|
+
Get loss function configured for landcover classification.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
loss_name: Name of loss function ("crossentropy", "focal", "dice", "combo")
|
|
302
|
+
num_classes: Number of classes
|
|
303
|
+
ignore_index: Class index to ignore (default: None for no ignoring)
|
|
304
|
+
class_weights: Manual class weights tensor
|
|
305
|
+
use_class_weights: Whether to use class weights
|
|
306
|
+
focal_alpha: Alpha parameter for focal loss
|
|
307
|
+
focal_gamma: Gamma parameter for focal loss
|
|
308
|
+
device: Device to place loss function on
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
Configured loss function
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
if device is None:
|
|
315
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
316
|
+
|
|
317
|
+
loss_name = loss_name.lower()
|
|
318
|
+
|
|
319
|
+
if loss_name == "crossentropy":
|
|
320
|
+
weights = class_weights if use_class_weights else None
|
|
321
|
+
if weights is not None:
|
|
322
|
+
weights = weights.to(device)
|
|
323
|
+
|
|
324
|
+
return LandcoverCrossEntropyLoss(
|
|
325
|
+
weight=weights,
|
|
326
|
+
ignore_index=ignore_index,
|
|
327
|
+
reduction="mean",
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
elif loss_name == "focal":
|
|
331
|
+
weights = class_weights if use_class_weights else None
|
|
332
|
+
if weights is not None:
|
|
333
|
+
weights = weights.to(device)
|
|
334
|
+
|
|
335
|
+
# Use -100 as default ignore_index for compatibility
|
|
336
|
+
idx = ignore_index if ignore_index is not None else -100
|
|
337
|
+
|
|
338
|
+
return FocalLoss(
|
|
339
|
+
alpha=focal_alpha,
|
|
340
|
+
gamma=focal_gamma,
|
|
341
|
+
ignore_index=idx,
|
|
342
|
+
reduction="mean",
|
|
343
|
+
weight=weights,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
else:
|
|
347
|
+
# Fall back to standard PyTorch loss
|
|
348
|
+
weights = class_weights if use_class_weights else None
|
|
349
|
+
if weights is not None:
|
|
350
|
+
weights = weights.to(device)
|
|
351
|
+
|
|
352
|
+
# Use -100 as default ignore_index for compatibility
|
|
353
|
+
idx = ignore_index if ignore_index is not None else -100
|
|
354
|
+
|
|
355
|
+
return nn.CrossEntropyLoss(
|
|
356
|
+
weight=weights,
|
|
357
|
+
ignore_index=idx,
|
|
358
|
+
reduction="mean",
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def compute_class_weights(
|
|
363
|
+
labels_dir: str,
|
|
364
|
+
num_classes: int,
|
|
365
|
+
ignore_index: Optional[int] = None,
|
|
366
|
+
custom_multipliers: Optional[Dict[int, float]] = None,
|
|
367
|
+
max_weight: float = 50.0,
|
|
368
|
+
use_inverse_frequency: bool = True,
|
|
369
|
+
) -> torch.Tensor:
|
|
370
|
+
"""
|
|
371
|
+
Compute class weights for imbalanced datasets with optional custom multipliers and maximum weight cap.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
labels_dir: Directory containing label files
|
|
375
|
+
num_classes: Number of classes
|
|
376
|
+
ignore_index: Class index to ignore when computing weights (default: None)
|
|
377
|
+
custom_multipliers: Custom multipliers for specific classes after inverse frequency calculation.
|
|
378
|
+
Format: {class_id: multiplier}
|
|
379
|
+
Example: {1: 0.5, 7: 2.0} - reduce class 1 weight by half, double class 7 weight
|
|
380
|
+
max_weight: Maximum allowed weight value to prevent extreme values (default: 50.0)
|
|
381
|
+
use_inverse_frequency: Whether to compute inverse frequency weights.
|
|
382
|
+
- True (default): Compute inverse frequency weights, then apply custom multipliers
|
|
383
|
+
- False: Use uniform weights (1.0) for all classes, then apply custom multipliers
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
Tensor of class weights (num_classes,) with custom adjustments and maximum weight cap applied
|
|
387
|
+
"""
|
|
388
|
+
import os
|
|
389
|
+
import rasterio
|
|
390
|
+
from collections import Counter
|
|
391
|
+
|
|
392
|
+
# Count pixels for each class
|
|
393
|
+
class_counts = Counter()
|
|
394
|
+
total_pixels = 0
|
|
395
|
+
|
|
396
|
+
# Get all label files
|
|
397
|
+
label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
|
|
398
|
+
label_files = [
|
|
399
|
+
os.path.join(labels_dir, f)
|
|
400
|
+
for f in os.listdir(labels_dir)
|
|
401
|
+
if f.lower().endswith(label_extensions)
|
|
402
|
+
]
|
|
403
|
+
|
|
404
|
+
print(f"Computing class weights from {len(label_files)} label files...")
|
|
405
|
+
|
|
406
|
+
for label_file in label_files:
|
|
407
|
+
try:
|
|
408
|
+
with rasterio.open(label_file) as src:
|
|
409
|
+
label_data = src.read(1)
|
|
410
|
+
for class_id in range(num_classes):
|
|
411
|
+
if ignore_index is not None and class_id == ignore_index:
|
|
412
|
+
continue
|
|
413
|
+
count = (label_data == class_id).sum()
|
|
414
|
+
class_counts[class_id] += int(count)
|
|
415
|
+
total_pixels += int(count)
|
|
416
|
+
except Exception as e:
|
|
417
|
+
print(f"Warning: Could not read {label_file}: {e}")
|
|
418
|
+
continue
|
|
419
|
+
|
|
420
|
+
if total_pixels == 0:
|
|
421
|
+
raise ValueError("No valid pixels found in label files")
|
|
422
|
+
|
|
423
|
+
# Initialize weights
|
|
424
|
+
weights = torch.ones(num_classes)
|
|
425
|
+
|
|
426
|
+
if use_inverse_frequency:
|
|
427
|
+
# Compute inverse frequency weights
|
|
428
|
+
for class_id in range(num_classes):
|
|
429
|
+
if ignore_index is not None and class_id == ignore_index:
|
|
430
|
+
weights[class_id] = 0.0
|
|
431
|
+
elif class_counts[class_id] > 0:
|
|
432
|
+
# Inverse frequency: total_pixels / class_pixels
|
|
433
|
+
weights[class_id] = total_pixels / class_counts[class_id]
|
|
434
|
+
else:
|
|
435
|
+
weights[class_id] = 0.0
|
|
436
|
+
|
|
437
|
+
# Normalize to have mean weight of 1.0
|
|
438
|
+
non_zero_weights = weights[weights > 0]
|
|
439
|
+
if len(non_zero_weights) > 0:
|
|
440
|
+
weights = weights / non_zero_weights.mean()
|
|
441
|
+
else:
|
|
442
|
+
# Use uniform weights (all 1.0)
|
|
443
|
+
for class_id in range(num_classes):
|
|
444
|
+
if ignore_index is not None and class_id == ignore_index:
|
|
445
|
+
weights[class_id] = 0.0
|
|
446
|
+
|
|
447
|
+
# Apply custom multipliers if provided
|
|
448
|
+
if custom_multipliers:
|
|
449
|
+
print(f"\n🎯 Applying custom multipliers: {custom_multipliers}")
|
|
450
|
+
for class_id, multiplier in custom_multipliers.items():
|
|
451
|
+
if class_id < 0 or class_id >= num_classes:
|
|
452
|
+
print(f"Warning: Invalid class_id {class_id}, skipping")
|
|
453
|
+
continue
|
|
454
|
+
|
|
455
|
+
original_weight = weights[class_id].item()
|
|
456
|
+
weights[class_id] = weights[class_id] * multiplier
|
|
457
|
+
print(
|
|
458
|
+
f" Class {class_id}: {original_weight:.4f} × {multiplier} = {weights[class_id].item():.4f}"
|
|
459
|
+
)
|
|
460
|
+
else:
|
|
461
|
+
print("\nℹ️ No custom multipliers provided, using computed weights as-is")
|
|
462
|
+
|
|
463
|
+
# Apply maximum weight cap to prevent extreme values
|
|
464
|
+
weights_capped = False
|
|
465
|
+
print(f"\n🔒 Applying maximum weight cap of {max_weight}...")
|
|
466
|
+
for class_id in range(num_classes):
|
|
467
|
+
if weights[class_id] > max_weight:
|
|
468
|
+
print(
|
|
469
|
+
f" Class {class_id}: {weights[class_id].item():.4f} → {max_weight} (capped)"
|
|
470
|
+
)
|
|
471
|
+
weights[class_id] = max_weight
|
|
472
|
+
weights_capped = True
|
|
473
|
+
|
|
474
|
+
if not weights_capped:
|
|
475
|
+
print(" No weights exceeded the cap")
|
|
476
|
+
|
|
477
|
+
print(f"\nClass pixel counts: {dict(class_counts)}")
|
|
478
|
+
print(f"\nFinal class weights:")
|
|
479
|
+
for class_id in range(num_classes):
|
|
480
|
+
pixel_count = class_counts.get(class_id, 0)
|
|
481
|
+
percent = (pixel_count / total_pixels * 100) if total_pixels > 0 else 0
|
|
482
|
+
print(
|
|
483
|
+
f" Class {class_id}: weight={weights[class_id].item():.4f}, "
|
|
484
|
+
f"pixels={pixel_count:,} ({percent:.2f}%)"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
if ignore_index is not None and 0 <= ignore_index < num_classes:
|
|
488
|
+
print(f"\n⚠️ Note: Class {ignore_index} (ignore_index) has weight 0.0")
|
|
489
|
+
|
|
490
|
+
return weights
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def train_segmentation_landcover(
|
|
494
|
+
images_dir: str,
|
|
495
|
+
labels_dir: str,
|
|
496
|
+
output_dir: str,
|
|
497
|
+
input_format: str = "directory",
|
|
498
|
+
architecture: str = "unet",
|
|
499
|
+
encoder_name: str = "resnet34",
|
|
500
|
+
encoder_weights: Optional[str] = "imagenet",
|
|
501
|
+
num_channels: int = 3,
|
|
502
|
+
num_classes: int = 2,
|
|
503
|
+
batch_size: int = 8,
|
|
504
|
+
num_epochs: int = 50,
|
|
505
|
+
learning_rate: float = 0.001,
|
|
506
|
+
weight_decay: float = 1e-4,
|
|
507
|
+
seed: int = 42,
|
|
508
|
+
val_split: float = 0.2,
|
|
509
|
+
print_freq: int = 10,
|
|
510
|
+
verbose: bool = True,
|
|
511
|
+
save_best_only: bool = True,
|
|
512
|
+
plot_curves: bool = False,
|
|
513
|
+
device: Optional[torch.device] = None,
|
|
514
|
+
checkpoint_path: Optional[str] = None,
|
|
515
|
+
resume_training: bool = False,
|
|
516
|
+
target_size: Optional[Tuple[int, int]] = None,
|
|
517
|
+
resize_mode: str = "resize",
|
|
518
|
+
num_workers: Optional[int] = None,
|
|
519
|
+
loss_function: str = "crossentropy",
|
|
520
|
+
ignore_index: Optional[int] = None,
|
|
521
|
+
use_class_weights: bool = False,
|
|
522
|
+
focal_alpha: float = 1.0,
|
|
523
|
+
focal_gamma: float = 2.0,
|
|
524
|
+
custom_multipliers: Optional[Dict[int, float]] = None,
|
|
525
|
+
max_class_weight: float = 50.0,
|
|
526
|
+
use_inverse_frequency: bool = True,
|
|
527
|
+
validation_iou_mode: str = "standard",
|
|
528
|
+
boundary_alpha: float = 1.0,
|
|
529
|
+
training_callback: Optional[callable] = None,
|
|
530
|
+
**kwargs: Any,
|
|
531
|
+
) -> torch.nn.Module:
|
|
532
|
+
"""
|
|
533
|
+
Train a semantic segmentation model with landcover-specific enhancements.
|
|
534
|
+
|
|
535
|
+
This is a standalone version that wraps geoai.train.train_segmentation_model
|
|
536
|
+
with landcover-specific loss functions, class weights, and metrics.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
images_dir: Directory containing training images
|
|
540
|
+
labels_dir: Directory containing training labels
|
|
541
|
+
output_dir: Directory to save model checkpoints and training history
|
|
542
|
+
input_format: Data format ("directory", "COCO", "YOLO")
|
|
543
|
+
architecture: Model architecture (default: "unet")
|
|
544
|
+
encoder_name: Encoder backbone (default: "resnet34")
|
|
545
|
+
encoder_weights: Pretrained weights ("imagenet" or None)
|
|
546
|
+
num_channels: Number of input channels (default: 3)
|
|
547
|
+
num_classes: Number of output classes (default: 2)
|
|
548
|
+
batch_size: Training batch size (default: 8)
|
|
549
|
+
num_epochs: Number of training epochs (default: 50)
|
|
550
|
+
learning_rate: Initial learning rate (default: 0.001)
|
|
551
|
+
weight_decay: Weight decay for optimizer (default: 1e-4)
|
|
552
|
+
seed: Random seed for reproducibility (default: 42)
|
|
553
|
+
val_split: Validation split ratio (default: 0.2)
|
|
554
|
+
print_freq: Frequency of training progress prints (default: 10)
|
|
555
|
+
verbose: Enable verbose output (default: True)
|
|
556
|
+
save_best_only: Only save best model checkpoint (default: True)
|
|
557
|
+
plot_curves: Plot training curves at end (default: False)
|
|
558
|
+
device: Torch device (auto-detected if None)
|
|
559
|
+
checkpoint_path: Path to checkpoint for resuming training
|
|
560
|
+
resume_training: Whether to resume from checkpoint (default: False)
|
|
561
|
+
target_size: Target size for resizing images (H, W) or None
|
|
562
|
+
resize_mode: How to resize ("resize", "crop", or "pad")
|
|
563
|
+
num_workers: Number of dataloader workers (default: auto)
|
|
564
|
+
loss_function: Loss function name ("crossentropy", "focal")
|
|
565
|
+
ignore_index: Class index to ignore (0 for background, None to include all)
|
|
566
|
+
use_class_weights: Whether to compute and use class weights (default: False)
|
|
567
|
+
focal_alpha: Focal loss alpha parameter (default: 1.0)
|
|
568
|
+
focal_gamma: Focal loss gamma parameter (default: 2.0)
|
|
569
|
+
custom_multipliers: Custom class weight multipliers {class_id: multiplier}
|
|
570
|
+
max_class_weight: Maximum allowed class weight (default: 50.0)
|
|
571
|
+
use_inverse_frequency: Use inverse frequency for weights (default: True)
|
|
572
|
+
validation_iou_mode: IoU calculation mode for validation (default: "standard")
|
|
573
|
+
- "standard": Unweighted mean IoU (all classes equal importance)
|
|
574
|
+
- "perclass_frequency": Frequency-weighted IoU (classes weighted by pixel count)
|
|
575
|
+
- "boundary_weighted": Boundary-distance weighted IoU (wIoU, focus on edges)
|
|
576
|
+
boundary_alpha: Boundary importance factor for wIoU mode (default: 1.0)
|
|
577
|
+
Higher values = more focus on boundaries (0.01-100 range)
|
|
578
|
+
training_callback: Optional callback function for automatic metric tracking
|
|
579
|
+
**kwargs: Additional arguments passed to base training function
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
Trained model
|
|
583
|
+
|
|
584
|
+
Example:
|
|
585
|
+
>>> from landcover_train import train_segmentation_landcover
|
|
586
|
+
>>>
|
|
587
|
+
>>> model = train_segmentation_landcover(
|
|
588
|
+
... images_dir="tiles/images",
|
|
589
|
+
... labels_dir="tiles/labels",
|
|
590
|
+
... output_dir="models/landcover_001",
|
|
591
|
+
... num_classes=5,
|
|
592
|
+
... loss_function="focal",
|
|
593
|
+
... ignore_index=0, # Ignore background
|
|
594
|
+
... use_class_weights=True,
|
|
595
|
+
... custom_multipliers={1: 1.5, 4: 0.8}, # Boost class 1, reduce class 4
|
|
596
|
+
... max_class_weight=50.0,
|
|
597
|
+
... use_inverse_frequency=True, # Use inverse frequency weighting
|
|
598
|
+
... validation_iou_mode="boundary_weighted", # Focus on boundaries
|
|
599
|
+
... boundary_alpha=2.0, # Moderate boundary emphasis
|
|
600
|
+
... )
|
|
601
|
+
"""
|
|
602
|
+
|
|
603
|
+
# Import geoai training function
|
|
604
|
+
try:
|
|
605
|
+
from geoai.train import train_segmentation_model
|
|
606
|
+
except ImportError:
|
|
607
|
+
raise ImportError("geoai package not found. Install with: pip install geoai-py")
|
|
608
|
+
|
|
609
|
+
# Convert ignore_index to format expected by base function
|
|
610
|
+
# Base function uses Union[int, bool], we use Optional[int]
|
|
611
|
+
ignore_idx_param = ignore_index if ignore_index is not None else False
|
|
612
|
+
|
|
613
|
+
# Compute class weights if requested
|
|
614
|
+
class_weights = None
|
|
615
|
+
if use_class_weights:
|
|
616
|
+
if verbose:
|
|
617
|
+
print("\n" + "=" * 60)
|
|
618
|
+
print("COMPUTING CLASS WEIGHTS")
|
|
619
|
+
print("=" * 60)
|
|
620
|
+
|
|
621
|
+
class_weights = compute_class_weights(
|
|
622
|
+
labels_dir=labels_dir,
|
|
623
|
+
num_classes=num_classes,
|
|
624
|
+
ignore_index=ignore_index if ignore_index is not None else -100,
|
|
625
|
+
custom_multipliers=custom_multipliers,
|
|
626
|
+
max_weight=max_class_weight,
|
|
627
|
+
use_inverse_frequency=use_inverse_frequency,
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
if verbose:
|
|
631
|
+
print("=" * 60 + "\n")
|
|
632
|
+
|
|
633
|
+
# Call base training function with enhanced parameters
|
|
634
|
+
model = train_segmentation_model(
|
|
635
|
+
images_dir=images_dir,
|
|
636
|
+
labels_dir=labels_dir,
|
|
637
|
+
output_dir=output_dir,
|
|
638
|
+
input_format=input_format,
|
|
639
|
+
architecture=architecture,
|
|
640
|
+
encoder_name=encoder_name,
|
|
641
|
+
encoder_weights=encoder_weights,
|
|
642
|
+
num_channels=num_channels,
|
|
643
|
+
num_classes=num_classes,
|
|
644
|
+
batch_size=batch_size,
|
|
645
|
+
num_epochs=num_epochs,
|
|
646
|
+
learning_rate=learning_rate,
|
|
647
|
+
weight_decay=weight_decay,
|
|
648
|
+
seed=seed,
|
|
649
|
+
val_split=val_split,
|
|
650
|
+
print_freq=print_freq,
|
|
651
|
+
verbose=verbose,
|
|
652
|
+
save_best_only=save_best_only,
|
|
653
|
+
plot_curves=plot_curves,
|
|
654
|
+
device=device,
|
|
655
|
+
checkpoint_path=checkpoint_path,
|
|
656
|
+
resume_training=resume_training,
|
|
657
|
+
target_size=target_size,
|
|
658
|
+
resize_mode=resize_mode,
|
|
659
|
+
num_workers=num_workers,
|
|
660
|
+
loss_function=loss_function,
|
|
661
|
+
ignore_index=ignore_idx_param,
|
|
662
|
+
use_class_weights=use_class_weights,
|
|
663
|
+
focal_alpha=focal_alpha,
|
|
664
|
+
focal_gamma=focal_gamma,
|
|
665
|
+
custom_multipliers=custom_multipliers,
|
|
666
|
+
max_class_weight=max_class_weight,
|
|
667
|
+
use_inverse_frequency=use_inverse_frequency,
|
|
668
|
+
validation_iou_mode=validation_iou_mode,
|
|
669
|
+
boundary_alpha=boundary_alpha,
|
|
670
|
+
training_callback=training_callback,
|
|
671
|
+
**kwargs,
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
return model
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
# Export main functions
|
|
678
|
+
__all__ = [
|
|
679
|
+
"FocalLoss",
|
|
680
|
+
"LandcoverCrossEntropyLoss",
|
|
681
|
+
"landcover_iou",
|
|
682
|
+
"get_landcover_loss_function",
|
|
683
|
+
"compute_class_weights",
|
|
684
|
+
"train_segmentation_landcover",
|
|
685
|
+
]
|