geoai-py 0.15.0__py2.py3-none-any.whl → 0.18.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/train.py CHANGED
@@ -34,6 +34,14 @@ try:
34
34
  except ImportError:
35
35
  SMP_AVAILABLE = False
36
36
 
37
+ # Additional imports for Lightly Train
38
+ try:
39
+ import lightly_train
40
+
41
+ LIGHTLY_TRAIN_AVAILABLE = True
42
+ except ImportError:
43
+ LIGHTLY_TRAIN_AVAILABLE = False
44
+
37
45
 
38
46
  def parse_coco_annotations(
39
47
  coco_json_path: str, images_dir: str, labels_dir: str
@@ -1428,8 +1436,12 @@ def instance_segmentation_inference_on_geotiff(
1428
1436
  # Apply Non-Maximum Suppression to handle overlapping detections
1429
1437
  if len(all_detections) > 0:
1430
1438
  # Convert to tensors for NMS
1431
- boxes = torch.tensor([det["box"] for det in all_detections])
1432
- scores = torch.tensor([det["score"] for det in all_detections])
1439
+ boxes = torch.tensor(
1440
+ [det["box"] for det in all_detections], dtype=torch.float32
1441
+ )
1442
+ scores = torch.tensor(
1443
+ [det["score"] for det in all_detections], dtype=torch.float32
1444
+ )
1433
1445
 
1434
1446
  # Apply NMS with IoU threshold
1435
1447
  nms_threshold = 0.3 # IoU threshold for NMS
@@ -1909,6 +1921,96 @@ class SemanticRandomHorizontalFlip:
1909
1921
  return image, mask
1910
1922
 
1911
1923
 
1924
+ class SemanticRandomVerticalFlip:
1925
+ """Random vertical flip transform for semantic segmentation."""
1926
+
1927
+ def __init__(self, prob: float = 0.5) -> None:
1928
+ self.prob = prob
1929
+
1930
+ def __call__(
1931
+ self, image: torch.Tensor, mask: torch.Tensor
1932
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1933
+ if random.random() < self.prob:
1934
+ # Flip image and mask along height dimension
1935
+ image = torch.flip(image, dims=[1])
1936
+ mask = torch.flip(mask, dims=[0])
1937
+ return image, mask
1938
+
1939
+
1940
+ class SemanticRandomRotation90:
1941
+ """Random 90-degree rotation transform for semantic segmentation."""
1942
+
1943
+ def __init__(self, prob: float = 0.5) -> None:
1944
+ self.prob = prob
1945
+
1946
+ def __call__(
1947
+ self, image: torch.Tensor, mask: torch.Tensor
1948
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1949
+ if random.random() < self.prob:
1950
+ # Randomly rotate by 90, 180, or 270 degrees
1951
+ k = random.randint(1, 3)
1952
+ image = torch.rot90(image, k, dims=[1, 2])
1953
+ mask = torch.rot90(mask, k, dims=[0, 1])
1954
+ return image, mask
1955
+
1956
+
1957
+ class SemanticBrightnessAdjustment:
1958
+ """Random brightness adjustment transform for semantic segmentation."""
1959
+
1960
+ def __init__(
1961
+ self, brightness_range: Tuple[float, float] = (0.8, 1.2), prob: float = 0.5
1962
+ ) -> None:
1963
+ """
1964
+ Initialize brightness adjustment transform.
1965
+
1966
+ Args:
1967
+ brightness_range: Tuple of (min, max) brightness factors.
1968
+ prob: Probability of applying the transform.
1969
+ """
1970
+ self.brightness_range = brightness_range
1971
+ self.prob = prob
1972
+
1973
+ def __call__(
1974
+ self, image: torch.Tensor, mask: torch.Tensor
1975
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1976
+ if random.random() < self.prob:
1977
+ # Apply random brightness adjustment
1978
+ factor = self.brightness_range[0] + random.random() * (
1979
+ self.brightness_range[1] - self.brightness_range[0]
1980
+ )
1981
+ image = torch.clamp(image * factor, 0, 1)
1982
+ return image, mask
1983
+
1984
+
1985
+ class SemanticContrastAdjustment:
1986
+ """Random contrast adjustment transform for semantic segmentation."""
1987
+
1988
+ def __init__(
1989
+ self, contrast_range: Tuple[float, float] = (0.8, 1.2), prob: float = 0.5
1990
+ ) -> None:
1991
+ """
1992
+ Initialize contrast adjustment transform.
1993
+
1994
+ Args:
1995
+ contrast_range: Tuple of (min, max) contrast factors.
1996
+ prob: Probability of applying the transform.
1997
+ """
1998
+ self.contrast_range = contrast_range
1999
+ self.prob = prob
2000
+
2001
+ def __call__(
2002
+ self, image: torch.Tensor, mask: torch.Tensor
2003
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
2004
+ if random.random() < self.prob:
2005
+ # Apply random contrast adjustment
2006
+ factor = self.contrast_range[0] + random.random() * (
2007
+ self.contrast_range[1] - self.contrast_range[0]
2008
+ )
2009
+ mean = image.mean(dim=(1, 2), keepdim=True)
2010
+ image = torch.clamp((image - mean) * factor + mean, 0, 1)
2011
+ return image, mask
2012
+
2013
+
1912
2014
  def get_semantic_transform(train: bool) -> Any:
1913
2015
  """
1914
2016
  Get transforms for semantic segmentation data augmentation.
@@ -2025,14 +2127,14 @@ def get_smp_model(
2025
2127
  )
2026
2128
 
2027
2129
 
2028
- def dice_coefficient(
2130
+ def f1_score(
2029
2131
  pred: torch.Tensor,
2030
2132
  target: torch.Tensor,
2031
2133
  smooth: float = 1e-6,
2032
2134
  num_classes: Optional[int] = None,
2033
2135
  ) -> float:
2034
2136
  """
2035
- Calculate Dice coefficient for segmentation (binary or multi-class).
2137
+ Calculate F1 score (also known as Dice coefficient) for segmentation (binary or multi-class).
2036
2138
 
2037
2139
  Args:
2038
2140
  pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
@@ -2041,7 +2143,7 @@ def dice_coefficient(
2041
2143
  num_classes (int, optional): Number of classes. If None, auto-detected.
2042
2144
 
2043
2145
  Returns:
2044
- float: Mean Dice coefficient across all classes.
2146
+ float: Mean F1 score across all classes.
2045
2147
  """
2046
2148
  # Convert predictions to class predictions
2047
2149
  if pred.dim() == 3: # [C, H, W] format
@@ -2056,8 +2158,8 @@ def dice_coefficient(
2056
2158
  if num_classes is None:
2057
2159
  num_classes = max(pred_classes.max().item(), target.max().item()) + 1
2058
2160
 
2059
- # Calculate Dice for each class and average
2060
- dice_scores = []
2161
+ # Calculate F1 score for each class and average
2162
+ f1_scores = []
2061
2163
  for class_id in range(num_classes):
2062
2164
  pred_class = (pred_classes == class_id).float()
2063
2165
  target_class = (target == class_id).float()
@@ -2066,10 +2168,10 @@ def dice_coefficient(
2066
2168
  union = pred_class.sum() + target_class.sum()
2067
2169
 
2068
2170
  if union > 0:
2069
- dice = (2.0 * intersection + smooth) / (union + smooth)
2070
- dice_scores.append(dice.item())
2171
+ f1 = (2.0 * intersection + smooth) / (union + smooth)
2172
+ f1_scores.append(f1.item())
2071
2173
 
2072
- return sum(dice_scores) / len(dice_scores) if dice_scores else 0.0
2174
+ return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
2073
2175
 
2074
2176
 
2075
2177
  def iou_coefficient(
@@ -2119,6 +2221,108 @@ def iou_coefficient(
2119
2221
  return sum(iou_scores) / len(iou_scores) if iou_scores else 0.0
2120
2222
 
2121
2223
 
2224
+ def precision_score(
2225
+ pred: torch.Tensor,
2226
+ target: torch.Tensor,
2227
+ smooth: float = 1e-6,
2228
+ num_classes: Optional[int] = None,
2229
+ ) -> float:
2230
+ """
2231
+ Calculate precision score for segmentation (binary or multi-class).
2232
+
2233
+ Precision = TP / (TP + FP), where:
2234
+ - TP (True Positives): Correctly predicted positive pixels
2235
+ - FP (False Positives): Incorrectly predicted positive pixels
2236
+
2237
+ Args:
2238
+ pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
2239
+ target (torch.Tensor): Ground truth mask with shape [H, W].
2240
+ smooth (float): Smoothing factor to avoid division by zero.
2241
+ num_classes (int, optional): Number of classes. If None, auto-detected.
2242
+
2243
+ Returns:
2244
+ float: Mean precision score across all classes.
2245
+ """
2246
+ # Convert predictions to class predictions
2247
+ if pred.dim() == 3: # [C, H, W] format
2248
+ pred = torch.softmax(pred, dim=0)
2249
+ pred_classes = torch.argmax(pred, dim=0)
2250
+ elif pred.dim() == 2: # [H, W] format
2251
+ pred_classes = pred
2252
+ else:
2253
+ raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
2254
+
2255
+ # Auto-detect number of classes if not provided
2256
+ if num_classes is None:
2257
+ num_classes = max(pred_classes.max().item(), target.max().item()) + 1
2258
+
2259
+ # Calculate precision for each class and average
2260
+ precision_scores = []
2261
+ for class_id in range(num_classes):
2262
+ pred_class = (pred_classes == class_id).float()
2263
+ target_class = (target == class_id).float()
2264
+
2265
+ true_positives = (pred_class * target_class).sum()
2266
+ predicted_positives = pred_class.sum()
2267
+
2268
+ if predicted_positives > 0:
2269
+ precision = (true_positives + smooth) / (predicted_positives + smooth)
2270
+ precision_scores.append(precision.item())
2271
+
2272
+ return sum(precision_scores) / len(precision_scores) if precision_scores else 0.0
2273
+
2274
+
2275
+ def recall_score(
2276
+ pred: torch.Tensor,
2277
+ target: torch.Tensor,
2278
+ smooth: float = 1e-6,
2279
+ num_classes: Optional[int] = None,
2280
+ ) -> float:
2281
+ """
2282
+ Calculate recall score (also known as sensitivity) for segmentation (binary or multi-class).
2283
+
2284
+ Recall = TP / (TP + FN), where:
2285
+ - TP (True Positives): Correctly predicted positive pixels
2286
+ - FN (False Negatives): Incorrectly predicted negative pixels
2287
+
2288
+ Args:
2289
+ pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
2290
+ target (torch.Tensor): Ground truth mask with shape [H, W].
2291
+ smooth (float): Smoothing factor to avoid division by zero.
2292
+ num_classes (int, optional): Number of classes. If None, auto-detected.
2293
+
2294
+ Returns:
2295
+ float: Mean recall score across all classes.
2296
+ """
2297
+ # Convert predictions to class predictions
2298
+ if pred.dim() == 3: # [C, H, W] format
2299
+ pred = torch.softmax(pred, dim=0)
2300
+ pred_classes = torch.argmax(pred, dim=0)
2301
+ elif pred.dim() == 2: # [H, W] format
2302
+ pred_classes = pred
2303
+ else:
2304
+ raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
2305
+
2306
+ # Auto-detect number of classes if not provided
2307
+ if num_classes is None:
2308
+ num_classes = max(pred_classes.max().item(), target.max().item()) + 1
2309
+
2310
+ # Calculate recall for each class and average
2311
+ recall_scores = []
2312
+ for class_id in range(num_classes):
2313
+ pred_class = (pred_classes == class_id).float()
2314
+ target_class = (target == class_id).float()
2315
+
2316
+ true_positives = (pred_class * target_class).sum()
2317
+ actual_positives = target_class.sum()
2318
+
2319
+ if actual_positives > 0:
2320
+ recall = (true_positives + smooth) / (actual_positives + smooth)
2321
+ recall_scores.append(recall.item())
2322
+
2323
+ return sum(recall_scores) / len(recall_scores) if recall_scores else 0.0
2324
+
2325
+
2122
2326
  def train_semantic_one_epoch(
2123
2327
  model: torch.nn.Module,
2124
2328
  optimizer: torch.optim.Optimizer,
@@ -2200,13 +2404,15 @@ def evaluate_semantic(
2200
2404
  num_classes (int): Number of classes for evaluation metrics.
2201
2405
 
2202
2406
  Returns:
2203
- dict: Evaluation metrics including loss, IoU, and Dice.
2407
+ dict: Evaluation metrics including loss, IoU, F1, precision, and recall.
2204
2408
  """
2205
2409
  model.eval()
2206
2410
 
2207
2411
  total_loss = 0
2208
- dice_scores = []
2412
+ f1_scores = []
2209
2413
  iou_scores = []
2414
+ precision_scores = []
2415
+ recall_scores = []
2210
2416
  num_batches = len(data_loader)
2211
2417
 
2212
2418
  with torch.no_grad():
@@ -2222,17 +2428,31 @@ def evaluate_semantic(
2222
2428
 
2223
2429
  # Calculate metrics for each sample in the batch
2224
2430
  for pred, target in zip(outputs, targets):
2225
- dice = dice_coefficient(pred, target, num_classes=num_classes)
2431
+ f1 = f1_score(pred, target, num_classes=num_classes)
2226
2432
  iou = iou_coefficient(pred, target, num_classes=num_classes)
2227
- dice_scores.append(dice)
2433
+ precision = precision_score(pred, target, num_classes=num_classes)
2434
+ recall = recall_score(pred, target, num_classes=num_classes)
2435
+ f1_scores.append(f1)
2228
2436
  iou_scores.append(iou)
2437
+ precision_scores.append(precision)
2438
+ recall_scores.append(recall)
2229
2439
 
2230
2440
  # Calculate metrics
2231
2441
  avg_loss = total_loss / num_batches
2232
- avg_dice = sum(dice_scores) / len(dice_scores) if dice_scores else 0
2442
+ avg_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
2233
2443
  avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
2234
-
2235
- return {"loss": avg_loss, "Dice": avg_dice, "IoU": avg_iou}
2444
+ avg_precision = (
2445
+ sum(precision_scores) / len(precision_scores) if precision_scores else 0
2446
+ )
2447
+ avg_recall = sum(recall_scores) / len(recall_scores) if recall_scores else 0
2448
+
2449
+ return {
2450
+ "loss": avg_loss,
2451
+ "F1": avg_f1,
2452
+ "IoU": avg_iou,
2453
+ "Precision": avg_precision,
2454
+ "Recall": avg_recall,
2455
+ }
2236
2456
 
2237
2457
 
2238
2458
  def train_segmentation_model(
@@ -2261,6 +2481,9 @@ def train_segmentation_model(
2261
2481
  target_size: Optional[Tuple[int, int]] = None,
2262
2482
  resize_mode: str = "resize",
2263
2483
  num_workers: Optional[int] = None,
2484
+ early_stopping_patience: Optional[int] = None,
2485
+ train_transforms: Optional[Callable] = None,
2486
+ val_transforms: Optional[Callable] = None,
2264
2487
  **kwargs: Any,
2265
2488
  ) -> torch.nn.Module:
2266
2489
  """
@@ -2313,6 +2536,17 @@ def train_segmentation_model(
2313
2536
  'resize' - Resize images to target_size (may change aspect ratio)
2314
2537
  'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
2315
2538
  num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
2539
+ Both image and mask should be torch.Tensor objects. The image tensor is expected to be in
2540
+ CHW format (channels, height, width), and the mask tensor in HW format (height, width).
2541
+ If None, uses default transforms (horizontal flip with 0.5 probability). Defaults to None.
2542
+ val_transforms (callable, optional): Custom transforms for validation data.
2543
+ Should be a callable that accepts (image, mask) tensors and returns transformed (image, mask).
2544
+ The image tensor is expected to be in CHW format (channels, height, width), and the mask tensor in HW format (height, width).
2545
+ Both image and mask should be torch.Tensor objects. If None, uses default transforms
2546
+ (horizontal flip with 0.5 probability). Defaults to None.
2547
+ val_transforms (callable, optional): Custom transforms for validation data.
2548
+ Should be a callable that accepts (image, mask) tensors and returns transformed (image, mask).
2549
+ If None, uses default transforms (no augmentation). Defaults to None.
2316
2550
  **kwargs: Additional arguments passed to smp.create_model().
2317
2551
  Returns:
2318
2552
  None: Model weights are saved to output_dir.
@@ -2455,10 +2689,22 @@ def train_segmentation_model(
2455
2689
  print("No resizing needed.")
2456
2690
 
2457
2691
  # Create datasets
2692
+ # Use custom transforms if provided, otherwise use default transforms
2693
+ train_transform = (
2694
+ train_transforms
2695
+ if train_transforms is not None
2696
+ else get_semantic_transform(train=True)
2697
+ )
2698
+ val_transform = (
2699
+ val_transforms
2700
+ if val_transforms is not None
2701
+ else get_semantic_transform(train=False)
2702
+ )
2703
+
2458
2704
  train_dataset = SemanticSegmentationDataset(
2459
2705
  train_imgs,
2460
2706
  train_labels,
2461
- transforms=get_semantic_transform(train=True),
2707
+ transforms=train_transform,
2462
2708
  num_channels=num_channels,
2463
2709
  target_size=target_size,
2464
2710
  resize_mode=resize_mode,
@@ -2467,7 +2713,7 @@ def train_segmentation_model(
2467
2713
  val_dataset = SemanticSegmentationDataset(
2468
2714
  val_imgs,
2469
2715
  val_labels,
2470
- transforms=get_semantic_transform(train=False),
2716
+ transforms=val_transform,
2471
2717
  num_channels=num_channels,
2472
2718
  target_size=target_size,
2473
2719
  resize_mode=resize_mode,
@@ -2542,7 +2788,7 @@ def train_segmentation_model(
2542
2788
  print(f"Using {torch.cuda.device_count()} GPUs for training")
2543
2789
  model = torch.nn.DataParallel(model)
2544
2790
 
2545
- # Set up loss function (CrossEntropyLoss for multi-class, can also use DiceLoss)
2791
+ # Set up loss function (CrossEntropyLoss for multi-class, can also use F1Loss)
2546
2792
  criterion = torch.nn.CrossEntropyLoss()
2547
2793
 
2548
2794
  # Set up optimizer
@@ -2560,8 +2806,11 @@ def train_segmentation_model(
2560
2806
  train_losses = []
2561
2807
  val_losses = []
2562
2808
  val_ious = []
2563
- val_dices = []
2809
+ val_f1s = []
2810
+ val_precisions = []
2811
+ val_recalls = []
2564
2812
  start_epoch = 0
2813
+ epochs_without_improvement = 0
2565
2814
 
2566
2815
  # Load checkpoint if provided
2567
2816
  if checkpoint_path is not None:
@@ -2596,8 +2845,15 @@ def train_segmentation_model(
2596
2845
  val_losses = checkpoint["val_losses"]
2597
2846
  if "val_ious" in checkpoint:
2598
2847
  val_ious = checkpoint["val_ious"]
2599
- if "val_dices" in checkpoint:
2600
- val_dices = checkpoint["val_dices"]
2848
+ if "val_f1s" in checkpoint:
2849
+ val_f1s = checkpoint["val_f1s"]
2850
+ # Also check for old val_dices format for backward compatibility
2851
+ elif "val_dices" in checkpoint:
2852
+ val_f1s = checkpoint["val_dices"]
2853
+ if "val_precisions" in checkpoint:
2854
+ val_precisions = checkpoint["val_precisions"]
2855
+ if "val_recalls" in checkpoint:
2856
+ val_recalls = checkpoint["val_recalls"]
2601
2857
 
2602
2858
  print(f"Resuming training from epoch {start_epoch}")
2603
2859
  print(f"Previous best IoU: {best_iou:.4f}")
@@ -2637,7 +2893,9 @@ def train_segmentation_model(
2637
2893
  )
2638
2894
  val_losses.append(eval_metrics["loss"])
2639
2895
  val_ious.append(eval_metrics["IoU"])
2640
- val_dices.append(eval_metrics["Dice"])
2896
+ val_f1s.append(eval_metrics["F1"])
2897
+ val_precisions.append(eval_metrics["Precision"])
2898
+ val_recalls.append(eval_metrics["Recall"])
2641
2899
 
2642
2900
  # Update learning rate
2643
2901
  lr_scheduler.step(eval_metrics["loss"])
@@ -2648,14 +2906,28 @@ def train_segmentation_model(
2648
2906
  f"Train Loss: {train_loss:.4f}, "
2649
2907
  f"Val Loss: {eval_metrics['loss']:.4f}, "
2650
2908
  f"Val IoU: {eval_metrics['IoU']:.4f}, "
2651
- f"Val Dice: {eval_metrics['Dice']:.4f}"
2909
+ f"Val F1: {eval_metrics['F1']:.4f}, "
2910
+ f"Val Precision: {eval_metrics['Precision']:.4f}, "
2911
+ f"Val Recall: {eval_metrics['Recall']:.4f}"
2652
2912
  )
2653
2913
 
2654
- # Save best model
2914
+ # Save best model and check for early stopping
2655
2915
  if eval_metrics["IoU"] > best_iou:
2656
2916
  best_iou = eval_metrics["IoU"]
2917
+ epochs_without_improvement = 0
2657
2918
  print(f"Saving best model with IoU: {best_iou:.4f}")
2658
2919
  torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
2920
+ else:
2921
+ epochs_without_improvement += 1
2922
+ if (
2923
+ early_stopping_patience is not None
2924
+ and epochs_without_improvement >= early_stopping_patience
2925
+ ):
2926
+ print(
2927
+ f"\nEarly stopping triggered after {epochs_without_improvement} epochs without improvement"
2928
+ )
2929
+ print(f"Best validation IoU: {best_iou:.4f}")
2930
+ break
2659
2931
 
2660
2932
  # Save checkpoint every 10 epochs (if not save_best_only)
2661
2933
  if not save_best_only and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
@@ -2673,7 +2945,9 @@ def train_segmentation_model(
2673
2945
  "train_losses": train_losses,
2674
2946
  "val_losses": val_losses,
2675
2947
  "val_ious": val_ious,
2676
- "val_dices": val_dices,
2948
+ "val_f1s": val_f1s,
2949
+ "val_precisions": val_precisions,
2950
+ "val_recalls": val_recalls,
2677
2951
  },
2678
2952
  os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
2679
2953
  )
@@ -2686,7 +2960,9 @@ def train_segmentation_model(
2686
2960
  "train_losses": train_losses,
2687
2961
  "val_losses": val_losses,
2688
2962
  "val_ious": val_ious,
2689
- "val_dices": val_dices,
2963
+ "val_f1s": val_f1s,
2964
+ "val_precisions": val_precisions,
2965
+ "val_recalls": val_recalls,
2690
2966
  }
2691
2967
  torch.save(history, os.path.join(output_dir, "training_history.pth"))
2692
2968
 
@@ -2702,7 +2978,9 @@ def train_segmentation_model(
2702
2978
  f.write(f"Total epochs: {num_epochs}\n")
2703
2979
  f.write(f"Best validation IoU: {best_iou:.4f}\n")
2704
2980
  f.write(f"Final validation IoU: {val_ious[-1]:.4f}\n")
2705
- f.write(f"Final validation Dice: {val_dices[-1]:.4f}\n")
2981
+ f.write(f"Final validation F1: {val_f1s[-1]:.4f}\n")
2982
+ f.write(f"Final validation Precision: {val_precisions[-1]:.4f}\n")
2983
+ f.write(f"Final validation Recall: {val_recalls[-1]:.4f}\n")
2706
2984
  f.write(f"Final validation loss: {val_losses[-1]:.4f}\n")
2707
2985
 
2708
2986
  print(f"Training complete! Best IoU: {best_iou:.4f}")
@@ -2731,10 +3009,10 @@ def train_segmentation_model(
2731
3009
  plt.grid(True)
2732
3010
 
2733
3011
  plt.subplot(1, 3, 3)
2734
- plt.plot(val_dices, label="Val Dice")
2735
- plt.title("Dice Score")
3012
+ plt.plot(val_f1s, label="Val F1")
3013
+ plt.title("F1 Score")
2736
3014
  plt.xlabel("Epoch")
2737
- plt.ylabel("Dice")
3015
+ plt.ylabel("F1")
2738
3016
  plt.legend()
2739
3017
  plt.grid(True)
2740
3018
 
@@ -2764,6 +3042,7 @@ def semantic_inference_on_geotiff(
2764
3042
  device: Optional[torch.device] = None,
2765
3043
  probability_path: Optional[str] = None,
2766
3044
  probability_threshold: Optional[float] = None,
3045
+ save_class_probabilities: bool = False,
2767
3046
  quiet: bool = False,
2768
3047
  **kwargs: Any,
2769
3048
  ) -> Tuple[str, float]:
@@ -2785,6 +3064,8 @@ def semantic_inference_on_geotiff(
2785
3064
  probability_threshold (float, optional): Probability threshold for binary classification.
2786
3065
  Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
2787
3066
  are classified as class 1, otherwise class 0. If None (default), uses argmax.
3067
+ save_class_probabilities (bool): If True and probability_path is provided, saves each
3068
+ class probability as a separate single-band file. Defaults to False.
2788
3069
  quiet (bool): If True, suppress progress bar. Defaults to False.
2789
3070
  **kwargs: Additional arguments.
2790
3071
 
@@ -3001,7 +3282,7 @@ def semantic_inference_on_geotiff(
3001
3282
  prob_meta = meta.copy()
3002
3283
  prob_meta.update({"count": num_classes, "dtype": "float32"})
3003
3284
 
3004
- # Save normalized probabilities
3285
+ # Save normalized probabilities as multi-band raster
3005
3286
  with rasterio.open(probability_path, "w", **prob_meta) as dst:
3006
3287
  for class_idx in range(num_classes):
3007
3288
  # Normalize probabilities
@@ -3015,6 +3296,36 @@ def semantic_inference_on_geotiff(
3015
3296
  if not quiet:
3016
3297
  print(f"Saved probability map to {probability_path}")
3017
3298
 
3299
+ # Save individual class probabilities if requested
3300
+ if save_class_probabilities:
3301
+ # Prepare single-band metadata
3302
+ single_band_meta = meta.copy()
3303
+ single_band_meta.update({"count": 1, "dtype": "float32"})
3304
+
3305
+ # Get base filename and extension
3306
+ prob_base = os.path.splitext(probability_path)[0]
3307
+ prob_ext = os.path.splitext(probability_path)[1]
3308
+
3309
+ for class_idx in range(num_classes):
3310
+ # Create filename for this class
3311
+ class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
3312
+
3313
+ # Normalize probabilities
3314
+ prob_band = np.zeros((height, width), dtype=np.float32)
3315
+ prob_band[valid_pixels] = (
3316
+ prob_accumulator[class_idx, valid_pixels]
3317
+ / count_accumulator[valid_pixels]
3318
+ )
3319
+
3320
+ # Save single-band file
3321
+ with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
3322
+ dst.write(prob_band, 1)
3323
+
3324
+ if not quiet:
3325
+ print(
3326
+ f"Saved class {class_idx} probability to {class_prob_path}"
3327
+ )
3328
+
3018
3329
  return output_path, inference_time
3019
3330
 
3020
3331
 
@@ -3031,6 +3342,7 @@ def semantic_inference_on_image(
3031
3342
  binary_output: bool = True,
3032
3343
  probability_path: Optional[str] = None,
3033
3344
  probability_threshold: Optional[float] = None,
3345
+ save_class_probabilities: bool = False,
3034
3346
  quiet: bool = False,
3035
3347
  **kwargs: Any,
3036
3348
  ) -> Tuple[str, float]:
@@ -3053,6 +3365,8 @@ def semantic_inference_on_image(
3053
3365
  probability_threshold (float, optional): Probability threshold for binary classification.
3054
3366
  Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
3055
3367
  are classified as class 1, otherwise class 0. If None (default), uses argmax.
3368
+ save_class_probabilities (bool): If True and probability_path is provided, saves each
3369
+ class probability as a separate single-band file. Defaults to False.
3056
3370
  quiet (bool): If True, suppress progress bar. Defaults to False.
3057
3371
  **kwargs: Additional arguments.
3058
3372
 
@@ -3331,7 +3645,7 @@ def semantic_inference_on_image(
3331
3645
  "transform": transform,
3332
3646
  }
3333
3647
 
3334
- # Save normalized probabilities
3648
+ # Save normalized probabilities as multi-band raster
3335
3649
  with rasterio.open(probability_path, "w", **prob_meta) as dst:
3336
3650
  for class_idx in range(num_classes):
3337
3651
  # Normalize probabilities
@@ -3342,6 +3656,39 @@ def semantic_inference_on_image(
3342
3656
  if not quiet:
3343
3657
  print(f"Saved probability map to {probability_path}")
3344
3658
 
3659
+ # Save individual class probabilities if requested
3660
+ if save_class_probabilities:
3661
+ # Prepare single-band metadata
3662
+ single_band_meta = {
3663
+ "driver": "GTiff",
3664
+ "height": height,
3665
+ "width": width,
3666
+ "count": 1,
3667
+ "dtype": "float32",
3668
+ "transform": transform,
3669
+ }
3670
+
3671
+ # Get base filename and extension
3672
+ prob_base = os.path.splitext(probability_path)[0]
3673
+ prob_ext = os.path.splitext(probability_path)[1]
3674
+
3675
+ for class_idx in range(num_classes):
3676
+ # Create filename for this class
3677
+ class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
3678
+
3679
+ # Normalize probabilities
3680
+ prob_band = np.zeros((height, width), dtype=np.float32)
3681
+ prob_band[valid_pixels] = normalized_probs[class_idx, valid_pixels]
3682
+
3683
+ # Save single-band file
3684
+ with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
3685
+ dst.write(prob_band, 1)
3686
+
3687
+ if not quiet:
3688
+ print(
3689
+ f"Saved class {class_idx} probability to {class_prob_path}"
3690
+ )
3691
+
3345
3692
  return output_path, inference_time
3346
3693
 
3347
3694
 
@@ -3359,6 +3706,7 @@ def semantic_segmentation(
3359
3706
  device: Optional[torch.device] = None,
3360
3707
  probability_path: Optional[str] = None,
3361
3708
  probability_threshold: Optional[float] = None,
3709
+ save_class_probabilities: bool = False,
3362
3710
  quiet: bool = False,
3363
3711
  **kwargs: Any,
3364
3712
  ) -> None:
@@ -3381,11 +3729,16 @@ def semantic_segmentation(
3381
3729
  batch_size (int): Batch size for inference.
3382
3730
  device (torch.device, optional): Device to run inference on.
3383
3731
  probability_path (str, optional): Path to save probability map. If provided,
3384
- the normalized class probabilities will be saved as a multi-band raster.
3732
+ the normalized class probabilities will be saved as a multi-band raster
3733
+ where each band contains probabilities for each class.
3385
3734
  probability_threshold (float, optional): Probability threshold for binary classification.
3386
3735
  Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
3387
3736
  are classified as class 1, otherwise class 0. If None (default), uses argmax.
3388
3737
  Must be between 0 and 1.
3738
+ save_class_probabilities (bool): If True and probability_path is provided, saves each
3739
+ class probability as a separate single-band file. Files will be named like
3740
+ "probability_class_0.tif", "probability_class_1.tif", etc. in the same directory
3741
+ as probability_path. Defaults to False.
3389
3742
  quiet (bool): If True, suppress progress bar. Defaults to False.
3390
3743
  **kwargs: Additional arguments.
3391
3744
 
@@ -3462,6 +3815,7 @@ def semantic_segmentation(
3462
3815
  device=device,
3463
3816
  probability_path=probability_path,
3464
3817
  probability_threshold=probability_threshold,
3818
+ save_class_probabilities=save_class_probabilities,
3465
3819
  quiet=quiet,
3466
3820
  **kwargs,
3467
3821
  )
@@ -3482,6 +3836,7 @@ def semantic_segmentation(
3482
3836
  binary_output=True, # Convert to binary output for better visualization
3483
3837
  probability_path=probability_path,
3484
3838
  probability_threshold=probability_threshold,
3839
+ save_class_probabilities=save_class_probabilities,
3485
3840
  quiet=quiet,
3486
3841
  **kwargs,
3487
3842
  )
@@ -3903,3 +4258,307 @@ def instance_segmentation_batch(
3903
4258
  continue
3904
4259
 
3905
4260
  print(f"Batch processing completed. Results saved to {output_dir}")
4261
+
4262
+
4263
+ def lightly_train_model(
4264
+ data_dir: str,
4265
+ output_dir: str,
4266
+ model: str = "torchvision/resnet50",
4267
+ method: str = "dinov2_distillation",
4268
+ epochs: int = 100,
4269
+ batch_size: int = 64,
4270
+ learning_rate: float = 1e-4,
4271
+ **kwargs: Any,
4272
+ ) -> str:
4273
+ """
4274
+ Train a model using Lightly Train for self-supervised pretraining.
4275
+
4276
+ Args:
4277
+ data_dir (str): Directory containing unlabeled images for training.
4278
+ output_dir (str): Directory to save training outputs and model checkpoints.
4279
+ model (str): Model architecture to train. Supports models from torchvision,
4280
+ timm, ultralytics, etc. Default is "torchvision/resnet50".
4281
+ method (str): Self-supervised learning method. Options include:
4282
+ - "simclr": Works with CNN models (ResNet, EfficientNet, etc.)
4283
+ - "dino": Works with both CNNs and ViTs
4284
+ - "dinov2": Requires ViT models only
4285
+ - "dinov2_distillation": Requires ViT models only (recommended for ViTs)
4286
+ Default is "dinov2_distillation".
4287
+ epochs (int): Number of training epochs. Default is 100.
4288
+ batch_size (int): Batch size for training. Default is 64.
4289
+ learning_rate (float): Learning rate for training. Default is 1e-4.
4290
+ **kwargs: Additional arguments passed to lightly_train.train().
4291
+
4292
+ Returns:
4293
+ str: Path to the exported model file.
4294
+
4295
+ Raises:
4296
+ ImportError: If lightly-train is not installed.
4297
+ ValueError: If data_dir does not exist, is empty, or incompatible model/method.
4298
+
4299
+ Note:
4300
+ Model/Method compatibility:
4301
+ - CNN models (ResNet, EfficientNet): Use "simclr" or "dino"
4302
+ - ViT models: Use "dinov2", "dinov2_distillation", or "dino"
4303
+
4304
+ Example:
4305
+ >>> # For CNN models (ResNet, EfficientNet)
4306
+ >>> model_path = lightly_train_model(
4307
+ ... data_dir="path/to/unlabeled/images",
4308
+ ... output_dir="path/to/output",
4309
+ ... model="torchvision/resnet50",
4310
+ ... method="simclr", # Use simclr for CNNs
4311
+ ... epochs=50
4312
+ ... )
4313
+ >>> # For ViT models
4314
+ >>> model_path = lightly_train_model(
4315
+ ... data_dir="path/to/unlabeled/images",
4316
+ ... output_dir="path/to/output",
4317
+ ... model="timm/vit_base_patch16_224",
4318
+ ... method="dinov2", # dinov2 requires ViT
4319
+ ... epochs=50
4320
+ ... )
4321
+ """
4322
+ if not LIGHTLY_TRAIN_AVAILABLE:
4323
+ raise ImportError(
4324
+ "lightly-train is not installed. Please install it with: "
4325
+ "pip install lightly-train"
4326
+ )
4327
+
4328
+ if not os.path.exists(data_dir):
4329
+ raise ValueError(f"Data directory does not exist: {data_dir}")
4330
+
4331
+ # Check if data directory contains images
4332
+ image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.tif", "*.tiff", "*.bmp"]
4333
+ image_files = []
4334
+ for ext in image_extensions:
4335
+ image_files.extend(glob.glob(os.path.join(data_dir, "**", ext), recursive=True))
4336
+
4337
+ if not image_files:
4338
+ raise ValueError(f"No image files found in {data_dir}")
4339
+
4340
+ # Validate model/method compatibility
4341
+ is_vit_model = "vit" in model.lower() or "vision_transformer" in model.lower()
4342
+
4343
+ if method in ["dinov2", "dinov2_distillation"] and not is_vit_model:
4344
+ raise ValueError(
4345
+ f"Method '{method}' requires a Vision Transformer (ViT) model, but got '{model}'.\n"
4346
+ f"Solutions:\n"
4347
+ f" 1. Use a ViT model: model='timm/vit_base_patch16_224'\n"
4348
+ f" 2. Use a CNN-compatible method: method='simclr' or method='dino'\n"
4349
+ f"\nFor CNN models (ResNet, EfficientNet), use 'simclr' or 'dino'.\n"
4350
+ f"For ViT models, use 'dinov2', 'dinov2_distillation', or 'dino'."
4351
+ )
4352
+
4353
+ print(f"Found {len(image_files)} images in {data_dir}")
4354
+ print(f"Starting self-supervised pretraining with {method} method...")
4355
+ print(f"Model: {model}")
4356
+
4357
+ # Create output directory
4358
+ os.makedirs(output_dir, exist_ok=True)
4359
+
4360
+ # Detect if running in notebook environment and set appropriate configuration
4361
+ def is_notebook():
4362
+ try:
4363
+ from IPython import get_ipython
4364
+
4365
+ if get_ipython() is not None:
4366
+ return True
4367
+ except (ImportError, NameError):
4368
+ pass
4369
+ return False
4370
+
4371
+ # Force single-device training in notebooks to avoid DDP strategy issues
4372
+ if is_notebook():
4373
+ # Only override if not explicitly set by user
4374
+ if "accelerator" not in kwargs:
4375
+ # Use CPU in notebooks to avoid DDP incompatibility
4376
+ # Users can still override by passing accelerator='gpu'
4377
+ kwargs["accelerator"] = "cpu"
4378
+ if "devices" not in kwargs:
4379
+ kwargs["devices"] = 1 # Force single device
4380
+
4381
+ # Train the model using Lightly Train
4382
+ lightly_train.train(
4383
+ out=output_dir,
4384
+ data=data_dir,
4385
+ model=model,
4386
+ method=method,
4387
+ epochs=epochs,
4388
+ batch_size=batch_size,
4389
+ **kwargs,
4390
+ )
4391
+
4392
+ # Return path to the exported model
4393
+ exported_model_path = os.path.join(
4394
+ output_dir, "exported_models", "exported_last.pt"
4395
+ )
4396
+
4397
+ if os.path.exists(exported_model_path):
4398
+ print(
4399
+ f"Model training completed. Exported model saved to: {exported_model_path}"
4400
+ )
4401
+ return exported_model_path
4402
+ else:
4403
+ # Check for alternative export paths
4404
+ possible_paths = [
4405
+ os.path.join(output_dir, "exported_models", "exported_best.pt"),
4406
+ os.path.join(output_dir, "checkpoints", "last.ckpt"),
4407
+ ]
4408
+
4409
+ for path in possible_paths:
4410
+ if os.path.exists(path):
4411
+ print(f"Model training completed. Exported model saved to: {path}")
4412
+ return path
4413
+
4414
+ print(f"Model training completed. Output saved to: {output_dir}")
4415
+ return output_dir
4416
+
4417
+
4418
+ def load_lightly_pretrained_model(
4419
+ model_path: str,
4420
+ model_architecture: str = "torchvision/resnet50",
4421
+ device: str = None,
4422
+ ) -> torch.nn.Module:
4423
+ """
4424
+ Load a pretrained model from Lightly Train.
4425
+
4426
+ Args:
4427
+ model_path (str): Path to the pretrained model file (.pt format).
4428
+ model_architecture (str): Architecture of the model to load.
4429
+ Default is "torchvision/resnet50".
4430
+ device (str): Device to load the model on. If None, uses CPU.
4431
+
4432
+ Returns:
4433
+ torch.nn.Module: Loaded pretrained model ready for fine-tuning.
4434
+
4435
+ Raises:
4436
+ FileNotFoundError: If model_path does not exist.
4437
+ ImportError: If required libraries are not available.
4438
+
4439
+ Example:
4440
+ >>> model = load_lightly_pretrained_model(
4441
+ ... model_path="path/to/pretrained_model.pt",
4442
+ ... model_architecture="torchvision/resnet50",
4443
+ ... device="cuda"
4444
+ ... )
4445
+ >>> # Fine-tune the model with your existing training pipeline
4446
+ """
4447
+ if not os.path.exists(model_path):
4448
+ raise FileNotFoundError(f"Model file not found: {model_path}")
4449
+
4450
+ print(f"Loading pretrained model from: {model_path}")
4451
+
4452
+ # Load the model based on architecture
4453
+ if model_architecture.startswith("torchvision/"):
4454
+ model_name = model_architecture.replace("torchvision/", "")
4455
+
4456
+ # Import the model from torchvision
4457
+ if hasattr(torchvision.models, model_name):
4458
+ model = getattr(torchvision.models, model_name)()
4459
+ else:
4460
+ raise ValueError(f"Unknown torchvision model: {model_name}")
4461
+
4462
+ elif model_architecture.startswith("timm/"):
4463
+ try:
4464
+ import timm
4465
+
4466
+ model_name = model_architecture.replace("timm/", "")
4467
+ model = timm.create_model(model_name)
4468
+ except ImportError:
4469
+ raise ImportError(
4470
+ "timm is required for TIMM models. Install with: pip install timm"
4471
+ )
4472
+
4473
+ else:
4474
+ # For other architectures, try to import from torchvision as default
4475
+ try:
4476
+ model = getattr(torchvision.models, model_architecture)()
4477
+ except AttributeError:
4478
+ raise ValueError(f"Unsupported model architecture: {model_architecture}")
4479
+
4480
+ # Load the pretrained weights
4481
+ try:
4482
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
4483
+ except TypeError:
4484
+ # For backward compatibility with older PyTorch versions
4485
+ state_dict = torch.load(model_path, map_location=device)
4486
+ model.load_state_dict(state_dict)
4487
+
4488
+ print(f"Successfully loaded pretrained model: {model_architecture}")
4489
+ return model
4490
+
4491
+
4492
+ def lightly_embed_images(
4493
+ data_dir: str,
4494
+ model_path: str,
4495
+ output_path: str,
4496
+ model_architecture: str = None, # Deprecated, kept for backwards compatibility
4497
+ batch_size: int = 64,
4498
+ **kwargs: Any,
4499
+ ) -> str:
4500
+ """
4501
+ Generate embeddings for images using a Lightly Train pretrained model.
4502
+
4503
+ Args:
4504
+ data_dir (str): Directory containing images to embed.
4505
+ model_path (str): Path to the pretrained model checkpoint file (.ckpt).
4506
+ output_path (str): Path to save the embeddings (as .pt file).
4507
+ model_architecture (str): Architecture of the pretrained model (deprecated,
4508
+ kept for backwards compatibility but not used). The model architecture
4509
+ is automatically loaded from the checkpoint.
4510
+ batch_size (int): Batch size for embedding generation. Default is 64.
4511
+ **kwargs: Additional arguments passed to lightly_train.embed().
4512
+ Supported kwargs include: image_size, num_workers, accelerator, etc.
4513
+
4514
+ Returns:
4515
+ str: Path to the saved embeddings file.
4516
+
4517
+ Raises:
4518
+ ImportError: If lightly-train is not installed.
4519
+ FileNotFoundError: If data_dir or model_path does not exist.
4520
+
4521
+ Note:
4522
+ The model_path should point to a .ckpt file from the training output,
4523
+ typically located at: output_dir/checkpoints/last.ckpt
4524
+
4525
+ Example:
4526
+ >>> embeddings_path = lightly_embed_images(
4527
+ ... data_dir="path/to/images",
4528
+ ... model_path="output_dir/checkpoints/last.ckpt",
4529
+ ... output_path="embeddings.pt",
4530
+ ... batch_size=32
4531
+ ... )
4532
+ >>> print(f"Embeddings saved to: {embeddings_path}")
4533
+ """
4534
+ if not LIGHTLY_TRAIN_AVAILABLE:
4535
+ raise ImportError(
4536
+ "lightly-train is not installed. Please install it with: "
4537
+ "pip install lightly-train"
4538
+ )
4539
+
4540
+ if not os.path.exists(data_dir):
4541
+ raise FileNotFoundError(f"Data directory does not exist: {data_dir}")
4542
+
4543
+ if not os.path.exists(model_path):
4544
+ raise FileNotFoundError(f"Model file does not exist: {model_path}")
4545
+
4546
+ print(f"Generating embeddings for images in: {data_dir}")
4547
+ print(f"Using pretrained model: {model_path}")
4548
+
4549
+ output_dir = os.path.dirname(output_path)
4550
+ if output_dir:
4551
+ os.makedirs(output_dir, exist_ok=True)
4552
+
4553
+ # Generate embeddings using Lightly Train
4554
+ # Note: model_architecture is not used - it's inferred from the checkpoint
4555
+ lightly_train.embed(
4556
+ out=output_path,
4557
+ data=data_dir,
4558
+ checkpoint=model_path,
4559
+ batch_size=batch_size,
4560
+ **kwargs,
4561
+ )
4562
+
4563
+ print(f"Embeddings saved to: {output_path}")
4564
+ return output_path