rslearn 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +414 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +14 -20
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +370 -1
- rslearn/main.py +3 -3
- rslearn/models/concatenate_features.py +93 -0
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
- {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/RECORD +47 -48
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py
CHANGED
|
@@ -1,19 +1,73 @@
|
|
|
1
1
|
"""Classes for storing configuration of a dataset."""
|
|
2
2
|
|
|
3
|
+
import copy
|
|
4
|
+
import functools
|
|
3
5
|
import json
|
|
6
|
+
import warnings
|
|
4
7
|
from datetime import timedelta
|
|
5
|
-
from enum import
|
|
6
|
-
from typing import Any
|
|
8
|
+
from enum import StrEnum
|
|
9
|
+
from typing import TYPE_CHECKING, Annotated, Any
|
|
7
10
|
|
|
11
|
+
import jsonargparse
|
|
8
12
|
import numpy as np
|
|
9
13
|
import numpy.typing as npt
|
|
10
14
|
import pytimeparse
|
|
15
|
+
from pydantic import (
|
|
16
|
+
BaseModel,
|
|
17
|
+
BeforeValidator,
|
|
18
|
+
ConfigDict,
|
|
19
|
+
Field,
|
|
20
|
+
PlainSerializer,
|
|
21
|
+
field_validator,
|
|
22
|
+
model_validator,
|
|
23
|
+
)
|
|
11
24
|
from rasterio.enums import Resampling
|
|
25
|
+
from upath import UPath
|
|
12
26
|
|
|
27
|
+
from rslearn.log_utils import get_logger
|
|
13
28
|
from rslearn.utils import PixelBounds, Projection
|
|
29
|
+
from rslearn.utils.raster_format import RasterFormat
|
|
30
|
+
from rslearn.utils.vector_format import VectorFormat
|
|
14
31
|
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from rslearn.data_sources.data_source import DataSource
|
|
15
34
|
|
|
16
|
-
|
|
35
|
+
logger = get_logger("__name__")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def ensure_timedelta(v: Any) -> Any:
|
|
39
|
+
"""Ensure the value is a timedelta.
|
|
40
|
+
|
|
41
|
+
If the value is a string, we try to parse it with pytimeparse.
|
|
42
|
+
|
|
43
|
+
This function is meant to be used like Annotated[timedelta, BeforeValidator(ensure_timedelta)].
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(v, timedelta):
|
|
46
|
+
return v
|
|
47
|
+
if isinstance(v, str):
|
|
48
|
+
return pytimeparse.parse(v)
|
|
49
|
+
raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def ensure_optional_timedelta(v: Any) -> Any:
|
|
53
|
+
"""Like ensure_timedelta, but allows None as a value."""
|
|
54
|
+
if v is None:
|
|
55
|
+
return None
|
|
56
|
+
if isinstance(v, timedelta):
|
|
57
|
+
return v
|
|
58
|
+
if isinstance(v, str):
|
|
59
|
+
return pytimeparse.parse(v)
|
|
60
|
+
raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def serialize_optional_timedelta(v: timedelta | None) -> str | None:
|
|
64
|
+
"""Serialize an optional timedelta for compatibility with pytimeparse."""
|
|
65
|
+
if v is None:
|
|
66
|
+
return None
|
|
67
|
+
return str(v.total_seconds()) + "s"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DType(StrEnum):
|
|
17
71
|
"""Data type of a raster."""
|
|
18
72
|
|
|
19
73
|
UINT8 = "uint8"
|
|
@@ -49,61 +103,28 @@ class DType(Enum):
|
|
|
49
103
|
raise ValueError(f"unable to handle numpy dtype {self}")
|
|
50
104
|
|
|
51
105
|
|
|
52
|
-
|
|
53
|
-
"
|
|
54
|
-
"bilinear": Resampling.bilinear,
|
|
55
|
-
"cubic": Resampling.cubic,
|
|
56
|
-
"cubic_spline": Resampling.cubic_spline,
|
|
57
|
-
}
|
|
106
|
+
class ResamplingMethod(StrEnum):
|
|
107
|
+
"""An enum representing the rasterio Resampling."""
|
|
58
108
|
|
|
109
|
+
NEAREST = "nearest"
|
|
110
|
+
BILINEAR = "bilinear"
|
|
111
|
+
CUBIC = "cubic"
|
|
112
|
+
CUBIC_SPLINE = "cubic_spline"
|
|
59
113
|
|
|
60
|
-
|
|
61
|
-
|
|
114
|
+
def get_rasterio_resampling(self) -> Resampling:
|
|
115
|
+
"""Get the rasterio Resampling corresponding to this ResamplingMethod."""
|
|
116
|
+
return RESAMPLING_METHODS[self]
|
|
62
117
|
|
|
63
|
-
def __init__(self, name: str, config_dict: dict[str, Any]) -> None:
|
|
64
|
-
"""Initialize a new RasterFormatConfig.
|
|
65
118
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
@staticmethod
|
|
74
|
-
def from_config(config: dict[str, Any]) -> "RasterFormatConfig":
|
|
75
|
-
"""Create a RasterFormatConfig from config dict.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
config: the config dict for this RasterFormatConfig
|
|
79
|
-
"""
|
|
80
|
-
return RasterFormatConfig(name=config["name"], config_dict=config)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class VectorFormatConfig:
|
|
84
|
-
"""A configuration specifying a VectorFormat."""
|
|
85
|
-
|
|
86
|
-
def __init__(self, name: str, config_dict: dict[str, Any] = {}) -> None:
|
|
87
|
-
"""Initialize a new VectorFormatConfig.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
name: the name of the VectorFormat to use.
|
|
91
|
-
config_dict: configuration to pass to the VectorFormat.
|
|
92
|
-
"""
|
|
93
|
-
self.name = name
|
|
94
|
-
self.config_dict = config_dict
|
|
95
|
-
|
|
96
|
-
@staticmethod
|
|
97
|
-
def from_config(config: dict[str, Any]) -> "VectorFormatConfig":
|
|
98
|
-
"""Create a VectorFormatConfig from config dict.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
config: the config dict for this VectorFormatConfig
|
|
102
|
-
"""
|
|
103
|
-
return VectorFormatConfig(name=config["name"], config_dict=config)
|
|
119
|
+
RESAMPLING_METHODS = {
|
|
120
|
+
ResamplingMethod.NEAREST: Resampling.nearest,
|
|
121
|
+
ResamplingMethod.BILINEAR: Resampling.bilinear,
|
|
122
|
+
ResamplingMethod.CUBIC: Resampling.cubic,
|
|
123
|
+
ResamplingMethod.CUBIC_SPLINE: Resampling.cubic_spline,
|
|
124
|
+
}
|
|
104
125
|
|
|
105
126
|
|
|
106
|
-
class BandSetConfig:
|
|
127
|
+
class BandSetConfig(BaseModel):
|
|
107
128
|
"""A configuration for a band set in a raster layer.
|
|
108
129
|
|
|
109
130
|
Each band set specifies one or more bands that should be stored together.
|
|
@@ -111,97 +132,67 @@ class BandSetConfig:
|
|
|
111
132
|
bands.
|
|
112
133
|
"""
|
|
113
134
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
135
|
+
dtype: DType = Field(description="Pixel value type to store the data under")
|
|
136
|
+
bands: list[str] = Field(
|
|
137
|
+
default_factory=lambda: [],
|
|
138
|
+
description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
|
|
139
|
+
)
|
|
140
|
+
num_bands: int | None = Field(
|
|
141
|
+
default=None,
|
|
142
|
+
description="The number of bands in this band set. The bands will be named B0, B1, B2, etc.",
|
|
143
|
+
)
|
|
144
|
+
format: dict[str, Any] = Field(
|
|
145
|
+
default_factory=lambda: {
|
|
146
|
+
"class_path": "rslearn.utils.raster_format.GeotiffRasterFormat"
|
|
147
|
+
},
|
|
148
|
+
description="jsonargparse configuration for the RasterFormat to store the tiles in.",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Store images at a resolution higher or lower than the window resolution. This
|
|
152
|
+
# enables keeping source data at its native resolution, either to save storage
|
|
153
|
+
# space (for lower resolution data) or to retain details (for higher resolution
|
|
154
|
+
# data). If positive, store data at the window resolution divided by
|
|
155
|
+
# 2^(zoom_offset) (higher resolution). If negative, store data at the window
|
|
156
|
+
# resolution multiplied by 2^(-zoom_offset) (lower resolution).
|
|
157
|
+
zoom_offset: int = Field(
|
|
158
|
+
default=0,
|
|
159
|
+
description="Store data at the window resolution multiplied by 2^(-zoom_offset).",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
remap: dict[str, Any] | None = Field(
|
|
163
|
+
default=None,
|
|
164
|
+
description="Optional jsonargparse configuration for a Remapper to remap pixel values.",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Optional list of names for the different possible values of each band. The length
|
|
168
|
+
# of this list must equal the number of bands. For example, [["forest", "desert"]]
|
|
169
|
+
# means that it is a single-band raster where values can be 0 (forest) or 1
|
|
170
|
+
# (desert).
|
|
171
|
+
class_names: list[list[str]] | None = Field(
|
|
172
|
+
default=None,
|
|
173
|
+
description="Optional list of names for the different possible values of each band.",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Optional list of nodata values for this band set. This is used during
|
|
177
|
+
# materialization when creating mosaics, to determine which parts of the source
|
|
178
|
+
# images should be copied.
|
|
179
|
+
nodata_vals: list[float] | None = Field(
|
|
180
|
+
default=None, description="Optional nodata value for each band."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
@model_validator(mode="after")
|
|
184
|
+
def after_validator(self) -> "BandSetConfig":
|
|
185
|
+
"""Ensure the BandSetConfig is valid, and handle the num_bands field."""
|
|
186
|
+
if (len(self.bands) == 0 and self.num_bands is None) or (
|
|
187
|
+
len(self.bands) != 0 and self.num_bands is not None
|
|
154
188
|
):
|
|
155
|
-
raise ValueError("exactly one of bands and num_bands must be
|
|
156
|
-
if bands is None:
|
|
157
|
-
assert num_bands is not None
|
|
158
|
-
bands = [f"B{idx}" for idx in range(num_bands)]
|
|
159
|
-
|
|
160
|
-
if class_names is not None and len(bands) != len(class_names):
|
|
161
|
-
raise ValueError(
|
|
162
|
-
f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
self.config_dict = config_dict
|
|
166
|
-
self.bands = bands
|
|
167
|
-
self.dtype = dtype
|
|
168
|
-
self.zoom_offset = zoom_offset
|
|
169
|
-
self.remap = remap
|
|
170
|
-
self.class_names = class_names
|
|
171
|
-
self.nodata_vals = nodata_vals
|
|
172
|
-
|
|
173
|
-
if format is None:
|
|
174
|
-
self.format = {"name": "geotiff"}
|
|
175
|
-
else:
|
|
176
|
-
self.format = format
|
|
189
|
+
raise ValueError("exactly one of bands and num_bands must be specified")
|
|
177
190
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
191
|
+
if self.num_bands is not None:
|
|
192
|
+
self.bands = [f"B{band_idx}" for band_idx in range(self.num_bands)]
|
|
193
|
+
self.num_bands = None
|
|
181
194
|
|
|
182
|
-
|
|
183
|
-
def from_config(config: dict[str, Any]) -> "BandSetConfig":
|
|
184
|
-
"""Create a BandSetConfig from config dict.
|
|
185
|
-
|
|
186
|
-
Args:
|
|
187
|
-
config: the config dict for this BandSetConfig
|
|
188
|
-
"""
|
|
189
|
-
kwargs = dict(
|
|
190
|
-
config_dict=config,
|
|
191
|
-
dtype=DType(config["dtype"]),
|
|
192
|
-
)
|
|
193
|
-
for k in [
|
|
194
|
-
"bands",
|
|
195
|
-
"num_bands",
|
|
196
|
-
"format",
|
|
197
|
-
"zoom_offset",
|
|
198
|
-
"remap",
|
|
199
|
-
"class_names",
|
|
200
|
-
"nodata_vals",
|
|
201
|
-
]:
|
|
202
|
-
if k in config:
|
|
203
|
-
kwargs[k] = config[k]
|
|
204
|
-
return BandSetConfig(**kwargs) # type: ignore
|
|
195
|
+
return self
|
|
205
196
|
|
|
206
197
|
def get_final_projection_and_bounds(
|
|
207
198
|
self, projection: Projection, bounds: PixelBounds
|
|
@@ -236,30 +227,76 @@ class BandSetConfig:
|
|
|
236
227
|
)
|
|
237
228
|
return projection, bounds
|
|
238
229
|
|
|
230
|
+
@field_validator("format", mode="before")
|
|
231
|
+
@classmethod
|
|
232
|
+
def convert_format_from_legacy(cls, v: dict[str, Any]) -> dict[str, Any]:
|
|
233
|
+
"""Support legacy format of the RasterFormat.
|
|
234
|
+
|
|
235
|
+
The legacy format sets 'name' instead of 'class_path', and uses custom parsing
|
|
236
|
+
for the init_args.
|
|
237
|
+
"""
|
|
238
|
+
if "name" not in v:
|
|
239
|
+
# New version, it is all good.
|
|
240
|
+
return v
|
|
241
|
+
|
|
242
|
+
warnings.warn(
|
|
243
|
+
"`format = {'name': ...}` is deprecated; "
|
|
244
|
+
"use `{'class_path': '...', 'init_args': {...}}` instead.",
|
|
245
|
+
DeprecationWarning,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
legacy_name_to_class_path = {
|
|
249
|
+
"image_tile": "rslearn.utils.raster_format.ImageTileRasterFormat",
|
|
250
|
+
"geotiff": "rslearn.utils.raster_format.GeotiffRasterFormat",
|
|
251
|
+
"single_image": "rslearn.utils.raster_format.SingleImageRasterFormat",
|
|
252
|
+
}
|
|
253
|
+
if v["name"] not in legacy_name_to_class_path:
|
|
254
|
+
raise ValueError(
|
|
255
|
+
f"could not parse legacy format with unknown raster format {v['name']}"
|
|
256
|
+
)
|
|
257
|
+
init_args = dict(v)
|
|
258
|
+
class_path = legacy_name_to_class_path[init_args.pop("name")]
|
|
259
|
+
|
|
260
|
+
return dict(
|
|
261
|
+
class_path=class_path,
|
|
262
|
+
init_args=init_args,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def instantiate_raster_format(self) -> RasterFormat:
|
|
266
|
+
"""Instantiate the RasterFormat specified by this BandSetConfig."""
|
|
267
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
239
268
|
|
|
240
|
-
|
|
269
|
+
init_jsonargparse()
|
|
270
|
+
parser = jsonargparse.ArgumentParser()
|
|
271
|
+
parser.add_argument("--raster_format", type=RasterFormat)
|
|
272
|
+
cfg = parser.parse_object({"raster_format": self.format})
|
|
273
|
+
raster_format = parser.instantiate_classes(cfg).raster_format
|
|
274
|
+
return raster_format
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class SpaceMode(StrEnum):
|
|
241
278
|
"""Spatial matching mode when looking up items corresponding to a window."""
|
|
242
279
|
|
|
243
|
-
CONTAINS =
|
|
280
|
+
CONTAINS = "CONTAINS"
|
|
244
281
|
"""Items must contain the entire window."""
|
|
245
282
|
|
|
246
|
-
INTERSECTS =
|
|
283
|
+
INTERSECTS = "INTERSECTS"
|
|
247
284
|
"""Items must overlap any portion of the window."""
|
|
248
285
|
|
|
249
|
-
MOSAIC =
|
|
286
|
+
MOSAIC = "MOSAIC"
|
|
250
287
|
"""Groups of items should be computed that cover the entire window.
|
|
251
288
|
|
|
252
289
|
During materialization, items in each group are merged to form a mosaic in the
|
|
253
290
|
dataset.
|
|
254
291
|
"""
|
|
255
292
|
|
|
256
|
-
PER_PERIOD_MOSAIC =
|
|
293
|
+
PER_PERIOD_MOSAIC = "PER_PERIOD_MOSAIC"
|
|
257
294
|
"""Create one mosaic per sub-period of the time range.
|
|
258
295
|
|
|
259
296
|
The duration of the sub-periods is controlled by another option in QueryConfig.
|
|
260
297
|
"""
|
|
261
298
|
|
|
262
|
-
COMPOSITE =
|
|
299
|
+
COMPOSITE = "COMPOSITE"
|
|
263
300
|
"""Creates one composite covering the entire window.
|
|
264
301
|
|
|
265
302
|
During querying all items intersecting the window are placed in one group.
|
|
@@ -270,188 +307,216 @@ class SpaceMode(Enum):
|
|
|
270
307
|
# TODO add PER_PERIOD_COMPOSITE
|
|
271
308
|
|
|
272
309
|
|
|
273
|
-
class TimeMode(
|
|
310
|
+
class TimeMode(StrEnum):
|
|
274
311
|
"""Temporal matching mode when looking up items corresponding to a window."""
|
|
275
312
|
|
|
276
|
-
WITHIN =
|
|
313
|
+
WITHIN = "WITHIN"
|
|
277
314
|
"""Items must be within the window time range."""
|
|
278
315
|
|
|
279
|
-
NEAREST =
|
|
316
|
+
NEAREST = "NEAREST"
|
|
280
317
|
"""Select items closest to the window time range, up to max_matches."""
|
|
281
318
|
|
|
282
|
-
BEFORE =
|
|
319
|
+
BEFORE = "BEFORE"
|
|
283
320
|
"""Select items before the end of the window time range, up to max_matches."""
|
|
284
321
|
|
|
285
|
-
AFTER =
|
|
322
|
+
AFTER = "AFTER"
|
|
286
323
|
"""Select items after the start of the window time range, up to max_matches."""
|
|
287
324
|
|
|
288
325
|
|
|
289
|
-
class QueryConfig:
|
|
326
|
+
class QueryConfig(BaseModel):
|
|
290
327
|
"""A configuration for querying items in a data source."""
|
|
291
328
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
"
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
329
|
+
model_config = ConfigDict(frozen=True)
|
|
330
|
+
|
|
331
|
+
space_mode: SpaceMode = Field(
|
|
332
|
+
default=SpaceMode.MOSAIC,
|
|
333
|
+
description="Specifies how items should be matched with windows spatially.",
|
|
334
|
+
)
|
|
335
|
+
time_mode: TimeMode = Field(
|
|
336
|
+
default=TimeMode.WITHIN,
|
|
337
|
+
description="Specifies how items should be matched with windows temporally.",
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Minimum number of item groups. If there are fewer than this many matches, then no
|
|
341
|
+
# matches will be returned. This can be used to prevent unnecessary data ingestion
|
|
342
|
+
# if the user plans to discard windows that do not have a sufficient amount of data.
|
|
343
|
+
min_matches: int = Field(
|
|
344
|
+
default=0, description="The minimum number of item groups."
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
max_matches: int = Field(
|
|
348
|
+
default=1, description="The maximum number of item groups."
|
|
349
|
+
)
|
|
350
|
+
period_duration: Annotated[
|
|
351
|
+
timedelta,
|
|
352
|
+
BeforeValidator(ensure_timedelta),
|
|
353
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
354
|
+
] = Field(
|
|
355
|
+
default=timedelta(days=30),
|
|
356
|
+
description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
class DataSourceConfig(BaseModel):
|
|
361
|
+
"""Configuration for a DataSource in a dataset layer."""
|
|
304
362
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
363
|
+
model_config = ConfigDict(frozen=True)
|
|
364
|
+
|
|
365
|
+
class_path: str = Field(description="Class path for the data source.")
|
|
366
|
+
init_args: dict[str, Any] = Field(
|
|
367
|
+
default_factory=lambda: {},
|
|
368
|
+
description="jsonargparse init args for the data source.",
|
|
369
|
+
)
|
|
370
|
+
query_config: QueryConfig = Field(
|
|
371
|
+
default_factory=lambda: QueryConfig(),
|
|
372
|
+
description="QueryConfig specifying how to match items with windows.",
|
|
373
|
+
)
|
|
374
|
+
time_offset: Annotated[
|
|
375
|
+
timedelta | None,
|
|
376
|
+
BeforeValidator(ensure_optional_timedelta),
|
|
377
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
378
|
+
] = Field(
|
|
379
|
+
default=None,
|
|
380
|
+
description="Optional timedelta to add to the window's time range before matching.",
|
|
381
|
+
)
|
|
382
|
+
duration: Annotated[
|
|
383
|
+
timedelta | None,
|
|
384
|
+
BeforeValidator(ensure_optional_timedelta),
|
|
385
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
386
|
+
] = Field(
|
|
387
|
+
default=None,
|
|
388
|
+
description="Optional, if the window's time range is (t0, t1), then update to (t0, t0 + duration).",
|
|
389
|
+
)
|
|
390
|
+
ingest: bool = Field(
|
|
391
|
+
default=True,
|
|
392
|
+
description="Whether to ingest this layer (default True). If False, it will be directly materialized without ingestion.",
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
@model_validator(mode="before")
|
|
396
|
+
@classmethod
|
|
397
|
+
def convert_from_legacy(cls, d: dict[str, Any]) -> dict[str, Any]:
|
|
398
|
+
"""Support legacy format of the DataSourceConfig.
|
|
399
|
+
|
|
400
|
+
The legacy format sets 'name' instead of 'class_path', and mixes the arguments
|
|
401
|
+
for the DataSource in with the DataSourceConfig keys.
|
|
316
402
|
"""
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
"space_mode": str(self.space_mode),
|
|
327
|
-
"time_mode": str(self.time_mode),
|
|
328
|
-
"min_matches": self.min_matches,
|
|
329
|
-
"max_matches": self.max_matches,
|
|
330
|
-
"period_duration": f"{self.period_duration.total_seconds()}s",
|
|
331
|
-
}
|
|
332
|
-
|
|
333
|
-
@staticmethod
|
|
334
|
-
def from_config(config: dict[str, Any]) -> "QueryConfig":
|
|
335
|
-
"""Create a QueryConfig from config dict.
|
|
403
|
+
if "name" not in d:
|
|
404
|
+
# New version, it is all good.
|
|
405
|
+
return d
|
|
406
|
+
|
|
407
|
+
warnings.warn(
|
|
408
|
+
"`Data source configuration {'name': ...}` is deprecated; "
|
|
409
|
+
"use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
|
|
410
|
+
DeprecationWarning,
|
|
411
|
+
)
|
|
336
412
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
""
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
kwargs["time_mode"] = TimeMode[config["time_mode"]]
|
|
345
|
-
if "period_duration" in config:
|
|
346
|
-
kwargs["period_duration"] = timedelta(
|
|
347
|
-
seconds=pytimeparse.parse(config["period_duration"])
|
|
348
|
-
)
|
|
349
|
-
for k in ["min_matches", "max_matches"]:
|
|
350
|
-
if k not in config:
|
|
413
|
+
# Split the dict into the base config that is in the pydantic model, and the
|
|
414
|
+
# source-specific options that should be moved to init_args dict.
|
|
415
|
+
class_path = d["name"]
|
|
416
|
+
base_config: dict[str, Any] = {}
|
|
417
|
+
ds_init_args: dict[str, Any] = {}
|
|
418
|
+
for k, v in d.items():
|
|
419
|
+
if k == "name":
|
|
351
420
|
continue
|
|
352
|
-
|
|
353
|
-
|
|
421
|
+
if k in cls.model_fields:
|
|
422
|
+
base_config[k] = v
|
|
423
|
+
else:
|
|
424
|
+
ds_init_args[k] = v
|
|
425
|
+
|
|
426
|
+
# Some legacy configs erroneously specify these keys, which are now caught by
|
|
427
|
+
# validation. But we still want those specific legacy configs to work.
|
|
428
|
+
if (
|
|
429
|
+
class_path == "rslearn.data_sources.planetary_computer.Sentinel2"
|
|
430
|
+
and "max_cloud_cover" in ds_init_args
|
|
431
|
+
):
|
|
432
|
+
warnings.warn(
|
|
433
|
+
"Data source configuration specifies invalid 'max_cloud_cover' option.",
|
|
434
|
+
DeprecationWarning,
|
|
435
|
+
)
|
|
436
|
+
del ds_init_args["max_cloud_cover"]
|
|
354
437
|
|
|
438
|
+
base_config["class_path"] = class_path
|
|
439
|
+
base_config["init_args"] = ds_init_args
|
|
440
|
+
return base_config
|
|
355
441
|
|
|
356
|
-
class DataSourceConfig:
|
|
357
|
-
"""Configuration for a DataSource in a dataset layer."""
|
|
358
442
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
name: str,
|
|
362
|
-
query_config: QueryConfig,
|
|
363
|
-
config_dict: dict[str, Any],
|
|
364
|
-
time_offset: timedelta | None = None,
|
|
365
|
-
duration: timedelta | None = None,
|
|
366
|
-
ingest: bool = True,
|
|
367
|
-
) -> None:
|
|
368
|
-
"""Initializes a new DataSourceConfig.
|
|
369
|
-
|
|
370
|
-
Args:
|
|
371
|
-
name: the data source class name
|
|
372
|
-
query_config: the QueryConfig specifying how to match items with windows
|
|
373
|
-
config_dict: additional config passed to initialize the DataSource
|
|
374
|
-
time_offset: optional, add this timedelta to the window's time range before
|
|
375
|
-
matching
|
|
376
|
-
duration: optional, if window's time range is (t0, t1), then update to
|
|
377
|
-
(t0, t0 + duration)
|
|
378
|
-
ingest: whether to ingest this layer or directly materialize it
|
|
379
|
-
(default true)
|
|
380
|
-
"""
|
|
381
|
-
self.name = name
|
|
382
|
-
self.query_config = query_config
|
|
383
|
-
self.config_dict = config_dict
|
|
384
|
-
self.time_offset = time_offset
|
|
385
|
-
self.duration = duration
|
|
386
|
-
self.ingest = ingest
|
|
443
|
+
class LayerType(StrEnum):
|
|
444
|
+
"""The layer type (raster or vector)."""
|
|
387
445
|
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
return self.config_dict
|
|
446
|
+
RASTER = "raster"
|
|
447
|
+
VECTOR = "vector"
|
|
391
448
|
|
|
392
|
-
@staticmethod
|
|
393
|
-
def from_config(config: dict[str, Any]) -> "DataSourceConfig":
|
|
394
|
-
"""Create a DataSourceConfig from config dict.
|
|
395
449
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
"""
|
|
399
|
-
kwargs = dict(
|
|
400
|
-
name=config["name"],
|
|
401
|
-
query_config=QueryConfig.from_config(config.get("query_config", {})),
|
|
402
|
-
config_dict=config,
|
|
403
|
-
)
|
|
404
|
-
if "time_offset" in config:
|
|
405
|
-
kwargs["time_offset"] = timedelta(
|
|
406
|
-
seconds=pytimeparse.parse(config["time_offset"])
|
|
407
|
-
)
|
|
408
|
-
if "duration" in config:
|
|
409
|
-
kwargs["duration"] = timedelta(
|
|
410
|
-
seconds=pytimeparse.parse(config["duration"])
|
|
411
|
-
)
|
|
412
|
-
if "ingest" in config:
|
|
413
|
-
kwargs["ingest"] = config["ingest"]
|
|
414
|
-
return DataSourceConfig(**kwargs)
|
|
450
|
+
class CompositingMethod(StrEnum):
|
|
451
|
+
"""Method how to select pixels for the composite from corresponding items of a window."""
|
|
415
452
|
|
|
453
|
+
FIRST_VALID = "FIRST_VALID"
|
|
454
|
+
"""Select first valid pixel in order of corresponding items (might be sorted)"""
|
|
416
455
|
|
|
417
|
-
|
|
418
|
-
"""
|
|
456
|
+
MEAN = "MEAN"
|
|
457
|
+
"""Select per-pixel mean value of corresponding items of a window"""
|
|
419
458
|
|
|
420
|
-
|
|
421
|
-
|
|
459
|
+
MEDIAN = "MEDIAN"
|
|
460
|
+
"""Select per-pixel median value of corresponding items of a window"""
|
|
422
461
|
|
|
423
462
|
|
|
424
|
-
class LayerConfig:
|
|
463
|
+
class LayerConfig(BaseModel):
|
|
425
464
|
"""Configuration of a layer in a dataset."""
|
|
426
465
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
466
|
+
model_config = ConfigDict(frozen=True)
|
|
467
|
+
|
|
468
|
+
type: LayerType = Field(description="The LayerType (raster or vector).")
|
|
469
|
+
data_source: DataSourceConfig | None = Field(
|
|
470
|
+
default=None,
|
|
471
|
+
description="Optional DataSourceConfig if this layer is retrievable.",
|
|
472
|
+
)
|
|
473
|
+
alias: str | None = Field(
|
|
474
|
+
default=None, description="Alias for this layer to use in the tile store."
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Raster layer options.
|
|
478
|
+
band_sets: list[BandSetConfig] = Field(
|
|
479
|
+
default_factory=lambda: [],
|
|
480
|
+
description="For raster layers, the bands to store in this layer.",
|
|
481
|
+
)
|
|
482
|
+
resampling_method: ResamplingMethod = Field(
|
|
483
|
+
default=ResamplingMethod.BILINEAR,
|
|
484
|
+
description="For raster layers, how to resample rasters (if neeed), default bilinear resampling.",
|
|
485
|
+
)
|
|
486
|
+
compositing_method: CompositingMethod = Field(
|
|
487
|
+
default=CompositingMethod.FIRST_VALID,
|
|
488
|
+
description="For raster layers, how to compute pixel values in the composite of each window's items.",
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# Vector layer options.
|
|
492
|
+
vector_format: dict[str, Any] = Field(
|
|
493
|
+
default_factory=lambda: {
|
|
494
|
+
"class_path": "rslearn.utils.vector_format.GeojsonVectorFormat"
|
|
495
|
+
},
|
|
496
|
+
description="For vector layers, the jsonargparse configuration for the VectorFormat.",
|
|
497
|
+
)
|
|
498
|
+
class_property_name: str | None = Field(
|
|
499
|
+
default=None,
|
|
500
|
+
description="Optional metadata field indicating that the GeoJSON features contain a property that corresponds to a class label, and this is the name of that property.",
|
|
501
|
+
)
|
|
502
|
+
class_names: list[str] | None = Field(
|
|
503
|
+
default=None,
|
|
504
|
+
description="The list of classes that the class_property_name property could be set to.",
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
@model_validator(mode="after")
|
|
508
|
+
def after_validator(self) -> "LayerConfig":
|
|
509
|
+
"""Ensure the LayerConfig is valid."""
|
|
510
|
+
if self.type == LayerType.RASTER and len(self.band_sets) == 0:
|
|
511
|
+
raise ValueError(
|
|
512
|
+
"band sets must be specified and non-empty for raster layers"
|
|
513
|
+
)
|
|
434
514
|
|
|
435
|
-
|
|
436
|
-
layer_type: the LayerType (raster or vector)
|
|
437
|
-
data_source: optional DataSourceConfig if this layer is retrievable
|
|
438
|
-
alias: alias for this layer to use in the tile store
|
|
439
|
-
"""
|
|
440
|
-
self.layer_type = layer_type
|
|
441
|
-
self.data_source = data_source
|
|
442
|
-
self.alias = alias
|
|
443
|
-
|
|
444
|
-
def serialize(self) -> dict[str, Any]:
|
|
445
|
-
"""Serialize this LayerConfig to a config dict."""
|
|
446
|
-
return {
|
|
447
|
-
"layer_type": str(self.layer_type),
|
|
448
|
-
"data_source": self.data_source.serialize() if self.data_source else None,
|
|
449
|
-
"alias": self.alias,
|
|
450
|
-
}
|
|
515
|
+
return self
|
|
451
516
|
|
|
452
517
|
def __hash__(self) -> int:
|
|
453
518
|
"""Return a hash of this LayerConfig."""
|
|
454
|
-
return hash(json.dumps(self.
|
|
519
|
+
return hash(json.dumps(self.model_dump(mode="json"), sort_keys=True))
|
|
455
520
|
|
|
456
521
|
def __eq__(self, other: Any) -> bool:
|
|
457
522
|
"""Returns whether other is the same as this LayerConfig.
|
|
@@ -461,142 +526,71 @@ class LayerConfig:
|
|
|
461
526
|
"""
|
|
462
527
|
if not isinstance(other, LayerConfig):
|
|
463
528
|
return False
|
|
464
|
-
return self.
|
|
465
|
-
|
|
529
|
+
return self.model_dump() == other.model_dump()
|
|
466
530
|
|
|
467
|
-
|
|
468
|
-
|
|
531
|
+
@functools.cache
|
|
532
|
+
def instantiate_data_source(self, ds_path: UPath | None = None) -> "DataSource":
|
|
533
|
+
"""Instantiate the data source specified by this config.
|
|
469
534
|
|
|
470
|
-
|
|
471
|
-
|
|
535
|
+
Args:
|
|
536
|
+
ds_path: optional dataset path to include in the DataSourceContext.
|
|
472
537
|
|
|
473
|
-
|
|
474
|
-
|
|
538
|
+
Returns:
|
|
539
|
+
the DataSource object.
|
|
540
|
+
"""
|
|
541
|
+
from rslearn.data_sources.data_source import DataSource, DataSourceContext
|
|
542
|
+
from rslearn.utils.jsonargparse import data_source_context_serializer
|
|
475
543
|
|
|
476
|
-
|
|
477
|
-
|
|
544
|
+
logger.debug("getting a data source for dataset at %s", ds_path)
|
|
545
|
+
if self.data_source is None:
|
|
546
|
+
raise ValueError("This layer does not specify a data source")
|
|
478
547
|
|
|
548
|
+
# Inject the DataSourceContext into the args.
|
|
549
|
+
context = DataSourceContext(
|
|
550
|
+
ds_path=ds_path,
|
|
551
|
+
layer_config=self,
|
|
552
|
+
)
|
|
553
|
+
ds_config: dict[str, Any] = {
|
|
554
|
+
"class_path": self.data_source.class_path,
|
|
555
|
+
"init_args": copy.deepcopy(self.data_source.init_args),
|
|
556
|
+
}
|
|
557
|
+
ds_config["init_args"]["context"] = data_source_context_serializer(context)
|
|
479
558
|
|
|
480
|
-
|
|
481
|
-
|
|
559
|
+
# Now we can parse with jsonargparse.
|
|
560
|
+
from rslearn.utils.jsonargparse import (
|
|
561
|
+
data_source_context_serializer,
|
|
562
|
+
init_jsonargparse,
|
|
563
|
+
)
|
|
482
564
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
data_source
|
|
488
|
-
|
|
489
|
-
alias: str | None = None,
|
|
490
|
-
compositing_method: CompositingMethod = CompositingMethod.FIRST_VALID,
|
|
491
|
-
):
|
|
492
|
-
"""Initialize a new RasterLayerConfig.
|
|
565
|
+
init_jsonargparse()
|
|
566
|
+
parser = jsonargparse.ArgumentParser()
|
|
567
|
+
parser.add_argument("--data_source", type=DataSource)
|
|
568
|
+
cfg = parser.parse_object({"data_source": ds_config})
|
|
569
|
+
data_source = parser.instantiate_classes(cfg).data_source
|
|
570
|
+
return data_source
|
|
493
571
|
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
compositing_method: how to compute pixel values in the composite of each windows items
|
|
501
|
-
"""
|
|
502
|
-
super().__init__(layer_type, data_source, alias)
|
|
503
|
-
self.band_sets = band_sets
|
|
504
|
-
self.resampling_method = resampling_method
|
|
505
|
-
self.compositing_method = compositing_method
|
|
572
|
+
def instantiate_vector_format(self) -> VectorFormat:
|
|
573
|
+
"""Instantiate the vector format specified by this config."""
|
|
574
|
+
if self.type != LayerType.VECTOR:
|
|
575
|
+
raise ValueError(
|
|
576
|
+
f"cannot instantiate vector format for layer with type {self.type}"
|
|
577
|
+
)
|
|
506
578
|
|
|
507
|
-
|
|
508
|
-
def from_config(config: dict[str, Any]) -> "RasterLayerConfig":
|
|
509
|
-
"""Create a RasterLayerConfig from config dict.
|
|
579
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
510
580
|
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
""
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
}
|
|
518
|
-
if "data_source" in config:
|
|
519
|
-
kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
|
|
520
|
-
if "resampling_method" in config:
|
|
521
|
-
kwargs["resampling_method"] = RESAMPLING_METHODS[
|
|
522
|
-
config["resampling_method"]
|
|
523
|
-
]
|
|
524
|
-
if "alias" in config:
|
|
525
|
-
kwargs["alias"] = config["alias"]
|
|
526
|
-
if "compositing_method" in config:
|
|
527
|
-
kwargs["compositing_method"] = CompositingMethod[
|
|
528
|
-
config["compositing_method"]
|
|
529
|
-
]
|
|
530
|
-
return RasterLayerConfig(**kwargs) # type: ignore
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
class VectorLayerConfig(LayerConfig):
|
|
534
|
-
"""Configuration of a vector layer."""
|
|
535
|
-
|
|
536
|
-
def __init__(
|
|
537
|
-
self,
|
|
538
|
-
layer_type: LayerType,
|
|
539
|
-
data_source: DataSourceConfig | None = None,
|
|
540
|
-
format: VectorFormatConfig = VectorFormatConfig("geojson"),
|
|
541
|
-
alias: str | None = None,
|
|
542
|
-
class_property_name: str | None = None,
|
|
543
|
-
class_names: list[str] | None = None,
|
|
544
|
-
):
|
|
545
|
-
"""Initialize a new VectorLayerConfig.
|
|
581
|
+
init_jsonargparse()
|
|
582
|
+
parser = jsonargparse.ArgumentParser()
|
|
583
|
+
parser.add_argument("--vector_format", type=VectorFormat)
|
|
584
|
+
cfg = parser.parse_object({"vector_format": self.vector_format})
|
|
585
|
+
vector_format = parser.instantiate_classes(cfg).vector_format
|
|
586
|
+
return vector_format
|
|
546
587
|
|
|
547
|
-
Args:
|
|
548
|
-
layer_type: the LayerType (must be vector)
|
|
549
|
-
data_source: optional DataSourceConfig if this layer is retrievable
|
|
550
|
-
format: the VectorFormatConfig, default storing as GeoJSON
|
|
551
|
-
alias: alias for this layer to use in the tile store
|
|
552
|
-
class_property_name: optional metadata field indicating that the GeoJSON
|
|
553
|
-
features contain a property that corresponds to a class label, and this
|
|
554
|
-
is the name of that property.
|
|
555
|
-
class_names: the list of classes that the class_property_name property
|
|
556
|
-
could be set to.
|
|
557
|
-
"""
|
|
558
|
-
super().__init__(layer_type, data_source, alias)
|
|
559
|
-
self.format = format
|
|
560
|
-
self.class_property_name = class_property_name
|
|
561
|
-
self.class_names = class_names
|
|
562
588
|
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
"""Create a VectorLayerConfig from config dict.
|
|
589
|
+
class DatasetConfig(BaseModel):
|
|
590
|
+
"""Overall dataset configuration."""
|
|
566
591
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
"""
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
|
|
573
|
-
if "format" in config:
|
|
574
|
-
kwargs["format"] = VectorFormatConfig.from_config(config["format"])
|
|
575
|
-
|
|
576
|
-
simple_optionals = [
|
|
577
|
-
"alias",
|
|
578
|
-
"class_property_name",
|
|
579
|
-
"class_names",
|
|
580
|
-
]
|
|
581
|
-
for k in simple_optionals:
|
|
582
|
-
if k in config:
|
|
583
|
-
kwargs[k] = config[k]
|
|
584
|
-
|
|
585
|
-
# The "zoom_offset" option was removed.
|
|
586
|
-
# We should change how we create configuration so we can error on all
|
|
587
|
-
# non-existing config options, but for now we make sure to raise error if
|
|
588
|
-
# zoom_offset is set since it is no longer supported.
|
|
589
|
-
if "zoom_offset" in config:
|
|
590
|
-
raise ValueError("unsupported zoom_offset option in vector layer config")
|
|
591
|
-
|
|
592
|
-
return VectorLayerConfig(**kwargs) # type: ignore
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
def load_layer_config(config: dict[str, Any]) -> LayerConfig:
|
|
596
|
-
"""Load a LayerConfig from a config dict."""
|
|
597
|
-
layer_type = LayerType(config.get("type"))
|
|
598
|
-
if layer_type == LayerType.RASTER:
|
|
599
|
-
return RasterLayerConfig.from_config(config)
|
|
600
|
-
elif layer_type == LayerType.VECTOR:
|
|
601
|
-
return VectorLayerConfig.from_config(config)
|
|
602
|
-
raise ValueError(f"Unknown layer type {layer_type}")
|
|
592
|
+
layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
|
|
593
|
+
tile_store: dict[str, Any] = Field(
|
|
594
|
+
default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
|
|
595
|
+
description="jsonargparse configuration for the TileStore.",
|
|
596
|
+
)
|