geoai-py 0.14.0__py2.py3-none-any.whl → 0.16.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +21 -1
- geoai/change_detection.py +16 -6
- geoai/geoai.py +3 -0
- geoai/timm_segment.py +1097 -0
- geoai/timm_train.py +658 -0
- geoai/train.py +796 -107
- geoai/utils.py +1427 -245
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/METADATA +9 -1
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/RECORD +13 -11
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/licenses/LICENSE +1 -2
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/top_level.txt +0 -0
geoai/train.py
CHANGED
@@ -34,6 +34,97 @@ try:
|
|
34
34
|
except ImportError:
|
35
35
|
SMP_AVAILABLE = False
|
36
36
|
|
37
|
+
# Additional imports for Lightly Train
|
38
|
+
try:
|
39
|
+
import lightly_train
|
40
|
+
|
41
|
+
LIGHTLY_TRAIN_AVAILABLE = True
|
42
|
+
except ImportError:
|
43
|
+
LIGHTLY_TRAIN_AVAILABLE = False
|
44
|
+
|
45
|
+
|
46
|
+
def parse_coco_annotations(
|
47
|
+
coco_json_path: str, images_dir: str, labels_dir: str
|
48
|
+
) -> Tuple[List[str], List[str]]:
|
49
|
+
"""
|
50
|
+
Parse COCO format annotations and return lists of image and label paths.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
coco_json_path (str): Path to COCO annotations JSON file (instances.json).
|
54
|
+
images_dir (str): Directory containing image files.
|
55
|
+
labels_dir (str): Directory containing label mask files.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Tuple[List[str], List[str]]: Lists of image paths and corresponding label paths.
|
59
|
+
"""
|
60
|
+
import json
|
61
|
+
|
62
|
+
with open(coco_json_path, "r") as f:
|
63
|
+
coco_data = json.load(f)
|
64
|
+
|
65
|
+
# Create mapping from image_id to filename
|
66
|
+
image_files = []
|
67
|
+
label_files = []
|
68
|
+
|
69
|
+
for img_info in coco_data["images"]:
|
70
|
+
img_filename = img_info["file_name"]
|
71
|
+
img_path = os.path.join(images_dir, img_filename)
|
72
|
+
|
73
|
+
# Derive label filename (same as image filename)
|
74
|
+
label_path = os.path.join(labels_dir, img_filename)
|
75
|
+
|
76
|
+
if os.path.exists(img_path) and os.path.exists(label_path):
|
77
|
+
image_files.append(img_path)
|
78
|
+
label_files.append(label_path)
|
79
|
+
|
80
|
+
return image_files, label_files
|
81
|
+
|
82
|
+
|
83
|
+
def parse_yolo_annotations(
|
84
|
+
data_dir: str, images_subdir: str = "images", labels_subdir: str = "labels"
|
85
|
+
) -> Tuple[List[str], List[str]]:
|
86
|
+
"""
|
87
|
+
Parse YOLO format annotations and return lists of image and label paths.
|
88
|
+
|
89
|
+
YOLO format structure:
|
90
|
+
- data_dir/images/: Contains image files (.tif, .png, .jpg)
|
91
|
+
- data_dir/labels/: Contains label masks (.tif, .png) and YOLO .txt files
|
92
|
+
- data_dir/classes.txt: Class names (one per line)
|
93
|
+
|
94
|
+
Args:
|
95
|
+
data_dir (str): Root directory containing YOLO-format data.
|
96
|
+
images_subdir (str): Subdirectory name for images. Defaults to 'images'.
|
97
|
+
labels_subdir (str): Subdirectory name for labels. Defaults to 'labels'.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
Tuple[List[str], List[str]]: Lists of image paths and corresponding label paths.
|
101
|
+
"""
|
102
|
+
images_dir = os.path.join(data_dir, images_subdir)
|
103
|
+
labels_dir = os.path.join(data_dir, labels_subdir)
|
104
|
+
|
105
|
+
if not os.path.exists(images_dir):
|
106
|
+
raise FileNotFoundError(f"Images directory not found: {images_dir}")
|
107
|
+
if not os.path.exists(labels_dir):
|
108
|
+
raise FileNotFoundError(f"Labels directory not found: {labels_dir}")
|
109
|
+
|
110
|
+
# Get all image files
|
111
|
+
image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
|
112
|
+
image_files = []
|
113
|
+
label_files = []
|
114
|
+
|
115
|
+
for img_file in os.listdir(images_dir):
|
116
|
+
if img_file.lower().endswith(image_extensions):
|
117
|
+
img_path = os.path.join(images_dir, img_file)
|
118
|
+
|
119
|
+
# Find corresponding label mask (same filename)
|
120
|
+
label_path = os.path.join(labels_dir, img_file)
|
121
|
+
|
122
|
+
if os.path.exists(label_path):
|
123
|
+
image_files.append(img_path)
|
124
|
+
label_files.append(label_path)
|
125
|
+
|
126
|
+
return sorted(image_files), sorted(label_files)
|
127
|
+
|
37
128
|
|
38
129
|
def get_instance_segmentation_model(
|
39
130
|
num_classes: int = 2, num_channels: int = 3, pretrained: bool = True
|
@@ -617,6 +708,7 @@ def train_MaskRCNN_model(
|
|
617
708
|
images_dir: str,
|
618
709
|
labels_dir: str,
|
619
710
|
output_dir: str,
|
711
|
+
input_format: str = "directory",
|
620
712
|
num_channels: int = 3,
|
621
713
|
model: Optional[torch.nn.Module] = None,
|
622
714
|
pretrained: bool = True,
|
@@ -640,9 +732,17 @@ def train_MaskRCNN_model(
|
|
640
732
|
the backbone or to continue training from a specific checkpoint.
|
641
733
|
|
642
734
|
Args:
|
643
|
-
images_dir (str): Directory containing image GeoTIFF files
|
644
|
-
|
735
|
+
images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
|
736
|
+
or root directory containing images/ subdirectory (for 'yolo' format),
|
737
|
+
or directory containing images (for 'coco' format).
|
738
|
+
labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
|
739
|
+
or path to COCO annotations JSON file (for 'coco' format),
|
740
|
+
or not used (for 'yolo' format - labels are in images_dir/labels/).
|
645
741
|
output_dir (str): Directory to save model checkpoints and results.
|
742
|
+
input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
|
743
|
+
- 'directory': Standard directory structure with separate images_dir and labels_dir
|
744
|
+
- 'coco': COCO JSON format (labels_dir should be path to instances.json)
|
745
|
+
- 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
|
646
746
|
num_channels (int, optional): Number of input channels. If None, auto-detected.
|
647
747
|
Defaults to 3.
|
648
748
|
model (torch.nn.Module, optional): Predefined model. If None, a new model is created.
|
@@ -688,45 +788,63 @@ def train_MaskRCNN_model(
|
|
688
788
|
device = get_device()
|
689
789
|
print(f"Using device: {device}")
|
690
790
|
|
691
|
-
# Get all image and label files
|
692
|
-
|
693
|
-
|
694
|
-
|
791
|
+
# Get all image and label files based on input format
|
792
|
+
if input_format.lower() == "coco":
|
793
|
+
# Parse COCO format annotations
|
794
|
+
if verbose:
|
795
|
+
print(f"Loading COCO format annotations from {labels_dir}")
|
796
|
+
# For COCO format, labels_dir is path to instances.json
|
797
|
+
# Labels are typically in a "labels" directory parallel to "annotations"
|
798
|
+
coco_root = os.path.dirname(os.path.dirname(labels_dir)) # Go up two levels
|
799
|
+
labels_directory = os.path.join(coco_root, "labels")
|
800
|
+
image_files, label_files = parse_coco_annotations(
|
801
|
+
labels_dir, images_dir, labels_directory
|
802
|
+
)
|
803
|
+
elif input_format.lower() == "yolo":
|
804
|
+
# Parse YOLO format annotations
|
805
|
+
if verbose:
|
806
|
+
print(f"Loading YOLO format data from {images_dir}")
|
807
|
+
image_files, label_files = parse_yolo_annotations(images_dir)
|
808
|
+
else:
|
809
|
+
# Default: directory format
|
810
|
+
# Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
|
811
|
+
image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
|
812
|
+
label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
|
813
|
+
|
814
|
+
image_files = sorted(
|
815
|
+
[
|
816
|
+
os.path.join(images_dir, f)
|
817
|
+
for f in os.listdir(images_dir)
|
818
|
+
if f.lower().endswith(image_extensions)
|
819
|
+
]
|
820
|
+
)
|
821
|
+
label_files = sorted(
|
822
|
+
[
|
823
|
+
os.path.join(labels_dir, f)
|
824
|
+
for f in os.listdir(labels_dir)
|
825
|
+
if f.lower().endswith(label_extensions)
|
826
|
+
]
|
827
|
+
)
|
695
828
|
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
829
|
+
# Ensure matching files
|
830
|
+
if len(image_files) != len(label_files):
|
831
|
+
print("Warning: Number of image files and label files don't match!")
|
832
|
+
# Find matching files by basename
|
833
|
+
basenames = [os.path.basename(f) for f in image_files]
|
834
|
+
label_files = [
|
835
|
+
os.path.join(labels_dir, os.path.basename(f))
|
836
|
+
for f in image_files
|
837
|
+
if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
|
838
|
+
]
|
839
|
+
image_files = [
|
840
|
+
f
|
841
|
+
for f, b in zip(image_files, basenames)
|
842
|
+
if os.path.exists(os.path.join(labels_dir, b))
|
843
|
+
]
|
844
|
+
print(f"Using {len(image_files)} matching files")
|
710
845
|
|
711
846
|
print(f"Found {len(image_files)} image files and {len(label_files)} label files")
|
712
847
|
|
713
|
-
# Ensure matching files
|
714
|
-
if len(image_files) != len(label_files):
|
715
|
-
print("Warning: Number of image files and label files don't match!")
|
716
|
-
# Find matching files by basename
|
717
|
-
basenames = [os.path.basename(f) for f in image_files]
|
718
|
-
label_files = [
|
719
|
-
os.path.join(labels_dir, os.path.basename(f))
|
720
|
-
for f in image_files
|
721
|
-
if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
|
722
|
-
]
|
723
|
-
image_files = [
|
724
|
-
f
|
725
|
-
for f, b in zip(image_files, basenames)
|
726
|
-
if os.path.exists(os.path.join(labels_dir, b))
|
727
|
-
]
|
728
|
-
print(f"Using {len(image_files)} matching files")
|
729
|
-
|
730
848
|
# Split data into train and validation sets
|
731
849
|
train_imgs, val_imgs, train_labels, val_labels = train_test_split(
|
732
850
|
image_files, label_files, test_size=val_split, random_state=seed
|
@@ -1915,14 +2033,14 @@ def get_smp_model(
|
|
1915
2033
|
)
|
1916
2034
|
|
1917
2035
|
|
1918
|
-
def
|
2036
|
+
def f1_score(
|
1919
2037
|
pred: torch.Tensor,
|
1920
2038
|
target: torch.Tensor,
|
1921
2039
|
smooth: float = 1e-6,
|
1922
2040
|
num_classes: Optional[int] = None,
|
1923
2041
|
) -> float:
|
1924
2042
|
"""
|
1925
|
-
Calculate Dice coefficient for segmentation (binary or multi-class).
|
2043
|
+
Calculate F1 score (also known as Dice coefficient) for segmentation (binary or multi-class).
|
1926
2044
|
|
1927
2045
|
Args:
|
1928
2046
|
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
@@ -1931,7 +2049,7 @@ def dice_coefficient(
|
|
1931
2049
|
num_classes (int, optional): Number of classes. If None, auto-detected.
|
1932
2050
|
|
1933
2051
|
Returns:
|
1934
|
-
float: Mean
|
2052
|
+
float: Mean F1 score across all classes.
|
1935
2053
|
"""
|
1936
2054
|
# Convert predictions to class predictions
|
1937
2055
|
if pred.dim() == 3: # [C, H, W] format
|
@@ -1946,8 +2064,8 @@ def dice_coefficient(
|
|
1946
2064
|
if num_classes is None:
|
1947
2065
|
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
1948
2066
|
|
1949
|
-
# Calculate
|
1950
|
-
|
2067
|
+
# Calculate F1 score for each class and average
|
2068
|
+
f1_scores = []
|
1951
2069
|
for class_id in range(num_classes):
|
1952
2070
|
pred_class = (pred_classes == class_id).float()
|
1953
2071
|
target_class = (target == class_id).float()
|
@@ -1956,10 +2074,10 @@ def dice_coefficient(
|
|
1956
2074
|
union = pred_class.sum() + target_class.sum()
|
1957
2075
|
|
1958
2076
|
if union > 0:
|
1959
|
-
|
1960
|
-
|
2077
|
+
f1 = (2.0 * intersection + smooth) / (union + smooth)
|
2078
|
+
f1_scores.append(f1.item())
|
1961
2079
|
|
1962
|
-
return sum(
|
2080
|
+
return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
|
1963
2081
|
|
1964
2082
|
|
1965
2083
|
def iou_coefficient(
|
@@ -2009,6 +2127,108 @@ def iou_coefficient(
|
|
2009
2127
|
return sum(iou_scores) / len(iou_scores) if iou_scores else 0.0
|
2010
2128
|
|
2011
2129
|
|
2130
|
+
def precision_score(
|
2131
|
+
pred: torch.Tensor,
|
2132
|
+
target: torch.Tensor,
|
2133
|
+
smooth: float = 1e-6,
|
2134
|
+
num_classes: Optional[int] = None,
|
2135
|
+
) -> float:
|
2136
|
+
"""
|
2137
|
+
Calculate precision score for segmentation (binary or multi-class).
|
2138
|
+
|
2139
|
+
Precision = TP / (TP + FP), where:
|
2140
|
+
- TP (True Positives): Correctly predicted positive pixels
|
2141
|
+
- FP (False Positives): Incorrectly predicted positive pixels
|
2142
|
+
|
2143
|
+
Args:
|
2144
|
+
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
2145
|
+
target (torch.Tensor): Ground truth mask with shape [H, W].
|
2146
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
2147
|
+
num_classes (int, optional): Number of classes. If None, auto-detected.
|
2148
|
+
|
2149
|
+
Returns:
|
2150
|
+
float: Mean precision score across all classes.
|
2151
|
+
"""
|
2152
|
+
# Convert predictions to class predictions
|
2153
|
+
if pred.dim() == 3: # [C, H, W] format
|
2154
|
+
pred = torch.softmax(pred, dim=0)
|
2155
|
+
pred_classes = torch.argmax(pred, dim=0)
|
2156
|
+
elif pred.dim() == 2: # [H, W] format
|
2157
|
+
pred_classes = pred
|
2158
|
+
else:
|
2159
|
+
raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
|
2160
|
+
|
2161
|
+
# Auto-detect number of classes if not provided
|
2162
|
+
if num_classes is None:
|
2163
|
+
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
2164
|
+
|
2165
|
+
# Calculate precision for each class and average
|
2166
|
+
precision_scores = []
|
2167
|
+
for class_id in range(num_classes):
|
2168
|
+
pred_class = (pred_classes == class_id).float()
|
2169
|
+
target_class = (target == class_id).float()
|
2170
|
+
|
2171
|
+
true_positives = (pred_class * target_class).sum()
|
2172
|
+
predicted_positives = pred_class.sum()
|
2173
|
+
|
2174
|
+
if predicted_positives > 0:
|
2175
|
+
precision = (true_positives + smooth) / (predicted_positives + smooth)
|
2176
|
+
precision_scores.append(precision.item())
|
2177
|
+
|
2178
|
+
return sum(precision_scores) / len(precision_scores) if precision_scores else 0.0
|
2179
|
+
|
2180
|
+
|
2181
|
+
def recall_score(
|
2182
|
+
pred: torch.Tensor,
|
2183
|
+
target: torch.Tensor,
|
2184
|
+
smooth: float = 1e-6,
|
2185
|
+
num_classes: Optional[int] = None,
|
2186
|
+
) -> float:
|
2187
|
+
"""
|
2188
|
+
Calculate recall score (also known as sensitivity) for segmentation (binary or multi-class).
|
2189
|
+
|
2190
|
+
Recall = TP / (TP + FN), where:
|
2191
|
+
- TP (True Positives): Correctly predicted positive pixels
|
2192
|
+
- FN (False Negatives): Incorrectly predicted negative pixels
|
2193
|
+
|
2194
|
+
Args:
|
2195
|
+
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
2196
|
+
target (torch.Tensor): Ground truth mask with shape [H, W].
|
2197
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
2198
|
+
num_classes (int, optional): Number of classes. If None, auto-detected.
|
2199
|
+
|
2200
|
+
Returns:
|
2201
|
+
float: Mean recall score across all classes.
|
2202
|
+
"""
|
2203
|
+
# Convert predictions to class predictions
|
2204
|
+
if pred.dim() == 3: # [C, H, W] format
|
2205
|
+
pred = torch.softmax(pred, dim=0)
|
2206
|
+
pred_classes = torch.argmax(pred, dim=0)
|
2207
|
+
elif pred.dim() == 2: # [H, W] format
|
2208
|
+
pred_classes = pred
|
2209
|
+
else:
|
2210
|
+
raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
|
2211
|
+
|
2212
|
+
# Auto-detect number of classes if not provided
|
2213
|
+
if num_classes is None:
|
2214
|
+
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
2215
|
+
|
2216
|
+
# Calculate recall for each class and average
|
2217
|
+
recall_scores = []
|
2218
|
+
for class_id in range(num_classes):
|
2219
|
+
pred_class = (pred_classes == class_id).float()
|
2220
|
+
target_class = (target == class_id).float()
|
2221
|
+
|
2222
|
+
true_positives = (pred_class * target_class).sum()
|
2223
|
+
actual_positives = target_class.sum()
|
2224
|
+
|
2225
|
+
if actual_positives > 0:
|
2226
|
+
recall = (true_positives + smooth) / (actual_positives + smooth)
|
2227
|
+
recall_scores.append(recall.item())
|
2228
|
+
|
2229
|
+
return sum(recall_scores) / len(recall_scores) if recall_scores else 0.0
|
2230
|
+
|
2231
|
+
|
2012
2232
|
def train_semantic_one_epoch(
|
2013
2233
|
model: torch.nn.Module,
|
2014
2234
|
optimizer: torch.optim.Optimizer,
|
@@ -2090,13 +2310,15 @@ def evaluate_semantic(
|
|
2090
2310
|
num_classes (int): Number of classes for evaluation metrics.
|
2091
2311
|
|
2092
2312
|
Returns:
|
2093
|
-
dict: Evaluation metrics including loss, IoU, and
|
2313
|
+
dict: Evaluation metrics including loss, IoU, F1, precision, and recall.
|
2094
2314
|
"""
|
2095
2315
|
model.eval()
|
2096
2316
|
|
2097
2317
|
total_loss = 0
|
2098
|
-
|
2318
|
+
f1_scores = []
|
2099
2319
|
iou_scores = []
|
2320
|
+
precision_scores = []
|
2321
|
+
recall_scores = []
|
2100
2322
|
num_batches = len(data_loader)
|
2101
2323
|
|
2102
2324
|
with torch.no_grad():
|
@@ -2112,23 +2334,38 @@ def evaluate_semantic(
|
|
2112
2334
|
|
2113
2335
|
# Calculate metrics for each sample in the batch
|
2114
2336
|
for pred, target in zip(outputs, targets):
|
2115
|
-
|
2337
|
+
f1 = f1_score(pred, target, num_classes=num_classes)
|
2116
2338
|
iou = iou_coefficient(pred, target, num_classes=num_classes)
|
2117
|
-
|
2339
|
+
precision = precision_score(pred, target, num_classes=num_classes)
|
2340
|
+
recall = recall_score(pred, target, num_classes=num_classes)
|
2341
|
+
f1_scores.append(f1)
|
2118
2342
|
iou_scores.append(iou)
|
2343
|
+
precision_scores.append(precision)
|
2344
|
+
recall_scores.append(recall)
|
2119
2345
|
|
2120
2346
|
# Calculate metrics
|
2121
2347
|
avg_loss = total_loss / num_batches
|
2122
|
-
|
2348
|
+
avg_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
|
2123
2349
|
avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
|
2124
|
-
|
2125
|
-
|
2350
|
+
avg_precision = (
|
2351
|
+
sum(precision_scores) / len(precision_scores) if precision_scores else 0
|
2352
|
+
)
|
2353
|
+
avg_recall = sum(recall_scores) / len(recall_scores) if recall_scores else 0
|
2354
|
+
|
2355
|
+
return {
|
2356
|
+
"loss": avg_loss,
|
2357
|
+
"F1": avg_f1,
|
2358
|
+
"IoU": avg_iou,
|
2359
|
+
"Precision": avg_precision,
|
2360
|
+
"Recall": avg_recall,
|
2361
|
+
}
|
2126
2362
|
|
2127
2363
|
|
2128
2364
|
def train_segmentation_model(
|
2129
2365
|
images_dir: str,
|
2130
2366
|
labels_dir: str,
|
2131
2367
|
output_dir: str,
|
2368
|
+
input_format: str = "directory",
|
2132
2369
|
architecture: str = "unet",
|
2133
2370
|
encoder_name: str = "resnet34",
|
2134
2371
|
encoder_weights: Optional[str] = "imagenet",
|
@@ -2150,6 +2387,7 @@ def train_segmentation_model(
|
|
2150
2387
|
target_size: Optional[Tuple[int, int]] = None,
|
2151
2388
|
resize_mode: str = "resize",
|
2152
2389
|
num_workers: Optional[int] = None,
|
2390
|
+
early_stopping_patience: Optional[int] = None,
|
2153
2391
|
**kwargs: Any,
|
2154
2392
|
) -> torch.nn.Module:
|
2155
2393
|
"""
|
@@ -2160,9 +2398,17 @@ def train_segmentation_model(
|
|
2160
2398
|
this approach treats the task as pixel-level binary classification.
|
2161
2399
|
|
2162
2400
|
Args:
|
2163
|
-
images_dir (str): Directory containing image GeoTIFF files
|
2164
|
-
|
2401
|
+
images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
|
2402
|
+
or root directory containing images/ subdirectory (for 'yolo' format),
|
2403
|
+
or directory containing images (for 'coco' format).
|
2404
|
+
labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
|
2405
|
+
or path to COCO annotations JSON file (for 'coco' format),
|
2406
|
+
or not used (for 'yolo' format - labels are in images_dir/labels/).
|
2165
2407
|
output_dir (str): Directory to save model checkpoints and results.
|
2408
|
+
input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
|
2409
|
+
- 'directory': Standard directory structure with separate images_dir and labels_dir
|
2410
|
+
- 'coco': COCO JSON format (labels_dir should be path to instances.json)
|
2411
|
+
- 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
|
2166
2412
|
architecture (str): Model architecture ('unet', 'deeplabv3', 'deeplabv3plus', 'fpn',
|
2167
2413
|
'pspnet', 'linknet', 'manet'). Defaults to 'unet'.
|
2168
2414
|
encoder_name (str): Encoder backbone name (e.g., 'resnet34', 'resnet50', 'efficientnet-b0').
|
@@ -2194,6 +2440,8 @@ def train_segmentation_model(
|
|
2194
2440
|
'resize' - Resize images to target_size (may change aspect ratio)
|
2195
2441
|
'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
|
2196
2442
|
num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
|
2443
|
+
early_stopping_patience (int, optional): Number of epochs with no improvement after which
|
2444
|
+
training will be stopped. If None, early stopping is disabled. Defaults to None.
|
2197
2445
|
**kwargs: Additional arguments passed to smp.create_model().
|
2198
2446
|
Returns:
|
2199
2447
|
None: Model weights are saved to output_dir.
|
@@ -2225,45 +2473,63 @@ def train_segmentation_model(
|
|
2225
2473
|
device = get_device()
|
2226
2474
|
print(f"Using device: {device}")
|
2227
2475
|
|
2228
|
-
# Get all image and label files
|
2229
|
-
|
2230
|
-
|
2231
|
-
|
2476
|
+
# Get all image and label files based on input format
|
2477
|
+
if input_format.lower() == "coco":
|
2478
|
+
# Parse COCO format annotations
|
2479
|
+
if verbose:
|
2480
|
+
print(f"Loading COCO format annotations from {labels_dir}")
|
2481
|
+
# For COCO format, labels_dir is path to instances.json
|
2482
|
+
# Labels are typically in a "labels" directory parallel to "annotations"
|
2483
|
+
coco_root = os.path.dirname(os.path.dirname(labels_dir)) # Go up two levels
|
2484
|
+
labels_directory = os.path.join(coco_root, "labels")
|
2485
|
+
image_files, label_files = parse_coco_annotations(
|
2486
|
+
labels_dir, images_dir, labels_directory
|
2487
|
+
)
|
2488
|
+
elif input_format.lower() == "yolo":
|
2489
|
+
# Parse YOLO format annotations
|
2490
|
+
if verbose:
|
2491
|
+
print(f"Loading YOLO format data from {images_dir}")
|
2492
|
+
image_files, label_files = parse_yolo_annotations(images_dir)
|
2493
|
+
else:
|
2494
|
+
# Default: directory format
|
2495
|
+
# Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
|
2496
|
+
image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
|
2497
|
+
label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
|
2498
|
+
|
2499
|
+
image_files = sorted(
|
2500
|
+
[
|
2501
|
+
os.path.join(images_dir, f)
|
2502
|
+
for f in os.listdir(images_dir)
|
2503
|
+
if f.lower().endswith(image_extensions)
|
2504
|
+
]
|
2505
|
+
)
|
2506
|
+
label_files = sorted(
|
2507
|
+
[
|
2508
|
+
os.path.join(labels_dir, f)
|
2509
|
+
for f in os.listdir(labels_dir)
|
2510
|
+
if f.lower().endswith(label_extensions)
|
2511
|
+
]
|
2512
|
+
)
|
2232
2513
|
|
2233
|
-
|
2234
|
-
|
2235
|
-
|
2236
|
-
|
2237
|
-
|
2238
|
-
|
2239
|
-
|
2240
|
-
|
2241
|
-
|
2242
|
-
|
2243
|
-
|
2244
|
-
|
2245
|
-
|
2246
|
-
|
2514
|
+
# Ensure matching files
|
2515
|
+
if len(image_files) != len(label_files):
|
2516
|
+
print("Warning: Number of image files and label files don't match!")
|
2517
|
+
# Find matching files by basename
|
2518
|
+
basenames = [os.path.basename(f) for f in image_files]
|
2519
|
+
label_files = [
|
2520
|
+
os.path.join(labels_dir, os.path.basename(f))
|
2521
|
+
for f in image_files
|
2522
|
+
if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
|
2523
|
+
]
|
2524
|
+
image_files = [
|
2525
|
+
f
|
2526
|
+
for f, b in zip(image_files, basenames)
|
2527
|
+
if os.path.exists(os.path.join(labels_dir, b))
|
2528
|
+
]
|
2529
|
+
print(f"Using {len(image_files)} matching files")
|
2247
2530
|
|
2248
2531
|
print(f"Found {len(image_files)} image files and {len(label_files)} label files")
|
2249
2532
|
|
2250
|
-
# Ensure matching files
|
2251
|
-
if len(image_files) != len(label_files):
|
2252
|
-
print("Warning: Number of image files and label files don't match!")
|
2253
|
-
# Find matching files by basename
|
2254
|
-
basenames = [os.path.basename(f) for f in image_files]
|
2255
|
-
label_files = [
|
2256
|
-
os.path.join(labels_dir, os.path.basename(f))
|
2257
|
-
for f in image_files
|
2258
|
-
if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
|
2259
|
-
]
|
2260
|
-
image_files = [
|
2261
|
-
f
|
2262
|
-
for f, b in zip(image_files, basenames)
|
2263
|
-
if os.path.exists(os.path.join(labels_dir, b))
|
2264
|
-
]
|
2265
|
-
print(f"Using {len(image_files)} matching files")
|
2266
|
-
|
2267
2533
|
if len(image_files) == 0:
|
2268
2534
|
raise FileNotFoundError("No matching image and label files found")
|
2269
2535
|
|
@@ -2405,7 +2671,7 @@ def train_segmentation_model(
|
|
2405
2671
|
print(f"Using {torch.cuda.device_count()} GPUs for training")
|
2406
2672
|
model = torch.nn.DataParallel(model)
|
2407
2673
|
|
2408
|
-
# Set up loss function (CrossEntropyLoss for multi-class, can also use
|
2674
|
+
# Set up loss function (CrossEntropyLoss for multi-class, can also use F1Loss)
|
2409
2675
|
criterion = torch.nn.CrossEntropyLoss()
|
2410
2676
|
|
2411
2677
|
# Set up optimizer
|
@@ -2423,8 +2689,11 @@ def train_segmentation_model(
|
|
2423
2689
|
train_losses = []
|
2424
2690
|
val_losses = []
|
2425
2691
|
val_ious = []
|
2426
|
-
|
2692
|
+
val_f1s = []
|
2693
|
+
val_precisions = []
|
2694
|
+
val_recalls = []
|
2427
2695
|
start_epoch = 0
|
2696
|
+
epochs_without_improvement = 0
|
2428
2697
|
|
2429
2698
|
# Load checkpoint if provided
|
2430
2699
|
if checkpoint_path is not None:
|
@@ -2459,8 +2728,15 @@ def train_segmentation_model(
|
|
2459
2728
|
val_losses = checkpoint["val_losses"]
|
2460
2729
|
if "val_ious" in checkpoint:
|
2461
2730
|
val_ious = checkpoint["val_ious"]
|
2462
|
-
if "
|
2463
|
-
|
2731
|
+
if "val_f1s" in checkpoint:
|
2732
|
+
val_f1s = checkpoint["val_f1s"]
|
2733
|
+
# Also check for old val_dices format for backward compatibility
|
2734
|
+
elif "val_dices" in checkpoint:
|
2735
|
+
val_f1s = checkpoint["val_dices"]
|
2736
|
+
if "val_precisions" in checkpoint:
|
2737
|
+
val_precisions = checkpoint["val_precisions"]
|
2738
|
+
if "val_recalls" in checkpoint:
|
2739
|
+
val_recalls = checkpoint["val_recalls"]
|
2464
2740
|
|
2465
2741
|
print(f"Resuming training from epoch {start_epoch}")
|
2466
2742
|
print(f"Previous best IoU: {best_iou:.4f}")
|
@@ -2500,7 +2776,9 @@ def train_segmentation_model(
|
|
2500
2776
|
)
|
2501
2777
|
val_losses.append(eval_metrics["loss"])
|
2502
2778
|
val_ious.append(eval_metrics["IoU"])
|
2503
|
-
|
2779
|
+
val_f1s.append(eval_metrics["F1"])
|
2780
|
+
val_precisions.append(eval_metrics["Precision"])
|
2781
|
+
val_recalls.append(eval_metrics["Recall"])
|
2504
2782
|
|
2505
2783
|
# Update learning rate
|
2506
2784
|
lr_scheduler.step(eval_metrics["loss"])
|
@@ -2511,14 +2789,28 @@ def train_segmentation_model(
|
|
2511
2789
|
f"Train Loss: {train_loss:.4f}, "
|
2512
2790
|
f"Val Loss: {eval_metrics['loss']:.4f}, "
|
2513
2791
|
f"Val IoU: {eval_metrics['IoU']:.4f}, "
|
2514
|
-
f"Val
|
2792
|
+
f"Val F1: {eval_metrics['F1']:.4f}, "
|
2793
|
+
f"Val Precision: {eval_metrics['Precision']:.4f}, "
|
2794
|
+
f"Val Recall: {eval_metrics['Recall']:.4f}"
|
2515
2795
|
)
|
2516
2796
|
|
2517
|
-
# Save best model
|
2797
|
+
# Save best model and check for early stopping
|
2518
2798
|
if eval_metrics["IoU"] > best_iou:
|
2519
2799
|
best_iou = eval_metrics["IoU"]
|
2800
|
+
epochs_without_improvement = 0
|
2520
2801
|
print(f"Saving best model with IoU: {best_iou:.4f}")
|
2521
2802
|
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
2803
|
+
else:
|
2804
|
+
epochs_without_improvement += 1
|
2805
|
+
if (
|
2806
|
+
early_stopping_patience is not None
|
2807
|
+
and epochs_without_improvement >= early_stopping_patience
|
2808
|
+
):
|
2809
|
+
print(
|
2810
|
+
f"\nEarly stopping triggered after {epochs_without_improvement} epochs without improvement"
|
2811
|
+
)
|
2812
|
+
print(f"Best validation IoU: {best_iou:.4f}")
|
2813
|
+
break
|
2522
2814
|
|
2523
2815
|
# Save checkpoint every 10 epochs (if not save_best_only)
|
2524
2816
|
if not save_best_only and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
|
@@ -2536,7 +2828,9 @@ def train_segmentation_model(
|
|
2536
2828
|
"train_losses": train_losses,
|
2537
2829
|
"val_losses": val_losses,
|
2538
2830
|
"val_ious": val_ious,
|
2539
|
-
"
|
2831
|
+
"val_f1s": val_f1s,
|
2832
|
+
"val_precisions": val_precisions,
|
2833
|
+
"val_recalls": val_recalls,
|
2540
2834
|
},
|
2541
2835
|
os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
|
2542
2836
|
)
|
@@ -2549,7 +2843,9 @@ def train_segmentation_model(
|
|
2549
2843
|
"train_losses": train_losses,
|
2550
2844
|
"val_losses": val_losses,
|
2551
2845
|
"val_ious": val_ious,
|
2552
|
-
"
|
2846
|
+
"val_f1s": val_f1s,
|
2847
|
+
"val_precisions": val_precisions,
|
2848
|
+
"val_recalls": val_recalls,
|
2553
2849
|
}
|
2554
2850
|
torch.save(history, os.path.join(output_dir, "training_history.pth"))
|
2555
2851
|
|
@@ -2565,7 +2861,9 @@ def train_segmentation_model(
|
|
2565
2861
|
f.write(f"Total epochs: {num_epochs}\n")
|
2566
2862
|
f.write(f"Best validation IoU: {best_iou:.4f}\n")
|
2567
2863
|
f.write(f"Final validation IoU: {val_ious[-1]:.4f}\n")
|
2568
|
-
f.write(f"Final validation
|
2864
|
+
f.write(f"Final validation F1: {val_f1s[-1]:.4f}\n")
|
2865
|
+
f.write(f"Final validation Precision: {val_precisions[-1]:.4f}\n")
|
2866
|
+
f.write(f"Final validation Recall: {val_recalls[-1]:.4f}\n")
|
2569
2867
|
f.write(f"Final validation loss: {val_losses[-1]:.4f}\n")
|
2570
2868
|
|
2571
2869
|
print(f"Training complete! Best IoU: {best_iou:.4f}")
|
@@ -2594,10 +2892,10 @@ def train_segmentation_model(
|
|
2594
2892
|
plt.grid(True)
|
2595
2893
|
|
2596
2894
|
plt.subplot(1, 3, 3)
|
2597
|
-
plt.plot(
|
2598
|
-
plt.title("
|
2895
|
+
plt.plot(val_f1s, label="Val F1")
|
2896
|
+
plt.title("F1 Score")
|
2599
2897
|
plt.xlabel("Epoch")
|
2600
|
-
plt.ylabel("
|
2898
|
+
plt.ylabel("F1")
|
2601
2899
|
plt.legend()
|
2602
2900
|
plt.grid(True)
|
2603
2901
|
|
@@ -2627,6 +2925,7 @@ def semantic_inference_on_geotiff(
|
|
2627
2925
|
device: Optional[torch.device] = None,
|
2628
2926
|
probability_path: Optional[str] = None,
|
2629
2927
|
probability_threshold: Optional[float] = None,
|
2928
|
+
save_class_probabilities: bool = False,
|
2630
2929
|
quiet: bool = False,
|
2631
2930
|
**kwargs: Any,
|
2632
2931
|
) -> Tuple[str, float]:
|
@@ -2648,6 +2947,8 @@ def semantic_inference_on_geotiff(
|
|
2648
2947
|
probability_threshold (float, optional): Probability threshold for binary classification.
|
2649
2948
|
Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
|
2650
2949
|
are classified as class 1, otherwise class 0. If None (default), uses argmax.
|
2950
|
+
save_class_probabilities (bool): If True and probability_path is provided, saves each
|
2951
|
+
class probability as a separate single-band file. Defaults to False.
|
2651
2952
|
quiet (bool): If True, suppress progress bar. Defaults to False.
|
2652
2953
|
**kwargs: Additional arguments.
|
2653
2954
|
|
@@ -2864,7 +3165,7 @@ def semantic_inference_on_geotiff(
|
|
2864
3165
|
prob_meta = meta.copy()
|
2865
3166
|
prob_meta.update({"count": num_classes, "dtype": "float32"})
|
2866
3167
|
|
2867
|
-
# Save normalized probabilities
|
3168
|
+
# Save normalized probabilities as multi-band raster
|
2868
3169
|
with rasterio.open(probability_path, "w", **prob_meta) as dst:
|
2869
3170
|
for class_idx in range(num_classes):
|
2870
3171
|
# Normalize probabilities
|
@@ -2878,6 +3179,36 @@ def semantic_inference_on_geotiff(
|
|
2878
3179
|
if not quiet:
|
2879
3180
|
print(f"Saved probability map to {probability_path}")
|
2880
3181
|
|
3182
|
+
# Save individual class probabilities if requested
|
3183
|
+
if save_class_probabilities:
|
3184
|
+
# Prepare single-band metadata
|
3185
|
+
single_band_meta = meta.copy()
|
3186
|
+
single_band_meta.update({"count": 1, "dtype": "float32"})
|
3187
|
+
|
3188
|
+
# Get base filename and extension
|
3189
|
+
prob_base = os.path.splitext(probability_path)[0]
|
3190
|
+
prob_ext = os.path.splitext(probability_path)[1]
|
3191
|
+
|
3192
|
+
for class_idx in range(num_classes):
|
3193
|
+
# Create filename for this class
|
3194
|
+
class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
|
3195
|
+
|
3196
|
+
# Normalize probabilities
|
3197
|
+
prob_band = np.zeros((height, width), dtype=np.float32)
|
3198
|
+
prob_band[valid_pixels] = (
|
3199
|
+
prob_accumulator[class_idx, valid_pixels]
|
3200
|
+
/ count_accumulator[valid_pixels]
|
3201
|
+
)
|
3202
|
+
|
3203
|
+
# Save single-band file
|
3204
|
+
with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
|
3205
|
+
dst.write(prob_band, 1)
|
3206
|
+
|
3207
|
+
if not quiet:
|
3208
|
+
print(
|
3209
|
+
f"Saved class {class_idx} probability to {class_prob_path}"
|
3210
|
+
)
|
3211
|
+
|
2881
3212
|
return output_path, inference_time
|
2882
3213
|
|
2883
3214
|
|
@@ -2894,6 +3225,7 @@ def semantic_inference_on_image(
|
|
2894
3225
|
binary_output: bool = True,
|
2895
3226
|
probability_path: Optional[str] = None,
|
2896
3227
|
probability_threshold: Optional[float] = None,
|
3228
|
+
save_class_probabilities: bool = False,
|
2897
3229
|
quiet: bool = False,
|
2898
3230
|
**kwargs: Any,
|
2899
3231
|
) -> Tuple[str, float]:
|
@@ -2916,6 +3248,8 @@ def semantic_inference_on_image(
|
|
2916
3248
|
probability_threshold (float, optional): Probability threshold for binary classification.
|
2917
3249
|
Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
|
2918
3250
|
are classified as class 1, otherwise class 0. If None (default), uses argmax.
|
3251
|
+
save_class_probabilities (bool): If True and probability_path is provided, saves each
|
3252
|
+
class probability as a separate single-band file. Defaults to False.
|
2919
3253
|
quiet (bool): If True, suppress progress bar. Defaults to False.
|
2920
3254
|
**kwargs: Additional arguments.
|
2921
3255
|
|
@@ -3194,7 +3528,7 @@ def semantic_inference_on_image(
|
|
3194
3528
|
"transform": transform,
|
3195
3529
|
}
|
3196
3530
|
|
3197
|
-
# Save normalized probabilities
|
3531
|
+
# Save normalized probabilities as multi-band raster
|
3198
3532
|
with rasterio.open(probability_path, "w", **prob_meta) as dst:
|
3199
3533
|
for class_idx in range(num_classes):
|
3200
3534
|
# Normalize probabilities
|
@@ -3205,6 +3539,39 @@ def semantic_inference_on_image(
|
|
3205
3539
|
if not quiet:
|
3206
3540
|
print(f"Saved probability map to {probability_path}")
|
3207
3541
|
|
3542
|
+
# Save individual class probabilities if requested
|
3543
|
+
if save_class_probabilities:
|
3544
|
+
# Prepare single-band metadata
|
3545
|
+
single_band_meta = {
|
3546
|
+
"driver": "GTiff",
|
3547
|
+
"height": height,
|
3548
|
+
"width": width,
|
3549
|
+
"count": 1,
|
3550
|
+
"dtype": "float32",
|
3551
|
+
"transform": transform,
|
3552
|
+
}
|
3553
|
+
|
3554
|
+
# Get base filename and extension
|
3555
|
+
prob_base = os.path.splitext(probability_path)[0]
|
3556
|
+
prob_ext = os.path.splitext(probability_path)[1]
|
3557
|
+
|
3558
|
+
for class_idx in range(num_classes):
|
3559
|
+
# Create filename for this class
|
3560
|
+
class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
|
3561
|
+
|
3562
|
+
# Normalize probabilities
|
3563
|
+
prob_band = np.zeros((height, width), dtype=np.float32)
|
3564
|
+
prob_band[valid_pixels] = normalized_probs[class_idx, valid_pixels]
|
3565
|
+
|
3566
|
+
# Save single-band file
|
3567
|
+
with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
|
3568
|
+
dst.write(prob_band, 1)
|
3569
|
+
|
3570
|
+
if not quiet:
|
3571
|
+
print(
|
3572
|
+
f"Saved class {class_idx} probability to {class_prob_path}"
|
3573
|
+
)
|
3574
|
+
|
3208
3575
|
return output_path, inference_time
|
3209
3576
|
|
3210
3577
|
|
@@ -3222,6 +3589,7 @@ def semantic_segmentation(
|
|
3222
3589
|
device: Optional[torch.device] = None,
|
3223
3590
|
probability_path: Optional[str] = None,
|
3224
3591
|
probability_threshold: Optional[float] = None,
|
3592
|
+
save_class_probabilities: bool = False,
|
3225
3593
|
quiet: bool = False,
|
3226
3594
|
**kwargs: Any,
|
3227
3595
|
) -> None:
|
@@ -3244,11 +3612,16 @@ def semantic_segmentation(
|
|
3244
3612
|
batch_size (int): Batch size for inference.
|
3245
3613
|
device (torch.device, optional): Device to run inference on.
|
3246
3614
|
probability_path (str, optional): Path to save probability map. If provided,
|
3247
|
-
the normalized class probabilities will be saved as a multi-band raster
|
3615
|
+
the normalized class probabilities will be saved as a multi-band raster
|
3616
|
+
where each band contains probabilities for each class.
|
3248
3617
|
probability_threshold (float, optional): Probability threshold for binary classification.
|
3249
3618
|
Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
|
3250
3619
|
are classified as class 1, otherwise class 0. If None (default), uses argmax.
|
3251
3620
|
Must be between 0 and 1.
|
3621
|
+
save_class_probabilities (bool): If True and probability_path is provided, saves each
|
3622
|
+
class probability as a separate single-band file. Files will be named like
|
3623
|
+
"probability_class_0.tif", "probability_class_1.tif", etc. in the same directory
|
3624
|
+
as probability_path. Defaults to False.
|
3252
3625
|
quiet (bool): If True, suppress progress bar. Defaults to False.
|
3253
3626
|
**kwargs: Additional arguments.
|
3254
3627
|
|
@@ -3325,6 +3698,7 @@ def semantic_segmentation(
|
|
3325
3698
|
device=device,
|
3326
3699
|
probability_path=probability_path,
|
3327
3700
|
probability_threshold=probability_threshold,
|
3701
|
+
save_class_probabilities=save_class_probabilities,
|
3328
3702
|
quiet=quiet,
|
3329
3703
|
**kwargs,
|
3330
3704
|
)
|
@@ -3345,6 +3719,7 @@ def semantic_segmentation(
|
|
3345
3719
|
binary_output=True, # Convert to binary output for better visualization
|
3346
3720
|
probability_path=probability_path,
|
3347
3721
|
probability_threshold=probability_threshold,
|
3722
|
+
save_class_probabilities=save_class_probabilities,
|
3348
3723
|
quiet=quiet,
|
3349
3724
|
**kwargs,
|
3350
3725
|
)
|
@@ -3527,6 +3902,7 @@ def train_instance_segmentation_model(
|
|
3527
3902
|
images_dir: str,
|
3528
3903
|
labels_dir: str,
|
3529
3904
|
output_dir: str,
|
3905
|
+
input_format: str = "directory",
|
3530
3906
|
num_classes: int = 2,
|
3531
3907
|
num_channels: int = 3,
|
3532
3908
|
batch_size: int = 4,
|
@@ -3545,9 +3921,17 @@ def train_instance_segmentation_model(
|
|
3545
3921
|
This is a wrapper function for train_MaskRCNN_model with clearer naming.
|
3546
3922
|
|
3547
3923
|
Args:
|
3548
|
-
images_dir (str): Directory containing image GeoTIFF files
|
3549
|
-
|
3924
|
+
images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
|
3925
|
+
or root directory containing images/ subdirectory (for 'yolo' format),
|
3926
|
+
or directory containing images (for 'coco' format).
|
3927
|
+
labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
|
3928
|
+
or path to COCO annotations JSON file (for 'coco' format),
|
3929
|
+
or not used (for 'yolo' format - labels are in images_dir/labels/).
|
3550
3930
|
output_dir (str): Directory to save model checkpoints and results.
|
3931
|
+
input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
|
3932
|
+
- 'directory': Standard directory structure with separate images_dir and labels_dir
|
3933
|
+
- 'coco': COCO JSON format (labels_dir should be path to instances.json)
|
3934
|
+
- 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
|
3551
3935
|
num_classes (int): Number of classes (including background). Defaults to 2.
|
3552
3936
|
num_channels (int): Number of input channels. Defaults to 3.
|
3553
3937
|
batch_size (int): Batch size for training. Defaults to 4.
|
@@ -3572,6 +3956,7 @@ def train_instance_segmentation_model(
|
|
3572
3956
|
images_dir=images_dir,
|
3573
3957
|
labels_dir=labels_dir,
|
3574
3958
|
output_dir=output_dir,
|
3959
|
+
input_format=input_format,
|
3575
3960
|
num_channels=num_channels,
|
3576
3961
|
model=model,
|
3577
3962
|
batch_size=batch_size,
|
@@ -3756,3 +4141,307 @@ def instance_segmentation_batch(
|
|
3756
4141
|
continue
|
3757
4142
|
|
3758
4143
|
print(f"Batch processing completed. Results saved to {output_dir}")
|
4144
|
+
|
4145
|
+
|
4146
|
+
def lightly_train_model(
|
4147
|
+
data_dir: str,
|
4148
|
+
output_dir: str,
|
4149
|
+
model: str = "torchvision/resnet50",
|
4150
|
+
method: str = "dinov2_distillation",
|
4151
|
+
epochs: int = 100,
|
4152
|
+
batch_size: int = 64,
|
4153
|
+
learning_rate: float = 1e-4,
|
4154
|
+
**kwargs: Any,
|
4155
|
+
) -> str:
|
4156
|
+
"""
|
4157
|
+
Train a model using Lightly Train for self-supervised pretraining.
|
4158
|
+
|
4159
|
+
Args:
|
4160
|
+
data_dir (str): Directory containing unlabeled images for training.
|
4161
|
+
output_dir (str): Directory to save training outputs and model checkpoints.
|
4162
|
+
model (str): Model architecture to train. Supports models from torchvision,
|
4163
|
+
timm, ultralytics, etc. Default is "torchvision/resnet50".
|
4164
|
+
method (str): Self-supervised learning method. Options include:
|
4165
|
+
- "simclr": Works with CNN models (ResNet, EfficientNet, etc.)
|
4166
|
+
- "dino": Works with both CNNs and ViTs
|
4167
|
+
- "dinov2": Requires ViT models only
|
4168
|
+
- "dinov2_distillation": Requires ViT models only (recommended for ViTs)
|
4169
|
+
Default is "dinov2_distillation".
|
4170
|
+
epochs (int): Number of training epochs. Default is 100.
|
4171
|
+
batch_size (int): Batch size for training. Default is 64.
|
4172
|
+
learning_rate (float): Learning rate for training. Default is 1e-4.
|
4173
|
+
**kwargs: Additional arguments passed to lightly_train.train().
|
4174
|
+
|
4175
|
+
Returns:
|
4176
|
+
str: Path to the exported model file.
|
4177
|
+
|
4178
|
+
Raises:
|
4179
|
+
ImportError: If lightly-train is not installed.
|
4180
|
+
ValueError: If data_dir does not exist, is empty, or incompatible model/method.
|
4181
|
+
|
4182
|
+
Note:
|
4183
|
+
Model/Method compatibility:
|
4184
|
+
- CNN models (ResNet, EfficientNet): Use "simclr" or "dino"
|
4185
|
+
- ViT models: Use "dinov2", "dinov2_distillation", or "dino"
|
4186
|
+
|
4187
|
+
Example:
|
4188
|
+
>>> # For CNN models (ResNet, EfficientNet)
|
4189
|
+
>>> model_path = lightly_train_model(
|
4190
|
+
... data_dir="path/to/unlabeled/images",
|
4191
|
+
... output_dir="path/to/output",
|
4192
|
+
... model="torchvision/resnet50",
|
4193
|
+
... method="simclr", # Use simclr for CNNs
|
4194
|
+
... epochs=50
|
4195
|
+
... )
|
4196
|
+
>>> # For ViT models
|
4197
|
+
>>> model_path = lightly_train_model(
|
4198
|
+
... data_dir="path/to/unlabeled/images",
|
4199
|
+
... output_dir="path/to/output",
|
4200
|
+
... model="timm/vit_base_patch16_224",
|
4201
|
+
... method="dinov2", # dinov2 requires ViT
|
4202
|
+
... epochs=50
|
4203
|
+
... )
|
4204
|
+
"""
|
4205
|
+
if not LIGHTLY_TRAIN_AVAILABLE:
|
4206
|
+
raise ImportError(
|
4207
|
+
"lightly-train is not installed. Please install it with: "
|
4208
|
+
"pip install lightly-train"
|
4209
|
+
)
|
4210
|
+
|
4211
|
+
if not os.path.exists(data_dir):
|
4212
|
+
raise ValueError(f"Data directory does not exist: {data_dir}")
|
4213
|
+
|
4214
|
+
# Check if data directory contains images
|
4215
|
+
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.tif", "*.tiff", "*.bmp"]
|
4216
|
+
image_files = []
|
4217
|
+
for ext in image_extensions:
|
4218
|
+
image_files.extend(glob.glob(os.path.join(data_dir, "**", ext), recursive=True))
|
4219
|
+
|
4220
|
+
if not image_files:
|
4221
|
+
raise ValueError(f"No image files found in {data_dir}")
|
4222
|
+
|
4223
|
+
# Validate model/method compatibility
|
4224
|
+
is_vit_model = "vit" in model.lower() or "vision_transformer" in model.lower()
|
4225
|
+
|
4226
|
+
if method in ["dinov2", "dinov2_distillation"] and not is_vit_model:
|
4227
|
+
raise ValueError(
|
4228
|
+
f"Method '{method}' requires a Vision Transformer (ViT) model, but got '{model}'.\n"
|
4229
|
+
f"Solutions:\n"
|
4230
|
+
f" 1. Use a ViT model: model='timm/vit_base_patch16_224'\n"
|
4231
|
+
f" 2. Use a CNN-compatible method: method='simclr' or method='dino'\n"
|
4232
|
+
f"\nFor CNN models (ResNet, EfficientNet), use 'simclr' or 'dino'.\n"
|
4233
|
+
f"For ViT models, use 'dinov2', 'dinov2_distillation', or 'dino'."
|
4234
|
+
)
|
4235
|
+
|
4236
|
+
print(f"Found {len(image_files)} images in {data_dir}")
|
4237
|
+
print(f"Starting self-supervised pretraining with {method} method...")
|
4238
|
+
print(f"Model: {model}")
|
4239
|
+
|
4240
|
+
# Create output directory
|
4241
|
+
os.makedirs(output_dir, exist_ok=True)
|
4242
|
+
|
4243
|
+
# Detect if running in notebook environment and set appropriate configuration
|
4244
|
+
def is_notebook():
|
4245
|
+
try:
|
4246
|
+
from IPython import get_ipython
|
4247
|
+
|
4248
|
+
if get_ipython() is not None:
|
4249
|
+
return True
|
4250
|
+
except (ImportError, NameError):
|
4251
|
+
pass
|
4252
|
+
return False
|
4253
|
+
|
4254
|
+
# Force single-device training in notebooks to avoid DDP strategy issues
|
4255
|
+
if is_notebook():
|
4256
|
+
# Only override if not explicitly set by user
|
4257
|
+
if "accelerator" not in kwargs:
|
4258
|
+
# Use CPU in notebooks to avoid DDP incompatibility
|
4259
|
+
# Users can still override by passing accelerator='gpu'
|
4260
|
+
kwargs["accelerator"] = "cpu"
|
4261
|
+
if "devices" not in kwargs:
|
4262
|
+
kwargs["devices"] = 1 # Force single device
|
4263
|
+
|
4264
|
+
# Train the model using Lightly Train
|
4265
|
+
lightly_train.train(
|
4266
|
+
out=output_dir,
|
4267
|
+
data=data_dir,
|
4268
|
+
model=model,
|
4269
|
+
method=method,
|
4270
|
+
epochs=epochs,
|
4271
|
+
batch_size=batch_size,
|
4272
|
+
**kwargs,
|
4273
|
+
)
|
4274
|
+
|
4275
|
+
# Return path to the exported model
|
4276
|
+
exported_model_path = os.path.join(
|
4277
|
+
output_dir, "exported_models", "exported_last.pt"
|
4278
|
+
)
|
4279
|
+
|
4280
|
+
if os.path.exists(exported_model_path):
|
4281
|
+
print(
|
4282
|
+
f"Model training completed. Exported model saved to: {exported_model_path}"
|
4283
|
+
)
|
4284
|
+
return exported_model_path
|
4285
|
+
else:
|
4286
|
+
# Check for alternative export paths
|
4287
|
+
possible_paths = [
|
4288
|
+
os.path.join(output_dir, "exported_models", "exported_best.pt"),
|
4289
|
+
os.path.join(output_dir, "checkpoints", "last.ckpt"),
|
4290
|
+
]
|
4291
|
+
|
4292
|
+
for path in possible_paths:
|
4293
|
+
if os.path.exists(path):
|
4294
|
+
print(f"Model training completed. Exported model saved to: {path}")
|
4295
|
+
return path
|
4296
|
+
|
4297
|
+
print(f"Model training completed. Output saved to: {output_dir}")
|
4298
|
+
return output_dir
|
4299
|
+
|
4300
|
+
|
4301
|
+
def load_lightly_pretrained_model(
|
4302
|
+
model_path: str,
|
4303
|
+
model_architecture: str = "torchvision/resnet50",
|
4304
|
+
device: str = None,
|
4305
|
+
) -> torch.nn.Module:
|
4306
|
+
"""
|
4307
|
+
Load a pretrained model from Lightly Train.
|
4308
|
+
|
4309
|
+
Args:
|
4310
|
+
model_path (str): Path to the pretrained model file (.pt format).
|
4311
|
+
model_architecture (str): Architecture of the model to load.
|
4312
|
+
Default is "torchvision/resnet50".
|
4313
|
+
device (str): Device to load the model on. If None, uses CPU.
|
4314
|
+
|
4315
|
+
Returns:
|
4316
|
+
torch.nn.Module: Loaded pretrained model ready for fine-tuning.
|
4317
|
+
|
4318
|
+
Raises:
|
4319
|
+
FileNotFoundError: If model_path does not exist.
|
4320
|
+
ImportError: If required libraries are not available.
|
4321
|
+
|
4322
|
+
Example:
|
4323
|
+
>>> model = load_lightly_pretrained_model(
|
4324
|
+
... model_path="path/to/pretrained_model.pt",
|
4325
|
+
... model_architecture="torchvision/resnet50",
|
4326
|
+
... device="cuda"
|
4327
|
+
... )
|
4328
|
+
>>> # Fine-tune the model with your existing training pipeline
|
4329
|
+
"""
|
4330
|
+
if not os.path.exists(model_path):
|
4331
|
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
4332
|
+
|
4333
|
+
print(f"Loading pretrained model from: {model_path}")
|
4334
|
+
|
4335
|
+
# Load the model based on architecture
|
4336
|
+
if model_architecture.startswith("torchvision/"):
|
4337
|
+
model_name = model_architecture.replace("torchvision/", "")
|
4338
|
+
|
4339
|
+
# Import the model from torchvision
|
4340
|
+
if hasattr(torchvision.models, model_name):
|
4341
|
+
model = getattr(torchvision.models, model_name)()
|
4342
|
+
else:
|
4343
|
+
raise ValueError(f"Unknown torchvision model: {model_name}")
|
4344
|
+
|
4345
|
+
elif model_architecture.startswith("timm/"):
|
4346
|
+
try:
|
4347
|
+
import timm
|
4348
|
+
|
4349
|
+
model_name = model_architecture.replace("timm/", "")
|
4350
|
+
model = timm.create_model(model_name)
|
4351
|
+
except ImportError:
|
4352
|
+
raise ImportError(
|
4353
|
+
"timm is required for TIMM models. Install with: pip install timm"
|
4354
|
+
)
|
4355
|
+
|
4356
|
+
else:
|
4357
|
+
# For other architectures, try to import from torchvision as default
|
4358
|
+
try:
|
4359
|
+
model = getattr(torchvision.models, model_architecture)()
|
4360
|
+
except AttributeError:
|
4361
|
+
raise ValueError(f"Unsupported model architecture: {model_architecture}")
|
4362
|
+
|
4363
|
+
# Load the pretrained weights
|
4364
|
+
try:
|
4365
|
+
state_dict = torch.load(model_path, map_location=device, weights_only=True)
|
4366
|
+
except TypeError:
|
4367
|
+
# For backward compatibility with older PyTorch versions
|
4368
|
+
state_dict = torch.load(model_path, map_location=device)
|
4369
|
+
model.load_state_dict(state_dict)
|
4370
|
+
|
4371
|
+
print(f"Successfully loaded pretrained model: {model_architecture}")
|
4372
|
+
return model
|
4373
|
+
|
4374
|
+
|
4375
|
+
def lightly_embed_images(
|
4376
|
+
data_dir: str,
|
4377
|
+
model_path: str,
|
4378
|
+
output_path: str,
|
4379
|
+
model_architecture: str = None, # Deprecated, kept for backwards compatibility
|
4380
|
+
batch_size: int = 64,
|
4381
|
+
**kwargs: Any,
|
4382
|
+
) -> str:
|
4383
|
+
"""
|
4384
|
+
Generate embeddings for images using a Lightly Train pretrained model.
|
4385
|
+
|
4386
|
+
Args:
|
4387
|
+
data_dir (str): Directory containing images to embed.
|
4388
|
+
model_path (str): Path to the pretrained model checkpoint file (.ckpt).
|
4389
|
+
output_path (str): Path to save the embeddings (as .pt file).
|
4390
|
+
model_architecture (str): Architecture of the pretrained model (deprecated,
|
4391
|
+
kept for backwards compatibility but not used). The model architecture
|
4392
|
+
is automatically loaded from the checkpoint.
|
4393
|
+
batch_size (int): Batch size for embedding generation. Default is 64.
|
4394
|
+
**kwargs: Additional arguments passed to lightly_train.embed().
|
4395
|
+
Supported kwargs include: image_size, num_workers, accelerator, etc.
|
4396
|
+
|
4397
|
+
Returns:
|
4398
|
+
str: Path to the saved embeddings file.
|
4399
|
+
|
4400
|
+
Raises:
|
4401
|
+
ImportError: If lightly-train is not installed.
|
4402
|
+
FileNotFoundError: If data_dir or model_path does not exist.
|
4403
|
+
|
4404
|
+
Note:
|
4405
|
+
The model_path should point to a .ckpt file from the training output,
|
4406
|
+
typically located at: output_dir/checkpoints/last.ckpt
|
4407
|
+
|
4408
|
+
Example:
|
4409
|
+
>>> embeddings_path = lightly_embed_images(
|
4410
|
+
... data_dir="path/to/images",
|
4411
|
+
... model_path="output_dir/checkpoints/last.ckpt",
|
4412
|
+
... output_path="embeddings.pt",
|
4413
|
+
... batch_size=32
|
4414
|
+
... )
|
4415
|
+
>>> print(f"Embeddings saved to: {embeddings_path}")
|
4416
|
+
"""
|
4417
|
+
if not LIGHTLY_TRAIN_AVAILABLE:
|
4418
|
+
raise ImportError(
|
4419
|
+
"lightly-train is not installed. Please install it with: "
|
4420
|
+
"pip install lightly-train"
|
4421
|
+
)
|
4422
|
+
|
4423
|
+
if not os.path.exists(data_dir):
|
4424
|
+
raise FileNotFoundError(f"Data directory does not exist: {data_dir}")
|
4425
|
+
|
4426
|
+
if not os.path.exists(model_path):
|
4427
|
+
raise FileNotFoundError(f"Model file does not exist: {model_path}")
|
4428
|
+
|
4429
|
+
print(f"Generating embeddings for images in: {data_dir}")
|
4430
|
+
print(f"Using pretrained model: {model_path}")
|
4431
|
+
|
4432
|
+
output_dir = os.path.dirname(output_path)
|
4433
|
+
if output_dir:
|
4434
|
+
os.makedirs(output_dir, exist_ok=True)
|
4435
|
+
|
4436
|
+
# Generate embeddings using Lightly Train
|
4437
|
+
# Note: model_architecture is not used - it's inferred from the checkpoint
|
4438
|
+
lightly_train.embed(
|
4439
|
+
out=output_path,
|
4440
|
+
data=data_dir,
|
4441
|
+
checkpoint=model_path,
|
4442
|
+
batch_size=batch_size,
|
4443
|
+
**kwargs,
|
4444
|
+
)
|
4445
|
+
|
4446
|
+
print(f"Embeddings saved to: {output_path}")
|
4447
|
+
return output_path
|