rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +3 -3
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/utils.py +204 -64
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/concatenate_features.py +6 -1
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +32 -27
- rslearn/train/dataset.py +260 -117
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +19 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/regression.py +1 -1
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
|
@@ -6,23 +6,20 @@ from collections.abc import Callable
|
|
|
6
6
|
from datetime import timedelta
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
-
import affine
|
|
10
9
|
import numpy as np
|
|
11
10
|
import numpy.typing as npt
|
|
12
11
|
import rasterio
|
|
13
12
|
import requests
|
|
14
|
-
from rasterio.enums import Resampling
|
|
15
13
|
from upath import UPath
|
|
16
14
|
|
|
17
|
-
from rslearn.
|
|
15
|
+
from rslearn.data_sources.direct_materialize_data_source import (
|
|
16
|
+
DirectMaterializeDataSource,
|
|
17
|
+
)
|
|
18
18
|
from rslearn.data_sources.stac import SourceItem, StacDataSource
|
|
19
|
-
from rslearn.dataset import Window
|
|
20
|
-
from rslearn.dataset.manage import RasterMaterializer
|
|
21
19
|
from rslearn.log_utils import get_logger
|
|
22
|
-
from rslearn.tile_stores import
|
|
23
|
-
from rslearn.utils import
|
|
20
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
21
|
+
from rslearn.utils import STGeometry
|
|
24
22
|
from rslearn.utils.fsspec import join_upath
|
|
25
|
-
from rslearn.utils.geometry import PixelBounds
|
|
26
23
|
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
27
24
|
|
|
28
25
|
from .data_source import (
|
|
@@ -32,7 +29,7 @@ from .data_source import (
|
|
|
32
29
|
logger = get_logger(__name__)
|
|
33
30
|
|
|
34
31
|
|
|
35
|
-
class Sentinel2(
|
|
32
|
+
class Sentinel2(DirectMaterializeDataSource[SourceItem], StacDataSource):
|
|
36
33
|
"""A data source for Sentinel-2 L2A imagery on AWS from s3://sentinel-cogs.
|
|
37
34
|
|
|
38
35
|
The S3 bucket has COGs so this data source supports direct materialization. It also
|
|
@@ -97,31 +94,36 @@ class Sentinel2(StacDataSource, TileStore):
|
|
|
97
94
|
cache_upath.mkdir(parents=True, exist_ok=True)
|
|
98
95
|
|
|
99
96
|
# Determine which assets we need based on the bands in the layer config.
|
|
100
|
-
|
|
97
|
+
asset_bands: dict[str, list[str]]
|
|
101
98
|
if context.layer_config is not None:
|
|
102
|
-
|
|
99
|
+
asset_bands = {}
|
|
103
100
|
for asset_key, band_names in self.ASSET_BANDS.items():
|
|
104
101
|
# See if the bands provided by this asset intersect with the bands in
|
|
105
102
|
# at least one configured band set.
|
|
106
103
|
for band_set in context.layer_config.band_sets:
|
|
107
104
|
if not set(band_set.bands).intersection(set(band_names)):
|
|
108
105
|
continue
|
|
109
|
-
|
|
106
|
+
asset_bands[asset_key] = band_names
|
|
110
107
|
break
|
|
111
108
|
elif assets is not None:
|
|
112
|
-
|
|
109
|
+
asset_bands = {
|
|
113
110
|
asset_key: self.ASSET_BANDS[asset_key] for asset_key in assets
|
|
114
111
|
}
|
|
115
112
|
else:
|
|
116
|
-
|
|
113
|
+
asset_bands = dict(self.ASSET_BANDS)
|
|
114
|
+
|
|
115
|
+
# Initialize DirectMaterializeDataSource with asset_bands
|
|
116
|
+
DirectMaterializeDataSource.__init__(self, asset_bands=asset_bands)
|
|
117
117
|
|
|
118
|
-
|
|
118
|
+
# Initialize StacDataSource
|
|
119
|
+
StacDataSource.__init__(
|
|
120
|
+
self,
|
|
119
121
|
endpoint=self.STAC_ENDPOINT,
|
|
120
122
|
collection_name=self.COLLECTION_NAME,
|
|
121
123
|
query=query,
|
|
122
124
|
sort_by=sort_by,
|
|
123
125
|
sort_ascending=sort_ascending,
|
|
124
|
-
required_assets=list(
|
|
126
|
+
required_assets=list(asset_bands.keys()),
|
|
125
127
|
cache_dir=cache_upath,
|
|
126
128
|
properties_to_record=[self.HARMONIZE_PROPERTY_NAME],
|
|
127
129
|
)
|
|
@@ -129,6 +131,42 @@ class Sentinel2(StacDataSource, TileStore):
|
|
|
129
131
|
self.harmonize = harmonize
|
|
130
132
|
self.timeout = timeout
|
|
131
133
|
|
|
134
|
+
# --- DirectMaterializeDataSource implementation ---
|
|
135
|
+
|
|
136
|
+
def get_asset_url(self, item_name: str, asset_key: str) -> str:
|
|
137
|
+
"""Get the URL to read the asset for the given item and asset key.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
item_name: the name of the item.
|
|
141
|
+
asset_key: the key identifying which asset to get.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
the URL to read the asset from.
|
|
145
|
+
"""
|
|
146
|
+
item = self.get_item_by_name(item_name)
|
|
147
|
+
return item.asset_urls[asset_key]
|
|
148
|
+
|
|
149
|
+
def get_read_callback(
|
|
150
|
+
self, item_name: str, asset_key: str
|
|
151
|
+
) -> Callable[[npt.NDArray[Any]], npt.NDArray[Any]] | None:
|
|
152
|
+
"""Return a callback to harmonize Sentinel-2 data if needed.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
item_name: the name of the item being read.
|
|
156
|
+
asset_key: the key identifying which asset is being read.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
A callback function for harmonization, or None if not needed.
|
|
160
|
+
"""
|
|
161
|
+
# Visual bands do not need harmonization.
|
|
162
|
+
if not self.harmonize or asset_key == "visual":
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
item = self.get_item_by_name(item_name)
|
|
166
|
+
return self._get_harmonize_callback(item)
|
|
167
|
+
|
|
168
|
+
# --- Harmonization helpers ---
|
|
169
|
+
|
|
132
170
|
def _get_harmonize_callback(
|
|
133
171
|
self, item: SourceItem
|
|
134
172
|
) -> Callable[[npt.NDArray], npt.NDArray] | None:
|
|
@@ -223,152 +261,3 @@ class Sentinel2(StacDataSource, TileStore):
|
|
|
223
261
|
item.name,
|
|
224
262
|
asset_key,
|
|
225
263
|
)
|
|
226
|
-
|
|
227
|
-
def is_raster_ready(
|
|
228
|
-
self, layer_name: str, item_name: str, bands: list[str]
|
|
229
|
-
) -> bool:
|
|
230
|
-
"""Checks if this raster has been written to the store.
|
|
231
|
-
|
|
232
|
-
Args:
|
|
233
|
-
layer_name: the layer name or alias.
|
|
234
|
-
item_name: the item.
|
|
235
|
-
bands: the list of bands identifying which specific raster to read.
|
|
236
|
-
|
|
237
|
-
Returns:
|
|
238
|
-
whether there is a raster in the store matching the source, item, and
|
|
239
|
-
bands.
|
|
240
|
-
"""
|
|
241
|
-
# Always ready since we wrap accesses to underlying API.
|
|
242
|
-
return True
|
|
243
|
-
|
|
244
|
-
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
245
|
-
"""Get the sets of bands that have been stored for the specified item.
|
|
246
|
-
|
|
247
|
-
Args:
|
|
248
|
-
layer_name: the layer name or alias.
|
|
249
|
-
item_name: the item.
|
|
250
|
-
|
|
251
|
-
Returns:
|
|
252
|
-
a list of lists of bands that are in the tile store (with one raster
|
|
253
|
-
stored corresponding to each inner list). If no rasters are ready for
|
|
254
|
-
this item, returns empty list.
|
|
255
|
-
"""
|
|
256
|
-
return list(self.asset_bands.values())
|
|
257
|
-
|
|
258
|
-
def _get_asset_by_band(self, bands: list[str]) -> str:
|
|
259
|
-
"""Get the name of the asset based on the band names."""
|
|
260
|
-
for asset_key, asset_bands in self.asset_bands.items():
|
|
261
|
-
if bands == asset_bands:
|
|
262
|
-
return asset_key
|
|
263
|
-
|
|
264
|
-
raise ValueError(f"no known asset with bands {bands}")
|
|
265
|
-
|
|
266
|
-
def get_raster_bounds(
|
|
267
|
-
self, layer_name: str, item_name: str, bands: list[str], projection: Projection
|
|
268
|
-
) -> PixelBounds:
|
|
269
|
-
"""Get the bounds of the raster in the specified projection.
|
|
270
|
-
|
|
271
|
-
Args:
|
|
272
|
-
layer_name: the layer name or alias.
|
|
273
|
-
item_name: the item to check.
|
|
274
|
-
bands: the list of bands identifying which specific raster to read. These
|
|
275
|
-
bands must match the bands of a stored raster.
|
|
276
|
-
projection: the projection to get the raster's bounds in.
|
|
277
|
-
|
|
278
|
-
Returns:
|
|
279
|
-
the bounds of the raster in the projection.
|
|
280
|
-
"""
|
|
281
|
-
item = self.get_item_by_name(item_name)
|
|
282
|
-
geom = item.geometry.to_projection(projection)
|
|
283
|
-
return (
|
|
284
|
-
int(geom.shp.bounds[0]),
|
|
285
|
-
int(geom.shp.bounds[1]),
|
|
286
|
-
int(geom.shp.bounds[2]),
|
|
287
|
-
int(geom.shp.bounds[3]),
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
def read_raster(
|
|
291
|
-
self,
|
|
292
|
-
layer_name: str,
|
|
293
|
-
item_name: str,
|
|
294
|
-
bands: list[str],
|
|
295
|
-
projection: Projection,
|
|
296
|
-
bounds: PixelBounds,
|
|
297
|
-
resampling: Resampling = Resampling.bilinear,
|
|
298
|
-
) -> npt.NDArray[Any]:
|
|
299
|
-
"""Read raster data from the store.
|
|
300
|
-
|
|
301
|
-
Args:
|
|
302
|
-
layer_name: the layer name or alias.
|
|
303
|
-
item_name: the item to read.
|
|
304
|
-
bands: the list of bands identifying which specific raster to read. These
|
|
305
|
-
bands must match the bands of a stored raster.
|
|
306
|
-
projection: the projection to read in.
|
|
307
|
-
bounds: the bounds to read.
|
|
308
|
-
resampling: the resampling method to use in case reprojection is needed.
|
|
309
|
-
|
|
310
|
-
Returns:
|
|
311
|
-
the raster data
|
|
312
|
-
"""
|
|
313
|
-
asset_key = self._get_asset_by_band(bands)
|
|
314
|
-
item = self.get_item_by_name(item_name)
|
|
315
|
-
asset_url = item.asset_urls[asset_key]
|
|
316
|
-
|
|
317
|
-
# Construct the transform to use for the warped dataset.
|
|
318
|
-
wanted_transform = affine.Affine(
|
|
319
|
-
projection.x_resolution,
|
|
320
|
-
0,
|
|
321
|
-
bounds[0] * projection.x_resolution,
|
|
322
|
-
0,
|
|
323
|
-
projection.y_resolution,
|
|
324
|
-
bounds[1] * projection.y_resolution,
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
# Read from the raster under the specified projection/bounds.
|
|
328
|
-
with rasterio.open(asset_url) as src:
|
|
329
|
-
with rasterio.vrt.WarpedVRT(
|
|
330
|
-
src,
|
|
331
|
-
crs=projection.crs,
|
|
332
|
-
transform=wanted_transform,
|
|
333
|
-
width=bounds[2] - bounds[0],
|
|
334
|
-
height=bounds[3] - bounds[1],
|
|
335
|
-
resampling=resampling,
|
|
336
|
-
) as vrt:
|
|
337
|
-
raw_data = vrt.read()
|
|
338
|
-
|
|
339
|
-
# We can return the data now if harmonization is not needed.
|
|
340
|
-
if not self.harmonize or bands == self.ASSET_BANDS["visual"]:
|
|
341
|
-
return raw_data
|
|
342
|
-
|
|
343
|
-
# Otherwise we apply the harmonize_callback.
|
|
344
|
-
item = self.get_item_by_name(item_name)
|
|
345
|
-
harmonize_callback = self._get_harmonize_callback(item)
|
|
346
|
-
|
|
347
|
-
if harmonize_callback is None:
|
|
348
|
-
return raw_data
|
|
349
|
-
|
|
350
|
-
array = harmonize_callback(raw_data)
|
|
351
|
-
return array
|
|
352
|
-
|
|
353
|
-
def materialize(
|
|
354
|
-
self,
|
|
355
|
-
window: Window,
|
|
356
|
-
item_groups: list[list[SourceItem]],
|
|
357
|
-
layer_name: str,
|
|
358
|
-
layer_cfg: LayerConfig,
|
|
359
|
-
) -> None:
|
|
360
|
-
"""Materialize data for the window.
|
|
361
|
-
|
|
362
|
-
Args:
|
|
363
|
-
window: the window to materialize
|
|
364
|
-
item_groups: the items from get_items
|
|
365
|
-
layer_name: the name of this layer
|
|
366
|
-
layer_cfg: the config of this layer
|
|
367
|
-
"""
|
|
368
|
-
RasterMaterializer().materialize(
|
|
369
|
-
TileStoreWithLayer(self, layer_name),
|
|
370
|
-
window,
|
|
371
|
-
layer_name,
|
|
372
|
-
layer_cfg,
|
|
373
|
-
item_groups,
|
|
374
|
-
)
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
5
|
from datetime import UTC, datetime
|
|
6
|
-
from typing import Any
|
|
7
6
|
|
|
8
7
|
import cdsapi
|
|
9
8
|
import netCDF4
|
|
@@ -160,9 +159,8 @@ class ERA5Land(DataSource):
|
|
|
160
159
|
|
|
161
160
|
return all_groups
|
|
162
161
|
|
|
163
|
-
def deserialize_item(self, serialized_item:
|
|
162
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
164
163
|
"""Deserializes an item from JSON-decoded data."""
|
|
165
|
-
assert isinstance(serialized_item, dict)
|
|
166
164
|
return Item.deserialize(serialized_item)
|
|
167
165
|
|
|
168
166
|
def _convert_nc_to_tif(self, nc_path: UPath, tif_path: UPath) -> None:
|
|
@@ -353,9 +353,8 @@ class Copernicus(DataSource):
|
|
|
353
353
|
self.username = os.environ["COPERNICUS_USERNAME"]
|
|
354
354
|
self.password = os.environ["COPERNICUS_PASSWORD"]
|
|
355
355
|
|
|
356
|
-
def deserialize_item(self, serialized_item:
|
|
356
|
+
def deserialize_item(self, serialized_item: dict) -> CopernicusItem:
|
|
357
357
|
"""Deserializes an item from JSON-decoded data."""
|
|
358
|
-
assert isinstance(serialized_item, dict)
|
|
359
358
|
return CopernicusItem.deserialize(serialized_item)
|
|
360
359
|
|
|
361
360
|
def _get(self, path: str) -> dict[str, Any]:
|
|
@@ -76,7 +76,7 @@ class DataSource(Generic[ItemType]):
|
|
|
76
76
|
"""
|
|
77
77
|
raise NotImplementedError
|
|
78
78
|
|
|
79
|
-
def deserialize_item(self, serialized_item:
|
|
79
|
+
def deserialize_item(self, serialized_item: dict) -> ItemType:
|
|
80
80
|
"""Deserializes an item from JSON-decoded data."""
|
|
81
81
|
raise NotImplementedError
|
|
82
82
|
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
"""Base class for data sources that support direct materialization via TileStore."""
|
|
2
|
+
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Any, Generic
|
|
6
|
+
|
|
7
|
+
import affine
|
|
8
|
+
import numpy.typing as npt
|
|
9
|
+
import rasterio
|
|
10
|
+
import rasterio.vrt
|
|
11
|
+
from rasterio.enums import Resampling
|
|
12
|
+
|
|
13
|
+
from rslearn.config import LayerConfig
|
|
14
|
+
from rslearn.data_sources.data_source import DataSource, ItemType
|
|
15
|
+
from rslearn.dataset import Window
|
|
16
|
+
from rslearn.dataset.materialize import RasterMaterializer
|
|
17
|
+
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
18
|
+
from rslearn.utils.geometry import PixelBounds, Projection
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DirectMaterializeDataSource(DataSource[ItemType], TileStore, Generic[ItemType]):
|
|
22
|
+
"""Base class for data sources that support direct materialization via TileStore.
|
|
23
|
+
|
|
24
|
+
This class provides common TileStore functionality for data sources that can read
|
|
25
|
+
raster data on-demand from remote sources (like cloud buckets or APIs) without
|
|
26
|
+
first ingesting into a local tile store.
|
|
27
|
+
|
|
28
|
+
Subclasses must implement:
|
|
29
|
+
- get_asset_url(): Get the URL for an asset given item name and bands
|
|
30
|
+
- get_item_by_name(): Get an item by its name
|
|
31
|
+
|
|
32
|
+
Subclasses may optionally override:
|
|
33
|
+
- get_raster_bands(): By default, we assume that items have all assets. If
|
|
34
|
+
items may have a subset of assets, override get_raster_bands to return
|
|
35
|
+
the sets of bands available for that item.
|
|
36
|
+
- get_read_callback(): Returns a callback to transform the raster array,
|
|
37
|
+
for post-processing like Sentinel-2 harmonization.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, asset_bands: dict[str, list[str]]):
|
|
41
|
+
"""Initialize the DirectMaterializeDataSource.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
asset_bands: mapping from asset key to the list of band names in that asset.
|
|
45
|
+
"""
|
|
46
|
+
self.asset_bands = asset_bands
|
|
47
|
+
|
|
48
|
+
def _get_asset_key_by_bands(self, bands: list[str]) -> str:
|
|
49
|
+
"""Get the asset key based on the band names.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
bands: list of band names to look up.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
the asset key that provides those bands.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: if no asset provides those bands.
|
|
59
|
+
"""
|
|
60
|
+
for asset_key, asset_bands in self.asset_bands.items():
|
|
61
|
+
if bands == asset_bands:
|
|
62
|
+
return asset_key
|
|
63
|
+
raise ValueError(f"no known asset with bands {bands}")
|
|
64
|
+
|
|
65
|
+
# --- Methods that subclasses must implement ---
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def get_asset_url(self, item_name: str, asset_key: str) -> str:
|
|
69
|
+
"""Get the URL to read the asset for the given item and asset key.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
item_name: the name of the item.
|
|
73
|
+
asset_key: the key identifying which asset to get.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
the URL to read the asset from (must be readable by rasterio).
|
|
77
|
+
"""
|
|
78
|
+
raise NotImplementedError
|
|
79
|
+
|
|
80
|
+
def get_item_by_name(self, name: str) -> ItemType:
|
|
81
|
+
"""Get an item by its name.
|
|
82
|
+
|
|
83
|
+
Subclasses must implement this method, either directly or by inheriting from
|
|
84
|
+
a class that provides it (e.g., StacDataSource).
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
name: the name of the item to get.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
the item object.
|
|
91
|
+
"""
|
|
92
|
+
raise NotImplementedError
|
|
93
|
+
|
|
94
|
+
# --- Optional hooks for subclasses ---
|
|
95
|
+
|
|
96
|
+
def get_read_callback(
|
|
97
|
+
self, item_name: str, asset_key: str
|
|
98
|
+
) -> Callable[[npt.NDArray[Any]], npt.NDArray[Any]] | None:
|
|
99
|
+
"""Return a callback to post-process raster data (e.g., harmonization).
|
|
100
|
+
|
|
101
|
+
Subclasses can override this to apply transformations to the raw raster data
|
|
102
|
+
after reading, such as harmonization for Sentinel-2 data.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
item_name: the name of the item being read.
|
|
106
|
+
asset_key: the key identifying which asset is being read.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
A callback function that takes an array and returns a modified array,
|
|
110
|
+
or None if no post-processing is needed.
|
|
111
|
+
"""
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
# --- TileStore implementation ---
|
|
115
|
+
|
|
116
|
+
def is_raster_ready(
|
|
117
|
+
self, layer_name: str, item_name: str, bands: list[str]
|
|
118
|
+
) -> bool:
|
|
119
|
+
"""Checks if this raster has been written to the store.
|
|
120
|
+
|
|
121
|
+
For remote-backed tile stores, this always returns True since data is
|
|
122
|
+
read on-demand from the remote source.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
layer_name: the layer name or alias.
|
|
126
|
+
item_name: the item.
|
|
127
|
+
bands: the list of bands identifying which specific raster to read.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
True, since data is always available from the remote source.
|
|
131
|
+
"""
|
|
132
|
+
return True
|
|
133
|
+
|
|
134
|
+
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
135
|
+
"""Get the sets of bands that have been stored for the specified item.
|
|
136
|
+
|
|
137
|
+
By default, returns all band sets from the asset_bands configuration.
|
|
138
|
+
Subclasses can override this if not all items have all assets.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
layer_name: the layer name or alias.
|
|
142
|
+
item_name: the item.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
a list of lists of bands available for this item.
|
|
146
|
+
"""
|
|
147
|
+
return list(self.asset_bands.values())
|
|
148
|
+
|
|
149
|
+
def get_raster_bounds(
|
|
150
|
+
self, layer_name: str, item_name: str, bands: list[str], projection: Projection
|
|
151
|
+
) -> PixelBounds:
|
|
152
|
+
"""Get the bounds of the raster in the specified projection.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
layer_name: the layer name or alias.
|
|
156
|
+
item_name: the item to check.
|
|
157
|
+
bands: the list of bands identifying which specific raster to read.
|
|
158
|
+
projection: the projection to get the raster's bounds in.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
the bounds of the raster in the projection.
|
|
162
|
+
"""
|
|
163
|
+
item = self.get_item_by_name(item_name)
|
|
164
|
+
geom = item.geometry.to_projection(projection)
|
|
165
|
+
return (
|
|
166
|
+
int(geom.shp.bounds[0]),
|
|
167
|
+
int(geom.shp.bounds[1]),
|
|
168
|
+
int(geom.shp.bounds[2]),
|
|
169
|
+
int(geom.shp.bounds[3]),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def _read_raster_from_url(
|
|
173
|
+
self,
|
|
174
|
+
url: str,
|
|
175
|
+
projection: Projection,
|
|
176
|
+
bounds: PixelBounds,
|
|
177
|
+
resampling: Resampling,
|
|
178
|
+
) -> npt.NDArray[Any]:
|
|
179
|
+
"""Read raster data from a URL with reprojection.
|
|
180
|
+
|
|
181
|
+
This is the common logic for reading raster data from a URL and reprojecting
|
|
182
|
+
it to the target projection and bounds using rasterio's WarpedVRT.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
url: the URL to read from (must be readable by rasterio).
|
|
186
|
+
projection: the projection to read in.
|
|
187
|
+
bounds: the bounds to read.
|
|
188
|
+
resampling: the resampling method to use.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
the raster data as a numpy array.
|
|
192
|
+
"""
|
|
193
|
+
# Construct the transform to use for the warped dataset.
|
|
194
|
+
wanted_transform = affine.Affine(
|
|
195
|
+
projection.x_resolution,
|
|
196
|
+
0,
|
|
197
|
+
bounds[0] * projection.x_resolution,
|
|
198
|
+
0,
|
|
199
|
+
projection.y_resolution,
|
|
200
|
+
bounds[1] * projection.y_resolution,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
with rasterio.open(url) as src:
|
|
204
|
+
with rasterio.vrt.WarpedVRT(
|
|
205
|
+
src,
|
|
206
|
+
crs=projection.crs,
|
|
207
|
+
transform=wanted_transform,
|
|
208
|
+
width=bounds[2] - bounds[0],
|
|
209
|
+
height=bounds[3] - bounds[1],
|
|
210
|
+
resampling=resampling,
|
|
211
|
+
) as vrt:
|
|
212
|
+
return vrt.read()
|
|
213
|
+
|
|
214
|
+
def read_raster(
|
|
215
|
+
self,
|
|
216
|
+
layer_name: str,
|
|
217
|
+
item_name: str,
|
|
218
|
+
bands: list[str],
|
|
219
|
+
projection: Projection,
|
|
220
|
+
bounds: PixelBounds,
|
|
221
|
+
resampling: Resampling = Resampling.bilinear,
|
|
222
|
+
) -> npt.NDArray[Any]:
|
|
223
|
+
"""Read raster data from the store.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
layer_name: the layer name or alias.
|
|
227
|
+
item_name: the item to read.
|
|
228
|
+
bands: the list of bands identifying which specific raster to read.
|
|
229
|
+
projection: the projection to read in.
|
|
230
|
+
bounds: the bounds to read.
|
|
231
|
+
resampling: the resampling method to use in case reprojection is needed.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
the raster data as a numpy array.
|
|
235
|
+
"""
|
|
236
|
+
# Get the asset key for the requested bands
|
|
237
|
+
asset_key = self._get_asset_key_by_bands(bands)
|
|
238
|
+
|
|
239
|
+
# Get the asset URL from the subclass
|
|
240
|
+
asset_url = self.get_asset_url(item_name, asset_key)
|
|
241
|
+
|
|
242
|
+
# Read the raster data
|
|
243
|
+
raw_data = self._read_raster_from_url(asset_url, projection, bounds, resampling)
|
|
244
|
+
|
|
245
|
+
# Apply any post-processing callback
|
|
246
|
+
callback = self.get_read_callback(item_name, asset_key)
|
|
247
|
+
if callback is not None:
|
|
248
|
+
raw_data = callback(raw_data)
|
|
249
|
+
|
|
250
|
+
return raw_data
|
|
251
|
+
|
|
252
|
+
def materialize(
|
|
253
|
+
self,
|
|
254
|
+
window: Window,
|
|
255
|
+
item_groups: list[list[ItemType]],
|
|
256
|
+
layer_name: str,
|
|
257
|
+
layer_cfg: LayerConfig,
|
|
258
|
+
) -> None:
|
|
259
|
+
"""Materialize data for the window.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
window: the window to materialize.
|
|
263
|
+
item_groups: the items from get_items.
|
|
264
|
+
layer_name: the name of this layer.
|
|
265
|
+
layer_cfg: the config of this layer.
|
|
266
|
+
"""
|
|
267
|
+
RasterMaterializer().materialize(
|
|
268
|
+
TileStoreWithLayer(self, layer_name),
|
|
269
|
+
window,
|
|
270
|
+
layer_name,
|
|
271
|
+
layer_cfg,
|
|
272
|
+
item_groups,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# --- TileStore methods that are not supported ---
|
|
276
|
+
|
|
277
|
+
def write_raster(
|
|
278
|
+
self,
|
|
279
|
+
layer_name: str,
|
|
280
|
+
item_name: str,
|
|
281
|
+
bands: list[str],
|
|
282
|
+
projection: Projection,
|
|
283
|
+
bounds: PixelBounds,
|
|
284
|
+
array: npt.NDArray[Any],
|
|
285
|
+
) -> None:
|
|
286
|
+
"""Write raster data to the store.
|
|
287
|
+
|
|
288
|
+
This is not supported for remote-backed tile stores.
|
|
289
|
+
"""
|
|
290
|
+
raise NotImplementedError(
|
|
291
|
+
"DirectMaterializeDataSource does not support writing raster data"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def write_raster_file(
|
|
295
|
+
self, layer_name: str, item_name: str, bands: list[str], fname: Any
|
|
296
|
+
) -> None:
|
|
297
|
+
"""Write raster data to the store.
|
|
298
|
+
|
|
299
|
+
This is not supported for remote-backed tile stores.
|
|
300
|
+
"""
|
|
301
|
+
raise NotImplementedError(
|
|
302
|
+
"DirectMaterializeDataSource does not support writing raster files"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def is_vector_ready(self, layer_name: str, item_name: str) -> bool:
|
|
306
|
+
"""Checks if this vector item has been written to the store.
|
|
307
|
+
|
|
308
|
+
This is not supported for remote-backed tile stores.
|
|
309
|
+
"""
|
|
310
|
+
raise NotImplementedError(
|
|
311
|
+
"DirectMaterializeDataSource does not support vector operations"
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def read_vector(
|
|
315
|
+
self,
|
|
316
|
+
layer_name: str,
|
|
317
|
+
item_name: str,
|
|
318
|
+
projection: Projection,
|
|
319
|
+
bounds: PixelBounds,
|
|
320
|
+
) -> Any:
|
|
321
|
+
"""Read vector data from the store.
|
|
322
|
+
|
|
323
|
+
This is not supported for remote-backed tile stores.
|
|
324
|
+
"""
|
|
325
|
+
raise NotImplementedError(
|
|
326
|
+
"DirectMaterializeDataSource does not support vector operations"
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def write_vector(self, layer_name: str, item_name: str, features: Any) -> None:
|
|
330
|
+
"""Write vector data to the store.
|
|
331
|
+
|
|
332
|
+
This is not supported for remote-backed tile stores.
|
|
333
|
+
"""
|
|
334
|
+
raise NotImplementedError(
|
|
335
|
+
"DirectMaterializeDataSource does not support vector operations"
|
|
336
|
+
)
|