geoai-py 0.8.1__py2.py3-none-any.whl → 0.8.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.8.1"
5
+ __version__ = "0.8.3"
6
6
 
7
7
 
8
8
  import os
geoai/geoai.py CHANGED
@@ -32,6 +32,11 @@ from .train import (
32
32
  train_segmentation_model,
33
33
  semantic_segmentation,
34
34
  semantic_segmentation_batch,
35
+ train_instance_segmentation_model,
36
+ instance_segmentation,
37
+ instance_segmentation_batch,
38
+ get_instance_segmentation_model,
39
+ instance_segmentation_inference_on_geotiff,
35
40
  )
36
41
  from .utils import *
37
42
 
geoai/train.py CHANGED
@@ -750,6 +750,15 @@ def train_MaskRCNN_model(
750
750
  start_epoch = 0
751
751
  best_iou = 0
752
752
 
753
+ # Initialize training history
754
+ training_history = {
755
+ "train_loss": [],
756
+ "val_loss": [],
757
+ "val_iou": [],
758
+ "epochs": [],
759
+ "lr": [],
760
+ }
761
+
753
762
  # Load pretrained model if provided
754
763
  if pretrained_model_path:
755
764
  if not os.path.exists(pretrained_model_path):
@@ -800,6 +809,13 @@ def train_MaskRCNN_model(
800
809
  # Evaluate
801
810
  eval_metrics = evaluate(model, val_loader, device)
802
811
 
812
+ # Record training history
813
+ training_history["train_loss"].append(train_loss)
814
+ training_history["val_loss"].append(eval_metrics["loss"])
815
+ training_history["val_iou"].append(eval_metrics["IoU"])
816
+ training_history["epochs"].append(epoch + 1)
817
+ training_history["lr"].append(optimizer.param_groups[0]["lr"])
818
+
803
819
  # Print metrics
804
820
  print(
805
821
  f"Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Val Loss: {eval_metrics['loss']:.4f}, Val IoU: {eval_metrics['IoU']:.4f}"
@@ -811,33 +827,11 @@ def train_MaskRCNN_model(
811
827
  print(f"Saving best model with IoU: {best_iou:.4f}")
812
828
  torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
813
829
 
814
- # Save checkpoint every 10 epochs
815
- if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
816
- torch.save(
817
- {
818
- "epoch": epoch,
819
- "model_state_dict": model.state_dict(),
820
- "optimizer_state_dict": optimizer.state_dict(),
821
- "scheduler_state_dict": lr_scheduler.state_dict(),
822
- "best_iou": best_iou,
823
- },
824
- os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
825
- )
826
-
827
830
  # Save final model
828
831
  torch.save(model.state_dict(), os.path.join(output_dir, "final_model.pth"))
829
832
 
830
- # Save full checkpoint of final state
831
- torch.save(
832
- {
833
- "epoch": num_epochs - 1,
834
- "model_state_dict": model.state_dict(),
835
- "optimizer_state_dict": optimizer.state_dict(),
836
- "scheduler_state_dict": lr_scheduler.state_dict(),
837
- "best_iou": best_iou,
838
- },
839
- os.path.join(output_dir, "final_checkpoint.pth"),
840
- )
833
+ # Save training history
834
+ torch.save(training_history, os.path.join(output_dir, "training_history.pth"))
841
835
 
842
836
  # Load best model for evaluation and visualization
843
837
  model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
@@ -1101,6 +1095,237 @@ def inference_on_geotiff(
1101
1095
  return output_path, inference_time
1102
1096
 
1103
1097
 
1098
+ def instance_segmentation_inference_on_geotiff(
1099
+ model,
1100
+ geotiff_path,
1101
+ output_path,
1102
+ window_size=512,
1103
+ overlap=256,
1104
+ confidence_threshold=0.5,
1105
+ batch_size=4,
1106
+ num_channels=3,
1107
+ device=None,
1108
+ **kwargs,
1109
+ ):
1110
+ """
1111
+ Perform instance segmentation inference on a large GeoTIFF using a sliding window approach.
1112
+
1113
+ This function collects all detections first, then applies non-maximum suppression
1114
+ to handle overlapping detections from different windows, preventing artifacts.
1115
+
1116
+ Args:
1117
+ model (torch.nn.Module): Trained model for inference.
1118
+ geotiff_path (str): Path to input GeoTIFF file.
1119
+ output_path (str): Path to save output instance mask GeoTIFF.
1120
+ window_size (int): Size of sliding window for inference.
1121
+ overlap (int): Overlap between adjacent windows.
1122
+ confidence_threshold (float): Confidence threshold for predictions (0-1).
1123
+ batch_size (int): Batch size for inference.
1124
+ num_channels (int): Number of channels to use from the input image.
1125
+ device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
1126
+ **kwargs: Additional arguments.
1127
+
1128
+ Returns:
1129
+ tuple: Tuple containing output path and inference time in seconds.
1130
+ """
1131
+ if device is None:
1132
+ device = get_device()
1133
+
1134
+ # Put model in evaluation mode
1135
+ model.to(device)
1136
+ model.eval()
1137
+
1138
+ # Open the GeoTIFF
1139
+ with rasterio.open(geotiff_path) as src:
1140
+ # Read metadata
1141
+ meta = src.meta
1142
+ height = src.height
1143
+ width = src.width
1144
+
1145
+ # Update metadata for output raster
1146
+ out_meta = meta.copy()
1147
+ out_meta.update(
1148
+ {"count": 1, "dtype": "uint16"} # uint16 to support many instances
1149
+ )
1150
+
1151
+ # Store all detections globally for NMS
1152
+ all_detections = []
1153
+
1154
+ # Calculate the number of windows needed to cover the entire image
1155
+ steps_y = math.ceil((height - overlap) / (window_size - overlap))
1156
+ steps_x = math.ceil((width - overlap) / (window_size - overlap))
1157
+
1158
+ # Ensure we cover the entire image
1159
+ last_y = height - window_size
1160
+ last_x = width - window_size
1161
+
1162
+ total_windows = steps_y * steps_x
1163
+ print(
1164
+ f"Processing {total_windows} windows with size {window_size}x{window_size} and overlap {overlap}..."
1165
+ )
1166
+
1167
+ # Create progress bar
1168
+ pbar = tqdm(total=total_windows)
1169
+
1170
+ # Process in batches
1171
+ batch_inputs = []
1172
+ batch_positions = []
1173
+ batch_count = 0
1174
+
1175
+ start_time = time.time()
1176
+
1177
+ # Slide window over the image
1178
+ for i in range(steps_y + 1): # +1 to ensure we reach the edge
1179
+ y = min(i * (window_size - overlap), last_y)
1180
+ y = max(0, y) # Prevent negative indices
1181
+
1182
+ if y > last_y and i > 0: # Skip if we've already covered the entire height
1183
+ continue
1184
+
1185
+ for j in range(steps_x + 1): # +1 to ensure we reach the edge
1186
+ x = min(j * (window_size - overlap), last_x)
1187
+ x = max(0, x) # Prevent negative indices
1188
+
1189
+ if (
1190
+ x > last_x and j > 0
1191
+ ): # Skip if we've already covered the entire width
1192
+ continue
1193
+
1194
+ # Read window
1195
+ window = src.read(window=Window(x, y, window_size, window_size))
1196
+
1197
+ # Check if window is valid
1198
+ if window.shape[1] == 0 or window.shape[2] == 0:
1199
+ continue
1200
+
1201
+ # Handle edge cases where window might be smaller than expected
1202
+ actual_height, actual_width = window.shape[1], window.shape[2]
1203
+
1204
+ # Convert to [C, H, W] format and normalize
1205
+ image = window.astype(np.float32) / 255.0
1206
+
1207
+ # Handle different number of channels
1208
+ if image.shape[0] > num_channels:
1209
+ image = image[:num_channels]
1210
+ elif image.shape[0] < num_channels:
1211
+ # Pad with zeros if less than expected channels
1212
+ padded = np.zeros(
1213
+ (num_channels, image.shape[1], image.shape[2]), dtype=np.float32
1214
+ )
1215
+ padded[: image.shape[0]] = image
1216
+ image = padded
1217
+
1218
+ # Convert to tensor
1219
+ image_tensor = torch.tensor(image, device=device)
1220
+
1221
+ # Add to batch
1222
+ batch_inputs.append(image_tensor)
1223
+ batch_positions.append((y, x, actual_height, actual_width))
1224
+ batch_count += 1
1225
+
1226
+ # Process batch when it reaches the batch size or at the end
1227
+ if batch_count == batch_size or (i == steps_y and j == steps_x):
1228
+ # Forward pass
1229
+ with torch.no_grad():
1230
+ outputs = model(batch_inputs)
1231
+
1232
+ # Process each output in the batch
1233
+ for idx, output in enumerate(outputs):
1234
+ y_pos, x_pos, h, w = batch_positions[idx]
1235
+
1236
+ # Process each detected instance
1237
+ if len(output["scores"]) > 0:
1238
+ # Get instances that meet confidence threshold
1239
+ keep = output["scores"] > confidence_threshold
1240
+ masks = output["masks"][keep].squeeze(1)
1241
+ scores = output["scores"][keep]
1242
+ boxes = output["boxes"][keep]
1243
+
1244
+ # Convert to global coordinates and store
1245
+ for k in range(len(masks)):
1246
+ mask = masks[k].cpu().numpy() > 0.5
1247
+ score = scores[k].cpu().item()
1248
+ box = boxes[k].cpu().numpy()
1249
+
1250
+ # Convert box to global coordinates
1251
+ global_box = [
1252
+ box[0] + x_pos,
1253
+ box[1] + y_pos,
1254
+ box[2] + x_pos,
1255
+ box[3] + y_pos,
1256
+ ]
1257
+
1258
+ # Create global mask
1259
+ global_mask = np.zeros((height, width), dtype=bool)
1260
+ global_mask[y_pos : y_pos + h, x_pos : x_pos + w] = mask
1261
+
1262
+ all_detections.append(
1263
+ {
1264
+ "mask": global_mask,
1265
+ "score": score,
1266
+ "box": global_box,
1267
+ }
1268
+ )
1269
+
1270
+ # Reset batch
1271
+ batch_inputs = []
1272
+ batch_positions = []
1273
+ batch_count = 0
1274
+
1275
+ # Update progress bar
1276
+ pbar.update(len(outputs))
1277
+
1278
+ # Close progress bar
1279
+ pbar.close()
1280
+
1281
+ print(f"Collected {len(all_detections)} detections before NMS")
1282
+
1283
+ # Apply Non-Maximum Suppression to handle overlapping detections
1284
+ if len(all_detections) > 0:
1285
+ # Convert to tensors for NMS
1286
+ boxes = torch.tensor([det["box"] for det in all_detections])
1287
+ scores = torch.tensor([det["score"] for det in all_detections])
1288
+
1289
+ # Apply NMS with IoU threshold
1290
+ nms_threshold = 0.3 # IoU threshold for NMS
1291
+ keep_indices = torchvision.ops.nms(boxes, scores, nms_threshold)
1292
+
1293
+ # Keep only the selected detections
1294
+ final_detections = [all_detections[i] for i in keep_indices]
1295
+ print(f"After NMS: {len(final_detections)} detections")
1296
+
1297
+ # Create final instance mask
1298
+ instance_mask = np.zeros((height, width), dtype=np.uint16)
1299
+
1300
+ # Sort by score (highest first) for consistent ordering
1301
+ final_detections.sort(key=lambda x: x["score"], reverse=True)
1302
+
1303
+ # Assign unique IDs to each detection
1304
+ for instance_id, detection in enumerate(final_detections, 1):
1305
+ mask = detection["mask"]
1306
+ # Only assign to pixels that are not already assigned
1307
+ available_pixels = (instance_mask == 0) & mask
1308
+ instance_mask[available_pixels] = instance_id
1309
+ else:
1310
+ # No detections found
1311
+ instance_mask = np.zeros((height, width), dtype=np.uint16)
1312
+
1313
+ # Record time
1314
+ inference_time = time.time() - start_time
1315
+ print(f"Instance segmentation completed in {inference_time:.2f} seconds")
1316
+ print(
1317
+ f"Final instances: {len(final_detections) if len(all_detections) > 0 else 0}"
1318
+ )
1319
+
1320
+ # Save output
1321
+ with rasterio.open(output_path, "w", **out_meta) as dst:
1322
+ dst.write(instance_mask, 1)
1323
+
1324
+ print(f"Saved instance segmentation to {output_path}")
1325
+
1326
+ return output_path, inference_time
1327
+
1328
+
1104
1329
  def object_detection(
1105
1330
  input_path,
1106
1331
  output_path,
@@ -3067,3 +3292,220 @@ def semantic_segmentation_batch(
3067
3292
  continue
3068
3293
 
3069
3294
  print(f"Batch processing completed. Results saved to {output_dir}")
3295
+
3296
+
3297
+ def train_instance_segmentation_model(
3298
+ images_dir,
3299
+ labels_dir,
3300
+ output_dir,
3301
+ num_classes=2,
3302
+ num_channels=3,
3303
+ batch_size=4,
3304
+ num_epochs=10,
3305
+ learning_rate=0.005,
3306
+ seed=42,
3307
+ val_split=0.2,
3308
+ visualize=False,
3309
+ device=None,
3310
+ verbose=True,
3311
+ **kwargs,
3312
+ ):
3313
+ """
3314
+ Train an instance segmentation model using Mask R-CNN.
3315
+
3316
+ This is a wrapper function for train_MaskRCNN_model with clearer naming.
3317
+
3318
+ Args:
3319
+ images_dir (str): Directory containing image GeoTIFF files.
3320
+ labels_dir (str): Directory containing label GeoTIFF files.
3321
+ output_dir (str): Directory to save model checkpoints and results.
3322
+ num_classes (int): Number of classes (including background). Defaults to 2.
3323
+ num_channels (int): Number of input channels. Defaults to 3.
3324
+ batch_size (int): Batch size for training. Defaults to 4.
3325
+ num_epochs (int): Number of training epochs. Defaults to 10.
3326
+ learning_rate (float): Initial learning rate. Defaults to 0.005.
3327
+ seed (int): Random seed for reproducibility. Defaults to 42.
3328
+ val_split (float): Fraction of data to use for validation (0-1). Defaults to 0.2.
3329
+ visualize (bool): Whether to generate visualizations. Defaults to False.
3330
+ device (torch.device): Device to train on. If None, uses CUDA if available.
3331
+ verbose (bool): If True, prints detailed training progress. Defaults to True.
3332
+ **kwargs: Additional arguments passed to train_MaskRCNN_model.
3333
+
3334
+ Returns:
3335
+ None: Model weights are saved to output_dir.
3336
+ """
3337
+ # Create model with the specified number of classes
3338
+ model = get_instance_segmentation_model(
3339
+ num_classes=num_classes, num_channels=num_channels, pretrained=True
3340
+ )
3341
+
3342
+ return train_MaskRCNN_model(
3343
+ images_dir=images_dir,
3344
+ labels_dir=labels_dir,
3345
+ output_dir=output_dir,
3346
+ num_channels=num_channels,
3347
+ model=model,
3348
+ batch_size=batch_size,
3349
+ num_epochs=num_epochs,
3350
+ learning_rate=learning_rate,
3351
+ seed=seed,
3352
+ val_split=val_split,
3353
+ visualize=visualize,
3354
+ device=device,
3355
+ verbose=verbose,
3356
+ **kwargs,
3357
+ )
3358
+
3359
+
3360
+ def instance_segmentation(
3361
+ input_path,
3362
+ output_path,
3363
+ model_path,
3364
+ window_size=512,
3365
+ overlap=256,
3366
+ confidence_threshold=0.5,
3367
+ batch_size=4,
3368
+ num_channels=3,
3369
+ num_classes=2,
3370
+ device=None,
3371
+ **kwargs,
3372
+ ):
3373
+ """
3374
+ Perform instance segmentation on a GeoTIFF using a pre-trained Mask R-CNN model.
3375
+
3376
+ This is a wrapper function for object_detection with clearer naming.
3377
+
3378
+ Args:
3379
+ input_path (str): Path to input GeoTIFF file.
3380
+ output_path (str): Path to save output mask GeoTIFF.
3381
+ model_path (str): Path to trained model weights.
3382
+ window_size (int): Size of sliding window for inference. Defaults to 512.
3383
+ overlap (int): Overlap between adjacent windows. Defaults to 256.
3384
+ confidence_threshold (float): Confidence threshold for predictions (0-1). Defaults to 0.5.
3385
+ batch_size (int): Batch size for inference. Defaults to 4.
3386
+ num_channels (int): Number of channels in the input image and model. Defaults to 3.
3387
+ num_classes (int): Number of classes (including background). Defaults to 2.
3388
+ device (torch.device): Device to run inference on. If None, uses CUDA if available.
3389
+ **kwargs: Additional arguments passed to object_detection.
3390
+
3391
+ Returns:
3392
+ None: Output mask is saved to output_path.
3393
+ """
3394
+ # Create model with the specified number of classes
3395
+ model = get_instance_segmentation_model(
3396
+ num_classes=num_classes, num_channels=num_channels, pretrained=True
3397
+ )
3398
+
3399
+ # Load the trained model
3400
+ if device is None:
3401
+ device = get_device()
3402
+
3403
+ model.load_state_dict(torch.load(model_path, map_location=device))
3404
+ model.to(device)
3405
+
3406
+ # Use the proper instance segmentation inference function
3407
+ return instance_segmentation_inference_on_geotiff(
3408
+ model=model,
3409
+ geotiff_path=input_path,
3410
+ output_path=output_path,
3411
+ window_size=window_size,
3412
+ overlap=overlap,
3413
+ confidence_threshold=confidence_threshold,
3414
+ batch_size=batch_size,
3415
+ num_channels=num_channels,
3416
+ device=device,
3417
+ **kwargs,
3418
+ )
3419
+
3420
+
3421
+ def instance_segmentation_batch(
3422
+ input_dir,
3423
+ output_dir,
3424
+ model_path,
3425
+ window_size=512,
3426
+ overlap=256,
3427
+ confidence_threshold=0.5,
3428
+ batch_size=4,
3429
+ num_channels=3,
3430
+ num_classes=2,
3431
+ device=None,
3432
+ **kwargs,
3433
+ ):
3434
+ """
3435
+ Perform instance segmentation on multiple GeoTIFF files using a pre-trained Mask R-CNN model.
3436
+
3437
+ This is a wrapper function for object_detection_batch with clearer naming.
3438
+
3439
+ Args:
3440
+ input_dir (str): Directory containing input GeoTIFF files.
3441
+ output_dir (str): Directory to save output mask GeoTIFF files.
3442
+ model_path (str): Path to trained model weights.
3443
+ window_size (int): Size of sliding window for inference. Defaults to 512.
3444
+ overlap (int): Overlap between adjacent windows. Defaults to 256.
3445
+ confidence_threshold (float): Confidence threshold for predictions (0-1). Defaults to 0.5.
3446
+ batch_size (int): Batch size for inference. Defaults to 4.
3447
+ num_channels (int): Number of channels in the input image and model. Defaults to 3.
3448
+ num_classes (int): Number of classes (including background). Defaults to 2.
3449
+ device (torch.device): Device to run inference on. If None, uses CUDA if available.
3450
+ **kwargs: Additional arguments passed to object_detection_batch.
3451
+
3452
+ Returns:
3453
+ None: Output masks are saved to output_dir.
3454
+ """
3455
+ # Create model with the specified number of classes
3456
+ model = get_instance_segmentation_model(
3457
+ num_classes=num_classes, num_channels=num_channels, pretrained=True
3458
+ )
3459
+
3460
+ # Load the trained model
3461
+ if device is None:
3462
+ device = get_device()
3463
+
3464
+ model.load_state_dict(torch.load(model_path, map_location=device))
3465
+ model.to(device)
3466
+
3467
+ # Process all GeoTIFF files in the input directory
3468
+ import glob
3469
+
3470
+ input_files = glob.glob(os.path.join(input_dir, "*.tif")) + glob.glob(
3471
+ os.path.join(input_dir, "*.tiff")
3472
+ )
3473
+
3474
+ if not input_files:
3475
+ print(f"No GeoTIFF files found in {input_dir}")
3476
+ return
3477
+
3478
+ # Create output directory if it doesn't exist
3479
+ os.makedirs(output_dir, exist_ok=True)
3480
+
3481
+ print(f"Processing {len(input_files)} files...")
3482
+
3483
+ for input_file in input_files:
3484
+ try:
3485
+ # Generate output filename
3486
+ base_name = os.path.splitext(os.path.basename(input_file))[0]
3487
+ output_file = os.path.join(output_dir, f"{base_name}_instances.tif")
3488
+
3489
+ print(f"Processing {input_file}...")
3490
+
3491
+ # Run instance segmentation inference
3492
+ instance_segmentation_inference_on_geotiff(
3493
+ model=model,
3494
+ geotiff_path=input_file,
3495
+ output_path=output_file,
3496
+ window_size=window_size,
3497
+ overlap=overlap,
3498
+ confidence_threshold=confidence_threshold,
3499
+ batch_size=batch_size,
3500
+ num_channels=num_channels,
3501
+ device=device,
3502
+ **kwargs,
3503
+ )
3504
+
3505
+ print(f"Saved result to {output_file}")
3506
+
3507
+ except Exception as e:
3508
+ print(f"Error processing {input_file}: {str(e)}")
3509
+ continue
3510
+
3511
+ print(f"Batch processing completed. Results saved to {output_dir}")
geoai/utils.py CHANGED
@@ -157,7 +157,8 @@ def view_image(
157
157
  image: Union[np.ndarray, torch.Tensor],
158
158
  transpose: bool = False,
159
159
  bdx: Optional[int] = None,
160
- scale_factor: float = 1.0,
160
+ clip_percentiles: Optional[Tuple[float, float]] = (2, 98),
161
+ gamma: Optional[float] = None,
161
162
  figsize: Tuple[int, int] = (10, 5),
162
163
  axis_off: bool = True,
163
164
  title: Optional[str] = None,
@@ -185,7 +186,7 @@ def view_image(
185
186
  elif isinstance(image, str):
186
187
  image = rasterio.open(image).read().transpose(1, 2, 0)
187
188
 
188
- plt.figure(figsize=figsize)
189
+ ax = plt.figure(figsize=figsize)
189
190
 
190
191
  if transpose:
191
192
  image = image.transpose(1, 2, 0)
@@ -196,8 +197,14 @@ def view_image(
196
197
  if len(image.shape) > 2 and image.shape[2] > 3:
197
198
  image = image[:, :, 0:3]
198
199
 
199
- if scale_factor != 1.0:
200
- image = np.clip(image * scale_factor, 0, 1)
200
+ if clip_percentiles is not None:
201
+ p_low, p_high = clip_percentiles
202
+ lower = np.percentile(image, p_low)
203
+ upper = np.percentile(image, p_high)
204
+ image = np.clip((image - lower) / (upper - lower), 0, 1)
205
+
206
+ if gamma is not None:
207
+ image = np.power(image, gamma)
201
208
 
202
209
  plt.imshow(image, **kwargs)
203
210
  if axis_off:
@@ -207,6 +214,8 @@ def view_image(
207
214
  plt.show()
208
215
  plt.close()
209
216
 
217
+ return ax
218
+
210
219
 
211
220
  def plot_images(
212
221
  images: Iterable[torch.Tensor],
@@ -5658,7 +5667,7 @@ def orthogonalize(
5658
5667
  if len(df) == 0:
5659
5668
  return ring
5660
5669
 
5661
- # If we have a triangle-like result (3 segments), return the original shape
5670
+ # If we have a triangle-like result (3 segments or less), return the original shape
5662
5671
  if len(df) <= 3:
5663
5672
  return ring
5664
5673
 
@@ -5669,8 +5678,116 @@ def orthogonalize(
5669
5678
  if len(joined_ring) == 0 or len(joined_ring[0]) < 3:
5670
5679
  return ring
5671
5680
 
5672
- # Basic validation: if result has 3 or fewer points (triangle), use original
5673
- if len(joined_ring[0]) <= 3:
5681
+ # Enhanced validation: check for triangular result and geometric validity
5682
+ result_coords = joined_ring[0]
5683
+
5684
+ # If result has 3 or fewer points (triangle), use original
5685
+ if len(result_coords) <= 3: # 2 points + closing point (degenerate)
5686
+ return ring
5687
+
5688
+ # Additional validation: check for degenerate geometry
5689
+ # Calculate area ratio to detect if the shape got severely distorted
5690
+ def calculate_polygon_area(coords):
5691
+ if len(coords) < 3:
5692
+ return 0
5693
+ area = 0
5694
+ n = len(coords)
5695
+ for i in range(n):
5696
+ j = (i + 1) % n
5697
+ area += coords[i][0] * coords[j][1]
5698
+ area -= coords[j][0] * coords[i][1]
5699
+ return abs(area) / 2
5700
+
5701
+ original_area = calculate_polygon_area(ring)
5702
+ result_area = calculate_polygon_area(result_coords)
5703
+
5704
+ # If the area changed dramatically (more than 30% shrinkage or 300% growth), use original
5705
+ if original_area > 0 and result_area > 0:
5706
+ area_ratio = result_area / original_area
5707
+ if area_ratio < 0.3 or area_ratio > 3.0:
5708
+ return ring
5709
+
5710
+ # Check for triangular spikes and problematic artifacts
5711
+ very_acute_angle_count = 0
5712
+ triangular_spike_detected = False
5713
+
5714
+ for i in range(len(result_coords) - 1): # -1 to exclude closing point
5715
+ p1 = result_coords[i - 1]
5716
+ p2 = result_coords[i]
5717
+ p3 = result_coords[(i + 1) % (len(result_coords) - 1)]
5718
+
5719
+ # Calculate angle at p2
5720
+ v1 = np.array([p1[0] - p2[0], p1[1] - p2[1]])
5721
+ v2 = np.array([p3[0] - p2[0], p3[1] - p2[1]])
5722
+
5723
+ v1_norm = np.linalg.norm(v1)
5724
+ v2_norm = np.linalg.norm(v2)
5725
+
5726
+ if v1_norm > 0 and v2_norm > 0:
5727
+ cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
5728
+ cos_angle = np.clip(cos_angle, -1, 1)
5729
+ angle = np.arccos(cos_angle)
5730
+
5731
+ # Count very acute angles (< 20 degrees) - these are likely spikes
5732
+ if angle < np.pi / 9: # 20 degrees
5733
+ very_acute_angle_count += 1
5734
+ # If it's very acute with short sides, it's definitely a spike
5735
+ if v1_norm < 5 or v2_norm < 5:
5736
+ triangular_spike_detected = True
5737
+
5738
+ # Check for excessively long edges that might be artifacts
5739
+ edge_lengths = []
5740
+ for i in range(len(result_coords) - 1):
5741
+ edge_len = np.sqrt(
5742
+ (result_coords[i + 1][0] - result_coords[i][0]) ** 2
5743
+ + (result_coords[i + 1][1] - result_coords[i][1]) ** 2
5744
+ )
5745
+ edge_lengths.append(edge_len)
5746
+
5747
+ excessive_edge_detected = False
5748
+ if len(edge_lengths) > 0:
5749
+ avg_edge_length = np.mean(edge_lengths)
5750
+ max_edge_length = np.max(edge_lengths)
5751
+ # Only reject if edge is extremely disproportionate (8x average)
5752
+ if max_edge_length > avg_edge_length * 8:
5753
+ excessive_edge_detected = True
5754
+
5755
+ # Check for triangular artifacts by detecting spikes that extend beyond bounds
5756
+ # Calculate original bounds
5757
+ orig_xs = [p[0] for p in ring]
5758
+ orig_ys = [p[1] for p in ring]
5759
+ orig_min_x, orig_max_x = min(orig_xs), max(orig_xs)
5760
+ orig_min_y, orig_max_y = min(orig_ys), max(orig_ys)
5761
+ orig_width = orig_max_x - orig_min_x
5762
+ orig_height = orig_max_y - orig_min_y
5763
+
5764
+ # Calculate result bounds
5765
+ result_xs = [p[0] for p in result_coords]
5766
+ result_ys = [p[1] for p in result_coords]
5767
+ result_min_x, result_max_x = min(result_xs), max(result_xs)
5768
+ result_min_y, result_max_y = min(result_ys), max(result_ys)
5769
+
5770
+ # Stricter bounds checking to catch triangular artifacts
5771
+ bounds_extension_detected = False
5772
+ # More conservative: only allow 10% extension
5773
+ tolerance_x = max(orig_width * 0.1, 1.0) # 10% tolerance, at least 1 unit
5774
+ tolerance_y = max(orig_height * 0.1, 1.0) # 10% tolerance, at least 1 unit
5775
+
5776
+ if (
5777
+ result_min_x < orig_min_x - tolerance_x
5778
+ or result_max_x > orig_max_x + tolerance_x
5779
+ or result_min_y < orig_min_y - tolerance_y
5780
+ or result_max_y > orig_max_y + tolerance_y
5781
+ ):
5782
+ bounds_extension_detected = True
5783
+
5784
+ # Reject if we detect triangular spikes, excessive edges, or bounds violations
5785
+ if (
5786
+ triangular_spike_detected
5787
+ or very_acute_angle_count > 2 # Multiple very acute angles
5788
+ or excessive_edge_detected
5789
+ or bounds_extension_detected
5790
+ ): # Any significant bounds extension
5674
5791
  return ring
5675
5792
 
5676
5793
  # Convert back to a list and ensure it's closed
@@ -5887,37 +6004,86 @@ def orthogonalize(
5887
6004
  }
5888
6005
  )
5889
6006
 
6007
+ # Improved fix: Prevent merging that would create triangular or problematic shapes
5890
6008
  if (
5891
- len(ortho_list) > 0 and ortho_list[0]["angle"] == ortho_list[-1]["angle"]
6009
+ len(ortho_list) > 3 and ortho_list[0]["angle"] == ortho_list[-1]["angle"]
5892
6010
  ): # join first and last segment if they're in same direction
5893
- totlen = ortho_list[0]["len"] + ortho_list[-1]["len"]
5894
- merge_cx = (
5895
- (ortho_list[0]["cx"] * ortho_list[0]["len"])
5896
- + (ortho_list[-1]["cx"] * ortho_list[-1]["len"])
5897
- ) / totlen
5898
-
5899
- merge_cy = (
5900
- (ortho_list[0]["cy"] * ortho_list[0]["len"])
5901
- + (ortho_list[-1]["cy"] * ortho_list[-1]["len"])
5902
- ) / totlen
5903
-
5904
- rot_angle = ortho_list[0]["angle"]
5905
- X1 = merge_cx - (totlen / 2) * math.cos(rot_angle)
5906
- X2 = merge_cx + (totlen / 2) * math.cos(rot_angle)
5907
- Y1 = merge_cy - (totlen / 2) * math.sin(rot_angle)
5908
- Y2 = merge_cy + (totlen / 2) * math.sin(rot_angle)
5909
-
5910
- ortho_list[-1] = {
5911
- "x1": X1,
5912
- "y1": Y1,
5913
- "x2": X2,
5914
- "y2": Y2,
5915
- "len": totlen,
5916
- "cx": merge_cx,
5917
- "cy": merge_cy,
5918
- "angle": rot_angle,
5919
- }
5920
- ortho_list = ortho_list[1:]
6011
+ # Check if merging would result in 3 or 4 segments (potentially triangular)
6012
+ resulting_segments = len(ortho_list) - 1
6013
+ if resulting_segments <= 4:
6014
+ # For very small polygons, be extra cautious about merging
6015
+ # Calculate the spatial relationship between first and last segments
6016
+ first_center = np.array([ortho_list[0]["cx"], ortho_list[0]["cy"]])
6017
+ last_center = np.array([ortho_list[-1]["cx"], ortho_list[-1]["cy"]])
6018
+ center_distance = np.linalg.norm(first_center - last_center)
6019
+
6020
+ # Get average segment length for comparison
6021
+ avg_length = sum(seg["len"] for seg in ortho_list) / len(ortho_list)
6022
+
6023
+ # Only merge if segments are close enough and it won't create degenerate shapes
6024
+ if center_distance > avg_length * 1.5:
6025
+ # Skip merging - segments are too far apart
6026
+ pass
6027
+ else:
6028
+ # Proceed with merging only for well-connected segments
6029
+ totlen = ortho_list[0]["len"] + ortho_list[-1]["len"]
6030
+ merge_cx = (
6031
+ (ortho_list[0]["cx"] * ortho_list[0]["len"])
6032
+ + (ortho_list[-1]["cx"] * ortho_list[-1]["len"])
6033
+ ) / totlen
6034
+
6035
+ merge_cy = (
6036
+ (ortho_list[0]["cy"] * ortho_list[0]["len"])
6037
+ + (ortho_list[-1]["cy"] * ortho_list[-1]["len"])
6038
+ ) / totlen
6039
+
6040
+ rot_angle = ortho_list[0]["angle"]
6041
+ X1 = merge_cx - (totlen / 2) * math.cos(rot_angle)
6042
+ X2 = merge_cx + (totlen / 2) * math.cos(rot_angle)
6043
+ Y1 = merge_cy - (totlen / 2) * math.sin(rot_angle)
6044
+ Y2 = merge_cy + (totlen / 2) * math.sin(rot_angle)
6045
+
6046
+ ortho_list[-1] = {
6047
+ "x1": X1,
6048
+ "y1": Y1,
6049
+ "x2": X2,
6050
+ "y2": Y2,
6051
+ "len": totlen,
6052
+ "cx": merge_cx,
6053
+ "cy": merge_cy,
6054
+ "angle": rot_angle,
6055
+ }
6056
+ ortho_list = ortho_list[1:]
6057
+ else:
6058
+ # For larger polygons, proceed with standard merging
6059
+ totlen = ortho_list[0]["len"] + ortho_list[-1]["len"]
6060
+ merge_cx = (
6061
+ (ortho_list[0]["cx"] * ortho_list[0]["len"])
6062
+ + (ortho_list[-1]["cx"] * ortho_list[-1]["len"])
6063
+ ) / totlen
6064
+
6065
+ merge_cy = (
6066
+ (ortho_list[0]["cy"] * ortho_list[0]["len"])
6067
+ + (ortho_list[-1]["cy"] * ortho_list[-1]["len"])
6068
+ ) / totlen
6069
+
6070
+ rot_angle = ortho_list[0]["angle"]
6071
+ X1 = merge_cx - (totlen / 2) * math.cos(rot_angle)
6072
+ X2 = merge_cx + (totlen / 2) * math.cos(rot_angle)
6073
+ Y1 = merge_cy - (totlen / 2) * math.sin(rot_angle)
6074
+ Y2 = merge_cy + (totlen / 2) * math.sin(rot_angle)
6075
+
6076
+ ortho_list[-1] = {
6077
+ "x1": X1,
6078
+ "y1": Y1,
6079
+ "x2": X2,
6080
+ "y2": Y2,
6081
+ "len": totlen,
6082
+ "cx": merge_cx,
6083
+ "cy": merge_cy,
6084
+ "angle": rot_angle,
6085
+ }
6086
+ ortho_list = ortho_list[1:]
5921
6087
  ortho_df = pd.DataFrame(ortho_list)
5922
6088
  return ortho_df
5923
6089
 
@@ -6026,12 +6192,49 @@ def orthogonalize(
6026
6192
  np.sqrt((x4 - x3) ** 2 + (y4 - y3) ** 2),
6027
6193
  )
6028
6194
 
6029
- # If intersection is too far away, use the endpoint of the first segment instead
6030
- if dist_to_seg1 > max_len * 0.5 or dist_to_seg2 > max_len * 0.5:
6031
- ring.append([x2, y2])
6195
+ # Improved intersection validation
6196
+ # Calculate angle between segments to detect sharp corners
6197
+ v1 = np.array([x2 - x1, y2 - y1])
6198
+ v2 = np.array([x4 - x3, y4 - y3])
6199
+ v1_norm = np.linalg.norm(v1)
6200
+ v2_norm = np.linalg.norm(v2)
6201
+
6202
+ if v1_norm > 0 and v2_norm > 0:
6203
+ cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
6204
+ cos_angle = np.clip(cos_angle, -1, 1)
6205
+ angle = np.arccos(cos_angle)
6206
+
6207
+ # Check for very sharp angles that could create triangular artifacts
6208
+ is_sharp_angle = (
6209
+ angle < np.pi / 6 or angle > 5 * np.pi / 6
6210
+ ) # <30° or >150°
6211
+ else:
6212
+ is_sharp_angle = False
6213
+
6214
+ # Determine whether to use intersection or segment endpoint
6215
+ if (
6216
+ dist_to_seg1 > max_len * 0.5
6217
+ or dist_to_seg2 > max_len * 0.5
6218
+ or is_sharp_angle
6219
+ ):
6220
+ # Use a more conservative approach for problematic intersections
6221
+ # Use the closer endpoint between segments
6222
+ dist_x2_to_seg2 = min(
6223
+ np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2),
6224
+ np.sqrt((x2 - x4) ** 2 + (y2 - y4) ** 2),
6225
+ )
6226
+ dist_x3_to_seg1 = min(
6227
+ np.sqrt((x3 - x1) ** 2 + (y3 - y1) ** 2),
6228
+ np.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2),
6229
+ )
6230
+
6231
+ if dist_x2_to_seg2 <= dist_x3_to_seg1:
6232
+ ring.append([x2, y2])
6233
+ else:
6234
+ ring.append([x3, y3])
6032
6235
  else:
6033
6236
  ring.append(intersection)
6034
- except Exception as e:
6237
+ except Exception:
6035
6238
  # If intersection calculation fails, use the endpoint of the first segment
6036
6239
  ring.append([x2, y2])
6037
6240
 
@@ -6057,11 +6260,42 @@ def orthogonalize(
6057
6260
  np.sqrt((x4 - x3) ** 2 + (y4 - y3) ** 2),
6058
6261
  )
6059
6262
 
6060
- if dist_to_seg1 > max_len * 0.5 or dist_to_seg2 > max_len * 0.5:
6061
- ring.append([x2, y2])
6263
+ # Apply same sharp angle detection for closing segment
6264
+ v1 = np.array([x2 - x1, y2 - y1])
6265
+ v2 = np.array([x4 - x3, y4 - y3])
6266
+ v1_norm = np.linalg.norm(v1)
6267
+ v2_norm = np.linalg.norm(v2)
6268
+
6269
+ if v1_norm > 0 and v2_norm > 0:
6270
+ cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
6271
+ cos_angle = np.clip(cos_angle, -1, 1)
6272
+ angle = np.arccos(cos_angle)
6273
+ is_sharp_angle = angle < np.pi / 6 or angle > 5 * np.pi / 6
6274
+ else:
6275
+ is_sharp_angle = False
6276
+
6277
+ if (
6278
+ dist_to_seg1 > max_len * 0.5
6279
+ or dist_to_seg2 > max_len * 0.5
6280
+ or is_sharp_angle
6281
+ ):
6282
+ # Use conservative approach for closing segment
6283
+ dist_x2_to_seg2 = min(
6284
+ np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2),
6285
+ np.sqrt((x2 - x4) ** 2 + (y2 - y4) ** 2),
6286
+ )
6287
+ dist_x3_to_seg1 = min(
6288
+ np.sqrt((x3 - x1) ** 2 + (y3 - y1) ** 2),
6289
+ np.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2),
6290
+ )
6291
+
6292
+ if dist_x2_to_seg2 <= dist_x3_to_seg1:
6293
+ ring.append([x2, y2])
6294
+ else:
6295
+ ring.append([x3, y3])
6062
6296
  else:
6063
6297
  ring.append(intersection)
6064
- except Exception as e:
6298
+ except Exception:
6065
6299
  # If intersection calculation fails, use the endpoint of the last segment
6066
6300
  ring.append([x2, y2])
6067
6301
 
@@ -7219,41 +7453,61 @@ def plot_performance_metrics(history_path, figsize=(15, 5), verbose=True):
7219
7453
  """
7220
7454
  history = torch.load(history_path)
7221
7455
 
7456
+ # Handle different key naming conventions
7457
+ train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
7458
+ val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
7459
+ val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
7460
+ val_dice_key = "val_dices" if "val_dices" in history else "val_dice"
7461
+
7462
+ # Determine number of subplots based on available metrics
7463
+ has_dice = val_dice_key in history
7464
+ n_plots = 3 if has_dice else 2
7465
+ figsize = (15, 5) if has_dice else (10, 5)
7466
+
7222
7467
  plt.figure(figsize=figsize)
7223
7468
 
7224
- plt.subplot(1, 3, 1)
7225
- plt.plot(history["train_losses"], label="Train Loss")
7226
- plt.plot(history["val_losses"], label="Val Loss")
7469
+ # Plot loss
7470
+ plt.subplot(1, n_plots, 1)
7471
+ if train_loss_key in history:
7472
+ plt.plot(history[train_loss_key], label="Train Loss")
7473
+ if val_loss_key in history:
7474
+ plt.plot(history[val_loss_key], label="Val Loss")
7227
7475
  plt.title("Loss")
7228
7476
  plt.xlabel("Epoch")
7229
7477
  plt.ylabel("Loss")
7230
7478
  plt.legend()
7231
7479
  plt.grid(True)
7232
7480
 
7233
- plt.subplot(1, 3, 2)
7234
- plt.plot(history["val_ious"], label="Val IoU")
7481
+ # Plot IoU
7482
+ plt.subplot(1, n_plots, 2)
7483
+ if val_iou_key in history:
7484
+ plt.plot(history[val_iou_key], label="Val IoU")
7235
7485
  plt.title("IoU Score")
7236
7486
  plt.xlabel("Epoch")
7237
7487
  plt.ylabel("IoU")
7238
7488
  plt.legend()
7239
7489
  plt.grid(True)
7240
7490
 
7241
- plt.subplot(1, 3, 3)
7242
- plt.plot(history["val_dices"], label="Val Dice")
7243
- plt.title("Dice Score")
7244
- plt.xlabel("Epoch")
7245
- plt.ylabel("Dice")
7246
- plt.legend()
7247
- plt.grid(True)
7491
+ # Plot Dice if available
7492
+ if has_dice:
7493
+ plt.subplot(1, n_plots, 3)
7494
+ plt.plot(history[val_dice_key], label="Val Dice")
7495
+ plt.title("Dice Score")
7496
+ plt.xlabel("Epoch")
7497
+ plt.ylabel("Dice")
7498
+ plt.legend()
7499
+ plt.grid(True)
7248
7500
 
7249
7501
  plt.tight_layout()
7250
7502
  plt.show()
7251
7503
 
7252
7504
  if verbose:
7253
- print(f"Best IoU: {max(history['val_ious']):.4f}")
7254
- print(f"Best Dice: {max(history['val_dices']):.4f}")
7255
- print(f"Final IoU: {history['val_ious'][-1]:.4f}")
7256
- print(f"Final Dice: {history['val_dices'][-1]:.4f}")
7505
+ if val_iou_key in history:
7506
+ print(f"Best IoU: {max(history[val_iou_key]):.4f}")
7507
+ print(f"Final IoU: {history[val_iou_key][-1]:.4f}")
7508
+ if val_dice_key in history:
7509
+ print(f"Best Dice: {max(history[val_dice_key]):.4f}")
7510
+ print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
7257
7511
 
7258
7512
 
7259
7513
  def get_device():
@@ -7280,105 +7534,123 @@ def plot_prediction_comparison(
7280
7534
  prediction_colormap: str = "gray",
7281
7535
  ground_truth_colormap: str = "gray",
7282
7536
  original_colormap: Optional[str] = None,
7537
+ indexes: Optional[List[int]] = None,
7538
+ divider: Optional[float] = None,
7283
7539
  ):
7284
- """
7285
- Plot original image, prediction image, and optionally ground truth image side by side.
7540
+ """Plot original image, prediction, and optional ground truth side by side.
7541
+
7542
+ Supports input as file paths, NumPy arrays, or PIL Images. For multi-band
7543
+ images, selected channels can be specified via `indexes`. If the image data
7544
+ is not normalized (e.g., Sentinel-2 [0, 10000]), the `divider` can be used
7545
+ to scale values for visualization.
7286
7546
 
7287
7547
  Args:
7288
- original_image: Original input image (file path, numpy array, or PIL Image)
7289
- prediction_image: Prediction/segmentation mask (file path, numpy array, or PIL Image)
7290
- ground_truth_image: Optional ground truth mask (file path, numpy array, or PIL Image)
7291
- titles: Optional list of titles for each subplot
7292
- figsize: Figure size tuple (width, height)
7293
- save_path: Optional path to save the plot
7294
- show_plot: Whether to display the plot
7295
- prediction_colormap: Colormap for prediction image
7296
- ground_truth_colormap: Colormap for ground truth image
7297
- original_colormap: Colormap for original image (None for RGB)
7548
+ original_image (Union[str, np.ndarray, Image.Image]):
7549
+ Original input image as a file path, NumPy array, or PIL Image.
7550
+ prediction_image (Union[str, np.ndarray, Image.Image]):
7551
+ Predicted segmentation mask image.
7552
+ ground_truth_image (Optional[Union[str, np.ndarray, Image.Image]], optional):
7553
+ Ground truth mask image. Defaults to None.
7554
+ titles (Optional[List[str]], optional):
7555
+ List of titles for the subplots. If not provided, default titles are used.
7556
+ figsize (Tuple[int, int], optional):
7557
+ Size of the entire figure in inches. Defaults to (15, 5).
7558
+ save_path (Optional[str], optional):
7559
+ If specified, saves the figure to this path. Defaults to None.
7560
+ show_plot (bool, optional):
7561
+ Whether to display the figure using plt.show(). Defaults to True.
7562
+ prediction_colormap (str, optional):
7563
+ Colormap to use for the prediction mask. Defaults to "gray".
7564
+ ground_truth_colormap (str, optional):
7565
+ Colormap to use for the ground truth mask. Defaults to "gray".
7566
+ original_colormap (Optional[str], optional):
7567
+ Colormap to use for the original image if it's grayscale. Defaults to None.
7568
+ indexes (Optional[List[int]], optional):
7569
+ List of band/channel indexes (0-based for NumPy, 1-based for rasterio) to extract from the original image.
7570
+ Useful for multi-band imagery like Sentinel-2. Defaults to None.
7571
+ divider (Optional[float], optional):
7572
+ Value to divide the original image by for normalization (e.g., 10000 for reflectance). Defaults to None.
7298
7573
 
7299
7574
  Returns:
7300
- matplotlib.figure.Figure: The figure object
7575
+ matplotlib.figure.Figure:
7576
+ The generated matplotlib figure object.
7301
7577
  """
7302
7578
 
7303
- def _load_image(img_input):
7579
+ def _load_image(img_input, indexes=None):
7304
7580
  """Helper function to load image from various input types."""
7305
7581
  if isinstance(img_input, str):
7306
- # File path
7307
7582
  if img_input.lower().endswith((".tif", ".tiff")):
7308
- # Handle GeoTIFF files
7309
7583
  with rasterio.open(img_input) as src:
7310
- img = src.read()
7311
- if img.shape[0] == 1:
7312
- # Single band
7313
- img = img[0]
7584
+ if indexes:
7585
+ img = src.read(indexes) # 1-based
7586
+ img = (
7587
+ np.transpose(img, (1, 2, 0)) if len(indexes) > 1 else img[0]
7588
+ )
7314
7589
  else:
7315
- # Multi-band, transpose to (H, W, C)
7316
- img = np.transpose(img, (1, 2, 0))
7590
+ img = src.read()
7591
+ if img.shape[0] == 1:
7592
+ img = img[0]
7593
+ else:
7594
+ img = np.transpose(img, (1, 2, 0))
7317
7595
  else:
7318
- # Regular image file
7319
7596
  img = np.array(Image.open(img_input))
7320
7597
  elif isinstance(img_input, Image.Image):
7321
- # PIL Image
7322
7598
  img = np.array(img_input)
7323
7599
  elif isinstance(img_input, np.ndarray):
7324
- # NumPy array
7325
7600
  img = img_input
7601
+ if indexes is not None and img.ndim == 3:
7602
+ img = img[:, :, indexes]
7326
7603
  else:
7327
7604
  raise ValueError(f"Unsupported image type: {type(img_input)}")
7328
-
7329
7605
  return img
7330
7606
 
7331
7607
  # Load images
7332
- original = _load_image(original_image)
7608
+ original = _load_image(original_image, indexes=indexes)
7333
7609
  prediction = _load_image(prediction_image)
7334
7610
  ground_truth = (
7335
7611
  _load_image(ground_truth_image) if ground_truth_image is not None else None
7336
7612
  )
7337
7613
 
7338
- # Determine number of subplots
7339
- num_plots = 3 if ground_truth is not None else 2
7614
+ # Apply divider normalization if requested
7615
+ if divider is not None and isinstance(original, np.ndarray) and original.ndim == 3:
7616
+ original = np.clip(original.astype(np.float32) / divider, 0, 1)
7340
7617
 
7341
- # Create figure and subplots
7618
+ # Determine layout
7619
+ num_plots = 3 if ground_truth is not None else 2
7342
7620
  fig, axes = plt.subplots(1, num_plots, figsize=figsize)
7343
7621
  if num_plots == 2:
7344
7622
  axes = [axes[0], axes[1]]
7345
7623
 
7346
- # Default titles
7347
7624
  if titles is None:
7348
7625
  titles = ["Original Image", "Prediction"]
7349
7626
  if ground_truth is not None:
7350
7627
  titles.append("Ground Truth")
7351
7628
 
7352
- # Plot original image
7353
- if len(original.shape) == 3 and original.shape[2] in [3, 4]:
7354
- # RGB or RGBA image
7629
+ # Plot original
7630
+ if original.ndim == 3 and original.shape[2] in [3, 4]:
7355
7631
  axes[0].imshow(original)
7356
7632
  else:
7357
- # Grayscale or single channel
7358
7633
  axes[0].imshow(original, cmap=original_colormap)
7359
7634
  axes[0].set_title(titles[0])
7360
7635
  axes[0].axis("off")
7361
7636
 
7362
- # Plot prediction image
7637
+ # Prediction
7363
7638
  axes[1].imshow(prediction, cmap=prediction_colormap)
7364
7639
  axes[1].set_title(titles[1])
7365
7640
  axes[1].axis("off")
7366
7641
 
7367
- # Plot ground truth if provided
7642
+ # Ground truth
7368
7643
  if ground_truth is not None:
7369
7644
  axes[2].imshow(ground_truth, cmap=ground_truth_colormap)
7370
7645
  axes[2].set_title(titles[2])
7371
7646
  axes[2].axis("off")
7372
7647
 
7373
- # Adjust layout
7374
7648
  plt.tight_layout()
7375
7649
 
7376
- # Save if requested
7377
7650
  if save_path:
7378
7651
  plt.savefig(save_path, dpi=300, bbox_inches="tight")
7379
7652
  print(f"Plot saved to: {save_path}")
7380
7653
 
7381
- # Show plot
7382
7654
  if show_plot:
7383
7655
  plt.show()
7384
7656
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: geoai-py
3
- Version: 0.8.1
3
+ Version: 0.8.3
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
@@ -0,0 +1,17 @@
1
+ geoai/__init__.py,sha256=Ks7SlSE7TkToYHB3EhXxIQLsYXuzfwHjyEUKJjkd8SA,3765
2
+ geoai/classify.py,sha256=O8fah3DOBDMZW7V_qfDYsUnjB-9Wo5fjA-0e4wvUeAE,35054
3
+ geoai/download.py,sha256=EQpcrcqMsYhDpd7bpjf4hGS5xL2oO-jsjngLgjjP3cE,46599
4
+ geoai/extract.py,sha256=vyHH1k5zaXiy1SLdCLXxbWiNLp8XKdu_MXZoREMtAOQ,119102
5
+ geoai/geoai.py,sha256=_Ar7PJgWpN86tm1YhzLVj7r1lYA3faRJT5brf1WzcwQ,9631
6
+ geoai/hf.py,sha256=mLKGxEAS5eHkxZLwuLpYc1o7e3-7QIXdBv-QUY-RkFk,17072
7
+ geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
8
+ geoai/segment.py,sha256=pThAyq8kjgVDhMwHMiWkZ2qL5ykzA5lRg7tyMmSEBxk,43434
9
+ geoai/segmentation.py,sha256=AtPzCvguHAEeuyXafa4bzMFATvltEYcah1B8ZMfkM_s,11373
10
+ geoai/train.py,sha256=0uIz3i3sH-Knyq9wBrKLy22f--oLk2m-fCdhfwJosNU,129112
11
+ geoai/utils.py,sha256=jNSXXQ056iyeDHDnkPUJ0_TLe-NonGbXh2YCN-gM50c,297270
12
+ geoai_py-0.8.3.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
13
+ geoai_py-0.8.3.dist-info/METADATA,sha256=ifNb98sm8g-Bsicc_R_9ljKahRXjibbEbSxxG8ga2mk,6661
14
+ geoai_py-0.8.3.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
15
+ geoai_py-0.8.3.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
16
+ geoai_py-0.8.3.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
17
+ geoai_py-0.8.3.dist-info/RECORD,,
@@ -1,17 +0,0 @@
1
- geoai/__init__.py,sha256=NRaQhG5-sDLmpYXZXhbJ_LPkCEQbT-71qDAspoNEyxk,3765
2
- geoai/classify.py,sha256=O8fah3DOBDMZW7V_qfDYsUnjB-9Wo5fjA-0e4wvUeAE,35054
3
- geoai/download.py,sha256=EQpcrcqMsYhDpd7bpjf4hGS5xL2oO-jsjngLgjjP3cE,46599
4
- geoai/extract.py,sha256=vyHH1k5zaXiy1SLdCLXxbWiNLp8XKdu_MXZoREMtAOQ,119102
5
- geoai/geoai.py,sha256=HU6epCjpk228J65ZXmxY6GKlHg_ncmWV3UQUr_f8QTM,9447
6
- geoai/hf.py,sha256=mLKGxEAS5eHkxZLwuLpYc1o7e3-7QIXdBv-QUY-RkFk,17072
7
- geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
8
- geoai/segment.py,sha256=pThAyq8kjgVDhMwHMiWkZ2qL5ykzA5lRg7tyMmSEBxk,43434
9
- geoai/segmentation.py,sha256=AtPzCvguHAEeuyXafa4bzMFATvltEYcah1B8ZMfkM_s,11373
10
- geoai/train.py,sha256=Mrsb0yMVnprQHld3zDvA-puc-r8hGm1XgG0j2GGIn7E,112845
11
- geoai/utils.py,sha256=WQ_h39LpU65i-LP6GfFJLZBpiY0fyPdFc7dZcuyZ24s,284611
12
- geoai_py-0.8.1.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
13
- geoai_py-0.8.1.dist-info/METADATA,sha256=zOOCsS3BTaqIjl3DPsbxDDba4kK7kDKgNk2ytC_sYJ8,6661
14
- geoai_py-0.8.1.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
15
- geoai_py-0.8.1.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
16
- geoai_py-0.8.1.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
17
- geoai_py-0.8.1.dist-info/RECORD,,