geoai-py 0.8.0__py2.py3-none-any.whl → 0.8.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.8.0"
5
+ __version__ = "0.8.2"
6
6
 
7
7
 
8
8
  import os
geoai/train.py CHANGED
@@ -511,7 +511,7 @@ def visualize_predictions(model, dataset, device, num_samples=5, output_dir=None
511
511
 
512
512
  # Create output directory if needed
513
513
  if output_dir:
514
- os.makedirs(output_dir, exist_ok=True)
514
+ os.makedirs(os.path.abspath(output_dir), exist_ok=True)
515
515
 
516
516
  # Select random samples
517
517
  indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
@@ -648,7 +648,7 @@ def train_MaskRCNN_model(
648
648
  torch.backends.cudnn.benchmark = False
649
649
 
650
650
  # Create output directory
651
- os.makedirs(output_dir, exist_ok=True)
651
+ os.makedirs(os.path.abspath(output_dir), exist_ok=True)
652
652
 
653
653
  # Get device
654
654
  if device is None:
@@ -1216,7 +1216,7 @@ def object_detection_batch(
1216
1216
  )
1217
1217
 
1218
1218
  if not os.path.exists(output_dir):
1219
- os.makedirs(output_dir, exist_ok=True)
1219
+ os.makedirs(os.path.abspath(output_dir), exist_ok=True)
1220
1220
 
1221
1221
  if not os.path.exists(model_path):
1222
1222
  try:
@@ -1905,7 +1905,7 @@ def train_segmentation_model(
1905
1905
  torch.backends.cudnn.benchmark = False
1906
1906
 
1907
1907
  # Create output directory
1908
- os.makedirs(output_dir, exist_ok=True)
1908
+ os.makedirs(os.path.abspath(output_dir), exist_ok=True)
1909
1909
 
1910
1910
  # Get device
1911
1911
  if device is None:
@@ -2511,7 +2511,7 @@ def semantic_inference_on_geotiff(
2511
2511
  print(f"Inference completed in {inference_time:.2f} seconds")
2512
2512
 
2513
2513
  # Save output
2514
- out_dir = os.path.dirname(output_path)
2514
+ out_dir = os.path.abspath(os.path.dirname(output_path))
2515
2515
  os.makedirs(out_dir, exist_ok=True)
2516
2516
  with rasterio.open(output_path, "w", **out_meta) as dst:
2517
2517
  dst.write(mask, 1)
@@ -2778,7 +2778,7 @@ def semantic_inference_on_image(
2778
2778
  # Change extension to PNG if binary output to preserve exact values
2779
2779
  output_path_png = os.path.splitext(output_path)[0] + ".png"
2780
2780
  output_img = Image.fromarray(mask, mode="L")
2781
- out_dir = os.path.dirname(output_path)
2781
+ out_dir = os.path.abspath(os.path.dirname(output_path))
2782
2782
  os.makedirs(out_dir, exist_ok=True)
2783
2783
  output_img.save(output_path_png)
2784
2784
  if not quiet:
@@ -2887,7 +2887,7 @@ def semantic_segmentation(
2887
2887
  )
2888
2888
  else:
2889
2889
  # Create output directory if it doesn't exist
2890
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
2890
+ os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
2891
2891
 
2892
2892
  semantic_inference_on_image(
2893
2893
  model=model,
@@ -2962,7 +2962,7 @@ def semantic_segmentation_batch(
2962
2962
  raise FileNotFoundError(f"Input directory does not exist: {input_dir}")
2963
2963
 
2964
2964
  # Create output directory if it doesn't exist
2965
- os.makedirs(output_dir, exist_ok=True)
2965
+ os.makedirs(os.path.abspath(output_dir), exist_ok=True)
2966
2966
 
2967
2967
  # Get all supported image files
2968
2968
  image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
geoai/utils.py CHANGED
@@ -5,6 +5,7 @@ import glob
5
5
  import json
6
6
  import math
7
7
  import os
8
+ import subprocess
8
9
  import warnings
9
10
  import xml.etree.ElementTree as ET
10
11
  from collections import OrderedDict
@@ -7279,106 +7280,213 @@ def plot_prediction_comparison(
7279
7280
  prediction_colormap: str = "gray",
7280
7281
  ground_truth_colormap: str = "gray",
7281
7282
  original_colormap: Optional[str] = None,
7283
+ indexes: Optional[List[int]] = None,
7284
+ divider: Optional[float] = None,
7282
7285
  ):
7283
- """
7284
- Plot original image, prediction image, and optionally ground truth image side by side.
7286
+ """Plot original image, prediction, and optional ground truth side by side.
7287
+
7288
+ Supports input as file paths, NumPy arrays, or PIL Images. For multi-band
7289
+ images, selected channels can be specified via `indexes`. If the image data
7290
+ is not normalized (e.g., Sentinel-2 [0, 10000]), the `divider` can be used
7291
+ to scale values for visualization.
7285
7292
 
7286
7293
  Args:
7287
- original_image: Original input image (file path, numpy array, or PIL Image)
7288
- prediction_image: Prediction/segmentation mask (file path, numpy array, or PIL Image)
7289
- ground_truth_image: Optional ground truth mask (file path, numpy array, or PIL Image)
7290
- titles: Optional list of titles for each subplot
7291
- figsize: Figure size tuple (width, height)
7292
- save_path: Optional path to save the plot
7293
- show_plot: Whether to display the plot
7294
- prediction_colormap: Colormap for prediction image
7295
- ground_truth_colormap: Colormap for ground truth image
7296
- original_colormap: Colormap for original image (None for RGB)
7294
+ original_image (Union[str, np.ndarray, Image.Image]):
7295
+ Original input image as a file path, NumPy array, or PIL Image.
7296
+ prediction_image (Union[str, np.ndarray, Image.Image]):
7297
+ Predicted segmentation mask image.
7298
+ ground_truth_image (Optional[Union[str, np.ndarray, Image.Image]], optional):
7299
+ Ground truth mask image. Defaults to None.
7300
+ titles (Optional[List[str]], optional):
7301
+ List of titles for the subplots. If not provided, default titles are used.
7302
+ figsize (Tuple[int, int], optional):
7303
+ Size of the entire figure in inches. Defaults to (15, 5).
7304
+ save_path (Optional[str], optional):
7305
+ If specified, saves the figure to this path. Defaults to None.
7306
+ show_plot (bool, optional):
7307
+ Whether to display the figure using plt.show(). Defaults to True.
7308
+ prediction_colormap (str, optional):
7309
+ Colormap to use for the prediction mask. Defaults to "gray".
7310
+ ground_truth_colormap (str, optional):
7311
+ Colormap to use for the ground truth mask. Defaults to "gray".
7312
+ original_colormap (Optional[str], optional):
7313
+ Colormap to use for the original image if it's grayscale. Defaults to None.
7314
+ indexes (Optional[List[int]], optional):
7315
+ List of band/channel indexes (0-based for NumPy, 1-based for rasterio) to extract from the original image.
7316
+ Useful for multi-band imagery like Sentinel-2. Defaults to None.
7317
+ divider (Optional[float], optional):
7318
+ Value to divide the original image by for normalization (e.g., 10000 for reflectance). Defaults to None.
7297
7319
 
7298
7320
  Returns:
7299
- matplotlib.figure.Figure: The figure object
7321
+ matplotlib.figure.Figure:
7322
+ The generated matplotlib figure object.
7300
7323
  """
7301
7324
 
7302
- def _load_image(img_input):
7325
+ def _load_image(img_input, indexes=None):
7303
7326
  """Helper function to load image from various input types."""
7304
7327
  if isinstance(img_input, str):
7305
- # File path
7306
7328
  if img_input.lower().endswith((".tif", ".tiff")):
7307
- # Handle GeoTIFF files
7308
7329
  with rasterio.open(img_input) as src:
7309
- img = src.read()
7310
- if img.shape[0] == 1:
7311
- # Single band
7312
- img = img[0]
7330
+ if indexes:
7331
+ img = src.read(indexes) # 1-based
7332
+ img = (
7333
+ np.transpose(img, (1, 2, 0)) if len(indexes) > 1 else img[0]
7334
+ )
7313
7335
  else:
7314
- # Multi-band, transpose to (H, W, C)
7315
- img = np.transpose(img, (1, 2, 0))
7336
+ img = src.read()
7337
+ if img.shape[0] == 1:
7338
+ img = img[0]
7339
+ else:
7340
+ img = np.transpose(img, (1, 2, 0))
7316
7341
  else:
7317
- # Regular image file
7318
7342
  img = np.array(Image.open(img_input))
7319
7343
  elif isinstance(img_input, Image.Image):
7320
- # PIL Image
7321
7344
  img = np.array(img_input)
7322
7345
  elif isinstance(img_input, np.ndarray):
7323
- # NumPy array
7324
7346
  img = img_input
7347
+ if indexes is not None and img.ndim == 3:
7348
+ img = img[:, :, indexes]
7325
7349
  else:
7326
7350
  raise ValueError(f"Unsupported image type: {type(img_input)}")
7327
-
7328
7351
  return img
7329
7352
 
7330
7353
  # Load images
7331
- original = _load_image(original_image)
7354
+ original = _load_image(original_image, indexes=indexes)
7332
7355
  prediction = _load_image(prediction_image)
7333
7356
  ground_truth = (
7334
7357
  _load_image(ground_truth_image) if ground_truth_image is not None else None
7335
7358
  )
7336
7359
 
7337
- # Determine number of subplots
7338
- num_plots = 3 if ground_truth is not None else 2
7360
+ # Apply divider normalization if requested
7361
+ if divider is not None and isinstance(original, np.ndarray) and original.ndim == 3:
7362
+ original = np.clip(original.astype(np.float32) / divider, 0, 1)
7339
7363
 
7340
- # Create figure and subplots
7364
+ # Determine layout
7365
+ num_plots = 3 if ground_truth is not None else 2
7341
7366
  fig, axes = plt.subplots(1, num_plots, figsize=figsize)
7342
7367
  if num_plots == 2:
7343
7368
  axes = [axes[0], axes[1]]
7344
7369
 
7345
- # Default titles
7346
7370
  if titles is None:
7347
7371
  titles = ["Original Image", "Prediction"]
7348
7372
  if ground_truth is not None:
7349
7373
  titles.append("Ground Truth")
7350
7374
 
7351
- # Plot original image
7352
- if len(original.shape) == 3 and original.shape[2] in [3, 4]:
7353
- # RGB or RGBA image
7375
+ # Plot original
7376
+ if original.ndim == 3 and original.shape[2] in [3, 4]:
7354
7377
  axes[0].imshow(original)
7355
7378
  else:
7356
- # Grayscale or single channel
7357
7379
  axes[0].imshow(original, cmap=original_colormap)
7358
7380
  axes[0].set_title(titles[0])
7359
7381
  axes[0].axis("off")
7360
7382
 
7361
- # Plot prediction image
7383
+ # Prediction
7362
7384
  axes[1].imshow(prediction, cmap=prediction_colormap)
7363
7385
  axes[1].set_title(titles[1])
7364
7386
  axes[1].axis("off")
7365
7387
 
7366
- # Plot ground truth if provided
7388
+ # Ground truth
7367
7389
  if ground_truth is not None:
7368
7390
  axes[2].imshow(ground_truth, cmap=ground_truth_colormap)
7369
7391
  axes[2].set_title(titles[2])
7370
7392
  axes[2].axis("off")
7371
7393
 
7372
- # Adjust layout
7373
7394
  plt.tight_layout()
7374
7395
 
7375
- # Save if requested
7376
7396
  if save_path:
7377
7397
  plt.savefig(save_path, dpi=300, bbox_inches="tight")
7378
7398
  print(f"Plot saved to: {save_path}")
7379
7399
 
7380
- # Show plot
7381
7400
  if show_plot:
7382
7401
  plt.show()
7383
7402
 
7384
7403
  return fig
7404
+
7405
+
7406
+ def get_raster_resolution(image_path: str) -> Tuple[float, float]:
7407
+ """Get pixel resolution from the raster using rasterio.
7408
+
7409
+ Args:
7410
+ image_path: The path to the raster image.
7411
+
7412
+ Returns:
7413
+ A tuple of (x resolution, y resolution).
7414
+ """
7415
+ with rasterio.open(image_path) as src:
7416
+ res = src.res
7417
+ return res
7418
+
7419
+
7420
+ def stack_bands(
7421
+ input_files: List[str],
7422
+ output_file: str,
7423
+ resolution: Optional[float] = None,
7424
+ dtype: Optional[str] = None, # e.g., "UInt16", "Float32"
7425
+ temp_vrt: str = "stack.vrt",
7426
+ overwrite: bool = False,
7427
+ compress: str = "DEFLATE",
7428
+ output_format: str = "COG",
7429
+ extra_gdal_translate_args: Optional[List[str]] = None,
7430
+ ) -> str:
7431
+ """
7432
+ Stack bands from multiple images into a single multi-band GeoTIFF.
7433
+
7434
+ Args:
7435
+ input_files (List[str]): List of input image paths.
7436
+ output_file (str): Path to the output stacked image.
7437
+ resolution (float, optional): Output resolution. If None, inferred from first image.
7438
+ dtype (str, optional): Output data type (e.g., "UInt16", "Float32").
7439
+ temp_vrt (str): Temporary VRT filename.
7440
+ overwrite (bool): Whether to overwrite the output file.
7441
+ compress (str): Compression method.
7442
+ output_format (str): GDAL output format (default is "COG").
7443
+ extra_gdal_translate_args (List[str], optional): Extra arguments for gdal_translate.
7444
+
7445
+ Returns:
7446
+ str: Path to the output file.
7447
+ """
7448
+ if not input_files:
7449
+ raise ValueError("No input files provided.")
7450
+
7451
+ if os.path.exists(output_file) and not overwrite:
7452
+ print(f"Output file already exists: {output_file}")
7453
+ return output_file
7454
+
7455
+ # Infer resolution if not provided
7456
+ if resolution is None:
7457
+ resolution_x, resolution_y = get_raster_resolution(input_files[0])
7458
+ else:
7459
+ resolution_x = resolution_y = resolution
7460
+
7461
+ # Step 1: Build VRT
7462
+ vrt_cmd = ["gdalbuildvrt", "-separate", temp_vrt] + input_files
7463
+ subprocess.run(vrt_cmd, check=True)
7464
+
7465
+ # Step 2: Translate VRT to output GeoTIFF
7466
+ translate_cmd = [
7467
+ "gdal_translate",
7468
+ "-tr",
7469
+ str(resolution_x),
7470
+ str(resolution_y),
7471
+ temp_vrt,
7472
+ output_file,
7473
+ "-of",
7474
+ output_format,
7475
+ "-co",
7476
+ f"COMPRESS={compress}",
7477
+ ]
7478
+
7479
+ if dtype:
7480
+ translate_cmd.insert(1, "-ot")
7481
+ translate_cmd.insert(2, dtype)
7482
+
7483
+ if extra_gdal_translate_args:
7484
+ translate_cmd += extra_gdal_translate_args
7485
+
7486
+ subprocess.run(translate_cmd, check=True)
7487
+
7488
+ # Step 3: Clean up VRT
7489
+ if os.path.exists(temp_vrt):
7490
+ os.remove(temp_vrt)
7491
+
7492
+ return output_file
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: geoai-py
3
- Version: 0.8.0
3
+ Version: 0.8.2
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
@@ -1,4 +1,4 @@
1
- geoai/__init__.py,sha256=05t9D_IforM2DDpgmwh1BPzGOsmQqfRi1BC_-Ola9_w,3765
1
+ geoai/__init__.py,sha256=wKdVPgcUCURuMI3TNzqh1MPw96Bs5XL7mNUST75bOjw,3765
2
2
  geoai/classify.py,sha256=O8fah3DOBDMZW7V_qfDYsUnjB-9Wo5fjA-0e4wvUeAE,35054
3
3
  geoai/download.py,sha256=EQpcrcqMsYhDpd7bpjf4hGS5xL2oO-jsjngLgjjP3cE,46599
4
4
  geoai/extract.py,sha256=vyHH1k5zaXiy1SLdCLXxbWiNLp8XKdu_MXZoREMtAOQ,119102
@@ -7,11 +7,11 @@ geoai/hf.py,sha256=mLKGxEAS5eHkxZLwuLpYc1o7e3-7QIXdBv-QUY-RkFk,17072
7
7
  geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
8
8
  geoai/segment.py,sha256=pThAyq8kjgVDhMwHMiWkZ2qL5ykzA5lRg7tyMmSEBxk,43434
9
9
  geoai/segmentation.py,sha256=AtPzCvguHAEeuyXafa4bzMFATvltEYcah1B8ZMfkM_s,11373
10
- geoai/train.py,sha256=bwcCjDwiWF3aGo2gRqhd-6Xlhg3xZV_SeHx0kskWsPM,112709
11
- geoai/utils.py,sha256=Ucards_Kb3H77bFn-pVpnt0f1oez-Vu3eFaaZhRV5D4,281943
12
- geoai_py-0.8.0.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
13
- geoai_py-0.8.0.dist-info/METADATA,sha256=iYkcRGlikK5Tdy1gBmiCBsd6t7tQhPi77e39btQMKgk,6661
14
- geoai_py-0.8.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
15
- geoai_py-0.8.0.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
16
- geoai_py-0.8.0.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
17
- geoai_py-0.8.0.dist-info/RECORD,,
10
+ geoai/train.py,sha256=Mrsb0yMVnprQHld3zDvA-puc-r8hGm1XgG0j2GGIn7E,112845
11
+ geoai/utils.py,sha256=_JVEhFUzOdDS_Rmco8c4BJVnrs2VvUVHKf4ubtim5fg,286109
12
+ geoai_py-0.8.2.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
13
+ geoai_py-0.8.2.dist-info/METADATA,sha256=6LIaLEjT5AtbV5OOJwMT7NQCMI6pHcKXsfQbakWl6Ns,6661
14
+ geoai_py-0.8.2.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
15
+ geoai_py-0.8.2.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
16
+ geoai_py-0.8.2.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
17
+ geoai_py-0.8.2.dist-info/RECORD,,