geoai-py 0.18.2__py2.py3-none-any.whl → 0.20.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,685 @@
1
+ """
2
+ Landcover Classification Training Module
3
+
4
+ This module extends the base geoai training functionality with specialized
5
+ components for discrete landcover classification, including:
6
+ - Enhanced loss functions with boundary weighting
7
+ - Per-class frequency weighting for imbalanced datasets
8
+ - Configurable ignore_index handling
9
+ - Additional validation metrics
10
+
11
+ Key Features:
12
+ - Maintains full compatibility with base geoai workflow
13
+ - Adds optional advanced loss computation modes
14
+ - Provides flexible ignore_index configuration
15
+ - Optimized for multi-class landcover segmentation
16
+
17
+ Author: ValHab Project
18
+ Date: November 2025
19
+ """
20
+
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+
28
+ class FocalLoss(nn.Module):
29
+ """
30
+ Focal Loss for addressing class imbalance in segmentation.
31
+
32
+ Reference: Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017).
33
+ Focal loss for dense object detection. ICCV.
34
+
35
+ Args:
36
+ alpha: Weighting factor in range (0,1) to balance positive/negative examples
37
+ gamma: Exponent of the modulating factor (1 - p_t)^gamma
38
+ ignore_index: Specifies a target value that is ignored
39
+ reduction: Specifies the reduction to apply to the output
40
+ weight: Manual rescaling weight given to each class
41
+ """
42
+
43
+ def __init__(
44
+ self, alpha=1.0, gamma=2.0, ignore_index=-100, reduction="mean", weight=None
45
+ ):
46
+ super(FocalLoss, self).__init__()
47
+ self.alpha = alpha
48
+ self.gamma = gamma
49
+ self.ignore_index = ignore_index
50
+ self.reduction = reduction
51
+ self.weight = weight
52
+
53
+ def forward(self, inputs, targets):
54
+ """
55
+ Forward pass of focal loss.
56
+
57
+ Args:
58
+ inputs: Predictions (N, C, H, W) where C = number of classes
59
+ targets: Ground truth (N, H, W) with class indices
60
+
61
+ Returns:
62
+ Loss value
63
+ """
64
+ # Get class probabilities
65
+ ce_loss = F.cross_entropy(
66
+ inputs,
67
+ targets,
68
+ weight=self.weight,
69
+ ignore_index=self.ignore_index,
70
+ reduction="none",
71
+ )
72
+
73
+ # Get probability of true class
74
+ p_t = torch.exp(-ce_loss)
75
+
76
+ # Calculate focal loss
77
+ focal_loss = self.alpha * (1 - p_t) ** self.gamma * ce_loss
78
+
79
+ # Apply reduction
80
+ if self.reduction == "mean":
81
+ return focal_loss.mean()
82
+ elif self.reduction == "sum":
83
+ return focal_loss.sum()
84
+ else:
85
+ return focal_loss
86
+
87
+
88
+ class LandcoverCrossEntropyLoss(nn.Module):
89
+ """
90
+ Enhanced CrossEntropyLoss with optional ignore_index and class weights.
91
+
92
+ This extends the standard CrossEntropyLoss with more flexible ignore_index
93
+ handling, specifically designed for landcover classification tasks.
94
+
95
+ Args:
96
+ weight: Manual rescaling weight given to each class
97
+ ignore_index: Specifies a target value that is ignored (default: None)
98
+ - None: No values ignored (standard behavior)
99
+ - int: Specific class index to ignore (e.g., 0 for background)
100
+ reduction: Specifies the reduction to apply ('mean', 'sum', 'none')
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ weight: Optional[torch.Tensor] = None,
106
+ ignore_index: Optional[int] = None,
107
+ reduction: str = "mean",
108
+ ):
109
+ super().__init__()
110
+ self.weight = weight
111
+ self.ignore_index = ignore_index if ignore_index is not None else -100
112
+ self.reduction = reduction
113
+
114
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
115
+ """
116
+ Compute cross entropy loss.
117
+
118
+ Args:
119
+ input: Predictions (N, C, H, W) where C = number of classes
120
+ target: Ground truth (N, H, W) with class indices
121
+
122
+ Returns:
123
+ Loss value
124
+ """
125
+ return F.cross_entropy(
126
+ input,
127
+ target,
128
+ weight=self.weight,
129
+ ignore_index=self.ignore_index,
130
+ reduction=self.reduction,
131
+ )
132
+
133
+
134
+ def landcover_iou(
135
+ pred: torch.Tensor,
136
+ target: torch.Tensor,
137
+ num_classes: int,
138
+ ignore_index: Optional[int] = None,
139
+ smooth: float = 1e-6,
140
+ mode: str = "mean",
141
+ boundary_weight_map: Optional[torch.Tensor] = None,
142
+ ) -> Union[float, Tuple[float, List[float], List[int]]]:
143
+ """
144
+ Calculate IoU for landcover classification with multiple weighting options.
145
+
146
+ Supports three IoU calculation modes:
147
+ 1. "mean": Simple mean IoU across all classes
148
+ 2. "perclass_frequency": Weight by per-class pixel frequency
149
+ 3. "boundary_weighted": Weight by distance to class boundaries
150
+
151
+ Args:
152
+ pred: Predicted classes (N, H, W) or logits (N, C, H, W)
153
+ target: Ground truth (N, H, W)
154
+ num_classes: Number of classes
155
+ ignore_index: Class index to ignore (default: None)
156
+ smooth: Smoothing factor to avoid division by zero
157
+ mode: IoU calculation mode ("mean", "perclass_frequency", "boundary_weighted")
158
+ boundary_weight_map: Optional boundary weights (N, H, W)
159
+
160
+ Returns:
161
+ If mode == "mean": float (mean IoU)
162
+ If mode == "perclass_frequency": tuple (weighted IoU, per-class IoUs, class counts)
163
+ If mode == "boundary_weighted": float (boundary-weighted IoU)
164
+ """
165
+
166
+ # Convert logits to class predictions if needed
167
+ if pred.dim() == 4:
168
+ pred = torch.argmax(pred, dim=1)
169
+
170
+ # Ensure correct shape
171
+ assert (
172
+ pred.shape == target.shape
173
+ ), f"Shape mismatch: pred {pred.shape}, target {target.shape}"
174
+
175
+ # Create mask for valid pixels
176
+ if ignore_index is not None:
177
+ valid_mask = target != ignore_index
178
+ else:
179
+ valid_mask = torch.ones_like(target, dtype=torch.bool)
180
+
181
+ # Simple mean IoU
182
+ if mode == "mean":
183
+ ious = []
184
+ for cls in range(num_classes):
185
+ if ignore_index is not None and cls == ignore_index:
186
+ continue
187
+
188
+ pred_cls = (pred == cls) & valid_mask
189
+ target_cls = (target == cls) & valid_mask
190
+
191
+ intersection = (pred_cls & target_cls).sum().float()
192
+ union = (pred_cls | target_cls).sum().float()
193
+
194
+ if union > 0:
195
+ iou = (intersection + smooth) / (union + smooth)
196
+ ious.append(iou.item())
197
+
198
+ return sum(ious) / len(ious) if ious else 0.0
199
+
200
+ # Per-class frequency weighted IoU
201
+ elif mode == "perclass_frequency":
202
+ ious = []
203
+ class_counts = []
204
+
205
+ # Filter out ignore_index from target
206
+ if ignore_index is not None:
207
+ target_filtered = target[valid_mask]
208
+ pred_filtered = pred[valid_mask]
209
+ else:
210
+ target_filtered = target.view(-1)
211
+ pred_filtered = pred.view(-1)
212
+
213
+ total_valid_pixels = target_filtered.numel()
214
+
215
+ for cls in range(num_classes):
216
+ if ignore_index is not None and cls == ignore_index:
217
+ continue
218
+
219
+ pred_cls = pred_filtered == cls
220
+ target_cls = target_filtered == cls
221
+
222
+ intersection = (pred_cls & target_cls).sum().float()
223
+ union = (pred_cls | target_cls).sum().float()
224
+
225
+ class_pixel_count = target_cls.sum().item()
226
+
227
+ if union > 0:
228
+ iou = (intersection + smooth) / (union + smooth)
229
+ ious.append(iou.item())
230
+ class_counts.append(class_pixel_count)
231
+ else:
232
+ ious.append(0.0)
233
+ class_counts.append(0)
234
+
235
+ # Calculate frequency-weighted IoU
236
+ if sum(class_counts) > 0:
237
+ weights = [count / total_valid_pixels for count in class_counts]
238
+ weighted_iou = sum(iou * weight for iou, weight in zip(ious, weights))
239
+ else:
240
+ weighted_iou = 0.0
241
+
242
+ return weighted_iou, ious, class_counts
243
+
244
+ # Boundary-weighted IoU
245
+ elif mode == "boundary_weighted":
246
+ if boundary_weight_map is None:
247
+ raise ValueError("boundary_weight_map required for boundary_weighted mode")
248
+
249
+ ious = []
250
+ weights = []
251
+
252
+ for cls in range(num_classes):
253
+ if ignore_index is not None and cls == ignore_index:
254
+ continue
255
+
256
+ pred_cls = (pred == cls) & valid_mask
257
+ target_cls = (target == cls) & valid_mask
258
+
259
+ # Weight by boundary map
260
+ weighted_intersection = (
261
+ pred_cls & target_cls
262
+ ).float() * boundary_weight_map
263
+ weighted_union = (pred_cls | target_cls).float() * boundary_weight_map
264
+
265
+ intersection_sum = weighted_intersection.sum()
266
+ union_sum = weighted_union.sum()
267
+
268
+ if union_sum > 0:
269
+ iou = (intersection_sum + smooth) / (union_sum + smooth)
270
+ weight = union_sum.item()
271
+ ious.append(iou.item())
272
+ weights.append(weight)
273
+
274
+ if sum(weights) > 0:
275
+ weighted_iou = sum(iou * w for iou, w in zip(ious, weights)) / sum(weights)
276
+ else:
277
+ weighted_iou = 0.0
278
+
279
+ return weighted_iou
280
+
281
+ else:
282
+ raise ValueError(
283
+ f"Unknown mode: {mode}. Use 'mean', 'perclass_frequency', or 'boundary_weighted'"
284
+ )
285
+
286
+
287
+ def get_landcover_loss_function(
288
+ loss_name: str = "crossentropy",
289
+ num_classes: int = 2,
290
+ ignore_index: Optional[int] = None,
291
+ class_weights: Optional[torch.Tensor] = None,
292
+ use_class_weights: bool = False,
293
+ focal_alpha: float = 1.0,
294
+ focal_gamma: float = 2.0,
295
+ device: Optional[torch.device] = None,
296
+ ) -> nn.Module:
297
+ """
298
+ Get loss function configured for landcover classification.
299
+
300
+ Args:
301
+ loss_name: Name of loss function ("crossentropy", "focal", "dice", "combo")
302
+ num_classes: Number of classes
303
+ ignore_index: Class index to ignore (default: None for no ignoring)
304
+ class_weights: Manual class weights tensor
305
+ use_class_weights: Whether to use class weights
306
+ focal_alpha: Alpha parameter for focal loss
307
+ focal_gamma: Gamma parameter for focal loss
308
+ device: Device to place loss function on
309
+
310
+ Returns:
311
+ Configured loss function
312
+ """
313
+
314
+ if device is None:
315
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
316
+
317
+ loss_name = loss_name.lower()
318
+
319
+ if loss_name == "crossentropy":
320
+ weights = class_weights if use_class_weights else None
321
+ if weights is not None:
322
+ weights = weights.to(device)
323
+
324
+ return LandcoverCrossEntropyLoss(
325
+ weight=weights,
326
+ ignore_index=ignore_index,
327
+ reduction="mean",
328
+ )
329
+
330
+ elif loss_name == "focal":
331
+ weights = class_weights if use_class_weights else None
332
+ if weights is not None:
333
+ weights = weights.to(device)
334
+
335
+ # Use -100 as default ignore_index for compatibility
336
+ idx = ignore_index if ignore_index is not None else -100
337
+
338
+ return FocalLoss(
339
+ alpha=focal_alpha,
340
+ gamma=focal_gamma,
341
+ ignore_index=idx,
342
+ reduction="mean",
343
+ weight=weights,
344
+ )
345
+
346
+ else:
347
+ # Fall back to standard PyTorch loss
348
+ weights = class_weights if use_class_weights else None
349
+ if weights is not None:
350
+ weights = weights.to(device)
351
+
352
+ # Use -100 as default ignore_index for compatibility
353
+ idx = ignore_index if ignore_index is not None else -100
354
+
355
+ return nn.CrossEntropyLoss(
356
+ weight=weights,
357
+ ignore_index=idx,
358
+ reduction="mean",
359
+ )
360
+
361
+
362
+ def compute_class_weights(
363
+ labels_dir: str,
364
+ num_classes: int,
365
+ ignore_index: Optional[int] = None,
366
+ custom_multipliers: Optional[Dict[int, float]] = None,
367
+ max_weight: float = 50.0,
368
+ use_inverse_frequency: bool = True,
369
+ ) -> torch.Tensor:
370
+ """
371
+ Compute class weights for imbalanced datasets with optional custom multipliers and maximum weight cap.
372
+
373
+ Args:
374
+ labels_dir: Directory containing label files
375
+ num_classes: Number of classes
376
+ ignore_index: Class index to ignore when computing weights (default: None)
377
+ custom_multipliers: Custom multipliers for specific classes after inverse frequency calculation.
378
+ Format: {class_id: multiplier}
379
+ Example: {1: 0.5, 7: 2.0} - reduce class 1 weight by half, double class 7 weight
380
+ max_weight: Maximum allowed weight value to prevent extreme values (default: 50.0)
381
+ use_inverse_frequency: Whether to compute inverse frequency weights.
382
+ - True (default): Compute inverse frequency weights, then apply custom multipliers
383
+ - False: Use uniform weights (1.0) for all classes, then apply custom multipliers
384
+
385
+ Returns:
386
+ Tensor of class weights (num_classes,) with custom adjustments and maximum weight cap applied
387
+ """
388
+ import os
389
+ import rasterio
390
+ from collections import Counter
391
+
392
+ # Count pixels for each class
393
+ class_counts = Counter()
394
+ total_pixels = 0
395
+
396
+ # Get all label files
397
+ label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
398
+ label_files = [
399
+ os.path.join(labels_dir, f)
400
+ for f in os.listdir(labels_dir)
401
+ if f.lower().endswith(label_extensions)
402
+ ]
403
+
404
+ print(f"Computing class weights from {len(label_files)} label files...")
405
+
406
+ for label_file in label_files:
407
+ try:
408
+ with rasterio.open(label_file) as src:
409
+ label_data = src.read(1)
410
+ for class_id in range(num_classes):
411
+ if ignore_index is not None and class_id == ignore_index:
412
+ continue
413
+ count = (label_data == class_id).sum()
414
+ class_counts[class_id] += int(count)
415
+ total_pixels += int(count)
416
+ except Exception as e:
417
+ print(f"Warning: Could not read {label_file}: {e}")
418
+ continue
419
+
420
+ if total_pixels == 0:
421
+ raise ValueError("No valid pixels found in label files")
422
+
423
+ # Initialize weights
424
+ weights = torch.ones(num_classes)
425
+
426
+ if use_inverse_frequency:
427
+ # Compute inverse frequency weights
428
+ for class_id in range(num_classes):
429
+ if ignore_index is not None and class_id == ignore_index:
430
+ weights[class_id] = 0.0
431
+ elif class_counts[class_id] > 0:
432
+ # Inverse frequency: total_pixels / class_pixels
433
+ weights[class_id] = total_pixels / class_counts[class_id]
434
+ else:
435
+ weights[class_id] = 0.0
436
+
437
+ # Normalize to have mean weight of 1.0
438
+ non_zero_weights = weights[weights > 0]
439
+ if len(non_zero_weights) > 0:
440
+ weights = weights / non_zero_weights.mean()
441
+ else:
442
+ # Use uniform weights (all 1.0)
443
+ for class_id in range(num_classes):
444
+ if ignore_index is not None and class_id == ignore_index:
445
+ weights[class_id] = 0.0
446
+
447
+ # Apply custom multipliers if provided
448
+ if custom_multipliers:
449
+ print(f"\n🎯 Applying custom multipliers: {custom_multipliers}")
450
+ for class_id, multiplier in custom_multipliers.items():
451
+ if class_id < 0 or class_id >= num_classes:
452
+ print(f"Warning: Invalid class_id {class_id}, skipping")
453
+ continue
454
+
455
+ original_weight = weights[class_id].item()
456
+ weights[class_id] = weights[class_id] * multiplier
457
+ print(
458
+ f" Class {class_id}: {original_weight:.4f} × {multiplier} = {weights[class_id].item():.4f}"
459
+ )
460
+ else:
461
+ print("\nℹ️ No custom multipliers provided, using computed weights as-is")
462
+
463
+ # Apply maximum weight cap to prevent extreme values
464
+ weights_capped = False
465
+ print(f"\n🔒 Applying maximum weight cap of {max_weight}...")
466
+ for class_id in range(num_classes):
467
+ if weights[class_id] > max_weight:
468
+ print(
469
+ f" Class {class_id}: {weights[class_id].item():.4f} → {max_weight} (capped)"
470
+ )
471
+ weights[class_id] = max_weight
472
+ weights_capped = True
473
+
474
+ if not weights_capped:
475
+ print(" No weights exceeded the cap")
476
+
477
+ print(f"\nClass pixel counts: {dict(class_counts)}")
478
+ print(f"\nFinal class weights:")
479
+ for class_id in range(num_classes):
480
+ pixel_count = class_counts.get(class_id, 0)
481
+ percent = (pixel_count / total_pixels * 100) if total_pixels > 0 else 0
482
+ print(
483
+ f" Class {class_id}: weight={weights[class_id].item():.4f}, "
484
+ f"pixels={pixel_count:,} ({percent:.2f}%)"
485
+ )
486
+
487
+ if ignore_index is not None and 0 <= ignore_index < num_classes:
488
+ print(f"\n⚠️ Note: Class {ignore_index} (ignore_index) has weight 0.0")
489
+
490
+ return weights
491
+
492
+
493
+ def train_segmentation_landcover(
494
+ images_dir: str,
495
+ labels_dir: str,
496
+ output_dir: str,
497
+ input_format: str = "directory",
498
+ architecture: str = "unet",
499
+ encoder_name: str = "resnet34",
500
+ encoder_weights: Optional[str] = "imagenet",
501
+ num_channels: int = 3,
502
+ num_classes: int = 2,
503
+ batch_size: int = 8,
504
+ num_epochs: int = 50,
505
+ learning_rate: float = 0.001,
506
+ weight_decay: float = 1e-4,
507
+ seed: int = 42,
508
+ val_split: float = 0.2,
509
+ print_freq: int = 10,
510
+ verbose: bool = True,
511
+ save_best_only: bool = True,
512
+ plot_curves: bool = False,
513
+ device: Optional[torch.device] = None,
514
+ checkpoint_path: Optional[str] = None,
515
+ resume_training: bool = False,
516
+ target_size: Optional[Tuple[int, int]] = None,
517
+ resize_mode: str = "resize",
518
+ num_workers: Optional[int] = None,
519
+ loss_function: str = "crossentropy",
520
+ ignore_index: Optional[int] = None,
521
+ use_class_weights: bool = False,
522
+ focal_alpha: float = 1.0,
523
+ focal_gamma: float = 2.0,
524
+ custom_multipliers: Optional[Dict[int, float]] = None,
525
+ max_class_weight: float = 50.0,
526
+ use_inverse_frequency: bool = True,
527
+ validation_iou_mode: str = "standard",
528
+ boundary_alpha: float = 1.0,
529
+ training_callback: Optional[callable] = None,
530
+ **kwargs: Any,
531
+ ) -> torch.nn.Module:
532
+ """
533
+ Train a semantic segmentation model with landcover-specific enhancements.
534
+
535
+ This is a standalone version that wraps geoai.train.train_segmentation_model
536
+ with landcover-specific loss functions, class weights, and metrics.
537
+
538
+ Args:
539
+ images_dir: Directory containing training images
540
+ labels_dir: Directory containing training labels
541
+ output_dir: Directory to save model checkpoints and training history
542
+ input_format: Data format ("directory", "COCO", "YOLO")
543
+ architecture: Model architecture (default: "unet")
544
+ encoder_name: Encoder backbone (default: "resnet34")
545
+ encoder_weights: Pretrained weights ("imagenet" or None)
546
+ num_channels: Number of input channels (default: 3)
547
+ num_classes: Number of output classes (default: 2)
548
+ batch_size: Training batch size (default: 8)
549
+ num_epochs: Number of training epochs (default: 50)
550
+ learning_rate: Initial learning rate (default: 0.001)
551
+ weight_decay: Weight decay for optimizer (default: 1e-4)
552
+ seed: Random seed for reproducibility (default: 42)
553
+ val_split: Validation split ratio (default: 0.2)
554
+ print_freq: Frequency of training progress prints (default: 10)
555
+ verbose: Enable verbose output (default: True)
556
+ save_best_only: Only save best model checkpoint (default: True)
557
+ plot_curves: Plot training curves at end (default: False)
558
+ device: Torch device (auto-detected if None)
559
+ checkpoint_path: Path to checkpoint for resuming training
560
+ resume_training: Whether to resume from checkpoint (default: False)
561
+ target_size: Target size for resizing images (H, W) or None
562
+ resize_mode: How to resize ("resize", "crop", or "pad")
563
+ num_workers: Number of dataloader workers (default: auto)
564
+ loss_function: Loss function name ("crossentropy", "focal")
565
+ ignore_index: Class index to ignore (0 for background, None to include all)
566
+ use_class_weights: Whether to compute and use class weights (default: False)
567
+ focal_alpha: Focal loss alpha parameter (default: 1.0)
568
+ focal_gamma: Focal loss gamma parameter (default: 2.0)
569
+ custom_multipliers: Custom class weight multipliers {class_id: multiplier}
570
+ max_class_weight: Maximum allowed class weight (default: 50.0)
571
+ use_inverse_frequency: Use inverse frequency for weights (default: True)
572
+ validation_iou_mode: IoU calculation mode for validation (default: "standard")
573
+ - "standard": Unweighted mean IoU (all classes equal importance)
574
+ - "perclass_frequency": Frequency-weighted IoU (classes weighted by pixel count)
575
+ - "boundary_weighted": Boundary-distance weighted IoU (wIoU, focus on edges)
576
+ boundary_alpha: Boundary importance factor for wIoU mode (default: 1.0)
577
+ Higher values = more focus on boundaries (0.01-100 range)
578
+ training_callback: Optional callback function for automatic metric tracking
579
+ **kwargs: Additional arguments passed to base training function
580
+
581
+ Returns:
582
+ Trained model
583
+
584
+ Example:
585
+ >>> from landcover_train import train_segmentation_landcover
586
+ >>>
587
+ >>> model = train_segmentation_landcover(
588
+ ... images_dir="tiles/images",
589
+ ... labels_dir="tiles/labels",
590
+ ... output_dir="models/landcover_001",
591
+ ... num_classes=5,
592
+ ... loss_function="focal",
593
+ ... ignore_index=0, # Ignore background
594
+ ... use_class_weights=True,
595
+ ... custom_multipliers={1: 1.5, 4: 0.8}, # Boost class 1, reduce class 4
596
+ ... max_class_weight=50.0,
597
+ ... use_inverse_frequency=True, # Use inverse frequency weighting
598
+ ... validation_iou_mode="boundary_weighted", # Focus on boundaries
599
+ ... boundary_alpha=2.0, # Moderate boundary emphasis
600
+ ... )
601
+ """
602
+
603
+ # Import geoai training function
604
+ try:
605
+ from geoai.train import train_segmentation_model
606
+ except ImportError:
607
+ raise ImportError("geoai package not found. Install with: pip install geoai-py")
608
+
609
+ # Convert ignore_index to format expected by base function
610
+ # Base function uses Union[int, bool], we use Optional[int]
611
+ ignore_idx_param = ignore_index if ignore_index is not None else False
612
+
613
+ # Compute class weights if requested
614
+ class_weights = None
615
+ if use_class_weights:
616
+ if verbose:
617
+ print("\n" + "=" * 60)
618
+ print("COMPUTING CLASS WEIGHTS")
619
+ print("=" * 60)
620
+
621
+ class_weights = compute_class_weights(
622
+ labels_dir=labels_dir,
623
+ num_classes=num_classes,
624
+ ignore_index=ignore_index if ignore_index is not None else -100,
625
+ custom_multipliers=custom_multipliers,
626
+ max_weight=max_class_weight,
627
+ use_inverse_frequency=use_inverse_frequency,
628
+ )
629
+
630
+ if verbose:
631
+ print("=" * 60 + "\n")
632
+
633
+ # Call base training function with enhanced parameters
634
+ model = train_segmentation_model(
635
+ images_dir=images_dir,
636
+ labels_dir=labels_dir,
637
+ output_dir=output_dir,
638
+ input_format=input_format,
639
+ architecture=architecture,
640
+ encoder_name=encoder_name,
641
+ encoder_weights=encoder_weights,
642
+ num_channels=num_channels,
643
+ num_classes=num_classes,
644
+ batch_size=batch_size,
645
+ num_epochs=num_epochs,
646
+ learning_rate=learning_rate,
647
+ weight_decay=weight_decay,
648
+ seed=seed,
649
+ val_split=val_split,
650
+ print_freq=print_freq,
651
+ verbose=verbose,
652
+ save_best_only=save_best_only,
653
+ plot_curves=plot_curves,
654
+ device=device,
655
+ checkpoint_path=checkpoint_path,
656
+ resume_training=resume_training,
657
+ target_size=target_size,
658
+ resize_mode=resize_mode,
659
+ num_workers=num_workers,
660
+ loss_function=loss_function,
661
+ ignore_index=ignore_idx_param,
662
+ use_class_weights=use_class_weights,
663
+ focal_alpha=focal_alpha,
664
+ focal_gamma=focal_gamma,
665
+ custom_multipliers=custom_multipliers,
666
+ max_class_weight=max_class_weight,
667
+ use_inverse_frequency=use_inverse_frequency,
668
+ validation_iou_mode=validation_iou_mode,
669
+ boundary_alpha=boundary_alpha,
670
+ training_callback=training_callback,
671
+ **kwargs,
672
+ )
673
+
674
+ return model
675
+
676
+
677
+ # Export main functions
678
+ __all__ = [
679
+ "FocalLoss",
680
+ "LandcoverCrossEntropyLoss",
681
+ "landcover_iou",
682
+ "get_landcover_loss_function",
683
+ "compute_class_weights",
684
+ "train_segmentation_landcover",
685
+ ]