rslearn 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. rslearn/data_sources/__init__.py +2 -0
  2. rslearn/data_sources/aws_landsat.py +44 -161
  3. rslearn/data_sources/aws_open_data.py +2 -4
  4. rslearn/data_sources/aws_sentinel1.py +1 -3
  5. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  6. rslearn/data_sources/climate_data_store.py +1 -3
  7. rslearn/data_sources/copernicus.py +1 -2
  8. rslearn/data_sources/data_source.py +1 -1
  9. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  10. rslearn/data_sources/earthdaily.py +52 -155
  11. rslearn/data_sources/earthdatahub.py +425 -0
  12. rslearn/data_sources/eurocrops.py +1 -2
  13. rslearn/data_sources/gcp_public_data.py +1 -2
  14. rslearn/data_sources/google_earth_engine.py +1 -2
  15. rslearn/data_sources/hf_srtm.py +595 -0
  16. rslearn/data_sources/local_files.py +1 -1
  17. rslearn/data_sources/openstreetmap.py +1 -1
  18. rslearn/data_sources/planet.py +1 -2
  19. rslearn/data_sources/planet_basemap.py +1 -2
  20. rslearn/data_sources/planetary_computer.py +183 -186
  21. rslearn/data_sources/soilgrids.py +3 -3
  22. rslearn/data_sources/stac.py +1 -2
  23. rslearn/data_sources/usda_cdl.py +1 -3
  24. rslearn/data_sources/usgs_landsat.py +7 -254
  25. rslearn/data_sources/worldcereal.py +1 -1
  26. rslearn/data_sources/worldcover.py +1 -1
  27. rslearn/data_sources/worldpop.py +1 -1
  28. rslearn/data_sources/xyz_tiles.py +5 -9
  29. rslearn/models/concatenate_features.py +6 -1
  30. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  31. rslearn/train/data_module.py +27 -27
  32. rslearn/train/dataset.py +109 -62
  33. rslearn/train/lightning_module.py +1 -1
  34. rslearn/train/model_context.py +3 -3
  35. rslearn/train/prediction_writer.py +69 -41
  36. rslearn/train/tasks/classification.py +1 -1
  37. rslearn/train/tasks/detection.py +5 -5
  38. rslearn/train/tasks/regression.py +1 -1
  39. rslearn/utils/__init__.py +2 -0
  40. rslearn/utils/geometry.py +21 -0
  41. rslearn/utils/m2m_api.py +251 -0
  42. rslearn/utils/retry_session.py +43 -0
  43. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  44. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/RECORD +49 -45
  45. rslearn/data_sources/earthdata_srtm.py +0 -282
  46. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  47. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  48. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  49. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  50. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,425 @@
