rslearn 0.0.19__tar.gz → 0.0.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.19/rslearn.egg-info → rslearn-0.0.21}/PKG-INFO +1 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/pyproject.toml +1 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/climate_data_store.py +216 -29
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/anysat.py +35 -33
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/clip.py +5 -2
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/croma.py +11 -3
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/dinov3.py +2 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/faster_rcnn.py +2 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/galileo/galileo.py +58 -31
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/module_wrapper.py +6 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/molmo.py +4 -2
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/olmoearth_pretrain/model.py +95 -32
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/olmoearth_pretrain/norm.py +5 -3
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon.py +3 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/presto/presto.py +45 -15
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/prithvi.py +9 -7
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/sam2_enc.py +3 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/satlaspretrain.py +4 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/simple_time_series.py +36 -16
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/ssl4eo_s12.py +19 -14
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/swin.py +3 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/terramind.py +5 -4
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/all_patches_dataset.py +34 -14
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/dataset.py +73 -8
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/model_context.py +35 -1
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/classification.py +8 -2
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/detection.py +3 -2
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/multi_task.py +2 -3
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/per_pixel_regression.py +14 -5
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/regression.py +8 -2
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/segmentation.py +13 -4
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/task.py +2 -2
- rslearn-0.0.21/rslearn/train/transforms/concatenate.py +89 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/crop.py +22 -8
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/flip.py +13 -5
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/mask.py +11 -2
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/normalize.py +46 -15
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/pad.py +15 -3
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/resize.py +18 -9
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/select_bands.py +11 -2
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/sentinel1.py +18 -3
- {rslearn-0.0.19 → rslearn-0.0.21/rslearn.egg-info}/PKG-INFO +1 -1
- rslearn-0.0.19/rslearn/train/transforms/concatenate.py +0 -49
- {rslearn-0.0.19 → rslearn-0.0.21}/LICENSE +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/NOTICE +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/README.md +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/const.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/storage/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/storage/file.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/storage/storage.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/lightning_cli.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/main.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/attention_pooling.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/component.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/concatenate_features.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/py.typed +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/template_params.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/embedding.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/SOURCES.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.21}/setup.cfg +0 -0
|
@@ -24,51 +24,55 @@ from rslearn.utils.geometry import STGeometry
|
|
|
24
24
|
logger = get_logger(__name__)
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
class
|
|
28
|
-
"""
|
|
27
|
+
class ERA5Land(DataSource):
|
|
28
|
+
"""Base class for ingesting ERA5 land data from the Copernicus Climate Data Store.
|
|
29
29
|
|
|
30
30
|
An API key must be passed either in the configuration or via the CDSAPI_KEY
|
|
31
31
|
environment variable. You can acquire an API key by going to the Climate Data Store
|
|
32
32
|
website (https://cds.climate.copernicus.eu/), registering an account and logging
|
|
33
|
-
in, and then
|
|
33
|
+
in, and then getting the API key from the user profile page.
|
|
34
34
|
|
|
35
35
|
The band names should match CDS variable names (see the reference at
|
|
36
36
|
https://confluence.ecmwf.int/display/CKB/ERA5-Land%3A+data+documentation). However,
|
|
37
37
|
replace "_" with "-" in the variable names when specifying bands in the layer
|
|
38
38
|
configuration.
|
|
39
39
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
All requests to the API will be for the whole globe. Although the API supports arbitrary
|
|
43
|
-
bounds in the requests, using the whole available area helps to reduce the total number of
|
|
44
|
-
requests.
|
|
40
|
+
By default, all requests to the API will be for the whole globe. To speed up ingestion,
|
|
41
|
+
we recommend specifying the bounds of the area of interest, in particular for hourly data.
|
|
45
42
|
"""
|
|
46
43
|
|
|
47
44
|
api_url = "https://cds.climate.copernicus.eu/api"
|
|
48
|
-
|
|
49
|
-
# see: https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-land-monthly-means
|
|
50
|
-
DATASET = "reanalysis-era5-land-monthly-means"
|
|
51
|
-
PRODUCT_TYPE = "monthly_averaged_reanalysis"
|
|
52
45
|
DATA_FORMAT = "netcdf"
|
|
53
46
|
DOWNLOAD_FORMAT = "unarchived"
|
|
54
47
|
PIXEL_SIZE = 0.1 # degrees, native resolution is 9km
|
|
55
48
|
|
|
56
49
|
def __init__(
|
|
57
50
|
self,
|
|
51
|
+
dataset: str,
|
|
52
|
+
product_type: str,
|
|
58
53
|
band_names: list[str] | None = None,
|
|
59
54
|
api_key: str | None = None,
|
|
55
|
+
bounds: list[float] | None = None,
|
|
60
56
|
context: DataSourceContext = DataSourceContext(),
|
|
61
57
|
):
|
|
62
|
-
"""Initialize a new
|
|
58
|
+
"""Initialize a new ERA5Land instance.
|
|
63
59
|
|
|
64
60
|
Args:
|
|
61
|
+
dataset: the CDS dataset name (e.g., "reanalysis-era5-land-monthly-means").
|
|
62
|
+
product_type: the CDS product type (e.g., "monthly_averaged_reanalysis").
|
|
65
63
|
band_names: list of band names to acquire. These should correspond to CDS
|
|
66
64
|
variable names but with "_" replaced with "-". This will only be used
|
|
67
65
|
if the layer config is missing from the context.
|
|
68
66
|
api_key: the API key. If not set, it should be set via the CDSAPI_KEY
|
|
69
67
|
environment variable.
|
|
68
|
+
bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
|
|
69
|
+
If not specified, the whole globe will be used.
|
|
70
70
|
context: the data source context.
|
|
71
71
|
"""
|
|
72
|
+
self.dataset = dataset
|
|
73
|
+
self.product_type = product_type
|
|
74
|
+
self.bounds = bounds
|
|
75
|
+
|
|
72
76
|
self.band_names: list[str]
|
|
73
77
|
if context.layer_config is not None:
|
|
74
78
|
self.band_names = []
|
|
@@ -134,8 +138,11 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
134
138
|
# Collect Item list corresponding to the current month.
|
|
135
139
|
items = []
|
|
136
140
|
item_name = f"era5land_monthlyaveraged_{cur_date.year}_{cur_date.month}"
|
|
137
|
-
#
|
|
138
|
-
bounds
|
|
141
|
+
# Use bounds if set, otherwise use whole globe
|
|
142
|
+
if self.bounds is not None:
|
|
143
|
+
bounds = self.bounds
|
|
144
|
+
else:
|
|
145
|
+
bounds = [-180, -90, 180, 90]
|
|
139
146
|
# Time is just the given month.
|
|
140
147
|
start_date = datetime(cur_date.year, cur_date.month, 1, tzinfo=UTC)
|
|
141
148
|
time_range = (
|
|
@@ -172,7 +179,9 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
172
179
|
# But the list of variables should include the bands we want in the correct
|
|
173
180
|
# order. And we can distinguish those bands from other "variables" because they
|
|
174
181
|
# will be 3D while the others will be scalars or 1D.
|
|
175
|
-
|
|
182
|
+
|
|
183
|
+
band_arrays = []
|
|
184
|
+
num_time_steps = None
|
|
176
185
|
for band_name in nc.variables:
|
|
177
186
|
band_data = nc.variables[band_name]
|
|
178
187
|
if len(band_data.shape) != 3:
|
|
@@ -182,18 +191,27 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
182
191
|
logger.debug(
|
|
183
192
|
f"NC file {nc_path} has variable {band_name} with shape {band_data.shape}"
|
|
184
193
|
)
|
|
185
|
-
# Variable data is stored in a 3D array (
|
|
186
|
-
|
|
194
|
+
# Variable data is stored in a 3D array (time, height, width)
|
|
195
|
+
# For hourly data, time is number of days in the month x 24 hours
|
|
196
|
+
if num_time_steps is None:
|
|
197
|
+
num_time_steps = band_data.shape[0]
|
|
198
|
+
elif band_data.shape[0] != num_time_steps:
|
|
187
199
|
raise ValueError(
|
|
188
|
-
f"
|
|
200
|
+
f"Variable {band_name} has {band_data.shape[0]} time steps, "
|
|
201
|
+
f"but expected {num_time_steps}"
|
|
189
202
|
)
|
|
190
|
-
|
|
203
|
+
# Original shape: (time, height, width)
|
|
204
|
+
band_array = np.array(band_data[:])
|
|
205
|
+
band_array = np.expand_dims(band_array, axis=1)
|
|
206
|
+
band_arrays.append(band_array)
|
|
191
207
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
208
|
+
# After concatenation: (time, num_variables, height, width)
|
|
209
|
+
stacked_array = np.concatenate(band_arrays, axis=1)
|
|
210
|
+
|
|
211
|
+
# After reshaping: (time x num_variables, height, width)
|
|
212
|
+
array = stacked_array.reshape(
|
|
213
|
+
-1, stacked_array.shape[2], stacked_array.shape[3]
|
|
214
|
+
)
|
|
197
215
|
|
|
198
216
|
# Get metadata for the GeoTIFF
|
|
199
217
|
lat = nc.variables["latitude"][:]
|
|
@@ -235,6 +253,58 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
235
253
|
) as dst:
|
|
236
254
|
dst.write(array)
|
|
237
255
|
|
|
256
|
+
def ingest(
|
|
257
|
+
self,
|
|
258
|
+
tile_store: TileStoreWithLayer,
|
|
259
|
+
items: list[Item],
|
|
260
|
+
geometries: list[list[STGeometry]],
|
|
261
|
+
) -> None:
|
|
262
|
+
"""Ingest items into the given tile store.
|
|
263
|
+
|
|
264
|
+
This method should be overridden by subclasses.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
tile_store: the tile store to ingest into
|
|
268
|
+
items: the items to ingest
|
|
269
|
+
geometries: a list of geometries needed for each item
|
|
270
|
+
"""
|
|
271
|
+
raise NotImplementedError("Subclasses must implement ingest method")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class ERA5LandMonthlyMeans(ERA5Land):
|
|
275
|
+
"""A data source for ingesting ERA5 land monthly averaged data from the Copernicus Climate Data Store.
|
|
276
|
+
|
|
277
|
+
This data source corresponds to the reanalysis-era5-land-monthly-means product.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
band_names: list[str] | None = None,
|
|
283
|
+
api_key: str | None = None,
|
|
284
|
+
bounds: list[float] | None = None,
|
|
285
|
+
context: DataSourceContext = DataSourceContext(),
|
|
286
|
+
):
|
|
287
|
+
"""Initialize a new ERA5LandMonthlyMeans instance.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
band_names: list of band names to acquire. These should correspond to CDS
|
|
291
|
+
variable names but with "_" replaced with "-". This will only be used
|
|
292
|
+
if the layer config is missing from the context.
|
|
293
|
+
api_key: the API key. If not set, it should be set via the CDSAPI_KEY
|
|
294
|
+
environment variable.
|
|
295
|
+
bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
|
|
296
|
+
If not specified, the whole globe will be used.
|
|
297
|
+
context: the data source context.
|
|
298
|
+
"""
|
|
299
|
+
super().__init__(
|
|
300
|
+
dataset="reanalysis-era5-land-monthly-means",
|
|
301
|
+
product_type="monthly_averaged_reanalysis",
|
|
302
|
+
band_names=band_names,
|
|
303
|
+
api_key=api_key,
|
|
304
|
+
bounds=bounds,
|
|
305
|
+
context=context,
|
|
306
|
+
)
|
|
307
|
+
|
|
238
308
|
def ingest(
|
|
239
309
|
self,
|
|
240
310
|
tile_store: TileStoreWithLayer,
|
|
@@ -256,25 +326,142 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
256
326
|
continue
|
|
257
327
|
|
|
258
328
|
# Send the request to the CDS API
|
|
259
|
-
|
|
329
|
+
if self.bounds is not None:
|
|
330
|
+
min_lon, min_lat, max_lon, max_lat = self.bounds
|
|
331
|
+
area = [max_lat, min_lon, min_lat, max_lon]
|
|
332
|
+
else:
|
|
333
|
+
area = [90, -180, -90, 180] # Whole globe
|
|
334
|
+
|
|
260
335
|
request = {
|
|
261
|
-
"product_type": [self.
|
|
336
|
+
"product_type": [self.product_type],
|
|
262
337
|
"variable": variable_names,
|
|
263
338
|
"year": [f"{item.geometry.time_range[0].year}"], # type: ignore
|
|
264
339
|
"month": [
|
|
265
340
|
f"{item.geometry.time_range[0].month:02d}" # type: ignore
|
|
266
341
|
],
|
|
267
342
|
"time": ["00:00"],
|
|
343
|
+
"area": area,
|
|
344
|
+
"data_format": self.DATA_FORMAT,
|
|
345
|
+
"download_format": self.DOWNLOAD_FORMAT,
|
|
346
|
+
}
|
|
347
|
+
logger.debug(
|
|
348
|
+
f"CDS API request for year={request['year']} month={request['month']} area={area}"
|
|
349
|
+
)
|
|
350
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
351
|
+
local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
|
|
352
|
+
local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
|
|
353
|
+
self.client.retrieve(self.dataset, request, local_nc_fname)
|
|
354
|
+
self._convert_nc_to_tif(
|
|
355
|
+
UPath(local_nc_fname),
|
|
356
|
+
UPath(local_tif_fname),
|
|
357
|
+
)
|
|
358
|
+
tile_store.write_raster_file(
|
|
359
|
+
item.name, self.band_names, UPath(local_tif_fname)
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class ERA5LandHourly(ERA5Land):
|
|
364
|
+
"""A data source for ingesting ERA5 land hourly data from the Copernicus Climate Data Store.
|
|
365
|
+
|
|
366
|
+
This data source corresponds to the reanalysis-era5-land product.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
def __init__(
|
|
370
|
+
self,
|
|
371
|
+
band_names: list[str] | None = None,
|
|
372
|
+
api_key: str | None = None,
|
|
373
|
+
bounds: list[float] | None = None,
|
|
374
|
+
context: DataSourceContext = DataSourceContext(),
|
|
375
|
+
):
|
|
376
|
+
"""Initialize a new ERA5LandHourly instance.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
band_names: list of band names to acquire. These should correspond to CDS
|
|
380
|
+
variable names but with "_" replaced with "-". This will only be used
|
|
381
|
+
if the layer config is missing from the context.
|
|
382
|
+
api_key: the API key. If not set, it should be set via the CDSAPI_KEY
|
|
383
|
+
environment variable.
|
|
384
|
+
bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
|
|
385
|
+
If not specified, the whole globe will be used.
|
|
386
|
+
context: the data source context.
|
|
387
|
+
"""
|
|
388
|
+
super().__init__(
|
|
389
|
+
dataset="reanalysis-era5-land",
|
|
390
|
+
product_type="reanalysis",
|
|
391
|
+
band_names=band_names,
|
|
392
|
+
api_key=api_key,
|
|
393
|
+
bounds=bounds,
|
|
394
|
+
context=context,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
def ingest(
|
|
398
|
+
self,
|
|
399
|
+
tile_store: TileStoreWithLayer,
|
|
400
|
+
items: list[Item],
|
|
401
|
+
geometries: list[list[STGeometry]],
|
|
402
|
+
) -> None:
|
|
403
|
+
"""Ingest items into the given tile store.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
tile_store: the tile store to ingest into
|
|
407
|
+
items: the items to ingest
|
|
408
|
+
geometries: a list of geometries needed for each item
|
|
409
|
+
"""
|
|
410
|
+
# for CDS variable names, replace "-" with "_"
|
|
411
|
+
variable_names = [band.replace("-", "_") for band in self.band_names]
|
|
412
|
+
|
|
413
|
+
for item in items:
|
|
414
|
+
if tile_store.is_raster_ready(item.name, self.band_names):
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
# Send the request to the CDS API
|
|
418
|
+
# If area is not specified, the whole globe will be requested
|
|
419
|
+
time_range = item.geometry.time_range
|
|
420
|
+
if time_range is None:
|
|
421
|
+
raise ValueError("Item must have a time range")
|
|
422
|
+
|
|
423
|
+
# For hourly data, request all days in the month and all 24 hours
|
|
424
|
+
start_time = time_range[0]
|
|
425
|
+
|
|
426
|
+
# Get all days in the month
|
|
427
|
+
year = start_time.year
|
|
428
|
+
month = start_time.month
|
|
429
|
+
# Get the last day of the month
|
|
430
|
+
if month == 12:
|
|
431
|
+
last_day = 31
|
|
432
|
+
else:
|
|
433
|
+
next_month = datetime(year, month + 1, 1, tzinfo=UTC)
|
|
434
|
+
last_day = (next_month - relativedelta(days=1)).day
|
|
435
|
+
|
|
436
|
+
days = [f"{day:02d}" for day in range(1, last_day + 1)]
|
|
437
|
+
|
|
438
|
+
# Get all 24 hours
|
|
439
|
+
hours = [f"{hour:02d}:00" for hour in range(24)]
|
|
440
|
+
|
|
441
|
+
if self.bounds is not None:
|
|
442
|
+
min_lon, min_lat, max_lon, max_lat = self.bounds
|
|
443
|
+
area = [max_lat, min_lon, min_lat, max_lon]
|
|
444
|
+
else:
|
|
445
|
+
area = [90, -180, -90, 180] # Whole globe
|
|
446
|
+
|
|
447
|
+
request = {
|
|
448
|
+
"product_type": [self.product_type],
|
|
449
|
+
"variable": variable_names,
|
|
450
|
+
"year": [f"{year}"],
|
|
451
|
+
"month": [f"{month:02d}"],
|
|
452
|
+
"day": days,
|
|
453
|
+
"time": hours,
|
|
454
|
+
"area": area,
|
|
268
455
|
"data_format": self.DATA_FORMAT,
|
|
269
456
|
"download_format": self.DOWNLOAD_FORMAT,
|
|
270
457
|
}
|
|
271
458
|
logger.debug(
|
|
272
|
-
f"CDS API request for
|
|
459
|
+
f"CDS API request for year={request['year']} month={request['month']} days={len(days)} hours={len(hours)} area={area}"
|
|
273
460
|
)
|
|
274
461
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
275
462
|
local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
|
|
276
463
|
local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
|
|
277
|
-
self.client.retrieve(self.
|
|
464
|
+
self.client.retrieve(self.dataset, request, local_nc_fname)
|
|
278
465
|
self._convert_nc_to_tif(
|
|
279
466
|
UPath(local_nc_fname),
|
|
280
467
|
UPath(local_tif_fname),
|
|
@@ -4,6 +4,8 @@ This code loads the AnySat model from torch hub. See
|
|
|
4
4
|
https://github.com/gastruc/AnySat for applicable license and copyright information.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
|
|
7
9
|
import torch
|
|
8
10
|
from einops import rearrange
|
|
9
11
|
|
|
@@ -53,7 +55,6 @@ class AnySat(FeatureExtractor):
|
|
|
53
55
|
self,
|
|
54
56
|
modalities: list[str],
|
|
55
57
|
patch_size_meters: int,
|
|
56
|
-
dates: dict[str, list[int]],
|
|
57
58
|
output: str = "patch",
|
|
58
59
|
output_modality: str | None = None,
|
|
59
60
|
hub_repo: str = "gastruc/anysat",
|
|
@@ -85,14 +86,6 @@ class AnySat(FeatureExtractor):
|
|
|
85
86
|
if m not in MODALITY_RESOLUTIONS:
|
|
86
87
|
raise ValueError(f"Invalid modality: {m}")
|
|
87
88
|
|
|
88
|
-
if not all(m in TIME_SERIES_MODALITIES for m in dates.keys()):
|
|
89
|
-
raise ValueError("`dates` keys must be time-series modalities only.")
|
|
90
|
-
for m in modalities:
|
|
91
|
-
if m in TIME_SERIES_MODALITIES and m not in dates:
|
|
92
|
-
raise ValueError(
|
|
93
|
-
f"Missing required dates for time-series modality '{m}'."
|
|
94
|
-
)
|
|
95
|
-
|
|
96
89
|
if patch_size_meters % 10 != 0:
|
|
97
90
|
raise ValueError(
|
|
98
91
|
"In AnySat, `patch_size` is in meters and must be a multiple of 10."
|
|
@@ -106,7 +99,6 @@ class AnySat(FeatureExtractor):
|
|
|
106
99
|
|
|
107
100
|
self.modalities = modalities
|
|
108
101
|
self.patch_size_meters = int(patch_size_meters)
|
|
109
|
-
self.dates = dates
|
|
110
102
|
self.output = output
|
|
111
103
|
self.output_modality = output_modality
|
|
112
104
|
|
|
@@ -119,6 +111,20 @@ class AnySat(FeatureExtractor):
|
|
|
119
111
|
)
|
|
120
112
|
self._embed_dim = 768 # base width, 'dense' returns 2x
|
|
121
113
|
|
|
114
|
+
@staticmethod
|
|
115
|
+
def time_ranges_to_doy(
|
|
116
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
117
|
+
device: torch.device,
|
|
118
|
+
) -> torch.Tensor:
|
|
119
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by AnySat.
|
|
120
|
+
|
|
121
|
+
AnySat uses the doy with each timestamp, so we take the midpoint
|
|
122
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
123
|
+
time so that start_time == end_time == mid_time.
|
|
124
|
+
"""
|
|
125
|
+
doys = [(t[0] + ((t[1] - t[0]) / 2)).timetuple().tm_yday for t in time_ranges]
|
|
126
|
+
return torch.tensor(doys, dtype=torch.int32, device=device)
|
|
127
|
+
|
|
122
128
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
123
129
|
"""Forward pass for the AnySat model.
|
|
124
130
|
|
|
@@ -139,17 +145,29 @@ class AnySat(FeatureExtractor):
|
|
|
139
145
|
raise ValueError(f"Modality '{modality}' not present in inputs.")
|
|
140
146
|
|
|
141
147
|
cur = torch.stack(
|
|
142
|
-
[inp[modality] for inp in inputs], dim=0
|
|
143
|
-
) # (B, C,
|
|
148
|
+
[inp[modality].image for inp in inputs], dim=0
|
|
149
|
+
) # (B, C, T, H, W)
|
|
144
150
|
|
|
145
151
|
if modality in TIME_SERIES_MODALITIES:
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
cur = rearrange(
|
|
149
|
-
cur, "b (t c) h w -> b t c h w", t=num_dates, c=num_bands
|
|
150
|
-
)
|
|
152
|
+
num_bands = cur.shape[1]
|
|
153
|
+
cur = rearrange(cur, "b c t h w -> b t c h w")
|
|
151
154
|
H, W = cur.shape[-2], cur.shape[-1]
|
|
155
|
+
|
|
156
|
+
if inputs[0][modality].timestamps is None:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Require timestamps for time series modality {modality}"
|
|
159
|
+
)
|
|
160
|
+
timestamps = torch.stack(
|
|
161
|
+
[
|
|
162
|
+
self.time_ranges_to_doy(inp[modality].timestamps, cur.device) # type: ignore
|
|
163
|
+
for inp in inputs
|
|
164
|
+
],
|
|
165
|
+
dim=0,
|
|
166
|
+
)
|
|
167
|
+
batch[f"{modality}_dates"] = timestamps
|
|
152
168
|
else:
|
|
169
|
+
# take the first (assumed only) timestep
|
|
170
|
+
cur = cur[:, :, 0]
|
|
153
171
|
num_bands = cur.shape[1]
|
|
154
172
|
H, W = cur.shape[-2], cur.shape[-1]
|
|
155
173
|
|
|
@@ -173,22 +191,6 @@ class AnySat(FeatureExtractor):
|
|
|
173
191
|
"All modalities must share the same spatial extent (H*res, W*res)."
|
|
174
192
|
)
|
|
175
193
|
|
|
176
|
-
# Add *_dates
|
|
177
|
-
to_add = {}
|
|
178
|
-
for modality, x in list(batch.items()):
|
|
179
|
-
if modality in TIME_SERIES_MODALITIES:
|
|
180
|
-
B, T = x.shape[0], x.shape[1]
|
|
181
|
-
d = torch.as_tensor(
|
|
182
|
-
self.dates[modality], dtype=torch.long, device=x.device
|
|
183
|
-
)
|
|
184
|
-
if d.ndim != 1 or d.numel() != T:
|
|
185
|
-
raise ValueError(
|
|
186
|
-
f"dates for '{modality}' must be 1D length {T}, got {tuple(d.shape)}"
|
|
187
|
-
)
|
|
188
|
-
to_add[f"{modality}_dates"] = d.unsqueeze(0).repeat(B, 1)
|
|
189
|
-
|
|
190
|
-
batch.update(to_add)
|
|
191
|
-
|
|
192
194
|
kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
|
|
193
195
|
if self.output == "dense":
|
|
194
196
|
kwargs["output_modality"] = self.output_modality
|
|
@@ -43,9 +43,12 @@ class CLIP(FeatureExtractor):
|
|
|
43
43
|
a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
|
|
44
44
|
"""
|
|
45
45
|
inputs = context.inputs
|
|
46
|
-
device = inputs[0]["image"].device
|
|
46
|
+
device = inputs[0]["image"].image.device
|
|
47
47
|
clip_inputs = self.processor(
|
|
48
|
-
images=[
|
|
48
|
+
images=[
|
|
49
|
+
inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
|
|
50
|
+
for inp in inputs
|
|
51
|
+
],
|
|
49
52
|
return_tensors="pt",
|
|
50
53
|
padding=True,
|
|
51
54
|
)
|
|
@@ -175,10 +175,16 @@ class Croma(FeatureExtractor):
|
|
|
175
175
|
sentinel1: torch.Tensor | None = None
|
|
176
176
|
sentinel2: torch.Tensor | None = None
|
|
177
177
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
|
|
178
|
-
sentinel1 = torch.stack(
|
|
178
|
+
sentinel1 = torch.stack(
|
|
179
|
+
[inp["sentinel1"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
180
|
+
dim=0,
|
|
181
|
+
)
|
|
179
182
|
sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
|
|
180
183
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
|
|
181
|
-
sentinel2 = torch.stack(
|
|
184
|
+
sentinel2 = torch.stack(
|
|
185
|
+
[inp["sentinel2"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
186
|
+
dim=0,
|
|
187
|
+
)
|
|
182
188
|
sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
|
|
183
189
|
|
|
184
190
|
outputs = self.model(
|
|
@@ -294,5 +300,7 @@ class CromaNormalize(Transform):
|
|
|
294
300
|
for modality in MODALITY_BANDS.keys():
|
|
295
301
|
if modality not in input_dict:
|
|
296
302
|
continue
|
|
297
|
-
input_dict[modality] = self.apply_image(
|
|
303
|
+
input_dict[modality].image = self.apply_image(
|
|
304
|
+
input_dict[modality].image, modality
|
|
305
|
+
)
|
|
298
306
|
return input_dict, target_dict
|
|
@@ -104,7 +104,8 @@ class DinoV3(FeatureExtractor):
|
|
|
104
104
|
a FeatureMaps with one feature map.
|
|
105
105
|
"""
|
|
106
106
|
cur = torch.stack(
|
|
107
|
-
[inp["image"] for inp in context.inputs],
|
|
107
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
108
|
+
dim=0,
|
|
108
109
|
) # (B, C, H, W)
|
|
109
110
|
|
|
110
111
|
if self.do_resizing and (
|
|
@@ -210,7 +210,8 @@ class FasterRCNN(Predictor):
|
|
|
210
210
|
),
|
|
211
211
|
)
|
|
212
212
|
|
|
213
|
-
|
|
213
|
+
# take the first (and assumed to be only) timestep
|
|
214
|
+
image_list = [inp["image"].image[:, 0] for inp in context.inputs]
|
|
214
215
|
images, targets = self.noop_transform(image_list, targets)
|
|
215
216
|
|
|
216
217
|
feature_dict = collections.OrderedDict()
|