geoai-py 0.15.0__py2.py3-none-any.whl → 0.18.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +16 -1
- geoai/agents/__init__.py +4 -0
- geoai/agents/catalog_models.py +51 -0
- geoai/agents/catalog_tools.py +907 -0
- geoai/agents/geo_agents.py +934 -42
- geoai/agents/stac_models.py +67 -0
- geoai/agents/stac_tools.py +435 -0
- geoai/change_detection.py +32 -7
- geoai/download.py +5 -1
- geoai/geoai.py +3 -0
- geoai/timm_segment.py +4 -1
- geoai/tools/__init__.py +65 -0
- geoai/tools/cloudmask.py +431 -0
- geoai/tools/multiclean.py +357 -0
- geoai/train.py +694 -35
- geoai/utils.py +752 -208
- {geoai_py-0.15.0.dist-info → geoai_py-0.18.0.dist-info}/METADATA +6 -2
- geoai_py-0.18.0.dist-info/RECORD +33 -0
- geoai_py-0.15.0.dist-info/RECORD +0 -26
- {geoai_py-0.15.0.dist-info → geoai_py-0.18.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.18.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.18.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.18.0.dist-info}/top_level.txt +0 -0
geoai/train.py
CHANGED
|
@@ -34,6 +34,14 @@ try:
|
|
|
34
34
|
except ImportError:
|
|
35
35
|
SMP_AVAILABLE = False
|
|
36
36
|
|
|
37
|
+
# Additional imports for Lightly Train
|
|
38
|
+
try:
|
|
39
|
+
import lightly_train
|
|
40
|
+
|
|
41
|
+
LIGHTLY_TRAIN_AVAILABLE = True
|
|
42
|
+
except ImportError:
|
|
43
|
+
LIGHTLY_TRAIN_AVAILABLE = False
|
|
44
|
+
|
|
37
45
|
|
|
38
46
|
def parse_coco_annotations(
|
|
39
47
|
coco_json_path: str, images_dir: str, labels_dir: str
|
|
@@ -1428,8 +1436,12 @@ def instance_segmentation_inference_on_geotiff(
|
|
|
1428
1436
|
# Apply Non-Maximum Suppression to handle overlapping detections
|
|
1429
1437
|
if len(all_detections) > 0:
|
|
1430
1438
|
# Convert to tensors for NMS
|
|
1431
|
-
boxes = torch.tensor(
|
|
1432
|
-
|
|
1439
|
+
boxes = torch.tensor(
|
|
1440
|
+
[det["box"] for det in all_detections], dtype=torch.float32
|
|
1441
|
+
)
|
|
1442
|
+
scores = torch.tensor(
|
|
1443
|
+
[det["score"] for det in all_detections], dtype=torch.float32
|
|
1444
|
+
)
|
|
1433
1445
|
|
|
1434
1446
|
# Apply NMS with IoU threshold
|
|
1435
1447
|
nms_threshold = 0.3 # IoU threshold for NMS
|
|
@@ -1909,6 +1921,96 @@ class SemanticRandomHorizontalFlip:
|
|
|
1909
1921
|
return image, mask
|
|
1910
1922
|
|
|
1911
1923
|
|
|
1924
|
+
class SemanticRandomVerticalFlip:
|
|
1925
|
+
"""Random vertical flip transform for semantic segmentation."""
|
|
1926
|
+
|
|
1927
|
+
def __init__(self, prob: float = 0.5) -> None:
|
|
1928
|
+
self.prob = prob
|
|
1929
|
+
|
|
1930
|
+
def __call__(
|
|
1931
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
|
1932
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1933
|
+
if random.random() < self.prob:
|
|
1934
|
+
# Flip image and mask along height dimension
|
|
1935
|
+
image = torch.flip(image, dims=[1])
|
|
1936
|
+
mask = torch.flip(mask, dims=[0])
|
|
1937
|
+
return image, mask
|
|
1938
|
+
|
|
1939
|
+
|
|
1940
|
+
class SemanticRandomRotation90:
|
|
1941
|
+
"""Random 90-degree rotation transform for semantic segmentation."""
|
|
1942
|
+
|
|
1943
|
+
def __init__(self, prob: float = 0.5) -> None:
|
|
1944
|
+
self.prob = prob
|
|
1945
|
+
|
|
1946
|
+
def __call__(
|
|
1947
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
|
1948
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1949
|
+
if random.random() < self.prob:
|
|
1950
|
+
# Randomly rotate by 90, 180, or 270 degrees
|
|
1951
|
+
k = random.randint(1, 3)
|
|
1952
|
+
image = torch.rot90(image, k, dims=[1, 2])
|
|
1953
|
+
mask = torch.rot90(mask, k, dims=[0, 1])
|
|
1954
|
+
return image, mask
|
|
1955
|
+
|
|
1956
|
+
|
|
1957
|
+
class SemanticBrightnessAdjustment:
|
|
1958
|
+
"""Random brightness adjustment transform for semantic segmentation."""
|
|
1959
|
+
|
|
1960
|
+
def __init__(
|
|
1961
|
+
self, brightness_range: Tuple[float, float] = (0.8, 1.2), prob: float = 0.5
|
|
1962
|
+
) -> None:
|
|
1963
|
+
"""
|
|
1964
|
+
Initialize brightness adjustment transform.
|
|
1965
|
+
|
|
1966
|
+
Args:
|
|
1967
|
+
brightness_range: Tuple of (min, max) brightness factors.
|
|
1968
|
+
prob: Probability of applying the transform.
|
|
1969
|
+
"""
|
|
1970
|
+
self.brightness_range = brightness_range
|
|
1971
|
+
self.prob = prob
|
|
1972
|
+
|
|
1973
|
+
def __call__(
|
|
1974
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
|
1975
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1976
|
+
if random.random() < self.prob:
|
|
1977
|
+
# Apply random brightness adjustment
|
|
1978
|
+
factor = self.brightness_range[0] + random.random() * (
|
|
1979
|
+
self.brightness_range[1] - self.brightness_range[0]
|
|
1980
|
+
)
|
|
1981
|
+
image = torch.clamp(image * factor, 0, 1)
|
|
1982
|
+
return image, mask
|
|
1983
|
+
|
|
1984
|
+
|
|
1985
|
+
class SemanticContrastAdjustment:
|
|
1986
|
+
"""Random contrast adjustment transform for semantic segmentation."""
|
|
1987
|
+
|
|
1988
|
+
def __init__(
|
|
1989
|
+
self, contrast_range: Tuple[float, float] = (0.8, 1.2), prob: float = 0.5
|
|
1990
|
+
) -> None:
|
|
1991
|
+
"""
|
|
1992
|
+
Initialize contrast adjustment transform.
|
|
1993
|
+
|
|
1994
|
+
Args:
|
|
1995
|
+
contrast_range: Tuple of (min, max) contrast factors.
|
|
1996
|
+
prob: Probability of applying the transform.
|
|
1997
|
+
"""
|
|
1998
|
+
self.contrast_range = contrast_range
|
|
1999
|
+
self.prob = prob
|
|
2000
|
+
|
|
2001
|
+
def __call__(
|
|
2002
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
|
2003
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
2004
|
+
if random.random() < self.prob:
|
|
2005
|
+
# Apply random contrast adjustment
|
|
2006
|
+
factor = self.contrast_range[0] + random.random() * (
|
|
2007
|
+
self.contrast_range[1] - self.contrast_range[0]
|
|
2008
|
+
)
|
|
2009
|
+
mean = image.mean(dim=(1, 2), keepdim=True)
|
|
2010
|
+
image = torch.clamp((image - mean) * factor + mean, 0, 1)
|
|
2011
|
+
return image, mask
|
|
2012
|
+
|
|
2013
|
+
|
|
1912
2014
|
def get_semantic_transform(train: bool) -> Any:
|
|
1913
2015
|
"""
|
|
1914
2016
|
Get transforms for semantic segmentation data augmentation.
|
|
@@ -2025,14 +2127,14 @@ def get_smp_model(
|
|
|
2025
2127
|
)
|
|
2026
2128
|
|
|
2027
2129
|
|
|
2028
|
-
def
|
|
2130
|
+
def f1_score(
|
|
2029
2131
|
pred: torch.Tensor,
|
|
2030
2132
|
target: torch.Tensor,
|
|
2031
2133
|
smooth: float = 1e-6,
|
|
2032
2134
|
num_classes: Optional[int] = None,
|
|
2033
2135
|
) -> float:
|
|
2034
2136
|
"""
|
|
2035
|
-
Calculate Dice coefficient for segmentation (binary or multi-class).
|
|
2137
|
+
Calculate F1 score (also known as Dice coefficient) for segmentation (binary or multi-class).
|
|
2036
2138
|
|
|
2037
2139
|
Args:
|
|
2038
2140
|
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
|
@@ -2041,7 +2143,7 @@ def dice_coefficient(
|
|
|
2041
2143
|
num_classes (int, optional): Number of classes. If None, auto-detected.
|
|
2042
2144
|
|
|
2043
2145
|
Returns:
|
|
2044
|
-
float: Mean
|
|
2146
|
+
float: Mean F1 score across all classes.
|
|
2045
2147
|
"""
|
|
2046
2148
|
# Convert predictions to class predictions
|
|
2047
2149
|
if pred.dim() == 3: # [C, H, W] format
|
|
@@ -2056,8 +2158,8 @@ def dice_coefficient(
|
|
|
2056
2158
|
if num_classes is None:
|
|
2057
2159
|
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
|
2058
2160
|
|
|
2059
|
-
# Calculate
|
|
2060
|
-
|
|
2161
|
+
# Calculate F1 score for each class and average
|
|
2162
|
+
f1_scores = []
|
|
2061
2163
|
for class_id in range(num_classes):
|
|
2062
2164
|
pred_class = (pred_classes == class_id).float()
|
|
2063
2165
|
target_class = (target == class_id).float()
|
|
@@ -2066,10 +2168,10 @@ def dice_coefficient(
|
|
|
2066
2168
|
union = pred_class.sum() + target_class.sum()
|
|
2067
2169
|
|
|
2068
2170
|
if union > 0:
|
|
2069
|
-
|
|
2070
|
-
|
|
2171
|
+
f1 = (2.0 * intersection + smooth) / (union + smooth)
|
|
2172
|
+
f1_scores.append(f1.item())
|
|
2071
2173
|
|
|
2072
|
-
return sum(
|
|
2174
|
+
return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
|
|
2073
2175
|
|
|
2074
2176
|
|
|
2075
2177
|
def iou_coefficient(
|
|
@@ -2119,6 +2221,108 @@ def iou_coefficient(
|
|
|
2119
2221
|
return sum(iou_scores) / len(iou_scores) if iou_scores else 0.0
|
|
2120
2222
|
|
|
2121
2223
|
|
|
2224
|
+
def precision_score(
|
|
2225
|
+
pred: torch.Tensor,
|
|
2226
|
+
target: torch.Tensor,
|
|
2227
|
+
smooth: float = 1e-6,
|
|
2228
|
+
num_classes: Optional[int] = None,
|
|
2229
|
+
) -> float:
|
|
2230
|
+
"""
|
|
2231
|
+
Calculate precision score for segmentation (binary or multi-class).
|
|
2232
|
+
|
|
2233
|
+
Precision = TP / (TP + FP), where:
|
|
2234
|
+
- TP (True Positives): Correctly predicted positive pixels
|
|
2235
|
+
- FP (False Positives): Incorrectly predicted positive pixels
|
|
2236
|
+
|
|
2237
|
+
Args:
|
|
2238
|
+
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
|
2239
|
+
target (torch.Tensor): Ground truth mask with shape [H, W].
|
|
2240
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
2241
|
+
num_classes (int, optional): Number of classes. If None, auto-detected.
|
|
2242
|
+
|
|
2243
|
+
Returns:
|
|
2244
|
+
float: Mean precision score across all classes.
|
|
2245
|
+
"""
|
|
2246
|
+
# Convert predictions to class predictions
|
|
2247
|
+
if pred.dim() == 3: # [C, H, W] format
|
|
2248
|
+
pred = torch.softmax(pred, dim=0)
|
|
2249
|
+
pred_classes = torch.argmax(pred, dim=0)
|
|
2250
|
+
elif pred.dim() == 2: # [H, W] format
|
|
2251
|
+
pred_classes = pred
|
|
2252
|
+
else:
|
|
2253
|
+
raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
|
|
2254
|
+
|
|
2255
|
+
# Auto-detect number of classes if not provided
|
|
2256
|
+
if num_classes is None:
|
|
2257
|
+
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
|
2258
|
+
|
|
2259
|
+
# Calculate precision for each class and average
|
|
2260
|
+
precision_scores = []
|
|
2261
|
+
for class_id in range(num_classes):
|
|
2262
|
+
pred_class = (pred_classes == class_id).float()
|
|
2263
|
+
target_class = (target == class_id).float()
|
|
2264
|
+
|
|
2265
|
+
true_positives = (pred_class * target_class).sum()
|
|
2266
|
+
predicted_positives = pred_class.sum()
|
|
2267
|
+
|
|
2268
|
+
if predicted_positives > 0:
|
|
2269
|
+
precision = (true_positives + smooth) / (predicted_positives + smooth)
|
|
2270
|
+
precision_scores.append(precision.item())
|
|
2271
|
+
|
|
2272
|
+
return sum(precision_scores) / len(precision_scores) if precision_scores else 0.0
|
|
2273
|
+
|
|
2274
|
+
|
|
2275
|
+
def recall_score(
|
|
2276
|
+
pred: torch.Tensor,
|
|
2277
|
+
target: torch.Tensor,
|
|
2278
|
+
smooth: float = 1e-6,
|
|
2279
|
+
num_classes: Optional[int] = None,
|
|
2280
|
+
) -> float:
|
|
2281
|
+
"""
|
|
2282
|
+
Calculate recall score (also known as sensitivity) for segmentation (binary or multi-class).
|
|
2283
|
+
|
|
2284
|
+
Recall = TP / (TP + FN), where:
|
|
2285
|
+
- TP (True Positives): Correctly predicted positive pixels
|
|
2286
|
+
- FN (False Negatives): Incorrectly predicted negative pixels
|
|
2287
|
+
|
|
2288
|
+
Args:
|
|
2289
|
+
pred (torch.Tensor): Predicted mask (probabilities or logits) with shape [C, H, W] or [H, W].
|
|
2290
|
+
target (torch.Tensor): Ground truth mask with shape [H, W].
|
|
2291
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
2292
|
+
num_classes (int, optional): Number of classes. If None, auto-detected.
|
|
2293
|
+
|
|
2294
|
+
Returns:
|
|
2295
|
+
float: Mean recall score across all classes.
|
|
2296
|
+
"""
|
|
2297
|
+
# Convert predictions to class predictions
|
|
2298
|
+
if pred.dim() == 3: # [C, H, W] format
|
|
2299
|
+
pred = torch.softmax(pred, dim=0)
|
|
2300
|
+
pred_classes = torch.argmax(pred, dim=0)
|
|
2301
|
+
elif pred.dim() == 2: # [H, W] format
|
|
2302
|
+
pred_classes = pred
|
|
2303
|
+
else:
|
|
2304
|
+
raise ValueError(f"Unexpected prediction dimensions: {pred.shape}")
|
|
2305
|
+
|
|
2306
|
+
# Auto-detect number of classes if not provided
|
|
2307
|
+
if num_classes is None:
|
|
2308
|
+
num_classes = max(pred_classes.max().item(), target.max().item()) + 1
|
|
2309
|
+
|
|
2310
|
+
# Calculate recall for each class and average
|
|
2311
|
+
recall_scores = []
|
|
2312
|
+
for class_id in range(num_classes):
|
|
2313
|
+
pred_class = (pred_classes == class_id).float()
|
|
2314
|
+
target_class = (target == class_id).float()
|
|
2315
|
+
|
|
2316
|
+
true_positives = (pred_class * target_class).sum()
|
|
2317
|
+
actual_positives = target_class.sum()
|
|
2318
|
+
|
|
2319
|
+
if actual_positives > 0:
|
|
2320
|
+
recall = (true_positives + smooth) / (actual_positives + smooth)
|
|
2321
|
+
recall_scores.append(recall.item())
|
|
2322
|
+
|
|
2323
|
+
return sum(recall_scores) / len(recall_scores) if recall_scores else 0.0
|
|
2324
|
+
|
|
2325
|
+
|
|
2122
2326
|
def train_semantic_one_epoch(
|
|
2123
2327
|
model: torch.nn.Module,
|
|
2124
2328
|
optimizer: torch.optim.Optimizer,
|
|
@@ -2200,13 +2404,15 @@ def evaluate_semantic(
|
|
|
2200
2404
|
num_classes (int): Number of classes for evaluation metrics.
|
|
2201
2405
|
|
|
2202
2406
|
Returns:
|
|
2203
|
-
dict: Evaluation metrics including loss, IoU, and
|
|
2407
|
+
dict: Evaluation metrics including loss, IoU, F1, precision, and recall.
|
|
2204
2408
|
"""
|
|
2205
2409
|
model.eval()
|
|
2206
2410
|
|
|
2207
2411
|
total_loss = 0
|
|
2208
|
-
|
|
2412
|
+
f1_scores = []
|
|
2209
2413
|
iou_scores = []
|
|
2414
|
+
precision_scores = []
|
|
2415
|
+
recall_scores = []
|
|
2210
2416
|
num_batches = len(data_loader)
|
|
2211
2417
|
|
|
2212
2418
|
with torch.no_grad():
|
|
@@ -2222,17 +2428,31 @@ def evaluate_semantic(
|
|
|
2222
2428
|
|
|
2223
2429
|
# Calculate metrics for each sample in the batch
|
|
2224
2430
|
for pred, target in zip(outputs, targets):
|
|
2225
|
-
|
|
2431
|
+
f1 = f1_score(pred, target, num_classes=num_classes)
|
|
2226
2432
|
iou = iou_coefficient(pred, target, num_classes=num_classes)
|
|
2227
|
-
|
|
2433
|
+
precision = precision_score(pred, target, num_classes=num_classes)
|
|
2434
|
+
recall = recall_score(pred, target, num_classes=num_classes)
|
|
2435
|
+
f1_scores.append(f1)
|
|
2228
2436
|
iou_scores.append(iou)
|
|
2437
|
+
precision_scores.append(precision)
|
|
2438
|
+
recall_scores.append(recall)
|
|
2229
2439
|
|
|
2230
2440
|
# Calculate metrics
|
|
2231
2441
|
avg_loss = total_loss / num_batches
|
|
2232
|
-
|
|
2442
|
+
avg_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
|
|
2233
2443
|
avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
|
|
2234
|
-
|
|
2235
|
-
|
|
2444
|
+
avg_precision = (
|
|
2445
|
+
sum(precision_scores) / len(precision_scores) if precision_scores else 0
|
|
2446
|
+
)
|
|
2447
|
+
avg_recall = sum(recall_scores) / len(recall_scores) if recall_scores else 0
|
|
2448
|
+
|
|
2449
|
+
return {
|
|
2450
|
+
"loss": avg_loss,
|
|
2451
|
+
"F1": avg_f1,
|
|
2452
|
+
"IoU": avg_iou,
|
|
2453
|
+
"Precision": avg_precision,
|
|
2454
|
+
"Recall": avg_recall,
|
|
2455
|
+
}
|
|
2236
2456
|
|
|
2237
2457
|
|
|
2238
2458
|
def train_segmentation_model(
|
|
@@ -2261,6 +2481,9 @@ def train_segmentation_model(
|
|
|
2261
2481
|
target_size: Optional[Tuple[int, int]] = None,
|
|
2262
2482
|
resize_mode: str = "resize",
|
|
2263
2483
|
num_workers: Optional[int] = None,
|
|
2484
|
+
early_stopping_patience: Optional[int] = None,
|
|
2485
|
+
train_transforms: Optional[Callable] = None,
|
|
2486
|
+
val_transforms: Optional[Callable] = None,
|
|
2264
2487
|
**kwargs: Any,
|
|
2265
2488
|
) -> torch.nn.Module:
|
|
2266
2489
|
"""
|
|
@@ -2313,6 +2536,17 @@ def train_segmentation_model(
|
|
|
2313
2536
|
'resize' - Resize images to target_size (may change aspect ratio)
|
|
2314
2537
|
'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
|
|
2315
2538
|
num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
|
|
2539
|
+
Both image and mask should be torch.Tensor objects. The image tensor is expected to be in
|
|
2540
|
+
CHW format (channels, height, width), and the mask tensor in HW format (height, width).
|
|
2541
|
+
If None, uses default transforms (horizontal flip with 0.5 probability). Defaults to None.
|
|
2542
|
+
val_transforms (callable, optional): Custom transforms for validation data.
|
|
2543
|
+
Should be a callable that accepts (image, mask) tensors and returns transformed (image, mask).
|
|
2544
|
+
The image tensor is expected to be in CHW format (channels, height, width), and the mask tensor in HW format (height, width).
|
|
2545
|
+
Both image and mask should be torch.Tensor objects. If None, uses default transforms
|
|
2546
|
+
(horizontal flip with 0.5 probability). Defaults to None.
|
|
2547
|
+
val_transforms (callable, optional): Custom transforms for validation data.
|
|
2548
|
+
Should be a callable that accepts (image, mask) tensors and returns transformed (image, mask).
|
|
2549
|
+
If None, uses default transforms (no augmentation). Defaults to None.
|
|
2316
2550
|
**kwargs: Additional arguments passed to smp.create_model().
|
|
2317
2551
|
Returns:
|
|
2318
2552
|
None: Model weights are saved to output_dir.
|
|
@@ -2455,10 +2689,22 @@ def train_segmentation_model(
|
|
|
2455
2689
|
print("No resizing needed.")
|
|
2456
2690
|
|
|
2457
2691
|
# Create datasets
|
|
2692
|
+
# Use custom transforms if provided, otherwise use default transforms
|
|
2693
|
+
train_transform = (
|
|
2694
|
+
train_transforms
|
|
2695
|
+
if train_transforms is not None
|
|
2696
|
+
else get_semantic_transform(train=True)
|
|
2697
|
+
)
|
|
2698
|
+
val_transform = (
|
|
2699
|
+
val_transforms
|
|
2700
|
+
if val_transforms is not None
|
|
2701
|
+
else get_semantic_transform(train=False)
|
|
2702
|
+
)
|
|
2703
|
+
|
|
2458
2704
|
train_dataset = SemanticSegmentationDataset(
|
|
2459
2705
|
train_imgs,
|
|
2460
2706
|
train_labels,
|
|
2461
|
-
transforms=
|
|
2707
|
+
transforms=train_transform,
|
|
2462
2708
|
num_channels=num_channels,
|
|
2463
2709
|
target_size=target_size,
|
|
2464
2710
|
resize_mode=resize_mode,
|
|
@@ -2467,7 +2713,7 @@ def train_segmentation_model(
|
|
|
2467
2713
|
val_dataset = SemanticSegmentationDataset(
|
|
2468
2714
|
val_imgs,
|
|
2469
2715
|
val_labels,
|
|
2470
|
-
transforms=
|
|
2716
|
+
transforms=val_transform,
|
|
2471
2717
|
num_channels=num_channels,
|
|
2472
2718
|
target_size=target_size,
|
|
2473
2719
|
resize_mode=resize_mode,
|
|
@@ -2542,7 +2788,7 @@ def train_segmentation_model(
|
|
|
2542
2788
|
print(f"Using {torch.cuda.device_count()} GPUs for training")
|
|
2543
2789
|
model = torch.nn.DataParallel(model)
|
|
2544
2790
|
|
|
2545
|
-
# Set up loss function (CrossEntropyLoss for multi-class, can also use
|
|
2791
|
+
# Set up loss function (CrossEntropyLoss for multi-class, can also use F1Loss)
|
|
2546
2792
|
criterion = torch.nn.CrossEntropyLoss()
|
|
2547
2793
|
|
|
2548
2794
|
# Set up optimizer
|
|
@@ -2560,8 +2806,11 @@ def train_segmentation_model(
|
|
|
2560
2806
|
train_losses = []
|
|
2561
2807
|
val_losses = []
|
|
2562
2808
|
val_ious = []
|
|
2563
|
-
|
|
2809
|
+
val_f1s = []
|
|
2810
|
+
val_precisions = []
|
|
2811
|
+
val_recalls = []
|
|
2564
2812
|
start_epoch = 0
|
|
2813
|
+
epochs_without_improvement = 0
|
|
2565
2814
|
|
|
2566
2815
|
# Load checkpoint if provided
|
|
2567
2816
|
if checkpoint_path is not None:
|
|
@@ -2596,8 +2845,15 @@ def train_segmentation_model(
|
|
|
2596
2845
|
val_losses = checkpoint["val_losses"]
|
|
2597
2846
|
if "val_ious" in checkpoint:
|
|
2598
2847
|
val_ious = checkpoint["val_ious"]
|
|
2599
|
-
if "
|
|
2600
|
-
|
|
2848
|
+
if "val_f1s" in checkpoint:
|
|
2849
|
+
val_f1s = checkpoint["val_f1s"]
|
|
2850
|
+
# Also check for old val_dices format for backward compatibility
|
|
2851
|
+
elif "val_dices" in checkpoint:
|
|
2852
|
+
val_f1s = checkpoint["val_dices"]
|
|
2853
|
+
if "val_precisions" in checkpoint:
|
|
2854
|
+
val_precisions = checkpoint["val_precisions"]
|
|
2855
|
+
if "val_recalls" in checkpoint:
|
|
2856
|
+
val_recalls = checkpoint["val_recalls"]
|
|
2601
2857
|
|
|
2602
2858
|
print(f"Resuming training from epoch {start_epoch}")
|
|
2603
2859
|
print(f"Previous best IoU: {best_iou:.4f}")
|
|
@@ -2637,7 +2893,9 @@ def train_segmentation_model(
|
|
|
2637
2893
|
)
|
|
2638
2894
|
val_losses.append(eval_metrics["loss"])
|
|
2639
2895
|
val_ious.append(eval_metrics["IoU"])
|
|
2640
|
-
|
|
2896
|
+
val_f1s.append(eval_metrics["F1"])
|
|
2897
|
+
val_precisions.append(eval_metrics["Precision"])
|
|
2898
|
+
val_recalls.append(eval_metrics["Recall"])
|
|
2641
2899
|
|
|
2642
2900
|
# Update learning rate
|
|
2643
2901
|
lr_scheduler.step(eval_metrics["loss"])
|
|
@@ -2648,14 +2906,28 @@ def train_segmentation_model(
|
|
|
2648
2906
|
f"Train Loss: {train_loss:.4f}, "
|
|
2649
2907
|
f"Val Loss: {eval_metrics['loss']:.4f}, "
|
|
2650
2908
|
f"Val IoU: {eval_metrics['IoU']:.4f}, "
|
|
2651
|
-
f"Val
|
|
2909
|
+
f"Val F1: {eval_metrics['F1']:.4f}, "
|
|
2910
|
+
f"Val Precision: {eval_metrics['Precision']:.4f}, "
|
|
2911
|
+
f"Val Recall: {eval_metrics['Recall']:.4f}"
|
|
2652
2912
|
)
|
|
2653
2913
|
|
|
2654
|
-
# Save best model
|
|
2914
|
+
# Save best model and check for early stopping
|
|
2655
2915
|
if eval_metrics["IoU"] > best_iou:
|
|
2656
2916
|
best_iou = eval_metrics["IoU"]
|
|
2917
|
+
epochs_without_improvement = 0
|
|
2657
2918
|
print(f"Saving best model with IoU: {best_iou:.4f}")
|
|
2658
2919
|
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
|
2920
|
+
else:
|
|
2921
|
+
epochs_without_improvement += 1
|
|
2922
|
+
if (
|
|
2923
|
+
early_stopping_patience is not None
|
|
2924
|
+
and epochs_without_improvement >= early_stopping_patience
|
|
2925
|
+
):
|
|
2926
|
+
print(
|
|
2927
|
+
f"\nEarly stopping triggered after {epochs_without_improvement} epochs without improvement"
|
|
2928
|
+
)
|
|
2929
|
+
print(f"Best validation IoU: {best_iou:.4f}")
|
|
2930
|
+
break
|
|
2659
2931
|
|
|
2660
2932
|
# Save checkpoint every 10 epochs (if not save_best_only)
|
|
2661
2933
|
if not save_best_only and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
|
|
@@ -2673,7 +2945,9 @@ def train_segmentation_model(
|
|
|
2673
2945
|
"train_losses": train_losses,
|
|
2674
2946
|
"val_losses": val_losses,
|
|
2675
2947
|
"val_ious": val_ious,
|
|
2676
|
-
"
|
|
2948
|
+
"val_f1s": val_f1s,
|
|
2949
|
+
"val_precisions": val_precisions,
|
|
2950
|
+
"val_recalls": val_recalls,
|
|
2677
2951
|
},
|
|
2678
2952
|
os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
|
|
2679
2953
|
)
|
|
@@ -2686,7 +2960,9 @@ def train_segmentation_model(
|
|
|
2686
2960
|
"train_losses": train_losses,
|
|
2687
2961
|
"val_losses": val_losses,
|
|
2688
2962
|
"val_ious": val_ious,
|
|
2689
|
-
"
|
|
2963
|
+
"val_f1s": val_f1s,
|
|
2964
|
+
"val_precisions": val_precisions,
|
|
2965
|
+
"val_recalls": val_recalls,
|
|
2690
2966
|
}
|
|
2691
2967
|
torch.save(history, os.path.join(output_dir, "training_history.pth"))
|
|
2692
2968
|
|
|
@@ -2702,7 +2978,9 @@ def train_segmentation_model(
|
|
|
2702
2978
|
f.write(f"Total epochs: {num_epochs}\n")
|
|
2703
2979
|
f.write(f"Best validation IoU: {best_iou:.4f}\n")
|
|
2704
2980
|
f.write(f"Final validation IoU: {val_ious[-1]:.4f}\n")
|
|
2705
|
-
f.write(f"Final validation
|
|
2981
|
+
f.write(f"Final validation F1: {val_f1s[-1]:.4f}\n")
|
|
2982
|
+
f.write(f"Final validation Precision: {val_precisions[-1]:.4f}\n")
|
|
2983
|
+
f.write(f"Final validation Recall: {val_recalls[-1]:.4f}\n")
|
|
2706
2984
|
f.write(f"Final validation loss: {val_losses[-1]:.4f}\n")
|
|
2707
2985
|
|
|
2708
2986
|
print(f"Training complete! Best IoU: {best_iou:.4f}")
|
|
@@ -2731,10 +3009,10 @@ def train_segmentation_model(
|
|
|
2731
3009
|
plt.grid(True)
|
|
2732
3010
|
|
|
2733
3011
|
plt.subplot(1, 3, 3)
|
|
2734
|
-
plt.plot(
|
|
2735
|
-
plt.title("
|
|
3012
|
+
plt.plot(val_f1s, label="Val F1")
|
|
3013
|
+
plt.title("F1 Score")
|
|
2736
3014
|
plt.xlabel("Epoch")
|
|
2737
|
-
plt.ylabel("
|
|
3015
|
+
plt.ylabel("F1")
|
|
2738
3016
|
plt.legend()
|
|
2739
3017
|
plt.grid(True)
|
|
2740
3018
|
|
|
@@ -2764,6 +3042,7 @@ def semantic_inference_on_geotiff(
|
|
|
2764
3042
|
device: Optional[torch.device] = None,
|
|
2765
3043
|
probability_path: Optional[str] = None,
|
|
2766
3044
|
probability_threshold: Optional[float] = None,
|
|
3045
|
+
save_class_probabilities: bool = False,
|
|
2767
3046
|
quiet: bool = False,
|
|
2768
3047
|
**kwargs: Any,
|
|
2769
3048
|
) -> Tuple[str, float]:
|
|
@@ -2785,6 +3064,8 @@ def semantic_inference_on_geotiff(
|
|
|
2785
3064
|
probability_threshold (float, optional): Probability threshold for binary classification.
|
|
2786
3065
|
Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
|
|
2787
3066
|
are classified as class 1, otherwise class 0. If None (default), uses argmax.
|
|
3067
|
+
save_class_probabilities (bool): If True and probability_path is provided, saves each
|
|
3068
|
+
class probability as a separate single-band file. Defaults to False.
|
|
2788
3069
|
quiet (bool): If True, suppress progress bar. Defaults to False.
|
|
2789
3070
|
**kwargs: Additional arguments.
|
|
2790
3071
|
|
|
@@ -3001,7 +3282,7 @@ def semantic_inference_on_geotiff(
|
|
|
3001
3282
|
prob_meta = meta.copy()
|
|
3002
3283
|
prob_meta.update({"count": num_classes, "dtype": "float32"})
|
|
3003
3284
|
|
|
3004
|
-
# Save normalized probabilities
|
|
3285
|
+
# Save normalized probabilities as multi-band raster
|
|
3005
3286
|
with rasterio.open(probability_path, "w", **prob_meta) as dst:
|
|
3006
3287
|
for class_idx in range(num_classes):
|
|
3007
3288
|
# Normalize probabilities
|
|
@@ -3015,6 +3296,36 @@ def semantic_inference_on_geotiff(
|
|
|
3015
3296
|
if not quiet:
|
|
3016
3297
|
print(f"Saved probability map to {probability_path}")
|
|
3017
3298
|
|
|
3299
|
+
# Save individual class probabilities if requested
|
|
3300
|
+
if save_class_probabilities:
|
|
3301
|
+
# Prepare single-band metadata
|
|
3302
|
+
single_band_meta = meta.copy()
|
|
3303
|
+
single_band_meta.update({"count": 1, "dtype": "float32"})
|
|
3304
|
+
|
|
3305
|
+
# Get base filename and extension
|
|
3306
|
+
prob_base = os.path.splitext(probability_path)[0]
|
|
3307
|
+
prob_ext = os.path.splitext(probability_path)[1]
|
|
3308
|
+
|
|
3309
|
+
for class_idx in range(num_classes):
|
|
3310
|
+
# Create filename for this class
|
|
3311
|
+
class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
|
|
3312
|
+
|
|
3313
|
+
# Normalize probabilities
|
|
3314
|
+
prob_band = np.zeros((height, width), dtype=np.float32)
|
|
3315
|
+
prob_band[valid_pixels] = (
|
|
3316
|
+
prob_accumulator[class_idx, valid_pixels]
|
|
3317
|
+
/ count_accumulator[valid_pixels]
|
|
3318
|
+
)
|
|
3319
|
+
|
|
3320
|
+
# Save single-band file
|
|
3321
|
+
with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
|
|
3322
|
+
dst.write(prob_band, 1)
|
|
3323
|
+
|
|
3324
|
+
if not quiet:
|
|
3325
|
+
print(
|
|
3326
|
+
f"Saved class {class_idx} probability to {class_prob_path}"
|
|
3327
|
+
)
|
|
3328
|
+
|
|
3018
3329
|
return output_path, inference_time
|
|
3019
3330
|
|
|
3020
3331
|
|
|
@@ -3031,6 +3342,7 @@ def semantic_inference_on_image(
|
|
|
3031
3342
|
binary_output: bool = True,
|
|
3032
3343
|
probability_path: Optional[str] = None,
|
|
3033
3344
|
probability_threshold: Optional[float] = None,
|
|
3345
|
+
save_class_probabilities: bool = False,
|
|
3034
3346
|
quiet: bool = False,
|
|
3035
3347
|
**kwargs: Any,
|
|
3036
3348
|
) -> Tuple[str, float]:
|
|
@@ -3053,6 +3365,8 @@ def semantic_inference_on_image(
|
|
|
3053
3365
|
probability_threshold (float, optional): Probability threshold for binary classification.
|
|
3054
3366
|
Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
|
|
3055
3367
|
are classified as class 1, otherwise class 0. If None (default), uses argmax.
|
|
3368
|
+
save_class_probabilities (bool): If True and probability_path is provided, saves each
|
|
3369
|
+
class probability as a separate single-band file. Defaults to False.
|
|
3056
3370
|
quiet (bool): If True, suppress progress bar. Defaults to False.
|
|
3057
3371
|
**kwargs: Additional arguments.
|
|
3058
3372
|
|
|
@@ -3331,7 +3645,7 @@ def semantic_inference_on_image(
|
|
|
3331
3645
|
"transform": transform,
|
|
3332
3646
|
}
|
|
3333
3647
|
|
|
3334
|
-
# Save normalized probabilities
|
|
3648
|
+
# Save normalized probabilities as multi-band raster
|
|
3335
3649
|
with rasterio.open(probability_path, "w", **prob_meta) as dst:
|
|
3336
3650
|
for class_idx in range(num_classes):
|
|
3337
3651
|
# Normalize probabilities
|
|
@@ -3342,6 +3656,39 @@ def semantic_inference_on_image(
|
|
|
3342
3656
|
if not quiet:
|
|
3343
3657
|
print(f"Saved probability map to {probability_path}")
|
|
3344
3658
|
|
|
3659
|
+
# Save individual class probabilities if requested
|
|
3660
|
+
if save_class_probabilities:
|
|
3661
|
+
# Prepare single-band metadata
|
|
3662
|
+
single_band_meta = {
|
|
3663
|
+
"driver": "GTiff",
|
|
3664
|
+
"height": height,
|
|
3665
|
+
"width": width,
|
|
3666
|
+
"count": 1,
|
|
3667
|
+
"dtype": "float32",
|
|
3668
|
+
"transform": transform,
|
|
3669
|
+
}
|
|
3670
|
+
|
|
3671
|
+
# Get base filename and extension
|
|
3672
|
+
prob_base = os.path.splitext(probability_path)[0]
|
|
3673
|
+
prob_ext = os.path.splitext(probability_path)[1]
|
|
3674
|
+
|
|
3675
|
+
for class_idx in range(num_classes):
|
|
3676
|
+
# Create filename for this class
|
|
3677
|
+
class_prob_path = f"{prob_base}_class_{class_idx}{prob_ext}"
|
|
3678
|
+
|
|
3679
|
+
# Normalize probabilities
|
|
3680
|
+
prob_band = np.zeros((height, width), dtype=np.float32)
|
|
3681
|
+
prob_band[valid_pixels] = normalized_probs[class_idx, valid_pixels]
|
|
3682
|
+
|
|
3683
|
+
# Save single-band file
|
|
3684
|
+
with rasterio.open(class_prob_path, "w", **single_band_meta) as dst:
|
|
3685
|
+
dst.write(prob_band, 1)
|
|
3686
|
+
|
|
3687
|
+
if not quiet:
|
|
3688
|
+
print(
|
|
3689
|
+
f"Saved class {class_idx} probability to {class_prob_path}"
|
|
3690
|
+
)
|
|
3691
|
+
|
|
3345
3692
|
return output_path, inference_time
|
|
3346
3693
|
|
|
3347
3694
|
|
|
@@ -3359,6 +3706,7 @@ def semantic_segmentation(
|
|
|
3359
3706
|
device: Optional[torch.device] = None,
|
|
3360
3707
|
probability_path: Optional[str] = None,
|
|
3361
3708
|
probability_threshold: Optional[float] = None,
|
|
3709
|
+
save_class_probabilities: bool = False,
|
|
3362
3710
|
quiet: bool = False,
|
|
3363
3711
|
**kwargs: Any,
|
|
3364
3712
|
) -> None:
|
|
@@ -3381,11 +3729,16 @@ def semantic_segmentation(
|
|
|
3381
3729
|
batch_size (int): Batch size for inference.
|
|
3382
3730
|
device (torch.device, optional): Device to run inference on.
|
|
3383
3731
|
probability_path (str, optional): Path to save probability map. If provided,
|
|
3384
|
-
the normalized class probabilities will be saved as a multi-band raster
|
|
3732
|
+
the normalized class probabilities will be saved as a multi-band raster
|
|
3733
|
+
where each band contains probabilities for each class.
|
|
3385
3734
|
probability_threshold (float, optional): Probability threshold for binary classification.
|
|
3386
3735
|
Only used when num_classes=2. If provided, pixels with class 1 probability >= threshold
|
|
3387
3736
|
are classified as class 1, otherwise class 0. If None (default), uses argmax.
|
|
3388
3737
|
Must be between 0 and 1.
|
|
3738
|
+
save_class_probabilities (bool): If True and probability_path is provided, saves each
|
|
3739
|
+
class probability as a separate single-band file. Files will be named like
|
|
3740
|
+
"probability_class_0.tif", "probability_class_1.tif", etc. in the same directory
|
|
3741
|
+
as probability_path. Defaults to False.
|
|
3389
3742
|
quiet (bool): If True, suppress progress bar. Defaults to False.
|
|
3390
3743
|
**kwargs: Additional arguments.
|
|
3391
3744
|
|
|
@@ -3462,6 +3815,7 @@ def semantic_segmentation(
|
|
|
3462
3815
|
device=device,
|
|
3463
3816
|
probability_path=probability_path,
|
|
3464
3817
|
probability_threshold=probability_threshold,
|
|
3818
|
+
save_class_probabilities=save_class_probabilities,
|
|
3465
3819
|
quiet=quiet,
|
|
3466
3820
|
**kwargs,
|
|
3467
3821
|
)
|
|
@@ -3482,6 +3836,7 @@ def semantic_segmentation(
|
|
|
3482
3836
|
binary_output=True, # Convert to binary output for better visualization
|
|
3483
3837
|
probability_path=probability_path,
|
|
3484
3838
|
probability_threshold=probability_threshold,
|
|
3839
|
+
save_class_probabilities=save_class_probabilities,
|
|
3485
3840
|
quiet=quiet,
|
|
3486
3841
|
**kwargs,
|
|
3487
3842
|
)
|
|
@@ -3903,3 +4258,307 @@ def instance_segmentation_batch(
|
|
|
3903
4258
|
continue
|
|
3904
4259
|
|
|
3905
4260
|
print(f"Batch processing completed. Results saved to {output_dir}")
|
|
4261
|
+
|
|
4262
|
+
|
|
4263
|
+
def lightly_train_model(
|
|
4264
|
+
data_dir: str,
|
|
4265
|
+
output_dir: str,
|
|
4266
|
+
model: str = "torchvision/resnet50",
|
|
4267
|
+
method: str = "dinov2_distillation",
|
|
4268
|
+
epochs: int = 100,
|
|
4269
|
+
batch_size: int = 64,
|
|
4270
|
+
learning_rate: float = 1e-4,
|
|
4271
|
+
**kwargs: Any,
|
|
4272
|
+
) -> str:
|
|
4273
|
+
"""
|
|
4274
|
+
Train a model using Lightly Train for self-supervised pretraining.
|
|
4275
|
+
|
|
4276
|
+
Args:
|
|
4277
|
+
data_dir (str): Directory containing unlabeled images for training.
|
|
4278
|
+
output_dir (str): Directory to save training outputs and model checkpoints.
|
|
4279
|
+
model (str): Model architecture to train. Supports models from torchvision,
|
|
4280
|
+
timm, ultralytics, etc. Default is "torchvision/resnet50".
|
|
4281
|
+
method (str): Self-supervised learning method. Options include:
|
|
4282
|
+
- "simclr": Works with CNN models (ResNet, EfficientNet, etc.)
|
|
4283
|
+
- "dino": Works with both CNNs and ViTs
|
|
4284
|
+
- "dinov2": Requires ViT models only
|
|
4285
|
+
- "dinov2_distillation": Requires ViT models only (recommended for ViTs)
|
|
4286
|
+
Default is "dinov2_distillation".
|
|
4287
|
+
epochs (int): Number of training epochs. Default is 100.
|
|
4288
|
+
batch_size (int): Batch size for training. Default is 64.
|
|
4289
|
+
learning_rate (float): Learning rate for training. Default is 1e-4.
|
|
4290
|
+
**kwargs: Additional arguments passed to lightly_train.train().
|
|
4291
|
+
|
|
4292
|
+
Returns:
|
|
4293
|
+
str: Path to the exported model file.
|
|
4294
|
+
|
|
4295
|
+
Raises:
|
|
4296
|
+
ImportError: If lightly-train is not installed.
|
|
4297
|
+
ValueError: If data_dir does not exist, is empty, or incompatible model/method.
|
|
4298
|
+
|
|
4299
|
+
Note:
|
|
4300
|
+
Model/Method compatibility:
|
|
4301
|
+
- CNN models (ResNet, EfficientNet): Use "simclr" or "dino"
|
|
4302
|
+
- ViT models: Use "dinov2", "dinov2_distillation", or "dino"
|
|
4303
|
+
|
|
4304
|
+
Example:
|
|
4305
|
+
>>> # For CNN models (ResNet, EfficientNet)
|
|
4306
|
+
>>> model_path = lightly_train_model(
|
|
4307
|
+
... data_dir="path/to/unlabeled/images",
|
|
4308
|
+
... output_dir="path/to/output",
|
|
4309
|
+
... model="torchvision/resnet50",
|
|
4310
|
+
... method="simclr", # Use simclr for CNNs
|
|
4311
|
+
... epochs=50
|
|
4312
|
+
... )
|
|
4313
|
+
>>> # For ViT models
|
|
4314
|
+
>>> model_path = lightly_train_model(
|
|
4315
|
+
... data_dir="path/to/unlabeled/images",
|
|
4316
|
+
... output_dir="path/to/output",
|
|
4317
|
+
... model="timm/vit_base_patch16_224",
|
|
4318
|
+
... method="dinov2", # dinov2 requires ViT
|
|
4319
|
+
... epochs=50
|
|
4320
|
+
... )
|
|
4321
|
+
"""
|
|
4322
|
+
if not LIGHTLY_TRAIN_AVAILABLE:
|
|
4323
|
+
raise ImportError(
|
|
4324
|
+
"lightly-train is not installed. Please install it with: "
|
|
4325
|
+
"pip install lightly-train"
|
|
4326
|
+
)
|
|
4327
|
+
|
|
4328
|
+
if not os.path.exists(data_dir):
|
|
4329
|
+
raise ValueError(f"Data directory does not exist: {data_dir}")
|
|
4330
|
+
|
|
4331
|
+
# Check if data directory contains images
|
|
4332
|
+
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.tif", "*.tiff", "*.bmp"]
|
|
4333
|
+
image_files = []
|
|
4334
|
+
for ext in image_extensions:
|
|
4335
|
+
image_files.extend(glob.glob(os.path.join(data_dir, "**", ext), recursive=True))
|
|
4336
|
+
|
|
4337
|
+
if not image_files:
|
|
4338
|
+
raise ValueError(f"No image files found in {data_dir}")
|
|
4339
|
+
|
|
4340
|
+
# Validate model/method compatibility
|
|
4341
|
+
is_vit_model = "vit" in model.lower() or "vision_transformer" in model.lower()
|
|
4342
|
+
|
|
4343
|
+
if method in ["dinov2", "dinov2_distillation"] and not is_vit_model:
|
|
4344
|
+
raise ValueError(
|
|
4345
|
+
f"Method '{method}' requires a Vision Transformer (ViT) model, but got '{model}'.\n"
|
|
4346
|
+
f"Solutions:\n"
|
|
4347
|
+
f" 1. Use a ViT model: model='timm/vit_base_patch16_224'\n"
|
|
4348
|
+
f" 2. Use a CNN-compatible method: method='simclr' or method='dino'\n"
|
|
4349
|
+
f"\nFor CNN models (ResNet, EfficientNet), use 'simclr' or 'dino'.\n"
|
|
4350
|
+
f"For ViT models, use 'dinov2', 'dinov2_distillation', or 'dino'."
|
|
4351
|
+
)
|
|
4352
|
+
|
|
4353
|
+
print(f"Found {len(image_files)} images in {data_dir}")
|
|
4354
|
+
print(f"Starting self-supervised pretraining with {method} method...")
|
|
4355
|
+
print(f"Model: {model}")
|
|
4356
|
+
|
|
4357
|
+
# Create output directory
|
|
4358
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
4359
|
+
|
|
4360
|
+
# Detect if running in notebook environment and set appropriate configuration
|
|
4361
|
+
def is_notebook():
|
|
4362
|
+
try:
|
|
4363
|
+
from IPython import get_ipython
|
|
4364
|
+
|
|
4365
|
+
if get_ipython() is not None:
|
|
4366
|
+
return True
|
|
4367
|
+
except (ImportError, NameError):
|
|
4368
|
+
pass
|
|
4369
|
+
return False
|
|
4370
|
+
|
|
4371
|
+
# Force single-device training in notebooks to avoid DDP strategy issues
|
|
4372
|
+
if is_notebook():
|
|
4373
|
+
# Only override if not explicitly set by user
|
|
4374
|
+
if "accelerator" not in kwargs:
|
|
4375
|
+
# Use CPU in notebooks to avoid DDP incompatibility
|
|
4376
|
+
# Users can still override by passing accelerator='gpu'
|
|
4377
|
+
kwargs["accelerator"] = "cpu"
|
|
4378
|
+
if "devices" not in kwargs:
|
|
4379
|
+
kwargs["devices"] = 1 # Force single device
|
|
4380
|
+
|
|
4381
|
+
# Train the model using Lightly Train
|
|
4382
|
+
lightly_train.train(
|
|
4383
|
+
out=output_dir,
|
|
4384
|
+
data=data_dir,
|
|
4385
|
+
model=model,
|
|
4386
|
+
method=method,
|
|
4387
|
+
epochs=epochs,
|
|
4388
|
+
batch_size=batch_size,
|
|
4389
|
+
**kwargs,
|
|
4390
|
+
)
|
|
4391
|
+
|
|
4392
|
+
# Return path to the exported model
|
|
4393
|
+
exported_model_path = os.path.join(
|
|
4394
|
+
output_dir, "exported_models", "exported_last.pt"
|
|
4395
|
+
)
|
|
4396
|
+
|
|
4397
|
+
if os.path.exists(exported_model_path):
|
|
4398
|
+
print(
|
|
4399
|
+
f"Model training completed. Exported model saved to: {exported_model_path}"
|
|
4400
|
+
)
|
|
4401
|
+
return exported_model_path
|
|
4402
|
+
else:
|
|
4403
|
+
# Check for alternative export paths
|
|
4404
|
+
possible_paths = [
|
|
4405
|
+
os.path.join(output_dir, "exported_models", "exported_best.pt"),
|
|
4406
|
+
os.path.join(output_dir, "checkpoints", "last.ckpt"),
|
|
4407
|
+
]
|
|
4408
|
+
|
|
4409
|
+
for path in possible_paths:
|
|
4410
|
+
if os.path.exists(path):
|
|
4411
|
+
print(f"Model training completed. Exported model saved to: {path}")
|
|
4412
|
+
return path
|
|
4413
|
+
|
|
4414
|
+
print(f"Model training completed. Output saved to: {output_dir}")
|
|
4415
|
+
return output_dir
|
|
4416
|
+
|
|
4417
|
+
|
|
4418
|
+
def load_lightly_pretrained_model(
|
|
4419
|
+
model_path: str,
|
|
4420
|
+
model_architecture: str = "torchvision/resnet50",
|
|
4421
|
+
device: str = None,
|
|
4422
|
+
) -> torch.nn.Module:
|
|
4423
|
+
"""
|
|
4424
|
+
Load a pretrained model from Lightly Train.
|
|
4425
|
+
|
|
4426
|
+
Args:
|
|
4427
|
+
model_path (str): Path to the pretrained model file (.pt format).
|
|
4428
|
+
model_architecture (str): Architecture of the model to load.
|
|
4429
|
+
Default is "torchvision/resnet50".
|
|
4430
|
+
device (str): Device to load the model on. If None, uses CPU.
|
|
4431
|
+
|
|
4432
|
+
Returns:
|
|
4433
|
+
torch.nn.Module: Loaded pretrained model ready for fine-tuning.
|
|
4434
|
+
|
|
4435
|
+
Raises:
|
|
4436
|
+
FileNotFoundError: If model_path does not exist.
|
|
4437
|
+
ImportError: If required libraries are not available.
|
|
4438
|
+
|
|
4439
|
+
Example:
|
|
4440
|
+
>>> model = load_lightly_pretrained_model(
|
|
4441
|
+
... model_path="path/to/pretrained_model.pt",
|
|
4442
|
+
... model_architecture="torchvision/resnet50",
|
|
4443
|
+
... device="cuda"
|
|
4444
|
+
... )
|
|
4445
|
+
>>> # Fine-tune the model with your existing training pipeline
|
|
4446
|
+
"""
|
|
4447
|
+
if not os.path.exists(model_path):
|
|
4448
|
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
4449
|
+
|
|
4450
|
+
print(f"Loading pretrained model from: {model_path}")
|
|
4451
|
+
|
|
4452
|
+
# Load the model based on architecture
|
|
4453
|
+
if model_architecture.startswith("torchvision/"):
|
|
4454
|
+
model_name = model_architecture.replace("torchvision/", "")
|
|
4455
|
+
|
|
4456
|
+
# Import the model from torchvision
|
|
4457
|
+
if hasattr(torchvision.models, model_name):
|
|
4458
|
+
model = getattr(torchvision.models, model_name)()
|
|
4459
|
+
else:
|
|
4460
|
+
raise ValueError(f"Unknown torchvision model: {model_name}")
|
|
4461
|
+
|
|
4462
|
+
elif model_architecture.startswith("timm/"):
|
|
4463
|
+
try:
|
|
4464
|
+
import timm
|
|
4465
|
+
|
|
4466
|
+
model_name = model_architecture.replace("timm/", "")
|
|
4467
|
+
model = timm.create_model(model_name)
|
|
4468
|
+
except ImportError:
|
|
4469
|
+
raise ImportError(
|
|
4470
|
+
"timm is required for TIMM models. Install with: pip install timm"
|
|
4471
|
+
)
|
|
4472
|
+
|
|
4473
|
+
else:
|
|
4474
|
+
# For other architectures, try to import from torchvision as default
|
|
4475
|
+
try:
|
|
4476
|
+
model = getattr(torchvision.models, model_architecture)()
|
|
4477
|
+
except AttributeError:
|
|
4478
|
+
raise ValueError(f"Unsupported model architecture: {model_architecture}")
|
|
4479
|
+
|
|
4480
|
+
# Load the pretrained weights
|
|
4481
|
+
try:
|
|
4482
|
+
state_dict = torch.load(model_path, map_location=device, weights_only=True)
|
|
4483
|
+
except TypeError:
|
|
4484
|
+
# For backward compatibility with older PyTorch versions
|
|
4485
|
+
state_dict = torch.load(model_path, map_location=device)
|
|
4486
|
+
model.load_state_dict(state_dict)
|
|
4487
|
+
|
|
4488
|
+
print(f"Successfully loaded pretrained model: {model_architecture}")
|
|
4489
|
+
return model
|
|
4490
|
+
|
|
4491
|
+
|
|
4492
|
+
def lightly_embed_images(
|
|
4493
|
+
data_dir: str,
|
|
4494
|
+
model_path: str,
|
|
4495
|
+
output_path: str,
|
|
4496
|
+
model_architecture: str = None, # Deprecated, kept for backwards compatibility
|
|
4497
|
+
batch_size: int = 64,
|
|
4498
|
+
**kwargs: Any,
|
|
4499
|
+
) -> str:
|
|
4500
|
+
"""
|
|
4501
|
+
Generate embeddings for images using a Lightly Train pretrained model.
|
|
4502
|
+
|
|
4503
|
+
Args:
|
|
4504
|
+
data_dir (str): Directory containing images to embed.
|
|
4505
|
+
model_path (str): Path to the pretrained model checkpoint file (.ckpt).
|
|
4506
|
+
output_path (str): Path to save the embeddings (as .pt file).
|
|
4507
|
+
model_architecture (str): Architecture of the pretrained model (deprecated,
|
|
4508
|
+
kept for backwards compatibility but not used). The model architecture
|
|
4509
|
+
is automatically loaded from the checkpoint.
|
|
4510
|
+
batch_size (int): Batch size for embedding generation. Default is 64.
|
|
4511
|
+
**kwargs: Additional arguments passed to lightly_train.embed().
|
|
4512
|
+
Supported kwargs include: image_size, num_workers, accelerator, etc.
|
|
4513
|
+
|
|
4514
|
+
Returns:
|
|
4515
|
+
str: Path to the saved embeddings file.
|
|
4516
|
+
|
|
4517
|
+
Raises:
|
|
4518
|
+
ImportError: If lightly-train is not installed.
|
|
4519
|
+
FileNotFoundError: If data_dir or model_path does not exist.
|
|
4520
|
+
|
|
4521
|
+
Note:
|
|
4522
|
+
The model_path should point to a .ckpt file from the training output,
|
|
4523
|
+
typically located at: output_dir/checkpoints/last.ckpt
|
|
4524
|
+
|
|
4525
|
+
Example:
|
|
4526
|
+
>>> embeddings_path = lightly_embed_images(
|
|
4527
|
+
... data_dir="path/to/images",
|
|
4528
|
+
... model_path="output_dir/checkpoints/last.ckpt",
|
|
4529
|
+
... output_path="embeddings.pt",
|
|
4530
|
+
... batch_size=32
|
|
4531
|
+
... )
|
|
4532
|
+
>>> print(f"Embeddings saved to: {embeddings_path}")
|
|
4533
|
+
"""
|
|
4534
|
+
if not LIGHTLY_TRAIN_AVAILABLE:
|
|
4535
|
+
raise ImportError(
|
|
4536
|
+
"lightly-train is not installed. Please install it with: "
|
|
4537
|
+
"pip install lightly-train"
|
|
4538
|
+
)
|
|
4539
|
+
|
|
4540
|
+
if not os.path.exists(data_dir):
|
|
4541
|
+
raise FileNotFoundError(f"Data directory does not exist: {data_dir}")
|
|
4542
|
+
|
|
4543
|
+
if not os.path.exists(model_path):
|
|
4544
|
+
raise FileNotFoundError(f"Model file does not exist: {model_path}")
|
|
4545
|
+
|
|
4546
|
+
print(f"Generating embeddings for images in: {data_dir}")
|
|
4547
|
+
print(f"Using pretrained model: {model_path}")
|
|
4548
|
+
|
|
4549
|
+
output_dir = os.path.dirname(output_path)
|
|
4550
|
+
if output_dir:
|
|
4551
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
4552
|
+
|
|
4553
|
+
# Generate embeddings using Lightly Train
|
|
4554
|
+
# Note: model_architecture is not used - it's inferred from the checkpoint
|
|
4555
|
+
lightly_train.embed(
|
|
4556
|
+
out=output_path,
|
|
4557
|
+
data=data_dir,
|
|
4558
|
+
checkpoint=model_path,
|
|
4559
|
+
batch_size=batch_size,
|
|
4560
|
+
**kwargs,
|
|
4561
|
+
)
|
|
4562
|
+
|
|
4563
|
+
print(f"Embeddings saved to: {output_path}")
|
|
4564
|
+
return output_path
|