rslearn 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,51 +24,55 @@ from rslearn.utils.geometry import STGeometry
24
24
  logger = get_logger(__name__)
25
25
 
26
26
 
27
- class ERA5LandMonthlyMeans(DataSource):
28
- """A data source for ingesting ERA5 land monthly averaged data from the Copernicus Climate Data Store.
27
+ class ERA5Land(DataSource):
28
+ """Base class for ingesting ERA5 land data from the Copernicus Climate Data Store.
29
29
 
30
30
  An API key must be passed either in the configuration or via the CDSAPI_KEY
31
31
  environment variable. You can acquire an API key by going to the Climate Data Store
32
32
  website (https://cds.climate.copernicus.eu/), registering an account and logging
33
- in, and then
33
+ in, and then getting the API key from the user profile page.
34
34
 
35
35
  The band names should match CDS variable names (see the reference at
36
36
  https://confluence.ecmwf.int/display/CKB/ERA5-Land%3A+data+documentation). However,
37
37
  replace "_" with "-" in the variable names when specifying bands in the layer
38
38
  configuration.
39
39
 
40
- This data source corresponds to the reanalysis-era5-land-monthly-means product.
41
-
42
- All requests to the API will be for the whole globe. Although the API supports arbitrary
43
- bounds in the requests, using the whole available area helps to reduce the total number of
44
- requests.
40
+ By default, all requests to the API will be for the whole globe. To speed up ingestion,
41
+ we recommend specifying the bounds of the area of interest, in particular for hourly data.
45
42
  """
46
43
 
47
44
  api_url = "https://cds.climate.copernicus.eu/api"
48
-
49
- # see: https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-land-monthly-means
50
- DATASET = "reanalysis-era5-land-monthly-means"
51
- PRODUCT_TYPE = "monthly_averaged_reanalysis"
52
45
  DATA_FORMAT = "netcdf"
53
46
  DOWNLOAD_FORMAT = "unarchived"
54
47
  PIXEL_SIZE = 0.1 # degrees, native resolution is 9km
55
48
 
56
49
  def __init__(
57
50
  self,
51
+ dataset: str,
52
+ product_type: str,
58
53
  band_names: list[str] | None = None,
59
54
  api_key: str | None = None,
55
+ bounds: list[float] | None = None,
60
56
  context: DataSourceContext = DataSourceContext(),
61
57
  ):
62
- """Initialize a new ERA5LandMonthlyMeans instance.
58
+ """Initialize a new ERA5Land instance.
63
59
 
64
60
  Args:
61
+ dataset: the CDS dataset name (e.g., "reanalysis-era5-land-monthly-means").
62
+ product_type: the CDS product type (e.g., "monthly_averaged_reanalysis").
65
63
  band_names: list of band names to acquire. These should correspond to CDS
66
64
  variable names but with "_" replaced with "-". This will only be used
67
65
  if the layer config is missing from the context.
68
66
  api_key: the API key. If not set, it should be set via the CDSAPI_KEY
69
67
  environment variable.
68
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
69
+ If not specified, the whole globe will be used.
70
70
  context: the data source context.
71
71
  """
72
+ self.dataset = dataset
73
+ self.product_type = product_type
74
+ self.bounds = bounds
75
+
72
76
  self.band_names: list[str]
73
77
  if context.layer_config is not None:
74
78
  self.band_names = []
@@ -134,8 +138,11 @@ class ERA5LandMonthlyMeans(DataSource):
134
138
  # Collect Item list corresponding to the current month.
135
139
  items = []
136
140
  item_name = f"era5land_monthlyaveraged_{cur_date.year}_{cur_date.month}"
137
- # Space is the whole globe.
138
- bounds = (-180, -90, 180, 90)
141
+ # Use bounds if set, otherwise use whole globe
142
+ if self.bounds is not None:
143
+ bounds = self.bounds
144
+ else:
145
+ bounds = [-180, -90, 180, 90]
139
146
  # Time is just the given month.
140
147
  start_date = datetime(cur_date.year, cur_date.month, 1, tzinfo=UTC)
141
148
  time_range = (
@@ -172,7 +179,9 @@ class ERA5LandMonthlyMeans(DataSource):
172
179
  # But the list of variables should include the bands we want in the correct
173
180
  # order. And we can distinguish those bands from other "variables" because they
174
181
  # will be 3D while the others will be scalars or 1D.
175
- bands_data = []
182
+
183
+ band_arrays = []
184
+ num_time_steps = None
176
185
  for band_name in nc.variables:
177
186
  band_data = nc.variables[band_name]
178
187
  if len(band_data.shape) != 3:
@@ -182,18 +191,27 @@ class ERA5LandMonthlyMeans(DataSource):
182
191
  logger.debug(
183
192
  f"NC file {nc_path} has variable {band_name} with shape {band_data.shape}"
184
193
  )
185
- # Variable data is stored in a 3D array (1, height, width)
186
- if band_data.shape[0] != 1:
194
+ # Variable data is stored in a 3D array (time, height, width)
195
+ # For hourly data, time is number of days in the month x 24 hours
196
+ if num_time_steps is None:
197
+ num_time_steps = band_data.shape[0]
198
+ elif band_data.shape[0] != num_time_steps:
187
199
  raise ValueError(
188
- f"Bad shape for band {band_name}, expected 1 band but got {band_data.shape[0]}"
200
+ f"Variable {band_name} has {band_data.shape[0]} time steps, "
201
+ f"but expected {num_time_steps}"
189
202
  )
190
- bands_data.append(band_data[0, :, :])
203
+ # Original shape: (time, height, width)
204
+ band_array = np.array(band_data[:])
205
+ band_array = np.expand_dims(band_array, axis=1)
206
+ band_arrays.append(band_array)
191
207
 
192
- array = np.array(bands_data) # (num_bands, height, width)
193
- if array.shape[0] != len(self.band_names):
194
- raise ValueError(
195
- f"Expected to get {len(self.band_names)} bands but got {array.shape[0]}"
196
- )
208
+ # After concatenation: (time, num_variables, height, width)
209
+ stacked_array = np.concatenate(band_arrays, axis=1)
210
+
211
+ # After reshaping: (time x num_variables, height, width)
212
+ array = stacked_array.reshape(
213
+ -1, stacked_array.shape[2], stacked_array.shape[3]
214
+ )
197
215
 
198
216
  # Get metadata for the GeoTIFF
199
217
  lat = nc.variables["latitude"][:]
@@ -235,6 +253,58 @@ class ERA5LandMonthlyMeans(DataSource):
235
253
  ) as dst:
236
254
  dst.write(array)
237
255
 
256
+ def ingest(
257
+ self,
258
+ tile_store: TileStoreWithLayer,
259
+ items: list[Item],
260
+ geometries: list[list[STGeometry]],
261
+ ) -> None:
262
+ """Ingest items into the given tile store.
263
+
264
+ This method should be overridden by subclasses.
265
+
266
+ Args:
267
+ tile_store: the tile store to ingest into
268
+ items: the items to ingest
269
+ geometries: a list of geometries needed for each item
270
+ """
271
+ raise NotImplementedError("Subclasses must implement ingest method")
272
+
273
+
274
+ class ERA5LandMonthlyMeans(ERA5Land):
275
+ """A data source for ingesting ERA5 land monthly averaged data from the Copernicus Climate Data Store.
276
+
277
+ This data source corresponds to the reanalysis-era5-land-monthly-means product.
278
+ """
279
+
280
+ def __init__(
281
+ self,
282
+ band_names: list[str] | None = None,
283
+ api_key: str | None = None,
284
+ bounds: list[float] | None = None,
285
+ context: DataSourceContext = DataSourceContext(),
286
+ ):
287
+ """Initialize a new ERA5LandMonthlyMeans instance.
288
+
289
+ Args:
290
+ band_names: list of band names to acquire. These should correspond to CDS
291
+ variable names but with "_" replaced with "-". This will only be used
292
+ if the layer config is missing from the context.
293
+ api_key: the API key. If not set, it should be set via the CDSAPI_KEY
294
+ environment variable.
295
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
296
+ If not specified, the whole globe will be used.
297
+ context: the data source context.
298
+ """
299
+ super().__init__(
300
+ dataset="reanalysis-era5-land-monthly-means",
301
+ product_type="monthly_averaged_reanalysis",
302
+ band_names=band_names,
303
+ api_key=api_key,
304
+ bounds=bounds,
305
+ context=context,
306
+ )
307
+
238
308
  def ingest(
239
309
  self,
240
310
  tile_store: TileStoreWithLayer,
@@ -256,25 +326,142 @@ class ERA5LandMonthlyMeans(DataSource):
256
326
  continue
257
327
 
258
328
  # Send the request to the CDS API
259
- # If area is not specified, the whole globe will be requested
329
+ if self.bounds is not None:
330
+ min_lon, min_lat, max_lon, max_lat = self.bounds
331
+ area = [max_lat, min_lon, min_lat, max_lon]
332
+ else:
333
+ area = [90, -180, -90, 180] # Whole globe
334
+
260
335
  request = {
261
- "product_type": [self.PRODUCT_TYPE],
336
+ "product_type": [self.product_type],
262
337
  "variable": variable_names,
263
338
  "year": [f"{item.geometry.time_range[0].year}"], # type: ignore
264
339
  "month": [
265
340
  f"{item.geometry.time_range[0].month:02d}" # type: ignore
266
341
  ],
267
342
  "time": ["00:00"],
343
+ "area": area,
344
+ "data_format": self.DATA_FORMAT,
345
+ "download_format": self.DOWNLOAD_FORMAT,
346
+ }
347
+ logger.debug(
348
+ f"CDS API request for year={request['year']} month={request['month']} area={area}"
349
+ )
350
+ with tempfile.TemporaryDirectory() as tmp_dir:
351
+ local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
352
+ local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
353
+ self.client.retrieve(self.dataset, request, local_nc_fname)
354
+ self._convert_nc_to_tif(
355
+ UPath(local_nc_fname),
356
+ UPath(local_tif_fname),
357
+ )
358
+ tile_store.write_raster_file(
359
+ item.name, self.band_names, UPath(local_tif_fname)
360
+ )
361
+
362
+
363
+ class ERA5LandHourly(ERA5Land):
364
+ """A data source for ingesting ERA5 land hourly data from the Copernicus Climate Data Store.
365
+
366
+ This data source corresponds to the reanalysis-era5-land product.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ band_names: list[str] | None = None,
372
+ api_key: str | None = None,
373
+ bounds: list[float] | None = None,
374
+ context: DataSourceContext = DataSourceContext(),
375
+ ):
376
+ """Initialize a new ERA5LandHourly instance.
377
+
378
+ Args:
379
+ band_names: list of band names to acquire. These should correspond to CDS
380
+ variable names but with "_" replaced with "-". This will only be used
381
+ if the layer config is missing from the context.
382
+ api_key: the API key. If not set, it should be set via the CDSAPI_KEY
383
+ environment variable.
384
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
385
+ If not specified, the whole globe will be used.
386
+ context: the data source context.
387
+ """
388
+ super().__init__(
389
+ dataset="reanalysis-era5-land",
390
+ product_type="reanalysis",
391
+ band_names=band_names,
392
+ api_key=api_key,
393
+ bounds=bounds,
394
+ context=context,
395
+ )
396
+
397
+ def ingest(
398
+ self,
399
+ tile_store: TileStoreWithLayer,
400
+ items: list[Item],
401
+ geometries: list[list[STGeometry]],
402
+ ) -> None:
403
+ """Ingest items into the given tile store.
404
+
405
+ Args:
406
+ tile_store: the tile store to ingest into
407
+ items: the items to ingest
408
+ geometries: a list of geometries needed for each item
409
+ """
410
+ # for CDS variable names, replace "-" with "_"
411
+ variable_names = [band.replace("-", "_") for band in self.band_names]
412
+
413
+ for item in items:
414
+ if tile_store.is_raster_ready(item.name, self.band_names):
415
+ continue
416
+
417
+ # Send the request to the CDS API
418
+ # If area is not specified, the whole globe will be requested
419
+ time_range = item.geometry.time_range
420
+ if time_range is None:
421
+ raise ValueError("Item must have a time range")
422
+
423
+ # For hourly data, request all days in the month and all 24 hours
424
+ start_time = time_range[0]
425
+
426
+ # Get all days in the month
427
+ year = start_time.year
428
+ month = start_time.month
429
+ # Get the last day of the month
430
+ if month == 12:
431
+ last_day = 31
432
+ else:
433
+ next_month = datetime(year, month + 1, 1, tzinfo=UTC)
434
+ last_day = (next_month - relativedelta(days=1)).day
435
+
436
+ days = [f"{day:02d}" for day in range(1, last_day + 1)]
437
+
438
+ # Get all 24 hours
439
+ hours = [f"{hour:02d}:00" for hour in range(24)]
440
+
441
+ if self.bounds is not None:
442
+ min_lon, min_lat, max_lon, max_lat = self.bounds
443
+ area = [max_lat, min_lon, min_lat, max_lon]
444
+ else:
445
+ area = [90, -180, -90, 180] # Whole globe
446
+
447
+ request = {
448
+ "product_type": [self.product_type],
449
+ "variable": variable_names,
450
+ "year": [f"{year}"],
451
+ "month": [f"{month:02d}"],
452
+ "day": days,
453
+ "time": hours,
454
+ "area": area,
268
455
  "data_format": self.DATA_FORMAT,
269
456
  "download_format": self.DOWNLOAD_FORMAT,
270
457
  }
271
458
  logger.debug(
272
- f"CDS API request for the whole globe for year={request['year']} month={request['month']}"
459
+ f"CDS API request for year={request['year']} month={request['month']} days={len(days)} hours={len(hours)} area={area}"
273
460
  )
274
461
  with tempfile.TemporaryDirectory() as tmp_dir:
275
462
  local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
276
463
  local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
277
- self.client.retrieve(self.DATASET, request, local_nc_fname)
464
+ self.client.retrieve(self.dataset, request, local_nc_fname)
278
465
  self._convert_nc_to_tif(
279
466
  UPath(local_nc_fname),
280
467
  UPath(local_tif_fname),
@@ -765,6 +765,22 @@ class Sentinel2(DataSource):
765
765
  for item in self._read_bigquery(
766
766
  time_range=geometry.time_range, wgs84_bbox=wgs84_bbox
767
767
  ):
768
+ # Get the item from XML to get its exact geometry (BigQuery only knows
769
+ # the bounding box of the item).
770
+ try:
771
+ item = self.get_item_by_name(item.name)
772
+ except CorruptItemException as e:
773
+ logger.warning("skipping corrupt item %s: %s", item.name, e.message)
774
+ continue
775
+ except MissingXMLException:
776
+ # Sometimes a scene that appears in the BigQuery index does not
777
+ # actually have an XML file on GCS. Since we know this happens
778
+ # occasionally, we ignore the error here.
779
+ logger.warning(
780
+ "skipping item %s that is missing XML file", item.name
781
+ )
782
+ continue
783
+
768
784
  candidates[idx].append(item)
769
785
 
770
786
  return candidates