rslearn 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/aws_open_data.py +11 -15
- rslearn/data_sources/aws_sentinel2_element84.py +374 -0
- rslearn/data_sources/climate_data_store.py +216 -29
- rslearn/data_sources/gcp_public_data.py +16 -0
- rslearn/data_sources/planetary_computer.py +28 -257
- rslearn/data_sources/soilgrids.py +331 -0
- rslearn/data_sources/stac.py +255 -0
- rslearn/models/attention_pooling.py +5 -2
- rslearn/models/olmoearth_pretrain/model.py +7 -8
- rslearn/train/dataset.py +44 -35
- rslearn/train/lightning_module.py +3 -3
- rslearn/train/tasks/embedding.py +2 -2
- rslearn/train/tasks/task.py +4 -2
- rslearn/utils/geometry.py +2 -2
- rslearn/utils/stac.py +173 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/METADATA +4 -1
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/RECORD +22 -18
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/WHEEL +0 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/top_level.txt +0 -0
|
@@ -24,51 +24,55 @@ from rslearn.utils.geometry import STGeometry
|
|
|
24
24
|
logger = get_logger(__name__)
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
class
|
|
28
|
-
"""
|
|
27
|
+
class ERA5Land(DataSource):
|
|
28
|
+
"""Base class for ingesting ERA5 land data from the Copernicus Climate Data Store.
|
|
29
29
|
|
|
30
30
|
An API key must be passed either in the configuration or via the CDSAPI_KEY
|
|
31
31
|
environment variable. You can acquire an API key by going to the Climate Data Store
|
|
32
32
|
website (https://cds.climate.copernicus.eu/), registering an account and logging
|
|
33
|
-
in, and then
|
|
33
|
+
in, and then getting the API key from the user profile page.
|
|
34
34
|
|
|
35
35
|
The band names should match CDS variable names (see the reference at
|
|
36
36
|
https://confluence.ecmwf.int/display/CKB/ERA5-Land%3A+data+documentation). However,
|
|
37
37
|
replace "_" with "-" in the variable names when specifying bands in the layer
|
|
38
38
|
configuration.
|
|
39
39
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
All requests to the API will be for the whole globe. Although the API supports arbitrary
|
|
43
|
-
bounds in the requests, using the whole available area helps to reduce the total number of
|
|
44
|
-
requests.
|
|
40
|
+
By default, all requests to the API will be for the whole globe. To speed up ingestion,
|
|
41
|
+
we recommend specifying the bounds of the area of interest, in particular for hourly data.
|
|
45
42
|
"""
|
|
46
43
|
|
|
47
44
|
api_url = "https://cds.climate.copernicus.eu/api"
|
|
48
|
-
|
|
49
|
-
# see: https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-land-monthly-means
|
|
50
|
-
DATASET = "reanalysis-era5-land-monthly-means"
|
|
51
|
-
PRODUCT_TYPE = "monthly_averaged_reanalysis"
|
|
52
45
|
DATA_FORMAT = "netcdf"
|
|
53
46
|
DOWNLOAD_FORMAT = "unarchived"
|
|
54
47
|
PIXEL_SIZE = 0.1 # degrees, native resolution is 9km
|
|
55
48
|
|
|
56
49
|
def __init__(
|
|
57
50
|
self,
|
|
51
|
+
dataset: str,
|
|
52
|
+
product_type: str,
|
|
58
53
|
band_names: list[str] | None = None,
|
|
59
54
|
api_key: str | None = None,
|
|
55
|
+
bounds: list[float] | None = None,
|
|
60
56
|
context: DataSourceContext = DataSourceContext(),
|
|
61
57
|
):
|
|
62
|
-
"""Initialize a new
|
|
58
|
+
"""Initialize a new ERA5Land instance.
|
|
63
59
|
|
|
64
60
|
Args:
|
|
61
|
+
dataset: the CDS dataset name (e.g., "reanalysis-era5-land-monthly-means").
|
|
62
|
+
product_type: the CDS product type (e.g., "monthly_averaged_reanalysis").
|
|
65
63
|
band_names: list of band names to acquire. These should correspond to CDS
|
|
66
64
|
variable names but with "_" replaced with "-". This will only be used
|
|
67
65
|
if the layer config is missing from the context.
|
|
68
66
|
api_key: the API key. If not set, it should be set via the CDSAPI_KEY
|
|
69
67
|
environment variable.
|
|
68
|
+
bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
|
|
69
|
+
If not specified, the whole globe will be used.
|
|
70
70
|
context: the data source context.
|
|
71
71
|
"""
|
|
72
|
+
self.dataset = dataset
|
|
73
|
+
self.product_type = product_type
|
|
74
|
+
self.bounds = bounds
|
|
75
|
+
|
|
72
76
|
self.band_names: list[str]
|
|
73
77
|
if context.layer_config is not None:
|
|
74
78
|
self.band_names = []
|
|
@@ -134,8 +138,11 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
134
138
|
# Collect Item list corresponding to the current month.
|
|
135
139
|
items = []
|
|
136
140
|
item_name = f"era5land_monthlyaveraged_{cur_date.year}_{cur_date.month}"
|
|
137
|
-
#
|
|
138
|
-
bounds
|
|
141
|
+
# Use bounds if set, otherwise use whole globe
|
|
142
|
+
if self.bounds is not None:
|
|
143
|
+
bounds = self.bounds
|
|
144
|
+
else:
|
|
145
|
+
bounds = [-180, -90, 180, 90]
|
|
139
146
|
# Time is just the given month.
|
|
140
147
|
start_date = datetime(cur_date.year, cur_date.month, 1, tzinfo=UTC)
|
|
141
148
|
time_range = (
|
|
@@ -172,7 +179,9 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
172
179
|
# But the list of variables should include the bands we want in the correct
|
|
173
180
|
# order. And we can distinguish those bands from other "variables" because they
|
|
174
181
|
# will be 3D while the others will be scalars or 1D.
|
|
175
|
-
|
|
182
|
+
|
|
183
|
+
band_arrays = []
|
|
184
|
+
num_time_steps = None
|
|
176
185
|
for band_name in nc.variables:
|
|
177
186
|
band_data = nc.variables[band_name]
|
|
178
187
|
if len(band_data.shape) != 3:
|
|
@@ -182,18 +191,27 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
182
191
|
logger.debug(
|
|
183
192
|
f"NC file {nc_path} has variable {band_name} with shape {band_data.shape}"
|
|
184
193
|
)
|
|
185
|
-
# Variable data is stored in a 3D array (
|
|
186
|
-
|
|
194
|
+
# Variable data is stored in a 3D array (time, height, width)
|
|
195
|
+
# For hourly data, time is number of days in the month x 24 hours
|
|
196
|
+
if num_time_steps is None:
|
|
197
|
+
num_time_steps = band_data.shape[0]
|
|
198
|
+
elif band_data.shape[0] != num_time_steps:
|
|
187
199
|
raise ValueError(
|
|
188
|
-
f"
|
|
200
|
+
f"Variable {band_name} has {band_data.shape[0]} time steps, "
|
|
201
|
+
f"but expected {num_time_steps}"
|
|
189
202
|
)
|
|
190
|
-
|
|
203
|
+
# Original shape: (time, height, width)
|
|
204
|
+
band_array = np.array(band_data[:])
|
|
205
|
+
band_array = np.expand_dims(band_array, axis=1)
|
|
206
|
+
band_arrays.append(band_array)
|
|
191
207
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
208
|
+
# After concatenation: (time, num_variables, height, width)
|
|
209
|
+
stacked_array = np.concatenate(band_arrays, axis=1)
|
|
210
|
+
|
|
211
|
+
# After reshaping: (time x num_variables, height, width)
|
|
212
|
+
array = stacked_array.reshape(
|
|
213
|
+
-1, stacked_array.shape[2], stacked_array.shape[3]
|
|
214
|
+
)
|
|
197
215
|
|
|
198
216
|
# Get metadata for the GeoTIFF
|
|
199
217
|
lat = nc.variables["latitude"][:]
|
|
@@ -235,6 +253,58 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
235
253
|
) as dst:
|
|
236
254
|
dst.write(array)
|
|
237
255
|
|
|
256
|
+
def ingest(
|
|
257
|
+
self,
|
|
258
|
+
tile_store: TileStoreWithLayer,
|
|
259
|
+
items: list[Item],
|
|
260
|
+
geometries: list[list[STGeometry]],
|
|
261
|
+
) -> None:
|
|
262
|
+
"""Ingest items into the given tile store.
|
|
263
|
+
|
|
264
|
+
This method should be overridden by subclasses.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
tile_store: the tile store to ingest into
|
|
268
|
+
items: the items to ingest
|
|
269
|
+
geometries: a list of geometries needed for each item
|
|
270
|
+
"""
|
|
271
|
+
raise NotImplementedError("Subclasses must implement ingest method")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class ERA5LandMonthlyMeans(ERA5Land):
|
|
275
|
+
"""A data source for ingesting ERA5 land monthly averaged data from the Copernicus Climate Data Store.
|
|
276
|
+
|
|
277
|
+
This data source corresponds to the reanalysis-era5-land-monthly-means product.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
band_names: list[str] | None = None,
|
|
283
|
+
api_key: str | None = None,
|
|
284
|
+
bounds: list[float] | None = None,
|
|
285
|
+
context: DataSourceContext = DataSourceContext(),
|
|
286
|
+
):
|
|
287
|
+
"""Initialize a new ERA5LandMonthlyMeans instance.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
band_names: list of band names to acquire. These should correspond to CDS
|
|
291
|
+
variable names but with "_" replaced with "-". This will only be used
|
|
292
|
+
if the layer config is missing from the context.
|
|
293
|
+
api_key: the API key. If not set, it should be set via the CDSAPI_KEY
|
|
294
|
+
environment variable.
|
|
295
|
+
bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
|
|
296
|
+
If not specified, the whole globe will be used.
|
|
297
|
+
context: the data source context.
|
|
298
|
+
"""
|
|
299
|
+
super().__init__(
|
|
300
|
+
dataset="reanalysis-era5-land-monthly-means",
|
|
301
|
+
product_type="monthly_averaged_reanalysis",
|
|
302
|
+
band_names=band_names,
|
|
303
|
+
api_key=api_key,
|
|
304
|
+
bounds=bounds,
|
|
305
|
+
context=context,
|
|
306
|
+
)
|
|
307
|
+
|
|
238
308
|
def ingest(
|
|
239
309
|
self,
|
|
240
310
|
tile_store: TileStoreWithLayer,
|
|
@@ -256,25 +326,142 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
256
326
|
continue
|
|
257
327
|
|
|
258
328
|
# Send the request to the CDS API
|
|
259
|
-
|
|
329
|
+
if self.bounds is not None:
|
|
330
|
+
min_lon, min_lat, max_lon, max_lat = self.bounds
|
|
331
|
+
area = [max_lat, min_lon, min_lat, max_lon]
|
|
332
|
+
else:
|
|
333
|
+
area = [90, -180, -90, 180] # Whole globe
|
|
334
|
+
|
|
260
335
|
request = {
|
|
261
|
-
"product_type": [self.
|
|
336
|
+
"product_type": [self.product_type],
|
|
262
337
|
"variable": variable_names,
|
|
263
338
|
"year": [f"{item.geometry.time_range[0].year}"], # type: ignore
|
|
264
339
|
"month": [
|
|
265
340
|
f"{item.geometry.time_range[0].month:02d}" # type: ignore
|
|
266
341
|
],
|
|
267
342
|
"time": ["00:00"],
|
|
343
|
+
"area": area,
|
|
344
|
+
"data_format": self.DATA_FORMAT,
|
|
345
|
+
"download_format": self.DOWNLOAD_FORMAT,
|
|
346
|
+
}
|
|
347
|
+
logger.debug(
|
|
348
|
+
f"CDS API request for year={request['year']} month={request['month']} area={area}"
|
|
349
|
+
)
|
|
350
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
351
|
+
local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
|
|
352
|
+
local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
|
|
353
|
+
self.client.retrieve(self.dataset, request, local_nc_fname)
|
|
354
|
+
self._convert_nc_to_tif(
|
|
355
|
+
UPath(local_nc_fname),
|
|
356
|
+
UPath(local_tif_fname),
|
|
357
|
+
)
|
|
358
|
+
tile_store.write_raster_file(
|
|
359
|
+
item.name, self.band_names, UPath(local_tif_fname)
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class ERA5LandHourly(ERA5Land):
|
|
364
|
+
"""A data source for ingesting ERA5 land hourly data from the Copernicus Climate Data Store.
|
|
365
|
+
|
|
366
|
+
This data source corresponds to the reanalysis-era5-land product.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
def __init__(
|
|
370
|
+
self,
|
|
371
|
+
band_names: list[str] | None = None,
|
|
372
|
+
api_key: str | None = None,
|
|
373
|
+
bounds: list[float] | None = None,
|
|
374
|
+
context: DataSourceContext = DataSourceContext(),
|
|
375
|
+
):
|
|
376
|
+
"""Initialize a new ERA5LandHourly instance.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
band_names: list of band names to acquire. These should correspond to CDS
|
|
380
|
+
variable names but with "_" replaced with "-". This will only be used
|
|
381
|
+
if the layer config is missing from the context.
|
|
382
|
+
api_key: the API key. If not set, it should be set via the CDSAPI_KEY
|
|
383
|
+
environment variable.
|
|
384
|
+
bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
|
|
385
|
+
If not specified, the whole globe will be used.
|
|
386
|
+
context: the data source context.
|
|
387
|
+
"""
|
|
388
|
+
super().__init__(
|
|
389
|
+
dataset="reanalysis-era5-land",
|
|
390
|
+
product_type="reanalysis",
|
|
391
|
+
band_names=band_names,
|
|
392
|
+
api_key=api_key,
|
|
393
|
+
bounds=bounds,
|
|
394
|
+
context=context,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
def ingest(
|
|
398
|
+
self,
|
|
399
|
+
tile_store: TileStoreWithLayer,
|
|
400
|
+
items: list[Item],
|
|
401
|
+
geometries: list[list[STGeometry]],
|
|
402
|
+
) -> None:
|
|
403
|
+
"""Ingest items into the given tile store.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
tile_store: the tile store to ingest into
|
|
407
|
+
items: the items to ingest
|
|
408
|
+
geometries: a list of geometries needed for each item
|
|
409
|
+
"""
|
|
410
|
+
# for CDS variable names, replace "-" with "_"
|
|
411
|
+
variable_names = [band.replace("-", "_") for band in self.band_names]
|
|
412
|
+
|
|
413
|
+
for item in items:
|
|
414
|
+
if tile_store.is_raster_ready(item.name, self.band_names):
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
# Send the request to the CDS API
|
|
418
|
+
# If area is not specified, the whole globe will be requested
|
|
419
|
+
time_range = item.geometry.time_range
|
|
420
|
+
if time_range is None:
|
|
421
|
+
raise ValueError("Item must have a time range")
|
|
422
|
+
|
|
423
|
+
# For hourly data, request all days in the month and all 24 hours
|
|
424
|
+
start_time = time_range[0]
|
|
425
|
+
|
|
426
|
+
# Get all days in the month
|
|
427
|
+
year = start_time.year
|
|
428
|
+
month = start_time.month
|
|
429
|
+
# Get the last day of the month
|
|
430
|
+
if month == 12:
|
|
431
|
+
last_day = 31
|
|
432
|
+
else:
|
|
433
|
+
next_month = datetime(year, month + 1, 1, tzinfo=UTC)
|
|
434
|
+
last_day = (next_month - relativedelta(days=1)).day
|
|
435
|
+
|
|
436
|
+
days = [f"{day:02d}" for day in range(1, last_day + 1)]
|
|
437
|
+
|
|
438
|
+
# Get all 24 hours
|
|
439
|
+
hours = [f"{hour:02d}:00" for hour in range(24)]
|
|
440
|
+
|
|
441
|
+
if self.bounds is not None:
|
|
442
|
+
min_lon, min_lat, max_lon, max_lat = self.bounds
|
|
443
|
+
area = [max_lat, min_lon, min_lat, max_lon]
|
|
444
|
+
else:
|
|
445
|
+
area = [90, -180, -90, 180] # Whole globe
|
|
446
|
+
|
|
447
|
+
request = {
|
|
448
|
+
"product_type": [self.product_type],
|
|
449
|
+
"variable": variable_names,
|
|
450
|
+
"year": [f"{year}"],
|
|
451
|
+
"month": [f"{month:02d}"],
|
|
452
|
+
"day": days,
|
|
453
|
+
"time": hours,
|
|
454
|
+
"area": area,
|
|
268
455
|
"data_format": self.DATA_FORMAT,
|
|
269
456
|
"download_format": self.DOWNLOAD_FORMAT,
|
|
270
457
|
}
|
|
271
458
|
logger.debug(
|
|
272
|
-
f"CDS API request for
|
|
459
|
+
f"CDS API request for year={request['year']} month={request['month']} days={len(days)} hours={len(hours)} area={area}"
|
|
273
460
|
)
|
|
274
461
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
275
462
|
local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
|
|
276
463
|
local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
|
|
277
|
-
self.client.retrieve(self.
|
|
464
|
+
self.client.retrieve(self.dataset, request, local_nc_fname)
|
|
278
465
|
self._convert_nc_to_tif(
|
|
279
466
|
UPath(local_nc_fname),
|
|
280
467
|
UPath(local_tif_fname),
|
|
@@ -765,6 +765,22 @@ class Sentinel2(DataSource):
|
|
|
765
765
|
for item in self._read_bigquery(
|
|
766
766
|
time_range=geometry.time_range, wgs84_bbox=wgs84_bbox
|
|
767
767
|
):
|
|
768
|
+
# Get the item from XML to get its exact geometry (BigQuery only knows
|
|
769
|
+
# the bounding box of the item).
|
|
770
|
+
try:
|
|
771
|
+
item = self.get_item_by_name(item.name)
|
|
772
|
+
except CorruptItemException as e:
|
|
773
|
+
logger.warning("skipping corrupt item %s: %s", item.name, e.message)
|
|
774
|
+
continue
|
|
775
|
+
except MissingXMLException:
|
|
776
|
+
# Sometimes a scene that appears in the BigQuery index does not
|
|
777
|
+
# actually have an XML file on GCS. Since we know this happens
|
|
778
|
+
# occasionally, we ignore the error here.
|
|
779
|
+
logger.warning(
|
|
780
|
+
"skipping item %s that is missing XML file", item.name
|
|
781
|
+
)
|
|
782
|
+
continue
|
|
783
|
+
|
|
768
784
|
candidates[idx].append(item)
|
|
769
785
|
|
|
770
786
|
return candidates
|