geoai-py 0.15.0__py2.py3-none-any.whl → 0.17.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/agents/__init__.py +4 -0
- geoai/agents/catalog_models.py +51 -0
- geoai/agents/catalog_tools.py +907 -0
- geoai/agents/geo_agents.py +925 -41
- geoai/agents/stac_models.py +67 -0
- geoai/agents/stac_tools.py +435 -0
- geoai/change_detection.py +16 -6
- geoai/download.py +5 -1
- geoai/geoai.py +3 -0
- geoai/train.py +573 -31
- geoai/utils.py +752 -208
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/METADATA +2 -1
- geoai_py-0.17.0.dist-info/RECORD +30 -0
- geoai_py-0.15.0.dist-info/RECORD +0 -26
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/top_level.txt +0 -0
geoai/train.py
CHANGED
@@ -34,6 +34,14 @@ try:
|
|
34
34
|
except ImportError:
|
35
35
|
SMP_AVAILABLE = False
|
36
36
|
|
37
|
+
# Additional imports for Lightly Train
|
38
|
+
try:
|
39
|
+
import lightly_train
|
40
|
+
|
41
|
+
LIGHTLY_TRAIN_AVAILABLE = True
|
42
|
+
except ImportError:
|
43
|
+
LIGHTLY_TRAIN_AVAILABLE = False
|
44
|
+
|
37
45
|
|
38
46
|
def parse_coco_annotations(
|
39
47
|
coco_json_path: str, images_dir: str, labels_dir: str
|
@@ -2025,14 +2033,14 @@ def get_smp_model(
|
|
2025
2033
|
)
|
2026
2034
|
|
2027
2035
|
|
2028
|
-
def
|
2036
|
+
def f1_score(
|
2029
2037
|
pred: torch.Tensor,
|
2030
2038
|
target: torch.Tensor,
|
2031
2039
|
smooth: float = 1e-6,
|
2032
2040
|
num_classes: Optional[int] = None,
|
2033
2041
|
) -> float:
|
2034
2042
|
"""
|
2035
|
-
Calculate Dice coefficient for segmentation (binary or multi-class).
|
2043
|
+
Calculate F1 score (also known as Dice coefficient) for segmentation (binary or multi-class).
|
2036
2044
|
|
2037
2045
|
Args:
|
2038
2046
|
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
@@ -2041,7 +2049,7 @@ def dice_coefficient(
|
|
2041
2049
|
num_classes (int, optional): Number of classes. If None, auto-detected.
|
2042
2050
|
|
2043
2051
|
Returns:
|
2044
|
-
float: Mean
|
2052
|
+
float: Mean F1 score across all classes.
|
2045
2053
|
"""
|
2046
2054
|
# Convert predictions to class predictions
|
2047
2055
|
if pred.dim() == 3: # [C, H, W] format
|
@@ -2056,8 +2064,8 @@ def dice_coefficient(
|
|
2056
2064
|
if num_classes is None:
|
2057
2065
|
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
2058
2066
|
|
2059
|
-
# Calculate
|
2060
|
-
|
2067
|
+
# Calculate F1 score for each class and average
|
2068
|
+
f1_scores = []
|
2061
2069
|
for class_id in range(num_classes):
|
2062
2070
|
pred_class = (pred_classes == class_id).float()
|
2063
2071
|
target_class = (target == class_id).float()
|
@@ -2066,10 +2074,10 @@ def dice_coefficient(
|
|
2066
2074
|
union = pred_class.sum() + target_class.sum()
|
2067
2075
|
|
2068
2076
|
if union > 0:
|
2069
|
-
|
2070
|
-
|
2077
|
+
f1 = (2.0 * intersection + smooth) / (union + smooth)
|
2078
|
+
f1_scores.append(f1.item())
|
2071
2079
|
|
2072
|
-
return sum(
|
2080
|
+
return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
|
2073
2081
|
|
2074
2082
|
|
2075
2083
|
def iou_coefficient(
|
@@ -2119,6 +2127,108 @@ def iou_coefficient(
|
|
2119
2127
|
return sum(iou_scores) / len(iou_scores) if iou_scores else 0.0
|
2120
2128
|
|
2121
2129
|
|
2130
|
+
def precision_score(
|
2131
|
+
pred: torch.Tensor,
|
2132
|
+
target: torch.Tensor,
|
2133
|
+
smooth: float = 1e-6,
|
2134
|
+
num_classes: Optional[int] = None,
|
2135
|
+
) -> float:
|
2136
|
+
"""
|
2137
|
+
Calculate precision score for segmentation (binary or multi-class).
|
2138
|
+
|
2139
|
+
Precision = TP / (TP + FP), where:
|
2140
|
+
- TP (True Positives): Correctly predicted positive pixels
|
2141
|
+
- FP (False Positives): Incorrectly predicted positive pixels
|
2142
|
+
|
2143
|
+
Args:
|
2144
|
+
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
2145
|
+
target (torch.Tensor): Ground truth mask with shape [H, W].
|
2146
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
2147
|
+
num_classes (int, optional): Number of classes. If None, auto-detected.
|
2148
|
+
|
2149
|
+
Returns:
|
2150
|
+
float: Mean precision score across all classes.
|
2151
|
+
"""
|
2152
|
+
# Convert predictions to class predictions
|
2153
|
+
if pred.dim() == 3: # [C, H, W] format
|
2154
|
+
pred = torch.softmax(pred, dim=0)
|
2155
|
+
pred_classes = torch.argmax(pred, dim=0)
|
2156
|
+
elif pred.dim() == 2: # [H, W] format
|
2157
|
+
pred_classes = pred
|
2158
|
+
else:
|
2159
|
+
raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
|
2160
|
+
|
2161
|
+
# Auto-detect number of classes if not provided
|
2162
|
+
if num_classes is None:
|
2163
|
+
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
2164
|
+
|
2165
|
+
# Calculate precision for each class and average
|
2166
|
+
precision_scores = []
|
2167
|
+
for class_id in range(num_classes):
|
2168
|
+
pred_class = (pred_classes == class_id).float()
|
2169
|
+
target_class = (target == class_id).float()
|
2170
|
+
|
2171
|
+
true_positives = (pred_class * target_class).sum()
|
2172
|
+
predicted_positives = pred_class.sum()
|
2173
|
+
|
2174
|
+
if predicted_positives > 0:
|
2175
|
+
precision = (true_positives + smooth) / (predicted_positives + smooth)
|
2176
|
+
precision_scores.append(precision.item())
|
2177
|
+
|
2178
|
+
return sum(precision_scores) / len(precision_scores) if precision_scores else 0.0
|
2179
|
+
|
2180
|
+
|
2181
|
+
def recall_score(
|
2182
|
+
pred: torch.Tensor,
|
2183
|
+
target: torch.Tensor,
|
2184
|
+
smooth: float = 1e-6,
|
2185
|
+
num_classes: Optional[int] = None,
|
2186
|
+
) -> float:
|
2187
|
+
"""
|
2188
|
+
Calculate recall score (also known as sensitivity) for segmentation (binary or multi-class).
|
2189
|
+
|
2190
|
+
Recall = TP / (TP + FN), where:
|
2191
|
+
- TP (True Positives): Correctly predicted positive pixels
|
2192
|
+
- FN (False Negatives): Incorrectly predicted negative pixels
|
2193
|
+
|
2194
|
+
Args:
|
2195
|
+
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
2196
|
+
target (torch.Tensor): Ground truth mask with shape [H, W].
|
2197
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
2198
|
+
num_classes (int, optional): Number of classes. If None, auto-detected.
|
2199
|
+
|
2200
|
+
Returns:
|
2201
|
+
float: Mean recall score across all classes.
|
2202
|
+
"""
|
2203
|
+
# Convert predictions to class predictions
|
2204
|
+
if pred.dim() == 3: # [C, H, W] format
|
2205
|
+
pred = torch.softmax(pred, dim=0)
|
2206
|
+
pred_classes = torch.argmax(pred, dim=0)
|
2207
|
+
elif pred.dim() == 2: # [H, W] format
|
2208
|
+
pred_classes = pred
|
2209
|
+
else:
|
2210
|
+
raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
|
2211
|
+
|
2212
|
+
# Auto-detect number of classes if not provided
|
2213
|
+
if num_classes is None:
|
2214
|
+
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
2215
|
+
|
2216
|
+
# Calculate recall for each class and average
|
2217
|
+
recall_scores = []
|
2218
|
+
for class_id in range(num_classes):
|
2219
|
+
pred_class = (pred_classes == class_id).float()
|
2220
|
+
target_class = (target == class_id).float()
|
2221
|
+
|
2222
|
+
true_positives = (pred_class * target_class).sum()
|
2223
|
+
actual_positives = target_class.sum()
|
2224
|
+
|
2225
|
+
if actual_positives > 0:
|
2226
|
+
recall = (true_positives + smooth) / (actual_positives + smooth)
|
2227
|
+
recall_scores.append(recall.item())
|
2228
|
+
|
2229
|
+
return sum(recall_scores) / len(recall_scores) if recall_scores else 0.0
|
2230
|
+
|
2231
|
+
|
2122
2232
|
def train_semantic_one_epoch(
|
2123
2233
|
model: torch.nn.Module,
|
2124
2234
|
optimizer: torch.optim.Optimizer,
|
@@ -2200,13 +2310,15 @@ def evaluate_semantic(
|
|
2200
2310
|
num_classes (int): Number of classes for evaluation metrics.
|
2201
2311
|
|
2202
2312
|
Returns:
|
2203
|
-
dict: Evaluation metrics including loss, IoU, and
|
2313
|
+
dict: Evaluation metrics including loss, IoU, F1, precision, and recall.
|
2204
2314
|
"""
|
2205
2315
|
model.eval()
|
2206
2316
|
|
2207
2317
|
total_loss = 0
|
2208
|
-
|
2318
|
+
f1_scores = []
|
2209
2319
|
iou_scores = []
|
2320
|
+
precision_scores = []
|
2321
|
+
recall_scores = []
|
2210
2322
|
num_batches = len(data_loader)
|
2211
2323
|
|
2212
2324
|
with torch.no_grad():
|
@@ -2222,17 +2334,31 @@ def evaluate_semantic(
|
|
2222
2334
|
|
2223
2335
|
# Calculate metrics for each sample in the batch
|
2224
2336
|
for pred, target in zip(outputs, targets):
|
2225
|
-
|
2337
|
+
f1 = f1_score(pred, target, num_classes=num_classes)
|
2226
2338
|
iou = iou_coefficient(pred, target, num_classes=num_classes)
|
2227
|
-
|
2339
|
+
precision = precision_score(pred, target, num_classes=num_classes)
|
2340
|
+
recall = recall_score(pred, target, num_classes=num_classes)
|
2341
|
+
f1_scores.append(f1)
|
2228
2342
|
iou_scores.append(iou)
|
2343
|
+
precision_scores.append(precision)
|
2344
|
+
recall_scores.append(recall)
|
2229
2345
|
|
2230
2346
|
# Calculate metrics
|
2231
2347
|
avg_loss = total_loss / num_batches
|
2232
|
-
|
2348
|
+
avg_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
|
2233
2349
|
avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
|
2234
|
-
|
2235
|
-
|
2350
|
+
avg_precision = (
|
2351
|
+
sum(precision_scores) / len(precision_scores) if precision_scores else 0
|
2352
|
+
)
|
2353
|
+
avg_recall = sum(recall_scores) / len(recall_scores) if recall_scores else 0
|
2354
|
+
|
2355
|
+
return {
|
2356
|
+
"loss": avg_loss,
|
2357
|
+
"F1": avg_f1,
|
2358
|
+
"IoU": avg_iou,
|
2359
|
+
"Precision": avg_precision,
|
2360
|
+
"Recall": avg_recall,
|
2361
|
+
}
|
2236
2362
|
|
2237
2363
|
|
2238
2364
|
def train_segmentation_model(
|
@@ -2261,6 +2387,7 @@ def train_segmentation_model(
|
|
2261
2387
|
target_size: Optional[Tuple[int, int]] = None,
|
2262
2388
|
resize_mode: str = "resize",
|
2263
2389
|
num_workers: Optional[int] = None,
|
2390
|
+
early_stopping_patience: Optional[int] = None,
|
2264
2391
|
**kwargs: Any,
|
2265
2392
|
) -> torch.nn.Module:
|
2266
2393
|
"""
|
@@ -2313,6 +2440,8 @@ def train_segmentation_model(
|
|
2313
2440
|
'resize' - Resize images to target_size (may change aspect ratio)
|
2314
2441
|
'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
|
2315
2442
|
num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
|
2443
|
+
early_stopping_patience (int, optional): Number of epochs with no improvement after which
|
2444
|
+
training will be stopped. If None, early stopping is disabled. Defaults to None.
|
2316
2445
|
**kwargs: Additional arguments passed to smp.create_model().
|
2317
2446
|
Returns:
|
2318
2447
|
None: Model weights are saved to output_dir.
|
@@ -2542,7 +2671,7 @@ def train_segmentation_model(
|
|
2542
2671
|
print(f"Using {torch.cuda.device_count()} GPUs for training")
|
2543
2672
|
model = torch.nn.DataParallel(model)
|
2544
2673
|
|
2545
|
-
# Set up loss function (CrossEntropyLoss for multi-class, can also use
|
2674
|
+
# Set up loss function (CrossEntropyLoss for multi-class, can also use F1Loss)
|
2546
2675
|
criterion = torch.nn.CrossEntropyLoss()
|
2547
2676
|
|
2548
2677
|
# Set up optimizer
|
@@ -2560,8 +2689,11 @@ def train_segmentation_model(
|
|
2560
2689
|
train_losses = []
|
2561
2690
|
val_losses = []
|
2562
2691
|
val_ious = []
|
2563
|
-
|
2692
|
+
val_f1s = []
|
2693
|
+
val_precisions = []
|
2694
|
+
val_recalls = []
|
2564
2695
|
start_epoch = 0
|
2696
|
+
epochs_without_improvement = 0
|
2565
2697
|
|
2566
2698
|
# Load checkpoint if provided
|
2567
2699
|
if checkpoint_path is not None:
|
@@ -2596,8 +2728,15 @@ def train_segmentation_model(
|
|
2596
2728
|
val_losses = checkpoint["val_losses"]
|
2597
2729
|
if "val_ious" in checkpoint:
|
2598
2730
|
val_ious = checkpoint["val_ious"]
|
2599
|
-
if "
|
2600
|
-
|
2731
|
+
if "val_f1s" in checkpoint:
|
2732
|
+
val_f1s = checkpoint["val_f1s"]
|
2733
|
+
# Also check for old val_dices format for backward compatibility
|
2734
|
+
elif "val_dices" in checkpoint:
|
2735
|
+
val_f1s = checkpoint["val_dices"]
|
2736
|
+
if "val_precisions" in checkpoint:
|
2737
|
+
val_precisions = checkpoint["val_precisions"]
|
2738
|
+
if "val_recalls" in checkpoint:
|
2739
|
+
val_recalls = checkpoint["val_recalls"]
|
2601
2740
|
|
2602
2741
|
print(f"Resuming training from epoch {start_epoch}")
|
2603
2742
|
print(f"Previous best IoU: {best_iou:.4f}")
|
@@ -2637,7 +2776,9 @@ def train_segmentation_model(
|
|
2637
2776
|
)
|
2638
2777
|
val_losses.append(eval_metrics["loss"])
|
2639
2778
|
val_ious.append(eval_metrics["IoU"])
|
2640
|
-
|
2779
|
+
val_f1s.append(eval_metrics["F1"])
|
2780
|
+
val_precisions.append(eval_metrics["Precision"])
|
2781
|
+
val_recalls.append(eval_metrics["Recall"])
|
2641
2782
|
|
2642
2783
|
# Update learning rate
|
2643
2784
|
lr_scheduler.step(eval_metrics["loss"])
|
@@ -2648,14 +2789,28 @@ def train_segmentation_model(
|
|
2648
2789
|
f"Train Loss: {train_loss:.4f}, "
|
2649
2790
|
f"Val Loss: {eval_metrics['loss']:.4f}, "
|
2650
2791
|
f"Val IoU: {eval_metrics['IoU']:.4f}, "
|
2651
|
-
f"Val
|
2792
|
+
f"Val F1: {eval_metrics['F1']:.4f}, "
|
2793
|
+
f"Val Precision: {eval_metrics['Precision']:.4f}, "
|
2794
|
+
f"Val Recall: {eval_metrics['Recall']:.4f}"
|
2652
2795
|
)
|
2653
2796
|
|
2654
|
-
# Save best model
|
2797
|
+
# Save best model and check for early stopping
|
2655
2798
|
if eval_metrics["IoU"] > best_iou:
|
2656
2799
|
best_iou = eval_metrics["IoU"]
|
2800
|
+
epochs_without_improvement = 0
|
2657
2801
|
print(f"Saving best model with IoU: {best_iou:.4f}")
|
2658
2802
|
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
2803
|
+
else:
|
2804
|
+
epochs_without_improvement += 1
|
2805
|
+
if (
|
2806
|
+
early_stopping_patience is not None
|
2807
|
+
and epochs_without_improvement >= early_stopping_patience
|
2808
|
+
):
|
2809
|
+
print(
|
2810
|
+
f"\nEarly stopping triggered after {epochs_without_improvement} epochs without improvement"
|
2811
|
+
)
|
2812
|
+
print(f"Best validation IoU: {best_iou:.4f}")
|
2813
|
+
break
|
2659
2814
|
|
2660
2815
|
# Save checkpoint every 10 epochs (if not save_best_only)
|
2661
2816
|
if not save_best_only and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
|
@@ -2673,7 +2828,9 @@ def train_segmentation_model(
|
|
2673
2828
|
"train_losses": train_losses,
|
2674
2829
|
"val_losses": val_losses,
|
2675
2830
|
"val_ious": val_ious,
|
2676
|
-
"
|
2831
|
+
"val_f1s": val_f1s,
|
2832
|
+
"val_precisions": val_precisions,
|
2833
|
+
"val_recalls": val_recalls,
|
2677
2834
|
},
|
2678
2835
|
os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
|
2679
2836
|
)
|
@@ -2686,7 +2843,9 @@ def train_segmentation_model(
|
|
2686
2843
|
"train_losses": train_losses,
|
2687
2844
|
"val_losses": val_losses,
|
2688
2845
|
"val_ious": val_ious,
|
2689
|
-
"
|
2846
|
+
"val_f1s": val_f1s,
|
2847
|
+
"val_precisions": val_precisions,
|
2848
|
+
"val_recalls": val_recalls,
|
2690
2849
|
}
|
2691
2850
|
torch.save(history, os.path.join(output_dir, "training_history.pth"))
|
2692
2851
|
|
@@ -2702,7 +2861,9 @@ def train_segmentation_model(
|
|
2702
2861
|
f.write(f"Total epochs: {num_epochs}\n")
|
2703
2862
|
f.write(f"Best validation IoU: {best_iou:.4f}\n")
|
2704
2863
|
f.write(f"Final validation IoU: {val_ious[-1]:.4f}\n")
|
2705
|
-
f.write(f"Final validation
|
2864
|
+
f.write(f"Final validation F1: {val_f1s[-1]:.4f}\n")
|
2865
|
+
f.write(f"Final validation Precision: {val_precisions[-1]:.4f}\n")
|
2866
|
+
f.write(f"Final validation Recall: {val_recalls[-1]:.4f}\n")
|
2706
2867
|
f.write(f"Final validation loss: {val_losses[-1]:.4f}\n")
|
2707
2868
|
|
2708
2869
|
print(f"Training complete! Best IoU: {best_iou:.4f}")
|
@@ -2731,10 +2892,10 @@ def train_segmentation_model(
|
|
2731
2892
|
plt.grid(True)
|
2732
2893
|
|
2733
2894
|
plt.subplot(1, 3, 3)
|
2734
|
-
plt.plot(
|
2735
|
-
plt.title("
|
2895
|
+
plt.plot(val_f1s, label="Val F1")
|
2896
|
+
plt.title("F1 Score")
|
2736
2897
|
plt.xlabel("Epoch")
|
2737
|
-
plt.ylabel("
|
2898
|
+
plt.ylabel("F1")
|
2738
2899
|
plt.legend()
|
2739
2900
|
plt.grid(True)
|
2740
2901
|
|
@@ -2764,6 +2925,7 @@ def semantic_inference_on_geotiff(
|
|
2764
2925
|
device: Optional[torch.device] = None,
|
2765
2926
|
probability_path: Optional[str] = None,
|
2766
2927
|
probability_threshold: Optional[float] = None,
|
2928
|
+
save_class_probabilities: bool = False,
|
2767
2929
|
quiet: bool = False,
|
2768
2930
|
**kwargs: Any,
|
2769
2931
|
) -> Tuple[str, float]:
|
@@ -2785,6 +2947,8 @@ def semantic_inference_on_geotiff(
|
|
2785
2947
|
probability_threshold (float, optional): Probability threshold for binary classification.
|
2786
2948
|
Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
|
2787
2949
|
are classified as class 1, otherwise class 0. If None (default), uses argmax.
|
2950
|
+
save_class_probabilities (bool): If True and probability_path is provided, saves each
|
2951
|
+
class probability as a separate single-band file. Defaults to False.
|
2788
2952
|
quiet (bool): If True, suppress progress bar. Defaults to False.
|
2789
2953
|
**kwargs: Additional arguments.
|
2790
2954
|
|
@@ -3001,7 +3165,7 @@ def semantic_inference_on_geotiff(
|
|
3001
3165
|
prob_meta = meta.copy()
|
3002
3166
|
prob_meta.update({"count": num_classes, "dtype": "float32"})
|
3003
3167
|
|
3004
|
-
# Save normalized probabilities
|
3168
|
+
# Save normalized probabilities as multi-band raster
|
3005
3169
|
with rasterio.open(probability_path, "w", **prob_meta) as dst:
|
3006
3170
|
for class_idx in range(num_classes):
|
3007
3171
|
# Normalize probabilities
|
@@ -3015,6 +3179,36 @@ def semantic_inference_on_geotiff(
|
|
3015
3179
|
if not quiet:
|
3016
3180
|
print(f"Saved probability map to {probability_path}")
|
3017
3181
|
|
3182
|
+
# Save individual class probabilities if requested
|
3183
|
+
if save_class_probabilities:
|
3184
|
+
# Prepare single-band metadata
|
3185
|
+
single_band_meta = meta.copy()
|
3186
|
+
single_band_meta.update({"count": 1, "dtype": "float32"})
|
3187
|
+
|
3188
|
+
# Get base filename and extension
|
3189
|
+
prob_base = os.path.splitext(probability_path)[0]
|
3190
|
+
prob_ext = os.path.splitext(probability_path)[1]
|
3191
|
+
|
3192
|
+
for class_idx in range(num_classes):
|
3193
|
+
# Create filename for this class
|
3194
|
+
class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
|
3195
|
+
|
3196
|
+
# Normalize probabilities
|
3197
|
+
prob_band = np.zeros((height, width), dtype=np.float32)
|
3198
|
+
prob_band[valid_pixels] = (
|
3199
|
+
prob_accumulator[class_idx, valid_pixels]
|
3200
|
+
/ count_accumulator[valid_pixels]
|
3201
|
+
)
|
3202
|
+
|
3203
|
+
# Save single-band file
|
3204
|
+
with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
|
3205
|
+
dst.write(prob_band, 1)
|
3206
|
+
|
3207
|
+
if not quiet:
|
3208
|
+
print(
|
3209
|
+
f"Saved class {class_idx} probability to {class_prob_path}"
|
3210
|
+
)
|
3211
|
+
|
3018
3212
|
return output_path, inference_time
|
3019
3213
|
|
3020
3214
|
|
@@ -3031,6 +3225,7 @@ def semantic_inference_on_image(
|
|
3031
3225
|
binary_output: bool = True,
|
3032
3226
|
probability_path: Optional[str] = None,
|
3033
3227
|
probability_threshold: Optional[float] = None,
|
3228
|
+
save_class_probabilities: bool = False,
|
3034
3229
|
quiet: bool = False,
|
3035
3230
|
**kwargs: Any,
|
3036
3231
|
) -> Tuple[str, float]:
|
@@ -3053,6 +3248,8 @@ def semantic_inference_on_image(
|
|
3053
3248
|
probability_threshold (float, optional): Probability threshold for binary classification.
|
3054
3249
|
Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
|
3055
3250
|
are classified as class 1, otherwise class 0. If None (default), uses argmax.
|
3251
|
+
save_class_probabilities (bool): If True and probability_path is provided, saves each
|
3252
|
+
class probability as a separate single-band file. Defaults to False.
|
3056
3253
|
quiet (bool): If True, suppress progress bar. Defaults to False.
|
3057
3254
|
**kwargs: Additional arguments.
|
3058
3255
|
|
@@ -3331,7 +3528,7 @@ def semantic_inference_on_image(
|
|
3331
3528
|
"transform": transform,
|
3332
3529
|
}
|
3333
3530
|
|
3334
|
-
# Save normalized probabilities
|
3531
|
+
# Save normalized probabilities as multi-band raster
|
3335
3532
|
with rasterio.open(probability_path, "w", **prob_meta) as dst:
|
3336
3533
|
for class_idx in range(num_classes):
|
3337
3534
|
# Normalize probabilities
|
@@ -3342,6 +3539,39 @@ def semantic_inference_on_image(
|
|
3342
3539
|
if not quiet:
|
3343
3540
|
print(f"Saved probability map to {probability_path}")
|
3344
3541
|
|
3542
|
+
# Save individual class probabilities if requested
|
3543
|
+
if save_class_probabilities:
|
3544
|
+
# Prepare single-band metadata
|
3545
|
+
single_band_meta = {
|
3546
|
+
"driver": "GTiff",
|
3547
|
+
"height": height,
|
3548
|
+
"width": width,
|
3549
|
+
"count": 1,
|
3550
|
+
"dtype": "float32",
|
3551
|
+
"transform": transform,
|
3552
|
+
}
|
3553
|
+
|
3554
|
+
# Get base filename and extension
|
3555
|
+
prob_base = os.path.splitext(probability_path)[0]
|
3556
|
+
prob_ext = os.path.splitext(probability_path)[1]
|
3557
|
+
|
3558
|
+
for class_idx in range(num_classes):
|
3559
|
+
# Create filename for this class
|
3560
|
+
class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
|
3561
|
+
|
3562
|
+
# Normalize probabilities
|
3563
|
+
prob_band = np.zeros((height, width), dtype=np.float32)
|
3564
|
+
prob_band[valid_pixels] = normalized_probs[class_idx, valid_pixels]
|
3565
|
+
|
3566
|
+
# Save single-band file
|
3567
|
+
with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
|
3568
|
+
dst.write(prob_band, 1)
|
3569
|
+
|
3570
|
+
if not quiet:
|
3571
|
+
print(
|
3572
|
+
f"Saved class {class_idx} probability to {class_prob_path}"
|
3573
|
+
)
|
3574
|
+
|
3345
3575
|
return output_path, inference_time
|
3346
3576
|
|
3347
3577
|
|
@@ -3359,6 +3589,7 @@ def semantic_segmentation(
|
|
3359
3589
|
device: Optional[torch.device] = None,
|
3360
3590
|
probability_path: Optional[str] = None,
|
3361
3591
|
probability_threshold: Optional[float] = None,
|
3592
|
+
save_class_probabilities: bool = False,
|
3362
3593
|
quiet: bool = False,
|
3363
3594
|
**kwargs: Any,
|
3364
3595
|
) -> None:
|
@@ -3381,11 +3612,16 @@ def semantic_segmentation(
|
|
3381
3612
|
batch_size (int): Batch size for inference.
|
3382
3613
|
device (torch.device, optional): Device to run inference on.
|
3383
3614
|
probability_path (str, optional): Path to save probability map. If provided,
|
3384
|
-
the normalized class probabilities will be saved as a multi-band raster
|
3615
|
+
the normalized class probabilities will be saved as a multi-band raster
|
3616
|
+
where each band contains probabilities for each class.
|
3385
3617
|
probability_threshold (float, optional): Probability threshold for binary classification.
|
3386
3618
|
Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
|
3387
3619
|
are classified as class 1, otherwise class 0. If None (default), uses argmax.
|
3388
3620
|
Must be between 0 and 1.
|
3621
|
+
save_class_probabilities (bool): If True and probability_path is provided, saves each
|
3622
|
+
class probability as a separate single-band file. Files will be named like
|
3623
|
+
"probability_class_0.tif", "probability_class_1.tif", etc. in the same directory
|
3624
|
+
as probability_path. Defaults to False.
|
3389
3625
|
quiet (bool): If True, suppress progress bar. Defaults to False.
|
3390
3626
|
**kwargs: Additional arguments.
|
3391
3627
|
|
@@ -3462,6 +3698,7 @@ def semantic_segmentation(
|
|
3462
3698
|
device=device,
|
3463
3699
|
probability_path=probability_path,
|
3464
3700
|
probability_threshold=probability_threshold,
|
3701
|
+
save_class_probabilities=save_class_probabilities,
|
3465
3702
|
quiet=quiet,
|
3466
3703
|
**kwargs,
|
3467
3704
|
)
|
@@ -3482,6 +3719,7 @@ def semantic_segmentation(
|
|
3482
3719
|
binary_output=True, # Convert to binary output for better visualization
|
3483
3720
|
probability_path=probability_path,
|
3484
3721
|
probability_threshold=probability_threshold,
|
3722
|
+
save_class_probabilities=save_class_probabilities,
|
3485
3723
|
quiet=quiet,
|
3486
3724
|
**kwargs,
|
3487
3725
|
)
|
@@ -3903,3 +4141,307 @@ def instance_segmentation_batch(
|
|
3903
4141
|
continue
|
3904
4142
|
|
3905
4143
|
print(f"Batch processing completed. Results saved to {output_dir}")
|
4144
|
+
|
4145
|
+
|
4146
|
+
def lightly_train_model(
|
4147
|
+
data_dir: str,
|
4148
|
+
output_dir: str,
|
4149
|
+
model: str = "torchvision/resnet50",
|
4150
|
+
method: str = "dinov2_distillation",
|
4151
|
+
epochs: int = 100,
|
4152
|
+
batch_size: int = 64,
|
4153
|
+
learning_rate: float = 1e-4,
|
4154
|
+
**kwargs: Any,
|
4155
|
+
) -> str:
|
4156
|
+
"""
|
4157
|
+
Train a model using Lightly Train for self-supervised pretraining.
|
4158
|
+
|
4159
|
+
Args:
|
4160
|
+
data_dir (str): Directory containing unlabeled images for training.
|
4161
|
+
output_dir (str): Directory to save training outputs and model checkpoints.
|
4162
|
+
model (str): Model architecture to train. Supports models from torchvision,
|
4163
|
+
timm, ultralytics, etc. Default is "torchvision/resnet50".
|
4164
|
+
method (str): Self-supervised learning method. Options include:
|
4165
|
+
- "simclr": Works with CNN models (ResNet, EfficientNet, etc.)
|
4166
|
+
- "dino": Works with both CNNs and ViTs
|
4167
|
+
- "dinov2": Requires ViT models only
|
4168
|
+
- "dinov2_distillation": Requires ViT models only (recommended for ViTs)
|
4169
|
+
Default is "dinov2_distillation".
|
4170
|
+
epochs (int): Number of training epochs. Default is 100.
|
4171
|
+
batch_size (int): Batch size for training. Default is 64.
|
4172
|
+
learning_rate (float): Learning rate for training. Default is 1e-4.
|
4173
|
+
**kwargs: Additional arguments passed to lightly_train.train().
|
4174
|
+
|
4175
|
+
Returns:
|
4176
|
+
str: Path to the exported model file.
|
4177
|
+
|
4178
|
+
Raises:
|
4179
|
+
ImportError: If lightly-train is not installed.
|
4180
|
+
ValueError: If data_dir does not exist, is empty, or incompatible model/method.
|
4181
|
+
|
4182
|
+
Note:
|
4183
|
+
Model/Method compatibility:
|
4184
|
+
- CNN models (ResNet, EfficientNet): Use "simclr" or "dino"
|
4185
|
+
- ViT models: Use "dinov2", "dinov2_distillation", or "dino"
|
4186
|
+
|
4187
|
+
Example:
|
4188
|
+
>>> # For CNN models (ResNet, EfficientNet)
|
4189
|
+
>>> model_path = lightly_train_model(
|
4190
|
+
... data_dir="path/to/unlabeled/images",
|
4191
|
+
... output_dir="path/to/output",
|
4192
|
+
... model="torchvision/resnet50",
|
4193
|
+
... method="simclr", # Use simclr for CNNs
|
4194
|
+
... epochs=50
|
4195
|
+
... )
|
4196
|
+
>>> # For ViT models
|
4197
|
+
>>> model_path = lightly_train_model(
|
4198
|
+
... data_dir="path/to/unlabeled/images",
|
4199
|
+
... output_dir="path/to/output",
|
4200
|
+
... model="timm/vit_base_patch16_224",
|
4201
|
+
... method="dinov2", # dinov2 requires ViT
|
4202
|
+
... epochs=50
|
4203
|
+
... )
|
4204
|
+
"""
|
4205
|
+
if not LIGHTLY_TRAIN_AVAILABLE:
|
4206
|
+
raise ImportError(
|
4207
|
+
"lightly-train is not installed. Please install it with: "
|
4208
|
+
"pip install lightly-train"
|
4209
|
+
)
|
4210
|
+
|
4211
|
+
if not os.path.exists(data_dir):
|
4212
|
+
raise ValueError(f"Data directory does not exist: {data_dir}")
|
4213
|
+
|
4214
|
+
# Check if data directory contains images
|
4215
|
+
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.tif", "*.tiff", "*.bmp"]
|
4216
|
+
image_files = []
|
4217
|
+
for ext in image_extensions:
|
4218
|
+
image_files.extend(glob.glob(os.path.join(data_dir, "**", ext), recursive=True))
|
4219
|
+
|
4220
|
+
if not image_files:
|
4221
|
+
raise ValueError(f"No image files found in {data_dir}")
|
4222
|
+
|
4223
|
+
# Validate model/method compatibility
|
4224
|
+
is_vit_model = "vit" in model.lower() or "vision_transformer" in model.lower()
|
4225
|
+
|
4226
|
+
if method in ["dinov2", "dinov2_distillation"] and not is_vit_model:
|
4227
|
+
raise ValueError(
|
4228
|
+
f"Method '{method}' requires a Vision Transformer (ViT) model, but got '{model}'.\n"
|
4229
|
+
f"Solutions:\n"
|
4230
|
+
f" 1. Use a ViT model: model='timm/vit_base_patch16_224'\n"
|
4231
|
+
f" 2. Use a CNN-compatible method: method='simclr' or method='dino'\n"
|
4232
|
+
f"\nFor CNN models (ResNet, EfficientNet), use 'simclr' or 'dino'.\n"
|
4233
|
+
f"For ViT models, use 'dinov2', 'dinov2_distillation', or 'dino'."
|
4234
|
+
)
|
4235
|
+
|
4236
|
+
print(f"Found {len(image_files)} images in {data_dir}")
|
4237
|
+
print(f"Starting self-supervised pretraining with {method} method...")
|
4238
|
+
print(f"Model: {model}")
|
4239
|
+
|
4240
|
+
# Create output directory
|
4241
|
+
os.makedirs(output_dir, exist_ok=True)
|
4242
|
+
|
4243
|
+
# Detect if running in notebook environment and set appropriate configuration
|
4244
|
+
def is_notebook():
|
4245
|
+
try:
|
4246
|
+
from IPython import get_ipython
|
4247
|
+
|
4248
|
+
if get_ipython() is not None:
|
4249
|
+
return True
|
4250
|
+
except (ImportError, NameError):
|
4251
|
+
pass
|
4252
|
+
return False
|
4253
|
+
|
4254
|
+
# Force single-device training in notebooks to avoid DDP strategy issues
|
4255
|
+
if is_notebook():
|
4256
|
+
# Only override if not explicitly set by user
|
4257
|
+
if "accelerator" not in kwargs:
|
4258
|
+
# Use CPU in notebooks to avoid DDP incompatibility
|
4259
|
+
# Users can still override by passing accelerator='gpu'
|
4260
|
+
kwargs["accelerator"] = "cpu"
|
4261
|
+
if "devices" not in kwargs:
|
4262
|
+
kwargs["devices"] = 1 # Force single device
|
4263
|
+
|
4264
|
+
# Train the model using Lightly Train
|
4265
|
+
lightly_train.train(
|
4266
|
+
out=output_dir,
|
4267
|
+
data=data_dir,
|
4268
|
+
model=model,
|
4269
|
+
method=method,
|
4270
|
+
epochs=epochs,
|
4271
|
+
batch_size=batch_size,
|
4272
|
+
**kwargs,
|
4273
|
+
)
|
4274
|
+
|
4275
|
+
# Return path to the exported model
|
4276
|
+
exported_model_path = os.path.join(
|
4277
|
+
output_dir, "exported_models", "exported_last.pt"
|
4278
|
+
)
|
4279
|
+
|
4280
|
+
if os.path.exists(exported_model_path):
|
4281
|
+
print(
|
4282
|
+
f"Model training completed. Exported model saved to: {exported_model_path}"
|
4283
|
+
)
|
4284
|
+
return exported_model_path
|
4285
|
+
else:
|
4286
|
+
# Check for alternative export paths
|
4287
|
+
possible_paths = [
|
4288
|
+
os.path.join(output_dir, "exported_models", "exported_best.pt"),
|
4289
|
+
os.path.join(output_dir, "checkpoints", "last.ckpt"),
|
4290
|
+
]
|
4291
|
+
|
4292
|
+
for path in possible_paths:
|
4293
|
+
if os.path.exists(path):
|
4294
|
+
print(f"Model training completed. Exported model saved to: {path}")
|
4295
|
+
return path
|
4296
|
+
|
4297
|
+
print(f"Model training completed. Output saved to: {output_dir}")
|
4298
|
+
return output_dir
|
4299
|
+
|
4300
|
+
|
4301
|
+
def load_lightly_pretrained_model(
|
4302
|
+
model_path: str,
|
4303
|
+
model_architecture: str = "torchvision/resnet50",
|
4304
|
+
device: str = None,
|
4305
|
+
) -> torch.nn.Module:
|
4306
|
+
"""
|
4307
|
+
Load a pretrained model from Lightly Train.
|
4308
|
+
|
4309
|
+
Args:
|
4310
|
+
model_path (str): Path to the pretrained model file (.pt format).
|
4311
|
+
model_architecture (str): Architecture of the model to load.
|
4312
|
+
Default is "torchvision/resnet50".
|
4313
|
+
device (str): Device to load the model on. If None, uses CPU.
|
4314
|
+
|
4315
|
+
Returns:
|
4316
|
+
torch.nn.Module: Loaded pretrained model ready for fine-tuning.
|
4317
|
+
|
4318
|
+
Raises:
|
4319
|
+
FileNotFoundError: If model_path does not exist.
|
4320
|
+
ImportError: If required libraries are not available.
|
4321
|
+
|
4322
|
+
Example:
|
4323
|
+
>>> model = load_lightly_pretrained_model(
|
4324
|
+
... model_path="path/to/pretrained_model.pt",
|
4325
|
+
... model_architecture="torchvision/resnet50",
|
4326
|
+
... device="cuda"
|
4327
|
+
... )
|
4328
|
+
>>> # Fine-tune the model with your existing training pipeline
|
4329
|
+
"""
|
4330
|
+
if not os.path.exists(model_path):
|
4331
|
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
4332
|
+
|
4333
|
+
print(f"Loading pretrained model from: {model_path}")
|
4334
|
+
|
4335
|
+
# Load the model based on architecture
|
4336
|
+
if model_architecture.startswith("torchvision/"):
|
4337
|
+
model_name = model_architecture.replace("torchvision/", "")
|
4338
|
+
|
4339
|
+
# Import the model from torchvision
|
4340
|
+
if hasattr(torchvision.models, model_name):
|
4341
|
+
model = getattr(torchvision.models, model_name)()
|
4342
|
+
else:
|
4343
|
+
raise ValueError(f"Unknown torchvision model: {model_name}")
|
4344
|
+
|
4345
|
+
elif model_architecture.startswith("timm/"):
|
4346
|
+
try:
|
4347
|
+
import timm
|
4348
|
+
|
4349
|
+
model_name = model_architecture.replace("timm/", "")
|
4350
|
+
model = timm.create_model(model_name)
|
4351
|
+
except ImportError:
|
4352
|
+
raise ImportError(
|
4353
|
+
"timm is required for TIMM models. Install with: pip install timm"
|
4354
|
+
)
|
4355
|
+
|
4356
|
+
else:
|
4357
|
+
# For other architectures, try to import from torchvision as default
|
4358
|
+
try:
|
4359
|
+
model = getattr(torchvision.models, model_architecture)()
|
4360
|
+
except AttributeError:
|
4361
|
+
raise ValueError(f"Unsupported model architecture: {model_architecture}")
|
4362
|
+
|
4363
|
+
# Load the pretrained weights
|
4364
|
+
try:
|
4365
|
+
state_dict = torch.load(model_path, map_location=device, weights_only=True)
|
4366
|
+
except TypeError:
|
4367
|
+
# For backward compatibility with older PyTorch versions
|
4368
|
+
state_dict = torch.load(model_path, map_location=device)
|
4369
|
+
model.load_state_dict(state_dict)
|
4370
|
+
|
4371
|
+
print(f"Successfully loaded pretrained model: {model_architecture}")
|
4372
|
+
return model
|
4373
|
+
|
4374
|
+
|
4375
|
+
def lightly_embed_images(
|
4376
|
+
data_dir: str,
|
4377
|
+
model_path: str,
|
4378
|
+
output_path: str,
|
4379
|
+
model_architecture: str = None, # Deprecated, kept for backwards compatibility
|
4380
|
+
batch_size: int = 64,
|
4381
|
+
**kwargs: Any,
|
4382
|
+
) -> str:
|
4383
|
+
"""
|
4384
|
+
Generate embeddings for images using a Lightly Train pretrained model.
|
4385
|
+
|
4386
|
+
Args:
|
4387
|
+
data_dir (str): Directory containing images to embed.
|
4388
|
+
model_path (str): Path to the pretrained model checkpoint file (.ckpt).
|
4389
|
+
output_path (str): Path to save the embeddings (as .pt file).
|
4390
|
+
model_architecture (str): Architecture of the pretrained model (deprecated,
|
4391
|
+
kept for backwards compatibility but not used). The model architecture
|
4392
|
+
is automatically loaded from the checkpoint.
|
4393
|
+
batch_size (int): Batch size for embedding generation. Default is 64.
|
4394
|
+
**kwargs: Additional arguments passed to lightly_train.embed().
|
4395
|
+
Supported kwargs include: image_size, num_workers, accelerator, etc.
|
4396
|
+
|
4397
|
+
Returns:
|
4398
|
+
str: Path to the saved embeddings file.
|
4399
|
+
|
4400
|
+
Raises:
|
4401
|
+
ImportError: If lightly-train is not installed.
|
4402
|
+
FileNotFoundError: If data_dir or model_path does not exist.
|
4403
|
+
|
4404
|
+
Note:
|
4405
|
+
The model_path should point to a .ckpt file from the training output,
|
4406
|
+
typically located at: output_dir/checkpoints/last.ckpt
|
4407
|
+
|
4408
|
+
Example:
|
4409
|
+
>>> embeddings_path = lightly_embed_images(
|
4410
|
+
... data_dir="path/to/images",
|
4411
|
+
... model_path="output_dir/checkpoints/last.ckpt",
|
4412
|
+
... output_path="embeddings.pt",
|
4413
|
+
... batch_size=32
|
4414
|
+
... )
|
4415
|
+
>>> print(f"Embeddings saved to: {embeddings_path}")
|
4416
|
+
"""
|
4417
|
+
if not LIGHTLY_TRAIN_AVAILABLE:
|
4418
|
+
raise ImportError(
|
4419
|
+
"lightly-train is not installed. Please install it with: "
|
4420
|
+
"pip install lightly-train"
|
4421
|
+
)
|
4422
|
+
|
4423
|
+
if not os.path.exists(data_dir):
|
4424
|
+
raise FileNotFoundError(f"Data directory does not exist: {data_dir}")
|
4425
|
+
|
4426
|
+
if not os.path.exists(model_path):
|
4427
|
+
raise FileNotFoundError(f"Model file does not exist: {model_path}")
|
4428
|
+
|
4429
|
+
print(f"Generating embeddings for images in: {data_dir}")
|
4430
|
+
print(f"Using pretrained model: {model_path}")
|
4431
|
+
|
4432
|
+
output_dir = os.path.dirname(output_path)
|
4433
|
+
if output_dir:
|
4434
|
+
os.makedirs(output_dir, exist_ok=True)
|
4435
|
+
|
4436
|
+
# Generate embeddings using Lightly Train
|
4437
|
+
# Note: model_architecture is not used - it's inferred from the checkpoint
|
4438
|
+
lightly_train.embed(
|
4439
|
+
out=output_path,
|
4440
|
+
data=data_dir,
|
4441
|
+
checkpoint=model_path,
|
4442
|
+
batch_size=batch_size,
|
4443
|
+
**kwargs,
|
4444
|
+
)
|
4445
|
+
|
4446
|
+
print(f"Embeddings saved to: {output_path}")
|
4447
|
+
return output_path
|