rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +3 -3
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/utils.py +204 -64
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/concatenate_features.py +6 -1
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +32 -27
- rslearn/train/dataset.py +260 -117
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +19 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/regression.py +1 -1
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
|
@@ -6,27 +6,24 @@ import tempfile
|
|
|
6
6
|
from datetime import timedelta
|
|
7
7
|
from typing import Any, Literal
|
|
8
8
|
|
|
9
|
-
import affine
|
|
10
|
-
import numpy.typing as npt
|
|
11
9
|
import pystac
|
|
12
10
|
import pystac_client
|
|
13
|
-
import rasterio
|
|
14
11
|
import requests
|
|
15
12
|
import shapely
|
|
16
13
|
from earthdaily import EDSClient, EDSConfig
|
|
17
|
-
from rasterio.enums import Resampling
|
|
18
14
|
from upath import UPath
|
|
19
15
|
|
|
20
|
-
from rslearn.config import
|
|
16
|
+
from rslearn.config import QueryConfig
|
|
21
17
|
from rslearn.const import WGS84_PROJECTION
|
|
22
|
-
from rslearn.data_sources import
|
|
18
|
+
from rslearn.data_sources import DataSourceContext, Item
|
|
19
|
+
from rslearn.data_sources.direct_materialize_data_source import (
|
|
20
|
+
DirectMaterializeDataSource,
|
|
21
|
+
)
|
|
23
22
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
24
|
-
from rslearn.dataset import Window
|
|
25
|
-
from rslearn.dataset.materialize import RasterMaterializer
|
|
26
23
|
from rslearn.log_utils import get_logger
|
|
27
|
-
from rslearn.tile_stores import
|
|
24
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
28
25
|
from rslearn.utils.fsspec import join_upath
|
|
29
|
-
from rslearn.utils.geometry import
|
|
26
|
+
from rslearn.utils.geometry import STGeometry
|
|
30
27
|
|
|
31
28
|
logger = get_logger(__name__)
|
|
32
29
|
|
|
@@ -62,7 +59,7 @@ class EarthDailyItem(Item):
|
|
|
62
59
|
)
|
|
63
60
|
|
|
64
61
|
|
|
65
|
-
class EarthDaily(
|
|
62
|
+
class EarthDaily(DirectMaterializeDataSource[EarthDailyItem]):
|
|
66
63
|
"""A data source for EarthDaily data.
|
|
67
64
|
|
|
68
65
|
This requires the following environment variables to be set:
|
|
@@ -111,8 +108,9 @@ class EarthDaily(DataSource, TileStore):
|
|
|
111
108
|
services "legacy" and "internal" are not supported.
|
|
112
109
|
context: the data source context.
|
|
113
110
|
"""
|
|
111
|
+
super().__init__(asset_bands=asset_bands)
|
|
112
|
+
|
|
114
113
|
self.collection_name = collection_name
|
|
115
|
-
self.asset_bands = asset_bands
|
|
116
114
|
self.query = query
|
|
117
115
|
self.sort_by = sort_by
|
|
118
116
|
self.sort_ascending = sort_ascending
|
|
@@ -221,6 +219,47 @@ class EarthDaily(DataSource, TileStore):
|
|
|
221
219
|
|
|
222
220
|
return item
|
|
223
221
|
|
|
222
|
+
# --- DirectMaterializeDataSource implementation ---
|
|
223
|
+
|
|
224
|
+
def get_asset_url(self, item_name: str, asset_key: str) -> str:
|
|
225
|
+
"""Get the URL to read the asset for the given item and asset key.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
item_name: the name of the item.
|
|
229
|
+
asset_key: the key identifying which asset to get.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
the URL to read the asset from.
|
|
233
|
+
"""
|
|
234
|
+
item = self.get_item_by_name(item_name)
|
|
235
|
+
return item.asset_urls[asset_key]
|
|
236
|
+
|
|
237
|
+
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
238
|
+
"""Get the sets of bands that have been stored for the specified item.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
layer_name: the layer name or alias.
|
|
242
|
+
item_name: the item.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
a list of lists of bands available for this item.
|
|
246
|
+
"""
|
|
247
|
+
if self.skip_items_missing_assets:
|
|
248
|
+
# In this case we can assume that the item has all of the assets.
|
|
249
|
+
return list(self.asset_bands.values())
|
|
250
|
+
|
|
251
|
+
# Otherwise we have to lookup the STAC item to see which assets it has.
|
|
252
|
+
# Here we use get_item_by_name since it handles caching.
|
|
253
|
+
item = self.get_item_by_name(item_name)
|
|
254
|
+
all_bands = []
|
|
255
|
+
for asset_key, band_names in self.asset_bands.items():
|
|
256
|
+
if asset_key not in item.asset_urls:
|
|
257
|
+
continue
|
|
258
|
+
all_bands.append(band_names)
|
|
259
|
+
return all_bands
|
|
260
|
+
|
|
261
|
+
# --- DataSource implementation ---
|
|
262
|
+
|
|
224
263
|
def get_items(
|
|
225
264
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
226
265
|
) -> list[list[list[EarthDailyItem]]]:
|
|
@@ -285,9 +324,8 @@ class EarthDaily(DataSource, TileStore):
|
|
|
285
324
|
|
|
286
325
|
return groups
|
|
287
326
|
|
|
288
|
-
def deserialize_item(self, serialized_item:
|
|
327
|
+
def deserialize_item(self, serialized_item: dict) -> EarthDailyItem:
|
|
289
328
|
"""Deserializes an item from JSON-decoded data."""
|
|
290
|
-
assert isinstance(serialized_item, dict)
|
|
291
329
|
return EarthDailyItem.deserialize(serialized_item)
|
|
292
330
|
|
|
293
331
|
def ingest(
|
|
@@ -341,144 +379,3 @@ class EarthDaily(DataSource, TileStore):
|
|
|
341
379
|
item.name,
|
|
342
380
|
asset_key,
|
|
343
381
|
)
|
|
344
|
-
|
|
345
|
-
def is_raster_ready(
|
|
346
|
-
self, layer_name: str, item_name: str, bands: list[str]
|
|
347
|
-
) -> bool:
|
|
348
|
-
"""Checks if this raster has been written to the store.
|
|
349
|
-
|
|
350
|
-
Args:
|
|
351
|
-
layer_name: the layer name or alias.
|
|
352
|
-
item_name: the item.
|
|
353
|
-
bands: the list of bands identifying which specific raster to read.
|
|
354
|
-
|
|
355
|
-
Returns:
|
|
356
|
-
whether there is a raster in the store matching the source, item, and
|
|
357
|
-
bands.
|
|
358
|
-
"""
|
|
359
|
-
# Always ready since we wrap accesses to EarthDaily.
|
|
360
|
-
return True
|
|
361
|
-
|
|
362
|
-
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
363
|
-
"""Get the sets of bands that have been stored for the specified item.
|
|
364
|
-
|
|
365
|
-
Args:
|
|
366
|
-
layer_name: the layer name or alias.
|
|
367
|
-
item_name: the item.
|
|
368
|
-
"""
|
|
369
|
-
if self.skip_items_missing_assets:
|
|
370
|
-
# In this case we can assume that the item has all of the assets.
|
|
371
|
-
return list(self.asset_bands.values())
|
|
372
|
-
|
|
373
|
-
# Otherwise we have to lookup the STAC item to see which assets it has.
|
|
374
|
-
# Here we use get_item_by_name since it handles caching.
|
|
375
|
-
item = self.get_item_by_name(item_name)
|
|
376
|
-
all_bands = []
|
|
377
|
-
for asset_key, band_names in self.asset_bands.items():
|
|
378
|
-
if asset_key not in item.asset_urls:
|
|
379
|
-
continue
|
|
380
|
-
all_bands.append(band_names)
|
|
381
|
-
return all_bands
|
|
382
|
-
|
|
383
|
-
def _get_asset_by_band(self, bands: list[str]) -> str:
|
|
384
|
-
"""Get the name of the asset based on the band names."""
|
|
385
|
-
for asset_key, asset_bands in self.asset_bands.items():
|
|
386
|
-
if bands == asset_bands:
|
|
387
|
-
return asset_key
|
|
388
|
-
|
|
389
|
-
raise ValueError(f"no raster with bands {bands}")
|
|
390
|
-
|
|
391
|
-
def get_raster_bounds(
|
|
392
|
-
self, layer_name: str, item_name: str, bands: list[str], projection: Projection
|
|
393
|
-
) -> PixelBounds:
|
|
394
|
-
"""Get the bounds of the raster in the specified projection.
|
|
395
|
-
|
|
396
|
-
Args:
|
|
397
|
-
layer_name: the layer name or alias.
|
|
398
|
-
item_name: the item to check.
|
|
399
|
-
bands: the list of bands identifying which specific raster to read. These
|
|
400
|
-
bands must match the bands of a stored raster.
|
|
401
|
-
projection: the projection to get the raster's bounds in.
|
|
402
|
-
|
|
403
|
-
Returns:
|
|
404
|
-
the bounds of the raster in the projection.
|
|
405
|
-
"""
|
|
406
|
-
item = self.get_item_by_name(item_name)
|
|
407
|
-
geom = item.geometry.to_projection(projection)
|
|
408
|
-
return (
|
|
409
|
-
int(geom.shp.bounds[0]),
|
|
410
|
-
int(geom.shp.bounds[1]),
|
|
411
|
-
int(geom.shp.bounds[2]),
|
|
412
|
-
int(geom.shp.bounds[3]),
|
|
413
|
-
)
|
|
414
|
-
|
|
415
|
-
def read_raster(
|
|
416
|
-
self,
|
|
417
|
-
layer_name: str,
|
|
418
|
-
item_name: str,
|
|
419
|
-
bands: list[str],
|
|
420
|
-
projection: Projection,
|
|
421
|
-
bounds: PixelBounds,
|
|
422
|
-
resampling: Resampling = Resampling.bilinear,
|
|
423
|
-
) -> npt.NDArray[Any]:
|
|
424
|
-
"""Read raster data from the store.
|
|
425
|
-
|
|
426
|
-
Args:
|
|
427
|
-
layer_name: the layer name or alias.
|
|
428
|
-
item_name: the item to read.
|
|
429
|
-
bands: the list of bands identifying which specific raster to read. These
|
|
430
|
-
bands must match the bands of a stored raster.
|
|
431
|
-
projection: the projection to read in.
|
|
432
|
-
bounds: the bounds to read.
|
|
433
|
-
resampling: the resampling method to use in case reprojection is needed.
|
|
434
|
-
|
|
435
|
-
Returns:
|
|
436
|
-
the raster data
|
|
437
|
-
"""
|
|
438
|
-
asset_key = self._get_asset_by_band(bands)
|
|
439
|
-
item = self.get_item_by_name(item_name)
|
|
440
|
-
asset_url = item.asset_urls[asset_key]
|
|
441
|
-
|
|
442
|
-
# Construct the transform to use for the warped dataset.
|
|
443
|
-
wanted_transform = affine.Affine(
|
|
444
|
-
projection.x_resolution,
|
|
445
|
-
0,
|
|
446
|
-
bounds[0] * projection.x_resolution,
|
|
447
|
-
0,
|
|
448
|
-
projection.y_resolution,
|
|
449
|
-
bounds[1] * projection.y_resolution,
|
|
450
|
-
)
|
|
451
|
-
|
|
452
|
-
with rasterio.open(asset_url) as src:
|
|
453
|
-
with rasterio.vrt.WarpedVRT(
|
|
454
|
-
src,
|
|
455
|
-
crs=projection.crs,
|
|
456
|
-
transform=wanted_transform,
|
|
457
|
-
width=bounds[2] - bounds[0],
|
|
458
|
-
height=bounds[3] - bounds[1],
|
|
459
|
-
resampling=resampling,
|
|
460
|
-
) as vrt:
|
|
461
|
-
return vrt.read()
|
|
462
|
-
|
|
463
|
-
def materialize(
|
|
464
|
-
self,
|
|
465
|
-
window: Window,
|
|
466
|
-
item_groups: list[list[Item]],
|
|
467
|
-
layer_name: str,
|
|
468
|
-
layer_cfg: LayerConfig,
|
|
469
|
-
) -> None:
|
|
470
|
-
"""Materialize data for the window.
|
|
471
|
-
|
|
472
|
-
Args:
|
|
473
|
-
window: the window to materialize
|
|
474
|
-
item_groups: the items from get_items
|
|
475
|
-
layer_name: the name of this layer
|
|
476
|
-
layer_cfg: the config of this layer
|
|
477
|
-
"""
|
|
478
|
-
RasterMaterializer().materialize(
|
|
479
|
-
TileStoreWithLayer(self, layer_name),
|
|
480
|
-
window,
|
|
481
|
-
layer_name,
|
|
482
|
-
layer_cfg,
|
|
483
|
-
item_groups,
|
|
484
|
-
)
|
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
"""Data sources backed by EarthDataHub-hosted datasets."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
import os
|
|
7
|
+
import tempfile
|
|
8
|
+
from datetime import UTC, datetime, timedelta
|
|
9
|
+
from typing import Any, Literal
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import rasterio
|
|
13
|
+
import shapely
|
|
14
|
+
import xarray as xr
|
|
15
|
+
from rasterio.transform import from_origin
|
|
16
|
+
from upath import UPath
|
|
17
|
+
|
|
18
|
+
from rslearn.config import QueryConfig, SpaceMode
|
|
19
|
+
from rslearn.const import WGS84_EPSG, WGS84_PROJECTION
|
|
20
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
21
|
+
from rslearn.log_utils import get_logger
|
|
22
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
23
|
+
from rslearn.utils.geometry import STGeometry
|
|
24
|
+
|
|
25
|
+
logger = get_logger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _floor_to_utc_day(t: datetime) -> datetime:
|
|
29
|
+
"""Return the UTC day boundary (00:00) for a datetime.
|
|
30
|
+
|
|
31
|
+
If `t` is naive, it is treated as UTC. If `t` is timezone-aware, it is converted
|
|
32
|
+
to UTC before flooring.
|
|
33
|
+
"""
|
|
34
|
+
if t.tzinfo is None:
|
|
35
|
+
t = t.replace(tzinfo=UTC)
|
|
36
|
+
else:
|
|
37
|
+
t = t.astimezone(UTC)
|
|
38
|
+
return datetime(t.year, t.month, t.day, tzinfo=UTC)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _bounds_to_lon_ranges_0_360(
|
|
42
|
+
bounds: tuple[float, float, float, float],
|
|
43
|
+
) -> list[tuple[float, float]]:
|
|
44
|
+
"""Convert lon bounds to one or two non-wrapping ranges in [0, 360].
|
|
45
|
+
|
|
46
|
+
Expects non-wrapping bounds where min_lon <= max_lon.
|
|
47
|
+
"""
|
|
48
|
+
min_lon, _, max_lon, _ = bounds
|
|
49
|
+
|
|
50
|
+
if min_lon >= 0 and max_lon >= 0:
|
|
51
|
+
return [(min_lon, max_lon)]
|
|
52
|
+
if min_lon < 0 and max_lon < 0:
|
|
53
|
+
return [(min_lon + 360.0, max_lon + 360.0)]
|
|
54
|
+
|
|
55
|
+
# Bounds cross 0 degrees (e.g. [-5, 5]) which wraps in the 0..360 convention.
|
|
56
|
+
return [(min_lon + 360.0, 360.0), (0.0, max_lon)]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _snap_bounds_outward(
|
|
60
|
+
bounds: tuple[float, float, float, float],
|
|
61
|
+
step_degrees: float,
|
|
62
|
+
) -> tuple[float, float, float, float]:
|
|
63
|
+
"""Snap lon/lat bounds outward to a fixed grid.
|
|
64
|
+
|
|
65
|
+
ERA5(-Land) datasets are provided on a regular 0.1° lat/lon grid. When a window is
|
|
66
|
+
much smaller than the grid spacing, its geographic bounds may not contain any grid
|
|
67
|
+
*centers*, which leads to empty `xarray.sel(..., slice(...))` selections.
|
|
68
|
+
|
|
69
|
+
Snapping outward ensures at least one grid cell is selected when there is any
|
|
70
|
+
overlap.
|
|
71
|
+
"""
|
|
72
|
+
min_lon, min_lat, max_lon, max_lat = bounds
|
|
73
|
+
if min_lon > max_lon or min_lat > max_lat:
|
|
74
|
+
raise ValueError(f"invalid bounds: {bounds}")
|
|
75
|
+
|
|
76
|
+
min_lon = math.floor(min_lon / step_degrees) * step_degrees
|
|
77
|
+
min_lat = math.floor(min_lat / step_degrees) * step_degrees
|
|
78
|
+
max_lon = math.ceil(max_lon / step_degrees) * step_degrees
|
|
79
|
+
max_lat = math.ceil(max_lat / step_degrees) * step_degrees
|
|
80
|
+
|
|
81
|
+
# Ensure non-empty after snapping.
|
|
82
|
+
if max_lon == min_lon:
|
|
83
|
+
max_lon = min_lon + step_degrees
|
|
84
|
+
if max_lat == min_lat:
|
|
85
|
+
max_lat = min_lat + step_degrees
|
|
86
|
+
|
|
87
|
+
return (min_lon, min_lat, max_lon, max_lat)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ERA5LandDailyUTCv1(DataSource[Item]):
|
|
91
|
+
"""ERA5-Land daily UTC (v1) hosted on EarthDataHub.
|
|
92
|
+
|
|
93
|
+
This data source reads from the EarthDataHub Zarr store and writes daily GeoTIFFs
|
|
94
|
+
into the dataset tile store.
|
|
95
|
+
|
|
96
|
+
Supported bands:
|
|
97
|
+
- d2m: 2m dewpoint temperature (units: K)
|
|
98
|
+
- e: evaporation (units: m of water equivalent)
|
|
99
|
+
- pev: potential evaporation (units: m)
|
|
100
|
+
- ro: runoff (units: m)
|
|
101
|
+
- sp: surface pressure (units: Pa)
|
|
102
|
+
- ssr: surface net short-wave (solar) radiation (units: J m-2)
|
|
103
|
+
- ssrd: surface short-wave (solar) radiation downwards (units: J m-2)
|
|
104
|
+
- str: surface net long-wave (thermal) radiation (units: J m-2)
|
|
105
|
+
- swvl1: volumetric soil water layer 1 (units: m3 m-3)
|
|
106
|
+
- swvl2: volumetric soil water layer 2 (units: m3 m-3)
|
|
107
|
+
- t2m: 2m temperature (units: K)
|
|
108
|
+
- tp: total precipitation (units: m)
|
|
109
|
+
- u10: 10m U wind component (units: m s-1)
|
|
110
|
+
- v10: 10m V wind component (units: m s-1)
|
|
111
|
+
|
|
112
|
+
Authentication:
|
|
113
|
+
EarthDataHub uses token-based auth. Configure your netrc file so HTTP clients
|
|
114
|
+
can attach the token automatically. On Linux and MacOS the netrc path is
|
|
115
|
+
`~/.netrc`.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
DEFAULT_ZARR_URL = (
|
|
119
|
+
"https://data.earthdatahub.destine.eu/era5/era5-land-daily-utc-v1.zarr"
|
|
120
|
+
)
|
|
121
|
+
ALLOWED_BANDS = {
|
|
122
|
+
"d2m",
|
|
123
|
+
"e",
|
|
124
|
+
"pev",
|
|
125
|
+
"ro",
|
|
126
|
+
"sp",
|
|
127
|
+
"ssr",
|
|
128
|
+
"ssrd",
|
|
129
|
+
"str",
|
|
130
|
+
"swvl1",
|
|
131
|
+
"swvl2",
|
|
132
|
+
"t2m",
|
|
133
|
+
"tp",
|
|
134
|
+
"u10",
|
|
135
|
+
"v10",
|
|
136
|
+
}
|
|
137
|
+
PIXEL_SIZE_DEGREES = 0.1
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
band_names: list[str] | None = None,
|
|
142
|
+
zarr_url: str = DEFAULT_ZARR_URL,
|
|
143
|
+
bounds: list[float] | None = None,
|
|
144
|
+
temperature_unit: Literal["celsius", "kelvin"] = "kelvin",
|
|
145
|
+
trust_env: bool = True,
|
|
146
|
+
context: DataSourceContext = DataSourceContext(),
|
|
147
|
+
) -> None:
|
|
148
|
+
"""Initialize a new ERA5LandDailyUTCv1 instance.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
band_names: list of bands to ingest. If omitted and a LayerConfig is
|
|
152
|
+
provided via context, bands are inferred from that layer's band sets.
|
|
153
|
+
zarr_url: URL/path to the EarthDataHub Zarr store.
|
|
154
|
+
bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat]
|
|
155
|
+
in degrees (WGS84). For best performance, set bounds to your area of
|
|
156
|
+
interest.
|
|
157
|
+
temperature_unit: units to return for `t2m` ("celsius" or "kelvin").
|
|
158
|
+
trust_env: if True (default), allow the underlying HTTP client to read
|
|
159
|
+
environment configuration (including netrc) for auth/proxies.
|
|
160
|
+
context: rslearn data source context.
|
|
161
|
+
"""
|
|
162
|
+
self.zarr_url = zarr_url
|
|
163
|
+
if bounds is not None:
|
|
164
|
+
if len(bounds) != 4:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"ERA5LandDailyUTCv1 bounds must be [min_lon, min_lat, max_lon, max_lat] "
|
|
167
|
+
f"(got {bounds!r})."
|
|
168
|
+
)
|
|
169
|
+
min_lon, min_lat, max_lon, max_lat = bounds
|
|
170
|
+
if min_lon > max_lon:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
"ERA5LandDailyUTCv1 does not yet support longitude ranges that cross the dateline "
|
|
173
|
+
f"(got bounds min_lon={min_lon}, max_lon={max_lon})."
|
|
174
|
+
)
|
|
175
|
+
if min_lat > max_lat:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
"ERA5LandDailyUTCv1 bounds must have min_lat <= max_lat "
|
|
178
|
+
f"(got bounds min_lat={min_lat}, max_lat={max_lat})."
|
|
179
|
+
)
|
|
180
|
+
self.bounds = bounds
|
|
181
|
+
self.temperature_unit = temperature_unit
|
|
182
|
+
self.trust_env = trust_env
|
|
183
|
+
|
|
184
|
+
self.band_names: list[str]
|
|
185
|
+
if context.layer_config is not None:
|
|
186
|
+
self.band_names = []
|
|
187
|
+
for band_set in context.layer_config.band_sets:
|
|
188
|
+
for band in band_set.bands:
|
|
189
|
+
if band not in self.band_names:
|
|
190
|
+
self.band_names.append(band)
|
|
191
|
+
elif band_names is not None:
|
|
192
|
+
self.band_names = band_names
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
"band_names must be set if layer_config is not in the context"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
invalid_bands = [b for b in self.band_names if b not in self.ALLOWED_BANDS]
|
|
199
|
+
if invalid_bands:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
f"unsupported ERA5LandDailyUTCv1 band(s): {invalid_bands}; "
|
|
202
|
+
f"supported: {sorted(self.ALLOWED_BANDS)}"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
self._ds: Any | None = None
|
|
206
|
+
|
|
207
|
+
def _get_dataset(self) -> xr.Dataset:
|
|
208
|
+
"""Open (and memoize) the backing ERA5-Land Zarr dataset."""
|
|
209
|
+
if self._ds is not None:
|
|
210
|
+
return self._ds
|
|
211
|
+
|
|
212
|
+
storage_options: dict[str, Any] | None = None
|
|
213
|
+
if self.zarr_url.startswith("http://") or self.zarr_url.startswith("https://"):
|
|
214
|
+
storage_options = {"client_kwargs": {"trust_env": self.trust_env}}
|
|
215
|
+
|
|
216
|
+
self._ds = xr.open_dataset(
|
|
217
|
+
self.zarr_url,
|
|
218
|
+
engine="zarr",
|
|
219
|
+
chunks=None, # No dask
|
|
220
|
+
storage_options=storage_options,
|
|
221
|
+
)
|
|
222
|
+
return self._ds
|
|
223
|
+
|
|
224
|
+
def get_items(
|
|
225
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
226
|
+
) -> list[list[list[Item]]]:
|
|
227
|
+
"""Get daily items intersecting the given geometries.
|
|
228
|
+
|
|
229
|
+
Returns one item per UTC day that intersects each requested geometry time
|
|
230
|
+
range.
|
|
231
|
+
"""
|
|
232
|
+
if query_config.space_mode != SpaceMode.MOSAIC:
|
|
233
|
+
raise ValueError("expected mosaic space mode in the query configuration")
|
|
234
|
+
|
|
235
|
+
if self.bounds is not None:
|
|
236
|
+
min_lon, min_lat, max_lon, max_lat = self.bounds
|
|
237
|
+
item_shp = shapely.box(min_lon, min_lat, max_lon, max_lat)
|
|
238
|
+
else:
|
|
239
|
+
item_shp = shapely.box(-180, -90, 180, 90)
|
|
240
|
+
|
|
241
|
+
all_groups: list[list[list[Item]]] = []
|
|
242
|
+
for geometry in geometries:
|
|
243
|
+
if geometry.time_range is None:
|
|
244
|
+
raise ValueError("expected all geometries to have a time range")
|
|
245
|
+
|
|
246
|
+
start, end = geometry.time_range
|
|
247
|
+
cur_day = _floor_to_utc_day(start)
|
|
248
|
+
cur_groups: list[list[Item]] = []
|
|
249
|
+
while cur_day < end:
|
|
250
|
+
next_day = cur_day + timedelta(days=1)
|
|
251
|
+
item_name = f"era5land_dailyutc_v1_{cur_day.year:04d}{cur_day.month:02d}{cur_day.day:02d}"
|
|
252
|
+
item_geom = STGeometry(WGS84_PROJECTION, item_shp, (cur_day, next_day))
|
|
253
|
+
cur_groups.append([Item(item_name, item_geom)])
|
|
254
|
+
cur_day = next_day
|
|
255
|
+
|
|
256
|
+
all_groups.append(cur_groups)
|
|
257
|
+
|
|
258
|
+
return all_groups
|
|
259
|
+
|
|
260
|
+
def deserialize_item(self, serialized_item: Any) -> Item:
|
|
261
|
+
"""Deserialize an `Item` previously produced by this data source."""
|
|
262
|
+
assert isinstance(serialized_item, dict)
|
|
263
|
+
return Item.deserialize(serialized_item)
|
|
264
|
+
|
|
265
|
+
def _get_effective_bounds(
|
|
266
|
+
self, geometries: list[STGeometry]
|
|
267
|
+
) -> tuple[float, float, float, float]:
|
|
268
|
+
"""Compute an effective WGS84 bounding box for ingestion.
|
|
269
|
+
|
|
270
|
+
If `self.bounds` is set, it is used as-is; otherwise, the bounds are derived
|
|
271
|
+
from the union of the provided geometries (after projecting each to WGS84).
|
|
272
|
+
"""
|
|
273
|
+
if self.bounds is not None:
|
|
274
|
+
min_lon, min_lat, max_lon, max_lat = self.bounds
|
|
275
|
+
return (min_lon, min_lat, max_lon, max_lat)
|
|
276
|
+
|
|
277
|
+
min_lon = 180.0
|
|
278
|
+
min_lat = 90.0
|
|
279
|
+
max_lon = -180.0
|
|
280
|
+
max_lat = -90.0
|
|
281
|
+
for geom in geometries:
|
|
282
|
+
wgs84 = geom.to_projection(WGS84_PROJECTION)
|
|
283
|
+
b = wgs84.shp.bounds
|
|
284
|
+
min_lon = min(min_lon, b[0])
|
|
285
|
+
min_lat = min(min_lat, b[1])
|
|
286
|
+
max_lon = max(max_lon, b[2])
|
|
287
|
+
max_lat = max(max_lat, b[3])
|
|
288
|
+
return (min_lon, min_lat, max_lon, max_lat)
|
|
289
|
+
|
|
290
|
+
def _write_geotiff(
|
|
291
|
+
self,
|
|
292
|
+
tif_path: str,
|
|
293
|
+
lat: np.ndarray,
|
|
294
|
+
lon: np.ndarray,
|
|
295
|
+
band_arrays: list[np.ndarray],
|
|
296
|
+
) -> None:
|
|
297
|
+
"""Write a GeoTIFF with WGS84 georeferencing from ERA5(-Land) arrays.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
tif_path: destination GeoTIFF path.
|
|
301
|
+
lat: 1D latitude coordinate (expected descending, north-to-south).
|
|
302
|
+
lon: 1D longitude coordinate (0..360 in the source dataset).
|
|
303
|
+
band_arrays: band arrays with shape (lat, lon), one per output band.
|
|
304
|
+
"""
|
|
305
|
+
if lat.ndim != 1 or lon.ndim != 1:
|
|
306
|
+
raise ValueError("expected 1D latitude/longitude coordinates")
|
|
307
|
+
if len(band_arrays) == 0:
|
|
308
|
+
raise ValueError("expected at least one band array")
|
|
309
|
+
|
|
310
|
+
# Convert longitude to [-180, 180) and reorder so GeoTIFF coordinates match
|
|
311
|
+
# common WGS84 conventions and rslearn windows.
|
|
312
|
+
lon = ((lon + 180) % 360) - 180
|
|
313
|
+
lon_sort_idx = np.argsort(lon)
|
|
314
|
+
lon = lon[lon_sort_idx]
|
|
315
|
+
band_arrays = [a[:, lon_sort_idx] for a in band_arrays]
|
|
316
|
+
|
|
317
|
+
if len(lon) > 1:
|
|
318
|
+
dx = float(lon[1] - lon[0])
|
|
319
|
+
else:
|
|
320
|
+
dx = self.PIXEL_SIZE_DEGREES
|
|
321
|
+
if len(lat) > 1:
|
|
322
|
+
dy = float(abs(lat[1] - lat[0]))
|
|
323
|
+
else:
|
|
324
|
+
dy = self.PIXEL_SIZE_DEGREES
|
|
325
|
+
|
|
326
|
+
# ERA5-Land latitude is descending (north to south). This matches GeoTIFF row order.
|
|
327
|
+
if len(lat) > 1 and lat[1] > lat[0]:
|
|
328
|
+
raise ValueError("expected latitude coordinate to be descending")
|
|
329
|
+
|
|
330
|
+
west = float(lon.min() - dx / 2)
|
|
331
|
+
north = float(lat.max() + dy / 2)
|
|
332
|
+
transform = from_origin(west, north, dx, dy)
|
|
333
|
+
crs = f"EPSG:{WGS84_EPSG}"
|
|
334
|
+
|
|
335
|
+
array = np.stack(band_arrays, axis=0).astype(np.float32)
|
|
336
|
+
with rasterio.open(
|
|
337
|
+
tif_path,
|
|
338
|
+
"w",
|
|
339
|
+
driver="GTiff",
|
|
340
|
+
height=array.shape[1],
|
|
341
|
+
width=array.shape[2],
|
|
342
|
+
count=array.shape[0],
|
|
343
|
+
dtype=array.dtype,
|
|
344
|
+
crs=crs,
|
|
345
|
+
transform=transform,
|
|
346
|
+
) as dst:
|
|
347
|
+
dst.write(array)
|
|
348
|
+
|
|
349
|
+
def ingest(
|
|
350
|
+
self,
|
|
351
|
+
tile_store: TileStoreWithLayer,
|
|
352
|
+
items: list[Item],
|
|
353
|
+
geometries: list[list[STGeometry]],
|
|
354
|
+
) -> None:
|
|
355
|
+
"""Ingest daily ERA5-Land rasters for the requested items/geometries.
|
|
356
|
+
|
|
357
|
+
For each item (one UTC day), this reads the corresponding slice from the
|
|
358
|
+
EarthDataHub Zarr store (subsetting by lat/lon for performance), then writes
|
|
359
|
+
a GeoTIFF into the dataset tile store.
|
|
360
|
+
"""
|
|
361
|
+
ds = self._get_dataset()
|
|
362
|
+
|
|
363
|
+
for item, item_geoms in zip(items, geometries):
|
|
364
|
+
if tile_store.is_raster_ready(item.name, self.band_names):
|
|
365
|
+
continue
|
|
366
|
+
|
|
367
|
+
if item.geometry.time_range is None:
|
|
368
|
+
raise ValueError("expected item to have a time range")
|
|
369
|
+
|
|
370
|
+
day_start = _floor_to_utc_day(item.geometry.time_range[0])
|
|
371
|
+
day_str = f"{day_start.year:04d}-{day_start.month:02d}-{day_start.day:02d}"
|
|
372
|
+
|
|
373
|
+
bounds = _snap_bounds_outward(
|
|
374
|
+
self._get_effective_bounds(item_geoms),
|
|
375
|
+
step_degrees=self.PIXEL_SIZE_DEGREES,
|
|
376
|
+
)
|
|
377
|
+
lon_ranges_0_360 = _bounds_to_lon_ranges_0_360(bounds)
|
|
378
|
+
min_lat = bounds[1]
|
|
379
|
+
max_lat = bounds[3]
|
|
380
|
+
|
|
381
|
+
# Subset the dataset before computing, for performance.
|
|
382
|
+
# Latitude is descending in the dataset.
|
|
383
|
+
sel_kwargs_base: dict[str, Any] = dict(
|
|
384
|
+
valid_time=day_str,
|
|
385
|
+
latitude=slice(max_lat, min_lat),
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
band_arrays: list[np.ndarray] = []
|
|
389
|
+
lat: np.ndarray | None = None
|
|
390
|
+
lon: np.ndarray | None = None
|
|
391
|
+
for band in self.band_names:
|
|
392
|
+
if len(lon_ranges_0_360) == 1:
|
|
393
|
+
da = ds[band].sel(
|
|
394
|
+
**sel_kwargs_base,
|
|
395
|
+
longitude=slice(lon_ranges_0_360[0][0], lon_ranges_0_360[0][1]),
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
parts = [
|
|
399
|
+
ds[band].sel(**sel_kwargs_base, longitude=slice(lo, hi))
|
|
400
|
+
for (lo, hi) in lon_ranges_0_360
|
|
401
|
+
]
|
|
402
|
+
da = xr.concat(parts, dim="longitude")
|
|
403
|
+
|
|
404
|
+
if band == "t2m" and self.temperature_unit == "celsius":
|
|
405
|
+
da = da - 273.15
|
|
406
|
+
|
|
407
|
+
da = da.load()
|
|
408
|
+
if lat is None:
|
|
409
|
+
lat = da["latitude"].to_numpy()
|
|
410
|
+
lon = da["longitude"].to_numpy()
|
|
411
|
+
band_arrays.append(da.to_numpy())
|
|
412
|
+
|
|
413
|
+
assert lat is not None and lon is not None
|
|
414
|
+
if lat.size == 0 or lon.size == 0:
|
|
415
|
+
raise ValueError(
|
|
416
|
+
f"ERA5LandDailyUTCv1 selection returned empty grid for item {item.name} "
|
|
417
|
+
f"(bounds={bounds})"
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
421
|
+
local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
|
|
422
|
+
self._write_geotiff(local_tif_fname, lat, lon, band_arrays)
|
|
423
|
+
tile_store.write_raster_file(
|
|
424
|
+
item.name, self.band_names, UPath(local_tif_fname)
|
|
425
|
+
)
|
|
@@ -5,7 +5,6 @@ import os
|
|
|
5
5
|
import tempfile
|
|
6
6
|
import zipfile
|
|
7
7
|
from datetime import UTC, datetime, timedelta
|
|
8
|
-
from typing import Any
|
|
9
8
|
|
|
10
9
|
import fiona
|
|
11
10
|
import requests
|
|
@@ -153,7 +152,7 @@ class EuroCrops(DataSource[EuroCropsItem]):
|
|
|
153
152
|
groups.append(cur_groups)
|
|
154
153
|
return groups
|
|
155
154
|
|
|
156
|
-
def deserialize_item(self, serialized_item:
|
|
155
|
+
def deserialize_item(self, serialized_item: dict) -> EuroCropsItem:
|
|
157
156
|
"""Deserializes an item from JSON-decoded data."""
|
|
158
157
|
return EuroCropsItem.deserialize(serialized_item)
|
|
159
158
|
|