geoai-py 0.15.0__py2.py3-none-any.whl → 0.16.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/train.py CHANGED
@@ -34,6 +34,14 @@ try:
34
34
  except ImportError:
35
35
  SMP_AVAILABLE = False
36
36
 
37
+ # Additional imports for Lightly Train
38
+ try:
39
+ import lightly_train
40
+
41
+ LIGHTLY_TRAIN_AVAILABLE = True
42
+ except ImportError:
43
+ LIGHTLY_TRAIN_AVAILABLE = False
44
+
37
45
 
38
46
  def parse_coco_annotations(
39
47
  coco_json_path: str, images_dir: str, labels_dir: str
@@ -2025,14 +2033,14 @@ def get_smp_model(
2025
2033
  )
2026
2034
 
2027
2035
 
2028
- def dice_coefficient(
2036
+ def f1_score(
2029
2037
  pred: torch.Tensor,
2030
2038
  target: torch.Tensor,
2031
2039
  smooth: float = 1e-6,
2032
2040
  num_classes: Optional[int] = None,
2033
2041
  ) -> float:
2034
2042
  """
2035
- Calculate Dice coefficient for segmentation (binary or multi-class).
2043
+ Calculate F1 score (also known as Dice coefficient) for segmentation (binary or multi-class).
2036
2044
 
2037
2045
  Args:
2038
2046
  pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
@@ -2041,7 +2049,7 @@ def dice_coefficient(
2041
2049
  num_classes (int, optional): Number of classes. If None, auto-detected.
2042
2050
 
2043
2051
  Returns:
2044
- float: Mean Dice coefficient across all classes.
2052
+ float: Mean F1 score across all classes.
2045
2053
  """
2046
2054
  # Convert predictions to class predictions
2047
2055
  if pred.dim() == 3: # [C, H, W] format
@@ -2056,8 +2064,8 @@ def dice_coefficient(
2056
2064
  if num_classes is None:
2057
2065
  num_classes = max(pred_classes.max().item(), target.max().item()) + 1
2058
2066
 
2059
- # Calculate Dice for each class and average
2060
- dice_scores = []
2067
+ # Calculate F1 score for each class and average
2068
+ f1_scores = []
2061
2069
  for class_id in range(num_classes):
2062
2070
  pred_class = (pred_classes == class_id).float()
2063
2071
  target_class = (target == class_id).float()
@@ -2066,10 +2074,10 @@ def dice_coefficient(
2066
2074
  union = pred_class.sum() + target_class.sum()
2067
2075
 
2068
2076
  if union > 0:
2069
- dice = (2.0 * intersection + smooth) / (union + smooth)
2070
- dice_scores.append(dice.item())
2077
+ f1 = (2.0 * intersection + smooth) / (union + smooth)
2078
+ f1_scores.append(f1.item())
2071
2079
 
2072
- return sum(dice_scores) / len(dice_scores) if dice_scores else 0.0
2080
+ return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
2073
2081
 
2074
2082
 
2075
2083
  def iou_coefficient(
@@ -2119,6 +2127,108 @@ def iou_coefficient(
2119
2127
  return sum(iou_scores) / len(iou_scores) if iou_scores else 0.0
2120
2128
 
2121
2129
 
2130
+ def precision_score(
2131
+ pred: torch.Tensor,
2132
+ target: torch.Tensor,
2133
+ smooth: float = 1e-6,
2134
+ num_classes: Optional[int] = None,
2135
+ ) -> float:
2136
+ """
2137
+ Calculate precision score for segmentation (binary or multi-class).
2138
+
2139
+ Precision = TP / (TP + FP), where:
2140
+ - TP (True Positives): Correctly predicted positive pixels
2141
+ - FP (False Positives): Incorrectly predicted positive pixels
2142
+
2143
+ Args:
2144
+ pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
2145
+ target (torch.Tensor): Ground truth mask with shape [H, W].
2146
+ smooth (float): Smoothing factor to avoid division by zero.
2147
+ num_classes (int, optional): Number of classes. If None, auto-detected.
2148
+
2149
+ Returns:
2150
+ float: Mean precision score across all classes.
2151
+ """
2152
+ # Convert predictions to class predictions
2153
+ if pred.dim() == 3: # [C, H, W] format
2154
+ pred = torch.softmax(pred, dim=0)
2155
+ pred_classes = torch.argmax(pred, dim=0)
2156
+ elif pred.dim() == 2: # [H, W] format
2157
+ pred_classes = pred
2158
+ else:
2159
+ raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
2160
+
2161
+ # Auto-detect number of classes if not provided
2162
+ if num_classes is None:
2163
+ num_classes = max(pred_classes.max().item(), target.max().item()) + 1
2164
+
2165
+ # Calculate precision for each class and average
2166
+ precision_scores = []
2167
+ for class_id in range(num_classes):
2168
+ pred_class = (pred_classes == class_id).float()
2169
+ target_class = (target == class_id).float()
2170
+
2171
+ true_positives = (pred_class * target_class).sum()
2172
+ predicted_positives = pred_class.sum()
2173
+
2174
+ if predicted_positives > 0:
2175
+ precision = (true_positives + smooth) / (predicted_positives + smooth)
2176
+ precision_scores.append(precision.item())
2177
+
2178
+ return sum(precision_scores) / len(precision_scores) if precision_scores else 0.0
2179
+
2180
+
2181
+ def recall_score(
2182
+ pred: torch.Tensor,
2183
+ target: torch.Tensor,
2184
+ smooth: float = 1e-6,
2185
+ num_classes: Optional[int] = None,
2186
+ ) -> float:
2187
+ """
2188
+ Calculate recall score (also known as sensitivity) for segmentation (binary or multi-class).
2189
+
2190
+ Recall = TP / (TP + FN), where:
2191
+ - TP (True Positives): Correctly predicted positive pixels
2192
+ - FN (False Negatives): Incorrectly predicted negative pixels
2193
+
2194
+ Args:
2195
+ pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
2196
+ target (torch.Tensor): Ground truth mask with shape [H, W].
2197
+ smooth (float): Smoothing factor to avoid division by zero.
2198
+ num_classes (int, optional): Number of classes. If None, auto-detected.
2199
+
2200
+ Returns:
2201
+ float: Mean recall score across all classes.
2202
+ """
2203
+ # Convert predictions to class predictions
2204
+ if pred.dim() == 3: # [C, H, W] format
2205
+ pred = torch.softmax(pred, dim=0)
2206
+ pred_classes = torch.argmax(pred, dim=0)
2207
+ elif pred.dim() == 2: # [H, W] format
2208
+ pred_classes = pred
2209
+ else:
2210
+ raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
2211
+
2212
+ # Auto-detect number of classes if not provided
2213
+ if num_classes is None:
2214
+ num_classes = max(pred_classes.max().item(), target.max().item()) + 1
2215
+
2216
+ # Calculate recall for each class and average
2217
+ recall_scores = []
2218
+ for class_id in range(num_classes):
2219
+ pred_class = (pred_classes == class_id).float()
2220
+ target_class = (target == class_id).float()
2221
+
2222
+ true_positives = (pred_class * target_class).sum()
2223
+ actual_positives = target_class.sum()
2224
+
2225
+ if actual_positives > 0:
2226
+ recall = (true_positives + smooth) / (actual_positives + smooth)
2227
+ recall_scores.append(recall.item())
2228
+
2229
+ return sum(recall_scores) / len(recall_scores) if recall_scores else 0.0
2230
+
2231
+
2122
2232
  def train_semantic_one_epoch(
2123
2233
  model: torch.nn.Module,
2124
2234
  optimizer: torch.optim.Optimizer,
@@ -2200,13 +2310,15 @@ def evaluate_semantic(
2200
2310
  num_classes (int): Number of classes for evaluation metrics.
2201
2311
 
2202
2312
  Returns:
2203
- dict: Evaluation metrics including loss, IoU, and Dice.
2313
+ dict: Evaluation metrics including loss, IoU, F1, precision, and recall.
2204
2314
  """
2205
2315
  model.eval()
2206
2316
 
2207
2317
  total_loss = 0
2208
- dice_scores = []
2318
+ f1_scores = []
2209
2319
  iou_scores = []
2320
+ precision_scores = []
2321
+ recall_scores = []
2210
2322
  num_batches = len(data_loader)
2211
2323
 
2212
2324
  with torch.no_grad():
@@ -2222,17 +2334,31 @@ def evaluate_semantic(
2222
2334
 
2223
2335
  # Calculate metrics for each sample in the batch
2224
2336
  for pred, target in zip(outputs, targets):
2225
- dice = dice_coefficient(pred, target, num_classes=num_classes)
2337
+ f1 = f1_score(pred, target, num_classes=num_classes)
2226
2338
  iou = iou_coefficient(pred, target, num_classes=num_classes)
2227
- dice_scores.append(dice)
2339
+ precision = precision_score(pred, target, num_classes=num_classes)
2340
+ recall = recall_score(pred, target, num_classes=num_classes)
2341
+ f1_scores.append(f1)
2228
2342
  iou_scores.append(iou)
2343
+ precision_scores.append(precision)
2344
+ recall_scores.append(recall)
2229
2345
 
2230
2346
  # Calculate metrics
2231
2347
  avg_loss = total_loss / num_batches
2232
- avg_dice = sum(dice_scores) / len(dice_scores) if dice_scores else 0
2348
+ avg_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
2233
2349
  avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
2234
-
2235
- return {"loss": avg_loss, "Dice": avg_dice, "IoU": avg_iou}
2350
+ avg_precision = (
2351
+ sum(precision_scores) / len(precision_scores) if precision_scores else 0
2352
+ )
2353
+ avg_recall = sum(recall_scores) / len(recall_scores) if recall_scores else 0
2354
+
2355
+ return {
2356
+ "loss": avg_loss,
2357
+ "F1": avg_f1,
2358
+ "IoU": avg_iou,
2359
+ "Precision": avg_precision,
2360
+ "Recall": avg_recall,
2361
+ }
2236
2362
 
2237
2363
 
2238
2364
  def train_segmentation_model(
@@ -2261,6 +2387,7 @@ def train_segmentation_model(
2261
2387
  target_size: Optional[Tuple[int, int]] = None,
2262
2388
  resize_mode: str = "resize",
2263
2389
  num_workers: Optional[int] = None,
2390
+ early_stopping_patience: Optional[int] = None,
2264
2391
  **kwargs: Any,
2265
2392
  ) -> torch.nn.Module:
2266
2393
  """
@@ -2313,6 +2440,8 @@ def train_segmentation_model(
2313
2440
  'resize' - Resize images to target_size (may change aspect ratio)
2314
2441
  'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
2315
2442
  num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
2443
+ early_stopping_patience (int, optional): Number of epochs with no improvement after which
2444
+ training will be stopped. If None, early stopping is disabled. Defaults to None.
2316
2445
  **kwargs: Additional arguments passed to smp.create_model().
2317
2446
  Returns:
2318
2447
  None: Model weights are saved to output_dir.
@@ -2542,7 +2671,7 @@ def train_segmentation_model(
2542
2671
  print(f"Using {torch.cuda.device_count()} GPUs for training")
2543
2672
  model = torch.nn.DataParallel(model)
2544
2673
 
2545
- # Set up loss function (CrossEntropyLoss for multi-class, can also use DiceLoss)
2674
+ # Set up loss function (CrossEntropyLoss for multi-class, can also use F1Loss)
2546
2675
  criterion = torch.nn.CrossEntropyLoss()
2547
2676
 
2548
2677
  # Set up optimizer
@@ -2560,8 +2689,11 @@ def train_segmentation_model(
2560
2689
  train_losses = []
2561
2690
  val_losses = []
2562
2691
  val_ious = []
2563
- val_dices = []
2692
+ val_f1s = []
2693
+ val_precisions = []
2694
+ val_recalls = []
2564
2695
  start_epoch = 0
2696
+ epochs_without_improvement = 0
2565
2697
 
2566
2698
  # Load checkpoint if provided
2567
2699
  if checkpoint_path is not None:
@@ -2596,8 +2728,15 @@ def train_segmentation_model(
2596
2728
  val_losses = checkpoint["val_losses"]
2597
2729
  if "val_ious" in checkpoint:
2598
2730
  val_ious = checkpoint["val_ious"]
2599
- if "val_dices" in checkpoint:
2600
- val_dices = checkpoint["val_dices"]
2731
+ if "val_f1s" in checkpoint:
2732
+ val_f1s = checkpoint["val_f1s"]
2733
+ # Also check for old val_dices format for backward compatibility
2734
+ elif "val_dices" in checkpoint:
2735
+ val_f1s = checkpoint["val_dices"]
2736
+ if "val_precisions" in checkpoint:
2737
+ val_precisions = checkpoint["val_precisions"]
2738
+ if "val_recalls" in checkpoint:
2739
+ val_recalls = checkpoint["val_recalls"]
2601
2740
 
2602
2741
  print(f"Resuming training from epoch {start_epoch}")
2603
2742
  print(f"Previous best IoU: {best_iou:.4f}")
@@ -2637,7 +2776,9 @@ def train_segmentation_model(
2637
2776
  )
2638
2777
  val_losses.append(eval_metrics["loss"])
2639
2778
  val_ious.append(eval_metrics["IoU"])
2640
- val_dices.append(eval_metrics["Dice"])
2779
+ val_f1s.append(eval_metrics["F1"])
2780
+ val_precisions.append(eval_metrics["Precision"])
2781
+ val_recalls.append(eval_metrics["Recall"])
2641
2782
 
2642
2783
  # Update learning rate
2643
2784
  lr_scheduler.step(eval_metrics["loss"])
@@ -2648,14 +2789,28 @@ def train_segmentation_model(
2648
2789
  f"Train Loss: {train_loss:.4f}, "
2649
2790
  f"Val Loss: {eval_metrics['loss']:.4f}, "
2650
2791
  f"Val IoU: {eval_metrics['IoU']:.4f}, "
2651
- f"Val Dice: {eval_metrics['Dice']:.4f}"
2792
+ f"Val F1: {eval_metrics['F1']:.4f}, "
2793
+ f"Val Precision: {eval_metrics['Precision']:.4f}, "
2794
+ f"Val Recall: {eval_metrics['Recall']:.4f}"
2652
2795
  )
2653
2796
 
2654
- # Save best model
2797
+ # Save best model and check for early stopping
2655
2798
  if eval_metrics["IoU"] > best_iou:
2656
2799
  best_iou = eval_metrics["IoU"]
2800
+ epochs_without_improvement = 0
2657
2801
  print(f"Saving best model with IoU: {best_iou:.4f}")
2658
2802
  torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
2803
+ else:
2804
+ epochs_without_improvement += 1
2805
+ if (
2806
+ early_stopping_patience is not None
2807
+ and epochs_without_improvement >= early_stopping_patience
2808
+ ):
2809
+ print(
2810
+ f"\nEarly stopping triggered after {epochs_without_improvement} epochs without improvement"
2811
+ )
2812
+ print(f"Best validation IoU: {best_iou:.4f}")
2813
+ break
2659
2814
 
2660
2815
  # Save checkpoint every 10 epochs (if not save_best_only)
2661
2816
  if not save_best_only and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
@@ -2673,7 +2828,9 @@ def train_segmentation_model(
2673
2828
  "train_losses": train_losses,
2674
2829
  "val_losses": val_losses,
2675
2830
  "val_ious": val_ious,
2676
- "val_dices": val_dices,
2831
+ "val_f1s": val_f1s,
2832
+ "val_precisions": val_precisions,
2833
+ "val_recalls": val_recalls,
2677
2834
  },
2678
2835
  os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
2679
2836
  )
@@ -2686,7 +2843,9 @@ def train_segmentation_model(
2686
2843
  "train_losses": train_losses,
2687
2844
  "val_losses": val_losses,
2688
2845
  "val_ious": val_ious,
2689
- "val_dices": val_dices,
2846
+ "val_f1s": val_f1s,
2847
+ "val_precisions": val_precisions,
2848
+ "val_recalls": val_recalls,
2690
2849
  }
2691
2850
  torch.save(history, os.path.join(output_dir, "training_history.pth"))
2692
2851
 
@@ -2702,7 +2861,9 @@ def train_segmentation_model(
2702
2861
  f.write(f"Total epochs: {num_epochs}\n")
2703
2862
  f.write(f"Best validation IoU: {best_iou:.4f}\n")
2704
2863
  f.write(f"Final validation IoU: {val_ious[-1]:.4f}\n")
2705
- f.write(f"Final validation Dice: {val_dices[-1]:.4f}\n")
2864
+ f.write(f"Final validation F1: {val_f1s[-1]:.4f}\n")
2865
+ f.write(f"Final validation Precision: {val_precisions[-1]:.4f}\n")
2866
+ f.write(f"Final validation Recall: {val_recalls[-1]:.4f}\n")
2706
2867
  f.write(f"Final validation loss: {val_losses[-1]:.4f}\n")
2707
2868
 
2708
2869
  print(f"Training complete! Best IoU: {best_iou:.4f}")
@@ -2731,10 +2892,10 @@ def train_segmentation_model(
2731
2892
  plt.grid(True)
2732
2893
 
2733
2894
  plt.subplot(1, 3, 3)
2734
- plt.plot(val_dices, label="Val Dice")
2735
- plt.title("Dice Score")
2895
+ plt.plot(val_f1s, label="Val F1")
2896
+ plt.title("F1 Score")
2736
2897
  plt.xlabel("Epoch")
2737
- plt.ylabel("Dice")
2898
+ plt.ylabel("F1")
2738
2899
  plt.legend()
2739
2900
  plt.grid(True)
2740
2901
 
@@ -2764,6 +2925,7 @@ def semantic_inference_on_geotiff(
2764
2925
  device: Optional[torch.device] = None,
2765
2926
  probability_path: Optional[str] = None,
2766
2927
  probability_threshold: Optional[float] = None,
2928
+ save_class_probabilities: bool = False,
2767
2929
  quiet: bool = False,
2768
2930
  **kwargs: Any,
2769
2931
  ) -> Tuple[str, float]:
@@ -2785,6 +2947,8 @@ def semantic_inference_on_geotiff(
2785
2947
  probability_threshold (float, optional): Probability threshold for binary classification.
2786
2948
  Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
2787
2949
  are classified as class 1, otherwise class 0. If None (default), uses argmax.
2950
+ save_class_probabilities (bool): If True and probability_path is provided, saves each
2951
+ class probability as a separate single-band file. Defaults to False.
2788
2952
  quiet (bool): If True, suppress progress bar. Defaults to False.
2789
2953
  **kwargs: Additional arguments.
2790
2954
 
@@ -3001,7 +3165,7 @@ def semantic_inference_on_geotiff(
3001
3165
  prob_meta = meta.copy()
3002
3166
  prob_meta.update({"count": num_classes, "dtype": "float32"})
3003
3167
 
3004
- # Save normalized probabilities
3168
+ # Save normalized probabilities as multi-band raster
3005
3169
  with rasterio.open(probability_path, "w", **prob_meta) as dst:
3006
3170
  for class_idx in range(num_classes):
3007
3171
  # Normalize probabilities
@@ -3015,6 +3179,36 @@ def semantic_inference_on_geotiff(
3015
3179
  if not quiet:
3016
3180
  print(f"Saved probability map to {probability_path}")
3017
3181
 
3182
+ # Save individual class probabilities if requested
3183
+ if save_class_probabilities:
3184
+ # Prepare single-band metadata
3185
+ single_band_meta = meta.copy()
3186
+ single_band_meta.update({"count": 1, "dtype": "float32"})
3187
+
3188
+ # Get base filename and extension
3189
+ prob_base = os.path.splitext(probability_path)[0]
3190
+ prob_ext = os.path.splitext(probability_path)[1]
3191
+
3192
+ for class_idx in range(num_classes):
3193
+ # Create filename for this class
3194
+ class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
3195
+
3196
+ # Normalize probabilities
3197
+ prob_band = np.zeros((height, width), dtype=np.float32)
3198
+ prob_band[valid_pixels] = (
3199
+ prob_accumulator[class_idx, valid_pixels]
3200
+ / count_accumulator[valid_pixels]
3201
+ )
3202
+
3203
+ # Save single-band file
3204
+ with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
3205
+ dst.write(prob_band, 1)
3206
+
3207
+ if not quiet:
3208
+ print(
3209
+ f"Saved class {class_idx} probability to {class_prob_path}"
3210
+ )
3211
+
3018
3212
  return output_path, inference_time
3019
3213
 
3020
3214
 
@@ -3031,6 +3225,7 @@ def semantic_inference_on_image(
3031
3225
  binary_output: bool = True,
3032
3226
  probability_path: Optional[str] = None,
3033
3227
  probability_threshold: Optional[float] = None,
3228
+ save_class_probabilities: bool = False,
3034
3229
  quiet: bool = False,
3035
3230
  **kwargs: Any,
3036
3231
  ) -> Tuple[str, float]:
@@ -3053,6 +3248,8 @@ def semantic_inference_on_image(
3053
3248
  probability_threshold (float, optional): Probability threshold for binary classification.
3054
3249
  Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
3055
3250
  are classified as class 1, otherwise class 0. If None (default), uses argmax.
3251
+ save_class_probabilities (bool): If True and probability_path is provided, saves each
3252
+ class probability as a separate single-band file. Defaults to False.
3056
3253
  quiet (bool): If True, suppress progress bar. Defaults to False.
3057
3254
  **kwargs: Additional arguments.
3058
3255
 
@@ -3331,7 +3528,7 @@ def semantic_inference_on_image(
3331
3528
  "transform": transform,
3332
3529
  }
3333
3530
 
3334
- # Save normalized probabilities
3531
+ # Save normalized probabilities as multi-band raster
3335
3532
  with rasterio.open(probability_path, "w", **prob_meta) as dst:
3336
3533
  for class_idx in range(num_classes):
3337
3534
  # Normalize probabilities
@@ -3342,6 +3539,39 @@ def semantic_inference_on_image(
3342
3539
  if not quiet:
3343
3540
  print(f"Saved probability map to {probability_path}")
3344
3541
 
3542
+ # Save individual class probabilities if requested
3543
+ if save_class_probabilities:
3544
+ # Prepare single-band metadata
3545
+ single_band_meta = {
3546
+ "driver": "GTiff",
3547
+ "height": height,
3548
+ "width": width,
3549
+ "count": 1,
3550
+ "dtype": "float32",
3551
+ "transform": transform,
3552
+ }
3553
+
3554
+ # Get base filename and extension
3555
+ prob_base = os.path.splitext(probability_path)[0]
3556
+ prob_ext = os.path.splitext(probability_path)[1]
3557
+
3558
+ for class_idx in range(num_classes):
3559
+ # Create filename for this class
3560
+ class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
3561
+
3562
+ # Normalize probabilities
3563
+ prob_band = np.zeros((height, width), dtype=np.float32)
3564
+ prob_band[valid_pixels] = normalized_probs[class_idx, valid_pixels]
3565
+
3566
+ # Save single-band file
3567
+ with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
3568
+ dst.write(prob_band, 1)
3569
+
3570
+ if not quiet:
3571
+ print(
3572
+ f"Saved class {class_idx} probability to {class_prob_path}"
3573
+ )
3574
+
3345
3575
  return output_path, inference_time
3346
3576
 
3347
3577
 
@@ -3359,6 +3589,7 @@ def semantic_segmentation(
3359
3589
  device: Optional[torch.device] = None,
3360
3590
  probability_path: Optional[str] = None,
3361
3591
  probability_threshold: Optional[float] = None,
3592
+ save_class_probabilities: bool = False,
3362
3593
  quiet: bool = False,
3363
3594
  **kwargs: Any,
3364
3595
  ) -> None:
@@ -3381,11 +3612,16 @@ def semantic_segmentation(
3381
3612
  batch_size (int): Batch size for inference.
3382
3613
  device (torch.device, optional): Device to run inference on.
3383
3614
  probability_path (str, optional): Path to save probability map. If provided,
3384
- the normalized class probabilities will be saved as a multi-band raster.
3615
+ the normalized class probabilities will be saved as a multi-band raster
3616
+ where each band contains probabilities for each class.
3385
3617
  probability_threshold (float, optional): Probability threshold for binary classification.
3386
3618
  Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
3387
3619
  are classified as class 1, otherwise class 0. If None (default), uses argmax.
3388
3620
  Must be between 0 and 1.
3621
+ save_class_probabilities (bool): If True and probability_path is provided, saves each
3622
+ class probability as a separate single-band file. Files will be named like
3623
+ "probability_class_0.tif", "probability_class_1.tif", etc. in the same directory
3624
+ as probability_path. Defaults to False.
3389
3625
  quiet (bool): If True, suppress progress bar. Defaults to False.
3390
3626
  **kwargs: Additional arguments.
3391
3627
 
@@ -3462,6 +3698,7 @@ def semantic_segmentation(
3462
3698
  device=device,
3463
3699
  probability_path=probability_path,
3464
3700
  probability_threshold=probability_threshold,
3701
+ save_class_probabilities=save_class_probabilities,
3465
3702
  quiet=quiet,
3466
3703
  **kwargs,
3467
3704
  )
@@ -3482,6 +3719,7 @@ def semantic_segmentation(
3482
3719
  binary_output=True, # Convert to binary output for better visualization
3483
3720
  probability_path=probability_path,
3484
3721
  probability_threshold=probability_threshold,
3722
+ save_class_probabilities=save_class_probabilities,
3485
3723
  quiet=quiet,
3486
3724
  **kwargs,
3487
3725
  )
@@ -3903,3 +4141,307 @@ def instance_segmentation_batch(
3903
4141
  continue
3904
4142
 
3905
4143
  print(f"Batch processing completed. Results saved to {output_dir}")
4144
+
4145
+
4146
+ def lightly_train_model(
4147
+ data_dir: str,
4148
+ output_dir: str,
4149
+ model: str = "torchvision/resnet50",
4150
+ method: str = "dinov2_distillation",
4151
+ epochs: int = 100,
4152
+ batch_size: int = 64,
4153
+ learning_rate: float = 1e-4,
4154
+ **kwargs: Any,
4155
+ ) -> str:
4156
+ """
4157
+ Train a model using Lightly Train for self-supervised pretraining.
4158
+
4159
+ Args:
4160
+ data_dir (str): Directory containing unlabeled images for training.
4161
+ output_dir (str): Directory to save training outputs and model checkpoints.
4162
+ model (str): Model architecture to train. Supports models from torchvision,
4163
+ timm, ultralytics, etc. Default is "torchvision/resnet50".
4164
+ method (str): Self-supervised learning method. Options include:
4165
+ - "simclr": Works with CNN models (ResNet, EfficientNet, etc.)
4166
+ - "dino": Works with both CNNs and ViTs
4167
+ - "dinov2": Requires ViT models only
4168
+ - "dinov2_distillation": Requires ViT models only (recommended for ViTs)
4169
+ Default is "dinov2_distillation".
4170
+ epochs (int): Number of training epochs. Default is 100.
4171
+ batch_size (int): Batch size for training. Default is 64.
4172
+ learning_rate (float): Learning rate for training. Default is 1e-4.
4173
+ **kwargs: Additional arguments passed to lightly_train.train().
4174
+
4175
+ Returns:
4176
+ str: Path to the exported model file.
4177
+
4178
+ Raises:
4179
+ ImportError: If lightly-train is not installed.
4180
+ ValueError: If data_dir does not exist, is empty, or incompatible model/method.
4181
+
4182
+ Note:
4183
+ Model/Method compatibility:
4184
+ - CNN models (ResNet, EfficientNet): Use "simclr" or "dino"
4185
+ - ViT models: Use "dinov2", "dinov2_distillation", or "dino"
4186
+
4187
+ Example:
4188
+ >>> # For CNN models (ResNet, EfficientNet)
4189
+ >>> model_path = lightly_train_model(
4190
+ ... data_dir="path/to/unlabeled/images",
4191
+ ... output_dir="path/to/output",
4192
+ ... model="torchvision/resnet50",
4193
+ ... method="simclr", # Use simclr for CNNs
4194
+ ... epochs=50
4195
+ ... )
4196
+ >>> # For ViT models
4197
+ >>> model_path = lightly_train_model(
4198
+ ... data_dir="path/to/unlabeled/images",
4199
+ ... output_dir="path/to/output",
4200
+ ... model="timm/vit_base_patch16_224",
4201
+ ... method="dinov2", # dinov2 requires ViT
4202
+ ... epochs=50
4203
+ ... )
4204
+ """
4205
+ if not LIGHTLY_TRAIN_AVAILABLE:
4206
+ raise ImportError(
4207
+ "lightly-train is not installed. Please install it with: "
4208
+ "pip install lightly-train"
4209
+ )
4210
+
4211
+ if not os.path.exists(data_dir):
4212
+ raise ValueError(f"Data directory does not exist: {data_dir}")
4213
+
4214
+ # Check if data directory contains images
4215
+ image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.tif", "*.tiff", "*.bmp"]
4216
+ image_files = []
4217
+ for ext in image_extensions:
4218
+ image_files.extend(glob.glob(os.path.join(data_dir, "**", ext), recursive=True))
4219
+
4220
+ if not image_files:
4221
+ raise ValueError(f"No image files found in {data_dir}")
4222
+
4223
+ # Validate model/method compatibility
4224
+ is_vit_model = "vit" in model.lower() or "vision_transformer" in model.lower()
4225
+
4226
+ if method in ["dinov2", "dinov2_distillation"] and not is_vit_model:
4227
+ raise ValueError(
4228
+ f"Method '{method}' requires a Vision Transformer (ViT) model, but got '{model}'.\n"
4229
+ f"Solutions:\n"
4230
+ f" 1. Use a ViT model: model='timm/vit_base_patch16_224'\n"
4231
+ f" 2. Use a CNN-compatible method: method='simclr' or method='dino'\n"
4232
+ f"\nFor CNN models (ResNet, EfficientNet), use 'simclr' or 'dino'.\n"
4233
+ f"For ViT models, use 'dinov2', 'dinov2_distillation', or 'dino'."
4234
+ )
4235
+
4236
+ print(f"Found {len(image_files)} images in {data_dir}")
4237
+ print(f"Starting self-supervised pretraining with {method} method...")
4238
+ print(f"Model: {model}")
4239
+
4240
+ # Create output directory
4241
+ os.makedirs(output_dir, exist_ok=True)
4242
+
4243
+ # Detect if running in notebook environment and set appropriate configuration
4244
+ def is_notebook():
4245
+ try:
4246
+ from IPython import get_ipython
4247
+
4248
+ if get_ipython() is not None:
4249
+ return True
4250
+ except (ImportError, NameError):
4251
+ pass
4252
+ return False
4253
+
4254
+ # Force single-device training in notebooks to avoid DDP strategy issues
4255
+ if is_notebook():
4256
+ # Only override if not explicitly set by user
4257
+ if "accelerator" not in kwargs:
4258
+ # Use CPU in notebooks to avoid DDP incompatibility
4259
+ # Users can still override by passing accelerator='gpu'
4260
+ kwargs["accelerator"] = "cpu"
4261
+ if "devices" not in kwargs:
4262
+ kwargs["devices"] = 1 # Force single device
4263
+
4264
+ # Train the model using Lightly Train
4265
+ lightly_train.train(
4266
+ out=output_dir,
4267
+ data=data_dir,
4268
+ model=model,
4269
+ method=method,
4270
+ epochs=epochs,
4271
+ batch_size=batch_size,
4272
+ **kwargs,
4273
+ )
4274
+
4275
+ # Return path to the exported model
4276
+ exported_model_path = os.path.join(
4277
+ output_dir, "exported_models", "exported_last.pt"
4278
+ )
4279
+
4280
+ if os.path.exists(exported_model_path):
4281
+ print(
4282
+ f"Model training completed. Exported model saved to: {exported_model_path}"
4283
+ )
4284
+ return exported_model_path
4285
+ else:
4286
+ # Check for alternative export paths
4287
+ possible_paths = [
4288
+ os.path.join(output_dir, "exported_models", "exported_best.pt"),
4289
+ os.path.join(output_dir, "checkpoints", "last.ckpt"),
4290
+ ]
4291
+
4292
+ for path in possible_paths:
4293
+ if os.path.exists(path):
4294
+ print(f"Model training completed. Exported model saved to: {path}")
4295
+ return path
4296
+
4297
+ print(f"Model training completed. Output saved to: {output_dir}")
4298
+ return output_dir
4299
+
4300
+
4301
+ def load_lightly_pretrained_model(
4302
+ model_path: str,
4303
+ model_architecture: str = "torchvision/resnet50",
4304
+ device: str = None,
4305
+ ) -> torch.nn.Module:
4306
+ """
4307
+ Load a pretrained model from Lightly Train.
4308
+
4309
+ Args:
4310
+ model_path (str): Path to the pretrained model file (.pt format).
4311
+ model_architecture (str): Architecture of the model to load.
4312
+ Default is "torchvision/resnet50".
4313
+ device (str): Device to load the model on. If None, uses CPU.
4314
+
4315
+ Returns:
4316
+ torch.nn.Module: Loaded pretrained model ready for fine-tuning.
4317
+
4318
+ Raises:
4319
+ FileNotFoundError: If model_path does not exist.
4320
+ ImportError: If required libraries are not available.
4321
+
4322
+ Example:
4323
+ >>> model = load_lightly_pretrained_model(
4324
+ ... model_path="path/to/pretrained_model.pt",
4325
+ ... model_architecture="torchvision/resnet50",
4326
+ ... device="cuda"
4327
+ ... )
4328
+ >>> # Fine-tune the model with your existing training pipeline
4329
+ """
4330
+ if not os.path.exists(model_path):
4331
+ raise FileNotFoundError(f"Model file not found: {model_path}")
4332
+
4333
+ print(f"Loading pretrained model from: {model_path}")
4334
+
4335
+ # Load the model based on architecture
4336
+ if model_architecture.startswith("torchvision/"):
4337
+ model_name = model_architecture.replace("torchvision/", "")
4338
+
4339
+ # Import the model from torchvision
4340
+ if hasattr(torchvision.models, model_name):
4341
+ model = getattr(torchvision.models, model_name)()
4342
+ else:
4343
+ raise ValueError(f"Unknown torchvision model: {model_name}")
4344
+
4345
+ elif model_architecture.startswith("timm/"):
4346
+ try:
4347
+ import timm
4348
+
4349
+ model_name = model_architecture.replace("timm/", "")
4350
+ model = timm.create_model(model_name)
4351
+ except ImportError:
4352
+ raise ImportError(
4353
+ "timm is required for TIMM models. Install with: pip install timm"
4354
+ )
4355
+
4356
+ else:
4357
+ # For other architectures, try to import from torchvision as default
4358
+ try:
4359
+ model = getattr(torchvision.models, model_architecture)()
4360
+ except AttributeError:
4361
+ raise ValueError(f"Unsupported model architecture: {model_architecture}")
4362
+
4363
+ # Load the pretrained weights
4364
+ try:
4365
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
4366
+ except TypeError:
4367
+ # For backward compatibility with older PyTorch versions
4368
+ state_dict = torch.load(model_path, map_location=device)
4369
+ model.load_state_dict(state_dict)
4370
+
4371
+ print(f"Successfully loaded pretrained model: {model_architecture}")
4372
+ return model
4373
+
4374
+
4375
+ def lightly_embed_images(
4376
+ data_dir: str,
4377
+ model_path: str,
4378
+ output_path: str,
4379
+ model_architecture: str = None, # Deprecated, kept for backwards compatibility
4380
+ batch_size: int = 64,
4381
+ **kwargs: Any,
4382
+ ) -> str:
4383
+ """
4384
+ Generate embeddings for images using a Lightly Train pretrained model.
4385
+
4386
+ Args:
4387
+ data_dir (str): Directory containing images to embed.
4388
+ model_path (str): Path to the pretrained model checkpoint file (.ckpt).
4389
+ output_path (str): Path to save the embeddings (as .pt file).
4390
+ model_architecture (str): Architecture of the pretrained model (deprecated,
4391
+ kept for backwards compatibility but not used). The model architecture
4392
+ is automatically loaded from the checkpoint.
4393
+ batch_size (int): Batch size for embedding generation. Default is 64.
4394
+ **kwargs: Additional arguments passed to lightly_train.embed().
4395
+ Supported kwargs include: image_size, num_workers, accelerator, etc.
4396
+
4397
+ Returns:
4398
+ str: Path to the saved embeddings file.
4399
+
4400
+ Raises:
4401
+ ImportError: If lightly-train is not installed.
4402
+ FileNotFoundError: If data_dir or model_path does not exist.
4403
+
4404
+ Note:
4405
+ The model_path should point to a .ckpt file from the training output,
4406
+ typically located at: output_dir/checkpoints/last.ckpt
4407
+
4408
+ Example:
4409
+ >>> embeddings_path = lightly_embed_images(
4410
+ ... data_dir="path/to/images",
4411
+ ... model_path="output_dir/checkpoints/last.ckpt",
4412
+ ... output_path="embeddings.pt",
4413
+ ... batch_size=32
4414
+ ... )
4415
+ >>> print(f"Embeddings saved to: {embeddings_path}")
4416
+ """
4417
+ if not LIGHTLY_TRAIN_AVAILABLE:
4418
+ raise ImportError(
4419
+ "lightly-train is not installed. Please install it with: "
4420
+ "pip install lightly-train"
4421
+ )
4422
+
4423
+ if not os.path.exists(data_dir):
4424
+ raise FileNotFoundError(f"Data directory does not exist: {data_dir}")
4425
+
4426
+ if not os.path.exists(model_path):
4427
+ raise FileNotFoundError(f"Model file does not exist: {model_path}")
4428
+
4429
+ print(f"Generating embeddings for images in: {data_dir}")
4430
+ print(f"Using pretrained model: {model_path}")
4431
+
4432
+ output_dir = os.path.dirname(output_path)
4433
+ if output_dir:
4434
+ os.makedirs(output_dir, exist_ok=True)
4435
+
4436
+ # Generate embeddings using Lightly Train
4437
+ # Note: model_architecture is not used - it's inferred from the checkpoint
4438
+ lightly_train.embed(
4439
+ out=output_path,
4440
+ data=data_dir,
4441
+ checkpoint=model_path,
4442
+ batch_size=batch_size,
4443
+ **kwargs,
4444
+ )
4445
+
4446
+ print(f"Embeddings saved to: {output_path}")
4447
+ return output_path