geoai-py 0.8.0__py2.py3-none-any.whl → 0.8.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/train.py +8 -8
- geoai/utils.py +148 -40
- {geoai_py-0.8.0.dist-info → geoai_py-0.8.2.dist-info}/METADATA +1 -1
- {geoai_py-0.8.0.dist-info → geoai_py-0.8.2.dist-info}/RECORD +9 -9
- {geoai_py-0.8.0.dist-info → geoai_py-0.8.2.dist-info}/WHEEL +0 -0
- {geoai_py-0.8.0.dist-info → geoai_py-0.8.2.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.8.0.dist-info → geoai_py-0.8.2.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.8.0.dist-info → geoai_py-0.8.2.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
geoai/train.py
CHANGED
@@ -511,7 +511,7 @@ def visualize_predictions(model, dataset, device, num_samples=5, output_dir=None
|
|
511
511
|
|
512
512
|
# Create output directory if needed
|
513
513
|
if output_dir:
|
514
|
-
os.makedirs(output_dir, exist_ok=True)
|
514
|
+
os.makedirs(os.path.abspath(output_dir), exist_ok=True)
|
515
515
|
|
516
516
|
# Select random samples
|
517
517
|
indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
|
@@ -648,7 +648,7 @@ def train_MaskRCNN_model(
|
|
648
648
|
torch.backends.cudnn.benchmark = False
|
649
649
|
|
650
650
|
# Create output directory
|
651
|
-
os.makedirs(output_dir, exist_ok=True)
|
651
|
+
os.makedirs(os.path.abspath(output_dir), exist_ok=True)
|
652
652
|
|
653
653
|
# Get device
|
654
654
|
if device is None:
|
@@ -1216,7 +1216,7 @@ def object_detection_batch(
|
|
1216
1216
|
)
|
1217
1217
|
|
1218
1218
|
if not os.path.exists(output_dir):
|
1219
|
-
os.makedirs(output_dir, exist_ok=True)
|
1219
|
+
os.makedirs(os.path.abspath(output_dir), exist_ok=True)
|
1220
1220
|
|
1221
1221
|
if not os.path.exists(model_path):
|
1222
1222
|
try:
|
@@ -1905,7 +1905,7 @@ def train_segmentation_model(
|
|
1905
1905
|
torch.backends.cudnn.benchmark = False
|
1906
1906
|
|
1907
1907
|
# Create output directory
|
1908
|
-
os.makedirs(output_dir, exist_ok=True)
|
1908
|
+
os.makedirs(os.path.abspath(output_dir), exist_ok=True)
|
1909
1909
|
|
1910
1910
|
# Get device
|
1911
1911
|
if device is None:
|
@@ -2511,7 +2511,7 @@ def semantic_inference_on_geotiff(
|
|
2511
2511
|
print(f"Inference completed in {inference_time:.2f} seconds")
|
2512
2512
|
|
2513
2513
|
# Save output
|
2514
|
-
out_dir = os.path.dirname(output_path)
|
2514
|
+
out_dir = os.path.abspath(os.path.dirname(output_path))
|
2515
2515
|
os.makedirs(out_dir, exist_ok=True)
|
2516
2516
|
with rasterio.open(output_path, "w", **out_meta) as dst:
|
2517
2517
|
dst.write(mask, 1)
|
@@ -2778,7 +2778,7 @@ def semantic_inference_on_image(
|
|
2778
2778
|
# Change extension to PNG if binary output to preserve exact values
|
2779
2779
|
output_path_png = os.path.splitext(output_path)[0] + ".png"
|
2780
2780
|
output_img = Image.fromarray(mask, mode="L")
|
2781
|
-
out_dir = os.path.dirname(output_path)
|
2781
|
+
out_dir = os.path.abspath(os.path.dirname(output_path))
|
2782
2782
|
os.makedirs(out_dir, exist_ok=True)
|
2783
2783
|
output_img.save(output_path_png)
|
2784
2784
|
if not quiet:
|
@@ -2887,7 +2887,7 @@ def semantic_segmentation(
|
|
2887
2887
|
)
|
2888
2888
|
else:
|
2889
2889
|
# Create output directory if it doesn't exist
|
2890
|
-
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
2890
|
+
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
|
2891
2891
|
|
2892
2892
|
semantic_inference_on_image(
|
2893
2893
|
model=model,
|
@@ -2962,7 +2962,7 @@ def semantic_segmentation_batch(
|
|
2962
2962
|
raise FileNotFoundError(f"Input directory does not exist: {input_dir}")
|
2963
2963
|
|
2964
2964
|
# Create output directory if it doesn't exist
|
2965
|
-
os.makedirs(output_dir, exist_ok=True)
|
2965
|
+
os.makedirs(os.path.abspath(output_dir), exist_ok=True)
|
2966
2966
|
|
2967
2967
|
# Get all supported image files
|
2968
2968
|
image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
|
geoai/utils.py
CHANGED
@@ -5,6 +5,7 @@ import glob
|
|
5
5
|
import json
|
6
6
|
import math
|
7
7
|
import os
|
8
|
+
import subprocess
|
8
9
|
import warnings
|
9
10
|
import xml.etree.ElementTree as ET
|
10
11
|
from collections import OrderedDict
|
@@ -7279,106 +7280,213 @@ def plot_prediction_comparison(
|
|
7279
7280
|
prediction_colormap: str = "gray",
|
7280
7281
|
ground_truth_colormap: str = "gray",
|
7281
7282
|
original_colormap: Optional[str] = None,
|
7283
|
+
indexes: Optional[List[int]] = None,
|
7284
|
+
divider: Optional[float] = None,
|
7282
7285
|
):
|
7283
|
-
"""
|
7284
|
-
|
7286
|
+
"""Plot original image, prediction, and optional ground truth side by side.
|
7287
|
+
|
7288
|
+
Supports input as file paths, NumPy arrays, or PIL Images. For multi-band
|
7289
|
+
images, selected channels can be specified via `indexes`. If the image data
|
7290
|
+
is not normalized (e.g., Sentinel-2 [0, 10000]), the `divider` can be used
|
7291
|
+
to scale values for visualization.
|
7285
7292
|
|
7286
7293
|
Args:
|
7287
|
-
original_image
|
7288
|
-
|
7289
|
-
|
7290
|
-
|
7291
|
-
|
7292
|
-
|
7293
|
-
|
7294
|
-
|
7295
|
-
|
7296
|
-
|
7294
|
+
original_image (Union[str, np.ndarray, Image.Image]):
|
7295
|
+
Original input image as a file path, NumPy array, or PIL Image.
|
7296
|
+
prediction_image (Union[str, np.ndarray, Image.Image]):
|
7297
|
+
Predicted segmentation mask image.
|
7298
|
+
ground_truth_image (Optional[Union[str, np.ndarray, Image.Image]], optional):
|
7299
|
+
Ground truth mask image. Defaults to None.
|
7300
|
+
titles (Optional[List[str]], optional):
|
7301
|
+
List of titles for the subplots. If not provided, default titles are used.
|
7302
|
+
figsize (Tuple[int, int], optional):
|
7303
|
+
Size of the entire figure in inches. Defaults to (15, 5).
|
7304
|
+
save_path (Optional[str], optional):
|
7305
|
+
If specified, saves the figure to this path. Defaults to None.
|
7306
|
+
show_plot (bool, optional):
|
7307
|
+
Whether to display the figure using plt.show(). Defaults to True.
|
7308
|
+
prediction_colormap (str, optional):
|
7309
|
+
Colormap to use for the prediction mask. Defaults to "gray".
|
7310
|
+
ground_truth_colormap (str, optional):
|
7311
|
+
Colormap to use for the ground truth mask. Defaults to "gray".
|
7312
|
+
original_colormap (Optional[str], optional):
|
7313
|
+
Colormap to use for the original image if it's grayscale. Defaults to None.
|
7314
|
+
indexes (Optional[List[int]], optional):
|
7315
|
+
List of band/channel indexes (0-based for NumPy, 1-based for rasterio) to extract from the original image.
|
7316
|
+
Useful for multi-band imagery like Sentinel-2. Defaults to None.
|
7317
|
+
divider (Optional[float], optional):
|
7318
|
+
Value to divide the original image by for normalization (e.g., 10000 for reflectance). Defaults to None.
|
7297
7319
|
|
7298
7320
|
Returns:
|
7299
|
-
matplotlib.figure.Figure:
|
7321
|
+
matplotlib.figure.Figure:
|
7322
|
+
The generated matplotlib figure object.
|
7300
7323
|
"""
|
7301
7324
|
|
7302
|
-
def _load_image(img_input):
|
7325
|
+
def _load_image(img_input, indexes=None):
|
7303
7326
|
"""Helper function to load image from various input types."""
|
7304
7327
|
if isinstance(img_input, str):
|
7305
|
-
# File path
|
7306
7328
|
if img_input.lower().endswith((".tif", ".tiff")):
|
7307
|
-
# Handle GeoTIFF files
|
7308
7329
|
with rasterio.open(img_input) as src:
|
7309
|
-
|
7310
|
-
|
7311
|
-
|
7312
|
-
|
7330
|
+
if indexes:
|
7331
|
+
img = src.read(indexes) # 1-based
|
7332
|
+
img = (
|
7333
|
+
np.transpose(img, (1, 2, 0)) if len(indexes) > 1 else img[0]
|
7334
|
+
)
|
7313
7335
|
else:
|
7314
|
-
|
7315
|
-
img
|
7336
|
+
img = src.read()
|
7337
|
+
if img.shape[0] == 1:
|
7338
|
+
img = img[0]
|
7339
|
+
else:
|
7340
|
+
img = np.transpose(img, (1, 2, 0))
|
7316
7341
|
else:
|
7317
|
-
# Regular image file
|
7318
7342
|
img = np.array(Image.open(img_input))
|
7319
7343
|
elif isinstance(img_input, Image.Image):
|
7320
|
-
# PIL Image
|
7321
7344
|
img = np.array(img_input)
|
7322
7345
|
elif isinstance(img_input, np.ndarray):
|
7323
|
-
# NumPy array
|
7324
7346
|
img = img_input
|
7347
|
+
if indexes is not None and img.ndim == 3:
|
7348
|
+
img = img[:, :, indexes]
|
7325
7349
|
else:
|
7326
7350
|
raise ValueError(f"Unsupported image type: {type(img_input)}")
|
7327
|
-
|
7328
7351
|
return img
|
7329
7352
|
|
7330
7353
|
# Load images
|
7331
|
-
original = _load_image(original_image)
|
7354
|
+
original = _load_image(original_image, indexes=indexes)
|
7332
7355
|
prediction = _load_image(prediction_image)
|
7333
7356
|
ground_truth = (
|
7334
7357
|
_load_image(ground_truth_image) if ground_truth_image is not None else None
|
7335
7358
|
)
|
7336
7359
|
|
7337
|
-
#
|
7338
|
-
|
7360
|
+
# Apply divider normalization if requested
|
7361
|
+
if divider is not None and isinstance(original, np.ndarray) and original.ndim == 3:
|
7362
|
+
original = np.clip(original.astype(np.float32) / divider, 0, 1)
|
7339
7363
|
|
7340
|
-
#
|
7364
|
+
# Determine layout
|
7365
|
+
num_plots = 3 if ground_truth is not None else 2
|
7341
7366
|
fig, axes = plt.subplots(1, num_plots, figsize=figsize)
|
7342
7367
|
if num_plots == 2:
|
7343
7368
|
axes = [axes[0], axes[1]]
|
7344
7369
|
|
7345
|
-
# Default titles
|
7346
7370
|
if titles is None:
|
7347
7371
|
titles = ["Original Image", "Prediction"]
|
7348
7372
|
if ground_truth is not None:
|
7349
7373
|
titles.append("Ground Truth")
|
7350
7374
|
|
7351
|
-
# Plot original
|
7352
|
-
if
|
7353
|
-
# RGB or RGBA image
|
7375
|
+
# Plot original
|
7376
|
+
if original.ndim == 3 and original.shape[2] in [3, 4]:
|
7354
7377
|
axes[0].imshow(original)
|
7355
7378
|
else:
|
7356
|
-
# Grayscale or single channel
|
7357
7379
|
axes[0].imshow(original, cmap=original_colormap)
|
7358
7380
|
axes[0].set_title(titles[0])
|
7359
7381
|
axes[0].axis("off")
|
7360
7382
|
|
7361
|
-
#
|
7383
|
+
# Prediction
|
7362
7384
|
axes[1].imshow(prediction, cmap=prediction_colormap)
|
7363
7385
|
axes[1].set_title(titles[1])
|
7364
7386
|
axes[1].axis("off")
|
7365
7387
|
|
7366
|
-
#
|
7388
|
+
# Ground truth
|
7367
7389
|
if ground_truth is not None:
|
7368
7390
|
axes[2].imshow(ground_truth, cmap=ground_truth_colormap)
|
7369
7391
|
axes[2].set_title(titles[2])
|
7370
7392
|
axes[2].axis("off")
|
7371
7393
|
|
7372
|
-
# Adjust layout
|
7373
7394
|
plt.tight_layout()
|
7374
7395
|
|
7375
|
-
# Save if requested
|
7376
7396
|
if save_path:
|
7377
7397
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
7378
7398
|
print(f"Plot saved to: {save_path}")
|
7379
7399
|
|
7380
|
-
# Show plot
|
7381
7400
|
if show_plot:
|
7382
7401
|
plt.show()
|
7383
7402
|
|
7384
7403
|
return fig
|
7404
|
+
|
7405
|
+
|
7406
|
+
def get_raster_resolution(image_path: str) -> Tuple[float, float]:
|
7407
|
+
"""Get pixel resolution from the raster using rasterio.
|
7408
|
+
|
7409
|
+
Args:
|
7410
|
+
image_path: The path to the raster image.
|
7411
|
+
|
7412
|
+
Returns:
|
7413
|
+
A tuple of (x resolution, y resolution).
|
7414
|
+
"""
|
7415
|
+
with rasterio.open(image_path) as src:
|
7416
|
+
res = src.res
|
7417
|
+
return res
|
7418
|
+
|
7419
|
+
|
7420
|
+
def stack_bands(
|
7421
|
+
input_files: List[str],
|
7422
|
+
output_file: str,
|
7423
|
+
resolution: Optional[float] = None,
|
7424
|
+
dtype: Optional[str] = None, # e.g., "UInt16", "Float32"
|
7425
|
+
temp_vrt: str = "stack.vrt",
|
7426
|
+
overwrite: bool = False,
|
7427
|
+
compress: str = "DEFLATE",
|
7428
|
+
output_format: str = "COG",
|
7429
|
+
extra_gdal_translate_args: Optional[List[str]] = None,
|
7430
|
+
) -> str:
|
7431
|
+
"""
|
7432
|
+
Stack bands from multiple images into a single multi-band GeoTIFF.
|
7433
|
+
|
7434
|
+
Args:
|
7435
|
+
input_files (List[str]): List of input image paths.
|
7436
|
+
output_file (str): Path to the output stacked image.
|
7437
|
+
resolution (float, optional): Output resolution. If None, inferred from first image.
|
7438
|
+
dtype (str, optional): Output data type (e.g., "UInt16", "Float32").
|
7439
|
+
temp_vrt (str): Temporary VRT filename.
|
7440
|
+
overwrite (bool): Whether to overwrite the output file.
|
7441
|
+
compress (str): Compression method.
|
7442
|
+
output_format (str): GDAL output format (default is "COG").
|
7443
|
+
extra_gdal_translate_args (List[str], optional): Extra arguments for gdal_translate.
|
7444
|
+
|
7445
|
+
Returns:
|
7446
|
+
str: Path to the output file.
|
7447
|
+
"""
|
7448
|
+
if not input_files:
|
7449
|
+
raise ValueError("No input files provided.")
|
7450
|
+
|
7451
|
+
if os.path.exists(output_file) and not overwrite:
|
7452
|
+
print(f"Output file already exists: {output_file}")
|
7453
|
+
return output_file
|
7454
|
+
|
7455
|
+
# Infer resolution if not provided
|
7456
|
+
if resolution is None:
|
7457
|
+
resolution_x, resolution_y = get_raster_resolution(input_files[0])
|
7458
|
+
else:
|
7459
|
+
resolution_x = resolution_y = resolution
|
7460
|
+
|
7461
|
+
# Step 1: Build VRT
|
7462
|
+
vrt_cmd = ["gdalbuildvrt", "-separate", temp_vrt] + input_files
|
7463
|
+
subprocess.run(vrt_cmd, check=True)
|
7464
|
+
|
7465
|
+
# Step 2: Translate VRT to output GeoTIFF
|
7466
|
+
translate_cmd = [
|
7467
|
+
"gdal_translate",
|
7468
|
+
"-tr",
|
7469
|
+
str(resolution_x),
|
7470
|
+
str(resolution_y),
|
7471
|
+
temp_vrt,
|
7472
|
+
output_file,
|
7473
|
+
"-of",
|
7474
|
+
output_format,
|
7475
|
+
"-co",
|
7476
|
+
f"COMPRESS={compress}",
|
7477
|
+
]
|
7478
|
+
|
7479
|
+
if dtype:
|
7480
|
+
translate_cmd.insert(1, "-ot")
|
7481
|
+
translate_cmd.insert(2, dtype)
|
7482
|
+
|
7483
|
+
if extra_gdal_translate_args:
|
7484
|
+
translate_cmd += extra_gdal_translate_args
|
7485
|
+
|
7486
|
+
subprocess.run(translate_cmd, check=True)
|
7487
|
+
|
7488
|
+
# Step 3: Clean up VRT
|
7489
|
+
if os.path.exists(temp_vrt):
|
7490
|
+
os.remove(temp_vrt)
|
7491
|
+
|
7492
|
+
return output_file
|
@@ -1,4 +1,4 @@
|
|
1
|
-
geoai/__init__.py,sha256=
|
1
|
+
geoai/__init__.py,sha256=wKdVPgcUCURuMI3TNzqh1MPw96Bs5XL7mNUST75bOjw,3765
|
2
2
|
geoai/classify.py,sha256=O8fah3DOBDMZW7V_qfDYsUnjB-9Wo5fjA-0e4wvUeAE,35054
|
3
3
|
geoai/download.py,sha256=EQpcrcqMsYhDpd7bpjf4hGS5xL2oO-jsjngLgjjP3cE,46599
|
4
4
|
geoai/extract.py,sha256=vyHH1k5zaXiy1SLdCLXxbWiNLp8XKdu_MXZoREMtAOQ,119102
|
@@ -7,11 +7,11 @@ geoai/hf.py,sha256=mLKGxEAS5eHkxZLwuLpYc1o7e3-7QIXdBv-QUY-RkFk,17072
|
|
7
7
|
geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
|
8
8
|
geoai/segment.py,sha256=pThAyq8kjgVDhMwHMiWkZ2qL5ykzA5lRg7tyMmSEBxk,43434
|
9
9
|
geoai/segmentation.py,sha256=AtPzCvguHAEeuyXafa4bzMFATvltEYcah1B8ZMfkM_s,11373
|
10
|
-
geoai/train.py,sha256=
|
11
|
-
geoai/utils.py,sha256=
|
12
|
-
geoai_py-0.8.
|
13
|
-
geoai_py-0.8.
|
14
|
-
geoai_py-0.8.
|
15
|
-
geoai_py-0.8.
|
16
|
-
geoai_py-0.8.
|
17
|
-
geoai_py-0.8.
|
10
|
+
geoai/train.py,sha256=Mrsb0yMVnprQHld3zDvA-puc-r8hGm1XgG0j2GGIn7E,112845
|
11
|
+
geoai/utils.py,sha256=_JVEhFUzOdDS_Rmco8c4BJVnrs2VvUVHKf4ubtim5fg,286109
|
12
|
+
geoai_py-0.8.2.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
13
|
+
geoai_py-0.8.2.dist-info/METADATA,sha256=6LIaLEjT5AtbV5OOJwMT7NQCMI6pHcKXsfQbakWl6Ns,6661
|
14
|
+
geoai_py-0.8.2.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
|
15
|
+
geoai_py-0.8.2.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
|
16
|
+
geoai_py-0.8.2.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
17
|
+
geoai_py-0.8.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|