rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
"""Base classes for rslearn data sources."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Generator
|
|
4
|
-
from typing import Any, BinaryIO
|
|
4
|
+
from typing import Any, BinaryIO, Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from upath import UPath
|
|
5
7
|
|
|
6
8
|
from rslearn.config import LayerConfig, QueryConfig
|
|
7
9
|
from rslearn.dataset import Window
|
|
8
|
-
from rslearn.tile_stores import
|
|
10
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
9
11
|
from rslearn.utils import STGeometry
|
|
10
12
|
|
|
11
13
|
|
|
@@ -51,7 +53,10 @@ class Item:
|
|
|
51
53
|
return hash(self.name)
|
|
52
54
|
|
|
53
55
|
|
|
54
|
-
|
|
56
|
+
ItemType = TypeVar("ItemType", bound="Item")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class DataSource(Generic[ItemType]):
|
|
55
60
|
"""A set of raster or vector files that can be retrieved.
|
|
56
61
|
|
|
57
62
|
Data sources should support at least one of ingest and materialize.
|
|
@@ -59,7 +64,7 @@ class DataSource:
|
|
|
59
64
|
|
|
60
65
|
def get_items(
|
|
61
66
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
62
|
-
) -> list[list[list[
|
|
67
|
+
) -> list[list[list[ItemType]]]:
|
|
63
68
|
"""Get a list of items in the data source intersecting the given geometries.
|
|
64
69
|
|
|
65
70
|
Args:
|
|
@@ -71,14 +76,14 @@ class DataSource:
|
|
|
71
76
|
"""
|
|
72
77
|
raise NotImplementedError
|
|
73
78
|
|
|
74
|
-
def deserialize_item(self, serialized_item: Any) ->
|
|
79
|
+
def deserialize_item(self, serialized_item: Any) -> ItemType:
|
|
75
80
|
"""Deserializes an item from JSON-decoded data."""
|
|
76
81
|
raise NotImplementedError
|
|
77
82
|
|
|
78
83
|
def ingest(
|
|
79
84
|
self,
|
|
80
|
-
tile_store:
|
|
81
|
-
items: list[
|
|
85
|
+
tile_store: TileStoreWithLayer,
|
|
86
|
+
items: list[ItemType],
|
|
82
87
|
geometries: list[list[STGeometry]],
|
|
83
88
|
) -> None:
|
|
84
89
|
"""Ingest items into the given tile store.
|
|
@@ -93,7 +98,7 @@ class DataSource:
|
|
|
93
98
|
def materialize(
|
|
94
99
|
self,
|
|
95
100
|
window: Window,
|
|
96
|
-
item_groups: list[list[
|
|
101
|
+
item_groups: list[list[ItemType]],
|
|
97
102
|
layer_name: str,
|
|
98
103
|
layer_cfg: LayerConfig,
|
|
99
104
|
) -> None:
|
|
@@ -108,17 +113,43 @@ class DataSource:
|
|
|
108
113
|
raise NotImplementedError
|
|
109
114
|
|
|
110
115
|
|
|
111
|
-
class ItemLookupDataSource(DataSource):
|
|
116
|
+
class ItemLookupDataSource(DataSource[ItemType]):
|
|
112
117
|
"""A data source that can look up items by name."""
|
|
113
118
|
|
|
114
|
-
def get_item_by_name(self, name: str) ->
|
|
119
|
+
def get_item_by_name(self, name: str) -> ItemType:
|
|
115
120
|
"""Gets an item by name."""
|
|
116
121
|
raise NotImplementedError
|
|
117
122
|
|
|
118
123
|
|
|
119
|
-
class RetrieveItemDataSource(DataSource):
|
|
124
|
+
class RetrieveItemDataSource(DataSource[ItemType]):
|
|
120
125
|
"""A data source that can retrieve items in their raw format."""
|
|
121
126
|
|
|
122
|
-
def retrieve_item(
|
|
127
|
+
def retrieve_item(
|
|
128
|
+
self, item: ItemType
|
|
129
|
+
) -> Generator[tuple[str, BinaryIO], None, None]:
|
|
123
130
|
"""Retrieves the rasters corresponding to an item as file streams."""
|
|
124
131
|
raise NotImplementedError
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class DataSourceContext:
|
|
135
|
+
"""This context is passed to every data source.
|
|
136
|
+
|
|
137
|
+
When initializing data sources within rslearn, we always set the ds_path and
|
|
138
|
+
layer_config. However, for convenience (for users directly initializing the data
|
|
139
|
+
sources externally), each data source should allow for initialization when one or
|
|
140
|
+
both are missing.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self, ds_path: UPath | None = None, layer_config: LayerConfig | None = None
|
|
145
|
+
):
|
|
146
|
+
"""Create a new DataSourceContext.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
ds_path: the path of the underlying dataset.
|
|
150
|
+
layer_config: the LayerConfig for the layer that the data source is for.
|
|
151
|
+
"""
|
|
152
|
+
# We don't use dataclass here because otherwise jsonargparse will ignore our
|
|
153
|
+
# custom serializer/deserializer defined in rslearn.utils.jsonargparse.
|
|
154
|
+
self.ds_path = ds_path
|
|
155
|
+
self.layer_config = layer_config
|
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
"""Data on EarthDaily."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import tempfile
|
|
6
|
+
from datetime import timedelta
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
import affine
|
|
10
|
+
import numpy.typing as npt
|
|
11
|
+
import pystac
|
|
12
|
+
import pystac_client
|
|
13
|
+
import rasterio
|
|
14
|
+
import requests
|
|
15
|
+
import shapely
|
|
16
|
+
from earthdaily import EDSClient, EDSConfig
|
|
17
|
+
from rasterio.enums import Resampling
|
|
18
|
+
from upath import UPath
|
|
19
|
+
|
|
20
|
+
from rslearn.config import LayerConfig, QueryConfig
|
|
21
|
+
from rslearn.const import WGS84_PROJECTION
|
|
22
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
23
|
+
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
24
|
+
from rslearn.dataset import Window
|
|
25
|
+
from rslearn.dataset.materialize import RasterMaterializer
|
|
26
|
+
from rslearn.log_utils import get_logger
|
|
27
|
+
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
28
|
+
from rslearn.utils.fsspec import join_upath
|
|
29
|
+
from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
|
|
30
|
+
|
|
31
|
+
logger = get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class EarthDailyItem(Item):
|
|
35
|
+
"""An item in the EarthDaily data source."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, name: str, geometry: STGeometry, asset_urls: dict[str, str]):
|
|
38
|
+
"""Creates a new EarthDailyItem.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
name: unique name of the item
|
|
42
|
+
geometry: the spatial and temporal extent of the item
|
|
43
|
+
asset_urls: map from asset key to the asset URL.
|
|
44
|
+
"""
|
|
45
|
+
super().__init__(name, geometry)
|
|
46
|
+
self.asset_urls = asset_urls
|
|
47
|
+
|
|
48
|
+
def serialize(self) -> dict[str, Any]:
|
|
49
|
+
"""Serializes the item to a JSON-encodable dictionary."""
|
|
50
|
+
d = super().serialize()
|
|
51
|
+
d["asset_urls"] = self.asset_urls
|
|
52
|
+
return d
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def deserialize(d: dict[str, Any]) -> "EarthDailyItem":
|
|
56
|
+
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
57
|
+
item = super(EarthDailyItem, EarthDailyItem).deserialize(d)
|
|
58
|
+
return EarthDailyItem(
|
|
59
|
+
name=item.name,
|
|
60
|
+
geometry=item.geometry,
|
|
61
|
+
asset_urls=d["asset_urls"],
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class EarthDaily(DataSource, TileStore):
|
|
66
|
+
"""A data source for EarthDaily data.
|
|
67
|
+
|
|
68
|
+
This requires the following environment variables to be set:
|
|
69
|
+
- EDS_CLIENT_ID
|
|
70
|
+
- EDS_SECRET
|
|
71
|
+
- EDS_AUTH_URL
|
|
72
|
+
- EDS_API_URL
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
collection_name: str,
|
|
78
|
+
asset_bands: dict[str, list[str]],
|
|
79
|
+
query: dict[str, Any] | None = None,
|
|
80
|
+
sort_by: str | None = None,
|
|
81
|
+
sort_ascending: bool = True,
|
|
82
|
+
timeout: timedelta = timedelta(seconds=10),
|
|
83
|
+
skip_items_missing_assets: bool = False,
|
|
84
|
+
cache_dir: str | None = None,
|
|
85
|
+
max_retries: int = 3,
|
|
86
|
+
retry_backoff_factor: float = 5.0,
|
|
87
|
+
service_name: Literal["platform"] = "platform",
|
|
88
|
+
context: DataSourceContext = DataSourceContext(),
|
|
89
|
+
):
|
|
90
|
+
"""Initialize a new EarthDaily instance.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
collection_name: the STAC collection name on EarthDaily.
|
|
94
|
+
asset_bands: assets to ingest, mapping from asset name to the list of bands
|
|
95
|
+
in that asset.
|
|
96
|
+
query: optional query argument to STAC searches.
|
|
97
|
+
sort_by: sort by this property in the STAC items.
|
|
98
|
+
sort_ascending: whether to sort ascending (or descending).
|
|
99
|
+
timeout: timeout for API requests.
|
|
100
|
+
skip_items_missing_assets: skip STAC items that are missing any of the
|
|
101
|
+
assets in asset_bands during get_items.
|
|
102
|
+
cache_dir: optional directory to cache items by name, including asset URLs.
|
|
103
|
+
If not set, there will be no cache and instead STAC requests will be
|
|
104
|
+
needed each time.
|
|
105
|
+
max_retries: the maximum number of retry attempts for HTTP requests that fail
|
|
106
|
+
due to transient errors (e.g., 429, 500, 502, 503, 504 status codes).
|
|
107
|
+
retry_backoff_factor: backoff factor for exponential retry delays between HTTP
|
|
108
|
+
request attempts. The delay between retries is calculated using the formula:
|
|
109
|
+
`(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
|
|
110
|
+
service_name: the service name, only "platform" is supported, the other
|
|
111
|
+
services "legacy" and "internal" are not supported.
|
|
112
|
+
context: the data source context.
|
|
113
|
+
"""
|
|
114
|
+
self.collection_name = collection_name
|
|
115
|
+
self.asset_bands = asset_bands
|
|
116
|
+
self.query = query
|
|
117
|
+
self.sort_by = sort_by
|
|
118
|
+
self.sort_ascending = sort_ascending
|
|
119
|
+
self.timeout = timeout
|
|
120
|
+
self.skip_items_missing_assets = skip_items_missing_assets
|
|
121
|
+
self.max_retries = max_retries
|
|
122
|
+
self.retry_backoff_factor = retry_backoff_factor
|
|
123
|
+
self.service_name = service_name
|
|
124
|
+
|
|
125
|
+
if cache_dir is not None:
|
|
126
|
+
# Use dataset path as root if provided.
|
|
127
|
+
if context.ds_path is not None:
|
|
128
|
+
self.cache_dir = join_upath(context.ds_path, cache_dir)
|
|
129
|
+
else:
|
|
130
|
+
self.cache_dir = UPath(cache_dir)
|
|
131
|
+
|
|
132
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
133
|
+
else:
|
|
134
|
+
self.cache_dir = None
|
|
135
|
+
|
|
136
|
+
self.eds_client: EDSClient | None = None
|
|
137
|
+
self.client: pystac_client.Client | None = None
|
|
138
|
+
self.collection: pystac_client.CollectionClient | None = None
|
|
139
|
+
|
|
140
|
+
def _load_client(
|
|
141
|
+
self,
|
|
142
|
+
) -> tuple[EDSClient, pystac_client.Client, pystac_client.CollectionClient]:
|
|
143
|
+
"""Lazily load EDS client.
|
|
144
|
+
|
|
145
|
+
We don't load it when creating the data source because it takes time and caller
|
|
146
|
+
may not be calling get_items. Additionally, loading it during the get_items
|
|
147
|
+
call enables leveraging the retry loop functionality in
|
|
148
|
+
prepare_dataset_windows.
|
|
149
|
+
"""
|
|
150
|
+
if self.eds_client is not None:
|
|
151
|
+
return self.eds_client, self.client, self.collection
|
|
152
|
+
|
|
153
|
+
self.eds_client = EDSClient(
|
|
154
|
+
EDSConfig(
|
|
155
|
+
max_retries=self.max_retries,
|
|
156
|
+
retry_backoff_factor=self.retry_backoff_factor,
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if self.service_name == "platform":
|
|
161
|
+
self.client = self.eds_client.platform.pystac_client
|
|
162
|
+
self.collection = self.client.get_collection(self.collection_name)
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(f"Invalid service name: {self.service_name}")
|
|
165
|
+
|
|
166
|
+
return self.eds_client, self.client, self.collection
|
|
167
|
+
|
|
168
|
+
def _stac_item_to_item(self, stac_item: pystac.Item) -> EarthDailyItem:
|
|
169
|
+
shp = shapely.geometry.shape(stac_item.geometry)
|
|
170
|
+
|
|
171
|
+
metadata = stac_item.common_metadata
|
|
172
|
+
if metadata.start_datetime is not None and metadata.end_datetime is not None:
|
|
173
|
+
time_range = (
|
|
174
|
+
metadata.start_datetime,
|
|
175
|
+
metadata.end_datetime,
|
|
176
|
+
)
|
|
177
|
+
elif stac_item.datetime is not None:
|
|
178
|
+
time_range = (stac_item.datetime, stac_item.datetime)
|
|
179
|
+
else:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"item {stac_item.id} unexpectedly missing start_datetime, end_datetime, and datetime"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
geom = STGeometry(WGS84_PROJECTION, shp, time_range)
|
|
185
|
+
asset_urls = {
|
|
186
|
+
asset_key: asset_obj.extra_fields["alternate"]["download"]["href"]
|
|
187
|
+
for asset_key, asset_obj in stac_item.assets.items()
|
|
188
|
+
if "alternate" in asset_obj.extra_fields
|
|
189
|
+
and "download" in asset_obj.extra_fields["alternate"]
|
|
190
|
+
and "href" in asset_obj.extra_fields["alternate"]["download"]
|
|
191
|
+
}
|
|
192
|
+
return EarthDailyItem(stac_item.id, geom, asset_urls)
|
|
193
|
+
|
|
194
|
+
def get_item_by_name(self, name: str) -> EarthDailyItem:
|
|
195
|
+
"""Gets an item by name.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
name: the name of the item to get
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
the item object
|
|
202
|
+
"""
|
|
203
|
+
# If cache_dir is set, we cache the item. First here we check if it is already
|
|
204
|
+
# in the cache.
|
|
205
|
+
cache_fname: UPath | None = None
|
|
206
|
+
if self.cache_dir:
|
|
207
|
+
cache_fname = self.cache_dir / f"{name}.json"
|
|
208
|
+
if cache_fname is not None and cache_fname.exists():
|
|
209
|
+
with cache_fname.open() as f:
|
|
210
|
+
return EarthDailyItem.deserialize(json.load(f))
|
|
211
|
+
|
|
212
|
+
# No cache or not in cache, so we need to make the STAC request.
|
|
213
|
+
_, _, collection = self._load_client()
|
|
214
|
+
stac_item = collection.get_item(name)
|
|
215
|
+
item = self._stac_item_to_item(stac_item)
|
|
216
|
+
|
|
217
|
+
# Finally we cache it if cache_dir is set.
|
|
218
|
+
if cache_fname is not None:
|
|
219
|
+
with cache_fname.open("w") as f:
|
|
220
|
+
json.dump(item.serialize(), f)
|
|
221
|
+
|
|
222
|
+
return item
|
|
223
|
+
|
|
224
|
+
def get_items(
|
|
225
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
226
|
+
) -> list[list[list[EarthDailyItem]]]:
|
|
227
|
+
"""Get a list of items in the data source intersecting the given geometries.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
geometries: the spatiotemporal geometries
|
|
231
|
+
query_config: the query configuration
|
|
232
|
+
"""
|
|
233
|
+
_, client, _ = self._load_client()
|
|
234
|
+
|
|
235
|
+
groups = []
|
|
236
|
+
for geometry in geometries:
|
|
237
|
+
# Get potentially relevant items from the collection by performing one search
|
|
238
|
+
# for each requested geometry.
|
|
239
|
+
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
240
|
+
logger.debug("performing STAC search for geometry %s", wgs84_geometry)
|
|
241
|
+
result = client.search(
|
|
242
|
+
collections=[self.collection_name],
|
|
243
|
+
intersects=shapely.to_geojson(wgs84_geometry.shp),
|
|
244
|
+
datetime=wgs84_geometry.time_range,
|
|
245
|
+
query=self.query,
|
|
246
|
+
)
|
|
247
|
+
stac_items = [item for item in result.item_collection()]
|
|
248
|
+
logger.debug("STAC search yielded %d items", len(stac_items))
|
|
249
|
+
|
|
250
|
+
if self.skip_items_missing_assets:
|
|
251
|
+
# Filter out items that are missing any of the assets in self.asset_bands.
|
|
252
|
+
good_stac_items = []
|
|
253
|
+
for stac_item in stac_items:
|
|
254
|
+
good = True
|
|
255
|
+
for asset_key in self.asset_bands.keys():
|
|
256
|
+
if asset_key in stac_item.assets:
|
|
257
|
+
continue
|
|
258
|
+
good = False
|
|
259
|
+
break
|
|
260
|
+
if good:
|
|
261
|
+
good_stac_items.append(stac_item)
|
|
262
|
+
logger.debug(
|
|
263
|
+
"skip_items_missing_assets filter from %d to %d items",
|
|
264
|
+
len(stac_items),
|
|
265
|
+
len(good_stac_items),
|
|
266
|
+
)
|
|
267
|
+
stac_items = good_stac_items
|
|
268
|
+
|
|
269
|
+
if self.sort_by is not None:
|
|
270
|
+
stac_items.sort(
|
|
271
|
+
key=lambda stac_item: stac_item.properties[self.sort_by],
|
|
272
|
+
reverse=not self.sort_ascending,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
candidate_items = [
|
|
276
|
+
# The only way to get the asset URLs is to get the item by name.
|
|
277
|
+
self.get_item_by_name(stac_item.id)
|
|
278
|
+
for stac_item in stac_items
|
|
279
|
+
]
|
|
280
|
+
|
|
281
|
+
cur_groups = match_candidate_items_to_window(
|
|
282
|
+
geometry, candidate_items, query_config
|
|
283
|
+
)
|
|
284
|
+
groups.append(cur_groups)
|
|
285
|
+
|
|
286
|
+
return groups
|
|
287
|
+
|
|
288
|
+
def deserialize_item(self, serialized_item: Any) -> EarthDailyItem:
|
|
289
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
290
|
+
assert isinstance(serialized_item, dict)
|
|
291
|
+
return EarthDailyItem.deserialize(serialized_item)
|
|
292
|
+
|
|
293
|
+
def ingest(
|
|
294
|
+
self,
|
|
295
|
+
tile_store: TileStoreWithLayer,
|
|
296
|
+
items: list[EarthDailyItem],
|
|
297
|
+
geometries: list[list[STGeometry]],
|
|
298
|
+
) -> None:
|
|
299
|
+
"""Ingest items into the given tile store.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
tile_store: the tile store to ingest into
|
|
303
|
+
items: the items to ingest
|
|
304
|
+
geometries: a list of geometries needed for each item
|
|
305
|
+
"""
|
|
306
|
+
for item in items:
|
|
307
|
+
for asset_key, band_names in self.asset_bands.items():
|
|
308
|
+
if asset_key not in item.asset_urls:
|
|
309
|
+
continue
|
|
310
|
+
if tile_store.is_raster_ready(item.name, band_names):
|
|
311
|
+
continue
|
|
312
|
+
|
|
313
|
+
asset_url = item.asset_urls[asset_key]
|
|
314
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
315
|
+
local_fname = os.path.join(tmp_dir, f"{asset_key}.tif")
|
|
316
|
+
logger.debug(
|
|
317
|
+
"EarthDaily download item %s asset %s to %s",
|
|
318
|
+
item.name,
|
|
319
|
+
asset_key,
|
|
320
|
+
local_fname,
|
|
321
|
+
)
|
|
322
|
+
with requests.get(
|
|
323
|
+
asset_url, stream=True, timeout=self.timeout.total_seconds()
|
|
324
|
+
) as r:
|
|
325
|
+
r.raise_for_status()
|
|
326
|
+
with open(local_fname, "wb") as f:
|
|
327
|
+
for chunk in r.iter_content(chunk_size=8192):
|
|
328
|
+
f.write(chunk)
|
|
329
|
+
|
|
330
|
+
logger.debug(
|
|
331
|
+
"EarthDaily ingest item %s asset %s",
|
|
332
|
+
item.name,
|
|
333
|
+
asset_key,
|
|
334
|
+
)
|
|
335
|
+
tile_store.write_raster_file(
|
|
336
|
+
item.name, band_names, UPath(local_fname)
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
logger.debug(
|
|
340
|
+
"EarthDaily done ingesting item %s asset %s",
|
|
341
|
+
item.name,
|
|
342
|
+
asset_key,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def is_raster_ready(
|
|
346
|
+
self, layer_name: str, item_name: str, bands: list[str]
|
|
347
|
+
) -> bool:
|
|
348
|
+
"""Checks if this raster has been written to the store.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
layer_name: the layer name or alias.
|
|
352
|
+
item_name: the item.
|
|
353
|
+
bands: the list of bands identifying which specific raster to read.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
whether there is a raster in the store matching the source, item, and
|
|
357
|
+
bands.
|
|
358
|
+
"""
|
|
359
|
+
# Always ready since we wrap accesses to EarthDaily.
|
|
360
|
+
return True
|
|
361
|
+
|
|
362
|
+
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
363
|
+
"""Get the sets of bands that have been stored for the specified item.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
layer_name: the layer name or alias.
|
|
367
|
+
item_name: the item.
|
|
368
|
+
"""
|
|
369
|
+
if self.skip_items_missing_assets:
|
|
370
|
+
# In this case we can assume that the item has all of the assets.
|
|
371
|
+
return list(self.asset_bands.values())
|
|
372
|
+
|
|
373
|
+
# Otherwise we have to lookup the STAC item to see which assets it has.
|
|
374
|
+
# Here we use get_item_by_name since it handles caching.
|
|
375
|
+
item = self.get_item_by_name(item_name)
|
|
376
|
+
all_bands = []
|
|
377
|
+
for asset_key, band_names in self.asset_bands.items():
|
|
378
|
+
if asset_key not in item.asset_urls:
|
|
379
|
+
continue
|
|
380
|
+
all_bands.append(band_names)
|
|
381
|
+
return all_bands
|
|
382
|
+
|
|
383
|
+
def _get_asset_by_band(self, bands: list[str]) -> str:
|
|
384
|
+
"""Get the name of the asset based on the band names."""
|
|
385
|
+
for asset_key, asset_bands in self.asset_bands.items():
|
|
386
|
+
if bands == asset_bands:
|
|
387
|
+
return asset_key
|
|
388
|
+
|
|
389
|
+
raise ValueError(f"no raster with bands {bands}")
|
|
390
|
+
|
|
391
|
+
def get_raster_bounds(
|
|
392
|
+
self, layer_name: str, item_name: str, bands: list[str], projection: Projection
|
|
393
|
+
) -> PixelBounds:
|
|
394
|
+
"""Get the bounds of the raster in the specified projection.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
layer_name: the layer name or alias.
|
|
398
|
+
item_name: the item to check.
|
|
399
|
+
bands: the list of bands identifying which specific raster to read. These
|
|
400
|
+
bands must match the bands of a stored raster.
|
|
401
|
+
projection: the projection to get the raster's bounds in.
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
the bounds of the raster in the projection.
|
|
405
|
+
"""
|
|
406
|
+
item = self.get_item_by_name(item_name)
|
|
407
|
+
geom = item.geometry.to_projection(projection)
|
|
408
|
+
return (
|
|
409
|
+
int(geom.shp.bounds[0]),
|
|
410
|
+
int(geom.shp.bounds[1]),
|
|
411
|
+
int(geom.shp.bounds[2]),
|
|
412
|
+
int(geom.shp.bounds[3]),
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
def read_raster(
|
|
416
|
+
self,
|
|
417
|
+
layer_name: str,
|
|
418
|
+
item_name: str,
|
|
419
|
+
bands: list[str],
|
|
420
|
+
projection: Projection,
|
|
421
|
+
bounds: PixelBounds,
|
|
422
|
+
resampling: Resampling = Resampling.bilinear,
|
|
423
|
+
) -> npt.NDArray[Any]:
|
|
424
|
+
"""Read raster data from the store.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
layer_name: the layer name or alias.
|
|
428
|
+
item_name: the item to read.
|
|
429
|
+
bands: the list of bands identifying which specific raster to read. These
|
|
430
|
+
bands must match the bands of a stored raster.
|
|
431
|
+
projection: the projection to read in.
|
|
432
|
+
bounds: the bounds to read.
|
|
433
|
+
resampling: the resampling method to use in case reprojection is needed.
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
the raster data
|
|
437
|
+
"""
|
|
438
|
+
asset_key = self._get_asset_by_band(bands)
|
|
439
|
+
item = self.get_item_by_name(item_name)
|
|
440
|
+
asset_url = item.asset_urls[asset_key]
|
|
441
|
+
|
|
442
|
+
# Construct the transform to use for the warped dataset.
|
|
443
|
+
wanted_transform = affine.Affine(
|
|
444
|
+
projection.x_resolution,
|
|
445
|
+
0,
|
|
446
|
+
bounds[0] * projection.x_resolution,
|
|
447
|
+
0,
|
|
448
|
+
projection.y_resolution,
|
|
449
|
+
bounds[1] * projection.y_resolution,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
with rasterio.open(asset_url) as src:
|
|
453
|
+
with rasterio.vrt.WarpedVRT(
|
|
454
|
+
src,
|
|
455
|
+
crs=projection.crs,
|
|
456
|
+
transform=wanted_transform,
|
|
457
|
+
width=bounds[2] - bounds[0],
|
|
458
|
+
height=bounds[3] - bounds[1],
|
|
459
|
+
resampling=resampling,
|
|
460
|
+
) as vrt:
|
|
461
|
+
return vrt.read()
|
|
462
|
+
|
|
463
|
+
def materialize(
|
|
464
|
+
self,
|
|
465
|
+
window: Window,
|
|
466
|
+
item_groups: list[list[Item]],
|
|
467
|
+
layer_name: str,
|
|
468
|
+
layer_cfg: LayerConfig,
|
|
469
|
+
) -> None:
|
|
470
|
+
"""Materialize data for the window.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
window: the window to materialize
|
|
474
|
+
item_groups: the items from get_items
|
|
475
|
+
layer_name: the name of this layer
|
|
476
|
+
layer_cfg: the config of this layer
|
|
477
|
+
"""
|
|
478
|
+
RasterMaterializer().materialize(
|
|
479
|
+
TileStoreWithLayer(self, layer_name),
|
|
480
|
+
window,
|
|
481
|
+
layer_name,
|
|
482
|
+
layer_cfg,
|
|
483
|
+
item_groups,
|
|
484
|
+
)
|