rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/const.py
CHANGED
|
@@ -1,23 +1,17 @@
|
|
|
1
1
|
"""Constants."""
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
from rslearn.utils import PixelBounds, Projection
|
|
6
|
-
|
|
7
|
-
WGS84_EPSG = 4326
|
|
8
|
-
"""The EPSG code for WGS-84."""
|
|
9
|
-
|
|
10
|
-
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
|
|
11
|
-
"""The Projection for WGS-84 assuming 1 degree per pixel.
|
|
12
|
-
|
|
13
|
-
This can be used to create STGeometry with shapes in longitude/latitude coordinates.
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
WGS84_BOUNDS: PixelBounds = (-180, -90, 180, 90)
|
|
17
|
-
"""The bounds of the WGS-84 projection."""
|
|
3
|
+
from rslearn.utils.geometry import WGS84_BOUNDS, WGS84_EPSG, WGS84_PROJECTION
|
|
18
4
|
|
|
19
5
|
TILE_SIZE = 512
|
|
20
6
|
"""Default tile size. TODO: remove this or move it elsewhere."""
|
|
21
7
|
|
|
22
8
|
SHAPEFILE_AUX_EXTENSIONS = [".cpg", ".dbf", ".prj", ".sbn", ".sbx", ".shx", ".txt"]
|
|
23
9
|
"""Extensions of potential auxiliary files to .shp file."""
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
"WGS84_PROJECTION",
|
|
13
|
+
"WGS84_EPSG",
|
|
14
|
+
"WGS84_BOUNDS",
|
|
15
|
+
"TILE_SIZE",
|
|
16
|
+
"SHAPEFILE_AUX_EXTENSIONS",
|
|
17
|
+
)
|
rslearn/data_sources/__init__.py
CHANGED
|
@@ -10,32 +10,17 @@ Each source supports operations to lookup items that match with spatiotemporal
|
|
|
10
10
|
geometries, and ingest those items.
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
|
-
import
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def data_source_from_config(config: LayerConfig, ds_path: UPath) -> DataSource:
|
|
23
|
-
"""Loads a data source from config dict.
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
config: the LayerConfig containing this data source.
|
|
27
|
-
ds_path: the dataset root directory.
|
|
28
|
-
"""
|
|
29
|
-
name = config.data_source.name
|
|
30
|
-
module_name = ".".join(name.split(".")[:-1])
|
|
31
|
-
class_name = name.split(".")[-1]
|
|
32
|
-
module = importlib.import_module(module_name)
|
|
33
|
-
class_ = getattr(module, class_name)
|
|
34
|
-
return class_.from_config(config, ds_path)
|
|
35
|
-
|
|
13
|
+
from .data_source import (
|
|
14
|
+
DataSource,
|
|
15
|
+
DataSourceContext,
|
|
16
|
+
Item,
|
|
17
|
+
ItemLookupDataSource,
|
|
18
|
+
RetrieveItemDataSource,
|
|
19
|
+
)
|
|
36
20
|
|
|
37
21
|
__all__ = (
|
|
38
22
|
"DataSource",
|
|
23
|
+
"DataSourceContext",
|
|
39
24
|
"Item",
|
|
40
25
|
"ItemLookupDataSource",
|
|
41
26
|
"RetrieveItemDataSource",
|
|
@@ -2,33 +2,41 @@
|
|
|
2
2
|
|
|
3
3
|
import io
|
|
4
4
|
import json
|
|
5
|
+
import os
|
|
5
6
|
import shutil
|
|
7
|
+
import tempfile
|
|
6
8
|
import urllib.request
|
|
7
9
|
import zipfile
|
|
8
10
|
from collections.abc import Generator
|
|
9
|
-
from datetime import
|
|
11
|
+
from datetime import datetime
|
|
10
12
|
from typing import Any, BinaryIO
|
|
11
13
|
|
|
14
|
+
import affine
|
|
12
15
|
import boto3
|
|
13
16
|
import dateutil.parser
|
|
14
17
|
import fiona
|
|
15
18
|
import fiona.transform
|
|
16
|
-
import
|
|
19
|
+
import numpy.typing as npt
|
|
17
20
|
import rasterio
|
|
18
21
|
import shapely
|
|
22
|
+
import shapely.geometry
|
|
19
23
|
import tqdm
|
|
24
|
+
from rasterio.enums import Resampling
|
|
20
25
|
from upath import UPath
|
|
21
26
|
|
|
22
27
|
import rslearn.data_sources.utils
|
|
23
|
-
|
|
24
|
-
from rslearn.config import LayerConfig, RasterLayerConfig
|
|
28
|
+
from rslearn.config import LayerConfig
|
|
25
29
|
from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_PROJECTION
|
|
26
|
-
from rslearn.
|
|
27
|
-
from rslearn.
|
|
30
|
+
from rslearn.dataset import Window
|
|
31
|
+
from rslearn.dataset.materialize import RasterMaterializer
|
|
32
|
+
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
28
33
|
from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
|
|
34
|
+
from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
|
|
35
|
+
from rslearn.utils.grid_index import GridIndex
|
|
29
36
|
|
|
30
|
-
from .data_source import DataSource, Item, QueryConfig
|
|
31
|
-
|
|
37
|
+
from .data_source import DataSource, DataSourceContext, Item, QueryConfig
|
|
38
|
+
|
|
39
|
+
WRS2_GRID_SIZE = 1.0
|
|
32
40
|
|
|
33
41
|
|
|
34
42
|
class LandsatOliTirsItem(Item):
|
|
@@ -36,7 +44,7 @@ class LandsatOliTirsItem(Item):
|
|
|
36
44
|
|
|
37
45
|
def __init__(
|
|
38
46
|
self, name: str, geometry: STGeometry, blob_path: str, cloud_cover: float
|
|
39
|
-
):
|
|
47
|
+
) -> None:
|
|
40
48
|
"""Creates a new LandsatOliTirsItem.
|
|
41
49
|
|
|
42
50
|
Args:
|
|
@@ -58,7 +66,7 @@ class LandsatOliTirsItem(Item):
|
|
|
58
66
|
return d
|
|
59
67
|
|
|
60
68
|
@staticmethod
|
|
61
|
-
def deserialize(d: dict) ->
|
|
69
|
+
def deserialize(d: dict) -> "LandsatOliTirsItem":
|
|
62
70
|
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
63
71
|
if "name" not in d:
|
|
64
72
|
d["name"] = d["blob_path"].split("/")[-1].split(".tif")[0]
|
|
@@ -71,7 +79,7 @@ class LandsatOliTirsItem(Item):
|
|
|
71
79
|
)
|
|
72
80
|
|
|
73
81
|
|
|
74
|
-
class LandsatOliTirs(DataSource):
|
|
82
|
+
class LandsatOliTirs(DataSource, TileStore):
|
|
75
83
|
"""A data source for Landsat 8/9 OLI-TIRS imagery on AWS.
|
|
76
84
|
|
|
77
85
|
Specifically, uses the usgs-landsat S3 bucket maintained by USGS. The data includes
|
|
@@ -90,53 +98,37 @@ class LandsatOliTirs(DataSource):
|
|
|
90
98
|
|
|
91
99
|
def __init__(
|
|
92
100
|
self,
|
|
93
|
-
|
|
94
|
-
metadata_cache_dir: UPath,
|
|
95
|
-
max_time_delta: timedelta = timedelta(days=30),
|
|
101
|
+
metadata_cache_dir: str,
|
|
96
102
|
sort_by: str | None = None,
|
|
103
|
+
context: DataSourceContext = DataSourceContext(),
|
|
97
104
|
) -> None:
|
|
98
105
|
"""Initialize a new LandsatOliTirs instance.
|
|
99
106
|
|
|
100
107
|
Args:
|
|
101
|
-
|
|
102
|
-
metadata_cache_dir: directory to cache product metadata files.
|
|
103
|
-
max_time_delta: maximum time before a query start time or after a
|
|
104
|
-
query end time to look for products. This is required due to the large
|
|
105
|
-
number of available products, and defaults to 30 days.
|
|
108
|
+
metadata_cache_dir: directory to cache produtc metadata files.
|
|
106
109
|
sort_by: can be "cloud_cover", default arbitrary order; only has effect for
|
|
107
110
|
SpaceMode.WITHIN.
|
|
111
|
+
context: the data source context.
|
|
108
112
|
"""
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
113
|
+
# If context is provided, we join the directory with the dataset path,
|
|
114
|
+
# otherwise we treat it directly as UPath.
|
|
115
|
+
if context.ds_path is not None:
|
|
116
|
+
self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
|
|
117
|
+
else:
|
|
118
|
+
self.metadata_cache_dir = UPath(metadata_cache_dir)
|
|
119
|
+
|
|
112
120
|
self.sort_by = sort_by
|
|
113
121
|
|
|
122
|
+
self.client = boto3.client("s3")
|
|
114
123
|
self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
|
|
115
|
-
|
|
116
124
|
self.metadata_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
117
125
|
|
|
118
|
-
|
|
119
|
-
def from_config(config: LayerConfig, ds_path: UPath) -> "LandsatOliTirs":
|
|
120
|
-
"""Creates a new LandsatOliTirs instance from a configuration dictionary."""
|
|
121
|
-
assert isinstance(config, RasterLayerConfig)
|
|
122
|
-
d = config.data_source.config_dict
|
|
123
|
-
kwargs = dict(
|
|
124
|
-
config=config,
|
|
125
|
-
metadata_cache_dir=join_upath(ds_path, d["metadata_cache_dir"]),
|
|
126
|
-
)
|
|
127
|
-
if "max_time_delta" in d:
|
|
128
|
-
kwargs["max_time_delta"] = timedelta(
|
|
129
|
-
seconds=pytimeparse.parse(d["max_time_delta"])
|
|
130
|
-
)
|
|
131
|
-
if "sort_by" in d:
|
|
132
|
-
kwargs["sort_by"] = d["sort_by"]
|
|
133
|
-
|
|
134
|
-
return LandsatOliTirs(**kwargs)
|
|
126
|
+
self.wrs2_index: GridIndex | None = None
|
|
135
127
|
|
|
136
128
|
def _read_products(
|
|
137
129
|
self, needed_year_pathrows: set[tuple[int, str, str]]
|
|
138
130
|
) -> Generator[LandsatOliTirsItem, None, None]:
|
|
139
|
-
"""Read
|
|
131
|
+
"""Read _stac.json files and yield relevant LandsatOliTirsItems.
|
|
140
132
|
|
|
141
133
|
Args:
|
|
142
134
|
needed_year_pathrows: set of (year, path, row) where we need to search for
|
|
@@ -155,7 +147,10 @@ class LandsatOliTirs(DataSource):
|
|
|
155
147
|
for obj in self.bucket.objects.filter(
|
|
156
148
|
Prefix=prefix, RequestPayer="requester"
|
|
157
149
|
):
|
|
158
|
-
|
|
150
|
+
# Only read the _stac.json files.
|
|
151
|
+
# Previously we used _MTL.json but those files don't have the full
|
|
152
|
+
# geometry of the Landsat scene, only the bounding box.
|
|
153
|
+
if not obj.key.endswith("_stac.json"):
|
|
159
154
|
continue
|
|
160
155
|
# Load JSON data.
|
|
161
156
|
buf = io.BytesIO()
|
|
@@ -163,33 +158,32 @@ class LandsatOliTirs(DataSource):
|
|
|
163
158
|
obj.key, buf, ExtraArgs={"RequestPayer": "requester"}
|
|
164
159
|
)
|
|
165
160
|
buf.seek(0)
|
|
166
|
-
|
|
167
|
-
metadata = product["LANDSAT_METADATA_FILE"]
|
|
168
|
-
image_attributes = metadata["IMAGE_ATTRIBUTES"]
|
|
169
|
-
projection_attributes = metadata["PROJECTION_ATTRIBUTES"]
|
|
161
|
+
stac_data = json.load(buf)
|
|
170
162
|
|
|
171
163
|
# Get polygon coordinates.
|
|
172
|
-
|
|
173
|
-
for corner_id in ["UL", "UR", "LR", "LL"]:
|
|
174
|
-
lon = projection_attributes[f"CORNER_{corner_id}_LON_PRODUCT"]
|
|
175
|
-
lat = projection_attributes[f"CORNER_{corner_id}_LAT_PRODUCT"]
|
|
176
|
-
coordinates.append((lon, lat))
|
|
164
|
+
shp = shapely.geometry.shape(stac_data["geometry"])
|
|
177
165
|
|
|
178
166
|
# Get datetime.
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
167
|
+
ts = dateutil.parser.isoparse(stac_data["properties"]["datetime"])
|
|
168
|
+
|
|
169
|
+
blob_path = obj.key.split("stac.json")[0]
|
|
170
|
+
time_range: tuple[datetime, datetime] = (ts, ts)
|
|
171
|
+
geometry = STGeometry(WGS84_PROJECTION, shp, time_range)
|
|
172
|
+
cloud_cover: float
|
|
173
|
+
if "eo:cloud_cover" in stac_data["properties"]:
|
|
174
|
+
cloud_cover = stac_data["properties"]["eo:cloud_cover"]
|
|
175
|
+
elif "landsat:cloud_cover_land" in stac_data["properties"]:
|
|
176
|
+
cloud_cover = stac_data["properties"][
|
|
177
|
+
"landsat:cloud_cover_land"
|
|
178
|
+
]
|
|
179
|
+
else:
|
|
180
|
+
cloud_cover = -1
|
|
187
181
|
items.append(
|
|
188
182
|
LandsatOliTirsItem(
|
|
189
|
-
name=
|
|
183
|
+
name=stac_data["id"],
|
|
190
184
|
geometry=geometry,
|
|
191
185
|
blob_path=blob_path,
|
|
192
|
-
cloud_cover=
|
|
186
|
+
cloud_cover=cloud_cover,
|
|
193
187
|
)
|
|
194
188
|
)
|
|
195
189
|
|
|
@@ -205,7 +199,7 @@ class LandsatOliTirs(DataSource):
|
|
|
205
199
|
|
|
206
200
|
yield from items
|
|
207
201
|
|
|
208
|
-
def
|
|
202
|
+
def _get_wrs2_polygons(self) -> list[tuple[shapely.Geometry, str, str]]:
|
|
209
203
|
"""Get polygons for each (path, row) in the WRS2 grid.
|
|
210
204
|
|
|
211
205
|
Returns:
|
|
@@ -216,6 +210,7 @@ class LandsatOliTirs(DataSource):
|
|
|
216
210
|
if not shp_fname.exists():
|
|
217
211
|
# Download and extract zip to cache dir.
|
|
218
212
|
zip_fname = self.metadata_cache_dir / f"{prefix}.zip"
|
|
213
|
+
print(f"Downloading {self.wrs2_url} to {zip_fname}")
|
|
219
214
|
with urllib.request.urlopen(self.wrs2_url) as response:
|
|
220
215
|
with zip_fname.open("wb") as f:
|
|
221
216
|
shutil.copyfileobj(response, f)
|
|
@@ -257,9 +252,22 @@ class LandsatOliTirs(DataSource):
|
|
|
257
252
|
polygons.append((shp, path, row))
|
|
258
253
|
return polygons
|
|
259
254
|
|
|
255
|
+
def _get_wrs2_index(self) -> GridIndex:
|
|
256
|
+
"""Get a grid index over the WRS2 polygons."""
|
|
257
|
+
if self.wrs2_index is not None:
|
|
258
|
+
return self.wrs2_index
|
|
259
|
+
|
|
260
|
+
# Index doesn't exist so we need to build it.
|
|
261
|
+
# We cache it with the object since it takes a bit of time to create it.
|
|
262
|
+
polygons = self._get_wrs2_polygons()
|
|
263
|
+
self.wrs2_index = GridIndex(WRS2_GRID_SIZE)
|
|
264
|
+
for polygon, path, row in polygons:
|
|
265
|
+
self.wrs2_index.insert(polygon.bounds, (polygon, path, row))
|
|
266
|
+
return self.wrs2_index
|
|
267
|
+
|
|
260
268
|
def get_items(
|
|
261
269
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
262
|
-
) -> list[list[list[
|
|
270
|
+
) -> list[list[list[LandsatOliTirsItem]]]:
|
|
263
271
|
"""Get a list of items in the data source intersecting the given geometries.
|
|
264
272
|
|
|
265
273
|
Args:
|
|
@@ -269,7 +277,7 @@ class LandsatOliTirs(DataSource):
|
|
|
269
277
|
Returns:
|
|
270
278
|
List of groups of items that should be retrieved for each geometry.
|
|
271
279
|
"""
|
|
272
|
-
|
|
280
|
+
wrs2_index = self._get_wrs2_index()
|
|
273
281
|
needed_year_pathrows = set()
|
|
274
282
|
wgs84_geometries = [
|
|
275
283
|
geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
|
|
@@ -280,13 +288,13 @@ class LandsatOliTirs(DataSource):
|
|
|
280
288
|
"Landsat on AWS requires geometry time ranges to be set"
|
|
281
289
|
)
|
|
282
290
|
cur_pathrows = set()
|
|
283
|
-
for polygon, path, row in
|
|
291
|
+
for polygon, path, row in wrs2_index.query(wgs84_geometry.shp.bounds):
|
|
284
292
|
if wgs84_geometry.shp.intersects(polygon):
|
|
285
293
|
cur_pathrows.add((path, row))
|
|
286
294
|
for path, row in cur_pathrows:
|
|
287
295
|
for year in range(
|
|
288
|
-
|
|
289
|
-
|
|
296
|
+
wgs84_geometry.time_range[0].year,
|
|
297
|
+
wgs84_geometry.time_range[1].year + 1,
|
|
290
298
|
):
|
|
291
299
|
needed_year_pathrows.add((year, path, row))
|
|
292
300
|
|
|
@@ -301,18 +309,22 @@ class LandsatOliTirs(DataSource):
|
|
|
301
309
|
cur_items.append(item)
|
|
302
310
|
|
|
303
311
|
if self.sort_by == "cloud_cover":
|
|
304
|
-
|
|
312
|
+
cur_items.sort(
|
|
313
|
+
key=lambda item: item.cloud_cover if item.cloud_cover >= 0 else 100
|
|
314
|
+
)
|
|
305
315
|
elif self.sort_by is not None:
|
|
306
316
|
raise ValueError(f"invalid sort_by setting ({self.sort_by})")
|
|
307
317
|
|
|
308
|
-
cur_groups =
|
|
309
|
-
|
|
318
|
+
cur_groups: list[list[LandsatOliTirsItem]] = (
|
|
319
|
+
rslearn.data_sources.utils.match_candidate_items_to_window(
|
|
320
|
+
geometry, cur_items, query_config
|
|
321
|
+
)
|
|
310
322
|
)
|
|
311
323
|
groups.append(cur_groups)
|
|
312
324
|
|
|
313
325
|
return groups
|
|
314
326
|
|
|
315
|
-
def get_item_by_name(self, name: str) ->
|
|
327
|
+
def get_item_by_name(self, name: str) -> LandsatOliTirsItem:
|
|
316
328
|
"""Gets an item by name."""
|
|
317
329
|
# Product name is like LC08_L1TP_046027_20230715_20230724_02_T1.
|
|
318
330
|
# We want to use _read_products so we need to extract:
|
|
@@ -330,12 +342,14 @@ class LandsatOliTirs(DataSource):
|
|
|
330
342
|
return item
|
|
331
343
|
raise ValueError(f"item {name} not found")
|
|
332
344
|
|
|
333
|
-
def deserialize_item(self, serialized_item: Any) ->
|
|
345
|
+
def deserialize_item(self, serialized_item: Any) -> LandsatOliTirsItem:
|
|
334
346
|
"""Deserializes an item from JSON-decoded data."""
|
|
335
347
|
assert isinstance(serialized_item, dict)
|
|
336
348
|
return LandsatOliTirsItem.deserialize(serialized_item)
|
|
337
349
|
|
|
338
|
-
def retrieve_item(
|
|
350
|
+
def retrieve_item(
|
|
351
|
+
self, item: LandsatOliTirsItem
|
|
352
|
+
) -> Generator[tuple[str, BinaryIO], None, None]:
|
|
339
353
|
"""Retrieves the rasters corresponding to an item as file streams."""
|
|
340
354
|
for band in self.bands:
|
|
341
355
|
buf = io.BytesIO()
|
|
@@ -350,8 +364,8 @@ class LandsatOliTirs(DataSource):
|
|
|
350
364
|
|
|
351
365
|
def ingest(
|
|
352
366
|
self,
|
|
353
|
-
tile_store:
|
|
354
|
-
items: list[
|
|
367
|
+
tile_store: TileStoreWithLayer,
|
|
368
|
+
items: list[LandsatOliTirsItem],
|
|
355
369
|
geometries: list[list[STGeometry]],
|
|
356
370
|
) -> None:
|
|
357
371
|
"""Ingest items into the given tile store.
|
|
@@ -364,28 +378,158 @@ class LandsatOliTirs(DataSource):
|
|
|
364
378
|
for item, cur_geometries in zip(items, geometries):
|
|
365
379
|
for band in self.bands:
|
|
366
380
|
band_names = [band]
|
|
367
|
-
|
|
368
|
-
tile_store, (item.name, "_".join(band_names))
|
|
369
|
-
)
|
|
370
|
-
needed_projections = get_needed_projections(
|
|
371
|
-
cur_tile_store, band_names, self.config.band_sets, cur_geometries
|
|
372
|
-
)
|
|
373
|
-
if not needed_projections:
|
|
381
|
+
if tile_store.is_raster_ready(item.name, band_names):
|
|
374
382
|
continue
|
|
375
383
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
384
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
385
|
+
fname = os.path.join(tmp_dir, f"{band}.tif")
|
|
386
|
+
self.bucket.download_file(
|
|
387
|
+
item.blob_path + f"{band}.TIF",
|
|
388
|
+
fname,
|
|
389
|
+
ExtraArgs={"RequestPayer": "requester"},
|
|
390
|
+
)
|
|
391
|
+
tile_store.write_raster_file(item.name, band_names, UPath(fname))
|
|
392
|
+
|
|
393
|
+
# The functions below are to emulate TileStore functionality so we can easily
|
|
394
|
+
# support materialization directly from the COGs.
|
|
395
|
+
def is_raster_ready(
|
|
396
|
+
self, layer_name: str, item_name: str, bands: list[str]
|
|
397
|
+
) -> bool:
|
|
398
|
+
"""Checks if this raster has been written to the store.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
layer_name: the layer name or alias.
|
|
402
|
+
item_name: the item.
|
|
403
|
+
bands: the list of bands identifying which specific raster to read.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
whether there is a raster in the store matching the source, item, and
|
|
407
|
+
bands.
|
|
408
|
+
"""
|
|
409
|
+
# Always ready since we access it on AWS bucket.
|
|
410
|
+
return True
|
|
411
|
+
|
|
412
|
+
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
413
|
+
"""Get the sets of bands that have been stored for the specified item.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
layer_name: the layer name or alias.
|
|
417
|
+
item_name: the item.
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
a list of lists of bands that are in the tile store (with one raster
|
|
421
|
+
stored corresponding to each inner list). If no rasters are ready for
|
|
422
|
+
this item, returns empty list.
|
|
423
|
+
"""
|
|
424
|
+
return [[band] for band in self.bands]
|
|
425
|
+
|
|
426
|
+
def get_raster_bounds(
|
|
427
|
+
self, layer_name: str, item_name: str, bands: list[str], projection: Projection
|
|
428
|
+
) -> PixelBounds:
|
|
429
|
+
"""Get the bounds of the raster in the specified projection.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
layer_name: the layer name or alias.
|
|
433
|
+
item_name: the item to check.
|
|
434
|
+
bands: the list of bands identifying which specific raster to read. These
|
|
435
|
+
bands must match the bands of a stored raster.
|
|
436
|
+
projection: the projection to get the raster's bounds in.
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
the bounds of the raster in the projection.
|
|
440
|
+
"""
|
|
441
|
+
item = self.get_item_by_name(item_name)
|
|
442
|
+
geom = item.geometry.to_projection(projection)
|
|
443
|
+
return (
|
|
444
|
+
int(geom.shp.bounds[0]),
|
|
445
|
+
int(geom.shp.bounds[1]),
|
|
446
|
+
int(geom.shp.bounds[2]),
|
|
447
|
+
int(geom.shp.bounds[3]),
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
def read_raster(
|
|
451
|
+
self,
|
|
452
|
+
layer_name: str,
|
|
453
|
+
item_name: str,
|
|
454
|
+
bands: list[str],
|
|
455
|
+
projection: Projection,
|
|
456
|
+
bounds: PixelBounds,
|
|
457
|
+
resampling: Resampling = Resampling.bilinear,
|
|
458
|
+
) -> npt.NDArray[Any]:
|
|
459
|
+
"""Read raster data from the store.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
layer_name: the layer name or alias.
|
|
463
|
+
item_name: the item to read.
|
|
464
|
+
bands: the list of bands identifying which specific raster to read. These
|
|
465
|
+
bands must match the bands of a stored raster.
|
|
466
|
+
projection: the projection to read in.
|
|
467
|
+
bounds: the bounds to read.
|
|
468
|
+
resampling: the resampling method to use in case reprojection is needed.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
the raster data
|
|
472
|
+
"""
|
|
473
|
+
# Landsat assets have single band per asset.
|
|
474
|
+
assert len(bands) == 1
|
|
475
|
+
band = bands[0]
|
|
476
|
+
|
|
477
|
+
# Get the item since it has the blob path.
|
|
478
|
+
item = self.get_item_by_name(item_name)
|
|
479
|
+
|
|
480
|
+
# Create pre-signed URL for rasterio access.
|
|
481
|
+
# We do this because accessing via URL is much faster since rasterio can use
|
|
482
|
+
# the URL directly.
|
|
483
|
+
blob_key = item.blob_path + f"{band}.TIF"
|
|
484
|
+
url = self.client.generate_presigned_url(
|
|
485
|
+
"get_object",
|
|
486
|
+
Params={
|
|
487
|
+
"Bucket": self.bucket_name,
|
|
488
|
+
"Key": blob_key,
|
|
489
|
+
"RequestPayer": "requester",
|
|
490
|
+
},
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Construct the transform to use for the warped dataset.
|
|
494
|
+
wanted_transform = affine.Affine(
|
|
495
|
+
projection.x_resolution,
|
|
496
|
+
0,
|
|
497
|
+
bounds[0] * projection.x_resolution,
|
|
498
|
+
0,
|
|
499
|
+
projection.y_resolution,
|
|
500
|
+
bounds[1] * projection.y_resolution,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
with rasterio.open(url) as src:
|
|
504
|
+
with rasterio.vrt.WarpedVRT(
|
|
505
|
+
src,
|
|
506
|
+
crs=projection.crs,
|
|
507
|
+
transform=wanted_transform,
|
|
508
|
+
width=bounds[2] - bounds[0],
|
|
509
|
+
height=bounds[3] - bounds[1],
|
|
510
|
+
resampling=resampling,
|
|
511
|
+
) as vrt:
|
|
512
|
+
return vrt.read()
|
|
513
|
+
|
|
514
|
+
def materialize(
|
|
515
|
+
self,
|
|
516
|
+
window: Window,
|
|
517
|
+
item_groups: list[list[LandsatOliTirsItem]],
|
|
518
|
+
layer_name: str,
|
|
519
|
+
layer_cfg: LayerConfig,
|
|
520
|
+
) -> None:
|
|
521
|
+
"""Materialize data for the window.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
window: the window to materialize
|
|
525
|
+
item_groups: the items from get_items
|
|
526
|
+
layer_name: the name of this layer
|
|
527
|
+
layer_cfg: the config of this layer
|
|
528
|
+
"""
|
|
529
|
+
RasterMaterializer().materialize(
|
|
530
|
+
TileStoreWithLayer(self, layer_name),
|
|
531
|
+
window,
|
|
532
|
+
layer_name,
|
|
533
|
+
layer_cfg,
|
|
534
|
+
item_groups,
|
|
535
|
+
)
|