geoai-py 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.2.0"
5
+ __version__ = "0.2.2"
6
6
 
7
7
 
8
8
  import os
@@ -29,5 +29,7 @@ def set_proj_lib_path():
29
29
  return
30
30
 
31
31
 
32
- set_proj_lib_path()
32
+ if "google.colab" not in sys.modules:
33
+ set_proj_lib_path()
34
+
33
35
  from .geoai import *
geoai/common.py CHANGED
@@ -114,6 +114,8 @@ def viz_image(
114
114
 
115
115
  if isinstance(image, torch.Tensor):
116
116
  image = image.cpu().numpy()
117
+ elif isinstance(image, str):
118
+ image = rio.open(image).read().transpose(1, 2, 0)
117
119
 
118
120
  plt.figure(figsize=figsize)
119
121
 
geoai/geoai.py CHANGED
@@ -1,3 +1,4 @@
1
1
  """Main module."""
2
2
 
3
- from .common import viz_raster, viz_image, plot_batch, calc_stats
3
+ from .common import *
4
+ from .preprocess import *
geoai/preprocess.py CHANGED
@@ -3,6 +3,7 @@ import math
3
3
  import os
4
4
  from PIL import Image
5
5
  from pathlib import Path
6
+ import requests
6
7
  import warnings
7
8
  import xml.etree.ElementTree as ET
8
9
  import numpy as np
@@ -11,6 +12,7 @@ import geopandas as gpd
11
12
  import pandas as pd
12
13
  from rasterio.windows import Window
13
14
  from rasterio import features
15
+ from rasterio.plot import show
14
16
  from shapely.geometry import box, shape
15
17
  import matplotlib.pyplot as plt
16
18
  from tqdm import tqdm
@@ -20,6 +22,691 @@ import torchgeo
20
22
  import torch
21
23
 
22
24
 
25
+ def download_file(url, output_path=None, overwrite=False):
26
+ """
27
+ Download a file from a given URL with a progress bar.
28
+
29
+ Args:
30
+ url (str): The URL of the file to download.
31
+ output_path (str, optional): The path where the downloaded file will be saved.
32
+ If not provided, the filename from the URL will be used.
33
+ overwrite (bool, optional): Whether to overwrite the file if it already exists.
34
+
35
+ Returns:
36
+ str: The path to the downloaded file.
37
+ """
38
+ # Get the filename from the URL if output_path is not provided
39
+ if output_path is None:
40
+ output_path = os.path.basename(url)
41
+
42
+ # Check if the file already exists
43
+ if os.path.exists(output_path) and not overwrite:
44
+ print(f"File already exists: {output_path}")
45
+ return output_path
46
+
47
+ # Send a streaming GET request
48
+ response = requests.get(url, stream=True, timeout=50)
49
+ response.raise_for_status() # Raise an exception for HTTP errors
50
+
51
+ # Get the total file size if available
52
+ total_size = int(response.headers.get("content-length", 0))
53
+
54
+ # Open the output file
55
+ with (
56
+ open(output_path, "wb") as file,
57
+ tqdm(
58
+ desc=os.path.basename(output_path),
59
+ total=total_size,
60
+ unit="B",
61
+ unit_scale=True,
62
+ unit_divisor=1024,
63
+ ) as progress_bar,
64
+ ):
65
+
66
+ # Download the file in chunks and update the progress bar
67
+ for chunk in response.iter_content(chunk_size=1024):
68
+ if chunk: # filter out keep-alive new chunks
69
+ file.write(chunk)
70
+ progress_bar.update(len(chunk))
71
+
72
+ return output_path
73
+
74
+
75
+ def get_raster_info(raster_path):
76
+ """Display basic information about a raster dataset.
77
+
78
+ Args:
79
+ raster_path (str): Path to the raster file
80
+
81
+ Returns:
82
+ dict: Dictionary containing the basic information about the raster
83
+ """
84
+ # Open the raster dataset
85
+ with rasterio.open(raster_path) as src:
86
+ # Get basic metadata
87
+ info = {
88
+ "driver": src.driver,
89
+ "width": src.width,
90
+ "height": src.height,
91
+ "count": src.count,
92
+ "dtype": src.dtypes[0],
93
+ "crs": src.crs.to_string() if src.crs else "No CRS defined",
94
+ "transform": src.transform,
95
+ "bounds": src.bounds,
96
+ "resolution": (src.transform[0], -src.transform[4]),
97
+ "nodata": src.nodata,
98
+ }
99
+
100
+ # Calculate statistics for each band
101
+ stats = []
102
+ for i in range(1, src.count + 1):
103
+ band = src.read(i, masked=True)
104
+ band_stats = {
105
+ "band": i,
106
+ "min": float(band.min()),
107
+ "max": float(band.max()),
108
+ "mean": float(band.mean()),
109
+ "std": float(band.std()),
110
+ }
111
+ stats.append(band_stats)
112
+
113
+ info["band_stats"] = stats
114
+
115
+ return info
116
+
117
+
118
+ def print_raster_info(raster_path, show_preview=True, figsize=(10, 8)):
119
+ """Print formatted information about a raster dataset and optionally show a preview.
120
+
121
+ Args:
122
+ raster_path (str): Path to the raster file
123
+ show_preview (bool, optional): Whether to display a visual preview of the raster.
124
+ Defaults to True.
125
+ figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
126
+
127
+ Returns:
128
+ dict: Dictionary containing raster information if successful, None otherwise
129
+ """
130
+ try:
131
+ info = get_raster_info(raster_path)
132
+
133
+ # Print basic information
134
+ print(f"===== RASTER INFORMATION: {raster_path} =====")
135
+ print(f"Driver: {info['driver']}")
136
+ print(f"Dimensions: {info['width']} x {info['height']} pixels")
137
+ print(f"Number of bands: {info['count']}")
138
+ print(f"Data type: {info['dtype']}")
139
+ print(f"Coordinate Reference System: {info['crs']}")
140
+ print(f"Georeferenced Bounds: {info['bounds']}")
141
+ print(f"Pixel Resolution: {info['resolution'][0]}, {info['resolution'][1]}")
142
+ print(f"NoData Value: {info['nodata']}")
143
+
144
+ # Print band statistics
145
+ print("\n----- Band Statistics -----")
146
+ for band_stat in info["band_stats"]:
147
+ print(f"Band {band_stat['band']}:")
148
+ print(f" Min: {band_stat['min']:.2f}")
149
+ print(f" Max: {band_stat['max']:.2f}")
150
+ print(f" Mean: {band_stat['mean']:.2f}")
151
+ print(f" Std Dev: {band_stat['std']:.2f}")
152
+
153
+ # Show a preview if requested
154
+ if show_preview:
155
+ with rasterio.open(raster_path) as src:
156
+ # For multi-band images, show RGB composite or first band
157
+ if src.count >= 3:
158
+ # Try to show RGB composite
159
+ rgb = np.dstack([src.read(i) for i in range(1, 4)])
160
+ plt.figure(figsize=figsize)
161
+ plt.imshow(rgb)
162
+ plt.title(f"RGB Preview: {raster_path}")
163
+ else:
164
+ # Show first band for single-band images
165
+ plt.figure(figsize=figsize)
166
+ show(
167
+ src.read(1),
168
+ cmap="viridis",
169
+ title=f"Band 1 Preview: {raster_path}",
170
+ )
171
+ plt.colorbar(label="Pixel Value")
172
+ plt.show()
173
+
174
+ except Exception as e:
175
+ print(f"Error reading raster: {str(e)}")
176
+
177
+
178
+ def get_raster_info_gdal(raster_path):
179
+ """Get basic information about a raster dataset using GDAL.
180
+
181
+ Args:
182
+ raster_path (str): Path to the raster file
183
+
184
+ Returns:
185
+ dict: Dictionary containing the basic information about the raster,
186
+ or None if the file cannot be opened
187
+ """
188
+
189
+ from osgeo import gdal
190
+
191
+ # Open the dataset
192
+ ds = gdal.Open(raster_path)
193
+ if ds is None:
194
+ print(f"Error: Could not open {raster_path}")
195
+ return None
196
+
197
+ # Get basic information
198
+ info = {
199
+ "driver": ds.GetDriver().ShortName,
200
+ "width": ds.RasterXSize,
201
+ "height": ds.RasterYSize,
202
+ "count": ds.RasterCount,
203
+ "projection": ds.GetProjection(),
204
+ "geotransform": ds.GetGeoTransform(),
205
+ }
206
+
207
+ # Calculate resolution
208
+ gt = ds.GetGeoTransform()
209
+ if gt:
210
+ info["resolution"] = (abs(gt[1]), abs(gt[5]))
211
+ info["origin"] = (gt[0], gt[3])
212
+
213
+ # Get band information
214
+ bands_info = []
215
+ for i in range(1, ds.RasterCount + 1):
216
+ band = ds.GetRasterBand(i)
217
+ stats = band.GetStatistics(True, True)
218
+ band_info = {
219
+ "band": i,
220
+ "datatype": gdal.GetDataTypeName(band.DataType),
221
+ "min": stats[0],
222
+ "max": stats[1],
223
+ "mean": stats[2],
224
+ "std": stats[3],
225
+ "nodata": band.GetNoDataValue(),
226
+ }
227
+ bands_info.append(band_info)
228
+
229
+ info["bands"] = bands_info
230
+
231
+ # Close the dataset
232
+ ds = None
233
+
234
+ return info
235
+
236
+
237
+ def get_vector_info(vector_path):
238
+ """Display basic information about a vector dataset using GeoPandas.
239
+
240
+ Args:
241
+ vector_path (str): Path to the vector file
242
+
243
+ Returns:
244
+ dict: Dictionary containing the basic information about the vector dataset
245
+ """
246
+ # Open the vector dataset
247
+ gdf = gpd.read_file(vector_path)
248
+
249
+ # Get basic metadata
250
+ info = {
251
+ "file_path": vector_path,
252
+ "driver": os.path.splitext(vector_path)[1][1:].upper(), # Format from extension
253
+ "feature_count": len(gdf),
254
+ "crs": str(gdf.crs),
255
+ "geometry_type": str(gdf.geom_type.value_counts().to_dict()),
256
+ "attribute_count": len(gdf.columns) - 1, # Subtract the geometry column
257
+ "attribute_names": list(gdf.columns[gdf.columns != "geometry"]),
258
+ "bounds": gdf.total_bounds.tolist(),
259
+ }
260
+
261
+ # Add statistics about numeric attributes
262
+ numeric_columns = gdf.select_dtypes(include=["number"]).columns
263
+ attribute_stats = {}
264
+ for col in numeric_columns:
265
+ if col != "geometry":
266
+ attribute_stats[col] = {
267
+ "min": gdf[col].min(),
268
+ "max": gdf[col].max(),
269
+ "mean": gdf[col].mean(),
270
+ "std": gdf[col].std(),
271
+ "null_count": gdf[col].isna().sum(),
272
+ }
273
+
274
+ info["attribute_stats"] = attribute_stats
275
+
276
+ return info
277
+
278
+
279
+ def print_vector_info(vector_path, show_preview=True, figsize=(10, 8)):
280
+ """Print formatted information about a vector dataset and optionally show a preview.
281
+
282
+ Args:
283
+ vector_path (str): Path to the vector file
284
+ show_preview (bool, optional): Whether to display a visual preview of the vector data.
285
+ Defaults to True.
286
+ figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
287
+
288
+ Returns:
289
+ dict: Dictionary containing vector information if successful, None otherwise
290
+ """
291
+ try:
292
+ info = get_vector_info(vector_path)
293
+
294
+ # Print basic information
295
+ print(f"===== VECTOR INFORMATION: {vector_path} =====")
296
+ print(f"Driver: {info['driver']}")
297
+ print(f"Feature count: {info['feature_count']}")
298
+ print(f"Geometry types: {info['geometry_type']}")
299
+ print(f"Coordinate Reference System: {info['crs']}")
300
+ print(f"Bounds: {info['bounds']}")
301
+ print(f"Number of attributes: {info['attribute_count']}")
302
+ print(f"Attribute names: {', '.join(info['attribute_names'])}")
303
+
304
+ # Print attribute statistics
305
+ if info["attribute_stats"]:
306
+ print("\n----- Attribute Statistics -----")
307
+ for attr, stats in info["attribute_stats"].items():
308
+ print(f"Attribute: {attr}")
309
+ for stat_name, stat_value in stats.items():
310
+ print(
311
+ f" {stat_name}: {stat_value:.4f}"
312
+ if isinstance(stat_value, float)
313
+ else f" {stat_name}: {stat_value}"
314
+ )
315
+
316
+ # Show a preview if requested
317
+ if show_preview:
318
+ gdf = gpd.read_file(vector_path)
319
+ fig, ax = plt.subplots(figsize=figsize)
320
+ gdf.plot(ax=ax, cmap="viridis")
321
+ ax.set_title(f"Preview: {vector_path}")
322
+ plt.tight_layout()
323
+ plt.show()
324
+
325
+ # # Show a sample of the attribute table
326
+ # if not gdf.empty:
327
+ # print("\n----- Sample of attribute table (first 5 rows) -----")
328
+ # print(gdf.head().to_string())
329
+
330
+ except Exception as e:
331
+ print(f"Error reading vector data: {str(e)}")
332
+
333
+
334
+ # Alternative implementation using OGR directly
335
+ def get_vector_info_ogr(vector_path):
336
+ """Get basic information about a vector dataset using OGR.
337
+
338
+ Args:
339
+ vector_path (str): Path to the vector file
340
+
341
+ Returns:
342
+ dict: Dictionary containing the basic information about the vector dataset,
343
+ or None if the file cannot be opened
344
+ """
345
+ from osgeo import ogr
346
+
347
+ # Register all OGR drivers
348
+ ogr.RegisterAll()
349
+
350
+ # Open the dataset
351
+ ds = ogr.Open(vector_path)
352
+ if ds is None:
353
+ print(f"Error: Could not open {vector_path}")
354
+ return None
355
+
356
+ # Basic dataset information
357
+ info = {
358
+ "file_path": vector_path,
359
+ "driver": ds.GetDriver().GetName(),
360
+ "layer_count": ds.GetLayerCount(),
361
+ "layers": [],
362
+ }
363
+
364
+ # Extract information for each layer
365
+ for i in range(ds.GetLayerCount()):
366
+ layer = ds.GetLayer(i)
367
+ layer_info = {
368
+ "name": layer.GetName(),
369
+ "feature_count": layer.GetFeatureCount(),
370
+ "geometry_type": ogr.GeometryTypeToName(layer.GetGeomType()),
371
+ "spatial_ref": (
372
+ layer.GetSpatialRef().ExportToWkt() if layer.GetSpatialRef() else "None"
373
+ ),
374
+ "extent": layer.GetExtent(),
375
+ "fields": [],
376
+ }
377
+
378
+ # Get field information
379
+ defn = layer.GetLayerDefn()
380
+ for j in range(defn.GetFieldCount()):
381
+ field_defn = defn.GetFieldDefn(j)
382
+ field_info = {
383
+ "name": field_defn.GetName(),
384
+ "type": field_defn.GetTypeName(),
385
+ "width": field_defn.GetWidth(),
386
+ "precision": field_defn.GetPrecision(),
387
+ }
388
+ layer_info["fields"].append(field_info)
389
+
390
+ info["layers"].append(layer_info)
391
+
392
+ # Close the dataset
393
+ ds = None
394
+
395
+ return info
396
+
397
+
398
+ def analyze_vector_attributes(vector_path, attribute_name):
399
+ """Analyze a specific attribute in a vector dataset and create a histogram.
400
+
401
+ Args:
402
+ vector_path (str): Path to the vector file
403
+ attribute_name (str): Name of the attribute to analyze
404
+
405
+ Returns:
406
+ dict: Dictionary containing analysis results for the attribute
407
+ """
408
+ try:
409
+ gdf = gpd.read_file(vector_path)
410
+
411
+ # Check if attribute exists
412
+ if attribute_name not in gdf.columns:
413
+ print(f"Attribute '{attribute_name}' not found in the dataset")
414
+ return None
415
+
416
+ # Get the attribute series
417
+ attr = gdf[attribute_name]
418
+
419
+ # Perform different analyses based on data type
420
+ if pd.api.types.is_numeric_dtype(attr):
421
+ # Numeric attribute
422
+ analysis = {
423
+ "attribute": attribute_name,
424
+ "type": "numeric",
425
+ "count": attr.count(),
426
+ "null_count": attr.isna().sum(),
427
+ "min": attr.min(),
428
+ "max": attr.max(),
429
+ "mean": attr.mean(),
430
+ "median": attr.median(),
431
+ "std": attr.std(),
432
+ "unique_values": attr.nunique(),
433
+ }
434
+
435
+ # Create histogram
436
+ plt.figure(figsize=(10, 6))
437
+ plt.hist(attr.dropna(), bins=20, alpha=0.7, color="blue")
438
+ plt.title(f"Histogram of {attribute_name}")
439
+ plt.xlabel(attribute_name)
440
+ plt.ylabel("Frequency")
441
+ plt.grid(True, alpha=0.3)
442
+ plt.show()
443
+
444
+ else:
445
+ # Categorical attribute
446
+ analysis = {
447
+ "attribute": attribute_name,
448
+ "type": "categorical",
449
+ "count": attr.count(),
450
+ "null_count": attr.isna().sum(),
451
+ "unique_values": attr.nunique(),
452
+ "value_counts": attr.value_counts().to_dict(),
453
+ }
454
+
455
+ # Create bar plot for top categories
456
+ top_n = min(10, attr.nunique())
457
+ plt.figure(figsize=(10, 6))
458
+ attr.value_counts().head(top_n).plot(kind="bar", color="skyblue")
459
+ plt.title(f"Top {top_n} values for {attribute_name}")
460
+ plt.xlabel(attribute_name)
461
+ plt.ylabel("Count")
462
+ plt.xticks(rotation=45)
463
+ plt.grid(True, alpha=0.3)
464
+ plt.tight_layout()
465
+ plt.show()
466
+
467
+ return analysis
468
+
469
+ except Exception as e:
470
+ print(f"Error analyzing attribute: {str(e)}")
471
+ return None
472
+
473
+
474
+ def visualize_vector_by_attribute(
475
+ vector_path, attribute_name, cmap="viridis", figsize=(10, 8)
476
+ ):
477
+ """Create a thematic map visualization of vector data based on an attribute.
478
+
479
+ Args:
480
+ vector_path (str): Path to the vector file
481
+ attribute_name (str): Name of the attribute to visualize
482
+ cmap (str, optional): Matplotlib colormap name. Defaults to 'viridis'.
483
+ figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
484
+
485
+ Returns:
486
+ bool: True if visualization was successful, False otherwise
487
+ """
488
+ try:
489
+ # Read the vector data
490
+ gdf = gpd.read_file(vector_path)
491
+
492
+ # Check if attribute exists
493
+ if attribute_name not in gdf.columns:
494
+ print(f"Attribute '{attribute_name}' not found in the dataset")
495
+ return False
496
+
497
+ # Create the plot
498
+ fig, ax = plt.subplots(figsize=figsize)
499
+
500
+ # Determine plot type based on data type
501
+ if pd.api.types.is_numeric_dtype(gdf[attribute_name]):
502
+ # Continuous data
503
+ gdf.plot(column=attribute_name, cmap=cmap, legend=True, ax=ax)
504
+ else:
505
+ # Categorical data
506
+ gdf.plot(column=attribute_name, categorical=True, legend=True, ax=ax)
507
+
508
+ # Add title and labels
509
+ ax.set_title(f"{os.path.basename(vector_path)} - {attribute_name}")
510
+ ax.set_xlabel("Longitude")
511
+ ax.set_ylabel("Latitude")
512
+
513
+ # Add basemap or additional elements if available
514
+ # Note: Additional options could be added here for more complex maps
515
+
516
+ plt.tight_layout()
517
+ plt.show()
518
+
519
+ except Exception as e:
520
+ print(f"Error visualizing data: {str(e)}")
521
+
522
+
523
+ def clip_raster_by_bbox(
524
+ input_raster, output_raster, bbox, bands=None, bbox_type="geo", bbox_crs=None
525
+ ):
526
+ """
527
+ Clip a raster dataset using a bounding box and optionally select specific bands.
528
+
529
+ Args:
530
+ input_raster (str): Path to the input raster file.
531
+ output_raster (str): Path where the clipped raster will be saved.
532
+ bbox (tuple): Bounding box coordinates either as:
533
+ - Geographic coordinates (minx, miny, maxx, maxy) if bbox_type="geo"
534
+ - Pixel indices (min_row, min_col, max_row, max_col) if bbox_type="pixel"
535
+ bands (list, optional): List of band indices to keep (1-based indexing).
536
+ If None, all bands will be kept.
537
+ bbox_type (str, optional): Type of bounding box coordinates. Either "geo" for
538
+ geographic coordinates or "pixel" for row/column indices.
539
+ Default is "geo".
540
+ bbox_crs (str or dict, optional): CRS of the bbox if different from the raster CRS.
541
+ Can be provided as EPSG code (e.g., "EPSG:4326") or
542
+ as a proj4 string. Only applies when bbox_type="geo".
543
+ If None, assumes bbox is in the same CRS as the raster.
544
+
545
+ Returns:
546
+ str: Path to the clipped output raster.
547
+
548
+ Raises:
549
+ ImportError: If required dependencies are not installed.
550
+ ValueError: If the bbox is invalid, bands are out of range, or bbox_type is invalid.
551
+ RuntimeError: If the clipping operation fails.
552
+
553
+ Examples:
554
+ # Clip using geographic coordinates in the same CRS as the raster
555
+ >>> clip_raster_by_bbox('input.tif', 'clipped_geo.tif', (100, 200, 300, 400))
556
+ 'clipped_geo.tif'
557
+
558
+ # Clip using WGS84 coordinates when the raster is in a different CRS
559
+ >>> clip_raster_by_bbox('input.tif', 'clipped_wgs84.tif', (-122.5, 37.7, -122.4, 37.8),
560
+ ... bbox_crs="EPSG:4326")
561
+ 'clipped_wgs84.tif'
562
+
563
+ # Clip using row/column indices
564
+ >>> clip_raster_by_bbox('input.tif', 'clipped_pixel.tif', (50, 100, 150, 200),
565
+ ... bbox_type="pixel")
566
+ 'clipped_pixel.tif'
567
+
568
+ # Clip with band selection
569
+ >>> clip_raster_by_bbox('input.tif', 'clipped_bands.tif', (100, 200, 300, 400),
570
+ ... bands=[1, 3])
571
+ 'clipped_bands.tif'
572
+ """
573
+ from rasterio.transform import from_bounds
574
+ from rasterio.warp import transform_bounds
575
+
576
+ # Validate bbox_type
577
+ if bbox_type not in ["geo", "pixel"]:
578
+ raise ValueError("bbox_type must be either 'geo' or 'pixel'")
579
+
580
+ # Validate bbox
581
+ if len(bbox) != 4:
582
+ raise ValueError("bbox must contain exactly 4 values")
583
+
584
+ # Open the source raster
585
+ with rasterio.open(input_raster) as src:
586
+ # Get the source CRS
587
+ src_crs = src.crs
588
+
589
+ # Handle different bbox types
590
+ if bbox_type == "geo":
591
+ minx, miny, maxx, maxy = bbox
592
+
593
+ # Validate geographic bbox
594
+ if minx >= maxx or miny >= maxy:
595
+ raise ValueError(
596
+ "Invalid geographic bbox. Expected (minx, miny, maxx, maxy) where minx < maxx and miny < maxy"
597
+ )
598
+
599
+ # If bbox_crs is provided and different from the source CRS, transform the bbox
600
+ if bbox_crs is not None and bbox_crs != src_crs:
601
+ try:
602
+ # Transform bbox coordinates from bbox_crs to src_crs
603
+ minx, miny, maxx, maxy = transform_bounds(
604
+ bbox_crs, src_crs, minx, miny, maxx, maxy
605
+ )
606
+ except Exception as e:
607
+ raise ValueError(
608
+ f"Failed to transform bbox from {bbox_crs} to {src_crs}: {str(e)}"
609
+ )
610
+
611
+ # Calculate the pixel window from geographic coordinates
612
+ window = src.window(minx, miny, maxx, maxy)
613
+
614
+ # Use the same bounds for the output transform
615
+ output_bounds = (minx, miny, maxx, maxy)
616
+
617
+ else: # bbox_type == "pixel"
618
+ min_row, min_col, max_row, max_col = bbox
619
+
620
+ # Validate pixel bbox
621
+ if min_row >= max_row or min_col >= max_col:
622
+ raise ValueError(
623
+ "Invalid pixel bbox. Expected (min_row, min_col, max_row, max_col) where min_row < max_row and min_col < max_col"
624
+ )
625
+
626
+ if (
627
+ min_row < 0
628
+ or min_col < 0
629
+ or max_row > src.height
630
+ or max_col > src.width
631
+ ):
632
+ raise ValueError(
633
+ f"Pixel indices out of bounds. Raster dimensions are {src.height} rows x {src.width} columns"
634
+ )
635
+
636
+ # Create a window from pixel coordinates
637
+ window = Window(min_col, min_row, max_col - min_col, max_row - min_row)
638
+
639
+ # Calculate the geographic bounds for this window
640
+ window_transform = src.window_transform(window)
641
+ output_bounds = rasterio.transform.array_bounds(
642
+ window.height, window.width, window_transform
643
+ )
644
+ # Reorder to (minx, miny, maxx, maxy)
645
+ output_bounds = (
646
+ output_bounds[0],
647
+ output_bounds[1],
648
+ output_bounds[2],
649
+ output_bounds[3],
650
+ )
651
+
652
+ # Get window dimensions
653
+ window_width = int(window.width)
654
+ window_height = int(window.height)
655
+
656
+ # Check if the window is valid
657
+ if window_width <= 0 or window_height <= 0:
658
+ raise ValueError("Bounding box results in an empty window")
659
+
660
+ # Handle band selection
661
+ if bands is None:
662
+ # Use all bands
663
+ bands_to_read = list(range(1, src.count + 1))
664
+ else:
665
+ # Validate band indices
666
+ if not all(1 <= b <= src.count for b in bands):
667
+ raise ValueError(f"Band indices must be between 1 and {src.count}")
668
+ bands_to_read = bands
669
+
670
+ # Calculate new transform for the clipped raster
671
+ new_transform = from_bounds(
672
+ output_bounds[0],
673
+ output_bounds[1],
674
+ output_bounds[2],
675
+ output_bounds[3],
676
+ window_width,
677
+ window_height,
678
+ )
679
+
680
+ # Create a metadata dictionary for the output
681
+ out_meta = src.meta.copy()
682
+ out_meta.update(
683
+ {
684
+ "height": window_height,
685
+ "width": window_width,
686
+ "transform": new_transform,
687
+ "count": len(bands_to_read),
688
+ }
689
+ )
690
+
691
+ # Read the data for the selected bands
692
+ data = []
693
+ for band_idx in bands_to_read:
694
+ band_data = src.read(band_idx, window=window)
695
+ data.append(band_data)
696
+
697
+ # Stack the bands into a single array
698
+ if len(data) > 1:
699
+ clipped_data = np.stack(data)
700
+ else:
701
+ clipped_data = data[0][np.newaxis, :, :]
702
+
703
+ # Write the output raster
704
+ with rasterio.open(output_raster, "w", **out_meta) as dst:
705
+ dst.write(clipped_data)
706
+
707
+ return output_raster
708
+
709
+
23
710
  def raster_to_vector(
24
711
  raster_path,
25
712
  output_path=None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: geoai-py
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
@@ -24,8 +24,8 @@ Requires-Dist: jupyter-server-proxy
24
24
  Requires-Dist: leafmap
25
25
  Requires-Dist: localtileserver
26
26
  Requires-Dist: overturemaps
27
- Requires-Dist: planetary_computer
28
- Requires-Dist: pystac_client
27
+ Requires-Dist: planetary-computer
28
+ Requires-Dist: pystac-client
29
29
  Requires-Dist: rasterio
30
30
  Requires-Dist: rioxarray
31
31
  Requires-Dist: scikit-learn
@@ -62,6 +62,8 @@ GeoAI bridges the gap between AI and geospatial analysis, providing tools for pr
62
62
 
63
63
  ## 🚀 Key Features
64
64
 
65
+ ❗ **Important notes:** The GeoAI package is under active development and new features are being added regularly. Not all features listed below are available in the current release. If you have a feature request or would like to contribute, please let us know!
66
+
65
67
  ### 📊 Advanced Geospatial Data Visualization
66
68
 
67
69
  - Interactive multi-layer visualization of vector, raster, and point cloud data
@@ -112,6 +114,12 @@ pip install geoai-py
112
114
  conda install -c conda-forge geoai
113
115
  ```
114
116
 
117
+ ### Using mamba
118
+
119
+ ```bash
120
+ mamba install -c conda-forge geoai
121
+ ```
122
+
115
123
  ## 📋 Documentation
116
124
 
117
125
  Comprehensive documentation is available at [https://geoai.gishub.org](https://geoai.gishub.org), including:
@@ -0,0 +1,13 @@
1
+ geoai/__init__.py,sha256=yEbFyHPNijxgK-75tatrRELZ9TUdZVYo2uPlxCeBFDA,923
2
+ geoai/common.py,sha256=NdfkQKMPHkwr0B5sDpH5Q_7Nt2AmYt9Gw-KE88NsQ5s,15222
3
+ geoai/download.py,sha256=4GiDmLrp2wKslgfm507WeZrwOdYcMekgQXxWGbl5cBw,13094
4
+ geoai/extract.py,sha256=Fh29d5Fj60YiqhMs62lzkd9T_ONTp2UZ4j98We769sg,31563
5
+ geoai/geoai.py,sha256=BCEtHil0P5cettJdMIhblg1pRaV-vHNQFaYmBrtYP3g,68
6
+ geoai/preprocess.py,sha256=pYtf3-eZY76SKd17MvEZ1qNUvblYW5kzQLvZ-ZM4Wwg,106833
7
+ geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
8
+ geoai_py-0.2.2.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
9
+ geoai_py-0.2.2.dist-info/METADATA,sha256=baREpHpvCvfktqiMSWNI-FGOVme8NAj0UkaJhS6Bkm4,5701
10
+ geoai_py-0.2.2.dist-info/WHEEL,sha256=rF4EZyR2XVS6irmOHQIJx2SUqXLZKRMUrjsg8UwN-XQ,109
11
+ geoai_py-0.2.2.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
12
+ geoai_py-0.2.2.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
13
+ geoai_py-0.2.2.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- geoai/__init__.py,sha256=9U65kN5xD2e8TOuHCXuyvHFDZnuzKxIt-a1yhnzSD34,880
2
- geoai/common.py,sha256=Q_KTGUtJj3RwBGdQLTNGjjtBh_beJ8iBkDg3xR7yx6c,15131
3
- geoai/download.py,sha256=4GiDmLrp2wKslgfm507WeZrwOdYcMekgQXxWGbl5cBw,13094
4
- geoai/extract.py,sha256=Fh29d5Fj60YiqhMs62lzkd9T_ONTp2UZ4j98We769sg,31563
5
- geoai/geoai.py,sha256=TmR7x1uL51G5oAjw0AQWnC5VQtLWDygyFLrDIj46xNc,86
6
- geoai/preprocess.py,sha256=dI3N-xdtDUXYY46nb_SSC7c5G_F1qvGC0HF0bUWKb8A,82824
7
- geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
8
- geoai_py-0.2.0.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
9
- geoai_py-0.2.0.dist-info/METADATA,sha256=ZY0Rx5zUI9-l8B_lFChyp9VUpt17CbvrpUAgNv85TGo,5373
10
- geoai_py-0.2.0.dist-info/WHEEL,sha256=rF4EZyR2XVS6irmOHQIJx2SUqXLZKRMUrjsg8UwN-XQ,109
11
- geoai_py-0.2.0.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
12
- geoai_py-0.2.0.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
13
- geoai_py-0.2.0.dist-info/RECORD,,