geoai-py 0.14.0__py2.py3-none-any.whl → 0.16.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/train.py CHANGED
@@ -34,6 +34,97 @@ try:
34
34
  except ImportError:
35
35
  SMP_AVAILABLE = False
36
36
 
37
+ # Additional imports for Lightly Train
38
+ try:
39
+ import lightly_train
40
+
41
+ LIGHTLY_TRAIN_AVAILABLE = True
42
+ except ImportError:
43
+ LIGHTLY_TRAIN_AVAILABLE = False
44
+
45
+
46
+ def parse_coco_annotations(
47
+ coco_json_path: str, images_dir: str, labels_dir: str
48
+ ) -> Tuple[List[str], List[str]]:
49
+ """
50
+ Parse COCO format annotations and return lists of image and label paths.
51
+
52
+ Args:
53
+ coco_json_path (str): Path to COCO annotations JSON file (instances.json).
54
+ images_dir (str): Directory containing image files.
55
+ labels_dir (str): Directory containing label mask files.
56
+
57
+ Returns:
58
+ Tuple[List[str], List[str]]: Lists of image paths and corresponding label paths.
59
+ """
60
+ import json
61
+
62
+ with open(coco_json_path, "r") as f:
63
+ coco_data = json.load(f)
64
+
65
+ # Create mapping from image_id to filename
66
+ image_files = []
67
+ label_files = []
68
+
69
+ for img_info in coco_data["images"]:
70
+ img_filename = img_info["file_name"]
71
+ img_path = os.path.join(images_dir, img_filename)
72
+
73
+ # Derive label filename (same as image filename)
74
+ label_path = os.path.join(labels_dir, img_filename)
75
+
76
+ if os.path.exists(img_path) and os.path.exists(label_path):
77
+ image_files.append(img_path)
78
+ label_files.append(label_path)
79
+
80
+ return image_files, label_files
81
+
82
+
83
+ def parse_yolo_annotations(
84
+ data_dir: str, images_subdir: str = "images", labels_subdir: str = "labels"
85
+ ) -> Tuple[List[str], List[str]]:
86
+ """
87
+ Parse YOLO format annotations and return lists of image and label paths.
88
+
89
+ YOLO format structure:
90
+ - data_dir/images/: Contains image files (.tif, .png, .jpg)
91
+ - data_dir/labels/: Contains label masks (.tif, .png) and YOLO .txt files
92
+ - data_dir/classes.txt: Class names (one per line)
93
+
94
+ Args:
95
+ data_dir (str): Root directory containing YOLO-format data.
96
+ images_subdir (str): Subdirectory name for images. Defaults to 'images'.
97
+ labels_subdir (str): Subdirectory name for labels. Defaults to 'labels'.
98
+
99
+ Returns:
100
+ Tuple[List[str], List[str]]: Lists of image paths and corresponding label paths.
101
+ """
102
+ images_dir = os.path.join(data_dir, images_subdir)
103
+ labels_dir = os.path.join(data_dir, labels_subdir)
104
+
105
+ if not os.path.exists(images_dir):
106
+ raise FileNotFoundError(f"Images directory not found: {images_dir}")
107
+ if not os.path.exists(labels_dir):
108
+ raise FileNotFoundError(f"Labels directory not found: {labels_dir}")
109
+
110
+ # Get all image files
111
+ image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
112
+ image_files = []
113
+ label_files = []
114
+
115
+ for img_file in os.listdir(images_dir):
116
+ if img_file.lower().endswith(image_extensions):
117
+ img_path = os.path.join(images_dir, img_file)
118
+
119
+ # Find corresponding label mask (same filename)
120
+ label_path = os.path.join(labels_dir, img_file)
121
+
122
+ if os.path.exists(label_path):
123
+ image_files.append(img_path)
124
+ label_files.append(label_path)
125
+
126
+ return sorted(image_files), sorted(label_files)
127
+
37
128
 