1
+ """Data sources backed by EarthDataHub-hosted datasets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import os
7
+ import tempfile
8
+ from datetime import UTC, datetime, timedelta
9
+ from typing import Any, Literal
10
+
11
+ import numpy as np
12
+ import rasterio
13
+ import shapely
14
+ import xarray as xr
15
+ from rasterio.transform import from_origin
16
+ from upath import UPath
17
+
18
+ from rslearn.config import QueryConfig, SpaceMode
19
+ from rslearn.const import WGS84_EPSG, WGS84_PROJECTION
20
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
21
+ from rslearn.log_utils import get_logger
22
+ from rslearn.tile_stores import TileStoreWithLayer
23
+ from rslearn.utils.geometry import STGeometry
24
+
25
+ logger = get_logger(__name__)
26
+
27
+
28
+ def _floor_to_utc_day(t: datetime) -> datetime:
29
+ """Return the UTC day boundary (00:00) for a datetime.
30
+
31
+ If `t` is naive, it is treated as UTC. If `t` is timezone-aware, it is converted
32
+ to UTC before flooring.
33
+ """
34
+ if t.tzinfo is None:
35
+ t = t.replace(tzinfo=UTC)
36
+ else:
37
+ t = t.astimezone(UTC)
38
+ return datetime(t.year, t.month, t.day, tzinfo=UTC)
39
+
40
+
41
+ def _bounds_to_lon_ranges_0_360(
42
+ bounds: tuple[float, float, float, float],
43
+ ) -> list[tuple[float, float]]:
44
+ """Convert lon bounds to one or two non-wrapping ranges in [0, 360].
45
+
46
+ Expects non-wrapping bounds where min_lon <= max_lon.
47
+ """
48
+ min_lon, _, max_lon, _ = bounds
49
+
50
+ if min_lon >= 0 and max_lon >= 0:
51
+ return [(min_lon, max_lon)]
52
+ if min_lon < 0 and max_lon < 0:
53
+ return [(min_lon + 360.0, max_lon + 360.0)]
54
+
55
+ # Bounds cross 0 degrees (e.g. [-5, 5]) which wraps in the 0..360 convention.
56
+ return [(min_lon + 360.0, 360.0), (0.0, max_lon)]
57
+
58
+
59
+ def _snap_bounds_outward(
60
+ bounds: tuple[float, float, float, float],
61
+ step_degrees: float,
62
+ ) -> tuple[float, float, float, float]:
63
+ """Snap lon/lat bounds outward to a fixed grid.
64
+
65
+ ERA5(-Land) datasets are provided on a regular 0.1° lat/lon grid. When a window is
66
+ much smaller than the grid spacing, its geographic bounds may not contain any grid
67
+ *centers*, which leads to empty `xarray.sel(..., slice(...))` selections.
68
+
69
+ Snapping outward ensures at least one grid cell is selected when there is any
70
+ overlap.
71
+ """
72
+ min_lon, min_lat, max_lon, max_lat = bounds
73
+ if min_lon > max_lon or min_lat > max_lat:
74
+ raise ValueError(f"invalid bounds: {bounds}")
75
+
76
+ min_lon = math.floor(min_lon / step_degrees) * step_degrees
77
+ min_lat = math.floor(min_lat / step_degrees) * step_degrees
78
+ max_lon = math.ceil(max_lon / step_degrees) * step_degrees
79
+ max_lat = math.ceil(max_lat / step_degrees) * step_degrees
80
+
81
+ # Ensure non-empty after snapping.
82
+ if max_lon == min_lon:
83
+ max_lon = min_lon + step_degrees
84
+ if max_lat == min_lat:
85
+ max_lat = min_lat + step_degrees
86
+
87
+ return (min_lon, min_lat, max_lon, max_lat)
88
+
89
+
90
+ class ERA5LandDailyUTCv1(DataSource[Item]):
91
+ """ERA5-Land daily UTC (v1) hosted on EarthDataHub.
92
+
93
+ This data source reads from the EarthDataHub Zarr store and writes daily GeoTIFFs
94
+ into the dataset tile store.
95
+
96
+ Supported bands:
97
+ - d2m: 2m dewpoint temperature (units: K)
98
+ - e: evaporation (units: m of water equivalent)
99
+ - pev: potential evaporation (units: m)
100
+ - ro: runoff (units: m)
101
+ - sp: surface pressure (units: Pa)
102
+ - ssr: surface net short-wave (solar) radiation (units: J m-2)
103
+ - ssrd: surface short-wave (solar) radiation downwards (units: J m-2)
104
+ - str: surface net long-wave (thermal) radiation (units: J m-2)
105
+ - swvl1: volumetric soil water layer 1 (units: m3 m-3)
106
+ - swvl2: volumetric soil water layer 2 (units: m3 m-3)
107
+ - t2m: 2m temperature (units: K)
108
+ - tp: total precipitation (units: m)
109
+ - u10: 10m U wind component (units: m s-1)
110
+ - v10: 10m V wind component (units: m s-1)
111
+
112
+ Authentication:
113
+ EarthDataHub uses token-based auth. Configure your netrc file so HTTP clients
114
+ can attach the token automatically. On Linux and MacOS the netrc path is
115
+ `~/.netrc`.
116
+ """
117
+
118
+ DEFAULT_ZARR_URL = (
119
+ "https://data.earthdatahub.destine.eu/era5/era5-land-daily-utc-v1.zarr"
120
+ )
121
+ ALLOWED_BANDS = {
122
+ "d2m",
123
+ "e",
124
+ "pev",
125
+ "ro",
126
+ "sp",
127
+ "ssr",
128
+ "ssrd",
129
+ "str",
130
+ "swvl1",
131
+ "swvl2",
132
+ "t2m",
133
+ "tp",
134
+ "u10",
135
+ "v10",
136
+ }
137
+ PIXEL_SIZE_DEGREES = 0.1
138
+
139
+ def __init__(
140
+ self,
141
+ band_names: list[str] | None = None,
142
+ zarr_url: str = DEFAULT_ZARR_URL,
143
+ bounds: list[float] | None = None,
144
+ temperature_unit: Literal["celsius", "kelvin"] = "kelvin",
145
+ trust_env: bool = True,
146
+ context: DataSourceContext = DataSourceContext(),
147
+ ) -> None:
148
+ """Initialize a new ERA5LandDailyUTCv1 instance.
149
+
150
+ Args:
151
+ band_names: list of bands to ingest. If omitted and a LayerConfig is
152
+ provided via context, bands are inferred from that layer's band sets.
153
+ zarr_url: URL/path to the EarthDataHub Zarr store.
154
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat]
155
+ in degrees (WGS84). For best performance, set bounds to your area of
156
+ interest.
157
+ temperature_unit: units to return for `t2m` ("celsius" or "kelvin").
158
+ trust_env: if True (default), allow the underlying HTTP client to read
159
+ environment configuration (including netrc) for auth/proxies.
160
+ context: rslearn data source context.
161
+ """
162
+ self.zarr_url = zarr_url
163
+ if bounds is not None:
164
+ if len(bounds) != 4:
165
+ raise ValueError(
166
+ "ERA5LandDailyUTCv1 bounds must be [min_lon, min_lat, max_lon, max_lat] "
167
+ f"(got {bounds!r})."
168
+ )
169
+ min_lon, min_lat, max_lon, max_lat = bounds
170
+ if min_lon > max_lon:
171
+ raise ValueError(
172
+ "ERA5LandDailyUTCv1 does not yet support longitude ranges that cross the dateline "
173
+ f"(got bounds min_lon={min_lon}, max_lon={max_lon})."
174
+ )
175
+ if min_lat > max_lat:
176
+ raise ValueError(
177
+ "ERA5LandDailyUTCv1 bounds must have min_lat <= max_lat "
178
+ f"(got bounds min_lat={min_lat}, max_lat={max_lat})."
179
+ )
180
+ self.bounds = bounds
181
+ self.temperature_unit = temperature_unit
182
+ self.trust_env = trust_env
183
+
184
+ self.band_names: list[str]
185
+ if context.layer_config is not None:
186
+ self.band_names = []
187
+ for band_set in context.layer_config.band_sets:
188
+ for band in band_set.bands:
189
+ if band not in self.band_names:
190
+ self.band_names.append(band)
191
+ elif band_names is not None:
192
+ self.band_names = band_names
193
+ else:
194
+ raise ValueError(
195
+ "band_names must be set if layer_config is not in the context"
196
+ )
197
+
198
+ invalid_bands = [b for b in self.band_names if b not in self.ALLOWED_BANDS]
199
+ if invalid_bands:
200
+ raise ValueError(
201
+ f"unsupported ERA5LandDailyUTCv1 band(s): {invalid_bands}; "
202
+ f"supported: {sorted(self.ALLOWED_BANDS)}"
203
+ )
204
+
205
+ self._ds: Any | None = None
206
+
207
+ def _get_dataset(self) -> xr.Dataset:
208
+ """Open (and memoize) the backing ERA5-Land Zarr dataset."""
209
+ if self._ds is not None:
210
+ return self._ds
211
+
212
+ storage_options: dict[str, Any] | None = None
213
+ if self.zarr_url.startswith("http://") or self.zarr_url.startswith("https://"):
214
+ storage_options = {"client_kwargs": {"trust_env": self.trust_env}}
215
+
216
+ self._ds = xr.open_dataset(
217
+ self.zarr_url,
218
+ engine="zarr",
219
+ chunks=None, # No dask
220
+ storage_options=storage_options,
221
+ )
222
+ return self._ds
223
+
224
+ def get_items(
225
+ self, geometries: list[STGeometry], query_config: QueryConfig
226
+ ) -> list[list[list[Item]]]:
227
+ """Get daily items intersecting the given geometries.
228
+
229
+ Returns one item per UTC day that intersects each requested geometry time
230
+ range.
231
+ """
232
+ if query_config.space_mode != SpaceMode.MOSAIC:
233
+ raise ValueError("expected mosaic space mode in the query configuration")
234
+
235
+ if self.bounds is not None:
236
+ min_lon, min_lat, max_lon, max_lat = self.bounds
237
+ item_shp = shapely.box(min_lon, min_lat, max_lon, max_lat)
238
+ else:
239
+ item_shp = shapely.box(-180, -90, 180, 90)
240
+
241
+ all_groups: list[list[list[Item]]] = []
242
+ for geometry in geometries:
243
+ if geometry.time_range is None:
244
+ raise ValueError("expected all geometries to have a time range")
245
+
246
+ start, end = geometry.time_range
247
+ cur_day = _floor_to_utc_day(start)
248
+ cur_groups: list[list[Item]] = []
249
+ while cur_day < end:
250
+ next_day = cur_day + timedelta(days=1)
251
+ item_name = f"era5land_dailyutc_v1_{cur_day.year:04d}{cur_day.month:02d}{cur_day.day:02d}"
252
+ item_geom = STGeometry(WGS84_PROJECTION, item_shp, (cur_day, next_day))
253
+ cur_groups.append([Item(item_name, item_geom)])
254
+ cur_day = next_day
255
+
256
+ all_groups.append(cur_groups)
257
+
258
+ return all_groups
259
+
260
+ def deserialize_item(self, serialized_item: Any) -> Item:
261
+ """Deserialize an `Item` previously produced by this data source."""
262
+ assert isinstance(serialized_item, dict)
263
+ return Item.deserialize(serialized_item)
264
+
265
+ def _get_effective_bounds(
266
+ self, geometries: list[STGeometry]
267
+ ) -> tuple[float, float, float, float]:
268
+ """Compute an effective WGS84 bounding box for ingestion.
269
+
270
+ If `self.bounds` is set, it is used as-is; otherwise, the bounds are derived
271
+ from the union of the provided geometries (after projecting each to WGS84).
272
+ """
273
+ if self.bounds is not None:
274
+ min_lon, min_lat, max_lon, max_lat = self.bounds
275
+ return (min_lon, min_lat, max_lon, max_lat)
276
+
277
+ min_lon = 180.0
278
+ min_lat = 90.0
279
+ max_lon = -180.0
280
+ max_lat = -90.0
281
+ for geom in geometries:
282
+ wgs84 = geom.to_projection(WGS84_PROJECTION)
283
+ b = wgs84.shp.bounds
284
+ min_lon = min(min_lon, b[0])
285
+ min_lat = min(min_lat, b[1])
286
+ max_lon = max(max_lon, b[2])
287
+ max_lat = max(max_lat, b[3])
288
+ return (min_lon, min_lat, max_lon, max_lat)
289
+
290
+ def _write_geotiff(
291
+ self,
292
+ tif_path: str,
293
+ lat: np.ndarray,
294
+ lon: np.ndarray,
295
+ band_arrays: list[np.ndarray],
296
+ ) -> None:
297
+ """Write a GeoTIFF with WGS84 georeferencing from ERA5(-Land) arrays.
298
+
299
+ Args:
300
+ tif_path: destination GeoTIFF path.
301
+ lat: 1D latitude coordinate (expected descending, north-to-south).
302
+ lon: 1D longitude coordinate (0..360 in the source dataset).
303
+ band_arrays: band arrays with shape (lat, lon), one per output band.
304
+ """
305
+ if lat.ndim != 1 or lon.ndim != 1:
306
+ raise ValueError("expected 1D latitude/longitude coordinates")
307
+ if len(band_arrays) == 0:
308
+ raise ValueError("expected at least one band array")
309
+
310
+ # Convert longitude to [-180, 180) and reorder so GeoTIFF coordinates match
311
+ # common WGS84 conventions and rslearn windows.
312
+ lon = ((lon + 180) % 360) - 180
313
+ lon_sort_idx = np.argsort(lon)
314
+ lon = lon[lon_sort_idx]
315
+ band_arrays = [a[:, lon_sort_idx] for a in band_arrays]
316
+
317
+ if len(lon) > 1:
318
+ dx = float(lon[1] - lon[0])
319
+ else:
320
+ dx = self.PIXEL_SIZE_DEGREES
321
+ if len(lat) > 1:
322
+ dy = float(abs(lat[1] - lat[0]))
323
+ else:
324
+ dy = self.PIXEL_SIZE_DEGREES
325
+
326
+ # ERA5-Land latitude is descending (north to south). This matches GeoTIFF row order.
327
+ if len(lat) > 1 and lat[1] > lat[0]:
328
+ raise ValueError("expected latitude coordinate to be descending")
329
+
330
+ west = float(lon.min() - dx / 2)
331
+ north = float(lat.max() + dy / 2)
332
+ transform = from_origin(west, north, dx, dy)
333
+ crs = f"EPSG:{WGS84_EPSG}"
334
+
335
+ array = np.stack(band_arrays, axis=0).astype(np.float32)
336
+ with rasterio.open(
337
+ tif_path,
338
+ "w",
339
+ driver="GTiff",
340
+ height=array.shape[1],
341
+ width=array.shape[2],
342
+ count=array.shape[0],
343
+ dtype=array.dtype,
344
+ crs=crs,
345
+ transform=transform,
346
+ ) as dst:
347
+ dst.write(array)
348
+
349
+ def ingest(
350
+ self,
351
+ tile_store: TileStoreWithLayer,
352
+ items: list[Item],
353
+ geometries: list[list[STGeometry]],
354
+ ) -> None:
355
+ """Ingest daily ERA5-Land rasters for the requested items/geometries.
356
+
357
+ For each item (one UTC day), this reads the corresponding slice from the
358
+ EarthDataHub Zarr store (subsetting by lat/lon for performance), then writes
359
+ a GeoTIFF into the dataset tile store.
360
+ """
361
+ ds = self._get_dataset()
362
+
363
+ for item, item_geoms in zip(items, geometries):
364
+ if tile_store.is_raster_ready(item.name, self.band_names):
365
+ continue
366
+
367
+ if item.geometry.time_range is None:
368
+ raise ValueError("expected item to have a time range")
369
+
370
+ day_start = _floor_to_utc_day(item.geometry.time_range[0])
371
+ day_str = f"{day_start.year:04d}-{day_start.month:02d}-{day_start.day:02d}"
372
+
373
+ bounds = _snap_bounds_outward(
374
+ self._get_effective_bounds(item_geoms),
375
+ step_degrees=self.PIXEL_SIZE_DEGREES,
376
+ )
377
+ lon_ranges_0_360 = _bounds_to_lon_ranges_0_360(bounds)
378
+ min_lat = bounds[1]
379
+ max_lat = bounds[3]
380
+
381
+ # Subset the dataset before computing, for performance.
382
+ # Latitude is descending in the dataset.
383
+ sel_kwargs_base: dict[str, Any] = dict(
384
+ valid_time=day_str,
385
+ latitude=slice(max_lat, min_lat),
386
+ )
387
+
388
+ band_arrays: list[np.ndarray] = []
389
+ lat: np.ndarray | None = None
390
+ lon: np.ndarray | None = None
391
+ for band in self.band_names:
392
+ if len(lon_ranges_0_360) == 1:
393
+ da = ds[band].sel(
394
+ **sel_kwargs_base,
395
+ longitude=slice(lon_ranges_0_360[0][0], lon_ranges_0_360[0][1]),
396
+ )
397
+ else:
398
+ parts = [
399
+ ds[band].sel(**sel_kwargs_base, longitude=slice(lo, hi))
400
+ for (lo, hi) in lon_ranges_0_360
401
+ ]
402
+ da = xr.concat(parts, dim="longitude")
403
+
404
+ if band == "t2m" and self.temperature_unit == "celsius":
405
+ da = da - 273.15
406
+
407
+ da = da.load()
408
+ if lat is None:
409
+ lat = da["latitude"].to_numpy()
410
+ lon = da["longitude"].to_numpy()
411
+ band_arrays.append(da.to_numpy())
412
+
413
+ assert lat is not None and lon is not None
414
+ if lat.size == 0 or lon.size == 0:
415
+ raise ValueError(
416
+ f"ERA5LandDailyUTCv1 selection returned empty grid for item {item.name} "
417
+ f"(bounds={bounds})"
418
+ )
419
+
420
+ with tempfile.TemporaryDirectory() as tmp_dir:
421
+ local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
422
+ self._write_geotiff(local_tif_fname, lat, lon, band_arrays)
423
+ tile_store.write_raster_file(
424
+ item.name, self.band_names, UPath(local_tif_fname)
425
+ )
@@ -5,7 +5,6 @@ import os
5
5
  import tempfile
6
6
  import zipfile
7
7
  from datetime import UTC, datetime, timedelta
8
- from typing import Any
9
8
 
10
9
  import fiona
11
10
  import requests
@@ -153,7 +152,7 @@ class EuroCrops(DataSource[EuroCropsItem]):
153
152
  groups.append(cur_groups)
154
153
  return groups
155
154
 
156
- def deserialize_item(self, serialized_item: Any) -> EuroCropsItem:
155
+ def deserialize_item(self, serialized_item: dict) -> EuroCropsItem:
157
156
  """Deserializes an item from JSON-decoded data."""
158
157
  return EuroCropsItem.deserialize(serialized_item)
159
158
 
@@ -820,9 +820,8 @@ class Sentinel2(DataSource):
820
820
  groups.append(cur_groups)
821
821
  return groups
822
822
 
823
- def deserialize_item(self, serialized_item: Any) -> Sentinel2Item:
823
+ def deserialize_item(self, serialized_item: dict) -> Sentinel2Item:
824
824
  """Deserializes an item from JSON-decoded data."""
825
- assert isinstance(serialized_item, dict)
826
825
  return Sentinel2Item.deserialize(serialized_item)
827
826
 
828
827
  def retrieve_item(
@@ -235,9 +235,8 @@ class GEE(DataSource, TileStore):
235
235
 
236
236
  return groups
237
237
 
238
- def deserialize_item(self, serialized_item: Any) -> Item:
238
+ def deserialize_item(self, serialized_item: dict) -> Item:
239
239
  """Deserializes an item from JSON-decoded data."""
240
- assert isinstance(serialized_item, dict)
241
240
  return Item.deserialize(serialized_item)
242
241
 
243
242
  def item_to_image(self, item: Item) -> ee.image.Image: