geoai-py 0.8.1__py2.py3-none-any.whl → 0.8.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/geoai.py +5 -0
- geoai/train.py +466 -24
- geoai/utils.py +371 -99
- {geoai_py-0.8.1.dist-info → geoai_py-0.8.3.dist-info}/METADATA +1 -1
- geoai_py-0.8.3.dist-info/RECORD +17 -0
- geoai_py-0.8.1.dist-info/RECORD +0 -17
- {geoai_py-0.8.1.dist-info → geoai_py-0.8.3.dist-info}/WHEEL +0 -0
- {geoai_py-0.8.1.dist-info → geoai_py-0.8.3.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.8.1.dist-info → geoai_py-0.8.3.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.8.1.dist-info → geoai_py-0.8.3.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
geoai/geoai.py
CHANGED
@@ -32,6 +32,11 @@ from .train import (
|
|
32
32
|
train_segmentation_model,
|
33
33
|
semantic_segmentation,
|
34
34
|
semantic_segmentation_batch,
|
35
|
+
train_instance_segmentation_model,
|
36
|
+
instance_segmentation,
|
37
|
+
instance_segmentation_batch,
|
38
|
+
get_instance_segmentation_model,
|
39
|
+
instance_segmentation_inference_on_geotiff,
|
35
40
|
)
|
36
41
|
from .utils import *
|
37
42
|
|
geoai/train.py
CHANGED
@@ -750,6 +750,15 @@ def train_MaskRCNN_model(
|
|
750
750
|
start_epoch = 0
|
751
751
|
best_iou = 0
|
752
752
|
|
753
|
+
# Initialize training history
|
754
|
+
training_history = {
|
755
|
+
"train_loss": [],
|
756
|
+
"val_loss": [],
|
757
|
+
"val_iou": [],
|
758
|
+
"epochs": [],
|
759
|
+
"lr": [],
|
760
|
+
}
|
761
|
+
|
753
762
|
# Load pretrained model if provided
|
754
763
|
if pretrained_model_path:
|
755
764
|
if not os.path.exists(pretrained_model_path):
|
@@ -800,6 +809,13 @@ def train_MaskRCNN_model(
|
|
800
809
|
# Evaluate
|
801
810
|
eval_metrics = evaluate(model, val_loader, device)
|
802
811
|
|
812
|
+
# Record training history
|
813
|
+
training_history["train_loss"].append(train_loss)
|
814
|
+
training_history["val_loss"].append(eval_metrics["loss"])
|
815
|
+
training_history["val_iou"].append(eval_metrics["IoU"])
|
816
|
+
training_history["epochs"].append(epoch + 1)
|
817
|
+
training_history["lr"].append(optimizer.param_groups[0]["lr"])
|
818
|
+
|
803
819
|
# Print metrics
|
804
820
|
print(
|
805
821
|
f"Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Val Loss: {eval_metrics['loss']:.4f}, Val IoU: {eval_metrics['IoU']:.4f}"
|
@@ -811,33 +827,11 @@ def train_MaskRCNN_model(
|
|
811
827
|
print(f"Saving best model with IoU: {best_iou:.4f}")
|
812
828
|
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
813
829
|
|
814
|
-
# Save checkpoint every 10 epochs
|
815
|
-
if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
|
816
|
-
torch.save(
|
817
|
-
{
|
818
|
-
"epoch": epoch,
|
819
|
-
"model_state_dict": model.state_dict(),
|
820
|
-
"optimizer_state_dict": optimizer.state_dict(),
|
821
|
-
"scheduler_state_dict": lr_scheduler.state_dict(),
|
822
|
-
"best_iou": best_iou,
|
823
|
-
},
|
824
|
-
os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
|
825
|
-
)
|
826
|
-
|
827
830
|
# Save final model
|
828
831
|
torch.save(model.state_dict(), os.path.join(output_dir, "final_model.pth"))
|
829
832
|
|
830
|
-
# Save
|
831
|
-
torch.save(
|
832
|
-
{
|
833
|
-
"epoch": num_epochs - 1,
|
834
|
-
"model_state_dict": model.state_dict(),
|
835
|
-
"optimizer_state_dict": optimizer.state_dict(),
|
836
|
-
"scheduler_state_dict": lr_scheduler.state_dict(),
|
837
|
-
"best_iou": best_iou,
|
838
|
-
},
|
839
|
-
os.path.join(output_dir, "final_checkpoint.pth"),
|
840
|
-
)
|
833
|
+
# Save training history
|
834
|
+
torch.save(training_history, os.path.join(output_dir, "training_history.pth"))
|
841
835
|
|
842
836
|
# Load best model for evaluation and visualization
|
843
837
|
model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
|
@@ -1101,6 +1095,237 @@ def inference_on_geotiff(
|
|
1101
1095
|
return output_path, inference_time
|
1102
1096
|
|
1103
1097
|
|
1098
|
+
def instance_segmentation_inference_on_geotiff(
|
1099
|
+
model,
|
1100
|
+
geotiff_path,
|
1101
|
+
output_path,
|
1102
|
+
window_size=512,
|
1103
|
+
overlap=256,
|
1104
|
+
confidence_threshold=0.5,
|
1105
|
+
batch_size=4,
|
1106
|
+
num_channels=3,
|
1107
|
+
device=None,
|
1108
|
+
**kwargs,
|
1109
|
+
):
|
1110
|
+
"""
|
1111
|
+
Perform instance segmentation inference on a large GeoTIFF using a sliding window approach.
|
1112
|
+
|
1113
|
+
This function collects all detections first, then applies non-maximum suppression
|
1114
|
+
to handle overlapping detections from different windows, preventing artifacts.
|
1115
|
+
|
1116
|
+
Args:
|
1117
|
+
model (torch.nn.Module): Trained model for inference.
|
1118
|
+
geotiff_path (str): Path to input GeoTIFF file.
|
1119
|
+
output_path (str): Path to save output instance mask GeoTIFF.
|
1120
|
+
window_size (int): Size of sliding window for inference.
|
1121
|
+
overlap (int): Overlap between adjacent windows.
|
1122
|
+
confidence_threshold (float): Confidence threshold for predictions (0-1).
|
1123
|
+
batch_size (int): Batch size for inference.
|
1124
|
+
num_channels (int): Number of channels to use from the input image.
|
1125
|
+
device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
|
1126
|
+
**kwargs: Additional arguments.
|
1127
|
+
|
1128
|
+
Returns:
|
1129
|
+
tuple: Tuple containing output path and inference time in seconds.
|
1130
|
+
"""
|
1131
|
+
if device is None:
|
1132
|
+
device = get_device()
|
1133
|
+
|
1134
|
+
# Put model in evaluation mode
|
1135
|
+
model.to(device)
|
1136
|
+
model.eval()
|
1137
|
+
|
1138
|
+
# Open the GeoTIFF
|
1139
|
+
with rasterio.open(geotiff_path) as src:
|
1140
|
+
# Read metadata
|
1141
|
+
meta = src.meta
|
1142
|
+
height = src.height
|
1143
|
+
width = src.width
|
1144
|
+
|
1145
|
+
# Update metadata for output raster
|
1146
|
+
out_meta = meta.copy()
|
1147
|
+
out_meta.update(
|
1148
|
+
{"count": 1, "dtype": "uint16"} # uint16 to support many instances
|
1149
|
+
)
|
1150
|
+
|
1151
|
+
# Store all detections globally for NMS
|
1152
|
+
all_detections = []
|
1153
|
+
|
1154
|
+
# Calculate the number of windows needed to cover the entire image
|
1155
|
+
steps_y = math.ceil((height - overlap) / (window_size - overlap))
|
1156
|
+
steps_x = math.ceil((width - overlap) / (window_size - overlap))
|
1157
|
+
|
1158
|
+
# Ensure we cover the entire image
|
1159
|
+
last_y = height - window_size
|
1160
|
+
last_x = width - window_size
|
1161
|
+
|
1162
|
+
total_windows = steps_y * steps_x
|
1163
|
+
print(
|
1164
|
+
f"Processing {total_windows} windows with size {window_size}x{window_size} and overlap {overlap}..."
|
1165
|
+
)
|
1166
|
+
|
1167
|
+
# Create progress bar
|
1168
|
+
pbar = tqdm(total=total_windows)
|
1169
|
+
|
1170
|
+
# Process in batches
|
1171
|
+
batch_inputs = []
|
1172
|
+
batch_positions = []
|
1173
|
+
batch_count = 0
|
1174
|
+
|
1175
|
+
start_time = time.time()
|
1176
|
+
|
1177
|
+
# Slide window over the image
|
1178
|
+
for i in range(steps_y + 1): # +1 to ensure we reach the edge
|
1179
|
+
y = min(i * (window_size - overlap), last_y)
|
1180
|
+
y = max(0, y) # Prevent negative indices
|
1181
|
+
|
1182
|
+
if y > last_y and i > 0: # Skip if we've already covered the entire height
|
1183
|
+
continue
|
1184
|
+
|
1185
|
+
for j in range(steps_x + 1): # +1 to ensure we reach the edge
|
1186
|
+
x = min(j * (window_size - overlap), last_x)
|
1187
|
+
x = max(0, x) # Prevent negative indices
|
1188
|
+
|
1189
|
+
if (
|
1190
|
+
x > last_x and j > 0
|
1191
|
+
): # Skip if we've already covered the entire width
|
1192
|
+
continue
|
1193
|
+
|
1194
|
+
# Read window
|
1195
|
+
window = src.read(window=Window(x, y, window_size, window_size))
|
1196
|
+
|
1197
|
+
# Check if window is valid
|
1198
|
+
if window.shape[1] == 0 or window.shape[2] == 0:
|
1199
|
+
continue
|
1200
|
+
|
1201
|
+
# Handle edge cases where window might be smaller than expected
|
1202
|
+
actual_height, actual_width = window.shape[1], window.shape[2]
|
1203
|
+
|
1204
|
+
# Convert to [C, H, W] format and normalize
|
1205
|
+
image = window.astype(np.float32) / 255.0
|
1206
|
+
|
1207
|
+
# Handle different number of channels
|
1208
|
+
if image.shape[0] > num_channels:
|
1209
|
+
image = image[:num_channels]
|
1210
|
+
elif image.shape[0] < num_channels:
|
1211
|
+
# Pad with zeros if less than expected channels
|
1212
|
+
padded = np.zeros(
|
1213
|
+
(num_channels, image.shape[1], image.shape[2]), dtype=np.float32
|
1214
|
+
)
|
1215
|
+
padded[: image.shape[0]] = image
|
1216
|
+
image = padded
|
1217
|
+
|
1218
|
+
# Convert to tensor
|
1219
|
+
image_tensor = torch.tensor(image, device=device)
|
1220
|
+
|
1221
|
+
# Add to batch
|
1222
|
+
batch_inputs.append(image_tensor)
|
1223
|
+
batch_positions.append((y, x, actual_height, actual_width))
|
1224
|
+
batch_count += 1
|
1225
|
+
|
1226
|
+
# Process batch when it reaches the batch size or at the end
|
1227
|
+
if batch_count == batch_size or (i == steps_y and j == steps_x):
|
1228
|
+
# Forward pass
|
1229
|
+
with torch.no_grad():
|
1230
|
+
outputs = model(batch_inputs)
|
1231
|
+
|
1232
|
+
# Process each output in the batch
|
1233
|
+
for idx, output in enumerate(outputs):
|
1234
|
+
y_pos, x_pos, h, w = batch_positions[idx]
|
1235
|
+
|
1236
|
+
# Process each detected instance
|
1237
|
+
if len(output["scores"]) > 0:
|
1238
|
+
# Get instances that meet confidence threshold
|
1239
|
+
keep = output["scores"] > confidence_threshold
|
1240
|
+
masks = output["masks"][keep].squeeze(1)
|
1241
|
+
scores = output["scores"][keep]
|
1242
|
+
boxes = output["boxes"][keep]
|
1243
|
+
|
1244
|
+
# Convert to global coordinates and store
|
1245
|
+
for k in range(len(masks)):
|
1246
|
+
mask = masks[k].cpu().numpy() > 0.5
|
1247
|
+
score = scores[k].cpu().item()
|
1248
|
+
box = boxes[k].cpu().numpy()
|
1249
|
+
|
1250
|
+
# Convert box to global coordinates
|
1251
|
+
global_box = [
|
1252
|
+
box[0] + x_pos,
|
1253
|
+
box[1] + y_pos,
|
1254
|
+
box[2] + x_pos,
|
1255
|
+
box[3] + y_pos,
|
1256
|
+
]
|
1257
|
+
|
1258
|
+
# Create global mask
|
1259
|
+
global_mask = np.zeros((height, width), dtype=bool)
|
1260
|
+
global_mask[y_pos : y_pos + h, x_pos : x_pos + w] = mask
|
1261
|
+
|
1262
|
+
all_detections.append(
|
1263
|
+
{
|
1264
|
+
"mask": global_mask,
|
1265
|
+
"score": score,
|
1266
|
+
"box": global_box,
|
1267
|
+
}
|
1268
|
+
)
|
1269
|
+
|
1270
|
+
# Reset batch
|
1271
|
+
batch_inputs = []
|
1272
|
+
batch_positions = []
|
1273
|
+
batch_count = 0
|
1274
|
+
|
1275
|
+
# Update progress bar
|
1276
|
+
pbar.update(len(outputs))
|
1277
|
+
|
1278
|
+
# Close progress bar
|
1279
|
+
pbar.close()
|
1280
|
+
|
1281
|
+
print(f"Collected {len(all_detections)} detections before NMS")
|
1282
|
+
|
1283
|
+
# Apply Non-Maximum Suppression to handle overlapping detections
|
1284
|
+
if len(all_detections) > 0:
|
1285
|
+
# Convert to tensors for NMS
|
1286
|
+
boxes = torch.tensor([det["box"] for det in all_detections])
|
1287
|
+
scores = torch.tensor([det["score"] for det in all_detections])
|
1288
|
+
|
1289
|
+
# Apply NMS with IoU threshold
|
1290
|
+
nms_threshold = 0.3 # IoU threshold for NMS
|
1291
|
+
keep_indices = torchvision.ops.nms(boxes, scores, nms_threshold)
|
1292
|
+
|
1293
|
+
# Keep only the selected detections
|
1294
|
+
final_detections = [all_detections[i] for i in keep_indices]
|
1295
|
+
print(f"After NMS: {len(final_detections)} detections")
|
1296
|
+
|
1297
|
+
# Create final instance mask
|
1298
|
+
instance_mask = np.zeros((height, width), dtype=np.uint16)
|
1299
|
+
|
1300
|
+
# Sort by score (highest first) for consistent ordering
|
1301
|
+
final_detections.sort(key=lambda x: x["score"], reverse=True)
|
1302
|
+
|
1303
|
+
# Assign unique IDs to each detection
|
1304
|
+
for instance_id, detection in enumerate(final_detections, 1):
|
1305
|
+
mask = detection["mask"]
|
1306
|
+
# Only assign to pixels that are not already assigned
|
1307
|
+
available_pixels = (instance_mask == 0) & mask
|
1308
|
+
instance_mask[available_pixels] = instance_id
|
1309
|
+
else:
|
1310
|
+
# No detections found
|
1311
|
+
instance_mask = np.zeros((height, width), dtype=np.uint16)
|
1312
|
+
|
1313
|
+
# Record time
|
1314
|
+
inference_time = time.time() - start_time
|
1315
|
+
print(f"Instance segmentation completed in {inference_time:.2f} seconds")
|
1316
|
+
print(
|
1317
|
+
f"Final instances: {len(final_detections) if len(all_detections) > 0 else 0}"
|
1318
|
+
)
|
1319
|
+
|
1320
|
+
# Save output
|
1321
|
+
with rasterio.open(output_path, "w", **out_meta) as dst:
|
1322
|
+
dst.write(instance_mask, 1)
|
1323
|
+
|
1324
|
+
print(f"Saved instance segmentation to {output_path}")
|
1325
|
+
|
1326
|
+
return output_path, inference_time
|
1327
|
+
|
1328
|
+
|
1104
1329
|
def object_detection(
|
1105
1330
|
input_path,
|
1106
1331
|
output_path,
|
@@ -3067,3 +3292,220 @@ def semantic_segmentation_batch(
|
|
3067
3292
|
continue
|
3068
3293
|
|
3069
3294
|
print(f"Batch processing completed. Results saved to {output_dir}")
|
3295
|
+
|
3296
|
+
|
3297
|
+
def train_instance_segmentation_model(
|
3298
|
+
images_dir,
|
3299
|
+
labels_dir,
|
3300
|
+
output_dir,
|
3301
|
+
num_classes=2,
|
3302
|
+
num_channels=3,
|
3303
|
+
batch_size=4,
|
3304
|
+
num_epochs=10,
|
3305
|
+
learning_rate=0.005,
|
3306
|
+
seed=42,
|
3307
|
+
val_split=0.2,
|
3308
|
+
visualize=False,
|
3309
|
+
device=None,
|
3310
|
+
verbose=True,
|
3311
|
+
**kwargs,
|
3312
|
+
):
|
3313
|
+
"""
|
3314
|
+
Train an instance segmentation model using Mask R-CNN.
|
3315
|
+
|
3316
|
+
This is a wrapper function for train_MaskRCNN_model with clearer naming.
|
3317
|
+
|
3318
|
+
Args:
|
3319
|
+
images_dir (str): Directory containing image GeoTIFF files.
|
3320
|
+
labels_dir (str): Directory containing label GeoTIFF files.
|
3321
|
+
output_dir (str): Directory to save model checkpoints and results.
|
3322
|
+
num_classes (int): Number of classes (including background). Defaults to 2.
|
3323
|
+
num_channels (int): Number of input channels. Defaults to 3.
|
3324
|
+
batch_size (int): Batch size for training. Defaults to 4.
|
3325
|
+
num_epochs (int): Number of training epochs. Defaults to 10.
|
3326
|
+
learning_rate (float): Initial learning rate. Defaults to 0.005.
|
3327
|
+
seed (int): Random seed for reproducibility. Defaults to 42.
|
3328
|
+
val_split (float): Fraction of data to use for validation (0-1). Defaults to 0.2.
|
3329
|
+
visualize (bool): Whether to generate visualizations. Defaults to False.
|
3330
|
+
device (torch.device): Device to train on. If None, uses CUDA if available.
|
3331
|
+
verbose (bool): If True, prints detailed training progress. Defaults to True.
|
3332
|
+
**kwargs: Additional arguments passed to train_MaskRCNN_model.
|
3333
|
+
|
3334
|
+
Returns:
|
3335
|
+
None: Model weights are saved to output_dir.
|
3336
|
+
"""
|
3337
|
+
# Create model with the specified number of classes
|
3338
|
+
model = get_instance_segmentation_model(
|
3339
|
+
num_classes=num_classes, num_channels=num_channels, pretrained=True
|
3340
|
+
)
|
3341
|
+
|
3342
|
+
return train_MaskRCNN_model(
|
3343
|
+
images_dir=images_dir,
|
3344
|
+
labels_dir=labels_dir,
|
3345
|
+
output_dir=output_dir,
|
3346
|
+
num_channels=num_channels,
|
3347
|
+
model=model,
|
3348
|
+
batch_size=batch_size,
|
3349
|
+
num_epochs=num_epochs,
|
3350
|
+
learning_rate=learning_rate,
|
3351
|
+
seed=seed,
|
3352
|
+
val_split=val_split,
|
3353
|
+
visualize=visualize,
|
3354
|
+
device=device,
|
3355
|
+
verbose=verbose,
|
3356
|
+
**kwargs,
|
3357
|
+
)
|
3358
|
+
|
3359
|
+
|
3360
|
+
def instance_segmentation(
|
3361
|
+
input_path,
|
3362
|
+
output_path,
|
3363
|
+
model_path,
|
3364
|
+
window_size=512,
|
3365
|
+
overlap=256,
|
3366
|
+
confidence_threshold=0.5,
|
3367
|
+
batch_size=4,
|
3368
|
+
num_channels=3,
|
3369
|
+
num_classes=2,
|
3370
|
+
device=None,
|
3371
|
+
**kwargs,
|
3372
|
+
):
|
3373
|
+
"""
|
3374
|
+
Perform instance segmentation on a GeoTIFF using a pre-trained Mask R-CNN model.
|
3375
|
+
|
3376
|
+
This is a wrapper function for object_detection with clearer naming.
|
3377
|
+
|
3378
|
+
Args:
|
3379
|
+
input_path (str): Path to input GeoTIFF file.
|
3380
|
+
output_path (str): Path to save output mask GeoTIFF.
|
3381
|
+
model_path (str): Path to trained model weights.
|
3382
|
+
window_size (int): Size of sliding window for inference. Defaults to 512.
|
3383
|
+
overlap (int): Overlap between adjacent windows. Defaults to 256.
|
3384
|
+
confidence_threshold (float): Confidence threshold for predictions (0-1). Defaults to 0.5.
|
3385
|
+
batch_size (int): Batch size for inference. Defaults to 4.
|
3386
|
+
num_channels (int): Number of channels in the input image and model. Defaults to 3.
|
3387
|
+
num_classes (int): Number of classes (including background). Defaults to 2.
|
3388
|
+
device (torch.device): Device to run inference on. If None, uses CUDA if available.
|
3389
|
+
**kwargs: Additional arguments passed to object_detection.
|
3390
|
+
|
3391
|
+
Returns:
|
3392
|
+
None: Output mask is saved to output_path.
|
3393
|
+
"""
|
3394
|
+
# Create model with the specified number of classes
|
3395
|
+
model = get_instance_segmentation_model(
|
3396
|
+
num_classes=num_classes, num_channels=num_channels, pretrained=True
|
3397
|
+
)
|
3398
|
+
|
3399
|
+
# Load the trained model
|
3400
|
+
if device is None:
|
3401
|
+
device = get_device()
|
3402
|
+
|
3403
|
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
3404
|
+
model.to(device)
|
3405
|
+
|
3406
|
+
# Use the proper instance segmentation inference function
|
3407
|
+
return instance_segmentation_inference_on_geotiff(
|
3408
|
+
model=model,
|
3409
|
+
geotiff_path=input_path,
|
3410
|
+
output_path=output_path,
|
3411
|
+
window_size=window_size,
|
3412
|
+
overlap=overlap,
|
3413
|
+
confidence_threshold=confidence_threshold,
|
3414
|
+
batch_size=batch_size,
|
3415
|
+
num_channels=num_channels,
|
3416
|
+
device=device,
|
3417
|
+
**kwargs,
|
3418
|
+
)
|
3419
|
+
|
3420
|
+
|
3421
|
+
def instance_segmentation_batch(
|
3422
|
+
input_dir,
|
3423
|
+
output_dir,
|
3424
|
+
model_path,
|
3425
|
+
window_size=512,
|
3426
|
+
overlap=256,
|
3427
|
+
confidence_threshold=0.5,
|
3428
|
+
batch_size=4,
|
3429
|
+
num_channels=3,
|
3430
|
+
num_classes=2,
|
3431
|
+
device=None,
|
3432
|
+
**kwargs,
|
3433
|
+
):
|
3434
|
+
"""
|
3435
|
+
Perform instance segmentation on multiple GeoTIFF files using a pre-trained Mask R-CNN model.
|
3436
|
+
|
3437
|
+
This is a wrapper function for object_detection_batch with clearer naming.
|
3438
|
+
|
3439
|
+
Args:
|
3440
|
+
input_dir (str): Directory containing input GeoTIFF files.
|
3441
|
+
output_dir (str): Directory to save output mask GeoTIFF files.
|
3442
|
+
model_path (str): Path to trained model weights.
|
3443
|
+
window_size (int): Size of sliding window for inference. Defaults to 512.
|
3444
|
+
overlap (int): Overlap between adjacent windows. Defaults to 256.
|
3445
|
+
confidence_threshold (float): Confidence threshold for predictions (0-1). Defaults to 0.5.
|
3446
|
+
batch_size (int): Batch size for inference. Defaults to 4.
|
3447
|
+
num_channels (int): Number of channels in the input image and model. Defaults to 3.
|
3448
|
+
num_classes (int): Number of classes (including background). Defaults to 2.
|
3449
|
+
device (torch.device): Device to run inference on. If None, uses CUDA if available.
|
3450
|
+
**kwargs: Additional arguments passed to object_detection_batch.
|
3451
|
+
|
3452
|
+
Returns:
|
3453
|
+
None: Output masks are saved to output_dir.
|
3454
|
+
"""
|
3455
|
+
# Create model with the specified number of classes
|
3456
|
+
model = get_instance_segmentation_model(
|
3457
|
+
num_classes=num_classes, num_channels=num_channels, pretrained=True
|
3458
|
+
)
|
3459
|
+
|
3460
|
+
# Load the trained model
|
3461
|
+
if device is None:
|
3462
|
+
device = get_device()
|
3463
|
+
|
3464
|
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
3465
|
+
model.to(device)
|
3466
|
+
|
3467
|
+
# Process all GeoTIFF files in the input directory
|
3468
|
+
import glob
|
3469
|
+
|
3470
|
+
input_files = glob.glob(os.path.join(input_dir, "*.tif")) + glob.glob(
|
3471
|
+
os.path.join(input_dir, "*.tiff")
|
3472
|
+
)
|
3473
|
+
|
3474
|
+
if not input_files:
|
3475
|
+
print(f"No GeoTIFF files found in {input_dir}")
|
3476
|
+
return
|
3477
|
+
|
3478
|
+
# Create output directory if it doesn't exist
|
3479
|
+
os.makedirs(output_dir, exist_ok=True)
|
3480
|
+
|
3481
|
+
print(f"Processing {len(input_files)} files...")
|
3482
|
+
|
3483
|
+
for input_file in input_files:
|
3484
|
+
try:
|
3485
|
+
# Generate output filename
|
3486
|
+
base_name = os.path.splitext(os.path.basename(input_file))[0]
|
3487
|
+
output_file = os.path.join(output_dir, f"{base_name}_instances.tif")
|
3488
|
+
|
3489
|
+
print(f"Processing {input_file}...")
|
3490
|
+
|
3491
|
+
# Run instance segmentation inference
|
3492
|
+
instance_segmentation_inference_on_geotiff(
|
3493
|
+
model=model,
|
3494
|
+
geotiff_path=input_file,
|
3495
|
+
output_path=output_file,
|
3496
|
+
window_size=window_size,
|
3497
|
+
overlap=overlap,
|
3498
|
+
confidence_threshold=confidence_threshold,
|
3499
|
+
batch_size=batch_size,
|
3500
|
+
num_channels=num_channels,
|
3501
|
+
device=device,
|
3502
|
+
**kwargs,
|
3503
|
+
)
|
3504
|
+
|
3505
|
+
print(f"Saved result to {output_file}")
|
3506
|
+
|
3507
|
+
except Exception as e:
|
3508
|
+
print(f"Error processing {input_file}: {str(e)}")
|
3509
|
+
continue
|
3510
|
+
|
3511
|
+
print(f"Batch processing completed. Results saved to {output_dir}")
|
geoai/utils.py
CHANGED
@@ -157,7 +157,8 @@ def view_image(
|
|
157
157
|
image: Union[np.ndarray, torch.Tensor],
|
158
158
|
transpose: bool = False,
|
159
159
|
bdx: Optional[int] = None,
|
160
|
-
|
160
|
+
clip_percentiles: Optional[Tuple[float, float]] = (2, 98),
|
161
|
+
gamma: Optional[float] = None,
|
161
162
|
figsize: Tuple[int, int] = (10, 5),
|
162
163
|
axis_off: bool = True,
|
163
164
|
title: Optional[str] = None,
|
@@ -185,7 +186,7 @@ def view_image(
|
|
185
186
|
elif isinstance(image, str):
|
186
187
|
image = rasterio.open(image).read().transpose(1, 2, 0)
|
187
188
|
|
188
|
-
plt.figure(figsize=figsize)
|
189
|
+
ax = plt.figure(figsize=figsize)
|
189
190
|
|
190
191
|
if transpose:
|
191
192
|
image = image.transpose(1, 2, 0)
|
@@ -196,8 +197,14 @@ def view_image(
|
|
196
197
|
if len(image.shape) > 2 and image.shape[2] > 3:
|
197
198
|
image = image[:, :, 0:3]
|
198
199
|
|
199
|
-
if
|
200
|
-
|
200
|
+
if clip_percentiles is not None:
|
201
|
+
p_low, p_high = clip_percentiles
|
202
|
+
lower = np.percentile(image, p_low)
|
203
|
+
upper = np.percentile(image, p_high)
|
204
|
+
image = np.clip((image - lower) / (upper - lower), 0, 1)
|
205
|
+
|
206
|
+
if gamma is not None:
|
207
|
+
image = np.power(image, gamma)
|
201
208
|
|
202
209
|
plt.imshow(image, **kwargs)
|
203
210
|
if axis_off:
|
@@ -207,6 +214,8 @@ def view_image(
|
|
207
214
|
plt.show()
|
208
215
|
plt.close()
|
209
216
|
|
217
|
+
return ax
|
218
|
+
|
210
219
|
|
211
220
|
def plot_images(
|
212
221
|
images: Iterable[torch.Tensor],
|
@@ -5658,7 +5667,7 @@ def orthogonalize(
|
|
5658
5667
|
if len(df) == 0:
|
5659
5668
|
return ring
|
5660
5669
|
|
5661
|
-
# If we have a triangle-like result (3 segments), return the original shape
|
5670
|
+
# If we have a triangle-like result (3 segments or less), return the original shape
|
5662
5671
|
if len(df) <= 3:
|
5663
5672
|
return ring
|
5664
5673
|
|
@@ -5669,8 +5678,116 @@ def orthogonalize(
|
|
5669
5678
|
if len(joined_ring) == 0 or len(joined_ring[0]) < 3:
|
5670
5679
|
return ring
|
5671
5680
|
|
5672
|
-
#
|
5673
|
-
|
5681
|
+
# Enhanced validation: check for triangular result and geometric validity
|
5682
|
+
result_coords = joined_ring[0]
|
5683
|
+
|
5684
|
+
# If result has 3 or fewer points (triangle), use original
|
5685
|
+
if len(result_coords) <= 3: # 2 points + closing point (degenerate)
|
5686
|
+
return ring
|
5687
|
+
|
5688
|
+
# Additional validation: check for degenerate geometry
|
5689
|
+
# Calculate area ratio to detect if the shape got severely distorted
|
5690
|
+
def calculate_polygon_area(coords):
|
5691
|
+
if len(coords) < 3:
|
5692
|
+
return 0
|
5693
|
+
area = 0
|
5694
|
+
n = len(coords)
|
5695
|
+
for i in range(n):
|
5696
|
+
j = (i + 1) % n
|
5697
|
+
area += coords[i][0] * coords[j][1]
|
5698
|
+
area -= coords[j][0] * coords[i][1]
|
5699
|
+
return abs(area) / 2
|
5700
|
+
|
5701
|
+
original_area = calculate_polygon_area(ring)
|
5702
|
+
result_area = calculate_polygon_area(result_coords)
|
5703
|
+
|
5704
|
+
# If the area changed dramatically (more than 30% shrinkage or 300% growth), use original
|
5705
|
+
if original_area > 0 and result_area > 0:
|
5706
|
+
area_ratio = result_area / original_area
|
5707
|
+
if area_ratio < 0.3 or area_ratio > 3.0:
|
5708
|
+
return ring
|
5709
|
+
|
5710
|
+
# Check for triangular spikes and problematic artifacts
|
5711
|
+
very_acute_angle_count = 0
|
5712
|
+
triangular_spike_detected = False
|
5713
|
+
|
5714
|
+
for i in range(len(result_coords) - 1): # -1 to exclude closing point
|
5715
|
+
p1 = result_coords[i - 1]
|
5716
|
+
p2 = result_coords[i]
|
5717
|
+
p3 = result_coords[(i + 1) % (len(result_coords) - 1)]
|
5718
|
+
|
5719
|
+
# Calculate angle at p2
|
5720
|
+
v1 = np.array([p1[0] - p2[0], p1[1] - p2[1]])
|
5721
|
+
v2 = np.array([p3[0] - p2[0], p3[1] - p2[1]])
|
5722
|
+
|
5723
|
+
v1_norm = np.linalg.norm(v1)
|
5724
|
+
v2_norm = np.linalg.norm(v2)
|
5725
|
+
|
5726
|
+
if v1_norm > 0 and v2_norm > 0:
|
5727
|
+
cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
|
5728
|
+
cos_angle = np.clip(cos_angle, -1, 1)
|
5729
|
+
angle = np.arccos(cos_angle)
|
5730
|
+
|
5731
|
+
# Count very acute angles (< 20 degrees) - these are likely spikes
|
5732
|
+
if angle < np.pi / 9: # 20 degrees
|
5733
|
+
very_acute_angle_count += 1
|
5734
|
+
# If it's very acute with short sides, it's definitely a spike
|
5735
|
+
if v1_norm < 5 or v2_norm < 5:
|
5736
|
+
triangular_spike_detected = True
|
5737
|
+
|
5738
|
+
# Check for excessively long edges that might be artifacts
|
5739
|
+
edge_lengths = []
|
5740
|
+
for i in range(len(result_coords) - 1):
|
5741
|
+
edge_len = np.sqrt(
|
5742
|
+
(result_coords[i + 1][0] - result_coords[i][0]) ** 2
|
5743
|
+
+ (result_coords[i + 1][1] - result_coords[i][1]) ** 2
|
5744
|
+
)
|
5745
|
+
edge_lengths.append(edge_len)
|
5746
|
+
|
5747
|
+
excessive_edge_detected = False
|
5748
|
+
if len(edge_lengths) > 0:
|
5749
|
+
avg_edge_length = np.mean(edge_lengths)
|
5750
|
+
max_edge_length = np.max(edge_lengths)
|
5751
|
+
# Only reject if edge is extremely disproportionate (8x average)
|
5752
|
+
if max_edge_length > avg_edge_length * 8:
|
5753
|
+
excessive_edge_detected = True
|
5754
|
+
|
5755
|
+
# Check for triangular artifacts by detecting spikes that extend beyond bounds
|
5756
|
+
# Calculate original bounds
|
5757
|
+
orig_xs = [p[0] for p in ring]
|
5758
|
+
orig_ys = [p[1] for p in ring]
|
5759
|
+
orig_min_x, orig_max_x = min(orig_xs), max(orig_xs)
|
5760
|
+
orig_min_y, orig_max_y = min(orig_ys), max(orig_ys)
|
5761
|
+
orig_width = orig_max_x - orig_min_x
|
5762
|
+
orig_height = orig_max_y - orig_min_y
|
5763
|
+
|
5764
|
+
# Calculate result bounds
|
5765
|
+
result_xs = [p[0] for p in result_coords]
|
5766
|
+
result_ys = [p[1] for p in result_coords]
|
5767
|
+
result_min_x, result_max_x = min(result_xs), max(result_xs)
|
5768
|
+
result_min_y, result_max_y = min(result_ys), max(result_ys)
|
5769
|
+
|
5770
|
+
# Stricter bounds checking to catch triangular artifacts
|
5771
|
+
bounds_extension_detected = False
|
5772
|
+
# More conservative: only allow 10% extension
|
5773
|
+
tolerance_x = max(orig_width * 0.1, 1.0) # 10% tolerance, at least 1 unit
|
5774
|
+
tolerance_y = max(orig_height * 0.1, 1.0) # 10% tolerance, at least 1 unit
|
5775
|
+
|
5776
|
+
if (
|
5777
|
+
result_min_x < orig_min_x - tolerance_x
|
5778
|
+
or result_max_x > orig_max_x + tolerance_x
|
5779
|
+
or result_min_y < orig_min_y - tolerance_y
|
5780
|
+
or result_max_y > orig_max_y + tolerance_y
|
5781
|
+
):
|
5782
|
+
bounds_extension_detected = True
|
5783
|
+
|
5784
|
+
# Reject if we detect triangular spikes, excessive edges, or bounds violations
|
5785
|
+
if (
|
5786
|
+
triangular_spike_detected
|
5787
|
+
or very_acute_angle_count > 2 # Multiple very acute angles
|
5788
|
+
or excessive_edge_detected
|
5789
|
+
or bounds_extension_detected
|
5790
|
+
): # Any significant bounds extension
|
5674
5791
|
return ring
|
5675
5792
|
|
5676
5793
|
# Convert back to a list and ensure it's closed
|
@@ -5887,37 +6004,86 @@ def orthogonalize(
|
|
5887
6004
|
}
|
5888
6005
|
)
|
5889
6006
|
|
6007
|
+
# Improved fix: Prevent merging that would create triangular or problematic shapes
|
5890
6008
|
if (
|
5891
|
-
len(ortho_list) >
|
6009
|
+
len(ortho_list) > 3 and ortho_list[0]["angle"] == ortho_list[-1]["angle"]
|
5892
6010
|
): # join first and last segment if they're in same direction
|
5893
|
-
|
5894
|
-
|
5895
|
-
|
5896
|
-
|
5897
|
-
|
5898
|
-
|
5899
|
-
|
5900
|
-
(
|
5901
|
-
|
5902
|
-
|
5903
|
-
|
5904
|
-
|
5905
|
-
|
5906
|
-
|
5907
|
-
|
5908
|
-
|
5909
|
-
|
5910
|
-
|
5911
|
-
|
5912
|
-
|
5913
|
-
|
5914
|
-
|
5915
|
-
|
5916
|
-
|
5917
|
-
|
5918
|
-
|
5919
|
-
|
5920
|
-
|
6011
|
+
# Check if merging would result in 3 or 4 segments (potentially triangular)
|
6012
|
+
resulting_segments = len(ortho_list) - 1
|
6013
|
+
if resulting_segments <= 4:
|
6014
|
+
# For very small polygons, be extra cautious about merging
|
6015
|
+
# Calculate the spatial relationship between first and last segments
|
6016
|
+
first_center = np.array([ortho_list[0]["cx"], ortho_list[0]["cy"]])
|
6017
|
+
last_center = np.array([ortho_list[-1]["cx"], ortho_list[-1]["cy"]])
|
6018
|
+
center_distance = np.linalg.norm(first_center - last_center)
|
6019
|
+
|
6020
|
+
# Get average segment length for comparison
|
6021
|
+
avg_length = sum(seg["len"] for seg in ortho_list) / len(ortho_list)
|
6022
|
+
|
6023
|
+
# Only merge if segments are close enough and it won't create degenerate shapes
|
6024
|
+
if center_distance > avg_length * 1.5:
|
6025
|
+
# Skip merging - segments are too far apart
|
6026
|
+
pass
|
6027
|
+
else:
|
6028
|
+
# Proceed with merging only for well-connected segments
|
6029
|
+
totlen = ortho_list[0]["len"] + ortho_list[-1]["len"]
|
6030
|
+
merge_cx = (
|
6031
|
+
(ortho_list[0]["cx"] * ortho_list[0]["len"])
|
6032
|
+
+ (ortho_list[-1]["cx"] * ortho_list[-1]["len"])
|
6033
|
+
) / totlen
|
6034
|
+
|
6035
|
+
merge_cy = (
|
6036
|
+
(ortho_list[0]["cy"] * ortho_list[0]["len"])
|
6037
|
+
+ (ortho_list[-1]["cy"] * ortho_list[-1]["len"])
|
6038
|
+
) / totlen
|
6039
|
+
|
6040
|
+
rot_angle = ortho_list[0]["angle"]
|
6041
|
+
X1 = merge_cx - (totlen / 2) * math.cos(rot_angle)
|
6042
|
+
X2 = merge_cx + (totlen / 2) * math.cos(rot_angle)
|
6043
|
+
Y1 = merge_cy - (totlen / 2) * math.sin(rot_angle)
|
6044
|
+
Y2 = merge_cy + (totlen / 2) * math.sin(rot_angle)
|
6045
|
+
|
6046
|
+
ortho_list[-1] = {
|
6047
|
+
"x1": X1,
|
6048
|
+
"y1": Y1,
|
6049
|
+
"x2": X2,
|
6050
|
+
"y2": Y2,
|
6051
|
+
"len": totlen,
|
6052
|
+
"cx": merge_cx,
|
6053
|
+
"cy": merge_cy,
|
6054
|
+
"angle": rot_angle,
|
6055
|
+
}
|
6056
|
+
ortho_list = ortho_list[1:]
|
6057
|
+
else:
|
6058
|
+
# For larger polygons, proceed with standard merging
|
6059
|
+
totlen = ortho_list[0]["len"] + ortho_list[-1]["len"]
|
6060
|
+
merge_cx = (
|
6061
|
+
(ortho_list[0]["cx"] * ortho_list[0]["len"])
|
6062
|
+
+ (ortho_list[-1]["cx"] * ortho_list[-1]["len"])
|
6063
|
+
) / totlen
|
6064
|
+
|
6065
|
+
merge_cy = (
|
6066
|
+
(ortho_list[0]["cy"] * ortho_list[0]["len"])
|
6067
|
+
+ (ortho_list[-1]["cy"] * ortho_list[-1]["len"])
|
6068
|
+
) / totlen
|
6069
|
+
|
6070
|
+
rot_angle = ortho_list[0]["angle"]
|
6071
|
+
X1 = merge_cx - (totlen / 2) * math.cos(rot_angle)
|
6072
|
+
X2 = merge_cx + (totlen / 2) * math.cos(rot_angle)
|
6073
|
+
Y1 = merge_cy - (totlen / 2) * math.sin(rot_angle)
|
6074
|
+
Y2 = merge_cy + (totlen / 2) * math.sin(rot_angle)
|
6075
|
+
|
6076
|
+
ortho_list[-1] = {
|
6077
|
+
"x1": X1,
|
6078
|
+
"y1": Y1,
|
6079
|
+
"x2": X2,
|
6080
|
+
"y2": Y2,
|
6081
|
+
"len": totlen,
|
6082
|
+
"cx": merge_cx,
|
6083
|
+
"cy": merge_cy,
|
6084
|
+
"angle": rot_angle,
|
6085
|
+
}
|
6086
|
+
ortho_list = ortho_list[1:]
|
5921
6087
|
ortho_df = pd.DataFrame(ortho_list)
|
5922
6088
|
return ortho_df
|
5923
6089
|
|
@@ -6026,12 +6192,49 @@ def orthogonalize(
|
|
6026
6192
|
np.sqrt((x4 - x3) ** 2 + (y4 - y3) ** 2),
|
6027
6193
|
)
|
6028
6194
|
|
6029
|
-
#
|
6030
|
-
|
6031
|
-
|
6195
|
+
# Improved intersection validation
|
6196
|
+
# Calculate angle between segments to detect sharp corners
|
6197
|
+
v1 = np.array([x2 - x1, y2 - y1])
|
6198
|
+
v2 = np.array([x4 - x3, y4 - y3])
|
6199
|
+
v1_norm = np.linalg.norm(v1)
|
6200
|
+
v2_norm = np.linalg.norm(v2)
|
6201
|
+
|
6202
|
+
if v1_norm > 0 and v2_norm > 0:
|
6203
|
+
cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
|
6204
|
+
cos_angle = np.clip(cos_angle, -1, 1)
|
6205
|
+
angle = np.arccos(cos_angle)
|
6206
|
+
|
6207
|
+
# Check for very sharp angles that could create triangular artifacts
|
6208
|
+
is_sharp_angle = (
|
6209
|
+
angle < np.pi / 6 or angle > 5 * np.pi / 6
|
6210
|
+
) # <30° or >150°
|
6211
|
+
else:
|
6212
|
+
is_sharp_angle = False
|
6213
|
+
|
6214
|
+
# Determine whether to use intersection or segment endpoint
|
6215
|
+
if (
|
6216
|
+
dist_to_seg1 > max_len * 0.5
|
6217
|
+
or dist_to_seg2 > max_len * 0.5
|
6218
|
+
or is_sharp_angle
|
6219
|
+
):
|
6220
|
+
# Use a more conservative approach for problematic intersections
|
6221
|
+
# Use the closer endpoint between segments
|
6222
|
+
dist_x2_to_seg2 = min(
|
6223
|
+
np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2),
|
6224
|
+
np.sqrt((x2 - x4) ** 2 + (y2 - y4) ** 2),
|
6225
|
+
)
|
6226
|
+
dist_x3_to_seg1 = min(
|
6227
|
+
np.sqrt((x3 - x1) ** 2 + (y3 - y1) ** 2),
|
6228
|
+
np.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2),
|
6229
|
+
)
|
6230
|
+
|
6231
|
+
if dist_x2_to_seg2 <= dist_x3_to_seg1:
|
6232
|
+
ring.append([x2, y2])
|
6233
|
+
else:
|
6234
|
+
ring.append([x3, y3])
|
6032
6235
|
else:
|
6033
6236
|
ring.append(intersection)
|
6034
|
-
except Exception
|
6237
|
+
except Exception:
|
6035
6238
|
# If intersection calculation fails, use the endpoint of the first segment
|
6036
6239
|
ring.append([x2, y2])
|
6037
6240
|
|
@@ -6057,11 +6260,42 @@ def orthogonalize(
|
|
6057
6260
|
np.sqrt((x4 - x3) ** 2 + (y4 - y3) ** 2),
|
6058
6261
|
)
|
6059
6262
|
|
6060
|
-
|
6061
|
-
|
6263
|
+
# Apply same sharp angle detection for closing segment
|
6264
|
+
v1 = np.array([x2 - x1, y2 - y1])
|
6265
|
+
v2 = np.array([x4 - x3, y4 - y3])
|
6266
|
+
v1_norm = np.linalg.norm(v1)
|
6267
|
+
v2_norm = np.linalg.norm(v2)
|
6268
|
+
|
6269
|
+
if v1_norm > 0 and v2_norm > 0:
|
6270
|
+
cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
|
6271
|
+
cos_angle = np.clip(cos_angle, -1, 1)
|
6272
|
+
angle = np.arccos(cos_angle)
|
6273
|
+
is_sharp_angle = angle < np.pi / 6 or angle > 5 * np.pi / 6
|
6274
|
+
else:
|
6275
|
+
is_sharp_angle = False
|
6276
|
+
|
6277
|
+
if (
|
6278
|
+
dist_to_seg1 > max_len * 0.5
|
6279
|
+
or dist_to_seg2 > max_len * 0.5
|
6280
|
+
or is_sharp_angle
|
6281
|
+
):
|
6282
|
+
# Use conservative approach for closing segment
|
6283
|
+
dist_x2_to_seg2 = min(
|
6284
|
+
np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2),
|
6285
|
+
np.sqrt((x2 - x4) ** 2 + (y2 - y4) ** 2),
|
6286
|
+
)
|
6287
|
+
dist_x3_to_seg1 = min(
|
6288
|
+
np.sqrt((x3 - x1) ** 2 + (y3 - y1) ** 2),
|
6289
|
+
np.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2),
|
6290
|
+
)
|
6291
|
+
|
6292
|
+
if dist_x2_to_seg2 <= dist_x3_to_seg1:
|
6293
|
+
ring.append([x2, y2])
|
6294
|
+
else:
|
6295
|
+
ring.append([x3, y3])
|
6062
6296
|
else:
|
6063
6297
|
ring.append(intersection)
|
6064
|
-
except Exception
|
6298
|
+
except Exception:
|
6065
6299
|
# If intersection calculation fails, use the endpoint of the last segment
|
6066
6300
|
ring.append([x2, y2])
|
6067
6301
|
|
@@ -7219,41 +7453,61 @@ def plot_performance_metrics(history_path, figsize=(15, 5), verbose=True):
|
|
7219
7453
|
"""
|
7220
7454
|
history = torch.load(history_path)
|
7221
7455
|
|
7456
|
+
# Handle different key naming conventions
|
7457
|
+
train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
|
7458
|
+
val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
|
7459
|
+
val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
|
7460
|
+
val_dice_key = "val_dices" if "val_dices" in history else "val_dice"
|
7461
|
+
|
7462
|
+
# Determine number of subplots based on available metrics
|
7463
|
+
has_dice = val_dice_key in history
|
7464
|
+
n_plots = 3 if has_dice else 2
|
7465
|
+
figsize = (15, 5) if has_dice else (10, 5)
|
7466
|
+
|
7222
7467
|
plt.figure(figsize=figsize)
|
7223
7468
|
|
7224
|
-
|
7225
|
-
plt.
|
7226
|
-
|
7469
|
+
# Plot loss
|
7470
|
+
plt.subplot(1, n_plots, 1)
|
7471
|
+
if train_loss_key in history:
|
7472
|
+
plt.plot(history[train_loss_key], label="Train Loss")
|
7473
|
+
if val_loss_key in history:
|
7474
|
+
plt.plot(history[val_loss_key], label="Val Loss")
|
7227
7475
|
plt.title("Loss")
|
7228
7476
|
plt.xlabel("Epoch")
|
7229
7477
|
plt.ylabel("Loss")
|
7230
7478
|
plt.legend()
|
7231
7479
|
plt.grid(True)
|
7232
7480
|
|
7233
|
-
|
7234
|
-
plt.
|
7481
|
+
# Plot IoU
|
7482
|
+
plt.subplot(1, n_plots, 2)
|
7483
|
+
if val_iou_key in history:
|
7484
|
+
plt.plot(history[val_iou_key], label="Val IoU")
|
7235
7485
|
plt.title("IoU Score")
|
7236
7486
|
plt.xlabel("Epoch")
|
7237
7487
|
plt.ylabel("IoU")
|
7238
7488
|
plt.legend()
|
7239
7489
|
plt.grid(True)
|
7240
7490
|
|
7241
|
-
|
7242
|
-
|
7243
|
-
|
7244
|
-
|
7245
|
-
|
7246
|
-
|
7247
|
-
|
7491
|
+
# Plot Dice if available
|
7492
|
+
if has_dice:
|
7493
|
+
plt.subplot(1, n_plots, 3)
|
7494
|
+
plt.plot(history[val_dice_key], label="Val Dice")
|
7495
|
+
plt.title("Dice Score")
|
7496
|
+
plt.xlabel("Epoch")
|
7497
|
+
plt.ylabel("Dice")
|
7498
|
+
plt.legend()
|
7499
|
+
plt.grid(True)
|
7248
7500
|
|
7249
7501
|
plt.tight_layout()
|
7250
7502
|
plt.show()
|
7251
7503
|
|
7252
7504
|
if verbose:
|
7253
|
-
|
7254
|
-
|
7255
|
-
|
7256
|
-
|
7505
|
+
if val_iou_key in history:
|
7506
|
+
print(f"Best IoU: {max(history[val_iou_key]):.4f}")
|
7507
|
+
print(f"Final IoU: {history[val_iou_key][-1]:.4f}")
|
7508
|
+
if val_dice_key in history:
|
7509
|
+
print(f"Best Dice: {max(history[val_dice_key]):.4f}")
|
7510
|
+
print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
|
7257
7511
|
|
7258
7512
|
|
7259
7513
|
def get_device():
|
@@ -7280,105 +7534,123 @@ def plot_prediction_comparison(
|
|
7280
7534
|
prediction_colormap: str = "gray",
|
7281
7535
|
ground_truth_colormap: str = "gray",
|
7282
7536
|
original_colormap: Optional[str] = None,
|
7537
|
+
indexes: Optional[List[int]] = None,
|
7538
|
+
divider: Optional[float] = None,
|
7283
7539
|
):
|
7284
|
-
"""
|
7285
|
-
|
7540
|
+
"""Plot original image, prediction, and optional ground truth side by side.
|
7541
|
+
|
7542
|
+
Supports input as file paths, NumPy arrays, or PIL Images. For multi-band
|
7543
|
+
images, selected channels can be specified via `indexes`. If the image data
|
7544
|
+
is not normalized (e.g., Sentinel-2 [0, 10000]), the `divider` can be used
|
7545
|
+
to scale values for visualization.
|
7286
7546
|
|
7287
7547
|
Args:
|
7288
|
-
original_image
|
7289
|
-
|
7290
|
-
|
7291
|
-
|
7292
|
-
|
7293
|
-
|
7294
|
-
|
7295
|
-
|
7296
|
-
|
7297
|
-
|
7548
|
+
original_image (Union[str, np.ndarray, Image.Image]):
|
7549
|
+
Original input image as a file path, NumPy array, or PIL Image.
|
7550
|
+
prediction_image (Union[str, np.ndarray, Image.Image]):
|
7551
|
+
Predicted segmentation mask image.
|
7552
|
+
ground_truth_image (Optional[Union[str, np.ndarray, Image.Image]], optional):
|
7553
|
+
Ground truth mask image. Defaults to None.
|
7554
|
+
titles (Optional[List[str]], optional):
|
7555
|
+
List of titles for the subplots. If not provided, default titles are used.
|
7556
|
+
figsize (Tuple[int, int], optional):
|
7557
|
+
Size of the entire figure in inches. Defaults to (15, 5).
|
7558
|
+
save_path (Optional[str], optional):
|
7559
|
+
If specified, saves the figure to this path. Defaults to None.
|
7560
|
+
show_plot (bool, optional):
|
7561
|
+
Whether to display the figure using plt.show(). Defaults to True.
|
7562
|
+
prediction_colormap (str, optional):
|
7563
|
+
Colormap to use for the prediction mask. Defaults to "gray".
|
7564
|
+
ground_truth_colormap (str, optional):
|
7565
|
+
Colormap to use for the ground truth mask. Defaults to "gray".
|
7566
|
+
original_colormap (Optional[str], optional):
|
7567
|
+
Colormap to use for the original image if it's grayscale. Defaults to None.
|
7568
|
+
indexes (Optional[List[int]], optional):
|
7569
|
+
List of band/channel indexes (0-based for NumPy, 1-based for rasterio) to extract from the original image.
|
7570
|
+
Useful for multi-band imagery like Sentinel-2. Defaults to None.
|
7571
|
+
divider (Optional[float], optional):
|
7572
|
+
Value to divide the original image by for normalization (e.g., 10000 for reflectance). Defaults to None.
|
7298
7573
|
|
7299
7574
|
Returns:
|
7300
|
-
matplotlib.figure.Figure:
|
7575
|
+
matplotlib.figure.Figure:
|
7576
|
+
The generated matplotlib figure object.
|
7301
7577
|
"""
|
7302
7578
|
|
7303
|
-
def _load_image(img_input):
|
7579
|
+
def _load_image(img_input, indexes=None):
|
7304
7580
|
"""Helper function to load image from various input types."""
|
7305
7581
|
if isinstance(img_input, str):
|
7306
|
-
# File path
|
7307
7582
|
if img_input.lower().endswith((".tif", ".tiff")):
|
7308
|
-
# Handle GeoTIFF files
|
7309
7583
|
with rasterio.open(img_input) as src:
|
7310
|
-
|
7311
|
-
|
7312
|
-
|
7313
|
-
|
7584
|
+
if indexes:
|
7585
|
+
img = src.read(indexes) # 1-based
|
7586
|
+
img = (
|
7587
|
+
np.transpose(img, (1, 2, 0)) if len(indexes) > 1 else img[0]
|
7588
|
+
)
|
7314
7589
|
else:
|
7315
|
-
|
7316
|
-
img
|
7590
|
+
img = src.read()
|
7591
|
+
if img.shape[0] == 1:
|
7592
|
+
img = img[0]
|
7593
|
+
else:
|
7594
|
+
img = np.transpose(img, (1, 2, 0))
|
7317
7595
|
else:
|
7318
|
-
# Regular image file
|
7319
7596
|
img = np.array(Image.open(img_input))
|
7320
7597
|
elif isinstance(img_input, Image.Image):
|
7321
|
-
# PIL Image
|
7322
7598
|
img = np.array(img_input)
|
7323
7599
|
elif isinstance(img_input, np.ndarray):
|
7324
|
-
# NumPy array
|
7325
7600
|
img = img_input
|
7601
|
+
if indexes is not None and img.ndim == 3:
|
7602
|
+
img = img[:, :, indexes]
|
7326
7603
|
else:
|
7327
7604
|
raise ValueError(f"Unsupported image type: {type(img_input)}")
|
7328
|
-
|
7329
7605
|
return img
|
7330
7606
|
|
7331
7607
|
# Load images
|
7332
|
-
original = _load_image(original_image)
|
7608
|
+
original = _load_image(original_image, indexes=indexes)
|
7333
7609
|
prediction = _load_image(prediction_image)
|
7334
7610
|
ground_truth = (
|
7335
7611
|
_load_image(ground_truth_image) if ground_truth_image is not None else None
|
7336
7612
|
)
|
7337
7613
|
|
7338
|
-
#
|
7339
|
-
|
7614
|
+
# Apply divider normalization if requested
|
7615
|
+
if divider is not None and isinstance(original, np.ndarray) and original.ndim == 3:
|
7616
|
+
original = np.clip(original.astype(np.float32) / divider, 0, 1)
|
7340
7617
|
|
7341
|
-
#
|
7618
|
+
# Determine layout
|
7619
|
+
num_plots = 3 if ground_truth is not None else 2
|
7342
7620
|
fig, axes = plt.subplots(1, num_plots, figsize=figsize)
|
7343
7621
|
if num_plots == 2:
|
7344
7622
|
axes = [axes[0], axes[1]]
|
7345
7623
|
|
7346
|
-
# Default titles
|
7347
7624
|
if titles is None:
|
7348
7625
|
titles = ["Original Image", "Prediction"]
|
7349
7626
|
if ground_truth is not None:
|
7350
7627
|
titles.append("Ground Truth")
|
7351
7628
|
|
7352
|
-
# Plot original
|
7353
|
-
if
|
7354
|
-
# RGB or RGBA image
|
7629
|
+
# Plot original
|
7630
|
+
if original.ndim == 3 and original.shape[2] in [3, 4]:
|
7355
7631
|
axes[0].imshow(original)
|
7356
7632
|
else:
|
7357
|
-
# Grayscale or single channel
|
7358
7633
|
axes[0].imshow(original, cmap=original_colormap)
|
7359
7634
|
axes[0].set_title(titles[0])
|
7360
7635
|
axes[0].axis("off")
|
7361
7636
|
|
7362
|
-
#
|
7637
|
+
# Prediction
|
7363
7638
|
axes[1].imshow(prediction, cmap=prediction_colormap)
|
7364
7639
|
axes[1].set_title(titles[1])
|
7365
7640
|
axes[1].axis("off")
|
7366
7641
|
|
7367
|
-
#
|
7642
|
+
# Ground truth
|
7368
7643
|
if ground_truth is not None:
|
7369
7644
|
axes[2].imshow(ground_truth, cmap=ground_truth_colormap)
|
7370
7645
|
axes[2].set_title(titles[2])
|
7371
7646
|
axes[2].axis("off")
|
7372
7647
|
|
7373
|
-
# Adjust layout
|
7374
7648
|
plt.tight_layout()
|
7375
7649
|
|
7376
|
-
# Save if requested
|
7377
7650
|
if save_path:
|
7378
7651
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
7379
7652
|
print(f"Plot saved to: {save_path}")
|
7380
7653
|
|
7381
|
-
# Show plot
|
7382
7654
|
if show_plot:
|
7383
7655
|
plt.show()
|
7384
7656
|
|
@@ -0,0 +1,17 @@
|
|
1
|
+
geoai/__init__.py,sha256=Ks7SlSE7TkToYHB3EhXxIQLsYXuzfwHjyEUKJjkd8SA,3765
|
2
|
+
geoai/classify.py,sha256=O8fah3DOBDMZW7V_qfDYsUnjB-9Wo5fjA-0e4wvUeAE,35054
|
3
|
+
geoai/download.py,sha256=EQpcrcqMsYhDpd7bpjf4hGS5xL2oO-jsjngLgjjP3cE,46599
|
4
|
+
geoai/extract.py,sha256=vyHH1k5zaXiy1SLdCLXxbWiNLp8XKdu_MXZoREMtAOQ,119102
|
5
|
+
geoai/geoai.py,sha256=_Ar7PJgWpN86tm1YhzLVj7r1lYA3faRJT5brf1WzcwQ,9631
|
6
|
+
geoai/hf.py,sha256=mLKGxEAS5eHkxZLwuLpYc1o7e3-7QIXdBv-QUY-RkFk,17072
|
7
|
+
geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
|
8
|
+
geoai/segment.py,sha256=pThAyq8kjgVDhMwHMiWkZ2qL5ykzA5lRg7tyMmSEBxk,43434
|
9
|
+
geoai/segmentation.py,sha256=AtPzCvguHAEeuyXafa4bzMFATvltEYcah1B8ZMfkM_s,11373
|
10
|
+
geoai/train.py,sha256=0uIz3i3sH-Knyq9wBrKLy22f--oLk2m-fCdhfwJosNU,129112
|
11
|
+
geoai/utils.py,sha256=jNSXXQ056iyeDHDnkPUJ0_TLe-NonGbXh2YCN-gM50c,297270
|
12
|
+
geoai_py-0.8.3.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
13
|
+
geoai_py-0.8.3.dist-info/METADATA,sha256=ifNb98sm8g-Bsicc_R_9ljKahRXjibbEbSxxG8ga2mk,6661
|
14
|
+
geoai_py-0.8.3.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
|
15
|
+
geoai_py-0.8.3.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
|
16
|
+
geoai_py-0.8.3.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
17
|
+
geoai_py-0.8.3.dist-info/RECORD,,
|
geoai_py-0.8.1.dist-info/RECORD
DELETED
@@ -1,17 +0,0 @@
|
|
1
|
-
geoai/__init__.py,sha256=NRaQhG5-sDLmpYXZXhbJ_LPkCEQbT-71qDAspoNEyxk,3765
|
2
|
-
geoai/classify.py,sha256=O8fah3DOBDMZW7V_qfDYsUnjB-9Wo5fjA-0e4wvUeAE,35054
|
3
|
-
geoai/download.py,sha256=EQpcrcqMsYhDpd7bpjf4hGS5xL2oO-jsjngLgjjP3cE,46599
|
4
|
-
geoai/extract.py,sha256=vyHH1k5zaXiy1SLdCLXxbWiNLp8XKdu_MXZoREMtAOQ,119102
|
5
|
-
geoai/geoai.py,sha256=HU6epCjpk228J65ZXmxY6GKlHg_ncmWV3UQUr_f8QTM,9447
|
6
|
-
geoai/hf.py,sha256=mLKGxEAS5eHkxZLwuLpYc1o7e3-7QIXdBv-QUY-RkFk,17072
|
7
|
-
geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
|
8
|
-
geoai/segment.py,sha256=pThAyq8kjgVDhMwHMiWkZ2qL5ykzA5lRg7tyMmSEBxk,43434
|
9
|
-
geoai/segmentation.py,sha256=AtPzCvguHAEeuyXafa4bzMFATvltEYcah1B8ZMfkM_s,11373
|
10
|
-
geoai/train.py,sha256=Mrsb0yMVnprQHld3zDvA-puc-r8hGm1XgG0j2GGIn7E,112845
|
11
|
-
geoai/utils.py,sha256=WQ_h39LpU65i-LP6GfFJLZBpiY0fyPdFc7dZcuyZ24s,284611
|
12
|
-
geoai_py-0.8.1.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
13
|
-
geoai_py-0.8.1.dist-info/METADATA,sha256=zOOCsS3BTaqIjl3DPsbxDDba4kK7kDKgNk2ytC_sYJ8,6661
|
14
|
-
geoai_py-0.8.1.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
|
15
|
-
geoai_py-0.8.1.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
|
16
|
-
geoai_py-0.8.1.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
17
|
-
geoai_py-0.8.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|