rslearn 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +414 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +28 -26
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +370 -1
- rslearn/main.py +3 -3
- rslearn/models/clay/clay.py +14 -1
- rslearn/models/concatenate_features.py +93 -0
- rslearn/models/croma.py +26 -3
- rslearn/models/satlaspretrain.py +18 -4
- rslearn/models/terramind.py +19 -0
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/RECORD +51 -52
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
rslearn/data_sources/__init__.py
CHANGED
|
@@ -10,40 +10,17 @@ Each source supports operations to lookup items that match with spatiotemporal
|
|
|
10
10
|
geometries, and ingest those items.
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
|
-
import
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
from .data_source import DataSource, Item, ItemLookupDataSource, RetrieveItemDataSource
|
|
22
|
-
|
|
23
|
-
logger = get_logger(__name__)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@functools.cache
|
|
27
|
-
def data_source_from_config(config: LayerConfig, ds_path: UPath) -> DataSource:
|
|
28
|
-
"""Loads a data source from config dict.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
config: the LayerConfig containing this data source.
|
|
32
|
-
ds_path: the dataset root directory.
|
|
33
|
-
"""
|
|
34
|
-
logger.debug("getting a data source for dataset at %s", ds_path)
|
|
35
|
-
if config.data_source is None:
|
|
36
|
-
raise ValueError("No data source specified")
|
|
37
|
-
name = config.data_source.name
|
|
38
|
-
module_name = ".".join(name.split(".")[:-1])
|
|
39
|
-
class_name = name.split(".")[-1]
|
|
40
|
-
module = importlib.import_module(module_name)
|
|
41
|
-
class_ = getattr(module, class_name)
|
|
42
|
-
return class_.from_config(config, ds_path)
|
|
43
|
-
|
|
13
|
+
from .data_source import (
|
|
14
|
+
DataSource,
|
|
15
|
+
DataSourceContext,
|
|
16
|
+
Item,
|
|
17
|
+
ItemLookupDataSource,
|
|
18
|
+
RetrieveItemDataSource,
|
|
19
|
+
)
|
|
44
20
|
|
|
45
21
|
__all__ = (
|
|
46
22
|
"DataSource",
|
|
23
|
+
"DataSourceContext",
|
|
47
24
|
"Item",
|
|
48
25
|
"ItemLookupDataSource",
|
|
49
26
|
"RetrieveItemDataSource",
|
|
@@ -25,7 +25,7 @@ from rasterio.enums import Resampling
|
|
|
25
25
|
from upath import UPath
|
|
26
26
|
|
|
27
27
|
import rslearn.data_sources.utils
|
|
28
|
-
from rslearn.config import LayerConfig
|
|
28
|
+
from rslearn.config import LayerConfig
|
|
29
29
|
from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_PROJECTION
|
|
30
30
|
from rslearn.dataset import Window
|
|
31
31
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
@@ -34,7 +34,7 @@ from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
|
|
|
34
34
|
from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
|
|
35
35
|
from rslearn.utils.grid_index import GridIndex
|
|
36
36
|
|
|
37
|
-
from .data_source import DataSource, Item, QueryConfig
|
|
37
|
+
from .data_source import DataSource, DataSourceContext, Item, QueryConfig
|
|
38
38
|
|
|
39
39
|
WRS2_GRID_SIZE = 1.0
|
|
40
40
|
|
|
@@ -98,20 +98,25 @@ class LandsatOliTirs(DataSource, TileStore):
|
|
|
98
98
|
|
|
99
99
|
def __init__(
|
|
100
100
|
self,
|
|
101
|
-
|
|
102
|
-
metadata_cache_dir: UPath,
|
|
101
|
+
metadata_cache_dir: str,
|
|
103
102
|
sort_by: str | None = None,
|
|
103
|
+
context: DataSourceContext = DataSourceContext(),
|
|
104
104
|
) -> None:
|
|
105
105
|
"""Initialize a new LandsatOliTirs instance.
|
|
106
106
|
|
|
107
107
|
Args:
|
|
108
|
-
|
|
109
|
-
metadata_cache_dir: directory to cache product metadata files.
|
|
108
|
+
metadata_cache_dir: directory to cache produtc metadata files.
|
|
110
109
|
sort_by: can be "cloud_cover", default arbitrary order; only has effect for
|
|
111
110
|
SpaceMode.WITHIN.
|
|
111
|
+
context: the data source context.
|
|
112
112
|
"""
|
|
113
|
-
|
|
114
|
-
|
|
113
|
+
# If context is provided, we join the directory with the dataset path,
|
|
114
|
+
# otherwise we treat it directly as UPath.
|
|
115
|
+
if context.ds_path is not None:
|
|
116
|
+
self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
|
|
117
|
+
else:
|
|
118
|
+
self.metadata_cache_dir = UPath(metadata_cache_dir)
|
|
119
|
+
|
|
115
120
|
self.sort_by = sort_by
|
|
116
121
|
|
|
117
122
|
self.client = boto3.client("s3")
|
|
@@ -120,21 +125,6 @@ class LandsatOliTirs(DataSource, TileStore):
|
|
|
120
125
|
|
|
121
126
|
self.wrs2_index: GridIndex | None = None
|
|
122
127
|
|
|
123
|
-
@staticmethod
|
|
124
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "LandsatOliTirs":
|
|
125
|
-
"""Creates a new LandsatOliTirs instance from a configuration dictionary."""
|
|
126
|
-
if config.data_source is None:
|
|
127
|
-
raise ValueError(f"data_source is required for config dict {config}")
|
|
128
|
-
d = config.data_source.config_dict
|
|
129
|
-
kwargs = dict(
|
|
130
|
-
config=config,
|
|
131
|
-
metadata_cache_dir=join_upath(ds_path, d["metadata_cache_dir"]),
|
|
132
|
-
)
|
|
133
|
-
if "sort_by" in d:
|
|
134
|
-
kwargs["sort_by"] = d["sort_by"]
|
|
135
|
-
|
|
136
|
-
return LandsatOliTirs(**kwargs)
|
|
137
|
-
|
|
138
128
|
def _read_products(
|
|
139
129
|
self, needed_year_pathrows: set[tuple[int, str, str]]
|
|
140
130
|
) -> Generator[LandsatOliTirsItem, None, None]:
|
|
@@ -536,7 +526,6 @@ class LandsatOliTirs(DataSource, TileStore):
|
|
|
536
526
|
layer_name: the name of this layer
|
|
537
527
|
layer_cfg: the config of this layer
|
|
538
528
|
"""
|
|
539
|
-
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
540
529
|
RasterMaterializer().materialize(
|
|
541
530
|
TileStoreWithLayer(self, layer_name),
|
|
542
531
|
window,
|
|
@@ -22,7 +22,6 @@ from rasterio.crs import CRS
|
|
|
22
22
|
from upath import UPath
|
|
23
23
|
|
|
24
24
|
import rslearn.data_sources.utils
|
|
25
|
-
from rslearn.config import RasterLayerConfig
|
|
26
25
|
from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_EPSG, WGS84_PROJECTION
|
|
27
26
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
28
27
|
from rslearn.utils import GridIndex, Projection, STGeometry, daterange
|
|
@@ -32,6 +31,7 @@ from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
|
32
31
|
from .copernicus import get_harmonize_callback, get_sentinel2_tiles
|
|
33
32
|
from .data_source import (
|
|
34
33
|
DataSource,
|
|
34
|
+
DataSourceContext,
|
|
35
35
|
Item,
|
|
36
36
|
ItemLookupDataSource,
|
|
37
37
|
QueryConfig,
|
|
@@ -83,16 +83,15 @@ class Naip(DataSource):
|
|
|
83
83
|
|
|
84
84
|
def __init__(
|
|
85
85
|
self,
|
|
86
|
-
|
|
87
|
-
index_cache_dir: UPath,
|
|
86
|
+
index_cache_dir: str,
|
|
88
87
|
use_rtree_index: bool = False,
|
|
89
88
|
states: list[str] | None = None,
|
|
90
89
|
years: list[int] | None = None,
|
|
90
|
+
context: DataSourceContext = DataSourceContext(),
|
|
91
91
|
) -> None:
|
|
92
92
|
"""Initialize a new Naip instance.
|
|
93
93
|
|
|
94
94
|
Args:
|
|
95
|
-
config: the LayerConfig of the layer containing this data source.
|
|
96
95
|
index_cache_dir: directory to cache index shapefiles.
|
|
97
96
|
use_rtree_index: whether to create an rtree index to enable faster lookups
|
|
98
97
|
(default false)
|
|
@@ -100,9 +99,15 @@ class Naip(DataSource):
|
|
|
100
99
|
the search. If use_rtree_index is enabled, the rtree will only be
|
|
101
100
|
populated with data from these states.
|
|
102
101
|
years: optional list of years to restrict the search
|
|
102
|
+
context: the data source context.
|
|
103
103
|
"""
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
# If context is provided, we join the directory with the dataset path,
|
|
105
|
+
# otherwise we treat it directly as UPath.
|
|
106
|
+
if context.ds_path is not None:
|
|
107
|
+
self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
|
|
108
|
+
else:
|
|
109
|
+
self.index_cache_dir = UPath(index_cache_dir)
|
|
110
|
+
|
|
106
111
|
self.states = states
|
|
107
112
|
self.years = years
|
|
108
113
|
|
|
@@ -119,22 +124,6 @@ class Naip(DataSource):
|
|
|
119
124
|
|
|
120
125
|
self.rtree_index = get_cached_rtree(self.index_cache_dir, build_fn)
|
|
121
126
|
|
|
122
|
-
@staticmethod
|
|
123
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
|
|
124
|
-
"""Creates a new Naip instance from a configuration dictionary."""
|
|
125
|
-
if config.data_source is None:
|
|
126
|
-
raise ValueError(f"data_source is required for config dict {config}")
|
|
127
|
-
d = config.data_source.config_dict
|
|
128
|
-
kwargs = dict(
|
|
129
|
-
config=config,
|
|
130
|
-
index_cache_dir=join_upath(ds_path, d["index_cache_dir"]),
|
|
131
|
-
)
|
|
132
|
-
simple_optionals = ["use_rtree_index", "states", "years"]
|
|
133
|
-
for k in simple_optionals:
|
|
134
|
-
if k in d:
|
|
135
|
-
kwargs[k] = d[k]
|
|
136
|
-
return Naip(**kwargs)
|
|
137
|
-
|
|
138
127
|
def _download_manifest(self) -> UPath:
|
|
139
128
|
"""Download the manifest that enumerates files in the bucket.
|
|
140
129
|
|
|
@@ -460,51 +449,37 @@ class Sentinel2(
|
|
|
460
449
|
|
|
461
450
|
def __init__(
|
|
462
451
|
self,
|
|
463
|
-
config: RasterLayerConfig,
|
|
464
452
|
modality: Sentinel2Modality,
|
|
465
|
-
metadata_cache_dir:
|
|
453
|
+
metadata_cache_dir: str,
|
|
466
454
|
sort_by: str | None = None,
|
|
467
455
|
harmonize: bool = False,
|
|
456
|
+
context: DataSourceContext = DataSourceContext(),
|
|
468
457
|
) -> None:
|
|
469
458
|
"""Initialize a new Sentinel2 instance.
|
|
470
459
|
|
|
471
460
|
Args:
|
|
472
|
-
config: the LayerConfig of the layer containing this data source.
|
|
473
461
|
modality: L1C or L2A.
|
|
474
462
|
metadata_cache_dir: directory to cache product metadata files.
|
|
475
463
|
sort_by: can be "cloud_cover", default arbitrary order; only has effect for
|
|
476
464
|
SpaceMode.WITHIN.
|
|
477
465
|
harmonize: harmonize pixel values across different processing baselines,
|
|
478
466
|
see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
|
|
467
|
+
context: the data source context.
|
|
479
468
|
""" # noqa: E501
|
|
480
|
-
|
|
469
|
+
# If context is provided, we join the directory with the dataset path,
|
|
470
|
+
# otherwise we treat it directly as UPath.
|
|
471
|
+
if context.ds_path is not None:
|
|
472
|
+
self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
|
|
473
|
+
else:
|
|
474
|
+
self.metadata_cache_dir = UPath(metadata_cache_dir)
|
|
475
|
+
|
|
481
476
|
self.modality = modality
|
|
482
|
-
self.metadata_cache_dir = metadata_cache_dir
|
|
483
477
|
self.sort_by = sort_by
|
|
484
478
|
self.harmonize = harmonize
|
|
485
479
|
|
|
486
480
|
bucket_name = self.bucket_names[modality]
|
|
487
481
|
self.bucket = boto3.resource("s3").Bucket(bucket_name)
|
|
488
482
|
|
|
489
|
-
@staticmethod
|
|
490
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
|
|
491
|
-
"""Creates a new Sentinel2 instance from a configuration dictionary."""
|
|
492
|
-
if config.data_source is None:
|
|
493
|
-
raise ValueError("Sentinel2 data source requires a data source config")
|
|
494
|
-
d = config.data_source.config_dict
|
|
495
|
-
kwargs = dict(
|
|
496
|
-
config=config,
|
|
497
|
-
modality=Sentinel2Modality(d["modality"]),
|
|
498
|
-
metadata_cache_dir=join_upath(ds_path, d["metadata_cache_dir"]),
|
|
499
|
-
)
|
|
500
|
-
|
|
501
|
-
simple_optionals = ["sort_by", "harmonize"]
|
|
502
|
-
for k in simple_optionals:
|
|
503
|
-
if k in d:
|
|
504
|
-
kwargs[k] = d[k]
|
|
505
|
-
|
|
506
|
-
return Sentinel2(**kwargs)
|
|
507
|
-
|
|
508
483
|
def _read_products(
|
|
509
484
|
self, needed_cell_months: set[tuple[str, int, int]]
|
|
510
485
|
) -> Generator[Sentinel2Item, None, None]:
|
|
@@ -7,7 +7,6 @@ from typing import Any
|
|
|
7
7
|
import boto3
|
|
8
8
|
from upath import UPath
|
|
9
9
|
|
|
10
|
-
from rslearn.config import RasterLayerConfig
|
|
11
10
|
from rslearn.data_sources.copernicus import (
|
|
12
11
|
CopernicusItem,
|
|
13
12
|
Sentinel1OrbitDirection,
|
|
@@ -19,7 +18,7 @@ from rslearn.log_utils import get_logger
|
|
|
19
18
|
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
20
19
|
from rslearn.utils.geometry import STGeometry
|
|
21
20
|
|
|
22
|
-
from .data_source import DataSource, QueryConfig
|
|
21
|
+
from .data_source import DataSource, DataSourceContext, QueryConfig
|
|
23
22
|
|
|
24
23
|
WRS2_GRID_SIZE = 1.0
|
|
25
24
|
|
|
@@ -45,11 +44,13 @@ class Sentinel1(DataSource, TileStore):
|
|
|
45
44
|
def __init__(
|
|
46
45
|
self,
|
|
47
46
|
orbit_direction: Sentinel1OrbitDirection | None = None,
|
|
47
|
+
context: DataSourceContext = DataSourceContext(),
|
|
48
48
|
) -> None:
|
|
49
49
|
"""Initialize a new Sentinel1 instance.
|
|
50
50
|
|
|
51
51
|
Args:
|
|
52
52
|
orbit_direction: optional orbit direction to filter by.
|
|
53
|
+
context: the data source context.
|
|
53
54
|
"""
|
|
54
55
|
self.client = boto3.client("s3")
|
|
55
56
|
self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
|
|
@@ -59,18 +60,6 @@ class Sentinel1(DataSource, TileStore):
|
|
|
59
60
|
orbit_direction=orbit_direction,
|
|
60
61
|
)
|
|
61
62
|
|
|
62
|
-
@staticmethod
|
|
63
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel1":
|
|
64
|
-
"""Creates a new Sentinel1 instance from a configuration dictionary."""
|
|
65
|
-
if config.data_source is None:
|
|
66
|
-
raise ValueError(f"data_source is required for config dict {config}")
|
|
67
|
-
d = config.data_source.config_dict
|
|
68
|
-
kwargs: dict[str, Any] = {}
|
|
69
|
-
if "orbit_direction" in d:
|
|
70
|
-
d["orbit_direction"] = Sentinel1OrbitDirection[d["orbit_direction"]]
|
|
71
|
-
|
|
72
|
-
return Sentinel1(**kwargs)
|
|
73
|
-
|
|
74
63
|
def get_items(
|
|
75
64
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
76
65
|
) -> list[list[list[CopernicusItem]]]:
|
|
@@ -14,9 +14,9 @@ from dateutil.relativedelta import relativedelta
|
|
|
14
14
|
from rasterio.transform import from_origin
|
|
15
15
|
from upath import UPath
|
|
16
16
|
|
|
17
|
-
from rslearn.config import QueryConfig,
|
|
17
|
+
from rslearn.config import QueryConfig, SpaceMode
|
|
18
18
|
from rslearn.const import WGS84_EPSG, WGS84_PROJECTION
|
|
19
|
-
from rslearn.data_sources import DataSource, Item
|
|
19
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
20
20
|
from rslearn.log_utils import get_logger
|
|
21
21
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
22
22
|
from rslearn.utils.geometry import STGeometry
|
|
@@ -55,59 +55,40 @@ class ERA5LandMonthlyMeans(DataSource):
|
|
|
55
55
|
|
|
56
56
|
def __init__(
|
|
57
57
|
self,
|
|
58
|
-
band_names: list[str],
|
|
58
|
+
band_names: list[str] | None = None,
|
|
59
59
|
api_key: str | None = None,
|
|
60
|
+
context: DataSourceContext = DataSourceContext(),
|
|
60
61
|
):
|
|
61
62
|
"""Initialize a new ERA5LandMonthlyMeans instance.
|
|
62
63
|
|
|
63
64
|
Args:
|
|
64
65
|
band_names: list of band names to acquire. These should correspond to CDS
|
|
65
|
-
variable names but with "_" replaced with "-".
|
|
66
|
+
variable names but with "_" replaced with "-". This will only be used
|
|
67
|
+
if the layer config is missing from the context.
|
|
66
68
|
api_key: the API key. If not set, it should be set via the CDSAPI_KEY
|
|
67
69
|
environment variable.
|
|
70
|
+
context: the data source context.
|
|
68
71
|
"""
|
|
69
|
-
self.band_names
|
|
72
|
+
self.band_names: list[str]
|
|
73
|
+
if context.layer_config is not None:
|
|
74
|
+
self.band_names = []
|
|
75
|
+
for band_set in context.layer_config.band_sets:
|
|
76
|
+
for band in band_set.bands:
|
|
77
|
+
if band in self.band_names:
|
|
78
|
+
continue
|
|
79
|
+
self.band_names.append(band)
|
|
80
|
+
elif band_names is not None:
|
|
81
|
+
self.band_names = band_names
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"band_names must be set if layer_config is not in the context"
|
|
85
|
+
)
|
|
70
86
|
|
|
71
87
|
self.client = cdsapi.Client(
|
|
72
88
|
url=self.api_url,
|
|
73
89
|
key=api_key,
|
|
74
90
|
)
|
|
75
91
|
|
|
76
|
-
@staticmethod
|
|
77
|
-
def from_config(
|
|
78
|
-
config: RasterLayerConfig, ds_path: UPath
|
|
79
|
-
) -> "ERA5LandMonthlyMeans":
|
|
80
|
-
"""Creates a new ERA5LandMonthlyMeans instance from a configuration dictionary.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
config: the LayerConfig of the layer containing this data source
|
|
84
|
-
ds_path: the path to the data source
|
|
85
|
-
|
|
86
|
-
Returns:
|
|
87
|
-
A new ERA5LandMonthlyMeans instance
|
|
88
|
-
"""
|
|
89
|
-
if config.data_source is None:
|
|
90
|
-
raise ValueError("data_source is required")
|
|
91
|
-
d = config.data_source.config_dict
|
|
92
|
-
|
|
93
|
-
# Determine band names based on the configured band sets.
|
|
94
|
-
band_names = []
|
|
95
|
-
for band_set in config.band_sets:
|
|
96
|
-
for band in band_set.bands:
|
|
97
|
-
if band in band_names:
|
|
98
|
-
continue
|
|
99
|
-
band_names.append(band)
|
|
100
|
-
kwargs: dict[str, Any] = dict(
|
|
101
|
-
band_names=band_names,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
simple_optionals = ["api_key"]
|
|
105
|
-
for k in simple_optionals:
|
|
106
|
-
if k in d:
|
|
107
|
-
kwargs[k] = d[k]
|
|
108
|
-
|
|
109
|
-
return ERA5LandMonthlyMeans(**kwargs)
|
|
110
|
-
|
|
111
92
|
def get_items(
|
|
112
93
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
113
94
|
) -> list[list[list[Item]]]:
|
|
@@ -23,9 +23,9 @@ import requests
|
|
|
23
23
|
import shapely
|
|
24
24
|
from upath import UPath
|
|
25
25
|
|
|
26
|
-
from rslearn.config import QueryConfig
|
|
26
|
+
from rslearn.config import QueryConfig
|
|
27
27
|
from rslearn.const import WGS84_PROJECTION
|
|
28
|
-
from rslearn.data_sources.data_source import DataSource, Item
|
|
28
|
+
from rslearn.data_sources.data_source import DataSource, DataSourceContext, Item
|
|
29
29
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
30
30
|
from rslearn.log_utils import get_logger
|
|
31
31
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
@@ -306,6 +306,7 @@ class Copernicus(DataSource):
|
|
|
306
306
|
sort_by: str | None = None,
|
|
307
307
|
sort_desc: bool = False,
|
|
308
308
|
timeout: float = 10,
|
|
309
|
+
context: DataSourceContext = DataSourceContext(),
|
|
309
310
|
):
|
|
310
311
|
"""Create a new Copernicus.
|
|
311
312
|
|
|
@@ -332,6 +333,7 @@ class Copernicus(DataSource):
|
|
|
332
333
|
sort_desc: for sort_by, sort in descending order instead of ascending
|
|
333
334
|
order.
|
|
334
335
|
timeout: timeout for requests.
|
|
336
|
+
context: the data source context.
|
|
335
337
|
"""
|
|
336
338
|
self.glob_to_bands = glob_to_bands
|
|
337
339
|
self.query_filter = query_filter
|
|
@@ -351,30 +353,6 @@ class Copernicus(DataSource):
|
|
|
351
353
|
self.username = os.environ["COPERNICUS_USERNAME"]
|
|
352
354
|
self.password = os.environ["COPERNICUS_PASSWORD"]
|
|
353
355
|
|
|
354
|
-
@staticmethod
|
|
355
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Copernicus":
|
|
356
|
-
"""Creates a new Copernicus instance from a configuration dictionary."""
|
|
357
|
-
if config.data_source is None:
|
|
358
|
-
raise ValueError("config.data_source is required")
|
|
359
|
-
d = config.data_source.config_dict
|
|
360
|
-
kwargs: dict[str, Any] = dict(
|
|
361
|
-
glob_to_bands=d["glob_to_bands"],
|
|
362
|
-
)
|
|
363
|
-
|
|
364
|
-
simple_optionals = [
|
|
365
|
-
"access_token",
|
|
366
|
-
"query_filter",
|
|
367
|
-
"order_by",
|
|
368
|
-
"sort_by",
|
|
369
|
-
"sort_desc",
|
|
370
|
-
"timeout",
|
|
371
|
-
]
|
|
372
|
-
for k in simple_optionals:
|
|
373
|
-
if k in d:
|
|
374
|
-
kwargs[k] = d[k]
|
|
375
|
-
|
|
376
|
-
return Copernicus(**kwargs)
|
|
377
|
-
|
|
378
356
|
def deserialize_item(self, serialized_item: Any) -> CopernicusItem:
|
|
379
357
|
"""Deserializes an item from JSON-decoded data."""
|
|
380
358
|
assert isinstance(serialized_item, dict)
|
|
@@ -763,23 +741,43 @@ class Sentinel2(Copernicus):
|
|
|
763
741
|
|
|
764
742
|
def __init__(
|
|
765
743
|
self,
|
|
766
|
-
assets: list[str],
|
|
767
744
|
product_type: Sentinel2ProductType,
|
|
768
745
|
harmonize: bool = False,
|
|
746
|
+
assets: list[str] | None = None,
|
|
747
|
+
context: DataSourceContext = DataSourceContext(),
|
|
769
748
|
**kwargs: Any,
|
|
770
749
|
):
|
|
771
750
|
"""Create a new Sentinel2.
|
|
772
751
|
|
|
773
752
|
Args:
|
|
774
|
-
assets: list of assets corresponding to keys in BANDS, e.g. ["TCI", "B08"].
|
|
775
753
|
product_type: desired product type, L1C or L2A.
|
|
776
754
|
harmonize: harmonize pixel values across different processing baselines,
|
|
777
755
|
see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
|
|
756
|
+
assets: the assets to download, or None to download all assets. This is
|
|
757
|
+
only used if the layer config is not in the context.
|
|
758
|
+
context: the data source context.
|
|
778
759
|
kwargs: additional arguments to pass to Copernicus.
|
|
779
760
|
"""
|
|
780
761
|
# Create glob to bands map.
|
|
762
|
+
# If the context is provided, we limit to needed assets based on the configured
|
|
763
|
+
# band sets.
|
|
764
|
+
if context.layer_config is not None:
|
|
765
|
+
needed_assets = []
|
|
766
|
+
for asset_key, asset_bands in Sentinel2.BANDS.items():
|
|
767
|
+
# See if the bands provided by this asset intersect with the bands in
|
|
768
|
+
# at least one configured band set.
|
|
769
|
+
for band_set in context.layer_config.band_sets:
|
|
770
|
+
if not set(band_set.bands).intersection(set(asset_bands)):
|
|
771
|
+
continue
|
|
772
|
+
needed_assets.append(asset_key)
|
|
773
|
+
break
|
|
774
|
+
elif assets is not None:
|
|
775
|
+
needed_assets = assets
|
|
776
|
+
else:
|
|
777
|
+
needed_assets = list(Sentinel2.BANDS.keys())
|
|
778
|
+
|
|
781
779
|
glob_to_bands = {}
|
|
782
|
-
for asset_key in
|
|
780
|
+
for asset_key in needed_assets:
|
|
783
781
|
band_names = self.BANDS[asset_key]
|
|
784
782
|
glob_pattern = self.GLOB_PATTERNS[product_type][asset_key]
|
|
785
783
|
glob_to_bands[glob_pattern] = band_names
|
|
@@ -788,46 +786,13 @@ class Sentinel2(Copernicus):
|
|
|
788
786
|
query_filter = f"Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'productType' and att/OData.CSC.StringAttribute/Value eq '{quote(product_type.value)}')"
|
|
789
787
|
|
|
790
788
|
super().__init__(
|
|
789
|
+
context=context,
|
|
791
790
|
glob_to_bands=glob_to_bands,
|
|
792
791
|
query_filter=query_filter,
|
|
793
792
|
**kwargs,
|
|
794
793
|
)
|
|
795
794
|
self.harmonize = harmonize
|
|
796
795
|
|
|
797
|
-
@staticmethod
|
|
798
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
|
|
799
|
-
"""Creates a new Sentinel2 instance from a configuration dictionary."""
|
|
800
|
-
if config.data_source is None:
|
|
801
|
-
raise ValueError("config.data_source is required")
|
|
802
|
-
d = config.data_source.config_dict
|
|
803
|
-
|
|
804
|
-
# Determine needed assets based on the configured band sets.
|
|
805
|
-
needed_assets: set[str] = set()
|
|
806
|
-
for asset_key, asset_bands in Sentinel2.BANDS.items():
|
|
807
|
-
for band_set in config.band_sets:
|
|
808
|
-
if not set(band_set.bands).intersection(set(asset_bands)):
|
|
809
|
-
continue
|
|
810
|
-
needed_assets.add(asset_key)
|
|
811
|
-
|
|
812
|
-
kwargs: dict[str, Any] = dict(
|
|
813
|
-
assets=list(needed_assets),
|
|
814
|
-
product_type=Sentinel2ProductType[d["product_type"]],
|
|
815
|
-
)
|
|
816
|
-
|
|
817
|
-
simple_optionals = [
|
|
818
|
-
"harmonize",
|
|
819
|
-
"access_token",
|
|
820
|
-
"order_by",
|
|
821
|
-
"sort_by",
|
|
822
|
-
"sort_desc",
|
|
823
|
-
"timeout",
|
|
824
|
-
]
|
|
825
|
-
for k in simple_optionals:
|
|
826
|
-
if k in d:
|
|
827
|
-
kwargs[k] = d[k]
|
|
828
|
-
|
|
829
|
-
return Sentinel2(**kwargs)
|
|
830
|
-
|
|
831
796
|
# Override to support harmonization step.
|
|
832
797
|
def _process_product_zip(
|
|
833
798
|
self, tile_store: TileStoreWithLayer, item: CopernicusItem, local_zip_fname: str
|
|
@@ -922,6 +887,7 @@ class Sentinel1(Copernicus):
|
|
|
922
887
|
product_type: Sentinel1ProductType,
|
|
923
888
|
polarisation: Sentinel1Polarisation,
|
|
924
889
|
orbit_direction: Sentinel1OrbitDirection | None = None,
|
|
890
|
+
context: DataSourceContext = DataSourceContext(),
|
|
925
891
|
**kwargs: Any,
|
|
926
892
|
):
|
|
927
893
|
"""Create a new Sentinel1.
|
|
@@ -930,6 +896,7 @@ class Sentinel1(Copernicus):
|
|
|
930
896
|
product_type: desired product type.
|
|
931
897
|
polarisation: desired polarisation(s).
|
|
932
898
|
orbit_direction: optional orbit direction to filter by.
|
|
899
|
+
context: the data source context.
|
|
933
900
|
kwargs: additional arguments to pass to Copernicus.
|
|
934
901
|
"""
|
|
935
902
|
# Create query filter based on the product type.
|
|
@@ -945,31 +912,3 @@ class Sentinel1(Copernicus):
|
|
|
945
912
|
query_filter=query_filter,
|
|
946
913
|
**kwargs,
|
|
947
914
|
)
|
|
948
|
-
|
|
949
|
-
@staticmethod
|
|
950
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel1":
|
|
951
|
-
"""Creates a new Sentinel1 instance from a configuration dictionary."""
|
|
952
|
-
if config.data_source is None:
|
|
953
|
-
raise ValueError("config.data_source is required")
|
|
954
|
-
d = config.data_source.config_dict
|
|
955
|
-
|
|
956
|
-
kwargs: dict[str, Any] = dict(
|
|
957
|
-
product_type=Sentinel1ProductType[d["product_type"]],
|
|
958
|
-
polarisation=Sentinel1Polarisation[d["polarisation"]],
|
|
959
|
-
)
|
|
960
|
-
|
|
961
|
-
if "orbit_direction" in d:
|
|
962
|
-
kwargs["orbit_direction"] = Sentinel1OrbitDirection[d["orbit_direction"]]
|
|
963
|
-
|
|
964
|
-
simple_optionals = [
|
|
965
|
-
"access_token",
|
|
966
|
-
"order_by",
|
|
967
|
-
"sort_by",
|
|
968
|
-
"sort_desc",
|
|
969
|
-
"timeout",
|
|
970
|
-
]
|
|
971
|
-
for k in simple_optionals:
|
|
972
|
-
if k in d:
|
|
973
|
-
kwargs[k] = d[k]
|
|
974
|
-
|
|
975
|
-
return Sentinel1(**kwargs)
|
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
from collections.abc import Generator
|
|
4
4
|
from typing import Any, BinaryIO, Generic, TypeVar
|
|
5
5
|
|
|
6
|
+
from upath import UPath
|
|
7
|
+
|
|
6
8
|
from rslearn.config import LayerConfig, QueryConfig
|
|
7
9
|
from rslearn.dataset import Window
|
|
8
10
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
@@ -127,3 +129,27 @@ class RetrieveItemDataSource(DataSource[ItemType]):
|
|
|
127
129
|
) -> Generator[tuple[str, BinaryIO], None, None]:
|
|
128
130
|
"""Retrieves the rasters corresponding to an item as file streams."""
|
|
129
131
|
raise NotImplementedError
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class DataSourceContext:
|
|
135
|
+
"""This context is passed to every data source.
|
|
136
|
+
|
|
137
|
+
When initializing data sources within rslearn, we always set the ds_path and
|
|
138
|
+
layer_config. However, for convenience (for users directly initializing the data
|
|
139
|
+
sources externally), each data source should allow for initialization when one or
|
|
140
|
+
both are missing.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self, ds_path: UPath | None = None, layer_config: LayerConfig | None = None
|
|
145
|
+
):
|
|
146
|
+
"""Create a new DataSourceContext.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
ds_path: the path of the underlying dataset.
|
|
150
|
+
layer_config: the LayerConfig for the layer that the data source is for.
|
|
151
|
+
"""
|
|
152
|
+
# We don't use dataclass here because otherwise jsonargparse will ignore our
|
|
153
|
+
# custom serializer/deserializer defined in rslearn.utils.jsonargparse.
|
|
154
|
+
self.ds_path = ds_path
|
|
155
|
+
self.layer_config = layer_config
|