geoai-py 0.4.0__py2.py3-none-any.whl → 0.4.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +76 -14
- geoai/download.py +644 -0
- geoai/extract.py +518 -0
- geoai/geoai.py +9 -1
- geoai/train.py +98 -12
- geoai/utils.py +260 -21
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info}/METADATA +8 -13
- geoai_py-0.4.2.dist-info/RECORD +15 -0
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info}/WHEEL +1 -1
- geoai_py-0.4.0.dist-info/RECORD +0 -15
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info/licenses}/LICENSE +0 -0
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info}/top_level.txt +0 -0
geoai/download.py
CHANGED
|
@@ -8,9 +8,11 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
|
8
8
|
import geopandas as gpd
|
|
9
9
|
import matplotlib.pyplot as plt
|
|
10
10
|
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
11
12
|
import planetary_computer as pc
|
|
12
13
|
import requests
|
|
13
14
|
import rioxarray
|
|
15
|
+
import xarray as xr
|
|
14
16
|
from pystac_client import Client
|
|
15
17
|
from shapely.geometry import box
|
|
16
18
|
from tqdm import tqdm
|
|
@@ -394,3 +396,645 @@ def extract_building_stats(geojson_file: str) -> Dict[str, Any]:
|
|
|
394
396
|
except Exception as e:
|
|
395
397
|
logger.error(f"Error extracting statistics: {str(e)}")
|
|
396
398
|
return {"error": str(e)}
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def download_pc_stac_item(
|
|
402
|
+
item_url,
|
|
403
|
+
bands=None,
|
|
404
|
+
output_dir=None,
|
|
405
|
+
show_progress=True,
|
|
406
|
+
merge_bands=False,
|
|
407
|
+
merged_filename=None,
|
|
408
|
+
overwrite=False,
|
|
409
|
+
cell_size=None,
|
|
410
|
+
):
|
|
411
|
+
"""
|
|
412
|
+
Downloads a STAC item from Microsoft Planetary Computer with specified bands.
|
|
413
|
+
|
|
414
|
+
This function fetches a STAC item by URL, signs the assets using Planetary Computer
|
|
415
|
+
credentials, and downloads the specified bands with a progress bar. Can optionally
|
|
416
|
+
merge bands into a single multi-band GeoTIFF.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
item_url (str): The URL of the STAC item to download.
|
|
420
|
+
bands (list, optional): List of specific bands to download (e.g., ['B01', 'B02']).
|
|
421
|
+
If None, all available bands will be downloaded.
|
|
422
|
+
output_dir (str, optional): Directory to save downloaded bands. If None,
|
|
423
|
+
bands are returned as xarray DataArrays.
|
|
424
|
+
show_progress (bool, optional): Whether to display a progress bar. Default is True.
|
|
425
|
+
merge_bands (bool, optional): Whether to merge downloaded bands into a single
|
|
426
|
+
multi-band GeoTIFF file. Default is False.
|
|
427
|
+
merged_filename (str, optional): Filename for the merged bands. If None and
|
|
428
|
+
merge_bands is True, uses "{item_id}_merged.tif".
|
|
429
|
+
overwrite (bool, optional): Whether to overwrite existing files. Default is False.
|
|
430
|
+
cell_size (float, optional): Resolution in meters for the merged output. If None,
|
|
431
|
+
uses the resolution of the first band.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
dict: Dictionary mapping band names to their corresponding xarray DataArrays
|
|
435
|
+
or file paths if output_dir is provided. If merge_bands is True, also
|
|
436
|
+
includes a 'merged' key with the path to the merged file.
|
|
437
|
+
|
|
438
|
+
Raises:
|
|
439
|
+
ValueError: If the item cannot be retrieved or a requested band is not available.
|
|
440
|
+
"""
|
|
441
|
+
from rasterio.enums import Resampling
|
|
442
|
+
|
|
443
|
+
# Get the item ID from the URL
|
|
444
|
+
item_id = item_url.split("/")[-1]
|
|
445
|
+
collection = item_url.split("/collections/")[1].split("/items/")[0]
|
|
446
|
+
|
|
447
|
+
# Connect to the Planetary Computer STAC API
|
|
448
|
+
catalog = Client.open(
|
|
449
|
+
"https://planetarycomputer.microsoft.com/api/stac/v1",
|
|
450
|
+
modifier=pc.sign_inplace,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Search for the specific item
|
|
454
|
+
search = catalog.search(collections=[collection], ids=[item_id])
|
|
455
|
+
|
|
456
|
+
# Get the first item from the search results
|
|
457
|
+
items = list(search.get_items())
|
|
458
|
+
if not items:
|
|
459
|
+
raise ValueError(f"Item with ID {item_id} not found")
|
|
460
|
+
|
|
461
|
+
item = items[0]
|
|
462
|
+
|
|
463
|
+
# Determine which bands to download
|
|
464
|
+
available_assets = list(item.assets.keys())
|
|
465
|
+
|
|
466
|
+
if bands is None:
|
|
467
|
+
# If no bands specified, download all band assets
|
|
468
|
+
bands_to_download = [
|
|
469
|
+
asset for asset in available_assets if asset.startswith("B")
|
|
470
|
+
]
|
|
471
|
+
else:
|
|
472
|
+
# Verify all requested bands exist
|
|
473
|
+
missing_bands = [band for band in bands if band not in available_assets]
|
|
474
|
+
if missing_bands:
|
|
475
|
+
raise ValueError(
|
|
476
|
+
f"The following bands are not available: {missing_bands}. "
|
|
477
|
+
f"Available assets are: {available_assets}"
|
|
478
|
+
)
|
|
479
|
+
bands_to_download = bands
|
|
480
|
+
|
|
481
|
+
# Create output directory if specified and doesn't exist
|
|
482
|
+
if output_dir and not os.path.exists(output_dir):
|
|
483
|
+
os.makedirs(output_dir)
|
|
484
|
+
|
|
485
|
+
result = {}
|
|
486
|
+
band_data_arrays = []
|
|
487
|
+
resampled_arrays = []
|
|
488
|
+
band_names = [] # Track band names in order
|
|
489
|
+
|
|
490
|
+
# Set up progress bar
|
|
491
|
+
progress_iter = (
|
|
492
|
+
tqdm(bands_to_download, desc="Downloading bands")
|
|
493
|
+
if show_progress
|
|
494
|
+
else bands_to_download
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Download each requested band
|
|
498
|
+
for band in progress_iter:
|
|
499
|
+
if band not in item.assets:
|
|
500
|
+
if show_progress and not isinstance(progress_iter, list):
|
|
501
|
+
progress_iter.write(
|
|
502
|
+
f"Warning: Band {band} not found in assets, skipping."
|
|
503
|
+
)
|
|
504
|
+
continue
|
|
505
|
+
|
|
506
|
+
band_url = item.assets[band].href
|
|
507
|
+
|
|
508
|
+
if output_dir:
|
|
509
|
+
file_path = os.path.join(output_dir, f"{item.id}_{band}.tif")
|
|
510
|
+
|
|
511
|
+
# Check if file exists and skip if overwrite is False
|
|
512
|
+
if os.path.exists(file_path) and not overwrite:
|
|
513
|
+
if show_progress and not isinstance(progress_iter, list):
|
|
514
|
+
progress_iter.write(
|
|
515
|
+
f"File {file_path} already exists, skipping (use overwrite=True to force download)."
|
|
516
|
+
)
|
|
517
|
+
# Still need to open the file to get the data for merging
|
|
518
|
+
if merge_bands:
|
|
519
|
+
band_data = rioxarray.open_rasterio(file_path)
|
|
520
|
+
band_data_arrays.append((band, band_data))
|
|
521
|
+
band_names.append(band)
|
|
522
|
+
result[band] = file_path
|
|
523
|
+
continue
|
|
524
|
+
|
|
525
|
+
if show_progress and not isinstance(progress_iter, list):
|
|
526
|
+
progress_iter.set_description(f"Downloading {band}")
|
|
527
|
+
|
|
528
|
+
band_data = rioxarray.open_rasterio(band_url)
|
|
529
|
+
|
|
530
|
+
# Store the data array for potential merging later
|
|
531
|
+
if merge_bands:
|
|
532
|
+
band_data_arrays.append((band, band_data))
|
|
533
|
+
band_names.append(band)
|
|
534
|
+
|
|
535
|
+
if output_dir:
|
|
536
|
+
file_path = os.path.join(output_dir, f"{item.id}_{band}.tif")
|
|
537
|
+
band_data.rio.to_raster(file_path)
|
|
538
|
+
result[band] = file_path
|
|
539
|
+
else:
|
|
540
|
+
result[band] = band_data
|
|
541
|
+
|
|
542
|
+
# Merge bands if requested
|
|
543
|
+
if merge_bands and output_dir:
|
|
544
|
+
if merged_filename is None:
|
|
545
|
+
merged_filename = f"{item.id}_merged.tif"
|
|
546
|
+
|
|
547
|
+
merged_path = os.path.join(output_dir, merged_filename)
|
|
548
|
+
|
|
549
|
+
# Check if merged file exists and skip if overwrite is False
|
|
550
|
+
if os.path.exists(merged_path) and not overwrite:
|
|
551
|
+
if show_progress:
|
|
552
|
+
print(
|
|
553
|
+
f"Merged file {merged_path} already exists, skipping (use overwrite=True to force creation)."
|
|
554
|
+
)
|
|
555
|
+
result["merged"] = merged_path
|
|
556
|
+
else:
|
|
557
|
+
if show_progress:
|
|
558
|
+
print("Resampling and merging bands...")
|
|
559
|
+
|
|
560
|
+
# Determine target cell size if not provided
|
|
561
|
+
if cell_size is None and band_data_arrays:
|
|
562
|
+
# Use the resolution of the first band (usually 10m for B02, B03, B04, B08)
|
|
563
|
+
# Get the affine transform (containing resolution info)
|
|
564
|
+
first_band_data = band_data_arrays[0][1]
|
|
565
|
+
# Extract resolution from transform
|
|
566
|
+
cell_size = abs(first_band_data.rio.transform()[0])
|
|
567
|
+
if show_progress:
|
|
568
|
+
print(f"Using detected resolution: {cell_size}m")
|
|
569
|
+
elif cell_size is None:
|
|
570
|
+
# Default to 10m if no bands are available
|
|
571
|
+
cell_size = 10
|
|
572
|
+
if show_progress:
|
|
573
|
+
print(f"Using default resolution: {cell_size}m")
|
|
574
|
+
|
|
575
|
+
# Process bands in memory-efficient way
|
|
576
|
+
for i, (band_name, data_array) in enumerate(band_data_arrays):
|
|
577
|
+
if show_progress:
|
|
578
|
+
print(f"Processing band: {band_name}")
|
|
579
|
+
|
|
580
|
+
# Get current resolution
|
|
581
|
+
current_res = abs(data_array.rio.transform()[0])
|
|
582
|
+
|
|
583
|
+
# Resample if needed
|
|
584
|
+
if (
|
|
585
|
+
abs(current_res - cell_size) > 0.01
|
|
586
|
+
): # Small tolerance for floating point comparison
|
|
587
|
+
if show_progress:
|
|
588
|
+
print(
|
|
589
|
+
f"Resampling {band_name} from {current_res}m to {cell_size}m"
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
# Use bilinear for downsampling (higher to lower resolution)
|
|
593
|
+
# Use nearest for upsampling (lower to higher resolution)
|
|
594
|
+
resampling_method = (
|
|
595
|
+
Resampling.bilinear
|
|
596
|
+
if current_res < cell_size
|
|
597
|
+
else Resampling.nearest
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
resampled = data_array.rio.reproject(
|
|
601
|
+
data_array.rio.crs,
|
|
602
|
+
resolution=(cell_size, cell_size),
|
|
603
|
+
resampling=resampling_method,
|
|
604
|
+
)
|
|
605
|
+
resampled_arrays.append(resampled)
|
|
606
|
+
else:
|
|
607
|
+
resampled_arrays.append(data_array)
|
|
608
|
+
|
|
609
|
+
if show_progress:
|
|
610
|
+
print("Stacking bands...")
|
|
611
|
+
|
|
612
|
+
# Concatenate all resampled arrays along the band dimension
|
|
613
|
+
try:
|
|
614
|
+
merged_data = xr.concat(resampled_arrays, dim="band")
|
|
615
|
+
|
|
616
|
+
if show_progress:
|
|
617
|
+
print(f"Writing merged data to {merged_path}...")
|
|
618
|
+
|
|
619
|
+
# Add description metadata
|
|
620
|
+
merged_data.attrs["description"] = (
|
|
621
|
+
f"Multi-band image containing {', '.join(band_names)}"
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
# Create a dictionary mapping band indices to band names
|
|
625
|
+
band_descriptions = {}
|
|
626
|
+
for i, name in enumerate(band_names):
|
|
627
|
+
band_descriptions[i + 1] = name
|
|
628
|
+
|
|
629
|
+
# Write the merged data to file with band descriptions
|
|
630
|
+
merged_data.rio.to_raster(
|
|
631
|
+
merged_path,
|
|
632
|
+
tags={"BAND_NAMES": ",".join(band_names)},
|
|
633
|
+
descriptions=band_names,
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
result["merged"] = merged_path
|
|
637
|
+
|
|
638
|
+
if show_progress:
|
|
639
|
+
print(f"Merged bands saved to: {merged_path}")
|
|
640
|
+
print(f"Band order in merged file: {', '.join(band_names)}")
|
|
641
|
+
except Exception as e:
|
|
642
|
+
if show_progress:
|
|
643
|
+
print(f"Error during merging: {str(e)}")
|
|
644
|
+
print(f"Error details: {type(e).__name__}: {str(e)}")
|
|
645
|
+
raise
|
|
646
|
+
|
|
647
|
+
return result
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def pc_collection_list(
|
|
651
|
+
endpoint="https://planetarycomputer.microsoft.com/api/stac/v1",
|
|
652
|
+
detailed=False,
|
|
653
|
+
filter_by=None,
|
|
654
|
+
sort_by="id",
|
|
655
|
+
):
|
|
656
|
+
"""
|
|
657
|
+
Retrieves and displays the list of available collections from Planetary Computer.
|
|
658
|
+
|
|
659
|
+
This function connects to the Planetary Computer STAC API and retrieves the
|
|
660
|
+
list of all available collections, with options to filter and sort the results.
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
endpoint (str, optional): STAC API endpoint URL.
|
|
664
|
+
Defaults to "https://planetarycomputer.microsoft.com/api/stac/v1".
|
|
665
|
+
detailed (bool, optional): Whether to return detailed information for each
|
|
666
|
+
collection. If False, returns only basic info. Defaults to False.
|
|
667
|
+
filter_by (dict, optional): Dictionary of field:value pairs to filter
|
|
668
|
+
collections. For example, {"license": "CC-BY-4.0"}. Defaults to None.
|
|
669
|
+
sort_by (str, optional): Field to sort the collections by.
|
|
670
|
+
Defaults to "id".
|
|
671
|
+
|
|
672
|
+
Returns:
|
|
673
|
+
pandas.DataFrame: DataFrame containing collection information.
|
|
674
|
+
|
|
675
|
+
Raises:
|
|
676
|
+
ConnectionError: If there's an issue connecting to the API.
|
|
677
|
+
"""
|
|
678
|
+
# Initialize the STAC client
|
|
679
|
+
try:
|
|
680
|
+
catalog = Client.open(endpoint)
|
|
681
|
+
except Exception as e:
|
|
682
|
+
raise ConnectionError(f"Failed to connect to STAC API at {endpoint}: {str(e)}")
|
|
683
|
+
|
|
684
|
+
# Get all collections
|
|
685
|
+
try:
|
|
686
|
+
collections = list(catalog.get_collections())
|
|
687
|
+
except Exception as e:
|
|
688
|
+
raise Exception(f"Error retrieving collections: {str(e)}")
|
|
689
|
+
|
|
690
|
+
# Basic info to extract from all collections
|
|
691
|
+
collection_info = []
|
|
692
|
+
|
|
693
|
+
# Extract information based on detail level
|
|
694
|
+
for collection in collections:
|
|
695
|
+
# Basic information always included
|
|
696
|
+
info = {
|
|
697
|
+
"id": collection.id,
|
|
698
|
+
"title": collection.title or "No title",
|
|
699
|
+
"description": (
|
|
700
|
+
collection.description[:100] + "..."
|
|
701
|
+
if collection.description and len(collection.description) > 100
|
|
702
|
+
else collection.description
|
|
703
|
+
),
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
# Add detailed information if requested
|
|
707
|
+
if detailed:
|
|
708
|
+
# Get temporal extent if available
|
|
709
|
+
temporal_extent = "Unknown"
|
|
710
|
+
if collection.extent and collection.extent.temporal:
|
|
711
|
+
interval = (
|
|
712
|
+
collection.extent.temporal.intervals[0]
|
|
713
|
+
if collection.extent.temporal.intervals
|
|
714
|
+
else None
|
|
715
|
+
)
|
|
716
|
+
if interval:
|
|
717
|
+
start = interval[0] or "Unknown Start"
|
|
718
|
+
end = interval[1] or "Present"
|
|
719
|
+
if isinstance(start, datetime.datetime):
|
|
720
|
+
start = start.strftime("%Y-%m-%d")
|
|
721
|
+
if isinstance(end, datetime.datetime):
|
|
722
|
+
end = end.strftime("%Y-%m-%d")
|
|
723
|
+
temporal_extent = f"{start} to {end}"
|
|
724
|
+
|
|
725
|
+
# Add additional details
|
|
726
|
+
info.update(
|
|
727
|
+
{
|
|
728
|
+
"license": collection.license or "Unknown",
|
|
729
|
+
"keywords": (
|
|
730
|
+
", ".join(collection.keywords)
|
|
731
|
+
if collection.keywords
|
|
732
|
+
else "None"
|
|
733
|
+
),
|
|
734
|
+
"temporal_extent": temporal_extent,
|
|
735
|
+
"asset_count": len(collection.assets) if collection.assets else 0,
|
|
736
|
+
"providers": (
|
|
737
|
+
", ".join([p.name for p in collection.providers])
|
|
738
|
+
if collection.providers
|
|
739
|
+
else "Unknown"
|
|
740
|
+
),
|
|
741
|
+
}
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Add spatial extent if available
|
|
745
|
+
if collection.extent and collection.extent.spatial:
|
|
746
|
+
info["bbox"] = (
|
|
747
|
+
str(collection.extent.spatial.bboxes[0])
|
|
748
|
+
if collection.extent.spatial.bboxes
|
|
749
|
+
else "Unknown"
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
collection_info.append(info)
|
|
753
|
+
|
|
754
|
+
# Convert to DataFrame for easier filtering and sorting
|
|
755
|
+
df = pd.DataFrame(collection_info)
|
|
756
|
+
|
|
757
|
+
# Apply filtering if specified
|
|
758
|
+
if filter_by:
|
|
759
|
+
for field, value in filter_by.items():
|
|
760
|
+
if field in df.columns:
|
|
761
|
+
df = df[df[field].astype(str).str.contains(value, case=False, na=False)]
|
|
762
|
+
|
|
763
|
+
# Apply sorting
|
|
764
|
+
if sort_by in df.columns:
|
|
765
|
+
df = df.sort_values(by=sort_by)
|
|
766
|
+
|
|
767
|
+
print(f"Retrieved {len(df)} collections from Planetary Computer")
|
|
768
|
+
|
|
769
|
+
# # Print a nicely formatted table
|
|
770
|
+
# if not df.empty:
|
|
771
|
+
# print("\nAvailable collections:")
|
|
772
|
+
# print(tabulate(df, headers="keys", tablefmt="grid", showindex=False))
|
|
773
|
+
|
|
774
|
+
return df
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def pc_stac_search(
|
|
778
|
+
collection,
|
|
779
|
+
bbox=None,
|
|
780
|
+
time_range=None,
|
|
781
|
+
query=None,
|
|
782
|
+
limit=10,
|
|
783
|
+
max_items=None,
|
|
784
|
+
endpoint="https://planetarycomputer.microsoft.com/api/stac/v1",
|
|
785
|
+
):
|
|
786
|
+
"""
|
|
787
|
+
Search for STAC items in the Planetary Computer catalog.
|
|
788
|
+
|
|
789
|
+
This function queries the Planetary Computer STAC API to find items matching
|
|
790
|
+
the specified criteria, including collection, bounding box, time range, and
|
|
791
|
+
additional query parameters.
|
|
792
|
+
|
|
793
|
+
Args:
|
|
794
|
+
collection (str): The STAC collection ID to search within.
|
|
795
|
+
bbox (list, optional): Bounding box coordinates [west, south, east, north].
|
|
796
|
+
Defaults to None.
|
|
797
|
+
time_range (str or tuple, optional): Time range as a string "start/end" or
|
|
798
|
+
a tuple of (start, end) datetime objects. Defaults to None.
|
|
799
|
+
query (dict, optional): Additional query parameters for filtering.
|
|
800
|
+
Defaults to None.
|
|
801
|
+
limit (int, optional): Number of items to return per page. Defaults to 10.
|
|
802
|
+
max_items (int, optional): Maximum total number of items to return.
|
|
803
|
+
Defaults to None (returns all matching items).
|
|
804
|
+
endpoint (str, optional): STAC API endpoint URL.
|
|
805
|
+
Defaults to "https://planetarycomputer.microsoft.com/api/stac/v1".
|
|
806
|
+
|
|
807
|
+
Returns:
|
|
808
|
+
list: List of STAC Item objects matching the search criteria.
|
|
809
|
+
|
|
810
|
+
Raises:
|
|
811
|
+
ValueError: If invalid parameters are provided.
|
|
812
|
+
ConnectionError: If there's an issue connecting to the API.
|
|
813
|
+
"""
|
|
814
|
+
import datetime
|
|
815
|
+
|
|
816
|
+
# Initialize the STAC client
|
|
817
|
+
try:
|
|
818
|
+
catalog = Client.open(endpoint)
|
|
819
|
+
except Exception as e:
|
|
820
|
+
raise ConnectionError(f"Failed to connect to STAC API at {endpoint}: {str(e)}")
|
|
821
|
+
|
|
822
|
+
# Process time_range if provided
|
|
823
|
+
if time_range:
|
|
824
|
+
if isinstance(time_range, tuple) and len(time_range) == 2:
|
|
825
|
+
# Convert datetime objects to ISO format strings
|
|
826
|
+
start, end = time_range
|
|
827
|
+
if isinstance(start, datetime.datetime):
|
|
828
|
+
start = start.isoformat()
|
|
829
|
+
if isinstance(end, datetime.datetime):
|
|
830
|
+
end = end.isoformat()
|
|
831
|
+
time_str = f"{start}/{end}"
|
|
832
|
+
elif isinstance(time_range, str):
|
|
833
|
+
time_str = time_range
|
|
834
|
+
else:
|
|
835
|
+
raise ValueError(
|
|
836
|
+
"time_range must be a 'start/end' string or tuple of (start, end)"
|
|
837
|
+
)
|
|
838
|
+
else:
|
|
839
|
+
time_str = None
|
|
840
|
+
|
|
841
|
+
# Create the search object
|
|
842
|
+
search = catalog.search(
|
|
843
|
+
collections=[collection], bbox=bbox, datetime=time_str, query=query, limit=limit
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
# Collect the items
|
|
847
|
+
items = []
|
|
848
|
+
try:
|
|
849
|
+
# Use max_items if specified, otherwise get all items
|
|
850
|
+
if max_items:
|
|
851
|
+
items_gen = search.get_items()
|
|
852
|
+
for item in items_gen:
|
|
853
|
+
items.append(item)
|
|
854
|
+
if len(items) >= max_items:
|
|
855
|
+
break
|
|
856
|
+
else:
|
|
857
|
+
items = list(search.get_items())
|
|
858
|
+
except Exception as e:
|
|
859
|
+
raise Exception(f"Error retrieving search results: {str(e)}")
|
|
860
|
+
|
|
861
|
+
print(f"Found {len(items)} items matching search criteria")
|
|
862
|
+
|
|
863
|
+
return items
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
def pc_stac_download(
|
|
867
|
+
items,
|
|
868
|
+
output_dir=".",
|
|
869
|
+
asset_keys=None,
|
|
870
|
+
max_workers=4,
|
|
871
|
+
skip_existing=True,
|
|
872
|
+
sign_urls=True,
|
|
873
|
+
):
|
|
874
|
+
"""
|
|
875
|
+
Download assets from STAC items retrieved from the Planetary Computer.
|
|
876
|
+
|
|
877
|
+
This function downloads specified assets from a list of STAC items to the
|
|
878
|
+
specified output directory. It supports parallel downloads and can skip
|
|
879
|
+
already downloaded files.
|
|
880
|
+
|
|
881
|
+
Args:
|
|
882
|
+
items (list or pystac.Item): STAC Item object or list of STAC Item objects.
|
|
883
|
+
output_dir (str, optional): Directory where assets will be saved.
|
|
884
|
+
Defaults to current directory.
|
|
885
|
+
asset_keys (list, optional): List of asset keys to download. If None,
|
|
886
|
+
downloads all available assets. Defaults to None.
|
|
887
|
+
max_workers (int, optional): Maximum number of concurrent download threads.
|
|
888
|
+
Defaults to 4.
|
|
889
|
+
skip_existing (bool, optional): Skip download if the file already exists.
|
|
890
|
+
Defaults to True.
|
|
891
|
+
sign_urls (bool, optional): Whether to sign URLs for authenticated access.
|
|
892
|
+
Defaults to True.
|
|
893
|
+
|
|
894
|
+
Returns:
|
|
895
|
+
dict: Dictionary mapping STAC item IDs to dictionaries of their downloaded
|
|
896
|
+
assets {asset_key: file_path}.
|
|
897
|
+
|
|
898
|
+
Raises:
|
|
899
|
+
TypeError: If items is not a STAC Item or list of STAC Items.
|
|
900
|
+
IOError: If there's an error writing the downloaded assets to disk.
|
|
901
|
+
"""
|
|
902
|
+
import pystac
|
|
903
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
904
|
+
|
|
905
|
+
# Handle single item case
|
|
906
|
+
if isinstance(items, pystac.Item):
|
|
907
|
+
items = [items]
|
|
908
|
+
elif not isinstance(items, list):
|
|
909
|
+
raise TypeError("items must be a STAC Item or list of STAC Items")
|
|
910
|
+
|
|
911
|
+
# Create output directory if it doesn't exist
|
|
912
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
913
|
+
|
|
914
|
+
# Function to sign URLs if needed
|
|
915
|
+
def get_signed_url(href):
|
|
916
|
+
if not sign_urls:
|
|
917
|
+
return href
|
|
918
|
+
|
|
919
|
+
# Planetary Computer typically requires signing URLs for accessing data
|
|
920
|
+
# Check if the URL is from Microsoft Planetary Computer
|
|
921
|
+
if "planetarycomputer" in href:
|
|
922
|
+
try:
|
|
923
|
+
sign_url = "https://planetarycomputer.microsoft.com/api/sas/v1/sign"
|
|
924
|
+
response = requests.get(sign_url, params={"href": href})
|
|
925
|
+
response.raise_for_status()
|
|
926
|
+
return response.json().get("href", href)
|
|
927
|
+
except Exception as e:
|
|
928
|
+
print(f"Warning: Failed to sign URL {href}: {str(e)}")
|
|
929
|
+
return href
|
|
930
|
+
return href
|
|
931
|
+
|
|
932
|
+
# Function to download a single asset
|
|
933
|
+
def download_asset(item, asset_key, asset):
|
|
934
|
+
item_id = item.id
|
|
935
|
+
|
|
936
|
+
# Get the asset URL and sign it if needed
|
|
937
|
+
asset_url = get_signed_url(asset.href)
|
|
938
|
+
|
|
939
|
+
# Determine output filename
|
|
940
|
+
if asset.media_type:
|
|
941
|
+
# Use appropriate file extension based on media type
|
|
942
|
+
if "tiff" in asset.media_type or "geotiff" in asset.media_type:
|
|
943
|
+
ext = ".tif"
|
|
944
|
+
elif "jpeg" in asset.media_type:
|
|
945
|
+
ext = ".jpg"
|
|
946
|
+
elif "png" in asset.media_type:
|
|
947
|
+
ext = ".png"
|
|
948
|
+
elif "json" in asset.media_type:
|
|
949
|
+
ext = ".json"
|
|
950
|
+
else:
|
|
951
|
+
# Default extension based on the original URL
|
|
952
|
+
ext = os.path.splitext(asset_url.split("?")[0])[1] or ".data"
|
|
953
|
+
else:
|
|
954
|
+
# Default extension based on the original URL
|
|
955
|
+
ext = os.path.splitext(asset_url.split("?")[0])[1] or ".data"
|
|
956
|
+
|
|
957
|
+
output_path = os.path.join(output_dir, f"{item_id}_{asset_key}{ext}")
|
|
958
|
+
|
|
959
|
+
# Skip if file exists and skip_existing is True
|
|
960
|
+
if skip_existing and os.path.exists(output_path):
|
|
961
|
+
print(f"Skipping existing asset: {asset_key} -> {output_path}")
|
|
962
|
+
return asset_key, output_path
|
|
963
|
+
|
|
964
|
+
try:
|
|
965
|
+
# Download the asset with progress bar
|
|
966
|
+
with requests.get(asset_url, stream=True) as r:
|
|
967
|
+
r.raise_for_status()
|
|
968
|
+
total_size = int(r.headers.get("content-length", 0))
|
|
969
|
+
with open(output_path, "wb") as f:
|
|
970
|
+
with tqdm(
|
|
971
|
+
total=total_size,
|
|
972
|
+
unit="B",
|
|
973
|
+
unit_scale=True,
|
|
974
|
+
unit_divisor=1024,
|
|
975
|
+
desc=f"Downloading {item_id}_{asset_key}",
|
|
976
|
+
ncols=100,
|
|
977
|
+
) as pbar:
|
|
978
|
+
for chunk in r.iter_content(chunk_size=8192):
|
|
979
|
+
f.write(chunk)
|
|
980
|
+
pbar.update(len(chunk))
|
|
981
|
+
|
|
982
|
+
return asset_key, output_path
|
|
983
|
+
except Exception as e:
|
|
984
|
+
print(f"Error downloading {asset_key} for item {item_id}: {str(e)}")
|
|
985
|
+
if os.path.exists(output_path):
|
|
986
|
+
os.remove(output_path) # Clean up partial download
|
|
987
|
+
return asset_key, None
|
|
988
|
+
|
|
989
|
+
# Process all items and their assets
|
|
990
|
+
results = {}
|
|
991
|
+
|
|
992
|
+
for item in items:
|
|
993
|
+
item_assets = {}
|
|
994
|
+
item_id = item.id
|
|
995
|
+
print(f"Processing STAC item: {item_id}")
|
|
996
|
+
|
|
997
|
+
# Determine which assets to download
|
|
998
|
+
if asset_keys:
|
|
999
|
+
assets_to_download = {
|
|
1000
|
+
k: v for k, v in item.assets.items() if k in asset_keys
|
|
1001
|
+
}
|
|
1002
|
+
if not assets_to_download:
|
|
1003
|
+
print(
|
|
1004
|
+
f"Warning: None of the specified asset keys {asset_keys} found in item {item_id}"
|
|
1005
|
+
)
|
|
1006
|
+
print(f"Available asset keys: {list(item.assets.keys())}")
|
|
1007
|
+
continue
|
|
1008
|
+
else:
|
|
1009
|
+
assets_to_download = item.assets
|
|
1010
|
+
|
|
1011
|
+
# Download assets concurrently
|
|
1012
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
1013
|
+
# Submit all download tasks
|
|
1014
|
+
future_to_asset = {
|
|
1015
|
+
executor.submit(download_asset, item, asset_key, asset): (
|
|
1016
|
+
asset_key,
|
|
1017
|
+
asset,
|
|
1018
|
+
)
|
|
1019
|
+
for asset_key, asset in assets_to_download.items()
|
|
1020
|
+
}
|
|
1021
|
+
|
|
1022
|
+
# Process results as they complete
|
|
1023
|
+
for future in as_completed(future_to_asset):
|
|
1024
|
+
asset_key, asset = future_to_asset[future]
|
|
1025
|
+
try:
|
|
1026
|
+
key, path = future.result()
|
|
1027
|
+
if path:
|
|
1028
|
+
item_assets[key] = path
|
|
1029
|
+
except Exception as e:
|
|
1030
|
+
print(
|
|
1031
|
+
f"Error processing asset {asset_key} for item {item_id}: {str(e)}"
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
results[item_id] = item_assets
|
|
1035
|
+
|
|
1036
|
+
# Count total downloaded assets
|
|
1037
|
+
total_assets = sum(len(assets) for assets in results.values())
|
|
1038
|
+
print(f"\nDownloaded {total_assets} assets for {len(results)} items")
|
|
1039
|
+
|
|
1040
|
+
return results
|