38
129
  def get_instance_segmentation_model(
39
130
  num_classes: int = 2, num_channels: int = 3, pretrained: bool = True
@@ -617,6 +708,7 @@ def train_MaskRCNN_model(
617
708
  images_dir: str,
618
709
  labels_dir: str,
619
710
  output_dir: str,
711
+ input_format: str = "directory",
620
712
  num_channels: int = 3,
621
713
  model: Optional[torch.nn.Module] = None,
622
714
  pretrained: bool = True,
@@ -640,9 +732,17 @@ def train_MaskRCNN_model(
640
732
  the backbone or to continue training from a specific checkpoint.
641
733
 
642
734
  Args:
643
- images_dir (str): Directory containing image GeoTIFF files.
644
- labels_dir (str): Directory containing label GeoTIFF files.
735
+ images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
736
+ or root directory containing images/ subdirectory (for 'yolo' format),
737
+ or directory containing images (for 'coco' format).
738
+ labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
739
+ or path to COCO annotations JSON file (for 'coco' format),
740
+ or not used (for 'yolo' format - labels are in images_dir/labels/).
645
741
  output_dir (str): Directory to save model checkpoints and results.
742
+ input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
743
+ - 'directory': Standard directory structure with separate images_dir and labels_dir
744
+ - 'coco': COCO JSON format (labels_dir should be path to instances.json)
745
+ - 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
646
746
  num_channels (int, optional): Number of input channels. If None, auto-detected.
647
747
  Defaults to 3.
648
748
  model (torch.nn.Module, optional): Predefined model. If None, a new model is created.
@@ -688,45 +788,63 @@ def train_MaskRCNN_model(
688
788
  device = get_device()
689
789
  print(f"Using device: {device}")
690
790
 
691
- # Get all image and label files
692
- # Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
693
- image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
694
- label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
791
+ # Get all image and label files based on input format
792
+ if input_format.lower() == "coco":
793
+ # Parse COCO format annotations
794
+ if verbose:
795
+ print(f"Loading COCO format annotations from {labels_dir}")
796
+ # For COCO format, labels_dir is path to instances.json
797
+ # Labels are typically in a "labels" directory parallel to "annotations"
798
+ coco_root = os.path.dirname(os.path.dirname(labels_dir)) # Go up two levels
799
+ labels_directory = os.path.join(coco_root, "labels")
800
+ image_files, label_files = parse_coco_annotations(
801
+ labels_dir, images_dir, labels_directory
802
+ )
803
+ elif input_format.lower() == "yolo":
804
+ # Parse YOLO format annotations
805
+ if verbose:
806
+ print(f"Loading YOLO format data from {images_dir}")
807
+ image_files, label_files = parse_yolo_annotations(images_dir)
808
+ else:
809
+ # Default: directory format
810
+ # Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
811
+ image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
812
+ label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
813
+
814
+ image_files = sorted(
815
+ [
816
+ os.path.join(images_dir, f)
817
+ for f in os.listdir(images_dir)
818
+ if f.lower().endswith(image_extensions)
819
+ ]
820
+ )
821
+ label_files = sorted(
822
+ [
823
+ os.path.join(labels_dir, f)
824
+ for f in os.listdir(labels_dir)
825
+ if f.lower().endswith(label_extensions)
826
+ ]
827
+ )
695
828
 
696
- image_files = sorted(
697
- [
698
- os.path.join(images_dir, f)
699
- for f in os.listdir(images_dir)
700
- if f.lower().endswith(image_extensions)
701
- ]
702
- )
703
- label_files = sorted(
704
- [
705
- os.path.join(labels_dir, f)
706
- for f in os.listdir(labels_dir)
707
- if f.lower().endswith(label_extensions)
708
- ]
709
- )
829
+ # Ensure matching files
830
+ if len(image_files) != len(label_files):
831
+ print("Warning: Number of image files and label files don't match!")
832
+ # Find matching files by basename
833
+ basenames = [os.path.basename(f) for f in image_files]
834
+ label_files = [
835
+ os.path.join(labels_dir, os.path.basename(f))
836
+ for f in image_files
837
+ if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
838
+ ]
839
+ image_files = [
840
+ f
841
+ for f, b in zip(image_files, basenames)
842
+ if os.path.exists(os.path.join(labels_dir, b))
843
+ ]
844
+ print(f"Using {len(image_files)} matching files")
710
845
 
711
846
  print(f"Found {len(image_files)} image files and {len(label_files)} label files")
712
847
 
713
- # Ensure matching files
714
- if len(image_files) != len(label_files):
715
- print("Warning: Number of image files and label files don't match!")
716
- # Find matching files by basename
717
- basenames = [os.path.basename(f) for f in image_files]
718
- label_files = [
719
- os.path.join(labels_dir, os.path.basename(f))
720
- for f in image_files
721
- if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
722
- ]
723
- image_files = [
724
- f
725
- for f, b in zip(image_files, basenames)
726
- if os.path.exists(os.path.join(labels_dir, b))
727
- ]
728
- print(f"Using {len(image_files)} matching files")
729
-
730
848
  # Split data into train and validation sets
731
849
  train_imgs, val_imgs, train_labels, val_labels = train_test_split(
732
850
  image_files, label_files, test_size=val_split, random_state=seed
@@ -1915,14 +2033,14 @@ def get_smp_model(
1915
2033
  )
1916
2034
 
1917
2035
 
1918
- def dice_coefficient(
2036
+ def f1_score(
1919
2037
  pred: torch.Tensor,
1920
2038
  target: torch.Tensor,
1921
2039
  smooth: float = 1e-6,
1922
2040
  num_classes: Optional[int] = None,
1923
2041
  ) -> float:
1924
2042
  """
1925
- Calculate Dice coefficient for segmentation (binary or multi-class).
2043
+ Calculate F1 score (also known as Dice coefficient) for segmentation (binary or multi-class).
1926
2044
 
1927
2045
  Args:
1928
2046
  pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
@@ -1931,7 +2049,7 @@ def dice_coefficient(
1931
2049
  num_classes (int, optional): Number of classes. If None, auto-detected.
1932
2050
 
1933
2051
  Returns:
1934
- float: Mean Dice coefficient across all classes.
2052
+ float: Mean F1 score across all classes.
1935
2053
  """
1936
2054
  # Convert predictions to class predictions
1937
2055
  if pred.dim() == 3: # [C, H, W] format
@@ -1946,8 +2064,8 @@ def dice_coefficient(
1946
2064
  if num_classes is None:
1947
2065
  num_classes = max(pred_classes.max().item(), target.max().item()) + 1
1948
2066
 
1949
- # Calculate Dice for each class and average
1950
- dice_scores = []
2067
+ # Calculate F1 score for each class and average
2068
+ f1_scores = []
1951
2069
  for class_id in range(num_classes):
1952
2070
  pred_class = (pred_classes == class_id).float()
1953
2071
  target_class = (target == class_id).float()
@@ -1956,10 +2074,10 @@ def dice_coefficient(
1956
2074
  union = pred_class.sum() + target_class.sum()
1957
2075
 
1958
2076
  if union > 0:
1959
- dice = (2.0 * intersection + smooth) / (union + smooth)
1960
- dice_scores.append(dice.item())
2077
+ f1 = (2.0 * intersection + smooth) / (union + smooth)
2078
+ f1_scores.append(f1.item())
1961
2079
 
1962
- return sum(dice_scores) / len(dice_scores) if dice_scores else 0.0
2080
+ return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
1963
2081
 
1964
2082
 
1965
2083
  def iou_coefficient(
@@ -2009,6 +2127,108 @@ def iou_coefficient(
2009
2127
  return sum(iou_scores) / len(iou_scores) if iou_scores else 0.0
2010
2128
 
2011
2129
 
2130
+ def precision_score(
2131
+ pred: torch.Tensor,
2132
+ target: torch.Tensor,
2133
+ smooth: float = 1e-6,
2134
+ num_classes: Optional[int] = None,
2135
+ ) -> float:
2136
+ """
2137
+ Calculate precision score for segmentation (binary or multi-class).
2138
+
2139
+ Precision = TP / (TP + FP), where:
2140
+ - TP (True Positives): Correctly predicted positive pixels
2141
+ - FP (False Positives): Incorrectly predicted positive pixels
2142
+
2143
+ Args:
2144
+ pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
2145
+ target (torch.Tensor): Ground truth mask with shape [H, W].
2146
+ smooth (float): Smoothing factor to avoid division by zero.
2147
+ num_classes (int, optional): Number of classes. If None, auto-detected.
2148
+
2149
+ Returns:
2150
+ float: Mean precision score across all classes.
2151
+ """
2152
+ # Convert predictions to class predictions
2153
+ if pred.dim() == 3: # [C, H, W] format
2154
+ pred = torch.softmax(pred, dim=0)
2155
+ pred_classes = torch.argmax(pred, dim=0)
2156
+ elif pred.dim() == 2: # [H, W] format
2157
+ pred_classes = pred
2158
+ else:
2159
+ raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
2160
+
2161
+ # Auto-detect number of classes if not provided
2162
+ if num_classes is None:
2163
+ num_classes = max(pred_classes.max().item(), target.max().item()) + 1
2164
+
2165
+ # Calculate precision for each class and average
2166
+ precision_scores = []
2167
+ for class_id in range(num_classes):
2168
+ pred_class = (pred_classes == class_id).float()
2169
+ target_class = (target == class_id).float()
2170
+
2171
+ true_positives = (pred_class * target_class).sum()
2172
+ predicted_positives = pred_class.sum()
2173
+
2174
+ if predicted_positives > 0:
2175
+ precision = (true_positives + smooth) / (predicted_positives + smooth)
2176
+ precision_scores.append(precision.item())
2177
+
2178
+ return sum(precision_scores) / len(precision_scores) if precision_scores else 0.0
2179
+
2180
+
2181
+ def recall_score(
2182
+ pred: torch.Tensor,
2183
+ target: torch.Tensor,
2184
+ smooth: float = 1e-6,
2185
+ num_classes: Optional[int] = None,
2186
+ ) -> float:
2187
+ """
2188
+ Calculate recall score (also known as sensitivity) for segmentation (binary or multi-class).
2189
+
2190
+ Recall = TP / (TP + FN), where:
2191
+ - TP (True Positives): Correctly predicted positive pixels
2192
+ - FN (False Negatives): Incorrectly predicted negative pixels
2193
+
2194
+ Args:
2195
+ pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
2196
+ target (torch.Tensor): Ground truth mask with shape [H, W].
2197
+ smooth (float): Smoothing factor to avoid division by zero.
2198
+ num_classes (int, optional): Number of classes. If None, auto-detected.
2199
+
2200
+ Returns:
2201
+ float: Mean recall score across all classes.
2202
+ """
2203
+ # Convert predictions to class predictions
2204
+ if pred.dim() == 3: # [C, H, W] format
2205
+ pred = torch.softmax(pred, dim=0)
2206
+ pred_classes = torch.argmax(pred, dim=0)
2207
+ elif pred.dim() == 2: # [H, W] format
2208
+ pred_classes = pred
2209
+ else:
2210
+ raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
2211
+
2212
+ # Auto-detect number of classes if not provided
2213
+ if num_classes is None:
2214
+ num_classes = max(pred_classes.max().item(), target.max().item()) + 1
2215
+
2216
+ # Calculate recall for each class and average
2217
+ recall_scores = []
2218
+ for class_id in range(num_classes):
2219
+ pred_class = (pred_classes == class_id).float()
2220
+ target_class = (target == class_id).float()
2221
+
2222
+ true_positives = (pred_class * target_class).sum()
2223
+ actual_positives = target_class.sum()
2224
+
2225
+ if actual_positives > 0:
2226
+ recall = (true_positives + smooth) / (actual_positives + smooth)
2227
+ recall_scores.append(recall.item())
2228
+
2229
+ return sum(recall_scores) / len(recall_scores) if recall_scores else 0.0
2230
+
2231
+
2012
2232
  def train_semantic_one_epoch(
2013
2233
  model: torch.nn.Module,
2014
2234
  optimizer: torch.optim.Optimizer,
@@ -2090,13 +2310,15 @@ def evaluate_semantic(
2090
2310
  num_classes (int): Number of classes for evaluation metrics.
2091
2311
 
2092
2312
  Returns:
2093
- dict: Evaluation metrics including loss, IoU, and Dice.
2313
+ dict: Evaluation metrics including loss, IoU, F1, precision, and recall.
2094
2314
  """
2095
2315
  model.eval()
2096
2316
 
2097
2317
  total_loss = 0
2098
- dice_scores = []
2318
+ f1_scores = []
2099
2319
  iou_scores = []
2320
+ precision_scores = []
2321
+ recall_scores = []
2100
2322
  num_batches = len(data_loader)
2101
2323
 
2102
2324
  with torch.no_grad():
@@ -2112,23 +2334,38 @@ def evaluate_semantic(
2112
2334
 
2113
2335
  # Calculate metrics for each sample in the batch
2114
2336
  for pred, target in zip(outputs, targets):
2115
- dice = dice_coefficient(pred, target, num_classes=num_classes)
2337
+ f1 = f1_score(pred, target, num_classes=num_classes)
2116
2338
  iou = iou_coefficient(pred, target, num_classes=num_classes)
2117
- dice_scores.append(dice)
2339
+ precision = precision_score(pred, target, num_classes=num_classes)
2340
+ recall = recall_score(pred, target, num_classes=num_classes)
2341
+ f1_scores.append(f1)
2118
2342
  iou_scores.append(iou)
2343
+ precision_scores.append(precision)
2344
+ recall_scores.append(recall)
2119
2345
 
2120
2346
  # Calculate metrics
2121
2347
  avg_loss = total_loss / num_batches
2122
- avg_dice = sum(dice_scores) / len(dice_scores) if dice_scores else 0
2348
+ avg_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
2123
2349
  avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
2124
-
2125
- return {"loss": avg_loss, "Dice": avg_dice, "IoU": avg_iou}
2350
+ avg_precision = (
2351
+ sum(precision_scores) / len(precision_scores) if precision_scores else 0
2352
+ )
2353
+ avg_recall = sum(recall_scores) / len(recall_scores) if recall_scores else 0
2354
+
2355
+ return {
2356
+ "loss": avg_loss,
2357
+ "F1": avg_f1,
2358
+ "IoU": avg_iou,
2359
+ "Precision": avg_precision,
2360
+ "Recall": avg_recall,
2361
+ }
2126
2362
 
2127
2363
 
2128
2364
  def train_segmentation_model(
2129
2365
  images_dir: str,
2130
2366
  labels_dir: str,
2131
2367
  output_dir: str,
2368
+ input_format: str = "directory",
2132
2369
  architecture: str = "unet",
2133
2370
  encoder_name: str = "resnet34",
2134
2371
  encoder_weights: Optional[str] = "imagenet",
@@ -2150,6 +2387,7 @@ def train_segmentation_model(
2150
2387
  target_size: Optional[Tuple[int, int]] = None,
2151
2388
  resize_mode: str = "resize",
2152
2389
  num_workers: Optional[int] = None,
2390
+ early_stopping_patience: Optional[int] = None,
2153
2391
  **kwargs: Any,
2154
2392
  ) -> torch.nn.Module:
2155
2393
  """
@@ -2160,9 +2398,17 @@ def train_segmentation_model(
2160
2398
  this approach treats the task as pixel-level binary classification.
2161
2399
 
2162
2400
  Args:
2163
- images_dir (str): Directory containing image GeoTIFF files.
2164
- labels_dir (str): Directory containing label GeoTIFF files.
2401
+ images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
2402
+ or root directory containing images/ subdirectory (for 'yolo' format),
2403
+ or directory containing images (for 'coco' format).
2404
+ labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
2405
+ or path to COCO annotations JSON file (for 'coco' format),
2406
+ or not used (for 'yolo' format - labels are in images_dir/labels/).
2165
2407
  output_dir (str): Directory to save model checkpoints and results.
2408
+ input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
2409
+ - 'directory': Standard directory structure with separate images_dir and labels_dir
2410
+ - 'coco': COCO JSON format (labels_dir should be path to instances.json)
2411
+ - 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
2166
2412
  architecture (str): Model architecture ('unet', 'deeplabv3', 'deeplabv3plus', 'fpn',
2167
2413
  'pspnet', 'linknet', 'manet'). Defaults to 'unet'.
2168
2414
  encoder_name (str): Encoder backbone name (e.g., 'resnet34', 'resnet50', 'efficientnet-b0').
@@ -2194,6 +2440,8 @@ def train_segmentation_model(
2194
2440
  'resize' - Resize images to target_size (may change aspect ratio)
2195
2441
  'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
2196
2442
  num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
2443
+ early_stopping_patience (int, optional): Number of epochs with no improvement after which
2444
+ training will be stopped. If None, early stopping is disabled. Defaults to None.
2197
2445
  **kwargs: Additional arguments passed to smp.create_model().
2198
2446
  Returns:
2199
2447
  None: Model weights are saved to output_dir.
@@ -2225,45 +2473,63 @@ def train_segmentation_model(
2225
2473
  device = get_device()
2226
2474
  print(f"Using device: {device}")
2227
2475
 
2228
- # Get all image and label files
2229
- # Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
2230
- image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
2231
- label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
2476
+ # Get all image and label files based on input format
2477
+ if input_format.lower() == "coco":
2478
+ # Parse COCO format annotations
2479
+ if verbose:
2480
+ print(f"Loading COCO format annotations from {labels_dir}")
2481
+ # For COCO format, labels_dir is path to instances.json
2482
+ # Labels are typically in a "labels" directory parallel to "annotations"
2483
+ coco_root = os.path.dirname(os.path.dirname(labels_dir)) # Go up two levels
2484
+ labels_directory = os.path.join(coco_root, "labels")
2485
+ image_files, label_files = parse_coco_annotations(
2486
+ labels_dir, images_dir, labels_directory
2487
+ )
2488
+ elif input_format.lower() == "yolo":
2489
+ # Parse YOLO format annotations
2490
+ if verbose:
2491
+ print(f"Loading YOLO format data from {images_dir}")
2492
+ image_files, label_files = parse_yolo_annotations(images_dir)
2493
+ else:
2494
+ # Default: directory format
2495
+ # Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
2496
+ image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
2497
+ label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
2498
+
2499
+ image_files = sorted(
2500
+ [
2501
+ os.path.join(images_dir, f)
2502
+ for f in os.listdir(images_dir)
2503
+ if f.lower().endswith(image_extensions)
2504
+ ]
2505
+ )
2506
+ label_files = sorted(
2507
+ [
2508
+ os.path.join(labels_dir, f)
2509
+ for f in os.listdir(labels_dir)
2510
+ if f.lower().endswith(label_extensions)
2511
+ ]
2512
+ )
2232
2513
 
2233
- image_files = sorted(
2234
- [
2235
- os.path.join(images_dir, f)
2236
- for f in os.listdir(images_dir)
2237
- if f.lower().endswith(image_extensions)
2238
- ]
2239
- )
2240
- label_files = sorted(
2241
- [
2242
- os.path.join(labels_dir, f)
2243
- for f in os.listdir(labels_dir)
2244
- if f.lower().endswith(label_extensions)
2245
- ]
2246
- )
2514
+ # Ensure matching files
2515
+ if len(image_files) != len(label_files):
2516
+ print("Warning: Number of image files and label files don't match!")
2517
+ # Find matching files by basename
2518
+ basenames = [os.path.basename(f) for f in image_files]
2519
+ label_files = [
2520
+ os.path.join(labels_dir, os.path.basename(f))
2521
+ for f in image_files
2522
+ if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
2523
+ ]
2524
+ image_files = [
2525
+ f
2526
+ for f, b in zip(image_files, basenames)
2527
+ if os.path.exists(os.path.join(labels_dir, b))
2528
+ ]
2529
+ print(f"Using {len(image_files)} matching files")
2247
2530
 
2248
2531
  print(f"Found {len(image_files)} image files and {len(label_files)} label files")
2249
2532
 
2250
- # Ensure matching files
2251
- if len(image_files) != len(label_files):
2252
- print("Warning: Number of image files and label files don't match!")
2253
- # Find matching files by basename
2254
- basenames = [os.path.basename(f) for f in image_files]
2255
- label_files = [
2256
- os.path.join(labels_dir, os.path.basename(f))
2257
- for f in image_files
2258
- if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
2259
- ]
2260
- image_files = [
2261
- f
2262
- for f, b in zip(image_files, basenames)
2263
- if os.path.exists(os.path.join(labels_dir, b))
2264
- ]
2265
- print(f"Using {len(image_files)} matching files")
2266
-
2267
2533
  if len(image_files) == 0:
2268
2534
  raise FileNotFoundError("No matching image and label files found")
2269
2535
 
@@ -2405,7 +2671,7 @@ def train_segmentation_model(
2405
2671
  print(f"Using {torch.cuda.device_count()} GPUs for training")
2406
2672
  model = torch.nn.DataParallel(model)
2407
2673
 
2408
- # Set up loss function (CrossEntropyLoss for multi-class, can also use DiceLoss)
2674
+ # Set up loss function (CrossEntropyLoss for multi-class, can also use F1Loss)
2409
2675
  criterion = torch.nn.CrossEntropyLoss()
2410
2676
 
2411
2677
  # Set up optimizer
@@ -2423,8 +2689,11 @@ def train_segmentation_model(
2423
2689
  train_losses = []
2424
2690
  val_losses = []
2425
2691
  val_ious = []
2426
- val_dices = []
2692
+ val_f1s = []
2693
+ val_precisions = []
2694
+ val_recalls = []
2427
2695
  start_epoch = 0
2696
+ epochs_without_improvement = 0
2428
2697
 
2429
2698
  # Load checkpoint if provided
2430
2699
  if checkpoint_path is not None:
@@ -2459,8 +2728,15 @@ def train_segmentation_model(
2459
2728
  val_losses = checkpoint["val_losses"]
2460
2729
  if "val_ious" in checkpoint:
2461
2730
  val_ious = checkpoint["val_ious"]
2462
- if "val_dices" in checkpoint:
2463
- val_dices = checkpoint["val_dices"]
2731
+ if "val_f1s" in checkpoint:
2732
+ val_f1s = checkpoint["val_f1s"]
2733
+ # Also check for old val_dices format for backward compatibility
2734
+ elif "val_dices" in checkpoint:
2735
+ val_f1s = checkpoint["val_dices"]
2736
+ if "val_precisions" in checkpoint:
2737
+ val_precisions = checkpoint["val_precisions"]
2738
+ if "val_recalls" in checkpoint:
2739
+ val_recalls = checkpoint["val_recalls"]
2464
2740
 
2465
2741
  print(f"Resuming training from epoch {start_epoch}")
2466
2742
  print(f"Previous best IoU: {best_iou:.4f}")
@@ -2500,7 +2776,9 @@ def train_segmentation_model(
2500
2776
  )
2501
2777
  val_losses.append(eval_metrics["loss"])
2502
2778
  val_ious.append(eval_metrics["IoU"])
2503
- val_dices.append(eval_metrics["Dice"])
2779
+ val_f1s.append(eval_metrics["F1"])
2780
+ val_precisions.append(eval_metrics["Precision"])
2781
+ val_recalls.append(eval_metrics["Recall"])
2504
2782
 
2505
2783
  # Update learning rate
2506
2784
  lr_scheduler.step(eval_metrics["loss"])
@@ -2511,14 +2789,28 @@ def train_segmentation_model(
2511
2789
  f"Train Loss: {train_loss:.4f}, "
2512
2790
  f"Val Loss: {eval_metrics['loss']:.4f}, "
2513
2791
  f"Val IoU: {eval_metrics['IoU']:.4f}, "
2514
- f"Val Dice: {eval_metrics['Dice']:.4f}"
2792
+ f"Val F1: {eval_metrics['F1']:.4f}, "
2793
+ f"Val Precision: {eval_metrics['Precision']:.4f}, "
2794
+ f"Val Recall: {eval_metrics['Recall']:.4f}"
2515
2795
  )
2516
2796
 
2517
- # Save best model
2797
+ # Save best model and check for early stopping
2518
2798
  if eval_metrics["IoU"] > best_iou:
2519
2799
  best_iou = eval_metrics["IoU"]
2800
+ epochs_without_improvement = 0
2520
2801
  print(f"Saving best model with IoU: {best_iou:.4f}")
2521
2802
  torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
2803
+ else:
2804
+ epochs_without_improvement += 1
2805
+ if (
2806
+ early_stopping_patience is not None
2807
+ and epochs_without_improvement >= early_stopping_patience
2808
+ ):
2809
+ print(
2810
+ f"\nEarly stopping triggered after {epochs_without_improvement} epochs without improvement"
2811
+ )
2812
+ print(f"Best validation IoU: {best_iou:.4f}")
2813
+ break
2522
2814
 
2523
2815
  # Save checkpoint every 10 epochs (if not save_best_only)
2524
2816
  if not save_best_only and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
@@ -2536,7 +2828,9 @@ def train_segmentation_model(
2536
2828
  "train_losses": train_losses,
2537
2829
  "val_losses": val_losses,
2538
2830
  "val_ious": val_ious,
2539
- "val_dices": val_dices,
2831
+ "val_f1s": val_f1s,
2832
+ "val_precisions": val_precisions,
2833
+ "val_recalls": val_recalls,
2540
2834
  },
2541
2835
  os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
2542
2836
  )
@@ -2549,7 +2843,9 @@ def train_segmentation_model(
2549
2843
  "train_losses": train_losses,
2550
2844
  "val_losses": val_losses,
2551
2845
  "val_ious": val_ious,
2552
- "val_dices": val_dices,
2846
+ "val_f1s": val_f1s,
2847
+ "val_precisions": val_precisions,
2848
+ "val_recalls": val_recalls,
2553
2849
  }
2554
2850
  torch.save(history, os.path.join(output_dir, "training_history.pth"))
2555
2851
 
@@ -2565,7 +2861,9 @@ def train_segmentation_model(
2565
2861
  f.write(f"Total epochs: {num_epochs}\n")
2566
2862
  f.write(f"Best validation IoU: {best_iou:.4f}\n")
2567
2863
  f.write(f"Final validation IoU: {val_ious[-1]:.4f}\n")
2568
- f.write(f"Final validation Dice: {val_dices[-1]:.4f}\n")
2864
+ f.write(f"Final validation F1: {val_f1s[-1]:.4f}\n")
2865
+ f.write(f"Final validation Precision: {val_precisions[-1]:.4f}\n")
2866
+ f.write(f"Final validation Recall: {val_recalls[-1]:.4f}\n")
2569
2867
  f.write(f"Final validation loss: {val_losses[-1]:.4f}\n")
2570
2868
 
2571
2869
  print(f"Training complete! Best IoU: {best_iou:.4f}")
@@ -2594,10 +2892,10 @@ def train_segmentation_model(
2594
2892
  plt.grid(True)
2595
2893
 
2596
2894
  plt.subplot(1, 3, 3)
2597
- plt.plot(val_dices, label="Val Dice")
2598
- plt.title("Dice Score")
2895
+ plt.plot(val_f1s, label="Val F1")
2896
+ plt.title("F1 Score")
2599
2897
  plt.xlabel("Epoch")
2600
- plt.ylabel("Dice")
2898
+ plt.ylabel("F1")
2601
2899
  plt.legend()
2602
2900
  plt.grid(True)
2603
2901
 
@@ -2627,6 +2925,7 @@ def semantic_inference_on_geotiff(
2627
2925
  device: Optional[torch.device] = None,
2628
2926
  probability_path: Optional[str] = None,
2629
2927
  probability_threshold: Optional[float] = None,
2928
+ save_class_probabilities: bool = False,
2630
2929
  quiet: bool = False,
2631
2930
  **kwargs: Any,
2632
2931
  ) -> Tuple[str, float]:
@@ -2648,6 +2947,8 @@ def semantic_inference_on_geotiff(
2648
2947
  probability_threshold (float, optional): Probability threshold for binary classification.
2649
2948
  Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
2650
2949
  are classified as class 1, otherwise class 0. If None (default), uses argmax.
2950
+ save_class_probabilities (bool): If True and probability_path is provided, saves each
2951
+ class probability as a separate single-band file. Defaults to False.
2651
2952
  quiet (bool): If True, suppress progress bar. Defaults to False.
2652
2953
  **kwargs: Additional arguments.
2653
2954
 
@@ -2864,7 +3165,7 @@ def semantic_inference_on_geotiff(
2864
3165
  prob_meta = meta.copy()
2865
3166
  prob_meta.update({"count": num_classes, "dtype": "float32"})
2866
3167
 
2867
- # Save normalized probabilities
3168
+ # Save normalized probabilities as multi-band raster
2868
3169
  with rasterio.open(probability_path, "w", **prob_meta) as dst:
2869
3170
  for class_idx in range(num_classes):
2870
3171
  # Normalize probabilities
@@ -2878,6 +3179,36 @@ def semantic_inference_on_geotiff(
2878
3179
  if not quiet:
2879
3180
  print(f"Saved probability map to {probability_path}")
2880
3181
 
3182
+ # Save individual class probabilities if requested
3183
+ if save_class_probabilities:
3184
+ # Prepare single-band metadata
3185
+ single_band_meta = meta.copy()
3186
+ single_band_meta.update({"count": 1, "dtype": "float32"})
3187
+
3188
+ # Get base filename and extension
3189
+ prob_base = os.path.splitext(probability_path)[0]
3190
+ prob_ext = os.path.splitext(probability_path)[1]
3191
+
3192
+ for class_idx in range(num_classes):
3193
+ # Create filename for this class
3194
+ class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
3195
+
3196
+ # Normalize probabilities
3197
+ prob_band = np.zeros((height, width), dtype=np.float32)
3198
+ prob_band[valid_pixels] = (
3199
+ prob_accumulator[class_idx, valid_pixels]
3200
+ / count_accumulator[valid_pixels]
3201
+ )
3202
+
3203
+ # Save single-band file
3204
+ with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
3205
+ dst.write(prob_band, 1)
3206
+
3207
+ if not quiet:
3208
+ print(
3209
+ f"Saved class {class_idx} probability to {class_prob_path}"
3210
+ )
3211
+
2881
3212
  return output_path, inference_time
2882
3213
 
2883
3214
 
@@ -2894,6 +3225,7 @@ def semantic_inference_on_image(
2894
3225
  binary_output: bool = True,
2895
3226
  probability_path: Optional[str] = None,
2896
3227
  probability_threshold: Optional[float] = None,
3228
+ save_class_probabilities: bool = False,
2897
3229
  quiet: bool = False,
2898
3230
  **kwargs: Any,
2899
3231
  ) -> Tuple[str, float]:
@@ -2916,6 +3248,8 @@ def semantic_inference_on_image(
2916
3248
  probability_threshold (float, optional): Probability threshold for binary classification.
2917
3249
  Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
2918
3250
  are classified as class 1, otherwise class 0. If None (default), uses argmax.
3251
+ save_class_probabilities (bool): If True and probability_path is provided, saves each
3252
+ class probability as a separate single-band file. Defaults to False.
2919
3253
  quiet (bool): If True, suppress progress bar. Defaults to False.
2920
3254
  **kwargs: Additional arguments.
2921
3255
 
@@ -3194,7 +3528,7 @@ def semantic_inference_on_image(
3194
3528
  "transform": transform,
3195
3529
  }
3196
3530
 
3197
- # Save normalized probabilities
3531
+ # Save normalized probabilities as multi-band raster
3198
3532
  with rasterio.open(probability_path, "w", **prob_meta) as dst:
3199
3533
  for class_idx in range(num_classes):
3200
3534
  # Normalize probabilities
@@ -3205,6 +3539,39 @@ def semantic_inference_on_image(
3205
3539
  if not quiet:
3206
3540
  print(f"Saved probability map to {probability_path}")
3207
3541
 
3542
+ # Save individual class probabilities if requested
3543
+ if save_class_probabilities:
3544
+ # Prepare single-band metadata
3545
+ single_band_meta = {
3546
+ "driver": "GTiff",
3547
+ "height": height,
3548
+ "width": width,
3549
+ "count": 1,
3550
+ "dtype": "float32",
3551
+ "transform": transform,
3552
+ }
3553
+
3554
+ # Get base filename and extension
3555
+ prob_base = os.path.splitext(probability_path)[0]
3556
+ prob_ext = os.path.splitext(probability_path)[1]
3557
+
3558
+ for class_idx in range(num_classes):
3559
+ # Create filename for this class
3560
+ class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
3561
+
3562
+ # Normalize probabilities
3563
+ prob_band = np.zeros((height, width), dtype=np.float32)
3564
+ prob_band[valid_pixels] = normalized_probs[class_idx, valid_pixels]
3565
+
3566
+ # Save single-band file
3567
+ with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
3568
+ dst.write(prob_band, 1)
3569
+
3570
+ if not quiet:
3571
+ print(
3572
+ f"Saved class {class_idx} probability to {class_prob_path}"
3573
+ )
3574
+
3208
3575
  return output_path, inference_time
3209
3576
 
3210
3577
 
@@ -3222,6 +3589,7 @@ def semantic_segmentation(
3222
3589
  device: Optional[torch.device] = None,
3223
3590
  probability_path: Optional[str] = None,
3224
3591
  probability_threshold: Optional[float] = None,
3592
+ save_class_probabilities: bool = False,
3225
3593
  quiet: bool = False,
3226
3594
  **kwargs: Any,
3227
3595
  ) -> None:
@@ -3244,11 +3612,16 @@ def semantic_segmentation(
3244
3612
  batch_size (int): Batch size for inference.
3245
3613
  device (torch.device, optional): Device to run inference on.
3246
3614
  probability_path (str, optional): Path to save probability map. If provided,
3247
- the normalized class probabilities will be saved as a multi-band raster.
3615
+ the normalized class probabilities will be saved as a multi-band raster
3616
+ where each band contains probabilities for each class.
3248
3617
  probability_threshold (float, optional): Probability threshold for binary classification.
3249
3618
  Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
3250
3619
  are classified as class 1, otherwise class 0. If None (default), uses argmax.
3251
3620
  Must be between 0 and 1.
3621
+ save_class_probabilities (bool): If True and probability_path is provided, saves each
3622
+ class probability as a separate single-band file. Files will be named like
3623
+ "probability_class_0.tif", "probability_class_1.tif", etc. in the same directory
3624
+ as probability_path. Defaults to False.
3252
3625
  quiet (bool): If True, suppress progress bar. Defaults to False.
3253
3626
  **kwargs: Additional arguments.
3254
3627
 
@@ -3325,6 +3698,7 @@ def semantic_segmentation(
3325
3698
  device=device,
3326
3699
  probability_path=probability_path,
3327
3700
  probability_threshold=probability_threshold,
3701
+ save_class_probabilities=save_class_probabilities,
3328
3702
  quiet=quiet,
3329
3703
  **kwargs,
3330
3704
  )
@@ -3345,6 +3719,7 @@ def semantic_segmentation(
3345
3719
  binary_output=True, # Convert to binary output for better visualization
3346
3720
  probability_path=probability_path,
3347
3721
  probability_threshold=probability_threshold,
3722
+ save_class_probabilities=save_class_probabilities,
3348
3723
  quiet=quiet,
3349
3724
  **kwargs,
3350
3725
  )
@@ -3527,6 +3902,7 @@ def train_instance_segmentation_model(
3527
3902
  images_dir: str,
3528
3903
  labels_dir: str,
3529
3904
  output_dir: str,
3905
+ input_format: str = "directory",
3530
3906
  num_classes: int = 2,
3531
3907
  num_channels: int = 3,
3532
3908
  batch_size: int = 4,
@@ -3545,9 +3921,17 @@ def train_instance_segmentation_model(
3545
3921
  This is a wrapper function for train_MaskRCNN_model with clearer naming.
3546
3922
 
3547
3923
  Args:
3548
- images_dir (str): Directory containing image GeoTIFF files.
3549
- labels_dir (str): Directory containing label GeoTIFF files.
3924
+ images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
3925
+ or root directory containing images/ subdirectory (for 'yolo' format),
3926
+ or directory containing images (for 'coco' format).
3927
+ labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
3928
+ or path to COCO annotations JSON file (for 'coco' format),
3929
+ or not used (for 'yolo' format - labels are in images_dir/labels/).
3550
3930
  output_dir (str): Directory to save model checkpoints and results.
3931
+ input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
3932
+ - 'directory': Standard directory structure with separate images_dir and labels_dir
3933
+ - 'coco': COCO JSON format (labels_dir should be path to instances.json)
3934
+ - 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
3551
3935
  num_classes (int): Number of classes (including background). Defaults to 2.
3552
3936
  num_channels (int): Number of input channels. Defaults to 3.
3553
3937
  batch_size (int): Batch size for training. Defaults to 4.
@@ -3572,6 +3956,7 @@ def train_instance_segmentation_model(
3572
3956
  images_dir=images_dir,
3573
3957
  labels_dir=labels_dir,
3574
3958
  output_dir=output_dir,
3959
+ input_format=input_format,
3575
3960
  num_channels=num_channels,
3576
3961
  model=model,
3577
3962
  batch_size=batch_size,
@@ -3756,3 +4141,307 @@ def instance_segmentation_batch(
3756
4141
  continue
3757
4142
 
3758
4143
  print(f"Batch processing completed. Results saved to {output_dir}")
4144
+
4145
+
4146
+ def lightly_train_model(
4147
+ data_dir: str,
4148
+ output_dir: str,
4149
+ model: str = "torchvision/resnet50",
4150
+ method: str = "dinov2_distillation",
4151
+ epochs: int = 100,
4152
+ batch_size: int = 64,
4153
+ learning_rate: float = 1e-4,
4154
+ **kwargs: Any,
4155
+ ) -> str:
4156
+ """
4157
+ Train a model using Lightly Train for self-supervised pretraining.
4158
+
4159
+ Args:
4160
+ data_dir (str): Directory containing unlabeled images for training.
4161
+ output_dir (str): Directory to save training outputs and model checkpoints.
4162
+ model (str): Model architecture to train. Supports models from torchvision,
4163
+ timm, ultralytics, etc. Default is "torchvision/resnet50".
4164
+ method (str): Self-supervised learning method. Options include:
4165
+ - "simclr": Works with CNN models (ResNet, EfficientNet, etc.)
4166
+ - "dino": Works with both CNNs and ViTs
4167
+ - "dinov2": Requires ViT models only
4168
+ - "dinov2_distillation": Requires ViT models only (recommended for ViTs)
4169
+ Default is "dinov2_distillation".
4170
+ epochs (int): Number of training epochs. Default is 100.
4171
+ batch_size (int): Batch size for training. Default is 64.
4172
+ learning_rate (float): Learning rate for training. Default is 1e-4.
4173
+ **kwargs: Additional arguments passed to lightly_train.train().
4174
+
4175
+ Returns:
4176
+ str: Path to the exported model file.
4177
+
4178
+ Raises:
4179
+ ImportError: If lightly-train is not installed.
4180
+ ValueError: If data_dir does not exist, is empty, or incompatible model/method.
4181
+
4182
+ Note:
4183
+ Model/Method compatibility:
4184
+ - CNN models (ResNet, EfficientNet): Use "simclr" or "dino"
4185
+ - ViT models: Use "dinov2", "dinov2_distillation", or "dino"
4186
+
4187
+ Example:
4188
+ >>> # For CNN models (ResNet, EfficientNet)
4189
+ >>> model_path = lightly_train_model(
4190
+ ... data_dir="path/to/unlabeled/images",
4191
+ ... output_dir="path/to/output",
4192
+ ... model="torchvision/resnet50",
4193
+ ... method="simclr", # Use simclr for CNNs
4194
+ ... epochs=50
4195
+ ... )
4196
+ >>> # For ViT models
4197
+ >>> model_path = lightly_train_model(
4198
+ ... data_dir="path/to/unlabeled/images",
4199
+ ... output_dir="path/to/output",
4200
+ ... model="timm/vit_base_patch16_224",
4201
+ ... method="dinov2", # dinov2 requires ViT
4202
+ ... epochs=50
4203
+ ... )
4204
+ """
4205
+ if not LIGHTLY_TRAIN_AVAILABLE:
4206
+ raise ImportError(
4207
+ "lightly-train is not installed. Please install it with: "
4208
+ "pip install lightly-train"
4209
+ )
4210
+
4211
+ if not os.path.exists(data_dir):
4212
+ raise ValueError(f"Data directory does not exist: {data_dir}")
4213
+
4214
+ # Check if data directory contains images
4215
+ image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.tif", "*.tiff", "*.bmp"]
4216
+ image_files = []
4217
+ for ext in image_extensions:
4218
+ image_files.extend(glob.glob(os.path.join(data_dir, "**", ext), recursive=True))
4219
+
4220
+ if not image_files:
4221
+ raise ValueError(f"No image files found in {data_dir}")
4222
+
4223
+ # Validate model/method compatibility
4224
+ is_vit_model = "vit" in model.lower() or "vision_transformer" in model.lower()
4225
+
4226
+ if method in ["dinov2", "dinov2_distillation"] and not is_vit_model:
4227
+ raise ValueError(
4228
+ f"Method '{method}' requires a Vision Transformer (ViT) model, but got '{model}'.\n"
4229
+ f"Solutions:\n"
4230
+ f" 1. Use a ViT model: model='timm/vit_base_patch16_224'\n"
4231
+ f" 2. Use a CNN-compatible method: method='simclr' or method='dino'\n"
4232
+ f"\nFor CNN models (ResNet, EfficientNet), use 'simclr' or 'dino'.\n"
4233
+ f"For ViT models, use 'dinov2', 'dinov2_distillation', or 'dino'."
4234
+ )
4235
+
4236
+ print(f"Found {len(image_files)} images in {data_dir}")
4237
+ print(f"Starting self-supervised pretraining with {method} method...")
4238
+ print(f"Model: {model}")
4239
+
4240
+ # Create output directory
4241
+ os.makedirs(output_dir, exist_ok=True)
4242
+
4243
+ # Detect if running in notebook environment and set appropriate configuration
4244
+ def is_notebook():
4245
+ try:
4246
+ from IPython import get_ipython
4247
+
4248
+ if get_ipython() is not None:
4249
+ return True
4250
+ except (ImportError, NameError):
4251
+ pass
4252
+ return False
4253
+
4254
+ # Force single-device training in notebooks to avoid DDP strategy issues
4255
+ if is_notebook():
4256
+ # Only override if not explicitly set by user
4257
+ if "accelerator" not in kwargs:
4258
+ # Use CPU in notebooks to avoid DDP incompatibility
4259
+ # Users can still override by passing accelerator='gpu'
4260
+ kwargs["accelerator"] = "cpu"
4261
+ if "devices" not in kwargs:
4262
+ kwargs["devices"] = 1 # Force single device
4263
+
4264
+ # Train the model using Lightly Train
4265
+ lightly_train.train(
4266
+ out=output_dir,
4267
+ data=data_dir,
4268
+ model=model,
4269
+ method=method,
4270
+ epochs=epochs,
4271
+ batch_size=batch_size,
4272
+ **kwargs,
4273
+ )
4274
+
4275
+ # Return path to the exported model
4276
+ exported_model_path = os.path.join(
4277
+ output_dir, "exported_models", "exported_last.pt"
4278
+ )
4279
+
4280
+ if os.path.exists(exported_model_path):
4281
+ print(
4282
+ f"Model training completed. Exported model saved to: {exported_model_path}"
4283
+ )
4284
+ return exported_model_path
4285
+ else:
4286
+ # Check for alternative export paths
4287
+ possible_paths = [
4288
+ os.path.join(output_dir, "exported_models", "exported_best.pt"),
4289
+ os.path.join(output_dir, "checkpoints", "last.ckpt"),
4290
+ ]
4291
+
4292
+ for path in possible_paths:
4293
+ if os.path.exists(path):
4294
+ print(f"Model training completed. Exported model saved to: {path}")
4295
+ return path
4296
+
4297
+ print(f"Model training completed. Output saved to: {output_dir}")
4298
+ return output_dir
4299
+
4300
+
4301
+ def load_lightly_pretrained_model(
4302
+ model_path: str,
4303
+ model_architecture: str = "torchvision/resnet50",
4304
+ device: str = None,
4305
+ ) -> torch.nn.Module:
4306
+ """
4307
+ Load a pretrained model from Lightly Train.
4308
+
4309
+ Args:
4310
+ model_path (str): Path to the pretrained model file (.pt format).
4311
+ model_architecture (str): Architecture of the model to load.
4312
+ Default is "torchvision/resnet50".
4313
+ device (str): Device to load the model on. If None, uses CPU.
4314
+
4315
+ Returns:
4316
+ torch.nn.Module: Loaded pretrained model ready for fine-tuning.
4317
+
4318
+ Raises:
4319
+ FileNotFoundError: If model_path does not exist.
4320
+ ImportError: If required libraries are not available.
4321
+
4322
+ Example:
4323
+ >>> model = load_lightly_pretrained_model(
4324
+ ... model_path="path/to/pretrained_model.pt",
4325
+ ... model_architecture="torchvision/resnet50",
4326
+ ... device="cuda"
4327
+ ... )
4328
+ >>> # Fine-tune the model with your existing training pipeline
4329
+ """
4330
+ if not os.path.exists(model_path):
4331
+ raise FileNotFoundError(f"Model file not found: {model_path}")
4332
+
4333
+ print(f"Loading pretrained model from: {model_path}")
4334
+
4335
+ # Load the model based on architecture
4336
+ if model_architecture.startswith("torchvision/"):
4337
+ model_name = model_architecture.replace("torchvision/", "")
4338
+
4339
+ # Import the model from torchvision
4340
+ if hasattr(torchvision.models, model_name):
4341
+ model = getattr(torchvision.models, model_name)()
4342
+ else:
4343
+ raise ValueError(f"Unknown torchvision model: {model_name}")
4344
+
4345
+ elif model_architecture.startswith("timm/"):
4346
+ try:
4347
+ import timm
4348
+
4349
+ model_name = model_architecture.replace("timm/", "")
4350
+ model = timm.create_model(model_name)
4351
+ except ImportError:
4352
+ raise ImportError(
4353
+ "timm is required for TIMM models. Install with: pip install timm"
4354
+ )
4355
+
4356
+ else:
4357
+ # For other architectures, try to import from torchvision as default
4358
+ try:
4359
+ model = getattr(torchvision.models, model_architecture)()
4360
+ except AttributeError:
4361
+ raise ValueError(f"Unsupported model architecture: {model_architecture}")
4362
+
4363
+ # Load the pretrained weights
4364
+ try:
4365
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
4366
+ except TypeError:
4367
+ # For backward compatibility with older PyTorch versions
4368
+ state_dict = torch.load(model_path, map_location=device)
4369
+ model.load_state_dict(state_dict)
4370
+
4371
+ print(f"Successfully loaded pretrained model: {model_architecture}")
4372
+ return model
4373
+
4374
+
4375
+ def lightly_embed_images(
4376
+ data_dir: str,
4377
+ model_path: str,
4378
+ output_path: str,
4379
+ model_architecture: str = None, # Deprecated, kept for backwards compatibility
4380
+ batch_size: int = 64,
4381
+ **kwargs: Any,
4382
+ ) -> str:
4383
+ """
4384
+ Generate embeddings for images using a Lightly Train pretrained model.
4385
+
4386
+ Args:
4387
+ data_dir (str): Directory containing images to embed.
4388
+ model_path (str): Path to the pretrained model checkpoint file (.ckpt).
4389
+ output_path (str): Path to save the embeddings (as .pt file).
4390
+ model_architecture (str): Architecture of the pretrained model (deprecated,
4391
+ kept for backwards compatibility but not used). The model architecture
4392
+ is automatically loaded from the checkpoint.
4393
+ batch_size (int): Batch size for embedding generation. Default is 64.
4394
+ **kwargs: Additional arguments passed to lightly_train.embed().
4395
+ Supported kwargs include: image_size, num_workers, accelerator, etc.
4396
+
4397
+ Returns:
4398
+ str: Path to the saved embeddings file.
4399
+
4400
+ Raises:
4401
+ ImportError: If lightly-train is not installed.
4402
+ FileNotFoundError: If data_dir or model_path does not exist.
4403
+
4404
+ Note:
4405
+ The model_path should point to a .ckpt file from the training output,
4406
+ typically located at: output_dir/checkpoints/last.ckpt
4407
+
4408
+ Example:
4409
+ >>> embeddings_path = lightly_embed_images(
4410
+ ... data_dir="path/to/images",
4411
+ ... model_path="output_dir/checkpoints/last.ckpt",
4412
+ ... output_path="embeddings.pt",
4413
+ ... batch_size=32
4414
+ ... )
4415
+ >>> print(f"Embeddings saved to: {embeddings_path}")
4416
+ """
4417
+ if not LIGHTLY_TRAIN_AVAILABLE:
4418
+ raise ImportError(
4419
+ "lightly-train is not installed. Please install it with: "
4420
+ "pip install lightly-train"
4421
+ )
4422
+
4423
+ if not os.path.exists(data_dir):
4424
+ raise FileNotFoundError(f"Data directory does not exist: {data_dir}")
4425
+
4426
+ if not os.path.exists(model_path):
4427
+ raise FileNotFoundError(f"Model file does not exist: {model_path}")
4428
+
4429
+ print(f"Generating embeddings for images in: {data_dir}")
4430
+ print(f"Using pretrained model: {model_path}")
4431
+
4432
+ output_dir = os.path.dirname(output_path)
4433
+ if output_dir:
4434
+ os.makedirs(output_dir, exist_ok=True)
4435
+
4436
+ # Generate embeddings using Lightly Train
4437
+ # Note: model_architecture is not used - it's inferred from the checkpoint
4438
+ lightly_train.embed(
4439
+ out=output_path,
4440
+ data=data_dir,
4441
+ checkpoint=model_path,
4442
+ batch_size=batch_size,
4443
+ **kwargs,
4444
+ )
4445
+
4446
+ print(f"Embeddings saved to: {output_path}")
4447
+ return output_path